Skip to content

Commit

Permalink
[SD3] Fix mis-matched shape when num_images_per_prompt > 1 using with…
Browse files Browse the repository at this point in the history
…out T5 (text_encoder_3=None) (#8558)

* fix shape mismatch when num_images_per_prompt > 1 and text_encoder_3=None

* style

* fix copies

---------

Co-authored-by: YiYi Xu <[email protected]>
Co-authored-by: yiyixuxu <yixu310@gmail,com>
  • Loading branch information
3 people committed Jun 20, 2024
1 parent dc74c7e commit a0a5427
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 2 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -217,7 +217,11 @@ def _get_t5_prompt_embeds(

if self.text_encoder_3 is None:
return torch.zeros(
(batch_size, self.tokenizer_max_length, self.transformer.config.joint_attention_dim),
(
batch_size * num_images_per_prompt,
self.tokenizer_max_length,
self.transformer.config.joint_attention_dim,
),
device=device,
dtype=dtype,
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -232,7 +232,11 @@ def _get_t5_prompt_embeds(

if self.text_encoder_3 is None:
return torch.zeros(
(batch_size, self.tokenizer_max_length, self.transformer.config.joint_attention_dim),
(
batch_size * num_images_per_prompt,
self.tokenizer_max_length,
self.transformer.config.joint_attention_dim,
),
device=device,
dtype=dtype,
)
Expand Down

0 comments on commit a0a5427

Please sign in to comment.