From 9b85c08c3c61fc7216bf49e56fe74d7705f730ae Mon Sep 17 00:00:00 2001 From: jtydhr88 Date: Thu, 25 May 2023 12:11:37 -0400 Subject: [PATCH] webui-plugin: lora/hypernetworks/embeddings support --- .../src/Utilities.ts | 54 ++++++++++++- .../stablestudio-plugin-webui/src/index.ts | 75 ++++++++++++++++++- 2 files changed, 126 insertions(+), 3 deletions(-) diff --git a/packages/stablestudio-plugin-webui/src/Utilities.ts b/packages/stablestudio-plugin-webui/src/Utilities.ts index 88e163b5..812f95d8 100644 --- a/packages/stablestudio-plugin-webui/src/Utilities.ts +++ b/packages/stablestudio-plugin-webui/src/Utilities.ts @@ -1,5 +1,8 @@ import { StableDiffusionInput } from "@stability/stablestudio-plugin"; +const LORA_PATTERN = ""; +const HYPERNETWORK_PATTERN = ""; + export function base64ToBlob(base64: string, contentType = ""): Promise { return fetch(`data:${contentType};base64,${base64}`).then((res) => res.blob() @@ -28,6 +31,40 @@ export async function fetchOptions(baseUrl: string | undefined) { return await optionsResponse.json(); } +export async function fetchEmbeddings(baseUrl: string | undefined) { + const embeddingsResponse = await fetch(`${baseUrl}/sdapi/v1/embeddings`, { + method: "GET", + headers: { + "Content-Type": "application/json", + }, + }); + + return await embeddingsResponse.json(); +} + +export async function fetchHypernetworks(baseUrl: string | undefined) { + const hypernetworksResponse = await fetch(`${baseUrl}/sdapi/v1/hypernetworks`, { + method: "GET", + headers: { + "Content-Type": "application/json", + }, + }); + + return await hypernetworksResponse.json(); +} + +export async function fetchExtraModels(baseUrl: string | undefined, modelType: string) { + const extraModelsResponse = await fetch(`${baseUrl}/StableStudio/get-extra-models`, { + method: "POST", + headers: { + "Content-Type": "application/json", + }, + body: JSON.stringify({ type: modelType }), + }); + + return await extraModelsResponse.json(); +} + export async function setOptions(baseUrl: string | undefined, options: any) { const optionsResponse = await fetch(`${baseUrl}/sdapi/v1/options`, { method: "POST", @@ -114,9 +151,10 @@ export async function constructPayload( count?: number | undefined; }, isUpscale = false, - upscaler: string | undefined + upscaler: string | undefined, + styleType: string | undefined ) { - const { sampler, prompts, initialImage, maskImage, width, height, steps } = + const { sampler, prompts, initialImage, maskImage, width, height, steps, style } = options?.input ?? {}; // Construct payload @@ -151,6 +189,18 @@ export async function constructPayload( prompts?.find((p) => (p.text && (p.weight ?? 0) < 0) ?? 0 < 0)?.text ?? ""; + if ((styleType != "none") && style && (style != "none") && (style != "enhance")) { + let extraStyle = style; + + if (styleType == "lora") { + extraStyle = LORA_PATTERN.replace("#LORA#", style); + } else if (styleType == "hypernetworks") { + extraStyle = HYPERNETWORK_PATTERN.replace("#HYPERNETWORK#", style); + } + + data.prompt = data.prompt + " " + extraStyle; + } + data.steps = steps ?? 20; data.batch_size = options?.count; data.save_images = true; diff --git a/packages/stablestudio-plugin-webui/src/index.ts b/packages/stablestudio-plugin-webui/src/index.ts index 3ad06c27..f890bc30 100644 --- a/packages/stablestudio-plugin-webui/src/index.ts +++ b/packages/stablestudio-plugin-webui/src/index.ts @@ -8,6 +8,9 @@ import { getImageInfo, setOptions, testForHistoryPlugin, + fetchHypernetworks, + fetchEmbeddings, + fetchExtraModels } from "./Utilities"; const manifest = { @@ -64,6 +67,25 @@ const webuiUpscalers = [ }, ]; +const webuiStyleTypes = [ + { + label: "None", + value: "none", + }, + { + label: "LoRA", + value: "lora", + }, + { + label: "Hypernetworks", + value: "hypernetworks", + }, + { + label: "Textual Inversions", + value: "embeddings", + }, +]; + const getNumber = (strValue: string | null, defaultValue: number) => { let retValue = defaultValue; @@ -80,6 +102,7 @@ export const createPlugin = StableStudio.createPlugin<{ baseUrl: StableStudio.PluginSettingString; upscaler: StableStudio.PluginSettingString; historyImagesCount: StableStudio.PluginSettingNumber; + styleType: StableStudio.PluginSettingString; }; }>(({ set, get }) => { const webuiLoad = ( @@ -93,6 +116,7 @@ export const createPlugin = StableStudio.createPlugin<{ | "getStableDiffusionDefaultCount" | "getStableDiffusionDefaultInput" | "getStableDiffusionExistingImages" + | "getStableDiffusionStyles" > => { webuiHostUrl = webuiHostUrl ?? "http://127.0.0.1:7861"; @@ -138,7 +162,8 @@ export const createPlugin = StableStudio.createPlugin<{ const data = await constructPayload( options, isUpscale, - get().settings.upscaler.value + get().settings.upscaler.value, + get().settings.styleType.value ); // Send payload to webui @@ -273,6 +298,45 @@ export const createPlugin = StableStudio.createPlugin<{ })); }, + getStableDiffusionStyles: async () => { + const styleType = localStorage.getItem("styleType") ?? "None"; + + const styles: any = []; + + if (styleType === "lora") { + const loras = await fetchExtraModels(webuiHostUrl, "Lora"); + + loras.forEach((lora: any) => { + styles.push({ + id: lora["name"], + name: lora["name"], + }); + }) + } + if (styleType === "hypernetworks") { + const hypernetworks = await fetchHypernetworks(webuiHostUrl); + + hypernetworks.forEach((hypernetwork: any) => { + styles.push({ + id: hypernetwork["name"], + name: hypernetwork["name"], + }); + }) + } + if (styleType === "embeddings") { + const embeddings = await fetchEmbeddings(webuiHostUrl); + + for (const key in embeddings.loaded) { + styles.push({ + id: key, + name: key, + }); + } + } + + return styles; + }, + getStableDiffusionExistingImages: async () => { const existingImagesResponse = await fetch( `${webuiHostUrl}/StableStudio/get-generated-images`, @@ -379,6 +443,13 @@ export const createPlugin = StableStudio.createPlugin<{ variant: "slider", value: getNumber(localStorage.getItem("historyImagesCount"), 20), }, + + styleType: { + type: "string", + title: "Style Type", + options: webuiStyleTypes, + value: localStorage.getItem("styleType") ?? "None" + } }, setSetting: (key, value) => { @@ -396,6 +467,8 @@ export const createPlugin = StableStudio.createPlugin<{ localStorage.setItem("upscaler1", value); } else if (key === "historyImagesCount" && typeof value === "number") { localStorage.setItem("historyImagesCount", value.toString()); + } else if (key === "styleType" && typeof value === "string") { + localStorage.setItem("styleType", value); } },