/*
 * Decompiled with CFR 0.152.
 */
package org.elasticsearch.xpack.esql.expression.function.aggregate;

import java.io.IOException;
import java.time.Duration;
import java.util.List;
import org.elasticsearch.TransportVersions;
import org.elasticsearch.common.VersionId;
import org.elasticsearch.common.io.stream.NamedWriteable;
import org.elasticsearch.common.io.stream.NamedWriteableRegistry;
import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.common.io.stream.StreamOutput;
import org.elasticsearch.compute.aggregation.AggregatorFunctionSupplier;
import org.elasticsearch.compute.aggregation.RateDoubleAggregatorFunctionSupplier;
import org.elasticsearch.compute.aggregation.RateIntAggregatorFunctionSupplier;
import org.elasticsearch.compute.aggregation.RateLongAggregatorFunctionSupplier;
import org.elasticsearch.core.TimeValue;
import org.elasticsearch.xpack.esql.EsqlIllegalArgumentException;
import org.elasticsearch.xpack.esql.core.expression.Expression;
import org.elasticsearch.xpack.esql.core.expression.FoldContext;
import org.elasticsearch.xpack.esql.core.expression.Literal;
import org.elasticsearch.xpack.esql.core.expression.TypeResolutions;
import org.elasticsearch.xpack.esql.core.expression.UnresolvedAttribute;
import org.elasticsearch.xpack.esql.core.tree.Node;
import org.elasticsearch.xpack.esql.core.tree.NodeInfo;
import org.elasticsearch.xpack.esql.core.tree.Source;
import org.elasticsearch.xpack.esql.core.type.DataType;
import org.elasticsearch.xpack.esql.core.util.CollectionUtils;
import org.elasticsearch.xpack.esql.expression.function.FunctionInfo;
import org.elasticsearch.xpack.esql.expression.function.FunctionType;
import org.elasticsearch.xpack.esql.expression.function.OptionalArgument;
import org.elasticsearch.xpack.esql.expression.function.Param;
import org.elasticsearch.xpack.esql.expression.function.aggregate.AggregateFunction;
import org.elasticsearch.xpack.esql.io.stream.PlanStreamInput;
import org.elasticsearch.xpack.esql.planner.ToAggregator;

public class Rate
extends AggregateFunction
implements OptionalArgument,
ToAggregator {
    public static final NamedWriteableRegistry.Entry ENTRY = new NamedWriteableRegistry.Entry(Expression.class, "Rate", Rate::new);
    private static final TimeValue DEFAULT_UNIT = TimeValue.timeValueSeconds((long)1L);
    private final Expression timestamp;
    private final Expression unit;

    @FunctionInfo(returnType={"double"}, description="compute the rate of a counter field. Available in METRICS command only", type=FunctionType.AGGREGATE)
    public Rate(Source source, @Param(name="field", type={"counter_long|counter_integer|counter_double"}, description="counter field") Expression field, Expression timestamp, @Param(optional=true, name="unit", type={"time_duration"}, description="the unit") Expression unit) {
        this(source, field, (Expression)Literal.TRUE, timestamp, unit);
    }

    private Rate(Source source, Expression field, Expression filter, List<Expression> children) {
        this(source, field, filter, children.get(0), children.size() > 1 ? children.get(1) : null);
    }

    private Rate(Source source, Expression field, Expression filter, Expression timestamp, Expression unit) {
        super(source, field, filter, unit != null ? List.of(timestamp, unit) : List.of(timestamp));
        this.timestamp = timestamp;
        this.unit = unit;
    }

    public Rate(StreamInput in) throws IOException {
        this(Source.readFrom((StreamInput)((PlanStreamInput)in)), (Expression)in.readNamedWriteable(Expression.class), (Expression)(in.getTransportVersion().onOrAfter((VersionId)TransportVersions.V_8_16_0) ? (Expression)in.readNamedWriteable(Expression.class) : Literal.TRUE), in.getTransportVersion().onOrAfter((VersionId)TransportVersions.V_8_16_0) ? in.readNamedWriteableCollectionAsList(Expression.class) : CollectionUtils.nullSafeList((Object[])new Expression[]{(Expression)in.readNamedWriteable(Expression.class), (Expression)in.readOptionalNamedWriteable(Expression.class)}));
    }

    @Override
    protected void deprecatedWriteParams(StreamOutput out) throws IOException {
        out.writeNamedWriteable((NamedWriteable)this.timestamp);
        out.writeOptionalNamedWriteable((NamedWriteable)this.unit);
    }

    public String getWriteableName() {
        return Rate.ENTRY.name;
    }

    public static Rate withUnresolvedTimestamp(Source source, Expression field, Expression unit) {
        return new Rate(source, field, (Expression)new UnresolvedAttribute(source, "@timestamp"), unit);
    }

    protected NodeInfo<Rate> info() {
        return NodeInfo.create((Node)this, Rate::new, (Object)this.field(), (Object)this.timestamp, (Object)this.unit);
    }

    public Rate replaceChildren(List<Expression> newChildren) {
        if (this.unit != null) {
            if (newChildren.size() == 4) {
                return new Rate(this.source(), newChildren.get(0), newChildren.get(1), newChildren.get(2), newChildren.get(3));
            }
            assert (false) : "expected 4 children for field, filter, @timestamp, and unit; got " + String.valueOf(newChildren);
            throw new IllegalArgumentException("expected 4 children for field, filter, @timestamp, and unit; got " + String.valueOf(newChildren));
        }
        if (newChildren.size() == 3) {
            return new Rate(this.source(), newChildren.get(0), newChildren.get(1), newChildren.get(2), null);
        }
        assert (false) : "expected 3 children for field, filter and @timestamp; got " + String.valueOf(newChildren);
        throw new IllegalArgumentException("expected 3 children for field, filter and @timestamp; got " + String.valueOf(newChildren));
    }

    @Override
    public Rate withFilter(Expression filter) {
        return new Rate(this.source(), this.field(), filter, this.timestamp, this.unit);
    }

    public DataType dataType() {
        return DataType.DOUBLE;
    }

    @Override
    protected Expression.TypeResolution resolveType() {
        Expression.TypeResolution resolution = TypeResolutions.isType((Expression)this.field(), dt -> DataType.isCounter((DataType)dt), (String)this.sourceText(), (TypeResolutions.ParamOrdinal)TypeResolutions.ParamOrdinal.FIRST, (String[])new String[]{"counter_long", "counter_integer", "counter_double"});
        if (this.unit != null) {
            resolution = resolution.and(TypeResolutions.isType((Expression)this.unit, dt -> dt.isWholeNumber() || DataType.isTemporalAmount((DataType)dt), (String)this.sourceText(), (TypeResolutions.ParamOrdinal)TypeResolutions.ParamOrdinal.SECOND, (String[])new String[]{"time_duration"}));
        }
        return resolution;
    }

    long unitInMillis() {
        Object foldValue;
        if (this.unit == null) {
            return DEFAULT_UNIT.millis();
        }
        if (!this.unit.foldable()) {
            throw new IllegalArgumentException("function [" + this.sourceText() + "] has invalid unit [" + this.unit.sourceText() + "]");
        }
        try {
            foldValue = this.unit.fold(FoldContext.small());
        }
        catch (Exception e) {
            throw new IllegalArgumentException("function [" + this.sourceText() + "] has invalid unit [" + this.unit.sourceText() + "]");
        }
        if (foldValue instanceof Duration) {
            Duration duration = (Duration)foldValue;
            return duration.toMillis();
        }
        throw new IllegalArgumentException("function [" + this.sourceText() + "] has invalid unit [" + this.unit.sourceText() + "]");
    }

    @Override
    public AggregatorFunctionSupplier supplier() {
        long unitInMillis = this.unitInMillis();
        DataType type = this.field().dataType();
        return switch (type) {
            case DataType.COUNTER_LONG -> new RateLongAggregatorFunctionSupplier(unitInMillis);
            case DataType.COUNTER_INTEGER -> new RateIntAggregatorFunctionSupplier(unitInMillis);
            case DataType.COUNTER_DOUBLE -> new RateDoubleAggregatorFunctionSupplier(unitInMillis);
            default -> throw EsqlIllegalArgumentException.illegalDataType(type);
        };
    }

    public String toString() {
        if (this.unit != null) {
            return "rate(" + String.valueOf(this.field()) + "," + String.valueOf(this.unit) + ")";
        }
        return "rate(" + String.valueOf(this.field()) + ")";
    }

    Expression timestamp() {
        return this.timestamp;
    }

    Expression unit() {
        return this.unit;
    }
}

