1 """@namespace IMP.pmi.restraints.crosslinking_atomic
2 Restraints for handling crosslinking data at atomic resolution.
5 from __future__
import print_function
14 from collections
import defaultdict
18 def setup_nuisance(m,rs,
27 nuisance.set_lower(min_val_nuis)
28 nuisance.set_upper(max_val_nuis)
29 nuisance.set_is_optimized(nuisance.get_nuisance_key(),is_opt)
31 max_val_prior,min_val_prior))
37 class RestraintSetupError(Exception):
40 class MyGetRestraint(object):
41 def __init__(self,rs):
43 def get_restraint_for_rmf(self):
46 class AtomicCrossLinkMSRestraint(object):
56 nuisances_are_optimized=
True,
61 """Experimental ATOMIC XL restraint. Provide selections for the particles to restrain.
62 Automatically creates one "sigma" per crosslinked residue and one "psis" per pair.
63 Other nuisance options are available.
64 \note Will return an error if the data+extra_sel don't specify two particles per XL pair.
65 @param root The root hierarchy on which you'll do selection
66 @param data CrossLinkData object
67 @param extra_sel Additional selections to add to each data point. Defaults to:
68 {'atom_type':IMP.atom.AtomType('NZ')}
69 @param length The XL linker length
70 @param nstates The number of states to model. Defaults to the number of states in root.
71 @param label The output label for the restraint
72 @param nuisances_are_optimized Whether to optimize nuisances
73 @param sigma_init The initial value for all the sigmas
74 @param psi_init The initial value for all the psis
75 @param one_psi Use a single psi for all restraints (if False, creates one per XL)
76 @param create_nz Coarse-graining hack - add 'NZ' atoms to every XL'd lysine
79 self.mdl = root.get_model()
84 self.nuis_opt = nuisances_are_optimized
85 self.nstates = nstates
89 print(
"Warning: nstates is not the same as the number of states in root")
93 self.particles=defaultdict(set)
94 self.one_psi = one_psi
95 self.create_nz = create_nz
97 print(
'creating a single psi for all XLs')
99 print(
'creating one psi for each XL')
103 psi_max_nuis = 0.4999999
106 sigma_min_nuis = 1e-7
107 sigma_max_nuis = 100.1
108 sigma_min_prior = 1e-3
109 sigma_max_prior = 100.0
113 self.sig_low = setup_nuisance(self.mdl,self.rs_nuis,init_val=sigma_init,min_val=1.0,
114 max_val=100.0,is_opt=self.nuis_opt)
115 self.sig_high = setup_nuisance(self.mdl,self.rs_nuis,init_val=sigma_init,min_val=1.0,
116 max_val=100.0,is_opt=self.nuis_opt)
118 self.sigma = setup_nuisance(self.mdl,self.rs_nuis,
120 min_val_nuis=sigma_min_nuis,
121 max_val_nuis=sigma_max_nuis,
122 min_val_prior=sigma_min_prior,
123 max_val_prior=sigma_max_prior,
124 is_opt=self.nuis_opt,
127 self.psi = setup_nuisance(self.mdl,self.rs_nuis,
129 min_val_nuis=psi_min_nuis,
130 max_val_nuis=psi_max_nuis,
131 min_val_prior=psi_min_prior,
132 max_val_prior=psi_max_prior,
133 is_opt=self.nuis_opt,
137 for unique_id
in data:
138 self.psis[unique_id]=setup_nuisance(self.mdl,self.rs_nuis,
140 min_val_nuis=psi_min_nuis,
141 max_val_nuis=psi_max_nuis,
142 min_val_prior=psi_min_prior,
143 max_val_prior=psi_max_prior,
144 is_opt=self.nuis_opt,
158 self.bonded_pairs = []
164 for nstate
in range(self.nstates):
165 for unique_id
in data:
166 for xl
in data[unique_id]:
167 xl_pairs = xl.get_selection(root,state_index=nstate)
175 frag = res.get_parent()
190 self.bonded_pairs.append([ca,nz])
191 self.rset_bonds.add_restraint(pr)
195 sel_pre =
IMP.atom.Selection(frag,residue_index=res.get_index()-1).get_selected_particles()
196 sel_post =
IMP.atom.Selection(frag,residue_index=res.get_index()+1).get_selected_particles()
197 if len(sel_pre)>1
or len(sel_post)>1:
198 print(
"SOMETHING WRONG WITH THIS FRAG")
213 ca_post = sel_post[0]
216 if nter
and not cter:
218 self.rset_angles.add_restraint(ar_post)
219 elif cter
and not nter:
221 self.rset_angles.add_restraint(ar_pre)
227 self.rset_angles.add_restraint(idr)
232 for unique_id
in data:
235 psip = self.psi.get_particle_index()
237 psip = self.psis[unique_id].get_particle_index()
248 for nstate
in range(self.nstates):
249 for xl
in data[unique_id]:
250 xl_pairs = xl.get_selection(root,state_index=nstate,
256 num1=num_xls_per_res[str(xl.r1)]
257 num2=num_xls_per_res[str(xl.r2)]
258 if num1<sig_threshold:
262 if num2<sig_threshold:
271 for p1,p2
in xl_pairs:
272 self.particles[nstate]|=set([p1,p2])
273 if max_dist
is not None:
277 r.add_contribution([p1.get_index(),p2.get_index()],
278 [sig1.get_particle_index(),sig2.get_particle_index()])
280 if num_contributions==0:
281 raise RestraintSetupError(
"No contributions!")
283 print(
'created',len(xlrs),
'XL restraints')
286 def set_weight(self,weight):
288 self.rs.set_weight(weight)
290 def set_label(self, label):
293 def add_to_model(self):
294 self.mdl.add_restraint(self.rs)
295 self.mdl.add_restraint(self.rs_nuis)
297 self.mdl.add_restraint(self.rset_bonds)
298 self.mdl.add_restraint(self.rset_angles)
299 def get_hierarchy(self):
302 def get_restraint_set(self):
305 def create_restraints_for_rmf(self):
306 """ create dummy harmonic restraints for each XL but don't add to model
307 Makes it easy to see each contribution to each XL in RMF
312 for nxl
in range(self.rs.get_number_of_restraints()):
313 xl=IMP.isd.AtomicCrossLinkMSRestraint.get_from(self.rs.get_restraint(nxl))
315 for ncontr
in range(xl.get_number_of_contributions()):
316 ps=xl.get_contribution(ncontr)
318 'xl%i_contr%i'%(nxl,ncontr))
320 dummy_rs.append(MyGetRestraint(rs))
325 """ Get particles involved in the restraint """
326 if state_num
is None:
327 return list(reduce(
lambda x,y: self.particles[x]|self.particles[y],self.particles))
329 return list(self.particles[state_num])
332 def get_bonded_pairs(self):
333 return self.bonded_pairs
335 def get_mc_sample_objects(self,max_step_sigma,max_step_psi):
336 """ HACK! Make a SampleObjects class that can be used with PMI::samplers"""
338 psigma=[[self.sigma],max_step_sigma]
340 ppsi=[[self.psi],max_step_psi]
342 ppsi=[[self.psis[p]
for p
in self.psis],max_step_psi]
348 return 'XL restraint with '+str(len(self.rs.get_restraint(0).get_number_of_restraints())) \
351 def load_nuisances_from_stat_file(self,in_fn,nframe):
352 """Read a stat file and load all the sigmas.
353 This is potentially quite stupid.
354 It's also a hack since the sigmas should be stored in the RMF file.
355 Also, requires one sigma and one psi for ALL XLs.
358 sig_val = float(subprocess.check_output([
"process_output.py",
"-f",in_fn,
359 "-s",
"AtomicXLRestraint_sigma"]).split(
'\n>')[1+nframe])
360 psi_val = float(subprocess.check_output([
"process_output.py",
"-f",in_fn,
361 "-s",
"AtomicXLRestraint_psi"]).split(
'\n>')[1+nframe])
362 for nxl
in range(self.rs.get_number_of_restraints()):
363 xl=IMP.isd.AtomicCrossLinkMSRestraint.get_from(self.rs.get_restraint(nxl))
366 for contr
in range(xl.get_number_of_contributions()):
367 sig1,sig2=xl.get_contribution_sigmas(contr)
370 print(
'loaded nuisances from file')
372 def plot_violations(self,out_prefix,
373 max_prob_for_violation=0.1,
374 min_dist_for_violation=1e9,
376 limit_to_chains=
None,
378 """Create CMM files, one for each state, of all xinks.
379 will draw in GREEN if non-violated in all states (or if only one state)
380 will draw in PURPLE if non-violated only in a subset of states (draws nothing elsewhere)
381 will draw in RED in ALL states if all violated
382 (if only one state, you'll only see green and red)
384 @param out_prefix Output xlink files prefix
385 @param max_prob_for_violation It's a violation if the probability is below this
386 @param min_dist_for_violation It's a violation if the min dist is above this
387 @param coarsen Use CA positions
388 @param limit_to_chains Try to visualize just these chains
389 @param exclude_to_chains Try to NOT visualize these chains
391 print(
'going to calculate violations and plot CMM files')
392 all_stats = self.get_best_stats()
393 all_dists = [s[
"low_dist"]
for s
in all_stats]
399 cmds = defaultdict(set)
400 for nstate
in range(self.nstates):
401 outf=open(out_prefix+str(nstate)+
'.cmm',
'w')
402 outf.write(
'<marker_set name="xlinks_state%i"> \n' % nstate)
405 print(
'will limit to',limit_to_chains)
406 print(
'will exclude',exclude_chains)
407 state_info.append(self.get_best_stats(nstate,
411 for nxl
in range(self.rs.get_number_of_restraints()):
415 for nstate
in range(self.nstates):
416 prob = state_info[nstate][nxl][
"prob"]
417 low_dist = state_info[nstate][nxl][
"low_dist"]
418 if prob<max_prob_for_violation
or low_dist>min_dist_for_violation:
426 if len(npass)==self.nstates:
428 elif len(nviol)==self.nstates:
432 print(nxl,
'state dists:',[state_info[nstate][nxl][
"low_dist"]
for nstate
in range(self.nstates)],
433 'viol states:',nviol,
'all viol?',all_viol)
434 for nstate
in range(self.nstates):
436 r=0.365; g=0.933; b=0.365;
439 r=0.980; g=0.302; b=0.247;
446 r=0.365; g=0.933; b=0.365;
448 pp = state_info[nstate][nxl][
"low_pp"]
457 cmds[nstate].add((ch1,r1))
458 cmds[nstate].add((ch2,r2))
460 outf = out_fns[nstate]
462 outf.write(
'<marker id= "%d" x="%.3f" y="%.3f" z="%.3f" radius="0.8" '
463 'r="%.2f" g="%.2f" b="%.2f"/> \n' % (nv,c1[0],c1[1],c1[2],r,g,b))
464 outf.write(
'<marker id= "%d" x="%.3f" y="%.3f" z="%.3f" radius="0.8" '
465 'r="%.2f" g="%.2f" b="%.2f"/> \n' % (nv+1,c2[0],c2[1],c2[2],r,g,b))
466 outf.write(
'<link id1= "%d" id2="%d" radius="0.8" '
467 'r="%.2f" g="%.2f" b="%.2f"/> \n' % (nv,nv+1,r,g,b))
470 for nstate
in range(self.nstates):
471 out_fns[nstate].write(
'</marker_set>\n')
472 out_fns[nstate].close()
474 for ch,r
in cmds[nstate]:
475 cmd+=
'#%i:%i.%s '%(nstate,r,ch)
479 def _get_contribution_info(self,xl,ncontr,use_CA=False):
480 """Return the particles at that contribution. If requested will return CA's instead"""
481 idx1=xl.get_contribution(ncontr)[0]
482 idx2=xl.get_contribution(ncontr)[1]
490 return idx1,idx2,dist
492 def get_best_stats(self,limit_to_state=None,limit_to_chains=None,exclude_chains='',use_CA=False):
493 ''' return the probability, best distance, two coords, and possibly the psi for each xl
494 @param limit_to_state Only examine contributions from one state
495 @param limit_to_chains Returns the particles for certain "easy to visualize" chains
496 @param exclude_chains Even if you limit, don't let one end be in this list.
497 Only works if you also limit chains
500 for nxl
in range(self.rs.get_number_of_restraints()):
502 xl=IMP.isd.AtomicCrossLinkMSRestraint.get_from(self.rs.get_restraint(nxl))
509 for contr
in range(xl.get_number_of_contributions()):
510 pp = xl.get_contribution(contr)
517 if limit_to_state
is not None:
519 if nstate!=limit_to_state:
521 state_contrs.append(contr)
524 if limit_to_chains
is not None:
527 if (c1
in limit_to_chains
or c2
in limit_to_chains)
and (
528 c1
not in exclude_chains
and c2
not in exclude_chains):
529 if dist<low_dist_lim:
536 if limit_to_state
is not None:
537 this_info[
"prob"] = xl.evaluate_for_contributions(state_contrs,
None)
539 this_info[
"prob"] = xl.unprotected_evaluate(
None)
540 if limit_to_chains
is not None:
541 this_info[
"low_pp"] = low_pp_lim
543 this_info[
"low_pp"] = low_pp
545 this_info[
"low_dist"] = low_dist
548 this_info[
"psi"] = pval
549 ret.append(this_info)
552 def print_stats(self):
554 stats = self.get_best_stats()
555 for nxl,s
in enumerate(stats):
560 def get_output(self):
563 score = self.weight * self.rs.unprotected_evaluate(
None)
564 output[
"_TotalScore"] = str(score)
565 output[
"AtomicXLRestraint" + self.label] = str(score)
568 output[
"AtomicXLRestraint_sigma"] = self.sigma.get_scale()
569 output[
"AtomicXLRestraint_priors"] = self.rs_nuis.unprotected_evaluate(
None)
571 output[
"AtomicXLRestraint_psi"] = self.psi.get_scale()
575 output[
"AtomicXLRestraint_NZBonds"] = self.rset_bonds.evaluate(
False)
576 output[
"AtomicXLRestraint_NZAngles"] = self.rset_angles.evaluate(
False)
580 stats = self.get_best_stats()
581 for nxl,s
in enumerate(stats):
582 if s[
'low_dist']>20.0:
584 output[
"AtomicXLRestraint_%i_%s"%(nxl,
"Prob")]=str(s[
'prob'])
585 output[
"AtomicXLRestraint_%i_%s"%(nxl,
"BestDist")]=str(s[
'low_dist'])
587 output[
"AtomicXLRestraint_%i_%s"%(nxl,
"psi")]=str(s[
'psi'])
588 output[
"AtomicXLRestraint_NumViol"] = str(bad_count)
void show_molecular_hierarchy(Hierarchy h)
Print out the molecular hierarchy.
Ints get_index(const kernel::ParticlesTemp &particles, const Subset &subset, const Subsets &excluded)
static Atom setup_particle(kernel::Model *m, ParticleIndex pi, Atom other)
Various classes to hold sets of particles.
Upper bound harmonic function (non-zero when feature > mean)
ParticlesTemp get_particles(kernel::Model *m, const ParticleIndexes &ps)
Object used to hold a set of restraints.
Restrain atom pairs based on a set of crosslinks.
Low level functionality (logging, error handling, profiling, command line flags etc) that is used by ...
static Scale setup_particle(kernel::Model *m, ParticleIndex pi)
Dihedral restraint between four particles.
Add scale parameter to particle.
Hierarchy get_residue(Hierarchy mhd, unsigned int index)
Get the residue with the specified index.
double get_distance(XYZR a, XYZR b)
Compute the sphere distance between a and b.
A class to store an fixed array of same-typed values.
Angle restraint between three particles.
static XYZR setup_particle(kernel::Model *m, ParticleIndex pi)
A decorator for a particle representing an atom.
Hierarchies get_by_type(Hierarchy mhd, GetByType t)
A decorator for a particle with x,y,z coordinates.
Class to handle individual model particles.
int get_state_index(Hierarchy h)
Walk up the hierarchy to find the current state.
Basic functionality that is expected to be used by a wide variety of IMP users.
General purpose algebraic and geometric methods that are expected to be used by a wide variety of IMP...
Calculate the -Log of a list of restraints.
double get_distance(const VectorD< D > &v1, const VectorD< D > &v2)
Compute the distance between two vectors.
Applies a PairScore to a Pair.
Functionality for loading, creating, manipulating and scoring atomic structures.
std::string get_chain_id(Hierarchy h)
Select hierarchy particles identified by the biological name.
Class for storing model, its restraints, constraints, and particles.
Inferential scoring building on methods developed as part of the Inferential Structure Determination ...
Harmonic function (symmetric about the mean)