Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Experimental] expose dynamic upcasting of layers as experimental APIs #9949

Open
sayakpaul opened this issue Nov 18, 2024 · 10 comments · May be fixed by #10347
Open

[Experimental] expose dynamic upcasting of layers as experimental APIs #9949

sayakpaul opened this issue Nov 18, 2024 · 10 comments · May be fixed by #10347
Assignees

Comments

@sayakpaul
Copy link
Member

sayakpaul commented Nov 18, 2024

Functionalities like #9177 are immensely helpful to load a checkpoint in say, torch.float8_e5m2, perform computation in say, torch.float16, and then keep the result in torch.float8_e5m2 again.

Even though this feature isn't immediately compatible with torch.compile() and we're unsure of its repercussions, we think it's still better to just have them as experimental APIs because the memory benefits are significant.

Cc: @vladmandic as you expressed interest for this.

Cc: @a-r-r-o-w @SunMarc as we discussed it in-person.

Cc: @DN6 because #9177 is his brainchild.

@vladmandic
Copy link
Contributor

thanks @sayakpaul - yes, i'm very interested!
this solution would enable user with old gpus to run newer models - exactly what is a frequent ask!
btw, note that torch.compile is not actually that popular due to its significant overhead.
and quantization is not always compatible, both bitsandbytes and optimum-quanto have limited platform support.

@sayakpaul
Copy link
Member Author

Yes that is why we think exposing this API even in an experimental capacity will make a lot of sense!

@a-r-r-o-w
Copy link
Member

a-r-r-o-w commented Nov 18, 2024

Just to add my two cents and to make others aware of our in-person discussion, I don't think we could ever make it fully compatible with torch.compile due to state change of model at every forward pass, but it will undoubtedly help reduce VRAM requirements. As we're pushing for the hooks API that is in process of being introduced in #9562, the implementation could be something roughly as follows:

class LayerwiseUpcastingHook(ModelHook):
    def __init__(self, compute_dtype: torch.dtype = torch.bfloat16, storage_dtype: torch.dtype = torch.float8_e5m2) -> None:
        super().__init__()

        self.compute_dtype = compute_dtype
        self.storage_dtype = storage_dtype
    
    def pre_forward(self, module: torch.nn.Module, *args, **kwargs) -> Any:
        set_args_dtype(args, dtype=self.compute_dtype)
        set_kwargs_dtype(kwargs, dtype=self.compute_dtype)
        module.to(dtype=self.compute_dtype)
        return args, kwargs

    def post_forward(self, module: torch.nn.Module, output: Any) -> Any:
        module.to(dtype=self.storage_dtype)
        return output


# The below can be abstracted away to whatever makes sense
for block in pipe.transformer.transformer_blocks:
    hook = LayerwiseUpcastingHook(compute_dtype=torch.bfloat16, storage_dtype=torch.float8_e5m2)
    apply_hook_to_module(block, hook, append=True)

This is similar to how accelerate casts to the correct device under the hood, so a similar API to that makes a lot of sense to me. I think TorchAO already does things in this "layerwise" manner at the nn.Module level, but this would still be nice to have for layers that have parameter weights, but are not nn.Linear or nn.Conv2d

I would like to also point out based on our discussion that this approach leads to more granular control and can lead to some more memory savings than normally doing it naively on transformer blocks would yield, for power users. Let me give an example for this:

Our max memory required is bound by memory for weights + intermediate activations. Weight quantizing or using lower precision dtype for storage typically only helps reduce the memory for weights, but peaks from activations remain the same. Assuming we have a very simple model consisting of transformer blocks (attention + large feed forward), applying layerwise upcasting, from fp8 to fp16, at the transformer block level, will require (fp8 model memory + [fp16 single block memory - fp8 single block memory] + fp16 single block activation peak). However, if we do the upcasting more granularly at each attention and each feed forward layer, the max memory footprint will be lower, because you are only upcasting weights of the attention OR feed forward MLP at a time as opposed to upcasting both at once at transformer block level, but at a slightly slower inference speed tradeoff (because more upcasts and downcasts).The ultimate granular case would be like what we do in sequential cpu offload, but here we would upcast and downcast at the nn.Module level instead of ModelMixin level. I think these tradeoffs are still worth it for making larger and larger models more accessible for inference purposes, so thinking of layerwise upcasting at just the transformer block level is a bit restricting (just as we discussed).

@sayakpaul
Copy link
Member Author

As we're pushing for the hooks API that is in process of being introduced in #9562, the implementation could be something roughly as follows:

I don't think we're quite settled on the hooks-based approach yet, though, no? Like for more caching, we're still debating if a mixin class would make more case. OTOH, this simple experimental API is simple and easy to use and cuts it for most use cases as I can imagine.

I think TorchAO already does things in this "layerwise" manner at the nn.Module level, but this would still be nice to have for layers that have parameter weights, but are not nn.Linear or nn.Conv2d

As mentioned in the description, having the experimental API is still meaningful for cases when the weight is stored in a lower precision without any quantization stats and we don't need any bitwise packing and unpacking unlike quantization.

I would like to also point out based on our discussion that this approach leads to more granular control and can lead to some more memory savings than normally doing it naively on transformer blocks would yield, for power users.

The ultimate granular case would be like what we do in sequential cpu offload, but here we would upcast and downcast at the nn.Module level instead of ModelMixin level.

Well, ModelMixin is still a subclass of nn.Module so, I am unsure what you mean here. And the utility from #9177 is recursive meaning that applies to all the children where applicable. So, I suggest simplifying your explanation a bit as I am bit lost in details.

@a-r-r-o-w
Copy link
Member

I don't think we're quite settled on the hooks-based approach yet, though, no? Like for more caching, we're still debating if a mixin class would make more case. OTOH, this simple experimental API is simple and easy to use and cuts it for most use cases as I can imagine.

I think this approach would be consistent with what accelerate does for cpu offloading and is very clean, since both device and dtype can be handled in the same way. It is also more easily usable with any kind of model because otherwise (as in @DN6's PR), it involves making many changes to each model by adding .to(...) logic. Any other implementation ideas should be good too, but I would like to the modeling implementations to look like simple mathematical functions and not deal with any kind of device/dtype casting - everything related to this should be handled at the pre and post hook levels, which is what Dhruv's PR is doing as well

Well, ModelMixin is still a subclass of nn.Module so, I am unsure what you mean here. And the utility from #9177 is recursive meaning that applies to all the children where applicable. So, I suggest simplifying your explanation a bit as I am bit lost in details.

If I understand the PR correctly, the entry point for enabling layerwise upcasting is at the ModelMixin level, yes? You can't arbitrarily apply it to any specific module that you want easily, unless it is already derived from ModelMixin. What would be really nice is keeping the API similar-ish to cpu_offload_with_hook from accelerate that works with any nn.Module

@sayakpaul
Copy link
Member Author

Okay, I see merit in the hook-based approach, but that said, we probably shouldn't also delay shipping this; otherwise, it is a simple EXPERIMENTAL feature because, clearly, it will enable a lot of memory savings very easily.

If I understand the PR correctly, the entry point for enabling layerwise upcasting is at the ModelMixin level, yes? You can't arbitrarily apply it to any specific module that you want easily, unless it is already derived from ModelMixin. What would be really nice is keeping the API similar-ish to cpu_offload_with_hook from accelerate that works with any nn.Module

Doesn't need to be. enable_layerwise_upcasting(transformer.transformer_blocks) should also work.

Any other implementation ideas should be good too, but I would like to the modeling implementations to look like simple mathematical functions and not deal with any kind of device/dtype casting - everything related to this should be handled at the pre and post hook levels, which is what Dhruv's PR is doing as well

Good point! From what I understand those explicit to()s were needed to handle edge cases. I think even if we had to use something like LayerwiseUpcastingHook, those explicit to()s might be needed. We can quickly check that by implementing the LayerwiseUpcastingHook approach and then running the tests. This should, IMO, allow us to quickly determine which approach could be preferred for shipping an experimental version of the idea. WDYT?

Copy link

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.

Please note that issues that do not follow the contributing guidelines are likely to be ignored.

@github-actions github-actions bot added the stale Issues that haven't received updates label Dec 18, 2024
@vladmandic
Copy link
Contributor

ping?

@a-r-r-o-w
Copy link
Member

a-r-r-o-w commented Dec 18, 2024

Sorry about the delay @vladmandic! It was not planned for this release so I was going to pick it up later. Let me open a quick prototype with some benchmarks in some time

@vladmandic
Copy link
Contributor

thanks - its not an urgent thing, but its very interesting one - i just wanted to make sure it doesn't drop off the radar since there was no update in a long time.

@sayakpaul sayakpaul removed the stale Issues that haven't received updates label Dec 19, 2024
@a-r-r-o-w a-r-r-o-w linked a pull request Dec 23, 2024 that will close this issue
8 tasks
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging a pull request may close this issue.

4 participants