IMP logo
IMP Reference Guide  2.13.0
The Integrative Modeling Platform
good_scoring_model_selector.py
1 from __future__ import print_function, division
2 import IMP
3 import IMP.atom
4 import IMP.rmf
5 import RMF
6 import subprocess
7 from subprocess import Popen
8 import os,sys,string,math
9 import shutil
10 import random
11 import glob
12 
13 
14 # If we have new enough IMP/RMF, do our own RMF slicing with provenance
15 if hasattr(RMF.NodeHandle, 'replace_child'):
16  def rmf_slice(infile, frameid, outfile, num_runs, total_num_frames,
17  num_good_scoring):
18  inr = RMF.open_rmf_file_read_only(infile)
19  outr = RMF.create_rmf_file(outfile)
20  cpf = RMF.CombineProvenanceFactory(outr)
21  fpf = RMF.FilterProvenanceFactory(outr)
22  RMF.clone_file_info(inr, outr)
23  RMF.clone_hierarchy(inr, outr)
24  RMF.clone_static_frame(inr, outr)
25  inr.set_current_frame(RMF.FrameID(frameid))
26  outr.add_frame("f0")
27  RMF.clone_loaded_frame(inr, outr)
28  rn = outr.get_root_node()
29  children = rn.get_children()
30  if len(children) == 0:
31  return
32  rn = children[0] # Should be the top-level IMP node
33  prov = [c for c in rn.get_children() if c.get_type() == RMF.PROVENANCE]
34  if not prov:
35  return
36  prov = prov[0]
37  # Add combine-provenance info
38  newp = rn.replace_child(prov, "combine", RMF.PROVENANCE)
39  cp = cpf.get(newp)
40  cp.set_frames(total_num_frames)
41  cp.set_runs(num_runs)
42  # Add filter-provenance info
43  newp = rn.replace_child(newp, "filter", RMF.PROVENANCE)
44  fp = fpf.get(newp)
45  fp.set_frames(num_good_scoring)
46  fp.set_method("Best scoring")
47  # todo: put in a more appropriate value
48  fp.set_threshold(0.)
49 
50 # Otherwise, fall back to the rmf_slice command line tool
51 else:
52  def rmf_slice(infile, frameid, outfile, num_runs, total_num_frames,
53  num_good_scoring):
54  FNULL = open(os.devnull, 'w')
55  subprocess.call(['rmf_slice', infile, "-f", str(frameid), outfile],
56  stdout=FNULL, stderr=subprocess.STDOUT)
57 
58 
60  # Authors: Shruthi Viswanath
61 
62  ''' Select good-scoring models based on scores and/or data satisfaction.
63  Exrtact the corresponding RMFs and put them in a separate directory
64  '''
65 
66  def __init__(self,run_directory,run_prefix):
67  """Constructor.
68  @param run_directory the directory containing subdirectories of runs
69  @param run_prefix the prefix for each run directory. For e.g. if the subdirectories are modeling_run1, modeling_run2, etc. the prefix is modeling_run
70  """
71  self.run_dir=run_directory
72  self.run_prefix=run_prefix
73  self.all_good_scoring_models=[]# list with each member as a tuple (run id,replica id,frame id) corresponding to good-scoring models
74 
75  def _get_subfields_for_criteria(self,field_headers,selection_keywords_list,printing_keywords_list):
76  ''' Given the list of keywords, get all the stat file entries corresponding to each keyword.'''
77 
78  selection_fields=[{} for kw in selection_keywords_list] # list of dicts corresponding to field indices for each keyword
79  # for ambiguous crosslink distances, it will store all restraints corresponding to the ambivalence in one dict.
80 
81  printing_fields = [-1 for j in range(len(printing_keywords_list))] #just a placeholder
82 
83  for fh_index in field_headers:
84 
85  for ki,kw in enumerate(selection_keywords_list): # need exact name of field unless it is a xlink distance
86  if kw == field_headers[fh_index]:
87  #selection_fields[ki].append(fh_index)
88  selection_fields[ki][kw]=fh_index
89 
90 
91  elif kw in field_headers[fh_index] and field_headers[fh_index].startswith("CrossLinkingMassSpectrometryRestraint_Distance_"): # handle ambiguous restraints
92  (prot1,res1,prot2,res2) = field_headers[fh_index].split("|")[3:7]
93  prot1 = prot1.split('.')[0]
94  prot2 = prot2.split('.')[0]
95 
96  if (prot1,res1,prot2,res2) in selection_fields[ki]:
97  selection_fields[ki][(prot1,res1,prot2,res2)].append(fh_index)
98  else:
99  selection_fields[ki][(prot1,res1,prot2,res2)]=[fh_index] #list of indices corresponding to all combinations of protein copies
100 
101  # print(field_headers[fh_index],prot1,res1,prot2,res2,ki,qselection_fields[ki][(prot1,res1,prot2,res2)])
102 
103 
104  for ki,kw in enumerate(printing_keywords_list):
105  if kw==field_headers[fh_index]:
106  printing_fields[ki] = fh_index
107 
108  return selection_fields,printing_fields
109 
110  def _get_crosslink_satisfaction(self,crosslink_distance_values,crosslink_percentage_lower_threshold,
111  crosslink_percentage_upper_threshold,xlink_distance_lower_threshold,xlink_distance_upper_threshold):
112  ''' For crosslinks, we want models with atleast x% (e.g. 90%) or more crosslink satisfaction. A crosslink is satisfied if the distance is between the lower and upper distance thresholds
113  @param crosslink_distance_values values of distances in the current model
114  @param crosslink_percentage_lower_threshold atleast x% of crosslinks should be within the below distance thresholds
115  @param crosslink_percentage_upper_threshold atmost x% of crosslinks should be within the below distance thresholds (usually 100%: dummy parameter)
116  @param xlink_distance_lower_threshold a crosslink should be atleast this distance apart (usually 0) to be considered satisfied
117  @param xlink_distance_upper_threshold a crosslink should be atmost this distance apart to be considered satisfied
118  '''
119  satisfied_xlinks=0.0
120  for d in crosslink_distance_values:
121  if d>=xlink_distance_lower_threshold and d<=xlink_distance_upper_threshold:
122  satisfied_xlinks+=1.0
123 
124  percent_satisfied=satisfied_xlinks/float(len(crosslink_distance_values))
125 
126  if percent_satisfied>=crosslink_percentage_lower_threshold and percent_satisfied<=crosslink_percentage_upper_threshold:
127  return percent_satisfied,True
128  else:
129  return percent_satisfied,False
130 
131 
132  def _get_score_satisfaction(self,score,lower_threshold,upper_threshold):
133  ''' Check if the score is within the thresholds
134  '''
135  if score<=upper_threshold and score>=lower_threshold:
136  return True
137  return False
138 
139  def _extract_models_from_trajectories(self, output_dir, num_runs,
140  total_num_frames):
141  '''Given the list of all good-scoring model indices, extract
142  their frames and store them ordered by the list index.'''
143  num_gsm = sum(1 for e in self.all_good_scoring_models)
144  print("Extracting",num_gsm,"good scoring models.")
145  model_num=1
146 
147  for i,gsm in enumerate(self.all_good_scoring_models):
148  if model_num % (num_gsm/10) == 0:
149  print(str(model_num / (num_gsm/10)*10)+"% Complete")
150  model_num+=1
151 
152  (runid,replicaid,frameid)=gsm
153 
154  trajfile=os.path.join(self.run_dir,self.run_prefix+runid,'output','rmfs',replicaid+'.rmf3')
155 
156  #slice_location=os.path.join(os.environ['IMP_BIN_DIR'],'rmf_slice')
157 
158  #rmf_slice=Popen([slice_location,trajfile,"-f",str(frameid),os.path.join(output_dir,str(i)+'.rmf3')])
159  #out,err=rmf_slice.communicate()
160 
161  rmf_slice(trajfile, frameid,
162  os.path.join(output_dir,str(i)+'.rmf3'),
163  num_runs, total_num_frames,
164  len(self.all_good_scoring_models))
165 
166 
167  def get_good_scoring_models(self,selection_keywords_list=[],printing_keywords_list=[],aggregate_lower_thresholds=[],
168  aggregate_upper_thresholds=[],member_lower_thresholds=[],member_upper_thresholds=[],extract=False):
169  ''' Loops over all stat files in the run directory and populates the list of good-scoring models.
170  @param selection_keywords_list is the list of keywords in the PMI stat file that need to be checked for each datatype/score in the criteria list
171  @param printing_keywords_list is the list of keywords in the PMI stat file whose values needs to be printed for selected models
172  @param aggregate_lower_thresholds The list of lower bounds on the values corresponding to fields in the criteria_list. Aggregates are used for terms like % of crosslink satisfaction and thresholds of score terms
173  @param aggregate_upper_thresholds The list of upper bounds on the values corresponding to fields in the criteria_list. Aggregates are used for terms like % of crosslink satisfaction and thresholds of score terms
174  @param member_lower_thresholds The list of lower bounds for values of subcomponents of an aggregate term. E.g. for crosslink satisfaction the thresholds are on distances for each individual crosslink. For score terms this can be ignored since thresholds are mentioned in the aggregate fields.
175  @param member_upper_thresholds The list of upper bounds for values of subcomponents of an aggregate term. E.g. for crosslink satisfaction the thresholds are on distances for each individual crosslink. For score terms this can be ignored since thresholds are mentioned in the aggregate fields.
176  '''
177  if extract:
178  output_dir=os.path.join(self.run_dir,"good_scoring_models")
179  else:
180  output_dir=os.path.join(self.run_dir,"filter")
181  if os.path.exists(output_dir):
182  shutil.rmtree(output_dir, ignore_errors=True)
183  os.mkdir(output_dir)
184 
185  outf=open(os.path.join(output_dir,"model_ids_scores.txt"),'w')
186  # Header line first
187  print(' '.join(["Model_index", "Run_id", "Replica_id", "Frame_id"]
188  + selection_keywords_list + printing_keywords_list),
189  file=outf)
190 
191  num_runs = 0
192  total_num_frames = 0
193 
194  for each_run_dir in sorted(glob.glob(os.path.join(self.run_dir,self.run_prefix+"*")),key=lambda x:int(x.split(self.run_prefix)[1])):
195 
196  runid=each_run_dir.split(self.run_prefix)[1]
197 
198  num_runs+=1
199 
200  print("Analyzing",runid)
201 
202  for each_replica_stat_file in sorted(glob.glob(os.path.join(each_run_dir,"output")+"/stat.*.out"),key=lambda x:int(x.strip('.out').split('.')[-1])):
203 
204  replicaid=each_replica_stat_file.strip(".out").split(".")[-1]
205 
206  rsf=open(each_replica_stat_file,'r')
207 
208  for line_index,each_model_line in enumerate(rsf.readlines()): # for each model in the current replica
209 
210  if line_index==0:
211  field_headers=eval(each_model_line.strip())
212  fields_for_selection,fields_for_printing=self._get_subfields_for_criteria(field_headers,selection_keywords_list,printing_keywords_list)
213  continue
214 
215  frameid=line_index-1
216 
217  dat=eval(each_model_line.strip())
218 
219  total_num_frames += 1
220  model_satisfies=False
221  selection_criteria_values=[]
222 
223  for si,score_type in enumerate(selection_keywords_list):
224  if "crosslink" in score_type.lower() and "distance" in score_type.lower():
225 
226  crosslink_distance_values=[]
227 
228  for xltype in fields_for_selection[si]:
229  crosslink_distance_values.append(min([float(dat[j]) for j in fields_for_selection[si][xltype]]))
230 
231  #crosslink_distance_values=[float(dat[j]) for j in fields_for_selection[si]] # earlier version without ambiguity
232 
233  satisfied_percent,model_satisfies=self._get_crosslink_satisfaction(crosslink_distance_values,aggregate_lower_thresholds[si],aggregate_upper_thresholds[si],member_lower_thresholds[si],member_upper_thresholds[si])
234 
235  selection_criteria_values.append(satisfied_percent)
236 
237  else:
238  score_value=float(dat[fields_for_selection[si][score_type]])
239 
240  model_satisfies=self._get_score_satisfaction(score_value,aggregate_lower_thresholds[si],aggregate_upper_thresholds[si])
241  selection_criteria_values.append(score_value)
242 
243  if not model_satisfies:
244  break
245 
246  if model_satisfies:
247 
248  # Now get the printing criteria
249  printing_criteria_values=[]
250  for si,score_type in enumerate(printing_keywords_list):
251  score_value=float(dat[fields_for_printing[si]])
252  printing_criteria_values.append(score_value)
253 
254  self.all_good_scoring_models.append((runid,replicaid,frameid))
255 
256  # Print out the scores finally
257 
258  print(' '.join(
259  [str(x) for x in
260  [len(self.all_good_scoring_models) - 1,
261  runid, replicaid, frameid]] +
262  ["%.2f" % s
263  for s in selection_criteria_values] +
264  ["%.2f" % s
265  for s in printing_criteria_values]),
266  file=outf)
267 
268  rsf.close()
269  outf.close()
270 
271  if extract:
272  self._extract_models_from_trajectories(output_dir,
273  num_runs, total_num_frames)
274 
275  return self._split_good_scoring_models_into_two_subsets(output_dir,num_runs,
276  split_type="divide_by_run_ids" if num_runs > 1 else "random")
277 
278 
279  def _split_good_scoring_models_into_two_subsets(self,output_dir,num_runs,split_type="divide_by_run_ids"):
280  ''' Get the listof good scoring models and split them into two samples, keeping the models in separate directories. Return the two subsets.
281  @param split_type how to split good scoring models into two samples. Current options are:
282  (a) divide_by_run_ids : where the list of runids is divided into 2. e.g. if we have runs from 1-50, good scoring models from runs
283  1-25 is sample A and those from runs 26-50 is sample B.
284  (b) random : split the set of good scoring models into two subsets at random.
285  '''
286  sampleA_indices=[]
287  sampleB_indices=[]
288 
289  if split_type=="divide_by_run_ids": # split based on run ids
290 
291  half_num_runs= num_runs/2
292  for i,gsm in enumerate(self.all_good_scoring_models):
293  if int(gsm[0])<=half_num_runs:
294  sampleA_indices.append(i)
295  else:
296  sampleB_indices.append(i)
297 
298  elif split_type=="random":
299  sampleA_indices=random.sample(range(len(self.all_good_scoring_models)),len(self.all_good_scoring_models)//2)
300  sampleB_indices=[i for i in range(len(self.all_good_scoring_models)) if i not in sampleA_indices]
301 
302  # write model and sample IDs to a file
303  f=open(os.path.join(self.run_dir,'good_scoring_models','model_sample_ids.txt'),'w')
304 
305  # move models to corresponding sample directory
306  sampleA_dir = os.path.join(output_dir,"sample_A")
307  sampleB_dir = os.path.join(output_dir,"sample_B")
308  os.mkdir(sampleA_dir)
309  os.mkdir(sampleB_dir)
310 
311  for i in sampleA_indices:
312  print(i,"A", file=f)
313  shutil.move(os.path.join(output_dir,str(i)+'.rmf3'),os.path.join(sampleA_dir,str(i)+'.rmf3'))
314  for i in sampleB_indices:
315  print(i,"B", file=f)
316  shutil.move(os.path.join(output_dir,str(i)+'.rmf3'),os.path.join(sampleB_dir,str(i)+'.rmf3'))
317  f.close()
318  return sampleA_indices, sampleB_indices
def get_good_scoring_models
Loops over all stat files in the run directory and populates the list of good-scoring models...
Select good-scoring models based on scores and/or data satisfaction.
Functionality for loading, creating, manipulating and scoring atomic structures.
Support for the RMF file format for storing hierarchical molecular data and markup.