IMP logo
IMP Reference Guide  develop.23016263b1,2026/04/24
The Integrative Modeling Platform
output.py
1 """@namespace IMP.pmi.output
2  Classes for writing output files and processing them.
3 """
4 
5 import IMP
6 import IMP.atom
7 import IMP.core
8 import IMP.pmi
9 import IMP.pmi.tools
10 import IMP.pmi.io
11 import os
12 import sys
13 import ast
14 import RMF
15 import numpy as np
16 import operator
17 import itertools
18 import warnings
19 import string
20 import ihm.format
21 import collections
22 import pickle
23 
24 
25 class _ChainIDs:
26  """Map indices to multi-character chain IDs.
27  We label the first 26 chains A-Z, then we move to two-letter
28  chain IDs: AA through AZ, then BA through BZ, through to ZZ.
29  This continues with longer chain IDs."""
30  def __getitem__(self, ind):
31  chars = string.ascii_uppercase
32  lc = len(chars)
33  ids = []
34  while ind >= lc:
35  ids.append(chars[ind % lc])
36  ind = ind // lc - 1
37  ids.append(chars[ind])
38  return "".join(reversed(ids))
39 
40 
42  """Base class for capturing a modeling protocol.
43  Unlike simple output of model coordinates, a complete
44  protocol includes the input data used, details on the restraints,
45  sampling, and clustering, as well as output models.
46  Use via IMP.pmi.topology.System.add_protocol_output().
47 
48  @see IMP.pmi.mmcif.ProtocolOutput for a concrete subclass that outputs
49  mmCIF files.
50  """
51  pass
52 
53 
54 def _flatten(seq):
55  for elt in seq:
56  if isinstance(elt, (tuple, list)):
57  for elt2 in _flatten(elt):
58  yield elt2
59  else:
60  yield elt
61 
62 
63 def _disambiguate_chain(chid, seen_chains):
64  """Make sure that the chain ID is unique; warn and correct if it isn't"""
65  # Handle null chain IDs
66  if chid == '\0':
67  chid = ' '
68 
69  if chid in seen_chains:
70  warnings.warn("Duplicate chain ID '%s' encountered" % chid,
72 
73  for suffix in itertools.count(1):
74  new_chid = chid + "%d" % suffix
75  if new_chid not in seen_chains:
76  seen_chains.add(new_chid)
77  return new_chid
78  seen_chains.add(chid)
79  return chid
80 
81 
82 def _write_pdb_internal(flpdb, particle_infos_for_pdb, geometric_center,
83  write_all_residues_per_bead):
84  for n, tupl in enumerate(particle_infos_for_pdb):
85  (xyz, atom_type, residue_type,
86  chain_id, residue_index, all_indexes, radius) = tupl
87  if atom_type is None:
88  atom_type = IMP.atom.AT_CA
89  if write_all_residues_per_bead and all_indexes is not None:
90  for residue_number in all_indexes:
91  flpdb.write(
92  IMP.atom.get_pdb_string((xyz[0] - geometric_center[0],
93  xyz[1] - geometric_center[1],
94  xyz[2] - geometric_center[2]),
95  n+1, atom_type, residue_type,
96  chain_id[:1], residue_number, ' ',
97  1.00, radius))
98  else:
99  flpdb.write(
100  IMP.atom.get_pdb_string((xyz[0] - geometric_center[0],
101  xyz[1] - geometric_center[1],
102  xyz[2] - geometric_center[2]),
103  n+1, atom_type, residue_type,
104  chain_id[:1], residue_index, ' ',
105  1.00, radius))
106  flpdb.write("ENDMDL\n")
107 
108 
109 _Entity = collections.namedtuple('_Entity', ('id', 'seq'))
110 _ChainInfo = collections.namedtuple('_ChainInfo', ('entity', 'name'))
111 
112 
113 def _get_chain_info(chains, root_hier):
114  chain_info = {}
115  entities = {}
116  all_entities = []
117  for mol in IMP.atom.get_by_type(root_hier, IMP.atom.MOLECULE_TYPE):
119  chain_id = chains[molname]
120  chain = IMP.atom.Chain(mol)
121  seq = chain.get_sequence()
122  if seq not in entities:
123  entities[seq] = e = _Entity(id=len(entities)+1, seq=seq)
124  all_entities.append(e)
125  entity = entities[seq]
126  info = _ChainInfo(entity=entity, name=molname)
127  chain_info[chain_id] = info
128  return chain_info, all_entities
129 
130 
131 def _write_mmcif_internal(flpdb, particle_infos_for_pdb, geometric_center,
132  write_all_residues_per_bead, chains, root_hier):
133  # get dict with keys=chain IDs, values=chain info
134  chain_info, entities = _get_chain_info(chains, root_hier)
135 
136  writer = ihm.format.CifWriter(flpdb)
137  writer.start_block('model')
138  with writer.category("_entry") as lp:
139  lp.write(id='model')
140 
141  with writer.loop("_entity", ["id", "type"]) as lp:
142  for e in entities:
143  lp.write(id=e.id, type="polymer")
144 
145  with writer.loop("_entity_poly",
146  ["entity_id", "pdbx_seq_one_letter_code"]) as lp:
147  for e in entities:
148  lp.write(entity_id=e.id, pdbx_seq_one_letter_code=e.seq)
149 
150  with writer.loop("_struct_asym", ["id", "entity_id", "details"]) as lp:
151  # Longer chain IDs (e.g. AA) should always come after shorter (e.g. Z)
152  for chid in sorted(chains.values(), key=lambda x: (len(x.strip()), x)):
153  ci = chain_info[chid]
154  lp.write(id=chid, entity_id=ci.entity.id, details=ci.name)
155 
156  with writer.loop("_atom_site",
157  ["group_PDB", "type_symbol", "label_atom_id",
158  "label_comp_id", "label_asym_id", "label_seq_id",
159  "auth_seq_id",
160  "Cartn_x", "Cartn_y", "Cartn_z", "label_entity_id",
161  "pdbx_pdb_model_num",
162  "id"]) as lp:
163  ordinal = 1
164  for n, tupl in enumerate(particle_infos_for_pdb):
165  (xyz, atom_type, residue_type,
166  chain_id, residue_index, all_indexes, radius) = tupl
167  ci = chain_info[chain_id]
168  if atom_type is None:
169  atom_type = IMP.atom.AT_CA
170  c = (xyz[0] - geometric_center[0],
171  xyz[1] - geometric_center[1],
172  xyz[2] - geometric_center[2])
173  if write_all_residues_per_bead and all_indexes is not None:
174  for residue_number in all_indexes:
175  lp.write(group_PDB='ATOM',
176  type_symbol='C',
177  label_atom_id=atom_type.get_string(),
178  label_comp_id=residue_type.get_string(),
179  label_asym_id=chain_id,
180  label_seq_id=residue_index,
181  auth_seq_id=residue_index, Cartn_x=c[0],
182  Cartn_y=c[1], Cartn_z=c[2], id=ordinal,
183  pdbx_pdb_model_num=1,
184  label_entity_id=ci.entity.id)
185  ordinal += 1
186  else:
187  lp.write(group_PDB='ATOM', type_symbol='C',
188  label_atom_id=atom_type.get_string(),
189  label_comp_id=residue_type.get_string(),
190  label_asym_id=chain_id,
191  label_seq_id=residue_index,
192  auth_seq_id=residue_index, Cartn_x=c[0],
193  Cartn_y=c[1], Cartn_z=c[2], id=ordinal,
194  pdbx_pdb_model_num=1,
195  label_entity_id=ci.entity.id)
196  ordinal += 1
197 
198 
199 class Output:
200  """Class for easy writing of PDBs, RMFs, and stat files
201 
202  @note Model should be updated prior to writing outputs.
203  """
204  def __init__(self, ascii=True, atomistic=False):
205  self.dictionary_pdbs = {}
206  self._pdb_mmcif = {}
207  self.dictionary_rmfs = {}
208  self.dictionary_stats = {}
209  self.dictionary_stats2 = {}
210  self.best_score_list = None
211  self.nbestscoring = None
212  self.prefixes = []
213  self.replica_exchange = False
214  self.ascii = ascii
215  self.initoutput = {}
216  self.residuetypekey = IMP.StringKey("ResidueName")
217  # 1-character chain IDs, suitable for PDB output
218  self.chainids = "ABCDEFGHIJKLMNOPQRSTUVWXYZ" \
219  "abcdefghijklmnopqrstuvwxyz0123456789"
220  # Multi-character chain IDs, suitable for mmCIF output
221  self.multi_chainids = _ChainIDs()
222  self.dictchain = {} # keys are molecule names, values are chain ids
223  self.particle_infos_for_pdb = {}
224  self.atomistic = atomistic
225 
226  def get_pdb_names(self):
227  """Get a list of all PDB files being output by this instance"""
228  return list(self.dictionary_pdbs.keys())
229 
230  def get_rmf_names(self):
231  return list(self.dictionary_rmfs.keys())
232 
233  def get_stat_names(self):
234  return list(self.dictionary_stats.keys())
235 
236  def _init_dictchain(self, name, prot, multichar_chain=False, mmcif=False):
237  self.dictchain[name] = {}
238  seen_chains = set()
239 
240  # attempt to find PMI objects.
241  self.atomistic = True # detects automatically
242  for n, mol in enumerate(IMP.atom.get_by_type(
243  prot, IMP.atom.MOLECULE_TYPE)):
244  chid = IMP.atom.Chain(mol).get_id()
245  if not mmcif and len(chid) > 1:
246  raise ValueError(
247  "The system contains at least one chain ID (%s) that "
248  "is more than 1 character long; this cannot be "
249  "represented in PDB. Either write mmCIF files "
250  "instead, or assign 1-character IDs to all chains "
251  "(this can be done with the `chain_ids` argument to "
252  "BuildSystem.add_state())." % chid)
253  chid = _disambiguate_chain(chid, seen_chains)
255  self.dictchain[name][molname] = chid
256 
257  def init_pdb(self, name, prot, mmcif=False):
258  """Init PDB Writing.
259  @param name The PDB filename
260  @param prot The hierarchy to write to this pdb file
261  @param mmcif If True, write PDBs in mmCIF format
262  @note if the PDB name is 'System' then will use Selection
263  to get molecules
264  """
265  flpdb = open(name, 'w')
266  flpdb.close()
267  self.dictionary_pdbs[name] = prot
268  self._pdb_mmcif[name] = mmcif
269  self._init_dictchain(name, prot, mmcif=mmcif)
270 
271  def write_psf(self, filename, name):
272  flpsf = open(filename, 'w')
273  flpsf.write("PSF CMAP CHEQ" + "\n")
274  index_residue_pair_list = {}
275  (particle_infos_for_pdb, geometric_center) = \
276  self.get_particle_infos_for_pdb_writing(name)
277  nparticles = len(particle_infos_for_pdb)
278  flpsf.write(str(nparticles) + " !NATOM" + "\n")
279  for n, p in enumerate(particle_infos_for_pdb):
280  atom_index = n+1
281  residue_type = p[2]
282  chain = p[3]
283  resid = p[4]
284  flpsf.write('{0:8d}{1:1s}{2:4s}{3:1s}{4:4s}{5:1s}{6:4s}{7:1s}'
285  '{8:4s}{9:1s}{10:4s}{11:14.6f}{12:14.6f}{13:8d}'
286  '{14:14.6f}{15:14.6f}'.format(
287  atom_index, " ", chain, " ", str(resid), " ",
288  '"'+residue_type.get_string()+'"', " ", "C",
289  " ", "C", 1.0, 0.0, 0, 0.0, 0.0))
290  flpsf.write('\n')
291  if chain not in index_residue_pair_list:
292  index_residue_pair_list[chain] = [(atom_index, resid)]
293  else:
294  index_residue_pair_list[chain].append((atom_index, resid))
295 
296  # now write the connectivity
297  indexes_pairs = []
298  for chain in sorted(index_residue_pair_list.keys()):
299 
300  ls = index_residue_pair_list[chain]
301  # sort by residue
302  ls = sorted(ls, key=lambda tup: tup[1])
303  # get the index list
304  indexes = [x[0] for x in ls]
305  # get the contiguous pairs
306  indexes_pairs.extend(IMP.pmi.tools.sublist_iterator(
307  indexes, lmin=2, lmax=2))
308  nbonds = len(indexes_pairs)
309  flpsf.write(str(nbonds)+" !NBOND: bonds"+"\n")
310 
311  # save bonds in fixed column format
312  for i in range(0, len(indexes_pairs), 4):
313  for bond in indexes_pairs[i:i+4]:
314  flpsf.write('{0:8d}{1:8d}'.format(*bond))
315  flpsf.write('\n')
316 
317  del particle_infos_for_pdb
318  flpsf.close()
319 
320  def write_pdb(self, name, appendmode=True,
321  translate_to_geometric_center=False,
322  write_all_residues_per_bead=False):
323 
324  (particle_infos_for_pdb,
325  geometric_center) = self.get_particle_infos_for_pdb_writing(name)
326 
327  if not translate_to_geometric_center:
328  geometric_center = (0, 0, 0)
329 
330  filemode = 'a' if appendmode else 'w'
331  with open(name, filemode) as flpdb:
332  if self._pdb_mmcif[name]:
333  _write_mmcif_internal(flpdb, particle_infos_for_pdb,
334  geometric_center,
335  write_all_residues_per_bead,
336  self.dictchain[name],
337  self.dictionary_pdbs[name])
338  else:
339  _write_pdb_internal(flpdb, particle_infos_for_pdb,
340  geometric_center,
341  write_all_residues_per_bead)
342 
343  def get_prot_name_from_particle(self, name, p):
344  """Get the protein name from the particle.
345  This is done by traversing the hierarchy."""
346  return IMP.pmi.get_molecule_name_and_copy(p), True
347 
348  def get_particle_infos_for_pdb_writing(self, name):
349  # index_residue_pair_list={}
350 
351  # the resindexes dictionary keep track of residues that have
352  # been already added to avoid duplication
353  # highest resolution have highest priority
354  resindexes_dict = {}
355 
356  # this dictionary will contain the sequence of tuples needed to
357  # write the pdb
358  particle_infos_for_pdb = []
359 
360  geometric_center = [0, 0, 0]
361  atom_count = 0
362 
363  # select highest resolution, if hierarchy is non-empty
364  if (not IMP.core.XYZR.get_is_setup(self.dictionary_pdbs[name])
365  and self.dictionary_pdbs[name].get_number_of_children() == 0):
366  ps = []
367  else:
368  sel = IMP.atom.Selection(self.dictionary_pdbs[name], resolution=0)
369  ps = sel.get_selected_particles()
370 
371  for n, p in enumerate(ps):
372  protname, is_a_bead = self.get_prot_name_from_particle(name, p)
373 
374  if protname not in resindexes_dict:
375  resindexes_dict[protname] = []
376 
377  if IMP.atom.Atom.get_is_setup(p) and self.atomistic:
378  residue = IMP.atom.Residue(IMP.atom.Atom(p).get_parent())
379  rt = residue.get_residue_type()
380  resind = residue.get_index()
381  atomtype = IMP.atom.Atom(p).get_atom_type()
382  xyz = list(IMP.core.XYZ(p).get_coordinates())
383  radius = IMP.core.XYZR(p).get_radius()
384  geometric_center[0] += xyz[0]
385  geometric_center[1] += xyz[1]
386  geometric_center[2] += xyz[2]
387  atom_count += 1
388  particle_infos_for_pdb.append(
389  (xyz, atomtype, rt, self.dictchain[name][protname],
390  resind, None, radius))
391  resindexes_dict[protname].append(resind)
392 
394 
395  residue = IMP.atom.Residue(p)
396  resind = residue.get_index()
397  # skip if the residue was already added by atomistic resolution
398  # 0
399  if resind in resindexes_dict[protname]:
400  continue
401  else:
402  resindexes_dict[protname].append(resind)
403  rt = residue.get_residue_type()
404  xyz = IMP.core.XYZ(p).get_coordinates()
405  radius = IMP.core.XYZR(p).get_radius()
406  geometric_center[0] += xyz[0]
407  geometric_center[1] += xyz[1]
408  geometric_center[2] += xyz[2]
409  atom_count += 1
410  particle_infos_for_pdb.append(
411  (xyz, None, rt, self.dictchain[name][protname], resind,
412  None, radius))
413 
414  elif IMP.atom.Fragment.get_is_setup(p) and not is_a_bead:
415  resindexes = list(IMP.pmi.tools.get_residue_indexes(p))
416  resind = resindexes[len(resindexes) // 2]
417  if resind in resindexes_dict[protname]:
418  continue
419  else:
420  resindexes_dict[protname].append(resind)
421  rt = IMP.atom.ResidueType('BEA')
422  xyz = IMP.core.XYZ(p).get_coordinates()
423  radius = IMP.core.XYZR(p).get_radius()
424  geometric_center[0] += xyz[0]
425  geometric_center[1] += xyz[1]
426  geometric_center[2] += xyz[2]
427  atom_count += 1
428  particle_infos_for_pdb.append(
429  (xyz, None, rt, self.dictchain[name][protname], resind,
430  resindexes, radius))
431 
432  else:
433  if is_a_bead:
434  rt = IMP.atom.ResidueType('BEA')
435  resindexes = list(IMP.pmi.tools.get_residue_indexes(p))
436  if len(resindexes) > 0:
437  resind = resindexes[len(resindexes) // 2]
438  xyz = IMP.core.XYZ(p).get_coordinates()
439  radius = IMP.core.XYZR(p).get_radius()
440  geometric_center[0] += xyz[0]
441  geometric_center[1] += xyz[1]
442  geometric_center[2] += xyz[2]
443  atom_count += 1
444  particle_infos_for_pdb.append(
445  (xyz, None, rt, self.dictchain[name][protname],
446  resind, resindexes, radius))
447 
448  if atom_count > 0:
449  geometric_center = (geometric_center[0] / atom_count,
450  geometric_center[1] / atom_count,
451  geometric_center[2] / atom_count)
452 
453  # sort by chain ID, then residue index. Longer chain IDs (e.g. AA)
454  # should always come after shorter (e.g. Z)
455  particle_infos_for_pdb = sorted(particle_infos_for_pdb,
456  key=lambda x: (len(x[3]), x[3], x[4]))
457 
458  return (particle_infos_for_pdb, geometric_center)
459 
460  def write_pdbs(self, appendmode=True, mmcif=False):
461  for pdb in self.dictionary_pdbs.keys():
462  self.write_pdb(pdb, appendmode)
463 
464  def init_pdb_best_scoring(self, prefix, prot, nbestscoring,
465  replica_exchange=False, mmcif=False,
466  best_score_file='best.scores.rex.py'):
467  """Prepare for writing best-scoring PDBs (or mmCIFs) for a
468  sampling run.
469 
470  @param prefix Initial part of each PDB filename (e.g. 'model').
471  @param prot The top-level Hierarchy to output.
472  @param nbestscoring The number of best-scoring files to output.
473  @param replica_exchange Whether to combine best scores from a
474  replica exchange run.
475  @param mmcif If True, output models in mmCIF format. If False
476  (the default) output in legacy PDB format.
477  @param best_score_file The filename to use for replica
478  exchange scores.
479  """
480 
481  self._pdb_best_scoring_mmcif = mmcif
482  fileext = '.cif' if mmcif else '.pdb'
483  self.prefixes.append(prefix)
484  self.replica_exchange = replica_exchange
485  if not self.replica_exchange:
486  # common usage
487  # if you are not in replica exchange mode
488  # initialize the array of scores internally
489  self.best_score_list = []
490  else:
491  # otherwise the replicas must communicate
492  # through a common file to know what are the best scores
493  self.best_score_file_name = best_score_file
494  self.best_score_list = []
495  with open(self.best_score_file_name, "w") as best_score_file:
496  best_score_file.write(
497  "self.best_score_list=" + str(self.best_score_list) + "\n")
498 
499  self.nbestscoring = nbestscoring
500  for i in range(self.nbestscoring):
501  name = prefix + "." + str(i) + fileext
502  flpdb = open(name, 'w')
503  flpdb.close()
504  self.dictionary_pdbs[name] = prot
505  self._pdb_mmcif[name] = mmcif
506  self._init_dictchain(name, prot, mmcif=mmcif)
507 
508  def write_pdb_best_scoring(self, score):
509  if self.nbestscoring is None:
510  print("Output.write_pdb_best_scoring: init_pdb_best_scoring "
511  "not run")
512 
513  mmcif = self._pdb_best_scoring_mmcif
514  fileext = '.cif' if mmcif else '.pdb'
515  # update the score list
516  if self.replica_exchange:
517  # read the self.best_score_list from the file
518  with open(self.best_score_file_name) as fh:
519  self.best_score_list = ast.literal_eval(
520  fh.read().split('=')[1])
521 
522  if len(self.best_score_list) < self.nbestscoring:
523  self.best_score_list.append(score)
524  self.best_score_list.sort()
525  index = self.best_score_list.index(score)
526  for prefix in self.prefixes:
527  for i in range(len(self.best_score_list) - 2, index - 1, -1):
528  oldname = prefix + "." + str(i) + fileext
529  newname = prefix + "." + str(i + 1) + fileext
530  # rename on Windows fails if newname already exists
531  if os.path.exists(newname):
532  os.unlink(newname)
533  os.rename(oldname, newname)
534  filetoadd = prefix + "." + str(index) + fileext
535  self.write_pdb(filetoadd, appendmode=False)
536 
537  else:
538  if score < self.best_score_list[-1]:
539  self.best_score_list.append(score)
540  self.best_score_list.sort()
541  self.best_score_list.pop(-1)
542  index = self.best_score_list.index(score)
543  for prefix in self.prefixes:
544  for i in range(len(self.best_score_list) - 1,
545  index - 1, -1):
546  oldname = prefix + "." + str(i) + fileext
547  newname = prefix + "." + str(i + 1) + fileext
548  os.rename(oldname, newname)
549  filenametoremove = prefix + \
550  "." + str(self.nbestscoring) + fileext
551  os.remove(filenametoremove)
552  filetoadd = prefix + "." + str(index) + fileext
553  self.write_pdb(filetoadd, appendmode=False)
554 
555  if self.replica_exchange:
556  # write the self.best_score_list to the file
557  with open(self.best_score_file_name, "w") as best_score_file:
558  best_score_file.write(
559  "self.best_score_list=" + str(self.best_score_list) + '\n')
560 
561  def init_rmf(self, name, hierarchies, rs=None, geometries=None,
562  listofobjects=None):
563  """
564  Initialize an RMF file
565 
566  @param name the name of the RMF file
567  @param hierarchies the hierarchies to be included (it is a list)
568  @param rs optional, the restraint sets (it is a list)
569  @param geometries optional, the geometries (it is a list)
570  @param listofobjects optional, the list of objects for the stat
571  (it is a list)
572  """
573  rh = RMF.create_rmf_file(name)
574  IMP.rmf.add_hierarchies(rh, hierarchies)
575  cat = None
576  outputkey_rmfkey = None
577 
578  if rs is not None:
579  IMP.rmf.add_restraints(rh, rs)
580  if geometries is not None:
581  IMP.rmf.add_geometries(rh, geometries)
582  dict_objects = []
583  callable_objects = []
584  if listofobjects is not None:
585  cat = rh.get_category("stat")
586  outputkey_rmfkey = {}
587  for o in listofobjects:
588  if not hasattr(o, "get_output"):
589  raise ValueError(
590  "Output: object %s doesn't have get_output() method"
591  % str(o))
592  # get_output() can return either a dict or a callable;
593  # store these in different lists
594  output = o.get_output()
595  if callable(output):
596  callable_objects.append(output)
597  output = output(None)
598  else:
599  dict_objects.append(o)
600  for outputkey in output:
601  rmftag = RMF.string_tag
602  if isinstance(output[outputkey], float):
603  rmftag = RMF.float_tag
604  elif isinstance(output[outputkey], int):
605  rmftag = RMF.int_tag
606  elif isinstance(output[outputkey], str):
607  rmftag = RMF.string_tag
608  else:
609  rmftag = RMF.string_tag
610  rmfkey = rh.get_key(cat, outputkey, rmftag)
611  outputkey_rmfkey[outputkey] = rmfkey
612  outputkey_rmfkey["rmf_file"] = \
613  rh.get_key(cat, "rmf_file", RMF.string_tag)
614  outputkey_rmfkey["rmf_frame_index"] = \
615  rh.get_key(cat, "rmf_frame_index", RMF.int_tag)
616 
617  self.dictionary_rmfs[name] = (rh, cat, outputkey_rmfkey,
618  dict_objects, callable_objects)
619 
620  def add_restraints_to_rmf(self, name, objectlist):
621  for o in _flatten(objectlist):
622  try:
623  rs = o.get_restraint_for_rmf()
624  if not isinstance(rs, (list, tuple)):
625  rs = [rs]
626  except: # noqa: E722
627  rs = [o.get_restraint()]
629  self.dictionary_rmfs[name][0], rs)
630 
631  def add_geometries_to_rmf(self, name, objectlist):
632  for o in objectlist:
633  geos = o.get_geometries()
634  IMP.rmf.add_geometries(self.dictionary_rmfs[name][0], geos)
635 
636  def add_particle_pair_from_restraints_to_rmf(self, name, objectlist):
637  for o in objectlist:
638 
639  pps = o.get_particle_pairs()
640  for pp in pps:
642  self.dictionary_rmfs[name][0],
644 
645  def write_rmf(self, name):
646  IMP.rmf.save_frame(self.dictionary_rmfs[name][0])
647  if self.dictionary_rmfs[name][1] is not None:
648  outputkey_rmfkey = self.dictionary_rmfs[name][2]
649  dict_objects = self.dictionary_rmfs[name][3]
650  callable_objects = self.dictionary_rmfs[name][4]
651 
652  def all_output():
653  for obj in dict_objects:
654  yield obj.get_output()
655  for obj in callable_objects:
656  yield obj(None)
657 
658  for output in all_output():
659  for outputkey in output:
660  rmfkey = outputkey_rmfkey[outputkey]
661  try:
662  n = self.dictionary_rmfs[name][0].get_root_node()
663  n.set_value(rmfkey, output[outputkey])
664  except NotImplementedError:
665  continue
666  rmfkey = outputkey_rmfkey["rmf_file"]
667  self.dictionary_rmfs[name][0].get_root_node().set_value(
668  rmfkey, name)
669  rmfkey = outputkey_rmfkey["rmf_frame_index"]
670  nframes = self.dictionary_rmfs[name][0].get_number_of_frames()
671  self.dictionary_rmfs[name][0].get_root_node().set_value(
672  rmfkey, nframes-1)
673  self.dictionary_rmfs[name][0].flush()
674 
675  def close_rmf(self, name):
676  rh = self.dictionary_rmfs[name][0]
677  del self.dictionary_rmfs[name]
678  del rh
679 
680  def write_rmfs(self):
681  for rmfinfo in self.dictionary_rmfs.keys():
682  self.write_rmf(rmfinfo[0])
683 
684  @IMP.deprecated_method("2.25", "Use init_stat2() instead")
685  def init_stat(self, name, listofobjects):
686  if self.ascii:
687  flstat = open(name, 'w')
688  flstat.close()
689  else:
690  flstat = open(name, 'wb')
691  flstat.close()
692 
693  # check that all objects in listofobjects have a get_output method
694  for o in listofobjects:
695  if not hasattr(o, "get_output"):
696  raise ValueError(
697  "Output: object %s doesn't have get_output() method"
698  % str(o))
699  self.dictionary_stats[name] = listofobjects
700 
701  def set_output_entry(self, key, value):
702  self.initoutput.update({key: value})
703 
704  @IMP.deprecated_method("2.25", "Use write_stat2() instead")
705  def write_stat(self, name, appendmode=True):
706  output = self.initoutput
707  for obj in self.dictionary_stats[name]:
708  d = obj.get_output()
709  # remove all entries that begin with _ (private entries)
710  dfiltered = dict((k, v) for k, v in d.items() if k[0] != "_")
711  output.update(dfiltered)
712 
713  if appendmode:
714  writeflag = 'a'
715  else:
716  writeflag = 'w'
717 
718  if self.ascii:
719  with open(name, writeflag) as flstat:
720  flstat.write("%s \n" % output)
721  else:
722  with open(name, writeflag + 'b') as flstat:
723  pickle.dump(output, flstat, 2)
724 
725  @IMP.deprecated_method("2.25", "Use write_stats2() instead")
726  def write_stats(self):
727  for stat in self.dictionary_stats.keys():
728  self.write_stat(stat)
729 
730  def get_stat(self, name):
731  output = {}
732  for obj in self.dictionary_stats[name]:
733  output.update(obj.get_output())
734  return output
735 
736  def write_test(self, name, listofobjects):
737  flstat = open(name, 'w')
738  output = self.initoutput
739  for o in listofobjects:
740  if (not hasattr(o, "get_test_output")
741  and not hasattr(o, "get_output")):
742  raise ValueError(
743  "Output: object %s doesn't have get_output() or "
744  "get_test_output() method" % str(o))
745  self.dictionary_stats[name] = listofobjects
746 
747  for obj in self.dictionary_stats[name]:
748  try:
749  d = obj.get_test_output()
750  except AttributeError:
751  d = obj.get_output()
752  if callable(d):
753  # Get any scores using the current IMP Model
754  d = d(None)
755  # remove all entries that begin with _ (private entries)
756  dfiltered = dict((k, v) for k, v in d.items() if k[0] != "_")
757  output.update(dfiltered)
758  flstat.write("%s \n" % output)
759  flstat.close()
760 
761  def test(self, name, listofobjects, tolerance=1e-5):
762  output = self.initoutput
763  for o in listofobjects:
764  if (not hasattr(o, "get_test_output")
765  and not hasattr(o, "get_output")):
766  raise ValueError(
767  "Output: object %s doesn't have get_output() or "
768  "get_test_output() method" % str(o))
769  for obj in listofobjects:
770  try:
771  out = obj.get_test_output()
772  except AttributeError:
773  out = obj.get_output()
774  if callable(out):
775  # Get any scores using the current IMP Model
776  out = out(None)
777  output.update(out)
778 
779  flstat = open(name, 'r')
780 
781  passed = True
782  for fl in flstat:
783  test_dict = ast.literal_eval(fl)
784  for k in test_dict:
785  if k in output:
786  old_value = str(test_dict[k])
787  new_value = str(output[k])
788  try:
789  float(old_value)
790  is_float = True
791  except ValueError:
792  is_float = False
793 
794  if is_float:
795  fold = float(old_value)
796  fnew = float(new_value)
797  diff = abs(fold - fnew)
798  if diff > tolerance:
799  print("%s: test failed, old value: %s new value %s; "
800  "diff %f > %f" % (str(k), str(old_value),
801  str(new_value), diff,
802  tolerance), file=sys.stderr)
803  passed = False
804  elif test_dict[k] != output[k]:
805  if len(old_value) < 50 and len(new_value) < 50:
806  print("%s: test failed, old value: %s new value %s"
807  % (str(k), old_value, new_value),
808  file=sys.stderr)
809  passed = False
810  else:
811  print("%s: test failed, omitting results (too long)"
812  % str(k), file=sys.stderr)
813  passed = False
814 
815  else:
816  print("%s from old objects (file %s) not in new objects"
817  % (str(k), str(name)), file=sys.stderr)
818  flstat.close()
819  return passed
820 
821  def get_environment_variables(self):
822  import os
823  return str(os.environ)
824 
825  def get_versions_of_relevant_modules(self):
826  import IMP
827  versions = {}
828  versions["IMP_VERSION"] = IMP.get_module_version()
829  versions["PMI_VERSION"] = IMP.pmi.get_module_version()
830  try:
831  import IMP.isd2
832  versions["ISD2_VERSION"] = IMP.isd2.get_module_version()
833  except ImportError:
834  pass
835  try:
836  import IMP.isd_emxl
837  versions["ISD_EMXL_VERSION"] = IMP.isd_emxl.get_module_version()
838  except ImportError:
839  pass
840  return versions
841 
842  def init_stat2(self, name, listofobjects, extralabels=None,
843  listofsummedobjects=None, jax_model=None):
844  """Write the header for a stat file in v2 format.
845  Lines can then be written to the stat file by calling write_stat2()
846  with the same file name.
847 
848  @param name The file name to write to.
849  @param listofobjects PMI objects that will be reported in the file.
850  Each object must implement the get_output() method.
851  This can either return a dict containing data from the
852  current state of the model, or a callable which returns
853  a similar dict of data each time it is called.
854  """
855  # this is a new stat file that should be less
856  # space greedy!
857  # listofsummedobjects must be in the form
858  # [([obj1,obj2,obj3,obj4...],label)]
859  # extralabels
860 
861  if listofsummedobjects is None:
862  listofsummedobjects = []
863  if extralabels is None:
864  extralabels = []
865  flstat = open(name, 'w')
866  output = {}
867  stat2_keywords = {"STAT2HEADER": "STAT2HEADER"}
868  stat2_keywords.update(
869  {"STAT2HEADER_ENVIRON": str(self.get_environment_variables())})
870  stat2_keywords.update(
871  {"STAT2HEADER_IMP_VERSIONS":
872  str(self.get_versions_of_relevant_modules())})
873  stat2_inverse = {}
874 
875  dict_objects = []
876  callable_objects = []
877  for obj in listofobjects:
878  if not hasattr(obj, "get_output"):
879  raise ValueError(
880  "Output: object %s doesn't have get_output() method"
881  % str(obj))
882  else:
883  # get_output() can return either a dict or a callable;
884  # store these in different lists
885  d = obj.get_output()
886  if callable(d):
887  callable_objects.append(d)
888  d = d(jax_model)
889  else:
890  dict_objects.append(obj)
891  # remove all entries that begin with _ (private entries)
892  dfiltered = dict((k, v)
893  for k, v in d.items() if k[0] != "_")
894  output.update(dfiltered)
895 
896  # check for customizable entries
897  for obj in listofsummedobjects:
898  for t in obj[0]:
899  if not hasattr(t, "get_output"):
900  raise ValueError(
901  "Output: object %s doesn't have get_output() method"
902  % str(t))
903  else:
904  if "_TotalScore" not in t.get_output():
905  raise ValueError(
906  "Output: object %s doesn't have _TotalScore "
907  "entry to be summed" % str(t))
908  else:
909  output.update({obj[1]: 0.0})
910 
911  for k in extralabels:
912  output.update({k: 0.0})
913 
914  for n, k in enumerate(output):
915  stat2_keywords.update({n: k})
916  stat2_inverse.update({k: n})
917 
918  flstat.write("%s \n" % stat2_keywords)
919  flstat.close()
920  self.dictionary_stats2[name] = (
921  dict_objects, callable_objects,
922  stat2_inverse,
923  listofsummedobjects,
924  extralabels)
925 
926  def write_stat2(self, name, appendmode=True, jax_model=None):
927  """Write a single line to a stat file previously created
928  with init_stat2().
929 
930  @param name The file name to write to.
931  """
932  output = {}
933  (dict_objects, callable_objects, stat2_inverse, listofsummedobjects,
934  extralabels) = self.dictionary_stats2[name]
935 
936  def all_output():
937  for obj in dict_objects:
938  yield obj.get_output()
939  for obj in callable_objects:
940  yield obj(jax_model)
941 
942  # writing objects
943  for od in all_output():
944  dfiltered = dict((k, v) for k, v in od.items() if k[0] != "_")
945  for k in dfiltered:
946  output.update({stat2_inverse[k]: od[k]})
947 
948  # writing summedobjects
949  for so in listofsummedobjects:
950  partial_score = 0.0
951  for t in so[0]:
952  d = t.get_output()
953  partial_score += float(d["_TotalScore"])
954  output.update({stat2_inverse[so[1]]: str(partial_score)})
955 
956  # writing extralabels
957  for k in extralabels:
958  if k in self.initoutput:
959  output.update({stat2_inverse[k]: self.initoutput[k]})
960  else:
961  output.update({stat2_inverse[k]: "None"})
962 
963  with open(name, 'a' if appendmode else 'w') as flstat:
964  flstat.write("%s \n" % output)
965 
966  def write_stats2(self):
967  for stat in self.dictionary_stats2.keys():
968  self.write_stat2(stat)
969 
970 
972  """Collect statistics from ProcessOutput.get_fields().
973  Counters of the total number of frames read, plus the models that
974  passed the various filters used in get_fields(), are provided."""
975  def __init__(self):
976  self.total = 0
977  self.passed_get_every = 0
978  self.passed_filterout = 0
979  self.passed_filtertuple = 0
980 
981 
983  """A class for reading stat files (either rmf or ascii v1 and v2)"""
984  def __init__(self, filename):
985  self.filename = filename
986  self.isstat1 = False
987  self.isstat2 = False
988  self.isrmf = False
989 
990  if self.filename is None:
991  raise ValueError("No file name provided. Use -h for help")
992 
993  try:
994  # let's see if that is an rmf file
995  rh = RMF.open_rmf_file_read_only(self.filename)
996  self.isrmf = True
997  cat = rh.get_category('stat')
998  rmf_klist = rh.get_keys(cat)
999  self.rmf_names_keys = dict([(rh.get_name(k), k)
1000  for k in rmf_klist])
1001  del rh
1002 
1003  except IOError:
1004  f = open(self.filename, "r")
1005  # try with an ascii stat file
1006  # get the keys from the first line
1007  for line in f.readlines():
1008  d = ast.literal_eval(line)
1009  self.klist = list(d.keys())
1010  # check if it is a stat2 file
1011  if "STAT2HEADER" in self.klist:
1012  self.isstat2 = True
1013  for k in self.klist:
1014  if "STAT2HEADER" in str(k):
1015  # if print_header: print k, d[k]
1016  del d[k]
1017  stat2_dict = d
1018  # get the list of keys sorted by value
1019  kkeys = [k[0]
1020  for k in sorted(stat2_dict.items(),
1021  key=operator.itemgetter(1))]
1022  self.klist = [k[1]
1023  for k in sorted(stat2_dict.items(),
1024  key=operator.itemgetter(1))]
1025  self.invstat2_dict = {}
1026  for k in kkeys:
1027  self.invstat2_dict.update({stat2_dict[k]: k})
1028  else:
1030  "statfile v1 is deprecated. "
1031  "Please convert to statfile v2.\n")
1032  self.isstat1 = True
1033  self.klist.sort()
1034 
1035  break
1036  f.close()
1037 
1038  def get_keys(self):
1039  if self.isrmf:
1040  return sorted(self.rmf_names_keys.keys())
1041  else:
1042  return self.klist
1043 
1044  def show_keys(self, ncolumns=2, truncate=65):
1045  IMP.pmi.tools.print_multicolumn(self.get_keys(), ncolumns, truncate)
1046 
1047  def get_fields(self, fields, filtertuple=None, filterout=None, get_every=1,
1048  statistics=None):
1049  '''
1050  Get the desired field names, and return a dictionary.
1051  Namely, "fields" are the queried keys in the stat file
1052  (eg. ["Total_Score",...])
1053  The returned data structure is a dictionary, where each key is
1054  a field and the value is the time series (ie, frame ordered series)
1055  of that field (ie, {"Total_Score":[Score_0,Score_1,Score_2,,...],....})
1056 
1057  @param fields (list of strings) queried keys in the stat file
1058  (eg. "Total_Score"....)
1059  @param filterout specify if you want to "grep" out something from
1060  the file, so that it is faster
1061  @param filtertuple a tuple that contains
1062  ("TheKeyToBeFiltered",relationship,value)
1063  where relationship = "<", "==", or ">"
1064  @param get_every only read every Nth line from the file
1065  @param statistics if provided, accumulate statistics in an
1066  OutputStatistics object
1067  '''
1068 
1069  if statistics is None:
1070  statistics = OutputStatistics()
1071  outdict = {}
1072  for field in fields:
1073  outdict[field] = []
1074 
1075  # print fields values
1076  if self.isrmf:
1077  rh = RMF.open_rmf_file_read_only(self.filename)
1078  nframes = rh.get_number_of_frames()
1079  for i in range(nframes):
1080  statistics.total += 1
1081  # "get_every" and "filterout" not enforced for RMF
1082  statistics.passed_get_every += 1
1083  statistics.passed_filterout += 1
1084  rh.set_current_frame(RMF.FrameID(i))
1085  if filtertuple is not None:
1086  keytobefiltered = filtertuple[0]
1087  relationship = filtertuple[1]
1088  value = filtertuple[2]
1089  datavalue = rh.get_root_node().get_value(
1090  self.rmf_names_keys[keytobefiltered])
1091  if self.isfiltered(datavalue, relationship, value):
1092  continue
1093 
1094  statistics.passed_filtertuple += 1
1095  for field in fields:
1096  outdict[field].append(rh.get_root_node().get_value(
1097  self.rmf_names_keys[field]))
1098 
1099  else:
1100  f = open(self.filename, "r")
1101  line_number = 0
1102 
1103  for line in f.readlines():
1104  statistics.total += 1
1105  if filterout is not None:
1106  if filterout in line:
1107  continue
1108  statistics.passed_filterout += 1
1109  line_number += 1
1110 
1111  if line_number % get_every != 0:
1112  if line_number == 1 and self.isstat2:
1113  statistics.total -= 1
1114  statistics.passed_filterout -= 1
1115  continue
1116  statistics.passed_get_every += 1
1117  try:
1118  d = ast.literal_eval(line)
1119  except: # noqa: E722
1120  print("# Warning: skipped line number " + str(line_number)
1121  + " not a valid line")
1122  continue
1123 
1124  if self.isstat1:
1125 
1126  if filtertuple is not None:
1127  keytobefiltered = filtertuple[0]
1128  relationship = filtertuple[1]
1129  value = filtertuple[2]
1130  datavalue = d[keytobefiltered]
1131  if self.isfiltered(datavalue, relationship, value):
1132  continue
1133 
1134  statistics.passed_filtertuple += 1
1135  [outdict[field].append(d[field]) for field in fields]
1136 
1137  elif self.isstat2:
1138  if line_number == 1:
1139  statistics.total -= 1
1140  statistics.passed_filterout -= 1
1141  statistics.passed_get_every -= 1
1142  continue
1143 
1144  if filtertuple is not None:
1145  keytobefiltered = filtertuple[0]
1146  relationship = filtertuple[1]
1147  value = filtertuple[2]
1148  datavalue = d[self.invstat2_dict[keytobefiltered]]
1149  if self.isfiltered(datavalue, relationship, value):
1150  continue
1151 
1152  statistics.passed_filtertuple += 1
1153  [outdict[field].append(d[self.invstat2_dict[field]])
1154  for field in fields]
1155 
1156  f.close()
1157 
1158  return outdict
1159 
1160  def isfiltered(self, datavalue, relationship, refvalue):
1161  dofilter = False
1162  try:
1163  _ = float(datavalue)
1164  except ValueError:
1165  raise ValueError("ProcessOutput.filter: datavalue cannot be "
1166  "converted into a float")
1167 
1168  if relationship == "<":
1169  if float(datavalue) >= refvalue:
1170  dofilter = True
1171  if relationship == ">":
1172  if float(datavalue) <= refvalue:
1173  dofilter = True
1174  if relationship == "==":
1175  if float(datavalue) != refvalue:
1176  dofilter = True
1177  return dofilter
1178 
1179 
1181  """ class to allow more advanced handling of RMF files.
1182  It is both a container and a IMP.atom.Hierarchy.
1183  - it is iterable (while loading the corresponding frame)
1184  - Item brackets [] load the corresponding frame
1185  - slice create an iterator
1186  - can relink to another RMF file
1187  """
1188  def __init__(self, model, rmf_file_name):
1189  """
1190  @param model: the IMP.Model()
1191  @param rmf_file_name: str, path of the rmf file
1192  """
1193  self.model = model
1194  try:
1195  self.rh_ref = RMF.open_rmf_file_read_only(rmf_file_name)
1196  except TypeError:
1197  raise TypeError("Wrong rmf file name or type: %s"
1198  % str(rmf_file_name))
1199  hs = IMP.rmf.create_hierarchies(self.rh_ref, self.model)
1200  IMP.rmf.load_frame(self.rh_ref, RMF.FrameID(0))
1201  self.root_hier_ref = hs[0]
1202  super().__init__(self.root_hier_ref)
1203  self.model.update()
1204  self.ColorHierarchy = None
1205 
1206  def link_to_rmf(self, rmf_file_name):
1207  """
1208  Link to another RMF file
1209  """
1210  self.rh_ref = RMF.open_rmf_file_read_only(rmf_file_name)
1211  IMP.rmf.link_hierarchies(self.rh_ref, [self])
1212  if self.ColorHierarchy:
1213  self.ColorHierarchy.method()
1214  RMFHierarchyHandler.set_frame(self, 0)
1215 
1216  def set_frame(self, index):
1217  try:
1218  IMP.rmf.load_frame(self.rh_ref, RMF.FrameID(index))
1219  except: # noqa: E722
1220  print("skipping frame %s:%d\n" % (self.current_rmf, index))
1221  self.model.update()
1222 
1223  def get_number_of_frames(self):
1224  return self.rh_ref.get_number_of_frames()
1225 
1226  def __getitem__(self, int_slice_adaptor):
1227  if isinstance(int_slice_adaptor, int):
1228  self.set_frame(int_slice_adaptor)
1229  return int_slice_adaptor
1230  elif isinstance(int_slice_adaptor, slice):
1231  return self.__iter__(int_slice_adaptor)
1232  else:
1233  raise TypeError("Unknown Type")
1234 
1235  def __len__(self):
1236  return self.get_number_of_frames()
1237 
1238  def __iter__(self, slice_key=None):
1239  if slice_key is None:
1240  for nframe in range(len(self)):
1241  yield self[nframe]
1242  else:
1243  for nframe in list(range(len(self)))[slice_key]:
1244  yield self[nframe]
1245 
1246 
1247 class CacheHierarchyCoordinates:
1248  def __init__(self, StatHierarchyHandler):
1249  self.xyzs = []
1250  self.nrms = []
1251  self.rbs = []
1252  self.nrm_coors = {}
1253  self.xyz_coors = {}
1254  self.rb_trans = {}
1255  self.current_index = None
1256  self.rmfh = StatHierarchyHandler
1257  rbs, xyzs = IMP.pmi.tools.get_rbs_and_beads([self.rmfh])
1258  self.model = self.rmfh.get_model()
1259  self.rbs = rbs
1260  for xyz in xyzs:
1262  nrm = IMP.core.NonRigidMember(xyz)
1263  self.nrms.append(nrm)
1264  else:
1265  fb = IMP.core.XYZ(xyz)
1266  self.xyzs.append(fb)
1267 
1268  def do_store(self, index):
1269  self.rb_trans[index] = {}
1270  self.nrm_coors[index] = {}
1271  self.xyz_coors[index] = {}
1272  for rb in self.rbs:
1273  self.rb_trans[index][rb] = rb.get_reference_frame()
1274  for nrm in self.nrms:
1275  self.nrm_coors[index][nrm] = nrm.get_internal_coordinates()
1276  for xyz in self.xyzs:
1277  self.xyz_coors[index][xyz] = xyz.get_coordinates()
1278  self.current_index = index
1279 
1280  def do_update(self, index):
1281  if self.current_index != index:
1282  for rb in self.rbs:
1283  rb.set_reference_frame(self.rb_trans[index][rb])
1284  for nrm in self.nrms:
1285  nrm.set_internal_coordinates(self.nrm_coors[index][nrm])
1286  for xyz in self.xyzs:
1287  xyz.set_coordinates(self.xyz_coors[index][xyz])
1288  self.current_index = index
1289  self.model.update()
1290 
1291  def get_number_of_frames(self):
1292  return len(self.rb_trans.keys())
1293 
1294  def __getitem__(self, index):
1295  if isinstance(index, int):
1296  return index in self.rb_trans.keys()
1297  else:
1298  raise TypeError("Unknown Type")
1299 
1300  def __len__(self):
1301  return self.get_number_of_frames()
1302 
1303 
1305  """ class to link stat files to several rmf files """
1306  def __init__(self, model=None, stat_file=None,
1307  number_best_scoring_models=None, score_key=None,
1308  StatHierarchyHandler=None, cache=None):
1309  """
1310 
1311  @param model: IMP.Model()
1312  @param stat_file: either 1) a list or 2) a single stat file names
1313  (either rmfs or ascii, or pickled data or pickled cluster),
1314  3) a dictionary containing an rmf/ascii
1315  stat file name as key and a list of frames as values
1316  @param number_best_scoring_models:
1317  @param StatHierarchyHandler: copy constructor input object
1318  @param cache: cache coordinates and rigid body transformations.
1319  """
1320 
1321  if StatHierarchyHandler is not None:
1322  # overrides all other arguments
1323  # copy constructor: create a copy with
1324  # different RMFHierarchyHandler
1325  self.model = StatHierarchyHandler.model
1326  self.data = StatHierarchyHandler.data
1327  self.number_best_scoring_models = \
1328  StatHierarchyHandler.number_best_scoring_models
1329  self.is_setup = True
1330  self.current_rmf = StatHierarchyHandler.current_rmf
1331  self.current_frame = None
1332  self.current_index = None
1333  self.score_threshold = StatHierarchyHandler.score_threshold
1334  self.score_key = StatHierarchyHandler.score_key
1335  self.cache = StatHierarchyHandler.cache
1336  super().__init__(self.model, self.current_rmf)
1337  if self.cache:
1338  self.cache = CacheHierarchyCoordinates(self)
1339  else:
1340  self.cache = None
1341  self.set_frame(0)
1342 
1343  else:
1344  # standard constructor
1345  self.model = model
1346  self.data = []
1347  self.number_best_scoring_models = number_best_scoring_models
1348  self.cache = cache
1349 
1350  if score_key is None:
1351  self.score_key = "Total_Score"
1352  else:
1353  self.score_key = score_key
1354  self.is_setup = None
1355  self.current_rmf = None
1356  self.current_frame = None
1357  self.current_index = None
1358  self.score_threshold = None
1359 
1360  if isinstance(stat_file, str):
1361  self.add_stat_file(stat_file)
1362  elif isinstance(stat_file, list):
1363  for f in stat_file:
1364  self.add_stat_file(f)
1365 
1366  def add_stat_file(self, stat_file):
1367  try:
1368  '''check that it is not a pickle file with saved data
1369  from a previous calculation'''
1370  self.load_data(stat_file)
1371 
1372  if self.number_best_scoring_models:
1373  scores = self.get_scores()
1374  max_score = sorted(scores)[
1375  0:min(len(self), self.number_best_scoring_models)][-1]
1376  self.do_filter_by_score(max_score)
1377 
1378  except pickle.UnpicklingError:
1379  '''alternatively read the ascii stat files'''
1380  try:
1381  scores, rmf_files, rmf_frame_indexes, features = \
1382  self.get_info_from_stat_file(stat_file,
1383  self.score_threshold)
1384  except (KeyError, SyntaxError):
1385  # in this case check that is it an rmf file, probably
1386  # without stat stored in
1387  try:
1388  # let's see if that is an rmf file
1389  rh = RMF.open_rmf_file_read_only(stat_file)
1390  nframes = rh.get_number_of_frames()
1391  scores = [0.0]*nframes
1392  rmf_files = [stat_file]*nframes
1393  rmf_frame_indexes = range(nframes)
1394  features = {}
1395  except: # noqa: E722
1396  return
1397 
1398  if len(set(rmf_files)) > 1:
1399  raise ("Multiple RMF files found")
1400 
1401  if not rmf_files:
1402  print("StatHierarchyHandler: Error: Trying to set none as "
1403  "rmf_file (probably empty stat file), aborting")
1404  return
1405 
1406  for n, index in enumerate(rmf_frame_indexes):
1407  featn_dict = dict([(k, features[k][n]) for k in features])
1408  self.data.append(IMP.pmi.output.DataEntry(
1409  stat_file, rmf_files[n], index, scores[n], featn_dict))
1410 
1411  if self.number_best_scoring_models:
1412  scores = self.get_scores()
1413  max_score = sorted(scores)[
1414  0:min(len(self), self.number_best_scoring_models)][-1]
1415  self.do_filter_by_score(max_score)
1416 
1417  if not self.is_setup:
1418  RMFHierarchyHandler.__init__(
1419  self, self.model, self.get_rmf_names()[0])
1420  if self.cache:
1421  self.cache = CacheHierarchyCoordinates(self)
1422  else:
1423  self.cache = None
1424  self.is_setup = True
1425  self.current_rmf = self.get_rmf_names()[0]
1426 
1427  self.set_frame(0)
1428 
1429  def save_data(self, filename='data.pkl'):
1430  with open(filename, 'wb') as fl:
1431  pickle.dump(self.data, fl)
1432 
1433  def load_data(self, filename='data.pkl'):
1434  with open(filename, 'rb') as fl:
1435  data_structure = pickle.load(fl)
1436  # first check that it is a list
1437  if not isinstance(data_structure, list):
1438  raise TypeError(
1439  "%filename should contain a list of IMP.pmi.output.DataEntry "
1440  "or IMP.pmi.output.Cluster" % filename)
1441  # second check the types
1442  if all(isinstance(item, IMP.pmi.output.DataEntry)
1443  for item in data_structure):
1444  self.data = data_structure
1445  elif all(isinstance(item, IMP.pmi.output.Cluster)
1446  for item in data_structure):
1447  nmodels = 0
1448  for cluster in data_structure:
1449  nmodels += len(cluster)
1450  self.data = [None]*nmodels
1451  for cluster in data_structure:
1452  for n, data in enumerate(cluster):
1453  index = cluster.members[n]
1454  self.data[index] = data
1455  else:
1456  raise TypeError(
1457  "%filename should contain a list of IMP.pmi.output.DataEntry "
1458  "or IMP.pmi.output.Cluster" % filename)
1459 
1460  def set_frame(self, index):
1461  if self.cache is not None and self.cache[index]:
1462  self.cache.do_update(index)
1463  else:
1464  nm = self.data[index].rmf_name
1465  fidx = self.data[index].rmf_index
1466  if nm != self.current_rmf:
1467  self.link_to_rmf(nm)
1468  self.current_rmf = nm
1469  self.current_frame = -1
1470  if fidx != self.current_frame:
1471  RMFHierarchyHandler.set_frame(self, fidx)
1472  self.current_frame = fidx
1473  if self.cache is not None:
1474  self.cache.do_store(index)
1475 
1476  self.current_index = index
1477 
1478  def __getitem__(self, int_slice_adaptor):
1479  if isinstance(int_slice_adaptor, int):
1480  self.set_frame(int_slice_adaptor)
1481  return self.data[int_slice_adaptor]
1482  elif isinstance(int_slice_adaptor, slice):
1483  return self.__iter__(int_slice_adaptor)
1484  else:
1485  raise TypeError("Unknown Type")
1486 
1487  def __len__(self):
1488  return len(self.data)
1489 
1490  def __iter__(self, slice_key=None):
1491  if slice_key is None:
1492  for i in range(len(self)):
1493  yield self[i]
1494  else:
1495  for i in range(len(self))[slice_key]:
1496  yield self[i]
1497 
1498  def do_filter_by_score(self, maximum_score):
1499  self.data = [d for d in self.data if d.score <= maximum_score]
1500 
1501  def get_scores(self):
1502  return [d.score for d in self.data]
1503 
1504  def get_feature_series(self, feature_name):
1505  return [d.features[feature_name] for d in self.data]
1506 
1507  def get_feature_names(self):
1508  return self.data[0].features.keys()
1509 
1510  def get_rmf_names(self):
1511  return [d.rmf_name for d in self.data]
1512 
1513  def get_stat_files_names(self):
1514  return [d.stat_file for d in self.data]
1515 
1516  def get_rmf_indexes(self):
1517  return [d.rmf_index for d in self.data]
1518 
1519  def get_info_from_stat_file(self, stat_file, score_threshold=None):
1520  po = ProcessOutput(stat_file)
1521  fs = po.get_keys()
1522  models = IMP.pmi.io.get_best_models(
1523  [stat_file], score_key=self.score_key, feature_keys=fs,
1524  rmf_file_key="rmf_file", rmf_file_frame_key="rmf_frame_index",
1525  prefiltervalue=score_threshold, get_every=1)
1526 
1527  scores = [float(y) for y in models[2]]
1528  rmf_files = models[0]
1529  rmf_frame_indexes = models[1]
1530  features = models[3]
1531  return scores, rmf_files, rmf_frame_indexes, features
1532 
1533 
1535  '''
1536  A class to store data associated to a model
1537  '''
1538  def __init__(self, stat_file=None, rmf_name=None, rmf_index=None,
1539  score=None, features=None):
1540  self.rmf_name = rmf_name
1541  self.rmf_index = rmf_index
1542  self.score = score
1543  self.features = features
1544  self.stat_file = stat_file
1545 
1546  def __repr__(self):
1547  s = "IMP.pmi.output.DataEntry\n"
1548  s += "---- stat file %s \n" % (self.stat_file)
1549  s += "---- rmf file %s \n" % (self.rmf_name)
1550  s += "---- rmf index %s \n" % (str(self.rmf_index))
1551  s += "---- score %s \n" % (str(self.score))
1552  s += "---- number of features %s \n" % (str(len(self.features.keys())))
1553  return s
1554 
1555 
1556 class Cluster:
1557  '''
1558  A container for models organized into clusters
1559  '''
1560  def __init__(self, cid=None):
1561  self.cluster_id = cid
1562  self.members = []
1563  self.precision = None
1564  self.center_index = None
1565  self.members_data = {}
1566 
1567  def add_member(self, index, data=None):
1568  self.members.append(index)
1569  self.members_data[index] = data
1570  self.average_score = self.compute_score()
1571 
1572  def compute_score(self):
1573  try:
1574  score = sum([d.score for d in self])/len(self)
1575  except AttributeError:
1576  score = None
1577  return score
1578 
1579  def __repr__(self):
1580  s = "IMP.pmi.output.Cluster\n"
1581  s += "---- cluster_id %s \n" % str(self.cluster_id)
1582  s += "---- precision %s \n" % str(self.precision)
1583  s += "---- average score %s \n" % str(self.average_score)
1584  s += "---- number of members %s \n" % str(len(self.members))
1585  s += "---- center index %s \n" % str(self.center_index)
1586  return s
1587 
1588  def __getitem__(self, int_slice_adaptor):
1589  if isinstance(int_slice_adaptor, int):
1590  index = self.members[int_slice_adaptor]
1591  return self.members_data[index]
1592  elif isinstance(int_slice_adaptor, slice):
1593  return self.__iter__(int_slice_adaptor)
1594  else:
1595  raise TypeError("Unknown Type")
1596 
1597  def __len__(self):
1598  return len(self.members)
1599 
1600  def __iter__(self, slice_key=None):
1601  if slice_key is None:
1602  for i in range(len(self)):
1603  yield self[i]
1604  else:
1605  for i in range(len(self))[slice_key]:
1606  yield self[i]
1607 
1608  def __add__(self, other):
1609  self.members += other.members
1610  self.members_data.update(other.members_data)
1611  self.average_score = self.compute_score()
1612  self.precision = None
1613  self.center_index = None
1614  return self
1615 
1616 
1617 def plot_clusters_populations(clusters):
1618  indexes = []
1619  populations = []
1620  for cluster in clusters:
1621  indexes.append(cluster.cluster_id)
1622  populations.append(len(cluster))
1623 
1624  import matplotlib.pyplot as plt
1625  fig, ax = plt.subplots()
1626  ax.bar(indexes, populations, 0.5, color='r')
1627  ax.set_ylabel('Population')
1628  ax.set_xlabel(('Cluster index'))
1629  plt.show()
1630 
1631 
1632 def plot_clusters_precisions(clusters):
1633  indexes = []
1634  precisions = []
1635  for cluster in clusters:
1636  indexes.append(cluster.cluster_id)
1637 
1638  prec = cluster.precision
1639  print(cluster.cluster_id, prec)
1640  if prec is None:
1641  prec = 0.0
1642  precisions.append(prec)
1643 
1644  import matplotlib.pyplot as plt
1645  fig, ax = plt.subplots()
1646  ax.bar(indexes, precisions, 0.5, color='r')
1647  ax.set_ylabel('Precision [A]')
1648  ax.set_xlabel(('Cluster index'))
1649  plt.show()
1650 
1651 
1652 def plot_clusters_scores(clusters):
1653  indexes = []
1654  values = []
1655  for cluster in clusters:
1656  indexes.append(cluster.cluster_id)
1657  values.append([])
1658  for data in cluster:
1659  values[-1].append(data.score)
1660 
1661  plot_fields_box_plots("scores.pdf", values, indexes, frequencies=None,
1662  valuename="Scores", positionname="Cluster index",
1663  xlabels=None, scale_plot_length=1.0)
1664 
1665 
1666 class CrossLinkIdentifierDatabase:
1667  def __init__(self):
1668  self.clidb = dict()
1669 
1670  def check_key(self, key):
1671  if key not in self.clidb:
1672  self.clidb[key] = {}
1673 
1674  def set_unique_id(self, key, value):
1675  self.check_key(key)
1676  self.clidb[key]["XLUniqueID"] = str(value)
1677 
1678  def set_protein1(self, key, value):
1679  self.check_key(key)
1680  self.clidb[key]["Protein1"] = str(value)
1681 
1682  def set_protein2(self, key, value):
1683  self.check_key(key)
1684  self.clidb[key]["Protein2"] = str(value)
1685 
1686  def set_residue1(self, key, value):
1687  self.check_key(key)
1688  self.clidb[key]["Residue1"] = int(value)
1689 
1690  def set_residue2(self, key, value):
1691  self.check_key(key)
1692  self.clidb[key]["Residue2"] = int(value)
1693 
1694  def set_idscore(self, key, value):
1695  self.check_key(key)
1696  self.clidb[key]["IDScore"] = float(value)
1697 
1698  def set_state(self, key, value):
1699  self.check_key(key)
1700  self.clidb[key]["State"] = int(value)
1701 
1702  def set_sigma1(self, key, value):
1703  self.check_key(key)
1704  self.clidb[key]["Sigma1"] = str(value)
1705 
1706  def set_sigma2(self, key, value):
1707  self.check_key(key)
1708  self.clidb[key]["Sigma2"] = str(value)
1709 
1710  def set_psi(self, key, value):
1711  self.check_key(key)
1712  self.clidb[key]["Psi"] = str(value)
1713 
1714  def get_unique_id(self, key):
1715  return self.clidb[key]["XLUniqueID"]
1716 
1717  def get_protein1(self, key):
1718  return self.clidb[key]["Protein1"]
1719 
1720  def get_protein2(self, key):
1721  return self.clidb[key]["Protein2"]
1722 
1723  def get_residue1(self, key):
1724  return self.clidb[key]["Residue1"]
1725 
1726  def get_residue2(self, key):
1727  return self.clidb[key]["Residue2"]
1728 
1729  def get_idscore(self, key):
1730  return self.clidb[key]["IDScore"]
1731 
1732  def get_state(self, key):
1733  return self.clidb[key]["State"]
1734 
1735  def get_sigma1(self, key):
1736  return self.clidb[key]["Sigma1"]
1737 
1738  def get_sigma2(self, key):
1739  return self.clidb[key]["Sigma2"]
1740 
1741  def get_psi(self, key):
1742  return self.clidb[key]["Psi"]
1743 
1744  def set_float_feature(self, key, value, feature_name):
1745  self.check_key(key)
1746  self.clidb[key][feature_name] = float(value)
1747 
1748  def set_int_feature(self, key, value, feature_name):
1749  self.check_key(key)
1750  self.clidb[key][feature_name] = int(value)
1751 
1752  def set_string_feature(self, key, value, feature_name):
1753  self.check_key(key)
1754  self.clidb[key][feature_name] = str(value)
1755 
1756  def get_feature(self, key, feature_name):
1757  return self.clidb[key][feature_name]
1758 
1759  def write(self, filename):
1760  with open(filename, 'wb') as handle:
1761  pickle.dump(self.clidb, handle)
1762 
1763  def load(self, filename):
1764  with open(filename, 'rb') as handle:
1765  self.clidb = pickle.load(handle)
1766 
1767 
1768 def plot_fields(fields, output, framemin=None, framemax=None):
1769  """Plot the given fields and save a figure as `output`.
1770  The fields generally are extracted from a stat file
1771  using ProcessOutput.get_fields()."""
1772  import matplotlib as mpl
1773  mpl.use('Agg')
1774  import matplotlib.pyplot as plt
1775 
1776  plt.rc('lines', linewidth=4)
1777  fig, axs = plt.subplots(nrows=len(fields))
1778  fig.set_size_inches(10.5, 5.5 * len(fields))
1779  plt.rc('axes')
1780 
1781  n = 0
1782  for key in fields:
1783  if framemin is None:
1784  framemin = 0
1785  if framemax is None:
1786  framemax = len(fields[key])
1787  x = list(range(framemin, framemax))
1788  y = [float(y) for y in fields[key][framemin:framemax]]
1789  if len(fields) > 1:
1790  axs[n].plot(x, y)
1791  axs[n].set_title(key, size="xx-large")
1792  axs[n].tick_params(labelsize=18, pad=10)
1793  else:
1794  axs.plot(x, y)
1795  axs.set_title(key, size="xx-large")
1796  axs.tick_params(labelsize=18, pad=10)
1797  n += 1
1798 
1799  # Tweak spacing between subplots to prevent labels from overlapping
1800  plt.subplots_adjust(hspace=0.3)
1801  plt.savefig(output)
1802 
1803 
1804 def plot_field_histogram(name, values_lists, valuename=None, bins=40,
1805  colors=None, format="png", reference_xline=None,
1806  yplotrange=None, xplotrange=None, normalized=True,
1807  leg_names=None):
1808  '''Plot a list of histograms from a value list.
1809  @param name the name of the plot
1810  @param values_lists the list of list of values eg: [[...],[...],[...]]
1811  @param valuename the y-label
1812  @param bins the number of bins
1813  @param colors If None, will use rainbow. Else will use specific list
1814  @param format output format
1815  @param reference_xline plot a reference line parallel to the y-axis
1816  @param yplotrange the range for the y-axis
1817  @param normalized whether the histogram is normalized or not
1818  @param leg_names names for the legend
1819  '''
1820 
1821  import matplotlib as mpl
1822  mpl.use('Agg')
1823  import matplotlib.pyplot as plt
1824  import matplotlib.cm as cm
1825  plt.figure(figsize=(18.0, 9.0))
1826 
1827  if colors is None:
1828  colors = cm.rainbow(np.linspace(0, 1, len(values_lists)))
1829  for nv, values in enumerate(values_lists):
1830  col = colors[nv]
1831  if leg_names is not None:
1832  label = leg_names[nv]
1833  else:
1834  label = str(nv)
1835  try:
1836  plt.hist(
1837  [float(y) for y in values], bins=bins, color=col,
1838  density=normalized, histtype='step', lw=4, label=label)
1839  except AttributeError:
1840  plt.hist(
1841  [float(y) for y in values], bins=bins, color=col,
1842  normed=normalized, histtype='step', lw=4, label=label)
1843 
1844  # plt.title(name,size="xx-large")
1845  plt.tick_params(labelsize=12, pad=10)
1846  if valuename is None:
1847  plt.xlabel(name, size="xx-large")
1848  else:
1849  plt.xlabel(valuename, size="xx-large")
1850  plt.ylabel("Frequency", size="xx-large")
1851 
1852  if yplotrange is not None:
1853  plt.ylim()
1854  if xplotrange is not None:
1855  plt.xlim(xplotrange)
1856 
1857  plt.legend(loc=2)
1858 
1859  if reference_xline is not None:
1860  plt.axvline(
1861  reference_xline,
1862  color='red',
1863  linestyle='dashed',
1864  linewidth=1)
1865 
1866  plt.savefig(name + "." + format, dpi=150, transparent=True)
1867 
1868 
1869 def plot_fields_box_plots(name, values, positions, frequencies=None,
1870  valuename="None", positionname="None",
1871  xlabels=None, scale_plot_length=1.0):
1872  '''
1873  Plot time series as boxplots.
1874  fields is a list of time series, positions are the x-values
1875  valuename is the y-label, positionname is the x-label
1876  '''
1877 
1878  import matplotlib as mpl
1879  mpl.use('Agg')
1880  import matplotlib.pyplot as plt
1881 
1882  bps = []
1883  fig = plt.figure(figsize=(float(len(positions))*scale_plot_length, 5.0))
1884  fig.canvas.manager.set_window_title(name)
1885 
1886  ax1 = fig.add_subplot(111)
1887 
1888  plt.subplots_adjust(left=0.1, right=0.990, top=0.95, bottom=0.4)
1889 
1890  bps.append(plt.boxplot(values, notch=0, sym='', vert=1,
1891  whis=1.5, positions=positions))
1892 
1893  plt.setp(bps[-1]['boxes'], color='black', lw=1.5)
1894  plt.setp(bps[-1]['whiskers'], color='black', ls=":", lw=1.5)
1895 
1896  if frequencies is not None:
1897  for n, v in enumerate(values):
1898  plist = [positions[n]]*len(v)
1899  ax1.plot(plist, v, 'gx', alpha=0.7, markersize=7)
1900 
1901  # print ax1.xaxis.get_majorticklocs()
1902  if xlabels is not None:
1903  ax1.set_xticklabels(xlabels)
1904  plt.xticks(rotation=90)
1905  plt.xlabel(positionname)
1906  plt.ylabel(valuename)
1907 
1908  plt.savefig(name + ".pdf", dpi=150)
1909  plt.show()
1910 
1911 
1912 def plot_xy_data(x, y, title=None, out_fn=None, display=True,
1913  set_plot_yaxis_range=None, xlabel=None, ylabel=None):
1914  import matplotlib as mpl
1915  mpl.use('Agg')
1916  import matplotlib.pyplot as plt
1917  plt.rc('lines', linewidth=2)
1918 
1919  fig, ax = plt.subplots(nrows=1)
1920  fig.set_size_inches(8, 4.5)
1921  if title is not None:
1922  fig.canvas.manager.set_window_title(title)
1923 
1924  ax.plot(x, y, color='r')
1925  if set_plot_yaxis_range is not None:
1926  x1, x2, y1, y2 = plt.axis()
1927  y1 = set_plot_yaxis_range[0]
1928  y2 = set_plot_yaxis_range[1]
1929  plt.axis((x1, x2, y1, y2))
1930  if title is not None:
1931  ax.set_title(title)
1932  if xlabel is not None:
1933  ax.set_xlabel(xlabel)
1934  if ylabel is not None:
1935  ax.set_ylabel(ylabel)
1936  if out_fn is not None:
1937  plt.savefig(out_fn + ".pdf")
1938  if display:
1939  plt.show()
1940  plt.close(fig)
1941 
1942 
1943 def plot_scatter_xy_data(x, y, labelx="None", labely="None",
1944  xmin=None, xmax=None, ymin=None, ymax=None,
1945  savefile=False, filename="None.eps", alpha=0.75):
1946 
1947  import matplotlib as mpl
1948  mpl.use('Agg')
1949  import matplotlib.pyplot as plt
1950  from matplotlib import rc
1951  rc('font', **{'family': 'sans-serif', 'sans-serif': ['Helvetica']})
1952 
1953  fig, axs = plt.subplots(1)
1954 
1955  axs0 = axs
1956 
1957  axs0.set_xlabel(labelx, size="xx-large")
1958  axs0.set_ylabel(labely, size="xx-large")
1959  axs0.tick_params(labelsize=18, pad=10)
1960 
1961  plot2 = []
1962 
1963  plot2.append(axs0.plot(x, y, 'o', color='k', lw=2, ms=0.1, alpha=alpha,
1964  c="w"))
1965 
1966  axs0.legend(
1967  loc=0,
1968  frameon=False,
1969  scatterpoints=1,
1970  numpoints=1,
1971  columnspacing=1)
1972 
1973  fig.set_size_inches(8.0, 8.0)
1974  fig.subplots_adjust(left=0.161, right=0.850, top=0.95, bottom=0.11)
1975  if (ymin is not None) and (ymax is not None):
1976  axs0.set_ylim(ymin, ymax)
1977  if (xmin is not None) and (xmax is not None):
1978  axs0.set_xlim(xmin, xmax)
1979 
1980  if savefile:
1981  fig.savefig(filename, dpi=300)
1982 
1983 
1984 def get_graph_from_hierarchy(hier):
1985  graph = []
1986  depth_dict = {}
1987  depth = 0
1988  (graph, depth, depth_dict) = recursive_graph(
1989  hier, graph, depth, depth_dict)
1990 
1991  # filters node labels according to depth_dict
1992  node_labels_dict = {}
1993  for key in depth_dict:
1994  if depth_dict[key] < 3:
1995  node_labels_dict[key] = key
1996  else:
1997  node_labels_dict[key] = ""
1998  draw_graph(graph, labels_dict=node_labels_dict)
1999 
2000 
2001 def recursive_graph(hier, graph, depth, depth_dict):
2002  depth = depth + 1
2003  nameh = IMP.atom.Hierarchy(hier).get_name()
2004  index = str(hier.get_particle().get_index())
2005  name1 = nameh + "|#" + index
2006  depth_dict[name1] = depth
2007 
2008  children = IMP.atom.Hierarchy(hier).get_children()
2009 
2010  if len(children) == 1 or children is None:
2011  depth = depth - 1
2012  return (graph, depth, depth_dict)
2013 
2014  else:
2015  for c in children:
2016  (graph, depth, depth_dict) = recursive_graph(
2017  c, graph, depth, depth_dict)
2018  nameh = IMP.atom.Hierarchy(c).get_name()
2019  index = str(c.get_particle().get_index())
2020  namec = nameh + "|#" + index
2021  graph.append((name1, namec))
2022 
2023  depth = depth - 1
2024  return (graph, depth, depth_dict)
2025 
2026 
2027 def draw_graph(graph, labels_dict=None, graph_layout='spring',
2028  node_size=5, node_color=None, node_alpha=0.3,
2029  node_text_size=11, fixed=None, pos=None,
2030  edge_color='blue', edge_alpha=0.3, edge_thickness=1,
2031  edge_text_pos=0.3,
2032  validation_edges=None,
2033  text_font='sans-serif',
2034  out_filename=None):
2035 
2036  import matplotlib as mpl
2037  mpl.use('Agg')
2038  import networkx as nx
2039  import matplotlib.pyplot as plt
2040  from math import sqrt, pi
2041 
2042  # create networkx graph
2043  G = nx.Graph()
2044 
2045  # add edges
2046  if isinstance(edge_thickness, list):
2047  for edge, weight in zip(graph, edge_thickness):
2048  G.add_edge(edge[0], edge[1], weight=weight)
2049  else:
2050  for edge in graph:
2051  G.add_edge(edge[0], edge[1])
2052 
2053  if node_color is None:
2054  node_color_rgb = (0, 0, 0)
2055  node_color_hex = "000000"
2056  else:
2058  tmpcolor_rgb = []
2059  tmpcolor_hex = []
2060  for node in G.nodes():
2061  cctuple = cc.rgb(node_color[node])
2062  tmpcolor_rgb.append((cctuple[0]/255,
2063  cctuple[1]/255,
2064  cctuple[2]/255))
2065  tmpcolor_hex.append(node_color[node])
2066  node_color_rgb = tmpcolor_rgb
2067  node_color_hex = tmpcolor_hex
2068 
2069  # get node sizes if dictionary
2070  if isinstance(node_size, dict):
2071  tmpsize = []
2072  for node in G.nodes():
2073  size = sqrt(node_size[node])/pi*10.0
2074  tmpsize.append(size)
2075  node_size = tmpsize
2076 
2077  for n, node in enumerate(G.nodes()):
2078  color = node_color_hex[n]
2079  size = node_size[n]
2080  nx.set_node_attributes(
2081  G, "graphics",
2082  {node: {'type': 'ellipse', 'w': size, 'h': size,
2083  'fill': '#' + color, 'label': node}})
2084  nx.set_node_attributes(
2085  G, "LabelGraphics",
2086  {node: {'type': 'text', 'text': node, 'color': '#000000',
2087  'visible': 'true'}})
2088 
2089  for edge in G.edges():
2090  nx.set_edge_attributes(
2091  G, "graphics",
2092  {edge: {'width': 1, 'fill': '#000000'}})
2093 
2094  for ve in validation_edges:
2095  print(ve)
2096  if (ve[0], ve[1]) in G.edges():
2097  print("found forward")
2098  nx.set_edge_attributes(
2099  G, "graphics",
2100  {ve: {'width': 1, 'fill': '#00FF00'}})
2101  elif (ve[1], ve[0]) in G.edges():
2102  print("found backward")
2103  nx.set_edge_attributes(
2104  G, "graphics",
2105  {(ve[1], ve[0]): {'width': 1, 'fill': '#00FF00'}})
2106  else:
2107  G.add_edge(ve[0], ve[1])
2108  print("not found")
2109  nx.set_edge_attributes(
2110  G, "graphics",
2111  {ve: {'width': 1, 'fill': '#FF0000'}})
2112 
2113  # these are different layouts for the network you may try
2114  # shell seems to work best
2115  if graph_layout == 'spring':
2116  print(fixed, pos)
2117  graph_pos = nx.spring_layout(G, k=1.0/8.0, fixed=fixed, pos=pos)
2118  elif graph_layout == 'spectral':
2119  graph_pos = nx.spectral_layout(G)
2120  elif graph_layout == 'random':
2121  graph_pos = nx.random_layout(G)
2122  else:
2123  graph_pos = nx.shell_layout(G)
2124 
2125  # draw graph
2126  nx.draw_networkx_nodes(G, graph_pos, node_size=node_size,
2127  alpha=node_alpha, node_color=node_color_rgb,
2128  linewidths=0)
2129  nx.draw_networkx_edges(G, graph_pos, width=edge_thickness,
2130  alpha=edge_alpha, edge_color=edge_color)
2131  nx.draw_networkx_labels(
2132  G, graph_pos, labels=labels_dict, font_size=node_text_size,
2133  font_family=text_font)
2134  if out_filename:
2135  plt.savefig(out_filename)
2136  nx.write_gml(G, 'out.gml')
2137  plt.show()
2138 
2139 
2140 def draw_table():
2141 
2142  # still an example!
2143 
2144  from ipyD3 import d3object
2145  from IPython.display import display
2146 
2147  d3 = d3object(width=800,
2148  height=400,
2149  style='JFTable',
2150  number=1,
2151  d3=None,
2152  title='Example table with d3js',
2153  desc='An example table created created with d3js with '
2154  'data generated with Python.')
2155  data = [[1277.0, 654.0, 288.0, 1976.0, 3281.0, 3089.0, 10336.0, 4650.0,
2156  4441.0, 4670.0, 944.0, 110.0],
2157  [1318.0, 664.0, 418.0, 1952.0, 3581.0, 4574.0, 11457.0, 6139.0,
2158  7078.0, 6561.0, 2354.0, 710.0],
2159  [1783.0, 774.0, 564.0, 1470.0, 3571.0, 3103.0, 9392.0, 5532.0,
2160  5661.0, 4991.0, 2032.0, 680.0],
2161  [1301.0, 604.0, 286.0, 2152.0, 3282.0, 3369.0, 10490.0, 5406.0,
2162  4727.0, 3428.0, 1559.0, 620.0],
2163  [1537.0, 1714.0, 724.0, 4824.0, 5551.0, 8096.0, 16589.0, 13650.0,
2164  9552.0, 13709.0, 2460.0, 720.0],
2165  [5691.0, 2995.0, 1680.0, 11741.0, 16232.0, 14731.0, 43522.0,
2166  32794.0, 26634.0, 31400.0, 7350.0, 3010.0],
2167  [1650.0, 2096.0, 60.0, 50.0, 1180.0, 5602.0, 15728.0, 6874.0,
2168  5115.0, 3510.0, 1390.0, 170.0],
2169  [72.0, 60.0, 60.0, 10.0, 120.0, 172.0, 1092.0, 675.0, 408.0,
2170  360.0, 156.0, 100.0]]
2171  data = [list(i) for i in zip(*data)]
2172  sRows = [['January',
2173  'February',
2174  'March',
2175  'April',
2176  'May',
2177  'June',
2178  'July',
2179  'August',
2180  'September',
2181  'October',
2182  'November',
2183  'Deecember']]
2184  sColumns = [['Prod {0}'.format(i) for i in range(1, 9)],
2185  [None, '', None, None, 'Group 1', None, None, 'Group 2']]
2186  d3.addSimpleTable(data,
2187  fontSizeCells=[12, ],
2188  sRows=sRows,
2189  sColumns=sColumns,
2190  sRowsMargins=[5, 50, 0],
2191  sColsMargins=[5, 20, 10],
2192  spacing=0,
2193  addBorders=1,
2194  addOutsideBorders=-1,
2195  rectWidth=45,
2196  rectHeight=0
2197  )
2198  html = d3.render(mode=['html', 'show'])
2199  display(html)
static bool get_is_setup(const IMP::ParticleAdaptor &p)
Definition: Residue.h:158
A container for models organized into clusters.
Definition: output.py:1556
A class for reading stat files (either rmf or ascii v1 and v2)
Definition: output.py:982
atom::Hierarchies create_hierarchies(RMF::FileConstHandle fh, Model *m)
RMF::FrameID save_frame(RMF::FileHandle file, std::string name="")
Save the current state of the linked objects as a new RMF frame.
static bool get_is_setup(const IMP::ParticleAdaptor &p)
Definition: atom/Atom.h:245
def plot_field_histogram
Plot a list of histograms from a value list.
Definition: output.py:1804
def plot_fields_box_plots
Plot time series as boxplots.
Definition: output.py:1869
Utility classes and functions for reading and storing PMI files.
def get_best_models
Given a list of stat files, read them all and find the best models.
Miscellaneous utilities.
Definition: pmi/tools.py:1
A class to store data associated to a model.
Definition: output.py:1534
void handle_use_deprecated(std::string message)
Break in this method in gdb to find deprecated uses at runtime.
std::string get_module_version()
Return the version of this module, as a string.
Change color code to hexadecimal to rgb.
Definition: pmi/tools.py:736
void write_pdb(const Selection &mhd, TextOutput out, unsigned int model=1)
Collect statistics from ProcessOutput.get_fields().
Definition: output.py:971
static bool get_is_setup(const IMP::ParticleAdaptor &p)
Definition: XYZR.h:47
def get_fields
Get the desired field names, and return a dictionary.
Definition: output.py:1047
Warning related to handling of structures.
static bool get_is_setup(Model *m, ParticleIndex pi)
Definition: Fragment.h:46
def link_to_rmf
Link to another RMF file.
Definition: output.py:1206
std::string get_molecule_name_and_copy(atom::Hierarchy h)
Walk up a PMI2 hierarchy/representations and get the "molname.copynum".
Definition: pmi/utilities.h:85
The standard decorator for manipulating molecular structures.
Ints get_index(const ParticlesTemp &particles, const Subset &subset, const Subsets &excluded)
def init_pdb
Init PDB Writing.
Definition: output.py:257
def deprecated_method
Python decorator to mark a method as deprecated.
Definition: __init__.py:9902
int get_number_of_frames(const ::npctransport_proto::Assignment &config, double time_step)
A decorator for a particle representing an atom.
Definition: atom/Atom.h:238
Base class for capturing a modeling protocol.
Definition: output.py:41
The type for a residue.
def write_stat2
Write a single line to a stat file previously created with init_stat2().
Definition: output.py:926
void load_frame(RMF::FileConstHandle file, RMF::FrameID frame)
Load the given RMF frame into the state of the linked objects.
A decorator for a particle with x,y,z coordinates.
Definition: XYZ.h:30
A base class for Keys.
Definition: Key.h:45
void add_hierarchies(RMF::NodeHandle fh, const atom::Hierarchies &hs)
Class for easy writing of PDBs, RMFs, and stat files.
Definition: output.py:199
void add_geometries(RMF::NodeHandle parent, const display::GeometriesTemp &r)
Add geometries to a given parent node.
void add_restraints(RMF::NodeHandle fh, const Restraints &hs)
A decorator for a particle that is part of a rigid body but not rigid.
Definition: rigid_bodies.h:766
Display a segment connecting a pair of particles.
Definition: XYZR.h:170
A decorator for a residue.
Definition: Residue.h:137
Basic functionality that is expected to be used by a wide variety of IMP users.
def get_pdb_names
Get a list of all PDB files being output by this instance.
Definition: output.py:226
def get_prot_name_from_particle
Get the protein name from the particle.
Definition: output.py:343
class to link stat files to several rmf files
Definition: output.py:1304
class to allow more advanced handling of RMF files.
Definition: output.py:1180
void link_hierarchies(RMF::FileConstHandle fh, const atom::Hierarchies &hs)
def plot_fields
Plot the given fields and save a figure as output.
Definition: output.py:1768
void add_geometry(RMF::FileHandle file, display::Geometry *r)
Add a single geometry to the file.
Store info for a chain of a protein.
Definition: Chain.h:61
Python classes to represent, score, sample and analyze models.
def init_pdb_best_scoring
Prepare for writing best-scoring PDBs (or mmCIFs) for a sampling run.
Definition: output.py:464
Functionality for loading, creating, manipulating and scoring atomic structures.
def get_rbs_and_beads
Returns unique objects in original order.
Definition: pmi/tools.py:1135
Select hierarchy particles identified by the biological name.
Definition: Selection.h:70
def init_rmf
Initialize an RMF file.
Definition: output.py:561
def get_residue_indexes
Retrieve the residue indexes for the given particle.
Definition: pmi/tools.py:499
static bool get_is_setup(const IMP::ParticleAdaptor &p)
Definition: rigid_bodies.h:768
def init_stat2
Write the header for a stat file in v2 format.
Definition: output.py:842
std::string get_module_version()
Return the version of this module, as a string.
def sublist_iterator
Yield all sublists of length >= lmin and <= lmax.
Definition: pmi/tools.py:574
A decorator for a particle with x,y,z coordinates and a radius.
Definition: XYZR.h:27