Skip to content

Commit

Permalink
update way to get img shape
Browse files Browse the repository at this point in the history
  • Loading branch information
eunwoosh committed Aug 16, 2024
1 parent 77e2b9b commit 37b39a3
Showing 1 changed file with 6 additions and 4 deletions.
10 changes: 6 additions & 4 deletions src/otx/core/data/dataset/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,9 +143,11 @@ def __getitem__(self, index: int) -> T_OTXDataEntity:

def _get_img_data_and_shape(self, img: Image) -> tuple[np.ndarray, tuple[int, int]]:
key = img.path if isinstance(img, ImageFromFile) else id(img)
img_shape = img.size

if (img_data := self.mem_cache_handler.get(key=key)[0]) is not None:
return img_data, img_data.shape[:2]
img_data, meta = self.mem_cache_handler.get(key=key)
if img_data is not None:
return img_data, img_shape or img_data.shape[:2]

with image_decode_context():
img_data = (
Expand All @@ -160,7 +162,7 @@ def _get_img_data_and_shape(self, img: Image) -> tuple[np.ndarray, tuple[int, in

img_data = self._cache_img(key=key, img_data=img_data.astype(np.uint8))

return img_data, img_data.shape[:2]
return img_data, img_shape or img_data.shape[:2]

def _cache_img(self, key: str | int, img_data: np.ndarray) -> np.ndarray:
"""Cache an image after resizing.
Expand All @@ -181,11 +183,11 @@ def _cache_img(self, key: str | int, img_data: np.ndarray) -> np.ndarray:
if self.mem_cache_handler.frozen:
return img_data

height, width = img_data.shape[:2]
if self.mem_cache_img_max_size is None:
self.mem_cache_handler.put(key=key, data=img_data, meta=None)
return img_data

height, width = img_data.shape[:2]
max_height, max_width = self.mem_cache_img_max_size

if height <= max_height and width <= max_width:
Expand Down

0 comments on commit 37b39a3

Please sign in to comment.