9 #ifndef IMPMULTIFIT_RMSD_CLUSTERING_H
10 #define IMPMULTIFIT_RMSD_CLUSTERING_H
12 #include <IMP/multifit/multifit_config.h>
18 #include <boost/graph/adjacency_list.hpp>
20 IMPMULTIFIT_BEGIN_NAMESPACE
39 template <
class TransT>
43 class TransformationRecord {
46 inline TransformationRecord(
const TransT &trans):
47 valid_(
true), trans_(trans) {
49 virtual ~TransformationRecord() {}
51 void join_into(
const TransformationRecord& record) {
52 trans_.update_score(record.trans_.get_score());
53 trans_.join_into(record.trans_);
55 inline float get_score()
const {
return trans_.get_score();}
58 centroid_ = trans_.get_representative_transformation().get_transformed(
60 TransT get_record()
const {
return trans_;}
61 bool get_valid()
const {
return valid_;}
62 void set_valid(
bool v) {valid_=v;}
68 typedef std::vector<TransformationRecord> TransformationRecords;
70 typedef GeometricHash<int, 3> Hash3;
71 typedef boost::property<boost::edge_weight_t, short> ClusEdgeWeightProperty;
72 typedef boost::property<boost::vertex_index_t, int> ClusVertexIndexProperty;
74 typedef boost::adjacency_list<boost::vecS, boost::vecS, boost::undirectedS,
75 ClusVertexIndexProperty, ClusEdgeWeightProperty> Graph;
76 typedef boost::graph_traits<Graph> RCGTraits;
77 typedef RCGTraits::vertex_descriptor RCVertex;
78 typedef RCGTraits::edge_descriptor RCEdge;
79 typedef RCGTraits::vertex_iterator RCVertexIt;
80 typedef RCGTraits::edge_iterator RCEdgeIt;
82 struct sort_by_weight {
83 bool operator()(
const std::pair<RCEdge,float> &s1,
84 const std::pair<RCEdge,float> &s2)
const {
85 return s1.second < s2.second;
95 void cluster(
float max_dist,
const std::vector<TransT>& input_trans,
96 std::vector<TransT>& output_trans);
98 void prepare(
const ParticlesTemp &ps);
99 void set_bin_size(
float bin_size) {bin_size_=bin_size;}
103 virtual float get_squared_distance(
const TransT& trans1,
104 const TransT& trans2);
106 void build_graph(
const Hash3::PointList &inds,
107 const std::vector<TransformationRecord*> &recs,
111 void build_full_graph(
const Hash3 &h,
112 const std::vector<TransformationRecord*> &recs,
113 float max_dist, Graph &g);
115 int cluster_graph(Graph &g,
116 const std::vector<TransformationRecord*> &recs,
119 int fast_clustering(
float max_dist,
120 std::vector<TransformationRecord *>& recs);
122 virtual int exhaustive_clustering(
float max_dist,
123 std::vector<TransformationRecord *>& recs);
126 virtual void clean(std::vector<TransformationRecord*>*& records);
134 atom::RMSDCalculator rmsd_calc_;
138 template<
class TransT>
float
140 const TransT& trans2) {
141 return rmsd_calc_.get_squared_rmsd(trans1.get_representative_transformation(),
142 trans2.get_representative_transformation());
145 template<
class TransT>
147 const std::vector<TransformationRecord*> &recs,
148 float max_dist, Graph &g){
150 float max_dist2=max_dist*max_dist;
153 std::vector<RCVertex> nodes(inds.size());
154 for (
unsigned int i=0; i<inds.size(); ++i) {
155 nodes[i]=boost::add_vertex(i,g);
159 for (
unsigned int i=0; i<inds.size(); ++i) {
160 for (
unsigned int j=i+1; j<inds.size(); ++j) {
162 recs[j]->get_record());
163 if (d2 < max_dist2) {
164 boost::add_edge(nodes[i],nodes[j],d2,g);
169 template<
class TransT>
170 void RMSDClustering<TransT>::build_full_graph(
const Hash3 &h,
171 const std::vector<TransformationRecord*> &recs,
172 float max_dist, Graph &g){
173 float max_dist2=max_dist*max_dist;
175 std::vector<RCVertex> nodes(recs.size());
176 for (
unsigned int i=0; i<recs.size(); ++i) {
177 nodes[i]=boost::add_vertex(i,g);
180 for (
int i = 0 ; i < (int)recs.size() ; ++i) {
181 TransT tr=recs[i]->get_record();
182 algebra::Transformation3D t = tr.get_representative_transformation();
183 Hash3::HashResult result =
184 h.neighbors(Hash3::INF, t.get_transformed(centroid_), max_dist);
185 for (
size_t k=0; k<result.size(); ++k ) {
186 int j = result[k]->second;
187 if (i >= j)
continue;
188 float centroids_dist2 = algebra::get_squared_distance(
191 if (centroids_dist2 < max_dist2) {
193 recs[j]->get_record());
194 if (d2 < max_dist2) {
195 boost::add_edge(nodes[i],nodes[j],d2,g);
199 template<
class TransT>
200 int RMSDClustering<TransT>::cluster_graph(Graph &g,
201 const std::vector<TransformationRecord*> &recs,
203 if (boost::num_edges(g)==0)
return 0;
205 <<boost::num_vertices(g)<<std::endl);
206 float max_dist2=max_dist*max_dist;
208 boost::property_map<Graph, boost::edge_weight_t>::type
209 weight =
get(boost::edge_weight, g);
210 std::vector<std::pair<RCEdge,float> > edge_weight;
212 for(boost::tie(ei,ei_end) = boost::edges(g); ei != ei_end; ++ei){
213 edge_weight.push_back(std::pair<RCEdge,float>(*ei,
214 boost::get(weight,*ei)));
218 std::sort(edge_weight.begin(),edge_weight.end(),sort_by_weight());
220 std::vector<bool> used;
221 used.insert(used.end(),boost::num_vertices(g),
false);
222 for(
unsigned int i=0;i<edge_weight.size();i++) {
223 RCEdge e = edge_weight[i].first;
224 int v1_ind=boost::source(e,g);
225 int v2_ind=boost::target(e,g);
227 " and "<<v2_ind<<std::endl);
229 if (!used[v1_ind] && !used[v2_ind] &&
230 (edge_weight[i].second < max_dist2)){
235 TransformationRecord* rec1 = recs[v1_ind];
236 TransformationRecord* rec2 = recs[v2_ind];
237 if (!(rec1->get_valid() &&rec2->get_valid()))
continue;
238 if (rec1->get_score() > rec2->get_score()) {
239 rec1->join_into(*rec2);
240 rec2->set_valid(
false);
242 rec2->join_into(*rec1);
243 rec1->set_valid(
false);
249 template<
class TransT>
255 for (core::XYZs::iterator it = xyzs.begin(); it != xyzs.end(); it++) {
256 centroid_ += it->get_coordinates();
258 centroid_ /= ps.size();
262 template<
class TransT>
264 std::vector<TransformationRecord*>& recs) {
265 IMP_LOG_VERBOSE(
"start fast clustering with "<<recs.size()<<
" records\n");
267 boost::scoped_array<bool> used(
new bool[recs.size()]);
268 Hash3 g_hash((
double)(bin_size_));
271 for (
int i = 0 ; i < (int)recs.size() ; ++i){
273 TransT tr=recs[i]->get_record();
275 tr.get_representative_transformation();
277 g_hash.add(trans_cen, i);
279 <<
" with center:"<<trans_cen<<std::endl);
282 const Hash3::GeomMap &M = g_hash.Map();
283 for (Hash3::GeomMap::const_iterator bucket = M.begin();
284 bucket != M.end() ; ++bucket){
285 const Hash3::PointList &pb = bucket->second;
289 std::vector<std::pair<RCEdge,float> > edge_weight;
290 build_graph(pb,recs,max_dist,g);
291 IMP_LOG_VERBOSE(
"create graph with:"<<boost::num_vertices(g)<<
" nodes and"<<
292 boost::num_edges(g)<<
" edges out of "<<pb.size()<<
" points\n");
294 num_joins +=cluster_graph(g,recs,max_dist);
295 IMP_LOG_VERBOSE(
"after clustering number of joins::"<<num_joins<<std::endl);
301 template<
class TransT>
302 int RMSDClustering<TransT>::exhaustive_clustering(
float max_dist,
303 std::vector<TransformationRecord *>& recs) {
304 IMP_LOG_VERBOSE(
"start full clustering with "<< recs.size()<<
" records \n");
305 if (recs.size()<2)
return 0;
306 boost::scoped_array<bool> used(
new bool[recs.size()]);
307 Hash3 ghash((
double)(max_dist));
310 for (
int i = 0 ; i < (int)recs.size() ; ++i) {
312 algebra::Transformation3D t =
313 recs[i]->get_record().get_representative_transformation();
314 ghash.add(t.get_transformed(centroid_), i);
318 build_full_graph(ghash,recs,max_dist,g);
319 int num_joins = cluster_graph(g,recs,max_dist);
322 template<
class TransT>
324 std::vector<TransformationRecord*>*& records) {
325 std::vector<TransformationRecord*> *results =
326 new std::vector<TransformationRecord*>();
327 for (
int i = 0 ; i < (int)records->size() ; i++){
328 if ((*records)[i]->get_valid()) {
329 results->push_back((*records)[i]);
331 delete((*records)[i]);
338 template<
class TransT>
340 const std::vector<TransT> &input_trans,
341 std::vector<TransT> & output) {
344 std::vector<TransformationRecord*>* records =
345 new std::vector<TransformationRecord*>();
346 for (
typename std::vector<TransT>::const_iterator
347 it = input_trans.begin();it != input_trans.end() ; ++it){
348 TransformationRecord* record =
new TransformationRecord(*it);
349 record->set_centroid(centroid_);
350 records->push_back(record);
353 while (fast_clustering(max_dist, *records)){
358 while (exhaustive_clustering(max_dist, *records)){
364 for (
int i = 0 ; i < (int)records->size() ; ++i){
365 output.push_back((*records)[i]->get_record());
366 delete((*records)[i]);
372 IMPMULTIFIT_END_NAMESPACE