3709 |
08 Nov 17 |
peter |
// $Id$ |
3709 |
08 Nov 17 |
peter |
2 |
|
3709 |
08 Nov 17 |
peter |
3 |
/* |
3709 |
08 Nov 17 |
peter |
Copyright (C) 2017 Peter Johansson |
3709 |
08 Nov 17 |
peter |
5 |
|
3709 |
08 Nov 17 |
peter |
This file is part of the yat library, http://dev.thep.lu.se/yat |
3709 |
08 Nov 17 |
peter |
7 |
|
3709 |
08 Nov 17 |
peter |
The yat library is free software; you can redistribute it and/or |
3709 |
08 Nov 17 |
peter |
modify it under the terms of the GNU General Public License as |
3709 |
08 Nov 17 |
peter |
published by the Free Software Foundation; either version 3 of the |
3709 |
08 Nov 17 |
peter |
License, or (at your option) any later version. |
3709 |
08 Nov 17 |
peter |
12 |
|
3709 |
08 Nov 17 |
peter |
The yat library is distributed in the hope that it will be useful, |
3709 |
08 Nov 17 |
peter |
but WITHOUT ANY WARRANTY; without even the implied warranty of |
3709 |
08 Nov 17 |
peter |
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU |
3709 |
08 Nov 17 |
peter |
General Public License for more details. |
3709 |
08 Nov 17 |
peter |
17 |
|
3709 |
08 Nov 17 |
peter |
You should have received a copy of the GNU General Public License |
3709 |
08 Nov 17 |
peter |
along with yat. If not, see <http://www.gnu.org/licenses/>. |
3709 |
08 Nov 17 |
peter |
20 |
*/ |
3709 |
08 Nov 17 |
peter |
21 |
|
3709 |
08 Nov 17 |
peter |
22 |
#include <config.h> |
3709 |
08 Nov 17 |
peter |
23 |
|
3709 |
08 Nov 17 |
peter |
24 |
#include "Suite.h" |
3709 |
08 Nov 17 |
peter |
25 |
|
3709 |
08 Nov 17 |
peter |
26 |
#include "yat/classifier/Perceptron.h" |
3709 |
08 Nov 17 |
peter |
27 |
#include "yat/classifier/Target.h" |
3709 |
08 Nov 17 |
peter |
28 |
#include "yat/random/random.h" |
3709 |
08 Nov 17 |
peter |
29 |
#include "yat/statistics/Averager.h" |
3709 |
08 Nov 17 |
peter |
30 |
#include "yat/utility/Matrix.h" |
3709 |
08 Nov 17 |
peter |
31 |
#include "yat/utility/Vector.h" |
3709 |
08 Nov 17 |
peter |
32 |
|
3709 |
08 Nov 17 |
peter |
33 |
#include <algorithm> |
3709 |
08 Nov 17 |
peter |
34 |
#include <cassert> |
3709 |
08 Nov 17 |
peter |
35 |
#include <iostream> |
3709 |
08 Nov 17 |
peter |
36 |
|
3709 |
08 Nov 17 |
peter |
37 |
using namespace theplu::yat; |
3709 |
08 Nov 17 |
peter |
38 |
|
3709 |
08 Nov 17 |
peter |
39 |
struct Stats |
3709 |
08 Nov 17 |
peter |
40 |
{ |
3709 |
08 Nov 17 |
peter |
41 |
Stats(size_t nof_samples, size_t nof_features) |
3709 |
08 Nov 17 |
peter |
42 |
: weight(nof_features), variance(nof_features), prediction(nof_samples) |
3709 |
08 Nov 17 |
peter |
43 |
{} |
3709 |
08 Nov 17 |
peter |
44 |
|
3709 |
08 Nov 17 |
peter |
45 |
std::vector<theplu::yat::statistics::Averager> weight; |
3709 |
08 Nov 17 |
peter |
46 |
std::vector<theplu::yat::statistics::Averager> variance; |
3709 |
08 Nov 17 |
peter |
47 |
std::vector<theplu::yat::statistics::Averager> prediction; |
3709 |
08 Nov 17 |
peter |
48 |
}; |
3709 |
08 Nov 17 |
peter |
49 |
|
3709 |
08 Nov 17 |
peter |
50 |
|
3709 |
08 Nov 17 |
peter |
51 |
void generate_y(classifier::Target& y, const utility::Vector& w, |
3709 |
08 Nov 17 |
peter |
52 |
const utility::Matrix& X) |
3709 |
08 Nov 17 |
peter |
53 |
{ |
3709 |
08 Nov 17 |
peter |
54 |
assert(X.columns() == w.size()); |
3709 |
08 Nov 17 |
peter |
55 |
std::vector<std::string> vec; |
3709 |
08 Nov 17 |
peter |
56 |
vec.reserve(X.rows()); |
3709 |
08 Nov 17 |
peter |
57 |
random::ContinuousUniform rnd; |
3709 |
08 Nov 17 |
peter |
58 |
for (size_t i=0; i<X.rows(); ++i) { |
3709 |
08 Nov 17 |
peter |
59 |
double p = 1.0 / (1.0 + std::exp(- (w * X.row_const_view(i)))); |
3709 |
08 Nov 17 |
peter |
60 |
if (rnd() < p) |
3709 |
08 Nov 17 |
peter |
61 |
vec.push_back("pos"); |
3709 |
08 Nov 17 |
peter |
62 |
else |
3709 |
08 Nov 17 |
peter |
63 |
vec.push_back("neg"); |
3709 |
08 Nov 17 |
peter |
64 |
} |
3709 |
08 Nov 17 |
peter |
65 |
assert(vec.size() == X.rows()); |
3709 |
08 Nov 17 |
peter |
66 |
|
3709 |
08 Nov 17 |
peter |
67 |
y = classifier::Target(vec); |
3709 |
08 Nov 17 |
peter |
68 |
assert(y.size() == vec.size()); |
3709 |
08 Nov 17 |
peter |
69 |
for (size_t i=0; i<y.labels().size(); ++i) { |
3709 |
08 Nov 17 |
peter |
70 |
if (y.labels()[i] == "pos") |
3709 |
08 Nov 17 |
peter |
71 |
y.set_binary(i, true); |
3709 |
08 Nov 17 |
peter |
72 |
else |
3709 |
08 Nov 17 |
peter |
73 |
y.set_binary(i, false); |
3709 |
08 Nov 17 |
peter |
74 |
} |
3709 |
08 Nov 17 |
peter |
75 |
assert(y.size() == X.rows()); |
3709 |
08 Nov 17 |
peter |
76 |
} |
3709 |
08 Nov 17 |
peter |
77 |
|
3709 |
08 Nov 17 |
peter |
78 |
|
3709 |
08 Nov 17 |
peter |
79 |
void analyse(const utility::Vector& w, const utility::Matrix& X, Stats& stats, |
3709 |
08 Nov 17 |
peter |
80 |
test::Suite& suite) |
3709 |
08 Nov 17 |
peter |
81 |
{ |
3709 |
08 Nov 17 |
peter |
82 |
classifier::Target y; |
3709 |
08 Nov 17 |
peter |
83 |
generate_y(y, w, X); |
3709 |
08 Nov 17 |
peter |
84 |
assert(y.size() == X.rows()); |
3709 |
08 Nov 17 |
peter |
85 |
|
3709 |
08 Nov 17 |
peter |
86 |
classifier::Perceptron perceptron; |
3709 |
08 Nov 17 |
peter |
87 |
perceptron.train(X, y); |
3709 |
08 Nov 17 |
peter |
88 |
|
3709 |
08 Nov 17 |
peter |
89 |
for (size_t i=0; i<perceptron.weight().size(); ++i) { |
3709 |
08 Nov 17 |
peter |
90 |
stats.weight[i].add(perceptron.weight()(i)); |
3709 |
08 Nov 17 |
peter |
91 |
stats.variance[i].add(perceptron.covariance()(i,i)); |
3709 |
08 Nov 17 |
peter |
92 |
double p = perceptron.p_value(i); |
3709 |
08 Nov 17 |
peter |
93 |
double x = perceptron.oddsratio(i); |
3709 |
08 Nov 17 |
peter |
94 |
double lower = perceptron.oddsratio_lower_CI(i); |
3709 |
08 Nov 17 |
peter |
95 |
double upper = perceptron.oddsratio_upper_CI(i); |
3709 |
08 Nov 17 |
peter |
// some sanity checks |
3709 |
08 Nov 17 |
peter |
97 |
if (!(lower<x && x<upper)) { |
3709 |
08 Nov 17 |
peter |
98 |
suite.add(false); |
3709 |
08 Nov 17 |
peter |
99 |
suite.err() << "error: incorrect CI: " << lower << " " |
3709 |
08 Nov 17 |
peter |
100 |
<< x << " " << upper << "\n"; |
3709 |
08 Nov 17 |
peter |
101 |
} |
3709 |
08 Nov 17 |
peter |
102 |
|
3709 |
08 Nov 17 |
peter |
103 |
if (p < 0.05) { |
3709 |
08 Nov 17 |
peter |
104 |
if (lower < 1.0 && upper > 1.0) { |
3709 |
08 Nov 17 |
peter |
105 |
suite.add(false); |
3709 |
08 Nov 17 |
peter |
106 |
suite.err() << "significant p " << p << " expected CI not to " |
3709 |
08 Nov 17 |
peter |
107 |
<< "overlap with 1.0: " << lower << " " << upper << "\n"; |
3709 |
08 Nov 17 |
peter |
108 |
} |
3709 |
08 Nov 17 |
peter |
109 |
} |
3709 |
08 Nov 17 |
peter |
110 |
else if (p > 0.05) { |
3709 |
08 Nov 17 |
peter |
111 |
if (lower > 1.0 || upper < 1.0) { |
3709 |
08 Nov 17 |
peter |
112 |
suite.add(false); |
3709 |
08 Nov 17 |
peter |
113 |
suite.err() << "nonsignificant p " << p << " expected CI to " |
3709 |
08 Nov 17 |
peter |
114 |
<< "overlap with 1.0: " << lower << " " << upper << "\n"; |
3709 |
08 Nov 17 |
peter |
115 |
} |
3709 |
08 Nov 17 |
peter |
116 |
} |
3709 |
08 Nov 17 |
peter |
117 |
} |
3709 |
08 Nov 17 |
peter |
118 |
|
3709 |
08 Nov 17 |
peter |
119 |
assert(y.size() == X.rows()); |
3709 |
08 Nov 17 |
peter |
120 |
for (size_t i=0; i<y.size(); ++i) { |
3709 |
08 Nov 17 |
peter |
121 |
stats.prediction[i].add(perceptron.predict(X.row_const_view(i))); |
3709 |
08 Nov 17 |
peter |
122 |
} |
3709 |
08 Nov 17 |
peter |
123 |
} |
3709 |
08 Nov 17 |
peter |
124 |
|
3709 |
08 Nov 17 |
peter |
125 |
|
3709 |
08 Nov 17 |
peter |
126 |
int main(int argc, char* argv[]) |
3709 |
08 Nov 17 |
peter |
127 |
{ |
3709 |
08 Nov 17 |
peter |
128 |
test::Suite suite(argc, argv); |
3709 |
08 Nov 17 |
peter |
129 |
|
3709 |
08 Nov 17 |
peter |
// data are modeled as y = 1 / (1 + exp(-w'x)) where w and x are vectors |
3709 |
08 Nov 17 |
peter |
131 |
|
3709 |
08 Nov 17 |
peter |
132 |
size_t n = 5000; |
3709 |
08 Nov 17 |
peter |
133 |
|
3709 |
08 Nov 17 |
peter |
134 |
utility::Vector w(5, 0); |
3709 |
08 Nov 17 |
peter |
135 |
w(0) = 3.14; |
3709 |
08 Nov 17 |
peter |
136 |
w(1) = 1.0; |
3709 |
08 Nov 17 |
peter |
137 |
w(2) = 5.0; |
3709 |
08 Nov 17 |
peter |
138 |
w(3) = -2.0; |
3709 |
08 Nov 17 |
peter |
139 |
|
3709 |
08 Nov 17 |
peter |
140 |
size_t p = w.size(); |
3709 |
08 Nov 17 |
peter |
141 |
utility::Matrix X(n, p); |
3709 |
08 Nov 17 |
peter |
142 |
std::generate(X.begin(), X.end(), random::Gaussian()); |
3709 |
08 Nov 17 |
peter |
143 |
X.column_view(0).all(1.0); |
3709 |
08 Nov 17 |
peter |
144 |
|
3709 |
08 Nov 17 |
peter |
145 |
Stats stats(n, p); |
3709 |
08 Nov 17 |
peter |
146 |
for (size_t i=0; i<100; ++i) { |
3709 |
08 Nov 17 |
peter |
147 |
analyse(w, X, stats, suite); |
3709 |
08 Nov 17 |
peter |
148 |
} |
3709 |
08 Nov 17 |
peter |
149 |
|
3709 |
08 Nov 17 |
peter |
150 |
suite.out() << "weight\ni\tweight\t<inferred>\tz\n"; |
3709 |
08 Nov 17 |
peter |
151 |
for (size_t i=0; i<p; ++i) { |
3709 |
08 Nov 17 |
peter |
152 |
double z = (stats.weight[i].mean()-w(i))/stats.weight[i].standard_error(); |
3709 |
08 Nov 17 |
peter |
153 |
suite.out() << i << "\t" << w(i) << "\t" |
3709 |
08 Nov 17 |
peter |
154 |
<< stats.weight[i].mean() << "\t" |
3709 |
08 Nov 17 |
peter |
155 |
<< z |
3709 |
08 Nov 17 |
peter |
156 |
<< "\n"; |
3709 |
08 Nov 17 |
peter |
157 |
if (std::abs(z) > 5.0) { |
3709 |
08 Nov 17 |
peter |
158 |
suite.add(false); |
3709 |
08 Nov 17 |
peter |
159 |
suite.err() << "error: z: " << z << " for param " << i << "\n"; |
3709 |
08 Nov 17 |
peter |
160 |
} |
3709 |
08 Nov 17 |
peter |
161 |
} |
3709 |
08 Nov 17 |
peter |
162 |
|
3709 |
08 Nov 17 |
peter |
163 |
suite.out() << "covariance\n"; |
3709 |
08 Nov 17 |
peter |
164 |
suite.out() << "i\t<inferred>\tobserved\tdelta\trelative error\n"; |
3709 |
08 Nov 17 |
peter |
165 |
for (size_t i=0; i<p; ++i) { |
3709 |
08 Nov 17 |
peter |
166 |
double delta = stats.variance[i].mean() - stats.weight[i].variance(); |
3709 |
08 Nov 17 |
peter |
167 |
double relative_error = delta / stats.weight[i].variance(); |
3709 |
08 Nov 17 |
peter |
168 |
suite.out() << i << "\t" |
3709 |
08 Nov 17 |
peter |
169 |
<< stats.variance[i].mean() << "\t" |
3709 |
08 Nov 17 |
peter |
170 |
<< stats.weight[i].variance() << "\t" |
3709 |
08 Nov 17 |
peter |
171 |
<< delta << "\t" |
3709 |
08 Nov 17 |
peter |
172 |
<< relative_error << "\t" |
3709 |
08 Nov 17 |
peter |
173 |
<< "\n"; |
3709 |
08 Nov 17 |
peter |
174 |
if (relative_error > 0.4) { |
3709 |
08 Nov 17 |
peter |
175 |
suite.add(false); |
3709 |
08 Nov 17 |
peter |
176 |
suite.err() << "error: covariance param: " << i << "\n"; |
3709 |
08 Nov 17 |
peter |
177 |
} |
3709 |
08 Nov 17 |
peter |
178 |
} |
3709 |
08 Nov 17 |
peter |
179 |
|
3709 |
08 Nov 17 |
peter |
180 |
suite.out() << "predictions\n"; |
3709 |
08 Nov 17 |
peter |
181 |
double sum = 0; |
3709 |
08 Nov 17 |
peter |
182 |
double slack = 0.01; |
3709 |
08 Nov 17 |
peter |
183 |
for (size_t i=0; i<n; ++i) { |
3709 |
08 Nov 17 |
peter |
184 |
double theoretical = 1.0 / (1 + std::exp(- w*X.row_const_view(i))); |
3709 |
08 Nov 17 |
peter |
185 |
double delta = stats.prediction[i].mean() - theoretical; |
3709 |
08 Nov 17 |
peter |
186 |
sum += delta*delta; |
3709 |
08 Nov 17 |
peter |
187 |
if (std::abs(delta) > slack) |
3709 |
08 Nov 17 |
peter |
188 |
suite.out() << i << "\t" |
3709 |
08 Nov 17 |
peter |
189 |
<< theoretical << "\t" |
3709 |
08 Nov 17 |
peter |
190 |
<< stats.prediction[i].mean() << "\t" |
3709 |
08 Nov 17 |
peter |
191 |
<< delta << "\t" |
3709 |
08 Nov 17 |
peter |
192 |
<< "\n"; |
3709 |
08 Nov 17 |
peter |
193 |
} |
3709 |
08 Nov 17 |
peter |
194 |
suite.out() << "sum squared error: " << sum << "\n"; |
3709 |
08 Nov 17 |
peter |
195 |
if (sum > n * slack*slack) { |
3709 |
08 Nov 17 |
peter |
196 |
suite.err() << "error: too large sum\n"; |
3709 |
08 Nov 17 |
peter |
197 |
suite.add(false); |
3709 |
08 Nov 17 |
peter |
198 |
} |
3709 |
08 Nov 17 |
peter |
199 |
|
3709 |
08 Nov 17 |
peter |
200 |
return suite.return_value(); |
3709 |
08 Nov 17 |
peter |
201 |
} |