1 """@namespace IMP.pmi.samplers
2 Sampling of the system.
10 class _SerialReplicaExchange:
11 """Dummy replica exchange class used in non-MPI builds.
12 It should act similarly to IMP.mpi.ReplicaExchange
13 on a single processor.
18 def get_number_of_replicas(self):
21 def create_temperatures(self, tmin, tmax, nrep):
24 def get_my_index(self):
27 def set_my_parameter(self, key, val):
28 self.__params[key] = val
30 def get_my_parameter(self, key):
31 return self.__params[key]
33 def get_friend_index(self, step):
36 def get_friend_parameter(self, key, findex):
37 return self.get_my_parameter(key)
39 def do_exchange(self, myscore, fscore, findex):
42 def set_was_used(self, was_used):
43 self.was_used = was_used
47 def __init__(self, model):
51 self.simulated_annealing =
False
53 def set_simulated_annealing(self, min_temp, max_temp, min_temp_time,
55 self.simulated_annealing =
True
56 self.tempmin = min_temp
57 self.tempmax = max_temp
58 self.timemin = min_temp_time
59 self.timemax = max_temp_time
61 def temp_simulated_annealing(self):
62 if self.nframe % (self.timemin + self.timemax) < self.timemin:
66 temp = self.tempmin + (self.tempmax - self.tempmin) * value
71 """Sample using Monte Carlo"""
80 def __init__(self, model, objects=None, temp=1.0, filterbyname=None,
82 """Setup Monte Carlo sampling
83 @param model The IMP Model
84 @param objects What to sample (a list of Movers)
85 @param temp The MC temperature
86 @param filterbyname Not used
87 @param score_moved If True, attempt to speed up sampling by
88 caching scoring function terms on particles that didn't move
98 self.selfadaptive =
False
103 self.movers_data = {}
105 self._jax_optimizer =
None
106 self._jax_state =
None
114 self.mc.set_scoring_function(get_restraint_set(self.model))
115 self.mc.set_return_best(
False)
116 self.mc.set_score_moved(score_moved)
117 self.mc.set_kt(self.temp)
118 self.mc.add_mover(self.smv)
121 """Request that sampling of the scoring function is done using
122 JAX instead of IMP's internal C++ implementation (requires
123 that all PMI restraints used have a JAX implementation)."""
125 self._jax_optimizer = self.mc._get_jax_optimizer(
126 nstep * self.get_number_of_movers())
127 self._jax_state = self._jax_optimizer.get_initial_state()
130 """Get the current JAX Model used by the sampler."""
131 return self._jax_state.jm
133 def set_kt(self, temp):
136 if self._jax_state
is not None:
137 self._jax_state.temperature = temp
142 def set_scoring_function(self, objectlist):
144 for ob
in objectlist:
145 rs.add_restraint(ob.get_restraint())
147 self.mc.set_scoring_function(sf)
149 def set_self_adaptive(self, isselfadaptive=True):
150 self.selfadaptive = isselfadaptive
152 def get_number_of_movers(self):
153 return len(self.smv.get_movers())
155 def get_particle_types(self):
158 def optimize(self, nstep):
161 score, self._jax_state = self._jax_optimizer.optimize(
164 score = self.mc.optimize(nstep * self.get_number_of_movers())
167 if self.simulated_annealing:
168 self.set_kt(self.temp_simulated_annealing())
171 if self.selfadaptive:
176 """Modify parameters of individual movers to try to keep acceptance
179 raise NotImplementedError(
180 "Adaptive protocol is not yet implemented for JAX")
181 for i, mv
in enumerate(self.mvs):
183 mvacc = mv.get_number_of_accepted()
184 mvprp = mv.get_number_of_proposed()
185 if mv
not in self.movers_data:
186 accept = float(mvacc) / float(mvprp)
187 self.movers_data[mv] = (mvacc, mvprp)
189 oldmvacc, oldmvprp = self.movers_data[mv]
190 accept = float(mvacc-oldmvacc) / float(mvprp-oldmvprp)
191 self.movers_data[mv] = (mvacc, mvprp)
198 stepsize = mv.get_sigma()
199 if 0.4 > accept
or accept > 0.6:
200 mv.set_sigma(stepsize * 2 * accept)
203 stepsize = mv.get_radius()
204 if 0.4 > accept
or accept > 0.6:
205 mv.set_radius(stepsize * 2 * accept)
208 mr = mv.get_maximum_rotation()
209 mt = mv.get_maximum_translation()
210 if 0.4 > accept
or accept > 0.6:
211 mv.set_maximum_rotation(mr * 2 * accept)
212 mv.set_maximum_translation(mt * 2 * accept)
215 mr = mv.get_maximum_rotation()
216 mt = mv.get_maximum_translation()
217 if 0.4 > accept
or accept > 0.6:
218 mv.set_maximum_rotation(mr * 2 * accept)
219 mv.set_maximum_translation(mt * 2 * accept)
223 if 0.4 > accept
or accept > 0.6:
224 mv.set_radius(mr * 2 * accept)
226 def set_label(self, label):
229 def get_frame_number(self):
232 def get_output(self):
234 for i, mv
in enumerate(self.smv.get_movers()):
235 mvname = mv.get_name()
236 mvacc = mv.get_number_of_accepted()
237 mvprp = mv.get_number_of_proposed()
239 mvacr = float(mvacc) / float(mvprp)
242 output[
"MonteCarlo_Acceptance_" +
243 mvname +
"_" + str(i)] = str(mvacr)
244 if "Nuisances" in mvname:
245 output[
"MonteCarlo_StepSize_" + mvname +
"_" + str(i)] = \
246 str(IMP.core.NormalMover.get_from(mv).get_sigma())
247 if "Weights" in mvname:
248 output[
"MonteCarlo_StepSize_" + mvname +
"_" + str(i)] = \
249 str(IMP.isd.WeightMover.get_from(mv).get_radius())
250 output[
"MonteCarlo_Temperature"] = str(self.mc.get_kt())
251 output[
"MonteCarlo_Nframe"] = str(self.nframe)
256 """Sample using molecular dynamics"""
258 def __init__(self, model, objects, kt, gamma=0.01, maximum_time_step=1.0,
259 sf=
None, use_jax=
False):
261 @param model The IMP Model
262 @param objects What to sample. Use flat list of particles
263 @param kt Temperature
264 @param gamma Viscosity parameter
265 @param maximum_time_step MD max time step
272 psamp = obj.get_particles_to_sample()
273 to_sample = psamp[
'Floppy_Bodies_SimplifiedModel'][0]
278 self.model, to_sample, kt/0.0019872041, gamma)
280 self.md.set_maximum_time_step(maximum_time_step)
282 self.md.set_scoring_function(sf)
284 self.md.set_scoring_function(get_restraint_set(self.model))
285 self.md.add_optimizer_state(self.ltstate)
288 """Request that sampling of the scoring function is done using
289 JAX instead of IMP's internal C++ implementation (requires
290 that all PMI restraints used have a JAX implementation)."""
291 raise NotImplementedError(
"JAX currently only supported for MC")
293 def set_kt(self, kt):
294 temp = kt/0.0019872041
295 self.ltstate.set_temperature(temp)
296 self.md.assign_velocities(temp)
298 def set_gamma(self, gamma):
299 self.ltstate.set_gamma(gamma)
301 def optimize(self, nsteps):
304 if self.simulated_annealing:
305 self.set_kt(self.temp_simulated_annealing())
306 return self.md.optimize(nsteps)
308 def get_output(self):
310 output[
"MolecularDynamics_KineticEnergy"] = \
311 str(self.md.get_kinetic_energy())
316 """Sample using conjugate gradients"""
318 def __init__(self, model, objects):
322 self.cg.set_scoring_function(get_restraint_set(self.model))
324 def set_label(self, label):
327 def get_frame_number(self):
330 def optimize(self, nstep):
332 self.cg.optimize(nstep)
334 def set_scoring_function(self, objectlist):
336 for ob
in objectlist:
337 rs.add_restraint(ob.get_restraint())
339 self.cg.set_scoring_function(sf)
341 def get_output(self):
343 output[
"ConjugatedGradients_Nframe"] = str(self.nframe)
348 """Sample using replica exchange"""
350 def __init__(self, model, tempmin, tempmax, samplerobjects, test=True,
351 replica_exchange_object=
None):
353 samplerobjects can be a list of MonteCarlo or MolecularDynamics
357 self.samplerobjects = samplerobjects
359 self.TEMPMIN_ = tempmin
360 self.TEMPMAX_ = tempmax
362 if replica_exchange_object
is None:
366 print(
'ReplicaExchange: MPI was found. '
367 'Using Parallel Replica Exchange')
370 print(
'ReplicaExchange: Could not find MPI. '
371 'Using Serial Replica Exchange')
372 self.rem = _SerialReplicaExchange()
376 print(
'got existing rex object')
377 self.rem = replica_exchange_object
380 nproc = self.rem.get_number_of_replicas()
382 if nproc % 2 != 0
and not test:
384 "number of replicas has to be even. "
385 "set test=True to run with odd number of replicas.")
387 temp = self.rem.create_temperatures(
392 self.temperatures = temp
394 myindex = self.rem.get_my_index()
396 self.rem.set_my_parameter(
"temp", [self.temperatures[myindex]])
397 for so
in self.samplerobjects:
398 so.set_kt(self.temperatures[myindex])
404 def get_temperatures(self):
405 return self.temperatures
407 def get_my_temp(self):
408 return self.rem.get_my_parameter(
"temp")[0]
410 def get_my_index(self):
411 return self.rem.get_my_index()
413 def swap_temp(self, nframe, score=None):
415 score = self.model.evaluate(
False)
417 _ = self.rem.get_my_index()
418 mytemp = self.rem.get_my_parameter(
"temp")[0]
420 if mytemp == self.TEMPMIN_:
423 if mytemp == self.TEMPMAX_:
427 myscore = score / mytemp
430 findex = self.rem.get_friend_index(nframe)
431 ftemp = self.rem.get_friend_parameter(
"temp", findex)[0]
433 fscore = score / ftemp
436 flag = self.rem.do_exchange(myscore, fscore, findex)
441 for so
in self.samplerobjects:
445 def get_output(self):
447 if self.nattempts != 0:
448 output[
"ReplicaExchange_SwapSuccessRatio"] = str(
449 float(self.nsuccess) / self.nattempts)
450 output[
"ReplicaExchange_MinTempFrequency"] = str(
451 float(self.nmintemp) / self.nattempts)
452 output[
"ReplicaExchange_MaxTempFrequency"] = str(
453 float(self.nmaxtemp) / self.nattempts)
455 output[
"ReplicaExchange_SwapSuccessRatio"] = str(0)
456 output[
"ReplicaExchange_MinTempFrequency"] = str(0)
457 output[
"ReplicaExchange_MaxTempFrequency"] = str(0)
458 output[
"ReplicaExchange_CurrentTemp"] = str(self.get_my_temp())
463 def __init__(self, replica_exchange_object=None):
464 """Query values (ie score, and others)
465 from a set of parallel jobs"""
467 if replica_exchange_object
is None:
471 print(
'MPI_values: MPI was found. '
472 'Using Parallel Replica Exchange')
475 print(
'MPI_values: Could not find MPI. '
476 'Using Serial Replica Exchange')
477 self.rem = _SerialReplicaExchange()
481 print(
'got existing rex object')
482 self.rem = replica_exchange_object
484 def set_value(self, name, value):
485 self.rem.set_my_parameter(name, [value])
487 def get_values(self, name):
489 for i
in range(self.rem.get_number_of_replicas()):
490 v = self.rem.get_friend_parameter(name, i)[0]
494 def get_percentile(self, name):
495 value = self.rem.get_my_parameter(name)[0]
496 values = sorted(self.get_values(name))
497 ind = values.index(value)
498 percentile = float(ind)/len(values)
def __init__
samplerobjects can be a list of MonteCarlo or MolecularDynamics
A class to implement Hamiltonian Replica Exchange.
Maintains temperature during molecular dynamics.
Sample using molecular dynamics.
Modify the transformation of a rigid body.
Simple conjugate gradients optimizer.
def set_use_jax
Request that sampling of the scoring function is done using JAX instead of IMP's internal C++ impleme...
Sample using conjugate gradients.
Create a scoring function on a list of restraints.
Move continuous particle variables by perturbing them within a ball.
Object used to hold a set of restraints.
Simple molecular dynamics simulator.
Code that uses the MPI parallel library.
def set_use_jax
Request that sampling of the scoring function is done using JAX instead of IMP's internal C++ impleme...
A mover that perturbs a Weight particle.
def __init__
Setup Monte Carlo sampling.
Modify a set of continuous variables using a normal distribution.
Basic functionality that is expected to be used by a wide variety of IMP users.
Sample using Monte Carlo.
The general base class for IMP exceptions.
def apply_self_adaptive
Modify parameters of individual movers to try to keep acceptance rate around 50%. ...
Applies a list of movers one at a time.
def get_jax_model
Get the current JAX Model used by the sampler.
Sample using replica exchange.
Inferential scoring building on methods developed as part of the Inferential Structure Determination ...