IMP logo
IMP Reference Guide  2.20.0
The Integrative Modeling Platform
output.py
1 """@namespace IMP.pmi.output
2  Classes for writing output files and processing them.
3 """
4 
5 from __future__ import print_function, division
6 import IMP
7 import IMP.atom
8 import IMP.core
9 import IMP.pmi
10 import IMP.pmi.tools
11 import IMP.pmi.io
12 import os
13 import sys
14 import ast
15 import RMF
16 import numpy as np
17 import operator
18 import itertools
19 import warnings
20 import string
21 import ihm.format
22 import collections
23 try:
24  import cPickle as pickle
25 except ImportError:
26  import pickle
27 
28 
29 class _ChainIDs(object):
30  """Map indices to multi-character chain IDs.
31  We label the first 26 chains A-Z, then we move to two-letter
32  chain IDs: AA through AZ, then BA through BZ, through to ZZ.
33  This continues with longer chain IDs."""
34  def __getitem__(self, ind):
35  chars = string.ascii_uppercase
36  lc = len(chars)
37  ids = []
38  while ind >= lc:
39  ids.append(chars[ind % lc])
40  ind = ind // lc - 1
41  ids.append(chars[ind])
42  return "".join(reversed(ids))
43 
44 
45 class ProtocolOutput(object):
46  """Base class for capturing a modeling protocol.
47  Unlike simple output of model coordinates, a complete
48  protocol includes the input data used, details on the restraints,
49  sampling, and clustering, as well as output models.
50  Use via IMP.pmi.topology.System.add_protocol_output().
51 
52  @see IMP.pmi.mmcif.ProtocolOutput for a concrete subclass that outputs
53  mmCIF files.
54  """
55  pass
56 
57 
58 def _flatten(seq):
59  for elt in seq:
60  if isinstance(elt, (tuple, list)):
61  for elt2 in _flatten(elt):
62  yield elt2
63  else:
64  yield elt
65 
66 
67 def _disambiguate_chain(chid, seen_chains):
68  """Make sure that the chain ID is unique; warn and correct if it isn't"""
69  # Handle null chain IDs
70  if chid == '\0':
71  chid = ' '
72 
73  if chid in seen_chains:
74  warnings.warn("Duplicate chain ID '%s' encountered" % chid,
76 
77  for suffix in itertools.count(1):
78  new_chid = chid + "%d" % suffix
79  if new_chid not in seen_chains:
80  seen_chains.add(new_chid)
81  return new_chid
82  seen_chains.add(chid)
83  return chid
84 
85 
86 def _write_pdb_internal(flpdb, particle_infos_for_pdb, geometric_center,
87  write_all_residues_per_bead):
88  for n, tupl in enumerate(particle_infos_for_pdb):
89  (xyz, atom_type, residue_type,
90  chain_id, residue_index, all_indexes, radius) = tupl
91  if atom_type is None:
92  atom_type = IMP.atom.AT_CA
93  if write_all_residues_per_bead and all_indexes is not None:
94  for residue_number in all_indexes:
95  flpdb.write(
96  IMP.atom.get_pdb_string((xyz[0] - geometric_center[0],
97  xyz[1] - geometric_center[1],
98  xyz[2] - geometric_center[2]),
99  n+1, atom_type, residue_type,
100  chain_id[:1], residue_number, ' ',
101  1.00, radius))
102  else:
103  flpdb.write(
104  IMP.atom.get_pdb_string((xyz[0] - geometric_center[0],
105  xyz[1] - geometric_center[1],
106  xyz[2] - geometric_center[2]),
107  n+1, atom_type, residue_type,
108  chain_id[:1], residue_index, ' ',
109  1.00, radius))
110  flpdb.write("ENDMDL\n")
111 
112 
113 _Entity = collections.namedtuple('_Entity', ('id', 'seq'))
114 _ChainInfo = collections.namedtuple('_ChainInfo', ('entity', 'name'))
115 
116 
117 def _get_chain_info(chains, root_hier):
118  chain_info = {}
119  entities = {}
120  all_entities = []
121  for mol in IMP.atom.get_by_type(root_hier, IMP.atom.MOLECULE_TYPE):
123  chain_id = chains[molname]
124  chain = IMP.atom.Chain(mol)
125  seq = chain.get_sequence()
126  if seq not in entities:
127  entities[seq] = e = _Entity(id=len(entities)+1, seq=seq)
128  all_entities.append(e)
129  entity = entities[seq]
130  info = _ChainInfo(entity=entity, name=molname)
131  chain_info[chain_id] = info
132  return chain_info, all_entities
133 
134 
135 def _write_mmcif_internal(flpdb, particle_infos_for_pdb, geometric_center,
136  write_all_residues_per_bead, chains, root_hier):
137  # get dict with keys=chain IDs, values=chain info
138  chain_info, entities = _get_chain_info(chains, root_hier)
139 
140  writer = ihm.format.CifWriter(flpdb)
141  writer.start_block('model')
142  with writer.category("_entry") as lp:
143  lp.write(id='model')
144 
145  with writer.loop("_entity", ["id", "type"]) as lp:
146  for e in entities:
147  lp.write(id=e.id, type="polymer")
148 
149  with writer.loop("_entity_poly",
150  ["entity_id", "pdbx_seq_one_letter_code"]) as lp:
151  for e in entities:
152  lp.write(entity_id=e.id, pdbx_seq_one_letter_code=e.seq)
153 
154  with writer.loop("_struct_asym", ["id", "entity_id", "details"]) as lp:
155  # Longer chain IDs (e.g. AA) should always come after shorter (e.g. Z)
156  for chid in sorted(chains.values(), key=lambda x: (len(x.strip()), x)):
157  ci = chain_info[chid]
158  lp.write(id=chid, entity_id=ci.entity.id, details=ci.name)
159 
160  with writer.loop("_atom_site",
161  ["group_PDB", "type_symbol", "label_atom_id",
162  "label_comp_id", "label_asym_id", "label_seq_id",
163  "auth_seq_id",
164  "Cartn_x", "Cartn_y", "Cartn_z", "label_entity_id",
165  "pdbx_pdb_model_num",
166  "id"]) as lp:
167  ordinal = 1
168  for n, tupl in enumerate(particle_infos_for_pdb):
169  (xyz, atom_type, residue_type,
170  chain_id, residue_index, all_indexes, radius) = tupl
171  ci = chain_info[chain_id]
172  if atom_type is None:
173  atom_type = IMP.atom.AT_CA
174  c = (xyz[0] - geometric_center[0],
175  xyz[1] - geometric_center[1],
176  xyz[2] - geometric_center[2])
177  if write_all_residues_per_bead and all_indexes is not None:
178  for residue_number in all_indexes:
179  lp.write(group_PDB='ATOM',
180  type_symbol='C',
181  label_atom_id=atom_type.get_string(),
182  label_comp_id=residue_type.get_string(),
183  label_asym_id=chain_id,
184  label_seq_id=residue_index,
185  auth_seq_id=residue_index, Cartn_x=c[0],
186  Cartn_y=c[1], Cartn_z=c[2], id=ordinal,
187  pdbx_pdb_model_num=1,
188  label_entity_id=ci.entity.id)
189  ordinal += 1
190  else:
191  lp.write(group_PDB='ATOM', type_symbol='C',
192  label_atom_id=atom_type.get_string(),
193  label_comp_id=residue_type.get_string(),
194  label_asym_id=chain_id,
195  label_seq_id=residue_index,
196  auth_seq_id=residue_index, Cartn_x=c[0],
197  Cartn_y=c[1], Cartn_z=c[2], id=ordinal,
198  pdbx_pdb_model_num=1,
199  label_entity_id=ci.entity.id)
200  ordinal += 1
201 
202 
203 class Output(object):
204  """Class for easy writing of PDBs, RMFs, and stat files
205 
206  @note Model should be updated prior to writing outputs.
207  """
208  def __init__(self, ascii=True, atomistic=False):
209  self.dictionary_pdbs = {}
210  self._pdb_mmcif = {}
211  self.dictionary_rmfs = {}
212  self.dictionary_stats = {}
213  self.dictionary_stats2 = {}
214  self.best_score_list = None
215  self.nbestscoring = None
216  self.prefixes = []
217  self.replica_exchange = False
218  self.ascii = ascii
219  self.initoutput = {}
220  self.residuetypekey = IMP.StringKey("ResidueName")
221  # 1-character chain IDs, suitable for PDB output
222  self.chainids = "ABCDEFGHIJKLMNOPQRSTUVWXYZ" \
223  "abcdefghijklmnopqrstuvwxyz0123456789"
224  # Multi-character chain IDs, suitable for mmCIF output
225  self.multi_chainids = _ChainIDs()
226  self.dictchain = {} # keys are molecule names, values are chain ids
227  self.particle_infos_for_pdb = {}
228  self.atomistic = atomistic
229 
230  def get_pdb_names(self):
231  """Get a list of all PDB files being output by this instance"""
232  return list(self.dictionary_pdbs.keys())
233 
234  def get_rmf_names(self):
235  return list(self.dictionary_rmfs.keys())
236 
237  def get_stat_names(self):
238  return list(self.dictionary_stats.keys())
239 
240  def _init_dictchain(self, name, prot, multichar_chain=False, mmcif=False):
241  self.dictchain[name] = {}
242  seen_chains = set()
243 
244  # attempt to find PMI objects.
245  self.atomistic = True # detects automatically
246  for n, mol in enumerate(IMP.atom.get_by_type(
247  prot, IMP.atom.MOLECULE_TYPE)):
248  chid = IMP.atom.Chain(mol).get_id()
249  if not mmcif and len(chid) > 1:
250  raise ValueError(
251  "The system contains at least one chain ID (%s) that "
252  "is more than 1 character long; this cannot be "
253  "represented in PDB. Either write mmCIF files "
254  "instead, or assign 1-character IDs to all chains "
255  "(this can be done with the `chain_ids` argument to "
256  "BuildSystem.add_state())." % chid)
257  chid = _disambiguate_chain(chid, seen_chains)
259  self.dictchain[name][molname] = chid
260 
261  def init_pdb(self, name, prot, mmcif=False):
262  """Init PDB Writing.
263  @param name The PDB filename
264  @param prot The hierarchy to write to this pdb file
265  @param mmcif If True, write PDBs in mmCIF format
266  @note if the PDB name is 'System' then will use Selection
267  to get molecules
268  """
269  flpdb = open(name, 'w')
270  flpdb.close()
271  self.dictionary_pdbs[name] = prot
272  self._pdb_mmcif[name] = mmcif
273  self._init_dictchain(name, prot, mmcif=mmcif)
274 
275  def write_psf(self, filename, name):
276  flpsf = open(filename, 'w')
277  flpsf.write("PSF CMAP CHEQ" + "\n")
278  index_residue_pair_list = {}
279  (particle_infos_for_pdb, geometric_center) = \
280  self.get_particle_infos_for_pdb_writing(name)
281  nparticles = len(particle_infos_for_pdb)
282  flpsf.write(str(nparticles) + " !NATOM" + "\n")
283  for n, p in enumerate(particle_infos_for_pdb):
284  atom_index = n+1
285  residue_type = p[2]
286  chain = p[3]
287  resid = p[4]
288  flpsf.write('{0:8d}{1:1s}{2:4s}{3:1s}{4:4s}{5:1s}{6:4s}{7:1s}'
289  '{8:4s}{9:1s}{10:4s}{11:14.6f}{12:14.6f}{13:8d}'
290  '{14:14.6f}{15:14.6f}'.format(
291  atom_index, " ", chain, " ", str(resid), " ",
292  '"'+residue_type.get_string()+'"', " ", "C",
293  " ", "C", 1.0, 0.0, 0, 0.0, 0.0))
294  flpsf.write('\n')
295  if chain not in index_residue_pair_list:
296  index_residue_pair_list[chain] = [(atom_index, resid)]
297  else:
298  index_residue_pair_list[chain].append((atom_index, resid))
299 
300  # now write the connectivity
301  indexes_pairs = []
302  for chain in sorted(index_residue_pair_list.keys()):
303 
304  ls = index_residue_pair_list[chain]
305  # sort by residue
306  ls = sorted(ls, key=lambda tup: tup[1])
307  # get the index list
308  indexes = [x[0] for x in ls]
309  # get the contiguous pairs
310  indexes_pairs.extend(IMP.pmi.tools.sublist_iterator(
311  indexes, lmin=2, lmax=2))
312  nbonds = len(indexes_pairs)
313  flpsf.write(str(nbonds)+" !NBOND: bonds"+"\n")
314 
315  # save bonds in fixed column format
316  for i in range(0, len(indexes_pairs), 4):
317  for bond in indexes_pairs[i:i+4]:
318  flpsf.write('{0:8d}{1:8d}'.format(*bond))
319  flpsf.write('\n')
320 
321  del particle_infos_for_pdb
322  flpsf.close()
323 
324  def write_pdb(self, name, appendmode=True,
325  translate_to_geometric_center=False,
326  write_all_residues_per_bead=False):
327 
328  (particle_infos_for_pdb,
329  geometric_center) = self.get_particle_infos_for_pdb_writing(name)
330 
331  if not translate_to_geometric_center:
332  geometric_center = (0, 0, 0)
333 
334  filemode = 'a' if appendmode else 'w'
335  with open(name, filemode) as flpdb:
336  if self._pdb_mmcif[name]:
337  _write_mmcif_internal(flpdb, particle_infos_for_pdb,
338  geometric_center,
339  write_all_residues_per_bead,
340  self.dictchain[name],
341  self.dictionary_pdbs[name])
342  else:
343  _write_pdb_internal(flpdb, particle_infos_for_pdb,
344  geometric_center,
345  write_all_residues_per_bead)
346 
347  def get_prot_name_from_particle(self, name, p):
348  """Get the protein name from the particle.
349  This is done by traversing the hierarchy."""
350  return IMP.pmi.get_molecule_name_and_copy(p), True
351 
352  def get_particle_infos_for_pdb_writing(self, name):
353  # index_residue_pair_list={}
354 
355  # the resindexes dictionary keep track of residues that have
356  # been already added to avoid duplication
357  # highest resolution have highest priority
358  resindexes_dict = {}
359 
360  # this dictionary will contain the sequence of tuples needed to
361  # write the pdb
362  particle_infos_for_pdb = []
363 
364  geometric_center = [0, 0, 0]
365  atom_count = 0
366 
367  # select highest resolution, if hierarchy is non-empty
368  if (not IMP.core.XYZR.get_is_setup(self.dictionary_pdbs[name])
369  and self.dictionary_pdbs[name].get_number_of_children() == 0):
370  ps = []
371  else:
372  sel = IMP.atom.Selection(self.dictionary_pdbs[name], resolution=0)
373  ps = sel.get_selected_particles()
374 
375  for n, p in enumerate(ps):
376  protname, is_a_bead = self.get_prot_name_from_particle(name, p)
377 
378  if protname not in resindexes_dict:
379  resindexes_dict[protname] = []
380 
381  if IMP.atom.Atom.get_is_setup(p) and self.atomistic:
382  residue = IMP.atom.Residue(IMP.atom.Atom(p).get_parent())
383  rt = residue.get_residue_type()
384  resind = residue.get_index()
385  atomtype = IMP.atom.Atom(p).get_atom_type()
386  xyz = list(IMP.core.XYZ(p).get_coordinates())
387  radius = IMP.core.XYZR(p).get_radius()
388  geometric_center[0] += xyz[0]
389  geometric_center[1] += xyz[1]
390  geometric_center[2] += xyz[2]
391  atom_count += 1
392  particle_infos_for_pdb.append(
393  (xyz, atomtype, rt, self.dictchain[name][protname],
394  resind, None, radius))
395  resindexes_dict[protname].append(resind)
396 
398 
399  residue = IMP.atom.Residue(p)
400  resind = residue.get_index()
401  # skip if the residue was already added by atomistic resolution
402  # 0
403  if resind in resindexes_dict[protname]:
404  continue
405  else:
406  resindexes_dict[protname].append(resind)
407  rt = residue.get_residue_type()
408  xyz = IMP.core.XYZ(p).get_coordinates()
409  radius = IMP.core.XYZR(p).get_radius()
410  geometric_center[0] += xyz[0]
411  geometric_center[1] += xyz[1]
412  geometric_center[2] += xyz[2]
413  atom_count += 1
414  particle_infos_for_pdb.append(
415  (xyz, None, rt, self.dictchain[name][protname], resind,
416  None, radius))
417 
418  elif IMP.atom.Fragment.get_is_setup(p) and not is_a_bead:
419  resindexes = list(IMP.pmi.tools.get_residue_indexes(p))
420  resind = resindexes[len(resindexes) // 2]
421  if resind in resindexes_dict[protname]:
422  continue
423  else:
424  resindexes_dict[protname].append(resind)
425  rt = IMP.atom.ResidueType('BEA')
426  xyz = IMP.core.XYZ(p).get_coordinates()
427  radius = IMP.core.XYZR(p).get_radius()
428  geometric_center[0] += xyz[0]
429  geometric_center[1] += xyz[1]
430  geometric_center[2] += xyz[2]
431  atom_count += 1
432  particle_infos_for_pdb.append(
433  (xyz, None, rt, self.dictchain[name][protname], resind,
434  resindexes, radius))
435 
436  else:
437  if is_a_bead:
438  rt = IMP.atom.ResidueType('BEA')
439  resindexes = list(IMP.pmi.tools.get_residue_indexes(p))
440  if len(resindexes) > 0:
441  resind = resindexes[len(resindexes) // 2]
442  xyz = IMP.core.XYZ(p).get_coordinates()
443  radius = IMP.core.XYZR(p).get_radius()
444  geometric_center[0] += xyz[0]
445  geometric_center[1] += xyz[1]
446  geometric_center[2] += xyz[2]
447  atom_count += 1
448  particle_infos_for_pdb.append(
449  (xyz, None, rt, self.dictchain[name][protname],
450  resind, resindexes, radius))
451 
452  if atom_count > 0:
453  geometric_center = (geometric_center[0] / atom_count,
454  geometric_center[1] / atom_count,
455  geometric_center[2] / atom_count)
456 
457  # sort by chain ID, then residue index. Longer chain IDs (e.g. AA)
458  # should always come after shorter (e.g. Z)
459  particle_infos_for_pdb = sorted(particle_infos_for_pdb,
460  key=lambda x: (len(x[3]), x[3], x[4]))
461 
462  return (particle_infos_for_pdb, geometric_center)
463 
464  def write_pdbs(self, appendmode=True, mmcif=False):
465  for pdb in self.dictionary_pdbs.keys():
466  self.write_pdb(pdb, appendmode)
467 
468  def init_pdb_best_scoring(self, prefix, prot, nbestscoring,
469  replica_exchange=False, mmcif=False,
470  best_score_file='best.scores.rex.py'):
471  """Prepare for writing best-scoring PDBs (or mmCIFs) for a
472  sampling run.
473 
474  @param prefix Initial part of each PDB filename (e.g. 'model').
475  @param prot The top-level Hierarchy to output.
476  @param nbestscoring The number of best-scoring files to output.
477  @param replica_exchange Whether to combine best scores from a
478  replica exchange run.
479  @param mmcif If True, output models in mmCIF format. If False
480  (the default) output in legacy PDB format.
481  @param best_score_file The filename to use for replica
482  exchange scores.
483  """
484 
485  self._pdb_best_scoring_mmcif = mmcif
486  fileext = '.cif' if mmcif else '.pdb'
487  self.prefixes.append(prefix)
488  self.replica_exchange = replica_exchange
489  if not self.replica_exchange:
490  # common usage
491  # if you are not in replica exchange mode
492  # initialize the array of scores internally
493  self.best_score_list = []
494  else:
495  # otherwise the replicas must communicate
496  # through a common file to know what are the best scores
497  self.best_score_file_name = best_score_file
498  self.best_score_list = []
499  with open(self.best_score_file_name, "w") as best_score_file:
500  best_score_file.write(
501  "self.best_score_list=" + str(self.best_score_list) + "\n")
502 
503  self.nbestscoring = nbestscoring
504  for i in range(self.nbestscoring):
505  name = prefix + "." + str(i) + fileext
506  flpdb = open(name, 'w')
507  flpdb.close()
508  self.dictionary_pdbs[name] = prot
509  self._pdb_mmcif[name] = mmcif
510  self._init_dictchain(name, prot, mmcif=mmcif)
511 
512  def write_pdb_best_scoring(self, score):
513  if self.nbestscoring is None:
514  print("Output.write_pdb_best_scoring: init_pdb_best_scoring "
515  "not run")
516 
517  mmcif = self._pdb_best_scoring_mmcif
518  fileext = '.cif' if mmcif else '.pdb'
519  # update the score list
520  if self.replica_exchange:
521  # read the self.best_score_list from the file
522  with open(self.best_score_file_name) as fh:
523  self.best_score_list = ast.literal_eval(
524  fh.read().split('=')[1])
525 
526  if len(self.best_score_list) < self.nbestscoring:
527  self.best_score_list.append(score)
528  self.best_score_list.sort()
529  index = self.best_score_list.index(score)
530  for prefix in self.prefixes:
531  for i in range(len(self.best_score_list) - 2, index - 1, -1):
532  oldname = prefix + "." + str(i) + fileext
533  newname = prefix + "." + str(i + 1) + fileext
534  # rename on Windows fails if newname already exists
535  if os.path.exists(newname):
536  os.unlink(newname)
537  os.rename(oldname, newname)
538  filetoadd = prefix + "." + str(index) + fileext
539  self.write_pdb(filetoadd, appendmode=False)
540 
541  else:
542  if score < self.best_score_list[-1]:
543  self.best_score_list.append(score)
544  self.best_score_list.sort()
545  self.best_score_list.pop(-1)
546  index = self.best_score_list.index(score)
547  for prefix in self.prefixes:
548  for i in range(len(self.best_score_list) - 1,
549  index - 1, -1):
550  oldname = prefix + "." + str(i) + fileext
551  newname = prefix + "." + str(i + 1) + fileext
552  os.rename(oldname, newname)
553  filenametoremove = prefix + \
554  "." + str(self.nbestscoring) + fileext
555  os.remove(filenametoremove)
556  filetoadd = prefix + "." + str(index) + fileext
557  self.write_pdb(filetoadd, appendmode=False)
558 
559  if self.replica_exchange:
560  # write the self.best_score_list to the file
561  with open(self.best_score_file_name, "w") as best_score_file:
562  best_score_file.write(
563  "self.best_score_list=" + str(self.best_score_list) + '\n')
564 
565  def init_rmf(self, name, hierarchies, rs=None, geometries=None,
566  listofobjects=None):
567  """
568  Initialize an RMF file
569 
570  @param name the name of the RMF file
571  @param hierarchies the hierarchies to be included (it is a list)
572  @param rs optional, the restraint sets (it is a list)
573  @param geometries optional, the geometries (it is a list)
574  @param listofobjects optional, the list of objects for the stat
575  (it is a list)
576  """
577  rh = RMF.create_rmf_file(name)
578  IMP.rmf.add_hierarchies(rh, hierarchies)
579  cat = None
580  outputkey_rmfkey = None
581 
582  if rs is not None:
583  IMP.rmf.add_restraints(rh, rs)
584  if geometries is not None:
585  IMP.rmf.add_geometries(rh, geometries)
586  if listofobjects is not None:
587  cat = rh.get_category("stat")
588  outputkey_rmfkey = {}
589  for o in listofobjects:
590  if "get_output" not in dir(o):
591  raise ValueError(
592  "Output: object %s doesn't have get_output() method"
593  % str(o))
594  output = o.get_output()
595  for outputkey in output:
596  rmftag = RMF.string_tag
597  if isinstance(output[outputkey], float):
598  rmftag = RMF.float_tag
599  elif isinstance(output[outputkey], int):
600  rmftag = RMF.int_tag
601  elif isinstance(output[outputkey], str):
602  rmftag = RMF.string_tag
603  else:
604  rmftag = RMF.string_tag
605  rmfkey = rh.get_key(cat, outputkey, rmftag)
606  outputkey_rmfkey[outputkey] = rmfkey
607  outputkey_rmfkey["rmf_file"] = \
608  rh.get_key(cat, "rmf_file", RMF.string_tag)
609  outputkey_rmfkey["rmf_frame_index"] = \
610  rh.get_key(cat, "rmf_frame_index", RMF.int_tag)
611 
612  self.dictionary_rmfs[name] = (rh, cat, outputkey_rmfkey, listofobjects)
613 
614  def add_restraints_to_rmf(self, name, objectlist):
615  for o in _flatten(objectlist):
616  try:
617  rs = o.get_restraint_for_rmf()
618  if not isinstance(rs, (list, tuple)):
619  rs = [rs]
620  except: # noqa: E722
621  rs = [o.get_restraint()]
623  self.dictionary_rmfs[name][0], rs)
624 
625  def add_geometries_to_rmf(self, name, objectlist):
626  for o in objectlist:
627  geos = o.get_geometries()
628  IMP.rmf.add_geometries(self.dictionary_rmfs[name][0], geos)
629 
630  def add_particle_pair_from_restraints_to_rmf(self, name, objectlist):
631  for o in objectlist:
632 
633  pps = o.get_particle_pairs()
634  for pp in pps:
636  self.dictionary_rmfs[name][0],
638 
639  def write_rmf(self, name):
640  IMP.rmf.save_frame(self.dictionary_rmfs[name][0])
641  if self.dictionary_rmfs[name][1] is not None:
642  outputkey_rmfkey = self.dictionary_rmfs[name][2]
643  listofobjects = self.dictionary_rmfs[name][3]
644  for o in listofobjects:
645  output = o.get_output()
646  for outputkey in output:
647  rmfkey = outputkey_rmfkey[outputkey]
648  try:
649  n = self.dictionary_rmfs[name][0].get_root_node()
650  n.set_value(rmfkey, output[outputkey])
651  except NotImplementedError:
652  continue
653  rmfkey = outputkey_rmfkey["rmf_file"]
654  self.dictionary_rmfs[name][0].get_root_node().set_value(
655  rmfkey, name)
656  rmfkey = outputkey_rmfkey["rmf_frame_index"]
657  nframes = self.dictionary_rmfs[name][0].get_number_of_frames()
658  self.dictionary_rmfs[name][0].get_root_node().set_value(
659  rmfkey, nframes-1)
660  self.dictionary_rmfs[name][0].flush()
661 
662  def close_rmf(self, name):
663  rh = self.dictionary_rmfs[name][0]
664  del self.dictionary_rmfs[name]
665  del rh
666 
667  def write_rmfs(self):
668  for rmfinfo in self.dictionary_rmfs.keys():
669  self.write_rmf(rmfinfo[0])
670 
671  def init_stat(self, name, listofobjects):
672  if self.ascii:
673  flstat = open(name, 'w')
674  flstat.close()
675  else:
676  flstat = open(name, 'wb')
677  flstat.close()
678 
679  # check that all objects in listofobjects have a get_output method
680  for o in listofobjects:
681  if "get_output" not in dir(o):
682  raise ValueError(
683  "Output: object %s doesn't have get_output() method"
684  % str(o))
685  self.dictionary_stats[name] = listofobjects
686 
687  def set_output_entry(self, key, value):
688  self.initoutput.update({key: value})
689 
690  def write_stat(self, name, appendmode=True):
691  output = self.initoutput
692  for obj in self.dictionary_stats[name]:
693  d = obj.get_output()
694  # remove all entries that begin with _ (private entries)
695  dfiltered = dict((k, v) for k, v in d.items() if k[0] != "_")
696  output.update(dfiltered)
697 
698  if appendmode:
699  writeflag = 'a'
700  else:
701  writeflag = 'w'
702 
703  if self.ascii:
704  flstat = open(name, writeflag)
705  flstat.write("%s \n" % output)
706  flstat.close()
707  else:
708  flstat = open(name, writeflag + 'b')
709  pickle.dump(output, flstat, 2)
710  flstat.close()
711 
712  def write_stats(self):
713  for stat in self.dictionary_stats.keys():
714  self.write_stat(stat)
715 
716  def get_stat(self, name):
717  output = {}
718  for obj in self.dictionary_stats[name]:
719  output.update(obj.get_output())
720  return output
721 
722  def write_test(self, name, listofobjects):
723  flstat = open(name, 'w')
724  output = self.initoutput
725  for o in listofobjects:
726  if "get_test_output" not in dir(o) and "get_output" not in dir(o):
727  raise ValueError(
728  "Output: object %s doesn't have get_output() or "
729  "get_test_output() method" % str(o))
730  self.dictionary_stats[name] = listofobjects
731 
732  for obj in self.dictionary_stats[name]:
733  try:
734  d = obj.get_test_output()
735  except: # noqa: E722
736  d = obj.get_output()
737  # remove all entries that begin with _ (private entries)
738  dfiltered = dict((k, v) for k, v in d.items() if k[0] != "_")
739  output.update(dfiltered)
740  flstat.write("%s \n" % output)
741  flstat.close()
742 
743  def test(self, name, listofobjects, tolerance=1e-5):
744  output = self.initoutput
745  for o in listofobjects:
746  if "get_test_output" not in dir(o) and "get_output" not in dir(o):
747  raise ValueError(
748  "Output: object %s doesn't have get_output() or "
749  "get_test_output() method" % str(o))
750  for obj in listofobjects:
751  try:
752  output.update(obj.get_test_output())
753  except: # noqa: E722
754  output.update(obj.get_output())
755 
756  flstat = open(name, 'r')
757 
758  passed = True
759  for fl in flstat:
760  test_dict = ast.literal_eval(fl)
761  for k in test_dict:
762  if k in output:
763  old_value = str(test_dict[k])
764  new_value = str(output[k])
765  try:
766  float(old_value)
767  is_float = True
768  except ValueError:
769  is_float = False
770 
771  if is_float:
772  fold = float(old_value)
773  fnew = float(new_value)
774  diff = abs(fold - fnew)
775  if diff > tolerance:
776  print("%s: test failed, old value: %s new value %s; "
777  "diff %f > %f" % (str(k), str(old_value),
778  str(new_value), diff,
779  tolerance), file=sys.stderr)
780  passed = False
781  elif test_dict[k] != output[k]:
782  if len(old_value) < 50 and len(new_value) < 50:
783  print("%s: test failed, old value: %s new value %s"
784  % (str(k), old_value, new_value),
785  file=sys.stderr)
786  passed = False
787  else:
788  print("%s: test failed, omitting results (too long)"
789  % str(k), file=sys.stderr)
790  passed = False
791 
792  else:
793  print("%s from old objects (file %s) not in new objects"
794  % (str(k), str(name)), file=sys.stderr)
795  flstat.close()
796  return passed
797 
798  def get_environment_variables(self):
799  import os
800  return str(os.environ)
801 
802  def get_versions_of_relevant_modules(self):
803  import IMP
804  versions = {}
805  versions["IMP_VERSION"] = IMP.get_module_version()
806  versions["PMI_VERSION"] = IMP.pmi.get_module_version()
807  try:
808  import IMP.isd2
809  versions["ISD2_VERSION"] = IMP.isd2.get_module_version()
810  except ImportError:
811  pass
812  try:
813  import IMP.isd_emxl
814  versions["ISD_EMXL_VERSION"] = IMP.isd_emxl.get_module_version()
815  except ImportError:
816  pass
817  return versions
818 
819  def init_stat2(self, name, listofobjects, extralabels=None,
820  listofsummedobjects=None):
821  # this is a new stat file that should be less
822  # space greedy!
823  # listofsummedobjects must be in the form
824  # [([obj1,obj2,obj3,obj4...],label)]
825  # extralabels
826 
827  if listofsummedobjects is None:
828  listofsummedobjects = []
829  if extralabels is None:
830  extralabels = []
831  flstat = open(name, 'w')
832  output = {}
833  stat2_keywords = {"STAT2HEADER": "STAT2HEADER"}
834  stat2_keywords.update(
835  {"STAT2HEADER_ENVIRON": str(self.get_environment_variables())})
836  stat2_keywords.update(
837  {"STAT2HEADER_IMP_VERSIONS":
838  str(self.get_versions_of_relevant_modules())})
839  stat2_inverse = {}
840 
841  for obj in listofobjects:
842  if "get_output" not in dir(obj):
843  raise ValueError(
844  "Output: object %s doesn't have get_output() method"
845  % str(obj))
846  else:
847  d = obj.get_output()
848  # remove all entries that begin with _ (private entries)
849  dfiltered = dict((k, v)
850  for k, v in d.items() if k[0] != "_")
851  output.update(dfiltered)
852 
853  # check for customizable entries
854  for obj in listofsummedobjects:
855  for t in obj[0]:
856  if "get_output" not in dir(t):
857  raise ValueError(
858  "Output: object %s doesn't have get_output() method"
859  % str(t))
860  else:
861  if "_TotalScore" not in t.get_output():
862  raise ValueError(
863  "Output: object %s doesn't have _TotalScore "
864  "entry to be summed" % str(t))
865  else:
866  output.update({obj[1]: 0.0})
867 
868  for k in extralabels:
869  output.update({k: 0.0})
870 
871  for n, k in enumerate(output):
872  stat2_keywords.update({n: k})
873  stat2_inverse.update({k: n})
874 
875  flstat.write("%s \n" % stat2_keywords)
876  flstat.close()
877  self.dictionary_stats2[name] = (
878  listofobjects,
879  stat2_inverse,
880  listofsummedobjects,
881  extralabels)
882 
883  def write_stat2(self, name, appendmode=True):
884  output = {}
885  (listofobjects, stat2_inverse, listofsummedobjects,
886  extralabels) = self.dictionary_stats2[name]
887 
888  # writing objects
889  for obj in listofobjects:
890  od = obj.get_output()
891  dfiltered = dict((k, v) for k, v in od.items() if k[0] != "_")
892  for k in dfiltered:
893  output.update({stat2_inverse[k]: od[k]})
894 
895  # writing summedobjects
896  for so in listofsummedobjects:
897  partial_score = 0.0
898  for t in so[0]:
899  d = t.get_output()
900  partial_score += float(d["_TotalScore"])
901  output.update({stat2_inverse[so[1]]: str(partial_score)})
902 
903  # writing extralabels
904  for k in extralabels:
905  if k in self.initoutput:
906  output.update({stat2_inverse[k]: self.initoutput[k]})
907  else:
908  output.update({stat2_inverse[k]: "None"})
909 
910  if appendmode:
911  writeflag = 'a'
912  else:
913  writeflag = 'w'
914 
915  flstat = open(name, writeflag)
916  flstat.write("%s \n" % output)
917  flstat.close()
918 
919  def write_stats2(self):
920  for stat in self.dictionary_stats2.keys():
921  self.write_stat2(stat)
922 
923 
924 class OutputStatistics(object):
925  """Collect statistics from ProcessOutput.get_fields().
926  Counters of the total number of frames read, plus the models that
927  passed the various filters used in get_fields(), are provided."""
928  def __init__(self):
929  self.total = 0
930  self.passed_get_every = 0
931  self.passed_filterout = 0
932  self.passed_filtertuple = 0
933 
934 
935 class ProcessOutput(object):
936  """A class for reading stat files (either rmf or ascii v1 and v2)"""
937  def __init__(self, filename):
938  self.filename = filename
939  self.isstat1 = False
940  self.isstat2 = False
941  self.isrmf = False
942 
943  if self.filename is None:
944  raise ValueError("No file name provided. Use -h for help")
945 
946  try:
947  # let's see if that is an rmf file
948  rh = RMF.open_rmf_file_read_only(self.filename)
949  self.isrmf = True
950  cat = rh.get_category('stat')
951  rmf_klist = rh.get_keys(cat)
952  self.rmf_names_keys = dict([(rh.get_name(k), k)
953  for k in rmf_klist])
954  del rh
955 
956  except IOError:
957  f = open(self.filename, "r")
958  # try with an ascii stat file
959  # get the keys from the first line
960  for line in f.readlines():
961  d = ast.literal_eval(line)
962  self.klist = list(d.keys())
963  # check if it is a stat2 file
964  if "STAT2HEADER" in self.klist:
965  self.isstat2 = True
966  for k in self.klist:
967  if "STAT2HEADER" in str(k):
968  # if print_header: print k, d[k]
969  del d[k]
970  stat2_dict = d
971  # get the list of keys sorted by value
972  kkeys = [k[0]
973  for k in sorted(stat2_dict.items(),
974  key=operator.itemgetter(1))]
975  self.klist = [k[1]
976  for k in sorted(stat2_dict.items(),
977  key=operator.itemgetter(1))]
978  self.invstat2_dict = {}
979  for k in kkeys:
980  self.invstat2_dict.update({stat2_dict[k]: k})
981  else:
983  "statfile v1 is deprecated. "
984  "Please convert to statfile v2.\n")
985  self.isstat1 = True
986  self.klist.sort()
987 
988  break
989  f.close()
990 
991  def get_keys(self):
992  if self.isrmf:
993  return sorted(self.rmf_names_keys.keys())
994  else:
995  return self.klist
996 
997  def show_keys(self, ncolumns=2, truncate=65):
998  IMP.pmi.tools.print_multicolumn(self.get_keys(), ncolumns, truncate)
999 
1000  def get_fields(self, fields, filtertuple=None, filterout=None, get_every=1,
1001  statistics=None):
1002  '''
1003  Get the desired field names, and return a dictionary.
1004  Namely, "fields" are the queried keys in the stat file
1005  (eg. ["Total_Score",...])
1006  The returned data structure is a dictionary, where each key is
1007  a field and the value is the time series (ie, frame ordered series)
1008  of that field (ie, {"Total_Score":[Score_0,Score_1,Score_2,,...],....})
1009 
1010  @param fields (list of strings) queried keys in the stat file
1011  (eg. "Total_Score"....)
1012  @param filterout specify if you want to "grep" out something from
1013  the file, so that it is faster
1014  @param filtertuple a tuple that contains
1015  ("TheKeyToBeFiltered",relationship,value)
1016  where relationship = "<", "==", or ">"
1017  @param get_every only read every Nth line from the file
1018  @param statistics if provided, accumulate statistics in an
1019  OutputStatistics object
1020  '''
1021 
1022  if statistics is None:
1023  statistics = OutputStatistics()
1024  outdict = {}
1025  for field in fields:
1026  outdict[field] = []
1027 
1028  # print fields values
1029  if self.isrmf:
1030  rh = RMF.open_rmf_file_read_only(self.filename)
1031  nframes = rh.get_number_of_frames()
1032  for i in range(nframes):
1033  statistics.total += 1
1034  # "get_every" and "filterout" not enforced for RMF
1035  statistics.passed_get_every += 1
1036  statistics.passed_filterout += 1
1037  rh.set_current_frame(RMF.FrameID(i))
1038  if filtertuple is not None:
1039  keytobefiltered = filtertuple[0]
1040  relationship = filtertuple[1]
1041  value = filtertuple[2]
1042  datavalue = rh.get_root_node().get_value(
1043  self.rmf_names_keys[keytobefiltered])
1044  if self.isfiltered(datavalue, relationship, value):
1045  continue
1046 
1047  statistics.passed_filtertuple += 1
1048  for field in fields:
1049  outdict[field].append(rh.get_root_node().get_value(
1050  self.rmf_names_keys[field]))
1051 
1052  else:
1053  f = open(self.filename, "r")
1054  line_number = 0
1055 
1056  for line in f.readlines():
1057  statistics.total += 1
1058  if filterout is not None:
1059  if filterout in line:
1060  continue
1061  statistics.passed_filterout += 1
1062  line_number += 1
1063 
1064  if line_number % get_every != 0:
1065  if line_number == 1 and self.isstat2:
1066  statistics.total -= 1
1067  statistics.passed_filterout -= 1
1068  continue
1069  statistics.passed_get_every += 1
1070  try:
1071  d = ast.literal_eval(line)
1072  except: # noqa: E722
1073  print("# Warning: skipped line number " + str(line_number)
1074  + " not a valid line")
1075  continue
1076 
1077  if self.isstat1:
1078 
1079  if filtertuple is not None:
1080  keytobefiltered = filtertuple[0]
1081  relationship = filtertuple[1]
1082  value = filtertuple[2]
1083  datavalue = d[keytobefiltered]
1084  if self.isfiltered(datavalue, relationship, value):
1085  continue
1086 
1087  statistics.passed_filtertuple += 1
1088  [outdict[field].append(d[field]) for field in fields]
1089 
1090  elif self.isstat2:
1091  if line_number == 1:
1092  statistics.total -= 1
1093  statistics.passed_filterout -= 1
1094  statistics.passed_get_every -= 1
1095  continue
1096 
1097  if filtertuple is not None:
1098  keytobefiltered = filtertuple[0]
1099  relationship = filtertuple[1]
1100  value = filtertuple[2]
1101  datavalue = d[self.invstat2_dict[keytobefiltered]]
1102  if self.isfiltered(datavalue, relationship, value):
1103  continue
1104 
1105  statistics.passed_filtertuple += 1
1106  [outdict[field].append(d[self.invstat2_dict[field]])
1107  for field in fields]
1108 
1109  f.close()
1110 
1111  return outdict
1112 
1113  def isfiltered(self, datavalue, relationship, refvalue):
1114  dofilter = False
1115  try:
1116  _ = float(datavalue)
1117  except ValueError:
1118  raise ValueError("ProcessOutput.filter: datavalue cannot be "
1119  "converted into a float")
1120 
1121  if relationship == "<":
1122  if float(datavalue) >= refvalue:
1123  dofilter = True
1124  if relationship == ">":
1125  if float(datavalue) <= refvalue:
1126  dofilter = True
1127  if relationship == "==":
1128  if float(datavalue) != refvalue:
1129  dofilter = True
1130  return dofilter
1131 
1132 
1134  """ class to allow more advanced handling of RMF files.
1135  It is both a container and a IMP.atom.Hierarchy.
1136  - it is iterable (while loading the corresponding frame)
1137  - Item brackets [] load the corresponding frame
1138  - slice create an iterator
1139  - can relink to another RMF file
1140  """
1141  def __init__(self, model, rmf_file_name):
1142  """
1143  @param model: the IMP.Model()
1144  @param rmf_file_name: str, path of the rmf file
1145  """
1146  self.model = model
1147  try:
1148  self.rh_ref = RMF.open_rmf_file_read_only(rmf_file_name)
1149  except TypeError:
1150  raise TypeError("Wrong rmf file name or type: %s"
1151  % str(rmf_file_name))
1152  hs = IMP.rmf.create_hierarchies(self.rh_ref, self.model)
1153  IMP.rmf.load_frame(self.rh_ref, RMF.FrameID(0))
1154  self.root_hier_ref = hs[0]
1155  IMP.atom.Hierarchy.__init__(self, self.root_hier_ref)
1156  self.model.update()
1157  self.ColorHierarchy = None
1158 
1159  def link_to_rmf(self, rmf_file_name):
1160  """
1161  Link to another RMF file
1162  """
1163  self.rh_ref = RMF.open_rmf_file_read_only(rmf_file_name)
1164  IMP.rmf.link_hierarchies(self.rh_ref, [self])
1165  if self.ColorHierarchy:
1166  self.ColorHierarchy.method()
1167  RMFHierarchyHandler.set_frame(self, 0)
1168 
1169  def set_frame(self, index):
1170  try:
1171  IMP.rmf.load_frame(self.rh_ref, RMF.FrameID(index))
1172  except: # noqa: E722
1173  print("skipping frame %s:%d\n" % (self.current_rmf, index))
1174  self.model.update()
1175 
1176  def get_number_of_frames(self):
1177  return self.rh_ref.get_number_of_frames()
1178 
1179  def __getitem__(self, int_slice_adaptor):
1180  if isinstance(int_slice_adaptor, int):
1181  self.set_frame(int_slice_adaptor)
1182  return int_slice_adaptor
1183  elif isinstance(int_slice_adaptor, slice):
1184  return self.__iter__(int_slice_adaptor)
1185  else:
1186  raise TypeError("Unknown Type")
1187 
1188  def __len__(self):
1189  return self.get_number_of_frames()
1190 
1191  def __iter__(self, slice_key=None):
1192  if slice_key is None:
1193  for nframe in range(len(self)):
1194  yield self[nframe]
1195  else:
1196  for nframe in list(range(len(self)))[slice_key]:
1197  yield self[nframe]
1198 
1199 
1200 class CacheHierarchyCoordinates(object):
1201  def __init__(self, StatHierarchyHandler):
1202  self.xyzs = []
1203  self.nrms = []
1204  self.rbs = []
1205  self.nrm_coors = {}
1206  self.xyz_coors = {}
1207  self.rb_trans = {}
1208  self.current_index = None
1209  self.rmfh = StatHierarchyHandler
1210  rbs, xyzs = IMP.pmi.tools.get_rbs_and_beads([self.rmfh])
1211  self.model = self.rmfh.get_model()
1212  self.rbs = rbs
1213  for xyz in xyzs:
1215  nrm = IMP.core.NonRigidMember(xyz)
1216  self.nrms.append(nrm)
1217  else:
1218  fb = IMP.core.XYZ(xyz)
1219  self.xyzs.append(fb)
1220 
1221  def do_store(self, index):
1222  self.rb_trans[index] = {}
1223  self.nrm_coors[index] = {}
1224  self.xyz_coors[index] = {}
1225  for rb in self.rbs:
1226  self.rb_trans[index][rb] = rb.get_reference_frame()
1227  for nrm in self.nrms:
1228  self.nrm_coors[index][nrm] = nrm.get_internal_coordinates()
1229  for xyz in self.xyzs:
1230  self.xyz_coors[index][xyz] = xyz.get_coordinates()
1231  self.current_index = index
1232 
1233  def do_update(self, index):
1234  if self.current_index != index:
1235  for rb in self.rbs:
1236  rb.set_reference_frame(self.rb_trans[index][rb])
1237  for nrm in self.nrms:
1238  nrm.set_internal_coordinates(self.nrm_coors[index][nrm])
1239  for xyz in self.xyzs:
1240  xyz.set_coordinates(self.xyz_coors[index][xyz])
1241  self.current_index = index
1242  self.model.update()
1243 
1244  def get_number_of_frames(self):
1245  return len(self.rb_trans.keys())
1246 
1247  def __getitem__(self, index):
1248  if isinstance(index, int):
1249  return index in self.rb_trans.keys()
1250  else:
1251  raise TypeError("Unknown Type")
1252 
1253  def __len__(self):
1254  return self.get_number_of_frames()
1255 
1256 
1258  """ class to link stat files to several rmf files """
1259  def __init__(self, model=None, stat_file=None,
1260  number_best_scoring_models=None, score_key=None,
1261  StatHierarchyHandler=None, cache=None):
1262  """
1263 
1264  @param model: IMP.Model()
1265  @param stat_file: either 1) a list or 2) a single stat file names
1266  (either rmfs or ascii, or pickled data or pickled cluster),
1267  3) a dictionary containing an rmf/ascii
1268  stat file name as key and a list of frames as values
1269  @param number_best_scoring_models:
1270  @param StatHierarchyHandler: copy constructor input object
1271  @param cache: cache coordinates and rigid body transformations.
1272  """
1273 
1274  if StatHierarchyHandler is not None:
1275  # overrides all other arguments
1276  # copy constructor: create a copy with
1277  # different RMFHierarchyHandler
1278  self.model = StatHierarchyHandler.model
1279  self.data = StatHierarchyHandler.data
1280  self.number_best_scoring_models = \
1281  StatHierarchyHandler.number_best_scoring_models
1282  self.is_setup = True
1283  self.current_rmf = StatHierarchyHandler.current_rmf
1284  self.current_frame = None
1285  self.current_index = None
1286  self.score_threshold = StatHierarchyHandler.score_threshold
1287  self.score_key = StatHierarchyHandler.score_key
1288  self.cache = StatHierarchyHandler.cache
1289  RMFHierarchyHandler.__init__(self, self.model,
1290  self.current_rmf)
1291  if self.cache:
1292  self.cache = CacheHierarchyCoordinates(self)
1293  else:
1294  self.cache = None
1295  self.set_frame(0)
1296 
1297  else:
1298  # standard constructor
1299  self.model = model
1300  self.data = []
1301  self.number_best_scoring_models = number_best_scoring_models
1302  self.cache = cache
1303 
1304  if score_key is None:
1305  self.score_key = "Total_Score"
1306  else:
1307  self.score_key = score_key
1308  self.is_setup = None
1309  self.current_rmf = None
1310  self.current_frame = None
1311  self.current_index = None
1312  self.score_threshold = None
1313 
1314  if isinstance(stat_file, str):
1315  self.add_stat_file(stat_file)
1316  elif isinstance(stat_file, list):
1317  for f in stat_file:
1318  self.add_stat_file(f)
1319 
1320  def add_stat_file(self, stat_file):
1321  try:
1322  '''check that it is not a pickle file with saved data
1323  from a previous calculation'''
1324  self.load_data(stat_file)
1325 
1326  if self.number_best_scoring_models:
1327  scores = self.get_scores()
1328  max_score = sorted(scores)[
1329  0:min(len(self), self.number_best_scoring_models)][-1]
1330  self.do_filter_by_score(max_score)
1331 
1332  except pickle.UnpicklingError:
1333  '''alternatively read the ascii stat files'''
1334  try:
1335  scores, rmf_files, rmf_frame_indexes, features = \
1336  self.get_info_from_stat_file(stat_file,
1337  self.score_threshold)
1338  except (KeyError, SyntaxError):
1339  # in this case check that is it an rmf file, probably
1340  # without stat stored in
1341  try:
1342  # let's see if that is an rmf file
1343  rh = RMF.open_rmf_file_read_only(stat_file)
1344  nframes = rh.get_number_of_frames()
1345  scores = [0.0]*nframes
1346  rmf_files = [stat_file]*nframes
1347  rmf_frame_indexes = range(nframes)
1348  features = {}
1349  except: # noqa: E722
1350  return
1351 
1352  if len(set(rmf_files)) > 1:
1353  raise ("Multiple RMF files found")
1354 
1355  if not rmf_files:
1356  print("StatHierarchyHandler: Error: Trying to set none as "
1357  "rmf_file (probably empty stat file), aborting")
1358  return
1359 
1360  for n, index in enumerate(rmf_frame_indexes):
1361  featn_dict = dict([(k, features[k][n]) for k in features])
1362  self.data.append(IMP.pmi.output.DataEntry(
1363  stat_file, rmf_files[n], index, scores[n], featn_dict))
1364 
1365  if self.number_best_scoring_models:
1366  scores = self.get_scores()
1367  max_score = sorted(scores)[
1368  0:min(len(self), self.number_best_scoring_models)][-1]
1369  self.do_filter_by_score(max_score)
1370 
1371  if not self.is_setup:
1372  RMFHierarchyHandler.__init__(
1373  self, self.model, self.get_rmf_names()[0])
1374  if self.cache:
1375  self.cache = CacheHierarchyCoordinates(self)
1376  else:
1377  self.cache = None
1378  self.is_setup = True
1379  self.current_rmf = self.get_rmf_names()[0]
1380 
1381  self.set_frame(0)
1382 
1383  def save_data(self, filename='data.pkl'):
1384  with open(filename, 'wb') as fl:
1385  pickle.dump(self.data, fl)
1386 
1387  def load_data(self, filename='data.pkl'):
1388  with open(filename, 'rb') as fl:
1389  data_structure = pickle.load(fl)
1390  # first check that it is a list
1391  if not isinstance(data_structure, list):
1392  raise TypeError(
1393  "%filename should contain a list of IMP.pmi.output.DataEntry "
1394  "or IMP.pmi.output.Cluster" % filename)
1395  # second check the types
1396  if all(isinstance(item, IMP.pmi.output.DataEntry)
1397  for item in data_structure):
1398  self.data = data_structure
1399  elif all(isinstance(item, IMP.pmi.output.Cluster)
1400  for item in data_structure):
1401  nmodels = 0
1402  for cluster in data_structure:
1403  nmodels += len(cluster)
1404  self.data = [None]*nmodels
1405  for cluster in data_structure:
1406  for n, data in enumerate(cluster):
1407  index = cluster.members[n]
1408  self.data[index] = data
1409  else:
1410  raise TypeError(
1411  "%filename should contain a list of IMP.pmi.output.DataEntry "
1412  "or IMP.pmi.output.Cluster" % filename)
1413 
1414  def set_frame(self, index):
1415  if self.cache is not None and self.cache[index]:
1416  self.cache.do_update(index)
1417  else:
1418  nm = self.data[index].rmf_name
1419  fidx = self.data[index].rmf_index
1420  if nm != self.current_rmf:
1421  self.link_to_rmf(nm)
1422  self.current_rmf = nm
1423  self.current_frame = -1
1424  if fidx != self.current_frame:
1425  RMFHierarchyHandler.set_frame(self, fidx)
1426  self.current_frame = fidx
1427  if self.cache is not None:
1428  self.cache.do_store(index)
1429 
1430  self.current_index = index
1431 
1432  def __getitem__(self, int_slice_adaptor):
1433  if isinstance(int_slice_adaptor, int):
1434  self.set_frame(int_slice_adaptor)
1435  return self.data[int_slice_adaptor]
1436  elif isinstance(int_slice_adaptor, slice):
1437  return self.__iter__(int_slice_adaptor)
1438  else:
1439  raise TypeError("Unknown Type")
1440 
1441  def __len__(self):
1442  return len(self.data)
1443 
1444  def __iter__(self, slice_key=None):
1445  if slice_key is None:
1446  for i in range(len(self)):
1447  yield self[i]
1448  else:
1449  for i in range(len(self))[slice_key]:
1450  yield self[i]
1451 
1452  def do_filter_by_score(self, maximum_score):
1453  self.data = [d for d in self.data if d.score <= maximum_score]
1454 
1455  def get_scores(self):
1456  return [d.score for d in self.data]
1457 
1458  def get_feature_series(self, feature_name):
1459  return [d.features[feature_name] for d in self.data]
1460 
1461  def get_feature_names(self):
1462  return self.data[0].features.keys()
1463 
1464  def get_rmf_names(self):
1465  return [d.rmf_name for d in self.data]
1466 
1467  def get_stat_files_names(self):
1468  return [d.stat_file for d in self.data]
1469 
1470  def get_rmf_indexes(self):
1471  return [d.rmf_index for d in self.data]
1472 
1473  def get_info_from_stat_file(self, stat_file, score_threshold=None):
1474  po = ProcessOutput(stat_file)
1475  fs = po.get_keys()
1476  models = IMP.pmi.io.get_best_models(
1477  [stat_file], score_key=self.score_key, feature_keys=fs,
1478  rmf_file_key="rmf_file", rmf_file_frame_key="rmf_frame_index",
1479  prefiltervalue=score_threshold, get_every=1)
1480 
1481  scores = [float(y) for y in models[2]]
1482  rmf_files = models[0]
1483  rmf_frame_indexes = models[1]
1484  features = models[3]
1485  return scores, rmf_files, rmf_frame_indexes, features
1486 
1487 
1488 class DataEntry(object):
1489  '''
1490  A class to store data associated to a model
1491  '''
1492  def __init__(self, stat_file=None, rmf_name=None, rmf_index=None,
1493  score=None, features=None):
1494  self.rmf_name = rmf_name
1495  self.rmf_index = rmf_index
1496  self.score = score
1497  self.features = features
1498  self.stat_file = stat_file
1499 
1500  def __repr__(self):
1501  s = "IMP.pmi.output.DataEntry\n"
1502  s += "---- stat file %s \n" % (self.stat_file)
1503  s += "---- rmf file %s \n" % (self.rmf_name)
1504  s += "---- rmf index %s \n" % (str(self.rmf_index))
1505  s += "---- score %s \n" % (str(self.score))
1506  s += "---- number of features %s \n" % (str(len(self.features.keys())))
1507  return s
1508 
1509 
1510 class Cluster(object):
1511  '''
1512  A container for models organized into clusters
1513  '''
1514  def __init__(self, cid=None):
1515  self.cluster_id = cid
1516  self.members = []
1517  self.precision = None
1518  self.center_index = None
1519  self.members_data = {}
1520 
1521  def add_member(self, index, data=None):
1522  self.members.append(index)
1523  self.members_data[index] = data
1524  self.average_score = self.compute_score()
1525 
1526  def compute_score(self):
1527  try:
1528  score = sum([d.score for d in self])/len(self)
1529  except AttributeError:
1530  score = None
1531  return score
1532 
1533  def __repr__(self):
1534  s = "IMP.pmi.output.Cluster\n"
1535  s += "---- cluster_id %s \n" % str(self.cluster_id)
1536  s += "---- precision %s \n" % str(self.precision)
1537  s += "---- average score %s \n" % str(self.average_score)
1538  s += "---- number of members %s \n" % str(len(self.members))
1539  s += "---- center index %s \n" % str(self.center_index)
1540  return s
1541 
1542  def __getitem__(self, int_slice_adaptor):
1543  if isinstance(int_slice_adaptor, int):
1544  index = self.members[int_slice_adaptor]
1545  return self.members_data[index]
1546  elif isinstance(int_slice_adaptor, slice):
1547  return self.__iter__(int_slice_adaptor)
1548  else:
1549  raise TypeError("Unknown Type")
1550 
1551  def __len__(self):
1552  return len(self.members)
1553 
1554  def __iter__(self, slice_key=None):
1555  if slice_key is None:
1556  for i in range(len(self)):
1557  yield self[i]
1558  else:
1559  for i in range(len(self))[slice_key]:
1560  yield self[i]
1561 
1562  def __add__(self, other):
1563  self.members += other.members
1564  self.members_data.update(other.members_data)
1565  self.average_score = self.compute_score()
1566  self.precision = None
1567  self.center_index = None
1568  return self
1569 
1570 
1571 def plot_clusters_populations(clusters):
1572  indexes = []
1573  populations = []
1574  for cluster in clusters:
1575  indexes.append(cluster.cluster_id)
1576  populations.append(len(cluster))
1577 
1578  import matplotlib.pyplot as plt
1579  fig, ax = plt.subplots()
1580  ax.bar(indexes, populations, 0.5, color='r')
1581  ax.set_ylabel('Population')
1582  ax.set_xlabel(('Cluster index'))
1583  plt.show()
1584 
1585 
1586 def plot_clusters_precisions(clusters):
1587  indexes = []
1588  precisions = []
1589  for cluster in clusters:
1590  indexes.append(cluster.cluster_id)
1591 
1592  prec = cluster.precision
1593  print(cluster.cluster_id, prec)
1594  if prec is None:
1595  prec = 0.0
1596  precisions.append(prec)
1597 
1598  import matplotlib.pyplot as plt
1599  fig, ax = plt.subplots()
1600  ax.bar(indexes, precisions, 0.5, color='r')
1601  ax.set_ylabel('Precision [A]')
1602  ax.set_xlabel(('Cluster index'))
1603  plt.show()
1604 
1605 
1606 def plot_clusters_scores(clusters):
1607  indexes = []
1608  values = []
1609  for cluster in clusters:
1610  indexes.append(cluster.cluster_id)
1611  values.append([])
1612  for data in cluster:
1613  values[-1].append(data.score)
1614 
1615  plot_fields_box_plots("scores.pdf", values, indexes, frequencies=None,
1616  valuename="Scores", positionname="Cluster index",
1617  xlabels=None, scale_plot_length=1.0)
1618 
1619 
1620 class CrossLinkIdentifierDatabase(object):
1621  def __init__(self):
1622  self.clidb = dict()
1623 
1624  def check_key(self, key):
1625  if key not in self.clidb:
1626  self.clidb[key] = {}
1627 
1628  def set_unique_id(self, key, value):
1629  self.check_key(key)
1630  self.clidb[key]["XLUniqueID"] = str(value)
1631 
1632  def set_protein1(self, key, value):
1633  self.check_key(key)
1634  self.clidb[key]["Protein1"] = str(value)
1635 
1636  def set_protein2(self, key, value):
1637  self.check_key(key)
1638  self.clidb[key]["Protein2"] = str(value)
1639 
1640  def set_residue1(self, key, value):
1641  self.check_key(key)
1642  self.clidb[key]["Residue1"] = int(value)
1643 
1644  def set_residue2(self, key, value):
1645  self.check_key(key)
1646  self.clidb[key]["Residue2"] = int(value)
1647 
1648  def set_idscore(self, key, value):
1649  self.check_key(key)
1650  self.clidb[key]["IDScore"] = float(value)
1651 
1652  def set_state(self, key, value):
1653  self.check_key(key)
1654  self.clidb[key]["State"] = int(value)
1655 
1656  def set_sigma1(self, key, value):
1657  self.check_key(key)
1658  self.clidb[key]["Sigma1"] = str(value)
1659 
1660  def set_sigma2(self, key, value):
1661  self.check_key(key)
1662  self.clidb[key]["Sigma2"] = str(value)
1663 
1664  def set_psi(self, key, value):
1665  self.check_key(key)
1666  self.clidb[key]["Psi"] = str(value)
1667 
1668  def get_unique_id(self, key):
1669  return self.clidb[key]["XLUniqueID"]
1670 
1671  def get_protein1(self, key):
1672  return self.clidb[key]["Protein1"]
1673 
1674  def get_protein2(self, key):
1675  return self.clidb[key]["Protein2"]
1676 
1677  def get_residue1(self, key):
1678  return self.clidb[key]["Residue1"]
1679 
1680  def get_residue2(self, key):
1681  return self.clidb[key]["Residue2"]
1682 
1683  def get_idscore(self, key):
1684  return self.clidb[key]["IDScore"]
1685 
1686  def get_state(self, key):
1687  return self.clidb[key]["State"]
1688 
1689  def get_sigma1(self, key):
1690  return self.clidb[key]["Sigma1"]
1691 
1692  def get_sigma2(self, key):
1693  return self.clidb[key]["Sigma2"]
1694 
1695  def get_psi(self, key):
1696  return self.clidb[key]["Psi"]
1697 
1698  def set_float_feature(self, key, value, feature_name):
1699  self.check_key(key)
1700  self.clidb[key][feature_name] = float(value)
1701 
1702  def set_int_feature(self, key, value, feature_name):
1703  self.check_key(key)
1704  self.clidb[key][feature_name] = int(value)
1705 
1706  def set_string_feature(self, key, value, feature_name):
1707  self.check_key(key)
1708  self.clidb[key][feature_name] = str(value)
1709 
1710  def get_feature(self, key, feature_name):
1711  return self.clidb[key][feature_name]
1712 
1713  def write(self, filename):
1714  with open(filename, 'wb') as handle:
1715  pickle.dump(self.clidb, handle)
1716 
1717  def load(self, filename):
1718  with open(filename, 'rb') as handle:
1719  self.clidb = pickle.load(handle)
1720 
1721 
1722 def plot_fields(fields, output, framemin=None, framemax=None):
1723  """Plot the given fields and save a figure as `output`.
1724  The fields generally are extracted from a stat file
1725  using ProcessOutput.get_fields()."""
1726  import matplotlib as mpl
1727  mpl.use('Agg')
1728  import matplotlib.pyplot as plt
1729 
1730  plt.rc('lines', linewidth=4)
1731  fig, axs = plt.subplots(nrows=len(fields))
1732  fig.set_size_inches(10.5, 5.5 * len(fields))
1733  plt.rc('axes')
1734 
1735  n = 0
1736  for key in fields:
1737  if framemin is None:
1738  framemin = 0
1739  if framemax is None:
1740  framemax = len(fields[key])
1741  x = list(range(framemin, framemax))
1742  y = [float(y) for y in fields[key][framemin:framemax]]
1743  if len(fields) > 1:
1744  axs[n].plot(x, y)
1745  axs[n].set_title(key, size="xx-large")
1746  axs[n].tick_params(labelsize=18, pad=10)
1747  else:
1748  axs.plot(x, y)
1749  axs.set_title(key, size="xx-large")
1750  axs.tick_params(labelsize=18, pad=10)
1751  n += 1
1752 
1753  # Tweak spacing between subplots to prevent labels from overlapping
1754  plt.subplots_adjust(hspace=0.3)
1755  plt.savefig(output)
1756 
1757 
1758 def plot_field_histogram(name, values_lists, valuename=None, bins=40,
1759  colors=None, format="png", reference_xline=None,
1760  yplotrange=None, xplotrange=None, normalized=True,
1761  leg_names=None):
1762  '''Plot a list of histograms from a value list.
1763  @param name the name of the plot
1764  @param values_lists the list of list of values eg: [[...],[...],[...]]
1765  @param valuename the y-label
1766  @param bins the number of bins
1767  @param colors If None, will use rainbow. Else will use specific list
1768  @param format output format
1769  @param reference_xline plot a reference line parallel to the y-axis
1770  @param yplotrange the range for the y-axis
1771  @param normalized whether the histogram is normalized or not
1772  @param leg_names names for the legend
1773  '''
1774 
1775  import matplotlib as mpl
1776  mpl.use('Agg')
1777  import matplotlib.pyplot as plt
1778  import matplotlib.cm as cm
1779  plt.figure(figsize=(18.0, 9.0))
1780 
1781  if colors is None:
1782  colors = cm.rainbow(np.linspace(0, 1, len(values_lists)))
1783  for nv, values in enumerate(values_lists):
1784  col = colors[nv]
1785  if leg_names is not None:
1786  label = leg_names[nv]
1787  else:
1788  label = str(nv)
1789  try:
1790  plt.hist(
1791  [float(y) for y in values], bins=bins, color=col,
1792  density=normalized, histtype='step', lw=4, label=label)
1793  except AttributeError:
1794  plt.hist(
1795  [float(y) for y in values], bins=bins, color=col,
1796  normed=normalized, histtype='step', lw=4, label=label)
1797 
1798  # plt.title(name,size="xx-large")
1799  plt.tick_params(labelsize=12, pad=10)
1800  if valuename is None:
1801  plt.xlabel(name, size="xx-large")
1802  else:
1803  plt.xlabel(valuename, size="xx-large")
1804  plt.ylabel("Frequency", size="xx-large")
1805 
1806  if yplotrange is not None:
1807  plt.ylim()
1808  if xplotrange is not None:
1809  plt.xlim(xplotrange)
1810 
1811  plt.legend(loc=2)
1812 
1813  if reference_xline is not None:
1814  plt.axvline(
1815  reference_xline,
1816  color='red',
1817  linestyle='dashed',
1818  linewidth=1)
1819 
1820  plt.savefig(name + "." + format, dpi=150, transparent=True)
1821 
1822 
1823 def plot_fields_box_plots(name, values, positions, frequencies=None,
1824  valuename="None", positionname="None",
1825  xlabels=None, scale_plot_length=1.0):
1826  '''
1827  Plot time series as boxplots.
1828  fields is a list of time series, positions are the x-values
1829  valuename is the y-label, positionname is the x-label
1830  '''
1831 
1832  import matplotlib as mpl
1833  mpl.use('Agg')
1834  import matplotlib.pyplot as plt
1835 
1836  bps = []
1837  fig = plt.figure(figsize=(float(len(positions))*scale_plot_length, 5.0))
1838  fig.canvas.manager.set_window_title(name)
1839 
1840  ax1 = fig.add_subplot(111)
1841 
1842  plt.subplots_adjust(left=0.1, right=0.990, top=0.95, bottom=0.4)
1843 
1844  bps.append(plt.boxplot(values, notch=0, sym='', vert=1,
1845  whis=1.5, positions=positions))
1846 
1847  plt.setp(bps[-1]['boxes'], color='black', lw=1.5)
1848  plt.setp(bps[-1]['whiskers'], color='black', ls=":", lw=1.5)
1849 
1850  if frequencies is not None:
1851  for n, v in enumerate(values):
1852  plist = [positions[n]]*len(v)
1853  ax1.plot(plist, v, 'gx', alpha=0.7, markersize=7)
1854 
1855  # print ax1.xaxis.get_majorticklocs()
1856  if xlabels is not None:
1857  ax1.set_xticklabels(xlabels)
1858  plt.xticks(rotation=90)
1859  plt.xlabel(positionname)
1860  plt.ylabel(valuename)
1861 
1862  plt.savefig(name + ".pdf", dpi=150)
1863  plt.show()
1864 
1865 
1866 def plot_xy_data(x, y, title=None, out_fn=None, display=True,
1867  set_plot_yaxis_range=None, xlabel=None, ylabel=None):
1868  import matplotlib as mpl
1869  mpl.use('Agg')
1870  import matplotlib.pyplot as plt
1871  plt.rc('lines', linewidth=2)
1872 
1873  fig, ax = plt.subplots(nrows=1)
1874  fig.set_size_inches(8, 4.5)
1875  if title is not None:
1876  fig.canvas.manager.set_window_title(title)
1877 
1878  ax.plot(x, y, color='r')
1879  if set_plot_yaxis_range is not None:
1880  x1, x2, y1, y2 = plt.axis()
1881  y1 = set_plot_yaxis_range[0]
1882  y2 = set_plot_yaxis_range[1]
1883  plt.axis((x1, x2, y1, y2))
1884  if title is not None:
1885  ax.set_title(title)
1886  if xlabel is not None:
1887  ax.set_xlabel(xlabel)
1888  if ylabel is not None:
1889  ax.set_ylabel(ylabel)
1890  if out_fn is not None:
1891  plt.savefig(out_fn + ".pdf")
1892  if display:
1893  plt.show()
1894  plt.close(fig)
1895 
1896 
1897 def plot_scatter_xy_data(x, y, labelx="None", labely="None",
1898  xmin=None, xmax=None, ymin=None, ymax=None,
1899  savefile=False, filename="None.eps", alpha=0.75):
1900 
1901  import matplotlib as mpl
1902  mpl.use('Agg')
1903  import matplotlib.pyplot as plt
1904  from matplotlib import rc
1905  rc('font', **{'family': 'sans-serif', 'sans-serif': ['Helvetica']})
1906 
1907  fig, axs = plt.subplots(1)
1908 
1909  axs0 = axs
1910 
1911  axs0.set_xlabel(labelx, size="xx-large")
1912  axs0.set_ylabel(labely, size="xx-large")
1913  axs0.tick_params(labelsize=18, pad=10)
1914 
1915  plot2 = []
1916 
1917  plot2.append(axs0.plot(x, y, 'o', color='k', lw=2, ms=0.1, alpha=alpha,
1918  c="w"))
1919 
1920  axs0.legend(
1921  loc=0,
1922  frameon=False,
1923  scatterpoints=1,
1924  numpoints=1,
1925  columnspacing=1)
1926 
1927  fig.set_size_inches(8.0, 8.0)
1928  fig.subplots_adjust(left=0.161, right=0.850, top=0.95, bottom=0.11)
1929  if (ymin is not None) and (ymax is not None):
1930  axs0.set_ylim(ymin, ymax)
1931  if (xmin is not None) and (xmax is not None):
1932  axs0.set_xlim(xmin, xmax)
1933 
1934  if savefile:
1935  fig.savefig(filename, dpi=300)
1936 
1937 
1938 def get_graph_from_hierarchy(hier):
1939  graph = []
1940  depth_dict = {}
1941  depth = 0
1942  (graph, depth, depth_dict) = recursive_graph(
1943  hier, graph, depth, depth_dict)
1944 
1945  # filters node labels according to depth_dict
1946  node_labels_dict = {}
1947  for key in depth_dict:
1948  if depth_dict[key] < 3:
1949  node_labels_dict[key] = key
1950  else:
1951  node_labels_dict[key] = ""
1952  draw_graph(graph, labels_dict=node_labels_dict)
1953 
1954 
1955 def recursive_graph(hier, graph, depth, depth_dict):
1956  depth = depth + 1
1957  nameh = IMP.atom.Hierarchy(hier).get_name()
1958  index = str(hier.get_particle().get_index())
1959  name1 = nameh + "|#" + index
1960  depth_dict[name1] = depth
1961 
1962  children = IMP.atom.Hierarchy(hier).get_children()
1963 
1964  if len(children) == 1 or children is None:
1965  depth = depth - 1
1966  return (graph, depth, depth_dict)
1967 
1968  else:
1969  for c in children:
1970  (graph, depth, depth_dict) = recursive_graph(
1971  c, graph, depth, depth_dict)
1972  nameh = IMP.atom.Hierarchy(c).get_name()
1973  index = str(c.get_particle().get_index())
1974  namec = nameh + "|#" + index
1975  graph.append((name1, namec))
1976 
1977  depth = depth - 1
1978  return (graph, depth, depth_dict)
1979 
1980 
1981 def draw_graph(graph, labels_dict=None, graph_layout='spring',
1982  node_size=5, node_color=None, node_alpha=0.3,
1983  node_text_size=11, fixed=None, pos=None,
1984  edge_color='blue', edge_alpha=0.3, edge_thickness=1,
1985  edge_text_pos=0.3,
1986  validation_edges=None,
1987  text_font='sans-serif',
1988  out_filename=None):
1989 
1990  import matplotlib as mpl
1991  mpl.use('Agg')
1992  import networkx as nx
1993  import matplotlib.pyplot as plt
1994  from math import sqrt, pi
1995 
1996  # create networkx graph
1997  G = nx.Graph()
1998 
1999  # add edges
2000  if isinstance(edge_thickness, list):
2001  for edge, weight in zip(graph, edge_thickness):
2002  G.add_edge(edge[0], edge[1], weight=weight)
2003  else:
2004  for edge in graph:
2005  G.add_edge(edge[0], edge[1])
2006 
2007  if node_color is None:
2008  node_color_rgb = (0, 0, 0)
2009  node_color_hex = "000000"
2010  else:
2012  tmpcolor_rgb = []
2013  tmpcolor_hex = []
2014  for node in G.nodes():
2015  cctuple = cc.rgb(node_color[node])
2016  tmpcolor_rgb.append((cctuple[0]/255,
2017  cctuple[1]/255,
2018  cctuple[2]/255))
2019  tmpcolor_hex.append(node_color[node])
2020  node_color_rgb = tmpcolor_rgb
2021  node_color_hex = tmpcolor_hex
2022 
2023  # get node sizes if dictionary
2024  if isinstance(node_size, dict):
2025  tmpsize = []
2026  for node in G.nodes():
2027  size = sqrt(node_size[node])/pi*10.0
2028  tmpsize.append(size)
2029  node_size = tmpsize
2030 
2031  for n, node in enumerate(G.nodes()):
2032  color = node_color_hex[n]
2033  size = node_size[n]
2034  nx.set_node_attributes(
2035  G, "graphics",
2036  {node: {'type': 'ellipse', 'w': size, 'h': size,
2037  'fill': '#' + color, 'label': node}})
2038  nx.set_node_attributes(
2039  G, "LabelGraphics",
2040  {node: {'type': 'text', 'text': node, 'color': '#000000',
2041  'visible': 'true'}})
2042 
2043  for edge in G.edges():
2044  nx.set_edge_attributes(
2045  G, "graphics",
2046  {edge: {'width': 1, 'fill': '#000000'}})
2047 
2048  for ve in validation_edges:
2049  print(ve)
2050  if (ve[0], ve[1]) in G.edges():
2051  print("found forward")
2052  nx.set_edge_attributes(
2053  G, "graphics",
2054  {ve: {'width': 1, 'fill': '#00FF00'}})
2055  elif (ve[1], ve[0]) in G.edges():
2056  print("found backward")
2057  nx.set_edge_attributes(
2058  G, "graphics",
2059  {(ve[1], ve[0]): {'width': 1, 'fill': '#00FF00'}})
2060  else:
2061  G.add_edge(ve[0], ve[1])
2062  print("not found")
2063  nx.set_edge_attributes(
2064  G, "graphics",
2065  {ve: {'width': 1, 'fill': '#FF0000'}})
2066 
2067  # these are different layouts for the network you may try
2068  # shell seems to work best
2069  if graph_layout == 'spring':
2070  print(fixed, pos)
2071  graph_pos = nx.spring_layout(G, k=1.0/8.0, fixed=fixed, pos=pos)
2072  elif graph_layout == 'spectral':
2073  graph_pos = nx.spectral_layout(G)
2074  elif graph_layout == 'random':
2075  graph_pos = nx.random_layout(G)
2076  else:
2077  graph_pos = nx.shell_layout(G)
2078 
2079  # draw graph
2080  nx.draw_networkx_nodes(G, graph_pos, node_size=node_size,
2081  alpha=node_alpha, node_color=node_color_rgb,
2082  linewidths=0)
2083  nx.draw_networkx_edges(G, graph_pos, width=edge_thickness,
2084  alpha=edge_alpha, edge_color=edge_color)
2085  nx.draw_networkx_labels(
2086  G, graph_pos, labels=labels_dict, font_size=node_text_size,
2087  font_family=text_font)
2088  if out_filename:
2089  plt.savefig(out_filename)
2090  nx.write_gml(G, 'out.gml')
2091  plt.show()
2092 
2093 
2094 def draw_table():
2095 
2096  # still an example!
2097 
2098  from ipyD3 import d3object
2099  from IPython.display import display
2100 
2101  d3 = d3object(width=800,
2102  height=400,
2103  style='JFTable',
2104  number=1,
2105  d3=None,
2106  title='Example table with d3js',
2107  desc='An example table created created with d3js with '
2108  'data generated with Python.')
2109  data = [[1277.0, 654.0, 288.0, 1976.0, 3281.0, 3089.0, 10336.0, 4650.0,
2110  4441.0, 4670.0, 944.0, 110.0],
2111  [1318.0, 664.0, 418.0, 1952.0, 3581.0, 4574.0, 11457.0, 6139.0,
2112  7078.0, 6561.0, 2354.0, 710.0],
2113  [1783.0, 774.0, 564.0, 1470.0, 3571.0, 3103.0, 9392.0, 5532.0,
2114  5661.0, 4991.0, 2032.0, 680.0],
2115  [1301.0, 604.0, 286.0, 2152.0, 3282.0, 3369.0, 10490.0, 5406.0,
2116  4727.0, 3428.0, 1559.0, 620.0],
2117  [1537.0, 1714.0, 724.0, 4824.0, 5551.0, 8096.0, 16589.0, 13650.0,
2118  9552.0, 13709.0, 2460.0, 720.0],
2119  [5691.0, 2995.0, 1680.0, 11741.0, 16232.0, 14731.0, 43522.0,
2120  32794.0, 26634.0, 31400.0, 7350.0, 3010.0],
2121  [1650.0, 2096.0, 60.0, 50.0, 1180.0, 5602.0, 15728.0, 6874.0,
2122  5115.0, 3510.0, 1390.0, 170.0],
2123  [72.0, 60.0, 60.0, 10.0, 120.0, 172.0, 1092.0, 675.0, 408.0,
2124  360.0, 156.0, 100.0]]
2125  data = [list(i) for i in zip(*data)]
2126  sRows = [['January',
2127  'February',
2128  'March',
2129  'April',
2130  'May',
2131  'June',
2132  'July',
2133  'August',
2134  'September',
2135  'October',
2136  'November',
2137  'Deecember']]
2138  sColumns = [['Prod {0}'.format(i) for i in range(1, 9)],
2139  [None, '', None, None, 'Group 1', None, None, 'Group 2']]
2140  d3.addSimpleTable(data,
2141  fontSizeCells=[12, ],
2142  sRows=sRows,
2143  sColumns=sColumns,
2144  sRowsMargins=[5, 50, 0],
2145  sColsMargins=[5, 20, 10],
2146  spacing=0,
2147  addBorders=1,
2148  addOutsideBorders=-1,
2149  rectWidth=45,
2150  rectHeight=0
2151  )
2152  html = d3.render(mode=['html', 'show'])
2153  display(html)
static bool get_is_setup(const IMP::ParticleAdaptor &p)
Definition: Residue.h:158
A container for models organized into clusters.
Definition: output.py:1510
A class for reading stat files (either rmf or ascii v1 and v2)
Definition: output.py:935
atom::Hierarchies create_hierarchies(RMF::FileConstHandle fh, Model *m)
RMF::FrameID save_frame(RMF::FileHandle file, std::string name="")
Save the current state of the linked objects as a new RMF frame.
static bool get_is_setup(const IMP::ParticleAdaptor &p)
Definition: atom/Atom.h:245
def plot_field_histogram
Plot a list of histograms from a value list.
Definition: output.py:1758
def plot_fields_box_plots
Plot time series as boxplots.
Definition: output.py:1823
Utility classes and functions for reading and storing PMI files.
def get_best_models
Given a list of stat files, read them all and find the best models.
Miscellaneous utilities.
Definition: tools.py:1
A class to store data associated to a model.
Definition: output.py:1488
void handle_use_deprecated(std::string message)
Break in this method in gdb to find deprecated uses at runtime.
std::string get_module_version()
Return the version of this module, as a string.
Change color code to hexadecimal to rgb.
Definition: tools.py:749
void write_pdb(const Selection &mhd, TextOutput out, unsigned int model=1)
Collect statistics from ProcessOutput.get_fields().
Definition: output.py:924
static bool get_is_setup(const IMP::ParticleAdaptor &p)
Definition: XYZR.h:47
def get_fields
Get the desired field names, and return a dictionary.
Definition: output.py:1000
Warning related to handling of structures.
static bool get_is_setup(Model *m, ParticleIndex pi)
Definition: Fragment.h:46
def link_to_rmf
Link to another RMF file.
Definition: output.py:1159
std::string get_molecule_name_and_copy(atom::Hierarchy h)
Walk up a PMI2 hierarchy/representations and get the "molname.copynum".
Definition: pmi/utilities.h:85
The standard decorator for manipulating molecular structures.
Ints get_index(const ParticlesTemp &particles, const Subset &subset, const Subsets &excluded)
def init_pdb
Init PDB Writing.
Definition: output.py:261
int get_number_of_frames(const ::npctransport_proto::Assignment &config, double time_step)
A decorator for a particle representing an atom.
Definition: atom/Atom.h:238
Base class for capturing a modeling protocol.
Definition: output.py:45
The type for a residue.
void load_frame(RMF::FileConstHandle file, RMF::FrameID frame)
Load the given RMF frame into the state of the linked objects.
A decorator for a particle with x,y,z coordinates.
Definition: XYZ.h:30
A base class for Keys.
Definition: Key.h:45
void add_hierarchies(RMF::NodeHandle fh, const atom::Hierarchies &hs)
Class for easy writing of PDBs, RMFs, and stat files.
Definition: output.py:203
void add_geometries(RMF::NodeHandle parent, const display::GeometriesTemp &r)
Add geometries to a given parent node.
void add_restraints(RMF::NodeHandle fh, const Restraints &hs)
A decorator for a particle that is part of a rigid body but not rigid.
Definition: rigid_bodies.h:768
Display a segment connecting a pair of particles.
Definition: XYZR.h:170
A decorator for a residue.
Definition: Residue.h:137
Basic functionality that is expected to be used by a wide variety of IMP users.
def get_pdb_names
Get a list of all PDB files being output by this instance.
Definition: output.py:230
def get_prot_name_from_particle
Get the protein name from the particle.
Definition: output.py:347
class to link stat files to several rmf files
Definition: output.py:1257
class to allow more advanced handling of RMF files.
Definition: output.py:1133
void link_hierarchies(RMF::FileConstHandle fh, const atom::Hierarchies &hs)
def plot_fields
Plot the given fields and save a figure as output.
Definition: output.py:1722
void add_geometry(RMF::FileHandle file, display::Geometry *r)
Add a single geometry to the file.
Store info for a chain of a protein.
Definition: Chain.h:61
Python classes to represent, score, sample and analyze models.
def init_pdb_best_scoring
Prepare for writing best-scoring PDBs (or mmCIFs) for a sampling run.
Definition: output.py:468
Functionality for loading, creating, manipulating and scoring atomic structures.
def get_rbs_and_beads
Returns unique objects in original order.
Definition: tools.py:1151
Select hierarchy particles identified by the biological name.
Definition: Selection.h:70
def init_rmf
Initialize an RMF file.
Definition: output.py:565
def get_residue_indexes
Retrieve the residue indexes for the given particle.
Definition: tools.py:512
static bool get_is_setup(const IMP::ParticleAdaptor &p)
Definition: rigid_bodies.h:770
std::string get_module_version()
Return the version of this module, as a string.
def sublist_iterator
Yield all sublists of length >= lmin and <= lmax.
Definition: tools.py:587
A decorator for a particle with x,y,z coordinates and a radius.
Definition: XYZR.h:27