From 9ebd7d8dbe6a5a667b01d5b3548936d43b6757b5 Mon Sep 17 00:00:00 2001 From: Yoland Yan <4950057+yoland68@users.noreply.github.com> Date: Thu, 16 May 2024 07:15:32 -0700 Subject: [PATCH] Feat: Add CivitAI API token support for download (#62) --- comfy_cli/command/models/models.py | 48 +++++++++++++++++++++++++----- comfy_cli/constants.py | 2 ++ comfy_cli/file_utils.py | 36 +++++++++++++++++++--- 3 files changed, 75 insertions(+), 11 deletions(-) diff --git a/comfy_cli/command/models/models.py b/comfy_cli/command/models/models.py index cdc0b09..4a2f318 100644 --- a/comfy_cli/command/models/models.py +++ b/comfy_cli/command/models/models.py @@ -7,6 +7,8 @@ from typing_extensions import Annotated from comfy_cli import tracking, ui +from comfy_cli import constants +from comfy_cli.config_manager import ConfigManager from comfy_cli.constants import DEFAULT_COMFY_MODEL_PATH from comfy_cli.file_utils import download_file, DownloadException from comfy_cli.workspace_manager import WorkspaceManager @@ -14,6 +16,7 @@ app = typer.Typer() workspace_manager = WorkspaceManager() +config_manager = ConfigManager() def get_workspace() -> pathlib.Path: @@ -66,10 +69,12 @@ def check_civitai_url(url: str) -> Tuple[bool, bool, int, int]: return False, False, None, None -def request_civitai_model_version_api(version_id: int): +def request_civitai_model_version_api(version_id: int, headers: Optional[dict] = None): # Make a request to the Civitai API to get the model information response = requests.get( - f"https://civitai.com/api/v1/model-versions/{version_id}", timeout=10 + f"https://civitai.com/api/v1/model-versions/{version_id}", + headers=headers, + timeout=10, ) response.raise_for_status() # Raise an error for bad status codes @@ -81,9 +86,13 @@ def request_civitai_model_version_api(version_id: int): return model_name, download_url -def request_civitai_model_api(model_id: int, version_id: int = None): +def request_civitai_model_api( + model_id: int, version_id: int = None, headers: Optional[dict] = None +): # Make a request to the Civitai API to get the model information - response = requests.get(f"https://civitai.com/api/v1/models/{model_id}", timeout=10) + response = requests.get( + f"https://civitai.com/api/v1/models/{model_id}", headers=headers, timeout=10 + ) response.raise_for_status() # Raise an error for bad status codes model_data = response.json() @@ -123,18 +132,42 @@ def download( show_default=True, ), ] = DEFAULT_COMFY_MODEL_PATH, + set_civitai_api_token: Annotated[ + Optional[str], + typer.Option( + "--set-civitai-api-token", + help="Set the CivitAI API token to use for model listing.", + show_default=False, + ), + ] = None, ): local_filename = None + headers = None + civitai_api_token = None + + if set_civitai_api_token is not None: + config_manager.set(constants.CIVITAI_API_TOKEN_KEY, set_civitai_api_token) + civitai_api_token = set_civitai_api_token + + else: + civitai_api_token = config_manager.get(constants.CIVITAI_API_TOKEN_KEY) + + if civitai_api_token is not None: + headers = { + "Content-Type": "application/json", + "Authorization": f"Bearer {civitai_api_token}", + } is_civitai_model_url, is_civitai_api_url, model_id, version_id = check_civitai_url( url ) + is_huggingface = False if is_civitai_model_url: - local_filename, url = request_civitai_model_api(model_id, version_id) + local_filename, url = request_civitai_model_api(model_id, version_id, headers) elif is_civitai_api_url: - local_filename, url = request_civitai_model_version_api(version_id) + local_filename, url = request_civitai_model_version_api(version_id, headers) elif check_huggingface_url(url): is_huggingface = True local_filename = potentially_strip_param_url(url.split("/")[-1]) @@ -157,7 +190,7 @@ def download( # File does not exist, proceed with download print(f"Start downloading URL: {url} into {local_filepath}") - download_file(url, local_filepath) + download_file(url, local_filepath, headers) @app.command() @@ -236,6 +269,7 @@ def list( show_default=True, ), ): + """Display a list of all models currently downloaded in a table format.""" model_dir = get_workspace() / relative_path models = list_models(model_dir) diff --git a/comfy_cli/constants.py b/comfy_cli/constants.py index 2c77561..3a46cfd 100644 --- a/comfy_cli/constants.py +++ b/comfy_cli/constants.py @@ -38,6 +38,8 @@ class OS(Enum): CONFIG_KEY_INSTALL_EVENT_TRIGGERED = "install_event_triggered" CONFIG_KEY_BACKGROUND = "background" +CIVITAI_API_TOKEN_KEY = "civitai_api_token" + DEFAULT_TRACKING_VALUE = True COMFY_LOCK_YAML_FILE = "comfy.lock.yaml" diff --git a/comfy_cli/file_utils.py b/comfy_cli/file_utils.py index 6eee407..a7f9806 100644 --- a/comfy_cli/file_utils.py +++ b/comfy_cli/file_utils.py @@ -1,5 +1,6 @@ import os import pathlib +from typing import Optional import zipfile import requests @@ -12,8 +13,30 @@ class DownloadException(Exception): pass -def guess_status_code_reason(status_code: int) -> str: +def guess_status_code_reason(status_code: int, message: str) -> str: if status_code == 401: + import json + + def parse_json(input_data): + try: + # Check if the input is a byte string + if isinstance(input_data, bytes): + # Decode the byte string to a regular string + input_data = input_data.decode("utf-8") + + # Parse the string as JSON + json_object = json.loads(input_data) + + return json_object + + except json.JSONDecodeError as e: + # Handle JSON decoding error + print(f"JSON decoding error: {e}") + + msg_json = parse_json(message) + if msg_json is not None: + if "message" in msg_json: + return f"Unauthorized download ({status_code}).\n{msg_json['message']}\nor you can set civitai api token using `comfy model download --set-civitai-api-token `" return f"Unauthorized download ({status_code}), you might need to manually log into browser to download one" elif status_code == 403: return f"Forbidden url ({status_code}), you might need to manually log into browser to download one" @@ -22,7 +45,10 @@ def guess_status_code_reason(status_code: int) -> str: return f"Unknown error occurred (status code: {status_code})" -def download_file(url: str, local_filepath: pathlib.Path): +def download_file( + url: str, local_filepath: pathlib.Path, headers: Optional[dict] = None +): + """Helper function to download a file.""" import httpx @@ -31,7 +57,7 @@ def download_file(url: str, local_filepath: pathlib.Path): parents=True, exist_ok=True ) # Ensure the directory exists - with httpx.stream("GET", url, follow_redirects=True) as response: + with httpx.stream("GET", url, follow_redirects=True, headers=headers) as response: if response.status_code == 200: total = int(response.headers["Content-Length"]) try: @@ -49,7 +75,9 @@ def download_file(url: str, local_filepath: pathlib.Path): if delete_eh: local_filepath.unlink() else: - status_reason = guess_status_code_reason(response.status_code) + status_reason = guess_status_code_reason( + response.status_code, response.read() + ) raise DownloadException(f"Failed to download file.\n{status_reason}")