diff --git a/internal_controlnet/args.py b/internal_controlnet/args.py index 63e6c1fc2..c936c56fe 100644 --- a/internal_controlnet/args.py +++ b/internal_controlnet/args.py @@ -471,3 +471,7 @@ def parse(cls, text: str) -> ControlNetUnit: for (key, value) in (item.strip().split(": "),) }, ) + + def __copy__(self) -> ControlNetUnit: + """Override the behavior on `copy.copy` calls.""" + return self.copy() diff --git a/unit_tests/args_test.py b/unit_tests/args_test.py index 31d79eeaa..f93dc628f 100644 --- a/unit_tests/args_test.py +++ b/unit_tests/args_test.py @@ -1,6 +1,7 @@ import pytest import torch import numpy as np +from copy import copy from dataclasses import dataclass from internal_controlnet.args import ControlNetUnit @@ -249,3 +250,11 @@ def test_infotext_parsing(): def test_alias(): ControlNetUnit.from_dict({"lowvram": True}) + + +def test_copy(): + unit1 = ControlNetUnit(enabled=True, module="none") + unit2 = copy(unit1) + unit2.enabled = False + assert unit1.enabled + assert not unit2.enabled