From f099b2fbd5c715f59baccdc0dd51ea29b95209a3 Mon Sep 17 00:00:00 2001 From: Maxim Kan Date: Thu, 26 Dec 2024 10:02:30 +0000 Subject: [PATCH 1/7] check for base_layer key in transformer state dict --- src/diffusers/loaders/lora_pipeline.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/src/diffusers/loaders/lora_pipeline.py b/src/diffusers/loaders/lora_pipeline.py index 351295e938ff..7e26c397a077 100644 --- a/src/diffusers/loaders/lora_pipeline.py +++ b/src/diffusers/loaders/lora_pipeline.py @@ -2460,13 +2460,17 @@ def _maybe_expand_lora_state_dict(cls, transformer, lora_state_dict): if unexpected_modules: logger.debug(f"Found unexpected modules: {unexpected_modules}. These will be ignored.") - is_peft_loaded = getattr(transformer, "peft_config", None) is not None + transformer_base_layer_keys = { + k[: -len(".base_layer.weight")] for k in transformer_state_dict.keys() if ".base_layer.weight" in k + } for k in lora_module_names: if k in unexpected_modules: continue base_param_name = ( - f"{k.replace(prefix, '')}.base_layer.weight" if is_peft_loaded else f"{k.replace(prefix, '')}.weight" + f"{k.replace(prefix, '')}.base_layer.weight" + if k in transformer_base_layer_keys + else f"{k.replace(prefix, '')}.weight" ) base_weight_param = transformer_state_dict[base_param_name] lora_A_param = lora_state_dict[f"{prefix}{k}.lora_A.weight"] From 3a4f8a4b0b55e30fb41dfbc659cc8a99c3912fec Mon Sep 17 00:00:00 2001 From: hlky Date: Sun, 29 Dec 2024 15:34:01 +0000 Subject: [PATCH 2/7] test_lora_expansion_works_for_absent_keys --- tests/lora/test_lora_layers_flux.py | 51 +++++++++++++++++++++++++++++ 1 file changed, 51 insertions(+) diff --git a/tests/lora/test_lora_layers_flux.py b/tests/lora/test_lora_layers_flux.py index 0861160de6aa..01c32b4d503d 100644 --- a/tests/lora/test_lora_layers_flux.py +++ b/tests/lora/test_lora_layers_flux.py @@ -12,6 +12,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import copy import gc import os import sys @@ -162,6 +163,56 @@ def test_with_alpha_in_state_dict(self): ) self.assertFalse(np.allclose(images_lora_with_alpha, images_lora, atol=1e-3, rtol=1e-3)) + def test_lora_expansion_works_for_absent_keys(self): + components, _, denoiser_lora_config = self.get_dummy_components(FlowMatchEulerDiscreteScheduler) + pipe = self.pipeline_class(**components) + pipe = pipe.to(torch_device) + pipe.set_progress_bar_config(disable=None) + _, _, inputs = self.get_dummy_inputs(with_generator=False) + + output_no_lora = pipe(**inputs, generator=torch.manual_seed(0)).images + self.assertTrue(output_no_lora.shape == self.output_shape) + + # Modify the config to have a layer which won't be present in the second LoRA we will load. + modified_denoiser_lora_config = copy.deepcopy(denoiser_lora_config) + modified_denoiser_lora_config.target_modules.add("x_embedder") + + pipe.transformer.add_adapter(modified_denoiser_lora_config) + self.assertTrue(check_if_lora_correctly_set(pipe.transformer), "Lora not correctly set in transformer") + + images_lora = pipe(**inputs, generator=torch.manual_seed(0)).images + self.assertFalse( + np.allclose(images_lora, output_no_lora, atol=1e-3, rtol=1e-3), + "LoRA should lead to different results.", + ) + + with tempfile.TemporaryDirectory() as tmpdirname: + denoiser_state_dict = get_peft_model_state_dict(pipe.transformer) + self.pipeline_class.save_lora_weights(tmpdirname, transformer_lora_layers=denoiser_state_dict) + + self.assertTrue(os.path.isfile(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors"))) + pipe.unload_lora_weights() + # Modify the state dict to exclude "x_embedder" related LoRA params. + lora_state_dict = safetensors.torch.load_file(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors")) + lora_state_dict_without_xembedder = {k: v for k, v in lora_state_dict.items() if "x_embedder" not in k} + pipe.load_lora_weights(lora_state_dict_without_xembedder, adapter_name="two") + + # Load state dict with `x_embedder`. + pipe.load_lora_weights(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors"), adapter_name="one") + + pipe.set_adapters(["one", "two"]) + self.assertTrue(check_if_lora_correctly_set(pipe.transformer), "Lora not correctly set in transformer") + images_lora_with_absent_keys = pipe(**inputs, generator=torch.manual_seed(0)).images + + self.assertFalse( + np.allclose(images_lora, images_lora_with_absent_keys, atol=1e-3, rtol=1e-3), + "Different LoRAs should lead to different results.", + ) + self.assertFalse( + np.allclose(output_no_lora, images_lora_with_absent_keys, atol=1e-3, rtol=1e-3), + "LoRA should lead to different results.", + ) + @unittest.skip("Not supported in Flux.") def test_simple_inference_with_text_denoiser_block_scale_for_all_dict_options(self): pass From a2cdcdafeaf6ad6ea68767063ecdd60a1a633dbf Mon Sep 17 00:00:00 2001 From: hlky Date: Mon, 30 Dec 2024 07:22:33 +0000 Subject: [PATCH 3/7] check --- src/diffusers/loaders/lora_pipeline.py | 13 +++++-------- 1 file changed, 5 insertions(+), 8 deletions(-) diff --git a/src/diffusers/loaders/lora_pipeline.py b/src/diffusers/loaders/lora_pipeline.py index 7e26c397a077..97909b31ee44 100644 --- a/src/diffusers/loaders/lora_pipeline.py +++ b/src/diffusers/loaders/lora_pipeline.py @@ -2460,18 +2460,15 @@ def _maybe_expand_lora_state_dict(cls, transformer, lora_state_dict): if unexpected_modules: logger.debug(f"Found unexpected modules: {unexpected_modules}. These will be ignored.") - transformer_base_layer_keys = { - k[: -len(".base_layer.weight")] for k in transformer_state_dict.keys() if ".base_layer.weight" in k - } + is_peft_loaded = getattr(transformer, "peft_config", None) is not None for k in lora_module_names: if k in unexpected_modules: continue - base_param_name = ( - f"{k.replace(prefix, '')}.base_layer.weight" - if k in transformer_base_layer_keys - else f"{k.replace(prefix, '')}.weight" - ) + base_param_name = f"{k.replace(prefix, '')}.weight" + base_layer_name = f"{k.replace(prefix, '')}.base_layer.weight" + if is_peft_loaded and base_layer_name in transformer_state_dict: + base_param_name = base_layer_name base_weight_param = transformer_state_dict[base_param_name] lora_A_param = lora_state_dict[f"{prefix}{k}.lora_A.weight"] From c8d4a1cb935401a681fd2cbe4defcce73cbc2f1a Mon Sep 17 00:00:00 2001 From: hlky Date: Mon, 30 Dec 2024 07:49:31 +0000 Subject: [PATCH 4/7] Update tests/lora/test_lora_layers_flux.py Co-authored-by: Sayak Paul --- tests/lora/test_lora_layers_flux.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/lora/test_lora_layers_flux.py b/tests/lora/test_lora_layers_flux.py index 01c32b4d503d..94784edf9750 100644 --- a/tests/lora/test_lora_layers_flux.py +++ b/tests/lora/test_lora_layers_flux.py @@ -195,10 +195,10 @@ def test_lora_expansion_works_for_absent_keys(self): # Modify the state dict to exclude "x_embedder" related LoRA params. lora_state_dict = safetensors.torch.load_file(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors")) lora_state_dict_without_xembedder = {k: v for k, v in lora_state_dict.items() if "x_embedder" not in k} - pipe.load_lora_weights(lora_state_dict_without_xembedder, adapter_name="two") + pipe.load_lora_weights(lora_state_dict_without_xembedder, adapter_name="one") # Load state dict with `x_embedder`. - pipe.load_lora_weights(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors"), adapter_name="one") + pipe.load_lora_weights(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors"), adapter_name="two") pipe.set_adapters(["one", "two"]) self.assertTrue(check_if_lora_correctly_set(pipe.transformer), "Lora not correctly set in transformer") From 75268c077aecb1b82bc4aeb478944ee15f1142b3 Mon Sep 17 00:00:00 2001 From: hlky Date: Mon, 30 Dec 2024 07:52:30 +0000 Subject: [PATCH 5/7] check --- src/diffusers/loaders/lora_pipeline.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/src/diffusers/loaders/lora_pipeline.py b/src/diffusers/loaders/lora_pipeline.py index 97909b31ee44..f55d9958e5c3 100644 --- a/src/diffusers/loaders/lora_pipeline.py +++ b/src/diffusers/loaders/lora_pipeline.py @@ -2465,10 +2465,11 @@ def _maybe_expand_lora_state_dict(cls, transformer, lora_state_dict): if k in unexpected_modules: continue - base_param_name = f"{k.replace(prefix, '')}.weight" - base_layer_name = f"{k.replace(prefix, '')}.base_layer.weight" - if is_peft_loaded and base_layer_name in transformer_state_dict: - base_param_name = base_layer_name + base_param_name = ( + f"{k.replace(prefix, '')}.base_layer.weight" + if is_peft_loaded and f"{k.replace(prefix, '')}.base_layer.weight" in transformer_state_dict + else f"{k.replace(prefix, '')}.weight" + ) base_weight_param = transformer_state_dict[base_param_name] lora_A_param = lora_state_dict[f"{prefix}{k}.lora_A.weight"] From 08ea124992921e4b4c493fd210c5fd5f59638b6d Mon Sep 17 00:00:00 2001 From: hlky Date: Mon, 30 Dec 2024 07:52:41 +0000 Subject: [PATCH 6/7] test_lora_expansion_works_for_absent_keys/test_lora_expansion_works_for_extra_keys --- tests/lora/test_lora_layers_flux.py | 49 +++++++++++++++++++++++++++++ 1 file changed, 49 insertions(+) diff --git a/tests/lora/test_lora_layers_flux.py b/tests/lora/test_lora_layers_flux.py index 94784edf9750..648a21aeaa03 100644 --- a/tests/lora/test_lora_layers_flux.py +++ b/tests/lora/test_lora_layers_flux.py @@ -186,6 +186,55 @@ def test_lora_expansion_works_for_absent_keys(self): "LoRA should lead to different results.", ) + with tempfile.TemporaryDirectory() as tmpdirname: + denoiser_state_dict = get_peft_model_state_dict(pipe.transformer) + self.pipeline_class.save_lora_weights(tmpdirname, transformer_lora_layers=denoiser_state_dict) + + self.assertTrue(os.path.isfile(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors"))) + pipe.unload_lora_weights() + pipe.load_lora_weights(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors"), adapter_name="one") + + # Modify the state dict to exclude "x_embedder" related LoRA params. + lora_state_dict = safetensors.torch.load_file(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors")) + lora_state_dict_without_xembedder = {k: v for k, v in lora_state_dict.items() if "x_embedder" not in k} + + pipe.load_lora_weights(lora_state_dict_without_xembedder, adapter_name="two") + pipe.set_adapters(["one", "two"]) + self.assertTrue(check_if_lora_correctly_set(pipe.transformer), "Lora not correctly set in transformer") + images_lora_with_absent_keys = pipe(**inputs, generator=torch.manual_seed(0)).images + + self.assertFalse( + np.allclose(images_lora, images_lora_with_absent_keys, atol=1e-3, rtol=1e-3), + "Different LoRAs should lead to different results.", + ) + self.assertFalse( + np.allclose(output_no_lora, images_lora_with_absent_keys, atol=1e-3, rtol=1e-3), + "LoRA should lead to different results.", + ) + + def test_lora_expansion_works_for_extra_keys(self): + components, _, denoiser_lora_config = self.get_dummy_components(FlowMatchEulerDiscreteScheduler) + pipe = self.pipeline_class(**components) + pipe = pipe.to(torch_device) + pipe.set_progress_bar_config(disable=None) + _, _, inputs = self.get_dummy_inputs(with_generator=False) + + output_no_lora = pipe(**inputs, generator=torch.manual_seed(0)).images + self.assertTrue(output_no_lora.shape == self.output_shape) + + # Modify the config to have a layer which won't be present in the first LoRA we will load. + modified_denoiser_lora_config = copy.deepcopy(denoiser_lora_config) + modified_denoiser_lora_config.target_modules.add("x_embedder") + + pipe.transformer.add_adapter(modified_denoiser_lora_config) + self.assertTrue(check_if_lora_correctly_set(pipe.transformer), "Lora not correctly set in transformer") + + images_lora = pipe(**inputs, generator=torch.manual_seed(0)).images + self.assertFalse( + np.allclose(images_lora, output_no_lora, atol=1e-3, rtol=1e-3), + "LoRA should lead to different results.", + ) + with tempfile.TemporaryDirectory() as tmpdirname: denoiser_state_dict = get_peft_model_state_dict(pipe.transformer) self.pipeline_class.save_lora_weights(tmpdirname, transformer_lora_layers=denoiser_state_dict) From 5a7997b587d10e1ed74d29118ca107ffaf898b69 Mon Sep 17 00:00:00 2001 From: hlky Date: Mon, 30 Dec 2024 07:54:49 +0000 Subject: [PATCH 7/7] absent->extra --- tests/lora/test_lora_layers_flux.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/tests/lora/test_lora_layers_flux.py b/tests/lora/test_lora_layers_flux.py index 648a21aeaa03..9fa968c47107 100644 --- a/tests/lora/test_lora_layers_flux.py +++ b/tests/lora/test_lora_layers_flux.py @@ -251,14 +251,14 @@ def test_lora_expansion_works_for_extra_keys(self): pipe.set_adapters(["one", "two"]) self.assertTrue(check_if_lora_correctly_set(pipe.transformer), "Lora not correctly set in transformer") - images_lora_with_absent_keys = pipe(**inputs, generator=torch.manual_seed(0)).images + images_lora_with_extra_keys = pipe(**inputs, generator=torch.manual_seed(0)).images self.assertFalse( - np.allclose(images_lora, images_lora_with_absent_keys, atol=1e-3, rtol=1e-3), + np.allclose(images_lora, images_lora_with_extra_keys, atol=1e-3, rtol=1e-3), "Different LoRAs should lead to different results.", ) self.assertFalse( - np.allclose(output_no_lora, images_lora_with_absent_keys, atol=1e-3, rtol=1e-3), + np.allclose(output_no_lora, images_lora_with_extra_keys, atol=1e-3, rtol=1e-3), "LoRA should lead to different results.", )