Skip to content

Commit

Permalink
add rocm support
Browse files Browse the repository at this point in the history
  • Loading branch information
Cyberhan123 committed Jan 21, 2024
1 parent 66dbdf4 commit cbd618a
Show file tree
Hide file tree
Showing 5 changed files with 65 additions and 9 deletions.
2 changes: 1 addition & 1 deletion .idea/deployment.xml

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

12 changes: 7 additions & 5 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -23,11 +23,13 @@ See `deps` folder for dylib compatibility, push request is welcome.

Windows NVIDIA GPU User may need check [cuda architecture](https://developer.nvidia.com/cuda-gpus) to get more information.

| platform | x32 | x64 | arm | AMD/ROCM | NVIDIA/CUDA |
|----------|-------------|-------------------------|-------------|-------------|----------------|
| windows | not support | support avx/avx2/avx512 | not support | not support | cuda12 support |
| linux | not support | support | not support | not support | not support |
| darwin | not support | support | support | not support | not support |
Windows AMD/ROCM GPU User may need check [system requirements](https://rocm.docs.amd.com/projects/install-on-windows/en/latest/reference/system-requirements.html) to get more information.

| platform | x32 | x64 | arm | AMD/ROCM | NVIDIA/CUDA |
|----------|-------------|-------------------------|-------------|-----------------|----------------|
| windows | not support | support avx/avx2/avx512 | not support | rocm5.5 support | cuda12 support |
| linux | not support | support | not support | not support | not support |
| darwin | not support | support | support | not support | not support |

## AutoModel Dynamic Libraries Disclaimer

Expand Down
Binary file added deps/windows/sd-abi_rocm5.5.dll
Binary file not shown.
16 changes: 13 additions & 3 deletions embed_windows.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,9 @@ var libStableDiffusionAvx512 []byte
//go:embed deps/windows/sd-abi_cuda12.dll
var libStableDiffusionCuda12 []byte

//go:embed deps/windows/sd-abi_rocm5.5.dll
var libStableDiffusionRocm5 []byte

var libName = "stable-diffusion-*.dll"

func getDl(gpu bool) []byte {
Expand All @@ -36,14 +39,21 @@ func getDl(gpu bool) []byte {
if err != nil {
log.Println(err)
}
driver := info.Cuda()
log.Print("get gpu info: ", driver.Name)
cuda := info.Cuda()
rocm := info.ROCm()

if driver.Available() {
if cuda.Available() {
log.Print("get gpu info: ", cuda.Name)
log.Println("Use GPU CUDA instead.")
return libStableDiffusionCuda12
}

if rocm.Available() {
log.Print("get gpu info: ", cuda.Name)
log.Println("Use GPU ROCm instead.")
return libStableDiffusionRocm5
}

log.Println("GPU not support, use CPU instead.")
}

Expand Down
44 changes: 44 additions & 0 deletions sd_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,50 @@ func TestNewStableDiffusionAutoModelPredict(t *testing.T) {
}
}

func TestModel_ROCm(t *testing.T) {
options := sd.DefaultOptions
options.GpuEnable = true
t.Log(options)
model, err := sd.NewAutoModel(options)
if err != nil {
t.Error(err)
return
}
defer model.Close()
model.SetLogCallback(func(level sd.LogLevel, msg string) {
t.Log(msg)
})
err = model.LoadFromFile("./models/miniSD.ckpt")
if err != nil {
t.Error(err)
return
}
var writers []io.Writer
filenames := []string{
"./assets/love_cat2.png",
}
for _, filename := range filenames {
file, err := os.Create(filename)
if err != nil {
t.Error(err)
return
}
defer file.Close()
writers = append(writers, file)
}

params := sd.DefaultFullParams
params.BatchCount = 1
params.Width = 256
params.Height = 256
params.NegativePrompt = ""
err = model.Predict("british short hair cat, high quality", params, writers)
if err != nil {
t.Error(err)
return
}
}

func TestNewStableDiffusionAutoModelImagePredict(t *testing.T) {
options := sd.DefaultOptions
options.VaeDecodeOnly = false
Expand Down

0 comments on commit cbd618a

Please sign in to comment.