Skip to content

Commit

Permalink
🐛 Fix update_cn_script (#2203)
Browse files Browse the repository at this point in the history
  • Loading branch information
huchenlei authored Oct 29, 2023
1 parent 4ac716a commit f2aafcf
Showing 1 changed file with 44 additions and 4 deletions.
48 changes: 44 additions & 4 deletions internal_controlnet/external_code.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from enum import Enum
from copy import copy
from typing import List, Any, Optional, Union, Tuple, Dict
import numpy as np
from modules import scripts, processing, shared
Expand Down Expand Up @@ -335,11 +336,47 @@ def update_cn_script_in_processing(
- ControlNet is not present in `p.scripts`
- `p.script_args` is not filled with script arguments for scripts that are processed before ControlNet
"""
p.script_args = update_cn_script(p.scripts, p.script_args_value, cn_units)

cn_units_type = type(cn_units) if type(cn_units) in (list, tuple) else list
script_args = list(p.script_args)
update_cn_script_in_place(p.scripts, script_args, cn_units)
p.script_args = cn_units_type(script_args)

def update_cn_script(
script_runner: scripts.ScriptRunner,
script_args: Union[Tuple[Any], List[Any]],
cn_units: List[ControlNetUnit],
) -> Union[Tuple[Any], List[Any]]:
"""
Returns: The updated `script_args` with given `cn_units` used as ControlNet
script args.
Does not update `script_args` if any of the folling is true:
- ControlNet is not present in `script_runner`
- `script_args` is not filled with script arguments for scripts that are
processed before ControlNet
"""
script_args_type = type(script_args)
assert script_args_type in (tuple, list), script_args_type
updated_script_args = list(copy(script_args))

cn_script = find_cn_script(script_runner)

if cn_script is None or len(script_args) < cn_script.args_from:
return script_args

# fill in remaining parameters to satisfy max models, just in case script needs it.
max_models = shared.opts.data.get("control_net_unit_count", 3)
cn_units = cn_units + [ControlNetUnit(enabled=False)] * max(max_models - len(cn_units), 0)

cn_script_args_diff = 0
for script in script_runner.alwayson_scripts:
if script is cn_script:
cn_script_args_diff = len(cn_units) - (cn_script.args_to - cn_script.args_from)
updated_script_args[script.args_from:script.args_to] = cn_units
script.args_to = script.args_from + len(cn_units)
else:
script.args_from += cn_script_args_diff
script.args_to += cn_script_args_diff

return script_args_type(updated_script_args)


def update_cn_script_in_place(
Expand All @@ -349,13 +386,16 @@ def update_cn_script_in_place(
**_kwargs, # for backwards compatibility
):
"""
@Deprecated(Raises assertion error if script_args passed in is Tuple)
Update the arguments of the ControlNet script in `script_args` in place, reading from `cn_units`.
`cn_units` and its elements are not modified. You can call this function repeatedly, as many times as you want.
Does not update `script_args` if any of the folling is true:
- ControlNet is not present in `script_runner`
- `script_args` is not filled with script arguments for scripts that are processed before ControlNet
"""
assert isinstance(script_args, list), type(script_args)

cn_script = find_cn_script(script_runner)
if cn_script is None or len(script_args) < cn_script.args_from:
Expand Down

0 comments on commit f2aafcf

Please sign in to comment.