IMP logo
IMP Reference Guide  2.18.0
The Integrative Modeling Platform
output.py
1 """@namespace IMP.pmi.output
2  Classes for writing output files and processing them.
3 """
4 
5 from __future__ import print_function, division
6 import IMP
7 import IMP.atom
8 import IMP.core
9 import IMP.pmi
10 import IMP.pmi.tools
11 import IMP.pmi.io
12 import os
13 import sys
14 import ast
15 import RMF
16 import numpy as np
17 import operator
18 import itertools
19 import warnings
20 import string
21 import ihm.format
22 import collections
23 try:
24  import cPickle as pickle
25 except ImportError:
26  import pickle
27 
28 
29 class _ChainIDs(object):
30  """Map indices to multi-character chain IDs.
31  We label the first 26 chains A-Z, then we move to two-letter
32  chain IDs: AA through AZ, then BA through BZ, through to ZZ.
33  This continues with longer chain IDs."""
34  def __getitem__(self, ind):
35  chars = string.ascii_uppercase
36  lc = len(chars)
37  ids = []
38  while ind >= lc:
39  ids.append(chars[ind % lc])
40  ind = ind // lc - 1
41  ids.append(chars[ind])
42  return "".join(reversed(ids))
43 
44 
45 class ProtocolOutput(object):
46  """Base class for capturing a modeling protocol.
47  Unlike simple output of model coordinates, a complete
48  protocol includes the input data used, details on the restraints,
49  sampling, and clustering, as well as output models.
50  Use via IMP.pmi.topology.System.add_protocol_output().
51 
52  @see IMP.pmi.mmcif.ProtocolOutput for a concrete subclass that outputs
53  mmCIF files.
54  """
55  pass
56 
57 
58 def _flatten(seq):
59  for elt in seq:
60  if isinstance(elt, (tuple, list)):
61  for elt2 in _flatten(elt):
62  yield elt2
63  else:
64  yield elt
65 
66 
67 def _disambiguate_chain(chid, seen_chains):
68  """Make sure that the chain ID is unique; warn and correct if it isn't"""
69  # Handle null chain IDs
70  if chid == '\0':
71  chid = ' '
72 
73  if chid in seen_chains:
74  warnings.warn("Duplicate chain ID '%s' encountered" % chid,
76 
77  for suffix in itertools.count(1):
78  new_chid = chid + "%d" % suffix
79  if new_chid not in seen_chains:
80  seen_chains.add(new_chid)
81  return new_chid
82  seen_chains.add(chid)
83  return chid
84 
85 
86 def _write_pdb_internal(flpdb, particle_infos_for_pdb, geometric_center,
87  write_all_residues_per_bead):
88  for n, tupl in enumerate(particle_infos_for_pdb):
89  (xyz, atom_type, residue_type,
90  chain_id, residue_index, all_indexes, radius) = tupl
91  if atom_type is None:
92  atom_type = IMP.atom.AT_CA
93  if write_all_residues_per_bead and all_indexes is not None:
94  for residue_number in all_indexes:
95  flpdb.write(
96  IMP.atom.get_pdb_string((xyz[0] - geometric_center[0],
97  xyz[1] - geometric_center[1],
98  xyz[2] - geometric_center[2]),
99  n+1, atom_type, residue_type,
100  chain_id[:1], residue_number, ' ',
101  1.00, radius))
102  else:
103  flpdb.write(
104  IMP.atom.get_pdb_string((xyz[0] - geometric_center[0],
105  xyz[1] - geometric_center[1],
106  xyz[2] - geometric_center[2]),
107  n+1, atom_type, residue_type,
108  chain_id[:1], residue_index, ' ',
109  1.00, radius))
110  flpdb.write("ENDMDL\n")
111 
112 
113 _Entity = collections.namedtuple('_Entity', ('id', 'seq'))
114 _ChainInfo = collections.namedtuple('_ChainInfo', ('entity', 'name'))
115 
116 
117 def _get_chain_info(chains, root_hier):
118  chain_info = {}
119  entities = {}
120  all_entities = []
121  for mol in IMP.atom.get_by_type(root_hier, IMP.atom.MOLECULE_TYPE):
123  chain_id = chains[molname]
124  chain = IMP.atom.Chain(mol)
125  seq = chain.get_sequence()
126  if seq not in entities:
127  entities[seq] = e = _Entity(id=len(entities)+1, seq=seq)
128  all_entities.append(e)
129  entity = entities[seq]
130  info = _ChainInfo(entity=entity, name=molname)
131  chain_info[chain_id] = info
132  return chain_info, all_entities
133 
134 
135 def _write_mmcif_internal(flpdb, particle_infos_for_pdb, geometric_center,
136  write_all_residues_per_bead, chains, root_hier):
137  # get dict with keys=chain IDs, values=chain info
138  chain_info, entities = _get_chain_info(chains, root_hier)
139 
140  writer = ihm.format.CifWriter(flpdb)
141  writer.start_block('model')
142  with writer.category("_entry") as lp:
143  lp.write(id='model')
144 
145  with writer.loop("_entity", ["id", "type"]) as lp:
146  for e in entities:
147  lp.write(id=e.id, type="polymer")
148 
149  with writer.loop("_entity_poly",
150  ["entity_id", "pdbx_seq_one_letter_code"]) as lp:
151  for e in entities:
152  lp.write(entity_id=e.id, pdbx_seq_one_letter_code=e.seq)
153 
154  with writer.loop("_struct_asym", ["id", "entity_id", "details"]) as lp:
155  # Longer chain IDs (e.g. AA) should always come after shorter (e.g. Z)
156  for chid in sorted(chains.values(), key=lambda x: (len(x.strip()), x)):
157  ci = chain_info[chid]
158  lp.write(id=chid, entity_id=ci.entity.id, details=ci.name)
159 
160  with writer.loop("_atom_site",
161  ["group_PDB", "type_symbol", "label_atom_id",
162  "label_comp_id", "label_asym_id", "label_seq_id",
163  "auth_seq_id",
164  "Cartn_x", "Cartn_y", "Cartn_z", "label_entity_id",
165  "pdbx_pdb_model_num",
166  "id"]) as lp:
167  ordinal = 1
168  for n, tupl in enumerate(particle_infos_for_pdb):
169  (xyz, atom_type, residue_type,
170  chain_id, residue_index, all_indexes, radius) = tupl
171  ci = chain_info[chain_id]
172  if atom_type is None:
173  atom_type = IMP.atom.AT_CA
174  c = (xyz[0] - geometric_center[0],
175  xyz[1] - geometric_center[1],
176  xyz[2] - geometric_center[2])
177  if write_all_residues_per_bead and all_indexes is not None:
178  for residue_number in all_indexes:
179  lp.write(group_PDB='ATOM',
180  type_symbol='C',
181  label_atom_id=atom_type.get_string(),
182  label_comp_id=residue_type.get_string(),
183  label_asym_id=chain_id,
184  label_seq_id=residue_index,
185  auth_seq_id=residue_index, Cartn_x=c[0],
186  Cartn_y=c[1], Cartn_z=c[2], id=ordinal,
187  pdbx_pdb_model_num=1,
188  label_entity_id=ci.entity.id)
189  ordinal += 1
190  else:
191  lp.write(group_PDB='ATOM', type_symbol='C',
192  label_atom_id=atom_type.get_string(),
193  label_comp_id=residue_type.get_string(),
194  label_asym_id=chain_id,
195  label_seq_id=residue_index,
196  auth_seq_id=residue_index, Cartn_x=c[0],
197  Cartn_y=c[1], Cartn_z=c[2], id=ordinal,
198  pdbx_pdb_model_num=1,
199  label_entity_id=ci.entity.id)
200  ordinal += 1
201 
202 
203 class Output(object):
204  """Class for easy writing of PDBs, RMFs, and stat files
205 
206  @note Model should be updated prior to writing outputs.
207  """
208  def __init__(self, ascii=True, atomistic=False):
209  self.dictionary_pdbs = {}
210  self._pdb_mmcif = {}
211  self.dictionary_rmfs = {}
212  self.dictionary_stats = {}
213  self.dictionary_stats2 = {}
214  self.best_score_list = None
215  self.nbestscoring = None
216  self.prefixes = []
217  self.replica_exchange = False
218  self.ascii = ascii
219  self.initoutput = {}
220  self.residuetypekey = IMP.StringKey("ResidueName")
221  # 1-character chain IDs, suitable for PDB output
222  self.chainids = "ABCDEFGHIJKLMNOPQRSTUVWXYZ" \
223  "abcdefghijklmnopqrstuvwxyz0123456789"
224  # Multi-character chain IDs, suitable for mmCIF output
225  self.multi_chainids = _ChainIDs()
226  self.dictchain = {} # keys are molecule names, values are chain ids
227  self.particle_infos_for_pdb = {}
228  self.atomistic = atomistic
229  self.use_pmi2 = False
230 
231  def get_pdb_names(self):
232  """Get a list of all PDB files being output by this instance"""
233  return list(self.dictionary_pdbs.keys())
234 
235  def get_rmf_names(self):
236  return list(self.dictionary_rmfs.keys())
237 
238  def get_stat_names(self):
239  return list(self.dictionary_stats.keys())
240 
241  def _init_dictchain(self, name, prot, multichar_chain=False, mmcif=False):
242  self.dictchain[name] = {}
243  self.use_pmi2 = False
244  seen_chains = set()
245 
246  # attempt to find PMI objects.
247  if IMP.pmi.get_is_canonical(prot):
248  self.use_pmi2 = True
249  self.atomistic = True # detects automatically
250  for n, mol in enumerate(IMP.atom.get_by_type(
251  prot, IMP.atom.MOLECULE_TYPE)):
252  chid = IMP.atom.Chain(mol).get_id()
253  if not mmcif and len(chid) > 1:
254  raise ValueError(
255  "The system contains at least one chain ID (%s) that "
256  "is more than 1 character long; this cannot be "
257  "represented in PDB. Either write mmCIF files "
258  "instead, or assign 1-character IDs to all chains "
259  "(this can be done with the `chain_ids` argument to "
260  "BuildSystem.add_state())." % chid)
261  chid = _disambiguate_chain(chid, seen_chains)
263  self.dictchain[name][molname] = chid
264  else:
265  chainids = self.multi_chainids if multichar_chain \
266  else self.chainids
267  for n, i in enumerate(self.dictionary_pdbs[name].get_children()):
268  self.dictchain[name][i.get_name()] = chainids[n]
269 
270  def init_pdb(self, name, prot, mmcif=False):
271  """Init PDB Writing.
272  @param name The PDB filename
273  @param prot The hierarchy to write to this pdb file
274  @param mmcif If True, write PDBs in mmCIF format
275  @note if the PDB name is 'System' then will use Selection
276  to get molecules
277  """
278  flpdb = open(name, 'w')
279  flpdb.close()
280  self.dictionary_pdbs[name] = prot
281  self._pdb_mmcif[name] = mmcif
282  self._init_dictchain(name, prot, mmcif=mmcif)
283 
284  def write_psf(self, filename, name):
285  flpsf = open(filename, 'w')
286  flpsf.write("PSF CMAP CHEQ" + "\n")
287  index_residue_pair_list = {}
288  (particle_infos_for_pdb, geometric_center) = \
289  self.get_particle_infos_for_pdb_writing(name)
290  nparticles = len(particle_infos_for_pdb)
291  flpsf.write(str(nparticles) + " !NATOM" + "\n")
292  for n, p in enumerate(particle_infos_for_pdb):
293  atom_index = n+1
294  residue_type = p[2]
295  chain = p[3]
296  resid = p[4]
297  flpsf.write('{0:8d}{1:1s}{2:4s}{3:1s}{4:4s}{5:1s}{6:4s}{7:1s}'
298  '{8:4s}{9:1s}{10:4s}{11:14.6f}{12:14.6f}{13:8d}'
299  '{14:14.6f}{15:14.6f}'.format(
300  atom_index, " ", chain, " ", str(resid), " ",
301  '"'+residue_type.get_string()+'"', " ", "C",
302  " ", "C", 1.0, 0.0, 0, 0.0, 0.0))
303  flpsf.write('\n')
304  if chain not in index_residue_pair_list:
305  index_residue_pair_list[chain] = [(atom_index, resid)]
306  else:
307  index_residue_pair_list[chain].append((atom_index, resid))
308 
309  # now write the connectivity
310  indexes_pairs = []
311  for chain in sorted(index_residue_pair_list.keys()):
312 
313  ls = index_residue_pair_list[chain]
314  # sort by residue
315  ls = sorted(ls, key=lambda tup: tup[1])
316  # get the index list
317  indexes = [x[0] for x in ls]
318  # get the contiguous pairs
319  indexes_pairs.extend(IMP.pmi.tools.sublist_iterator(
320  indexes, lmin=2, lmax=2))
321  nbonds = len(indexes_pairs)
322  flpsf.write(str(nbonds)+" !NBOND: bonds"+"\n")
323 
324  # save bonds in fixed column format
325  for i in range(0, len(indexes_pairs), 4):
326  for bond in indexes_pairs[i:i+4]:
327  flpsf.write('{0:8d}{1:8d}'.format(*bond))
328  flpsf.write('\n')
329 
330  del particle_infos_for_pdb
331  flpsf.close()
332 
333  def write_pdb(self, name, appendmode=True,
334  translate_to_geometric_center=False,
335  write_all_residues_per_bead=False):
336 
337  (particle_infos_for_pdb,
338  geometric_center) = self.get_particle_infos_for_pdb_writing(name)
339 
340  if not translate_to_geometric_center:
341  geometric_center = (0, 0, 0)
342 
343  filemode = 'a' if appendmode else 'w'
344  with open(name, filemode) as flpdb:
345  if self._pdb_mmcif[name]:
346  _write_mmcif_internal(flpdb, particle_infos_for_pdb,
347  geometric_center,
348  write_all_residues_per_bead,
349  self.dictchain[name],
350  self.dictionary_pdbs[name])
351  else:
352  _write_pdb_internal(flpdb, particle_infos_for_pdb,
353  geometric_center,
354  write_all_residues_per_bead)
355 
356  def get_prot_name_from_particle(self, name, p):
357  """Get the protein name from the particle.
358  This is done by traversing the hierarchy."""
359  if self.use_pmi2:
360  return IMP.pmi.get_molecule_name_and_copy(p), True
361  else:
363  p, self.dictchain[name])
364 
365  def get_particle_infos_for_pdb_writing(self, name):
366  # index_residue_pair_list={}
367 
368  # the resindexes dictionary keep track of residues that have
369  # been already added to avoid duplication
370  # highest resolution have highest priority
371  resindexes_dict = {}
372 
373  # this dictionary dill contain the sequence of tuples needed to
374  # write the pdb
375  particle_infos_for_pdb = []
376 
377  geometric_center = [0, 0, 0]
378  atom_count = 0
379 
380  if self.use_pmi2:
381  # select highest resolution
382  sel = IMP.atom.Selection(self.dictionary_pdbs[name], resolution=0)
383  ps = sel.get_selected_particles()
384  else:
385  ps = IMP.atom.get_leaves(self.dictionary_pdbs[name])
386 
387  for n, p in enumerate(ps):
388  protname, is_a_bead = self.get_prot_name_from_particle(name, p)
389 
390  if protname not in resindexes_dict:
391  resindexes_dict[protname] = []
392 
393  if IMP.atom.Atom.get_is_setup(p) and self.atomistic:
394  residue = IMP.atom.Residue(IMP.atom.Atom(p).get_parent())
395  rt = residue.get_residue_type()
396  resind = residue.get_index()
397  atomtype = IMP.atom.Atom(p).get_atom_type()
398  xyz = list(IMP.core.XYZ(p).get_coordinates())
399  radius = IMP.core.XYZR(p).get_radius()
400  geometric_center[0] += xyz[0]
401  geometric_center[1] += xyz[1]
402  geometric_center[2] += xyz[2]
403  atom_count += 1
404  particle_infos_for_pdb.append(
405  (xyz, atomtype, rt, self.dictchain[name][protname],
406  resind, None, radius))
407  resindexes_dict[protname].append(resind)
408 
410 
411  residue = IMP.atom.Residue(p)
412  resind = residue.get_index()
413  # skip if the residue was already added by atomistic resolution
414  # 0
415  if resind in resindexes_dict[protname]:
416  continue
417  else:
418  resindexes_dict[protname].append(resind)
419  rt = residue.get_residue_type()
420  xyz = IMP.core.XYZ(p).get_coordinates()
421  radius = IMP.core.XYZR(p).get_radius()
422  geometric_center[0] += xyz[0]
423  geometric_center[1] += xyz[1]
424  geometric_center[2] += xyz[2]
425  atom_count += 1
426  particle_infos_for_pdb.append(
427  (xyz, None, rt, self.dictchain[name][protname], resind,
428  None, radius))
429 
430  elif IMP.atom.Fragment.get_is_setup(p) and not is_a_bead:
431  resindexes = list(IMP.pmi.tools.get_residue_indexes(p))
432  resind = resindexes[len(resindexes) // 2]
433  if resind in resindexes_dict[protname]:
434  continue
435  else:
436  resindexes_dict[protname].append(resind)
437  rt = IMP.atom.ResidueType('BEA')
438  xyz = IMP.core.XYZ(p).get_coordinates()
439  radius = IMP.core.XYZR(p).get_radius()
440  geometric_center[0] += xyz[0]
441  geometric_center[1] += xyz[1]
442  geometric_center[2] += xyz[2]
443  atom_count += 1
444  particle_infos_for_pdb.append(
445  (xyz, None, rt, self.dictchain[name][protname], resind,
446  resindexes, radius))
447 
448  else:
449  if is_a_bead:
450  rt = IMP.atom.ResidueType('BEA')
451  resindexes = list(IMP.pmi.tools.get_residue_indexes(p))
452  if len(resindexes) > 0:
453  resind = resindexes[len(resindexes) // 2]
454  xyz = IMP.core.XYZ(p).get_coordinates()
455  radius = IMP.core.XYZR(p).get_radius()
456  geometric_center[0] += xyz[0]
457  geometric_center[1] += xyz[1]
458  geometric_center[2] += xyz[2]
459  atom_count += 1
460  particle_infos_for_pdb.append(
461  (xyz, None, rt, self.dictchain[name][protname],
462  resind, resindexes, radius))
463 
464  if atom_count > 0:
465  geometric_center = (geometric_center[0] / atom_count,
466  geometric_center[1] / atom_count,
467  geometric_center[2] / atom_count)
468 
469  # sort by chain ID, then residue index. Longer chain IDs (e.g. AA)
470  # should always come after shorter (e.g. Z)
471  particle_infos_for_pdb = sorted(particle_infos_for_pdb,
472  key=lambda x: (len(x[3]), x[3], x[4]))
473 
474  return (particle_infos_for_pdb, geometric_center)
475 
476  def write_pdbs(self, appendmode=True, mmcif=False):
477  for pdb in self.dictionary_pdbs.keys():
478  self.write_pdb(pdb, appendmode)
479 
480  def init_pdb_best_scoring(self, prefix, prot, nbestscoring,
481  replica_exchange=False, mmcif=False,
482  best_score_file='best.scores.rex.py'):
483  """Prepare for writing best-scoring PDBs (or mmCIFs) for a
484  sampling run.
485 
486  @param prefix Initial part of each PDB filename (e.g. 'model').
487  @param prot The top-level Hierarchy to output.
488  @param nbestscoring The number of best-scoring files to output.
489  @param replica_exchange Whether to combine best scores from a
490  replica exchange run.
491  @param mmcif If True, output models in mmCIF format. If False
492  (the default) output in legacy PDB format.
493  @param best_score_file The filename to use for replica
494  exchange scores.
495  """
496 
497  self._pdb_best_scoring_mmcif = mmcif
498  fileext = '.cif' if mmcif else '.pdb'
499  self.prefixes.append(prefix)
500  self.replica_exchange = replica_exchange
501  if not self.replica_exchange:
502  # common usage
503  # if you are not in replica exchange mode
504  # initialize the array of scores internally
505  self.best_score_list = []
506  else:
507  # otherwise the replicas must communicate
508  # through a common file to know what are the best scores
509  self.best_score_file_name = best_score_file
510  self.best_score_list = []
511  with open(self.best_score_file_name, "w") as best_score_file:
512  best_score_file.write(
513  "self.best_score_list=" + str(self.best_score_list) + "\n")
514 
515  self.nbestscoring = nbestscoring
516  for i in range(self.nbestscoring):
517  name = prefix + "." + str(i) + fileext
518  flpdb = open(name, 'w')
519  flpdb.close()
520  self.dictionary_pdbs[name] = prot
521  self._pdb_mmcif[name] = mmcif
522  self._init_dictchain(name, prot, mmcif=mmcif)
523 
524  def write_pdb_best_scoring(self, score):
525  if self.nbestscoring is None:
526  print("Output.write_pdb_best_scoring: init_pdb_best_scoring "
527  "not run")
528 
529  mmcif = self._pdb_best_scoring_mmcif
530  fileext = '.cif' if mmcif else '.pdb'
531  # update the score list
532  if self.replica_exchange:
533  # read the self.best_score_list from the file
534  with open(self.best_score_file_name) as fh:
535  self.best_score_list = ast.literal_eval(
536  fh.read().split('=')[1])
537 
538  if len(self.best_score_list) < self.nbestscoring:
539  self.best_score_list.append(score)
540  self.best_score_list.sort()
541  index = self.best_score_list.index(score)
542  for prefix in self.prefixes:
543  for i in range(len(self.best_score_list) - 2, index - 1, -1):
544  oldname = prefix + "." + str(i) + fileext
545  newname = prefix + "." + str(i + 1) + fileext
546  # rename on Windows fails if newname already exists
547  if os.path.exists(newname):
548  os.unlink(newname)
549  os.rename(oldname, newname)
550  filetoadd = prefix + "." + str(index) + fileext
551  self.write_pdb(filetoadd, appendmode=False)
552 
553  else:
554  if score < self.best_score_list[-1]:
555  self.best_score_list.append(score)
556  self.best_score_list.sort()
557  self.best_score_list.pop(-1)
558  index = self.best_score_list.index(score)
559  for prefix in self.prefixes:
560  for i in range(len(self.best_score_list) - 1,
561  index - 1, -1):
562  oldname = prefix + "." + str(i) + fileext
563  newname = prefix + "." + str(i + 1) + fileext
564  os.rename(oldname, newname)
565  filenametoremove = prefix + \
566  "." + str(self.nbestscoring) + fileext
567  os.remove(filenametoremove)
568  filetoadd = prefix + "." + str(index) + fileext
569  self.write_pdb(filetoadd, appendmode=False)
570 
571  if self.replica_exchange:
572  # write the self.best_score_list to the file
573  with open(self.best_score_file_name, "w") as best_score_file:
574  best_score_file.write(
575  "self.best_score_list=" + str(self.best_score_list) + '\n')
576 
577  def init_rmf(self, name, hierarchies, rs=None, geometries=None,
578  listofobjects=None):
579  """
580  Initialize an RMF file
581 
582  @param name the name of the RMF file
583  @param hierarchies the hierarchies to be included (it is a list)
584  @param rs optional, the restraint sets (it is a list)
585  @param geometries optional, the geometries (it is a list)
586  @param listofobjects optional, the list of objects for the stat
587  (it is a list)
588  """
589  rh = RMF.create_rmf_file(name)
590  IMP.rmf.add_hierarchies(rh, hierarchies)
591  cat = None
592  outputkey_rmfkey = None
593 
594  if rs is not None:
595  IMP.rmf.add_restraints(rh, rs)
596  if geometries is not None:
597  IMP.rmf.add_geometries(rh, geometries)
598  if listofobjects is not None:
599  cat = rh.get_category("stat")
600  outputkey_rmfkey = {}
601  for o in listofobjects:
602  if "get_output" not in dir(o):
603  raise ValueError(
604  "Output: object %s doesn't have get_output() method"
605  % str(o))
606  output = o.get_output()
607  for outputkey in output:
608  rmftag = RMF.string_tag
609  if isinstance(output[outputkey], float):
610  rmftag = RMF.float_tag
611  elif isinstance(output[outputkey], int):
612  rmftag = RMF.int_tag
613  elif isinstance(output[outputkey], str):
614  rmftag = RMF.string_tag
615  else:
616  rmftag = RMF.string_tag
617  rmfkey = rh.get_key(cat, outputkey, rmftag)
618  outputkey_rmfkey[outputkey] = rmfkey
619  outputkey_rmfkey["rmf_file"] = \
620  rh.get_key(cat, "rmf_file", RMF.string_tag)
621  outputkey_rmfkey["rmf_frame_index"] = \
622  rh.get_key(cat, "rmf_frame_index", RMF.int_tag)
623 
624  self.dictionary_rmfs[name] = (rh, cat, outputkey_rmfkey, listofobjects)
625 
626  def add_restraints_to_rmf(self, name, objectlist):
627  for o in _flatten(objectlist):
628  try:
629  rs = o.get_restraint_for_rmf()
630  if not isinstance(rs, (list, tuple)):
631  rs = [rs]
632  except: # noqa: E722
633  rs = [o.get_restraint()]
635  self.dictionary_rmfs[name][0], rs)
636 
637  def add_geometries_to_rmf(self, name, objectlist):
638  for o in objectlist:
639  geos = o.get_geometries()
640  IMP.rmf.add_geometries(self.dictionary_rmfs[name][0], geos)
641 
642  def add_particle_pair_from_restraints_to_rmf(self, name, objectlist):
643  for o in objectlist:
644 
645  pps = o.get_particle_pairs()
646  for pp in pps:
648  self.dictionary_rmfs[name][0],
650 
651  def write_rmf(self, name):
652  IMP.rmf.save_frame(self.dictionary_rmfs[name][0])
653  if self.dictionary_rmfs[name][1] is not None:
654  outputkey_rmfkey = self.dictionary_rmfs[name][2]
655  listofobjects = self.dictionary_rmfs[name][3]
656  for o in listofobjects:
657  output = o.get_output()
658  for outputkey in output:
659  rmfkey = outputkey_rmfkey[outputkey]
660  try:
661  n = self.dictionary_rmfs[name][0].get_root_node()
662  n.set_value(rmfkey, output[outputkey])
663  except NotImplementedError:
664  continue
665  rmfkey = outputkey_rmfkey["rmf_file"]
666  self.dictionary_rmfs[name][0].get_root_node().set_value(
667  rmfkey, name)
668  rmfkey = outputkey_rmfkey["rmf_frame_index"]
669  nframes = self.dictionary_rmfs[name][0].get_number_of_frames()
670  self.dictionary_rmfs[name][0].get_root_node().set_value(
671  rmfkey, nframes-1)
672  self.dictionary_rmfs[name][0].flush()
673 
674  def close_rmf(self, name):
675  rh = self.dictionary_rmfs[name][0]
676  del self.dictionary_rmfs[name]
677  del rh
678 
679  def write_rmfs(self):
680  for rmfinfo in self.dictionary_rmfs.keys():
681  self.write_rmf(rmfinfo[0])
682 
683  def init_stat(self, name, listofobjects):
684  if self.ascii:
685  flstat = open(name, 'w')
686  flstat.close()
687  else:
688  flstat = open(name, 'wb')
689  flstat.close()
690 
691  # check that all objects in listofobjects have a get_output method
692  for o in listofobjects:
693  if "get_output" not in dir(o):
694  raise ValueError(
695  "Output: object %s doesn't have get_output() method"
696  % str(o))
697  self.dictionary_stats[name] = listofobjects
698 
699  def set_output_entry(self, key, value):
700  self.initoutput.update({key: value})
701 
702  def write_stat(self, name, appendmode=True):
703  output = self.initoutput
704  for obj in self.dictionary_stats[name]:
705  d = obj.get_output()
706  # remove all entries that begin with _ (private entries)
707  dfiltered = dict((k, v) for k, v in d.items() if k[0] != "_")
708  output.update(dfiltered)
709 
710  if appendmode:
711  writeflag = 'a'
712  else:
713  writeflag = 'w'
714 
715  if self.ascii:
716  flstat = open(name, writeflag)
717  flstat.write("%s \n" % output)
718  flstat.close()
719  else:
720  flstat = open(name, writeflag + 'b')
721  pickle.dump(output, flstat, 2)
722  flstat.close()
723 
724  def write_stats(self):
725  for stat in self.dictionary_stats.keys():
726  self.write_stat(stat)
727 
728  def get_stat(self, name):
729  output = {}
730  for obj in self.dictionary_stats[name]:
731  output.update(obj.get_output())
732  return output
733 
734  def write_test(self, name, listofobjects):
735  flstat = open(name, 'w')
736  output = self.initoutput
737  for o in listofobjects:
738  if "get_test_output" not in dir(o) and "get_output" not in dir(o):
739  raise ValueError(
740  "Output: object %s doesn't have get_output() or "
741  "get_test_output() method" % str(o))
742  self.dictionary_stats[name] = listofobjects
743 
744  for obj in self.dictionary_stats[name]:
745  try:
746  d = obj.get_test_output()
747  except: # noqa: E722
748  d = obj.get_output()
749  # remove all entries that begin with _ (private entries)
750  dfiltered = dict((k, v) for k, v in d.items() if k[0] != "_")
751  output.update(dfiltered)
752  flstat.write("%s \n" % output)
753  flstat.close()
754 
755  def test(self, name, listofobjects, tolerance=1e-5):
756  output = self.initoutput
757  for o in listofobjects:
758  if "get_test_output" not in dir(o) and "get_output" not in dir(o):
759  raise ValueError(
760  "Output: object %s doesn't have get_output() or "
761  "get_test_output() method" % str(o))
762  for obj in listofobjects:
763  try:
764  output.update(obj.get_test_output())
765  except: # noqa: E722
766  output.update(obj.get_output())
767 
768  flstat = open(name, 'r')
769 
770  passed = True
771  for fl in flstat:
772  test_dict = ast.literal_eval(fl)
773  for k in test_dict:
774  if k in output:
775  old_value = str(test_dict[k])
776  new_value = str(output[k])
777  try:
778  float(old_value)
779  is_float = True
780  except ValueError:
781  is_float = False
782 
783  if is_float:
784  fold = float(old_value)
785  fnew = float(new_value)
786  diff = abs(fold - fnew)
787  if diff > tolerance:
788  print("%s: test failed, old value: %s new value %s; "
789  "diff %f > %f" % (str(k), str(old_value),
790  str(new_value), diff,
791  tolerance), file=sys.stderr)
792  passed = False
793  elif test_dict[k] != output[k]:
794  if len(old_value) < 50 and len(new_value) < 50:
795  print("%s: test failed, old value: %s new value %s"
796  % (str(k), old_value, new_value),
797  file=sys.stderr)
798  passed = False
799  else:
800  print("%s: test failed, omitting results (too long)"
801  % str(k), file=sys.stderr)
802  passed = False
803 
804  else:
805  print("%s from old objects (file %s) not in new objects"
806  % (str(k), str(name)), file=sys.stderr)
807  flstat.close()
808  return passed
809 
810  def get_environment_variables(self):
811  import os
812  return str(os.environ)
813 
814  def get_versions_of_relevant_modules(self):
815  import IMP
816  versions = {}
817  versions["IMP_VERSION"] = IMP.get_module_version()
818  versions["PMI_VERSION"] = IMP.pmi.get_module_version()
819  try:
820  import IMP.isd2
821  versions["ISD2_VERSION"] = IMP.isd2.get_module_version()
822  except ImportError:
823  pass
824  try:
825  import IMP.isd_emxl
826  versions["ISD_EMXL_VERSION"] = IMP.isd_emxl.get_module_version()
827  except ImportError:
828  pass
829  return versions
830 
831  def init_stat2(self, name, listofobjects, extralabels=None,
832  listofsummedobjects=None):
833  # this is a new stat file that should be less
834  # space greedy!
835  # listofsummedobjects must be in the form
836  # [([obj1,obj2,obj3,obj4...],label)]
837  # extralabels
838 
839  if listofsummedobjects is None:
840  listofsummedobjects = []
841  if extralabels is None:
842  extralabels = []
843  flstat = open(name, 'w')
844  output = {}
845  stat2_keywords = {"STAT2HEADER": "STAT2HEADER"}
846  stat2_keywords.update(
847  {"STAT2HEADER_ENVIRON": str(self.get_environment_variables())})
848  stat2_keywords.update(
849  {"STAT2HEADER_IMP_VERSIONS":
850  str(self.get_versions_of_relevant_modules())})
851  stat2_inverse = {}
852 
853  for obj in listofobjects:
854  if "get_output" not in dir(obj):
855  raise ValueError(
856  "Output: object %s doesn't have get_output() method"
857  % str(obj))
858  else:
859  d = obj.get_output()
860  # remove all entries that begin with _ (private entries)
861  dfiltered = dict((k, v)
862  for k, v in d.items() if k[0] != "_")
863  output.update(dfiltered)
864 
865  # check for customizable entries
866  for obj in listofsummedobjects:
867  for t in obj[0]:
868  if "get_output" not in dir(t):
869  raise ValueError(
870  "Output: object %s doesn't have get_output() method"
871  % str(t))
872  else:
873  if "_TotalScore" not in t.get_output():
874  raise ValueError(
875  "Output: object %s doesn't have _TotalScore "
876  "entry to be summed" % str(t))
877  else:
878  output.update({obj[1]: 0.0})
879 
880  for k in extralabels:
881  output.update({k: 0.0})
882 
883  for n, k in enumerate(output):
884  stat2_keywords.update({n: k})
885  stat2_inverse.update({k: n})
886 
887  flstat.write("%s \n" % stat2_keywords)
888  flstat.close()
889  self.dictionary_stats2[name] = (
890  listofobjects,
891  stat2_inverse,
892  listofsummedobjects,
893  extralabels)
894 
895  def write_stat2(self, name, appendmode=True):
896  output = {}
897  (listofobjects, stat2_inverse, listofsummedobjects,
898  extralabels) = self.dictionary_stats2[name]
899 
900  # writing objects
901  for obj in listofobjects:
902  od = obj.get_output()
903  dfiltered = dict((k, v) for k, v in od.items() if k[0] != "_")
904  for k in dfiltered:
905  output.update({stat2_inverse[k]: od[k]})
906 
907  # writing summedobjects
908  for so in listofsummedobjects:
909  partial_score = 0.0
910  for t in so[0]:
911  d = t.get_output()
912  partial_score += float(d["_TotalScore"])
913  output.update({stat2_inverse[so[1]]: str(partial_score)})
914 
915  # writing extralabels
916  for k in extralabels:
917  if k in self.initoutput:
918  output.update({stat2_inverse[k]: self.initoutput[k]})
919  else:
920  output.update({stat2_inverse[k]: "None"})
921 
922  if appendmode:
923  writeflag = 'a'
924  else:
925  writeflag = 'w'
926 
927  flstat = open(name, writeflag)
928  flstat.write("%s \n" % output)
929  flstat.close()
930 
931  def write_stats2(self):
932  for stat in self.dictionary_stats2.keys():
933  self.write_stat2(stat)
934 
935 
936 class OutputStatistics(object):
937  """Collect statistics from ProcessOutput.get_fields().
938  Counters of the total number of frames read, plus the models that
939  passed the various filters used in get_fields(), are provided."""
940  def __init__(self):
941  self.total = 0
942  self.passed_get_every = 0
943  self.passed_filterout = 0
944  self.passed_filtertuple = 0
945 
946 
947 class ProcessOutput(object):
948  """A class for reading stat files (either rmf or ascii v1 and v2)"""
949  def __init__(self, filename):
950  self.filename = filename
951  self.isstat1 = False
952  self.isstat2 = False
953  self.isrmf = False
954 
955  if self.filename is None:
956  raise ValueError("No file name provided. Use -h for help")
957 
958  try:
959  # let's see if that is an rmf file
960  rh = RMF.open_rmf_file_read_only(self.filename)
961  self.isrmf = True
962  cat = rh.get_category('stat')
963  rmf_klist = rh.get_keys(cat)
964  self.rmf_names_keys = dict([(rh.get_name(k), k)
965  for k in rmf_klist])
966  del rh
967 
968  except IOError:
969  f = open(self.filename, "r")
970  # try with an ascii stat file
971  # get the keys from the first line
972  for line in f.readlines():
973  d = ast.literal_eval(line)
974  self.klist = list(d.keys())
975  # check if it is a stat2 file
976  if "STAT2HEADER" in self.klist:
977  self.isstat2 = True
978  for k in self.klist:
979  if "STAT2HEADER" in str(k):
980  # if print_header: print k, d[k]
981  del d[k]
982  stat2_dict = d
983  # get the list of keys sorted by value
984  kkeys = [k[0]
985  for k in sorted(stat2_dict.items(),
986  key=operator.itemgetter(1))]
987  self.klist = [k[1]
988  for k in sorted(stat2_dict.items(),
989  key=operator.itemgetter(1))]
990  self.invstat2_dict = {}
991  for k in kkeys:
992  self.invstat2_dict.update({stat2_dict[k]: k})
993  else:
995  "statfile v1 is deprecated. "
996  "Please convert to statfile v2.\n")
997  self.isstat1 = True
998  self.klist.sort()
999 
1000  break
1001  f.close()
1002 
1003  def get_keys(self):
1004  if self.isrmf:
1005  return sorted(self.rmf_names_keys.keys())
1006  else:
1007  return self.klist
1008 
1009  def show_keys(self, ncolumns=2, truncate=65):
1010  IMP.pmi.tools.print_multicolumn(self.get_keys(), ncolumns, truncate)
1011 
1012  def get_fields(self, fields, filtertuple=None, filterout=None, get_every=1,
1013  statistics=None):
1014  '''
1015  Get the desired field names, and return a dictionary.
1016  Namely, "fields" are the queried keys in the stat file
1017  (eg. ["Total_Score",...])
1018  The returned data structure is a dictionary, where each key is
1019  a field and the value is the time series (ie, frame ordered series)
1020  of that field (ie, {"Total_Score":[Score_0,Score_1,Score_2,,...],....})
1021 
1022  @param fields (list of strings) queried keys in the stat file
1023  (eg. "Total_Score"....)
1024  @param filterout specify if you want to "grep" out something from
1025  the file, so that it is faster
1026  @param filtertuple a tuple that contains
1027  ("TheKeyToBeFiltered",relationship,value)
1028  where relationship = "<", "==", or ">"
1029  @param get_every only read every Nth line from the file
1030  @param statistics if provided, accumulate statistics in an
1031  OutputStatistics object
1032  '''
1033 
1034  if statistics is None:
1035  statistics = OutputStatistics()
1036  outdict = {}
1037  for field in fields:
1038  outdict[field] = []
1039 
1040  # print fields values
1041  if self.isrmf:
1042  rh = RMF.open_rmf_file_read_only(self.filename)
1043  nframes = rh.get_number_of_frames()
1044  for i in range(nframes):
1045  statistics.total += 1
1046  # "get_every" and "filterout" not enforced for RMF
1047  statistics.passed_get_every += 1
1048  statistics.passed_filterout += 1
1049  IMP.rmf.load_frame(rh, RMF.FrameID(i))
1050  if filtertuple is not None:
1051  keytobefiltered = filtertuple[0]
1052  relationship = filtertuple[1]
1053  value = filtertuple[2]
1054  datavalue = rh.get_root_node().get_value(
1055  self.rmf_names_keys[keytobefiltered])
1056  if self.isfiltered(datavalue, relationship, value):
1057  continue
1058 
1059  statistics.passed_filtertuple += 1
1060  for field in fields:
1061  outdict[field].append(rh.get_root_node().get_value(
1062  self.rmf_names_keys[field]))
1063 
1064  else:
1065  f = open(self.filename, "r")
1066  line_number = 0
1067 
1068  for line in f.readlines():
1069  statistics.total += 1
1070  if filterout is not None:
1071  if filterout in line:
1072  continue
1073  statistics.passed_filterout += 1
1074  line_number += 1
1075 
1076  if line_number % get_every != 0:
1077  if line_number == 1 and self.isstat2:
1078  statistics.total -= 1
1079  statistics.passed_filterout -= 1
1080  continue
1081  statistics.passed_get_every += 1
1082  try:
1083  d = ast.literal_eval(line)
1084  except: # noqa: E722
1085  print("# Warning: skipped line number " + str(line_number)
1086  + " not a valid line")
1087  continue
1088 
1089  if self.isstat1:
1090 
1091  if filtertuple is not None:
1092  keytobefiltered = filtertuple[0]
1093  relationship = filtertuple[1]
1094  value = filtertuple[2]
1095  datavalue = d[keytobefiltered]
1096  if self.isfiltered(datavalue, relationship, value):
1097  continue
1098 
1099  statistics.passed_filtertuple += 1
1100  [outdict[field].append(d[field]) for field in fields]
1101 
1102  elif self.isstat2:
1103  if line_number == 1:
1104  statistics.total -= 1
1105  statistics.passed_filterout -= 1
1106  statistics.passed_get_every -= 1
1107  continue
1108 
1109  if filtertuple is not None:
1110  keytobefiltered = filtertuple[0]
1111  relationship = filtertuple[1]
1112  value = filtertuple[2]
1113  datavalue = d[self.invstat2_dict[keytobefiltered]]
1114  if self.isfiltered(datavalue, relationship, value):
1115  continue
1116 
1117  statistics.passed_filtertuple += 1
1118  [outdict[field].append(d[self.invstat2_dict[field]])
1119  for field in fields]
1120 
1121  f.close()
1122 
1123  return outdict
1124 
1125  def isfiltered(self, datavalue, relationship, refvalue):
1126  dofilter = False
1127  try:
1128  _ = float(datavalue)
1129  except ValueError:
1130  raise ValueError("ProcessOutput.filter: datavalue cannot be "
1131  "converted into a float")
1132 
1133  if relationship == "<":
1134  if float(datavalue) >= refvalue:
1135  dofilter = True
1136  if relationship == ">":
1137  if float(datavalue) <= refvalue:
1138  dofilter = True
1139  if relationship == "==":
1140  if float(datavalue) != refvalue:
1141  dofilter = True
1142  return dofilter
1143 
1144 
1146  """ class to allow more advanced handling of RMF files.
1147  It is both a container and a IMP.atom.Hierarchy.
1148  - it is iterable (while loading the corresponding frame)
1149  - Item brackets [] load the corresponding frame
1150  - slice create an iterator
1151  - can relink to another RMF file
1152  """
1153  def __init__(self, model, rmf_file_name):
1154  """
1155  @param model: the IMP.Model()
1156  @param rmf_file_name: str, path of the rmf file
1157  """
1158  self.model = model
1159  try:
1160  self.rh_ref = RMF.open_rmf_file_read_only(rmf_file_name)
1161  except TypeError:
1162  raise TypeError("Wrong rmf file name or type: %s"
1163  % str(rmf_file_name))
1164  hs = IMP.rmf.create_hierarchies(self.rh_ref, self.model)
1165  IMP.rmf.load_frame(self.rh_ref, RMF.FrameID(0))
1166  self.root_hier_ref = hs[0]
1167  IMP.atom.Hierarchy.__init__(self, self.root_hier_ref)
1168  self.model.update()
1169  self.ColorHierarchy = None
1170 
1171  def link_to_rmf(self, rmf_file_name):
1172  """
1173  Link to another RMF file
1174  """
1175  self.rh_ref = RMF.open_rmf_file_read_only(rmf_file_name)
1176  IMP.rmf.link_hierarchies(self.rh_ref, [self])
1177  if self.ColorHierarchy:
1178  self.ColorHierarchy.method()
1179  RMFHierarchyHandler.set_frame(self, 0)
1180 
1181  def set_frame(self, index):
1182  try:
1183  IMP.rmf.load_frame(self.rh_ref, RMF.FrameID(index))
1184  except: # noqa: E722
1185  print("skipping frame %s:%d\n" % (self.current_rmf, index))
1186  self.model.update()
1187 
1188  def get_number_of_frames(self):
1189  return self.rh_ref.get_number_of_frames()
1190 
1191  def __getitem__(self, int_slice_adaptor):
1192  if isinstance(int_slice_adaptor, int):
1193  self.set_frame(int_slice_adaptor)
1194  return int_slice_adaptor
1195  elif isinstance(int_slice_adaptor, slice):
1196  return self.__iter__(int_slice_adaptor)
1197  else:
1198  raise TypeError("Unknown Type")
1199 
1200  def __len__(self):
1201  return self.get_number_of_frames()
1202 
1203  def __iter__(self, slice_key=None):
1204  if slice_key is None:
1205  for nframe in range(len(self)):
1206  yield self[nframe]
1207  else:
1208  for nframe in list(range(len(self)))[slice_key]:
1209  yield self[nframe]
1210 
1211 
1212 class CacheHierarchyCoordinates(object):
1213  def __init__(self, StatHierarchyHandler):
1214  self.xyzs = []
1215  self.nrms = []
1216  self.rbs = []
1217  self.nrm_coors = {}
1218  self.xyz_coors = {}
1219  self.rb_trans = {}
1220  self.current_index = None
1221  self.rmfh = StatHierarchyHandler
1222  rbs, xyzs = IMP.pmi.tools.get_rbs_and_beads([self.rmfh])
1223  self.model = self.rmfh.get_model()
1224  self.rbs = rbs
1225  for xyz in xyzs:
1227  nrm = IMP.core.NonRigidMember(xyz)
1228  self.nrms.append(nrm)
1229  else:
1230  fb = IMP.core.XYZ(xyz)
1231  self.xyzs.append(fb)
1232 
1233  def do_store(self, index):
1234  self.rb_trans[index] = {}
1235  self.nrm_coors[index] = {}
1236  self.xyz_coors[index] = {}
1237  for rb in self.rbs:
1238  self.rb_trans[index][rb] = rb.get_reference_frame()
1239  for nrm in self.nrms:
1240  self.nrm_coors[index][nrm] = nrm.get_internal_coordinates()
1241  for xyz in self.xyzs:
1242  self.xyz_coors[index][xyz] = xyz.get_coordinates()
1243  self.current_index = index
1244 
1245  def do_update(self, index):
1246  if self.current_index != index:
1247  for rb in self.rbs:
1248  rb.set_reference_frame(self.rb_trans[index][rb])
1249  for nrm in self.nrms:
1250  nrm.set_internal_coordinates(self.nrm_coors[index][nrm])
1251  for xyz in self.xyzs:
1252  xyz.set_coordinates(self.xyz_coors[index][xyz])
1253  self.current_index = index
1254  self.model.update()
1255 
1256  def get_number_of_frames(self):
1257  return len(self.rb_trans.keys())
1258 
1259  def __getitem__(self, index):
1260  if isinstance(index, int):
1261  return index in self.rb_trans.keys()
1262  else:
1263  raise TypeError("Unknown Type")
1264 
1265  def __len__(self):
1266  return self.get_number_of_frames()
1267 
1268 
1270  """ class to link stat files to several rmf files """
1271  def __init__(self, model=None, stat_file=None,
1272  number_best_scoring_models=None, score_key=None,
1273  StatHierarchyHandler=None, cache=None):
1274  """
1275 
1276  @param model: IMP.Model()
1277  @param stat_file: either 1) a list or 2) a single stat file names
1278  (either rmfs or ascii, or pickled data or pickled cluster),
1279  3) a dictionary containing an rmf/ascii
1280  stat file name as key and a list of frames as values
1281  @param number_best_scoring_models:
1282  @param StatHierarchyHandler: copy constructor input object
1283  @param cache: cache coordinates and rigid body transformations.
1284  """
1285 
1286  if StatHierarchyHandler is not None:
1287  # overrides all other arguments
1288  # copy constructor: create a copy with
1289  # different RMFHierarchyHandler
1290  self.model = StatHierarchyHandler.model
1291  self.data = StatHierarchyHandler.data
1292  self.number_best_scoring_models = \
1293  StatHierarchyHandler.number_best_scoring_models
1294  self.is_setup = True
1295  self.current_rmf = StatHierarchyHandler.current_rmf
1296  self.current_frame = None
1297  self.current_index = None
1298  self.score_threshold = StatHierarchyHandler.score_threshold
1299  self.score_key = StatHierarchyHandler.score_key
1300  self.cache = StatHierarchyHandler.cache
1301  RMFHierarchyHandler.__init__(self, self.model,
1302  self.current_rmf)
1303  if self.cache:
1304  self.cache = CacheHierarchyCoordinates(self)
1305  else:
1306  self.cache = None
1307  self.set_frame(0)
1308 
1309  else:
1310  # standard constructor
1311  self.model = model
1312  self.data = []
1313  self.number_best_scoring_models = number_best_scoring_models
1314  self.cache = cache
1315 
1316  if score_key is None:
1317  self.score_key = "Total_Score"
1318  else:
1319  self.score_key = score_key
1320  self.is_setup = None
1321  self.current_rmf = None
1322  self.current_frame = None
1323  self.current_index = None
1324  self.score_threshold = None
1325 
1326  if isinstance(stat_file, str):
1327  self.add_stat_file(stat_file)
1328  elif isinstance(stat_file, list):
1329  for f in stat_file:
1330  self.add_stat_file(f)
1331 
1332  def add_stat_file(self, stat_file):
1333  try:
1334  '''check that it is not a pickle file with saved data
1335  from a previous calculation'''
1336  self.load_data(stat_file)
1337 
1338  if self.number_best_scoring_models:
1339  scores = self.get_scores()
1340  max_score = sorted(scores)[
1341  0:min(len(self), self.number_best_scoring_models)][-1]
1342  self.do_filter_by_score(max_score)
1343 
1344  except pickle.UnpicklingError:
1345  '''alternatively read the ascii stat files'''
1346  try:
1347  scores, rmf_files, rmf_frame_indexes, features = \
1348  self.get_info_from_stat_file(stat_file,
1349  self.score_threshold)
1350  except (KeyError, SyntaxError):
1351  # in this case check that is it an rmf file, probably
1352  # without stat stored in
1353  try:
1354  # let's see if that is an rmf file
1355  rh = RMF.open_rmf_file_read_only(stat_file)
1356  nframes = rh.get_number_of_frames()
1357  scores = [0.0]*nframes
1358  rmf_files = [stat_file]*nframes
1359  rmf_frame_indexes = range(nframes)
1360  features = {}
1361  except: # noqa: E722
1362  return
1363 
1364  if len(set(rmf_files)) > 1:
1365  raise ("Multiple RMF files found")
1366 
1367  if not rmf_files:
1368  print("StatHierarchyHandler: Error: Trying to set none as "
1369  "rmf_file (probably empty stat file), aborting")
1370  return
1371 
1372  for n, index in enumerate(rmf_frame_indexes):
1373  featn_dict = dict([(k, features[k][n]) for k in features])
1374  self.data.append(IMP.pmi.output.DataEntry(
1375  stat_file, rmf_files[n], index, scores[n], featn_dict))
1376 
1377  if self.number_best_scoring_models:
1378  scores = self.get_scores()
1379  max_score = sorted(scores)[
1380  0:min(len(self), self.number_best_scoring_models)][-1]
1381  self.do_filter_by_score(max_score)
1382 
1383  if not self.is_setup:
1384  RMFHierarchyHandler.__init__(
1385  self, self.model, self.get_rmf_names()[0])
1386  if self.cache:
1387  self.cache = CacheHierarchyCoordinates(self)
1388  else:
1389  self.cache = None
1390  self.is_setup = True
1391  self.current_rmf = self.get_rmf_names()[0]
1392 
1393  self.set_frame(0)
1394 
1395  def save_data(self, filename='data.pkl'):
1396  with open(filename, 'wb') as fl:
1397  pickle.dump(self.data, fl)
1398 
1399  def load_data(self, filename='data.pkl'):
1400  with open(filename, 'rb') as fl:
1401  data_structure = pickle.load(fl)
1402  # first check that it is a list
1403  if not isinstance(data_structure, list):
1404  raise TypeError(
1405  "%filename should contain a list of IMP.pmi.output.DataEntry "
1406  "or IMP.pmi.output.Cluster" % filename)
1407  # second check the types
1408  if all(isinstance(item, IMP.pmi.output.DataEntry)
1409  for item in data_structure):
1410  self.data = data_structure
1411  elif all(isinstance(item, IMP.pmi.output.Cluster)
1412  for item in data_structure):
1413  nmodels = 0
1414  for cluster in data_structure:
1415  nmodels += len(cluster)
1416  self.data = [None]*nmodels
1417  for cluster in data_structure:
1418  for n, data in enumerate(cluster):
1419  index = cluster.members[n]
1420  self.data[index] = data
1421  else:
1422  raise TypeError(
1423  "%filename should contain a list of IMP.pmi.output.DataEntry "
1424  "or IMP.pmi.output.Cluster" % filename)
1425 
1426  def set_frame(self, index):
1427  if self.cache is not None and self.cache[index]:
1428  self.cache.do_update(index)
1429  else:
1430  nm = self.data[index].rmf_name
1431  fidx = self.data[index].rmf_index
1432  if nm != self.current_rmf:
1433  self.link_to_rmf(nm)
1434  self.current_rmf = nm
1435  self.current_frame = -1
1436  if fidx != self.current_frame:
1437  RMFHierarchyHandler.set_frame(self, fidx)
1438  self.current_frame = fidx
1439  if self.cache is not None:
1440  self.cache.do_store(index)
1441 
1442  self.current_index = index
1443 
1444  def __getitem__(self, int_slice_adaptor):
1445  if isinstance(int_slice_adaptor, int):
1446  self.set_frame(int_slice_adaptor)
1447  return self.data[int_slice_adaptor]
1448  elif isinstance(int_slice_adaptor, slice):
1449  return self.__iter__(int_slice_adaptor)
1450  else:
1451  raise TypeError("Unknown Type")
1452 
1453  def __len__(self):
1454  return len(self.data)
1455 
1456  def __iter__(self, slice_key=None):
1457  if slice_key is None:
1458  for i in range(len(self)):
1459  yield self[i]
1460  else:
1461  for i in range(len(self))[slice_key]:
1462  yield self[i]
1463 
1464  def do_filter_by_score(self, maximum_score):
1465  self.data = [d for d in self.data if d.score <= maximum_score]
1466 
1467  def get_scores(self):
1468  return [d.score for d in self.data]
1469 
1470  def get_feature_series(self, feature_name):
1471  return [d.features[feature_name] for d in self.data]
1472 
1473  def get_feature_names(self):
1474  return self.data[0].features.keys()
1475 
1476  def get_rmf_names(self):
1477  return [d.rmf_name for d in self.data]
1478 
1479  def get_stat_files_names(self):
1480  return [d.stat_file for d in self.data]
1481 
1482  def get_rmf_indexes(self):
1483  return [d.rmf_index for d in self.data]
1484 
1485  def get_info_from_stat_file(self, stat_file, score_threshold=None):
1486  po = ProcessOutput(stat_file)
1487  fs = po.get_keys()
1488  models = IMP.pmi.io.get_best_models(
1489  [stat_file], score_key=self.score_key, feature_keys=fs,
1490  rmf_file_key="rmf_file", rmf_file_frame_key="rmf_frame_index",
1491  prefiltervalue=score_threshold, get_every=1)
1492 
1493  scores = [float(y) for y in models[2]]
1494  rmf_files = models[0]
1495  rmf_frame_indexes = models[1]
1496  features = models[3]
1497  return scores, rmf_files, rmf_frame_indexes, features
1498 
1499 
1500 class DataEntry(object):
1501  '''
1502  A class to store data associated to a model
1503  '''
1504  def __init__(self, stat_file=None, rmf_name=None, rmf_index=None,
1505  score=None, features=None):
1506  self.rmf_name = rmf_name
1507  self.rmf_index = rmf_index
1508  self.score = score
1509  self.features = features
1510  self.stat_file = stat_file
1511 
1512  def __repr__(self):
1513  s = "IMP.pmi.output.DataEntry\n"
1514  s += "---- stat file %s \n" % (self.stat_file)
1515  s += "---- rmf file %s \n" % (self.rmf_name)
1516  s += "---- rmf index %s \n" % (str(self.rmf_index))
1517  s += "---- score %s \n" % (str(self.score))
1518  s += "---- number of features %s \n" % (str(len(self.features.keys())))
1519  return s
1520 
1521 
1522 class Cluster(object):
1523  '''
1524  A container for models organized into clusters
1525  '''
1526  def __init__(self, cid=None):
1527  self.cluster_id = cid
1528  self.members = []
1529  self.precision = None
1530  self.center_index = None
1531  self.members_data = {}
1532 
1533  def add_member(self, index, data=None):
1534  self.members.append(index)
1535  self.members_data[index] = data
1536  self.average_score = self.compute_score()
1537 
1538  def compute_score(self):
1539  try:
1540  score = sum([d.score for d in self])/len(self)
1541  except AttributeError:
1542  score = None
1543  return score
1544 
1545  def __repr__(self):
1546  s = "IMP.pmi.output.Cluster\n"
1547  s += "---- cluster_id %s \n" % str(self.cluster_id)
1548  s += "---- precision %s \n" % str(self.precision)
1549  s += "---- average score %s \n" % str(self.average_score)
1550  s += "---- number of members %s \n" % str(len(self.members))
1551  s += "---- center index %s \n" % str(self.center_index)
1552  return s
1553 
1554  def __getitem__(self, int_slice_adaptor):
1555  if isinstance(int_slice_adaptor, int):
1556  index = self.members[int_slice_adaptor]
1557  return self.members_data[index]
1558  elif isinstance(int_slice_adaptor, slice):
1559  return self.__iter__(int_slice_adaptor)
1560  else:
1561  raise TypeError("Unknown Type")
1562 
1563  def __len__(self):
1564  return len(self.members)
1565 
1566  def __iter__(self, slice_key=None):
1567  if slice_key is None:
1568  for i in range(len(self)):
1569  yield self[i]
1570  else:
1571  for i in range(len(self))[slice_key]:
1572  yield self[i]
1573 
1574  def __add__(self, other):
1575  self.members += other.members
1576  self.members_data.update(other.members_data)
1577  self.average_score = self.compute_score()
1578  self.precision = None
1579  self.center_index = None
1580  return self
1581 
1582 
1583 def plot_clusters_populations(clusters):
1584  indexes = []
1585  populations = []
1586  for cluster in clusters:
1587  indexes.append(cluster.cluster_id)
1588  populations.append(len(cluster))
1589 
1590  import matplotlib.pyplot as plt
1591  fig, ax = plt.subplots()
1592  ax.bar(indexes, populations, 0.5, color='r')
1593  ax.set_ylabel('Population')
1594  ax.set_xlabel(('Cluster index'))
1595  plt.show()
1596 
1597 
1598 def plot_clusters_precisions(clusters):
1599  indexes = []
1600  precisions = []
1601  for cluster in clusters:
1602  indexes.append(cluster.cluster_id)
1603 
1604  prec = cluster.precision
1605  print(cluster.cluster_id, prec)
1606  if prec is None:
1607  prec = 0.0
1608  precisions.append(prec)
1609 
1610  import matplotlib.pyplot as plt
1611  fig, ax = plt.subplots()
1612  ax.bar(indexes, precisions, 0.5, color='r')
1613  ax.set_ylabel('Precision [A]')
1614  ax.set_xlabel(('Cluster index'))
1615  plt.show()
1616 
1617 
1618 def plot_clusters_scores(clusters):
1619  indexes = []
1620  values = []
1621  for cluster in clusters:
1622  indexes.append(cluster.cluster_id)
1623  values.append([])
1624  for data in cluster:
1625  values[-1].append(data.score)
1626 
1627  plot_fields_box_plots("scores.pdf", values, indexes, frequencies=None,
1628  valuename="Scores", positionname="Cluster index",
1629  xlabels=None, scale_plot_length=1.0)
1630 
1631 
1632 class CrossLinkIdentifierDatabase(object):
1633  def __init__(self):
1634  self.clidb = dict()
1635 
1636  def check_key(self, key):
1637  if key not in self.clidb:
1638  self.clidb[key] = {}
1639 
1640  def set_unique_id(self, key, value):
1641  self.check_key(key)
1642  self.clidb[key]["XLUniqueID"] = str(value)
1643 
1644  def set_protein1(self, key, value):
1645  self.check_key(key)
1646  self.clidb[key]["Protein1"] = str(value)
1647 
1648  def set_protein2(self, key, value):
1649  self.check_key(key)
1650  self.clidb[key]["Protein2"] = str(value)
1651 
1652  def set_residue1(self, key, value):
1653  self.check_key(key)
1654  self.clidb[key]["Residue1"] = int(value)
1655 
1656  def set_residue2(self, key, value):
1657  self.check_key(key)
1658  self.clidb[key]["Residue2"] = int(value)
1659 
1660  def set_idscore(self, key, value):
1661  self.check_key(key)
1662  self.clidb[key]["IDScore"] = float(value)
1663 
1664  def set_state(self, key, value):
1665  self.check_key(key)
1666  self.clidb[key]["State"] = int(value)
1667 
1668  def set_sigma1(self, key, value):
1669  self.check_key(key)
1670  self.clidb[key]["Sigma1"] = str(value)
1671 
1672  def set_sigma2(self, key, value):
1673  self.check_key(key)
1674  self.clidb[key]["Sigma2"] = str(value)
1675 
1676  def set_psi(self, key, value):
1677  self.check_key(key)
1678  self.clidb[key]["Psi"] = str(value)
1679 
1680  def get_unique_id(self, key):
1681  return self.clidb[key]["XLUniqueID"]
1682 
1683  def get_protein1(self, key):
1684  return self.clidb[key]["Protein1"]
1685 
1686  def get_protein2(self, key):
1687  return self.clidb[key]["Protein2"]
1688 
1689  def get_residue1(self, key):
1690  return self.clidb[key]["Residue1"]
1691 
1692  def get_residue2(self, key):
1693  return self.clidb[key]["Residue2"]
1694 
1695  def get_idscore(self, key):
1696  return self.clidb[key]["IDScore"]
1697 
1698  def get_state(self, key):
1699  return self.clidb[key]["State"]
1700 
1701  def get_sigma1(self, key):
1702  return self.clidb[key]["Sigma1"]
1703 
1704  def get_sigma2(self, key):
1705  return self.clidb[key]["Sigma2"]
1706 
1707  def get_psi(self, key):
1708  return self.clidb[key]["Psi"]
1709 
1710  def set_float_feature(self, key, value, feature_name):
1711  self.check_key(key)
1712  self.clidb[key][feature_name] = float(value)
1713 
1714  def set_int_feature(self, key, value, feature_name):
1715  self.check_key(key)
1716  self.clidb[key][feature_name] = int(value)
1717 
1718  def set_string_feature(self, key, value, feature_name):
1719  self.check_key(key)
1720  self.clidb[key][feature_name] = str(value)
1721 
1722  def get_feature(self, key, feature_name):
1723  return self.clidb[key][feature_name]
1724 
1725  def write(self, filename):
1726  with open(filename, 'wb') as handle:
1727  pickle.dump(self.clidb, handle)
1728 
1729  def load(self, filename):
1730  with open(filename, 'rb') as handle:
1731  self.clidb = pickle.load(handle)
1732 
1733 
1734 def plot_fields(fields, output, framemin=None, framemax=None):
1735  """Plot the given fields and save a figure as `output`.
1736  The fields generally are extracted from a stat file
1737  using ProcessOutput.get_fields()."""
1738  import matplotlib as mpl
1739  mpl.use('Agg')
1740  import matplotlib.pyplot as plt
1741 
1742  plt.rc('lines', linewidth=4)
1743  fig, axs = plt.subplots(nrows=len(fields))
1744  fig.set_size_inches(10.5, 5.5 * len(fields))
1745  plt.rc('axes')
1746 
1747  n = 0
1748  for key in fields:
1749  if framemin is None:
1750  framemin = 0
1751  if framemax is None:
1752  framemax = len(fields[key])
1753  x = list(range(framemin, framemax))
1754  y = [float(y) for y in fields[key][framemin:framemax]]
1755  if len(fields) > 1:
1756  axs[n].plot(x, y)
1757  axs[n].set_title(key, size="xx-large")
1758  axs[n].tick_params(labelsize=18, pad=10)
1759  else:
1760  axs.plot(x, y)
1761  axs.set_title(key, size="xx-large")
1762  axs.tick_params(labelsize=18, pad=10)
1763  n += 1
1764 
1765  # Tweak spacing between subplots to prevent labels from overlapping
1766  plt.subplots_adjust(hspace=0.3)
1767  plt.savefig(output)
1768 
1769 
1770 def plot_field_histogram(name, values_lists, valuename=None, bins=40,
1771  colors=None, format="png", reference_xline=None,
1772  yplotrange=None, xplotrange=None, normalized=True,
1773  leg_names=None):
1774  '''Plot a list of histograms from a value list.
1775  @param name the name of the plot
1776  @param value_lists the list of list of values eg: [[...],[...],[...]]
1777  @param valuename the y-label
1778  @param bins the number of bins
1779  @param colors If None, will use rainbow. Else will use specific list
1780  @param format output format
1781  @param reference_xline plot a reference line parallel to the y-axis
1782  @param yplotrange the range for the y-axis
1783  @param normalized whether the histogram is normalized or not
1784  @param leg_names names for the legend
1785  '''
1786 
1787  import matplotlib as mpl
1788  mpl.use('Agg')
1789  import matplotlib.pyplot as plt
1790  import matplotlib.cm as cm
1791  plt.figure(figsize=(18.0, 9.0))
1792 
1793  if colors is None:
1794  colors = cm.rainbow(np.linspace(0, 1, len(values_lists)))
1795  for nv, values in enumerate(values_lists):
1796  col = colors[nv]
1797  if leg_names is not None:
1798  label = leg_names[nv]
1799  else:
1800  label = str(nv)
1801  try:
1802  plt.hist(
1803  [float(y) for y in values], bins=bins, color=col,
1804  density=normalized, histtype='step', lw=4, label=label)
1805  except AttributeError:
1806  plt.hist(
1807  [float(y) for y in values], bins=bins, color=col,
1808  normed=normalized, histtype='step', lw=4, label=label)
1809 
1810  # plt.title(name,size="xx-large")
1811  plt.tick_params(labelsize=12, pad=10)
1812  if valuename is None:
1813  plt.xlabel(name, size="xx-large")
1814  else:
1815  plt.xlabel(valuename, size="xx-large")
1816  plt.ylabel("Frequency", size="xx-large")
1817 
1818  if yplotrange is not None:
1819  plt.ylim()
1820  if xplotrange is not None:
1821  plt.xlim(xplotrange)
1822 
1823  plt.legend(loc=2)
1824 
1825  if reference_xline is not None:
1826  plt.axvline(
1827  reference_xline,
1828  color='red',
1829  linestyle='dashed',
1830  linewidth=1)
1831 
1832  plt.savefig(name + "." + format, dpi=150, transparent=True)
1833 
1834 
1835 def plot_fields_box_plots(name, values, positions, frequencies=None,
1836  valuename="None", positionname="None",
1837  xlabels=None, scale_plot_length=1.0):
1838  '''
1839  Plot time series as boxplots.
1840  fields is a list of time series, positions are the x-values
1841  valuename is the y-label, positionname is the x-label
1842  '''
1843 
1844  import matplotlib as mpl
1845  mpl.use('Agg')
1846  import matplotlib.pyplot as plt
1847 
1848  bps = []
1849  fig = plt.figure(figsize=(float(len(positions))*scale_plot_length, 5.0))
1850  fig.canvas.manager.set_window_title(name)
1851 
1852  ax1 = fig.add_subplot(111)
1853 
1854  plt.subplots_adjust(left=0.1, right=0.990, top=0.95, bottom=0.4)
1855 
1856  bps.append(plt.boxplot(values, notch=0, sym='', vert=1,
1857  whis=1.5, positions=positions))
1858 
1859  plt.setp(bps[-1]['boxes'], color='black', lw=1.5)
1860  plt.setp(bps[-1]['whiskers'], color='black', ls=":", lw=1.5)
1861 
1862  if frequencies is not None:
1863  for n, v in enumerate(values):
1864  plist = [positions[n]]*len(v)
1865  ax1.plot(plist, v, 'gx', alpha=0.7, markersize=7)
1866 
1867  # print ax1.xaxis.get_majorticklocs()
1868  if xlabels is not None:
1869  ax1.set_xticklabels(xlabels)
1870  plt.xticks(rotation=90)
1871  plt.xlabel(positionname)
1872  plt.ylabel(valuename)
1873 
1874  plt.savefig(name + ".pdf", dpi=150)
1875  plt.show()
1876 
1877 
1878 def plot_xy_data(x, y, title=None, out_fn=None, display=True,
1879  set_plot_yaxis_range=None, xlabel=None, ylabel=None):
1880  import matplotlib as mpl
1881  mpl.use('Agg')
1882  import matplotlib.pyplot as plt
1883  plt.rc('lines', linewidth=2)
1884 
1885  fig, ax = plt.subplots(nrows=1)
1886  fig.set_size_inches(8, 4.5)
1887  if title is not None:
1888  fig.canvas.manager.set_window_title(title)
1889 
1890  ax.plot(x, y, color='r')
1891  if set_plot_yaxis_range is not None:
1892  x1, x2, y1, y2 = plt.axis()
1893  y1 = set_plot_yaxis_range[0]
1894  y2 = set_plot_yaxis_range[1]
1895  plt.axis((x1, x2, y1, y2))
1896  if title is not None:
1897  ax.set_title(title)
1898  if xlabel is not None:
1899  ax.set_xlabel(xlabel)
1900  if ylabel is not None:
1901  ax.set_ylabel(ylabel)
1902  if out_fn is not None:
1903  plt.savefig(out_fn + ".pdf")
1904  if display:
1905  plt.show()
1906  plt.close(fig)
1907 
1908 
1909 def plot_scatter_xy_data(x, y, labelx="None", labely="None",
1910  xmin=None, xmax=None, ymin=None, ymax=None,
1911  savefile=False, filename="None.eps", alpha=0.75):
1912 
1913  import matplotlib as mpl
1914  mpl.use('Agg')
1915  import matplotlib.pyplot as plt
1916  from matplotlib import rc
1917  rc('font', **{'family': 'sans-serif', 'sans-serif': ['Helvetica']})
1918 
1919  fig, axs = plt.subplots(1)
1920 
1921  axs0 = axs
1922 
1923  axs0.set_xlabel(labelx, size="xx-large")
1924  axs0.set_ylabel(labely, size="xx-large")
1925  axs0.tick_params(labelsize=18, pad=10)
1926 
1927  plot2 = []
1928 
1929  plot2.append(axs0.plot(x, y, 'o', color='k', lw=2, ms=0.1, alpha=alpha,
1930  c="w"))
1931 
1932  axs0.legend(
1933  loc=0,
1934  frameon=False,
1935  scatterpoints=1,
1936  numpoints=1,
1937  columnspacing=1)
1938 
1939  fig.set_size_inches(8.0, 8.0)
1940  fig.subplots_adjust(left=0.161, right=0.850, top=0.95, bottom=0.11)
1941  if (ymin is not None) and (ymax is not None):
1942  axs0.set_ylim(ymin, ymax)
1943  if (xmin is not None) and (xmax is not None):
1944  axs0.set_xlim(xmin, xmax)
1945 
1946  if savefile:
1947  fig.savefig(filename, dpi=300)
1948 
1949 
1950 def get_graph_from_hierarchy(hier):
1951  graph = []
1952  depth_dict = {}
1953  depth = 0
1954  (graph, depth, depth_dict) = recursive_graph(
1955  hier, graph, depth, depth_dict)
1956 
1957  # filters node labels according to depth_dict
1958  node_labels_dict = {}
1959  for key in depth_dict:
1960  if depth_dict[key] < 3:
1961  node_labels_dict[key] = key
1962  else:
1963  node_labels_dict[key] = ""
1964  draw_graph(graph, labels_dict=node_labels_dict)
1965 
1966 
1967 def recursive_graph(hier, graph, depth, depth_dict):
1968  depth = depth + 1
1969  nameh = IMP.atom.Hierarchy(hier).get_name()
1970  index = str(hier.get_particle().get_index())
1971  name1 = nameh + "|#" + index
1972  depth_dict[name1] = depth
1973 
1974  children = IMP.atom.Hierarchy(hier).get_children()
1975 
1976  if len(children) == 1 or children is None:
1977  depth = depth - 1
1978  return (graph, depth, depth_dict)
1979 
1980  else:
1981  for c in children:
1982  (graph, depth, depth_dict) = recursive_graph(
1983  c, graph, depth, depth_dict)
1984  nameh = IMP.atom.Hierarchy(c).get_name()
1985  index = str(c.get_particle().get_index())
1986  namec = nameh + "|#" + index
1987  graph.append((name1, namec))
1988 
1989  depth = depth - 1
1990  return (graph, depth, depth_dict)
1991 
1992 
1993 def draw_graph(graph, labels_dict=None, graph_layout='spring',
1994  node_size=5, node_color=None, node_alpha=0.3,
1995  node_text_size=11, fixed=None, pos=None,
1996  edge_color='blue', edge_alpha=0.3, edge_thickness=1,
1997  edge_text_pos=0.3,
1998  validation_edges=None,
1999  text_font='sans-serif',
2000  out_filename=None):
2001 
2002  import matplotlib as mpl
2003  mpl.use('Agg')
2004  import networkx as nx
2005  import matplotlib.pyplot as plt
2006  from math import sqrt, pi
2007 
2008  # create networkx graph
2009  G = nx.Graph()
2010 
2011  # add edges
2012  if isinstance(edge_thickness, list):
2013  for edge, weight in zip(graph, edge_thickness):
2014  G.add_edge(edge[0], edge[1], weight=weight)
2015  else:
2016  for edge in graph:
2017  G.add_edge(edge[0], edge[1])
2018 
2019  if node_color is None:
2020  node_color_rgb = (0, 0, 0)
2021  node_color_hex = "000000"
2022  else:
2024  tmpcolor_rgb = []
2025  tmpcolor_hex = []
2026  for node in G.nodes():
2027  cctuple = cc.rgb(node_color[node])
2028  tmpcolor_rgb.append((cctuple[0]/255,
2029  cctuple[1]/255,
2030  cctuple[2]/255))
2031  tmpcolor_hex.append(node_color[node])
2032  node_color_rgb = tmpcolor_rgb
2033  node_color_hex = tmpcolor_hex
2034 
2035  # get node sizes if dictionary
2036  if isinstance(node_size, dict):
2037  tmpsize = []
2038  for node in G.nodes():
2039  size = sqrt(node_size[node])/pi*10.0
2040  tmpsize.append(size)
2041  node_size = tmpsize
2042 
2043  for n, node in enumerate(G.nodes()):
2044  color = node_color_hex[n]
2045  size = node_size[n]
2046  nx.set_node_attributes(
2047  G, "graphics",
2048  {node: {'type': 'ellipse', 'w': size, 'h': size,
2049  'fill': '#' + color, 'label': node}})
2050  nx.set_node_attributes(
2051  G, "LabelGraphics",
2052  {node: {'type': 'text', 'text': node, 'color': '#000000',
2053  'visible': 'true'}})
2054 
2055  for edge in G.edges():
2056  nx.set_edge_attributes(
2057  G, "graphics",
2058  {edge: {'width': 1, 'fill': '#000000'}})
2059 
2060  for ve in validation_edges:
2061  print(ve)
2062  if (ve[0], ve[1]) in G.edges():
2063  print("found forward")
2064  nx.set_edge_attributes(
2065  G, "graphics",
2066  {ve: {'width': 1, 'fill': '#00FF00'}})
2067  elif (ve[1], ve[0]) in G.edges():
2068  print("found backward")
2069  nx.set_edge_attributes(
2070  G, "graphics",
2071  {(ve[1], ve[0]): {'width': 1, 'fill': '#00FF00'}})
2072  else:
2073  G.add_edge(ve[0], ve[1])
2074  print("not found")
2075  nx.set_edge_attributes(
2076  G, "graphics",
2077  {ve: {'width': 1, 'fill': '#FF0000'}})
2078 
2079  # these are different layouts for the network you may try
2080  # shell seems to work best
2081  if graph_layout == 'spring':
2082  print(fixed, pos)
2083  graph_pos = nx.spring_layout(G, k=1.0/8.0, fixed=fixed, pos=pos)
2084  elif graph_layout == 'spectral':
2085  graph_pos = nx.spectral_layout(G)
2086  elif graph_layout == 'random':
2087  graph_pos = nx.random_layout(G)
2088  else:
2089  graph_pos = nx.shell_layout(G)
2090 
2091  # draw graph
2092  nx.draw_networkx_nodes(G, graph_pos, node_size=node_size,
2093  alpha=node_alpha, node_color=node_color_rgb,
2094  linewidths=0)
2095  nx.draw_networkx_edges(G, graph_pos, width=edge_thickness,
2096  alpha=edge_alpha, edge_color=edge_color)
2097  nx.draw_networkx_labels(
2098  G, graph_pos, labels=labels_dict, font_size=node_text_size,
2099  font_family=text_font)
2100  if out_filename:
2101  plt.savefig(out_filename)
2102  nx.write_gml(G, 'out.gml')
2103  plt.show()
2104 
2105 
2106 def draw_table():
2107 
2108  # still an example!
2109 
2110  from ipyD3 import d3object
2111  from IPython.display import display
2112 
2113  d3 = d3object(width=800,
2114  height=400,
2115  style='JFTable',
2116  number=1,
2117  d3=None,
2118  title='Example table with d3js',
2119  desc='An example table created created with d3js with '
2120  'data generated with Python.')
2121  data = [[1277.0, 654.0, 288.0, 1976.0, 3281.0, 3089.0, 10336.0, 4650.0,
2122  4441.0, 4670.0, 944.0, 110.0],
2123  [1318.0, 664.0, 418.0, 1952.0, 3581.0, 4574.0, 11457.0, 6139.0,
2124  7078.0, 6561.0, 2354.0, 710.0],
2125  [1783.0, 774.0, 564.0, 1470.0, 3571.0, 3103.0, 9392.0, 5532.0,
2126  5661.0, 4991.0, 2032.0, 680.0],
2127  [1301.0, 604.0, 286.0, 2152.0, 3282.0, 3369.0, 10490.0, 5406.0,
2128  4727.0, 3428.0, 1559.0, 620.0],
2129  [1537.0, 1714.0, 724.0, 4824.0, 5551.0, 8096.0, 16589.0, 13650.0,
2130  9552.0, 13709.0, 2460.0, 720.0],
2131  [5691.0, 2995.0, 1680.0, 11741.0, 16232.0, 14731.0, 43522.0,
2132  32794.0, 26634.0, 31400.0, 7350.0, 3010.0],
2133  [1650.0, 2096.0, 60.0, 50.0, 1180.0, 5602.0, 15728.0, 6874.0,
2134  5115.0, 3510.0, 1390.0, 170.0],
2135  [72.0, 60.0, 60.0, 10.0, 120.0, 172.0, 1092.0, 675.0, 408.0,
2136  360.0, 156.0, 100.0]]
2137  data = [list(i) for i in zip(*data)]
2138  sRows = [['January',
2139  'February',
2140  'March',
2141  'April',
2142  'May',
2143  'June',
2144  'July',
2145  'August',
2146  'September',
2147  'October',
2148  'November',
2149  'Deecember']]
2150  sColumns = [['Prod {0}'.format(i) for i in range(1, 9)],
2151  [None, '', None, None, 'Group 1', None, None, 'Group 2']]
2152  d3.addSimpleTable(data,
2153  fontSizeCells=[12, ],
2154  sRows=sRows,
2155  sColumns=sColumns,
2156  sRowsMargins=[5, 50, 0],
2157  sColsMargins=[5, 20, 10],
2158  spacing=0,
2159  addBorders=1,
2160  addOutsideBorders=-1,
2161  rectWidth=45,
2162  rectHeight=0
2163  )
2164  html = d3.render(mode=['html', 'show'])
2165  display(html)
static bool get_is_setup(const IMP::ParticleAdaptor &p)
Definition: Residue.h:158
A container for models organized into clusters.
Definition: output.py:1522
A class for reading stat files (either rmf or ascii v1 and v2)
Definition: output.py:947
atom::Hierarchies create_hierarchies(RMF::FileConstHandle fh, Model *m)
RMF::FrameID save_frame(RMF::FileHandle file, std::string name="")
Save the current state of the linked objects as a new RMF frame.
static bool get_is_setup(const IMP::ParticleAdaptor &p)
Definition: atom/Atom.h:245
def plot_field_histogram
Plot a list of histograms from a value list.
Definition: output.py:1770
def plot_fields_box_plots
Plot time series as boxplots.
Definition: output.py:1835
Utility classes and functions for reading and storing PMI files.
def get_best_models
Given a list of stat files, read them all and find the best models.
Miscellaneous utilities.
Definition: tools.py:1
A class to store data associated to a model.
Definition: output.py:1500
void handle_use_deprecated(std::string message)
Break in this method in gdb to find deprecated uses at runtime.
std::string get_module_version()
Return the version of this module, as a string.
Change color code to hexadecimal to rgb.
Definition: tools.py:807
void write_pdb(const Selection &mhd, TextOutput out, unsigned int model=1)
Collect statistics from ProcessOutput.get_fields().
Definition: output.py:936
def get_prot_name_from_particle
Return the component name provided a particle and a list of names.
Definition: tools.py:550
def get_fields
Get the desired field names, and return a dictionary.
Definition: output.py:1012
Warning related to handling of structures.
static bool get_is_setup(Model *m, ParticleIndex pi)
Definition: Fragment.h:46
def link_to_rmf
Link to another RMF file.
Definition: output.py:1171
std::string get_molecule_name_and_copy(atom::Hierarchy h)
Walk up a PMI2 hierarchy/representations and get the "molname.copynum".
Definition: pmi/utilities.h:85
The standard decorator for manipulating molecular structures.
Ints get_index(const ParticlesTemp &particles, const Subset &subset, const Subsets &excluded)
def init_pdb
Init PDB Writing.
Definition: output.py:270
int get_number_of_frames(const ::npctransport_proto::Assignment &config, double time_step)
A decorator for a particle representing an atom.
Definition: atom/Atom.h:238
Base class for capturing a modeling protocol.
Definition: output.py:45
The type for a residue.
void load_frame(RMF::FileConstHandle file, RMF::FrameID frame)
Load the given RMF frame into the state of the linked objects.
A decorator for a particle with x,y,z coordinates.
Definition: XYZ.h:30
A base class for Keys.
Definition: Key.h:44
void add_hierarchies(RMF::NodeHandle fh, const atom::Hierarchies &hs)
Class for easy writing of PDBs, RMFs, and stat files.
Definition: output.py:203
void add_geometries(RMF::NodeHandle parent, const display::GeometriesTemp &r)
Add geometries to a given parent node.
void add_restraints(RMF::NodeHandle fh, const Restraints &hs)
A decorator for a particle that is part of a rigid body but not rigid.
Definition: rigid_bodies.h:750
bool get_is_canonical(atom::Hierarchy h)
Walk up a PMI2 hierarchy/representations and check if the root is named System.
Definition: pmi/utilities.h:91
Display a segment connecting a pair of particles.
Definition: XYZR.h:170
A decorator for a residue.
Definition: Residue.h:137
Basic functionality that is expected to be used by a wide variety of IMP users.
def get_pdb_names
Get a list of all PDB files being output by this instance.
Definition: output.py:231
def get_prot_name_from_particle
Get the protein name from the particle.
Definition: output.py:356
class to link stat files to several rmf files
Definition: output.py:1269
class to allow more advanced handling of RMF files.
Definition: output.py:1145
void link_hierarchies(RMF::FileConstHandle fh, const atom::Hierarchies &hs)
def plot_fields
Plot the given fields and save a figure as output.
Definition: output.py:1734
void add_geometry(RMF::FileHandle file, display::Geometry *r)
Add a single geometry to the file.
Store info for a chain of a protein.
Definition: Chain.h:61
Python classes to represent, score, sample and analyze models.
def init_pdb_best_scoring
Prepare for writing best-scoring PDBs (or mmCIFs) for a sampling run.
Definition: output.py:480
Functionality for loading, creating, manipulating and scoring atomic structures.
def get_rbs_and_beads
Returns unique objects in original order.
Definition: tools.py:1209
Hierarchies get_leaves(const Selection &h)
Select hierarchy particles identified by the biological name.
Definition: Selection.h:70
def init_rmf
Initialize an RMF file.
Definition: output.py:577
def get_residue_indexes
Retrieve the residue indexes for the given particle.
Definition: tools.py:570
static bool get_is_setup(const IMP::ParticleAdaptor &p)
Definition: rigid_bodies.h:752
std::string get_module_version()
Return the version of this module, as a string.
def sublist_iterator
Yield all sublists of length >= lmin and <= lmax.
Definition: tools.py:645
A decorator for a particle with x,y,z coordinates and a radius.
Definition: XYZR.h:27