Skip to content

Commit

Permalink
Rename SDVersion -> Arch
Browse files Browse the repository at this point in the history
  • Loading branch information
Acly committed Sep 10, 2024
1 parent 898a618 commit caf1e93
Show file tree
Hide file tree
Showing 25 changed files with 494 additions and 524 deletions.
2 changes: 1 addition & 1 deletion ai_diffusion/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
"""Generative AI plugin for Krita using Stable Diffusion"""

__version__ = "1.23.0"
__version__ = "1.24.0"

import importlib.util

Expand Down
12 changes: 6 additions & 6 deletions ai_diffusion/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import math

from .image import Bounds, Extent, Image, ImageCollection, ImageFileFormat
from .resources import ControlMode, SDVersion
from .resources import ControlMode, Arch
from .util import ensure, clamp


Expand Down Expand Up @@ -55,7 +55,7 @@ def from_dict(data: dict[str, Any]):
@dataclass
class CheckpointInput:
checkpoint: str
version: SDVersion = SDVersion.sd15
version: Arch = Arch.sd15
vae: str = ""
loras: list[LoraInput] = field(default_factory=list)
clip_skip: int = 0
Expand Down Expand Up @@ -211,12 +211,12 @@ def cost_factor(batch: int, extent: Extent, steps: int):
return base + round((10 * cost) / unit)


def _base_cost(version: SDVersion):
if version is SDVersion.sd15:
def _base_cost(arch: Arch):
if arch is Arch.sd15:
return 1
if version is SDVersion.sdxl:
if arch is Arch.sdxl:
return 2
if version is SDVersion.flux:
if arch is Arch.flux:
return 4
return 1

Expand Down
84 changes: 41 additions & 43 deletions ai_diffusion/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from .files import FileLibrary, FileFormat
from .style import Style, Styles
from .settings import PerformanceSettings
from .resources import ControlMode, ResourceKind, SDVersion, UpscalerName
from .resources import ControlMode, ResourceKind, Arch, UpscalerName
from .resources import ResourceId, resource_id
from .localization import translate as _
from .util import client_logger as log
Expand Down Expand Up @@ -70,7 +70,7 @@ def parse(data: dict):

class CheckpointInfo(NamedTuple):
filename: str
sd_version: SDVersion
arch: Arch
format: FileFormat = FileFormat.checkpoint

@property
Expand All @@ -79,9 +79,7 @@ def name(self):

@staticmethod
def deduce_from_filename(filename: str):
return CheckpointInfo(
filename, SDVersion.from_checkpoint_name(filename), FileFormat.checkpoint
)
return CheckpointInfo(filename, Arch.from_checkpoint_name(filename), FileFormat.checkpoint)


class ClientModels:
Expand All @@ -99,9 +97,9 @@ def __init__(self) -> None:
self.resources = {}

def resource(
self, kind: ResourceKind, identifier: ControlMode | UpscalerName | str, version: SDVersion
self, kind: ResourceKind, identifier: ControlMode | UpscalerName | str, arch: Arch
):
id = ResourceId(kind, version, identifier)
id = ResourceId(kind, arch, identifier)
model = self.find(id)
if model is None:
raise Exception(f"{id.name} not found")
Expand All @@ -110,92 +108,92 @@ def resource(
def find(self, id: ResourceId):
if result := self.resources.get(id.string):
return result
return self.resources.get(id._replace(version=SDVersion.all).string)
return self.resources.get(id._replace(arch=Arch.all).string)

def version_of(self, checkpoint: str):
def arch_of(self, checkpoint: str):
if info := self.checkpoints.get(checkpoint):
return info.sd_version
return SDVersion.from_checkpoint_name(checkpoint)
return info.arch
return Arch.from_checkpoint_name(checkpoint)

def for_version(self, version: SDVersion):
return ModelDict(self, ResourceKind.upscaler, version)
def for_arch(self, arch: Arch):
return ModelDict(self, ResourceKind.upscaler, arch)

def for_checkpoint(self, checkpoint: str):
return self.for_version(self.version_of(checkpoint))
return self.for_arch(self.arch_of(checkpoint))

@property
def upscale(self):
return ModelDict(self, ResourceKind.upscaler, SDVersion.all)
return ModelDict(self, ResourceKind.upscaler, Arch.all)

@property
def default_upscaler(self):
return self.resource(ResourceKind.upscaler, UpscalerName.default, SDVersion.all)
return self.resource(ResourceKind.upscaler, UpscalerName.default, Arch.all)


class ModelDict:
"""Provides access to filtered list of models matching a certain SD version."""
"""Provides access to filtered list of models matching a certain Diffusion base model."""

_models: ClientModels
kind: ResourceKind
version: SDVersion
arch: Arch

def __init__(self, models: ClientModels, kind: ResourceKind, version: SDVersion):
def __init__(self, models: ClientModels, kind: ResourceKind, arch: Arch):
self._models = models
self.kind = kind
self.version = version
self.arch = arch

def __getitem__(self, key: ControlMode | UpscalerName | str):
return self._models.resource(self.kind, key, self.version)
return self._models.resource(self.kind, key, self.arch)

def find(self, key: ControlMode | UpscalerName | str, allow_universal=False) -> str | None:
if key in [ControlMode.style, ControlMode.composition]:
key = ControlMode.reference # Same model with different weight types
result = self._models.resources.get(resource_id(self.kind, self.version, key))
result = self._models.resources.get(resource_id(self.kind, self.arch, key))
if result is None and allow_universal and isinstance(key, ControlMode):
result = self.find(ControlMode.universal)
return result

def for_version(self, version: SDVersion):
return ModelDict(self._models, self.kind, version)
def for_version(self, arch: Arch):
return ModelDict(self._models, self.kind, arch)

@property
def text_encoder(self):
return ModelDict(self._models, ResourceKind.text_encoder, self.version)
return ModelDict(self._models, ResourceKind.text_encoder, self.arch)

@property
def clip_vision(self):
return self._models.resource(ResourceKind.clip_vision, "ip_adapter", SDVersion.all)
return self._models.resource(ResourceKind.clip_vision, "ip_adapter", Arch.all)

@property
def upscale(self):
return ModelDict(self._models, ResourceKind.upscaler, SDVersion.all)
return ModelDict(self._models, ResourceKind.upscaler, Arch.all)

@property
def control(self):
return ModelDict(self._models, ResourceKind.controlnet, self.version)
return ModelDict(self._models, ResourceKind.controlnet, self.arch)

@property
def ip_adapter(self):
return ModelDict(self._models, ResourceKind.ip_adapter, self.version)
return ModelDict(self._models, ResourceKind.ip_adapter, self.arch)

@property
def inpaint(self):
return ModelDict(self._models, ResourceKind.inpaint, SDVersion.all)
return ModelDict(self._models, ResourceKind.inpaint, Arch.all)

@property
def lora(self):
return ModelDict(self._models, ResourceKind.lora, self.version)
return ModelDict(self._models, ResourceKind.lora, self.arch)

@property
def vae(self):
return self._models.resource(ResourceKind.vae, "default", self.version)
return self._models.resource(ResourceKind.vae, "default", self.arch)

@property
def fooocus_inpaint(self):
assert self.version is SDVersion.sdxl
assert self.arch is Arch.sdxl
return dict(
head=self._models.resource(ResourceKind.inpaint, "fooocus_head", SDVersion.sdxl),
patch=self._models.resource(ResourceKind.inpaint, "fooocus_patch", SDVersion.sdxl),
head=self._models.resource(ResourceKind.inpaint, "fooocus_head", Arch.sdxl),
patch=self._models.resource(ResourceKind.inpaint, "fooocus_patch", Arch.sdxl),
)

@property
Expand All @@ -208,10 +206,10 @@ def node_inputs(self):

@property
def has_te_vae(self):
if self._models.find(ResourceId(ResourceKind.vae, self.version, "default")) is None:
if self._models.find(ResourceId(ResourceKind.vae, self.arch, "default")) is None:
return False
for te in self.version.text_encoders:
if self._models.find(ResourceId(ResourceKind.text_encoder, self.version, te)) is None:
for te in self.arch.text_encoders:
if self._models.find(ResourceId(ResourceKind.text_encoder, self.arch, te)) is None:
return False
return True

Expand Down Expand Up @@ -263,7 +261,7 @@ async def disconnect(self):
def user(self) -> User | None:
return None

def supports_version(self, version: SDVersion) -> bool:
def supports_arch(self, arch: Arch) -> bool:
return True

@property
Expand Down Expand Up @@ -292,10 +290,10 @@ async def __aexit__(self, exc_type, exc_value, traceback):
await self.disconnect()


def resolve_sd_version(style: Style, client: Client | None = None):
if style.sd_version is SDVersion.auto:
def resolve_arch(style: Style, client: Client | None = None):
if style.sd_version is Arch.auto:
if client and style.sd_checkpoint in client.models.checkpoints:
return client.models.version_of(style.sd_checkpoint)
return client.models.arch_of(style.sd_checkpoint)
return style.sd_version.resolve(style.sd_checkpoint)
return style.sd_version

Expand All @@ -305,7 +303,7 @@ def filter_supported_styles(styles: Iterable[Style], client: Client | None = Non
return [
style
for style in styles
if client.supports_version(resolve_sd_version(style, client))
if client.supports_arch(resolve_arch(style, client))
and style.sd_checkpoint in client.models.checkpoints
]
return list(styles)
Expand Down
80 changes: 39 additions & 41 deletions ai_diffusion/cloud_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from .image import Extent, ImageCollection
from .network import RequestManager, NetworkError
from .files import FileLibrary, File
from .resources import SDVersion
from .resources import Arch
from .settings import PerformanceSettings, settings
from .localization import translate as _
from .util import clamp, ensure, client_logger as log
Expand Down Expand Up @@ -336,20 +336,18 @@ def _base64_size(size: int):

models = ClientModels()
models.checkpoints = {
"dreamshaper_8.safetensors": CheckpointInfo("dreamshaper_8.safetensors", SDVersion.sd15),
"dreamshaper_8.safetensors": CheckpointInfo("dreamshaper_8.safetensors", Arch.sd15),
"realisticVisionV51_v51VAE.safetensors": CheckpointInfo(
"realisticVisionV51_v51VAE.safetensors", SDVersion.sd15
"realisticVisionV51_v51VAE.safetensors", Arch.sd15
),
"flat2DAnimerge_v45Sharp.safetensors": CheckpointInfo(
"flat2DAnimerge_v45Sharp.safetensors", SDVersion.sd15
"flat2DAnimerge_v45Sharp.safetensors", Arch.sd15
),
"juggernautXL_version6Rundiffusion.safetensors": CheckpointInfo(
"juggernautXL_version6Rundiffusion.safetensors", SDVersion.sdxl
),
"zavychromaxl_v80.safetensors": CheckpointInfo("zavychromaxl_v80.safetensors", SDVersion.sdxl),
"flux1-schnell-fp8.safetensors": CheckpointInfo(
"flux1-schnell-fp8.safetensors", SDVersion.flux
"juggernautXL_version6Rundiffusion.safetensors", Arch.sdxl
),
"zavychromaxl_v80.safetensors": CheckpointInfo("zavychromaxl_v80.safetensors", Arch.sdxl),
"flux1-schnell-fp8.safetensors": CheckpointInfo("flux1-schnell-fp8.safetensors", Arch.flux),
}
models.vae = []
models.loras = [
Expand All @@ -366,37 +364,37 @@ def _base64_size(size: int):
# fmt: off
from ai_diffusion.resources import resource_id, ResourceKind, ControlMode, UpscalerName
models.resources = {
resource_id(ResourceKind.controlnet, SDVersion.sd15, ControlMode.inpaint): "control_v11p_sd15_inpaint_fp16.safetensors",
resource_id(ResourceKind.controlnet, SDVersion.sdxl, ControlMode.universal): "xinsir-controlnet-union-sdxl-1.0-promax.safetensors",
resource_id(ResourceKind.controlnet, SDVersion.sd15, ControlMode.scribble): "control_lora_rank128_v11p_sd15_scribble_fp16.safetensors",
resource_id(ResourceKind.controlnet, SDVersion.sd15, ControlMode.line_art): "control_v11p_sd15_lineart_fp16.safetensors",
resource_id(ResourceKind.controlnet, SDVersion.sd15, ControlMode.soft_edge): "control_v11p_sd15_softedge_fp16.safetensors",
resource_id(ResourceKind.controlnet, SDVersion.sd15, ControlMode.canny_edge): "control_v11p_sd15_canny_fp16.safetensors",
resource_id(ResourceKind.controlnet, SDVersion.sd15, ControlMode.depth): "control_lora_rank128_v11f1p_sd15_depth_fp16.safetensors",
resource_id(ResourceKind.controlnet, SDVersion.sd15, ControlMode.normal): None,
resource_id(ResourceKind.controlnet, SDVersion.sd15, ControlMode.pose): "control_lora_rank128_v11p_sd15_openpose_fp16.safetensors",
resource_id(ResourceKind.controlnet, SDVersion.sd15, ControlMode.segmentation): None,
resource_id(ResourceKind.controlnet, SDVersion.sd15, ControlMode.blur):"control_lora_rank128_v11f1e_sd15_tile_fp16.safetensors",
resource_id(ResourceKind.controlnet, SDVersion.sd15, ControlMode.stencil): "control_v1p_sd15_qrcode_monster.safetensors",
resource_id(ResourceKind.controlnet, SDVersion.sd15, ControlMode.hands): None,
resource_id(ResourceKind.controlnet, SDVersion.sdxl, ControlMode.hands): None,
resource_id(ResourceKind.ip_adapter, SDVersion.sd15, ControlMode.reference): "ip-adapter_sd15.safetensors",
resource_id(ResourceKind.ip_adapter, SDVersion.sdxl, ControlMode.reference): "ip-adapter_sdxl_vit-h.safetensors",
resource_id(ResourceKind.ip_adapter, SDVersion.sd15, ControlMode.face): None,
resource_id(ResourceKind.ip_adapter, SDVersion.sdxl, ControlMode.face): None,
resource_id(ResourceKind.clip_vision, SDVersion.all, "ip_adapter"): "clip-vision_vit-h.safetensors",
resource_id(ResourceKind.lora, SDVersion.sd15, "lcm"): "lcm-lora-sdv1-5.safetensors",
resource_id(ResourceKind.lora, SDVersion.sdxl, "lcm"): "lcm-lora-sdxl.safetensors",
resource_id(ResourceKind.lora, SDVersion.sd15, "hyper"): "Hyper-SD15-8steps-CFG-lora.safetensors",
resource_id(ResourceKind.lora, SDVersion.sdxl, "hyper"): "Hyper-SDXL-8steps-CFG-lora.safetensors",
resource_id(ResourceKind.lora, SDVersion.sd15, ControlMode.face): None,
resource_id(ResourceKind.lora, SDVersion.sdxl, ControlMode.face): None,
resource_id(ResourceKind.upscaler, SDVersion.all, UpscalerName.default): UpscalerName.default.value,
resource_id(ResourceKind.upscaler, SDVersion.all, UpscalerName.fast_2x): UpscalerName.fast_2x.value,
resource_id(ResourceKind.upscaler, SDVersion.all, UpscalerName.fast_3x): UpscalerName.fast_3x.value,
resource_id(ResourceKind.upscaler, SDVersion.all, UpscalerName.fast_4x): UpscalerName.fast_4x.value,
resource_id(ResourceKind.inpaint, SDVersion.sdxl, "fooocus_head"): "fooocus_inpaint_head.pth",
resource_id(ResourceKind.inpaint, SDVersion.sdxl, "fooocus_patch"): "inpaint_v26.fooocus.patch",
resource_id(ResourceKind.inpaint, SDVersion.all, "default"): "MAT_Places512_G_fp16.safetensors",
resource_id(ResourceKind.controlnet, Arch.sd15, ControlMode.inpaint): "control_v11p_sd15_inpaint_fp16.safetensors",
resource_id(ResourceKind.controlnet, Arch.sdxl, ControlMode.universal): "xinsir-controlnet-union-sdxl-1.0-promax.safetensors",
resource_id(ResourceKind.controlnet, Arch.sd15, ControlMode.scribble): "control_lora_rank128_v11p_sd15_scribble_fp16.safetensors",
resource_id(ResourceKind.controlnet, Arch.sd15, ControlMode.line_art): "control_v11p_sd15_lineart_fp16.safetensors",
resource_id(ResourceKind.controlnet, Arch.sd15, ControlMode.soft_edge): "control_v11p_sd15_softedge_fp16.safetensors",
resource_id(ResourceKind.controlnet, Arch.sd15, ControlMode.canny_edge): "control_v11p_sd15_canny_fp16.safetensors",
resource_id(ResourceKind.controlnet, Arch.sd15, ControlMode.depth): "control_lora_rank128_v11f1p_sd15_depth_fp16.safetensors",
resource_id(ResourceKind.controlnet, Arch.sd15, ControlMode.normal): None,
resource_id(ResourceKind.controlnet, Arch.sd15, ControlMode.pose): "control_lora_rank128_v11p_sd15_openpose_fp16.safetensors",
resource_id(ResourceKind.controlnet, Arch.sd15, ControlMode.segmentation): None,
resource_id(ResourceKind.controlnet, Arch.sd15, ControlMode.blur):"control_lora_rank128_v11f1e_sd15_tile_fp16.safetensors",
resource_id(ResourceKind.controlnet, Arch.sd15, ControlMode.stencil): "control_v1p_sd15_qrcode_monster.safetensors",
resource_id(ResourceKind.controlnet, Arch.sd15, ControlMode.hands): None,
resource_id(ResourceKind.controlnet, Arch.sdxl, ControlMode.hands): None,
resource_id(ResourceKind.ip_adapter, Arch.sd15, ControlMode.reference): "ip-adapter_sd15.safetensors",
resource_id(ResourceKind.ip_adapter, Arch.sdxl, ControlMode.reference): "ip-adapter_sdxl_vit-h.safetensors",
resource_id(ResourceKind.ip_adapter, Arch.sd15, ControlMode.face): None,
resource_id(ResourceKind.ip_adapter, Arch.sdxl, ControlMode.face): None,
resource_id(ResourceKind.clip_vision, Arch.all, "ip_adapter"): "clip-vision_vit-h.safetensors",
resource_id(ResourceKind.lora, Arch.sd15, "lcm"): "lcm-lora-sdv1-5.safetensors",
resource_id(ResourceKind.lora, Arch.sdxl, "lcm"): "lcm-lora-sdxl.safetensors",
resource_id(ResourceKind.lora, Arch.sd15, "hyper"): "Hyper-SD15-8steps-CFG-lora.safetensors",
resource_id(ResourceKind.lora, Arch.sdxl, "hyper"): "Hyper-SDXL-8steps-CFG-lora.safetensors",
resource_id(ResourceKind.lora, Arch.sd15, ControlMode.face): None,
resource_id(ResourceKind.lora, Arch.sdxl, ControlMode.face): None,
resource_id(ResourceKind.upscaler, Arch.all, UpscalerName.default): UpscalerName.default.value,
resource_id(ResourceKind.upscaler, Arch.all, UpscalerName.fast_2x): UpscalerName.fast_2x.value,
resource_id(ResourceKind.upscaler, Arch.all, UpscalerName.fast_3x): UpscalerName.fast_3x.value,
resource_id(ResourceKind.upscaler, Arch.all, UpscalerName.fast_4x): UpscalerName.fast_4x.value,
resource_id(ResourceKind.inpaint, Arch.sdxl, "fooocus_head"): "fooocus_inpaint_head.pth",
resource_id(ResourceKind.inpaint, Arch.sdxl, "fooocus_patch"): "inpaint_v26.fooocus.patch",
resource_id(ResourceKind.inpaint, Arch.all, "default"): "MAT_Places512_G_fp16.safetensors",
}
# fmt: on
Loading

0 comments on commit caf1e93

Please sign in to comment.