IMP  2.0.1
The Integrative Modeling Platform
Database.py
1 
2 import sqlite3 as sqlite
3 import os
4 import csv
5 import logging
6 
7 log = logging.getLogger("Database")
8 
9 
10 class Database2:
11  """ Class to manage a SQL database built with sqlite3 """
12 
13  def __init__(self):
14  # Connection to the database
15  self.connection = None
16  # Cursor of actions
17  self.cursor = None
18  # Dictionary of tablenames and types (used to convert values when storing data)
19 
20  def create(self,filename,overwrite=False):
21  """ Creates a database by simply connecting to the file """
22  log.info("Creating database")
23  if overwrite and os.path.exists(filename):
24  os.remove(filename)
25  sqlite.connect(filename)
26 
27  def connect(self,filename):
28  """ Connects to the database in filename """
29  if not os.path.isfile(filename):
30  raise IOError,"Database file not found: %s" % filename
31  self.connection = sqlite.connect(filename)
32  self.cursor = self.connection.cursor()
33 
34  def check_if_is_connected(self):
35  """ Checks if the class is connected to the database filename """
36  if self.connection == None:
37  raise ValueError,"The database has not been created " \
38  "or connection not established "
39 
40  def create_table(self, table_name, column_names, column_types):
41  """ Creates a table. It expects a sorted dictionary
42  of (data_field,typename) entries """
43  log.info("Creating table %s",table_name)
45  sql_command = "CREATE TABLE %s (" % (table_name)
46  for name, data_type in zip(column_names, column_types):
47  sql_typename = get_sql_type_name(data_type)
48  sql_command += "%s %s," % (name, sql_typename)
49  # replace last comma for a parenthesis
50  n = len(sql_command)
51  sql_command = sql_command[0:n-1] + ")"
52  log.debug(sql_command)
53  self.cursor.execute(sql_command)
54  self.connection.commit()
55 
56  def drop_table(self, table_name):
57  """
58  Delete a table if it exists
59  """
60  log.info("Deleting table %s",table_name)
62  sql_command = "DROP TABLE IF EXISTS %s" % (table_name)
63  log.debug(sql_command)
64  self.cursor.execute(sql_command)
65  self.connection.commit()
66 
67  def store_dataV1(self,table_name,data):
68  """ Inserts information in a given table of the database.
69  The info must be a list of tuples containing as many values
70  as columns in the table
71  Conversion of values is done AUTOMATICALLY after checking the
72  types stored in the table
73  """
75  n = len(data[0]) # number of columns for each row inserted
76  tuple_format="("+"?,"*(n-1)+"?)"
77  sql_command="INSERT INTO %s VALUES %s " % (table_name, tuple_format)
78  # Fill the table with the info in the tuples
79  types = self.get_table_types(table_name)
80 # log.debug("Storing types: %s", types)
81  for x in data:
82 # log.debug("DATA %s", x)
83  # convert (applies the types stored in the table dictionary to each value in x
84  y = [apply_type(i) for i,apply_type in zip(x, types)]
85  self.cursor.execute(sql_command, y)
86  self.connection.commit()
87 
88  def store_data(self,table_name,data):
89  """ Inserts information in a given table of the database.
90  The info must be a list of tuples containing as many values
91  as columns in the table
92  Conversion of values is done AUTOMATICALLY after checking the
93  types stored in the table
94  """
95  if len(data) == 0:
96  log.warning("Inserting empty data")
97  return
99  n = len(data[0]) # number of columns for each row inserted
100  tuple_format="("+"?,"*(n-1)+"?)"
101  sql_command="INSERT INTO %s VALUES %s " % (table_name, tuple_format)
102  # Fill the table with the info in the tuples
103  types = self.get_table_types(table_name)
104 # log.debug("Storing types: %s", types)
105  for i in xrange(len(data)):
106  data[i] = [apply_type(d) for d,apply_type in zip(data[i], types)]
107  self.cursor.executemany(sql_command, data)
108  self.connection.commit()
109 
110  def retrieve_data(self,sql_command):
111  """ Retrieves data from the database using the sql_command
112  returns the records as a list of tuples"""
113  self.check_if_is_connected()
114  log.debug("Retrieving data: %s" % sql_command)
115  self.cursor.execute(sql_command)
116  return self.cursor.fetchall()
117 
118  def update_data(self, table_name,
119  updated_fields,
120  updated_values,
121  condition_fields,
122  condition_values):
123  """ updates the register in the table identified by the condition
124  values for the condition fields
125  """
126  self.check_if_is_connected()
127  sql_command = "UPDATE %s SET " % (table_name)
128  for field, value in zip(updated_fields, updated_values):
129  sql_command += "%s=%s," % (field, value)
130  sql_command = sql_command.rstrip(",") + " WHERE "
131  s = self.get_condition_string(condition_fields, condition_values)
132  sql_command = sql_command + s
133  #print sql_command
134  log.debug("Updating %s: %s",table_name, sql_command)
135  self.cursor.execute(sql_command)
136  self.connection.commit()
137 
138  def create_view(self,view_name,table_name,
139  condition_fields, condition_values):
140  """ creates a view of the given table where the values are selected
141  using the condition values. See the help for update_data()
142  """
143  try: # if this fails is because the view already exist
144  self.drop_view(view_name)
145  except:
146  pass
147  sql_command = 'CREATE VIEW %s AS SELECT * FROM %s WHERE ' % (view_name, table_name)
148  condition = self.get_condition_string(condition_fields, condition_values)
149  sql_command += condition
150  log.info("Creating view %s", sql_command)
151  self.cursor.execute(sql_command)
152 
153  def create_view_of_best_records(self, view_name, table_name, orderby, n_records):
154  try: # if this fails is because the view already exist
155  self.drop_view(view_name)
156  except:
157  pass
158  sql_command = """CREATE VIEW %s AS SELECT * FROM %s
159  ORDER BY %s ASC LIMIT %d """ % (view_name, table_name, orderby, n_records)
160  log.info("Creating view %s", sql_command)
161  self.cursor.execute(sql_command)
162 
163  def drop_view(self,view_name):
164  """ Removes a view from the database """
165  self.cursor.execute('DROP VIEW %s' % view_name)
166 
167  def get_table(self, table_name, fields=False, orderby=False):
168  """ Returns th fields requested from the table """
169  fields = self.get_fields_string(fields)
170  sql_command = "SELECT %s FROM %s " % (fields, table_name)
171  if orderby:
172  sql_command += " ORDER BY %s ASC" % orderby
173  data = self.retrieve_data(sql_command)
174  return data
175 
176  def get_fields_string(self, fields, field_delim=","):
177  if fields:
178  return field_delim.join(fields)
179  return "*"
180 
181  def close(self):
182  """ Closes the database """
183  self.check_if_is_connected()
184  self.cursor.close()
185  self.connection.close()
186 
187  def get_condition_string(self, fields, values):
188  """ creates a condition applying each value to each field
189  """
190  s = ""
191  for field,value in zip(fields,values):
192  s += "%s=%s AND " % (field, value)
193  # remove last AND
194  n = len(s)
195  s = s[0:n-5]
196  return s
197 
198  def get_table_types(self, name):
199  """
200  Gets info about a table and retuns all the types in it
201  """
202  self.check_if_is_connected()
203  sql_command = "PRAGMA table_info(%s)" % name
204  self.cursor.execute(sql_command)
205  info = self.cursor.fetchall()
206  types = []
207  for row in info:
208  if row[2] == "INT":
209  types.append(int)
210  elif row[2] == "DOUBLE":
211  types.append(float)
212  elif row[2][0:7] == "VARCHAR":
213  types.append(str)
214  return types
215 
216  def get_table_column_names(self, name):
217  """
218  Get the names of the columns for a given table
219  """
220  self.check_if_is_connected()
221  sql_command = "PRAGMA table_info(%s)" % name
222  self.cursor.execute(sql_command)
223  info = self.cursor.fetchall()
224  return [ row[1] for row in info]
225 
226  def execute_sql_command(self, sql_command):
227  self.check_if_is_connected()
228  self.cursor.execute(sql_command)
229  self.connection.commit()
230 
231 
232  def add_column(self,table,column, data_type):
233  """
234  Add a column to a table
235  column - the name of the column.
236  data_type - the type: int, float, str
237  """
238  sql_typename = get_sql_type_name(data_type)
239  sql_command = "ALTER TABLE %s ADD %s %s" % (table, column, sql_typename)
240  self.execute_sql_command(sql_command)
241 
242  def add_columns(self, table, names, types, check=True):
243  """
244  Add columns to the database. If check=True, columns with names
245  already in the database are skipped. If check=False no check
246  is done and trying to add a column that already exists will
247  raise and exception
248  """
249  col_names = self.get_table_column_names(table)
250  if check:
251  for name, dtype in zip(names, types):
252  if name not in col_names:
253  self.add_column(table, name, dtype)
254  else:
255  for name, dtype in zip(names, types):
256  self.add_column(table, name, dtype)
257 
258  def get_tables_names(self):
259  sql_command = """ SELECT tbl_name FROM sqlite_master """
260  data = self.retrieve_data(sql_command)
261  names = [d[0] for d in data]
262  return names
263 
264 
265  def select_table(self):
266  """
267  Prompt for tables so the user can choose one
268  """
269  table_name = ""
270  self.check_if_is_connected()
271  tables = self.get_tables_names()
272  for t in tables:
273  say = ''
274  while say not in ('n','y'):
275  say = raw_input("Use table %s (y/n) " % t)
276  if say == 'y':
277  table_name = t
278  columns = self.get_table_column_names(t)
279  break
280  return table_name, columns
281 
282 
283  def drop_columns(self, table, columns):
284 
285  cnames = self.get_table_column_names(table)
286  for name in columns:
287  cnames.remove(name)
288  names_txt = ", ".join(cnames)
289  sql_command = [
290  "CREATE TEMPORARY TABLE backup(%s);" % names_txt,
291  "INSERT INTO backup SELECT %s FROM %s" % (names_txt, table),
292  "DROP TABLE %s;" % table,
293  "CREATE TABLE %s(%s);" % (table, names_txt),
294  "INSERT INTO %s SELECT * FROM backup;" % table,
295  "DROP TABLE backup;",
296  ]
297  for command in sql_command:
298  log.debug(command)
299 # print command
300  self.cursor.execute(command)
301 
302 def print_data(data, delimiter=" "):
303  """ Prints the data recovered from a database """
304  for row in data:
305  line = delimiter.join([str(x) for x in row])
306  print line
307 
308 def write_data(data,output_file,delimiter=" "):
309  """writes data to a file. The output file is expected to be a python
310  file object """
311  w = csv.writer(output_file, delimiter=delimiter)
312  for row in data:
313  w.writerow(row)
314 
315 def get_sql_type_name(data_type):
316  if(data_type == int):
317  return "INT"
318  elif(data_type == float):
319  return "DOUBLE"
320  elif(data_type == str):
321  return "VARCHAR(10)" # 10 is a random number, SQLITE does not chop strings
322 
323 def open(fn_database):
324  db = Database2()
325  db.connect(fn_database)
326  return db
327 
328 def read_data(fn_database, sql_command):
329  db = Database2()
330  db.connect(fn_database)
331  data = db.retrieve_data(sql_command)
332  db.close()
333  return data
334 
335 def get_sorting_indices(l):
336  """ Return indices that sort the list l """
337  pairs = [(element, i) for i,element in enumerate(l)]
338  pairs.sort()
339  indices = [p[1] for p in pairs]
340  return indices
341 
342 def merge_databases(fns, fn_output, tbl):
343  """
344  Reads a table from a set of database files into a single file
345  Makes sure to reorder all column names if neccesary before merging
346  """
347  # Get names and types of the columns from first database file
348  db = Database2()
349  db.connect(fns[0])
350  names = db.get_table_column_names(tbl)
351  types = db.get_table_types(tbl)
352  indices = get_sorting_indices(names)
353  sorted_names = [ names[i] for i in indices]
354  sorted_types = [ types[i] for i in indices]
355  log.info("Merging databases. Saving to %s", fn_output)
356  out_db = Database2()
357  out_db.create(fn_output, overwrite=True)
358  out_db.connect(fn_output)
359  out_db.create_table(tbl, sorted_names, sorted_types)
360  for fn in fns:
361  log.debug("Reading %s",fn)
362  db.connect(fn)
363  names = db.get_table_column_names(tbl)
364  names.sort()
365  they_are_sorted = ",".join(names)
366  log.debug("Retrieving %s", they_are_sorted)
367  sql_command = "SELECT %s FROM %s" % (they_are_sorted, tbl)
368  data = db.retrieve_data(sql_command)
369  out_db.store_data(tbl, data)
370  db.close()
371  out_db.close()