Skip to content

Commit

Permalink
Gracefully handle empty CSV (#206)
Browse files Browse the repository at this point in the history
* Gracefully handle empty CSV

* addressed comments

* Add a new test case to verify cloudpickle fix

* Allow raising error on empty CSV

* ignore_reinit_error=True to avoid init twice

* Bump up the version
  • Loading branch information
raghumdani authored Sep 1, 2023
1 parent 9ba2aba commit 2b213f8
Show file tree
Hide file tree
Showing 7 changed files with 353 additions and 7 deletions.
2 changes: 1 addition & 1 deletion deltacat/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@

deltacat.logs.configure_deltacat_logger(logging.getLogger(__name__))

__version__ = "0.1.18b16"
__version__ = "0.1.18b17"


__all__ = [
Expand Down
51 changes: 51 additions & 0 deletions deltacat/tests/io/test_cloudpickle_bug_fix.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
import unittest
from typing import Any
import pyarrow as pa
import ray


class CustomObj:
def pickle_len(self, obj: Any):
pass


class AnyObject:
def __init__(self) -> None:
pass

def pickle_len(self, obj):
return len(ray.cloudpickle.dumps(obj))


@ray.remote
def calculate_pickled_length(custom_obj: CustomObj):

table = pa.table({"a": pa.array(list(range(1000000)))})

full_len = custom_obj.pickle_len(table)
sliced_len = custom_obj.pickle_len(table[0:1])

return [sliced_len, full_len]


class TestCloudpickleBugFix(unittest.TestCase):
"""
This test is specifically to validate if nothing has
changed across Ray versions regarding the cloudpickle behavior.
If the tables are sliced, cloudpickle used to dump entire buffer.
However, Ray has added a custom serializer to get around this problem in
https://github.com/ray-project/ray/pull/29993 for the issue at
https://github.com/ray-project/ray/issues/29814.
Note: If this test fails, it indicates that you may need to address the cloudpickle
bug before upgrading ray version.
"""

def test_sanity(self):
ray.init(local_mode=True, ignore_reinit_error=True)

result = ray.get(calculate_pickled_length.remote(AnyObject()))

self.assertTrue(result[0] < 1000)
self.assertTrue(result[1] >= 5000000)
Empty file.
3 changes: 3 additions & 0 deletions deltacat/tests/utils/data/non_empty_valid.csv
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
Y 1754-08-30T22:43:41
"" 2022-08-10T22:43:21
N 9999-08-10T22:43:21.123456
212 changes: 212 additions & 0 deletions deltacat/tests/utils/test_pyarrow.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,20 @@
from deltacat.utils.pyarrow import (
s3_parquet_file_to_table,
s3_partial_parquet_file_to_table,
pyarrow_read_csv,
content_type_to_reader_kwargs,
_add_column_kwargs,
ReadKwargsProviderPyArrowSchemaOverride,
RAISE_ON_EMPTY_CSV_KWARG,
)
from deltacat.types.media import ContentEncoding, ContentType
from deltacat.types.partial_download import PartialParquetParameters
from pyarrow.parquet import ParquetFile
import pyarrow as pa

PARQUET_FILE_PATH = "deltacat/tests/utils/data/test_file.parquet"
EMPTY_UTSV_PATH = "deltacat/tests/utils/data/empty.csv"
NON_EMPTY_VALID_UTSV_PATH = "deltacat/tests/utils/data/non_empty_valid.csv"


class TestS3ParquetFileToTable(TestCase):
Expand Down Expand Up @@ -131,3 +138,208 @@ def test_s3_partial_parquet_file_to_table_when_multiple_row_groups(self):

self.assertEqual(len(result), 6)
self.assertEqual(len(result.columns), 2)


class TestReadCSV(TestCase):
def test_read_csv_sanity(self):

schema = pa.schema(
[("is_active", pa.string()), ("ship_datetime_utc", pa.timestamp("us"))]
)
kwargs = content_type_to_reader_kwargs(ContentType.UNESCAPED_TSV.value)
_add_column_kwargs(
ContentType.UNESCAPED_TSV.value,
["is_active", "ship_datetime_utc"],
None,
kwargs,
)

read_kwargs_provider = ReadKwargsProviderPyArrowSchemaOverride(schema=schema)

kwargs = read_kwargs_provider(ContentType.UNESCAPED_TSV.value, kwargs)

result = pyarrow_read_csv(NON_EMPTY_VALID_UTSV_PATH, **kwargs)

self.assertEqual(len(result), 3)
self.assertEqual(len(result.column_names), 2)
result_schema = result.schema
for index, field in enumerate(result_schema):
self.assertEqual(field.name, schema.field(index).name)

self.assertEqual(result.schema.field(0).type, "string")

def test_read_csv_when_column_order_changes(self):

schema = pa.schema(
[("is_active", pa.string()), ("ship_datetime_utc", pa.timestamp("us"))]
)
kwargs = content_type_to_reader_kwargs(ContentType.UNESCAPED_TSV.value)
_add_column_kwargs(
ContentType.UNESCAPED_TSV.value,
["is_active", "ship_datetime_utc"],
["ship_datetime_utc", "is_active"],
kwargs,
)

read_kwargs_provider = ReadKwargsProviderPyArrowSchemaOverride(schema=schema)

kwargs = read_kwargs_provider(ContentType.UNESCAPED_TSV.value, kwargs)

result = pyarrow_read_csv(NON_EMPTY_VALID_UTSV_PATH, **kwargs)

self.assertEqual(len(result), 3)
self.assertEqual(len(result.column_names), 2)
result_schema = result.schema
self.assertEqual(result_schema.field(1).type, "string")
self.assertEqual(result_schema.field(0).type, "timestamp[us]")

def test_read_csv_when_partial_columns_included(self):

schema = pa.schema(
[("is_active", pa.string()), ("ship_datetime_utc", pa.timestamp("us"))]
)
kwargs = content_type_to_reader_kwargs(ContentType.UNESCAPED_TSV.value)
_add_column_kwargs(
ContentType.UNESCAPED_TSV.value,
["is_active", "ship_datetime_utc"],
["is_active"],
kwargs,
)

read_kwargs_provider = ReadKwargsProviderPyArrowSchemaOverride(schema=schema)

kwargs = read_kwargs_provider(ContentType.UNESCAPED_TSV.value, kwargs)

result = pyarrow_read_csv(NON_EMPTY_VALID_UTSV_PATH, **kwargs)

self.assertEqual(len(result), 3)
self.assertEqual(len(result.column_names), 1)
result_schema = result.schema
self.assertEqual(result_schema.field(0).type, "string")

def test_read_csv_when_column_names_partial(self):

schema = pa.schema(
[("is_active", pa.string()), ("ship_datetime_utc", pa.timestamp("us"))]
)
kwargs = content_type_to_reader_kwargs(ContentType.UNESCAPED_TSV.value)
_add_column_kwargs(ContentType.UNESCAPED_TSV.value, ["is_active"], None, kwargs)

read_kwargs_provider = ReadKwargsProviderPyArrowSchemaOverride(schema=schema)

kwargs = read_kwargs_provider(ContentType.UNESCAPED_TSV.value, kwargs)

self.assertRaises(
pa.lib.ArrowInvalid,
lambda: pyarrow_read_csv(NON_EMPTY_VALID_UTSV_PATH, **kwargs),
)

def test_read_csv_when_empty_csv_sanity(self):

schema = pa.schema(
[("is_active", pa.string()), ("ship_datetime_utc", pa.timestamp("us"))]
)
kwargs = content_type_to_reader_kwargs(ContentType.UNESCAPED_TSV.value)
_add_column_kwargs(
ContentType.UNESCAPED_TSV.value,
["is_active", "ship_datetime_utc"],
None,
kwargs,
)

read_kwargs_provider = ReadKwargsProviderPyArrowSchemaOverride(schema=schema)
kwargs = read_kwargs_provider(ContentType.UNESCAPED_TSV.value, kwargs)
result = pyarrow_read_csv(EMPTY_UTSV_PATH, **kwargs)

self.assertEqual(len(result), 0)
self.assertEqual(len(result.column_names), 2)
result_schema = result.schema
self.assertEqual(result_schema.field(0).type, "string")
self.assertEqual(result_schema.field(1).type, "timestamp[us]")

def test_read_csv_when_empty_csv_include_columns(self):

schema = pa.schema(
[("is_active", pa.string()), ("ship_datetime_utc", pa.timestamp("us"))]
)
kwargs = content_type_to_reader_kwargs(ContentType.UNESCAPED_TSV.value)
_add_column_kwargs(
ContentType.UNESCAPED_TSV.value,
["is_active", "ship_datetime_utc"],
["ship_datetime_utc", "is_active"],
kwargs,
)

read_kwargs_provider = ReadKwargsProviderPyArrowSchemaOverride(schema=schema)

kwargs = read_kwargs_provider(ContentType.UNESCAPED_TSV.value, kwargs)

result = pyarrow_read_csv(EMPTY_UTSV_PATH, **kwargs)

self.assertEqual(len(result), 0)
self.assertEqual(len(result.column_names), 2)
result_schema = result.schema
self.assertEqual(result_schema.field(1).type, "string")
self.assertEqual(result_schema.field(0).type, "timestamp[us]")

def test_read_csv_when_empty_csv_include_partial_columns(self):

schema = pa.schema(
[("is_active", pa.string()), ("ship_datetime_utc", pa.timestamp("us"))]
)
kwargs = content_type_to_reader_kwargs(ContentType.UNESCAPED_TSV.value)
_add_column_kwargs(
ContentType.UNESCAPED_TSV.value,
["is_active", "ship_datetime_utc"],
["is_active"],
kwargs,
)

read_kwargs_provider = ReadKwargsProviderPyArrowSchemaOverride(schema=schema)

kwargs = read_kwargs_provider(ContentType.UNESCAPED_TSV.value, kwargs)

result = pyarrow_read_csv(EMPTY_UTSV_PATH, **kwargs)

self.assertEqual(len(result), 0)
self.assertEqual(len(result.column_names), 1)
result_schema = result.schema
self.assertEqual(result_schema.field(0).type, "string")

def test_read_csv_when_empty_csv_honors_column_names(self):

schema = pa.schema(
[("is_active", pa.string()), ("ship_datetime_utc", pa.timestamp("us"))]
)
kwargs = content_type_to_reader_kwargs(ContentType.UNESCAPED_TSV.value)
_add_column_kwargs(ContentType.UNESCAPED_TSV.value, ["is_active"], None, kwargs)

read_kwargs_provider = ReadKwargsProviderPyArrowSchemaOverride(schema=schema)

kwargs = read_kwargs_provider(ContentType.UNESCAPED_TSV.value, kwargs)

result = pyarrow_read_csv(EMPTY_UTSV_PATH, **kwargs)

self.assertEqual(len(result), 0)
self.assertEqual(len(result.column_names), 1)
result_schema = result.schema
self.assertEqual(result_schema.field(0).type, "string")

def test_read_csv_when_empty_csv_and_raise_on_empty_passed(self):

schema = pa.schema(
[("is_active", pa.string()), ("ship_datetime_utc", pa.timestamp("us"))]
)
kwargs = content_type_to_reader_kwargs(ContentType.UNESCAPED_TSV.value)
_add_column_kwargs(ContentType.UNESCAPED_TSV.value, ["is_active"], None, kwargs)

read_kwargs_provider = ReadKwargsProviderPyArrowSchemaOverride(schema=schema)

kwargs = read_kwargs_provider(ContentType.UNESCAPED_TSV.value, kwargs)

self.assertRaises(
pa.lib.ArrowInvalid,
lambda: pyarrow_read_csv(
EMPTY_UTSV_PATH, **{**kwargs, RAISE_ON_EMPTY_CSV_KWARG: True}
),
)
21 changes: 20 additions & 1 deletion deltacat/utils/arguments.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import inspect
from typing import Any, Dict
from typing import Any, Dict, List


def sanitize_kwargs_to_callable(callable: Any, kwargs: Dict) -> Dict:
Expand All @@ -23,3 +23,22 @@ def sanitize_kwargs_to_callable(callable: Any, kwargs: Dict) -> Dict:
new_kwargs.pop(key)

return new_kwargs


def sanitize_kwargs_by_supported_kwargs(
supported_kwargs: List[str], kwargs: Dict
) -> Dict:
"""
This method only keeps the kwargs in the list provided above and ignores any other kwargs passed.
This method will specifically be useful where signature cannot be automatically determined
(say the definition is part C++ implementation).
Returns: a sanitized dict of kwargs.
"""

new_kwargs = {}
for key in supported_kwargs:
if key in kwargs:
new_kwargs[key] = kwargs[key]

return new_kwargs
Loading

0 comments on commit 2b213f8

Please sign in to comment.