Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Is it possible to use tf.data with tf operations while utilizing jax or torch as the backend? #20722

Closed
innat opened this issue Jan 3, 2025 · 1 comment
Assignees

Comments

@innat
Copy link

innat commented Jan 3, 2025

Apart from tensorflow as backend, what are the proper approach to use basic operatons (i.e. tf.concat) inside the tf.data API pipelines? The following code works with tensorflow backend, but not with torch or jax.

import os
os.environ["KERAS_BACKEND"] = "jax" # tensorflow, torch, jax

import keras
from keras import layers
import tensorflow as tf

aug_model = keras.Sequential([
    keras.Input(shape=(224, 224, 3)),
    layers.RandomFlip("horizontal_and_vertical")
])

def augment_data_tf(x, y):
    combined = tf.concat([x, y], axis=-1)
    z = aug_model(combined)
    x = z[..., :3]
    y = z[..., 3:]
    return x, y

a = np.ones((4, 224, 224, 3)).astype(np.float32)
b = np.ones((4, 224, 224, 2)).astype(np.float32)

dataset = tf.data.Dataset.from_tensor_slices((a, b))
dataset = dataset.batch(3, drop_remainder=True)
dataset = dataset.map(
    augment_data_tf, num_parallel_calls=tf.data.AUTOTUNE
)
---------------------------------------------------------------------------
NotImplementedError                       Traceback (most recent call last)
[<ipython-input-7-2d25b0c0bbad>](https://localhost:8080/#) in <cell line: 3>()
      1 dataset = tf.data.Dataset.from_tensor_slices((a, b))
      2 dataset = dataset.batch(3, drop_remainder=True)
----> 3 dataset = dataset.map(
      4     augment_data_tf, num_parallel_calls=tf.data.AUTOTUNE
      5 )

25 frames
[/usr/local/lib/python3.10/dist-packages/jax/_src/numpy/lax_numpy.py](https://localhost:8080/#) in _convert_to_array_if_dtype_fails(x)
   4102     dtypes.dtype(x)
   4103   except TypeError:
-> 4104     return np.asarray(x)
   4105   else:
   4106     return x

NotImplementedError: in user code:

    File "<ipython-input-5-ca4b074b58a5>", line 6, in augment_data_tf  *
        z = aug_model(combined)
    File "/usr/local/lib/python3.10/dist-packages/keras/src/utils/traceback_utils.py", line 122, in error_handler  **
        raise e.with_traceback(filtered_tb) from None
    File "/usr/local/lib/python3.10/dist-packages/optree/ops.py", line 752, in tree_map
        return treespec.unflatten(map(func, *flat_args))
    File "/usr/local/lib/python3.10/dist-packages/jax/_src/numpy/lax_numpy.py", line 4252, in asarray
        return array(a, dtype=dtype, copy=bool(copy), order=order, device=device)
    File "/usr/local/lib/python3.10/dist-packages/jax/_src/numpy/lax_numpy.py", line 4058, in array
        leaves = [_convert_to_array_if_dtype_fails(leaf) for leaf in leaves]
    File "/usr/local/lib/python3.10/dist-packages/jax/_src/numpy/lax_numpy.py", line 4058, in <listcomp>
        leaves = [_convert_to_array_if_dtype_fails(leaf) for leaf in leaves]
    File "/usr/local/lib/python3.10/dist-packages/jax/_src/numpy/lax_numpy.py", line 4104, in _convert_to_array_if_dtype_fails
        return np.asarray(x)

    NotImplementedError: Cannot convert a symbolic tf.Tensor (concat:0) to a numpy array. This error may indicate that you're trying to pass a Tensor to a NumPy call, which is not supported.
@innat innat changed the title NotImplementedError: Cannot convert a symbolic tf.Tensor (concat:0) to a numpy array. This error may indicate that you're trying to pass a Tensor to a NumPy call, which is not supported. Is it possible to use tf.data with tf operations while utilizing jax or torch as the backend? Jan 4, 2025
@fchollet
Copy link
Collaborator

fchollet commented Jan 4, 2025

The part of your pipeline that is fundamentally backend-dependent is Sequential, because it will seek to convert its inputs to backend-native tensors.

This is generally true of all Keras layers. However some layers are special-cased to be compatible with tf.data regardless of your backend. This is true of all augmentation and preprocessing layers, including RandomFlip.

So here are two options:

  1. Use a list of tf.data compatible layers, applied in a loop.
aug_layers = [
    layers.RandomFlip("horizontal_and_vertical"),
]

def augment_data_tf(x, y):
    z = tf.concat([x, y], axis=-1)
    for layer in aug_layers:
        z = layer(z)
    x = z[..., :3]
    y = z[..., 3:]
    return x, y

a = np.ones((4, 224, 224, 3)).astype(np.float32)
b = np.ones((4, 224, 224, 2)).astype(np.float32)

dataset = tf.data.Dataset.from_tensor_slices((a, b))
dataset = dataset.batch(3, drop_remainder=True)
dataset = dataset.map(
    augment_data_tf, num_parallel_calls=tf.data.AUTOTUNE
)
  1. Use the Pipeline class, which is designed for exactly this.
pipeline = keras.layers.Pipeline([
    layers.RandomFlip("horizontal_and_vertical"),
])

def augment_data_tf(x, y):
    z = tf.concat([x, y], axis=-1)
    z = pipeline(z)
    x = z[..., :3]
    y = z[..., 3:]
    return x, y

a = np.ones((4, 224, 224, 3)).astype(np.float32)
b = np.ones((4, 224, 224, 2)).astype(np.float32)

dataset = tf.data.Dataset.from_tensor_slices((a, b))
dataset = dataset.batch(3, drop_remainder=True)
dataset = dataset.map(
    augment_data_tf, num_parallel_calls=tf.data.AUTOTUNE
)

@fchollet fchollet closed this as completed Jan 4, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants