-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathlist_plugins.py
72 lines (55 loc) · 2.63 KB
/
list_plugins.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
#!/usr/bin/env python3
# lix19937
import ctypes
import logging
import tensorrt as trt
logging.basicConfig(level=logging.DEBUG,
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
datefmt="%Y-%m-%d %H:%M:%S")
logger = logging.getLogger(__name__)
TRT_LOGGER = trt.Logger(trt.Logger.INFO)
# https://docs.nvidia.com/deeplearning/sdk/tensorrt-api/python_api/infer/Plugin/IPluginCreator.html
def get_all_plugin_details(plugin_registry):
details = {}
for c in plugin_registry.plugin_creator_list:
details[c.name] = {}
details[c.name]["tensorrt_version"] = c.tensorrt_version
details[c.name]["plugin_version"] = c.plugin_version
details[c.name]["plugin_namespace"] = c.plugin_namespace
details[c.name]["PluginFields"] = []
plugin_field_collection = c.field_names
if plugin_field_collection:
for i, x in enumerate(list(plugin_field_collection)):
pfd = {
"name": x.name,
"data": x.data,
"type": str(x.type),
"size": x.size
}
details[c.name]["PluginFields"].append(pfd)
return details
def get_all_plugin_names(plugin_registry):
return [c.name for c in plugin_registry.plugin_creator_list]
# Example usage:
# (1) python list_plugins.py
# (2) python list_plugins.py --plugins CustomIPluginV2/CustomPlugin.so
# (3) python list_plugins.py --plugins CustomIPluginV2/CustomPlugin.so /mnt/TensorRT/build/out/libnvinfer_plugin.so
if __name__ == "__main__":
import argparse
parser = argparse.ArgumentParser(description="Script to list registered TensorRT plugins. Can optionally load custom plugin libraries.")
parser.add_argument("-p", "--plugins", nargs="*", default=[], help="Path to a plugin (.so) library file. Accepts multiple arguments.")
args = parser.parse_args()
for plugin_library in args.plugins:
# Example default plugin library: "/usr/lib/x86_64-linux-gnu/libnvinfer_plugin.so"
logger.info("Loading plugin library: {}".format(plugin_library))
ctypes.CDLL(plugin_library, mode=ctypes.RTLD_GLOBAL)
logger.info("Registering plugins...")
# Register the plugins loaded from libraries
trt.init_libnvinfer_plugins(TRT_LOGGER, "")
# Get plugin registry to view the registered plugins
plugin_registry = trt.get_plugin_registry()
from pprint import pprint
#logger.info("Registered Plugin Details:")
#pprint(get_all_plugin_details(plugin_registry))
logger.info("Registered Plugin Names:")
pprint(get_all_plugin_names(plugin_registry))