IMP logo
IMP Reference Guide  2.21.0
The Integrative Modeling Platform
wrapper_v6.py
1 """@namespace IMP.nestor.wrapper_v6
2  Top-level NestOR script"""
3 
4 import os
5 import yaml
6 import math
7 import shutil
8 import argparse
9 import subprocess
10 import numpy as np
11 from mergedeep import merge
12 from ast import literal_eval
13 from matplotlib import pyplot as plt
14 
15 
16 ###################################################
17 # Functions
18 ###################################################
19 def parse_args():
20  parser = argparse.ArgumentParser()
21  parser.add_argument(
22  "-p",
23  dest="paramf",
24  type=str,
25  required=True,
26  help="Absolute path to yaml file containing the input parameters",
27  )
28  parser.add_argument(
29  "-t",
30  dest="topology",
31  action="store_true",
32  help="Whether to use topology file",
33  )
34  parser.add_argument(
35  "-s",
36  dest="skip_calc",
37  action="store_true",
38  help="Running only the plotting functions?",
39  )
40 
41  return parser.parse_args()
42 
43 
44 def get_all_toruns(h_params, target_runs):
45  parent_dir = h_params["parent_dir"]
46  runs = []
47  for res in h_params["resolutions"]:
48  for run in target_runs:
49  run_deets = (os.path.join(parent_dir, f"res_{res}"), str(run))
50  runs.append(run_deets)
51  return runs
52 
53 
54 def get_curr_processes_and_terminated_runs(processes: dict):
55  faulty_runs = []
56  successful_runs = []
57  terminated_runs = []
58 
59  for run_deets, proc in processes.items():
60  if proc.poll() is not None:
61  proc.wait()
62  terminated_runs.append(run_deets)
63 
64  if proc.returncode == 11 or proc.returncode == 12:
65  faulty_runs.append((run_deets, proc))
66  if proc.returncode == 11:
67  shutil.rmtree(os.path.join(run_deets[0],
68  f"run_{run_deets[1]}"))
69  elif proc.returncode == 0:
70  successful_runs.append((run_deets, proc))
71 
72  for run in terminated_runs:
73  print(
74  f"Terminated: {run[0].split('/')[-1]}, run_{run[1]} with "
75  f"exit code: {processes[run].returncode}"
76  )
77  if processes[run].returncode != 0:
78  print(f"Error:\n{processes[run].stderr.read()}")
79 
80  processes.pop(run)
81 
82  return processes, faulty_runs, successful_runs
83 
84 
85 def plotter(results: dict, h_params):
86  all_log_z = {}
87  mean_proc_time = []
88  mean_per_step_time = []
89  resolutions = []
90 
91  plt.figure(1)
92  for resolution in results:
93  if not resolution[4:] in resolutions:
94  resolutions.append(resolution[4:])
95 
96  log_z = []
97  proc_time = []
98  per_step_time = []
99  for _, run in results[resolution].items():
100  log_z.append(run["log_estimated_evidence"])
101  proc_time.append(run["nestor_process_time"])
102  per_step_time.append(run["mcmc_step_time"])
103 
104  all_log_z[resolution] = log_z
105  mean_proc_time.append(np.mean(proc_time))
106  mean_per_step_time.append(np.mean(per_step_time))
107 
108  avg_logz = np.mean(log_z)
109  stderr_logz = np.std(log_z) / math.sqrt(len(log_z))
110  plt.errorbar(
111  resolution[4:], avg_logz, yerr=stderr_logz, fmt="o", c="dodgerblue"
112  )
113 
114  plt.xlabel("Resolutions")
115  plt.ylabel("log(Evidence)")
116  plt.savefig(
117  os.path.join(
118  h_params["parent_dir"],
119  f"trial_{h_params['trial_name']}_evidence_errorbarplot.png",
120  )
121  )
122 
123  plt.figure(2)
124  # resolutions, mean_proc_time = zip(
125  # *sorted(zip(resolutions, mean_proc_time), key=lambda x: x[0])
126  # )
127  plt.scatter(resolutions, mean_proc_time, c="C2", marker="o")
128  plt.xlabel("Resolutions")
129  plt.ylabel("Nested sampling process time")
130  plt.savefig(
131  os.path.join(
132  h_params["parent_dir"],
133  f"trial_{h_params['trial_name']}_proctime.png"
134  )
135  )
136 
137  plt.figure(3)
138  # resolutions, mean_per_step_time = zip(
139  # *sorted(zip(resolutions, mean_per_step_time), key=lambda x: x[0])
140  # )
141  plt.scatter(resolutions, mean_per_step_time, c="C2", marker="o")
142  plt.xlabel("Resolutions")
143  plt.ylabel("Mean time per MCMC step")
144  plt.savefig(
145  os.path.join(
146  h_params["parent_dir"],
147  f"trial_{h_params['trial_name']}_persteptime.png"
148  )
149  )
150 
151  plot_evi_proctime(results, h_params)
152 
153 
154 def run_nested_sampling(h_param_file, topology=True):
155  with open(h_param_file, "r") as paramf:
156  h_params = yaml.safe_load(paramf)
157 
158  max_allowed_runs = h_params["max_usable_threads"] // h_params["num_cores"]
159  parent_path = h_params["parent_dir"]
160 
161  if not os.path.isdir(parent_path):
162  os.mkdir(parent_path)
163 
164  target_runs = str(h_params["num_runs"])
165  if "-" not in target_runs:
166  target_runs = range(0, int(target_runs))
167  else:
168  target_runs = range(
169  int(target_runs.split("-")[0]), int(target_runs.split("-")[1])
170  )
171 
172  torun = get_all_toruns(h_params, target_runs)
173 
174  results = {}
175 
176  processes = {"Dummy": "Dummy"}
177  completed_runs = []
178  while len(list(processes.keys())) > 0:
179  if "Dummy" in processes.keys():
180  processes.pop("Dummy")
181 
182  if len(torun) > 0:
183  curr_iter_torun = [run for run in torun]
184  for res, run_id in curr_iter_torun:
185  if len(processes) < max_allowed_runs:
186  if not os.path.isdir(res):
187  os.mkdir(res)
188 
189  os.chdir(res)
190  os.mkdir(f"run_{run_id}")
191  os.chdir(f"run_{run_id}")
192 
193  if topology:
194  topf = \
195  f"topology{res.split('/')[-1].split('_')[-1]}.txt"
196  else:
197  topf = res.split("/")[-1].split("_")[-1]
198 
199  run_cmd = [
200  "mpirun",
201  "-n",
202  str(h_params["num_cores"]),
203  h_params["imp_path"],
204  "python",
205  h_params["modeling_script_path"],
206  str(run_id),
207  topf,
208  h_param_file,
209  ]
210 
211  p = subprocess.Popen(
212  run_cmd,
213  stdout=subprocess.PIPE,
214  stderr=subprocess.PIPE,
215  text=True,
216  )
217  processes[(res, run_id)] = p
218  torun.remove((res, run_id))
219  print(f"Launched: {res.split('/')[-1]}, run_{run_id}")
220 
221  else:
222  print("Waiting for free threads...")
223 
224  waiting = True
225  while waiting:
226  for _, p in processes.items():
227  if p.poll() is not None:
228  waiting = False
229 
230  (
231  processes,
232  curr_faulty_runs,
233  successful_runs,
234  ) = get_curr_processes_and_terminated_runs(processes)
235 
236  for proc in successful_runs:
237  completed_runs.append(proc)
238  if len(processes) == 0:
239  break
240 
241  if len(curr_faulty_runs) != 0:
242  for fr, p in curr_faulty_runs:
243  if p.returncode == 11:
244  print(f"Will relaunch ({fr[0].split('/')[-1]}, "
245  f"run_{fr[1]})")
246  torun.append(fr)
247  elif p.returncode == 12:
248  print(
249  f"Terminated: {fr[0].split('/')[-1]}, run_{fr[1]} "
250  f"with exit code: {p.returncode}"
251  )
252  print(
253  f"{fr[0].split('/')[-1]}, run_{fr[1]} ran out of "
254  f"maximum allowed iterations before converging. "
255  f"Will not relaunch it..."
256  )
257 
258  print(f"Waiting for {len(processes.keys())} processes to terminate...")
259 
260  while len(processes) > 0:
261  final_waiting = True
262  while final_waiting:
263  for _, p in processes.items():
264  if p.poll() is not None:
265  final_waiting = False
266 
267  (
268  processes,
269  curr_faulty_runs,
270  successful_runs,
271  ) = get_curr_processes_and_terminated_runs(processes)
272 
273  for proc in successful_runs:
274  completed_runs.append(proc)
275 
276  # Preparing the output
277 
278  print("Performing housekeeping tasks")
279 
280  for proc in completed_runs:
281  run_deets, p = proc
282  if p.returncode == 0:
283  out, _ = p.communicate()
284 
285  result = literal_eval(out[4:])
286 
287  if run_deets[0].split("/")[-1] not in results.keys():
288  results[f"{run_deets[0].split('/')[-1]}"] = {
289  f"run_{run_deets[1]}": result
290  }
291  else:
292  results[f"{run_deets[0].split('/')[-1]}"][
293  f"run_{run_deets[1]}"
294  ] = result
295 
296  else:
297  _, err = p.communicate()
298  print(err)
299  exit()
300 
301  if "nestor_output.yaml" in os.listdir(parent_path):
302  with open(os.path.join(parent_path, "nestor_output.yaml"), "r") as inf:
303  old_results = yaml.safe_load(inf)
304  merge(results, old_results)
305 
306  with open(f"{parent_path}/nestor_output.yaml", "w") as outf:
307  yaml.dump(results, outf)
308 
309 
310 def plot_evi_proctime(nestor_results: dict, h_params: dict):
311  representations: list = []
312  log_evi_mean_sterr: list = []
313  proctime_mean_sterr: list = []
314  for k in nestor_results:
315  representations.append(k.split("_")[-1])
316  log_evi, proctime = [], []
317 
318  for k1 in nestor_results[k]:
319  log_evi.append(nestor_results[k][k1]["log_estimated_evidence"])
320  proctime.append(nestor_results[k][k1]["mcmc_step_time"])
321 
322  log_evi_mean_sterr.append(
323  (np.mean(log_evi), np.std(log_evi) / math.sqrt(len(log_evi)))
324  )
325  proctime_mean_sterr.append(
326  (np.mean(proctime), np.std(proctime) / math.sqrt(len(proctime)))
327  )
328 
329  log_evi_mean_sterr = np.array(log_evi_mean_sterr)
330  proctime_mean_sterr = np.array(proctime_mean_sterr)
331 
332  fig, ax1 = plt.subplots()
333  ax1.errorbar(
334  x=representations,
335  y=log_evi_mean_sterr[:, 0],
336  yerr=log_evi_mean_sterr[:, 1],
337  fmt="o",
338  c="dodgerblue",
339  label="Log(Evidence)",
340  )
341  # plt.rcParams["text.usetex"] = True
342  ylabel = "Mean log$Z$"
343  ax1.set_ylabel(ylabel)
344  ax1.set_xlabel("Representations")
345 
346  ax2 = plt.twinx(ax=ax1)
347 
348  ax2.scatter(
349  x=representations,
350  y=proctime_mean_sterr[:, 0],
351  # yerr=proctime_mean_sterr[:, 1],
352  # fmt="o",
353  c="green",
354  label="NestOR process time",
355  )
356  ax2.set_ylabel("Time per MCMC sampling step (sec)")
357  # fig.legend()
358  print("\n\nCame here\n\n")
359  output_fname: str = os.path.join(
360  h_params["parent_dir"], "sterr_evi_and_proctime.png"
361  )
362 
363  plt.savefig(f"{output_fname}.png", dpi=1200)
364  plt.close()
365 
366 
367 def main():
368  args = parse_args()
369  h_param_file = args.paramf
370  use_topology = args.topology
371 
372  with open(h_param_file, "r") as h_paramf:
373  h_params = yaml.safe_load(h_paramf)
374 
375  if not args.skip_calc:
376  run_nested_sampling(h_param_file, use_topology)
377 
378  with open(os.path.join(h_params["parent_dir"],
379  "nestor_output.yaml"), "r") as outf:
380  results = yaml.safe_load(outf)
381 
382  if len(list(results.keys())) > 0:
383  print("Plotting the results")
384  plotter(results, h_params)
385  # plot_evi_proctime_together(results)
386  else:
387  print("\nNone of the runs was successful...!")
388  print("Done...!\n\n")
389 
390 
391 ###################################################
392 # Main
393 ###################################################
394 
395 if __name__ == "__main__":
396  main()