-
Notifications
You must be signed in to change notification settings - Fork 2
/
calculate_performance_metrics.py
150 lines (133 loc) · 5.17 KB
/
calculate_performance_metrics.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
#!/usr/bin/env python3
"""
This script calculates performance metrics given a set of predictions and ground truth.
"""
import argparse
import h5py
import numpy as np
import pandas as pd
from scipy.spatial.distance import jensenshannon
from scipy.stats import pearsonr, spearmanr
def main():
parser = argparse.ArgumentParser(description=__doc__)
parser.add_argument(
"predictions",
type=str,
help="An hdf5 file containing predictions (track, quantity).",
)
parser.add_argument(
"observed",
type=str,
help="A csv(.gz), npy, or npz file containing the observed procap tracks.",
)
parser.add_argument(
"output",
type=str,
help="An hdf5 file to write the performance metrics to.",
)
args = parser.parse_args()
# Load predictions
with h5py.File(args.predictions, "r") as hf:
track = hf["track"][:]
if len(hf["quantity"].shape) == 2:
quantity = hf["quantity"][:, 0]
else:
quantity = hf["quantity"][:]
# Load observed data
if args.observed.endswith(".npz") or args.observed.endswith(".npy"):
observed = np.load(args.observed)
if args.observed.endswith(".npz"):
observed = observed["arr_0"]
elif args.observed.endswith(".csv.gz") or args.observed.endswith(".csv"):
observed = pd.read_csv(args.observed, header=None, index_col=0).to_numpy()
else:
raise ValueError(
f"File with observed PRO-cap data ({args.observed}) must be numpy or csv format."
)
# Validate dimensions
if track.shape[0] != observed.shape[0]:
raise ValueError(
f"n predictions ({track.shape[0]}) and n observed ({observed.shape[0]}) do not match."
)
if track.shape[1] > observed.shape[1]:
raise ValueError(
f"Predicted tracks ({track.shape[1]}) are longer than observed ({observed.shape[1]})."
)
if (observed.shape[1] - track.shape[1]) % 4 != 0:
raise ValueError(
f"Padding around predicted tracks ({observed.shape[1] - track.shape[1]}) must be divisible by 4."
)
# Trim off padding for observed tracks
start = (observed.shape[1] - track.shape[1]) // 4
end = observed.shape[1] // 2 - start
observed_clipped = observed[
:,
np.r_[start:end, observed.shape[1] // 2 + start : observed.shape[1] // 2 + end],
]
# Benchmark directionality
track_directionality = np.log1p(
track[:, : track.shape[1] // 2].sum(axis=1)
) - np.log1p(track[:, track.shape[1] // 2 :].sum(axis=1))
observed_directionality = np.log1p(
observed_clipped[:, : observed_clipped.shape[1] // 2].sum(axis=1)
) - np.log1p(observed_clipped[:, observed_clipped.shape[1] // 2 :].sum(axis=1))
directionality_pearson = pearsonr(track_directionality, observed_directionality)
# Benchmark TSS position
strand_break = track.shape[1] // 2
pred_tss = np.concatenate(
[track[:, :strand_break].argmax(axis=1), track[:, strand_break:].argmax(axis=1)]
)
obs_tss = np.concatenate(
[
observed_clipped[:, :strand_break].argmax(axis=1),
observed_clipped[:, strand_break:].argmax(axis=1),
]
)
tss_pos_pearson = pearsonr(pred_tss, obs_tss)
# Benchmark profile
track_pearson = pd.DataFrame(track).corrwith(pd.DataFrame(observed_clipped), axis=1)
track_js_distance = jensenshannon(track, observed_clipped, axis=1)
# Benchmark quantity
quantity_log_pearson = pearsonr(
np.log1p(quantity), np.log1p(observed_clipped.sum(axis=1))
)
quantity_spearman = spearmanr(quantity, observed_clipped.sum(axis=1))
# Print summary
print(f"Median Track Pearson: {track_pearson.median():.4f}")
print(
f"Mean Track Pearson: {track_pearson.mean():.4f} "
+ f"+/- {track_pearson.std():.4f}"
)
print(f"Median Track JS Distance: {pd.Series(track_js_distance).median():.4f} ")
print(
f"Mean Track JS Distance: {pd.Series(track_js_distance).mean():.4f} "
+ f"+/- {pd.Series(track_js_distance).std():.4f}"
)
print(f"Track Directionality Pearson: {directionality_pearson[0]:.4f}")
print(f"TSS Position Pearson: {tss_pos_pearson[0]:.4f}")
print(f"Quantity Log Pearson: {quantity_log_pearson[0]:.4f}")
print(f"Quantity Spearman: {quantity_spearman[0]:.4f}")
# Save metrics
with h5py.File(args.output, "w") as hf:
hf.create_dataset(
"track_pearson", data=track_pearson.to_numpy(), compression="gzip"
)
hf.create_dataset(
"track_js_distance", data=track_js_distance, compression="gzip"
)
hf.create_dataset(
"track_directionality",
data=np.array(directionality_pearson),
compression="gzip",
)
hf.create_dataset("tss_pos_pearson", data=tss_pos_pearson, compression="gzip")
hf.create_dataset(
"quantity_log_pearson",
data=np.array(quantity_log_pearson),
compression="gzip",
)
hf.create_dataset(
"quantity_spearman", data=np.array(quantity_spearman), compression="gzip"
)
if __name__ == "__main__":
main()