Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Handle a null ExecutionResult in GraphqlService #5816

Merged
merged 3 commits into from
Jul 19, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
import com.linecorp.armeria.common.HttpStatus;
import com.linecorp.armeria.common.MediaType;
import com.linecorp.armeria.common.graphql.protocol.GraphqlRequest;
import com.linecorp.armeria.common.util.Exceptions;
import com.linecorp.armeria.internal.server.graphql.protocol.GraphqlUtil;
import com.linecorp.armeria.server.ServiceRequestContext;
import com.linecorp.armeria.server.graphql.protocol.AbstractGraphqlService;
Expand All @@ -49,7 +50,7 @@ final class DefaultGraphqlService extends AbstractGraphqlService implements Grap
private final GraphQL graphQL;

private final Function<? super ServiceRequestContext,
? extends DataLoaderRegistry> dataLoaderRegistryFunction;
? extends DataLoaderRegistry> dataLoaderRegistryFunction;

private final boolean useBlockingTaskExecutor;

Expand Down Expand Up @@ -111,25 +112,35 @@ public CompletableFuture<ExecutionResult> executeGraphql(ServiceRequestContext c

private HttpResponse execute(
ServiceRequestContext ctx, ExecutionInput input, MediaType produceType) {
final CompletableFuture<ExecutionResult> future = executeGraphql(ctx, input);
return HttpResponse.of(
future.handle((executionResult, cause) -> {
if (executionResult.getData() instanceof Publisher) {
logger.warn("executionResult.getData() returns a {} that is not supported yet.",
executionResult.getData().toString());

return HttpResponse.ofJson(HttpStatus.NOT_IMPLEMENTED,
produceType,
toSpecification(
"Use GraphQL over WebSocket for subscription"));
}

if (executionResult.getErrors().isEmpty() && cause == null) {
return HttpResponse.ofJson(produceType, executionResult.toSpecification());
}

return errorHandler.handle(ctx, input, executionResult, cause);
}));
try {
final CompletableFuture<ExecutionResult> future = executeGraphql(ctx, input);
return HttpResponse.of(
future.handle((executionResult, cause) -> {
if (cause != null) {
cause = Exceptions.peel(cause);
return errorHandler.handle(ctx, input, null, cause);
}

if (executionResult.getData() instanceof Publisher) {
logger.warn("Use GraphQL over WebSocket for subscription. " +
"executionResult.getData(): {}", executionResult.getData().toString());

return HttpResponse.ofJson(HttpStatus.NOT_IMPLEMENTED,
produceType,
toSpecification(
"Use GraphQL over WebSocket for subscription"));
}

if (executionResult.getErrors().isEmpty()) {
return HttpResponse.ofJson(produceType, executionResult.toSpecification());
}

return errorHandler.handle(ctx, input, executionResult, null);
}));
} catch (Throwable cause) {
cause = Exceptions.peel(cause);
return errorHandler.handle(ctx, input, null, cause);
}
}

static Map<String, Object> toSpecification(String message) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,8 @@ static GraphqlErrorHandler of() {
*/
@Nullable
HttpResponse handle(
ServiceRequestContext ctx, ExecutionInput input, ExecutionResult result, @Nullable Throwable cause);
ServiceRequestContext ctx, ExecutionInput input, @Nullable ExecutionResult result,
@Nullable Throwable cause);

/**
* Returns a composed {@link GraphqlErrorHandler} that applies this first and the specified
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,83 +20,163 @@

import java.io.File;
import java.util.List;
import java.util.concurrent.atomic.AtomicBoolean;

import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.extension.RegisterExtension;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.ValueSource;

import com.linecorp.armeria.common.AggregatedHttpResponse;
import com.linecorp.armeria.common.HttpRequest;
import com.linecorp.armeria.common.HttpResponse;
import com.linecorp.armeria.common.HttpStatus;
import com.linecorp.armeria.common.MediaType;
import com.linecorp.armeria.internal.testing.AnticipatedException;
import com.linecorp.armeria.server.ServerBuilder;
import com.linecorp.armeria.server.ServiceRequestContext;
import com.linecorp.armeria.testing.junit5.server.ServerExtension;

import graphql.GraphQL;
import graphql.GraphQLError;
import graphql.GraphqlErrorException;
import graphql.execution.instrumentation.Instrumentation;
import graphql.execution.instrumentation.InstrumentationState;
import graphql.execution.instrumentation.parameters.InstrumentationCreateStateParameters;
import graphql.schema.DataFetcher;
import graphql.schema.GraphQLSchema;
import graphql.schema.idl.RuntimeWiring;
import graphql.schema.idl.SchemaGenerator;
import graphql.schema.idl.SchemaParser;
import graphql.schema.idl.TypeDefinitionRegistry;

class GraphqlErrorHandlerTest {

private static final AtomicBoolean shouldFailRequests = new AtomicBoolean();

private static GraphQL newGraphQL() throws Exception {
final File graphqlSchemaFile =
new File(GraphqlErrorHandlerTest.class.getResource("/testing/graphql/test.graphqls").toURI());
final SchemaParser schemaParser = new SchemaParser();
final SchemaGenerator schemaGenerator = new SchemaGenerator();
final TypeDefinitionRegistry typeRegistry = new TypeDefinitionRegistry();
typeRegistry.merge(schemaParser.parse(graphqlSchemaFile));
final RuntimeWiring.Builder runtimeWiringBuilder = RuntimeWiring.newRuntimeWiring();
final DataFetcher<String> foo = dataFetcher("foo");
runtimeWiringBuilder.type("Query",
typeWiring -> typeWiring.dataFetcher("foo", foo));
final DataFetcher<String> error = dataFetcher("error");
runtimeWiringBuilder.type("Query",
typeWiring -> typeWiring.dataFetcher("error", error));

final GraphQLSchema graphQLSchema = schemaGenerator.makeExecutableSchema(typeRegistry,
runtimeWiringBuilder.build());
final Instrumentation instrumentation = new Instrumentation() {
@Override
public InstrumentationState createState(
InstrumentationCreateStateParameters parameters) {
if (shouldFailRequests.get()) {
throw new AnticipatedException("external exception");
} else {
return Instrumentation.super.createState(parameters);
}
}
};

return new GraphQL.Builder(graphQLSchema)
.instrumentation(instrumentation)
.build();
}

private static final GraphqlErrorHandler errorHandler
= (ctx, input, result, cause) -> {
if (result == null) {
assertThat(cause).isNotNull();
return HttpResponse.of(HttpStatus.INTERNAL_SERVER_ERROR, MediaType.PLAIN_TEXT,
cause.getMessage());
}
final List<GraphQLError> errors = result.getErrors();
if (errors.stream().map(GraphQLError::getMessage).anyMatch(m -> m.endsWith("foo"))) {
return HttpResponse.of(HttpStatus.BAD_REQUEST);
}
return null;
};

private static DataFetcher<String> dataFetcher(String value) {
return environment -> {
final ServiceRequestContext ctx = GraphqlServiceContexts.get(environment);
// Make sure that a ServiceRequestContext is available
assertThat(ServiceRequestContext.current()).isSameAs(ctx);
throw GraphqlErrorException.newErrorException().message(value).build();
};
}

@RegisterExtension
static ServerExtension server = new ServerExtension() {
@Override
protected void configure(ServerBuilder sb) throws Exception {
final File graphqlSchemaFile =
new File(getClass().getResource("/testing/graphql/test.graphqls").toURI());

final GraphqlErrorHandler errorHandler
= (ctx, input, result, cause) -> {
final List<GraphQLError> errors = result.getErrors();
if (errors.stream().map(GraphQLError::getMessage).anyMatch(m -> m.endsWith("foo"))) {
return HttpResponse.of(HttpStatus.BAD_REQUEST);
}
return null;
};

final GraphqlService service =
GraphqlService.builder()
.schemaFile(graphqlSchemaFile)
.runtimeWiring(c -> {
final DataFetcher<String> foo = dataFetcher("foo");
c.type("Query",
typeWiring -> typeWiring.dataFetcher("foo", foo));
final DataFetcher<String> error = dataFetcher("error");
c.type("Query",
typeWiring -> typeWiring.dataFetcher("error", error));
})
.graphql(newGraphQL())
.errorHandler(errorHandler)
.build();
sb.service("/graphql", service);
}
};

private static DataFetcher<String> dataFetcher(String value) {
return environment -> {
final ServiceRequestContext ctx = GraphqlServiceContexts.get(environment);
assertThat(ctx.eventLoop().inEventLoop()).isTrue();
// Make sure that a ServiceRequestContext is available
assertThat(ServiceRequestContext.current()).isSameAs(ctx);
throw GraphqlErrorException.newErrorException().message(value).build();
};
@RegisterExtension
static ServerExtension blockingServer = new ServerExtension() {
@Override
protected void configure(ServerBuilder sb) throws Exception {

final GraphqlService service =
GraphqlService.builder()
.graphql(newGraphQL())
.useBlockingTaskExecutor(true)
.errorHandler(errorHandler)
.build();
sb.service("/graphql", service);
}
};

@BeforeEach
void setUp() {
shouldFailRequests.set(false);
}

@Test
void handledError() {
@ValueSource(booleans = { true, false })
@ParameterizedTest
void handledError(boolean blocking) {
final HttpRequest request = HttpRequest.builder().post("/graphql")
.content(MediaType.GRAPHQL, "{foo}")
.build();
final ServerExtension server = blocking ? blockingServer : GraphqlErrorHandlerTest.server;
final AggregatedHttpResponse response = server.blockingWebClient().execute(request);
assertThat(response.status()).isEqualTo(HttpStatus.BAD_REQUEST);
}

@Test
void unhandledError() {
@ValueSource(booleans = { true, false })
@ParameterizedTest
void unhandledGraphqlError(boolean blocking) {
final HttpRequest request = HttpRequest.builder().post("/graphql")
.content(MediaType.GRAPHQL, "{error}")
.build();
final ServerExtension server = blocking ? blockingServer : GraphqlErrorHandlerTest.server;
final AggregatedHttpResponse response = server.blockingWebClient().execute(request);
assertThat(response.status()).isEqualTo(HttpStatus.OK);
}

@ValueSource(booleans = { true, false })
@ParameterizedTest
void unhandledException(boolean blocking) {
shouldFailRequests.set(true);
final HttpRequest request = HttpRequest.builder().post("/graphql")
.content(MediaType.GRAPHQL, "{error}")
.build();
final ServerExtension server = blocking ? blockingServer : GraphqlErrorHandlerTest.server;
final AggregatedHttpResponse response = server.blockingWebClient().execute(request);
assertThat(response.status()).isEqualTo(HttpStatus.INTERNAL_SERVER_ERROR);
assertThat(response.contentUtf8()).isEqualTo("external exception");
}
}
Loading