/*
 * Decompiled with CFR 0.152.
 */
package edu.stanford.nlp.util;

import edu.stanford.nlp.util.ErasureUtils;
import edu.stanford.nlp.util.Generics;
import edu.stanford.nlp.util.Pair;
import edu.stanford.nlp.util.StringUtils;
import java.awt.BorderLayout;
import java.awt.Canvas;
import java.awt.Color;
import java.awt.Dimension;
import java.awt.EventQueue;
import java.awt.Font;
import java.awt.Graphics;
import java.awt.Graphics2D;
import java.awt.Point;
import java.awt.Rectangle;
import java.awt.event.MouseAdapter;
import java.awt.event.MouseEvent;
import java.io.StringWriter;
import java.text.DecimalFormat;
import java.text.DecimalFormatSymbols;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Locale;
import java.util.Map;
import java.util.Set;
import java.util.concurrent.ConcurrentHashMap;
import java.util.stream.Collectors;
import javax.swing.JFrame;
import javax.swing.JPanel;
import javax.swing.UIManager;
import javax.swing.UnsupportedLookAndFeelException;

public class ConfusionMatrix<U> {
    private static final String CLASS_PREFIX = "C";
    private static final String FORMAT = "#.#####";
    protected DecimalFormat format;
    private int leftPadSize = 16;
    private int delimPadSize = 8;
    private boolean useRealLabels = false;
    private ConcurrentHashMap<Pair<U, U>, Integer> confTable = new ConcurrentHashMap();

    public ConfusionMatrix() {
        this.format = new DecimalFormat(FORMAT);
    }

    public ConfusionMatrix(Locale locale) {
        this.format = new DecimalFormat(FORMAT, new DecimalFormatSymbols(locale));
    }

    public String toString() {
        return this.printTable();
    }

    public void setLeftPadSize(int newPadSize) {
        this.leftPadSize = newPadSize;
    }

    public void setDelimPadSize(int newPadSize) {
        this.delimPadSize = newPadSize;
    }

    public void setUseRealLabels(boolean useRealLabels) {
        this.useRealLabels = useRealLabels;
    }

    public void add(U guess, U gold) {
        this.add(guess, gold, 1);
    }

    public synchronized void add(U guess, U gold, int increment) {
        Pair<U, U> pair = new Pair<U, U>(guess, gold);
        if (this.confTable.containsKey(pair)) {
            this.confTable.put(pair, this.confTable.get(pair) + increment);
        } else {
            this.confTable.put(pair, increment);
        }
    }

    public Integer get(U guess, U gold) {
        Pair<U, U> pair = new Pair<U, U>(guess, gold);
        if (this.confTable.containsKey(pair)) {
            return this.confTable.get(pair);
        }
        return 0;
    }

    public Set<U> uniqueLabels() {
        HashSet<Object> ret = new HashSet<Object>();
        for (Pair pair : this.confTable.keySet()) {
            ret.add(pair.first());
            ret.add(pair.second());
        }
        return ret;
    }

    public Contingency getContingency(U positiveLabel) {
        int tp = 0;
        int fp = 0;
        int tn = 0;
        int fn = 0;
        for (Pair pair : this.confTable.keySet()) {
            int count = this.confTable.get(pair);
            Object guess = pair.first();
            Object gold = pair.second();
            boolean guessP = guess.equals(positiveLabel);
            boolean goldP = gold.equals(positiveLabel);
            if (guessP && goldP) {
                tp += count;
                continue;
            }
            if (!guessP && goldP) {
                fn += count;
                continue;
            }
            if (guessP && !goldP) {
                fp += count;
                continue;
            }
            tn += count;
        }
        return new Contingency(tp, fp, tn, fn);
    }

    private List<U> sortKeys() {
        Set<U> labels = this.uniqueLabels();
        if (labels.size() == 0) {
            return Collections.emptyList();
        }
        boolean comparable = true;
        for (Object label : labels) {
            if (label instanceof Comparable) continue;
            comparable = false;
            break;
        }
        if (comparable) {
            ArrayList sorted = Generics.newArrayList();
            for (Object label : labels) {
                sorted.add(ErasureUtils.uncheckedCast(label));
            }
            Collections.sort(sorted);
            ArrayList ret = Generics.newArrayList();
            for (Object o : sorted) {
                ret.add(ErasureUtils.uncheckedCast(o));
            }
            return ret;
        }
        ArrayList<String> names = new ArrayList<String>();
        HashMap<String, U> lookup = new HashMap<String, U>();
        for (U label : labels) {
            names.add(label.toString());
            lookup.put(label.toString(), label);
        }
        Collections.sort(names);
        ArrayList ret = new ArrayList();
        for (String name : names) {
            ret.add(lookup.get(name));
        }
        return ret;
    }

    private Integer goldMarginal(U gold) {
        Integer sum = 0;
        Set<U> labels = this.uniqueLabels();
        for (U guess : labels) {
            sum = sum + this.get(guess, gold);
        }
        return sum;
    }

    private Integer guessMarginal(U guess) {
        Integer sum = 0;
        Set<U> labels = this.uniqueLabels();
        for (U gold : labels) {
            sum = sum + this.get(guess, gold);
        }
        return sum;
    }

    private String getPlaceHolder(int index, U label) {
        if (this.useRealLabels) {
            return label.toString();
        }
        return CLASS_PREFIX + (index + 1);
    }

    public String printTable() {
        String placeHolder;
        List<U> sortedLabels = this.sortKeys();
        if (this.confTable.size() == 0) {
            return "Empty table!";
        }
        StringWriter ret = new StringWriter();
        ret.write(StringUtils.padLeft("Guess/Gold", this.leftPadSize));
        for (int i = 0; i < sortedLabels.size(); ++i) {
            placeHolder = this.getPlaceHolder(i, sortedLabels.get(i));
            ret.write(StringUtils.padLeft(placeHolder, this.delimPadSize));
        }
        ret.write("    Marg. (Guess)");
        ret.write("\n");
        for (int guessI = 0; guessI < sortedLabels.size(); ++guessI) {
            placeHolder = this.getPlaceHolder(guessI, sortedLabels.get(guessI));
            ret.write(StringUtils.padLeft(placeHolder, this.leftPadSize));
            U guess = sortedLabels.get(guessI);
            for (U gold : sortedLabels) {
                Integer value = this.get(guess, gold);
                ret.write(StringUtils.padLeft(value.toString(), this.delimPadSize));
            }
            ret.write(StringUtils.padLeft(this.guessMarginal(guess).toString(), this.delimPadSize));
            ret.write("\n");
        }
        ret.write(StringUtils.padLeft("Marg. (Gold)", this.leftPadSize));
        for (U gold : sortedLabels) {
            ret.write(StringUtils.padLeft(this.goldMarginal(gold).toString(), this.delimPadSize));
        }
        ret.write("\n\n");
        for (int labelI = 0; labelI < sortedLabels.size(); ++labelI) {
            U classLabel = sortedLabels.get(labelI);
            String placeHolder2 = this.getPlaceHolder(labelI, classLabel);
            ret.write(StringUtils.padLeft(placeHolder2, this.leftPadSize));
            if (!this.useRealLabels) {
                ret.write(" = ");
                ret.write(classLabel.toString());
            }
            ret.write(StringUtils.padLeft("", this.delimPadSize));
            Contingency contingency = this.getContingency(classLabel);
            ret.write(contingency.toString());
            ret.write("\n");
        }
        return ret.toString();
    }

    public void gui() {
        ConfusionGrid gui = new ConfusionGrid();
        gui.setVisible(true);
    }

    public static void main(String[] args) {
        ConfusionMatrix<String> confusion = new ConfusionMatrix<String>();
        confusion.add("a", "a");
        confusion.add("a", "b");
        confusion.add("b", "a");
        confusion.add("a", "a");
        confusion.add("b", "b");
        confusion.add("b", "b");
        confusion.add("a", "b");
        confusion.gui();
    }

    private class ConfusionGrid
    extends Canvas {
        public ConfusionGrid() {
            EventQueue.invokeLater(() -> {
                try {
                    UIManager.setLookAndFeel(UIManager.getSystemLookAndFeelClassName());
                }
                catch (ClassNotFoundException | IllegalAccessException | InstantiationException | UnsupportedLookAndFeelException exception) {
                    // empty catch block
                }
                JFrame frame = new JFrame("Confusion Matrix");
                frame.setDefaultCloseOperation(2);
                frame.setLayout(new BorderLayout());
                frame.add(new Grid());
                frame.pack();
                frame.setLocationRelativeTo(null);
                frame.setVisible(true);
            });
        }

        public class Grid
        extends JPanel {
            private int columnCount;
            private int rowCount;
            private List<Rectangle> cells;
            private Point selectedCell;

            public Grid() {
                this.columnCount = ConfusionMatrix.this.uniqueLabels().size() + 1;
                this.rowCount = ConfusionMatrix.this.uniqueLabels().size() + 1;
                this.cells = new ArrayList<Rectangle>(this.columnCount * this.rowCount);
                MouseAdapter mouseHandler = new MouseAdapter(){

                    @Override
                    public void mouseMoved(MouseEvent e) {
                        int width = Grid.this.getWidth();
                        int height = Grid.this.getHeight();
                        int cellWidth = width / Grid.this.columnCount;
                        int cellHeight = height / Grid.this.rowCount;
                        int column = e.getX() / cellWidth;
                        int row = e.getY() / cellHeight;
                        Grid.this.selectedCell = new Point(column, row);
                        Grid.this.repaint();
                    }
                };
                this.addMouseMotionListener(mouseHandler);
            }

            public void onMouseOver(Graphics2D g2d, Rectangle cell, U guess, U gold) {
                int x = (int)((double)cell.getLocation().x + cell.getWidth() / 5.0);
                int y = (int)((double)cell.getLocation().y + cell.getHeight() / 5.0);
                Integer value = (Integer)ConfusionMatrix.this.confTable.get(Pair.makePair(guess, gold));
                if (value == null) {
                    value = 0;
                }
                String text = "Guess: " + guess.toString() + "\nGold: " + gold.toString() + "\nValue: " + value;
                Font bak = g2d.getFont();
                g2d.setFont(bak.deriveFont((float)bak.getSize() * 2.0f));
                g2d.setColor(Color.WHITE);
                g2d.fill(cell);
                g2d.setColor(Color.BLACK);
                for (String line : text.split("\n")) {
                    g2d.drawString(line, x, y += g2d.getFontMetrics().getHeight());
                }
                g2d.setFont(bak);
            }

            @Override
            public Dimension getPreferredSize() {
                return new Dimension(800, 800);
            }

            @Override
            public void invalidate() {
                this.cells.clear();
                super.invalidate();
            }

            @Override
            protected void paintComponent(Graphics g) {
                super.paintComponent(g);
                Graphics2D g2d = (Graphics2D)g.create();
                g.setFont(new Font("Arial", 0, 10));
                int width = this.getWidth();
                int height = this.getHeight();
                int cellWidth = width / this.columnCount;
                int cellHeight = height / this.rowCount;
                int xOffset = (width - this.columnCount * cellWidth) / 2;
                int yOffset = (height - this.rowCount * cellHeight) / 2;
                List labels = ConfusionMatrix.this.uniqueLabels().stream().collect(Collectors.toList());
                int maxDiag = 0;
                int maxOffdiag = 0;
                for (Map.Entry entry : ConfusionMatrix.this.confTable.entrySet()) {
                    if (((Pair)entry.getKey()).first == ((Pair)entry.getKey()).second) {
                        maxDiag = Math.max(maxDiag, (Integer)entry.getValue());
                        continue;
                    }
                    maxOffdiag = Math.max(maxOffdiag, (Integer)entry.getValue());
                }
                float[] hsb = new float[3];
                for (int row = 0; row < this.rowCount; ++row) {
                    for (int col = 0; col < this.columnCount; ++col) {
                        String text;
                        int x = xOffset + col * cellWidth;
                        int y = yOffset + row * cellHeight;
                        float xCenter = (float)(xOffset + col * cellWidth) + (float)cellWidth / 3.0f;
                        float yCenter = (float)(yOffset + row * cellHeight) + (float)cellHeight / 2.0f;
                        Color bg = Color.WHITE;
                        if (row == 0 && col == 0) {
                            text = "V guess | gold >";
                        } else if (row == 0) {
                            text = labels.get(col - 1).toString();
                        } else if (col == 0) {
                            text = labels.get(row - 1).toString();
                        } else {
                            Integer count = (Integer)ConfusionMatrix.this.confTable.get(Pair.makePair(labels.get(row - 1), labels.get(col - 1)));
                            if (count == null) {
                                count = 0;
                            }
                            text = "" + count;
                            if (row == col) {
                                double percentGood = (double)count.intValue() / (double)maxDiag;
                                hsb = Color.RGBtoHSB((int)(255.0 - 255.0 * percentGood), (int)(255.0 - 255.0 * percentGood / 2.0), (int)(255.0 - 255.0 * percentGood), hsb);
                                bg = Color.getHSBColor(hsb[0], hsb[1], hsb[2]);
                            } else {
                                double percentBad = (double)count.intValue() / (double)maxOffdiag;
                                hsb = Color.RGBtoHSB((int)(255.0 - 255.0 * percentBad / 2.0), (int)(255.0 - 255.0 * percentBad), (int)(255.0 - 255.0 * percentBad), hsb);
                                bg = Color.getHSBColor(hsb[0], hsb[1], hsb[2]);
                            }
                        }
                        Rectangle cell = new Rectangle(x, y, cellWidth, cellHeight);
                        g2d.setColor(bg);
                        g2d.fill(cell);
                        g2d.setColor(Color.BLACK);
                        g2d.drawString(text, xCenter, yCenter);
                        this.cells.add(cell);
                    }
                }
                if (this.selectedCell != null && this.selectedCell.x > 0 && this.selectedCell.y > 0) {
                    int index = this.selectedCell.x + this.selectedCell.y * this.columnCount;
                    Rectangle cell = this.cells.get(index);
                    this.onMouseOver(g2d, cell, labels.get(this.selectedCell.y - 1), labels.get(this.selectedCell.x - 1));
                }
                g2d.dispose();
            }
        }
    }

    public class Contingency {
        private double tp = 0.0;
        private double fp = 0.0;
        private double tn = 0.0;
        private double fn = 0.0;
        private double prec = 0.0;
        private double recall = 0.0;
        private double spec = 0.0;
        private double f1 = 0.0;

        public Contingency(int tp_, int fp_, int tn_, int fn_) {
            this.tp = tp_;
            this.fp = fp_;
            this.tn = tn_;
            this.fn = fn_;
            this.prec = this.tp / (this.tp + this.fp);
            this.recall = this.tp / (this.tp + this.fn);
            this.spec = this.tn / (this.fp + this.tn);
            this.f1 = 2.0 * this.prec * this.recall / (this.prec + this.recall);
        }

        public String toString() {
            return StringUtils.join(Arrays.asList("prec=" + (this.tp + this.fp > 0.0 ? ConfusionMatrix.this.format.format(this.prec) : "n/a"), "recall=" + (this.tp + this.fn > 0.0 ? ConfusionMatrix.this.format.format(this.recall) : "n/a"), "spec=" + (this.fp + this.tn > 0.0 ? ConfusionMatrix.this.format.format(this.spec) : "n/a"), "f1=" + (this.prec + this.recall > 0.0 ? ConfusionMatrix.this.format.format(this.f1) : "n/a")), ", ");
        }

        public double f1() {
            return this.f1;
        }

        public double precision() {
            return this.prec;
        }

        public double recall() {
            return this.recall;
        }

        public double spec() {
            return this.spec;
        }
    }
}

