Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

apply formatting after iter_arrow to speed up format -> map, filter for iterable datasets #7207

Open
wants to merge 46 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 36 commits
Commits
Show all changes
46 commits
Select commit Hold shift + click to select a range
a73bb02
apply formatting after iter_arrow
alex-hh Oct 8, 2024
4a761a9
add support for formatting to map iteration
alex-hh Oct 8, 2024
3b65d99
formatted iterator for filter
alex-hh Oct 8, 2024
d906b9f
fix filtered formatting
alex-hh Oct 8, 2024
421917d
option to disable formatting for outputs of map
alex-hh Oct 8, 2024
e7b67c3
remove format_outputs kwarg
alex-hh Oct 9, 2024
a4f9700
rename batched_examples_iterator -> inputs_iterator
alex-hh Oct 9, 2024
a465abd
support arbitrary input formatting in filtered examples iterable iter…
alex-hh Oct 9, 2024
1863f8c
preserve formatting on filtered shuffle
alex-hh Oct 9, 2024
205e0d6
pass token_per_repo_id to python_feature_decoder in formatters
alex-hh Oct 9, 2024
42dc44f
implement FormattedExamplesIterator
alex-hh Oct 9, 2024
4a8fed5
fix formatted examples iterable
alex-hh Oct 9, 2024
8cdf6a6
Merge branch 'main' into iterable-map-with-format
alex-hh Oct 9, 2024
2ddaa7d
restore is_typed property
alex-hh Oct 9, 2024
dcd5017
pass formatting config to formatted examples iterable
alex-hh Oct 9, 2024
1ae947e
fix formatter init
alex-hh Oct 9, 2024
20330e8
Merge branch 'main' into iterable-map-with-format
alex-hh Oct 9, 2024
8f6845f
map examples iterable expects to receive rebatchedarrowexamplesiterab…
alex-hh Oct 9, 2024
3a91aac
only apply features if they exist
alex-hh Oct 9, 2024
84fcf74
fix shuffle and shard
alex-hh Oct 9, 2024
4fac60a
remove formatting from FilteredExamplesIterable
alex-hh Oct 10, 2024
afa78aa
run pre commit
alex-hh Oct 10, 2024
5a8389b
filtered iter_arrow always allowed if available
alex-hh Oct 10, 2024
c97f02e
filtered examples iterable needs formatting when iter_arrow enabled
alex-hh Oct 10, 2024
76e09a1
only iter arrow on filter if formatting is set
alex-hh Oct 10, 2024
ee45f7f
add features property to support feature inference
alex-hh Oct 10, 2024
b828575
fix features property
alex-hh Oct 10, 2024
f76701b
dont re-encode featuers
alex-hh Oct 10, 2024
15a8cfe
avoid re-encoding outputs of map
alex-hh Oct 10, 2024
884bba1
map should not preserve formatting
alex-hh Oct 10, 2024
d979672
update comment
alex-hh Oct 10, 2024
190d062
update map features property
alex-hh Oct 10, 2024
85b7d4d
return bool for mapped ex iterable is typed
alex-hh Oct 11, 2024
3129274
pass return features to mapped exampels iterable constructor
alex-hh Oct 11, 2024
45f55b4
don't iter arrow with formatted filter to avoid re formatting
alex-hh Oct 11, 2024
5e31fe0
avoid re-formatting data
alex-hh Oct 12, 2024
49a84fe
rename return features -> features
alex-hh Oct 14, 2024
002f5b4
update refs to return_features
alex-hh Oct 14, 2024
2479264
decode features in batched map
alex-hh Oct 14, 2024
68bfa39
preserve formatting in with_format
alex-hh Oct 15, 2024
38f78d2
fix features (mapped ex iterable
alex-hh Oct 16, 2024
f59a8e6
Merge branch 'main' into iterable-map-with-format
alex-hh Oct 31, 2024
ca2deb4
update shard
alex-hh Oct 31, 2024
4efcf11
remove formatted examples iterable from with_format
alex-hh Nov 2, 2024
f997f8c
avoid reapplying features when chaining filter, map
alex-hh Nov 2, 2024
bd8bbd3
preserve formatting in map
alex-hh Nov 11, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 16 additions & 8 deletions src/datasets/formatting/formatting.py
Original file line number Diff line number Diff line change
Expand Up @@ -215,11 +215,14 @@ def extract_batch(self, pa_table: pa.Table) -> pd.DataFrame:


class PythonFeaturesDecoder:
def __init__(self, features: Optional[Features]):
def __init__(
self, features: Optional[Features], token_per_repo_id: Optional[Dict[str, Union[str, bool, None]]] = None
):
self.features = features
self.token_per_repo_id = token_per_repo_id

def decode_row(self, row: dict) -> dict:
return self.features.decode_example(row) if self.features else row
return self.features.decode_example(row, token_per_repo_id=self.token_per_repo_id) if self.features else row

def decode_column(self, column: list, column_name: str) -> list:
return self.features.decode_column(column, column_name) if self.features else column
Expand Down Expand Up @@ -393,9 +396,14 @@ class Formatter(Generic[RowFormat, ColumnFormat, BatchFormat]):
numpy_arrow_extractor = NumpyArrowExtractor
pandas_arrow_extractor = PandasArrowExtractor

def __init__(self, features: Optional[Features] = None):
def __init__(
self,
features: Optional[Features] = None,
token_per_repo_id: Optional[Dict[str, Union[str, bool, None]]] = None,
):
self.features = features
self.python_features_decoder = PythonFeaturesDecoder(self.features)
self.token_per_repo_id = token_per_repo_id
self.python_features_decoder = PythonFeaturesDecoder(self.features, self.token_per_repo_id)
self.pandas_features_decoder = PandasFeaturesDecoder(self.features)

def __call__(self, pa_table: pa.Table, query_type: str) -> Union[RowFormat, ColumnFormat, BatchFormat]:
Expand Down Expand Up @@ -433,8 +441,8 @@ def format_batch(self, pa_table: pa.Table) -> pa.Table:


class PythonFormatter(Formatter[Mapping, list, Mapping]):
def __init__(self, features=None, lazy=False):
super().__init__(features)
def __init__(self, features=None, lazy=False, token_per_repo_id=None):
super().__init__(features, token_per_repo_id)
self.lazy = lazy

def format_row(self, pa_table: pa.Table) -> Mapping:
Expand Down Expand Up @@ -484,8 +492,8 @@ class CustomFormatter(Formatter[dict, ColumnFormat, dict]):
to return.
"""

def __init__(self, transform: Callable[[dict], dict], features=None, **kwargs):
super().__init__(features=features)
def __init__(self, transform: Callable[[dict], dict], features=None, token_per_repo_id=None, **kwargs):
super().__init__(features=features, token_per_repo_id=token_per_repo_id)
self.transform = transform

def format_row(self, pa_table: pa.Table) -> dict:
Expand Down
4 changes: 2 additions & 2 deletions src/datasets/formatting/jax_formatter.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,8 +36,8 @@


class JaxFormatter(TensorFormatter[Mapping, "jax.Array", Mapping]):
def __init__(self, features=None, device=None, **jnp_array_kwargs):
super().__init__(features=features)
def __init__(self, features=None, device=None, token_per_repo_id=None, **jnp_array_kwargs):
super().__init__(features=features, token_per_repo_id=token_per_repo_id)
import jax
from jaxlib.xla_client import Device

Expand Down
4 changes: 2 additions & 2 deletions src/datasets/formatting/np_formatter.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,8 @@


class NumpyFormatter(TensorFormatter[Mapping, np.ndarray, Mapping]):
def __init__(self, features=None, **np_array_kwargs):
super().__init__(features=features)
def __init__(self, features=None, token_per_repo_id=None, **np_array_kwargs):
super().__init__(features=features, token_per_repo_id=token_per_repo_id)
self.np_array_kwargs = np_array_kwargs

def _consolidate(self, column):
Expand Down
4 changes: 2 additions & 2 deletions src/datasets/formatting/tf_formatter.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,8 @@


class TFFormatter(TensorFormatter[Mapping, "tf.Tensor", Mapping]):
def __init__(self, features=None, **tf_tensor_kwargs):
super().__init__(features=features)
def __init__(self, features=None, token_per_repo_id=None, **tf_tensor_kwargs):
super().__init__(features=features, token_per_repo_id=token_per_repo_id)
self.tf_tensor_kwargs = tf_tensor_kwargs
import tensorflow as tf # noqa: F401 - import tf at initialization

Expand Down
4 changes: 2 additions & 2 deletions src/datasets/formatting/torch_formatter.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,8 @@


class TorchFormatter(TensorFormatter[Mapping, "torch.Tensor", Mapping]):
def __init__(self, features=None, **torch_tensor_kwargs):
super().__init__(features=features)
def __init__(self, features=None, token_per_repo_id=None, **torch_tensor_kwargs):
super().__init__(features=features, token_per_repo_id=token_per_repo_id)
self.torch_tensor_kwargs = torch_tensor_kwargs
import torch # noqa import torch at initialization

Expand Down
Loading