IMP logo
IMP Reference Guide  2.15.0
The Integrative Modeling Platform
output.py
1 """@namespace IMP.pmi.output
2  Classes for writing output files and processing them.
3 """
4 
5 from __future__ import print_function, division
6 import IMP
7 import IMP.atom
8 import IMP.core
9 import IMP.pmi
10 import IMP.pmi.tools
11 import IMP.pmi.io
12 import os
13 import sys
14 import ast
15 import RMF
16 import numpy as np
17 import operator
18 import itertools
19 import warnings
20 import string
21 import ihm.format
22 import collections
23 try:
24  import cPickle as pickle
25 except ImportError:
26  import pickle
27 
28 
29 class _ChainIDs(object):
30  """Map indices to multi-character chain IDs.
31  We label the first 26 chains A-Z, then we move to two-letter
32  chain IDs: AA through AZ, then BA through BZ, through to ZZ.
33  This continues with longer chain IDs."""
34  def __getitem__(self, ind):
35  chars = string.ascii_uppercase
36  lc = len(chars)
37  ids = []
38  while ind >= lc:
39  ids.append(chars[ind % lc])
40  ind = ind // lc - 1
41  ids.append(chars[ind])
42  return "".join(reversed(ids))
43 
44 
45 class ProtocolOutput(object):
46  """Base class for capturing a modeling protocol.
47  Unlike simple output of model coordinates, a complete
48  protocol includes the input data used, details on the restraints,
49  sampling, and clustering, as well as output models.
50  Use via IMP.pmi.topology.System.add_protocol_output().
51 
52  @see IMP.pmi.mmcif.ProtocolOutput for a concrete subclass that outputs
53  mmCIF files.
54  """
55  pass
56 
57 
58 def _flatten(seq):
59  for elt in seq:
60  if isinstance(elt, (tuple, list)):
61  for elt2 in _flatten(elt):
62  yield elt2
63  else:
64  yield elt
65 
66 
67 def _disambiguate_chain(chid, seen_chains):
68  """Make sure that the chain ID is unique; warn and correct if it isn't"""
69  # Handle null chain IDs
70  if chid == '\0':
71  chid = ' '
72 
73  if chid in seen_chains:
74  warnings.warn("Duplicate chain ID '%s' encountered" % chid,
76 
77  for suffix in itertools.count(1):
78  new_chid = chid + "%d" % suffix
79  if new_chid not in seen_chains:
80  seen_chains.add(new_chid)
81  return new_chid
82  seen_chains.add(chid)
83  return chid
84 
85 
86 def _write_pdb_internal(flpdb, particle_infos_for_pdb, geometric_center,
87  write_all_residues_per_bead):
88  for n, tupl in enumerate(particle_infos_for_pdb):
89  (xyz, atom_type, residue_type,
90  chain_id, residue_index, all_indexes, radius) = tupl
91  if atom_type is None:
92  atom_type = IMP.atom.AT_CA
93  if write_all_residues_per_bead and all_indexes is not None:
94  for residue_number in all_indexes:
95  flpdb.write(
96  IMP.atom.get_pdb_string((xyz[0] - geometric_center[0],
97  xyz[1] - geometric_center[1],
98  xyz[2] - geometric_center[2]),
99  n+1, atom_type, residue_type,
100  chain_id[:1], residue_number, ' ',
101  1.00, radius))
102  else:
103  flpdb.write(
104  IMP.atom.get_pdb_string((xyz[0] - geometric_center[0],
105  xyz[1] - geometric_center[1],
106  xyz[2] - geometric_center[2]),
107  n+1, atom_type, residue_type,
108  chain_id[:1], residue_index, ' ',
109  1.00, radius))
110  flpdb.write("ENDMDL\n")
111 
112 
113 _Entity = collections.namedtuple('_Entity', ('id', 'seq'))
114 _ChainInfo = collections.namedtuple('_ChainInfo', ('entity', 'name'))
115 
116 
117 def _get_chain_info(chains, root_hier):
118  chain_info = {}
119  entities = {}
120  all_entities = []
121  for mol in IMP.atom.get_by_type(root_hier, IMP.atom.MOLECULE_TYPE):
123  chain_id = chains[molname]
124  chain = IMP.atom.Chain(mol)
125  seq = chain.get_sequence()
126  if seq not in entities:
127  entities[seq] = e = _Entity(id=len(entities)+1, seq=seq)
128  all_entities.append(e)
129  entity = entities[seq]
130  info = _ChainInfo(entity=entity, name=molname)
131  chain_info[chain_id] = info
132  return chain_info, all_entities
133 
134 
135 def _write_mmcif_internal(flpdb, particle_infos_for_pdb, geometric_center,
136  write_all_residues_per_bead, chains, root_hier):
137  # get dict with keys=chain IDs, values=chain info
138  chain_info, entities = _get_chain_info(chains, root_hier)
139 
140  writer = ihm.format.CifWriter(flpdb)
141  writer.start_block('model')
142  with writer.category("_entry") as lp:
143  lp.write(id='model')
144 
145  with writer.loop("_entity", ["id"]) as lp:
146  for e in entities:
147  lp.write(id=e.id)
148 
149  with writer.loop("_entity_poly",
150  ["entity_id", "pdbx_seq_one_letter_code"]) as lp:
151  for e in entities:
152  lp.write(entity_id=e.id, pdbx_seq_one_letter_code=e.seq)
153 
154  with writer.loop("_struct_asym", ["id", "entity_id", "details"]) as lp:
155  for chid in sorted(chains.values()):
156  ci = chain_info[chid]
157  lp.write(id=chid, entity_id=ci.entity.id, details=ci.name)
158 
159  with writer.loop("_atom_site",
160  ["group_PDB", "type_symbol", "label_atom_id",
161  "label_comp_id", "label_asym_id", "label_seq_id",
162  "auth_seq_id",
163  "Cartn_x", "Cartn_y", "Cartn_z", "label_entity_id",
164  "pdbx_pdb_model_num",
165  "id"]) as lp:
166  ordinal = 1
167  for n, tupl in enumerate(particle_infos_for_pdb):
168  (xyz, atom_type, residue_type,
169  chain_id, residue_index, all_indexes, radius) = tupl
170  ci = chain_info[chain_id]
171  if atom_type is None:
172  atom_type = IMP.atom.AT_CA
173  c = xyz - geometric_center
174  if write_all_residues_per_bead and all_indexes is not None:
175  for residue_number in all_indexes:
176  lp.write(group_PDB='ATOM',
177  type_symbol='C',
178  label_atom_id=atom_type.get_string(),
179  label_comp_id=residue_type.get_string(),
180  label_asym_id=chain_id,
181  label_seq_id=residue_index,
182  auth_seq_id=residue_index, Cartn_x=c[0],
183  Cartn_y=c[1], Cartn_z=c[2], id=ordinal,
184  pdbx_pdb_model_num=1,
185  label_entity_id=ci.entity.id)
186  ordinal += 1
187  else:
188  lp.write(group_PDB='ATOM', type_symbol='C',
189  label_atom_id=atom_type.get_string(),
190  label_comp_id=residue_type.get_string(),
191  label_asym_id=chain_id,
192  label_seq_id=residue_index,
193  auth_seq_id=residue_index, Cartn_x=c[0],
194  Cartn_y=c[1], Cartn_z=c[2], id=ordinal,
195  pdbx_pdb_model_num=1,
196  label_entity_id=ci.entity.id)
197  ordinal += 1
198 
199 
200 class Output(object):
201  """Class for easy writing of PDBs, RMFs, and stat files
202 
203  @note Model should be updated prior to writing outputs.
204  """
205  def __init__(self, ascii=True, atomistic=False):
206  self.dictionary_pdbs = {}
207  self._pdb_mmcif = {}
208  self.dictionary_rmfs = {}
209  self.dictionary_stats = {}
210  self.dictionary_stats2 = {}
211  self.best_score_list = None
212  self.nbestscoring = None
213  self.suffixes = []
214  self.replica_exchange = False
215  self.ascii = ascii
216  self.initoutput = {}
217  self.residuetypekey = IMP.StringKey("ResidueName")
218  # 1-character chain IDs, suitable for PDB output
219  self.chainids = "ABCDEFGHIJKLMNOPQRSTUVWXYZ" \
220  "abcdefghijklmnopqrstuvwxyz0123456789"
221  # Multi-character chain IDs, suitable for mmCIF output
222  self.multi_chainids = _ChainIDs()
223  self.dictchain = {} # keys are molecule names, values are chain ids
224  self.particle_infos_for_pdb = {}
225  self.atomistic = atomistic
226  self.use_pmi2 = False
227 
228  def get_pdb_names(self):
229  """Get a list of all PDB files being output by this instance"""
230  return list(self.dictionary_pdbs.keys())
231 
232  def get_rmf_names(self):
233  return list(self.dictionary_rmfs.keys())
234 
235  def get_stat_names(self):
236  return list(self.dictionary_stats.keys())
237 
238  def _init_dictchain(self, name, prot, multichar_chain=False):
239  self.dictchain[name] = {}
240  self.use_pmi2 = False
241  seen_chains = set()
242 
243  # attempt to find PMI objects.
244  if IMP.pmi.get_is_canonical(prot):
245  self.use_pmi2 = True
246  self.atomistic = True # detects automatically
247  for n, mol in enumerate(IMP.atom.get_by_type(
248  prot, IMP.atom.MOLECULE_TYPE)):
249  chid = _disambiguate_chain(IMP.atom.Chain(mol).get_id(),
250  seen_chains)
252  self.dictchain[name][molname] = chid
253  else:
254  chainids = self.multi_chainids if multichar_chain \
255  else self.chainids
256  for n, i in enumerate(self.dictionary_pdbs[name].get_children()):
257  self.dictchain[name][i.get_name()] = chainids[n]
258 
259  def init_pdb(self, name, prot, mmcif=False):
260  """Init PDB Writing.
261  @param name The PDB filename
262  @param prot The hierarchy to write to this pdb file
263  @param mmcif If True, write PDBs in mmCIF format
264  @note if the PDB name is 'System' then will use Selection
265  to get molecules
266  """
267  flpdb = open(name, 'w')
268  flpdb.close()
269  self.dictionary_pdbs[name] = prot
270  self._pdb_mmcif[name] = mmcif
271  self._init_dictchain(name, prot)
272 
273  def write_psf(self, filename, name):
274  flpsf = open(filename, 'w')
275  flpsf.write("PSF CMAP CHEQ" + "\n")
276  index_residue_pair_list = {}
277  (particle_infos_for_pdb, geometric_center) = \
278  self.get_particle_infos_for_pdb_writing(name)
279  nparticles = len(particle_infos_for_pdb)
280  flpsf.write(str(nparticles) + " !NATOM" + "\n")
281  for n, p in enumerate(particle_infos_for_pdb):
282  atom_index = n+1
283  residue_type = p[2]
284  chain = p[3]
285  resid = p[4]
286  flpsf.write('{0:8d}{1:1s}{2:4s}{3:1s}{4:4s}{5:1s}{6:4s}{7:1s}'
287  '{8:4s}{9:1s}{10:4s}{11:14.6f}{12:14.6f}{13:8d}'
288  '{14:14.6f}{15:14.6f}'.format(
289  atom_index, " ", chain, " ", str(resid), " ",
290  '"'+residue_type.get_string()+'"', " ", "C",
291  " ", "C", 1.0, 0.0, 0, 0.0, 0.0))
292  flpsf.write('\n')
293  if chain not in index_residue_pair_list:
294  index_residue_pair_list[chain] = [(atom_index, resid)]
295  else:
296  index_residue_pair_list[chain].append((atom_index, resid))
297 
298  # now write the connectivity
299  indexes_pairs = []
300  for chain in sorted(index_residue_pair_list.keys()):
301 
302  ls = index_residue_pair_list[chain]
303  # sort by residue
304  ls = sorted(ls, key=lambda tup: tup[1])
305  # get the index list
306  indexes = [x[0] for x in ls]
307  # get the contiguous pairs
308  indexes_pairs += list(IMP.pmi.tools.sublist_iterator(
309  indexes, lmin=2, lmax=2))
310  nbonds = len(indexes_pairs)
311  flpsf.write(str(nbonds)+" !NBOND: bonds"+"\n")
312 
313  # save bonds in fixed column format
314  for i in range(0, len(indexes_pairs), 4):
315  for bond in indexes_pairs[i:i+4]:
316  flpsf.write('{0:8d}{1:8d}'.format(*bond))
317  flpsf.write('\n')
318 
319  del particle_infos_for_pdb
320  flpsf.close()
321 
322  def write_pdb(self, name, appendmode=True,
323  translate_to_geometric_center=False,
324  write_all_residues_per_bead=False):
325 
326  (particle_infos_for_pdb,
327  geometric_center) = self.get_particle_infos_for_pdb_writing(name)
328 
329  if not translate_to_geometric_center:
330  geometric_center = (0, 0, 0)
331 
332  filemode = 'a' if appendmode else 'w'
333  with open(name, filemode) as flpdb:
334  if self._pdb_mmcif[name]:
335  _write_mmcif_internal(flpdb, particle_infos_for_pdb,
336  geometric_center,
337  write_all_residues_per_bead,
338  self.dictchain[name],
339  self.dictionary_pdbs[name])
340  else:
341  _write_pdb_internal(flpdb, particle_infos_for_pdb,
342  geometric_center,
343  write_all_residues_per_bead)
344 
345  def get_prot_name_from_particle(self, name, p):
346  """Get the protein name from the particle.
347  This is done by traversing the hierarchy."""
348  if self.use_pmi2:
349  return IMP.pmi.get_molecule_name_and_copy(p), True
350  else:
352  p, self.dictchain[name])
353 
354  def get_particle_infos_for_pdb_writing(self, name):
355  # index_residue_pair_list={}
356 
357  # the resindexes dictionary keep track of residues that have
358  # been already added to avoid duplication
359  # highest resolution have highest priority
360  resindexes_dict = {}
361 
362  # this dictionary dill contain the sequence of tuples needed to
363  # write the pdb
364  particle_infos_for_pdb = []
365 
366  geometric_center = [0, 0, 0]
367  atom_count = 0
368 
369  if self.use_pmi2:
370  # select highest resolution
371  sel = IMP.atom.Selection(self.dictionary_pdbs[name], resolution=0)
372  ps = sel.get_selected_particles()
373  else:
374  ps = IMP.atom.get_leaves(self.dictionary_pdbs[name])
375 
376  for n, p in enumerate(ps):
377  protname, is_a_bead = self.get_prot_name_from_particle(name, p)
378 
379  if protname not in resindexes_dict:
380  resindexes_dict[protname] = []
381 
382  if IMP.atom.Atom.get_is_setup(p) and self.atomistic:
383  residue = IMP.atom.Residue(IMP.atom.Atom(p).get_parent())
384  rt = residue.get_residue_type()
385  resind = residue.get_index()
386  atomtype = IMP.atom.Atom(p).get_atom_type()
387  xyz = list(IMP.core.XYZ(p).get_coordinates())
388  radius = IMP.core.XYZR(p).get_radius()
389  geometric_center[0] += xyz[0]
390  geometric_center[1] += xyz[1]
391  geometric_center[2] += xyz[2]
392  atom_count += 1
393  particle_infos_for_pdb.append(
394  (xyz, atomtype, rt, self.dictchain[name][protname],
395  resind, None, radius))
396  resindexes_dict[protname].append(resind)
397 
399 
400  residue = IMP.atom.Residue(p)
401  resind = residue.get_index()
402  # skip if the residue was already added by atomistic resolution
403  # 0
404  if resind in resindexes_dict[protname]:
405  continue
406  else:
407  resindexes_dict[protname].append(resind)
408  rt = residue.get_residue_type()
409  xyz = IMP.core.XYZ(p).get_coordinates()
410  radius = IMP.core.XYZR(p).get_radius()
411  geometric_center[0] += xyz[0]
412  geometric_center[1] += xyz[1]
413  geometric_center[2] += xyz[2]
414  atom_count += 1
415  particle_infos_for_pdb.append(
416  (xyz, None, rt, self.dictchain[name][protname], resind,
417  None, radius))
418 
419  elif IMP.atom.Fragment.get_is_setup(p) and not is_a_bead:
420  resindexes = IMP.pmi.tools.get_residue_indexes(p)
421  resind = resindexes[len(resindexes) // 2]
422  if resind in resindexes_dict[protname]:
423  continue
424  else:
425  resindexes_dict[protname].append(resind)
426  rt = IMP.atom.ResidueType('BEA')
427  xyz = IMP.core.XYZ(p).get_coordinates()
428  radius = IMP.core.XYZR(p).get_radius()
429  geometric_center[0] += xyz[0]
430  geometric_center[1] += xyz[1]
431  geometric_center[2] += xyz[2]
432  atom_count += 1
433  particle_infos_for_pdb.append(
434  (xyz, None, rt, self.dictchain[name][protname], resind,
435  resindexes, radius))
436 
437  else:
438  if is_a_bead:
439  rt = IMP.atom.ResidueType('BEA')
440  resindexes = IMP.pmi.tools.get_residue_indexes(p)
441  if len(resindexes) > 0:
442  resind = resindexes[len(resindexes) // 2]
443  xyz = IMP.core.XYZ(p).get_coordinates()
444  radius = IMP.core.XYZR(p).get_radius()
445  geometric_center[0] += xyz[0]
446  geometric_center[1] += xyz[1]
447  geometric_center[2] += xyz[2]
448  atom_count += 1
449  particle_infos_for_pdb.append(
450  (xyz, None, rt, self.dictchain[name][protname],
451  resind, resindexes, radius))
452 
453  if atom_count > 0:
454  geometric_center = (geometric_center[0] / atom_count,
455  geometric_center[1] / atom_count,
456  geometric_center[2] / atom_count)
457 
458  # sort by chain ID, then residue index. Longer chain IDs (e.g. AA)
459  # should always come after shorter (e.g. Z)
460  particle_infos_for_pdb = sorted(particle_infos_for_pdb,
461  key=lambda x: (len(x[3]), x[3], x[4]))
462 
463  return (particle_infos_for_pdb, geometric_center)
464 
465  def write_pdbs(self, appendmode=True, mmcif=False):
466  for pdb in self.dictionary_pdbs.keys():
467  self.write_pdb(pdb, appendmode)
468 
469  def init_pdb_best_scoring(self,
470  suffix,
471  prot,
472  nbestscoring,
473  replica_exchange=False, mmcif=False):
474  # save only the nbestscoring conformations
475  # create as many pdbs as needed
476 
477  self._pdb_best_scoring_mmcif = mmcif
478  fileext = '.cif' if mmcif else '.pdb'
479  self.suffixes.append(suffix)
480  self.replica_exchange = replica_exchange
481  if not self.replica_exchange:
482  # common usage
483  # if you are not in replica exchange mode
484  # initialize the array of scores internally
485  self.best_score_list = []
486  else:
487  # otherwise the replicas must communicate
488  # through a common file to know what are the best scores
489  self.best_score_file_name = "best.scores.rex.py"
490  self.best_score_list = []
491  with open(self.best_score_file_name, "w") as best_score_file:
492  best_score_file.write(
493  "self.best_score_list=" + str(self.best_score_list) + "\n")
494 
495  self.nbestscoring = nbestscoring
496  for i in range(self.nbestscoring):
497  name = suffix + "." + str(i) + fileext
498  flpdb = open(name, 'w')
499  flpdb.close()
500  self.dictionary_pdbs[name] = prot
501  self._pdb_mmcif[name] = mmcif
502  self._init_dictchain(name, prot)
503 
504  def write_pdb_best_scoring(self, score):
505  if self.nbestscoring is None:
506  print("Output.write_pdb_best_scoring: init_pdb_best_scoring "
507  "not run")
508 
509  mmcif = self._pdb_best_scoring_mmcif
510  fileext = '.cif' if mmcif else '.pdb'
511  # update the score list
512  if self.replica_exchange:
513  # read the self.best_score_list from the file
514  with open(self.best_score_file_name) as fh:
515  exec(fh.read())
516 
517  if len(self.best_score_list) < self.nbestscoring:
518  self.best_score_list.append(score)
519  self.best_score_list.sort()
520  index = self.best_score_list.index(score)
521  for suffix in self.suffixes:
522  for i in range(len(self.best_score_list) - 2, index - 1, -1):
523  oldname = suffix + "." + str(i) + fileext
524  newname = suffix + "." + str(i + 1) + fileext
525  # rename on Windows fails if newname already exists
526  if os.path.exists(newname):
527  os.unlink(newname)
528  os.rename(oldname, newname)
529  filetoadd = suffix + "." + str(index) + fileext
530  self.write_pdb(filetoadd, appendmode=False)
531 
532  else:
533  if score < self.best_score_list[-1]:
534  self.best_score_list.append(score)
535  self.best_score_list.sort()
536  self.best_score_list.pop(-1)
537  index = self.best_score_list.index(score)
538  for suffix in self.suffixes:
539  for i in range(len(self.best_score_list) - 1,
540  index - 1, -1):
541  oldname = suffix + "." + str(i) + fileext
542  newname = suffix + "." + str(i + 1) + fileext
543  os.rename(oldname, newname)
544  filenametoremove = suffix + \
545  "." + str(self.nbestscoring) + fileext
546  os.remove(filenametoremove)
547  filetoadd = suffix + "." + str(index) + fileext
548  self.write_pdb(filetoadd, appendmode=False)
549 
550  if self.replica_exchange:
551  # write the self.best_score_list to the file
552  with open(self.best_score_file_name, "w") as best_score_file:
553  best_score_file.write(
554  "self.best_score_list=" + str(self.best_score_list) + '\n')
555 
556  def init_rmf(self, name, hierarchies, rs=None, geometries=None,
557  listofobjects=None):
558  """
559  Initialize an RMF file
560 
561  @param name the name of the RMF file
562  @param hierarchies the hierarchies to be included (it is a list)
563  @param rs optional, the restraint sets (it is a list)
564  @param geometries optional, the geometries (it is a list)
565  @param listofobjects optional, the list of objects for the stat
566  (it is a list)
567  """
568  rh = RMF.create_rmf_file(name)
569  IMP.rmf.add_hierarchies(rh, hierarchies)
570  cat = None
571  outputkey_rmfkey = None
572 
573  if rs is not None:
574  IMP.rmf.add_restraints(rh, rs)
575  if geometries is not None:
576  IMP.rmf.add_geometries(rh, geometries)
577  if listofobjects is not None:
578  cat = rh.get_category("stat")
579  outputkey_rmfkey = {}
580  for o in listofobjects:
581  if "get_output" not in dir(o):
582  raise ValueError(
583  "Output: object %s doesn't have get_output() method"
584  % str(o))
585  output = o.get_output()
586  for outputkey in output:
587  rmftag = RMF.string_tag
588  if isinstance(output[outputkey], float):
589  rmftag = RMF.float_tag
590  elif isinstance(output[outputkey], int):
591  rmftag = RMF.int_tag
592  elif isinstance(output[outputkey], str):
593  rmftag = RMF.string_tag
594  else:
595  rmftag = RMF.string_tag
596  rmfkey = rh.get_key(cat, outputkey, rmftag)
597  outputkey_rmfkey[outputkey] = rmfkey
598  outputkey_rmfkey["rmf_file"] = \
599  rh.get_key(cat, "rmf_file", RMF.string_tag)
600  outputkey_rmfkey["rmf_frame_index"] = \
601  rh.get_key(cat, "rmf_frame_index", RMF.int_tag)
602 
603  self.dictionary_rmfs[name] = (rh, cat, outputkey_rmfkey, listofobjects)
604 
605  def add_restraints_to_rmf(self, name, objectlist):
606  for o in _flatten(objectlist):
607  try:
608  rs = o.get_restraint_for_rmf()
609  if not isinstance(rs, (list, tuple)):
610  rs = [rs]
611  except: # noqa: E722
612  rs = [o.get_restraint()]
614  self.dictionary_rmfs[name][0], rs)
615 
616  def add_geometries_to_rmf(self, name, objectlist):
617  for o in objectlist:
618  geos = o.get_geometries()
619  IMP.rmf.add_geometries(self.dictionary_rmfs[name][0], geos)
620 
621  def add_particle_pair_from_restraints_to_rmf(self, name, objectlist):
622  for o in objectlist:
623 
624  pps = o.get_particle_pairs()
625  for pp in pps:
627  self.dictionary_rmfs[name][0],
629 
630  def write_rmf(self, name):
631  IMP.rmf.save_frame(self.dictionary_rmfs[name][0])
632  if self.dictionary_rmfs[name][1] is not None:
633  outputkey_rmfkey = self.dictionary_rmfs[name][2]
634  listofobjects = self.dictionary_rmfs[name][3]
635  for o in listofobjects:
636  output = o.get_output()
637  for outputkey in output:
638  rmfkey = outputkey_rmfkey[outputkey]
639  try:
640  n = self.dictionary_rmfs[name][0].get_root_node()
641  n.set_value(rmfkey, output[outputkey])
642  except NotImplementedError:
643  continue
644  rmfkey = outputkey_rmfkey["rmf_file"]
645  self.dictionary_rmfs[name][0].get_root_node().set_value(
646  rmfkey, name)
647  rmfkey = outputkey_rmfkey["rmf_frame_index"]
648  nframes = self.dictionary_rmfs[name][0].get_number_of_frames()
649  self.dictionary_rmfs[name][0].get_root_node().set_value(
650  rmfkey, nframes-1)
651  self.dictionary_rmfs[name][0].flush()
652 
653  def close_rmf(self, name):
654  rh = self.dictionary_rmfs[name][0]
655  del self.dictionary_rmfs[name]
656  del rh
657 
658  def write_rmfs(self):
659  for rmfinfo in self.dictionary_rmfs.keys():
660  self.write_rmf(rmfinfo[0])
661 
662  def init_stat(self, name, listofobjects):
663  if self.ascii:
664  flstat = open(name, 'w')
665  flstat.close()
666  else:
667  flstat = open(name, 'wb')
668  flstat.close()
669 
670  # check that all objects in listofobjects have a get_output method
671  for o in listofobjects:
672  if "get_output" not in dir(o):
673  raise ValueError(
674  "Output: object %s doesn't have get_output() method"
675  % str(o))
676  self.dictionary_stats[name] = listofobjects
677 
678  def set_output_entry(self, key, value):
679  self.initoutput.update({key: value})
680 
681  def write_stat(self, name, appendmode=True):
682  output = self.initoutput
683  for obj in self.dictionary_stats[name]:
684  d = obj.get_output()
685  # remove all entries that begin with _ (private entries)
686  dfiltered = dict((k, v) for k, v in d.items() if k[0] != "_")
687  output.update(dfiltered)
688 
689  if appendmode:
690  writeflag = 'a'
691  else:
692  writeflag = 'w'
693 
694  if self.ascii:
695  flstat = open(name, writeflag)
696  flstat.write("%s \n" % output)
697  flstat.close()
698  else:
699  flstat = open(name, writeflag + 'b')
700  pickle.dump(output, flstat, 2)
701  flstat.close()
702 
703  def write_stats(self):
704  for stat in self.dictionary_stats.keys():
705  self.write_stat(stat)
706 
707  def get_stat(self, name):
708  output = {}
709  for obj in self.dictionary_stats[name]:
710  output.update(obj.get_output())
711  return output
712 
713  def write_test(self, name, listofobjects):
714  flstat = open(name, 'w')
715  output = self.initoutput
716  for o in listofobjects:
717  if "get_test_output" not in dir(o) and "get_output" not in dir(o):
718  raise ValueError(
719  "Output: object %s doesn't have get_output() or "
720  "get_test_output() method" % str(o))
721  self.dictionary_stats[name] = listofobjects
722 
723  for obj in self.dictionary_stats[name]:
724  try:
725  d = obj.get_test_output()
726  except: # noqa: E722
727  d = obj.get_output()
728  # remove all entries that begin with _ (private entries)
729  dfiltered = dict((k, v) for k, v in d.items() if k[0] != "_")
730  output.update(dfiltered)
731  flstat.write("%s \n" % output)
732  flstat.close()
733 
734  def test(self, name, listofobjects, tolerance=1e-5):
735  output = self.initoutput
736  for o in listofobjects:
737  if "get_test_output" not in dir(o) and "get_output" not in dir(o):
738  raise ValueError(
739  "Output: object %s doesn't have get_output() or "
740  "get_test_output() method" % str(o))
741  for obj in listofobjects:
742  try:
743  output.update(obj.get_test_output())
744  except: # noqa: E722
745  output.update(obj.get_output())
746 
747  flstat = open(name, 'r')
748 
749  passed = True
750  for fl in flstat:
751  test_dict = ast.literal_eval(fl)
752  for k in test_dict:
753  if k in output:
754  old_value = str(test_dict[k])
755  new_value = str(output[k])
756  try:
757  float(old_value)
758  is_float = True
759  except ValueError:
760  is_float = False
761 
762  if is_float:
763  fold = float(old_value)
764  fnew = float(new_value)
765  diff = abs(fold - fnew)
766  if diff > tolerance:
767  print("%s: test failed, old value: %s new value %s; "
768  "diff %f > %f" % (str(k), str(old_value),
769  str(new_value), diff,
770  tolerance), file=sys.stderr)
771  passed = False
772  elif test_dict[k] != output[k]:
773  if len(old_value) < 50 and len(new_value) < 50:
774  print("%s: test failed, old value: %s new value %s"
775  % (str(k), old_value, new_value),
776  file=sys.stderr)
777  passed = False
778  else:
779  print("%s: test failed, omitting results (too long)"
780  % str(k), file=sys.stderr)
781  passed = False
782 
783  else:
784  print("%s from old objects (file %s) not in new objects"
785  % (str(k), str(name)), file=sys.stderr)
786  flstat.close()
787  return passed
788 
789  def get_environment_variables(self):
790  import os
791  return str(os.environ)
792 
793  def get_versions_of_relevant_modules(self):
794  import IMP
795  versions = {}
796  versions["IMP_VERSION"] = IMP.get_module_version()
797  versions["PMI_VERSION"] = IMP.pmi.get_module_version()
798  try:
799  import IMP.isd2
800  versions["ISD2_VERSION"] = IMP.isd2.get_module_version()
801  except ImportError:
802  pass
803  try:
804  import IMP.isd_emxl
805  versions["ISD_EMXL_VERSION"] = IMP.isd_emxl.get_module_version()
806  except ImportError:
807  pass
808  return versions
809 
810  def init_stat2(self, name, listofobjects, extralabels=None,
811  listofsummedobjects=None):
812  # this is a new stat file that should be less
813  # space greedy!
814  # listofsummedobjects must be in the form
815  # [([obj1,obj2,obj3,obj4...],label)]
816  # extralabels
817 
818  if listofsummedobjects is None:
819  listofsummedobjects = []
820  if extralabels is None:
821  extralabels = []
822  flstat = open(name, 'w')
823  output = {}
824  stat2_keywords = {"STAT2HEADER": "STAT2HEADER"}
825  stat2_keywords.update(
826  {"STAT2HEADER_ENVIRON": str(self.get_environment_variables())})
827  stat2_keywords.update(
828  {"STAT2HEADER_IMP_VERSIONS":
829  str(self.get_versions_of_relevant_modules())})
830  stat2_inverse = {}
831 
832  for obj in listofobjects:
833  if "get_output" not in dir(obj):
834  raise ValueError(
835  "Output: object %s doesn't have get_output() method"
836  % str(obj))
837  else:
838  d = obj.get_output()
839  # remove all entries that begin with _ (private entries)
840  dfiltered = dict((k, v)
841  for k, v in d.items() if k[0] != "_")
842  output.update(dfiltered)
843 
844  # check for customizable entries
845  for obj in listofsummedobjects:
846  for t in obj[0]:
847  if "get_output" not in dir(t):
848  raise ValueError(
849  "Output: object %s doesn't have get_output() method"
850  % str(t))
851  else:
852  if "_TotalScore" not in t.get_output():
853  raise ValueError(
854  "Output: object %s doesn't have _TotalScore "
855  "entry to be summed" % str(t))
856  else:
857  output.update({obj[1]: 0.0})
858 
859  for k in extralabels:
860  output.update({k: 0.0})
861 
862  for n, k in enumerate(output):
863  stat2_keywords.update({n: k})
864  stat2_inverse.update({k: n})
865 
866  flstat.write("%s \n" % stat2_keywords)
867  flstat.close()
868  self.dictionary_stats2[name] = (
869  listofobjects,
870  stat2_inverse,
871  listofsummedobjects,
872  extralabels)
873 
874  def write_stat2(self, name, appendmode=True):
875  output = {}
876  (listofobjects, stat2_inverse, listofsummedobjects,
877  extralabels) = self.dictionary_stats2[name]
878 
879  # writing objects
880  for obj in listofobjects:
881  od = obj.get_output()
882  dfiltered = dict((k, v) for k, v in od.items() if k[0] != "_")
883  for k in dfiltered:
884  output.update({stat2_inverse[k]: od[k]})
885 
886  # writing summedobjects
887  for so in listofsummedobjects:
888  partial_score = 0.0
889  for t in so[0]:
890  d = t.get_output()
891  partial_score += float(d["_TotalScore"])
892  output.update({stat2_inverse[so[1]]: str(partial_score)})
893 
894  # writing extralabels
895  for k in extralabels:
896  if k in self.initoutput:
897  output.update({stat2_inverse[k]: self.initoutput[k]})
898  else:
899  output.update({stat2_inverse[k]: "None"})
900 
901  if appendmode:
902  writeflag = 'a'
903  else:
904  writeflag = 'w'
905 
906  flstat = open(name, writeflag)
907  flstat.write("%s \n" % output)
908  flstat.close()
909 
910  def write_stats2(self):
911  for stat in self.dictionary_stats2.keys():
912  self.write_stat2(stat)
913 
914 
915 class OutputStatistics(object):
916  """Collect statistics from ProcessOutput.get_fields().
917  Counters of the total number of frames read, plus the models that
918  passed the various filters used in get_fields(), are provided."""
919  def __init__(self):
920  self.total = 0
921  self.passed_get_every = 0
922  self.passed_filterout = 0
923  self.passed_filtertuple = 0
924 
925 
926 class ProcessOutput(object):
927  """A class for reading stat files (either rmf or ascii v1 and v2)"""
928  def __init__(self, filename):
929  self.filename = filename
930  self.isstat1 = False
931  self.isstat2 = False
932  self.isrmf = False
933 
934  if self.filename is None:
935  raise ValueError("No file name provided. Use -h for help")
936 
937  try:
938  # let's see if that is an rmf file
939  rh = RMF.open_rmf_file_read_only(self.filename)
940  self.isrmf = True
941  cat = rh.get_category('stat')
942  rmf_klist = rh.get_keys(cat)
943  self.rmf_names_keys = dict([(rh.get_name(k), k)
944  for k in rmf_klist])
945  del rh
946 
947  except IOError:
948  f = open(self.filename, "r")
949  # try with an ascii stat file
950  # get the keys from the first line
951  for line in f.readlines():
952  d = ast.literal_eval(line)
953  self.klist = list(d.keys())
954  # check if it is a stat2 file
955  if "STAT2HEADER" in self.klist:
956  self.isstat2 = True
957  for k in self.klist:
958  if "STAT2HEADER" in str(k):
959  # if print_header: print k, d[k]
960  del d[k]
961  stat2_dict = d
962  # get the list of keys sorted by value
963  kkeys = [k[0]
964  for k in sorted(stat2_dict.items(),
965  key=operator.itemgetter(1))]
966  self.klist = [k[1]
967  for k in sorted(stat2_dict.items(),
968  key=operator.itemgetter(1))]
969  self.invstat2_dict = {}
970  for k in kkeys:
971  self.invstat2_dict.update({stat2_dict[k]: k})
972  else:
974  "statfile v1 is deprecated. "
975  "Please convert to statfile v2.\n")
976  self.isstat1 = True
977  self.klist.sort()
978 
979  break
980  f.close()
981 
982  def get_keys(self):
983  if self.isrmf:
984  return sorted(self.rmf_names_keys.keys())
985  else:
986  return self.klist
987 
988  def show_keys(self, ncolumns=2, truncate=65):
989  IMP.pmi.tools.print_multicolumn(self.get_keys(), ncolumns, truncate)
990 
991  def get_fields(self, fields, filtertuple=None, filterout=None, get_every=1,
992  statistics=None):
993  '''
994  Get the desired field names, and return a dictionary.
995  Namely, "fields" are the queried keys in the stat file
996  (eg. ["Total_Score",...])
997  The returned data structure is a dictionary, where each key is
998  a field and the value is the time series (ie, frame ordered series)
999  of that field (ie, {"Total_Score":[Score_0,Score_1,Score_2,,...],....})
1000 
1001  @param fields (list of strings) queried keys in the stat file
1002  (eg. "Total_Score"....)
1003  @param filterout specify if you want to "grep" out something from
1004  the file, so that it is faster
1005  @param filtertuple a tuple that contains
1006  ("TheKeyToBeFiltered",relationship,value)
1007  where relationship = "<", "==", or ">"
1008  @param get_every only read every Nth line from the file
1009  @param statistics if provided, accumulate statistics in an
1010  OutputStatistics object
1011  '''
1012 
1013  if statistics is None:
1014  statistics = OutputStatistics()
1015  outdict = {}
1016  for field in fields:
1017  outdict[field] = []
1018 
1019  # print fields values
1020  if self.isrmf:
1021  rh = RMF.open_rmf_file_read_only(self.filename)
1022  nframes = rh.get_number_of_frames()
1023  for i in range(nframes):
1024  statistics.total += 1
1025  # "get_every" and "filterout" not enforced for RMF
1026  statistics.passed_get_every += 1
1027  statistics.passed_filterout += 1
1028  IMP.rmf.load_frame(rh, RMF.FrameID(i))
1029  if filtertuple is not None:
1030  keytobefiltered = filtertuple[0]
1031  relationship = filtertuple[1]
1032  value = filtertuple[2]
1033  datavalue = rh.get_root_node().get_value(
1034  self.rmf_names_keys[keytobefiltered])
1035  if self.isfiltered(datavalue, relationship, value):
1036  continue
1037 
1038  statistics.passed_filtertuple += 1
1039  for field in fields:
1040  outdict[field].append(rh.get_root_node().get_value(
1041  self.rmf_names_keys[field]))
1042 
1043  else:
1044  f = open(self.filename, "r")
1045  line_number = 0
1046 
1047  for line in f.readlines():
1048  statistics.total += 1
1049  if filterout is not None:
1050  if filterout in line:
1051  continue
1052  statistics.passed_filterout += 1
1053  line_number += 1
1054 
1055  if line_number % get_every != 0:
1056  if line_number == 1 and self.isstat2:
1057  statistics.total -= 1
1058  statistics.passed_filterout -= 1
1059  continue
1060  statistics.passed_get_every += 1
1061  try:
1062  d = ast.literal_eval(line)
1063  except: # noqa: E722
1064  print("# Warning: skipped line number " + str(line_number)
1065  + " not a valid line")
1066  continue
1067 
1068  if self.isstat1:
1069 
1070  if filtertuple is not None:
1071  keytobefiltered = filtertuple[0]
1072  relationship = filtertuple[1]
1073  value = filtertuple[2]
1074  datavalue = d[keytobefiltered]
1075  if self.isfiltered(datavalue, relationship, value):
1076  continue
1077 
1078  statistics.passed_filtertuple += 1
1079  [outdict[field].append(d[field]) for field in fields]
1080 
1081  elif self.isstat2:
1082  if line_number == 1:
1083  statistics.total -= 1
1084  statistics.passed_filterout -= 1
1085  statistics.passed_get_every -= 1
1086  continue
1087 
1088  if filtertuple is not None:
1089  keytobefiltered = filtertuple[0]
1090  relationship = filtertuple[1]
1091  value = filtertuple[2]
1092  datavalue = d[self.invstat2_dict[keytobefiltered]]
1093  if self.isfiltered(datavalue, relationship, value):
1094  continue
1095 
1096  statistics.passed_filtertuple += 1
1097  [outdict[field].append(d[self.invstat2_dict[field]])
1098  for field in fields]
1099 
1100  f.close()
1101 
1102  return outdict
1103 
1104  def isfiltered(self, datavalue, relationship, refvalue):
1105  dofilter = False
1106  try:
1107  _ = float(datavalue)
1108  except ValueError:
1109  raise ValueError("ProcessOutput.filter: datavalue cannot be "
1110  "converted into a float")
1111 
1112  if relationship == "<":
1113  if float(datavalue) >= refvalue:
1114  dofilter = True
1115  if relationship == ">":
1116  if float(datavalue) <= refvalue:
1117  dofilter = True
1118  if relationship == "==":
1119  if float(datavalue) != refvalue:
1120  dofilter = True
1121  return dofilter
1122 
1123 
1125  """ class to allow more advanced handling of RMF files.
1126  It is both a container and a IMP.atom.Hierarchy.
1127  - it is iterable (while loading the corresponding frame)
1128  - Item brackets [] load the corresponding frame
1129  - slice create an iterator
1130  - can relink to another RMF file
1131  """
1132  def __init__(self, model, rmf_file_name):
1133  """
1134  @param model: the IMP.Model()
1135  @param rmf_file_name: str, path of the rmf file
1136  """
1137  self.model = model
1138  try:
1139  self.rh_ref = RMF.open_rmf_file_read_only(rmf_file_name)
1140  except TypeError:
1141  raise TypeError("Wrong rmf file name or type: %s"
1142  % str(rmf_file_name))
1143  hs = IMP.rmf.create_hierarchies(self.rh_ref, self.model)
1144  IMP.rmf.load_frame(self.rh_ref, RMF.FrameID(0))
1145  self.root_hier_ref = hs[0]
1146  IMP.atom.Hierarchy.__init__(self, self.root_hier_ref)
1147  self.model.update()
1148  self.ColorHierarchy = None
1149 
1150  def link_to_rmf(self, rmf_file_name):
1151  """
1152  Link to another RMF file
1153  """
1154  self.rh_ref = RMF.open_rmf_file_read_only(rmf_file_name)
1155  IMP.rmf.link_hierarchies(self.rh_ref, [self])
1156  if self.ColorHierarchy:
1157  self.ColorHierarchy.method()
1158  RMFHierarchyHandler.set_frame(self, 0)
1159 
1160  def set_frame(self, index):
1161  try:
1162  IMP.rmf.load_frame(self.rh_ref, RMF.FrameID(index))
1163  except: # noqa: E722
1164  print("skipping frame %s:%d\n" % (self.current_rmf, index))
1165  self.model.update()
1166 
1167  def get_number_of_frames(self):
1168  return self.rh_ref.get_number_of_frames()
1169 
1170  def __getitem__(self, int_slice_adaptor):
1171  if isinstance(int_slice_adaptor, int):
1172  self.set_frame(int_slice_adaptor)
1173  return int_slice_adaptor
1174  elif isinstance(int_slice_adaptor, slice):
1175  return self.__iter__(int_slice_adaptor)
1176  else:
1177  raise TypeError("Unknown Type")
1178 
1179  def __len__(self):
1180  return self.get_number_of_frames()
1181 
1182  def __iter__(self, slice_key=None):
1183  if slice_key is None:
1184  for nframe in range(len(self)):
1185  yield self[nframe]
1186  else:
1187  for nframe in list(range(len(self)))[slice_key]:
1188  yield self[nframe]
1189 
1190 
1191 class CacheHierarchyCoordinates(object):
1192  def __init__(self, StatHierarchyHandler):
1193  self.xyzs = []
1194  self.nrms = []
1195  self.rbs = []
1196  self.nrm_coors = {}
1197  self.xyz_coors = {}
1198  self.rb_trans = {}
1199  self.current_index = None
1200  self.rmfh = StatHierarchyHandler
1201  rbs, xyzs = IMP.pmi.tools.get_rbs_and_beads([self.rmfh])
1202  self.model = self.rmfh.get_model()
1203  self.rbs = rbs
1204  for xyz in xyzs:
1206  nrm = IMP.core.NonRigidMember(xyz)
1207  self.nrms.append(nrm)
1208  else:
1209  fb = IMP.core.XYZ(xyz)
1210  self.xyzs.append(fb)
1211 
1212  def do_store(self, index):
1213  self.rb_trans[index] = {}
1214  self.nrm_coors[index] = {}
1215  self.xyz_coors[index] = {}
1216  for rb in self.rbs:
1217  self.rb_trans[index][rb] = rb.get_reference_frame()
1218  for nrm in self.nrms:
1219  self.nrm_coors[index][nrm] = nrm.get_internal_coordinates()
1220  for xyz in self.xyzs:
1221  self.xyz_coors[index][xyz] = xyz.get_coordinates()
1222  self.current_index = index
1223 
1224  def do_update(self, index):
1225  if self.current_index != index:
1226  for rb in self.rbs:
1227  rb.set_reference_frame(self.rb_trans[index][rb])
1228  for nrm in self.nrms:
1229  nrm.set_internal_coordinates(self.nrm_coors[index][nrm])
1230  for xyz in self.xyzs:
1231  xyz.set_coordinates(self.xyz_coors[index][xyz])
1232  self.current_index = index
1233  self.model.update()
1234 
1235  def get_number_of_frames(self):
1236  return len(self.rb_trans.keys())
1237 
1238  def __getitem__(self, index):
1239  if isinstance(index, int):
1240  return index in self.rb_trans.keys()
1241  else:
1242  raise TypeError("Unknown Type")
1243 
1244  def __len__(self):
1245  return self.get_number_of_frames()
1246 
1247 
1249  """ class to link stat files to several rmf files """
1250  def __init__(self, model=None, stat_file=None,
1251  number_best_scoring_models=None, score_key=None,
1252  StatHierarchyHandler=None, cache=None):
1253  """
1254 
1255  @param model: IMP.Model()
1256  @param stat_file: either 1) a list or 2) a single stat file names
1257  (either rmfs or ascii, or pickled data or pickled cluster),
1258  3) a dictionary containing an rmf/ascii
1259  stat file name as key and a list of frames as values
1260  @param number_best_scoring_models:
1261  @param StatHierarchyHandler: copy constructor input object
1262  @param cache: cache coordinates and rigid body transformations.
1263  """
1264 
1265  if StatHierarchyHandler is not None:
1266  # overrides all other arguments
1267  # copy constructor: create a copy with
1268  # different RMFHierarchyHandler
1269  self.model = StatHierarchyHandler.model
1270  self.data = StatHierarchyHandler.data
1271  self.number_best_scoring_models = \
1272  StatHierarchyHandler.number_best_scoring_models
1273  self.is_setup = True
1274  self.current_rmf = StatHierarchyHandler.current_rmf
1275  self.current_frame = None
1276  self.current_index = None
1277  self.score_threshold = StatHierarchyHandler.score_threshold
1278  self.score_key = StatHierarchyHandler.score_key
1279  self.cache = StatHierarchyHandler.cache
1280  RMFHierarchyHandler.__init__(self, self.model,
1281  self.current_rmf)
1282  if self.cache:
1283  self.cache = CacheHierarchyCoordinates(self)
1284  else:
1285  self.cache = None
1286  self.set_frame(0)
1287 
1288  else:
1289  # standard constructor
1290  self.model = model
1291  self.data = []
1292  self.number_best_scoring_models = number_best_scoring_models
1293  self.cache = cache
1294 
1295  if score_key is None:
1296  self.score_key = "Total_Score"
1297  else:
1298  self.score_key = score_key
1299  self.is_setup = None
1300  self.current_rmf = None
1301  self.current_frame = None
1302  self.current_index = None
1303  self.score_threshold = None
1304 
1305  if isinstance(stat_file, str):
1306  self.add_stat_file(stat_file)
1307  elif isinstance(stat_file, list):
1308  for f in stat_file:
1309  self.add_stat_file(f)
1310 
1311  def add_stat_file(self, stat_file):
1312  try:
1313  '''check that it is not a pickle file with saved data
1314  from a previous calculation'''
1315  self.load_data(stat_file)
1316 
1317  if self.number_best_scoring_models:
1318  scores = self.get_scores()
1319  max_score = sorted(scores)[
1320  0:min(len(self), self.number_best_scoring_models)][-1]
1321  self.do_filter_by_score(max_score)
1322 
1323  except pickle.UnpicklingError:
1324  '''alternatively read the ascii stat files'''
1325  try:
1326  scores, rmf_files, rmf_frame_indexes, features = \
1327  self.get_info_from_stat_file(stat_file,
1328  self.score_threshold)
1329  except (KeyError, SyntaxError):
1330  # in this case check that is it an rmf file, probably
1331  # without stat stored in
1332  try:
1333  # let's see if that is an rmf file
1334  rh = RMF.open_rmf_file_read_only(stat_file)
1335  nframes = rh.get_number_of_frames()
1336  scores = [0.0]*nframes
1337  rmf_files = [stat_file]*nframes
1338  rmf_frame_indexes = range(nframes)
1339  features = {}
1340  except: # noqa: E722
1341  return
1342 
1343  if len(set(rmf_files)) > 1:
1344  raise ("Multiple RMF files found")
1345 
1346  if not rmf_files:
1347  print("StatHierarchyHandler: Error: Trying to set none as "
1348  "rmf_file (probably empty stat file), aborting")
1349  return
1350 
1351  for n, index in enumerate(rmf_frame_indexes):
1352  featn_dict = dict([(k, features[k][n]) for k in features])
1353  self.data.append(IMP.pmi.output.DataEntry(
1354  stat_file, rmf_files[n], index, scores[n], featn_dict))
1355 
1356  if self.number_best_scoring_models:
1357  scores = self.get_scores()
1358  max_score = sorted(scores)[
1359  0:min(len(self), self.number_best_scoring_models)][-1]
1360  self.do_filter_by_score(max_score)
1361 
1362  if not self.is_setup:
1363  RMFHierarchyHandler.__init__(
1364  self, self.model, self.get_rmf_names()[0])
1365  if self.cache:
1366  self.cache = CacheHierarchyCoordinates(self)
1367  else:
1368  self.cache = None
1369  self.is_setup = True
1370  self.current_rmf = self.get_rmf_names()[0]
1371 
1372  self.set_frame(0)
1373 
1374  def save_data(self, filename='data.pkl'):
1375  with open(filename, 'wb') as fl:
1376  pickle.dump(self.data, fl)
1377 
1378  def load_data(self, filename='data.pkl'):
1379  with open(filename, 'rb') as fl:
1380  data_structure = pickle.load(fl)
1381  # first check that it is a list
1382  if not isinstance(data_structure, list):
1383  raise TypeError(
1384  "%filename should contain a list of IMP.pmi.output.DataEntry "
1385  "or IMP.pmi.output.Cluster" % filename)
1386  # second check the types
1387  if all(isinstance(item, IMP.pmi.output.DataEntry)
1388  for item in data_structure):
1389  self.data = data_structure
1390  elif all(isinstance(item, IMP.pmi.output.Cluster)
1391  for item in data_structure):
1392  nmodels = 0
1393  for cluster in data_structure:
1394  nmodels += len(cluster)
1395  self.data = [None]*nmodels
1396  for cluster in data_structure:
1397  for n, data in enumerate(cluster):
1398  index = cluster.members[n]
1399  self.data[index] = data
1400  else:
1401  raise TypeError(
1402  "%filename should contain a list of IMP.pmi.output.DataEntry "
1403  "or IMP.pmi.output.Cluster" % filename)
1404 
1405  def set_frame(self, index):
1406  if self.cache is not None and self.cache[index]:
1407  self.cache.do_update(index)
1408  else:
1409  nm = self.data[index].rmf_name
1410  fidx = self.data[index].rmf_index
1411  if nm != self.current_rmf:
1412  self.link_to_rmf(nm)
1413  self.current_rmf = nm
1414  self.current_frame = -1
1415  if fidx != self.current_frame:
1416  RMFHierarchyHandler.set_frame(self, fidx)
1417  self.current_frame = fidx
1418  if self.cache is not None:
1419  self.cache.do_store(index)
1420 
1421  self.current_index = index
1422 
1423  def __getitem__(self, int_slice_adaptor):
1424  if isinstance(int_slice_adaptor, int):
1425  self.set_frame(int_slice_adaptor)
1426  return self.data[int_slice_adaptor]
1427  elif isinstance(int_slice_adaptor, slice):
1428  return self.__iter__(int_slice_adaptor)
1429  else:
1430  raise TypeError("Unknown Type")
1431 
1432  def __len__(self):
1433  return len(self.data)
1434 
1435  def __iter__(self, slice_key=None):
1436  if slice_key is None:
1437  for i in range(len(self)):
1438  yield self[i]
1439  else:
1440  for i in range(len(self))[slice_key]:
1441  yield self[i]
1442 
1443  def do_filter_by_score(self, maximum_score):
1444  self.data = [d for d in self.data if d.score <= maximum_score]
1445 
1446  def get_scores(self):
1447  return [d.score for d in self.data]
1448 
1449  def get_feature_series(self, feature_name):
1450  return [d.features[feature_name] for d in self.data]
1451 
1452  def get_feature_names(self):
1453  return self.data[0].features.keys()
1454 
1455  def get_rmf_names(self):
1456  return [d.rmf_name for d in self.data]
1457 
1458  def get_stat_files_names(self):
1459  return [d.stat_file for d in self.data]
1460 
1461  def get_rmf_indexes(self):
1462  return [d.rmf_index for d in self.data]
1463 
1464  def get_info_from_stat_file(self, stat_file, score_threshold=None):
1465  po = ProcessOutput(stat_file)
1466  fs = po.get_keys()
1467  models = IMP.pmi.io.get_best_models(
1468  [stat_file], score_key=self.score_key, feature_keys=fs,
1469  rmf_file_key="rmf_file", rmf_file_frame_key="rmf_frame_index",
1470  prefiltervalue=score_threshold, get_every=1)
1471 
1472  scores = [float(y) for y in models[2]]
1473  rmf_files = models[0]
1474  rmf_frame_indexes = models[1]
1475  features = models[3]
1476  return scores, rmf_files, rmf_frame_indexes, features
1477 
1478 
1479 class DataEntry(object):
1480  '''
1481  A class to store data associated to a model
1482  '''
1483  def __init__(self, stat_file=None, rmf_name=None, rmf_index=None,
1484  score=None, features=None):
1485  self.rmf_name = rmf_name
1486  self.rmf_index = rmf_index
1487  self.score = score
1488  self.features = features
1489  self.stat_file = stat_file
1490 
1491  def __repr__(self):
1492  s = "IMP.pmi.output.DataEntry\n"
1493  s += "---- stat file %s \n" % (self.stat_file)
1494  s += "---- rmf file %s \n" % (self.rmf_name)
1495  s += "---- rmf index %s \n" % (str(self.rmf_index))
1496  s += "---- score %s \n" % (str(self.score))
1497  s += "---- number of features %s \n" % (str(len(self.features.keys())))
1498  return s
1499 
1500 
1501 class Cluster(object):
1502  '''
1503  A container for models organized into clusters
1504  '''
1505  def __init__(self, cid=None):
1506  self.cluster_id = cid
1507  self.members = []
1508  self.precision = None
1509  self.center_index = None
1510  self.members_data = {}
1511 
1512  def add_member(self, index, data=None):
1513  self.members.append(index)
1514  self.members_data[index] = data
1515  self.average_score = self.compute_score()
1516 
1517  def compute_score(self):
1518  try:
1519  score = sum([d.score for d in self])/len(self)
1520  except AttributeError:
1521  score = None
1522  return score
1523 
1524  def __repr__(self):
1525  s = "IMP.pmi.output.Cluster\n"
1526  s += "---- cluster_id %s \n" % str(self.cluster_id)
1527  s += "---- precision %s \n" % str(self.precision)
1528  s += "---- average score %s \n" % str(self.average_score)
1529  s += "---- number of members %s \n" % str(len(self.members))
1530  s += "---- center index %s \n" % str(self.center_index)
1531  return s
1532 
1533  def __getitem__(self, int_slice_adaptor):
1534  if isinstance(int_slice_adaptor, int):
1535  index = self.members[int_slice_adaptor]
1536  return self.members_data[index]
1537  elif isinstance(int_slice_adaptor, slice):
1538  return self.__iter__(int_slice_adaptor)
1539  else:
1540  raise TypeError("Unknown Type")
1541 
1542  def __len__(self):
1543  return len(self.members)
1544 
1545  def __iter__(self, slice_key=None):
1546  if slice_key is None:
1547  for i in range(len(self)):
1548  yield self[i]
1549  else:
1550  for i in range(len(self))[slice_key]:
1551  yield self[i]
1552 
1553  def __add__(self, other):
1554  self.members += other.members
1555  self.members_data.update(other.members_data)
1556  self.average_score = self.compute_score()
1557  self.precision = None
1558  self.center_index = None
1559  return self
1560 
1561 
1562 def plot_clusters_populations(clusters):
1563  indexes = []
1564  populations = []
1565  for cluster in clusters:
1566  indexes.append(cluster.cluster_id)
1567  populations.append(len(cluster))
1568 
1569  import matplotlib.pyplot as plt
1570  fig, ax = plt.subplots()
1571  ax.bar(indexes, populations, 0.5, color='r')
1572  ax.set_ylabel('Population')
1573  ax.set_xlabel(('Cluster index'))
1574  plt.show()
1575 
1576 
1577 def plot_clusters_precisions(clusters):
1578  indexes = []
1579  precisions = []
1580  for cluster in clusters:
1581  indexes.append(cluster.cluster_id)
1582 
1583  prec = cluster.precision
1584  print(cluster.cluster_id, prec)
1585  if prec is None:
1586  prec = 0.0
1587  precisions.append(prec)
1588 
1589  import matplotlib.pyplot as plt
1590  fig, ax = plt.subplots()
1591  ax.bar(indexes, precisions, 0.5, color='r')
1592  ax.set_ylabel('Precision [A]')
1593  ax.set_xlabel(('Cluster index'))
1594  plt.show()
1595 
1596 
1597 def plot_clusters_scores(clusters):
1598  indexes = []
1599  values = []
1600  for cluster in clusters:
1601  indexes.append(cluster.cluster_id)
1602  values.append([])
1603  for data in cluster:
1604  values[-1].append(data.score)
1605 
1606  plot_fields_box_plots("scores.pdf", values, indexes, frequencies=None,
1607  valuename="Scores", positionname="Cluster index",
1608  xlabels=None, scale_plot_length=1.0)
1609 
1610 
1611 class CrossLinkIdentifierDatabase(object):
1612  def __init__(self):
1613  self.clidb = dict()
1614 
1615  def check_key(self, key):
1616  if key not in self.clidb:
1617  self.clidb[key] = {}
1618 
1619  def set_unique_id(self, key, value):
1620  self.check_key(key)
1621  self.clidb[key]["XLUniqueID"] = str(value)
1622 
1623  def set_protein1(self, key, value):
1624  self.check_key(key)
1625  self.clidb[key]["Protein1"] = str(value)
1626 
1627  def set_protein2(self, key, value):
1628  self.check_key(key)
1629  self.clidb[key]["Protein2"] = str(value)
1630 
1631  def set_residue1(self, key, value):
1632  self.check_key(key)
1633  self.clidb[key]["Residue1"] = int(value)
1634 
1635  def set_residue2(self, key, value):
1636  self.check_key(key)
1637  self.clidb[key]["Residue2"] = int(value)
1638 
1639  def set_idscore(self, key, value):
1640  self.check_key(key)
1641  self.clidb[key]["IDScore"] = float(value)
1642 
1643  def set_state(self, key, value):
1644  self.check_key(key)
1645  self.clidb[key]["State"] = int(value)
1646 
1647  def set_sigma1(self, key, value):
1648  self.check_key(key)
1649  self.clidb[key]["Sigma1"] = str(value)
1650 
1651  def set_sigma2(self, key, value):
1652  self.check_key(key)
1653  self.clidb[key]["Sigma2"] = str(value)
1654 
1655  def set_psi(self, key, value):
1656  self.check_key(key)
1657  self.clidb[key]["Psi"] = str(value)
1658 
1659  def get_unique_id(self, key):
1660  return self.clidb[key]["XLUniqueID"]
1661 
1662  def get_protein1(self, key):
1663  return self.clidb[key]["Protein1"]
1664 
1665  def get_protein2(self, key):
1666  return self.clidb[key]["Protein2"]
1667 
1668  def get_residue1(self, key):
1669  return self.clidb[key]["Residue1"]
1670 
1671  def get_residue2(self, key):
1672  return self.clidb[key]["Residue2"]
1673 
1674  def get_idscore(self, key):
1675  return self.clidb[key]["IDScore"]
1676 
1677  def get_state(self, key):
1678  return self.clidb[key]["State"]
1679 
1680  def get_sigma1(self, key):
1681  return self.clidb[key]["Sigma1"]
1682 
1683  def get_sigma2(self, key):
1684  return self.clidb[key]["Sigma2"]
1685 
1686  def get_psi(self, key):
1687  return self.clidb[key]["Psi"]
1688 
1689  def set_float_feature(self, key, value, feature_name):
1690  self.check_key(key)
1691  self.clidb[key][feature_name] = float(value)
1692 
1693  def set_int_feature(self, key, value, feature_name):
1694  self.check_key(key)
1695  self.clidb[key][feature_name] = int(value)
1696 
1697  def set_string_feature(self, key, value, feature_name):
1698  self.check_key(key)
1699  self.clidb[key][feature_name] = str(value)
1700 
1701  def get_feature(self, key, feature_name):
1702  return self.clidb[key][feature_name]
1703 
1704  def write(self, filename):
1705  with open(filename, 'wb') as handle:
1706  pickle.dump(self.clidb, handle)
1707 
1708  def load(self, filename):
1709  with open(filename, 'rb') as handle:
1710  self.clidb = pickle.load(handle)
1711 
1712 
1713 def plot_fields(fields, output, framemin=None, framemax=None):
1714  """Plot the given fields and save a figure as `output`.
1715  The fields generally are extracted from a stat file
1716  using ProcessOutput.get_fields()."""
1717  import matplotlib as mpl
1718  mpl.use('Agg')
1719  import matplotlib.pyplot as plt
1720 
1721  plt.rc('lines', linewidth=4)
1722  fig, axs = plt.subplots(nrows=len(fields))
1723  fig.set_size_inches(10.5, 5.5 * len(fields))
1724  plt.rc('axes')
1725 
1726  n = 0
1727  for key in fields:
1728  if framemin is None:
1729  framemin = 0
1730  if framemax is None:
1731  framemax = len(fields[key])
1732  x = list(range(framemin, framemax))
1733  y = [float(y) for y in fields[key][framemin:framemax]]
1734  if len(fields) > 1:
1735  axs[n].plot(x, y)
1736  axs[n].set_title(key, size="xx-large")
1737  axs[n].tick_params(labelsize=18, pad=10)
1738  else:
1739  axs.plot(x, y)
1740  axs.set_title(key, size="xx-large")
1741  axs.tick_params(labelsize=18, pad=10)
1742  n += 1
1743 
1744  # Tweak spacing between subplots to prevent labels from overlapping
1745  plt.subplots_adjust(hspace=0.3)
1746  plt.savefig(output)
1747 
1748 
1749 def plot_field_histogram(name, values_lists, valuename=None, bins=40,
1750  colors=None, format="png", reference_xline=None,
1751  yplotrange=None, xplotrange=None, normalized=True,
1752  leg_names=None):
1753  '''Plot a list of histograms from a value list.
1754  @param name the name of the plot
1755  @param value_lists the list of list of values eg: [[...],[...],[...]]
1756  @param valuename the y-label
1757  @param bins the number of bins
1758  @param colors If None, will use rainbow. Else will use specific list
1759  @param format output format
1760  @param reference_xline plot a reference line parallel to the y-axis
1761  @param yplotrange the range for the y-axis
1762  @param normalized whether the histogram is normalized or not
1763  @param leg_names names for the legend
1764  '''
1765 
1766  import matplotlib as mpl
1767  mpl.use('Agg')
1768  import matplotlib.pyplot as plt
1769  import matplotlib.cm as cm
1770  plt.figure(figsize=(18.0, 9.0))
1771 
1772  if colors is None:
1773  colors = cm.rainbow(np.linspace(0, 1, len(values_lists)))
1774  for nv, values in enumerate(values_lists):
1775  col = colors[nv]
1776  if leg_names is not None:
1777  label = leg_names[nv]
1778  else:
1779  label = str(nv)
1780  plt.hist(
1781  [float(y) for y in values], bins=bins, color=col,
1782  density=normalized, histtype='step', lw=4, label=label)
1783 
1784  # plt.title(name,size="xx-large")
1785  plt.tick_params(labelsize=12, pad=10)
1786  if valuename is None:
1787  plt.xlabel(name, size="xx-large")
1788  else:
1789  plt.xlabel(valuename, size="xx-large")
1790  plt.ylabel("Frequency", size="xx-large")
1791 
1792  if yplotrange is not None:
1793  plt.ylim()
1794  if xplotrange is not None:
1795  plt.xlim(xplotrange)
1796 
1797  plt.legend(loc=2)
1798 
1799  if reference_xline is not None:
1800  plt.axvline(
1801  reference_xline,
1802  color='red',
1803  linestyle='dashed',
1804  linewidth=1)
1805 
1806  plt.savefig(name + "." + format, dpi=150, transparent=True)
1807 
1808 
1809 def plot_fields_box_plots(name, values, positions, frequencies=None,
1810  valuename="None", positionname="None",
1811  xlabels=None, scale_plot_length=1.0):
1812  '''
1813  Plot time series as boxplots.
1814  fields is a list of time series, positions are the x-values
1815  valuename is the y-label, positionname is the x-label
1816  '''
1817 
1818  import matplotlib as mpl
1819  mpl.use('Agg')
1820  import matplotlib.pyplot as plt
1821 
1822  bps = []
1823  fig = plt.figure(figsize=(float(len(positions))*scale_plot_length, 5.0))
1824  fig.canvas.set_window_title(name)
1825 
1826  ax1 = fig.add_subplot(111)
1827 
1828  plt.subplots_adjust(left=0.1, right=0.990, top=0.95, bottom=0.4)
1829 
1830  bps.append(plt.boxplot(values, notch=0, sym='', vert=1,
1831  whis=1.5, positions=positions))
1832 
1833  plt.setp(bps[-1]['boxes'], color='black', lw=1.5)
1834  plt.setp(bps[-1]['whiskers'], color='black', ls=":", lw=1.5)
1835 
1836  if frequencies is not None:
1837  for n, v in enumerate(values):
1838  plist = [positions[n]]*len(v)
1839  ax1.plot(plist, v, 'gx', alpha=0.7, markersize=7)
1840 
1841  # print ax1.xaxis.get_majorticklocs()
1842  if xlabels is not None:
1843  ax1.set_xticklabels(xlabels)
1844  plt.xticks(rotation=90)
1845  plt.xlabel(positionname)
1846  plt.ylabel(valuename)
1847 
1848  plt.savefig(name + ".pdf", dpi=150)
1849  plt.show()
1850 
1851 
1852 def plot_xy_data(x, y, title=None, out_fn=None, display=True,
1853  set_plot_yaxis_range=None, xlabel=None, ylabel=None):
1854  import matplotlib as mpl
1855  mpl.use('Agg')
1856  import matplotlib.pyplot as plt
1857  plt.rc('lines', linewidth=2)
1858 
1859  fig, ax = plt.subplots(nrows=1)
1860  fig.set_size_inches(8, 4.5)
1861  if title is not None:
1862  fig.canvas.set_window_title(title)
1863 
1864  ax.plot(x, y, color='r')
1865  if set_plot_yaxis_range is not None:
1866  x1, x2, y1, y2 = plt.axis()
1867  y1 = set_plot_yaxis_range[0]
1868  y2 = set_plot_yaxis_range[1]
1869  plt.axis((x1, x2, y1, y2))
1870  if title is not None:
1871  ax.set_title(title)
1872  if xlabel is not None:
1873  ax.set_xlabel(xlabel)
1874  if ylabel is not None:
1875  ax.set_ylabel(ylabel)
1876  if out_fn is not None:
1877  plt.savefig(out_fn + ".pdf")
1878  if display:
1879  plt.show()
1880  plt.close(fig)
1881 
1882 
1883 def plot_scatter_xy_data(x, y, labelx="None", labely="None",
1884  xmin=None, xmax=None, ymin=None, ymax=None,
1885  savefile=False, filename="None.eps", alpha=0.75):
1886 
1887  import matplotlib as mpl
1888  mpl.use('Agg')
1889  import matplotlib.pyplot as plt
1890  from matplotlib import rc
1891  rc('font', **{'family': 'sans-serif', 'sans-serif': ['Helvetica']})
1892 
1893  fig, axs = plt.subplots(1)
1894 
1895  axs0 = axs
1896 
1897  axs0.set_xlabel(labelx, size="xx-large")
1898  axs0.set_ylabel(labely, size="xx-large")
1899  axs0.tick_params(labelsize=18, pad=10)
1900 
1901  plot2 = []
1902 
1903  plot2.append(axs0.plot(x, y, 'o', color='k', lw=2, ms=0.1, alpha=alpha,
1904  c="w"))
1905 
1906  axs0.legend(
1907  loc=0,
1908  frameon=False,
1909  scatterpoints=1,
1910  numpoints=1,
1911  columnspacing=1)
1912 
1913  fig.set_size_inches(8.0, 8.0)
1914  fig.subplots_adjust(left=0.161, right=0.850, top=0.95, bottom=0.11)
1915  if (ymin is not None) and (ymax is not None):
1916  axs0.set_ylim(ymin, ymax)
1917  if (xmin is not None) and (xmax is not None):
1918  axs0.set_xlim(xmin, xmax)
1919 
1920  if savefile:
1921  fig.savefig(filename, dpi=300)
1922 
1923 
1924 def get_graph_from_hierarchy(hier):
1925  graph = []
1926  depth_dict = {}
1927  depth = 0
1928  (graph, depth, depth_dict) = recursive_graph(
1929  hier, graph, depth, depth_dict)
1930 
1931  # filters node labels according to depth_dict
1932  node_labels_dict = {}
1933  for key in depth_dict:
1934  if depth_dict[key] < 3:
1935  node_labels_dict[key] = key
1936  else:
1937  node_labels_dict[key] = ""
1938  draw_graph(graph, labels_dict=node_labels_dict)
1939 
1940 
1941 def recursive_graph(hier, graph, depth, depth_dict):
1942  depth = depth + 1
1943  nameh = IMP.atom.Hierarchy(hier).get_name()
1944  index = str(hier.get_particle().get_index())
1945  name1 = nameh + "|#" + index
1946  depth_dict[name1] = depth
1947 
1948  children = IMP.atom.Hierarchy(hier).get_children()
1949 
1950  if len(children) == 1 or children is None:
1951  depth = depth - 1
1952  return (graph, depth, depth_dict)
1953 
1954  else:
1955  for c in children:
1956  (graph, depth, depth_dict) = recursive_graph(
1957  c, graph, depth, depth_dict)
1958  nameh = IMP.atom.Hierarchy(c).get_name()
1959  index = str(c.get_particle().get_index())
1960  namec = nameh + "|#" + index
1961  graph.append((name1, namec))
1962 
1963  depth = depth - 1
1964  return (graph, depth, depth_dict)
1965 
1966 
1967 def draw_graph(graph, labels_dict=None, graph_layout='spring',
1968  node_size=5, node_color=None, node_alpha=0.3,
1969  node_text_size=11, fixed=None, pos=None,
1970  edge_color='blue', edge_alpha=0.3, edge_thickness=1,
1971  edge_text_pos=0.3,
1972  validation_edges=None,
1973  text_font='sans-serif',
1974  out_filename=None):
1975 
1976  import matplotlib as mpl
1977  mpl.use('Agg')
1978  import networkx as nx
1979  import matplotlib.pyplot as plt
1980  from math import sqrt, pi
1981 
1982  # create networkx graph
1983  G = nx.Graph()
1984 
1985  # add edges
1986  if isinstance(edge_thickness, list):
1987  for edge, weight in zip(graph, edge_thickness):
1988  G.add_edge(edge[0], edge[1], weight=weight)
1989  else:
1990  for edge in graph:
1991  G.add_edge(edge[0], edge[1])
1992 
1993  if node_color is None:
1994  node_color_rgb = (0, 0, 0)
1995  node_color_hex = "000000"
1996  else:
1998  tmpcolor_rgb = []
1999  tmpcolor_hex = []
2000  for node in G.nodes():
2001  cctuple = cc.rgb(node_color[node])
2002  tmpcolor_rgb.append((cctuple[0]/255,
2003  cctuple[1]/255,
2004  cctuple[2]/255))
2005  tmpcolor_hex.append(node_color[node])
2006  node_color_rgb = tmpcolor_rgb
2007  node_color_hex = tmpcolor_hex
2008 
2009  # get node sizes if dictionary
2010  if isinstance(node_size, dict):
2011  tmpsize = []
2012  for node in G.nodes():
2013  size = sqrt(node_size[node])/pi*10.0
2014  tmpsize.append(size)
2015  node_size = tmpsize
2016 
2017  for n, node in enumerate(G.nodes()):
2018  color = node_color_hex[n]
2019  size = node_size[n]
2020  nx.set_node_attributes(
2021  G, "graphics",
2022  {node: {'type': 'ellipse', 'w': size, 'h': size,
2023  'fill': '#' + color, 'label': node}})
2024  nx.set_node_attributes(
2025  G, "LabelGraphics",
2026  {node: {'type': 'text', 'text': node, 'color': '#000000',
2027  'visible': 'true'}})
2028 
2029  for edge in G.edges():
2030  nx.set_edge_attributes(
2031  G, "graphics",
2032  {edge: {'width': 1, 'fill': '#000000'}})
2033 
2034  for ve in validation_edges:
2035  print(ve)
2036  if (ve[0], ve[1]) in G.edges():
2037  print("found forward")
2038  nx.set_edge_attributes(
2039  G, "graphics",
2040  {ve: {'width': 1, 'fill': '#00FF00'}})
2041  elif (ve[1], ve[0]) in G.edges():
2042  print("found backward")
2043  nx.set_edge_attributes(
2044  G, "graphics",
2045  {(ve[1], ve[0]): {'width': 1, 'fill': '#00FF00'}})
2046  else:
2047  G.add_edge(ve[0], ve[1])
2048  print("not found")
2049  nx.set_edge_attributes(
2050  G, "graphics",
2051  {ve: {'width': 1, 'fill': '#FF0000'}})
2052 
2053  # these are different layouts for the network you may try
2054  # shell seems to work best
2055  if graph_layout == 'spring':
2056  print(fixed, pos)
2057  graph_pos = nx.spring_layout(G, k=1.0/8.0, fixed=fixed, pos=pos)
2058  elif graph_layout == 'spectral':
2059  graph_pos = nx.spectral_layout(G)
2060  elif graph_layout == 'random':
2061  graph_pos = nx.random_layout(G)
2062  else:
2063  graph_pos = nx.shell_layout(G)
2064 
2065  # draw graph
2066  nx.draw_networkx_nodes(G, graph_pos, node_size=node_size,
2067  alpha=node_alpha, node_color=node_color_rgb,
2068  linewidths=0)
2069  nx.draw_networkx_edges(G, graph_pos, width=edge_thickness,
2070  alpha=edge_alpha, edge_color=edge_color)
2071  nx.draw_networkx_labels(
2072  G, graph_pos, labels=labels_dict, font_size=node_text_size,
2073  font_family=text_font)
2074  if out_filename:
2075  plt.savefig(out_filename)
2076  nx.write_gml(G, 'out.gml')
2077  plt.show()
2078 
2079 
2080 def draw_table():
2081 
2082  # still an example!
2083 
2084  from ipyD3 import d3object
2085  from IPython.display import display
2086 
2087  d3 = d3object(width=800,
2088  height=400,
2089  style='JFTable',
2090  number=1,
2091  d3=None,
2092  title='Example table with d3js',
2093  desc='An example table created created with d3js with '
2094  'data generated with Python.')
2095  data = [[1277.0, 654.0, 288.0, 1976.0, 3281.0, 3089.0, 10336.0, 4650.0,
2096  4441.0, 4670.0, 944.0, 110.0],
2097  [1318.0, 664.0, 418.0, 1952.0, 3581.0, 4574.0, 11457.0, 6139.0,
2098  7078.0, 6561.0, 2354.0, 710.0],
2099  [1783.0, 774.0, 564.0, 1470.0, 3571.0, 3103.0, 9392.0, 5532.0,
2100  5661.0, 4991.0, 2032.0, 680.0],
2101  [1301.0, 604.0, 286.0, 2152.0, 3282.0, 3369.0, 10490.0, 5406.0,
2102  4727.0, 3428.0, 1559.0, 620.0],
2103  [1537.0, 1714.0, 724.0, 4824.0, 5551.0, 8096.0, 16589.0, 13650.0,
2104  9552.0, 13709.0, 2460.0, 720.0],
2105  [5691.0, 2995.0, 1680.0, 11741.0, 16232.0, 14731.0, 43522.0,
2106  32794.0, 26634.0, 31400.0, 7350.0, 3010.0],
2107  [1650.0, 2096.0, 60.0, 50.0, 1180.0, 5602.0, 15728.0, 6874.0,
2108  5115.0, 3510.0, 1390.0, 170.0],
2109  [72.0, 60.0, 60.0, 10.0, 120.0, 172.0, 1092.0, 675.0, 408.0,
2110  360.0, 156.0, 100.0]]
2111  data = [list(i) for i in zip(*data)]
2112  sRows = [['January',
2113  'February',
2114  'March',
2115  'April',
2116  'May',
2117  'June',
2118  'July',
2119  'August',
2120  'September',
2121  'October',
2122  'November',
2123  'Deecember']]
2124  sColumns = [['Prod {0}'.format(i) for i in range(1, 9)],
2125  [None, '', None, None, 'Group 1', None, None, 'Group 2']]
2126  d3.addSimpleTable(data,
2127  fontSizeCells=[12, ],
2128  sRows=sRows,
2129  sColumns=sColumns,
2130  sRowsMargins=[5, 50, 0],
2131  sColsMargins=[5, 20, 10],
2132  spacing=0,
2133  addBorders=1,
2134  addOutsideBorders=-1,
2135  rectWidth=45,
2136  rectHeight=0
2137  )
2138  html = d3.render(mode=['html', 'show'])
2139  display(html)
static bool get_is_setup(const IMP::ParticleAdaptor &p)
Definition: Residue.h:156
A container for models organized into clusters.
Definition: output.py:1501
A class for reading stat files (either rmf or ascii v1 and v2)
Definition: output.py:926
atom::Hierarchies create_hierarchies(RMF::FileConstHandle fh, Model *m)
RMF::FrameID save_frame(RMF::FileHandle file, std::string name="")
Save the current state of the linked objects as a new RMF frame.
static bool get_is_setup(const IMP::ParticleAdaptor &p)
Definition: atom/Atom.h:241
def plot_field_histogram
Plot a list of histograms from a value list.
Definition: output.py:1749
def plot_fields_box_plots
Plot time series as boxplots.
Definition: output.py:1809
Utility classes and functions for reading and storing PMI files.
def get_best_models
Given a list of stat files, read them all and find the best models.
Miscellaneous utilities.
Definition: tools.py:1
A class to store data associated to a model.
Definition: output.py:1479
void handle_use_deprecated(std::string message)
Break in this method in gdb to find deprecated uses at runtime.
std::string get_module_version()
Return the version of this module, as a string.
Change color code to hexadecimal to rgb.
Definition: tools.py:805
void write_pdb(const Selection &mhd, TextOutput out, unsigned int model=1)
Collect statistics from ProcessOutput.get_fields().
Definition: output.py:915
def get_prot_name_from_particle
Return the component name provided a particle and a list of names.
Definition: tools.py:545
def get_fields
Get the desired field names, and return a dictionary.
Definition: output.py:991
Warning related to handling of structures.
static bool get_is_setup(Model *m, ParticleIndex pi)
Definition: Fragment.h:46
def link_to_rmf
Link to another RMF file.
Definition: output.py:1150
std::string get_molecule_name_and_copy(atom::Hierarchy h)
Walk up a PMI2 hierarchy/representations and get the "molname.copynum".
Definition: pmi/utilities.h:85
The standard decorator for manipulating molecular structures.
Ints get_index(const ParticlesTemp &particles, const Subset &subset, const Subsets &excluded)
def init_pdb
Init PDB Writing.
Definition: output.py:259
int get_number_of_frames(const ::npctransport_proto::Assignment &config, double time_step)
A decorator for a particle representing an atom.
Definition: atom/Atom.h:234
Base class for capturing a modeling protocol.
Definition: output.py:45
The type for a residue.
void load_frame(RMF::FileConstHandle file, RMF::FrameID frame)
Load the given RMF frame into the state of the linked objects.
A decorator for a particle with x,y,z coordinates.
Definition: XYZ.h:30
A base class for Keys.
Definition: Key.h:44
void add_hierarchies(RMF::NodeHandle fh, const atom::Hierarchies &hs)
Class for easy writing of PDBs, RMFs, and stat files.
Definition: output.py:200
void add_geometries(RMF::NodeHandle parent, const display::GeometriesTemp &r)
Add geometries to a given parent node.
void add_restraints(RMF::NodeHandle fh, const Restraints &hs)
A decorator for a particle that is part of a rigid body but not rigid.
Definition: rigid_bodies.h:750
bool get_is_canonical(atom::Hierarchy h)
Walk up a PMI2 hierarchy/representations and check if the root is named System.
Definition: pmi/utilities.h:91
Display a segment connecting a pair of particles.
Definition: XYZR.h:170
A decorator for a residue.
Definition: Residue.h:135
Basic functionality that is expected to be used by a wide variety of IMP users.
def get_pdb_names
Get a list of all PDB files being output by this instance.
Definition: output.py:228
def get_prot_name_from_particle
Get the protein name from the particle.
Definition: output.py:345
class to link stat files to several rmf files
Definition: output.py:1248
class to allow more advanced handling of RMF files.
Definition: output.py:1124
void link_hierarchies(RMF::FileConstHandle fh, const atom::Hierarchies &hs)
def plot_fields
Plot the given fields and save a figure as output.
Definition: output.py:1713
void add_geometry(RMF::FileHandle file, display::Geometry *r)
Add a single geometry to the file.
Store info for a chain of a protein.
Definition: Chain.h:61
Python classes to represent, score, sample and analyze models.
Functionality for loading, creating, manipulating and scoring atomic structures.
def get_rbs_and_beads
Returns unique objects in original order.
Definition: tools.py:1207
Hierarchies get_leaves(const Selection &h)
Select hierarchy particles identified by the biological name.
Definition: Selection.h:66
def init_rmf
Initialize an RMF file.
Definition: output.py:556
def get_residue_indexes
Retrieve the residue indexes for the given particle.
Definition: tools.py:565
static bool get_is_setup(const IMP::ParticleAdaptor &p)
Definition: rigid_bodies.h:752
std::string get_module_version()
Return the version of this module, as a string.
def sublist_iterator
Yield all sublists of length >= lmin and <= lmax.
Definition: tools.py:640
A decorator for a particle with x,y,z coordinates and a radius.
Definition: XYZR.h:27