1 """@namespace IMP.sampcon.good_scoring_model_selector
2 Select good-scoring models based on scores and/or data satisfaction."""
4 from __future__
import print_function, division
15 if hasattr(RMF.NodeHandle,
'replace_child'):
16 def rmf_slice(infile, frameid, outfile, num_runs, total_num_frames,
18 inr = RMF.open_rmf_file_read_only(infile)
19 outr = RMF.create_rmf_file(outfile)
20 cpf = RMF.CombineProvenanceFactory(outr)
21 fpf = RMF.FilterProvenanceFactory(outr)
22 RMF.clone_file_info(inr, outr)
23 RMF.clone_hierarchy(inr, outr)
24 RMF.clone_static_frame(inr, outr)
25 inr.set_current_frame(RMF.FrameID(frameid))
27 RMF.clone_loaded_frame(inr, outr)
28 rn = outr.get_root_node()
29 children = rn.get_children()
30 if len(children) == 0:
33 prov = [c
for c
in rn.get_children()
if c.get_type() == RMF.PROVENANCE]
38 newp = rn.replace_child(prov,
"combine", RMF.PROVENANCE)
40 cp.set_frames(total_num_frames)
43 newp = rn.replace_child(newp,
"filter", RMF.PROVENANCE)
45 fp.set_frames(num_good_scoring)
46 fp.set_method(
"Best scoring")
52 def rmf_slice(infile, frameid, outfile, num_runs, total_num_frames,
54 FNULL = open(os.devnull,
'w')
55 subprocess.call([
'rmf_slice', infile,
"-f", str(frameid), outfile],
56 stdout=FNULL, stderr=subprocess.STDOUT)
62 ''' Select good-scoring models based on scores and/or data satisfaction.
63 Exrtact the corresponding RMFs and put them in a separate directory
68 @param run_directory the directory containing subdirectories of runs
69 @param run_prefix the prefix for each run directory. For e.g. if the
70 subdirectories are modeling_run1, modeling_run2, etc. the
71 prefix is modeling_run
73 self.run_dir = run_directory
74 self.run_prefix = run_prefix
78 self.all_good_scoring_models = []
80 def _all_run_dirs(self):
81 """Yield (pathname, runid) for all run directories (unsorted)"""
82 for x
in os.listdir(self.run_dir):
83 if x.startswith(self.run_prefix):
84 runid = x[len(self.run_prefix):]
85 fullpath = os.path.join(self.run_dir, x)
86 if os.path.isdir(fullpath)
and runid.isdigit():
87 yield (fullpath, runid)
89 def _get_subfields_for_criteria(
90 self, field_headers, selection_keywords_list,
91 printing_keywords_list):
92 '''Given the list of keywords, get all the stat file entries
93 corresponding to each keyword.'''
96 selection_fields = [{}
for kw
in selection_keywords_list]
101 printing_fields = [-1
for j
in range(len(printing_keywords_list))]
103 for fh_index
in field_headers:
105 for ki, kw
in enumerate(selection_keywords_list):
107 if kw == field_headers[fh_index]:
108 selection_fields[ki][kw] = fh_index
110 elif kw
in field_headers[fh_index]
and \
111 field_headers[fh_index].startswith(
112 "CrossLinkingMassSpectrometry"
113 "Restraint_Distance_"):
115 (prot1, res1, prot2, res2) = \
116 field_headers[fh_index].split(
"|")[3:7]
117 prot1 = prot1.split(
'.')[0]
118 prot2 = prot2.split(
'.')[0]
120 if (prot1, res1, prot2, res2)
in selection_fields[ki]:
121 selection_fields[ki][
122 (prot1, res1, prot2, res2)].append(fh_index)
126 selection_fields[ki][(prot1, res1, prot2, res2)] = \
129 for ki, kw
in enumerate(printing_keywords_list):
130 if kw == field_headers[fh_index]:
131 printing_fields[ki] = fh_index
133 return selection_fields, printing_fields
135 def _get_crosslink_satisfaction(
136 self, crosslink_distance_values,
137 crosslink_percentage_lower_threshold,
138 crosslink_percentage_upper_threshold,
139 xlink_distance_lower_threshold, xlink_distance_upper_threshold):
140 '''For crosslinks, we want models with at least x% (e.g. 90%) or more
141 crosslink satisfaction. A crosslink is satisfied if the distance
142 is between the lower and upper distance thresholds
143 @param crosslink_distance_values values of distances in the
145 @param crosslink_percentage_lower_threshold at least x% of crosslinks
146 should be within the below distance thresholds
147 @param crosslink_percentage_upper_threshold at most x% of crosslinks
148 should be within the below distance thresholds
149 (usually 100%: dummy parameter)
150 @param xlink_distance_lower_threshold a crosslink should be at least
151 this distance apart (usually 0) to be considered satisfied
152 @param xlink_distance_upper_threshold a crosslink should be at most
153 this distance apart to be considered satisfied
155 satisfied_xlinks = 0.0
156 for d
in crosslink_distance_values:
157 if d >= xlink_distance_lower_threshold \
158 and d <= xlink_distance_upper_threshold:
159 satisfied_xlinks += 1.0
161 percent_satisfied = \
162 satisfied_xlinks/float(len(crosslink_distance_values))
164 if percent_satisfied >= crosslink_percentage_lower_threshold \
165 and percent_satisfied <= crosslink_percentage_upper_threshold:
166 return percent_satisfied,
True
168 return percent_satisfied,
False
170 def _get_score_satisfaction(self, score, lower_threshold, upper_threshold):
171 ''' Check if the score is within the thresholds
173 if score <= upper_threshold
and score >= lower_threshold:
177 def _extract_models_from_trajectories(self, output_dir, num_runs,
178 total_num_frames, sampleA_set):
179 '''Given the list of all good-scoring model indices, extract
180 their frames and store them ordered by the list index.
181 Store the models in two sample subdirectories.'''
182 sampleA_dir = os.path.join(output_dir,
"sample_A")
183 sampleB_dir = os.path.join(output_dir,
"sample_B")
184 os.mkdir(sampleA_dir)
185 os.mkdir(sampleB_dir)
187 num_gsm = sum(1
for e
in self.all_good_scoring_models)
188 print(
"Extracting", num_gsm,
"good scoring models.")
191 for i, gsm
in enumerate(self.all_good_scoring_models):
192 if model_num % (num_gsm/10) == 0:
193 print(str(model_num / (num_gsm/10)*10)+
"% Complete")
196 (runid, replicaid, frameid) = gsm
198 trajfile = os.path.join(self.run_dir, self.run_prefix+runid,
199 'output',
'rmfs', replicaid+
'.rmf3')
201 sample_dir = sampleA_dir
if i
in sampleA_set
else sampleB_dir
202 rmf_slice(trajfile, frameid,
203 os.path.join(sample_dir, str(i)+
'.rmf3'),
204 num_runs, total_num_frames,
205 len(self.all_good_scoring_models))
208 self, selection_keywords_list=[], printing_keywords_list=[],
209 aggregate_lower_thresholds=[], aggregate_upper_thresholds=[],
210 member_lower_thresholds=[], member_upper_thresholds=[],
212 ''' Loops over all stat files in the run directory and populates
213 the list of good-scoring models.
214 @param selection_keywords_list is the list of keywords in the PMI
215 stat file that need to be checked for each datatype/score
217 @param printing_keywords_list is the list of keywords in the PMI
218 stat file whose values needs to be printed for selected models
219 @param aggregate_lower_thresholds The list of lower bounds on the
220 values corresponding to fields in the criteria_list.
221 Aggregates are used for terms like % of crosslink satisfaction
222 and thresholds of score terms
223 @param aggregate_upper_thresholds The list of upper bounds on the
224 values corresponding to fields in the criteria_list.
225 Aggregates are used for terms like % of crosslink satisfaction
226 and thresholds of score terms
227 @param member_lower_thresholds The list of lower bounds for values
228 of subcomponents of an aggregate term. E.g. for crosslink
229 satisfaction the thresholds are on distances for each
230 individual crosslink. For score terms this can be ignored
231 since thresholds are mentioned in the aggregate fields.
232 @param member_upper_thresholds The list of upper bounds for values
233 of subcomponents of an aggregate term. E.g. for crosslink
234 satisfaction the thresholds are on distances for each
235 individual crosslink. For score terms this can be ignored
236 since thresholds are mentioned in the aggregate fields.
239 output_dir = os.path.join(self.run_dir,
"good_scoring_models")
241 output_dir = os.path.join(self.run_dir,
"filter")
242 if os.path.exists(output_dir):
243 shutil.rmtree(output_dir, ignore_errors=
True)
246 outf = open(os.path.join(output_dir,
"model_ids_scores.txt"),
'w')
248 print(
' '.join([
"Model_index",
"Run_id",
"Replica_id",
"Frame_id"]
249 + selection_keywords_list + printing_keywords_list),
255 for each_run_dir, runid
in sorted(self._all_run_dirs(),
256 key=operator.itemgetter(1)):
259 print(
"Analyzing", runid)
261 for each_replica_stat_file
in sorted(
262 glob.glob(os.path.join(each_run_dir,
"output",
264 key=
lambda x: int(x.strip(
'.out').split(
'.')[-1])):
266 replicaid = each_replica_stat_file.strip(
".out").split(
".")[-1]
268 rsf = open(each_replica_stat_file,
'r')
270 for line_index, each_model_line
in enumerate(rsf.readlines()):
274 field_headers = eval(each_model_line.strip())
275 fields_for_selection, fields_for_printing = \
276 self._get_subfields_for_criteria(
277 field_headers, selection_keywords_list,
278 printing_keywords_list)
281 frameid = line_index-1
283 dat = eval(each_model_line.strip())
285 total_num_frames += 1
286 model_satisfies =
False
287 selection_criteria_values = []
289 for si, score_type
in enumerate(selection_keywords_list):
290 if "crosslink" in score_type.lower() \
291 and "distance" in score_type.lower():
293 crosslink_distance_values = []
295 for xltype
in fields_for_selection[si]:
296 crosslink_distance_values.append(
297 min([float(dat[j])
for j
in
298 fields_for_selection[si][xltype]]))
300 satisfied_percent, model_satisfies = \
301 self._get_crosslink_satisfaction(
302 crosslink_distance_values,
303 aggregate_lower_thresholds[si],
304 aggregate_upper_thresholds[si],
305 member_lower_thresholds[si],
306 member_upper_thresholds[si])
308 selection_criteria_values.append(satisfied_percent)
312 dat[fields_for_selection[si][score_type]])
314 model_satisfies = self._get_score_satisfaction(
315 score_value, aggregate_lower_thresholds[si],
316 aggregate_upper_thresholds[si])
317 selection_criteria_values.append(score_value)
319 if not model_satisfies:
325 printing_criteria_values = []
326 for si, score_type
in enumerate(
327 printing_keywords_list):
328 if fields_for_printing[si] < 0:
330 "Bad stat file key '%s': use `imp_sampcon "
331 "show_stat <stat file>` to see all "
332 "acceptable keys" % score_type)
333 score_value = float(dat[fields_for_printing[si]])
334 printing_criteria_values.append(score_value)
336 self.all_good_scoring_models.append(
337 (runid, replicaid, frameid))
343 [len(self.all_good_scoring_models) - 1,
344 runid, replicaid, frameid]] +
345 [
"%.2f" % s
for s
in selection_criteria_values] +
346 [
"%.2f" % s
for s
in printing_criteria_values]),
353 sampA, sampB = self._split_good_scoring_models_into_two_subsets(
355 split_type=
"divide_by_run_ids" if num_runs > 1
else "random")
356 self._extract_models_from_trajectories(
357 output_dir, num_runs, total_num_frames, frozenset(sampA))
360 def _split_good_scoring_models_into_two_subsets(
361 self, num_runs, split_type=
"divide_by_run_ids"):
362 ''' Get the list of good scoring models and split them into two
363 samples. Return the two subsets.
365 @param split_type how to split good scoring models into two samples.
367 (a) divide_by_run_ids : where the list of runids is divided
368 into 2. e.g. if we have runs from 1-50, good scoring
369 models from runs 1-25 is sample A and those from
370 runs 26-50 is sample B.
371 (b) random : split the set of good scoring models into
372 two subsets at random.
377 if split_type ==
"divide_by_run_ids":
379 half_num_runs = num_runs/2
380 for i, gsm
in enumerate(self.all_good_scoring_models):
381 if int(gsm[0]) <= half_num_runs:
382 sampleA_indices.append(i)
384 sampleB_indices.append(i)
386 elif split_type ==
"random":
387 sampleA_indices = random.sample(
388 range(len(self.all_good_scoring_models)),
389 len(self.all_good_scoring_models)//2)
391 i
for i
in range(len(self.all_good_scoring_models))
392 if i
not in sampleA_indices]
395 with open(os.path.join(self.run_dir,
'good_scoring_models',
396 'model_sample_ids.txt'),
'w')
as f:
397 for i
in sampleA_indices:
398 print(i,
"A", file=f)
399 for i
in sampleB_indices:
400 print(i,
"B", file=f)
401 return sampleA_indices, sampleB_indices
def get_good_scoring_models
Loops over all stat files in the run directory and populates the list of good-scoring models...
Select good-scoring models based on scores and/or data satisfaction.