/*
 * Decompiled with CFR 0.152.
 */
package org.apache.sysds.runtime.instructions.spark;

import java.util.ArrayList;
import java.util.Iterator;
import org.apache.spark.api.java.JavaPairRDD;
import org.apache.spark.api.java.function.PairFlatMapFunction;
import org.apache.spark.broadcast.Broadcast;
import org.apache.sysds.common.Types;
import org.apache.sysds.runtime.DMLRuntimeException;
import org.apache.sysds.runtime.controlprogram.context.ExecutionContext;
import org.apache.sysds.runtime.controlprogram.context.SparkExecutionContext;
import org.apache.sysds.runtime.functionobjects.SwapIndex;
import org.apache.sysds.runtime.instructions.InstructionUtils;
import org.apache.sysds.runtime.instructions.cp.CPOperand;
import org.apache.sysds.runtime.instructions.spark.SPInstruction;
import org.apache.sysds.runtime.instructions.spark.UnarySPInstruction;
import org.apache.sysds.runtime.instructions.spark.data.LazyIterableIterator;
import org.apache.sysds.runtime.instructions.spark.functions.ExtractBlockForBinaryReblock;
import org.apache.sysds.runtime.instructions.spark.utils.RDDAggregateUtils;
import org.apache.sysds.runtime.matrix.data.DnnParameters;
import org.apache.sysds.runtime.matrix.data.LibMatrixDNN;
import org.apache.sysds.runtime.matrix.data.LibMatrixNative;
import org.apache.sysds.runtime.matrix.data.MatrixBlock;
import org.apache.sysds.runtime.matrix.data.MatrixIndexes;
import org.apache.sysds.runtime.matrix.operators.ReorgOperator;
import org.apache.sysds.runtime.meta.DataCharacteristics;
import org.apache.sysds.runtime.meta.MatrixCharacteristics;
import org.apache.sysds.runtime.meta.MetaDataFormat;
import org.apache.sysds.runtime.util.DnnUtils;
import org.apache.sysds.utils.NativeHelper;
import scala.Tuple2;

public class DnnSPInstruction
extends UnarySPInstruction {
    private CPOperand _in2;
    private CPOperand _in3;
    private ArrayList<CPOperand> _input_shape;
    private ArrayList<CPOperand> _filter_shape;
    private ArrayList<CPOperand> _stride = new ArrayList();
    private ArrayList<CPOperand> _padding = new ArrayList();

    private DnnSPInstruction(CPOperand in, CPOperand out, String opcode, String istr, ArrayList<CPOperand> stride, ArrayList<CPOperand> padding, ArrayList<CPOperand> input_shape, ArrayList<CPOperand> filter_shape) {
        super(SPInstruction.SPType.Dnn, new ReorgOperator(SwapIndex.getSwapIndexFnObject()), in, out, opcode, istr);
        this._stride = stride;
        this._padding = padding;
        this._input_shape = input_shape;
        this._filter_shape = filter_shape;
    }

    private DnnSPInstruction(CPOperand in, CPOperand in2, CPOperand out, String opcode, String istr, ArrayList<CPOperand> stride, ArrayList<CPOperand> padding, ArrayList<CPOperand> input_shape, ArrayList<CPOperand> filter_shape) {
        super(SPInstruction.SPType.Dnn, new ReorgOperator(SwapIndex.getSwapIndexFnObject()), in, out, opcode, istr);
        this._in2 = in2;
        this._stride = stride;
        this._padding = padding;
        this._input_shape = input_shape;
        this._filter_shape = filter_shape;
    }

    private DnnSPInstruction(CPOperand in, CPOperand in2, CPOperand in3, CPOperand out, String opcode, String istr, ArrayList<CPOperand> stride, ArrayList<CPOperand> padding, ArrayList<CPOperand> input_shape, ArrayList<CPOperand> filter_shape) {
        super(SPInstruction.SPType.Dnn, new ReorgOperator(SwapIndex.getSwapIndexFnObject()), in, out, opcode, istr);
        this._in2 = in2;
        this._in3 = in3;
        this._stride = stride;
        this._padding = padding;
        this._input_shape = input_shape;
        this._filter_shape = filter_shape;
    }

    private DnnSPInstruction(CPOperand in, CPOperand in2, CPOperand out, String opcode, String istr) {
        super(SPInstruction.SPType.Dnn, new ReorgOperator(SwapIndex.getSwapIndexFnObject()), in, out, opcode, istr);
        this._in2 = in2;
    }

    public static DnnSPInstruction parseInstruction(String str) {
        CPOperand in = new CPOperand("", Types.ValueType.UNKNOWN, Types.DataType.UNKNOWN);
        CPOperand out = new CPOperand("", Types.ValueType.UNKNOWN, Types.DataType.UNKNOWN);
        String[] parts = InstructionUtils.getInstructionPartsWithValueType(str);
        String opcode = parts[0];
        if (opcode.equalsIgnoreCase("maxpooling") || opcode.equalsIgnoreCase("relu_maxpooling")) {
            InstructionUtils.checkNumFields(parts, 14);
            in.split(parts[1]);
            out.split(parts[14]);
            ArrayList<CPOperand> stride = new ArrayList<CPOperand>();
            ArrayList<CPOperand> padding = new ArrayList<CPOperand>();
            ArrayList<CPOperand> input_shape = new ArrayList<CPOperand>();
            ArrayList<CPOperand> filter_shape = new ArrayList<CPOperand>();
            stride.add(new CPOperand(parts[2]));
            stride.add(new CPOperand(parts[3]));
            padding.add(new CPOperand(parts[4]));
            padding.add(new CPOperand(parts[5]));
            input_shape.add(new CPOperand(parts[6]));
            input_shape.add(new CPOperand(parts[7]));
            input_shape.add(new CPOperand(parts[8]));
            input_shape.add(new CPOperand(parts[9]));
            filter_shape.add(new CPOperand(parts[10]));
            filter_shape.add(new CPOperand(parts[11]));
            filter_shape.add(new CPOperand(parts[12]));
            filter_shape.add(new CPOperand(parts[13]));
            return new DnnSPInstruction(in, out, opcode, str, stride, padding, input_shape, filter_shape);
        }
        if (opcode.equalsIgnoreCase("maxpooling_backward") || opcode.equalsIgnoreCase("conv2d") || opcode.equalsIgnoreCase("conv2d_backward_filter") || opcode.equalsIgnoreCase("conv2d_backward_data")) {
            InstructionUtils.checkNumFields(parts, 15);
            in.split(parts[1]);
            CPOperand in2 = new CPOperand("", Types.ValueType.UNKNOWN, Types.DataType.UNKNOWN);
            in2.split(parts[2]);
            out.split(parts[15]);
            ArrayList<CPOperand> stride = new ArrayList<CPOperand>();
            ArrayList<CPOperand> padding = new ArrayList<CPOperand>();
            ArrayList<CPOperand> input_shape = new ArrayList<CPOperand>();
            ArrayList<CPOperand> filter_shape = new ArrayList<CPOperand>();
            stride.add(new CPOperand(parts[3]));
            stride.add(new CPOperand(parts[4]));
            padding.add(new CPOperand(parts[5]));
            padding.add(new CPOperand(parts[6]));
            input_shape.add(new CPOperand(parts[7]));
            input_shape.add(new CPOperand(parts[8]));
            input_shape.add(new CPOperand(parts[9]));
            input_shape.add(new CPOperand(parts[10]));
            filter_shape.add(new CPOperand(parts[11]));
            filter_shape.add(new CPOperand(parts[12]));
            filter_shape.add(new CPOperand(parts[13]));
            filter_shape.add(new CPOperand(parts[14]));
            return new DnnSPInstruction(in, in2, out, opcode, str, stride, padding, input_shape, filter_shape);
        }
        if (opcode.equalsIgnoreCase("conv2d_bias_add")) {
            InstructionUtils.checkNumFields(parts, 16);
            in.split(parts[1]);
            CPOperand in2 = new CPOperand("", Types.ValueType.UNKNOWN, Types.DataType.UNKNOWN);
            in2.split(parts[2]);
            CPOperand in3 = new CPOperand("", Types.ValueType.UNKNOWN, Types.DataType.UNKNOWN);
            in3.split(parts[3]);
            out.split(parts[16]);
            ArrayList<CPOperand> stride = new ArrayList<CPOperand>();
            ArrayList<CPOperand> padding = new ArrayList<CPOperand>();
            ArrayList<CPOperand> input_shape = new ArrayList<CPOperand>();
            ArrayList<CPOperand> filter_shape = new ArrayList<CPOperand>();
            stride.add(new CPOperand(parts[4]));
            stride.add(new CPOperand(parts[5]));
            padding.add(new CPOperand(parts[6]));
            padding.add(new CPOperand(parts[7]));
            input_shape.add(new CPOperand(parts[8]));
            input_shape.add(new CPOperand(parts[9]));
            input_shape.add(new CPOperand(parts[10]));
            input_shape.add(new CPOperand(parts[11]));
            filter_shape.add(new CPOperand(parts[12]));
            filter_shape.add(new CPOperand(parts[13]));
            filter_shape.add(new CPOperand(parts[14]));
            filter_shape.add(new CPOperand(parts[15]));
            return new DnnSPInstruction(in, in2, in3, out, opcode, str, stride, padding, input_shape, filter_shape);
        }
        if (opcode.equalsIgnoreCase("bias_add")) {
            InstructionUtils.checkNumFields(parts, 3);
            in.split(parts[1]);
            CPOperand in2 = new CPOperand("", Types.ValueType.UNKNOWN, Types.DataType.UNKNOWN);
            in2.split(parts[2]);
            out.split(parts[3]);
            return new DnnSPInstruction(in, in2, out, opcode, str);
        }
        throw new DMLRuntimeException("Unknown opcode while parsing a DnnCPInstruction: " + str);
    }

    private static JavaPairRDD<MatrixIndexes, MatrixBlock> reblockAsRectangularMatrices(SparkExecutionContext sec, String name, int numRowsPerBlock) {
        JavaPairRDD<MatrixIndexes, MatrixBlock> in1 = sec.getBinaryMatrixBlockRDDHandleForVariable(name);
        DataCharacteristics mcRdd = sec.getDataCharacteristics(name);
        if (mcRdd.getBlocksize() != 1) {
            MatrixCharacteristics mcOut = new MatrixCharacteristics(mcRdd);
            mcOut.setBlocksize(numRowsPerBlock);
            in1 = RDDAggregateUtils.mergeByKey(in1.flatMapToPair(new ExtractBlockForBinaryReblock(mcRdd, mcOut)));
        }
        return in1;
    }

    private static Broadcast<MatrixBlock> getBroadcast(SparkExecutionContext sec, String name) {
        MatrixBlock mb = sec.getMatrixInput(name);
        sec.releaseMatrixInput(name);
        return sec.getSparkContext().broadcast(mb);
    }

    @Override
    public void processInstruction(ExecutionContext ec) {
        long numCols;
        long nnz;
        DataCharacteristics mcRdd;
        int numRowsPerBlock;
        SparkExecutionContext sec = (SparkExecutionContext)ec;
        if (this.instOpcode.equalsIgnoreCase("conv2d") || this.instOpcode.equalsIgnoreCase("conv2d_bias_add") || this.instOpcode.equalsIgnoreCase("maxpooling") || this.instOpcode.equalsIgnoreCase("relu_maxpooling")) {
            String rddVar = this.input1.getName();
            numRowsPerBlock = 1;
            JavaPairRDD<MatrixIndexes, MatrixBlock> inputRDD = DnnSPInstruction.reblockAsRectangularMatrices(sec, rddVar, numRowsPerBlock);
            mcRdd = sec.getDataCharacteristics(rddVar);
            Broadcast<MatrixBlock> filterBroadcast = null;
            Broadcast<MatrixBlock> biasBroadcast = null;
            if (this.instOpcode.equalsIgnoreCase("conv2d")) {
                filterBroadcast = DnnSPInstruction.getBroadcast(sec, this._in2.getName());
            } else if (this.instOpcode.equalsIgnoreCase("conv2d_bias_add")) {
                filterBroadcast = DnnSPInstruction.getBroadcast(sec, this._in3.getName());
                biasBroadcast = DnnSPInstruction.getBroadcast(sec, this._in2.getName());
            }
            int pad_h = DnnSPInstruction.getScalarInput(ec, this._padding, 0);
            int pad_w = DnnSPInstruction.getScalarInput(ec, this._padding, 1);
            int stride_h = DnnSPInstruction.getScalarInput(ec, this._stride, 0);
            int stride_w = DnnSPInstruction.getScalarInput(ec, this._stride, 1);
            int C = DnnSPInstruction.getScalarInput(ec, this._input_shape, 1);
            int H = DnnSPInstruction.getScalarInput(ec, this._input_shape, 2);
            int W = DnnSPInstruction.getScalarInput(ec, this._input_shape, 3);
            int K2 = DnnSPInstruction.getScalarInput(ec, this._filter_shape, 0);
            int R = DnnSPInstruction.getScalarInput(ec, this._filter_shape, 2);
            int S = DnnSPInstruction.getScalarInput(ec, this._filter_shape, 3);
            int P = (int)DnnUtils.getP(H, R, stride_h, pad_h);
            int Q = (int)DnnUtils.getQ(W, S, stride_w, pad_w);
            DnnParameters params = new DnnParameters(numRowsPerBlock, C, H, W, K2, R, S, stride_h, stride_w, pad_h, pad_w, 1);
            boolean enableNativeBLAS = NativeHelper.isNativeLibraryLoaded();
            JavaPairRDD<MatrixIndexes, MatrixBlock> out = inputRDD.mapPartitionsToPair(new RDDConv2dMapMMFunction(filterBroadcast, params, this.instOpcode, biasBroadcast, mcRdd.getRows(), enableNativeBLAS), true);
            sec.setRDDHandleForVariable(this.output.getName(), out);
            sec.addLineageRDD(this.output.getName(), rddVar);
            nnz = -1L;
            numCols = (long)K2 * (long)P * (long)Q;
            if (this.instOpcode.equalsIgnoreCase("maxpooling") || this.instOpcode.equalsIgnoreCase("relu_maxpooling")) {
                numCols = (long)C * (long)P * (long)Q;
            }
            if (numCols > Integer.MAX_VALUE) {
                throw new DMLRuntimeException("The current operator doesnot support large outputs.");
            }
        } else {
            throw new DMLRuntimeException("Not implemented: " + this.instOpcode);
        }
        sec.setMetaData(this.output.getName(), new MetaDataFormat(new MatrixCharacteristics(mcRdd.getRows(), numCols, numRowsPerBlock, nnz), Types.FileFormat.BINARY));
    }

    private static int getScalarInput(ExecutionContext ec, ArrayList<CPOperand> aL, int index) {
        return (int)ec.getScalarInput(aL.get(index)).getLongValue();
    }

    private static class RDDConv2dMapMMFunction
    implements PairFlatMapFunction<Iterator<Tuple2<MatrixIndexes, MatrixBlock>>, MatrixIndexes, MatrixBlock> {
        private static final long serialVersionUID = -2106155380020232155L;
        Broadcast<MatrixBlock> filterBroadcast = null;
        Broadcast<MatrixBlock> biasBroadcast = null;
        DnnParameters params = null;
        String instOpcode = null;
        boolean enableNative;
        long numRows = 0L;

        public RDDConv2dMapMMFunction(Broadcast<MatrixBlock> filterBroadcast, DnnParameters params, String instOpcode, Broadcast<MatrixBlock> biasBroadcast, long numRows, boolean enableNativeBLAS) {
            this.filterBroadcast = filterBroadcast;
            this.params = params;
            this.instOpcode = instOpcode;
            this.biasBroadcast = biasBroadcast;
            this.numRows = numRows;
            this.enableNative = enableNativeBLAS;
        }

        private MatrixBlock processRectangularBlock(MatrixBlock matBlock) throws Exception {
            MatrixBlock outputBlock = null;
            if (this.instOpcode.equalsIgnoreCase("conv2d")) {
                MatrixBlock filter = this.filterBroadcast.getValue();
                if (filter.isEmptyBlock() || matBlock.isEmptyBlock()) {
                    outputBlock = new MatrixBlock(this.params.N, this.params.K * this.params.P * this.params.Q, true);
                } else {
                    outputBlock = new MatrixBlock(this.params.N, this.params.K * this.params.P * this.params.Q, false).allocateDenseBlock();
                    if (this.enableNative) {
                        LibMatrixNative.conv2d(matBlock, filter, outputBlock, this.params);
                    } else {
                        LibMatrixDNN.conv2d(matBlock, filter, outputBlock, this.params);
                    }
                }
            } else if (this.instOpcode.equalsIgnoreCase("conv2d_bias_add")) {
                MatrixBlock filter = this.filterBroadcast.getValue();
                MatrixBlock bias = this.biasBroadcast.getValue();
                if ((filter.isEmptyBlock() || matBlock.isEmptyBlock()) && bias.isEmptyBlock()) {
                    outputBlock = new MatrixBlock(this.params.N, this.params.K * this.params.P * this.params.Q, true);
                } else {
                    outputBlock = new MatrixBlock(this.params.N, this.params.K * this.params.P * this.params.Q, false).allocateDenseBlock();
                    if (!bias.isEmptyBlock()) {
                        this.params.bias = bias;
                    }
                    if (this.enableNative) {
                        LibMatrixNative.conv2d(matBlock, filter, outputBlock, this.params);
                    } else {
                        LibMatrixDNN.conv2d(matBlock, filter, outputBlock, this.params);
                    }
                }
            } else if (this.instOpcode.equalsIgnoreCase("maxpooling") || this.instOpcode.equalsIgnoreCase("relu_maxpooling")) {
                if (matBlock.isEmptyBlock()) {
                    outputBlock = new MatrixBlock(this.params.N, this.params.C * this.params.P * this.params.Q, true);
                } else {
                    outputBlock = new MatrixBlock(this.params.N, this.params.C * this.params.P * this.params.Q, false).allocateBlock();
                    if (this.instOpcode.equalsIgnoreCase("maxpooling")) {
                        outputBlock.getDenseBlock().set(-1.7976931348623157E308);
                    }
                    LibMatrixDNN.pooling(matBlock, outputBlock, this.params, LibMatrixDNN.PoolingType.MAX);
                }
            } else if (this.instOpcode.equalsIgnoreCase("avgpooling") || this.instOpcode.equalsIgnoreCase("relu_avgpooling")) {
                if (matBlock.isEmptyBlock()) {
                    outputBlock = new MatrixBlock(this.params.N, this.params.C * this.params.P * this.params.Q, true);
                } else {
                    outputBlock = new MatrixBlock(this.params.N, this.params.C * this.params.P * this.params.Q, false).allocateBlock();
                    LibMatrixDNN.pooling(matBlock, outputBlock, this.params, LibMatrixDNN.PoolingType.AVG);
                }
            } else {
                throw new RuntimeException("Not implemented");
            }
            return outputBlock;
        }

        public Iterator<Tuple2<MatrixIndexes, MatrixBlock>> call(Iterator<Tuple2<MatrixIndexes, MatrixBlock>> arg0) throws Exception {
            return new MapsideDnnPartitionIterator(arg0);
        }

        private class MapsideDnnPartitionIterator
        extends LazyIterableIterator<Tuple2<MatrixIndexes, MatrixBlock>> {
            public MapsideDnnPartitionIterator(Iterator<Tuple2<MatrixIndexes, MatrixBlock>> in) {
                super(in);
            }

            @Override
            protected Tuple2<MatrixIndexes, MatrixBlock> computeNext(Tuple2<MatrixIndexes, MatrixBlock> arg) throws Exception {
                if (((MatrixIndexes)arg._1).getRowIndex() > RDDConv2dMapMMFunction.this.numRows || ((MatrixIndexes)arg._1).getColumnIndex() != 1L) {
                    throw new RuntimeException("Expected the inputs to be reblocked as rectangular RDD");
                }
                MatrixBlock out = RDDConv2dMapMMFunction.this.processRectangularBlock((MatrixBlock)arg._2);
                if (out.getNumRows() != 1) {
                    throw new RuntimeException("Expected the output to have 1 row");
                }
                return new Tuple2((Object)((MatrixIndexes)arg._1), (Object)out);
            }
        }
    }
}

