test/perceptron.cc

Code
Comments
Other
Rev Date Author Line
3709 08 Nov 17 peter 1 // $Id$
3709 08 Nov 17 peter 2
3709 08 Nov 17 peter 3 /*
3709 08 Nov 17 peter 4   Copyright (C) 2017 Peter Johansson
3709 08 Nov 17 peter 5
3709 08 Nov 17 peter 6   This file is part of the yat library, http://dev.thep.lu.se/yat
3709 08 Nov 17 peter 7
3709 08 Nov 17 peter 8   The yat library is free software; you can redistribute it and/or
3709 08 Nov 17 peter 9   modify it under the terms of the GNU General Public License as
3709 08 Nov 17 peter 10   published by the Free Software Foundation; either version 3 of the
3709 08 Nov 17 peter 11   License, or (at your option) any later version.
3709 08 Nov 17 peter 12
3709 08 Nov 17 peter 13   The yat library is distributed in the hope that it will be useful,
3709 08 Nov 17 peter 14   but WITHOUT ANY WARRANTY; without even the implied warranty of
3709 08 Nov 17 peter 15   MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
3709 08 Nov 17 peter 16   General Public License for more details.
3709 08 Nov 17 peter 17
3709 08 Nov 17 peter 18   You should have received a copy of the GNU General Public License
3709 08 Nov 17 peter 19   along with yat. If not, see <http://www.gnu.org/licenses/>.
3709 08 Nov 17 peter 20 */
3709 08 Nov 17 peter 21
3709 08 Nov 17 peter 22 #include <config.h>
3709 08 Nov 17 peter 23
3709 08 Nov 17 peter 24 #include "Suite.h"
3709 08 Nov 17 peter 25
3709 08 Nov 17 peter 26 #include "yat/classifier/Perceptron.h"
3709 08 Nov 17 peter 27 #include "yat/classifier/Target.h"
3709 08 Nov 17 peter 28 #include "yat/random/random.h"
3709 08 Nov 17 peter 29 #include "yat/statistics/Averager.h"
3709 08 Nov 17 peter 30 #include "yat/utility/Matrix.h"
3709 08 Nov 17 peter 31 #include "yat/utility/Vector.h"
3709 08 Nov 17 peter 32
3709 08 Nov 17 peter 33 #include <algorithm>
3709 08 Nov 17 peter 34 #include <cassert>
3709 08 Nov 17 peter 35 #include <iostream>
3709 08 Nov 17 peter 36
3709 08 Nov 17 peter 37 using namespace theplu::yat;
3709 08 Nov 17 peter 38
3709 08 Nov 17 peter 39 struct Stats
3709 08 Nov 17 peter 40 {
3709 08 Nov 17 peter 41   Stats(size_t nof_samples, size_t nof_features)
3709 08 Nov 17 peter 42     : weight(nof_features), variance(nof_features), prediction(nof_samples)
3709 08 Nov 17 peter 43   {}
3709 08 Nov 17 peter 44
3709 08 Nov 17 peter 45   std::vector<theplu::yat::statistics::Averager> weight;
3709 08 Nov 17 peter 46   std::vector<theplu::yat::statistics::Averager> variance;
3709 08 Nov 17 peter 47   std::vector<theplu::yat::statistics::Averager> prediction;
3709 08 Nov 17 peter 48 };
3709 08 Nov 17 peter 49
3709 08 Nov 17 peter 50
3709 08 Nov 17 peter 51 void generate_y(classifier::Target& y, const utility::Vector& w,
3709 08 Nov 17 peter 52                 const utility::Matrix& X)
3709 08 Nov 17 peter 53 {
3709 08 Nov 17 peter 54   assert(X.columns() == w.size());
3709 08 Nov 17 peter 55   std::vector<std::string> vec;
3709 08 Nov 17 peter 56   vec.reserve(X.rows());
3709 08 Nov 17 peter 57   random::ContinuousUniform rnd;
3709 08 Nov 17 peter 58   for (size_t i=0; i<X.rows(); ++i) {
3709 08 Nov 17 peter 59     double p = 1.0 / (1.0 + std::exp(- (w * X.row_const_view(i))));
3709 08 Nov 17 peter 60     if (rnd() < p)
3709 08 Nov 17 peter 61       vec.push_back("pos");
3709 08 Nov 17 peter 62     else
3709 08 Nov 17 peter 63       vec.push_back("neg");
3709 08 Nov 17 peter 64   }
3709 08 Nov 17 peter 65   assert(vec.size() == X.rows());
3709 08 Nov 17 peter 66
3709 08 Nov 17 peter 67   y = classifier::Target(vec);
3709 08 Nov 17 peter 68   assert(y.size() == vec.size());
3709 08 Nov 17 peter 69   for (size_t i=0; i<y.labels().size(); ++i) {
3709 08 Nov 17 peter 70     if (y.labels()[i] == "pos")
3709 08 Nov 17 peter 71       y.set_binary(i, true);
3709 08 Nov 17 peter 72     else
3709 08 Nov 17 peter 73       y.set_binary(i, false);
3709 08 Nov 17 peter 74   }
3709 08 Nov 17 peter 75   assert(y.size() == X.rows());
3709 08 Nov 17 peter 76 }
3709 08 Nov 17 peter 77
3709 08 Nov 17 peter 78
3709 08 Nov 17 peter 79 void analyse(const utility::Vector& w, const utility::Matrix& X, Stats& stats,
3709 08 Nov 17 peter 80              test::Suite& suite)
3709 08 Nov 17 peter 81 {
3709 08 Nov 17 peter 82   classifier::Target y;
3709 08 Nov 17 peter 83   generate_y(y, w, X);
3709 08 Nov 17 peter 84   assert(y.size() == X.rows());
3709 08 Nov 17 peter 85
3709 08 Nov 17 peter 86   classifier::Perceptron perceptron;
3709 08 Nov 17 peter 87   perceptron.train(X, y);
3709 08 Nov 17 peter 88
3709 08 Nov 17 peter 89   for (size_t i=0; i<perceptron.weight().size(); ++i) {
3709 08 Nov 17 peter 90     stats.weight[i].add(perceptron.weight()(i));
3709 08 Nov 17 peter 91     stats.variance[i].add(perceptron.covariance()(i,i));
3709 08 Nov 17 peter 92     double p = perceptron.p_value(i);
3709 08 Nov 17 peter 93     double x = perceptron.oddsratio(i);
3709 08 Nov 17 peter 94     double lower = perceptron.oddsratio_lower_CI(i);
3709 08 Nov 17 peter 95     double upper = perceptron.oddsratio_upper_CI(i);
3709 08 Nov 17 peter 96     // some sanity checks
3709 08 Nov 17 peter 97     if (!(lower<x && x<upper)) {
3709 08 Nov 17 peter 98       suite.add(false);
3709 08 Nov 17 peter 99       suite.err() << "error: incorrect CI: " << lower << " "
3709 08 Nov 17 peter 100                   << x << " " << upper << "\n";
3709 08 Nov 17 peter 101     }
3709 08 Nov 17 peter 102
3709 08 Nov 17 peter 103     if (p < 0.05) {
3709 08 Nov 17 peter 104       if (lower < 1.0 && upper > 1.0) {
3709 08 Nov 17 peter 105         suite.add(false);
3709 08 Nov 17 peter 106         suite.err() << "significant p " << p << " expected CI not to "
3709 08 Nov 17 peter 107                     << "overlap with 1.0: " << lower << " " << upper << "\n";
3709 08 Nov 17 peter 108       }
3709 08 Nov 17 peter 109     }
3709 08 Nov 17 peter 110     else if (p > 0.05) {
3709 08 Nov 17 peter 111       if (lower > 1.0 || upper < 1.0) {
3709 08 Nov 17 peter 112         suite.add(false);
3709 08 Nov 17 peter 113         suite.err() << "nonsignificant p " << p << " expected CI to "
3709 08 Nov 17 peter 114                     << "overlap with 1.0: " << lower << " " << upper << "\n";
3709 08 Nov 17 peter 115       }
3709 08 Nov 17 peter 116     }
3709 08 Nov 17 peter 117   }
3709 08 Nov 17 peter 118
3709 08 Nov 17 peter 119   assert(y.size() == X.rows());
3709 08 Nov 17 peter 120   for (size_t i=0; i<y.size(); ++i) {
3709 08 Nov 17 peter 121     stats.prediction[i].add(perceptron.predict(X.row_const_view(i)));
3709 08 Nov 17 peter 122   }
3709 08 Nov 17 peter 123 }
3709 08 Nov 17 peter 124
3709 08 Nov 17 peter 125
3709 08 Nov 17 peter 126 int main(int argc, char* argv[])
3709 08 Nov 17 peter 127 {
3709 08 Nov 17 peter 128   test::Suite suite(argc, argv);
3709 08 Nov 17 peter 129
3709 08 Nov 17 peter 130   // data are modeled as y = 1 / (1 + exp(-w'x)) where w and x are vectors
3709 08 Nov 17 peter 131
3709 08 Nov 17 peter 132   size_t n = 5000;
3709 08 Nov 17 peter 133
3709 08 Nov 17 peter 134   utility::Vector w(5, 0);
3709 08 Nov 17 peter 135   w(0) = 3.14;
3709 08 Nov 17 peter 136   w(1) = 1.0;
3709 08 Nov 17 peter 137   w(2) = 5.0;
3709 08 Nov 17 peter 138   w(3) = -2.0;
3709 08 Nov 17 peter 139
3709 08 Nov 17 peter 140   size_t p = w.size();
3709 08 Nov 17 peter 141   utility::Matrix X(n, p);
3709 08 Nov 17 peter 142   std::generate(X.begin(), X.end(), random::Gaussian());
3709 08 Nov 17 peter 143   X.column_view(0).all(1.0);
3709 08 Nov 17 peter 144
3709 08 Nov 17 peter 145   Stats stats(n, p);
3709 08 Nov 17 peter 146   for (size_t i=0; i<100; ++i) {
3709 08 Nov 17 peter 147     analyse(w, X, stats, suite);
3709 08 Nov 17 peter 148   }
3709 08 Nov 17 peter 149
3709 08 Nov 17 peter 150   suite.out() << "weight\ni\tweight\t<inferred>\tz\n";
3709 08 Nov 17 peter 151   for (size_t i=0; i<p; ++i) {
3709 08 Nov 17 peter 152     double z = (stats.weight[i].mean()-w(i))/stats.weight[i].standard_error();
3709 08 Nov 17 peter 153     suite.out() << i << "\t" << w(i) << "\t"
3709 08 Nov 17 peter 154                 << stats.weight[i].mean() << "\t"
3709 08 Nov 17 peter 155                 << z
3709 08 Nov 17 peter 156                 << "\n";
3709 08 Nov 17 peter 157     if (std::abs(z) > 5.0) {
3709 08 Nov 17 peter 158       suite.add(false);
3709 08 Nov 17 peter 159       suite.err() << "error: z: " << z << " for param " << i << "\n";
3709 08 Nov 17 peter 160     }
3709 08 Nov 17 peter 161   }
3709 08 Nov 17 peter 162
3709 08 Nov 17 peter 163   suite.out() << "covariance\n";
3709 08 Nov 17 peter 164   suite.out() << "i\t<inferred>\tobserved\tdelta\trelative error\n";
3709 08 Nov 17 peter 165   for (size_t i=0; i<p; ++i) {
3709 08 Nov 17 peter 166     double delta = stats.variance[i].mean() - stats.weight[i].variance();
3709 08 Nov 17 peter 167     double relative_error = delta / stats.weight[i].variance();
3709 08 Nov 17 peter 168     suite.out() << i << "\t"
3709 08 Nov 17 peter 169               << stats.variance[i].mean() << "\t"
3709 08 Nov 17 peter 170               << stats.weight[i].variance() << "\t"
3709 08 Nov 17 peter 171               << delta << "\t"
3709 08 Nov 17 peter 172               << relative_error << "\t"
3709 08 Nov 17 peter 173               << "\n";
3709 08 Nov 17 peter 174     if (relative_error > 0.4) {
3709 08 Nov 17 peter 175       suite.add(false);
3709 08 Nov 17 peter 176       suite.err() << "error: covariance param: " << i << "\n";
3709 08 Nov 17 peter 177     }
3709 08 Nov 17 peter 178   }
3709 08 Nov 17 peter 179
3709 08 Nov 17 peter 180   suite.out() << "predictions\n";
3709 08 Nov 17 peter 181   double sum = 0;
3709 08 Nov 17 peter 182   double slack = 0.01;
3709 08 Nov 17 peter 183   for (size_t i=0; i<n; ++i) {
3709 08 Nov 17 peter 184     double theoretical = 1.0 / (1 + std::exp(- w*X.row_const_view(i)));
3709 08 Nov 17 peter 185     double delta = stats.prediction[i].mean() - theoretical;
3709 08 Nov 17 peter 186     sum += delta*delta;
3709 08 Nov 17 peter 187     if (std::abs(delta) > slack)
3709 08 Nov 17 peter 188       suite.out() << i << "\t"
3709 08 Nov 17 peter 189                 << theoretical << "\t"
3709 08 Nov 17 peter 190                 << stats.prediction[i].mean() << "\t"
3709 08 Nov 17 peter 191                 << delta << "\t"
3709 08 Nov 17 peter 192                 << "\n";
3709 08 Nov 17 peter 193   }
3709 08 Nov 17 peter 194   suite.out() << "sum squared error: " << sum << "\n";
3709 08 Nov 17 peter 195   if (sum > n * slack*slack) {
3709 08 Nov 17 peter 196     suite.err() << "error: too large sum\n";
3709 08 Nov 17 peter 197     suite.add(false);
3709 08 Nov 17 peter 198   }
3709 08 Nov 17 peter 199
3709 08 Nov 17 peter 200   return suite.return_value();
3709 08 Nov 17 peter 201 }