/*
 * Decompiled with CFR 0.152.
 */
package org.elasticsearch.xpack.esql.plugin;

import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.concurrent.Executor;
import java.util.concurrent.atomic.AtomicBoolean;
import org.elasticsearch.action.ActionListener;
import org.elasticsearch.action.ActionListenerResponseHandler;
import org.elasticsearch.action.ActionRunnable;
import org.elasticsearch.action.OriginalIndices;
import org.elasticsearch.action.search.SearchShardsGroup;
import org.elasticsearch.action.search.SearchShardsRequest;
import org.elasticsearch.action.search.SearchShardsResponse;
import org.elasticsearch.action.support.ChannelActionListener;
import org.elasticsearch.action.support.RefCountingRunnable;
import org.elasticsearch.cluster.node.DiscoveryNode;
import org.elasticsearch.compute.operator.DriverProfile;
import org.elasticsearch.compute.operator.exchange.ExchangeService;
import org.elasticsearch.compute.operator.exchange.ExchangeSink;
import org.elasticsearch.compute.operator.exchange.ExchangeSinkHandler;
import org.elasticsearch.compute.operator.exchange.ExchangeSourceHandler;
import org.elasticsearch.compute.operator.exchange.RemoteSink;
import org.elasticsearch.core.IOUtils;
import org.elasticsearch.core.Releasable;
import org.elasticsearch.core.TimeValue;
import org.elasticsearch.index.Index;
import org.elasticsearch.index.query.QueryBuilder;
import org.elasticsearch.index.shard.IndexShard;
import org.elasticsearch.index.shard.ShardId;
import org.elasticsearch.index.shard.ShardNotFoundException;
import org.elasticsearch.search.SearchService;
import org.elasticsearch.search.internal.AliasFilter;
import org.elasticsearch.search.internal.SearchContext;
import org.elasticsearch.search.internal.ShardSearchRequest;
import org.elasticsearch.tasks.CancellableTask;
import org.elasticsearch.tasks.Task;
import org.elasticsearch.tasks.TaskCancelledException;
import org.elasticsearch.threadpool.ThreadPool;
import org.elasticsearch.transport.Transport;
import org.elasticsearch.transport.TransportChannel;
import org.elasticsearch.transport.TransportRequest;
import org.elasticsearch.transport.TransportRequestHandler;
import org.elasticsearch.transport.TransportRequestOptions;
import org.elasticsearch.transport.TransportResponseHandler;
import org.elasticsearch.transport.TransportService;
import org.elasticsearch.xpack.esql.action.EsqlSearchShardsAction;
import org.elasticsearch.xpack.esql.core.expression.FoldContext;
import org.elasticsearch.xpack.esql.plan.physical.ExchangeSinkExec;
import org.elasticsearch.xpack.esql.plan.physical.PhysicalPlan;
import org.elasticsearch.xpack.esql.planner.PlannerUtils;
import org.elasticsearch.xpack.esql.plugin.ComputeContext;
import org.elasticsearch.xpack.esql.plugin.ComputeListener;
import org.elasticsearch.xpack.esql.plugin.ComputeResponse;
import org.elasticsearch.xpack.esql.plugin.ComputeService;
import org.elasticsearch.xpack.esql.plugin.DataNodeRequest;
import org.elasticsearch.xpack.esql.plugin.QueryPragmas;
import org.elasticsearch.xpack.esql.session.Configuration;

final class DataNodeComputeHandler
implements TransportRequestHandler<DataNodeRequest> {
    private final ComputeService computeService;
    private final SearchService searchService;
    private final TransportService transportService;
    private final ExchangeService exchangeService;
    private final Executor esqlExecutor;

    DataNodeComputeHandler(ComputeService computeService, SearchService searchService, TransportService transportService, ExchangeService exchangeService, Executor esqlExecutor) {
        this.computeService = computeService;
        this.searchService = searchService;
        this.transportService = transportService;
        this.exchangeService = exchangeService;
        this.esqlExecutor = esqlExecutor;
        transportService.registerRequestHandler("indices:data/read/esql/data", esqlExecutor, DataNodeRequest::new, (TransportRequestHandler)this);
    }

    void startComputeOnDataNodes(String sessionId, String clusterAlias, CancellableTask parentTask, Configuration configuration, PhysicalPlan dataNodePlan, Set<String> concreteIndices, OriginalIndices originalIndices, ExchangeSourceHandler exchangeSource, Runnable runOnTaskFailure, ActionListener<ComputeResponse> outListener) {
        QueryBuilder requestFilter = PlannerUtils.requestTimestampFilter(dataNodePlan);
        ActionListener listener = ActionListener.runAfter(outListener, () -> ((Releasable)exchangeSource.addEmptySink()).close());
        long startTimeInNanos = System.nanoTime();
        this.lookupDataNodes((Task)parentTask, clusterAlias, requestFilter, concreteIndices, originalIndices, (ActionListener<DataNodeResult>)ActionListener.wrap(dataNodeResult -> {
            try (ComputeListener computeListener = new ComputeListener(this.transportService.getThreadPool(), runOnTaskFailure, (ActionListener<List<DriverProfile>>)listener.map(profiles -> {
                TimeValue took = TimeValue.timeValueNanos((long)(System.nanoTime() - startTimeInNanos));
                return new ComputeResponse((List<DriverProfile>)profiles, took, dataNodeResult.totalShards(), dataNodeResult.totalShards(), dataNodeResult.skippedShards(), 0);
            }));){
                for (DataNode node : dataNodeResult.dataNodes()) {
                    QueryPragmas queryPragmas = configuration.pragmas();
                    String childSessionId = this.computeService.newChildSession(sessionId);
                    ActionListener nodeListener = computeListener.acquireCompute().map(ComputeResponse::getProfiles);
                    ExchangeService.openExchange((TransportService)this.transportService, (Transport.Connection)node.connection, (String)childSessionId, (int)queryPragmas.exchangeBufferSize(), (Executor)this.esqlExecutor, (ActionListener)nodeListener.delegateFailureAndWrap((l, unused) -> {
                        RemoteSink remoteSink = this.exchangeService.newRemoteSink((Task)parentTask, childSessionId, this.transportService, node.connection);
                        exchangeSource.addRemoteSink(remoteSink, true, () -> {}, queryPragmas.concurrentExchangeClients(), computeListener.acquireAvoid());
                        boolean sameNode = this.transportService.getLocalNode().getId().equals(node.connection.getNode().getId());
                        DataNodeRequest dataNodeRequest = new DataNodeRequest(childSessionId, configuration, clusterAlias, node.shardIds, node.aliasFilters, dataNodePlan, originalIndices.indices(), originalIndices.indicesOptions(), !sameNode && queryPragmas.nodeLevelReduction());
                        this.transportService.sendChildRequest(node.connection, "indices:data/read/esql/data", (TransportRequest)dataNodeRequest, (Task)parentTask, TransportRequestOptions.EMPTY, (TransportResponseHandler)new ActionListenerResponseHandler(nodeListener, ComputeResponse::new, this.esqlExecutor));
                    }));
                }
            }
        }, arg_0 -> ((ActionListener)listener).onFailure(arg_0)));
    }

    private void acquireSearchContexts(String clusterAlias, List<ShardId> shardIds, Configuration configuration, Map<Index, AliasFilter> aliasFilters, ActionListener<List<SearchContext>> listener) {
        ArrayList<IndexShard> targetShards = new ArrayList<IndexShard>();
        try {
            for (ShardId shardId : shardIds) {
                IndexShard indexShard = this.searchService.getIndicesService().indexServiceSafe(shardId.getIndex()).getShard(shardId.id());
                targetShards.add(indexShard);
            }
        }
        catch (Exception e) {
            listener.onFailure(e);
            return;
        }
        ActionRunnable doAcquire = ActionRunnable.supply(listener, () -> {
            ArrayList<SearchContext> searchContexts = new ArrayList<SearchContext>(targetShards.size());
            boolean success = false;
            try {
                for (IndexShard shard : targetShards) {
                    AliasFilter aliasFilter = aliasFilters.getOrDefault(shard.shardId().getIndex(), AliasFilter.EMPTY);
                    ShardSearchRequest shardRequest = new ShardSearchRequest(shard.shardId(), configuration.absoluteStartedTimeInMillis(), aliasFilter, clusterAlias);
                    SearchContext context = this.searchService.createSearchContext(shardRequest, SearchService.NO_TIMEOUT);
                    searchContexts.add(context);
                }
                for (SearchContext searchContext : searchContexts) {
                    searchContext.preProcess();
                }
                success = true;
                ArrayList<SearchContext> arrayList = searchContexts;
                return arrayList;
            }
            finally {
                if (!success) {
                    IOUtils.close(searchContexts);
                }
            }
        });
        AtomicBoolean waitedForRefreshes = new AtomicBoolean();
        try (RefCountingRunnable refs = new RefCountingRunnable(() -> {
            if (waitedForRefreshes.get()) {
                this.esqlExecutor.execute((Runnable)doAcquire);
            } else {
                doAcquire.run();
            }
        });){
            for (IndexShard targetShard : targetShards) {
                Releasable ref = refs.acquire();
                targetShard.ensureShardSearchActive(await -> {
                    try (Releasable releasable = ref;){
                        if (await.booleanValue()) {
                            waitedForRefreshes.set(true);
                        }
                    }
                });
            }
        }
    }

    private void lookupDataNodes(Task parentTask, String clusterAlias, QueryBuilder filter, Set<String> concreteIndices, OriginalIndices originalIndices, ActionListener<DataNodeResult> listener) {
        ActionListener searchShardsListener = listener.map(resp -> {
            HashMap<String, DiscoveryNode> nodes = new HashMap<String, DiscoveryNode>();
            for (DiscoveryNode node : resp.getNodes()) {
                nodes.put(node.getId(), node);
            }
            HashMap<String, List> nodeToShards = new HashMap<String, List>();
            HashMap nodeToAliasFilters = new HashMap();
            int totalShards = 0;
            int skippedShards = 0;
            for (SearchShardsGroup group : resp.getGroups()) {
                ShardId shardId = group.shardId();
                if (group.allocatedNodes().isEmpty()) {
                    throw new ShardNotFoundException(group.shardId(), "no shard copies found {}", new Object[]{group.shardId()});
                }
                if (!concreteIndices.contains(shardId.getIndexName())) continue;
                ++totalShards;
                if (group.skipped()) {
                    ++skippedShards;
                    continue;
                }
                String targetNode = (String)group.allocatedNodes().get(0);
                nodeToShards.computeIfAbsent(targetNode, k -> new ArrayList()).add(shardId);
                AliasFilter aliasFilter = (AliasFilter)resp.getAliasFilters().get(shardId.getIndex().getUUID());
                if (aliasFilter == null) continue;
                nodeToAliasFilters.computeIfAbsent(targetNode, k -> new HashMap()).put(shardId.getIndex(), aliasFilter);
            }
            ArrayList<DataNode> dataNodes = new ArrayList<DataNode>(nodeToShards.size());
            for (Map.Entry e : nodeToShards.entrySet()) {
                DiscoveryNode node = (DiscoveryNode)nodes.get(e.getKey());
                Map<Index, AliasFilter> aliasFilters = nodeToAliasFilters.getOrDefault(e.getKey(), Map.of());
                dataNodes.add(new DataNode(this.transportService.getConnection(node), (List)e.getValue(), aliasFilters));
            }
            return new DataNodeResult(dataNodes, totalShards, skippedShards);
        });
        SearchShardsRequest searchShardsRequest = new SearchShardsRequest(originalIndices.indices(), originalIndices.indicesOptions(), filter, null, null, false, clusterAlias);
        this.transportService.sendChildRequest(this.transportService.getLocalNode(), EsqlSearchShardsAction.TYPE.name(), (TransportRequest)searchShardsRequest, parentTask, TransportRequestOptions.EMPTY, (TransportResponseHandler)new ActionListenerResponseHandler(searchShardsListener, SearchShardsResponse::new, this.esqlExecutor));
    }

    private void runComputeOnDataNode(CancellableTask task, String externalId, PhysicalPlan reducePlan, DataNodeRequest request, ActionListener<ComputeResponse> listener) {
        try (ComputeListener computeListener = new ComputeListener(this.transportService.getThreadPool(), this.computeService.cancelQueryOnFailure(task), (ActionListener<List<DriverProfile>>)listener.map(ComputeResponse::new));){
            ActionListener<Void> parentListener = computeListener.acquireAvoid();
            try {
                ExchangeSinkHandler internalSink = this.exchangeService.createSinkHandler(request.sessionId(), request.pragmas().exchangeBufferSize());
                DataNodeRequestExecutor dataNodeRequestExecutor = new DataNodeRequestExecutor(request, task, internalSink, request.configuration().pragmas().maxConcurrentShardsPerNode(), computeListener);
                dataNodeRequestExecutor.start();
                ExchangeSinkHandler externalSink = this.exchangeService.getSinkHandler(externalId);
                task.addListener(() -> this.exchangeService.finishSinkHandler(externalId, (Exception)new TaskCancelledException(task.getReasonCancelled())));
                ExchangeSourceHandler exchangeSource = new ExchangeSourceHandler(1, this.esqlExecutor);
                exchangeSource.addRemoteSink((arg_0, arg_1) -> ((ExchangeSinkHandler)internalSink).fetchPageAsync(arg_0, arg_1), true, () -> {}, 1, ActionListener.noop());
                ActionListener<List<DriverProfile>> reductionListener = computeListener.acquireCompute();
                this.computeService.runCompute(task, new ComputeContext(request.sessionId(), request.clusterAlias(), List.of(), request.configuration(), new FoldContext(request.pragmas().foldLimit().getBytes()), () -> ((ExchangeSourceHandler)exchangeSource).createExchangeSource(), () -> externalSink.createExchangeSink(() -> {})), reducePlan, (ActionListener<List<DriverProfile>>)ActionListener.wrap(resp -> externalSink.addCompletionListener(ActionListener.running(() -> {
                    this.exchangeService.finishSinkHandler(externalId, null);
                    reductionListener.onResponse(resp);
                })), e -> {
                    this.exchangeService.finishSinkHandler(externalId, e);
                    reductionListener.onFailure(e);
                }));
                parentListener.onResponse(null);
            }
            catch (Exception e2) {
                this.exchangeService.finishSinkHandler(externalId, e2);
                this.exchangeService.finishSinkHandler(request.sessionId(), e2);
                parentListener.onFailure(e2);
            }
        }
    }

    public void messageReceived(DataNodeRequest request, TransportChannel channel, Task task) {
        ChannelActionListener listener = new ChannelActionListener(channel);
        PhysicalPlan physicalPlan = request.plan();
        if (!(physicalPlan instanceof ExchangeSinkExec)) {
            listener.onFailure((Exception)new IllegalStateException("expected exchange sink for a remote compute; got " + String.valueOf((Object)request.plan())));
            return;
        }
        ExchangeSinkExec plan = (ExchangeSinkExec)physicalPlan;
        PhysicalPlan reductionPlan = ComputeService.reductionPlan(plan, request.runNodeLevelReduction());
        String sessionId = request.sessionId();
        request = new DataNodeRequest(sessionId + "[n]", request.configuration(), request.clusterAlias(), request.shardIds(), request.aliasFilters(), request.plan(), request.indices(), request.indicesOptions(), request.runNodeLevelReduction());
        this.runComputeOnDataNode((CancellableTask)task, sessionId, reductionPlan, request, (ActionListener<ComputeResponse>)listener);
    }

    private class DataNodeRequestExecutor {
        private final DataNodeRequest request;
        private final CancellableTask parentTask;
        private final ExchangeSinkHandler exchangeSink;
        private final ComputeListener computeListener;
        private final int maxConcurrentShards;
        private final ExchangeSink blockingSink;

        DataNodeRequestExecutor(DataNodeRequest request, CancellableTask parentTask, ExchangeSinkHandler exchangeSink, int maxConcurrentShards, ComputeListener computeListener) {
            this.request = request;
            this.parentTask = parentTask;
            this.exchangeSink = exchangeSink;
            this.computeListener = computeListener;
            this.maxConcurrentShards = maxConcurrentShards;
            this.blockingSink = exchangeSink.createExchangeSink(() -> {});
        }

        void start() {
            this.parentTask.addListener(() -> DataNodeComputeHandler.this.exchangeService.finishSinkHandler(this.request.sessionId(), (Exception)new TaskCancelledException(this.parentTask.getReasonCancelled())));
            this.runBatch(0);
        }

        private void runBatch(int startBatchIndex) {
            Configuration configuration = this.request.configuration();
            String clusterAlias = this.request.clusterAlias();
            String sessionId = this.request.sessionId();
            final int endBatchIndex = Math.min(startBatchIndex + this.maxConcurrentShards, this.request.shardIds().size());
            List<ShardId> shardIds = this.request.shardIds().subList(startBatchIndex, endBatchIndex);
            ActionListener<List<DriverProfile>> batchListener = new ActionListener<List<DriverProfile>>(){
                final ActionListener<List<DriverProfile>> ref;
                {
                    this.ref = DataNodeRequestExecutor.this.computeListener.acquireCompute();
                }

                public void onResponse(List<DriverProfile> result) {
                    try {
                        DataNodeRequestExecutor.this.onBatchCompleted(endBatchIndex);
                    }
                    finally {
                        this.ref.onResponse(result);
                    }
                }

                public void onFailure(Exception e) {
                    try {
                        DataNodeComputeHandler.this.exchangeService.finishSinkHandler(DataNodeRequestExecutor.this.request.sessionId(), e);
                    }
                    finally {
                        this.ref.onFailure(e);
                    }
                }
            };
            DataNodeComputeHandler.this.acquireSearchContexts(clusterAlias, shardIds, configuration, this.request.aliasFilters(), (ActionListener<List<SearchContext>>)ActionListener.wrap(arg_0 -> this.lambda$runBatch$4(sessionId, clusterAlias, configuration, (ActionListener)batchListener, arg_0), arg_0 -> ((ActionListener)batchListener).onFailure(arg_0)));
        }

        private void onBatchCompleted(int lastBatchIndex) {
            if (lastBatchIndex < this.request.shardIds().size() && !this.exchangeSink.isFinished()) {
                this.runBatch(lastBatchIndex);
            } else {
                ActionListener<Void> completionListener = this.computeListener.acquireAvoid();
                this.exchangeSink.addCompletionListener(ActionListener.runAfter(completionListener, () -> DataNodeComputeHandler.this.exchangeService.finishSinkHandler(this.request.sessionId(), null)));
                this.blockingSink.finish();
            }
        }

        private /* synthetic */ void lambda$runBatch$4(String sessionId, String clusterAlias, Configuration configuration, ActionListener batchListener, List searchContexts) throws Exception {
            assert (ThreadPool.assertCurrentThreadPool((String[])new String[]{"search", "esql_worker"}));
            ComputeContext computeContext = new ComputeContext(sessionId, clusterAlias, searchContexts, configuration, configuration.newFoldContext(), null, () -> this.exchangeSink.createExchangeSink(() -> {}));
            DataNodeComputeHandler.this.computeService.runCompute(this.parentTask, computeContext, this.request.plan(), (ActionListener<List<DriverProfile>>)batchListener);
        }
    }

    record DataNode(Transport.Connection connection, List<ShardId> shardIds, Map<Index, AliasFilter> aliasFilters) {
    }

    record DataNodeResult(List<DataNode> dataNodes, int totalShards, int skippedShards) {
    }
}

