Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feat/polars support #1357

Draft
wants to merge 4 commits into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion clearml/backend_interface/metrics/reporter.py
Original file line number Diff line number Diff line change
Expand Up @@ -654,7 +654,7 @@ def report_table(self, title, series, table, iteration, layout_config=None, data
:param series: Series (AKA variant)
:type series: str
:param table: The table data
:type table: pandas.DataFrame
:type table: pandas.DataFrame or polars.DataFrame
:param iteration: Iteration number
:type iteration: int
:param layout_config: optional dictionary for layout configuration, passed directly to plotly
Expand Down
78 changes: 56 additions & 22 deletions clearml/binding/artifacts.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,10 @@
from ..utilities.proxy_object import LazyEvalWrapper
from ..config import deferred_config, config

try:
import polars as pl
except ImportError:
pl = None
try:
import pandas as pd
DataFrame = pd.DataFrame
Expand Down Expand Up @@ -152,6 +156,7 @@ def get(self, force_download=False, deserialization_function=None):
Supported content types are:
- dict - ``.json``, ``.yaml``
- pandas.DataFrame - ``.csv.gz``, ``.parquet``, ``.feather``, ``.pickle``
- polars.DataFrame - ``.csv.gz``, ``.parquet``, ``.feather``, ``.pickle``
- numpy.ndarray - ``.npz``, ``.csv.gz``
- PIL.Image - whatever content types PIL supports
All other types will return a pathlib2.Path object pointing to a local copy of the artifacts file (or directory).
Expand Down Expand Up @@ -192,6 +197,18 @@ def get(self, force_download=False, deserialization_function=None):
self._object = pd.read_csv(local_file)
else:
self._object = pd.read_csv(local_file, index_col=[0])
elif self.type == Artifacts._pd_artifact_type or self.type == "polars" and pl:
if self._content_type == "application/parquet":
self._object = pl.read_parquet(local_file)
elif self._content_type == "application/feather":
self._object = pl.read_ipc(local_file)
elif self._content_type == "application/pickle":
with open(local_file, "rb") as f:
self._object = pickle.load(f)
elif self.type == Artifacts._pd_artifact_type:
self._object = pl.read_csv(local_file)
else:
self._object = pl.read_csv(local_file)
elif self.type == "image":
self._object = Image.open(local_file)
elif self.type == "JSON" or self.type == "dict":
Expand Down Expand Up @@ -279,14 +296,14 @@ def __init__(self, artifacts_manager, *args, **kwargs):
self.artifact_hash_columns = {}

def __setitem__(self, key, value):
# check that value is of type pandas
if pd and isinstance(value, pd.DataFrame):
# check that value is of type pandas or polars
if (pd and isinstance(value, pd.DataFrame)) or (pl and isinstance(value, pl.DataFrame)):
super(Artifacts._ProxyDictWrite, self).__setitem__(key, value)

if self._artifacts_manager:
self._artifacts_manager.flush()
else:
raise ValueError('Artifacts currently support pandas.DataFrame objects only')
raise ValueError('Artifacts currently support pandas.DataFrame and polars.DataFrame objects only')

def unregister_artifact(self, name):
self.artifact_metadata.pop(name, None)
Expand Down Expand Up @@ -471,8 +488,8 @@ def get_extension(extension_name_, valid_extensions, default_extension, artifact
artifact_type_data.content_type = "text/csv"
np.savetxt(local_filename, artifact_object, delimiter=",")
delete_after_upload = True
elif pd and isinstance(artifact_object, pd.DataFrame):
artifact_type = "pandas"
elif (pd and isinstance(artifact_object, pd.DataFrame)) or (pl and isinstance(artifact_object, pl.DataFrame)):
artifact_type = "pandas" if (pd and isinstance(artifact_object, pd.DataFrame)) else "polars"
artifact_type_data.preview = preview or str(artifact_object.__repr__())
# we are making sure self._default_pandas_dataframe_extension_name is not deferred
extension_name = extension_name or str(self._default_pandas_dataframe_extension_name or "")
Expand All @@ -483,17 +500,20 @@ def get_extension(extension_name_, valid_extensions, default_extension, artifact
local_filename = self._push_temp_file(
prefix=quote(name, safe="") + ".", suffix=override_filename_ext_in_uri
)
if (
if (pd and isinstance(artifact_object, pd.DataFrame) and (
isinstance(artifact_object.index, pd.MultiIndex) or isinstance(artifact_object.columns, pd.MultiIndex)
) and not extension_name:
) or (pl and isinstance(artifact_object, pl.DataFrame))) and not extension_name:
store_as_pickle = True
elif override_filename_ext_in_uri == ".csv.gz":
artifact_type_data.content_type = "text/csv"
self._store_compressed_pd_csv(artifact_object, local_filename)
elif override_filename_ext_in_uri == ".parquet":
try:
artifact_type_data.content_type = "application/parquet"
artifact_object.to_parquet(local_filename)
if (pd and isinstance(artifact_object, pd.DataFrame)):
artifact_object.to_parquet(local_filename)
else:
artifact_object.write_parquet(local_filename)
except Exception as e:
LoggerRoot.get_base_logger().warning(
"Exception '{}' encountered when uploading artifact as .parquet. Defaulting to .csv.gz".format(
Expand All @@ -505,7 +525,10 @@ def get_extension(extension_name_, valid_extensions, default_extension, artifact
elif override_filename_ext_in_uri == ".feather":
try:
artifact_type_data.content_type = "application/feather"
artifact_object.to_feather(local_filename)
if (pd and isinstance(artifact_object, pd.DataFrame)):
artifact_object.to_feather(local_filename)
else:
artifact_object.write_ipc(local_filename)
except Exception as e:
LoggerRoot.get_base_logger().warning(
"Exception '{}' encountered when uploading artifact as .feather. Defaulting to .csv.gz".format(
Expand All @@ -516,7 +539,11 @@ def get_extension(extension_name_, valid_extensions, default_extension, artifact
self._store_compressed_pd_csv(artifact_object, local_filename)
elif override_filename_ext_in_uri == ".pickle":
artifact_type_data.content_type = "application/pickle"
artifact_object.to_pickle(local_filename)
if (pl and isinstance(artifact_object, pd.DataFrame)):
artifact_object.to_pickle(local_filename)
else:
with open(local_filename, "wb") as f:
pickle.dump(artifact_object, f)
delete_after_upload = True
elif isinstance(artifact_object, Image.Image):
artifact_type = "image"
Expand Down Expand Up @@ -1005,7 +1032,7 @@ def _get_statistics(self, artifacts_dict=None):
artifacts_summary = []
for a_name, a_df in artifacts_dict.items():
hash_cols = self._artifacts_container.get_hash_columns(a_name)
if not pd or not isinstance(a_df, pd.DataFrame):
if not pd or not isinstance(a_df, pd.DataFrame) or not pl or not isinstance(a_df, pl.DataFrame):
continue

if hash_cols is True:
Expand Down Expand Up @@ -1037,8 +1064,12 @@ def hash_row(r):

a_shape = a_df.shape
# parallelize
a_hash_cols = a_df.drop(columns=hash_col_drop)
thread_pool.map(hash_row, a_hash_cols.values)
if pd and isinstance(a_df, pd.DataFrame):
a_hash_cols = a_df.drop(columns=hash_col_drop)
thread_pool.map(hash_row, a_hash_cols.values)
else:
a_hash_cols = a_df.drop(hash_col_drop)
a_unique_hash.add(a_hash_cols.hash_rows())
# add result
artifacts_summary.append((a_name, a_shape, a_unique_hash,))

Expand Down Expand Up @@ -1082,16 +1113,19 @@ def _store_compressed_pd_csv(self, artifact_object, local_filename, **kwargs):
# (otherwise it is encoded and creates new hash every time)
if self._compression == "gzip":
with gzip.GzipFile(local_filename, 'wb', mtime=0) as gzip_file:
try:
pd_version = int(pd.__version__.split(".")[0])
except ValueError:
pd_version = 0

if pd_version >= 2:
artifact_object.to_csv(gzip_file, **kwargs)
if pl and isinstance(artifact_object, pl.DataFrame):
artifact_object.write_csv(gzip_file)
else:
# old (pandas<2) versions of pandas cannot handle direct gzip stream, so we manually encode it
artifact_object.to_csv(io.TextIOWrapper(gzip_file), **kwargs)
try:
pd_version = int(pd.__version__.split(".")[0])
except ValueError:
pd_version = 0

if pd_version >= 2:
artifact_object.to_csv(gzip_file, **kwargs)
else:
# old (pandas<2) versions of pandas cannot handle direct gzip stream, so we manually encode it
artifact_object.to_csv(io.TextIOWrapper(gzip_file), **kwargs)
else:
artifact_object.to_csv(local_filename, compression=self._compression)

Expand Down
61 changes: 43 additions & 18 deletions clearml/datasets/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,10 @@
except Exception as e:
logging.warning("ClearML Dataset failed importing pandas: {}".format(e))
pd = None
try:
import polars as pl
except ImportError:
pl = None

try:
import pyarrow # noqa
Expand Down Expand Up @@ -850,7 +854,7 @@ def finalize(self, verbose=False, raise_on_error=True, auto_upload=False):
return True

def set_metadata(self, metadata, metadata_name='metadata', ui_visible=True):
# type: (Union[numpy.array, pd.DataFrame, Dict[str, Any]], str, bool) -> () # noqa: F821
# type: (Union[numpy.array, pd.DataFrame, pl.DataFrame, Dict[str, Any]], str, bool) -> () # noqa: F821
"""
Attach a user-defined metadata to the dataset. Check `Task.upload_artifact` for supported types.
If type is Pandas Dataframes, optionally make it visible as a table in the UI.
Expand All @@ -859,7 +863,7 @@ def set_metadata(self, metadata, metadata_name='metadata', ui_visible=True):
raise ValueError("metadata_name can not start with '{}'".format(self.__data_entry_name_prefix))
self._task.upload_artifact(name=metadata_name, artifact_object=metadata)
if ui_visible:
if pd and isinstance(metadata, pd.DataFrame):
if (pd and isinstance(metadata, pd.DataFrame)) or (pl and isinstance(metadata, pl.DataFrame)):
self.get_logger().report_table(
title='Dataset Metadata',
series='Dataset Metadata',
Expand All @@ -872,7 +876,7 @@ def set_metadata(self, metadata, metadata_name='metadata', ui_visible=True):
)

def get_metadata(self, metadata_name='metadata'):
# type: (str) -> Optional[numpy.array, pd.DataFrame, dict, str, bool] # noqa: F821
# type: (str) -> Optional[numpy.array, pd.DataFrame, pl.DataFrame, dict, str, bool] # noqa: F821
"""
Get attached metadata back in its original format. Will return None if none was found.
"""
Expand Down Expand Up @@ -3091,19 +3095,34 @@ def _report_dataset_preview(self):
def convert_to_tabular_artifact(file_path_, file_extension_, compression_=None):
# noinspection PyBroadException
try:
if file_extension_ == ".csv" and pd:
return pd.read_csv(
file_path_,
nrows=self.__preview_tabular_row_count,
compression=compression_.lstrip(".") if compression_ else None,
)
elif file_extension_ == ".tsv" and pd:
return pd.read_csv(
file_path_,
sep='\t',
nrows=self.__preview_tabular_row_count,
compression=compression_.lstrip(".") if compression_ else None,
)
if file_extension_ == ".csv" and (pl or pd):
if pd:
return pd.read_csv(
file_path_,
nrows=self.__preview_tabular_row_count,
compression=compression_.lstrip(".") if compression_ else None,
)
else:
# TODO Re-implement compression after testing all extensions
return pl.read_csv(
file_path_,
n_rows=self.__preview_tabular_row_count,
)
elif file_extension_ == ".tsv" and (pl or pd):
if pd:
return pd.read_csv(
file_path_,
sep='\t',
nrows=self.__preview_tabular_row_count,
compression=compression_.lstrip(".") if compression_ else None,
)
else:
# TODO Re-implement compression after testing all extensions
return pl.read_csv(
file_path_,
n_rows=self.__preview_tabular_row_count,
separator='\t',
)
elif file_extension_ == ".parquet" or file_extension_ == ".parq":
if pyarrow:
pf = pyarrow.parquet.ParquetFile(file_path_)
Expand All @@ -3112,7 +3131,10 @@ def convert_to_tabular_artifact(file_path_, file_extension_, compression_=None):
elif fastparquet:
return fastparquet.ParquetFile(file_path_).head(self.__preview_tabular_row_count).to_pandas()
elif (file_extension_ == ".npz" or file_extension_ == ".npy") and np:
return pd.DataFrame(np.loadtxt(file_path_, max_rows=self.__preview_tabular_row_count))
if pd:
return pd.DataFrame(np.loadtxt(file_path_, max_rows=self.__preview_tabular_row_count))
else
return pl.DataFrame(np.loadtxt(file_path_, max_rows=self.__preview_tabular_row_count))
except Exception:
pass
return None
Expand Down Expand Up @@ -3144,7 +3166,10 @@ def convert_to_tabular_artifact(file_path_, file_extension_, compression_=None):
# because it will not upload the sample to that destination.
# use report_media instead to not leak data
if (
isinstance(artifact, pd.DataFrame)
(
(pd and isinstance(artifact, pd.DataFrame))
or (pl and isinstance(artifact, pl.DataFrame))
)
and self._task.get_logger().get_default_upload_destination() == Session.get_files_server_host()
):
self._task.get_logger().report_table("Tables", "summary", table_plot=artifact)
Expand Down
20 changes: 13 additions & 7 deletions clearml/logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,11 @@

from .debugging.log import LoggerRoot

try:
import polars as pl
except ImportError:
pl = None

try:
import pandas as pd
except ImportError:
Expand Down Expand Up @@ -326,7 +331,7 @@ def report_table(
title, # type: str
series, # type: str
iteration=None, # type: Optional[int]
table_plot=None, # type: Optional[pd.DataFrame, Sequence[Sequence]]
table_plot=None, # type: Optional[pd.DataFrame, pl.DataFrame, Sequence[Sequence]]
csv=None, # type: Optional[str]
url=None, # type: Optional[str]
extra_layout=None, # type: Optional[dict]
Expand Down Expand Up @@ -392,15 +397,15 @@ def report_table(
mutually_exclusive(UsageError, _check_none=True, table_plot=table_plot, csv=csv, url=url)
table = table_plot
if url or csv:
if not pd:
if not pd and not pl:
raise UsageError(
"pandas is required in order to support reporting tables using CSV or a URL, "
"please install the pandas python package"
"pandas or polars is required in order to support reporting tables using CSV "
"or a URL, please install the pandas or polars python package"
)
if url:
table = pd.read_csv(url, index_col=[0])
table = pd.read_csv(url, index_col=[0]) if pd else pl.read_csv(url)
elif csv:
table = pd.read_csv(csv, index_col=[0])
table = pd.read_csv(csv, index_col=[0]) if pd else pl.read_csv(csv)

def replace(dst, *srcs):
for src in srcs:
Expand All @@ -409,7 +414,8 @@ def replace(dst, *srcs):
if isinstance(table, (list, tuple)):
reporter_table = table
else:
reporter_table = table.fillna(str(np.nan))
nan = str(np.nan)
reporter_table = table.fillna(nan) if pd else table.fill_nan(nan)
replace("NaN", np.nan, math.nan if six.PY3 else float("nan"))
replace("Inf", np.inf, math.inf if six.PY3 else float("inf"))
minus_inf = [-np.inf, -math.inf if six.PY3 else -float("inf")]
Expand Down
20 changes: 15 additions & 5 deletions clearml/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,10 @@
import pandas as pd
except ImportError:
pd = None
try:
import polars as pl
except ImportError:
pl = None

from .backend_api import Session
from .backend_api.services import models, projects
Expand Down Expand Up @@ -638,15 +642,21 @@ def report_table(
)
table = table_plot
if url or csv:
if not pd:
if not pd and not pl:
raise UsageError(
"pandas is required in order to support reporting tables using CSV or a URL, "
"please install the pandas python package"
"pandas or polars is required in order to support reporting tables using CSV or a URL, "
"please install the pandas or polars python package"
)
if url:
table = pd.read_csv(url, index_col=[0])
if pd:
table = pd.read_csv(url, index_col=[0])
else:
table = pd.read_csv(url)
elif csv:
table = pd.read_csv(csv, index_col=[0])
if pd:
table = pd.read_csv(csv, index_col=[0])
else:
table = pd.read_csv(url)

def replace(dst, *srcs):
for src in srcs:
Expand Down
Loading