IMP  2.4.0
The Integrative Modeling Platform
output.py
1 """@namespace IMP.pmi.output
2  Classes for writing output files and processing them.
3 """
4 
5 from __future__ import print_function, division
6 import IMP
7 import IMP.atom
8 import IMP.core
9 import IMP.pmi
10 import IMP.pmi.tools
11 import os
12 import RMF
13 import numpy as np
14 import operator
15 try:
16  import cPickle as pickle
17 except ImportError:
18  import pickle
19 
20 class Output(object):
21  """Class for easy writing of PDBs, RMFs, and stat files"""
22  def __init__(self, ascii=True,atomistic=False):
23  self.dictionary_pdbs = {}
24  self.dictionary_rmfs = {}
25  self.dictionary_stats = {}
26  self.dictionary_stats2 = {}
27  self.best_score_list = None
28  self.nbestscoring = None
29  self.suffixes = []
30  self.replica_exchange = False
31  self.ascii = ascii
32  self.initoutput = {}
33  self.residuetypekey = IMP.StringKey("ResidueName")
34  self.chainids = "ABCDEFGHIJKLMNOPQRSTUVXYWZabcdefghijklmnopqrstuvxywz"
35  self.dictchain = {}
36  self.particle_infos_for_pdb = {}
37  self.atomistic=atomistic
38 
39  def get_pdb_names(self):
40  return list(self.dictionary_pdbs.keys())
41 
42  def get_rmf_names(self):
43  return list(self.dictionary_rmfs.keys())
44 
45  def get_stat_names(self):
46  return list(self.dictionary_stats.keys())
47 
48  def init_pdb(self, name, prot):
49  flpdb = open(name, 'w')
50  flpdb.close()
51  self.dictionary_pdbs[name] = prot
52  self.dictchain[name] = {}
53 
54  for n, i in enumerate(self.dictionary_pdbs[name].get_children()):
55  self.dictchain[name][i.get_name()] = self.chainids[n]
56 
57  def write_psf(self,filename,name):
58  flpsf=open(filename,'w')
59  flpsf.write("PSF CMAP CHEQ"+"\n")
60  index_residue_pair_list={}
61  (particle_infos_for_pdb, geometric_center)=self.get_particle_infos_for_pdb_writing(name)
62  nparticles=len(particle_infos_for_pdb)
63  flpsf.write(str(nparticles)+" !NATOM"+"\n")
64  for n,p in enumerate(particle_infos_for_pdb):
65  atom_index=n+1
66  atomtype=p[1]
67  residue_type=p[2]
68  chain=p[3]
69  resid=p[4]
70  flpsf.write('{0:8d}{1:1s}{2:4s}{3:1s}{4:4s}{5:1s}{6:4s}{7:1s}{8:4s}{9:1s}{10:4s}{11:14.6f}{12:14.6f}{13:8d}{14:14.6f}{15:14.6f}'.format(atom_index," ",chain," ",str(resid)," ",residue_type," ","C"," ","C",1.0,0.0,0,0.0,0.0))
71  flpsf.write('\n')
72  #flpsf.write(str(atom_index)+" "+str(chain)+" "+str(resid)+" "+str(residue_type).replace('"','')+" C C "+"1.0 0.0 0 0.0 0.0\n")
73  if chain not in index_residue_pair_list:
74  index_residue_pair_list[chain]=[(atom_index,resid)]
75  else:
76  index_residue_pair_list[chain].append((atom_index,resid))
77 
78 
79  #now write the connectivity
80  indexes_pairs=[]
81  for chain in sorted(index_residue_pair_list.keys()):
82 
83  ls=index_residue_pair_list[chain]
84  #sort by residue
85  ls=sorted(ls, key=lambda tup: tup[1])
86  #get the index list
87  indexes=[x[0] for x in ls]
88  # get the contiguous pairs
89  indexes_pairs+=list(IMP.pmi.tools.sublist_iterator(indexes,lmin=2,lmax=2))
90  nbonds=len(indexes_pairs)
91  flpsf.write(str(nbonds)+" !NBOND: bonds"+"\n")
92 
93  sublists=[indexes_pairs[i:i+4] for i in range(0,len(indexes_pairs),4)]
94 
95  # save bonds in fized column format
96  for ip in sublists:
97  if len(ip)==4:
98  flpsf.write('{0:8d}{1:8d}{2:8d}{3:8d}{4:8d}{5:8d}{6:8d}{7:8d}'.format(ip[0][0],ip[0][1],
99  ip[1][0],ip[1][1],ip[2][0],ip[2][1],ip[3][0],ip[3][1]))
100  elif len(ip)==3:
101  flpsf.write('{0:8d}{1:8d}{2:8d}{3:8d}{4:8d}{5:8d}'.format(ip[0][0],ip[0][1],ip[1][0],
102  ip[1][1],ip[2][0],ip[2][1]))
103  elif len(ip)==2:
104  flpsf.write('{0:8d}{1:8d}{2:8d}{3:8d}'.format(ip[0][0],ip[0][1],ip[1][0],ip[1][1]))
105  elif len(ip)==1:
106  flpsf.write('{0:8d}{1:8d}'.format(ip[0][0],ip[0][1]))
107  flpsf.write('\n')
108 
109  del particle_infos_for_pdb
110 
111  def write_pdb(self,name,appendmode=True,
112  translate_to_geometric_center=False):
113  if appendmode:
114  flpdb = open(name, 'a')
115  else:
116  flpdb = open(name, 'w')
117 
118  (particle_infos_for_pdb,
119  geometric_center) = self.get_particle_infos_for_pdb_writing(name)
120 
121  if not translate_to_geometric_center:
122  geometric_center = (0, 0, 0)
123 
124  for n,tupl in enumerate(particle_infos_for_pdb):
125 
126  (xyz, atom_type, residue_type,
127  chain_id, residue_index,radius) = tupl
128 
129  flpdb.write(IMP.atom.get_pdb_string((xyz[0] - geometric_center[0],
130  xyz[1] - geometric_center[1],
131  xyz[2] - geometric_center[2]),
132  n+1, atom_type, residue_type,
133  chain_id, residue_index,' ',1.00,radius))
134 
135  flpdb.write("ENDMDL\n")
136  flpdb.close()
137 
138  del particle_infos_for_pdb
139 
140  def get_particle_infos_for_pdb_writing(self, name):
141  # index_residue_pair_list={}
142 
143  # the resindexes dictionary keep track of residues that have been already
144  # added to avoid duplication
145  # highest resolution have highest priority
146  resindexes_dict = {}
147 
148  # this dictionary dill contain the sequence of tuples needed to
149  # write the pdb
150  particle_infos_for_pdb = []
151 
152  geometric_center = [0, 0, 0]
153  atom_count = 0
154  atom_index = 0
155 
156  for n, p in enumerate(IMP.atom.get_leaves(self.dictionary_pdbs[name])):
157 
158  # this loop gets the protein name from the
159  # particle leave by descending into the hierarchy
160 
161  (protname, is_a_bead) = IMP.pmi.tools.get_prot_name_from_particle(
162  p, self.dictchain[name])
163 
164  if protname not in resindexes_dict:
165  resindexes_dict[protname] = []
166 
167  if IMP.atom.Atom.get_is_setup(p) and self.atomistic:
168  residue = IMP.atom.Residue(IMP.atom.Atom(p).get_parent())
169  rt = residue.get_residue_type()
170  resind = residue.get_index()
171  atomtype = IMP.atom.Atom(p).get_atom_type()
172  xyz = list(IMP.core.XYZ(p).get_coordinates())
173  radius = IMP.core.XYZR(p).get_radius()
174  geometric_center[0] += xyz[0]
175  geometric_center[1] += xyz[1]
176  geometric_center[2] += xyz[2]
177  atom_count += 1
178  particle_infos_for_pdb.append((xyz,
179  atomtype, rt, self.dictchain[name][protname], resind,radius))
180  resindexes_dict[protname].append(resind)
181 
183 
184  residue = IMP.atom.Residue(p)
185  resind = residue.get_index()
186  # skip if the residue was already added by atomistic resolution
187  # 0
188  if resind in resindexes_dict[protname]:
189  continue
190  else:
191  resindexes_dict[protname].append(resind)
192  rt = residue.get_residue_type()
193  xyz = IMP.core.XYZ(p).get_coordinates()
194  radius = IMP.core.XYZR(p).get_radius()
195  geometric_center[0] += xyz[0]
196  geometric_center[1] += xyz[1]
197  geometric_center[2] += xyz[2]
198  atom_count += 1
199  particle_infos_for_pdb.append((xyz,
200  IMP.atom.AT_CA, rt, self.dictchain[name][protname], resind,radius))
201 
202  elif IMP.atom.Fragment.get_is_setup(p) and not is_a_bead:
203  resindexes = IMP.pmi.tools.get_residue_indexes(p)
204  resind = resindexes[len(resindexes) // 2]
205  if resind in resindexes_dict[protname]:
206  continue
207  else:
208  resindexes_dict[protname].append(resind)
209  rt = IMP.atom.ResidueType('BEA')
210  xyz = IMP.core.XYZ(p).get_coordinates()
211  radius = IMP.core.XYZR(p).get_radius()
212  geometric_center[0] += xyz[0]
213  geometric_center[1] += xyz[1]
214  geometric_center[2] += xyz[2]
215  atom_count += 1
216  particle_infos_for_pdb.append((xyz,
217  IMP.atom.AT_CA, rt, self.dictchain[name][protname], resind,radius))
218 
219  else:
220  if is_a_bead:
221  rt = IMP.atom.ResidueType('BEA')
222  resindexes = IMP.pmi.tools.get_residue_indexes(p)
223  resind = resindexes[len(resindexes) // 2]
224  xyz = IMP.core.XYZ(p).get_coordinates()
225  radius = IMP.core.XYZR(p).get_radius()
226  geometric_center[0] += xyz[0]
227  geometric_center[1] += xyz[1]
228  geometric_center[2] += xyz[2]
229  atom_count += 1
230  particle_infos_for_pdb.append((xyz,
231  IMP.atom.AT_CA, rt, self.dictchain[name][protname], resind,radius))
232 
233  geometric_center = (geometric_center[0] / atom_count,
234  geometric_center[1] / atom_count,
235  geometric_center[2] / atom_count)
236 
237  particle_infos_for_pdb = sorted(particle_infos_for_pdb, key=operator.itemgetter(3, 4))
238 
239  return (particle_infos_for_pdb, geometric_center)
240 
241 
242  def write_pdbs(self, appendmode=True):
243  for pdb in self.dictionary_pdbs.keys():
244  self.write_pdb(pdb, appendmode)
245 
246  def init_pdb_best_scoring(
247  self,
248  suffix,
249  prot,
250  nbestscoring,
251  replica_exchange=False):
252  # save only the nbestscoring conformations
253  # create as many pdbs as needed
254 
255  self.suffixes.append(suffix)
256  self.replica_exchange = replica_exchange
257  if not self.replica_exchange:
258  # common usage
259  # if you are not in replica exchange mode
260  # initialize the array of scores internally
261  self.best_score_list = []
262  else:
263  # otherwise the replicas must cominucate
264  # through a common file to know what are the best scores
265  self.best_score_file_name = "best.scores.rex.py"
266  self.best_score_list = []
267  best_score_file = open(self.best_score_file_name, "w")
268  best_score_file.write(
269  "self.best_score_list=" + str(self.best_score_list))
270  best_score_file.close()
271 
272  self.nbestscoring = nbestscoring
273  for i in range(self.nbestscoring):
274  name = suffix + "." + str(i) + ".pdb"
275  flpdb = open(name, 'w')
276  flpdb.close()
277  self.dictionary_pdbs[name] = prot
278  self.dictchain[name] = {}
279  for n, i in enumerate(self.dictionary_pdbs[name].get_children()):
280  self.dictchain[name][i.get_name()] = self.chainids[n]
281 
282  def write_pdb_best_scoring(self, score):
283  if self.nbestscoring is None:
284  print("Output.write_pdb_best_scoring: init_pdb_best_scoring not run")
285 
286  # update the score list
287  if self.replica_exchange:
288  # read the self.best_score_list from the file
289  exec(open(self.best_score_file_name).read())
290 
291  if len(self.best_score_list) < self.nbestscoring:
292  self.best_score_list.append(score)
293  self.best_score_list.sort()
294  index = self.best_score_list.index(score)
295  for suffix in self.suffixes:
296  for i in range(len(self.best_score_list) - 2, index - 1, -1):
297  oldname = suffix + "." + str(i) + ".pdb"
298  newname = suffix + "." + str(i + 1) + ".pdb"
299  # rename on Windows fails if newname already exists
300  if os.path.exists(newname):
301  os.unlink(newname)
302  os.rename(oldname, newname)
303  filetoadd = suffix + "." + str(index) + ".pdb"
304  self.write_pdb(filetoadd, appendmode=False)
305 
306  else:
307  if score < self.best_score_list[-1]:
308  self.best_score_list.append(score)
309  self.best_score_list.sort()
310  self.best_score_list.pop(-1)
311  index = self.best_score_list.index(score)
312  for suffix in self.suffixes:
313  for i in range(len(self.best_score_list) - 1, index - 1, -1):
314  oldname = suffix + "." + str(i) + ".pdb"
315  newname = suffix + "." + str(i + 1) + ".pdb"
316  os.rename(oldname, newname)
317  filenametoremove = suffix + \
318  "." + str(self.nbestscoring) + ".pdb"
319  os.remove(filenametoremove)
320  filetoadd = suffix + "." + str(index) + ".pdb"
321  self.write_pdb(filetoadd, appendmode=False)
322 
323  if self.replica_exchange:
324  # write the self.best_score_list to the file
325  best_score_file = open(self.best_score_file_name, "w")
326  best_score_file.write(
327  "self.best_score_list=" + str(self.best_score_list))
328  best_score_file.close()
329 
330  def init_rmf(self, name, hierarchies,rs=None):
331  rh = RMF.create_rmf_file(name)
332  IMP.rmf.add_hierarchies(rh, hierarchies)
333  if rs is not None:
335  self.dictionary_rmfs[name] = rh
336 
337  def add_restraints_to_rmf(self, name, objectlist):
338  for o in objectlist:
339  try:
340  rs = o.get_restraint_for_rmf()
341  except:
342  rs = o.get_restraint()
344  self.dictionary_rmfs[name],
345  rs.get_restraints())
346 
347  def add_geometries_to_rmf(self, name, objectlist):
348  for o in objectlist:
349  geos = o.get_geometries()
350  IMP.rmf.add_geometries(self.dictionary_rmfs[name], geos)
351 
352  def add_particle_pair_from_restraints_to_rmf(self, name, objectlist):
353  for o in objectlist:
354 
355  pps = o.get_particle_pairs()
356  for pp in pps:
357  IMP.rmf.add_geometry(
358  self.dictionary_rmfs[name],
360 
361  def write_rmf(self, name):
362  IMP.rmf.save_frame(self.dictionary_rmfs[name])
363  self.dictionary_rmfs[name].flush()
364 
365  def close_rmf(self, name):
366  del self.dictionary_rmfs[name]
367 
368  def write_rmfs(self):
369  for rmf in self.dictionary_rmfs.keys():
370  self.write_rmf(rmf)
371 
372  def init_stat(self, name, listofobjects):
373  if self.ascii:
374  flstat = open(name, 'w')
375  flstat.close()
376  else:
377  flstat = open(name, 'wb')
378  flstat.close()
379 
380  # check that all objects in listofobjects have a get_output method
381  for l in listofobjects:
382  if not "get_output" in dir(l):
383  raise ValueError("Output: object %s doesn't have get_output() method" % str(l))
384  self.dictionary_stats[name] = listofobjects
385 
386  def set_output_entry(self, key, value):
387  self.initoutput.update({key: value})
388 
389  def write_stat(self, name, appendmode=True):
390  output = self.initoutput
391  for obj in self.dictionary_stats[name]:
392  d = obj.get_output()
393  # remove all entries that begin with _ (private entries)
394  dfiltered = dict((k, v) for k, v in d.items() if k[0] != "_")
395  output.update(dfiltered)
396 
397  if appendmode:
398  writeflag = 'a'
399  else:
400  writeflag = 'w'
401 
402  if self.ascii:
403  flstat = open(name, writeflag)
404  flstat.write("%s \n" % output)
405  flstat.close()
406  else:
407  flstat = open(name, writeflag + 'b')
408  cPickle.dump(output, flstat, 2)
409  flstat.close()
410 
411  def write_stats(self):
412  for stat in self.dictionary_stats.keys():
413  self.write_stat(stat)
414 
415  def get_stat(self, name):
416  output = {}
417  for obj in self.dictionary_stats[name]:
418  output.update(obj.get_output())
419  return output
420 
421  def write_test(self, name, listofobjects):
422 # write the test:
423 # output=output.Output()
424 # output.write_test("test_modeling11_models.rmf_45492_11Sep13_veena_imp-020713.dat",outputobjects)
425 # run the test:
426 # output=output.Output()
427 # output.test("test_modeling11_models.rmf_45492_11Sep13_veena_imp-020713.dat",outputobjects)
428  flstat = open(name, 'w')
429  output = self.initoutput
430  for l in listofobjects:
431  if not "get_test_output" in dir(l) and not "get_output" in dir(l):
432  raise ValueError("Output: object %s doesn't have get_output() or get_test_output() method" % str(l))
433  self.dictionary_stats[name] = listofobjects
434 
435  for obj in self.dictionary_stats[name]:
436  try:
437  d = obj.get_test_output()
438  except:
439  d = obj.get_output()
440  # remove all entries that begin with _ (private entries)
441  dfiltered = dict((k, v) for k, v in d.items() if k[0] != "_")
442  output.update(dfiltered)
443  #output.update({"ENVIRONMENT": str(self.get_environment_variables())})
444  #output.update(
445  # {"IMP_VERSIONS": str(self.get_versions_of_relevant_modules())})
446  flstat.write("%s \n" % output)
447  flstat.close()
448 
449  def test(self, name, listofobjects):
450  from numpy.testing import assert_approx_equal as aae
451  output = self.initoutput
452  for l in listofobjects:
453  if not "get_test_output" in dir(l) and not "get_output" in dir(l):
454  raise ValueError("Output: object %s doesn't have get_output() or get_test_output() method" % str(l))
455  for obj in listofobjects:
456  try:
457  output.update(obj.get_test_output())
458  except:
459  output.update(obj.get_output())
460  #output.update({"ENVIRONMENT": str(self.get_environment_variables())})
461  #output.update(
462  # {"IMP_VERSIONS": str(self.get_versions_of_relevant_modules())})
463 
464  flstat = open(name, 'r')
465 
466  passed=True
467  for l in flstat:
468  test_dict = eval(l)
469  for k in test_dict:
470  if k in output:
471  old_value = str(test_dict[k])
472  new_value = str(output[k])
473 
474  if test_dict[k] != output[k]:
475  if len(old_value) < 50 and len(new_value) < 50:
476  print(str(k) + ": test failed, old value: " + old_value + " new value " + new_value)
477  passed=False
478  else:
479  print(str(k) + ": test failed, omitting results (too long)")
480  passed=False
481 
482  else:
483  print(str(k) + " from old objects (file " + str(name) + ") not in new objects")
484  return passed
485 
486  def get_environment_variables(self):
487  import os
488  return str(os.environ)
489 
490  def get_versions_of_relevant_modules(self):
491  import IMP
492  versions = {}
493  versions["IMP_VERSION"] = IMP.kernel.get_module_version()
494  try:
495  import IMP.pmi
496  versions["PMI_VERSION"] = IMP.pmi.get_module_version()
497  except (ImportError):
498  pass
499  try:
500  import IMP.isd2
501  versions["ISD2_VERSION"] = IMP.isd2.get_module_version()
502  except (ImportError):
503  pass
504  try:
505  import IMP.isd_emxl
506  versions["ISD_EMXL_VERSION"] = IMP.isd_emxl.get_module_version()
507  except (ImportError):
508  pass
509  return versions
510 
511 #-------------------
512  def init_stat2(
513  self,
514  name,
515  listofobjects,
516  extralabels=None,
517  listofsummedobjects=None):
518  # this is a new stat file that should be less
519  # space greedy!
520  # listofsummedobjects must be in the form [([obj1,obj2,obj3,obj4...],label)]
521  # extralabels
522 
523  if listofsummedobjects is None:
524  listofsummedobjects = []
525  if extralabels is None:
526  extralabels = []
527  flstat = open(name, 'w')
528  output = {}
529  stat2_keywords = {"STAT2HEADER": "STAT2HEADER"}
530  stat2_keywords.update(
531  {"STAT2HEADER_ENVIRON": str(self.get_environment_variables())})
532  stat2_keywords.update(
533  {"STAT2HEADER_IMP_VERSIONS": str(self.get_versions_of_relevant_modules())})
534  stat2_inverse = {}
535 
536  for l in listofobjects:
537  if not "get_output" in dir(l):
538  raise ValueError("Output: object %s doesn't have get_output() method" % str(l))
539  else:
540  d = l.get_output()
541  # remove all entries that begin with _ (private entries)
542  dfiltered = dict((k, v)
543  for k, v in d.items() if k[0] != "_")
544  output.update(dfiltered)
545 
546  # check for customizable entries
547  for l in listofsummedobjects:
548  for t in l[0]:
549  if not "get_output" in dir(t):
550  raise ValueError("Output: object %s doesn't have get_output() method" % str(t))
551  else:
552  if "_TotalScore" not in t.get_output():
553  raise ValueError("Output: object %s doesn't have _TotalScore entry to be summed" % str(t))
554  else:
555  output.update({l[1]: 0.0})
556 
557  for k in extralabels:
558  output.update({k: 0.0})
559 
560  for n, k in enumerate(output):
561  stat2_keywords.update({n: k})
562  stat2_inverse.update({k: n})
563 
564  flstat.write("%s \n" % stat2_keywords)
565  flstat.close()
566  self.dictionary_stats2[name] = (
567  listofobjects,
568  stat2_inverse,
569  listofsummedobjects,
570  extralabels)
571 
572  def write_stat2(self, name, appendmode=True):
573  output = {}
574  (listofobjects, stat2_inverse, listofsummedobjects,
575  extralabels) = self.dictionary_stats2[name]
576 
577  # writing objects
578  for obj in listofobjects:
579  od = obj.get_output()
580  dfiltered = dict((k, v) for k, v in od.items() if k[0] != "_")
581  for k in dfiltered:
582  output.update({stat2_inverse[k]: od[k]})
583 
584  # writing summedobjects
585  for l in listofsummedobjects:
586  partial_score = 0.0
587  for t in l[0]:
588  d = t.get_output()
589  partial_score += float(d["_TotalScore"])
590  output.update({stat2_inverse[l[1]]: str(partial_score)})
591 
592  # writing extralabels
593  for k in extralabels:
594  if k in self.initoutput:
595  output.update({stat2_inverse[k]: self.initoutput[k]})
596  else:
597  output.update({stat2_inverse[k]: "None"})
598 
599  if appendmode:
600  writeflag = 'a'
601  else:
602  writeflag = 'w'
603 
604  flstat = open(name, writeflag)
605  flstat.write("%s \n" % output)
606  flstat.close()
607 
608  def write_stats2(self):
609  for stat in self.dictionary_stats2.keys():
610  self.write_stat2(stat)
611 
612 
613 class ProcessOutput(object):
614  """A class for reading stat files"""
615  def __init__(self, filename):
616  self.filename = filename
617  self.isstat1 = False
618  self.isstat2 = False
619 
620  # open the file
621  if not self.filename is None:
622  f = open(self.filename, "r")
623  else:
624  raise ValueError("No file name provided. Use -h for help")
625 
626  # get the keys from the first line
627  for line in f.readlines():
628  d = eval(line)
629  self.klist = list(d.keys())
630  # check if it is a stat2 file
631  if "STAT2HEADER" in self.klist:
632  self.isstat2 = True
633  for k in self.klist:
634  if "STAT2HEADER" in str(k):
635  # if print_header: print k, d[k]
636  del d[k]
637  stat2_dict = d
638  # get the list of keys sorted by value
639  kkeys = [k[0]
640  for k in sorted(stat2_dict.items(), key=operator.itemgetter(1))]
641  self.klist = [k[1]
642  for k in sorted(stat2_dict.items(), key=operator.itemgetter(1))]
643  self.invstat2_dict = {}
644  for k in kkeys:
645  self.invstat2_dict.update({stat2_dict[k]: k})
646  else:
647  self.isstat1 = True
648  self.klist.sort()
649 
650  break
651  f.close()
652 
653  def get_keys(self):
654  return self.klist
655 
656  def show_keys(self, ncolumns=2, truncate=65):
657  IMP.pmi.tools.print_multicolumn(self.get_keys(), ncolumns, truncate)
658 
659  def get_fields(self, fields, filtertuple=None, filterout=None, get_every=1):
660  '''
661  Get the desired field names, and return a dictionary.
662 
663  @param filterout specify if you want to "grep" out something from
664  the file, so that it is faster
665  @param filtertuple a tuple that contains
666  ("TheKeyToBeFiltered",relationship,value)
667  where relationship = "<", "==", or ">"
668  '''
669 
670  outdict = {}
671  for field in fields:
672  outdict[field] = []
673 
674  # print fields values
675  f = open(self.filename, "r")
676  line_number = 0
677 
678  for line in f.readlines():
679  if not filterout is None:
680  if filterout in line:
681  continue
682  line_number += 1
683 
684  if line_number % get_every != 0:
685  continue
686  #if line_number % 1000 == 0:
687  # print "ProcessOutput.get_fields: read line %s from file %s" % (str(line_number), self.filename)
688  try:
689  d = eval(line)
690  except:
691  print("# Warning: skipped line number " + str(line_number) + " not a valid line")
692  continue
693 
694  if self.isstat1:
695 
696  if not filtertuple is None:
697  keytobefiltered = filtertuple[0]
698  relationship = filtertuple[1]
699  value = filtertuple[2]
700  if relationship == "<":
701  if float(d[keytobefiltered]) >= value:
702  continue
703  if relationship == ">":
704  if float(d[keytobefiltered]) <= value:
705  continue
706  if relationship == "==":
707  if float(d[keytobefiltered]) != value:
708  continue
709  [outdict[field].append(d[field]) for field in fields]
710 
711  elif self.isstat2:
712  if line_number == 1:
713  continue
714 
715  if not filtertuple is None:
716  keytobefiltered = filtertuple[0]
717  relationship = filtertuple[1]
718  value = filtertuple[2]
719  if relationship == "<":
720  if float(d[self.invstat2_dict[keytobefiltered]]) >= value:
721  continue
722  if relationship == ">":
723  if float(d[self.invstat2_dict[keytobefiltered]]) <= value:
724  continue
725  if relationship == "==":
726  if float(d[self.invstat2_dict[keytobefiltered]]) != value:
727  continue
728 
729  [outdict[field].append(d[self.invstat2_dict[field]])
730  for field in fields]
731  f.close()
732  return outdict
733 
734 
735 def plot_fields(fields, framemin=None, framemax=None):
736  import matplotlib as mpl
737  mpl.use('Agg')
738  import matplotlib.pyplot as plt
739 
740  plt.rc('lines', linewidth=4)
741  fig, axs = plt.subplots(nrows=len(fields))
742  fig.set_size_inches(10.5, 5.5 * len(fields))
743  plt.rc('axes', color_cycle=['r'])
744 
745  n = 0
746  for key in fields:
747  if framemin is None:
748  framemin = 0
749  if framemax is None:
750  framemax = len(fields[key])
751  x = list(range(framemin, framemax))
752  y = [float(y) for y in fields[key][framemin:framemax]]
753  if len(fields) > 1:
754  axs[n].plot(x, y)
755  axs[n].set_title(key, size="xx-large")
756  axs[n].tick_params(labelsize=18, pad=10)
757  else:
758  axs.plot(x, y)
759  axs.set_title(key, size="xx-large")
760  axs.tick_params(labelsize=18, pad=10)
761  n += 1
762 
763  # Tweak spacing between subplots to prevent labels from overlapping
764  plt.subplots_adjust(hspace=0.3)
765  plt.show()
766 
767 
769  name, values_lists, valuename=None, bins=40, colors=None, format="png",
770  reference_xline=None, yplotrange=None, xplotrange=None,normalized=True,
771  leg_names=None):
772 
773  '''Plot a list of histograms from a value list.
774  @param name the name of the plot
775  @param value_lists the list of list of values eg: [[...],[...],[...]]
776  @param valuename the y-label
777  @param bins the number of bins
778  @param colors If None, will use rainbow. Else will use specific list
779  @param format output format
780  @param reference_xline plot a reference line parallel to the y-axis
781  @param yplotrange the range for the y-axis
782  @param normalized whether the histogram is normalized or not
783  @param leg_names names for the legend
784  '''
785 
786  import matplotlib as mpl
787  mpl.use('Agg')
788  import matplotlib.pyplot as plt
789  import matplotlib.cm as cm
790  fig = plt.figure(figsize=(18.0, 9.0))
791 
792  if colors is None:
793  colors = cm.rainbow(np.linspace(0, 1, len(values_lists)))
794  for nv,values in enumerate(values_lists):
795  col=colors[nv]
796  if leg_names is not None:
797  label=leg_names[nv]
798  else:
799  label=str(nv)
800  h=plt.hist(
801  [float(y) for y in values],
802  bins=bins,
803  color=col,
804  normed=normalized,histtype='step',lw=4,
805  label=label)
806 
807  # plt.title(name,size="xx-large")
808  plt.tick_params(labelsize=12, pad=10)
809  if valuename is None:
810  plt.xlabel(name, size="xx-large")
811  else:
812  plt.xlabel(valuename, size="xx-large")
813  plt.ylabel("Frequency", size="xx-large")
814 
815  if not yplotrange is None:
816  plt.ylim()
817  if not xplotrange is None:
818  plt.xlim(xplotrange)
819 
820  plt.legend(loc=2)
821 
822  if not reference_xline is None:
823  plt.axvline(
824  reference_xline,
825  color='red',
826  linestyle='dashed',
827  linewidth=1)
828 
829  plt.savefig(name + "." + format, dpi=150, transparent=True)
830  plt.show()
831 
832 
833 def plot_fields_box_plots(name, values, positions, frequencies=None,
834  valuename="None", positionname="None", xlabels=None):
835  '''
836  Plot time series as boxplots.
837  fields is a list of time series, positions are the x-values
838  valuename is the y-label, positionname is the x-label
839  '''
840  import matplotlib as mpl
841  mpl.use('Agg')
842  import matplotlib.pyplot as plt
843  from matplotlib.patches import Polygon
844 
845  bps = []
846  fig = plt.figure(figsize=(float(len(positions)) / 2, 5.0))
847  fig.canvas.set_window_title(name)
848 
849  ax1 = fig.add_subplot(111)
850 
851  plt.subplots_adjust(left=0.2, right=0.990, top=0.95, bottom=0.4)
852 
853  bps.append(plt.boxplot(values, notch=0, sym='', vert=1,
854  whis=1.5, positions=positions))
855 
856  plt.setp(bps[-1]['boxes'], color='black', lw=1.5)
857  plt.setp(bps[-1]['whiskers'], color='black', ls=":", lw=1.5)
858 
859  if frequencies is not None:
860  ax1.plot(positions, frequencies, 'k.', alpha=0.5, markersize=20)
861 
862  # print ax1.xaxis.get_majorticklocs()
863  if not xlabels is None:
864  ax1.set_xticklabels(xlabels)
865  plt.xticks(rotation=90)
866  plt.xlabel(positionname)
867  plt.ylabel(valuename)
868 
869  plt.savefig(name,dpi=150)
870  plt.show()
871 
872 
873 def plot_xy_data(x,y,title=None,out_fn=None,display=True,set_plot_yaxis_range=None,
874  xlabel=None,ylabel=None):
875  import matplotlib as mpl
876  mpl.use('Agg')
877  import matplotlib.pyplot as plt
878  plt.rc('lines', linewidth=2)
879 
880  fig, ax = plt.subplots(nrows=1)
881  fig.set_size_inches(8,4.5)
882  if title is not None:
883  fig.canvas.set_window_title(title)
884 
885  #plt.rc('axes', color='r')
886  ax.plot(x,y,color='r')
887  if set_plot_yaxis_range is not None:
888  x1,x2,y1,y2=plt.axis()
889  y1=set_plot_yaxis_range[0]
890  y2=set_plot_yaxis_range[1]
891  plt.axis((x1,x2,y1,y2))
892  if title is not None:
893  ax.set_title(title)
894  if xlabel is not None:
895  ax.set_xlabel(xlabel)
896  if ylabel is not None:
897  ax.set_ylabel(ylabel)
898  if out_fn is not None:
899  plt.savefig(out_fn+".pdf")
900  if display:
901  plt.show()
902  plt.close(fig)
903 
904 def plot_scatter_xy_data(x,y,labelx="None",labely="None",
905  xmin=None,xmax=None,ymin=None,ymax=None,
906  savefile=False,filename="None.eps",alpha=0.75):
907 
908  import matplotlib as mpl
909  mpl.use('Agg')
910  import matplotlib.pyplot as plt
911  import sys
912  from matplotlib import rc
913  #rc('font', **{'family':'serif','serif':['Palatino']})
914  rc('font',**{'family':'sans-serif','sans-serif':['Helvetica']})
915  #rc('text', usetex=True)
916 
917  fig, axs = plt.subplots(1)
918 
919  axs0 = axs
920 
921  axs0.set_xlabel(labelx, size="xx-large")
922  axs0.set_ylabel(labely, size="xx-large")
923  axs0.tick_params(labelsize=18, pad=10)
924 
925  plot2 = []
926 
927  plot2.append(axs0.plot(x, y, 'o', color='k',lw=2, ms=0.1, alpha=alpha, c="w"))
928 
929  axs0.legend(
930  loc=0,
931  frameon=False,
932  scatterpoints=1,
933  numpoints=1,
934  columnspacing=1)
935 
936  fig.set_size_inches(8.0, 8.0)
937  fig.subplots_adjust(left=0.161, right=0.850, top=0.95, bottom=0.11)
938  if (not ymin is None) and (not ymax is None):
939  axs0.set_ylim(ymin,ymax)
940  if (not xmin is None) and (not xmax is None):
941  axs0.set_xlim(xmin,xmax)
942 
943  #plt.show()
944  if savefile:
945  fig.savefig(filename, dpi=300)
946 
947 
948 def get_graph_from_hierarchy(hier):
949  graph = []
950  depth_dict = {}
951  depth = 0
952  (graph, depth, depth_dict) = recursive_graph(
953  hier, graph, depth, depth_dict)
954 
955  # filters node labels according to depth_dict
956  node_labels_dict = {}
957  node_size_dict = {}
958  for key in depth_dict:
959  node_size_dict = 10 / depth_dict[key]
960  if depth_dict[key] < 3:
961  node_labels_dict[key] = key
962  else:
963  node_labels_dict[key] = ""
964  draw_graph(graph, labels_dict=node_labels_dict)
965 
966 
967 def recursive_graph(hier, graph, depth, depth_dict):
968  depth = depth + 1
969  nameh = IMP.atom.Hierarchy(hier).get_name()
970  index = str(hier.get_particle().get_index())
971  name1 = nameh + "|#" + index
972  depth_dict[name1] = depth
973 
974  children = IMP.atom.Hierarchy(hier).get_children()
975 
976  if len(children) == 1 or children is None:
977  depth = depth - 1
978  return (graph, depth, depth_dict)
979 
980  else:
981  for c in children:
982  (graph, depth, depth_dict) = recursive_graph(
983  c, graph, depth, depth_dict)
984  nameh = IMP.atom.Hierarchy(c).get_name()
985  index = str(c.get_particle().get_index())
986  namec = nameh + "|#" + index
987  graph.append((name1, namec))
988 
989  depth = depth - 1
990  return (graph, depth, depth_dict)
991 
992 
993 def draw_graph(graph, labels_dict=None, graph_layout='spring',
994  node_size=5, node_color=None, node_alpha=0.3,
995  node_text_size=11, fixed=None, pos=None,
996  edge_color='blue', edge_alpha=0.3, edge_thickness=1,
997  edge_text_pos=0.3,
998  validation_edges=None,
999  text_font='sans-serif',
1000  out_filename=None):
1001 
1002  import matplotlib as mpl
1003  mpl.use('Agg')
1004  import networkx as nx
1005  import matplotlib.pyplot as plt
1006  from math import sqrt, pi
1007 
1008  # create networkx graph
1009  G = nx.Graph()
1010 
1011  # add edges
1012  if type(edge_thickness) is list:
1013  for edge,weight in zip(graph,edge_thickness):
1014  G.add_edge(edge[0], edge[1], weight=weight)
1015  else:
1016  for edge in graph:
1017  G.add_edge(edge[0], edge[1])
1018 
1019  if node_color==None:
1020  node_color_rgb=(0,0,0)
1021  node_color_hex="000000"
1022  else:
1024  tmpcolor_rgb=[]
1025  tmpcolor_hex=[]
1026  for node in G.nodes():
1027  cctuple=cc.rgb(node_color[node])
1028  tmpcolor_rgb.append((cctuple[0]/255,cctuple[1]/255,cctuple[2]/255))
1029  tmpcolor_hex.append(node_color[node])
1030  node_color_rgb=tmpcolor_rgb
1031  node_color_hex=tmpcolor_hex
1032 
1033  # get node sizes if dictionary
1034  if type(node_size) is dict:
1035  tmpsize=[]
1036  for node in G.nodes():
1037  size=sqrt(node_size[node])/pi*10.0
1038  tmpsize.append(size)
1039  node_size=tmpsize
1040 
1041  for n,node in enumerate(G.nodes()):
1042  color=node_color_hex[n]
1043  size=node_size[n]
1044  nx.set_node_attributes(G, "graphics", {node : {'type': 'ellipse','w': size, 'h': size,'fill': '#'+color, 'label': node}})
1045  nx.set_node_attributes(G, "LabelGraphics", {node : {'type': 'text','text':node, 'color':'#000000', 'visible':'true'}})
1046 
1047  for edge in G.edges():
1048  nx.set_edge_attributes(G, "graphics", {edge : {'width': 1,'fill': '#000000'}})
1049 
1050  for ve in validation_edges:
1051  print(ve)
1052  if (ve[0],ve[1]) in G.edges():
1053  print("found forward")
1054  nx.set_edge_attributes(G, "graphics", {ve : {'width': 1,'fill': '#00FF00'}})
1055  elif (ve[1],ve[0]) in G.edges():
1056  print("found backward")
1057  nx.set_edge_attributes(G, "graphics", {(ve[1],ve[0]) : {'width': 1,'fill': '#00FF00'}})
1058  else:
1059  G.add_edge(ve[0], ve[1])
1060  print("not found")
1061  nx.set_edge_attributes(G, "graphics", {ve : {'width': 1,'fill': '#FF0000'}})
1062 
1063  # these are different layouts for the network you may try
1064  # shell seems to work best
1065  if graph_layout == 'spring':
1066  print(fixed, pos)
1067  graph_pos = nx.spring_layout(G,k=1.0/8.0,fixed=fixed,pos=pos)
1068  elif graph_layout == 'spectral':
1069  graph_pos = nx.spectral_layout(G)
1070  elif graph_layout == 'random':
1071  graph_pos = nx.random_layout(G)
1072  else:
1073  graph_pos = nx.shell_layout(G)
1074 
1075 
1076  # draw graph
1077  nx.draw_networkx_nodes(G, graph_pos, node_size=node_size,
1078  alpha=node_alpha, node_color=node_color_rgb,
1079  linewidths=0)
1080  nx.draw_networkx_edges(G, graph_pos, width=edge_thickness,
1081  alpha=edge_alpha, edge_color=edge_color)
1082  nx.draw_networkx_labels(
1083  G, graph_pos, labels=labels_dict, font_size=node_text_size,
1084  font_family=text_font)
1085  if out_filename:
1086  plt.savefig(out_filename)
1087  nx.write_gml(G,'out.gml')
1088  plt.show()
1089 
1090 
1091 def draw_table():
1092 
1093  # still an example!
1094 
1095  from ipyD3 import d3object
1096  from IPython.display import display
1097 
1098  d3 = d3object(width=800,
1099  height=400,
1100  style='JFTable',
1101  number=1,
1102  d3=None,
1103  title='Example table with d3js',
1104  desc='An example table created created with d3js with data generated with Python.')
1105  data = [
1106  [1277.0,
1107  654.0,
1108  288.0,
1109  1976.0,
1110  3281.0,
1111  3089.0,
1112  10336.0,
1113  4650.0,
1114  4441.0,
1115  4670.0,
1116  944.0,
1117  110.0],
1118  [1318.0,
1119  664.0,
1120  418.0,
1121  1952.0,
1122  3581.0,
1123  4574.0,
1124  11457.0,
1125  6139.0,
1126  7078.0,
1127  6561.0,
1128  2354.0,
1129  710.0],
1130  [1783.0,
1131  774.0,
1132  564.0,
1133  1470.0,
1134  3571.0,
1135  3103.0,
1136  9392.0,
1137  5532.0,
1138  5661.0,
1139  4991.0,
1140  2032.0,
1141  680.0],
1142  [1301.0,
1143  604.0,
1144  286.0,
1145  2152.0,
1146  3282.0,
1147  3369.0,
1148  10490.0,
1149  5406.0,
1150  4727.0,
1151  3428.0,
1152  1559.0,
1153  620.0],
1154  [1537.0,
1155  1714.0,
1156  724.0,
1157  4824.0,
1158  5551.0,
1159  8096.0,
1160  16589.0,
1161  13650.0,
1162  9552.0,
1163  13709.0,
1164  2460.0,
1165  720.0],
1166  [5691.0,
1167  2995.0,
1168  1680.0,
1169  11741.0,
1170  16232.0,
1171  14731.0,
1172  43522.0,
1173  32794.0,
1174  26634.0,
1175  31400.0,
1176  7350.0,
1177  3010.0],
1178  [1650.0,
1179  2096.0,
1180  60.0,
1181  50.0,
1182  1180.0,
1183  5602.0,
1184  15728.0,
1185  6874.0,
1186  5115.0,
1187  3510.0,
1188  1390.0,
1189  170.0],
1190  [72.0, 60.0, 60.0, 10.0, 120.0, 172.0, 1092.0, 675.0, 408.0, 360.0, 156.0, 100.0]]
1191  data = [list(i) for i in zip(*data)]
1192  sRows = [['January',
1193  'February',
1194  'March',
1195  'April',
1196  'May',
1197  'June',
1198  'July',
1199  'August',
1200  'September',
1201  'October',
1202  'November',
1203  'Deecember']]
1204  sColumns = [['Prod {0}'.format(i) for i in range(1, 9)],
1205  [None, '', None, None, 'Group 1', None, None, 'Group 2']]
1206  d3.addSimpleTable(data,
1207  fontSizeCells=[12, ],
1208  sRows=sRows,
1209  sColumns=sColumns,
1210  sRowsMargins=[5, 50, 0],
1211  sColsMargins=[5, 20, 10],
1212  spacing=0,
1213  addBorders=1,
1214  addOutsideBorders=-1,
1215  rectWidth=45,
1216  rectHeight=0
1217  )
1218  html = d3.render(mode=['html', 'show'])
1219  display(html)
A base class for Keys.
Definition: kernel/Key.h:46
void write_pdb(const Selection &mhd, base::TextOutput out, unsigned int model=1)
void save_frame(RMF::FileHandle file, unsigned int, std::string name="")
Definition: frames.h:42
void add_restraints(RMF::NodeHandle fh, const kernel::Restraints &hs)
A class for reading stat files.
Definition: output.py:613
Ints get_index(const kernel::ParticlesTemp &particles, const Subset &subset, const Subsets &excluded)
def plot_field_histogram
Plot a list of histograms from a value list.
Definition: output.py:768
def plot_fields_box_plots
Plot time series as boxplots.
Definition: output.py:833
Miscellaneous utilities.
Definition: tools.py:1
std::string get_module_version()
Change color code to hexadecimal to rgb.
Definition: tools.py:1377
def get_prot_name_from_particle
Return the component name provided a particle and a list of names.
Definition: tools.py:910
def get_fields
Get the desired field names, and return a dictionary.
Definition: output.py:659
The standard decorator for manipulating molecular structures.
A decorator for a particle representing an atom.
Definition: atom/Atom.h:234
The type for a residue.
A decorator for a particle with x,y,z coordinates.
Definition: XYZ.h:30
void add_hierarchies(RMF::NodeHandle fh, const atom::Hierarchies &hs)
Class for easy writing of PDBs, RMFs, and stat files.
Definition: output.py:20
void add_geometries(RMF::NodeHandle parent, const display::GeometriesTemp &r)
std::string get_module_version()
Display a segment connecting a pair of particles.
Definition: XYZR.h:168
static bool get_is_setup(const IMP::kernel::ParticleAdaptor &p)
Definition: atom/Atom.h:241
A decorator for a residue.
Definition: Residue.h:134
Basic functionality that is expected to be used by a wide variety of IMP users.
static bool get_is_setup(const IMP::kernel::ParticleAdaptor &p)
Definition: Residue.h:155
static bool get_is_setup(kernel::Model *m, kernel::ParticleIndex pi)
Definition: Fragment.h:46
Python classes to represent, score, sample and analyze models.
Functionality for loading, creating, manipulating and scoring atomic structures.
Hierarchies get_leaves(const Selection &h)
def get_residue_indexes
Retrieve the residue indexes for the given particle.
Definition: tools.py:930
def sublist_iterator
Yield all sublists of length >= lmin and <= lmax.
Definition: tools.py:1026
A decorator for a particle with x,y,z coordinates and a radius.
Definition: XYZR.h:27