Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Set mem_cache_img_min_size to input_size if it's none #3842

Draft
wants to merge 4 commits into
base: develop
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/otx/core/config/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ def as_int_tuple(*args) -> tuple[int, ...]:
YAML file example::

```yaml
mem_cache_img_max_size: ${as_int_tuple:500,500}
mem_cache_img_min_size: ${as_int_tuple:500,500}
```
"""
return tuple(int(arg) for arg in args)
Expand Down
4 changes: 2 additions & 2 deletions src/otx/core/data/dataset/action_classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ def __init__(
dm_subset: DatasetSubset,
transforms: Transforms,
mem_cache_handler: MemCacheHandlerBase = NULL_MEM_CACHE_HANDLER,
mem_cache_img_max_size: tuple[int, int] | None = None,
mem_cache_img_min_size: tuple[int, int] | None = None,
max_refetch: int = 1000,
image_color_channel: ImageColorChannel = ImageColorChannel.BGR,
stack_images: bool = True,
Expand All @@ -42,7 +42,7 @@ def __init__(
dm_subset,
transforms,
mem_cache_handler,
mem_cache_img_max_size,
mem_cache_img_min_size,
max_refetch,
image_color_channel,
stack_images,
Expand Down
10 changes: 5 additions & 5 deletions src/otx/core/data/dataset/anomaly.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ def __init__(
dm_subset: DmDataset,
transforms: Transforms,
mem_cache_handler: MemCacheHandlerBase = NULL_MEM_CACHE_HANDLER,
mem_cache_img_max_size: tuple[int, int] | None = None,
mem_cache_img_min_size: tuple[int, int] | None = None,
max_refetch: int = 1000,
image_color_channel: ImageColorChannel = ImageColorChannel.RGB,
stack_images: bool = True,
Expand All @@ -51,7 +51,7 @@ def __init__(
dm_subset,
transforms,
mem_cache_handler,
mem_cache_img_max_size,
mem_cache_img_min_size,
max_refetch,
image_color_channel,
stack_images,
Expand Down Expand Up @@ -79,7 +79,7 @@ def _get_item_impl(
image=img_data,
img_info=ImageInfo(
img_idx=index,
img_shape=img_shape,
img_shape=img_data.shape[:2],
ori_shape=img_shape,
image_color_channel=self.image_color_channel,
),
Expand All @@ -101,7 +101,7 @@ def _get_item_impl(
image=img_data,
img_info=ImageInfo(
img_idx=index,
img_shape=img_shape,
img_shape=img_data.shape[:2],
ori_shape=img_shape,
image_color_channel=self.image_color_channel,
),
Expand All @@ -124,7 +124,7 @@ def _get_item_impl(
image=img_data,
img_info=ImageInfo(
img_idx=index,
img_shape=img_shape,
img_shape=img_data.shape[:2],
ori_shape=img_shape,
image_color_channel=self.image_color_channel,
),
Expand Down
31 changes: 17 additions & 14 deletions src/otx/core/data/dataset/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ class OTXDataset(Dataset, Generic[T_OTXDataEntity]):
dm_subset: Datumaro subset of a dataset
transforms: Transforms to apply on images
mem_cache_handler: Handler of the images cache
mem_cache_img_max_size: Max size of images to put in cache
mem_cache_img_min_size: Minimum size of images to put in cache
max_refetch: Maximum number of images to fetch in cache
image_color_channel: Color channel of images
stack_images: Whether or not to stack images in collate function in OTXBatchData entity.
Expand All @@ -78,7 +78,7 @@ def __init__(
dm_subset: DatasetSubset,
transforms: Transforms,
mem_cache_handler: MemCacheHandlerBase = NULL_MEM_CACHE_HANDLER,
mem_cache_img_max_size: tuple[int, int] | None = None,
mem_cache_img_min_size: tuple[int, int] | None = None,
max_refetch: int = 1000,
image_color_channel: ImageColorChannel = ImageColorChannel.RGB,
stack_images: bool = True,
Expand All @@ -87,7 +87,7 @@ def __init__(
self.dm_subset = dm_subset
self.transforms = transforms
self.mem_cache_handler = mem_cache_handler
self.mem_cache_img_max_size = mem_cache_img_max_size
self.mem_cache_img_min_size = mem_cache_img_min_size
self.max_refetch = max_refetch
self.image_color_channel = image_color_channel
self.stack_images = stack_images
Expand Down Expand Up @@ -142,10 +142,13 @@ def __getitem__(self, index: int) -> T_OTXDataEntity:
raise RuntimeError(msg)

def _get_img_data_and_shape(self, img: Image) -> tuple[np.ndarray, tuple[int, int]]:
"""Get image and original image shape from the memory cache."""
key = img.path if isinstance(img, ImageFromFile) else id(img)
img_shape = img.size

if (img_data := self.mem_cache_handler.get(key=key)[0]) is not None:
return img_data, img_data.shape[:2]
img_data, _ = self.mem_cache_handler.get(key=key)
if img_data is not None:
return img_data, img_shape or img_data.shape[:2]

with image_decode_context():
img_data = (
Expand All @@ -160,13 +163,13 @@ def _get_img_data_and_shape(self, img: Image) -> tuple[np.ndarray, tuple[int, in

img_data = self._cache_img(key=key, img_data=img_data.astype(np.uint8))

return img_data, img_data.shape[:2]
return img_data, img_shape or img_data.shape[:2]

def _cache_img(self, key: str | int, img_data: np.ndarray) -> np.ndarray:
"""Cache an image after resizing.

If there is available space in the memory pool, the input image is cached.
Before caching, the input image is resized if it is larger than the maximum image size
Before caching, the input image is resized if it is larger than the minimum image size
specified by the memory caching handler.
Otherwise, the input image is directly cached.
After caching, the processed image data is returned.
Expand All @@ -181,21 +184,21 @@ def _cache_img(self, key: str | int, img_data: np.ndarray) -> np.ndarray:
if self.mem_cache_handler.frozen:
return img_data

if self.mem_cache_img_max_size is None:
if self.mem_cache_img_min_size is None:
self.mem_cache_handler.put(key=key, data=img_data, meta=None)
return img_data

height, width = img_data.shape[:2]
max_height, max_width = self.mem_cache_img_max_size
min_height, min_width = self.mem_cache_img_min_size

if height <= max_height and width <= max_width:
if height <= min_height or width <= min_width:
self.mem_cache_handler.put(key=key, data=img_data, meta=None)
return img_data

# Preserve the image size ratio and fit to max_height or max_width
# e.g. (1000 / 2000 = 0.5, 1000 / 1000 = 1.0) => 0.5
# h, w = 2000 * 0.5 => 1000, 1000 * 0.5 => 500, bounded by max_height
min_scale = min(max_height / height, max_width / width)
# Preserve the image size ratio and fit to min_height or min_width
# e.g. (1000 / 2000 = 0.5, 1000 / 4000 = 0.25) => 0.5
# h, w = 2000 * 0.5 => 1000, 4000 * 0.5 => 2000, bounded by min_height
min_scale = max(min_height / height, min_width / width)
new_height, new_width = int(min_scale * height), int(min_scale * width)
resized_img = cv2.resize(
src=img_data,
Expand Down
6 changes: 3 additions & 3 deletions src/otx/core/data/dataset/classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ def _get_item_impl(self, index: int) -> MulticlassClsDataEntity | None:
image=img_data,
img_info=ImageInfo(
img_idx=index,
img_shape=img_shape,
img_shape=img_data.shape[:2],
ori_shape=img_shape,
image_color_channel=self.image_color_channel,
),
Expand Down Expand Up @@ -78,7 +78,7 @@ def _get_item_impl(self, index: int) -> MultilabelClsDataEntity | None:
image=img_data,
img_info=ImageInfo(
img_idx=index,
img_shape=img_shape,
img_shape=img_data.shape[:2],
ori_shape=img_shape,
image_color_channel=self.image_color_channel,
ignored_labels=ignored_labels,
Expand Down Expand Up @@ -186,7 +186,7 @@ def _get_item_impl(self, index: int) -> HlabelClsDataEntity | None:
image=img_data,
img_info=ImageInfo(
img_idx=index,
img_shape=img_shape,
img_shape=img_data.shape[:2],
ori_shape=img_shape,
image_color_channel=self.image_color_channel,
ignored_labels=ignored_labels,
Expand Down
2 changes: 1 addition & 1 deletion src/otx/core/data/dataset/detection.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ def _get_item_impl(self, index: int) -> DetDataEntity | None:
image=img_data,
img_info=ImageInfo(
img_idx=index,
img_shape=img_shape,
img_shape=img_data.shape[:2],
ori_shape=img_shape,
image_color_channel=self.image_color_channel,
ignored_labels=ignored_labels,
Expand Down
2 changes: 1 addition & 1 deletion src/otx/core/data/dataset/instance_segmentation.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ def _get_item_impl(self, index: int) -> InstanceSegDataEntity | None:
image=img_data,
img_info=ImageInfo(
img_idx=index,
img_shape=img_shape,
img_shape=img_data.shape[:2],
ori_shape=img_shape,
image_color_channel=self.image_color_channel,
ignored_labels=ignored_labels,
Expand Down
6 changes: 3 additions & 3 deletions src/otx/core/data/dataset/keypoint_detection.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ def __init__(
dm_subset: DatasetSubset,
transforms: Transforms,
mem_cache_handler: MemCacheHandlerBase = NULL_MEM_CACHE_HANDLER,
mem_cache_img_max_size: tuple[int, int] | None = None,
mem_cache_img_min_size: tuple[int, int] | None = None,
max_refetch: int = 1000,
image_color_channel: ImageColorChannel = ImageColorChannel.RGB,
stack_images: bool = True,
Expand All @@ -44,7 +44,7 @@ def __init__(
dm_subset,
transforms,
mem_cache_handler,
mem_cache_img_max_size,
mem_cache_img_min_size,
max_refetch,
image_color_channel,
stack_images,
Expand Down Expand Up @@ -112,7 +112,7 @@ def _get_item_impl(self, index: int) -> KeypointDetDataEntity | None:
image=img_data,
img_info=ImageInfo(
img_idx=index,
img_shape=img_shape,
img_shape=img_data.shape[:2],
ori_shape=img_shape,
image_color_channel=self.image_color_channel,
ignored_labels=ignored_labels,
Expand Down
6 changes: 3 additions & 3 deletions src/otx/core/data/dataset/segmentation.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,7 +161,7 @@ def __init__(
dm_subset: DmDataset,
transforms: Transforms,
mem_cache_handler: MemCacheHandlerBase = NULL_MEM_CACHE_HANDLER,
mem_cache_img_max_size: tuple[int, int] | None = None,
mem_cache_img_min_size: tuple[int, int] | None = None,
max_refetch: int = 1000,
image_color_channel: ImageColorChannel = ImageColorChannel.RGB,
stack_images: bool = True,
Expand All @@ -172,7 +172,7 @@ def __init__(
dm_subset,
transforms,
mem_cache_handler,
mem_cache_img_max_size,
mem_cache_img_min_size,
max_refetch,
image_color_channel,
stack_images,
Expand Down Expand Up @@ -214,7 +214,7 @@ def _get_item_impl(self, index: int) -> SegDataEntity | None:
image=img_data,
img_info=ImageInfo(
img_idx=index,
img_shape=img_shape,
img_shape=img_data.shape[:2],
ori_shape=img_shape,
image_color_channel=self.image_color_channel,
ignored_labels=ignored_labels,
Expand Down
6 changes: 3 additions & 3 deletions src/otx/core/data/dataset/tile.py
Original file line number Diff line number Diff line change
Expand Up @@ -212,7 +212,7 @@ def __init__(self, dataset: OTXDataset, tile_config: TileConfig) -> None:
dataset.dm_subset,
dataset.transforms,
dataset.mem_cache_handler,
dataset.mem_cache_img_max_size,
dataset.mem_cache_img_min_size,
dataset.max_refetch,
)
self.tile_config = tile_config
Expand Down Expand Up @@ -354,7 +354,7 @@ def _get_item_impl(self, index: int) -> TileDetDataEntity: # type: ignore[overr
tile_attr_list=tile_attrs,
ori_img_info=ImageInfo(
img_idx=index,
img_shape=img_shape,
img_shape=img_data.shape[:2],
ori_shape=img_shape,
),
ori_bboxes=tv_tensors.BoundingBoxes(
Expand Down Expand Up @@ -456,7 +456,7 @@ def _get_item_impl(self, index: int) -> TileInstSegDataEntity: # type: ignore[o
tile_attr_list=tile_attrs,
ori_img_info=ImageInfo(
img_idx=index,
img_shape=img_shape,
img_shape=img_data.shape[:2],
ori_shape=img_shape,
),
ori_bboxes=tv_tensors.BoundingBoxes(
Expand Down
4 changes: 2 additions & 2 deletions src/otx/core/data/dataset/visual_prompting.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,7 +147,7 @@ def _get_item_impl(self, index: int) -> VisualPromptingDataEntity | None:
image=img_data,
img_info=ImageInfo(
img_idx=index,
img_shape=img_shape,
img_shape=img_data.shape[:2],
ori_shape=img_shape,
),
masks=None,
Expand Down Expand Up @@ -270,7 +270,7 @@ def _get_item_impl(self, index: int) -> ZeroShotVisualPromptingDataEntity | None
image=to_image(img_data),
img_info=ImageInfo(
img_idx=index,
img_shape=img_shape,
img_shape=img_data.shape[:2],
ori_shape=img_shape,
),
masks=masks,
Expand Down
4 changes: 2 additions & 2 deletions src/otx/core/data/factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ def create( # noqa: PLR0911
dm_subset: DmDataset,
cfg_subset: SubsetConfig,
mem_cache_handler: MemCacheHandlerBase,
mem_cache_img_max_size: tuple[int, int] | None = None,
mem_cache_img_min_size: tuple[int, int] | None = None,
image_color_channel: ImageColorChannel = ImageColorChannel.RGB,
stack_images: bool = True,
include_polygons: bool = False,
Expand All @@ -86,7 +86,7 @@ def create( # noqa: PLR0911
"dm_subset": dm_subset,
"transforms": transforms,
"mem_cache_handler": mem_cache_handler,
"mem_cache_img_max_size": mem_cache_img_max_size,
"mem_cache_img_min_size": mem_cache_img_min_size,
"image_color_channel": image_color_channel,
"stack_images": stack_images,
"to_tv_image": cfg_subset.to_tv_image,
Expand Down
6 changes: 3 additions & 3 deletions src/otx/core/data/mem_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,9 +82,9 @@ class MemCacheHandlerBase:
It will be combined with LoadImageFromOTXDataset to store/retrieve the samples in memory.
"""

def __init__(self, mem_size: int, mem_cache_img_max_size: tuple[int, int] | None = None):
def __init__(self, mem_size: int, mem_cache_img_min_size: tuple[int, int] | None = None):
self._mem_size = mem_size
self._mem_cache_img_max_size = mem_cache_img_max_size
self._mem_cache_img_max_size = mem_cache_img_min_size
self._init_data_structs(mem_size)

def _init_data_structs(self, mem_size: int) -> None:
Expand All @@ -110,7 +110,7 @@ def mem_size(self) -> int:
return len(self._arr)

@property
def mem_cache_img_max_size(self) -> tuple[int, int] | None:
def mem_cache_img_min_size(self) -> tuple[int, int] | None:
"""Get the image max size in mem cache."""
return self._mem_cache_img_max_size

Expand Down
Loading
Loading