1 """@namespace IMP.pmi.output
2 Classes for writing output files and processing them.
26 """Map indices to multi-character chain IDs.
27 We label the first 26 chains A-Z, then we move to two-letter
28 chain IDs: AA through AZ, then BA through BZ, through to ZZ.
29 This continues with longer chain IDs."""
30 def __getitem__(self, ind):
31 chars = string.ascii_uppercase
35 ids.append(chars[ind % lc])
37 ids.append(chars[ind])
38 return "".join(reversed(ids))
42 """Base class for capturing a modeling protocol.
43 Unlike simple output of model coordinates, a complete
44 protocol includes the input data used, details on the restraints,
45 sampling, and clustering, as well as output models.
46 Use via IMP.pmi.topology.System.add_protocol_output().
48 @see IMP.pmi.mmcif.ProtocolOutput for a concrete subclass that outputs
56 if isinstance(elt, (tuple, list)):
57 for elt2
in _flatten(elt):
63 def _disambiguate_chain(chid, seen_chains):
64 """Make sure that the chain ID is unique; warn and correct if it isn't"""
69 if chid
in seen_chains:
70 warnings.warn(
"Duplicate chain ID '%s' encountered" % chid,
73 for suffix
in itertools.count(1):
74 new_chid = chid +
"%d" % suffix
75 if new_chid
not in seen_chains:
76 seen_chains.add(new_chid)
82 def _write_pdb_internal(flpdb, particle_infos_for_pdb, geometric_center,
83 write_all_residues_per_bead):
84 for n, tupl
in enumerate(particle_infos_for_pdb):
85 (xyz, atom_type, residue_type,
86 chain_id, residue_index, all_indexes, radius) = tupl
88 atom_type = IMP.atom.AT_CA
89 if write_all_residues_per_bead
and all_indexes
is not None:
90 for residue_number
in all_indexes:
92 IMP.atom.get_pdb_string((xyz[0] - geometric_center[0],
93 xyz[1] - geometric_center[1],
94 xyz[2] - geometric_center[2]),
95 n+1, atom_type, residue_type,
96 chain_id[:1], residue_number,
' ',
100 IMP.atom.get_pdb_string((xyz[0] - geometric_center[0],
101 xyz[1] - geometric_center[1],
102 xyz[2] - geometric_center[2]),
103 n+1, atom_type, residue_type,
104 chain_id[:1], residue_index,
' ',
106 flpdb.write(
"ENDMDL\n")
109 _Entity = collections.namedtuple(
'_Entity', (
'id',
'seq'))
110 _ChainInfo = collections.namedtuple(
'_ChainInfo', (
'entity',
'name'))
113 def _get_chain_info(chains, root_hier):
117 for mol
in IMP.atom.get_by_type(root_hier, IMP.atom.MOLECULE_TYPE):
119 chain_id = chains[molname]
121 seq = chain.get_sequence()
122 if seq
not in entities:
123 entities[seq] = e = _Entity(id=len(entities)+1, seq=seq)
124 all_entities.append(e)
125 entity = entities[seq]
126 info = _ChainInfo(entity=entity, name=molname)
127 chain_info[chain_id] = info
128 return chain_info, all_entities
131 def _write_mmcif_internal(flpdb, particle_infos_for_pdb, geometric_center,
132 write_all_residues_per_bead, chains, root_hier):
134 chain_info, entities = _get_chain_info(chains, root_hier)
136 writer = ihm.format.CifWriter(flpdb)
137 writer.start_block(
'model')
138 with writer.category(
"_entry")
as lp:
141 with writer.loop(
"_entity", [
"id",
"type"])
as lp:
143 lp.write(id=e.id, type=
"polymer")
145 with writer.loop(
"_entity_poly",
146 [
"entity_id",
"pdbx_seq_one_letter_code"])
as lp:
148 lp.write(entity_id=e.id, pdbx_seq_one_letter_code=e.seq)
150 with writer.loop(
"_struct_asym", [
"id",
"entity_id",
"details"])
as lp:
152 for chid
in sorted(chains.values(), key=
lambda x: (len(x.strip()), x)):
153 ci = chain_info[chid]
154 lp.write(id=chid, entity_id=ci.entity.id, details=ci.name)
156 with writer.loop(
"_atom_site",
157 [
"group_PDB",
"type_symbol",
"label_atom_id",
158 "label_comp_id",
"label_asym_id",
"label_seq_id",
160 "Cartn_x",
"Cartn_y",
"Cartn_z",
"label_entity_id",
161 "pdbx_pdb_model_num",
164 for n, tupl
in enumerate(particle_infos_for_pdb):
165 (xyz, atom_type, residue_type,
166 chain_id, residue_index, all_indexes, radius) = tupl
167 ci = chain_info[chain_id]
168 if atom_type
is None:
169 atom_type = IMP.atom.AT_CA
170 c = (xyz[0] - geometric_center[0],
171 xyz[1] - geometric_center[1],
172 xyz[2] - geometric_center[2])
173 if write_all_residues_per_bead
and all_indexes
is not None:
174 for residue_number
in all_indexes:
175 lp.write(group_PDB=
'ATOM',
177 label_atom_id=atom_type.get_string(),
178 label_comp_id=residue_type.get_string(),
179 label_asym_id=chain_id,
180 label_seq_id=residue_index,
181 auth_seq_id=residue_index, Cartn_x=c[0],
182 Cartn_y=c[1], Cartn_z=c[2], id=ordinal,
183 pdbx_pdb_model_num=1,
184 label_entity_id=ci.entity.id)
187 lp.write(group_PDB=
'ATOM', type_symbol=
'C',
188 label_atom_id=atom_type.get_string(),
189 label_comp_id=residue_type.get_string(),
190 label_asym_id=chain_id,
191 label_seq_id=residue_index,
192 auth_seq_id=residue_index, Cartn_x=c[0],
193 Cartn_y=c[1], Cartn_z=c[2], id=ordinal,
194 pdbx_pdb_model_num=1,
195 label_entity_id=ci.entity.id)
200 """Class for easy writing of PDBs, RMFs, and stat files
202 @note Model should be updated prior to writing outputs.
204 def __init__(self, ascii=True, atomistic=False):
205 self.dictionary_pdbs = {}
207 self.dictionary_rmfs = {}
208 self.dictionary_stats = {}
209 self.dictionary_stats2 = {}
210 self.best_score_list =
None
211 self.nbestscoring =
None
213 self.replica_exchange =
False
218 self.chainids =
"ABCDEFGHIJKLMNOPQRSTUVWXYZ" \
219 "abcdefghijklmnopqrstuvwxyz0123456789"
221 self.multi_chainids = _ChainIDs()
223 self.particle_infos_for_pdb = {}
224 self.atomistic = atomistic
227 """Get a list of all PDB files being output by this instance"""
228 return list(self.dictionary_pdbs.keys())
230 def get_rmf_names(self):
231 return list(self.dictionary_rmfs.keys())
233 def get_stat_names(self):
234 return list(self.dictionary_stats.keys())
236 def _init_dictchain(self, name, prot, multichar_chain=False, mmcif=False):
237 self.dictchain[name] = {}
241 self.atomistic =
True
242 for n, mol
in enumerate(IMP.atom.get_by_type(
243 prot, IMP.atom.MOLECULE_TYPE)):
245 if not mmcif
and len(chid) > 1:
247 "The system contains at least one chain ID (%s) that "
248 "is more than 1 character long; this cannot be "
249 "represented in PDB. Either write mmCIF files "
250 "instead, or assign 1-character IDs to all chains "
251 "(this can be done with the `chain_ids` argument to "
252 "BuildSystem.add_state())." % chid)
253 chid = _disambiguate_chain(chid, seen_chains)
255 self.dictchain[name][molname] = chid
259 @param name The PDB filename
260 @param prot The hierarchy to write to this pdb file
261 @param mmcif If True, write PDBs in mmCIF format
262 @note if the PDB name is 'System' then will use Selection
265 flpdb = open(name,
'w')
267 self.dictionary_pdbs[name] = prot
268 self._pdb_mmcif[name] = mmcif
269 self._init_dictchain(name, prot, mmcif=mmcif)
271 def write_psf(self, filename, name):
272 flpsf = open(filename,
'w')
273 flpsf.write(
"PSF CMAP CHEQ" +
"\n")
274 index_residue_pair_list = {}
275 (particle_infos_for_pdb, geometric_center) = \
276 self.get_particle_infos_for_pdb_writing(name)
277 nparticles = len(particle_infos_for_pdb)
278 flpsf.write(str(nparticles) +
" !NATOM" +
"\n")
279 for n, p
in enumerate(particle_infos_for_pdb):
284 flpsf.write(
'{0:8d}{1:1s}{2:4s}{3:1s}{4:4s}{5:1s}{6:4s}{7:1s}'
285 '{8:4s}{9:1s}{10:4s}{11:14.6f}{12:14.6f}{13:8d}'
286 '{14:14.6f}{15:14.6f}'.format(
287 atom_index,
" ", chain,
" ", str(resid),
" ",
288 '"'+residue_type.get_string()+
'"',
" ",
"C",
289 " ",
"C", 1.0, 0.0, 0, 0.0, 0.0))
291 if chain
not in index_residue_pair_list:
292 index_residue_pair_list[chain] = [(atom_index, resid)]
294 index_residue_pair_list[chain].append((atom_index, resid))
298 for chain
in sorted(index_residue_pair_list.keys()):
300 ls = index_residue_pair_list[chain]
302 ls = sorted(ls, key=
lambda tup: tup[1])
304 indexes = [x[0]
for x
in ls]
307 indexes, lmin=2, lmax=2))
308 nbonds = len(indexes_pairs)
309 flpsf.write(str(nbonds)+
" !NBOND: bonds"+
"\n")
312 for i
in range(0, len(indexes_pairs), 4):
313 for bond
in indexes_pairs[i:i+4]:
314 flpsf.write(
'{0:8d}{1:8d}'.format(*bond))
317 del particle_infos_for_pdb
320 def write_pdb(self, name, appendmode=True,
321 translate_to_geometric_center=
False,
322 write_all_residues_per_bead=
False):
324 (particle_infos_for_pdb,
325 geometric_center) = self.get_particle_infos_for_pdb_writing(name)
327 if not translate_to_geometric_center:
328 geometric_center = (0, 0, 0)
330 filemode =
'a' if appendmode
else 'w'
331 with open(name, filemode)
as flpdb:
332 if self._pdb_mmcif[name]:
333 _write_mmcif_internal(flpdb, particle_infos_for_pdb,
335 write_all_residues_per_bead,
336 self.dictchain[name],
337 self.dictionary_pdbs[name])
339 _write_pdb_internal(flpdb, particle_infos_for_pdb,
341 write_all_residues_per_bead)
344 """Get the protein name from the particle.
345 This is done by traversing the hierarchy."""
348 def get_particle_infos_for_pdb_writing(self, name):
358 particle_infos_for_pdb = []
360 geometric_center = [0, 0, 0]
365 and self.dictionary_pdbs[name].get_number_of_children() == 0):
369 ps = sel.get_selected_particles()
371 for n, p
in enumerate(ps):
374 if protname
not in resindexes_dict:
375 resindexes_dict[protname] = []
379 rt = residue.get_residue_type()
380 resind = residue.get_index()
384 geometric_center[0] += xyz[0]
385 geometric_center[1] += xyz[1]
386 geometric_center[2] += xyz[2]
388 particle_infos_for_pdb.append(
389 (xyz, atomtype, rt, self.dictchain[name][protname],
390 resind,
None, radius))
391 resindexes_dict[protname].append(resind)
396 resind = residue.get_index()
399 if resind
in resindexes_dict[protname]:
402 resindexes_dict[protname].append(resind)
403 rt = residue.get_residue_type()
406 geometric_center[0] += xyz[0]
407 geometric_center[1] += xyz[1]
408 geometric_center[2] += xyz[2]
410 particle_infos_for_pdb.append(
411 (xyz,
None, rt, self.dictchain[name][protname], resind,
416 resind = resindexes[len(resindexes) // 2]
417 if resind
in resindexes_dict[protname]:
420 resindexes_dict[protname].append(resind)
424 geometric_center[0] += xyz[0]
425 geometric_center[1] += xyz[1]
426 geometric_center[2] += xyz[2]
428 particle_infos_for_pdb.append(
429 (xyz,
None, rt, self.dictchain[name][protname], resind,
436 if len(resindexes) > 0:
437 resind = resindexes[len(resindexes) // 2]
440 geometric_center[0] += xyz[0]
441 geometric_center[1] += xyz[1]
442 geometric_center[2] += xyz[2]
444 particle_infos_for_pdb.append(
445 (xyz,
None, rt, self.dictchain[name][protname],
446 resind, resindexes, radius))
449 geometric_center = (geometric_center[0] / atom_count,
450 geometric_center[1] / atom_count,
451 geometric_center[2] / atom_count)
455 particle_infos_for_pdb = sorted(particle_infos_for_pdb,
456 key=
lambda x: (len(x[3]), x[3], x[4]))
458 return (particle_infos_for_pdb, geometric_center)
460 def write_pdbs(self, appendmode=True, mmcif=False):
461 for pdb
in self.dictionary_pdbs.keys():
462 self.write_pdb(pdb, appendmode)
465 replica_exchange=
False, mmcif=
False,
466 best_score_file=
'best.scores.rex.py'):
467 """Prepare for writing best-scoring PDBs (or mmCIFs) for a
470 @param prefix Initial part of each PDB filename (e.g. 'model').
471 @param prot The top-level Hierarchy to output.
472 @param nbestscoring The number of best-scoring files to output.
473 @param replica_exchange Whether to combine best scores from a
474 replica exchange run.
475 @param mmcif If True, output models in mmCIF format. If False
476 (the default) output in legacy PDB format.
477 @param best_score_file The filename to use for replica
481 self._pdb_best_scoring_mmcif = mmcif
482 fileext =
'.cif' if mmcif
else '.pdb'
483 self.prefixes.append(prefix)
484 self.replica_exchange = replica_exchange
485 if not self.replica_exchange:
489 self.best_score_list = []
493 self.best_score_file_name = best_score_file
494 self.best_score_list = []
495 with open(self.best_score_file_name,
"w")
as best_score_file:
496 best_score_file.write(
497 "self.best_score_list=" + str(self.best_score_list) +
"\n")
499 self.nbestscoring = nbestscoring
500 for i
in range(self.nbestscoring):
501 name = prefix +
"." + str(i) + fileext
502 flpdb = open(name,
'w')
504 self.dictionary_pdbs[name] = prot
505 self._pdb_mmcif[name] = mmcif
506 self._init_dictchain(name, prot, mmcif=mmcif)
508 def write_pdb_best_scoring(self, score):
509 if self.nbestscoring
is None:
510 print(
"Output.write_pdb_best_scoring: init_pdb_best_scoring "
513 mmcif = self._pdb_best_scoring_mmcif
514 fileext =
'.cif' if mmcif
else '.pdb'
516 if self.replica_exchange:
518 with open(self.best_score_file_name)
as fh:
519 self.best_score_list = ast.literal_eval(
520 fh.read().split(
'=')[1])
522 if len(self.best_score_list) < self.nbestscoring:
523 self.best_score_list.append(score)
524 self.best_score_list.sort()
525 index = self.best_score_list.index(score)
526 for prefix
in self.prefixes:
527 for i
in range(len(self.best_score_list) - 2, index - 1, -1):
528 oldname = prefix +
"." + str(i) + fileext
529 newname = prefix +
"." + str(i + 1) + fileext
531 if os.path.exists(newname):
533 os.rename(oldname, newname)
534 filetoadd = prefix +
"." + str(index) + fileext
535 self.write_pdb(filetoadd, appendmode=
False)
538 if score < self.best_score_list[-1]:
539 self.best_score_list.append(score)
540 self.best_score_list.sort()
541 self.best_score_list.pop(-1)
542 index = self.best_score_list.index(score)
543 for prefix
in self.prefixes:
544 for i
in range(len(self.best_score_list) - 1,
546 oldname = prefix +
"." + str(i) + fileext
547 newname = prefix +
"." + str(i + 1) + fileext
548 os.rename(oldname, newname)
549 filenametoremove = prefix + \
550 "." + str(self.nbestscoring) + fileext
551 os.remove(filenametoremove)
552 filetoadd = prefix +
"." + str(index) + fileext
553 self.write_pdb(filetoadd, appendmode=
False)
555 if self.replica_exchange:
557 with open(self.best_score_file_name,
"w")
as best_score_file:
558 best_score_file.write(
559 "self.best_score_list=" + str(self.best_score_list) +
'\n')
561 def init_rmf(self, name, hierarchies, rs=None, geometries=None,
564 Initialize an RMF file
566 @param name the name of the RMF file
567 @param hierarchies the hierarchies to be included (it is a list)
568 @param rs optional, the restraint sets (it is a list)
569 @param geometries optional, the geometries (it is a list)
570 @param listofobjects optional, the list of objects for the stat
573 rh = RMF.create_rmf_file(name)
576 outputkey_rmfkey =
None
580 if geometries
is not None:
583 callable_objects = []
584 if listofobjects
is not None:
585 cat = rh.get_category(
"stat")
586 outputkey_rmfkey = {}
587 for o
in listofobjects:
588 if not hasattr(o,
"get_output"):
590 "Output: object %s doesn't have get_output() method"
594 output = o.get_output()
596 callable_objects.append(output)
597 output = output(
None)
599 dict_objects.append(o)
600 for outputkey
in output:
601 rmftag = RMF.string_tag
602 if isinstance(output[outputkey], float):
603 rmftag = RMF.float_tag
604 elif isinstance(output[outputkey], int):
606 elif isinstance(output[outputkey], str):
607 rmftag = RMF.string_tag
609 rmftag = RMF.string_tag
610 rmfkey = rh.get_key(cat, outputkey, rmftag)
611 outputkey_rmfkey[outputkey] = rmfkey
612 outputkey_rmfkey[
"rmf_file"] = \
613 rh.get_key(cat,
"rmf_file", RMF.string_tag)
614 outputkey_rmfkey[
"rmf_frame_index"] = \
615 rh.get_key(cat,
"rmf_frame_index", RMF.int_tag)
617 self.dictionary_rmfs[name] = (rh, cat, outputkey_rmfkey,
618 dict_objects, callable_objects)
620 def add_restraints_to_rmf(self, name, objectlist):
621 for o
in _flatten(objectlist):
623 rs = o.get_restraint_for_rmf()
624 if not isinstance(rs, (list, tuple)):
627 rs = [o.get_restraint()]
629 self.dictionary_rmfs[name][0], rs)
631 def add_geometries_to_rmf(self, name, objectlist):
633 geos = o.get_geometries()
636 def add_particle_pair_from_restraints_to_rmf(self, name, objectlist):
639 pps = o.get_particle_pairs()
642 self.dictionary_rmfs[name][0],
645 def write_rmf(self, name):
647 if self.dictionary_rmfs[name][1]
is not None:
648 outputkey_rmfkey = self.dictionary_rmfs[name][2]
649 dict_objects = self.dictionary_rmfs[name][3]
650 callable_objects = self.dictionary_rmfs[name][4]
653 for obj
in dict_objects:
654 yield obj.get_output()
655 for obj
in callable_objects:
658 for output
in all_output():
659 for outputkey
in output:
660 rmfkey = outputkey_rmfkey[outputkey]
662 n = self.dictionary_rmfs[name][0].get_root_node()
663 n.set_value(rmfkey, output[outputkey])
664 except NotImplementedError:
666 rmfkey = outputkey_rmfkey[
"rmf_file"]
667 self.dictionary_rmfs[name][0].get_root_node().set_value(
669 rmfkey = outputkey_rmfkey[
"rmf_frame_index"]
671 self.dictionary_rmfs[name][0].get_root_node().set_value(
673 self.dictionary_rmfs[name][0].flush()
675 def close_rmf(self, name):
676 rh = self.dictionary_rmfs[name][0]
677 del self.dictionary_rmfs[name]
680 def write_rmfs(self):
681 for rmfinfo
in self.dictionary_rmfs.keys():
682 self.write_rmf(rmfinfo[0])
685 def init_stat(self, name, listofobjects):
687 flstat = open(name,
'w')
690 flstat = open(name,
'wb')
694 for o
in listofobjects:
695 if not hasattr(o,
"get_output"):
697 "Output: object %s doesn't have get_output() method"
699 self.dictionary_stats[name] = listofobjects
701 def set_output_entry(self, key, value):
702 self.initoutput.update({key: value})
705 def write_stat(self, name, appendmode=True):
706 output = self.initoutput
707 for obj
in self.dictionary_stats[name]:
710 dfiltered = dict((k, v)
for k, v
in d.items()
if k[0] !=
"_")
711 output.update(dfiltered)
719 with open(name, writeflag)
as flstat:
720 flstat.write(
"%s \n" % output)
722 with open(name, writeflag +
'b')
as flstat:
723 pickle.dump(output, flstat, 2)
726 def write_stats(self):
727 for stat
in self.dictionary_stats.keys():
728 self.write_stat(stat)
730 def get_stat(self, name):
732 for obj
in self.dictionary_stats[name]:
733 output.update(obj.get_output())
736 def write_test(self, name, listofobjects):
737 flstat = open(name,
'w')
738 output = self.initoutput
739 for o
in listofobjects:
740 if (
not hasattr(o,
"get_test_output")
741 and not hasattr(o,
"get_output")):
743 "Output: object %s doesn't have get_output() or "
744 "get_test_output() method" % str(o))
745 self.dictionary_stats[name] = listofobjects
747 for obj
in self.dictionary_stats[name]:
749 d = obj.get_test_output()
750 except AttributeError:
756 dfiltered = dict((k, v)
for k, v
in d.items()
if k[0] !=
"_")
757 output.update(dfiltered)
758 flstat.write(
"%s \n" % output)
761 def test(self, name, listofobjects, tolerance=1e-5):
762 output = self.initoutput
763 for o
in listofobjects:
764 if (
not hasattr(o,
"get_test_output")
765 and not hasattr(o,
"get_output")):
767 "Output: object %s doesn't have get_output() or "
768 "get_test_output() method" % str(o))
769 for obj
in listofobjects:
771 out = obj.get_test_output()
772 except AttributeError:
773 out = obj.get_output()
779 flstat = open(name,
'r')
783 test_dict = ast.literal_eval(fl)
786 old_value = str(test_dict[k])
787 new_value = str(output[k])
795 fold = float(old_value)
796 fnew = float(new_value)
797 diff = abs(fold - fnew)
799 print(
"%s: test failed, old value: %s new value %s; "
800 "diff %f > %f" % (str(k), str(old_value),
801 str(new_value), diff,
802 tolerance), file=sys.stderr)
804 elif test_dict[k] != output[k]:
805 if len(old_value) < 50
and len(new_value) < 50:
806 print(
"%s: test failed, old value: %s new value %s"
807 % (str(k), old_value, new_value),
811 print(
"%s: test failed, omitting results (too long)"
812 % str(k), file=sys.stderr)
816 print(
"%s from old objects (file %s) not in new objects"
817 % (str(k), str(name)), file=sys.stderr)
821 def get_environment_variables(self):
823 return str(os.environ)
825 def get_versions_of_relevant_modules(self):
832 versions[
"ISD2_VERSION"] = IMP.isd2.get_module_version()
837 versions[
"ISD_EMXL_VERSION"] = IMP.isd_emxl.get_module_version()
843 listofsummedobjects=
None, jax_model=
None):
844 """Write the header for a stat file in v2 format.
845 Lines can then be written to the stat file by calling write_stat2()
846 with the same file name.
848 @param name The file name to write to.
849 @param listofobjects PMI objects that will be reported in the file.
850 Each object must implement the get_output() method.
851 This can either return a dict containing data from the
852 current state of the model, or a callable which returns
853 a similar dict of data each time it is called.
861 if listofsummedobjects
is None:
862 listofsummedobjects = []
863 if extralabels
is None:
865 flstat = open(name,
'w')
867 stat2_keywords = {
"STAT2HEADER":
"STAT2HEADER"}
868 stat2_keywords.update(
869 {
"STAT2HEADER_ENVIRON": str(self.get_environment_variables())})
870 stat2_keywords.update(
871 {
"STAT2HEADER_IMP_VERSIONS":
872 str(self.get_versions_of_relevant_modules())})
876 callable_objects = []
877 for obj
in listofobjects:
878 if not hasattr(obj,
"get_output"):
880 "Output: object %s doesn't have get_output() method"
887 callable_objects.append(d)
890 dict_objects.append(obj)
892 dfiltered = dict((k, v)
893 for k, v
in d.items()
if k[0] !=
"_")
894 output.update(dfiltered)
897 for obj
in listofsummedobjects:
899 if not hasattr(t,
"get_output"):
901 "Output: object %s doesn't have get_output() method"
904 if "_TotalScore" not in t.get_output():
906 "Output: object %s doesn't have _TotalScore "
907 "entry to be summed" % str(t))
909 output.update({obj[1]: 0.0})
911 for k
in extralabels:
912 output.update({k: 0.0})
914 for n, k
in enumerate(output):
915 stat2_keywords.update({n: k})
916 stat2_inverse.update({k: n})
918 flstat.write(
"%s \n" % stat2_keywords)
920 self.dictionary_stats2[name] = (
921 dict_objects, callable_objects,
927 """Write a single line to a stat file previously created
930 @param name The file name to write to.
933 (dict_objects, callable_objects, stat2_inverse, listofsummedobjects,
934 extralabels) = self.dictionary_stats2[name]
937 for obj
in dict_objects:
938 yield obj.get_output()
939 for obj
in callable_objects:
943 for od
in all_output():
944 dfiltered = dict((k, v)
for k, v
in od.items()
if k[0] !=
"_")
946 output.update({stat2_inverse[k]: od[k]})
949 for so
in listofsummedobjects:
953 partial_score += float(d[
"_TotalScore"])
954 output.update({stat2_inverse[so[1]]: str(partial_score)})
957 for k
in extralabels:
958 if k
in self.initoutput:
959 output.update({stat2_inverse[k]: self.initoutput[k]})
961 output.update({stat2_inverse[k]:
"None"})
963 with open(name,
'a' if appendmode
else 'w')
as flstat:
964 flstat.write(
"%s \n" % output)
966 def write_stats2(self):
967 for stat
in self.dictionary_stats2.keys():
972 """Collect statistics from ProcessOutput.get_fields().
973 Counters of the total number of frames read, plus the models that
974 passed the various filters used in get_fields(), are provided."""
977 self.passed_get_every = 0
978 self.passed_filterout = 0
979 self.passed_filtertuple = 0
983 """A class for reading stat files (either rmf or ascii v1 and v2)"""
984 def __init__(self, filename):
985 self.filename = filename
990 if self.filename
is None:
991 raise ValueError(
"No file name provided. Use -h for help")
995 rh = RMF.open_rmf_file_read_only(self.filename)
997 cat = rh.get_category(
'stat')
998 rmf_klist = rh.get_keys(cat)
999 self.rmf_names_keys = dict([(rh.get_name(k), k)
1000 for k
in rmf_klist])
1004 f = open(self.filename,
"r")
1007 for line
in f.readlines():
1008 d = ast.literal_eval(line)
1009 self.klist = list(d.keys())
1011 if "STAT2HEADER" in self.klist:
1013 for k
in self.klist:
1014 if "STAT2HEADER" in str(k):
1020 for k
in sorted(stat2_dict.items(),
1021 key=operator.itemgetter(1))]
1023 for k
in sorted(stat2_dict.items(),
1024 key=operator.itemgetter(1))]
1025 self.invstat2_dict = {}
1027 self.invstat2_dict.update({stat2_dict[k]: k})
1030 "statfile v1 is deprecated. "
1031 "Please convert to statfile v2.\n")
1040 return sorted(self.rmf_names_keys.keys())
1044 def show_keys(self, ncolumns=2, truncate=65):
1045 IMP.pmi.tools.print_multicolumn(self.get_keys(), ncolumns, truncate)
1047 def get_fields(self, fields, filtertuple=None, filterout=None, get_every=1,
1050 Get the desired field names, and return a dictionary.
1051 Namely, "fields" are the queried keys in the stat file
1052 (eg. ["Total_Score",...])
1053 The returned data structure is a dictionary, where each key is
1054 a field and the value is the time series (ie, frame ordered series)
1055 of that field (ie, {"Total_Score":[Score_0,Score_1,Score_2,,...],....})
1057 @param fields (list of strings) queried keys in the stat file
1058 (eg. "Total_Score"....)
1059 @param filterout specify if you want to "grep" out something from
1060 the file, so that it is faster
1061 @param filtertuple a tuple that contains
1062 ("TheKeyToBeFiltered",relationship,value)
1063 where relationship = "<", "==", or ">"
1064 @param get_every only read every Nth line from the file
1065 @param statistics if provided, accumulate statistics in an
1066 OutputStatistics object
1069 if statistics
is None:
1072 for field
in fields:
1077 rh = RMF.open_rmf_file_read_only(self.filename)
1078 nframes = rh.get_number_of_frames()
1079 for i
in range(nframes):
1080 statistics.total += 1
1082 statistics.passed_get_every += 1
1083 statistics.passed_filterout += 1
1084 rh.set_current_frame(RMF.FrameID(i))
1085 if filtertuple
is not None:
1086 keytobefiltered = filtertuple[0]
1087 relationship = filtertuple[1]
1088 value = filtertuple[2]
1089 datavalue = rh.get_root_node().get_value(
1090 self.rmf_names_keys[keytobefiltered])
1091 if self.isfiltered(datavalue, relationship, value):
1094 statistics.passed_filtertuple += 1
1095 for field
in fields:
1096 outdict[field].append(rh.get_root_node().get_value(
1097 self.rmf_names_keys[field]))
1100 f = open(self.filename,
"r")
1103 for line
in f.readlines():
1104 statistics.total += 1
1105 if filterout
is not None:
1106 if filterout
in line:
1108 statistics.passed_filterout += 1
1111 if line_number % get_every != 0:
1112 if line_number == 1
and self.isstat2:
1113 statistics.total -= 1
1114 statistics.passed_filterout -= 1
1116 statistics.passed_get_every += 1
1118 d = ast.literal_eval(line)
1120 print(
"# Warning: skipped line number " + str(line_number)
1121 +
" not a valid line")
1126 if filtertuple
is not None:
1127 keytobefiltered = filtertuple[0]
1128 relationship = filtertuple[1]
1129 value = filtertuple[2]
1130 datavalue = d[keytobefiltered]
1131 if self.isfiltered(datavalue, relationship, value):
1134 statistics.passed_filtertuple += 1
1135 [outdict[field].append(d[field])
for field
in fields]
1138 if line_number == 1:
1139 statistics.total -= 1
1140 statistics.passed_filterout -= 1
1141 statistics.passed_get_every -= 1
1144 if filtertuple
is not None:
1145 keytobefiltered = filtertuple[0]
1146 relationship = filtertuple[1]
1147 value = filtertuple[2]
1148 datavalue = d[self.invstat2_dict[keytobefiltered]]
1149 if self.isfiltered(datavalue, relationship, value):
1152 statistics.passed_filtertuple += 1
1153 [outdict[field].append(d[self.invstat2_dict[field]])
1154 for field
in fields]
1160 def isfiltered(self, datavalue, relationship, refvalue):
1163 _ = float(datavalue)
1165 raise ValueError(
"ProcessOutput.filter: datavalue cannot be "
1166 "converted into a float")
1168 if relationship ==
"<":
1169 if float(datavalue) >= refvalue:
1171 if relationship ==
">":
1172 if float(datavalue) <= refvalue:
1174 if relationship ==
"==":
1175 if float(datavalue) != refvalue:
1181 """ class to allow more advanced handling of RMF files.
1182 It is both a container and a IMP.atom.Hierarchy.
1183 - it is iterable (while loading the corresponding frame)
1184 - Item brackets [] load the corresponding frame
1185 - slice create an iterator
1186 - can relink to another RMF file
1190 @param model: the IMP.Model()
1191 @param rmf_file_name: str, path of the rmf file
1195 self.rh_ref = RMF.open_rmf_file_read_only(rmf_file_name)
1197 raise TypeError(
"Wrong rmf file name or type: %s"
1198 % str(rmf_file_name))
1201 self.root_hier_ref = hs[0]
1202 super().
__init__(self.root_hier_ref)
1204 self.ColorHierarchy =
None
1208 Link to another RMF file
1210 self.rh_ref = RMF.open_rmf_file_read_only(rmf_file_name)
1212 if self.ColorHierarchy:
1213 self.ColorHierarchy.method()
1214 RMFHierarchyHandler.set_frame(self, 0)
1216 def set_frame(self, index):
1220 print(
"skipping frame %s:%d\n" % (self.current_rmf, index))
1224 return self.rh_ref.get_number_of_frames()
1226 def __getitem__(self, int_slice_adaptor):
1227 if isinstance(int_slice_adaptor, int):
1228 self.set_frame(int_slice_adaptor)
1229 return int_slice_adaptor
1230 elif isinstance(int_slice_adaptor, slice):
1231 return self.__iter__(int_slice_adaptor)
1233 raise TypeError(
"Unknown Type")
1236 return self.get_number_of_frames()
1238 def __iter__(self, slice_key=None):
1239 if slice_key
is None:
1240 for nframe
in range(len(self)):
1243 for nframe
in list(range(len(self)))[slice_key]:
1247 class CacheHierarchyCoordinates:
1248 def __init__(self, StatHierarchyHandler):
1255 self.current_index =
None
1256 self.rmfh = StatHierarchyHandler
1258 self.model = self.rmfh.get_model()
1263 self.nrms.append(nrm)
1266 self.xyzs.append(fb)
1268 def do_store(self, index):
1269 self.rb_trans[index] = {}
1270 self.nrm_coors[index] = {}
1271 self.xyz_coors[index] = {}
1273 self.rb_trans[index][rb] = rb.get_reference_frame()
1274 for nrm
in self.nrms:
1275 self.nrm_coors[index][nrm] = nrm.get_internal_coordinates()
1276 for xyz
in self.xyzs:
1277 self.xyz_coors[index][xyz] = xyz.get_coordinates()
1278 self.current_index = index
1280 def do_update(self, index):
1281 if self.current_index != index:
1283 rb.set_reference_frame(self.rb_trans[index][rb])
1284 for nrm
in self.nrms:
1285 nrm.set_internal_coordinates(self.nrm_coors[index][nrm])
1286 for xyz
in self.xyzs:
1287 xyz.set_coordinates(self.xyz_coors[index][xyz])
1288 self.current_index = index
1292 return len(self.rb_trans.keys())
1294 def __getitem__(self, index):
1295 if isinstance(index, int):
1296 return index
in self.rb_trans.keys()
1298 raise TypeError(
"Unknown Type")
1301 return self.get_number_of_frames()
1305 """ class to link stat files to several rmf files """
1306 def __init__(self, model=None, stat_file=None,
1307 number_best_scoring_models=
None, score_key=
None,
1308 StatHierarchyHandler=
None, cache=
None):
1311 @param model: IMP.Model()
1312 @param stat_file: either 1) a list or 2) a single stat file names
1313 (either rmfs or ascii, or pickled data or pickled cluster),
1314 3) a dictionary containing an rmf/ascii
1315 stat file name as key and a list of frames as values
1316 @param number_best_scoring_models:
1317 @param StatHierarchyHandler: copy constructor input object
1318 @param cache: cache coordinates and rigid body transformations.
1321 if StatHierarchyHandler
is not None:
1325 self.model = StatHierarchyHandler.model
1326 self.data = StatHierarchyHandler.data
1327 self.number_best_scoring_models = \
1328 StatHierarchyHandler.number_best_scoring_models
1329 self.is_setup =
True
1330 self.current_rmf = StatHierarchyHandler.current_rmf
1331 self.current_frame =
None
1332 self.current_index =
None
1333 self.score_threshold = StatHierarchyHandler.score_threshold
1334 self.score_key = StatHierarchyHandler.score_key
1335 self.cache = StatHierarchyHandler.cache
1336 super().
__init__(self.model, self.current_rmf)
1338 self.cache = CacheHierarchyCoordinates(self)
1347 self.number_best_scoring_models = number_best_scoring_models
1350 if score_key
is None:
1351 self.score_key =
"Total_Score"
1353 self.score_key = score_key
1354 self.is_setup =
None
1355 self.current_rmf =
None
1356 self.current_frame =
None
1357 self.current_index =
None
1358 self.score_threshold =
None
1360 if isinstance(stat_file, str):
1361 self.add_stat_file(stat_file)
1362 elif isinstance(stat_file, list):
1364 self.add_stat_file(f)
1366 def add_stat_file(self, stat_file):
1368 '''check that it is not a pickle file with saved data
1369 from a previous calculation'''
1370 self.load_data(stat_file)
1372 if self.number_best_scoring_models:
1373 scores = self.get_scores()
1374 max_score = sorted(scores)[
1375 0:min(len(self), self.number_best_scoring_models)][-1]
1376 self.do_filter_by_score(max_score)
1378 except pickle.UnpicklingError:
1379 '''alternatively read the ascii stat files'''
1381 scores, rmf_files, rmf_frame_indexes, features = \
1382 self.get_info_from_stat_file(stat_file,
1383 self.score_threshold)
1384 except (KeyError, SyntaxError):
1389 rh = RMF.open_rmf_file_read_only(stat_file)
1390 nframes = rh.get_number_of_frames()
1391 scores = [0.0]*nframes
1392 rmf_files = [stat_file]*nframes
1393 rmf_frame_indexes = range(nframes)
1398 if len(set(rmf_files)) > 1:
1399 raise (
"Multiple RMF files found")
1402 print(
"StatHierarchyHandler: Error: Trying to set none as "
1403 "rmf_file (probably empty stat file), aborting")
1406 for n, index
in enumerate(rmf_frame_indexes):
1407 featn_dict = dict([(k, features[k][n])
for k
in features])
1409 stat_file, rmf_files[n], index, scores[n], featn_dict))
1411 if self.number_best_scoring_models:
1412 scores = self.get_scores()
1413 max_score = sorted(scores)[
1414 0:min(len(self), self.number_best_scoring_models)][-1]
1415 self.do_filter_by_score(max_score)
1417 if not self.is_setup:
1418 RMFHierarchyHandler.__init__(
1419 self, self.model, self.get_rmf_names()[0])
1421 self.cache = CacheHierarchyCoordinates(self)
1424 self.is_setup =
True
1425 self.current_rmf = self.get_rmf_names()[0]
1429 def save_data(self, filename='data.pkl'):
1430 with open(filename,
'wb')
as fl:
1431 pickle.dump(self.data, fl)
1433 def load_data(self, filename='data.pkl'):
1434 with open(filename,
'rb')
as fl:
1435 data_structure = pickle.load(fl)
1437 if not isinstance(data_structure, list):
1439 "%filename should contain a list of IMP.pmi.output.DataEntry "
1440 "or IMP.pmi.output.Cluster" % filename)
1443 for item
in data_structure):
1444 self.data = data_structure
1446 for item
in data_structure):
1448 for cluster
in data_structure:
1449 nmodels += len(cluster)
1450 self.data = [
None]*nmodels
1451 for cluster
in data_structure:
1452 for n, data
in enumerate(cluster):
1453 index = cluster.members[n]
1454 self.data[index] = data
1457 "%filename should contain a list of IMP.pmi.output.DataEntry "
1458 "or IMP.pmi.output.Cluster" % filename)
1460 def set_frame(self, index):
1461 if self.cache
is not None and self.cache[index]:
1462 self.cache.do_update(index)
1464 nm = self.data[index].rmf_name
1465 fidx = self.data[index].rmf_index
1466 if nm != self.current_rmf:
1468 self.current_rmf = nm
1469 self.current_frame = -1
1470 if fidx != self.current_frame:
1471 RMFHierarchyHandler.set_frame(self, fidx)
1472 self.current_frame = fidx
1473 if self.cache
is not None:
1474 self.cache.do_store(index)
1476 self.current_index = index
1478 def __getitem__(self, int_slice_adaptor):
1479 if isinstance(int_slice_adaptor, int):
1480 self.set_frame(int_slice_adaptor)
1481 return self.data[int_slice_adaptor]
1482 elif isinstance(int_slice_adaptor, slice):
1483 return self.__iter__(int_slice_adaptor)
1485 raise TypeError(
"Unknown Type")
1488 return len(self.data)
1490 def __iter__(self, slice_key=None):
1491 if slice_key
is None:
1492 for i
in range(len(self)):
1495 for i
in range(len(self))[slice_key]:
1498 def do_filter_by_score(self, maximum_score):
1499 self.data = [d
for d
in self.data
if d.score <= maximum_score]
1501 def get_scores(self):
1502 return [d.score
for d
in self.data]
1504 def get_feature_series(self, feature_name):
1505 return [d.features[feature_name]
for d
in self.data]
1507 def get_feature_names(self):
1508 return self.data[0].features.keys()
1510 def get_rmf_names(self):
1511 return [d.rmf_name
for d
in self.data]
1513 def get_stat_files_names(self):
1514 return [d.stat_file
for d
in self.data]
1516 def get_rmf_indexes(self):
1517 return [d.rmf_index
for d
in self.data]
1519 def get_info_from_stat_file(self, stat_file, score_threshold=None):
1523 [stat_file], score_key=self.score_key, feature_keys=fs,
1524 rmf_file_key=
"rmf_file", rmf_file_frame_key=
"rmf_frame_index",
1525 prefiltervalue=score_threshold, get_every=1)
1527 scores = [float(y)
for y
in models[2]]
1528 rmf_files = models[0]
1529 rmf_frame_indexes = models[1]
1530 features = models[3]
1531 return scores, rmf_files, rmf_frame_indexes, features
1536 A class to store data associated to a model
1538 def __init__(self, stat_file=None, rmf_name=None, rmf_index=None,
1539 score=
None, features=
None):
1540 self.rmf_name = rmf_name
1541 self.rmf_index = rmf_index
1543 self.features = features
1544 self.stat_file = stat_file
1547 s =
"IMP.pmi.output.DataEntry\n"
1548 s +=
"---- stat file %s \n" % (self.stat_file)
1549 s +=
"---- rmf file %s \n" % (self.rmf_name)
1550 s +=
"---- rmf index %s \n" % (str(self.rmf_index))
1551 s +=
"---- score %s \n" % (str(self.score))
1552 s +=
"---- number of features %s \n" % (str(len(self.features.keys())))
1558 A container for models organized into clusters
1560 def __init__(self, cid=None):
1561 self.cluster_id = cid
1563 self.precision =
None
1564 self.center_index =
None
1565 self.members_data = {}
1567 def add_member(self, index, data=None):
1568 self.members.append(index)
1569 self.members_data[index] = data
1570 self.average_score = self.compute_score()
1572 def compute_score(self):
1574 score = sum([d.score
for d
in self])/len(self)
1575 except AttributeError:
1580 s =
"IMP.pmi.output.Cluster\n"
1581 s +=
"---- cluster_id %s \n" % str(self.cluster_id)
1582 s +=
"---- precision %s \n" % str(self.precision)
1583 s +=
"---- average score %s \n" % str(self.average_score)
1584 s +=
"---- number of members %s \n" % str(len(self.members))
1585 s +=
"---- center index %s \n" % str(self.center_index)
1588 def __getitem__(self, int_slice_adaptor):
1589 if isinstance(int_slice_adaptor, int):
1590 index = self.members[int_slice_adaptor]
1591 return self.members_data[index]
1592 elif isinstance(int_slice_adaptor, slice):
1593 return self.__iter__(int_slice_adaptor)
1595 raise TypeError(
"Unknown Type")
1598 return len(self.members)
1600 def __iter__(self, slice_key=None):
1601 if slice_key
is None:
1602 for i
in range(len(self)):
1605 for i
in range(len(self))[slice_key]:
1608 def __add__(self, other):
1609 self.members += other.members
1610 self.members_data.update(other.members_data)
1611 self.average_score = self.compute_score()
1612 self.precision =
None
1613 self.center_index =
None
1617 def plot_clusters_populations(clusters):
1620 for cluster
in clusters:
1621 indexes.append(cluster.cluster_id)
1622 populations.append(len(cluster))
1624 import matplotlib.pyplot
as plt
1625 fig, ax = plt.subplots()
1626 ax.bar(indexes, populations, 0.5, color=
'r')
1627 ax.set_ylabel('Population')
1628 ax.set_xlabel((
'Cluster index'))
1632 def plot_clusters_precisions(clusters):
1635 for cluster
in clusters:
1636 indexes.append(cluster.cluster_id)
1638 prec = cluster.precision
1639 print(cluster.cluster_id, prec)
1642 precisions.append(prec)
1644 import matplotlib.pyplot
as plt
1645 fig, ax = plt.subplots()
1646 ax.bar(indexes, precisions, 0.5, color=
'r')
1647 ax.set_ylabel('Precision [A]')
1648 ax.set_xlabel((
'Cluster index'))
1652 def plot_clusters_scores(clusters):
1655 for cluster
in clusters:
1656 indexes.append(cluster.cluster_id)
1658 for data
in cluster:
1659 values[-1].append(data.score)
1662 valuename=
"Scores", positionname=
"Cluster index",
1663 xlabels=
None, scale_plot_length=1.0)
1666 class CrossLinkIdentifierDatabase:
1670 def check_key(self, key):
1671 if key
not in self.clidb:
1672 self.clidb[key] = {}
1674 def set_unique_id(self, key, value):
1676 self.clidb[key][
"XLUniqueID"] = str(value)
1678 def set_protein1(self, key, value):
1680 self.clidb[key][
"Protein1"] = str(value)
1682 def set_protein2(self, key, value):
1684 self.clidb[key][
"Protein2"] = str(value)
1686 def set_residue1(self, key, value):
1688 self.clidb[key][
"Residue1"] = int(value)
1690 def set_residue2(self, key, value):
1692 self.clidb[key][
"Residue2"] = int(value)
1694 def set_idscore(self, key, value):
1696 self.clidb[key][
"IDScore"] = float(value)
1698 def set_state(self, key, value):
1700 self.clidb[key][
"State"] = int(value)
1702 def set_sigma1(self, key, value):
1704 self.clidb[key][
"Sigma1"] = str(value)
1706 def set_sigma2(self, key, value):
1708 self.clidb[key][
"Sigma2"] = str(value)
1710 def set_psi(self, key, value):
1712 self.clidb[key][
"Psi"] = str(value)
1714 def get_unique_id(self, key):
1715 return self.clidb[key][
"XLUniqueID"]
1717 def get_protein1(self, key):
1718 return self.clidb[key][
"Protein1"]
1720 def get_protein2(self, key):
1721 return self.clidb[key][
"Protein2"]
1723 def get_residue1(self, key):
1724 return self.clidb[key][
"Residue1"]
1726 def get_residue2(self, key):
1727 return self.clidb[key][
"Residue2"]
1729 def get_idscore(self, key):
1730 return self.clidb[key][
"IDScore"]
1732 def get_state(self, key):
1733 return self.clidb[key][
"State"]
1735 def get_sigma1(self, key):
1736 return self.clidb[key][
"Sigma1"]
1738 def get_sigma2(self, key):
1739 return self.clidb[key][
"Sigma2"]
1741 def get_psi(self, key):
1742 return self.clidb[key][
"Psi"]
1744 def set_float_feature(self, key, value, feature_name):
1746 self.clidb[key][feature_name] = float(value)
1748 def set_int_feature(self, key, value, feature_name):
1750 self.clidb[key][feature_name] = int(value)
1752 def set_string_feature(self, key, value, feature_name):
1754 self.clidb[key][feature_name] = str(value)
1756 def get_feature(self, key, feature_name):
1757 return self.clidb[key][feature_name]
1759 def write(self, filename):
1760 with open(filename,
'wb')
as handle:
1761 pickle.dump(self.clidb, handle)
1763 def load(self, filename):
1764 with open(filename,
'rb')
as handle:
1765 self.clidb = pickle.load(handle)
1769 """Plot the given fields and save a figure as `output`.
1770 The fields generally are extracted from a stat file
1771 using ProcessOutput.get_fields()."""
1772 import matplotlib
as mpl
1774 import matplotlib.pyplot
as plt
1776 plt.rc(
'lines', linewidth=4)
1777 fig, axs = plt.subplots(nrows=len(fields))
1778 fig.set_size_inches(10.5, 5.5 * len(fields))
1783 if framemin
is None:
1785 if framemax
is None:
1786 framemax = len(fields[key])
1787 x = list(range(framemin, framemax))
1788 y = [float(y)
for y
in fields[key][framemin:framemax]]
1791 axs[n].set_title(key, size=
"xx-large")
1792 axs[n].tick_params(labelsize=18, pad=10)
1795 axs.set_title(key, size=
"xx-large")
1796 axs.tick_params(labelsize=18, pad=10)
1800 plt.subplots_adjust(hspace=0.3)
1805 colors=
None, format=
"png", reference_xline=
None,
1806 yplotrange=
None, xplotrange=
None, normalized=
True,
1808 '''Plot a list of histograms from a value list.
1809 @param name the name of the plot
1810 @param values_lists the list of list of values eg: [[...],[...],[...]]
1811 @param valuename the y-label
1812 @param bins the number of bins
1813 @param colors If None, will use rainbow. Else will use specific list
1814 @param format output format
1815 @param reference_xline plot a reference line parallel to the y-axis
1816 @param yplotrange the range for the y-axis
1817 @param normalized whether the histogram is normalized or not
1818 @param leg_names names for the legend
1821 import matplotlib
as mpl
1823 import matplotlib.pyplot
as plt
1824 import matplotlib.cm
as cm
1825 plt.figure(figsize=(18.0, 9.0))
1828 colors = cm.rainbow(np.linspace(0, 1, len(values_lists)))
1829 for nv, values
in enumerate(values_lists):
1831 if leg_names
is not None:
1832 label = leg_names[nv]
1837 [float(y)
for y
in values], bins=bins, color=col,
1838 density=normalized, histtype=
'step', lw=4, label=label)
1839 except AttributeError:
1841 [float(y)
for y
in values], bins=bins, color=col,
1842 normed=normalized, histtype=
'step', lw=4, label=label)
1845 plt.tick_params(labelsize=12, pad=10)
1846 if valuename
is None:
1847 plt.xlabel(name, size=
"xx-large")
1849 plt.xlabel(valuename, size=
"xx-large")
1850 plt.ylabel(
"Frequency", size=
"xx-large")
1852 if yplotrange
is not None:
1854 if xplotrange
is not None:
1855 plt.xlim(xplotrange)
1859 if reference_xline
is not None:
1866 plt.savefig(name +
"." + format, dpi=150, transparent=
True)
1870 valuename=
"None", positionname=
"None",
1871 xlabels=
None, scale_plot_length=1.0):
1873 Plot time series as boxplots.
1874 fields is a list of time series, positions are the x-values
1875 valuename is the y-label, positionname is the x-label
1878 import matplotlib
as mpl
1880 import matplotlib.pyplot
as plt
1883 fig = plt.figure(figsize=(float(len(positions))*scale_plot_length, 5.0))
1884 fig.canvas.manager.set_window_title(name)
1886 ax1 = fig.add_subplot(111)
1888 plt.subplots_adjust(left=0.1, right=0.990, top=0.95, bottom=0.4)
1890 bps.append(plt.boxplot(values, notch=0, sym=
'', vert=1,
1891 whis=1.5, positions=positions))
1893 plt.setp(bps[-1][
'boxes'], color=
'black', lw=1.5)
1894 plt.setp(bps[-1][
'whiskers'], color=
'black', ls=
":", lw=1.5)
1896 if frequencies
is not None:
1897 for n, v
in enumerate(values):
1898 plist = [positions[n]]*len(v)
1899 ax1.plot(plist, v,
'gx', alpha=0.7, markersize=7)
1902 if xlabels
is not None:
1903 ax1.set_xticklabels(xlabels)
1904 plt.xticks(rotation=90)
1905 plt.xlabel(positionname)
1906 plt.ylabel(valuename)
1908 plt.savefig(name +
".pdf", dpi=150)
1912 def plot_xy_data(x, y, title=None, out_fn=None, display=True,
1913 set_plot_yaxis_range=
None, xlabel=
None, ylabel=
None):
1914 import matplotlib
as mpl
1916 import matplotlib.pyplot
as plt
1917 plt.rc(
'lines', linewidth=2)
1919 fig, ax = plt.subplots(nrows=1)
1920 fig.set_size_inches(8, 4.5)
1921 if title
is not None:
1922 fig.canvas.manager.set_window_title(title)
1924 ax.plot(x, y, color=
'r')
1925 if set_plot_yaxis_range
is not None:
1926 x1, x2, y1, y2 = plt.axis()
1927 y1 = set_plot_yaxis_range[0]
1928 y2 = set_plot_yaxis_range[1]
1929 plt.axis((x1, x2, y1, y2))
1930 if title
is not None:
1932 if xlabel
is not None:
1933 ax.set_xlabel(xlabel)
1934 if ylabel
is not None:
1935 ax.set_ylabel(ylabel)
1936 if out_fn
is not None:
1937 plt.savefig(out_fn +
".pdf")
1943 def plot_scatter_xy_data(x, y, labelx="None", labely="None",
1944 xmin=
None, xmax=
None, ymin=
None, ymax=
None,
1945 savefile=
False, filename=
"None.eps", alpha=0.75):
1947 import matplotlib
as mpl
1949 import matplotlib.pyplot
as plt
1950 from matplotlib
import rc
1951 rc(
'font', **{
'family':
'sans-serif',
'sans-serif': [
'Helvetica']})
1953 fig, axs = plt.subplots(1)
1957 axs0.set_xlabel(labelx, size=
"xx-large")
1958 axs0.set_ylabel(labely, size=
"xx-large")
1959 axs0.tick_params(labelsize=18, pad=10)
1963 plot2.append(axs0.plot(x, y,
'o', color=
'k', lw=2, ms=0.1, alpha=alpha,
1973 fig.set_size_inches(8.0, 8.0)
1974 fig.subplots_adjust(left=0.161, right=0.850, top=0.95, bottom=0.11)
1975 if (ymin
is not None)
and (ymax
is not None):
1976 axs0.set_ylim(ymin, ymax)
1977 if (xmin
is not None)
and (xmax
is not None):
1978 axs0.set_xlim(xmin, xmax)
1981 fig.savefig(filename, dpi=300)
1984 def get_graph_from_hierarchy(hier):
1988 (graph, depth, depth_dict) = recursive_graph(
1989 hier, graph, depth, depth_dict)
1992 node_labels_dict = {}
1993 for key
in depth_dict:
1994 if depth_dict[key] < 3:
1995 node_labels_dict[key] = key
1997 node_labels_dict[key] =
""
1998 draw_graph(graph, labels_dict=node_labels_dict)
2001 def recursive_graph(hier, graph, depth, depth_dict):
2004 index = str(hier.get_particle().
get_index())
2005 name1 = nameh +
"|#" + index
2006 depth_dict[name1] = depth
2010 if len(children) == 1
or children
is None:
2012 return (graph, depth, depth_dict)
2016 (graph, depth, depth_dict) = recursive_graph(
2017 c, graph, depth, depth_dict)
2019 index = str(c.get_particle().
get_index())
2020 namec = nameh +
"|#" + index
2021 graph.append((name1, namec))
2024 return (graph, depth, depth_dict)
2027 def draw_graph(graph, labels_dict=None, graph_layout='spring',
2028 node_size=5, node_color=
None, node_alpha=0.3,
2029 node_text_size=11, fixed=
None, pos=
None,
2030 edge_color=
'blue', edge_alpha=0.3, edge_thickness=1,
2032 validation_edges=
None,
2033 text_font=
'sans-serif',
2036 import matplotlib
as mpl
2038 import networkx
as nx
2039 import matplotlib.pyplot
as plt
2040 from math
import sqrt, pi
2046 if isinstance(edge_thickness, list):
2047 for edge, weight
in zip(graph, edge_thickness):
2048 G.add_edge(edge[0], edge[1], weight=weight)
2051 G.add_edge(edge[0], edge[1])
2053 if node_color
is None:
2054 node_color_rgb = (0, 0, 0)
2055 node_color_hex =
"000000"
2060 for node
in G.nodes():
2061 cctuple = cc.rgb(node_color[node])
2062 tmpcolor_rgb.append((cctuple[0]/255,
2065 tmpcolor_hex.append(node_color[node])
2066 node_color_rgb = tmpcolor_rgb
2067 node_color_hex = tmpcolor_hex
2070 if isinstance(node_size, dict):
2072 for node
in G.nodes():
2073 size = sqrt(node_size[node])/pi*10.0
2074 tmpsize.append(size)
2077 for n, node
in enumerate(G.nodes()):
2078 color = node_color_hex[n]
2080 nx.set_node_attributes(
2082 {node: {
'type':
'ellipse',
'w': size,
'h': size,
2083 'fill':
'#' + color,
'label': node}})
2084 nx.set_node_attributes(
2086 {node: {
'type':
'text',
'text': node,
'color':
'#000000',
2087 'visible':
'true'}})
2089 for edge
in G.edges():
2090 nx.set_edge_attributes(
2092 {edge: {
'width': 1,
'fill':
'#000000'}})
2094 for ve
in validation_edges:
2096 if (ve[0], ve[1])
in G.edges():
2097 print(
"found forward")
2098 nx.set_edge_attributes(
2100 {ve: {
'width': 1,
'fill':
'#00FF00'}})
2101 elif (ve[1], ve[0])
in G.edges():
2102 print(
"found backward")
2103 nx.set_edge_attributes(
2105 {(ve[1], ve[0]): {
'width': 1,
'fill':
'#00FF00'}})
2107 G.add_edge(ve[0], ve[1])
2109 nx.set_edge_attributes(
2111 {ve: {
'width': 1,
'fill':
'#FF0000'}})
2115 if graph_layout ==
'spring':
2117 graph_pos = nx.spring_layout(G, k=1.0/8.0, fixed=fixed, pos=pos)
2118 elif graph_layout ==
'spectral':
2119 graph_pos = nx.spectral_layout(G)
2120 elif graph_layout ==
'random':
2121 graph_pos = nx.random_layout(G)
2123 graph_pos = nx.shell_layout(G)
2126 nx.draw_networkx_nodes(G, graph_pos, node_size=node_size,
2127 alpha=node_alpha, node_color=node_color_rgb,
2129 nx.draw_networkx_edges(G, graph_pos, width=edge_thickness,
2130 alpha=edge_alpha, edge_color=edge_color)
2131 nx.draw_networkx_labels(
2132 G, graph_pos, labels=labels_dict, font_size=node_text_size,
2133 font_family=text_font)
2135 plt.savefig(out_filename)
2136 nx.write_gml(G,
'out.gml')
2144 from ipyD3
import d3object
2145 from IPython.display
import display
2147 d3 = d3object(width=800,
2152 title=
'Example table with d3js',
2153 desc=
'An example table created created with d3js with '
2154 'data generated with Python.')
2155 data = [[1277.0, 654.0, 288.0, 1976.0, 3281.0, 3089.0, 10336.0, 4650.0,
2156 4441.0, 4670.0, 944.0, 110.0],
2157 [1318.0, 664.0, 418.0, 1952.0, 3581.0, 4574.0, 11457.0, 6139.0,
2158 7078.0, 6561.0, 2354.0, 710.0],
2159 [1783.0, 774.0, 564.0, 1470.0, 3571.0, 3103.0, 9392.0, 5532.0,
2160 5661.0, 4991.0, 2032.0, 680.0],
2161 [1301.0, 604.0, 286.0, 2152.0, 3282.0, 3369.0, 10490.0, 5406.0,
2162 4727.0, 3428.0, 1559.0, 620.0],
2163 [1537.0, 1714.0, 724.0, 4824.0, 5551.0, 8096.0, 16589.0, 13650.0,
2164 9552.0, 13709.0, 2460.0, 720.0],
2165 [5691.0, 2995.0, 1680.0, 11741.0, 16232.0, 14731.0, 43522.0,
2166 32794.0, 26634.0, 31400.0, 7350.0, 3010.0],
2167 [1650.0, 2096.0, 60.0, 50.0, 1180.0, 5602.0, 15728.0, 6874.0,
2168 5115.0, 3510.0, 1390.0, 170.0],
2169 [72.0, 60.0, 60.0, 10.0, 120.0, 172.0, 1092.0, 675.0, 408.0,
2170 360.0, 156.0, 100.0]]
2171 data = [list(i)
for i
in zip(*data)]
2172 sRows = [[
'January',
2184 sColumns = [[
'Prod {0}'.format(i)
for i
in range(1, 9)],
2185 [
None,
'',
None,
None,
'Group 1',
None,
None,
'Group 2']]
2186 d3.addSimpleTable(data,
2187 fontSizeCells=[12, ],
2190 sRowsMargins=[5, 50, 0],
2191 sColsMargins=[5, 20, 10],
2194 addOutsideBorders=-1,
2198 html = d3.render(mode=[
'html',
'show'])
static bool get_is_setup(const IMP::ParticleAdaptor &p)
A container for models organized into clusters.
A class for reading stat files (either rmf or ascii v1 and v2)
atom::Hierarchies create_hierarchies(RMF::FileConstHandle fh, Model *m)
RMF::FrameID save_frame(RMF::FileHandle file, std::string name="")
Save the current state of the linked objects as a new RMF frame.
static bool get_is_setup(const IMP::ParticleAdaptor &p)
def plot_field_histogram
Plot a list of histograms from a value list.
def plot_fields_box_plots
Plot time series as boxplots.
Utility classes and functions for reading and storing PMI files.
def get_best_models
Given a list of stat files, read them all and find the best models.
A class to store data associated to a model.
void handle_use_deprecated(std::string message)
Break in this method in gdb to find deprecated uses at runtime.
std::string get_module_version()
Return the version of this module, as a string.
void write_pdb(const Selection &mhd, TextOutput out, unsigned int model=1)
Collect statistics from ProcessOutput.get_fields().
static bool get_is_setup(const IMP::ParticleAdaptor &p)
def get_fields
Get the desired field names, and return a dictionary.
Warning related to handling of structures.
static bool get_is_setup(Model *m, ParticleIndex pi)
def link_to_rmf
Link to another RMF file.
std::string get_molecule_name_and_copy(atom::Hierarchy h)
Walk up a PMI2 hierarchy/representations and get the "molname.copynum".
The standard decorator for manipulating molecular structures.
Ints get_index(const ParticlesTemp &particles, const Subset &subset, const Subsets &excluded)
def init_pdb
Init PDB Writing.
def deprecated_method
Python decorator to mark a method as deprecated.
int get_number_of_frames(const ::npctransport_proto::Assignment &config, double time_step)
A decorator for a particle representing an atom.
Base class for capturing a modeling protocol.
def write_stat2
Write a single line to a stat file previously created with init_stat2().
void load_frame(RMF::FileConstHandle file, RMF::FrameID frame)
Load the given RMF frame into the state of the linked objects.
A decorator for a particle with x,y,z coordinates.
void add_hierarchies(RMF::NodeHandle fh, const atom::Hierarchies &hs)
Class for easy writing of PDBs, RMFs, and stat files.
void add_geometries(RMF::NodeHandle parent, const display::GeometriesTemp &r)
Add geometries to a given parent node.
void add_restraints(RMF::NodeHandle fh, const Restraints &hs)
A decorator for a particle that is part of a rigid body but not rigid.
Display a segment connecting a pair of particles.
A decorator for a residue.
Basic functionality that is expected to be used by a wide variety of IMP users.
def get_pdb_names
Get a list of all PDB files being output by this instance.
def get_prot_name_from_particle
Get the protein name from the particle.
class to link stat files to several rmf files
class to allow more advanced handling of RMF files.
void link_hierarchies(RMF::FileConstHandle fh, const atom::Hierarchies &hs)
def plot_fields
Plot the given fields and save a figure as output.
void add_geometry(RMF::FileHandle file, display::Geometry *r)
Add a single geometry to the file.
Store info for a chain of a protein.
Python classes to represent, score, sample and analyze models.
def init_pdb_best_scoring
Prepare for writing best-scoring PDBs (or mmCIFs) for a sampling run.
Functionality for loading, creating, manipulating and scoring atomic structures.
Select hierarchy particles identified by the biological name.
def init_rmf
Initialize an RMF file.
static bool get_is_setup(const IMP::ParticleAdaptor &p)
def init_stat2
Write the header for a stat file in v2 format.
std::string get_module_version()
Return the version of this module, as a string.
A decorator for a particle with x,y,z coordinates and a radius.