Loading [MathJax]/extensions/tex2jax.js
IMP logo
IMP Reference Guide  2.18.0
The Integrative Modeling Platform
Replica.py
1 # replica exchange class
2 # inspired by ISD Replica.py
3 # Yannick
4 
5 from __future__ import print_function, division
6 from numpy import *
7 from numpy.random import random, randint
8 from numpy.random import shuffle
9 kB = 1.3806503 * 6.0221415 / 4184.0 # Boltzmann constant in kcal/mol/K
10 
11 
12 class ReplicaTracker:
13 
14  def __init__(self, nreps, inv_temps, grid, sfo_id,
15  rexlog='replicanums.txt', scheme='standard', xchg='random',
16  convectivelog='stirred.txt', tune_temps=False,
17  tune_data={}, templog='temps.txt'):
18  self.nreps = nreps
19  # replica number as a function of state no
20  self.replicanums = list(range(nreps))
21  # state no as a function of replica no
22  self.statenums = list(range(nreps))
23  self.grid = grid
24  self.sfo_id = sfo_id
25  # expect inverse temperatures
26  self.inv_temps = inv_temps
27  self.logfile = rexlog
28  self.stepno = 1
29  self.scheme = scheme
30  self.xchg = xchg
31  self.tune_temps = tune_temps
32  self.tune_data = tune_data
33  self.rn_history = []
34  self.templog = templog
35  # scheme is one of gromacs, randomneighbors or convective
36  if scheme == "convective":
37  self.stirred = {}
38  # which order the replicas should be chosen
39  self.stirred['order'] = list(range(self.nreps))
40  self.stirred['order'].reverse()
41  # which replica is currently being stirred
42  self.stirred['replica'] = self.stirred['order'][0]
43  # which position in the order list this replica sits in
44  self.stirred['pos'] = 0
45  # what direction should this replica should go to
46  if self.stirred['replica'] != self.nreps - 1:
47  self.stirred['dir'] = 1 # up
48  else:
49  self.stirred['dir'] = 0 # down
50  # how many steps are remaining til we change stirring replica
51  self.stirred['steps'] = 2 * (self.nreps - 1)
52  self.convectivelog = convectivelog
53  self.write_rex_stats()
54 
55  def sort_per_state(self, inplist):
56  "sorts a replica list per state"
57  if len(inplist) != self.nreps:
58  raise ValueError('list has wrong size')
59  return [inplist[i] for i in self.replicanums]
60 
61  def sort_per_replica(self, inplist):
62  "sorts a state list per replica number"
63  if len(inplist) != self.nreps:
64  raise ValueError('list has wrong size')
65  return [inplist[i] for i in self.statenums]
66 
67  def get_energies(self):
68  "return replica-sorted energies"
69  return self.grid.gather(
70  self.grid.broadcast(self.sfo_id, 'm', 'evaluate', False))
71 
72  def gen_pairs_list_gromacs(self, direction):
73  """generate ordered list of exchange pairs based on direction.
74  direction == 0 : (0,1)(2,3)...
75  direction == 1 : (0(1,2)(3,4)...
76  returns only pairs.
77  """
78  nreps = self.nreps
79  if direction != 0 and direction != 1:
80  raise ValueError(direction)
81  if nreps == 2:
82  return [(0, 1)]
83  ret = [(2 * i + direction, 2 * i + 1 + direction)
84  for i in range(nreps // 2)]
85  if nreps in ret[-1]:
86  ret.pop()
87  return ret
88 
89  def gen_pairs_list_rand(self, needed=[]):
90  "generate list of neighboring pairs of states"
91  nreps = self.nreps
92  pairslist = []
93  # generate all possible pairs
94  init = [(i, i + 1) for i in range(nreps - 1)]
95  # add this pair to form a loop
96  init.append((0, nreps - 1))
97  # add needed pairs and remove overlapping candidates
98  for (i, j) in needed:
99  if j - i != 1:
100  raise ValueError("wrong format for 'needed' list")
101  pairslist.append((i, i + 1))
102  init.remove((i - 1, i))
103  init.remove((i, i + 1))
104  init.remove((i + 1, i + 2))
105  while len(init) > 0:
106  # choose random pair
107  i = randint(0, len(init)) # numpy randint is [a,b[
108  # remove it from list
109  pair = init.pop(i)
110  # print pair
111  # remove overlapping
112  init = [(r, q) for (r, q) in init
113  if (r not in pair and q not in pair)]
114  # print init
115  # add to pairslist
116  if not pair == (0, nreps - 1):
117  pairslist.append(pair)
118  # print "pl:",sorted(pairslist)
119  pairslist.sort()
120  return pairslist
121 
122  def gen_pairs_list_conv(self):
123  nreps = self.nreps
124  rep = self.stirred['replica']
125  state = self.statenums[rep]
126  pair = sorted([state, state + 2 * self.stirred['dir'] - 1])
127  self.stirred['pair'] = tuple(pair)
128  if self.xchg == 'gromacs':
129  dir = (state + 1 + self.stirred['dir']) % 2
130  return self.gen_pairs_list_gromacs(dir)
131  elif self.xchg == 'random':
132  return self.gen_pairs_list_rand(needed=[self.stirred['pair']])
133  else:
134  raise NotImplementedError(
135  "Unknown exchange method: %s" %
136  self.xchg)
137 
138  def gen_pairs_list(self):
139  if self.scheme == 'standard':
140  if self.xchg == 'gromacs':
141  return self.gen_pairs_list_gromacs(self.stepno % 2)
142  elif self.xchg == 'random':
143  return self.gen_pairs_list_rand()
144  else:
145  raise NotImplementedError(
146  "unknown exchange method: %s" %
147  self.xchg)
148  elif self.scheme == 'convective':
149  return self.gen_pairs_list_conv()
150  else:
151  raise NotImplementedError(
152  "unknown exchange scheme: %s" %
153  self.scheme)
154 
155  def get_cross_energies(self, pairslist):
156  "get energies assuming all exchanges have succeeded"
157  print("this is not needed for temperature replica-exchange")
158  raise NotImplementedError
159 
160  def get_metropolis(self, pairslist, old_ene):
161  """compute metropolis criterion for temperature replica exchange
162  e.g. exp(Delta beta Delta E)
163  input: list of pairs, list of state-sorted energies
164  """
165  metrop = {}
166  for (s1, s2) in pairslist:
167  metrop[(s1, s2)] = \
168  min(1, exp((old_ene[s2] - old_ene[s1]) *
169  (self.inv_temps[s2] - self.inv_temps[s1])))
170  return metrop
171 
172  def try_exchanges(self, plist, metrop):
173  accepted = []
174  for couple in plist:
175  if (metrop[couple] >= 1) or (random() < metrop[couple]):
176  accepted.append(couple)
177  return accepted
178 
179  def perform_exchanges(self, accepted):
180  "exchange given state couples both in local variables and on the grid"
181  # locally
182  for (i, j) in accepted:
183  # states
184  ri = self.replicanums[i]
185  rj = self.replicanums[j]
186  self.statenums[ri] = j
187  self.statenums[rj] = i
188  # replicas
189  buf = self.replicanums[i]
190  self.replicanums[i] = self.replicanums[j]
191  self.replicanums[j] = buf
192  # on the grid (suboptimal but who cares)
193  newtemps = self.sort_per_replica(self.inv_temps)
194  states = self.grid.gather(
195  self.grid.broadcast(self.sfo_id, 'get_state'))
196  for (i, j) in accepted:
197  ri = self.replicanums[i]
198  rj = self.replicanums[j]
199  buf = states[ri]
200  states[ri] = states[rj]
201  states[rj] = buf
202  for temp, state in zip(newtemps, states):
203  state['inv_temp'] = temp
204  self.grid.gather(
205  self.grid.scatter(self.sfo_id, 'set_state', states))
206 
207  def write_rex_stats(self):
208  "write replica numbers as a function of state"
209  fl = open(self.logfile, 'a')
210  fl.write('%8d ' % self.stepno)
211  fl.write(' '.join(['%2d' % (i + 1) for i in self.replicanums]))
212  fl.write('\n')
213  fl.close()
214  if self.scheme == 'convective':
215  fl = open(self.convectivelog, 'a')
216  fl.write('%5d ' % self.stepno)
217  fl.write('%2d ' % (self.stirred['replica'] + 1))
218  fl.write('%2d ' % self.stirred['dir'])
219  fl.write('%2d\n' % self.stirred['steps'])
220  fl.close()
221  if self.tune_temps and len(self.rn_history) == 1:
222  fl = open(self.templog, 'a')
223  fl.write('%5d ' % self.stepno)
224  fl.write(' '.join(['%.3f' % i for i in self.inv_temps]))
225  fl.write('\n')
226  fl.close()
227 
228  def tune_rex(self):
229  """use TuneRex to optimize temp set. Temps are optimized every
230  'rate' steps and 'method' is used. Data is accumulated as long as
231  the temperatures weren't optimized.
232  td keys that should be passed to init:
233  rate : the rate at which to try tuning temps
234  method : flux or ar
235  targetAR : for ar only, target acceptance rate
236  alpha : Type I error to use.
237  """
238  import TuneRex
239  # update replicanum
240  self.rn_history.append([i for i in self.replicanums])
241  td = self.tune_data
242  if len(self.rn_history) % td['rate'] == 0\
243  and len(self.rn_history) > 0:
244  temps = [1 / (kB * la) for la in self.inv_temps]
245  kwargs = {}
246  if td['method'] == 'ar':
247  if 'targetAR' in td:
248  kwargs['targetAR'] = td['targetAR']
249  if 'alpha' in td:
250  kwargs['alpha'] = td['alpha']
251  if 'dumb_scale' in td:
252  kwargs['dumb_scale'] = td['dumb_scale']
253  indicators = TuneRex.compute_indicators(
254  transpose(self.rn_history))
255  changed, newparams = TuneRex.tune_params_ar(
256  indicators, temps, **kwargs)
257  elif td['method'] == 'flux':
258  if 'alpha' in td:
259  kwargs['alpha'] = td['alpha']
260  changed, newparams = TuneRex.tune_params_flux(
261  transpose(self.rn_history),
262  temps, **kwargs)
263  if changed:
264  self.rn_history = []
265  print(newparams)
266  self.inv_temps = [1 / (kB * t) for t in newparams]
267 
268  def do_bookkeeping_before(self):
269  self.stepno += 1
270  if self.scheme == 'convective':
271  st = self.stirred
272  # check if we are done stirring this replica
273  if st['steps'] == 0:
274  st['pos'] = (st['pos'] + 1) % self.nreps
275  st['replica'] = st['order'][st['pos']]
276  st['dir'] = 1
277  st['steps'] = 2 * (self.nreps - 1)
278  rep = st['replica']
279  state = self.statenums[rep]
280  # update endpoints
281  if state == self.nreps - 1:
282  st['dir'] = 0
283  elif state == 0:
284  st['dir'] = 1
285  self.stirred = st
286 
287  def do_bookkeeping_after(self, accepted):
288  if self.scheme == 'convective':
289  rep = self.stirred['replica']
290  state = self.statenums[rep]
291  dir = 2 * self.stirred['dir'] - 1
292  expected = (min(state, state + dir), max(state, state + dir))
293  if self.stirred['pair'] in accepted:
294  self.stirred['steps'] -= 1
295 
296  def replica_exchange(self):
297  "main entry point for replica-exchange"
298  # print "replica exchange"
299  self.do_bookkeeping_before()
300  # tune temperatures
301  if self.tune_temps:
302  self.tune_rex()
303  # print "energies"
304  energies = self.sort_per_state(self.get_energies())
305  # print "pairs list"
306  plist = self.gen_pairs_list()
307  # print "metropolis"
308  metrop = self.get_metropolis(plist, energies)
309  # print "exchanges"
310  accepted = self.try_exchanges(plist, metrop)
311  # print "propagate"
312  self.perform_exchanges(accepted)
313  # print "book"
314  self.do_bookkeeping_after(accepted)
315  # print "done"
def __init__
input a list of particles, the slope and theta of the sigmoid potential theta is the cutoff distance ...