/*
 * Decompiled with CFR 0.152.
 */
package edu.stanford.nlp.ie.crf;

import edu.stanford.nlp.ie.crf.CRFClassifier;
import edu.stanford.nlp.ie.crf.CRFNonLinearLogConditionalObjectiveFunction;
import edu.stanford.nlp.ie.crf.CRFNonLinearSecondOrderLogConditionalObjectiveFunction;
import edu.stanford.nlp.ie.crf.CliquePotentialFunction;
import edu.stanford.nlp.ie.crf.NonLinearCliquePotentialFunction;
import edu.stanford.nlp.ie.crf.NonLinearSecondOrderCliquePotentialFunction;
import edu.stanford.nlp.io.RuntimeIOException;
import edu.stanford.nlp.optimization.AbstractCachingDiffFunction;
import edu.stanford.nlp.optimization.DiffFunction;
import edu.stanford.nlp.optimization.Evaluator;
import edu.stanford.nlp.optimization.Minimizer;
import edu.stanford.nlp.optimization.StochasticDiffFunctionTester;
import edu.stanford.nlp.sequences.SeqClassifierFlags;
import edu.stanford.nlp.util.ConvertByteArray;
import edu.stanford.nlp.util.CoreMap;
import edu.stanford.nlp.util.HashIndex;
import edu.stanford.nlp.util.Index;
import edu.stanford.nlp.util.Quadruple;
import edu.stanford.nlp.util.StringUtils;
import edu.stanford.nlp.util.Triple;
import edu.stanford.nlp.util.logging.Redwood;
import java.io.BufferedInputStream;
import java.io.BufferedReader;
import java.io.DataInputStream;
import java.io.FileInputStream;
import java.io.IOException;
import java.io.ObjectInputStream;
import java.io.ObjectOutputStream;
import java.io.PrintWriter;
import java.util.ArrayList;
import java.util.List;
import java.util.Properties;
import java.util.zip.GZIPInputStream;

public class CRFClassifierNonlinear<IN extends CoreMap>
extends CRFClassifier<IN> {
    private static final Redwood.RedwoodChannels log = Redwood.channels(CRFClassifierNonlinear.class);
    private double[][] linearWeights;
    private double[][] inputLayerWeights4Edge;
    private double[][] outputLayerWeights4Edge;
    private double[][] inputLayerWeights;
    private double[][] outputLayerWeights;

    protected CRFClassifierNonlinear() {
        super(new SeqClassifierFlags());
    }

    public CRFClassifierNonlinear(Properties props) {
        super(props);
    }

    public CRFClassifierNonlinear(SeqClassifierFlags flags) {
        super(flags);
    }

    @Override
    public Triple<int[][][], int[], double[][][]> documentToDataAndLabels(List<IN> document) {
        Triple<int[][][], int[], double[][][]> result = super.documentToDataAndLabels(document);
        int[][][] data = result.first();
        data = this.transformDocData(data);
        return new Triple<int[][][], int[], double[][][]>(data, result.second(), result.third());
    }

    private int[][][] transformDocData(int[][][] docData) {
        int[][][] transData = new int[docData.length][][];
        for (int i = 0; i < docData.length; ++i) {
            transData[i] = new int[docData[i].length][];
            for (int j = 0; j < docData[i].length; ++j) {
                int[] cliqueFeatures = docData[i][j];
                transData[i][j] = new int[cliqueFeatures.length];
                for (int n = 0; n < cliqueFeatures.length; ++n) {
                    int transFeatureIndex = -1;
                    if (j == 0) {
                        transFeatureIndex = this.nodeFeatureIndicesMap.indexOf(cliqueFeatures[n]);
                        if (transFeatureIndex == -1) {
                            throw new RuntimeException("node cliqueFeatures[n]=" + cliqueFeatures[n] + " not found, nodeFeatureIndicesMap.size=" + this.nodeFeatureIndicesMap.size());
                        }
                    } else {
                        transFeatureIndex = this.edgeFeatureIndicesMap.indexOf(cliqueFeatures[n]);
                        if (transFeatureIndex == -1) {
                            throw new RuntimeException("edge cliqueFeatures[n]=" + cliqueFeatures[n] + " not found, edgeFeatureIndicesMap.size=" + this.edgeFeatureIndicesMap.size());
                        }
                    }
                    transData[i][j][n] = transFeatureIndex;
                }
            }
        }
        return transData;
    }

    @Override
    protected CliquePotentialFunction getCliquePotentialFunctionForTest() {
        if (this.cliquePotentialFunction == null) {
            this.cliquePotentialFunction = this.flags.secondOrderNonLinear ? new NonLinearSecondOrderCliquePotentialFunction(this.inputLayerWeights4Edge, this.outputLayerWeights4Edge, this.inputLayerWeights, this.outputLayerWeights, this.flags) : new NonLinearCliquePotentialFunction(this.linearWeights, this.inputLayerWeights, this.outputLayerWeights, this.flags);
        }
        return this.cliquePotentialFunction;
    }

    @Override
    protected double[] trainWeights(int[][][][] data, int[][] labels, Evaluator[] evaluators, int pruneFeatureItr, double[][][][] featureVals) {
        if (this.flags.secondOrderNonLinear) {
            CRFNonLinearSecondOrderLogConditionalObjectiveFunction func = new CRFNonLinearSecondOrderLogConditionalObjectiveFunction(data, labels, this.windowSize, this.classIndex, this.labelIndices, this.map, this.flags, this.nodeFeatureIndicesMap.size(), this.edgeFeatureIndicesMap.size());
            this.cliquePotentialFunctionHelper = func;
            double[] allWeights = this.trainWeightsUsingNonLinearCRF(func, evaluators);
            Quadruple<double[][], double[][], double[][], double[][]> params = func.separateWeights(allWeights);
            this.inputLayerWeights4Edge = params.first();
            this.outputLayerWeights4Edge = params.second();
            this.inputLayerWeights = params.third();
            this.outputLayerWeights = params.fourth();
        } else {
            CRFNonLinearLogConditionalObjectiveFunction func = new CRFNonLinearLogConditionalObjectiveFunction(data, labels, this.windowSize, this.classIndex, this.labelIndices, this.map, this.flags, this.nodeFeatureIndicesMap.size(), this.edgeFeatureIndicesMap.size(), featureVals);
            if (this.flags.useAdaGradFOBOS) {
                func.gradientsOnly = true;
            }
            this.cliquePotentialFunctionHelper = func;
            double[] allWeights = this.trainWeightsUsingNonLinearCRF(func, evaluators);
            Triple<double[][], double[][], double[][]> params = func.separateWeights(allWeights);
            this.linearWeights = params.first();
            this.inputLayerWeights = params.second();
            this.outputLayerWeights = params.third();
        }
        return null;
    }

    private double[] trainWeightsUsingNonLinearCRF(AbstractCachingDiffFunction func, Evaluator[] evaluators) {
        double[] initialWeights;
        Minimizer<DiffFunction> minimizer = this.getMinimizer(0, evaluators);
        if (this.flags.initialWeights == null) {
            initialWeights = func.initial();
        } else {
            log.info("Reading initial weights from file " + this.flags.initialWeights);
            try (DataInputStream dis = new DataInputStream(new BufferedInputStream(new GZIPInputStream(new FileInputStream(this.flags.initialWeights))));){
                initialWeights = ConvertByteArray.readDoubleArr(dis);
            }
            catch (IOException e) {
                throw new RuntimeException("Could not read from double initial weight file " + this.flags.initialWeights);
            }
        }
        log.info("numWeights: " + initialWeights.length);
        if (this.flags.testObjFunction) {
            StochasticDiffFunctionTester tester = new StochasticDiffFunctionTester(func);
            if (tester.testSumOfBatches(initialWeights, 1.0E-4)) {
                log.info("Testing complete... exiting");
                System.exit(1);
            } else {
                log.info("Testing failed....exiting");
                System.exit(1);
            }
        }
        if (this.flags.checkGradient) {
            if (func.gradientCheck()) {
                log.info("gradient check passed");
            } else {
                throw new RuntimeException("gradient check failed");
            }
        }
        return minimizer.minimize(func, this.flags.tolerance, initialWeights);
    }

    @Override
    protected void serializeTextClassifier(PrintWriter pw) throws Exception {
        ArrayList<Double> list;
        int i;
        super.serializeTextClassifier(pw);
        pw.printf("nodeFeatureIndicesMap.size()=\t%d%n", this.nodeFeatureIndicesMap.size());
        for (i = 0; i < this.nodeFeatureIndicesMap.size(); ++i) {
            pw.printf("%d\t%d%n", i, this.nodeFeatureIndicesMap.get(i));
        }
        pw.printf("edgeFeatureIndicesMap.size()=\t%d%n", this.edgeFeatureIndicesMap.size());
        for (i = 0; i < this.edgeFeatureIndicesMap.size(); ++i) {
            pw.printf("%d\t%d%n", i, this.edgeFeatureIndicesMap.get(i));
        }
        if (this.flags.secondOrderNonLinear) {
            pw.printf("inputLayerWeights4Edge.length=\t%d%n", this.inputLayerWeights4Edge.length);
            for (double[] ws : this.inputLayerWeights4Edge) {
                list = new ArrayList<Double>();
                for (double w : ws) {
                    list.add(w);
                }
                pw.printf("%d\t%s%n", ws.length, StringUtils.join(list, " "));
            }
            pw.printf("outputLayerWeights4Edge.length=\t%d%n", this.outputLayerWeights4Edge.length);
            for (double[] ws : this.outputLayerWeights4Edge) {
                list = new ArrayList();
                for (double w : ws) {
                    list.add(w);
                }
                pw.printf("%d\t%s%n", ws.length, StringUtils.join(list, " "));
            }
        } else {
            pw.printf("linearWeights.length=\t%d%n", this.linearWeights.length);
            for (double[] ws : this.linearWeights) {
                list = new ArrayList();
                for (double w : ws) {
                    list.add(w);
                }
                pw.printf("%d\t%s%n", ws.length, StringUtils.join(list, " "));
            }
        }
        pw.printf("inputLayerWeights.length=\t%d%n", this.inputLayerWeights.length);
        for (double[] ws : this.inputLayerWeights) {
            list = new ArrayList();
            for (double w : ws) {
                list.add(w);
            }
            pw.printf("%d\t%s%n", ws.length, StringUtils.join(list, " "));
        }
        pw.printf("outputLayerWeights.length=\t%d%n", this.outputLayerWeights.length);
        for (double[] ws : this.outputLayerWeights) {
            list = new ArrayList();
            for (double w : ws) {
                list.add(w);
            }
            pw.printf("%d\t%s%n", ws.length, StringUtils.join(list, " "));
        }
    }

    @Override
    protected void loadTextClassifier(BufferedReader br) throws Exception {
        int i2;
        String[] weightsValue;
        int weights2Length;
        int count;
        super.loadTextClassifier(br);
        String line = br.readLine();
        String[] toks = line.split("\\t");
        if (!toks[0].equals("nodeFeatureIndicesMap.size()=")) {
            throw new RuntimeException("format error in nodeFeatureIndicesMap");
        }
        int nodeFeatureIndicesMapSize = Integer.parseInt(toks[1]);
        this.nodeFeatureIndicesMap = new HashIndex();
        for (count = 0; count < nodeFeatureIndicesMapSize; ++count) {
            line = br.readLine();
            toks = line.split("\\t");
            int idx = Integer.parseInt(toks[0]);
            if (count != idx) {
                throw new RuntimeException("format error");
            }
            this.nodeFeatureIndicesMap.add(Integer.parseInt(toks[1]));
        }
        line = br.readLine();
        toks = line.split("\\t");
        if (!toks[0].equals("edgeFeatureIndicesMap.size()=")) {
            throw new RuntimeException("format error");
        }
        int edgeFeatureIndicesMapSize = Integer.parseInt(toks[1]);
        this.edgeFeatureIndicesMap = new HashIndex();
        for (count = 0; count < edgeFeatureIndicesMapSize; ++count) {
            line = br.readLine();
            toks = line.split("\\t");
            int idx = Integer.parseInt(toks[0]);
            if (count != idx) {
                throw new RuntimeException("format error");
            }
            this.edgeFeatureIndicesMap.add(Integer.parseInt(toks[1]));
        }
        int weightsLength = -1;
        if (this.flags.secondOrderNonLinear) {
            line = br.readLine();
            toks = line.split("\\t");
            if (!toks[0].equals("inputLayerWeights4Edge.length=")) {
                throw new RuntimeException("format error");
            }
            weightsLength = Integer.parseInt(toks[1]);
            this.inputLayerWeights4Edge = new double[weightsLength][];
            for (count = 0; count < weightsLength; ++count) {
                line = br.readLine();
                toks = line.split("\\t");
                weights2Length = Integer.parseInt(toks[0]);
                this.inputLayerWeights4Edge[count] = new double[weights2Length];
                weightsValue = toks[1].split(" ");
                if (weights2Length != weightsValue.length) {
                    throw new RuntimeException("weights format error");
                }
                for (i2 = 0; i2 < weights2Length; ++i2) {
                    this.inputLayerWeights4Edge[count][i2] = Double.parseDouble(weightsValue[i2]);
                }
            }
            line = br.readLine();
            toks = line.split("\\t");
            if (!toks[0].equals("outputLayerWeights4Edge.length=")) {
                throw new RuntimeException("format error");
            }
            weightsLength = Integer.parseInt(toks[1]);
            this.outputLayerWeights4Edge = new double[weightsLength][];
            for (count = 0; count < weightsLength; ++count) {
                line = br.readLine();
                toks = line.split("\\t");
                weights2Length = Integer.parseInt(toks[0]);
                this.outputLayerWeights4Edge[count] = new double[weights2Length];
                weightsValue = toks[1].split(" ");
                if (weights2Length != weightsValue.length) {
                    throw new RuntimeException("weights format error");
                }
                for (i2 = 0; i2 < weights2Length; ++i2) {
                    this.outputLayerWeights4Edge[count][i2] = Double.parseDouble(weightsValue[i2]);
                }
            }
        } else {
            line = br.readLine();
            toks = line.split("\\t");
            if (!toks[0].equals("linearWeights.length=")) {
                throw new RuntimeException("format error");
            }
            weightsLength = Integer.parseInt(toks[1]);
            this.linearWeights = new double[weightsLength][];
            for (count = 0; count < weightsLength; ++count) {
                line = br.readLine();
                toks = line.split("\\t");
                weights2Length = Integer.parseInt(toks[0]);
                this.linearWeights[count] = new double[weights2Length];
                weightsValue = toks[1].split(" ");
                if (weights2Length != weightsValue.length) {
                    throw new RuntimeException("weights format error");
                }
                for (i2 = 0; i2 < weights2Length; ++i2) {
                    this.linearWeights[count][i2] = Double.parseDouble(weightsValue[i2]);
                }
            }
        }
        if (!(toks = (line = br.readLine()).split("\\t"))[0].equals("inputLayerWeights.length=")) {
            throw new RuntimeException("format error");
        }
        weightsLength = Integer.parseInt(toks[1]);
        this.inputLayerWeights = new double[weightsLength][];
        for (count = 0; count < weightsLength; ++count) {
            line = br.readLine();
            toks = line.split("\\t");
            weights2Length = Integer.parseInt(toks[0]);
            this.inputLayerWeights[count] = new double[weights2Length];
            weightsValue = toks[1].split(" ");
            if (weights2Length != weightsValue.length) {
                throw new RuntimeException("weights format error");
            }
            for (i2 = 0; i2 < weights2Length; ++i2) {
                this.inputLayerWeights[count][i2] = Double.parseDouble(weightsValue[i2]);
            }
        }
        line = br.readLine();
        toks = line.split("\\t");
        if (!toks[0].equals("outputLayerWeights.length=")) {
            throw new RuntimeException("format error");
        }
        weightsLength = Integer.parseInt(toks[1]);
        this.outputLayerWeights = new double[weightsLength][];
        for (count = 0; count < weightsLength; ++count) {
            line = br.readLine();
            toks = line.split("\\t");
            weights2Length = Integer.parseInt(toks[0]);
            this.outputLayerWeights[count] = new double[weights2Length];
            weightsValue = toks[1].split(" ");
            if (weights2Length != weightsValue.length) {
                throw new RuntimeException("weights format error");
            }
            for (i2 = 0; i2 < weights2Length; ++i2) {
                this.outputLayerWeights[count][i2] = Double.parseDouble(weightsValue[i2]);
            }
        }
    }

    @Override
    public void serializeClassifier(ObjectOutputStream oos) {
        try {
            super.serializeClassifier(oos);
            oos.writeObject(this.nodeFeatureIndicesMap);
            oos.writeObject(this.edgeFeatureIndicesMap);
            if (this.flags.secondOrderNonLinear) {
                oos.writeObject(this.inputLayerWeights4Edge);
                oos.writeObject(this.outputLayerWeights4Edge);
            } else {
                oos.writeObject(this.linearWeights);
            }
            oos.writeObject(this.inputLayerWeights);
            oos.writeObject(this.outputLayerWeights);
        }
        catch (IOException e) {
            throw new RuntimeIOException(e);
        }
    }

    @Override
    public void loadClassifier(ObjectInputStream ois, Properties props) throws ClassCastException, IOException, ClassNotFoundException {
        super.loadClassifier(ois, props);
        this.nodeFeatureIndicesMap = (Index)ois.readObject();
        this.edgeFeatureIndicesMap = (Index)ois.readObject();
        if (this.flags.secondOrderNonLinear) {
            this.inputLayerWeights4Edge = (double[][])ois.readObject();
            this.outputLayerWeights4Edge = (double[][])ois.readObject();
        } else {
            this.linearWeights = (double[][])ois.readObject();
        }
        this.inputLayerWeights = (double[][])ois.readObject();
        this.outputLayerWeights = (double[][])ois.readObject();
    }
}

