00001 #ifndef _theplu_yat_classifier_knn_
00002 #define _theplu_yat_classifier_knn_
00003
00004
00005
00006
00007
00008
00009
00010
00011
00012
00013
00014
00015
00016
00017
00018
00019
00020
00021
00022
00023
00024
00025 #include "DataLookup1D.h"
00026 #include "DataLookupWeighted1D.h"
00027 #include "KNN_Uniform.h"
00028 #include "MatrixLookup.h"
00029 #include "MatrixLookupWeighted.h"
00030 #include "SupervisedClassifier.h"
00031 #include "Target.h"
00032 #include "yat/utility/Matrix.h"
00033 #include "yat/utility/yat_assert.h"
00034
00035 #include <cmath>
00036 #include <map>
00037 #include <stdexcept>
00038
00039 namespace theplu {
00040 namespace yat {
00041 namespace classifier {
00042
00059 template <typename Distance, typename NeighborWeighting=KNN_Uniform>
00060 class KNN : public SupervisedClassifier
00061 {
00062
00063 public:
00071 KNN(void);
00072
00073
00083 KNN(const Distance&);
00084
00085
00089 virtual ~KNN();
00090
00091
00096 unsigned int k() const;
00097
00103 void k(unsigned int k_in);
00104
00105
00106 KNN<Distance,NeighborWeighting>* make_classifier(void) const;
00107
00118 void predict(const MatrixLookup& data , utility::Matrix& results) const;
00119
00132 void predict(const MatrixLookupWeighted& data, utility::Matrix& results) const;
00133
00134
00149 void train(const MatrixLookup& data, const Target& targets);
00150
00157 void train(const MatrixLookupWeighted& data, const Target& targets);
00158
00159 private:
00160
00161 const MatrixLookup* data_ml_;
00162 const MatrixLookupWeighted* data_mlw_;
00163 const Target* target_;
00164
00165
00166 unsigned int k_;
00167
00168 Distance distance_;
00169 NeighborWeighting weighting_;
00170
00171 void calculate_unweighted(const MatrixLookup&,
00172 const MatrixLookup&,
00173 utility::Matrix*) const;
00174 void calculate_weighted(const MatrixLookupWeighted&,
00175 const MatrixLookupWeighted&,
00176 utility::Matrix*) const;
00177
00178 void predict_common(const utility::Matrix& distances,
00179 utility::Matrix& prediction) const;
00180
00181 };
00182
00183
00184
00185
00186 template <typename Distance, typename NeighborWeighting>
00187 KNN<Distance, NeighborWeighting>::KNN()
00188 : SupervisedClassifier(),data_ml_(0),data_mlw_(0),target_(0),k_(3)
00189 {
00190 }
00191
00192 template <typename Distance, typename NeighborWeighting>
00193 KNN<Distance, NeighborWeighting>::KNN(const Distance& dist)
00194 : SupervisedClassifier(),data_ml_(0),data_mlw_(0),target_(0),k_(3), distance_(dist)
00195 {
00196 }
00197
00198
00199 template <typename Distance, typename NeighborWeighting>
00200 KNN<Distance, NeighborWeighting>::~KNN()
00201 {
00202 }
00203
00204
00205 template <typename Distance, typename NeighborWeighting>
00206 void KNN<Distance, NeighborWeighting>::calculate_unweighted
00207 (const MatrixLookup& training, const MatrixLookup& test,
00208 utility::Matrix* distances) const
00209 {
00210 for(size_t i=0; i<training.columns(); i++) {
00211 for(size_t j=0; j<test.columns(); j++) {
00212 (*distances)(i,j) = distance_(training.begin_column(i), training.end_column(i),
00213 test.begin_column(j));
00214 utility::yat_assert<std::runtime_error>(!std::isnan((*distances)(i,j)));
00215 }
00216 }
00217 }
00218
00219
00220 template <typename Distance, typename NeighborWeighting>
00221 void
00222 KNN<Distance, NeighborWeighting>::calculate_weighted
00223 (const MatrixLookupWeighted& training, const MatrixLookupWeighted& test,
00224 utility::Matrix* distances) const
00225 {
00226 for(size_t i=0; i<training.columns(); i++) {
00227 for(size_t j=0; j<test.columns(); j++) {
00228 (*distances)(i,j) = distance_(training.begin_column(i), training.end_column(i),
00229 test.begin_column(j));
00230
00231
00232 if(std::isnan((*distances)(i,j)))
00233 (*distances)(i,j)=std::numeric_limits<double>::infinity();
00234 }
00235 }
00236 }
00237
00238
00239 template <typename Distance, typename NeighborWeighting>
00240 unsigned int KNN<Distance, NeighborWeighting>::k() const
00241 {
00242 return k_;
00243 }
00244
00245 template <typename Distance, typename NeighborWeighting>
00246 void KNN<Distance, NeighborWeighting>::k(unsigned int k)
00247 {
00248 k_=k;
00249 }
00250
00251
00252 template <typename Distance, typename NeighborWeighting>
00253 KNN<Distance, NeighborWeighting>*
00254 KNN<Distance, NeighborWeighting>::make_classifier() const
00255 {
00256
00257
00258 KNN* knn=new KNN<Distance, NeighborWeighting>(distance_);
00259 knn->weighting_=this->weighting_;
00260 knn->k(this->k());
00261 return knn;
00262 }
00263
00264
00265 template <typename Distance, typename NeighborWeighting>
00266 void KNN<Distance, NeighborWeighting>::train(const MatrixLookup& data,
00267 const Target& target)
00268 {
00269 utility::yat_assert<std::runtime_error>
00270 (data.columns()==target.size(),
00271 "KNN::train called with different sizes of target and data");
00272
00273 if(data.columns()<k_)
00274 k_=data.columns();
00275 data_ml_=&data;
00276 data_mlw_=0;
00277 target_=⌖
00278 }
00279
00280 template <typename Distance, typename NeighborWeighting>
00281 void KNN<Distance, NeighborWeighting>::train(const MatrixLookupWeighted& data,
00282 const Target& target)
00283 {
00284 utility::yat_assert<std::runtime_error>
00285 (data.columns()==target.size(),
00286 "KNN::train called with different sizes of target and data");
00287
00288 if(data.columns()<k_)
00289 k_=data.columns();
00290 data_ml_=0;
00291 data_mlw_=&data;
00292 target_=⌖
00293 }
00294
00295
00296 template <typename Distance, typename NeighborWeighting>
00297 void KNN<Distance, NeighborWeighting>::predict(const MatrixLookup& test,
00298 utility::Matrix& prediction) const
00299 {
00300
00301 utility::Matrix* distances = 0;
00302
00303 if(data_ml_ && !data_mlw_) {
00304 utility::yat_assert<std::runtime_error>
00305 (data_ml_->rows()==test.rows(),
00306 "KNN::predict different number of rows in training and test data");
00307 distances=new utility::Matrix(data_ml_->columns(),test.columns());
00308 calculate_unweighted(*data_ml_,test,distances);
00309 }
00310 else if (data_mlw_ && !data_ml_) {
00311
00312 utility::yat_assert<std::runtime_error>
00313 (data_mlw_->rows()==test.rows(),
00314 "KNN::predict different number of rows in training and test data");
00315 distances=new utility::Matrix(data_mlw_->columns(),test.columns());
00316 calculate_weighted(*data_mlw_,MatrixLookupWeighted(test),
00317 distances);
00318 }
00319 else {
00320 std::runtime_error("KNN::predict no training data");
00321 }
00322
00323 prediction.resize(target_->nof_classes(),test.columns(),0.0);
00324 predict_common(*distances,prediction);
00325 if(distances)
00326 delete distances;
00327 }
00328
00329 template <typename Distance, typename NeighborWeighting>
00330 void KNN<Distance, NeighborWeighting>::predict(const MatrixLookupWeighted& test,
00331 utility::Matrix& prediction) const
00332 {
00333
00334 utility::Matrix* distances=0;
00335
00336 if(data_ml_ && !data_mlw_) {
00337 utility::yat_assert<std::runtime_error>
00338 (data_ml_->rows()==test.rows(),
00339 "KNN::predict different number of rows in training and test data");
00340 distances=new utility::Matrix(data_ml_->columns(),test.columns());
00341 calculate_weighted(MatrixLookupWeighted(*data_ml_),test,distances);
00342 }
00343
00344 else if (data_mlw_ && !data_ml_) {
00345 utility::yat_assert<std::runtime_error>
00346 (data_mlw_->rows()==test.rows(),
00347 "KNN::predict different number of rows in training and test data");
00348 distances=new utility::Matrix(data_mlw_->columns(),test.columns());
00349 calculate_weighted(*data_mlw_,test,distances);
00350 }
00351 else {
00352 std::runtime_error("KNN::predict no training data");
00353 }
00354
00355 prediction.resize(target_->nof_classes(),test.columns(),0.0);
00356 predict_common(*distances,prediction);
00357
00358 if(distances)
00359 delete distances;
00360 }
00361
00362 template <typename Distance, typename NeighborWeighting>
00363 void KNN<Distance, NeighborWeighting>::predict_common
00364 (const utility::Matrix& distances, utility::Matrix& prediction) const
00365 {
00366 for(size_t sample=0;sample<distances.columns();sample++) {
00367 std::vector<size_t> k_index;
00368 utility::VectorConstView dist=distances.column_const_view(sample);
00369 utility::sort_smallest_index(k_index,k_,dist);
00370 utility::VectorView pred=prediction.column_view(sample);
00371 weighting_(dist,k_index,*target_,pred);
00372 }
00373
00374
00375
00376 for(size_t c=0;c<target_->nof_classes(); c++)
00377 if(!target_->size(c))
00378 for(size_t j=0;j<prediction.columns();j++)
00379 prediction(c,j)=std::numeric_limits<double>::quiet_NaN();
00380 }
00381
00382
00383 }}}
00384
00385 #endif
00386