/*
 * Decompiled with CFR 0.152.
 */
package edu.stanford.nlp.optimization;

import edu.stanford.nlp.math.ArrayMath;
import edu.stanford.nlp.optimization.AbstractCachingDiffFunction;
import edu.stanford.nlp.optimization.StochasticCalculateMethods;
import edu.stanford.nlp.util.logging.Redwood;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.List;
import java.util.Random;

public abstract class AbstractStochasticCachingDiffFunction
extends AbstractCachingDiffFunction {
    private static final Redwood.RedwoodChannels log = Redwood.channels(AbstractStochasticCachingDiffFunction.class);
    public boolean hasNewVals = true;
    public boolean recalculatePrevBatch = false;
    public boolean returnPreviousValues = false;
    protected int lastBatchSize = 0;
    protected int[] lastBatch = null;
    protected int[] thisBatch = null;
    protected double[] lastXBatch = null;
    protected double[] lastVBatch = null;
    protected int lastElement = 0;
    protected double[] HdotV = null;
    protected double[] gradPerturbed = null;
    protected double[] xPerturbed = null;
    protected int curElement = 0;
    protected List<Integer> allIndices = null;
    protected Random randGenerator = new Random(1L);
    protected boolean scaleUp = false;
    private int[] shuffledArray = null;
    public StochasticCalculateMethods method = StochasticCalculateMethods.ExternalFiniteDifference;
    public SamplingMethod sampleMethod = SamplingMethod.RandomWithoutReplacement;
    protected double finiteDifferenceStepSize = 1.0E-4;

    public void incrementRandom(int numTimes) {
        log.info("incrementing random " + numTimes + " times.");
        for (int i = 0; i < numTimes; ++i) {
            this.randGenerator.nextInt(this.dataDimension());
        }
    }

    public void scaleUp(boolean toScaleUp) {
        this.scaleUp = toScaleUp;
    }

    public abstract void calculateStochastic(double[] var1, double[] var2, int[] var3);

    public abstract int dataDimension();

    @Override
    protected void clearCache() {
        super.clearCache();
        if (this.lastXBatch != null) {
            this.lastXBatch[0] = Double.NaN;
        }
        if (this.lastVBatch != null) {
            this.lastVBatch[0] = Double.NaN;
        }
    }

    @Override
    public double[] initial() {
        double[] initial = new double[this.domainDimension()];
        return initial;
    }

    public void decrementBatch(int batchSize) {
        this.curElement -= batchSize;
        if (this.curElement < 0) {
            this.curElement = 0;
        }
    }

    public void incrementBatch(int batchSize) {
        this.curElement += batchSize;
        this.hasNewVals = false;
        this.recalculatePrevBatch = false;
        this.returnPreviousValues = false;
    }

    protected void getBatch(int batchSize) {
        if (this.thisBatch == null || this.thisBatch.length != batchSize) {
            this.thisBatch = new int[batchSize];
        }
        if (this.sampleMethod == SamplingMethod.Shuffled) {
            if (this.shuffledArray == null) {
                this.shuffledArray = ArrayMath.range(0, this.dataDimension());
            }
            for (int i = 0; i < batchSize; ++i) {
                this.thisBatch[i] = this.shuffledArray[(this.curElement + i) % this.dataDimension()];
            }
            this.curElement = (this.curElement + batchSize) % this.dataDimension();
        } else if (this.sampleMethod == SamplingMethod.RandomWithReplacement) {
            for (int i = 0; i < batchSize; ++i) {
                this.thisBatch[i] = this.randGenerator.nextInt(this.dataDimension());
            }
        } else if (this.sampleMethod == SamplingMethod.Ordered) {
            for (int i = 0; i < batchSize; ++i) {
                this.thisBatch[i] = (this.curElement + i) % this.dataDimension();
            }
            this.curElement = (this.curElement + batchSize) % this.dataDimension();
        } else if (this.sampleMethod == SamplingMethod.RandomWithoutReplacement) {
            int i;
            if (this.allIndices == null || this.allIndices.size() != this.dataDimension()) {
                this.allIndices = new ArrayList<Integer>();
                for (i = 0; i < this.dataDimension(); ++i) {
                    this.allIndices.add(i);
                }
                Collections.shuffle(this.allIndices, this.randGenerator);
            }
            for (i = 0; i < batchSize; ++i) {
                this.thisBatch[i] = this.allIndices.get((this.curElement + i) % this.allIndices.size());
            }
            if (this.curElement + batchSize > this.dataDimension()) {
                Collections.shuffle(this.allIndices, this.randGenerator);
            }
            this.curElement = (this.curElement + batchSize) % this.allIndices.size();
        } else {
            throw new IllegalStateException("NO SAMPLING METHOD SELECTED");
        }
    }

    private void stochasticEnsure(double[] x, double[] v, int batchSize) {
        if (this.lastXBatch == null) {
            this.lastXBatch = new double[this.domainDimension()];
            log.info("Setting previous position (x).");
        }
        if (this.lastVBatch == null) {
            this.lastVBatch = new double[this.domainDimension()];
            log.info("Setting previous gain (v)");
        }
        if (this.derivative == null) {
            this.derivative = new double[this.domainDimension()];
            log.info("Setting Derivative.");
        }
        if (this.HdotV == null) {
            this.HdotV = new double[this.domainDimension()];
            log.info("Setting HdotV.");
        }
        if (this.lastBatch == null) {
            this.lastBatch = new int[batchSize];
            log.info("Setting last batch");
        }
        if (this.recalculatePrevBatch && batchSize == this.lastBatch.length) {
            this.thisBatch = this.lastBatch;
        } else {
            if (this.returnPreviousValues) {
                this.returnPreviousValues = false;
                return;
            }
            if (!this.hasNewVals && this.lastElement != this.curElement && this.lastBatchSize == batchSize && Arrays.equals(x, this.lastXBatch) && Arrays.equals(v, this.lastVBatch) && Arrays.equals(this.thisBatch, this.lastBatch)) {
                return;
            }
            this.getBatch(batchSize);
        }
        AbstractStochasticCachingDiffFunction.copy(this.lastXBatch, x);
        if (this.lastBatch.length != batchSize) {
            this.lastBatch = new int[batchSize];
        }
        System.arraycopy(this.thisBatch, 0, this.lastBatch, 0, this.thisBatch.length);
        if (v != null) {
            AbstractStochasticCachingDiffFunction.copy(this.lastVBatch, v);
        }
        this.lastBatchSize = batchSize;
        this.calculateStochastic(x, v, this.thisBatch);
        if (this.scaleUp) {
            double ratio = (double)this.dataDimension() / (double)batchSize;
            for (int i = 0; i < x.length; ++i) {
                this.derivative[i] = this.derivative[i] * ratio;
            }
            this.value = ratio * this.value;
        }
        this.incrementBatch(batchSize);
        this.lastElement = this.curElement;
    }

    public double valueAt(double[] x, int batchSize) {
        this.stochasticEnsure(x, null, batchSize);
        return this.value;
    }

    public double[] derivativeAt(double[] x, int batchSize) {
        this.stochasticEnsure(x, null, batchSize);
        return this.derivative;
    }

    public double valueAt(double[] x, double[] v, int batchSize) {
        this.stochasticEnsure(x, v, batchSize);
        return this.value;
    }

    public double[] derivativeAt(double[] x, double[] v, int batchSize) {
        this.stochasticEnsure(x, v, batchSize);
        return this.derivative;
    }

    private void getHdotVFiniteDifference(double[] x, double[] v, double[] curDerivative) {
        double h = this.finiteDifferenceStepSize;
        double hInv = 1.0 / h;
        if (this.gradPerturbed == null) {
            this.gradPerturbed = new double[x.length];
            System.out.println("Setting approximate gradient.");
        }
        if (this.xPerturbed == null) {
            this.xPerturbed = new double[x.length];
            System.out.println("Setting perturbed.");
        }
        if (this.HdotV == null) {
            this.HdotV = new double[x.length];
            System.out.println("Setting H dot V.");
        }
        for (int i = 0; i < x.length; ++i) {
            this.xPerturbed[i] = x[i] + h * v[i];
        }
        double prevValue = this.value;
        this.recalculatePrevBatch = true;
        this.calculateStochastic(this.xPerturbed, null, this.thisBatch);
        for (int i = 0; i < x.length; ++i) {
            double tmp = this.derivative[i] - curDerivative[i];
            this.HdotV[i] = hInv * tmp;
        }
        System.arraycopy(curDerivative, 0, this.derivative, 0, this.derivative.length);
        this.value = prevValue;
        this.hasNewVals = false;
        this.recalculatePrevBatch = false;
        this.returnPreviousValues = false;
    }

    public double[] HdotVAt(double[] x, double[] v, int batchSize) {
        if (this.method == StochasticCalculateMethods.ExternalFiniteDifference) {
            throw new RuntimeException("Attempt to use ExternalFiniteDifference without passing currentDerivative");
        }
        this.stochasticEnsure(x, v, batchSize);
        return this.HdotV;
    }

    public double[] HdotVAt(double[] x, double[] v, double[] curDerivative, int batchSize) {
        if (this.method == StochasticCalculateMethods.ExternalFiniteDifference) {
            this.getHdotVFiniteDifference(x, v, curDerivative);
        } else {
            this.stochasticEnsure(x, v, batchSize);
        }
        return this.HdotV;
    }

    public double[] HdotVAt(double[] x, double[] v) {
        if (this.method == StochasticCalculateMethods.ExternalFiniteDifference) {
            log.info("Attempt to use ExternalFiniteDifference without passing currentDerivative");
            throw new RuntimeException();
        }
        this.stochasticEnsure(x, v, this.dataDimension());
        this.decrementBatch(this.dataDimension());
        return this.HdotV;
    }

    public double[] lastDerivative() {
        return this.derivative;
    }

    @Override
    public double lastValue() {
        return this.value;
    }

    public static enum SamplingMethod {
        NoneSpecified,
        RandomWithReplacement,
        RandomWithoutReplacement,
        Ordered,
        Shuffled;

    }
}

