/*
 * Decompiled with CFR 0.152.
 */
package org.elasticsearch.xpack.esql.optimizer.rules.logical;

import java.util.ArrayList;
import java.util.Collection;
import java.util.Iterator;
import org.elasticsearch.xpack.esql.core.expression.Expression;
import org.elasticsearch.xpack.esql.core.expression.Literal;
import org.elasticsearch.xpack.esql.core.expression.predicate.operator.comparison.BinaryComparison;
import org.elasticsearch.xpack.esql.core.type.DataType;
import org.elasticsearch.xpack.esql.core.util.CollectionUtils;
import org.elasticsearch.xpack.esql.expression.predicate.Predicates;
import org.elasticsearch.xpack.esql.expression.predicate.Range;
import org.elasticsearch.xpack.esql.expression.predicate.logical.And;
import org.elasticsearch.xpack.esql.expression.predicate.logical.BinaryLogic;
import org.elasticsearch.xpack.esql.expression.predicate.logical.Or;
import org.elasticsearch.xpack.esql.expression.predicate.operator.comparison.Equals;
import org.elasticsearch.xpack.esql.expression.predicate.operator.comparison.GreaterThan;
import org.elasticsearch.xpack.esql.expression.predicate.operator.comparison.GreaterThanOrEqual;
import org.elasticsearch.xpack.esql.expression.predicate.operator.comparison.LessThan;
import org.elasticsearch.xpack.esql.expression.predicate.operator.comparison.LessThanOrEqual;
import org.elasticsearch.xpack.esql.expression.predicate.operator.comparison.NotEquals;
import org.elasticsearch.xpack.esql.optimizer.LogicalOptimizerContext;
import org.elasticsearch.xpack.esql.optimizer.rules.logical.OptimizerRules;

public final class PropagateEquals
extends OptimizerRules.OptimizerExpressionRule<BinaryLogic> {
    public PropagateEquals() {
        super(OptimizerRules.TransformDirection.DOWN);
    }

    @Override
    public Expression rule(BinaryLogic e, LogicalOptimizerContext ctx) {
        if (e instanceof And) {
            return PropagateEquals.propagate((And)e, ctx);
        }
        if (e instanceof Or) {
            return PropagateEquals.propagate((Or)e, ctx);
        }
        return e;
    }

    private static Expression propagate(And and, LogicalOptimizerContext ctx) {
        BinaryComparison bc;
        ArrayList<Range> ranges = new ArrayList<Range>();
        ArrayList<Equals> equals = new ArrayList<Equals>();
        ArrayList<NotEquals> notEquals = new ArrayList<NotEquals>();
        ArrayList<BinaryComparison> inequalities = new ArrayList<BinaryComparison>();
        ArrayList<Equals> exps = new ArrayList<Equals>();
        boolean changed = false;
        for (Expression expression : Predicates.splitAnd((Expression)and)) {
            if (expression instanceof Range) {
                ranges.add((Range)expression);
                continue;
            }
            if (expression instanceof Equals) {
                Equals otherEq = (Equals)expression;
                if (otherEq.right().foldable() && !DataType.isDateTime((DataType)otherEq.left().dataType())) {
                    for (BinaryComparison binaryComparison : equals) {
                        Integer comp;
                        if (!otherEq.left().semanticEquals(binaryComparison.left()) || (comp = BinaryComparison.compare((Object)binaryComparison.right().fold(ctx.foldCtx()), (Object)otherEq.right().fold(ctx.foldCtx()))) == null || comp == 0) continue;
                        return new Literal(and.source(), (Object)Boolean.FALSE, DataType.BOOLEAN);
                    }
                    equals.add(otherEq);
                    continue;
                }
                exps.add(otherEq);
                continue;
            }
            if (expression instanceof GreaterThan || expression instanceof GreaterThanOrEqual || expression instanceof LessThan || expression instanceof LessThanOrEqual) {
                bc = (BinaryComparison)expression;
                if (bc.right().foldable()) {
                    inequalities.add(bc);
                    continue;
                }
                exps.add((Equals)expression);
                continue;
            }
            if (expression instanceof NotEquals) {
                NotEquals otherNotEq = (NotEquals)expression;
                if (otherNotEq.right().foldable()) {
                    notEquals.add(otherNotEq);
                    continue;
                }
                exps.add((Equals)expression);
                continue;
            }
            exps.add((Equals)expression);
        }
        for (BinaryComparison binaryComparison : equals) {
            Object eqValue = binaryComparison.right().fold(ctx.foldCtx());
            Iterator iterator = ranges.iterator();
            while (iterator.hasNext()) {
                Integer n;
                Integer n2;
                Range range = (Range)iterator.next();
                if (!range.value().semanticEquals(binaryComparison.left())) continue;
                if (range.lower().foldable() && (n2 = BinaryComparison.compare((Object)range.lower().fold(ctx.foldCtx()), (Object)eqValue)) != null && (n2 > 0 || n2 == 0 && !range.includeLower())) {
                    return new Literal(and.source(), (Object)Boolean.FALSE, DataType.BOOLEAN);
                }
                if (range.upper().foldable() && (n = BinaryComparison.compare((Object)range.upper().fold(ctx.foldCtx()), (Object)eqValue)) != null && (n < 0 || n == 0 && !range.includeUpper())) {
                    return new Literal(and.source(), (Object)Boolean.FALSE, DataType.BOOLEAN);
                }
                iterator.remove();
                changed = true;
            }
            Iterator iter = notEquals.iterator();
            while (iter.hasNext()) {
                Integer n;
                NotEquals neq = (NotEquals)iter.next();
                if (!binaryComparison.left().semanticEquals(neq.left()) || (n = BinaryComparison.compare((Object)eqValue, (Object)neq.right().fold(ctx.foldCtx()))) == null) continue;
                if (n == 0) {
                    return new Literal(and.source(), (Object)Boolean.FALSE, DataType.BOOLEAN);
                }
                iter.remove();
                changed = true;
            }
            iter = inequalities.iterator();
            while (iter.hasNext()) {
                Integer n;
                bc = (BinaryComparison)iter.next();
                if (!binaryComparison.left().semanticEquals(bc.left()) || (n = BinaryComparison.compare((Object)eqValue, (Object)bc.right().fold(ctx.foldCtx()))) == null) continue;
                if (bc instanceof LessThan || bc instanceof LessThanOrEqual ? n == 0 && bc instanceof LessThan || 0 < n : (bc instanceof GreaterThan || bc instanceof GreaterThanOrEqual) && (n == 0 && bc instanceof GreaterThan || n < 0)) {
                    return new Literal(and.source(), (Object)Boolean.FALSE, DataType.BOOLEAN);
                }
                iter.remove();
                changed = true;
            }
        }
        return changed ? Predicates.combineAnd(CollectionUtils.combine((Collection[])new Collection[]{exps, equals, notEquals, inequalities, ranges})) : and;
    }

    private static Expression propagate(Or or, LogicalOptimizerContext ctx) {
        Equals eq;
        ArrayList<Expression> exps = new ArrayList<Expression>();
        ArrayList<Equals> equals = new ArrayList<Equals>();
        ArrayList<NotEquals> notEquals = new ArrayList<NotEquals>();
        ArrayList<Range> ranges = new ArrayList<Range>();
        ArrayList<BinaryComparison> inequalities = new ArrayList<BinaryComparison>();
        for (Expression ex : Predicates.splitOr((Expression)or)) {
            if (ex instanceof Equals) {
                eq = (Equals)ex;
                if (eq.right().foldable()) {
                    equals.add(eq);
                    continue;
                }
                exps.add(ex);
                continue;
            }
            if (ex instanceof NotEquals) {
                NotEquals neq = (NotEquals)ex;
                if (neq.right().foldable()) {
                    notEquals.add(neq);
                    continue;
                }
                exps.add(ex);
                continue;
            }
            if (ex instanceof Range) {
                ranges.add((Range)ex);
                continue;
            }
            if (ex instanceof BinaryComparison) {
                BinaryComparison bc = (BinaryComparison)ex;
                if (bc.right().foldable()) {
                    inequalities.add(bc);
                    continue;
                }
                exps.add(ex);
                continue;
            }
            exps.add(ex);
        }
        boolean updated = false;
        Iterator iterEq = equals.iterator();
        while (iterEq.hasNext()) {
            int i;
            Integer comp;
            eq = (Equals)iterEq.next();
            Object eqValue = eq.right().fold(ctx.foldCtx());
            boolean removeEquals = false;
            for (NotEquals neq : notEquals) {
                if (!eq.left().semanticEquals(neq.left()) || (comp = BinaryComparison.compare((Object)eqValue, (Object)neq.right().fold(ctx.foldCtx()))) == null) continue;
                if (comp == 0) {
                    return Literal.TRUE;
                }
                removeEquals = true;
                break;
            }
            if (removeEquals) {
                iterEq.remove();
                updated = true;
                continue;
            }
            for (i = 0; i < ranges.size(); ++i) {
                Integer upperComp;
                Range range = (Range)ranges.get(i);
                if (!eq.left().semanticEquals(range.value())) continue;
                Integer lowerComp = range.lower().foldable() ? BinaryComparison.compare((Object)eqValue, (Object)range.lower().fold(ctx.foldCtx())) : null;
                Integer n = upperComp = range.upper().foldable() ? BinaryComparison.compare((Object)eqValue, (Object)range.upper().fold(ctx.foldCtx())) : null;
                if (lowerComp != null && lowerComp == 0) {
                    if (!range.includeLower()) {
                        ranges.set(i, new Range(range.source(), range.value(), range.lower(), true, range.upper(), range.includeUpper(), range.zoneId()));
                    }
                    removeEquals = true;
                    break;
                }
                if (upperComp != null && upperComp == 0) {
                    if (!range.includeUpper()) {
                        ranges.set(i, new Range(range.source(), range.value(), range.lower(), range.includeLower(), range.upper(), true, range.zoneId()));
                    }
                    removeEquals = true;
                    break;
                }
                if (lowerComp == null || upperComp == null || 0 >= lowerComp || upperComp >= 0) continue;
                removeEquals = true;
                break;
            }
            if (removeEquals) {
                iterEq.remove();
                updated = true;
                continue;
            }
            for (i = 0; i < inequalities.size(); ++i) {
                BinaryComparison bc = (BinaryComparison)inequalities.get(i);
                if (!eq.left().semanticEquals(bc.left()) || (comp = BinaryComparison.compare((Object)eqValue, (Object)bc.right().fold(ctx.foldCtx()))) == null) continue;
                if (bc instanceof GreaterThan || bc instanceof GreaterThanOrEqual) {
                    if (comp < 0) continue;
                    if (comp == 0 && bc instanceof GreaterThan) {
                        inequalities.set(i, new GreaterThanOrEqual(bc.source(), bc.left(), bc.right(), bc.zoneId()));
                    }
                    removeEquals = true;
                    break;
                }
                if (!(bc instanceof LessThan) && !(bc instanceof LessThanOrEqual) || comp > 0) continue;
                if (comp == 0 && bc instanceof LessThan) {
                    inequalities.set(i, new LessThanOrEqual(bc.source(), bc.left(), bc.right(), bc.zoneId()));
                }
                removeEquals = true;
                break;
            }
            if (!removeEquals) continue;
            iterEq.remove();
            updated = true;
        }
        return updated ? Predicates.combineOr(CollectionUtils.combine((Collection[])new Collection[]{exps, equals, notEquals, inequalities, ranges})) : or;
    }
}

