diff --git a/src/datasets/dataset_dict.py b/src/datasets/dataset_dict.py index f92a1a8afda..89e054fabcf 100644 --- a/src/datasets/dataset_dict.py +++ b/src/datasets/dataset_dict.py @@ -784,6 +784,7 @@ def map( function: Optional[Callable] = None, with_indices: bool = False, with_rank: bool = False, + with_split: bool = False, input_columns: Optional[Union[str, List[str]]] = None, batched: bool = False, batch_size: Optional[int] = 1000, @@ -795,7 +796,7 @@ def map( writer_batch_size: Optional[int] = 1000, features: Optional[Features] = None, disable_nullable: bool = False, - fn_kwargs: Optional[dict] = None, + fn_kwargs: dict = {}, num_proc: Optional[int] = None, desc: Optional[str] = None, ) -> "DatasetDict": @@ -882,6 +883,7 @@ def map( self._check_values_type() if cache_file_names is None: cache_file_names = {k: None for k in self} + return DatasetDict( { k: dataset.map( @@ -899,7 +901,7 @@ def map( writer_batch_size=writer_batch_size, features=features, disable_nullable=disable_nullable, - fn_kwargs=fn_kwargs, + fn_kwargs={**fn_kwargs, "split": k} if with_split else fn_kwargs, num_proc=num_proc, desc=desc, )