IMP logo
IMP Reference Guide  develop.63b38c487d,2024/12/22
The Integrative Modeling Platform
good_scoring_model_selector.py
1 """@namespace IMP.sampcon.good_scoring_model_selector
2  Select good-scoring models based on scores and/or data satisfaction."""
3 
4 from __future__ import print_function, division
5 import RMF
6 import subprocess
7 import os
8 import shutil
9 import random
10 import glob
11 import operator
12 
13 
14 # If we have new enough IMP/RMF, do our own RMF slicing with provenance
15 if hasattr(RMF.NodeHandle, 'replace_child'):
16  def rmf_slice(infile, frameid, outfile, num_runs, total_num_frames,
17  num_good_scoring):
18  inr = RMF.open_rmf_file_read_only(infile)
19  outr = RMF.create_rmf_file(outfile)
20  cpf = RMF.CombineProvenanceFactory(outr)
21  fpf = RMF.FilterProvenanceFactory(outr)
22  RMF.clone_file_info(inr, outr)
23  RMF.clone_hierarchy(inr, outr)
24  RMF.clone_static_frame(inr, outr)
25  inr.set_current_frame(RMF.FrameID(frameid))
26  outr.add_frame("f0")
27  RMF.clone_loaded_frame(inr, outr)
28  rn = outr.get_root_node()
29  children = rn.get_children()
30  if len(children) == 0:
31  return
32  rn = children[0] # Should be the top-level IMP node
33  prov = [c for c in rn.get_children() if c.get_type() == RMF.PROVENANCE]
34  if not prov:
35  return
36  prov = prov[0]
37  # Add combine-provenance info
38  newp = rn.replace_child(prov, "combine", RMF.PROVENANCE)
39  cp = cpf.get(newp)
40  cp.set_frames(total_num_frames)
41  cp.set_runs(num_runs)
42  # Add filter-provenance info
43  newp = rn.replace_child(newp, "filter", RMF.PROVENANCE)
44  fp = fpf.get(newp)
45  fp.set_frames(num_good_scoring)
46  fp.set_method("Best scoring")
47  # todo: put in a more appropriate value
48  fp.set_threshold(0.)
49 
50 # Otherwise, fall back to the rmf_slice command line tool
51 else:
52  def rmf_slice(infile, frameid, outfile, num_runs, total_num_frames,
53  num_good_scoring):
54  FNULL = open(os.devnull, 'w')
55  subprocess.call(['rmf_slice', infile, "-f", str(frameid), outfile],
56  stdout=FNULL, stderr=subprocess.STDOUT)
57 
58 
60  # Authors: Shruthi Viswanath
61 
62  ''' Select good-scoring models based on scores and/or data satisfaction.
63  Exrtact the corresponding RMFs and put them in a separate directory
64  '''
65 
66  def __init__(self, run_directory, run_prefix):
67  """Constructor.
68  @param run_directory the directory containing subdirectories of runs
69  @param run_prefix the prefix for each run directory. For e.g. if the
70  subdirectories are modeling_run1, modeling_run2, etc. the
71  prefix is modeling_run
72  """
73  self.run_dir = run_directory
74  self.run_prefix = run_prefix
75 
76  # list with each member as a tuple (run id,replica id,frame id)
77  # corresponding to good-scoring models
78  self.all_good_scoring_models = []
79 
80  def _all_run_dirs(self):
81  """Yield (pathname, runid) for all run directories (unsorted)"""
82  for x in os.listdir(self.run_dir):
83  if x.startswith(self.run_prefix):
84  runid = x[len(self.run_prefix):]
85  fullpath = os.path.join(self.run_dir, x)
86  if os.path.isdir(fullpath) and runid.isdigit():
87  yield (fullpath, runid)
88 
89  def _get_subfields_for_criteria(
90  self, field_headers, selection_keywords_list,
91  printing_keywords_list):
92  '''Given the list of keywords, get all the stat file entries
93  corresponding to each keyword.'''
94 
95  # list of dicts corresponding to field indices for each keyword
96  selection_fields = [{} for kw in selection_keywords_list]
97  # for ambiguous crosslink distances, it will store all restraints
98  # corresponding to the ambivalence in one dict.
99 
100  # just a placeholder
101  printing_fields = [-1 for j in range(len(printing_keywords_list))]
102 
103  for fh_index in field_headers:
104 
105  for ki, kw in enumerate(selection_keywords_list):
106  # need exact name of field unless it is a xlink distance
107  if kw == field_headers[fh_index]:
108  selection_fields[ki][kw] = fh_index
109 
110  elif kw in field_headers[fh_index] and \
111  field_headers[fh_index].startswith(
112  "CrossLinkingMassSpectrometry"
113  "Restraint_Distance_"):
114  # handle ambiguous restraints
115  (prot1, res1, prot2, res2) = \
116  field_headers[fh_index].split("|")[3:7]
117  prot1 = prot1.split('.')[0]
118  prot2 = prot2.split('.')[0]
119 
120  if (prot1, res1, prot2, res2) in selection_fields[ki]:
121  selection_fields[ki][
122  (prot1, res1, prot2, res2)].append(fh_index)
123  else:
124  # list of indices corresponding to all
125  # combinations of protein copies
126  selection_fields[ki][(prot1, res1, prot2, res2)] = \
127  [fh_index]
128 
129  for ki, kw in enumerate(printing_keywords_list):
130  if kw == field_headers[fh_index]:
131  printing_fields[ki] = fh_index
132 
133  return selection_fields, printing_fields
134 
135  def _get_crosslink_satisfaction(
136  self, crosslink_distance_values,
137  crosslink_percentage_lower_threshold,
138  crosslink_percentage_upper_threshold,
139  xlink_distance_lower_threshold, xlink_distance_upper_threshold):
140  '''For crosslinks, we want models with at least x% (e.g. 90%) or more
141  crosslink satisfaction. A crosslink is satisfied if the distance
142  is between the lower and upper distance thresholds
143  @param crosslink_distance_values values of distances in the
144  current model
145  @param crosslink_percentage_lower_threshold at least x% of crosslinks
146  should be within the below distance thresholds
147  @param crosslink_percentage_upper_threshold at most x% of crosslinks
148  should be within the below distance thresholds
149  (usually 100%: dummy parameter)
150  @param xlink_distance_lower_threshold a crosslink should be at least
151  this distance apart (usually 0) to be considered satisfied
152  @param xlink_distance_upper_threshold a crosslink should be at most
153  this distance apart to be considered satisfied
154  '''
155  satisfied_xlinks = 0.0
156  for d in crosslink_distance_values:
157  if d >= xlink_distance_lower_threshold \
158  and d <= xlink_distance_upper_threshold:
159  satisfied_xlinks += 1.0
160 
161  percent_satisfied = \
162  satisfied_xlinks/float(len(crosslink_distance_values))
163 
164  if percent_satisfied >= crosslink_percentage_lower_threshold \
165  and percent_satisfied <= crosslink_percentage_upper_threshold:
166  return percent_satisfied, True
167  else:
168  return percent_satisfied, False
169 
170  def _get_score_satisfaction(self, score, lower_threshold, upper_threshold):
171  ''' Check if the score is within the thresholds
172  '''
173  if score <= upper_threshold and score >= lower_threshold:
174  return True
175  return False
176 
177  def _extract_models_from_trajectories(self, output_dir, num_runs,
178  total_num_frames, sampleA_set):
179  '''Given the list of all good-scoring model indices, extract
180  their frames and store them ordered by the list index.
181  Store the models in two sample subdirectories.'''
182  sampleA_dir = os.path.join(output_dir, "sample_A")
183  sampleB_dir = os.path.join(output_dir, "sample_B")
184  os.mkdir(sampleA_dir)
185  os.mkdir(sampleB_dir)
186 
187  num_gsm = sum(1 for e in self.all_good_scoring_models)
188  print("Extracting", num_gsm, "good scoring models.")
189  model_num = 1
190 
191  for i, gsm in enumerate(self.all_good_scoring_models):
192  if model_num % (num_gsm/10) == 0:
193  print(str(model_num / (num_gsm/10)*10)+"% Complete")
194  model_num += 1
195 
196  (runid, replicaid, frameid) = gsm
197 
198  trajfile = os.path.join(self.run_dir, self.run_prefix+runid,
199  'output', 'rmfs', replicaid+'.rmf3')
200 
201  sample_dir = sampleA_dir if i in sampleA_set else sampleB_dir
202  rmf_slice(trajfile, frameid,
203  os.path.join(sample_dir, str(i)+'.rmf3'),
204  num_runs, total_num_frames,
205  len(self.all_good_scoring_models))
206 
208  self, selection_keywords_list=[], printing_keywords_list=[],
209  aggregate_lower_thresholds=[], aggregate_upper_thresholds=[],
210  member_lower_thresholds=[], member_upper_thresholds=[],
211  extract=False):
212  ''' Loops over all stat files in the run directory and populates
213  the list of good-scoring models.
214  @param selection_keywords_list is the list of keywords in the PMI
215  stat file that need to be checked for each datatype/score
216  in the criteria list
217  @param printing_keywords_list is the list of keywords in the PMI
218  stat file whose values needs to be printed for selected models
219  @param aggregate_lower_thresholds The list of lower bounds on the
220  values corresponding to fields in the criteria_list.
221  Aggregates are used for terms like % of crosslink satisfaction
222  and thresholds of score terms
223  @param aggregate_upper_thresholds The list of upper bounds on the
224  values corresponding to fields in the criteria_list.
225  Aggregates are used for terms like % of crosslink satisfaction
226  and thresholds of score terms
227  @param member_lower_thresholds The list of lower bounds for values
228  of subcomponents of an aggregate term. E.g. for crosslink
229  satisfaction the thresholds are on distances for each
230  individual crosslink. For score terms this can be ignored
231  since thresholds are mentioned in the aggregate fields.
232  @param member_upper_thresholds The list of upper bounds for values
233  of subcomponents of an aggregate term. E.g. for crosslink
234  satisfaction the thresholds are on distances for each
235  individual crosslink. For score terms this can be ignored
236  since thresholds are mentioned in the aggregate fields.
237  '''
238  if extract:
239  output_dir = os.path.join(self.run_dir, "good_scoring_models")
240  else:
241  output_dir = os.path.join(self.run_dir, "filter")
242  if os.path.exists(output_dir):
243  shutil.rmtree(output_dir, ignore_errors=True)
244  os.mkdir(output_dir)
245 
246  outf = open(os.path.join(output_dir, "model_ids_scores.txt"), 'w')
247  # Header line first
248  print(' '.join(["Model_index", "Run_id", "Replica_id", "Frame_id"]
249  + selection_keywords_list + printing_keywords_list),
250  file=outf)
251 
252  num_runs = 0
253  total_num_frames = 0
254 
255  for each_run_dir, runid in sorted(self._all_run_dirs(),
256  key=operator.itemgetter(1)):
257  num_runs += 1
258 
259  print("Analyzing", runid)
260 
261  for each_replica_stat_file in sorted(
262  glob.glob(os.path.join(each_run_dir, "output",
263  "stat.*.out")),
264  key=lambda x: int(x.strip('.out').split('.')[-1])):
265 
266  replicaid = each_replica_stat_file.strip(".out").split(".")[-1]
267 
268  rsf = open(each_replica_stat_file, 'r')
269 
270  for line_index, each_model_line in enumerate(rsf.readlines()):
271  # for each model in the current replica
272 
273  if line_index == 0:
274  field_headers = eval(each_model_line.strip())
275  fields_for_selection, fields_for_printing = \
276  self._get_subfields_for_criteria(
277  field_headers, selection_keywords_list,
278  printing_keywords_list)
279  continue
280 
281  frameid = line_index-1
282 
283  dat = eval(each_model_line.strip())
284 
285  total_num_frames += 1
286  model_satisfies = False
287  selection_criteria_values = []
288 
289  for si, score_type in enumerate(selection_keywords_list):
290  if "crosslink" in score_type.lower() \
291  and "distance" in score_type.lower():
292 
293  crosslink_distance_values = []
294 
295  for xltype in fields_for_selection[si]:
296  crosslink_distance_values.append(
297  min([float(dat[j]) for j in
298  fields_for_selection[si][xltype]]))
299 
300  satisfied_percent, model_satisfies = \
301  self._get_crosslink_satisfaction(
302  crosslink_distance_values,
303  aggregate_lower_thresholds[si],
304  aggregate_upper_thresholds[si],
305  member_lower_thresholds[si],
306  member_upper_thresholds[si])
307 
308  selection_criteria_values.append(satisfied_percent)
309 
310  else:
311  score_value = float(
312  dat[fields_for_selection[si][score_type]])
313 
314  model_satisfies = self._get_score_satisfaction(
315  score_value, aggregate_lower_thresholds[si],
316  aggregate_upper_thresholds[si])
317  selection_criteria_values.append(score_value)
318 
319  if not model_satisfies:
320  break
321 
322  if model_satisfies:
323 
324  # Now get the printing criteria
325  printing_criteria_values = []
326  for si, score_type in enumerate(
327  printing_keywords_list):
328  if fields_for_printing[si] < 0:
329  raise KeyError(
330  "Bad stat file key '%s': use `imp_sampcon "
331  "show_stat <stat file>` to see all "
332  "acceptable keys" % score_type)
333  score_value = float(dat[fields_for_printing[si]])
334  printing_criteria_values.append(score_value)
335 
336  self.all_good_scoring_models.append(
337  (runid, replicaid, frameid))
338 
339  # Print out the scores finally
340 
341  print(' '.join(
342  [str(x) for x in
343  [len(self.all_good_scoring_models) - 1,
344  runid, replicaid, frameid]] +
345  ["%.2f" % s for s in selection_criteria_values] +
346  ["%.2f" % s for s in printing_criteria_values]),
347  file=outf)
348 
349  rsf.close()
350  outf.close()
351 
352  if extract:
353  sampA, sampB = self._split_good_scoring_models_into_two_subsets(
354  num_runs,
355  split_type="divide_by_run_ids" if num_runs > 1 else "random")
356  self._extract_models_from_trajectories(
357  output_dir, num_runs, total_num_frames, frozenset(sampA))
358  return sampA, sampB
359 
360  def _split_good_scoring_models_into_two_subsets(
361  self, num_runs, split_type="divide_by_run_ids"):
362  ''' Get the list of good scoring models and split them into two
363  samples. Return the two subsets.
364 
365  @param split_type how to split good scoring models into two samples.
366  Current options are:
367  (a) divide_by_run_ids : where the list of runids is divided
368  into 2. e.g. if we have runs from 1-50, good scoring
369  models from runs 1-25 is sample A and those from
370  runs 26-50 is sample B.
371  (b) random : split the set of good scoring models into
372  two subsets at random.
373  '''
374  sampleA_indices = []
375  sampleB_indices = []
376 
377  if split_type == "divide_by_run_ids": # split based on run ids
378 
379  half_num_runs = num_runs/2
380  for i, gsm in enumerate(self.all_good_scoring_models):
381  if int(gsm[0]) <= half_num_runs:
382  sampleA_indices.append(i)
383  else:
384  sampleB_indices.append(i)
385 
386  elif split_type == "random":
387  sampleA_indices = random.sample(
388  range(len(self.all_good_scoring_models)),
389  len(self.all_good_scoring_models)//2)
390  sampleB_indices = [
391  i for i in range(len(self.all_good_scoring_models))
392  if i not in sampleA_indices]
393 
394  # write model and sample IDs to a file
395  with open(os.path.join(self.run_dir, 'good_scoring_models',
396  'model_sample_ids.txt'), 'w') as f:
397  for i in sampleA_indices:
398  print(i, "A", file=f)
399  for i in sampleB_indices:
400  print(i, "B", file=f)
401  return sampleA_indices, sampleB_indices
def get_good_scoring_models
Loops over all stat files in the run directory and populates the list of good-scoring models...
Select good-scoring models based on scores and/or data satisfaction.