Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Does Keras 3 work with tf.distribute.MultiWorkerMirroredStrategy? #20585

Open
justinvyu opened this issue Dec 4, 2024 · 6 comments
Open

Does Keras 3 work with tf.distribute.MultiWorkerMirroredStrategy? #20585

justinvyu opened this issue Dec 4, 2024 · 6 comments
Assignees
Labels
stat:awaiting keras-eng Awaiting response from Keras engineer type:Bug

Comments

@justinvyu
Copy link

justinvyu commented Dec 4, 2024

This user guide shows Keras 3 usage with tf.distribute.MirroredStrategy, which is only useful for single-node multi-GPU data parallel training.

However, tf.distribute.MultiWorkerMirroredStrategy is required for multi-node data parallel training.

This example on Tensorflow's docs does not work with keras-nightly-3.7.0.dev2024120303, tf-nightly-2.19.0.dev20241203, though this problem has been around for a few releases of tensorflow (since tensorflow==2.16 when the default Keras version got bumped to Keras 3; see tensorflow/tensorflow#72388).

Question: Does Keras 3 support tf.distribute.MultiWorkerMirroredStrategy for TF distributed training, or does it only support tf.distribute.MirroredStrategy?

Reproduction

Try running this example as a colab notebook: https://www.tensorflow.org/tutorials/distribute/multi_worker_with_keras

You will encounter this error:

%%bash
python main.py

Traceback (most recent call last):
  File "/content/main.py", line 22, in <module>
    multi_worker_model.fit(multi_worker_dataset, epochs=3, steps_per_epoch=70)
  File "/usr/local/lib/python3.10/dist-packages/keras/src/utils/traceback_utils.py", line 122, in error_handler
    raise e.with_traceback(filtered_tb) from None
  File "/usr/local/lib/python3.10/dist-packages/tensorflow/python/framework/constant_op.py", line 108, in convert_to_eager_tensor
    return ops.EagerTensor(value, ctx.device_name, dtype)
ValueError: Attempt to convert a value (PerReplica:{
  0: <tf.Tensor: shape=(64, 28, 28), dtype=float32, numpy=
array([[[0., 0., 0., ..., 0., 0., 0.],
        [0., 0., 0., ..., 0., 0., 0.],
        [0., 0., 0., ..., 0., 0., 0.],
        ...,
        [0., 0., 0., ..., 0., 0., 0.],
        [0., 0., 0., ..., 0., 0., 0.],
        [0., 0., 0., ..., 0., 0., 0.]],

       [[0., 0., 0., ..., 0., 0., 0.],
        [0., 0., 0., ..., 0., 0., 0.],
        [0., 0., 0., ..., 0., 0., 0.],
        ...,
        [0., 0., 0., ..., 0., 0., 0.],
        [0., 0., 0., ..., 0., 0., 0.],
        [0., 0., 0., ..., 0., 0., 0.]],

       [[0., 0., 0., ..., 0., 0., 0.],
        [0., 0., 0., ..., 0., 0., 0.],
        [0., 0., 0., ..., 0., 0., 0.],
        ...,
        [0., 0., 0., ..., 0., 0., 0.],
        [0., 0., 0., ..., 0., 0., 0.],
        [0., 0., 0., ..., 0., 0., 0.]],

       ...,

       [[0., 0., 0., ..., 0., 0., 0.],
        [0., 0., 0., ..., 0., 0., 0.],
        [0., 0., 0., ..., 0., 0., 0.],
        ...,
        [0., 0., 0., ..., 0., 0., 0.],
        [0., 0., 0., ..., 0., 0., 0.],
        [0., 0., 0., ..., 0., 0., 0.]],

       [[0., 0., 0., ..., 0., 0., 0.],
        [0., 0., 0., ..., 0., 0., 0.],
        [0., 0., 0., ..., 0., 0., 0.],
        ...,
        [0., 0., 0., ..., 0., 0., 0.],
        [0., 0., 0., ..., 0., 0., 0.],
        [0., 0., 0., ..., 0., 0., 0.]],

       [[0., 0., 0., ..., 0., 0., 0.],
        [0., 0., 0., ..., 0., 0., 0.],
        [0., 0., 0., ..., 0., 0., 0.],
        ...,
        [0., 0., 0., ..., 0., 0., 0.],
        [0., 0., 0., ..., 0., 0., 0.],
        [0., 0., 0., ..., 0., 0., 0.]]], dtype=float32)>
}) with an unsupported type (<class 'tensorflow.python.distribute.values.PerReplica'>) to a Tensor.
@dryglicki
Copy link
Contributor

dryglicki commented Dec 4, 2024

Welcome to the party, pal: #20329

I will say, it's good to know that it looks like you're not using input dictionaries, thus removing one layer of abstraction, and it still doesn't work.

@dhantule
Copy link
Contributor

dhantule commented Dec 5, 2024

Hi @justinvyu,

Thanks for reporting this. You can use tf.distribute.MirroredStrategy to implement data parallelism, it seems tf.distribute.MultiWorkerMirroredStrategy is not supported with keras 3 yet as mentioned in this documentation. You can read more about data parallelism here.

@justinvyu
Copy link
Author

Thanks @dhantule. Do you know when the tf.distribute.MultiWorkerMirroredStrategy support is slated on the roadmap? Is there somewhere that I can read about the roadmap?

@dhantule
Copy link
Contributor

dhantule commented Dec 6, 2024

Hi @justinvyu, data parallelism and distributed tuning can be combined, we can run multiple trials of training and leverage data parallelism to speed up the training process. If you have 8 workers with 2 GPUs on each worker, you can run 8 parallel trials with each trial training on 2 GPUs by using tf.distribute.MirroredStrategy.

@dhantule dhantule added the keras-team-review-pending Pending review by a Keras team member. label Dec 6, 2024
@richardliaw
Copy link

hi @dhantule we need to be able to train a single model across multiple nodes still, even without distributed tuning.

@fchollet
Copy link
Collaborator

fchollet commented Dec 8, 2024

Keras 3 does not work with tf.distribute.MultiWorkerMirroredStrategy.

Whether support with be added is a question for the Keras team at Google to answer. @jeffcarp, do you know?

If you problem is, "I have many machines, each with 1 or more GPUs, and I want to train a single model on all of them", then the answer is Keras 3 + JAX + the keras.distribution.ModelParallel API. It will do a much better job at utilizing your devices. And it works in multi-node settings, to be clear.

@sachinprasadhs sachinprasadhs added stat:awaiting keras-eng Awaiting response from Keras engineer and removed keras-team-review-pending Pending review by a Keras team member. labels Dec 12, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
stat:awaiting keras-eng Awaiting response from Keras engineer type:Bug
Projects
None yet
Development

No branches or pull requests

7 participants