IMP  2.3.0
The Integrative Modeling Platform
output.py
1 """@namespace IMP.pmi.output
2  Classes for writing output files and processing them.
3 """
4 
5 import IMP
6 import IMP.atom
7 import IMP.core
8 import IMP.pmi
9 import IMP.pmi.tools
10 import os
11 import RMF
12 import numpy as np
13 try:
14  import cPickle as pickle
15 except ImportError:
16  import pickle
17 
18 class Output(object):
19  """Class for easy writing of PDBs, RMFs, and stat files"""
20  def __init__(self, ascii=True,atomistic=False):
21  self.dictionary_pdbs = {}
22  self.dictionary_rmfs = {}
23  self.dictionary_stats = {}
24  self.dictionary_stats2 = {}
25  self.best_score_list = None
26  self.nbestscoring = None
27  self.suffixes = []
28  self.replica_exchange = False
29  self.ascii = ascii
30  self.initoutput = {}
31  self.residuetypekey = IMP.StringKey("ResidueName")
32  self.chainids = "ABCDEFGHIJKLMNOPQRSTUVXYWZabcdefghijklmnopqrstuvxywz"
33  self.dictchain = {}
34  self.particle_infos_for_pdb = {}
35  self.atomistic=atomistic
36 
37  def get_pdb_names(self):
38  return self.dictionary_pdbs.keys()
39 
40  def get_rmf_names(self):
41  return self.dictionary_rmfs.keys()
42 
43  def get_stat_names(self):
44  return self.dictionary_stats.keys()
45 
46  def init_pdb(self, name, prot):
47  flpdb = open(name, 'w')
48  flpdb.close()
49  self.dictionary_pdbs[name] = prot
50  self.dictchain[name] = {}
51 
52  for n, i in enumerate(self.dictionary_pdbs[name].get_children()):
53  self.dictchain[name][i.get_name()] = self.chainids[n]
54 
55  def write_pdb(self,name,appendmode=True,
56  translate_to_geometric_center=False):
57  import resource
58  if appendmode:
59  flpdb = open(name, 'a')
60  else:
61  flpdb = open(name, 'w')
62 
63  (particle_infos_for_pdb,
64  geometric_center) = self.get_particle_infos_for_pdb_writing(name)
65 
66  if not translate_to_geometric_center:
67  geometric_center = (0, 0, 0)
68 
69  for tupl in particle_infos_for_pdb:
70 
71  (xyz, atom_index, atom_type, residue_type,
72  chain_id, residue_index,radius) = tupl
73 
74  flpdb.write(IMP.atom.get_pdb_string((xyz[0] - geometric_center[0],
75  xyz[1] - geometric_center[1],
76  xyz[2] - geometric_center[2]),
77  atom_index, atom_type, residue_type,
78  chain_id, residue_index,' ',1.00,radius))
79 
80  flpdb.write("ENDMOL\n")
81  flpdb.close()
82 
83  del particle_infos_for_pdb
84 
85  def get_particle_infos_for_pdb_writing(self, name):
86  # index_residue_pair_list={}
87 
88  # the resindexes dictionary keep track of residues that have been already
89  # added to avoid duplication
90  # highest resolution have highest priority
91  resindexes_dict = {}
92 
93  # this dictionary dill contain the sequence of tuples needed to
94  # write the pdb
95  particle_infos_for_pdb = []
96 
97  geometric_center = [0, 0, 0]
98  atom_count = 0
99  atom_index = 0
100 
101  for n, p in enumerate(IMP.atom.get_leaves(self.dictionary_pdbs[name])):
102 
103  # this loop gets the protein name from the
104  # particle leave by descending into the hierarchy
105 
106  (protname, is_a_bead) = IMP.pmi.tools.get_prot_name_from_particle(
107  p, self.dictchain[name])
108 
109  if protname not in resindexes_dict:
110  resindexes_dict[protname] = []
111 
112  if IMP.atom.Atom.get_is_setup(p) and self.atomistic:
113  atom_index += 1
114  residue = IMP.atom.Residue(IMP.atom.Atom(p).get_parent())
115  rt = residue.get_residue_type()
116  resind = residue.get_index()
117  atomtype = IMP.atom.Atom(p).get_atom_type()
118  xyz = list(IMP.core.XYZ(p).get_coordinates())
119  radius = IMP.core.XYZR(p).get_radius()
120  geometric_center[0] += xyz[0]
121  geometric_center[1] += xyz[1]
122  geometric_center[2] += xyz[2]
123  atom_count += 1
124  particle_infos_for_pdb.append((xyz, atom_index,
125  atomtype, rt, self.dictchain[name][protname], resind,radius))
126  resindexes_dict[protname].append(resind)
127 
129 
130  residue = IMP.atom.Residue(p)
131  resind = residue.get_index()
132  # skip if the residue was already added by atomistic resolution
133  # 0
134  if resind in resindexes_dict[protname]:
135  continue
136  else:
137  resindexes_dict[protname].append(resind)
138  atom_index += 1
139  rt = residue.get_residue_type()
140  xyz = IMP.core.XYZ(p).get_coordinates()
141  radius = IMP.core.XYZR(p).get_radius()
142  geometric_center[0] += xyz[0]
143  geometric_center[1] += xyz[1]
144  geometric_center[2] += xyz[2]
145  atom_count += 1
146  particle_infos_for_pdb.append((xyz, atom_index,
147  IMP.atom.AT_CA, rt, self.dictchain[name][protname], resind,radius))
148 
149  # if protname not in index_residue_pair_list:
150  # index_residue_pair_list[protname]=[(atom_index,resind)]
151  # else:
152  # index_residue_pair_list[protname].append((atom_index,resind))
153 
154  elif IMP.atom.Fragment.get_is_setup(p) and not is_a_bead:
155  resindexes = IMP.pmi.tools.get_residue_indexes(p)
156  resind = resindexes[len(resindexes) / 2]
157  if resind in resindexes_dict[protname]:
158  continue
159  else:
160  resindexes_dict[protname].append(resind)
161  atom_index += 1
162  rt = IMP.atom.ResidueType('BEA')
163  xyz = IMP.core.XYZ(p).get_coordinates()
164  radius = IMP.core.XYZR(p).get_radius()
165  geometric_center[0] += xyz[0]
166  geometric_center[1] += xyz[1]
167  geometric_center[2] += xyz[2]
168  atom_count += 1
169  particle_infos_for_pdb.append((xyz, atom_index,
170  IMP.atom.AT_CA, rt, self.dictchain[name][protname], resind,radius))
171 
172  else:
173  if is_a_bead:
174  atom_index += 1
175  rt = IMP.atom.ResidueType('BEA')
176  resindexes = IMP.pmi.tools.get_residue_indexes(p)
177  resind = resindexes[len(resindexes) / 2]
178  xyz = IMP.core.XYZ(p).get_coordinates()
179  radius = IMP.core.XYZR(p).get_radius()
180  geometric_center[0] += xyz[0]
181  geometric_center[1] += xyz[1]
182  geometric_center[2] += xyz[2]
183  atom_count += 1
184  particle_infos_for_pdb.append((xyz, atom_index,
185  IMP.atom.AT_CA, rt, self.dictchain[name][protname], resind,radius))
186  # if protname not in index_residue_pair_list:
187  # index_residue_pair_list[protname]=[(atom_index,resind)]
188  # else:
189  # index_residue_pair_list[protname].append((atom_index,resind))
190 
191  geometric_center = (geometric_center[0] / atom_count,
192  geometric_center[1] / atom_count,
193  geometric_center[2] / atom_count)
194 
195  return (particle_infos_for_pdb, geometric_center)
196  '''
197  #now write the connectivity
198  for protname in index_residue_pair_list:
199 
200  ls=index_residue_pair_list[protname]
201  #sort by residue
202  ls=sorted(ls, key=lambda tup: tup[1])
203  #get the index list
204  indexes=map(list, zip(*ls))[0]
205  # get the contiguous pairs
206  indexes_pairs=list(IMP.pmi.tools.sublist_iterator(indexes,lmin=2,lmax=2))
207  #write the connection record only if the residue gap is larger than 1
208 
209  for ip in indexes_pairs:
210  if abs(ip[1]-ip[0])>1:
211  flpdb.write('{:6s}{:5d}{:5d}'.format('CONECT',ip[0],ip[1]))
212  flpdb.write("\n")
213  '''
214 
215  def write_pdbs(self, appendmode=True):
216  for pdb in self.dictionary_pdbs.keys():
217  self.write_pdb(pdb, appendmode)
218 
219  def init_pdb_best_scoring(
220  self,
221  suffix,
222  prot,
223  nbestscoring,
224  replica_exchange=False):
225  # save only the nbestscoring conformations
226  # create as many pdbs as needed
227 
228  self.suffixes.append(suffix)
229  self.replica_exchange = replica_exchange
230  if not self.replica_exchange:
231  # common usage
232  # if you are not in replica exchange mode
233  # initialize the array of scores internally
234  self.best_score_list = []
235  else:
236  # otherwise the replicas must cominucate
237  # through a common file to know what are the best scores
238  self.best_score_file_name = "best.scores.rex.py"
239  self.best_score_list = []
240  best_score_file = open(self.best_score_file_name, "w")
241  best_score_file.write(
242  "self.best_score_list=" + str(self.best_score_list))
243  best_score_file.close()
244 
245  self.nbestscoring = nbestscoring
246  for i in range(self.nbestscoring):
247  name = suffix + "." + str(i) + ".pdb"
248  flpdb = open(name, 'w')
249  flpdb.close()
250  self.dictionary_pdbs[name] = prot
251  self.dictchain[name] = {}
252  for n, i in enumerate(self.dictionary_pdbs[name].get_children()):
253  self.dictchain[name][i.get_name()] = self.chainids[n]
254 
255  def write_pdb_best_scoring(self, score):
256  if self.nbestscoring is None:
257  print "Output.write_pdb_best_scoring: init_pdb_best_scoring not run"
258 
259  # update the score list
260  if self.replica_exchange:
261  # read the self.best_score_list from the file
262  execfile(self.best_score_file_name)
263 
264  if len(self.best_score_list) < self.nbestscoring:
265  self.best_score_list.append(score)
266  self.best_score_list.sort()
267  index = self.best_score_list.index(score)
268  for suffix in self.suffixes:
269  for i in range(len(self.best_score_list) - 2, index - 1, -1):
270  oldname = suffix + "." + str(i) + ".pdb"
271  newname = suffix + "." + str(i + 1) + ".pdb"
272  os.rename(oldname, newname)
273  filetoadd = suffix + "." + str(index) + ".pdb"
274  self.write_pdb(filetoadd, appendmode=False)
275 
276  else:
277  if score < self.best_score_list[-1]:
278  self.best_score_list.append(score)
279  self.best_score_list.sort()
280  self.best_score_list.pop(-1)
281  index = self.best_score_list.index(score)
282  for suffix in self.suffixes:
283  for i in range(len(self.best_score_list) - 1, index - 1, -1):
284  oldname = suffix + "." + str(i) + ".pdb"
285  newname = suffix + "." + str(i + 1) + ".pdb"
286  os.rename(oldname, newname)
287  filenametoremove = suffix + \
288  "." + str(self.nbestscoring) + ".pdb"
289  os.remove(filenametoremove)
290  filetoadd = suffix + "." + str(index) + ".pdb"
291  self.write_pdb(filetoadd, appendmode=False)
292 
293  if self.replica_exchange:
294  # write the self.best_score_list to the file
295  best_score_file = open(self.best_score_file_name, "w")
296  best_score_file.write(
297  "self.best_score_list=" + str(self.best_score_list))
298  best_score_file.close()
299 
300  def init_rmf(self, name, hierarchies,rs=None):
301  rh = RMF.create_rmf_file(name)
302  IMP.rmf.add_hierarchies(rh, hierarchies)
303  if rs is not None:
305  self.dictionary_rmfs[name] = rh
306 
307  def add_restraints_to_rmf(self, name, objectlist):
308  for o in objectlist:
309  try:
310  rs = o.get_restraint_for_rmf()
311  except:
312  rs = o.get_restraint()
314  self.dictionary_rmfs[name],
315  rs.get_restraints())
316 
317  def add_geometries_to_rmf(self, name, objectlist):
318  for o in objectlist:
319  geos = o.get_geometries()
320  IMP.rmf.add_geometries(self.dictionary_rmfs[name], geos)
321 
322  def add_particle_pair_from_restraints_to_rmf(self, name, objectlist):
323  for o in objectlist:
324 
325  pps = o.get_particle_pairs()
326  for pp in pps:
327  IMP.rmf.add_geometry(
328  self.dictionary_rmfs[name],
330 
331  def write_rmf(self, name):
332  IMP.rmf.save_frame(self.dictionary_rmfs[name])
333  self.dictionary_rmfs[name].flush()
334 
335  def close_rmf(self, name):
336  del self.dictionary_rmfs[name]
337 
338  def write_rmfs(self):
339  for rmf in self.dictionary_rmfs.keys():
340  self.write_rmf(rmf)
341 
342  def init_stat(self, name, listofobjects):
343  if self.ascii:
344  flstat = open(name, 'w')
345  flstat.close()
346  else:
347  flstat = open(name, 'wb')
348  flstat.close()
349 
350  # check that all objects in listofobjects have a get_output method
351  for l in listofobjects:
352  if not "get_output" in dir(l):
353  print "Output: object", l,"doesn't have get_output() method"
354  exit()
355  self.dictionary_stats[name] = listofobjects
356 
357  def set_output_entry(self, key, value):
358  self.initoutput.update({key: value})
359 
360  def write_stat(self, name, appendmode=True):
361  output = self.initoutput
362  for obj in self.dictionary_stats[name]:
363  d = obj.get_output()
364  # remove all entries that begin with _ (private entries)
365  dfiltered = dict((k, v) for k, v in d.iteritems() if k[0] != "_")
366  output.update(dfiltered)
367 
368  if appendmode:
369  writeflag = 'a'
370  else:
371  writeflag = 'w'
372 
373  if self.ascii:
374  flstat = open(name, writeflag)
375  flstat.write("%s \n" % output)
376  flstat.close()
377  else:
378  flstat = open(name, writeflag + 'b')
379  cPickle.dump(output, flstat, 2)
380  flstat.close()
381 
382  def write_stats(self):
383  for stat in self.dictionary_stats.keys():
384  self.write_stat(stat)
385 
386  def get_stat(self, name):
387  output = {}
388  for obj in self.dictionary_stats[name]:
389  output.update(obj.get_output())
390  return output
391 
392  def write_test(self, name, listofobjects):
393  '''
394  write the test:
395  output=output.Output()
396  output.write_test("test_modeling11_models.rmf_45492_11Sep13_veena_imp-020713.dat",outputobjects)
397  run the test:
398  output=output.Output()
399  output.test("test_modeling11_models.rmf_45492_11Sep13_veena_imp-020713.dat",outputobjects)
400  '''
401  flstat = open(name, 'w')
402  output = self.initoutput
403  for l in listofobjects:
404  if not "get_test_output" in dir(l) and not "get_output" in dir(l):
405  print "Output: object ", l, " doesn't have get_output() or get_test_output() method"
406  exit()
407  self.dictionary_stats[name] = listofobjects
408 
409  for obj in self.dictionary_stats[name]:
410  try:
411  d = obj.get_test_output()
412  except:
413  d = obj.get_output()
414  # remove all entries that begin with _ (private entries)
415  dfiltered = dict((k, v) for k, v in d.iteritems() if k[0] != "_")
416  output.update(dfiltered)
417  #output.update({"ENVIRONMENT": str(self.get_environment_variables())})
418  #output.update(
419  # {"IMP_VERSIONS": str(self.get_versions_of_relevant_modules())})
420  flstat.write("%s \n" % output)
421  flstat.close()
422 
423  def test(self, name, listofobjects):
424  from numpy.testing import assert_approx_equal as aae
425  output = self.initoutput
426  for l in listofobjects:
427  if not "get_test_output" in dir(l) and not "get_output" in dir(l):
428  print "Output: object ", l, " doesn't have get_output() or get_test_output() method"
429  exit()
430  for obj in listofobjects:
431  try:
432  output.update(obj.get_test_output())
433  except:
434  output.update(obj.get_output())
435  #output.update({"ENVIRONMENT": str(self.get_environment_variables())})
436  #output.update(
437  # {"IMP_VERSIONS": str(self.get_versions_of_relevant_modules())})
438 
439  flstat = open(name, 'r')
440 
441  passed=True
442  for l in flstat:
443  test_dict = eval(l)
444  for k in test_dict:
445  if k in output:
446  old_value = str(test_dict[k])
447  new_value = str(output[k])
448 
449  if test_dict[k] != output[k]:
450  if len(old_value) < 50 and len(new_value) < 50:
451  print str(k) + ": test failed, old value: " + old_value + " new value " + new_value
452  passed=False
453  else:
454  print str(k) + ": test failed, omitting results (too long)"
455  passed=False
456 
457  else:
458  print str(k) + " from old objects (file " + str(name) + ") not in new objects"
459  return passed
460 
461  def get_environment_variables(self):
462  import os
463  return str(os.environ)
464 
465  def get_versions_of_relevant_modules(self):
466  import IMP
467  versions = {}
468  versions["IMP_VERSION"] = IMP.kernel.get_module_version()
469  try:
470  import IMP.pmi
471  versions["PMI_VERSION"] = IMP.pmi.get_module_version()
472  except (ImportError):
473  pass
474  try:
475  import IMP.isd2
476  versions["ISD2_VERSION"] = IMP.isd2.get_module_version()
477  except (ImportError):
478  pass
479  try:
480  import IMP.isd_emxl
481  versions["ISD_EMXL_VERSION"] = IMP.isd_emxl.get_module_version()
482  except (ImportError):
483  pass
484  return versions
485 
486 #-------------------
487  def init_stat2(
488  self,
489  name,
490  listofobjects,
491  extralabels=None,
492  listofsummedobjects=None):
493  # this is a new stat file that should be less
494  # space greedy!
495  # listofsummedobjects must be in the form [([obj1,obj2,obj3,obj4...],label)]
496  # extralabels
497 
498  if listofsummedobjects is None:
499  listofsummedobjects = []
500  if extralabels is None:
501  extralabels = []
502  flstat = open(name, 'w')
503  output = {}
504  stat2_keywords = {"STAT2HEADER": "STAT2HEADER"}
505  stat2_keywords.update(
506  {"STAT2HEADER_ENVIRON": str(self.get_environment_variables())})
507  stat2_keywords.update(
508  {"STAT2HEADER_IMP_VERSIONS": str(self.get_versions_of_relevant_modules())})
509  stat2_inverse = {}
510 
511  for l in listofobjects:
512  if not "get_output" in dir(l):
513  print "Output: object ", l, " doesn't have get_output() method"
514  exit()
515  else:
516  d = l.get_output()
517  # remove all entries that begin with _ (private entries)
518  dfiltered = dict((k, v)
519  for k, v in d.iteritems() if k[0] != "_")
520  output.update(dfiltered)
521 
522  # check for customizable entries
523  for l in listofsummedobjects:
524  for t in l[0]:
525  if not "get_output" in dir(t):
526  print "Output: object ", t, " doesn't have get_output() method"
527  exit()
528  else:
529  if "_TotalScore" not in t.get_output():
530  print "Output: object ", t, " doesn't have _TotalScore entry to be summed"
531  exit()
532  else:
533  output.update({l[1]: 0.0})
534 
535  for k in extralabels:
536  output.update({k: 0.0})
537 
538  for n, k in enumerate(output):
539  stat2_keywords.update({n: k})
540  stat2_inverse.update({k: n})
541 
542  flstat.write("%s \n" % stat2_keywords)
543  flstat.close()
544  self.dictionary_stats2[name] = (
545  listofobjects,
546  stat2_inverse,
547  listofsummedobjects,
548  extralabels)
549 
550  def write_stat2(self, name, appendmode=True):
551  output = {}
552  (listofobjects, stat2_inverse, listofsummedobjects,
553  extralabels) = self.dictionary_stats2[name]
554 
555  # writing objects
556  for obj in listofobjects:
557  od = obj.get_output()
558  dfiltered = dict((k, v) for k, v in od.iteritems() if k[0] != "_")
559  for k in dfiltered:
560  output.update({stat2_inverse[k]: od[k]})
561 
562  # writing summedobjects
563  for l in listofsummedobjects:
564  partial_score = 0.0
565  for t in l[0]:
566  d = t.get_output()
567  partial_score += float(d["_TotalScore"])
568  output.update({stat2_inverse[l[1]]: str(partial_score)})
569 
570  # writing extralabels
571  for k in extralabels:
572  if k in self.initoutput:
573  output.update({stat2_inverse[k]: self.initoutput[k]})
574  else:
575  output.update({stat2_inverse[k]: "None"})
576 
577  if appendmode:
578  writeflag = 'a'
579  else:
580  writeflag = 'w'
581 
582  flstat = open(name, writeflag)
583  flstat.write("%s \n" % output)
584  flstat.close()
585 
586  def write_stats2(self):
587  for stat in self.dictionary_stats2.keys():
588  self.write_stat2(stat)
589 
590 
591 class ProcessOutput(object):
592  """A class for reading stat files"""
593  def __init__(self, filename):
594  self.filename = filename
595  self.isstat1 = False
596  self.isstat2 = False
597 
598  # open the file
599  if not self.filename is None:
600  f = open(self.filename, "r")
601  else:
602  print "Error: No file name provided. Use -h for help"
603  exit()
604 
605  # get the keys from the first line
606  for line in f.readlines():
607  d = eval(line)
608  self.klist = d.keys()
609  # check if it is a stat2 file
610  if "STAT2HEADER" in self.klist:
611  import operator
612  self.isstat2 = True
613  for k in self.klist:
614  if "STAT2HEADER" in str(k):
615  # if print_header: print k, d[k]
616  del d[k]
617  stat2_dict = d
618  # get the list of keys sorted by value
619  kkeys = [k[0]
620  for k in sorted(stat2_dict.iteritems(), key=operator.itemgetter(1))]
621  self.klist = [k[1]
622  for k in sorted(stat2_dict.iteritems(), key=operator.itemgetter(1))]
623  self.invstat2_dict = {}
624  for k in kkeys:
625  self.invstat2_dict.update({stat2_dict[k]: k})
626  else:
627  self.isstat1 = True
628  self.klist.sort()
629 
630  break
631  f.close()
632 
633  def get_keys(self):
634  return self.klist
635 
636  def show_keys(self, ncolumns=2, truncate=65):
637  IMP.pmi.tools.print_multicolumn(self.get_keys(), ncolumns, truncate)
638 
640  self,
641  fields,
642  filtertuple=None,
643  filterout=None,
644  get_every=1):
645  '''
646  this function get the wished field names and return a dictionary
647  you can give the optional argument filterout if you want to "grep" out
648  something from the file, so that it is faster
649 
650  filtertuple a tuple that contains ("TheKeyToBeFiltered",relationship,value)
651  relationship = "<", "==", or ">"
652  '''
653 
654  outdict = {}
655  for field in fields:
656  outdict[field] = []
657 
658  # print fields values
659  f = open(self.filename, "r")
660  line_number = 0
661 
662  for line in f.readlines():
663  if not filterout is None:
664  if filterout in line:
665  continue
666  line_number += 1
667 
668  if line_number % get_every != 0:
669  continue
670  #if line_number % 1000 == 0:
671  # print "ProcessOutput.get_fields: read line %s from file %s" % (str(line_number), self.filename)
672  try:
673  d = eval(line)
674  except:
675  print "# Warning: skipped line number " + str(line_number) + " not a valid line"
676  continue
677 
678  if self.isstat1:
679 
680  if not filtertuple is None:
681  keytobefiltered = filtertuple[0]
682  relationship = filtertuple[1]
683  value = filtertuple[2]
684  if relationship == "<":
685  if float(d[keytobefiltered]) >= value:
686  continue
687  if relationship == ">":
688  if float(d[keytobefiltered]) <= value:
689  continue
690  if relationship == "==":
691  if float(d[keytobefiltered]) != value:
692  continue
693  [outdict[field].append(d[field]) for field in fields]
694 
695  elif self.isstat2:
696  if line_number == 1:
697  continue
698 
699  if not filtertuple is None:
700  keytobefiltered = filtertuple[0]
701  relationship = filtertuple[1]
702  value = filtertuple[2]
703  if relationship == "<":
704  if float(d[self.invstat2_dict[keytobefiltered]]) >= value:
705  continue
706  if relationship == ">":
707  if float(d[self.invstat2_dict[keytobefiltered]]) <= value:
708  continue
709  if relationship == "==":
710  if float(d[self.invstat2_dict[keytobefiltered]]) != value:
711  continue
712 
713  [outdict[field].append(d[self.invstat2_dict[field]])
714  for field in fields]
715  f.close()
716  return outdict
717 
718 
719 def plot_fields(fields, framemin=None, framemax=None):
720  import matplotlib.pyplot as plt
721 
722  plt.rc('lines', linewidth=4)
723  fig, axs = plt.subplots(nrows=len(fields))
724  fig.set_size_inches(10.5, 5.5 * len(fields))
725  plt.rc('axes', color_cycle=['r'])
726 
727  n = 0
728  for key in fields:
729  if framemin is None:
730  framemin = 0
731  if framemax is None:
732  framemax = len(fields[key])
733  x = range(framemin, framemax)
734  y = [float(y) for y in fields[key][framemin:framemax]]
735  if len(fields) > 1:
736  axs[n].plot(x, y)
737  axs[n].set_title(key, size="xx-large")
738  axs[n].tick_params(labelsize=18, pad=10)
739  else:
740  axs.plot(x, y)
741  axs.set_title(key, size="xx-large")
742  axs.tick_params(labelsize=18, pad=10)
743  n += 1
744 
745  # Tweak spacing between subplots to prevent labels from overlapping
746  plt.subplots_adjust(hspace=0.3)
747  plt.show()
748 
749 
751  name, values_lists, valuename=None, bins=40, colors=None, format="png",
752  reference_xline=None, yplotrange=None, xplotrange=None,normalized=True,
753  leg_names=None):
754 
755  '''This function is plotting a list of histograms from a value list.
756  @param name the name of the plot
757  @param value_lists the list of list of values eg: [[...],[...],[...]]
758  @param valuename=None the y-label
759  @param bins=40 the number of bins
760  @param colors=None. If None, will use rainbow. Else will use specific list
761  @param format="png" output format
762  @param reference_xline=None plot a reference line parallel to the y-axis
763  @param yplotrange=None the range for the y-axis
764  @param normalized=True whether the histogram is normalized or not
765  @param leg_names names for the legend
766  '''
767 
768  import matplotlib.pyplot as plt
769  import matplotlib.cm as cm
770  fig = plt.figure(figsize=(18.0, 9.0))
771 
772  if colors is None:
773  colors = cm.rainbow(np.linspace(0, 1, len(values_lists)))
774  for nv,values in enumerate(values_lists):
775  col=colors[nv]
776  if leg_names is not None:
777  label=leg_names[nv]
778  else:
779  label=str(nv)
780  h=plt.hist(
781  [float(y) for y in values],
782  bins=bins,
783  color=col,
784  normed=normalized,histtype='step',lw=4,
785  label=label)
786 
787  # plt.title(name,size="xx-large")
788  plt.tick_params(labelsize=12, pad=10)
789  if valuename is None:
790  plt.xlabel(name, size="xx-large")
791  else:
792  plt.xlabel(valuename, size="xx-large")
793  plt.ylabel("Frequency", size="xx-large")
794 
795  if not yplotrange is None:
796  plt.ylim()
797  if not xplotrange is None:
798  plt.xlim(xplotrange)
799 
800  plt.legend(loc=2)
801 
802  if not reference_xline is None:
803  plt.axvline(
804  reference_xline,
805  color='red',
806  linestyle='dashed',
807  linewidth=1)
808 
809  plt.savefig(name + "." + format, dpi=150, transparent=True)
810  plt.show()
811 
812 
813 def plot_fields_box_plots(name, values, positions, frequencies=None,
814  valuename="None", positionname="None", xlabels=None):
815  '''
816  This function plots time series as boxplots
817  fields is a list of time series, positions are the x-values
818  valuename is the y-label, positionname is the x-label
819  '''
820  import matplotlib.pyplot as plt
821  from matplotlib.patches import Polygon
822 
823  bps = []
824  fig = plt.figure(figsize=(float(len(positions)) / 2, 5.0))
825  fig.canvas.set_window_title(name)
826 
827  ax1 = fig.add_subplot(111)
828 
829  plt.subplots_adjust(left=0.2, right=0.990, top=0.95, bottom=0.4)
830 
831  bps.append(plt.boxplot(values, notch=0, sym='', vert=1,
832  whis=1.5, positions=positions))
833 
834  plt.setp(bps[-1]['boxes'], color='black', lw=1.5)
835  plt.setp(bps[-1]['whiskers'], color='black', ls=":", lw=1.5)
836 
837  if frequencies is not None:
838  ax1.plot(positions, frequencies, 'k.', alpha=0.5, markersize=20)
839 
840  # print ax1.xaxis.get_majorticklocs()
841  if not xlabels is None:
842  ax1.set_xticklabels(xlabels)
843  plt.xticks(rotation=90)
844  plt.xlabel(positionname)
845  plt.ylabel(valuename)
846 
847  plt.savefig(name,dpi=150)
848  plt.show()
849 
850 
851 def plot_xy_data(x,y,title=None,display=True,set_plot_yaxis_range=None):
852  import matplotlib.pyplot as plt
853  plt.rc('lines', linewidth=2)
854 
855  fig, ax = plt.subplots(nrows=1)
856  fig.set_size_inches(8,4.5)
857  if title is not None:
858  fig.canvas.set_window_title(title)
859  plt.rc('axes', color_cycle=['r'])
860  ax.plot(x,y)
861  if title is not None:
862  plt.savefig(title+".pdf")
863  if display:
864  plt.show()
865  if not yplotrange is None:
866  plt.ylim(set_plot_yaxis_range)
867  plt.close(fig)
868 
869 def plot_scatter_xy_data(x,y,labelx="None",labely="None",
870  xmin=None,xmax=None,ymin=None,ymax=None,
871  savefile=False,filename="None.eps",alpha=0.75):
872 
873  import matplotlib.pyplot as plt
874  import sys
875  from matplotlib import rc
876  #rc('font', **{'family':'serif','serif':['Palatino']})
877  rc('font',**{'family':'sans-serif','sans-serif':['Helvetica']})
878  #rc('text', usetex=True)
879 
880  fig, axs = plt.subplots(1)
881 
882  axs0 = axs
883 
884  axs0.set_xlabel(labelx, size="xx-large")
885  axs0.set_ylabel(labely, size="xx-large")
886  axs0.tick_params(labelsize=18, pad=10)
887 
888  plot2 = []
889 
890  plot2.append(axs0.plot(x, y, 'o', color='k',lw=2, ms=0.1, alpha=alpha, c="w"))
891 
892  axs0.legend(
893  loc=0,
894  frameon=False,
895  scatterpoints=1,
896  numpoints=1,
897  columnspacing=1)
898 
899  fig.set_size_inches(8.0, 8.0)
900  fig.subplots_adjust(left=0.161, right=0.850, top=0.95, bottom=0.11)
901  if (not ymin is None) and (not ymax is None):
902  axs0.set_ylim(ymin,ymax)
903  if (not xmin is None) and (not xmax is None):
904  axs0.set_xlim(xmin,xmax)
905 
906  #plt.show()
907  if savefile:
908  fig.savefig(filename, dpi=300)
909 
910 
911 def get_graph_from_hierarchy(hier):
912  graph = []
913  depth_dict = {}
914  depth = 0
915  (graph, depth, depth_dict) = recursive_graph(
916  hier, graph, depth, depth_dict)
917 
918  # filters node labels according to depth_dict
919  node_labels_dict = {}
920  node_size_dict = {}
921  for key in depth_dict:
922  node_size_dict = 10 / depth_dict[key]
923  if depth_dict[key] < 3:
924  node_labels_dict[key] = key
925  else:
926  node_labels_dict[key] = ""
927  draw_graph(graph, labels_dict=node_labels_dict)
928 
929 
930 def recursive_graph(hier, graph, depth, depth_dict):
931  depth = depth + 1
932  nameh = IMP.atom.Hierarchy(hier).get_name()
933  index = str(hier.get_particle().get_index())
934  name1 = nameh + "|#" + index
935  depth_dict[name1] = depth
936 
937  children = IMP.atom.Hierarchy(hier).get_children()
938 
939  if len(children) == 1 or children is None:
940  depth = depth - 1
941  return (graph, depth, depth_dict)
942 
943  else:
944  for c in children:
945  (graph, depth, depth_dict) = recursive_graph(
946  c, graph, depth, depth_dict)
947  nameh = IMP.atom.Hierarchy(c).get_name()
948  index = str(c.get_particle().get_index())
949  namec = nameh + "|#" + index
950  graph.append((name1, namec))
951 
952  depth = depth - 1
953  return (graph, depth, depth_dict)
954 
955 
956 def draw_graph(graph, labels_dict=None, graph_layout='spring',
957  node_size=5, node_color=None, node_alpha=0.3,
958  node_text_size=11, fixed=None, pos=None,
959  edge_color='blue', edge_alpha=0.3, edge_thickness=1,
960  edge_text_pos=0.3,
961  validation_edges=None,
962  text_font='sans-serif',
963  out_filename=None):
964 
965  import networkx as nx
966  import matplotlib.pyplot as plt
967  from math import sqrt, pi
968 
969  # create networkx graph
970  G = nx.Graph()
971 
972  # add edges
973  if type(edge_thickness) is list:
974  for edge,weight in zip(graph,edge_thickness):
975  G.add_edge(edge[0], edge[1], weight=weight)
976  else:
977  for edge in graph:
978  G.add_edge(edge[0], edge[1])
979 
980  if node_color==None:
981  node_color_rgb=(0,0,0)
982  node_color_hex="000000"
983  else:
985  tmpcolor_rgb=[]
986  tmpcolor_hex=[]
987  for node in G.nodes():
988  cctuple=cc.rgb(node_color[node])
989  tmpcolor_rgb.append((cctuple[0]/255,cctuple[1]/255,cctuple[2]/255))
990  tmpcolor_hex.append(node_color[node])
991  node_color_rgb=tmpcolor_rgb
992  node_color_hex=tmpcolor_hex
993 
994  # get node sizes if dictionary
995  if type(node_size) is dict:
996  tmpsize=[]
997  for node in G.nodes():
998  size=sqrt(node_size[node])/pi*10.0
999  tmpsize.append(size)
1000  node_size=tmpsize
1001 
1002  for n,node in enumerate(G.nodes()):
1003  color=node_color_hex[n]
1004  size=node_size[n]
1005  nx.set_node_attributes(G, "graphics", {node : {'type': 'ellipse','w': size, 'h': size,'fill': '#'+color, 'label': node}})
1006  nx.set_node_attributes(G, "LabelGraphics", {node : {'type': 'text','text':node, 'color':'#000000', 'visible':'true'}})
1007 
1008  for edge in G.edges():
1009  nx.set_edge_attributes(G, "graphics", {edge : {'width': 1,'fill': '#000000'}})
1010 
1011  for ve in validation_edges:
1012  print ve
1013  if (ve[0],ve[1]) in G.edges():
1014  print "found forward"
1015  nx.set_edge_attributes(G, "graphics", {ve : {'width': 1,'fill': '#00FF00'}})
1016  elif (ve[1],ve[0]) in G.edges():
1017  print "found backward"
1018  nx.set_edge_attributes(G, "graphics", {(ve[1],ve[0]) : {'width': 1,'fill': '#00FF00'}})
1019  else:
1020  G.add_edge(ve[0], ve[1])
1021  print "not found"
1022  nx.set_edge_attributes(G, "graphics", {ve : {'width': 1,'fill': '#FF0000'}})
1023 
1024  # these are different layouts for the network you may try
1025  # shell seems to work best
1026  if graph_layout == 'spring':
1027  print fixed, pos
1028  graph_pos = nx.spring_layout(G,k=1.0/8.0,fixed=fixed,pos=pos)
1029  elif graph_layout == 'spectral':
1030  graph_pos = nx.spectral_layout(G)
1031  elif graph_layout == 'random':
1032  graph_pos = nx.random_layout(G)
1033  else:
1034  graph_pos = nx.shell_layout(G)
1035 
1036 
1037  # draw graph
1038  nx.draw_networkx_nodes(G, graph_pos, node_size=node_size,
1039  alpha=node_alpha, node_color=node_color_rgb,
1040  linewidths=0)
1041  nx.draw_networkx_edges(G, graph_pos, width=edge_thickness,
1042  alpha=edge_alpha, edge_color=edge_color)
1043  nx.draw_networkx_labels(
1044  G, graph_pos, labels=labels_dict, font_size=node_text_size,
1045  font_family=text_font)
1046  if out_filename:
1047  plt.savefig(out_filename)
1048  nx.write_gml(G,'out.gml')
1049  plt.show()
1050 
1051 
1052 def draw_table():
1053 
1054  # still an example!
1055 
1056  from ipyD3 import d3object
1057  from IPython.display import display
1058 
1059  d3 = d3object(width=800,
1060  height=400,
1061  style='JFTable',
1062  number=1,
1063  d3=None,
1064  title='Example table with d3js',
1065  desc='An example table created created with d3js with data generated with Python.')
1066  data = [
1067  [1277.0,
1068  654.0,
1069  288.0,
1070  1976.0,
1071  3281.0,
1072  3089.0,
1073  10336.0,
1074  4650.0,
1075  4441.0,
1076  4670.0,
1077  944.0,
1078  110.0],
1079  [1318.0,
1080  664.0,
1081  418.0,
1082  1952.0,
1083  3581.0,
1084  4574.0,
1085  11457.0,
1086  6139.0,
1087  7078.0,
1088  6561.0,
1089  2354.0,
1090  710.0],
1091  [1783.0,
1092  774.0,
1093  564.0,
1094  1470.0,
1095  3571.0,
1096  3103.0,
1097  9392.0,
1098  5532.0,
1099  5661.0,
1100  4991.0,
1101  2032.0,
1102  680.0],
1103  [1301.0,
1104  604.0,
1105  286.0,
1106  2152.0,
1107  3282.0,
1108  3369.0,
1109  10490.0,
1110  5406.0,
1111  4727.0,
1112  3428.0,
1113  1559.0,
1114  620.0],
1115  [1537.0,
1116  1714.0,
1117  724.0,
1118  4824.0,
1119  5551.0,
1120  8096.0,
1121  16589.0,
1122  13650.0,
1123  9552.0,
1124  13709.0,
1125  2460.0,
1126  720.0],
1127  [5691.0,
1128  2995.0,
1129  1680.0,
1130  11741.0,
1131  16232.0,
1132  14731.0,
1133  43522.0,
1134  32794.0,
1135  26634.0,
1136  31400.0,
1137  7350.0,
1138  3010.0],
1139  [1650.0,
1140  2096.0,
1141  60.0,
1142  50.0,
1143  1180.0,
1144  5602.0,
1145  15728.0,
1146  6874.0,
1147  5115.0,
1148  3510.0,
1149  1390.0,
1150  170.0],
1151  [72.0, 60.0, 60.0, 10.0, 120.0, 172.0, 1092.0, 675.0, 408.0, 360.0, 156.0, 100.0]]
1152  data = [list(i) for i in zip(*data)]
1153  sRows = [['January',
1154  'February',
1155  'March',
1156  'April',
1157  'May',
1158  'June',
1159  'July',
1160  'August',
1161  'September',
1162  'October',
1163  'November',
1164  'Deecember']]
1165  sColumns = [['Prod {0}'.format(i) for i in xrange(1, 9)],
1166  [None, '', None, None, 'Group 1', None, None, 'Group 2']]
1167  d3.addSimpleTable(data,
1168  fontSizeCells=[12, ],
1169  sRows=sRows,
1170  sColumns=sColumns,
1171  sRowsMargins=[5, 50, 0],
1172  sColsMargins=[5, 20, 10],
1173  spacing=0,
1174  addBorders=1,
1175  addOutsideBorders=-1,
1176  rectWidth=45,
1177  rectHeight=0
1178  )
1179  html = d3.render(mode=['html', 'show'])
1180  display(html)
A base class for Keys.
Definition: kernel/Key.h:46
void write_pdb(const Selection &mhd, base::TextOutput out, unsigned int model=1)
void save_frame(RMF::FileHandle file, unsigned int, std::string name="")
Definition: frames.h:42
void add_restraints(RMF::NodeHandle fh, const kernel::Restraints &hs)
A class for reading stat files.
Definition: output.py:591
Ints get_index(const kernel::ParticlesTemp &particles, const Subset &subset, const Subsets &excluded)
def plot_field_histogram
This function is plotting a list of histograms from a value list.
Definition: output.py:750
def plot_fields_box_plots
This function plots time series as boxplots fields is a list of time series, positions are the x-valu...
Definition: output.py:813
Miscellaneous utilities.
Definition: tools.py:1
std::string get_module_version()
a class to change color code to hexadecimal to rgb
Definition: tools.py:1352
def get_prot_name_from_particle
this function returns the component name provided a particle and a list of names
Definition: tools.py:909
def get_fields
this function get the wished field names and return a dictionary you can give the optional argument f...
Definition: output.py:639
The standard decorator for manipulating molecular structures.
A decorator for a particle representing an atom.
Definition: atom/Atom.h:234
The type for a residue.
A decorator for a particle with x,y,z coordinates.
Definition: XYZ.h:30
void add_hierarchies(RMF::NodeHandle fh, const atom::Hierarchies &hs)
Class for easy writing of PDBs, RMFs, and stat files.
Definition: output.py:18
void add_geometries(RMF::NodeHandle parent, const display::GeometriesTemp &r)
std::string get_module_version()
Display a segment connecting a pair of particles.
Definition: XYZR.h:168
static bool get_is_setup(const IMP::kernel::ParticleAdaptor &p)
Definition: atom/Atom.h:241
A decorator for a residue.
Definition: Residue.h:134
Basic functionality that is expected to be used by a wide variety of IMP users.
static bool get_is_setup(const IMP::kernel::ParticleAdaptor &p)
Definition: Residue.h:155
static bool get_is_setup(kernel::Model *m, kernel::ParticleIndex pi)
Definition: Fragment.h:46
def write_test
write the test: output=output.Output() output.write_test("test_modeling11_models.rmf_45492_11Sep13_ve...
Definition: output.py:392
Python classes to represent, score, sample and analyze models.
Functionality for loading, creating, manipulating and scoring atomic structures.
Hierarchies get_leaves(const Selection &h)
def get_residue_indexes
This "overloaded" function retrieves the residue indexes for each particle which is an instance of Fr...
Definition: tools.py:929
A decorator for a particle with x,y,z coordinates and a radius.
Definition: XYZR.h:27