1 #ifndef _theplu_yat_classifier_knn_
2 #define _theplu_yat_classifier_knn_
26 #include "DataLookup1D.h"
27 #include "DataLookupWeighted1D.h"
28 #include "KNN_Uniform.h"
29 #include "MatrixLookup.h"
30 #include "MatrixLookupWeighted.h"
31 #include "SupervisedClassifier.h"
33 #include "yat/utility/concept_check.h"
34 #include "yat/utility/Exception.h"
35 #include "yat/utility/Matrix.h"
36 #include "yat/utility/Vector.h"
37 #include "yat/utility/VectorConstView.h"
38 #include "yat/utility/VectorView.h"
39 #include "yat/utility/yat_assert.h"
41 #include <boost/concept_check.hpp>
51 namespace classifier {
69 template <
typename Distance,
typename NeighborWeighting=KNN_Uniform>
106 unsigned int k()
const;
113 void k(
unsigned int k_in);
180 NeighborWeighting weighting_;
215 :
public boost::DefaultConstructible<T>,
public boost::Assignable<T>
223 T neighbor_weighting;
227 std::vector<size_t> k_sorted;
229 neighbor_weighting(distance, k_sorted, target, prediction);
236 template <
typename Distance,
typename NeighborWeighting>
244 template <
typename Distance,
typename NeighborWeighting>
254 template <
typename Distance,
typename NeighborWeighting>
260 template <
typename Distance,
typename NeighborWeighting>
265 for(
size_t i=0; i<training.
columns(); i++) {
266 for(
size_t j=0; j<test.
columns(); j++) {
270 YAT_ASSERT(!std::isnan((*distances)(i,j)));
276 template <
typename Distance,
typename NeighborWeighting>
278 KNN<Distance, NeighborWeighting>::calculate_weighted
279 (
const MatrixLookupWeighted& training,
const MatrixLookupWeighted& test,
280 utility::Matrix* distances)
const
282 for(
size_t i=0; i<training.columns(); i++) {
283 for(
size_t j=0; j<test.columns(); j++) {
284 (*distances)(i,j) = distance_(training.begin_column(i),
285 training.end_column(i),
286 test.begin_column(j));
289 if(std::isnan((*distances)(i,j)))
290 (*distances)(i,j)=std::numeric_limits<double>::infinity();
296 template <
typename Distance,
typename NeighborWeighting>
302 template <
typename Distance,
typename NeighborWeighting>
309 template <
typename Distance,
typename NeighborWeighting>
316 knn->weighting_=this->weighting_;
322 template <
typename Distance,
typename NeighborWeighting>
326 utility::yat_assert<utility::runtime_error>
328 "KNN::train called with different sizes of target and data");
337 template <
typename Distance,
typename NeighborWeighting>
341 utility::yat_assert<utility::runtime_error>
343 "KNN::train called with different sizes of target and data");
353 template <
typename Distance,
typename NeighborWeighting>
361 if(data_ml_ && !data_mlw_) {
362 utility::yat_assert<utility::runtime_error>
363 (data_ml_->rows()==test.
rows(),
364 "KNN::predict different number of rows in training and test data");
366 calculate_unweighted(*data_ml_,test,distances);
368 else if (data_mlw_ && !data_ml_) {
370 utility::yat_assert<utility::runtime_error>
371 (data_mlw_->rows()==test.
rows(),
372 "KNN::predict different number of rows in training and test data");
381 prediction.
resize(target_->nof_classes(),test.
columns(),0.0);
382 predict_common(*distances,prediction);
387 template <
typename Distance,
typename NeighborWeighting>
395 if(data_ml_ && !data_mlw_) {
396 utility::yat_assert<utility::runtime_error>
397 (data_ml_->rows()==test.
rows(),
398 "KNN::predict different number of rows in training and test data");
403 else if (data_mlw_ && !data_ml_) {
404 utility::yat_assert<utility::runtime_error>
405 (data_mlw_->rows()==test.
rows(),
406 "KNN::predict different number of rows in training and test data");
408 calculate_weighted(*data_mlw_,test,distances);
414 prediction.
resize(target_->nof_classes(),test.
columns(),0.0);
415 predict_common(*distances,prediction);
421 template <
typename Distance,
typename NeighborWeighting>
425 for(
size_t sample=0;sample<distances.
columns();sample++) {
426 std::vector<size_t> k_index;
430 weighting_(dist,k_index,*target_,pred);
435 for(
size_t c=0;c<target_->nof_classes(); c++)
436 if(!target_->size(c))
437 for(
size_t j=0;j<prediction.
columns();j++)
438 prediction(c,j)=std::numeric_limits<double>::quiet_NaN();
const VectorConstView column_const_view(size_t) const
General view into utility::Matrix.
Definition: MatrixLookup.h:70
size_t columns(void) const
virtual ~KNN()
Definition: KNN.h:255
Class for containing sample labels.
Definition: Target.h:47
BOOST_CONCEPT_USAGE(NeighborWeightingConcept)
function doing the concept test
Definition: KNN.h:221
Nearest Neighbor Classifier.
Definition: KNN.h:70
This is the yat interface to gsl_vector_view.
Definition: VectorView.h:79
void predict(const MatrixLookup &data, utility::Matrix &results) const
Make predictions for unweighted test data.
Definition: KNN.h:355
void train(const MatrixLookup &data, const Target &targets)
Train the KNN using unweighted training data with known targets.
Definition: KNN.h:323
KNN(void)
Default constructor.
Definition: KNN.h:237
void resize(size_t r, size_t c, double init_value=0)
Resize Matrix.
void sort_smallest_index(std::vector< size_t > &sort_index, size_t k, const VectorBase &invec)
Interface class for supervised classifiers that use data in a matrix format.
Definition: SupervisedClassifier.h:56
unsigned int k() const
Get the number of nearest neighbors.
Definition: KNN.h:297
Read-only view.
Definition: VectorConstView.h:55
Class used for all runtime error detected within yat library.
Definition: Exception.h:38
This is the yat interface to GSL vector.
Definition: Vector.h:57
size_t columns(void) const
This is the yat interface to GSL vector.
Definition: VectorBase.h:52
General view into utility::MatrixWeighted.
Definition: MatrixLookupWeighted.h:63
KNN< Distance, NeighborWeighting > * make_classifier(void) const
Create an untrained copy of the classifier.
Definition: KNN.h:311
VectorView column_view(size_t i)
This is the mutable interface to GSL vector.
Definition: VectorMutable.h:55
const_column_iterator end_column(size_t) const
Concept check for a Distance.
Definition: concept_check.h:278
Interface to GSL matrix.
Definition: Matrix.h:63
const_column_iterator begin_column(size_t) const
Concept check for a Neighbor Weighting Method.
Definition: KNN.h:214
size_t columns(void) const