From 8e0881d9abfa4e139f2dbf3c9a072dd8d4b6d1ba Mon Sep 17 00:00:00 2001 From: w-e-w <40751091+w-e-w@users.noreply.github.com> Date: Sat, 27 Jul 2024 21:10:20 +0900 Subject: [PATCH 01/13] fix image upscale on cpu for some reason upscale using cpu will fail with RuntimeError: Inplace update to inference tensor outside InferenceMode switch from no_grad to inference_mode seems to have fixed it --- modules/upscaler_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/modules/upscaler_utils.py b/modules/upscaler_utils.py index 5ecbbed96fd..a8408f05bca 100644 --- a/modules/upscaler_utils.py +++ b/modules/upscaler_utils.py @@ -41,7 +41,7 @@ def upscale_pil_patch(model, img: Image.Image) -> Image.Image: """ param = torch_utils.get_param(model) - with torch.no_grad(): + with torch.inference_mode(): tensor = pil_image_to_torch_bgr(img).unsqueeze(0) # add batch dimension tensor = tensor.to(device=param.device, dtype=param.dtype) with devices.without_autocast(): From cbaaf0af0e0e3173edafaca83448ca90c5ac14e1 Mon Sep 17 00:00:00 2001 From: hello2564 Date: Wed, 31 Jul 2024 14:55:30 +0800 Subject: [PATCH 02/13] fix NGMS pr typo --- modules/shared_options.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/modules/shared_options.py b/modules/shared_options.py index 9f4520274b1..6d0b8ac5f25 100644 --- a/modules/shared_options.py +++ b/modules/shared_options.py @@ -231,7 +231,7 @@ options_templates.update(options_section(('optimizations', "Optimizations", "sd"), { "cross_attention_optimization": OptionInfo("Automatic", "Cross attention optimization", gr.Dropdown, lambda: {"choices": shared_items.cross_attention_optimizations()}), - "s_min_uncond": OptionInfo(0.0, "Negative Guidance minimum sigma", gr.Slider, {"minimum": 0.0, "maximum": 15.0, "step": 0.01}, infotext='NGMS').link("PR", "https://github.com/AUTOMATIC1111/stablediffusion-webui/pull/9177").info("skip negative prompt for some steps when the image is almost ready; 0=disable, higher=faster"), + "s_min_uncond": OptionInfo(0.0, "Negative Guidance minimum sigma", gr.Slider, {"minimum": 0.0, "maximum": 15.0, "step": 0.01}, infotext='NGMS').link("PR", "https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/9177").info("skip negative prompt for some steps when the image is almost ready; 0=disable, higher=faster"), "s_min_uncond_all": OptionInfo(False, "Negative Guidance minimum sigma all steps", infotext='NGMS all steps').info("By default, NGMS above skips every other step; this makes it skip all steps"), "token_merging_ratio": OptionInfo(0.0, "Token merging ratio", gr.Slider, {"minimum": 0.0, "maximum": 0.9, "step": 0.1}, infotext='Token merging ratio').link("PR", "https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/9256").info("0=disable, higher=faster"), "token_merging_ratio_img2img": OptionInfo(0.0, "Token merging ratio for img2img", gr.Slider, {"minimum": 0.0, "maximum": 0.9, "step": 0.1}).info("only applies if non-zero and overrides above"), From 9677b09b7c9b06f51cfa6a997a32eac69bb63663 Mon Sep 17 00:00:00 2001 From: gutris1 <132797949+gutris1@users.noreply.github.com> Date: Wed, 7 Aug 2024 17:37:23 +0700 Subject: [PATCH 03/13] add break-word for geninfo in pnginfo --- modules/extras.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/modules/extras.py b/modules/extras.py index 2a310ae3f25..adc88ca558b 100644 --- a/modules/extras.py +++ b/modules/extras.py @@ -23,7 +23,7 @@ def run_pnginfo(image): info = '' for key, text in items.items(): info += f""" -
+

{plaintext_to_html(str(key))}

{plaintext_to_html(str(text))}

From f57ec2b53b2fd89672f5611dee3c5cb33738c30a Mon Sep 17 00:00:00 2001 From: missionfloyd Date: Tue, 3 Sep 2024 19:58:29 -0600 Subject: [PATCH 04/13] Update stable diffusion 1.5 URL --- modules/sd_models.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/modules/sd_models.py b/modules/sd_models.py index 55bd9ca5e43..da84e326d7e 100644 --- a/modules/sd_models.py +++ b/modules/sd_models.py @@ -159,7 +159,7 @@ def list_models(): model_url = None expected_sha256 = None else: - model_url = f"{shared.hf_endpoint}/runwayml/stable-diffusion-v1-5/resolve/main/v1-5-pruned-emaonly.safetensors" + model_url = f"{shared.hf_endpoint}/botp/stable-diffusion-v1-5/resolve/main/v1-5-pruned-emaonly.safetensors" expected_sha256 = '6ce0161689b3853acaa03779ec93eafe75a02f4ced659bee03f50797806fa2fa' model_list = modelloader.load_models(model_path=model_path, model_url=model_url, command_path=shared.cmd_opts.ckpt_dir, ext_filter=[".ckpt", ".safetensors"], download_name="v1-5-pruned-emaonly.safetensors", ext_blacklist=[".vae.ckpt", ".vae.safetensors"], hash_prefix=expected_sha256) From c9a06d1093df828c7ff1dd356f38cf5ae41c1227 Mon Sep 17 00:00:00 2001 From: missionfloyd Date: Tue, 8 Oct 2024 16:50:39 -0600 Subject: [PATCH 05/13] Use stable-diffusion-v1-5 repo instead --- modules/sd_models.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/modules/sd_models.py b/modules/sd_models.py index da84e326d7e..f9f3f07310b 100644 --- a/modules/sd_models.py +++ b/modules/sd_models.py @@ -159,7 +159,7 @@ def list_models(): model_url = None expected_sha256 = None else: - model_url = f"{shared.hf_endpoint}/botp/stable-diffusion-v1-5/resolve/main/v1-5-pruned-emaonly.safetensors" + model_url = f"{shared.hf_endpoint}/stable-diffusion-v1-5/stable-diffusion-v1-5/resolve/main/v1-5-pruned-emaonly.safetensors" expected_sha256 = '6ce0161689b3853acaa03779ec93eafe75a02f4ced659bee03f50797806fa2fa' model_list = modelloader.load_models(model_path=model_path, model_url=model_url, command_path=shared.cmd_opts.ckpt_dir, ext_filter=[".ckpt", ".safetensors"], download_name="v1-5-pruned-emaonly.safetensors", ext_blacklist=[".vae.ckpt", ".vae.safetensors"], hash_prefix=expected_sha256) From 1ae073c052c2a1eb2b3e87464a576451b3197326 Mon Sep 17 00:00:00 2001 From: catboxanon <122327233+catboxanon@users.noreply.github.com> Date: Sat, 19 Oct 2024 06:53:19 -0400 Subject: [PATCH 06/13] Support SDXL v-pred models --- configs/sd_xl_v.yaml | 98 +++++++++++++++++++++++++++++++++++++ modules/sd_models.py | 7 +-- modules/sd_models_config.py | 4 ++ 3 files changed, 106 insertions(+), 3 deletions(-) create mode 100644 configs/sd_xl_v.yaml diff --git a/configs/sd_xl_v.yaml b/configs/sd_xl_v.yaml new file mode 100644 index 00000000000..c755dc74fda --- /dev/null +++ b/configs/sd_xl_v.yaml @@ -0,0 +1,98 @@ +model: + target: sgm.models.diffusion.DiffusionEngine + params: + scale_factor: 0.13025 + disable_first_stage_autocast: True + + denoiser_config: + target: sgm.modules.diffusionmodules.denoiser.DiscreteDenoiser + params: + num_idx: 1000 + + weighting_config: + target: sgm.modules.diffusionmodules.denoiser_weighting.EpsWeighting + scaling_config: + target: sgm.modules.diffusionmodules.denoiser_scaling.VScaling + discretization_config: + target: sgm.modules.diffusionmodules.discretizer.LegacyDDPMDiscretization + + network_config: + target: sgm.modules.diffusionmodules.openaimodel.UNetModel + params: + adm_in_channels: 2816 + num_classes: sequential + use_checkpoint: True + in_channels: 4 + out_channels: 4 + model_channels: 320 + attention_resolutions: [4, 2] + num_res_blocks: 2 + channel_mult: [1, 2, 4] + num_head_channels: 64 + use_spatial_transformer: True + use_linear_in_transformer: True + transformer_depth: [1, 2, 10] # note: the first is unused (due to attn_res starting at 2) 32, 16, 8 --> 64, 32, 16 + context_dim: 2048 + spatial_transformer_attn_type: softmax-xformers + legacy: False + + conditioner_config: + target: sgm.modules.GeneralConditioner + params: + emb_models: + # crossattn cond + - is_trainable: False + input_key: txt + target: sgm.modules.encoders.modules.FrozenCLIPEmbedder + params: + layer: hidden + layer_idx: 11 + # crossattn and vector cond + - is_trainable: False + input_key: txt + target: sgm.modules.encoders.modules.FrozenOpenCLIPEmbedder2 + params: + arch: ViT-bigG-14 + version: laion2b_s39b_b160k + freeze: True + layer: penultimate + always_return_pooled: True + legacy: False + # vector cond + - is_trainable: False + input_key: original_size_as_tuple + target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND + params: + outdim: 256 # multiplied by two + # vector cond + - is_trainable: False + input_key: crop_coords_top_left + target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND + params: + outdim: 256 # multiplied by two + # vector cond + - is_trainable: False + input_key: target_size_as_tuple + target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND + params: + outdim: 256 # multiplied by two + + first_stage_config: + target: sgm.models.autoencoder.AutoencoderKLInferenceWrapper + params: + embed_dim: 4 + monitor: val/rec_loss + ddconfig: + attn_type: vanilla-xformers + double_z: true + z_channels: 4 + resolution: 256 + in_channels: 3 + out_ch: 3 + ch: 128 + ch_mult: [1, 2, 4, 4] + num_res_blocks: 2 + attn_resolutions: [] + dropout: 0.0 + lossconfig: + target: torch.nn.Identity diff --git a/modules/sd_models.py b/modules/sd_models.py index 55bd9ca5e43..abe1c966c26 100644 --- a/modules/sd_models.py +++ b/modules/sd_models.py @@ -783,7 +783,7 @@ def get_obj_from_str(string, reload=False): return getattr(importlib.import_module(module, package=None), cls) -def load_model(checkpoint_info=None, already_loaded_state_dict=None): +def load_model(checkpoint_info=None, already_loaded_state_dict=None, checkpoint_config=None): from modules import sd_hijack checkpoint_info = checkpoint_info or select_checkpoint() @@ -801,7 +801,8 @@ def load_model(checkpoint_info=None, already_loaded_state_dict=None): else: state_dict = get_checkpoint_state_dict(checkpoint_info, timer) - checkpoint_config = sd_models_config.find_checkpoint_config(state_dict, checkpoint_info) + if not checkpoint_config: + checkpoint_config = sd_models_config.find_checkpoint_config(state_dict, checkpoint_info) clip_is_included_into_sd = any(x for x in [sd1_clip_weight, sd2_clip_weight, sdxl_clip_weight, sdxl_refiner_clip_weight] if x in state_dict) timer.record("find config") @@ -974,7 +975,7 @@ def reload_model_weights(sd_model=None, info=None, forced_reload=False): if sd_model is not None: send_model_to_trash(sd_model) - load_model(checkpoint_info, already_loaded_state_dict=state_dict) + load_model(checkpoint_info, already_loaded_state_dict=state_dict, checkpoint_config=checkpoint_config) return model_data.sd_model try: diff --git a/modules/sd_models_config.py b/modules/sd_models_config.py index fb44c5a8d98..3c1e4a1518f 100644 --- a/modules/sd_models_config.py +++ b/modules/sd_models_config.py @@ -14,6 +14,7 @@ config_sd2v = os.path.join(sd_repo_configs_path, "v2-inference-v.yaml") config_sd2_inpainting = os.path.join(sd_repo_configs_path, "v2-inpainting-inference.yaml") config_sdxl = os.path.join(sd_xl_repo_configs_path, "sd_xl_base.yaml") +config_sdxlv = os.path.join(sd_configs_path, "sd_xl_v.yaml") config_sdxl_refiner = os.path.join(sd_xl_repo_configs_path, "sd_xl_refiner.yaml") config_sdxl_inpainting = os.path.join(sd_configs_path, "sd_xl_inpaint.yaml") config_depth_model = os.path.join(sd_repo_configs_path, "v2-midas-inference.yaml") @@ -81,6 +82,9 @@ def guess_model_config_from_state_dict(sd, filename): if diffusion_model_input.shape[1] == 9: return config_sdxl_inpainting else: + if ('v_pred' in sd): + del sd['v_pred'] + return config_sdxlv return config_sdxl if sd.get('conditioner.embedders.0.model.ln_final.weight', None) is not None: From 907bfb5ef0f4b0bd3ff18313872c687c249b057e Mon Sep 17 00:00:00 2001 From: AUTOMATIC1111 <16777216c@gmail.com> Date: Sat, 19 Oct 2024 17:33:58 +0300 Subject: [PATCH 07/13] add w-e-w and catboxanon to codeowners file --- CODEOWNERS | 13 +------------ 1 file changed, 1 insertion(+), 12 deletions(-) diff --git a/CODEOWNERS b/CODEOWNERS index 7438c9bc69d..4eb946b009d 100644 --- a/CODEOWNERS +++ b/CODEOWNERS @@ -1,12 +1 @@ -* @AUTOMATIC1111 - -# if you were managing a localization and were removed from this file, this is because -# the intended way to do localizations now is via extensions. See: -# https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Developing-extensions -# Make a repo with your localization and since you are still listed as a collaborator -# you can add it to the wiki page yourself. This change is because some people complained -# the git commit log is cluttered with things unrelated to almost everyone and -# because I believe this is the best overall for the project to handle localizations almost -# entirely without my oversight. - - +* @AUTOMATIC1111 @w-e-w @catboxanon From c2ce1d3b9c9be2bef7efd1fe5b5a53424105a1c5 Mon Sep 17 00:00:00 2001 From: catboxanon <122327233+catboxanon@users.noreply.github.com> Date: Sat, 19 Oct 2024 19:58:13 -0400 Subject: [PATCH 08/13] Automatically enable ztSNR based on existence of key in state_dict --- modules/sd_models.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/modules/sd_models.py b/modules/sd_models.py index 167d4ff3615..1c7d370e97b 100644 --- a/modules/sd_models.py +++ b/modules/sd_models.py @@ -423,6 +423,10 @@ def load_model_weights(model, checkpoint_info: CheckpointInfo, state_dict, timer set_model_type(model, state_dict) set_model_fields(model) + if 'ztsnr' in state_dict: + model.ztsnr = True + else: + model.ztsnr = False if model.is_sdxl: sd_models_xl.extend_sdxl(model) @@ -661,7 +665,7 @@ def apply_alpha_schedule_override(sd_model, p=None): p.extra_generation_params['Downcast alphas_cumprod'] = opts.use_downcasted_alpha_bar sd_model.alphas_cumprod = sd_model.alphas_cumprod.half().to(shared.device) - if opts.sd_noise_schedule == "Zero Terminal SNR": + if opts.sd_noise_schedule == "Zero Terminal SNR" or (hasattr(sd_model, 'ztsnr') and sd_model.ztsnr): if p is not None: p.extra_generation_params['Noise Schedule'] = opts.sd_noise_schedule sd_model.alphas_cumprod = rescale_zero_terminal_snr_abar(sd_model.alphas_cumprod).to(shared.device) From c2bc187ce7e36bd83833058020cdfacc64de3436 Mon Sep 17 00:00:00 2001 From: w-e-w <40751091+w-e-w@users.noreply.github.com> Date: Sun, 20 Oct 2024 09:51:59 +0900 Subject: [PATCH 09/13] fix modalImageViewer preview/result flicker (#16426) --- javascript/imageviewer.js | 30 ++++++++++++++++++------------ 1 file changed, 18 insertions(+), 12 deletions(-) diff --git a/javascript/imageviewer.js b/javascript/imageviewer.js index 9b23f4700b3..979d05de5ba 100644 --- a/javascript/imageviewer.js +++ b/javascript/imageviewer.js @@ -13,6 +13,7 @@ function showModal(event) { if (modalImage.style.display === 'none') { lb.style.setProperty('background-image', 'url(' + source.src + ')'); } + updateModalImage(); lb.style.display = "flex"; lb.focus(); @@ -31,21 +32,26 @@ function negmod(n, m) { return ((n % m) + m) % m; } +function updateModalImage() { + const modalImage = gradioApp().getElementById("modalImage"); + let currentButton = selected_gallery_button(); + let preview = gradioApp().querySelectorAll('.livePreview > img'); + if (opts.js_live_preview_in_modal_lightbox && preview.length > 0) { + // show preview image if available + modalImage.src = preview[preview.length - 1].src; + } else if (currentButton?.children?.length > 0 && modalImage.src != currentButton.children[0].src) { + modalImage.src = currentButton.children[0].src; + if (modalImage.style.display === 'none') { + const modal = gradioApp().getElementById("lightboxModal"); + modal.style.setProperty('background-image', `url(${modalImage.src})`); + } + } +} + function updateOnBackgroundChange() { const modalImage = gradioApp().getElementById("modalImage"); if (modalImage && modalImage.offsetParent) { - let currentButton = selected_gallery_button(); - let preview = gradioApp().querySelectorAll('.livePreview > img'); - if (opts.js_live_preview_in_modal_lightbox && preview.length > 0) { - // show preview image if available - modalImage.src = preview[preview.length - 1].src; - } else if (currentButton?.children?.length > 0 && modalImage.src != currentButton.children[0].src) { - modalImage.src = currentButton.children[0].src; - if (modalImage.style.display === 'none') { - const modal = gradioApp().getElementById("lightboxModal"); - modal.style.setProperty('background-image', `url(${modalImage.src})`); - } - } + updateModalImage(); } } From 65423d2b33a2ea696e6d62186e12a93af272403d Mon Sep 17 00:00:00 2001 From: w-e-w <40751091+w-e-w@users.noreply.github.com> Date: Sun, 20 Oct 2024 09:52:47 +0900 Subject: [PATCH 10/13] MIME type text/css (#16406) --- modules/ui.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/modules/ui.py b/modules/ui.py index f48638f69ad..9a76b5fcd90 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -44,6 +44,9 @@ mimetypes.add_type('image/webp', '.webp') mimetypes.add_type('image/avif', '.avif') +# override potentially incorrect mimetypes +mimetypes.add_type('text/css', '.css') + if not cmd_opts.share and not cmd_opts.listen: # fix gradio phoning home gradio.utils.version_check = lambda: None From 6a5976631334745287c1e5a65b9358c88bee164b Mon Sep 17 00:00:00 2001 From: w-e-w <40751091+w-e-w@users.noreply.github.com> Date: Sun, 20 Oct 2024 09:56:12 +0900 Subject: [PATCH 11/13] Add Skip Early CFG to XYZ (#16282) Co-authored-by: Yevhenii Hurin --- modules/shared_options.py | 2 +- scripts/xyz_grid.py | 1 + 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/modules/shared_options.py b/modules/shared_options.py index 6d0b8ac5f25..efede7067f2 100644 --- a/modules/shared_options.py +++ b/modules/shared_options.py @@ -404,7 +404,7 @@ 'uni_pc_order': OptionInfo(3, "UniPC order", gr.Slider, {"minimum": 1, "maximum": 50, "step": 1}, infotext='UniPC order').info("must be < sampling steps"), 'uni_pc_lower_order_final': OptionInfo(True, "UniPC lower order final", infotext='UniPC lower order final'), 'sd_noise_schedule': OptionInfo("Default", "Noise schedule for sampling", gr.Radio, {"choices": ["Default", "Zero Terminal SNR"]}, infotext="Noise Schedule").info("for use with zero terminal SNR trained models"), - 'skip_early_cond': OptionInfo(0.0, "Ignore negative prompt during early sampling", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}, infotext="Skip Early CFG").info("disables CFG on a proportion of steps at the beginning of generation; 0=skip none; 1=skip all; can both improve sample diversity/quality and speed up sampling"), + 'skip_early_cond': OptionInfo(0.0, "Ignore negative prompt during early sampling", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}, infotext="Skip Early CFG").info("disables CFG on a proportion of steps at the beginning of generation; 0=skip none; 1=skip all; can both improve sample diversity/quality and speed up sampling; XYZ plot: Skip Early CFG"), 'beta_dist_alpha': OptionInfo(0.6, "Beta scheduler - alpha", gr.Slider, {"minimum": 0.01, "maximum": 1.0, "step": 0.01}, infotext='Beta scheduler alpha').info('Default = 0.6; the alpha parameter of the beta distribution used in Beta sampling'), 'beta_dist_beta': OptionInfo(0.6, "Beta scheduler - beta", gr.Slider, {"minimum": 0.01, "maximum": 1.0, "step": 0.01}, infotext='Beta scheduler beta').info('Default = 0.6; the beta parameter of the beta distribution used in Beta sampling'), })) diff --git a/scripts/xyz_grid.py b/scripts/xyz_grid.py index 6a42a04d9a3..c60dd6dda2f 100644 --- a/scripts/xyz_grid.py +++ b/scripts/xyz_grid.py @@ -259,6 +259,7 @@ def __init__(self, *args, **kwargs): AxisOption("Schedule min sigma", float, apply_override("sigma_min")), AxisOption("Schedule max sigma", float, apply_override("sigma_max")), AxisOption("Schedule rho", float, apply_override("rho")), + AxisOption("Skip Early CFG", float, apply_override('skip_early_cond')), AxisOption("Beta schedule alpha", float, apply_override("beta_dist_alpha")), AxisOption("Beta schedule beta", float, apply_override("beta_dist_beta")), AxisOption("Eta", float, apply_field("eta")), From bb1f39196e6f47d200ddbe1a6b4be27157720546 Mon Sep 17 00:00:00 2001 From: w-e-w <40751091+w-e-w@users.noreply.github.com> Date: Sun, 20 Oct 2024 09:58:53 +0900 Subject: [PATCH 12/13] clarify readme: weget ... chmod +x webui.sh (#16251) --- README.md | 1 + 1 file changed, 1 insertion(+) diff --git a/README.md b/README.md index bc62945c0c5..007f590d249 100644 --- a/README.md +++ b/README.md @@ -148,6 +148,7 @@ python_cmd="python3.11" 2. Navigate to the directory you would like the webui to be installed and execute the following command: ```bash wget -q https://raw.githubusercontent.com/AUTOMATIC1111/stable-diffusion-webui/master/webui.sh +chmod +x webui.sh ``` Or just clone the repo wherever you want: ```bash From 984b952eb30798ea478cb97b7ca79ff32196f4ce Mon Sep 17 00:00:00 2001 From: w-e-w <40751091+w-e-w@users.noreply.github.com> Date: Thu, 24 Oct 2024 22:05:51 +0900 Subject: [PATCH 13/13] Fix DAT models download (#16302) --- modules/dat_model.py | 20 +++++++++-- modules/modelloader.py | 25 +------------- modules/upscaler.py | 3 +- modules/util.py | 77 ++++++++++++++++++++++++++++++++++++++++++ 4 files changed, 97 insertions(+), 28 deletions(-) diff --git a/modules/dat_model.py b/modules/dat_model.py index 495d5f4937d..298d160d1e8 100644 --- a/modules/dat_model.py +++ b/modules/dat_model.py @@ -49,7 +49,18 @@ def load_model(self, path): scaler.local_data_path = modelloader.load_file_from_url( scaler.data_path, model_dir=self.model_download_path, + hash_prefix=scaler.sha256, ) + + if os.path.getsize(scaler.local_data_path) < 200: + # Re-download if the file is too small, probably an LFS pointer + scaler.local_data_path = modelloader.load_file_from_url( + scaler.data_path, + model_dir=self.model_download_path, + hash_prefix=scaler.sha256, + re_download=True, + ) + if not os.path.exists(scaler.local_data_path): raise FileNotFoundError(f"DAT data missing: {scaler.local_data_path}") return scaler @@ -60,20 +71,23 @@ def get_dat_models(scaler): return [ UpscalerData( name="DAT x2", - path="https://github.com/n0kovo/dat_upscaler_models/raw/main/DAT/DAT_x2.pth", + path="https://huggingface.co/w-e-w/DAT/resolve/main/experiments/pretrained_models/DAT/DAT_x2.pth", scale=2, upscaler=scaler, + sha256='7760aa96e4ee77e29d4f89c3a4486200042e019461fdb8aa286f49aa00b89b51', ), UpscalerData( name="DAT x3", - path="https://github.com/n0kovo/dat_upscaler_models/raw/main/DAT/DAT_x3.pth", + path="https://huggingface.co/w-e-w/DAT/resolve/main/experiments/pretrained_models/DAT/DAT_x3.pth", scale=3, upscaler=scaler, + sha256='581973e02c06f90d4eb90acf743ec9604f56f3c2c6f9e1e2c2b38ded1f80d197', ), UpscalerData( name="DAT x4", - path="https://github.com/n0kovo/dat_upscaler_models/raw/main/DAT/DAT_x4.pth", + path="https://huggingface.co/w-e-w/DAT/resolve/main/experiments/pretrained_models/DAT/DAT_x4.pth", scale=4, upscaler=scaler, + sha256='391a6ce69899dff5ea3214557e9d585608254579217169faf3d4c353caff049e', ), ] diff --git a/modules/modelloader.py b/modules/modelloader.py index 36e7415af43..f5a2ff79c30 100644 --- a/modules/modelloader.py +++ b/modules/modelloader.py @@ -10,6 +10,7 @@ from modules import shared from modules.upscaler import Upscaler, UpscalerLanczos, UpscalerNearest, UpscalerNone +from modules.util import load_file_from_url # noqa, backwards compatibility if TYPE_CHECKING: import spandrel @@ -17,30 +18,6 @@ logger = logging.getLogger(__name__) -def load_file_from_url( - url: str, - *, - model_dir: str, - progress: bool = True, - file_name: str | None = None, - hash_prefix: str | None = None, -) -> str: - """Download a file from `url` into `model_dir`, using the file present if possible. - - Returns the path to the downloaded file. - """ - os.makedirs(model_dir, exist_ok=True) - if not file_name: - parts = urlparse(url) - file_name = os.path.basename(parts.path) - cached_file = os.path.abspath(os.path.join(model_dir, file_name)) - if not os.path.exists(cached_file): - print(f'Downloading: "{url}" to {cached_file}\n') - from torch.hub import download_url_to_file - download_url_to_file(url, cached_file, progress=progress, hash_prefix=hash_prefix) - return cached_file - - def load_models(model_path: str, model_url: str = None, command_path: str = None, ext_filter=None, download_name=None, ext_blacklist=None, hash_prefix=None) -> list: """ A one-and done loader to try finding the desired models in specified directories. diff --git a/modules/upscaler.py b/modules/upscaler.py index 507881fede2..12ab3547cf6 100644 --- a/modules/upscaler.py +++ b/modules/upscaler.py @@ -93,13 +93,14 @@ class UpscalerData: scaler: Upscaler = None model: None - def __init__(self, name: str, path: str, upscaler: Upscaler = None, scale: int = 4, model=None): + def __init__(self, name: str, path: str, upscaler: Upscaler = None, scale: int = 4, model=None, sha256: str = None): self.name = name self.data_path = path self.local_data_path = path self.scaler = upscaler self.scale = scale self.model = model + self.sha256 = sha256 def __repr__(self): return f"" diff --git a/modules/util.py b/modules/util.py index 7911b0db72c..baeba2fa271 100644 --- a/modules/util.py +++ b/modules/util.py @@ -211,3 +211,80 @@ def open_folder(path): subprocess.Popen(["explorer.exe", subprocess.check_output(["wslpath", "-w", path])]) else: subprocess.Popen(["xdg-open", path]) + + +def load_file_from_url( + url: str, + *, + model_dir: str, + progress: bool = True, + file_name: str | None = None, + hash_prefix: str | None = None, + re_download: bool = False, +) -> str: + """Download a file from `url` into `model_dir`, using the file present if possible. + Returns the path to the downloaded file. + + file_name: if specified, it will be used as the filename, otherwise the filename will be extracted from the url. + file is downloaded to {file_name}.tmp then moved to the final location after download is complete. + hash_prefix: sha256 hex string, if provided, the hash of the downloaded file will be checked against this prefix. + if the hash does not match, the temporary file is deleted and a ValueError is raised. + re_download: forcibly re-download the file even if it already exists. + """ + from urllib.parse import urlparse + import requests + try: + from tqdm import tqdm + except ImportError: + class tqdm: + def __init__(self, *args, **kwargs): + pass + + def update(self, n=1, *args, **kwargs): + pass + + def __enter__(self): + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + pass + + if not file_name: + parts = urlparse(url) + file_name = os.path.basename(parts.path) + + cached_file = os.path.abspath(os.path.join(model_dir, file_name)) + + if re_download or not os.path.exists(cached_file): + os.makedirs(model_dir, exist_ok=True) + temp_file = os.path.join(model_dir, f"{file_name}.tmp") + print(f'\nDownloading: "{url}" to {cached_file}') + response = requests.get(url, stream=True) + response.raise_for_status() + total_size = int(response.headers.get('content-length', 0)) + with tqdm(total=total_size, unit='B', unit_scale=True, desc=file_name, disable=not progress) as progress_bar: + with open(temp_file, 'wb') as file: + for chunk in response.iter_content(chunk_size=1024): + if chunk: + file.write(chunk) + progress_bar.update(len(chunk)) + + if hash_prefix and not compare_sha256(temp_file, hash_prefix): + print(f"Hash mismatch for {temp_file}. Deleting the temporary file.") + os.remove(temp_file) + raise ValueError(f"File hash does not match the expected hash prefix {hash_prefix}!") + + os.rename(temp_file, cached_file) + return cached_file + + +def compare_sha256(file_path: str, hash_prefix: str) -> bool: + """Check if the SHA256 hash of the file matches the given prefix.""" + import hashlib + hash_sha256 = hashlib.sha256() + blksize = 1024 * 1024 + + with open(file_path, "rb") as f: + for chunk in iter(lambda: f.read(blksize), b""): + hash_sha256.update(chunk) + return hash_sha256.hexdigest().startswith(hash_prefix.strip().lower())