6 from numpy.random
import random, randint
7 from numpy.random
import shuffle
8 kB = 1.3806503 * 6.0221415 / 4184.0
12 def __init__(self,nreps,inv_temps,grid,sfo_id,
13 rexlog=
'replicanums.txt', scheme=
'standard', xchg=
'random',
14 convectivelog=
'stirred.txt', tune_temps =
False,
15 tune_data={}, templog =
'temps.txt'):
18 self.replicanums = range(nreps)
20 self.statenums = range(nreps)
24 self.inv_temps = inv_temps
29 self.tune_temps = tune_temps
30 self.tune_data = tune_data
32 self.templog = templog
34 if scheme ==
"convective":
37 self.stirred[
'order'] = range(self.nreps)
38 self.stirred[
'order'].reverse()
40 self.stirred[
'replica']=self.stirred[
'order'][0]
44 if self.stirred[
'replica'] != self.nreps-1:
49 self.stirred[
'steps']=2*(self.nreps-1)
50 self.convectivelog = convectivelog
51 self.write_rex_stats()
53 def sort_per_state(self, inplist):
54 "sorts a replica list per state"
55 if len(inplist) != self.nreps:
56 raise ValueError,
'list has wrong size'
57 return [inplist[i]
for i
in self.replicanums]
59 def sort_per_replica(self, inplist):
60 "sorts a state list per replica number"
61 if len(inplist) != self.nreps:
62 raise ValueError,
'list has wrong size'
63 return [inplist[i]
for i
in self.statenums]
65 def get_energies(self):
66 "return replica-sorted energies"
67 return self.grid.gather(
68 self.grid.broadcast(self.sfo_id,
'm',
'evaluate',
False))
70 def gen_pairs_list_gromacs(self, direction):
71 """generate ordered list of exchange pairs based on direction.
72 direction == 0 : (0,1)(2,3)...
73 direction == 1 : (0(1,2)(3,4)...
77 if direction != 0
and direction != 1:
78 raise ValueError, direction
81 ret = [(2*i+direction,2*i+1+direction)
for i
in xrange(nreps/2)]
86 def gen_pairs_list_rand(self, needed = []):
87 "generate list of neighboring pairs of states"
91 init = [(i,i+1)
for i
in xrange(nreps-1)]
93 init.append((0,nreps-1))
97 raise ValueError,
"wrong format for 'needed' list"
98 pairslist.append((i,i+1))
101 init.remove((i+1,i+2))
104 i = randint(0,len(init))
109 init = [(r,q)
for (r,q)
in init
110 if (r
not in pair
and q
not in pair)]
113 if not pair == (0,nreps-1):
114 pairslist.append(pair)
119 def gen_pairs_list_conv(self):
121 rep = self.stirred[
'replica']
122 state = self.statenums[rep]
123 pair = [state, state + 2*self.stirred[
'dir'] - 1]
125 self.stirred[
'pair'] = tuple(pair)
126 if self.xchg ==
'gromacs':
127 dir = (state + 1 + self.stirred[
'dir']) % 2
128 return self.gen_pairs_list_gromacs(dir)
129 elif self.xchg ==
'random':
130 return self.gen_pairs_list_rand(needed=[self.stirred[
'pair']])
132 raise NotImplementedError, \
133 "Unknown exchange method: %s" % self.xchg
135 def gen_pairs_list(self):
136 if self.scheme ==
'standard':
137 if self.xchg ==
'gromacs':
138 return self.gen_pairs_list_gromacs(self.stepno % 2)
139 elif self.xchg ==
'random':
140 return self.gen_pairs_list_rand()
142 raise NotImplementedError, \
143 "unknown exchange method: %s" % self.xchg
144 elif self.scheme ==
'convective':
145 return self.gen_pairs_list_conv()
147 raise NotImplementedError, \
148 "unknown exchange scheme: %s" % self.scheme
150 def get_cross_energies(self, pairslist):
151 "get energies assuming all exchanges have succeeded"
152 print "this is not needed for temperature replica-exchange"
153 raise NotImplementedError
155 def get_metropolis(self, pairslist, old_ene):
156 """compute metropolis criterion for temperature replica exchange
157 e.g. exp(Delta beta Delta E)
158 input: list of pairs, list of state-sorted energies
161 for (s1,s2)
in pairslist:
163 min(1,exp((old_ene[s2]-old_ene[s1])*
164 (self.inv_temps[s2]-self.inv_temps[s1])))
167 def try_exchanges(self, plist, metrop):
170 if (metrop[couple] >= 1)
or (random() < metrop[couple]):
171 accepted.append(couple)
174 def perform_exchanges(self, accepted):
175 "exchange given state couples both in local variables and on the grid"
177 for (i,j)
in accepted:
179 ri = self.replicanums[i]
180 rj = self.replicanums[j]
181 self.statenums[ri] = j
182 self.statenums[rj] = i
184 buf = self.replicanums[i]
185 self.replicanums[i] = self.replicanums[j]
186 self.replicanums[j] = buf
188 newtemps = self.sort_per_replica(self.inv_temps)
189 states = self.grid.gather(
190 self.grid.broadcast(self.sfo_id,
'get_state'))
191 for (i,j)
in accepted:
192 ri = self.replicanums[i]
193 rj = self.replicanums[j]
195 states[ri] = states[rj]
197 for temp,state
in zip(newtemps,states):
198 state[
'inv_temp'] = temp
200 self.grid.scatter(self.sfo_id,
'set_state', states))
202 def write_rex_stats(self):
203 "write replica numbers as a function of state"
204 fl=open(self.logfile,
'a')
205 fl.write(
'%8d ' % self.stepno)
206 fl.write(
' '.join([
'%2d' % (i+1)
for i
in self.replicanums]))
209 if self.scheme ==
'convective':
210 fl=open(self.convectivelog,
'a')
211 fl.write(
'%5d ' % self.stepno)
212 fl.write(
'%2d ' % (self.stirred[
'replica']+1))
213 fl.write(
'%2d ' % self.stirred[
'dir'])
214 fl.write(
'%2d\n' % self.stirred[
'steps'])
216 if self.tune_temps
and len(self.rn_history) == 1:
217 fl=open(self.templog,
'a')
218 fl.write(
'%5d ' % self.stepno)
219 fl.write(
' '.join([
'%.3f' % i
for i
in self.inv_temps]))
224 """use TuneRex to optimize temp set. Temps are optimized every
225 'rate' steps and 'method' is used. Data is accumulated as long as
226 the temperatures weren't optimized.
227 td keys that should be passed to init:
228 rate : the rate at which to try tuning temps
230 targetAR : for ar only, target acceptance rate
231 alpha : Type I error to use.
235 self.rn_history.append([i
for i
in self.replicanums])
237 if len(self.rn_history) % td[
'rate'] == 0\
238 and len(self.rn_history) > 0:
239 temps = [1/(kB*la)
for la
in self.inv_temps]
241 if td[
'method'] ==
'ar':
242 if 'targetAR' in td: kwargs[
'targetAR']=td[
'targetAR']
243 if 'alpha' in td: kwargs[
'alpha']=td[
'alpha']
244 if 'dumb_scale' in td: kwargs[
'dumb_scale']=td[
'dumb_scale']
245 indicators = TuneRex.compute_indicators(
246 transpose(self.rn_history))
247 changed, newparams = TuneRex.tune_params_ar(indicators, temps, **kwargs)
248 elif td[
'method'] ==
'flux':
249 if 'alpha' in td: kwargs[
'alpha']=td[
'alpha']
250 changed, newparams = TuneRex.tune_params_flux(
251 transpose(self.rn_history),
256 self.inv_temps = [1/(kB*t)
for t
in newparams]
258 def do_bookkeeping_before(self):
260 if self.scheme ==
'convective':
264 st[
'pos'] = (st[
'pos'] + 1) % self.nreps
265 st[
'replica'] = st[
'order'][st[
'pos']]
267 st[
'steps'] = 2*(self.nreps - 1)
269 state = self.statenums[rep]
271 if state == self.nreps - 1:
277 def do_bookkeeping_after(self, accepted):
278 if self.scheme ==
'convective':
279 rep = self.stirred[
'replica']
280 state = self.statenums[rep]
281 dir = 2*self.stirred[
'dir'] - 1
282 expected = (min(state,state+dir),max(state,state+dir))
283 if self.stirred[
'pair']
in accepted:
284 self.stirred[
'steps'] -= 1
286 def replica_exchange(self):
287 "main entry point for replica-exchange"
289 self.do_bookkeeping_before()
294 energies = self.sort_per_state(self.get_energies())
296 plist = self.gen_pairs_list()
298 metrop = self.get_metropolis(plist,energies)
300 accepted = self.try_exchanges(plist, metrop)
302 self.perform_exchanges(accepted)
304 self.do_bookkeeping_after(accepted)