diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index a6fbd7b1a91453..8481fa7df9cd96 100755 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -943,13 +943,14 @@ def _load_state_dict_into_meta_model( old_param = model splits = param_name.split(".") for split in splits: - old_param = getattr(old_param, split) - # Not all the attributes of a module are Parameters/Tensor - if not isinstance(old_param, (torch.nn.Parameter, torch.Tensor)): - old_param = None + # We shouldn't hit the default value unless for quant methods like hqq that modifies expected_keys. + old_param = getattr(old_param, split, None) if old_param is None: break + if not isinstance(old_param, (torch.nn.Parameter, torch.Tensor)): + old_param = None + if old_param is not None: if dtype is None: param = param.to(old_param.dtype) diff --git a/tests/quantization/torchao_integration/test_torchao.py b/tests/quantization/torchao_integration/test_torchao.py index 8014f745d08688..c7c701e49aec14 100644 --- a/tests/quantization/torchao_integration/test_torchao.py +++ b/tests/quantization/torchao_integration/test_torchao.py @@ -208,6 +208,26 @@ def test_int4wo_offload(self): self.assertEqual(tokenizer.decode(output[0], skip_special_tokens=True), EXPECTED_OUTPUT) + def test_int8_dynamic_activation_int8_weight_quant(self): + """ + Simple LLM model testing int8_dynamic_activation_int8_weight + """ + quant_config = TorchAoConfig("int8_dynamic_activation_int8_weight") + + # Note: we quantize the bfloat16 model on the fly to int4 + quantized_model = AutoModelForCausalLM.from_pretrained( + self.model_name, + device_map=torch_device, + quantization_config=quant_config, + ) + tokenizer = AutoTokenizer.from_pretrained(self.model_name) + + input_ids = tokenizer(self.input_text, return_tensors="pt").to(torch_device) + + output = quantized_model.generate(**input_ids, max_new_tokens=self.max_new_tokens) + EXPECTED_OUTPUT = "What are we having for dinner?\n\nJessica: (smiling)" + self.assertEqual(tokenizer.decode(output[0], skip_special_tokens=True), EXPECTED_OUTPUT) + if __name__ == "__main__": unittest.main()