IMP logo
IMP Reference Guide  develop.330bebda01,2025/01/21
The Integrative Modeling Platform
demux_trajs.py
1 #!/usr/bin/env python
2 
3 import sys
4 import os
5 import re
6 import shutil
7 
8 
9 class LogStep:
10 
11  def __init__(self, stepno, statline, header):
12  self.stepno = stepno
13  self.stats = statline
14  self.header = header
15  self.dumps = {}
16  self.trajs = {}
17 
18  def add(self, ftype, category, *data):
19  if ftype == 'dump':
20  self.dumps[category] = data[0]
21  elif ftype == 'traj':
22  self.trajs[category] = {'ftype': data[0],
23  'fullpath': data[1],
24  'stepno': data[2],
25  'tail': data[3]}
26  else:
27  raise ValueError("unknown file type")
28 
29  def get_stats_header(self):
30  return self.header
31 
32  def get_stats(self):
33  return self.stats
34 
35  def get_dumps(self):
36  return self.dumps
37 
38  def get_trajs(self):
39  return self.trajs
40 
41 
42 class LogHolder:
43 
44  """Manages information on a given simulation.
45  Assumes the existence of a _stats.txt file, and handles more files if
46  available.
47  folder : the folder which contains _stats.txt
48  prefix : the stats file is supposed to be prefix+'_stats.txt'
49  """
50 
51  def __init__(self, folder, prefix):
52  self.folder = folder
53  self.prefix = prefix
54  # verify that stats file exists
55  self.stats_file = os.path.join(folder, prefix + 'stats.txt')
56  if not os.path.isfile(self.stats_file):
57  raise ValueError('cannot find stats file %s' % self.stats_file)
58  # scan for other files
59  files = {}
60  for fl in os.listdir(folder):
61  # get variable part of file, exclude stats file
62  match = re.match(prefix + r'(.*)', fl)
63  if match is None:
64  continue
65  tail = match.group(1)
66  if tail == 'stats.txt':
67  continue
68  # get file category and add to files dict
69  category = tail.split('_')[0]
70  if category not in files:
71  files[category] = []
72  files[category].append(tail)
73  # see if there are multiple files in the same category, and store them
74  self.dumpfiles = {}
75  self.trajfiles = {}
76  for cat, fnames in files.items():
77  if len(fnames) > 1 \
78  or os.path.splitext(fnames[0].split('_')[-1])[0].isdigit():
79  # there are multiple files, no need to understand their content
80  if cat not in self.dumpfiles:
81  self.dumpfiles[cat] = []
82  for fname in fnames:
83  # parse tail and find index number
84  indexno = int(os.path.splitext(fname.split('_')[1])[0])
85  self.dumpfiles[cat].append((indexno, fname))
86  self.dumpfiles[cat] = dict(self.dumpfiles[cat])
87  # make sure there are no duplicate index numbers
88  if len(self.dumpfiles[cat]) != \
89  len(set(self.dumpfiles[cat].keys())):
90  raise ValueError("found duplicates in %s %s %s"
91  % (folder, prefix, fname))
92  else:
93  # this is a trajectory, need to be able to parse it
94  fname = fnames[0]
95  ext = os.path.splitext(fname)[1]
96  if ext.startswith('.rmf'):
97  self.trajfiles[cat] = (ext[1:], fname)
98  else:
99  raise ValueError("Unknown extension: %s in file %s"
100  % (ext, fname))
101 
102  def get_stats_header(self):
103  if not hasattr(self, 'stats_handle'):
104  self.stats_handle = open(self.stats_file)
105  # read the file and guess the number of lines
106  # for now, be compatible with only one line
107  self.stats_handle.readline()
108  self.stats_first_line = self.stats_handle.readline()
109  if self.stats_first_line.startswith('#'):
110  raise ValueError('stats file must be 1-line only')
111  self.stats_handle = open(self.stats_file)
112  self.stats_header = self.stats_handle.readline()
113  return self.stats_header
114 
115  def get_first_stats_line(self):
116  # make sure file is open, skip header
117  self.get_stats_header()
118  return self.stats_first_line
119 
120  def _get_next_stats(self):
121  # make sure file is open, skip header
122  self.get_stats_header()
123  for line in self.stats_handle:
124  yield line
125 
126  def items(self):
127  """iterate over all time steps"""
128  # open stats file, store header and loop over stats file
129  for stat in self._get_next_stats():
130  # extract step number and create LogStep
131  stepno = int(stat.split()[1])
132  step = LogStep(stepno, stat, self.get_stats_header())
133  # get other files' entries at this step if available
134  for cat, df in self.dumpfiles.items():
135  if stepno in df:
136  fullpath = os.path.join(self.folder,
137  self.prefix + df[stepno])
138  step.add('dump', cat, fullpath)
139  for cat, tf in self.trajfiles.items():
140  fullpath = os.path.join(self.folder,
141  self.prefix + tf[1])
142  step.add('traj', cat, tf[0], fullpath, stepno, tf[1])
143  # yield a LogStep containing these entries
144  yield step
145 
146 
147 class Demuxer:
148 
149  """uses column to demux a replica trajectory. Assumes column points to a
150  float or integer type, which is allowed to change over time. Attribution is
151  based on order of float params. State 0 will be lowest param etc. Use
152  reverse=True to start with highest.
153  """
154 
155  def __init__(self, logs, outfolder, column, reverse=False):
156  self.logs = logs
157  self.reverse = reverse
158  self.column = column
159  self.outfolder = outfolder
160  self.stat_handles = {}
161  self.traj_handles_in = {}
162  self.traj_handles_out = {}
163  self.folders = {}
164  # create needed folders
165  if not os.path.isdir(outfolder):
166  os.mkdir(outfolder)
167  for log in range(len(self.logs)):
168  fname = os.path.join(outfolder, 'p%d' % log)
169  if not os.path.isdir(fname):
170  os.mkdir(fname)
171  self.folders[log] = fname
172  # make sure every log has the same header
173  h0 = self.logs[0].get_stats_header()
174  for log in self.logs[1:]:
175  if h0 != log.get_stats_header():
176  raise ValueError("headers must be identical!")
177  # get column number from header
178  tokens = [idx for idx, i in enumerate(h0.split()) if self.column in i]
179  if len(tokens) == 0:
180  raise ValueError("column %d not found in this header\n%s"
181  % (column, h0))
182  elif len(tokens) > 1:
183  raise ValueError("column %d found multiple times!\n%s"
184  % (column, h0))
185  self.colno = tokens[0]
186 
187  def get_param(self, statline):
188  return float(statline.split()[self.colno])
189 
190  def _write_step_stats(self, stateno, lstep):
191  # check if stats file is open
192  if stateno not in self.stat_handles:
193  self.stat_handles[stateno] = open(
194  os.path.join(self.folders[stateno],
195  str(stateno) + '_stats.txt'), 'w')
196  self.stat_handles[stateno].write(lstep.get_stats_header())
197  # write stats
198  self.stat_handles[stateno].write(lstep.get_stats())
199 
200  def _write_step_dump(self, stateno, lstep):
201  for cat, fname in lstep.get_dumps().items():
202  shutil.copyfile(fname,
203  os.path.join(self.folders[stateno],
204  str(stateno) + '_' + cat
205  + fname.split(cat)[1]))
206 
207  def _write_traj_rmf(self, infile, instep, outfile, stateno, cat):
208  import RMF
209  # make sure infile is open
210  if infile not in self.traj_handles_in:
211  src = RMF.open_rmf_file_read_only(infile)
212  self.traj_handles_in[infile] = src
213  src = self.traj_handles_in[infile]
214  # make sure outfile is open
215  if outfile not in self.traj_handles_out:
216  dest = RMF.create_rmf_file(outfile)
217  self.traj_handles_out[outfile] = dest
218  RMF.clone_file_info(src, dest)
219  RMF.clone_hierarchy(src, dest)
220  RMF.clone_static_frame(src, dest)
221  dest = self.traj_handles_out[outfile]
222  # clone frame
223  frameid = src.get_frames()[instep - 1]
224  src.set_current_frame(frameid)
225  dest.add_frame(src.get_name(frameid), src.get_type(frameid))
226  RMF.clone_loaded_frame(src, dest)
227 
228  def _write_step_traj(self, stateno, lstep):
229  # loop over categories
230  for cat, data in lstep.get_trajs().items():
231  destfile = os.path.join(self.outfolder, 'p' + str(stateno),
232  str(stateno) + '_' + data['tail'])
233  if data['ftype'].startswith('rmf'):
234  self._write_traj_rmf(data['fullpath'], data['stepno'],
235  destfile, stateno, cat)
236  else:
237  raise ValueError("unknown trajectory file type")
238 
239  def _write_step(self, stateno, lstep):
240  self._write_step_stats(stateno, lstep)
241  self._write_step_dump(stateno, lstep)
242  self._write_step_traj(stateno, lstep)
243 
244  def write(self):
245  # loop over time steps
246  log_iterators = [list(log.items()) for log in self.logs]
247  print("Demuxing", len(log_iterators), "replicas")
248  for idx, steps in enumerate(zip(*log_iterators)):
249  if idx % 10 == 0 and idx > 0:
250  print("step", idx, '\r', end=' ')
251  sys.stdout.flush()
252  # assign state numbers to these logs
253  params = [(self.get_param(i.get_stats()), i) for i in steps]
254  params.sort(reverse=self.reverse)
255  # write them
256  for i in range(len(params)):
257  self._write_step(i, params[i][1])
258  print("Done")
259 
260 
261 def get_prefix(folder):
262  rval = [re.match(r'(.*_)stats.txt', f) for f in os.listdir(folder)]
263  rval = [i for i in rval if i]
264  if len(rval) != 1:
265  raise ValueError("stats file not unique, found %d" % len(rval))
266  return rval[0].group(1)
267 
268 
269 if __name__ == '__main__':
270  if len(sys.argv) == 1 or len(sys.argv) > 4:
271  sys.exit("""demux_trajs.py column [infolder [outfolder]]
272  expects r?? folders in infolder and will write p?? folders in
273  outfolder. infolder must contain a _stats.txt file which will contain
274  a header. column must be a substring matching to one of the columns in
275  the _stats.txt files. It will typically be a temperature, or a state
276  number. That column will be used for demuxing. Folders are optional
277  and will be taken as ./ if not indicated.
278  """)
279  column = sys.argv[1]
280  if len(sys.argv) == 3:
281  infolder = sys.argv[2]
282  outfolder = './'
283  elif len(sys.argv) == 4:
284  infolder = sys.argv[2]
285  outfolder = sys.argv[3]
286  else:
287  infolder = './'
288  outfolder = './'
289  # loop over infolder and read stats files
290  folders = [os.path.join(infolder, f)
291  for f in os.listdir(infolder) if re.match(r'r\d+', f)]
292  replica_logs = [LogHolder(f, prefix)
293  for f, prefix in zip(folders, map(get_prefix, folders))]
294  demux = Demuxer(replica_logs, outfolder, column, reverse=True)
295  demux.write()
def items
iterate over all time steps
Definition: demux_trajs.py:126
Manages information on a given simulation.
Definition: demux_trajs.py:42
uses column to demux a replica trajectory.
Definition: demux_trajs.py:147