From 3b8958a181dbbd892efb2df25875306b01156958 Mon Sep 17 00:00:00 2001 From: jrhee17 Date: Fri, 5 Jul 2024 17:15:13 +0900 Subject: [PATCH 1/7] client contexts are cancellable --- .../client/AbstractHttpRequestHandler.java | 45 +- .../armeria/client/HttpClientDelegate.java | 60 +- .../armeria/client/HttpResponseWrapper.java | 41 +- .../armeria/common/RequestContext.java | 3 + .../common/logging/DefaultRequestLog.java | 3 +- .../client/DefaultClientRequestContext.java | 26 +- .../common/CancellationScheduler.java | 70 ++- .../common/DefaultCancellationScheduler.java | 575 +++++++----------- .../common/NoopCancellationScheduler.java | 13 +- .../server/DefaultServiceRequestContext.java | 4 +- .../server/AbstractHttpResponseHandler.java | 18 +- .../client/ContextCancellationTest.java | 350 +++++++++++ .../CountingConnectionPoolListener.java | 2 +- .../armeria/client/DelegatingHttpRequest.java | 85 +++ .../client/Http1ConnectionReuseTest.java | 54 +- .../client/HttpClientResponseTimeoutTest.java | 45 +- .../armeria/common/ContextPushHookTest.java | 4 +- .../common/CancellationSchedulerTest.java | 226 +++++-- 18 files changed, 1071 insertions(+), 553 deletions(-) create mode 100644 core/src/test/java/com/linecorp/armeria/client/ContextCancellationTest.java create mode 100644 core/src/test/java/com/linecorp/armeria/client/DelegatingHttpRequest.java diff --git a/core/src/main/java/com/linecorp/armeria/client/AbstractHttpRequestHandler.java b/core/src/main/java/com/linecorp/armeria/client/AbstractHttpRequestHandler.java index a25a5b9fd6b..915a1826df3 100644 --- a/core/src/main/java/com/linecorp/armeria/client/AbstractHttpRequestHandler.java +++ b/core/src/main/java/com/linecorp/armeria/client/AbstractHttpRequestHandler.java @@ -45,6 +45,8 @@ import com.linecorp.armeria.internal.client.ClientRequestContextExtension; import com.linecorp.armeria.internal.client.DecodedHttpResponse; import com.linecorp.armeria.internal.client.HttpSession; +import com.linecorp.armeria.internal.common.CancellationScheduler; +import com.linecorp.armeria.internal.common.CancellationScheduler.CancellationTask; import com.linecorp.armeria.internal.common.RequestContextUtil; import com.linecorp.armeria.unsafe.PooledObjects; @@ -90,6 +92,7 @@ enum State { private ScheduledFuture timeoutFuture; private State state = State.NEEDS_TO_WRITE_FIRST_HEADER; private boolean loggedRequestFirstBytesTransferred; + private boolean failed; AbstractHttpRequestHandler(Channel ch, ClientHttpObjectEncoder encoder, HttpResponseDecoder responseDecoder, DecodedHttpResponse originalRes, @@ -193,9 +196,25 @@ final boolean tryInitialize() { () -> failAndReset(WriteTimeoutException.get()), timeoutMillis, TimeUnit.MILLISECONDS); } + final CancellationScheduler scheduler = cancellationScheduler(); + if (scheduler != null) { + scheduler.updateTask(newCancellationTask()); + } return true; } + private CancellationTask newCancellationTask() { + return cause -> { + if (ch.eventLoop().inEventLoop()) { + try (SafeCloseable ignored = RequestContextUtil.pop()) { + failAndReset(cause); + } + } else { + ch.eventLoop().execute(() -> failAndReset(cause)); + } + }; + } + RequestHeaders mergedRequestHeaders(RequestHeaders headers) { final HttpHeaders internalHeaders; final ClientRequestContextExtension ctxExtension = ctx.as(ClientRequestContextExtension.class); @@ -354,6 +373,10 @@ final void failRequest(Throwable cause) { } private void fail(Throwable cause) { + if (failed) { + return; + } + failed = true; state = State.DONE; cancel(); logBuilder.endRequest(cause); @@ -368,9 +391,20 @@ private void fail(Throwable cause) { logBuilder.endResponse(cause); originalRes.close(cause); } + + final CancellationScheduler scheduler = cancellationScheduler(); + if (scheduler != null) { + // best-effort attempt to cancel the scheduled timeout task so that RequestContext#cause + // isn't set unnecessarily + scheduler.cancelScheduled(); + } } final void failAndReset(Throwable cause) { + if (failed) { + return; + } + if (cause instanceof WriteTimeoutException) { final HttpSession session = HttpSession.get(ch); // Mark the session as unhealthy so that subsequent requests do not use it. @@ -395,7 +429,7 @@ final void failAndReset(Throwable cause) { error = Http2Error.INTERNAL_ERROR; } - if (error.code() != Http2Error.CANCEL.code()) { + if (error.code() != Http2Error.CANCEL.code() && cause != ctx.cancellationCause()) { Exceptions.logIfUnexpected(logger, ch, HttpSession.get(ch).protocol(), "a request publisher raised an exception", cause); @@ -416,4 +450,13 @@ final boolean cancelTimeout() { this.timeoutFuture = null; return timeoutFuture.cancel(false); } + + @Nullable + private CancellationScheduler cancellationScheduler() { + final ClientRequestContextExtension ctxExt = ctx.as(ClientRequestContextExtension.class); + if (ctxExt != null) { + return ctxExt.responseCancellationScheduler(); + } + return null; + } } diff --git a/core/src/main/java/com/linecorp/armeria/client/HttpClientDelegate.java b/core/src/main/java/com/linecorp/armeria/client/HttpClientDelegate.java index e54959ac233..f3d60aea270 100644 --- a/core/src/main/java/com/linecorp/armeria/client/HttpClientDelegate.java +++ b/core/src/main/java/com/linecorp/armeria/client/HttpClientDelegate.java @@ -33,9 +33,9 @@ import com.linecorp.armeria.common.annotation.Nullable; import com.linecorp.armeria.common.logging.ClientConnectionTimings; import com.linecorp.armeria.common.logging.ClientConnectionTimingsBuilder; -import com.linecorp.armeria.common.logging.RequestLogBuilder; import com.linecorp.armeria.common.util.SafeCloseable; import com.linecorp.armeria.internal.client.ClientPendingThrowableUtil; +import com.linecorp.armeria.internal.client.ClientRequestContextExtension; import com.linecorp.armeria.internal.client.DecodedHttpResponse; import com.linecorp.armeria.internal.client.HttpSession; import com.linecorp.armeria.internal.client.PooledChannel; @@ -63,13 +63,13 @@ final class HttpClientDelegate implements HttpClient { public HttpResponse execute(ClientRequestContext ctx, HttpRequest req) throws Exception { final Throwable throwable = ClientPendingThrowableUtil.pendingThrowable(ctx); if (throwable != null) { - return earlyFailedResponse(throwable, ctx, req); + return earlyFailedResponse(throwable, ctx); } if (req != ctx.request()) { return earlyFailedResponse( new IllegalStateException("ctx.request() does not match the actual request; " + "did you forget to call ctx.updateRequest() in your decorator?"), - ctx, req); + ctx); } final Endpoint endpoint = ctx.endpoint(); @@ -84,7 +84,7 @@ public HttpResponse execute(ClientRequestContext ctx, HttpRequest req) throws Ex // and response created here will be exposed only when `EndpointGroup.select()` returned `null`. // // See `DefaultClientRequestContext.init()` for more information. - return earlyFailedResponse(EmptyEndpointGroupException.get(ctx.endpointGroup()), ctx, req); + return earlyFailedResponse(EmptyEndpointGroupException.get(ctx.endpointGroup()), ctx); } final SessionProtocol protocol = ctx.sessionProtocol(); @@ -92,13 +92,19 @@ public HttpResponse execute(ClientRequestContext ctx, HttpRequest req) throws Ex try { proxyConfig = getProxyConfig(protocol, endpoint); } catch (Throwable t) { - return earlyFailedResponse(t, ctx, req); + return earlyFailedResponse(t, ctx); + } + + final Throwable cancellationCause = ctx.cancellationCause(); + if (cancellationCause != null) { + return earlyFailedResponse(cancellationCause, ctx); } final Endpoint endpointWithPort = endpoint.withDefaultPort(ctx.sessionProtocol()); final EventLoop eventLoop = ctx.eventLoop().withoutContext(); // TODO(ikhoon) Use ctx.exchangeType() to create an optimized HttpResponse for non-streaming response. final DecodedHttpResponse res = new DecodedHttpResponse(eventLoop); + updateCancellationTask(ctx, req, res); final ClientConnectionTimingsBuilder timingsBuilder = ClientConnectionTimings.builder(); @@ -115,7 +121,7 @@ public HttpResponse execute(ClientRequestContext ctx, HttpRequest req) throws Ex acquireConnectionAndExecute(ctx, resolved, req, res, timingsBuilder, proxyConfig); } else { ctx.logBuilder().session(null, ctx.sessionProtocol(), timingsBuilder.build()); - earlyFailedResponse(cause, ctx, req, res); + ctx.cancel(cause); } }); } @@ -123,6 +129,23 @@ public HttpResponse execute(ClientRequestContext ctx, HttpRequest req) throws Ex return res; } + private static void updateCancellationTask(ClientRequestContext ctx, HttpRequest req, + DecodedHttpResponse res) { + final ClientRequestContextExtension ctxExt = ctx.as(ClientRequestContextExtension.class); + if (ctxExt == null) { + return; + } + ctxExt.responseCancellationScheduler().updateTask(cause -> { + try (SafeCloseable ignored = RequestContextUtil.pop()) { + final UnprocessedRequestException ure = UnprocessedRequestException.of(cause); + req.abort(ure); + ctx.logBuilder().endRequest(ure); + res.close(ure); + ctx.logBuilder().endResponse(ure); + } + }); + } + private void resolveAddress(Endpoint endpoint, ClientRequestContext ctx, BiConsumer<@Nullable Endpoint, @Nullable Throwable> onComplete) { @@ -169,7 +192,7 @@ private void acquireConnectionAndExecute0(ClientRequestContext ctx, Endpoint end try { pool = factory.pool(ctx.eventLoop().withoutContext()); } catch (Throwable t) { - earlyFailedResponse(t, ctx, req, res); + ctx.cancel(t); return; } final SessionProtocol protocol = ctx.sessionProtocol(); @@ -185,7 +208,7 @@ private void acquireConnectionAndExecute0(ClientRequestContext ctx, Endpoint end if (cause == null) { doExecute(newPooledChannel, ctx, req, res); } else { - earlyFailedResponse(cause, ctx, req, res); + ctx.cancel(cause); } return null; }); @@ -224,29 +247,12 @@ private static void logSession(ClientRequestContext ctx, @Nullable PooledChannel } } - private static HttpResponse earlyFailedResponse(Throwable t, ClientRequestContext ctx, HttpRequest req) { + private static HttpResponse earlyFailedResponse(Throwable t, ClientRequestContext ctx) { final UnprocessedRequestException cause = UnprocessedRequestException.of(t); - handleEarlyRequestException(ctx, req, cause); + ctx.cancel(cause); return HttpResponse.ofFailure(cause); } - private static void earlyFailedResponse(Throwable t, ClientRequestContext ctx, HttpRequest req, - DecodedHttpResponse res) { - final UnprocessedRequestException cause = UnprocessedRequestException.of(t); - handleEarlyRequestException(ctx, req, cause); - res.close(cause); - } - - private static void handleEarlyRequestException(ClientRequestContext ctx, - HttpRequest req, Throwable cause) { - try (SafeCloseable ignored = RequestContextUtil.pop()) { - req.abort(cause); - final RequestLogBuilder logBuilder = ctx.logBuilder(); - logBuilder.endRequest(cause); - logBuilder.endResponse(cause); - } - } - private static void doExecute(PooledChannel pooledChannel, ClientRequestContext ctx, HttpRequest req, DecodedHttpResponse res) { final Channel channel = pooledChannel.get(); diff --git a/core/src/main/java/com/linecorp/armeria/client/HttpResponseWrapper.java b/core/src/main/java/com/linecorp/armeria/client/HttpResponseWrapper.java index d9adc43ac95..db8ebf945b7 100644 --- a/core/src/main/java/com/linecorp/armeria/client/HttpResponseWrapper.java +++ b/core/src/main/java/com/linecorp/armeria/client/HttpResponseWrapper.java @@ -37,10 +37,12 @@ import com.linecorp.armeria.common.stream.StreamWriter; import com.linecorp.armeria.common.stream.SubscriptionOption; import com.linecorp.armeria.common.util.Exceptions; +import com.linecorp.armeria.common.util.SafeCloseable; import com.linecorp.armeria.internal.client.ClientRequestContextExtension; import com.linecorp.armeria.internal.client.DecodedHttpResponse; import com.linecorp.armeria.internal.common.CancellationScheduler; import com.linecorp.armeria.internal.common.CancellationScheduler.CancellationTask; +import com.linecorp.armeria.internal.common.RequestContextUtil; import com.linecorp.armeria.unsafe.PooledObjects; import io.netty.channel.EventLoop; @@ -213,7 +215,7 @@ void close(@Nullable Throwable cause, boolean cancel) { } done = true; closed = true; - cancelTimeoutOrLog(cause, cancel); + cancelTimeoutAndLog(cause, cancel); final HttpRequest request = ctx.request(); assert request != null; if (cause != null) { @@ -250,32 +252,24 @@ private void cancelAction(@Nullable Throwable cause) { } } - private void cancelTimeoutOrLog(@Nullable Throwable cause, boolean cancel) { - CancellationScheduler responseCancellationScheduler = null; + private void cancelTimeoutAndLog(@Nullable Throwable cause, boolean cancel) { final ClientRequestContextExtension ctxExtension = ctx.as(ClientRequestContextExtension.class); if (ctxExtension != null) { - responseCancellationScheduler = ctxExtension.responseCancellationScheduler(); + // best-effort attempt to cancel the scheduled timeout task so that RequestContext#cause + // isn't set unnecessarily + ctxExtension.responseCancellationScheduler().cancelScheduled(); } - if (responseCancellationScheduler == null || !responseCancellationScheduler.isFinished()) { - if (responseCancellationScheduler != null) { - responseCancellationScheduler.clearTimeout(false); - } - // There's no timeout or the response has not been timed out. - if (cancel) { - cancelAction(cause); - } else { - closeAction(cause); - } + if (cancel) { + cancelAction(cause); return; } if (delegate.isOpen()) { closeAction(cause); } - // Response has been timed out already. - // Log only when it's not a ResponseTimeoutException. - if (cause instanceof ResponseTimeoutException) { + // the context has been cancelled either by timeout or by user invocation + if (cause == ctx.cancellationCause()) { return; } @@ -297,7 +291,8 @@ void initTimeout() { if (ctxExtension != null) { final CancellationScheduler responseCancellationScheduler = ctxExtension.responseCancellationScheduler(); - responseCancellationScheduler.start(newCancellationTask()); + responseCancellationScheduler.updateTask(newCancellationTask()); + responseCancellationScheduler.start(); } } @@ -310,9 +305,13 @@ public boolean canSchedule() { @Override public void run(Throwable cause) { - delegate.close(cause); - ctx.request().abort(cause); - ctx.logBuilder().endResponse(cause); + if (ctx.eventLoop().inEventLoop()) { + try (SafeCloseable ignored = RequestContextUtil.pop()) { + close(cause); + } + } else { + ctx.eventLoop().withoutContext().execute(() -> close(cause)); + } } }; } diff --git a/core/src/main/java/com/linecorp/armeria/common/RequestContext.java b/core/src/main/java/com/linecorp/armeria/common/RequestContext.java index 685764622f1..1f0935431ba 100644 --- a/core/src/main/java/com/linecorp/armeria/common/RequestContext.java +++ b/core/src/main/java/com/linecorp/armeria/common/RequestContext.java @@ -467,6 +467,9 @@ default void setRequestAutoAbortDelay(Duration delay) { /** * Returns the cause of cancellation, {@code null} if the request has not been cancelled. + * Note that there is no guarantee that the cancellation cause is equivalent to the cause of failure + * for {@link HttpRequest} or {@link HttpResponse}. Refer to {@link RequestLog#requestCause()} + * or {@link RequestLog#responseCause()} for the exact reason why a request or response failed. */ @Nullable Throwable cancellationCause(); diff --git a/core/src/main/java/com/linecorp/armeria/common/logging/DefaultRequestLog.java b/core/src/main/java/com/linecorp/armeria/common/logging/DefaultRequestLog.java index b91bc58a247..3f0f552c084 100644 --- a/core/src/main/java/com/linecorp/armeria/common/logging/DefaultRequestLog.java +++ b/core/src/main/java/com/linecorp/armeria/common/logging/DefaultRequestLog.java @@ -428,7 +428,8 @@ private void updateFlags(int flags) { private static void completeSatisfiedFutures(RequestLogFuture[] satisfiedFutures, RequestLog log, RequestContext ctx) { if (!ctx.eventLoop().inEventLoop()) { - ctx.eventLoop().execute(() -> completeSatisfiedFutures(satisfiedFutures, log, ctx)); + ctx.eventLoop().withoutContext().execute( + () -> completeSatisfiedFutures(satisfiedFutures, log, ctx)); return; } for (RequestLogFuture f : satisfiedFutures) { diff --git a/core/src/main/java/com/linecorp/armeria/internal/client/DefaultClientRequestContext.java b/core/src/main/java/com/linecorp/armeria/internal/client/DefaultClientRequestContext.java index f6b0478c4bb..bfb7baf0ef2 100644 --- a/core/src/main/java/com/linecorp/armeria/internal/client/DefaultClientRequestContext.java +++ b/core/src/main/java/com/linecorp/armeria/internal/client/DefaultClientRequestContext.java @@ -66,13 +66,16 @@ import com.linecorp.armeria.common.logging.RequestLogBuilder; import com.linecorp.armeria.common.logging.RequestLogProperty; import com.linecorp.armeria.common.util.ReleasableHolder; +import com.linecorp.armeria.common.util.SafeCloseable; import com.linecorp.armeria.common.util.TextFormatter; import com.linecorp.armeria.common.util.TimeoutMode; import com.linecorp.armeria.common.util.UnmodifiableFuture; import com.linecorp.armeria.internal.common.CancellationScheduler; +import com.linecorp.armeria.internal.common.CancellationScheduler.CancellationTask; import com.linecorp.armeria.internal.common.HeaderOverridingHttpRequest; import com.linecorp.armeria.internal.common.NonWrappingRequestContext; import com.linecorp.armeria.internal.common.RequestContextExtension; +import com.linecorp.armeria.internal.common.RequestContextUtil; import com.linecorp.armeria.internal.common.SchemeAndAuthority; import com.linecorp.armeria.internal.common.stream.FixedStreamMessage; import com.linecorp.armeria.internal.common.util.ChannelUtil; @@ -445,7 +448,7 @@ private void acquireEventLoop(EndpointGroup endpointGroup) { options().factory().acquireEventLoop(sessionProtocol(), endpointGroup, endpoint); eventLoop = releasableEventLoop.get(); log.whenComplete().thenAccept(unused -> releasableEventLoop.release()); - responseCancellationScheduler.init(eventLoop()); + initializeResponseCancellationScheduler(); } } @@ -545,12 +548,27 @@ private DefaultClientRequestContext(DefaultClientRequestContext ctx, // the root context. if (endpoint == null || ctx.endpoint() == endpoint && ctx.log.children().isEmpty()) { eventLoop = ctx.eventLoop().withoutContext(); - responseCancellationScheduler.init(eventLoop()); + initializeResponseCancellationScheduler(); } else { acquireEventLoop(endpoint); } } + private void initializeResponseCancellationScheduler() { + final CancellationTask cancellationTask = cause -> { + try (SafeCloseable ignored = RequestContextUtil.pop()) { + final HttpRequest request = request(); + if (request != null) { + request.abort(cause); + } + log.endRequest(cause); + log.endResponse(cause); + } + }; + responseCancellationScheduler.init(eventLoop().withoutContext()); + responseCancellationScheduler.updateTask(cancellationTask); + } + @Nullable private Consumer copyThreadLocalCustomizer() { final ClientThreadLocalState state = ClientThreadLocalState.get(); @@ -926,13 +944,13 @@ public CompletableFuture whenResponseCancelled() { @Deprecated @Override public CompletableFuture whenResponseTimingOut() { - return responseCancellationScheduler.whenTimingOut(); + return whenResponseCancelling().handle((v, e) -> null); } @Deprecated @Override public CompletableFuture whenResponseTimedOut() { - return responseCancellationScheduler.whenTimedOut(); + return whenResponseCancelled().handle((v, e) -> null); } @Override diff --git a/core/src/main/java/com/linecorp/armeria/internal/common/CancellationScheduler.java b/core/src/main/java/com/linecorp/armeria/internal/common/CancellationScheduler.java index f8cfe3576b5..2145907e50f 100644 --- a/core/src/main/java/com/linecorp/armeria/internal/common/CancellationScheduler.java +++ b/core/src/main/java/com/linecorp/armeria/internal/common/CancellationScheduler.java @@ -26,10 +26,16 @@ public interface CancellationScheduler { static CancellationScheduler ofClient(long timeoutNanos) { + if (timeoutNanos == 0) { + timeoutNanos = Long.MAX_VALUE; + } return new DefaultCancellationScheduler(timeoutNanos, false); } static CancellationScheduler ofServer(long timeoutNanos) { + if (timeoutNanos == 0) { + timeoutNanos = Long.MAX_VALUE; + } return new DefaultCancellationScheduler(timeoutNanos, true); } @@ -51,36 +57,52 @@ static CancellationScheduler noop() { return NoopCancellationScheduler.INSTANCE; } - CancellationTask noopCancellationTask = new CancellationTask() { - @Override - public boolean canSchedule() { - return true; - } - - @Override - public void run(Throwable cause) { /* no-op */ } - }; + CancellationTask noopCancellationTask = cause -> {}; void initAndStart(EventExecutor eventLoop, CancellationTask task); void init(EventExecutor eventLoop); - void start(CancellationTask task); + /** + * Starts the scheduler task. If a timeout has already been configured, then scheduling is done. + * If the timeout is undefined, then the task won't be scheduled. If a timeout has already been reached + * the execution will be done from the designated event loop. Note that this behavior + * differs from {@link #setTimeoutNanos(TimeoutMode, long)} where a task is invoked immediately in the + * same thread. + * This is mostly due to how armeria uses this API - if this behavior is to be changed, + * we should make sure all locations invoking {@link #start()} can handle exceptions on invocation. + */ + void start(); + /** + * Clears the timeout. If a scheduled task exists, a best effort is made to cancel it. + */ void clearTimeout(); - void clearTimeout(boolean resetTimeout); + /** + * Cancels the scheduled timeout task if exists. + * @return true if a timeout task doesn't exist, or a task has been cancelled. + */ + boolean cancelScheduled(); void setTimeoutNanos(TimeoutMode mode, long timeoutNanos); - void finishNow(); + default void finishNow() { + finishNow(null); + } void finishNow(@Nullable Throwable cause); boolean isFinished(); - @Nullable Throwable cause(); + @Nullable + Throwable cause(); + /** + * Before the scheduler has started, the configured timeout will be returned regardless of the + * {@link TimeoutMode}. If the scheduler has already started, the timeout since + * {@link #startTimeNanos()} will be returned. + */ long timeoutNanos(); long startTimeNanos(); @@ -89,18 +111,18 @@ public void run(Throwable cause) { /* no-op */ } CompletableFuture whenCancelled(); - @Deprecated - CompletableFuture whenTimingOut(); - - @Deprecated - CompletableFuture whenTimedOut(); + /** + * Updates the task that will be executed once this scheduler completes either by the configured timeout, + * or immediately via {@link #finishNow()}. If the scheduler hasn't completed yet, the task will simply + * be updated. If the scheduler has already been triggered for completion, the supplied + * {@link CancellationTask} will be executed after the currently set task has finished executing. + */ + void updateTask(CancellationTask cancellationTask); enum State { INIT, - INACTIVE, - SCHEDULED, - FINISHING, - FINISHED + PENDING, + FINISHED, } /** @@ -110,7 +132,9 @@ interface CancellationTask { /** * Returns {@code true} if the cancellation task can be scheduled. */ - boolean canSchedule(); + default boolean canSchedule() { + return true; + } /** * Invoked by the scheduler with the cause of cancellation. diff --git a/core/src/main/java/com/linecorp/armeria/internal/common/DefaultCancellationScheduler.java b/core/src/main/java/com/linecorp/armeria/internal/common/DefaultCancellationScheduler.java index 536b3e8a0bc..1488b8616e0 100644 --- a/core/src/main/java/com/linecorp/armeria/internal/common/DefaultCancellationScheduler.java +++ b/core/src/main/java/com/linecorp/armeria/internal/common/DefaultCancellationScheduler.java @@ -22,17 +22,17 @@ import java.util.concurrent.CompletableFuture; import java.util.concurrent.ScheduledFuture; -import java.util.concurrent.atomic.AtomicLongFieldUpdater; import java.util.concurrent.atomic.AtomicReferenceFieldUpdater; import com.google.common.annotations.VisibleForTesting; import com.google.common.math.LongMath; import com.linecorp.armeria.client.ResponseTimeoutException; -import com.linecorp.armeria.common.TimeoutException; import com.linecorp.armeria.common.annotation.Nullable; +import com.linecorp.armeria.common.util.Ticker; import com.linecorp.armeria.common.util.TimeoutMode; import com.linecorp.armeria.common.util.UnmodifiableFuture; +import com.linecorp.armeria.internal.common.util.ReentrantShortLock; import com.linecorp.armeria.server.HttpResponseException; import com.linecorp.armeria.server.HttpStatusException; import com.linecorp.armeria.server.RequestTimeoutException; @@ -50,24 +50,6 @@ final class DefaultCancellationScheduler implements CancellationScheduler { whenCancelledUpdater = AtomicReferenceFieldUpdater.newUpdater( DefaultCancellationScheduler.class, CancellationFuture.class, "whenCancelled"); - private static final AtomicReferenceFieldUpdater - whenTimingOutUpdater = AtomicReferenceFieldUpdater.newUpdater( - DefaultCancellationScheduler.class, TimeoutFuture.class, "whenTimingOut"); - - private static final AtomicReferenceFieldUpdater - whenTimedOutUpdater = AtomicReferenceFieldUpdater.newUpdater( - DefaultCancellationScheduler.class, TimeoutFuture.class, "whenTimedOut"); - - private static final AtomicReferenceFieldUpdater - pendingTaskUpdater = AtomicReferenceFieldUpdater.newUpdater( - DefaultCancellationScheduler.class, Runnable.class, "pendingTask"); - - private static final AtomicLongFieldUpdater pendingTimeoutNanosUpdater = - AtomicLongFieldUpdater.newUpdater(DefaultCancellationScheduler.class, "pendingTimeoutNanos"); - - private static final Runnable noopPendingTask = () -> { - }; - static final CancellationScheduler serverFinishedCancellationScheduler = finished0(true); static final CancellationScheduler clientFinishedCancellationScheduler = finished0(false); @@ -76,25 +58,21 @@ final class DefaultCancellationScheduler implements CancellationScheduler { private long startTimeNanos; @Nullable private EventExecutor eventLoop; - @Nullable - private CancellationTask task; - @Nullable - private volatile Runnable pendingTask; + private volatile CancellationTask task = noopCancellationTask; @Nullable private ScheduledFuture scheduledFuture; + private long setFromNowStartNanos; + private TimeoutMode timeoutMode = TimeoutMode.SET_FROM_START; + @Nullable + private volatile Throwable cause; + private final Ticker ticker; + private final ReentrantShortLock lock = new ReentrantShortLock(); + + private final boolean server; @Nullable private volatile CancellationFuture whenCancelling; @Nullable private volatile CancellationFuture whenCancelled; - @Nullable - private volatile TimeoutFuture whenTimingOut; - @Nullable - private volatile TimeoutFuture whenTimedOut; - @SuppressWarnings("FieldMayBeFinal") - private volatile long pendingTimeoutNanos; - private final boolean server; - @Nullable - private Throwable cause; @VisibleForTesting DefaultCancellationScheduler(long timeoutNanos) { @@ -102,9 +80,14 @@ final class DefaultCancellationScheduler implements CancellationScheduler { } DefaultCancellationScheduler(long timeoutNanos, boolean server) { + this(timeoutNanos, server, Ticker.systemTicker()); + } + + @VisibleForTesting + DefaultCancellationScheduler(long timeoutNanos, boolean server, Ticker ticker) { this.timeoutNanos = timeoutNanos; - pendingTimeoutNanos = timeoutNanos; this.server = server; + this.ticker = ticker; } /** @@ -112,259 +95,196 @@ final class DefaultCancellationScheduler implements CancellationScheduler { */ @Override public void initAndStart(EventExecutor eventLoop, CancellationTask task) { - init(eventLoop); - if (!eventLoop.inEventLoop()) { - eventLoop.execute(() -> start(task)); - } else { - start(task); + lock.lock(); + try { + init(eventLoop); + updateTask(task); + start(); + } finally { + lock.unlock(); } } @Override public void init(EventExecutor eventLoop) { - checkState(this.eventLoop == null, "Can't init() more than once"); - this.eventLoop = eventLoop; + lock.lock(); + try { + checkState(this.eventLoop == null, "Can't init() more than once"); + this.eventLoop = eventLoop; + } finally { + lock.unlock(); + } } @Override - public void start(CancellationTask task) { - assert eventLoop != null; - assert eventLoop.inEventLoop(); - if (isFinished()) { - assert cause != null; - task.run(cause); - return; - } - if (this.task != null) { - // just replace the task - this.task = task; - return; - } - this.task = task; - startTimeNanos = System.nanoTime(); - if (timeoutNanos != 0) { - state = State.SCHEDULED; - scheduledFuture = - eventLoop.schedule(() -> invokeTask(null), timeoutNanos, NANOSECONDS); - } else { - state = State.INACTIVE; - } - for (;;) { - final Runnable pendingTask = this.pendingTask; - if (pendingTaskUpdater.compareAndSet(this, pendingTask, noopPendingTask)) { - if (pendingTask != null) { - pendingTask.run(); - } - break; + public void start() { + lock.lock(); + try { + if (state != State.INIT) { + return; + } + state = State.PENDING; + startTimeNanos = ticker.read(); + if (timeoutMode == TimeoutMode.SET_FROM_NOW) { + final long elapsedTimeNanos = startTimeNanos - setFromNowStartNanos; + timeoutNanos = LongMath.saturatedSubtract(timeoutNanos, elapsedTimeNanos); + } + if (timeoutNanos != Long.MAX_VALUE) { + scheduledFuture = eventLoop().schedule(() -> invokeTask(null), timeoutNanos, NANOSECONDS); } + } finally { + lock.unlock(); } } @Override public void clearTimeout() { - clearTimeout(true); - } - - @Override - public void clearTimeout(boolean resetTimeout) { - if (timeoutNanos() == 0) { - return; - } - if (isInitialized()) { - if (eventLoop.inEventLoop()) { - clearTimeout0(resetTimeout); - } else { - eventLoop.execute(() -> clearTimeout0(resetTimeout)); + lock.lock(); + try { + if (timeoutNanos == Long.MAX_VALUE) { + return; } - } else { - if (resetTimeout) { - setPendingTimeoutNanos(0); + timeoutNanos = Long.MAX_VALUE; + if (isStarted()) { + cancelScheduled(); } - addPendingTask(() -> clearTimeout0(resetTimeout)); + } finally { + lock.unlock(); } } - private boolean clearTimeout0(boolean resetTimeout) { - assert eventLoop != null && eventLoop.inEventLoop(); - if (state != State.SCHEDULED) { - return true; - } - if (resetTimeout) { - timeoutNanos = 0; - } - assert scheduledFuture != null; - final boolean cancelled = scheduledFuture.cancel(false); - scheduledFuture = null; - if (cancelled) { - state = State.INACTIVE; + @Override + public boolean cancelScheduled() { + lock.lock(); + try { + if (scheduledFuture == null) { + return true; + } + final boolean cancelled = scheduledFuture.cancel(false); + scheduledFuture = null; + return cancelled; + } finally { + lock.unlock(); } - return cancelled; } @Override public void setTimeoutNanos(TimeoutMode mode, long timeoutNanos) { - switch (mode) { - case SET_FROM_NOW: - setTimeoutNanosFromNow(timeoutNanos); - break; - case SET_FROM_START: - setTimeoutNanosFromStart(timeoutNanos); - break; - case EXTEND: - extendTimeoutNanos(timeoutNanos); - break; + lock.lock(); + final ScheduleResult result; + try { + switch (mode) { + case SET_FROM_NOW: + result = setTimeoutNanosFromNow(timeoutNanos); + break; + case SET_FROM_START: + result = setTimeoutNanosFromStart(timeoutNanos); + break; + case EXTEND: + result = extendTimeoutNanos(timeoutNanos); + break; + default: + throw new Error(); + } + } finally { + lock.unlock(); + } + if (result == ScheduleResult.INVOKE_IMMEDIATELY) { + invokeTask(null); } } - private void setTimeoutNanosFromStart(long timeoutNanos) { + private ScheduleResult setTimeoutNanosFromStart(long timeoutNanos) { checkArgument(timeoutNanos >= 0, "timeoutNanos: %s (expected: >= 0)", timeoutNanos); - if (timeoutNanos == 0) { + if (timeoutNanos == Long.MAX_VALUE) { clearTimeout(); - return; + return ScheduleResult.INVOKE_LATER; } - if (isInitialized()) { - if (eventLoop.inEventLoop()) { - setTimeoutNanosFromStart0(timeoutNanos); - } else { - eventLoop.execute(() -> setTimeoutNanosFromStart0(timeoutNanos)); - } - } else { - setPendingTimeoutNanos(timeoutNanos); - addPendingTask(() -> setTimeoutNanosFromStart0(timeoutNanos)); + if (isStarted()) { + return setTimeoutNanosFromStart0(timeoutNanos); } + this.timeoutNanos = timeoutNanos; + timeoutMode = TimeoutMode.SET_FROM_START; + return ScheduleResult.INVOKE_LATER; } - private void setTimeoutNanosFromStart0(long timeoutNanos) { - assert eventLoop != null && eventLoop.inEventLoop(); - final long passedTimeNanos = System.nanoTime() - startTimeNanos; - final long newTimeoutNanos = LongMath.saturatedSubtract(timeoutNanos, passedTimeNanos); + private ScheduleResult setTimeoutNanosFromStart0(long timeoutNanos) { + final long newTimeoutNanos; + if (timeoutNanos != Long.MAX_VALUE) { + final long passedTimeNanos = ticker.read() - startTimeNanos; + newTimeoutNanos = LongMath.saturatedSubtract(timeoutNanos, passedTimeNanos); + } else { + newTimeoutNanos = timeoutNanos; + } + + timeoutMode = TimeoutMode.SET_FROM_START; + this.timeoutNanos = timeoutNanos; if (newTimeoutNanos <= 0) { - invokeTask(null); - return; + return ScheduleResult.INVOKE_IMMEDIATELY; } // Cancel the previously scheduled timeout, if exists. - clearTimeout0(true); - this.timeoutNanos = timeoutNanos; - state = State.SCHEDULED; - scheduledFuture = eventLoop.schedule(() -> invokeTask(null), newTimeoutNanos, NANOSECONDS); + if (cancelScheduled() && !isFinished() && newTimeoutNanos != Long.MAX_VALUE) { + scheduledFuture = eventLoop().schedule(() -> invokeTask(null), newTimeoutNanos, NANOSECONDS); + } + return ScheduleResult.INVOKE_LATER; } - private void extendTimeoutNanos(long adjustmentNanos) { - if (adjustmentNanos == 0 || timeoutNanos() == 0) { - return; + private ScheduleResult extendTimeoutNanos(long adjustmentNanos) { + if (timeoutNanos == Long.MAX_VALUE || adjustmentNanos == Long.MAX_VALUE) { + return ScheduleResult.INVOKE_LATER; } - if (isInitialized()) { - if (eventLoop.inEventLoop()) { - extendTimeoutNanos0(adjustmentNanos); - } else { - eventLoop.execute(() -> extendTimeoutNanos0(adjustmentNanos)); - } - } else { - addPendingTimeoutNanos(adjustmentNanos); - addPendingTask(() -> extendTimeoutNanos0(adjustmentNanos)); + if (isStarted()) { + return extendTimeoutNanos0(adjustmentNanos); } + timeoutNanos = LongMath.saturatedAdd(timeoutNanos, adjustmentNanos); + return ScheduleResult.INVOKE_LATER; } - private void extendTimeoutNanos0(long adjustmentNanos) { - assert eventLoop != null && eventLoop.inEventLoop() && task != null; - if (state != State.SCHEDULED || !task.canSchedule()) { - return; - } + private ScheduleResult extendTimeoutNanos0(long adjustmentNanos) { final long timeoutNanos = this.timeoutNanos; - // Cancel the previously scheduled timeout, if exists. - clearTimeout0(true); this.timeoutNanos = LongMath.saturatedAdd(timeoutNanos, adjustmentNanos); - if (timeoutNanos <= 0) { - invokeTask(null); - return; + + if (this.timeoutNanos <= 0) { + return ScheduleResult.INVOKE_IMMEDIATELY; + } + // Cancel the previously scheduled timeout, if exists. + if (cancelScheduled() && !isFinished()) { + scheduledFuture = eventLoop().schedule(() -> invokeTask(null), this.timeoutNanos, NANOSECONDS); } - state = State.SCHEDULED; - scheduledFuture = eventLoop.schedule(() -> invokeTask(null), this.timeoutNanos, NANOSECONDS); + return ScheduleResult.INVOKE_LATER; } - private void setTimeoutNanosFromNow(long timeoutNanos) { + private ScheduleResult setTimeoutNanosFromNow(long timeoutNanos) { checkArgument(timeoutNanos > 0, "timeoutNanos: %s (expected: > 0)", timeoutNanos); - if (isInitialized()) { - if (eventLoop.inEventLoop()) { - setTimeoutNanosFromNow0(timeoutNanos); - } else { - final long eventLoopStartTimeNanos = System.nanoTime(); - eventLoop.execute(() -> { - final long passedTimeNanos0 = System.nanoTime() - eventLoopStartTimeNanos; - final long timeoutNanos0 = Math.max(1, timeoutNanos - passedTimeNanos0); - setTimeoutNanosFromNow0(timeoutNanos0); - }); - } - } else { - final long pendingTaskRegisterTimeNanos = System.nanoTime(); - setPendingTimeoutNanos(timeoutNanos); - addPendingTask(() -> { - final long passedTimeNanos0 = System.nanoTime() - pendingTaskRegisterTimeNanos; - final long timeoutNanos0 = Math.max(1, timeoutNanos - passedTimeNanos0); - setTimeoutNanosFromNow0(timeoutNanos0); - }); + if (isStarted()) { + return setTimeoutNanosFromNow0(timeoutNanos); } + setFromNowStartNanos = ticker.read(); + timeoutMode = TimeoutMode.SET_FROM_NOW; + this.timeoutNanos = timeoutNanos; + return ScheduleResult.INVOKE_LATER; } - private void setTimeoutNanosFromNow0(long newTimeoutNanos) { + private ScheduleResult setTimeoutNanosFromNow0(long newTimeoutNanos) { assert newTimeoutNanos > 0; - assert eventLoop != null && eventLoop.inEventLoop() && task != null; - if (isFinishing() || !task.canSchedule()) { - return; - } - // Cancel the previously scheduled timeout, if exists. - clearTimeout0(true); - final long passedTimeNanos = System.nanoTime() - startTimeNanos; + final long passedTimeNanos = ticker.read() - startTimeNanos; timeoutNanos = LongMath.saturatedAdd(newTimeoutNanos, passedTimeNanos); - - state = State.SCHEDULED; - scheduledFuture = eventLoop.schedule(() -> invokeTask(null), newTimeoutNanos, NANOSECONDS); - } - - @Override - public void finishNow() { - finishNow(null); - } - - @Override - public void finishNow(@Nullable Throwable cause) { - if (isFinishing()) { - return; - } - assert eventLoop != null; - if (!eventLoop.inEventLoop()) { - eventLoop.execute(() -> finishNow(cause)); - return; - } - if (isInitialized()) { - finishNow0(cause); - } else { - start(noopCancellationTask); - finishNow0(cause); + timeoutMode = TimeoutMode.SET_FROM_NOW; + // Cancel the previously scheduled timeout, if exists. + if (cancelScheduled() && !isFinished() && newTimeoutNanos != Long.MAX_VALUE) { + scheduledFuture = eventLoop().schedule(() -> invokeTask(null), newTimeoutNanos, NANOSECONDS); } + return ScheduleResult.INVOKE_LATER; } - private void finishNow0(@Nullable Throwable cause) { - assert eventLoop != null && eventLoop.inEventLoop() && task != null; - if (isFinishing() || !task.canSchedule()) { - return; - } - if (state == State.SCHEDULED) { - if (clearTimeout0(false)) { - invokeTask(cause); - } - } else { - invokeTask(cause); - } + private EventExecutor eventLoop() { + assert eventLoop != null; + return eventLoop; } @Override - public boolean isFinished() { - return state == State.FINISHED; - } - - private boolean isFinishing() { - return state == State.FINISHED || state == State.FINISHING; + public void finishNow(@Nullable Throwable cause) { + invokeTask(cause); } @Override @@ -375,7 +295,7 @@ public Throwable cause() { @Override public long timeoutNanos() { - return isInitialized() ? timeoutNanos : pendingTimeoutNanos; + return timeoutNanos == Long.MAX_VALUE ? 0 : timeoutNanos; } @Override @@ -383,125 +303,64 @@ public long startTimeNanos() { return startTimeNanos; } - @Override - public CompletableFuture whenCancelling() { - final CancellationFuture whenCancelling = this.whenCancelling; - if (whenCancelling != null) { - return whenCancelling; - } - final CancellationFuture cancellationFuture = new CancellationFuture(); - if (whenCancellingUpdater.compareAndSet(this, null, cancellationFuture)) { - return cancellationFuture; - } else { - return this.whenCancelling; - } - } - - @Override - public CompletableFuture whenCancelled() { - final CancellationFuture whenCancelled = this.whenCancelled; - if (whenCancelled != null) { - return whenCancelled; - } - final CancellationFuture cancellationFuture = new CancellationFuture(); - if (whenCancelledUpdater.compareAndSet(this, null, cancellationFuture)) { - return cancellationFuture; - } else { - return this.whenCancelled; - } + private boolean isStarted() { + return state != State.INIT; } @Override - @Deprecated - public CompletableFuture whenTimingOut() { - final TimeoutFuture whenTimingOut = this.whenTimingOut; - if (whenTimingOut != null) { - return whenTimingOut; - } - final TimeoutFuture timeoutFuture = new TimeoutFuture(); - if (whenTimingOutUpdater.compareAndSet(this, null, timeoutFuture)) { - whenCancelling().thenAccept(cause -> { - if (cause instanceof TimeoutException) { - timeoutFuture.doComplete(); - } - }); - return timeoutFuture; - } else { - return this.whenTimingOut; - } + public boolean isFinished() { + return state == State.FINISHED; } @Override - @Deprecated - public CompletableFuture whenTimedOut() { - final TimeoutFuture whenTimedOut = this.whenTimedOut; - if (whenTimedOut != null) { - return whenTimedOut; - } - final TimeoutFuture timeoutFuture = new TimeoutFuture(); - if (whenTimedOutUpdater.compareAndSet(this, null, timeoutFuture)) { - whenCancelled().thenAccept(cause -> { - if (cause instanceof TimeoutException) { - timeoutFuture.doComplete(); - } - }); - return timeoutFuture; - } else { - return this.whenTimedOut; + public void updateTask(CancellationTask task) { + lock.lock(); + try { + if (state != State.FINISHED) { + // if the task hasn't been run yet + this.task = task; + return; + } + } finally { + lock.unlock(); } - } - private boolean isInitialized() { - return pendingTask == noopPendingTask && eventLoop != null; - } - - private void addPendingTask(Runnable pendingTask) { - if (!pendingTaskUpdater.compareAndSet(this, null, pendingTask)) { - for (;;) { - final Runnable oldPendingTask = this.pendingTask; - assert oldPendingTask != null; - if (oldPendingTask == noopPendingTask) { - assert eventLoop != null; - eventLoop.execute(pendingTask); - break; - } - final Runnable newPendingTask = () -> { - oldPendingTask.run(); - pendingTask.run(); - }; - if (pendingTaskUpdater.compareAndSet(this, oldPendingTask, newPendingTask)) { - break; - } + whenCancelled().thenAccept(cause -> { + if (task.canSchedule()) { + task.run(cause); } - } + }); } - private void setPendingTimeoutNanos(long pendingTimeoutNanos) { - for (;;) { - final long oldPendingTimeoutNanos = this.pendingTimeoutNanos; - if (pendingTimeoutNanosUpdater.compareAndSet(this, oldPendingTimeoutNanos, pendingTimeoutNanos)) { - break; + private void invokeTask(@Nullable Throwable cause) { + lock.lock(); + try { + if (state == State.FINISHED) { + return; } + state = State.FINISHED; + cancelScheduled(); + // set the cause + cause = getFinalCause(cause); + this.cause = cause; + } finally { + lock.unlock(); } - } - private void addPendingTimeoutNanos(long pendingTimeoutNanos) { - for (;;) { - final long oldPendingTimeoutNanos = this.pendingTimeoutNanos; - final long newPendingTimeoutNanos = LongMath.saturatedAdd(oldPendingTimeoutNanos, - pendingTimeoutNanos); - if (pendingTimeoutNanosUpdater.compareAndSet(this, oldPendingTimeoutNanos, - newPendingTimeoutNanos)) { - break; - } + if (task.canSchedule()) { + ((CancellationFuture) whenCancelling()).doComplete(cause); } - } - private void invokeTask(@Nullable Throwable cause) { - if (task == null) { - return; + if (task.canSchedule()) { + assert !lock.isHeldByCurrentThread() : "Currently locked by lock: [" + lock + "], with count: " + + lock.getHoldCount(); + task.run(cause); } + ((CancellationFuture) whenCancelled()).doComplete(cause); + } + + private Throwable getFinalCause(@Nullable Throwable cause) { if (cause instanceof HttpStatusException || cause instanceof HttpResponseException) { // Log the requestCause only when an Http{Status,Response}Exception was created with a cause. cause = cause.getCause(); @@ -514,26 +373,45 @@ private void invokeTask(@Nullable Throwable cause) { cause = ResponseTimeoutException.get(); } } + return cause; + } - // Set FINISHING to preclude executing other timeout operations from the callbacks of `whenCancelling()` - state = State.FINISHING; - if (task.canSchedule()) { - ((CancellationFuture) whenCancelling()).doComplete(cause); + @VisibleForTesting + State state() { + return state; + } + + @Override + public CompletableFuture whenCancelling() { + final CancellationFuture whenCancelling = this.whenCancelling; + if (whenCancelling != null) { + return whenCancelling; + } + final CancellationFuture cancellationFuture = new CancellationFuture(); + if (whenCancellingUpdater.compareAndSet(this, null, cancellationFuture)) { + return cancellationFuture; + } else { + return this.whenCancelling; } - // Set state first to prevent duplicate execution - state = State.FINISHED; + } - // The returned value of `canSchedule()` could've been changed by the callbacks of `whenCancelling()` - if (task.canSchedule()) { - task.run(cause); + @Override + public CompletableFuture whenCancelled() { + final CancellationFuture whenCancelled = this.whenCancelled; + if (whenCancelled != null) { + return whenCancelled; + } + final CancellationFuture cancellationFuture = new CancellationFuture(); + if (whenCancelledUpdater.compareAndSet(this, null, cancellationFuture)) { + return cancellationFuture; + } else { + return this.whenCancelled; } - this.cause = cause; - ((CancellationFuture) whenCancelled()).doComplete(cause); } - @VisibleForTesting - State state() { - return state; + private enum ScheduleResult { + INVOKE_LATER, + INVOKE_IMMEDIATELY, } private static class CancellationFuture extends UnmodifiableFuture { @@ -543,14 +421,9 @@ protected void doComplete(@Nullable Throwable cause) { } } - private static class TimeoutFuture extends UnmodifiableFuture { - void doComplete() { - doComplete(null); - } - } - private static CancellationScheduler finished0(boolean server) { - final CancellationScheduler cancellationScheduler = new DefaultCancellationScheduler(0, server); + final CancellationScheduler cancellationScheduler = + new DefaultCancellationScheduler(Long.MAX_VALUE, server); cancellationScheduler.initAndStart(ImmediateEventExecutor.INSTANCE, noopCancellationTask); cancellationScheduler.finishNow(); return cancellationScheduler; diff --git a/core/src/main/java/com/linecorp/armeria/internal/common/NoopCancellationScheduler.java b/core/src/main/java/com/linecorp/armeria/internal/common/NoopCancellationScheduler.java index bc30d28dd2b..046d246278e 100644 --- a/core/src/main/java/com/linecorp/armeria/internal/common/NoopCancellationScheduler.java +++ b/core/src/main/java/com/linecorp/armeria/internal/common/NoopCancellationScheduler.java @@ -45,7 +45,7 @@ public void init(EventExecutor eventLoop) { } @Override - public void start(CancellationTask task) { + public void start() { } @Override @@ -53,7 +53,8 @@ public void clearTimeout() { } @Override - public void clearTimeout(boolean resetTimeout) { + public boolean cancelScheduled() { + return false; } @Override @@ -100,12 +101,6 @@ public CompletableFuture whenCancelled() { } @Override - public CompletableFuture whenTimingOut() { - return VOID_FUTURE; - } - - @Override - public CompletableFuture whenTimedOut() { - return VOID_FUTURE; + public void updateTask(CancellationTask cancellationTask) { } } diff --git a/core/src/main/java/com/linecorp/armeria/internal/server/DefaultServiceRequestContext.java b/core/src/main/java/com/linecorp/armeria/internal/server/DefaultServiceRequestContext.java index fb4dea410dc..2e6e2906ffa 100644 --- a/core/src/main/java/com/linecorp/armeria/internal/server/DefaultServiceRequestContext.java +++ b/core/src/main/java/com/linecorp/armeria/internal/server/DefaultServiceRequestContext.java @@ -388,13 +388,13 @@ public CompletableFuture whenRequestCancelled() { @Deprecated @Override public CompletableFuture whenRequestTimingOut() { - return requestCancellationScheduler.whenTimingOut(); + return requestCancellationScheduler.whenCancelling().handle((v, e) -> null); } @Deprecated @Override public CompletableFuture whenRequestTimedOut() { - return requestCancellationScheduler.whenTimedOut(); + return requestCancellationScheduler.whenCancelled().handle((v, e) -> null); } @Override diff --git a/core/src/main/java/com/linecorp/armeria/server/AbstractHttpResponseHandler.java b/core/src/main/java/com/linecorp/armeria/server/AbstractHttpResponseHandler.java index 446c51f8bd8..e49e5b8a120 100644 --- a/core/src/main/java/com/linecorp/armeria/server/AbstractHttpResponseHandler.java +++ b/core/src/main/java/com/linecorp/armeria/server/AbstractHttpResponseHandler.java @@ -241,14 +241,15 @@ final void maybeWriteAccessLog() { */ final void scheduleTimeout() { // Schedule the initial request timeout with the timeoutNanos in the CancellationScheduler - reqCtx.requestCancellationScheduler().start(newCancellationTask()); + reqCtx.requestCancellationScheduler().updateTask(newCancellationTask()); + reqCtx.requestCancellationScheduler().start(); } /** * Clears the scheduled request timeout. */ final void clearTimeout() { - reqCtx.requestCancellationScheduler().clearTimeout(false); + reqCtx.requestCancellationScheduler().cancelScheduled(); } final CancellationTask newCancellationTask() { @@ -260,8 +261,17 @@ public boolean canSchedule() { @Override public void run(Throwable cause) { - // This method will be invoked only when `canSchedule()` returns true. - assert !isDone(); + if (ctx.executor().inEventLoop()) { + doCancel(cause); + } else { + ctx.executor().execute(() -> doCancel(cause)); + } + } + + private void doCancel(Throwable cause) { + if (isDone()) { + return; + } if (cause instanceof ClosedStreamException) { // A stream or connection was already closed by a client diff --git a/core/src/test/java/com/linecorp/armeria/client/ContextCancellationTest.java b/core/src/test/java/com/linecorp/armeria/client/ContextCancellationTest.java new file mode 100644 index 00000000000..04b9b341b78 --- /dev/null +++ b/core/src/test/java/com/linecorp/armeria/client/ContextCancellationTest.java @@ -0,0 +1,350 @@ +/* + * Copyright 2024 LINE Corporation + * + * LINE Corporation licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ + +package com.linecorp.armeria.client; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatThrownBy; +import static org.awaitility.Awaitility.await; + +import java.net.InetAddress; +import java.net.InetSocketAddress; +import java.net.UnknownHostException; +import java.util.Set; +import java.util.concurrent.BlockingQueue; +import java.util.concurrent.CompletionException; +import java.util.concurrent.LinkedBlockingQueue; +import java.util.concurrent.atomic.AtomicReference; +import java.util.function.Function; + +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.TestInfo; +import org.junit.jupiter.api.extension.RegisterExtension; +import org.reactivestreams.Subscriber; + +import com.google.common.collect.Sets; + +import com.linecorp.armeria.common.HttpMethod; +import com.linecorp.armeria.common.HttpObject; +import com.linecorp.armeria.common.HttpRequest; +import com.linecorp.armeria.common.HttpResponse; +import com.linecorp.armeria.common.RequestContext; +import com.linecorp.armeria.common.RequestHeaders; +import com.linecorp.armeria.common.SessionProtocol; +import com.linecorp.armeria.common.annotation.Nullable; +import com.linecorp.armeria.common.logging.RequestLogAccess; +import com.linecorp.armeria.common.stream.SubscriptionOption; +import com.linecorp.armeria.internal.testing.MockAddressResolverGroup; +import com.linecorp.armeria.server.ServerBuilder; +import com.linecorp.armeria.testing.junit5.common.EventLoopGroupExtension; +import com.linecorp.armeria.testing.junit5.server.ServerExtension; + +import io.netty.util.AttributeMap; +import io.netty.util.concurrent.EventExecutor; + +class ContextCancellationTest { + + private static final Set requests = Sets.newConcurrentHashSet(); + private static final BlockingQueue callbackThreads = new LinkedBlockingQueue<>(); + private static final Set callbackContexts = Sets.newConcurrentHashSet(); + private static final String eventLoopThreadPrefix = "context-cancellation-test"; + private static final String HEADER = "x-request-id"; + + @RegisterExtension + static EventLoopGroupExtension eventLoopGroup = new EventLoopGroupExtension(4, eventLoopThreadPrefix); + + @RegisterExtension + static ServerExtension server = new ServerExtension() { + @Override + protected void configure(ServerBuilder sb) throws Exception { + sb.service("/", (ctx, req) -> { + requests.add(req.headers().get(HEADER)); + return HttpResponse.streaming(); + }); + } + }; + + @BeforeEach + void beforeEach() { + requests.clear(); + callbackThreads.clear(); + callbackContexts.clear(); + } + + @Test + void cancel_beforeDelegate(TestInfo testInfo) { + final Throwable t = new Throwable(); + final CountingConnectionPoolListener connListener = new CountingConnectionPoolListener(); + final AtomicReference ctxRef = new AtomicReference<>(); + try (ClientFactory cf = ClientFactory + .builder() + .connectionPoolListener(connListener) + .workerGroup(eventLoopGroup.get(), false) + .build()) { + final HttpResponse res = server.webClient(cb -> { + cb.decorator((delegate, ctx, req) -> { + ctx.cancel(t); + ctxRef.set(ctx); + return delegate.execute(ctx, req); + }); + cb.decorator(TestInfoHeaderDecorator.newDecorator(testInfo)); + cb.decorator(AttachCallbacksDecorator.newDecorator()); + cb.factory(cf); + }).get("/"); + assertThatThrownBy(() -> res.aggregate().join()) + .isInstanceOf(CompletionException.class) + .hasCauseInstanceOf(UnprocessedRequestException.class) + .hasRootCause(t); + assertThat(connListener.opened()).isEqualTo(0); + assertThat(requests).doesNotContain(testInfo.getDisplayName()); + // don't validate the thread since we haven't started with event loop scheduling yet + validateCallbackChecks(null); + } + } + + @Test + void cancel_beforeConnection(TestInfo testInfo) { + final Throwable t = new Throwable(); + final AtomicReference ctxRef = new AtomicReference<>(); + final CountingConnectionPoolListener connListener = new CountingConnectionPoolListener(); + try (ClientFactory cf = ClientFactory + .builder() + .workerGroup(eventLoopGroup.get(), false) + .addressResolverGroupFactory( + eventLoop -> MockAddressResolverGroup.of(ignored -> { + ctxRef.get().cancel(t); + try { + return InetAddress.getByName("127.0.0.1"); + } catch (UnknownHostException e) { + throw new RuntimeException(e); + } + })) + .connectionPoolListener(connListener).build()) { + final HttpResponse res = WebClient.builder("http://foo.com:" + server.httpPort()) + .decorator((delegate, ctx, req) -> { + ctxRef.set(ctx); + return delegate.execute(ctx, req); + }) + .decorator(TestInfoHeaderDecorator.newDecorator(testInfo)) + .decorator(AttachCallbacksDecorator.newDecorator()) + .factory(cf) + .build() + .execute(HttpRequest.streaming(HttpMethod.POST, "/")); + assertThatThrownBy(() -> res.aggregate().join()) + .isInstanceOf(CompletionException.class) + .hasCauseInstanceOf(UnprocessedRequestException.class) + .hasRootCause(t); + assertThat(requests).doesNotContain(testInfo.getDisplayName()); + // don't validate the thread since we haven't started with event loop scheduling yet + validateCallbackChecks(null); + } + } + + @Test + void cancel_afterConnection(TestInfo testInfo) { + final Throwable t = new Throwable(); + final AtomicReference ctxRef = new AtomicReference<>(); + final CountingConnectionPoolListener connListener = new CountingConnectionPoolListener() { + @Override + public void connectionOpen(SessionProtocol protocol, InetSocketAddress remoteAddr, + InetSocketAddress localAddr, AttributeMap attrs) + throws Exception { + super.connectionOpen(protocol, remoteAddr, localAddr, attrs); + ctxRef.get().cancel(t); + } + }; + try (ClientFactory cf = ClientFactory + .builder() + .workerGroup(eventLoopGroup.get(), false) + .connectionPoolListener(connListener) + .build()) { + final HttpResponse res = server.webClient(cb -> { + cb.decorator((delegate, ctx, req) -> { + ctxRef.set(ctx); + return delegate.execute(ctx, req); + }); + cb.decorator(TestInfoHeaderDecorator.newDecorator(testInfo)); + cb.decorator(AttachCallbacksDecorator.newDecorator()); + cb.factory(cf); + }).execute(HttpRequest.streaming(HttpMethod.POST, "/")); + assertThatThrownBy(() -> res.aggregate().join()) + .isInstanceOf(CompletionException.class) + .hasCauseInstanceOf(UnprocessedRequestException.class) + .hasRootCause(t); + assertThat(connListener.opened()).isEqualTo(1); + assertThat(requests).doesNotContain(testInfo.getDisplayName()); + validateCallbackChecks(eventLoopThreadPrefix); + } + } + + @Test + void cancel_beforeSubscribe(TestInfo testInfo) { + final Throwable t = new Throwable(); + final AtomicReference ctxRef = new AtomicReference<>(); + final CountingConnectionPoolListener connListener = new CountingConnectionPoolListener(); + try (ClientFactory cf = ClientFactory + .builder() + .workerGroup(eventLoopGroup.get(), false) + .connectionPoolListener(connListener) + .build()) { + final HttpResponse res = server.webClient(cb -> { + cb.decorator((delegate, ctx, req) -> { + ctxRef.set(ctx); + return delegate.execute(ctx, req); + }); + cb.decorator(TestInfoHeaderDecorator.newDecorator(testInfo)); + cb.decorator(AttachCallbacksDecorator.newDecorator()); + cb.factory(cf); + }).execute(new DelegatingHttpRequest(HttpRequest.streaming(HttpMethod.POST, "/")) { + @Override + public void subscribe(Subscriber subscriber, EventExecutor executor, + SubscriptionOption... options) { + ctxRef.get().cancel(t); + } + }); + assertThatThrownBy(() -> res.aggregate().join()) + .isInstanceOf(CompletionException.class) + .hasCauseInstanceOf(UnprocessedRequestException.class) + .hasRootCause(t); + assertThat(connListener.opened()).isEqualTo(1); + assertThat(requests).doesNotContain(testInfo.getDisplayName()); + validateCallbackChecks(eventLoopThreadPrefix); + } + } + + @Test + void cancel_beforeWriteFinished(TestInfo testInfo) { + final Throwable t = new Throwable(); + final AtomicReference ctxRef = new AtomicReference<>(); + final CountingConnectionPoolListener connListener = new CountingConnectionPoolListener(); + try (ClientFactory cf = ClientFactory + .builder() + .workerGroup(eventLoopGroup.get(), false) + .connectionPoolListener(connListener) + .build()) { + final HttpResponse res = server.webClient(cb -> { + cb.decorator((delegate, ctx, req) -> { + ctxRef.set(ctx); + return delegate.execute(ctx, req); + }); + cb.decorator(TestInfoHeaderDecorator.newDecorator(testInfo)); + cb.decorator(AttachCallbacksDecorator.newDecorator()); + cb.factory(cf); + }).execute(new DelegatingHttpRequest(HttpRequest.streaming(HttpMethod.POST, "/")) { + @Override + public void subscribe(Subscriber subscriber, EventExecutor executor, + SubscriptionOption... options) { + super.subscribe(subscriber, executor, options); + ctxRef.get().cancel(t); + } + }); + assertThatThrownBy(() -> res.aggregate().join()) + .isInstanceOf(CompletionException.class) + .hasCause(t); + assertThat(connListener.opened()).isEqualTo(1); + validateCallbackChecks(eventLoopThreadPrefix); + } + } + + @Test + void cancel_waitingForResponse(TestInfo testInfo) { + final Throwable t = new Throwable(); + final CountingConnectionPoolListener connListener = new CountingConnectionPoolListener(); + try (ClientFactory cf = ClientFactory + .builder() + .workerGroup(eventLoopGroup.get(), false) + .connectionPoolListener(connListener) + .build(); + ClientRequestContextCaptor captor = Clients.newContextCaptor()) { + final HttpResponse res = server.webClient(cb -> { + cb.factory(cf); + cb.decorator(TestInfoHeaderDecorator.newDecorator(testInfo)); + cb.decorator(AttachCallbacksDecorator.newDecorator()); + }).get("/"); + await().untilAsserted(() -> assertThat(requests).contains(testInfo.getDisplayName())); + captor.get().cancel(t); + assertThatThrownBy(() -> res.aggregate().join()) + .isInstanceOf(CompletionException.class) + .hasCause(t); + assertThat(connListener.opened()).isEqualTo(1); + validateCallbackChecks(eventLoopThreadPrefix); + } + } + + static void validateCallbackChecks(@Nullable String expectedPrefix) { + assertThat(callbackContexts).isEmpty(); + if (expectedPrefix != null) { + assertThat(callbackThreads).allSatisfy(t -> assertThat(t.getName()).startsWith(expectedPrefix)); + } + } + + private static class TestInfoHeaderDecorator extends SimpleDecoratingHttpClient { + + private final TestInfo testInfo; + + static Function newDecorator(TestInfo testInfo) { + return delegate -> new TestInfoHeaderDecorator(delegate, testInfo); + } + + /** + * Creates a new instance that decorates the specified {@link HttpClient}. + */ + protected TestInfoHeaderDecorator(HttpClient delegate, TestInfo testInfo) { + super(delegate); + this.testInfo = testInfo; + } + + @Override + public HttpResponse execute(ClientRequestContext ctx, HttpRequest req) throws Exception { + final RequestHeaders requestHeaders = req.headers().toBuilder() + .add(HEADER, testInfo.getDisplayName()) + .build(); + req = req.withHeaders(requestHeaders); + ctx.updateRequest(req); + return unwrap().execute(ctx, req); + } + } + + private static final class AttachCallbacksDecorator extends SimpleDecoratingHttpClient { + + static Function newDecorator() { + return AttachCallbacksDecorator::new; + } + + private AttachCallbacksDecorator(HttpClient delegate) { + super(delegate); + } + + @Override + public HttpResponse execute(ClientRequestContext ctx, HttpRequest req) throws Exception { + attachCallbackChecks(ctx.log()); + return unwrap().execute(ctx, req); + } + + private static void attachCallbackChecks(RequestLogAccess log) { + final Runnable runnable = () -> { + callbackThreads.add(Thread.currentThread()); + final RequestContext ctx = RequestContext.currentOrNull(); + if (ctx != null) { + callbackContexts.add(ctx); + } + }; + log.whenRequestComplete().thenRun(runnable); + log.whenComplete().thenRun(runnable); + } + } +} diff --git a/core/src/test/java/com/linecorp/armeria/client/CountingConnectionPoolListener.java b/core/src/test/java/com/linecorp/armeria/client/CountingConnectionPoolListener.java index 07596021df2..23b014e5597 100644 --- a/core/src/test/java/com/linecorp/armeria/client/CountingConnectionPoolListener.java +++ b/core/src/test/java/com/linecorp/armeria/client/CountingConnectionPoolListener.java @@ -26,7 +26,7 @@ /** * A {@link ConnectionPoolListener} to count the number of connections which have been open and closed. */ -public final class CountingConnectionPoolListener implements ConnectionPoolListener { +public class CountingConnectionPoolListener implements ConnectionPoolListener { private final AtomicInteger opened = new AtomicInteger(); private final AtomicInteger closed = new AtomicInteger(); diff --git a/core/src/test/java/com/linecorp/armeria/client/DelegatingHttpRequest.java b/core/src/test/java/com/linecorp/armeria/client/DelegatingHttpRequest.java new file mode 100644 index 00000000000..3ef4cf5d398 --- /dev/null +++ b/core/src/test/java/com/linecorp/armeria/client/DelegatingHttpRequest.java @@ -0,0 +1,85 @@ +/* + * Copyright 2024 LINE Corporation + * + * LINE Corporation licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ + +package com.linecorp.armeria.client; + +import java.util.concurrent.CompletableFuture; + +import org.reactivestreams.Subscriber; + +import com.linecorp.armeria.common.AggregatedHttpRequest; +import com.linecorp.armeria.common.AggregationOptions; +import com.linecorp.armeria.common.HttpObject; +import com.linecorp.armeria.common.HttpRequest; +import com.linecorp.armeria.common.RequestHeaders; +import com.linecorp.armeria.common.stream.SubscriptionOption; + +import io.netty.util.concurrent.EventExecutor; + +class DelegatingHttpRequest implements HttpRequest { + + private final HttpRequest delegate; + + DelegatingHttpRequest(HttpRequest delegate) { + this.delegate = delegate; + } + + @Override + public RequestHeaders headers() { + return delegate.headers(); + } + + @Override + public CompletableFuture aggregate(AggregationOptions options) { + return delegate.aggregate(options); + } + + @Override + public boolean isOpen() { + return delegate.isOpen(); + } + + @Override + public boolean isEmpty() { + return delegate.isEmpty(); + } + + @Override + public long demand() { + return delegate.demand(); + } + + @Override + public CompletableFuture whenComplete() { + return delegate.whenComplete(); + } + + @Override + public void subscribe(Subscriber subscriber, EventExecutor executor, + SubscriptionOption... options) { + delegate.subscribe(subscriber, executor, options); + } + + @Override + public void abort() { + delegate.abort(); + } + + @Override + public void abort(Throwable cause) { + delegate.abort(cause); + } +} diff --git a/core/src/test/java/com/linecorp/armeria/client/Http1ConnectionReuseTest.java b/core/src/test/java/com/linecorp/armeria/client/Http1ConnectionReuseTest.java index 9ede0490d20..07bc771e2a0 100644 --- a/core/src/test/java/com/linecorp/armeria/client/Http1ConnectionReuseTest.java +++ b/core/src/test/java/com/linecorp/armeria/client/Http1ConnectionReuseTest.java @@ -24,27 +24,20 @@ import org.junit.jupiter.api.Test; import org.junit.jupiter.api.extension.RegisterExtension; -import org.reactivestreams.Subscriber; -import com.linecorp.armeria.common.AggregatedHttpRequest; import com.linecorp.armeria.common.AggregatedHttpResponse; -import com.linecorp.armeria.common.AggregationOptions; import com.linecorp.armeria.common.HttpMethod; -import com.linecorp.armeria.common.HttpObject; import com.linecorp.armeria.common.HttpRequest; import com.linecorp.armeria.common.HttpResponse; import com.linecorp.armeria.common.HttpStatus; -import com.linecorp.armeria.common.RequestHeaders; import com.linecorp.armeria.common.SessionProtocol; -import com.linecorp.armeria.common.stream.SubscriptionOption; import com.linecorp.armeria.server.ServerBuilder; import com.linecorp.armeria.testing.junit5.server.ServerExtension; -import io.netty.util.concurrent.EventExecutor; - class Http1ConnectionReuseTest { private static final List remoteAddresses = new ArrayList<>(3); + private static final HttpRequest REQUEST = HttpRequest.of(HttpMethod.GET, "/"); @RegisterExtension static final ServerExtension server = new ServerExtension() { @@ -73,54 +66,11 @@ void returnToThePoolAfterRequestIsComplete() { } private static HttpRequest httpRequest(CompletableFuture future) { - return new HttpRequest() { - private final HttpRequest delegate = HttpRequest.of(HttpMethod.GET, "/"); - - @Override - public RequestHeaders headers() { - return delegate.headers(); - } - - @Override - public boolean isOpen() { - return delegate.isOpen(); - } - - @Override - public boolean isEmpty() { - return delegate.isEmpty(); - } - - @Override - public long demand() { - return delegate.demand(); - } - + return new DelegatingHttpRequest(REQUEST) { @Override public CompletableFuture whenComplete() { return future; } - - @Override - public void subscribe(Subscriber subscriber, EventExecutor executor, - SubscriptionOption... options) { - delegate.subscribe(subscriber, executor, options); - } - - @Override - public void abort() { - delegate.abort(); - } - - @Override - public void abort(Throwable cause) { - delegate.abort(cause); - } - - @Override - public CompletableFuture aggregate(AggregationOptions options) { - return delegate.aggregate(options); - } }; } diff --git a/core/src/test/java/com/linecorp/armeria/client/HttpClientResponseTimeoutTest.java b/core/src/test/java/com/linecorp/armeria/client/HttpClientResponseTimeoutTest.java index ce8785c5da4..694080c9806 100644 --- a/core/src/test/java/com/linecorp/armeria/client/HttpClientResponseTimeoutTest.java +++ b/core/src/test/java/com/linecorp/armeria/client/HttpClientResponseTimeoutTest.java @@ -39,7 +39,6 @@ import com.linecorp.armeria.common.AggregatedHttpResponse; import com.linecorp.armeria.common.CancellationException; import com.linecorp.armeria.common.HttpResponse; -import com.linecorp.armeria.common.RequestContext; import com.linecorp.armeria.common.TimeoutException; import com.linecorp.armeria.common.util.TimeoutMode; import com.linecorp.armeria.server.ServerBuilder; @@ -76,7 +75,8 @@ void shouldSetResponseTimeoutWithNoTimeout() { @ParameterizedTest @ArgumentsSource(TimeoutDecoratorSource.class) - void setRequestTimeoutAtPendingTimeoutTask(Consumer timeoutCustomizer) { + void setRequestTimeoutAtPendingTimeoutTask(Consumer timeoutCustomizer, + boolean unprocessed) { final WebClient client = WebClient .builder(server.httpUri()) .option(ClientOptions.RESPONSE_TIMEOUT_MILLIS.newValue(30L)) @@ -86,11 +86,20 @@ void setRequestTimeoutAtPendingTimeoutTask(Consumer { - assertThatThrownBy(() -> client.get("/no-timeout").aggregate().join()) - .isInstanceOf(CompletionException.class) - .hasCauseInstanceOf(ResponseTimeoutException.class); - }); + if (unprocessed) { + await().timeout(Duration.ofSeconds(5)).untilAsserted(() -> { + assertThatThrownBy(() -> client.get("/no-timeout").aggregate().join()) + .isInstanceOf(CompletionException.class) + .hasCauseInstanceOf(UnprocessedRequestException.class) + .hasRootCauseInstanceOf(ResponseTimeoutException.class); + }); + } else { + await().timeout(Duration.ofSeconds(5)).untilAsserted(() -> { + assertThatThrownBy(() -> client.get("/no-timeout").aggregate().join()) + .isInstanceOf(CompletionException.class) + .hasCauseInstanceOf(ResponseTimeoutException.class); + }); + } } @Test @@ -133,7 +142,7 @@ void timeoutWithContext() { await().timeout(Duration.ofSeconds(5)).untilAsserted(() -> { assertThatThrownBy(response::join) .isInstanceOf(CompletionException.class) - .hasCauseInstanceOf(ResponseTimeoutException.class); + .hasRootCauseInstanceOf(ResponseTimeoutException.class); }); assertThat(cctx.isTimedOut()).isTrue(); @@ -152,7 +161,8 @@ void cancel() { .build(); assertThatThrownBy(() -> client.get("/no-timeout").aggregate().join()) .isInstanceOf(CompletionException.class) - .hasCauseInstanceOf(CancellationException.class); + .hasCauseInstanceOf(UnprocessedRequestException.class) + .hasRootCauseInstanceOf(CancellationException.class); } @Test @@ -169,7 +179,7 @@ void cancelWithContext() { await().timeout(Duration.ofSeconds(5)).untilAsserted(() -> { assertThatThrownBy(response::join) .isInstanceOf(CompletionException.class) - .hasCauseInstanceOf(CancellationException.class); + .hasRootCauseInstanceOf(CancellationException.class); }); assertThat(cctx.isCancelled()).isTrue(); @@ -190,7 +200,7 @@ void cancelWithException() { await().timeout(Duration.ofSeconds(5)).untilAsserted(() -> { assertThatThrownBy(response::join) .isInstanceOf(CompletionException.class) - .hasCauseInstanceOf(IllegalStateException.class); + .hasRootCauseInstanceOf(IllegalStateException.class); }); assertThat(cctx.isCancelled()).isTrue(); @@ -231,12 +241,15 @@ void timeoutWithWebClientPreparation(long timeoutMillisForRequest, long timeoutM private static class TimeoutDecoratorSource implements ArgumentsProvider { @Override public Stream provideArguments(ExtensionContext extensionContext) { - final Stream> timeoutCustomizers = Stream.of( - ctx -> ctx.setResponseTimeoutMillis(TimeoutMode.SET_FROM_NOW, 1000), - ctx -> ctx.setResponseTimeoutMillis(TimeoutMode.SET_FROM_START, 1000), - RequestContext::timeoutNow + return Stream.of( + Arguments.of( + (Consumer) ctx -> ctx.setResponseTimeoutMillis( + TimeoutMode.SET_FROM_NOW, 1000), false), + Arguments.of( + (Consumer) ctx -> ctx.setResponseTimeoutMillis( + TimeoutMode.SET_FROM_START, 1000), false), + Arguments.of((Consumer) ClientRequestContext::timeoutNow, true) ); - return timeoutCustomizers.map(Arguments::of); } } } diff --git a/core/src/test/java/com/linecorp/armeria/common/ContextPushHookTest.java b/core/src/test/java/com/linecorp/armeria/common/ContextPushHookTest.java index acd776f8d88..6d76d7b4dc7 100644 --- a/core/src/test/java/com/linecorp/armeria/common/ContextPushHookTest.java +++ b/core/src/test/java/com/linecorp/armeria/common/ContextPushHookTest.java @@ -135,7 +135,9 @@ void shouldRunHooksWhenContextIsPushed() { hookEvents.clear(); response = client.get("http://foo.com:" + server.httpPort() + "/virtualhost"); assertThat(response.status()).isEqualTo(HttpStatus.OK); - assertThat(hookEvents).containsExactly( + // we don't do containsExactly here because there is no easy way to guarantee that + // all context hooks from the previous request have been completed + assertThat(hookEvents).contains( "ClientBuilder/push", "ClientContext/push", "ServerBuilder/push", diff --git a/core/src/test/java/com/linecorp/armeria/internal/common/CancellationSchedulerTest.java b/core/src/test/java/com/linecorp/armeria/internal/common/CancellationSchedulerTest.java index 380ae29d716..2739d9362b1 100644 --- a/core/src/test/java/com/linecorp/armeria/internal/common/CancellationSchedulerTest.java +++ b/core/src/test/java/com/linecorp/armeria/internal/common/CancellationSchedulerTest.java @@ -25,7 +25,12 @@ import static org.awaitility.Awaitility.await; import java.util.concurrent.CompletableFuture; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.Executors; import java.util.concurrent.atomic.AtomicBoolean; +import java.util.concurrent.atomic.AtomicInteger; +import java.util.concurrent.atomic.AtomicLong; import java.util.concurrent.atomic.AtomicReference; import java.util.function.Consumer; @@ -48,15 +53,7 @@ class CancellationSchedulerTest { private static final EventExecutor eventExecutor = CommonPools.workerGroup().next(); - private static final CancellationTask noopTask = new CancellationTask() { - @Override - public boolean canSchedule() { - return true; - } - - @Override - public void run(Throwable cause) {} - }; + private static final CancellationTask noopTask = cause -> {}; private static void executeInEventLoop(long initTimeoutNanos, Consumer task) { @@ -95,7 +92,7 @@ void extendTimeout() { @Test void setTimeoutFromNow() { - executeInEventLoop(0, scheduler -> { + executeInEventLoop(Long.MAX_VALUE, scheduler -> { scheduler.setTimeoutNanos(SET_FROM_NOW, MILLISECONDS.toNanos(1000)); scheduler.setTimeoutNanos(SET_FROM_NOW, MILLISECONDS.toNanos(500)); assertTimeoutWithTolerance(scheduler.timeoutNanos(), MILLISECONDS.toNanos(500)); @@ -104,7 +101,7 @@ void setTimeoutFromNow() { @Test void setTimeoutFromNowZero() { - executeInEventLoop(0, scheduler -> { + executeInEventLoop(Long.MAX_VALUE, scheduler -> { scheduler.setTimeoutNanos(SET_FROM_NOW, MILLISECONDS.toNanos(1000)); assertThatThrownBy(() -> scheduler.setTimeoutNanos(SET_FROM_NOW, 0)) .isInstanceOf(IllegalArgumentException.class) @@ -114,7 +111,7 @@ void setTimeoutFromNowZero() { @Test void setTimeoutFromNowMultipleNonZero() { - executeInEventLoop(0, scheduler -> { + executeInEventLoop(Long.MAX_VALUE, scheduler -> { scheduler.setTimeoutNanos(SET_FROM_NOW, MILLISECONDS.toNanos(1000)); scheduler.setTimeoutNanos(SET_FROM_NOW, MILLISECONDS.toNanos(500)); }); @@ -122,7 +119,7 @@ void setTimeoutFromNowMultipleNonZero() { @Test void cancelTimeoutBeforeDeadline() { - executeInEventLoop(0, scheduler -> { + executeInEventLoop(Long.MAX_VALUE, scheduler -> { scheduler.setTimeoutNanos(SET_FROM_NOW, MILLISECONDS.toNanos(1000)); scheduler.clearTimeout(); assertThat(scheduler.isFinished()).isFalse(); @@ -131,7 +128,7 @@ void cancelTimeoutBeforeDeadline() { @Test void cancelTimeoutAfterDeadline() { - executeInEventLoop(0, scheduler -> { + executeInEventLoop(Long.MAX_VALUE, scheduler -> { scheduler.finishNow(); scheduler.clearTimeout(); assertThat(scheduler.isFinished()).isTrue(); @@ -141,14 +138,14 @@ void cancelTimeoutAfterDeadline() { @Test void cancelTimeoutBySettingTimeoutZero() { executeInEventLoop(1000, scheduler -> { - scheduler.setTimeoutNanos(SET_FROM_START, 0); - assertThat(scheduler.state()).isEqualTo(CancellationScheduler.State.INACTIVE); + scheduler.setTimeoutNanos(SET_FROM_START, Long.MAX_VALUE); + assertThat(scheduler.state()).isEqualTo(State.PENDING); }); } @Test void scheduleTimeoutWhenFinished() { - executeInEventLoop(0, scheduler -> { + executeInEventLoop(Long.MAX_VALUE, scheduler -> { scheduler.finishNow(); assertThat(scheduler.isFinished()).isTrue(); scheduler.setTimeoutNanos(SET_FROM_NOW, MILLISECONDS.toNanos(1000)); @@ -158,7 +155,7 @@ void scheduleTimeoutWhenFinished() { @Test void extendTimeoutWhenScheduled() { - executeInEventLoop(0, scheduler -> { + executeInEventLoop(Long.MAX_VALUE, scheduler -> { final long timeoutNanos = MILLISECONDS.toNanos(1000); scheduler.setTimeoutNanos(SET_FROM_NOW, timeoutNanos); final long currentTimeoutNanos = scheduler.timeoutNanos(); @@ -170,7 +167,7 @@ void extendTimeoutWhenScheduled() { @Test void extendTimeoutWhenFinished() { - executeInEventLoop(0, scheduler -> { + executeInEventLoop(Long.MAX_VALUE, scheduler -> { scheduler.finishNow(); assertThat(scheduler.isFinished()).isTrue(); scheduler.setTimeoutNanos(EXTEND, MILLISECONDS.toNanos(1000)); @@ -180,7 +177,7 @@ void extendTimeoutWhenFinished() { @Test void cancelTimeoutWhenScheduled() { - executeInEventLoop(0, scheduler -> { + executeInEventLoop(Long.MAX_VALUE, scheduler -> { scheduler.setTimeoutNanos(SET_FROM_NOW, MILLISECONDS.toNanos(1000)); scheduler.clearTimeout(); }); @@ -188,7 +185,7 @@ void cancelTimeoutWhenScheduled() { @Test void cancelTimeoutWhenFinished() { - executeInEventLoop(0, scheduler -> { + executeInEventLoop(Long.MAX_VALUE, scheduler -> { scheduler.finishNow(); scheduler.clearTimeout(); assertThat(scheduler.isFinished()).isTrue(); @@ -197,7 +194,7 @@ void cancelTimeoutWhenFinished() { @Test void finishWhenFinished() { - executeInEventLoop(0, scheduler -> { + executeInEventLoop(Long.MAX_VALUE, scheduler -> { scheduler.finishNow(); assertThat(scheduler.isFinished()).isTrue(); scheduler.finishNow(); @@ -209,7 +206,7 @@ void finishWhenFinished() { void setTimeoutFromStartAfterClear() { final AtomicBoolean completed = new AtomicBoolean(); - executeInEventLoop(0, scheduler -> { + executeInEventLoop(Long.MAX_VALUE, scheduler -> { scheduler.clearTimeout(); final long newTimeoutNanos = MILLISECONDS.toNanos(1123); scheduler.setTimeoutNanos(SET_FROM_START, newTimeoutNanos); @@ -227,7 +224,7 @@ void setTimeoutFromStartAfterClear() { @Test void setTimeoutFromStartAfterClearAndFinished() { final AtomicBoolean completed = new AtomicBoolean(); - executeInEventLoop(0, scheduler -> { + executeInEventLoop(Long.MAX_VALUE, scheduler -> { scheduler.clearTimeout(); eventExecutor.schedule(() -> { final long newTimeoutNanos = MILLISECONDS.toNanos(1123); @@ -241,7 +238,7 @@ void setTimeoutFromStartAfterClearAndFinished() { @Test void cancellationCause() { - executeInEventLoop(0, scheduler -> { + executeInEventLoop(Long.MAX_VALUE, scheduler -> { scheduler.finishNow(new IllegalStateException()); assertThat(scheduler.isFinished()).isTrue(); assertThat(scheduler.cause()).isInstanceOf(IllegalStateException.class); @@ -251,11 +248,11 @@ void cancellationCause() { @Test void whenTimingOutAndWhenTimedOut() { final AtomicReference schedulerRef = new AtomicReference<>(); - final AtomicReference> whenTimedOutRef = new AtomicReference<>(); + final AtomicReference> whenTimedOutRef = new AtomicReference<>(); final AtomicBoolean completed = new AtomicBoolean(); final AtomicBoolean passed = new AtomicBoolean(); eventExecutor.execute(() -> { - final DefaultCancellationScheduler scheduler = new DefaultCancellationScheduler(0); + final DefaultCancellationScheduler scheduler = new DefaultCancellationScheduler(Long.MAX_VALUE); final CancellationTask task = new CancellationTask() { @Override public boolean canSchedule() { @@ -265,9 +262,9 @@ public boolean canSchedule() { @Override public void run(Throwable cause) { assertThat(cause).isInstanceOf(RequestTimeoutException.class); - assertThat(scheduler.whenTimingOut()).isDone(); + assertThat(scheduler.whenCancelling()).isDone(); assertThat(scheduler.isFinished()).isTrue(); - assertThat(scheduler.whenTimedOut()).isNotDone(); + assertThat(scheduler.whenCancelled()).isNotDone(); passed.set(true); } }; @@ -275,10 +272,10 @@ public void run(Throwable cause) { assertThat(scheduler.isFinished()).isFalse(); scheduler.setTimeoutNanos(SET_FROM_NOW, MILLISECONDS.toNanos(1000)); - assertThat(scheduler.state()).isEqualTo(CancellationScheduler.State.SCHEDULED); + assertThat(scheduler.state()).isEqualTo(State.PENDING); schedulerRef.set(scheduler); - whenTimedOutRef.set(scheduler.whenTimedOut()); + whenTimedOutRef.set(scheduler.whenCancelled()); completed.set(true); }); await().untilTrue(passed); @@ -291,7 +288,7 @@ public void run(Throwable cause) { void whenTimingOutAndWhenTimedOut2() { final AtomicReference> whenTimingOutRef = new AtomicReference<>(); final AtomicReference> whenTimedOutRef = new AtomicReference<>(); - executeInEventLoop(0, scheduler -> { + executeInEventLoop(Long.MAX_VALUE, scheduler -> { final CompletableFuture whenTimingOut = scheduler.whenCancelling(); final CompletableFuture whenTimedOut = scheduler.whenCancelled(); assertThat(whenTimingOut).isNotDone(); @@ -321,7 +318,8 @@ void whenCancellingAndWhenCancelled(boolean server) { } eventExecutor.execute(() -> { - final DefaultCancellationScheduler scheduler = new DefaultCancellationScheduler(0, server); + final DefaultCancellationScheduler scheduler = + new DefaultCancellationScheduler(Long.MAX_VALUE, server); final CancellationTask task = new CancellationTask() { @Override public boolean canSchedule() { @@ -341,7 +339,7 @@ public void run(Throwable cause) { assertThat(scheduler.isFinished()).isFalse(); scheduler.setTimeoutNanos(SET_FROM_NOW, MILLISECONDS.toNanos(1000)); - assertThat(scheduler.state()).isEqualTo(CancellationScheduler.State.SCHEDULED); + assertThat(scheduler.state()).isEqualTo(State.PENDING); schedulerRef.set(scheduler); whenCancellingRef.set(scheduler.whenCancelling()); @@ -365,7 +363,7 @@ void pendingTimeout() { scheduler.setTimeoutNanos(SET_FROM_NOW, 1000); assertThat(scheduler.timeoutNanos()).isEqualTo(1000); - scheduler.clearTimeout(false); + scheduler.cancelScheduled(); assertThat(scheduler.timeoutNanos()).isEqualTo(1000); scheduler.clearTimeout(); assertThat(scheduler.timeoutNanos()).isZero(); @@ -390,7 +388,7 @@ void evaluatePendingTimeout() { assertTimeoutWithTolerance(scheduler.timeoutNanos(), MILLISECONDS.toNanos(1000)); scheduler = new DefaultCancellationScheduler(MILLISECONDS.toNanos(1000)); - scheduler.clearTimeout(false); + scheduler.cancelScheduled(); scheduler.initAndStart(eventExecutor, noopTask); assertThat(scheduler.timeoutNanos()).isEqualTo(MILLISECONDS.toNanos(1000)); @@ -432,8 +430,8 @@ void multiple_ClearTimeoutInWhenCancelling() { final AtomicBoolean completed = new AtomicBoolean(); final CancellationScheduler scheduler = new DefaultCancellationScheduler(MILLISECONDS.toNanos(100)); scheduler.whenCancelling().thenRun(() -> { - scheduler.clearTimeout(false); - scheduler.clearTimeout(false); + scheduler.cancelScheduled(); + scheduler.cancelScheduled(); completed.set(true); }); eventExecutor.execute(() -> { @@ -446,7 +444,7 @@ void multiple_ClearTimeoutInWhenCancelling() { @Test void immediateFinishTriggersCompletion() { - final DefaultCancellationScheduler scheduler = new DefaultCancellationScheduler(0); + final DefaultCancellationScheduler scheduler = new DefaultCancellationScheduler(Long.MAX_VALUE); scheduler.init(eventExecutor); final Throwable throwable = new Throwable(); @@ -465,7 +463,7 @@ void immediateFinishTriggersCompletion() { @ParameterizedTest @ValueSource(booleans = {true, false}) void immediateFinishWithoutCause(boolean server) { - final DefaultCancellationScheduler scheduler = new DefaultCancellationScheduler(0, server); + final DefaultCancellationScheduler scheduler = new DefaultCancellationScheduler(Long.MAX_VALUE, server); scheduler.init(eventExecutor); @@ -484,6 +482,154 @@ void immediateFinishWithoutCause(boolean server) { } } + @Test + void immediateCancellation() { + // Tests that there is no need to go through the event loop for task invocation + final DefaultCancellationScheduler scheduler = new DefaultCancellationScheduler(Long.MAX_VALUE); + scheduler.init(eventExecutor); + final AtomicReference throwableRef = new AtomicReference<>(); + scheduler.updateTask(throwableRef::set); + + final Throwable throwable = new Throwable(); + scheduler.finishNow(throwable); + assertThat(scheduler.cause()).isSameAs(throwable); + assertThat(throwableRef.get()).isSameAs(throwable); + } + + @Test + void concurrentUpdateTask_onlyOneExecutedIfNotFinished() throws Exception { + final DefaultCancellationScheduler scheduler = new DefaultCancellationScheduler(Long.MAX_VALUE); + scheduler.init(eventExecutor); + scheduler.start(); + final int numTasks = 10; + final AtomicInteger atomicInteger = new AtomicInteger(); + final ExecutorService executor = Executors.newFixedThreadPool(numTasks); + final CountDownLatch waitLatch = new CountDownLatch(1); + final CountDownLatch doneLatch = new CountDownLatch(numTasks); + try { + for (int i = 0; i < numTasks; i++) { + executor.execute(() -> { + try { + waitLatch.await(); + } catch (InterruptedException e) { + throw new RuntimeException(e); + } + scheduler.updateTask(cause -> atomicInteger.incrementAndGet()); + doneLatch.countDown(); + }); + } + } finally { + executor.shutdown(); + } + waitLatch.countDown(); + doneLatch.await(); + + scheduler.finishNow(); + assertThat(atomicInteger.get()).isEqualTo(1); + } + + @Test + void concurrentUpdateTask_allExecutedIfFinished() throws Exception { + final DefaultCancellationScheduler scheduler = new DefaultCancellationScheduler(Long.MAX_VALUE); + scheduler.init(eventExecutor); + scheduler.start(); + final int numTasks = 10; + final AtomicInteger atomicInteger = new AtomicInteger(); + final ExecutorService executor = Executors.newFixedThreadPool(numTasks); + final CountDownLatch waitLatch = new CountDownLatch(1); + final CountDownLatch doneLatch = new CountDownLatch(numTasks); + try { + for (int i = 0; i < numTasks; i++) { + executor.execute(() -> { + try { + waitLatch.await(); + } catch (InterruptedException e) { + throw new RuntimeException(e); + } + scheduler.updateTask(cause -> atomicInteger.incrementAndGet()); + doneLatch.countDown(); + }); + } + } finally { + executor.shutdown(); + } + scheduler.finishNow(); + waitLatch.countDown(); + doneLatch.await(); + + assertThat(atomicInteger.get()).isEqualTo(numTasks); + } + + @Test + void timeoutNanos_fromNow() throws Exception { + final AtomicLong ticker = new AtomicLong(); + final DefaultCancellationScheduler scheduler = + new DefaultCancellationScheduler(1000, false, ticker::get); + scheduler.init(eventExecutor); + assertThat(scheduler.timeoutNanos()).isEqualTo(1000); + + scheduler.setTimeoutNanos(SET_FROM_NOW, 5000); + assertThat(scheduler.timeoutNanos()).isEqualTo(5000); + + ticker.addAndGet(2000); + + scheduler.start(); + + // 5000 (set from now) - 2000 (elapsed time until start) + assertThat(scheduler.timeoutNanos()).isEqualTo(3000); + + ticker.addAndGet(1000); + scheduler.setTimeoutNanos(SET_FROM_NOW, 5000); + // 1000 (since start time) + 5000 (set from now) + assertThat(scheduler.timeoutNanos()).isEqualTo(6000); + + scheduler.setTimeoutNanos(EXTEND, 1000); + // 6000 (previous timeout) + 1000 (extend) + assertThat(scheduler.timeoutNanos()).isEqualTo(7000); + + scheduler.clearTimeout(); + } + + @Test + void extendNanos_immediateExecution() throws Exception { + final AtomicLong ticker = new AtomicLong(); + final AtomicReference throwableRef = new AtomicReference<>(); + final DefaultCancellationScheduler scheduler = + new DefaultCancellationScheduler(1000, false, ticker::get); + scheduler.init(eventExecutor); + scheduler.updateTask(throwableRef::set); + scheduler.setTimeoutNanos(EXTEND, -2000); + assertThat(scheduler.timeoutNanos()).isEqualTo(-1000); + + // because the scheduler didn't start yet + assertThat(throwableRef).hasNullValue(); + + scheduler.setTimeoutNanos(EXTEND, 2000); + assertThat(scheduler.timeoutNanos()).isEqualTo(1000); + scheduler.start(); + scheduler.setTimeoutNanos(EXTEND, -2000); + assertThat(throwableRef).hasValueMatching(t -> t instanceof ResponseTimeoutException); + + scheduler.clearTimeout(); + } + + @Test + void zero_notInfinite() throws Exception { + final AtomicLong ticker = new AtomicLong(); + final AtomicReference throwableRef = new AtomicReference<>(); + final DefaultCancellationScheduler scheduler = + new DefaultCancellationScheduler(1000, false, ticker::get); + scheduler.init(eventExecutor); + scheduler.updateTask(throwableRef::set); + + assertThat(scheduler.timeoutNanos()).isEqualTo(1000); + scheduler.setTimeoutNanos(EXTEND, -1000); + assertThat(scheduler.timeoutNanos()).isEqualTo(0); + + scheduler.finishNow(); + assertThat(throwableRef).doesNotHaveNullValue(); + } + static void assertTimeoutWithTolerance(long actualNanos, long expectedNanos) { assertTimeoutWithTolerance(actualNanos, expectedNanos, MILLISECONDS.toNanos(200)); } From e6f95c16cce331bce2f5a46d7c0a8ce03ef6d915 Mon Sep 17 00:00:00 2001 From: jrhee17 Date: Tue, 9 Jul 2024 17:22:50 +0900 Subject: [PATCH 2/7] minor nit --- .../linecorp/armeria/client/AbstractHttpRequestHandler.java | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/core/src/main/java/com/linecorp/armeria/client/AbstractHttpRequestHandler.java b/core/src/main/java/com/linecorp/armeria/client/AbstractHttpRequestHandler.java index 915a1826df3..b1ea42f6052 100644 --- a/core/src/main/java/com/linecorp/armeria/client/AbstractHttpRequestHandler.java +++ b/core/src/main/java/com/linecorp/armeria/client/AbstractHttpRequestHandler.java @@ -200,6 +200,11 @@ final boolean tryInitialize() { if (scheduler != null) { scheduler.updateTask(newCancellationTask()); } + if (ctx.isCancelled()) { + // The previous cancellation task wraps the cause with an UnprocessedRequestException + // so we return early + return false; + } return true; } From 06012a1e86657707b5878b9cf93f5aff2ebd0a90 Mon Sep 17 00:00:00 2001 From: jrhee17 Date: Thu, 1 Aug 2024 14:32:47 +0900 Subject: [PATCH 3/7] address comments by @ikhoon --- .../armeria/internal/common/CancellationScheduler.java | 3 ++- .../internal/common/DefaultCancellationScheduler.java | 4 ++-- .../com/linecorp/armeria/common/ContextPushHookTest.java | 4 +--- .../armeria/internal/common/CancellationSchedulerTest.java | 6 +++--- 4 files changed, 8 insertions(+), 9 deletions(-) diff --git a/core/src/main/java/com/linecorp/armeria/internal/common/CancellationScheduler.java b/core/src/main/java/com/linecorp/armeria/internal/common/CancellationScheduler.java index 2145907e50f..04f40a680bb 100644 --- a/core/src/main/java/com/linecorp/armeria/internal/common/CancellationScheduler.java +++ b/core/src/main/java/com/linecorp/armeria/internal/common/CancellationScheduler.java @@ -121,13 +121,14 @@ default void finishNow() { enum State { INIT, - PENDING, + SCHEDULED, FINISHED, } /** * A cancellation task invoked by the scheduler when its timeout exceeds or invoke by the user. */ + @FunctionalInterface interface CancellationTask { /** * Returns {@code true} if the cancellation task can be scheduled. diff --git a/core/src/main/java/com/linecorp/armeria/internal/common/DefaultCancellationScheduler.java b/core/src/main/java/com/linecorp/armeria/internal/common/DefaultCancellationScheduler.java index e5792084926..39dda7ce3e4 100644 --- a/core/src/main/java/com/linecorp/armeria/internal/common/DefaultCancellationScheduler.java +++ b/core/src/main/java/com/linecorp/armeria/internal/common/DefaultCancellationScheduler.java @@ -123,11 +123,11 @@ public void start() { if (state != State.INIT) { return; } - state = State.PENDING; + state = State.SCHEDULED; startTimeNanos = ticker.read(); if (timeoutMode == TimeoutMode.SET_FROM_NOW) { final long elapsedTimeNanos = startTimeNanos - setFromNowStartNanos; - timeoutNanos = LongMath.saturatedSubtract(timeoutNanos, elapsedTimeNanos); + timeoutNanos = Long.max(LongMath.saturatedSubtract(timeoutNanos, elapsedTimeNanos), 0); } if (timeoutNanos != Long.MAX_VALUE) { scheduledFuture = eventLoop().schedule(() -> invokeTask(null), timeoutNanos, NANOSECONDS); diff --git a/core/src/test/java/com/linecorp/armeria/common/ContextPushHookTest.java b/core/src/test/java/com/linecorp/armeria/common/ContextPushHookTest.java index 6d76d7b4dc7..acd776f8d88 100644 --- a/core/src/test/java/com/linecorp/armeria/common/ContextPushHookTest.java +++ b/core/src/test/java/com/linecorp/armeria/common/ContextPushHookTest.java @@ -135,9 +135,7 @@ void shouldRunHooksWhenContextIsPushed() { hookEvents.clear(); response = client.get("http://foo.com:" + server.httpPort() + "/virtualhost"); assertThat(response.status()).isEqualTo(HttpStatus.OK); - // we don't do containsExactly here because there is no easy way to guarantee that - // all context hooks from the previous request have been completed - assertThat(hookEvents).contains( + assertThat(hookEvents).containsExactly( "ClientBuilder/push", "ClientContext/push", "ServerBuilder/push", diff --git a/core/src/test/java/com/linecorp/armeria/internal/common/CancellationSchedulerTest.java b/core/src/test/java/com/linecorp/armeria/internal/common/CancellationSchedulerTest.java index 2739d9362b1..2b51fa69e88 100644 --- a/core/src/test/java/com/linecorp/armeria/internal/common/CancellationSchedulerTest.java +++ b/core/src/test/java/com/linecorp/armeria/internal/common/CancellationSchedulerTest.java @@ -139,7 +139,7 @@ void cancelTimeoutAfterDeadline() { void cancelTimeoutBySettingTimeoutZero() { executeInEventLoop(1000, scheduler -> { scheduler.setTimeoutNanos(SET_FROM_START, Long.MAX_VALUE); - assertThat(scheduler.state()).isEqualTo(State.PENDING); + assertThat(scheduler.state()).isEqualTo(State.SCHEDULED); }); } @@ -272,7 +272,7 @@ public void run(Throwable cause) { assertThat(scheduler.isFinished()).isFalse(); scheduler.setTimeoutNanos(SET_FROM_NOW, MILLISECONDS.toNanos(1000)); - assertThat(scheduler.state()).isEqualTo(State.PENDING); + assertThat(scheduler.state()).isEqualTo(State.SCHEDULED); schedulerRef.set(scheduler); whenTimedOutRef.set(scheduler.whenCancelled()); @@ -339,7 +339,7 @@ public void run(Throwable cause) { assertThat(scheduler.isFinished()).isFalse(); scheduler.setTimeoutNanos(SET_FROM_NOW, MILLISECONDS.toNanos(1000)); - assertThat(scheduler.state()).isEqualTo(State.PENDING); + assertThat(scheduler.state()).isEqualTo(State.SCHEDULED); schedulerRef.set(scheduler); whenCancellingRef.set(scheduler.whenCancelling()); From 5cb756b03356742fbb8d7795d152b43b037c1082 Mon Sep 17 00:00:00 2001 From: jrhee17 Date: Thu, 8 Aug 2024 16:21:07 +0900 Subject: [PATCH 4/7] address comment by @trustin --- .../internal/common/CancellationScheduler.java | 10 ++++------ .../common/DefaultCancellationScheduler.java | 12 ++++++++++++ 2 files changed, 16 insertions(+), 6 deletions(-) diff --git a/core/src/main/java/com/linecorp/armeria/internal/common/CancellationScheduler.java b/core/src/main/java/com/linecorp/armeria/internal/common/CancellationScheduler.java index 04f40a680bb..18dc7276153 100644 --- a/core/src/main/java/com/linecorp/armeria/internal/common/CancellationScheduler.java +++ b/core/src/main/java/com/linecorp/armeria/internal/common/CancellationScheduler.java @@ -16,6 +16,8 @@ package com.linecorp.armeria.internal.common; +import static com.linecorp.armeria.internal.common.DefaultCancellationScheduler.translateTimeoutNanos; + import java.util.concurrent.CompletableFuture; import com.linecorp.armeria.common.annotation.Nullable; @@ -26,16 +28,12 @@ public interface CancellationScheduler { static CancellationScheduler ofClient(long timeoutNanos) { - if (timeoutNanos == 0) { - timeoutNanos = Long.MAX_VALUE; - } + timeoutNanos = translateTimeoutNanos(timeoutNanos); return new DefaultCancellationScheduler(timeoutNanos, false); } static CancellationScheduler ofServer(long timeoutNanos) { - if (timeoutNanos == 0) { - timeoutNanos = Long.MAX_VALUE; - } + timeoutNanos = translateTimeoutNanos(timeoutNanos); return new DefaultCancellationScheduler(timeoutNanos, true); } diff --git a/core/src/main/java/com/linecorp/armeria/internal/common/DefaultCancellationScheduler.java b/core/src/main/java/com/linecorp/armeria/internal/common/DefaultCancellationScheduler.java index 39dda7ce3e4..0c4daedc522 100644 --- a/core/src/main/java/com/linecorp/armeria/internal/common/DefaultCancellationScheduler.java +++ b/core/src/main/java/com/linecorp/armeria/internal/common/DefaultCancellationScheduler.java @@ -432,4 +432,16 @@ private static CancellationScheduler finished0(boolean server) { cancellationScheduler.finishNow(); return cancellationScheduler; } + + static long translateTimeoutNanos(long timeoutNanos) { + if (timeoutNanos == Long.MAX_VALUE) { + // If the user specified MAX_VALUE, then use MAX_VALUE-1 since MAX_VALUE means no scheduling + timeoutNanos = Long.MAX_VALUE - 1; + } + if (timeoutNanos == 0) { + // If the user specified 0, then use MAX_VALUE which means no scheduling + timeoutNanos = Long.MAX_VALUE; + } + return timeoutNanos; + } } From 36c094ccb9657c361ce2cdb157219fe9ce580a42 Mon Sep 17 00:00:00 2001 From: jrhee17 Date: Fri, 9 Aug 2024 10:29:03 +0900 Subject: [PATCH 5/7] address comments by @minwoox --- .../common/logging/DefaultRequestLog.java | 3 +- .../common/DefaultCancellationScheduler.java | 41 ++----------------- 2 files changed, 5 insertions(+), 39 deletions(-) diff --git a/core/src/main/java/com/linecorp/armeria/common/logging/DefaultRequestLog.java b/core/src/main/java/com/linecorp/armeria/common/logging/DefaultRequestLog.java index 85e243a909e..7cd223bb086 100644 --- a/core/src/main/java/com/linecorp/armeria/common/logging/DefaultRequestLog.java +++ b/core/src/main/java/com/linecorp/armeria/common/logging/DefaultRequestLog.java @@ -428,8 +428,7 @@ private void updateFlags(int flags) { private static void completeSatisfiedFutures(RequestLogFuture[] satisfiedFutures, RequestLog log, RequestContext ctx) { if (!ctx.eventLoop().inEventLoop()) { - ctx.eventLoop().withoutContext().execute( - () -> completeSatisfiedFutures(satisfiedFutures, log, ctx)); + ctx.eventLoop().execute(() -> completeSatisfiedFutures(satisfiedFutures, log, ctx)); return; } for (RequestLogFuture f : satisfiedFutures) { diff --git a/core/src/main/java/com/linecorp/armeria/internal/common/DefaultCancellationScheduler.java b/core/src/main/java/com/linecorp/armeria/internal/common/DefaultCancellationScheduler.java index 0c4daedc522..7682ac42861 100644 --- a/core/src/main/java/com/linecorp/armeria/internal/common/DefaultCancellationScheduler.java +++ b/core/src/main/java/com/linecorp/armeria/internal/common/DefaultCancellationScheduler.java @@ -22,7 +22,6 @@ import java.util.concurrent.CompletableFuture; import java.util.concurrent.ScheduledFuture; -import java.util.concurrent.atomic.AtomicReferenceFieldUpdater; import com.google.common.annotations.VisibleForTesting; import com.google.common.math.LongMath; @@ -42,14 +41,6 @@ final class DefaultCancellationScheduler implements CancellationScheduler { - private static final AtomicReferenceFieldUpdater - whenCancellingUpdater = AtomicReferenceFieldUpdater.newUpdater( - DefaultCancellationScheduler.class, CancellationFuture.class, "whenCancelling"); - - private static final AtomicReferenceFieldUpdater - whenCancelledUpdater = AtomicReferenceFieldUpdater.newUpdater( - DefaultCancellationScheduler.class, CancellationFuture.class, "whenCancelled"); - static final CancellationScheduler serverFinishedCancellationScheduler = finished0(true); static final CancellationScheduler clientFinishedCancellationScheduler = finished0(false); @@ -69,10 +60,8 @@ final class DefaultCancellationScheduler implements CancellationScheduler { private final ReentrantShortLock lock = new ReentrantShortLock(); private final boolean server; - @Nullable - private volatile CancellationFuture whenCancelling; - @Nullable - private volatile CancellationFuture whenCancelled; + private final CancellationFuture whenCancelling = new CancellationFuture(); + private final CancellationFuture whenCancelled = new CancellationFuture(); @VisibleForTesting DefaultCancellationScheduler(long timeoutNanos) { @@ -383,34 +372,12 @@ State state() { @Override public CompletableFuture whenCancelling() { - final CancellationFuture whenCancelling = this.whenCancelling; - if (whenCancelling != null) { - return whenCancelling; - } - final CancellationFuture cancellationFuture = new CancellationFuture(); - if (whenCancellingUpdater.compareAndSet(this, null, cancellationFuture)) { - return cancellationFuture; - } else { - final CancellationFuture oldWhenCancelling = this.whenCancelling; - assert oldWhenCancelling != null; - return oldWhenCancelling; - } + return whenCancelling; } @Override public CompletableFuture whenCancelled() { - final CancellationFuture whenCancelled = this.whenCancelled; - if (whenCancelled != null) { - return whenCancelled; - } - final CancellationFuture cancellationFuture = new CancellationFuture(); - if (whenCancelledUpdater.compareAndSet(this, null, cancellationFuture)) { - return cancellationFuture; - } else { - final CancellationFuture oldWhenCancelled = this.whenCancelled; - assert oldWhenCancelled != null; - return oldWhenCancelled; - } + return whenCancelled; } private enum ScheduleResult { From 610f83d813751be4f4269139790a80df754c0e53 Mon Sep 17 00:00:00 2001 From: jrhee17 Date: Fri, 9 Aug 2024 11:06:43 +0900 Subject: [PATCH 6/7] don't validate log completion thread local before a connection is opened --- .../com/linecorp/armeria/client/ContextCancellationTest.java | 4 ---- 1 file changed, 4 deletions(-) diff --git a/core/src/test/java/com/linecorp/armeria/client/ContextCancellationTest.java b/core/src/test/java/com/linecorp/armeria/client/ContextCancellationTest.java index 04b9b341b78..4424599e7ac 100644 --- a/core/src/test/java/com/linecorp/armeria/client/ContextCancellationTest.java +++ b/core/src/test/java/com/linecorp/armeria/client/ContextCancellationTest.java @@ -111,8 +111,6 @@ void cancel_beforeDelegate(TestInfo testInfo) { .hasRootCause(t); assertThat(connListener.opened()).isEqualTo(0); assertThat(requests).doesNotContain(testInfo.getDisplayName()); - // don't validate the thread since we haven't started with event loop scheduling yet - validateCallbackChecks(null); } } @@ -149,8 +147,6 @@ void cancel_beforeConnection(TestInfo testInfo) { .hasCauseInstanceOf(UnprocessedRequestException.class) .hasRootCause(t); assertThat(requests).doesNotContain(testInfo.getDisplayName()); - // don't validate the thread since we haven't started with event loop scheduling yet - validateCallbackChecks(null); } } From cc8db8310b7dbe154cf4758d729ec59a29fc1ded Mon Sep 17 00:00:00 2001 From: jrhee17 Date: Fri, 9 Aug 2024 11:35:49 +0900 Subject: [PATCH 7/7] handle flaky --- .../linecorp/armeria/client/ContextCancellationTest.java | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/core/src/test/java/com/linecorp/armeria/client/ContextCancellationTest.java b/core/src/test/java/com/linecorp/armeria/client/ContextCancellationTest.java index 4424599e7ac..37c2859c85b 100644 --- a/core/src/test/java/com/linecorp/armeria/client/ContextCancellationTest.java +++ b/core/src/test/java/com/linecorp/armeria/client/ContextCancellationTest.java @@ -48,6 +48,8 @@ import com.linecorp.armeria.common.annotation.Nullable; import com.linecorp.armeria.common.logging.RequestLogAccess; import com.linecorp.armeria.common.stream.SubscriptionOption; +import com.linecorp.armeria.common.util.SafeCloseable; +import com.linecorp.armeria.internal.common.RequestContextUtil; import com.linecorp.armeria.internal.testing.MockAddressResolverGroup; import com.linecorp.armeria.server.ServerBuilder; import com.linecorp.armeria.testing.junit5.common.EventLoopGroupExtension; @@ -245,7 +247,9 @@ void cancel_beforeWriteFinished(TestInfo testInfo) { public void subscribe(Subscriber subscriber, EventExecutor executor, SubscriptionOption... options) { super.subscribe(subscriber, executor, options); - ctxRef.get().cancel(t); + try (SafeCloseable ignored = RequestContextUtil.pop()) { + ctxRef.get().cancel(t); + } } }); assertThatThrownBy(() -> res.aggregate().join())