IMP  2.0.1
The Integrative Modeling Platform
Replica.py
1 #replica exchange class
2 #inspired by ISD Replica.py
3 #Yannick
4 
5 from numpy import *
6 from numpy.random import random, randint
7 from numpy.random import shuffle
8 kB = 1.3806503 * 6.0221415 / 4184.0 # Boltzmann constant in kcal/mol/K
9 
10 class ReplicaTracker:
11 
12  def __init__(self,nreps,inv_temps,grid,sfo_id,
13  rexlog='replicanums.txt', scheme='standard', xchg='random',
14  convectivelog='stirred.txt', tune_temps = False,
15  tune_data={}, templog = 'temps.txt'):
16  self.nreps = nreps
17  #replica number as a function of state no
18  self.replicanums = range(nreps)
19  #state no as a function of replica no
20  self.statenums = range(nreps)
21  self.grid = grid
22  self.sfo_id = sfo_id
23  #expect inverse temperatures
24  self.inv_temps = inv_temps
25  self.logfile = rexlog
26  self.stepno = 1
27  self.scheme = scheme
28  self.xchg = xchg
29  self.tune_temps = tune_temps
30  self.tune_data = tune_data
31  self.rn_history = []
32  self.templog = templog
33  #scheme is one of gromacs, randomneighbors or convective
34  if scheme == "convective":
35  self.stirred = {}
36  #which order the replicas should be chosen
37  self.stirred['order'] = range(self.nreps)
38  self.stirred['order'].reverse()
39  #which replica is currently being stirred
40  self.stirred['replica']=self.stirred['order'][0]
41  #which position in the order list this replica sits in
42  self.stirred['pos']=0
43  #what direction should this replica should go to
44  if self.stirred['replica'] != self.nreps-1:
45  self.stirred['dir']=1 #up
46  else:
47  self.stirred['dir']=0 #down
48  #how many steps are remaining til we change stirring replica
49  self.stirred['steps']=2*(self.nreps-1)
50  self.convectivelog = convectivelog
51  self.write_rex_stats()
52 
53  def sort_per_state(self, inplist):
54  "sorts a replica list per state"
55  if len(inplist) != self.nreps:
56  raise ValueError, 'list has wrong size'
57  return [inplist[i] for i in self.replicanums]
58 
59  def sort_per_replica(self, inplist):
60  "sorts a state list per replica number"
61  if len(inplist) != self.nreps:
62  raise ValueError, 'list has wrong size'
63  return [inplist[i] for i in self.statenums]
64 
65  def get_energies(self):
66  "return replica-sorted energies"
67  return self.grid.gather(
68  self.grid.broadcast(self.sfo_id,'m','evaluate',False))
69 
70  def gen_pairs_list_gromacs(self, direction):
71  """generate ordered list of exchange pairs based on direction.
72  direction == 0 : (0,1)(2,3)...
73  direction == 1 : (0(1,2)(3,4)...
74  returns only pairs.
75  """
76  nreps = self.nreps
77  if direction != 0 and direction != 1:
78  raise ValueError, direction
79  if nreps == 2:
80  return [(0,1)]
81  ret = [(2*i+direction,2*i+1+direction) for i in xrange(nreps/2)]
82  if nreps in ret[-1]:
83  ret.pop()
84  return ret
85 
86  def gen_pairs_list_rand(self, needed = []):
87  "generate list of neighboring pairs of states"
88  nreps = self.nreps
89  pairslist = []
90  #generate all possible pairs
91  init = [(i,i+1) for i in xrange(nreps-1)]
92  #add this pair to form a loop
93  init.append((0,nreps-1))
94  #add needed pairs and remove overlapping candidates
95  for (i,j) in needed:
96  if j-i != 1:
97  raise ValueError, "wrong format for 'needed' list"
98  pairslist.append((i,i+1))
99  init.remove((i-1,i))
100  init.remove((i,i+1))
101  init.remove((i+1,i+2))
102  while len(init) > 0:
103  #choose random pair
104  i = randint(0,len(init)) # numpy randint is [a,b[
105  #remove it from list
106  pair = init.pop(i)
107  #print pair
108  #remove overlapping
109  init = [(r,q) for (r,q) in init
110  if (r not in pair and q not in pair)]
111  #print init
112  #add to pairslist
113  if not pair == (0,nreps-1):
114  pairslist.append(pair)
115  #print "pl:",sorted(pairslist)
116  pairslist.sort()
117  return pairslist
118 
119  def gen_pairs_list_conv(self):
120  nreps = self.nreps
121  rep = self.stirred['replica']
122  state = self.statenums[rep]
123  pair = [state, state + 2*self.stirred['dir'] - 1]
124  pair.sort()
125  self.stirred['pair'] = tuple(pair)
126  if self.xchg == 'gromacs':
127  dir = (state + 1 + self.stirred['dir']) % 2
128  return self.gen_pairs_list_gromacs(dir)
129  elif self.xchg == 'random':
130  return self.gen_pairs_list_rand(needed=[self.stirred['pair']])
131  else:
132  raise NotImplementedError, \
133  "Unknown exchange method: %s" % self.xchg
134 
135  def gen_pairs_list(self):
136  if self.scheme == 'standard':
137  if self.xchg == 'gromacs':
138  return self.gen_pairs_list_gromacs(self.stepno % 2)
139  elif self.xchg == 'random':
140  return self.gen_pairs_list_rand()
141  else:
142  raise NotImplementedError, \
143  "unknown exchange method: %s" % self.xchg
144  elif self.scheme == 'convective':
145  return self.gen_pairs_list_conv()
146  else:
147  raise NotImplementedError, \
148  "unknown exchange scheme: %s" % self.scheme
149 
150  def get_cross_energies(self, pairslist):
151  "get energies assuming all exchanges have succeeded"
152  print "this is not needed for temperature replica-exchange"
153  raise NotImplementedError
154 
155  def get_metropolis(self, pairslist, old_ene):
156  """compute metropolis criterion for temperature replica exchange
157  e.g. exp(Delta beta Delta E)
158  input: list of pairs, list of state-sorted energies
159  """
160  metrop={}
161  for (s1,s2) in pairslist:
162  metrop[(s1,s2)] = \
163  min(1,exp((old_ene[s2]-old_ene[s1])*
164  (self.inv_temps[s2]-self.inv_temps[s1])))
165  return metrop
166 
167  def try_exchanges(self, plist, metrop):
168  accepted = []
169  for couple in plist:
170  if (metrop[couple] >= 1) or (random() < metrop[couple]):
171  accepted.append(couple)
172  return accepted
173 
174  def perform_exchanges(self, accepted):
175  "exchange given state couples both in local variables and on the grid"
176  #locally
177  for (i,j) in accepted:
178  #states
179  ri = self.replicanums[i]
180  rj = self.replicanums[j]
181  self.statenums[ri] = j
182  self.statenums[rj] = i
183  #replicas
184  buf = self.replicanums[i]
185  self.replicanums[i] = self.replicanums[j]
186  self.replicanums[j] = buf
187  #on the grid (suboptimal but who cares)
188  newtemps = self.sort_per_replica(self.inv_temps)
189  states = self.grid.gather(
190  self.grid.broadcast(self.sfo_id, 'get_state'))
191  for (i,j) in accepted:
192  ri = self.replicanums[i]
193  rj = self.replicanums[j]
194  buf = states[ri]
195  states[ri] = states[rj]
196  states[rj] = buf
197  for temp,state in zip(newtemps,states):
198  state['inv_temp'] = temp
199  self.grid.gather(
200  self.grid.scatter(self.sfo_id, 'set_state', states))
201 
202  def write_rex_stats(self):
203  "write replica numbers as a function of state"
204  fl=open(self.logfile,'a')
205  fl.write('%8d ' % self.stepno)
206  fl.write(' '.join(['%2d' % (i+1) for i in self.replicanums]))
207  fl.write('\n')
208  fl.close()
209  if self.scheme == 'convective':
210  fl=open(self.convectivelog, 'a')
211  fl.write('%5d ' % self.stepno)
212  fl.write('%2d ' % (self.stirred['replica']+1))
213  fl.write('%2d ' % self.stirred['dir'])
214  fl.write('%2d\n' % self.stirred['steps'])
215  fl.close()
216  if self.tune_temps and len(self.rn_history) == 1:
217  fl=open(self.templog, 'a')
218  fl.write('%5d ' % self.stepno)
219  fl.write(' '.join(['%.3f' % i for i in self.inv_temps]))
220  fl.write('\n')
221  fl.close()
222 
223  def tune_rex(self):
224  """use TuneRex to optimize temp set. Temps are optimized every
225  'rate' steps and 'method' is used. Data is accumulated as long as
226  the temperatures weren't optimized.
227  td keys that should be passed to init:
228  rate : the rate at which to try tuning temps
229  method : flux or ar
230  targetAR : for ar only, target acceptance rate
231  alpha : Type I error to use.
232  """
233  import TuneRex
234  #update replicanum
235  self.rn_history.append([i for i in self.replicanums])
236  td = self.tune_data
237  if len(self.rn_history) % td['rate'] == 0\
238  and len(self.rn_history) > 0:
239  temps = [1/(kB*la) for la in self.inv_temps]
240  kwargs={}
241  if td['method'] == 'ar':
242  if 'targetAR' in td: kwargs['targetAR']=td['targetAR']
243  if 'alpha' in td: kwargs['alpha']=td['alpha']
244  if 'dumb_scale' in td: kwargs['dumb_scale']=td['dumb_scale']
245  indicators = TuneRex.compute_indicators(
246  transpose(self.rn_history))
247  changed, newparams = TuneRex.tune_params_ar(indicators, temps, **kwargs)
248  elif td['method'] == 'flux':
249  if 'alpha' in td: kwargs['alpha']=td['alpha']
250  changed, newparams = TuneRex.tune_params_flux(
251  transpose(self.rn_history),
252  temps, **kwargs)
253  if changed:
254  self.rn_history = []
255  print newparams
256  self.inv_temps = [1/(kB*t) for t in newparams]
257 
258  def do_bookkeeping_before(self):
259  self.stepno += 1
260  if self.scheme == 'convective':
261  st = self.stirred
262  #check if we are done stirring this replica
263  if st['steps'] == 0:
264  st['pos'] = (st['pos'] + 1) % self.nreps
265  st['replica'] = st['order'][st['pos']]
266  st['dir'] = 1
267  st['steps'] = 2*(self.nreps - 1)
268  rep = st['replica']
269  state = self.statenums[rep]
270  #update endpoints
271  if state == self.nreps - 1:
272  st['dir'] = 0
273  elif state == 0:
274  st['dir'] = 1
275  self.stirred = st
276 
277  def do_bookkeeping_after(self, accepted):
278  if self.scheme == 'convective':
279  rep = self.stirred['replica']
280  state = self.statenums[rep]
281  dir = 2*self.stirred['dir'] - 1
282  expected = (min(state,state+dir),max(state,state+dir))
283  if self.stirred['pair'] in accepted:
284  self.stirred['steps'] -= 1
285 
286  def replica_exchange(self):
287  "main entry point for replica-exchange"
288  #print "replica exchange"
289  self.do_bookkeeping_before()
290  #tune temperatures
291  if self.tune_temps:
292  self.tune_rex()
293  #print "energies"
294  energies = self.sort_per_state(self.get_energies())
295  #print "pairs list"
296  plist = self.gen_pairs_list()
297  #print "metropolis"
298  metrop = self.get_metropolis(plist,energies)
299  #print "exchanges"
300  accepted = self.try_exchanges(plist, metrop)
301  #print "propagate"
302  self.perform_exchanges(accepted)
303  #print "book"
304  self.do_bookkeeping_after(accepted)
305  #print "done"