Skip to content

Commit

Permalink
update
Browse files Browse the repository at this point in the history
  • Loading branch information
sdbds committed Jul 15, 2024
1 parent eb14e7a commit f46298f
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 8 deletions.
13 changes: 6 additions & 7 deletions annotator/mobile_sam/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,12 +19,8 @@ class SamDetector_Aux(SamDetector):
def __init__(self, mask_generator: SamAutomaticMaskGenerator):
super().__init__(mask_generator)

self.device = devices.device
self.model = SamDetector_Aux().to(self.device).eval()
self.from_pretrained(model_type="vit_t")

@classmethod
def from_pretrained(cls, model_type="vit_t"):
def from_pretrained(cls):
"""
Possible model_type : vit_h, vit_l, vit_b, vit_t
download weights from https://huggingface.co/dhkim2810/MobileSAM
Expand All @@ -35,9 +31,12 @@ def from_pretrained(cls, model_type="vit_t"):
)
model_path = load_model(
"mobile_sam.pt", remote_url=remote_url, model_dir=cls.model_dir
)
)

sam = sam_model_registry["vit_t"](checkpoint=model_path)

sam = sam_model_registry[model_type](checkpoint=model_path)
cls.device = devices.device
cls.model = SamDetector_Aux().to(cls.device).eval()

mask_generator = SamAutomaticMaskGenerator(sam)

Expand Down
2 changes: 1 addition & 1 deletion scripts/preprocessor/mobile_sam.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ def __call__(
):
img, remove_pad = resize_image_with_pad(input_image, resolution)
if self.model is None:
self.model = SamDetector_Aux()
self.model = SamDetector_Aux.from_pretrained()

result = self.model(img, detect_resolution=resolution, image_resolution=resolution)
return remove_pad(result)
Expand Down

0 comments on commit f46298f

Please sign in to comment.