generated from fastai/nbdev_template
-
Notifications
You must be signed in to change notification settings - Fork 1
Expand file tree
/
Copy pathquantreader_utils.py
More file actions
90 lines (70 loc) · 2.9 KB
/
quantreader_utils.py
File metadata and controls
90 lines (70 loc) · 2.9 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
import contextlib
import logging
import pandas as pd
import pyarrow.parquet
LOGGER = logging.getLogger(__name__)
def filter_input(filter_dict, input):
if filter_dict is None:
return input
for filtname, filterconf in filter_dict.items():
param = filterconf.get("param")
comparator = filterconf.get("comparator")
value = filterconf.get("value")
if comparator not in [">", ">=", "<", "<=", "==", "!="]:
raise TypeError(
f"cannot identify the filter comparator of {filtname} given in the longtable config yaml!"
)
if comparator == "==":
input = input[input[param] == value]
continue
with contextlib.suppress(Exception):
input = input.astype({f"{param}": "float"})
if comparator == ">":
input = input[input[param].astype(type(value)) > value]
if comparator == ">=":
input = input[input[param].astype(type(value)) >= value]
if comparator == "<":
input = input[input[param].astype(type(value)) < value]
if comparator == "<=":
input = input[input[param].astype(type(value)) <= value]
if comparator == "!=":
input = input[input[param].astype(type(value)) != value]
return input
def read_file(file_path, decimal=".", usecols=None, chunksize=None, sep=None):
file_path = str(file_path)
if file_path.endswith(".parquet"):
return _read_parquet_file(file_path, usecols=usecols, chunksize=chunksize)
else:
if sep is None:
if ".csv" in file_path:
sep = ","
elif ".tsv" in file_path:
sep = "\t"
else:
sep = "\t"
LOGGER.info(
f"neither of the file extensions (.tsv, .csv) detected for file {file_path}! Trying with tab separation. In the case that it fails, please provide the correct file extension"
)
return pd.read_csv(
file_path,
sep=sep,
decimal=decimal,
usecols=usecols,
encoding="latin1",
chunksize=chunksize,
)
def _read_parquet_file(file_path, usecols=None, chunksize=None):
if chunksize is not None:
return _read_parquet_file_chunkwise(
file_path, usecols=usecols, chunksize=chunksize
)
return pd.read_parquet(file_path, columns=usecols)
def _read_parquet_file_chunkwise(file_path, usecols=None, chunksize=None):
parquet_file = pyarrow.parquet.ParquetFile(file_path)
for batch in parquet_file.iter_batches(columns=usecols, batch_size=chunksize):
yield batch.to_pandas()
def read_columns_from_file(file, sep="\t"):
if file.endswith(".parquet"):
parquet_file = pyarrow.parquet.ParquetFile(file)
return parquet_file.schema.names
return pd.read_csv(file, sep=sep, nrows=1).columns.tolist()