diff --git a/src/accelerate/utils/modeling.py b/src/accelerate/utils/modeling.py index fbbbfcd2610..5f88e54e3c9 100644 --- a/src/accelerate/utils/modeling.py +++ b/src/accelerate/utils/modeling.py @@ -1784,6 +1784,10 @@ def get_state_dict_from_offload( root = module_name[: module_name.rfind(".")] # module name without .weight or .bias + # do not move parameters if the module is not offloaded + if not has_offloaded_params(module): + device_to_put_offload = None + # assign the device to which the offloaded parameters will be sent with align_module_device(module, device_to_put_offload): for m_key, params in module.state_dict().items():