Skip to content

Commit

Permalink
Upgrade to latest version and include BiRefNet General-Lite-2K model
Browse files Browse the repository at this point in the history
  • Loading branch information
dimitribarbot committed Sep 28, 2024
1 parent ba99d44 commit 9f02102
Show file tree
Hide file tree
Showing 4 changed files with 21 additions and 8 deletions.
4 changes: 3 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ The available models are:

- General: A pre-trained model for general use cases.
- General-Lite: A light pre-trained model for general use cases.
- General-Lite-2K: A light pre-trained model for general use cases in high resolution (2560x1440).
- Portrait: A pre-trained model for human portraits.
- DIS: A pre-trained model for dichotomous image segmentation (DIS).
- HRSOD: A pre-trained model for high-resolution salient object detection (HRSOD).
Expand All @@ -32,6 +33,7 @@ Model files go here (automatically downloaded if the folder is not present durin
If necessary, they can be downloaded from:
- [General](https://huggingface.co/ZhengPeng7/BiRefNet/resolve/main/model.safetensors)`model.safetensors` must be renamed `General.safetensors`
- [General-Lite](https://huggingface.co/ZhengPeng7/BiRefNet_T/resolve/main/model.safetensors)`model.safetensors` must be renamed `General-Lite.safetensors`
- [General-Lite-2K](https://huggingface.co/ZhengPeng7/BiRefNet_lite-2K/resolve/main/model.safetensors)`model.safetensors` must be renamed `General-Lite-2K.safetensors`
- [Portrait](https://huggingface.co/ZhengPeng7/BiRefNet-portrait/resolve/main/model.safetensors)`model.safetensors` must be renamed `Portrait.safetensors`
- [DIS](https://huggingface.co/ZhengPeng7/BiRefNet-DIS5K/resolve/main/model.safetensors)`model.safetensors` must be renamed `DIS.safetensors`
- [HRSOD](https://huggingface.co/ZhengPeng7/BiRefNet-HRSOD/resolve/main/model.safetensors)`model.safetensors` must be renamed `HRSOD.safetensors`
Expand All @@ -48,7 +50,7 @@ Both endpoints share these parameters:
- `return_mask`: whether to return mask (can be used for inpainting).
- `return_edge_mask`: whether to return edge mask (can be used to blend foreground image with another background).
- `edge_mask_width`: edge mask width in pixels. Default to 64.
- `model_name`: `General`, `General-Lite`, `Portrait`, `DIS`, `HRSOD`, `COD` or `DIS-TR_TEs`. BiRefNet model to be used. Default to `General`.
- `model_name`: `General`, `General-Lite`, `General-Lite-2K`, `Portrait`, `DIS`, `HRSOD`, `COD` or `DIS-TR_TEs`. BiRefNet model to be used. Default to `General`.
- `output_dir`: directory to save output images.
- `output_extension`: output image file extension (without leading dot, `png` by default).
- `device_id`: GPU device id.
Expand Down
19 changes: 14 additions & 5 deletions birefnet/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,20 +4,29 @@

class Config():
def __init__(self, bb_index: int = 6) -> None:
# PATH settings
# PATH settings
# Make up your file system as: SYS_HOME_DIR/codes/dis/BiRefNet, SYS_HOME_DIR/datasets/dis/xx, SYS_HOME_DIR/weights/xx
# self.sys_home_dir = [os.path.expanduser('~'), '/mnt/data'][1] # Default, custom
# self.sys_home_dir = [os.path.expanduser('~'), '/mnt/data'][0] # Default, custom
# self.data_root_dir = os.path.join(self.sys_home_dir, 'datasets/dis')

# TASK settings
self.task = ['DIS5K', 'COD', 'HRSOD', 'General', 'General-2K', 'Matting'][0]
self.validation_set = {
'DIS5K': [],
'COD': [],
'HRSOD': [],
'General': ['DIS-VD', 'TE-P3M-500-NP'],
'General-2K': ['DIS-VD', 'TE-P3M-500-NP'],
'Matting': ['TE-P3M-500-NP'],
}[self.task]
# datasets_all = '+'.join([ds for ds in (os.listdir(os.path.join(self.data_root_dir, self.task)) if os.path.isdir(os.path.join(self.data_root_dir, self.task)) else []) if ds not in self.validation_set])
self.training_set = {
'DIS5K': ['DIS-TR', 'DIS-TR+DIS-TE1+DIS-TE2+DIS-TE3+DIS-TE4'][0],
'COD': 'TR-COD10K+TR-CAMO',
'HRSOD': ['TR-DUTS', 'TR-HRSOD', 'TR-UHRSD', 'TR-DUTS+TR-HRSOD', 'TR-DUTS+TR-UHRSD', 'TR-HRSOD+TR-UHRSD', 'TR-DUTS+TR-HRSOD+TR-UHRSD'][5],
'General': 'DIS-TE1+DIS-TE2+DIS-TE3+DIS-TE4+DIS-TR+TR-HRSOD+TE-HRSOD+TR-HRS10K+TE-HRS10K+TR-UHRSD+TE-UHRSD+TR-P3M-10k+TE-P3M-500-P+TR-humans+DIS-VD-ori', # '+'.join([ds for ds in os.listdir(os.path.join(self.data_root_dir, self.task)) if ds not in ['DIS-VD', 'TE-P3M-500-NP']]), # leave DIS-VD,TE-P3M-500-NP for evaluation.
'General-2K': 'DIS-TE1+DIS-TE2+DIS-TE3+DIS-TE4+DIS-TR+TR-HRSOD+TE-HRSOD+TR-HRS10K+TE-HRS10K+TR-UHRSD+TE-UHRSD+TR-P3M-10k+TE-P3M-500-P+TR-humans+DIS-VD-ori', # '+'.join([ds for ds in os.listdir(os.path.join(self.data_root_dir, self.task)) if ds not in ['DIS-VD', 'TE-P3M-500-NP']]),
'Matting': 'TR-P3M-10k+TE-P3M-500-NP+TR-humans+TR-Distrinctions-646',
'General': 'DIS-TE1+DIS-TE2+DIS-TE3+DIS-TE4+DIS-TR+TR-HRSOD+TE-HRSOD+TR-HRS10K+TE-HRS10K+TR-UHRSD+TE-UHRSD+TR-P3M-10k+TE-P3M-500-P+TR-humans+DIS-VD-ori', # datasets_all
'General-2K': 'DIS-TE1+DIS-TE2+DIS-TE3+DIS-TE4+DIS-TR+TR-HRSOD+TE-HRSOD+TR-HRS10K+TE-HRS10K+TR-UHRSD+TE-UHRSD+TR-P3M-10k+TE-P3M-500-P+TR-humans+DIS-VD-ori', # datasets_all
'Matting': 'TR-P3M-10k+TE-P3M-500-NP+TR-humans+TR-Distrinctions-646', # datasets_all
}[self.task]
self.prompt4loc = ['dense', 'sparse'][0]

Expand Down
5 changes: 3 additions & 2 deletions internal_birefnet/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
usage_to_weights_file = {
"General": "BiRefNet",
"General-Lite": "BiRefNet_T",
"General-Lite-2K": "BiRefNet_lite-2K",
"Portrait": "BiRefNet-portrait",
"DIS": "BiRefNet-DIS5K",
"HRSOD": "BiRefNet-HRSOD",
Expand All @@ -33,7 +34,7 @@
}

BiRefNetModelName = Literal[
"General", "General-Lite", "Portrait", "DIS", "HRSOD", "COD", "DIS-TR_TEs"
"General", "General-Lite", "General-Lite-2K", "Portrait", "DIS", "HRSOD", "COD", "DIS-TR_TEs"
]


Expand Down Expand Up @@ -109,7 +110,7 @@ def __init__(

state_dict = safetensors.torch.load_file(weight_path, device=self.device)

bb_index = 3 if model_name == "General-Lite" else 6
bb_index = 3 if model_name == "General-Lite" or model_name == "General-Lite-2K" else 6

self.birefnet = BiRefNet(bb_pretrained=False, bb_index=bb_index)
self.birefnet.load_state_dict(state_dict)
Expand Down
1 change: 1 addition & 0 deletions scripts/postprocessing_birefnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
"None",
"General",
"General-Lite",
"General-Lite-2K",
"Portrait",
"DIS",
"HRSOD",
Expand Down

0 comments on commit 9f02102

Please sign in to comment.