Skip to content

Commit

Permalink
Update model_loading.py
Browse files Browse the repository at this point in the history
  • Loading branch information
kijai committed Nov 19, 2024
1 parent 67f2f6a commit 516655b
Showing 1 changed file with 1 addition and 30 deletions.
31 changes: 1 addition & 30 deletions model_loading.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,24 +36,13 @@ def patched_write_atomic(

import torch
import torch.nn as nn
from .utils import check_diffusers_version, remove_specific_blocks, log
check_diffusers_version()

from diffusers.models import AutoencoderKLCogVideoX
from diffusers.schedulers import CogVideoXDDIMScheduler
from .custom_cogvideox_transformer_3d import CogVideoXTransformer3DModel
from .pipeline_cogvideox import CogVideoXPipeline
from contextlib import nullcontext

from .cogvideox_fun.transformer_3d import CogVideoXTransformer3DModel as CogVideoXTransformer3DModelFun
from .cogvideox_fun.fun_pab_transformer_3d import CogVideoXTransformer3DModel as CogVideoXTransformer3DModelFunPAB
from .cogvideox_fun.autoencoder_magvit import AutoencoderKLCogVideoX as AutoencoderKLCogVideoXFun

from .cogvideox_fun.pipeline_cogvideox_inpaint import CogVideoX_Fun_Pipeline_Inpaint
from .cogvideox_fun.pipeline_cogvideox_control import CogVideoX_Fun_Pipeline_Control

from .videosys.cogvideox_transformer_3d import CogVideoXTransformer3DModel as CogVideoXTransformer3DModelPAB

from accelerate import init_empty_weights
from accelerate.utils import set_module_tensor_to_device

Expand Down Expand Up @@ -231,8 +220,6 @@ def loadmodel(self, model, precision, quantization="disabled", compile="disabled

if block_edit is not None:
transformer = remove_specific_blocks(transformer, block_edit)



with open(scheduler_path) as f:
scheduler_config = json.load(f)
Expand Down Expand Up @@ -274,22 +261,6 @@ def loadmodel(self, model, precision, quantization="disabled", compile="disabled
for l in lora:
pipe.set_adapters(adapter_list, adapter_weights=adapter_weights)
if fuse:
pipe.fuse_lora(lora_scale=lora[-1]["strength"] / lora_rank, components=["transformer"])

#fp8
if fp8_transformer == "enabled" or fp8_transformer == "fastmode":
for name, param in pipe.transformer.named_parameters():
params_to_keep = {"patch_embed", "lora", "pos_embedding"}
if not any(keyword in name for keyword in params_to_keep):
param.data = param.data.to(torch.float8_e4m3fn)

if fp8_transformer == "fastmode":
from .fp8_optimization import convert_fp8_linear
convert_fp8_linear(pipe.transformer, dtype)

if enable_sequential_cpu_offload:
pipe.enable_sequential_cpu_offload()

lora_scale = 1
dimension_loras = ["orbit", "dimensionx"] # for now dimensionx loras need scaling
if any(item in lora[-1]["path"].lower() for item in dimension_loras):
Expand Down Expand Up @@ -1057,4 +1028,4 @@ def loadmodel(self, model):
"CogVideoLoraSelect": "CogVideo LoraSelect",
"CogVideoXVAELoader": "CogVideoX VAE Loader",
"CogVideoXModelLoader": "CogVideoX Model Loader",
}
}

0 comments on commit 516655b

Please sign in to comment.