Skip to content

Commit

Permalink
Sync BiRefNet with latest BiRefNet repository updates and add Matting…
Browse files Browse the repository at this point in the history
… model
  • Loading branch information
dimitribarbot committed Oct 9, 2024
1 parent 9f02102 commit 1c380a6
Show file tree
Hide file tree
Showing 4 changed files with 25 additions and 14 deletions.
4 changes: 3 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ The available models are:
- General-Lite: A light pre-trained model for general use cases.
- General-Lite-2K: A light pre-trained model for general use cases in high resolution (2560x1440).
- Portrait: A pre-trained model for human portraits.
- Matting: A pre-trained model for general trimap-free matting use.
- DIS: A pre-trained model for dichotomous image segmentation (DIS).
- HRSOD: A pre-trained model for high-resolution salient object detection (HRSOD).
- COD: A pre-trained model for concealed object detection (COD).
Expand All @@ -35,6 +36,7 @@ If necessary, they can be downloaded from:
- [General-Lite](https://huggingface.co/ZhengPeng7/BiRefNet_T/resolve/main/model.safetensors)`model.safetensors` must be renamed `General-Lite.safetensors`
- [General-Lite-2K](https://huggingface.co/ZhengPeng7/BiRefNet_lite-2K/resolve/main/model.safetensors)`model.safetensors` must be renamed `General-Lite-2K.safetensors`
- [Portrait](https://huggingface.co/ZhengPeng7/BiRefNet-portrait/resolve/main/model.safetensors)`model.safetensors` must be renamed `Portrait.safetensors`
- [Matting](https://huggingface.co/ZhengPeng7/BiRefNet-matting/resolve/main/model.safetensors)`model.safetensors` must be renamed `Matting.safetensors`
- [DIS](https://huggingface.co/ZhengPeng7/BiRefNet-DIS5K/resolve/main/model.safetensors)`model.safetensors` must be renamed `DIS.safetensors`
- [HRSOD](https://huggingface.co/ZhengPeng7/BiRefNet-HRSOD/resolve/main/model.safetensors)`model.safetensors` must be renamed `HRSOD.safetensors`
- [COD](https://huggingface.co/ZhengPeng7/BiRefNet-COD/resolve/main/model.safetensors)`model.safetensors` must be renamed `COD.safetensors`
Expand All @@ -50,7 +52,7 @@ Both endpoints share these parameters:
- `return_mask`: whether to return mask (can be used for inpainting).
- `return_edge_mask`: whether to return edge mask (can be used to blend foreground image with another background).
- `edge_mask_width`: edge mask width in pixels. Default to 64.
- `model_name`: `General`, `General-Lite`, `General-Lite-2K`, `Portrait`, `DIS`, `HRSOD`, `COD` or `DIS-TR_TEs`. BiRefNet model to be used. Default to `General`.
- `model_name`: `General`, `General-Lite`, `General-Lite-2K`, `Portrait`, `Matting`, `DIS`, `HRSOD`, `COD` or `DIS-TR_TEs`. BiRefNet model to be used. Default to `General`.
- `output_dir`: directory to save output images.
- `output_extension`: output image file extension (without leading dot, `png` by default).
- `device_id`: GPU device id.
Expand Down
31 changes: 19 additions & 12 deletions birefnet/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,15 +11,17 @@ def __init__(self, bb_index: int = 6) -> None:

# TASK settings
self.task = ['DIS5K', 'COD', 'HRSOD', 'General', 'General-2K', 'Matting'][0]
self.validation_set = {
'DIS5K': [],
'COD': [],
'HRSOD': [],
'General': ['DIS-VD', 'TE-P3M-500-NP'],
'General-2K': ['DIS-VD', 'TE-P3M-500-NP'],
'Matting': ['TE-P3M-500-NP'],
self.testsets = {
# Benchmarks
'DIS5K': ','.join(['DIS-VD', 'DIS-TE1', 'DIS-TE2', 'DIS-TE3', 'DIS-TE4']),
'COD': ','.join(['CHAMELEON', 'NC4K', 'TE-CAMO', 'TE-COD10K']),
'HRSOD': ','.join(['DAVIS-S', 'TE-HRSOD', 'TE-UHRSD', 'DUT-OMRON', 'TE-DUTS']),
# Practical use
'General': ','.join(['DIS-VD', 'TE-P3M-500-NP']),
'General-2K': ','.join(['DIS-VD', 'TE-P3M-500-NP']),
'Matting': ','.join(['TE-P3M-500-NP', 'TE-AM-2k']),
}[self.task]
# datasets_all = '+'.join([ds for ds in (os.listdir(os.path.join(self.data_root_dir, self.task)) if os.path.isdir(os.path.join(self.data_root_dir, self.task)) else []) if ds not in self.validation_set])
# datasets_all = '+'.join([ds for ds in (os.listdir(os.path.join(self.data_root_dir, self.task)) if os.path.isdir(os.path.join(self.data_root_dir, self.task)) else []) if ds not in self.testsets.split(',')])
self.training_set = {
'DIS5K': ['DIS-TR', 'DIS-TR+DIS-TE1+DIS-TE2+DIS-TE3+DIS-TE4'][0],
'COD': 'TR-COD10K+TR-CAMO',
Expand Down Expand Up @@ -184,11 +186,16 @@ def __init__(self, bb_index: int = 6) -> None:
# self.save_last = int([l.strip() for l in lines if '"{}")'.format(self.task) in l and 'val_last=' in l][0].split('val_last=')[-1].split()[0])
# self.save_step = int([l.strip() for l in lines if '"{}")'.format(self.task) in l and 'step=' in l][0].split('step=')[-1].split()[0])

def print_task(self) -> None:
# Return task for choosing settings in shell scripts.
print(self.task)

# Return task for choosing settings in shell scripts.
if __name__ == '__main__':
import argparse
parser = argparse.ArgumentParser(description='Only choose one argument to activate.')
parser.add_argument('--print_task', action='store_true', help='print task name')
parser.add_argument('--print_testsets', action='store_true', help='print validation set')
args = parser.parse_args()
config = Config()
config.print_task()
for arg_name, arg_value in args._get_kwargs():
if arg_value:
print(config.__getattribute__(arg_name[len('print_'):]))

3 changes: 2 additions & 1 deletion internal_birefnet/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,14 +27,15 @@
"General-Lite": "BiRefNet_T",
"General-Lite-2K": "BiRefNet_lite-2K",
"Portrait": "BiRefNet-portrait",
"Matting": "BiRefNet-matting",
"DIS": "BiRefNet-DIS5K",
"HRSOD": "BiRefNet-HRSOD",
"COD": "BiRefNet-COD",
"DIS-TR_TEs": "BiRefNet-DIS5K-TR_TEs",
}

BiRefNetModelName = Literal[
"General", "General-Lite", "General-Lite-2K", "Portrait", "DIS", "HRSOD", "COD", "DIS-TR_TEs"
"General", "General-Lite", "General-Lite-2K", "Portrait", "Matting", "DIS", "HRSOD", "COD", "DIS-TR_TEs"
]


Expand Down
1 change: 1 addition & 0 deletions scripts/postprocessing_birefnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
"General-Lite",
"General-Lite-2K",
"Portrait",
"Matting",
"DIS",
"HRSOD",
"COD",
Expand Down

0 comments on commit 1c380a6

Please sign in to comment.