IMP logo
IMP Reference Guide  develop.63b38c487d,2024/12/22
The Integrative Modeling Platform
Replica.py
1 # replica exchange class
2 # inspired by ISD Replica.py
3 # Yannick
4 
5 import numpy as np
6 from numpy.random import random, randint
7 kB = 1.3806503 * 6.0221415 / 4184.0 # Boltzmann constant in kcal/mol/K
8 
9 
10 class ReplicaTracker:
11 
12  def __init__(self, nreps, inv_temps, grid, sfo_id,
13  rexlog='replicanums.txt', scheme='standard', xchg='random',
14  convectivelog='stirred.txt', tune_temps=False,
15  tune_data={}, templog='temps.txt'):
16  self.nreps = nreps
17  # replica number as a function of state no
18  self.replicanums = list(range(nreps))
19  # state no as a function of replica no
20  self.statenums = list(range(nreps))
21  self.grid = grid
22  self.sfo_id = sfo_id
23  # expect inverse temperatures
24  self.inv_temps = inv_temps
25  self.logfile = rexlog
26  self.stepno = 1
27  self.scheme = scheme
28  self.xchg = xchg
29  self.tune_temps = tune_temps
30  self.tune_data = tune_data
31  self.rn_history = []
32  self.templog = templog
33  # scheme is one of gromacs, randomneighbors or convective
34  if scheme == "convective":
35  self.stirred = {}
36  # which order the replicas should be chosen
37  self.stirred['order'] = list(range(self.nreps))
38  self.stirred['order'].reverse()
39  # which replica is currently being stirred
40  self.stirred['replica'] = self.stirred['order'][0]
41  # which position in the order list this replica sits in
42  self.stirred['pos'] = 0
43  # what direction should this replica should go to
44  if self.stirred['replica'] != self.nreps - 1:
45  self.stirred['dir'] = 1 # up
46  else:
47  self.stirred['dir'] = 0 # down
48  # how many steps are remaining til we change stirring replica
49  self.stirred['steps'] = 2 * (self.nreps - 1)
50  self.convectivelog = convectivelog
51  self.write_rex_stats()
52 
53  def sort_per_state(self, inplist):
54  "sorts a replica list per state"
55  if len(inplist) != self.nreps:
56  raise ValueError('list has wrong size')
57  return [inplist[i] for i in self.replicanums]
58 
59  def sort_per_replica(self, inplist):
60  "sorts a state list per replica number"
61  if len(inplist) != self.nreps:
62  raise ValueError('list has wrong size')
63  return [inplist[i] for i in self.statenums]
64 
65  def get_energies(self):
66  "return replica-sorted energies"
67  return self.grid.gather(
68  self.grid.broadcast(self.sfo_id, 'm', 'evaluate', False))
69 
70  def gen_pairs_list_gromacs(self, direction):
71  """generate ordered list of exchange pairs based on direction.
72  direction == 0 : (0,1)(2,3)...
73  direction == 1 : (0(1,2)(3,4)...
74  returns only pairs.
75  """
76  nreps = self.nreps
77  if direction != 0 and direction != 1:
78  raise ValueError(direction)
79  if nreps == 2:
80  return [(0, 1)]
81  ret = [(2 * i + direction, 2 * i + 1 + direction)
82  for i in range(nreps // 2)]
83  if nreps in ret[-1]:
84  ret.pop()
85  return ret
86 
87  def gen_pairs_list_rand(self, needed=[]):
88  "generate list of neighboring pairs of states"
89  nreps = self.nreps
90  pairslist = []
91  # generate all possible pairs
92  init = [(i, i + 1) for i in range(nreps - 1)]
93  # add this pair to form a loop
94  init.append((0, nreps - 1))
95  # add needed pairs and remove overlapping candidates
96  for (i, j) in needed:
97  if j - i != 1:
98  raise ValueError("wrong format for 'needed' list")
99  pairslist.append((i, i + 1))
100  init.remove((i - 1, i))
101  init.remove((i, i + 1))
102  init.remove((i + 1, i + 2))
103  while len(init) > 0:
104  # choose random pair
105  i = randint(0, len(init)) # numpy randint is [a,b[
106  # remove it from list
107  pair = init.pop(i)
108  # print pair
109  # remove overlapping
110  init = [(r, q) for (r, q) in init
111  if (r not in pair and q not in pair)]
112  # print init
113  # add to pairslist
114  if not pair == (0, nreps - 1):
115  pairslist.append(pair)
116  # print "pl:",sorted(pairslist)
117  pairslist.sort()
118  return pairslist
119 
120  def gen_pairs_list_conv(self):
121  rep = self.stirred['replica']
122  state = self.statenums[rep]
123  pair = sorted([state, state + 2 * self.stirred['dir'] - 1])
124  self.stirred['pair'] = tuple(pair)
125  if self.xchg == 'gromacs':
126  dir = (state + 1 + self.stirred['dir']) % 2
127  return self.gen_pairs_list_gromacs(dir)
128  elif self.xchg == 'random':
129  return self.gen_pairs_list_rand(needed=[self.stirred['pair']])
130  else:
131  raise NotImplementedError(
132  "Unknown exchange method: %s" %
133  self.xchg)
134 
135  def gen_pairs_list(self):
136  if self.scheme == 'standard':
137  if self.xchg == 'gromacs':
138  return self.gen_pairs_list_gromacs(self.stepno % 2)
139  elif self.xchg == 'random':
140  return self.gen_pairs_list_rand()
141  else:
142  raise NotImplementedError(
143  "unknown exchange method: %s" %
144  self.xchg)
145  elif self.scheme == 'convective':
146  return self.gen_pairs_list_conv()
147  else:
148  raise NotImplementedError(
149  "unknown exchange scheme: %s" %
150  self.scheme)
151 
152  def get_cross_energies(self, pairslist):
153  "get energies assuming all exchanges have succeeded"
154  print("this is not needed for temperature replica-exchange")
155  raise NotImplementedError
156 
157  def get_metropolis(self, pairslist, old_ene):
158  """compute metropolis criterion for temperature replica exchange
159  e.g. exp(Delta beta Delta E)
160  input: list of pairs, list of state-sorted energies
161  """
162  metrop = {}
163  for (s1, s2) in pairslist:
164  metrop[(s1, s2)] = \
165  min(1, np.exp((old_ene[s2] - old_ene[s1])
166  * (self.inv_temps[s2] - self.inv_temps[s1])))
167  return metrop
168 
169  def try_exchanges(self, plist, metrop):
170  accepted = []
171  for couple in plist:
172  if (metrop[couple] >= 1) or (random() < metrop[couple]):
173  accepted.append(couple)
174  return accepted
175 
176  def perform_exchanges(self, accepted):
177  "exchange given state couples both in local variables and on the grid"
178  # locally
179  for (i, j) in accepted:
180  # states
181  ri = self.replicanums[i]
182  rj = self.replicanums[j]
183  self.statenums[ri] = j
184  self.statenums[rj] = i
185  # replicas
186  buf = self.replicanums[i]
187  self.replicanums[i] = self.replicanums[j]
188  self.replicanums[j] = buf
189  # on the grid (suboptimal but who cares)
190  newtemps = self.sort_per_replica(self.inv_temps)
191  states = self.grid.gather(
192  self.grid.broadcast(self.sfo_id, 'get_state'))
193  for (i, j) in accepted:
194  ri = self.replicanums[i]
195  rj = self.replicanums[j]
196  buf = states[ri]
197  states[ri] = states[rj]
198  states[rj] = buf
199  for temp, state in zip(newtemps, states):
200  state['inv_temp'] = temp
201  self.grid.gather(
202  self.grid.scatter(self.sfo_id, 'set_state', states))
203 
204  def write_rex_stats(self):
205  "write replica numbers as a function of state"
206  fl = open(self.logfile, 'a')
207  fl.write('%8d ' % self.stepno)
208  fl.write(' '.join(['%2d' % (i + 1) for i in self.replicanums]))
209  fl.write('\n')
210  fl.close()
211  if self.scheme == 'convective':
212  fl = open(self.convectivelog, 'a')
213  fl.write('%5d ' % self.stepno)
214  fl.write('%2d ' % (self.stirred['replica'] + 1))
215  fl.write('%2d ' % self.stirred['dir'])
216  fl.write('%2d\n' % self.stirred['steps'])
217  fl.close()
218  if self.tune_temps and len(self.rn_history) == 1:
219  fl = open(self.templog, 'a')
220  fl.write('%5d ' % self.stepno)
221  fl.write(' '.join(['%.3f' % i for i in self.inv_temps]))
222  fl.write('\n')
223  fl.close()
224 
225  def tune_rex(self):
226  """use TuneRex to optimize temp set. Temps are optimized every
227  'rate' steps and 'method' is used. Data is accumulated as long as
228  the temperatures weren't optimized.
229  td keys that should be passed to init:
230  rate : the rate at which to try tuning temps
231  method : flux or ar
232  targetAR : for ar only, target acceptance rate
233  alpha : Type I error to use.
234  """
235  import TuneRex
236  # update replicanum
237  self.rn_history.append([i for i in self.replicanums])
238  td = self.tune_data
239  if len(self.rn_history) % td['rate'] == 0\
240  and len(self.rn_history) > 0:
241  temps = [1 / (kB * la) for la in self.inv_temps]
242  kwargs = {}
243  if td['method'] == 'ar':
244  if 'targetAR' in td:
245  kwargs['targetAR'] = td['targetAR']
246  if 'alpha' in td:
247  kwargs['alpha'] = td['alpha']
248  if 'dumb_scale' in td:
249  kwargs['dumb_scale'] = td['dumb_scale']
250  indicators = TuneRex.compute_indicators(
251  np.transpose(self.rn_history))
252  changed, newparams = TuneRex.tune_params_ar(
253  indicators, temps, **kwargs)
254  elif td['method'] == 'flux':
255  if 'alpha' in td:
256  kwargs['alpha'] = td['alpha']
257  changed, newparams = TuneRex.tune_params_flux(
258  np.transpose(self.rn_history),
259  temps, **kwargs)
260  if changed:
261  self.rn_history = []
262  print(newparams)
263  self.inv_temps = [1 / (kB * t) for t in newparams]
264 
265  def do_bookkeeping_before(self):
266  self.stepno += 1
267  if self.scheme == 'convective':
268  st = self.stirred
269  # check if we are done stirring this replica
270  if st['steps'] == 0:
271  st['pos'] = (st['pos'] + 1) % self.nreps
272  st['replica'] = st['order'][st['pos']]
273  st['dir'] = 1
274  st['steps'] = 2 * (self.nreps - 1)
275  rep = st['replica']
276  state = self.statenums[rep]
277  # update endpoints
278  if state == self.nreps - 1:
279  st['dir'] = 0
280  elif state == 0:
281  st['dir'] = 1
282  self.stirred = st
283 
284  def do_bookkeeping_after(self, accepted):
285  if self.scheme == 'convective':
286  # rep = self.stirred['replica']
287  # state = self.statenums[rep]
288  # dir = 2 * self.stirred['dir'] - 1
289  # expected = (min(state, state + dir), max(state, state + dir))
290  if self.stirred['pair'] in accepted:
291  self.stirred['steps'] -= 1
292 
293  def replica_exchange(self):
294  "main entry point for replica-exchange"
295  # print "replica exchange"
296  self.do_bookkeeping_before()
297  # tune temperatures
298  if self.tune_temps:
299  self.tune_rex()
300  # print "energies"
301  energies = self.sort_per_state(self.get_energies())
302  # print "pairs list"
303  plist = self.gen_pairs_list()
304  # print "metropolis"
305  metrop = self.get_metropolis(plist, energies)
306  # print "exchanges"
307  accepted = self.try_exchanges(plist, metrop)
308  # print "propagate"
309  self.perform_exchanges(accepted)
310  # print "book"
311  self.do_bookkeeping_after(accepted)
312  # print "done"
def __init__
input a list of particles, the slope and theta of the sigmoid potential theta is the cutoff distance ...