-
Notifications
You must be signed in to change notification settings - Fork 999
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Enable cpu offload with weights inside the module #2214
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -302,6 +302,7 @@ def dispatch_model( | |
offload_dir: Optional[Union[str, os.PathLike]] = None, | ||
offload_index: Optional[Dict[str, str]] = None, | ||
offload_buffers: bool = False, | ||
cpu_offload: bool = False, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think it can be confusing as CPU offloading is already indicated in the IMO ideally there should not be any argument added, and by default the weights of modules offloaded on RAM should be on |
||
skip_keys: Optional[Union[str, List[str]]] = None, | ||
preload_module_classes: Optional[List[str]] = None, | ||
force_hooks: bool = False, | ||
|
@@ -321,6 +322,8 @@ def dispatch_model( | |
`"disk"`. | ||
state_dict (`Dict[str, torch.Tensor]`, *optional*): | ||
The state dict of the part of the model that will be kept on CPU. | ||
cpu_offload (`bool`, *optional*, defaults to `False`): | ||
Whether the weights offloaded on the cpu should be kept in the module or not. | ||
offload_dir (`str` or `os.PathLike`): | ||
The folder in which to offload the model weights (or where the model weights are already offloaded). | ||
offload_index (`Dict`, *optional*): | ||
|
@@ -358,7 +361,7 @@ def dispatch_model( | |
else: | ||
main_device = [d for d in device_map.values() if d not in ["cpu", "disk"]][0] | ||
|
||
if main_device != "cpu": | ||
if main_device != "cpu" and not cpu_offload: | ||
cpu_modules = [name for name, device in device_map.items() if device == "cpu"] | ||
if state_dict is None and len(cpu_modules) > 0: | ||
state_dict = extract_submodules_state_dict(model.state_dict(), cpu_modules) | ||
|
@@ -381,8 +384,12 @@ def dispatch_model( | |
name: main_device if device in ["cpu", "disk"] else device for name, device in device_map.items() | ||
} | ||
execution_device[""] = main_device | ||
offloaded_devices = ["disk"] if main_device == "cpu" or main_device == "mps" else ["cpu", "disk"] | ||
offloaded_devices = ( | ||
["disk"] if cpu_offload or main_device == "cpu" or main_device == "mps" else ["cpu", "disk"] | ||
) | ||
offload = {name: device in offloaded_devices for name, device in device_map.items()} | ||
if cpu_offload: | ||
cpu_offload = {name: device == "cpu" for name, device in device_map.items()} | ||
save_folder = offload_dir if len(disk_modules) > 0 else None | ||
if state_dict is not None or save_folder is not None or offload_index is not None: | ||
device = main_device if offload_index is not None else None | ||
|
@@ -397,6 +404,7 @@ def dispatch_model( | |
model, | ||
execution_device=execution_device, | ||
offload=offload, | ||
cpu_offload=cpu_offload, | ||
offload_buffers=offload_buffers, | ||
weights_map=weights_map, | ||
skip_keys=skip_keys, | ||
|
@@ -405,7 +413,7 @@ def dispatch_model( | |
|
||
# warn if there is any params on the meta device | ||
offloaded_devices_str = " and ".join( | ||
[device for device in set(device_map.values()) if device in ("cpu", "disk")] | ||
[device for device in set(device_map.values()) if device in offloaded_devices] | ||
) | ||
if len(offloaded_devices_str) > 0: | ||
logging.warning( | ||
|
@@ -450,6 +458,7 @@ def load_checkpoint_and_dispatch( | |
no_split_module_classes: Optional[List[str]] = None, | ||
offload_folder: Optional[Union[str, os.PathLike]] = None, | ||
offload_buffers: bool = False, | ||
cpu_offload: bool = False, | ||
dtype: Optional[Union[str, torch.dtype]] = None, | ||
offload_state_dict: Optional[bool] = None, | ||
skip_keys: Optional[Union[str, List[str]]] = None, | ||
|
@@ -484,6 +493,8 @@ def load_checkpoint_and_dispatch( | |
offload_buffers (`bool`, *optional*, defaults to `False`): | ||
In the layers that are offloaded on the CPU or the hard drive, whether or not to offload the buffers as | ||
well as the parameters. | ||
cpu_offload (`bool`, *optional*, defaults to `False`): | ||
Whether the weights offloaded on the cpu should be kept in the module or not. | ||
dtype (`str` or `torch.dtype`, *optional*): | ||
If provided, the weights will be converted to that type when loaded. | ||
offload_state_dict (`bool`, *optional*): | ||
|
@@ -558,6 +569,7 @@ def load_checkpoint_and_dispatch( | |
device_map=device_map, | ||
offload_dir=offload_folder, | ||
offload_buffers=offload_buffers, | ||
cpu_offload=cpu_offload, | ||
skip_keys=skip_keys, | ||
preload_module_classes=preload_module_classes, | ||
force_hooks=force_hooks, | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -208,6 +208,8 @@ class AlignDevicesHook(ModelHook): | |
The device on which inputs and model weights should be placed before the forward pass. | ||
offload (`bool`, *optional*, defaults to `False`): | ||
Whether or not the weights should be offloaded after the forward pass. | ||
cpu_offload (`bool`, *optional*, defaults to `False`): | ||
Whether the weights offloaded on the cpu should be kept in the module or not. | ||
io_same_device (`bool`, *optional*, defaults to `False`): | ||
Whether or not the output should be placed on the same device as the input was. | ||
weights_map (`Mapping[str, torch.Tensor]`, *optional*): | ||
|
@@ -222,6 +224,7 @@ def __init__( | |
self, | ||
execution_device: Optional[Union[int, str, torch.device]] = None, | ||
offload: bool = False, | ||
cpu_offload: bool = False, | ||
io_same_device: bool = False, | ||
weights_map: Optional[Mapping] = None, | ||
offload_buffers: bool = False, | ||
|
@@ -230,12 +233,12 @@ def __init__( | |
): | ||
self.execution_device = execution_device | ||
self.offload = offload | ||
self.cpu_offload = cpu_offload | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Would it make sense to pass both |
||
self.io_same_device = io_same_device | ||
self.weights_map = weights_map | ||
self.offload_buffers = offload_buffers | ||
self.place_submodules = place_submodules | ||
self.skip_keys = skip_keys | ||
|
||
# Will contain the input device when `io_same_device=True`. | ||
self.input_device = None | ||
self.param_original_devices = {} | ||
|
@@ -249,7 +252,7 @@ def __repr__(self): | |
) | ||
|
||
def init_hook(self, module): | ||
if not self.offload and self.execution_device is not None: | ||
if not self.offload and self.execution_device is not None and not self.cpu_offload: | ||
for name, _ in named_module_tensors(module, recurse=self.place_submodules): | ||
set_module_tensor_to_device(module, name, self.execution_device) | ||
elif self.offload: | ||
|
@@ -273,7 +276,9 @@ def init_hook(self, module): | |
elif self.offload_buffers and self.execution_device is not None: | ||
for name in get_non_persistent_buffers(module, recurse=self.place_submodules): | ||
set_module_tensor_to_device(module, name, self.execution_device) | ||
|
||
elif self.cpu_offload: | ||
for name, _ in named_module_tensors(module, recurse=self.place_submodules): | ||
set_module_tensor_to_device(module, name, "cpu") | ||
return module | ||
|
||
def pre_forward(self, module, *args, **kwargs): | ||
|
@@ -293,7 +298,9 @@ def pre_forward(self, module, *args, **kwargs): | |
set_module_tensor_to_device( | ||
module, name, self.execution_device, value=self.weights_map[name], fp16_statistics=fp16_statistics | ||
) | ||
|
||
elif self.cpu_offload: | ||
for name, _ in named_module_tensors(module, recurse=self.place_submodules): | ||
set_module_tensor_to_device(module, name, self.execution_device) | ||
return send_to_device(args, self.execution_device), send_to_device( | ||
kwargs, self.execution_device, skip_keys=self.skip_keys | ||
) | ||
|
@@ -310,7 +317,9 @@ def post_forward(self, module, output): | |
if type(module).__name__ == "Linear8bitLt": | ||
module.state.SCB = None | ||
module.state.CxB = None | ||
|
||
elif self.cpu_offload: | ||
for name, _ in named_module_tensors(module, recurse=self.place_submodules): | ||
set_module_tensor_to_device(module, name, "cpu") | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Is special handling for |
||
if self.io_same_device and self.input_device is not None: | ||
output = send_to_device(output, self.input_device, skip_keys=self.skip_keys) | ||
|
||
|
@@ -450,6 +459,7 @@ def attach_align_device_hook_on_blocks( | |
module: nn.Module, | ||
execution_device: Optional[Union[torch.device, Dict[str, torch.device]]] = None, | ||
offload: Union[bool, Dict[str, bool]] = False, | ||
cpu_offload: Union[bool, Dict[str, bool]] = False, | ||
weights_map: Mapping = None, | ||
offload_buffers: bool = False, | ||
module_name: str = "", | ||
|
@@ -468,6 +478,8 @@ def attach_align_device_hook_on_blocks( | |
offload (`bool`, *optional*, defaults to `False`): | ||
Whether or not the weights should be offloaded after the forward pass. It can be one boolean for the whole | ||
module, or a dictionary mapping module name to boolean. | ||
cpu_offload (`Union[bool, Dict[str, bool]]`, *optional*, defaults to `False`): | ||
Whether the weights offloaded on the cpu should be kept in the module or not. | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Docstring misses to explain what the option is for passing a dict here. |
||
weights_map (`Mapping[str, torch.Tensor]`, *optional*): | ||
When the model weights are offloaded, a (potentially lazy) map from param names to the tensor values. | ||
offload_buffers (`bool`, *optional*, defaults to `False`): | ||
|
@@ -505,10 +517,16 @@ def attach_align_device_hook_on_blocks( | |
execution_device = {key: execution_device for key in offload.keys()} | ||
if not isinstance(offload, Mapping): | ||
offload = {key: offload for key in execution_device.keys()} | ||
|
||
if module_name in execution_device and module_name in offload and not offload[module_name]: | ||
if not isinstance(cpu_offload, Mapping): | ||
cpu_offload = {key: cpu_offload for key in execution_device.keys()} | ||
if ( | ||
module_name in execution_device | ||
and module_name in offload | ||
and (not offload[module_name] or cpu_offload[module_name]) | ||
): | ||
hook = AlignDevicesHook( | ||
execution_device=execution_device[module_name], | ||
cpu_offload=cpu_offload[module_name], | ||
offload_buffers=offload_buffers, | ||
io_same_device=(module_name == ""), | ||
place_submodules=True, | ||
|
@@ -548,6 +566,7 @@ def attach_align_device_hook_on_blocks( | |
child, | ||
execution_device=execution_device, | ||
offload=offload, | ||
cpu_offload=cpu_offload, | ||
weights_map=weights_map, | ||
offload_buffers=offload_buffers, | ||
module_name=child_name, | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -338,6 +338,24 @@ def test_dispatch_model_with_non_persistent_buffers(self): | |
|
||
with TemporaryDirectory() as tmp_dir: | ||
dispatch_model(model, device_map, offload_dir=tmp_dir, offload_buffers=True) | ||
|
||
output = model(x) | ||
self.assertTrue(torch.allclose(expected, output.cpu(), atol=1e-5)) | ||
|
||
def test_dispatch_model_with_cpu_offload(self): | ||
model = ModelForTest() | ||
device_map = {"linear1": "disk", "batchnorm": "cpu", "linear2": 0} | ||
|
||
x = torch.randn(2, 3) | ||
expected = model(x) | ||
|
||
with TemporaryDirectory() as tmp_dir: | ||
dispatch_model(model, device_map, offload_dir=tmp_dir, cpu_offload=True) | ||
|
||
self.assertEqual(model.linear1.weight.device, torch.device("meta")) | ||
self.assertEqual(model.batchnorm.weight.device, torch.device("cpu")) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The new behavior of getting "cpu" here instead of "meta" looks more intuitive to me. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes this is what we are aiming in this PR ! We want to let the module on |
||
self.assertEqual(model.linear2.weight.device, torch.device(0)) | ||
|
||
output = model(x) | ||
self.assertTrue(torch.allclose(expected, output.cpu(), atol=1e-5)) | ||
|
||
|
@@ -548,6 +566,31 @@ def test_load_checkpoint_and_dispatch(self): | |
output = new_model(x) | ||
self.assertTrue(torch.allclose(expected, output.cpu(), atol=1e-5)) | ||
|
||
@require_cuda | ||
def test_load_checkpoint_and_dispatch_with_cpu_offload(self): | ||
model = ModelForTest() | ||
device_map = {"linear1": "cpu", "batchnorm": "disk", "linear2": 0} | ||
|
||
x = torch.randn(2, 3) | ||
expected = model(x) | ||
|
||
with TemporaryDirectory() as tmp_dir: | ||
checkpoint = os.path.join(tmp_dir, "pt_model.bin") | ||
torch.save(model.state_dict(), checkpoint) | ||
|
||
new_model = ModelForTest() | ||
new_model = load_checkpoint_and_dispatch( | ||
new_model, checkpoint, device_map=device_map, cpu_offload=True, offload_folder=tmp_dir | ||
) | ||
|
||
# CPU-offloaded weights are on the meta device while waiting for the forward pass. | ||
self.assertEqual(new_model.linear1.weight.device, torch.device("cpu")) | ||
self.assertEqual(new_model.batchnorm.weight.device, torch.device("meta")) | ||
self.assertEqual(new_model.linear2.weight.device, torch.device(0)) | ||
|
||
output = new_model(x) | ||
self.assertTrue(torch.allclose(expected, output.cpu(), atol=1e-5)) | ||
|
||
@require_mps | ||
def test_load_checkpoint_and_dispatch_mps(self): | ||
model = ModelForTest() | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In general, should newly added parameters be placed last in case someone calls this function with purely positional arguments?