1 """@namespace IMP.nestor.wrapper_v6
2 Top-level NestOR script"""
11 from mergedeep
import merge
12 from ast
import literal_eval
13 from matplotlib
import pyplot
as plt
20 parser = argparse.ArgumentParser()
26 help=
"Absolute path to yaml file containing the input parameters",
32 help=
"Whether to use topology file",
38 help=
"Running only the plotting functions?",
41 return parser.parse_args()
44 def get_all_toruns(h_params, target_runs):
45 parent_dir = h_params[
"parent_dir"]
47 for res
in h_params[
"resolutions"]:
48 for run
in target_runs:
49 run_deets = (os.path.join(parent_dir, f
"res_{res}"), str(run))
50 runs.append(run_deets)
54 def get_curr_processes_and_terminated_runs(processes: dict):
59 for run_deets, proc
in processes.items():
60 if proc.poll()
is not None:
62 terminated_runs.append(run_deets)
64 if proc.returncode == 11
or proc.returncode == 12:
65 faulty_runs.append((run_deets, proc))
66 if proc.returncode == 11:
67 shutil.rmtree(os.path.join(run_deets[0],
68 f
"run_{run_deets[1]}"))
69 elif proc.returncode == 0:
70 successful_runs.append((run_deets, proc))
72 for run
in terminated_runs:
74 f
"Terminated: {run[0].split('/')[-1]}, run_{run[1]} with "
75 f
"exit code: {processes[run].returncode}"
77 if processes[run].returncode != 0:
78 print(f
"Error:\n{processes[run].stderr.read()}")
82 return processes, faulty_runs, successful_runs
85 def plotter(results: dict, h_params):
88 mean_per_step_time = []
92 for resolution
in results:
93 if not resolution[4:]
in resolutions:
94 resolutions.append(resolution[4:])
99 for _, run
in results[resolution].items():
100 log_z.append(run[
"log_estimated_evidence"])
101 proc_time.append(run[
"nestor_process_time"])
102 per_step_time.append(run[
"mcmc_step_time"])
104 all_log_z[resolution] = log_z
105 mean_proc_time.append(np.mean(proc_time))
106 mean_per_step_time.append(np.mean(per_step_time))
108 avg_logz = np.mean(log_z)
109 stderr_logz = np.std(log_z) / math.sqrt(len(log_z))
111 resolution[4:], avg_logz, yerr=stderr_logz, fmt=
"o", c=
"dodgerblue"
114 plt.xlabel(
"Resolutions")
115 plt.ylabel(
"log(Evidence)")
118 h_params[
"parent_dir"],
119 f
"trial_{h_params['trial_name']}_evidence_errorbarplot.png",
127 plt.scatter(resolutions, mean_proc_time, c=
"C2", marker=
"o")
128 plt.xlabel(
"Resolutions")
129 plt.ylabel(
"Nested sampling process time")
132 h_params[
"parent_dir"],
133 f
"trial_{h_params['trial_name']}_proctime.png"
141 plt.scatter(resolutions, mean_per_step_time, c=
"C2", marker=
"o")
142 plt.xlabel(
"Resolutions")
143 plt.ylabel(
"Mean time per MCMC step")
146 h_params[
"parent_dir"],
147 f
"trial_{h_params['trial_name']}_persteptime.png"
151 plot_evi_proctime(results, h_params)
154 def run_nested_sampling(h_param_file, topology=True):
155 with open(h_param_file,
"r") as paramf:
156 h_params = yaml.safe_load(paramf)
158 max_allowed_runs = h_params["max_usable_threads"] // h_params[
"num_cores"]
159 parent_path = h_params[
"parent_dir"]
161 if not os.path.isdir(parent_path):
162 os.mkdir(parent_path)
164 target_runs = str(h_params[
"num_runs"])
165 if "-" not in target_runs:
166 target_runs = range(0, int(target_runs))
169 int(target_runs.split(
"-")[0]), int(target_runs.split(
"-")[1])
172 torun = get_all_toruns(h_params, target_runs)
176 processes = {
"Dummy":
"Dummy"}
178 while len(list(processes.keys())) > 0:
179 if "Dummy" in processes.keys():
180 processes.pop(
"Dummy")
183 curr_iter_torun = [run
for run
in torun]
184 for res, run_id
in curr_iter_torun:
185 if len(processes) < max_allowed_runs:
186 if not os.path.isdir(res):
190 os.mkdir(f
"run_{run_id}")
191 os.chdir(f
"run_{run_id}")
195 f
"topology{res.split('/')[-1].split('_')[-1]}.txt"
197 topf = res.split(
"/")[-1].split(
"_")[-1]
202 str(h_params[
"num_cores"]),
203 h_params[
"imp_path"],
205 h_params[
"modeling_script_path"],
211 p = subprocess.Popen(
213 stdout=subprocess.PIPE,
214 stderr=subprocess.PIPE,
217 processes[(res, run_id)] = p
218 torun.remove((res, run_id))
219 print(f
"Launched: {res.split('/')[-1]}, run_{run_id}")
222 print(
"Waiting for free threads...")
226 for _, p
in processes.items():
227 if p.poll()
is not None:
234 ) = get_curr_processes_and_terminated_runs(processes)
236 for proc
in successful_runs:
237 completed_runs.append(proc)
238 if len(processes) == 0:
241 if len(curr_faulty_runs) != 0:
242 for fr, p
in curr_faulty_runs:
243 if p.returncode == 11:
244 print(f
"Will relaunch ({fr[0].split('/')[-1]}, "
247 elif p.returncode == 12:
249 f
"Terminated: {fr[0].split('/')[-1]}, run_{fr[1]} "
250 f
"with exit code: {p.returncode}"
253 f
"{fr[0].split('/')[-1]}, run_{fr[1]} ran out of "
254 f
"maximum allowed iterations before converging. "
255 f
"Will not relaunch it..."
258 print(f
"Waiting for {len(processes.keys())} processes to terminate...")
260 while len(processes) > 0:
263 for _, p
in processes.items():
264 if p.poll()
is not None:
265 final_waiting =
False
271 ) = get_curr_processes_and_terminated_runs(processes)
273 for proc
in successful_runs:
274 completed_runs.append(proc)
278 print(
"Performing housekeeping tasks")
280 for proc
in completed_runs:
282 if p.returncode == 0:
283 out, _ = p.communicate()
285 result = literal_eval(out[4:])
287 if run_deets[0].split(
"/")[-1]
not in results.keys():
288 results[f
"{run_deets[0].split('/')[-1]}"] = {
289 f
"run_{run_deets[1]}": result
292 results[f
"{run_deets[0].split('/')[-1]}"][
293 f
"run_{run_deets[1]}"
297 _, err = p.communicate()
301 if "nestor_output.yaml" in os.listdir(parent_path):
302 with open(os.path.join(parent_path,
"nestor_output.yaml"),
"r") as inf:
303 old_results = yaml.safe_load(inf)
304 merge(results, old_results)
306 with open(f"{parent_path}/nestor_output.yaml",
"w")
as outf:
307 yaml.dump(results, outf)
310 def plot_evi_proctime(nestor_results: dict, h_params: dict):
311 representations: list = []
312 log_evi_mean_sterr: list = []
313 proctime_mean_sterr: list = []
314 for k
in nestor_results:
315 representations.append(k.split(
"_")[-1])
316 log_evi, proctime = [], []
318 for k1
in nestor_results[k]:
319 log_evi.append(nestor_results[k][k1][
"log_estimated_evidence"])
320 proctime.append(nestor_results[k][k1][
"mcmc_step_time"])
322 log_evi_mean_sterr.append(
323 (np.mean(log_evi), np.std(log_evi) / math.sqrt(len(log_evi)))
325 proctime_mean_sterr.append(
326 (np.mean(proctime), np.std(proctime) / math.sqrt(len(proctime)))
329 log_evi_mean_sterr = np.array(log_evi_mean_sterr)
330 proctime_mean_sterr = np.array(proctime_mean_sterr)
332 fig, ax1 = plt.subplots()
335 y=log_evi_mean_sterr[:, 0],
336 yerr=log_evi_mean_sterr[:, 1],
339 label=
"Log(Evidence)",
342 ylabel =
"Mean log$Z$"
343 ax1.set_ylabel(ylabel)
344 ax1.set_xlabel(
"Representations")
346 ax2 = plt.twinx(ax=ax1)
350 y=proctime_mean_sterr[:, 0],
354 label=
"NestOR process time",
356 ax2.set_ylabel(
"Time per MCMC sampling step (sec)")
358 print(
"\n\nCame here\n\n")
359 output_fname: str = os.path.join(
360 h_params[
"parent_dir"],
"sterr_evi_and_proctime.png"
363 plt.savefig(f
"{output_fname}.png", dpi=1200)
369 h_param_file = args.paramf
370 use_topology = args.topology
372 with open(h_param_file,
"r") as h_paramf:
373 h_params = yaml.safe_load(h_paramf)
375 if not args.skip_calc:
376 run_nested_sampling(h_param_file, use_topology)
378 with open(os.path.join(h_params[
"parent_dir"],
379 "nestor_output.yaml"),
"r") as outf:
380 results = yaml.safe_load(outf)
382 if len(list(results.keys())) > 0:
383 print(
"Plotting the results")
384 plotter(results, h_params)
387 print(
"\nNone of the runs was successful...!")
388 print(
"Done...!\n\n")
395 if __name__ ==
"__main__":