IMP  2.3.1
The Integrative Modeling Platform
output.py
1 """@namespace IMP.pmi.output
2  Classes for writing output files and processing them.
3 """
4 
5 import IMP
6 import IMP.atom
7 import IMP.core
8 import IMP.pmi
9 import IMP.pmi.tools
10 import os
11 import RMF
12 import numpy as np
13 import operator
14 try:
15  import cPickle as pickle
16 except ImportError:
17  import pickle
18 
19 class Output(object):
20  """Class for easy writing of PDBs, RMFs, and stat files"""
21  def __init__(self, ascii=True,atomistic=False):
22  self.dictionary_pdbs = {}
23  self.dictionary_rmfs = {}
24  self.dictionary_stats = {}
25  self.dictionary_stats2 = {}
26  self.best_score_list = None
27  self.nbestscoring = None
28  self.suffixes = []
29  self.replica_exchange = False
30  self.ascii = ascii
31  self.initoutput = {}
32  self.residuetypekey = IMP.StringKey("ResidueName")
33  self.chainids = "ABCDEFGHIJKLMNOPQRSTUVXYWZabcdefghijklmnopqrstuvxywz"
34  self.dictchain = {}
35  self.particle_infos_for_pdb = {}
36  self.atomistic=atomistic
37 
38  def get_pdb_names(self):
39  return self.dictionary_pdbs.keys()
40 
41  def get_rmf_names(self):
42  return self.dictionary_rmfs.keys()
43 
44  def get_stat_names(self):
45  return self.dictionary_stats.keys()
46 
47  def init_pdb(self, name, prot):
48  flpdb = open(name, 'w')
49  flpdb.close()
50  self.dictionary_pdbs[name] = prot
51  self.dictchain[name] = {}
52 
53  for n, i in enumerate(self.dictionary_pdbs[name].get_children()):
54  self.dictchain[name][i.get_name()] = self.chainids[n]
55 
56  def write_pdb(self,name,appendmode=True,
57  translate_to_geometric_center=False):
58  if appendmode:
59  flpdb = open(name, 'a')
60  else:
61  flpdb = open(name, 'w')
62 
63  (particle_infos_for_pdb,
64  geometric_center) = self.get_particle_infos_for_pdb_writing(name)
65 
66  if not translate_to_geometric_center:
67  geometric_center = (0, 0, 0)
68 
69  for tupl in particle_infos_for_pdb:
70 
71  (xyz, atom_index, atom_type, residue_type,
72  chain_id, residue_index,radius) = tupl
73 
74  flpdb.write(IMP.atom.get_pdb_string((xyz[0] - geometric_center[0],
75  xyz[1] - geometric_center[1],
76  xyz[2] - geometric_center[2]),
77  atom_index, atom_type, residue_type,
78  chain_id, residue_index,' ',1.00,radius))
79 
80  flpdb.write("ENDMDL\n")
81  flpdb.close()
82 
83  del particle_infos_for_pdb
84 
85  def get_particle_infos_for_pdb_writing(self, name):
86  # index_residue_pair_list={}
87 
88  # the resindexes dictionary keep track of residues that have been already
89  # added to avoid duplication
90  # highest resolution have highest priority
91  resindexes_dict = {}
92 
93  # this dictionary dill contain the sequence of tuples needed to
94  # write the pdb
95  particle_infos_for_pdb = []
96 
97  geometric_center = [0, 0, 0]
98  atom_count = 0
99  atom_index = 0
100 
101  for n, p in enumerate(IMP.atom.get_leaves(self.dictionary_pdbs[name])):
102 
103  # this loop gets the protein name from the
104  # particle leave by descending into the hierarchy
105 
106  (protname, is_a_bead) = IMP.pmi.tools.get_prot_name_from_particle(
107  p, self.dictchain[name])
108 
109  if protname not in resindexes_dict:
110  resindexes_dict[protname] = []
111 
112  if IMP.atom.Atom.get_is_setup(p) and self.atomistic:
113  atom_index += 1
114  residue = IMP.atom.Residue(IMP.atom.Atom(p).get_parent())
115  rt = residue.get_residue_type()
116  resind = residue.get_index()
117  atomtype = IMP.atom.Atom(p).get_atom_type()
118  xyz = list(IMP.core.XYZ(p).get_coordinates())
119  radius = IMP.core.XYZR(p).get_radius()
120  geometric_center[0] += xyz[0]
121  geometric_center[1] += xyz[1]
122  geometric_center[2] += xyz[2]
123  atom_count += 1
124  particle_infos_for_pdb.append((xyz, atom_index,
125  atomtype, rt, self.dictchain[name][protname], resind,radius))
126  resindexes_dict[protname].append(resind)
127 
129 
130  residue = IMP.atom.Residue(p)
131  resind = residue.get_index()
132  # skip if the residue was already added by atomistic resolution
133  # 0
134  if resind in resindexes_dict[protname]:
135  continue
136  else:
137  resindexes_dict[protname].append(resind)
138  atom_index += 1
139  rt = residue.get_residue_type()
140  xyz = IMP.core.XYZ(p).get_coordinates()
141  radius = IMP.core.XYZR(p).get_radius()
142  geometric_center[0] += xyz[0]
143  geometric_center[1] += xyz[1]
144  geometric_center[2] += xyz[2]
145  atom_count += 1
146  particle_infos_for_pdb.append((xyz, atom_index,
147  IMP.atom.AT_CA, rt, self.dictchain[name][protname], resind,radius))
148 
149  # if protname not in index_residue_pair_list:
150  # index_residue_pair_list[protname]=[(atom_index,resind)]
151  # else:
152  # index_residue_pair_list[protname].append((atom_index,resind))
153 
154  elif IMP.atom.Fragment.get_is_setup(p) and not is_a_bead:
155  resindexes = IMP.pmi.tools.get_residue_indexes(p)
156  resind = resindexes[len(resindexes) / 2]
157  if resind in resindexes_dict[protname]:
158  continue
159  else:
160  resindexes_dict[protname].append(resind)
161  atom_index += 1
162  rt = IMP.atom.ResidueType('BEA')
163  xyz = IMP.core.XYZ(p).get_coordinates()
164  radius = IMP.core.XYZR(p).get_radius()
165  geometric_center[0] += xyz[0]
166  geometric_center[1] += xyz[1]
167  geometric_center[2] += xyz[2]
168  atom_count += 1
169  particle_infos_for_pdb.append((xyz, atom_index,
170  IMP.atom.AT_CA, rt, self.dictchain[name][protname], resind,radius))
171 
172  else:
173  if is_a_bead:
174  atom_index += 1
175  rt = IMP.atom.ResidueType('BEA')
176  resindexes = IMP.pmi.tools.get_residue_indexes(p)
177  resind = resindexes[len(resindexes) / 2]
178  xyz = IMP.core.XYZ(p).get_coordinates()
179  radius = IMP.core.XYZR(p).get_radius()
180  geometric_center[0] += xyz[0]
181  geometric_center[1] += xyz[1]
182  geometric_center[2] += xyz[2]
183  atom_count += 1
184  particle_infos_for_pdb.append((xyz, atom_index,
185  IMP.atom.AT_CA, rt, self.dictchain[name][protname], resind,radius))
186  # if protname not in index_residue_pair_list:
187  # index_residue_pair_list[protname]=[(atom_index,resind)]
188  # else:
189  # index_residue_pair_list[protname].append((atom_index,resind))
190 
191  geometric_center = (geometric_center[0] / atom_count,
192  geometric_center[1] / atom_count,
193  geometric_center[2] / atom_count)
194 
195  return (particle_infos_for_pdb, geometric_center)
196  '''
197  #now write the connectivity
198  for protname in index_residue_pair_list:
199 
200  ls=index_residue_pair_list[protname]
201  #sort by residue
202  ls=sorted(ls, key=lambda tup: tup[1])
203  #get the index list
204  indexes=map(list, zip(*ls))[0]
205  # get the contiguous pairs
206  indexes_pairs=list(IMP.pmi.tools.sublist_iterator(indexes,lmin=2,lmax=2))
207  #write the connection record only if the residue gap is larger than 1
208 
209  for ip in indexes_pairs:
210  if abs(ip[1]-ip[0])>1:
211  flpdb.write('{:6s}{:5d}{:5d}'.format('CONECT',ip[0],ip[1]))
212  flpdb.write("\n")
213  '''
214 
215  def write_pdbs(self, appendmode=True):
216  for pdb in self.dictionary_pdbs.keys():
217  self.write_pdb(pdb, appendmode)
218 
219  def init_pdb_best_scoring(
220  self,
221  suffix,
222  prot,
223  nbestscoring,
224  replica_exchange=False):
225  # save only the nbestscoring conformations
226  # create as many pdbs as needed
227 
228  self.suffixes.append(suffix)
229  self.replica_exchange = replica_exchange
230  if not self.replica_exchange:
231  # common usage
232  # if you are not in replica exchange mode
233  # initialize the array of scores internally
234  self.best_score_list = []
235  else:
236  # otherwise the replicas must cominucate
237  # through a common file to know what are the best scores
238  self.best_score_file_name = "best.scores.rex.py"
239  self.best_score_list = []
240  best_score_file = open(self.best_score_file_name, "w")
241  best_score_file.write(
242  "self.best_score_list=" + str(self.best_score_list))
243  best_score_file.close()
244 
245  self.nbestscoring = nbestscoring
246  for i in range(self.nbestscoring):
247  name = suffix + "." + str(i) + ".pdb"
248  flpdb = open(name, 'w')
249  flpdb.close()
250  self.dictionary_pdbs[name] = prot
251  self.dictchain[name] = {}
252  for n, i in enumerate(self.dictionary_pdbs[name].get_children()):
253  self.dictchain[name][i.get_name()] = self.chainids[n]
254 
255  def write_pdb_best_scoring(self, score):
256  if self.nbestscoring is None:
257  print "Output.write_pdb_best_scoring: init_pdb_best_scoring not run"
258 
259  # update the score list
260  if self.replica_exchange:
261  # read the self.best_score_list from the file
262  execfile(self.best_score_file_name)
263 
264  if len(self.best_score_list) < self.nbestscoring:
265  self.best_score_list.append(score)
266  self.best_score_list.sort()
267  index = self.best_score_list.index(score)
268  for suffix in self.suffixes:
269  for i in range(len(self.best_score_list) - 2, index - 1, -1):
270  oldname = suffix + "." + str(i) + ".pdb"
271  newname = suffix + "." + str(i + 1) + ".pdb"
272  # rename on Windows fails if newname already exists
273  if os.path.exists(newname):
274  os.unlink(newname)
275  os.rename(oldname, newname)
276  filetoadd = suffix + "." + str(index) + ".pdb"
277  self.write_pdb(filetoadd, appendmode=False)
278 
279  else:
280  if score < self.best_score_list[-1]:
281  self.best_score_list.append(score)
282  self.best_score_list.sort()
283  self.best_score_list.pop(-1)
284  index = self.best_score_list.index(score)
285  for suffix in self.suffixes:
286  for i in range(len(self.best_score_list) - 1, index - 1, -1):
287  oldname = suffix + "." + str(i) + ".pdb"
288  newname = suffix + "." + str(i + 1) + ".pdb"
289  os.rename(oldname, newname)
290  filenametoremove = suffix + \
291  "." + str(self.nbestscoring) + ".pdb"
292  os.remove(filenametoremove)
293  filetoadd = suffix + "." + str(index) + ".pdb"
294  self.write_pdb(filetoadd, appendmode=False)
295 
296  if self.replica_exchange:
297  # write the self.best_score_list to the file
298  best_score_file = open(self.best_score_file_name, "w")
299  best_score_file.write(
300  "self.best_score_list=" + str(self.best_score_list))
301  best_score_file.close()
302 
303  def init_rmf(self, name, hierarchies,rs=None):
304  rh = RMF.create_rmf_file(name)
305  IMP.rmf.add_hierarchies(rh, hierarchies)
306  if rs is not None:
308  self.dictionary_rmfs[name] = rh
309 
310  def add_restraints_to_rmf(self, name, objectlist):
311  for o in objectlist:
312  try:
313  rs = o.get_restraint_for_rmf()
314  except:
315  rs = o.get_restraint()
317  self.dictionary_rmfs[name],
318  rs.get_restraints())
319 
320  def add_geometries_to_rmf(self, name, objectlist):
321  for o in objectlist:
322  geos = o.get_geometries()
323  IMP.rmf.add_geometries(self.dictionary_rmfs[name], geos)
324 
325  def add_particle_pair_from_restraints_to_rmf(self, name, objectlist):
326  for o in objectlist:
327 
328  pps = o.get_particle_pairs()
329  for pp in pps:
330  IMP.rmf.add_geometry(
331  self.dictionary_rmfs[name],
333 
334  def write_rmf(self, name):
335  IMP.rmf.save_frame(self.dictionary_rmfs[name])
336  self.dictionary_rmfs[name].flush()
337 
338  def close_rmf(self, name):
339  del self.dictionary_rmfs[name]
340 
341  def write_rmfs(self):
342  for rmf in self.dictionary_rmfs.keys():
343  self.write_rmf(rmf)
344 
345  def init_stat(self, name, listofobjects):
346  if self.ascii:
347  flstat = open(name, 'w')
348  flstat.close()
349  else:
350  flstat = open(name, 'wb')
351  flstat.close()
352 
353  # check that all objects in listofobjects have a get_output method
354  for l in listofobjects:
355  if not "get_output" in dir(l):
356  raise ValueError("Output: object %s doesn't have get_output() method" % str(l))
357  self.dictionary_stats[name] = listofobjects
358 
359  def set_output_entry(self, key, value):
360  self.initoutput.update({key: value})
361 
362  def write_stat(self, name, appendmode=True):
363  output = self.initoutput
364  for obj in self.dictionary_stats[name]:
365  d = obj.get_output()
366  # remove all entries that begin with _ (private entries)
367  dfiltered = dict((k, v) for k, v in d.iteritems() if k[0] != "_")
368  output.update(dfiltered)
369 
370  if appendmode:
371  writeflag = 'a'
372  else:
373  writeflag = 'w'
374 
375  if self.ascii:
376  flstat = open(name, writeflag)
377  flstat.write("%s \n" % output)
378  flstat.close()
379  else:
380  flstat = open(name, writeflag + 'b')
381  cPickle.dump(output, flstat, 2)
382  flstat.close()
383 
384  def write_stats(self):
385  for stat in self.dictionary_stats.keys():
386  self.write_stat(stat)
387 
388  def get_stat(self, name):
389  output = {}
390  for obj in self.dictionary_stats[name]:
391  output.update(obj.get_output())
392  return output
393 
394  def write_test(self, name, listofobjects):
395  '''
396  write the test:
397  output=output.Output()
398  output.write_test("test_modeling11_models.rmf_45492_11Sep13_veena_imp-020713.dat",outputobjects)
399  run the test:
400  output=output.Output()
401  output.test("test_modeling11_models.rmf_45492_11Sep13_veena_imp-020713.dat",outputobjects)
402  '''
403  flstat = open(name, 'w')
404  output = self.initoutput
405  for l in listofobjects:
406  if not "get_test_output" in dir(l) and not "get_output" in dir(l):
407  raise ValueError("Output: object %s doesn't have get_output() or get_test_output() method" % str(l))
408  self.dictionary_stats[name] = listofobjects
409 
410  for obj in self.dictionary_stats[name]:
411  try:
412  d = obj.get_test_output()
413  except:
414  d = obj.get_output()
415  # remove all entries that begin with _ (private entries)
416  dfiltered = dict((k, v) for k, v in d.iteritems() if k[0] != "_")
417  output.update(dfiltered)
418  #output.update({"ENVIRONMENT": str(self.get_environment_variables())})
419  #output.update(
420  # {"IMP_VERSIONS": str(self.get_versions_of_relevant_modules())})
421  flstat.write("%s \n" % output)
422  flstat.close()
423 
424  def test(self, name, listofobjects):
425  from numpy.testing import assert_approx_equal as aae
426  output = self.initoutput
427  for l in listofobjects:
428  if not "get_test_output" in dir(l) and not "get_output" in dir(l):
429  raise ValueError("Output: object %s doesn't have get_output() or get_test_output() method" % str(l))
430  for obj in listofobjects:
431  try:
432  output.update(obj.get_test_output())
433  except:
434  output.update(obj.get_output())
435  #output.update({"ENVIRONMENT": str(self.get_environment_variables())})
436  #output.update(
437  # {"IMP_VERSIONS": str(self.get_versions_of_relevant_modules())})
438 
439  flstat = open(name, 'r')
440 
441  passed=True
442  for l in flstat:
443  test_dict = eval(l)
444  for k in test_dict:
445  if k in output:
446  old_value = str(test_dict[k])
447  new_value = str(output[k])
448 
449  if test_dict[k] != output[k]:
450  if len(old_value) < 50 and len(new_value) < 50:
451  print str(k) + ": test failed, old value: " + old_value + " new value " + new_value
452  passed=False
453  else:
454  print str(k) + ": test failed, omitting results (too long)"
455  passed=False
456 
457  else:
458  print str(k) + " from old objects (file " + str(name) + ") not in new objects"
459  return passed
460 
461  def get_environment_variables(self):
462  import os
463  return str(os.environ)
464 
465  def get_versions_of_relevant_modules(self):
466  import IMP
467  versions = {}
468  versions["IMP_VERSION"] = IMP.kernel.get_module_version()
469  try:
470  import IMP.pmi
471  versions["PMI_VERSION"] = IMP.pmi.get_module_version()
472  except (ImportError):
473  pass
474  try:
475  import IMP.isd2
476  versions["ISD2_VERSION"] = IMP.isd2.get_module_version()
477  except (ImportError):
478  pass
479  try:
480  import IMP.isd_emxl
481  versions["ISD_EMXL_VERSION"] = IMP.isd_emxl.get_module_version()
482  except (ImportError):
483  pass
484  return versions
485 
486 #-------------------
487  def init_stat2(
488  self,
489  name,
490  listofobjects,
491  extralabels=None,
492  listofsummedobjects=None):
493  # this is a new stat file that should be less
494  # space greedy!
495  # listofsummedobjects must be in the form [([obj1,obj2,obj3,obj4...],label)]
496  # extralabels
497 
498  if listofsummedobjects is None:
499  listofsummedobjects = []
500  if extralabels is None:
501  extralabels = []
502  flstat = open(name, 'w')
503  output = {}
504  stat2_keywords = {"STAT2HEADER": "STAT2HEADER"}
505  stat2_keywords.update(
506  {"STAT2HEADER_ENVIRON": str(self.get_environment_variables())})
507  stat2_keywords.update(
508  {"STAT2HEADER_IMP_VERSIONS": str(self.get_versions_of_relevant_modules())})
509  stat2_inverse = {}
510 
511  for l in listofobjects:
512  if not "get_output" in dir(l):
513  raise ValueError("Output: object %s doesn't have get_output() method" % str(l))
514  else:
515  d = l.get_output()
516  # remove all entries that begin with _ (private entries)
517  dfiltered = dict((k, v)
518  for k, v in d.iteritems() if k[0] != "_")
519  output.update(dfiltered)
520 
521  # check for customizable entries
522  for l in listofsummedobjects:
523  for t in l[0]:
524  if not "get_output" in dir(t):
525  raise ValueError("Output: object %s doesn't have get_output() method" % str(t))
526  else:
527  if "_TotalScore" not in t.get_output():
528  raise ValueError("Output: object %s doesn't have _TotalScore entry to be summed" % str(t))
529  else:
530  output.update({l[1]: 0.0})
531 
532  for k in extralabels:
533  output.update({k: 0.0})
534 
535  for n, k in enumerate(output):
536  stat2_keywords.update({n: k})
537  stat2_inverse.update({k: n})
538 
539  flstat.write("%s \n" % stat2_keywords)
540  flstat.close()
541  self.dictionary_stats2[name] = (
542  listofobjects,
543  stat2_inverse,
544  listofsummedobjects,
545  extralabels)
546 
547  def write_stat2(self, name, appendmode=True):
548  output = {}
549  (listofobjects, stat2_inverse, listofsummedobjects,
550  extralabels) = self.dictionary_stats2[name]
551 
552  # writing objects
553  for obj in listofobjects:
554  od = obj.get_output()
555  dfiltered = dict((k, v) for k, v in od.iteritems() if k[0] != "_")
556  for k in dfiltered:
557  output.update({stat2_inverse[k]: od[k]})
558 
559  # writing summedobjects
560  for l in listofsummedobjects:
561  partial_score = 0.0
562  for t in l[0]:
563  d = t.get_output()
564  partial_score += float(d["_TotalScore"])
565  output.update({stat2_inverse[l[1]]: str(partial_score)})
566 
567  # writing extralabels
568  for k in extralabels:
569  if k in self.initoutput:
570  output.update({stat2_inverse[k]: self.initoutput[k]})
571  else:
572  output.update({stat2_inverse[k]: "None"})
573 
574  if appendmode:
575  writeflag = 'a'
576  else:
577  writeflag = 'w'
578 
579  flstat = open(name, writeflag)
580  flstat.write("%s \n" % output)
581  flstat.close()
582 
583  def write_stats2(self):
584  for stat in self.dictionary_stats2.keys():
585  self.write_stat2(stat)
586 
587 
588 class ProcessOutput(object):
589  """A class for reading stat files"""
590  def __init__(self, filename):
591  self.filename = filename
592  self.isstat1 = False
593  self.isstat2 = False
594 
595  # open the file
596  if not self.filename is None:
597  f = open(self.filename, "r")
598  else:
599  raise ValueError("No file name provided. Use -h for help")
600 
601  # get the keys from the first line
602  for line in f.readlines():
603  d = eval(line)
604  self.klist = d.keys()
605  # check if it is a stat2 file
606  if "STAT2HEADER" in self.klist:
607  self.isstat2 = True
608  for k in self.klist:
609  if "STAT2HEADER" in str(k):
610  # if print_header: print k, d[k]
611  del d[k]
612  stat2_dict = d
613  # get the list of keys sorted by value
614  kkeys = [k[0]
615  for k in sorted(stat2_dict.iteritems(), key=operator.itemgetter(1))]
616  self.klist = [k[1]
617  for k in sorted(stat2_dict.iteritems(), key=operator.itemgetter(1))]
618  self.invstat2_dict = {}
619  for k in kkeys:
620  self.invstat2_dict.update({stat2_dict[k]: k})
621  else:
622  self.isstat1 = True
623  self.klist.sort()
624 
625  break
626  f.close()
627 
628  def get_keys(self):
629  return self.klist
630 
631  def show_keys(self, ncolumns=2, truncate=65):
632  IMP.pmi.tools.print_multicolumn(self.get_keys(), ncolumns, truncate)
633 
635  self,
636  fields,
637  filtertuple=None,
638  filterout=None,
639  get_every=1):
640  '''
641  this function get the wished field names and return a dictionary
642  you can give the optional argument filterout if you want to "grep" out
643  something from the file, so that it is faster
644 
645  filtertuple a tuple that contains ("TheKeyToBeFiltered",relationship,value)
646  relationship = "<", "==", or ">"
647  '''
648 
649  outdict = {}
650  for field in fields:
651  outdict[field] = []
652 
653  # print fields values
654  f = open(self.filename, "r")
655  line_number = 0
656 
657  for line in f.readlines():
658  if not filterout is None:
659  if filterout in line:
660  continue
661  line_number += 1
662 
663  if line_number % get_every != 0:
664  continue
665  #if line_number % 1000 == 0:
666  # print "ProcessOutput.get_fields: read line %s from file %s" % (str(line_number), self.filename)
667  try:
668  d = eval(line)
669  except:
670  print "# Warning: skipped line number " + str(line_number) + " not a valid line"
671  continue
672 
673  if self.isstat1:
674 
675  if not filtertuple is None:
676  keytobefiltered = filtertuple[0]
677  relationship = filtertuple[1]
678  value = filtertuple[2]
679  if relationship == "<":
680  if float(d[keytobefiltered]) >= value:
681  continue
682  if relationship == ">":
683  if float(d[keytobefiltered]) <= value:
684  continue
685  if relationship == "==":
686  if float(d[keytobefiltered]) != value:
687  continue
688  [outdict[field].append(d[field]) for field in fields]
689 
690  elif self.isstat2:
691  if line_number == 1:
692  continue
693 
694  if not filtertuple is None:
695  keytobefiltered = filtertuple[0]
696  relationship = filtertuple[1]
697  value = filtertuple[2]
698  if relationship == "<":
699  if float(d[self.invstat2_dict[keytobefiltered]]) >= value:
700  continue
701  if relationship == ">":
702  if float(d[self.invstat2_dict[keytobefiltered]]) <= value:
703  continue
704  if relationship == "==":
705  if float(d[self.invstat2_dict[keytobefiltered]]) != value:
706  continue
707 
708  [outdict[field].append(d[self.invstat2_dict[field]])
709  for field in fields]
710  f.close()
711  return outdict
712 
713 
714 def plot_fields(fields, framemin=None, framemax=None):
715  import matplotlib as mpl
716  mpl.use('Agg')
717  import matplotlib.pyplot as plt
718 
719  plt.rc('lines', linewidth=4)
720  fig, axs = plt.subplots(nrows=len(fields))
721  fig.set_size_inches(10.5, 5.5 * len(fields))
722  plt.rc('axes', color_cycle=['r'])
723 
724  n = 0
725  for key in fields:
726  if framemin is None:
727  framemin = 0
728  if framemax is None:
729  framemax = len(fields[key])
730  x = range(framemin, framemax)
731  y = [float(y) for y in fields[key][framemin:framemax]]
732  if len(fields) > 1:
733  axs[n].plot(x, y)
734  axs[n].set_title(key, size="xx-large")
735  axs[n].tick_params(labelsize=18, pad=10)
736  else:
737  axs.plot(x, y)
738  axs.set_title(key, size="xx-large")
739  axs.tick_params(labelsize=18, pad=10)
740  n += 1
741 
742  # Tweak spacing between subplots to prevent labels from overlapping
743  plt.subplots_adjust(hspace=0.3)
744  plt.show()
745 
746 
748  name, values_lists, valuename=None, bins=40, colors=None, format="png",
749  reference_xline=None, yplotrange=None, xplotrange=None,normalized=True,
750  leg_names=None):
751 
752  '''This function is plotting a list of histograms from a value list.
753  @param name the name of the plot
754  @param value_lists the list of list of values eg: [[...],[...],[...]]
755  @param valuename=None the y-label
756  @param bins=40 the number of bins
757  @param colors=None. If None, will use rainbow. Else will use specific list
758  @param format="png" output format
759  @param reference_xline=None plot a reference line parallel to the y-axis
760  @param yplotrange=None the range for the y-axis
761  @param normalized=True whether the histogram is normalized or not
762  @param leg_names names for the legend
763  '''
764 
765  import matplotlib as mpl
766  mpl.use('Agg')
767  import matplotlib.pyplot as plt
768  import matplotlib.cm as cm
769  fig = plt.figure(figsize=(18.0, 9.0))
770 
771  if colors is None:
772  colors = cm.rainbow(np.linspace(0, 1, len(values_lists)))
773  for nv,values in enumerate(values_lists):
774  col=colors[nv]
775  if leg_names is not None:
776  label=leg_names[nv]
777  else:
778  label=str(nv)
779  h=plt.hist(
780  [float(y) for y in values],
781  bins=bins,
782  color=col,
783  normed=normalized,histtype='step',lw=4,
784  label=label)
785 
786  # plt.title(name,size="xx-large")
787  plt.tick_params(labelsize=12, pad=10)
788  if valuename is None:
789  plt.xlabel(name, size="xx-large")
790  else:
791  plt.xlabel(valuename, size="xx-large")
792  plt.ylabel("Frequency", size="xx-large")
793 
794  if not yplotrange is None:
795  plt.ylim()
796  if not xplotrange is None:
797  plt.xlim(xplotrange)
798 
799  plt.legend(loc=2)
800 
801  if not reference_xline is None:
802  plt.axvline(
803  reference_xline,
804  color='red',
805  linestyle='dashed',
806  linewidth=1)
807 
808  plt.savefig(name + "." + format, dpi=150, transparent=True)
809  plt.show()
810 
811 
812 def plot_fields_box_plots(name, values, positions, frequencies=None,
813  valuename="None", positionname="None", xlabels=None):
814  '''
815  This function plots time series as boxplots
816  fields is a list of time series, positions are the x-values
817  valuename is the y-label, positionname is the x-label
818  '''
819  import matplotlib as mpl
820  mpl.use('Agg')
821  import matplotlib.pyplot as plt
822  from matplotlib.patches import Polygon
823 
824  bps = []
825  fig = plt.figure(figsize=(float(len(positions)) / 2, 5.0))
826  fig.canvas.set_window_title(name)
827 
828  ax1 = fig.add_subplot(111)
829 
830  plt.subplots_adjust(left=0.2, right=0.990, top=0.95, bottom=0.4)
831 
832  bps.append(plt.boxplot(values, notch=0, sym='', vert=1,
833  whis=1.5, positions=positions))
834 
835  plt.setp(bps[-1]['boxes'], color='black', lw=1.5)
836  plt.setp(bps[-1]['whiskers'], color='black', ls=":", lw=1.5)
837 
838  if frequencies is not None:
839  ax1.plot(positions, frequencies, 'k.', alpha=0.5, markersize=20)
840 
841  # print ax1.xaxis.get_majorticklocs()
842  if not xlabels is None:
843  ax1.set_xticklabels(xlabels)
844  plt.xticks(rotation=90)
845  plt.xlabel(positionname)
846  plt.ylabel(valuename)
847 
848  plt.savefig(name,dpi=150)
849  plt.show()
850 
851 
852 def plot_xy_data(x,y,title=None,out_fn=None,display=True,set_plot_yaxis_range=None,
853  xlabel=None,ylabel=None):
854  import matplotlib as mpl
855  mpl.use('Agg')
856  import matplotlib.pyplot as plt
857  plt.rc('lines', linewidth=2)
858 
859  fig, ax = plt.subplots(nrows=1)
860  fig.set_size_inches(8,4.5)
861  if title is not None:
862  fig.canvas.set_window_title(title)
863 
864  #plt.rc('axes', color='r')
865  ax.plot(x,y,color='r')
866  if set_plot_yaxis_range is not None:
867  x1,x2,y1,y2=plt.axis()
868  y1=set_plot_yaxis_range[0]
869  y2=set_plot_yaxis_range[1]
870  plt.axis((x1,x2,y1,y2))
871  if title is not None:
872  ax.set_title(title)
873  if xlabel is not None:
874  ax.set_xlabel(xlabel)
875  if ylabel is not None:
876  ax.set_ylabel(ylabel)
877  if out_fn is not None:
878  plt.savefig(out_fn+".pdf")
879  if display:
880  plt.show()
881  plt.close(fig)
882 
883 def plot_scatter_xy_data(x,y,labelx="None",labely="None",
884  xmin=None,xmax=None,ymin=None,ymax=None,
885  savefile=False,filename="None.eps",alpha=0.75):
886 
887  import matplotlib as mpl
888  mpl.use('Agg')
889  import matplotlib.pyplot as plt
890  import sys
891  from matplotlib import rc
892  #rc('font', **{'family':'serif','serif':['Palatino']})
893  rc('font',**{'family':'sans-serif','sans-serif':['Helvetica']})
894  #rc('text', usetex=True)
895 
896  fig, axs = plt.subplots(1)
897 
898  axs0 = axs
899 
900  axs0.set_xlabel(labelx, size="xx-large")
901  axs0.set_ylabel(labely, size="xx-large")
902  axs0.tick_params(labelsize=18, pad=10)
903 
904  plot2 = []
905 
906  plot2.append(axs0.plot(x, y, 'o', color='k',lw=2, ms=0.1, alpha=alpha, c="w"))
907 
908  axs0.legend(
909  loc=0,
910  frameon=False,
911  scatterpoints=1,
912  numpoints=1,
913  columnspacing=1)
914 
915  fig.set_size_inches(8.0, 8.0)
916  fig.subplots_adjust(left=0.161, right=0.850, top=0.95, bottom=0.11)
917  if (not ymin is None) and (not ymax is None):
918  axs0.set_ylim(ymin,ymax)
919  if (not xmin is None) and (not xmax is None):
920  axs0.set_xlim(xmin,xmax)
921 
922  #plt.show()
923  if savefile:
924  fig.savefig(filename, dpi=300)
925 
926 
927 def get_graph_from_hierarchy(hier):
928  graph = []
929  depth_dict = {}
930  depth = 0
931  (graph, depth, depth_dict) = recursive_graph(
932  hier, graph, depth, depth_dict)
933 
934  # filters node labels according to depth_dict
935  node_labels_dict = {}
936  node_size_dict = {}
937  for key in depth_dict:
938  node_size_dict = 10 / depth_dict[key]
939  if depth_dict[key] < 3:
940  node_labels_dict[key] = key
941  else:
942  node_labels_dict[key] = ""
943  draw_graph(graph, labels_dict=node_labels_dict)
944 
945 
946 def recursive_graph(hier, graph, depth, depth_dict):
947  depth = depth + 1
948  nameh = IMP.atom.Hierarchy(hier).get_name()
949  index = str(hier.get_particle().get_index())
950  name1 = nameh + "|#" + index
951  depth_dict[name1] = depth
952 
953  children = IMP.atom.Hierarchy(hier).get_children()
954 
955  if len(children) == 1 or children is None:
956  depth = depth - 1
957  return (graph, depth, depth_dict)
958 
959  else:
960  for c in children:
961  (graph, depth, depth_dict) = recursive_graph(
962  c, graph, depth, depth_dict)
963  nameh = IMP.atom.Hierarchy(c).get_name()
964  index = str(c.get_particle().get_index())
965  namec = nameh + "|#" + index
966  graph.append((name1, namec))
967 
968  depth = depth - 1
969  return (graph, depth, depth_dict)
970 
971 
972 def draw_graph(graph, labels_dict=None, graph_layout='spring',
973  node_size=5, node_color=None, node_alpha=0.3,
974  node_text_size=11, fixed=None, pos=None,
975  edge_color='blue', edge_alpha=0.3, edge_thickness=1,
976  edge_text_pos=0.3,
977  validation_edges=None,
978  text_font='sans-serif',
979  out_filename=None):
980 
981  import matplotlib as mpl
982  mpl.use('Agg')
983  import networkx as nx
984  import matplotlib.pyplot as plt
985  from math import sqrt, pi
986 
987  # create networkx graph
988  G = nx.Graph()
989 
990  # add edges
991  if type(edge_thickness) is list:
992  for edge,weight in zip(graph,edge_thickness):
993  G.add_edge(edge[0], edge[1], weight=weight)
994  else:
995  for edge in graph:
996  G.add_edge(edge[0], edge[1])
997 
998  if node_color==None:
999  node_color_rgb=(0,0,0)
1000  node_color_hex="000000"
1001  else:
1003  tmpcolor_rgb=[]
1004  tmpcolor_hex=[]
1005  for node in G.nodes():
1006  cctuple=cc.rgb(node_color[node])
1007  tmpcolor_rgb.append((cctuple[0]/255,cctuple[1]/255,cctuple[2]/255))
1008  tmpcolor_hex.append(node_color[node])
1009  node_color_rgb=tmpcolor_rgb
1010  node_color_hex=tmpcolor_hex
1011 
1012  # get node sizes if dictionary
1013  if type(node_size) is dict:
1014  tmpsize=[]
1015  for node in G.nodes():
1016  size=sqrt(node_size[node])/pi*10.0
1017  tmpsize.append(size)
1018  node_size=tmpsize
1019 
1020  for n,node in enumerate(G.nodes()):
1021  color=node_color_hex[n]
1022  size=node_size[n]
1023  nx.set_node_attributes(G, "graphics", {node : {'type': 'ellipse','w': size, 'h': size,'fill': '#'+color, 'label': node}})
1024  nx.set_node_attributes(G, "LabelGraphics", {node : {'type': 'text','text':node, 'color':'#000000', 'visible':'true'}})
1025 
1026  for edge in G.edges():
1027  nx.set_edge_attributes(G, "graphics", {edge : {'width': 1,'fill': '#000000'}})
1028 
1029  for ve in validation_edges:
1030  print ve
1031  if (ve[0],ve[1]) in G.edges():
1032  print "found forward"
1033  nx.set_edge_attributes(G, "graphics", {ve : {'width': 1,'fill': '#00FF00'}})
1034  elif (ve[1],ve[0]) in G.edges():
1035  print "found backward"
1036  nx.set_edge_attributes(G, "graphics", {(ve[1],ve[0]) : {'width': 1,'fill': '#00FF00'}})
1037  else:
1038  G.add_edge(ve[0], ve[1])
1039  print "not found"
1040  nx.set_edge_attributes(G, "graphics", {ve : {'width': 1,'fill': '#FF0000'}})
1041 
1042  # these are different layouts for the network you may try
1043  # shell seems to work best
1044  if graph_layout == 'spring':
1045  print fixed, pos
1046  graph_pos = nx.spring_layout(G,k=1.0/8.0,fixed=fixed,pos=pos)
1047  elif graph_layout == 'spectral':
1048  graph_pos = nx.spectral_layout(G)
1049  elif graph_layout == 'random':
1050  graph_pos = nx.random_layout(G)
1051  else:
1052  graph_pos = nx.shell_layout(G)
1053 
1054 
1055  # draw graph
1056  nx.draw_networkx_nodes(G, graph_pos, node_size=node_size,
1057  alpha=node_alpha, node_color=node_color_rgb,
1058  linewidths=0)
1059  nx.draw_networkx_edges(G, graph_pos, width=edge_thickness,
1060  alpha=edge_alpha, edge_color=edge_color)
1061  nx.draw_networkx_labels(
1062  G, graph_pos, labels=labels_dict, font_size=node_text_size,
1063  font_family=text_font)
1064  if out_filename:
1065  plt.savefig(out_filename)
1066  nx.write_gml(G,'out.gml')
1067  plt.show()
1068 
1069 
1070 def draw_table():
1071 
1072  # still an example!
1073 
1074  from ipyD3 import d3object
1075  from IPython.display import display
1076 
1077  d3 = d3object(width=800,
1078  height=400,
1079  style='JFTable',
1080  number=1,
1081  d3=None,
1082  title='Example table with d3js',
1083  desc='An example table created created with d3js with data generated with Python.')
1084  data = [
1085  [1277.0,
1086  654.0,
1087  288.0,
1088  1976.0,
1089  3281.0,
1090  3089.0,
1091  10336.0,
1092  4650.0,
1093  4441.0,
1094  4670.0,
1095  944.0,
1096  110.0],
1097  [1318.0,
1098  664.0,
1099  418.0,
1100  1952.0,
1101  3581.0,
1102  4574.0,
1103  11457.0,
1104  6139.0,
1105  7078.0,
1106  6561.0,
1107  2354.0,
1108  710.0],
1109  [1783.0,
1110  774.0,
1111  564.0,
1112  1470.0,
1113  3571.0,
1114  3103.0,
1115  9392.0,
1116  5532.0,
1117  5661.0,
1118  4991.0,
1119  2032.0,
1120  680.0],
1121  [1301.0,
1122  604.0,
1123  286.0,
1124  2152.0,
1125  3282.0,
1126  3369.0,
1127  10490.0,
1128  5406.0,
1129  4727.0,
1130  3428.0,
1131  1559.0,
1132  620.0],
1133  [1537.0,
1134  1714.0,
1135  724.0,
1136  4824.0,
1137  5551.0,
1138  8096.0,
1139  16589.0,
1140  13650.0,
1141  9552.0,
1142  13709.0,
1143  2460.0,
1144  720.0],
1145  [5691.0,
1146  2995.0,
1147  1680.0,
1148  11741.0,
1149  16232.0,
1150  14731.0,
1151  43522.0,
1152  32794.0,
1153  26634.0,
1154  31400.0,
1155  7350.0,
1156  3010.0],
1157  [1650.0,
1158  2096.0,
1159  60.0,
1160  50.0,
1161  1180.0,
1162  5602.0,
1163  15728.0,
1164  6874.0,
1165  5115.0,
1166  3510.0,
1167  1390.0,
1168  170.0],
1169  [72.0, 60.0, 60.0, 10.0, 120.0, 172.0, 1092.0, 675.0, 408.0, 360.0, 156.0, 100.0]]
1170  data = [list(i) for i in zip(*data)]
1171  sRows = [['January',
1172  'February',
1173  'March',
1174  'April',
1175  'May',
1176  'June',
1177  'July',
1178  'August',
1179  'September',
1180  'October',
1181  'November',
1182  'Deecember']]
1183  sColumns = [['Prod {0}'.format(i) for i in xrange(1, 9)],
1184  [None, '', None, None, 'Group 1', None, None, 'Group 2']]
1185  d3.addSimpleTable(data,
1186  fontSizeCells=[12, ],
1187  sRows=sRows,
1188  sColumns=sColumns,
1189  sRowsMargins=[5, 50, 0],
1190  sColsMargins=[5, 20, 10],
1191  spacing=0,
1192  addBorders=1,
1193  addOutsideBorders=-1,
1194  rectWidth=45,
1195  rectHeight=0
1196  )
1197  html = d3.render(mode=['html', 'show'])
1198  display(html)
A base class for Keys.
Definition: kernel/Key.h:46
void write_pdb(const Selection &mhd, base::TextOutput out, unsigned int model=1)
void save_frame(RMF::FileHandle file, unsigned int, std::string name="")
Definition: frames.h:42
void add_restraints(RMF::NodeHandle fh, const kernel::Restraints &hs)
A class for reading stat files.
Definition: output.py:588
Ints get_index(const kernel::ParticlesTemp &particles, const Subset &subset, const Subsets &excluded)
def plot_field_histogram
This function is plotting a list of histograms from a value list.
Definition: output.py:747
def plot_fields_box_plots
This function plots time series as boxplots fields is a list of time series, positions are the x-valu...
Definition: output.py:812
Miscellaneous utilities.
Definition: tools.py:1
std::string get_module_version()
a class to change color code to hexadecimal to rgb
Definition: tools.py:1364
def get_prot_name_from_particle
this function returns the component name provided a particle and a list of names
Definition: tools.py:900
def get_fields
this function get the wished field names and return a dictionary you can give the optional argument f...
Definition: output.py:634
The standard decorator for manipulating molecular structures.
A decorator for a particle representing an atom.
Definition: atom/Atom.h:234
The type for a residue.
A decorator for a particle with x,y,z coordinates.
Definition: XYZ.h:30
void add_hierarchies(RMF::NodeHandle fh, const atom::Hierarchies &hs)
Class for easy writing of PDBs, RMFs, and stat files.
Definition: output.py:19
void add_geometries(RMF::NodeHandle parent, const display::GeometriesTemp &r)
std::string get_module_version()
Display a segment connecting a pair of particles.
Definition: XYZR.h:168
static bool get_is_setup(const IMP::kernel::ParticleAdaptor &p)
Definition: atom/Atom.h:241
A decorator for a residue.
Definition: Residue.h:134
Basic functionality that is expected to be used by a wide variety of IMP users.
static bool get_is_setup(const IMP::kernel::ParticleAdaptor &p)
Definition: Residue.h:155
static bool get_is_setup(kernel::Model *m, kernel::ParticleIndex pi)
Definition: Fragment.h:46
def write_test
write the test: output=output.Output() output.write_test("test_modeling11_models.rmf_45492_11Sep13_ve...
Definition: output.py:394
Python classes to represent, score, sample and analyze models.
Functionality for loading, creating, manipulating and scoring atomic structures.
Hierarchies get_leaves(const Selection &h)
def get_residue_indexes
This "overloaded" function retrieves the residue indexes for each particle which is an instance of Fr...
Definition: tools.py:920
A decorator for a particle with x,y,z coordinates and a radius.
Definition: XYZR.h:27