Skip to content

Commit

Permalink
fix
Browse files Browse the repository at this point in the history
  • Loading branch information
lizexu123 committed Jan 22, 2025
1 parent 117071a commit 9320669
Showing 1 changed file with 77 additions and 0 deletions.
77 changes: 77 additions & 0 deletions python/paddle/tensorrt/impls/others.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,8 @@
get_shape_tensor_element,
get_trt_plugin,
trt_concat,
trt_gather,
trt_div,
trt_prod,
trt_shape,
trt_sub,
Expand Down Expand Up @@ -490,3 +492,78 @@ def affine_channel_converter(network, paddle_op, inputs):
out_tensor = shuffle_layer2.get_output(0)

return out_tensor


@converter_registry.register("pd_op.shuffle_channel", trt_version="8.x")
def shuffle_channel_converter(network, paddle_op, inputs):
input = inputs[0]
group = paddle_op.attrs().get("group")
input_shape_tensor = trt_shape(network, input)
batch_shape_tensor = get_shape_tensor_element(
network, input_shape_tensor, 0
)
channel_shape_tensor = get_shape_tensor_element(
network, input_shape_tensor, 1
)
group_tensor = add_1D_constant_layer(network, group)
new_channel_shape_tensor = trt_div(
network, channel_shape_tensor, group_tensor
)
shape_dim2 = [2, 3]
shape_dim2_tensor = trt_gather(network, input_shape_tensor, shape_dim2)
itensors = []
itensors.append(batch_shape_tensor)
itensors.append(group_tensor)
itensors.append(new_channel_shape_tensor)
itensors.append(shape_dim2_tensor)
reshape_tensor = trt_concat(network, itensors)
layer = network.add_shuffle(input)
layer.set_input(1, reshape_tensor)
transpose_embed = trt.Permutation([0, 2, 1, 3, 4])
layer.second_transpose = transpose_embed
output = layer.get_output(0)
output_layer = network.add_shuffle(output)
output_layer.set_input(1, input_shape_tensor)
return output_layer.get_output(0)

@converter_registry.register("pd_op.full_batch_size_like", trt_version="8.x")
def full_batch_size_like_converter(network, paddle_op, inputs):
input = inputs[0]
input_dim_idx = paddle_op.attrs().get("input_dim_idx")
output_dim_idx = paddle_op.attrs().get("output_dim_idx")
value = paddle_op.attrs().get("value")
shape = paddle_op.attrs().get("shape")
value = float(value)

input_shape_tensor = trt_shape(network, input)
batch_tensor = get_shape_tensor_element(
network, input_shape_tensor, input_dim_idx
)

shape_attr_tensor = add_1D_constant_layer(network, shape)

gather_output_shape_indices = []
for i in range(len(shape)):
if i == output_dim_idx:
gather_output_shape_indices.append(len(shape))
continue
gather_output_shape_indices.append(i)

concat_inputs = [shape_attr_tensor, batch_tensor]
concat_tensor = trt_concat(network, concat_inputs)
out_shape_tensor = trt_gather(
network, concat_tensor, gather_output_shape_indices
)

layer = network.add_fill(shape=(), op=trt.FillOperation.LINSPACE)

value_tensor = add_1D_constant_layer(network, [value], is_scalar=True)

beta_vec = [0.0] * len(shape)
beta_tensor = add_1D_constant_layer(network, beta_vec, is_scalar=False)

layer.set_input(0, out_shape_tensor)
layer.set_input(1, value_tensor)
layer.set_input(2, beta_tensor)

return layer.get_output(0)

0 comments on commit 9320669

Please sign in to comment.