From 2530fae69fdae3614045a5188c4a1fa3ed202bec Mon Sep 17 00:00:00 2001 From: Ngoc Ngo Date: Mon, 8 May 2023 22:38:15 +0700 Subject: [PATCH] #7 add simple api --- scripts/promptgen.py | 30 ++++++++++++++++++++++++++++++ 1 file changed, 30 insertions(+) diff --git a/scripts/promptgen.py b/scripts/promptgen.py index 487c7fb..197518f 100644 --- a/scripts/promptgen.py +++ b/scripts/promptgen.py @@ -1,6 +1,7 @@ import html import os import time +from pydantic import BaseModel, Field import torch import transformers @@ -18,12 +19,29 @@ class Model: model = None tokenizer = None +class ApiResult(BaseModel): + items: list = Field(title="Generated prompts") + +class GenerateRequest(BaseModel): + prompt: str = Field(title="Beginning of the prompt") + sampling_mode: str = Field(title="The mode", default='Top K') + repetition_penalty: int = Field(title="Repetition penalty", default=1) + length_penalty: int = Field(title="Length penalty", default=1) + min_length: int = Field(title="Min length", default=20) + max_length: int = Field(title="Max length", default=150) + num_beams: int = Field(title="Number of beams", default=1) + top_k: int = Field(title="Top K", default=12) + top_p: float = Field(title="Top P", default=0.15) + temperature: float = Field(title="Temperature", default=1) + batch_count: int = Field(title="How many batch", default=1) + batch_size: int = Field(title="How many to generate", default=10) available_models = [] current = Model() base_dir = scripts.basedir() models_dir = os.path.join(base_dir, "models") +current_api_result = ApiResult(items = []) def device(): @@ -115,8 +133,10 @@ def generate(id_task, model_name, batch_count, batch_size, text, *args): markup = '' index = 0 + all_texts = [] for i in range(batch_count): texts = generate_batch(input_ids, *args) + all_texts = all_texts + texts shared.state.nextjob() for generated_text in texts: index += 1 @@ -135,6 +155,7 @@ def generate(id_task, model_name, batch_count, batch_size, text, *args): """ markup += '
' + current_api_result.items = all_texts return markup, '' @@ -237,7 +258,16 @@ def on_unload(): current.model = None current.tokenizer = None +def api_requested(req: GenerateRequest): + generate('', available_models[0], req.batch_count, req.batch_size, req.prompt, req.min_length, req.max_length, req.num_beams, req.temperature, req.repetition_penalty, req.length_penalty, req.sampling_mode, req.top_k, req.top_p) + return current_api_result + +def register_api(demo, app): + if shared.cmd_opts.api: + app.add_api_route("/promptgen", api_requested, methods=["POST"], response_model=ApiResult) + script_callbacks.on_ui_tabs(add_tab) script_callbacks.on_ui_settings(on_ui_settings) script_callbacks.on_script_unloaded(on_unload) +script_callbacks.on_app_started(register_api)