IMP logo
IMP Reference Guide  2.16.0
The Integrative Modeling Platform
output.py
1 """@namespace IMP.pmi.output
2  Classes for writing output files and processing them.
3 """
4 
5 from __future__ import print_function, division
6 import IMP
7 import IMP.atom
8 import IMP.core
9 import IMP.pmi
10 import IMP.pmi.tools
11 import IMP.pmi.io
12 import os
13 import sys
14 import ast
15 import RMF
16 import numpy as np
17 import operator
18 import itertools
19 import warnings
20 import string
21 import ihm.format
22 import collections
23 try:
24  import cPickle as pickle
25 except ImportError:
26  import pickle
27 
28 
29 class _ChainIDs(object):
30  """Map indices to multi-character chain IDs.
31  We label the first 26 chains A-Z, then we move to two-letter
32  chain IDs: AA through AZ, then BA through BZ, through to ZZ.
33  This continues with longer chain IDs."""
34  def __getitem__(self, ind):
35  chars = string.ascii_uppercase
36  lc = len(chars)
37  ids = []
38  while ind >= lc:
39  ids.append(chars[ind % lc])
40  ind = ind // lc - 1
41  ids.append(chars[ind])
42  return "".join(reversed(ids))
43 
44 
45 class ProtocolOutput(object):
46  """Base class for capturing a modeling protocol.
47  Unlike simple output of model coordinates, a complete
48  protocol includes the input data used, details on the restraints,
49  sampling, and clustering, as well as output models.
50  Use via IMP.pmi.topology.System.add_protocol_output().
51 
52  @see IMP.pmi.mmcif.ProtocolOutput for a concrete subclass that outputs
53  mmCIF files.
54  """
55  pass
56 
57 
58 def _flatten(seq):
59  for elt in seq:
60  if isinstance(elt, (tuple, list)):
61  for elt2 in _flatten(elt):
62  yield elt2
63  else:
64  yield elt
65 
66 
67 def _disambiguate_chain(chid, seen_chains):
68  """Make sure that the chain ID is unique; warn and correct if it isn't"""
69  # Handle null chain IDs
70  if chid == '\0':
71  chid = ' '
72 
73  if chid in seen_chains:
74  warnings.warn("Duplicate chain ID '%s' encountered" % chid,
76 
77  for suffix in itertools.count(1):
78  new_chid = chid + "%d" % suffix
79  if new_chid not in seen_chains:
80  seen_chains.add(new_chid)
81  return new_chid
82  seen_chains.add(chid)
83  return chid
84 
85 
86 def _write_pdb_internal(flpdb, particle_infos_for_pdb, geometric_center,
87  write_all_residues_per_bead):
88  for n, tupl in enumerate(particle_infos_for_pdb):
89  (xyz, atom_type, residue_type,
90  chain_id, residue_index, all_indexes, radius) = tupl
91  if atom_type is None:
92  atom_type = IMP.atom.AT_CA
93  if write_all_residues_per_bead and all_indexes is not None:
94  for residue_number in all_indexes:
95  flpdb.write(
96  IMP.atom.get_pdb_string((xyz[0] - geometric_center[0],
97  xyz[1] - geometric_center[1],
98  xyz[2] - geometric_center[2]),
99  n+1, atom_type, residue_type,
100  chain_id[:1], residue_number, ' ',
101  1.00, radius))
102  else:
103  flpdb.write(
104  IMP.atom.get_pdb_string((xyz[0] - geometric_center[0],
105  xyz[1] - geometric_center[1],
106  xyz[2] - geometric_center[2]),
107  n+1, atom_type, residue_type,
108  chain_id[:1], residue_index, ' ',
109  1.00, radius))
110  flpdb.write("ENDMDL\n")
111 
112 
113 _Entity = collections.namedtuple('_Entity', ('id', 'seq'))
114 _ChainInfo = collections.namedtuple('_ChainInfo', ('entity', 'name'))
115 
116 
117 def _get_chain_info(chains, root_hier):
118  chain_info = {}
119  entities = {}
120  all_entities = []
121  for mol in IMP.atom.get_by_type(root_hier, IMP.atom.MOLECULE_TYPE):
123  chain_id = chains[molname]
124  chain = IMP.atom.Chain(mol)
125  seq = chain.get_sequence()
126  if seq not in entities:
127  entities[seq] = e = _Entity(id=len(entities)+1, seq=seq)
128  all_entities.append(e)
129  entity = entities[seq]
130  info = _ChainInfo(entity=entity, name=molname)
131  chain_info[chain_id] = info
132  return chain_info, all_entities
133 
134 
135 def _write_mmcif_internal(flpdb, particle_infos_for_pdb, geometric_center,
136  write_all_residues_per_bead, chains, root_hier):
137  # get dict with keys=chain IDs, values=chain info
138  chain_info, entities = _get_chain_info(chains, root_hier)
139 
140  writer = ihm.format.CifWriter(flpdb)
141  writer.start_block('model')
142  with writer.category("_entry") as lp:
143  lp.write(id='model')
144 
145  with writer.loop("_entity", ["id"]) as lp:
146  for e in entities:
147  lp.write(id=e.id)
148 
149  with writer.loop("_entity_poly",
150  ["entity_id", "pdbx_seq_one_letter_code"]) as lp:
151  for e in entities:
152  lp.write(entity_id=e.id, pdbx_seq_one_letter_code=e.seq)
153 
154  with writer.loop("_struct_asym", ["id", "entity_id", "details"]) as lp:
155  # Longer chain IDs (e.g. AA) should always come after shorter (e.g. Z)
156  for chid in sorted(chains.values(), key=lambda x: (len(x.strip()), x)):
157  ci = chain_info[chid]
158  lp.write(id=chid, entity_id=ci.entity.id, details=ci.name)
159 
160  with writer.loop("_atom_site",
161  ["group_PDB", "type_symbol", "label_atom_id",
162  "label_comp_id", "label_asym_id", "label_seq_id",
163  "auth_seq_id",
164  "Cartn_x", "Cartn_y", "Cartn_z", "label_entity_id",
165  "pdbx_pdb_model_num",
166  "id"]) as lp:
167  ordinal = 1
168  for n, tupl in enumerate(particle_infos_for_pdb):
169  (xyz, atom_type, residue_type,
170  chain_id, residue_index, all_indexes, radius) = tupl
171  ci = chain_info[chain_id]
172  if atom_type is None:
173  atom_type = IMP.atom.AT_CA
174  c = xyz - geometric_center
175  if write_all_residues_per_bead and all_indexes is not None:
176  for residue_number in all_indexes:
177  lp.write(group_PDB='ATOM',
178  type_symbol='C',
179  label_atom_id=atom_type.get_string(),
180  label_comp_id=residue_type.get_string(),
181  label_asym_id=chain_id,
182  label_seq_id=residue_index,
183  auth_seq_id=residue_index, Cartn_x=c[0],
184  Cartn_y=c[1], Cartn_z=c[2], id=ordinal,
185  pdbx_pdb_model_num=1,
186  label_entity_id=ci.entity.id)
187  ordinal += 1
188  else:
189  lp.write(group_PDB='ATOM', type_symbol='C',
190  label_atom_id=atom_type.get_string(),
191  label_comp_id=residue_type.get_string(),
192  label_asym_id=chain_id,
193  label_seq_id=residue_index,
194  auth_seq_id=residue_index, Cartn_x=c[0],
195  Cartn_y=c[1], Cartn_z=c[2], id=ordinal,
196  pdbx_pdb_model_num=1,
197  label_entity_id=ci.entity.id)
198  ordinal += 1
199 
200 
201 class Output(object):
202  """Class for easy writing of PDBs, RMFs, and stat files
203 
204  @note Model should be updated prior to writing outputs.
205  """
206  def __init__(self, ascii=True, atomistic=False):
207  self.dictionary_pdbs = {}
208  self._pdb_mmcif = {}
209  self.dictionary_rmfs = {}
210  self.dictionary_stats = {}
211  self.dictionary_stats2 = {}
212  self.best_score_list = None
213  self.nbestscoring = None
214  self.suffixes = []
215  self.replica_exchange = False
216  self.ascii = ascii
217  self.initoutput = {}
218  self.residuetypekey = IMP.StringKey("ResidueName")
219  # 1-character chain IDs, suitable for PDB output
220  self.chainids = "ABCDEFGHIJKLMNOPQRSTUVWXYZ" \
221  "abcdefghijklmnopqrstuvwxyz0123456789"
222  # Multi-character chain IDs, suitable for mmCIF output
223  self.multi_chainids = _ChainIDs()
224  self.dictchain = {} # keys are molecule names, values are chain ids
225  self.particle_infos_for_pdb = {}
226  self.atomistic = atomistic
227  self.use_pmi2 = False
228 
229  def get_pdb_names(self):
230  """Get a list of all PDB files being output by this instance"""
231  return list(self.dictionary_pdbs.keys())
232 
233  def get_rmf_names(self):
234  return list(self.dictionary_rmfs.keys())
235 
236  def get_stat_names(self):
237  return list(self.dictionary_stats.keys())
238 
239  def _init_dictchain(self, name, prot, multichar_chain=False, mmcif=False):
240  self.dictchain[name] = {}
241  self.use_pmi2 = False
242  seen_chains = set()
243 
244  # attempt to find PMI objects.
245  if IMP.pmi.get_is_canonical(prot):
246  self.use_pmi2 = True
247  self.atomistic = True # detects automatically
248  for n, mol in enumerate(IMP.atom.get_by_type(
249  prot, IMP.atom.MOLECULE_TYPE)):
250  chid = IMP.atom.Chain(mol).get_id()
251  if not mmcif and len(chid) > 1:
252  raise ValueError(
253  "The system contains at least one chain ID (%s) that "
254  "is more than 1 character long; this cannot be "
255  "represented in PDB. Either write mmCIF files "
256  "instead, or assign 1-character IDs to all chains "
257  "(this can be done with the `chain_ids` argument to "
258  "BuildSystem.add_state())." % chid)
259  chid = _disambiguate_chain(chid, seen_chains)
261  self.dictchain[name][molname] = chid
262  else:
263  chainids = self.multi_chainids if multichar_chain \
264  else self.chainids
265  for n, i in enumerate(self.dictionary_pdbs[name].get_children()):
266  self.dictchain[name][i.get_name()] = chainids[n]
267 
268  def init_pdb(self, name, prot, mmcif=False):
269  """Init PDB Writing.
270  @param name The PDB filename
271  @param prot The hierarchy to write to this pdb file
272  @param mmcif If True, write PDBs in mmCIF format
273  @note if the PDB name is 'System' then will use Selection
274  to get molecules
275  """
276  flpdb = open(name, 'w')
277  flpdb.close()
278  self.dictionary_pdbs[name] = prot
279  self._pdb_mmcif[name] = mmcif
280  self._init_dictchain(name, prot, mmcif=mmcif)
281 
282  def write_psf(self, filename, name):
283  flpsf = open(filename, 'w')
284  flpsf.write("PSF CMAP CHEQ" + "\n")
285  index_residue_pair_list = {}
286  (particle_infos_for_pdb, geometric_center) = \
287  self.get_particle_infos_for_pdb_writing(name)
288  nparticles = len(particle_infos_for_pdb)
289  flpsf.write(str(nparticles) + " !NATOM" + "\n")
290  for n, p in enumerate(particle_infos_for_pdb):
291  atom_index = n+1
292  residue_type = p[2]
293  chain = p[3]
294  resid = p[4]
295  flpsf.write('{0:8d}{1:1s}{2:4s}{3:1s}{4:4s}{5:1s}{6:4s}{7:1s}'
296  '{8:4s}{9:1s}{10:4s}{11:14.6f}{12:14.6f}{13:8d}'
297  '{14:14.6f}{15:14.6f}'.format(
298  atom_index, " ", chain, " ", str(resid), " ",
299  '"'+residue_type.get_string()+'"', " ", "C",
300  " ", "C", 1.0, 0.0, 0, 0.0, 0.0))
301  flpsf.write('\n')
302  if chain not in index_residue_pair_list:
303  index_residue_pair_list[chain] = [(atom_index, resid)]
304  else:
305  index_residue_pair_list[chain].append((atom_index, resid))
306 
307  # now write the connectivity
308  indexes_pairs = []
309  for chain in sorted(index_residue_pair_list.keys()):
310 
311  ls = index_residue_pair_list[chain]
312  # sort by residue
313  ls = sorted(ls, key=lambda tup: tup[1])
314  # get the index list
315  indexes = [x[0] for x in ls]
316  # get the contiguous pairs
317  indexes_pairs += list(IMP.pmi.tools.sublist_iterator(
318  indexes, lmin=2, lmax=2))
319  nbonds = len(indexes_pairs)
320  flpsf.write(str(nbonds)+" !NBOND: bonds"+"\n")
321 
322  # save bonds in fixed column format
323  for i in range(0, len(indexes_pairs), 4):
324  for bond in indexes_pairs[i:i+4]:
325  flpsf.write('{0:8d}{1:8d}'.format(*bond))
326  flpsf.write('\n')
327 
328  del particle_infos_for_pdb
329  flpsf.close()
330 
331  def write_pdb(self, name, appendmode=True,
332  translate_to_geometric_center=False,
333  write_all_residues_per_bead=False):
334 
335  (particle_infos_for_pdb,
336  geometric_center) = self.get_particle_infos_for_pdb_writing(name)
337 
338  if not translate_to_geometric_center:
339  geometric_center = (0, 0, 0)
340 
341  filemode = 'a' if appendmode else 'w'
342  with open(name, filemode) as flpdb:
343  if self._pdb_mmcif[name]:
344  _write_mmcif_internal(flpdb, particle_infos_for_pdb,
345  geometric_center,
346  write_all_residues_per_bead,
347  self.dictchain[name],
348  self.dictionary_pdbs[name])
349  else:
350  _write_pdb_internal(flpdb, particle_infos_for_pdb,
351  geometric_center,
352  write_all_residues_per_bead)
353 
354  def get_prot_name_from_particle(self, name, p):
355  """Get the protein name from the particle.
356  This is done by traversing the hierarchy."""
357  if self.use_pmi2:
358  return IMP.pmi.get_molecule_name_and_copy(p), True
359  else:
361  p, self.dictchain[name])
362 
363  def get_particle_infos_for_pdb_writing(self, name):
364  # index_residue_pair_list={}
365 
366  # the resindexes dictionary keep track of residues that have
367  # been already added to avoid duplication
368  # highest resolution have highest priority
369  resindexes_dict = {}
370 
371  # this dictionary dill contain the sequence of tuples needed to
372  # write the pdb
373  particle_infos_for_pdb = []
374 
375  geometric_center = [0, 0, 0]
376  atom_count = 0
377 
378  if self.use_pmi2:
379  # select highest resolution
380  sel = IMP.atom.Selection(self.dictionary_pdbs[name], resolution=0)
381  ps = sel.get_selected_particles()
382  else:
383  ps = IMP.atom.get_leaves(self.dictionary_pdbs[name])
384 
385  for n, p in enumerate(ps):
386  protname, is_a_bead = self.get_prot_name_from_particle(name, p)
387 
388  if protname not in resindexes_dict:
389  resindexes_dict[protname] = []
390 
391  if IMP.atom.Atom.get_is_setup(p) and self.atomistic:
392  residue = IMP.atom.Residue(IMP.atom.Atom(p).get_parent())
393  rt = residue.get_residue_type()
394  resind = residue.get_index()
395  atomtype = IMP.atom.Atom(p).get_atom_type()
396  xyz = list(IMP.core.XYZ(p).get_coordinates())
397  radius = IMP.core.XYZR(p).get_radius()
398  geometric_center[0] += xyz[0]
399  geometric_center[1] += xyz[1]
400  geometric_center[2] += xyz[2]
401  atom_count += 1
402  particle_infos_for_pdb.append(
403  (xyz, atomtype, rt, self.dictchain[name][protname],
404  resind, None, radius))
405  resindexes_dict[protname].append(resind)
406 
408 
409  residue = IMP.atom.Residue(p)
410  resind = residue.get_index()
411  # skip if the residue was already added by atomistic resolution
412  # 0
413  if resind in resindexes_dict[protname]:
414  continue
415  else:
416  resindexes_dict[protname].append(resind)
417  rt = residue.get_residue_type()
418  xyz = IMP.core.XYZ(p).get_coordinates()
419  radius = IMP.core.XYZR(p).get_radius()
420  geometric_center[0] += xyz[0]
421  geometric_center[1] += xyz[1]
422  geometric_center[2] += xyz[2]
423  atom_count += 1
424  particle_infos_for_pdb.append(
425  (xyz, None, rt, self.dictchain[name][protname], resind,
426  None, radius))
427 
428  elif IMP.atom.Fragment.get_is_setup(p) and not is_a_bead:
429  resindexes = list(IMP.pmi.tools.get_residue_indexes(p))
430  resind = resindexes[len(resindexes) // 2]
431  if resind in resindexes_dict[protname]:
432  continue
433  else:
434  resindexes_dict[protname].append(resind)
435  rt = IMP.atom.ResidueType('BEA')
436  xyz = IMP.core.XYZ(p).get_coordinates()
437  radius = IMP.core.XYZR(p).get_radius()
438  geometric_center[0] += xyz[0]
439  geometric_center[1] += xyz[1]
440  geometric_center[2] += xyz[2]
441  atom_count += 1
442  particle_infos_for_pdb.append(
443  (xyz, None, rt, self.dictchain[name][protname], resind,
444  resindexes, radius))
445 
446  else:
447  if is_a_bead:
448  rt = IMP.atom.ResidueType('BEA')
449  resindexes = list(IMP.pmi.tools.get_residue_indexes(p))
450  if len(resindexes) > 0:
451  resind = resindexes[len(resindexes) // 2]
452  xyz = IMP.core.XYZ(p).get_coordinates()
453  radius = IMP.core.XYZR(p).get_radius()
454  geometric_center[0] += xyz[0]
455  geometric_center[1] += xyz[1]
456  geometric_center[2] += xyz[2]
457  atom_count += 1
458  particle_infos_for_pdb.append(
459  (xyz, None, rt, self.dictchain[name][protname],
460  resind, resindexes, radius))
461 
462  if atom_count > 0:
463  geometric_center = (geometric_center[0] / atom_count,
464  geometric_center[1] / atom_count,
465  geometric_center[2] / atom_count)
466 
467  # sort by chain ID, then residue index. Longer chain IDs (e.g. AA)
468  # should always come after shorter (e.g. Z)
469  particle_infos_for_pdb = sorted(particle_infos_for_pdb,
470  key=lambda x: (len(x[3]), x[3], x[4]))
471 
472  return (particle_infos_for_pdb, geometric_center)
473 
474  def write_pdbs(self, appendmode=True, mmcif=False):
475  for pdb in self.dictionary_pdbs.keys():
476  self.write_pdb(pdb, appendmode)
477 
478  def init_pdb_best_scoring(self,
479  suffix,
480  prot,
481  nbestscoring,
482  replica_exchange=False, mmcif=False):
483  # save only the nbestscoring conformations
484  # create as many pdbs as needed
485 
486  self._pdb_best_scoring_mmcif = mmcif
487  fileext = '.cif' if mmcif else '.pdb'
488  self.suffixes.append(suffix)
489  self.replica_exchange = replica_exchange
490  if not self.replica_exchange:
491  # common usage
492  # if you are not in replica exchange mode
493  # initialize the array of scores internally
494  self.best_score_list = []
495  else:
496  # otherwise the replicas must communicate
497  # through a common file to know what are the best scores
498  self.best_score_file_name = "best.scores.rex.py"
499  self.best_score_list = []
500  with open(self.best_score_file_name, "w") as best_score_file:
501  best_score_file.write(
502  "self.best_score_list=" + str(self.best_score_list) + "\n")
503 
504  self.nbestscoring = nbestscoring
505  for i in range(self.nbestscoring):
506  name = suffix + "." + str(i) + fileext
507  flpdb = open(name, 'w')
508  flpdb.close()
509  self.dictionary_pdbs[name] = prot
510  self._pdb_mmcif[name] = mmcif
511  self._init_dictchain(name, prot, mmcif=mmcif)
512 
513  def write_pdb_best_scoring(self, score):
514  if self.nbestscoring is None:
515  print("Output.write_pdb_best_scoring: init_pdb_best_scoring "
516  "not run")
517 
518  mmcif = self._pdb_best_scoring_mmcif
519  fileext = '.cif' if mmcif else '.pdb'
520  # update the score list
521  if self.replica_exchange:
522  # read the self.best_score_list from the file
523  with open(self.best_score_file_name) as fh:
524  exec(fh.read())
525 
526  if len(self.best_score_list) < self.nbestscoring:
527  self.best_score_list.append(score)
528  self.best_score_list.sort()
529  index = self.best_score_list.index(score)
530  for suffix in self.suffixes:
531  for i in range(len(self.best_score_list) - 2, index - 1, -1):
532  oldname = suffix + "." + str(i) + fileext
533  newname = suffix + "." + str(i + 1) + fileext
534  # rename on Windows fails if newname already exists
535  if os.path.exists(newname):
536  os.unlink(newname)
537  os.rename(oldname, newname)
538  filetoadd = suffix + "." + str(index) + fileext
539  self.write_pdb(filetoadd, appendmode=False)
540 
541  else:
542  if score < self.best_score_list[-1]:
543  self.best_score_list.append(score)
544  self.best_score_list.sort()
545  self.best_score_list.pop(-1)
546  index = self.best_score_list.index(score)
547  for suffix in self.suffixes:
548  for i in range(len(self.best_score_list) - 1,
549  index - 1, -1):
550  oldname = suffix + "." + str(i) + fileext
551  newname = suffix + "." + str(i + 1) + fileext
552  os.rename(oldname, newname)
553  filenametoremove = suffix + \
554  "." + str(self.nbestscoring) + fileext
555  os.remove(filenametoremove)
556  filetoadd = suffix + "." + str(index) + fileext
557  self.write_pdb(filetoadd, appendmode=False)
558 
559  if self.replica_exchange:
560  # write the self.best_score_list to the file
561  with open(self.best_score_file_name, "w") as best_score_file:
562  best_score_file.write(
563  "self.best_score_list=" + str(self.best_score_list) + '\n')
564 
565  def init_rmf(self, name, hierarchies, rs=None, geometries=None,
566  listofobjects=None):
567  """
568  Initialize an RMF file
569 
570  @param name the name of the RMF file
571  @param hierarchies the hierarchies to be included (it is a list)
572  @param rs optional, the restraint sets (it is a list)
573  @param geometries optional, the geometries (it is a list)
574  @param listofobjects optional, the list of objects for the stat
575  (it is a list)
576  """
577  rh = RMF.create_rmf_file(name)
578  IMP.rmf.add_hierarchies(rh, hierarchies)
579  cat = None
580  outputkey_rmfkey = None
581 
582  if rs is not None:
583  IMP.rmf.add_restraints(rh, rs)
584  if geometries is not None:
585  IMP.rmf.add_geometries(rh, geometries)
586  if listofobjects is not None:
587  cat = rh.get_category("stat")
588  outputkey_rmfkey = {}
589  for o in listofobjects:
590  if "get_output" not in dir(o):
591  raise ValueError(
592  "Output: object %s doesn't have get_output() method"
593  % str(o))
594  output = o.get_output()
595  for outputkey in output:
596  rmftag = RMF.string_tag
597  if isinstance(output[outputkey], float):
598  rmftag = RMF.float_tag
599  elif isinstance(output[outputkey], int):
600  rmftag = RMF.int_tag
601  elif isinstance(output[outputkey], str):
602  rmftag = RMF.string_tag
603  else:
604  rmftag = RMF.string_tag
605  rmfkey = rh.get_key(cat, outputkey, rmftag)
606  outputkey_rmfkey[outputkey] = rmfkey
607  outputkey_rmfkey["rmf_file"] = \
608  rh.get_key(cat, "rmf_file", RMF.string_tag)
609  outputkey_rmfkey["rmf_frame_index"] = \
610  rh.get_key(cat, "rmf_frame_index", RMF.int_tag)
611 
612  self.dictionary_rmfs[name] = (rh, cat, outputkey_rmfkey, listofobjects)
613 
614  def add_restraints_to_rmf(self, name, objectlist):
615  for o in _flatten(objectlist):
616  try:
617  rs = o.get_restraint_for_rmf()
618  if not isinstance(rs, (list, tuple)):
619  rs = [rs]
620  except: # noqa: E722
621  rs = [o.get_restraint()]
623  self.dictionary_rmfs[name][0], rs)
624 
625  def add_geometries_to_rmf(self, name, objectlist):
626  for o in objectlist:
627  geos = o.get_geometries()
628  IMP.rmf.add_geometries(self.dictionary_rmfs[name][0], geos)
629 
630  def add_particle_pair_from_restraints_to_rmf(self, name, objectlist):
631  for o in objectlist:
632 
633  pps = o.get_particle_pairs()
634  for pp in pps:
636  self.dictionary_rmfs[name][0],
638 
639  def write_rmf(self, name):
640  IMP.rmf.save_frame(self.dictionary_rmfs[name][0])
641  if self.dictionary_rmfs[name][1] is not None:
642  outputkey_rmfkey = self.dictionary_rmfs[name][2]
643  listofobjects = self.dictionary_rmfs[name][3]
644  for o in listofobjects:
645  output = o.get_output()
646  for outputkey in output:
647  rmfkey = outputkey_rmfkey[outputkey]
648  try:
649  n = self.dictionary_rmfs[name][0].get_root_node()
650  n.set_value(rmfkey, output[outputkey])
651  except NotImplementedError:
652  continue
653  rmfkey = outputkey_rmfkey["rmf_file"]
654  self.dictionary_rmfs[name][0].get_root_node().set_value(
655  rmfkey, name)
656  rmfkey = outputkey_rmfkey["rmf_frame_index"]
657  nframes = self.dictionary_rmfs[name][0].get_number_of_frames()
658  self.dictionary_rmfs[name][0].get_root_node().set_value(
659  rmfkey, nframes-1)
660  self.dictionary_rmfs[name][0].flush()
661 
662  def close_rmf(self, name):
663  rh = self.dictionary_rmfs[name][0]
664  del self.dictionary_rmfs[name]
665  del rh
666 
667  def write_rmfs(self):
668  for rmfinfo in self.dictionary_rmfs.keys():
669  self.write_rmf(rmfinfo[0])
670 
671  def init_stat(self, name, listofobjects):
672  if self.ascii:
673  flstat = open(name, 'w')
674  flstat.close()
675  else:
676  flstat = open(name, 'wb')
677  flstat.close()
678 
679  # check that all objects in listofobjects have a get_output method
680  for o in listofobjects:
681  if "get_output" not in dir(o):
682  raise ValueError(
683  "Output: object %s doesn't have get_output() method"
684  % str(o))
685  self.dictionary_stats[name] = listofobjects
686 
687  def set_output_entry(self, key, value):
688  self.initoutput.update({key: value})
689 
690  def write_stat(self, name, appendmode=True):
691  output = self.initoutput
692  for obj in self.dictionary_stats[name]:
693  d = obj.get_output()
694  # remove all entries that begin with _ (private entries)
695  dfiltered = dict((k, v) for k, v in d.items() if k[0] != "_")
696  output.update(dfiltered)
697 
698  if appendmode:
699  writeflag = 'a'
700  else:
701  writeflag = 'w'
702 
703  if self.ascii:
704  flstat = open(name, writeflag)
705  flstat.write("%s \n" % output)
706  flstat.close()
707  else:
708  flstat = open(name, writeflag + 'b')
709  pickle.dump(output, flstat, 2)
710  flstat.close()
711 
712  def write_stats(self):
713  for stat in self.dictionary_stats.keys():
714  self.write_stat(stat)
715 
716  def get_stat(self, name):
717  output = {}
718  for obj in self.dictionary_stats[name]:
719  output.update(obj.get_output())
720  return output
721 
722  def write_test(self, name, listofobjects):
723  flstat = open(name, 'w')
724  output = self.initoutput
725  for o in listofobjects:
726  if "get_test_output" not in dir(o) and "get_output" not in dir(o):
727  raise ValueError(
728  "Output: object %s doesn't have get_output() or "
729  "get_test_output() method" % str(o))
730  self.dictionary_stats[name] = listofobjects
731 
732  for obj in self.dictionary_stats[name]:
733  try:
734  d = obj.get_test_output()
735  except: # noqa: E722
736  d = obj.get_output()
737  # remove all entries that begin with _ (private entries)
738  dfiltered = dict((k, v) for k, v in d.items() if k[0] != "_")
739  output.update(dfiltered)
740  flstat.write("%s \n" % output)
741  flstat.close()
742 
743  def test(self, name, listofobjects, tolerance=1e-5):
744  output = self.initoutput
745  for o in listofobjects:
746  if "get_test_output" not in dir(o) and "get_output" not in dir(o):
747  raise ValueError(
748  "Output: object %s doesn't have get_output() or "
749  "get_test_output() method" % str(o))
750  for obj in listofobjects:
751  try:
752  output.update(obj.get_test_output())
753  except: # noqa: E722
754  output.update(obj.get_output())
755 
756  flstat = open(name, 'r')
757 
758  passed = True
759  for fl in flstat:
760  test_dict = ast.literal_eval(fl)
761  for k in test_dict:
762  if k in output:
763  old_value = str(test_dict[k])
764  new_value = str(output[k])
765  try:
766  float(old_value)
767  is_float = True
768  except ValueError:
769  is_float = False
770 
771  if is_float:
772  fold = float(old_value)
773  fnew = float(new_value)
774  diff = abs(fold - fnew)
775  if diff > tolerance:
776  print("%s: test failed, old value: %s new value %s; "
777  "diff %f > %f" % (str(k), str(old_value),
778  str(new_value), diff,
779  tolerance), file=sys.stderr)
780  passed = False
781  elif test_dict[k] != output[k]:
782  if len(old_value) < 50 and len(new_value) < 50:
783  print("%s: test failed, old value: %s new value %s"
784  % (str(k), old_value, new_value),
785  file=sys.stderr)
786  passed = False
787  else:
788  print("%s: test failed, omitting results (too long)"
789  % str(k), file=sys.stderr)
790  passed = False
791 
792  else:
793  print("%s from old objects (file %s) not in new objects"
794  % (str(k), str(name)), file=sys.stderr)
795  flstat.close()
796  return passed
797 
798  def get_environment_variables(self):
799  import os
800  return str(os.environ)
801 
802  def get_versions_of_relevant_modules(self):
803  import IMP
804  versions = {}
805  versions["IMP_VERSION"] = IMP.get_module_version()
806  versions["PMI_VERSION"] = IMP.pmi.get_module_version()
807  try:
808  import IMP.isd2
809  versions["ISD2_VERSION"] = IMP.isd2.get_module_version()
810  except ImportError:
811  pass
812  try:
813  import IMP.isd_emxl
814  versions["ISD_EMXL_VERSION"] = IMP.isd_emxl.get_module_version()
815  except ImportError:
816  pass
817  return versions
818 
819  def init_stat2(self, name, listofobjects, extralabels=None,
820  listofsummedobjects=None):
821  # this is a new stat file that should be less
822  # space greedy!
823  # listofsummedobjects must be in the form
824  # [([obj1,obj2,obj3,obj4...],label)]
825  # extralabels
826 
827  if listofsummedobjects is None:
828  listofsummedobjects = []
829  if extralabels is None:
830  extralabels = []
831  flstat = open(name, 'w')
832  output = {}
833  stat2_keywords = {"STAT2HEADER": "STAT2HEADER"}
834  stat2_keywords.update(
835  {"STAT2HEADER_ENVIRON": str(self.get_environment_variables())})
836  stat2_keywords.update(
837  {"STAT2HEADER_IMP_VERSIONS":
838  str(self.get_versions_of_relevant_modules())})
839  stat2_inverse = {}
840 
841  for obj in listofobjects:
842  if "get_output" not in dir(obj):
843  raise ValueError(
844  "Output: object %s doesn't have get_output() method"
845  % str(obj))
846  else:
847  d = obj.get_output()
848  # remove all entries that begin with _ (private entries)
849  dfiltered = dict((k, v)
850  for k, v in d.items() if k[0] != "_")
851  output.update(dfiltered)
852 
853  # check for customizable entries
854  for obj in listofsummedobjects:
855  for t in obj[0]:
856  if "get_output" not in dir(t):
857  raise ValueError(
858  "Output: object %s doesn't have get_output() method"
859  % str(t))
860  else:
861  if "_TotalScore" not in t.get_output():
862  raise ValueError(
863  "Output: object %s doesn't have _TotalScore "
864  "entry to be summed" % str(t))
865  else:
866  output.update({obj[1]: 0.0})
867 
868  for k in extralabels:
869  output.update({k: 0.0})
870 
871  for n, k in enumerate(output):
872  stat2_keywords.update({n: k})
873  stat2_inverse.update({k: n})
874 
875  flstat.write("%s \n" % stat2_keywords)
876  flstat.close()
877  self.dictionary_stats2[name] = (
878  listofobjects,
879  stat2_inverse,
880  listofsummedobjects,
881  extralabels)
882 
883  def write_stat2(self, name, appendmode=True):
884  output = {}
885  (listofobjects, stat2_inverse, listofsummedobjects,
886  extralabels) = self.dictionary_stats2[name]
887 
888  # writing objects
889  for obj in listofobjects:
890  od = obj.get_output()
891  dfiltered = dict((k, v) for k, v in od.items() if k[0] != "_")
892  for k in dfiltered:
893  output.update({stat2_inverse[k]: od[k]})
894 
895  # writing summedobjects
896  for so in listofsummedobjects:
897  partial_score = 0.0
898  for t in so[0]:
899  d = t.get_output()
900  partial_score += float(d["_TotalScore"])
901  output.update({stat2_inverse[so[1]]: str(partial_score)})
902 
903  # writing extralabels
904  for k in extralabels:
905  if k in self.initoutput:
906  output.update({stat2_inverse[k]: self.initoutput[k]})
907  else:
908  output.update({stat2_inverse[k]: "None"})
909 
910  if appendmode:
911  writeflag = 'a'
912  else:
913  writeflag = 'w'
914 
915  flstat = open(name, writeflag)
916  flstat.write("%s \n" % output)
917  flstat.close()
918 
919  def write_stats2(self):
920  for stat in self.dictionary_stats2.keys():
921  self.write_stat2(stat)
922 
923 
924 class OutputStatistics(object):
925  """Collect statistics from ProcessOutput.get_fields().
926  Counters of the total number of frames read, plus the models that
927  passed the various filters used in get_fields(), are provided."""
928  def __init__(self):
929  self.total = 0
930  self.passed_get_every = 0
931  self.passed_filterout = 0
932  self.passed_filtertuple = 0
933 
934 
935 class ProcessOutput(object):
936  """A class for reading stat files (either rmf or ascii v1 and v2)"""
937  def __init__(self, filename):
938  self.filename = filename
939  self.isstat1 = False
940  self.isstat2 = False
941  self.isrmf = False
942 
943  if self.filename is None:
944  raise ValueError("No file name provided. Use -h for help")
945 
946  try:
947  # let's see if that is an rmf file
948  rh = RMF.open_rmf_file_read_only(self.filename)
949  self.isrmf = True
950  cat = rh.get_category('stat')
951  rmf_klist = rh.get_keys(cat)
952  self.rmf_names_keys = dict([(rh.get_name(k), k)
953  for k in rmf_klist])
954  del rh
955 
956  except IOError:
957  f = open(self.filename, "r")
958  # try with an ascii stat file
959  # get the keys from the first line
960  for line in f.readlines():
961  d = ast.literal_eval(line)
962  self.klist = list(d.keys())
963  # check if it is a stat2 file
964  if "STAT2HEADER" in self.klist:
965  self.isstat2 = True
966  for k in self.klist:
967  if "STAT2HEADER" in str(k):
968  # if print_header: print k, d[k]
969  del d[k]
970  stat2_dict = d
971  # get the list of keys sorted by value
972  kkeys = [k[0]
973  for k in sorted(stat2_dict.items(),
974  key=operator.itemgetter(1))]
975  self.klist = [k[1]
976  for k in sorted(stat2_dict.items(),
977  key=operator.itemgetter(1))]
978  self.invstat2_dict = {}
979  for k in kkeys:
980  self.invstat2_dict.update({stat2_dict[k]: k})
981  else:
983  "statfile v1 is deprecated. "
984  "Please convert to statfile v2.\n")
985  self.isstat1 = True
986  self.klist.sort()
987 
988  break
989  f.close()
990 
991  def get_keys(self):
992  if self.isrmf:
993  return sorted(self.rmf_names_keys.keys())
994  else:
995  return self.klist
996 
997  def show_keys(self, ncolumns=2, truncate=65):
998  IMP.pmi.tools.print_multicolumn(self.get_keys(), ncolumns, truncate)
999 
1000  def get_fields(self, fields, filtertuple=None, filterout=None, get_every=1,
1001  statistics=None):
1002  '''
1003  Get the desired field names, and return a dictionary.
1004  Namely, "fields" are the queried keys in the stat file
1005  (eg. ["Total_Score",...])
1006  The returned data structure is a dictionary, where each key is
1007  a field and the value is the time series (ie, frame ordered series)
1008  of that field (ie, {"Total_Score":[Score_0,Score_1,Score_2,,...],....})
1009 
1010  @param fields (list of strings) queried keys in the stat file
1011  (eg. "Total_Score"....)
1012  @param filterout specify if you want to "grep" out something from
1013  the file, so that it is faster
1014  @param filtertuple a tuple that contains
1015  ("TheKeyToBeFiltered",relationship,value)
1016  where relationship = "<", "==", or ">"
1017  @param get_every only read every Nth line from the file
1018  @param statistics if provided, accumulate statistics in an
1019  OutputStatistics object
1020  '''
1021 
1022  if statistics is None:
1023  statistics = OutputStatistics()
1024  outdict = {}
1025  for field in fields:
1026  outdict[field] = []
1027 
1028  # print fields values
1029  if self.isrmf:
1030  rh = RMF.open_rmf_file_read_only(self.filename)
1031  nframes = rh.get_number_of_frames()
1032  for i in range(nframes):
1033  statistics.total += 1
1034  # "get_every" and "filterout" not enforced for RMF
1035  statistics.passed_get_every += 1
1036  statistics.passed_filterout += 1
1037  IMP.rmf.load_frame(rh, RMF.FrameID(i))
1038  if filtertuple is not None:
1039  keytobefiltered = filtertuple[0]
1040  relationship = filtertuple[1]
1041  value = filtertuple[2]
1042  datavalue = rh.get_root_node().get_value(
1043  self.rmf_names_keys[keytobefiltered])
1044  if self.isfiltered(datavalue, relationship, value):
1045  continue
1046 
1047  statistics.passed_filtertuple += 1
1048  for field in fields:
1049  outdict[field].append(rh.get_root_node().get_value(
1050  self.rmf_names_keys[field]))
1051 
1052  else:
1053  f = open(self.filename, "r")
1054  line_number = 0
1055 
1056  for line in f.readlines():
1057  statistics.total += 1
1058  if filterout is not None:
1059  if filterout in line:
1060  continue
1061  statistics.passed_filterout += 1
1062  line_number += 1
1063 
1064  if line_number % get_every != 0:
1065  if line_number == 1 and self.isstat2:
1066  statistics.total -= 1
1067  statistics.passed_filterout -= 1
1068  continue
1069  statistics.passed_get_every += 1
1070  try:
1071  d = ast.literal_eval(line)
1072  except: # noqa: E722
1073  print("# Warning: skipped line number " + str(line_number)
1074  + " not a valid line")
1075  continue
1076 
1077  if self.isstat1:
1078 
1079  if filtertuple is not None:
1080  keytobefiltered = filtertuple[0]
1081  relationship = filtertuple[1]
1082  value = filtertuple[2]
1083  datavalue = d[keytobefiltered]
1084  if self.isfiltered(datavalue, relationship, value):
1085  continue
1086 
1087  statistics.passed_filtertuple += 1
1088  [outdict[field].append(d[field]) for field in fields]
1089 
1090  elif self.isstat2:
1091  if line_number == 1:
1092  statistics.total -= 1
1093  statistics.passed_filterout -= 1
1094  statistics.passed_get_every -= 1
1095  continue
1096 
1097  if filtertuple is not None:
1098  keytobefiltered = filtertuple[0]
1099  relationship = filtertuple[1]
1100  value = filtertuple[2]
1101  datavalue = d[self.invstat2_dict[keytobefiltered]]
1102  if self.isfiltered(datavalue, relationship, value):
1103  continue
1104 
1105  statistics.passed_filtertuple += 1
1106  [outdict[field].append(d[self.invstat2_dict[field]])
1107  for field in fields]
1108 
1109  f.close()
1110 
1111  return outdict
1112 
1113  def isfiltered(self, datavalue, relationship, refvalue):
1114  dofilter = False
1115  try:
1116  _ = float(datavalue)
1117  except ValueError:
1118  raise ValueError("ProcessOutput.filter: datavalue cannot be "
1119  "converted into a float")
1120 
1121  if relationship == "<":
1122  if float(datavalue) >= refvalue:
1123  dofilter = True
1124  if relationship == ">":
1125  if float(datavalue) <= refvalue:
1126  dofilter = True
1127  if relationship == "==":
1128  if float(datavalue) != refvalue:
1129  dofilter = True
1130  return dofilter
1131 
1132 
1134  """ class to allow more advanced handling of RMF files.
1135  It is both a container and a IMP.atom.Hierarchy.
1136  - it is iterable (while loading the corresponding frame)
1137  - Item brackets [] load the corresponding frame
1138  - slice create an iterator
1139  - can relink to another RMF file
1140  """
1141  def __init__(self, model, rmf_file_name):
1142  """
1143  @param model: the IMP.Model()
1144  @param rmf_file_name: str, path of the rmf file
1145  """
1146  self.model = model
1147  try:
1148  self.rh_ref = RMF.open_rmf_file_read_only(rmf_file_name)
1149  except TypeError:
1150  raise TypeError("Wrong rmf file name or type: %s"
1151  % str(rmf_file_name))
1152  hs = IMP.rmf.create_hierarchies(self.rh_ref, self.model)
1153  IMP.rmf.load_frame(self.rh_ref, RMF.FrameID(0))
1154  self.root_hier_ref = hs[0]
1155  IMP.atom.Hierarchy.__init__(self, self.root_hier_ref)
1156  self.model.update()
1157  self.ColorHierarchy = None
1158 
1159  def link_to_rmf(self, rmf_file_name):
1160  """
1161  Link to another RMF file
1162  """
1163  self.rh_ref = RMF.open_rmf_file_read_only(rmf_file_name)
1164  IMP.rmf.link_hierarchies(self.rh_ref, [self])
1165  if self.ColorHierarchy:
1166  self.ColorHierarchy.method()
1167  RMFHierarchyHandler.set_frame(self, 0)
1168 
1169  def set_frame(self, index):
1170  try:
1171  IMP.rmf.load_frame(self.rh_ref, RMF.FrameID(index))
1172  except: # noqa: E722
1173  print("skipping frame %s:%d\n" % (self.current_rmf, index))
1174  self.model.update()
1175 
1176  def get_number_of_frames(self):
1177  return self.rh_ref.get_number_of_frames()
1178 
1179  def __getitem__(self, int_slice_adaptor):
1180  if isinstance(int_slice_adaptor, int):
1181  self.set_frame(int_slice_adaptor)
1182  return int_slice_adaptor
1183  elif isinstance(int_slice_adaptor, slice):
1184  return self.__iter__(int_slice_adaptor)
1185  else:
1186  raise TypeError("Unknown Type")
1187 
1188  def __len__(self):
1189  return self.get_number_of_frames()
1190 
1191  def __iter__(self, slice_key=None):
1192  if slice_key is None:
1193  for nframe in range(len(self)):
1194  yield self[nframe]
1195  else:
1196  for nframe in list(range(len(self)))[slice_key]:
1197  yield self[nframe]
1198 
1199 
1200 class CacheHierarchyCoordinates(object):
1201  def __init__(self, StatHierarchyHandler):
1202  self.xyzs = []
1203  self.nrms = []
1204  self.rbs = []
1205  self.nrm_coors = {}
1206  self.xyz_coors = {}
1207  self.rb_trans = {}
1208  self.current_index = None
1209  self.rmfh = StatHierarchyHandler
1210  rbs, xyzs = IMP.pmi.tools.get_rbs_and_beads([self.rmfh])
1211  self.model = self.rmfh.get_model()
1212  self.rbs = rbs
1213  for xyz in xyzs:
1215  nrm = IMP.core.NonRigidMember(xyz)
1216  self.nrms.append(nrm)
1217  else:
1218  fb = IMP.core.XYZ(xyz)
1219  self.xyzs.append(fb)
1220 
1221  def do_store(self, index):
1222  self.rb_trans[index] = {}
1223  self.nrm_coors[index] = {}
1224  self.xyz_coors[index] = {}
1225  for rb in self.rbs:
1226  self.rb_trans[index][rb] = rb.get_reference_frame()
1227  for nrm in self.nrms:
1228  self.nrm_coors[index][nrm] = nrm.get_internal_coordinates()
1229  for xyz in self.xyzs:
1230  self.xyz_coors[index][xyz] = xyz.get_coordinates()
1231  self.current_index = index
1232 
1233  def do_update(self, index):
1234  if self.current_index != index:
1235  for rb in self.rbs:
1236  rb.set_reference_frame(self.rb_trans[index][rb])
1237  for nrm in self.nrms:
1238  nrm.set_internal_coordinates(self.nrm_coors[index][nrm])
1239  for xyz in self.xyzs:
1240  xyz.set_coordinates(self.xyz_coors[index][xyz])
1241  self.current_index = index
1242  self.model.update()
1243 
1244  def get_number_of_frames(self):
1245  return len(self.rb_trans.keys())
1246 
1247  def __getitem__(self, index):
1248  if isinstance(index, int):
1249  return index in self.rb_trans.keys()
1250  else:
1251  raise TypeError("Unknown Type")
1252 
1253  def __len__(self):
1254  return self.get_number_of_frames()
1255 
1256 
1258  """ class to link stat files to several rmf files """
1259  def __init__(self, model=None, stat_file=None,
1260  number_best_scoring_models=None, score_key=None,
1261  StatHierarchyHandler=None, cache=None):
1262  """
1263 
1264  @param model: IMP.Model()
1265  @param stat_file: either 1) a list or 2) a single stat file names
1266  (either rmfs or ascii, or pickled data or pickled cluster),
1267  3) a dictionary containing an rmf/ascii
1268  stat file name as key and a list of frames as values
1269  @param number_best_scoring_models:
1270  @param StatHierarchyHandler: copy constructor input object
1271  @param cache: cache coordinates and rigid body transformations.
1272  """
1273 
1274  if StatHierarchyHandler is not None:
1275  # overrides all other arguments
1276  # copy constructor: create a copy with
1277  # different RMFHierarchyHandler
1278  self.model = StatHierarchyHandler.model
1279  self.data = StatHierarchyHandler.data
1280  self.number_best_scoring_models = \
1281  StatHierarchyHandler.number_best_scoring_models
1282  self.is_setup = True
1283  self.current_rmf = StatHierarchyHandler.current_rmf
1284  self.current_frame = None
1285  self.current_index = None
1286  self.score_threshold = StatHierarchyHandler.score_threshold
1287  self.score_key = StatHierarchyHandler.score_key
1288  self.cache = StatHierarchyHandler.cache
1289  RMFHierarchyHandler.__init__(self, self.model,
1290  self.current_rmf)
1291  if self.cache:
1292  self.cache = CacheHierarchyCoordinates(self)
1293  else:
1294  self.cache = None
1295  self.set_frame(0)
1296 
1297  else:
1298  # standard constructor
1299  self.model = model
1300  self.data = []
1301  self.number_best_scoring_models = number_best_scoring_models
1302  self.cache = cache
1303 
1304  if score_key is None:
1305  self.score_key = "Total_Score"
1306  else:
1307  self.score_key = score_key
1308  self.is_setup = None
1309  self.current_rmf = None
1310  self.current_frame = None
1311  self.current_index = None
1312  self.score_threshold = None
1313 
1314  if isinstance(stat_file, str):
1315  self.add_stat_file(stat_file)
1316  elif isinstance(stat_file, list):
1317  for f in stat_file:
1318  self.add_stat_file(f)
1319 
1320  def add_stat_file(self, stat_file):
1321  try:
1322  '''check that it is not a pickle file with saved data
1323  from a previous calculation'''
1324  self.load_data(stat_file)
1325 
1326  if self.number_best_scoring_models:
1327  scores = self.get_scores()
1328  max_score = sorted(scores)[
1329  0:min(len(self), self.number_best_scoring_models)][-1]
1330  self.do_filter_by_score(max_score)
1331 
1332  except pickle.UnpicklingError:
1333  '''alternatively read the ascii stat files'''
1334  try:
1335  scores, rmf_files, rmf_frame_indexes, features = \
1336  self.get_info_from_stat_file(stat_file,
1337  self.score_threshold)
1338  except (KeyError, SyntaxError):
1339  # in this case check that is it an rmf file, probably
1340  # without stat stored in
1341  try:
1342  # let's see if that is an rmf file
1343  rh = RMF.open_rmf_file_read_only(stat_file)
1344  nframes = rh.get_number_of_frames()
1345  scores = [0.0]*nframes
1346  rmf_files = [stat_file]*nframes
1347  rmf_frame_indexes = range(nframes)
1348  features = {}
1349  except: # noqa: E722
1350  return
1351 
1352  if len(set(rmf_files)) > 1:
1353  raise ("Multiple RMF files found")
1354 
1355  if not rmf_files:
1356  print("StatHierarchyHandler: Error: Trying to set none as "
1357  "rmf_file (probably empty stat file), aborting")
1358  return
1359 
1360  for n, index in enumerate(rmf_frame_indexes):
1361  featn_dict = dict([(k, features[k][n]) for k in features])
1362  self.data.append(IMP.pmi.output.DataEntry(
1363  stat_file, rmf_files[n], index, scores[n], featn_dict))
1364 
1365  if self.number_best_scoring_models:
1366  scores = self.get_scores()
1367  max_score = sorted(scores)[
1368  0:min(len(self), self.number_best_scoring_models)][-1]
1369  self.do_filter_by_score(max_score)
1370 
1371  if not self.is_setup:
1372  RMFHierarchyHandler.__init__(
1373  self, self.model, self.get_rmf_names()[0])
1374  if self.cache:
1375  self.cache = CacheHierarchyCoordinates(self)
1376  else:
1377  self.cache = None
1378  self.is_setup = True
1379  self.current_rmf = self.get_rmf_names()[0]
1380 
1381  self.set_frame(0)
1382 
1383  def save_data(self, filename='data.pkl'):
1384  with open(filename, 'wb') as fl:
1385  pickle.dump(self.data, fl)
1386 
1387  def load_data(self, filename='data.pkl'):
1388  with open(filename, 'rb') as fl:
1389  data_structure = pickle.load(fl)
1390  # first check that it is a list
1391  if not isinstance(data_structure, list):
1392  raise TypeError(
1393  "%filename should contain a list of IMP.pmi.output.DataEntry "
1394  "or IMP.pmi.output.Cluster" % filename)
1395  # second check the types
1396  if all(isinstance(item, IMP.pmi.output.DataEntry)
1397  for item in data_structure):
1398  self.data = data_structure
1399  elif all(isinstance(item, IMP.pmi.output.Cluster)
1400  for item in data_structure):
1401  nmodels = 0
1402  for cluster in data_structure:
1403  nmodels += len(cluster)
1404  self.data = [None]*nmodels
1405  for cluster in data_structure:
1406  for n, data in enumerate(cluster):
1407  index = cluster.members[n]
1408  self.data[index] = data
1409  else:
1410  raise TypeError(
1411  "%filename should contain a list of IMP.pmi.output.DataEntry "
1412  "or IMP.pmi.output.Cluster" % filename)
1413 
1414  def set_frame(self, index):
1415  if self.cache is not None and self.cache[index]:
1416  self.cache.do_update(index)
1417  else:
1418  nm = self.data[index].rmf_name
1419  fidx = self.data[index].rmf_index
1420  if nm != self.current_rmf:
1421  self.link_to_rmf(nm)
1422  self.current_rmf = nm
1423  self.current_frame = -1
1424  if fidx != self.current_frame:
1425  RMFHierarchyHandler.set_frame(self, fidx)
1426  self.current_frame = fidx
1427  if self.cache is not None:
1428  self.cache.do_store(index)
1429 
1430  self.current_index = index
1431 
1432  def __getitem__(self, int_slice_adaptor):
1433  if isinstance(int_slice_adaptor, int):
1434  self.set_frame(int_slice_adaptor)
1435  return self.data[int_slice_adaptor]
1436  elif isinstance(int_slice_adaptor, slice):
1437  return self.__iter__(int_slice_adaptor)
1438  else:
1439  raise TypeError("Unknown Type")
1440 
1441  def __len__(self):
1442  return len(self.data)
1443 
1444  def __iter__(self, slice_key=None):
1445  if slice_key is None:
1446  for i in range(len(self)):
1447  yield self[i]
1448  else:
1449  for i in range(len(self))[slice_key]:
1450  yield self[i]
1451 
1452  def do_filter_by_score(self, maximum_score):
1453  self.data = [d for d in self.data if d.score <= maximum_score]
1454 
1455  def get_scores(self):
1456  return [d.score for d in self.data]
1457 
1458  def get_feature_series(self, feature_name):
1459  return [d.features[feature_name] for d in self.data]
1460 
1461  def get_feature_names(self):
1462  return self.data[0].features.keys()
1463 
1464  def get_rmf_names(self):
1465  return [d.rmf_name for d in self.data]
1466 
1467  def get_stat_files_names(self):
1468  return [d.stat_file for d in self.data]
1469 
1470  def get_rmf_indexes(self):
1471  return [d.rmf_index for d in self.data]
1472 
1473  def get_info_from_stat_file(self, stat_file, score_threshold=None):
1474  po = ProcessOutput(stat_file)
1475  fs = po.get_keys()
1476  models = IMP.pmi.io.get_best_models(
1477  [stat_file], score_key=self.score_key, feature_keys=fs,
1478  rmf_file_key="rmf_file", rmf_file_frame_key="rmf_frame_index",
1479  prefiltervalue=score_threshold, get_every=1)
1480 
1481  scores = [float(y) for y in models[2]]
1482  rmf_files = models[0]
1483  rmf_frame_indexes = models[1]
1484  features = models[3]
1485  return scores, rmf_files, rmf_frame_indexes, features
1486 
1487 
1488 class DataEntry(object):
1489  '''
1490  A class to store data associated to a model
1491  '''
1492  def __init__(self, stat_file=None, rmf_name=None, rmf_index=None,
1493  score=None, features=None):
1494  self.rmf_name = rmf_name
1495  self.rmf_index = rmf_index
1496  self.score = score
1497  self.features = features
1498  self.stat_file = stat_file
1499 
1500  def __repr__(self):
1501  s = "IMP.pmi.output.DataEntry\n"
1502  s += "---- stat file %s \n" % (self.stat_file)
1503  s += "---- rmf file %s \n" % (self.rmf_name)
1504  s += "---- rmf index %s \n" % (str(self.rmf_index))
1505  s += "---- score %s \n" % (str(self.score))
1506  s += "---- number of features %s \n" % (str(len(self.features.keys())))
1507  return s
1508 
1509 
1510 class Cluster(object):
1511  '''
1512  A container for models organized into clusters
1513  '''
1514  def __init__(self, cid=None):
1515  self.cluster_id = cid
1516  self.members = []
1517  self.precision = None
1518  self.center_index = None
1519  self.members_data = {}
1520 
1521  def add_member(self, index, data=None):
1522  self.members.append(index)
1523  self.members_data[index] = data
1524  self.average_score = self.compute_score()
1525 
1526  def compute_score(self):
1527  try:
1528  score = sum([d.score for d in self])/len(self)
1529  except AttributeError:
1530  score = None
1531  return score
1532 
1533  def __repr__(self):
1534  s = "IMP.pmi.output.Cluster\n"
1535  s += "---- cluster_id %s \n" % str(self.cluster_id)
1536  s += "---- precision %s \n" % str(self.precision)
1537  s += "---- average score %s \n" % str(self.average_score)
1538  s += "---- number of members %s \n" % str(len(self.members))
1539  s += "---- center index %s \n" % str(self.center_index)
1540  return s
1541 
1542  def __getitem__(self, int_slice_adaptor):
1543  if isinstance(int_slice_adaptor, int):
1544  index = self.members[int_slice_adaptor]
1545  return self.members_data[index]
1546  elif isinstance(int_slice_adaptor, slice):
1547  return self.__iter__(int_slice_adaptor)
1548  else:
1549  raise TypeError("Unknown Type")
1550 
1551  def __len__(self):
1552  return len(self.members)
1553 
1554  def __iter__(self, slice_key=None):
1555  if slice_key is None:
1556  for i in range(len(self)):
1557  yield self[i]
1558  else:
1559  for i in range(len(self))[slice_key]:
1560  yield self[i]
1561 
1562  def __add__(self, other):
1563  self.members += other.members
1564  self.members_data.update(other.members_data)
1565  self.average_score = self.compute_score()
1566  self.precision = None
1567  self.center_index = None
1568  return self
1569 
1570 
1571 def plot_clusters_populations(clusters):
1572  indexes = []
1573  populations = []
1574  for cluster in clusters:
1575  indexes.append(cluster.cluster_id)
1576  populations.append(len(cluster))
1577 
1578  import matplotlib.pyplot as plt
1579  fig, ax = plt.subplots()
1580  ax.bar(indexes, populations, 0.5, color='r')
1581  ax.set_ylabel('Population')
1582  ax.set_xlabel(('Cluster index'))
1583  plt.show()
1584 
1585 
1586 def plot_clusters_precisions(clusters):
1587  indexes = []
1588  precisions = []
1589  for cluster in clusters:
1590  indexes.append(cluster.cluster_id)
1591 
1592  prec = cluster.precision
1593  print(cluster.cluster_id, prec)
1594  if prec is None:
1595  prec = 0.0
1596  precisions.append(prec)
1597 
1598  import matplotlib.pyplot as plt
1599  fig, ax = plt.subplots()
1600  ax.bar(indexes, precisions, 0.5, color='r')
1601  ax.set_ylabel('Precision [A]')
1602  ax.set_xlabel(('Cluster index'))
1603  plt.show()
1604 
1605 
1606 def plot_clusters_scores(clusters):
1607  indexes = []
1608  values = []
1609  for cluster in clusters:
1610  indexes.append(cluster.cluster_id)
1611  values.append([])
1612  for data in cluster:
1613  values[-1].append(data.score)
1614 
1615  plot_fields_box_plots("scores.pdf", values, indexes, frequencies=None,
1616  valuename="Scores", positionname="Cluster index",
1617  xlabels=None, scale_plot_length=1.0)
1618 
1619 
1620 class CrossLinkIdentifierDatabase(object):
1621  def __init__(self):
1622  self.clidb = dict()
1623 
1624  def check_key(self, key):
1625  if key not in self.clidb:
1626  self.clidb[key] = {}
1627 
1628  def set_unique_id(self, key, value):
1629  self.check_key(key)
1630  self.clidb[key]["XLUniqueID"] = str(value)
1631 
1632  def set_protein1(self, key, value):
1633  self.check_key(key)
1634  self.clidb[key]["Protein1"] = str(value)
1635 
1636  def set_protein2(self, key, value):
1637  self.check_key(key)
1638  self.clidb[key]["Protein2"] = str(value)
1639 
1640  def set_residue1(self, key, value):
1641  self.check_key(key)
1642  self.clidb[key]["Residue1"] = int(value)
1643 
1644  def set_residue2(self, key, value):
1645  self.check_key(key)
1646  self.clidb[key]["Residue2"] = int(value)
1647 
1648  def set_idscore(self, key, value):
1649  self.check_key(key)
1650  self.clidb[key]["IDScore"] = float(value)
1651 
1652  def set_state(self, key, value):
1653  self.check_key(key)
1654  self.clidb[key]["State"] = int(value)
1655 
1656  def set_sigma1(self, key, value):
1657  self.check_key(key)
1658  self.clidb[key]["Sigma1"] = str(value)
1659 
1660  def set_sigma2(self, key, value):
1661  self.check_key(key)
1662  self.clidb[key]["Sigma2"] = str(value)
1663 
1664  def set_psi(self, key, value):
1665  self.check_key(key)
1666  self.clidb[key]["Psi"] = str(value)
1667 
1668  def get_unique_id(self, key):
1669  return self.clidb[key]["XLUniqueID"]
1670 
1671  def get_protein1(self, key):
1672  return self.clidb[key]["Protein1"]
1673 
1674  def get_protein2(self, key):
1675  return self.clidb[key]["Protein2"]
1676 
1677  def get_residue1(self, key):
1678  return self.clidb[key]["Residue1"]
1679 
1680  def get_residue2(self, key):
1681  return self.clidb[key]["Residue2"]
1682 
1683  def get_idscore(self, key):
1684  return self.clidb[key]["IDScore"]
1685 
1686  def get_state(self, key):
1687  return self.clidb[key]["State"]
1688 
1689  def get_sigma1(self, key):
1690  return self.clidb[key]["Sigma1"]
1691 
1692  def get_sigma2(self, key):
1693  return self.clidb[key]["Sigma2"]
1694 
1695  def get_psi(self, key):
1696  return self.clidb[key]["Psi"]
1697 
1698  def set_float_feature(self, key, value, feature_name):
1699  self.check_key(key)
1700  self.clidb[key][feature_name] = float(value)
1701 
1702  def set_int_feature(self, key, value, feature_name):
1703  self.check_key(key)
1704  self.clidb[key][feature_name] = int(value)
1705 
1706  def set_string_feature(self, key, value, feature_name):
1707  self.check_key(key)
1708  self.clidb[key][feature_name] = str(value)
1709 
1710  def get_feature(self, key, feature_name):
1711  return self.clidb[key][feature_name]
1712 
1713  def write(self, filename):
1714  with open(filename, 'wb') as handle:
1715  pickle.dump(self.clidb, handle)
1716 
1717  def load(self, filename):
1718  with open(filename, 'rb') as handle:
1719  self.clidb = pickle.load(handle)
1720 
1721 
1722 def plot_fields(fields, output, framemin=None, framemax=None):
1723  """Plot the given fields and save a figure as `output`.
1724  The fields generally are extracted from a stat file
1725  using ProcessOutput.get_fields()."""
1726  import matplotlib as mpl
1727  mpl.use('Agg')
1728  import matplotlib.pyplot as plt
1729 
1730  plt.rc('lines', linewidth=4)
1731  fig, axs = plt.subplots(nrows=len(fields))
1732  fig.set_size_inches(10.5, 5.5 * len(fields))
1733  plt.rc('axes')
1734 
1735  n = 0
1736  for key in fields:
1737  if framemin is None:
1738  framemin = 0
1739  if framemax is None:
1740  framemax = len(fields[key])
1741  x = list(range(framemin, framemax))
1742  y = [float(y) for y in fields[key][framemin:framemax]]
1743  if len(fields) > 1:
1744  axs[n].plot(x, y)
1745  axs[n].set_title(key, size="xx-large")
1746  axs[n].tick_params(labelsize=18, pad=10)
1747  else:
1748  axs.plot(x, y)
1749  axs.set_title(key, size="xx-large")
1750  axs.tick_params(labelsize=18, pad=10)
1751  n += 1
1752 
1753  # Tweak spacing between subplots to prevent labels from overlapping
1754  plt.subplots_adjust(hspace=0.3)
1755  plt.savefig(output)
1756 
1757 
1758 def plot_field_histogram(name, values_lists, valuename=None, bins=40,
1759  colors=None, format="png", reference_xline=None,
1760  yplotrange=None, xplotrange=None, normalized=True,
1761  leg_names=None):
1762  '''Plot a list of histograms from a value list.
1763  @param name the name of the plot
1764  @param value_lists the list of list of values eg: [[...],[...],[...]]
1765  @param valuename the y-label
1766  @param bins the number of bins
1767  @param colors If None, will use rainbow. Else will use specific list
1768  @param format output format
1769  @param reference_xline plot a reference line parallel to the y-axis
1770  @param yplotrange the range for the y-axis
1771  @param normalized whether the histogram is normalized or not
1772  @param leg_names names for the legend
1773  '''
1774 
1775  import matplotlib as mpl
1776  mpl.use('Agg')
1777  import matplotlib.pyplot as plt
1778  import matplotlib.cm as cm
1779  plt.figure(figsize=(18.0, 9.0))
1780 
1781  if colors is None:
1782  colors = cm.rainbow(np.linspace(0, 1, len(values_lists)))
1783  for nv, values in enumerate(values_lists):
1784  col = colors[nv]
1785  if leg_names is not None:
1786  label = leg_names[nv]
1787  else:
1788  label = str(nv)
1789  plt.hist(
1790  [float(y) for y in values], bins=bins, color=col,
1791  density=normalized, histtype='step', lw=4, label=label)
1792 
1793  # plt.title(name,size="xx-large")
1794  plt.tick_params(labelsize=12, pad=10)
1795  if valuename is None:
1796  plt.xlabel(name, size="xx-large")
1797  else:
1798  plt.xlabel(valuename, size="xx-large")
1799  plt.ylabel("Frequency", size="xx-large")
1800 
1801  if yplotrange is not None:
1802  plt.ylim()
1803  if xplotrange is not None:
1804  plt.xlim(xplotrange)
1805 
1806  plt.legend(loc=2)
1807 
1808  if reference_xline is not None:
1809  plt.axvline(
1810  reference_xline,
1811  color='red',
1812  linestyle='dashed',
1813  linewidth=1)
1814 
1815  plt.savefig(name + "." + format, dpi=150, transparent=True)
1816 
1817 
1818 def plot_fields_box_plots(name, values, positions, frequencies=None,
1819  valuename="None", positionname="None",
1820  xlabels=None, scale_plot_length=1.0):
1821  '''
1822  Plot time series as boxplots.
1823  fields is a list of time series, positions are the x-values
1824  valuename is the y-label, positionname is the x-label
1825  '''
1826 
1827  import matplotlib as mpl
1828  mpl.use('Agg')
1829  import matplotlib.pyplot as plt
1830 
1831  bps = []
1832  fig = plt.figure(figsize=(float(len(positions))*scale_plot_length, 5.0))
1833  fig.canvas.set_window_title(name)
1834 
1835  ax1 = fig.add_subplot(111)
1836 
1837  plt.subplots_adjust(left=0.1, right=0.990, top=0.95, bottom=0.4)
1838 
1839  bps.append(plt.boxplot(values, notch=0, sym='', vert=1,
1840  whis=1.5, positions=positions))
1841 
1842  plt.setp(bps[-1]['boxes'], color='black', lw=1.5)
1843  plt.setp(bps[-1]['whiskers'], color='black', ls=":", lw=1.5)
1844 
1845  if frequencies is not None:
1846  for n, v in enumerate(values):
1847  plist = [positions[n]]*len(v)
1848  ax1.plot(plist, v, 'gx', alpha=0.7, markersize=7)
1849 
1850  # print ax1.xaxis.get_majorticklocs()
1851  if xlabels is not None:
1852  ax1.set_xticklabels(xlabels)
1853  plt.xticks(rotation=90)
1854  plt.xlabel(positionname)
1855  plt.ylabel(valuename)
1856 
1857  plt.savefig(name + ".pdf", dpi=150)
1858  plt.show()
1859 
1860 
1861 def plot_xy_data(x, y, title=None, out_fn=None, display=True,
1862  set_plot_yaxis_range=None, xlabel=None, ylabel=None):
1863  import matplotlib as mpl
1864  mpl.use('Agg')
1865  import matplotlib.pyplot as plt
1866  plt.rc('lines', linewidth=2)
1867 
1868  fig, ax = plt.subplots(nrows=1)
1869  fig.set_size_inches(8, 4.5)
1870  if title is not None:
1871  fig.canvas.set_window_title(title)
1872 
1873  ax.plot(x, y, color='r')
1874  if set_plot_yaxis_range is not None:
1875  x1, x2, y1, y2 = plt.axis()
1876  y1 = set_plot_yaxis_range[0]
1877  y2 = set_plot_yaxis_range[1]
1878  plt.axis((x1, x2, y1, y2))
1879  if title is not None:
1880  ax.set_title(title)
1881  if xlabel is not None:
1882  ax.set_xlabel(xlabel)
1883  if ylabel is not None:
1884  ax.set_ylabel(ylabel)
1885  if out_fn is not None:
1886  plt.savefig(out_fn + ".pdf")
1887  if display:
1888  plt.show()
1889  plt.close(fig)
1890 
1891 
1892 def plot_scatter_xy_data(x, y, labelx="None", labely="None",
1893  xmin=None, xmax=None, ymin=None, ymax=None,
1894  savefile=False, filename="None.eps", alpha=0.75):
1895 
1896  import matplotlib as mpl
1897  mpl.use('Agg')
1898  import matplotlib.pyplot as plt
1899  from matplotlib import rc
1900  rc('font', **{'family': 'sans-serif', 'sans-serif': ['Helvetica']})
1901 
1902  fig, axs = plt.subplots(1)
1903 
1904  axs0 = axs
1905 
1906  axs0.set_xlabel(labelx, size="xx-large")
1907  axs0.set_ylabel(labely, size="xx-large")
1908  axs0.tick_params(labelsize=18, pad=10)
1909 
1910  plot2 = []
1911 
1912  plot2.append(axs0.plot(x, y, 'o', color='k', lw=2, ms=0.1, alpha=alpha,
1913  c="w"))
1914 
1915  axs0.legend(
1916  loc=0,
1917  frameon=False,
1918  scatterpoints=1,
1919  numpoints=1,
1920  columnspacing=1)
1921 
1922  fig.set_size_inches(8.0, 8.0)
1923  fig.subplots_adjust(left=0.161, right=0.850, top=0.95, bottom=0.11)
1924  if (ymin is not None) and (ymax is not None):
1925  axs0.set_ylim(ymin, ymax)
1926  if (xmin is not None) and (xmax is not None):
1927  axs0.set_xlim(xmin, xmax)
1928 
1929  if savefile:
1930  fig.savefig(filename, dpi=300)
1931 
1932 
1933 def get_graph_from_hierarchy(hier):
1934  graph = []
1935  depth_dict = {}
1936  depth = 0
1937  (graph, depth, depth_dict) = recursive_graph(
1938  hier, graph, depth, depth_dict)
1939 
1940  # filters node labels according to depth_dict
1941  node_labels_dict = {}
1942  for key in depth_dict:
1943  if depth_dict[key] < 3:
1944  node_labels_dict[key] = key
1945  else:
1946  node_labels_dict[key] = ""
1947  draw_graph(graph, labels_dict=node_labels_dict)
1948 
1949 
1950 def recursive_graph(hier, graph, depth, depth_dict):
1951  depth = depth + 1
1952  nameh = IMP.atom.Hierarchy(hier).get_name()
1953  index = str(hier.get_particle().get_index())
1954  name1 = nameh + "|#" + index
1955  depth_dict[name1] = depth
1956 
1957  children = IMP.atom.Hierarchy(hier).get_children()
1958 
1959  if len(children) == 1 or children is None:
1960  depth = depth - 1
1961  return (graph, depth, depth_dict)
1962 
1963  else:
1964  for c in children:
1965  (graph, depth, depth_dict) = recursive_graph(
1966  c, graph, depth, depth_dict)
1967  nameh = IMP.atom.Hierarchy(c).get_name()
1968  index = str(c.get_particle().get_index())
1969  namec = nameh + "|#" + index
1970  graph.append((name1, namec))
1971 
1972  depth = depth - 1
1973  return (graph, depth, depth_dict)
1974 
1975 
1976 def draw_graph(graph, labels_dict=None, graph_layout='spring',
1977  node_size=5, node_color=None, node_alpha=0.3,
1978  node_text_size=11, fixed=None, pos=None,
1979  edge_color='blue', edge_alpha=0.3, edge_thickness=1,
1980  edge_text_pos=0.3,
1981  validation_edges=None,
1982  text_font='sans-serif',
1983  out_filename=None):
1984 
1985  import matplotlib as mpl
1986  mpl.use('Agg')
1987  import networkx as nx
1988  import matplotlib.pyplot as plt
1989  from math import sqrt, pi
1990 
1991  # create networkx graph
1992  G = nx.Graph()
1993 
1994  # add edges
1995  if isinstance(edge_thickness, list):
1996  for edge, weight in zip(graph, edge_thickness):
1997  G.add_edge(edge[0], edge[1], weight=weight)
1998  else:
1999  for edge in graph:
2000  G.add_edge(edge[0], edge[1])
2001 
2002  if node_color is None:
2003  node_color_rgb = (0, 0, 0)
2004  node_color_hex = "000000"
2005  else:
2007  tmpcolor_rgb = []
2008  tmpcolor_hex = []
2009  for node in G.nodes():
2010  cctuple = cc.rgb(node_color[node])
2011  tmpcolor_rgb.append((cctuple[0]/255,
2012  cctuple[1]/255,
2013  cctuple[2]/255))
2014  tmpcolor_hex.append(node_color[node])
2015  node_color_rgb = tmpcolor_rgb
2016  node_color_hex = tmpcolor_hex
2017 
2018  # get node sizes if dictionary
2019  if isinstance(node_size, dict):
2020  tmpsize = []
2021  for node in G.nodes():
2022  size = sqrt(node_size[node])/pi*10.0
2023  tmpsize.append(size)
2024  node_size = tmpsize
2025 
2026  for n, node in enumerate(G.nodes()):
2027  color = node_color_hex[n]
2028  size = node_size[n]
2029  nx.set_node_attributes(
2030  G, "graphics",
2031  {node: {'type': 'ellipse', 'w': size, 'h': size,
2032  'fill': '#' + color, 'label': node}})
2033  nx.set_node_attributes(
2034  G, "LabelGraphics",
2035  {node: {'type': 'text', 'text': node, 'color': '#000000',
2036  'visible': 'true'}})
2037 
2038  for edge in G.edges():
2039  nx.set_edge_attributes(
2040  G, "graphics",
2041  {edge: {'width': 1, 'fill': '#000000'}})
2042 
2043  for ve in validation_edges:
2044  print(ve)
2045  if (ve[0], ve[1]) in G.edges():
2046  print("found forward")
2047  nx.set_edge_attributes(
2048  G, "graphics",
2049  {ve: {'width': 1, 'fill': '#00FF00'}})
2050  elif (ve[1], ve[0]) in G.edges():
2051  print("found backward")
2052  nx.set_edge_attributes(
2053  G, "graphics",
2054  {(ve[1], ve[0]): {'width': 1, 'fill': '#00FF00'}})
2055  else:
2056  G.add_edge(ve[0], ve[1])
2057  print("not found")
2058  nx.set_edge_attributes(
2059  G, "graphics",
2060  {ve: {'width': 1, 'fill': '#FF0000'}})
2061 
2062  # these are different layouts for the network you may try
2063  # shell seems to work best
2064  if graph_layout == 'spring':
2065  print(fixed, pos)
2066  graph_pos = nx.spring_layout(G, k=1.0/8.0, fixed=fixed, pos=pos)
2067  elif graph_layout == 'spectral':
2068  graph_pos = nx.spectral_layout(G)
2069  elif graph_layout == 'random':
2070  graph_pos = nx.random_layout(G)
2071  else:
2072  graph_pos = nx.shell_layout(G)
2073 
2074  # draw graph
2075  nx.draw_networkx_nodes(G, graph_pos, node_size=node_size,
2076  alpha=node_alpha, node_color=node_color_rgb,
2077  linewidths=0)
2078  nx.draw_networkx_edges(G, graph_pos, width=edge_thickness,
2079  alpha=edge_alpha, edge_color=edge_color)
2080  nx.draw_networkx_labels(
2081  G, graph_pos, labels=labels_dict, font_size=node_text_size,
2082  font_family=text_font)
2083  if out_filename:
2084  plt.savefig(out_filename)
2085  nx.write_gml(G, 'out.gml')
2086  plt.show()
2087 
2088 
2089 def draw_table():
2090 
2091  # still an example!
2092 
2093  from ipyD3 import d3object
2094  from IPython.display import display
2095 
2096  d3 = d3object(width=800,
2097  height=400,
2098  style='JFTable',
2099  number=1,
2100  d3=None,
2101  title='Example table with d3js',
2102  desc='An example table created created with d3js with '
2103  'data generated with Python.')
2104  data = [[1277.0, 654.0, 288.0, 1976.0, 3281.0, 3089.0, 10336.0, 4650.0,
2105  4441.0, 4670.0, 944.0, 110.0],
2106  [1318.0, 664.0, 418.0, 1952.0, 3581.0, 4574.0, 11457.0, 6139.0,
2107  7078.0, 6561.0, 2354.0, 710.0],
2108  [1783.0, 774.0, 564.0, 1470.0, 3571.0, 3103.0, 9392.0, 5532.0,
2109  5661.0, 4991.0, 2032.0, 680.0],
2110  [1301.0, 604.0, 286.0, 2152.0, 3282.0, 3369.0, 10490.0, 5406.0,
2111  4727.0, 3428.0, 1559.0, 620.0],
2112  [1537.0, 1714.0, 724.0, 4824.0, 5551.0, 8096.0, 16589.0, 13650.0,
2113  9552.0, 13709.0, 2460.0, 720.0],
2114  [5691.0, 2995.0, 1680.0, 11741.0, 16232.0, 14731.0, 43522.0,
2115  32794.0, 26634.0, 31400.0, 7350.0, 3010.0],
2116  [1650.0, 2096.0, 60.0, 50.0, 1180.0, 5602.0, 15728.0, 6874.0,
2117  5115.0, 3510.0, 1390.0, 170.0],
2118  [72.0, 60.0, 60.0, 10.0, 120.0, 172.0, 1092.0, 675.0, 408.0,
2119  360.0, 156.0, 100.0]]
2120  data = [list(i) for i in zip(*data)]
2121  sRows = [['January',
2122  'February',
2123  'March',
2124  'April',
2125  'May',
2126  'June',
2127  'July',
2128  'August',
2129  'September',
2130  'October',
2131  'November',
2132  'Deecember']]
2133  sColumns = [['Prod {0}'.format(i) for i in range(1, 9)],
2134  [None, '', None, None, 'Group 1', None, None, 'Group 2']]
2135  d3.addSimpleTable(data,
2136  fontSizeCells=[12, ],
2137  sRows=sRows,
2138  sColumns=sColumns,
2139  sRowsMargins=[5, 50, 0],
2140  sColsMargins=[5, 20, 10],
2141  spacing=0,
2142  addBorders=1,
2143  addOutsideBorders=-1,
2144  rectWidth=45,
2145  rectHeight=0
2146  )
2147  html = d3.render(mode=['html', 'show'])
2148  display(html)
static bool get_is_setup(const IMP::ParticleAdaptor &p)
Definition: Residue.h:158
A container for models organized into clusters.
Definition: output.py:1510
A class for reading stat files (either rmf or ascii v1 and v2)
Definition: output.py:935
atom::Hierarchies create_hierarchies(RMF::FileConstHandle fh, Model *m)
RMF::FrameID save_frame(RMF::FileHandle file, std::string name="")
Save the current state of the linked objects as a new RMF frame.
static bool get_is_setup(const IMP::ParticleAdaptor &p)
Definition: atom/Atom.h:241
def plot_field_histogram
Plot a list of histograms from a value list.
Definition: output.py:1758
def plot_fields_box_plots
Plot time series as boxplots.
Definition: output.py:1818
Utility classes and functions for reading and storing PMI files.
def get_best_models
Given a list of stat files, read them all and find the best models.
Miscellaneous utilities.
Definition: tools.py:1
A class to store data associated to a model.
Definition: output.py:1488
void handle_use_deprecated(std::string message)
Break in this method in gdb to find deprecated uses at runtime.
std::string get_module_version()
Return the version of this module, as a string.
Change color code to hexadecimal to rgb.
Definition: tools.py:809
void write_pdb(const Selection &mhd, TextOutput out, unsigned int model=1)
Collect statistics from ProcessOutput.get_fields().
Definition: output.py:924
def get_prot_name_from_particle
Return the component name provided a particle and a list of names.
Definition: tools.py:549
def get_fields
Get the desired field names, and return a dictionary.
Definition: output.py:1000
Warning related to handling of structures.
static bool get_is_setup(Model *m, ParticleIndex pi)
Definition: Fragment.h:46
def link_to_rmf
Link to another RMF file.
Definition: output.py:1159
std::string get_molecule_name_and_copy(atom::Hierarchy h)
Walk up a PMI2 hierarchy/representations and get the "molname.copynum".
Definition: pmi/utilities.h:85
The standard decorator for manipulating molecular structures.
Ints get_index(const ParticlesTemp &particles, const Subset &subset, const Subsets &excluded)
def init_pdb
Init PDB Writing.
Definition: output.py:268
int get_number_of_frames(const ::npctransport_proto::Assignment &config, double time_step)
A decorator for a particle representing an atom.
Definition: atom/Atom.h:234
Base class for capturing a modeling protocol.
Definition: output.py:45
The type for a residue.
void load_frame(RMF::FileConstHandle file, RMF::FrameID frame)
Load the given RMF frame into the state of the linked objects.
A decorator for a particle with x,y,z coordinates.
Definition: XYZ.h:30
A base class for Keys.
Definition: Key.h:44
void add_hierarchies(RMF::NodeHandle fh, const atom::Hierarchies &hs)
Class for easy writing of PDBs, RMFs, and stat files.
Definition: output.py:201
void add_geometries(RMF::NodeHandle parent, const display::GeometriesTemp &r)
Add geometries to a given parent node.
void add_restraints(RMF::NodeHandle fh, const Restraints &hs)
A decorator for a particle that is part of a rigid body but not rigid.
Definition: rigid_bodies.h:750
bool get_is_canonical(atom::Hierarchy h)
Walk up a PMI2 hierarchy/representations and check if the root is named System.
Definition: pmi/utilities.h:91
Display a segment connecting a pair of particles.
Definition: XYZR.h:170
A decorator for a residue.
Definition: Residue.h:137
Basic functionality that is expected to be used by a wide variety of IMP users.
def get_pdb_names
Get a list of all PDB files being output by this instance.
Definition: output.py:229
def get_prot_name_from_particle
Get the protein name from the particle.
Definition: output.py:354
class to link stat files to several rmf files
Definition: output.py:1257
class to allow more advanced handling of RMF files.
Definition: output.py:1133
void link_hierarchies(RMF::FileConstHandle fh, const atom::Hierarchies &hs)
def plot_fields
Plot the given fields and save a figure as output.
Definition: output.py:1722
void add_geometry(RMF::FileHandle file, display::Geometry *r)
Add a single geometry to the file.
Store info for a chain of a protein.
Definition: Chain.h:61
Python classes to represent, score, sample and analyze models.
Functionality for loading, creating, manipulating and scoring atomic structures.
def get_rbs_and_beads
Returns unique objects in original order.
Definition: tools.py:1211
Hierarchies get_leaves(const Selection &h)
Select hierarchy particles identified by the biological name.
Definition: Selection.h:66
def init_rmf
Initialize an RMF file.
Definition: output.py:565
def get_residue_indexes
Retrieve the residue indexes for the given particle.
Definition: tools.py:569
static bool get_is_setup(const IMP::ParticleAdaptor &p)
Definition: rigid_bodies.h:752
std::string get_module_version()
Return the version of this module, as a string.
def sublist_iterator
Yield all sublists of length >= lmin and <= lmax.
Definition: tools.py:644
A decorator for a particle with x,y,z coordinates and a radius.
Definition: XYZR.h:27