/*
 * Decompiled with CFR 0.152.
 */
package org.elasticsearch.xpack.esql.optimizer.rules.logical;

import java.util.ArrayList;
import java.util.List;
import org.elasticsearch.xpack.esql.core.expression.Expression;
import org.elasticsearch.xpack.esql.core.expression.FoldContext;
import org.elasticsearch.xpack.esql.core.expression.predicate.operator.comparison.BinaryComparison;
import org.elasticsearch.xpack.esql.core.util.CollectionUtils;
import org.elasticsearch.xpack.esql.expression.predicate.Predicates;
import org.elasticsearch.xpack.esql.expression.predicate.logical.And;
import org.elasticsearch.xpack.esql.expression.predicate.logical.BinaryLogic;
import org.elasticsearch.xpack.esql.expression.predicate.logical.Or;
import org.elasticsearch.xpack.esql.expression.predicate.operator.comparison.Equals;
import org.elasticsearch.xpack.esql.expression.predicate.operator.comparison.GreaterThan;
import org.elasticsearch.xpack.esql.expression.predicate.operator.comparison.GreaterThanOrEqual;
import org.elasticsearch.xpack.esql.expression.predicate.operator.comparison.LessThan;
import org.elasticsearch.xpack.esql.expression.predicate.operator.comparison.LessThanOrEqual;
import org.elasticsearch.xpack.esql.expression.predicate.operator.comparison.NotEquals;
import org.elasticsearch.xpack.esql.optimizer.LogicalOptimizerContext;
import org.elasticsearch.xpack.esql.optimizer.rules.logical.OptimizerRules;

public final class CombineBinaryComparisons
extends OptimizerRules.OptimizerExpressionRule<BinaryLogic> {
    public CombineBinaryComparisons() {
        super(OptimizerRules.TransformDirection.DOWN);
    }

    @Override
    public Expression rule(BinaryLogic e, LogicalOptimizerContext ctx) {
        if (e instanceof And) {
            And and = (And)e;
            return CombineBinaryComparisons.combine(ctx.foldCtx(), and);
        }
        if (e instanceof Or) {
            Or or = (Or)e;
            return CombineBinaryComparisons.combine(ctx.foldCtx(), or);
        }
        return e;
    }

    private static Expression combine(FoldContext ctx, And and) {
        ArrayList<BinaryComparison> bcs = new ArrayList<BinaryComparison>();
        ArrayList<Expression> exps = new ArrayList<Expression>();
        boolean changed = false;
        List<Expression> andExps = Predicates.splitAnd((Expression)and);
        andExps.sort((o1, o2) -> {
            if (o1 instanceof NotEquals && o2 instanceof NotEquals) {
                return 0;
            }
            if (o1 instanceof NotEquals || o2 instanceof NotEquals) {
                return o1 instanceof NotEquals ? 1 : -1;
            }
            return 0;
        });
        for (Expression ex : andExps) {
            if (ex instanceof BinaryComparison) {
                BinaryComparison bc = (BinaryComparison)ex;
                if (!(ex instanceof Equals || ex instanceof NotEquals)) {
                    if (bc.right().foldable() && CombineBinaryComparisons.findExistingComparison(ctx, bc, bcs, true)) {
                        changed = true;
                        continue;
                    }
                    bcs.add(bc);
                    continue;
                }
            }
            if (ex instanceof NotEquals) {
                NotEquals neq = (NotEquals)ex;
                if (neq.right().foldable() && CombineBinaryComparisons.notEqualsIsRemovableFromConjunction(ctx, neq, bcs)) {
                    changed = true;
                    continue;
                }
                exps.add(ex);
                continue;
            }
            exps.add(ex);
        }
        return changed ? Predicates.combineAnd(CollectionUtils.combine(exps, bcs)) : and;
    }

    private static Expression combine(FoldContext ctx, Or or) {
        ArrayList<BinaryComparison> bcs = new ArrayList<BinaryComparison>();
        ArrayList<Expression> exps = new ArrayList<Expression>();
        boolean changed = false;
        for (Expression ex : Predicates.splitOr((Expression)or)) {
            if (ex instanceof BinaryComparison) {
                BinaryComparison bc = (BinaryComparison)ex;
                if (bc.right().foldable() && CombineBinaryComparisons.findExistingComparison(ctx, bc, bcs, false)) {
                    changed = true;
                    continue;
                }
                bcs.add(bc);
                continue;
            }
            exps.add(ex);
        }
        return changed ? Predicates.combineOr(CollectionUtils.combine(exps, bcs)) : or;
    }

    private static boolean findExistingComparison(FoldContext ctx, BinaryComparison main, List<BinaryComparison> bcs, boolean conjunctive) {
        Object value = main.right().fold(ctx);
        for (int i = 0; i < bcs.size(); ++i) {
            BinaryComparison other = bcs.get(i);
            if (!other.right().foldable()) continue;
            if ((other instanceof GreaterThan || other instanceof GreaterThanOrEqual) && (main instanceof GreaterThan || main instanceof GreaterThanOrEqual)) {
                if (!main.left().semanticEquals(other.left())) continue;
                Integer compare = BinaryComparison.compare((Object)value, (Object)other.right().fold(ctx));
                if (compare != null) {
                    if (conjunctive && (compare > 0 || compare == 0 && main instanceof GreaterThan && other instanceof GreaterThanOrEqual) || !conjunctive && (compare < 0 || compare == 0 && main instanceof GreaterThanOrEqual && other instanceof GreaterThan)) {
                        bcs.remove(i);
                        bcs.add(i, main);
                    }
                    return true;
                }
                return false;
            }
            if (!(other instanceof LessThan) && !(other instanceof LessThanOrEqual) || !(main instanceof LessThan) && !(main instanceof LessThanOrEqual) || !main.left().semanticEquals(other.left())) continue;
            Integer compare = BinaryComparison.compare((Object)value, (Object)other.right().fold(ctx));
            if (compare != null) {
                if (conjunctive && (compare < 0 || compare == 0 && main instanceof LessThan && other instanceof LessThanOrEqual) || !conjunctive && (compare > 0 || compare == 0 && main instanceof LessThanOrEqual && other instanceof LessThan)) {
                    bcs.remove(i);
                    bcs.add(i, main);
                }
                return true;
            }
            return false;
        }
        return false;
    }

    private static boolean notEqualsIsRemovableFromConjunction(FoldContext ctx, NotEquals notEquals, List<BinaryComparison> bcs) {
        Object neqVal = notEquals.right().fold(ctx);
        for (int i = 0; i < bcs.size(); ++i) {
            Integer comp;
            BinaryComparison bc = bcs.get(i);
            if (!notEquals.left().semanticEquals(bc.left())) continue;
            if (bc instanceof LessThan || bc instanceof LessThanOrEqual) {
                Integer n = comp = bc.right().foldable() ? BinaryComparison.compare((Object)neqVal, (Object)bc.right().fold(ctx)) : null;
                if (comp == null || comp < 0) continue;
                if (comp == 0 && bc instanceof LessThanOrEqual) {
                    bcs.set(i, new LessThan(bc.source(), bc.left(), bc.right(), bc.zoneId()));
                }
                return true;
            }
            if (!(bc instanceof GreaterThan) && !(bc instanceof GreaterThanOrEqual)) continue;
            Integer n = comp = bc.right().foldable() ? BinaryComparison.compare((Object)neqVal, (Object)bc.right().fold(ctx)) : null;
            if (comp == null || comp > 0) continue;
            if (comp == 0 && bc instanceof GreaterThanOrEqual) {
                bcs.set(i, new GreaterThan(bc.source(), bc.left(), bc.right(), bc.zoneId()));
            }
            return true;
        }
        return false;
    }
}

