diff --git a/scripts/api.py b/scripts/api.py new file mode 100644 index 0000000..d893248 --- /dev/null +++ b/scripts/api.py @@ -0,0 +1,68 @@ +from typing import List +from fastapi import FastAPI, HTTPException +from pydantic import BaseModel, Field +from scripts.promptgen import api_generate, return_available_models + +import traceback + +from enum import Enum + +class SamplingMode(str, Enum): + TopK = "Top K" + TopP = "Top P" + + +class PromptRequest(BaseModel): + model_name: str = Field("AUTOMATIC/promptgen-lexart", description="Model name.") + batch_count: int = Field(1, ge=1, le=100, description="Batch count.") + batch_size: int = Field(20, ge=1, le=100, description="Batch size.") + text: str = Field("", description="Input text.") + min_length: int = Field(20, ge=1, le=400, description="Minimum length.") + max_length: int = Field(150, ge=1, le=400, description="Maximum length.") + num_beams: int = Field(1, ge=1, le=8, description="Number of beams.") + temperature: float = Field(1, ge=0, le=4, description="Temperature.") + repetition_penalty: float = Field(1, ge=1, le=4, description="Repetition penalty.") + length_preference: float = Field(1, ge=-10, le=10, description="Length preference.") + sampling_mode: SamplingMode = Field(SamplingMode.TopK, description="Sampling mode, Either 'Top K' or 'Top P'") + top_k: int = Field(12, ge=1, le=50, description="Top K.") + top_p: float = Field(0.15, ge=0, le=1, description="Top P.") + +def promptgen_api(_, app: FastAPI): + @app.get("/promptgen/list_models") + async def list_models(): + try: + return {"available_models": return_available_models()} + except Exception as e: + raise HTTPException(status_code=500, detail=str(e)) + + @app.post("/promptgen/generate") + async def generate_prompts(prompt_request: PromptRequest): + try: + prompts = api_generate( + model_name=prompt_request.model_name, + batch_count=prompt_request.batch_count, + batch_size=prompt_request.batch_size, + text=prompt_request.text, + min_length=prompt_request.min_length, + max_length=prompt_request.max_length, + num_beams=prompt_request.num_beams, + temperature=prompt_request.temperature, + repetition_penalty=prompt_request.repetition_penalty, + length_penalty=prompt_request.length_preference, + sampling_mode=prompt_request.sampling_mode, + top_k=prompt_request.top_k, + top_p=prompt_request.top_p, + ) + return {"prompts": prompts} + except Exception as e: + tb = traceback.format_exc() # This will capture the full traceback as a string. + detailed_error_message = f"{str(e)}\n\n{tb}" + raise HTTPException(status_code=500, detail=detailed_error_message) + + +try: + import modules.script_callbacks as script_callbacks + + script_callbacks.on_app_started(promptgen_api) +except: + pass diff --git a/scripts/initpromptgen.py b/scripts/initpromptgen.py new file mode 100644 index 0000000..d9e04ee --- /dev/null +++ b/scripts/initpromptgen.py @@ -0,0 +1,8 @@ + +from modules import scripts, script_callbacks, devices, ui +from scripts.promptgen import add_tab, on_ui_settings, on_unload + + +script_callbacks.on_ui_tabs(add_tab) +script_callbacks.on_ui_settings(on_ui_settings) +script_callbacks.on_script_unloaded(on_unload) diff --git a/scripts/promptgen.py b/scripts/promptgen.py index 487c7fb..1736f59 100644 --- a/scripts/promptgen.py +++ b/scripts/promptgen.py @@ -1,6 +1,7 @@ import html import os import time +import re import torch import transformers @@ -27,7 +28,11 @@ class Model: def device(): - return devices.cpu if shared.opts.promptgen_device == 'cpu' else devices.device + if hasattr(shared.opts, "promptgen_device"): + return devices.cpu if shared.opts.promptgen_device == 'cpu' else devices.device + else: + os.environ['CUDA_LAUNCH_BLOCKING'] = "1" + return devices.cpu def list_available_models(): @@ -45,6 +50,10 @@ def list_available_models(): available_models.append(name) +def return_available_models(): + list_available_models() + return available_models + def get_model_path(name): dirname = os.path.join(models_dir, name) @@ -84,27 +93,33 @@ def model_selection_changed(model_name): devices.torch_gc() -def generate(id_task, model_name, batch_count, batch_size, text, *args): - shared.state.textinfo = "Loading model..." - shared.state.job_count = batch_count - - if current.name != model_name: - current.tokenizer = None - current.model = None - current.name = None +def api_generate(model_name, batch_count, batch_size, text, min_length, max_length, num_beams, temperature, repetition_penalty, length_penalty, sampling_mode, top_k, top_p): + shared.state.job_count = batch_count + setup_model(model_name) - if model_name != 'None': - path = get_model_path(model_name) - current.tokenizer = transformers.AutoTokenizer.from_pretrained(path) - current.model = transformers.AutoModelForCausalLM.from_pretrained(path) - current.name = model_name - - assert current.model, 'No model available' - assert current.tokenizer, 'No tokenizer available' + input_ids = current.tokenizer(text, return_tensors="pt").input_ids + if input_ids.shape[1] == 0: + input_ids = torch.asarray([[current.tokenizer.bos_token_id]], dtype=torch.long) + input_ids = input_ids.to(device()) + input_ids = input_ids.repeat((batch_size, 1)) + + prompts = [] + for i in range(batch_count): + texts = generate_batch(input_ids, min_length, max_length, num_beams, temperature, repetition_penalty, length_penalty, sampling_mode, top_k, top_p) + shared.state.nextjob() + for idx, text in enumerate(texts): + text = html.escape(text) + text = text.replace('\n', ' ').replace('\r', ' ') # replace newline and carriage return with space + text = re.sub('[^A-Za-z0-9 .,]+', '', text) # keep alphanumeric characters, space, period, and comma + texts[idx] = ' '.join(text.split()) # remove excess spaces + prompts.extend(texts) + return prompts + - current.model.to(device()) +def generate(id_task, model_name, batch_count, batch_size, text, *args): + + setup_model(model_name) - shared.state.textinfo = "" input_ids = current.tokenizer(text, return_tensors="pt").input_ids if input_ids.shape[1] == 0: @@ -138,6 +153,26 @@ def generate(id_task, model_name, batch_count, batch_size, text, *args): return markup, '' +def setup_model(model_name): + shared.state.textinfo = "Loading model..." + if current.name != model_name: + current.tokenizer = None + current.model = None + current.name = None + + if model_name != 'None': + path = get_model_path(model_name) + current.tokenizer = transformers.AutoTokenizer.from_pretrained(path) + current.model = transformers.AutoModelForCausalLM.from_pretrained(path) + current.name = model_name + + assert current.model, 'No model available' + assert current.tokenizer, 'No tokenizer available' + + current.model.to(device()) + shared.state.textinfo = "" + + def find_prompts(fields): field_prompt = [x for x in fields if x[1] == "Prompt"][0] @@ -237,7 +272,3 @@ def on_unload(): current.model = None current.tokenizer = None - -script_callbacks.on_ui_tabs(add_tab) -script_callbacks.on_ui_settings(on_ui_settings) -script_callbacks.on_script_unloaded(on_unload)