IMP  2.0.1
The Integrative Modeling Platform
solutions_io.py
1 import IMP.em2d.imp_general.io as io
2 import IMP.em2d.Database as Database
3 
4 import sys
5 import heapq
6 import math
7 import os
8 import csv
9 import time
10 import logging
11 import glob
12 import numpy as np
13 
14 try:
15  set = set
16 except NameError:
17  from sets import Set as set
18 
19 log = logging.getLogger("solutions_io")
20 
21 unit_delim = "/" # separate units within a field (eg, reference frames).
22 field_delim = ","
23 
24 class ClusterRecord(tuple):
25  """Simple named tuple class"""
26 
27  class _itemgetter(object):
28  def __init__(self, ind):
29  self.__ind = ind
30  def __call__(self, obj):
31  return obj[self.__ind]
32 
33  def __init__(self, iterable):
34  if len(iterable) != self.__n_fields:
35  raise TypeError("Expected %d arguments, got %d" \
36  % (self.__n_fields, len(iterable)))
37  tuple.__init__(self, iterable)
38 
39  __n_fields = 5
40  cluster_id = property(_itemgetter(0))
41  n_elements = property(_itemgetter(1))
42  representative = property(_itemgetter(2))
43  elements = property(_itemgetter(3))
44  solutions_ids = property(_itemgetter(4))
45 
46 
47 #################################
48 
49 # INPUT/OUTPUT OF SOLUTIONS OBTAINED WITH DominoModel
50 
51 #################################
52 
53 class HeapRecord(tuple):
54  """
55  The heapq algorithm is a min-heap. I want a max-heap, that pops the
56  larger values out of the heap.
57  For that I have to modify the comparison function and also set the
58  index that is used for the comparison. The index corresponds to
59  the restraint that we desired to order by
60  """
61  def __new__(self,x,i):
62  """
63  Build from a tuple and the index used to compare
64  """
65  self.i = i
66  return tuple.__new__(self, x)
67 
68  def __lt__(self, other):
69  """
70  Compare. To convert the min-heap into a max-heap, the lower than
71  comparison is transformed into a greater-than
72  """
73  i = self.i
74  if(self[i] > other[i]):
75  return True
76  return False
77 
78  # Need __le__ as well for older Pythons
79  def __le__(self, other):
80  i = self.i
81  return self[i] >= other[i]
82 
83 
84 def gather_best_solution_results(fns, fn_output, max_number=50000,
85  raisef=0.1, orderby="em2d"):
86  """
87  Reads a set of database files and merge them into a single file.
88 
89  @param fns List of files with databases
90  @param fn_output The database to create
91  @param max_number Maximum number of records to keep, sorted according
92  to orderby
93  @param raisef Ratio of problematic database files tolerated before
94  raising an error. This option is to tolerate some files
95  of the databases being broken because the cluster fails,
96  fill the disks, etc
97  @param orderby Criterium used to sort the the records
98  NOTE:
99  Makes sure to reorder all column names if neccesary before merging
100  The record for the native solution is only added once (from first file).
101  """
102  tbl = "results"
103  # Get names and types of the columns from first database file
104  db = Database.Database2()
105  db.connect(fns[0])
106  names = db.get_table_column_names(tbl)
107  types = db.get_table_types(tbl)
108  indices = get_sorting_indices(names)
109  sorted_names = [ names[i] for i in indices]
110  sorted_types = [ types[i] for i in indices]
111 
112  names.sort()
113  ind = names.index(orderby)
114  they_are_sorted = field_delim.join(names)
115  # Get the native structure data from the first database
116  sql_command = """SELECT %s FROM %s
117  WHERE assignment="native" LIMIT 1 """ % (they_are_sorted, tbl)
118  native_data = db.retrieve_data(sql_command)
119  db.close()
120  log.info("Gathering results. Saving to %s", fn_output)
121  out_db = Database.Database2()
122  out_db.create(fn_output, overwrite=True)
123  out_db.connect(fn_output)
124  out_db.create_table(tbl, sorted_names, sorted_types)
125 
126  best_records = []
127  n_problems = 0
128  for fn in fns:
129  try:
130  log.info("Reading %s",fn)
131  db.connect(fn)
132 # log.debug("Retrieving %s", they_are_sorted)
133  sql_command = """SELECT %s FROM %s
134  WHERE assignment<>"native"
135  ORDER BY %s ASC LIMIT %s """ % (
136  they_are_sorted, tbl,orderby, max_number)
137  data = db.retrieve_data(sql_command)
138  log.info("%s records read from %s",len(data), fn)
139  db.close()
140  # Fill heap
141  for d in data:
142  a = HeapRecord(d, ind)
143  if(len(best_records) < max_number):
144  heapq.heappush(best_records, a)
145  else:
146  # remember that < here compares for greater em2d value,
147  # as a HeapRecord is used
148  if(best_records[0] < a):
149  heapq.heapreplace(best_records, a)
150  except Exception, e:
151  log.error("Error for %s: %s",fn, e)
152  n_problems += 1
153 
154  # If the number of problematic files is too high, report that something
155  # big is going on. Otherwise tolerate some errors from some tasks that
156  # failed (memory errors, locks, writing errors ...)
157  ratio = float(n_problems)/float(len(fns))
158  if ratio > raisef:
159  raise IOError("There are %8.1f %s of the database "\
160  "files to merge with problems! " % (ratio*100,"%"))
161  # append the native data to the best_records
162  heapq.heappush(best_records, native_data[0])
163  out_db.store_data(tbl, best_records)
164  out_db.close()
165 
166 def gather_solution_results(fns, fn_output, raisef=0.1):
167  """
168  Reads a set of database files and puts them in a single file
169  Makes sure to reorder all column names if neccesary before merging
170  @param fns List of database files
171  @param fn_output Name of the output database
172  @param raisef See help for gather_best_solution_results()
173  """
174  tbl = "results"
175  # Get names and types of the columns from first database file
176  db = Database.Database2()
177  db.connect(fns[0])
178  names = db.get_table_column_names(tbl)
179  types = db.get_table_types(tbl)
180  indices = get_sorting_indices(names)
181  sorted_names = [ names[i] for i in indices]
182  sorted_types = [ types[i] for i in indices]
183  log.info("Gathering results. Saving to %s", fn_output)
184  out_db = Database.Database2()
185  out_db.create(fn_output, overwrite=True)
186  out_db.connect(fn_output)
187  out_db.create_table(tbl, sorted_names, sorted_types)
188 
189  n_problems = 0
190  for fn in fns:
191  try:
192  log.info("Reading %s",fn)
193  db.connect(fn)
194  names = db.get_table_column_names(tbl)
195  names.sort()
196  they_are_sorted = field_delim.join(names)
197  log.debug("Retrieving %s", they_are_sorted)
198  sql_command = "SELECT %s FROM %s" % (they_are_sorted, tbl)
199  data = db.retrieve_data(sql_command)
200  out_db.store_data(tbl, data)
201  db.close()
202  except Exception, e:
203  log.error("Error for file %s: %s",fn, e)
204  n_problems += 1
205  ratio = float(n_problems)/float(len(fns))
206  if ratio > raisef:
207  raise IOError("There are %8.1f %s of the database "\
208  "files to merge with problems! " % (ratio*100,"%"))
209  out_db.close()
210 
211 def get_sorting_indices(l):
212  """ Return indices that sort the list l """
213  pairs = [(element, i) for i,element in enumerate(l)]
214  pairs.sort()
215  indices = [p[1] for p in pairs]
216  return indices
217 
218 def get_best_solution(fn_database, Nth, fields=False, orderby=False,
219  tbl="results"):
220  """
221  Recover the reference frame of the n-th best solution from a database.
222  The index Nth stars at 0
223  """
224  f = get_fields_string(fields)
225  sql_command = """ SELECT %s FROM %s
226  ORDER BY %s
227  ASC LIMIT 1 OFFSET %d """ % (f, tbl, orderby, Nth)
228  data = Database.read_data(fn_database, sql_command)
229  if len(data) == 0:
230  raise ValueError("The requested %s-th best solution does not exist. "\
231  "Only %s solutions found" % (Nth, len(data) ))
232  # the only field last record is the solution requested
233  return data[0][0]
234 
235 def get_pca(string, delimiter="/"):
236  pca = string.split(delimiter)
237  pca = [float(p) for p in pca]
238  return pca
239 
240 def get_fields_string(fields):
241  """
242  Get a list of fields and return a string with them. If there are no
243  fields, return an *, indicating SQL that all the fields are requested
244  @param fields A list of strings
245  @return a string
246  """
247 
248  if fields:
249  return field_delim.join(fields)
250  return "*"
251 
252 
254  """
255  Class for managing the results of the experiments
256  """
257  def __init__(self, ):
258  self.records = []
259  self.native_table_name = "native"
260  self.results_table = "results"
261  self.placements_table = "placements"
262  self.ccc_table_name = "ccc"
263  self.cluster_records = []
264 
265  # columns describing a solution in the results
266  self.results_description_columns = ["solution_id", "assignment",
267  "reference_frames"]
268  self.results_description_types = [int, str, str]
269  # columns describing measures for a result
270  self.results_measures_columns = ["drms", "cdrms", "crmsd"]
271  self.results_measures_types = [float, float, float]
272 
273  def add_results_table(self,restraints_names, add_measures=False):
274  """
275  Build the table of results
276  @param restraints_names The names given to the columns of the table
277  @param add_measures If True, add fields for comparing models
278  and native conformation
279  """
280  table_fields = self.results_description_columns + \
281  ["total_score"] + restraints_names
282  table_types = self.results_description_types + \
283  [float] + [float for r in restraints_names]
284  if add_measures:
285  # Add columns for measures
286  table_fields += self.results_measures_columns
287  table_types += self.results_measures_types
288  log.debug("Creating table %s\n%s",table_fields,table_types)
289  self.create_table(self.results_table, table_fields, table_types)
290  # create a table for the native assembly if we are benchmarking
291  if add_measures :
292  self.create_table(self.native_table_name, table_fields, table_types)
293 
294  def get_solutions_results_table(self, fields=False,
295  max_number=None, orderby=False):
296  """
297  Recovers solutions
298  @param fields Fields to recover from the table
299  @param max_number Maximum number of solutions to recover
300  @param orderby Name of the restraint used for sorting the states
301  """
302  self.check_if_is_connected()
303  log.info("Getting %s from solutions", fields)
304  f = self.get_fields_string(fields)
305  sql_command = "SELECT %s FROM %s " % (f, self.results_table)
306  if orderby:
307  sql_command += " ORDER BY %s ASC" % orderby
308  if max_number not in (None,False):
309  sql_command += " LIMIT %d" % (max_number)
310  log.debug("Using %s", sql_command )
311  data = self.retrieve_data(sql_command)
312  return data
313 
314  def get_solutions(self, fields=False, max_number=None, orderby=False):
315  """
316  Get solutions from the database.
317  @param fields Fields requested. If the fields are in different
318  tables, a left join is done. Otherwise get_solutions_results_table()
319  is called. See get_solutions_results_table() for the meaning
320  of the parameters.
321  """
322  tables = self.get_tables_names()
323  log.debug("tables %s", tables)
324  required_tables = set()
325  pairs_table_field = []
326 # fields_string = self.get_fields_string(fields)
327  if not fields:
328  fields = ["*",]
329  for f,t in [(f,t) for f in fields for t in tables]:
330  if t == "native" or f == "solution_id":
331  continue
332  columns = self.get_table_column_names(t)
333  if f in columns:
334  required_tables.add(t)
335  pairs_table_field.append((t,f))
336  required_tables = list(required_tables)
337  log.debug("required_tables %s", required_tables)
338  log.debug("pairs_table_field %s", pairs_table_field)
339  if len(required_tables) == 0:
340  data = self.get_solutions_results_table(fields,
341  max_number, orderby)
342  return data
343  elif len(required_tables) == 1 and required_tables[0] == "results":
344  data = self.get_solutions_results_table(fields,
345  max_number, orderby)
346  return data
347  elif len(required_tables) > 1:
348  sql_command = self.get_left_join_command( pairs_table_field,
349  required_tables)
350  if orderby:
351  sql_command += " ORDER BY %s ASC" % orderby
352  log.debug("Using %s", sql_command )
353  data = self.retrieve_data(sql_command)
354  return data
355  else:
356  raise ValueError("Fields not found in the database")
357 
358  def get_native_solution(self, fields=False):
359  """
360  Recover data for the native solution
361  @param fields Fields to recover
362  """
363 
364  f = self.get_fields_string(fields)
365  sql_command = "SELECT %s FROM %s " % (f, self.native_table_name)
366  data = self.retrieve_data(sql_command)
367  return data
368 
369  def add_record(self, solution_id, assignment, RFs, total_score,
370  restraints_scores, measures):
371  """
372  Add a recorde to the database
373  @param solution_id The key for the solution
374  @param assignment The assigment for the solution provided by
375  domino
376  @param RFs Reference frames of the rigid bodies of the components
377  of the assembly in the solution
378  @param total_score Total value of the scoring function
379  @param restraints_scores A list with all the values for the
380  restraints
381  @param measures A list with the values of all the measures for
382  benchmark
383  """
384  words = [io.ReferenceFrameToText(ref).get_text() for ref in RFs]
385  RFs_txt = unit_delim.join(words)
386  record = [solution_id, assignment, RFs_txt, total_score] + \
387  restraints_scores
388  if measures != None:
389  record = record + measures
390  self.records.append(record)
391 
392  def add_native_record(self, assignment, RFs, total_score,
393  restraints_scores):
394  """
395  Add a record for the native structure to the database
396  see add_record() for the meaning of the parameters
397  """
398  words = [io.ReferenceFrameToText(ref).get_text() for ref in RFs]
399  RFs_txt = unit_delim.join(words)
400  solution_id = 0
401  record = [solution_id, assignment, RFs_txt, total_score] + \
402  restraints_scores
403  measures = [0,0,0] # ["drms", "cdrms", "crmsd"]
404  record = record + measures
405  self.store_data(self.native_table_name, [record])
406 
407  def save_records(self,table="results"):
408  self.store_data(table, self.records)
409 
410  def format_placement_record(self, solution_id, distances, angles):
411  """ both distances and angles are expected to be a list of floats """
412  return [solution_id] + distances + angles
413 
414 
415  def add_placement_scores_table(self, names):
416  """
417  Creates a table to store the values of the placement scores for the
418  models.
419  @param names Names of the components of the assembly
420  """
421  self.check_if_is_connected()
422  self.placement_table_name = self.placements_table
423  table_fields = ["solution_id"]
424  table_fields += ["distance_%s" % name for name in names]
425  table_fields += ["angle_%s" % name for name in names]
426  table_types = [int] + [float for f in table_fields]
427  self.drop_table(self.placement_table_name)
428  self.create_table(self.placement_table_name, table_fields, table_types)
429  self.add_columns(self.native_table_name,
430  table_fields, table_types,check=True)
431  # update all placements scores to 0 for the native assembly
432  native_values = [0 for t in table_fields]
433  log.debug("%s", self.native_table_name)
434  log.debug("table fields %s", table_fields)
435  self.update_data(self.native_table_name,
436  table_fields, native_values,
437  ["assignment"], ["\"native\""])
438 
439  def get_placement_fields(self):
440  """
441  Return the names of the placement score fields in the database
442  """
443  columns = self.get_table_column_names(self.placements_table)
444  fields = [col for col in columns if "distance" in col or "angle" in col]
445  return fields
446 
447  def add_ccc_table(self):
448  """
449  Add a table to the database for store the values of the cross
450  correlation coefficient between a model and the native configuration
451  """
452 
453  self.check_if_is_connected()
454  table_fields = ["solution_id", "ccc"]
455  table_types = [int, float]
456  self.drop_table(self.ccc_table_name)
457  self.create_table(self.ccc_table_name, table_fields, table_types)
458  # update values for the native assembly
459  self.add_columns(self.native_table_name,
460  table_fields, table_types,check=True)
461  self.update_data(self.native_table_name,
462  table_fields, [0,1.00], ["assignment"], ["\"native\""])
463 
464  def format_ccc_record(self, solution_id, ccc):
465  """ Format for the record to store in the ccc table """
466  return [solution_id, ccc]
467 
468  def get_ccc(self, solution_id):
469  """
470  Recover the cross-correlation coefficient for a solution
471  @param solution_id
472  """
473  sql_command = """ SELECT ccc FROM %s
474  WHERE solution_id=%d """ % (self.ccc_table_name,
475  solution_id)
476  data = self.retrieve_data(sql_command)
477  return data[0][0]
478 
479  def store_ccc_data(self, ccc_data):
480  self.store_data(self.ccc_table_name, ccc_data)
481 
482  def store_placement_data(self, data):
483  log.debug("store placement table %s",data)
484  self.store_data(self.placement_table_name,data)
485 
486  def get_left_join_command(self, pairs_table_field, tables_names):
487  """
488  Format a left join SQL command that recovers all fileds from the
489  tables given
490  @param pairs_table_field Pairs of (table,field)
491  @param tables_names Names of the tables
492 
493  E.g. If pairs_table_filed = ((table1,a), (table2,b), (table3,c),
494  (table2,d)) and tables_names = (table1, table2, table3)
495 
496  The SQL command is:
497  SELECT table1.a, table2.b, table3.c, table2.d FROM table1
498  LEFT JOIN table2 ON table1.solution_id = table2.solution_id
499  LEFT JOIN table3 ON table1.solution_id = table3.solution_id
500  WHERE table1.solution_id IS NOT NULL AND
501  table2.solution_id IS NOT NULL AND
502  table3.solution_id IS NOT NULL
503  """
504 
505  txt = [ "%s.%s" % (p[0],p[1]) for p in pairs_table_field]
506  fields_requested = field_delim.join(txt)
507  sql_command = " SELECT %s FROM %s " % (fields_requested,tables_names[0])
508  n_tables = len(tables_names)
509  for i in range(1, n_tables):
510  a = tables_names[i-1]
511  b = tables_names[i]
512  sql_command += " LEFT JOIN %s " \
513  "ON %s.solution_id = %s.solution_id " % (b,a,b)
514  # add the condition of solution_id being not null, so there are not
515  # problems if some solutions are missing in one table
516  for i in range(n_tables-1):
517  sql_command += "WHERE %s.solution_id " \
518  "IS NOT NULL AND " % tables_names[i]
519  sql_command += " %s.solution_id IS NOT NULL " % tables_names[n_tables-1]
520  log.debug("%s" %sql_command)
521  return sql_command
522 
523  def add_clusters_table(self, name):
524  """
525  Add a table to store information about the clusters of structures
526  @param name Name of the table
527  """
528  self.cluster_table_name = name
529  self.check_if_is_connected()
530  table_fields = ("cluster_id","n_elements",
531  "representative","elements", "solutions_ids")
532  table_types = (int, int, int, str, str)
533  self.drop_table(name)
534  self.create_table(name, table_fields, table_types)
535 
536  def add_cluster_record(self, cluster_id, n_elements, representative,
537  elements, solutions_ids):
538  """
539  Add a record to the cluster database. Actually, only stores it
540  in a list (that will be added later)
541  @param cluster_id Number with the id of the cluster
542  @param n_elements Number of elements in the cluster
543  @param representative Number with the id of the representative
544  element
545  @param elements List with the number of the elements of the cluster
546  @param solutions_ids The numbers above are provided by the
547  clustering algorithm. The solutions_ids are the ids of the models
548  in "elements".
549  """
550 
551  record = (cluster_id, n_elements, representative, elements,
552  solutions_ids)
553  log.debug("Adding cluster record: %s", record)
554  self.cluster_records.append(record)
555 
556  def store_cluster_data(self):
557  """
558  Store the data for the clusters
559  """
560  log.info("Storing data of clusters. Number of records %s",
561  len(self.cluster_records) )
562  self.store_data(self.cluster_table_name, self.cluster_records)
563 
564  def get_solutions_from_list(self, fields=False, solutions_ids=[]):
565  """
566  Recover solutions for a specific list of results
567  @param fields Fields to recover fro the database
568  @param solutions_ids A list with the desired solutions. E.g. [0,3,6]
569  """
570  sql_command = """ SELECT %s FROM %s WHERE solution_id IN (%s) """
571  f = self.get_fields_string(fields)
572  str_ids = ",".join(map(str,solutions_ids))
573  data = self.retrieve_data( sql_command % (f, self.results_table, str_ids ) )
574  return data
575 
576  def get_native_rank(self, orderby):
577  """
578  Get the position of the native configuration
579  @param orderby Criterium used to sort the solutions
580  """
581  import numpy as np
582 
583  data = self.get_native_solution([orderby,])
584  native_value = data[0][0]
585  data = self.get_solutions_results_table(fields=[orderby,],
586  orderby=orderby)
587  values = [row[0] for row in data]
588  rank = np.searchsorted(values,native_value)
589  return rank
590 
591  def get_nth_largest_cluster(self, position, table_name="clusters"):
592  """
593  Recover the the information about the n-th largest cluster
594  @param position Cluster position (by size) requested
595  (1 is the largest cluster)
596  @param table_name Table where the information about the
597  clusters is stored
598  """
599  s = """ SELECT * FROM %s ORDER BY n_elements DESC """ % table_name
600  data = self.retrieve_data(s)
601  record = ClusterRecord(data[position-1])
602  return record
603 
604 
605  def get_individual_placement_statistics(self, solutions_ids):
606  """
607  Recovers from the database the placement scores for a set of
608  solutions, and returns the mean and standard deviation of the
609  placement score for each of the components of the complex being
610  scored. This function will be typical used to compute the variation
611  of the placement of each component within a cluster of solutions
612  @param solutions_ids The ids of the solutions used to compute
613  the statistics
614  @return The output are 4 numpy vectors:
615  placement_distances_mean - The mean placement distance for each
616  component
617  placement_distances_stddev - The standardd deviation of the
618  placement distance for each component
619  placement_angles_mean - The mean placement angle for each
620  component
621  placement_angles_stddev - The standard deviation of the placement
622  angle for each component,
623  """
624 
625  self.check_if_is_connected()
626  table = self.placements_table
627  fields = self.get_table_column_names(table)
628  distance_fields = filter(lambda x: 'distance' in x, fields)
629  angle_fields = filter(lambda x: 'angle' in x, fields)
630  sql_command = """ SELECT %s FROM %s WHERE solution_id IN (%s) """
631  # string with the solution ids to pass to the sql_command
632  str_ids = ",".join(map(str,solutions_ids))
633  log.debug("Solutions considered %s", solutions_ids)
634  s = sql_command % (",".join(distance_fields), table, str_ids )
635  data_distances = self.retrieve_data(s)
636  s = sql_command % (",".join(angle_fields), table, str_ids )
637  data_angles = self.retrieve_data(s)
638  D = np.array(data_distances)
639  placement_distances_mean = D.mean(axis=0)
640  placement_distances_stddev = D.std(axis=0)
641  A = np.array(data_angles)
642  placement_angles_mean = A.mean(axis=0)
643  placement_angles_stddev = A.std(axis=0)
644  return [placement_distances_mean,placement_distances_stddev,
645  placement_angles_mean, placement_angles_stddev]
646 
647 
648  def get_placement_statistics(self, solutions_ids):
649  """
650  Calculate the placement score and its standard deviation for
651  the complexes in a set of solutions. The values returned are
652  averages, as the placement score for a complex is the average
653  of the placement scores of the components. This function is used
654  to obtain global placement for a cluster of solutions.
655  @param solutions_ids The ids of the solutions used to compute
656  the statistics
657  @return The output are 4 values:
658  plcd_mean - Average of the placement distance for the entire
659  complex over all the solutions.
660  plcd_std - Standard deviation of the placement distance for
661  the entire complex over all the solutions.
662  plca_mean - Average of the placement angle for the entire
663  complex over all the solutions.
664  plca_std - Standard deviation of the placement angle for
665  the entire complex over all the solutions.
666  """
667  [placement_distances_mean,placement_distances_stddev,
668  placement_angles_mean, placement_angles_stddev] = \
669  self.get_individual_placement_statistics(solutions_ids)
670  plcd_mean = placement_distances_mean.mean(axis=0)
671  plcd_std = placement_distances_stddev.mean(axis=0)
672  plca_mean = placement_angles_mean.mean(axis=0)
673  plca_std = placement_angles_stddev.mean(axis=0)
674  return [plcd_mean, plcd_std, plca_mean, plca_std]