1 #ifndef _theplu_yat_classifier_ncc_
2 #define _theplu_yat_classifier_ncc_
27 #include "MatrixLookup.h"
28 #include "MatrixLookupWeighted.h"
29 #include "SupervisedClassifier.h"
32 #include "yat/statistics/Averager.h"
33 #include "yat/statistics/AveragerWeighted.h"
34 #include "yat/utility/concept_check.h"
35 #include "yat/utility/Exception.h"
36 #include "yat/utility/Matrix.h"
37 #include "yat/utility/MatrixWeighted.h"
38 #include "yat/utility/Vector.h"
40 #include "yat/utility/yat_assert.h"
42 #include <boost/concept_check.hpp>
50 namespace classifier {
64 template <
typename Distance>
180 template <
typename Distance>
187 template <
typename Distance>
195 template <
typename Distance>
201 template <
typename Distance>
208 template <
typename Distance>
217 template <
typename Distance>
221 for(
size_t i=0; i<data.
rows(); i++) {
222 std::vector<statistics::Averager> class_averager;
224 for(
size_t j=0; j<data.
columns(); j++) {
225 class_averager[target(j)].add(data(i,j));
228 centroids_(i,c) = class_averager[c].mean();
234 template <
typename Distance>
238 for(
size_t i=0; i<data.
rows(); i++) {
239 std::vector<statistics::AveragerWeighted> class_averager;
241 for(
size_t j=0; j<data.
columns(); j++)
242 class_averager[target(j)].add(data.
data(i,j),data.
weight(i,j));
244 if(class_averager[c].sum_w()==0) {
247 centroids_(i,c) = class_averager[c].mean();
253 template <
typename Distance>
257 utility::yat_assert<utility::runtime_error>
258 (centroids_.rows()==test.
rows(),
259 "NCC::predict test data with incorrect number of rows");
270 predict_unweighted(test,prediction);
274 template <
typename Distance>
278 utility::yat_assert<utility::runtime_error>
279 (centroids_.rows()==test.
rows(),
280 "NCC::predict test data with incorrect number of rows");
283 predict_weighted(test,prediction);
287 template <
typename Distance>
291 for(
size_t j=0; j<test.
columns();j++)
292 for(
size_t k=0; k<centroids_.columns();k++)
294 centroids_.begin_column(k));
297 template <
typename Distance>
298 void NCC<Distance>::predict_weighted(
const MatrixLookupWeighted& test,
302 for(
size_t j=0; j<test.columns();j++)
303 for(
size_t k=0; k<centroids_.columns();k++)
304 prediction(k,j) = distance_(test.begin_column(j), test.end_column(j),
305 weighted_centroids.begin_column(k));
General view into utility::Matrix.
Definition: MatrixLookup.h:70
const utility::Matrix & centroids(void) const
Get the centroids for all classes.
Definition: NCC.h:202
size_t columns(void) const
Class for containing sample labels.
Definition: Target.h:47
virtual ~NCC(void)
Definition: NCC.h:196
void resize(size_t r, size_t c, double init_value=0)
Resize Matrix.
void predict(const MatrixLookup &data, utility::Matrix &results) const
Make predictions for unweighted test data.
Definition: NCC.h:254
NCC< Distance > * make_classifier(void) const
Create an untrained copy of the classifier.
Definition: NCC.h:210
Nearest Centroid Classifier.
Definition: NCC.h:65
Interface class for supervised classifiers that use data in a matrix format.
Definition: SupervisedClassifier.h:56
double weight(size_t row, size_t column) const
size_t columns(void) const
General view into utility::MatrixWeighted.
Definition: MatrixLookupWeighted.h:63
void train(const MatrixLookup &data, const Target &targets)
Train the NCC using unweighted training data with known targets.
Definition: NCC.h:218
const_column_iterator end_column(size_t) const
Concept check for a Distance.
Definition: concept_check.h:278
Interface to GSL matrix.
Definition: Matrix.h:63
Weighted Matrix.
Definition: MatrixWeighted.h:44
const_column_iterator begin_column(size_t) const
NCC(void)
Constructor.
Definition: NCC.h:181
size_t nof_classes(void) const
double data(size_t row, size_t column) const