You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Apart from tensorflow as backend, what are the proper approach to use basic operatons (i.e. tf.concat) inside the tf.data API pipelines? The following code works with tensorflow backend, but not with torch or jax.
---------------------------------------------------------------------------
NotImplementedError Traceback (most recent call last)
[<ipython-input-7-2d25b0c0bbad>](https://localhost:8080/#) in <cell line: 3>()
1 dataset = tf.data.Dataset.from_tensor_slices((a, b))
2 dataset = dataset.batch(3, drop_remainder=True)
----> 3 dataset = dataset.map(
4 augment_data_tf, num_parallel_calls=tf.data.AUTOTUNE
5 )
25 frames
[/usr/local/lib/python3.10/dist-packages/jax/_src/numpy/lax_numpy.py](https://localhost:8080/#) in _convert_to_array_if_dtype_fails(x)
4102 dtypes.dtype(x)
4103 except TypeError:
-> 4104 return np.asarray(x)
4105 else:
4106 return x
NotImplementedError: in user code:
File "<ipython-input-5-ca4b074b58a5>", line 6, in augment_data_tf *
z = aug_model(combined)
File "/usr/local/lib/python3.10/dist-packages/keras/src/utils/traceback_utils.py", line 122, in error_handler **
raise e.with_traceback(filtered_tb) from None
File "/usr/local/lib/python3.10/dist-packages/optree/ops.py", line 752, in tree_map
return treespec.unflatten(map(func, *flat_args))
File "/usr/local/lib/python3.10/dist-packages/jax/_src/numpy/lax_numpy.py", line 4252, in asarray
return array(a, dtype=dtype, copy=bool(copy), order=order, device=device)
File "/usr/local/lib/python3.10/dist-packages/jax/_src/numpy/lax_numpy.py", line 4058, in array
leaves = [_convert_to_array_if_dtype_fails(leaf) forleafin leaves]
File "/usr/local/lib/python3.10/dist-packages/jax/_src/numpy/lax_numpy.py", line 4058, in<listcomp>
leaves = [_convert_to_array_if_dtype_fails(leaf) forleafin leaves]
File "/usr/local/lib/python3.10/dist-packages/jax/_src/numpy/lax_numpy.py", line 4104, in _convert_to_array_if_dtype_fails
return np.asarray(x)
NotImplementedError: Cannot convert a symbolic tf.Tensor (concat:0) to a numpy array. This error may indicate that you're trying to pass a Tensor to a NumPy call, which is not supported.
The text was updated successfully, but these errors were encountered:
innat
changed the title
NotImplementedError: Cannot convert a symbolic tf.Tensor (concat:0) to a numpy array. This error may indicate that you're trying to pass a Tensor to a NumPy call, which is not supported.
Is it possible to use tf.data with tf operations while utilizing jax or torch as the backend?
Jan 4, 2025
The part of your pipeline that is fundamentally backend-dependent is Sequential, because it will seek to convert its inputs to backend-native tensors.
This is generally true of all Keras layers. However some layers are special-cased to be compatible with tf.data regardless of your backend. This is true of all augmentation and preprocessing layers, including RandomFlip.
So here are two options:
Use a list of tf.data compatible layers, applied in a loop.
Apart from tensorflow as backend, what are the proper approach to use basic operatons (i.e. tf.concat) inside the tf.data API pipelines? The following code works with tensorflow backend, but not with torch or jax.
The text was updated successfully, but these errors were encountered: