IMP logo
IMP Reference Guide  develop.63b38c487d,2024/12/22
The Integrative Modeling Platform
compare_runs_v2_w_pyplot.py
1 #!/usr/bin/env python
2 """@namespace IMP.nestor.compare_runs_v2_w_pyplot
3  Plotting script to compare NestOR runs"""
4 
5 import os
6 import sys
7 import math
8 import yaml
9 import numpy as np
10 import matplotlib as mpl
11 from matplotlib import pyplot as plt
12 
13 
14 TITLE = sys.argv[1]
15 runs_to_compare = sys.argv[2:]
16 transparency = 0.75
17 plt.rcParams["font.family"] = "sans-serif"
18 
19 # Start with setting default parameters for matplotlib
20 mpl.rcParams["font.size"] = 12
21 
22 
23 def get_all_results(all_runs: list) -> dict:
24  # Read all nestor output files
25  results = {}
26  for run in all_runs:
27  with open(os.path.join(run, "nestor_output.yaml"), "r") as resf:
28  result = yaml.safe_load(resf)
29  results[run.split("/")[-1]] = result
30  return results
31 
32 
33 def mean_type_plotter(results: dict, key: str, ylabel: str):
34  # Plots the mean values for the key
35  data = []
36  for parent in results: # parent is a trial set (init_x or x_fpi)
37  x_vals = []
38  y_vals = []
39  for run_set in results[parent]: # runset is res_01
40  all_vals = []
41  for run in results[parent][run_set]:
42  try:
43  val = float(results[parent][run_set][run][key])
44  except ValueError as err:
45  print(f"Terminating due to the following error...\n{err}")
46  return None
47  all_vals.append(val)
48  x_vals.append(run_set)
49  y_vals.append(np.mean(all_vals))
50  data.append((x_vals, y_vals, parent))
51 
52  fig = plt.figure()
53  for datum in data:
54  datum = list(datum)
55  datum[0] = [str(x.split("_")[-1]) for x in datum[0]]
56  plt.scatter(datum[0], datum[1], label=datum[2], alpha=transparency)
57 
58  plt.xlabel("Representation (number of residues per bead)")
59  plt.ylabel(ylabel)
60  plt.title(TITLE)
61  # fig.legend(bbox_to_anchor=(1.15, 1.0), loc="upper right")
62  fig.savefig(f"{ylabel}_comparison.png", bbox_inches="tight", dpi=600)
63  plt.close()
64 
65 
66 def errorbar_type_plotter(results: dict, key: str, ylabel: str):
67  data = []
68  for parent in results: # parent is a trial set (init_x or x_fpi)
69  xvals = []
70  yvals = []
71  yerr = []
72  for run_set in results[parent]: # runset is res_01
73  xvals.append(run_set)
74  all_vals = [
75  float(results[parent][run_set][run][key])
76  for run in results[parent][run_set]
77  ]
78  yvals.append(np.mean(all_vals))
79  yerr.append(np.std(all_vals) / (math.sqrt(len(all_vals))))
80 
81  data.append((xvals, yvals, yerr, parent))
82 
83  fig = plt.figure()
84  for idx, datum in enumerate(data):
85  datum = list(datum)
86  datum[0] = [str(x.split("_")[-1]) for x in datum[0]]
87  plt.errorbar(
88  datum[0],
89  datum[1],
90  yerr=datum[2],
91  label=datum[3],
92  fmt="o",
93  alpha=transparency,
94  c=f"C{idx}",
95  )
96  plt.xlabel("Representation (number of residues per bead)")
97  if "log" in ylabel:
98  # plt.rcParams["text.usetex"] = True
99  ylabel = "Mean log$Z$"
100 
101  plt.ylabel(ylabel)
102  plt.title(TITLE)
103  fig.legend(bbox_to_anchor=(1.15, 1.0), loc="upper right")
104  fig.savefig(f"{ylabel}_comparison.png", bbox_inches="tight", dpi=600)
105  plt.close()
106 
107 
108 def plot_sterr(results: dict):
109  """Plots standard error comparison"""
110  data = []
111  for parent in results: # parent is a trial set (init_x or x_fpi)
112  x_vals = []
113  y_vals = []
114  for run_set in results[parent]: # runset is res_01
115  log_evi = []
116  for run in results[parent][run_set]:
117  r_temp = results[parent][run_set][run]
118  log_evi.append(float(r_temp["log_estimated_evidence"]))
119  stderr_log_evi = np.std(log_evi) / (math.sqrt(len(log_evi)))
120  x_vals.append(run_set)
121  y_vals.append(stderr_log_evi)
122  data.append((x_vals, y_vals, parent))
123 
124  fig = plt.figure()
125  for datum in data:
126  plt.scatter(
127  [x.split("_")[-1] for x in datum[0]], # datum[0].split("_")[-1],
128  datum[1],
129  label=datum[2],
130  alpha=transparency,
131  )
132 
133  plt.xlabel("Representation (number of residues per bead)")
134  plt.ylabel("Standard error on log(Evidence)")
135  plt.title(TITLE)
136  # fig.legend() # bbox_to_anchor=(1.15, 1.0),loc="upper right"
137  fig.savefig("stderr_comparison.png", bbox_inches="tight", dpi=600)
138  plt.close()
139 
140 
141 # --------------------------------------------------------------------
142 # ---------------------------------------------- Main ----------------
143 # --------------------------------------------------------------------
144 
145 nestor_results = get_all_results(runs_to_compare)
146 
147 
148 toPlot_meanType: dict = {
149  "analytical_uncertainty": "Mean analytical uncertainties",
150 }
151 
152 toPlot_errorbarType: dict = {
153  "last_iter": "Mean iterations",
154  "log_estimated_evidence": "Mean log(Z)",
155  "nestor_process_time": "Mean NestOR process time",
156  # "mcmc_step_time": "Mean time per MCMC step",
157 }
158 
159 for key, y_lbl in toPlot_meanType.items():
160  mean_type_plotter(
161  nestor_results,
162  key=key,
163  ylabel=y_lbl,
164  )
165 
166 for key, y_lbl in toPlot_errorbarType.items():
167  errorbar_type_plotter(
168  nestor_results,
169  key=key,
170  ylabel=y_lbl,
171  )
172 
173 plot_sterr(nestor_results)
def plot_sterr
Plots standard error comparison.