Skip to content

Commit

Permalink
feat: add flux support (#356)
Browse files Browse the repository at this point in the history
* add flux support

* avoid build failures in non-CUDA environments

* fix schnell support

* add k quants support

* add support for applying lora to quantized tensors

* add inplace conversion support for f8_e4m3 (#359)

in the same way it is done for bf16
like how bf16 converts losslessly to fp32,
f8_e4m3 converts losslessly to fp16

* add xlabs flux comfy converted lora support

* update docs

---------

Co-authored-by: Erik Scholz <[email protected]>
  • Loading branch information
leejet and Green-Sky authored Aug 24, 2024
1 parent 697d000 commit 64d231f
Show file tree
Hide file tree
Showing 25 changed files with 1,886 additions and 172 deletions.
9 changes: 5 additions & 4 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,12 @@ Inference of [Stable Diffusion](https://github.com/CompVis/stable-diffusion) in
- Super lightweight and without external dependencies
- SD1.x, SD2.x, SDXL and SD3 support
- !!!The VAE in SDXL encounters NaN issues under FP16, but unfortunately, the ggml_conv_2d only operates under FP16. Hence, a parameter is needed to specify the VAE that has fixed the FP16 NaN issue. You can find it here: [SDXL VAE FP16 Fix](https://huggingface.co/madebyollin/sdxl-vae-fp16-fix/blob/main/sdxl_vae.safetensors).
- [Flux-dev/Flux-schnell Support](./docs/flux.md)

- [SD-Turbo](https://huggingface.co/stabilityai/sd-turbo) and [SDXL-Turbo](https://huggingface.co/stabilityai/sdxl-turbo) support
- [PhotoMaker](https://github.com/TencentARC/PhotoMaker) support.
- 16-bit, 32-bit float support
- 4-bit, 5-bit and 8-bit integer quantization support
- 2-bit, 3-bit, 4-bit, 5-bit and 8-bit integer quantization support
- Accelerated memory-efficient CPU inference
- Only requires ~2.3GB when using txt2img with fp16 precision to generate a 512x512 image, enabling Flash Attention just requires ~1.8GB.
- AVX, AVX2 and AVX512 support for x86 architectures
Expand Down Expand Up @@ -57,7 +58,6 @@ Inference of [Stable Diffusion](https://github.com/CompVis/stable-diffusion) in
- The current implementation of ggml_conv_2d is slow and has high memory usage
- [ ] Continuing to reduce memory usage (quantizing the weights of ggml_conv_2d)
- [ ] Implement Inpainting support
- [ ] k-quants support

## Usage

Expand Down Expand Up @@ -202,7 +202,7 @@ arguments:
--normalize-input normalize PHOTOMAKER input id images
--upscale-model [ESRGAN_PATH] path to esrgan model. Upscale images after generate, just RealESRGAN_x4plus_anime_6B supported by now.
--upscale-repeats Run the ESRGAN upscaler this many times (default 1)
--type [TYPE] weight type (f32, f16, q4_0, q4_1, q5_0, q5_1, q8_0)
--type [TYPE] weight type (f32, f16, q4_0, q4_1, q5_0, q5_1, q8_0, q2_k, q3_k, q4_k)
If not specified, the default is the type of the weight file.
--lora-model-dir [DIR] lora model directory
-i, --init-img [IMAGE] path to the input image, required by img2img
Expand All @@ -229,7 +229,7 @@ arguments:
--vae-tiling process vae in tiles to reduce memory usage
--control-net-cpu keep controlnet in cpu (for low vram)
--canny apply canny preprocessor (edge detection)
--color colors the logging tags according to level
--color Colors the logging tags according to level
-v, --verbose print extra info
```
Expand All @@ -240,6 +240,7 @@ arguments:
# ./bin/sd -m ../models/v1-5-pruned-emaonly.safetensors -p "a lovely cat"
# ./bin/sd -m ../models/sd_xl_base_1.0.safetensors --vae ../models/sdxl_vae-fp16-fix.safetensors -H 1024 -W 1024 -p "a lovely cat" -v
# ./bin/sd -m ../models/sd3_medium_incl_clips_t5xxlfp16.safetensors -H 1024 -W 1024 -p 'a lovely cat holding a sign says \"Stable Diffusion CPP\"' --cfg-scale 4.5 --sampling-method euler -v
# ./bin/sd --diffusion-model ../models/flux1-dev-q3_k.gguf --vae ../models/ae.sft --clip_l ../models/clip_l.safetensors --t5xxl ../models/t5xxl_fp16.safetensors -p "a lovely cat holding a sign says 'flux.cpp'" --cfg-scale 1.0 --sampling-method euler -v
```

Using formats of different precisions will yield results of varying quality.
Expand Down
Binary file added assets/flux/flux1-dev-q2_k.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added assets/flux/flux1-dev-q3_k.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added assets/flux/flux1-dev-q4_0.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added assets/flux/flux1-dev-q8_0 with lora.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added assets/flux/flux1-dev-q8_0.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added assets/flux/flux1-schnell-q8_0.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
2 changes: 1 addition & 1 deletion common.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -367,7 +367,7 @@ class SpatialTransformer : public GGMLBlock {
int64_t n_head;
int64_t d_head;
int64_t depth = 1; // 1
int64_t context_dim = 768; // hidden_size, 1024 for VERSION_2_x
int64_t context_dim = 768; // hidden_size, 1024 for VERSION_SD2

public:
SpatialTransformer(int64_t in_channels,
Expand Down
256 changes: 241 additions & 15 deletions conditioner.hpp

Large diffs are not rendered by default.

18 changes: 9 additions & 9 deletions control.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
*/
class ControlNetBlock : public GGMLBlock {
protected:
SDVersion version = VERSION_1_x;
SDVersion version = VERSION_SD1;
// network hparams
int in_channels = 4;
int out_channels = 4;
Expand All @@ -26,19 +26,19 @@ class ControlNetBlock : public GGMLBlock {
int time_embed_dim = 1280; // model_channels*4
int num_heads = 8;
int num_head_channels = -1; // channels // num_heads
int context_dim = 768; // 1024 for VERSION_2_x, 2048 for VERSION_XL
int context_dim = 768; // 1024 for VERSION_SD2, 2048 for VERSION_SDXL

public:
int model_channels = 320;
int adm_in_channels = 2816; // only for VERSION_XL
int adm_in_channels = 2816; // only for VERSION_SDXL

ControlNetBlock(SDVersion version = VERSION_1_x)
ControlNetBlock(SDVersion version = VERSION_SD1)
: version(version) {
if (version == VERSION_2_x) {
if (version == VERSION_SD2) {
context_dim = 1024;
num_head_channels = 64;
num_heads = -1;
} else if (version == VERSION_XL) {
} else if (version == VERSION_SDXL) {
context_dim = 2048;
attention_resolutions = {4, 2};
channel_mult = {1, 2, 4};
Expand All @@ -58,7 +58,7 @@ class ControlNetBlock : public GGMLBlock {
// time_embed_1 is nn.SiLU()
blocks["time_embed.2"] = std::shared_ptr<GGMLBlock>(new Linear(time_embed_dim, time_embed_dim));

if (version == VERSION_XL || version == VERSION_SVD) {
if (version == VERSION_SDXL || version == VERSION_SVD) {
blocks["label_emb.0.0"] = std::shared_ptr<GGMLBlock>(new Linear(adm_in_channels, time_embed_dim));
// label_emb_1 is nn.SiLU()
blocks["label_emb.0.2"] = std::shared_ptr<GGMLBlock>(new Linear(time_embed_dim, time_embed_dim));
Expand Down Expand Up @@ -307,7 +307,7 @@ class ControlNetBlock : public GGMLBlock {
};

struct ControlNet : public GGMLRunner {
SDVersion version = VERSION_1_x;
SDVersion version = VERSION_SD1;
ControlNetBlock control_net;

ggml_backend_buffer_t control_buffer = NULL; // keep control output tensors in backend memory
Expand All @@ -318,7 +318,7 @@ struct ControlNet : public GGMLRunner {

ControlNet(ggml_backend_t backend,
ggml_type wtype,
SDVersion version = VERSION_1_x)
SDVersion version = VERSION_SD1)
: GGMLRunner(backend, wtype), control_net(version) {
control_net.init(params_ctx, wtype);
}
Expand Down
67 changes: 64 additions & 3 deletions denoiser.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
// Ref: https://github.com/crowsonkb/k-diffusion/blob/master/k_diffusion/external.py

#define TIMESTEPS 1000
#define FLUX_TIMESTEPS 1000

struct SigmaSchedule {
int version = 0;
Expand Down Expand Up @@ -144,13 +145,13 @@ struct AYSSchedule : SigmaSchedule {
std::vector<float> results(n + 1);

switch (version) {
case VERSION_2_x: /* fallthrough */
case VERSION_SD2: /* fallthrough */
LOG_WARN("AYS not designed for SD2.X models");
case VERSION_1_x:
case VERSION_SD1:
LOG_INFO("AYS using SD1.5 noise levels");
inputs = noise_levels[0];
break;
case VERSION_XL:
case VERSION_SDXL:
LOG_INFO("AYS using SDXL noise levels");
inputs = noise_levels[1];
break;
Expand Down Expand Up @@ -350,6 +351,66 @@ struct DiscreteFlowDenoiser : public Denoiser {
}
};


float flux_time_shift(float mu, float sigma, float t) {
return std::exp(mu) / (std::exp(mu) + std::pow((1.0 / t - 1.0), sigma));
}

struct FluxFlowDenoiser : public Denoiser {
float sigmas[TIMESTEPS];
float shift = 1.15f;

float sigma_data = 1.0f;

FluxFlowDenoiser(float shift = 1.15f) {
set_parameters(shift);
}

void set_parameters(float shift = 1.15f) {
this->shift = shift;
for (int i = 1; i < TIMESTEPS + 1; i++) {
sigmas[i - 1] = t_to_sigma(i/TIMESTEPS * TIMESTEPS);
}
}

float sigma_min() {
return sigmas[0];
}

float sigma_max() {
return sigmas[TIMESTEPS - 1];
}

float sigma_to_t(float sigma) {
return sigma;
}

float t_to_sigma(float t) {
t = t + 1;
return flux_time_shift(shift, 1.0f, t / TIMESTEPS);
}

std::vector<float> get_scalings(float sigma) {
float c_skip = 1.0f;
float c_out = -sigma;
float c_in = 1.0f;
return {c_skip, c_out, c_in};
}

// this function will modify noise/latent
ggml_tensor* noise_scaling(float sigma, ggml_tensor* noise, ggml_tensor* latent) {
ggml_tensor_scale(noise, sigma);
ggml_tensor_scale(latent, 1.0f - sigma);
ggml_tensor_add(latent, noise);
return latent;
}

ggml_tensor* inverse_noise_scaling(float sigma, ggml_tensor* latent) {
ggml_tensor_scale(latent, 1.0f / (1.0f - sigma));
return latent;
}
};

typedef std::function<ggml_tensor*(ggml_tensor*, float, int)> denoise_cb_t;

// k diffusion reverse ODE: dx = (x - D(x;\sigma)) / \sigma dt; \sigma(t) = t
Expand Down
58 changes: 56 additions & 2 deletions diffusion_model.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

#include "mmdit.hpp"
#include "unet.hpp"
#include "flux.hpp"

struct DiffusionModel {
virtual void compute(int n_threads,
Expand All @@ -11,6 +12,7 @@ struct DiffusionModel {
struct ggml_tensor* context,
struct ggml_tensor* c_concat,
struct ggml_tensor* y,
struct ggml_tensor* guidance,
int num_video_frames = -1,
std::vector<struct ggml_tensor*> controls = {},
float control_strength = 0.f,
Expand All @@ -29,7 +31,7 @@ struct UNetModel : public DiffusionModel {

UNetModel(ggml_backend_t backend,
ggml_type wtype,
SDVersion version = VERSION_1_x)
SDVersion version = VERSION_SD1)
: unet(backend, wtype, version) {
}

Expand Down Expand Up @@ -63,6 +65,7 @@ struct UNetModel : public DiffusionModel {
struct ggml_tensor* context,
struct ggml_tensor* c_concat,
struct ggml_tensor* y,
struct ggml_tensor* guidance,
int num_video_frames = -1,
std::vector<struct ggml_tensor*> controls = {},
float control_strength = 0.f,
Expand All @@ -77,7 +80,7 @@ struct MMDiTModel : public DiffusionModel {

MMDiTModel(ggml_backend_t backend,
ggml_type wtype,
SDVersion version = VERSION_3_2B)
SDVersion version = VERSION_SD3_2B)
: mmdit(backend, wtype, version) {
}

Expand Down Expand Up @@ -111,6 +114,7 @@ struct MMDiTModel : public DiffusionModel {
struct ggml_tensor* context,
struct ggml_tensor* c_concat,
struct ggml_tensor* y,
struct ggml_tensor* guidance,
int num_video_frames = -1,
std::vector<struct ggml_tensor*> controls = {},
float control_strength = 0.f,
Expand All @@ -120,4 +124,54 @@ struct MMDiTModel : public DiffusionModel {
}
};


struct FluxModel : public DiffusionModel {
Flux::FluxRunner flux;

FluxModel(ggml_backend_t backend,
ggml_type wtype,
SDVersion version = VERSION_FLUX_DEV)
: flux(backend, wtype, version) {
}

void alloc_params_buffer() {
flux.alloc_params_buffer();
}

void free_params_buffer() {
flux.free_params_buffer();
}

void free_compute_buffer() {
flux.free_compute_buffer();
}

void get_param_tensors(std::map<std::string, struct ggml_tensor*>& tensors) {
flux.get_param_tensors(tensors, "model.diffusion_model");
}

size_t get_params_buffer_size() {
return flux.get_params_buffer_size();
}

int64_t get_adm_in_channels() {
return 768;
}

void compute(int n_threads,
struct ggml_tensor* x,
struct ggml_tensor* timesteps,
struct ggml_tensor* context,
struct ggml_tensor* c_concat,
struct ggml_tensor* y,
struct ggml_tensor* guidance,
int num_video_frames = -1,
std::vector<struct ggml_tensor*> controls = {},
float control_strength = 0.f,
struct ggml_tensor** output = NULL,
struct ggml_context* output_ctx = NULL) {
return flux.compute(n_threads, x, timesteps, context, y, guidance, output, output_ctx);
}
};

#endif
63 changes: 63 additions & 0 deletions docs/flux.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
# How to Use

You can run Flux using stable-diffusion.cpp with a GPU that has 6GB or even 4GB of VRAM, without needing to offload to RAM.

## Download weights

- Download flux-dev from https://huggingface.co/black-forest-labs/FLUX.1-dev/blob/main/flux1-dev.safetensors
- Download flux-schnell from https://huggingface.co/black-forest-labs/FLUX.1-schnell/blob/main/flux1-schnell.safetensors
- Download vae from https://huggingface.co/black-forest-labs/FLUX.1-dev/blob/main/ae.safetensors
- Download clip_l from https://huggingface.co/comfyanonymous/flux_text_encoders/blob/main/clip_l.safetensors
- Download t5xxl from https://huggingface.co/comfyanonymous/flux_text_encoders/blob/main/t5xxl_fp16.safetensors

## Convert flux weights

Using fp16 will lead to overflow, but ggml's support for bf16 is not yet fully developed. Therefore, we need to convert flux to gguf format here, which also saves VRAM. For example:
```
.\bin\Release\sd.exe -M convert -m ..\..\ComfyUI\models\unet\flux1-dev.sft -o ..\models\flux1-dev-q8_0.gguf -v --type q8_0
```

## Run

- `--cfg-scale` is recommended to be set to 1.

### Flux-dev
For example:

```
.\bin\Release\sd.exe --diffusion-model ..\models\flux1-dev-q8_0.gguf --vae ..\models\ae.sft --clip_l ..\models\clip_l.safetensors --t5xxl ..\models\t5xxl_fp16.safetensors -p "a lovely cat holding a sign says 'flux.cpp'" --cfg-scale 1.0 --sampling-method euler -v
```

Using formats of different precisions will yield results of varying quality.

| Type | q8_0 | q4_0 | q3_k | q2_k |
|---- | ---- |---- |---- |---- |
| **Memory** | 12068.09 MB | 6394.53 MB | 4888.16 MB | 3735.73 MB |
| **Result** | ![](../assets/flux/flux1-dev-q8_0.png) |![](../assets/flux/flux1-dev-q4_0.png) |![](../assets/flux/flux1-dev-q3_k.png) |![](../assets/flux/flux1-dev-q2_k.png)|



### Flux-schnell


```
.\bin\Release\sd.exe --diffusion-model ..\models\flux1-schnell-q8_0.gguf --vae ..\models\ae.sft --clip_l ..\models\clip_l.safetensors --t5xxl ..\models\t5xxl_fp16.safetensors -p "a lovely cat holding a sign says 'flux.cpp'" --cfg-scale 1.0 --sampling-method euler -v --steps 4
```

| q8_0 |
| ---- |
|![](../assets/flux/flux1-schnell-q8_0.png) |

## Run with LoRA

Since many flux LoRA training libraries have used various LoRA naming formats, it is possible that not all flux LoRA naming formats are supported. It is recommended to use LoRA with naming formats compatible with ComfyUI.

### Flux-dev q8_0 with LoRA

- LoRA model from https://huggingface.co/XLabs-AI/flux-lora-collection/tree/main (using comfy converted version!!!)

```
.\bin\Release\sd.exe --diffusion-model ..\models\flux1-dev-q8_0.gguf --vae ...\models\ae.sft --clip_l ..\models\clip_l.safetensors --t5xxl ..\models\t5xxl_fp16.safetensors -p "a lovely cat holding a sign says 'flux.cpp'<lora:realism_lora_comfy_converted:1>" --cfg-scale 1.0 --sampling-method euler -v --lora-model-dir ../models
```

![output](../assets/flux/flux1-dev-q8_0%20with%20lora.png)
Loading

0 comments on commit 64d231f

Please sign in to comment.