From 6919098a5e5ce3342db577d99cd4595ff35f36da Mon Sep 17 00:00:00 2001 From: Jaeho Yoo Date: Tue, 31 Dec 2024 14:18:18 +0900 Subject: [PATCH] Add timeout handling for external routing group selector --- docs/routing-rules.md | 18 +++++++++++ .../ha/config/RulesExternalConfiguration.java | 13 ++++++++ .../router/ExternalRoutingGroupSelector.java | 30 +++++++++++++++++-- .../TestRoutingGroupSelectorExternal.java | 25 ++++++++++++++++ 4 files changed, 83 insertions(+), 3 deletions(-) diff --git a/docs/routing-rules.md b/docs/routing-rules.md index 78def563c..2246e5487 100644 --- a/docs/routing-rules.md +++ b/docs/routing-rules.md @@ -38,6 +38,9 @@ routingRules: excludeHeaders: - 'Authorization' - 'Accept-Encoding' + requestConfig: + idleTimeout: 1m + requestTimeout: 5m ``` * Redirect URLs are not supported. @@ -50,6 +53,21 @@ If there is error parsing the routing rules configuration file, an error is logged, and requests are routed using the routing group header `X-Trino-Routing-Group` as default. +### Configuring Request Parameters with `requestConfig` + +The `requestConfig` parameter allows you to customize various aspects of the +HTTP requests sent by the Trino Gateway. +By specifying key-value pairs, you can control settings such as timeouts. + +#### Available Configuration Options + +| Key | Description | Example Value | +|-----------------------------------|-----------------------------------------------------------------------|---------------| +| `idleTimeout` | Sets the idle timeout duration for the request. | `1m` | +| `requestTimeout` | Sets the total timeout duration for the request. | `5m` | + +*Note*: Durations should be specified same as format mentioned in [Trino](https://trino.io/docs/current/admin/properties.html#duration). + ### Use an external service for routing rules You can use an external service for processing your routing by setting the diff --git a/gateway-ha/src/main/java/io/trino/gateway/ha/config/RulesExternalConfiguration.java b/gateway-ha/src/main/java/io/trino/gateway/ha/config/RulesExternalConfiguration.java index 4601e0566..7bc2ab621 100644 --- a/gateway-ha/src/main/java/io/trino/gateway/ha/config/RulesExternalConfiguration.java +++ b/gateway-ha/src/main/java/io/trino/gateway/ha/config/RulesExternalConfiguration.java @@ -13,12 +13,15 @@ */ package io.trino.gateway.ha.config; +import java.util.HashMap; import java.util.List; +import java.util.Map; public class RulesExternalConfiguration { private String urlPath; private List excludeHeaders; + private Map requestConfig = new HashMap<>(); public String getUrlPath() { @@ -39,4 +42,14 @@ public void setExcludeHeaders(List excludeHeaders) { this.excludeHeaders = excludeHeaders; } + + public Map getRequestConfig() + { + return requestConfig; + } + + public void setRequestConfig(Map requestConfig) + { + this.requestConfig = requestConfig; + } } diff --git a/gateway-ha/src/main/java/io/trino/gateway/ha/router/ExternalRoutingGroupSelector.java b/gateway-ha/src/main/java/io/trino/gateway/ha/router/ExternalRoutingGroupSelector.java index c083f3122..3d40ec01f 100644 --- a/gateway-ha/src/main/java/io/trino/gateway/ha/router/ExternalRoutingGroupSelector.java +++ b/gateway-ha/src/main/java/io/trino/gateway/ha/router/ExternalRoutingGroupSelector.java @@ -25,6 +25,7 @@ import io.airlift.http.client.jetty.JettyHttpClient; import io.airlift.json.JsonCodec; import io.airlift.log.Logger; +import io.airlift.units.Duration; import io.trino.gateway.ha.config.RequestAnalyzerConfig; import io.trino.gateway.ha.config.RulesExternalConfiguration; import io.trino.gateway.ha.router.schema.RoutingGroupExternalBody; @@ -33,8 +34,10 @@ import java.net.URI; import java.net.URISyntaxException; +import java.util.Map; import java.util.Optional; import java.util.Set; +import java.util.function.Consumer; import static com.google.common.net.HttpHeaders.CONTENT_TYPE; import static com.google.common.net.MediaType.JSON_UTF_8; @@ -50,6 +53,7 @@ public class ExternalRoutingGroupSelector { private static final Logger log = Logger.get(ExternalRoutingGroupSelector.class); private final Set excludeHeaders; + private final Map requestConfig; private final URI uri; private final HttpClient httpClient; private final RequestAnalyzerConfig requestAnalyzerConfig; @@ -65,6 +69,7 @@ public class ExternalRoutingGroupSelector .add("Content-Length") .addAll(rulesExternalConfiguration.getExcludeHeaders()) .build(); + this.requestConfig = rulesExternalConfiguration.getRequestConfig(); this.requestAnalyzerConfig = requestAnalyzerConfig; trinoRequestUserProvider = new TrinoRequestUser.TrinoRequestUserProvider(requestAnalyzerConfig); @@ -87,12 +92,14 @@ public String findRoutingGroup(HttpServletRequest servletRequest) try { RoutingGroupExternalBody requestBody = createRequestBody(servletRequest); requestBodyGenerator = jsonBodyGenerator(ROUTING_GROUP_EXTERNAL_BODY_JSON_CODEC, requestBody); - request = preparePost() + Request.Builder requestBuilder = preparePost() .addHeader(CONTENT_TYPE, JSON_UTF_8.toString()) .addHeaders(getValidHeaders(servletRequest)) .setUri(uri) - .setBodyGenerator(requestBodyGenerator) - .build(); + .setBodyGenerator(requestBodyGenerator); + applyRequestConfig(requestBuilder); + + request = requestBuilder.build(); // Execute the request and get the response RoutingGroupExternalResponse response = httpClient.execute(request, ROUTING_GROUP_EXTERNAL_RESPONSE_JSON_RESPONSE_HANDLER); @@ -148,4 +155,21 @@ private Multimap getValidHeaders(HttpServletRequest servletReque } return headers; } + + private void applyRequestConfig(Request.Builder requestBuilder) + { + Map> configActions = Map.of( + "idleTimeout", value -> requestBuilder.setIdleTimeout(Duration.valueOf(value)), + "requestTimeout", value -> requestBuilder.setRequestTimeout(Duration.valueOf(value))); + + requestConfig.forEach((key, value) -> { + Consumer action = configActions.get(key); + if (action != null) { + action.accept(value); + } + else { + log.warn("Unknown request config key: %s", key); + } + }); + } } diff --git a/gateway-ha/src/test/java/io/trino/gateway/ha/router/TestRoutingGroupSelectorExternal.java b/gateway-ha/src/test/java/io/trino/gateway/ha/router/TestRoutingGroupSelectorExternal.java index a1c90179d..3ade49790 100644 --- a/gateway-ha/src/test/java/io/trino/gateway/ha/router/TestRoutingGroupSelectorExternal.java +++ b/gateway-ha/src/test/java/io/trino/gateway/ha/router/TestRoutingGroupSelectorExternal.java @@ -19,6 +19,7 @@ import io.airlift.http.client.JsonResponseHandler; import io.airlift.http.client.Request; import io.airlift.json.JsonCodec; +import io.airlift.units.Duration; import io.trino.gateway.ha.config.RequestAnalyzerConfig; import io.trino.gateway.ha.config.RulesExternalConfiguration; import io.trino.gateway.ha.router.schema.RoutingGroupExternalBody; @@ -41,6 +42,7 @@ import java.util.Collections; import java.util.Enumeration; import java.util.List; +import java.util.Map; import java.util.Optional; import static io.airlift.http.client.JsonResponseHandler.createJsonResponseHandler; @@ -196,6 +198,29 @@ void testExcludeHeader() assertThat(validHeaders.size()).isEqualTo(1); } + @Test + void testRequestConfig() + throws NoSuchMethodException, InvocationTargetException, IllegalAccessException + { + RulesExternalConfiguration rulesExternalConfiguration = provideRoutingRuleExternalConfig(); + rulesExternalConfiguration.setRequestConfig(Map.of("requestTimeout", "1m", "idleTimeout", "30s")); + RoutingGroupSelector routingGroupSelector = + RoutingGroupSelector.byRoutingExternal(rulesExternalConfiguration, requestAnalyzerConfig); + + Request.Builder requestBuilder = mock(Request.Builder.class); + + ArgumentCaptor timeoutCaptor = ArgumentCaptor.forClass(Duration.class); + when(requestBuilder.setRequestTimeout(timeoutCaptor.capture())).thenReturn(requestBuilder); + when(requestBuilder.setIdleTimeout(timeoutCaptor.capture())).thenReturn(requestBuilder); + + Method applyRequestConfig = ExternalRoutingGroupSelector.class.getDeclaredMethod("applyRequestConfig", Request.Builder.class); + applyRequestConfig.setAccessible(true); + applyRequestConfig.invoke(routingGroupSelector, requestBuilder); + + List capturedDurations = timeoutCaptor.getAllValues(); + assertThat(capturedDurations).containsExactlyInAnyOrder(Duration.valueOf("1m"), Duration.valueOf("30s")); + } + private HttpServletRequest prepareMockRequest() { HttpServletRequest mockRequest = mock(HttpServletRequest.class);