Skip to content

Commit

Permalink
Add mobile_sam with controlnet_aux (#3000)
Browse files Browse the repository at this point in the history
* Add mobile_sam with controlnet_aux for CNXL_Union
  • Loading branch information
sdbds authored Jul 15, 2024
1 parent 3ff69b9 commit 0baecb5
Show file tree
Hide file tree
Showing 4 changed files with 77 additions and 1 deletion.
49 changes: 49 additions & 0 deletions annotator/mobile_sam/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
from __future__ import print_function

import os
import numpy as np
from PIL import Image
from typing import Union

from modules import devices
from annotator.util import load_model
from annotator.annotator_path import models_path

from controlnet_aux import SamDetector
from controlnet_aux.segment_anything import sam_model_registry, SamAutomaticMaskGenerator

class SamDetector_Aux(SamDetector):

model_dir = os.path.join(models_path, "mobile_sam")

def __init__(self, mask_generator: SamAutomaticMaskGenerator, sam):
super().__init__(mask_generator)
self.device = devices.device
self.model = sam.to(self.device).eval()

@classmethod
def from_pretrained(cls):
"""
Possible model_type : vit_h, vit_l, vit_b, vit_t
download weights from https://huggingface.co/dhkim2810/MobileSAM
"""
remote_url = os.environ.get(
"CONTROLNET_MOBILE_SAM_MODEL_URL",
"https://huggingface.co/dhkim2810/MobileSAM/resolve/main/mobile_sam.pt",
)
model_path = load_model(
"mobile_sam.pt", remote_url=remote_url, model_dir=cls.model_dir
)

sam = sam_model_registry["vit_t"](checkpoint=model_path)

cls.model = sam.to(devices.device).eval()

mask_generator = SamAutomaticMaskGenerator(cls.model)

return cls(mask_generator, sam)

def __call__(self, input_image: Union[np.ndarray, Image.Image]=None, detect_resolution=512, image_resolution=512, output_type="cv2", **kwargs) -> np.ndarray:
self.model.to(self.device)
image = super().__call__(input_image=input_image, detect_resolution=detect_resolution, image_resolution=image_resolution, output_type=output_type, **kwargs)
return np.array(image).astype(np.uint8)
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -9,3 +9,4 @@ matplotlib
facexlib
timm<=0.9.5
pydantic<=1.10.17
controlnet_aux
3 changes: 2 additions & 1 deletion scripts/preprocessor/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,4 +5,5 @@
from .ip_adapter_auto import *
from .normal_dsine import *
from .model_free_preprocessors import *
from .legacy.legacy_preprocessors import *
from .legacy.legacy_preprocessors import *
from .mobile_sam import *
25 changes: 25 additions & 0 deletions scripts/preprocessor/mobile_sam.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
from annotator.mobile_sam import SamDetector_Aux
from scripts.supported_preprocessor import Preprocessor

class PreprocessorMobileSam(Preprocessor):
def __init__(self):
super().__init__(name="mobile_sam")
self.tags = ["Segmentation"]
self.model = None

def __call__(
self,
input_image,
resolution,
slider_1=None,
slider_2=None,
slider_3=None,
**kwargs
):
if self.model is None:
self.model = SamDetector_Aux.from_pretrained()

result = self.model(input_image, detect_resolution=resolution, image_resolution=resolution, output_type="cv2")
return result

Preprocessor.add_supported_preprocessor(PreprocessorMobileSam())

0 comments on commit 0baecb5

Please sign in to comment.