Skip to content

Commit

Permalink
implement hash prune
Browse files Browse the repository at this point in the history
  • Loading branch information
Haoming02 committed Dec 3, 2024
1 parent 0120768 commit 4498e7a
Show file tree
Hide file tree
Showing 2 changed files with 51 additions and 0 deletions.
47 changes: 47 additions & 0 deletions modules/cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,3 +121,50 @@ def cached_data_for_file(subsection, title, filename, func):
dump_cache()

return entry['value']


def prune_unused_hash():
import glob

from modules.paths_internal import extensions_dir

existing_cache = cache('extensions-git')
total_count = len(existing_cache)
with tqdm.tqdm(total=total_count, desc='pruning extensions') as progress:
for name in existing_cache:
if not os.path.isdir(os.path.join(extensions_dir, name)):
existing_cache.pop(name)
progress.update(1)

def file_exists(parent_dir, filename):
matches = glob.glob(os.path.join(parent_dir, '**', f'{filename}*'), recursive=True)
return len(matches) > 0

from modules.paths_internal import models_path
from modules.shared import cmd_opts

for db in ('hashes', 'hashes-addnet', 'safetensors-metadata'):
existing_cache = cache(db)
total_count = len(existing_cache)
with tqdm.tqdm(total=total_count, desc=f'pruning {db}') as progress:
for name in existing_cache:
if '/' not in name:
progress.update(1)
continue

category, filename = name.split('/', 1)
if category.lower() == 'lora':
exists = file_exists(os.path.join(models_path, 'Lora'), filename)
elif category.lower() == 'checkpoint':
exists = file_exists(os.path.join(models_path, 'Stable-diffusion'), filename)
elif category.lower() == 'textual_inversion':
exists = file_exists(cmd_opts.embeddings_dir, filename)
else:
progress.update(1)
continue

if not exists:
del existing_cache[name]
progress.update(1)

print('Finish pruning hash')
4 changes: 4 additions & 0 deletions modules/ui_settings.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import gradio as gr

from modules import ui_common, shared, script_callbacks, scripts, sd_models, sysinfo, timer, shared_items
from modules.cache import prune_unused_hash
from modules.call_queue import wrap_gradio_call_no_job
from modules.options import options_section
from modules.shared import opts
Expand Down Expand Up @@ -190,6 +191,7 @@ def create_ui(self, loadsave, dummy_component):
with gr.Row():
calculate_all_checkpoint_hash = gr.Button(value='Calculate hash for all checkpoint', elem_id="calculate_all_checkpoint_hash")
calculate_all_checkpoint_hash_threads = gr.Number(value=1, label="Number of parallel calculations", elem_id="calculate_all_checkpoint_hash_threads", precision=0, minimum=1)
prune_all_unused_hash = gr.Button(value='Prune all unused hash', elem_id="prune_all_unused_hash")

with gr.TabItem("Licenses", id="licenses", elem_id="settings_tab_licenses"):
gr.HTML(shared.html("licenses.html"), elem_id="licenses")
Expand Down Expand Up @@ -285,6 +287,8 @@ def calculate_all_checkpoint_hash_fn(max_thread):
inputs=[calculate_all_checkpoint_hash_threads],
)

prune_all_unused_hash.click(fn=prune_unused_hash)

self.interface = settings_interface

def add_quicksettings(self):
Expand Down

0 comments on commit 4498e7a

Please sign in to comment.