Skip to content

Commit

Permalink
refactoring of file io
Browse files Browse the repository at this point in the history
  • Loading branch information
Amit Jaiswal committed Nov 1, 2024
1 parent 412746c commit 0944c98
Show file tree
Hide file tree
Showing 3 changed files with 99 additions and 44 deletions.
103 changes: 61 additions & 42 deletions python-packages/core/src/omigo_core/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,18 +49,19 @@ def __init__(self, header_fields, data_fields):

# workaround
if (len(self.header_fields) == 1):
if (len(data_fields) > 0):
if (isinstance(data_fields[0], (str))):
data_fields = list([[t] for t in data_fields])
if (len(self.data_fields) > 0):
# utils.info("data_fields[0]: {}".format(self.data_fields[0]))
if (isinstance(self.data_fields[0], (str))):
self.data_fields = list([[t] for t in self.data_fields])

# basic validation
if (len(self.data_fields) > 0):
if (len(self.header_fields) > 1):
if (len(data_fields[0]) != len(self.header_fields)):
raise Exception("Header length: {} is not matching with data length: {}".format(len(self.header_fields), len(data_fields[0])))
if (len(self.data_fields[0]) != len(self.header_fields)):
raise Exception("Header length: {} is not matching with data length: {}".format(len(self.header_fields), len(self.data_fields[0])))
else:
if (len(data_fields[0]) != 1):
raise Exception("Header length: {} is not matching with data length: {}".format(len(self.header_fields), len(data_fields[0])))
raise Exception("Header length: {} is not matching with data length: {}".format(len(self.header_fields), len(self.data_fields[0])))

# debugging
def to_string(self):
Expand Down Expand Up @@ -3208,7 +3209,7 @@ def __join__(self, that, lkeys, rkeys = None, join_type = "inner", lsuffix = Non
for fields in that.get_data_fields():
# report progress
counter = counter + 1
utils.report_progress("[2/3] building map for right side", dmsg, counter, len(that.get_data()))
utils.report_progress("[2/3] building map for right side", dmsg, counter, len(that.get_data_fields()))

# parse data
rvals1 = list([fields[i] for i in rkey_indexes])
Expand Down Expand Up @@ -3530,7 +3531,7 @@ def __map_join__(self, that, lkeys, rkeys = None, join_type = "inner", lsuffix =
for fields in that.get_data_fields():
# report progress
counter = counter + 1
utils.report_progress("__map_join__: building map for right side", dmsg, counter, len(that.get_data()))
utils.report_progress("__map_join__: building map for right side", dmsg, counter, len(that.get_data_fields()))

# parse data
rvals1 = list([fields[i] for i in rkey_indexes])
Expand Down Expand Up @@ -3755,7 +3756,7 @@ def __split_batches_by_cols__(self, num_batches, cols, seed = 0, dmsg = ""):
for fields in hashed_tsv2.get_data_fields():
# report progress
counter = counter + 1
utils.report_progress("__split_batches_by_cols__: [1/1] assigning batch index", dmsg, counter, len(hashed_tsv2.get_data()))
utils.report_progress("__split_batches_by_cols__: [1/1] assigning batch index", dmsg, counter, len(hashed_tsv2.get_data_fields()))

batch_id = int(fields[batch_index])
new_data_fields_list[batch_id].append(fields)
Expand Down Expand Up @@ -4369,11 +4370,6 @@ def explode_json(self, col, prefix = None, accepted_cols = None, excluded_cols =
if (merge_list_method == "cogroup"):
utils.print_code_todo_warning("{}: merge_list_method = cogroup is only meant for data exploration. Use merge_list_method = join for generating all combinations for multiple list values".format(dmsg))

# name prefix
if (prefix is None):
utils.warn("{}: prefix is None. Using col as the name prefix".format(dmsg))
prefix = col

# check for explode function
exp_func = self.__explode_json_transform_func__(col, accepted_cols = accepted_cols, excluded_cols = excluded_cols, single_value_list_cols = single_value_list_cols,
transpose_col_groups = transpose_col_groups, merge_list_method = merge_list_method, url_encoded_cols = url_encoded_cols, nested_cols = nested_cols,
Expand All @@ -4385,18 +4381,22 @@ def explode_json(self, col, prefix = None, accepted_cols = None, excluded_cols =
# use explode to do this parsing
return self \
.add_seq_num(prefix + ":__json_index__", dmsg = dmsg) \
.explode([col], exp_func, prefix = prefix, default_val = default_val, collapse = collapse, dmsg = dmsg) \
.explode([col], exp_func, prefix, default_val = default_val, collapse = collapse, dmsg = dmsg) \
.validate()

# newer version of explode_json
def explode_json_v2(self, col, dmsg = "", **kwargs):
def explode_json_v2(self, col, prefix = None, dmsg = "", **kwargs):
dmsg = utils.extend_inherit_message(dmsg, "explode_json_v2")

# input
xinput = self \
.select(col) \
.is_nonempty_str(col)

# check for url encoding
if (utils.is_url_encoded_col(col)):
xinput = xinput.url_decode_inline(col)

# get all original values
vs = xinput.col_as_array_uniq(col)

Expand All @@ -4410,40 +4410,59 @@ def explode_json_v2(self, col, dmsg = "", **kwargs):
hash_value = utils.compute_hash(v)

# TODO: hack
if (v.startswith("{'") or v.startswith("\"{'")):
if (v.startswith("{'") or v.startswith("\"{'") or v.startswith("[{'") or v.startswith("\"[{'")):
utils.trace("{}: Found non standard json with mix of single and double quotes. Transforming: {}".format(dmsg, v))
v = v.replace("\"", "")
v = v.replace("'", "\"")

# load as json
mp = json.loads(v) if (v != "") else {}

# convert to string values
for k in mp.keys():
mp[k] = str(mp[k])
# this could be array or a map
if (isinstance(mp, (list))):
for m in mp:
# convert to string values
for k in m.keys():
m[k] = str(m[k])

# append hash
mp[temp_col] = hash_value
# assign hash value
m[temp_col] = hash_value

# append
json_arr.append(json.dumps(mp))
# append
json_arr.append(json.dumps(m))
else:
# convert to string values
for k in mp.keys():
mp[k] = str(mp[k])

# append hash
mp[temp_col] = hash_value

# append
json_arr.append(json.dumps(mp))

# create string
json_str = "[" + ",".join(json_arr) + "]"

# parse
df = pd.read_json(StringIO(json_str))

# reassess prefix
if (prefix is None):
prefix = col

# result
result = from_df(df) \
.add_prefix(col, dmsg = dmsg)
.add_prefix(prefix, dmsg = dmsg) \
.show_transpose(3, title = "result")

# return
return self \
.transform(col, utils.compute_hash, "{}:{}".format(col, temp_col), dmsg = dmsg) \
.transform(col, lambda t: utils.compute_hash(utils.url_decode(t)) if (utils.is_url_encoded_col(col)) else utils.compute_hash(t),
"{}:{}".format(prefix, temp_col), dmsg = dmsg) \
.drop_cols(col, dmsg = dmsg) \
.left_map_join(result, "{}:{}".format(col, temp_col), dmsg = dmsg) \
.drop_cols("{}:{}".format(col, temp_col), dmsg = dmsg)
.left_map_join(result, "{}:{}".format(prefix, temp_col), dmsg = dmsg) \
.drop_cols("{}:{}".format(prefix, temp_col), dmsg = dmsg)

def transpose(self, n = 1, dmsg = ""):
dmsg = utils.extend_inherit_message(dmsg, "transpose")
Expand Down Expand Up @@ -4541,12 +4560,12 @@ def to_tuples(self, cols, dmsg = ""):
# progress counters
counter = 0
dmsg = utils.extend_inherit_message(dmsg, "to_tuples")
for line in self.select(cols, dmsg = dmsg).get_data():
for fields in self.select(cols, dmsg = dmsg).get_data_fields():
# report progress
counter = counter + 1
utils.report_progress("to_tuples: [1/1] converting to tuples", dmsg, counter, self.num_rows())

fields = line.split("\t")
# append
result.append(self.__expand_to_tuple__(fields))

# return
Expand Down Expand Up @@ -4823,8 +4842,7 @@ def __get_matching_cols__(self, col_or_cols, ignore_if_missing = False):
# append
col_patterns_transformed.append(col_pattern)

# now iterate through all the column names, check if it is a regular expression and find
# all matching ones
# now iterate through all the column names, check if it is a regular expression and find all matching ones
matching_cols = []
for col_pattern in col_patterns_transformed:
# check for matching columns for the pattern
Expand Down Expand Up @@ -5156,6 +5174,7 @@ def from_df(df, url_encoded_cols = []):
tsv_lines.append(line)

# number of columns to skip with empty column name
utils.warn_once("from_df: this skip count logic is hacky")
skip_count = 0
for h in header_fields:
if (h == ""):
Expand All @@ -5165,14 +5184,14 @@ def from_df(df, url_encoded_cols = []):
header_fields = header_fields[skip_count:]

# generate data
data = []
data_fields = []
if (len(tsv_lines) > 1):
for line in tsv_lines[1:]:
for line in tsv_lines[0:]:
fields = line.split("\t")[skip_count:]
data.append("\t".join(fields))
data_fields.append(fields)

# return
return DataFrame("\t".join(header_fields), data).validate()
return DataFrame(header_fields, data_fields).validate()

def from_maps(mps, accepted_cols = None, excluded_cols = None, url_encoded_cols = None, dmsg = ""):
dmsg = utils.extend_inherit_message(dmsg, "from_maps")
Expand Down Expand Up @@ -5210,6 +5229,11 @@ def from_tsv(xtsv):
# return
return DataFrame(header_fields, data_fields)

def from_header_data(header, data):
header_fields = header.split("\t")
data_fields = list([t.split("\t") for t in data])
return new_with_cols(header_fields, data = data_fields)

def enable_debug_mode():
utils.enable_debug_mode()

Expand All @@ -5222,13 +5246,8 @@ def set_report_progress_perc(perc):
def set_report_progress_min_thresh(thresh):
utils.set_report_progress_min_thresh(thresh)

# factory method
def newWithCols(cols, data = []):
utils.warn("newWithCols is deprecated. Use new_with_cols instead")
return new_with_cols(cols, data = data)

def new_with_cols(cols, data = []):
return DataFrame("\t".join(cols), data)
return DataFrame(cols, data)

def create_empty():
return new_with_cols([])
Expand Down
5 changes: 5 additions & 0 deletions python-packages/core/src/omigo_core/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -783,3 +783,8 @@ def get_argument_as_array(arg_or_args):
else:
return arg_or_args

def is_url_encoded_col(col):
if (col is not None and col.endswith(":url_encoded")):
return True
else:
return False
35 changes: 33 additions & 2 deletions python-packages/hydra/src/omigo_hydra/hydra.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from omigo_core import utils
from omigo_hydra import file_paths_data_reader, file_paths_util
from omigo_core import utils, dataframe
from omigo_hydra import file_paths_data_reader, file_paths_util, s3io_wrapper

def save_to_file(xtsv, output_file_name, s3_region = None, aws_profile = None):
# do some validation
Expand Down Expand Up @@ -232,3 +232,34 @@ def load_from_files(filepaths, s3_region, aws_profile):

return tsv.TSV(header, data)

def read_json_files_from_directories_as_tsv(paths, s3_region = None, aws_profile = None):
# initialize fs
fs = s3io_wrapper.S3FSWrapper(s3_region = s3_region, aws_profile = aws_profile)

# result
result = []

# iterate through each directory
for path in paths:
# list all files
files = fs.list_leaf_dir(path)

# read file as set of lines
for f in files:
full_path = "{}/{}".format(path, f)
lines = fs.read_file_contents_as_text(full_path).split("\n")

# append to result
result = result + lines

# remove empty lines
result = list(filter(lambda t: t.strip() != "", result))

# create dataframe
header_fields = ["json"]
data_fields = list([t.split("\t") for t in result])
df = dataframe.new_with_cols(header_fields, data = data_fields)

# return
return df

0 comments on commit 0944c98

Please sign in to comment.