/*
 * Decompiled with CFR 0.152.
 */
package org.opensearch.sql.analysis;

import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.stream.Collectors;
import org.opensearch.sql.analysis.AnalysisContext;
import org.opensearch.sql.expression.Expression;
import org.opensearch.sql.expression.ExpressionNodeVisitor;
import org.opensearch.sql.expression.FunctionExpression;
import org.opensearch.sql.expression.NamedExpression;
import org.opensearch.sql.expression.ReferenceExpression;
import org.opensearch.sql.expression.aggregation.Aggregator;
import org.opensearch.sql.expression.conditional.cases.CaseClause;
import org.opensearch.sql.expression.conditional.cases.WhenClause;
import org.opensearch.sql.expression.function.BuiltinFunctionRepository;
import org.opensearch.sql.expression.function.OpenSearchFunctions;
import org.opensearch.sql.planner.logical.LogicalAggregation;
import org.opensearch.sql.planner.logical.LogicalPlan;
import org.opensearch.sql.planner.logical.LogicalPlanNodeVisitor;
import org.opensearch.sql.planner.logical.LogicalWindow;

public class ExpressionReferenceOptimizer
extends ExpressionNodeVisitor<Expression, AnalysisContext> {
    private final BuiltinFunctionRepository repository;
    private final Map<Expression, Expression> expressionMap = new HashMap<Expression, Expression>();

    public ExpressionReferenceOptimizer(BuiltinFunctionRepository repository, LogicalPlan logicalPlan) {
        this.repository = repository;
        logicalPlan.accept(new ExpressionMapBuilder(), null);
    }

    public Expression optimize(Expression analyzed, AnalysisContext context) {
        return analyzed.accept(this, context);
    }

    @Override
    public Expression visitNode(Expression node, AnalysisContext context) {
        return node;
    }

    @Override
    public Expression visitFunction(FunctionExpression node, AnalysisContext context) {
        if (this.expressionMap.containsKey(node)) {
            return this.expressionMap.get(node);
        }
        List<Expression> args2 = node.getArguments().stream().map(expr -> expr.accept(this, context)).collect(Collectors.toList());
        Expression optimizedFunctionExpression = (Expression)((Object)this.repository.compile(context.getFunctionProperties(), node.getFunctionName(), args2));
        if (optimizedFunctionExpression instanceof OpenSearchFunctions.OpenSearchFunction) {
            ((OpenSearchFunctions.OpenSearchFunction)optimizedFunctionExpression).setScoreTracked(((OpenSearchFunctions.OpenSearchFunction)node).isScoreTracked());
        }
        return optimizedFunctionExpression;
    }

    @Override
    public Expression visitAggregator(Aggregator<?> node, AnalysisContext context) {
        return this.expressionMap.getOrDefault(node, node);
    }

    @Override
    public Expression visitNamed(NamedExpression node, AnalysisContext context) {
        if (this.expressionMap.containsKey(node)) {
            return this.expressionMap.get(node);
        }
        return node.getDelegated().accept(this, context);
    }

    @Override
    public Expression visitCase(CaseClause node, AnalysisContext context) {
        if (this.expressionMap.containsKey(node)) {
            return this.expressionMap.get(node);
        }
        List<WhenClause> whenClauses = node.getWhenClauses().stream().map(expr -> (WhenClause)expr.accept(this, context)).collect(Collectors.toList());
        Expression defaultResult = null;
        if (node.getDefaultResult() != null) {
            defaultResult = node.getDefaultResult().accept(this, context);
        }
        return new CaseClause(whenClauses, defaultResult);
    }

    @Override
    public Expression visitWhen(WhenClause node, AnalysisContext context) {
        return new WhenClause(node.getCondition().accept(this, context), node.getResult().accept(this, context));
    }

    class ExpressionMapBuilder
    extends LogicalPlanNodeVisitor<Void, Void> {
        ExpressionMapBuilder() {
        }

        @Override
        public Void visitNode(LogicalPlan plan, Void context) {
            plan.getChild().forEach(child -> child.accept(this, context));
            return null;
        }

        @Override
        public Void visitAggregation(LogicalAggregation plan, Void context) {
            plan.getAggregatorList().forEach(namedAggregator -> ExpressionReferenceOptimizer.this.expressionMap.put(namedAggregator.getDelegated(), new ReferenceExpression(namedAggregator.getName(), namedAggregator.type())));
            plan.getGroupByList().forEach(groupBy -> ExpressionReferenceOptimizer.this.expressionMap.put(groupBy.getDelegated(), new ReferenceExpression(groupBy.getNameOrAlias(), groupBy.type())));
            return null;
        }

        @Override
        public Void visitWindow(LogicalWindow plan, Void context) {
            NamedExpression windowFunc = plan.getWindowFunction();
            ExpressionReferenceOptimizer.this.expressionMap.put(windowFunc, new ReferenceExpression(windowFunc.getName(), windowFunc.type()));
            return this.visitNode((LogicalPlan)plan, context);
        }
    }
}

