2 """@namespace IMP.nestor.wrapper_v6
3 Top-level NestOR script"""
12 from mergedeep
import merge
13 from ast
import literal_eval
14 from matplotlib
import pyplot
as plt
21 parser = argparse.ArgumentParser()
27 help=
"Absolute path to yaml file containing the input parameters",
33 help=
"Whether to use topology file",
39 help=
"Running only the plotting functions?",
42 return parser.parse_args()
45 def get_all_toruns(h_params, target_runs):
46 parent_dir = h_params[
"parent_dir"]
48 for res
in h_params[
"resolutions"]:
49 for run
in target_runs:
50 run_deets = (os.path.join(parent_dir, f
"res_{res}"), str(run))
51 runs.append(run_deets)
55 def get_curr_processes_and_terminated_runs(processes: dict):
60 for run_deets, proc
in processes.items():
61 if proc.poll()
is not None:
63 terminated_runs.append(run_deets)
65 if proc.returncode == 11
or proc.returncode == 12:
66 faulty_runs.append((run_deets, proc))
67 if proc.returncode == 11:
69 os.path.join(run_deets[0], f
"run_{run_deets[1]}"),
71 elif proc.returncode == 0:
72 successful_runs.append((run_deets, proc))
74 for run
in terminated_runs:
76 f
"Terminated: {run[0].split('/')[-1]}, run_{run[1]} with "
77 f
"exit code: {processes[run].returncode}"
79 if processes[run].returncode != 0:
80 print(f
"Error:\n{processes[run].stderr.read()}")
84 return processes, faulty_runs, successful_runs
87 def plotter(results: dict, h_params):
90 mean_per_step_time = []
94 for resolution
in results:
95 if not resolution[4:]
in resolutions:
96 resolutions.append(resolution[4:])
101 for _, run
in results[resolution].items():
102 log_z.append(run[
"log_estimated_evidence"])
103 proc_time.append(run[
"nestor_process_time"])
104 per_step_time.append(run[
"mcmc_step_time"])
106 all_log_z[resolution] = log_z
107 mean_proc_time.append(np.mean(proc_time))
108 mean_per_step_time.append(np.mean(per_step_time))
110 avg_logz = np.mean(log_z)
111 stderr_logz = np.std(log_z) / math.sqrt(len(log_z))
120 plt.xlabel(
"Resolutions")
121 plt.ylabel(
"log(Evidence)")
124 h_params[
"parent_dir"],
125 f
"trial_{h_params['trial_name']}_evidence_errorbarplot.png",
133 plt.scatter(resolutions, mean_proc_time, c=
"C2", marker=
"o")
134 plt.xlabel(
"Resolutions")
135 plt.ylabel(
"Nested sampling process time")
138 h_params[
"parent_dir"],
139 f
"trial_{h_params['trial_name']}_proctime.png",
147 plt.scatter(resolutions, mean_per_step_time, c=
"C2", marker=
"o")
148 plt.xlabel(
"Resolutions")
149 plt.ylabel(
"Mean time per MCMC step")
152 h_params[
"parent_dir"],
153 f
"trial_{h_params['trial_name']}_persteptime.png",
157 plot_evi_proctime(results, h_params)
160 def run_nested_sampling(h_param_file, topology=True):
161 with open(h_param_file,
"r") as paramf:
162 h_params = yaml.safe_load(paramf)
164 max_allowed_runs = h_params["max_usable_threads"] // h_params[
"num_cores"]
165 parent_path = h_params[
"parent_dir"]
167 if not os.path.isdir(parent_path):
168 os.mkdir(parent_path)
170 target_runs = str(h_params[
"num_runs"])
171 if "-" not in target_runs:
172 target_runs = range(0, int(target_runs))
175 int(target_runs.split(
"-")[0]), int(target_runs.split(
"-")[1])
178 torun = get_all_toruns(h_params, target_runs)
182 processes = {
"Dummy":
"Dummy"}
184 while len(list(processes.keys())) > 0:
185 if "Dummy" in processes.keys():
186 processes.pop(
"Dummy")
189 curr_iter_torun = [run
for run
in torun]
190 for res, run_id
in curr_iter_torun:
191 if len(processes) < max_allowed_runs:
192 if not os.path.isdir(res):
196 os.mkdir(f
"run_{run_id}")
197 os.chdir(f
"run_{run_id}")
200 topf = f
"topology{res.split('/')[-1].split('_')[-1]}\
203 topf = res.split(
"/")[-1].split(
"_")[-1]
208 str(h_params[
"num_cores"]),
209 h_params[
"imp_path"],
211 h_params[
"modeling_script_path"],
217 p = subprocess.Popen(
219 stdout=subprocess.PIPE,
220 stderr=subprocess.PIPE,
223 processes[(res, run_id)] = p
224 torun.remove((res, run_id))
225 print(f
"Launched: {res.split('/')[-1]}, run_{run_id}")
228 print(
"Waiting for free threads...")
232 for _, p
in processes.items():
233 if p.poll()
is not None:
240 ) = get_curr_processes_and_terminated_runs(processes)
242 for proc
in successful_runs:
243 completed_runs.append(proc)
244 if len(processes) == 0:
247 if len(curr_faulty_runs) != 0:
248 for fr, p
in curr_faulty_runs:
249 if p.returncode == 11:
252 ({fr[0].split('/')[-1]}, "
256 elif p.returncode == 12:
258 f
"Terminated: {fr[0].split('/')[-1]}, run_{fr[1]} "
259 f
"with exit code: {p.returncode}"
262 f
"{fr[0].split('/')[-1]}, run_{fr[1]} ran out of "
263 f
"maximum allowed iterations before converging. "
264 f
"Will not relaunch it..."
267 print(f
"Waiting for {len(processes.keys())} processes to terminate...")
269 while len(processes) > 0:
272 for _, p
in processes.items():
273 if p.poll()
is not None:
274 final_waiting =
False
280 ) = get_curr_processes_and_terminated_runs(processes)
282 for proc
in successful_runs:
283 completed_runs.append(proc)
287 print(
"Performing housekeeping tasks")
289 for proc
in completed_runs:
291 if p.returncode == 0:
292 out, _ = p.communicate()
294 result = literal_eval(out[4:])
296 if run_deets[0].split(
"/")[-1]
not in results.keys():
297 results[f
"{run_deets[0].split('/')[-1]}"] = {
298 f
"run_{run_deets[1]}": result
301 results[f
"{run_deets[0].split('/')[-1]}"][
302 f
"run_{run_deets[1]}"
306 _, err = p.communicate()
310 if "nestor_output.yaml" in os.listdir(parent_path):
311 with open(os.path.join(parent_path,
"nestor_output.yaml"),
"r") as inf:
312 old_results = yaml.safe_load(inf)
313 merge(results, old_results)
315 with open(f"{parent_path}/nestor_output.yaml",
"w")
as outf:
316 yaml.dump(results, outf)
319 def plot_evi_proctime(nestor_results: dict, h_params: dict):
320 representations: list = []
321 log_evi_mean_sterr: list = []
322 proctime_mean_sterr: list = []
323 for k
in nestor_results:
324 representations.append(k.split(
"_")[-1])
325 log_evi, proctime = [], []
327 for k1
in nestor_results[k]:
328 log_evi.append(nestor_results[k][k1][
"log_estimated_evidence"])
329 proctime.append(nestor_results[k][k1][
"mcmc_step_time"])
331 log_evi_mean_sterr.append(
332 (np.mean(log_evi), np.std(log_evi) / math.sqrt(len(log_evi)))
334 proctime_mean_sterr.append(
335 (np.mean(proctime), np.std(proctime) / math.sqrt(len(proctime)))
338 log_evi_mean_sterr = np.array(log_evi_mean_sterr)
339 proctime_mean_sterr = np.array(proctime_mean_sterr)
341 fig, ax1 = plt.subplots()
344 y=log_evi_mean_sterr[:, 0],
345 yerr=log_evi_mean_sterr[:, 1],
348 label=
"Log(Evidence)",
351 ylabel =
"Mean log$Z$"
352 ax1.set_ylabel(ylabel)
353 ax1.set_xlabel(
"Representations")
355 ax2 = plt.twinx(ax=ax1)
359 y=proctime_mean_sterr[:, 0],
363 label=
"NestOR process time",
365 ax2.set_ylabel(
"Time per MCMC sampling step (sec)")
367 print(
"\n\nCame here\n\n")
368 output_fname: str = os.path.join(
369 h_params[
"parent_dir"],
"sterr_evi_and_proctime.png"
372 plt.savefig(f
"{output_fname}.png", dpi=1200)
378 h_param_file = args.paramf
379 use_topology = args.topology
381 with open(h_param_file,
"r") as h_paramf:
382 h_params = yaml.safe_load(h_paramf)
384 if not args.skip_calc:
385 run_nested_sampling(h_param_file, use_topology)
389 h_params[
"parent_dir"],
390 "nestor_output.yaml",
394 results = yaml.safe_load(outf)
396 if len(list(results.keys())) > 0:
397 print(
"Plotting the results")
398 plotter(results, h_params)
401 print(
"\nNone of the runs was successful...!")
402 print(
"Done...!\n\n")
409 if __name__ ==
"__main__":