IMP logo
IMP Reference Guide  2.16.0
The Integrative Modeling Platform
good_scoring_model_selector.py
1 from __future__ import print_function, division
2 import RMF
3 import subprocess
4 import os
5 import shutil
6 import random
7 import glob
8 
9 
10 # If we have new enough IMP/RMF, do our own RMF slicing with provenance
11 if hasattr(RMF.NodeHandle, 'replace_child'):
12  def rmf_slice(infile, frameid, outfile, num_runs, total_num_frames,
13  num_good_scoring):
14  inr = RMF.open_rmf_file_read_only(infile)
15  outr = RMF.create_rmf_file(outfile)
16  cpf = RMF.CombineProvenanceFactory(outr)
17  fpf = RMF.FilterProvenanceFactory(outr)
18  RMF.clone_file_info(inr, outr)
19  RMF.clone_hierarchy(inr, outr)
20  RMF.clone_static_frame(inr, outr)
21  inr.set_current_frame(RMF.FrameID(frameid))
22  outr.add_frame("f0")
23  RMF.clone_loaded_frame(inr, outr)
24  rn = outr.get_root_node()
25  children = rn.get_children()
26  if len(children) == 0:
27  return
28  rn = children[0] # Should be the top-level IMP node
29  prov = [c for c in rn.get_children() if c.get_type() == RMF.PROVENANCE]
30  if not prov:
31  return
32  prov = prov[0]
33  # Add combine-provenance info
34  newp = rn.replace_child(prov, "combine", RMF.PROVENANCE)
35  cp = cpf.get(newp)
36  cp.set_frames(total_num_frames)
37  cp.set_runs(num_runs)
38  # Add filter-provenance info
39  newp = rn.replace_child(newp, "filter", RMF.PROVENANCE)
40  fp = fpf.get(newp)
41  fp.set_frames(num_good_scoring)
42  fp.set_method("Best scoring")
43  # todo: put in a more appropriate value
44  fp.set_threshold(0.)
45 
46 # Otherwise, fall back to the rmf_slice command line tool
47 else:
48  def rmf_slice(infile, frameid, outfile, num_runs, total_num_frames,
49  num_good_scoring):
50  FNULL = open(os.devnull, 'w')
51  subprocess.call(['rmf_slice', infile, "-f", str(frameid), outfile],
52  stdout=FNULL, stderr=subprocess.STDOUT)
53 
54 
56  # Authors: Shruthi Viswanath
57 
58  ''' Select good-scoring models based on scores and/or data satisfaction.
59  Exrtact the corresponding RMFs and put them in a separate directory
60  '''
61 
62  def __init__(self, run_directory, run_prefix):
63  """Constructor.
64  @param run_directory the directory containing subdirectories of runs
65  @param run_prefix the prefix for each run directory. For e.g. if the
66  subdirectories are modeling_run1, modeling_run2, etc. the
67  prefix is modeling_run
68  """
69  self.run_dir = run_directory
70  self.run_prefix = run_prefix
71 
72  # list with each member as a tuple (run id,replica id,frame id)
73  # corresponding to good-scoring models
74  self.all_good_scoring_models = []
75 
76  def _get_subfields_for_criteria(
77  self, field_headers, selection_keywords_list,
78  printing_keywords_list):
79  '''Given the list of keywords, get all the stat file entries
80  corresponding to each keyword.'''
81 
82  # list of dicts corresponding to field indices for each keyword
83  selection_fields = [{} for kw in selection_keywords_list]
84  # for ambiguous crosslink distances, it will store all restraints
85  # corresponding to the ambivalence in one dict.
86 
87  # just a placeholder
88  printing_fields = [-1 for j in range(len(printing_keywords_list))]
89 
90  for fh_index in field_headers:
91 
92  for ki, kw in enumerate(selection_keywords_list):
93  # need exact name of field unless it is a xlink distance
94  if kw == field_headers[fh_index]:
95  selection_fields[ki][kw] = fh_index
96 
97  elif kw in field_headers[fh_index] and \
98  field_headers[fh_index].startswith(
99  "CrossLinkingMassSpectrometry"
100  "Restraint_Distance_"):
101  # handle ambiguous restraints
102  (prot1, res1, prot2, res2) = \
103  field_headers[fh_index].split("|")[3:7]
104  prot1 = prot1.split('.')[0]
105  prot2 = prot2.split('.')[0]
106 
107  if (prot1, res1, prot2, res2) in selection_fields[ki]:
108  selection_fields[ki][
109  (prot1, res1, prot2, res2)].append(fh_index)
110  else:
111  # list of indices corresponding to all
112  # combinations of protein copies
113  selection_fields[ki][(prot1, res1, prot2, res2)] = \
114  [fh_index]
115 
116  for ki, kw in enumerate(printing_keywords_list):
117  if kw == field_headers[fh_index]:
118  printing_fields[ki] = fh_index
119 
120  return selection_fields, printing_fields
121 
122  def _get_crosslink_satisfaction(
123  self, crosslink_distance_values,
124  crosslink_percentage_lower_threshold,
125  crosslink_percentage_upper_threshold,
126  xlink_distance_lower_threshold, xlink_distance_upper_threshold):
127  '''For crosslinks, we want models with atleast x% (e.g. 90%) or more
128  crosslink satisfaction. A crosslink is satisfied if the distance
129  is between the lower and upper distance thresholds
130  @param crosslink_distance_values values of distances in the
131  current model
132  @param crosslink_percentage_lower_threshold atleast x% of crosslinks
133  should be within the below distance thresholds
134  @param crosslink_percentage_upper_threshold atmost x% of crosslinks
135  should be within the below distance thresholds
136  (usually 100%: dummy parameter)
137  @param xlink_distance_lower_threshold a crosslink should be at least
138  this distance apart (usually 0) to be considered satisfied
139  @param xlink_distance_upper_threshold a crosslink should be at most
140  this distance apart to be considered satisfied
141  '''
142  satisfied_xlinks = 0.0
143  for d in crosslink_distance_values:
144  if d >= xlink_distance_lower_threshold \
145  and d <= xlink_distance_upper_threshold:
146  satisfied_xlinks += 1.0
147 
148  percent_satisfied = \
149  satisfied_xlinks/float(len(crosslink_distance_values))
150 
151  if percent_satisfied >= crosslink_percentage_lower_threshold \
152  and percent_satisfied <= crosslink_percentage_upper_threshold:
153  return percent_satisfied, True
154  else:
155  return percent_satisfied, False
156 
157  def _get_score_satisfaction(self, score, lower_threshold, upper_threshold):
158  ''' Check if the score is within the thresholds
159  '''
160  if score <= upper_threshold and score >= lower_threshold:
161  return True
162  return False
163 
164  def _extract_models_from_trajectories(self, output_dir, num_runs,
165  total_num_frames):
166  '''Given the list of all good-scoring model indices, extract
167  their frames and store them ordered by the list index.'''
168  num_gsm = sum(1 for e in self.all_good_scoring_models)
169  print("Extracting", num_gsm, "good scoring models.")
170  model_num = 1
171 
172  for i, gsm in enumerate(self.all_good_scoring_models):
173  if model_num % (num_gsm/10) == 0:
174  print(str(model_num / (num_gsm/10)*10)+"% Complete")
175  model_num += 1
176 
177  (runid, replicaid, frameid) = gsm
178 
179  trajfile = os.path.join(self.run_dir, self.run_prefix+runid,
180  'output', 'rmfs', replicaid+'.rmf3')
181 
182  rmf_slice(trajfile, frameid,
183  os.path.join(output_dir, str(i)+'.rmf3'),
184  num_runs, total_num_frames,
185  len(self.all_good_scoring_models))
186 
188  self, selection_keywords_list=[], printing_keywords_list=[],
189  aggregate_lower_thresholds=[], aggregate_upper_thresholds=[],
190  member_lower_thresholds=[], member_upper_thresholds=[],
191  extract=False):
192  ''' Loops over all stat files in the run directory and populates
193  the list of good-scoring models.
194  @param selection_keywords_list is the list of keywords in the PMI
195  stat file that need to be checked for each datatype/score
196  in the criteria list
197  @param printing_keywords_list is the list of keywords in the PMI
198  stat file whose values needs to be printed for selected models
199  @param aggregate_lower_thresholds The list of lower bounds on the
200  values corresponding to fields in the criteria_list.
201  Aggregates are used for terms like % of crosslink satisfaction
202  and thresholds of score terms
203  @param aggregate_upper_thresholds The list of upper bounds on the
204  values corresponding to fields in the criteria_list.
205  Aggregates are used for terms like % of crosslink satisfaction
206  and thresholds of score terms
207  @param member_lower_thresholds The list of lower bounds for values
208  of subcomponents of an aggregate term. E.g. for crosslink
209  satisfaction the thresholds are on distances for each
210  individual crosslink. For score terms this can be ignored
211  since thresholds are mentioned in the aggregate fields.
212  @param member_upper_thresholds The list of upper bounds for values
213  of subcomponents of an aggregate term. E.g. for crosslink
214  satisfaction the thresholds are on distances for each
215  individual crosslink. For score terms this can be ignored
216  since thresholds are mentioned in the aggregate fields.
217  '''
218  if extract:
219  output_dir = os.path.join(self.run_dir, "good_scoring_models")
220  else:
221  output_dir = os.path.join(self.run_dir, "filter")
222  if os.path.exists(output_dir):
223  shutil.rmtree(output_dir, ignore_errors=True)
224  os.mkdir(output_dir)
225 
226  outf = open(os.path.join(output_dir, "model_ids_scores.txt"), 'w')
227  # Header line first
228  print(' '.join(["Model_index", "Run_id", "Replica_id", "Frame_id"]
229  + selection_keywords_list + printing_keywords_list),
230  file=outf)
231 
232  num_runs = 0
233  total_num_frames = 0
234 
235  for each_run_dir in sorted(
236  glob.glob(os.path.join(self.run_dir, self.run_prefix+"*")),
237  key=lambda x: int(x.split(self.run_prefix)[1])):
238 
239  runid = each_run_dir.split(self.run_prefix)[1]
240 
241  num_runs += 1
242 
243  print("Analyzing", runid)
244 
245  for each_replica_stat_file in sorted(
246  glob.glob(os.path.join(each_run_dir, "output",
247  "stat.*.out")),
248  key=lambda x: int(x.strip('.out').split('.')[-1])):
249 
250  replicaid = each_replica_stat_file.strip(".out").split(".")[-1]
251 
252  rsf = open(each_replica_stat_file, 'r')
253 
254  for line_index, each_model_line in enumerate(rsf.readlines()):
255  # for each model in the current replica
256 
257  if line_index == 0:
258  field_headers = eval(each_model_line.strip())
259  fields_for_selection, fields_for_printing = \
260  self._get_subfields_for_criteria(
261  field_headers, selection_keywords_list,
262  printing_keywords_list)
263  continue
264 
265  frameid = line_index-1
266 
267  dat = eval(each_model_line.strip())
268 
269  total_num_frames += 1
270  model_satisfies = False
271  selection_criteria_values = []
272 
273  for si, score_type in enumerate(selection_keywords_list):
274  if "crosslink" in score_type.lower() \
275  and "distance" in score_type.lower():
276 
277  crosslink_distance_values = []
278 
279  for xltype in fields_for_selection[si]:
280  crosslink_distance_values.append(
281  min([float(dat[j]) for j in
282  fields_for_selection[si][xltype]]))
283 
284  satisfied_percent, model_satisfies = \
285  self._get_crosslink_satisfaction(
286  crosslink_distance_values,
287  aggregate_lower_thresholds[si],
288  aggregate_upper_thresholds[si],
289  member_lower_thresholds[si],
290  member_upper_thresholds[si])
291 
292  selection_criteria_values.append(satisfied_percent)
293 
294  else:
295  score_value = float(
296  dat[fields_for_selection[si][score_type]])
297 
298  model_satisfies = self._get_score_satisfaction(
299  score_value, aggregate_lower_thresholds[si],
300  aggregate_upper_thresholds[si])
301  selection_criteria_values.append(score_value)
302 
303  if not model_satisfies:
304  break
305 
306  if model_satisfies:
307 
308  # Now get the printing criteria
309  printing_criteria_values = []
310  for si, score_type in enumerate(
311  printing_keywords_list):
312  score_value = float(dat[fields_for_printing[si]])
313  printing_criteria_values.append(score_value)
314 
315  self.all_good_scoring_models.append(
316  (runid, replicaid, frameid))
317 
318  # Print out the scores finally
319 
320  print(' '.join(
321  [str(x) for x in
322  [len(self.all_good_scoring_models) - 1,
323  runid, replicaid, frameid]] +
324  ["%.2f" % s for s in selection_criteria_values] +
325  ["%.2f" % s for s in printing_criteria_values]),
326  file=outf)
327 
328  rsf.close()
329  outf.close()
330 
331  if extract:
332  self._extract_models_from_trajectories(
333  output_dir, num_runs, total_num_frames)
334 
335  return self._split_good_scoring_models_into_two_subsets(
336  output_dir, num_runs,
337  split_type="divide_by_run_ids" if num_runs > 1 else "random")
338 
339  def _split_good_scoring_models_into_two_subsets(
340  self, output_dir, num_runs, split_type="divide_by_run_ids"):
341  ''' Get the listof good scoring models and split them into two samples,
342  keeping the models in separate directories. Return the two subsets.
343 
344  @param split_type how to split good scoring models into two samples.
345  Current options are:
346  (a) divide_by_run_ids : where the list of runids is divided
347  into 2. e.g. if we have runs from 1-50, good scoring
348  models from runs 1-25 is sample A and those from
349  runs 26-50 is sample B.
350  (b) random : split the set of good scoring models into
351  two subsets at random.
352  '''
353  sampleA_indices = []
354  sampleB_indices = []
355 
356  if split_type == "divide_by_run_ids": # split based on run ids
357 
358  half_num_runs = num_runs/2
359  for i, gsm in enumerate(self.all_good_scoring_models):
360  if int(gsm[0]) <= half_num_runs:
361  sampleA_indices.append(i)
362  else:
363  sampleB_indices.append(i)
364 
365  elif split_type == "random":
366  sampleA_indices = random.sample(
367  range(len(self.all_good_scoring_models)),
368  len(self.all_good_scoring_models)//2)
369  sampleB_indices = [
370  i for i in range(len(self.all_good_scoring_models))
371  if i not in sampleA_indices]
372 
373  # write model and sample IDs to a file
374  f = open(os.path.join(self.run_dir, 'good_scoring_models',
375  'model_sample_ids.txt'), 'w')
376 
377  # move models to corresponding sample directory
378  sampleA_dir = os.path.join(output_dir, "sample_A")
379  sampleB_dir = os.path.join(output_dir, "sample_B")
380  os.mkdir(sampleA_dir)
381  os.mkdir(sampleB_dir)
382 
383  for i in sampleA_indices:
384  print(i, "A", file=f)
385  shutil.move(os.path.join(output_dir, str(i)+'.rmf3'),
386  os.path.join(sampleA_dir, str(i)+'.rmf3'))
387  for i in sampleB_indices:
388  print(i, "B", file=f)
389  shutil.move(os.path.join(output_dir, str(i)+'.rmf3'),
390  os.path.join(sampleB_dir, str(i)+'.rmf3'))
391  f.close()
392  return sampleA_indices, sampleB_indices
def get_good_scoring_models
Loops over all stat files in the run directory and populates the list of good-scoring models...
Select good-scoring models based on scores and/or data satisfaction.