Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Jax Support within Lightning #20458

Open
ludwigwinkler opened this issue Nov 28, 2024 · 2 comments
Open

Jax Support within Lightning #20458

ludwigwinkler opened this issue Nov 28, 2024 · 2 comments
Assignees
Labels
design Includes a design discussion feature Is an improvement or enhancement

Comments

@ludwigwinkler
Copy link

ludwigwinkler commented Nov 28, 2024

Description & Motivation

Hi,

a while back I did some work with Jax but I didn't want to miss the nice logging and boilerplate code reduction features of Lightning which is built around PyTorch.
After some thinking and tinkering I settled on a solution which required three steps:

  • Set `self.automatic_optimization=False' to do our own Jax based optimization
  • Load data in numpy mode to make it accessible to Jax
  • With `self.automatic_optimization=False' execute the forward, backward and gradient optimization step myself.

A repo with examples can be found here: https://github.com/ludwigwinkler/JaxLightning (the name is somewhat unoriginal).

With those three changes I was able to run and train Jax models (the Equinox flavor) within Lightning with the logging, code structure and general data and model flow of a Lightning module.

Pitch

@lantiga was so kind to approach me to potentially integrate this idea into Lightning itself and I'd be more than happy to contribute this idea and implementation to Lightning.

Since the way I got Jax to run within Lightning is quite lightweight I'm hopeful that an integration of Jax could be possible.
Fortunately, from my rudimentary non-large-scale experience with Jax, it takes care of a lot of placement and parallelization under the hood.

One thing to keep in mind is that the Jax ecosystem is less coalesced and there are several neural network and optimization frameworks build on top of Jax which only provides a tool set. This is in contrast to PyTorch which already includes the batteries with torch.optim. In that sense Lightning could either support a predefined list of packages or outsource a bit more basic implementation to the user and provide higher level functions like gradient accumulation for example.

The (incomplete) list of things I can I think of that maybe somebody with more knowledge of Lightning could answer is:

  • Could Jax be detected automatically or would there be an extra flag framework=[ jax | pytorch ]?
  • Would PyTorch and Jax be allowed to exist in the same module side by side? That could become messy.
  • What is the depth of support for individual jax packages (Equinox vs Flax vs Haiku)?
  • Checkpointing
  • (probably so much more that I haven't thought of)

Happy to discuss the motivation, interest, feasibility and possible implementation ideas.

Alternatives

No response

Additional context

No response

cc @Borda @tchaton @justusschock @awaelchli

@ludwigwinkler ludwigwinkler added feature Is an improvement or enhancement needs triage Waiting to be triaged by maintainers labels Nov 28, 2024
@lantiga lantiga removed the needs triage Waiting to be triaged by maintainers label Dec 4, 2024
@lantiga lantiga self-assigned this Dec 4, 2024
@lantiga lantiga added the design Includes a design discussion label Dec 5, 2024
@lantiga
Copy link
Collaborator

lantiga commented Dec 11, 2024

Thank you @ludwigwinkler, currently working at cutting 2.5, let's catch up on this one after the release in a few days

@ludwigwinkler
Copy link
Author

Just a couple of observations/notes:

  • Functions like pmap come natively with jax supporting multi-gpu placement, training and aggregation out of the box
  • Whereas in PyTorch one only has to pass around the (scalar) loss term (with the entire computational graph attached to it in the background), jax defines grad. This implies that we would/could pass around functions instead of tensors.
  • JAX uses functions to compute gradients wrt the inputs (instead of tracing the graph ala PyTorch). Computing both the loss and the gradients can be as easy as just callling value_and_grad
  • By design JAX wants pure functions, so the self attribute of class methods has to be circumvented if we want to use jax.jit and jax.grad.
  • Due to the non-homogeneity of the JAX NN ecosystem, there might be some trade-off of what can be supported. Do we let users write a bit more code and this allows us to support a wider range of JAX packages by not imposing constraints, or do we only support i.e. equinox and risk betting everything on one horse.
  • I'd dare to say most JAX users are a bit more used to implement fancy operations themselves (due to cool features but also some more implementation work vis-a-vis PyTorch).

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
design Includes a design discussion feature Is an improvement or enhancement
Projects
None yet
Development

No branches or pull requests

2 participants