IMP logo
IMP Reference Guide  develop.7f24702912,2026/04/21
The Integrative Modeling Platform
samplers.py
1 """@namespace IMP.pmi.samplers
2  Sampling of the system.
3 """
4 
5 import IMP
6 import IMP.core
7 from IMP.pmi.tools import get_restraint_set
8 
9 
10 class _SerialReplicaExchange:
11  """Dummy replica exchange class used in non-MPI builds.
12  It should act similarly to IMP.mpi.ReplicaExchange
13  on a single processor.
14  """
15  def __init__(self):
16  self.__params = {}
17 
18  def get_number_of_replicas(self):
19  return 1
20 
21  def create_temperatures(self, tmin, tmax, nrep):
22  return [tmin]
23 
24  def get_my_index(self):
25  return 0
26 
27  def set_my_parameter(self, key, val):
28  self.__params[key] = val
29 
30  def get_my_parameter(self, key):
31  return self.__params[key]
32 
33  def get_friend_index(self, step):
34  return 0
35 
36  def get_friend_parameter(self, key, findex):
37  return self.get_my_parameter(key)
38 
39  def do_exchange(self, myscore, fscore, findex):
40  return False
41 
42  def set_was_used(self, was_used):
43  self.was_used = was_used
44 
45 
46 class _SamplerBase:
47  def __init__(self, model):
48  self.model = model
49  # that is -1 because mc/md has not yet run
50  self.nframe = -1
51  self.simulated_annealing = False
52 
53  def set_simulated_annealing(self, min_temp, max_temp, min_temp_time,
54  max_temp_time):
55  self.simulated_annealing = True
56  self.tempmin = min_temp
57  self.tempmax = max_temp
58  self.timemin = min_temp_time
59  self.timemax = max_temp_time
60 
61  def temp_simulated_annealing(self):
62  if self.nframe % (self.timemin + self.timemax) < self.timemin:
63  value = 0.0
64  else:
65  value = 1.0
66  temp = self.tempmin + (self.tempmax - self.tempmin) * value
67  return temp
68 
69 
70 class MonteCarlo(_SamplerBase):
71  """Sample using Monte Carlo"""
72 
73  # check that isd is installed
74  try:
75  import IMP.isd
76  isd_available = True
77  except ImportError:
78  isd_available = False
79 
80  def __init__(self, model, objects=None, temp=1.0, filterbyname=None,
81  score_moved=False):
82  """Setup Monte Carlo sampling
83  @param model The IMP Model
84  @param objects What to sample (a list of Movers)
85  @param temp The MC temperature
86  @param filterbyname Not used
87  @param score_moved If True, attempt to speed up sampling by
88  caching scoring function terms on particles that didn't move
89  """
90  super().__init__(model)
91  self.losp = [
92  "Rigid_Bodies",
93  "Floppy_Bodies",
94  "Nuisances",
95  "X_coord",
96  "Weights"
97  "Surfaces"]
98  self.selfadaptive = False
99  self.temp = temp
100  self.mvs = []
101  self.mvslabels = []
102  self.label = "None"
103  self.movers_data = {}
104  self.use_jax = False
105  self._jax_optimizer = None
106  self._jax_state = None
107 
108  self.mvs = objects
109 
110  # SerialMover
111  self.smv = IMP.core.SerialMover(self.mvs)
112 
113  self.mc = IMP.core.MonteCarlo(self.model)
114  self.mc.set_scoring_function(get_restraint_set(self.model))
115  self.mc.set_return_best(False)
116  self.mc.set_score_moved(score_moved)
117  self.mc.set_kt(self.temp)
118  self.mc.add_mover(self.smv)
119 
120  def set_use_jax(self, nstep):
121  """Request that sampling of the scoring function is done using
122  JAX instead of IMP's internal C++ implementation (requires
123  that all PMI restraints used have a JAX implementation)."""
124  self.use_jax = True
125  self._jax_optimizer = self.mc._get_jax_optimizer(
126  nstep * self.get_number_of_movers())
127  self._jax_state = self._jax_optimizer.get_initial_state()
128 
129  def get_jax_model(self):
130  """Get the current JAX Model used by the sampler."""
131  return self._jax_state.jm
132 
133  def set_kt(self, temp):
134  self.temp = temp
135  self.mc.set_kt(temp)
136  if self._jax_state is not None:
137  self._jax_state.temperature = temp
138 
139  def get_mc(self):
140  return self.mc
141 
142  def set_scoring_function(self, objectlist):
143  rs = IMP.RestraintSet(self.model, 1.0, 'sfo')
144  for ob in objectlist:
145  rs.add_restraint(ob.get_restraint())
147  self.mc.set_scoring_function(sf)
148 
149  def set_self_adaptive(self, isselfadaptive=True):
150  self.selfadaptive = isselfadaptive
151 
152  def get_number_of_movers(self):
153  return len(self.smv.get_movers())
154 
155  def get_particle_types(self):
156  return self.losp
157 
158  def optimize(self, nstep):
159  self.nframe += 1
160  if self.use_jax:
161  score, self._jax_state = self._jax_optimizer.optimize(
162  self._jax_state)
163  else:
164  score = self.mc.optimize(nstep * self.get_number_of_movers())
165 
166  # apply simulated annealing protocol
167  if self.simulated_annealing:
168  self.set_kt(self.temp_simulated_annealing())
169 
170  # apply self adaptive protocol
171  if self.selfadaptive:
172  self.apply_self_adaptive()
173  return score
174 
176  """Modify parameters of individual movers to try to keep acceptance
177  rate around 50%"""
178  if self.use_jax:
179  raise NotImplementedError(
180  "Adaptive protocol is not yet implemented for JAX")
181  for i, mv in enumerate(self.mvs):
182 
183  mvacc = mv.get_number_of_accepted()
184  mvprp = mv.get_number_of_proposed()
185  if mv not in self.movers_data:
186  accept = float(mvacc) / float(mvprp)
187  self.movers_data[mv] = (mvacc, mvprp)
188  else:
189  oldmvacc, oldmvprp = self.movers_data[mv]
190  accept = float(mvacc-oldmvacc) / float(mvprp-oldmvprp)
191  self.movers_data[mv] = (mvacc, mvprp)
192  if accept < 0.05:
193  accept = 0.05
194  if accept > 1.0:
195  accept = 1.0
196 
197  if isinstance(mv, IMP.core.NormalMover):
198  stepsize = mv.get_sigma()
199  if 0.4 > accept or accept > 0.6:
200  mv.set_sigma(stepsize * 2 * accept)
201 
202  if isinstance(mv, IMP.isd.WeightMover):
203  stepsize = mv.get_radius()
204  if 0.4 > accept or accept > 0.6:
205  mv.set_radius(stepsize * 2 * accept)
206 
207  if isinstance(mv, IMP.core.RigidBodyMover):
208  mr = mv.get_maximum_rotation()
209  mt = mv.get_maximum_translation()
210  if 0.4 > accept or accept > 0.6:
211  mv.set_maximum_rotation(mr * 2 * accept)
212  mv.set_maximum_translation(mt * 2 * accept)
213 
214  if isinstance(mv, IMP.pmi.TransformMover):
215  mr = mv.get_maximum_rotation()
216  mt = mv.get_maximum_translation()
217  if 0.4 > accept or accept > 0.6:
218  mv.set_maximum_rotation(mr * 2 * accept)
219  mv.set_maximum_translation(mt * 2 * accept)
220 
221  if isinstance(mv, IMP.core.BallMover):
222  mr = mv.get_radius()
223  if 0.4 > accept or accept > 0.6:
224  mv.set_radius(mr * 2 * accept)
225 
226  def set_label(self, label):
227  self.label = label
228 
229  def get_frame_number(self):
230  return self.nframe
231 
232  def get_output(self):
233  output = {}
234  for i, mv in enumerate(self.smv.get_movers()):
235  mvname = mv.get_name()
236  mvacc = mv.get_number_of_accepted()
237  mvprp = mv.get_number_of_proposed()
238  try:
239  mvacr = float(mvacc) / float(mvprp)
240  except: # noqa: E722
241  mvacr = 0.0
242  output["MonteCarlo_Acceptance_" +
243  mvname + "_" + str(i)] = str(mvacr)
244  if "Nuisances" in mvname:
245  output["MonteCarlo_StepSize_" + mvname + "_" + str(i)] = \
246  str(IMP.core.NormalMover.get_from(mv).get_sigma())
247  if "Weights" in mvname:
248  output["MonteCarlo_StepSize_" + mvname + "_" + str(i)] = \
249  str(IMP.isd.WeightMover.get_from(mv).get_radius())
250  output["MonteCarlo_Temperature"] = str(self.mc.get_kt())
251  output["MonteCarlo_Nframe"] = str(self.nframe)
252  return output
253 
254 
255 class MolecularDynamics(_SamplerBase):
256  """Sample using molecular dynamics"""
257 
258  def __init__(self, model, objects, kt, gamma=0.01, maximum_time_step=1.0,
259  sf=None, use_jax=False):
260  """Setup MD
261  @param model The IMP Model
262  @param objects What to sample. Use flat list of particles
263  @param kt Temperature
264  @param gamma Viscosity parameter
265  @param maximum_time_step MD max time step
266  """
267  super().__init__(model)
268 
269  # check if using PMI1 objects dictionary, or just list of particles
270  try:
271  for obj in objects:
272  psamp = obj.get_particles_to_sample()
273  to_sample = psamp['Floppy_Bodies_SimplifiedModel'][0]
274  except: # noqa: E722
275  to_sample = objects
276 
278  self.model, to_sample, kt/0.0019872041, gamma)
279  self.md = IMP.atom.MolecularDynamics(self.model)
280  self.md.set_maximum_time_step(maximum_time_step)
281  if sf:
282  self.md.set_scoring_function(sf)
283  else:
284  self.md.set_scoring_function(get_restraint_set(self.model))
285  self.md.add_optimizer_state(self.ltstate)
286 
287  def set_use_jax(self, nstep):
288  """Request that sampling of the scoring function is done using
289  JAX instead of IMP's internal C++ implementation (requires
290  that all PMI restraints used have a JAX implementation)."""
291  raise NotImplementedError("JAX currently only supported for MC")
292 
293  def set_kt(self, kt):
294  temp = kt/0.0019872041
295  self.ltstate.set_temperature(temp)
296  self.md.assign_velocities(temp)
297 
298  def set_gamma(self, gamma):
299  self.ltstate.set_gamma(gamma)
300 
301  def optimize(self, nsteps):
302  # apply simulated annealing protocol
303  self.nframe += 1
304  if self.simulated_annealing:
305  self.set_kt(self.temp_simulated_annealing())
306  return self.md.optimize(nsteps)
307 
308  def get_output(self):
309  output = {}
310  output["MolecularDynamics_KineticEnergy"] = \
311  str(self.md.get_kinetic_energy())
312  return output
313 
314 
316  """Sample using conjugate gradients"""
317 
318  def __init__(self, model, objects):
319  self.model = model
320  self.nframe = -1
321  self.cg = IMP.core.ConjugateGradients(self.model)
322  self.cg.set_scoring_function(get_restraint_set(self.model))
323 
324  def set_label(self, label):
325  self.label = label
326 
327  def get_frame_number(self):
328  return self.nframe
329 
330  def optimize(self, nstep):
331  self.nframe += 1
332  self.cg.optimize(nstep)
333 
334  def set_scoring_function(self, objectlist):
335  rs = IMP.RestraintSet(self.model, 1.0, 'sfo')
336  for ob in objectlist:
337  rs.add_restraint(ob.get_restraint())
339  self.cg.set_scoring_function(sf)
340 
341  def get_output(self):
342  output = {}
343  output["ConjugatedGradients_Nframe"] = str(self.nframe)
344  return output
345 
346 
348  """Sample using replica exchange"""
349 
350  def __init__(self, model, tempmin, tempmax, samplerobjects, test=True,
351  replica_exchange_object=None):
352  '''
353  samplerobjects can be a list of MonteCarlo or MolecularDynamics
354  '''
355 
356  self.model = model
357  self.samplerobjects = samplerobjects
358  # min and max temperature
359  self.TEMPMIN_ = tempmin
360  self.TEMPMAX_ = tempmax
361 
362  if replica_exchange_object is None:
363  # initialize Replica Exchange class
364  try:
365  import IMP.mpi
366  print('ReplicaExchange: MPI was found. '
367  'Using Parallel Replica Exchange')
368  self.rem = IMP.mpi.ReplicaExchange()
369  except ImportError:
370  print('ReplicaExchange: Could not find MPI. '
371  'Using Serial Replica Exchange')
372  self.rem = _SerialReplicaExchange()
373 
374  else:
375  # get the replica exchange class instance from elsewhere
376  print('got existing rex object')
377  self.rem = replica_exchange_object
378 
379  # get number of replicas
380  nproc = self.rem.get_number_of_replicas()
381 
382  if nproc % 2 != 0 and not test:
383  raise Exception(
384  "number of replicas has to be even. "
385  "set test=True to run with odd number of replicas.")
386  # create array of temperatures, in geometric progression
387  temp = self.rem.create_temperatures(
388  self.TEMPMIN_,
389  self.TEMPMAX_,
390  nproc)
391  # get replica index
392  self.temperatures = temp
393 
394  myindex = self.rem.get_my_index()
395  # set initial value of the parameter (temperature) to exchange
396  self.rem.set_my_parameter("temp", [self.temperatures[myindex]])
397  for so in self.samplerobjects:
398  so.set_kt(self.temperatures[myindex])
399  self.nattempts = 0
400  self.nmintemp = 0
401  self.nmaxtemp = 0
402  self.nsuccess = 0
403 
404  def get_temperatures(self):
405  return self.temperatures
406 
407  def get_my_temp(self):
408  return self.rem.get_my_parameter("temp")[0]
409 
410  def get_my_index(self):
411  return self.rem.get_my_index()
412 
413  def swap_temp(self, nframe, score=None):
414  if score is None:
415  score = self.model.evaluate(False)
416  # get my replica index and temperature
417  _ = self.rem.get_my_index()
418  mytemp = self.rem.get_my_parameter("temp")[0]
419 
420  if mytemp == self.TEMPMIN_:
421  self.nmintemp += 1
422 
423  if mytemp == self.TEMPMAX_:
424  self.nmaxtemp += 1
425 
426  # score divided by kbt
427  myscore = score / mytemp
428 
429  # get my friend index and temperature
430  findex = self.rem.get_friend_index(nframe)
431  ftemp = self.rem.get_friend_parameter("temp", findex)[0]
432  # score divided by kbt
433  fscore = score / ftemp
434 
435  # try exchange
436  flag = self.rem.do_exchange(myscore, fscore, findex)
437 
438  self.nattempts += 1
439  # if accepted, change temperature
440  if (flag):
441  for so in self.samplerobjects:
442  so.set_kt(ftemp)
443  self.nsuccess += 1
444 
445  def get_output(self):
446  output = {}
447  if self.nattempts != 0:
448  output["ReplicaExchange_SwapSuccessRatio"] = str(
449  float(self.nsuccess) / self.nattempts)
450  output["ReplicaExchange_MinTempFrequency"] = str(
451  float(self.nmintemp) / self.nattempts)
452  output["ReplicaExchange_MaxTempFrequency"] = str(
453  float(self.nmaxtemp) / self.nattempts)
454  else:
455  output["ReplicaExchange_SwapSuccessRatio"] = str(0)
456  output["ReplicaExchange_MinTempFrequency"] = str(0)
457  output["ReplicaExchange_MaxTempFrequency"] = str(0)
458  output["ReplicaExchange_CurrentTemp"] = str(self.get_my_temp())
459  return output
460 
461 
462 class MPI_values:
463  def __init__(self, replica_exchange_object=None):
464  """Query values (ie score, and others)
465  from a set of parallel jobs"""
466 
467  if replica_exchange_object is None:
468  # initialize Replica Exchange class
469  try:
470  import IMP.mpi
471  print('MPI_values: MPI was found. '
472  'Using Parallel Replica Exchange')
473  self.rem = IMP.mpi.ReplicaExchange()
474  except ImportError:
475  print('MPI_values: Could not find MPI. '
476  'Using Serial Replica Exchange')
477  self.rem = _SerialReplicaExchange()
478 
479  else:
480  # get the replica exchange class instance from elsewhere
481  print('got existing rex object')
482  self.rem = replica_exchange_object
483 
484  def set_value(self, name, value):
485  self.rem.set_my_parameter(name, [value])
486 
487  def get_values(self, name):
488  values = []
489  for i in range(self.rem.get_number_of_replicas()):
490  v = self.rem.get_friend_parameter(name, i)[0]
491  values.append(v)
492  return values
493 
494  def get_percentile(self, name):
495  value = self.rem.get_my_parameter(name)[0]
496  values = sorted(self.get_values(name))
497  ind = values.index(value)
498  percentile = float(ind)/len(values)
499  return percentile
def __init__
samplerobjects can be a list of MonteCarlo or MolecularDynamics
Definition: samplers.py:350
A Monte Carlo optimizer.
Definition: MonteCarlo.h:44
A class to implement Hamiltonian Replica Exchange.
Maintains temperature during molecular dynamics.
Sample using molecular dynamics.
Definition: samplers.py:255
Miscellaneous utilities.
Definition: pmi/tools.py:1
Modify the transformation of a rigid body.
Simple conjugate gradients optimizer.
def set_use_jax
Request that sampling of the scoring function is done using JAX instead of IMP's internal C++ impleme...
Definition: samplers.py:120
Sample using conjugate gradients.
Definition: samplers.py:315
Create a scoring function on a list of restraints.
Move continuous particle variables by perturbing them within a ball.
Object used to hold a set of restraints.
Definition: RestraintSet.h:41
Simple molecular dynamics simulator.
Code that uses the MPI parallel library.
def set_use_jax
Request that sampling of the scoring function is done using JAX instead of IMP's internal C++ impleme...
Definition: samplers.py:287
A mover that perturbs a Weight particle.
Definition: WeightMover.h:20
Modify the transformation of a rigid body.
def __init__
Setup Monte Carlo sampling.
Definition: samplers.py:80
Modify a set of continuous variables using a normal distribution.
Definition: NormalMover.h:23
Basic functionality that is expected to be used by a wide variety of IMP users.
Sample using Monte Carlo.
Definition: samplers.py:70
The general base class for IMP exceptions.
Definition: exception.h:48
def apply_self_adaptive
Modify parameters of individual movers to try to keep acceptance rate around 50%. ...
Definition: samplers.py:175
Applies a list of movers one at a time.
Definition: SerialMover.h:26
def get_jax_model
Get the current JAX Model used by the sampler.
Definition: samplers.py:129
Sample using replica exchange.
Definition: samplers.py:347
Inferential scoring building on methods developed as part of the Inferential Structure Determination ...