Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update rembg dependency and add BiRefNet models in the list of available RemBG models #47

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1 +1,2 @@
__pycache__
.vscode
4 changes: 4 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,10 @@ Extension for [webui](https://github.com/AUTOMATIC1111/stable-diffusion-webui).

Find the UI for rembg in the Extras tab after installing the extension.

You'll be able to select the following models:
- "Classic" models: inference is fast, result is good.
- "BiRefNet" models: inference may take more time but result might be better.

# Installation

Install from webui's Extensions tab.
Expand Down
16 changes: 14 additions & 2 deletions install.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,19 @@
import launch
from importlib import metadata


rembg_expected_version = "2.0.59"


try:
rembg_installed_version = metadata.version("rembg")
except Exception:
rembg_installed_version = None


if rembg_installed_version != rembg_expected_version:
launch.run_pip(f"install rembg=={rembg_expected_version} --no-deps", "rembg")

if not launch.is_installed("rembg"):
launch.run_pip("install rembg==2.0.50 --no-deps", "rembg")

for dep in ['onnxruntime', 'pymatting', 'pooch']:
if not launch.is_installed(dep):
Expand Down
12 changes: 10 additions & 2 deletions scripts/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,20 @@
import rembg

# models = [
# "None",
# "u2net",
# "u2netp",
# "u2net_human_seg",
# "u2net_cloth_seg",
# "silueta",
# "isnet-general-use",
# "isnet-anime",
# "birefnet-general",
# "birefnet-general-lite",
# "birefnet-portrait",
# "birefnet-dis",
# "birefnet-hrsod",
# "birefnet-cod",
# "birefnet-massive",
# ]


Expand All @@ -34,7 +42,7 @@ async def rembg_remove(

image = rembg.remove(
input_image,
session=rembg.new_session(model),
session=rembg.new_session(model, providers=['CUDAExecutionProvider', 'CPUExecutionProvider']),
only_mask=return_mask,
alpha_matting=alpha_matting,
alpha_matting_foreground_threshold=alpha_matting_foreground_threshold,
Expand Down
24 changes: 21 additions & 3 deletions scripts/postprocessing_rembg.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,12 @@
from modules import scripts_postprocessing, ui_components
import gradio as gr

from modules.ui_components import FormRow
from modules.paths_internal import models_path
import rembg
import os

models = [
"None",
"isnet-general-use",
"u2net",
"u2netp",
"u2net_human_seg",
Expand All @@ -18,13 +16,27 @@
"isnet-anime",
]

birefnet_models = [
"None",
"birefnet-general",
"birefnet-general-lite",
"birefnet-portrait",
"birefnet-dis",
"birefnet-hrsod",
"birefnet-cod",
"birefnet-massive",
]

class ScriptPostprocessingUpscale(scripts_postprocessing.ScriptPostprocessing):
name = "Rembg"
order = 20000
model = None

def ui(self):
with ui_components.InputAccordion(False, label="Remove background") as enable:
with gr.Row():
model_type = gr.Radio(["Classic", "BiRefNet"], value="Classic", label="Model type")

with gr.Row():
model = gr.Dropdown(label="Remove background", choices=models, value="None")
return_mask = gr.Checkbox(label="Return mask", value=False)
Expand All @@ -35,6 +47,12 @@ def ui(self):
alpha_matting_foreground_threshold = gr.Slider(label="Foreground threshold", minimum=0, maximum=255, step=1, value=240)
alpha_matting_background_threshold = gr.Slider(label="Background threshold", minimum=0, maximum=255, step=1, value=10)

model_type.change(
fn=lambda x: gr.update(value="None", choices=birefnet_models if x == "BiRefNet" else models),
inputs=[model_type],
outputs=[model],
)

alpha_matting.change(
fn=lambda x: gr.update(visible=x),
inputs=[alpha_matting],
Expand Down Expand Up @@ -63,7 +81,7 @@ def process(self, pp: scripts_postprocessing.PostprocessedImage, enable, model,

pp.image = rembg.remove(
pp.image,
session=rembg.new_session(model),
session=rembg.new_session(model, providers=['CUDAExecutionProvider', 'CPUExecutionProvider']),
only_mask=return_mask,
alpha_matting=alpha_matting,
alpha_matting_foreground_threshold=alpha_matting_foreground_threshold,
Expand Down