Skip to content

Commit

Permalink
Offer to confirm on download name
Browse files Browse the repository at this point in the history
  • Loading branch information
yoland68 committed May 6, 2024
1 parent a91fc21 commit 7df0e80
Show file tree
Hide file tree
Showing 2 changed files with 33 additions and 3 deletions.
19 changes: 16 additions & 3 deletions comfy_cli/command/models/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,11 @@ class DownloadException(Exception):
pass


def potentially_strip_param_url(path_name: str) -> str:
path_name = path_name.split("?")[0]
return path_name


@app.command()
@tracking.track_command("model")
def download(
Expand All @@ -42,16 +47,24 @@ def download(
):
"""Download a model to a specified relative path if it is not already downloaded."""
# Convert relative path to absolute path based on the current working directory
local_filename = url.split("/")[-1]
local_filename = potentially_strip_param_url(url.split("/")[-1])
local_filename = ui.prompt_input(
"Enter filename to save model as", default=local_filename
)
if local_filename is None:
raise typer.Exit(code=1)
if local_filename == "":
raise DownloadException("Filename cannot be empty")

local_filepath = get_workspace() / relative_path / local_filename

# Check if the file already exists
if local_filepath.exists():
typer.echo(f"File already exists: {local_filepath}")
print(f"[bold red]File already exists: {local_filepath}[/bold red]")
return

# File does not exist, proceed with download
typer.echo(f"Start downloading URL: {url} into {local_filepath}")
print(f"Start downloading URL: {url} into {local_filepath}")
download_file(url, local_filepath)


Expand Down
17 changes: 17 additions & 0 deletions comfy_cli/ui.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,23 @@ def prompt_select_enum(question: str, choices: list) -> str:
return choice_map[selected]


def prompt_input(question: str, default: str = "") -> str:
"""
Asks the user for an input using questionary.
Args:
question (str): The question to display to the user.
default (str): The default value for the input.
Returns:
str: The user's input.
Raises:
KeyboardInterrupt: If the user interrupts the input.
"""
return questionary.text(question, default=default).ask()


def prompt_multi_select(prompt: str, choices: List[str]) -> List[str]:
"""
Prompts the user to select multiple items from a list of choices.
Expand Down

0 comments on commit 7df0e80

Please sign in to comment.