IMP logo
IMP Reference Guide  2.14.0
The Integrative Modeling Platform
output.py
1 """@namespace IMP.pmi.output
2  Classes for writing output files and processing them.
3 """
4 
5 from __future__ import print_function, division
6 import IMP
7 import IMP.atom
8 import IMP.core
9 import IMP.pmi
10 import IMP.pmi.tools
11 import IMP.pmi.io
12 import os
13 import sys
14 import ast
15 import RMF
16 import numpy as np
17 import operator
18 import itertools
19 import warnings
20 import string
21 import ihm.format
22 import collections
23 try:
24  import cPickle as pickle
25 except ImportError:
26  import pickle
27 
28 class _ChainIDs(object):
29  """Map indices to multi-character chain IDs.
30  We label the first 26 chains A-Z, then we move to two-letter
31  chain IDs: AA through AZ, then BA through BZ, through to ZZ.
32  This continues with longer chain IDs."""
33  def __getitem__(self, ind):
34  chars = string.ascii_uppercase
35  lc = len(chars)
36  ids = []
37  while ind >= lc:
38  ids.append(chars[ind % lc])
39  ind = ind // lc - 1
40  ids.append(chars[ind])
41  return "".join(reversed(ids))
42 
43 
44 class ProtocolOutput(object):
45  """Base class for capturing a modeling protocol.
46  Unlike simple output of model coordinates, a complete
47  protocol includes the input data used, details on the restraints,
48  sampling, and clustering, as well as output models.
49  Use via IMP.pmi.topology.System.add_protocol_output().
50 
51  @see IMP.pmi.mmcif.ProtocolOutput for a concrete subclass that outputs
52  mmCIF files.
53  """
54  pass
55 
56 def _flatten(seq):
57  for elt in seq:
58  if isinstance(elt, (tuple, list)):
59  for elt2 in _flatten(elt):
60  yield elt2
61  else:
62  yield elt
63 
64 
65 def _disambiguate_chain(chid, seen_chains):
66  """Make sure that the chain ID is unique; warn and correct if it isn't"""
67  # Handle null chain IDs
68  if chid == '\0':
69  chid = ' '
70 
71  if chid in seen_chains:
72  warnings.warn("Duplicate chain ID '%s' encountered" % chid,
74 
75  for suffix in itertools.count(1):
76  new_chid = chid + "%d" % suffix
77  if new_chid not in seen_chains:
78  seen_chains.add(new_chid)
79  return new_chid
80  seen_chains.add(chid)
81  return chid
82 
83 
84 def _write_pdb_internal(flpdb, particle_infos_for_pdb, geometric_center,
85  write_all_residues_per_bead):
86  for n,tupl in enumerate(particle_infos_for_pdb):
87  (xyz, atom_type, residue_type,
88  chain_id, residue_index, all_indexes, radius) = tupl
89  if atom_type is None:
90  atom_type = IMP.atom.AT_CA
91  if write_all_residues_per_bead and all_indexes is not None:
92  for residue_number in all_indexes:
93  flpdb.write(
94  IMP.atom.get_pdb_string((xyz[0] - geometric_center[0],
95  xyz[1] - geometric_center[1],
96  xyz[2] - geometric_center[2]),
97  n+1, atom_type, residue_type,
98  chain_id[:1], residue_number, ' ',
99  1.00, radius))
100  else:
101  flpdb.write(
102  IMP.atom.get_pdb_string((xyz[0] - geometric_center[0],
103  xyz[1] - geometric_center[1],
104  xyz[2] - geometric_center[2]),
105  n+1, atom_type, residue_type,
106  chain_id[:1], residue_index, ' ',
107  1.00, radius))
108  flpdb.write("ENDMDL\n")
109 
110 
111 _Entity = collections.namedtuple('_Entity', ('id', 'seq'))
112 _ChainInfo = collections.namedtuple('_ChainInfo', ('entity', 'name'))
113 
114 
115 def _get_chain_info(chains, root_hier):
116  chain_info = {}
117  entities = {}
118  all_entities = []
119  for mol in IMP.atom.get_by_type(root_hier, IMP.atom.MOLECULE_TYPE):
121  chain_id = chains[molname]
122  chain = IMP.atom.Chain(mol)
123  seq = chain.get_sequence()
124  if seq not in entities:
125  entities[seq] = e = _Entity(id=len(entities)+1, seq=seq)
126  all_entities.append(e)
127  entity = entities[seq]
128  info = _ChainInfo(entity=entity, name=molname)
129  chain_info[chain_id] = info
130  return chain_info, all_entities
131 
132 
133 def _write_mmcif_internal(flpdb, particle_infos_for_pdb, geometric_center,
134  write_all_residues_per_bead, chains, root_hier):
135  # get dict with keys=chain IDs, values=chain info
136  chain_info, entities = _get_chain_info(chains, root_hier)
137 
138  writer = ihm.format.CifWriter(flpdb)
139  writer.start_block('model')
140  with writer.category("_entry") as l:
141  l.write(id='model')
142 
143  with writer.loop("_entity", ["id"]) as l:
144  for e in entities:
145  l.write(id=e.id)
146 
147  with writer.loop("_entity_poly",
148  ["entity_id", "pdbx_seq_one_letter_code"]) as l:
149  for e in entities:
150  l.write(entity_id=e.id, pdbx_seq_one_letter_code=e.seq)
151 
152  with writer.loop("_struct_asym", ["id", "entity_id", "details"]) as l:
153  for chid in sorted(chains.values()):
154  ci = chain_info[chid]
155  l.write(id=chid, entity_id=ci.entity.id, details=ci.name)
156 
157  with writer.loop("_atom_site",
158  ["group_PDB", "type_symbol", "label_atom_id",
159  "label_comp_id", "label_asym_id", "label_seq_id",
160  "auth_seq_id",
161  "Cartn_x", "Cartn_y", "Cartn_z", "label_entity_id",
162  "pdbx_pdb_model_num",
163  "id"]) as l:
164  ordinal = 1
165  for n,tupl in enumerate(particle_infos_for_pdb):
166  (xyz, atom_type, residue_type,
167  chain_id, residue_index, all_indexes, radius) = tupl
168  ci = chain_info[chain_id]
169  if atom_type is None:
170  atom_type = IMP.atom.AT_CA
171  c = xyz - geometric_center
172  if write_all_residues_per_bead and all_indexes is not None:
173  for residue_number in all_indexes:
174  l.write(group_PDB='ATOM',
175  type_symbol='C',
176  label_atom_id=atom_type.get_string(),
177  label_comp_id=residue_type.get_string(),
178  label_asym_id=chain_id,
179  label_seq_id=residue_index,
180  auth_seq_id=residue_index, Cartn_x=c[0],
181  Cartn_y=c[1], Cartn_z=c[2], id=ordinal,
182  pdbx_pdb_model_num=1,
183  label_entity_id=ci.entity.id)
184  ordinal += 1
185  else:
186  l.write(group_PDB='ATOM', type_symbol='C',
187  label_atom_id=atom_type.get_string(),
188  label_comp_id=residue_type.get_string(),
189  label_asym_id=chain_id,
190  label_seq_id=residue_index,
191  auth_seq_id=residue_index, Cartn_x=c[0],
192  Cartn_y=c[1], Cartn_z=c[2], id=ordinal,
193  pdbx_pdb_model_num=1,
194  label_entity_id=ci.entity.id)
195  ordinal += 1
196 
197 
198 class Output(object):
199  """Class for easy writing of PDBs, RMFs, and stat files
200 
201  @note Model should be updated prior to writing outputs.
202  """
203  def __init__(self, ascii=True,atomistic=False):
204  self.dictionary_pdbs = {}
205  self._pdb_mmcif = {}
206  self.dictionary_rmfs = {}
207  self.dictionary_stats = {}
208  self.dictionary_stats2 = {}
209  self.best_score_list = None
210  self.nbestscoring = None
211  self.suffixes = []
212  self.replica_exchange = False
213  self.ascii = ascii
214  self.initoutput = {}
215  self.residuetypekey = IMP.StringKey("ResidueName")
216  # 1-character chain IDs, suitable for PDB output
217  self.chainids = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789"
218  # Multi-character chain IDs, suitable for mmCIF output
219  self.multi_chainids = _ChainIDs()
220  self.dictchain = {} # keys are molecule names, values are chain ids
221  self.particle_infos_for_pdb = {}
222  self.atomistic=atomistic
223  self.use_pmi2 = False
224 
225  def get_pdb_names(self):
226  """Get a list of all PDB files being output by this instance"""
227  return list(self.dictionary_pdbs.keys())
228 
229  def get_rmf_names(self):
230  return list(self.dictionary_rmfs.keys())
231 
232  def get_stat_names(self):
233  return list(self.dictionary_stats.keys())
234 
235  def _init_dictchain(self, name, prot, multichar_chain=False):
236  self.dictchain[name] = {}
237  self.use_pmi2 = False
238  seen_chains = set()
239 
240  # attempt to find PMI objects.
241  if IMP.pmi.get_is_canonical(prot):
242  self.use_pmi2 = True
243  self.atomistic = True #detects automatically
244  for n,mol in enumerate(IMP.atom.get_by_type(prot,IMP.atom.MOLECULE_TYPE)):
245  chid = _disambiguate_chain(IMP.atom.Chain(mol).get_id(),
246  seen_chains)
247  self.dictchain[name][IMP.pmi.get_molecule_name_and_copy(mol)] = chid
248  else:
249  chainids = self.multi_chainids if multichar_chain else self.chainids
250  for n, i in enumerate(self.dictionary_pdbs[name].get_children()):
251  self.dictchain[name][i.get_name()] = chainids[n]
252 
253  def init_pdb(self, name, prot, mmcif=False):
254  """Init PDB Writing.
255  @param name The PDB filename
256  @param prot The hierarchy to write to this pdb file
257  @param mmcif If True, write PDBs in mmCIF format
258  @note if the PDB name is 'System' then will use Selection to get molecules
259  """
260  flpdb = open(name, 'w')
261  flpdb.close()
262  self.dictionary_pdbs[name] = prot
263  self._pdb_mmcif[name] = mmcif
264  self._init_dictchain(name, prot)
265 
266  def write_psf(self,filename,name):
267  flpsf=open(filename,'w')
268  flpsf.write("PSF CMAP CHEQ"+"\n")
269  index_residue_pair_list={}
270  (particle_infos_for_pdb, geometric_center)=self.get_particle_infos_for_pdb_writing(name)
271  nparticles=len(particle_infos_for_pdb)
272  flpsf.write(str(nparticles)+" !NATOM"+"\n")
273  for n,p in enumerate(particle_infos_for_pdb):
274  atom_index=n+1
275  residue_type=p[2]
276  chain=p[3]
277  resid=p[4]
278  flpsf.write('{0:8d}{1:1s}{2:4s}{3:1s}{4:4s}{5:1s}{6:4s}{7:1s}{8:4s}{9:1s}{10:4s}{11:14.6f}{12:14.6f}{13:8d}{14:14.6f}{15:14.6f}'.format(atom_index," ",chain," ",str(resid)," ",'"'+residue_type.get_string()+'"'," ","C"," ","C",1.0,0.0,0,0.0,0.0))
279  flpsf.write('\n')
280  #flpsf.write(str(atom_index)+" "+str(chain)+" "+str(resid)+" "+str(residue_type).replace('"','')+" C C "+"1.0 0.0 0 0.0 0.0\n")
281  if chain not in index_residue_pair_list:
282  index_residue_pair_list[chain]=[(atom_index,resid)]
283  else:
284  index_residue_pair_list[chain].append((atom_index,resid))
285 
286 
287  #now write the connectivity
288  indexes_pairs=[]
289  for chain in sorted(index_residue_pair_list.keys()):
290 
291  ls=index_residue_pair_list[chain]
292  #sort by residue
293  ls=sorted(ls, key=lambda tup: tup[1])
294  #get the index list
295  indexes=[x[0] for x in ls]
296  # get the contiguous pairs
297  indexes_pairs+=list(IMP.pmi.tools.sublist_iterator(indexes,lmin=2,lmax=2))
298  nbonds=len(indexes_pairs)
299  flpsf.write(str(nbonds)+" !NBOND: bonds"+"\n")
300 
301  # save bonds in fixed column format
302  for i in range(0,len(indexes_pairs),4):
303  for bond in indexes_pairs[i:i+4]:
304  flpsf.write('{0:8d}{1:8d}'.format(*bond))
305  flpsf.write('\n')
306 
307  del particle_infos_for_pdb
308  flpsf.close()
309 
310  def write_pdb(self,name,
311  appendmode=True,
312  translate_to_geometric_center=False,
313  write_all_residues_per_bead=False):
314 
315  (particle_infos_for_pdb,
316  geometric_center) = self.get_particle_infos_for_pdb_writing(name)
317 
318  if not translate_to_geometric_center:
319  geometric_center = (0, 0, 0)
320 
321  filemode = 'a' if appendmode else 'w'
322  with open(name, filemode) as flpdb:
323  if self._pdb_mmcif[name]:
324  _write_mmcif_internal(flpdb, particle_infos_for_pdb,
325  geometric_center,
326  write_all_residues_per_bead,
327  self.dictchain[name],
328  self.dictionary_pdbs[name])
329  else:
330  _write_pdb_internal(flpdb, particle_infos_for_pdb,
331  geometric_center,
332  write_all_residues_per_bead)
333 
334  def get_prot_name_from_particle(self, name, p):
335  """Get the protein name from the particle.
336  This is done by traversing the hierarchy."""
337  if self.use_pmi2:
338  return IMP.pmi.get_molecule_name_and_copy(p), True
339  else:
341  p, self.dictchain[name])
342 
343  def get_particle_infos_for_pdb_writing(self, name):
344  # index_residue_pair_list={}
345 
346  # the resindexes dictionary keep track of residues that have been already
347  # added to avoid duplication
348  # highest resolution have highest priority
349  resindexes_dict = {}
350 
351  # this dictionary dill contain the sequence of tuples needed to
352  # write the pdb
353  particle_infos_for_pdb = []
354 
355  geometric_center = [0, 0, 0]
356  atom_count = 0
357  atom_index = 0
358 
359  if self.use_pmi2:
360  # select highest resolution
361  ps = IMP.atom.Selection(self.dictionary_pdbs[name],resolution=0).get_selected_particles()
362  else:
363  ps = IMP.atom.get_leaves(self.dictionary_pdbs[name])
364 
365  for n, p in enumerate(ps):
366  protname, is_a_bead = self.get_prot_name_from_particle(name, p)
367 
368  if protname not in resindexes_dict:
369  resindexes_dict[protname] = []
370 
371  if IMP.atom.Atom.get_is_setup(p) and self.atomistic:
372  residue = IMP.atom.Residue(IMP.atom.Atom(p).get_parent())
373  rt = residue.get_residue_type()
374  resind = residue.get_index()
375  atomtype = IMP.atom.Atom(p).get_atom_type()
376  xyz = list(IMP.core.XYZ(p).get_coordinates())
377  radius = IMP.core.XYZR(p).get_radius()
378  geometric_center[0] += xyz[0]
379  geometric_center[1] += xyz[1]
380  geometric_center[2] += xyz[2]
381  atom_count += 1
382  particle_infos_for_pdb.append((xyz,
383  atomtype, rt, self.dictchain[name][protname], resind, None, radius))
384  resindexes_dict[protname].append(resind)
385 
387 
388  residue = IMP.atom.Residue(p)
389  resind = residue.get_index()
390  # skip if the residue was already added by atomistic resolution
391  # 0
392  if resind in resindexes_dict[protname]:
393  continue
394  else:
395  resindexes_dict[protname].append(resind)
396  rt = residue.get_residue_type()
397  xyz = IMP.core.XYZ(p).get_coordinates()
398  radius = IMP.core.XYZR(p).get_radius()
399  geometric_center[0] += xyz[0]
400  geometric_center[1] += xyz[1]
401  geometric_center[2] += xyz[2]
402  atom_count += 1
403  particle_infos_for_pdb.append((xyz, None,
404  rt, self.dictchain[name][protname], resind, None, radius))
405 
406  elif IMP.atom.Fragment.get_is_setup(p) and not is_a_bead:
407  resindexes = IMP.pmi.tools.get_residue_indexes(p)
408  resind = resindexes[len(resindexes) // 2]
409  if resind in resindexes_dict[protname]:
410  continue
411  else:
412  resindexes_dict[protname].append(resind)
413  rt = IMP.atom.ResidueType('BEA')
414  xyz = IMP.core.XYZ(p).get_coordinates()
415  radius = IMP.core.XYZR(p).get_radius()
416  geometric_center[0] += xyz[0]
417  geometric_center[1] += xyz[1]
418  geometric_center[2] += xyz[2]
419  atom_count += 1
420  particle_infos_for_pdb.append((xyz, None,
421  rt, self.dictchain[name][protname], resind, resindexes, radius))
422 
423  else:
424  if is_a_bead:
425  rt = IMP.atom.ResidueType('BEA')
426  resindexes = IMP.pmi.tools.get_residue_indexes(p)
427  if len(resindexes) > 0:
428  resind = resindexes[len(resindexes) // 2]
429  xyz = IMP.core.XYZ(p).get_coordinates()
430  radius = IMP.core.XYZR(p).get_radius()
431  geometric_center[0] += xyz[0]
432  geometric_center[1] += xyz[1]
433  geometric_center[2] += xyz[2]
434  atom_count += 1
435  particle_infos_for_pdb.append((xyz, None,
436  rt, self.dictchain[name][protname], resind, resindexes, radius))
437 
438  if atom_count > 0:
439  geometric_center = (geometric_center[0] / atom_count,
440  geometric_center[1] / atom_count,
441  geometric_center[2] / atom_count)
442 
443  # sort by chain ID, then residue index. Longer chain IDs (e.g. AA)
444  # should always come after shorter (e.g. Z)
445  particle_infos_for_pdb = sorted(particle_infos_for_pdb,
446  key=lambda x: (len(x[3]), x[3], x[4]))
447 
448  return (particle_infos_for_pdb, geometric_center)
449 
450 
451  def write_pdbs(self, appendmode=True, mmcif=False):
452  for pdb in self.dictionary_pdbs.keys():
453  self.write_pdb(pdb, appendmode)
454 
455  def init_pdb_best_scoring(self,
456  suffix,
457  prot,
458  nbestscoring,
459  replica_exchange=False, mmcif=False):
460  # save only the nbestscoring conformations
461  # create as many pdbs as needed
462 
463  self._pdb_best_scoring_mmcif = mmcif
464  fileext = '.cif' if mmcif else '.pdb'
465  self.suffixes.append(suffix)
466  self.replica_exchange = replica_exchange
467  if not self.replica_exchange:
468  # common usage
469  # if you are not in replica exchange mode
470  # initialize the array of scores internally
471  self.best_score_list = []
472  else:
473  # otherwise the replicas must communicate
474  # through a common file to know what are the best scores
475  self.best_score_file_name = "best.scores.rex.py"
476  self.best_score_list = []
477  with open(self.best_score_file_name, "w") as best_score_file:
478  best_score_file.write(
479  "self.best_score_list=" + str(self.best_score_list) + "\n")
480 
481  self.nbestscoring = nbestscoring
482  for i in range(self.nbestscoring):
483  name = suffix + "." + str(i) + fileext
484  flpdb = open(name, 'w')
485  flpdb.close()
486  self.dictionary_pdbs[name] = prot
487  self._pdb_mmcif[name] = mmcif
488  self._init_dictchain(name, prot)
489 
490  def write_pdb_best_scoring(self, score):
491  if self.nbestscoring is None:
492  print("Output.write_pdb_best_scoring: init_pdb_best_scoring not run")
493 
494  mmcif = self._pdb_best_scoring_mmcif
495  fileext = '.cif' if mmcif else '.pdb'
496  # update the score list
497  if self.replica_exchange:
498  # read the self.best_score_list from the file
499  with open(self.best_score_file_name) as fh:
500  exec(fh.read())
501 
502  if len(self.best_score_list) < self.nbestscoring:
503  self.best_score_list.append(score)
504  self.best_score_list.sort()
505  index = self.best_score_list.index(score)
506  for suffix in self.suffixes:
507  for i in range(len(self.best_score_list) - 2, index - 1, -1):
508  oldname = suffix + "." + str(i) + fileext
509  newname = suffix + "." + str(i + 1) + fileext
510  # rename on Windows fails if newname already exists
511  if os.path.exists(newname):
512  os.unlink(newname)
513  os.rename(oldname, newname)
514  filetoadd = suffix + "." + str(index) + fileext
515  self.write_pdb(filetoadd, appendmode=False)
516 
517  else:
518  if score < self.best_score_list[-1]:
519  self.best_score_list.append(score)
520  self.best_score_list.sort()
521  self.best_score_list.pop(-1)
522  index = self.best_score_list.index(score)
523  for suffix in self.suffixes:
524  for i in range(len(self.best_score_list) - 1, index - 1, -1):
525  oldname = suffix + "." + str(i) + fileext
526  newname = suffix + "." + str(i + 1) + fileext
527  os.rename(oldname, newname)
528  filenametoremove = suffix + \
529  "." + str(self.nbestscoring) + fileext
530  os.remove(filenametoremove)
531  filetoadd = suffix + "." + str(index) + fileext
532  self.write_pdb(filetoadd, appendmode=False)
533 
534  if self.replica_exchange:
535  # write the self.best_score_list to the file
536  with open(self.best_score_file_name, "w") as best_score_file:
537  best_score_file.write(
538  "self.best_score_list=" + str(self.best_score_list) + '\n')
539 
540  def init_rmf(self, name, hierarchies, rs=None, geometries=None, listofobjects=None):
541  """
542  Initialize an RMF file
543 
544  @param name the name of the RMF file
545  @param hierarchies the hierarchies to be included (it is a list)
546  @param rs optional, the restraint sets (it is a list)
547  @param geometries optional, the geometries (it is a list)
548  @param listofobjects optional, the list of objects for the stat (it is a list)
549  """
550  rh = RMF.create_rmf_file(name)
551  IMP.rmf.add_hierarchies(rh, hierarchies)
552  cat=None
553  outputkey_rmfkey=None
554 
555  if rs is not None:
557  if geometries is not None:
558  IMP.rmf.add_geometries(rh,geometries)
559  if listofobjects is not None:
560  cat = rh.get_category("stat")
561  outputkey_rmfkey={}
562  for l in listofobjects:
563  if not "get_output" in dir(l):
564  raise ValueError("Output: object %s doesn't have get_output() method" % str(l))
565  output=l.get_output()
566  for outputkey in output:
567  rmftag=RMF.string_tag
568  if isinstance(output[outputkey], float):
569  rmftag = RMF.float_tag
570  elif isinstance(output[outputkey], int):
571  rmftag = RMF.int_tag
572  elif isinstance(output[outputkey], str):
573  rmftag = RMF.string_tag
574  else:
575  rmftag = RMF.string_tag
576  rmfkey=rh.get_key(cat, outputkey, rmftag)
577  outputkey_rmfkey[outputkey]=rmfkey
578  outputkey_rmfkey["rmf_file"]=rh.get_key(cat, "rmf_file", RMF.string_tag)
579  outputkey_rmfkey["rmf_frame_index"]=rh.get_key(cat, "rmf_frame_index", RMF.int_tag)
580 
581  self.dictionary_rmfs[name] = (rh,cat,outputkey_rmfkey,listofobjects)
582 
583  def add_restraints_to_rmf(self, name, objectlist):
584  for o in _flatten(objectlist):
585  try:
586  rs = o.get_restraint_for_rmf()
587  if not isinstance(rs, (list, tuple)):
588  rs = [rs]
589  except:
590  rs = [o.get_restraint()]
592  self.dictionary_rmfs[name][0], rs)
593 
594  def add_geometries_to_rmf(self, name, objectlist):
595  for o in objectlist:
596  geos = o.get_geometries()
597  IMP.rmf.add_geometries(self.dictionary_rmfs[name][0], geos)
598 
599  def add_particle_pair_from_restraints_to_rmf(self, name, objectlist):
600  for o in objectlist:
601 
602  pps = o.get_particle_pairs()
603  for pp in pps:
605  self.dictionary_rmfs[name][0],
607 
608  def write_rmf(self, name):
609  IMP.rmf.save_frame(self.dictionary_rmfs[name][0])
610  if self.dictionary_rmfs[name][1] is not None:
611  cat=self.dictionary_rmfs[name][1]
612  outputkey_rmfkey=self.dictionary_rmfs[name][2]
613  listofobjects=self.dictionary_rmfs[name][3]
614  for l in listofobjects:
615  output=l.get_output()
616  for outputkey in output:
617  rmfkey=outputkey_rmfkey[outputkey]
618  try:
619  self.dictionary_rmfs[name][0].get_root_node().set_value(rmfkey,output[outputkey])
620  except NotImplementedError:
621  continue
622  rmfkey = outputkey_rmfkey["rmf_file"]
623  self.dictionary_rmfs[name][0].get_root_node().set_value(rmfkey, name)
624  rmfkey = outputkey_rmfkey["rmf_frame_index"]
625  nframes=self.dictionary_rmfs[name][0].get_number_of_frames()
626  self.dictionary_rmfs[name][0].get_root_node().set_value(rmfkey, nframes-1)
627  self.dictionary_rmfs[name][0].flush()
628 
629  def close_rmf(self, name):
630  rh = self.dictionary_rmfs[name][0]
631  del self.dictionary_rmfs[name]
632  del rh
633 
634  def write_rmfs(self):
635  for rmfinfo in self.dictionary_rmfs.keys():
636  self.write_rmf(rmfinfo[0])
637 
638  def init_stat(self, name, listofobjects):
639  if self.ascii:
640  flstat = open(name, 'w')
641  flstat.close()
642  else:
643  flstat = open(name, 'wb')
644  flstat.close()
645 
646  # check that all objects in listofobjects have a get_output method
647  for l in listofobjects:
648  if not "get_output" in dir(l):
649  raise ValueError("Output: object %s doesn't have get_output() method" % str(l))
650  self.dictionary_stats[name] = listofobjects
651 
652  def set_output_entry(self, key, value):
653  self.initoutput.update({key: value})
654 
655  def write_stat(self, name, appendmode=True):
656  output = self.initoutput
657  for obj in self.dictionary_stats[name]:
658  d = obj.get_output()
659  # remove all entries that begin with _ (private entries)
660  dfiltered = dict((k, v) for k, v in d.items() if k[0] != "_")
661  output.update(dfiltered)
662 
663  if appendmode:
664  writeflag = 'a'
665  else:
666  writeflag = 'w'
667 
668  if self.ascii:
669  flstat = open(name, writeflag)
670  flstat.write("%s \n" % output)
671  flstat.close()
672  else:
673  flstat = open(name, writeflag + 'b')
674  cPickle.dump(output, flstat, 2)
675  flstat.close()
676 
677  def write_stats(self):
678  for stat in self.dictionary_stats.keys():
679  self.write_stat(stat)
680 
681  def get_stat(self, name):
682  output = {}
683  for obj in self.dictionary_stats[name]:
684  output.update(obj.get_output())
685  return output
686 
687  def write_test(self, name, listofobjects):
688 # write the test:
689 # output=output.Output()
690 # output.write_test("test_modeling11_models.rmf_45492_11Sep13_veena_imp-020713.dat",outputobjects)
691 # run the test:
692 # output=output.Output()
693 # output.test("test_modeling11_models.rmf_45492_11Sep13_veena_imp-020713.dat",outputobjects)
694  flstat = open(name, 'w')
695  output = self.initoutput
696  for l in listofobjects:
697  if not "get_test_output" in dir(l) and not "get_output" in dir(l):
698  raise ValueError("Output: object %s doesn't have get_output() or get_test_output() method" % str(l))
699  self.dictionary_stats[name] = listofobjects
700 
701  for obj in self.dictionary_stats[name]:
702  try:
703  d = obj.get_test_output()
704  except:
705  d = obj.get_output()
706  # remove all entries that begin with _ (private entries)
707  dfiltered = dict((k, v) for k, v in d.items() if k[0] != "_")
708  output.update(dfiltered)
709  #output.update({"ENVIRONMENT": str(self.get_environment_variables())})
710  #output.update(
711  # {"IMP_VERSIONS": str(self.get_versions_of_relevant_modules())})
712  flstat.write("%s \n" % output)
713  flstat.close()
714 
715  def test(self, name, listofobjects, tolerance=1e-5):
716  output = self.initoutput
717  for l in listofobjects:
718  if not "get_test_output" in dir(l) and not "get_output" in dir(l):
719  raise ValueError("Output: object %s doesn't have get_output() or get_test_output() method" % str(l))
720  for obj in listofobjects:
721  try:
722  output.update(obj.get_test_output())
723  except:
724  output.update(obj.get_output())
725  #output.update({"ENVIRONMENT": str(self.get_environment_variables())})
726  #output.update(
727  # {"IMP_VERSIONS": str(self.get_versions_of_relevant_modules())})
728 
729  flstat = open(name, 'r')
730 
731  passed=True
732  for l in flstat:
733  test_dict = ast.literal_eval(l)
734  for k in test_dict:
735  if k in output:
736  old_value = str(test_dict[k])
737  new_value = str(output[k])
738  try:
739  float(old_value)
740  is_float = True
741  except ValueError:
742  is_float = False
743 
744  if is_float:
745  fold = float(old_value)
746  fnew = float(new_value)
747  diff = abs(fold - fnew)
748  if diff > tolerance:
749  print("%s: test failed, old value: %s new value %s; "
750  "diff %f > %f" % (str(k), str(old_value),
751  str(new_value), diff,
752  tolerance), file=sys.stderr)
753  passed=False
754  elif test_dict[k] != output[k]:
755  if len(old_value) < 50 and len(new_value) < 50:
756  print("%s: test failed, old value: %s new value %s"
757  % (str(k), old_value, new_value), file=sys.stderr)
758  passed=False
759  else:
760  print("%s: test failed, omitting results (too long)"
761  % str(k), file=sys.stderr)
762  passed=False
763 
764  else:
765  print("%s from old objects (file %s) not in new objects"
766  % (str(k), str(name)), file=sys.stderr)
767  flstat.close()
768  return passed
769 
770  def get_environment_variables(self):
771  import os
772  return str(os.environ)
773 
774  def get_versions_of_relevant_modules(self):
775  import IMP
776  versions = {}
777  versions["IMP_VERSION"] = IMP.get_module_version()
778  versions["PMI_VERSION"] = IMP.pmi.get_module_version()
779  try:
780  import IMP.isd2
781  versions["ISD2_VERSION"] = IMP.isd2.get_module_version()
782  except (ImportError):
783  pass
784  try:
785  import IMP.isd_emxl
786  versions["ISD_EMXL_VERSION"] = IMP.isd_emxl.get_module_version()
787  except (ImportError):
788  pass
789  return versions
790 
791 #-------------------
792  def init_stat2(
793  self,
794  name,
795  listofobjects,
796  extralabels=None,
797  listofsummedobjects=None):
798  # this is a new stat file that should be less
799  # space greedy!
800  # listofsummedobjects must be in the form [([obj1,obj2,obj3,obj4...],label)]
801  # extralabels
802 
803  if listofsummedobjects is None:
804  listofsummedobjects = []
805  if extralabels is None:
806  extralabels = []
807  flstat = open(name, 'w')
808  output = {}
809  stat2_keywords = {"STAT2HEADER": "STAT2HEADER"}
810  stat2_keywords.update(
811  {"STAT2HEADER_ENVIRON": str(self.get_environment_variables())})
812  stat2_keywords.update(
813  {"STAT2HEADER_IMP_VERSIONS": str(self.get_versions_of_relevant_modules())})
814  stat2_inverse = {}
815 
816  for l in listofobjects:
817  if not "get_output" in dir(l):
818  raise ValueError("Output: object %s doesn't have get_output() method" % str(l))
819  else:
820  d = l.get_output()
821  # remove all entries that begin with _ (private entries)
822  dfiltered = dict((k, v)
823  for k, v in d.items() if k[0] != "_")
824  output.update(dfiltered)
825 
826  # check for customizable entries
827  for l in listofsummedobjects:
828  for t in l[0]:
829  if not "get_output" in dir(t):
830  raise ValueError("Output: object %s doesn't have get_output() method" % str(t))
831  else:
832  if "_TotalScore" not in t.get_output():
833  raise ValueError("Output: object %s doesn't have _TotalScore entry to be summed" % str(t))
834  else:
835  output.update({l[1]: 0.0})
836 
837  for k in extralabels:
838  output.update({k: 0.0})
839 
840  for n, k in enumerate(output):
841  stat2_keywords.update({n: k})
842  stat2_inverse.update({k: n})
843 
844  flstat.write("%s \n" % stat2_keywords)
845  flstat.close()
846  self.dictionary_stats2[name] = (
847  listofobjects,
848  stat2_inverse,
849  listofsummedobjects,
850  extralabels)
851 
852  def write_stat2(self, name, appendmode=True):
853  output = {}
854  (listofobjects, stat2_inverse, listofsummedobjects,
855  extralabels) = self.dictionary_stats2[name]
856 
857  # writing objects
858  for obj in listofobjects:
859  od = obj.get_output()
860  dfiltered = dict((k, v) for k, v in od.items() if k[0] != "_")
861  for k in dfiltered:
862  output.update({stat2_inverse[k]: od[k]})
863 
864  # writing summedobjects
865  for l in listofsummedobjects:
866  partial_score = 0.0
867  for t in l[0]:
868  d = t.get_output()
869  partial_score += float(d["_TotalScore"])
870  output.update({stat2_inverse[l[1]]: str(partial_score)})
871 
872  # writing extralabels
873  for k in extralabels:
874  if k in self.initoutput:
875  output.update({stat2_inverse[k]: self.initoutput[k]})
876  else:
877  output.update({stat2_inverse[k]: "None"})
878 
879  if appendmode:
880  writeflag = 'a'
881  else:
882  writeflag = 'w'
883 
884  flstat = open(name, writeflag)
885  flstat.write("%s \n" % output)
886  flstat.close()
887 
888  def write_stats2(self):
889  for stat in self.dictionary_stats2.keys():
890  self.write_stat2(stat)
891 
892 
893 class OutputStatistics(object):
894  """Collect statistics from ProcessOutput.get_fields().
895  Counters of the total number of frames read, plus the models that
896  passed the various filters used in get_fields(), are provided."""
897  def __init__(self):
898  self.total = 0
899  self.passed_get_every = 0
900  self.passed_filterout = 0
901  self.passed_filtertuple = 0
902 
903 
904 class ProcessOutput(object):
905  """A class for reading stat files (either rmf or ascii v1 and v2)"""
906  def __init__(self, filename):
907  self.filename = filename
908  self.isstat1 = False
909  self.isstat2 = False
910  self.isrmf = False
911 
912  if self.filename is None:
913  raise ValueError("No file name provided. Use -h for help")
914 
915  try:
916  #let's see if that is an rmf file
917  rh = RMF.open_rmf_file_read_only(self.filename)
918  self.isrmf=True
919  cat=rh.get_category('stat')
920  rmf_klist=rh.get_keys(cat)
921  self.rmf_names_keys=dict([(rh.get_name(k),k) for k in rmf_klist])
922  del rh
923 
924  except IOError:
925  f = open(self.filename, "r")
926  # try with an ascii stat file
927  # get the keys from the first line
928  for line in f.readlines():
929  d = ast.literal_eval(line)
930  self.klist = list(d.keys())
931  # check if it is a stat2 file
932  if "STAT2HEADER" in self.klist:
933  self.isstat2 = True
934  for k in self.klist:
935  if "STAT2HEADER" in str(k):
936  # if print_header: print k, d[k]
937  del d[k]
938  stat2_dict = d
939  # get the list of keys sorted by value
940  kkeys = [k[0]
941  for k in sorted(stat2_dict.items(), key=operator.itemgetter(1))]
942  self.klist = [k[1]
943  for k in sorted(stat2_dict.items(), key=operator.itemgetter(1))]
944  self.invstat2_dict = {}
945  for k in kkeys:
946  self.invstat2_dict.update({stat2_dict[k]: k})
947  else:
948  IMP.handle_use_deprecated("statfile v1 is deprecated. "
949  "Please convert to statfile v2.\n")
950  self.isstat1 = True
951  self.klist.sort()
952 
953  break
954  f.close()
955 
956 
957  def get_keys(self):
958  if self.isrmf:
959  return sorted(self.rmf_names_keys.keys())
960  else:
961  return self.klist
962 
963  def show_keys(self, ncolumns=2, truncate=65):
964  IMP.pmi.tools.print_multicolumn(self.get_keys(), ncolumns, truncate)
965 
966  def get_fields(self, fields, filtertuple=None, filterout=None, get_every=1,
967  statistics=None):
968  '''
969  Get the desired field names, and return a dictionary.
970  Namely, "fields" are the queried keys in the stat file (eg. ["Total_Score",...])
971  The returned data structure is a dictionary, where each key is a field and the value
972  is the time series (ie, frame ordered series)
973  of that field (ie, {"Total_Score":[Score_0,Score_1,Score_2,Score_3,...],....} )
974 
975  @param fields (list of strings) queried keys in the stat file (eg. "Total_Score"....)
976  @param filterout specify if you want to "grep" out something from
977  the file, so that it is faster
978  @param filtertuple a tuple that contains
979  ("TheKeyToBeFiltered",relationship,value)
980  where relationship = "<", "==", or ">"
981  @param get_every only read every Nth line from the file
982  @param statistics if provided, accumulate statistics in an
983  OutputStatistics object
984  '''
985 
986  if statistics is None:
987  statistics = OutputStatistics()
988  outdict = {}
989  for field in fields:
990  outdict[field] = []
991 
992  # print fields values
993  if self.isrmf:
994  rh = RMF.open_rmf_file_read_only(self.filename)
995  nframes=rh.get_number_of_frames()
996  for i in range(nframes):
997  statistics.total += 1
998  # "get_every" and "filterout" not enforced for RMF
999  statistics.passed_get_every += 1
1000  statistics.passed_filterout += 1
1001  IMP.rmf.load_frame(rh, RMF.FrameID(i))
1002  if not filtertuple is None:
1003  keytobefiltered = filtertuple[0]
1004  relationship = filtertuple[1]
1005  value = filtertuple[2]
1006  datavalue=rh.get_root_node().get_value(self.rmf_names_keys[keytobefiltered])
1007  if self.isfiltered(datavalue,relationship,value): continue
1008 
1009  statistics.passed_filtertuple += 1
1010  for field in fields:
1011  outdict[field].append(rh.get_root_node().get_value(self.rmf_names_keys[field]))
1012 
1013  else:
1014  f = open(self.filename, "r")
1015  line_number = 0
1016 
1017  for line in f.readlines():
1018  statistics.total += 1
1019  if not filterout is None:
1020  if filterout in line:
1021  continue
1022  statistics.passed_filterout += 1
1023  line_number += 1
1024 
1025  if line_number % get_every != 0:
1026  if line_number == 1 and self.isstat2:
1027  statistics.total -= 1
1028  statistics.passed_filterout -= 1
1029  continue
1030  statistics.passed_get_every += 1
1031  #if line_number % 1000 == 0:
1032  # print "ProcessOutput.get_fields: read line %s from file %s" % (str(line_number), self.filename)
1033  try:
1034  d = ast.literal_eval(line)
1035  except:
1036  print("# Warning: skipped line number " + str(line_number) + " not a valid line")
1037  continue
1038 
1039  if self.isstat1:
1040 
1041  if not filtertuple is None:
1042  keytobefiltered = filtertuple[0]
1043  relationship = filtertuple[1]
1044  value = filtertuple[2]
1045  datavalue=d[keytobefiltered]
1046  if self.isfiltered(datavalue, relationship, value): continue
1047 
1048  statistics.passed_filtertuple += 1
1049  [outdict[field].append(d[field]) for field in fields]
1050 
1051  elif self.isstat2:
1052  if line_number == 1:
1053  statistics.total -= 1
1054  statistics.passed_filterout -= 1
1055  statistics.passed_get_every -= 1
1056  continue
1057 
1058  if not filtertuple is None:
1059  keytobefiltered = filtertuple[0]
1060  relationship = filtertuple[1]
1061  value = filtertuple[2]
1062  datavalue=d[self.invstat2_dict[keytobefiltered]]
1063  if self.isfiltered(datavalue, relationship, value): continue
1064 
1065  statistics.passed_filtertuple += 1
1066  [outdict[field].append(d[self.invstat2_dict[field]]) for field in fields]
1067 
1068  f.close()
1069 
1070  return outdict
1071 
1072  def isfiltered(self,datavalue,relationship,refvalue):
1073  dofilter=False
1074  try:
1075  fdatavalue=float(datavalue)
1076  except ValueError:
1077  raise ValueError("ProcessOutput.filter: datavalue cannot be converted into a float")
1078 
1079  if relationship == "<":
1080  if float(datavalue) >= refvalue:
1081  dofilter=True
1082  if relationship == ">":
1083  if float(datavalue) <= refvalue:
1084  dofilter=True
1085  if relationship == "==":
1086  if float(datavalue) != refvalue:
1087  dofilter=True
1088  return dofilter
1089 
1090 
1092  """ class to allow more advanced handling of RMF files.
1093  It is both a container and a IMP.atom.Hierarchy.
1094  - it is iterable (while loading the corresponding frame)
1095  - Item brackets [] load the corresponding frame
1096  - slice create an iterator
1097  - can relink to another RMF file
1098  """
1099  def __init__(self,model,rmf_file_name):
1100  """
1101  @param model: the IMP.Model()
1102  @param rmf_file_name: str, path of the rmf file
1103  """
1104  self.model=model
1105  try:
1106  self.rh_ref = RMF.open_rmf_file_read_only(rmf_file_name)
1107  except TypeError:
1108  raise TypeError("Wrong rmf file name or type: %s"% str(rmf_file_name))
1109  hs = IMP.rmf.create_hierarchies(self.rh_ref, self.model)
1110  IMP.rmf.load_frame(self.rh_ref, RMF.FrameID(0))
1111  self.root_hier_ref = hs[0]
1112  IMP.atom.Hierarchy.__init__(self, self.root_hier_ref)
1113  self.model.update()
1114  self.ColorHierarchy=None
1115 
1116 
1117  def link_to_rmf(self,rmf_file_name):
1118  """
1119  Link to another RMF file
1120  """
1121  self.rh_ref = RMF.open_rmf_file_read_only(rmf_file_name)
1122  IMP.rmf.link_hierarchies(self.rh_ref, [self])
1123  if self.ColorHierarchy:
1124  self.ColorHierarchy.method()
1125  RMFHierarchyHandler.set_frame(self,0)
1126 
1127  def set_frame(self,index):
1128  try:
1129  IMP.rmf.load_frame(self.rh_ref, RMF.FrameID(index))
1130  except:
1131  print("skipping frame %s:%d\n"%(self.current_rmf, index))
1132  self.model.update()
1133 
1134  def get_number_of_frames(self):
1135  return self.rh_ref.get_number_of_frames()
1136 
1137  def __getitem__(self,int_slice_adaptor):
1138  if isinstance(int_slice_adaptor, int):
1139  self.set_frame(int_slice_adaptor)
1140  return int_slice_adaptor
1141  elif isinstance(int_slice_adaptor, slice):
1142  return self.__iter__(int_slice_adaptor)
1143  else:
1144  raise TypeError("Unknown Type")
1145 
1146  def __len__(self):
1147  return self.get_number_of_frames()
1148 
1149  def __iter__(self,slice_key=None):
1150  if slice_key is None:
1151  for nframe in range(len(self)):
1152  yield self[nframe]
1153  else:
1154  for nframe in list(range(len(self)))[slice_key]:
1155  yield self[nframe]
1156 
1157 class CacheHierarchyCoordinates(object):
1158  def __init__(self,StatHierarchyHandler):
1159  self.xyzs=[]
1160  self.nrms=[]
1161  self.rbs=[]
1162  self.nrm_coors={}
1163  self.xyz_coors={}
1164  self.rb_trans={}
1165  self.current_index=None
1166  self.rmfh=StatHierarchyHandler
1167  rbs,xyzs=IMP.pmi.tools.get_rbs_and_beads([self.rmfh])
1168  self.model=self.rmfh.get_model()
1169  self.rbs=rbs
1170  for xyz in xyzs:
1172  nrm=IMP.core.NonRigidMember(xyz)
1173  self.nrms.append(nrm)
1174  else:
1175  fb=IMP.core.XYZ(xyz)
1176  self.xyzs.append(fb)
1177 
1178  def do_store(self,index):
1179  self.rb_trans[index]={}
1180  self.nrm_coors[index]={}
1181  self.xyz_coors[index]={}
1182  for rb in self.rbs:
1183  self.rb_trans[index][rb]=rb.get_reference_frame()
1184  for nrm in self.nrms:
1185  self.nrm_coors[index][nrm]=nrm.get_internal_coordinates()
1186  for xyz in self.xyzs:
1187  self.xyz_coors[index][xyz]=xyz.get_coordinates()
1188  self.current_index=index
1189 
1190  def do_update(self,index):
1191  if self.current_index!=index:
1192  for rb in self.rbs:
1193  rb.set_reference_frame(self.rb_trans[index][rb])
1194  for nrm in self.nrms:
1195  nrm.set_internal_coordinates(self.nrm_coors[index][nrm])
1196  for xyz in self.xyzs:
1197  xyz.set_coordinates(self.xyz_coors[index][xyz])
1198  self.current_index=index
1199  self.model.update()
1200 
1201  def get_number_of_frames(self):
1202  return len(self.rb_trans.keys())
1203 
1204  def __getitem__(self,index):
1205  if isinstance(index, int):
1206  return index in self.rb_trans.keys()
1207  else:
1208  raise TypeError("Unknown Type")
1209 
1210  def __len__(self):
1211  return self.get_number_of_frames()
1212 
1213 
1214 
1215 
1217  """ class to link stat files to several rmf files """
1218  def __init__(self,model=None,stat_file=None,number_best_scoring_models=None,score_key=None,StatHierarchyHandler=None,cache=None):
1219  """
1220 
1221  @param model: IMP.Model()
1222  @param stat_file: either 1) a list or 2) a single stat file names (either rmfs or ascii, or pickled data or pickled cluster), 3) a dictionary containing an rmf/ascii
1223  stat file name as key and a list of frames as values
1224  @param number_best_scoring_models:
1225  @param StatHierarchyHandler: copy constructor input object
1226  @param cache: cache coordinates and rigid body transformations.
1227  """
1228 
1229  if not StatHierarchyHandler is None:
1230  #overrides all other arguments
1231  #copy constructor: create a copy with different RMFHierarchyHandler
1232  self.model=StatHierarchyHandler.model
1233  self.data=StatHierarchyHandler.data
1234  self.number_best_scoring_models=StatHierarchyHandler.number_best_scoring_models
1235  self.is_setup=True
1236  self.current_rmf=StatHierarchyHandler.current_rmf
1237  self.current_frame=None
1238  self.current_index=None
1239  self.score_threshold=StatHierarchyHandler.score_threshold
1240  self.score_key=StatHierarchyHandler.score_key
1241  self.cache=StatHierarchyHandler.cache
1242  RMFHierarchyHandler.__init__(self, self.model,self.current_rmf)
1243  if self.cache:
1244  self.cache=CacheHierarchyCoordinates(self)
1245  else:
1246  self.cache=None
1247  self.set_frame(0)
1248 
1249  else:
1250  #standard constructor
1251  self.model=model
1252  self.data=[]
1253  self.number_best_scoring_models=number_best_scoring_models
1254  self.cache=cache
1255 
1256  if score_key is None:
1257  self.score_key="Total_Score"
1258  else:
1259  self.score_key=score_key
1260  self.is_setup=None
1261  self.current_rmf=None
1262  self.current_frame=None
1263  self.current_index=None
1264  self.score_threshold=None
1265 
1266  if isinstance(stat_file, str):
1267  self.add_stat_file(stat_file)
1268  elif isinstance(stat_file, list):
1269  for f in stat_file:
1270  self.add_stat_file(f)
1271 
1272  def add_stat_file(self,stat_file):
1273  try:
1274  '''check that it is not a pickle file with saved data from a previous calculation'''
1275  self.load_data(stat_file)
1276 
1277  if self.number_best_scoring_models:
1278  scores = self.get_scores()
1279  max_score = sorted(scores)[0:min(len(self), self.number_best_scoring_models)][-1]
1280  self.do_filter_by_score(max_score)
1281 
1282  except pickle.UnpicklingError:
1283  '''alternatively read the ascii stat files'''
1284  try:
1285  scores,rmf_files,rmf_frame_indexes,features = self.get_info_from_stat_file(stat_file, self.score_threshold)
1286  except (KeyError, SyntaxError):
1287  # in this case check that is it an rmf file, probably without stat stored in
1288  try:
1289  # let's see if that is an rmf file
1290  rh = RMF.open_rmf_file_read_only(stat_file)
1291  nframes = rh.get_number_of_frames()
1292  scores=[0.0]*nframes
1293  rmf_files=[stat_file]*nframes
1294  rmf_frame_indexes=range(nframes)
1295  features={}
1296  except:
1297  return
1298 
1299 
1300  if len(set(rmf_files)) > 1:
1301  raise ("Multiple RMF files found")
1302 
1303  if not rmf_files:
1304  print("StatHierarchyHandler: Error: Trying to set none as rmf_file (probably empty stat file), aborting")
1305  return
1306 
1307  for n,index in enumerate(rmf_frame_indexes):
1308  featn_dict=dict([(k,features[k][n]) for k in features])
1309  self.data.append(IMP.pmi.output.DataEntry(stat_file,rmf_files[n],index,scores[n],featn_dict))
1310 
1311  if self.number_best_scoring_models:
1312  scores=self.get_scores()
1313  max_score=sorted(scores)[0:min(len(self),self.number_best_scoring_models)][-1]
1314  self.do_filter_by_score(max_score)
1315 
1316  if not self.is_setup:
1317  RMFHierarchyHandler.__init__(self, self.model,self.get_rmf_names()[0])
1318  if self.cache:
1319  self.cache=CacheHierarchyCoordinates(self)
1320  else:
1321  self.cache=None
1322  self.is_setup=True
1323  self.current_rmf=self.get_rmf_names()[0]
1324 
1325  self.set_frame(0)
1326 
1327  def save_data(self,filename='data.pkl'):
1328  with open(filename, 'wb') as fl:
1329  pickle.dump(self.data, fl)
1330 
1331  def load_data(self,filename='data.pkl'):
1332  with open(filename, 'rb') as fl:
1333  data_structure=pickle.load(fl)
1334  #first check that it is a list
1335  if not isinstance(data_structure, list):
1336  raise TypeError("%filename should contain a list of IMP.pmi.output.DataEntry or IMP.pmi.output.Cluster" % filename)
1337  # second check the types
1338  if all(isinstance(item, IMP.pmi.output.DataEntry) for item in data_structure):
1339  self.data=data_structure
1340  elif all(isinstance(item, IMP.pmi.output.Cluster) for item in data_structure):
1341  nmodels=0
1342  for cluster in data_structure:
1343  nmodels+=len(cluster)
1344  self.data=[None]*nmodels
1345  for cluster in data_structure:
1346  for n,data in enumerate(cluster):
1347  index=cluster.members[n]
1348  self.data[index]=data
1349  else:
1350  raise TypeError("%filename should contain a list of IMP.pmi.output.DataEntry or IMP.pmi.output.Cluster" % filename)
1351 
1352  def set_frame(self,index):
1353  if self.cache is not None and self.cache[index]:
1354  self.cache.do_update(index)
1355  else:
1356  nm=self.data[index].rmf_name
1357  fidx=self.data[index].rmf_index
1358  if nm != self.current_rmf:
1359  self.link_to_rmf(nm)
1360  self.current_rmf=nm
1361  self.current_frame=-1
1362  if fidx!=self.current_frame:
1363  RMFHierarchyHandler.set_frame(self, fidx)
1364  self.current_frame=fidx
1365  if self.cache is not None:
1366  self.cache.do_store(index)
1367 
1368  self.current_index = index
1369 
1370  def __getitem__(self,int_slice_adaptor):
1371  if isinstance(int_slice_adaptor, int):
1372  self.set_frame(int_slice_adaptor)
1373  return self.data[int_slice_adaptor]
1374  elif isinstance(int_slice_adaptor, slice):
1375  return self.__iter__(int_slice_adaptor)
1376  else:
1377  raise TypeError("Unknown Type")
1378 
1379  def __len__(self):
1380  return len(self.data)
1381 
1382  def __iter__(self,slice_key=None):
1383  if slice_key is None:
1384  for i in range(len(self)):
1385  yield self[i]
1386  else:
1387  for i in range(len(self))[slice_key]:
1388  yield self[i]
1389 
1390  def do_filter_by_score(self,maximum_score):
1391  self.data=[d for d in self.data if d.score<=maximum_score]
1392 
1393  def get_scores(self):
1394  return [d.score for d in self.data]
1395 
1396  def get_feature_series(self,feature_name):
1397  return [d.features[feature_name] for d in self.data]
1398 
1399  def get_feature_names(self):
1400  return self.data[0].features.keys()
1401 
1402  def get_rmf_names(self):
1403  return [d.rmf_name for d in self.data]
1404 
1405  def get_stat_files_names(self):
1406  return [d.stat_file for d in self.data]
1407 
1408  def get_rmf_indexes(self):
1409  return [d.rmf_index for d in self.data]
1410 
1411  def get_info_from_stat_file(self, stat_file, score_threshold=None):
1412  po=ProcessOutput(stat_file)
1413  fs=po.get_keys()
1414  models = IMP.pmi.io.get_best_models([stat_file],
1415  score_key=self.score_key,
1416  feature_keys=fs,
1417  rmf_file_key="rmf_file",
1418  rmf_file_frame_key="rmf_frame_index",
1419  prefiltervalue=score_threshold,
1420  get_every=1)
1421 
1422 
1423 
1424  scores = [float(y) for y in models[2]]
1425  rmf_files = models[0]
1426  rmf_frame_indexes = models[1]
1427  features=models[3]
1428  return scores, rmf_files, rmf_frame_indexes,features
1429 
1430 
1431 class DataEntry(object):
1432  '''
1433  A class to store data associated to a model
1434  '''
1435  def __init__(self,stat_file=None,rmf_name=None,rmf_index=None,score=None,features=None):
1436  self.rmf_name=rmf_name
1437  self.rmf_index=rmf_index
1438  self.score=score
1439  self.features=features
1440  self.stat_file=stat_file
1441 
1442  def __repr__(self):
1443  s= "IMP.pmi.output.DataEntry\n"
1444  s+="---- stat file %s \n"%(self.stat_file)
1445  s+="---- rmf file %s \n"%(self.rmf_name)
1446  s+="---- rmf index %s \n"%(str(self.rmf_index))
1447  s+="---- score %s \n"%(str(self.score))
1448  s+="---- number of features %s \n"%(str(len(self.features.keys())))
1449  return s
1450 
1451 
1452 class Cluster(object):
1453  '''
1454  A container for models organized into clusters
1455  '''
1456  def __init__(self,cid=None):
1457  self.cluster_id=cid
1458  self.members=[]
1459  self.precision=None
1460  self.center_index=None
1461  self.members_data={}
1462 
1463  def add_member(self,index,data=None):
1464  self.members.append(index)
1465  self.members_data[index]=data
1466  self.average_score=self.compute_score()
1467 
1468  def compute_score(self):
1469  try:
1470  score=sum([d.score for d in self])/len(self)
1471  except AttributeError:
1472  score=None
1473  return score
1474 
1475  def __repr__(self):
1476  s= "IMP.pmi.output.Cluster\n"
1477  s+="---- cluster_id %s \n"%str(self.cluster_id)
1478  s+="---- precision %s \n"%str(self.precision)
1479  s+="---- average score %s \n"%str(self.average_score)
1480  s+="---- number of members %s \n"%str(len(self.members))
1481  s+="---- center index %s \n"%str(self.center_index)
1482  return s
1483 
1484  def __getitem__(self,int_slice_adaptor):
1485  if isinstance(int_slice_adaptor, int):
1486  index=self.members[int_slice_adaptor]
1487  return self.members_data[index]
1488  elif isinstance(int_slice_adaptor, slice):
1489  return self.__iter__(int_slice_adaptor)
1490  else:
1491  raise TypeError("Unknown Type")
1492 
1493  def __len__(self):
1494  return len(self.members)
1495 
1496  def __iter__(self,slice_key=None):
1497  if slice_key is None:
1498  for i in range(len(self)):
1499  yield self[i]
1500  else:
1501  for i in range(len(self))[slice_key]:
1502  yield self[i]
1503 
1504  def __add__(self, other):
1505  self.members+=other.members
1506  self.members_data.update(other.members_data)
1507  self.average_score=self.compute_score()
1508  self.precision=None
1509  self.center_index=None
1510  return self
1511 
1512 
1513 def plot_clusters_populations(clusters):
1514  indexes=[]
1515  populations=[]
1516  for cluster in clusters:
1517  indexes.append(cluster.cluster_id)
1518  populations.append(len(cluster))
1519 
1520  import matplotlib.pyplot as plt
1521  fig, ax = plt.subplots()
1522  ax.bar(indexes, populations, 0.5, color='r') #, yerr=men_std)
1523  ax.set_ylabel('Population')
1524  ax.set_xlabel(('Cluster index'))
1525  plt.show()
1526 
1527 def plot_clusters_precisions(clusters):
1528  indexes=[]
1529  precisions=[]
1530  for cluster in clusters:
1531  indexes.append(cluster.cluster_id)
1532 
1533  prec=cluster.precision
1534  print(cluster.cluster_id,prec)
1535  if prec is None:
1536  prec=0.0
1537  precisions.append(prec)
1538 
1539  import matplotlib.pyplot as plt
1540  fig, ax = plt.subplots()
1541  ax.bar(indexes, precisions, 0.5, color='r') #, yerr=men_std)
1542  ax.set_ylabel('Precision [A]')
1543  ax.set_xlabel(('Cluster index'))
1544  plt.show()
1545 
1546 def plot_clusters_scores(clusters):
1547  indexes=[]
1548  values=[]
1549  for cluster in clusters:
1550  indexes.append(cluster.cluster_id)
1551  values.append([])
1552  for data in cluster:
1553  values[-1].append(data.score)
1554 
1555  plot_fields_box_plots("scores.pdf", values, indexes, frequencies=None,
1556  valuename="Scores", positionname="Cluster index", xlabels=None,scale_plot_length=1.0)
1557 
1558 class CrossLinkIdentifierDatabase(object):
1559  def __init__(self):
1560  self.clidb=dict()
1561 
1562  def check_key(self,key):
1563  if key not in self.clidb:
1564  self.clidb[key]={}
1565 
1566  def set_unique_id(self,key,value):
1567  self.check_key(key)
1568  self.clidb[key]["XLUniqueID"]=str(value)
1569 
1570  def set_protein1(self,key,value):
1571  self.check_key(key)
1572  self.clidb[key]["Protein1"]=str(value)
1573 
1574  def set_protein2(self,key,value):
1575  self.check_key(key)
1576  self.clidb[key]["Protein2"]=str(value)
1577 
1578  def set_residue1(self,key,value):
1579  self.check_key(key)
1580  self.clidb[key]["Residue1"]=int(value)
1581 
1582  def set_residue2(self,key,value):
1583  self.check_key(key)
1584  self.clidb[key]["Residue2"]=int(value)
1585 
1586  def set_idscore(self,key,value):
1587  self.check_key(key)
1588  self.clidb[key]["IDScore"]=float(value)
1589 
1590  def set_state(self,key,value):
1591  self.check_key(key)
1592  self.clidb[key]["State"]=int(value)
1593 
1594  def set_sigma1(self,key,value):
1595  self.check_key(key)
1596  self.clidb[key]["Sigma1"]=str(value)
1597 
1598  def set_sigma2(self,key,value):
1599  self.check_key(key)
1600  self.clidb[key]["Sigma2"]=str(value)
1601 
1602  def set_psi(self,key,value):
1603  self.check_key(key)
1604  self.clidb[key]["Psi"]=str(value)
1605 
1606  def get_unique_id(self,key):
1607  return self.clidb[key]["XLUniqueID"]
1608 
1609  def get_protein1(self,key):
1610  return self.clidb[key]["Protein1"]
1611 
1612  def get_protein2(self,key):
1613  return self.clidb[key]["Protein2"]
1614 
1615  def get_residue1(self,key):
1616  return self.clidb[key]["Residue1"]
1617 
1618  def get_residue2(self,key):
1619  return self.clidb[key]["Residue2"]
1620 
1621  def get_idscore(self,key):
1622  return self.clidb[key]["IDScore"]
1623 
1624  def get_state(self,key):
1625  return self.clidb[key]["State"]
1626 
1627  def get_sigma1(self,key):
1628  return self.clidb[key]["Sigma1"]
1629 
1630  def get_sigma2(self,key):
1631  return self.clidb[key]["Sigma2"]
1632 
1633  def get_psi(self,key):
1634  return self.clidb[key]["Psi"]
1635 
1636  def set_float_feature(self,key,value,feature_name):
1637  self.check_key(key)
1638  self.clidb[key][feature_name]=float(value)
1639 
1640  def set_int_feature(self,key,value,feature_name):
1641  self.check_key(key)
1642  self.clidb[key][feature_name]=int(value)
1643 
1644  def set_string_feature(self,key,value,feature_name):
1645  self.check_key(key)
1646  self.clidb[key][feature_name]=str(value)
1647 
1648  def get_feature(self,key,feature_name):
1649  return self.clidb[key][feature_name]
1650 
1651  def write(self,filename):
1652  with open(filename, 'wb') as handle:
1653  pickle.dump(self.clidb,handle)
1654 
1655  def load(self,filename):
1656  with open(filename, 'rb') as handle:
1657  self.clidb=pickle.load(handle)
1658 
1659 def plot_fields(fields, output, framemin=None, framemax=None):
1660  """Plot the given fields and save a figure as `output`.
1661  The fields generally are extracted from a stat file
1662  using ProcessOutput.get_fields()."""
1663  import matplotlib as mpl
1664  mpl.use('Agg')
1665  import matplotlib.pyplot as plt
1666 
1667  plt.rc('lines', linewidth=4)
1668  fig, axs = plt.subplots(nrows=len(fields))
1669  fig.set_size_inches(10.5, 5.5 * len(fields))
1670  plt.rc('axes')
1671 
1672  n = 0
1673  for key in fields:
1674  if framemin is None:
1675  framemin = 0
1676  if framemax is None:
1677  framemax = len(fields[key])
1678  x = list(range(framemin, framemax))
1679  y = [float(y) for y in fields[key][framemin:framemax]]
1680  if len(fields) > 1:
1681  axs[n].plot(x, y)
1682  axs[n].set_title(key, size="xx-large")
1683  axs[n].tick_params(labelsize=18, pad=10)
1684  else:
1685  axs.plot(x, y)
1686  axs.set_title(key, size="xx-large")
1687  axs.tick_params(labelsize=18, pad=10)
1688  n += 1
1689 
1690  # Tweak spacing between subplots to prevent labels from overlapping
1691  plt.subplots_adjust(hspace=0.3)
1692  plt.savefig(output)
1693 
1694 
1696  name, values_lists, valuename=None, bins=40, colors=None, format="png",
1697  reference_xline=None, yplotrange=None, xplotrange=None, normalized=True,
1698  leg_names=None):
1699  '''Plot a list of histograms from a value list.
1700  @param name the name of the plot
1701  @param value_lists the list of list of values eg: [[...],[...],[...]]
1702  @param valuename the y-label
1703  @param bins the number of bins
1704  @param colors If None, will use rainbow. Else will use specific list
1705  @param format output format
1706  @param reference_xline plot a reference line parallel to the y-axis
1707  @param yplotrange the range for the y-axis
1708  @param normalized whether the histogram is normalized or not
1709  @param leg_names names for the legend
1710  '''
1711 
1712  import matplotlib as mpl
1713  mpl.use('Agg')
1714  import matplotlib.pyplot as plt
1715  import matplotlib.cm as cm
1716  fig = plt.figure(figsize=(18.0, 9.0))
1717 
1718  if colors is None:
1719  colors = cm.rainbow(np.linspace(0, 1, len(values_lists)))
1720  for nv,values in enumerate(values_lists):
1721  col=colors[nv]
1722  if leg_names is not None:
1723  label=leg_names[nv]
1724  else:
1725  label=str(nv)
1726  h=plt.hist(
1727  [float(y) for y in values],
1728  bins=bins,
1729  color=col,
1730  density=normalized,histtype='step',lw=4,
1731  label=label)
1732 
1733  # plt.title(name,size="xx-large")
1734  plt.tick_params(labelsize=12, pad=10)
1735  if valuename is None:
1736  plt.xlabel(name, size="xx-large")
1737  else:
1738  plt.xlabel(valuename, size="xx-large")
1739  plt.ylabel("Frequency", size="xx-large")
1740 
1741  if not yplotrange is None:
1742  plt.ylim()
1743  if not xplotrange is None:
1744  plt.xlim(xplotrange)
1745 
1746  plt.legend(loc=2)
1747 
1748  if not reference_xline is None:
1749  plt.axvline(
1750  reference_xline,
1751  color='red',
1752  linestyle='dashed',
1753  linewidth=1)
1754 
1755  plt.savefig(name + "." + format, dpi=150, transparent=True)
1756 
1757 
1758 def plot_fields_box_plots(name, values, positions, frequencies=None,
1759  valuename="None", positionname="None", xlabels=None,scale_plot_length=1.0):
1760  '''
1761  Plot time series as boxplots.
1762  fields is a list of time series, positions are the x-values
1763  valuename is the y-label, positionname is the x-label
1764  '''
1765 
1766  import matplotlib as mpl
1767  mpl.use('Agg')
1768  import matplotlib.pyplot as plt
1769  from matplotlib.patches import Polygon
1770 
1771  bps = []
1772  fig = plt.figure(figsize=(float(len(positions))*scale_plot_length, 5.0))
1773  fig.canvas.set_window_title(name)
1774 
1775  ax1 = fig.add_subplot(111)
1776 
1777  plt.subplots_adjust(left=0.1, right=0.990, top=0.95, bottom=0.4)
1778 
1779  bps.append(plt.boxplot(values, notch=0, sym='', vert=1,
1780  whis=1.5, positions=positions))
1781 
1782  plt.setp(bps[-1]['boxes'], color='black', lw=1.5)
1783  plt.setp(bps[-1]['whiskers'], color='black', ls=":", lw=1.5)
1784 
1785  if frequencies is not None:
1786  for n,v in enumerate(values):
1787  plist=[positions[n]]*len(v)
1788  ax1.plot(plist, v, 'gx', alpha=0.7, markersize=7)
1789 
1790  # print ax1.xaxis.get_majorticklocs()
1791  if not xlabels is None:
1792  ax1.set_xticklabels(xlabels)
1793  plt.xticks(rotation=90)
1794  plt.xlabel(positionname)
1795  plt.ylabel(valuename)
1796 
1797  plt.savefig(name+".pdf",dpi=150)
1798  plt.show()
1799 
1800 
1801 def plot_xy_data(x,y,title=None,out_fn=None,display=True,set_plot_yaxis_range=None,
1802  xlabel=None,ylabel=None):
1803  import matplotlib as mpl
1804  mpl.use('Agg')
1805  import matplotlib.pyplot as plt
1806  plt.rc('lines', linewidth=2)
1807 
1808  fig, ax = plt.subplots(nrows=1)
1809  fig.set_size_inches(8,4.5)
1810  if title is not None:
1811  fig.canvas.set_window_title(title)
1812 
1813  #plt.rc('axes', color='r')
1814  ax.plot(x,y,color='r')
1815  if set_plot_yaxis_range is not None:
1816  x1,x2,y1,y2=plt.axis()
1817  y1=set_plot_yaxis_range[0]
1818  y2=set_plot_yaxis_range[1]
1819  plt.axis((x1,x2,y1,y2))
1820  if title is not None:
1821  ax.set_title(title)
1822  if xlabel is not None:
1823  ax.set_xlabel(xlabel)
1824  if ylabel is not None:
1825  ax.set_ylabel(ylabel)
1826  if out_fn is not None:
1827  plt.savefig(out_fn+".pdf")
1828  if display:
1829  plt.show()
1830  plt.close(fig)
1831 
1832 def plot_scatter_xy_data(x,y,labelx="None",labely="None",
1833  xmin=None,xmax=None,ymin=None,ymax=None,
1834  savefile=False,filename="None.eps",alpha=0.75):
1835 
1836  import matplotlib as mpl
1837  mpl.use('Agg')
1838  import matplotlib.pyplot as plt
1839  import sys
1840  from matplotlib import rc
1841  #rc('font', **{'family':'serif','serif':['Palatino']})
1842  rc('font',**{'family':'sans-serif','sans-serif':['Helvetica']})
1843  #rc('text', usetex=True)
1844 
1845  fig, axs = plt.subplots(1)
1846 
1847  axs0 = axs
1848 
1849  axs0.set_xlabel(labelx, size="xx-large")
1850  axs0.set_ylabel(labely, size="xx-large")
1851  axs0.tick_params(labelsize=18, pad=10)
1852 
1853  plot2 = []
1854 
1855  plot2.append(axs0.plot(x, y, 'o', color='k',lw=2, ms=0.1, alpha=alpha, c="w"))
1856 
1857  axs0.legend(
1858  loc=0,
1859  frameon=False,
1860  scatterpoints=1,
1861  numpoints=1,
1862  columnspacing=1)
1863 
1864  fig.set_size_inches(8.0, 8.0)
1865  fig.subplots_adjust(left=0.161, right=0.850, top=0.95, bottom=0.11)
1866  if (not ymin is None) and (not ymax is None):
1867  axs0.set_ylim(ymin,ymax)
1868  if (not xmin is None) and (not xmax is None):
1869  axs0.set_xlim(xmin,xmax)
1870 
1871  #plt.show()
1872  if savefile:
1873  fig.savefig(filename, dpi=300)
1874 
1875 
1876 def get_graph_from_hierarchy(hier):
1877  graph = []
1878  depth_dict = {}
1879  depth = 0
1880  (graph, depth, depth_dict) = recursive_graph(
1881  hier, graph, depth, depth_dict)
1882 
1883  # filters node labels according to depth_dict
1884  node_labels_dict = {}
1885  node_size_dict = {}
1886  for key in depth_dict:
1887  node_size_dict = 10 / depth_dict[key]
1888  if depth_dict[key] < 3:
1889  node_labels_dict[key] = key
1890  else:
1891  node_labels_dict[key] = ""
1892  draw_graph(graph, labels_dict=node_labels_dict)
1893 
1894 
1895 def recursive_graph(hier, graph, depth, depth_dict):
1896  depth = depth + 1
1897  nameh = IMP.atom.Hierarchy(hier).get_name()
1898  index = str(hier.get_particle().get_index())
1899  name1 = nameh + "|#" + index
1900  depth_dict[name1] = depth
1901 
1902  children = IMP.atom.Hierarchy(hier).get_children()
1903 
1904  if len(children) == 1 or children is None:
1905  depth = depth - 1
1906  return (graph, depth, depth_dict)
1907 
1908  else:
1909  for c in children:
1910  (graph, depth, depth_dict) = recursive_graph(
1911  c, graph, depth, depth_dict)
1912  nameh = IMP.atom.Hierarchy(c).get_name()
1913  index = str(c.get_particle().get_index())
1914  namec = nameh + "|#" + index
1915  graph.append((name1, namec))
1916 
1917  depth = depth - 1
1918  return (graph, depth, depth_dict)
1919 
1920 
1921 def draw_graph(graph, labels_dict=None, graph_layout='spring',
1922  node_size=5, node_color=None, node_alpha=0.3,
1923  node_text_size=11, fixed=None, pos=None,
1924  edge_color='blue', edge_alpha=0.3, edge_thickness=1,
1925  edge_text_pos=0.3,
1926  validation_edges=None,
1927  text_font='sans-serif',
1928  out_filename=None):
1929 
1930  import matplotlib as mpl
1931  mpl.use('Agg')
1932  import networkx as nx
1933  import matplotlib.pyplot as plt
1934  from math import sqrt, pi
1935 
1936  # create networkx graph
1937  G = nx.Graph()
1938 
1939  # add edges
1940  if isinstance(edge_thickness, list):
1941  for edge,weight in zip(graph,edge_thickness):
1942  G.add_edge(edge[0], edge[1], weight=weight)
1943  else:
1944  for edge in graph:
1945  G.add_edge(edge[0], edge[1])
1946 
1947  if node_color==None:
1948  node_color_rgb=(0,0,0)
1949  node_color_hex="000000"
1950  else:
1952  tmpcolor_rgb=[]
1953  tmpcolor_hex=[]
1954  for node in G.nodes():
1955  cctuple=cc.rgb(node_color[node])
1956  tmpcolor_rgb.append((cctuple[0]/255,cctuple[1]/255,cctuple[2]/255))
1957  tmpcolor_hex.append(node_color[node])
1958  node_color_rgb=tmpcolor_rgb
1959  node_color_hex=tmpcolor_hex
1960 
1961  # get node sizes if dictionary
1962  if isinstance(node_size, dict):
1963  tmpsize=[]
1964  for node in G.nodes():
1965  size=sqrt(node_size[node])/pi*10.0
1966  tmpsize.append(size)
1967  node_size=tmpsize
1968 
1969  for n,node in enumerate(G.nodes()):
1970  color=node_color_hex[n]
1971  size=node_size[n]
1972  nx.set_node_attributes(G, "graphics", {node : {'type': 'ellipse','w': size, 'h': size,'fill': '#'+color, 'label': node}})
1973  nx.set_node_attributes(G, "LabelGraphics", {node : {'type': 'text','text':node, 'color':'#000000', 'visible':'true'}})
1974 
1975  for edge in G.edges():
1976  nx.set_edge_attributes(G, "graphics", {edge : {'width': 1,'fill': '#000000'}})
1977 
1978  for ve in validation_edges:
1979  print(ve)
1980  if (ve[0],ve[1]) in G.edges():
1981  print("found forward")
1982  nx.set_edge_attributes(G, "graphics", {ve : {'width': 1,'fill': '#00FF00'}})
1983  elif (ve[1],ve[0]) in G.edges():
1984  print("found backward")
1985  nx.set_edge_attributes(G, "graphics", {(ve[1],ve[0]) : {'width': 1,'fill': '#00FF00'}})
1986  else:
1987  G.add_edge(ve[0], ve[1])
1988  print("not found")
1989  nx.set_edge_attributes(G, "graphics", {ve : {'width': 1,'fill': '#FF0000'}})
1990 
1991  # these are different layouts for the network you may try
1992  # shell seems to work best
1993  if graph_layout == 'spring':
1994  print(fixed, pos)
1995  graph_pos = nx.spring_layout(G,k=1.0/8.0,fixed=fixed,pos=pos)
1996  elif graph_layout == 'spectral':
1997  graph_pos = nx.spectral_layout(G)
1998  elif graph_layout == 'random':
1999  graph_pos = nx.random_layout(G)
2000  else:
2001  graph_pos = nx.shell_layout(G)
2002 
2003 
2004  # draw graph
2005  nx.draw_networkx_nodes(G, graph_pos, node_size=node_size,
2006  alpha=node_alpha, node_color=node_color_rgb,
2007  linewidths=0)
2008  nx.draw_networkx_edges(G, graph_pos, width=edge_thickness,
2009  alpha=edge_alpha, edge_color=edge_color)
2010  nx.draw_networkx_labels(
2011  G, graph_pos, labels=labels_dict, font_size=node_text_size,
2012  font_family=text_font)
2013  if out_filename:
2014  plt.savefig(out_filename)
2015  nx.write_gml(G,'out.gml')
2016  plt.show()
2017 
2018 
2019 def draw_table():
2020 
2021  # still an example!
2022 
2023  from ipyD3 import d3object
2024  from IPython.display import display
2025 
2026  d3 = d3object(width=800,
2027  height=400,
2028  style='JFTable',
2029  number=1,
2030  d3=None,
2031  title='Example table with d3js',
2032  desc='An example table created created with d3js with data generated with Python.')
2033  data = [
2034  [1277.0,
2035  654.0,
2036  288.0,
2037  1976.0,
2038  3281.0,
2039  3089.0,
2040  10336.0,
2041  4650.0,
2042  4441.0,
2043  4670.0,
2044  944.0,
2045  110.0],
2046  [1318.0,
2047  664.0,
2048  418.0,
2049  1952.0,
2050  3581.0,
2051  4574.0,
2052  11457.0,
2053  6139.0,
2054  7078.0,
2055  6561.0,
2056  2354.0,
2057  710.0],
2058  [1783.0,
2059  774.0,
2060  564.0,
2061  1470.0,
2062  3571.0,
2063  3103.0,
2064  9392.0,
2065  5532.0,
2066  5661.0,
2067  4991.0,
2068  2032.0,
2069  680.0],
2070  [1301.0,
2071  604.0,
2072  286.0,
2073  2152.0,
2074  3282.0,
2075  3369.0,
2076  10490.0,
2077  5406.0,
2078  4727.0,
2079  3428.0,
2080  1559.0,
2081  620.0],
2082  [1537.0,
2083  1714.0,
2084  724.0,
2085  4824.0,
2086  5551.0,
2087  8096.0,
2088  16589.0,
2089  13650.0,
2090  9552.0,
2091  13709.0,
2092  2460.0,
2093  720.0],
2094  [5691.0,
2095  2995.0,
2096  1680.0,
2097  11741.0,
2098  16232.0,
2099  14731.0,
2100  43522.0,
2101  32794.0,
2102  26634.0,
2103  31400.0,
2104  7350.0,
2105  3010.0],
2106  [1650.0,
2107  2096.0,
2108  60.0,
2109  50.0,
2110  1180.0,
2111  5602.0,
2112  15728.0,
2113  6874.0,
2114  5115.0,
2115  3510.0,
2116  1390.0,
2117  170.0],
2118  [72.0, 60.0, 60.0, 10.0, 120.0, 172.0, 1092.0, 675.0, 408.0, 360.0, 156.0, 100.0]]
2119  data = [list(i) for i in zip(*data)]
2120  sRows = [['January',
2121  'February',
2122  'March',
2123  'April',
2124  'May',
2125  'June',
2126  'July',
2127  'August',
2128  'September',
2129  'October',
2130  'November',
2131  'Deecember']]
2132  sColumns = [['Prod {0}'.format(i) for i in range(1, 9)],
2133  [None, '', None, None, 'Group 1', None, None, 'Group 2']]
2134  d3.addSimpleTable(data,
2135  fontSizeCells=[12, ],
2136  sRows=sRows,
2137  sColumns=sColumns,
2138  sRowsMargins=[5, 50, 0],
2139  sColsMargins=[5, 20, 10],
2140  spacing=0,
2141  addBorders=1,
2142  addOutsideBorders=-1,
2143  rectWidth=45,
2144  rectHeight=0
2145  )
2146  html = d3.render(mode=['html', 'show'])
2147  display(html)
static bool get_is_setup(const IMP::ParticleAdaptor &p)
Definition: Residue.h:156
A container for models organized into clusters.
Definition: output.py:1452
A class for reading stat files (either rmf or ascii v1 and v2)
Definition: output.py:904
atom::Hierarchies create_hierarchies(RMF::FileConstHandle fh, Model *m)
RMF::FrameID save_frame(RMF::FileHandle file, std::string name="")
Save the current state of the linked objects as a new RMF frame.
static bool get_is_setup(const IMP::ParticleAdaptor &p)
Definition: atom/Atom.h:241
def plot_field_histogram
Plot a list of histograms from a value list.
Definition: output.py:1695
def plot_fields_box_plots
Plot time series as boxplots.
Definition: output.py:1758
Utility classes and functions for reading and storing PMI files.
def get_best_models
Given a list of stat files, read them all and find the best models.
Miscellaneous utilities.
Definition: tools.py:1
A class to store data associated to a model.
Definition: output.py:1431
void handle_use_deprecated(std::string message)
std::string get_module_version()
Return the version of this module, as a string.
Change color code to hexadecimal to rgb.
Definition: tools.py:787
void write_pdb(const Selection &mhd, TextOutput out, unsigned int model=1)
Collect statistics from ProcessOutput.get_fields().
Definition: output.py:893
def get_prot_name_from_particle
Return the component name provided a particle and a list of names.
Definition: tools.py:533
def get_fields
Get the desired field names, and return a dictionary.
Definition: output.py:966
Warning related to handling of structures.
static bool get_is_setup(Model *m, ParticleIndex pi)
Definition: Fragment.h:46
def link_to_rmf
Link to another RMF file.
Definition: output.py:1117
std::string get_molecule_name_and_copy(atom::Hierarchy h)
Walk up a PMI2 hierarchy/representations and get the "molname.copynum".
Definition: pmi/utilities.h:85
The standard decorator for manipulating molecular structures.
Ints get_index(const ParticlesTemp &particles, const Subset &subset, const Subsets &excluded)
def init_pdb
Init PDB Writing.
Definition: output.py:253
int get_number_of_frames(const ::npctransport_proto::Assignment &config, double time_step)
A decorator for a particle representing an atom.
Definition: atom/Atom.h:234
Base class for capturing a modeling protocol.
Definition: output.py:44
The type for a residue.
void load_frame(RMF::FileConstHandle file, RMF::FrameID frame)
Load the given RMF frame into the state of the linked objects.
A decorator for a particle with x,y,z coordinates.
Definition: XYZ.h:30
A base class for Keys.
Definition: Key.h:44
void add_hierarchies(RMF::NodeHandle fh, const atom::Hierarchies &hs)
Class for easy writing of PDBs, RMFs, and stat files.
Definition: output.py:198
void add_geometries(RMF::NodeHandle parent, const display::GeometriesTemp &r)
Add geometries to a given parent node.
void add_restraints(RMF::NodeHandle fh, const Restraints &hs)
A decorator for a particle that is part of a rigid body but not rigid.
Definition: rigid_bodies.h:750
bool get_is_canonical(atom::Hierarchy h)
Walk up a PMI2 hierarchy/representations and check if the root is named System.
Definition: pmi/utilities.h:91
Display a segment connecting a pair of particles.
Definition: XYZR.h:170
A decorator for a residue.
Definition: Residue.h:135
Basic functionality that is expected to be used by a wide variety of IMP users.
def get_pdb_names
Get a list of all PDB files being output by this instance.
Definition: output.py:225
def get_prot_name_from_particle
Get the protein name from the particle.
Definition: output.py:334
class to link stat files to several rmf files
Definition: output.py:1216
class to allow more advanced handling of RMF files.
Definition: output.py:1091
void link_hierarchies(RMF::FileConstHandle fh, const atom::Hierarchies &hs)
def plot_fields
Plot the given fields and save a figure as output.
Definition: output.py:1659
void add_geometry(RMF::FileHandle file, display::Geometry *r)
Add a single geometry to the file.
Store info for a chain of a protein.
Definition: Chain.h:61
Python classes to represent, score, sample and analyze models.
Functionality for loading, creating, manipulating and scoring atomic structures.
def get_rbs_and_beads
Returns unique objects in original order.
Definition: tools.py:1183
Hierarchies get_leaves(const Selection &h)
Select hierarchy particles identified by the biological name.
Definition: Selection.h:66
def init_rmf
Initialize an RMF file.
Definition: output.py:540
def get_residue_indexes
Retrieve the residue indexes for the given particle.
Definition: tools.py:553
static bool get_is_setup(const IMP::ParticleAdaptor &p)
Definition: rigid_bodies.h:752
std::string get_module_version()
Return the version of this module, as a string.
def sublist_iterator
Yield all sublists of length >= lmin and <= lmax.
Definition: tools.py:626
A decorator for a particle with x,y,z coordinates and a radius.
Definition: XYZR.h:27