/*
 * Decompiled with CFR 0.152.
 */
package org.apache.hadoop.hive.ql.optimizer;

import com.google.common.collect.Lists;
import java.util.ArrayList;
import java.util.Comparator;
import java.util.EnumSet;
import java.util.HashMap;
import java.util.HashSet;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.Set;
import java.util.TreeSet;
import org.apache.calcite.util.Pair;
import org.apache.commons.collections4.multimap.ArrayListValuedHashMap;
import org.apache.hadoop.hive.ql.exec.ColumnInfo;
import org.apache.hadoop.hive.ql.exec.MapJoinOperator;
import org.apache.hadoop.hive.ql.exec.Operator;
import org.apache.hadoop.hive.ql.exec.OperatorFactory;
import org.apache.hadoop.hive.ql.exec.ReduceSinkOperator;
import org.apache.hadoop.hive.ql.exec.RowSchema;
import org.apache.hadoop.hive.ql.exec.TableScanOperator;
import org.apache.hadoop.hive.ql.optimizer.Transform;
import org.apache.hadoop.hive.ql.optimizer.graph.OperatorGraph;
import org.apache.hadoop.hive.ql.parse.ParseContext;
import org.apache.hadoop.hive.ql.parse.RuntimeValuesInfo;
import org.apache.hadoop.hive.ql.parse.SemanticException;
import org.apache.hadoop.hive.ql.parse.SemiJoinBranchInfo;
import org.apache.hadoop.hive.ql.plan.ExprNodeColumnDesc;
import org.apache.hadoop.hive.ql.plan.ExprNodeConstantDesc;
import org.apache.hadoop.hive.ql.plan.ExprNodeDesc;
import org.apache.hadoop.hive.ql.plan.OperatorDesc;
import org.apache.hadoop.hive.ql.plan.ReduceSinkDesc;
import org.apache.hadoop.hive.ql.plan.SelectDesc;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class ParallelEdgeFixer
extends Transform {
    protected static final Logger LOG = LoggerFactory.getLogger(ParallelEdgeFixer.class);

    @Override
    public ParseContext transform(ParseContext pctx) throws SemanticException {
        OperatorGraph og = new OperatorGraph(pctx);
        try (MaterializeSemiJoinEdges ac = new MaterializeSemiJoinEdges(pctx);){
            this.fixParallelEdges(og);
        }
        catch (Exception e) {
            if (e instanceof SemanticException) {
                throw (SemanticException)((Object)e);
            }
            throw new SemanticException((Throwable)e);
        }
        return pctx;
    }

    private void fixParallelEdges(OperatorGraph og) throws SemanticException {
        ActualEdgePredicate actualEdgePredicate = new ActualEdgePredicate();
        ArrayListValuedHashMap edgeOperators = new ArrayListValuedHashMap();
        for (OperatorGraph.Cluster cluster : og.getClusters()) {
            for (OperatorGraph.Cluster parentCluster : cluster.parentClusters(actualEdgePredicate)) {
                Set<Operator<?>> parentOperators = parentCluster.getMembers();
                for (Operator<?> operator : cluster.getMembers()) {
                    for (Operator<OperatorDesc> parentOperator : operator.getParentOperators()) {
                        if (!parentOperators.contains(parentOperator)) continue;
                        edgeOperators.put((Object)new Pair((Object)parentCluster, (Object)cluster), (Object)new Pair(parentOperator, operator));
                    }
                }
            }
        }
        HashSet<Pair> processedEdge = new HashSet<Pair>();
        for (Pair key : edgeOperators.keySet()) {
            List values = edgeOperators.get((Object)key);
            if (values.size() <= 1) continue;
            values.sort(new OperatorPairComparator());
            this.removeOneEdge(values);
            for (Pair pair : values) {
                if (processedEdge.contains(pair)) continue;
                this.fixParallelEdge((Operator)pair.left, (Operator)pair.right);
                processedEdge.add(pair);
            }
        }
    }

    private void removeOneEdge(List<Pair<Operator<?>, Operator<?>>> values) {
        Pair<Operator<?>, Operator<?>> toKeep = null;
        for (Pair<Operator<?>, Operator<?>> pair : values) {
            if (this.isParallelEdgeSupported(pair)) continue;
            if (toKeep != null) {
                throw new RuntimeException("More than one operators which may not reshuffled!");
            }
            toKeep = pair;
        }
        if (toKeep == null) {
            toKeep = values.get(values.size() - 1);
        }
        values.remove(toKeep);
    }

    public boolean isParallelEdgeSupported(Pair<Operator<?>, Operator<?>> pair) {
        Operator rs = (Operator)pair.left;
        if (rs instanceof ReduceSinkOperator && !ParallelEdgeFixer.colMappingInverseKeys((ReduceSinkOperator)rs).isPresent()) {
            return false;
        }
        Operator child = (Operator)pair.right;
        if (child instanceof MapJoinOperator) {
            return true;
        }
        return child instanceof TableScanOperator;
    }

    private void fixParallelEdge(Operator<? extends OperatorDesc> p, Operator<?> o) throws SemanticException {
        LOG.info("Fixing parallel by adding a concentrator RS between {} -> {}", p, o);
        ReduceSinkDesc conf = (ReduceSinkDesc)p.getConf();
        ReduceSinkDesc newConf = (ReduceSinkDesc)conf.clone();
        Operator<SelectDesc> newSEL = this.buildSEL(p, conf);
        Operator<ReduceSinkDesc> newRS = OperatorFactory.getAndMakeChild(p.getCompilationOpContext(), newConf, new ArrayList<Operator<? extends OperatorDesc>>());
        conf.setOutputName("forward_to_" + newRS);
        conf.setTag(0);
        newConf.setKeyCols(new ArrayList<ExprNodeDesc>(conf.getKeyCols()));
        newRS.setSchema(new RowSchema(p.getSchema()));
        p.replaceChild(o, newSEL);
        newSEL.setParentOperators(Lists.newArrayList((Object[])new Operator[]{p}));
        newSEL.setChildOperators(Lists.newArrayList((Object[])new Operator[]{newRS}));
        newRS.setParentOperators(Lists.newArrayList((Object[])new Operator[]{newSEL}));
        newRS.setChildOperators(Lists.newArrayList((Object[])new Operator[]{o}));
        o.replaceParent(p, newRS);
    }

    private Operator<SelectDesc> buildSEL(Operator<? extends OperatorDesc> p, ReduceSinkDesc conf) throws SemanticException {
        ArrayList<ExprNodeDesc> colList = new ArrayList<ExprNodeDesc>();
        ArrayList<String> outputColumnNames = new ArrayList<String>();
        ArrayList<ColumnInfo> newColumns = new ArrayList<ColumnInfo>();
        Set<String> inverseKeys = ParallelEdgeFixer.colMappingInverseKeys((ReduceSinkOperator)p).get();
        for (String colName : inverseKeys) {
            ExprNodeDesc expr = conf.getColumnExprMap().get(colName);
            ExprNodeColumnDesc colRef = new ExprNodeColumnDesc(expr.getTypeInfo(), colName, colName, false);
            colList.add(colRef);
            String newColName = ParallelEdgeFixer.extractColumnName(expr);
            outputColumnNames.add(newColName);
            ColumnInfo newColInfo = new ColumnInfo(p.getSchema().getColumnInfo(colName));
            newColInfo.setInternalName(newColName);
            newColumns.add(newColInfo);
        }
        SelectDesc selConf = new SelectDesc(colList, outputColumnNames);
        Operator<SelectDesc> newSEL = OperatorFactory.getAndMakeChild(p.getCompilationOpContext(), selConf, new ArrayList<Operator<? extends OperatorDesc>>());
        newSEL.setSchema(new RowSchema(newColumns));
        return newSEL;
    }

    private static String extractColumnName(ExprNodeDesc expr) throws SemanticException {
        if (expr instanceof ExprNodeColumnDesc) {
            ExprNodeColumnDesc exprNodeColumnDesc = (ExprNodeColumnDesc)expr;
            return exprNodeColumnDesc.getColumn();
        }
        if (expr instanceof ExprNodeConstantDesc) {
            ExprNodeConstantDesc exprNodeConstantDesc = (ExprNodeConstantDesc)expr;
            return exprNodeConstantDesc.getFoldedFromCol();
        }
        throw new SemanticException("unexpected mapping expression!");
    }

    public static Optional<Set<String>> colMappingInverseKeys(ReduceSinkOperator rs) {
        HashMap<String, String> ret = new HashMap<String, String>();
        Map<String, ExprNodeDesc> exprMap = rs.getColumnExprMap();
        HashSet<String> neededColumns = new HashSet<String>();
        if (!rs.getSchema().getColumnNames().stream().allMatch(exprMap::containsKey)) {
            return Optional.empty();
        }
        try {
            for (Map.Entry<String, ExprNodeDesc> e : exprMap.entrySet()) {
                String columnName = ParallelEdgeFixer.extractColumnName(e.getValue());
                if (rs.getSchema().getColumnInfo(e.getKey()) == null) {
                    neededColumns.add(columnName);
                    continue;
                }
                ret.put(columnName, e.getKey());
            }
            neededColumns.removeAll(ret.keySet());
            if (!neededColumns.isEmpty()) {
                return Optional.empty();
            }
            return Optional.of(new TreeSet(ret.values()));
        }
        catch (SemanticException e) {
            return Optional.empty();
        }
    }

    static class MaterializeSemiJoinEdges
    implements AutoCloseable {
        private ParseContext pctx;

        public MaterializeSemiJoinEdges(ParseContext pctx) {
            this.pctx = pctx;
            this.addSJEdges();
        }

        private void addSJEdges() {
            LinkedHashMap<ReduceSinkOperator, SemiJoinBranchInfo> rs2sj = this.pctx.getRsToSemiJoinBranchInfo();
            for (Map.Entry<ReduceSinkOperator, SemiJoinBranchInfo> e : rs2sj.entrySet()) {
                ReduceSinkOperator rs = e.getKey();
                SemiJoinBranchInfo sji = e.getValue();
                TableScanOperator ts = sji.getTsOp();
                rs.getChildOperators().add(ts);
                ts.getParentOperators().add(rs);
            }
        }

        private void removeSJEdges() throws SemanticException {
            LinkedHashMap<ReduceSinkOperator, SemiJoinBranchInfo> rs2sj = new LinkedHashMap<ReduceSinkOperator, SemiJoinBranchInfo>();
            for (Map.Entry<ReduceSinkOperator, SemiJoinBranchInfo> e : this.pctx.getRsToSemiJoinBranchInfo().entrySet()) {
                Operator<OperatorDesc> rs = e.getKey();
                SemiJoinBranchInfo sji = e.getValue();
                TableScanOperator ts = sji.getTsOp();
                while (true) {
                    if (rs.getChildOperators().size() != 1) {
                        throw new SemanticException("Unexpected number of children");
                    }
                    Operator<OperatorDesc> child = rs.getChildOperators().get(0);
                    if (child == ts) break;
                    rs = child;
                }
                rs.getChildOperators().clear();
                ts.getParentOperators().remove(rs);
                rs2sj.put((ReduceSinkOperator)rs, sji);
                if (!this.pctx.getRsToRuntimeValuesInfoMap().containsKey(e.getKey())) continue;
                RuntimeValuesInfo rvi = this.pctx.getRsToRuntimeValuesInfoMap().remove(e.getKey());
                this.pctx.getRsToRuntimeValuesInfoMap().put((ReduceSinkOperator)rs, rvi);
            }
            this.pctx.setRsToSemiJoinBranchInfo(rs2sj);
        }

        @Override
        public void close() throws Exception {
            this.removeSJEdges();
        }
    }

    private static class ActualEdgePredicate
    implements OperatorGraph.OperatorEdgePredicate {
        private static final EnumSet<OperatorGraph.EdgeType> ACCEPTABLE_EDGE_TYPES = EnumSet.of(OperatorGraph.EdgeType.FLOW, OperatorGraph.EdgeType.SEMIJOIN, OperatorGraph.EdgeType.BROADCAST);

        private ActualEdgePredicate() {
        }

        @Override
        public boolean accept(Operator<?> s, Operator<?> t, OperatorGraph.OpEdge opEdge) {
            return ACCEPTABLE_EDGE_TYPES.contains((Object)opEdge.getEdgeType());
        }
    }

    private static class OperatorPairComparator
    implements Comparator<Pair<Operator<?>, Operator<?>>> {
        private OperatorPairComparator() {
        }

        @Override
        public int compare(Pair<Operator<?>, Operator<?>> o1, Pair<Operator<?>, Operator<?>> o2) {
            return this.sig(o1).compareTo(this.sig(o2));
        }

        private String sig(Pair<Operator<?>, Operator<?>> o1) {
            return ((Operator)o1.left).toString() + ((Operator)o1.right).toString();
        }
    }
}

