IMP logo
IMP Reference Guide  2.21.0
The Integrative Modeling Platform
compare_runs_v2_w_pyplot.py
1 """@namespace IMP.nestor.compare_runs_v2_w_pyplot
2  Plotting script to compare NestOR runs"""
3 
4 import os
5 import sys
6 import math
7 import yaml
8 import numpy as np
9 import matplotlib as mpl
10 from matplotlib import pyplot as plt
11 
12 
13 TITLE = sys.argv[1]
14 runs_to_compare = sys.argv[2:]
15 transparency = 0.75
16 plt.rcParams["font.family"] = "sans-serif"
17 
18 # Start with setting default parameters for matplotlib
19 mpl.rcParams["font.size"] = 12
20 
21 
22 def get_all_results(all_runs: list) -> dict:
23  # Read all nestor output files
24  results = {}
25  for run in all_runs:
26  with open(os.path.join(run, "nestor_output.yaml"), "r") as resf:
27  result = yaml.safe_load(resf)
28  results[run.split("/")[-1]] = result
29  return results
30 
31 
32 def mean_type_plotter(results: dict, key: str, ylabel: str):
33  # Plots the mean values for the key
34  data = []
35  for parent in results: # parent is a trial set (init_x or x_fpi)
36  x_vals = []
37  y_vals = []
38  for run_set in results[parent]: # runset is res_01
39  all_vals = []
40  for run in results[parent][run_set]:
41  try:
42  val = float(results[parent][run_set][run][key])
43  except ValueError as err:
44  print(f"Terminating due to the following error...\n{err}")
45  return None
46  all_vals.append(val)
47  x_vals.append(run_set)
48  y_vals.append(np.mean(all_vals))
49  data.append((x_vals, y_vals, parent))
50 
51  fig = plt.figure()
52  for datum in data:
53  datum = list(datum)
54  datum[0] = [str(x.split("_")[-1]) for x in datum[0]]
55  plt.scatter(datum[0], datum[1], label=datum[2], alpha=transparency)
56 
57  plt.xlabel("Representation (number of residues per bead)")
58  plt.ylabel(ylabel)
59  plt.title(TITLE)
60  # fig.legend(bbox_to_anchor=(1.15, 1.0), loc="upper right")
61  fig.savefig(f"{ylabel}_comparison.png", bbox_inches="tight", dpi=600)
62  plt.close()
63 
64 
65 def errorbar_type_plotter(results: dict, key: str, ylabel: str):
66  data = []
67  for parent in results: # parent is a trial set (init_x or x_fpi)
68  xvals = []
69  yvals = []
70  yerr = []
71  for run_set in results[parent]: # runset is res_01
72  xvals.append(run_set)
73  all_vals = [
74  float(results[parent][run_set][run][key])
75  for run in results[parent][run_set]
76  ]
77  yvals.append(np.mean(all_vals))
78  yerr.append(np.std(all_vals) / (math.sqrt(len(all_vals))))
79 
80  data.append((xvals, yvals, yerr, parent))
81 
82  fig = plt.figure()
83  for idx, datum in enumerate(data):
84  datum = list(datum)
85  datum[0] = [str(x.split("_")[-1]) for x in datum[0]]
86  plt.errorbar(
87  datum[0],
88  datum[1],
89  yerr=datum[2],
90  label=datum[3],
91  fmt="o",
92  alpha=transparency,
93  c=f"C{idx}",
94  )
95  plt.xlabel("Representation (number of residues per bead)")
96  if "log" in ylabel:
97  # plt.rcParams["text.usetex"] = True
98  ylabel = "Mean log$Z$"
99 
100  plt.ylabel(ylabel)
101  plt.title(TITLE)
102  fig.legend(bbox_to_anchor=(1.15, 1.0), loc="upper right")
103  fig.savefig(f"{ylabel}_comparison.png", bbox_inches="tight", dpi=600)
104  plt.close()
105 
106 
107 def plot_sterr(results: dict):
108  """Plots standard error comparison"""
109  data = []
110  for parent in results: # parent is a trial set (init_x or x_fpi)
111  x_vals = []
112  y_vals = []
113  for run_set in results[parent]: # runset is res_01
114  log_evi = []
115  for run in results[parent][run_set]:
116  log_evi.append(float(
117  results[parent][run_set][run]["log_estimated_evidence"]))
118  stderr_log_evi = np.std(log_evi) / (math.sqrt(len(log_evi)))
119  x_vals.append(run_set)
120  y_vals.append(stderr_log_evi)
121  data.append((x_vals, y_vals, parent))
122 
123  fig = plt.figure()
124  for datum in data:
125  plt.scatter(
126  [x.split("_")[-1] for x in datum[0]], # datum[0].split("_")[-1],
127  datum[1],
128  label=datum[2],
129  alpha=transparency,
130  )
131 
132  plt.xlabel("Representation (number of residues per bead)")
133  plt.ylabel("Standard error on log(Evidence)")
134  plt.title(TITLE)
135  # fig.legend() # bbox_to_anchor=(1.15, 1.0),loc="upper right"
136  fig.savefig("stderr_comparison.png", bbox_inches="tight", dpi=600)
137  plt.close()
138 
139 
140 # --------------------------------------------------------------------
141 # ---------------------------------------------- Main ----------------
142 # --------------------------------------------------------------------
143 
144 nestor_results = get_all_results(runs_to_compare)
145 
146 
147 toPlot_meanType: dict = {
148  "analytical_uncertainty": "Mean analytical uncertainties",
149 }
150 
151 toPlot_errorbarType: dict = {
152  "last_iter": "Mean iterations",
153  "log_estimated_evidence": "Mean log(Z)",
154  "nestor_process_time": "Mean NestOR process time",
155  # "mcmc_step_time": "Mean time per MCMC step",
156 }
157 
158 for key, y_lbl in toPlot_meanType.items():
159  mean_type_plotter(
160  nestor_results,
161  key=key,
162  ylabel=y_lbl,
163  )
164 
165 for key, y_lbl in toPlot_errorbarType.items():
166  errorbar_type_plotter(
167  nestor_results,
168  key=key,
169  ylabel=y_lbl,
170  )
171 
172 plot_sterr(nestor_results)