IMP  2.3.1
The Integrative Modeling Platform
Replica.py
1 # replica exchange class
2 # inspired by ISD Replica.py
3 # Yannick
4 
5 from numpy import *
6 from numpy.random import random, randint
7 from numpy.random import shuffle
8 kB = 1.3806503 * 6.0221415 / 4184.0 # Boltzmann constant in kcal/mol/K
9 
10 
11 class ReplicaTracker:
12 
13  def __init__(self, nreps, inv_temps, grid, sfo_id,
14  rexlog='replicanums.txt', scheme='standard', xchg='random',
15  convectivelog='stirred.txt', tune_temps=False,
16  tune_data={}, templog='temps.txt'):
17  self.nreps = nreps
18  # replica number as a function of state no
19  self.replicanums = range(nreps)
20  # state no as a function of replica no
21  self.statenums = range(nreps)
22  self.grid = grid
23  self.sfo_id = sfo_id
24  # expect inverse temperatures
25  self.inv_temps = inv_temps
26  self.logfile = rexlog
27  self.stepno = 1
28  self.scheme = scheme
29  self.xchg = xchg
30  self.tune_temps = tune_temps
31  self.tune_data = tune_data
32  self.rn_history = []
33  self.templog = templog
34  # scheme is one of gromacs, randomneighbors or convective
35  if scheme == "convective":
36  self.stirred = {}
37  # which order the replicas should be chosen
38  self.stirred['order'] = range(self.nreps)
39  self.stirred['order'].reverse()
40  # which replica is currently being stirred
41  self.stirred['replica'] = self.stirred['order'][0]
42  # which position in the order list this replica sits in
43  self.stirred['pos'] = 0
44  # what direction should this replica should go to
45  if self.stirred['replica'] != self.nreps - 1:
46  self.stirred['dir'] = 1 # up
47  else:
48  self.stirred['dir'] = 0 # down
49  # how many steps are remaining til we change stirring replica
50  self.stirred['steps'] = 2 * (self.nreps - 1)
51  self.convectivelog = convectivelog
52  self.write_rex_stats()
53 
54  def sort_per_state(self, inplist):
55  "sorts a replica list per state"
56  if len(inplist) != self.nreps:
57  raise ValueError('list has wrong size')
58  return [inplist[i] for i in self.replicanums]
59 
60  def sort_per_replica(self, inplist):
61  "sorts a state list per replica number"
62  if len(inplist) != self.nreps:
63  raise ValueError('list has wrong size')
64  return [inplist[i] for i in self.statenums]
65 
66  def get_energies(self):
67  "return replica-sorted energies"
68  return self.grid.gather(
69  self.grid.broadcast(self.sfo_id, 'm', 'evaluate', False))
70 
71  def gen_pairs_list_gromacs(self, direction):
72  """generate ordered list of exchange pairs based on direction.
73  direction == 0 : (0,1)(2,3)...
74  direction == 1 : (0(1,2)(3,4)...
75  returns only pairs.
76  """
77  nreps = self.nreps
78  if direction != 0 and direction != 1:
79  raise ValueError(direction)
80  if nreps == 2:
81  return [(0, 1)]
82  ret = [(2 * i + direction, 2 * i + 1 + direction)
83  for i in xrange(nreps / 2)]
84  if nreps in ret[-1]:
85  ret.pop()
86  return ret
87 
88  def gen_pairs_list_rand(self, needed=[]):
89  "generate list of neighboring pairs of states"
90  nreps = self.nreps
91  pairslist = []
92  # generate all possible pairs
93  init = [(i, i + 1) for i in xrange(nreps - 1)]
94  # add this pair to form a loop
95  init.append((0, nreps - 1))
96  # add needed pairs and remove overlapping candidates
97  for (i, j) in needed:
98  if j - i != 1:
99  raise ValueError("wrong format for 'needed' list")
100  pairslist.append((i, i + 1))
101  init.remove((i - 1, i))
102  init.remove((i, i + 1))
103  init.remove((i + 1, i + 2))
104  while len(init) > 0:
105  # choose random pair
106  i = randint(0, len(init)) # numpy randint is [a,b[
107  # remove it from list
108  pair = init.pop(i)
109  # print pair
110  # remove overlapping
111  init = [(r, q) for (r, q) in init
112  if (r not in pair and q not in pair)]
113  # print init
114  # add to pairslist
115  if not pair == (0, nreps - 1):
116  pairslist.append(pair)
117  # print "pl:",sorted(pairslist)
118  pairslist.sort()
119  return pairslist
120 
121  def gen_pairs_list_conv(self):
122  nreps = self.nreps
123  rep = self.stirred['replica']
124  state = self.statenums[rep]
125  pair = sorted([state, state + 2 * self.stirred['dir'] - 1])
126  self.stirred['pair'] = tuple(pair)
127  if self.xchg == 'gromacs':
128  dir = (state + 1 + self.stirred['dir']) % 2
129  return self.gen_pairs_list_gromacs(dir)
130  elif self.xchg == 'random':
131  return self.gen_pairs_list_rand(needed=[self.stirred['pair']])
132  else:
133  raise NotImplementedError(
134  "Unknown exchange method: %s" %
135  self.xchg)
136 
137  def gen_pairs_list(self):
138  if self.scheme == 'standard':
139  if self.xchg == 'gromacs':
140  return self.gen_pairs_list_gromacs(self.stepno % 2)
141  elif self.xchg == 'random':
142  return self.gen_pairs_list_rand()
143  else:
144  raise NotImplementedError(
145  "unknown exchange method: %s" %
146  self.xchg)
147  elif self.scheme == 'convective':
148  return self.gen_pairs_list_conv()
149  else:
150  raise NotImplementedError(
151  "unknown exchange scheme: %s" %
152  self.scheme)
153 
154  def get_cross_energies(self, pairslist):
155  "get energies assuming all exchanges have succeeded"
156  print "this is not needed for temperature replica-exchange"
157  raise NotImplementedError
158 
159  def get_metropolis(self, pairslist, old_ene):
160  """compute metropolis criterion for temperature replica exchange
161  e.g. exp(Delta beta Delta E)
162  input: list of pairs, list of state-sorted energies
163  """
164  metrop = {}
165  for (s1, s2) in pairslist:
166  metrop[(s1, s2)] = \
167  min(1, exp((old_ene[s2] - old_ene[s1]) *
168  (self.inv_temps[s2] - self.inv_temps[s1])))
169  return metrop
170 
171  def try_exchanges(self, plist, metrop):
172  accepted = []
173  for couple in plist:
174  if (metrop[couple] >= 1) or (random() < metrop[couple]):
175  accepted.append(couple)
176  return accepted
177 
178  def perform_exchanges(self, accepted):
179  "exchange given state couples both in local variables and on the grid"
180  # locally
181  for (i, j) in accepted:
182  # states
183  ri = self.replicanums[i]
184  rj = self.replicanums[j]
185  self.statenums[ri] = j
186  self.statenums[rj] = i
187  # replicas
188  buf = self.replicanums[i]
189  self.replicanums[i] = self.replicanums[j]
190  self.replicanums[j] = buf
191  # on the grid (suboptimal but who cares)
192  newtemps = self.sort_per_replica(self.inv_temps)
193  states = self.grid.gather(
194  self.grid.broadcast(self.sfo_id, 'get_state'))
195  for (i, j) in accepted:
196  ri = self.replicanums[i]
197  rj = self.replicanums[j]
198  buf = states[ri]
199  states[ri] = states[rj]
200  states[rj] = buf
201  for temp, state in zip(newtemps, states):
202  state['inv_temp'] = temp
203  self.grid.gather(
204  self.grid.scatter(self.sfo_id, 'set_state', states))
205 
206  def write_rex_stats(self):
207  "write replica numbers as a function of state"
208  fl = open(self.logfile, 'a')
209  fl.write('%8d ' % self.stepno)
210  fl.write(' '.join(['%2d' % (i + 1) for i in self.replicanums]))
211  fl.write('\n')
212  fl.close()
213  if self.scheme == 'convective':
214  fl = open(self.convectivelog, 'a')
215  fl.write('%5d ' % self.stepno)
216  fl.write('%2d ' % (self.stirred['replica'] + 1))
217  fl.write('%2d ' % self.stirred['dir'])
218  fl.write('%2d\n' % self.stirred['steps'])
219  fl.close()
220  if self.tune_temps and len(self.rn_history) == 1:
221  fl = open(self.templog, 'a')
222  fl.write('%5d ' % self.stepno)
223  fl.write(' '.join(['%.3f' % i for i in self.inv_temps]))
224  fl.write('\n')
225  fl.close()
226 
227  def tune_rex(self):
228  """use TuneRex to optimize temp set. Temps are optimized every
229  'rate' steps and 'method' is used. Data is accumulated as long as
230  the temperatures weren't optimized.
231  td keys that should be passed to init:
232  rate : the rate at which to try tuning temps
233  method : flux or ar
234  targetAR : for ar only, target acceptance rate
235  alpha : Type I error to use.
236  """
237  import TuneRex
238  # update replicanum
239  self.rn_history.append([i for i in self.replicanums])
240  td = self.tune_data
241  if len(self.rn_history) % td['rate'] == 0\
242  and len(self.rn_history) > 0:
243  temps = [1 / (kB * la) for la in self.inv_temps]
244  kwargs = {}
245  if td['method'] == 'ar':
246  if 'targetAR' in td:
247  kwargs['targetAR'] = td['targetAR']
248  if 'alpha' in td:
249  kwargs['alpha'] = td['alpha']
250  if 'dumb_scale' in td:
251  kwargs['dumb_scale'] = td['dumb_scale']
252  indicators = TuneRex.compute_indicators(
253  transpose(self.rn_history))
254  changed, newparams = TuneRex.tune_params_ar(
255  indicators, temps, **kwargs)
256  elif td['method'] == 'flux':
257  if 'alpha' in td:
258  kwargs['alpha'] = td['alpha']
259  changed, newparams = TuneRex.tune_params_flux(
260  transpose(self.rn_history),
261  temps, **kwargs)
262  if changed:
263  self.rn_history = []
264  print newparams
265  self.inv_temps = [1 / (kB * t) for t in newparams]
266 
267  def do_bookkeeping_before(self):
268  self.stepno += 1
269  if self.scheme == 'convective':
270  st = self.stirred
271  # check if we are done stirring this replica
272  if st['steps'] == 0:
273  st['pos'] = (st['pos'] + 1) % self.nreps
274  st['replica'] = st['order'][st['pos']]
275  st['dir'] = 1
276  st['steps'] = 2 * (self.nreps - 1)
277  rep = st['replica']
278  state = self.statenums[rep]
279  # update endpoints
280  if state == self.nreps - 1:
281  st['dir'] = 0
282  elif state == 0:
283  st['dir'] = 1
284  self.stirred = st
285 
286  def do_bookkeeping_after(self, accepted):
287  if self.scheme == 'convective':
288  rep = self.stirred['replica']
289  state = self.statenums[rep]
290  dir = 2 * self.stirred['dir'] - 1
291  expected = (min(state, state + dir), max(state, state + dir))
292  if self.stirred['pair'] in accepted:
293  self.stirred['steps'] -= 1
294 
295  def replica_exchange(self):
296  "main entry point for replica-exchange"
297  # print "replica exchange"
298  self.do_bookkeeping_before()
299  # tune temperatures
300  if self.tune_temps:
301  self.tune_rex()
302  # print "energies"
303  energies = self.sort_per_state(self.get_energies())
304  # print "pairs list"
305  plist = self.gen_pairs_list()
306  # print "metropolis"
307  metrop = self.get_metropolis(plist, energies)
308  # print "exchanges"
309  accepted = self.try_exchanges(plist, metrop)
310  # print "propagate"
311  self.perform_exchanges(accepted)
312  # print "book"
313  self.do_bookkeeping_after(accepted)
314  # print "done"