/*
 * Decompiled with CFR 0.152.
 */
package org.apache.sysds.runtime.compress.cost;

import org.apache.sysds.runtime.compress.DMLCompressionException;
import org.apache.sysds.runtime.compress.colgroup.AColGroup;
import org.apache.sysds.runtime.compress.cost.ACostEstimate;
import org.apache.sysds.runtime.compress.cost.InstructionTypeCounter;
import org.apache.sysds.runtime.compress.estim.CompressedSizeInfoColGroup;
import org.apache.sysds.runtime.matrix.data.MatrixBlock;

public class ComputationCostEstimator
extends ACostEstimate {
    private static final long serialVersionUID = -1205636215389161815L;
    private static final double cvThreshold = 0.2;
    private final InstructionTypeCounter ins;

    protected ComputationCostEstimator(InstructionTypeCounter counts) {
        this.ins = counts;
        if (LOG.isDebugEnabled()) {
            LOG.debug((Object)this);
        }
    }

    public ComputationCostEstimator(int scans, int decompressions, int overlappingDecompressions, int leftMultiplications, int rightMultiplications, int compressedMultiplication, int dictOps, int indexing, boolean isDensifying) {
        this.ins = new InstructionTypeCounter(scans, decompressions, overlappingDecompressions, leftMultiplications, rightMultiplications, compressedMultiplication, dictOps, indexing, isDensifying);
    }

    @Override
    protected double getCostSafe(CompressedSizeInfoColGroup g) {
        int nVals = g.getNumVals();
        int nCols = g.getColumns().size();
        int nRows = g.getNumRows();
        double sparsity = nCols < 3 || this.ins.isDensifying() ? 1.0 : g.getTupleSparsity() + 1.0E-10;
        double commonFraction = g.getLargestOffInstances();
        if (g.isEmpty() && !this.ins.isDensifying()) {
            return this.getCost(nRows, 1, nCols, 1, 1.0E-5);
        }
        if (g.isEmpty() || g.isConst()) {
            return this.getCost(nRows, 1, nCols, 1, 1.0);
        }
        if (g.isIncompressable()) {
            return this.getCost(nRows * 3, nRows, nCols, nRows * 3, sparsity);
        }
        if (commonFraction > 0.2) {
            return this.getCost(nRows, nRows - g.getLargestOffInstances(), nCols, nVals, sparsity);
        }
        return this.getCost(nRows, nRows, nCols, nVals, sparsity);
    }

    public double getCost(int nRows, int nRowsScanned, int nCols, int nVals, double sparsity) {
        sparsity = nCols < 3 || sparsity > 0.4 ? 1.0 : sparsity;
        double cost = 0.0;
        cost += this.leftMultCost(nRowsScanned, nRows, nCols, nVals, sparsity);
        cost += this.scanCost(nRowsScanned, nCols, nVals, sparsity);
        cost += this.dictionaryOpsCost(nVals, nCols, sparsity);
        cost += this.rightMultCost(nCols, nVals, sparsity);
        cost += this.decompressionCost(nVals, nCols, nRowsScanned, sparsity);
        cost += this.overlappingDecompressionCost(nRowsScanned);
        cost += this.compressedMultiplicationCost(nRowsScanned, nRows, nVals, nCols, sparsity);
        if ((cost += 100.0) < 0.0) {
            throw new DMLCompressionException("Ivalid negative cost: " + cost);
        }
        return cost;
    }

    public boolean isDense() {
        return this.ins.isDensifying();
    }

    @Override
    public double getCost(MatrixBlock mb) {
        double cost = 0.0;
        double nCols = mb.getNumColumns();
        double nRows = mb.getNumRows();
        double sparsity = nCols < 3.0 || this.ins.isDensifying() ? 1.0 : mb.getSparsity();
        cost += this.dictionaryOpsCost(nRows, nCols, sparsity);
        cost += this.leftMultCost(0.0, nRows * nCols * sparsity + nCols);
        cost += this.rightMultCost(nRows * nCols * sparsity, nRows * nCols);
        cost += this.scanCost(0.0, nRows, nCols, sparsity);
        if ((cost += this.compressedMultiplicationCost(0.0, 0.0, nRows, nCols, sparsity)) < 0.0) {
            throw new DMLCompressionException("Invalid negative cost : " + cost);
        }
        return cost;
    }

    @Override
    public double getCost(AColGroup cg, int nRows) {
        return cg.getCost(this, nRows);
    }

    @Override
    public boolean shouldSparsify() {
        return this.ins.getLeftMultiplications() > 0 || this.ins.getCompressedMultiplications() > 0 || this.ins.getRightMultiplications() > 0;
    }

    private double dictionaryOpsCost(double nVals, double nCols, double sparsity) {
        return (double)this.ins.getDictionaryOps() * sparsity * nVals * nCols * 2.0;
    }

    private double leftMultCost(double nRowsScanned, double nRows, double nCols, double nVals, double sparsity) {
        double preScalingCost = Math.max(nRowsScanned, nRows) * 2.0;
        if ((nCols == nVals || nCols == nVals + 1.0) && nVals > 1000.0) {
            preScalingCost = 0.0;
        }
        double postScalingCost = sparsity * nVals * nCols;
        return this.leftMultCost(preScalingCost, postScalingCost);
    }

    private double leftMultCost(double preAggregateCost, double postScalingCost) {
        return (double)this.ins.getLeftMultiplications() * (preAggregateCost + postScalingCost);
    }

    private double rightMultCost(double nVals, double nCols, double sparsity) {
        double preMultiplicationCost = sparsity * nCols * nVals;
        double allocationCost = nVals;
        return this.rightMultCost(preMultiplicationCost, allocationCost);
    }

    private double rightMultCost(double preMultiplicationCost, double allocationCost) {
        return (double)this.ins.getRightMultiplications() * (preMultiplicationCost + allocationCost);
    }

    private double decompressionCost(double nVals, double nCols, double nRowsScanned, double sparsity) {
        return (double)this.ins.getDecompressions() * (nCols * nRowsScanned * sparsity);
    }

    private double overlappingDecompressionCost(double nRows) {
        return (double)this.ins.getOverlappingDecompressions() * nRows;
    }

    private double scanCost(double nRowsScanned, double nVals, double nCols, double sparsity) {
        return (double)this.ins.getScans() * (nRowsScanned + nVals * nCols * sparsity);
    }

    private double compressedMultiplicationCost(double nRowsScanned, double nRows, double nVals, double nCols, double sparsity) {
        return (double)this.ins.getCompressedMultiplications() * (Math.max(nRowsScanned, nRows / 10.0) + nVals * nCols * sparsity);
    }

    @Override
    public String toString() {
        StringBuilder sb = new StringBuilder();
        sb.append(super.toString());
        sb.append(" : ");
        sb.append(this.ins.toString());
        return sb.toString();
    }
}

