Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Inference]Refactor PIR-TRT weight code #70762

Merged
merged 8 commits into from
Jan 13, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 19 additions & 23 deletions python/paddle/tensorrt/converter.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,6 @@
import hashlib
import logging

import numpy as np

import paddle

paddle.base.core.register_paddle_plugin()
Expand All @@ -27,7 +25,6 @@
from paddle import pir
from paddle.base.core import clear_shape_info, get_value_shape_range_info
from paddle.base.log_helper import get_logger
from paddle.pir.core import _PADDLE_PIR_DTYPE_2_NUMPY_DTYPE

from .impls.activation import * # noqa: F403
from .impls.attribute import * # noqa: F403
Expand All @@ -49,6 +46,7 @@
from .register import converter_registry
from .util import (
TensorRTConfigManager,
TensorRTConstantManager,
get_cache_path,
get_trt_version,
get_trt_version_list,
Expand All @@ -71,17 +69,14 @@ def __init__(self, paddle_program, scope, trt_config=None):
self.scope = scope
self.program = paddle_program
self.trt_config = trt_config
constant_manager = TensorRTConstantManager()
params = paddle_program.global_block().all_parameters()
param_dict = {}
# save parameters
for v in params:
name = v.get_defining_op().attrs()["parameter_name"]
if self.scope.find_var(name) is None:
weight_array = None
else:
weight_array = np.array(self.scope.var(name).get_tensor())
param_dict.update({name: weight_array})
self.param_dict = param_dict
weight_tensor = self.scope.var(name).get_tensor()
constant_manager.set_constant_value(name, weight_tensor, v)

self.input_info = {}
self.trt_output_value_map = {}
Expand Down Expand Up @@ -154,6 +149,7 @@ def convert_subgraph_to_trt(self, program, group_op):
max_value_map = {}
input_names = []
new_input_values = []
constant_manager = TensorRTConstantManager()

# Because one of the inputs to pd_op.concat is builtin.combine,
# during the conversion process using the converter,
Expand All @@ -174,26 +170,22 @@ def convert_subgraph_to_trt(self, program, group_op):
defining_op = value.get_defining_op()
if defining_op.name() == "builtin.parameter":
param_name = defining_op.attrs()["parameter_name"]
weight = trt.Weights(self.param_dict[param_name])
weight = trt.Weights(
constant_manager.get_constant_value(param_name)
)
value_to_trt_tensor[value.id] = weight
elif defining_op.name() == "builtin.constant":
constant_value_name = defining_op.attrs()["value"]
constant_tensor = self.scope.var(
constant_value_name
).get_tensor()
out_dtype = np.dtype(
_PADDLE_PIR_DTYPE_2_NUMPY_DTYPE[value.dtype]
constant_manager.set_constant_value(
constant_value_name, constant_tensor, value
)
if out_dtype == np.dtype("float64"):
out_dtype = np.dtype("float32")
if out_dtype == np.dtype("int64"):
out_dtype = np.dtype("int32")
constant_data = np.array(constant_tensor, dtype=out_dtype)
if len(constant_data) == 0:
value_to_trt_tensor[value.id] = None
else:
constant_tensor = trt.Weights(constant_data)
value_to_trt_tensor[value.id] = constant_tensor
constant_tensor = trt.Weights(
constant_manager.get_constant_value(constant_value_name)
)
value_to_trt_tensor[value.id] = constant_tensor
else:
shape = value.shape
dtype = map_dtype(value.dtype.name)
Expand Down Expand Up @@ -572,7 +564,11 @@ def convert_program_to_trt(self):
)
for op in self.program.global_block().ops:
if op.name() == "builtin.parameter":
if not save_one_parameter:
parameter_name = op.attrs()["parameter_name"]
if (
not save_one_parameter
and "constant_folding" not in parameter_name
):
save_one_parameter = True
continue
if op.results()[0].use_empty():
Expand Down
9 changes: 6 additions & 3 deletions python/paddle/tensorrt/converter_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
import numpy as np
import tensorrt as trt

from paddle.tensorrt.util import TensorRTConfigManager
from paddle.tensorrt.util import TensorRTConfigManager, TensorRTConstantManager

current_dir = os.path.dirname(os.path.abspath(__file__))
parent_dir = os.path.abspath(os.path.join(current_dir, os.pardir))
Expand Down Expand Up @@ -171,8 +171,8 @@ def add_elementwise_layer(network, paddle_op, inputs, op_type):
def add_1D_constant_layer(network, data, dtype=np.int32, is_scalar=False):
if not isinstance(data, list):
data = [data]
constant_data = np.array(data, dtype=dtype)
shape = () if is_scalar else (len(data),)
constant_data = np.array(data, dtype=dtype)
constant_layer = network.add_constant(shape, constant_data)
return constant_layer.get_output(0)

Expand Down Expand Up @@ -569,8 +569,11 @@ def convert_conv2d(network, paddle_op, inputs):

def get_input_constant_value(paddle_op, inputs, input_index):
input_op = paddle_op.operands()[input_index].source().get_defining_op()
constant_manager = TensorRTConstantManager()
if input_op.name() == "builtin.constant":
return inputs[input_index].numpy().tolist()
return constant_manager.get_constant_value(
input_op.attrs()["value"]
).tolist()
elif input_op.name() == "pd_op.full_int_array":
return input_op.attrs()["value"]
elif input_op.name() == "pd_op.full":
Expand Down
6 changes: 2 additions & 4 deletions python/paddle/tensorrt/export.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@
forbid_op_lower_trt,
mark_builtin_op,
run_pir_pass,
run_trt_partition,
warmup_shape_infer,
)

Expand Down Expand Up @@ -264,7 +265,6 @@ def convert_to_trt(program, trt_config, scope):
# run pir pass (including trt_op_marker_pass)
program_with_pir = run_pir_pass(
program,
partition_mode=False,
disable_passes=trt_config.disable_passes,
scope=scope,
)
Expand All @@ -286,9 +286,7 @@ def convert_to_trt(program, trt_config, scope):
mark_builtin_op(program)

# run pir pass (including trt_sub_graph_extract_pass)
program_with_pir = run_pir_pass(
program, partition_mode=True, scope=scope
)
program_with_pir = run_trt_partition(program)

# Step4: run TRTConverter (would lower group_op into tensorrt_engine_op)
converter = PaddleToTensorRTConverter(
Expand Down
1 change: 1 addition & 0 deletions python/paddle/tensorrt/impls/creation.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,6 +151,7 @@ def full_like_converter(network, paddle_op, inputs):

fill_value = get_input_constant_value(paddle_op, inputs, 1)
if fill_value is not None:
fill_value = fill_value[0]
value = network.add_constant(
(1,),
np.array(
Expand Down
1 change: 1 addition & 0 deletions python/paddle/tensorrt/impls/math.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,7 @@ def scale_converter(network, paddle_op, inputs):

scale = get_input_constant_value(paddle_op, inputs, 1)
if scale is not None:
scale = scale[0]
has_scale_tensor = False
if is_int:
scale_tensor = add_1D_constant_layer(
Expand Down
99 changes: 80 additions & 19 deletions python/paddle/tensorrt/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@
import logging
import os

import numpy as np

import paddle

try:
Expand All @@ -23,6 +25,7 @@
pass
from paddle import pir
from paddle.base.log_helper import get_logger
from paddle.pir.core import _PADDLE_PIR_DTYPE_2_NUMPY_DTYPE

_logger = get_logger(
__name__, logging.INFO, fmt='%(asctime)s-%(levelname)s: %(message)s'
Expand All @@ -49,32 +52,61 @@ def map_dtype(pd_dtype):
raise TypeError(f"Unsupported dtype: {pd_dtype}")


def run_pir_pass(program, partition_mode=False, disable_passes=[], scope=None):
def all_ops_into_trt(program):
for op in program.global_block().ops:
if (
op.name() == "pd_op.fetch"
or op.name() == "pd_op.data"
or op.name().split('.')[0] == "builtin"
):
continue
if op.has_attr("__l_trt__") is False:
return False
if op.attrs()["__l_trt__"] is False:
return False
_logger.info("All ops convert to trt.")
return True


def run_pir_pass(program, disable_passes=[], scope=None):
def _add_pass_(pm, passes, disable_passes):
for pass_item in passes:
for pass_name, pass_attr in pass_item.items():
if pass_name in disable_passes:
continue
pm.add_pass(pass_name, pass_attr)

pm = pir.PassManager(opt_level=4)
pm.enable_print_statistics()
paddle.base.libpaddle.pir.infer_symbolic_shape_pass(pm, program)
if scope is None:
scope = paddle.static.global_scope()
place = paddle.CUDAPlace(0)

# run marker pass
passes = [
{'trt_op_marker_pass': {}},
{
'constant_folding_pass': {
"__place__": place,
"__param_scope__": scope,
}
},
{'conv2d_add_fuse_pass': {}},
{'trt_op_marker_pass': {}}, # for fusion op
]
if partition_mode:
passes = [{'trt_sub_graph_extract_pass': {}}]

for pass_item in passes:
for pass_name, pass_attr in pass_item.items():
if pass_name in disable_passes:
continue
pm.add_pass(pass_name, pass_attr)
_add_pass_(pm, passes, disable_passes)
pm.run(program)

# run other passes
pm.clear()
passes = []
if all_ops_into_trt(program):
# only run constant_folding_pass when all ops into trt
passes.append(
{
'constant_folding_pass': {
"__place__": place,
"__param_scope__": scope,
}
}
)

passes.append({'conv2d_add_fuse_pass': {}})
passes.append({'trt_op_marker_pass': {}}) # for op that created by pass
_add_pass_(pm, passes, disable_passes)
pm.run(program)

# delete unused op
Expand All @@ -86,6 +118,15 @@ def run_pir_pass(program, partition_mode=False, disable_passes=[], scope=None):
return program


def run_trt_partition(program):
pm = pir.PassManager(opt_level=4)
pm.enable_print_statistics()
paddle.base.libpaddle.pir.infer_symbolic_shape_pass(pm, program)
pm.add_pass("trt_sub_graph_extract_pass", {})
pm.run(program)
return program


def forbid_op_lower_trt(program, disabled_ops):
if isinstance(disabled_ops, str):
disabled_ops = [disabled_ops]
Expand Down Expand Up @@ -198,6 +239,28 @@ def get_force_fp32_ops(self):
return []


class TensorRTConstantManager:
_instance = None

def __new__(cls, trt_config=None):
if not cls._instance:
cls._instance = super().__new__(cls)
cls._instance.constant_dict = {}
return cls._instance

def set_constant_value(self, name, tensor_data, value):
out_dtype = np.dtype(_PADDLE_PIR_DTYPE_2_NUMPY_DTYPE[value.dtype])
if out_dtype == np.dtype("float64"):
out_dtype = np.dtype("float32")
if out_dtype == np.dtype("int64"):
out_dtype = np.dtype("int32")
constant_array = np.array(tensor_data, dtype=out_dtype)
self.constant_dict.update({name: constant_array})

def get_constant_value(self, name):
return self.constant_dict[name]


# In TensorRT FP16 inference, this function sets the precision of specific
# operators to FP32, ensuring numerical accuracy for these operations.
def support_fp32_mix_precision(op_type, layer, trt_config=None):
Expand All @@ -223,8 +286,6 @@ def weight_to_tensor(network, paddle_value, trt_tensor, use_op_name):
]
if use_op_name in forbid_cast_op:
return trt_tensor
if paddle_value.get_defining_op().name() == "builtin.constant":
return trt_tensor
input_shape = paddle_value.shape
if type(trt_tensor) == trt.Weights:
return network.add_constant(input_shape, trt_tensor).get_output(0)
Expand Down
5 changes: 2 additions & 3 deletions test/tensorrt/tensorrt_test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
from paddle.tensorrt.util import (
mark_builtin_op,
run_pir_pass,
run_trt_partition,
warmup_shape_infer,
)

Expand Down Expand Up @@ -264,7 +265,6 @@ def check_trt_result(self, rtol=1e-4, atol=1e-4, precision_mode="fp32"):
# run pir pass(including some constant fold pass, dead code elimination pass, fusion pass and trt_op_marker_pass)
main_program = run_pir_pass(
main_program,
partition_mode=False,
disable_passes=self.disable_passes,
)

Expand All @@ -285,7 +285,7 @@ def check_trt_result(self, rtol=1e-4, atol=1e-4, precision_mode="fp32"):
mark_builtin_op(main_program)

# run trt_sub_graph_extract_pass()
program_with_trt = run_pir_pass(main_program, partition_mode=True)
program_with_trt = run_trt_partition(main_program)

# run TRTConverter(would lower group_op into tensorrt_engine_op)
trt_config = None
Expand Down Expand Up @@ -340,7 +340,6 @@ def check_marker(self, expected_result):
)
main_program = run_pir_pass(
main_program,
partition_mode=False,
disable_passes=self.disable_passes,
)
marker_result = False
Expand Down
Loading