IMP logo
IMP Reference Guide  develop.330bebda01,2025/01/21
The Integrative Modeling Platform
wrapper_v6.py
1 #!/usr/bin/env python
2 """@namespace IMP.nestor.wrapper_v6
3  Top-level NestOR script"""
4 
5 import os
6 import yaml
7 import math
8 import shutil
9 import argparse
10 import subprocess
11 import numpy as np
12 from mergedeep import merge
13 from ast import literal_eval
14 from matplotlib import pyplot as plt
15 
16 
17 ###################################################
18 # Functions
19 ###################################################
20 def parse_args():
21  parser = argparse.ArgumentParser()
22  parser.add_argument(
23  "-p",
24  dest="paramf",
25  type=str,
26  required=True,
27  help="Absolute path to yaml file containing the input parameters",
28  )
29  parser.add_argument(
30  "-t",
31  dest="topology",
32  action="store_true",
33  help="Whether to use topology file",
34  )
35  parser.add_argument(
36  "-s",
37  dest="skip_calc",
38  action="store_true",
39  help="Running only the plotting functions?",
40  )
41 
42  return parser.parse_args()
43 
44 
45 def get_all_toruns(h_params, target_runs):
46  parent_dir = h_params["parent_dir"]
47  runs = []
48  for res in h_params["resolutions"]:
49  for run in target_runs:
50  run_deets = (os.path.join(parent_dir, f"res_{res}"), str(run))
51  runs.append(run_deets)
52  return runs
53 
54 
55 def get_curr_processes_and_terminated_runs(processes: dict):
56  faulty_runs = []
57  successful_runs = []
58  terminated_runs = []
59 
60  for run_deets, proc in processes.items():
61  if proc.poll() is not None:
62  proc.wait()
63  terminated_runs.append(run_deets)
64 
65  if proc.returncode == 11 or proc.returncode == 12:
66  faulty_runs.append((run_deets, proc))
67  if proc.returncode == 11:
68  shutil.rmtree(
69  os.path.join(run_deets[0], f"run_{run_deets[1]}"),
70  )
71  elif proc.returncode == 0:
72  successful_runs.append((run_deets, proc))
73 
74  for run in terminated_runs:
75  print(
76  f"Terminated: {run[0].split('/')[-1]}, run_{run[1]} with "
77  f"exit code: {processes[run].returncode}"
78  )
79  if processes[run].returncode != 0:
80  print(f"Error:\n{processes[run].stderr.read()}")
81 
82  processes.pop(run)
83 
84  return processes, faulty_runs, successful_runs
85 
86 
87 def plotter(results: dict, h_params):
88  all_log_z = {}
89  mean_proc_time = []
90  mean_per_step_time = []
91  resolutions = []
92 
93  plt.figure(1)
94  for resolution in results:
95  if not resolution[4:] in resolutions:
96  resolutions.append(resolution[4:])
97 
98  log_z = []
99  proc_time = []
100  per_step_time = []
101  for _, run in results[resolution].items():
102  log_z.append(run["log_estimated_evidence"])
103  proc_time.append(run["nestor_process_time"])
104  per_step_time.append(run["mcmc_step_time"])
105 
106  all_log_z[resolution] = log_z
107  mean_proc_time.append(np.mean(proc_time))
108  mean_per_step_time.append(np.mean(per_step_time))
109 
110  avg_logz = np.mean(log_z)
111  stderr_logz = np.std(log_z) / math.sqrt(len(log_z))
112  plt.errorbar(
113  resolution[4:],
114  avg_logz,
115  yerr=stderr_logz,
116  fmt="o",
117  c="dodgerblue",
118  )
119 
120  plt.xlabel("Resolutions")
121  plt.ylabel("log(Evidence)")
122  plt.savefig(
123  os.path.join(
124  h_params["parent_dir"],
125  f"trial_{h_params['trial_name']}_evidence_errorbarplot.png",
126  )
127  )
128 
129  plt.figure(2)
130  # resolutions, mean_proc_time = zip(
131  # *sorted(zip(resolutions, mean_proc_time), key=lambda x: x[0])
132  # )
133  plt.scatter(resolutions, mean_proc_time, c="C2", marker="o")
134  plt.xlabel("Resolutions")
135  plt.ylabel("Nested sampling process time")
136  plt.savefig(
137  os.path.join(
138  h_params["parent_dir"],
139  f"trial_{h_params['trial_name']}_proctime.png",
140  )
141  )
142 
143  plt.figure(3)
144  # resolutions, mean_per_step_time = zip(
145  # *sorted(zip(resolutions, mean_per_step_time), key=lambda x: x[0])
146  # )
147  plt.scatter(resolutions, mean_per_step_time, c="C2", marker="o")
148  plt.xlabel("Resolutions")
149  plt.ylabel("Mean time per MCMC step")
150  plt.savefig(
151  os.path.join(
152  h_params["parent_dir"],
153  f"trial_{h_params['trial_name']}_persteptime.png",
154  )
155  )
156 
157  plot_evi_proctime(results, h_params)
158 
159 
160 def run_nested_sampling(h_param_file, topology=True):
161  with open(h_param_file, "r") as paramf:
162  h_params = yaml.safe_load(paramf)
163 
164  max_allowed_runs = h_params["max_usable_threads"] // h_params["num_cores"]
165  parent_path = h_params["parent_dir"]
166 
167  if not os.path.isdir(parent_path):
168  os.mkdir(parent_path)
169 
170  target_runs = str(h_params["num_runs"])
171  if "-" not in target_runs:
172  target_runs = range(0, int(target_runs))
173  else:
174  target_runs = range(
175  int(target_runs.split("-")[0]), int(target_runs.split("-")[1])
176  )
177 
178  torun = get_all_toruns(h_params, target_runs)
179 
180  results = {}
181 
182  processes = {"Dummy": "Dummy"}
183  completed_runs = []
184  while len(list(processes.keys())) > 0:
185  if "Dummy" in processes.keys():
186  processes.pop("Dummy")
187 
188  if len(torun) > 0:
189  curr_iter_torun = [run for run in torun]
190  for res, run_id in curr_iter_torun:
191  if len(processes) < max_allowed_runs:
192  if not os.path.isdir(res):
193  os.mkdir(res)
194 
195  os.chdir(res)
196  os.mkdir(f"run_{run_id}")
197  os.chdir(f"run_{run_id}")
198 
199  if topology:
200  topf = f"topology{res.split('/')[-1].split('_')[-1]}\
201  .txt"
202  else:
203  topf = res.split("/")[-1].split("_")[-1]
204 
205  run_cmd = [
206  "mpirun",
207  "-n",
208  str(h_params["num_cores"]),
209  h_params["imp_path"],
210  "python",
211  h_params["modeling_script_path"],
212  str(run_id),
213  topf,
214  h_param_file,
215  ]
216 
217  p = subprocess.Popen(
218  run_cmd,
219  stdout=subprocess.PIPE,
220  stderr=subprocess.PIPE,
221  text=True,
222  )
223  processes[(res, run_id)] = p
224  torun.remove((res, run_id))
225  print(f"Launched: {res.split('/')[-1]}, run_{run_id}")
226 
227  else:
228  print("Waiting for free threads...")
229 
230  waiting = True
231  while waiting:
232  for _, p in processes.items():
233  if p.poll() is not None:
234  waiting = False
235 
236  (
237  processes,
238  curr_faulty_runs,
239  successful_runs,
240  ) = get_curr_processes_and_terminated_runs(processes)
241 
242  for proc in successful_runs:
243  completed_runs.append(proc)
244  if len(processes) == 0:
245  break
246 
247  if len(curr_faulty_runs) != 0:
248  for fr, p in curr_faulty_runs:
249  if p.returncode == 11:
250  print(
251  f"Will relaunch \
252  ({fr[0].split('/')[-1]}, "
253  f"run_{fr[1]})"
254  )
255  torun.append(fr)
256  elif p.returncode == 12:
257  print(
258  f"Terminated: {fr[0].split('/')[-1]}, run_{fr[1]} "
259  f"with exit code: {p.returncode}"
260  )
261  print(
262  f"{fr[0].split('/')[-1]}, run_{fr[1]} ran out of "
263  f"maximum allowed iterations before converging. "
264  f"Will not relaunch it..."
265  )
266 
267  print(f"Waiting for {len(processes.keys())} processes to terminate...")
268 
269  while len(processes) > 0:
270  final_waiting = True
271  while final_waiting:
272  for _, p in processes.items():
273  if p.poll() is not None:
274  final_waiting = False
275 
276  (
277  processes,
278  curr_faulty_runs,
279  successful_runs,
280  ) = get_curr_processes_and_terminated_runs(processes)
281 
282  for proc in successful_runs:
283  completed_runs.append(proc)
284 
285  # Preparing the output
286 
287  print("Performing housekeeping tasks")
288 
289  for proc in completed_runs:
290  run_deets, p = proc
291  if p.returncode == 0:
292  out, _ = p.communicate()
293 
294  result = literal_eval(out[4:])
295 
296  if run_deets[0].split("/")[-1] not in results.keys():
297  results[f"{run_deets[0].split('/')[-1]}"] = {
298  f"run_{run_deets[1]}": result
299  }
300  else:
301  results[f"{run_deets[0].split('/')[-1]}"][
302  f"run_{run_deets[1]}"
303  ] = result
304 
305  else:
306  _, err = p.communicate()
307  print(err)
308  exit()
309 
310  if "nestor_output.yaml" in os.listdir(parent_path):
311  with open(os.path.join(parent_path, "nestor_output.yaml"), "r") as inf:
312  old_results = yaml.safe_load(inf)
313  merge(results, old_results)
314 
315  with open(f"{parent_path}/nestor_output.yaml", "w") as outf:
316  yaml.dump(results, outf)
317 
318 
319 def plot_evi_proctime(nestor_results: dict, h_params: dict):
320  representations: list = []
321  log_evi_mean_sterr: list = []
322  proctime_mean_sterr: list = []
323  for k in nestor_results:
324  representations.append(k.split("_")[-1])
325  log_evi, proctime = [], []
326 
327  for k1 in nestor_results[k]:
328  log_evi.append(nestor_results[k][k1]["log_estimated_evidence"])
329  proctime.append(nestor_results[k][k1]["mcmc_step_time"])
330 
331  log_evi_mean_sterr.append(
332  (np.mean(log_evi), np.std(log_evi) / math.sqrt(len(log_evi)))
333  )
334  proctime_mean_sterr.append(
335  (np.mean(proctime), np.std(proctime) / math.sqrt(len(proctime)))
336  )
337 
338  log_evi_mean_sterr = np.array(log_evi_mean_sterr)
339  proctime_mean_sterr = np.array(proctime_mean_sterr)
340 
341  fig, ax1 = plt.subplots()
342  ax1.errorbar(
343  x=representations,
344  y=log_evi_mean_sterr[:, 0],
345  yerr=log_evi_mean_sterr[:, 1],
346  fmt="o",
347  c="dodgerblue",
348  label="Log(Evidence)",
349  )
350  # plt.rcParams["text.usetex"] = True
351  ylabel = "Mean log$Z$"
352  ax1.set_ylabel(ylabel)
353  ax1.set_xlabel("Representations")
354 
355  ax2 = plt.twinx(ax=ax1)
356 
357  ax2.scatter(
358  x=representations,
359  y=proctime_mean_sterr[:, 0],
360  # yerr=proctime_mean_sterr[:, 1],
361  # fmt="o",
362  c="green",
363  label="NestOR process time",
364  )
365  ax2.set_ylabel("Time per MCMC sampling step (sec)")
366  # fig.legend()
367  print("\n\nCame here\n\n")
368  output_fname: str = os.path.join(
369  h_params["parent_dir"], "sterr_evi_and_proctime.png"
370  )
371 
372  plt.savefig(f"{output_fname}.png", dpi=1200)
373  plt.close()
374 
375 
376 def main():
377  args = parse_args()
378  h_param_file = args.paramf
379  use_topology = args.topology
380 
381  with open(h_param_file, "r") as h_paramf:
382  h_params = yaml.safe_load(h_paramf)
383 
384  if not args.skip_calc:
385  run_nested_sampling(h_param_file, use_topology)
386 
387  with open(
388  os.path.join(
389  h_params["parent_dir"],
390  "nestor_output.yaml",
391  ),
392  "r",
393  ) as outf:
394  results = yaml.safe_load(outf)
395 
396  if len(list(results.keys())) > 0:
397  print("Plotting the results")
398  plotter(results, h_params)
399  # plot_evi_proctime_together(results)
400  else:
401  print("\nNone of the runs was successful...!")
402  print("Done...!\n\n")
403 
404 
405 ###################################################
406 # Main
407 ###################################################
408 
409 if __name__ == "__main__":
410  main()