IMP logo
IMP Reference Guide  develop.5ce362c3b9,2024/09/16
The Integrative Modeling Platform
nestor/__init__.py
1 from __future__ import print_function, division
2 import os
3 import glob
4 import time
5 import math
6 import yaml
7 import pickle
8 from mpi4py import MPI
9 
10 
11 class NestedSampling:
12  def __init__(self, h_param_file, nestor_restraints, rex_macro, exit_code):
13  with open(h_param_file, "r") as paramf:
14  self.h_params = yaml.safe_load(paramf)
15 
16  self.tic = -9999.0
17  self.toc = None
18  self.mcmc_step_time = None
19  self.num_init_frames = self.h_params["num_init_frames"]
20  self.num_frames_per_iter = self.h_params["num_frames_per_iter"]
21  self.nestor_niter = self.h_params["max_nestor_iter"]
22  self.rex_macro = rex_macro
23 
24  for restraint in nestor_restraints:
25  if restraint.weight != 0:
26  raise ValueError(
27  "Weight of all restraints in nestor_restraints "
28  "must be set to 0"
29  )
30 
31  self.rex_macro.nestor_restraints = nestor_restraints
32  self.rex_macro.nest = True
33 
34  self.max_plateau_hits = self.h_params["max_plateau_hits"]
35  self.plateau_hits = 0
36  self.max_failed_iter = self.h_params["max_failed_iterations"]
37  self.failed_iter = 0
38 
39  self.Xi = 1
40  self.Z = 0
41  self.H = 0
42 
43  self.worst_li_list = []
44  self.worst_xi_list = []
45  self.log_worst_li = []
46  self.log_xi = []
47  self.xi = []
48 
49  self.comm_obj = MPI.COMM_WORLD
50  self.finished = False
51  self.termination_mode = "None"
52  self.exit_code = exit_code
53  self.return_vals = {}
54 
55  def sample_initial_frames(self):
56  self.rex_macro.vars["number_of_frames"] = self.num_init_frames
57  start_time = time.time()
58  self.rex_macro.execute_macro()
59  end_time = time.time()
60  per_frame_sampling_time = ((end_time - start_time)
61  / self.num_init_frames)
62  self.mcmc_step_time = (
63  per_frame_sampling_time / self.rex_macro.vars["monte_carlo_steps"]
64  )
65 
66  def parse_likelihoods(self, iteration, fhead="likelihoods_"):
67  sampled_likelihoods = []
68  all_likelihood_binaries = glob.glob(f"{fhead}*")
69 
70  for binfile in all_likelihood_binaries:
71  likelihoods = []
72  with open(binfile, "rb") as rlif:
73  likelihoods = pickle.load(rlif)
74  for li in likelihoods:
75  sampled_likelihoods.append(li)
76  os.remove(binfile)
77 
78  is_nan = False
79  for li in sampled_likelihoods:
80  if math.isnan(li):
81  is_nan = True
82  if is_nan:
83  self.termination_mode = "Error: Nan found"
84  self.exit_code = 11
85  print("NaN found. Terminating...")
86  for li in sampled_likelihoods:
87  print(li)
88  self.terminator(
89  iteration=iteration,
90  plateau_hits=self.plateau_hits,
91  failed_iter=self.failed_iter,
92  )
93 
94  return sampled_likelihoods
95 
96  def check_plateau(self):
97  """
98  Check if Li/Xi is plateuing for consecutive samples, stop
99  """
100 
101  previous_Li = self.worst_li_list[-2]
102  current_Li = self.worst_li_list[-1]
103  previous_Xi = self.worst_xi_list[-2]
104  current_Xi = self.worst_xi_list[-1]
105 
106  if (current_Li / previous_Li) < (previous_Xi / current_Xi):
107  self.plateau_hits += 1
108  print(
109  f"{'---'*20}\nPlateau detector hits: "
110  f"{self.plateau_hits}/{self.max_plateau_hits}"
111  )
112  else:
113  self.plateau_hits = 0
114 
115  if self.plateau_hits == self.max_plateau_hits:
116  self.termination_mode = "MaxPlateauHits"
117  self.finished = True
118 
119  def terminator(self, iteration, plateau_hits, failed_iter):
120  from math import log
121 
122  self.toc = time.time()
123 
124  if "error" not in self.termination_mode.lower():
125  print(f"Estimated evidence sampled: {self.Z}")
126  self.exit_code = 0
127  try:
128  ana_unc = math.sqrt(self.H / self.num_init_frames)
129  except ValueError:
130  ana_unc = "Did not compute. H was negative"
131  print("Math domain error")
132  self.exit_code = 13
133 
134  from matplotlib import pyplot as plt
135 
136  fig, ax = plt.subplots(1)
137  ax.set_xlabel("log(Xi)")
138  ax.set_ylabel("log(Li)")
139  ax.clear()
140  ax.plot(self.log_xi, self.log_worst_li)
141  fig.savefig("log_lixi.png")
142  plt.close()
143 
144  fig, ax = plt.subplots(1)
145  ax.plot(self.xi, self.log_worst_li)
146  ax.set_xlabel("Xi")
147  ax.set_ylabel("log(Li)")
148  fig.savefig("lixi.png")
149  plt.close()
150 
151  self.return_vals["last_iter"] = iteration
152  self.return_vals["plateau_hits"] = plateau_hits
153  self.return_vals["failed_iter"] = failed_iter
154  self.return_vals["obtained_information"] = self.H
155  self.return_vals["analytical_uncertainty"] = ana_unc
156  self.return_vals["nestor_process_time"] = self.toc - self.tic
157  self.return_vals["mcmc_step_time"] = self.mcmc_step_time
158  self.return_vals["log_estimated_evidence"] = log(self.Z)
159 
160  else:
161  self.return_vals["run_params"] = self.h_params
162 
163  self.return_vals["termination_mode"] = self.termination_mode
164  self.return_vals["exit_code"] = self.exit_code
165 
166  def compute_evidence_H(self, iteration, curr_li):
167  # compute Z
168  curr_xi = math.exp(-iteration / self.num_init_frames)
169  curr_wi = self.Xi - curr_xi
170  prev_zi = self.Z
171  self.Z += curr_li * curr_wi
172  curr_zi = self.Z
173  self.Xi = curr_xi
174 
175  # compute H
176  if iteration > 1:
177  first_term = ((curr_li * curr_wi) / curr_zi) * math.log(curr_li)
178  second_term = (prev_zi / curr_zi) * (self.H + math.log(prev_zi))
179  self.H = first_term + second_term - math.log(curr_zi)
180 
181  def execute_nested_sampling2(self):
182  self.tic = time.time()
183  import matplotlib.pyplot as plt
184 
185  i = 0
186  true_iter = 0
187  base_process = self.comm_obj.Get_rank() == 0
188  self.comm_obj.Barrier()
189 
190  if "shuffle_config.err" in os.listdir("./"):
191  self.exit_code = 11
192 
193  self.comm_obj.Barrier()
194 
195  print(
196  f"Exit code from the macros after communication: {self.exit_code} "
197  f"at rank: {self.comm_obj.Get_rank()}"
198  )
199 
200  if self.exit_code is None:
201  # Check for nan through small test run
202  self.comm_obj.Barrier()
203  if base_process:
204  print(
205  f"{'-'*50}\nTest run complete, no NaN found. "
206  f"Continuing...\n{'-'*50}\n\n"
207  )
208  self.comm_obj.Barrier()
209 
210  self.sample_initial_frames()
211  self.comm_obj.Barrier()
212 
213  if base_process:
214  self.likelihoods = self.parse_likelihoods(iteration=true_iter)
215  self.comm_obj.Barrier()
216 
217  self.rex_macro.vars["number_of_frames"] = self.num_frames_per_iter
218  self.rex_macro.vars["replica_exchange_swap"] = True
219 
220  while true_iter < self.nestor_niter:
221  self.comm_obj.Barrier()
222  self.finished = self.comm_obj.bcast(self.finished, root=0)
223  self.exit_code = self.comm_obj.bcast(self.exit_code, root=0)
224 
225  if self.exit_code is not None:
226  # run log will exist if
227  # a. parse_likelihoods had a nan error in the test iter,
228  # called terminator.
229  # b. convergence criterion plateau reached, called
230  # terminator.
231  # c. convergence criterion max_failed_iterations reached,
232  # called terminator.
233  break
234 
235  if not self.finished:
236  # Other processes should not sample more models as
237  # convergence criteria i.e. a. max_failed_iterations
238  # or b. plateau triggered and the likelihoods list is
239  # unraveled to accumulate Z/H.
240  self.rex_macro.execute_macro()
241  self.comm_obj.Barrier()
242 
243  if base_process:
244  if len(self.likelihoods) != 0:
245  Li = min(self.likelihoods)
246  if not self.finished:
247  newly_sampled_likelihoods = self.parse_likelihoods(
248  iteration=true_iter
249  )
250  candidate_li = max(newly_sampled_likelihoods)
251  else: # unraveling
252  candidate_li = Li
253 
254  if candidate_li >= Li:
255  self.likelihoods.remove(Li)
256 
257  if not self.finished:
258  self.likelihoods.append(candidate_li)
259 
260  # print(self.likelihoods, "\n", Li)
261  self.compute_evidence_H(iteration=i, curr_li=Li)
262  self.log_worst_li.append(math.log(Li))
263 
264  self.log_xi.append(math.log(self.Xi))
265  self.xi.append(self.Xi)
266  self.worst_li_list.append(Li)
267  self.worst_xi_list.append(self.Xi)
268 
269  if not self.finished:
270  if i > 1:
271  self.check_plateau()
272  self.failed_iter = 0
273  i += 1
274 
275  else:
276  self.failed_iter += 1
277  if self.failed_iter == self.max_failed_iter:
278  self.termination_mode = "MaxFailedIterations"
279  self.finished = True
280 
281  true_iter += 1
282  print(
283  f'\n-----> True iteration: {true_iter} {" "*5} '
284  f'Calculation iteration: {i} {" "*5} '
285  f'Failed iteration: {self.failed_iter} {" "*5} '
286  f'Evidence: {self.Z} {" "*5} '
287  f'Terminating: {self.finished}\n'
288  )
289  if true_iter % 10 == 0:
290  from math import log
291 
292  tempout = {
293  "True iteration": true_iter,
294  "Calculation iteration": i,
295  "Failed iteration": self.failed_iter,
296  "Log Evidence": log(self.Z),
297  "Plateau hits": self.plateau_hits,
298  }
299  with open("temporary_output.yaml", "w") as tof:
300  yaml.dump(tempout, tof)
301 
302  else:
303  self.terminator(
304  iteration=true_iter,
305  plateau_hits=self.plateau_hits,
306  failed_iter=self.failed_iter,
307  )
308 
309  live_fig, live_ax = plt.subplots(1)
310  live_ax.set_xlabel("log(Xi)")
311  live_ax.set_ylabel("log(Li)")
312  live_ax.plot(self.log_xi, self.log_worst_li)
313  live_fig.savefig("live_loglixi.png")
314  plt.close()
315 
316  self.comm_obj.Barrier()
317  true_iter = self.comm_obj.bcast(true_iter, root=0)
318  if true_iter == self.nestor_niter:
319  self.termination_mode = (
320  "Error: MaxIterations reached without convergence "
321  "criteria")
322  self.exit_code = 12
323  self.exit_code = self.comm_obj.bcast(self.exit_code,
324  root=0)
325  self.terminator(
326  iteration=true_iter,
327  plateau_hits=self.plateau_hits,
328  failed_iter=self.failed_iter,
329  )
330 
331  else:
332  if base_process:
333  self.termination_mode = "Error: Shuffle configuration error"
334  self.exit_code = 11
335  self.exit_code = self.comm_obj.bcast(self.exit_code, root=0)
336  self.terminator(iteration=0, plateau_hits=0, failed_iter=0)
337  self.comm_obj.Barrier()
338 
339  self.exit_code = self.comm_obj.bcast(self.exit_code, root=0)
340 
341  if base_process:
342  return self.return_vals, self.exit_code
343  else:
344  return None, None
345 
346 
347 __version__ = "9709206"
348 
350  '''Return the version of this module, as a string'''
351  return "9709206"
352 
353 def get_module_name():
354  '''Return the fully-qualified name of this module'''
355  return "IMP::nestor"
356 
357 def get_data_path(fname):
358  '''Return the full path to one of this module's data files'''
359  import IMP
360  return IMP._get_module_data_path("nestor", fname)
361 
362 def get_example_path(fname):
363  '''Return the full path to one of this module's example files'''
364  import IMP
365  return IMP._get_module_example_path("nestor", fname)
def get_example_path
Return the full path to one of this module's example files.
def get_module_version
Return the version of this module, as a string.
log
Definition: log.py:1
def get_data_path
Return the full path to one of this module's data files.
def get_module_name
Return the fully-qualified name of this module.