IMP logo
IMP Reference Guide  2.20.1
The Integrative Modeling Platform
good_scoring_model_selector.py
1 from __future__ import print_function, division
2 import RMF
3 import subprocess
4 import os
5 import shutil
6 import random
7 import glob
8 import operator
9 
10 
11 # If we have new enough IMP/RMF, do our own RMF slicing with provenance
12 if hasattr(RMF.NodeHandle, 'replace_child'):
13  def rmf_slice(infile, frameid, outfile, num_runs, total_num_frames,
14  num_good_scoring):
15  inr = RMF.open_rmf_file_read_only(infile)
16  outr = RMF.create_rmf_file(outfile)
17  cpf = RMF.CombineProvenanceFactory(outr)
18  fpf = RMF.FilterProvenanceFactory(outr)
19  RMF.clone_file_info(inr, outr)
20  RMF.clone_hierarchy(inr, outr)
21  RMF.clone_static_frame(inr, outr)
22  inr.set_current_frame(RMF.FrameID(frameid))
23  outr.add_frame("f0")
24  RMF.clone_loaded_frame(inr, outr)
25  rn = outr.get_root_node()
26  children = rn.get_children()
27  if len(children) == 0:
28  return
29  rn = children[0] # Should be the top-level IMP node
30  prov = [c for c in rn.get_children() if c.get_type() == RMF.PROVENANCE]
31  if not prov:
32  return
33  prov = prov[0]
34  # Add combine-provenance info
35  newp = rn.replace_child(prov, "combine", RMF.PROVENANCE)
36  cp = cpf.get(newp)
37  cp.set_frames(total_num_frames)
38  cp.set_runs(num_runs)
39  # Add filter-provenance info
40  newp = rn.replace_child(newp, "filter", RMF.PROVENANCE)
41  fp = fpf.get(newp)
42  fp.set_frames(num_good_scoring)
43  fp.set_method("Best scoring")
44  # todo: put in a more appropriate value
45  fp.set_threshold(0.)
46 
47 # Otherwise, fall back to the rmf_slice command line tool
48 else:
49  def rmf_slice(infile, frameid, outfile, num_runs, total_num_frames,
50  num_good_scoring):
51  FNULL = open(os.devnull, 'w')
52  subprocess.call(['rmf_slice', infile, "-f", str(frameid), outfile],
53  stdout=FNULL, stderr=subprocess.STDOUT)
54 
55 
57  # Authors: Shruthi Viswanath
58 
59  ''' Select good-scoring models based on scores and/or data satisfaction.
60  Exrtact the corresponding RMFs and put them in a separate directory
61  '''
62 
63  def __init__(self, run_directory, run_prefix):
64  """Constructor.
65  @param run_directory the directory containing subdirectories of runs
66  @param run_prefix the prefix for each run directory. For e.g. if the
67  subdirectories are modeling_run1, modeling_run2, etc. the
68  prefix is modeling_run
69  """
70  self.run_dir = run_directory
71  self.run_prefix = run_prefix
72 
73  # list with each member as a tuple (run id,replica id,frame id)
74  # corresponding to good-scoring models
75  self.all_good_scoring_models = []
76 
77  def _all_run_dirs(self):
78  """Yield (pathname, runid) for all run directories (unsorted)"""
79  for x in os.listdir(self.run_dir):
80  if x.startswith(self.run_prefix):
81  runid = x[len(self.run_prefix):]
82  fullpath = os.path.join(self.run_dir, x)
83  if os.path.isdir(fullpath) and runid.isdigit():
84  yield (fullpath, runid)
85 
86  def _get_subfields_for_criteria(
87  self, field_headers, selection_keywords_list,
88  printing_keywords_list):
89  '''Given the list of keywords, get all the stat file entries
90  corresponding to each keyword.'''
91 
92  # list of dicts corresponding to field indices for each keyword
93  selection_fields = [{} for kw in selection_keywords_list]
94  # for ambiguous crosslink distances, it will store all restraints
95  # corresponding to the ambivalence in one dict.
96 
97  # just a placeholder
98  printing_fields = [-1 for j in range(len(printing_keywords_list))]
99 
100  for fh_index in field_headers:
101 
102  for ki, kw in enumerate(selection_keywords_list):
103  # need exact name of field unless it is a xlink distance
104  if kw == field_headers[fh_index]:
105  selection_fields[ki][kw] = fh_index
106 
107  elif kw in field_headers[fh_index] and \
108  field_headers[fh_index].startswith(
109  "CrossLinkingMassSpectrometry"
110  "Restraint_Distance_"):
111  # handle ambiguous restraints
112  (prot1, res1, prot2, res2) = \
113  field_headers[fh_index].split("|")[3:7]
114  prot1 = prot1.split('.')[0]
115  prot2 = prot2.split('.')[0]
116 
117  if (prot1, res1, prot2, res2) in selection_fields[ki]:
118  selection_fields[ki][
119  (prot1, res1, prot2, res2)].append(fh_index)
120  else:
121  # list of indices corresponding to all
122  # combinations of protein copies
123  selection_fields[ki][(prot1, res1, prot2, res2)] = \
124  [fh_index]
125 
126  for ki, kw in enumerate(printing_keywords_list):
127  if kw == field_headers[fh_index]:
128  printing_fields[ki] = fh_index
129 
130  return selection_fields, printing_fields
131 
132  def _get_crosslink_satisfaction(
133  self, crosslink_distance_values,
134  crosslink_percentage_lower_threshold,
135  crosslink_percentage_upper_threshold,
136  xlink_distance_lower_threshold, xlink_distance_upper_threshold):
137  '''For crosslinks, we want models with atleast x% (e.g. 90%) or more
138  crosslink satisfaction. A crosslink is satisfied if the distance
139  is between the lower and upper distance thresholds
140  @param crosslink_distance_values values of distances in the
141  current model
142  @param crosslink_percentage_lower_threshold atleast x% of crosslinks
143  should be within the below distance thresholds
144  @param crosslink_percentage_upper_threshold atmost x% of crosslinks
145  should be within the below distance thresholds
146  (usually 100%: dummy parameter)
147  @param xlink_distance_lower_threshold a crosslink should be at least
148  this distance apart (usually 0) to be considered satisfied
149  @param xlink_distance_upper_threshold a crosslink should be at most
150  this distance apart to be considered satisfied
151  '''
152  satisfied_xlinks = 0.0
153  for d in crosslink_distance_values:
154  if d >= xlink_distance_lower_threshold \
155  and d <= xlink_distance_upper_threshold:
156  satisfied_xlinks += 1.0
157 
158  percent_satisfied = \
159  satisfied_xlinks/float(len(crosslink_distance_values))
160 
161  if percent_satisfied >= crosslink_percentage_lower_threshold \
162  and percent_satisfied <= crosslink_percentage_upper_threshold:
163  return percent_satisfied, True
164  else:
165  return percent_satisfied, False
166 
167  def _get_score_satisfaction(self, score, lower_threshold, upper_threshold):
168  ''' Check if the score is within the thresholds
169  '''
170  if score <= upper_threshold and score >= lower_threshold:
171  return True
172  return False
173 
174  def _extract_models_from_trajectories(self, output_dir, num_runs,
175  total_num_frames, sampleA_set):
176  '''Given the list of all good-scoring model indices, extract
177  their frames and store them ordered by the list index.
178  Store the models in two sample subdirectories.'''
179  sampleA_dir = os.path.join(output_dir, "sample_A")
180  sampleB_dir = os.path.join(output_dir, "sample_B")
181  os.mkdir(sampleA_dir)
182  os.mkdir(sampleB_dir)
183 
184  num_gsm = sum(1 for e in self.all_good_scoring_models)
185  print("Extracting", num_gsm, "good scoring models.")
186  model_num = 1
187 
188  for i, gsm in enumerate(self.all_good_scoring_models):
189  if model_num % (num_gsm/10) == 0:
190  print(str(model_num / (num_gsm/10)*10)+"% Complete")
191  model_num += 1
192 
193  (runid, replicaid, frameid) = gsm
194 
195  trajfile = os.path.join(self.run_dir, self.run_prefix+runid,
196  'output', 'rmfs', replicaid+'.rmf3')
197 
198  sample_dir = sampleA_dir if i in sampleA_set else sampleB_dir
199  rmf_slice(trajfile, frameid,
200  os.path.join(sample_dir, str(i)+'.rmf3'),
201  num_runs, total_num_frames,
202  len(self.all_good_scoring_models))
203 
205  self, selection_keywords_list=[], printing_keywords_list=[],
206  aggregate_lower_thresholds=[], aggregate_upper_thresholds=[],
207  member_lower_thresholds=[], member_upper_thresholds=[],
208  extract=False):
209  ''' Loops over all stat files in the run directory and populates
210  the list of good-scoring models.
211  @param selection_keywords_list is the list of keywords in the PMI
212  stat file that need to be checked for each datatype/score
213  in the criteria list
214  @param printing_keywords_list is the list of keywords in the PMI
215  stat file whose values needs to be printed for selected models
216  @param aggregate_lower_thresholds The list of lower bounds on the
217  values corresponding to fields in the criteria_list.
218  Aggregates are used for terms like % of crosslink satisfaction
219  and thresholds of score terms
220  @param aggregate_upper_thresholds The list of upper bounds on the
221  values corresponding to fields in the criteria_list.
222  Aggregates are used for terms like % of crosslink satisfaction
223  and thresholds of score terms
224  @param member_lower_thresholds The list of lower bounds for values
225  of subcomponents of an aggregate term. E.g. for crosslink
226  satisfaction the thresholds are on distances for each
227  individual crosslink. For score terms this can be ignored
228  since thresholds are mentioned in the aggregate fields.
229  @param member_upper_thresholds The list of upper bounds for values
230  of subcomponents of an aggregate term. E.g. for crosslink
231  satisfaction the thresholds are on distances for each
232  individual crosslink. For score terms this can be ignored
233  since thresholds are mentioned in the aggregate fields.
234  '''
235  if extract:
236  output_dir = os.path.join(self.run_dir, "good_scoring_models")
237  else:
238  output_dir = os.path.join(self.run_dir, "filter")
239  if os.path.exists(output_dir):
240  shutil.rmtree(output_dir, ignore_errors=True)
241  os.mkdir(output_dir)
242 
243  outf = open(os.path.join(output_dir, "model_ids_scores.txt"), 'w')
244  # Header line first
245  print(' '.join(["Model_index", "Run_id", "Replica_id", "Frame_id"]
246  + selection_keywords_list + printing_keywords_list),
247  file=outf)
248 
249  num_runs = 0
250  total_num_frames = 0
251 
252  for each_run_dir, runid in sorted(self._all_run_dirs(),
253  key=operator.itemgetter(1)):
254  num_runs += 1
255 
256  print("Analyzing", runid)
257 
258  for each_replica_stat_file in sorted(
259  glob.glob(os.path.join(each_run_dir, "output",
260  "stat.*.out")),
261  key=lambda x: int(x.strip('.out').split('.')[-1])):
262 
263  replicaid = each_replica_stat_file.strip(".out").split(".")[-1]
264 
265  rsf = open(each_replica_stat_file, 'r')
266 
267  for line_index, each_model_line in enumerate(rsf.readlines()):
268  # for each model in the current replica
269 
270  if line_index == 0:
271  field_headers = eval(each_model_line.strip())
272  fields_for_selection, fields_for_printing = \
273  self._get_subfields_for_criteria(
274  field_headers, selection_keywords_list,
275  printing_keywords_list)
276  continue
277 
278  frameid = line_index-1
279 
280  dat = eval(each_model_line.strip())
281 
282  total_num_frames += 1
283  model_satisfies = False
284  selection_criteria_values = []
285 
286  for si, score_type in enumerate(selection_keywords_list):
287  if "crosslink" in score_type.lower() \
288  and "distance" in score_type.lower():
289 
290  crosslink_distance_values = []
291 
292  for xltype in fields_for_selection[si]:
293  crosslink_distance_values.append(
294  min([float(dat[j]) for j in
295  fields_for_selection[si][xltype]]))
296 
297  satisfied_percent, model_satisfies = \
298  self._get_crosslink_satisfaction(
299  crosslink_distance_values,
300  aggregate_lower_thresholds[si],
301  aggregate_upper_thresholds[si],
302  member_lower_thresholds[si],
303  member_upper_thresholds[si])
304 
305  selection_criteria_values.append(satisfied_percent)
306 
307  else:
308  score_value = float(
309  dat[fields_for_selection[si][score_type]])
310 
311  model_satisfies = self._get_score_satisfaction(
312  score_value, aggregate_lower_thresholds[si],
313  aggregate_upper_thresholds[si])
314  selection_criteria_values.append(score_value)
315 
316  if not model_satisfies:
317  break
318 
319  if model_satisfies:
320 
321  # Now get the printing criteria
322  printing_criteria_values = []
323  for si, score_type in enumerate(
324  printing_keywords_list):
325  if fields_for_printing[si] < 0:
326  raise KeyError(
327  "Bad stat file key '%s': use `imp_sampcon "
328  "show_stat <stat file>` to see all "
329  "acceptable keys" % score_type)
330  score_value = float(dat[fields_for_printing[si]])
331  printing_criteria_values.append(score_value)
332 
333  self.all_good_scoring_models.append(
334  (runid, replicaid, frameid))
335 
336  # Print out the scores finally
337 
338  print(' '.join(
339  [str(x) for x in
340  [len(self.all_good_scoring_models) - 1,
341  runid, replicaid, frameid]] +
342  ["%.2f" % s for s in selection_criteria_values] +
343  ["%.2f" % s for s in printing_criteria_values]),
344  file=outf)
345 
346  rsf.close()
347  outf.close()
348 
349  if extract:
350  sampA, sampB = self._split_good_scoring_models_into_two_subsets(
351  num_runs,
352  split_type="divide_by_run_ids" if num_runs > 1 else "random")
353  self._extract_models_from_trajectories(
354  output_dir, num_runs, total_num_frames, frozenset(sampA))
355  return sampA, sampB
356 
357  def _split_good_scoring_models_into_two_subsets(
358  self, num_runs, split_type="divide_by_run_ids"):
359  ''' Get the list of good scoring models and split them into two
360  samples. Return the two subsets.
361 
362  @param split_type how to split good scoring models into two samples.
363  Current options are:
364  (a) divide_by_run_ids : where the list of runids is divided
365  into 2. e.g. if we have runs from 1-50, good scoring
366  models from runs 1-25 is sample A and those from
367  runs 26-50 is sample B.
368  (b) random : split the set of good scoring models into
369  two subsets at random.
370  '''
371  sampleA_indices = []
372  sampleB_indices = []
373 
374  if split_type == "divide_by_run_ids": # split based on run ids
375 
376  half_num_runs = num_runs/2
377  for i, gsm in enumerate(self.all_good_scoring_models):
378  if int(gsm[0]) <= half_num_runs:
379  sampleA_indices.append(i)
380  else:
381  sampleB_indices.append(i)
382 
383  elif split_type == "random":
384  sampleA_indices = random.sample(
385  range(len(self.all_good_scoring_models)),
386  len(self.all_good_scoring_models)//2)
387  sampleB_indices = [
388  i for i in range(len(self.all_good_scoring_models))
389  if i not in sampleA_indices]
390 
391  # write model and sample IDs to a file
392  with open(os.path.join(self.run_dir, 'good_scoring_models',
393  'model_sample_ids.txt'), 'w') as f:
394  for i in sampleA_indices:
395  print(i, "A", file=f)
396  for i in sampleB_indices:
397  print(i, "B", file=f)
398  return sampleA_indices, sampleB_indices
def get_good_scoring_models
Loops over all stat files in the run directory and populates the list of good-scoring models...
Select good-scoring models based on scores and/or data satisfaction.