6 from numpy.random
import random, randint
7 from numpy.random
import shuffle
8 kB = 1.3806503 * 6.0221415 / 4184.0
13 def __init__(self, nreps, inv_temps, grid, sfo_id,
14 rexlog=
'replicanums.txt', scheme=
'standard', xchg=
'random',
15 convectivelog=
'stirred.txt', tune_temps=
False,
16 tune_data={}, templog=
'temps.txt'):
19 self.replicanums = range(nreps)
21 self.statenums = range(nreps)
25 self.inv_temps = inv_temps
30 self.tune_temps = tune_temps
31 self.tune_data = tune_data
33 self.templog = templog
35 if scheme ==
"convective":
38 self.stirred[
'order'] = range(self.nreps)
39 self.stirred[
'order'].reverse()
41 self.stirred[
'replica'] = self.stirred[
'order'][0]
43 self.stirred[
'pos'] = 0
45 if self.stirred[
'replica'] != self.nreps - 1:
46 self.stirred[
'dir'] = 1
48 self.stirred[
'dir'] = 0
50 self.stirred[
'steps'] = 2 * (self.nreps - 1)
51 self.convectivelog = convectivelog
52 self.write_rex_stats()
54 def sort_per_state(self, inplist):
55 "sorts a replica list per state"
56 if len(inplist) != self.nreps:
57 raise ValueError(
'list has wrong size')
58 return [inplist[i]
for i
in self.replicanums]
60 def sort_per_replica(self, inplist):
61 "sorts a state list per replica number"
62 if len(inplist) != self.nreps:
63 raise ValueError(
'list has wrong size')
64 return [inplist[i]
for i
in self.statenums]
66 def get_energies(self):
67 "return replica-sorted energies"
68 return self.grid.gather(
69 self.grid.broadcast(self.sfo_id,
'm',
'evaluate',
False))
71 def gen_pairs_list_gromacs(self, direction):
72 """generate ordered list of exchange pairs based on direction.
73 direction == 0 : (0,1)(2,3)...
74 direction == 1 : (0(1,2)(3,4)...
78 if direction != 0
and direction != 1:
79 raise ValueError(direction)
82 ret = [(2 * i + direction, 2 * i + 1 + direction)
83 for i
in xrange(nreps / 2)]
88 def gen_pairs_list_rand(self, needed=[]):
89 "generate list of neighboring pairs of states"
93 init = [(i, i + 1)
for i
in xrange(nreps - 1)]
95 init.append((0, nreps - 1))
99 raise ValueError(
"wrong format for 'needed' list")
100 pairslist.append((i, i + 1))
101 init.remove((i - 1, i))
102 init.remove((i, i + 1))
103 init.remove((i + 1, i + 2))
106 i = randint(0, len(init))
111 init = [(r, q)
for (r, q)
in init
112 if (r
not in pair
and q
not in pair)]
115 if not pair == (0, nreps - 1):
116 pairslist.append(pair)
121 def gen_pairs_list_conv(self):
123 rep = self.stirred[
'replica']
124 state = self.statenums[rep]
125 pair = sorted([state, state + 2 * self.stirred[
'dir'] - 1])
126 self.stirred[
'pair'] = tuple(pair)
127 if self.xchg ==
'gromacs':
128 dir = (state + 1 + self.stirred[
'dir']) % 2
129 return self.gen_pairs_list_gromacs(dir)
130 elif self.xchg ==
'random':
131 return self.gen_pairs_list_rand(needed=[self.stirred[
'pair']])
133 raise NotImplementedError(
134 "Unknown exchange method: %s" %
137 def gen_pairs_list(self):
138 if self.scheme ==
'standard':
139 if self.xchg ==
'gromacs':
140 return self.gen_pairs_list_gromacs(self.stepno % 2)
141 elif self.xchg ==
'random':
142 return self.gen_pairs_list_rand()
144 raise NotImplementedError(
145 "unknown exchange method: %s" %
147 elif self.scheme ==
'convective':
148 return self.gen_pairs_list_conv()
150 raise NotImplementedError(
151 "unknown exchange scheme: %s" %
154 def get_cross_energies(self, pairslist):
155 "get energies assuming all exchanges have succeeded"
156 print "this is not needed for temperature replica-exchange"
157 raise NotImplementedError
159 def get_metropolis(self, pairslist, old_ene):
160 """compute metropolis criterion for temperature replica exchange
161 e.g. exp(Delta beta Delta E)
162 input: list of pairs, list of state-sorted energies
165 for (s1, s2)
in pairslist:
167 min(1, exp((old_ene[s2] - old_ene[s1]) *
168 (self.inv_temps[s2] - self.inv_temps[s1])))
171 def try_exchanges(self, plist, metrop):
174 if (metrop[couple] >= 1)
or (random() < metrop[couple]):
175 accepted.append(couple)
178 def perform_exchanges(self, accepted):
179 "exchange given state couples both in local variables and on the grid"
181 for (i, j)
in accepted:
183 ri = self.replicanums[i]
184 rj = self.replicanums[j]
185 self.statenums[ri] = j
186 self.statenums[rj] = i
188 buf = self.replicanums[i]
189 self.replicanums[i] = self.replicanums[j]
190 self.replicanums[j] = buf
192 newtemps = self.sort_per_replica(self.inv_temps)
193 states = self.grid.gather(
194 self.grid.broadcast(self.sfo_id,
'get_state'))
195 for (i, j)
in accepted:
196 ri = self.replicanums[i]
197 rj = self.replicanums[j]
199 states[ri] = states[rj]
201 for temp, state
in zip(newtemps, states):
202 state[
'inv_temp'] = temp
204 self.grid.scatter(self.sfo_id,
'set_state', states))
206 def write_rex_stats(self):
207 "write replica numbers as a function of state"
208 fl = open(self.logfile,
'a')
209 fl.write(
'%8d ' % self.stepno)
210 fl.write(
' '.join([
'%2d' % (i + 1)
for i
in self.replicanums]))
213 if self.scheme ==
'convective':
214 fl = open(self.convectivelog,
'a')
215 fl.write(
'%5d ' % self.stepno)
216 fl.write(
'%2d ' % (self.stirred[
'replica'] + 1))
217 fl.write(
'%2d ' % self.stirred[
'dir'])
218 fl.write(
'%2d\n' % self.stirred[
'steps'])
220 if self.tune_temps
and len(self.rn_history) == 1:
221 fl = open(self.templog,
'a')
222 fl.write(
'%5d ' % self.stepno)
223 fl.write(
' '.join([
'%.3f' % i
for i
in self.inv_temps]))
228 """use TuneRex to optimize temp set. Temps are optimized every
229 'rate' steps and 'method' is used. Data is accumulated as long as
230 the temperatures weren't optimized.
231 td keys that should be passed to init:
232 rate : the rate at which to try tuning temps
234 targetAR : for ar only, target acceptance rate
235 alpha : Type I error to use.
239 self.rn_history.append([i
for i
in self.replicanums])
241 if len(self.rn_history) % td[
'rate'] == 0\
242 and len(self.rn_history) > 0:
243 temps = [1 / (kB * la)
for la
in self.inv_temps]
245 if td[
'method'] ==
'ar':
247 kwargs[
'targetAR'] = td[
'targetAR']
249 kwargs[
'alpha'] = td[
'alpha']
250 if 'dumb_scale' in td:
251 kwargs[
'dumb_scale'] = td[
'dumb_scale']
252 indicators = TuneRex.compute_indicators(
253 transpose(self.rn_history))
254 changed, newparams = TuneRex.tune_params_ar(
255 indicators, temps, **kwargs)
256 elif td[
'method'] ==
'flux':
258 kwargs[
'alpha'] = td[
'alpha']
259 changed, newparams = TuneRex.tune_params_flux(
260 transpose(self.rn_history),
265 self.inv_temps = [1 / (kB * t)
for t
in newparams]
267 def do_bookkeeping_before(self):
269 if self.scheme ==
'convective':
273 st[
'pos'] = (st[
'pos'] + 1) % self.nreps
274 st[
'replica'] = st[
'order'][st[
'pos']]
276 st[
'steps'] = 2 * (self.nreps - 1)
278 state = self.statenums[rep]
280 if state == self.nreps - 1:
286 def do_bookkeeping_after(self, accepted):
287 if self.scheme ==
'convective':
288 rep = self.stirred[
'replica']
289 state = self.statenums[rep]
290 dir = 2 * self.stirred[
'dir'] - 1
291 expected = (min(state, state + dir), max(state, state + dir))
292 if self.stirred[
'pair']
in accepted:
293 self.stirred[
'steps'] -= 1
295 def replica_exchange(self):
296 "main entry point for replica-exchange"
298 self.do_bookkeeping_before()
303 energies = self.sort_per_state(self.get_energies())
305 plist = self.gen_pairs_list()
307 metrop = self.get_metropolis(plist, energies)
309 accepted = self.try_exchanges(plist, metrop)
311 self.perform_exchanges(accepted)
313 self.do_bookkeeping_after(accepted)