From 96700eb9a6b9f70dbb24958df76e42079e71fff3 Mon Sep 17 00:00:00 2001 From: i-colbert Date: Tue, 31 Dec 2024 18:10:05 +0000 Subject: [PATCH] Fix: pytree warning --- src/brevitas/backport/fx/immutable_collections.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/brevitas/backport/fx/immutable_collections.py b/src/brevitas/backport/fx/immutable_collections.py index 0144e4701..35f3e23ad 100644 --- a/src/brevitas/backport/fx/immutable_collections.py +++ b/src/brevitas/backport/fx/immutable_collections.py @@ -43,7 +43,7 @@ from typing import Any, Dict, List, Tuple -from torch.utils._pytree import _register_pytree_node +from torch.utils._pytree import register_pytree_node from torch.utils._pytree import Context from ._compatibility import compatibility @@ -111,5 +111,5 @@ def _immutable_list_unflatten(values: List[Any], context: Context) -> List[Any]: return immutable_list(values) -_register_pytree_node(immutable_dict, _immutable_dict_flatten, _immutable_dict_unflatten) -_register_pytree_node(immutable_list, _immutable_list_flatten, _immutable_list_unflatten) +register_pytree_node(immutable_dict, _immutable_dict_flatten, _immutable_dict_unflatten) +register_pytree_node(immutable_list, _immutable_list_flatten, _immutable_list_unflatten)