/*
 * Decompiled with CFR 0.152.
 */
package org.apache.sysds.runtime.transform.encode;

import java.io.IOException;
import java.io.ObjectInput;
import java.io.ObjectOutput;
import java.util.ArrayList;
import java.util.Collections;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.concurrent.Callable;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.apache.sysds.api.DMLScript;
import org.apache.sysds.runtime.DMLRuntimeException;
import org.apache.sysds.runtime.controlprogram.caching.CacheBlock;
import org.apache.sysds.runtime.data.DenseBlock;
import org.apache.sysds.runtime.data.SparseBlock;
import org.apache.sysds.runtime.data.SparseBlockCSR;
import org.apache.sysds.runtime.data.SparseRowVector;
import org.apache.sysds.runtime.frame.data.FrameBlock;
import org.apache.sysds.runtime.matrix.data.MatrixBlock;
import org.apache.sysds.runtime.transform.encode.Encoder;
import org.apache.sysds.runtime.transform.encode.EncoderFactory;
import org.apache.sysds.runtime.util.DependencyTask;
import org.apache.sysds.runtime.util.DependencyThreadPool;
import org.apache.sysds.runtime.util.UtilFunctions;
import org.apache.sysds.utils.stats.TransformStatistics;

public abstract class ColumnEncoder
implements Encoder,
Comparable<ColumnEncoder> {
    protected static final Log LOG = LogFactory.getLog((String)ColumnEncoder.class.getName());
    public static int APPLY_ROW_BLOCKS_PER_COLUMN = -1;
    public static int BUILD_ROW_BLOCKS_PER_COLUMN = -1;
    private static final long serialVersionUID = 2299156350718979064L;
    protected int _colID;
    protected ArrayList<Integer> _sparseRowsWZeros = null;
    protected long _estMetaSize = 0L;
    protected int _estNumDistincts = 0;
    protected int _nBuildPartitions = 0;
    protected int _nApplyPartitions = 0;

    public void initEmbeddings(MatrixBlock embeddings) {
    }

    protected ColumnEncoder(int colID) {
        this._colID = colID;
    }

    @Override
    public MatrixBlock apply(CacheBlock<?> in, MatrixBlock out, int outputCol) {
        return this.apply(in, out, outputCol, 0, -1);
    }

    public MatrixBlock apply(CacheBlock<?> in, MatrixBlock out, int outputCol, int rowStart, int blk) {
        long t0;
        long l = t0 = DMLScript.STATISTICS ? System.nanoTime() : 0L;
        if (out.isInSparseFormat()) {
            this.applySparse(in, out, outputCol, rowStart, blk);
        } else {
            this.applyDense(in, out, outputCol, rowStart, blk);
        }
        if (DMLScript.STATISTICS) {
            long t = System.nanoTime() - t0;
            switch (this.getTransformType()) {
                case RECODE: {
                    TransformStatistics.incRecodeApplyTime(t);
                    break;
                }
                case BIN: {
                    TransformStatistics.incBinningApplyTime(t);
                    break;
                }
                case DUMMYCODE: {
                    TransformStatistics.incDummyCodeApplyTime(t);
                    break;
                }
                case WORD_EMBEDDING: {
                    TransformStatistics.incWordEmbeddingApplyTime(t);
                    break;
                }
                case FEATURE_HASH: {
                    TransformStatistics.incFeatureHashingApplyTime(t);
                    break;
                }
                case PASS_THROUGH: {
                    TransformStatistics.incPassThroughApplyTime(t);
                    break;
                }
            }
        }
        return out;
    }

    protected abstract double getCode(CacheBlock<?> var1, int var2);

    protected abstract double[] getCodeCol(CacheBlock<?> var1, int var2, int var3, double[] var4);

    protected void applySparse(CacheBlock<?> in, MatrixBlock out, int outputCol, int rowStart, int blk) {
        boolean mcsr = MatrixBlock.DEFAULT_SPARSEBLOCK == SparseBlock.Type.MCSR;
        mcsr = false;
        int index = this._colID - 1;
        int rowEnd = UtilFunctions.getEndIndex(in.getNumRows(), rowStart, blk);
        double[] codes = this.getCodeCol(in, rowStart, rowEnd, null);
        int B = 32;
        for (int i = rowStart; i < rowEnd; i += B) {
            int lim = Math.min(i + B, rowEnd);
            for (int ii = i; ii < lim; ++ii) {
                if (mcsr) {
                    SparseRowVector row = (SparseRowVector)out.getSparseBlock().get(ii);
                    row.values()[index] = codes[ii - rowStart];
                    row.indexes()[index] = outputCol;
                    continue;
                }
                SparseBlockCSR csrblock = (SparseBlockCSR)out.getSparseBlock();
                int[] rptr = csrblock.rowPointers();
                csrblock.indexes()[rptr[ii] + index] = outputCol;
                csrblock.values()[rptr[ii] + index] = codes[ii - rowStart];
            }
        }
    }

    protected void applyDense(CacheBlock<?> in, MatrixBlock out, int outputCol, int rowStart, int blk) {
        int rowEnd = UtilFunctions.getEndIndex(in.getNumRows(), rowStart, blk);
        int smallTile = 64;
        double[] tmp = new double[64];
        DenseBlock outB = out.getDenseBlock();
        if (outB.isContiguous(rowStart, blk)) {
            for (int i = rowStart; i < rowEnd; i += 64) {
                this.applyDenseTileContiguous(in, outB, outputCol, i, Math.min(i + 64, rowEnd), tmp);
            }
        } else {
            for (int i = rowStart; i < rowEnd; i += 64) {
                this.applyDenseTileGeneric(in, outB, outputCol, i, Math.min(i + 64, rowEnd), tmp);
            }
        }
    }

    private void applyDenseTileContiguous(CacheBlock<?> in, DenseBlock out, int outputCol, int s, int e, double[] tmp) {
        double[] codes = this.getCodeCol(in, s, e, tmp);
        double[] vals = out.values(s);
        int off = out.pos(s) + outputCol;
        int nCol = out.getDim(1);
        int i = 0;
        while (i < e - s) {
            vals[off] = codes[i];
            ++i;
            off += nCol;
        }
    }

    private void applyDenseTileGeneric(CacheBlock<?> in, DenseBlock out, int outputCol, int s, int e, double[] tmp) {
        double[] codes = this.getCodeCol(in, s, e, tmp);
        for (int i = s; i < e; ++i) {
            out.set(i, outputCol, codes[i - s]);
        }
    }

    protected abstract TransformType getTransformType();

    public boolean isApplicable() {
        return this._colID != -1;
    }

    public boolean isApplicable(int colID) {
        return colID == this._colID;
    }

    @Override
    public void prepareBuildPartial() {
    }

    public int getDomainSize() {
        return 1;
    }

    @Override
    public void buildPartial(FrameBlock in) {
    }

    public void build(CacheBlock<?> in, double[] equiHeightMaxs) {
    }

    public void build(CacheBlock<?> in, Map<Integer, double[]> equiHeightMaxs) {
    }

    public void mergeAt(ColumnEncoder other) {
        throw new DMLRuntimeException(this.getClass().getSimpleName() + " does not support merging with " + other.getClass().getSimpleName());
    }

    @Override
    public void updateIndexRanges(long[] beginDims, long[] endDims, int colOffset) {
    }

    public MatrixBlock getColMapping(FrameBlock meta) {
        return null;
    }

    @Override
    public void writeExternal(ObjectOutput os) throws IOException {
        os.writeInt(this._colID);
    }

    @Override
    public void readExternal(ObjectInput in) throws IOException {
        this._colID = in.readInt();
    }

    public int getColID() {
        return this._colID;
    }

    public void setColID(int colID) {
        this._colID = colID;
    }

    public void shiftCol(int columnOffset) {
        this._colID += columnOffset;
    }

    public void setEstMetaSize(long estSize) {
        this._estMetaSize = estSize;
    }

    public long getEstMetaSize() {
        return this._estMetaSize;
    }

    public void setEstNumDistincts(int numDistincts) {
        this._estNumDistincts = numDistincts;
    }

    public int getEstNumDistincts() {
        return this._estNumDistincts;
    }

    @Override
    public int compareTo(ColumnEncoder o) {
        return Integer.compare(EncoderFactory.getEncoderType(this), EncoderFactory.getEncoderType(o));
    }

    public List<DependencyTask<?>> getBuildTasks(CacheBlock<?> in) {
        ArrayList<Callable<Object>> tasks = new ArrayList<Callable<Object>>();
        ArrayList<Object> dep = null;
        int nRows = in.getNumRows();
        int[] blockSizes = UtilFunctions.getBlockSizes(nRows, this._nBuildPartitions);
        if (blockSizes.length == 1) {
            tasks.add(this.getBuildTask(in));
        } else {
            HashMap<Integer, Object> ret = new HashMap<Integer, Object>();
            int startRow = 0;
            for (int i = 0; i < blockSizes.length; ++i) {
                tasks.add(this.getPartialBuildTask(in, startRow, blockSizes[i], ret));
                startRow += blockSizes[i];
            }
            tasks.add(this.getPartialMergeBuildTask(ret));
            dep = new ArrayList<Object>(Collections.nCopies(tasks.size() - 1, null));
            dep.add(tasks.subList(0, tasks.size() - 1));
        }
        return DependencyThreadPool.createDependencyTasks(tasks, dep);
    }

    public Callable<Object> getBuildTask(CacheBlock<?> in) {
        throw new DMLRuntimeException("Trying to get the Build task of an Encoder which does not require building");
    }

    public Callable<Object> getPartialBuildTask(CacheBlock<?> in, int startRow, int blockSize, HashMap<Integer, Object> ret) {
        throw new DMLRuntimeException("Trying to get the PartialBuild task of an Encoder which does not support  partial building");
    }

    public Callable<Object> getPartialMergeBuildTask(HashMap<Integer, ?> ret) {
        throw new DMLRuntimeException("Trying to get the BuildMergeTask task of an Encoder which does not support partial building");
    }

    public List<DependencyTask<?>> getApplyTasks(CacheBlock<?> in, MatrixBlock out, int outputCol) {
        ArrayList<Callable<Object>> tasks = new ArrayList<Callable<Object>>();
        ArrayList<Object> dep = null;
        int[] blockSizes = UtilFunctions.getBlockSizes(in.getNumRows(), this._nApplyPartitions);
        int startRow = 0;
        for (int i = 0; i < blockSizes.length; ++i) {
            if (out.isInSparseFormat()) {
                tasks.add(this.getSparseTask(in, out, outputCol, startRow, blockSizes[i]));
            } else {
                tasks.add(this.getDenseTask(in, out, outputCol, startRow, blockSizes[i]));
            }
            startRow += blockSizes[i];
        }
        if (tasks.size() > 1) {
            dep = new ArrayList<Object>(Collections.nCopies(tasks.size(), null));
            tasks.add(() -> null);
            dep.add(tasks.subList(0, tasks.size() - 1));
        }
        return DependencyThreadPool.createDependencyTasks(tasks, dep);
    }

    protected ColumnApplyTask<? extends ColumnEncoder> getSparseTask(CacheBlock<?> in, MatrixBlock out, int outputCol, int startRow, int blk) {
        return new ColumnApplyTask<ColumnEncoder>(this, in, out, outputCol, startRow, blk);
    }

    protected ColumnApplyTask<? extends ColumnEncoder> getDenseTask(CacheBlock<?> in, MatrixBlock out, int outputCol, int startRow, int blk) {
        return new ColumnApplyTask<ColumnEncoder>(this, in, out, outputCol, startRow, blk);
    }

    public Set<Integer> getSparseRowsWZeros() {
        if (this._sparseRowsWZeros != null) {
            return new HashSet<Integer>(this._sparseRowsWZeros);
        }
        return null;
    }

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     */
    protected void addSparseRowsWZeros(ArrayList<Integer> sparseRowsWZeros) {
        ColumnEncoder columnEncoder = this;
        synchronized (columnEncoder) {
            if (this._sparseRowsWZeros == null) {
                this._sparseRowsWZeros = new ArrayList();
            }
            this._sparseRowsWZeros.addAll(sparseRowsWZeros);
        }
    }

    protected void setBuildRowBlocksPerColumn(int nPart) {
        this._nBuildPartitions = nPart;
    }

    protected void setApplyRowBlocksPerColumn(int nPart) {
        this._nApplyPartitions = nPart;
    }

    protected static class ColumnApplyTask<T extends ColumnEncoder>
    implements Callable<Object> {
        protected final T _encoder;
        protected final CacheBlock<?> _input;
        protected final MatrixBlock _out;
        protected final int _outputCol;
        protected final int _startRow;
        protected final int _blk;

        protected ColumnApplyTask(T encoder, CacheBlock<?> input, MatrixBlock out, int outputCol) {
            this(encoder, input, out, outputCol, 0, -1);
        }

        protected ColumnApplyTask(T encoder, CacheBlock<?> input, MatrixBlock out, int outputCol, int startRow, int blk) {
            this._encoder = encoder;
            this._input = input;
            this._out = out;
            this._outputCol = outputCol;
            this._startRow = startRow;
            this._blk = blk;
        }

        @Override
        public Object call() throws Exception {
            assert (this._outputCol >= 0);
            ((ColumnEncoder)this._encoder).apply(this._input, this._out, this._outputCol, this._startRow, this._blk);
            return null;
        }

        public String toString() {
            return this.getClass().getSimpleName() + "<Encoder: " + this._encoder.getClass().getSimpleName() + "; ColId: " + ((ColumnEncoder)this._encoder)._colID + ">";
        }
    }

    public static enum EncoderType {
        Recode,
        FeatureHash,
        PassThrough,
        Bin,
        Dummycode,
        Omit,
        MVImpute,
        Composite,
        WordEmbedding;

    }

    protected static enum TransformType {
        BIN,
        RECODE,
        DUMMYCODE,
        FEATURE_HASH,
        PASS_THROUGH,
        UDF,
        WORD_EMBEDDING,
        N_A;

    }
}

