Skip to content

Commit

Permalink
Add tests for externalization
Browse files Browse the repository at this point in the history
  • Loading branch information
vinayakdsci committed Nov 12, 2024
1 parent 13b5638 commit 1da3d7a
Show file tree
Hide file tree
Showing 2 changed files with 58 additions and 0 deletions.
58 changes: 58 additions & 0 deletions compiler/bindings/python/test/tools/import_onnx_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,64 @@ def run_tool(*argv: str):


ONNX_FILE_PATH = os.path.join(os.path.dirname(__file__), "testdata", "LeakyReLU.onnx")
LARGE_WEIGHTS_ONNX_FILE_PATH = os.path.join(
os.path.dirname(__file__), "testdata", "conv.onnx"
)


class ImportOnnxwithExternalizationTest(unittest.TestCase):
def setUp(self):
with tempfile.NamedTemporaryFile(delete=False) as f:
self.outputPath = f.name

def tearDown(self) -> None:
if os.path.exists(self.outputPath):
os.unlink(self.outputPath)
if os.path.exists("custom_params_file.irpa"):
os.unlink("custom_params_file.irpa")
if os.path.exists(str(self.outputPath) + "_params.irpa"):
os.unlink(str(self.outputPath) + "_params.irpa")

def testExternalizeWeightsDefaultThreshold(self):
run_tool(
LARGE_WEIGHTS_ONNX_FILE_PATH, "--externalize-params", "-o", self.outputPath
)
with open(self.outputPath, "rt") as f:
contents = f.read()
self.assertIn("util.global", contents)
self.assertIn("util.global.load", contents)
assert os.path.isfile(str(self.outputPath) + "_params.irpa")

def testExternalizeParamsSaveCustomPath(self):
run_tool(
LARGE_WEIGHTS_ONNX_FILE_PATH,
"--externalize-params",
"--save-params-to",
"custom_params_file.irpa",
"-o",
self.outputPath,
)
with open(self.outputPath, "rt") as f:
contents = f.read()
self.assertIn("util.global", contents)
self.assertIn("util.global.load", contents)
assert os.path.isfile("custom_params_file.irpa")

def testExternalizeTooHighThreshold(self):
num_elements_weights = 1 * 256 * 100 * 100 + 1
run_tool(
LARGE_WEIGHTS_ONNX_FILE_PATH,
"--externalize-params",
"--num-elements-threshold",
str(num_elements_weights),
"-o",
self.outputPath,
)
with open(self.outputPath, "rt") as f:
contents = f.read()
self.assertNotIn("util.global", contents)
self.assertNotIn("util.global.load", contents)
self.assertIn("onnx.Constant", contents)


class ImportOnnxTest(unittest.TestCase):
Expand Down
Binary file not shown.

0 comments on commit 1da3d7a

Please sign in to comment.