Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support adding gRPC interceptors using annotation '@GrpcInterceptor' #5397

Open
wants to merge 13 commits into
base: main
Choose a base branch
from
Open
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
/*
* Copyright 2024 LINE Corporation
*
* LINE Corporation licenses this file to you under the Apache License,
* version 2.0 (the "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at:
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*/

package com.linecorp.armeria.server.grpc;

import java.lang.annotation.ElementType;
import java.lang.annotation.Repeatable;
import java.lang.annotation.Retention;
import java.lang.annotation.RetentionPolicy;
import java.lang.annotation.Target;

import com.linecorp.armeria.common.grpc.GrpcExceptionHandlerFunction;

import io.grpc.ServerInterceptor;

/**
* Specifies a {@link ServerInterceptor} class which intercepts requests and responses of a gRPC service or its
* methods.
*/
@Repeatable(GrpcInterceptors.class)
@Retention(RetentionPolicy.RUNTIME)
@Target({ ElementType.TYPE, ElementType.METHOD })
public @interface GrpcInterceptor {

/**
* {@link GrpcExceptionHandlerFunction} implementation type.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
* {@link GrpcExceptionHandlerFunction} implementation type.
* {@link ServerInterceptor} implementation type.

*/
Class<? extends ServerInterceptor> value();
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
/*
* Copyright 2024 LINE Corporation
*
* LINE Corporation licenses this file to you under the Apache License,
* version 2.0 (the "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at:
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*/

package com.linecorp.armeria.server.grpc;

import java.lang.annotation.ElementType;
import java.lang.annotation.Retention;
import java.lang.annotation.RetentionPolicy;
import java.lang.annotation.Target;

/**
* The containing annotation type for {@link GrpcInterceptor}.
*/
@Retention(RetentionPolicy.RUNTIME)
@Target({ ElementType.TYPE, ElementType.METHOD })
public @interface GrpcInterceptors {

/**
* An array of {@link GrpcInterceptor}s.
*/
GrpcInterceptor[] value();
}
Original file line number Diff line number Diff line change
Expand Up @@ -20,24 +20,33 @@
import static com.google.common.base.Preconditions.checkArgument;
import static com.google.common.base.Preconditions.checkState;
import static java.util.Objects.requireNonNull;
import static org.reflections.ReflectionUtils.withModifier;

import java.lang.reflect.Method;
import java.lang.reflect.Modifier;
import java.time.Duration;
import java.util.List;
import java.util.Optional;
import java.util.Set;
import java.util.concurrent.Executors;
import java.util.function.BiFunction;
import java.util.function.Consumer;
import java.util.function.Function;
import java.util.stream.Collectors;
import java.util.stream.Stream;

import org.curioswitch.common.protobuf.json.MessageMarshaller;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import com.google.common.base.CaseFormat;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableSet;
import com.google.common.collect.Iterables;
import com.google.protobuf.ByteString;
import com.google.protobuf.Message;

import com.linecorp.armeria.common.DependencyInjector;
import com.linecorp.armeria.common.HttpResponse;
import com.linecorp.armeria.common.RequestContext;
import com.linecorp.armeria.common.SerializationFormat;
Expand All @@ -51,6 +60,8 @@
import com.linecorp.armeria.common.grpc.GrpcStatusFunction;
import com.linecorp.armeria.common.grpc.protocol.AbstractMessageDeframer;
import com.linecorp.armeria.common.grpc.protocol.ArmeriaMessageFramer;
import com.linecorp.armeria.internal.common.ReflectiveDependencyInjector;
import com.linecorp.armeria.internal.server.annotation.AnnotationUtil;
import com.linecorp.armeria.server.HttpService;
import com.linecorp.armeria.server.HttpServiceWithRoutes;
import com.linecorp.armeria.server.Server;
Expand All @@ -68,6 +79,7 @@
import io.grpc.MethodDescriptor;
import io.grpc.ServerInterceptor;
import io.grpc.ServerInterceptors;
import io.grpc.ServerMethodDefinition;
import io.grpc.ServerServiceDefinition;
import io.grpc.ServiceDescriptor;
import io.grpc.Status;
Expand Down Expand Up @@ -938,6 +950,65 @@ private ImmutableList.Builder<ServerInterceptor> interceptors() {
return interceptors;
}

/**
* Get the Method from the MethodDescriptor for the given gRPC service. This method can be used to
* get annotations applied to the method.
*
* @param clazz The class of the service.
* @param methodDescriptor The method descriptor to get the method for.
* @return The method for the given method descriptor.
*/
private Optional<Method> getMethodFromMethodDescriptor(Class<?> clazz,
MethodDescriptor<?, ?> methodDescriptor) {
final String methodName = methodDescriptor.getBareMethodName();

if (methodName == null) {
return Optional.empty();
}

final String matchingMethodName = CaseFormat.UPPER_CAMEL
.converterTo(CaseFormat.LOWER_CAMEL)
.convert(methodName);

if (matchingMethodName == null) {
return Optional.empty();
}

return InternalReflectionUtils.getAllSortedMethods(clazz, withModifier(Modifier.PUBLIC))
.stream()
.filter(m -> matchingMethodName.equals(m.getName()))
.findFirst();
}

/**
* Get the interceptors for the given method created by using annotations.
* @param clazz The class of the service.
* @param method The method to get interceptors for.
* @param dependencyInjector The dependency injector to use.
* @param globalInterceptors The global interceptors to use. This comes from the builder.
* @return The list of interceptors for the given method in order.
*/
private List<ServerInterceptor> getInterceptorsFromAnnotations(Class<?> clazz, Method method,
DependencyInjector dependencyInjector,
List<ServerInterceptor> globalInterceptors) {
final List<ServerInterceptor> methodAndClassInterceptors =
AnnotationUtil.getAnnotatedInstances(method, clazz,
GrpcInterceptor.class,
ServerInterceptor.class,
dependencyInjector).build().reverse();

return Stream.concat(globalInterceptors.stream(), methodAndClassInterceptors.stream())
.collect(Collectors.toList());
}

private String calculateServicePath(Entry entry, MethodDescriptor<?, ?> methodDescriptor) {
// Use the path of method descriptor instead of the path of service. We are adding a
// single method to the registry as a service opposed to adding the entire service
// to the registry. The reason is that we can't intercept individual methods if we
// add the service as a whole to the registry.
return entry.path() + '/' + methodDescriptor.getBareMethodName();
}

/**
* Constructs a new {@link GrpcService} that can be bound to
* {@link ServerBuilder}. It is recommended to bind the service to a server using
Expand All @@ -946,7 +1017,6 @@ private ImmutableList.Builder<ServerInterceptor> interceptors() {
* without interfering with other services.
*/
public GrpcService build() {
final HandlerRegistry handlerRegistry;
if (USE_COROUTINE_CONTEXT_INTERCEPTOR) {
final ServerInterceptor coroutineContextInterceptor =
new ArmeriaCoroutineContextInterceptor(useBlockingTaskExecutor);
Expand Down Expand Up @@ -981,28 +1051,108 @@ public GrpcService build() {
grpcExceptionHandler = exceptionHandler;
}

// We will copy the service and methods inside the old registry to new one and intercept them.
final HandlerRegistry.Builder newRegistryBuilder = new HandlerRegistry.Builder();

if (grpcExceptionHandler != null) {
registryBuilder.setDefaultExceptionHandler(grpcExceptionHandler);
newRegistryBuilder.setDefaultExceptionHandler(grpcExceptionHandler);
}

if (interceptors != null) {
minwoox marked this conversation as resolved.
Show resolved Hide resolved
final HandlerRegistry.Builder newRegistryBuilder = new HandlerRegistry.Builder();
final ImmutableList<ServerInterceptor> interceptors = this.interceptors.build();
for (Entry entry : registryBuilder.entries()) {
final MethodDescriptor<?, ?> methodDescriptor = entry.method();
final ServerServiceDefinition intercepted =
ServerInterceptors.intercept(entry.service(), interceptors);
newRegistryBuilder.addService(entry.path(), intercepted, methodDescriptor, entry.type(),
entry.additionalDecorators());
}
if (grpcExceptionHandler != null) {
newRegistryBuilder.setDefaultExceptionHandler(grpcExceptionHandler);
// Interceptors passed via the grpc service builder.
final ImmutableList<ServerInterceptor> globalInterceptors;

if (this.interceptors == null) {
globalInterceptors = ImmutableList.of();
} else {
globalInterceptors = this.interceptors.build();
}

// Use reflection to parse annotations and get interceptors.
final DependencyInjector dependencyInjector = new ReflectiveDependencyInjector();

// Copy services, method to the new registry builder and intercept them.
for (Entry entry : registryBuilder.entries()) {
final MethodDescriptor<?, ?> methodDescriptor = entry.method();

if (entry.type() != null && methodDescriptor == null) {
// A "Service" entry thus there is no method descriptor.

final List<MethodDescriptor<?, ?>> serverMethodDescriptors =
entry.service().getMethods().stream()
.map(ServerMethodDefinition::getMethodDescriptor)
.collect(Collectors.toList());

final boolean shouldSplitServiceToMethod = serverMethodDescriptors.stream().anyMatch(
methodDescriptor1 -> {
final Optional<Method> methodOption =
getMethodFromMethodDescriptor(entry.type(), methodDescriptor1);

if (methodOption.isPresent()) {
final List<ServerInterceptor> allInterceptors =
getInterceptorsFromAnnotations(entry.type(), methodOption.get(),
dependencyInjector, ImmutableList.of());

return !allInterceptors.isEmpty();
}

return false;
}
);

if (shouldSplitServiceToMethod) {
// Add all methods of the service to the new registry builder one by one and intercept them.
for (MethodDescriptor<?, ?> serverMethodDescriptor : serverMethodDescriptors) {
final Optional<Method> methodOption =
getMethodFromMethodDescriptor(entry.type(), serverMethodDescriptor);

if (methodOption.isPresent()) {
final List<ServerInterceptor> allInterceptors =
getInterceptorsFromAnnotations(entry.type(), methodOption.get(),
dependencyInjector, globalInterceptors);

final ServerServiceDefinition intercepted =
ServerInterceptors.intercept(entry.service(), allInterceptors);

final String path = calculateServicePath(entry, serverMethodDescriptor);
newRegistryBuilder.addService(path, intercepted,
serverMethodDescriptor, entry.type(),
ImmutableList.copyOf(entry.additionalDecorators()));
}
}
} else {
// No need to split service into individual methods if there are no interceptors.
final ServerServiceDefinition intercepted =
ServerInterceptors.intercept(entry.service(), globalInterceptors);
newRegistryBuilder.addService(entry.path(), intercepted, methodDescriptor,
entry.type(), entry.additionalDecorators());
}
} else if (entry.type() != null) {
// A "Method" entry
final Optional<Method> methodOption =
getMethodFromMethodDescriptor(entry.type(), methodDescriptor);

if (methodOption.isPresent()) {
final List<ServerInterceptor> allInterceptors =
getInterceptorsFromAnnotations(entry.type(), methodOption.get(),
dependencyInjector, globalInterceptors);

final ServerServiceDefinition intercepted = ServerInterceptors.intercept(entry.service(),
allInterceptors);
newRegistryBuilder.addService(entry.path(), intercepted, methodDescriptor,
entry.type(), entry.additionalDecorators());
}
} else {
// Others
// Only intercept the service with global interceptors.
final ServerServiceDefinition intercepted = ServerInterceptors.intercept(entry.service(),
globalInterceptors);
newRegistryBuilder.addService(entry.path(), intercepted, methodDescriptor,
entry.type(), entry.additionalDecorators());
}
handlerRegistry = newRegistryBuilder.build();
} else {
handlerRegistry = registryBuilder.build();
}

final HandlerRegistry handlerRegistry = newRegistryBuilder.build();

GrpcService grpcService = new FramedGrpcService(
handlerRegistry,
firstNonNull(decompressorRegistry, DecompressorRegistry.getDefaultInstance()),
Expand Down
Loading