Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

final #15963

Closed
wants to merge 3 commits into from
Closed

final #15963

Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions launch.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from modules import launch_utils
import nltk

args = launch_utils.args
python = launch_utils.python
Expand Down Expand Up @@ -41,6 +42,13 @@ def main():
if args.test_server:
configure_for_tests()


# 下载nltk资源(仅需运行一次)
print("Downloading nltk data...")
nltk.download('punkt')
nltk.download('averaged_perceptron_tagger')
nltk.download('stopwords')

start()


Expand Down
164 changes: 103 additions & 61 deletions modules/api/api.py
Original file line number Diff line number Diff line change
@@ -1,21 +1,29 @@
import base64
import io
import os
import time
import datetime
import uvicorn
import ipaddress
import requests
import logging
from openai import AsyncOpenAI
import gradio as gr
from threading import Lock
from io import BytesIO
from starlette.middleware.base import BaseHTTPMiddleware
from fastapi import APIRouter, Depends, FastAPI, Request, Response
from fastapi.security import HTTPBasic, HTTPBasicCredentials
from fastapi.exceptions import HTTPException
from fastapi.responses import JSONResponse
from fastapi.encoders import jsonable_encoder
from secrets import compare_digest

import nltk
from nltk.corpus import stopwords
from nltk.tokenize import word_tokenize
from deep_translator import GoogleTranslator


import modules.shared as shared
from modules import sd_samplers, deepbooru, sd_hijack, images, scripts, ui, postprocessing, errors, restart, shared_items, script_callbacks, infotext_utils, sd_models, sd_schedulers
from modules.api import models
Expand Down Expand Up @@ -133,67 +141,39 @@ def encode_pil_to_base64(image):


def api_middleware(app: FastAPI):
rich_available = False
try:
if os.environ.get('WEBUI_RICH_EXCEPTIONS', None) is not None:
import anyio # importing just so it can be placed on silent list
import starlette # importing just so it can be placed on silent list
from rich.console import Console
console = Console()
rich_available = True
except Exception:
pass

@app.middleware("http")
async def log_and_time(req: Request, call_next):
ts = time.time()
res: Response = await call_next(req)
duration = str(round(time.time() - ts, 4))
res.headers["X-Process-Time"] = duration
endpoint = req.scope.get('path', 'err')
if shared.cmd_opts.api_log and endpoint.startswith('/sdapi'):
print('API {t} {code} {prot}/{ver} {method} {endpoint} {cli} {duration}'.format(
t=datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S.%f"),
code=res.status_code,
ver=req.scope.get('http_version', '0.0'),
cli=req.scope.get('client', ('0:0.0.0', 0))[0],
prot=req.scope.get('scheme', 'err'),
method=req.scope.get('method', 'err'),
endpoint=endpoint,
duration=duration,
))
return res

def handle_exception(request: Request, e: Exception):
err = {
"error": type(e).__name__,
"detail": vars(e).get('detail', ''),
"body": vars(e).get('body', ''),
"errors": str(e),
}
if not isinstance(e, HTTPException): # do not print backtrace on known httpexceptions
message = f"API error: {request.method}: {request.url} {err}"
if rich_available:
print(message)
console.print_exception(show_locals=True, max_frames=2, extra_lines=1, suppress=[anyio, starlette], word_wrap=False, width=min([console.width, 200]))
else:
errors.report(message, exc_info=True)
return JSONResponse(status_code=vars(e).get('status_code', 500), content=jsonable_encoder(err))

@app.middleware("http")
async def exception_handling(request: Request, call_next):
try:
return await call_next(request)
except Exception as e:
return handle_exception(request, e)

@app.exception_handler(Exception)
async def fastapi_exception_handler(request: Request, e: Exception):
return handle_exception(request, e)
# try:
# if os.environ.get('WEBUI_RICH_EXCEPTIONS', None) is not None:
# import anyio # importing just so it can be placed on silent list
# import starlette # importing just so it can be placed on silent list
# from rich.console import Console
# console = Console()
# rich_available = True
# except Exception:
# pass
class LoggingMiddleware(BaseHTTPMiddleware):
async def dispatch(self, req: Request, call_next):
ts = time.time()
res: Response = await call_next(req)
duration = str(round(time.time() - ts, 4))
res.headers["X-Process-Time"] = duration
endpoint = req.scope.get('path', 'err')
if shared.cmd_opts.api_log and endpoint.startswith('/sdapi'):
print('API {t} {code} {prot}/{ver} {method} {endpoint} {cli} {duration}'.format(
t=datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S.%f"),
code=res.status_code,
ver=req.scope.get('http_version', '0.0'),
cli=req.scope.get('client', ('0:0.0.0', 0))[0],
prot=req.scope.get('scheme', 'err'),
method=req.scope.get('method', 'err'),
endpoint=endpoint,
duration=duration,
))
return res

# fastapi 自带了 exception handler
# 不再需要自定义
app.add_middleware(LoggingMiddleware)

@app.exception_handler(HTTPException)
async def http_exception_handler(request: Request, e: HTTPException):
return handle_exception(request, e)


class Api:
Expand Down Expand Up @@ -244,6 +224,7 @@ def __init__(self, app: FastAPI, queue_lock: Lock):
self.add_api_route("/sdapi/v1/scripts", self.get_scripts_list, methods=["GET"], response_model=models.ScriptsList)
self.add_api_route("/sdapi/v1/script-info", self.get_script_info, methods=["GET"], response_model=list[models.ScriptInfo])
self.add_api_route("/sdapi/v1/extensions", self.get_extensions_list, methods=["GET"], response_model=list[models.ExtensionItem])
self.add_api_route("/nlp/v1/nature2prompt", self.nature2prompt, methods=["POST"], response_model=models.Nature2PromptResponse)

if shared.cmd_opts.api_server_stop:
self.add_api_route("/sdapi/v1/server-kill", self.kill_webui, methods=["POST"])
Expand Down Expand Up @@ -429,6 +410,67 @@ def get_field_value(field, params):

return params

def extract_nouns(self, sentence):
# 分词
words = word_tokenize(sentence)

# 词性标注
pos_tagged = nltk.pos_tag(words)

# 提取名词(NN, NNS, NNP, NNPS)
nouns = [word for word, pos in pos_tagged if pos in ['NN', 'NNS', 'NNP', 'NNPS']]

return nouns

async def nature2prompt(self, nature2promptreq: models.Nature2PromptRequest):
client = AsyncOpenAI()

n2p_logger = logging.getLogger(__name__)

input_to_bot = nature2promptreq
chat_completion = await client.chat.completions.create(
messages=[
{
"role": "system",
"content": "现在我需要你提取出几个关键词概括我给你的句子,例如 '海上生明月' 对应 '海上' 和 '明月'。注意,仅返回关键词,不要返回其他内容。",
},
# {"role": "assistant",
# "content":" ",},
{"role": "user",
"content":input_to_bot.text,}
# {"role": "assistant",
# "content":" ",},
],

model="gpt-3.5-turbo",
)
return_to_bot2 = chat_completion.choices[0].message.content


n2p_logger.info(f"返回: {return_to_bot2}")
translated = GoogleTranslator(source='auto', target='en').translate(return_to_bot2) # output -> Weiter so, du bist großartig
n2p_logger.info(f"翻译后:{translated}")
# 确保你已经安装了nltk库。如果没有安装,请运行以下命令:



# 移动到了 launch.py 脚本中,只需要启动时候调用
# 下载nltk资源(仅需运行一次)
# nltk.download('punkt')
# nltk.download('averaged_perceptron_tagger')
# nltk.download('stopwords')

# 示例句子
sentence = translated
n2p_logger.info(f"sentence: {sentence}")
printnon = ''
# 提取名词
nouns = self.extract_nouns(sentence)
for item in nouns:
printnon += (item +", ")
print("过滤后的名词词表:",printnon)
return models.Nature2PromptResponse(filtered_nouns=printnon)

def text2imgapi(self, txt2imgreq: models.StableDiffusionTxt2ImgProcessingAPI):
task_id = txt2imgreq.force_task_id or create_task_id("txt2img")

Expand Down
6 changes: 6 additions & 0 deletions modules/api/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,12 @@ def generate_model(self):
]
).generate_model()

class Nature2PromptRequest(BaseModel):
text: str

class Nature2PromptResponse(BaseModel):
filtered_nouns: str

class TextToImageResponse(BaseModel):
images: list[str] = Field(default=None, title="Image", description="The generated image in base64 format.")
parameters: dict
Expand Down
6 changes: 5 additions & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -30,4 +30,8 @@ torch
torchdiffeq
torchsde
transformers==4.30.2
pillow-avif-plugin==1.4.3
pillow-avif-plugin==1.4.3

openai
nltk
deep_translator
Loading