/*
 * Decompiled with CFR 0.152.
 */
package org.apache.shenyu.plugin.mcp.server.callback;

import com.google.common.collect.Maps;
import com.google.gson.JsonElement;
import com.google.gson.JsonObject;
import io.modelcontextprotocol.server.McpSyncServerExchange;
import java.net.URI;
import java.net.URISyntaxException;
import java.util.Map;
import java.util.Objects;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.TimeUnit;
import org.apache.shenyu.common.dto.MetaData;
import org.apache.shenyu.common.utils.GsonUtils;
import org.apache.shenyu.plugin.api.ShenyuPluginChain;
import org.apache.shenyu.plugin.api.context.ShenyuContext;
import org.apache.shenyu.plugin.base.cache.MetaDataCache;
import org.apache.shenyu.plugin.mcp.server.definition.ShenyuToolDefinition;
import org.apache.shenyu.plugin.mcp.server.holder.ShenyuMcpExchangeHolder;
import org.apache.shenyu.plugin.mcp.server.request.BodyWriterExchange;
import org.apache.shenyu.plugin.mcp.server.request.ParameterFormatter;
import org.apache.shenyu.plugin.mcp.server.request.RequestConfig;
import org.apache.shenyu.plugin.mcp.server.request.RequestConfigHelper;
import org.apache.shenyu.plugin.mcp.server.response.NonCommittingMcpResponseDecorator;
import org.apache.shenyu.plugin.mcp.server.response.ShenyuMcpResponseDecorator;
import org.apache.shenyu.plugin.mcp.server.session.McpSessionHelper;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.ai.chat.model.ToolContext;
import org.springframework.ai.tool.ToolCallback;
import org.springframework.ai.tool.definition.ToolDefinition;
import org.springframework.http.HttpMethod;
import org.springframework.http.server.reactive.ServerHttpRequest;
import org.springframework.http.server.reactive.ServerHttpResponse;
import org.springframework.http.server.reactive.ServerHttpResponseDecorator;
import org.springframework.lang.NonNull;
import org.springframework.util.Assert;
import org.springframework.util.StringUtils;
import org.springframework.web.server.ServerWebExchange;

public class ShenyuToolCallback
implements ToolCallback {
    private static final Logger LOG = LoggerFactory.getLogger(ShenyuToolCallback.class);
    private static final int DEFAULT_TIMEOUT_SECONDS = 30;
    private static final String MCP_TOOL_CALL_ATTR = "MCP_TOOL_CALL";
    private static final String MCP_SESSION_ID_ATTR = "MCP_TOOL_SESSION_ID";
    private static final String STREAMABLE_HTTP_PATH = "/streamablehttp";
    private final ToolDefinition toolDefinition;

    public ShenyuToolCallback(ToolDefinition toolDefinition) {
        this.toolDefinition = Objects.requireNonNull(toolDefinition, "ToolDefinition cannot be null");
    }

    @NonNull
    public ToolDefinition getToolDefinition() {
        return this.toolDefinition;
    }

    @NonNull
    public String call(@NonNull String input) {
        return this.call(input, new ToolContext((Map)Maps.newHashMap()));
    }

    @NonNull
    public String call(@NonNull String input, ToolContext toolContext) {
        Objects.requireNonNull(input, "Input cannot be null");
        Objects.requireNonNull(toolContext, "ToolContext cannot be null");
        LOG.debug("Executing tool call for definition '{}' with input length: {} chars", (Object)this.toolDefinition.name(), (Object)input.length());
        try {
            McpSyncServerExchange mcpExchange = this.extractMcpExchange(toolContext);
            String sessionId = this.extractSessionId(mcpExchange);
            ShenyuToolDefinition shenyuTool = this.validateToolDefinition();
            String configStr = this.extractRequestConfig(shenyuTool);
            ServerWebExchange originExchange = this.getOriginExchange(sessionId);
            ShenyuPluginChain chain = this.getPluginChain(originExchange);
            return this.executeToolCall(originExchange, chain, sessionId, configStr, input);
        }
        catch (Exception e) {
            LOG.error("Failed to process tool call for '{}': {}", new Object[]{this.toolDefinition.name(), e.getMessage(), e});
            throw new RuntimeException("Tool execution failed: " + e.getMessage(), e);
        }
    }

    private ShenyuToolDefinition validateToolDefinition() {
        if (!(this.toolDefinition instanceof ShenyuToolDefinition)) {
            throw new IllegalStateException("Tool definition must be of type ShenyuToolDefinition, got: " + this.toolDefinition.getClass().getSimpleName());
        }
        return (ShenyuToolDefinition)this.toolDefinition;
    }

    private String extractRequestConfig(ShenyuToolDefinition definition) {
        String config = definition.requestConfig();
        if (!StringUtils.hasText((String)config)) {
            throw new IllegalStateException("Request configuration cannot be empty");
        }
        LOG.debug("Using request configuration with length: {} chars", (Object)config.length());
        return config;
    }

    private ShenyuPluginChain getPluginChain(ServerWebExchange exchange) {
        ShenyuPluginChain chain = (ShenyuPluginChain)exchange.getAttribute("chain");
        Assert.notNull((Object)chain, (String)"ShenyuPluginChain cannot be null");
        return chain;
    }

    private String executeToolCall(ServerWebExchange originExchange, ShenyuPluginChain chain, String sessionId, String configStr, String input) {
        CompletableFuture<String> responseFuture = new CompletableFuture<String>();
        ServerWebExchange decoratedExchange = this.buildDecoratedExchange(originExchange, responseFuture, sessionId, configStr, input);
        LOG.debug("Executing plugin chain for session: {}", (Object)sessionId);
        boolean isTemporarySession = sessionId.startsWith("temp_");
        chain.execute(decoratedExchange).doOnSubscribe(s -> LOG.debug("Plugin chain subscribed for session: {}", (Object)sessionId)).doOnError(e -> {
            LOG.error("Plugin chain execution failed for session {}: {}", new Object[]{sessionId, e.getMessage(), e});
            responseFuture.completeExceptionally((Throwable)e);
        }).doOnSuccess(v -> LOG.debug("Plugin chain completed successfully for session: {}", (Object)sessionId)).doOnCancel(() -> {
            LOG.warn("Plugin chain execution cancelled for session: {}", (Object)sessionId);
            responseFuture.completeExceptionally(new RuntimeException("Execution was cancelled"));
        }).doFinally(signalType -> {
            if (isTemporarySession) {
                LOG.debug("Cleaning up temporary session: {} (signal: {})", (Object)sessionId, signalType);
                ShenyuMcpExchangeHolder.remove(sessionId);
            }
        }).subscribe();
        try {
            String result = responseFuture.get(30L, TimeUnit.SECONDS);
            LOG.debug("Tool call completed successfully for session: {}", (Object)sessionId);
            return result;
        }
        catch (Exception e2) {
            LOG.error("Timeout or error waiting for response for session {}: {}", new Object[]{sessionId, e2.getMessage(), e2});
            if (isTemporarySession) {
                LOG.debug("Emergency cleanup of temporary session on error: {}", (Object)sessionId);
                ShenyuMcpExchangeHolder.remove(sessionId);
            }
            throw new RuntimeException("Tool execution timeout or error: " + e2.getMessage(), e2);
        }
    }

    private ServerWebExchange buildDecoratedExchange(ServerWebExchange originExchange, CompletableFuture<String> responseFuture, String sessionId, String configStr, String input) {
        JsonObject inputJson = this.parseInput(input);
        RequestConfig requestConfig = this.buildRequestConfig(configStr, inputJson);
        ServerHttpRequest decoratedRequest = this.buildDecoratedRequest(originExchange, sessionId, requestConfig);
        ServerHttpResponseDecorator responseDecorator = this.createResponseDecorator(originExchange, sessionId, responseFuture, configStr);
        ServerWebExchange decoratedExchange = originExchange.mutate().request(decoratedRequest).response((ServerHttpResponse)responseDecorator).build();
        ServerWebExchange finalExchange = this.handleRequestBody(decoratedExchange, requestConfig);
        this.configureShenyuContext(finalExchange, sessionId, requestConfig.getPath(), configStr);
        return finalExchange;
    }

    private JsonObject parseInput(String input) {
        try {
            JsonObject inputJson = (JsonObject)GsonUtils.getInstance().fromJson(input, JsonObject.class);
            if (Objects.isNull(inputJson)) {
                throw new IllegalArgumentException("Invalid input JSON format");
            }
            return inputJson;
        }
        catch (Exception e) {
            LOG.error("Failed to parse input JSON: {}", (Object)e.getMessage());
            throw new IllegalArgumentException("Invalid JSON format: " + e.getMessage(), e);
        }
    }

    private RequestConfig buildRequestConfig(String configStr, JsonObject inputJson) {
        RequestConfigHelper configHelper = new RequestConfigHelper(configStr);
        JsonObject requestTemplate = configHelper.getRequestTemplate();
        JsonObject argsPosition = configHelper.getArgsPosition();
        String urlTemplate = configHelper.getUrlTemplate();
        String method = configHelper.getMethod();
        boolean argsToJsonBody = configHelper.isArgsToJsonBody();
        String path = RequestConfigHelper.buildPath(urlTemplate, argsPosition, inputJson);
        JsonObject bodyJson = this.buildFormattedBodyJson(argsToJsonBody, argsPosition, inputJson);
        return new RequestConfig(method, path, bodyJson, requestTemplate, argsToJsonBody);
    }

    private JsonObject buildFormattedBodyJson(boolean argsToJsonBody, JsonObject argsPosition, JsonObject inputJson) {
        JsonObject bodyJson = new JsonObject();
        if (!argsToJsonBody) {
            return bodyJson;
        }
        for (String key : argsPosition.keySet()) {
            String position = argsPosition.get(key).getAsString();
            if (!position.startsWith("body") || !inputJson.has(key)) continue;
            JsonElement value = inputJson.get(key);
            JsonElement formattedValue = this.formatBodyParameterValue(value, key);
            if ("body".equals(position)) {
                bodyJson.add(key, formattedValue);
                continue;
            }
            if (!position.startsWith("body.")) continue;
            String[] pathParts = position.substring(5).split("\\.");
            this.setNestedValue(bodyJson, pathParts, formattedValue);
        }
        return bodyJson;
    }

    private JsonElement formatBodyParameterValue(JsonElement value, String paramName) {
        if (value.isJsonPrimitive() && value.getAsJsonPrimitive().isString()) {
            String stringValue = value.getAsString();
            JsonElement parsed = ParameterFormatter.tryParseJsonString(stringValue);
            if (!parsed.equals(value)) {
                LOG.debug("Parsed JSON string parameter '{}' into {}", (Object)paramName, (Object)(parsed.isJsonArray() ? "array" : "object"));
            }
            return parsed;
        }
        return value;
    }

    private void setNestedValue(JsonObject jsonObject, String[] pathParts, JsonElement value) {
        JsonObject current = jsonObject;
        for (int i = 0; i < pathParts.length - 1; ++i) {
            String part = pathParts[i];
            if (!current.has(part)) {
                current.add(part, (JsonElement)new JsonObject());
            }
            current = current.getAsJsonObject(part);
        }
        current.add(pathParts[pathParts.length - 1], value);
    }

    private ServerHttpRequest buildDecoratedRequest(ServerWebExchange originExchange, String sessionId, RequestConfig requestConfig) {
        ServerHttpRequest.Builder requestBuilder = originExchange.getRequest().mutate().method(HttpMethod.valueOf((String)requestConfig.getMethod())).header("sessionId", new String[]{sessionId}).header("Accept", new String[]{"application/json"});
        this.addCustomHeaders(requestBuilder, requestConfig);
        this.configureContentType(requestBuilder, requestConfig.getMethod());
        this.setTargetUri(requestBuilder, originExchange, requestConfig.getPath());
        return requestBuilder.build();
    }

    private void addCustomHeaders(ServerHttpRequest.Builder requestBuilder, RequestConfig requestConfig) {
        if (requestConfig.getRequestTemplate().has("headers")) {
            for (JsonElement headerElem : requestConfig.getRequestTemplate().getAsJsonArray("headers")) {
                JsonObject headerObj = headerElem.getAsJsonObject();
                requestBuilder.header(headerObj.get("key").getAsString(), new String[]{headerObj.get("value").getAsString()});
            }
        }
    }

    private void configureContentType(ServerHttpRequest.Builder requestBuilder, String method) {
        if (this.isRequestBodyMethod(method)) {
            requestBuilder.header("Content-Type", new String[]{"application/json"});
        } else {
            requestBuilder.headers(httpHeaders -> httpHeaders.remove((Object)"Content-Type"));
        }
    }

    private void setTargetUri(ServerHttpRequest.Builder requestBuilder, ServerWebExchange originExchange, String path) {
        try {
            URI oldUri = originExchange.getRequest().getURI();
            String newUriStr = oldUri.getScheme() + "://" + oldUri.getAuthority() + path;
            requestBuilder.uri(new URI(newUriStr));
        }
        catch (URISyntaxException e) {
            throw new RuntimeException("Invalid URI construction: " + e.getMessage(), e);
        }
    }

    private ServerHttpResponseDecorator createResponseDecorator(ServerWebExchange originExchange, String sessionId, CompletableFuture<String> responseFuture, String configStr) {
        RequestConfigHelper configHelper = new RequestConfigHelper(configStr);
        JsonObject responseTemplate = configHelper.getResponseTemplate();
        if (this.isStreamableHttpProtocol(originExchange)) {
            LOG.debug("Using non-committing decorator for Streamable HTTP protocol, session: {}", (Object)sessionId);
            return new NonCommittingMcpResponseDecorator(originExchange.getResponse(), sessionId, responseFuture, responseTemplate);
        }
        LOG.debug("Using standard decorator for SSE protocol, session: {}", (Object)sessionId);
        return new ShenyuMcpResponseDecorator(originExchange.getResponse(), sessionId, responseFuture, responseTemplate);
    }

    private ServerWebExchange handleRequestBody(ServerWebExchange decoratedExchange, RequestConfig requestConfig) {
        if (this.isRequestBodyMethod(requestConfig.getMethod()) && requestConfig.getBodyJson().size() > 0) {
            return new BodyWriterExchange(decoratedExchange, requestConfig.getBodyJson().toString());
        }
        return decoratedExchange;
    }

    private void configureShenyuContext(ServerWebExchange decoratedExchange, String sessionId, String decoratedPath, String configStr) {
        ShenyuContext shenyuContext = (ShenyuContext)decoratedExchange.getAttribute("context");
        if (Objects.nonNull(shenyuContext)) {
            this.configureMetadata(decoratedExchange, decoratedPath, shenyuContext);
            RequestConfigHelper configHelper = new RequestConfigHelper(configStr);
            MetaData metaData = MetaDataCache.getInstance().obtain(configHelper.getUrlTemplate());
            if (Objects.nonNull(metaData) && Boolean.TRUE.equals(metaData.getEnabled())) {
                decoratedExchange.getAttributes().put("metaData", metaData);
                shenyuContext.setRpcType(metaData.getRpcType());
            }
            shenyuContext.setPath(decoratedPath);
            shenyuContext.setRealUrl(decoratedPath);
            LOG.debug("Configured RpcType to HTTP for tool call, session: {}", (Object)sessionId);
            decoratedExchange.getAttributes().put("context", shenyuContext);
            decoratedExchange.getAttributes().put(MCP_TOOL_CALL_ATTR, true);
            decoratedExchange.getAttributes().put(MCP_SESSION_ID_ATTR, sessionId);
        }
    }

    private void configureMetadata(ServerWebExchange decoratedExchange, String decoratedPath, ShenyuContext shenyuContext) {
        MetaData metaData = MetaDataCache.getInstance().obtain(decoratedPath);
        if (Objects.nonNull(metaData) && Boolean.TRUE.equals(metaData.getEnabled())) {
            decoratedExchange.getAttributes().put("metaData", metaData);
            shenyuContext.setRpcType(metaData.getRpcType());
            LOG.debug("Applied metadata for path: {}", (Object)decoratedPath);
        }
    }

    private boolean isStreamableHttpProtocol(ServerWebExchange exchange) {
        String uri = exchange.getRequest().getURI().getRawPath();
        boolean isStreamable = uri.contains(STREAMABLE_HTTP_PATH) || uri.endsWith(STREAMABLE_HTTP_PATH);
        LOG.debug("Protocol detection - URI: {}, isStreamableHttp: {}", (Object)uri, (Object)isStreamable);
        return isStreamable;
    }

    private boolean isRequestBodyMethod(String method) {
        return "POST".equalsIgnoreCase(method) || "PUT".equalsIgnoreCase(method) || "PATCH".equalsIgnoreCase(method);
    }

    private McpSyncServerExchange extractMcpExchange(ToolContext toolContext) {
        McpSyncServerExchange exchange = McpSessionHelper.getMcpSyncServerExchange(toolContext);
        if (Objects.isNull(exchange)) {
            throw new IllegalStateException("Failed to retrieve MCP sync server exchange from context");
        }
        return exchange;
    }

    private String extractSessionId(McpSyncServerExchange mcpExchange) {
        String sessionId;
        try {
            sessionId = McpSessionHelper.getSessionId(mcpExchange);
        }
        catch (IllegalAccessException | NoSuchFieldException e) {
            throw new RuntimeException(e);
        }
        if (StringUtils.hasText((String)sessionId)) {
            LOG.debug("Extracted session ID: {}", (Object)sessionId);
            return sessionId;
        }
        throw new IllegalStateException("Session ID is empty \u2013 it should have been set earlier by handleMessageEndpoint");
    }

    private ServerWebExchange getOriginExchange(String sessionId) {
        ServerWebExchange exchange = ShenyuMcpExchangeHolder.get(sessionId);
        if (Objects.nonNull(exchange)) {
            LOG.debug("Found existing exchange for session: {}", (Object)sessionId);
            return exchange;
        }
        throw new IllegalStateException("No ServerWebExchange found for session '" + sessionId + "'. It should have been stored by handleMessageEndpoint before the tool was invoked.");
    }
}

