Loading [MathJax]/extensions/tex2jax.js
IMP logo
IMP Reference Guide  2.20.1
The Integrative Modeling Platform
Replica.py
1 # replica exchange class
2 # inspired by ISD Replica.py
3 # Yannick
4 
5 from __future__ import print_function, division
6 import numpy as np
7 from numpy.random import random, randint
8 kB = 1.3806503 * 6.0221415 / 4184.0 # Boltzmann constant in kcal/mol/K
9 
10 
11 class ReplicaTracker:
12 
13  def __init__(self, nreps, inv_temps, grid, sfo_id,
14  rexlog='replicanums.txt', scheme='standard', xchg='random',
15  convectivelog='stirred.txt', tune_temps=False,
16  tune_data={}, templog='temps.txt'):
17  self.nreps = nreps
18  # replica number as a function of state no
19  self.replicanums = list(range(nreps))
20  # state no as a function of replica no
21  self.statenums = list(range(nreps))
22  self.grid = grid
23  self.sfo_id = sfo_id
24  # expect inverse temperatures
25  self.inv_temps = inv_temps
26  self.logfile = rexlog
27  self.stepno = 1
28  self.scheme = scheme
29  self.xchg = xchg
30  self.tune_temps = tune_temps
31  self.tune_data = tune_data
32  self.rn_history = []
33  self.templog = templog
34  # scheme is one of gromacs, randomneighbors or convective
35  if scheme == "convective":
36  self.stirred = {}
37  # which order the replicas should be chosen
38  self.stirred['order'] = list(range(self.nreps))
39  self.stirred['order'].reverse()
40  # which replica is currently being stirred
41  self.stirred['replica'] = self.stirred['order'][0]
42  # which position in the order list this replica sits in
43  self.stirred['pos'] = 0
44  # what direction should this replica should go to
45  if self.stirred['replica'] != self.nreps - 1:
46  self.stirred['dir'] = 1 # up
47  else:
48  self.stirred['dir'] = 0 # down
49  # how many steps are remaining til we change stirring replica
50  self.stirred['steps'] = 2 * (self.nreps - 1)
51  self.convectivelog = convectivelog
52  self.write_rex_stats()
53 
54  def sort_per_state(self, inplist):
55  "sorts a replica list per state"
56  if len(inplist) != self.nreps:
57  raise ValueError('list has wrong size')
58  return [inplist[i] for i in self.replicanums]
59 
60  def sort_per_replica(self, inplist):
61  "sorts a state list per replica number"
62  if len(inplist) != self.nreps:
63  raise ValueError('list has wrong size')
64  return [inplist[i] for i in self.statenums]
65 
66  def get_energies(self):
67  "return replica-sorted energies"
68  return self.grid.gather(
69  self.grid.broadcast(self.sfo_id, 'm', 'evaluate', False))
70 
71  def gen_pairs_list_gromacs(self, direction):
72  """generate ordered list of exchange pairs based on direction.
73  direction == 0 : (0,1)(2,3)...
74  direction == 1 : (0(1,2)(3,4)...
75  returns only pairs.
76  """
77  nreps = self.nreps
78  if direction != 0 and direction != 1:
79  raise ValueError(direction)
80  if nreps == 2:
81  return [(0, 1)]
82  ret = [(2 * i + direction, 2 * i + 1 + direction)
83  for i in range(nreps // 2)]
84  if nreps in ret[-1]:
85  ret.pop()
86  return ret
87 
88  def gen_pairs_list_rand(self, needed=[]):
89  "generate list of neighboring pairs of states"
90  nreps = self.nreps
91  pairslist = []
92  # generate all possible pairs
93  init = [(i, i + 1) for i in range(nreps - 1)]
94  # add this pair to form a loop
95  init.append((0, nreps - 1))
96  # add needed pairs and remove overlapping candidates
97  for (i, j) in needed:
98  if j - i != 1:
99  raise ValueError("wrong format for 'needed' list")
100  pairslist.append((i, i + 1))
101  init.remove((i - 1, i))
102  init.remove((i, i + 1))
103  init.remove((i + 1, i + 2))
104  while len(init) > 0:
105  # choose random pair
106  i = randint(0, len(init)) # numpy randint is [a,b[
107  # remove it from list
108  pair = init.pop(i)
109  # print pair
110  # remove overlapping
111  init = [(r, q) for (r, q) in init
112  if (r not in pair and q not in pair)]
113  # print init
114  # add to pairslist
115  if not pair == (0, nreps - 1):
116  pairslist.append(pair)
117  # print "pl:",sorted(pairslist)
118  pairslist.sort()
119  return pairslist
120 
121  def gen_pairs_list_conv(self):
122  rep = self.stirred['replica']
123  state = self.statenums[rep]
124  pair = sorted([state, state + 2 * self.stirred['dir'] - 1])
125  self.stirred['pair'] = tuple(pair)
126  if self.xchg == 'gromacs':
127  dir = (state + 1 + self.stirred['dir']) % 2
128  return self.gen_pairs_list_gromacs(dir)
129  elif self.xchg == 'random':
130  return self.gen_pairs_list_rand(needed=[self.stirred['pair']])
131  else:
132  raise NotImplementedError(
133  "Unknown exchange method: %s" %
134  self.xchg)
135 
136  def gen_pairs_list(self):
137  if self.scheme == 'standard':
138  if self.xchg == 'gromacs':
139  return self.gen_pairs_list_gromacs(self.stepno % 2)
140  elif self.xchg == 'random':
141  return self.gen_pairs_list_rand()
142  else:
143  raise NotImplementedError(
144  "unknown exchange method: %s" %
145  self.xchg)
146  elif self.scheme == 'convective':
147  return self.gen_pairs_list_conv()
148  else:
149  raise NotImplementedError(
150  "unknown exchange scheme: %s" %
151  self.scheme)
152 
153  def get_cross_energies(self, pairslist):
154  "get energies assuming all exchanges have succeeded"
155  print("this is not needed for temperature replica-exchange")
156  raise NotImplementedError
157 
158  def get_metropolis(self, pairslist, old_ene):
159  """compute metropolis criterion for temperature replica exchange
160  e.g. exp(Delta beta Delta E)
161  input: list of pairs, list of state-sorted energies
162  """
163  metrop = {}
164  for (s1, s2) in pairslist:
165  metrop[(s1, s2)] = \
166  min(1, np.exp((old_ene[s2] - old_ene[s1]) *
167  (self.inv_temps[s2] - self.inv_temps[s1])))
168  return metrop
169 
170  def try_exchanges(self, plist, metrop):
171  accepted = []
172  for couple in plist:
173  if (metrop[couple] >= 1) or (random() < metrop[couple]):
174  accepted.append(couple)
175  return accepted
176 
177  def perform_exchanges(self, accepted):
178  "exchange given state couples both in local variables and on the grid"
179  # locally
180  for (i, j) in accepted:
181  # states
182  ri = self.replicanums[i]
183  rj = self.replicanums[j]
184  self.statenums[ri] = j
185  self.statenums[rj] = i
186  # replicas
187  buf = self.replicanums[i]
188  self.replicanums[i] = self.replicanums[j]
189  self.replicanums[j] = buf
190  # on the grid (suboptimal but who cares)
191  newtemps = self.sort_per_replica(self.inv_temps)
192  states = self.grid.gather(
193  self.grid.broadcast(self.sfo_id, 'get_state'))
194  for (i, j) in accepted:
195  ri = self.replicanums[i]
196  rj = self.replicanums[j]
197  buf = states[ri]
198  states[ri] = states[rj]
199  states[rj] = buf
200  for temp, state in zip(newtemps, states):
201  state['inv_temp'] = temp
202  self.grid.gather(
203  self.grid.scatter(self.sfo_id, 'set_state', states))
204 
205  def write_rex_stats(self):
206  "write replica numbers as a function of state"
207  fl = open(self.logfile, 'a')
208  fl.write('%8d ' % self.stepno)
209  fl.write(' '.join(['%2d' % (i + 1) for i in self.replicanums]))
210  fl.write('\n')
211  fl.close()
212  if self.scheme == 'convective':
213  fl = open(self.convectivelog, 'a')
214  fl.write('%5d ' % self.stepno)
215  fl.write('%2d ' % (self.stirred['replica'] + 1))
216  fl.write('%2d ' % self.stirred['dir'])
217  fl.write('%2d\n' % self.stirred['steps'])
218  fl.close()
219  if self.tune_temps and len(self.rn_history) == 1:
220  fl = open(self.templog, 'a')
221  fl.write('%5d ' % self.stepno)
222  fl.write(' '.join(['%.3f' % i for i in self.inv_temps]))
223  fl.write('\n')
224  fl.close()
225 
226  def tune_rex(self):
227  """use TuneRex to optimize temp set. Temps are optimized every
228  'rate' steps and 'method' is used. Data is accumulated as long as
229  the temperatures weren't optimized.
230  td keys that should be passed to init:
231  rate : the rate at which to try tuning temps
232  method : flux or ar
233  targetAR : for ar only, target acceptance rate
234  alpha : Type I error to use.
235  """
236  import TuneRex
237  # update replicanum
238  self.rn_history.append([i for i in self.replicanums])
239  td = self.tune_data
240  if len(self.rn_history) % td['rate'] == 0\
241  and len(self.rn_history) > 0:
242  temps = [1 / (kB * la) for la in self.inv_temps]
243  kwargs = {}
244  if td['method'] == 'ar':
245  if 'targetAR' in td:
246  kwargs['targetAR'] = td['targetAR']
247  if 'alpha' in td:
248  kwargs['alpha'] = td['alpha']
249  if 'dumb_scale' in td:
250  kwargs['dumb_scale'] = td['dumb_scale']
251  indicators = TuneRex.compute_indicators(
252  np.transpose(self.rn_history))
253  changed, newparams = TuneRex.tune_params_ar(
254  indicators, temps, **kwargs)
255  elif td['method'] == 'flux':
256  if 'alpha' in td:
257  kwargs['alpha'] = td['alpha']
258  changed, newparams = TuneRex.tune_params_flux(
259  np.transpose(self.rn_history),
260  temps, **kwargs)
261  if changed:
262  self.rn_history = []
263  print(newparams)
264  self.inv_temps = [1 / (kB * t) for t in newparams]
265 
266  def do_bookkeeping_before(self):
267  self.stepno += 1
268  if self.scheme == 'convective':
269  st = self.stirred
270  # check if we are done stirring this replica
271  if st['steps'] == 0:
272  st['pos'] = (st['pos'] + 1) % self.nreps
273  st['replica'] = st['order'][st['pos']]
274  st['dir'] = 1
275  st['steps'] = 2 * (self.nreps - 1)
276  rep = st['replica']
277  state = self.statenums[rep]
278  # update endpoints
279  if state == self.nreps - 1:
280  st['dir'] = 0
281  elif state == 0:
282  st['dir'] = 1
283  self.stirred = st
284 
285  def do_bookkeeping_after(self, accepted):
286  if self.scheme == 'convective':
287  rep = self.stirred['replica']
288  state = self.statenums[rep]
289  dir = 2 * self.stirred['dir'] - 1
290  expected = (min(state, state + dir), max(state, state + dir))
291  if self.stirred['pair'] in accepted:
292  self.stirred['steps'] -= 1
293 
294  def replica_exchange(self):
295  "main entry point for replica-exchange"
296  # print "replica exchange"
297  self.do_bookkeeping_before()
298  # tune temperatures
299  if self.tune_temps:
300  self.tune_rex()
301  # print "energies"
302  energies = self.sort_per_state(self.get_energies())
303  # print "pairs list"
304  plist = self.gen_pairs_list()
305  # print "metropolis"
306  metrop = self.get_metropolis(plist, energies)
307  # print "exchanges"
308  accepted = self.try_exchanges(plist, metrop)
309  # print "propagate"
310  self.perform_exchanges(accepted)
311  # print "book"
312  self.do_bookkeeping_after(accepted)
313  # print "done"
def __init__
input a list of particles, the slope and theta of the sigmoid potential theta is the cutoff distance ...