-
Notifications
You must be signed in to change notification settings - Fork 5.5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[core] Layerwise Upcasting #10347
base: main
Are you sure you want to change the base?
[core] Layerwise Upcasting #10347
Conversation
Co-Authored-By: Dhruv Nair <[email protected]>
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
Nice start 👍🏽 A few things to consider here
It is difficult to maintain a large global list of supported ops and can lead to us either missing modules or not applying upcasting in cases where it can be used.
So any kind of layerwise casting on these modules runs into an error because the parameters remain in a lower memory dtype unless the entire module is upcast. The initial PR got around this by adding the
This implementation seems to do something similar using the global
|
An enumeration class that defines the granularity of the layerwise upcasting process. | ||
|
||
Granularity can be one of the following: | ||
- `DIFFUSERS_MODEL`: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Peak memory usage during inference would be the same this approach right? Memory footprint is only really lower when the model is not being used?
I suppose there is a case that you could keep the entire pipeline on GPU and save memory by upcasting individual components, but in practice that would only work well on GPUs with enough memory to store both the upcast model + the storage dtype versions of all the other components.
We can test to see if there's a good use case for this granularity level, perhaps by comparing inference speed against cpu offload. Although I think you need > 24GB VRAM to see real benefits.
[...continuation of #9177]
Pytorch has had support for
float8_e4m3fn
andfloat8_e5m2
as storage dtypes for a while now. This allows one to store model weights in a lower precision dtype and upcast them on-the-fly when a layer is required for proceeding with computation.Code
Flux visual results
CogVideoX visual results
cogvideox-1.0---storage_dtype-bfloat16---compute_dtype-bfloat16---granularity-none---compile-False.mp4
cogvideox-1.0---storage_dtype-float8_e4m3fn---compute_dtype-bfloat16---granularity-diffusers_model---compile-False.mp4
cogvideox-1.0---storage_dtype-float8_e5m2---compute_dtype-bfloat16---granularity-diffusers_model---compile-False.mp4
cogvideox-1.0---storage_dtype-float8_e4m3fn---compute_dtype-bfloat16---granularity-diffusers_layer---compile-False.mp4
cogvideox-1.0---storage_dtype-float8_e5m2---compute_dtype-bfloat16---granularity-diffusers_layer---compile-False.mp4
cogvideox-1.0---storage_dtype-float8_e4m3fn---compute_dtype-bfloat16---granularity-pytorch_layer---compile-False.mp4
cogvideox-1.0---storage_dtype-float8_e5m2---compute_dtype-bfloat16---granularity-pytorch_layer---compile-False.mp4
Assumptions made so far:
compute_dtype
storage_dtype
.Why is there no memory savings in the initial load memory?
We are first moving weights to VRAM and then performing the lower dtype casting. We should maybe look into directly allowing loading of weights of lower dtype
Why different "granularities"?
This was mostly an experiment and we don't need to use everything in the PR. I wanted to understand the affect of typecasting all weights vs some of them vs only the pytorch primitives. As usual, image models seem to be less affected by normalization casting (from
DIFFUSERS_MODEL
granularity compared to video models. However, the more granular we try to go, the more times weights are casted per inference step and more synchronizations are introduced with the current implementation, leading to slow downs in inference time. Allowing different levels of applying the typecasting hooks is akin to what we have formodel cpu offloading
vssequential cpu offloading
, and allows for some tradeoffs that users can choose based on their use cases.Is this compatible with
torch.compile
?No, it isn't because we overwrite the forward method of underlying models to invoke a pre-hook and post-hook. Both the pre and post hook change the state of the underlying model (downcast or upcast it) per forward pass, which makes it incompatible as it does not fit with the rules of
torch.compile
. Using@torch._dynamo.disable(recursive=False)
or similar does not seem to work.Why a different approach from #9177?
While providing the API to use this via
ModelMixin
is okay, it puts a restriction that requires all implementations to derive from it to use it. As this method can be generally applied to any modeling component, at any level of granularity, implementing it independent ofModelMixin
allows for its use in other modeling components like text encoders, which come from transformers, and any downstream research work or library can directly use it for their demos on Spaces without having to reimplement the wheel.Not opposed to the idea of having
enable_layerwise_upcasting
inModelMixin
, but let's do it in a way that does not impose any restrictions on how it's possible to use it.Also, the original PR typecasted all leaf nodes to storage dtype, but this may not be ideal for things like normalization and modulation, so supporting parameters like
skip_modules_pattern
andskip_modules_classes
helps ignore a few layers. We can default to sensible values, while to maintain another parameter per class for layers to not upcast/downcast. This is also one of the places where it helps to follow a common naming convention across all our models.Fixes #9949
cc @vladmandic @asomoza
TODOs:
non_blocking
and cuda streams for overlapping weight casting with computation without introducing many stream synchronizations on default streamTensor
,LongTensor
,BoolTensor
, etc. and we should not typecast all of them tocompute_dtype
, which would be incorrectNice reading material for the interested: