00001
00002
00003
00004
00005
00006
00007
00008 #ifndef IMP_OPTIMIZER_H
00009 #define IMP_OPTIMIZER_H
00010
00011 #include "kernel_config.h"
00012 #include "base_types.h"
00013 #include "VersionInfo.h"
00014 #include "Object.h"
00015 #include "utility.h"
00016 #include "Model.h"
00017 #include "Particle.h"
00018 #include "Pointer.h"
00019 #include "OptimizerState.h"
00020 #include <limits>
00021 #include <cmath>
00022
00023 IMP_BEGIN_NAMESPACE
00024
00025
00026
00027
00028
00029
00030
00031
00032
00033
00034
00035
00036
00037
00038
00039
00040
00041
00042 class IMPEXPORT Optimizer: public Object
00043 {
00044 public:
00045 Optimizer(Model *m= NULL, std::string name="Optimizer %1%");
00046
00047
00048
00049
00050
00051
00052
00053
00054
00055
00056 virtual Float optimize(unsigned int max_steps) = 0;
00057
00058
00059 virtual VersionInfo get_version_info() const = 0;
00060
00061
00062 Model *get_model() const {
00063 return model_.get();
00064 }
00065
00066
00067
00068
00069
00070
00071
00072
00073 void set_model(Model *m) {model_=m;}
00074
00075
00076
00077 virtual void show(std::ostream &out= std::cout) const {
00078 out << "Some optimizer" << std::endl;
00079 }
00080
00081
00082
00083
00084
00085
00086
00087
00088 IMP_LIST(public, OptimizerState, optimizer_state, OptimizerState*,
00089 OptimizerStates);
00090
00091
00092
00093 IMP_REF_COUNTED_DESTRUCTOR(Optimizer);
00094
00095 protected:
00096
00097 void update_states() const ;
00098
00099 struct FloatIndex
00100 {
00101
00102
00103
00104 friend class Optimizer;
00105 friend class FloatIndexIterator;
00106 Model::ParticleConstIterator p_;
00107 Particle::OptimizedKeyIterator fk_;
00108 FloatIndex(Model::ParticleConstIterator p): p_(p){}
00109 public:
00110 FloatIndex() {}
00111 std::string get_string() const {
00112 return (*p_)->get_name() + ": " + fk_->get_string();
00113 }
00114 };
00115
00116
00117 class FloatIndexIterator
00118 {
00119 typedef FloatIndexIterator This;
00120 Model::ParticleConstIterator pe_;
00121 mutable FloatIndex i_;
00122
00123 void search_valid() const {
00124 while (i_.fk_ == (*i_.p_)->optimized_keys_end()) {
00125 if (i_.fk_ == (*i_.p_)->optimized_keys_end()) {
00126 ++i_.p_;
00127 if (i_.p_== pe_) return;
00128 else {
00129 i_.fk_= (*i_.p_)->optimized_keys_begin();
00130 }
00131 } else {
00132 ++i_.fk_;
00133 }
00134 }
00135 IMP_INTERNAL_CHECK(i_.p_ != pe_, "Should have just returned");
00136 IMP_INTERNAL_CHECK(i_.fk_ != (*i_.p_)->optimized_keys_end(),
00137 "Broken iterator end");
00138 IMP_INTERNAL_CHECK((*i_.p_)->get_is_optimized(*i_.fk_),
00139 "Why did the loop end?");
00140 }
00141 void find_next() const {
00142 ++i_.fk_;
00143 search_valid();
00144 }
00145 public:
00146 FloatIndexIterator(Model::ParticleConstIterator pc,
00147 Model::ParticleConstIterator pe): pe_(pe), i_(pc) {
00148 if (pc != pe) {
00149 i_.fk_= (*pc)->optimized_keys_begin();
00150 search_valid();
00151 }
00152 }
00153 typedef FloatIndex value_type;
00154 typedef FloatIndex& reference;
00155 typedef FloatIndex* pointer;
00156 typedef std::forward_iterator_tag iterator_category;
00157 typedef int difference_type;
00158
00159 const This &operator++() {
00160 find_next();
00161 return *this;
00162 }
00163 This operator++(int) {
00164 This ret=*this;
00165 find_next();
00166 return ret;
00167 }
00168 reference operator*() const {
00169 IMP_INTERNAL_CHECK((*i_.p_)->get_is_optimized(*i_.fk_),
00170 "The iterator is broken");
00171 return i_;
00172 }
00173 pointer operator->() const {
00174 return &i_;
00175 }
00176 bool operator==(const This &o) const {
00177 if (i_.p_ != o.i_.p_) return false;
00178 if (i_.p_== pe_) return o.i_.p_ ==pe_;
00179 else return i_.fk_ == o.i_.fk_;
00180 }
00181 bool operator!=(const This &o) const {
00182 return !operator==(o);
00183 }
00184 };
00185
00186
00187
00188
00189
00190
00191
00192
00193
00194
00195
00196
00197 FloatIndexIterator float_indexes_begin() const {
00198 return FloatIndexIterator(model_->particles_begin(),
00199 model_->particles_end());
00200 }
00201
00202 FloatIndexIterator float_indexes_end() const {
00203 return FloatIndexIterator(model_->particles_end(),
00204 model_->particles_end());
00205 }
00206
00207 void set_value(FloatIndex fi, Float v) const {
00208 IMP_INTERNAL_CHECK(fi.p_ != model_->particles_end(),
00209 "Out of range FloatIndex in Optimizer");
00210 IMP_INTERNAL_CHECK((*fi.p_)->get_is_optimized(*fi.fk_),
00211 "Keep your mits off unoptimized attributes "
00212 << (*fi.p_)->get_name() << " " << *fi.fk_ << std::endl);
00213 (*fi.p_)->set_value(*fi.fk_, v);
00214 }
00215
00216 Float get_value(FloatIndex fi) const {
00217
00218 IMP_INTERNAL_CHECK(static_cast<Model::ParticleConstIterator>(fi.p_)
00219 != model_->particles_end(),
00220 "Out of range FloatIndex in Optimizer");
00221 return (*fi.p_)->get_value(*fi.fk_);
00222 }
00223
00224 Float get_derivative(FloatIndex fi) const {
00225 IMP_INTERNAL_CHECK(fi.p_ != model_->particles_end(),
00226 "Out of range FloatIndex in Optimizer");
00227 return (*fi.p_)->get_derivative(*fi.fk_);
00228 }
00229
00230
00231
00232 typedef std::vector<FloatIndex> FloatIndexes;
00233
00234 double width(FloatKey k) const {
00235 if (!widths_.fits(k.get_index())
00236 || !FloatTable::Traits::get_is_valid(widths_.get(k.get_index())) ) {
00237 FloatRange w= model_->get_range(k);
00238 double wid=static_cast<double>(w.second)- w.first;
00239 if (wid > .0001) {
00240
00241 widths_.add(k.get_index(), wid);
00242 } else {
00243 widths_.add(k.get_index(), 1);
00244 }
00245 }
00246 return widths_.get(k.get_index());
00247
00248 }
00249
00250
00251
00252
00253
00254
00255
00256
00257 void set_scaled_value(FloatIndex fi, Float v) const {
00258 double wid = width(*fi.fk_);
00259 set_value(fi, v*wid);
00260 }
00261
00262 double get_scaled_value(FloatIndex fi) const {
00263 double uv= get_value(fi);
00264 double wid = width(*fi.fk_);
00265 return uv/wid;
00266 }
00267
00268 double get_scaled_derivative(FloatIndex fi) const {
00269 double uv=get_derivative(fi);
00270 double wid= width(*fi.fk_);
00271 return uv*wid;
00272 }
00273
00274
00275 void clear_range_cache() {
00276 widths_.clear();
00277 }
00278
00279
00280 private:
00281 typedef internal::VectorStorage<internal::FloatAttributeTableTraits>
00282 FloatTable;
00283 mutable FloatTable widths_;
00284 WeakPointer<Model> model_;
00285 };
00286
00287
00288
00289 class SaveOptimizeds: public RAII {
00290 ParticlesTemp pt_;
00291 std::vector<internal::ParticleStorage::OptimizedTable> saved_;
00292 public:
00293 IMP_RAII(SaveOptimizeds, (const ParticlesTemp &pt),,
00294 {
00295 pt_=pt;
00296 saved_= std::vector<internal::ParticleStorage::OptimizedTable>
00297 (pt_.size());
00298 for (unsigned int i=0; i< pt_.size(); ++i) {
00299 saved_[i]= pt_[i]->ps_->optimizeds_;
00300 }
00301 },
00302 {
00303 for (unsigned int i=0; i< pt_.size(); ++i) {
00304 pt_[i]->ps_->optimizeds_= saved_[i];
00305 }
00306 });
00307 };
00308
00309 IMP_OBJECTS(Optimizer);
00310
00311 IMP_END_NAMESPACE
00312
00313 #endif