Skip to content

Commit

Permalink
webui-plugin: lora/hypernetworks/embeddings support
Browse files Browse the repository at this point in the history
  • Loading branch information
jtydhr88 committed May 25, 2023
1 parent 02dda52 commit 9b85c08
Show file tree
Hide file tree
Showing 2 changed files with 126 additions and 3 deletions.
54 changes: 52 additions & 2 deletions packages/stablestudio-plugin-webui/src/Utilities.ts
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
import { StableDiffusionInput } from "@stability/stablestudio-plugin";

const LORA_PATTERN = "<lora:#LORA#:1>";
const HYPERNETWORK_PATTERN = "<hypernet:#HYPERNETWORK#:1>";

export function base64ToBlob(base64: string, contentType = ""): Promise<Blob> {
return fetch(`data:${contentType};base64,${base64}`).then((res) =>
res.blob()
Expand Down Expand Up @@ -28,6 +31,40 @@ export async function fetchOptions(baseUrl: string | undefined) {
return await optionsResponse.json();
}

export async function fetchEmbeddings(baseUrl: string | undefined) {
const embeddingsResponse = await fetch(`${baseUrl}/sdapi/v1/embeddings`, {
method: "GET",
headers: {
"Content-Type": "application/json",
},
});

return await embeddingsResponse.json();
}

export async function fetchHypernetworks(baseUrl: string | undefined) {
const hypernetworksResponse = await fetch(`${baseUrl}/sdapi/v1/hypernetworks`, {
method: "GET",
headers: {
"Content-Type": "application/json",
},
});

return await hypernetworksResponse.json();
}

export async function fetchExtraModels(baseUrl: string | undefined, modelType: string) {
const extraModelsResponse = await fetch(`${baseUrl}/StableStudio/get-extra-models`, {
method: "POST",
headers: {
"Content-Type": "application/json",
},
body: JSON.stringify({ type: modelType }),
});

return await extraModelsResponse.json();
}

export async function setOptions(baseUrl: string | undefined, options: any) {
const optionsResponse = await fetch(`${baseUrl}/sdapi/v1/options`, {
method: "POST",
Expand Down Expand Up @@ -114,9 +151,10 @@ export async function constructPayload(
count?: number | undefined;
},
isUpscale = false,
upscaler: string | undefined
upscaler: string | undefined,
styleType: string | undefined
) {
const { sampler, prompts, initialImage, maskImage, width, height, steps } =
const { sampler, prompts, initialImage, maskImage, width, height, steps, style } =
options?.input ?? {};

// Construct payload
Expand Down Expand Up @@ -151,6 +189,18 @@ export async function constructPayload(
prompts?.find((p) => (p.text && (p.weight ?? 0) < 0) ?? 0 < 0)?.text ??
"";

if ((styleType != "none") && style && (style != "none") && (style != "enhance")) {
let extraStyle = style;

if (styleType == "lora") {
extraStyle = LORA_PATTERN.replace("#LORA#", style);
} else if (styleType == "hypernetworks") {
extraStyle = HYPERNETWORK_PATTERN.replace("#HYPERNETWORK#", style);
}

data.prompt = data.prompt + " " + extraStyle;
}

data.steps = steps ?? 20;
data.batch_size = options?.count;
data.save_images = true;
Expand Down
75 changes: 74 additions & 1 deletion packages/stablestudio-plugin-webui/src/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,9 @@ import {
getImageInfo,
setOptions,
testForHistoryPlugin,
fetchHypernetworks,
fetchEmbeddings,
fetchExtraModels
} from "./Utilities";

const manifest = {
Expand Down Expand Up @@ -64,6 +67,25 @@ const webuiUpscalers = [
},
];

const webuiStyleTypes = [
{
label: "None",
value: "none",
},
{
label: "LoRA",
value: "lora",
},
{
label: "Hypernetworks",
value: "hypernetworks",
},
{
label: "Textual Inversions",
value: "embeddings",
},
];

const getNumber = (strValue: string | null, defaultValue: number) => {
let retValue = defaultValue;

Expand All @@ -80,6 +102,7 @@ export const createPlugin = StableStudio.createPlugin<{
baseUrl: StableStudio.PluginSettingString;
upscaler: StableStudio.PluginSettingString;
historyImagesCount: StableStudio.PluginSettingNumber;
styleType: StableStudio.PluginSettingString;
};
}>(({ set, get }) => {
const webuiLoad = (
Expand All @@ -93,6 +116,7 @@ export const createPlugin = StableStudio.createPlugin<{
| "getStableDiffusionDefaultCount"
| "getStableDiffusionDefaultInput"
| "getStableDiffusionExistingImages"
| "getStableDiffusionStyles"
> => {
webuiHostUrl = webuiHostUrl ?? "http://127.0.0.1:7861";

Expand Down Expand Up @@ -138,7 +162,8 @@ export const createPlugin = StableStudio.createPlugin<{
const data = await constructPayload(
options,
isUpscale,
get().settings.upscaler.value
get().settings.upscaler.value,
get().settings.styleType.value
);

// Send payload to webui
Expand Down Expand Up @@ -273,6 +298,45 @@ export const createPlugin = StableStudio.createPlugin<{
}));
},

getStableDiffusionStyles: async () => {
const styleType = localStorage.getItem("styleType") ?? "None";

const styles: any = [];

if (styleType === "lora") {
const loras = await fetchExtraModels(webuiHostUrl, "Lora");

loras.forEach((lora: any) => {
styles.push({
id: lora["name"],
name: lora["name"],
});
})
}
if (styleType === "hypernetworks") {
const hypernetworks = await fetchHypernetworks(webuiHostUrl);

hypernetworks.forEach((hypernetwork: any) => {
styles.push({
id: hypernetwork["name"],
name: hypernetwork["name"],
});
})
}
if (styleType === "embeddings") {
const embeddings = await fetchEmbeddings(webuiHostUrl);

for (const key in embeddings.loaded) {
styles.push({
id: key,
name: key,
});
}
}

return styles;
},

getStableDiffusionExistingImages: async () => {
const existingImagesResponse = await fetch(
`${webuiHostUrl}/StableStudio/get-generated-images`,
Expand Down Expand Up @@ -379,6 +443,13 @@ export const createPlugin = StableStudio.createPlugin<{
variant: "slider",
value: getNumber(localStorage.getItem("historyImagesCount"), 20),
},

styleType: {
type: "string",
title: "Style Type",
options: webuiStyleTypes,
value: localStorage.getItem("styleType") ?? "None"
}
},

setSetting: (key, value) => {
Expand All @@ -396,6 +467,8 @@ export const createPlugin = StableStudio.createPlugin<{
localStorage.setItem("upscaler1", value);
} else if (key === "historyImagesCount" && typeof value === "number") {
localStorage.setItem("historyImagesCount", value.toString());
} else if (key === "styleType" && typeof value === "string") {
localStorage.setItem("styleType", value);
}
},

Expand Down

0 comments on commit 9b85c08

Please sign in to comment.