Skip to content

Commit

Permalink
fix position_ids
Browse files Browse the repository at this point in the history
  • Loading branch information
wkpark committed Sep 17, 2024
1 parent ffd1585 commit acf31be
Showing 1 changed file with 22 additions and 0 deletions.
22 changes: 22 additions & 0 deletions modules/sd_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -469,6 +469,25 @@ def get_vae_dtype(state_dict=None, state_dict_dtype=None):
return None


def fix_position_ids(state_dict, force=False):
# for SD1.5 or some SDXL with position_ids
for prefix in ("cond_stage_models.", "conditioner.embedders.0."):
position_id_key = f"{prefix}transformer.text_model.embeddings.position_ids"
if position_id_key in state_dict:
original = state_dict[position_id_key]
if original.dtype == torch.int64:
return

if force:
# regenerate
fixed = torch.tensor([list(range(77))], dtype=torch.int64, device=original.device)
else:
fixed = state_dict[position_id_key].to(torch.int64)
print(f"Warning: Fixed position_ids dtype from {original.dtype} to {fixed.dtype}")

state_dict[position_id_key] = fixed


def load_model_weights(model, checkpoint_info: CheckpointInfo, state_dict, timer):
sd_model_hash = checkpoint_info.calculate_shorthash()
timer.record("calculate hash")
Expand All @@ -486,6 +505,9 @@ def load_model_weights(model, checkpoint_info: CheckpointInfo, state_dict, timer
set_model_type(model, state_dict)
set_model_fields(model)

fix_position_ids(state_dict)


if model.is_sdxl:
sd_models_xl.extend_sdxl(model)

Expand Down

0 comments on commit acf31be

Please sign in to comment.