Skip to content

Commit

Permalink
Fix regression loading dtype (#34409)
Browse files Browse the repository at this point in the history
* fix regression

* add test for torchao

* expected output

* better fix
  • Loading branch information
SunMarc authored and ArthurZucker committed Oct 29, 2024
1 parent 72c716d commit 94ed13c
Show file tree
Hide file tree
Showing 2 changed files with 25 additions and 4 deletions.
9 changes: 5 additions & 4 deletions src/transformers/modeling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -943,13 +943,14 @@ def _load_state_dict_into_meta_model(
old_param = model
splits = param_name.split(".")
for split in splits:
old_param = getattr(old_param, split)
# Not all the attributes of a module are Parameters/Tensor
if not isinstance(old_param, (torch.nn.Parameter, torch.Tensor)):
old_param = None
# We shouldn't hit the default value unless for quant methods like hqq that modifies expected_keys.
old_param = getattr(old_param, split, None)
if old_param is None:
break

if not isinstance(old_param, (torch.nn.Parameter, torch.Tensor)):
old_param = None

if old_param is not None:
if dtype is None:
param = param.to(old_param.dtype)
Expand Down
20 changes: 20 additions & 0 deletions tests/quantization/torchao_integration/test_torchao.py
Original file line number Diff line number Diff line change
Expand Up @@ -208,6 +208,26 @@ def test_int4wo_offload(self):

self.assertEqual(tokenizer.decode(output[0], skip_special_tokens=True), EXPECTED_OUTPUT)

def test_int8_dynamic_activation_int8_weight_quant(self):
"""
Simple LLM model testing int8_dynamic_activation_int8_weight
"""
quant_config = TorchAoConfig("int8_dynamic_activation_int8_weight")

# Note: we quantize the bfloat16 model on the fly to int4
quantized_model = AutoModelForCausalLM.from_pretrained(
self.model_name,
device_map=torch_device,
quantization_config=quant_config,
)
tokenizer = AutoTokenizer.from_pretrained(self.model_name)

input_ids = tokenizer(self.input_text, return_tensors="pt").to(torch_device)

output = quantized_model.generate(**input_ids, max_new_tokens=self.max_new_tokens)
EXPECTED_OUTPUT = "What are we having for dinner?\n\nJessica: (smiling)"
self.assertEqual(tokenizer.decode(output[0], skip_special_tokens=True), EXPECTED_OUTPUT)


if __name__ == "__main__":
unittest.main()

0 comments on commit 94ed13c

Please sign in to comment.