640 |
07 Sep 06 |
peter |
// $Id$ |
640 |
07 Sep 06 |
peter |
2 |
|
675 |
10 Oct 06 |
jari |
3 |
/* |
2119 |
12 Dec 09 |
peter |
Copyright (C) 2006 Jari Häkkinen, Peter Johansson, Markus Ringnér |
4359 |
23 Aug 23 |
peter |
Copyright (C) 2007 Peter Johansson |
4359 |
23 Aug 23 |
peter |
Copyright (C) 2008 Jari Häkkinen, Peter Johansson |
4359 |
23 Aug 23 |
peter |
Copyright (C) 2012 Peter Johansson |
640 |
07 Sep 06 |
peter |
8 |
|
1437 |
25 Aug 08 |
peter |
This file is part of the yat library, http://dev.thep.lu.se/yat |
675 |
10 Oct 06 |
jari |
10 |
|
675 |
10 Oct 06 |
jari |
The yat library is free software; you can redistribute it and/or |
675 |
10 Oct 06 |
jari |
modify it under the terms of the GNU General Public License as |
1486 |
09 Sep 08 |
jari |
published by the Free Software Foundation; either version 3 of the |
675 |
10 Oct 06 |
jari |
License, or (at your option) any later version. |
675 |
10 Oct 06 |
jari |
15 |
|
675 |
10 Oct 06 |
jari |
The yat library is distributed in the hope that it will be useful, |
675 |
10 Oct 06 |
jari |
but WITHOUT ANY WARRANTY; without even the implied warranty of |
675 |
10 Oct 06 |
jari |
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU |
675 |
10 Oct 06 |
jari |
General Public License for more details. |
675 |
10 Oct 06 |
jari |
20 |
|
675 |
10 Oct 06 |
jari |
You should have received a copy of the GNU General Public License |
1487 |
10 Sep 08 |
jari |
along with yat. If not, see <http://www.gnu.org/licenses/>. |
675 |
10 Oct 06 |
jari |
23 |
*/ |
675 |
10 Oct 06 |
jari |
24 |
|
2881 |
18 Nov 12 |
peter |
25 |
#include <config.h> |
2881 |
18 Nov 12 |
peter |
26 |
|
1248 |
19 Mar 08 |
peter |
27 |
#include "Suite.h" |
1248 |
19 Mar 08 |
peter |
28 |
|
824 |
19 Mar 07 |
peter |
29 |
#include "yat/classifier/BootstrapSampler.h" |
675 |
10 Oct 06 |
jari |
30 |
#include "yat/classifier/CrossValidationSampler.h" |
675 |
10 Oct 06 |
jari |
31 |
#include "yat/classifier/FeatureSelectorIR.h" |
675 |
10 Oct 06 |
jari |
32 |
#include "yat/classifier/Kernel_SEV.h" |
675 |
10 Oct 06 |
jari |
33 |
#include "yat/classifier/KernelLookup.h" |
675 |
10 Oct 06 |
jari |
34 |
#include "yat/classifier/MatrixLookup.h" |
675 |
10 Oct 06 |
jari |
35 |
#include "yat/classifier/PolynomialKernelFunction.h" |
675 |
10 Oct 06 |
jari |
36 |
#include "yat/classifier/SubsetGenerator.h" |
820 |
17 Mar 07 |
peter |
37 |
#include "yat/statistics/AUC.h" |
1121 |
22 Feb 08 |
peter |
38 |
#include "yat/utility/Matrix.h" |
675 |
10 Oct 06 |
jari |
39 |
|
781 |
05 Mar 07 |
peter |
40 |
#include <cassert> |
640 |
07 Sep 06 |
peter |
41 |
#include <fstream> |
640 |
07 Sep 06 |
peter |
42 |
#include <iostream> |
640 |
07 Sep 06 |
peter |
43 |
#include <string> |
640 |
07 Sep 06 |
peter |
44 |
|
680 |
11 Oct 06 |
jari |
45 |
using namespace theplu::yat; |
640 |
07 Sep 06 |
peter |
46 |
|
1248 |
19 Mar 08 |
peter |
47 |
bool class_count_test(const std::vector<size_t>&, test::Suite&); |
1248 |
19 Mar 08 |
peter |
48 |
bool sample_count_test(const std::vector<size_t>&, test::Suite&); |
1248 |
19 Mar 08 |
peter |
49 |
bool test_nested(test::Suite&); |
1248 |
19 Mar 08 |
peter |
50 |
bool test_cv(test::Suite&); |
1248 |
19 Mar 08 |
peter |
51 |
bool test_creation(test::Suite&); |
1248 |
19 Mar 08 |
peter |
52 |
bool test_bootstrap(test::Suite&); |
824 |
19 Mar 07 |
peter |
53 |
|
824 |
19 Mar 07 |
peter |
54 |
|
1248 |
19 Mar 08 |
peter |
55 |
int main(int argc, char* argv[]) |
4200 |
19 Aug 22 |
peter |
56 |
{ |
1248 |
19 Mar 08 |
peter |
57 |
test::Suite suite(argc, argv); |
1248 |
19 Mar 08 |
peter |
58 |
suite.err() << "testing subset_generator" << std::endl; |
640 |
07 Sep 06 |
peter |
59 |
|
1248 |
19 Mar 08 |
peter |
60 |
test_creation(suite); |
1248 |
19 Mar 08 |
peter |
61 |
test_nested(suite); |
1248 |
19 Mar 08 |
peter |
62 |
test_cv(suite); |
640 |
07 Sep 06 |
peter |
63 |
|
1248 |
19 Mar 08 |
peter |
64 |
return suite.return_value(); |
824 |
19 Mar 07 |
peter |
65 |
} |
824 |
19 Mar 07 |
peter |
66 |
|
824 |
19 Mar 07 |
peter |
67 |
|
1248 |
19 Mar 08 |
peter |
68 |
bool test_creation(test::Suite& suite) |
824 |
19 Mar 07 |
peter |
69 |
{ |
824 |
19 Mar 07 |
peter |
70 |
bool ok=true; |
1251 |
03 Apr 08 |
peter |
71 |
std::ifstream is(test::filename("data/nm_target_bin.txt").c_str()); |
1248 |
19 Mar 08 |
peter |
72 |
suite.err() << "loading target " << std::endl; |
640 |
07 Sep 06 |
peter |
73 |
classifier::Target target(is); |
640 |
07 Sep 06 |
peter |
74 |
is.close(); |
1248 |
19 Mar 08 |
peter |
75 |
suite.err() << "number of targets: " << target.size() << std::endl; |
1248 |
19 Mar 08 |
peter |
76 |
suite.err() << "number of classes: " << target.nof_classes() << std::endl; |
1251 |
03 Apr 08 |
peter |
77 |
is.open(test::filename("data/nm_data_centralized.txt").c_str()); |
1248 |
19 Mar 08 |
peter |
78 |
suite.err() << "loading data " << std::endl; |
1121 |
22 Feb 08 |
peter |
79 |
utility::Matrix m(is); |
640 |
07 Sep 06 |
peter |
80 |
is.close(); |
640 |
07 Sep 06 |
peter |
81 |
classifier::MatrixLookup data(m); |
1248 |
19 Mar 08 |
peter |
82 |
suite.err() << "number of samples: " << data.columns() << std::endl; |
1248 |
19 Mar 08 |
peter |
83 |
suite.err() << "number of features: " << data.rows() << std::endl; |
640 |
07 Sep 06 |
peter |
84 |
assert(data.columns()==target.size()); |
640 |
07 Sep 06 |
peter |
85 |
|
1248 |
19 Mar 08 |
peter |
86 |
suite.err() << "building kernel" << std::endl; |
640 |
07 Sep 06 |
peter |
87 |
classifier::PolynomialKernelFunction kf(1); |
640 |
07 Sep 06 |
peter |
88 |
classifier::Kernel_SEV kernel_core(data,kf); |
640 |
07 Sep 06 |
peter |
89 |
classifier::KernelLookup kernel(kernel_core); |
1248 |
19 Mar 08 |
peter |
90 |
suite.err() << "building Sampler" << std::endl; |
640 |
07 Sep 06 |
peter |
91 |
classifier::CrossValidationSampler sampler(target, 30, 3); |
640 |
07 Sep 06 |
peter |
92 |
|
820 |
17 Mar 07 |
peter |
93 |
statistics::AUC score; |
640 |
07 Sep 06 |
peter |
94 |
classifier::FeatureSelectorIR fs(score, 96, 0); |
1248 |
19 Mar 08 |
peter |
95 |
suite.err() << "building SubsetGenerator" << std::endl; |
4200 |
19 Aug 22 |
peter |
96 |
classifier::SubsetGenerator<classifier::MatrixLookup> |
1086 |
14 Feb 08 |
peter |
97 |
subset_data(sampler, data, fs); |
4200 |
19 Aug 22 |
peter |
98 |
classifier::SubsetGenerator<classifier::KernelLookup> |
1086 |
14 Feb 08 |
peter |
99 |
subset_kernel(sampler, kernel,fs); |
824 |
19 Mar 07 |
peter |
100 |
return ok; |
824 |
19 Mar 07 |
peter |
101 |
} |
640 |
07 Sep 06 |
peter |
102 |
|
1248 |
19 Mar 08 |
peter |
103 |
bool test_nested(test::Suite& suite) |
824 |
19 Mar 07 |
peter |
104 |
{ |
824 |
19 Mar 07 |
peter |
105 |
bool ok=true; |
824 |
19 Mar 07 |
peter |
106 |
// |
4200 |
19 Aug 22 |
peter |
// Test two nested CrossSplitters |
824 |
19 Mar 07 |
peter |
108 |
// |
824 |
19 Mar 07 |
peter |
109 |
|
1248 |
19 Mar 08 |
peter |
110 |
suite.err() << "\ntesting two nested crossplitters" << std::endl; |
824 |
19 Mar 07 |
peter |
111 |
std::vector<std::string> label(9); |
824 |
19 Mar 07 |
peter |
112 |
label[0]=label[1]=label[2]="0"; |
824 |
19 Mar 07 |
peter |
113 |
label[3]=label[4]=label[5]="1"; |
824 |
19 Mar 07 |
peter |
114 |
label[6]=label[7]=label[8]="2"; |
4200 |
19 Aug 22 |
peter |
115 |
|
824 |
19 Mar 07 |
peter |
116 |
classifier::Target target(label); |
1121 |
22 Feb 08 |
peter |
117 |
utility::Matrix raw_data2(2,9); |
824 |
19 Mar 07 |
peter |
118 |
for(size_t i=0;i<raw_data2.rows();i++) |
824 |
19 Mar 07 |
peter |
119 |
for(size_t j=0;j<raw_data2.columns();j++) |
824 |
19 Mar 07 |
peter |
120 |
raw_data2(i,j)=i*10+10+j+1; |
4200 |
19 Aug 22 |
peter |
121 |
|
824 |
19 Mar 07 |
peter |
122 |
classifier::MatrixLookup data2(raw_data2); |
824 |
19 Mar 07 |
peter |
123 |
classifier::CrossValidationSampler cv2(target,3,3); |
1167 |
26 Feb 08 |
peter |
124 |
classifier::SubsetGenerator<classifier::MatrixLookup> cv_test(cv2,data2); |
824 |
19 Mar 07 |
peter |
125 |
|
824 |
19 Mar 07 |
peter |
126 |
std::vector<size_t> sample_count(10,0); |
824 |
19 Mar 07 |
peter |
127 |
std::vector<size_t> test_sample_count(9,0); |
824 |
19 Mar 07 |
peter |
128 |
std::vector<size_t> test_class_count(3,0); |
824 |
19 Mar 07 |
peter |
129 |
std::vector<double> test_value1(4,0); |
824 |
19 Mar 07 |
peter |
130 |
std::vector<double> test_value2(4,0); |
824 |
19 Mar 07 |
peter |
131 |
std::vector<double> t_value(4,0); |
4200 |
19 Aug 22 |
peter |
132 |
std::vector<double> v_value(4,0); |
1273 |
10 Apr 08 |
jari |
133 |
for(unsigned long k=0;k<cv_test.size();k++) { |
4200 |
19 Aug 22 |
peter |
134 |
|
1167 |
26 Feb 08 |
peter |
135 |
const classifier::MatrixLookup& tv_view=cv_test.training_data(k); |
824 |
19 Mar 07 |
peter |
136 |
const classifier::Target& tv_target=cv_test.training_target(k); |
1134 |
23 Feb 08 |
peter |
137 |
const utility::Index& tv_index=cv_test.training_index(k); |
1167 |
26 Feb 08 |
peter |
138 |
const classifier::MatrixLookup& test_view=cv_test.validation_data(k); |
824 |
19 Mar 07 |
peter |
139 |
const classifier::Target& test_target=cv_test.validation_target(k); |
1134 |
23 Feb 08 |
peter |
140 |
const utility::Index& test_index=cv_test.validation_index(k); |
824 |
19 Mar 07 |
peter |
141 |
|
824 |
19 Mar 07 |
peter |
142 |
for (size_t i=0; i<test_index.size(); i++) { |
824 |
19 Mar 07 |
peter |
143 |
assert(test_index[i]<sample_count.size()); |
824 |
19 Mar 07 |
peter |
144 |
test_sample_count[test_index[i]]++; |
824 |
19 Mar 07 |
peter |
145 |
test_class_count[target(test_index[i])]++; |
824 |
19 Mar 07 |
peter |
146 |
test_value1[0]+=test_view(0,i); |
824 |
19 Mar 07 |
peter |
147 |
test_value2[0]+=test_view(1,i); |
824 |
19 Mar 07 |
peter |
148 |
test_value1[test_target(i)+1]+=test_view(0,i); |
824 |
19 Mar 07 |
peter |
149 |
test_value2[test_target(i)+1]+=test_view(1,i); |
824 |
19 Mar 07 |
peter |
150 |
if(test_target(i)!=target(test_index[i])) { |
824 |
19 Mar 07 |
peter |
151 |
ok=false; |
1248 |
19 Mar 08 |
peter |
152 |
suite.err() << "ERROR: incorrect mapping of test indices" << std:: endl; |
4200 |
19 Aug 22 |
peter |
153 |
} |
824 |
19 Mar 07 |
peter |
154 |
} |
4200 |
19 Aug 22 |
peter |
155 |
|
824 |
19 Mar 07 |
peter |
156 |
classifier::CrossValidationSampler sampler_training(tv_target,2,2); |
4200 |
19 Aug 22 |
peter |
157 |
classifier::SubsetGenerator<classifier::MatrixLookup> |
1072 |
12 Feb 08 |
peter |
158 |
cv_training(sampler_training,tv_view); |
824 |
19 Mar 07 |
peter |
159 |
std::vector<size_t> v_sample_count(6,0); |
824 |
19 Mar 07 |
peter |
160 |
std::vector<size_t> t_sample_count(6,0); |
824 |
19 Mar 07 |
peter |
161 |
std::vector<size_t> v_class_count(3,0); |
824 |
19 Mar 07 |
peter |
162 |
std::vector<size_t> t_class_count(3,0); |
824 |
19 Mar 07 |
peter |
163 |
std::vector<size_t> t_class_count2(3,0); |
1273 |
10 Apr 08 |
jari |
164 |
for(unsigned long l=0;l<cv_training.size();l++) { |
1170 |
27 Feb 08 |
peter |
165 |
const classifier::MatrixLookup& t_view=cv_training.training_data(l); |
824 |
19 Mar 07 |
peter |
166 |
const classifier::Target& t_target=cv_training.training_target(l); |
1134 |
23 Feb 08 |
peter |
167 |
const utility::Index& t_index=cv_training.training_index(l); |
1170 |
27 Feb 08 |
peter |
168 |
const classifier::MatrixLookup& v_view=cv_training.validation_data(l); |
824 |
19 Mar 07 |
peter |
169 |
const classifier::Target& v_target=cv_training.validation_target(l); |
1134 |
23 Feb 08 |
peter |
170 |
const utility::Index& v_index=cv_training.validation_index(l); |
4200 |
19 Aug 22 |
peter |
171 |
|
4200 |
19 Aug 22 |
peter |
172 |
if (test_index.size()+tv_index.size()!=target.size() |
4200 |
19 Aug 22 |
peter |
173 |
|| t_index.size()+v_index.size() != tv_target.size() |
824 |
19 Mar 07 |
peter |
174 |
|| test_index.size()+v_index.size()+t_index.size() != target.size()){ |
824 |
19 Mar 07 |
peter |
175 |
ok = false; |
4200 |
19 Aug 22 |
peter |
176 |
suite.err() << "ERROR: size of training samples, validation samples " |
4200 |
19 Aug 22 |
peter |
177 |
<< "and test samples in is invalid." |
824 |
19 Mar 07 |
peter |
178 |
<< std::endl; |
824 |
19 Mar 07 |
peter |
179 |
} |
824 |
19 Mar 07 |
peter |
180 |
if (test_index.size()!=3 || tv_index.size()!=6 || t_index.size()!=3 || |
824 |
19 Mar 07 |
peter |
181 |
v_index.size()!=3){ |
824 |
19 Mar 07 |
peter |
182 |
ok = false; |
1248 |
19 Mar 08 |
peter |
183 |
suite.err() << "ERROR: size of training, validation, and test samples" |
4200 |
19 Aug 22 |
peter |
184 |
<< " is invalid." |
824 |
19 Mar 07 |
peter |
185 |
<< " Expected sizes to be 3" << std::endl; |
4200 |
19 Aug 22 |
peter |
186 |
} |
824 |
19 Mar 07 |
peter |
187 |
|
824 |
19 Mar 07 |
peter |
188 |
std::vector<size_t> tv_sample_count(6,0); |
824 |
19 Mar 07 |
peter |
189 |
for (size_t i=0; i<t_index.size(); i++) { |
824 |
19 Mar 07 |
peter |
190 |
assert(t_index[i]<t_sample_count.size()); |
824 |
19 Mar 07 |
peter |
191 |
tv_sample_count[t_index[i]]++; |
824 |
19 Mar 07 |
peter |
192 |
t_sample_count[t_index[i]]++; |
824 |
19 Mar 07 |
peter |
193 |
t_class_count[t_target(i)]++; |
824 |
19 Mar 07 |
peter |
194 |
t_class_count2[tv_target(t_index[i])]++; |
824 |
19 Mar 07 |
peter |
195 |
t_value[0]+=t_view(0,i); |
4200 |
19 Aug 22 |
peter |
196 |
t_value[t_target(i)+1]+=t_view(0,i); |
824 |
19 Mar 07 |
peter |
197 |
} |
824 |
19 Mar 07 |
peter |
198 |
for (size_t i=0; i<v_index.size(); i++) { |
824 |
19 Mar 07 |
peter |
199 |
assert(v_index[i]<v_sample_count.size()); |
824 |
19 Mar 07 |
peter |
200 |
tv_sample_count[v_index[i]]++; |
824 |
19 Mar 07 |
peter |
201 |
v_sample_count[v_index[i]]++; |
824 |
19 Mar 07 |
peter |
202 |
v_class_count[v_target(i)]++; |
824 |
19 Mar 07 |
peter |
203 |
v_value[0]+=v_view(0,i); |
824 |
19 Mar 07 |
peter |
204 |
v_value[v_target(i)+1]+=v_view(0,i); |
824 |
19 Mar 07 |
peter |
205 |
} |
640 |
07 Sep 06 |
peter |
206 |
|
4200 |
19 Aug 22 |
peter |
207 |
ok = ok && sample_count_test(tv_sample_count,suite); |
4200 |
19 Aug 22 |
peter |
208 |
|
824 |
19 Mar 07 |
peter |
209 |
} |
1248 |
19 Mar 08 |
peter |
210 |
ok = ok && sample_count_test(v_sample_count,suite); |
1248 |
19 Mar 08 |
peter |
211 |
ok = ok && sample_count_test(t_sample_count,suite); |
4200 |
19 Aug 22 |
peter |
212 |
|
1248 |
19 Mar 08 |
peter |
213 |
ok = ok && class_count_test(t_class_count,suite); |
1248 |
19 Mar 08 |
peter |
214 |
ok = ok && class_count_test(t_class_count2,suite); |
1248 |
19 Mar 08 |
peter |
215 |
ok = ok && class_count_test(v_class_count,suite); |
824 |
19 Mar 07 |
peter |
216 |
|
824 |
19 Mar 07 |
peter |
217 |
|
824 |
19 Mar 07 |
peter |
218 |
} |
1248 |
19 Mar 08 |
peter |
219 |
ok = ok && sample_count_test(test_sample_count,suite); |
1248 |
19 Mar 08 |
peter |
220 |
ok = ok && class_count_test(test_class_count,suite); |
4200 |
19 Aug 22 |
peter |
221 |
|
824 |
19 Mar 07 |
peter |
222 |
if(test_value1[0]!=135 || test_value1[1]!=36 || test_value1[2]!=45 || |
824 |
19 Mar 07 |
peter |
223 |
test_value1[3]!=54) { |
824 |
19 Mar 07 |
peter |
224 |
ok=false; |
4200 |
19 Aug 22 |
peter |
225 |
suite.err() << "ERROR: incorrect sums of test values in row 1" |
4200 |
19 Aug 22 |
peter |
226 |
<< " found: " << test_value1[0] << ", " << test_value1[1] |
4200 |
19 Aug 22 |
peter |
227 |
<< ", " << test_value1[2] << " and " << test_value1[3] |
824 |
19 Mar 07 |
peter |
228 |
<< std::endl; |
824 |
19 Mar 07 |
peter |
229 |
} |
824 |
19 Mar 07 |
peter |
230 |
|
4200 |
19 Aug 22 |
peter |
231 |
|
824 |
19 Mar 07 |
peter |
232 |
if(test_value2[0]!=225 || test_value2[1]!=66 || test_value2[2]!=75 || |
824 |
19 Mar 07 |
peter |
233 |
test_value2[3]!=84) { |
824 |
19 Mar 07 |
peter |
234 |
ok=false; |
4200 |
19 Aug 22 |
peter |
235 |
suite.err() << "ERROR: incorrect sums of test values in row 2" |
4200 |
19 Aug 22 |
peter |
236 |
<< " found: " << test_value2[0] << ", " << test_value2[1] |
4200 |
19 Aug 22 |
peter |
237 |
<< ", " << test_value2[2] << " and " << test_value2[3] |
824 |
19 Mar 07 |
peter |
238 |
<< std::endl; |
824 |
19 Mar 07 |
peter |
239 |
} |
824 |
19 Mar 07 |
peter |
240 |
|
824 |
19 Mar 07 |
peter |
241 |
if(t_value[0]!=270 || t_value[1]!=72 || t_value[2]!=90 || t_value[3]!=108) { |
824 |
19 Mar 07 |
peter |
242 |
ok=false; |
4200 |
19 Aug 22 |
peter |
243 |
suite.err() << "ERROR: incorrect sums of training values in row 1" |
4200 |
19 Aug 22 |
peter |
244 |
<< " found: " << t_value[0] << ", " << t_value[1] |
4200 |
19 Aug 22 |
peter |
245 |
<< ", " << t_value[2] << " and " << t_value[3] |
4200 |
19 Aug 22 |
peter |
246 |
<< std::endl; |
824 |
19 Mar 07 |
peter |
247 |
} |
824 |
19 Mar 07 |
peter |
248 |
|
824 |
19 Mar 07 |
peter |
249 |
if(v_value[0]!=270 || v_value[1]!=72 || v_value[2]!=90 || v_value[3]!=108) { |
824 |
19 Mar 07 |
peter |
250 |
ok=false; |
4200 |
19 Aug 22 |
peter |
251 |
suite.err() << "ERROR: incorrect sums of validation values in row 1" |
4200 |
19 Aug 22 |
peter |
252 |
<< " found: " << v_value[0] << ", " << v_value[1] |
4200 |
19 Aug 22 |
peter |
253 |
<< ", " << v_value[2] << " and " << v_value[3] |
4200 |
19 Aug 22 |
peter |
254 |
<< std::endl; |
824 |
19 Mar 07 |
peter |
255 |
} |
824 |
19 Mar 07 |
peter |
256 |
return ok; |
640 |
07 Sep 06 |
peter |
257 |
} |
824 |
19 Mar 07 |
peter |
258 |
|
4200 |
19 Aug 22 |
peter |
259 |
bool class_count_test(const std::vector<size_t>& class_count, |
4200 |
19 Aug 22 |
peter |
260 |
test::Suite& suite) |
824 |
19 Mar 07 |
peter |
261 |
{ |
824 |
19 Mar 07 |
peter |
262 |
bool ok=true; |
824 |
19 Mar 07 |
peter |
263 |
for (size_t i=0; i<class_count.size(); i++) |
824 |
19 Mar 07 |
peter |
264 |
if (class_count[i]==0){ |
824 |
19 Mar 07 |
peter |
265 |
ok = false; |
4200 |
19 Aug 22 |
peter |
266 |
suite.err() << "ERROR: class " << i << " was not in set." |
4200 |
19 Aug 22 |
peter |
267 |
<< " Expected at least one sample from each class." |
824 |
19 Mar 07 |
peter |
268 |
<< std::endl; |
824 |
19 Mar 07 |
peter |
269 |
} |
824 |
19 Mar 07 |
peter |
270 |
return ok; |
824 |
19 Mar 07 |
peter |
271 |
} |
824 |
19 Mar 07 |
peter |
272 |
|
4200 |
19 Aug 22 |
peter |
273 |
bool sample_count_test(const std::vector<size_t>& sample_count, |
4200 |
19 Aug 22 |
peter |
274 |
test::Suite& suite) |
824 |
19 Mar 07 |
peter |
275 |
{ |
824 |
19 Mar 07 |
peter |
276 |
bool ok=true; |
824 |
19 Mar 07 |
peter |
277 |
for (size_t i=0; i<sample_count.size(); i++){ |
824 |
19 Mar 07 |
peter |
278 |
if (sample_count[i]!=1){ |
824 |
19 Mar 07 |
peter |
279 |
ok = false; |
4200 |
19 Aug 22 |
peter |
280 |
suite.err() << "ERROR: sample " << i << " was in a group " << sample_count[i] |
824 |
19 Mar 07 |
peter |
281 |
<< " times." << " Expected to be 1 time" << std::endl; |
824 |
19 Mar 07 |
peter |
282 |
} |
824 |
19 Mar 07 |
peter |
283 |
} |
824 |
19 Mar 07 |
peter |
284 |
return ok; |
824 |
19 Mar 07 |
peter |
285 |
} |
824 |
19 Mar 07 |
peter |
286 |
|
824 |
19 Mar 07 |
peter |
287 |
|
1248 |
19 Mar 08 |
peter |
288 |
bool test_bootstrap(test::Suite& suite) |
824 |
19 Mar 07 |
peter |
289 |
{ |
824 |
19 Mar 07 |
peter |
290 |
bool ok=true; |
824 |
19 Mar 07 |
peter |
291 |
std::vector<std::string> label(10,"default"); |
824 |
19 Mar 07 |
peter |
292 |
label[2]=label[7]="white"; |
824 |
19 Mar 07 |
peter |
293 |
label[4]=label[5]="black"; |
824 |
19 Mar 07 |
peter |
294 |
label[6]=label[3]="green"; |
824 |
19 Mar 07 |
peter |
295 |
label[8]=label[9]="red"; |
4200 |
19 Aug 22 |
peter |
296 |
|
824 |
19 Mar 07 |
peter |
297 |
classifier::Target target(label); |
1121 |
22 Feb 08 |
peter |
298 |
utility::Matrix raw_data(10,10); |
824 |
19 Mar 07 |
peter |
299 |
classifier::MatrixLookup data(raw_data); |
824 |
19 Mar 07 |
peter |
300 |
classifier::BootstrapSampler cv(target,3); |
824 |
19 Mar 07 |
peter |
301 |
return ok; |
824 |
19 Mar 07 |
peter |
302 |
} |
824 |
19 Mar 07 |
peter |
303 |
|
824 |
19 Mar 07 |
peter |
304 |
|
1248 |
19 Mar 08 |
peter |
305 |
bool test_cv(test::Suite& suite) |
824 |
19 Mar 07 |
peter |
306 |
{ |
824 |
19 Mar 07 |
peter |
307 |
bool ok=true; |
824 |
19 Mar 07 |
peter |
308 |
std::vector<std::string> label(10,"default"); |
824 |
19 Mar 07 |
peter |
309 |
label[2]=label[7]="white"; |
824 |
19 Mar 07 |
peter |
310 |
label[4]=label[5]="black"; |
824 |
19 Mar 07 |
peter |
311 |
label[6]=label[3]="green"; |
824 |
19 Mar 07 |
peter |
312 |
label[8]=label[9]="red"; |
4200 |
19 Aug 22 |
peter |
313 |
|
824 |
19 Mar 07 |
peter |
314 |
classifier::Target target(label); |
1121 |
22 Feb 08 |
peter |
315 |
utility::Matrix raw_data(10,10); |
824 |
19 Mar 07 |
peter |
316 |
classifier::MatrixLookup data(raw_data); |
824 |
19 Mar 07 |
peter |
317 |
classifier::CrossValidationSampler cv(target,3,3); |
4200 |
19 Aug 22 |
peter |
318 |
|
824 |
19 Mar 07 |
peter |
319 |
std::vector<size_t> sample_count(10,0); |
824 |
19 Mar 07 |
peter |
320 |
for (size_t j=0; j<cv.size(); ++j){ |
824 |
19 Mar 07 |
peter |
321 |
std::vector<size_t> class_count(5,0); |
824 |
19 Mar 07 |
peter |
322 |
assert(j<cv.size()); |
824 |
19 Mar 07 |
peter |
323 |
if (cv.training_index(j).size()+cv.validation_index(j).size()!= |
824 |
19 Mar 07 |
peter |
324 |
target.size()){ |
824 |
19 Mar 07 |
peter |
325 |
ok = false; |
4200 |
19 Aug 22 |
peter |
326 |
suite.err() << "ERROR: size of training samples plus " |
824 |
19 Mar 07 |
peter |
327 |
<< "size of validation samples is invalid." << std::endl; |
824 |
19 Mar 07 |
peter |
328 |
} |
824 |
19 Mar 07 |
peter |
329 |
if (cv.validation_index(j).size()!=3 && cv.validation_index(j).size()!=4){ |
824 |
19 Mar 07 |
peter |
330 |
ok = false; |
4200 |
19 Aug 22 |
peter |
331 |
suite.err() << "ERROR: size of validation samples is invalid." |
824 |
19 Mar 07 |
peter |
332 |
<< "expected size to be 3 or 4" << std::endl; |
824 |
19 Mar 07 |
peter |
333 |
} |
824 |
19 Mar 07 |
peter |
334 |
for (size_t i=0; i<cv.validation_index(j).size(); i++) { |
824 |
19 Mar 07 |
peter |
335 |
assert(cv.validation_index(j)[i]<sample_count.size()); |
824 |
19 Mar 07 |
peter |
336 |
sample_count[cv.validation_index(j)[i]]++; |
824 |
19 Mar 07 |
peter |
337 |
} |
824 |
19 Mar 07 |
peter |
338 |
for (size_t i=0; i<cv.training_index(j).size(); i++) { |
824 |
19 Mar 07 |
peter |
339 |
class_count[target(cv.training_index(j)[i])]++; |
824 |
19 Mar 07 |
peter |
340 |
} |
1248 |
19 Mar 08 |
peter |
341 |
ok = ok && class_count_test(class_count,suite); |
824 |
19 Mar 07 |
peter |
342 |
} |
1248 |
19 Mar 08 |
peter |
343 |
ok = ok && sample_count_test(sample_count,suite); |
4200 |
19 Aug 22 |
peter |
344 |
|
824 |
19 Mar 07 |
peter |
345 |
return ok; |
824 |
19 Mar 07 |
peter |
346 |
} |