Skip to content

Commit

Permalink
fix (jax backend)(general.py): adding support for Flax variables to t…
Browse files Browse the repository at this point in the history
…he function.
  • Loading branch information
YushaArif99 committed Sep 21, 2024
1 parent 83f133a commit eeda5fc
Show file tree
Hide file tree
Showing 2 changed files with 4 additions and 1 deletion.
3 changes: 3 additions & 0 deletions ivy/functional/backends/jax/general.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,9 @@ def current_backend_str() -> str:
def is_native_array(x, /, *, exclusive=False):
if exclusive:
return isinstance(x, NativeArray)
elif any(cls in str(x.__class__) for cls in ['flax.nnx.nnx.variables.Param', 'flax.core.scope.Variable']):
# ensure flax Variables(linen, nnx) classify as a native array if `exclusive` is False
return True
return isinstance(
x,
(
Expand Down
2 changes: 1 addition & 1 deletion ivy/functional/ivy/general.py
Original file line number Diff line number Diff line change
Expand Up @@ -289,7 +289,7 @@ def is_native_array(
x
The input to check
exclusive
Whether to check if the data type is exclusively an array, rather than a
Whether to check if the input x is exclusively an array, rather than a
variable or traced array.
Returns
Expand Down

0 comments on commit eeda5fc

Please sign in to comment.