From 4498e7a273ddd645c2a83f06008a6982a1dee94f Mon Sep 17 00:00:00 2001 From: Haoming Date: Tue, 3 Dec 2024 22:17:56 +0800 Subject: [PATCH] implement hash prune --- modules/cache.py | 47 ++++++++++++++++++++++++++++++++++++++++++ modules/ui_settings.py | 4 ++++ 2 files changed, 51 insertions(+) diff --git a/modules/cache.py b/modules/cache.py index f4e5f702b42..3a44993fd1d 100644 --- a/modules/cache.py +++ b/modules/cache.py @@ -121,3 +121,50 @@ def cached_data_for_file(subsection, title, filename, func): dump_cache() return entry['value'] + + +def prune_unused_hash(): + import glob + + from modules.paths_internal import extensions_dir + + existing_cache = cache('extensions-git') + total_count = len(existing_cache) + with tqdm.tqdm(total=total_count, desc='pruning extensions') as progress: + for name in existing_cache: + if not os.path.isdir(os.path.join(extensions_dir, name)): + existing_cache.pop(name) + progress.update(1) + + def file_exists(parent_dir, filename): + matches = glob.glob(os.path.join(parent_dir, '**', f'{filename}*'), recursive=True) + return len(matches) > 0 + + from modules.paths_internal import models_path + from modules.shared import cmd_opts + + for db in ('hashes', 'hashes-addnet', 'safetensors-metadata'): + existing_cache = cache(db) + total_count = len(existing_cache) + with tqdm.tqdm(total=total_count, desc=f'pruning {db}') as progress: + for name in existing_cache: + if '/' not in name: + progress.update(1) + continue + + category, filename = name.split('/', 1) + if category.lower() == 'lora': + exists = file_exists(os.path.join(models_path, 'Lora'), filename) + elif category.lower() == 'checkpoint': + exists = file_exists(os.path.join(models_path, 'Stable-diffusion'), filename) + elif category.lower() == 'textual_inversion': + exists = file_exists(cmd_opts.embeddings_dir, filename) + else: + progress.update(1) + continue + + if not exists: + del existing_cache[name] + progress.update(1) + + print('Finish pruning hash') diff --git a/modules/ui_settings.py b/modules/ui_settings.py index e53ad50f8f4..e7c8e5707a2 100644 --- a/modules/ui_settings.py +++ b/modules/ui_settings.py @@ -1,6 +1,7 @@ import gradio as gr from modules import ui_common, shared, script_callbacks, scripts, sd_models, sysinfo, timer, shared_items +from modules.cache import prune_unused_hash from modules.call_queue import wrap_gradio_call_no_job from modules.options import options_section from modules.shared import opts @@ -190,6 +191,7 @@ def create_ui(self, loadsave, dummy_component): with gr.Row(): calculate_all_checkpoint_hash = gr.Button(value='Calculate hash for all checkpoint', elem_id="calculate_all_checkpoint_hash") calculate_all_checkpoint_hash_threads = gr.Number(value=1, label="Number of parallel calculations", elem_id="calculate_all_checkpoint_hash_threads", precision=0, minimum=1) + prune_all_unused_hash = gr.Button(value='Prune all unused hash', elem_id="prune_all_unused_hash") with gr.TabItem("Licenses", id="licenses", elem_id="settings_tab_licenses"): gr.HTML(shared.html("licenses.html"), elem_id="licenses") @@ -285,6 +287,8 @@ def calculate_all_checkpoint_hash_fn(max_thread): inputs=[calculate_all_checkpoint_hash_threads], ) + prune_all_unused_hash.click(fn=prune_unused_hash) + self.interface = settings_interface def add_quicksettings(self):