IMP logo
IMP Reference Guide  2.5.0
The Integrative Modeling Platform
pmi/Analysis.py
1 #!/usr/bin/env python
2 
3 """@namespace IMP.pmi.analysis
4  Tools for clustering and cluster analysis
5 """
6 from __future__ import print_function
7 import IMP
8 import IMP.algebra
9 import IMP.em
10 import IMP.pmi
11 import IMP.pmi.tools
12 import IMP.pmi.output
13 import IMP.rmf
14 import RMF
15 import IMP.pmi.analysis
16 from operator import itemgetter
17 from copy import deepcopy
18 from math import log,sqrt
19 import itertools
20 import numpy as np
21 
22 
23 class Alignment(object):
24  """Performs alignment and RMSD calculation for two sets of coordinates
25 
26  The class also takes into account non-equal stoichiometry of the proteins.
27  If this is the case, the protein names of proteins in multiple copies
28  should be specified in the following form:
29  nameA..1, nameA..2 (note two dots).
30  """
31 
32  def __init__(self, template, query, weights=None):
33  """Constructor.
34  @param query {'p1':coords(L,3), 'p2':coords(L,3)}
35  @param template {'p1':coords(L,3), 'p2':coords(L,3)}
36  @param weights optional weights for each set of coordinates
37  """
38  self.query = query
39  self.template = template
40  self.weights=weights
41 
42  if len(self.query.keys()) != len(self.template.keys()):
43  raise ValueError('''the number of proteins
44  in template and query does not match!''')
45 
46  def permute(self):
47 
48  self.proteins = sorted(self.query.keys())
49  prots_uniq = [i.split('..')[0] for i in self.proteins]
50  P = {}
51  for p in prots_uniq:
52  np = prots_uniq.count(p)
53  copies = [i for i in self.proteins if i.split('..')[0] == p]
54  prmts = list(itertools.permutations(copies, len(copies)))
55  P[p] = prmts
56  self.P = P
57  self.Product = list(itertools.product(*P.values()))
58 
59  def get_rmsd(self):
60 
61  self.permute()
62 
63  template_xyz = []
64  weights = []
65  torder = sum([list(i) for i in self.Product[0]], [])
66  for t in torder:
67  template_xyz += [IMP.algebra.Vector3D(i) for i in self.template[t]]
68  if self.weights is not None:
69  weights += [i for i in self.weights[t]]
70  #template_xyz = np.array(template_xyz)
71 
72  self.rmsd = 10000000000.
73  for comb in self.Product:
74 
75 
76 
77  order = sum([list(i) for i in comb], [])
78  query_xyz = []
79  for p in order:
80  query_xyz += [IMP.algebra.Vector3D(i) for i in self.query[p]]
81  #query_xyz = np.array(query_xyz)
82  #if len(template_xyz) != len(query_xyz):
83  # print '''Alignment.get_rmsd: ERROR: the number of coordinates
84  # in template and query does not match!'''
85  # exit()
86 
87  if self.weights is not None:
88  dist=IMP.algebra.get_weighted_rmsd(template_xyz, query_xyz, weights)
89  else:
90  dist=IMP.algebra.get_rmsd(template_xyz, query_xyz)
91  #dist = sqrt(
92  # sum(np.diagonal(cdist(template_xyz, query_xyz) ** 2)) / len(template_xyz))
93  if dist < self.rmsd:
94  self.rmsd = dist
95  return self.rmsd
96 
97  def align(self):
98  from scipy.spatial.distance import cdist
99 
100  self.permute()
101 
102  template_xyz = []
103  torder = sum([list(i) for i in self.Product[0]], [])
104  for t in torder:
105  template_xyz += [IMP.algebra.Vector3D(i) for i in self.template[t]]
106  #template_xyz = np.array(template_xyz)
107 
108  self.rmsd, Transformation = 10000000000., ''
109  for comb in self.Product:
110  order = sum([list(i) for i in comb], [])
111  query_xyz = []
112  for p in order:
113  query_xyz += [IMP.algebra.Vector3D(i) for i in self.query[p]]
114  #query_xyz = np.array(query_xyz)
115 
116  if len(template_xyz) != len(query_xyz):
117  raise ValueError('''the number of coordinates
118  in template and query does not match!''')
119 
121  query_xyz,
122  template_xyz)
123  query_xyz_tr = [transformation.get_transformed(n)
124  for n in query_xyz]
125 
126  dist = sqrt(
127  sum(np.diagonal(cdist(template_xyz, query_xyz_tr) ** 2)) / len(template_xyz))
128  if dist < self.rmsd:
129  self.rmsd = dist
130  Transformation = transformation
131 
132  return (self.rmsd, Transformation)
133 
134 
135 # TEST for the alignment ###
136 """
137 Proteins = {'a..1':np.array([np.array([-1.,1.])]),
138  'a..2':np.array([np.array([1.,1.,])]),
139  'a..3':np.array([np.array([-2.,1.])]),
140  'b':np.array([np.array([0.,-1.])]),
141  'c..1':np.array([np.array([-1.,-1.])]),
142  'c..2':np.array([np.array([1.,-1.])]),
143  'd':np.array([np.array([0.,0.])]),
144  'e':np.array([np.array([0.,1.])])}
145 
146 Ali = Alignment(Proteins, Proteins)
147 Ali.permute()
148 if Ali.get_rmsd() == 0.0: print 'successful test!'
149 else: print 'ERROR!'; exit()
150 """
151 
152 
153 # ----------------------------------
154 class Violations(object):
155 
156  def __init__(self, filename):
157 
158  self.violation_thresholds = {}
159  self.violation_counts = {}
160 
161  data = open(filename)
162  D = data.readlines()
163  data.close()
164 
165  for d in D:
166  d = d.strip().split()
167  self.violation_thresholds[d[0]] = float(d[1])
168 
169  def get_number_violated_restraints(self, rsts_dict):
170  num_violated = 0
171  for rst in self.violation_thresholds:
172  if rst not in rsts_dict:
173  continue # print rst;
174  if float(rsts_dict[rst]) > self.violation_thresholds[rst]:
175  num_violated += 1
176  if rst not in self.violation_counts:
177  self.violation_counts[rst] = 1
178  else:
179  self.violation_counts[rst] += 1
180  return num_violated
181 
182 
183 # ----------------------------------
184 class Clustering(object):
185  """A class to cluster structures.
186  Uses scipy's cdist function to compute distance matrices
187  and sklearn's kmeans clustering module.
188  """
189  def __init__(self,rmsd_weights=None):
190  """Constructor.
191  @param rmsd_weights Flat list of weights for each particle
192  (if they're coarse)
193  """
194  try:
195  from mpi4py import MPI
196  self.comm = MPI.COMM_WORLD
197  self.rank = self.comm.Get_rank()
198  self.number_of_processes = self.comm.size
199  except ImportError:
200  self.number_of_processes = 1
201  self.rank = 0
202  self.all_coords = {}
203  self.structure_cluster_ids = None
204  self.tmpl_coords = None
205  self.rmsd_weights=rmsd_weights
206 
207  def set_template(self, part_coords):
208 
209  self.tmpl_coords = part_coords
210 
211  def fill(self, frame, Coords):
212  """Add coordinates for a single model."""
213 
214  self.all_coords[frame] = Coords
215 
216  def dist_matrix(self):
217 
218  self.model_list_names = list(self.all_coords.keys())
219  self.model_indexes = list(range(len(self.model_list_names)))
220  self.model_indexes_dict = dict(
221  list(zip(self.model_list_names, self.model_indexes)))
222  model_indexes_unique_pairs = list(itertools.combinations(self.model_indexes, 2))
223 
224  my_model_indexes_unique_pairs = IMP.pmi.tools.chunk_list_into_segments(
225  model_indexes_unique_pairs,
226  self.number_of_processes)[self.rank]
227 
228  print("process %s assigned with %s pairs" % (str(self.rank), str(len(my_model_indexes_unique_pairs))))
229 
230  (raw_distance_dict, self.transformation_distance_dict) = self.matrix_calculation(self.all_coords,
231  self.tmpl_coords,
232  my_model_indexes_unique_pairs)
233 
234  if self.number_of_processes > 1:
235  raw_distance_dict = IMP.pmi.tools.scatter_and_gather(
236  raw_distance_dict)
237  pickable_transformations = self.get_pickable_transformation_distance_dict(
238  )
239  pickable_transformations = IMP.pmi.tools.scatter_and_gather(
240  pickable_transformations)
241  self.set_transformation_distance_dict_from_pickable(
242  pickable_transformations)
243 
244  self.raw_distance_matrix = np.zeros(
245  (len(self.model_list_names), len(self.model_list_names)))
246  for item in raw_distance_dict:
247  (f1, f2) = item
248  self.raw_distance_matrix[f1, f2] = raw_distance_dict[item]
249  self.raw_distance_matrix[f2, f1] = raw_distance_dict[item]
250 
251  def get_dist_matrix(self):
252  return self.raw_distance_matrix
253 
254  def do_cluster(self, number_of_clusters,seed=None):
255  """Run K-means clustering
256  @param number_of_clusters Num means
257  @param seed the random seed
258  """
259  from sklearn.cluster import KMeans
260  if seed is not None:
261  np.random.seed(seed)
262  try:
263  # check whether we have the right version of sklearn
264  kmeans = KMeans(n_clusters=number_of_clusters)
265  except TypeError:
266  # sklearn older than 0.12
267  kmeans = KMeans(k=number_of_clusters)
268  kmeans.fit_predict(self.raw_distance_matrix)
269 
270  self.structure_cluster_ids = kmeans.labels_
271 
272  def get_pickable_transformation_distance_dict(self):
273  pickable_transformations = {}
274  for label in self.transformation_distance_dict:
275  tr = self.transformation_distance_dict[label]
276  trans = tuple(tr.get_translation())
277  rot = tuple(tr.get_rotation().get_quaternion())
278  pickable_transformations[label] = (rot, trans)
279  return pickable_transformations
280 
281  def set_transformation_distance_dict_from_pickable(
282  self,
283  pickable_transformations):
284  self.transformation_distance_dict = {}
285  for label in pickable_transformations:
286  tr = pickable_transformations[label]
287  trans = IMP.algebra.Vector3D(tr[1])
288  rot = IMP.algebra.Rotation3D(tr[0])
289  self.transformation_distance_dict[
290  label] = IMP.algebra.Transformation3D(rot, trans)
291 
292  def save_distance_matrix_file(self, file_name='cluster.rawmatrix.pkl'):
293  import pickle
294  outf = open(file_name + ".data", 'wb')
295 
296  # to pickle the transformation dictionary
297  # you have to save the arrays correposnding to
298  # the transformations
299 
300  pickable_transformations = self.get_pickable_transformation_distance_dict(
301  )
302  pickle.dump(
303  (self.structure_cluster_ids,
304  self.model_list_names,
305  pickable_transformations),
306  outf)
307 
308  np.save(file_name + ".npy", self.raw_distance_matrix)
309 
310  def load_distance_matrix_file(self, file_name='cluster.rawmatrix.pkl'):
311  import pickle
312 
313  inputf = open(file_name + ".data", 'rb')
314  (self.structure_cluster_ids, self.model_list_names,
315  pickable_transformations) = pickle.load(inputf)
316  inputf.close()
317 
318  self.raw_distance_matrix = np.load(file_name + ".npy")
319 
320  self.set_transformation_distance_dict_from_pickable(
321  pickable_transformations)
322  self.model_indexes = list(range(len(self.model_list_names)))
323  self.model_indexes_dict = dict(
324  list(zip(self.model_list_names, self.model_indexes)))
325 
326  def plot_matrix(self, figurename="clustermatrix.pdf"):
327  import pylab as pl
328  from scipy.cluster import hierarchy as hrc
329 
330  fig = pl.figure(figsize=(10,8))
331  ax = fig.add_subplot(212)
332  dendrogram = hrc.dendrogram(
333  hrc.linkage(self.raw_distance_matrix),
334  color_threshold=7,
335  no_labels=True)
336  leaves_order = dendrogram['leaves']
337  ax.set_xlabel('Model')
338  ax.set_ylabel('RMSD [Angstroms]')
339 
340  ax2 = fig.add_subplot(221)
341  cax = ax2.imshow(
342  self.raw_distance_matrix[leaves_order,
343  :][:,
344  leaves_order],
345  interpolation='nearest')
346  cb = fig.colorbar(cax)
347  cb.set_label('RMSD [Angstroms]')
348  ax2.set_xlabel('Model')
349  ax2.set_ylabel('Model')
350 
351  pl.savefig(figurename, dpi=300)
352  pl.close(fig)
353 
354  def get_model_index_from_name(self, name):
355  return self.model_indexes_dict[name]
356 
357  def get_cluster_labels(self):
358  # this list
359  return list(set(self.structure_cluster_ids))
360 
361  def get_number_of_clusters(self):
362  return len(self.get_cluster_labels())
363 
364  def get_cluster_label_indexes(self, label):
365  return (
366  [i for i, l in enumerate(self.structure_cluster_ids) if l == label]
367  )
368 
369  def get_cluster_label_names(self, label):
370  return (
371  [self.model_list_names[i]
372  for i in self.get_cluster_label_indexes(label)]
373  )
374 
375  def get_cluster_label_average_rmsd(self, label):
376 
377  indexes = self.get_cluster_label_indexes(label)
378 
379  if len(indexes) > 1:
380  sub_distance_matrix = self.raw_distance_matrix[
381  indexes, :][:, indexes]
382  average_rmsd = np.sum(sub_distance_matrix) / \
383  (len(sub_distance_matrix)
384  ** 2 - len(sub_distance_matrix))
385  else:
386  average_rmsd = 0.0
387  return average_rmsd
388 
389  def get_cluster_label_size(self, label):
390  return len(self.get_cluster_label_indexes(label))
391 
392  def get_transformation_to_first_member(
393  self,
394  cluster_label,
395  structure_index):
396  reference = self.get_cluster_label_indexes(cluster_label)[0]
397  return self.transformation_distance_dict[(reference, structure_index)]
398 
399  def matrix_calculation(self, all_coords, template_coords, list_of_pairs):
400 
401  model_list_names = list(all_coords.keys())
402  rmsd_protein_names = list(all_coords[model_list_names[0]].keys())
403  raw_distance_dict = {}
404  transformation_distance_dict = {}
405  if template_coords is None:
406  do_alignment = False
407  else:
408  do_alignment = True
409  alignment_template_protein_names = list(template_coords.keys())
410 
411  for (f1, f2) in list_of_pairs:
412 
413  if not do_alignment:
414  # here we only get the rmsd,
415  # we need that for instance when you want to cluster conformations
416  # globally, eg the EM map is a reference
418 
419  coords_f1 = dict([(pr, all_coords[model_list_names[f1]][pr])
420  for pr in rmsd_protein_names])
421  coords_f2 = {}
422  for pr in rmsd_protein_names:
423  coords_f2[pr] = all_coords[model_list_names[f2]][pr]
424 
425  Ali = Alignment(coords_f1, coords_f2, self.rmsd_weights)
426  rmsd = Ali.get_rmsd()
427 
428  elif do_alignment:
429  # here we actually align the conformations first
430  # and than calculate the rmsd. We need that when the
431  # protein(s) is the reference
432  coords_f1 = dict([(pr, all_coords[model_list_names[f1]][pr])
433  for pr in alignment_template_protein_names])
434  coords_f2 = dict([(pr, all_coords[model_list_names[f2]][pr])
435  for pr in alignment_template_protein_names])
436 
437  Ali = Alignment(coords_f1, coords_f2)
438  template_rmsd, transformation = Ali.align()
439 
440  # here we calculate the rmsd
441  # we will align two models based n the nuber of subunits provided
442  # and transform coordinates of model 2 to model 1
443  coords_f1 = dict([(pr, all_coords[model_list_names[f1]][pr])
444  for pr in rmsd_protein_names])
445  coords_f2 = {}
446  for pr in rmsd_protein_names:
447  coords_f2[pr] = [transformation.get_transformed(
448  i) for i in all_coords[model_list_names[f2]][pr]]
449 
450  Ali = Alignment(coords_f1, coords_f2, self.rmsd_weights)
451  rmsd = Ali.get_rmsd()
452 
453  raw_distance_dict[(f1, f2)] = rmsd
454  raw_distance_dict[(f2, f1)] = rmsd
455  transformation_distance_dict[(f1, f2)] = transformation
456  transformation_distance_dict[(f2, f1)] = transformation
457 
458  return raw_distance_dict, transformation_distance_dict
459 
460 
461 class Precision(object):
462  """A class to evaluate the precision of an ensemble.
463 
464  Also can evaluate the cross-precision of multiple ensembles.
465  Supports MPI for coordinate reading.
466  Recommended procedure:
467  -# initialize object and pass the selection for evaluating precision
468  -# call add_structures() to read in the data (specify group name)
469  -# call get_precision() to evaluate inter/intra precision
470  -# call get_rmsf() to evaluate within-group fluctuations
471  """
472  def __init__(self,model,
473  resolution=1,
474  selection_dictionary={}):
475  """Constructor.
476  @param model The IMP Model
477  @param resolution Use 1 or 10 (kluge: requires that "_Res:X" is
478  part of the hier name)
479  @param selection_dictionary Dictionary where keys are names for
480  selections and values are selection tuples for scoring
481  precision. "All" is automatically made as well
482  """
483  try:
484  from mpi4py import MPI
485  self.comm = MPI.COMM_WORLD
486  self.rank = self.comm.Get_rank()
487  self.number_of_processes = self.comm.size
488  except ImportError:
489  self.number_of_processes=1
490  self.rank=0
491 
492  self.styles=['pairwise_rmsd','pairwise_drmsd_k','pairwise_drmsd_Q',
493  'pairwise_drms_k','pairwise_rmsd','drmsd_from_center']
494  self.style='pairwise_drmsd_k'
495  self.structures_dictionary={}
496  self.reference_structures_dictionary={}
497  self.prots=[]
498  self.protein_names=None
499  self.len_particles_resolution_one=None
500  self.model=model
501  self.rmf_names_frames={}
502  self.reference_rmf_names_frames=None
503  self.reference_structure=None
504  self.reference_prot=None
505  self.selection_dictionary=selection_dictionary
506  self.threshold=40.0
507  self.residue_particle_index_map=None
508  if resolution in [1,10]:
509  self.resolution=resolution
510  else:
511  raise KeyError("no such resolution")
512 
513  def _get_structure(self,rmf_frame_index,rmf_name):
514  """Read an RMF file and return the particles"""
515  rh= RMF.open_rmf_file_read_only(rmf_name)
516  prots=IMP.rmf.create_hierarchies(rh, self.model)
517  IMP.rmf.load_frame(rh, RMF.FrameID(rmf_frame_index))
518  print("getting coordinates for frame %i rmf file %s" % (rmf_frame_index, rmf_name))
519  del rh
520 
521  if self.resolution==1:
522  particle_dict = get_particles_at_resolution_one(prots[0])
523  elif self.resolution==10:
524  particle_dict = get_particles_at_resolution_ten(prots[0])
525 
526  protein_names=list(particle_dict.keys())
527  particles_resolution_one=[]
528  for k in particle_dict:
529  particles_resolution_one+=(particle_dict[k])
530 
531  if self.protein_names==None:
532  self.protein_names=protein_names
533  else:
534  if self.protein_names!=protein_names:
535  print("Error: the protein names of the new coordinate set is not compatible with the previous one")
536 
537  if self.len_particles_resolution_one==None:
538  self.len_particles_resolution_one=len(particles_resolution_one)
539  else:
540  if self.len_particles_resolution_one!=len(particles_resolution_one):
541  raise ValueError("the new coordinate set is not compatible with the previous one")
542 
543  return particles_resolution_one,prots
544 
545  def add_structure(self,
546  rmf_name,
547  rmf_frame_index,
548  structure_set_name,
549  setup_index_map=False):
550  """ Read a structure into the ensemble and store (as coordinates).
551  @param rmf_name The name of the RMF file
552  @param rmf_frame_index The frame to read
553  @param structure_set_name Name for the set that includes this structure
554  (e.g. "cluster 1")
555  @param setup_index_map if requested, set up a dictionary to help
556  find residue indexes
557  """
558 
559  # decide where to put this structure
560  if structure_set_name in self.structures_dictionary:
561  cdict=self.structures_dictionary[structure_set_name]
562  rmflist=self.rmf_names_frames[structure_set_name]
563  else:
564  self.structures_dictionary[structure_set_name]={}
565  self.rmf_names_frames[structure_set_name]=[]
566  cdict=self.structures_dictionary[structure_set_name]
567  rmflist=self.rmf_names_frames[structure_set_name]
568 
569  # read the particles
570  try:
571  (particles_resolution_one, prots)=self._get_structure(rmf_frame_index,rmf_name)
572  except ValueError:
573  print("something wrong with the rmf")
574  return 0
575 
576  self.selection_dictionary.update({"All":self.protein_names})
577 
578  for selection_name in self.selection_dictionary:
579  selection_tuple=self.selection_dictionary[selection_name]
580  coords=self._select_coordinates(selection_tuple,particles_resolution_one,prots[0])
581  if selection_name not in cdict:
582  cdict[selection_name]=[coords]
583  else:
584  cdict[selection_name].append(coords)
585 
586  rmflist.append((rmf_name,rmf_frame_index))
587 
588  # if requested, set up a dictionary to help find residue indexes
589  if setup_index_map:
590  self.residue_particle_index_map={}
591  for prot_name in self.protein_names:
592  self.residue_particle_index_map[prot_name] = \
593  self._get_residue_particle_index_map(
594  prot_name,
595  particles_resolution_one,prots[0])
596  for prot in prots:
597  IMP.atom.destroy(prot)
598 
599  def add_structures(self,
600  rmf_name_frame_tuples,
601  structure_set_name):
602  """Read a list of RMFs, supports parallel
603  @param rmf_name_frame_tuples list of (rmf_file_name,frame_number)
604  @param structure_set_name Name this set of structures (e.g. "cluster.1")
605  """
606 
607  # split up the requested list to read in parallel
608  my_rmf_name_frame_tuples=IMP.pmi.tools.chunk_list_into_segments(
609  rmf_name_frame_tuples,self.number_of_processes)[self.rank]
610  for nfr,tup in enumerate(my_rmf_name_frame_tuples):
611  rmf_name=tup[0]
612  rmf_frame_index=tup[1]
613  # the first frame stores the map between residues and particles
614  if self.residue_particle_index_map is None:
615  setup_index_map=True
616  else:
617  setup_index_map=False
618  self.add_structure(rmf_name,
619  rmf_frame_index,
620  structure_set_name,
621  setup_index_map)
622 
623  # synchronize the structures
624  if self.number_of_processes > 1:
625  self.rmf_names_frames=IMP.pmi.tools.scatter_and_gather(self.rmf_names_frames)
626  if self.rank != 0:
627  self.comm.send(self.structures_dictionary, dest=0, tag=11)
628  elif self.rank == 0:
629  for i in range(1, self.number_of_processes):
630  data_tmp = self.comm.recv(source=i, tag=11)
631  for key in self.structures_dictionary:
632  self.structures_dictionary[key].update(data_tmp[key])
633  for i in range(1, self.number_of_processes):
634  self.comm.send(self.structures_dictionary, dest=i, tag=11)
635  if self.rank != 0:
636  self.structures_dictionary = self.comm.recv(source=0, tag=11)
637 
638  def _get_residue_particle_index_map(self,prot_name,structure,hier):
639  residue_particle_index_map=[]
640  s=IMP.atom.Selection(hier,molecules=[prot_name])
641  all_selected_particles=s.get_selected_particles()
642  intersection=list(set(all_selected_particles) & set(structure))
643  sorted_intersection=IMP.pmi.tools.sort_by_residues(intersection)
644  for p in sorted_intersection:
645  residue_particle_index_map.append(IMP.pmi.tools.get_residue_indexes(p))
646  return residue_particle_index_map
647 
648 
649  def _select_coordinates(self,tuple_selections,structure,prot):
650  selected_coordinates=[]
651  for t in tuple_selections:
652  if type(t)==tuple and len(t)==3:
653  s=IMP.atom.Selection(prot,molecules=[t[2]],residue_indexes=range(t[0],t[1]+1))
654  all_selected_particles=s.get_selected_particles()
655  intersection=list(set(all_selected_particles) & set(structure))
656  sorted_intersection=IMP.pmi.tools.sort_by_residues(intersection)
657  cc=[tuple(IMP.core.XYZ(p).get_coordinates()) for p in sorted_intersection]
658  selected_coordinates+=cc
659 
660  elif type(t)==str:
661  s=IMP.atom.Selection(prot,molecules=[t])
662  all_selected_particles=s.get_selected_particles()
663  intersection=list(set(all_selected_particles) & set(structure))
664  sorted_intersection=IMP.pmi.tools.sort_by_residues(intersection)
665  cc=[tuple(IMP.core.XYZ(p).get_coordinates()) for p in sorted_intersection]
666  selected_coordinates+=cc
667  else:
668  raise ValueError("Selection error")
669  return selected_coordinates
670 
671  def set_threshold(self,threshold):
672  self.threshold=threshold
673 
674  def _get_distance(self,
675  structure_set_name1,
676  structure_set_name2,
677  selection_name,
678  index1,
679  index2):
680  """ Compute distance between structures with various metrics """
681  c1=self.structures_dictionary[structure_set_name1][selection_name][index1]
682  c2=self.structures_dictionary[structure_set_name2][selection_name][index2]
683 
684  coordinates1=[IMP.algebra.Vector3D(c) for c in c1]
685  coordinates2=[IMP.algebra.Vector3D(c) for c in c2]
686 
687  if self.style=='pairwise_drmsd_k':
688  distance=IMP.atom.get_drmsd(coordinates1,coordinates2)
689  if self.style=='pairwise_drms_k':
690  distance=IMP.atom.get_drms(coordinates1,coordinates2)
691  if self.style=='pairwise_drmsd_Q':
692  distance=IMP.atom.get_drmsd_Q(coordinates1,coordinates2,self.threshold)
693 
694  if self.style=='pairwise_rmsd':
695  distance=IMP.algebra.get_rmsd(coordinates1,coordinates2)
696  return distance
697 
698  def _get_particle_distances(self,structure_set_name1,structure_set_name2,
699  selection_name,index1,index2):
700  import numpy as np
701  c1=self.structures_dictionary[structure_set_name1][selection_name][index1]
702  c2=self.structures_dictionary[structure_set_name2][selection_name][index2]
703 
704  coordinates1=[IMP.algebra.Vector3D(c) for c in c1]
705  coordinates2=[IMP.algebra.Vector3D(c) for c in c2]
706 
707  distances=[np.linalg.norm(a-b) for (a,b) in zip(coordinates1,coordinates2)]
708 
709  return distances
710 
711  def get_precision(self,
712  structure_set_name1,
713  structure_set_name2,
714  outfile=None,
715  skip=1,
716  selection_keywords=None):
717  """ Evaluate the precision of two named structure groups. Supports MPI.
718  When the structure_set_name1 is different from the structure_set_name2,
719  this evaluates the cross-precision (average pairwise distances).
720  @param outfile Name of the precision output file
721  @param structure_set_name1 string name of the first structure set
722  @param structure_set_name2 string name of the second structure set
723  @param skip analyze every (skip) structure for the distance matrix calculation
724  @param selection_keywords Specify the selection name you want to calculate on.
725  By default this is computed for everything you provided in the constructor,
726  plus all the subunits together.
727  """
728  if selection_keywords is None:
729  sel_keys=list(self.selection_dictionary.keys())
730  else:
731  for k in selection_keywords:
732  if k not in self.selection_dictionary:
733  raise KeyError("you are trying to find named selection " \
734  + k + " which was not requested in the constructor")
735  sel_keys=selection_keywords
736 
737  if outfile is not None:
738  of=open(outfile,"w")
739  centroid_index=0
740  for selection_name in sel_keys:
741  number_of_structures_1=len(self.structures_dictionary[structure_set_name1][selection_name])
742  number_of_structures_2=len(self.structures_dictionary[structure_set_name2][selection_name])
743 
744  distances={}
745  structure_pointers_1=list(range(0,number_of_structures_1,skip))
746  structure_pointers_2=list(range(0,number_of_structures_2,skip))
747 
748  pair_combination_list=list(itertools.product(structure_pointers_1,structure_pointers_2))
749 
750  if len(pair_combination_list)==0:
751  raise ValueError("no structure selected. Check the skip parameter.")
752 
753  my_pair_combination_list=IMP.pmi.tools.chunk_list_into_segments(
754  pair_combination_list,self.number_of_processes)[self.rank]
755  my_length=len(my_pair_combination_list)
756  for n,pair in enumerate(my_pair_combination_list):
757 
758  progression=int(float(n)/my_length*100.0)
759  distances[pair]=self._get_distance(structure_set_name1,structure_set_name2,
760  selection_name,pair[0],pair[1])
761 
762  if self.number_of_processes > 1:
763  distances = IMP.pmi.tools.scatter_and_gather(distances)
764  if self.rank == 0:
765  if structure_set_name1==structure_set_name2:
766  structure_pointers=structure_pointers_1
767  number_of_structures=number_of_structures_1
768 
769  # calculate the distance from the first centroid
770  # and determine the centroid
771 
772  distance=0.0
773  distances_to_structure={}
774  distances_to_structure_normalization={}
775 
776  for n in structure_pointers:
777  distances_to_structure[n]=0.0
778  distances_to_structure_normalization[n]=0
779 
780  for k in distances:
781  distance+=distances[k]
782  distances_to_structure[k[0]]+=distances[k]
783  distances_to_structure[k[1]]+=distances[k]
784  distances_to_structure_normalization[k[0]]+=1
785  distances_to_structure_normalization[k[1]]+=1
786 
787  for n in structure_pointers:
788  distances_to_structure[n]=distances_to_structure[n]/distances_to_structure_normalization[n]
789 
790  min_distance=min([distances_to_structure[n] for n in distances_to_structure])
791  centroid_index=[k for k, v in distances_to_structure.items() if v == min_distance][0]
792  centroid_rmf_name=self.rmf_names_frames[structure_set_name1][centroid_index]
793 
794  centroid_distance=0.0
795  for n in range(number_of_structures):
796  centroid_distance+=self._get_distance(structure_set_name1,structure_set_name1,
797  selection_name,centroid_index,n)
798 
799  #pairwise_distance=distance/len(distances.keys())
800  centroid_distance/=number_of_structures
801  #average_centroid_distance=sum(distances_to_structure)/len(distances_to_structure)
802  if outfile is not None:
803  of.write(str(selection_name)+" "+structure_set_name1+
804  " average centroid distance "+str(centroid_distance)+"\n")
805  of.write(str(selection_name)+" "+structure_set_name1+
806  " centroid index "+str(centroid_index)+"\n")
807  of.write(str(selection_name)+" "+structure_set_name1+
808  " centroid rmf name "+str(centroid_rmf_name)+"\n")
809 
810  average_pairwise_distances=sum(distances.values())/len(list(distances.values()))
811  if outfile is not None:
812  of.write(str(selection_name)+" "+structure_set_name1+" "+structure_set_name2+
813  " average pairwise distance "+str(average_pairwise_distances)+"\n")
814  if outfile is not None:
815  of.close()
816  return centroid_index
817 
818  def get_rmsf(self,
819  structure_set_name,
820  outdir="./",
821  skip=1,
822  set_plot_yaxis_range=None):
823  """ Calculate the residue mean square fluctuations (RMSF).
824  Automatically outputs as data file and pdf
825  @param structure_set_name Which structure set to calculate RMSF for
826  @param outdir Where to write the files
827  @param skip Skip this number of structures
828  @param set_plot_yaxis_range In case you need to change the plot
829  """
830  # get the centroid structure for the whole complex
831  centroid_index=self.get_precision(
832  structure_set_name,
833  structure_set_name,
834  outfile=None,
835  skip=skip)
836  if self.rank==0:
837  for sel_name in self.protein_names:
838  self.selection_dictionary.update({sel_name:[sel_name]})
839  try:
840  number_of_structures=len(self.structures_dictionary[structure_set_name][sel_name])
841  except KeyError:
842  # that protein was not included in the selection
843  continue
844  rpim=self.residue_particle_index_map[sel_name]
845  outfile=outdir+"/rmsf."+sel_name+".dat"
846  of=open(outfile,"w")
847  residue_distances={}
848  residue_nblock={}
849  for index in range(number_of_structures):
850  distances=self._get_particle_distances(structure_set_name,
851  structure_set_name,
852  sel_name,
853  centroid_index,index)
854  for nblock,block in enumerate(rpim):
855  for residue_number in block:
856  residue_nblock[residue_number]=nblock
857  if residue_number not in residue_distances:
858  residue_distances[residue_number]=[distances[nblock]]
859  else:
860  residue_distances[residue_number].append(distances[nblock])
861 
862  residues=[]
863  rmsfs=[]
864  for rn in residue_distances:
865  residues.append(rn)
866  rmsf=np.std(residue_distances[rn])
867  rmsfs.append(rmsf)
868  of.write(str(rn)+" "+str(residue_nblock[rn])+" "+str(rmsf)+"\n")
869 
870  IMP.pmi.output.plot_xy_data(residues,rmsfs,title=sel_name,
871  out_fn=outdir+"/rmsf."+sel_name,display=False,
872  set_plot_yaxis_range=set_plot_yaxis_range,
873  xlabel='Residue Number',ylabel='Standard error')
874  of.close()
875 
876 
877  def set_reference_structure(self,rmf_name,rmf_frame_index):
878  """Read in a structure used for reference computation.
879  Needed before calling get_average_distance_wrt_reference_structure()
880  @param rmf_name The RMF file to read the reference
881  @param rmf_frame_index The index in that file
882  """
883  (particles_resolution_one, prot)=self._get_structure(rmf_frame_index,rmf_name)
884  self.reference_rmf_names_frames=(rmf_name,rmf_frame_index)
885 
886 
887  for selection_name in self.selection_dictionary:
888  selection_tuple=self.selection_dictionary[selection_name]
889  coords=self._select_coordinates(selection_tuple,
890  particles_resolution_one,prot)
891  self.reference_structures_dictionary[selection_name]=coords
892 
893 
894  def get_average_distance_wrt_reference_structure(self,structure_set_name):
895  """Compare the structure set to the reference structure.
896  @param structure_set_name The structure set to compute this on
897  @note First call set_reference_structure()
898  """
899  if self.reference_structures_dictionary=={}:
900  print("Cannot compute until you set a reference structure")
901  return
902  for selection_name in self.selection_dictionary:
903  reference_coordinates=self.reference_structures_dictionary[selection_name]
904  coordinates2=[IMP.algebra.Vector3D(c) for c in reference_coordinates]
905  distances=[]
906 
907  for sc in self.structures_dictionary[structure_set_name][selection_name]:
908  coordinates1=[IMP.algebra.Vector3D(c) for c in sc]
909  if self.style=='pairwise_drmsd_k':
910  distance=IMP.atom.get_drmsd(coordinates1,coordinates2)
911  if self.style=='pairwise_drms_k':
912  distance=IMP.atom.get_drms(coordinates1,coordinates2)
913  if self.style=='pairwise_drmsd_Q':
914  distance=IMP.atom.get_drmsd_Q(coordinates1,coordinates2,self.threshold)
915  if self.style=='pairwise_rmsd':
916  distance=IMP.algebra.get_rmsd(coordinates1,coordinates2)
917  distances.append(distance)
918 
919  print(selection_name,"average distance",sum(distances)/len(distances),"minimum distance",min(distances))
920 
921  def get_coordinates(self):
922  pass
923 
924  def set_precision_style(self, style):
925  if style in self.styles:
926  self.style=style
927  else:
928  raise ValueError("No such style")
929 
930 
931 class GetModelDensity(object):
932  """Compute mean density maps from structures.
933 
934  Keeps a dictionary of density maps,
935  keys are in the custom ranges. When you call add_subunits_density, it adds
936  particle coordinates to the existing density maps.
937  """
938 
939  def __init__(self, custom_ranges, representation=None, voxel=5.0):
940  """Constructor.
941  @param custom_ranges Required. It's a dictionary, keys are the
942  density component names, values are selection tuples
943  e.g. {'kin28':[['kin28',1,-1]],
944  'density_name_1' :[('ccl1')],
945  'density_name_2' :[(1,142,'tfb3d1'),
946  (143,700,'tfb3d2')],
947  @param representation PMI representation, for doing selections.
948  Not needed if you only pass hierarchies
949  @param voxel The voxel size for the output map (lower is slower)
950  """
951 
952  self.representation = representation
953  self.voxel = voxel
954  self.densities = {}
955  self.count_models = 0.0
956  self.custom_ranges = custom_ranges
957 
958  def add_subunits_density(self, hierarchy=None):
959  """Add a frame to the densities.
960  @param hierarchy Optionally read the hierarchy from somewhere.
961  If not passed, will just read the representation.
962  """
963  self.count_models += 1.0
964 
965  if hierarchy:
966  part_dict = get_particles_at_resolution_one(hierarchy)
967  all_particles_by_resolution = []
968  for name in part_dict:
969  all_particles_by_resolution += part_dict[name]
970 
971 
972  for density_name in self.custom_ranges:
973  parts = []
974  if hierarchy:
975  all_particles_by_segments = []
976 
977  for seg in self.custom_ranges[density_name]:
978  if not hierarchy:
979  parts += IMP.tools.select_by_tuple(self.representation,
980  seg, resolution=1, name_is_ambiguous=False)
981  else:
982  if type(seg) == str:
983  s = IMP.atom.Selection(hierarchy,molecule=seg)
984  elif type(seg) == tuple:
985  s = IMP.atom.Selection(
986  hierarchy, molecule=seg[2],residue_indexes=range(seg[0], seg[1] + 1))
987  else:
988  raise Exception('could not understand selection tuple '+str(seg))
989 
990  all_particles_by_segments += s.get_selected_particles()
991 
992  if hierarchy:
993  parts = list(
994  set(all_particles_by_segments) & set(all_particles_by_resolution))
995 
996  self._create_density_from_particles(parts, density_name)
997 
998  def normalize_density(self):
999  pass
1000 
1001  def _create_density_from_particles(self, ps, name,
1002  resolution=1,
1003  kernel_type='GAUSSIAN'):
1004  '''Internal function for adding to densities.
1005  pass XYZR particles with mass and create a density from them.
1006  kernel type options are GAUSSIAN, BINARIZED_SPHERE, and SPHERE.'''
1007  kd = {
1008  'GAUSSIAN': IMP.em.GAUSSIAN,
1009  'BINARIZED_SPHERE': IMP.em.BINARIZED_SPHERE,
1010  'SPHERE': IMP.em.SPHERE}
1011 
1012  dmap = IMP.em.SampledDensityMap(ps, resolution, self.voxel)
1013  dmap.calcRMS()
1014  if name not in self.densities:
1015  self.densities[name] = dmap
1016  else:
1017  bbox1 = IMP.em.get_bounding_box(self.densities[name])
1018  bbox2 = IMP.em.get_bounding_box(dmap)
1019  bbox1 += bbox2
1020  dmap3 = IMP.em.create_density_map(bbox1,self.voxel)
1021  dmap3.add(dmap)
1022  dmap3.add(self.densities[name])
1023  self.densities[name] = dmap3
1024 
1025  def get_density_keys(self):
1026  return list(self.densities.keys())
1027 
1028  def get_density(self,name):
1029  """Get the current density for some component name"""
1030  if name not in self.densities:
1031  return None
1032  else:
1033  return self.densities[name]
1034 
1035  def write_mrc(self, path="./"):
1036  for density_name in self.densities:
1037  self.densities[density_name].multiply(1. / self.count_models)
1039  self.densities[density_name],
1040  path + "/" + density_name + ".mrc",
1042 
1043 
1044 class GetContactMap(object):
1045 
1046  def __init__(self, distance=15.):
1047  self.distance = distance
1048  self.contactmap = ''
1049  self.namelist = []
1050  self.xlinks = 0
1051  self.XL = {}
1052  self.expanded = {}
1053  self.resmap = {}
1054 
1055  def set_prot(self, prot):
1056  from scipy.spatial.distance import cdist
1057  self.prot = prot
1058  self.protnames = []
1059  coords = []
1060  radii = []
1061  namelist = []
1062 
1063  particles_dictionary = get_particles_at_resolution_one(self.prot)
1064 
1065  for name in particles_dictionary:
1066  residue_indexes = []
1067  for p in particles_dictionary[name]:
1068  print(p.get_name())
1069  residue_indexes += IMP.pmi.tools.get_residue_indexes(p)
1070  #residue_indexes.add( )
1071 
1072  if len(residue_indexes) != 0:
1073  self.protnames.append(name)
1074  for res in range(min(residue_indexes), max(residue_indexes) + 1):
1075  d = IMP.core.XYZR(p)
1076  new_name = name + ":" + str(res)
1077  if name not in self.resmap:
1078  self.resmap[name] = {}
1079  if res not in self.resmap:
1080  self.resmap[name][res] = {}
1081 
1082  self.resmap[name][res] = new_name
1083  namelist.append(new_name)
1084 
1085  crd = np.array([d.get_x(), d.get_y(), d.get_z()])
1086  coords.append(crd)
1087  radii.append(d.get_radius())
1088 
1089  coords = np.array(coords)
1090  radii = np.array(radii)
1091 
1092  if len(self.namelist) == 0:
1093  self.namelist = namelist
1094  self.contactmap = np.zeros((len(coords), len(coords)))
1095 
1096  distances = cdist(coords, coords)
1097  distances = (distances - radii).T - radii
1098  distances = distances <= self.distance
1099 
1100  print(coords)
1101  print(radii)
1102  print(distances)
1103 
1104  self.contactmap += distances
1105 
1106  def get_subunit_coords(self, frame, align=0):
1107  from scipy.spatial.distance import cdist
1108  coords = []
1109  radii = []
1110  namelist = []
1111  test, testr = [], []
1112  for part in self.prot.get_children():
1113  SortedSegments = []
1114  print(part)
1115  for chl in part.get_children():
1116  start = IMP.atom.get_leaves(chl)[0]
1117  end = IMP.atom.get_leaves(chl)[-1]
1118 
1119  startres = IMP.atom.Fragment(start).get_residue_indexes()[0]
1120  endres = IMP.atom.Fragment(end).get_residue_indexes()[-1]
1121  SortedSegments.append((chl, startres))
1122  SortedSegments = sorted(SortedSegments, key=itemgetter(1))
1123 
1124  for sgmnt in SortedSegments:
1125  for leaf in IMP.atom.get_leaves(sgmnt[0]):
1126  p = IMP.core.XYZR(leaf)
1127  crd = np.array([p.get_x(), p.get_y(), p.get_z()])
1128 
1129  coords.append(crd)
1130  radii.append(p.get_radius())
1131 
1132  new_name = part.get_name() + '_' + sgmnt[0].get_name() +\
1133  '_' + \
1134  str(IMP.atom.Fragment(leaf)
1135  .get_residue_indexes()[0])
1136  namelist.append(new_name)
1137  self.expanded[new_name] = len(
1138  IMP.atom.Fragment(leaf).get_residue_indexes())
1139  if part.get_name() not in self.resmap:
1140  self.resmap[part.get_name()] = {}
1141  for res in IMP.atom.Fragment(leaf).get_residue_indexes():
1142  self.resmap[part.get_name()][res] = new_name
1143 
1144  coords = np.array(coords)
1145  radii = np.array(radii)
1146  if len(self.namelist) == 0:
1147  self.namelist = namelist
1148  self.contactmap = np.zeros((len(coords), len(coords)))
1149  distances = cdist(coords, coords)
1150  distances = (distances - radii).T - radii
1151  distances = distances <= self.distance
1152  self.contactmap += distances
1153 
1154  def add_xlinks(
1155  self,
1156  filname,
1157  identification_string='ISDCrossLinkMS_Distance_'):
1158  # 'ISDCrossLinkMS_Distance_interrb_6629-State:0-20:RPS30_218:eIF3j-1-1-0.1_None'
1159  self.xlinks = 1
1160  data = open(filname)
1161  D = data.readlines()
1162  data.close()
1163 
1164  for d in D:
1165  if identification_string in d:
1166  d = d.replace(
1167  "_",
1168  " ").replace("-",
1169  " ").replace(":",
1170  " ").split()
1171 
1172  t1, t2 = (d[0], d[1]), (d[1], d[0])
1173  if t1 not in self.XL:
1174  self.XL[t1] = [(int(d[2]) + 1, int(d[3]) + 1)]
1175  self.XL[t2] = [(int(d[3]) + 1, int(d[2]) + 1)]
1176  else:
1177  self.XL[t1].append((int(d[2]) + 1, int(d[3]) + 1))
1178  self.XL[t2].append((int(d[3]) + 1, int(d[2]) + 1))
1179 
1180  def dist_matrix(self, skip_cmap=0, skip_xl=1):
1181  K = self.namelist
1182  M = self.contactmap
1183  C, R = [], []
1184  L = sum(self.expanded.values())
1185  proteins = self.protnames
1186 
1187  # exp new
1188  if skip_cmap == 0:
1189  Matrices = {}
1190  proteins = [p.get_name() for p in self.prot.get_children()]
1191  missing = []
1192  for p1 in range(len(proteins)):
1193  for p2 in range(p1, len(proteins)):
1194  pl1, pl2 = max(
1195  self.resmap[proteins[p1]].keys()), max(self.resmap[proteins[p2]].keys())
1196  pn1, pn2 = proteins[p1], proteins[p2]
1197  mtr = np.zeros((pl1 + 1, pl2 + 1))
1198  print('Creating matrix for: ', p1, p2, pn1, pn2, mtr.shape, pl1, pl2)
1199  for i1 in range(1, pl1 + 1):
1200  for i2 in range(1, pl2 + 1):
1201  try:
1202  r1 = K.index(self.resmap[pn1][i1])
1203  r2 = K.index(self.resmap[pn2][i2])
1204  r = M[r1, r2]
1205  mtr[i1 - 1, i2 - 1] = r
1206  except KeyError:
1207  missing.append((pn1, pn2, i1, i2))
1208  pass
1209  Matrices[(pn1, pn2)] = mtr
1210 
1211  # add cross-links
1212  if skip_xl == 0:
1213  if self.XL == {}:
1214  raise ValueError("cross-links were not provided, use add_xlinks function!")
1215  Matrices_xl = {}
1216  missing_xl = []
1217  for p1 in range(len(proteins)):
1218  for p2 in range(p1, len(proteins)):
1219  pl1, pl2 = max(
1220  self.resmap[proteins[p1]].keys()), max(self.resmap[proteins[p2]].keys())
1221  pn1, pn2 = proteins[p1], proteins[p2]
1222  mtr = np.zeros((pl1 + 1, pl2 + 1))
1223  flg = 0
1224  try:
1225  xls = self.XL[(pn1, pn2)]
1226  except KeyError:
1227  try:
1228  xls = self.XL[(pn2, pn1)]
1229  flg = 1
1230  except KeyError:
1231  flg = 2
1232  if flg == 0:
1233  print('Creating matrix for: ', p1, p2, pn1, pn2, mtr.shape, pl1, pl2)
1234  for xl1, xl2 in xls:
1235  if xl1 > pl1:
1236  print('X' * 10, xl1, xl2)
1237  xl1 = pl1
1238  if xl2 > pl2:
1239  print('X' * 10, xl1, xl2)
1240  xl2 = pl2
1241  mtr[xl1 - 1, xl2 - 1] = 100
1242  elif flg == 1:
1243  print('Creating matrix for: ', p1, p2, pn1, pn2, mtr.shape, pl1, pl2)
1244  for xl1, xl2 in xls:
1245  if xl1 > pl1:
1246  print('X' * 10, xl1, xl2)
1247  xl1 = pl1
1248  if xl2 > pl2:
1249  print('X' * 10, xl1, xl2)
1250  xl2 = pl2
1251  mtr[xl2 - 1, xl1 - 1] = 100
1252  else:
1253  raise RuntimeError('WTF!')
1254  Matrices_xl[(pn1, pn2)] = mtr
1255 
1256  # expand the matrix to individual residues
1257  #NewM = []
1258  # for x1 in xrange(len(K)):
1259  # lst = []
1260  # for x2 in xrange(len(K)):
1261  # lst += [M[x1,x2]]*self.expanded[K[x2]]
1262  # for i in xrange(self.expanded[K[x1]]): NewM.append(np.array(lst))
1263  #NewM = np.array(NewM)
1264 
1265  # make list of protein names and create coordinate lists
1266  C = proteins
1267  # W is the component length list,
1268  # R is the contiguous coordinates list
1269  W, R = [], []
1270  for i, c in enumerate(C):
1271  cl = max(self.resmap[c].keys())
1272  W.append(cl)
1273  if i == 0:
1274  R.append(cl)
1275  else:
1276  R.append(R[-1] + cl)
1277 
1278  # start plotting
1279  if filename:
1280  # Don't require a display
1281  import matplotlib as mpl
1282  mpl.use('Agg')
1283  import matplotlib.pyplot as plt
1284  import matplotlib.gridspec as gridspec
1285  import scipy.sparse as sparse
1286 
1287  f = plt.figure()
1288  gs = gridspec.GridSpec(len(W), len(W),
1289  width_ratios=W,
1290  height_ratios=W)
1291 
1292  cnt = 0
1293  for x1, r1 in enumerate(R):
1294  if x1 == 0:
1295  s1 = 0
1296  else:
1297  s1 = R[x1 - 1]
1298  for x2, r2 in enumerate(R):
1299  if x2 == 0:
1300  s2 = 0
1301  else:
1302  s2 = R[x2 - 1]
1303 
1304  ax = plt.subplot(gs[cnt])
1305  if skip_cmap == 0:
1306  try:
1307  mtr = Matrices[(C[x1], C[x2])]
1308  except KeyError:
1309  mtr = Matrices[(C[x2], C[x1])].T
1310  #cax = ax.imshow(log(NewM[s1:r1,s2:r2] / 1.), interpolation='nearest', vmin=0., vmax=log(NewM.max()))
1311  cax = ax.imshow(
1312  log(mtr),
1313  interpolation='nearest',
1314  vmin=0.,
1315  vmax=log(NewM.max()))
1316  ax.set_xticks([])
1317  ax.set_yticks([])
1318  if skip_xl == 0:
1319  try:
1320  mtr = Matrices_xl[(C[x1], C[x2])]
1321  except KeyError:
1322  mtr = Matrices_xl[(C[x2], C[x1])].T
1323  cax = ax.spy(
1324  sparse.csr_matrix(mtr),
1325  markersize=10,
1326  color='white',
1327  linewidth=100,
1328  alpha=0.5)
1329  ax.set_xticks([])
1330  ax.set_yticks([])
1331 
1332  cnt += 1
1333  if x2 == 0:
1334  ax.set_ylabel(C[x1], rotation=90)
1335  plt.show()
1336 
1337 
1338 # ------------------------------------------------------------------
1339 # a few random tools
1340 
1341 def get_hiers_from_rmf(model, frame_number, rmf_file):
1342  # I have to deprecate this function
1343  print("getting coordinates for frame %i rmf file %s" % (frame_number, rmf_file))
1344 
1345  # load the frame
1346  rh = RMF.open_rmf_file_read_only(rmf_file)
1347 
1348  try:
1349  prots = IMP.rmf.create_hierarchies(rh, model)
1350  except IOError:
1351  print("Unable to open rmf file %s" % (rmf_file))
1352  prots = None
1353  return prots
1354  #IMP.rmf.link_hierarchies(rh, prots)
1355  prot = prots[state_number]
1356  try:
1357  IMP.rmf.load_frame(rh, RMF.FrameID(frame_number))
1358  except IOError:
1359  print("Unable to open frame %i of file %s" % (frame_number, rmf_file))
1360  prots = None
1361  return prots
1362  model.update()
1363  del rh
1364  return prots
1365 
1366 def link_hiers_to_rmf(model,hiers,frame_number, rmf_file):
1367  print("linking hierarchies for frame %i rmf file %s" % (frame_number, rmf_file))
1368  rh = RMF.open_rmf_file_read_only(rmf_file)
1369  IMP.rmf.link_hierarchies(rh, hiers)
1370  IMP.rmf.load_frame(rh, RMF.FrameID(frame_number))
1371  model.update()
1372  del rh
1373 
1374 def get_hiers_and_restraints_from_rmf(model, frame_number, rmf_file):
1375  # I have to deprecate this function
1376  print("getting coordinates for frame %i rmf file %s" % (frame_number, rmf_file))
1377 
1378  # load the frame
1379  rh = RMF.open_rmf_file_read_only(rmf_file)
1380 
1381  try:
1382  prots = IMP.rmf.create_hierarchies(rh, model)
1383  rs = IMP.rmf.create_restraints(rh, model)
1384  except:
1385  print("Unable to open rmf file %s" % (rmf_file))
1386  prot = None
1387  rs = None
1388  return prots,rs
1389  try:
1390  IMP.rmf.load_frame(rh, RMF.FrameID(frame_number))
1391  except:
1392  print("Unable to open frame %i of file %s" % (frame_number, rmf_file))
1393  prots = None
1394  rs = None
1395  return prots,rs
1396  model.update()
1397  del rh
1398  return prots,rs
1399 
1400 def link_hiers_and_restraints_to_rmf(model,hiers,rs, frame_number, rmf_file):
1401  print("linking hierarchies for frame %i rmf file %s" % (frame_number, rmf_file))
1402  rh = RMF.open_rmf_file_read_only(rmf_file)
1403  IMP.rmf.link_hierarchies(rh, hiers)
1404  IMP.rmf.link_restraints(rh, rs)
1405  IMP.rmf.load_frame(rh, RMF.FrameID(frame_number))
1406  model.update()
1407  del rh
1408 
1409 def get_hiers_from_rmf(model, frame_number, rmf_file):
1410  print("getting coordinates for frame %i rmf file %s" % (frame_number, rmf_file))
1411 
1412  # load the frame
1413  rh = RMF.open_rmf_file_read_only(rmf_file)
1414 
1415  try:
1416  prots = IMP.rmf.create_hierarchies(rh, model)
1417  except:
1418  print("Unable to open rmf file %s" % (rmf_file))
1419  prot = None
1420  return prot
1421  #IMP.rmf.link_hierarchies(rh, prots)
1422  try:
1423  IMP.rmf.load_frame(rh, RMF.FrameID(frame_number))
1424  except:
1425  print("Unable to open frame %i of file %s" % (frame_number, rmf_file))
1426  prots = None
1427  model.update()
1428  del rh
1429  return prots
1430 
1431 
1433  """
1434  Get particles at res 1, or any beads, based on the name.
1435  No Representation is needed. This is mainly used when the hierarchy
1436  is read from an RMF file.
1437  @return a dictionary of component names and their particles
1438  """
1439  particle_dict = {}
1440  allparticles = []
1441  for c in prot.get_children():
1442  name = c.get_name()
1443  particle_dict[name] = IMP.atom.get_leaves(c)
1444  for s in c.get_children():
1445  if "_Res:1" in s.get_name() and "_Res:10" not in s.get_name():
1446  allparticles += IMP.atom.get_leaves(s)
1447  if "Beads" in s.get_name():
1448  allparticles += IMP.atom.get_leaves(s)
1449 
1450  particle_align = []
1451  for name in particle_dict:
1452  particle_dict[name] = IMP.pmi.tools.sort_by_residues(
1453  list(set(particle_dict[name]) & set(allparticles)))
1454  return particle_dict
1455 
1457  """
1458  Get particles at res 10, or any beads, based on the name.
1459  No Representation is needed.
1460  This is mainly used when the hierarchy is read from an RMF file.
1461  @return a dictionary of component names and their particles
1462  """
1463  particle_dict = {}
1464  allparticles = []
1465  for c in prot.get_children():
1466  name = c.get_name()
1467  particle_dict[name] = IMP.atom.get_leaves(c)
1468  for s in c.get_children():
1469  if "_Res:10" in s.get_name():
1470  allparticles += IMP.atom.get_leaves(s)
1471  if "Beads" in s.get_name():
1472  allparticles += IMP.atom.get_leaves(s)
1473  particle_align = []
1474  for name in particle_dict:
1475  particle_dict[name] = IMP.pmi.tools.sort_by_residues(
1476  list(set(particle_dict[name]) & set(allparticles)))
1477  return particle_dict
1478 
1479 
1480 
1481 def select_by_tuple(first_res_last_res_name_tuple):
1482  first_res = first_res_last_res_hier_tuple[0]
1483  last_res = first_res_last_res_hier_tuple[1]
1484  name = first_res_last_res_hier_tuple[2]
1485 
1486 class CrossLinkTable(object):
1487  """Visualization of crosslinks"""
1488  def __init__(self):
1489  self.crosslinks = []
1490  self.external_csv_data = None
1491  self.crosslinkedprots = set()
1492  self.mindist = +10000000.0
1493  self.maxdist = -10000000.0
1494  self.contactmap = None
1495 
1496  def set_hierarchy(self, prot):
1497  self.prot_length_dict = {}
1498  self.model=prot.get_model()
1499 
1500  for i in prot.get_children():
1501  name = i.get_name()
1502  residue_indexes = []
1503  for p in IMP.atom.get_leaves(i):
1504  residue_indexes += IMP.pmi.tools.get_residue_indexes(p)
1505 
1506  if len(residue_indexes) != 0:
1507  self.prot_length_dict[name] = max(residue_indexes)
1508 
1509  def set_coordinates_for_contact_map(self, rmf_name,rmf_frame_index):
1510  from scipy.spatial.distance import cdist
1511 
1512  rh= RMF.open_rmf_file_read_only(rmf_name)
1513  prots=IMP.rmf.create_hierarchies(rh, self.model)
1514  IMP.rmf.load_frame(rh, RMF.FrameID(rmf_frame_index))
1515  print("getting coordinates for frame %i rmf file %s" % (rmf_frame_index, rmf_name))
1516  del rh
1517 
1518 
1519  coords = []
1520  radii = []
1521  namelist = []
1522 
1523  particles_dictionary = get_particles_at_resolution_one(prots[0])
1524 
1525  resindex = 0
1526  self.index_dictionary = {}
1527 
1528  for name in particles_dictionary:
1529  residue_indexes = []
1530  for p in particles_dictionary[name]:
1531  print(p.get_name())
1532  residue_indexes = IMP.pmi.tools.get_residue_indexes(p)
1533  #residue_indexes.add( )
1534 
1535  if len(residue_indexes) != 0:
1536 
1537  for res in range(min(residue_indexes), max(residue_indexes) + 1):
1538  d = IMP.core.XYZR(p)
1539 
1540  crd = np.array([d.get_x(), d.get_y(), d.get_z()])
1541  coords.append(crd)
1542  radii.append(d.get_radius())
1543  if name not in self.index_dictionary:
1544  self.index_dictionary[name] = [resindex]
1545  else:
1546  self.index_dictionary[name].append(resindex)
1547  resindex += 1
1548 
1549  coords = np.array(coords)
1550  radii = np.array(radii)
1551 
1552  distances = cdist(coords, coords)
1553  distances = (distances - radii).T - radii
1554 
1555  distances = np.where(distances <= 20.0, 1.0, 0)
1556  if self.contactmap is None:
1557  self.contactmap = np.zeros((len(coords), len(coords)))
1558  self.contactmap += distances
1559 
1560  for prot in prots: IMP.atom.destroy(prot)
1561 
1562  def set_crosslinks(
1563  self, data_file, search_label='ISDCrossLinkMS_Distance_',
1564  mapping=None,
1565  filter_label=None,
1566  filter_rmf_file_names=None, #provide a list of rmf base names to filter the stat file
1567  external_csv_data_file=None,
1568  external_csv_data_file_unique_id_key="Unique ID"):
1569 
1570  # example key ISDCrossLinkMS_Distance_intrarb_937-State:0-108:RPS3_55:RPS30-1-1-0.1_None
1571  # mapping is a dictionary that maps standard keywords to entry positions in the key string
1572  # confidence class is a filter that
1573  # external datafile is a datafile that contains further information on the crosslinks
1574  # it will use the unique id to create the dictionary keys
1575 
1576  po = IMP.pmi.output.ProcessOutput(data_file)
1577  keys = po.get_keys()
1578 
1579  xl_keys = [k for k in keys if search_label in k]
1580 
1581  if filter_rmf_file_names is not None:
1582  rmf_file_key="local_rmf_file_name"
1583  fs = po.get_fields(xl_keys+[rmf_file_key])
1584  else:
1585  fs = po.get_fields(xl_keys)
1586 
1587  # this dictionary stores the occurency of given crosslinks
1588  self.cross_link_frequency = {}
1589 
1590  # this dictionary stores the series of distances for given crosslinked
1591  # residues
1592  self.cross_link_distances = {}
1593 
1594  # this dictionary stores the series of distances for given crosslinked
1595  # residues
1596  self.cross_link_distances_unique = {}
1597 
1598  if not external_csv_data_file is None:
1599  # this dictionary stores the further information on crosslinks
1600  # labeled by unique ID
1601  self.external_csv_data = {}
1602  xldb = IMP.pmi.tools.get_db_from_csv(external_csv_data_file)
1603 
1604  for xl in xldb:
1605  self.external_csv_data[
1606  xl[external_csv_data_file_unique_id_key]] = xl
1607 
1608  # this list keeps track the tuple of cross-links and sample
1609  # so that we don't count twice the same crosslinked residues in the
1610  # same sample
1611  cross_link_frequency_list = []
1612 
1613  self.unique_cross_link_list = []
1614 
1615  for key in xl_keys:
1616  print(key)
1617  keysplit = key.replace(
1618  "_",
1619  " ").replace(
1620  "-",
1621  " ").replace(
1622  ":",
1623  " ").split(
1624  )
1625 
1626  if filter_label!=None:
1627  if filter_label not in keysplit: continue
1628 
1629  if mapping is None:
1630  r1 = int(keysplit[5])
1631  c1 = keysplit[6]
1632  r2 = int(keysplit[7])
1633  c2 = keysplit[8]
1634  try:
1635  confidence = keysplit[12]
1636  except:
1637  confidence = '0.0'
1638  try:
1639  unique_identifier = keysplit[3]
1640  except:
1641  unique_identifier = '0'
1642  else:
1643  r1 = int(keysplit[mapping["Residue1"]])
1644  c1 = keysplit[mapping["Protein1"]]
1645  r2 = int(keysplit[mapping["Residue2"]])
1646  c2 = keysplit[mapping["Protein2"]]
1647  try:
1648  confidence = keysplit[mapping["Confidence"]]
1649  except:
1650  confidence = '0.0'
1651  try:
1652  unique_identifier = keysplit[mapping["Unique Identifier"]]
1653  except:
1654  unique_identifier = '0'
1655 
1656  self.crosslinkedprots.add(c1)
1657  self.crosslinkedprots.add(c2)
1658 
1659  # construct the list of distances corresponding to the input rmf
1660  # files
1661 
1662  dists=[]
1663  if filter_rmf_file_names is not None:
1664  for n,d in enumerate(fs[key]):
1665  if fs[rmf_file_key][n] in filter_rmf_file_names:
1666  dists.append(float(d))
1667  else:
1668  dists=[float(f) for f in fs[key]]
1669 
1670  # check if the input confidence class corresponds to the
1671  # one of the cross-link
1672 
1673  mdist = self.median(dists)
1674 
1675  stdv = np.std(np.array(dists))
1676  if self.mindist > mdist:
1677  self.mindist = mdist
1678  if self.maxdist < mdist:
1679  self.maxdist = mdist
1680 
1681  # calculate the frequency of unique crosslinks within the same
1682  # sample
1683  if not self.external_csv_data is None:
1684  sample = self.external_csv_data[unique_identifier]["Sample"]
1685  else:
1686  sample = "None"
1687 
1688  if (r1, c1, r2, c2,mdist) not in cross_link_frequency_list:
1689  if (r1, c1, r2, c2) not in self.cross_link_frequency:
1690  self.cross_link_frequency[(r1, c1, r2, c2)] = 1
1691  self.cross_link_frequency[(r2, c2, r1, c1)] = 1
1692  else:
1693  self.cross_link_frequency[(r2, c2, r1, c1)] += 1
1694  self.cross_link_frequency[(r1, c1, r2, c2)] += 1
1695  cross_link_frequency_list.append((r1, c1, r2, c2))
1696  cross_link_frequency_list.append((r2, c2, r1, c1))
1697  self.unique_cross_link_list.append(
1698  (r1, c1, r2, c2,mdist))
1699 
1700  if (r1, c1, r2, c2) not in self.cross_link_distances:
1701  self.cross_link_distances[(
1702  r1,
1703  c1,
1704  r2,
1705  c2,
1706  mdist,
1707  confidence)] = dists
1708  self.cross_link_distances[(
1709  r2,
1710  c2,
1711  r1,
1712  c1,
1713  mdist,
1714  confidence)] = dists
1715  self.cross_link_distances_unique[(r1, c1, r2, c2)] = dists
1716  else:
1717  self.cross_link_distances[(
1718  r2,
1719  c2,
1720  r1,
1721  c1,
1722  mdist,
1723  confidence)] += dists
1724  self.cross_link_distances[(
1725  r1,
1726  c1,
1727  r2,
1728  c2,
1729  mdist,
1730  confidence)] += dists
1731 
1732  self.crosslinks.append(
1733  (r1,
1734  c1,
1735  r2,
1736  c2,
1737  mdist,
1738  stdv,
1739  confidence,
1740  unique_identifier,
1741  'original'))
1742  self.crosslinks.append(
1743  (r2,
1744  c2,
1745  r1,
1746  c1,
1747  mdist,
1748  stdv,
1749  confidence,
1750  unique_identifier,
1751  'reversed'))
1752 
1753  self.cross_link_frequency_inverted = {}
1754  for xl in self.unique_cross_link_list:
1755  (r1, c1, r2, c2, mdist) = xl
1756  frequency = self.cross_link_frequency[(r1, c1, r2, c2)]
1757  if frequency not in self.cross_link_frequency_inverted:
1758  self.cross_link_frequency_inverted[
1759  frequency] = [(r1, c1, r2, c2)]
1760  else:
1761  self.cross_link_frequency_inverted[
1762  frequency].append((r1, c1, r2, c2))
1763 
1764  # -------------
1765 
1766  def median(self, mylist):
1767  sorts = sorted(mylist)
1768  length = len(sorts)
1769  print(length)
1770  if length == 1:
1771  return mylist[0]
1772  if not length % 2:
1773  return (sorts[length / 2] + sorts[length / 2 - 1]) / 2.0
1774  return sorts[length / 2]
1775 
1776  def set_threshold(self,threshold):
1777  self.threshold=threshold
1778 
1779  def set_tolerance(self,tolerance):
1780  self.tolerance=tolerance
1781 
1782  def colormap(self, dist):
1783  if dist < self.threshold - self.tolerance:
1784  return "Green"
1785  elif dist >= self.threshold + self.tolerance:
1786  return "Orange"
1787  else:
1788  return "Red"
1789 
1790  def write_cross_link_database(self, filename, format='csv'):
1791  import csv
1792 
1793  fieldnames = [
1794  "Unique ID", "Protein1", "Residue1", "Protein2", "Residue2",
1795  "Median Distance", "Standard Deviation", "Confidence", "Frequency", "Arrangement"]
1796 
1797  if not self.external_csv_data is None:
1798  keys = list(self.external_csv_data.keys())
1799  innerkeys = list(self.external_csv_data[keys[0]].keys())
1800  innerkeys.sort()
1801  fieldnames += innerkeys
1802 
1803  dw = csv.DictWriter(
1804  open(filename,
1805  "w"),
1806  delimiter=',',
1807  fieldnames=fieldnames)
1808  dw.writeheader()
1809  for xl in self.crosslinks:
1810  (r1, c1, r2, c2, mdist, stdv, confidence,
1811  unique_identifier, descriptor) = xl
1812  if descriptor == 'original':
1813  outdict = {}
1814  outdict["Unique ID"] = unique_identifier
1815  outdict["Protein1"] = c1
1816  outdict["Protein2"] = c2
1817  outdict["Residue1"] = r1
1818  outdict["Residue2"] = r2
1819  outdict["Median Distance"] = mdist
1820  outdict["Standard Deviation"] = stdv
1821  outdict["Confidence"] = confidence
1822  outdict["Frequency"] = self.cross_link_frequency[
1823  (r1, c1, r2, c2)]
1824  if c1 == c2:
1825  arrangement = "Intra"
1826  else:
1827  arrangement = "Inter"
1828  outdict["Arrangement"] = arrangement
1829  if not self.external_csv_data is None:
1830  outdict.update(self.external_csv_data[unique_identifier])
1831 
1832  dw.writerow(outdict)
1833 
1834  def plot(self, prot_listx=None, prot_listy=None, no_dist_info=False,
1835  no_confidence_info=False, filter=None, layout="whole", crosslinkedonly=False,
1836  filename=None, confidence_classes=None, alphablend=0.1, scale_symbol_size=1.0,
1837  gap_between_components=0,
1838  rename_protein_map=None):
1839  # layout can be:
1840  # "lowerdiagonal" print only the lower diagonal plot
1841  # "upperdiagonal" print only the upper diagonal plot
1842  # "whole" print all
1843  # crosslinkedonly: plot only components that have crosslinks
1844  # no_dist_info: if True will plot only the cross-links as grey spots
1845  # filter = tuple the tuple contains a keyword to be search in the database
1846  # a relationship ">","==","<"
1847  # and a value
1848  # example ("ID_Score",">",40)
1849  # scale_symbol_size rescale the symbol for the crosslink
1850  # rename_protein_map is a dictionary to rename proteins
1851 
1852  import matplotlib.pyplot as plt
1853  import matplotlib.cm as cm
1854 
1855  fig = plt.figure(figsize=(10, 10))
1856  ax = fig.add_subplot(111)
1857 
1858  ax.set_xticks([])
1859  ax.set_yticks([])
1860 
1861  # set the list of proteins on the x axis
1862  if prot_listx is None:
1863  if crosslinkedonly:
1864  prot_listx = list(self.crosslinkedprots)
1865  else:
1866  prot_listx = list(self.prot_length_dict.keys())
1867  prot_listx.sort()
1868 
1869  nresx = gap_between_components + \
1870  sum([self.prot_length_dict[name]
1871  + gap_between_components for name in prot_listx])
1872 
1873  # set the list of proteins on the y axis
1874 
1875  if prot_listy is None:
1876  if crosslinkedonly:
1877  prot_listy = list(self.crosslinkedprots)
1878  else:
1879  prot_listy = list(self.prot_length_dict.keys())
1880  prot_listy.sort()
1881 
1882  nresy = gap_between_components + \
1883  sum([self.prot_length_dict[name]
1884  + gap_between_components for name in prot_listy])
1885 
1886  # this is the residue offset for each protein
1887  resoffsetx = {}
1888  resendx = {}
1889  res = gap_between_components
1890  for prot in prot_listx:
1891  resoffsetx[prot] = res
1892  res += self.prot_length_dict[prot]
1893  resendx[prot] = res
1894  res += gap_between_components
1895 
1896  resoffsety = {}
1897  resendy = {}
1898  res = gap_between_components
1899  for prot in prot_listy:
1900  resoffsety[prot] = res
1901  res += self.prot_length_dict[prot]
1902  resendy[prot] = res
1903  res += gap_between_components
1904 
1905  resoffsetdiagonal = {}
1906  res = gap_between_components
1907  for prot in IMP.pmi.tools.OrderedSet(prot_listx + prot_listy):
1908  resoffsetdiagonal[prot] = res
1909  res += self.prot_length_dict[prot]
1910  res += gap_between_components
1911 
1912  # plot protein boundaries
1913 
1914  xticks = []
1915  xlabels = []
1916  for n, prot in enumerate(prot_listx):
1917  res = resoffsetx[prot]
1918  end = resendx[prot]
1919  for proty in prot_listy:
1920  resy = resoffsety[proty]
1921  endy = resendy[proty]
1922  ax.plot([res, res], [resy, endy], 'k-', lw=0.4)
1923  ax.plot([end, end], [resy, endy], 'k-', lw=0.4)
1924  xticks.append((float(res) + float(end)) / 2)
1925  if rename_protein_map is not None:
1926  if prot in rename_protein_map:
1927  prot=rename_protein_map[prot]
1928  xlabels.append(prot)
1929 
1930  yticks = []
1931  ylabels = []
1932  for n, prot in enumerate(prot_listy):
1933  res = resoffsety[prot]
1934  end = resendy[prot]
1935  for protx in prot_listx:
1936  resx = resoffsetx[protx]
1937  endx = resendx[protx]
1938  ax.plot([resx, endx], [res, res], 'k-', lw=0.4)
1939  ax.plot([resx, endx], [end, end], 'k-', lw=0.4)
1940  yticks.append((float(res) + float(end)) / 2)
1941  if rename_protein_map is not None:
1942  if prot in rename_protein_map:
1943  prot=rename_protein_map[prot]
1944  ylabels.append(prot)
1945 
1946  # plot the contact map
1947  print(prot_listx, prot_listy)
1948 
1949  if not self.contactmap is None:
1950  import matplotlib.cm as cm
1951  tmp_array = np.zeros((nresx, nresy))
1952 
1953  for px in prot_listx:
1954  print(px)
1955  for py in prot_listy:
1956  print(py)
1957  resx = resoffsety[px]
1958  lengx = resendx[px] - 1
1959  resy = resoffsety[py]
1960  lengy = resendy[py] - 1
1961  indexes_x = self.index_dictionary[px]
1962  minx = min(indexes_x)
1963  maxx = max(indexes_x)
1964  indexes_y = self.index_dictionary[py]
1965  miny = min(indexes_y)
1966  maxy = max(indexes_y)
1967 
1968  print(px, py, minx, maxx, miny, maxy)
1969 
1970  try:
1971  tmp_array[
1972  resx:lengx,
1973  resy:lengy] = self.contactmap[
1974  minx:maxx,
1975  miny:maxy]
1976  except:
1977  continue
1978 
1979 
1980  ax.imshow(tmp_array,
1981  cmap=cm.binary,
1982  origin='lower',
1983  interpolation='nearest')
1984 
1985  ax.set_xticks(xticks)
1986  ax.set_xticklabels(xlabels, rotation=90)
1987  ax.set_yticks(yticks)
1988  ax.set_yticklabels(ylabels)
1989  ax.set_xlim(0,nresx)
1990  ax.set_ylim(0,nresy)
1991 
1992 
1993  # set the crosslinks
1994 
1995  already_added_xls = []
1996 
1997  for xl in self.crosslinks:
1998 
1999  (r1, c1, r2, c2, mdist, stdv, confidence,
2000  unique_identifier, descriptor) = xl
2001 
2002  if confidence_classes is not None:
2003  if confidence not in confidence_classes:
2004  continue
2005 
2006  try:
2007  pos1 = r1 + resoffsetx[c1]
2008  except:
2009  continue
2010  try:
2011  pos2 = r2 + resoffsety[c2]
2012  except:
2013  continue
2014 
2015  if not filter is None:
2016  xldb = self.external_csv_data[unique_identifier]
2017  xldb.update({"Distance": mdist})
2018  xldb.update({"Distance_stdv": stdv})
2019 
2020  if filter[1] == ">":
2021  if float(xldb[filter[0]]) <= float(filter[2]):
2022  continue
2023 
2024  if filter[1] == "<":
2025  if float(xldb[filter[0]]) >= float(filter[2]):
2026  continue
2027 
2028  if filter[1] == "==":
2029  if float(xldb[filter[0]]) != float(filter[2]):
2030  continue
2031 
2032  # all that below is used for plotting the diagonal
2033  # when you have a rectangolar plots
2034 
2035  pos_for_diagonal1 = r1 + resoffsetdiagonal[c1]
2036  pos_for_diagonal2 = r2 + resoffsetdiagonal[c2]
2037 
2038  if layout == 'lowerdiagonal':
2039  if pos_for_diagonal1 <= pos_for_diagonal2:
2040  continue
2041  if layout == 'upperdiagonal':
2042  if pos_for_diagonal1 >= pos_for_diagonal2:
2043  continue
2044 
2045  already_added_xls.append((r1, c1, r2, c2))
2046 
2047  if not no_confidence_info:
2048  if confidence == '0.01':
2049  markersize = 14 * scale_symbol_size
2050  elif confidence == '0.05':
2051  markersize = 9 * scale_symbol_size
2052  elif confidence == '0.1':
2053  markersize = 6 * scale_symbol_size
2054  else:
2055  markersize = 15 * scale_symbol_size
2056  else:
2057  markersize = 5 * scale_symbol_size
2058 
2059  if not no_dist_info:
2060  color = self.colormap(mdist)
2061  else:
2062  color = "gray"
2063 
2064  ax.plot(
2065  [pos1],
2066  [pos2],
2067  'o',
2068  c=color,
2069  alpha=alphablend,
2070  markersize=markersize)
2071 
2072 
2073 
2074  fig.set_size_inches(0.004 * nresx, 0.004 * nresy)
2075 
2076  [i.set_linewidth(2.0) for i in ax.spines.values()]
2077 
2078  #plt.tight_layout()
2079 
2080  if filename:
2081  plt.savefig(filename + ".pdf", dpi=300, transparent="False")
2082  else:
2083  plt.show()
2084 
2085  def get_frequency_statistics(self, prot_list,
2086  prot_list2=None):
2087 
2088  violated_histogram = {}
2089  satisfied_histogram = {}
2090  unique_cross_links = []
2091 
2092  for xl in self.unique_cross_link_list:
2093  (r1, c1, r2, c2, mdist) = xl
2094 
2095  # here we filter by the protein
2096  if prot_list2 is None:
2097  if not c1 in prot_list:
2098  continue
2099  if not c2 in prot_list:
2100  continue
2101  else:
2102  if c1 in prot_list and c2 in prot_list2:
2103  pass
2104  elif c1 in prot_list2 and c2 in prot_list:
2105  pass
2106  else:
2107  continue
2108 
2109  frequency = self.cross_link_frequency[(r1, c1, r2, c2)]
2110 
2111  if (r1, c1, r2, c2) not in unique_cross_links:
2112  if mdist > 35.0:
2113  if frequency not in violated_histogram:
2114  violated_histogram[frequency] = 1
2115  else:
2116  violated_histogram[frequency] += 1
2117  else:
2118  if frequency not in satisfied_histogram:
2119  satisfied_histogram[frequency] = 1
2120  else:
2121  satisfied_histogram[frequency] += 1
2122  unique_cross_links.append((r1, c1, r2, c2))
2123  unique_cross_links.append((r2, c2, r1, c1))
2124 
2125  print("# satisfied")
2126 
2127  total_number_of_crosslinks=0
2128 
2129  for i in satisfied_histogram:
2130  # if i in violated_histogram:
2131  # print i, satisfied_histogram[i]+violated_histogram[i]
2132  # else:
2133  if i in violated_histogram:
2134  print(i, violated_histogram[i]+satisfied_histogram[i])
2135  else:
2136  print(i, satisfied_histogram[i])
2137  total_number_of_crosslinks+=i*satisfied_histogram[i]
2138 
2139  print("# violated")
2140 
2141  for i in violated_histogram:
2142  print(i, violated_histogram[i])
2143  total_number_of_crosslinks+=i*violated_histogram[i]
2144 
2145  print(total_number_of_crosslinks)
2146 
2147 
2148 # ------------
2149  def print_cross_link_binary_symbols(self, prot_list,
2150  prot_list2=None):
2151  tmp_matrix = []
2152  confidence_list = []
2153  for xl in self.crosslinks:
2154  (r1, c1, r2, c2, mdist, stdv, confidence,
2155  unique_identifier, descriptor) = xl
2156 
2157  if prot_list2 is None:
2158  if not c1 in prot_list:
2159  continue
2160  if not c2 in prot_list:
2161  continue
2162  else:
2163  if c1 in prot_list and c2 in prot_list2:
2164  pass
2165  elif c1 in prot_list2 and c2 in prot_list:
2166  pass
2167  else:
2168  continue
2169 
2170  if descriptor != "original":
2171  continue
2172 
2173  confidence_list.append(confidence)
2174 
2175  dists = self.cross_link_distances_unique[(r1, c1, r2, c2)]
2176  tmp_dist_binary = []
2177  for d in dists:
2178  if d < 35:
2179  tmp_dist_binary.append(1)
2180  else:
2181  tmp_dist_binary.append(0)
2182  tmp_matrix.append(tmp_dist_binary)
2183 
2184  matrix = list(zip(*tmp_matrix))
2185 
2186  satisfied_high_sum = 0
2187  satisfied_mid_sum = 0
2188  satisfied_low_sum = 0
2189  total_satisfied_sum = 0
2190  for k, m in enumerate(matrix):
2191  satisfied_high = 0
2192  total_high = 0
2193  satisfied_mid = 0
2194  total_mid = 0
2195  satisfied_low = 0
2196  total_low = 0
2197  total_satisfied = 0
2198  total = 0
2199  for n, b in enumerate(m):
2200  if confidence_list[n] == "0.01":
2201  total_high += 1
2202  if b == 1:
2203  satisfied_high += 1
2204  satisfied_high_sum += 1
2205  elif confidence_list[n] == "0.05":
2206  total_mid += 1
2207  if b == 1:
2208  satisfied_mid += 1
2209  satisfied_mid_sum += 1
2210  elif confidence_list[n] == "0.1":
2211  total_low += 1
2212  if b == 1:
2213  satisfied_low += 1
2214  satisfied_low_sum += 1
2215  if b == 1:
2216  total_satisfied += 1
2217  total_satisfied_sum += 1
2218  total += 1
2219  print(k, satisfied_high, total_high)
2220  print(k, satisfied_mid, total_mid)
2221  print(k, satisfied_low, total_low)
2222  print(k, total_satisfied, total)
2223  print(float(satisfied_high_sum) / len(matrix))
2224  print(float(satisfied_mid_sum) / len(matrix))
2225  print(float(satisfied_low_sum) / len(matrix))
2226 # ------------
2227 
2228  def get_unique_crosslinks_statistics(self, prot_list,
2229  prot_list2=None):
2230 
2231  print(prot_list)
2232  print(prot_list2)
2233  satisfied_high = 0
2234  total_high = 0
2235  satisfied_mid = 0
2236  total_mid = 0
2237  satisfied_low = 0
2238  total_low = 0
2239  total = 0
2240  tmp_matrix = []
2241  satisfied_string = []
2242  for xl in self.crosslinks:
2243  (r1, c1, r2, c2, mdist, stdv, confidence,
2244  unique_identifier, descriptor) = xl
2245 
2246  if prot_list2 is None:
2247  if not c1 in prot_list:
2248  continue
2249  if not c2 in prot_list:
2250  continue
2251  else:
2252  if c1 in prot_list and c2 in prot_list2:
2253  pass
2254  elif c1 in prot_list2 and c2 in prot_list:
2255  pass
2256  else:
2257  continue
2258 
2259  if descriptor != "original":
2260  continue
2261 
2262  total += 1
2263  if confidence == "0.01":
2264  total_high += 1
2265  if mdist <= 35:
2266  satisfied_high += 1
2267  if confidence == "0.05":
2268  total_mid += 1
2269  if mdist <= 35:
2270  satisfied_mid += 1
2271  if confidence == "0.1":
2272  total_low += 1
2273  if mdist <= 35:
2274  satisfied_low += 1
2275  if mdist <= 35:
2276  satisfied_string.append(1)
2277  else:
2278  satisfied_string.append(0)
2279 
2280  dists = self.cross_link_distances_unique[(r1, c1, r2, c2)]
2281  tmp_dist_binary = []
2282  for d in dists:
2283  if d < 35:
2284  tmp_dist_binary.append(1)
2285  else:
2286  tmp_dist_binary.append(0)
2287  tmp_matrix.append(tmp_dist_binary)
2288 
2289  print("unique satisfied_high/total_high", satisfied_high, "/", total_high)
2290  print("unique satisfied_mid/total_mid", satisfied_mid, "/", total_mid)
2291  print("unique satisfied_low/total_low", satisfied_low, "/", total_low)
2292  print("total", total)
2293 
2294  matrix = list(zip(*tmp_matrix))
2295  satisfied_models = 0
2296  satstr = ""
2297  for b in satisfied_string:
2298  if b == 0:
2299  satstr += "-"
2300  if b == 1:
2301  satstr += "*"
2302 
2303  for m in matrix:
2304  all_satisfied = True
2305  string = ""
2306  for n, b in enumerate(m):
2307  if b == 0:
2308  string += "0"
2309  if b == 1:
2310  string += "1"
2311  if b == 1 and satisfied_string[n] == 1:
2312  pass
2313  elif b == 1 and satisfied_string[n] == 0:
2314  pass
2315  elif b == 0 and satisfied_string[n] == 0:
2316  pass
2317  elif b == 0 and satisfied_string[n] == 1:
2318  all_satisfied = False
2319  if all_satisfied:
2320  satisfied_models += 1
2321  print(string)
2322  print(satstr, all_satisfied)
2323  print("models that satisfies the median satisfied crosslinks/total models", satisfied_models, len(matrix))
2324 
2325  def plot_matrix_cross_link_distances_unique(self, figurename, prot_list,
2326  prot_list2=None):
2327 
2328  import pylab as pl
2329 
2330  tmp_matrix = []
2331  for kw in self.cross_link_distances_unique:
2332  (r1, c1, r2, c2) = kw
2333  dists = self.cross_link_distances_unique[kw]
2334 
2335  if prot_list2 is None:
2336  if not c1 in prot_list:
2337  continue
2338  if not c2 in prot_list:
2339  continue
2340  else:
2341  if c1 in prot_list and c2 in prot_list2:
2342  pass
2343  elif c1 in prot_list2 and c2 in prot_list:
2344  pass
2345  else:
2346  continue
2347  # append the sum of dists to order by that in the matrix plot
2348  dists.append(sum(dists))
2349  tmp_matrix.append(dists)
2350 
2351  tmp_matrix.sort(key=itemgetter(len(tmp_matrix[0]) - 1))
2352 
2353  # print len(tmp_matrix), len(tmp_matrix[0])-1
2354  matrix = np.zeros((len(tmp_matrix), len(tmp_matrix[0]) - 1))
2355 
2356  for i in range(len(tmp_matrix)):
2357  for k in range(len(tmp_matrix[i]) - 1):
2358  matrix[i][k] = tmp_matrix[i][k]
2359 
2360  print(matrix)
2361 
2362  fig = pl.figure()
2363  ax = fig.add_subplot(211)
2364 
2365  cax = ax.imshow(matrix, interpolation='nearest')
2366  # ax.set_yticks(range(len(self.model_list_names)))
2367  #ax.set_yticklabels( [self.model_list_names[i] for i in leaves_order] )
2368  fig.colorbar(cax)
2369  pl.savefig(figurename, dpi=300)
2370  pl.show()
2371 
2372  def plot_bars(
2373  self,
2374  filename,
2375  prots1,
2376  prots2,
2377  nxl_per_row=20,
2378  arrangement="inter",
2379  confidence_input="None"):
2380 
2381  data = []
2382  for xl in self.cross_link_distances:
2383  (r1, c1, r2, c2, mdist, confidence) = xl
2384  if c1 in prots1 and c2 in prots2:
2385  if arrangement == "inter" and c1 == c2:
2386  continue
2387  if arrangement == "intra" and c1 != c2:
2388  continue
2389  if confidence_input == confidence:
2390  label = str(c1) + ":" + str(r1) + \
2391  "-" + str(c2) + ":" + str(r2)
2392  values = self.cross_link_distances[xl]
2393  frequency = self.cross_link_frequency[(r1, c1, r2, c2)]
2394  data.append((label, values, mdist, frequency))
2395 
2396  sort_by_dist = sorted(data, key=lambda tup: tup[2])
2397  sort_by_dist = list(zip(*sort_by_dist))
2398  values = sort_by_dist[1]
2399  positions = list(range(len(values)))
2400  labels = sort_by_dist[0]
2401  frequencies = list(map(float, sort_by_dist[3]))
2402  frequencies = [f * 10.0 for f in frequencies]
2403 
2404  nchunks = int(float(len(values)) / nxl_per_row)
2405  values_chunks = IMP.pmi.tools.chunk_list_into_segments(values, nchunks)
2406  positions_chunks = IMP.pmi.tools.chunk_list_into_segments(
2407  positions,
2408  nchunks)
2409  frequencies_chunks = IMP.pmi.tools.chunk_list_into_segments(
2410  frequencies,
2411  nchunks)
2412  labels_chunks = IMP.pmi.tools.chunk_list_into_segments(labels, nchunks)
2413 
2414  for n, v in enumerate(values_chunks):
2415  p = positions_chunks[n]
2416  f = frequencies_chunks[n]
2417  l = labels_chunks[n]
2419  filename + "." + str(n), v, p, f,
2420  valuename="Distance (Ang)", positionname="Unique " + arrangement + " Crosslinks", xlabels=l)
2421 
2422  def crosslink_distance_histogram(self, filename,
2423  prot_list=None,
2424  prot_list2=None,
2425  confidence_classes=None,
2426  bins=40,
2427  color='#66CCCC',
2428  yplotrange=[0, 1],
2429  format="png",
2430  normalized=False):
2431  if prot_list is None:
2432  prot_list = list(self.prot_length_dict.keys())
2433 
2434  distances = []
2435  for xl in self.crosslinks:
2436  (r1, c1, r2, c2, mdist, stdv, confidence,
2437  unique_identifier, descriptor) = xl
2438 
2439  if not confidence_classes is None:
2440  if confidence not in confidence_classes:
2441  continue
2442 
2443  if prot_list2 is None:
2444  if not c1 in prot_list:
2445  continue
2446  if not c2 in prot_list:
2447  continue
2448  else:
2449  if c1 in prot_list and c2 in prot_list2:
2450  pass
2451  elif c1 in prot_list2 and c2 in prot_list:
2452  pass
2453  else:
2454  continue
2455 
2456  distances.append(mdist)
2457 
2459  filename, distances, valuename="C-alpha C-alpha distance [Ang]",
2460  bins=bins, color=color,
2461  format=format,
2462  reference_xline=35.0,
2463  yplotrange=yplotrange, normalized=normalized)
2464 
2465  def scatter_plot_xl_features(self, filename,
2466  feature1=None,
2467  feature2=None,
2468  prot_list=None,
2469  prot_list2=None,
2470  yplotrange=None,
2471  reference_ylines=None,
2472  distance_color=True,
2473  format="png"):
2474  import matplotlib.pyplot as plt
2475  import matplotlib.cm as cm
2476 
2477  fig = plt.figure(figsize=(10, 10))
2478  ax = fig.add_subplot(111)
2479 
2480  for xl in self.crosslinks:
2481  (r1, c1, r2, c2, mdist, stdv, confidence,
2482  unique_identifier, arrangement) = xl
2483 
2484  if prot_list2 is None:
2485  if not c1 in prot_list:
2486  continue
2487  if not c2 in prot_list:
2488  continue
2489  else:
2490  if c1 in prot_list and c2 in prot_list2:
2491  pass
2492  elif c1 in prot_list2 and c2 in prot_list:
2493  pass
2494  else:
2495  continue
2496 
2497  xldb = self.external_csv_data[unique_identifier]
2498  xldb.update({"Distance": mdist})
2499  xldb.update({"Distance_stdv": stdv})
2500 
2501  xvalue = float(xldb[feature1])
2502  yvalue = float(xldb[feature2])
2503 
2504  if distance_color:
2505  color = self.colormap(mdist)
2506  else:
2507  color = "gray"
2508 
2509  ax.plot([xvalue], [yvalue], 'o', c=color, alpha=0.1, markersize=7)
2510 
2511  if not yplotrange is None:
2512  ax.set_ylim(yplotrange)
2513  if not reference_ylines is None:
2514  for rl in reference_ylines:
2515  ax.axhline(rl, color='red', linestyle='dashed', linewidth=1)
2516 
2517  if filename:
2518  plt.savefig(filename + "." + format, dpi=150, transparent="False")
2519 
2520  plt.show()
Simple 3D transformation class.
Visualization of crosslinks.
A decorator to associate a particle with a part of a protein/DNA/RNA.
Definition: Fragment.h:20
double get_drms(const Vector3DsOrXYZs0 &m1, const Vector3DsOrXYZs1 &m2)
DensityMap * create_density_map(const algebra::GridD< 3, algebra::DenseGridStorageD< 3, float >, float > &grid)
A class for reading stat files.
Definition: output.py:633
atom::Hierarchies create_hierarchies(RMF::FileConstHandle fh, Model *m)
def plot_field_histogram
Plot a list of histograms from a value list.
Definition: output.py:894
def plot_fields_box_plots
Plot time series as boxplots.
Definition: output.py:959
double get_drmsd(const Vector3DsOrXYZs0 &m0, const Vector3DsOrXYZs1 &m1)
Calculate distance the root mean square deviation between two sets of 3D points.
Definition: atom/distance.h:49
Miscellaneous utilities.
Definition: tools.py:1
def __init__
Constructor.
double get_weighted_rmsd(const Vector3DsOrXYZs0 &m1, const Vector3DsOrXYZs1 &m2, const Floats &weights)
void link_restraints(RMF::FileConstHandle fh, const Restraints &hs)
A class to evaluate the precision of an ensemble.
log
Definition: log.py:1
A class to cluster structures.
void write_map(DensityMap *m, std::string filename)
Write a density map to a file.
def __init__
Constructor.
Definition: pmi/Analysis.py:32
Class for sampling a density map from particles.
def add_structure
Read a structure into the ensemble and store (as coordinates).
DensityMap * multiply(const DensityMap *m1, const DensityMap *m2)
Performs alignment and RMSD calculation for two sets of coordinates.
Definition: pmi/Analysis.py:23
def add_subunits_density
Add a frame to the densities.
def scatter_and_gather
Synchronize data over a parallel run.
Definition: tools.py:1001
double get_rmsd(const Vector3DsOrXYZs0 &m1, const Vector3DsOrXYZs1 &m2)
Basic utilities for handling cryo-electron microscopy 3D density maps.
void load_frame(RMF::FileConstHandle file, RMF::FrameID frame)
A decorator for a particle with x,y,z coordinates.
Definition: XYZ.h:30
def get_particles_at_resolution_one
Get particles at res 1, or any beads, based on the name.
def fill
Add coordinates for a single model.
Tools for clustering and cluster analysis.
Definition: pmi/Analysis.py:1
def do_cluster
Run K-means clustering.
void destroy(Hierarchy d)
Delete the Hierarchy.
3D rotation class.
Definition: Rotation3D.h:46
algebra::BoundingBoxD< 3 > get_bounding_box(const DensityMap *m)
Definition: DensityMap.h:464
Transformation3D get_identity_transformation_3d()
Return a transformation that does not do anything.
Classes for writing output files and processing them.
Definition: output.py:1
def set_reference_structure
Read in a structure used for reference computation.
def get_rmsf
Calculate the residue mean square fluctuations (RMSF).
def get_average_distance_wrt_reference_structure
Compare the structure set to the reference structure.
def get_density
Get the current density for some component name.
General purpose algebraic and geometric methods that are expected to be used by a wide variety of IMP...
The general base class for IMP exceptions.
Definition: exception.h:49
def get_precision
Evaluate the precision of two named structure groups.
VectorD< 3 > Vector3D
Definition: VectorD.h:395
Restraints create_restraints(RMF::FileConstHandle fh, Model *m)
void link_hierarchies(RMF::FileConstHandle fh, const atom::Hierarchies &hs)
def get_particles_at_resolution_ten
Get particles at res 10, or any beads, based on the name.
Python classes to represent, score, sample and analyze models.
def add_structures
Read a list of RMFs, supports parallel.
Transformation3D get_transformation_aligning_first_to_second(Vector3Ds a, Vector3Ds b)
Hierarchies get_leaves(const Selection &h)
double get_drmsd_Q(const Vector3DsOrXYZs0 &m0, const Vector3DsOrXYZs1 &m1, double threshold)
Definition: atom/distance.h:85
Select hierarchy particles identified by the biological name.
Definition: Selection.h:65
Compute mean density maps from structures.
Support for the RMF file format for storing hierarchical molecular data and markup.
def get_residue_indexes
Retrieve the residue indexes for the given particle.
Definition: tools.py:959
A decorator for a particle with x,y,z coordinates and a radius.
Definition: XYZR.h:27