/*
 * Decompiled with CFR 0.152.
 */
package cc.mallet.grmm.test;

import cc.mallet.grmm.inference.AbstractBeliefPropagation;
import cc.mallet.grmm.inference.AbstractInferencer;
import cc.mallet.grmm.inference.BruteForceInferencer;
import cc.mallet.grmm.inference.ExactSampler;
import cc.mallet.grmm.inference.GibbsSampler;
import cc.mallet.grmm.inference.Inferencer;
import cc.mallet.grmm.inference.JunctionTree;
import cc.mallet.grmm.inference.JunctionTreeInferencer;
import cc.mallet.grmm.inference.LoopyBP;
import cc.mallet.grmm.inference.RandomGraphs;
import cc.mallet.grmm.inference.SamplingInferencer;
import cc.mallet.grmm.inference.TRP;
import cc.mallet.grmm.inference.TreeBP;
import cc.mallet.grmm.inference.VariableElimination;
import cc.mallet.grmm.types.AbstractTableFactor;
import cc.mallet.grmm.types.Assignment;
import cc.mallet.grmm.types.AssignmentIterator;
import cc.mallet.grmm.types.CPT;
import cc.mallet.grmm.types.DirectedModel;
import cc.mallet.grmm.types.Factor;
import cc.mallet.grmm.types.FactorGraph;
import cc.mallet.grmm.types.HashVarSet;
import cc.mallet.grmm.types.LogTableFactor;
import cc.mallet.grmm.types.TableFactor;
import cc.mallet.grmm.types.Tree;
import cc.mallet.grmm.types.UndirectedModel;
import cc.mallet.grmm.types.VarSet;
import cc.mallet.grmm.types.Variable;
import cc.mallet.grmm.util.GeneralUtils;
import cc.mallet.grmm.util.ModelReader;
import cc.mallet.types.Dirichlet;
import cc.mallet.types.Matrix;
import cc.mallet.types.Matrixn;
import cc.mallet.types.tests.TestSerializable;
import cc.mallet.util.CollectionUtils;
import cc.mallet.util.MalletLogger;
import cc.mallet.util.Maths;
import cc.mallet.util.Randoms;
import cc.mallet.util.Timing;
import gnu.trove.TDoubleArrayList;
import java.io.BufferedReader;
import java.io.IOException;
import java.io.StringReader;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.Date;
import java.util.Iterator;
import java.util.LinkedList;
import java.util.List;
import java.util.Random;
import java.util.logging.Logger;
import junit.framework.AssertionFailedError;
import junit.framework.Test;
import junit.framework.TestCase;
import junit.framework.TestSuite;
import junit.textui.TestRunner;

public class TestInference
extends TestCase {
    private static Logger logger = MalletLogger.getLogger(TestInference.class.getName());
    private static double APPX_EPSILON = 0.15;
    public final Class[] algorithms = new Class[]{BruteForceInferencer.class, VariableElimination.class, JunctionTreeInferencer.class};
    public final Class[] appxAlgs = new Class[]{TRP.class, LoopyBP.class};
    public final Class[] allAlgs = new Class[]{JunctionTreeInferencer.class, TRP.class, LoopyBP.class};
    public final Class[] treeAlgs = new Class[]{TreeBP.class};
    List modelsList;
    UndirectedModel[] models;
    FactorGraph[] trees;
    Factor[][] treeMargs;
    private static final int JT_CHAIN_TEST_TREE = 2;
    private static String[] treeStrs = new String[]{"<TREE>  <VAR NAME='V0'>    <FACTOR VARS='V0 V1'>      <VAR NAME='V1'/>    </FACTOR>    <FACTOR VARS='V0 V2'>      <VAR NAME='V2'/>    </FACTOR>  </VAR></TREE>", "<TREE>  <VAR NAME='V1'>    <FACTOR VARS='V0 V1'>      <VAR NAME='V0'/>    </FACTOR>    <FACTOR VARS='V1 V2'>      <VAR NAME='V2'/>    </FACTOR>  </VAR></TREE>", "<TREE>  <VAR NAME='V0'>    <FACTOR VARS='V0 V1'>      <VAR NAME='V1'>  <FACTOR VARS='V1 V2'>    <VAR NAME='V2'/>  </FACTOR></VAR>    </FACTOR>  </VAR></TREE>", "<TREE>  <VAR NAME='V2'>    <FACTOR VARS='V2 V1'>      <VAR NAME='V1'/>    </FACTOR>    <FACTOR VARS='V0 V2'>      <VAR NAME='V0'/>    </FACTOR>  </VAR></TREE>"};
    private String gridStr = "VAR alpha u : CONTINUOUS\nalpha ~ Uniform -1.0 1.0\nu ~ Uniform -2.0 2.0\nx00 ~ Unary u\nx10 ~ Unary u\nx01 ~ Unary u\nx11 ~ Unary u\nx00 x01 ~ Potts alpha\nx00 x10 ~ Potts alpha\nx01 x11 ~ Potts alpha\nx10 x11 ~ Potts alpha\n";

    public TestInference(String name) {
        super(name);
    }

    private static UndirectedModel createChainGraph() {
        Variable[] vars = new Variable[5];
        UndirectedModel model = new UndirectedModel();
        try {
            for (int i = 0; i < 5; ++i) {
                vars[i] = new Variable(2);
            }
            double[] probs = new double[]{0.9, 0.1, 0.1, 0.9};
            for (int i = 0; i < 4; ++i) {
                Variable[] pair = new Variable[]{vars[i], vars[i + 1]};
                TableFactor pot = new TableFactor(pair, probs);
                model.addFactor(pot);
            }
        }
        catch (Exception e) {
            e.printStackTrace();
            TestInference.assertTrue((boolean)false);
        }
        return model;
    }

    private static UndirectedModel createTriangle() {
        Variable[] vars = new Variable[3];
        for (int i = 0; i < 3; ++i) {
            vars[i] = new Variable(2);
        }
        UndirectedModel model = new UndirectedModel(vars);
        double[][] pots = new double[][]{{0.2, 0.8, 0.1, 0.9}, {0.7, 0.3, 0.5, 0.5}, {0.6, 0.4, 0.8, 0.2}, {0.35, 0.65}};
        model.addFactor(vars[0], vars[1], pots[0]);
        model.addFactor(vars[1], vars[2], pots[1]);
        model.addFactor(vars[2], vars[0], pots[2]);
        TableFactor pot = new TableFactor(new Variable[]{vars[0]}, pots[3]);
        model.addFactor(pot);
        return model;
    }

    private static TableFactor randomEdgePotential(Random r, Variable v1, Variable v2) {
        int max1 = v1.getNumOutcomes();
        int max2 = v2.getNumOutcomes();
        Matrixn phi = new Matrixn(new int[]{max1, max2});
        for (int i = 0; i < v1.getNumOutcomes(); ++i) {
            int j = 0;
            while (j < v2.getNumOutcomes()) {
                phi.setValue(new int[]{i, j++}, r.nextDouble());
            }
        }
        return new TableFactor(new Variable[]{v1, v2}, (Matrix)phi);
    }

    private static TableFactor randomNodePotential(Random r, Variable v) {
        int max = v.getNumOutcomes();
        Matrixn phi = new Matrixn(new int[]{max});
        for (int i = 0; i < v.getNumOutcomes(); ++i) {
            phi.setSingleValue(i, TestInference.rescale(r.nextDouble()));
        }
        return new TableFactor(new Variable[]{v}, (Matrix)phi);
    }

    private static double rescale(double d) {
        return 0.2 + 0.6 * d;
    }

    private static UndirectedModel createRandomGraph(int numV, int numOutcomes, Random r) {
        int i;
        Variable[] vars = new Variable[numV];
        for (int i2 = 0; i2 < numV; ++i2) {
            vars[i2] = new Variable(numOutcomes);
        }
        UndirectedModel model = new UndirectedModel(vars);
        for (i = 0; i < numV; ++i) {
            boolean hasOne = false;
            for (int j = i + 1; j < numV; ++j) {
                if (!r.nextBoolean()) continue;
                hasOne = true;
                model.addFactor(TestInference.randomEdgePotential(r, vars[i], vars[j]));
            }
            if (hasOne) continue;
            TableFactor pot = TestInference.randomNodePotential(r, vars[i]);
            pot.normalize();
            model.addFactor(pot);
        }
        for (i = 0; i < numV; ++i) {
            for (int j = i + 1; j < numV; ++j) {
                if (model.isConnected(vars[i], vars[j])) continue;
                TableFactor ptl = TestInference.randomEdgePotential(r, vars[i], vars[j]);
                model.addFactor(ptl);
            }
        }
        return model;
    }

    public static UndirectedModel createRandomGrid(int w, int h, int maxOutcomes, Random r) {
        int j;
        int i;
        Variable[][] vars = new Variable[w][h];
        UndirectedModel mdl = new UndirectedModel(w * h);
        for (i = 0; i < w; ++i) {
            for (j = 0; j < h; ++j) {
                vars[i][j] = new Variable(r.nextInt(maxOutcomes - 1) + 2);
            }
        }
        for (i = 0; i < w; ++i) {
            for (j = 0; j < h; ++j) {
                TableFactor ptl;
                if (i < w - 1) {
                    ptl = TestInference.randomEdgePotential(r, vars[i][j], vars[i + 1][j]);
                    mdl.addFactor(ptl);
                }
                if (j >= h - 1) continue;
                ptl = TestInference.randomEdgePotential(r, vars[i][j], vars[i][j + 1]);
                mdl.addFactor(ptl);
            }
        }
        return mdl;
    }

    private UndirectedModel createRandomTree(int nnodes, int maxOutcomes, Random r) {
        TableFactor ptl;
        int j;
        int i;
        Variable[] vars = new Variable[nnodes];
        UndirectedModel mdl = new UndirectedModel(nnodes);
        for (i = 0; i < nnodes; ++i) {
            vars[i] = new Variable(r.nextInt(maxOutcomes - 1) + 2);
        }
        for (i = 0; i < nnodes; ++i) {
            for (j = i + 1; j < nnodes; ++j) {
                if (mdl.isConnected(vars[i], vars[j]) || !r.nextBoolean()) continue;
                ptl = TestInference.randomEdgePotential(r, vars[i], vars[j]);
                mdl.addFactor(ptl);
            }
        }
        for (i = 0; i < nnodes; ++i) {
            for (j = i + 1; j < nnodes; ++j) {
                if (mdl.isConnected(vars[i], vars[j])) continue;
                System.out.println("forced edge: " + i + " " + j);
                ptl = TestInference.randomEdgePotential(r, vars[i], vars[j]);
                mdl.addFactor(ptl);
            }
        }
        return mdl;
    }

    public static List createTestModels() {
        Random r = new Random(42L);
        FactorGraph[] mdls = new FactorGraph[]{TestInference.createTriangle(), TestInference.createChainGraph(), TestInference.createRandomGraph(3, 2, r), TestInference.createRandomGraph(3, 3, r), TestInference.createRandomGraph(6, 3, r), TestInference.createRandomGraph(8, 2, r), TestInference.createRandomGrid(3, 2, 4, r), TestInference.createRandomGrid(4, 3, 2, r)};
        return new ArrayList<FactorGraph>(Arrays.asList(mdls));
    }

    public void testUniformJoint() throws Exception {
        FactorGraph mdl = RandomGraphs.createUniformChain(3);
        double expected = -Math.log(8.0);
        for (int i = 0; i < this.allAlgs.length; ++i) {
            Inferencer inf = (Inferencer)this.allAlgs[i].newInstance();
            inf.computeMarginals(mdl);
            AssignmentIterator it = mdl.assignmentIterator();
            while (it.hasNext()) {
                Assignment assn = it.assignment();
                double actual = inf.lookupLogJoint(assn);
                TestInference.assertEquals((String)("Incorrect joint for inferencer " + inf), (double)expected, (double)actual, (double)1.0E-5);
                it.advance();
            }
        }
    }

    public void testJointConsistent() throws Exception {
        for (int i = 0; i < this.allAlgs.length; ++i) {
            int mdlIdx = 13;
            Inferencer inf = (Inferencer)this.allAlgs[i].newInstance();
            try {
                UndirectedModel mdl = this.models[mdlIdx];
                inf.computeMarginals(mdl);
                Assignment assn = new Assignment(mdl, new int[mdl.numVariables()]);
                TestInference.assertEquals((double)Math.log(inf.lookupJoint(assn)), (double)inf.lookupLogJoint(assn), (double)1.0E-5);
                continue;
            }
            catch (UnsupportedOperationException e) {
                logger.warning("Skipping (" + mdlIdx + "," + i + ")\n" + e);
                throw e;
            }
        }
    }

    public void testFactorizedJoint() throws Exception {
        Inferencer[][] infs = new Inferencer[this.allAlgs.length][this.models.length];
        for (int i = 0; i < this.allAlgs.length; ++i) {
            for (int mdl = 0; mdl < this.models.length; ++mdl) {
                Inferencer alg = (Inferencer)this.allAlgs[i].newInstance();
                if (alg instanceof TRP) {
                    ((TRP)alg).setRandomSeed(1231234L);
                }
                try {
                    alg.computeMarginals(this.models[mdl]);
                    infs[i][mdl] = alg;
                    continue;
                }
                catch (UnsupportedOperationException e) {
                    logger.warning("Skipping (" + mdl + "," + i + ")\n" + e);
                    throw e;
                }
            }
        }
        int alg1 = 0;
        for (int alg2 = 1; alg2 < this.allAlgs.length; ++alg2) {
            for (int mdl = 0; mdl < this.models.length; ++mdl) {
                Inferencer inf1 = infs[alg1][mdl];
                Inferencer inf2 = infs[alg2][mdl];
                if (inf1 == null || inf2 == null) continue;
                AssignmentIterator it = this.models[mdl].assignmentIterator();
                while (it.hasNext()) {
                    try {
                        Assignment assn = (Assignment)it.next();
                        double joint1 = inf1.lookupLogJoint(assn);
                        double joint2 = inf2.lookupLogJoint(assn);
                        logger.finest("logJoint: " + inf1 + " " + inf2 + "  Model " + mdl + "  Assn: " + assn + "  INF1: " + joint1 + "\n" + "  INF2: " + joint2 + "\n");
                        TestInference.assertTrue((String)("logJoint not equal btwn " + GeneralUtils.classShortName(inf1) + " " + " and " + GeneralUtils.classShortName(inf2) + "\n" + "  Model " + mdl + "\n" + "  INF1: " + joint1 + "\n" + "  INF2: " + joint2 + "\n"), (Math.abs(joint1 - joint2) < 0.2 ? 1 : 0) != 0);
                        double joint3 = inf1.lookupJoint(assn);
                        TestInference.assertTrue((String)("logJoint & joint not consistent\n  Model " + mdl + "\n" + assn), (boolean)Maths.almostEquals(joint3, Math.exp(joint1)));
                    }
                    catch (UnsupportedOperationException e) {
                        logger.warning("Skipping " + inf1 + " -> " + inf2 + "\n" + e);
                    }
                }
            }
        }
    }

    public void testMarginals() throws Exception {
        int mdl;
        Factor[][][] joints = new Factor[this.models.length][][];
        Inferencer[] appxInferencers = this.constructAllAppxInferencers();
        int numExactAlgs = this.algorithms.length;
        int numAppxAlgs = appxInferencers.length;
        int numAlgs = numExactAlgs + numAppxAlgs;
        for (mdl = 0; mdl < this.models.length; ++mdl) {
            joints[mdl] = new Factor[numAlgs][];
        }
        for (int i = 0; i < this.algorithms.length; ++i) {
            for (int mdl2 = 0; mdl2 < this.models.length; ++mdl2) {
                Inferencer alg = (Inferencer)this.algorithms[i].newInstance();
                logger.fine("Computing marginals for model " + mdl2 + " alg " + alg);
                alg.computeMarginals(this.models[mdl2]);
                joints[mdl2][i] = this.collectAllMarginals(this.models[mdl2], alg);
            }
        }
        logger.fine("Checking that results are consistent...");
        for (mdl = 0; mdl < this.models.length; ++mdl) {
            int maxV = this.models[mdl].numVariables();
            for (int vrt = 0; vrt < maxV; ++vrt) {
                for (int alg1 = 0; alg1 < this.algorithms.length; ++alg1) {
                    for (int alg2 = 0; alg2 < this.algorithms.length; ++alg2) {
                        Factor joint1 = joints[mdl][alg1][vrt];
                        Factor joint2 = joints[mdl][alg2][vrt];
                        try {
                            if (joint1 == null || joint2 == null) continue;
                            TestInference.assertTrue((boolean)joint1.almostEquals(joint2));
                            continue;
                        }
                        catch (AssertionFailedError e) {
                            int i;
                            System.out.println("\n************************************\nTest FAILED\n\n");
                            System.out.println("Model " + mdl + " Vertex " + vrt);
                            System.out.println("Algs " + alg1 + " and " + alg2 + " not consistent.");
                            System.out.println("MARGINAL from " + alg1);
                            System.out.println(joint1);
                            System.out.println("MARGINAL from " + alg2);
                            System.out.println(joint2);
                            System.out.println("Marginals from " + alg1 + ":");
                            for (i = 0; i < maxV; ++i) {
                                System.out.println(joints[mdl][alg1][i]);
                            }
                            System.out.println("Marginals from " + alg2 + ":");
                            for (i = 0; i < maxV; ++i) {
                                System.out.println(joints[mdl][alg2][i]);
                            }
                            this.models[mdl].dump();
                            throw e;
                        }
                    }
                }
            }
        }
        logger.fine("Checking the approximate algorithms...");
        int alg2 = 0;
        for (int appxIdx = 0; appxIdx < appxInferencers.length; ++appxIdx) {
            Inferencer alg = appxInferencers[appxIdx];
            for (int mdl3 = 0; mdl3 < this.models.length; ++mdl3) {
                logger.finer("Running inference alg " + alg + " with model " + mdl3);
                try {
                    alg.computeMarginals(this.models[mdl3]);
                }
                catch (UnsupportedOperationException e) {
                    if (alg instanceof AbstractBeliefPropagation) {
                        logger.warning("Skipping model " + mdl3 + " for alg " + alg + "\nInference unsupported.");
                        continue;
                    }
                    throw e;
                }
                int vrt = 0;
                int alg1 = numExactAlgs + appxIdx;
                int maxV = this.models[mdl3].numVariables();
                joints[mdl3][alg1] = new Factor[maxV];
                for (Variable var : this.models[mdl3].variablesSet()) {
                    logger.finer("Lookup marginal for model " + mdl3 + " vrt " + var + " alg " + alg);
                    Factor ptl = alg.lookupMarginal(var);
                    joints[mdl3][alg1][vrt] = ptl.duplicate();
                    ++vrt;
                }
                for (vrt = 0; vrt < maxV; ++vrt) {
                    Factor joint1 = joints[mdl3][alg1][vrt];
                    Factor joint2 = joints[mdl3][alg2][vrt];
                    try {
                        TestInference.assertTrue((boolean)joint1.almostEquals(joint2, APPX_EPSILON));
                        continue;
                    }
                    catch (AssertionFailedError e) {
                        int i;
                        System.out.println("\n************************************\nAppx Marginal Test FAILED\n\n");
                        System.out.println("Inferencer: " + alg);
                        System.out.println("Model " + mdl3 + " Vertex " + vrt);
                        System.out.println(joint1.dumpToString());
                        System.out.println(joint2.dumpToString());
                        this.models[mdl3].dump();
                        System.out.println("All marginals:");
                        for (i = 0; i < maxV; ++i) {
                            System.out.println(joints[mdl3][alg1][i].dumpToString());
                        }
                        System.out.println("Correct marginals:");
                        for (i = 0; i < maxV; ++i) {
                            System.out.println(joints[mdl3][alg2][i].dumpToString());
                        }
                        throw e;
                    }
                }
            }
        }
        System.out.println("Tested " + this.models.length + " undirected models.");
    }

    private Inferencer[] constructAllAppxInferencers() throws IllegalAccessException, InstantiationException {
        ArrayList<Object> algs = new ArrayList<Object>(this.appxAlgs.length * 2);
        for (int i = 0; i < this.appxAlgs.length; ++i) {
            algs.add(this.appxAlgs[i].newInstance());
        }
        algs.add(new TRP().setMessager(new AbstractBeliefPropagation.SumProductMessageStrategy(0.8)));
        algs.add(new LoopyBP().setMessager(new AbstractBeliefPropagation.SumProductMessageStrategy(0.8)));
        algs.add(new SamplingInferencer(new GibbsSampler(10000), 10000));
        algs.add(new SamplingInferencer(new ExactSampler(), 1000));
        return algs.toArray(new Inferencer[algs.size()]);
    }

    private Inferencer[] constructMaxProductInferencers() throws IllegalAccessException, InstantiationException {
        ArrayList<Inferencer> algs = new ArrayList<Inferencer>();
        algs.add(JunctionTreeInferencer.createForMaxProduct());
        algs.add(TRP.createForMaxProduct());
        algs.add(LoopyBP.createForMaxProduct());
        return algs.toArray(new Inferencer[algs.size()]);
    }

    private Factor[] collectAllMarginals(FactorGraph mdl, Inferencer alg) {
        int vrt = 0;
        int numVertices = mdl.numVariables();
        Factor[] collector = new Factor[numVertices];
        for (Variable var : mdl.variablesSet()) {
            try {
                collector[vrt] = alg.lookupMarginal(var);
                assert (collector[vrt] != null) : "Query returned null for model " + mdl + " vertex " + var + " alg " + alg;
            }
            catch (UnsupportedOperationException e) {
                logger.warning("Warning: Skipping model " + mdl + " for alg " + alg + "\n  Inference unsupported.");
            }
            ++vrt;
        }
        return collector;
    }

    public void testQuery() throws Exception {
        Random rand = new Random(15667L);
        for (int mdlIdx = 0; mdlIdx < this.models.length; ++mdlIdx) {
            UndirectedModel mdl = this.models[mdlIdx];
            int size = rand.nextInt(3) + 2;
            size = Math.min(size, mdl.varSet().size());
            Collection vars = CollectionUtils.subset(mdl.variablesSet(), size, rand);
            Variable[] varArr = vars.toArray(new Variable[0]);
            Assignment assn = new Assignment(varArr, new int[size]);
            BruteForceInferencer brute = new BruteForceInferencer();
            Factor joint = brute.joint(mdl);
            double marginal = joint.marginalize(vars).value(assn);
            for (int algIdx = 0; algIdx < this.appxAlgs.length; ++algIdx) {
                Inferencer alg = (Inferencer)this.appxAlgs[algIdx].newInstance();
                if (alg instanceof TRP) continue;
                double returned = alg.query(mdl, assn);
                TestInference.assertEquals((String)("Failure on model " + mdlIdx + " alg " + alg), (double)marginal, (double)returned, (double)APPX_EPSILON);
            }
        }
        logger.info("Test testQuery passed.");
    }

    public void testSerializable() throws Exception {
        Inferencer alg;
        int i;
        for (i = 0; i < this.algorithms.length; ++i) {
            alg = (Inferencer)this.algorithms[i].newInstance();
            this.testSerializationForAlg(alg);
        }
        for (i = 0; i < this.appxAlgs.length; ++i) {
            alg = (Inferencer)this.appxAlgs[i].newInstance();
            this.testSerializationForAlg(alg);
        }
        Inferencer[] maxAlgs = this.constructMaxProductInferencers();
        for (int i2 = 0; i2 < maxAlgs.length; ++i2) {
            this.testSerializationForAlg(maxAlgs[i2]);
        }
    }

    private void testSerializationForAlg(Inferencer alg) throws IOException, ClassNotFoundException {
        for (int mdlIdx = 0; mdlIdx < this.models.length; ++mdlIdx) {
            UndirectedModel mdl = this.models[mdlIdx];
            Inferencer alg2 = (Inferencer)TestSerializable.cloneViaSerialization(alg);
            alg.computeMarginals(mdl);
            Factor[] pre = this.collectAllMarginals(mdl, alg);
            alg2.computeMarginals(mdl);
            Factor[] post2 = this.collectAllMarginals(mdl, alg2);
            this.compareMarginals("Error comparing marginals after serialzation on model " + mdl, pre, post2);
        }
    }

    private void compareMarginals(String msg, Factor[] pre, Factor[] post) {
        for (int i = 0; i < pre.length; ++i) {
            Factor ptl1 = pre[i];
            Factor ptl2 = post[i];
            TestInference.assertTrue((String)(msg + "\n" + ptl1.dumpToString() + "\n" + ptl2.dumpToString()), (boolean)ptl1.almostEquals(ptl2, 0.001));
        }
    }

    public void ignoreTestNumMessages() {
        for (int mdlIdx = 0; mdlIdx < this.models.length; ++mdlIdx) {
            UndirectedModel mdl = this.models[mdlIdx];
            TRP trp = new TRP();
            trp.computeMarginals(mdl);
            int expectedMessages = (mdl.numVariables() - 1) * 2 * trp.iterationsUsed();
            TestInference.assertEquals((int)expectedMessages, (int)trp.getTotalMessagesSent());
            LoopyBP loopy = new LoopyBP();
            loopy.computeMarginals(mdl);
            expectedMessages = mdl.getEdgeSet().size() * 2 * loopy.iterationsUsed();
            TestInference.assertEquals((int)expectedMessages, (int)loopy.getTotalMessagesSent());
        }
    }

    private UndirectedModel createJtChain() {
        int numNodes = 4;
        Variable[] nodes = new Variable[numNodes];
        for (int i = 0; i < numNodes; ++i) {
            nodes[i] = new Variable(2);
        }
        TableFactor[] pots = new TableFactor[]{new TableFactor(new Variable[]{nodes[0], nodes[1]}, new double[]{1.0, 2.0, 5.0, 4.0}), new TableFactor(new Variable[]{nodes[1], nodes[2]}, new double[]{4.0, 2.0, 4.0, 1.0}), new TableFactor(new Variable[]{nodes[2], nodes[3]}, new double[]{7.0, 3.0, 6.0, 9.0})};
        for (int i = 0; i < pots.length; ++i) {
            pots[i].normalize();
        }
        UndirectedModel uGraph = new UndirectedModel();
        for (int i = 0; i < numNodes - 1; ++i) {
            uGraph.addFactor(pots[i]);
        }
        return uGraph;
    }

    private void createTestTrees() {
        Random r = new Random(185L);
        this.trees = new FactorGraph[]{RandomGraphs.createUniformChain(2), RandomGraphs.createUniformChain(4), this.createJtChain(), TestInference.createRandomGrid(5, 1, 3, r), TestInference.createRandomGrid(6, 1, 2, r), this.createRandomTree(10, 2, r), this.createRandomTree(10, 2, r), this.createRandomTree(8, 3, r), this.createRandomTree(8, 3, r)};
        this.modelsList.addAll(Arrays.asList(this.trees));
    }

    private void computeTestTreeMargs() {
        this.treeMargs = new Factor[this.trees.length][];
        BruteForceInferencer brute = new BruteForceInferencer();
        for (int i = 0; i < this.trees.length; ++i) {
            FactorGraph mdl = this.trees[i];
            Factor joint = brute.joint(mdl);
            this.treeMargs[i] = new Factor[mdl.numVariables()];
            Iterator it = mdl.variablesIterator();
            while (it.hasNext()) {
                Variable var = (Variable)it.next();
                this.treeMargs[i][mdl.getIndex((Variable)var)] = joint.marginalize(var);
            }
        }
    }

    public void testJtConsistency() {
        for (int mdlIdx = 0; mdlIdx < this.models.length; ++mdlIdx) {
            UndirectedModel mdl = this.models[mdlIdx];
            JunctionTreeInferencer jti = new JunctionTreeInferencer();
            JunctionTree jt = jti.buildJunctionTree(mdl);
            Iterator it = jt.getVerticesIterator();
            while (it.hasNext()) {
                VarSet parent = (VarSet)it.next();
                for (VarSet child : jt.getChildren(parent)) {
                    Factor ptl = jt.getSepsetPot(parent, child);
                    VarSet intersection = parent.intersection(child);
                    TestInference.assertTrue((boolean)intersection.equals(ptl.varSet()));
                }
            }
        }
    }

    private void compareTrpJoint(Factor joint, TRP trp) {
        Assignment assn = null;
        double prob1 = 0.0;
        double prob2 = 0.0;
        try {
            HashVarSet all = new HashVarSet(joint.varSet());
            AssignmentIterator it = all.assignmentIterator();
            while (it.hasNext()) {
                assn = (Assignment)it.next();
                prob1 = trp.lookupJoint(assn);
                TestInference.assertTrue((Math.abs(prob1 - (prob2 = joint.value(assn))) < 0.01 ? 1 : 0) != 0);
            }
        }
        catch (AssertionFailedError e) {
            System.out.println("*****************************************\nTEST FAILURE in compareTrpJoint");
            System.out.println("*****************************************\nat");
            System.out.println(assn);
            System.out.println("Expected: " + prob2);
            System.out.println("TRP: " + prob1);
            System.out.println("*****************************************\nExpected joint");
            System.out.println(joint);
            System.out.println("*****************************************\nTRP dump");
            trp.dump();
            throw e;
        }
    }

    public void testTrp() {
        UndirectedModel model = TestInference.createTriangle();
        TRP trp = new TRP().setTerminator(new TRP.IterationTerminator(200));
        BruteForceInferencer brute = new BruteForceInferencer();
        Factor joint = brute.joint(model);
        trp.computeMarginals(model);
        this.compareTrpJoint(joint, trp);
        try {
            Factor marg2;
            Factor marg1;
            Iterator it = model.variablesIterator();
            while (it.hasNext()) {
                Variable var = (Variable)it.next();
                marg1 = trp.lookupMarginal(var);
                marg2 = joint.marginalize(var);
                TestInference.assertTrue((boolean)marg1.almostEquals(marg2, APPX_EPSILON));
            }
            it = model.factorsIterator();
            while (it.hasNext()) {
                Factor factor = (Factor)it.next();
                marg1 = trp.lookupMarginal(factor.varSet());
                marg2 = joint.marginalize(factor.varSet());
                TestInference.assertTrue((boolean)marg1.almostEquals(marg2, APPX_EPSILON));
            }
        }
        catch (AssertionFailedError e) {
            System.out.println("\n*************************************\nTEST FAILURE in compareTrpMargs");
            System.out.println("*************************************\nComplete model:\n\n");
            model.dump();
            System.out.println("*************************************\nTRP margs:\n\n");
            trp.dump();
            System.out.println("**************************************\nAll correct margs:\n");
            Iterator it2 = model.variablesIterator();
            while (it2.hasNext()) {
                Variable v2 = (Variable)it2.next();
                brute.computeMarginals(model);
                System.out.println(brute.lookupMarginal(v2));
            }
            throw e;
        }
    }

    public void testTrpJoint() {
        UndirectedModel model = TestInference.createTriangle();
        TRP trp = new TRP().setTerminator(new TRP.IterationTerminator(25));
        trp.computeMarginals(model);
        HashVarSet all = new HashVarSet(model.variablesSet());
        AssignmentIterator it = all.assignmentIterator();
        while (it.hasNext()) {
            Assignment assn = (Assignment)it.next();
            double log = trp.lookupLogJoint(assn);
            double prob = trp.lookupJoint(assn);
            TestInference.assertTrue((boolean)Maths.almostEquals(Math.exp(log), prob));
        }
        logger.info("Test trpJoint passed.");
    }

    public void testTrpNonDestructivity() {
        UndirectedModel model = TestInference.createTriangle();
        TRP trp = new TRP(new TRP.IterationTerminator(25));
        BruteForceInferencer brute = new BruteForceInferencer();
        Factor joint1 = brute.joint(model);
        trp.computeMarginals(model);
        Factor joint2 = brute.joint(model);
        TestInference.assertTrue((boolean)joint1.almostEquals(joint2));
        logger.info("Test trpNonDestructivity passed.");
    }

    public void testTrpReuse() {
        TRP trp1 = new TRP(new TRP.IterationTerminator(25));
        for (int i = 0; i < this.models.length; ++i) {
            trp1.computeMarginals(this.models[i]);
        }
        logger.info("Please ensure that all instantiations above run for 25 iterations.");
        UndirectedModel mdl = this.models[0];
        final Tree tree = trp1.new TRP.AlmostRandomTreeFactory().nextTree(mdl);
        TRP trp2 = new TRP(new TRP.TreeFactory(){

            public Tree nextTree(FactorGraph mdl) {
                return tree;
            }
        });
        trp2.computeMarginals(mdl);
        logger.info("Ensure that the above instantiation ran for 1000 iterations with a warning.");
    }

    public void testTrpTreeList() {
        UndirectedModel model = TestInference.createTriangle();
        model.getVariable(0).setLabel("V0");
        model.getVariable(1).setLabel("V1");
        model.getVariable(2).setLabel("V2");
        ArrayList<StringReader> readers = new ArrayList<StringReader>();
        for (int i = 0; i < treeStrs.length; ++i) {
            readers.add(new StringReader(treeStrs[i]));
        }
        TRP trp = new TRP().setTerminator(new TRP.DefaultConvergenceTerminator()).setFactory(TRP.TreeListFactory.makeFromReaders(model, readers));
        trp.computeMarginals(model);
        BruteForceInferencer jt = new BruteForceInferencer();
        jt.computeMarginals(model);
        this.compareMarginals("", model, trp, jt);
    }

    public void testUndirectedIndices() {
        for (int mdlIdx = 0; mdlIdx < this.models.length; ++mdlIdx) {
            UndirectedModel mdl = this.models[mdlIdx];
            Iterator it = mdl.variablesIterator();
            while (it.hasNext()) {
                Variable var1 = (Variable)it.next();
                Variable var2 = mdl.get(mdl.getIndex(var1));
                TestInference.assertTrue((String)("Mismatch in Variable index for " + var1 + " vs " + var2 + " in model " + mdlIdx + "\n" + mdl), (var1 == var2 ? 1 : 0) != 0);
            }
        }
        logger.info("Test undirectedIndices passed.");
    }

    public void testTrpViterbiEquiv() {
        for (int mdlIdx = 0; mdlIdx < this.trees.length; ++mdlIdx) {
            FactorGraph mdl = this.trees[mdlIdx];
            TreeBP maxprod = TreeBP.createForMaxProduct();
            TRP trp = TRP.createForMaxProduct().setTerminator(new TRP.IterationTerminator(1));
            maxprod.computeMarginals(mdl);
            trp.computeMarginals(mdl);
            Iterator it = mdl.variablesIterator();
            while (it.hasNext()) {
                Variable var = (Variable)it.next();
                Factor maxPotBp = maxprod.lookupMarginal(var);
                Factor maxPotTrp = trp.lookupMarginal(var);
                maxPotBp.normalize();
                maxPotTrp.normalize();
                TestInference.assertTrue((String)("TRP 1 iter maxprod propagation not the same as plain maxProd!\nTrp " + maxPotTrp.dumpToString() + "\n Plain maxprod " + maxPotBp.dumpToString()), (boolean)maxPotBp.almostEquals(maxPotTrp));
            }
        }
    }

    public void testTrpOnTrees() {
        for (int mdlIdx = 0; mdlIdx < this.trees.length; ++mdlIdx) {
            FactorGraph mdl = this.trees[mdlIdx];
            TreeBP bp = new TreeBP();
            TRP trp = new TRP().setTerminator(new TRP.IterationTerminator(1));
            bp.computeMarginals(mdl);
            trp.computeMarginals(mdl);
            int[] outcomes = new int[mdl.numVariables()];
            Assignment assn = new Assignment(mdl, outcomes);
            TestInference.assertEquals((double)bp.lookupLogJoint(assn), (double)trp.lookupLogJoint(assn), (double)1.0E-5);
            Arrays.fill(outcomes, 1);
            assn = new Assignment(mdl, outcomes);
            TestInference.assertEquals((double)bp.lookupLogJoint(assn), (double)trp.lookupLogJoint(assn), (double)1.0E-5);
            Iterator it = mdl.variablesIterator();
            while (it.hasNext()) {
                Variable var = (Variable)it.next();
                Factor maxPotBp = bp.lookupMarginal(var);
                Factor maxPotTrp = trp.lookupMarginal(var);
                maxPotBp.normalize();
                maxPotTrp.normalize();
                TestInference.assertTrue((String)("TRP 1 iter bp propagation not the same as plain maxProd!\nTrp " + maxPotTrp.dumpToString() + "\n Plain bp " + maxPotBp.dumpToString()), (boolean)maxPotBp.almostEquals(maxPotTrp));
            }
        }
    }

    public void testTrpViterbiEquiv2() {
        for (int mdlIdx = 0; mdlIdx < this.trees.length; ++mdlIdx) {
            FactorGraph mdl = this.trees[mdlIdx];
            TreeBP maxprod = TreeBP.createForMaxProduct();
            TRP trp = TRP.createForMaxProduct();
            maxprod.computeMarginals(mdl);
            trp.computeMarginals(mdl);
            Iterator it = mdl.variablesIterator();
            while (it.hasNext()) {
                Variable var = (Variable)it.next();
                Factor maxPotBp = maxprod.lookupMarginal(var);
                Factor maxPotTrp = trp.lookupMarginal(var);
                TestInference.assertTrue((String)("TRP maxprod propagation not the same as plain maxProd!\nTrp " + maxPotTrp + "\n Plain maxprod " + maxPotBp), (boolean)maxPotBp.almostEquals(maxPotTrp));
            }
        }
    }

    public void testTreeViterbi() {
        for (int mdlIdx = 0; mdlIdx < this.trees.length; ++mdlIdx) {
            FactorGraph mdl = this.trees[mdlIdx];
            BruteForceInferencer brute = new BruteForceInferencer();
            TreeBP maxprod = TreeBP.createForMaxProduct();
            Factor joint = brute.joint(mdl);
            maxprod.computeMarginals(mdl);
            Iterator it = mdl.variablesIterator();
            while (it.hasNext()) {
                Variable var = (Variable)it.next();
                Factor maxPot = maxprod.lookupMarginal(var);
                Factor trueMaxPot = joint.extractMax(var);
                maxPot.normalize();
                trueMaxPot.normalize();
                TestInference.assertTrue((String)("Maximization failed! Normalized returns:\n" + maxPot + "\nTrue: " + trueMaxPot), (boolean)maxPot.almostEquals(trueMaxPot));
            }
        }
        logger.info("Test treeViterbi passed: " + this.trees.length + " models.");
    }

    public void testJtViterbi() {
        JunctionTreeInferencer jti = new JunctionTreeInferencer();
        for (int mdlIdx = 0; mdlIdx < this.models.length; ++mdlIdx) {
            UndirectedModel mdl = this.models[mdlIdx];
            BruteForceInferencer brute = new BruteForceInferencer();
            JunctionTreeInferencer maxprod = JunctionTreeInferencer.createForMaxProduct();
            JunctionTree jt = maxprod.buildJunctionTree(mdl);
            Factor joint = brute.joint(mdl);
            maxprod.computeMarginals(jt);
            Iterator it = mdl.variablesIterator();
            while (it.hasNext()) {
                Variable var = (Variable)it.next();
                Factor maxPotRaw = maxprod.lookupMarginal(var);
                Factor trueMaxPotRaw = joint.extractMax(var);
                Factor maxPot = maxPotRaw.duplicate().normalize();
                Factor trueMaxPot = trueMaxPotRaw.duplicate().normalize();
                TestInference.assertTrue((String)("Maximization failed on model " + mdlIdx + " ! Normalized returns:\n" + maxPot.dumpToString() + "\nTrue: " + trueMaxPot.dumpToString()), (boolean)maxPot.almostEquals(trueMaxPot, 0.01));
            }
        }
        logger.info("Test jtViterbi passed.");
    }

    public void testMaxMarginals() throws Exception {
        for (int mdlIdx = 0; mdlIdx < this.models.length; ++mdlIdx) {
            UndirectedModel mdl = this.models[mdlIdx];
            BruteForceInferencer brute = new BruteForceInferencer();
            Factor joint = brute.joint(mdl);
            Inferencer[] algs = this.constructMaxProductInferencers();
            for (int infIdx = 0; infIdx < algs.length; ++infIdx) {
                Inferencer inf = algs[infIdx];
                if (inf instanceof TRP) {
                    ((TRP)inf).setRandomSeed(42L);
                }
                inf.computeMarginals(mdl);
                Iterator it = mdl.variablesIterator();
                while (it.hasNext()) {
                    Variable var = (Variable)it.next();
                    Factor maxPot = inf.lookupMarginal(var);
                    Factor trueMaxPot = joint.extractMax(var);
                    if (maxPot.argmax() == trueMaxPot.argmax()) continue;
                    logger.warning("Argmax not equal on model " + mdlIdx + " inferencer " + inf + " !\n  Factors:\nReturned: " + maxPot + "\nTrue: " + trueMaxPot);
                    System.err.println("Dump of model " + mdlIdx + " ***");
                    mdl.dump();
                    TestInference.assertTrue((maxPot.argmax() == trueMaxPot.argmax() ? 1 : 0) != 0);
                }
            }
        }
        logger.info("Test maxMarginals passed.");
    }

    public void testBeliefPropagation() {
        for (int mdlIdx = 0; mdlIdx < this.trees.length; ++mdlIdx) {
            FactorGraph mdl = this.trees[mdlIdx];
            TreeBP prop = new TreeBP();
            prop.computeMarginals(mdl);
            Iterator it = mdl.variablesIterator();
            while (it.hasNext()) {
                Variable var = (Variable)it.next();
                Factor marg1 = this.treeMargs[mdlIdx][mdl.getIndex(var)];
                Factor marg2 = prop.lookupMarginal(var);
                try {
                    TestInference.assertTrue((String)("Test failed on graph " + mdlIdx + " vertex " + var + "\n" + "Model: " + mdl + "\nExpected: " + marg1.dumpToString() + "\nActual: " + marg2.dumpToString()), (boolean)marg1.almostEquals(marg2, 0.011));
                }
                catch (AssertionFailedError e) {
                    System.out.println(e.getMessage());
                    System.out.println("*******************************************\nMODEL:\n");
                    mdl.dump();
                    System.out.println("*******************************************\nMESSAGES:\n");
                    ((AbstractBeliefPropagation)prop).dump();
                    throw e;
                }
            }
        }
        logger.info("Test beliefPropagation passed.");
    }

    public void testBpJoint() {
        for (int mdlIdx = 0; mdlIdx < this.trees.length; ++mdlIdx) {
            FactorGraph mdl = this.trees[mdlIdx];
            TreeBP bp = new TreeBP();
            BruteForceInferencer brute = new BruteForceInferencer();
            brute.computeMarginals(mdl);
            bp.computeMarginals(mdl);
            AssignmentIterator it = mdl.assignmentIterator();
            while (it.hasNext()) {
                Assignment assn = (Assignment)it.next();
                TestInference.assertEquals((double)brute.lookupJoint(assn), (double)bp.lookupJoint(assn), (double)1.0E-15);
            }
        }
    }

    public void testDirectedJt() {
        DirectedModel bn = this.createDirectedModel();
        BruteForceInferencer brute = new BruteForceInferencer();
        brute.computeMarginals(bn);
        JunctionTreeInferencer jt = new JunctionTreeInferencer();
        jt.computeMarginals(bn);
        this.compareMarginals("Error comparing junction tree to brute on directed model!", bn, brute, jt);
    }

    private DirectedModel createDirectedModel() {
        int NUM_OUTCOMES = 2;
        Randoms random = new Randoms(13413);
        Dirichlet dirichlet = new Dirichlet(NUM_OUTCOMES, 1.0);
        double[] pA = dirichlet.randomVector(random);
        double[] pB = dirichlet.randomVector(random);
        TDoubleArrayList pC = new TDoubleArrayList(NUM_OUTCOMES * NUM_OUTCOMES * NUM_OUTCOMES);
        for (int i = 0; i < NUM_OUTCOMES * NUM_OUTCOMES; ++i) {
            pC.add(dirichlet.randomVector(random));
        }
        Variable[] vars = new Variable[]{new Variable(NUM_OUTCOMES), new Variable(NUM_OUTCOMES), new Variable(NUM_OUTCOMES)};
        DirectedModel mdl = new DirectedModel();
        mdl.addFactor(new CPT(new TableFactor(vars[0], pA), vars[0]));
        mdl.addFactor(new CPT(new TableFactor(vars[1], pB), vars[1]));
        mdl.addFactor(new CPT(new TableFactor(vars, pC.toNativeArray()), vars[2]));
        return mdl;
    }

    private void compareMarginals(String msg, FactorGraph fg, Inferencer inf1, Inferencer inf2) {
        for (int i = 0; i < fg.numVariables(); ++i) {
            Variable var = fg.get(i);
            Factor ptl1 = inf1.lookupMarginal(var);
            Factor ptl2 = inf2.lookupMarginal(var);
            TestInference.assertTrue((String)(msg + "\n" + ptl1.dumpToString() + "\n" + ptl2.dumpToString()), (boolean)ptl1.almostEquals(ptl2, 1.0E-5));
        }
    }

    protected void setUp() {
        this.modelsList = TestInference.createTestModels();
        this.createTestTrees();
        this.models = this.modelsList.toArray(new UndirectedModel[0]);
        this.computeTestTreeMargs();
    }

    public void testMultiply() {
        TableFactor p1 = new TableFactor(new Variable[0]);
        System.out.println(p1);
        Variable[] vars = new Variable[]{new Variable(2), new Variable(2)};
        double[] probs = new double[]{1.0, 3.0, 5.0, 6.0};
        TableFactor p2 = new TableFactor(vars, probs);
        Factor p3 = p1.multiply(p2);
        TestInference.assertTrue((String)("Should be equal: " + p2 + "\n" + p3), (boolean)p2.almostEquals(p3));
    }

    public void testLogMarginalize() {
        UndirectedModel mdl = this.models[0];
        Iterator it = mdl.variablesIterator();
        Variable v1 = (Variable)it.next();
        Variable v2 = (Variable)it.next();
        Random rand = new Random(3214123L);
        for (int i = 0; i < 10; ++i) {
            TableFactor ptl = TestInference.randomEdgePotential(rand, v1, v2);
            Factor logmarg1 = new LogTableFactor(ptl).marginalize(v1);
            LogTableFactor marglog1 = new LogTableFactor((AbstractTableFactor)ptl.marginalize(v1));
            TestInference.assertTrue((String)("LogMarg failed! Correct: " + marglog1 + " Log-marg: " + logmarg1), (boolean)logmarg1.almostEquals(marglog1));
            Factor logmarg2 = new LogTableFactor(ptl).marginalize(v2);
            LogTableFactor marglog2 = new LogTableFactor((AbstractTableFactor)ptl.marginalize(v2));
            TestInference.assertTrue((boolean)logmarg2.almostEquals(marglog2));
        }
    }

    public void testLogNormalize() {
        UndirectedModel mdl = this.models[0];
        Iterator it = mdl.variablesIterator();
        Variable v1 = (Variable)it.next();
        Variable v2 = (Variable)it.next();
        Random rand = new Random(3214123L);
        for (int i = 0; i < 10; ++i) {
            TableFactor ptl = TestInference.randomEdgePotential(rand, v1, v2);
            LogTableFactor norm1 = new LogTableFactor(ptl);
            Factor norm2 = ptl.duplicate();
            norm1.normalize();
            norm2.normalize();
            TestInference.assertTrue((String)("LogNormalize failed! Correct: " + norm2 + " Log-normed: " + norm1), (boolean)norm1.almostEquals(norm2));
        }
    }

    public void testSumLogProb() {
        Random rand = new Random(3214123L);
        for (int i = 0; i < 10; ++i) {
            double v1 = rand.nextDouble();
            double v2 = rand.nextDouble();
            double sum1 = Math.log(v1 + v2);
            double sum2 = Maths.sumLogProb(Math.log(v1), Math.log(v2));
            TestInference.assertEquals((double)sum1, (double)sum2, (double)1.0E-5);
        }
    }

    public void testInfiniteCost() {
        Variable[] vars = new Variable[3];
        for (int i = 0; i < vars.length; ++i) {
            vars[i] = new Variable(2);
        }
        FactorGraph mdl = new FactorGraph(vars);
        mdl.addFactor(vars[0], vars[1], new double[]{2.0, 6.0, 4.0, 8.0});
        mdl.addFactor(vars[1], vars[2], new double[]{1.0, 0.0, 0.0, 1.0});
        mdl.dump();
        TreeBP bp = new TreeBP();
        bp.computeMarginals(mdl);
    }

    public void testJtCaching() {
        for (int i = 0; i < this.models.length; ++i) {
            UndirectedModel model = this.models[i];
            model.setInferenceCache(JunctionTreeInferencer.class, null);
        }
        Factor[][] margs = new Factor[this.models.length][];
        long stime1 = new Date().getTime();
        for (int i = 0; i < this.models.length; ++i) {
            UndirectedModel model = this.models[i];
            JunctionTreeInferencer inf = new JunctionTreeInferencer();
            inf.computeMarginals(model);
            margs[i] = new Factor[model.numVariables()];
            Iterator it = model.variablesIterator();
            int j = -1;
            while (it.hasNext()) {
                Variable var = (Variable)it.next();
                margs[i][++j] = inf.lookupMarginal(var);
            }
        }
        long etime1 = new Date().getTime();
        long diff1 = etime1 - stime1;
        logger.info("Pre-cache took " + diff1 + " ms.");
        long stime2 = new Date().getTime();
        for (int i = 0; i < this.models.length; ++i) {
            UndirectedModel model = this.models[i];
            JunctionTreeInferencer inf = new JunctionTreeInferencer();
            inf.computeMarginals(model);
            Iterator it = model.variablesIterator();
            int j = -1;
            while (it.hasNext()) {
                Variable var = (Variable)it.next();
                TestInference.assertTrue((boolean)margs[i][++j].almostEquals(inf.lookupMarginal(var)));
            }
        }
        long etime2 = new Date().getTime();
        long diff2 = etime2 - stime2;
        logger.info("Post-cache took " + diff2 + " ms.");
    }

    public void testFindVariable() {
        UndirectedModel mdl = this.models[0];
        Variable[] vars = new Variable[mdl.numVariables()];
        Iterator it = mdl.variablesIterator();
        while (it.hasNext()) {
            String name;
            Variable var = (Variable)it.next();
            TestInference.assertTrue((var == mdl.findVariable(name = new String(var.getLabel())) ? 1 : 0) != 0);
        }
        TestInference.assertTrue((mdl.findVariable("xsdfasdf") == null ? 1 : 0) != 0);
    }

    public void testDefaultLookupMarginal() {
        TreeBP inf = new TreeBP();
        FactorGraph mdl = this.trees[2];
        Variable var = mdl.get(0);
        inf.computeMarginals(mdl);
        HashVarSet varSet = new HashVarSet(new Variable[]{var});
        Factor ptl1 = inf.lookupMarginal(varSet);
        Factor ptl2 = inf.lookupMarginal(var);
        TestInference.assertTrue((boolean)ptl1.almostEquals(ptl2));
        Variable var2 = mdl.get(1);
        Variable var3 = mdl.get(2);
        HashVarSet c2 = new HashVarSet(new Variable[]{var, var2, var3});
        try {
            inf.lookupMarginal(c2);
            TestInference.fail((String)("Expected an UnsupportedOperationException with clique " + c2));
        }
        catch (UnsupportedOperationException e) {
            // empty catch block
        }
    }

    public void testDisconnectedModel() {
        Variable[] vars = new Variable[4];
        for (int i = 0; i < vars.length; ++i) {
            vars[i] = new Variable(2);
        }
        UndirectedModel mdl = new UndirectedModel(vars);
        Random r = new Random(67L);
        Factor[] ptls = new Factor[4];
        Factor[] normed = new Factor[4];
        for (int i = 0; i < vars.length; ++i) {
            ptls[i] = TestInference.randomNodePotential(r, vars[i]);
            normed[i] = ptls[i].duplicate();
            normed[i].normalize();
            ((FactorGraph)mdl).addFactor(ptls[i]);
        }
        mdl.dump();
        LoopyBP inf = new LoopyBP();
        inf.computeMarginals(mdl);
        for (int i = 0; i < vars.length; ++i) {
            Factor marg = inf.lookupMarginal(vars[i]);
            TestInference.assertTrue((String)("Marginals not equal!\n   True: " + normed[i] + "\n   Returned " + marg), (boolean)marg.almostEquals(normed[i]));
        }
        AssignmentIterator it = mdl.assignmentIterator();
        while (it.hasNext()) {
            Assignment assn = (Assignment)it.next();
            double trueProb = 1.0;
            for (int i = 0; i < vars.length; ++i) {
                trueProb *= normed[i].value(assn);
            }
            TestInference.assertEquals((double)trueProb, (double)inf.lookupJoint(assn), (double)1.0E-5);
        }
    }

    public void timeMarginalization() {
        Random r = new Random(7732847L);
        Variable[] vars = new Variable[]{new Variable(2), new Variable(2)};
        TableFactor ptl = TestInference.randomEdgePotential(r, vars[0], vars[1]);
        long stime = System.currentTimeMillis();
        for (int i = 0; i < 1000; ++i) {
            Factor marg = ptl.marginalize(vars[0]);
            Factor marg2 = ptl.marginalize(vars[1]);
        }
        long etime = System.currentTimeMillis();
        logger.info("Marginalization (2-outcome) took " + (etime - stime) + " ms.");
        Variable[] vars45 = new Variable[]{new Variable(45), new Variable(45)};
        TableFactor ptl45 = TestInference.randomEdgePotential(r, vars45[0], vars45[1]);
        stime = System.currentTimeMillis();
        for (int i = 0; i < 1000; ++i) {
            Factor marg = ptl45.marginalize(vars45[0]);
            Factor marg2 = ptl45.marginalize(vars45[1]);
        }
        etime = System.currentTimeMillis();
        logger.info("Marginalization (45-outcome) took " + (etime - stime) + " ms.");
    }

    public void runJunctionTree() {
        for (int mdlIdx = 0; mdlIdx < this.models.length; ++mdlIdx) {
            UndirectedModel model = this.models[mdlIdx];
            JunctionTreeInferencer inf = new JunctionTreeInferencer();
            inf.computeMarginals(model);
            Iterator it = model.variablesIterator();
            while (it.hasNext()) {
                Variable var = (Variable)it.next();
                inf.lookupMarginal(var);
            }
        }
    }

    public void testDestructiveAssignment() {
        Variable[] vars = new Variable[]{new Variable(2), new Variable(2)};
        Assignment assn = new Assignment(vars, new int[]{0, 1});
        TestInference.assertEquals((int)0, (int)assn.get(vars[0]));
        TestInference.assertEquals((int)1, (int)assn.get(vars[1]));
        assn.setValue(vars[0], 1);
        TestInference.assertEquals((int)1, (int)assn.get(vars[0]));
        TestInference.assertEquals((int)1, (int)assn.get(vars[1]));
    }

    public void testLoopyConvergence() {
        Random r = new Random(67L);
        UndirectedModel mdl = TestInference.createRandomGrid(5, 5, 2, r);
        LoopyBP loopy = new LoopyBP();
        loopy.computeMarginals(mdl);
        TestInference.assertTrue((loopy.iterationsUsed() > 8 ? 1 : 0) != 0);
    }

    public void testSingletonGraph() {
        Variable v = new Variable(2);
        FactorGraph mdl = new FactorGraph(new Variable[]{v});
        mdl.addFactor(new TableFactor(v, new double[]{1.0, 2.0}));
        TRP trp = new TRP();
        trp.computeMarginals(mdl);
        Factor ptl = trp.lookupMarginal(v);
        double[] dbl = ((AbstractTableFactor)ptl).toValueArray();
        TestInference.assertEquals((int)2, (int)dbl.length);
        TestInference.assertEquals((double)0.33333, (double)dbl[0], (double)1.0E-4);
        TestInference.assertEquals((double)0.66666, (double)dbl[1], (double)1.0E-4);
    }

    public void testLoopyCaching() {
        UndirectedModel mdl1 = this.models[4];
        UndirectedModel mdl2 = this.models[5];
        Variable var = mdl1.get(0);
        LoopyBP inferencer = new LoopyBP();
        inferencer.setUseCaching(true);
        inferencer.computeMarginals(mdl1);
        Factor origPtl = inferencer.lookupMarginal(var);
        TestInference.assertTrue((2 < inferencer.iterationsUsed() ? 1 : 0) != 0);
        inferencer.computeMarginals(mdl2);
        inferencer.computeMarginals(mdl1);
        Factor sndPtl = inferencer.lookupMarginal(var);
        TestInference.assertTrue((String)("Huh? Original potential:" + origPtl + "After: " + sndPtl), (boolean)origPtl.almostEquals(sndPtl, 1.0E-4));
        TestInference.assertEquals((int)1, (int)inferencer.iterationsUsed());
    }

    public void testJunctionTreeConnectedFromRoot() {
        JunctionTreeInferencer jti = new JunctionTreeInferencer();
        jti.computeMarginals(this.models[0]);
        jti.computeMarginals(this.models[1]);
        JunctionTree jt = jti.lookupJunctionTree();
        ArrayList<VarSet> reached = new ArrayList<VarSet>();
        LinkedList<Object> queue = new LinkedList<Object>();
        queue.add(jt.getRoot());
        while (!queue.isEmpty()) {
            VarSet current = (VarSet)queue.removeFirst();
            queue.addAll(jt.getChildren(current));
            reached.add(current);
        }
        TestInference.assertEquals((int)jt.clusterPotentials().size(), (int)reached.size());
    }

    public void testBpLargeModels() {
        Timing timing = new Timing();
        FactorGraph mdl = RandomGraphs.createUniformChain(8196);
        timing.tick("Model creation");
        LoopyBP inf = new LoopyBP();
        try {
            ((AbstractInferencer)inf).computeMarginals(mdl);
        }
        catch (OutOfMemoryError e) {
            System.out.println("OUT OF MEMORY: Messages sent " + AbstractBeliefPropagation.getTotalMessagesSent());
            throw e;
        }
        timing.tick("Inference time (Random sched BP)");
    }

    public void testTrpLargeModels() {
        Timing timing = new Timing();
        FactorGraph mdl = RandomGraphs.createUniformChain(8192);
        timing.tick("Model creation");
        TRP inf = new TRP();
        inf.computeMarginals(mdl);
        timing.tick("Inference time (TRP)");
    }

    private Factor createEdgePtl(Variable var1, Variable var2, Random r) {
        double[] dbls = new double[4];
        for (int i = 0; i < dbls.length; ++i) {
            dbls[i] = r.nextDouble();
        }
        return new TableFactor(new Variable[]{var1, var2}, dbls);
    }

    public void testJtConstant() throws IOException {
        FactorGraph masterFg = new ModelReader().readModel(new BufferedReader(new StringReader(this.gridStr)));
        JunctionTreeInferencer jt = new JunctionTreeInferencer();
        Assignment assn = masterFg.sampleContinuousVars(new Randoms(3214));
        FactorGraph fg = (FactorGraph)masterFg.slice(assn);
        jt.computeMarginals(fg);
    }

    public static Test suite() {
        return new TestSuite(TestInference.class);
    }

    public static void main(String[] args) throws Exception {
        TestSuite theSuite;
        if (args.length > 0) {
            theSuite = new TestSuite();
            for (int i = 0; i < args.length; ++i) {
                theSuite.addTest((Test)new TestInference(args[i]));
            }
        } else {
            theSuite = (TestSuite)TestInference.suite();
        }
        TestRunner.run((Test)theSuite);
    }
}

