IMP logo
IMP Reference Guide  develop.d97d4ead1f,2024/11/21
The Integrative Modeling Platform
output.py
1 """@namespace IMP.pmi.output
2  Classes for writing output files and processing them.
3 """
4 
5 import IMP
6 import IMP.atom
7 import IMP.core
8 import IMP.pmi
9 import IMP.pmi.tools
10 import IMP.pmi.io
11 import os
12 import sys
13 import ast
14 import RMF
15 import numpy as np
16 import operator
17 import itertools
18 import warnings
19 import string
20 import ihm.format
21 import collections
22 import pickle
23 
24 
25 class _ChainIDs:
26  """Map indices to multi-character chain IDs.
27  We label the first 26 chains A-Z, then we move to two-letter
28  chain IDs: AA through AZ, then BA through BZ, through to ZZ.
29  This continues with longer chain IDs."""
30  def __getitem__(self, ind):
31  chars = string.ascii_uppercase
32  lc = len(chars)
33  ids = []
34  while ind >= lc:
35  ids.append(chars[ind % lc])
36  ind = ind // lc - 1
37  ids.append(chars[ind])
38  return "".join(reversed(ids))
39 
40 
42  """Base class for capturing a modeling protocol.
43  Unlike simple output of model coordinates, a complete
44  protocol includes the input data used, details on the restraints,
45  sampling, and clustering, as well as output models.
46  Use via IMP.pmi.topology.System.add_protocol_output().
47 
48  @see IMP.pmi.mmcif.ProtocolOutput for a concrete subclass that outputs
49  mmCIF files.
50  """
51  pass
52 
53 
54 def _flatten(seq):
55  for elt in seq:
56  if isinstance(elt, (tuple, list)):
57  for elt2 in _flatten(elt):
58  yield elt2
59  else:
60  yield elt
61 
62 
63 def _disambiguate_chain(chid, seen_chains):
64  """Make sure that the chain ID is unique; warn and correct if it isn't"""
65  # Handle null chain IDs
66  if chid == '\0':
67  chid = ' '
68 
69  if chid in seen_chains:
70  warnings.warn("Duplicate chain ID '%s' encountered" % chid,
72 
73  for suffix in itertools.count(1):
74  new_chid = chid + "%d" % suffix
75  if new_chid not in seen_chains:
76  seen_chains.add(new_chid)
77  return new_chid
78  seen_chains.add(chid)
79  return chid
80 
81 
82 def _write_pdb_internal(flpdb, particle_infos_for_pdb, geometric_center,
83  write_all_residues_per_bead):
84  for n, tupl in enumerate(particle_infos_for_pdb):
85  (xyz, atom_type, residue_type,
86  chain_id, residue_index, all_indexes, radius) = tupl
87  if atom_type is None:
88  atom_type = IMP.atom.AT_CA
89  if write_all_residues_per_bead and all_indexes is not None:
90  for residue_number in all_indexes:
91  flpdb.write(
92  IMP.atom.get_pdb_string((xyz[0] - geometric_center[0],
93  xyz[1] - geometric_center[1],
94  xyz[2] - geometric_center[2]),
95  n+1, atom_type, residue_type,
96  chain_id[:1], residue_number, ' ',
97  1.00, radius))
98  else:
99  flpdb.write(
100  IMP.atom.get_pdb_string((xyz[0] - geometric_center[0],
101  xyz[1] - geometric_center[1],
102  xyz[2] - geometric_center[2]),
103  n+1, atom_type, residue_type,
104  chain_id[:1], residue_index, ' ',
105  1.00, radius))
106  flpdb.write("ENDMDL\n")
107 
108 
109 _Entity = collections.namedtuple('_Entity', ('id', 'seq'))
110 _ChainInfo = collections.namedtuple('_ChainInfo', ('entity', 'name'))
111 
112 
113 def _get_chain_info(chains, root_hier):
114  chain_info = {}
115  entities = {}
116  all_entities = []
117  for mol in IMP.atom.get_by_type(root_hier, IMP.atom.MOLECULE_TYPE):
119  chain_id = chains[molname]
120  chain = IMP.atom.Chain(mol)
121  seq = chain.get_sequence()
122  if seq not in entities:
123  entities[seq] = e = _Entity(id=len(entities)+1, seq=seq)
124  all_entities.append(e)
125  entity = entities[seq]
126  info = _ChainInfo(entity=entity, name=molname)
127  chain_info[chain_id] = info
128  return chain_info, all_entities
129 
130 
131 def _write_mmcif_internal(flpdb, particle_infos_for_pdb, geometric_center,
132  write_all_residues_per_bead, chains, root_hier):
133  # get dict with keys=chain IDs, values=chain info
134  chain_info, entities = _get_chain_info(chains, root_hier)
135 
136  writer = ihm.format.CifWriter(flpdb)
137  writer.start_block('model')
138  with writer.category("_entry") as lp:
139  lp.write(id='model')
140 
141  with writer.loop("_entity", ["id", "type"]) as lp:
142  for e in entities:
143  lp.write(id=e.id, type="polymer")
144 
145  with writer.loop("_entity_poly",
146  ["entity_id", "pdbx_seq_one_letter_code"]) as lp:
147  for e in entities:
148  lp.write(entity_id=e.id, pdbx_seq_one_letter_code=e.seq)
149 
150  with writer.loop("_struct_asym", ["id", "entity_id", "details"]) as lp:
151  # Longer chain IDs (e.g. AA) should always come after shorter (e.g. Z)
152  for chid in sorted(chains.values(), key=lambda x: (len(x.strip()), x)):
153  ci = chain_info[chid]
154  lp.write(id=chid, entity_id=ci.entity.id, details=ci.name)
155 
156  with writer.loop("_atom_site",
157  ["group_PDB", "type_symbol", "label_atom_id",
158  "label_comp_id", "label_asym_id", "label_seq_id",
159  "auth_seq_id",
160  "Cartn_x", "Cartn_y", "Cartn_z", "label_entity_id",
161  "pdbx_pdb_model_num",
162  "id"]) as lp:
163  ordinal = 1
164  for n, tupl in enumerate(particle_infos_for_pdb):
165  (xyz, atom_type, residue_type,
166  chain_id, residue_index, all_indexes, radius) = tupl
167  ci = chain_info[chain_id]
168  if atom_type is None:
169  atom_type = IMP.atom.AT_CA
170  c = (xyz[0] - geometric_center[0],
171  xyz[1] - geometric_center[1],
172  xyz[2] - geometric_center[2])
173  if write_all_residues_per_bead and all_indexes is not None:
174  for residue_number in all_indexes:
175  lp.write(group_PDB='ATOM',
176  type_symbol='C',
177  label_atom_id=atom_type.get_string(),
178  label_comp_id=residue_type.get_string(),
179  label_asym_id=chain_id,
180  label_seq_id=residue_index,
181  auth_seq_id=residue_index, Cartn_x=c[0],
182  Cartn_y=c[1], Cartn_z=c[2], id=ordinal,
183  pdbx_pdb_model_num=1,
184  label_entity_id=ci.entity.id)
185  ordinal += 1
186  else:
187  lp.write(group_PDB='ATOM', type_symbol='C',
188  label_atom_id=atom_type.get_string(),
189  label_comp_id=residue_type.get_string(),
190  label_asym_id=chain_id,
191  label_seq_id=residue_index,
192  auth_seq_id=residue_index, Cartn_x=c[0],
193  Cartn_y=c[1], Cartn_z=c[2], id=ordinal,
194  pdbx_pdb_model_num=1,
195  label_entity_id=ci.entity.id)
196  ordinal += 1
197 
198 
199 class Output:
200  """Class for easy writing of PDBs, RMFs, and stat files
201 
202  @note Model should be updated prior to writing outputs.
203  """
204  def __init__(self, ascii=True, atomistic=False):
205  self.dictionary_pdbs = {}
206  self._pdb_mmcif = {}
207  self.dictionary_rmfs = {}
208  self.dictionary_stats = {}
209  self.dictionary_stats2 = {}
210  self.best_score_list = None
211  self.nbestscoring = None
212  self.prefixes = []
213  self.replica_exchange = False
214  self.ascii = ascii
215  self.initoutput = {}
216  self.residuetypekey = IMP.StringKey("ResidueName")
217  # 1-character chain IDs, suitable for PDB output
218  self.chainids = "ABCDEFGHIJKLMNOPQRSTUVWXYZ" \
219  "abcdefghijklmnopqrstuvwxyz0123456789"
220  # Multi-character chain IDs, suitable for mmCIF output
221  self.multi_chainids = _ChainIDs()
222  self.dictchain = {} # keys are molecule names, values are chain ids
223  self.particle_infos_for_pdb = {}
224  self.atomistic = atomistic
225 
226  def get_pdb_names(self):
227  """Get a list of all PDB files being output by this instance"""
228  return list(self.dictionary_pdbs.keys())
229 
230  def get_rmf_names(self):
231  return list(self.dictionary_rmfs.keys())
232 
233  def get_stat_names(self):
234  return list(self.dictionary_stats.keys())
235 
236  def _init_dictchain(self, name, prot, multichar_chain=False, mmcif=False):
237  self.dictchain[name] = {}
238  seen_chains = set()
239 
240  # attempt to find PMI objects.
241  self.atomistic = True # detects automatically
242  for n, mol in enumerate(IMP.atom.get_by_type(
243  prot, IMP.atom.MOLECULE_TYPE)):
244  chid = IMP.atom.Chain(mol).get_id()
245  if not mmcif and len(chid) > 1:
246  raise ValueError(
247  "The system contains at least one chain ID (%s) that "
248  "is more than 1 character long; this cannot be "
249  "represented in PDB. Either write mmCIF files "
250  "instead, or assign 1-character IDs to all chains "
251  "(this can be done with the `chain_ids` argument to "
252  "BuildSystem.add_state())." % chid)
253  chid = _disambiguate_chain(chid, seen_chains)
255  self.dictchain[name][molname] = chid
256 
257  def init_pdb(self, name, prot, mmcif=False):
258  """Init PDB Writing.
259  @param name The PDB filename
260  @param prot The hierarchy to write to this pdb file
261  @param mmcif If True, write PDBs in mmCIF format
262  @note if the PDB name is 'System' then will use Selection
263  to get molecules
264  """
265  flpdb = open(name, 'w')
266  flpdb.close()
267  self.dictionary_pdbs[name] = prot
268  self._pdb_mmcif[name] = mmcif
269  self._init_dictchain(name, prot, mmcif=mmcif)
270 
271  def write_psf(self, filename, name):
272  flpsf = open(filename, 'w')
273  flpsf.write("PSF CMAP CHEQ" + "\n")
274  index_residue_pair_list = {}
275  (particle_infos_for_pdb, geometric_center) = \
276  self.get_particle_infos_for_pdb_writing(name)
277  nparticles = len(particle_infos_for_pdb)
278  flpsf.write(str(nparticles) + " !NATOM" + "\n")
279  for n, p in enumerate(particle_infos_for_pdb):
280  atom_index = n+1
281  residue_type = p[2]
282  chain = p[3]
283  resid = p[4]
284  flpsf.write('{0:8d}{1:1s}{2:4s}{3:1s}{4:4s}{5:1s}{6:4s}{7:1s}'
285  '{8:4s}{9:1s}{10:4s}{11:14.6f}{12:14.6f}{13:8d}'
286  '{14:14.6f}{15:14.6f}'.format(
287  atom_index, " ", chain, " ", str(resid), " ",
288  '"'+residue_type.get_string()+'"', " ", "C",
289  " ", "C", 1.0, 0.0, 0, 0.0, 0.0))
290  flpsf.write('\n')
291  if chain not in index_residue_pair_list:
292  index_residue_pair_list[chain] = [(atom_index, resid)]
293  else:
294  index_residue_pair_list[chain].append((atom_index, resid))
295 
296  # now write the connectivity
297  indexes_pairs = []
298  for chain in sorted(index_residue_pair_list.keys()):
299 
300  ls = index_residue_pair_list[chain]
301  # sort by residue
302  ls = sorted(ls, key=lambda tup: tup[1])
303  # get the index list
304  indexes = [x[0] for x in ls]
305  # get the contiguous pairs
306  indexes_pairs.extend(IMP.pmi.tools.sublist_iterator(
307  indexes, lmin=2, lmax=2))
308  nbonds = len(indexes_pairs)
309  flpsf.write(str(nbonds)+" !NBOND: bonds"+"\n")
310 
311  # save bonds in fixed column format
312  for i in range(0, len(indexes_pairs), 4):
313  for bond in indexes_pairs[i:i+4]:
314  flpsf.write('{0:8d}{1:8d}'.format(*bond))
315  flpsf.write('\n')
316 
317  del particle_infos_for_pdb
318  flpsf.close()
319 
320  def write_pdb(self, name, appendmode=True,
321  translate_to_geometric_center=False,
322  write_all_residues_per_bead=False):
323 
324  (particle_infos_for_pdb,
325  geometric_center) = self.get_particle_infos_for_pdb_writing(name)
326 
327  if not translate_to_geometric_center:
328  geometric_center = (0, 0, 0)
329 
330  filemode = 'a' if appendmode else 'w'
331  with open(name, filemode) as flpdb:
332  if self._pdb_mmcif[name]:
333  _write_mmcif_internal(flpdb, particle_infos_for_pdb,
334  geometric_center,
335  write_all_residues_per_bead,
336  self.dictchain[name],
337  self.dictionary_pdbs[name])
338  else:
339  _write_pdb_internal(flpdb, particle_infos_for_pdb,
340  geometric_center,
341  write_all_residues_per_bead)
342 
343  def get_prot_name_from_particle(self, name, p):
344  """Get the protein name from the particle.
345  This is done by traversing the hierarchy."""
346  return IMP.pmi.get_molecule_name_and_copy(p), True
347 
348  def get_particle_infos_for_pdb_writing(self, name):
349  # index_residue_pair_list={}
350 
351  # the resindexes dictionary keep track of residues that have
352  # been already added to avoid duplication
353  # highest resolution have highest priority
354  resindexes_dict = {}
355 
356  # this dictionary will contain the sequence of tuples needed to
357  # write the pdb
358  particle_infos_for_pdb = []
359 
360  geometric_center = [0, 0, 0]
361  atom_count = 0
362 
363  # select highest resolution, if hierarchy is non-empty
364  if (not IMP.core.XYZR.get_is_setup(self.dictionary_pdbs[name])
365  and self.dictionary_pdbs[name].get_number_of_children() == 0):
366  ps = []
367  else:
368  sel = IMP.atom.Selection(self.dictionary_pdbs[name], resolution=0)
369  ps = sel.get_selected_particles()
370 
371  for n, p in enumerate(ps):
372  protname, is_a_bead = self.get_prot_name_from_particle(name, p)
373 
374  if protname not in resindexes_dict:
375  resindexes_dict[protname] = []
376 
377  if IMP.atom.Atom.get_is_setup(p) and self.atomistic:
378  residue = IMP.atom.Residue(IMP.atom.Atom(p).get_parent())
379  rt = residue.get_residue_type()
380  resind = residue.get_index()
381  atomtype = IMP.atom.Atom(p).get_atom_type()
382  xyz = list(IMP.core.XYZ(p).get_coordinates())
383  radius = IMP.core.XYZR(p).get_radius()
384  geometric_center[0] += xyz[0]
385  geometric_center[1] += xyz[1]
386  geometric_center[2] += xyz[2]
387  atom_count += 1
388  particle_infos_for_pdb.append(
389  (xyz, atomtype, rt, self.dictchain[name][protname],
390  resind, None, radius))
391  resindexes_dict[protname].append(resind)
392 
394 
395  residue = IMP.atom.Residue(p)
396  resind = residue.get_index()
397  # skip if the residue was already added by atomistic resolution
398  # 0
399  if resind in resindexes_dict[protname]:
400  continue
401  else:
402  resindexes_dict[protname].append(resind)
403  rt = residue.get_residue_type()
404  xyz = IMP.core.XYZ(p).get_coordinates()
405  radius = IMP.core.XYZR(p).get_radius()
406  geometric_center[0] += xyz[0]
407  geometric_center[1] += xyz[1]
408  geometric_center[2] += xyz[2]
409  atom_count += 1
410  particle_infos_for_pdb.append(
411  (xyz, None, rt, self.dictchain[name][protname], resind,
412  None, radius))
413 
414  elif IMP.atom.Fragment.get_is_setup(p) and not is_a_bead:
415  resindexes = list(IMP.pmi.tools.get_residue_indexes(p))
416  resind = resindexes[len(resindexes) // 2]
417  if resind in resindexes_dict[protname]:
418  continue
419  else:
420  resindexes_dict[protname].append(resind)
421  rt = IMP.atom.ResidueType('BEA')
422  xyz = IMP.core.XYZ(p).get_coordinates()
423  radius = IMP.core.XYZR(p).get_radius()
424  geometric_center[0] += xyz[0]
425  geometric_center[1] += xyz[1]
426  geometric_center[2] += xyz[2]
427  atom_count += 1
428  particle_infos_for_pdb.append(
429  (xyz, None, rt, self.dictchain[name][protname], resind,
430  resindexes, radius))
431 
432  else:
433  if is_a_bead:
434  rt = IMP.atom.ResidueType('BEA')
435  resindexes = list(IMP.pmi.tools.get_residue_indexes(p))
436  if len(resindexes) > 0:
437  resind = resindexes[len(resindexes) // 2]
438  xyz = IMP.core.XYZ(p).get_coordinates()
439  radius = IMP.core.XYZR(p).get_radius()
440  geometric_center[0] += xyz[0]
441  geometric_center[1] += xyz[1]
442  geometric_center[2] += xyz[2]
443  atom_count += 1
444  particle_infos_for_pdb.append(
445  (xyz, None, rt, self.dictchain[name][protname],
446  resind, resindexes, radius))
447 
448  if atom_count > 0:
449  geometric_center = (geometric_center[0] / atom_count,
450  geometric_center[1] / atom_count,
451  geometric_center[2] / atom_count)
452 
453  # sort by chain ID, then residue index. Longer chain IDs (e.g. AA)
454  # should always come after shorter (e.g. Z)
455  particle_infos_for_pdb = sorted(particle_infos_for_pdb,
456  key=lambda x: (len(x[3]), x[3], x[4]))
457 
458  return (particle_infos_for_pdb, geometric_center)
459 
460  def write_pdbs(self, appendmode=True, mmcif=False):
461  for pdb in self.dictionary_pdbs.keys():
462  self.write_pdb(pdb, appendmode)
463 
464  def init_pdb_best_scoring(self, prefix, prot, nbestscoring,
465  replica_exchange=False, mmcif=False,
466  best_score_file='best.scores.rex.py'):
467  """Prepare for writing best-scoring PDBs (or mmCIFs) for a
468  sampling run.
469 
470  @param prefix Initial part of each PDB filename (e.g. 'model').
471  @param prot The top-level Hierarchy to output.
472  @param nbestscoring The number of best-scoring files to output.
473  @param replica_exchange Whether to combine best scores from a
474  replica exchange run.
475  @param mmcif If True, output models in mmCIF format. If False
476  (the default) output in legacy PDB format.
477  @param best_score_file The filename to use for replica
478  exchange scores.
479  """
480 
481  self._pdb_best_scoring_mmcif = mmcif
482  fileext = '.cif' if mmcif else '.pdb'
483  self.prefixes.append(prefix)
484  self.replica_exchange = replica_exchange
485  if not self.replica_exchange:
486  # common usage
487  # if you are not in replica exchange mode
488  # initialize the array of scores internally
489  self.best_score_list = []
490  else:
491  # otherwise the replicas must communicate
492  # through a common file to know what are the best scores
493  self.best_score_file_name = best_score_file
494  self.best_score_list = []
495  with open(self.best_score_file_name, "w") as best_score_file:
496  best_score_file.write(
497  "self.best_score_list=" + str(self.best_score_list) + "\n")
498 
499  self.nbestscoring = nbestscoring
500  for i in range(self.nbestscoring):
501  name = prefix + "." + str(i) + fileext
502  flpdb = open(name, 'w')
503  flpdb.close()
504  self.dictionary_pdbs[name] = prot
505  self._pdb_mmcif[name] = mmcif
506  self._init_dictchain(name, prot, mmcif=mmcif)
507 
508  def write_pdb_best_scoring(self, score):
509  if self.nbestscoring is None:
510  print("Output.write_pdb_best_scoring: init_pdb_best_scoring "
511  "not run")
512 
513  mmcif = self._pdb_best_scoring_mmcif
514  fileext = '.cif' if mmcif else '.pdb'
515  # update the score list
516  if self.replica_exchange:
517  # read the self.best_score_list from the file
518  with open(self.best_score_file_name) as fh:
519  self.best_score_list = ast.literal_eval(
520  fh.read().split('=')[1])
521 
522  if len(self.best_score_list) < self.nbestscoring:
523  self.best_score_list.append(score)
524  self.best_score_list.sort()
525  index = self.best_score_list.index(score)
526  for prefix in self.prefixes:
527  for i in range(len(self.best_score_list) - 2, index - 1, -1):
528  oldname = prefix + "." + str(i) + fileext
529  newname = prefix + "." + str(i + 1) + fileext
530  # rename on Windows fails if newname already exists
531  if os.path.exists(newname):
532  os.unlink(newname)
533  os.rename(oldname, newname)
534  filetoadd = prefix + "." + str(index) + fileext
535  self.write_pdb(filetoadd, appendmode=False)
536 
537  else:
538  if score < self.best_score_list[-1]:
539  self.best_score_list.append(score)
540  self.best_score_list.sort()
541  self.best_score_list.pop(-1)
542  index = self.best_score_list.index(score)
543  for prefix in self.prefixes:
544  for i in range(len(self.best_score_list) - 1,
545  index - 1, -1):
546  oldname = prefix + "." + str(i) + fileext
547  newname = prefix + "." + str(i + 1) + fileext
548  os.rename(oldname, newname)
549  filenametoremove = prefix + \
550  "." + str(self.nbestscoring) + fileext
551  os.remove(filenametoremove)
552  filetoadd = prefix + "." + str(index) + fileext
553  self.write_pdb(filetoadd, appendmode=False)
554 
555  if self.replica_exchange:
556  # write the self.best_score_list to the file
557  with open(self.best_score_file_name, "w") as best_score_file:
558  best_score_file.write(
559  "self.best_score_list=" + str(self.best_score_list) + '\n')
560 
561  def init_rmf(self, name, hierarchies, rs=None, geometries=None,
562  listofobjects=None):
563  """
564  Initialize an RMF file
565 
566  @param name the name of the RMF file
567  @param hierarchies the hierarchies to be included (it is a list)
568  @param rs optional, the restraint sets (it is a list)
569  @param geometries optional, the geometries (it is a list)
570  @param listofobjects optional, the list of objects for the stat
571  (it is a list)
572  """
573  rh = RMF.create_rmf_file(name)
574  IMP.rmf.add_hierarchies(rh, hierarchies)
575  cat = None
576  outputkey_rmfkey = None
577 
578  if rs is not None:
579  IMP.rmf.add_restraints(rh, rs)
580  if geometries is not None:
581  IMP.rmf.add_geometries(rh, geometries)
582  if listofobjects is not None:
583  cat = rh.get_category("stat")
584  outputkey_rmfkey = {}
585  for o in listofobjects:
586  if "get_output" not in dir(o):
587  raise ValueError(
588  "Output: object %s doesn't have get_output() method"
589  % str(o))
590  output = o.get_output()
591  for outputkey in output:
592  rmftag = RMF.string_tag
593  if isinstance(output[outputkey], float):
594  rmftag = RMF.float_tag
595  elif isinstance(output[outputkey], int):
596  rmftag = RMF.int_tag
597  elif isinstance(output[outputkey], str):
598  rmftag = RMF.string_tag
599  else:
600  rmftag = RMF.string_tag
601  rmfkey = rh.get_key(cat, outputkey, rmftag)
602  outputkey_rmfkey[outputkey] = rmfkey
603  outputkey_rmfkey["rmf_file"] = \
604  rh.get_key(cat, "rmf_file", RMF.string_tag)
605  outputkey_rmfkey["rmf_frame_index"] = \
606  rh.get_key(cat, "rmf_frame_index", RMF.int_tag)
607 
608  self.dictionary_rmfs[name] = (rh, cat, outputkey_rmfkey, listofobjects)
609 
610  def add_restraints_to_rmf(self, name, objectlist):
611  for o in _flatten(objectlist):
612  try:
613  rs = o.get_restraint_for_rmf()
614  if not isinstance(rs, (list, tuple)):
615  rs = [rs]
616  except: # noqa: E722
617  rs = [o.get_restraint()]
619  self.dictionary_rmfs[name][0], rs)
620 
621  def add_geometries_to_rmf(self, name, objectlist):
622  for o in objectlist:
623  geos = o.get_geometries()
624  IMP.rmf.add_geometries(self.dictionary_rmfs[name][0], geos)
625 
626  def add_particle_pair_from_restraints_to_rmf(self, name, objectlist):
627  for o in objectlist:
628 
629  pps = o.get_particle_pairs()
630  for pp in pps:
632  self.dictionary_rmfs[name][0],
634 
635  def write_rmf(self, name):
636  IMP.rmf.save_frame(self.dictionary_rmfs[name][0])
637  if self.dictionary_rmfs[name][1] is not None:
638  outputkey_rmfkey = self.dictionary_rmfs[name][2]
639  listofobjects = self.dictionary_rmfs[name][3]
640  for o in listofobjects:
641  output = o.get_output()
642  for outputkey in output:
643  rmfkey = outputkey_rmfkey[outputkey]
644  try:
645  n = self.dictionary_rmfs[name][0].get_root_node()
646  n.set_value(rmfkey, output[outputkey])
647  except NotImplementedError:
648  continue
649  rmfkey = outputkey_rmfkey["rmf_file"]
650  self.dictionary_rmfs[name][0].get_root_node().set_value(
651  rmfkey, name)
652  rmfkey = outputkey_rmfkey["rmf_frame_index"]
653  nframes = self.dictionary_rmfs[name][0].get_number_of_frames()
654  self.dictionary_rmfs[name][0].get_root_node().set_value(
655  rmfkey, nframes-1)
656  self.dictionary_rmfs[name][0].flush()
657 
658  def close_rmf(self, name):
659  rh = self.dictionary_rmfs[name][0]
660  del self.dictionary_rmfs[name]
661  del rh
662 
663  def write_rmfs(self):
664  for rmfinfo in self.dictionary_rmfs.keys():
665  self.write_rmf(rmfinfo[0])
666 
667  def init_stat(self, name, listofobjects):
668  if self.ascii:
669  flstat = open(name, 'w')
670  flstat.close()
671  else:
672  flstat = open(name, 'wb')
673  flstat.close()
674 
675  # check that all objects in listofobjects have a get_output method
676  for o in listofobjects:
677  if "get_output" not in dir(o):
678  raise ValueError(
679  "Output: object %s doesn't have get_output() method"
680  % str(o))
681  self.dictionary_stats[name] = listofobjects
682 
683  def set_output_entry(self, key, value):
684  self.initoutput.update({key: value})
685 
686  def write_stat(self, name, appendmode=True):
687  output = self.initoutput
688  for obj in self.dictionary_stats[name]:
689  d = obj.get_output()
690  # remove all entries that begin with _ (private entries)
691  dfiltered = dict((k, v) for k, v in d.items() if k[0] != "_")
692  output.update(dfiltered)
693 
694  if appendmode:
695  writeflag = 'a'
696  else:
697  writeflag = 'w'
698 
699  if self.ascii:
700  flstat = open(name, writeflag)
701  flstat.write("%s \n" % output)
702  flstat.close()
703  else:
704  flstat = open(name, writeflag + 'b')
705  pickle.dump(output, flstat, 2)
706  flstat.close()
707 
708  def write_stats(self):
709  for stat in self.dictionary_stats.keys():
710  self.write_stat(stat)
711 
712  def get_stat(self, name):
713  output = {}
714  for obj in self.dictionary_stats[name]:
715  output.update(obj.get_output())
716  return output
717 
718  def write_test(self, name, listofobjects):
719  flstat = open(name, 'w')
720  output = self.initoutput
721  for o in listofobjects:
722  if "get_test_output" not in dir(o) and "get_output" not in dir(o):
723  raise ValueError(
724  "Output: object %s doesn't have get_output() or "
725  "get_test_output() method" % str(o))
726  self.dictionary_stats[name] = listofobjects
727 
728  for obj in self.dictionary_stats[name]:
729  try:
730  d = obj.get_test_output()
731  except: # noqa: E722
732  d = obj.get_output()
733  # remove all entries that begin with _ (private entries)
734  dfiltered = dict((k, v) for k, v in d.items() if k[0] != "_")
735  output.update(dfiltered)
736  flstat.write("%s \n" % output)
737  flstat.close()
738 
739  def test(self, name, listofobjects, tolerance=1e-5):
740  output = self.initoutput
741  for o in listofobjects:
742  if "get_test_output" not in dir(o) and "get_output" not in dir(o):
743  raise ValueError(
744  "Output: object %s doesn't have get_output() or "
745  "get_test_output() method" % str(o))
746  for obj in listofobjects:
747  try:
748  output.update(obj.get_test_output())
749  except: # noqa: E722
750  output.update(obj.get_output())
751 
752  flstat = open(name, 'r')
753 
754  passed = True
755  for fl in flstat:
756  test_dict = ast.literal_eval(fl)
757  for k in test_dict:
758  if k in output:
759  old_value = str(test_dict[k])
760  new_value = str(output[k])
761  try:
762  float(old_value)
763  is_float = True
764  except ValueError:
765  is_float = False
766 
767  if is_float:
768  fold = float(old_value)
769  fnew = float(new_value)
770  diff = abs(fold - fnew)
771  if diff > tolerance:
772  print("%s: test failed, old value: %s new value %s; "
773  "diff %f > %f" % (str(k), str(old_value),
774  str(new_value), diff,
775  tolerance), file=sys.stderr)
776  passed = False
777  elif test_dict[k] != output[k]:
778  if len(old_value) < 50 and len(new_value) < 50:
779  print("%s: test failed, old value: %s new value %s"
780  % (str(k), old_value, new_value),
781  file=sys.stderr)
782  passed = False
783  else:
784  print("%s: test failed, omitting results (too long)"
785  % str(k), file=sys.stderr)
786  passed = False
787 
788  else:
789  print("%s from old objects (file %s) not in new objects"
790  % (str(k), str(name)), file=sys.stderr)
791  flstat.close()
792  return passed
793 
794  def get_environment_variables(self):
795  import os
796  return str(os.environ)
797 
798  def get_versions_of_relevant_modules(self):
799  import IMP
800  versions = {}
801  versions["IMP_VERSION"] = IMP.get_module_version()
802  versions["PMI_VERSION"] = IMP.pmi.get_module_version()
803  try:
804  import IMP.isd2
805  versions["ISD2_VERSION"] = IMP.isd2.get_module_version()
806  except ImportError:
807  pass
808  try:
809  import IMP.isd_emxl
810  versions["ISD_EMXL_VERSION"] = IMP.isd_emxl.get_module_version()
811  except ImportError:
812  pass
813  return versions
814 
815  def init_stat2(self, name, listofobjects, extralabels=None,
816  listofsummedobjects=None):
817  # this is a new stat file that should be less
818  # space greedy!
819  # listofsummedobjects must be in the form
820  # [([obj1,obj2,obj3,obj4...],label)]
821  # extralabels
822 
823  if listofsummedobjects is None:
824  listofsummedobjects = []
825  if extralabels is None:
826  extralabels = []
827  flstat = open(name, 'w')
828  output = {}
829  stat2_keywords = {"STAT2HEADER": "STAT2HEADER"}
830  stat2_keywords.update(
831  {"STAT2HEADER_ENVIRON": str(self.get_environment_variables())})
832  stat2_keywords.update(
833  {"STAT2HEADER_IMP_VERSIONS":
834  str(self.get_versions_of_relevant_modules())})
835  stat2_inverse = {}
836 
837  for obj in listofobjects:
838  if "get_output" not in dir(obj):
839  raise ValueError(
840  "Output: object %s doesn't have get_output() method"
841  % str(obj))
842  else:
843  d = obj.get_output()
844  # remove all entries that begin with _ (private entries)
845  dfiltered = dict((k, v)
846  for k, v in d.items() if k[0] != "_")
847  output.update(dfiltered)
848 
849  # check for customizable entries
850  for obj in listofsummedobjects:
851  for t in obj[0]:
852  if "get_output" not in dir(t):
853  raise ValueError(
854  "Output: object %s doesn't have get_output() method"
855  % str(t))
856  else:
857  if "_TotalScore" not in t.get_output():
858  raise ValueError(
859  "Output: object %s doesn't have _TotalScore "
860  "entry to be summed" % str(t))
861  else:
862  output.update({obj[1]: 0.0})
863 
864  for k in extralabels:
865  output.update({k: 0.0})
866 
867  for n, k in enumerate(output):
868  stat2_keywords.update({n: k})
869  stat2_inverse.update({k: n})
870 
871  flstat.write("%s \n" % stat2_keywords)
872  flstat.close()
873  self.dictionary_stats2[name] = (
874  listofobjects,
875  stat2_inverse,
876  listofsummedobjects,
877  extralabels)
878 
879  def write_stat2(self, name, appendmode=True):
880  output = {}
881  (listofobjects, stat2_inverse, listofsummedobjects,
882  extralabels) = self.dictionary_stats2[name]
883 
884  # writing objects
885  for obj in listofobjects:
886  od = obj.get_output()
887  dfiltered = dict((k, v) for k, v in od.items() if k[0] != "_")
888  for k in dfiltered:
889  output.update({stat2_inverse[k]: od[k]})
890 
891  # writing summedobjects
892  for so in listofsummedobjects:
893  partial_score = 0.0
894  for t in so[0]:
895  d = t.get_output()
896  partial_score += float(d["_TotalScore"])
897  output.update({stat2_inverse[so[1]]: str(partial_score)})
898 
899  # writing extralabels
900  for k in extralabels:
901  if k in self.initoutput:
902  output.update({stat2_inverse[k]: self.initoutput[k]})
903  else:
904  output.update({stat2_inverse[k]: "None"})
905 
906  if appendmode:
907  writeflag = 'a'
908  else:
909  writeflag = 'w'
910 
911  flstat = open(name, writeflag)
912  flstat.write("%s \n" % output)
913  flstat.close()
914 
915  def write_stats2(self):
916  for stat in self.dictionary_stats2.keys():
917  self.write_stat2(stat)
918 
919 
921  """Collect statistics from ProcessOutput.get_fields().
922  Counters of the total number of frames read, plus the models that
923  passed the various filters used in get_fields(), are provided."""
924  def __init__(self):
925  self.total = 0
926  self.passed_get_every = 0
927  self.passed_filterout = 0
928  self.passed_filtertuple = 0
929 
930 
932  """A class for reading stat files (either rmf or ascii v1 and v2)"""
933  def __init__(self, filename):
934  self.filename = filename
935  self.isstat1 = False
936  self.isstat2 = False
937  self.isrmf = False
938 
939  if self.filename is None:
940  raise ValueError("No file name provided. Use -h for help")
941 
942  try:
943  # let's see if that is an rmf file
944  rh = RMF.open_rmf_file_read_only(self.filename)
945  self.isrmf = True
946  cat = rh.get_category('stat')
947  rmf_klist = rh.get_keys(cat)
948  self.rmf_names_keys = dict([(rh.get_name(k), k)
949  for k in rmf_klist])
950  del rh
951 
952  except IOError:
953  f = open(self.filename, "r")
954  # try with an ascii stat file
955  # get the keys from the first line
956  for line in f.readlines():
957  d = ast.literal_eval(line)
958  self.klist = list(d.keys())
959  # check if it is a stat2 file
960  if "STAT2HEADER" in self.klist:
961  self.isstat2 = True
962  for k in self.klist:
963  if "STAT2HEADER" in str(k):
964  # if print_header: print k, d[k]
965  del d[k]
966  stat2_dict = d
967  # get the list of keys sorted by value
968  kkeys = [k[0]
969  for k in sorted(stat2_dict.items(),
970  key=operator.itemgetter(1))]
971  self.klist = [k[1]
972  for k in sorted(stat2_dict.items(),
973  key=operator.itemgetter(1))]
974  self.invstat2_dict = {}
975  for k in kkeys:
976  self.invstat2_dict.update({stat2_dict[k]: k})
977  else:
979  "statfile v1 is deprecated. "
980  "Please convert to statfile v2.\n")
981  self.isstat1 = True
982  self.klist.sort()
983 
984  break
985  f.close()
986 
987  def get_keys(self):
988  if self.isrmf:
989  return sorted(self.rmf_names_keys.keys())
990  else:
991  return self.klist
992 
993  def show_keys(self, ncolumns=2, truncate=65):
994  IMP.pmi.tools.print_multicolumn(self.get_keys(), ncolumns, truncate)
995 
996  def get_fields(self, fields, filtertuple=None, filterout=None, get_every=1,
997  statistics=None):
998  '''
999  Get the desired field names, and return a dictionary.
1000  Namely, "fields" are the queried keys in the stat file
1001  (eg. ["Total_Score",...])
1002  The returned data structure is a dictionary, where each key is
1003  a field and the value is the time series (ie, frame ordered series)
1004  of that field (ie, {"Total_Score":[Score_0,Score_1,Score_2,,...],....})
1005 
1006  @param fields (list of strings) queried keys in the stat file
1007  (eg. "Total_Score"....)
1008  @param filterout specify if you want to "grep" out something from
1009  the file, so that it is faster
1010  @param filtertuple a tuple that contains
1011  ("TheKeyToBeFiltered",relationship,value)
1012  where relationship = "<", "==", or ">"
1013  @param get_every only read every Nth line from the file
1014  @param statistics if provided, accumulate statistics in an
1015  OutputStatistics object
1016  '''
1017 
1018  if statistics is None:
1019  statistics = OutputStatistics()
1020  outdict = {}
1021  for field in fields:
1022  outdict[field] = []
1023 
1024  # print fields values
1025  if self.isrmf:
1026  rh = RMF.open_rmf_file_read_only(self.filename)
1027  nframes = rh.get_number_of_frames()
1028  for i in range(nframes):
1029  statistics.total += 1
1030  # "get_every" and "filterout" not enforced for RMF
1031  statistics.passed_get_every += 1
1032  statistics.passed_filterout += 1
1033  rh.set_current_frame(RMF.FrameID(i))
1034  if filtertuple is not None:
1035  keytobefiltered = filtertuple[0]
1036  relationship = filtertuple[1]
1037  value = filtertuple[2]
1038  datavalue = rh.get_root_node().get_value(
1039  self.rmf_names_keys[keytobefiltered])
1040  if self.isfiltered(datavalue, relationship, value):
1041  continue
1042 
1043  statistics.passed_filtertuple += 1
1044  for field in fields:
1045  outdict[field].append(rh.get_root_node().get_value(
1046  self.rmf_names_keys[field]))
1047 
1048  else:
1049  f = open(self.filename, "r")
1050  line_number = 0
1051 
1052  for line in f.readlines():
1053  statistics.total += 1
1054  if filterout is not None:
1055  if filterout in line:
1056  continue
1057  statistics.passed_filterout += 1
1058  line_number += 1
1059 
1060  if line_number % get_every != 0:
1061  if line_number == 1 and self.isstat2:
1062  statistics.total -= 1
1063  statistics.passed_filterout -= 1
1064  continue
1065  statistics.passed_get_every += 1
1066  try:
1067  d = ast.literal_eval(line)
1068  except: # noqa: E722
1069  print("# Warning: skipped line number " + str(line_number)
1070  + " not a valid line")
1071  continue
1072 
1073  if self.isstat1:
1074 
1075  if filtertuple is not None:
1076  keytobefiltered = filtertuple[0]
1077  relationship = filtertuple[1]
1078  value = filtertuple[2]
1079  datavalue = d[keytobefiltered]
1080  if self.isfiltered(datavalue, relationship, value):
1081  continue
1082 
1083  statistics.passed_filtertuple += 1
1084  [outdict[field].append(d[field]) for field in fields]
1085 
1086  elif self.isstat2:
1087  if line_number == 1:
1088  statistics.total -= 1
1089  statistics.passed_filterout -= 1
1090  statistics.passed_get_every -= 1
1091  continue
1092 
1093  if filtertuple is not None:
1094  keytobefiltered = filtertuple[0]
1095  relationship = filtertuple[1]
1096  value = filtertuple[2]
1097  datavalue = d[self.invstat2_dict[keytobefiltered]]
1098  if self.isfiltered(datavalue, relationship, value):
1099  continue
1100 
1101  statistics.passed_filtertuple += 1
1102  [outdict[field].append(d[self.invstat2_dict[field]])
1103  for field in fields]
1104 
1105  f.close()
1106 
1107  return outdict
1108 
1109  def isfiltered(self, datavalue, relationship, refvalue):
1110  dofilter = False
1111  try:
1112  _ = float(datavalue)
1113  except ValueError:
1114  raise ValueError("ProcessOutput.filter: datavalue cannot be "
1115  "converted into a float")
1116 
1117  if relationship == "<":
1118  if float(datavalue) >= refvalue:
1119  dofilter = True
1120  if relationship == ">":
1121  if float(datavalue) <= refvalue:
1122  dofilter = True
1123  if relationship == "==":
1124  if float(datavalue) != refvalue:
1125  dofilter = True
1126  return dofilter
1127 
1128 
1130  """ class to allow more advanced handling of RMF files.
1131  It is both a container and a IMP.atom.Hierarchy.
1132  - it is iterable (while loading the corresponding frame)
1133  - Item brackets [] load the corresponding frame
1134  - slice create an iterator
1135  - can relink to another RMF file
1136  """
1137  def __init__(self, model, rmf_file_name):
1138  """
1139  @param model: the IMP.Model()
1140  @param rmf_file_name: str, path of the rmf file
1141  """
1142  self.model = model
1143  try:
1144  self.rh_ref = RMF.open_rmf_file_read_only(rmf_file_name)
1145  except TypeError:
1146  raise TypeError("Wrong rmf file name or type: %s"
1147  % str(rmf_file_name))
1148  hs = IMP.rmf.create_hierarchies(self.rh_ref, self.model)
1149  IMP.rmf.load_frame(self.rh_ref, RMF.FrameID(0))
1150  self.root_hier_ref = hs[0]
1151  IMP.atom.Hierarchy.__init__(self, self.root_hier_ref)
1152  self.model.update()
1153  self.ColorHierarchy = None
1154 
1155  def link_to_rmf(self, rmf_file_name):
1156  """
1157  Link to another RMF file
1158  """
1159  self.rh_ref = RMF.open_rmf_file_read_only(rmf_file_name)
1160  IMP.rmf.link_hierarchies(self.rh_ref, [self])
1161  if self.ColorHierarchy:
1162  self.ColorHierarchy.method()
1163  RMFHierarchyHandler.set_frame(self, 0)
1164 
1165  def set_frame(self, index):
1166  try:
1167  IMP.rmf.load_frame(self.rh_ref, RMF.FrameID(index))
1168  except: # noqa: E722
1169  print("skipping frame %s:%d\n" % (self.current_rmf, index))
1170  self.model.update()
1171 
1172  def get_number_of_frames(self):
1173  return self.rh_ref.get_number_of_frames()
1174 
1175  def __getitem__(self, int_slice_adaptor):
1176  if isinstance(int_slice_adaptor, int):
1177  self.set_frame(int_slice_adaptor)
1178  return int_slice_adaptor
1179  elif isinstance(int_slice_adaptor, slice):
1180  return self.__iter__(int_slice_adaptor)
1181  else:
1182  raise TypeError("Unknown Type")
1183 
1184  def __len__(self):
1185  return self.get_number_of_frames()
1186 
1187  def __iter__(self, slice_key=None):
1188  if slice_key is None:
1189  for nframe in range(len(self)):
1190  yield self[nframe]
1191  else:
1192  for nframe in list(range(len(self)))[slice_key]:
1193  yield self[nframe]
1194 
1195 
1196 class CacheHierarchyCoordinates:
1197  def __init__(self, StatHierarchyHandler):
1198  self.xyzs = []
1199  self.nrms = []
1200  self.rbs = []
1201  self.nrm_coors = {}
1202  self.xyz_coors = {}
1203  self.rb_trans = {}
1204  self.current_index = None
1205  self.rmfh = StatHierarchyHandler
1206  rbs, xyzs = IMP.pmi.tools.get_rbs_and_beads([self.rmfh])
1207  self.model = self.rmfh.get_model()
1208  self.rbs = rbs
1209  for xyz in xyzs:
1211  nrm = IMP.core.NonRigidMember(xyz)
1212  self.nrms.append(nrm)
1213  else:
1214  fb = IMP.core.XYZ(xyz)
1215  self.xyzs.append(fb)
1216 
1217  def do_store(self, index):
1218  self.rb_trans[index] = {}
1219  self.nrm_coors[index] = {}
1220  self.xyz_coors[index] = {}
1221  for rb in self.rbs:
1222  self.rb_trans[index][rb] = rb.get_reference_frame()
1223  for nrm in self.nrms:
1224  self.nrm_coors[index][nrm] = nrm.get_internal_coordinates()
1225  for xyz in self.xyzs:
1226  self.xyz_coors[index][xyz] = xyz.get_coordinates()
1227  self.current_index = index
1228 
1229  def do_update(self, index):
1230  if self.current_index != index:
1231  for rb in self.rbs:
1232  rb.set_reference_frame(self.rb_trans[index][rb])
1233  for nrm in self.nrms:
1234  nrm.set_internal_coordinates(self.nrm_coors[index][nrm])
1235  for xyz in self.xyzs:
1236  xyz.set_coordinates(self.xyz_coors[index][xyz])
1237  self.current_index = index
1238  self.model.update()
1239 
1240  def get_number_of_frames(self):
1241  return len(self.rb_trans.keys())
1242 
1243  def __getitem__(self, index):
1244  if isinstance(index, int):
1245  return index in self.rb_trans.keys()
1246  else:
1247  raise TypeError("Unknown Type")
1248 
1249  def __len__(self):
1250  return self.get_number_of_frames()
1251 
1252 
1254  """ class to link stat files to several rmf files """
1255  def __init__(self, model=None, stat_file=None,
1256  number_best_scoring_models=None, score_key=None,
1257  StatHierarchyHandler=None, cache=None):
1258  """
1259 
1260  @param model: IMP.Model()
1261  @param stat_file: either 1) a list or 2) a single stat file names
1262  (either rmfs or ascii, or pickled data or pickled cluster),
1263  3) a dictionary containing an rmf/ascii
1264  stat file name as key and a list of frames as values
1265  @param number_best_scoring_models:
1266  @param StatHierarchyHandler: copy constructor input object
1267  @param cache: cache coordinates and rigid body transformations.
1268  """
1269 
1270  if StatHierarchyHandler is not None:
1271  # overrides all other arguments
1272  # copy constructor: create a copy with
1273  # different RMFHierarchyHandler
1274  self.model = StatHierarchyHandler.model
1275  self.data = StatHierarchyHandler.data
1276  self.number_best_scoring_models = \
1277  StatHierarchyHandler.number_best_scoring_models
1278  self.is_setup = True
1279  self.current_rmf = StatHierarchyHandler.current_rmf
1280  self.current_frame = None
1281  self.current_index = None
1282  self.score_threshold = StatHierarchyHandler.score_threshold
1283  self.score_key = StatHierarchyHandler.score_key
1284  self.cache = StatHierarchyHandler.cache
1285  RMFHierarchyHandler.__init__(self, self.model,
1286  self.current_rmf)
1287  if self.cache:
1288  self.cache = CacheHierarchyCoordinates(self)
1289  else:
1290  self.cache = None
1291  self.set_frame(0)
1292 
1293  else:
1294  # standard constructor
1295  self.model = model
1296  self.data = []
1297  self.number_best_scoring_models = number_best_scoring_models
1298  self.cache = cache
1299 
1300  if score_key is None:
1301  self.score_key = "Total_Score"
1302  else:
1303  self.score_key = score_key
1304  self.is_setup = None
1305  self.current_rmf = None
1306  self.current_frame = None
1307  self.current_index = None
1308  self.score_threshold = None
1309 
1310  if isinstance(stat_file, str):
1311  self.add_stat_file(stat_file)
1312  elif isinstance(stat_file, list):
1313  for f in stat_file:
1314  self.add_stat_file(f)
1315 
1316  def add_stat_file(self, stat_file):
1317  try:
1318  '''check that it is not a pickle file with saved data
1319  from a previous calculation'''
1320  self.load_data(stat_file)
1321 
1322  if self.number_best_scoring_models:
1323  scores = self.get_scores()
1324  max_score = sorted(scores)[
1325  0:min(len(self), self.number_best_scoring_models)][-1]
1326  self.do_filter_by_score(max_score)
1327 
1328  except pickle.UnpicklingError:
1329  '''alternatively read the ascii stat files'''
1330  try:
1331  scores, rmf_files, rmf_frame_indexes, features = \
1332  self.get_info_from_stat_file(stat_file,
1333  self.score_threshold)
1334  except (KeyError, SyntaxError):
1335  # in this case check that is it an rmf file, probably
1336  # without stat stored in
1337  try:
1338  # let's see if that is an rmf file
1339  rh = RMF.open_rmf_file_read_only(stat_file)
1340  nframes = rh.get_number_of_frames()
1341  scores = [0.0]*nframes
1342  rmf_files = [stat_file]*nframes
1343  rmf_frame_indexes = range(nframes)
1344  features = {}
1345  except: # noqa: E722
1346  return
1347 
1348  if len(set(rmf_files)) > 1:
1349  raise ("Multiple RMF files found")
1350 
1351  if not rmf_files:
1352  print("StatHierarchyHandler: Error: Trying to set none as "
1353  "rmf_file (probably empty stat file), aborting")
1354  return
1355 
1356  for n, index in enumerate(rmf_frame_indexes):
1357  featn_dict = dict([(k, features[k][n]) for k in features])
1358  self.data.append(IMP.pmi.output.DataEntry(
1359  stat_file, rmf_files[n], index, scores[n], featn_dict))
1360 
1361  if self.number_best_scoring_models:
1362  scores = self.get_scores()
1363  max_score = sorted(scores)[
1364  0:min(len(self), self.number_best_scoring_models)][-1]
1365  self.do_filter_by_score(max_score)
1366 
1367  if not self.is_setup:
1368  RMFHierarchyHandler.__init__(
1369  self, self.model, self.get_rmf_names()[0])
1370  if self.cache:
1371  self.cache = CacheHierarchyCoordinates(self)
1372  else:
1373  self.cache = None
1374  self.is_setup = True
1375  self.current_rmf = self.get_rmf_names()[0]
1376 
1377  self.set_frame(0)
1378 
1379  def save_data(self, filename='data.pkl'):
1380  with open(filename, 'wb') as fl:
1381  pickle.dump(self.data, fl)
1382 
1383  def load_data(self, filename='data.pkl'):
1384  with open(filename, 'rb') as fl:
1385  data_structure = pickle.load(fl)
1386  # first check that it is a list
1387  if not isinstance(data_structure, list):
1388  raise TypeError(
1389  "%filename should contain a list of IMP.pmi.output.DataEntry "
1390  "or IMP.pmi.output.Cluster" % filename)
1391  # second check the types
1392  if all(isinstance(item, IMP.pmi.output.DataEntry)
1393  for item in data_structure):
1394  self.data = data_structure
1395  elif all(isinstance(item, IMP.pmi.output.Cluster)
1396  for item in data_structure):
1397  nmodels = 0
1398  for cluster in data_structure:
1399  nmodels += len(cluster)
1400  self.data = [None]*nmodels
1401  for cluster in data_structure:
1402  for n, data in enumerate(cluster):
1403  index = cluster.members[n]
1404  self.data[index] = data
1405  else:
1406  raise TypeError(
1407  "%filename should contain a list of IMP.pmi.output.DataEntry "
1408  "or IMP.pmi.output.Cluster" % filename)
1409 
1410  def set_frame(self, index):
1411  if self.cache is not None and self.cache[index]:
1412  self.cache.do_update(index)
1413  else:
1414  nm = self.data[index].rmf_name
1415  fidx = self.data[index].rmf_index
1416  if nm != self.current_rmf:
1417  self.link_to_rmf(nm)
1418  self.current_rmf = nm
1419  self.current_frame = -1
1420  if fidx != self.current_frame:
1421  RMFHierarchyHandler.set_frame(self, fidx)
1422  self.current_frame = fidx
1423  if self.cache is not None:
1424  self.cache.do_store(index)
1425 
1426  self.current_index = index
1427 
1428  def __getitem__(self, int_slice_adaptor):
1429  if isinstance(int_slice_adaptor, int):
1430  self.set_frame(int_slice_adaptor)
1431  return self.data[int_slice_adaptor]
1432  elif isinstance(int_slice_adaptor, slice):
1433  return self.__iter__(int_slice_adaptor)
1434  else:
1435  raise TypeError("Unknown Type")
1436 
1437  def __len__(self):
1438  return len(self.data)
1439 
1440  def __iter__(self, slice_key=None):
1441  if slice_key is None:
1442  for i in range(len(self)):
1443  yield self[i]
1444  else:
1445  for i in range(len(self))[slice_key]:
1446  yield self[i]
1447 
1448  def do_filter_by_score(self, maximum_score):
1449  self.data = [d for d in self.data if d.score <= maximum_score]
1450 
1451  def get_scores(self):
1452  return [d.score for d in self.data]
1453 
1454  def get_feature_series(self, feature_name):
1455  return [d.features[feature_name] for d in self.data]
1456 
1457  def get_feature_names(self):
1458  return self.data[0].features.keys()
1459 
1460  def get_rmf_names(self):
1461  return [d.rmf_name for d in self.data]
1462 
1463  def get_stat_files_names(self):
1464  return [d.stat_file for d in self.data]
1465 
1466  def get_rmf_indexes(self):
1467  return [d.rmf_index for d in self.data]
1468 
1469  def get_info_from_stat_file(self, stat_file, score_threshold=None):
1470  po = ProcessOutput(stat_file)
1471  fs = po.get_keys()
1472  models = IMP.pmi.io.get_best_models(
1473  [stat_file], score_key=self.score_key, feature_keys=fs,
1474  rmf_file_key="rmf_file", rmf_file_frame_key="rmf_frame_index",
1475  prefiltervalue=score_threshold, get_every=1)
1476 
1477  scores = [float(y) for y in models[2]]
1478  rmf_files = models[0]
1479  rmf_frame_indexes = models[1]
1480  features = models[3]
1481  return scores, rmf_files, rmf_frame_indexes, features
1482 
1483 
1485  '''
1486  A class to store data associated to a model
1487  '''
1488  def __init__(self, stat_file=None, rmf_name=None, rmf_index=None,
1489  score=None, features=None):
1490  self.rmf_name = rmf_name
1491  self.rmf_index = rmf_index
1492  self.score = score
1493  self.features = features
1494  self.stat_file = stat_file
1495 
1496  def __repr__(self):
1497  s = "IMP.pmi.output.DataEntry\n"
1498  s += "---- stat file %s \n" % (self.stat_file)
1499  s += "---- rmf file %s \n" % (self.rmf_name)
1500  s += "---- rmf index %s \n" % (str(self.rmf_index))
1501  s += "---- score %s \n" % (str(self.score))
1502  s += "---- number of features %s \n" % (str(len(self.features.keys())))
1503  return s
1504 
1505 
1506 class Cluster:
1507  '''
1508  A container for models organized into clusters
1509  '''
1510  def __init__(self, cid=None):
1511  self.cluster_id = cid
1512  self.members = []
1513  self.precision = None
1514  self.center_index = None
1515  self.members_data = {}
1516 
1517  def add_member(self, index, data=None):
1518  self.members.append(index)
1519  self.members_data[index] = data
1520  self.average_score = self.compute_score()
1521 
1522  def compute_score(self):
1523  try:
1524  score = sum([d.score for d in self])/len(self)
1525  except AttributeError:
1526  score = None
1527  return score
1528 
1529  def __repr__(self):
1530  s = "IMP.pmi.output.Cluster\n"
1531  s += "---- cluster_id %s \n" % str(self.cluster_id)
1532  s += "---- precision %s \n" % str(self.precision)
1533  s += "---- average score %s \n" % str(self.average_score)
1534  s += "---- number of members %s \n" % str(len(self.members))
1535  s += "---- center index %s \n" % str(self.center_index)
1536  return s
1537 
1538  def __getitem__(self, int_slice_adaptor):
1539  if isinstance(int_slice_adaptor, int):
1540  index = self.members[int_slice_adaptor]
1541  return self.members_data[index]
1542  elif isinstance(int_slice_adaptor, slice):
1543  return self.__iter__(int_slice_adaptor)
1544  else:
1545  raise TypeError("Unknown Type")
1546 
1547  def __len__(self):
1548  return len(self.members)
1549 
1550  def __iter__(self, slice_key=None):
1551  if slice_key is None:
1552  for i in range(len(self)):
1553  yield self[i]
1554  else:
1555  for i in range(len(self))[slice_key]:
1556  yield self[i]
1557 
1558  def __add__(self, other):
1559  self.members += other.members
1560  self.members_data.update(other.members_data)
1561  self.average_score = self.compute_score()
1562  self.precision = None
1563  self.center_index = None
1564  return self
1565 
1566 
1567 def plot_clusters_populations(clusters):
1568  indexes = []
1569  populations = []
1570  for cluster in clusters:
1571  indexes.append(cluster.cluster_id)
1572  populations.append(len(cluster))
1573 
1574  import matplotlib.pyplot as plt
1575  fig, ax = plt.subplots()
1576  ax.bar(indexes, populations, 0.5, color='r')
1577  ax.set_ylabel('Population')
1578  ax.set_xlabel(('Cluster index'))
1579  plt.show()
1580 
1581 
1582 def plot_clusters_precisions(clusters):
1583  indexes = []
1584  precisions = []
1585  for cluster in clusters:
1586  indexes.append(cluster.cluster_id)
1587 
1588  prec = cluster.precision
1589  print(cluster.cluster_id, prec)
1590  if prec is None:
1591  prec = 0.0
1592  precisions.append(prec)
1593 
1594  import matplotlib.pyplot as plt
1595  fig, ax = plt.subplots()
1596  ax.bar(indexes, precisions, 0.5, color='r')
1597  ax.set_ylabel('Precision [A]')
1598  ax.set_xlabel(('Cluster index'))
1599  plt.show()
1600 
1601 
1602 def plot_clusters_scores(clusters):
1603  indexes = []
1604  values = []
1605  for cluster in clusters:
1606  indexes.append(cluster.cluster_id)
1607  values.append([])
1608  for data in cluster:
1609  values[-1].append(data.score)
1610 
1611  plot_fields_box_plots("scores.pdf", values, indexes, frequencies=None,
1612  valuename="Scores", positionname="Cluster index",
1613  xlabels=None, scale_plot_length=1.0)
1614 
1615 
1616 class CrossLinkIdentifierDatabase:
1617  def __init__(self):
1618  self.clidb = dict()
1619 
1620  def check_key(self, key):
1621  if key not in self.clidb:
1622  self.clidb[key] = {}
1623 
1624  def set_unique_id(self, key, value):
1625  self.check_key(key)
1626  self.clidb[key]["XLUniqueID"] = str(value)
1627 
1628  def set_protein1(self, key, value):
1629  self.check_key(key)
1630  self.clidb[key]["Protein1"] = str(value)
1631 
1632  def set_protein2(self, key, value):
1633  self.check_key(key)
1634  self.clidb[key]["Protein2"] = str(value)
1635 
1636  def set_residue1(self, key, value):
1637  self.check_key(key)
1638  self.clidb[key]["Residue1"] = int(value)
1639 
1640  def set_residue2(self, key, value):
1641  self.check_key(key)
1642  self.clidb[key]["Residue2"] = int(value)
1643 
1644  def set_idscore(self, key, value):
1645  self.check_key(key)
1646  self.clidb[key]["IDScore"] = float(value)
1647 
1648  def set_state(self, key, value):
1649  self.check_key(key)
1650  self.clidb[key]["State"] = int(value)
1651 
1652  def set_sigma1(self, key, value):
1653  self.check_key(key)
1654  self.clidb[key]["Sigma1"] = str(value)
1655 
1656  def set_sigma2(self, key, value):
1657  self.check_key(key)
1658  self.clidb[key]["Sigma2"] = str(value)
1659 
1660  def set_psi(self, key, value):
1661  self.check_key(key)
1662  self.clidb[key]["Psi"] = str(value)
1663 
1664  def get_unique_id(self, key):
1665  return self.clidb[key]["XLUniqueID"]
1666 
1667  def get_protein1(self, key):
1668  return self.clidb[key]["Protein1"]
1669 
1670  def get_protein2(self, key):
1671  return self.clidb[key]["Protein2"]
1672 
1673  def get_residue1(self, key):
1674  return self.clidb[key]["Residue1"]
1675 
1676  def get_residue2(self, key):
1677  return self.clidb[key]["Residue2"]
1678 
1679  def get_idscore(self, key):
1680  return self.clidb[key]["IDScore"]
1681 
1682  def get_state(self, key):
1683  return self.clidb[key]["State"]
1684 
1685  def get_sigma1(self, key):
1686  return self.clidb[key]["Sigma1"]
1687 
1688  def get_sigma2(self, key):
1689  return self.clidb[key]["Sigma2"]
1690 
1691  def get_psi(self, key):
1692  return self.clidb[key]["Psi"]
1693 
1694  def set_float_feature(self, key, value, feature_name):
1695  self.check_key(key)
1696  self.clidb[key][feature_name] = float(value)
1697 
1698  def set_int_feature(self, key, value, feature_name):
1699  self.check_key(key)
1700  self.clidb[key][feature_name] = int(value)
1701 
1702  def set_string_feature(self, key, value, feature_name):
1703  self.check_key(key)
1704  self.clidb[key][feature_name] = str(value)
1705 
1706  def get_feature(self, key, feature_name):
1707  return self.clidb[key][feature_name]
1708 
1709  def write(self, filename):
1710  with open(filename, 'wb') as handle:
1711  pickle.dump(self.clidb, handle)
1712 
1713  def load(self, filename):
1714  with open(filename, 'rb') as handle:
1715  self.clidb = pickle.load(handle)
1716 
1717 
1718 def plot_fields(fields, output, framemin=None, framemax=None):
1719  """Plot the given fields and save a figure as `output`.
1720  The fields generally are extracted from a stat file
1721  using ProcessOutput.get_fields()."""
1722  import matplotlib as mpl
1723  mpl.use('Agg')
1724  import matplotlib.pyplot as plt
1725 
1726  plt.rc('lines', linewidth=4)
1727  fig, axs = plt.subplots(nrows=len(fields))
1728  fig.set_size_inches(10.5, 5.5 * len(fields))
1729  plt.rc('axes')
1730 
1731  n = 0
1732  for key in fields:
1733  if framemin is None:
1734  framemin = 0
1735  if framemax is None:
1736  framemax = len(fields[key])
1737  x = list(range(framemin, framemax))
1738  y = [float(y) for y in fields[key][framemin:framemax]]
1739  if len(fields) > 1:
1740  axs[n].plot(x, y)
1741  axs[n].set_title(key, size="xx-large")
1742  axs[n].tick_params(labelsize=18, pad=10)
1743  else:
1744  axs.plot(x, y)
1745  axs.set_title(key, size="xx-large")
1746  axs.tick_params(labelsize=18, pad=10)
1747  n += 1
1748 
1749  # Tweak spacing between subplots to prevent labels from overlapping
1750  plt.subplots_adjust(hspace=0.3)
1751  plt.savefig(output)
1752 
1753 
1754 def plot_field_histogram(name, values_lists, valuename=None, bins=40,
1755  colors=None, format="png", reference_xline=None,
1756  yplotrange=None, xplotrange=None, normalized=True,
1757  leg_names=None):
1758  '''Plot a list of histograms from a value list.
1759  @param name the name of the plot
1760  @param values_lists the list of list of values eg: [[...],[...],[...]]
1761  @param valuename the y-label
1762  @param bins the number of bins
1763  @param colors If None, will use rainbow. Else will use specific list
1764  @param format output format
1765  @param reference_xline plot a reference line parallel to the y-axis
1766  @param yplotrange the range for the y-axis
1767  @param normalized whether the histogram is normalized or not
1768  @param leg_names names for the legend
1769  '''
1770 
1771  import matplotlib as mpl
1772  mpl.use('Agg')
1773  import matplotlib.pyplot as plt
1774  import matplotlib.cm as cm
1775  plt.figure(figsize=(18.0, 9.0))
1776 
1777  if colors is None:
1778  colors = cm.rainbow(np.linspace(0, 1, len(values_lists)))
1779  for nv, values in enumerate(values_lists):
1780  col = colors[nv]
1781  if leg_names is not None:
1782  label = leg_names[nv]
1783  else:
1784  label = str(nv)
1785  try:
1786  plt.hist(
1787  [float(y) for y in values], bins=bins, color=col,
1788  density=normalized, histtype='step', lw=4, label=label)
1789  except AttributeError:
1790  plt.hist(
1791  [float(y) for y in values], bins=bins, color=col,
1792  normed=normalized, histtype='step', lw=4, label=label)
1793 
1794  # plt.title(name,size="xx-large")
1795  plt.tick_params(labelsize=12, pad=10)
1796  if valuename is None:
1797  plt.xlabel(name, size="xx-large")
1798  else:
1799  plt.xlabel(valuename, size="xx-large")
1800  plt.ylabel("Frequency", size="xx-large")
1801 
1802  if yplotrange is not None:
1803  plt.ylim()
1804  if xplotrange is not None:
1805  plt.xlim(xplotrange)
1806 
1807  plt.legend(loc=2)
1808 
1809  if reference_xline is not None:
1810  plt.axvline(
1811  reference_xline,
1812  color='red',
1813  linestyle='dashed',
1814  linewidth=1)
1815 
1816  plt.savefig(name + "." + format, dpi=150, transparent=True)
1817 
1818 
1819 def plot_fields_box_plots(name, values, positions, frequencies=None,
1820  valuename="None", positionname="None",
1821  xlabels=None, scale_plot_length=1.0):
1822  '''
1823  Plot time series as boxplots.
1824  fields is a list of time series, positions are the x-values
1825  valuename is the y-label, positionname is the x-label
1826  '''
1827 
1828  import matplotlib as mpl
1829  mpl.use('Agg')
1830  import matplotlib.pyplot as plt
1831 
1832  bps = []
1833  fig = plt.figure(figsize=(float(len(positions))*scale_plot_length, 5.0))
1834  fig.canvas.manager.set_window_title(name)
1835 
1836  ax1 = fig.add_subplot(111)
1837 
1838  plt.subplots_adjust(left=0.1, right=0.990, top=0.95, bottom=0.4)
1839 
1840  bps.append(plt.boxplot(values, notch=0, sym='', vert=1,
1841  whis=1.5, positions=positions))
1842 
1843  plt.setp(bps[-1]['boxes'], color='black', lw=1.5)
1844  plt.setp(bps[-1]['whiskers'], color='black', ls=":", lw=1.5)
1845 
1846  if frequencies is not None:
1847  for n, v in enumerate(values):
1848  plist = [positions[n]]*len(v)
1849  ax1.plot(plist, v, 'gx', alpha=0.7, markersize=7)
1850 
1851  # print ax1.xaxis.get_majorticklocs()
1852  if xlabels is not None:
1853  ax1.set_xticklabels(xlabels)
1854  plt.xticks(rotation=90)
1855  plt.xlabel(positionname)
1856  plt.ylabel(valuename)
1857 
1858  plt.savefig(name + ".pdf", dpi=150)
1859  plt.show()
1860 
1861 
1862 def plot_xy_data(x, y, title=None, out_fn=None, display=True,
1863  set_plot_yaxis_range=None, xlabel=None, ylabel=None):
1864  import matplotlib as mpl
1865  mpl.use('Agg')
1866  import matplotlib.pyplot as plt
1867  plt.rc('lines', linewidth=2)
1868 
1869  fig, ax = plt.subplots(nrows=1)
1870  fig.set_size_inches(8, 4.5)
1871  if title is not None:
1872  fig.canvas.manager.set_window_title(title)
1873 
1874  ax.plot(x, y, color='r')
1875  if set_plot_yaxis_range is not None:
1876  x1, x2, y1, y2 = plt.axis()
1877  y1 = set_plot_yaxis_range[0]
1878  y2 = set_plot_yaxis_range[1]
1879  plt.axis((x1, x2, y1, y2))
1880  if title is not None:
1881  ax.set_title(title)
1882  if xlabel is not None:
1883  ax.set_xlabel(xlabel)
1884  if ylabel is not None:
1885  ax.set_ylabel(ylabel)
1886  if out_fn is not None:
1887  plt.savefig(out_fn + ".pdf")
1888  if display:
1889  plt.show()
1890  plt.close(fig)
1891 
1892 
1893 def plot_scatter_xy_data(x, y, labelx="None", labely="None",
1894  xmin=None, xmax=None, ymin=None, ymax=None,
1895  savefile=False, filename="None.eps", alpha=0.75):
1896 
1897  import matplotlib as mpl
1898  mpl.use('Agg')
1899  import matplotlib.pyplot as plt
1900  from matplotlib import rc
1901  rc('font', **{'family': 'sans-serif', 'sans-serif': ['Helvetica']})
1902 
1903  fig, axs = plt.subplots(1)
1904 
1905  axs0 = axs
1906 
1907  axs0.set_xlabel(labelx, size="xx-large")
1908  axs0.set_ylabel(labely, size="xx-large")
1909  axs0.tick_params(labelsize=18, pad=10)
1910 
1911  plot2 = []
1912 
1913  plot2.append(axs0.plot(x, y, 'o', color='k', lw=2, ms=0.1, alpha=alpha,
1914  c="w"))
1915 
1916  axs0.legend(
1917  loc=0,
1918  frameon=False,
1919  scatterpoints=1,
1920  numpoints=1,
1921  columnspacing=1)
1922 
1923  fig.set_size_inches(8.0, 8.0)
1924  fig.subplots_adjust(left=0.161, right=0.850, top=0.95, bottom=0.11)
1925  if (ymin is not None) and (ymax is not None):
1926  axs0.set_ylim(ymin, ymax)
1927  if (xmin is not None) and (xmax is not None):
1928  axs0.set_xlim(xmin, xmax)
1929 
1930  if savefile:
1931  fig.savefig(filename, dpi=300)
1932 
1933 
1934 def get_graph_from_hierarchy(hier):
1935  graph = []
1936  depth_dict = {}
1937  depth = 0
1938  (graph, depth, depth_dict) = recursive_graph(
1939  hier, graph, depth, depth_dict)
1940 
1941  # filters node labels according to depth_dict
1942  node_labels_dict = {}
1943  for key in depth_dict:
1944  if depth_dict[key] < 3:
1945  node_labels_dict[key] = key
1946  else:
1947  node_labels_dict[key] = ""
1948  draw_graph(graph, labels_dict=node_labels_dict)
1949 
1950 
1951 def recursive_graph(hier, graph, depth, depth_dict):
1952  depth = depth + 1
1953  nameh = IMP.atom.Hierarchy(hier).get_name()
1954  index = str(hier.get_particle().get_index())
1955  name1 = nameh + "|#" + index
1956  depth_dict[name1] = depth
1957 
1958  children = IMP.atom.Hierarchy(hier).get_children()
1959 
1960  if len(children) == 1 or children is None:
1961  depth = depth - 1
1962  return (graph, depth, depth_dict)
1963 
1964  else:
1965  for c in children:
1966  (graph, depth, depth_dict) = recursive_graph(
1967  c, graph, depth, depth_dict)
1968  nameh = IMP.atom.Hierarchy(c).get_name()
1969  index = str(c.get_particle().get_index())
1970  namec = nameh + "|#" + index
1971  graph.append((name1, namec))
1972 
1973  depth = depth - 1
1974  return (graph, depth, depth_dict)
1975 
1976 
1977 def draw_graph(graph, labels_dict=None, graph_layout='spring',
1978  node_size=5, node_color=None, node_alpha=0.3,
1979  node_text_size=11, fixed=None, pos=None,
1980  edge_color='blue', edge_alpha=0.3, edge_thickness=1,
1981  edge_text_pos=0.3,
1982  validation_edges=None,
1983  text_font='sans-serif',
1984  out_filename=None):
1985 
1986  import matplotlib as mpl
1987  mpl.use('Agg')
1988  import networkx as nx
1989  import matplotlib.pyplot as plt
1990  from math import sqrt, pi
1991 
1992  # create networkx graph
1993  G = nx.Graph()
1994 
1995  # add edges
1996  if isinstance(edge_thickness, list):
1997  for edge, weight in zip(graph, edge_thickness):
1998  G.add_edge(edge[0], edge[1], weight=weight)
1999  else:
2000  for edge in graph:
2001  G.add_edge(edge[0], edge[1])
2002 
2003  if node_color is None:
2004  node_color_rgb = (0, 0, 0)
2005  node_color_hex = "000000"
2006  else:
2008  tmpcolor_rgb = []
2009  tmpcolor_hex = []
2010  for node in G.nodes():
2011  cctuple = cc.rgb(node_color[node])
2012  tmpcolor_rgb.append((cctuple[0]/255,
2013  cctuple[1]/255,
2014  cctuple[2]/255))
2015  tmpcolor_hex.append(node_color[node])
2016  node_color_rgb = tmpcolor_rgb
2017  node_color_hex = tmpcolor_hex
2018 
2019  # get node sizes if dictionary
2020  if isinstance(node_size, dict):
2021  tmpsize = []
2022  for node in G.nodes():
2023  size = sqrt(node_size[node])/pi*10.0
2024  tmpsize.append(size)
2025  node_size = tmpsize
2026 
2027  for n, node in enumerate(G.nodes()):
2028  color = node_color_hex[n]
2029  size = node_size[n]
2030  nx.set_node_attributes(
2031  G, "graphics",
2032  {node: {'type': 'ellipse', 'w': size, 'h': size,
2033  'fill': '#' + color, 'label': node}})
2034  nx.set_node_attributes(
2035  G, "LabelGraphics",
2036  {node: {'type': 'text', 'text': node, 'color': '#000000',
2037  'visible': 'true'}})
2038 
2039  for edge in G.edges():
2040  nx.set_edge_attributes(
2041  G, "graphics",
2042  {edge: {'width': 1, 'fill': '#000000'}})
2043 
2044  for ve in validation_edges:
2045  print(ve)
2046  if (ve[0], ve[1]) in G.edges():
2047  print("found forward")
2048  nx.set_edge_attributes(
2049  G, "graphics",
2050  {ve: {'width': 1, 'fill': '#00FF00'}})
2051  elif (ve[1], ve[0]) in G.edges():
2052  print("found backward")
2053  nx.set_edge_attributes(
2054  G, "graphics",
2055  {(ve[1], ve[0]): {'width': 1, 'fill': '#00FF00'}})
2056  else:
2057  G.add_edge(ve[0], ve[1])
2058  print("not found")
2059  nx.set_edge_attributes(
2060  G, "graphics",
2061  {ve: {'width': 1, 'fill': '#FF0000'}})
2062 
2063  # these are different layouts for the network you may try
2064  # shell seems to work best
2065  if graph_layout == 'spring':
2066  print(fixed, pos)
2067  graph_pos = nx.spring_layout(G, k=1.0/8.0, fixed=fixed, pos=pos)
2068  elif graph_layout == 'spectral':
2069  graph_pos = nx.spectral_layout(G)
2070  elif graph_layout == 'random':
2071  graph_pos = nx.random_layout(G)
2072  else:
2073  graph_pos = nx.shell_layout(G)
2074 
2075  # draw graph
2076  nx.draw_networkx_nodes(G, graph_pos, node_size=node_size,
2077  alpha=node_alpha, node_color=node_color_rgb,
2078  linewidths=0)
2079  nx.draw_networkx_edges(G, graph_pos, width=edge_thickness,
2080  alpha=edge_alpha, edge_color=edge_color)
2081  nx.draw_networkx_labels(
2082  G, graph_pos, labels=labels_dict, font_size=node_text_size,
2083  font_family=text_font)
2084  if out_filename:
2085  plt.savefig(out_filename)
2086  nx.write_gml(G, 'out.gml')
2087  plt.show()
2088 
2089 
2090 def draw_table():
2091 
2092  # still an example!
2093 
2094  from ipyD3 import d3object
2095  from IPython.display import display
2096 
2097  d3 = d3object(width=800,
2098  height=400,
2099  style='JFTable',
2100  number=1,
2101  d3=None,
2102  title='Example table with d3js',
2103  desc='An example table created created with d3js with '
2104  'data generated with Python.')
2105  data = [[1277.0, 654.0, 288.0, 1976.0, 3281.0, 3089.0, 10336.0, 4650.0,
2106  4441.0, 4670.0, 944.0, 110.0],
2107  [1318.0, 664.0, 418.0, 1952.0, 3581.0, 4574.0, 11457.0, 6139.0,
2108  7078.0, 6561.0, 2354.0, 710.0],
2109  [1783.0, 774.0, 564.0, 1470.0, 3571.0, 3103.0, 9392.0, 5532.0,
2110  5661.0, 4991.0, 2032.0, 680.0],
2111  [1301.0, 604.0, 286.0, 2152.0, 3282.0, 3369.0, 10490.0, 5406.0,
2112  4727.0, 3428.0, 1559.0, 620.0],
2113  [1537.0, 1714.0, 724.0, 4824.0, 5551.0, 8096.0, 16589.0, 13650.0,
2114  9552.0, 13709.0, 2460.0, 720.0],
2115  [5691.0, 2995.0, 1680.0, 11741.0, 16232.0, 14731.0, 43522.0,
2116  32794.0, 26634.0, 31400.0, 7350.0, 3010.0],
2117  [1650.0, 2096.0, 60.0, 50.0, 1180.0, 5602.0, 15728.0, 6874.0,
2118  5115.0, 3510.0, 1390.0, 170.0],
2119  [72.0, 60.0, 60.0, 10.0, 120.0, 172.0, 1092.0, 675.0, 408.0,
2120  360.0, 156.0, 100.0]]
2121  data = [list(i) for i in zip(*data)]
2122  sRows = [['January',
2123  'February',
2124  'March',
2125  'April',
2126  'May',
2127  'June',
2128  'July',
2129  'August',
2130  'September',
2131  'October',
2132  'November',
2133  'Deecember']]
2134  sColumns = [['Prod {0}'.format(i) for i in range(1, 9)],
2135  [None, '', None, None, 'Group 1', None, None, 'Group 2']]
2136  d3.addSimpleTable(data,
2137  fontSizeCells=[12, ],
2138  sRows=sRows,
2139  sColumns=sColumns,
2140  sRowsMargins=[5, 50, 0],
2141  sColsMargins=[5, 20, 10],
2142  spacing=0,
2143  addBorders=1,
2144  addOutsideBorders=-1,
2145  rectWidth=45,
2146  rectHeight=0
2147  )
2148  html = d3.render(mode=['html', 'show'])
2149  display(html)
static bool get_is_setup(const IMP::ParticleAdaptor &p)
Definition: Residue.h:158
A container for models organized into clusters.
Definition: output.py:1506
A class for reading stat files (either rmf or ascii v1 and v2)
Definition: output.py:931
atom::Hierarchies create_hierarchies(RMF::FileConstHandle fh, Model *m)
RMF::FrameID save_frame(RMF::FileHandle file, std::string name="")
Save the current state of the linked objects as a new RMF frame.
static bool get_is_setup(const IMP::ParticleAdaptor &p)
Definition: atom/Atom.h:245
def plot_field_histogram
Plot a list of histograms from a value list.
Definition: output.py:1754
def plot_fields_box_plots
Plot time series as boxplots.
Definition: output.py:1819
Utility classes and functions for reading and storing PMI files.
def get_best_models
Given a list of stat files, read them all and find the best models.
Miscellaneous utilities.
Definition: tools.py:1
A class to store data associated to a model.
Definition: output.py:1484
void handle_use_deprecated(std::string message)
Break in this method in gdb to find deprecated uses at runtime.
std::string get_module_version()
Return the version of this module, as a string.
Change color code to hexadecimal to rgb.
Definition: tools.py:736
void write_pdb(const Selection &mhd, TextOutput out, unsigned int model=1)
Collect statistics from ProcessOutput.get_fields().
Definition: output.py:920
static bool get_is_setup(const IMP::ParticleAdaptor &p)
Definition: XYZR.h:47
def get_fields
Get the desired field names, and return a dictionary.
Definition: output.py:996
Warning related to handling of structures.
static bool get_is_setup(Model *m, ParticleIndex pi)
Definition: Fragment.h:46
def link_to_rmf
Link to another RMF file.
Definition: output.py:1155
std::string get_molecule_name_and_copy(atom::Hierarchy h)
Walk up a PMI2 hierarchy/representations and get the "molname.copynum".
Definition: pmi/utilities.h:85
The standard decorator for manipulating molecular structures.
Ints get_index(const ParticlesTemp &particles, const Subset &subset, const Subsets &excluded)
def init_pdb
Init PDB Writing.
Definition: output.py:257
int get_number_of_frames(const ::npctransport_proto::Assignment &config, double time_step)
A decorator for a particle representing an atom.
Definition: atom/Atom.h:238
Base class for capturing a modeling protocol.
Definition: output.py:41
The type for a residue.
void load_frame(RMF::FileConstHandle file, RMF::FrameID frame)
Load the given RMF frame into the state of the linked objects.
A decorator for a particle with x,y,z coordinates.
Definition: XYZ.h:30
A base class for Keys.
Definition: Key.h:45
void add_hierarchies(RMF::NodeHandle fh, const atom::Hierarchies &hs)
Class for easy writing of PDBs, RMFs, and stat files.
Definition: output.py:199
void add_geometries(RMF::NodeHandle parent, const display::GeometriesTemp &r)
Add geometries to a given parent node.
void add_restraints(RMF::NodeHandle fh, const Restraints &hs)
A decorator for a particle that is part of a rigid body but not rigid.
Definition: rigid_bodies.h:768
Display a segment connecting a pair of particles.
Definition: XYZR.h:170
A decorator for a residue.
Definition: Residue.h:137
Basic functionality that is expected to be used by a wide variety of IMP users.
def get_pdb_names
Get a list of all PDB files being output by this instance.
Definition: output.py:226
def get_prot_name_from_particle
Get the protein name from the particle.
Definition: output.py:343
class to link stat files to several rmf files
Definition: output.py:1253
class to allow more advanced handling of RMF files.
Definition: output.py:1129
void link_hierarchies(RMF::FileConstHandle fh, const atom::Hierarchies &hs)
def plot_fields
Plot the given fields and save a figure as output.
Definition: output.py:1718
void add_geometry(RMF::FileHandle file, display::Geometry *r)
Add a single geometry to the file.
Store info for a chain of a protein.
Definition: Chain.h:61
Python classes to represent, score, sample and analyze models.
def init_pdb_best_scoring
Prepare for writing best-scoring PDBs (or mmCIFs) for a sampling run.
Definition: output.py:464
Functionality for loading, creating, manipulating and scoring atomic structures.
def get_rbs_and_beads
Returns unique objects in original order.
Definition: tools.py:1135
Select hierarchy particles identified by the biological name.
Definition: Selection.h:70
def init_rmf
Initialize an RMF file.
Definition: output.py:561
def get_residue_indexes
Retrieve the residue indexes for the given particle.
Definition: tools.py:499
static bool get_is_setup(const IMP::ParticleAdaptor &p)
Definition: rigid_bodies.h:770
std::string get_module_version()
Return the version of this module, as a string.
def sublist_iterator
Yield all sublists of length >= lmin and <= lmax.
Definition: tools.py:574
A decorator for a particle with x,y,z coordinates and a radius.
Definition: XYZR.h:27