1 """@namespace IMP.pmi.restraints.crosslinking_atomic
2 Restraints for handling crosslinking data at atomic resolution.
5 from __future__
import print_function
14 from collections
import defaultdict
18 def setup_nuisance(m,rs,
27 nuisance.set_lower(min_val_nuis)
28 nuisance.set_upper(max_val_nuis)
29 nuisance.set_is_optimized(nuisance.get_nuisance_key(),is_opt)
31 max_val_prior,min_val_prior))
40 class MyGetRestraint(object):
41 def __init__(self,rs):
43 def get_restraint_for_rmf(self):
46 class AtomicCrossLinkMSRestraint(object):
56 nuisances_are_optimized=
True,
61 """Experimental ATOMIC XL restraint. Provide selections for the particles to restrain.
62 Automatically creates one "sigma" per crosslinked residue and one "psis" per pair.
63 Other nuisance options are available.
64 \note Will return an error if the data+extra_sel don't specify two particles per XL pair.
65 @param root The root hierarchy on which you'll do selection
66 @param data CrossLinkData object
67 @param extra_sel Additional selections to add to each data point. Defaults to:
68 {'atom_type':IMP.atom.AtomType('NZ')}
69 @param length The XL linker length
70 @param nstates The number of states to model. Defaults to the number of states in root.
71 @param label The output label for the restraint
72 @param nuisances_are_optimized Whether to optimize nuisances
73 @param sigma_init The initial value for all the sigmas
74 @param psi_init The initial value for all the psis
75 @param one_psi Use a single psi for all restraints (if False, creates one per XL)
76 @param create_nz Coarse-graining hack - add 'NZ' atoms to every XL'd lysine
79 self.mdl = root.get_model()
84 self.nuis_opt = nuisances_are_optimized
85 self.nstates = nstates
89 print(
"Warning: nstates is not the same as the number of states in root")
93 self.particles=defaultdict(set)
94 self.one_psi = one_psi
95 self.create_nz = create_nz
97 print(
'creating a single psi for all XLs')
99 print(
'creating one psi for each XL')
103 psi_max_nuis = 0.4999999
106 sigma_min_nuis = 1e-7
107 sigma_max_nuis = 100.1
108 sigma_min_prior = 1e-3
109 sigma_max_prior = 100.0
113 self.sig_low = setup_nuisance(self.mdl,self.rs_nuis,init_val=sigma_init,min_val=1.0,
114 max_val=100.0,is_opt=self.nuis_opt)
115 self.sig_high = setup_nuisance(self.mdl,self.rs_nuis,init_val=sigma_init,min_val=1.0,
116 max_val=100.0,is_opt=self.nuis_opt)
118 self.sigma = setup_nuisance(self.mdl,self.rs_nuis,
120 min_val_nuis=sigma_min_nuis,
121 max_val_nuis=sigma_max_nuis,
122 min_val_prior=sigma_min_prior,
123 max_val_prior=sigma_max_prior,
124 is_opt=self.nuis_opt,
127 self.psi = setup_nuisance(self.mdl,self.rs_nuis,
129 min_val_nuis=psi_min_nuis,
130 max_val_nuis=psi_max_nuis,
131 min_val_prior=psi_min_prior,
132 max_val_prior=psi_max_prior,
133 is_opt=self.nuis_opt,
137 for unique_id
in data:
138 self.psis[unique_id]=setup_nuisance(self.mdl,self.rs_nuis,
140 min_val_nuis=psi_min_nuis,
141 max_val_nuis=psi_max_nuis,
142 min_val_prior=psi_min_prior,
143 max_val_prior=psi_max_prior,
144 is_opt=self.nuis_opt,
158 self.bonded_pairs = []
164 for nstate
in range(self.nstates):
165 for unique_id
in data:
166 for xl
in data[unique_id]:
167 xl_pairs = xl.get_selection(root,state_index=nstate)
175 frag = res.get_parent()
190 self.bonded_pairs.append([ca,nz])
191 self.rset_bonds.add_restraint(pr)
195 sel_pre =
IMP.atom.Selection(frag,residue_index=res.get_index()-1).get_selected_particles()
196 sel_post =
IMP.atom.Selection(frag,residue_index=res.get_index()+1).get_selected_particles()
197 if len(sel_pre)>1
or len(sel_post)>1:
198 print(
"SOMETHING WRONG WITH THIS FRAG")
213 ca_post = sel_post[0]
216 if nter
and not cter:
218 self.rset_angles.add_restraint(ar_post)
219 elif cter
and not nter:
221 self.rset_angles.add_restraint(ar_pre)
227 self.rset_angles.add_restraint(idr)
232 for unique_id
in data:
235 psip = self.psi.get_particle_index()
237 psip = self.psis[unique_id].get_particle_index()
248 for nstate
in range(self.nstates):
249 for xl
in data[unique_id]:
250 xl_pairs = xl.get_selection(root,state_index=nstate,
256 num1=num_xls_per_res[str(xl.r1)]
257 num2=num_xls_per_res[str(xl.r2)]
258 if num1<sig_threshold:
262 if num2<sig_threshold:
271 for p1,p2
in xl_pairs:
272 self.particles[nstate]|=set([p1,p2])
273 if max_dist
is not None:
277 r.add_contribution([p1.get_index(),p2.get_index()],
278 [sig1.get_particle_index(),sig2.get_particle_index()])
280 if num_contributions==0:
281 raise RestraintSetupError(
"No contributions!")
283 print(
'created',len(xlrs),
'XL restraints')
286 def set_weight(self,weight):
288 self.rs.set_weight(weight)
290 def set_label(self, label):
293 def add_to_model(self):
300 def get_hierarchy(self):
303 def get_restraint_set(self):
306 def create_restraints_for_rmf(self):
307 """ create dummy harmonic restraints for each XL but don't add to model
308 Makes it easy to see each contribution to each XL in RMF
313 for nxl
in range(self.rs.get_number_of_restraints()):
314 xl=IMP.isd.AtomicCrossLinkMSRestraint.get_from(self.rs.get_restraint(nxl))
316 for ncontr
in range(xl.get_number_of_contributions()):
317 ps=xl.get_contribution(ncontr)
319 'xl%i_contr%i'%(nxl,ncontr))
321 dummy_rs.append(MyGetRestraint(rs))
326 """ Get particles involved in the restraint """
327 if state_num
is None:
328 return list(reduce(
lambda x,y: self.particles[x]|self.particles[y],self.particles))
330 return list(self.particles[state_num])
333 def get_bonded_pairs(self):
334 return self.bonded_pairs
336 def get_mc_sample_objects(self,max_step_sigma,max_step_psi):
337 """ HACK! Make a SampleObjects class that can be used with PMI::samplers"""
339 psigma=[[self.sigma],max_step_sigma]
341 ppsi=[[self.psi],max_step_psi]
343 ppsi=[[self.psis[p]
for p
in self.psis],max_step_psi]
349 return 'XL restraint with '+str(len(self.rs.get_restraint(0).get_number_of_restraints())) \
352 def load_nuisances_from_stat_file(self,in_fn,nframe):
353 """Read a stat file and load all the sigmas.
354 This is potentially quite stupid.
355 It's also a hack since the sigmas should be stored in the RMF file.
356 Also, requires one sigma and one psi for ALL XLs.
359 sig_val = float(subprocess.check_output([
"process_output.py",
"-f",in_fn,
360 "-s",
"AtomicXLRestraint_sigma"]).split(
'\n>')[1+nframe])
361 psi_val = float(subprocess.check_output([
"process_output.py",
"-f",in_fn,
362 "-s",
"AtomicXLRestraint_psi"]).split(
'\n>')[1+nframe])
363 for nxl
in range(self.rs.get_number_of_restraints()):
364 xl=IMP.isd.AtomicCrossLinkMSRestraint.get_from(self.rs.get_restraint(nxl))
367 for contr
in range(xl.get_number_of_contributions()):
368 sig1,sig2=xl.get_contribution_sigmas(contr)
371 print(
'loaded nuisances from file')
373 def plot_violations(self,out_prefix,
374 max_prob_for_violation=0.1,
375 min_dist_for_violation=1e9,
377 limit_to_chains=
None,
379 """Create CMM files, one for each state, of all xinks.
380 will draw in GREEN if non-violated in all states (or if only one state)
381 will draw in PURPLE if non-violated only in a subset of states (draws nothing elsewhere)
382 will draw in RED in ALL states if all violated
383 (if only one state, you'll only see green and red)
385 @param out_prefix Output xlink files prefix
386 @param max_prob_for_violation It's a violation if the probability is below this
387 @param min_dist_for_violation It's a violation if the min dist is above this
388 @param coarsen Use CA positions
389 @param limit_to_chains Try to visualize just these chains
390 @param exclude_to_chains Try to NOT visualize these chains
392 print(
'going to calculate violations and plot CMM files')
393 all_stats = self.get_best_stats()
394 all_dists = [s[
"low_dist"]
for s
in all_stats]
400 cmds = defaultdict(set)
401 for nstate
in range(self.nstates):
402 outf=open(out_prefix+str(nstate)+
'.cmm',
'w')
403 outf.write(
'<marker_set name="xlinks_state%i"> \n' % nstate)
406 print(
'will limit to',limit_to_chains)
407 print(
'will exclude',exclude_chains)
408 state_info.append(self.get_best_stats(nstate,
412 for nxl
in range(self.rs.get_number_of_restraints()):
416 for nstate
in range(self.nstates):
417 prob = state_info[nstate][nxl][
"prob"]
418 low_dist = state_info[nstate][nxl][
"low_dist"]
419 if prob<max_prob_for_violation
or low_dist>min_dist_for_violation:
427 if len(npass)==self.nstates:
429 elif len(nviol)==self.nstates:
433 print(nxl,
'state dists:',[state_info[nstate][nxl][
"low_dist"]
for nstate
in range(self.nstates)],
434 'viol states:',nviol,
'all viol?',all_viol)
435 for nstate
in range(self.nstates):
437 r=0.365; g=0.933; b=0.365;
440 r=0.980; g=0.302; b=0.247;
447 r=0.365; g=0.933; b=0.365;
449 pp = state_info[nstate][nxl][
"low_pp"]
458 cmds[nstate].add((ch1,r1))
459 cmds[nstate].add((ch2,r2))
461 outf = out_fns[nstate]
463 outf.write(
'<marker id= "%d" x="%.3f" y="%.3f" z="%.3f" radius="0.8" '
464 'r="%.2f" g="%.2f" b="%.2f"/> \n' % (nv,c1[0],c1[1],c1[2],r,g,b))
465 outf.write(
'<marker id= "%d" x="%.3f" y="%.3f" z="%.3f" radius="0.8" '
466 'r="%.2f" g="%.2f" b="%.2f"/> \n' % (nv+1,c2[0],c2[1],c2[2],r,g,b))
467 outf.write(
'<link id1= "%d" id2="%d" radius="0.8" '
468 'r="%.2f" g="%.2f" b="%.2f"/> \n' % (nv,nv+1,r,g,b))
471 for nstate
in range(self.nstates):
472 out_fns[nstate].write(
'</marker_set>\n')
473 out_fns[nstate].close()
475 for ch,r
in cmds[nstate]:
476 cmd+=
'#%i:%i.%s '%(nstate,r,ch)
480 def _get_contribution_info(self,xl,ncontr,use_CA=False):
481 """Return the particles at that contribution. If requested will return CA's instead"""
482 idx1=xl.get_contribution(ncontr)[0]
483 idx2=xl.get_contribution(ncontr)[1]
491 return idx1,idx2,dist
493 def get_best_stats(self,limit_to_state=None,limit_to_chains=None,exclude_chains='',use_CA=False):
494 ''' return the probability, best distance, two coords, and possibly the psi for each xl
495 @param limit_to_state Only examine contributions from one state
496 @param limit_to_chains Returns the particles for certain "easy to visualize" chains
497 @param exclude_chains Even if you limit, don't let one end be in this list.
498 Only works if you also limit chains
501 for nxl
in range(self.rs.get_number_of_restraints()):
503 xl=IMP.isd.AtomicCrossLinkMSRestraint.get_from(self.rs.get_restraint(nxl))
510 for contr
in range(xl.get_number_of_contributions()):
511 pp = xl.get_contribution(contr)
518 if limit_to_state
is not None:
520 if nstate!=limit_to_state:
522 state_contrs.append(contr)
525 if limit_to_chains
is not None:
528 if (c1
in limit_to_chains
or c2
in limit_to_chains)
and (
529 c1
not in exclude_chains
and c2
not in exclude_chains):
530 if dist<low_dist_lim:
537 if limit_to_state
is not None:
538 this_info[
"prob"] = xl.evaluate_for_contributions(state_contrs,
None)
540 this_info[
"prob"] = xl.unprotected_evaluate(
None)
541 if limit_to_chains
is not None:
542 this_info[
"low_pp"] = low_pp_lim
544 this_info[
"low_pp"] = low_pp
546 this_info[
"low_dist"] = low_dist
549 this_info[
"psi"] = pval
550 ret.append(this_info)
553 def print_stats(self):
555 stats = self.get_best_stats()
556 for nxl,s
in enumerate(stats):
561 def get_output(self):
564 score = self.weight * self.rs.unprotected_evaluate(
None)
565 output[
"_TotalScore"] = str(score)
566 output[
"AtomicXLRestraint" + self.label] = str(score)
569 output[
"AtomicXLRestraint_sigma"] = self.sigma.get_scale()
570 output[
"AtomicXLRestraint_priors"] = self.rs_nuis.unprotected_evaluate(
None)
572 output[
"AtomicXLRestraint_psi"] = self.psi.get_scale()
576 output[
"AtomicXLRestraint_NZBonds"] = self.rset_bonds.evaluate(
False)
577 output[
"AtomicXLRestraint_NZAngles"] = self.rset_angles.evaluate(
False)
581 stats = self.get_best_stats()
582 for nxl,s
in enumerate(stats):
583 if s[
'low_dist']>20.0:
585 output[
"AtomicXLRestraint_%i_%s"%(nxl,
"Prob")]=str(s[
'prob'])
586 output[
"AtomicXLRestraint_%i_%s"%(nxl,
"BestDist")]=str(s[
'low_dist'])
588 output[
"AtomicXLRestraint_%i_%s"%(nxl,
"psi")]=str(s[
'psi'])
589 output[
"AtomicXLRestraint_NumViol"] = str(bad_count)
void show_molecular_hierarchy(Hierarchy h)
Print out the molecular hierarchy.
static Atom setup_particle(Model *m, ParticleIndex pi, Atom other)
Various classes to hold sets of particles.
Upper bound harmonic function (non-zero when feature > mean)
static XYZR setup_particle(Model *m, ParticleIndex pi)
A class to store an fixed array of same-typed values.
Restrain atom pairs based on a set of crosslinks.
ParticlesTemp get_particles(Model *m, const ParticleIndexes &ps)
Dihedral restraint between four particles.
Add scale parameter to particle.
Hierarchy get_residue(Hierarchy mhd, unsigned int index)
Get the residue with the specified index.
double get_distance(XYZR a, XYZR b)
Compute the sphere distance between a and b.
Object used to hold a set of restraints.
Class for storing model, its restraints, constraints, and particles.
Angle restraint between three particles.
Ints get_index(const ParticlesTemp &particles, const Subset &subset, const Subsets &excluded)
A decorator for a particle representing an atom.
Hierarchies get_by_type(Hierarchy mhd, GetByType t)
Gather all the molecular particles of a certain level in the hierarchy.
A decorator for a particle with x,y,z coordinates.
static Scale setup_particle(Model *m, ParticleIndex pi)
int get_state_index(Hierarchy h)
Walk up the hierarchy to find the current state.
Basic functionality that is expected to be used by a wide variety of IMP users.
General purpose algebraic and geometric methods that are expected to be used by a wide variety of IMP...
The general base class for IMP exceptions.
Calculate the -Log of a list of restraints.
Class to handle individual model particles.
double get_distance(const VectorD< D > &v1, const VectorD< D > &v2)
Compute the distance between two vectors.
Applies a PairScore to a Pair.
Functionality for loading, creating, manipulating and scoring atomic structures.
std::string get_chain_id(Hierarchy h)
Select hierarchy particles identified by the biological name.
Inferential scoring building on methods developed as part of the Inferential Structure Determination ...
Harmonic function (symmetric about the mean)