IMP logo
IMP Reference Guide  2.5.0
The Integrative Modeling Platform
xltable.py
1 """@namespace IMP.pmi.io.xltable
2  Tools to plot a contact map overlaid with cross-links.
3 """
4 
5 from __future__ import print_function
6 from math import sqrt
7 from Bio import SeqIO
8 from Bio.PDB.PDBParser import PDBParser
9 import numpy as np
10 from scipy.spatial.distance import cdist
11 import matplotlib.pyplot as plt
12 import matplotlib.cm as cm
13 
14 from collections import defaultdict
15 import pickle
16 import IMP.pmi
17 import IMP.pmi.io
20 
21 class XLTable():
22  """ class to read, analyze, and plot xlink data on contact maps
23  Canonical way to read the data:
24  1) load sequences and name them
25  2) load coordinates for those sequences from PDB file
26  3) add crosslinks
27  4) create contact map
28  5) plot
29  """
30 
31  def __init__(self,contact_threshold):
32  self.sequence_dict={}
33  self.cross_link_db = None
34  self.residue_pair_list = [] # list of special residue pairs to display
35  self.distance_maps = [] # distance map for each copy of the complex
36  self.contact_freqs = None
37  self.num_pdbs = 0
38  self.num_rmfs = 0
39  self.index_dict = defaultdict(list) # location in the dmap of each residue
40  self.contact_threshold = contact_threshold
41  # internal things
42  self._first = True
43 
44  def _colormap_distance(self, dist, threshold=35, tolerance=0):
45  if dist < threshold - tolerance:
46  return "Green"
47  elif dist >= threshold + tolerance:
48  return "Orange"
49  else:
50  return "Red"
51 
52  def _colormap_satisfaction(self, sat, threshold=0.5, tolerance=0.1):
53  if sat >= threshold + tolerance:
54  print(sat, "green")
55  return "Green"
56  elif sat < threshold + tolerance and sat >= threshold - tolerance :
57  print(sat, "orange")
58  return "Orange"
59  else:
60  print(sat, "orange")
61  return "Orange"
62 
63  def _get_percentage_satisfaction(self,r1,c1,r2,c2,threshold=35):
64  try:
65  idx1=self.index_dict[c1][r1]
66  except:
67  return None
68  try:
69  idx2=self.index_dict[c2][r2]
70  except:
71  return None
72  nsatisfied=0
73  for dists in self.dist_maps:
74  dist=dists[idx1,idx2]
75  if dist<threshold: nsatisfied+=1
76  return float(nsatisfied)/len(self.dist_maps)
77 
78  def _get_distance(self,r1,c1,r2,c2):
79  try:
80  idx1=self.index_dict[c1][r1]
81  except:
82  return None
83  try:
84  idx2=self.index_dict[c2][r2]
85  except:
86  return None
87  return self.av_dist_map[idx1,idx2]
88 
89  def _internal_load_maps(self,maps_fn):
90  npzfile = np.load(maps_fn)
91  cname_array=npzfile['cname_array']
92  idx_array=npzfile['idx_array']
93  index_dict={}
94  for cname,idxs in zip(cname_array,idx_array):
95  tmp=list(idxs)
96  if -1 in tmp:
97  index_dict[cname]=tmp[0:tmp.index(-1)]
98  else:
99  index_dict[cname]=tmp
100  av_dist_map = npzfile['av_dist_map']
101  contact_map = npzfile['contact_map']
102  return index_dict,av_dist_map,contact_map
103 
104  def load_sequence_from_fasta_file(self,fasta_file,id_in_fasta_file,protein_name):
105  """ read sequence. structures are displayed in the same order as sequences are read.
106  fasta_file: file to read
107  id_in_fasta_file: id of desired sequence
108  protein_name: identifier for this sequence (use same name when handling coordinates)
109  can provide the fasta name (for retrieval) and the protein name (for storage) """
110  handle = open(fasta_file, "rU")
111  record_dict = SeqIO.to_dict(SeqIO.parse(handle, "fasta"))
112  handle.close()
113  if id_in_fasta_file is None:
114  id_in_fasta_file = name
115  try:
116  length = len(record_dict[id_in_fasta_file].seq)
117  except KeyError:
118  print("add_component_sequence: id %s not found in fasta file" % id_in_fasta_file)
119  exit()
120  self.sequence_dict[protein_name] = str(record_dict[id_in_fasta_file].seq).replace("*", "")
121 
122  def load_pdb_coordinates(self,pdbfile,chain_to_name_map):
123  """ read coordinates from a pdb file. also appends to distance maps
124  @param pdbfile file for reading coords
125  @param chain_to_name_map correspond chain ID with protein name (will ONLY read these chains)
126  \note This function returns an error if the sequence for each chain has NOT been read
127  """
128  pdbparser = PDBParser()
129  structure = pdbparser.get_structure(pdbfile,pdbfile)
130  total_len = sum(len(self.sequence_dict[s]) for s in self.sequence_dict)
131  coords = np.ones((total_len,3)) * 1e5 #default to coords "very far away"
132  prev_stop=0
133  for n,model in enumerate(structure):
134  for cid in chain_to_name_map:
135  cname=chain_to_name_map[cid]
136  if cname not in self.sequence_dict:
137  print("ERROR: chain",cname,'has not been read or has a naming mismatch')
138  return
139  if self._first:
140  self.index_dict[cname]=range(prev_stop,prev_stop+len(self.sequence_dict[cname]))
141  for residue in model[cid]:
142  if "CA" in residue:
143  ca=residue["CA"]
144  rnum=residue.id[1]
145  coords[rnum+prev_stop-1,:]=ca.get_coord()
146  #else:
147  # print residue
148  prev_stop+=len(self.sequence_dict[cname])
149  dists = cdist(coords, coords)
150  binary_dists = np.where((dists <= self.contact_threshold) & (dists >= 1.0), 1.0, 0.0)
151  if self._first:
152  self.dist_maps= [dists]
153  self.av_dist_map = dists
154  self.contact_freqs = binary_dists
155  self._first=False
156  else:
157  self.dist_maps.append(dists)
158  self.av_dist_map += dists
159  self.contact_freqs += binary_dists
160  self.num_pdbs+=1
161 
162  def load_rmf_coordinates(self,rmf_name,rmf_frame_index, chain_names):
163  """ read coordinates from a rmf file. It needs IMP to run.
164  rmf has been created using IMP.pmi conventions. It gets the
165  highest resolution atomatically. Also appends to distance maps
166  @param rmf_name file for reading coords
167  @param rmf_frame_index frame index from the rmf
168  """
169  import IMP
170  import IMP.atom
171  import IMP.pmi
172  import IMP.pmi.tools
173 
174  self.imp_model=IMP.Model()
175  (particles_resolution_one, prots)=self._get_rmf_structure(rmf_name,rmf_frame_index)
176 
177  pdbparser = PDBParser()
178  total_len = sum(len(self.sequence_dict[s]) for s in self.sequence_dict)
179 
180 
181  coords = np.ones((total_len,3)) * 1e5 #default to coords "very far away"
182  prev_stop=0
183  sorted_particles=IMP.pmi.tools.sort_by_residues(particles_resolution_one)
184  print(chain_names)
185  for cname in chain_names:
186  if self._first:
187  self.index_dict[cname]=range(prev_stop,prev_stop+len(self.sequence_dict[cname]))
188  rindexes=range(1,len(self.sequence_dict[cname])+1)
189  for rnum in rindexes:
190  sel=IMP.atom.Selection(prots,molecule=cname,residue_index=rnum)
191  selpart=sel.get_selected_particles()
192  selpart_res_one=list(set(particles_resolution_one) & set(selpart))
193  if len(selpart_res_one)>1: continue
194  if len(selpart_res_one)==0: continue
195  selpart_res_one=selpart_res_one[0]
196  coords[rnum+prev_stop-1,:]=IMP.core.XYZ(selpart_res_one).get_coordinates()
197  prev_stop+=len(self.sequence_dict[cname])
198  dists = cdist(coords, coords)
199  binary_dists = np.where((dists <= self.contact_threshold) & (dists >= 1.0), 1.0, 0.0)
200  if self._first:
201  self.dist_maps= [dists]
202  self.av_dist_map = dists
203  self.contact_freqs = binary_dists
204  self._first=False
205  else:
206  self.dist_maps.append(dists)
207  self.av_dist_map += dists
208  self.contact_freqs += binary_dists
209  self.num_rmfs+=1
210 
211 
212  def _get_rmf_structure(self,rmf_name,rmf_frame_index):
213  import IMP.pmi
214  import IMP.pmi.analysis
215  import IMP.rmf
216  import RMF
217 
218  rh= RMF.open_rmf_file_read_only(rmf_name)
219  prots=IMP.rmf.create_hierarchies(rh, self.imp_model)
220  IMP.rmf.load_frame(rh, rmf_frame_index)
221  print("getting coordinates for frame %i rmf file %s" % (rmf_frame_index, rmf_name))
222  del rh
223 
225 
226  protein_names=particle_dict.keys()
227  particles_resolution_one=[]
228  for k in particle_dict:
229  particles_resolution_one+=(particle_dict[k])
230 
231  return particles_resolution_one, prots
232 
233 
234  def save_maps(self,maps_fn):
235  maxlen=max(len(self.index_dict[key]) for key in self.index_dict)
236  cnames=[]
237  idxs=[]
238  for cname,idx in self.index_dict.iteritems():
239  cnames.append(cname)
240  idxs.append(idx+[-1]*(maxlen-len(idx)))
241  idx_array=np.array(idxs)
242  cname_array=np.array(cnames)
243  np.savez(maps_fn,
244  cname_array=cname_array,
245  idx_array=idx_array,
246  av_dist_map=self.av_dist_map,
247  contact_map=self.contact_freqs)
248 
249  def load_maps(self,maps_fn):
250  self.index_dict,self.av_dist_map,self.contact_freqs=self._internal_load_maps(maps_fn)
251 
252  def load_crosslinks(self,CrossLinkDataBase):
253  """ read crosslinks from a CSV file.
254  provide a CrossLinkDataBaseKeywordsConverter to explain the columns"""
255  if type(CrossLinkDataBase) is not IMP.pmi.io.crosslink.CrossLinkDataBase:
256  raise TypeError("Crosslink database must be a IMP.pmi.io.CrossLinkDataBase type")
257  self.cross_link_db=CrossLinkDataBase
258 
259  def set_residue_pairs_to_display(self,residue_type_pair):
260  """ select the atom names of residue pairs to plot on the contact map
261  list of residues types must be single letter code
262  e.g. residue_type_pair=("K","K")
263  """
264  rtp=sorted(residue_type_pair)
265  for prot1 in self.sequence_dict:
266  seq1=self.sequence_dict[prot1]
267  for nres1,res1 in enumerate(seq1):
268  for prot2 in self.sequence_dict:
269  seq2=self.sequence_dict[prot2]
270  for nres2,res2 in enumerate(seq2):
271  if sorted((res1,res2))==rtp:
272  self.residue_pair_list.append((nres1+1,prot1,nres2+1,prot2))
273 
274  def setup_contact_map(self):
275  """ loop through each distance map and get frequency of contacts
276  upperbound: maximum distance to be marked
277  """
278  if self.num_pdbs!=0 and self.num_rmfs==0:
279  self.av_dist_map = 1.0/self.num_pdbs * self.av_dist_map
280  self.contact_freqs = 1.0/self.num_pdbs * self.contact_freqs
281  if self.num_pdbs==0 and self.num_rmfs!=0:
282  self.av_dist_map = 1.0/self.num_rmfs * self.av_dist_map
283  self.contact_freqs = 1.0/self.num_rmfs * self.contact_freqs
284 
285  def setup_difference_map(self,maps_fn1,maps_fn2,thresh):
286  idx1,av1,contact1=self._internal_load_maps(maps_fn1)
287  idx2,av2,contact2=self._internal_load_maps(maps_fn2)
288  if idx1!=idx2:
289  print("UH OH: index dictionaries do not match!")
290  exit()
291  self.index_dict=idx1
292  self.av_dist_map=av1 # should we store both somehow? only needed for XL
293 
294  def logic(c1,c2):
295  if c1==0 and c2==0: # white
296  return 0
297  elif c1>thresh and c2<thresh: # red
298  return 1
299  elif c1<thresh and c2>thresh: # blue
300  return 2
301  else: # green
302  return 3
303  f = np.vectorize(logic,otypes=[np.int])
304  print('computing contact map')
305  self.contact_freqs = f(contact1,contact2)
306  print('done')
307 
308 
309  def plot_table(self, prot_listx=None,
310  prot_listy=None,
311  no_dist_info=False,
312  confidence_info=False,
313  filter=None,
314  display_residue_pairs=False,
315  contactmap=False,
316  filename=None,
317  confidence_classes=None,
318  alphablend=0.1,
319  scale_symbol_size=1.0,
320  gap_between_components=0,
321  colormap=cm.binary,
322  crosslink_threshold=None,
323  colornorm=None,
324  cbar_labels=None,
325  color_crosslinks_by_distance=True):
326  """ plot the xlink table with optional contact map.
327  prot_listx: list of protein names on the x-axis
328  prot_listy: list of protein names on the y-axis
329  no_dist_info: plot only the cross-links as grey spots
330  confidence_info:
331  filter: list of tuples to filter on. each one contains:
332  keyword in the database to be filtered on
333  relationship ">","==","<"
334  a value
335  example ("ID_Score",">",40)
336  display_residue_pairs: display all pairs defined in self.residue_pair_list
337  contactmap: display the contact map
338  filename: save to file (adds .pdf extension)
339  confidence_classes:
340  alphablend:
341  scale_symbol_size: rescale the symbol for the crosslink
342  gap_between_components:
343  """
344  # prepare figure
345  fig = plt.figure(figsize=(10, 10))
346  ax = fig.add_subplot(111)
347  ax.set_xticks([])
348  ax.set_yticks([])
349 
350  if cbar_labels is not None:
351  if len(cbar_labels)!=4:
352  print("to provide cbar labels, give 3 fields (first=first input file, last=last input) in oppose order of input contact maps")
353  exit()
354  # set the list of proteins on the x axis
355  if prot_listx is None:
356  prot_listx = self.sequence_dict.keys()
357  prot_listx.sort()
358  nresx = gap_between_components + \
359  sum([len(self.sequence_dict[name])
360  + gap_between_components for name in prot_listx])
361 
362  # set the list of proteins on the y axis
363  if prot_listy is None:
364  prot_listy = self.sequence_dict.keys()
365  prot_listy.sort()
366  nresy = gap_between_components + \
367  sum([len(self.sequence_dict[name])
368  + gap_between_components for name in prot_listy])
369 
370  # this is the residue offset for each protein
371  resoffsetx = {}
372  resendx = {}
373  res = gap_between_components
374  for prot in prot_listx:
375  resoffsetx[prot] = res
376  res += len(self.sequence_dict[prot])
377  resendx[prot] = res
378  res += gap_between_components
379 
380  resoffsety = {}
381  resendy = {}
382  res = gap_between_components
383  for prot in prot_listy:
384  resoffsety[prot] = res
385  res += len(self.sequence_dict[prot])
386  resendy[prot] = res
387  res += gap_between_components
388 
389  resoffsetdiagonal = {}
390  res = gap_between_components
391  for prot in IMP.pmi.io.utilities.OrderedSet(prot_listx + prot_listy):
392  resoffsetdiagonal[prot] = res
393  res += len(self.sequence_dict[prot])
394  res += gap_between_components
395 
396  # plot protein boundaries
397  xticks = []
398  xlabels = []
399  for n, prot in enumerate(prot_listx):
400  res = resoffsetx[prot]
401  end = resendx[prot]
402  for proty in prot_listy:
403  resy = resoffsety[proty]
404  endy = resendy[proty]
405  ax.plot([res, res], [resy, endy], 'k-', lw=0.4)
406  ax.plot([end, end], [resy, endy], 'k-', lw=0.4)
407  xticks.append((float(res) + float(end)) / 2)
408  xlabels.append(prot)
409 
410  yticks = []
411  ylabels = []
412  for n, prot in enumerate(prot_listy):
413  res = resoffsety[prot]
414  end = resendy[prot]
415  for protx in prot_listx:
416  resx = resoffsetx[protx]
417  endx = resendx[protx]
418  ax.plot([resx, endx], [res, res], 'k-', lw=0.4)
419  ax.plot([resx, endx], [end, end], 'k-', lw=0.4)
420  yticks.append((float(res) + float(end)) / 2)
421  ylabels.append(prot)
422 
423  # plot the contact map
424  if contactmap:
425  tmp_array = np.zeros((nresx, nresy))
426  for px in prot_listx:
427  for py in prot_listy:
428  resx = resoffsety[px]
429  lengx = resendx[px] - 1
430  resy = resoffsety[py]
431  lengy = resendy[py] - 1
432  indexes_x = self.index_dict[px]
433  minx = min(indexes_x)
434  maxx = max(indexes_x)
435  indexes_y = self.index_dict[py]
436  miny = min(indexes_y)
437  maxy = max(indexes_y)
438  tmp_array[resx:lengx,resy:lengy] = self.contact_freqs[minx:maxx,miny:maxy]
439 
440  cax = ax.imshow(tmp_array,
441  cmap=colormap,
442  norm=colornorm,
443  origin='lower',
444  alpha=0.6,
445  interpolation='nearest')
446 
447  ax.set_xticks(xticks)
448  ax.set_xticklabels(xlabels, rotation=90)
449  ax.set_yticks(yticks)
450  ax.set_yticklabels(ylabels)
451 
452  # set the crosslinks
453  already_added_xls = []
454  for xl in self.cross_link_db:
455 
456  (c1,c2,r1,r2)=IMP.pmi.io.crosslink._ProteinsResiduesArray(xl)
457 
458  if color_crosslinks_by_distance:
459 
460  try:
461  mdist=self._get_distance(r1,c1,r2,c2)
462  if mdist is None: continue
463  color = self._colormap_distance(mdist,threshold=crosslink_threshold)
464  except KeyError:
465  color="gray"
466 
467  else:
468 
469  try:
470  ps=self._get_percentage_satisfaction(r1,c1,r2,c2)
471  if ps is None: continue
472  color = self._colormap_satisfaction(ps,threshold=0.2,tolerance=0.1)
473  except KeyError:
474  color="gray"
475 
476  try:
477  pos1 = r1 + resoffsetx[c1]
478  except:
479  continue
480  try:
481  pos2 = r2 + resoffsety[c2]
482  except:
483  continue
484 
485 
486  # everything below is used for plotting the diagonal
487  # when you have a rectangolar plots
488  pos_for_diagonal1 = r1 + resoffsetdiagonal[c1]
489  pos_for_diagonal2 = r2 + resoffsetdiagonal[c2]
490  if confidence_info:
491  if confidence == '0.01':
492  markersize = 14 * scale_symbol_size
493  elif confidence == '0.05':
494  markersize = 9 * scale_symbol_size
495  elif confidence == '0.1':
496  markersize = 6 * scale_symbol_size
497  else:
498  markersize = 15 * scale_symbol_size
499  else:
500  markersize = 5 * scale_symbol_size
501 
502  ax.plot([pos1],
503  [pos2],
504  'o',
505  c=color,
506  alpha=alphablend,
507  markersize=markersize)
508 
509  ax.plot([pos2],
510  [pos1],
511  'o',
512  c=color,
513  alpha=alphablend,
514  markersize=markersize)
515 
516  # plot requested residue pairs
517  if display_residue_pairs:
518  for rp in self.residue_pair_list:
519  r1=rp[0]
520  c1=rp[1]
521  r2=rp[2]
522  c2=rp[3]
523 
524  try:
525  dist=self._get_distance(r1,c1,r2,c2)
526  except:
527  continue
528 
529  if dist<=40.0:
530  print(rp)
531  try:
532  pos1 = r1 + resoffsetx[c1]
533  except:
534  continue
535  try:
536  pos2 = r2 + resoffsety[c2]
537  except:
538  continue
539 
540  ax.plot([pos1],
541  [pos2],
542  '+',
543  c="blue",
544  alpha=0.1,
545  markersize=markersize)
546 
547  # display and write to file
548  fig.set_size_inches(0.002 * nresx, 0.002 * nresy)
549  [i.set_linewidth(2.0) for i in ax.spines.itervalues()]
550  if cbar_labels is not None:
551  cbar = fig.colorbar(cax, ticks=[0.5,1.5,2.5,3.5])
552  cbar.ax.set_yticklabels(cbar_labels)# vertically oriented colorbar
553 
554  if filename:
555  plt.savefig(filename, dpi=300,transparent="False")
556  #plt.show()
Utility classes and functions for IO.
Definition: utilities.py:1
atom::Hierarchies create_hierarchies(RMF::FileConstHandle fh, Model *m)
class to read, analyze, and plot xlink data on contact maps Canonical way to read the data: 1) load s...
Definition: xltable.py:21
Utility classes and functions for reading and storing PMI files.
def plot_table
plot the xlink table with optional contact map.
Definition: xltable.py:309
Miscellaneous utilities.
Definition: tools.py:1
def load_rmf_coordinates
read coordinates from a rmf file.
Definition: xltable.py:162
def load_pdb_coordinates
read coordinates from a pdb file.
Definition: xltable.py:122
def setup_contact_map
loop through each distance map and get frequency of contacts upperbound: maximum distance to be marke...
Definition: xltable.py:274
Class for storing model, its restraints, constraints, and particles.
Definition: Model.h:72
def load_sequence_from_fasta_file
read sequence.
Definition: xltable.py:104
def set_residue_pairs_to_display
select the atom names of residue pairs to plot on the contact map list of residues types must be sing...
Definition: xltable.py:259
def load_crosslinks
read crosslinks from a CSV file.
Definition: xltable.py:252
void load_frame(RMF::FileConstHandle file, RMF::FrameID frame)
A decorator for a particle with x,y,z coordinates.
Definition: XYZ.h:30
def get_particles_at_resolution_one
Get particles at res 1, or any beads, based on the name.
Tools for clustering and cluster analysis.
Definition: pmi/Analysis.py:1
Python classes to represent, score, sample and analyze models.
Functionality for loading, creating, manipulating and scoring atomic structures.
Select hierarchy particles identified by the biological name.
Definition: Selection.h:65
Support for the RMF file format for storing hierarchical molecular data and markup.