1 import IMP.em2d.imp_general.io
as io
17 from sets
import Set
as set
19 log = logging.getLogger(
"solutions_io")
25 """Simple named tuple class"""
27 class _itemgetter(object):
28 def __init__(self, ind):
30 def __call__(self, obj):
31 return obj[self.__ind]
33 def __init__(self, iterable):
34 if len(iterable) != self.__n_fields:
35 raise TypeError(
"Expected %d arguments, got %d" \
36 % (self.__n_fields, len(iterable)))
37 tuple.__init__(self, iterable)
40 cluster_id = property(_itemgetter(0))
41 n_elements = property(_itemgetter(1))
42 representative = property(_itemgetter(2))
43 elements = property(_itemgetter(3))
44 solutions_ids = property(_itemgetter(4))
55 The heapq algorithm is a min-heap. I want a max-heap, that pops the
56 larger values out of the heap.
57 For that I have to modify the comparison function and also set the
58 index that is used for the comparison. The index corresponds to
59 the restraint that we desired to order by
61 def __new__(self,x,i):
63 Build from a tuple and the index used to compare
66 return tuple.__new__(self, x)
68 def __lt__(self, other):
70 Compare. To convert the min-heap into a max-heap, the lower than
71 comparison is transformed into a greater-than
74 if(self[i] > other[i]):
79 def __le__(self, other):
81 return self[i] >= other[i]
84 def gather_best_solution_results(fns, fn_output, max_number=50000,
85 raisef=0.1, orderby=
"em2d"):
87 Reads a set of database files and merge them into a single file.
89 @param fns List of files with databases
90 @param fn_output The database to create
91 @param max_number Maximum number of records to keep, sorted according
93 @param raisef Ratio of problematic database files tolerated before
94 raising an error. This option is to tolerate some files
95 of the databases being broken because the cluster fails,
97 @param orderby Criterium used to sort the the records
99 Makes sure to reorder all column names if neccesary before merging
100 The record for the native solution is only added once (from first file).
106 names = db.get_table_column_names(tbl)
107 types = db.get_table_types(tbl)
108 indices = get_sorting_indices(names)
109 sorted_names = [ names[i]
for i
in indices]
110 sorted_types = [ types[i]
for i
in indices]
113 ind = names.index(orderby)
114 they_are_sorted = field_delim.join(names)
116 sql_command =
"""SELECT %s FROM %s
117 WHERE assignment="native" LIMIT 1 """ % (they_are_sorted, tbl)
118 native_data = db.retrieve_data(sql_command)
120 log.info(
"Gathering results. Saving to %s", fn_output)
122 out_db.create(fn_output, overwrite=
True)
123 out_db.connect(fn_output)
124 out_db.create_table(tbl, sorted_names, sorted_types)
130 log.info(
"Reading %s",fn)
133 sql_command =
"""SELECT %s FROM %s
134 WHERE assignment<>"native"
135 ORDER BY %s ASC LIMIT %s """ % (
136 they_are_sorted, tbl,orderby, max_number)
137 data = db.retrieve_data(sql_command)
138 log.info(
"%s records read from %s",len(data), fn)
143 if(len(best_records) < max_number):
144 heapq.heappush(best_records, a)
148 if(best_records[0] < a):
149 heapq.heapreplace(best_records, a)
151 log.error(
"Error for %s: %s",fn, e)
157 ratio = float(n_problems)/float(len(fns))
159 raise IOError(
"There are %8.1f %s of the database "\
160 "files to merge with problems! " % (ratio*100,
"%"))
162 heapq.heappush(best_records, native_data[0])
163 out_db.store_data(tbl, best_records)
166 def gather_solution_results(fns, fn_output, raisef=0.1):
168 Reads a set of database files and puts them in a single file
169 Makes sure to reorder all column names if neccesary before merging
170 @param fns List of database files
171 @param fn_output Name of the output database
172 @param raisef See help for gather_best_solution_results()
178 names = db.get_table_column_names(tbl)
179 types = db.get_table_types(tbl)
180 indices = get_sorting_indices(names)
181 sorted_names = [ names[i]
for i
in indices]
182 sorted_types = [ types[i]
for i
in indices]
183 log.info(
"Gathering results. Saving to %s", fn_output)
185 out_db.create(fn_output, overwrite=
True)
186 out_db.connect(fn_output)
187 out_db.create_table(tbl, sorted_names, sorted_types)
192 log.info(
"Reading %s",fn)
194 names = db.get_table_column_names(tbl)
196 they_are_sorted = field_delim.join(names)
197 log.debug(
"Retrieving %s", they_are_sorted)
198 sql_command =
"SELECT %s FROM %s" % (they_are_sorted, tbl)
199 data = db.retrieve_data(sql_command)
200 out_db.store_data(tbl, data)
203 log.error(
"Error for file %s: %s",fn, e)
205 ratio = float(n_problems)/float(len(fns))
207 raise IOError(
"There are %8.1f %s of the database "\
208 "files to merge with problems! " % (ratio*100,
"%"))
211 def get_sorting_indices(l):
212 """ Return indices that sort the list l """
213 pairs = [(element, i)
for i,element
in enumerate(l)]
215 indices = [p[1]
for p
in pairs]
218 def get_best_solution(fn_database, Nth, fields=False, orderby=False,
221 Recover the reference frame of the n-th best solution from a database.
222 The index Nth stars at 0
224 f = get_fields_string(fields)
225 sql_command =
""" SELECT %s FROM %s
227 ASC LIMIT 1 OFFSET %d """ % (f, tbl, orderby, Nth)
228 data = Database.read_data(fn_database, sql_command)
230 raise ValueError(
"The requested %s-th best solution does not exist. "\
231 "Only %s solutions found" % (Nth, len(data) ))
235 def get_pca(string, delimiter="/"):
236 pca = string.split(delimiter)
237 pca = [float(p)
for p
in pca]
240 def get_fields_string(fields):
242 Get a list of fields and return a string with them. If there are no
243 fields, return an *, indicating SQL that all the fields are requested
244 @param fields A list of strings
249 return field_delim.join(fields)
255 Class for managing the results of the experiments
257 def __init__(self, ):
259 self.native_table_name =
"native"
260 self.results_table =
"results"
261 self.placements_table =
"placements"
262 self.ccc_table_name =
"ccc"
263 self.cluster_records = []
266 self.results_description_columns = [
"solution_id",
"assignment",
268 self.results_description_types = [int, str, str]
270 self.results_measures_columns = [
"drms",
"cdrms",
"crmsd"]
271 self.results_measures_types = [float, float, float]
275 Build the table of results
276 @param restraints_names The names given to the columns of the table
277 @param add_measures If True, add fields for comparing models
278 and native conformation
280 table_fields = self.results_description_columns + \
281 [
"total_score"] + restraints_names
282 table_types = self.results_description_types + \
283 [float] + [float
for r
in restraints_names]
286 table_fields += self.results_measures_columns
287 table_types += self.results_measures_types
288 log.debug(
"Creating table %s\n%s",table_fields,table_types)
289 self.
create_table(self.results_table, table_fields, table_types)
292 self.
create_table(self.native_table_name, table_fields, table_types)
295 max_number=
None, orderby=
False):
298 @param fields Fields to recover from the table
299 @param max_number Maximum number of solutions to recover
300 @param orderby Name of the restraint used for sorting the states
303 log.info(
"Getting %s from solutions", fields)
304 f = self.get_fields_string(fields)
305 sql_command =
"SELECT %s FROM %s " % (f, self.results_table)
307 sql_command +=
" ORDER BY %s ASC" % orderby
308 if max_number
not in (
None,
False):
309 sql_command +=
" LIMIT %d" % (max_number)
310 log.debug(
"Using %s", sql_command )
314 def get_solutions(self, fields=False, max_number=None, orderby=False):
316 Get solutions from the database.
317 @param fields Fields requested. If the fields are in different
318 tables, a left join is done. Otherwise get_solutions_results_table()
319 is called. See get_solutions_results_table() for the meaning
322 tables = self.get_tables_names()
323 log.debug(
"tables %s", tables)
324 required_tables = set()
325 pairs_table_field = []
329 for f,t
in [(f,t)
for f
in fields
for t
in tables]:
330 if t ==
"native" or f ==
"solution_id":
334 required_tables.add(t)
335 pairs_table_field.append((t,f))
336 required_tables = list(required_tables)
337 log.debug(
"required_tables %s", required_tables)
338 log.debug(
"pairs_table_field %s", pairs_table_field)
339 if len(required_tables) == 0:
343 elif len(required_tables) == 1
and required_tables[0] ==
"results":
347 elif len(required_tables) > 1:
351 sql_command +=
" ORDER BY %s ASC" % orderby
352 log.debug(
"Using %s", sql_command )
356 raise ValueError(
"Fields not found in the database")
360 Recover data for the native solution
361 @param fields Fields to recover
364 f = self.get_fields_string(fields)
365 sql_command =
"SELECT %s FROM %s " % (f, self.native_table_name)
369 def add_record(self, solution_id, assignment, RFs, total_score,
370 restraints_scores, measures):
372 Add a recorde to the database
373 @param solution_id The key for the solution
374 @param assignment The assigment for the solution provided by
376 @param RFs Reference frames of the rigid bodies of the components
377 of the assembly in the solution
378 @param total_score Total value of the scoring function
379 @param restraints_scores A list with all the values for the
381 @param measures A list with the values of all the measures for
384 words = [io.ReferenceFrameToText(ref).get_text()
for ref
in RFs]
385 RFs_txt = unit_delim.join(words)
386 record = [solution_id, assignment, RFs_txt, total_score] + \
389 record = record + measures
390 self.records.append(record)
395 Add a record for the native structure to the database
396 see add_record() for the meaning of the parameters
398 words = [io.ReferenceFrameToText(ref).get_text()
for ref
in RFs]
399 RFs_txt = unit_delim.join(words)
401 record = [solution_id, assignment, RFs_txt, total_score] + \
404 record = record + measures
405 self.
store_data(self.native_table_name, [record])
407 def save_records(self,table="results"):
411 """ both distances and angles are expected to be a list of floats """
412 return [solution_id] + distances + angles
417 Creates a table to store the values of the placement scores for the
419 @param names Names of the components of the assembly
422 self.placement_table_name = self.placements_table
423 table_fields = [
"solution_id"]
424 table_fields += [
"distance_%s" % name
for name
in names]
425 table_fields += [
"angle_%s" % name
for name
in names]
426 table_types = [int] + [float
for f
in table_fields]
428 self.
create_table(self.placement_table_name, table_fields, table_types)
430 table_fields, table_types,check=
True)
432 native_values = [0
for t
in table_fields]
433 log.debug(
"%s", self.native_table_name)
434 log.debug(
"table fields %s", table_fields)
436 table_fields, native_values,
437 [
"assignment"], [
"\"native\""])
441 Return the names of the placement score fields in the database
444 fields = [col
for col
in columns
if "distance" in col
or "angle" in col]
449 Add a table to the database for store the values of the cross
450 correlation coefficient between a model and the native configuration
454 table_fields = [
"solution_id",
"ccc"]
455 table_types = [int, float]
457 self.
create_table(self.ccc_table_name, table_fields, table_types)
460 table_fields, table_types,check=
True)
462 table_fields, [0,1.00], [
"assignment"], [
"\"native\""])
465 """ Format for the record to store in the ccc table """
466 return [solution_id, ccc]
468 def get_ccc(self, solution_id):
470 Recover the cross-correlation coefficient for a solution
473 sql_command =
""" SELECT ccc FROM %s
474 WHERE solution_id=%d """ % (self.ccc_table_name,
479 def store_ccc_data(self, ccc_data):
480 self.
store_data(self.ccc_table_name, ccc_data)
482 def store_placement_data(self, data):
483 log.debug(
"store placement table %s",data)
484 self.
store_data(self.placement_table_name,data)
488 Format a left join SQL command that recovers all fileds from the
490 @param pairs_table_field Pairs of (table,field)
491 @param tables_names Names of the tables
493 E.g. If pairs_table_filed = ((table1,a), (table2,b), (table3,c),
494 (table2,d)) and tables_names = (table1, table2, table3)
497 SELECT table1.a, table2.b, table3.c, table2.d FROM table1
498 LEFT JOIN table2 ON table1.solution_id = table2.solution_id
499 LEFT JOIN table3 ON table1.solution_id = table3.solution_id
500 WHERE table1.solution_id IS NOT NULL AND
501 table2.solution_id IS NOT NULL AND
502 table3.solution_id IS NOT NULL
505 txt = [
"%s.%s" % (p[0],p[1])
for p
in pairs_table_field]
506 fields_requested = field_delim.join(txt)
507 sql_command =
" SELECT %s FROM %s " % (fields_requested,tables_names[0])
508 n_tables = len(tables_names)
509 for i
in range(1, n_tables):
510 a = tables_names[i-1]
512 sql_command +=
" LEFT JOIN %s " \
513 "ON %s.solution_id = %s.solution_id " % (b,a,b)
516 for i
in range(n_tables-1):
517 sql_command +=
"WHERE %s.solution_id " \
518 "IS NOT NULL AND " % tables_names[i]
519 sql_command +=
" %s.solution_id IS NOT NULL " % tables_names[n_tables-1]
520 log.debug(
"%s" %sql_command)
525 Add a table to store information about the clusters of structures
526 @param name Name of the table
528 self.cluster_table_name = name
530 table_fields = (
"cluster_id",
"n_elements",
531 "representative",
"elements",
"solutions_ids")
532 table_types = (int, int, int, str, str)
537 elements, solutions_ids):
539 Add a record to the cluster database. Actually, only stores it
540 in a list (that will be added later)
541 @param cluster_id Number with the id of the cluster
542 @param n_elements Number of elements in the cluster
543 @param representative Number with the id of the representative
545 @param elements List with the number of the elements of the cluster
546 @param solutions_ids The numbers above are provided by the
547 clustering algorithm. The solutions_ids are the ids of the models
551 record = (cluster_id, n_elements, representative, elements,
553 log.debug(
"Adding cluster record: %s", record)
554 self.cluster_records.append(record)
558 Store the data for the clusters
560 log.info(
"Storing data of clusters. Number of records %s",
561 len(self.cluster_records) )
562 self.
store_data(self.cluster_table_name, self.cluster_records)
566 Recover solutions for a specific list of results
567 @param fields Fields to recover fro the database
568 @param solutions_ids A list with the desired solutions. E.g. [0,3,6]
570 sql_command =
""" SELECT %s FROM %s WHERE solution_id IN (%s) """
571 f = self.get_fields_string(fields)
572 str_ids =
",".join(map(str,solutions_ids))
573 data = self.
retrieve_data( sql_command % (f, self.results_table, str_ids ) )
578 Get the position of the native configuration
579 @param orderby Criterium used to sort the solutions
584 native_value = data[0][0]
587 values = [row[0]
for row
in data]
588 rank = np.searchsorted(values,native_value)
593 Recover the the information about the n-th largest cluster
594 @param position Cluster position (by size) requested
595 (1 is the largest cluster)
596 @param table_name Table where the information about the
599 s =
""" SELECT * FROM %s ORDER BY n_elements DESC """ % table_name
607 Recovers from the database the placement scores for a set of
608 solutions, and returns the mean and standard deviation of the
609 placement score for each of the components of the complex being
610 scored. This function will be typical used to compute the variation
611 of the placement of each component within a cluster of solutions
612 @param solutions_ids The ids of the solutions used to compute
614 @return The output are 4 numpy vectors:
615 placement_distances_mean - The mean placement distance for each
617 placement_distances_stddev - The standardd deviation of the
618 placement distance for each component
619 placement_angles_mean - The mean placement angle for each
621 placement_angles_stddev - The standard deviation of the placement
622 angle for each component,
626 table = self.placements_table
628 distance_fields = filter(
lambda x:
'distance' in x, fields)
629 angle_fields = filter(
lambda x:
'angle' in x, fields)
630 sql_command =
""" SELECT %s FROM %s WHERE solution_id IN (%s) """
632 str_ids =
",".join(map(str,solutions_ids))
633 log.debug(
"Solutions considered %s", solutions_ids)
634 s = sql_command % (
",".join(distance_fields), table, str_ids )
636 s = sql_command % (
",".join(angle_fields), table, str_ids )
638 D = np.array(data_distances)
639 placement_distances_mean = D.mean(axis=0)
640 placement_distances_stddev = D.std(axis=0)
641 A = np.array(data_angles)
642 placement_angles_mean = A.mean(axis=0)
643 placement_angles_stddev = A.std(axis=0)
644 return [placement_distances_mean,placement_distances_stddev,
645 placement_angles_mean, placement_angles_stddev]
650 Calculate the placement score and its standard deviation for
651 the complexes in a set of solutions. The values returned are
652 averages, as the placement score for a complex is the average
653 of the placement scores of the components. This function is used
654 to obtain global placement for a cluster of solutions.
655 @param solutions_ids The ids of the solutions used to compute
657 @return The output are 4 values:
658 plcd_mean - Average of the placement distance for the entire
659 complex over all the solutions.
660 plcd_std - Standard deviation of the placement distance for
661 the entire complex over all the solutions.
662 plca_mean - Average of the placement angle for the entire
663 complex over all the solutions.
664 plca_std - Standard deviation of the placement angle for
665 the entire complex over all the solutions.
667 [placement_distances_mean,placement_distances_stddev,
668 placement_angles_mean, placement_angles_stddev] = \
670 plcd_mean = placement_distances_mean.mean(axis=0)
671 plcd_std = placement_distances_stddev.mean(axis=0)
672 plca_mean = placement_angles_mean.mean(axis=0)
673 plca_std = placement_angles_stddev.mean(axis=0)
674 return [plcd_mean, plcd_std, plca_mean, plca_std]