IMP logo
IMP Reference Guide  develop.1441b25730,2025/12/12
The Integrative Modeling Platform
output.py
1 """@namespace IMP.pmi.output
2  Classes for writing output files and processing them.
3 """
4 
5 import IMP
6 import IMP.atom
7 import IMP.core
8 import IMP.pmi
9 import IMP.pmi.tools
10 import IMP.pmi.io
11 import os
12 import sys
13 import ast
14 import RMF
15 import numpy as np
16 import operator
17 import itertools
18 import warnings
19 import string
20 import ihm.format
21 import collections
22 import pickle
23 
24 
25 class _ChainIDs:
26  """Map indices to multi-character chain IDs.
27  We label the first 26 chains A-Z, then we move to two-letter
28  chain IDs: AA through AZ, then BA through BZ, through to ZZ.
29  This continues with longer chain IDs."""
30  def __getitem__(self, ind):
31  chars = string.ascii_uppercase
32  lc = len(chars)
33  ids = []
34  while ind >= lc:
35  ids.append(chars[ind % lc])
36  ind = ind // lc - 1
37  ids.append(chars[ind])
38  return "".join(reversed(ids))
39 
40 
42  """Base class for capturing a modeling protocol.
43  Unlike simple output of model coordinates, a complete
44  protocol includes the input data used, details on the restraints,
45  sampling, and clustering, as well as output models.
46  Use via IMP.pmi.topology.System.add_protocol_output().
47 
48  @see IMP.pmi.mmcif.ProtocolOutput for a concrete subclass that outputs
49  mmCIF files.
50  """
51  pass
52 
53 
54 def _flatten(seq):
55  for elt in seq:
56  if isinstance(elt, (tuple, list)):
57  for elt2 in _flatten(elt):
58  yield elt2
59  else:
60  yield elt
61 
62 
63 def _disambiguate_chain(chid, seen_chains):
64  """Make sure that the chain ID is unique; warn and correct if it isn't"""
65  # Handle null chain IDs
66  if chid == '\0':
67  chid = ' '
68 
69  if chid in seen_chains:
70  warnings.warn("Duplicate chain ID '%s' encountered" % chid,
72 
73  for suffix in itertools.count(1):
74  new_chid = chid + "%d" % suffix
75  if new_chid not in seen_chains:
76  seen_chains.add(new_chid)
77  return new_chid
78  seen_chains.add(chid)
79  return chid
80 
81 
82 def _write_pdb_internal(flpdb, particle_infos_for_pdb, geometric_center,
83  write_all_residues_per_bead):
84  for n, tupl in enumerate(particle_infos_for_pdb):
85  (xyz, atom_type, residue_type,
86  chain_id, residue_index, all_indexes, radius) = tupl
87  if atom_type is None:
88  atom_type = IMP.atom.AT_CA
89  if write_all_residues_per_bead and all_indexes is not None:
90  for residue_number in all_indexes:
91  flpdb.write(
92  IMP.atom.get_pdb_string((xyz[0] - geometric_center[0],
93  xyz[1] - geometric_center[1],
94  xyz[2] - geometric_center[2]),
95  n+1, atom_type, residue_type,
96  chain_id[:1], residue_number, ' ',
97  1.00, radius))
98  else:
99  flpdb.write(
100  IMP.atom.get_pdb_string((xyz[0] - geometric_center[0],
101  xyz[1] - geometric_center[1],
102  xyz[2] - geometric_center[2]),
103  n+1, atom_type, residue_type,
104  chain_id[:1], residue_index, ' ',
105  1.00, radius))
106  flpdb.write("ENDMDL\n")
107 
108 
109 _Entity = collections.namedtuple('_Entity', ('id', 'seq'))
110 _ChainInfo = collections.namedtuple('_ChainInfo', ('entity', 'name'))
111 
112 
113 def _get_chain_info(chains, root_hier):
114  chain_info = {}
115  entities = {}
116  all_entities = []
117  for mol in IMP.atom.get_by_type(root_hier, IMP.atom.MOLECULE_TYPE):
119  chain_id = chains[molname]
120  chain = IMP.atom.Chain(mol)
121  seq = chain.get_sequence()
122  if seq not in entities:
123  entities[seq] = e = _Entity(id=len(entities)+1, seq=seq)
124  all_entities.append(e)
125  entity = entities[seq]
126  info = _ChainInfo(entity=entity, name=molname)
127  chain_info[chain_id] = info
128  return chain_info, all_entities
129 
130 
131 def _write_mmcif_internal(flpdb, particle_infos_for_pdb, geometric_center,
132  write_all_residues_per_bead, chains, root_hier):
133  # get dict with keys=chain IDs, values=chain info
134  chain_info, entities = _get_chain_info(chains, root_hier)
135 
136  writer = ihm.format.CifWriter(flpdb)
137  writer.start_block('model')
138  with writer.category("_entry") as lp:
139  lp.write(id='model')
140 
141  with writer.loop("_entity", ["id", "type"]) as lp:
142  for e in entities:
143  lp.write(id=e.id, type="polymer")
144 
145  with writer.loop("_entity_poly",
146  ["entity_id", "pdbx_seq_one_letter_code"]) as lp:
147  for e in entities:
148  lp.write(entity_id=e.id, pdbx_seq_one_letter_code=e.seq)
149 
150  with writer.loop("_struct_asym", ["id", "entity_id", "details"]) as lp:
151  # Longer chain IDs (e.g. AA) should always come after shorter (e.g. Z)
152  for chid in sorted(chains.values(), key=lambda x: (len(x.strip()), x)):
153  ci = chain_info[chid]
154  lp.write(id=chid, entity_id=ci.entity.id, details=ci.name)
155 
156  with writer.loop("_atom_site",
157  ["group_PDB", "type_symbol", "label_atom_id",
158  "label_comp_id", "label_asym_id", "label_seq_id",
159  "auth_seq_id",
160  "Cartn_x", "Cartn_y", "Cartn_z", "label_entity_id",
161  "pdbx_pdb_model_num",
162  "id"]) as lp:
163  ordinal = 1
164  for n, tupl in enumerate(particle_infos_for_pdb):
165  (xyz, atom_type, residue_type,
166  chain_id, residue_index, all_indexes, radius) = tupl
167  ci = chain_info[chain_id]
168  if atom_type is None:
169  atom_type = IMP.atom.AT_CA
170  c = (xyz[0] - geometric_center[0],
171  xyz[1] - geometric_center[1],
172  xyz[2] - geometric_center[2])
173  if write_all_residues_per_bead and all_indexes is not None:
174  for residue_number in all_indexes:
175  lp.write(group_PDB='ATOM',
176  type_symbol='C',
177  label_atom_id=atom_type.get_string(),
178  label_comp_id=residue_type.get_string(),
179  label_asym_id=chain_id,
180  label_seq_id=residue_index,
181  auth_seq_id=residue_index, Cartn_x=c[0],
182  Cartn_y=c[1], Cartn_z=c[2], id=ordinal,
183  pdbx_pdb_model_num=1,
184  label_entity_id=ci.entity.id)
185  ordinal += 1
186  else:
187  lp.write(group_PDB='ATOM', type_symbol='C',
188  label_atom_id=atom_type.get_string(),
189  label_comp_id=residue_type.get_string(),
190  label_asym_id=chain_id,
191  label_seq_id=residue_index,
192  auth_seq_id=residue_index, Cartn_x=c[0],
193  Cartn_y=c[1], Cartn_z=c[2], id=ordinal,
194  pdbx_pdb_model_num=1,
195  label_entity_id=ci.entity.id)
196  ordinal += 1
197 
198 
199 class Output:
200  """Class for easy writing of PDBs, RMFs, and stat files
201 
202  @note Model should be updated prior to writing outputs.
203  """
204  def __init__(self, ascii=True, atomistic=False):
205  self.dictionary_pdbs = {}
206  self._pdb_mmcif = {}
207  self.dictionary_rmfs = {}
208  self.dictionary_stats = {}
209  self.dictionary_stats2 = {}
210  self.best_score_list = None
211  self.nbestscoring = None
212  self.prefixes = []
213  self.replica_exchange = False
214  self.ascii = ascii
215  self.initoutput = {}
216  self.residuetypekey = IMP.StringKey("ResidueName")
217  # 1-character chain IDs, suitable for PDB output
218  self.chainids = "ABCDEFGHIJKLMNOPQRSTUVWXYZ" \
219  "abcdefghijklmnopqrstuvwxyz0123456789"
220  # Multi-character chain IDs, suitable for mmCIF output
221  self.multi_chainids = _ChainIDs()
222  self.dictchain = {} # keys are molecule names, values are chain ids
223  self.particle_infos_for_pdb = {}
224  self.atomistic = atomistic
225 
226  def get_pdb_names(self):
227  """Get a list of all PDB files being output by this instance"""
228  return list(self.dictionary_pdbs.keys())
229 
230  def get_rmf_names(self):
231  return list(self.dictionary_rmfs.keys())
232 
233  def get_stat_names(self):
234  return list(self.dictionary_stats.keys())
235 
236  def _init_dictchain(self, name, prot, multichar_chain=False, mmcif=False):
237  self.dictchain[name] = {}
238  seen_chains = set()
239 
240  # attempt to find PMI objects.
241  self.atomistic = True # detects automatically
242  for n, mol in enumerate(IMP.atom.get_by_type(
243  prot, IMP.atom.MOLECULE_TYPE)):
244  chid = IMP.atom.Chain(mol).get_id()
245  if not mmcif and len(chid) > 1:
246  raise ValueError(
247  "The system contains at least one chain ID (%s) that "
248  "is more than 1 character long; this cannot be "
249  "represented in PDB. Either write mmCIF files "
250  "instead, or assign 1-character IDs to all chains "
251  "(this can be done with the `chain_ids` argument to "
252  "BuildSystem.add_state())." % chid)
253  chid = _disambiguate_chain(chid, seen_chains)
255  self.dictchain[name][molname] = chid
256 
257  def init_pdb(self, name, prot, mmcif=False):
258  """Init PDB Writing.
259  @param name The PDB filename
260  @param prot The hierarchy to write to this pdb file
261  @param mmcif If True, write PDBs in mmCIF format
262  @note if the PDB name is 'System' then will use Selection
263  to get molecules
264  """
265  flpdb = open(name, 'w')
266  flpdb.close()
267  self.dictionary_pdbs[name] = prot
268  self._pdb_mmcif[name] = mmcif
269  self._init_dictchain(name, prot, mmcif=mmcif)
270 
271  def write_psf(self, filename, name):
272  flpsf = open(filename, 'w')
273  flpsf.write("PSF CMAP CHEQ" + "\n")
274  index_residue_pair_list = {}
275  (particle_infos_for_pdb, geometric_center) = \
276  self.get_particle_infos_for_pdb_writing(name)
277  nparticles = len(particle_infos_for_pdb)
278  flpsf.write(str(nparticles) + " !NATOM" + "\n")
279  for n, p in enumerate(particle_infos_for_pdb):
280  atom_index = n+1
281  residue_type = p[2]
282  chain = p[3]
283  resid = p[4]
284  flpsf.write('{0:8d}{1:1s}{2:4s}{3:1s}{4:4s}{5:1s}{6:4s}{7:1s}'
285  '{8:4s}{9:1s}{10:4s}{11:14.6f}{12:14.6f}{13:8d}'
286  '{14:14.6f}{15:14.6f}'.format(
287  atom_index, " ", chain, " ", str(resid), " ",
288  '"'+residue_type.get_string()+'"', " ", "C",
289  " ", "C", 1.0, 0.0, 0, 0.0, 0.0))
290  flpsf.write('\n')
291  if chain not in index_residue_pair_list:
292  index_residue_pair_list[chain] = [(atom_index, resid)]
293  else:
294  index_residue_pair_list[chain].append((atom_index, resid))
295 
296  # now write the connectivity
297  indexes_pairs = []
298  for chain in sorted(index_residue_pair_list.keys()):
299 
300  ls = index_residue_pair_list[chain]
301  # sort by residue
302  ls = sorted(ls, key=lambda tup: tup[1])
303  # get the index list
304  indexes = [x[0] for x in ls]
305  # get the contiguous pairs
306  indexes_pairs.extend(IMP.pmi.tools.sublist_iterator(
307  indexes, lmin=2, lmax=2))
308  nbonds = len(indexes_pairs)
309  flpsf.write(str(nbonds)+" !NBOND: bonds"+"\n")
310 
311  # save bonds in fixed column format
312  for i in range(0, len(indexes_pairs), 4):
313  for bond in indexes_pairs[i:i+4]:
314  flpsf.write('{0:8d}{1:8d}'.format(*bond))
315  flpsf.write('\n')
316 
317  del particle_infos_for_pdb
318  flpsf.close()
319 
320  def write_pdb(self, name, appendmode=True,
321  translate_to_geometric_center=False,
322  write_all_residues_per_bead=False):
323 
324  (particle_infos_for_pdb,
325  geometric_center) = self.get_particle_infos_for_pdb_writing(name)
326 
327  if not translate_to_geometric_center:
328  geometric_center = (0, 0, 0)
329 
330  filemode = 'a' if appendmode else 'w'
331  with open(name, filemode) as flpdb:
332  if self._pdb_mmcif[name]:
333  _write_mmcif_internal(flpdb, particle_infos_for_pdb,
334  geometric_center,
335  write_all_residues_per_bead,
336  self.dictchain[name],
337  self.dictionary_pdbs[name])
338  else:
339  _write_pdb_internal(flpdb, particle_infos_for_pdb,
340  geometric_center,
341  write_all_residues_per_bead)
342 
343  def get_prot_name_from_particle(self, name, p):
344  """Get the protein name from the particle.
345  This is done by traversing the hierarchy."""
346  return IMP.pmi.get_molecule_name_and_copy(p), True
347 
348  def get_particle_infos_for_pdb_writing(self, name):
349  # index_residue_pair_list={}
350 
351  # the resindexes dictionary keep track of residues that have
352  # been already added to avoid duplication
353  # highest resolution have highest priority
354  resindexes_dict = {}
355 
356  # this dictionary will contain the sequence of tuples needed to
357  # write the pdb
358  particle_infos_for_pdb = []
359 
360  geometric_center = [0, 0, 0]
361  atom_count = 0
362 
363  # select highest resolution, if hierarchy is non-empty
364  if (not IMP.core.XYZR.get_is_setup(self.dictionary_pdbs[name])
365  and self.dictionary_pdbs[name].get_number_of_children() == 0):
366  ps = []
367  else:
368  sel = IMP.atom.Selection(self.dictionary_pdbs[name], resolution=0)
369  ps = sel.get_selected_particles()
370 
371  for n, p in enumerate(ps):
372  protname, is_a_bead = self.get_prot_name_from_particle(name, p)
373 
374  if protname not in resindexes_dict:
375  resindexes_dict[protname] = []
376 
377  if IMP.atom.Atom.get_is_setup(p) and self.atomistic:
378  residue = IMP.atom.Residue(IMP.atom.Atom(p).get_parent())
379  rt = residue.get_residue_type()
380  resind = residue.get_index()
381  atomtype = IMP.atom.Atom(p).get_atom_type()
382  xyz = list(IMP.core.XYZ(p).get_coordinates())
383  radius = IMP.core.XYZR(p).get_radius()
384  geometric_center[0] += xyz[0]
385  geometric_center[1] += xyz[1]
386  geometric_center[2] += xyz[2]
387  atom_count += 1
388  particle_infos_for_pdb.append(
389  (xyz, atomtype, rt, self.dictchain[name][protname],
390  resind, None, radius))
391  resindexes_dict[protname].append(resind)
392 
394 
395  residue = IMP.atom.Residue(p)
396  resind = residue.get_index()
397  # skip if the residue was already added by atomistic resolution
398  # 0
399  if resind in resindexes_dict[protname]:
400  continue
401  else:
402  resindexes_dict[protname].append(resind)
403  rt = residue.get_residue_type()
404  xyz = IMP.core.XYZ(p).get_coordinates()
405  radius = IMP.core.XYZR(p).get_radius()
406  geometric_center[0] += xyz[0]
407  geometric_center[1] += xyz[1]
408  geometric_center[2] += xyz[2]
409  atom_count += 1
410  particle_infos_for_pdb.append(
411  (xyz, None, rt, self.dictchain[name][protname], resind,
412  None, radius))
413 
414  elif IMP.atom.Fragment.get_is_setup(p) and not is_a_bead:
415  resindexes = list(IMP.pmi.tools.get_residue_indexes(p))
416  resind = resindexes[len(resindexes) // 2]
417  if resind in resindexes_dict[protname]:
418  continue
419  else:
420  resindexes_dict[protname].append(resind)
421  rt = IMP.atom.ResidueType('BEA')
422  xyz = IMP.core.XYZ(p).get_coordinates()
423  radius = IMP.core.XYZR(p).get_radius()
424  geometric_center[0] += xyz[0]
425  geometric_center[1] += xyz[1]
426  geometric_center[2] += xyz[2]
427  atom_count += 1
428  particle_infos_for_pdb.append(
429  (xyz, None, rt, self.dictchain[name][protname], resind,
430  resindexes, radius))
431 
432  else:
433  if is_a_bead:
434  rt = IMP.atom.ResidueType('BEA')
435  resindexes = list(IMP.pmi.tools.get_residue_indexes(p))
436  if len(resindexes) > 0:
437  resind = resindexes[len(resindexes) // 2]
438  xyz = IMP.core.XYZ(p).get_coordinates()
439  radius = IMP.core.XYZR(p).get_radius()
440  geometric_center[0] += xyz[0]
441  geometric_center[1] += xyz[1]
442  geometric_center[2] += xyz[2]
443  atom_count += 1
444  particle_infos_for_pdb.append(
445  (xyz, None, rt, self.dictchain[name][protname],
446  resind, resindexes, radius))
447 
448  if atom_count > 0:
449  geometric_center = (geometric_center[0] / atom_count,
450  geometric_center[1] / atom_count,
451  geometric_center[2] / atom_count)
452 
453  # sort by chain ID, then residue index. Longer chain IDs (e.g. AA)
454  # should always come after shorter (e.g. Z)
455  particle_infos_for_pdb = sorted(particle_infos_for_pdb,
456  key=lambda x: (len(x[3]), x[3], x[4]))
457 
458  return (particle_infos_for_pdb, geometric_center)
459 
460  def write_pdbs(self, appendmode=True, mmcif=False):
461  for pdb in self.dictionary_pdbs.keys():
462  self.write_pdb(pdb, appendmode)
463 
464  def init_pdb_best_scoring(self, prefix, prot, nbestscoring,
465  replica_exchange=False, mmcif=False,
466  best_score_file='best.scores.rex.py'):
467  """Prepare for writing best-scoring PDBs (or mmCIFs) for a
468  sampling run.
469 
470  @param prefix Initial part of each PDB filename (e.g. 'model').
471  @param prot The top-level Hierarchy to output.
472  @param nbestscoring The number of best-scoring files to output.
473  @param replica_exchange Whether to combine best scores from a
474  replica exchange run.
475  @param mmcif If True, output models in mmCIF format. If False
476  (the default) output in legacy PDB format.
477  @param best_score_file The filename to use for replica
478  exchange scores.
479  """
480 
481  self._pdb_best_scoring_mmcif = mmcif
482  fileext = '.cif' if mmcif else '.pdb'
483  self.prefixes.append(prefix)
484  self.replica_exchange = replica_exchange
485  if not self.replica_exchange:
486  # common usage
487  # if you are not in replica exchange mode
488  # initialize the array of scores internally
489  self.best_score_list = []
490  else:
491  # otherwise the replicas must communicate
492  # through a common file to know what are the best scores
493  self.best_score_file_name = best_score_file
494  self.best_score_list = []
495  with open(self.best_score_file_name, "w") as best_score_file:
496  best_score_file.write(
497  "self.best_score_list=" + str(self.best_score_list) + "\n")
498 
499  self.nbestscoring = nbestscoring
500  for i in range(self.nbestscoring):
501  name = prefix + "." + str(i) + fileext
502  flpdb = open(name, 'w')
503  flpdb.close()
504  self.dictionary_pdbs[name] = prot
505  self._pdb_mmcif[name] = mmcif
506  self._init_dictchain(name, prot, mmcif=mmcif)
507 
508  def write_pdb_best_scoring(self, score):
509  if self.nbestscoring is None:
510  print("Output.write_pdb_best_scoring: init_pdb_best_scoring "
511  "not run")
512 
513  mmcif = self._pdb_best_scoring_mmcif
514  fileext = '.cif' if mmcif else '.pdb'
515  # update the score list
516  if self.replica_exchange:
517  # read the self.best_score_list from the file
518  with open(self.best_score_file_name) as fh:
519  self.best_score_list = ast.literal_eval(
520  fh.read().split('=')[1])
521 
522  if len(self.best_score_list) < self.nbestscoring:
523  self.best_score_list.append(score)
524  self.best_score_list.sort()
525  index = self.best_score_list.index(score)
526  for prefix in self.prefixes:
527  for i in range(len(self.best_score_list) - 2, index - 1, -1):
528  oldname = prefix + "." + str(i) + fileext
529  newname = prefix + "." + str(i + 1) + fileext
530  # rename on Windows fails if newname already exists
531  if os.path.exists(newname):
532  os.unlink(newname)
533  os.rename(oldname, newname)
534  filetoadd = prefix + "." + str(index) + fileext
535  self.write_pdb(filetoadd, appendmode=False)
536 
537  else:
538  if score < self.best_score_list[-1]:
539  self.best_score_list.append(score)
540  self.best_score_list.sort()
541  self.best_score_list.pop(-1)
542  index = self.best_score_list.index(score)
543  for prefix in self.prefixes:
544  for i in range(len(self.best_score_list) - 1,
545  index - 1, -1):
546  oldname = prefix + "." + str(i) + fileext
547  newname = prefix + "." + str(i + 1) + fileext
548  os.rename(oldname, newname)
549  filenametoremove = prefix + \
550  "." + str(self.nbestscoring) + fileext
551  os.remove(filenametoremove)
552  filetoadd = prefix + "." + str(index) + fileext
553  self.write_pdb(filetoadd, appendmode=False)
554 
555  if self.replica_exchange:
556  # write the self.best_score_list to the file
557  with open(self.best_score_file_name, "w") as best_score_file:
558  best_score_file.write(
559  "self.best_score_list=" + str(self.best_score_list) + '\n')
560 
561  def init_rmf(self, name, hierarchies, rs=None, geometries=None,
562  listofobjects=None):
563  """
564  Initialize an RMF file
565 
566  @param name the name of the RMF file
567  @param hierarchies the hierarchies to be included (it is a list)
568  @param rs optional, the restraint sets (it is a list)
569  @param geometries optional, the geometries (it is a list)
570  @param listofobjects optional, the list of objects for the stat
571  (it is a list)
572  """
573  rh = RMF.create_rmf_file(name)
574  IMP.rmf.add_hierarchies(rh, hierarchies)
575  cat = None
576  outputkey_rmfkey = None
577 
578  if rs is not None:
579  IMP.rmf.add_restraints(rh, rs)
580  if geometries is not None:
581  IMP.rmf.add_geometries(rh, geometries)
582  if listofobjects is not None:
583  cat = rh.get_category("stat")
584  outputkey_rmfkey = {}
585  for o in listofobjects:
586  if "get_output" not in dir(o):
587  raise ValueError(
588  "Output: object %s doesn't have get_output() method"
589  % str(o))
590  output = o.get_output()
591  for outputkey in output:
592  rmftag = RMF.string_tag
593  if isinstance(output[outputkey], float):
594  rmftag = RMF.float_tag
595  elif isinstance(output[outputkey], int):
596  rmftag = RMF.int_tag
597  elif isinstance(output[outputkey], str):
598  rmftag = RMF.string_tag
599  else:
600  rmftag = RMF.string_tag
601  rmfkey = rh.get_key(cat, outputkey, rmftag)
602  outputkey_rmfkey[outputkey] = rmfkey
603  outputkey_rmfkey["rmf_file"] = \
604  rh.get_key(cat, "rmf_file", RMF.string_tag)
605  outputkey_rmfkey["rmf_frame_index"] = \
606  rh.get_key(cat, "rmf_frame_index", RMF.int_tag)
607 
608  self.dictionary_rmfs[name] = (rh, cat, outputkey_rmfkey, listofobjects)
609 
610  def add_restraints_to_rmf(self, name, objectlist):
611  for o in _flatten(objectlist):
612  try:
613  rs = o.get_restraint_for_rmf()
614  if not isinstance(rs, (list, tuple)):
615  rs = [rs]
616  except: # noqa: E722
617  rs = [o.get_restraint()]
619  self.dictionary_rmfs[name][0], rs)
620 
621  def add_geometries_to_rmf(self, name, objectlist):
622  for o in objectlist:
623  geos = o.get_geometries()
624  IMP.rmf.add_geometries(self.dictionary_rmfs[name][0], geos)
625 
626  def add_particle_pair_from_restraints_to_rmf(self, name, objectlist):
627  for o in objectlist:
628 
629  pps = o.get_particle_pairs()
630  for pp in pps:
632  self.dictionary_rmfs[name][0],
634 
635  def write_rmf(self, name):
636  IMP.rmf.save_frame(self.dictionary_rmfs[name][0])
637  if self.dictionary_rmfs[name][1] is not None:
638  outputkey_rmfkey = self.dictionary_rmfs[name][2]
639  listofobjects = self.dictionary_rmfs[name][3]
640  for o in listofobjects:
641  output = o.get_output()
642  for outputkey in output:
643  rmfkey = outputkey_rmfkey[outputkey]
644  try:
645  n = self.dictionary_rmfs[name][0].get_root_node()
646  n.set_value(rmfkey, output[outputkey])
647  except NotImplementedError:
648  continue
649  rmfkey = outputkey_rmfkey["rmf_file"]
650  self.dictionary_rmfs[name][0].get_root_node().set_value(
651  rmfkey, name)
652  rmfkey = outputkey_rmfkey["rmf_frame_index"]
653  nframes = self.dictionary_rmfs[name][0].get_number_of_frames()
654  self.dictionary_rmfs[name][0].get_root_node().set_value(
655  rmfkey, nframes-1)
656  self.dictionary_rmfs[name][0].flush()
657 
658  def close_rmf(self, name):
659  rh = self.dictionary_rmfs[name][0]
660  del self.dictionary_rmfs[name]
661  del rh
662 
663  def write_rmfs(self):
664  for rmfinfo in self.dictionary_rmfs.keys():
665  self.write_rmf(rmfinfo[0])
666 
667  def init_stat(self, name, listofobjects):
668  if self.ascii:
669  flstat = open(name, 'w')
670  flstat.close()
671  else:
672  flstat = open(name, 'wb')
673  flstat.close()
674 
675  # check that all objects in listofobjects have a get_output method
676  for o in listofobjects:
677  if "get_output" not in dir(o):
678  raise ValueError(
679  "Output: object %s doesn't have get_output() method"
680  % str(o))
681  self.dictionary_stats[name] = listofobjects
682 
683  def set_output_entry(self, key, value):
684  self.initoutput.update({key: value})
685 
686  def write_stat(self, name, appendmode=True):
687  output = self.initoutput
688  for obj in self.dictionary_stats[name]:
689  d = obj.get_output()
690  # remove all entries that begin with _ (private entries)
691  dfiltered = dict((k, v) for k, v in d.items() if k[0] != "_")
692  output.update(dfiltered)
693 
694  if appendmode:
695  writeflag = 'a'
696  else:
697  writeflag = 'w'
698 
699  if self.ascii:
700  flstat = open(name, writeflag)
701  flstat.write("%s \n" % output)
702  flstat.close()
703  else:
704  flstat = open(name, writeflag + 'b')
705  pickle.dump(output, flstat, 2)
706  flstat.close()
707 
708  def write_stats(self):
709  for stat in self.dictionary_stats.keys():
710  self.write_stat(stat)
711 
712  def get_stat(self, name):
713  output = {}
714  for obj in self.dictionary_stats[name]:
715  output.update(obj.get_output())
716  return output
717 
718  def write_test(self, name, listofobjects):
719  flstat = open(name, 'w')
720  output = self.initoutput
721  for o in listofobjects:
722  if "get_test_output" not in dir(o) and "get_output" not in dir(o):
723  raise ValueError(
724  "Output: object %s doesn't have get_output() or "
725  "get_test_output() method" % str(o))
726  self.dictionary_stats[name] = listofobjects
727 
728  for obj in self.dictionary_stats[name]:
729  try:
730  d = obj.get_test_output()
731  except: # noqa: E722
732  d = obj.get_output()
733  # remove all entries that begin with _ (private entries)
734  dfiltered = dict((k, v) for k, v in d.items() if k[0] != "_")
735  output.update(dfiltered)
736  flstat.write("%s \n" % output)
737  flstat.close()
738 
739  def test(self, name, listofobjects, tolerance=1e-5):
740  output = self.initoutput
741  for o in listofobjects:
742  if "get_test_output" not in dir(o) and "get_output" not in dir(o):
743  raise ValueError(
744  "Output: object %s doesn't have get_output() or "
745  "get_test_output() method" % str(o))
746  for obj in listofobjects:
747  try:
748  output.update(obj.get_test_output())
749  except: # noqa: E722
750  output.update(obj.get_output())
751 
752  flstat = open(name, 'r')
753 
754  passed = True
755  for fl in flstat:
756  test_dict = ast.literal_eval(fl)
757  for k in test_dict:
758  if k in output:
759  old_value = str(test_dict[k])
760  new_value = str(output[k])
761  try:
762  float(old_value)
763  is_float = True
764  except ValueError:
765  is_float = False
766 
767  if is_float:
768  fold = float(old_value)
769  fnew = float(new_value)
770  diff = abs(fold - fnew)
771  if diff > tolerance:
772  print("%s: test failed, old value: %s new value %s; "
773  "diff %f > %f" % (str(k), str(old_value),
774  str(new_value), diff,
775  tolerance), file=sys.stderr)
776  passed = False
777  elif test_dict[k] != output[k]:
778  if len(old_value) < 50 and len(new_value) < 50:
779  print("%s: test failed, old value: %s new value %s"
780  % (str(k), old_value, new_value),
781  file=sys.stderr)
782  passed = False
783  else:
784  print("%s: test failed, omitting results (too long)"
785  % str(k), file=sys.stderr)
786  passed = False
787 
788  else:
789  print("%s from old objects (file %s) not in new objects"
790  % (str(k), str(name)), file=sys.stderr)
791  flstat.close()
792  return passed
793 
794  def get_environment_variables(self):
795  import os
796  return str(os.environ)
797 
798  def get_versions_of_relevant_modules(self):
799  import IMP
800  versions = {}
801  versions["IMP_VERSION"] = IMP.get_module_version()
802  versions["PMI_VERSION"] = IMP.pmi.get_module_version()
803  try:
804  import IMP.isd2
805  versions["ISD2_VERSION"] = IMP.isd2.get_module_version()
806  except ImportError:
807  pass
808  try:
809  import IMP.isd_emxl
810  versions["ISD_EMXL_VERSION"] = IMP.isd_emxl.get_module_version()
811  except ImportError:
812  pass
813  return versions
814 
815  def init_stat2(self, name, listofobjects, extralabels=None,
816  listofsummedobjects=None):
817  # this is a new stat file that should be less
818  # space greedy!
819  # listofsummedobjects must be in the form
820  # [([obj1,obj2,obj3,obj4...],label)]
821  # extralabels
822 
823  if listofsummedobjects is None:
824  listofsummedobjects = []
825  if extralabels is None:
826  extralabels = []
827  flstat = open(name, 'w')
828  output = {}
829  stat2_keywords = {"STAT2HEADER": "STAT2HEADER"}
830  stat2_keywords.update(
831  {"STAT2HEADER_ENVIRON": str(self.get_environment_variables())})
832  stat2_keywords.update(
833  {"STAT2HEADER_IMP_VERSIONS":
834  str(self.get_versions_of_relevant_modules())})
835  stat2_inverse = {}
836 
837  for obj in listofobjects:
838  if "get_output" not in dir(obj):
839  raise ValueError(
840  "Output: object %s doesn't have get_output() method"
841  % str(obj))
842  else:
843  d = obj.get_output()
844  # remove all entries that begin with _ (private entries)
845  dfiltered = dict((k, v)
846  for k, v in d.items() if k[0] != "_")
847  output.update(dfiltered)
848 
849  # check for customizable entries
850  for obj in listofsummedobjects:
851  for t in obj[0]:
852  if "get_output" not in dir(t):
853  raise ValueError(
854  "Output: object %s doesn't have get_output() method"
855  % str(t))
856  else:
857  if "_TotalScore" not in t.get_output():
858  raise ValueError(
859  "Output: object %s doesn't have _TotalScore "
860  "entry to be summed" % str(t))
861  else:
862  output.update({obj[1]: 0.0})
863 
864  for k in extralabels:
865  output.update({k: 0.0})
866 
867  for n, k in enumerate(output):
868  stat2_keywords.update({n: k})
869  stat2_inverse.update({k: n})
870 
871  flstat.write("%s \n" % stat2_keywords)
872  flstat.close()
873  self.dictionary_stats2[name] = (
874  listofobjects,
875  stat2_inverse,
876  listofsummedobjects,
877  extralabels)
878 
879  def write_stat2(self, name, appendmode=True):
880  output = {}
881  (listofobjects, stat2_inverse, listofsummedobjects,
882  extralabels) = self.dictionary_stats2[name]
883 
884  # writing objects
885  for obj in listofobjects:
886  od = obj.get_output()
887  dfiltered = dict((k, v) for k, v in od.items() if k[0] != "_")
888  for k in dfiltered:
889  output.update({stat2_inverse[k]: od[k]})
890 
891  # writing summedobjects
892  for so in listofsummedobjects:
893  partial_score = 0.0
894  for t in so[0]:
895  d = t.get_output()
896  partial_score += float(d["_TotalScore"])
897  output.update({stat2_inverse[so[1]]: str(partial_score)})
898 
899  # writing extralabels
900  for k in extralabels:
901  if k in self.initoutput:
902  output.update({stat2_inverse[k]: self.initoutput[k]})
903  else:
904  output.update({stat2_inverse[k]: "None"})
905 
906  if appendmode:
907  writeflag = 'a'
908  else:
909  writeflag = 'w'
910 
911  flstat = open(name, writeflag)
912  flstat.write("%s \n" % output)
913  flstat.close()
914 
915  def write_stats2(self):
916  for stat in self.dictionary_stats2.keys():
917  self.write_stat2(stat)
918 
919 
921  """Collect statistics from ProcessOutput.get_fields().
922  Counters of the total number of frames read, plus the models that
923  passed the various filters used in get_fields(), are provided."""
924  def __init__(self):
925  self.total = 0
926  self.passed_get_every = 0
927  self.passed_filterout = 0
928  self.passed_filtertuple = 0
929 
930 
932  """A class for reading stat files (either rmf or ascii v1 and v2)"""
933  def __init__(self, filename):
934  self.filename = filename
935  self.isstat1 = False
936  self.isstat2 = False
937  self.isrmf = False
938 
939  if self.filename is None:
940  raise ValueError("No file name provided. Use -h for help")
941 
942  try:
943  # let's see if that is an rmf file
944  rh = RMF.open_rmf_file_read_only(self.filename)
945  self.isrmf = True
946  cat = rh.get_category('stat')
947  rmf_klist = rh.get_keys(cat)
948  self.rmf_names_keys = dict([(rh.get_name(k), k)
949  for k in rmf_klist])
950  del rh
951 
952  except IOError:
953  f = open(self.filename, "r")
954  # try with an ascii stat file
955  # get the keys from the first line
956  for line in f.readlines():
957  d = ast.literal_eval(line)
958  self.klist = list(d.keys())
959  # check if it is a stat2 file
960  if "STAT2HEADER" in self.klist:
961  self.isstat2 = True
962  for k in self.klist:
963  if "STAT2HEADER" in str(k):
964  # if print_header: print k, d[k]
965  del d[k]
966  stat2_dict = d
967  # get the list of keys sorted by value
968  kkeys = [k[0]
969  for k in sorted(stat2_dict.items(),
970  key=operator.itemgetter(1))]
971  self.klist = [k[1]
972  for k in sorted(stat2_dict.items(),
973  key=operator.itemgetter(1))]
974  self.invstat2_dict = {}
975  for k in kkeys:
976  self.invstat2_dict.update({stat2_dict[k]: k})
977  else:
979  "statfile v1 is deprecated. "
980  "Please convert to statfile v2.\n")
981  self.isstat1 = True
982  self.klist.sort()
983 
984  break
985  f.close()
986 
987  def get_keys(self):
988  if self.isrmf:
989  return sorted(self.rmf_names_keys.keys())
990  else:
991  return self.klist
992 
993  def show_keys(self, ncolumns=2, truncate=65):
994  IMP.pmi.tools.print_multicolumn(self.get_keys(), ncolumns, truncate)
995 
996  def get_fields(self, fields, filtertuple=None, filterout=None, get_every=1,
997  statistics=None):
998  '''
999  Get the desired field names, and return a dictionary.
1000  Namely, "fields" are the queried keys in the stat file
1001  (eg. ["Total_Score",...])
1002  The returned data structure is a dictionary, where each key is
1003  a field and the value is the time series (ie, frame ordered series)
1004  of that field (ie, {"Total_Score":[Score_0,Score_1,Score_2,,...],....})
1005 
1006  @param fields (list of strings) queried keys in the stat file
1007  (eg. "Total_Score"....)
1008  @param filterout specify if you want to "grep" out something from
1009  the file, so that it is faster
1010  @param filtertuple a tuple that contains
1011  ("TheKeyToBeFiltered",relationship,value)
1012  where relationship = "<", "==", or ">"
1013  @param get_every only read every Nth line from the file
1014  @param statistics if provided, accumulate statistics in an
1015  OutputStatistics object
1016  '''
1017 
1018  if statistics is None:
1019  statistics = OutputStatistics()
1020  outdict = {}
1021  for field in fields:
1022  outdict[field] = []
1023 
1024  # print fields values
1025  if self.isrmf:
1026  rh = RMF.open_rmf_file_read_only(self.filename)
1027  nframes = rh.get_number_of_frames()
1028  for i in range(nframes):
1029  statistics.total += 1
1030  # "get_every" and "filterout" not enforced for RMF
1031  statistics.passed_get_every += 1
1032  statistics.passed_filterout += 1
1033  rh.set_current_frame(RMF.FrameID(i))
1034  if filtertuple is not None:
1035  keytobefiltered = filtertuple[0]
1036  relationship = filtertuple[1]
1037  value = filtertuple[2]
1038  datavalue = rh.get_root_node().get_value(
1039  self.rmf_names_keys[keytobefiltered])
1040  if self.isfiltered(datavalue, relationship, value):
1041  continue
1042 
1043  statistics.passed_filtertuple += 1
1044  for field in fields:
1045  outdict[field].append(rh.get_root_node().get_value(
1046  self.rmf_names_keys[field]))
1047 
1048  else:
1049  f = open(self.filename, "r")
1050  line_number = 0
1051 
1052  for line in f.readlines():
1053  statistics.total += 1
1054  if filterout is not None:
1055  if filterout in line:
1056  continue
1057  statistics.passed_filterout += 1
1058  line_number += 1
1059 
1060  if line_number % get_every != 0:
1061  if line_number == 1 and self.isstat2:
1062  statistics.total -= 1
1063  statistics.passed_filterout -= 1
1064  continue
1065  statistics.passed_get_every += 1
1066  try:
1067  d = ast.literal_eval(line)
1068  except: # noqa: E722
1069  print("# Warning: skipped line number " + str(line_number)
1070  + " not a valid line")
1071  continue
1072 
1073  if self.isstat1:
1074 
1075  if filtertuple is not None:
1076  keytobefiltered = filtertuple[0]
1077  relationship = filtertuple[1]
1078  value = filtertuple[2]
1079  datavalue = d[keytobefiltered]
1080  if self.isfiltered(datavalue, relationship, value):
1081  continue
1082 
1083  statistics.passed_filtertuple += 1
1084  [outdict[field].append(d[field]) for field in fields]
1085 
1086  elif self.isstat2:
1087  if line_number == 1:
1088  statistics.total -= 1
1089  statistics.passed_filterout -= 1
1090  statistics.passed_get_every -= 1
1091  continue
1092 
1093  if filtertuple is not None:
1094  keytobefiltered = filtertuple[0]
1095  relationship = filtertuple[1]
1096  value = filtertuple[2]
1097  datavalue = d[self.invstat2_dict[keytobefiltered]]
1098  if self.isfiltered(datavalue, relationship, value):
1099  continue
1100 
1101  statistics.passed_filtertuple += 1
1102  [outdict[field].append(d[self.invstat2_dict[field]])
1103  for field in fields]
1104 
1105  f.close()
1106 
1107  return outdict
1108 
1109  def isfiltered(self, datavalue, relationship, refvalue):
1110  dofilter = False
1111  try:
1112  _ = float(datavalue)
1113  except ValueError:
1114  raise ValueError("ProcessOutput.filter: datavalue cannot be "
1115  "converted into a float")
1116 
1117  if relationship == "<":
1118  if float(datavalue) >= refvalue:
1119  dofilter = True
1120  if relationship == ">":
1121  if float(datavalue) <= refvalue:
1122  dofilter = True
1123  if relationship == "==":
1124  if float(datavalue) != refvalue:
1125  dofilter = True
1126  return dofilter
1127 
1128 
1130  """ class to allow more advanced handling of RMF files.
1131  It is both a container and a IMP.atom.Hierarchy.
1132  - it is iterable (while loading the corresponding frame)
1133  - Item brackets [] load the corresponding frame
1134  - slice create an iterator
1135  - can relink to another RMF file
1136  """
1137  def __init__(self, model, rmf_file_name):
1138  """
1139  @param model: the IMP.Model()
1140  @param rmf_file_name: str, path of the rmf file
1141  """
1142  self.model = model
1143  try:
1144  self.rh_ref = RMF.open_rmf_file_read_only(rmf_file_name)
1145  except TypeError:
1146  raise TypeError("Wrong rmf file name or type: %s"
1147  % str(rmf_file_name))
1148  hs = IMP.rmf.create_hierarchies(self.rh_ref, self.model)
1149  IMP.rmf.load_frame(self.rh_ref, RMF.FrameID(0))
1150  self.root_hier_ref = hs[0]
1151  super().__init__(self.root_hier_ref)
1152  self.model.update()
1153  self.ColorHierarchy = None
1154 
1155  def link_to_rmf(self, rmf_file_name):
1156  """
1157  Link to another RMF file
1158  """
1159  self.rh_ref = RMF.open_rmf_file_read_only(rmf_file_name)
1160  IMP.rmf.link_hierarchies(self.rh_ref, [self])
1161  if self.ColorHierarchy:
1162  self.ColorHierarchy.method()
1163  RMFHierarchyHandler.set_frame(self, 0)
1164 
1165  def set_frame(self, index):
1166  try:
1167  IMP.rmf.load_frame(self.rh_ref, RMF.FrameID(index))
1168  except: # noqa: E722
1169  print("skipping frame %s:%d\n" % (self.current_rmf, index))
1170  self.model.update()
1171 
1172  def get_number_of_frames(self):
1173  return self.rh_ref.get_number_of_frames()
1174 
1175  def __getitem__(self, int_slice_adaptor):
1176  if isinstance(int_slice_adaptor, int):
1177  self.set_frame(int_slice_adaptor)
1178  return int_slice_adaptor
1179  elif isinstance(int_slice_adaptor, slice):
1180  return self.__iter__(int_slice_adaptor)
1181  else:
1182  raise TypeError("Unknown Type")
1183 
1184  def __len__(self):
1185  return self.get_number_of_frames()
1186 
1187  def __iter__(self, slice_key=None):
1188  if slice_key is None:
1189  for nframe in range(len(self)):
1190  yield self[nframe]
1191  else:
1192  for nframe in list(range(len(self)))[slice_key]:
1193  yield self[nframe]
1194 
1195 
1196 class CacheHierarchyCoordinates:
1197  def __init__(self, StatHierarchyHandler):
1198  self.xyzs = []
1199  self.nrms = []
1200  self.rbs = []
1201  self.nrm_coors = {}
1202  self.xyz_coors = {}
1203  self.rb_trans = {}
1204  self.current_index = None
1205  self.rmfh = StatHierarchyHandler
1206  rbs, xyzs = IMP.pmi.tools.get_rbs_and_beads([self.rmfh])
1207  self.model = self.rmfh.get_model()
1208  self.rbs = rbs
1209  for xyz in xyzs:
1211  nrm = IMP.core.NonRigidMember(xyz)
1212  self.nrms.append(nrm)
1213  else:
1214  fb = IMP.core.XYZ(xyz)
1215  self.xyzs.append(fb)
1216 
1217  def do_store(self, index):
1218  self.rb_trans[index] = {}
1219  self.nrm_coors[index] = {}
1220  self.xyz_coors[index] = {}
1221  for rb in self.rbs:
1222  self.rb_trans[index][rb] = rb.get_reference_frame()
1223  for nrm in self.nrms:
1224  self.nrm_coors[index][nrm] = nrm.get_internal_coordinates()
1225  for xyz in self.xyzs:
1226  self.xyz_coors[index][xyz] = xyz.get_coordinates()
1227  self.current_index = index
1228 
1229  def do_update(self, index):
1230  if self.current_index != index:
1231  for rb in self.rbs:
1232  rb.set_reference_frame(self.rb_trans[index][rb])
1233  for nrm in self.nrms:
1234  nrm.set_internal_coordinates(self.nrm_coors[index][nrm])
1235  for xyz in self.xyzs:
1236  xyz.set_coordinates(self.xyz_coors[index][xyz])
1237  self.current_index = index
1238  self.model.update()
1239 
1240  def get_number_of_frames(self):
1241  return len(self.rb_trans.keys())
1242 
1243  def __getitem__(self, index):
1244  if isinstance(index, int):
1245  return index in self.rb_trans.keys()
1246  else:
1247  raise TypeError("Unknown Type")
1248 
1249  def __len__(self):
1250  return self.get_number_of_frames()
1251 
1252 
1254  """ class to link stat files to several rmf files """
1255  def __init__(self, model=None, stat_file=None,
1256  number_best_scoring_models=None, score_key=None,
1257  StatHierarchyHandler=None, cache=None):
1258  """
1259 
1260  @param model: IMP.Model()
1261  @param stat_file: either 1) a list or 2) a single stat file names
1262  (either rmfs or ascii, or pickled data or pickled cluster),
1263  3) a dictionary containing an rmf/ascii
1264  stat file name as key and a list of frames as values
1265  @param number_best_scoring_models:
1266  @param StatHierarchyHandler: copy constructor input object
1267  @param cache: cache coordinates and rigid body transformations.
1268  """
1269 
1270  if StatHierarchyHandler is not None:
1271  # overrides all other arguments
1272  # copy constructor: create a copy with
1273  # different RMFHierarchyHandler
1274  self.model = StatHierarchyHandler.model
1275  self.data = StatHierarchyHandler.data
1276  self.number_best_scoring_models = \
1277  StatHierarchyHandler.number_best_scoring_models
1278  self.is_setup = True
1279  self.current_rmf = StatHierarchyHandler.current_rmf
1280  self.current_frame = None
1281  self.current_index = None
1282  self.score_threshold = StatHierarchyHandler.score_threshold
1283  self.score_key = StatHierarchyHandler.score_key
1284  self.cache = StatHierarchyHandler.cache
1285  super().__init__(self.model, self.current_rmf)
1286  if self.cache:
1287  self.cache = CacheHierarchyCoordinates(self)
1288  else:
1289  self.cache = None
1290  self.set_frame(0)
1291 
1292  else:
1293  # standard constructor
1294  self.model = model
1295  self.data = []
1296  self.number_best_scoring_models = number_best_scoring_models
1297  self.cache = cache
1298 
1299  if score_key is None:
1300  self.score_key = "Total_Score"
1301  else:
1302  self.score_key = score_key
1303  self.is_setup = None
1304  self.current_rmf = None
1305  self.current_frame = None
1306  self.current_index = None
1307  self.score_threshold = None
1308 
1309  if isinstance(stat_file, str):
1310  self.add_stat_file(stat_file)
1311  elif isinstance(stat_file, list):
1312  for f in stat_file:
1313  self.add_stat_file(f)
1314 
1315  def add_stat_file(self, stat_file):
1316  try:
1317  '''check that it is not a pickle file with saved data
1318  from a previous calculation'''
1319  self.load_data(stat_file)
1320 
1321  if self.number_best_scoring_models:
1322  scores = self.get_scores()
1323  max_score = sorted(scores)[
1324  0:min(len(self), self.number_best_scoring_models)][-1]
1325  self.do_filter_by_score(max_score)
1326 
1327  except pickle.UnpicklingError:
1328  '''alternatively read the ascii stat files'''
1329  try:
1330  scores, rmf_files, rmf_frame_indexes, features = \
1331  self.get_info_from_stat_file(stat_file,
1332  self.score_threshold)
1333  except (KeyError, SyntaxError):
1334  # in this case check that is it an rmf file, probably
1335  # without stat stored in
1336  try:
1337  # let's see if that is an rmf file
1338  rh = RMF.open_rmf_file_read_only(stat_file)
1339  nframes = rh.get_number_of_frames()
1340  scores = [0.0]*nframes
1341  rmf_files = [stat_file]*nframes
1342  rmf_frame_indexes = range(nframes)
1343  features = {}
1344  except: # noqa: E722
1345  return
1346 
1347  if len(set(rmf_files)) > 1:
1348  raise ("Multiple RMF files found")
1349 
1350  if not rmf_files:
1351  print("StatHierarchyHandler: Error: Trying to set none as "
1352  "rmf_file (probably empty stat file), aborting")
1353  return
1354 
1355  for n, index in enumerate(rmf_frame_indexes):
1356  featn_dict = dict([(k, features[k][n]) for k in features])
1357  self.data.append(IMP.pmi.output.DataEntry(
1358  stat_file, rmf_files[n], index, scores[n], featn_dict))
1359 
1360  if self.number_best_scoring_models:
1361  scores = self.get_scores()
1362  max_score = sorted(scores)[
1363  0:min(len(self), self.number_best_scoring_models)][-1]
1364  self.do_filter_by_score(max_score)
1365 
1366  if not self.is_setup:
1367  RMFHierarchyHandler.__init__(
1368  self, self.model, self.get_rmf_names()[0])
1369  if self.cache:
1370  self.cache = CacheHierarchyCoordinates(self)
1371  else:
1372  self.cache = None
1373  self.is_setup = True
1374  self.current_rmf = self.get_rmf_names()[0]
1375 
1376  self.set_frame(0)
1377 
1378  def save_data(self, filename='data.pkl'):
1379  with open(filename, 'wb') as fl:
1380  pickle.dump(self.data, fl)
1381 
1382  def load_data(self, filename='data.pkl'):
1383  with open(filename, 'rb') as fl:
1384  data_structure = pickle.load(fl)
1385  # first check that it is a list
1386  if not isinstance(data_structure, list):
1387  raise TypeError(
1388  "%filename should contain a list of IMP.pmi.output.DataEntry "
1389  "or IMP.pmi.output.Cluster" % filename)
1390  # second check the types
1391  if all(isinstance(item, IMP.pmi.output.DataEntry)
1392  for item in data_structure):
1393  self.data = data_structure
1394  elif all(isinstance(item, IMP.pmi.output.Cluster)
1395  for item in data_structure):
1396  nmodels = 0
1397  for cluster in data_structure:
1398  nmodels += len(cluster)
1399  self.data = [None]*nmodels
1400  for cluster in data_structure:
1401  for n, data in enumerate(cluster):
1402  index = cluster.members[n]
1403  self.data[index] = data
1404  else:
1405  raise TypeError(
1406  "%filename should contain a list of IMP.pmi.output.DataEntry "
1407  "or IMP.pmi.output.Cluster" % filename)
1408 
1409  def set_frame(self, index):
1410  if self.cache is not None and self.cache[index]:
1411  self.cache.do_update(index)
1412  else:
1413  nm = self.data[index].rmf_name
1414  fidx = self.data[index].rmf_index
1415  if nm != self.current_rmf:
1416  self.link_to_rmf(nm)
1417  self.current_rmf = nm
1418  self.current_frame = -1
1419  if fidx != self.current_frame:
1420  RMFHierarchyHandler.set_frame(self, fidx)
1421  self.current_frame = fidx
1422  if self.cache is not None:
1423  self.cache.do_store(index)
1424 
1425  self.current_index = index
1426 
1427  def __getitem__(self, int_slice_adaptor):
1428  if isinstance(int_slice_adaptor, int):
1429  self.set_frame(int_slice_adaptor)
1430  return self.data[int_slice_adaptor]
1431  elif isinstance(int_slice_adaptor, slice):
1432  return self.__iter__(int_slice_adaptor)
1433  else:
1434  raise TypeError("Unknown Type")
1435 
1436  def __len__(self):
1437  return len(self.data)
1438 
1439  def __iter__(self, slice_key=None):
1440  if slice_key is None:
1441  for i in range(len(self)):
1442  yield self[i]
1443  else:
1444  for i in range(len(self))[slice_key]:
1445  yield self[i]
1446 
1447  def do_filter_by_score(self, maximum_score):
1448  self.data = [d for d in self.data if d.score <= maximum_score]
1449 
1450  def get_scores(self):
1451  return [d.score for d in self.data]
1452 
1453  def get_feature_series(self, feature_name):
1454  return [d.features[feature_name] for d in self.data]
1455 
1456  def get_feature_names(self):
1457  return self.data[0].features.keys()
1458 
1459  def get_rmf_names(self):
1460  return [d.rmf_name for d in self.data]
1461 
1462  def get_stat_files_names(self):
1463  return [d.stat_file for d in self.data]
1464 
1465  def get_rmf_indexes(self):
1466  return [d.rmf_index for d in self.data]
1467 
1468  def get_info_from_stat_file(self, stat_file, score_threshold=None):
1469  po = ProcessOutput(stat_file)
1470  fs = po.get_keys()
1471  models = IMP.pmi.io.get_best_models(
1472  [stat_file], score_key=self.score_key, feature_keys=fs,
1473  rmf_file_key="rmf_file", rmf_file_frame_key="rmf_frame_index",
1474  prefiltervalue=score_threshold, get_every=1)
1475 
1476  scores = [float(y) for y in models[2]]
1477  rmf_files = models[0]
1478  rmf_frame_indexes = models[1]
1479  features = models[3]
1480  return scores, rmf_files, rmf_frame_indexes, features
1481 
1482 
1484  '''
1485  A class to store data associated to a model
1486  '''
1487  def __init__(self, stat_file=None, rmf_name=None, rmf_index=None,
1488  score=None, features=None):
1489  self.rmf_name = rmf_name
1490  self.rmf_index = rmf_index
1491  self.score = score
1492  self.features = features
1493  self.stat_file = stat_file
1494 
1495  def __repr__(self):
1496  s = "IMP.pmi.output.DataEntry\n"
1497  s += "---- stat file %s \n" % (self.stat_file)
1498  s += "---- rmf file %s \n" % (self.rmf_name)
1499  s += "---- rmf index %s \n" % (str(self.rmf_index))
1500  s += "---- score %s \n" % (str(self.score))
1501  s += "---- number of features %s \n" % (str(len(self.features.keys())))
1502  return s
1503 
1504 
1505 class Cluster:
1506  '''
1507  A container for models organized into clusters
1508  '''
1509  def __init__(self, cid=None):
1510  self.cluster_id = cid
1511  self.members = []
1512  self.precision = None
1513  self.center_index = None
1514  self.members_data = {}
1515 
1516  def add_member(self, index, data=None):
1517  self.members.append(index)
1518  self.members_data[index] = data
1519  self.average_score = self.compute_score()
1520 
1521  def compute_score(self):
1522  try:
1523  score = sum([d.score for d in self])/len(self)
1524  except AttributeError:
1525  score = None
1526  return score
1527 
1528  def __repr__(self):
1529  s = "IMP.pmi.output.Cluster\n"
1530  s += "---- cluster_id %s \n" % str(self.cluster_id)
1531  s += "---- precision %s \n" % str(self.precision)
1532  s += "---- average score %s \n" % str(self.average_score)
1533  s += "---- number of members %s \n" % str(len(self.members))
1534  s += "---- center index %s \n" % str(self.center_index)
1535  return s
1536 
1537  def __getitem__(self, int_slice_adaptor):
1538  if isinstance(int_slice_adaptor, int):
1539  index = self.members[int_slice_adaptor]
1540  return self.members_data[index]
1541  elif isinstance(int_slice_adaptor, slice):
1542  return self.__iter__(int_slice_adaptor)
1543  else:
1544  raise TypeError("Unknown Type")
1545 
1546  def __len__(self):
1547  return len(self.members)
1548 
1549  def __iter__(self, slice_key=None):
1550  if slice_key is None:
1551  for i in range(len(self)):
1552  yield self[i]
1553  else:
1554  for i in range(len(self))[slice_key]:
1555  yield self[i]
1556 
1557  def __add__(self, other):
1558  self.members += other.members
1559  self.members_data.update(other.members_data)
1560  self.average_score = self.compute_score()
1561  self.precision = None
1562  self.center_index = None
1563  return self
1564 
1565 
1566 def plot_clusters_populations(clusters):
1567  indexes = []
1568  populations = []
1569  for cluster in clusters:
1570  indexes.append(cluster.cluster_id)
1571  populations.append(len(cluster))
1572 
1573  import matplotlib.pyplot as plt
1574  fig, ax = plt.subplots()
1575  ax.bar(indexes, populations, 0.5, color='r')
1576  ax.set_ylabel('Population')
1577  ax.set_xlabel(('Cluster index'))
1578  plt.show()
1579 
1580 
1581 def plot_clusters_precisions(clusters):
1582  indexes = []
1583  precisions = []
1584  for cluster in clusters:
1585  indexes.append(cluster.cluster_id)
1586 
1587  prec = cluster.precision
1588  print(cluster.cluster_id, prec)
1589  if prec is None:
1590  prec = 0.0
1591  precisions.append(prec)
1592 
1593  import matplotlib.pyplot as plt
1594  fig, ax = plt.subplots()
1595  ax.bar(indexes, precisions, 0.5, color='r')
1596  ax.set_ylabel('Precision [A]')
1597  ax.set_xlabel(('Cluster index'))
1598  plt.show()
1599 
1600 
1601 def plot_clusters_scores(clusters):
1602  indexes = []
1603  values = []
1604  for cluster in clusters:
1605  indexes.append(cluster.cluster_id)
1606  values.append([])
1607  for data in cluster:
1608  values[-1].append(data.score)
1609 
1610  plot_fields_box_plots("scores.pdf", values, indexes, frequencies=None,
1611  valuename="Scores", positionname="Cluster index",
1612  xlabels=None, scale_plot_length=1.0)
1613 
1614 
1615 class CrossLinkIdentifierDatabase:
1616  def __init__(self):
1617  self.clidb = dict()
1618 
1619  def check_key(self, key):
1620  if key not in self.clidb:
1621  self.clidb[key] = {}
1622 
1623  def set_unique_id(self, key, value):
1624  self.check_key(key)
1625  self.clidb[key]["XLUniqueID"] = str(value)
1626 
1627  def set_protein1(self, key, value):
1628  self.check_key(key)
1629  self.clidb[key]["Protein1"] = str(value)
1630 
1631  def set_protein2(self, key, value):
1632  self.check_key(key)
1633  self.clidb[key]["Protein2"] = str(value)
1634 
1635  def set_residue1(self, key, value):
1636  self.check_key(key)
1637  self.clidb[key]["Residue1"] = int(value)
1638 
1639  def set_residue2(self, key, value):
1640  self.check_key(key)
1641  self.clidb[key]["Residue2"] = int(value)
1642 
1643  def set_idscore(self, key, value):
1644  self.check_key(key)
1645  self.clidb[key]["IDScore"] = float(value)
1646 
1647  def set_state(self, key, value):
1648  self.check_key(key)
1649  self.clidb[key]["State"] = int(value)
1650 
1651  def set_sigma1(self, key, value):
1652  self.check_key(key)
1653  self.clidb[key]["Sigma1"] = str(value)
1654 
1655  def set_sigma2(self, key, value):
1656  self.check_key(key)
1657  self.clidb[key]["Sigma2"] = str(value)
1658 
1659  def set_psi(self, key, value):
1660  self.check_key(key)
1661  self.clidb[key]["Psi"] = str(value)
1662 
1663  def get_unique_id(self, key):
1664  return self.clidb[key]["XLUniqueID"]
1665 
1666  def get_protein1(self, key):
1667  return self.clidb[key]["Protein1"]
1668 
1669  def get_protein2(self, key):
1670  return self.clidb[key]["Protein2"]
1671 
1672  def get_residue1(self, key):
1673  return self.clidb[key]["Residue1"]
1674 
1675  def get_residue2(self, key):
1676  return self.clidb[key]["Residue2"]
1677 
1678  def get_idscore(self, key):
1679  return self.clidb[key]["IDScore"]
1680 
1681  def get_state(self, key):
1682  return self.clidb[key]["State"]
1683 
1684  def get_sigma1(self, key):
1685  return self.clidb[key]["Sigma1"]
1686 
1687  def get_sigma2(self, key):
1688  return self.clidb[key]["Sigma2"]
1689 
1690  def get_psi(self, key):
1691  return self.clidb[key]["Psi"]
1692 
1693  def set_float_feature(self, key, value, feature_name):
1694  self.check_key(key)
1695  self.clidb[key][feature_name] = float(value)
1696 
1697  def set_int_feature(self, key, value, feature_name):
1698  self.check_key(key)
1699  self.clidb[key][feature_name] = int(value)
1700 
1701  def set_string_feature(self, key, value, feature_name):
1702  self.check_key(key)
1703  self.clidb[key][feature_name] = str(value)
1704 
1705  def get_feature(self, key, feature_name):
1706  return self.clidb[key][feature_name]
1707 
1708  def write(self, filename):
1709  with open(filename, 'wb') as handle:
1710  pickle.dump(self.clidb, handle)
1711 
1712  def load(self, filename):
1713  with open(filename, 'rb') as handle:
1714  self.clidb = pickle.load(handle)
1715 
1716 
1717 def plot_fields(fields, output, framemin=None, framemax=None):
1718  """Plot the given fields and save a figure as `output`.
1719  The fields generally are extracted from a stat file
1720  using ProcessOutput.get_fields()."""
1721  import matplotlib as mpl
1722  mpl.use('Agg')
1723  import matplotlib.pyplot as plt
1724 
1725  plt.rc('lines', linewidth=4)
1726  fig, axs = plt.subplots(nrows=len(fields))
1727  fig.set_size_inches(10.5, 5.5 * len(fields))
1728  plt.rc('axes')
1729 
1730  n = 0
1731  for key in fields:
1732  if framemin is None:
1733  framemin = 0
1734  if framemax is None:
1735  framemax = len(fields[key])
1736  x = list(range(framemin, framemax))
1737  y = [float(y) for y in fields[key][framemin:framemax]]
1738  if len(fields) > 1:
1739  axs[n].plot(x, y)
1740  axs[n].set_title(key, size="xx-large")
1741  axs[n].tick_params(labelsize=18, pad=10)
1742  else:
1743  axs.plot(x, y)
1744  axs.set_title(key, size="xx-large")
1745  axs.tick_params(labelsize=18, pad=10)
1746  n += 1
1747 
1748  # Tweak spacing between subplots to prevent labels from overlapping
1749  plt.subplots_adjust(hspace=0.3)
1750  plt.savefig(output)
1751 
1752 
1753 def plot_field_histogram(name, values_lists, valuename=None, bins=40,
1754  colors=None, format="png", reference_xline=None,
1755  yplotrange=None, xplotrange=None, normalized=True,
1756  leg_names=None):
1757  '''Plot a list of histograms from a value list.
1758  @param name the name of the plot
1759  @param values_lists the list of list of values eg: [[...],[...],[...]]
1760  @param valuename the y-label
1761  @param bins the number of bins
1762  @param colors If None, will use rainbow. Else will use specific list
1763  @param format output format
1764  @param reference_xline plot a reference line parallel to the y-axis
1765  @param yplotrange the range for the y-axis
1766  @param normalized whether the histogram is normalized or not
1767  @param leg_names names for the legend
1768  '''
1769 
1770  import matplotlib as mpl
1771  mpl.use('Agg')
1772  import matplotlib.pyplot as plt
1773  import matplotlib.cm as cm
1774  plt.figure(figsize=(18.0, 9.0))
1775 
1776  if colors is None:
1777  colors = cm.rainbow(np.linspace(0, 1, len(values_lists)))
1778  for nv, values in enumerate(values_lists):
1779  col = colors[nv]
1780  if leg_names is not None:
1781  label = leg_names[nv]
1782  else:
1783  label = str(nv)
1784  try:
1785  plt.hist(
1786  [float(y) for y in values], bins=bins, color=col,
1787  density=normalized, histtype='step', lw=4, label=label)
1788  except AttributeError:
1789  plt.hist(
1790  [float(y) for y in values], bins=bins, color=col,
1791  normed=normalized, histtype='step', lw=4, label=label)
1792 
1793  # plt.title(name,size="xx-large")
1794  plt.tick_params(labelsize=12, pad=10)
1795  if valuename is None:
1796  plt.xlabel(name, size="xx-large")
1797  else:
1798  plt.xlabel(valuename, size="xx-large")
1799  plt.ylabel("Frequency", size="xx-large")
1800 
1801  if yplotrange is not None:
1802  plt.ylim()
1803  if xplotrange is not None:
1804  plt.xlim(xplotrange)
1805 
1806  plt.legend(loc=2)
1807 
1808  if reference_xline is not None:
1809  plt.axvline(
1810  reference_xline,
1811  color='red',
1812  linestyle='dashed',
1813  linewidth=1)
1814 
1815  plt.savefig(name + "." + format, dpi=150, transparent=True)
1816 
1817 
1818 def plot_fields_box_plots(name, values, positions, frequencies=None,
1819  valuename="None", positionname="None",
1820  xlabels=None, scale_plot_length=1.0):
1821  '''
1822  Plot time series as boxplots.
1823  fields is a list of time series, positions are the x-values
1824  valuename is the y-label, positionname is the x-label
1825  '''
1826 
1827  import matplotlib as mpl
1828  mpl.use('Agg')
1829  import matplotlib.pyplot as plt
1830 
1831  bps = []
1832  fig = plt.figure(figsize=(float(len(positions))*scale_plot_length, 5.0))
1833  fig.canvas.manager.set_window_title(name)
1834 
1835  ax1 = fig.add_subplot(111)
1836 
1837  plt.subplots_adjust(left=0.1, right=0.990, top=0.95, bottom=0.4)
1838 
1839  bps.append(plt.boxplot(values, notch=0, sym='', vert=1,
1840  whis=1.5, positions=positions))
1841 
1842  plt.setp(bps[-1]['boxes'], color='black', lw=1.5)
1843  plt.setp(bps[-1]['whiskers'], color='black', ls=":", lw=1.5)
1844 
1845  if frequencies is not None:
1846  for n, v in enumerate(values):
1847  plist = [positions[n]]*len(v)
1848  ax1.plot(plist, v, 'gx', alpha=0.7, markersize=7)
1849 
1850  # print ax1.xaxis.get_majorticklocs()
1851  if xlabels is not None:
1852  ax1.set_xticklabels(xlabels)
1853  plt.xticks(rotation=90)
1854  plt.xlabel(positionname)
1855  plt.ylabel(valuename)
1856 
1857  plt.savefig(name + ".pdf", dpi=150)
1858  plt.show()
1859 
1860 
1861 def plot_xy_data(x, y, title=None, out_fn=None, display=True,
1862  set_plot_yaxis_range=None, xlabel=None, ylabel=None):
1863  import matplotlib as mpl
1864  mpl.use('Agg')
1865  import matplotlib.pyplot as plt
1866  plt.rc('lines', linewidth=2)
1867 
1868  fig, ax = plt.subplots(nrows=1)
1869  fig.set_size_inches(8, 4.5)
1870  if title is not None:
1871  fig.canvas.manager.set_window_title(title)
1872 
1873  ax.plot(x, y, color='r')
1874  if set_plot_yaxis_range is not None:
1875  x1, x2, y1, y2 = plt.axis()
1876  y1 = set_plot_yaxis_range[0]
1877  y2 = set_plot_yaxis_range[1]
1878  plt.axis((x1, x2, y1, y2))
1879  if title is not None:
1880  ax.set_title(title)
1881  if xlabel is not None:
1882  ax.set_xlabel(xlabel)
1883  if ylabel is not None:
1884  ax.set_ylabel(ylabel)
1885  if out_fn is not None:
1886  plt.savefig(out_fn + ".pdf")
1887  if display:
1888  plt.show()
1889  plt.close(fig)
1890 
1891 
1892 def plot_scatter_xy_data(x, y, labelx="None", labely="None",
1893  xmin=None, xmax=None, ymin=None, ymax=None,
1894  savefile=False, filename="None.eps", alpha=0.75):
1895 
1896  import matplotlib as mpl
1897  mpl.use('Agg')
1898  import matplotlib.pyplot as plt
1899  from matplotlib import rc
1900  rc('font', **{'family': 'sans-serif', 'sans-serif': ['Helvetica']})
1901 
1902  fig, axs = plt.subplots(1)
1903 
1904  axs0 = axs
1905 
1906  axs0.set_xlabel(labelx, size="xx-large")
1907  axs0.set_ylabel(labely, size="xx-large")
1908  axs0.tick_params(labelsize=18, pad=10)
1909 
1910  plot2 = []
1911 
1912  plot2.append(axs0.plot(x, y, 'o', color='k', lw=2, ms=0.1, alpha=alpha,
1913  c="w"))
1914 
1915  axs0.legend(
1916  loc=0,
1917  frameon=False,
1918  scatterpoints=1,
1919  numpoints=1,
1920  columnspacing=1)
1921 
1922  fig.set_size_inches(8.0, 8.0)
1923  fig.subplots_adjust(left=0.161, right=0.850, top=0.95, bottom=0.11)
1924  if (ymin is not None) and (ymax is not None):
1925  axs0.set_ylim(ymin, ymax)
1926  if (xmin is not None) and (xmax is not None):
1927  axs0.set_xlim(xmin, xmax)
1928 
1929  if savefile:
1930  fig.savefig(filename, dpi=300)
1931 
1932 
1933 def get_graph_from_hierarchy(hier):
1934  graph = []
1935  depth_dict = {}
1936  depth = 0
1937  (graph, depth, depth_dict) = recursive_graph(
1938  hier, graph, depth, depth_dict)
1939 
1940  # filters node labels according to depth_dict
1941  node_labels_dict = {}
1942  for key in depth_dict:
1943  if depth_dict[key] < 3:
1944  node_labels_dict[key] = key
1945  else:
1946  node_labels_dict[key] = ""
1947  draw_graph(graph, labels_dict=node_labels_dict)
1948 
1949 
1950 def recursive_graph(hier, graph, depth, depth_dict):
1951  depth = depth + 1
1952  nameh = IMP.atom.Hierarchy(hier).get_name()
1953  index = str(hier.get_particle().get_index())
1954  name1 = nameh + "|#" + index
1955  depth_dict[name1] = depth
1956 
1957  children = IMP.atom.Hierarchy(hier).get_children()
1958 
1959  if len(children) == 1 or children is None:
1960  depth = depth - 1
1961  return (graph, depth, depth_dict)
1962 
1963  else:
1964  for c in children:
1965  (graph, depth, depth_dict) = recursive_graph(
1966  c, graph, depth, depth_dict)
1967  nameh = IMP.atom.Hierarchy(c).get_name()
1968  index = str(c.get_particle().get_index())
1969  namec = nameh + "|#" + index
1970  graph.append((name1, namec))
1971 
1972  depth = depth - 1
1973  return (graph, depth, depth_dict)
1974 
1975 
1976 def draw_graph(graph, labels_dict=None, graph_layout='spring',
1977  node_size=5, node_color=None, node_alpha=0.3,
1978  node_text_size=11, fixed=None, pos=None,
1979  edge_color='blue', edge_alpha=0.3, edge_thickness=1,
1980  edge_text_pos=0.3,
1981  validation_edges=None,
1982  text_font='sans-serif',
1983  out_filename=None):
1984 
1985  import matplotlib as mpl
1986  mpl.use('Agg')
1987  import networkx as nx
1988  import matplotlib.pyplot as plt
1989  from math import sqrt, pi
1990 
1991  # create networkx graph
1992  G = nx.Graph()
1993 
1994  # add edges
1995  if isinstance(edge_thickness, list):
1996  for edge, weight in zip(graph, edge_thickness):
1997  G.add_edge(edge[0], edge[1], weight=weight)
1998  else:
1999  for edge in graph:
2000  G.add_edge(edge[0], edge[1])
2001 
2002  if node_color is None:
2003  node_color_rgb = (0, 0, 0)
2004  node_color_hex = "000000"
2005  else:
2007  tmpcolor_rgb = []
2008  tmpcolor_hex = []
2009  for node in G.nodes():
2010  cctuple = cc.rgb(node_color[node])
2011  tmpcolor_rgb.append((cctuple[0]/255,
2012  cctuple[1]/255,
2013  cctuple[2]/255))
2014  tmpcolor_hex.append(node_color[node])
2015  node_color_rgb = tmpcolor_rgb
2016  node_color_hex = tmpcolor_hex
2017 
2018  # get node sizes if dictionary
2019  if isinstance(node_size, dict):
2020  tmpsize = []
2021  for node in G.nodes():
2022  size = sqrt(node_size[node])/pi*10.0
2023  tmpsize.append(size)
2024  node_size = tmpsize
2025 
2026  for n, node in enumerate(G.nodes()):
2027  color = node_color_hex[n]
2028  size = node_size[n]
2029  nx.set_node_attributes(
2030  G, "graphics",
2031  {node: {'type': 'ellipse', 'w': size, 'h': size,
2032  'fill': '#' + color, 'label': node}})
2033  nx.set_node_attributes(
2034  G, "LabelGraphics",
2035  {node: {'type': 'text', 'text': node, 'color': '#000000',
2036  'visible': 'true'}})
2037 
2038  for edge in G.edges():
2039  nx.set_edge_attributes(
2040  G, "graphics",
2041  {edge: {'width': 1, 'fill': '#000000'}})
2042 
2043  for ve in validation_edges:
2044  print(ve)
2045  if (ve[0], ve[1]) in G.edges():
2046  print("found forward")
2047  nx.set_edge_attributes(
2048  G, "graphics",
2049  {ve: {'width': 1, 'fill': '#00FF00'}})
2050  elif (ve[1], ve[0]) in G.edges():
2051  print("found backward")
2052  nx.set_edge_attributes(
2053  G, "graphics",
2054  {(ve[1], ve[0]): {'width': 1, 'fill': '#00FF00'}})
2055  else:
2056  G.add_edge(ve[0], ve[1])
2057  print("not found")
2058  nx.set_edge_attributes(
2059  G, "graphics",
2060  {ve: {'width': 1, 'fill': '#FF0000'}})
2061 
2062  # these are different layouts for the network you may try
2063  # shell seems to work best
2064  if graph_layout == 'spring':
2065  print(fixed, pos)
2066  graph_pos = nx.spring_layout(G, k=1.0/8.0, fixed=fixed, pos=pos)
2067  elif graph_layout == 'spectral':
2068  graph_pos = nx.spectral_layout(G)
2069  elif graph_layout == 'random':
2070  graph_pos = nx.random_layout(G)
2071  else:
2072  graph_pos = nx.shell_layout(G)
2073 
2074  # draw graph
2075  nx.draw_networkx_nodes(G, graph_pos, node_size=node_size,
2076  alpha=node_alpha, node_color=node_color_rgb,
2077  linewidths=0)
2078  nx.draw_networkx_edges(G, graph_pos, width=edge_thickness,
2079  alpha=edge_alpha, edge_color=edge_color)
2080  nx.draw_networkx_labels(
2081  G, graph_pos, labels=labels_dict, font_size=node_text_size,
2082  font_family=text_font)
2083  if out_filename:
2084  plt.savefig(out_filename)
2085  nx.write_gml(G, 'out.gml')
2086  plt.show()
2087 
2088 
2089 def draw_table():
2090 
2091  # still an example!
2092 
2093  from ipyD3 import d3object
2094  from IPython.display import display
2095 
2096  d3 = d3object(width=800,
2097  height=400,
2098  style='JFTable',
2099  number=1,
2100  d3=None,
2101  title='Example table with d3js',
2102  desc='An example table created created with d3js with '
2103  'data generated with Python.')
2104  data = [[1277.0, 654.0, 288.0, 1976.0, 3281.0, 3089.0, 10336.0, 4650.0,
2105  4441.0, 4670.0, 944.0, 110.0],
2106  [1318.0, 664.0, 418.0, 1952.0, 3581.0, 4574.0, 11457.0, 6139.0,
2107  7078.0, 6561.0, 2354.0, 710.0],
2108  [1783.0, 774.0, 564.0, 1470.0, 3571.0, 3103.0, 9392.0, 5532.0,
2109  5661.0, 4991.0, 2032.0, 680.0],
2110  [1301.0, 604.0, 286.0, 2152.0, 3282.0, 3369.0, 10490.0, 5406.0,
2111  4727.0, 3428.0, 1559.0, 620.0],
2112  [1537.0, 1714.0, 724.0, 4824.0, 5551.0, 8096.0, 16589.0, 13650.0,
2113  9552.0, 13709.0, 2460.0, 720.0],
2114  [5691.0, 2995.0, 1680.0, 11741.0, 16232.0, 14731.0, 43522.0,
2115  32794.0, 26634.0, 31400.0, 7350.0, 3010.0],
2116  [1650.0, 2096.0, 60.0, 50.0, 1180.0, 5602.0, 15728.0, 6874.0,
2117  5115.0, 3510.0, 1390.0, 170.0],
2118  [72.0, 60.0, 60.0, 10.0, 120.0, 172.0, 1092.0, 675.0, 408.0,
2119  360.0, 156.0, 100.0]]
2120  data = [list(i) for i in zip(*data)]
2121  sRows = [['January',
2122  'February',
2123  'March',
2124  'April',
2125  'May',
2126  'June',
2127  'July',
2128  'August',
2129  'September',
2130  'October',
2131  'November',
2132  'Deecember']]
2133  sColumns = [['Prod {0}'.format(i) for i in range(1, 9)],
2134  [None, '', None, None, 'Group 1', None, None, 'Group 2']]
2135  d3.addSimpleTable(data,
2136  fontSizeCells=[12, ],
2137  sRows=sRows,
2138  sColumns=sColumns,
2139  sRowsMargins=[5, 50, 0],
2140  sColsMargins=[5, 20, 10],
2141  spacing=0,
2142  addBorders=1,
2143  addOutsideBorders=-1,
2144  rectWidth=45,
2145  rectHeight=0
2146  )
2147  html = d3.render(mode=['html', 'show'])
2148  display(html)
static bool get_is_setup(const IMP::ParticleAdaptor &p)
Definition: Residue.h:158
A container for models organized into clusters.
Definition: output.py:1505
A class for reading stat files (either rmf or ascii v1 and v2)
Definition: output.py:931
atom::Hierarchies create_hierarchies(RMF::FileConstHandle fh, Model *m)
RMF::FrameID save_frame(RMF::FileHandle file, std::string name="")
Save the current state of the linked objects as a new RMF frame.
static bool get_is_setup(const IMP::ParticleAdaptor &p)
Definition: atom/Atom.h:245
def plot_field_histogram
Plot a list of histograms from a value list.
Definition: output.py:1753
def plot_fields_box_plots
Plot time series as boxplots.
Definition: output.py:1818
Utility classes and functions for reading and storing PMI files.
def get_best_models
Given a list of stat files, read them all and find the best models.
Miscellaneous utilities.
Definition: pmi/tools.py:1
A class to store data associated to a model.
Definition: output.py:1483
void handle_use_deprecated(std::string message)
Break in this method in gdb to find deprecated uses at runtime.
std::string get_module_version()
Return the version of this module, as a string.
Change color code to hexadecimal to rgb.
Definition: pmi/tools.py:736
void write_pdb(const Selection &mhd, TextOutput out, unsigned int model=1)
Collect statistics from ProcessOutput.get_fields().
Definition: output.py:920
static bool get_is_setup(const IMP::ParticleAdaptor &p)
Definition: XYZR.h:47
def get_fields
Get the desired field names, and return a dictionary.
Definition: output.py:996
Warning related to handling of structures.
static bool get_is_setup(Model *m, ParticleIndex pi)
Definition: Fragment.h:46
def link_to_rmf
Link to another RMF file.
Definition: output.py:1155
std::string get_molecule_name_and_copy(atom::Hierarchy h)
Walk up a PMI2 hierarchy/representations and get the "molname.copynum".
Definition: pmi/utilities.h:85
The standard decorator for manipulating molecular structures.
Ints get_index(const ParticlesTemp &particles, const Subset &subset, const Subsets &excluded)
def init_pdb
Init PDB Writing.
Definition: output.py:257
int get_number_of_frames(const ::npctransport_proto::Assignment &config, double time_step)
A decorator for a particle representing an atom.
Definition: atom/Atom.h:238
Base class for capturing a modeling protocol.
Definition: output.py:41
The type for a residue.
void load_frame(RMF::FileConstHandle file, RMF::FrameID frame)
Load the given RMF frame into the state of the linked objects.
A decorator for a particle with x,y,z coordinates.
Definition: XYZ.h:30
A base class for Keys.
Definition: Key.h:45
void add_hierarchies(RMF::NodeHandle fh, const atom::Hierarchies &hs)
Class for easy writing of PDBs, RMFs, and stat files.
Definition: output.py:199
void add_geometries(RMF::NodeHandle parent, const display::GeometriesTemp &r)
Add geometries to a given parent node.
void add_restraints(RMF::NodeHandle fh, const Restraints &hs)
A decorator for a particle that is part of a rigid body but not rigid.
Definition: rigid_bodies.h:768
Display a segment connecting a pair of particles.
Definition: XYZR.h:170
A decorator for a residue.
Definition: Residue.h:137
Basic functionality that is expected to be used by a wide variety of IMP users.
def get_pdb_names
Get a list of all PDB files being output by this instance.
Definition: output.py:226
def get_prot_name_from_particle
Get the protein name from the particle.
Definition: output.py:343
class to link stat files to several rmf files
Definition: output.py:1253
class to allow more advanced handling of RMF files.
Definition: output.py:1129
void link_hierarchies(RMF::FileConstHandle fh, const atom::Hierarchies &hs)
def plot_fields
Plot the given fields and save a figure as output.
Definition: output.py:1717
void add_geometry(RMF::FileHandle file, display::Geometry *r)
Add a single geometry to the file.
Store info for a chain of a protein.
Definition: Chain.h:61
Python classes to represent, score, sample and analyze models.
def init_pdb_best_scoring
Prepare for writing best-scoring PDBs (or mmCIFs) for a sampling run.
Definition: output.py:464
Functionality for loading, creating, manipulating and scoring atomic structures.
def get_rbs_and_beads
Returns unique objects in original order.
Definition: pmi/tools.py:1135
Select hierarchy particles identified by the biological name.
Definition: Selection.h:70
def init_rmf
Initialize an RMF file.
Definition: output.py:561
def get_residue_indexes
Retrieve the residue indexes for the given particle.
Definition: pmi/tools.py:499
static bool get_is_setup(const IMP::ParticleAdaptor &p)
Definition: rigid_bodies.h:770
std::string get_module_version()
Return the version of this module, as a string.
def sublist_iterator
Yield all sublists of length >= lmin and <= lmax.
Definition: pmi/tools.py:574
A decorator for a particle with x,y,z coordinates and a radius.
Definition: XYZR.h:27