-
Notifications
You must be signed in to change notification settings - Fork 1
/
action.py
277 lines (234 loc) · 9.64 KB
/
action.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
# For licensing see accompanying LICENSE file.
# Copyright (C) 2024 Apple Inc. All Rights Reserved.
import os
from typing import Optional, Type
from omegaconf import DictConfig
from PIL import Image
from gbc.utils import instantiate_class
from .io_unit import QueryResult
from .action_io import ActionInput, ActionOutput, NodeInfo
__all__ = [
"Action",
"ActionInputPair",
]
class Action(object):
"""
Abstract base class for defining actions that can be queried.
This class should be subclassed to implement specific actions. Subclasses must
implement the :meth:`~query` method for single queries and can optionally
override the :meth:`~batch_query` method for handling multiple queries
more efficiently.
"""
def query(
self,
action_input: ActionInput,
queried_nodes: Optional[dict[str, list[NodeInfo]]] = None,
) -> ActionOutput:
"""
Executes a query based on the provided action input
This method performs the core action logic using the information
in ``action_input``. It may also leverage previously queried nodes
stored in ``queried_nodes``.
Parameters
----------
action_input:
The input required to perform the action.
queried_nodes:
A dictionary containing information about previously queried nodes,
as organized by image path.
This is used to avoid redundant queries (via node merging) and is updated
with new nodes discovered during this query.
Returns
-------
ActionOutput
The output of the action, containing the following:
- actions_to_complete: Additional action input pairs for recursive query.
- query_result: The result of the query, which may be ``None``.
- image: The image used to perform the query, which may be ``None``.
"""
pass
def batch_query(
self,
action_inputs: list[ActionInput],
queried_nodes: Optional[dict[str, list[NodeInfo]]],
) -> list[ActionOutput]:
"""
Executes multiple queries in parallel.
By default, this method sequentially calls :meth:`~query` for each input.
Subclasses can override this method to provide a more efficient implementation.
Parameters
----------
action_inputs
A list of inputs for performing the queries.
queried_nodes
A dictionary containing information about previously queried nodes,
as organized by image path.
This is used to avoid redundant queries (via node merging) and is updated
with new nodes discovered during this query.
Returns
-------
list[ActionOutput]
A list of outputs, each corresponding to an action input.
"""
return [self.query(input, queried_nodes) for input in action_inputs]
@staticmethod
def _add_to_queried_nodes(
query_result: QueryResult,
action_input: ActionInput,
queried_nodes: Optional[dict[str, list[NodeInfo]]],
):
"""
Adds the result of a query to the queried nodes dictionary.
Parameters
----------
query_result : QueryResult
The result of the query to be added.
action_input : ActionInput
The input that produced the query result.
queried_nodes : dict of str to list of NodeInfo, optional
A dictionary containing information about previously queried nodes,
as organized by image path.
This is used to avoid redundant queries (via node merging) and is updated
with new nodes discovered during this query.
"""
if queried_nodes is None:
return
node_info = NodeInfo(
action_input=action_input,
query_result=query_result,
)
img_path = node_info.img_path
if img_path in queried_nodes:
queried_nodes[img_path].append(node_info)
else:
queried_nodes[img_path] = [node_info]
class ActionInputPair(object):
"""
Encapsulates a pair of an :class:`~Action` class and its corresponding
input for performing queries.
This class allows for the storage and execution of queries.
The storage is achieved through the :meth:`~model_dump`
and :meth:`model_validate` methods.
Precisely, the conversion from and to a dictionary relies on representing
each class with their module and name.
To execute the query, the action class should be instantiated with
a configuration dictionary, passed as argument to the :meth:`~query` method.
Attributes
----------
action_class : Type[Action]
The class of the action to be performed.
action_input : ActionInput
The input required for the action.
"""
def __init__(self, action_class: Type[Action], action_input: ActionInput):
self.action_class = action_class
self.action_input = action_input
def __repr__(self):
return f"ActionInputPair({self.action_class!r}, {self.action_input!r})"
def model_dump(self) -> dict:
"""
Converts the ``ActionInputPair`` object to a dictionary format.
Returns
-------
dict
Dictionary representation of the ``ActionInputPair`` object.
"""
return {
"action_type": (
self.action_class.__module__,
self.action_class.__name__,
),
"input_type": (
self.action_input.__class__.__module__,
self.action_input.__class__.__name__,
),
"input_dict": self.action_input.model_dump(),
}
@classmethod
def model_validate(cls, obj: dict) -> "ActionInputPair":
"""
Validates and constructs an ``ActionInputPair`` object from a dictionary.
Parameters
----------
obj : dict
Dictionary containing the action type and input type information.
Returns
-------
ActionInputPair
An instance of ``ActionInputPair`` constructed from the dictionary.
Raises
------
ValueError
If the input is not a dictionary.
"""
# Ensure obj is a dict
if not isinstance(obj, dict):
raise ValueError("Input should be a dictionary", cls)
action_class = instantiate_class(*obj["action_type"])
input_class = instantiate_class(*obj["input_type"])
input = input_class.model_validate(obj["input_dict"])
return cls(action_class=action_class, action_input=input)
def query(
self,
config: DictConfig,
queried_nodes: Optional[dict[str, list[NodeInfo]]] = None,
) -> ActionOutput:
"""
Executes the query for the action using the provided configuration.
It first instantiates :attr:`~action_class` with the provided configuration
and then calls the ``query`` method of the action on the :attr:`~action_input`.
Parameters
----------
config
Configuration for the querying pipeline.
It would generally define how the actions should be instantiated.
queried_nodes
A dictionary containing information about previously queried nodes,
as organized by image path.
This is used to avoid redundant queries (via node merging) and is updated
with new nodes discovered during this query.
Returns
-------
ActionOutput
The output of the action, containing the following:
- actions_to_complete: Additional action input pairs for recursive query.
- query_result: The result of the query, which may be ``None``.
- image: The image used to perform the query, which may be ``None``.
"""
action = self.action_class(config)
return action.query(self.action_input, queried_nodes=queried_nodes)
def save_image(self, image: Image.Image, base_save_dir: str):
"""
Saves the associated image to a subdirectory within a specified base directory.
This method takes the image to be saved and a base directory path.
It constructs a subdirectory structure within ``base_save_dir``
based on the following logic:
- ``images``: A subdirectory is created under ``base_save_dir`` to store images.
- Image Path (modified): The image path is used but adjusted by:
- Splitting filename and extension using ``os.path.splitext``
- Replacing path separators with hyphens (-)
- Entity ID: The entity ID from the query node information is retrieved.
- Action Class Name: The name of the current action class is retrieved.
The final save path is constructed by joining these elements.
Parameters
----------
image
The image to be saved.
base_save_dir
The base directory path where the image will be saved.
Subdirectories will be created within this path.
"""
img_path_adjusted = os.path.splitext(self.action_input.img_path)[0].replace(
os.path.sep, "-"
)
entity_info = self.action_input.entity_info
if isinstance(entity_info, list):
entity_info = entity_info[0]
save_name = (entity_info.entity_id + "_" + self.action_class.__name__).replace(
os.path.sep, "-"
)
save_path = os.path.join(
base_save_dir, "images", img_path_adjusted, save_name + ".jpg"
)
os.makedirs(os.path.dirname(save_path), exist_ok=True)
image.save(save_path)