Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Bug fixes, Runtime & Performance optimization #332

Open
wants to merge 6 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
182 changes: 66 additions & 116 deletions python_coreml_stable_diffusion/controlnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,54 +12,50 @@

from .unet import Timesteps, TimestepEmbedding, get_down_block, UNetMidBlock2DCrossAttn, linear_to_conv2d_map


class ControlNetConditioningEmbedding(nn.Module):
"""
Embeds conditioning input into a feature space suitable for ControlNet.
"""

def __init__(
self,
conditioning_embedding_channels,
conditioning_channels=3,
block_out_channels=(16, 32, 96, 256),
):
def __init__(self, conditioning_embedding_channels, conditioning_channels=3, block_out_channels=(16, 32, 96, 256)):
super().__init__()

# Initial convolution
self.conv_in = nn.Conv2d(conditioning_channels, block_out_channels[0], kernel_size=3, padding=1)

self.blocks = nn.ModuleList([])

for i in range(len(block_out_channels) - 1):
channel_in = block_out_channels[i]
channel_out = block_out_channels[i + 1]
self.blocks.append(nn.Conv2d(channel_in, channel_in, kernel_size=3, padding=1))
self.blocks.append(nn.Conv2d(channel_in, channel_out, kernel_size=3, padding=1, stride=2))
# Convolutional blocks for progressive embedding
self.blocks = nn.ModuleList(
[
nn.Conv2d(in_channels, in_channels, kernel_size=3, padding=1)
if i % 2 == 0
else nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1, stride=2)
for i, (in_channels, out_channels) in enumerate(zip(block_out_channels[:-1], block_out_channels[1:]))
]
)

# Final embedding convolution
self.conv_out = nn.Conv2d(block_out_channels[-1], conditioning_embedding_channels, kernel_size=3, padding=1)

def forward(self, conditioning):
embedding = self.conv_in(conditioning)
embedding = F.silu(embedding)

# Process the conditioning input through the embedding layers
embedding = F.silu(self.conv_in(conditioning))
for block in self.blocks:
embedding = block(embedding)
embedding = F.silu(embedding)

embedding = self.conv_out(embedding)
embedding = F.silu(block(embedding))
return self.conv_out(embedding)

return embedding

class ControlNetModel(ModelMixin, ConfigMixin):
"""
Implements a ControlNet model with flexible configuration for conditioning, downsampling, and cross-attention blocks.
"""

@register_to_config
def __init__(
self,
in_channels=4,
flip_sin_to_cos=True,
freq_shift=0,
down_block_types=(
"CrossAttnDownBlock2D",
"CrossAttnDownBlock2D",
"CrossAttnDownBlock2D",
"DownBlock2D",
),
down_block_types=("CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "DownBlock2D"),
only_cross_attention=False,
block_out_channels=(320, 640, 1280, 1280),
layers_per_block=2,
Expand All @@ -79,66 +75,42 @@ def __init__(
):
super().__init__()

# Check inputs
# Validate inputs
if len(block_out_channels) != len(down_block_types):
raise ValueError(
f"Must provide the same number of `block_out_channels` as `down_block_types`. `block_out_channels`: {block_out_channels}. `down_block_types`: {down_block_types}."
)

if not isinstance(only_cross_attention, bool) and len(only_cross_attention) != len(down_block_types):
raise ValueError(
f"Must provide the same number of `only_cross_attention` as `down_block_types`. `only_cross_attention`: {only_cross_attention}. `down_block_types`: {down_block_types}."
f"`block_out_channels` length must match `down_block_types` length. Received {len(block_out_channels)} and {len(down_block_types)}."
)

if not isinstance(attention_head_dim, int) and len(attention_head_dim) != len(down_block_types):
raise ValueError(
f"Must provide the same number of `attention_head_dim` as `down_block_types`. `attention_head_dim`: {attention_head_dim}. `down_block_types`: {down_block_types}."
)
# Convert scalar parameters into lists if needed
if isinstance(only_cross_attention, bool):
only_cross_attention = [only_cross_attention] * len(down_block_types)
if isinstance(attention_head_dim, int):
attention_head_dim = (attention_head_dim,) * len(down_block_types)
if isinstance(transformer_layers_per_block, int):
transformer_layers_per_block = [transformer_layers_per_block] * len(down_block_types)

# Register pre-hook for state dict mapping
self._register_load_state_dict_pre_hook(linear_to_conv2d_map)

# input
conv_in_kernel = 3
conv_in_padding = (conv_in_kernel - 1) // 2
self.conv_in = nn.Conv2d(
in_channels, block_out_channels[0], kernel_size=conv_in_kernel, padding=conv_in_padding
)
# Initial convolution
self.conv_in = nn.Conv2d(in_channels, block_out_channels[0], kernel_size=3, padding=1)

# time
# Time embedding
time_embed_dim = block_out_channels[0] * 4

self.time_proj = Timesteps(block_out_channels[0], flip_sin_to_cos, freq_shift)
timestep_input_dim = block_out_channels[0]

self.time_embedding = TimestepEmbedding(
timestep_input_dim,
time_embed_dim,
)
self.time_embedding = TimestepEmbedding(block_out_channels[0], time_embed_dim)

# control net conditioning embedding
# ControlNet conditioning embedding
self.controlnet_cond_embedding = ControlNetConditioningEmbedding(
conditioning_embedding_channels=block_out_channels[0],
block_out_channels=conditioning_embedding_out_channels,
)

self.down_blocks = nn.ModuleList([])
self.controlnet_down_blocks = nn.ModuleList([])
# Down blocks
self.down_blocks = nn.ModuleList()
self.controlnet_down_blocks = nn.ModuleList([nn.Conv2d(block_out_channels[0], block_out_channels[0], kernel_size=1)])

if isinstance(only_cross_attention, bool):
only_cross_attention = [only_cross_attention] * len(down_block_types)

if isinstance(attention_head_dim, int):
attention_head_dim = (attention_head_dim,) * len(down_block_types)

if isinstance(transformer_layers_per_block, int):
transformer_layers_per_block = [transformer_layers_per_block] * len(down_block_types)

# down
output_channel = block_out_channels[0]

controlnet_block = nn.Conv2d(output_channel, output_channel, kernel_size=1)
self.controlnet_down_blocks.append(controlnet_block)

for i, down_block_type in enumerate(down_block_types):
input_channel = output_channel
output_channel = block_out_channels[i]
Expand All @@ -160,22 +132,14 @@ def __init__(
)
self.down_blocks.append(down_block)

for _ in range(layers_per_block):
controlnet_block = nn.Conv2d(output_channel, output_channel, kernel_size=1)
self.controlnet_down_blocks.append(controlnet_block)

if not is_final_block:
controlnet_block = nn.Conv2d(output_channel, output_channel, kernel_size=1)
self.controlnet_down_blocks.append(controlnet_block)

# mid
mid_block_channel = block_out_channels[-1]

controlnet_block = nn.Conv2d(mid_block_channel, mid_block_channel, kernel_size=1)
self.controlnet_mid_block = controlnet_block
# Add corresponding ControlNet blocks
for _ in range(layers_per_block + (0 if is_final_block else 1)):
self.controlnet_down_blocks.append(nn.Conv2d(output_channel, output_channel, kernel_size=1))

# Mid block
self.controlnet_mid_block = nn.Conv2d(block_out_channels[-1], block_out_channels[-1], kernel_size=1)
self.mid_block = UNetMidBlock2DCrossAttn(
in_channels=mid_block_channel,
in_channels=block_out_channels[-1],
temb_channels=time_embed_dim,
resnet_eps=norm_eps,
resnet_act_fn=act_fn,
Expand All @@ -189,62 +153,48 @@ def __init__(
)

def get_num_residuals(self):
num_res = 2 # initial sample + mid block
"""
Returns the total number of residual connections.
"""
num_res = 2 # Includes initial sample and mid block
for down_block in self.down_blocks:
num_res += len(down_block.resnets)
if hasattr(down_block, "downsamplers") and down_block.downsamplers is not None:
num_res += len(down_block.downsamplers)
return num_res

def forward(
self,
sample,
timestep,
encoder_hidden_states,
controlnet_cond,
):
# 1. time
def forward(self, sample, timestep, encoder_hidden_states, controlnet_cond):
"""
Forward pass through the ControlNet model.
"""
# Time embedding
t_emb = self.time_proj(timestep)
emb = self.time_embedding(t_emb)

# 2. pre-process
# Input convolution and conditioning
sample = self.conv_in(sample)

controlnet_cond = self.controlnet_cond_embedding(controlnet_cond)

sample += controlnet_cond

# 3. down
# Down blocks
down_block_res_samples = (sample,)
for downsample_block in self.down_blocks:
if hasattr(downsample_block, "attentions") and downsample_block.attentions is not None:
sample, res_samples = downsample_block(
hidden_states=sample,
temb=emb,
encoder_hidden_states=encoder_hidden_states,
hidden_states=sample, temb=emb, encoder_hidden_states=encoder_hidden_states
)
else:
sample, res_samples = downsample_block(hidden_states=sample, temb=emb)

down_block_res_samples += res_samples

# 4. mid
# Mid block
if self.mid_block is not None:
sample = self.mid_block(
sample,
emb,
encoder_hidden_states=encoder_hidden_states,
)
sample = self.mid_block(sample, emb, encoder_hidden_states=encoder_hidden_states)

# 5. Control net blocks
# ControlNet-specific processing
controlnet_down_block_res_samples = ()

for down_block_res_sample, controlnet_block in zip(down_block_res_samples, self.controlnet_down_blocks):
down_block_res_sample = controlnet_block(down_block_res_sample)
controlnet_down_block_res_samples += (down_block_res_sample,)

down_block_res_samples = controlnet_down_block_res_samples

mid_block_res_sample = self.controlnet_mid_block(sample)
controlnet_down_block_res_samples += (controlnet_block(down_block_res_sample),)

return down_block_res_samples, mid_block_res_sample
# Return results
return controlnet_down_block_res_samples, self.controlnet_mid_block(sample)
Loading