1 """@namespace IMP.nestor.compare_runs_v2_w_pyplot
2 Plotting script to compare NestOR runs"""
9 import matplotlib
as mpl
10 from matplotlib
import pyplot
as plt
14 runs_to_compare = sys.argv[2:]
16 plt.rcParams[
"font.family"] =
"sans-serif"
19 mpl.rcParams[
"font.size"] = 12
22 def get_all_results(all_runs: list) -> dict:
26 with open(os.path.join(run,
"nestor_output.yaml"),
"r") as resf:
27 result = yaml.safe_load(resf)
28 results[run.split("/")[-1]] = result
32 def mean_type_plotter(results: dict, key: str, ylabel: str):
35 for parent
in results:
38 for run_set
in results[parent]:
40 for run
in results[parent][run_set]:
42 val = float(results[parent][run_set][run][key])
43 except ValueError
as err:
44 print(f
"Terminating due to the following error...\n{err}")
47 x_vals.append(run_set)
48 y_vals.append(np.mean(all_vals))
49 data.append((x_vals, y_vals, parent))
54 datum[0] = [str(x.split(
"_")[-1])
for x
in datum[0]]
55 plt.scatter(datum[0], datum[1], label=datum[2], alpha=transparency)
57 plt.xlabel(
"Representation (number of residues per bead)")
61 fig.savefig(f
"{ylabel}_comparison.png", bbox_inches=
"tight", dpi=600)
65 def errorbar_type_plotter(results: dict, key: str, ylabel: str):
67 for parent
in results:
71 for run_set
in results[parent]:
74 float(results[parent][run_set][run][key])
75 for run
in results[parent][run_set]
77 yvals.append(np.mean(all_vals))
78 yerr.append(np.std(all_vals) / (math.sqrt(len(all_vals))))
80 data.append((xvals, yvals, yerr, parent))
83 for idx, datum
in enumerate(data):
85 datum[0] = [str(x.split(
"_")[-1])
for x
in datum[0]]
95 plt.xlabel(
"Representation (number of residues per bead)")
98 ylabel =
"Mean log$Z$"
102 fig.legend(bbox_to_anchor=(1.15, 1.0), loc=
"upper right")
103 fig.savefig(f
"{ylabel}_comparison.png", bbox_inches=
"tight", dpi=600)
107 def plot_sterr(results: dict):
108 """Plots standard error comparison"""
110 for parent
in results:
113 for run_set
in results[parent]:
115 for run
in results[parent][run_set]:
116 log_evi.append(float(
117 results[parent][run_set][run][
"log_estimated_evidence"]))
118 stderr_log_evi = np.std(log_evi) / (math.sqrt(len(log_evi)))
119 x_vals.append(run_set)
120 y_vals.append(stderr_log_evi)
121 data.append((x_vals, y_vals, parent))
126 [x.split(
"_")[-1]
for x
in datum[0]],
132 plt.xlabel(
"Representation (number of residues per bead)")
133 plt.ylabel(
"Standard error on log(Evidence)")
136 fig.savefig(
"stderr_comparison.png", bbox_inches=
"tight", dpi=600)
144 nestor_results = get_all_results(runs_to_compare)
147 toPlot_meanType: dict = {
148 "analytical_uncertainty":
"Mean analytical uncertainties",
151 toPlot_errorbarType: dict = {
152 "last_iter":
"Mean iterations",
153 "log_estimated_evidence":
"Mean log(Z)",
154 "nestor_process_time":
"Mean NestOR process time",
158 for key, y_lbl
in toPlot_meanType.items():
165 for key, y_lbl
in toPlot_errorbarType.items():
166 errorbar_type_plotter(
172 plot_sterr(nestor_results)