Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add optional webhook after job finished #5276

Open
wants to merge 3 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 23 additions & 1 deletion main.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,8 @@
import time
from comfy.cli_args import args
from app.logger import setup_logger

import aiohttp
import asyncio

setup_logger(log_level=args.verbose)

Expand Down Expand Up @@ -104,6 +105,18 @@ def cuda_malloc_warning():
if cuda_malloc_warning:
logging.warning("\nWARNING: this card most likely does not support cuda-malloc, if you get \"CUDA error\" please run ComfyUI with: --disable-cuda-malloc\n")

async def send_webhook(server, prompt_id, data):
webhook_url = server.webhooks.pop(prompt_id, None)
if webhook_url:
try:
async with aiohttp.ClientSession() as session:
logging.info(f"Sending webhook for prompt {prompt_id}")
async with session.post(webhook_url, json=data) as response:
if response.status != 200:
logging.warning(f"Webhook delivery failed for prompt {prompt_id}. Status: {response.status}")
except Exception as e:
logging.error(f"Error sending webhook for prompt {prompt_id}: {str(e)}")

def prompt_worker(q, server):
e = execution.PromptExecutor(server, lru_size=args.cache_lru)
last_gc_collect = 0
Expand Down Expand Up @@ -137,6 +150,15 @@ def prompt_worker(q, server):
execution_time = current_time - execution_start_time
logging.info("Prompt executed in {:.2f} seconds".format(execution_time))

# Send webhook after execution is complete
webhook_data = {
"prompt_id": prompt_id,
"execution_time": execution_time,
"status": "success" if e.success else "error",
"result": e.history_result
}
asyncio.run_coroutine_threadsafe(send_webhook(server, prompt_id, webhook_data), server.loop)

flags = q.get_flags()
free_memory = flags.get("free_memory", False)

Expand Down
4 changes: 4 additions & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,11 @@ scipy
tqdm
psutil

#webhook handling
aiohttp

#non essential dependencies:
kornia>=0.7.1
spandrel
soundfile

16 changes: 14 additions & 2 deletions server.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,10 @@
from model_filemanager import download_model, DownloadModelStatus
from typing import Optional
from api_server.routes.internal.internal_routes import InternalRoutes
import requests
import json
from urllib.parse import urlparse, parse_qs, urlencode


class BinaryEventTypes:
PREVIEW_IMAGE = 1
Expand Down Expand Up @@ -159,6 +163,7 @@ def __init__(self, loop):
self.messages = asyncio.Queue()
self.client_session:Optional[aiohttp.ClientSession] = None
self.number = 0
self.webhooks = {}

middlewares = [cache_control]
if args.enable_cors_header:
Expand Down Expand Up @@ -623,8 +628,15 @@ async def post_prompt(request):
if "client_id" in json_data:
extra_data["client_id"] = json_data["client_id"]
if valid[0]:
prompt_id = str(uuid.uuid4())
# allow to accept prompt_id from api caller to reference it in webhook handler if needed
prompt_id = json_data.get("prompt_id", str(uuid.uuid4()))
outputs_to_execute = valid[2]

# Add webhook URL to the webhooks dict if provided
webhook_url = json_data.get("webhook_url")
if webhook_url:
self.webhooks[prompt_id] = webhook_url

self.prompt_queue.put((number, prompt_id, prompt, extra_data, outputs_to_execute))
response = {"prompt_id": prompt_id, "number": number, "node_errors": valid[3]}
return web.json_response(response)
Expand Down Expand Up @@ -864,4 +876,4 @@ def trigger_on_prompt(self, json_data):
logging.warning(f"[ERROR] An error occurred during the on_prompt_handler processing")
logging.warning(traceback.format_exc())

return json_data
return json_data