/*
 * Decompiled with CFR 0.152.
 */
package edu.stanford.nlp.parser.lexparser;

import edu.stanford.nlp.parser.lexparser.IntDependency;
import edu.stanford.nlp.parser.lexparser.IntTaggedWord;
import edu.stanford.nlp.parser.lexparser.Lexicon;
import edu.stanford.nlp.parser.lexparser.MLEDependencyGrammar;
import edu.stanford.nlp.parser.lexparser.Options;
import edu.stanford.nlp.parser.lexparser.TreebankLangParserParams;
import edu.stanford.nlp.stats.ClassicCounter;
import edu.stanford.nlp.util.Generics;
import edu.stanford.nlp.util.Index;
import edu.stanford.nlp.util.Pair;
import edu.stanford.nlp.util.Triple;
import edu.stanford.nlp.util.logging.Redwood;
import java.io.BufferedReader;
import java.io.FileInputStream;
import java.io.IOException;
import java.io.InputStream;
import java.io.InputStreamReader;
import java.util.ArrayList;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.regex.Matcher;
import java.util.regex.Pattern;

public class ChineseSimWordAvgDepGrammar
extends MLEDependencyGrammar {
    private static Redwood.RedwoodChannels log = Redwood.channels(ChineseSimWordAvgDepGrammar.class);
    private static final long serialVersionUID = -1845503582705055342L;
    private static final double simSmooth = 10.0;
    private static final String argHeadFile = "simWords/ArgHead.5";
    private static final String headArgFile = "simWords/HeadArg.5";
    private Map<Pair<Integer, String>, List<Triple<Integer, String, Double>>> simArgMap;
    private Map<Pair<Integer, String>, List<Triple<Integer, String, Double>>> simHeadMap;
    private static final boolean debug = true;
    private static final boolean verbose = false;
    private ClassicCounter<String> statsCounter = new ClassicCounter();

    public ChineseSimWordAvgDepGrammar(TreebankLangParserParams tlpParams, boolean directional, boolean distance, boolean coarseDistance, boolean basicCategoryTagsInDependencyGrammar, Options op, Index<String> wordIndex, Index<String> tagIndex) {
        super(tlpParams, directional, distance, coarseDistance, basicCategoryTagsInDependencyGrammar, op, wordIndex, tagIndex);
        this.simHeadMap = this.getMap(headArgFile);
        this.simArgMap = this.getMap(argHeadFile);
    }

    public Map<Pair<Integer, String>, List<Triple<Integer, String, Double>>> getMap(String filename) {
        Map<Pair<Integer, String>, List<Triple<Integer, String, Double>>> hashMap = Generics.newHashMap();
        try {
            String wordMapLine;
            BufferedReader wordMapBReader = new BufferedReader(new InputStreamReader((InputStream)new FileInputStream(filename), "UTF-8"));
            Pattern linePattern = Pattern.compile("sim\\((.+)/(.+):(.+)/(.+)\\)=(.+)");
            while ((wordMapLine = wordMapBReader.readLine()) != null) {
                Matcher m = linePattern.matcher(wordMapLine);
                if (!m.matches()) {
                    log.info("Ill-formed line in similar word map file: " + wordMapLine);
                    continue;
                }
                Pair<Integer, String> iTW = new Pair<Integer, String>(this.wordIndex.addToIndex(m.group(1)), m.group(2));
                double score = Double.parseDouble(m.group(5));
                List<Triple<Integer, String, Double>> tripleList = hashMap.get(iTW);
                if (tripleList == null) {
                    tripleList = new ArrayList<Triple<Integer, String, Double>>();
                    hashMap.put(iTW, tripleList);
                }
                tripleList.add(new Triple<Integer, String, Double>(this.wordIndex.addToIndex(m.group(3)), m.group(4), score));
            }
        }
        catch (IOException e) {
            throw new RuntimeException("Problem reading similar words file!");
        }
        return hashMap;
    }

    @Override
    public double scoreTB(IntDependency dependency) {
        return this.op.testOptions.depWeight * Math.log(this.probTBwithSimWords(dependency));
    }

    public void setLex(Lexicon lex) {
        this.lex = lex;
    }

    public void dumpSimWordAvgStats() {
        log.info("SimWordAvg stats:");
        log.info(this.statsCounter);
    }

    private double probTBwithSimWords(IntDependency dependency) {
        double pSim_aTW_hTd;
        boolean leftHeaded = dependency.leftHeaded && this.directional;
        IntTaggedWord unknownHead = new IntTaggedWord(-1, dependency.head.tag);
        IntTaggedWord unknownArg = new IntTaggedWord(-1, dependency.arg.tag);
        short distance = dependency.distance;
        IntTaggedWord aTW = dependency.arg;
        double pb_stop_hTWds = this.getStopProb(dependency);
        boolean isRoot = this.rootTW(dependency.head);
        if (dependency.arg.word == -2) {
            if (isRoot) {
                return 0.0;
            }
            return pb_stop_hTWds;
        }
        double pb_go_hTWds = 1.0 - pb_stop_hTWds;
        if (isRoot) {
            pb_go_hTWds = 1.0;
        }
        short valenceBinDistance = this.valenceBin(distance);
        IntDependency temp = new IntDependency(dependency.head, dependency.arg, leftHeaded, valenceBinDistance);
        double c_aTW_hTWd = this.argCounter.getCount(temp);
        temp = new IntDependency(dependency.head, unknownArg, leftHeaded, valenceBinDistance);
        double c_aT_hTWd = this.argCounter.getCount(temp);
        temp = new IntDependency(dependency.head, this.wildTW, leftHeaded, valenceBinDistance);
        double c_hTWd = this.argCounter.getCount(temp);
        temp = new IntDependency(unknownHead, dependency.arg, leftHeaded, valenceBinDistance);
        double c_aTW_hTd = this.argCounter.getCount(temp);
        temp = new IntDependency(unknownHead, unknownArg, leftHeaded, valenceBinDistance);
        double c_aT_hTd = this.argCounter.getCount(temp);
        temp = new IntDependency(unknownHead, this.wildTW, leftHeaded, valenceBinDistance);
        double c_hTd = this.argCounter.getCount(temp);
        temp = new IntDependency(this.wildTW, dependency.arg, false, -1);
        double c_aTW = this.argCounter.getCount(temp);
        temp = new IntDependency(this.wildTW, unknownArg, false, -1);
        double c_aT = this.argCounter.getCount(temp);
        double p_aTW_hTd = c_hTd > 0.0 ? c_aTW_hTd / c_hTd : 0.0;
        double p_aT_hTd = c_hTd > 0.0 ? c_aT_hTd / c_hTd : 0.0;
        double p_aTW_aT = c_aTW > 0.0 ? c_aTW / c_aT : 1.0;
        double pb_aT_hTWd = (c_aT_hTWd + this.smooth_aT_hTWd * p_aT_hTd) / (c_hTWd + this.smooth_aT_hTWd);
        List<Triple<Integer, String, Double>> sim2arg = this.simArgMap.get(new Pair<Integer, String>(dependency.arg.word, this.stringBasicCategory(dependency.arg.tag)));
        List<Triple<Integer, String, Double>> sim2head = this.simHeadMap.get(new Pair<Integer, String>(dependency.head.word, this.stringBasicCategory(dependency.head.tag)));
        ArrayList simArg = new ArrayList();
        ArrayList simHead = new ArrayList();
        if (sim2arg != null) {
            for (Triple<Integer, String, Double> t : sim2arg) {
                simArg.add(t.first);
            }
        }
        if (sim2head != null) {
            for (Triple<Integer, String, Double> t : sim2head) {
                simHead.add(t.first);
            }
        }
        double cSim_aTW_hTd = 0.0;
        double cSim_hTd = 0.0;
        Iterator iterator = simHead.iterator();
        while (iterator.hasNext()) {
            int h = (Integer)iterator.next();
            IntTaggedWord hWord = new IntTaggedWord(h, dependency.head.tag);
            temp = new IntDependency(hWord, dependency.arg, dependency.leftHeaded, dependency.distance);
            cSim_aTW_hTd += this.argCounter.getCount(temp);
            temp = new IntDependency(hWord, this.wildTW, dependency.leftHeaded, dependency.distance);
            cSim_hTd += this.argCounter.getCount(temp);
        }
        double d = pSim_aTW_hTd = cSim_hTd > 0.0 ? cSim_aTW_hTd / cSim_hTd : 0.0;
        if (pSim_aTW_hTd > 0.0) {
            System.out.println(dependency + "\t" + pSim_aTW_hTd);
        }
        double smoothSim_aTW_hTWd = 17.7;
        double smooth_aTW_hTWd = 35.4;
        double pb_aTW_hTWd = (c_aTW_hTWd + smoothSim_aTW_hTWd * pSim_aTW_hTd + smooth_aTW_hTWd * p_aTW_hTd) / (c_hTWd + smoothSim_aTW_hTWd + smooth_aTW_hTWd);
        System.out.println(dependency);
        System.out.println(c_aTW_hTWd + " + " + smoothSim_aTW_hTWd + " * " + pSim_aTW_hTd + " + " + smooth_aTW_hTWd + " * " + p_aTW_hTd);
        System.out.println("--------------------------------  = " + pb_aTW_hTWd);
        System.out.println(c_hTWd + " + " + smoothSim_aTW_hTWd + " + " + smooth_aTW_hTWd);
        System.out.println();
        double score = (this.interp * pb_aTW_hTWd + (1.0 - this.interp) * p_aTW_aT * pb_aT_hTWd) * pb_go_hTWds;
        if (this.op.testOptions.prunePunc && this.pruneTW(aTW)) {
            return 1.0;
        }
        if (Double.isNaN(score)) {
            score = 0.0;
        }
        if (score < 1.0E-40) {
            score = 0.0;
        }
        return score;
    }

    private double probSimilarWordAvg(IntDependency dep) {
        double smoothProb;
        double simProb;
        double regProb = this.probTB(dep);
        this.statsCounter.incrementCount("total");
        List<Triple<Integer, String, Double>> sim2arg = this.simArgMap.get(new Pair<Integer, String>(dep.arg.word, this.stringBasicCategory(dep.arg.tag)));
        List<Triple<Integer, String, Double>> sim2head = this.simHeadMap.get(new Pair<Integer, String>(dep.head.word, this.stringBasicCategory(dep.head.tag)));
        if (sim2head == null && sim2arg == null) {
            return regProb;
        }
        double sumScores = 0.0;
        double sumWeights = 0.0;
        if (sim2head == null) {
            this.statsCounter.incrementCount("aSim");
            for (Triple<Integer, String, Double> simArg : sim2arg) {
                double weight = Math.exp(-50.0 * (Double)simArg.third);
                int numT = this.tagIndex.size();
                for (int tag = 0; tag < numT; ++tag) {
                    if (!this.stringBasicCategory(tag).equals(simArg.second)) continue;
                    IntTaggedWord tempArg = new IntTaggedWord((Integer)simArg.first, tag);
                    IntDependency tempDep = new IntDependency(dep.head, tempArg, dep.leftHeaded, dep.distance);
                    double probArg = Math.exp(this.lex.score(tempArg, 0, (String)this.wordIndex.get(tempArg.word), null));
                    if (probArg == 0.0) continue;
                    sumScores += this.probTB(tempDep) * weight / probArg;
                    sumWeights += weight;
                }
            }
        } else if (sim2arg == null) {
            this.statsCounter.incrementCount("hSim");
            for (Triple<Integer, String, Double> simHead : sim2head) {
                double weight = Math.exp(-50.0 * (Double)simHead.third);
                int numT = this.tagIndex.size();
                for (int tag = 0; tag < numT; ++tag) {
                    if (!this.stringBasicCategory(tag).equals(simHead.second)) continue;
                    IntTaggedWord tempHead = new IntTaggedWord((Integer)simHead.first, tag);
                    IntDependency tempDep = new IntDependency(tempHead, dep.arg, dep.leftHeaded, dep.distance);
                    sumScores += this.probTB(tempDep) * weight;
                    sumWeights += weight;
                }
            }
        } else {
            this.statsCounter.incrementCount("hSim");
            this.statsCounter.incrementCount("aSim");
            this.statsCounter.incrementCount("aSim&hSim");
            for (Triple<Integer, String, Double> simArg : sim2arg) {
                int numT = this.tagIndex.size();
                for (int aTag = 0; aTag < numT; ++aTag) {
                    IntTaggedWord tempArg;
                    double probArg;
                    if (!this.stringBasicCategory(aTag).equals(simArg.second) || (probArg = Math.exp(this.lex.score(tempArg = new IntTaggedWord((Integer)simArg.first, aTag), 0, (String)this.wordIndex.get(tempArg.word), null))) == 0.0) continue;
                    for (Triple<Integer, String, Double> simHead : sim2head) {
                        for (int hTag = 0; hTag < numT; ++hTag) {
                            if (!this.stringBasicCategory(hTag).equals(simHead.second)) continue;
                            IntTaggedWord tempHead = new IntTaggedWord((Integer)simHead.first, aTag);
                            IntDependency tempDep = new IntDependency(tempHead, tempArg, dep.leftHeaded, dep.distance);
                            double weight = Math.exp(-50.0 * (Double)simHead.third) * Math.exp(-50.0 * (Double)simArg.third);
                            sumScores += this.probTB(tempDep) * weight / probArg;
                            sumWeights += weight;
                        }
                    }
                }
            }
        }
        IntDependency temp = new IntDependency(dep.head, this.wildTW, dep.leftHeaded, dep.distance);
        double countHead = this.argCounter.getCount(temp);
        if (sim2arg == null) {
            simProb = sumScores / sumWeights;
        } else {
            double probArg = Math.exp(this.lex.score(dep.arg, 0, (String)this.wordIndex.get(dep.arg.word), null));
            simProb = probArg * sumScores / sumWeights;
        }
        if (simProb == 0.0) {
            this.statsCounter.incrementCount("simProbZero");
        }
        if (regProb == 0.0) {
            this.statsCounter.incrementCount("regProbZero");
        }
        if ((smoothProb = (countHead * regProb + 10.0 * simProb) / (countHead + 10.0)) == 0.0) {
            this.statsCounter.incrementCount("smoothProbZero");
        }
        return smoothProb;
    }

    private String stringBasicCategory(int tag) {
        return this.tlp.basicCategory((String)this.tagIndex.get(tag));
    }
}

