/*
 * Decompiled with CFR 0.152.
 */
package org.apache.beam.runners.spark.structuredstreaming.translation.batch;

import java.io.IOException;
import java.util.AbstractMap;
import java.util.ArrayList;
import java.util.Collection;
import java.util.List;
import java.util.Map;
import java.util.stream.Collectors;
import javax.annotation.Nullable;
import org.apache.beam.runners.core.construction.ParDoTranslation;
import org.apache.beam.runners.spark.SparkCommonPipelineOptions;
import org.apache.beam.runners.spark.structuredstreaming.translation.TransformTranslator;
import org.apache.beam.runners.spark.structuredstreaming.translation.batch.DoFnMapPartitionsFactory;
import org.apache.beam.runners.spark.structuredstreaming.translation.helpers.CoderHelpers;
import org.apache.beam.runners.spark.structuredstreaming.translation.helpers.EncoderHelpers;
import org.apache.beam.runners.spark.structuredstreaming.translation.helpers.SideInputBroadcast;
import org.apache.beam.runners.spark.structuredstreaming.translation.utils.ScalaInterop;
import org.apache.beam.sdk.coders.Coder;
import org.apache.beam.sdk.transforms.DoFn;
import org.apache.beam.sdk.transforms.DoFnSchemaInformation;
import org.apache.beam.sdk.transforms.ParDo;
import org.apache.beam.sdk.transforms.reflect.DoFnSignatures;
import org.apache.beam.sdk.util.WindowedValue;
import org.apache.beam.sdk.values.PCollection;
import org.apache.beam.sdk.values.PCollectionTuple;
import org.apache.beam.sdk.values.PCollectionView;
import org.apache.beam.sdk.values.TupleTag;
import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.base.Preconditions;
import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.ImmutableMap;
import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.Streams;
import org.apache.spark.rdd.RDD;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Encoder;
import org.apache.spark.sql.TypedColumn;
import org.apache.spark.sql.functions;
import org.apache.spark.storage.StorageLevel;
import scala.Tuple2;
import scala.reflect.ClassTag;

class ParDoTranslatorBatch<InputT, OutputT>
extends TransformTranslator<PCollection<? extends InputT>, PCollectionTuple, ParDo.MultiOutput<InputT, OutputT>> {
    private static final ClassTag<WindowedValue<Object>> WINDOWED_VALUE_CTAG = ClassTag.apply(WindowedValue.class);
    private static final ClassTag<Tuple2<Integer, WindowedValue<Object>>> TUPLE2_CTAG = ClassTag.apply(Tuple2.class);

    ParDoTranslatorBatch() {
    }

    @Override
    public void translate(ParDo.MultiOutput<InputT, OutputT> transform, TransformTranslator.Context cxt) throws IOException {
        String stepName = cxt.getCurrentTransform().getFullName();
        SparkCommonPipelineOptions opts = (SparkCommonPipelineOptions)cxt.getOptions().as(SparkCommonPipelineOptions.class);
        StorageLevel storageLevel = StorageLevel.fromString((String)opts.getStorageLevel());
        DoFn doFn = transform.getFn();
        Preconditions.checkState((!DoFnSignatures.isSplittable((DoFn)doFn) ? 1 : 0) != 0, (String)"Not expected to directly translate splittable DoFn, should have been overridden: %s", (Object)doFn);
        Preconditions.checkState((!DoFnSignatures.isStateful((DoFn)doFn) ? 1 : 0) != 0, (Object)"States and timers are not supported for the moment.");
        Preconditions.checkState((!DoFnSignatures.requiresTimeSortedInput((DoFn)doFn) ? 1 : 0) != 0, (Object)"@RequiresTimeSortedInput is not supported for the moment");
        TupleTag mainOutputTag = transform.getMainOutputTag();
        DoFnSchemaInformation doFnSchema = ParDoTranslation.getSchemaInformation(cxt.getCurrentTransform());
        PCollection input = (PCollection)cxt.getInput();
        DoFnMapPartitionsFactory factory = new DoFnMapPartitionsFactory(stepName, doFn, doFnSchema, cxt.getSerializableOptions(), input, mainOutputTag, cxt.getOutputs(), transform.getSideInputs(), this.createBroadcastSideInputs(transform.getSideInputs().values(), cxt));
        Dataset inputDs = cxt.getDataset(input);
        if (cxt.getOutputs().size() > 1) {
            ImmutableMap tags = ImmutableMap.copyOf(ParDoTranslatorBatch.zipwithIndex(cxt.getOutputs().keySet()));
            List encoders = this.createEncoders(cxt.getOutputs(), tags.keySet(), cxt);
            ScalaInterop.Fun1 doFnMapper = factory.create((arg_0, arg_1) -> ParDoTranslatorBatch.lambda$translate$f432823e$1((Map)tags, arg_0, arg_1));
            if (StorageLevel.MEMORY_ONLY().equals((Object)storageLevel)) {
                RDD allTagsRDD = inputDs.rdd().mapPartitions(doFnMapper, false, TUPLE2_CTAG);
                allTagsRDD.persist();
                for (Map.Entry e : tags.entrySet()) {
                    TupleTag key = (TupleTag)e.getKey();
                    Integer id = (Integer)e.getValue();
                    RDD rddByTag = allTagsRDD.filter(ScalaInterop.fun1(t -> ((Integer)t._1).equals(id))).map(ScalaInterop.fun1(Tuple2::_2), WINDOWED_VALUE_CTAG);
                    cxt.putDataset(cxt.getOutput(key), cxt.getSparkSession().createDataset(rddByTag, encoders.get(id)));
                }
            } else {
                Dataset allTagsDS = inputDs.mapPartitions(doFnMapper, EncoderHelpers.oneOfEncoder(encoders));
                allTagsDS.persist(storageLevel);
                for (Map.Entry e : tags.entrySet()) {
                    TupleTag key = (TupleTag)e.getKey();
                    Integer id = (Integer)e.getValue();
                    TypedColumn col = functions.col((String)id.toString()).as(encoders.get(id));
                    cxt.putDataset(cxt.getOutput(key), allTagsDS.filter(col.isNotNull()).select(col));
                }
            }
        } else {
            PCollection output = cxt.getOutput(mainOutputTag);
            Dataset mainDS = inputDs.mapPartitions(factory.create((tag, value) -> value), cxt.windowedEncoder(output.getCoder()));
            cxt.putDataset(output, mainDS);
        }
    }

    private List<Encoder<WindowedValue<Object>>> createEncoders(Map<TupleTag<?>, PCollection<?>> outputs, Iterable<TupleTag<?>> columns, TransformTranslator.Context ctx) {
        return Streams.stream(columns).map(tag -> ctx.windowedEncoder(this.getCoder((PCollection<?>)((PCollection)outputs.get(tag)), (TupleTag<?>)tag))).collect(Collectors.toList());
    }

    private Coder<Object> getCoder(@Nullable PCollection<?> pc, TupleTag<?> tag) {
        if (pc == null) {
            throw new NullPointerException("No PCollection for tag " + tag);
        }
        return pc.getCoder();
    }

    private SideInputBroadcast createBroadcastSideInputs(Collection<PCollectionView<?>> sideInputs, TransformTranslator.Context context) {
        SideInputBroadcast sideInputBroadcast = new SideInputBroadcast();
        for (PCollectionView<?> sideInput : sideInputs) {
            PCollection pc = sideInput.getPCollection();
            if (pc == null) {
                throw new NullPointerException("PCollection for SideInput is null");
            }
            Coder windowCoder = pc.getWindowingStrategy().getWindowFn().windowCoder();
            WindowedValue.FullWindowedValueCoder windowedValueCoder = WindowedValue.getFullCoder((Coder)pc.getCoder(), (Coder)windowCoder);
            Dataset broadcastSet = context.getSideInputDataset(sideInput);
            List valuesList = broadcastSet.collectAsList();
            ArrayList<byte[]> codedValues = new ArrayList<byte[]>();
            for (WindowedValue v : valuesList) {
                codedValues.add(CoderHelpers.toByteArray(v, windowedValueCoder));
            }
            sideInputBroadcast.add(sideInput.getTagInternal().getId(), context.broadcast(codedValues), (Coder<?>)windowedValueCoder);
        }
        return sideInputBroadcast;
    }

    private static <T> Collection<Map.Entry<T, Integer>> zipwithIndex(Collection<T> col) {
        ArrayList<Map.Entry<T, Integer>> zipped = new ArrayList<Map.Entry<T, Integer>>(col.size());
        int i = 0;
        for (T t : col) {
            zipped.add(new AbstractMap.SimpleImmutableEntry<T, Integer>(t, i++));
        }
        return zipped;
    }

    private static /* synthetic */ Tuple2 lambda$translate$f432823e$1(Map tags, TupleTag tag, WindowedValue v) {
        return ScalaInterop.tuple((Integer)tags.get(tag), v);
    }
}

