/*
 * Decompiled with CFR 0.152.
 */
package org.apache.hadoop.hive.ql.optimizer.calcite.rules.views;

import com.google.common.collect.ImmutableList;
import java.util.ArrayList;
import java.util.List;
import org.apache.calcite.plan.RelOptRule;
import org.apache.calcite.plan.RelOptRuleCall;
import org.apache.calcite.plan.RelOptRuleOperand;
import org.apache.calcite.rel.RelNode;
import org.apache.calcite.rel.core.Aggregate;
import org.apache.calcite.rel.core.AggregateCall;
import org.apache.calcite.rel.core.JoinRelType;
import org.apache.calcite.rel.core.Union;
import org.apache.calcite.rel.type.RelDataTypeField;
import org.apache.calcite.rex.RexBuilder;
import org.apache.calcite.rex.RexInputRef;
import org.apache.calcite.rex.RexNode;
import org.apache.calcite.rex.RexUtil;
import org.apache.calcite.sql.SqlAggFunction;
import org.apache.calcite.sql.SqlOperator;
import org.apache.calcite.sql.fun.SqlStdOperatorTable;
import org.apache.calcite.tools.RelBuilder;
import org.apache.calcite.tools.RelBuilderFactory;

public abstract class HiveAggregateIncrementalRewritingRuleBase<T extends IncrementalComputePlan>
extends RelOptRule {
    private final int aggregateIndex;

    protected HiveAggregateIncrementalRewritingRuleBase(RelOptRuleOperand operand, RelBuilderFactory relBuilderFactory, String description, int aggregateIndex) {
        super(operand, relBuilderFactory, description);
        this.aggregateIndex = aggregateIndex;
    }

    public void onMatch(RelOptRuleCall call) {
        Aggregate agg = (Aggregate)call.rel(this.aggregateIndex);
        Union union = (Union)call.rel(1);
        RelBuilder relBuilder = call.builder();
        RexBuilder rexBuilder = relBuilder.getRexBuilder();
        RelNode joinLeftInput = union.getInput(1);
        T joinRightInput = this.createJoinRightInput(call);
        if (joinRightInput == null) {
            return;
        }
        ArrayList<Object> mvCols = new ArrayList<Object>(joinLeftInput.getRowType().getFieldCount());
        for (int i = 0; i < joinLeftInput.getRowType().getFieldCount(); ++i) {
            mvCols.add(rexBuilder.makeInputRef(((RelDataTypeField)joinLeftInput.getRowType().getFieldList().get(i)).getType(), i));
        }
        mvCols.add(rexBuilder.makeLiteral(true));
        joinLeftInput = relBuilder.push(joinLeftInput).project(mvCols).build();
        ArrayList<RexNode> projExprs = new ArrayList<RexNode>();
        ArrayList<RexNode> joinConjs = new ArrayList<RexNode>();
        int groupCount = agg.getGroupCount();
        int totalCount = agg.getGroupCount() + agg.getAggCallList().size();
        int leftPos = 0;
        int rightPos = totalCount + 1;
        while (leftPos < groupCount) {
            RexInputRef leftRef = rexBuilder.makeInputRef(((RelDataTypeField)joinLeftInput.getRowType().getFieldList().get(leftPos)).getType(), leftPos);
            RexInputRef rightRef = rexBuilder.makeInputRef(((RelDataTypeField)((IncrementalComputePlan)joinRightInput).rightInput.getRowType().getFieldList().get(leftPos)).getType(), rightPos);
            projExprs.add((RexNode)rightRef);
            RexNode nsEqExpr = rexBuilder.makeCall((SqlOperator)SqlStdOperatorTable.IS_NOT_DISTINCT_FROM, (List)ImmutableList.of((Object)leftRef, (Object)rightRef));
            joinConjs.add(nsEqExpr);
            ++leftPos;
            ++rightPos;
        }
        RexNode joinCond = RexUtil.composeConjunction((RexBuilder)rexBuilder, joinConjs);
        RelNode join = relBuilder.push(joinLeftInput).push(((IncrementalComputePlan)joinRightInput).rightInput).join(JoinRelType.RIGHT, joinCond).build();
        int i = 0;
        int leftPos2 = groupCount;
        int rightPos2 = totalCount + 1 + groupCount;
        while (leftPos2 < totalCount) {
            RexInputRef leftRef = rexBuilder.makeInputRef(((RelDataTypeField)joinLeftInput.getRowType().getFieldList().get(leftPos2)).getType(), leftPos2);
            RexInputRef rightRef = rexBuilder.makeInputRef(((RelDataTypeField)((IncrementalComputePlan)joinRightInput).rightInput.getRowType().getFieldList().get(leftPos2)).getType(), rightPos2);
            SqlAggFunction aggCall = ((AggregateCall)agg.getAggCallList().get(i)).getAggregation();
            RexNode elseReturn = this.createAggregateNode(aggCall, (RexNode)leftRef, (RexNode)rightRef, rexBuilder);
            RexNode leftNull = rexBuilder.makeCall((SqlOperator)SqlStdOperatorTable.IS_NULL, new RexNode[]{leftRef});
            RexNode rightNull = rexBuilder.makeCall((SqlOperator)SqlStdOperatorTable.IS_NULL, new RexNode[]{rightRef});
            RexNode caseExpression = rexBuilder.makeCall((SqlOperator)SqlStdOperatorTable.CASE, new RexNode[]{leftNull, rightRef, rightNull, leftRef, elseReturn});
            RexNode cast = rexBuilder.makeCast(((RelDataTypeField)call.rel(0).getRowType().getFieldList().get(projExprs.size())).getType(), caseExpression);
            projExprs.add(cast);
            ++i;
            ++leftPos2;
            ++rightPos2;
        }
        int flagIndex = joinLeftInput.getRowType().getFieldCount() - 1;
        RexInputRef flagNode = rexBuilder.makeInputRef(((RelDataTypeField)join.getRowType().getFieldList().get(flagIndex)).getType(), flagIndex);
        RelNode newNode = relBuilder.push(join).filter(new RexNode[]{this.createFilterCondition(joinRightInput, (RexNode)flagNode, projExprs, relBuilder)}).project(projExprs).build();
        call.transformTo(newNode);
    }

    protected abstract T createJoinRightInput(RelOptRuleCall var1);

    protected RexNode createAggregateNode(SqlAggFunction aggCall, RexNode leftRef, RexNode rightRef, RexBuilder rexBuilder) {
        switch (aggCall.getKind()) {
            case SUM: 
            case SUM0: 
            case COUNT: {
                return rexBuilder.makeCall((SqlOperator)SqlStdOperatorTable.PLUS, (List)ImmutableList.of((Object)rightRef, (Object)leftRef));
            }
        }
        throw new AssertionError((Object)("Found an aggregation that could not be recognized: " + String.valueOf(aggCall)));
    }

    protected abstract RexNode createFilterCondition(T var1, RexNode var2, List<RexNode> var3, RelBuilder var4);

    protected static class IncrementalComputePlan {
        protected final RelNode rightInput;

        public IncrementalComputePlan(RelNode rightInput) {
            this.rightInput = rightInput;
        }
    }
}

