Skip to content

Commit

Permalink
Handle a null ExecutionResult in GraphqlService (#5816)
Browse files Browse the repository at this point in the history
Motivation:

If `GraphQL.execute(input)` returns a `CompletableFuture` completing
exceptionally, `NullPointException` is raised while handling
`ExecutionResult`.
```java
java.lang.NullPointerException: Cannot invoke "graphql.ExecutionResult.getData()" because "executionResult" is null
	at com.linecorp.armeria.server.graphql.DefaultGraphqlService.lambda$execute$1(DefaultGraphqlService.java:117)
	at java.base/java.util.concurrent.CompletableFuture.uniHandle(CompletableFuture.java:934)
	at java.base/java.util.concurrent.CompletableFuture$UniHandle.tryFire(CompletableFuture.java:911)
```

Modifications:

- Check if `cause != null` before accessing `executionResult` in the
callback of `executeGraphql(ctx, input)`.

Result:

- `NullPointerException` is no longer raised when `GraphqlService`
handles errors.
- Closes #5815
  • Loading branch information
ikhoon committed Jul 26, 2024
1 parent ed776bb commit 11652aa
Show file tree
Hide file tree
Showing 3 changed files with 146 additions and 54 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
import com.linecorp.armeria.common.HttpStatus;
import com.linecorp.armeria.common.MediaType;
import com.linecorp.armeria.common.graphql.protocol.GraphqlRequest;
import com.linecorp.armeria.common.util.Exceptions;
import com.linecorp.armeria.internal.server.graphql.protocol.GraphqlUtil;
import com.linecorp.armeria.server.ServiceRequestContext;
import com.linecorp.armeria.server.graphql.protocol.AbstractGraphqlService;
Expand All @@ -49,7 +50,7 @@ final class DefaultGraphqlService extends AbstractGraphqlService implements Grap
private final GraphQL graphQL;

private final Function<? super ServiceRequestContext,
? extends DataLoaderRegistry> dataLoaderRegistryFunction;
? extends DataLoaderRegistry> dataLoaderRegistryFunction;

private final boolean useBlockingTaskExecutor;

Expand Down Expand Up @@ -111,25 +112,35 @@ public CompletableFuture<ExecutionResult> executeGraphql(ServiceRequestContext c

private HttpResponse execute(
ServiceRequestContext ctx, ExecutionInput input, MediaType produceType) {
final CompletableFuture<ExecutionResult> future = executeGraphql(ctx, input);
return HttpResponse.of(
future.handle((executionResult, cause) -> {
if (executionResult.getData() instanceof Publisher) {
logger.warn("executionResult.getData() returns a {} that is not supported yet.",
executionResult.getData().toString());

return HttpResponse.ofJson(HttpStatus.NOT_IMPLEMENTED,
produceType,
toSpecification(
"Use GraphQL over WebSocket for subscription"));
}

if (executionResult.getErrors().isEmpty() && cause == null) {
return HttpResponse.ofJson(produceType, executionResult.toSpecification());
}

return errorHandler.handle(ctx, input, executionResult, cause);
}));
try {
final CompletableFuture<ExecutionResult> future = executeGraphql(ctx, input);
return HttpResponse.of(
future.handle((executionResult, cause) -> {
if (cause != null) {
cause = Exceptions.peel(cause);
return errorHandler.handle(ctx, input, null, cause);
}

if (executionResult.getData() instanceof Publisher) {
logger.warn("Use GraphQL over WebSocket for subscription. " +
"executionResult.getData(): {}", executionResult.getData().toString());

return HttpResponse.ofJson(HttpStatus.NOT_IMPLEMENTED,
produceType,
toSpecification(
"Use GraphQL over WebSocket for subscription"));
}

if (executionResult.getErrors().isEmpty()) {
return HttpResponse.ofJson(produceType, executionResult.toSpecification());
}

return errorHandler.handle(ctx, input, executionResult, null);
}));
} catch (Throwable cause) {
cause = Exceptions.peel(cause);
return errorHandler.handle(ctx, input, null, cause);
}
}

static Map<String, Object> toSpecification(String message) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,8 @@ static GraphqlErrorHandler of() {
*/
@Nullable
HttpResponse handle(
ServiceRequestContext ctx, ExecutionInput input, ExecutionResult result, @Nullable Throwable cause);
ServiceRequestContext ctx, ExecutionInput input, @Nullable ExecutionResult result,
@Nullable Throwable cause);

/**
* Returns a composed {@link GraphqlErrorHandler} that applies this first and the specified
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,83 +20,163 @@

import java.io.File;
import java.util.List;
import java.util.concurrent.atomic.AtomicBoolean;

import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.extension.RegisterExtension;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.ValueSource;

import com.linecorp.armeria.common.AggregatedHttpResponse;
import com.linecorp.armeria.common.HttpRequest;
import com.linecorp.armeria.common.HttpResponse;
import com.linecorp.armeria.common.HttpStatus;
import com.linecorp.armeria.common.MediaType;
import com.linecorp.armeria.internal.testing.AnticipatedException;
import com.linecorp.armeria.server.ServerBuilder;
import com.linecorp.armeria.server.ServiceRequestContext;
import com.linecorp.armeria.testing.junit5.server.ServerExtension;

import graphql.GraphQL;
import graphql.GraphQLError;
import graphql.GraphqlErrorException;
import graphql.execution.instrumentation.Instrumentation;
import graphql.execution.instrumentation.InstrumentationState;
import graphql.execution.instrumentation.parameters.InstrumentationCreateStateParameters;
import graphql.schema.DataFetcher;
import graphql.schema.GraphQLSchema;
import graphql.schema.idl.RuntimeWiring;
import graphql.schema.idl.SchemaGenerator;
import graphql.schema.idl.SchemaParser;
import graphql.schema.idl.TypeDefinitionRegistry;

class GraphqlErrorHandlerTest {

private static final AtomicBoolean shouldFailRequests = new AtomicBoolean();

private static GraphQL newGraphQL() throws Exception {
final File graphqlSchemaFile =
new File(GraphqlErrorHandlerTest.class.getResource("/testing/graphql/test.graphqls").toURI());
final SchemaParser schemaParser = new SchemaParser();
final SchemaGenerator schemaGenerator = new SchemaGenerator();
final TypeDefinitionRegistry typeRegistry = new TypeDefinitionRegistry();
typeRegistry.merge(schemaParser.parse(graphqlSchemaFile));
final RuntimeWiring.Builder runtimeWiringBuilder = RuntimeWiring.newRuntimeWiring();
final DataFetcher<String> foo = dataFetcher("foo");
runtimeWiringBuilder.type("Query",
typeWiring -> typeWiring.dataFetcher("foo", foo));
final DataFetcher<String> error = dataFetcher("error");
runtimeWiringBuilder.type("Query",
typeWiring -> typeWiring.dataFetcher("error", error));

final GraphQLSchema graphQLSchema = schemaGenerator.makeExecutableSchema(typeRegistry,
runtimeWiringBuilder.build());
final Instrumentation instrumentation = new Instrumentation() {
@Override
public InstrumentationState createState(
InstrumentationCreateStateParameters parameters) {
if (shouldFailRequests.get()) {
throw new AnticipatedException("external exception");
} else {
return Instrumentation.super.createState(parameters);
}
}
};

return new GraphQL.Builder(graphQLSchema)
.instrumentation(instrumentation)
.build();
}

private static final GraphqlErrorHandler errorHandler
= (ctx, input, result, cause) -> {
if (result == null) {
assertThat(cause).isNotNull();
return HttpResponse.of(HttpStatus.INTERNAL_SERVER_ERROR, MediaType.PLAIN_TEXT,
cause.getMessage());
}
final List<GraphQLError> errors = result.getErrors();
if (errors.stream().map(GraphQLError::getMessage).anyMatch(m -> m.endsWith("foo"))) {
return HttpResponse.of(HttpStatus.BAD_REQUEST);
}
return null;
};

private static DataFetcher<String> dataFetcher(String value) {
return environment -> {
final ServiceRequestContext ctx = GraphqlServiceContexts.get(environment);
// Make sure that a ServiceRequestContext is available
assertThat(ServiceRequestContext.current()).isSameAs(ctx);
throw GraphqlErrorException.newErrorException().message(value).build();
};
}

@RegisterExtension
static ServerExtension server = new ServerExtension() {
@Override
protected void configure(ServerBuilder sb) throws Exception {
final File graphqlSchemaFile =
new File(getClass().getResource("/testing/graphql/test.graphqls").toURI());

final GraphqlErrorHandler errorHandler
= (ctx, input, result, cause) -> {
final List<GraphQLError> errors = result.getErrors();
if (errors.stream().map(GraphQLError::getMessage).anyMatch(m -> m.endsWith("foo"))) {
return HttpResponse.of(HttpStatus.BAD_REQUEST);
}
return null;
};

final GraphqlService service =
GraphqlService.builder()
.schemaFile(graphqlSchemaFile)
.runtimeWiring(c -> {
final DataFetcher<String> foo = dataFetcher("foo");
c.type("Query",
typeWiring -> typeWiring.dataFetcher("foo", foo));
final DataFetcher<String> error = dataFetcher("error");
c.type("Query",
typeWiring -> typeWiring.dataFetcher("error", error));
})
.graphql(newGraphQL())
.errorHandler(errorHandler)
.build();
sb.service("/graphql", service);
}
};

private static DataFetcher<String> dataFetcher(String value) {
return environment -> {
final ServiceRequestContext ctx = GraphqlServiceContexts.get(environment);
assertThat(ctx.eventLoop().inEventLoop()).isTrue();
// Make sure that a ServiceRequestContext is available
assertThat(ServiceRequestContext.current()).isSameAs(ctx);
throw GraphqlErrorException.newErrorException().message(value).build();
};
@RegisterExtension
static ServerExtension blockingServer = new ServerExtension() {
@Override
protected void configure(ServerBuilder sb) throws Exception {

final GraphqlService service =
GraphqlService.builder()
.graphql(newGraphQL())
.useBlockingTaskExecutor(true)
.errorHandler(errorHandler)
.build();
sb.service("/graphql", service);
}
};

@BeforeEach
void setUp() {
shouldFailRequests.set(false);
}

@Test
void handledError() {
@ValueSource(booleans = { true, false })
@ParameterizedTest
void handledError(boolean blocking) {
final HttpRequest request = HttpRequest.builder().post("/graphql")
.content(MediaType.GRAPHQL, "{foo}")
.build();
final ServerExtension server = blocking ? blockingServer : GraphqlErrorHandlerTest.server;
final AggregatedHttpResponse response = server.blockingWebClient().execute(request);
assertThat(response.status()).isEqualTo(HttpStatus.BAD_REQUEST);
}

@Test
void unhandledError() {
@ValueSource(booleans = { true, false })
@ParameterizedTest
void unhandledGraphqlError(boolean blocking) {
final HttpRequest request = HttpRequest.builder().post("/graphql")
.content(MediaType.GRAPHQL, "{error}")
.build();
final ServerExtension server = blocking ? blockingServer : GraphqlErrorHandlerTest.server;
final AggregatedHttpResponse response = server.blockingWebClient().execute(request);
assertThat(response.status()).isEqualTo(HttpStatus.OK);
}

@ValueSource(booleans = { true, false })
@ParameterizedTest
void unhandledException(boolean blocking) {
shouldFailRequests.set(true);
final HttpRequest request = HttpRequest.builder().post("/graphql")
.content(MediaType.GRAPHQL, "{error}")
.build();
final ServerExtension server = blocking ? blockingServer : GraphqlErrorHandlerTest.server;
final AggregatedHttpResponse response = server.blockingWebClient().execute(request);
assertThat(response.status()).isEqualTo(HttpStatus.INTERNAL_SERVER_ERROR);
assertThat(response.contentUtf8()).isEqualTo("external exception");
}
}

0 comments on commit 11652aa

Please sign in to comment.