Skip to content

Commit

Permalink
Support llama hunyuan video text encoder in scaled fp8 format.
Browse files Browse the repository at this point in the history
  • Loading branch information
comfyanonymous committed Dec 17, 2024
1 parent f4cdede commit d6656b0
Show file tree
Hide file tree
Showing 3 changed files with 25 additions and 4 deletions.
10 changes: 9 additions & 1 deletion comfy/sd.py
Original file line number Diff line number Diff line change
Expand Up @@ -603,6 +603,14 @@ def t5xxl_detect(clip_data):

return {}

def llama_detect(clip_data):
weight_name = "model.layers.0.self_attn.k_proj.weight"

for sd in clip_data:
if weight_name in sd:
return comfy.text_encoders.hunyuan_video.llama_detect(sd)

return {}

def load_text_encoder_state_dicts(state_dicts=[], embedding_directory=None, clip_type=CLIPType.STABLE_DIFFUSION, model_options={}):
clip_data = state_dicts
Expand Down Expand Up @@ -669,7 +677,7 @@ class EmptyClass:
clip_target.clip = comfy.text_encoders.flux.flux_clip(**t5xxl_detect(clip_data))
clip_target.tokenizer = comfy.text_encoders.flux.FluxTokenizer
elif clip_type == CLIPType.HUNYUAN_VIDEO:
clip_target.clip = comfy.text_encoders.hunyuan_video.hunyuan_video_clip() #TODO
clip_target.clip = comfy.text_encoders.hunyuan_video.hunyuan_video_clip(**llama_detect(clip_data))
clip_target.tokenizer = comfy.text_encoders.hunyuan_video.HunyuanVideoTokenizer
else:
clip_target.clip = sdxl_clip.SDXLClipModel
Expand Down
6 changes: 3 additions & 3 deletions comfy/supported_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -783,9 +783,9 @@ def process_unet_state_dict_for_saving(self, state_dict):
return utils.state_dict_prefix_replace(state_dict, replace_prefix)

def clip_target(self, state_dict={}):
# pref = self.text_encoder_key_prefix[0]
# t5_detect = comfy.text_encoders.sd3_clip.t5_xxl_detect(state_dict, "{}t5xxl.transformer.".format(pref))
return supported_models_base.ClipTarget(comfy.text_encoders.hunyuan_video.HunyuanVideoTokenizer, comfy.text_encoders.hunyuan_video.hunyuan_video_clip()) #TODO
pref = self.text_encoder_key_prefix[0]
hunyuan_detect = comfy.text_encoders.hunyuan_video.llama_detect(state_dict, "{}llama.transformer.".format(pref))
return supported_models_base.ClipTarget(comfy.text_encoders.hunyuan_video.HunyuanVideoTokenizer, comfy.text_encoders.hunyuan_video.hunyuan_video_clip(**hunyuan_detect))

models = [Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p, SD3, StableAudio, AuraFlow, HunyuanDiT, HunyuanDiT1, FluxInpaint, Flux, FluxSchnell, GenmoMochi, LTXV, HunyuanVideo]

Expand Down
13 changes: 13 additions & 0 deletions comfy/text_encoders/hunyuan_video.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,19 @@
import os


def llama_detect(state_dict, prefix=""):
out = {}
t5_key = "{}model.norm.weight".format(prefix)
if t5_key in state_dict:
out["dtype_llama"] = state_dict[t5_key].dtype

scaled_fp8_key = "{}scaled_fp8".format(prefix)
if scaled_fp8_key in state_dict:
out["llama_scaled_fp8"] = state_dict[scaled_fp8_key].dtype

return out


class LLAMA3Tokenizer(sd1_clip.SDTokenizer):
def __init__(self, embedding_directory=None, tokenizer_data={}, min_length=256):
tokenizer_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "llama_tokenizer")
Expand Down

0 comments on commit d6656b0

Please sign in to comment.