Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

webui-plugin improvment: group generated images by prompt, lora/hypernetworks/embedding support #70

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
53 changes: 51 additions & 2 deletions packages/stablestudio-plugin-webui/src/Utilities.ts
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
import { StableDiffusionInput } from "@stability/stablestudio-plugin";

const LORA_PATTERN = "<lora:#LORA#:1>";
const HYPERNETWORK_PATTERN = "<hypernet:#HYPERNETWORK#:1>";

export function base64ToBlob(base64: string, contentType = ""): Promise<Blob> {
return fetch(`data:${contentType};base64,${base64}`).then((res) =>
res.blob()
Expand Down Expand Up @@ -28,6 +31,39 @@ export async function fetchOptions(baseUrl: string | undefined) {
return await optionsResponse.json();
}

export async function fetchEmbeddings(baseUrl: string | undefined) {
const embeddingsResponse = await fetch(`${baseUrl}/sdapi/v1/embeddings`, {
method: "GET",
headers: {
"Content-Type": "application/json",
},
});

return await embeddingsResponse.json();
}

export async function fetchHypernetworks(baseUrl: string | undefined) {
const hypernetworksResponse = await fetch(`${baseUrl}/sdapi/v1/hypernetworks`, {
method: "GET",
headers: {
"Content-Type": "application/json",
},
});

return await hypernetworksResponse.json();
}

export async function fetchLoras(baseUrl: string | undefined) {
const lorasResponse = await fetch(`${baseUrl}/sdapi/v1/loras`, {
method: "GET",
headers: {
"Content-Type": "application/json",
},
});

return await lorasResponse.json();
}

export async function setOptions(baseUrl: string | undefined, options: any) {
const optionsResponse = await fetch(`${baseUrl}/sdapi/v1/options`, {
method: "POST",
Expand Down Expand Up @@ -114,9 +150,10 @@ export async function constructPayload(
count?: number | undefined;
},
isUpscale = false,
upscaler: string | undefined
upscaler: string | undefined,
styleType: string | undefined
) {
const { sampler, prompts, initialImage, maskImage, width, height, steps } =
const { sampler, prompts, initialImage, maskImage, width, height, steps, style } =
options?.input ?? {};

// Construct payload
Expand Down Expand Up @@ -151,6 +188,18 @@ export async function constructPayload(
prompts?.find((p) => (p.text && (p.weight ?? 0) < 0) ?? 0 < 0)?.text ??
"";

if ((styleType != "none") && style && (style != "none") && (style != "enhance")) {
let extraStyle = style;

if (styleType == "lora") {
extraStyle = LORA_PATTERN.replace("#LORA#", style);
} else if (styleType == "hypernetworks") {
extraStyle = HYPERNETWORK_PATTERN.replace("#HYPERNETWORK#", style);
}

data.prompt = data.prompt + " " + extraStyle;
}

data.steps = steps ?? 20;
data.batch_size = options?.count;
data.save_images = true;
Expand Down
99 changes: 92 additions & 7 deletions packages/stablestudio-plugin-webui/src/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,9 @@ import {
getImageInfo,
setOptions,
testForHistoryPlugin,
fetchHypernetworks,
fetchEmbeddings,
fetchLoras
} from "./Utilities";

const manifest = {
Expand Down Expand Up @@ -64,6 +67,25 @@ const webuiUpscalers = [
},
];

const webuiStyleTypes = [
{
label: "None",
value: "none",
},
{
label: "LoRA",
value: "lora",
},
{
label: "Hypernetworks",
value: "hypernetworks",
},
{
label: "Textual Inversions",
value: "embeddings",
},
];

const getNumber = (strValue: string | null, defaultValue: number) => {
let retValue = defaultValue;

Expand All @@ -80,6 +102,7 @@ export const createPlugin = StableStudio.createPlugin<{
baseUrl: StableStudio.PluginSettingString;
upscaler: StableStudio.PluginSettingString;
historyImagesCount: StableStudio.PluginSettingNumber;
styleType: StableStudio.PluginSettingString;
};
}>(({ set, get }) => {
const webuiLoad = (
Expand All @@ -93,6 +116,7 @@ export const createPlugin = StableStudio.createPlugin<{
| "getStableDiffusionDefaultCount"
| "getStableDiffusionDefaultInput"
| "getStableDiffusionExistingImages"
| "getStableDiffusionStyles"
> => {
webuiHostUrl = webuiHostUrl ?? "http://127.0.0.1:7861";

Expand Down Expand Up @@ -138,7 +162,8 @@ export const createPlugin = StableStudio.createPlugin<{
const data = await constructPayload(
options,
isUpscale,
get().settings.upscaler.value
get().settings.upscaler.value,
get().settings.styleType.value
);

// Send payload to webui
Expand Down Expand Up @@ -273,6 +298,45 @@ export const createPlugin = StableStudio.createPlugin<{
}));
},

getStableDiffusionStyles: async () => {
const styleType = localStorage.getItem("styleType") ?? "None";

const styles: any = [];

if (styleType === "lora") {
const loras = await fetchLoras(webuiHostUrl);

loras.forEach((lora: any) => {
styles.push({
id: lora["name"],
name: lora["name"],
});
})
}
if (styleType === "hypernetworks") {
const hypernetworks = await fetchHypernetworks(webuiHostUrl);

hypernetworks.forEach((hypernetwork: any) => {
styles.push({
id: hypernetwork["name"],
name: hypernetwork["name"],
});
})
}
if (styleType === "embeddings") {
const embeddings = await fetchEmbeddings(webuiHostUrl);

for (const key in embeddings.loaded) {
styles.push({
id: key,
name: key,
});
}
}

return styles;
},

getStableDiffusionExistingImages: async () => {
const existingImagesResponse = await fetch(
`${webuiHostUrl}/StableStudio/get-generated-images`,
Expand All @@ -293,14 +357,22 @@ export const createPlugin = StableStudio.createPlugin<{

const responseData = await existingImagesResponse.json();

const images = [];
const promptedImages: any = {};

for (let i = 0; i < responseData.length; i++) {
const imageInfo = await getImageInfo(
webuiHostUrl,
responseData[i].content
);

let images = promptedImages[imageInfo["prompt"]];

if (!images) {
images = [];

promptedImages[imageInfo["prompt"]] = images;
}

const blob = await base64ToBlob(responseData[i].content, "image/jpeg");

const timestampInSeconds = responseData[i].create_date;
Expand Down Expand Up @@ -329,13 +401,17 @@ export const createPlugin = StableStudio.createPlugin<{

images.push(stableDiffusionImage);
}

const ret = [];

return [
{
for (const key in promptedImages) {
ret.push({
id: `${Math.random() * 10000000}`,
images: images,
},
];
images: promptedImages[key],
});
}

return ret;
},

settings: {
Expand Down Expand Up @@ -367,6 +443,13 @@ export const createPlugin = StableStudio.createPlugin<{
variant: "slider",
value: getNumber(localStorage.getItem("historyImagesCount"), 20),
},

styleType: {
type: "string",
title: "Style Type",
options: webuiStyleTypes,
value: localStorage.getItem("styleType") ?? "None"
}
},

setSetting: (key, value) => {
Expand All @@ -384,6 +467,8 @@ export const createPlugin = StableStudio.createPlugin<{
localStorage.setItem("upscaler1", value);
} else if (key === "historyImagesCount" && typeof value === "number") {
localStorage.setItem("historyImagesCount", value.toString());
} else if (key === "styleType" && typeof value === "string") {
localStorage.setItem("styleType", value);
}
},

Expand Down