Skip to content

Commit

Permalink
Feat: Add CivitAI API token support for download (#62)
Browse files Browse the repository at this point in the history
  • Loading branch information
yoland68 authored May 16, 2024
1 parent f74cedc commit 9ebd7d8
Show file tree
Hide file tree
Showing 3 changed files with 75 additions and 11 deletions.
48 changes: 41 additions & 7 deletions comfy_cli/command/models/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,16 @@
from typing_extensions import Annotated

from comfy_cli import tracking, ui
from comfy_cli import constants
from comfy_cli.config_manager import ConfigManager
from comfy_cli.constants import DEFAULT_COMFY_MODEL_PATH
from comfy_cli.file_utils import download_file, DownloadException
from comfy_cli.workspace_manager import WorkspaceManager

app = typer.Typer()

workspace_manager = WorkspaceManager()
config_manager = ConfigManager()


def get_workspace() -> pathlib.Path:
Expand Down Expand Up @@ -66,10 +69,12 @@ def check_civitai_url(url: str) -> Tuple[bool, bool, int, int]:
return False, False, None, None


def request_civitai_model_version_api(version_id: int):
def request_civitai_model_version_api(version_id: int, headers: Optional[dict] = None):
# Make a request to the Civitai API to get the model information
response = requests.get(
f"https://civitai.com/api/v1/model-versions/{version_id}", timeout=10
f"https://civitai.com/api/v1/model-versions/{version_id}",
headers=headers,
timeout=10,
)
response.raise_for_status() # Raise an error for bad status codes

Expand All @@ -81,9 +86,13 @@ def request_civitai_model_version_api(version_id: int):
return model_name, download_url


def request_civitai_model_api(model_id: int, version_id: int = None):
def request_civitai_model_api(
model_id: int, version_id: int = None, headers: Optional[dict] = None
):
# Make a request to the Civitai API to get the model information
response = requests.get(f"https://civitai.com/api/v1/models/{model_id}", timeout=10)
response = requests.get(
f"https://civitai.com/api/v1/models/{model_id}", headers=headers, timeout=10
)
response.raise_for_status() # Raise an error for bad status codes

model_data = response.json()
Expand Down Expand Up @@ -123,18 +132,42 @@ def download(
show_default=True,
),
] = DEFAULT_COMFY_MODEL_PATH,
set_civitai_api_token: Annotated[
Optional[str],
typer.Option(
"--set-civitai-api-token",
help="Set the CivitAI API token to use for model listing.",
show_default=False,
),
] = None,
):

local_filename = None
headers = None
civitai_api_token = None

if set_civitai_api_token is not None:
config_manager.set(constants.CIVITAI_API_TOKEN_KEY, set_civitai_api_token)
civitai_api_token = set_civitai_api_token

else:
civitai_api_token = config_manager.get(constants.CIVITAI_API_TOKEN_KEY)

if civitai_api_token is not None:
headers = {
"Content-Type": "application/json",
"Authorization": f"Bearer {civitai_api_token}",
}

is_civitai_model_url, is_civitai_api_url, model_id, version_id = check_civitai_url(
url
)

is_huggingface = False
if is_civitai_model_url:
local_filename, url = request_civitai_model_api(model_id, version_id)
local_filename, url = request_civitai_model_api(model_id, version_id, headers)
elif is_civitai_api_url:
local_filename, url = request_civitai_model_version_api(version_id)
local_filename, url = request_civitai_model_version_api(version_id, headers)
elif check_huggingface_url(url):
is_huggingface = True
local_filename = potentially_strip_param_url(url.split("/")[-1])
Expand All @@ -157,7 +190,7 @@ def download(

# File does not exist, proceed with download
print(f"Start downloading URL: {url} into {local_filepath}")
download_file(url, local_filepath)
download_file(url, local_filepath, headers)


@app.command()
Expand Down Expand Up @@ -236,6 +269,7 @@ def list(
show_default=True,
),
):

"""Display a list of all models currently downloaded in a table format."""
model_dir = get_workspace() / relative_path
models = list_models(model_dir)
Expand Down
2 changes: 2 additions & 0 deletions comfy_cli/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,8 @@ class OS(Enum):
CONFIG_KEY_INSTALL_EVENT_TRIGGERED = "install_event_triggered"
CONFIG_KEY_BACKGROUND = "background"

CIVITAI_API_TOKEN_KEY = "civitai_api_token"

DEFAULT_TRACKING_VALUE = True

COMFY_LOCK_YAML_FILE = "comfy.lock.yaml"
Expand Down
36 changes: 32 additions & 4 deletions comfy_cli/file_utils.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import os
import pathlib
from typing import Optional
import zipfile

import requests
Expand All @@ -12,8 +13,30 @@ class DownloadException(Exception):
pass


def guess_status_code_reason(status_code: int) -> str:
def guess_status_code_reason(status_code: int, message: str) -> str:
if status_code == 401:
import json

def parse_json(input_data):
try:
# Check if the input is a byte string
if isinstance(input_data, bytes):
# Decode the byte string to a regular string
input_data = input_data.decode("utf-8")

# Parse the string as JSON
json_object = json.loads(input_data)

return json_object

except json.JSONDecodeError as e:
# Handle JSON decoding error
print(f"JSON decoding error: {e}")

msg_json = parse_json(message)
if msg_json is not None:
if "message" in msg_json:
return f"Unauthorized download ({status_code}).\n{msg_json['message']}\nor you can set civitai api token using `comfy model download --set-civitai-api-token <token>`"
return f"Unauthorized download ({status_code}), you might need to manually log into browser to download one"
elif status_code == 403:
return f"Forbidden url ({status_code}), you might need to manually log into browser to download one"
Expand All @@ -22,7 +45,10 @@ def guess_status_code_reason(status_code: int) -> str:
return f"Unknown error occurred (status code: {status_code})"


def download_file(url: str, local_filepath: pathlib.Path):
def download_file(
url: str, local_filepath: pathlib.Path, headers: Optional[dict] = None
):

"""Helper function to download a file."""

import httpx
Expand All @@ -31,7 +57,7 @@ def download_file(url: str, local_filepath: pathlib.Path):
parents=True, exist_ok=True
) # Ensure the directory exists

with httpx.stream("GET", url, follow_redirects=True) as response:
with httpx.stream("GET", url, follow_redirects=True, headers=headers) as response:
if response.status_code == 200:
total = int(response.headers["Content-Length"])
try:
Expand All @@ -49,7 +75,9 @@ def download_file(url: str, local_filepath: pathlib.Path):
if delete_eh:
local_filepath.unlink()
else:
status_reason = guess_status_code_reason(response.status_code)
status_reason = guess_status_code_reason(
response.status_code, response.read()
)
raise DownloadException(f"Failed to download file.\n{status_reason}")


Expand Down

0 comments on commit 9ebd7d8

Please sign in to comment.