5 from __future__
import print_function, division
7 from numpy.random
import random, randint
8 from numpy.random
import shuffle
9 kB = 1.3806503 * 6.0221415 / 4184.0
14 def __init__(self, nreps, inv_temps, grid, sfo_id,
15 rexlog=
'replicanums.txt', scheme=
'standard', xchg=
'random',
16 convectivelog=
'stirred.txt', tune_temps=
False,
17 tune_data={}, templog=
'temps.txt'):
20 self.replicanums = list(range(nreps))
22 self.statenums = list(range(nreps))
26 self.inv_temps = inv_temps
31 self.tune_temps = tune_temps
32 self.tune_data = tune_data
34 self.templog = templog
36 if scheme ==
"convective":
39 self.stirred[
'order'] = list(range(self.nreps))
40 self.stirred[
'order'].reverse()
42 self.stirred[
'replica'] = self.stirred[
'order'][0]
44 self.stirred[
'pos'] = 0
46 if self.stirred[
'replica'] != self.nreps - 1:
47 self.stirred[
'dir'] = 1
49 self.stirred[
'dir'] = 0
51 self.stirred[
'steps'] = 2 * (self.nreps - 1)
52 self.convectivelog = convectivelog
53 self.write_rex_stats()
55 def sort_per_state(self, inplist):
56 "sorts a replica list per state"
57 if len(inplist) != self.nreps:
58 raise ValueError(
'list has wrong size')
59 return [inplist[i]
for i
in self.replicanums]
61 def sort_per_replica(self, inplist):
62 "sorts a state list per replica number"
63 if len(inplist) != self.nreps:
64 raise ValueError(
'list has wrong size')
65 return [inplist[i]
for i
in self.statenums]
67 def get_energies(self):
68 "return replica-sorted energies"
69 return self.grid.gather(
70 self.grid.broadcast(self.sfo_id,
'm',
'evaluate',
False))
72 def gen_pairs_list_gromacs(self, direction):
73 """generate ordered list of exchange pairs based on direction.
74 direction == 0 : (0,1)(2,3)...
75 direction == 1 : (0(1,2)(3,4)...
79 if direction != 0
and direction != 1:
80 raise ValueError(direction)
83 ret = [(2 * i + direction, 2 * i + 1 + direction)
84 for i
in range(nreps // 2)]
89 def gen_pairs_list_rand(self, needed=[]):
90 "generate list of neighboring pairs of states"
94 init = [(i, i + 1)
for i
in range(nreps - 1)]
96 init.append((0, nreps - 1))
100 raise ValueError(
"wrong format for 'needed' list")
101 pairslist.append((i, i + 1))
102 init.remove((i - 1, i))
103 init.remove((i, i + 1))
104 init.remove((i + 1, i + 2))
107 i = randint(0, len(init))
112 init = [(r, q)
for (r, q)
in init
113 if (r
not in pair
and q
not in pair)]
116 if not pair == (0, nreps - 1):
117 pairslist.append(pair)
122 def gen_pairs_list_conv(self):
124 rep = self.stirred[
'replica']
125 state = self.statenums[rep]
126 pair = sorted([state, state + 2 * self.stirred[
'dir'] - 1])
127 self.stirred[
'pair'] = tuple(pair)
128 if self.xchg ==
'gromacs':
129 dir = (state + 1 + self.stirred[
'dir']) % 2
130 return self.gen_pairs_list_gromacs(dir)
131 elif self.xchg ==
'random':
132 return self.gen_pairs_list_rand(needed=[self.stirred[
'pair']])
134 raise NotImplementedError(
135 "Unknown exchange method: %s" %
138 def gen_pairs_list(self):
139 if self.scheme ==
'standard':
140 if self.xchg ==
'gromacs':
141 return self.gen_pairs_list_gromacs(self.stepno % 2)
142 elif self.xchg ==
'random':
143 return self.gen_pairs_list_rand()
145 raise NotImplementedError(
146 "unknown exchange method: %s" %
148 elif self.scheme ==
'convective':
149 return self.gen_pairs_list_conv()
151 raise NotImplementedError(
152 "unknown exchange scheme: %s" %
155 def get_cross_energies(self, pairslist):
156 "get energies assuming all exchanges have succeeded"
157 print(
"this is not needed for temperature replica-exchange")
158 raise NotImplementedError
160 def get_metropolis(self, pairslist, old_ene):
161 """compute metropolis criterion for temperature replica exchange
162 e.g. exp(Delta beta Delta E)
163 input: list of pairs, list of state-sorted energies
166 for (s1, s2)
in pairslist:
168 min(1, exp((old_ene[s2] - old_ene[s1]) *
169 (self.inv_temps[s2] - self.inv_temps[s1])))
172 def try_exchanges(self, plist, metrop):
175 if (metrop[couple] >= 1)
or (random() < metrop[couple]):
176 accepted.append(couple)
179 def perform_exchanges(self, accepted):
180 "exchange given state couples both in local variables and on the grid"
182 for (i, j)
in accepted:
184 ri = self.replicanums[i]
185 rj = self.replicanums[j]
186 self.statenums[ri] = j
187 self.statenums[rj] = i
189 buf = self.replicanums[i]
190 self.replicanums[i] = self.replicanums[j]
191 self.replicanums[j] = buf
193 newtemps = self.sort_per_replica(self.inv_temps)
194 states = self.grid.gather(
195 self.grid.broadcast(self.sfo_id,
'get_state'))
196 for (i, j)
in accepted:
197 ri = self.replicanums[i]
198 rj = self.replicanums[j]
200 states[ri] = states[rj]
202 for temp, state
in zip(newtemps, states):
203 state[
'inv_temp'] = temp
205 self.grid.scatter(self.sfo_id,
'set_state', states))
207 def write_rex_stats(self):
208 "write replica numbers as a function of state"
209 fl = open(self.logfile,
'a')
210 fl.write(
'%8d ' % self.stepno)
211 fl.write(
' '.join([
'%2d' % (i + 1)
for i
in self.replicanums]))
214 if self.scheme ==
'convective':
215 fl = open(self.convectivelog,
'a')
216 fl.write(
'%5d ' % self.stepno)
217 fl.write(
'%2d ' % (self.stirred[
'replica'] + 1))
218 fl.write(
'%2d ' % self.stirred[
'dir'])
219 fl.write(
'%2d\n' % self.stirred[
'steps'])
221 if self.tune_temps
and len(self.rn_history) == 1:
222 fl = open(self.templog,
'a')
223 fl.write(
'%5d ' % self.stepno)
224 fl.write(
' '.join([
'%.3f' % i
for i
in self.inv_temps]))
229 """use TuneRex to optimize temp set. Temps are optimized every
230 'rate' steps and 'method' is used. Data is accumulated as long as
231 the temperatures weren't optimized.
232 td keys that should be passed to init:
233 rate : the rate at which to try tuning temps
235 targetAR : for ar only, target acceptance rate
236 alpha : Type I error to use.
240 self.rn_history.append([i
for i
in self.replicanums])
242 if len(self.rn_history) % td[
'rate'] == 0\
243 and len(self.rn_history) > 0:
244 temps = [1 / (kB * la)
for la
in self.inv_temps]
246 if td[
'method'] ==
'ar':
248 kwargs[
'targetAR'] = td[
'targetAR']
250 kwargs[
'alpha'] = td[
'alpha']
251 if 'dumb_scale' in td:
252 kwargs[
'dumb_scale'] = td[
'dumb_scale']
253 indicators = TuneRex.compute_indicators(
254 transpose(self.rn_history))
255 changed, newparams = TuneRex.tune_params_ar(
256 indicators, temps, **kwargs)
257 elif td[
'method'] ==
'flux':
259 kwargs[
'alpha'] = td[
'alpha']
260 changed, newparams = TuneRex.tune_params_flux(
261 transpose(self.rn_history),
266 self.inv_temps = [1 / (kB * t)
for t
in newparams]
268 def do_bookkeeping_before(self):
270 if self.scheme ==
'convective':
274 st[
'pos'] = (st[
'pos'] + 1) % self.nreps
275 st[
'replica'] = st[
'order'][st[
'pos']]
277 st[
'steps'] = 2 * (self.nreps - 1)
279 state = self.statenums[rep]
281 if state == self.nreps - 1:
287 def do_bookkeeping_after(self, accepted):
288 if self.scheme ==
'convective':
289 rep = self.stirred[
'replica']
290 state = self.statenums[rep]
291 dir = 2 * self.stirred[
'dir'] - 1
292 expected = (min(state, state + dir), max(state, state + dir))
293 if self.stirred[
'pair']
in accepted:
294 self.stirred[
'steps'] -= 1
296 def replica_exchange(self):
297 "main entry point for replica-exchange"
299 self.do_bookkeeping_before()
304 energies = self.sort_per_state(self.get_energies())
306 plist = self.gen_pairs_list()
308 metrop = self.get_metropolis(plist, energies)
310 accepted = self.try_exchanges(plist, metrop)
312 self.perform_exchanges(accepted)
314 self.do_bookkeeping_after(accepted)