00001
00002
00003
00004
00005
00006
00007
00008
00009 #ifndef IMPSTATISTICS_KM_FILTER_CENTERS_H
00010 #define IMPSTATISTICS_KM_FILTER_CENTERS_H
00011
00012 #include "KMCenters.h"
00013 #include <iostream>
00014 #include "KMData.h"
00015 #include "KMCentersTree.h"
00016 IMPSTATISTICS_BEGIN_NAMESPACE
00017
00018 #ifndef IMP_DOXYGEN
00019
00020
00021
00022
00023
00024
00025 class IMPSTATISTICSEXPORT KMFilterCenters : public KMCenters{
00026 public:
00027 KMFilterCenters();
00028
00029
00030
00031
00032
00033
00034
00035 KMFilterCenters(int k, KMData* data,KMPointArray *ini_cen_arr=NULL,
00036 double df = 1);
00037 virtual ~KMFilterCenters();
00038 public:
00039
00040 KMPointArray *get_sums(bool auto_update = true){
00041 if (auto_update && !valid_) compute_distortion();
00042 return sums_;
00043 }
00044
00045 std::vector<double>* get_sum_sqs(bool auto_update = true){
00046 if (auto_update && !valid_) compute_distortion();
00047 return &sum_sqs_;
00048 }
00049
00050 std::vector<int>* get_weights(bool auto_update = true){
00051 if (auto_update && !valid_) compute_distortion();
00052 return &weights_;
00053 }
00054
00055 double get_distortion(bool auto_update = true) {
00056 if (auto_update && !valid_) compute_distortion();
00057 return curr_dist_;
00058 }
00059
00060 double get_average_distortion(bool auto_update = true){
00061 if (auto_update && !valid_) compute_distortion();
00062 return curr_dist_/double(get_number_of_points());
00063 }
00064
00065 std::vector<double>* get_distortions(bool auto_update = true) {
00066 if (auto_update && !valid_) compute_distortion();
00067 return &dists_;
00068 }
00069
00070
00071
00072
00073
00074
00075
00076 void get_assignments(std::vector<int> &close_center);
00077
00078
00079 virtual void generate_random_centers(int k);
00080 void show(std::ostream& out=std::cout) const;
00081
00082
00083
00084
00085
00086
00087
00088
00089
00090
00091 void move_to_centroid();
00092 protected:
00093 void clear_data();
00094
00095
00096
00097
00098
00099
00100
00101
00102
00103
00104
00105
00106 void compute_distortion();
00107 void validate(){ valid_ = true; }
00108
00109
00110
00111 void invalidate();
00112
00113 protected:
00114 KMPointArray *sums_;
00115 KMPoint sum_sqs_;
00116 std::vector<int> weights_;
00117 KMPointArray *ini_cen_arr_;
00118 KMPoint dists_;
00119 double curr_dist_;
00120 bool valid_;
00121 double damp_factor_;
00122
00123 KMCentersTree* tree_;
00124 };
00125
00126 class IMPSTATISTICSEXPORT KMFilterCentersResults : public KMCenters {
00127 public:
00128 KMFilterCentersResults(){};
00129
00130
00131
00132
00133
00134
00135
00136 KMFilterCentersResults(KMFilterCenters &full)
00137 : KMCenters(full) {
00138 close_center_.clear();
00139 full.get_assignments(close_center_);
00140 sums_ = new KMPointArray();
00141 copy_points(full.get_sums(),sums_);
00142 copy_point(full.get_sum_sqs(),&sum_sqs_);
00143 std::vector<int> *w = full.get_weights();
00144 weights_.clear();
00145 for(unsigned int i=0;i<w->size();i++) {
00146 weights_.push_back((*w)[i]);
00147 }
00148 copy_point(full.get_distortions(),&dists_);
00149 curr_dist_ = full.get_distortion();
00150 }
00151 KMFilterCentersResults & operator=(const KMFilterCentersResults &other) {
00152 if (this != &other) {
00153 KMCenters::operator=(other);
00154 close_center_.clear();
00155 for(unsigned int i=0;i<other.close_center_.size();i++) {
00156 close_center_.push_back(other.close_center_[i]);
00157 }
00158 sums_ = new KMPointArray();
00159 copy_points(other.sums_,sums_);
00160 copy_point(&other.sum_sqs_,&sum_sqs_);
00161 weights_.clear();
00162 for(unsigned int i=0;i<other.weights_.size();i++) {
00163 weights_.push_back(other.weights_[i]);
00164 }
00165 copy_point(&other.dists_,&dists_);
00166 curr_dist_ = other.curr_dist_;
00167 }
00168 return *this;
00169 }
00170 KMFilterCentersResults(const KMFilterCentersResults &other):KMCenters(other) {
00171 close_center_.clear();
00172 for(unsigned int i=0;i<other.close_center_.size();i++) {
00173 close_center_.push_back(other.close_center_[i]);
00174 }
00175 sums_ = new KMPointArray();
00176 copy_points(other.sums_,sums_);
00177 copy_point(&other.sum_sqs_,&sum_sqs_);
00178 weights_.clear();
00179 for(unsigned int i=0;i<other.weights_.size();i++) {
00180 weights_.push_back(other.weights_[i]);
00181 }
00182 copy_point(&other.dists_,&dists_);
00183 curr_dist_ = other.curr_dist_;
00184 }
00185 ~KMFilterCentersResults() {
00186 deallocate_points(sums_);
00187 }
00188 public:
00189
00190 KMPointArray *get_sums() const {
00191 return sums_;
00192 }
00193
00194 const std::vector<double>* get_sum_sqs() const {
00195 return &sum_sqs_;
00196 }
00197
00198 const std::vector<int>* get_weights() const {
00199 return &weights_;
00200 }
00201
00202 double get_distortion() const {
00203 return curr_dist_;
00204 }
00205
00206 double get_average_distortion() const {
00207 return curr_dist_/double(get_number_of_points());
00208 }
00209
00210 const std::vector<double>* get_distortions() const {
00211 return &dists_;
00212 }
00213
00214 const std::vector<int> * get_assignments() const {
00215 return &close_center_;
00216 }
00217
00218 void show(std::ostream& out=std::cout) const{
00219 for (int j = 0; j < get_number_of_centers(); j++) {
00220 out << " " << std::setw(4) << j << "\t";
00221 print_point(*((*centers_)[j]), out);
00222 out << " dist = " << std::setw(8) << dists_[j] <<
00223 " weight = " << std::setw(8) << weights_[j] <<
00224 std::endl;
00225 }
00226 }
00227 protected:
00228 KMPointArray *sums_;
00229 KMPoint sum_sqs_;
00230 std::vector<int> weights_;
00231 KMPoint dists_;
00232 double curr_dist_;
00233 std::vector<int> close_center_;
00234 };
00235
00236 #endif
00237
00238 IMPSTATISTICS_END_NAMESPACE
00239 #endif