1 from __future__
import print_function, division
12 def __init__(self, h_param_file, nestor_restraints, rex_macro, exit_code):
13 with open(h_param_file,
"r") as paramf:
14 self.h_params = yaml.safe_load(paramf)
18 self.mcmc_step_time =
None
19 self.num_init_frames = self.h_params[
"num_init_frames"]
20 self.num_frames_per_iter = self.h_params[
"num_frames_per_iter"]
21 self.nestor_niter = self.h_params[
"max_nestor_iter"]
22 self.rex_macro = rex_macro
24 for restraint
in nestor_restraints:
25 if restraint.weight != 0:
27 "Weight of all restraints in nestor_restraints "
31 self.rex_macro.nestor_restraints = nestor_restraints
32 self.rex_macro.nest =
True
34 self.max_plateau_hits = self.h_params[
"max_plateau_hits"]
36 self.max_failed_iter = self.h_params[
"max_failed_iterations"]
43 self.worst_li_list = []
44 self.worst_xi_list = []
45 self.log_worst_li = []
49 self.comm_obj = MPI.COMM_WORLD
51 self.termination_mode =
"None"
52 self.exit_code = exit_code
55 def sample_initial_frames(self):
56 self.rex_macro.vars[
"number_of_frames"] = self.num_init_frames
57 start_time = time.time()
58 self.rex_macro.execute_macro()
59 end_time = time.time()
60 per_frame_sampling_time = ((end_time - start_time)
61 / self.num_init_frames)
62 self.mcmc_step_time = (
63 per_frame_sampling_time / self.rex_macro.vars[
"monte_carlo_steps"]
66 def parse_likelihoods(self, iteration, fhead="likelihoods_"):
67 sampled_likelihoods = []
68 all_likelihood_binaries = glob.glob(f
"{fhead}*")
70 for binfile
in all_likelihood_binaries:
72 with open(binfile,
"rb")
as rlif:
73 likelihoods = pickle.load(rlif)
74 for li
in likelihoods:
75 sampled_likelihoods.append(li)
79 for li
in sampled_likelihoods:
83 self.termination_mode =
"Error: Nan found"
85 print(
"NaN found. Terminating...")
86 for li
in sampled_likelihoods:
90 plateau_hits=self.plateau_hits,
91 failed_iter=self.failed_iter,
94 return sampled_likelihoods
96 def check_plateau(self):
98 Check if Li/Xi is plateuing for consecutive samples, stop
101 previous_Li = self.worst_li_list[-2]
102 current_Li = self.worst_li_list[-1]
103 previous_Xi = self.worst_xi_list[-2]
104 current_Xi = self.worst_xi_list[-1]
106 if (current_Li / previous_Li) < (previous_Xi / current_Xi):
107 self.plateau_hits += 1
109 f
"{'---'*20}\nPlateau detector hits: "
110 f
"{self.plateau_hits}/{self.max_plateau_hits}"
113 self.plateau_hits = 0
115 if self.plateau_hits == self.max_plateau_hits:
116 self.termination_mode =
"MaxPlateauHits"
119 def terminator(self, iteration, plateau_hits, failed_iter):
122 self.toc = time.time()
124 if "error" not in self.termination_mode.lower():
125 print(f
"Estimated evidence sampled: {self.Z}")
128 ana_unc = math.sqrt(self.H / self.num_init_frames)
130 ana_unc =
"Did not compute. H was negative"
131 print(
"Math domain error")
134 from matplotlib
import pyplot
as plt
136 fig, ax = plt.subplots(1)
137 ax.set_xlabel(
"log(Xi)")
138 ax.set_ylabel(
"log(Li)")
140 ax.plot(self.log_xi, self.log_worst_li)
141 fig.savefig(
"log_lixi.png")
144 fig, ax = plt.subplots(1)
145 ax.plot(self.xi, self.log_worst_li)
147 ax.set_ylabel(
"log(Li)")
148 fig.savefig(
"lixi.png")
151 self.return_vals[
"last_iter"] = iteration
152 self.return_vals[
"plateau_hits"] = plateau_hits
153 self.return_vals[
"failed_iter"] = failed_iter
154 self.return_vals[
"obtained_information"] = self.H
155 self.return_vals[
"analytical_uncertainty"] = ana_unc
156 self.return_vals[
"nestor_process_time"] = self.toc - self.tic
157 self.return_vals[
"mcmc_step_time"] = self.mcmc_step_time
158 self.return_vals[
"log_estimated_evidence"] = log(self.Z)
161 self.return_vals[
"run_params"] = self.h_params
163 self.return_vals[
"termination_mode"] = self.termination_mode
164 self.return_vals[
"exit_code"] = self.exit_code
166 def compute_evidence_H(self, iteration, curr_li):
168 curr_xi = math.exp(-iteration / self.num_init_frames)
169 curr_wi = self.Xi - curr_xi
171 self.Z += curr_li * curr_wi
177 first_term = ((curr_li * curr_wi) / curr_zi) * math.log(curr_li)
178 second_term = (prev_zi / curr_zi) * (self.H + math.log(prev_zi))
179 self.H = first_term + second_term - math.log(curr_zi)
181 def execute_nested_sampling2(self):
182 self.tic = time.time()
183 import matplotlib.pyplot
as plt
187 base_process = self.comm_obj.Get_rank() == 0
188 self.comm_obj.Barrier()
190 if "shuffle_config.err" in os.listdir(
"./"):
193 self.comm_obj.Barrier()
196 f
"Exit code from the macros after communication: {self.exit_code} "
197 f
"at rank: {self.comm_obj.Get_rank()}"
200 if self.exit_code
is None:
202 self.comm_obj.Barrier()
205 f
"{'-'*50}\nTest run complete, no NaN found. "
206 f
"Continuing...\n{'-'*50}\n\n"
208 self.comm_obj.Barrier()
210 self.sample_initial_frames()
211 self.comm_obj.Barrier()
214 self.likelihoods = self.parse_likelihoods(iteration=true_iter)
215 self.comm_obj.Barrier()
217 self.rex_macro.vars[
"number_of_frames"] = self.num_frames_per_iter
218 self.rex_macro.vars[
"replica_exchange_swap"] =
True
220 while true_iter < self.nestor_niter:
221 self.comm_obj.Barrier()
222 self.finished = self.comm_obj.bcast(self.finished, root=0)
223 self.exit_code = self.comm_obj.bcast(self.exit_code, root=0)
225 if self.exit_code
is not None:
235 if not self.finished:
240 self.rex_macro.execute_macro()
241 self.comm_obj.Barrier()
244 if len(self.likelihoods) != 0:
245 Li = min(self.likelihoods)
246 if not self.finished:
247 newly_sampled_likelihoods = self.parse_likelihoods(
250 candidate_li = max(newly_sampled_likelihoods)
254 if candidate_li >= Li:
255 self.likelihoods.remove(Li)
257 if not self.finished:
258 self.likelihoods.append(candidate_li)
261 self.compute_evidence_H(iteration=i, curr_li=Li)
262 self.log_worst_li.append(math.log(Li))
264 self.log_xi.append(math.log(self.Xi))
265 self.xi.append(self.Xi)
266 self.worst_li_list.append(Li)
267 self.worst_xi_list.append(self.Xi)
269 if not self.finished:
276 self.failed_iter += 1
277 if self.failed_iter == self.max_failed_iter:
278 self.termination_mode =
"MaxFailedIterations"
283 f
'\n-----> True iteration: {true_iter} {" "*5} '
284 f
'Calculation iteration: {i} {" "*5} '
285 f
'Failed iteration: {self.failed_iter} {" "*5} '
286 f
'Evidence: {self.Z} {" "*5} '
287 f
'Terminating: {self.finished}\n'
289 if true_iter % 10 == 0:
293 "True iteration": true_iter,
294 "Calculation iteration": i,
295 "Failed iteration": self.failed_iter,
296 "Log Evidence": log(self.Z),
297 "Plateau hits": self.plateau_hits,
299 with open(
"temporary_output.yaml",
"w")
as tof:
300 yaml.dump(tempout, tof)
305 plateau_hits=self.plateau_hits,
306 failed_iter=self.failed_iter,
309 live_fig, live_ax = plt.subplots(1)
310 live_ax.set_xlabel(
"log(Xi)")
311 live_ax.set_ylabel(
"log(Li)")
312 live_ax.plot(self.log_xi, self.log_worst_li)
313 live_fig.savefig(
"live_loglixi.png")
316 self.comm_obj.Barrier()
317 true_iter = self.comm_obj.bcast(true_iter, root=0)
318 if true_iter == self.nestor_niter:
319 self.termination_mode = (
320 "Error: MaxIterations reached without convergence "
323 self.exit_code = self.comm_obj.bcast(self.exit_code,
327 plateau_hits=self.plateau_hits,
328 failed_iter=self.failed_iter,
333 self.termination_mode =
"Error: Shuffle configuration error"
335 self.exit_code = self.comm_obj.bcast(self.exit_code, root=0)
336 self.terminator(iteration=0, plateau_hits=0, failed_iter=0)
337 self.comm_obj.Barrier()
339 self.exit_code = self.comm_obj.bcast(self.exit_code, root=0)
342 return self.return_vals, self.exit_code
347 __version__ =
"2.21.0"
350 '''Return the version of this module, as a string'''
354 '''Return the fully-qualified name of this module'''
358 '''Return the full path to one of this module's data files'''
360 return IMP._get_module_data_path(
"nestor", fname)
363 '''Return the full path to one of this module's example files'''
365 return IMP._get_module_example_path(
"nestor", fname)
def get_module_version
Return the version of this module, as a string.
def get_example_path
Return the full path to one of this module's example files.
def get_data_path
Return the full path to one of this module's data files.
def get_module_name
Return the fully-qualified name of this module.