You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
a while back I did some work with Jax but I didn't want to miss the nice logging and boilerplate code reduction features of Lightning which is built around PyTorch.
After some thinking and tinkering I settled on a solution which required three steps:
Set `self.automatic_optimization=False' to do our own Jax based optimization
Load data in numpy mode to make it accessible to Jax
With `self.automatic_optimization=False' execute the forward, backward and gradient optimization step myself.
With those three changes I was able to run and train Jax models (the Equinox flavor) within Lightning with the logging, code structure and general data and model flow of a Lightning module.
Pitch
@lantiga was so kind to approach me to potentially integrate this idea into Lightning itself and I'd be more than happy to contribute this idea and implementation to Lightning.
Since the way I got Jax to run within Lightning is quite lightweight I'm hopeful that an integration of Jax could be possible.
Fortunately, from my rudimentary non-large-scale experience with Jax, it takes care of a lot of placement and parallelization under the hood.
One thing to keep in mind is that the Jax ecosystem is less coalesced and there are several neural network and optimization frameworks build on top of Jax which only provides a tool set. This is in contrast to PyTorch which already includes the batteries with torch.optim. In that sense Lightning could either support a predefined list of packages or outsource a bit more basic implementation to the user and provide higher level functions like gradient accumulation for example.
The (incomplete) list of things I can I think of that maybe somebody with more knowledge of Lightning could answer is:
Could Jax be detected automatically or would there be an extra flag framework=[ jax | pytorch ]?
Would PyTorch and Jax be allowed to exist in the same module side by side? That could become messy.
What is the depth of support for individual jax packages (Equinox vs Flax vs Haiku)?
Checkpointing
(probably so much more that I haven't thought of)
Happy to discuss the motivation, interest, feasibility and possible implementation ideas.
Functions like pmap come natively with jax supporting multi-gpu placement, training and aggregation out of the box
Whereas in PyTorch one only has to pass around the (scalar) loss term (with the entire computational graph attached to it in the background), jax defines grad. This implies that we would/could pass around functions instead of tensors.
JAX uses functions to compute gradients wrt the inputs (instead of tracing the graph ala PyTorch). Computing both the loss and the gradients can be as easy as just callling value_and_grad
By design JAX wants pure functions, so the self attribute of class methods has to be circumvented if we want to use jax.jit and jax.grad.
Due to the non-homogeneity of the JAX NN ecosystem, there might be some trade-off of what can be supported. Do we let users write a bit more code and this allows us to support a wider range of JAX packages by not imposing constraints, or do we only support i.e. equinox and risk betting everything on one horse.
I'd dare to say most JAX users are a bit more used to implement fancy operations themselves (due to cool features but also some more implementation work vis-a-vis PyTorch).
Description & Motivation
Hi,
a while back I did some work with Jax but I didn't want to miss the nice logging and boilerplate code reduction features of Lightning which is built around PyTorch.
After some thinking and tinkering I settled on a solution which required three steps:
A repo with examples can be found here: https://github.com/ludwigwinkler/JaxLightning (the name is somewhat unoriginal).
With those three changes I was able to run and train Jax models (the Equinox flavor) within Lightning with the logging, code structure and general data and model flow of a Lightning module.
Pitch
@lantiga was so kind to approach me to potentially integrate this idea into Lightning itself and I'd be more than happy to contribute this idea and implementation to Lightning.
Since the way I got Jax to run within Lightning is quite lightweight I'm hopeful that an integration of Jax could be possible.
Fortunately, from my rudimentary non-large-scale experience with Jax, it takes care of a lot of placement and parallelization under the hood.
One thing to keep in mind is that the Jax ecosystem is less coalesced and there are several neural network and optimization frameworks build on top of Jax which only provides a tool set. This is in contrast to PyTorch which already includes the batteries with
torch.optim
. In that sense Lightning could either support a predefined list of packages or outsource a bit more basic implementation to the user and provide higher level functions like gradient accumulation for example.The (incomplete) list of things I can I think of that maybe somebody with more knowledge of Lightning could answer is:
framework=[ jax | pytorch ]
?Happy to discuss the motivation, interest, feasibility and possible implementation ideas.
Alternatives
No response
Additional context
No response
cc @Borda @tchaton @justusschock @awaelchli
The text was updated successfully, but these errors were encountered: