Skip to content

Commit

Permalink
pd_op.shuffle
Browse files Browse the repository at this point in the history
  • Loading branch information
lizexu123 committed Jan 22, 2025
1 parent c9374a9 commit 2ac9d61
Show file tree
Hide file tree
Showing 3 changed files with 167 additions and 0 deletions.
38 changes: 38 additions & 0 deletions paddle/fluid/pir/transforms/tensorrt/trt_op_marker_pass.cc
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,7 @@ DEFINE_GENERAL_PATTERN(Sign, paddle::dialect::SignOp)
DEFINE_GENERAL_PATTERN(Round, paddle::dialect::RoundOp)
DEFINE_GENERAL_PATTERN(Numel, paddle::dialect::NumelOp)
DEFINE_GENERAL_PATTERN(Pool3d, paddle::dialect::Pool3dOp)
DEFINE_GENERAL_PATTERN(ShuffleChannel, paddle::dialect::ShuffleChannelOp)

#undef DEFINE_GENERAL_PATTERN

Expand Down Expand Up @@ -2407,6 +2408,41 @@ class YoloBoxOpPattern
}
};

class FullBatchSizeLikeOpPattern
: public pir::OpRewritePattern<paddle::dialect::FullBatchSizeLikeOp> {
public:
using pir::OpRewritePattern<
paddle::dialect::FullBatchSizeLikeOp>::OpRewritePattern;
bool MatchAndRewrite(paddle::dialect::FullBatchSizeLikeOp op,
pir::PatternRewriter &rewriter) const override {
if (op->HasAttribute(kCanRunTrtAttr) &&
op->attribute<pir::BoolAttribute>(kCanRunTrtAttr).data()) {
return false;
}
if (!op->HasAttribute("input_dim_idx")) {
VLOG(3) << "pd_op.full_batch_size_like must has input_dim_idx attribute";
return false;
}
if (!op->HasAttribute("output_dim_idx")) {
VLOG(3) << "pd_op.full_batch_size_like must has output_dim_idx attribute";
return false;
}
if (!op->HasAttribute("shape")) {
VLOG(3) << "pd_op.full_batch_size_like must has shape attribute";
return false;
}
pir::Value input = op.operand_source(0);
auto input_type = pir::GetDataTypeFromValue(input);
if (!input_type.isa<pir::Float32Type>()) {
VLOG(3) << "pd_op.full_batch_size_like only support float32.";
return false;
}

op->set_attribute(kCanRunTrtAttr, rewriter.bool_attr(true));
return true;
}
};

class TrtOpMarkerPass : public pir::PatternRewritePass {
public:
TrtOpMarkerPass() : pir::PatternRewritePass("trt_op_marker_pass", 2) {}
Expand Down Expand Up @@ -2485,6 +2521,7 @@ class TrtOpMarkerPass : public pir::PatternRewritePass {
ADD_PATTERN(Round)
ADD_PATTERN(Numel)
ADD_PATTERN(Pool3d)
ADD_PATTERN(ShuffleChannel)
#if IS_TRT_VERSION_GE(8600)
ADD_PATTERN(Layer_norm)
#endif
Expand Down Expand Up @@ -2573,6 +2610,7 @@ class TrtOpMarkerPass : public pir::PatternRewritePass {
ps.Add(
std::make_unique<FusedBiasDropoutResidualLayerNormOpPattern>(context));
ps.Add(std::make_unique<YoloBoxOpPattern>(context));
ps.Add(std::make_unique<FullBatchSizeLikeOpPattern>(context));
return ps;
}
};
Expand Down
78 changes: 78 additions & 0 deletions python/paddle/tensorrt/impls/others.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,8 @@
get_shape_tensor_element,
get_trt_plugin,
trt_concat,
trt_div,
trt_gather,
trt_prod,
trt_shape,
trt_sub,
Expand Down Expand Up @@ -490,3 +492,79 @@ def affine_channel_converter(network, paddle_op, inputs):
out_tensor = shuffle_layer2.get_output(0)

return out_tensor


@converter_registry.register("pd_op.shuffle_channel", trt_version="8.x")
def shuffle_channel_converter(network, paddle_op, inputs):
input = inputs[0]
group = paddle_op.attrs().get("group")
input_shape_tensor = trt_shape(network, input)
batch_shape_tensor = get_shape_tensor_element(
network, input_shape_tensor, 0
)
channel_shape_tensor = get_shape_tensor_element(
network, input_shape_tensor, 1
)
group_tensor = add_1D_constant_layer(network, group)
new_channel_shape_tensor = trt_div(
network, channel_shape_tensor, group_tensor
)
shape_dim2 = [2, 3]
shape_dim2_tensor = trt_gather(network, input_shape_tensor, shape_dim2)
itensors = []
itensors.append(batch_shape_tensor)
itensors.append(group_tensor)
itensors.append(new_channel_shape_tensor)
itensors.append(shape_dim2_tensor)
reshape_tensor = trt_concat(network, itensors)
layer = network.add_shuffle(input)
layer.set_input(1, reshape_tensor)
transpose_embed = trt.Permutation([0, 2, 1, 3, 4])
layer.second_transpose = transpose_embed
output = layer.get_output(0)
output_layer = network.add_shuffle(output)
output_layer.set_input(1, input_shape_tensor)
return output_layer.get_output(0)


@converter_registry.register("pd_op.full_batch_size_like", trt_version="8.x")
def full_batch_size_like_converter(network, paddle_op, inputs):
input = inputs[0]
input_dim_idx = paddle_op.attrs().get("input_dim_idx")
output_dim_idx = paddle_op.attrs().get("output_dim_idx")
value = paddle_op.attrs().get("value")
shape = paddle_op.attrs().get("shape")
value = float(value)

input_shape_tensor = trt_shape(network, input)
batch_tensor = get_shape_tensor_element(
network, input_shape_tensor, input_dim_idx
)

shape_attr_tensor = add_1D_constant_layer(network, shape)

gather_output_shape_indices = []
for i in range(len(shape)):
if i == output_dim_idx:
gather_output_shape_indices.append(len(shape))
continue
gather_output_shape_indices.append(i)

concat_inputs = [shape_attr_tensor, batch_tensor]
concat_tensor = trt_concat(network, concat_inputs)
out_shape_tensor = trt_gather(
network, concat_tensor, gather_output_shape_indices
)

layer = network.add_fill(shape=(), op=trt.FillOperation.LINSPACE)

value_tensor = add_1D_constant_layer(network, [value], is_scalar=True)

beta_vec = [0.0] * len(shape)
beta_tensor = add_1D_constant_layer(network, beta_vec, is_scalar=False)

layer.set_input(0, out_shape_tensor)
layer.set_input(1, value_tensor)
layer.set_input(2, beta_tensor)

return layer.get_output(0)
51 changes: 51 additions & 0 deletions test/tensorrt/test_converter_others.py
Original file line number Diff line number Diff line change
Expand Up @@ -656,5 +656,56 @@ def test_fp16_trt_result(self):
self.check_trt_result(precision_mode="fp16")


def shuffle_channel_wrapper(x, group=1):
return _C_ops.shuffle_channel(x, group)


class TestShuffleChannelTRTPattern(TensorRTBaseTest):
def setUp(self):
self.python_api = shuffle_channel_wrapper
self.api_args = {
"x": np.random.random((10, 16, 4, 4)).astype("float32"),
}
self.program_config = {"feed_list": ["x"]}
self.min_shape = {"x": [10, 16, 4, 4]}
self.opt_shape = {"x": [10, 16, 4, 4]}
self.max_shape = {"x": [10, 16, 4, 4]}

def test_fp32_trt_result(self):
self.check_trt_result()

def test_fp16_trt_result(self):
self.check_trt_result(precision_mode="fp16")


def full_batch_size_like_wrapper(x, dtype, value, batch_dim):
place = paddle.CPUPlace()
out_shape = [-1, 5, 1]
return _C_ops.full_batch_size_like(
x, out_shape, dtype, value, batch_dim, batch_dim, place
)


class TestFullBatchSizeLikeTRTPattern(TensorRTBaseTest):
def setUp(self):
self.python_api = full_batch_size_like_wrapper
self.api_args = {
"x": np.random.random((2, 3, 4)).astype("float32"),
"dtype": paddle.float32,
"value": 2.0,
"batch_dim": 0,
}
self.program_config = {"feed_list": ["x"]}
self.min_shape = {"x": [2, 3, 4]}
self.opt_shape = {"x": [3, 3, 4]}
self.max_shape = {"x": [4, 3, 4]}

def test_fp32_trt_result(self):
self.check_trt_result()

def test_fp16_trt_result(self):
self.check_trt_result(precision_mode="fp16")


if __name__ == '__main__':
unittest.main()

0 comments on commit 2ac9d61

Please sign in to comment.