/*
 * Decompiled with CFR 0.152.
 */
package org.apache.iotdb.db.queryengine.plan.relational.planner.optimizations;

import com.google.common.collect.ImmutableSet;
import java.util.ArrayList;
import java.util.Collection;
import java.util.List;
import java.util.Map;
import org.apache.iotdb.commons.udf.builtin.relational.TableBuiltinScalarFunction;
import org.apache.iotdb.db.queryengine.common.QueryId;
import org.apache.iotdb.db.queryengine.common.SessionInfo;
import org.apache.iotdb.db.queryengine.plan.planner.plan.node.PlanNode;
import org.apache.iotdb.db.queryengine.plan.planner.plan.node.PlanVisitor;
import org.apache.iotdb.db.queryengine.plan.relational.metadata.Metadata;
import org.apache.iotdb.db.queryengine.plan.relational.planner.Symbol;
import org.apache.iotdb.db.queryengine.plan.relational.planner.SymbolAllocator;
import org.apache.iotdb.db.queryengine.plan.relational.planner.node.AggregationNode;
import org.apache.iotdb.db.queryengine.plan.relational.planner.node.AggregationTableScanNode;
import org.apache.iotdb.db.queryengine.plan.relational.planner.node.DeviceTableScanNode;
import org.apache.iotdb.db.queryengine.plan.relational.planner.node.ProjectNode;
import org.apache.iotdb.db.queryengine.plan.relational.planner.optimizations.PlanOptimizer;
import org.apache.iotdb.db.queryengine.plan.relational.planner.optimizations.Util;
import org.apache.iotdb.db.queryengine.plan.relational.sql.ast.Expression;
import org.apache.iotdb.db.queryengine.plan.relational.sql.ast.FunctionCall;
import org.apache.iotdb.db.queryengine.plan.relational.sql.ast.SymbolReference;
import org.apache.tsfile.utils.Pair;

public class PushAggregationIntoTableScan
implements PlanOptimizer {
    @Override
    public PlanNode optimize(PlanNode plan, PlanOptimizer.Context context) {
        if (!context.getAnalysis().isQuery() || !context.getAnalysis().containsAggregationQuery()) {
            return plan;
        }
        return plan.accept(new Rewriter(), new Context(context.getQueryContext().getQueryId(), context.getMetadata(), context.sessionInfo(), context.getSymbolAllocator()));
    }

    private static class Rewriter
    extends PlanVisitor<PlanNode, Context> {
        private Rewriter() {
        }

        @Override
        public PlanNode visitPlan(PlanNode node, Context context) {
            PlanNode newNode = node.clone();
            for (PlanNode child : node.getChildren()) {
                newNode.addChild(child.accept(this, context));
            }
            return newNode;
        }

        @Override
        public PlanNode visitAggregation(AggregationNode node, Context context) {
            PlanNode child = node.getChild().accept(this, context);
            node = (AggregationNode)node.clone();
            node.setChild(child);
            DeviceTableScanNode tableScanNode = null;
            ProjectNode projectNode = null;
            if (child instanceof DeviceTableScanNode) {
                tableScanNode = (DeviceTableScanNode)child;
            }
            if (child instanceof ProjectNode && (projectNode = (ProjectNode)child).getChild() instanceof DeviceTableScanNode) {
                tableScanNode = (DeviceTableScanNode)projectNode.getChild();
            }
            if (tableScanNode == null || tableScanNode instanceof AggregationTableScanNode) {
                return node;
            }
            if (tableScanNode.containsNonAlignedDevice()) {
                return node;
            }
            PushDownLevel pushDownLevel = this.calculatePushDownLevel(node.getAggregations().values(), node.getGroupingKeys(), projectNode, tableScanNode, context.session, context.metadata);
            if (pushDownLevel == PushDownLevel.NOOP) {
                return node;
            }
            if (pushDownLevel == PushDownLevel.PARTIAL) {
                Pair<AggregationNode, AggregationNode> result = Util.split(node, context.symbolAllocator, context.queryId);
                AggregationTableScanNode aggregationTableScanNode = AggregationTableScanNode.combineAggregationAndTableScan(context.queryId.genPlanNodeId(), (AggregationNode)result.right, projectNode, tableScanNode);
                ((AggregationNode)result.left).setChild(aggregationTableScanNode);
                return (PlanNode)result.left;
            }
            return AggregationTableScanNode.combineAggregationAndTableScan(context.queryId.genPlanNodeId(), node, projectNode, tableScanNode);
        }

        private PushDownLevel calculatePushDownLevel(Collection<AggregationNode.Aggregation> values, List<Symbol> groupingKeys, ProjectNode projectNode, DeviceTableScanNode tableScanNode, SessionInfo session, Metadata metadata) {
            boolean singleDeviceEntry;
            boolean hasProject = projectNode != null;
            Map<Symbol, Expression> assignments = hasProject ? projectNode.getAssignments().getMap() : null;
            for (AggregationNode.Aggregation aggregation : values) {
                if (aggregation.isDistinct()) {
                    return PushDownLevel.NOOP;
                }
                if (!hasProject || !aggregation.getArguments().stream().anyMatch(argument -> !(assignments.get(Symbol.from(argument)) instanceof SymbolReference))) continue;
                return PushDownLevel.NOOP;
            }
            boolean bl = singleDeviceEntry = tableScanNode.getDeviceEntries().size() < 2;
            if (groupingKeys.isEmpty()) {
                if (singleDeviceEntry) {
                    return PushDownLevel.COMPLETE;
                }
                return PushDownLevel.PARTIAL;
            }
            ArrayList dateBinFunctionsOfTime = new ArrayList();
            if (groupingKeys.stream().anyMatch(groupingKey -> hasProject && !(assignments.get(groupingKey) instanceof SymbolReference) && !this.isDateBinFunctionOfTime((Expression)assignments.get(groupingKey), dateBinFunctionsOfTime, tableScanNode) || tableScanNode.isMeasurementOrTimeColumn((Symbol)groupingKey)) || dateBinFunctionsOfTime.size() > 1) {
                return PushDownLevel.NOOP;
            }
            if (singleDeviceEntry || ImmutableSet.copyOf(groupingKeys).containsAll(tableScanNode.getIdColumnsInTableStore(metadata, session))) {
                return PushDownLevel.COMPLETE;
            }
            return PushDownLevel.PARTIAL;
        }

        private boolean isDateBinFunctionOfTime(Expression expression, List<FunctionCall> dateBinFunctionsOfTime, DeviceTableScanNode tableScanNode) {
            if (expression instanceof FunctionCall) {
                FunctionCall function = (FunctionCall)expression;
                if (TableBuiltinScalarFunction.DATE_BIN.getFunctionName().equals(function.getName().toString()) && function.getArguments().get(2) instanceof SymbolReference && tableScanNode.isTimeColumn(Symbol.from(function.getArguments().get(2)))) {
                    dateBinFunctionsOfTime.add(function);
                    return true;
                }
            }
            return false;
        }
    }

    private static class Context {
        private final QueryId queryId;
        private final Metadata metadata;
        private final SessionInfo session;
        private final SymbolAllocator symbolAllocator;

        public Context(QueryId queryId, Metadata metadata, SessionInfo session, SymbolAllocator symbolAllocator) {
            this.queryId = queryId;
            this.metadata = metadata;
            this.session = session;
            this.symbolAllocator = symbolAllocator;
        }
    }

    private static enum PushDownLevel {
        NOOP,
        PARTIAL,
        COMPLETE;

    }
}

