test/matrix_expression.cc

Code
Comments
Other
Rev Date Author Line
3603 23 Jan 17 peter 1 // $Id$
3603 23 Jan 17 peter 2
3603 23 Jan 17 peter 3 /*
4207 26 Aug 22 peter 4   Copyright (C) 2017, 2020, 2022 Peter Johansson
3603 23 Jan 17 peter 5
3603 23 Jan 17 peter 6   This file is part of the yat library, http://dev.thep.lu.se/yat
3603 23 Jan 17 peter 7
3603 23 Jan 17 peter 8   The yat library is free software; you can redistribute it and/or
3603 23 Jan 17 peter 9   modify it under the terms of the GNU General Public License as
3603 23 Jan 17 peter 10   published by the Free Software Foundation; either version 3 of the
3603 23 Jan 17 peter 11   License, or (at your option) any later version.
3603 23 Jan 17 peter 12
3603 23 Jan 17 peter 13   The yat library is distributed in the hope that it will be useful,
3603 23 Jan 17 peter 14   but WITHOUT ANY WARRANTY; without even the implied warranty of
3603 23 Jan 17 peter 15   MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
3603 23 Jan 17 peter 16   General Public License for more details.
3603 23 Jan 17 peter 17
3603 23 Jan 17 peter 18   You should have received a copy of the GNU General Public License
3603 23 Jan 17 peter 19   along with yat. If not, see <http://www.gnu.org/licenses/>.
3603 23 Jan 17 peter 20 */
3603 23 Jan 17 peter 21
3603 23 Jan 17 peter 22 #include <config.h>
3603 23 Jan 17 peter 23
3603 23 Jan 17 peter 24 #include "Suite.h"
3603 23 Jan 17 peter 25
3603 23 Jan 17 peter 26 #include "yat/utility/Matrix.h"
3603 23 Jan 17 peter 27
3603 23 Jan 17 peter 28 using namespace theplu::yat;
3603 23 Jan 17 peter 29 using utility::Matrix;
3603 23 Jan 17 peter 30
3603 23 Jan 17 peter 31 bool check(const Matrix& lhs, const Matrix& rhs, test::Suite& suite,
3603 23 Jan 17 peter 32            unsigned int N=1)
3603 23 Jan 17 peter 33 {
3603 23 Jan 17 peter 34   if (lhs.rows() != rhs.rows() || lhs.columns() != rhs.columns()) {
3603 23 Jan 17 peter 35     suite.err() << "error: wrong dimensions: comparing "
3603 23 Jan 17 peter 36                 << lhs.rows() << " x " << lhs.columns() << " with "
3603 23 Jan 17 peter 37                 << rhs.rows() << " x " << rhs.columns() << "\n";
3603 23 Jan 17 peter 38     suite.add(false);
3603 23 Jan 17 peter 39     return false;
3603 23 Jan 17 peter 40   }
3603 23 Jan 17 peter 41   if (!suite.equal_range(lhs.begin(), lhs.end(), rhs.begin(), N)) {
3603 23 Jan 17 peter 42     suite.add(false);
3603 23 Jan 17 peter 43     return false;
3603 23 Jan 17 peter 44   }
3603 23 Jan 17 peter 45   return true;
3603 23 Jan 17 peter 46 }
3603 23 Jan 17 peter 47
3603 23 Jan 17 peter 48
4139 29 Jan 22 peter 49 void func_matrix_const(const Matrix&)
4139 29 Jan 22 peter 50 {
4139 29 Jan 22 peter 51 }
4139 29 Jan 22 peter 52
4139 29 Jan 22 peter 53
4139 29 Jan 22 peter 54 void func_matrix_rvalue(Matrix&&)
4139 29 Jan 22 peter 55 {
4139 29 Jan 22 peter 56 }
4139 29 Jan 22 peter 57
4139 29 Jan 22 peter 58
4139 29 Jan 22 peter 59 void func_matrix_base(const utility::MatrixBase&)
4139 29 Jan 22 peter 60 {
4139 29 Jan 22 peter 61 }
4139 29 Jan 22 peter 62
4139 29 Jan 22 peter 63
4139 29 Jan 22 peter 64 void func_matrix_mutable_rvalue(utility::MatrixMutable&&)
4139 29 Jan 22 peter 65 {
4139 29 Jan 22 peter 66 }
4139 29 Jan 22 peter 67
4139 29 Jan 22 peter 68
3603 23 Jan 17 peter 69 int main(int argc, char* argv[])
3603 23 Jan 17 peter 70 {
3603 23 Jan 17 peter 71   test::Suite suite(argc, argv);
3603 23 Jan 17 peter 72   Matrix A(4, 3, 1);
3603 23 Jan 17 peter 73   Matrix B(4, 3, 2);
3603 23 Jan 17 peter 74   Matrix C(4, 3, 4);
4139 29 Jan 22 peter 75
4139 29 Jan 22 peter 76   // test that we can pass a matrix expression to a function
4139 29 Jan 22 peter 77   func_matrix_const(A+B);
4139 29 Jan 22 peter 78   func_matrix_rvalue(A+B);
4139 29 Jan 22 peter 79 #ifdef YAT_TICKET897
4139 29 Jan 22 peter 80   func_matrix_base(A+B);
4139 29 Jan 22 peter 81   func_matrix_mutable_rvalue(A+B);
4139 29 Jan 22 peter 82 #endif
4139 29 Jan 22 peter 83
3603 23 Jan 17 peter 84   // addition operator+
3603 23 Jan 17 peter 85   {
3603 23 Jan 17 peter 86     suite.out() << "testing operator+\n";
3603 23 Jan 17 peter 87     Matrix oldres(A);
3604 23 Jan 17 peter 88     oldres += B;
3604 23 Jan 17 peter 89     oldres += C;
3603 23 Jan 17 peter 90     Matrix sum = A + B + C;
3603 23 Jan 17 peter 91     check(sum, oldres, suite);
3603 23 Jan 17 peter 92   }
3603 23 Jan 17 peter 93
3605 27 Jan 17 peter 94
3604 23 Jan 17 peter 95   // subtraction operator-
3603 23 Jan 17 peter 96   {
3603 23 Jan 17 peter 97     suite.out() << "testing operator-\n";
3603 23 Jan 17 peter 98     Matrix oldres(A);
3604 23 Jan 17 peter 99     oldres -= B;
3604 23 Jan 17 peter 100     oldres -= C;
3603 23 Jan 17 peter 101     Matrix sum = A - B - C;
3603 23 Jan 17 peter 102     check(sum, oldres, suite);
3603 23 Jan 17 peter 103   }
3603 23 Jan 17 peter 104
3603 23 Jan 17 peter 105   // combining addition and subtraction
3603 23 Jan 17 peter 106   {
3603 23 Jan 17 peter 107     suite.out() << "testing operator+ and operator-\n";
3603 23 Jan 17 peter 108     Matrix oldres(A);
3604 23 Jan 17 peter 109     oldres += B;
3604 23 Jan 17 peter 110     oldres -= C;
3603 23 Jan 17 peter 111     Matrix sum = A + B - C;
3603 23 Jan 17 peter 112     check(sum, oldres, suite);
3603 23 Jan 17 peter 113   }
3603 23 Jan 17 peter 114
3603 23 Jan 17 peter 115   Matrix At(A);
3603 23 Jan 17 peter 116   At.transpose();
3603 23 Jan 17 peter 117   Matrix Bt(B);
3603 23 Jan 17 peter 118   Bt.transpose();
3603 23 Jan 17 peter 119   Matrix Ct(C);
3603 23 Jan 17 peter 120   Ct.transpose();
3603 23 Jan 17 peter 121
3603 23 Jan 17 peter 122   // operator*
3603 23 Jan 17 peter 123   {
3603 23 Jan 17 peter 124     suite.out() << "testing operator*\n";
3603 23 Jan 17 peter 125     Matrix oldres(A);
3603 23 Jan 17 peter 126     oldres *= At;
3603 23 Jan 17 peter 127     oldres *= B;
3603 23 Jan 17 peter 128     oldres *= Bt;
3603 23 Jan 17 peter 129     Matrix res = A * At * B * Bt;
3603 23 Jan 17 peter 130
3603 23 Jan 17 peter 131     check(res, oldres, suite);
3603 23 Jan 17 peter 132   }
3603 23 Jan 17 peter 133
3603 23 Jan 17 peter 134   // operator*
3603 23 Jan 17 peter 135   {
3603 23 Jan 17 peter 136     suite.out() << "testing operator*(double)\n";
3603 23 Jan 17 peter 137     double a = 2;
3603 23 Jan 17 peter 138     double b = 3;
3603 23 Jan 17 peter 139     Matrix oldres(A);
3603 23 Jan 17 peter 140     oldres *= a;
3603 23 Jan 17 peter 141     oldres *= b;
3603 23 Jan 17 peter 142     Matrix res = a * A * b;
3603 23 Jan 17 peter 143
3603 23 Jan 17 peter 144     check(res, oldres, suite);
3603 23 Jan 17 peter 145   }
3603 23 Jan 17 peter 146
3603 23 Jan 17 peter 147   // combining everything
3603 23 Jan 17 peter 148   {
3603 23 Jan 17 peter 149     Matrix oldres(A);
3603 23 Jan 17 peter 150     oldres += B;
3603 23 Jan 17 peter 151     Matrix factor2(At);
3603 23 Jan 17 peter 152     factor2 += Ct;
3603 23 Jan 17 peter 153     oldres *= factor2;
3603 23 Jan 17 peter 154     Matrix term2(A);
3603 23 Jan 17 peter 155     term2 *= At;
3603 23 Jan 17 peter 156     term2 *= 2;
3603 23 Jan 17 peter 157     oldres += term2;
3603 23 Jan 17 peter 158     Matrix res = (A + B) * (At + Ct) + 2*A*At;
3603 23 Jan 17 peter 159     check(res, oldres, suite);
3603 23 Jan 17 peter 160   }
3603 23 Jan 17 peter 161
3609 27 Jan 17 peter 162   // operator+=
3609 27 Jan 17 peter 163   {
3609 27 Jan 17 peter 164     suite.out() << "test: Matrix += MatrixExpression\n";
3609 27 Jan 17 peter 165     Matrix res1 = A + B*Bt*C;
3609 27 Jan 17 peter 166     Matrix res2(A);
3609 27 Jan 17 peter 167     res2 += B*Bt*C;
3609 27 Jan 17 peter 168     check(res1, res2, suite);
3609 27 Jan 17 peter 169   }
3609 27 Jan 17 peter 170
3609 27 Jan 17 peter 171
3609 27 Jan 17 peter 172   // operator-=
3609 27 Jan 17 peter 173   {
3609 27 Jan 17 peter 174     suite.out() << "test: Matrix += MatrixExpression\n";
3609 27 Jan 17 peter 175     Matrix res1 = A - B*Bt*C;
3609 27 Jan 17 peter 176     Matrix res2(A);
3609 27 Jan 17 peter 177     res2 -= B*Bt*C;
3609 27 Jan 17 peter 178     check(res1, res2, suite);
3609 27 Jan 17 peter 179   }
3609 27 Jan 17 peter 180
3610 27 Jan 17 peter 181   // negation
3610 27 Jan 17 peter 182   {
3610 27 Jan 17 peter 183     suite.out() << "test: negation\n";
3610 27 Jan 17 peter 184     Matrix res = -A;
3610 27 Jan 17 peter 185     res = -res;
3610 27 Jan 17 peter 186     check(res, A, suite);
3610 27 Jan 17 peter 187   }
3610 27 Jan 17 peter 188
3605 27 Jan 17 peter 189   // testing assignment operator
3605 27 Jan 17 peter 190   A = B + C;
3603 23 Jan 17 peter 191
3654 10 Jul 17 peter 192   // test transpose(A) (ticket 880)
3654 10 Jul 17 peter 193   {
3654 10 Jul 17 peter 194     suite.out() << "test: transpose\n";
3654 10 Jul 17 peter 195     // calculate A * Bt * C
3654 10 Jul 17 peter 196     // old interface
3654 10 Jul 17 peter 197     Matrix old(A);
3654 10 Jul 17 peter 198     old *= Bt;
3654 10 Jul 17 peter 199     old *= C;
3654 10 Jul 17 peter 200     // new interface
3654 10 Jul 17 peter 201     Matrix m = A * transpose(B) * C;
3654 10 Jul 17 peter 202     check(m, old, suite);
3654 10 Jul 17 peter 203     m.transpose();
3654 10 Jul 17 peter 204     Matrix m2 = transpose(A * transpose(B) * C);
3654 10 Jul 17 peter 205     check(m2, m, suite);
3904 07 May 20 peter 206     Matrix m3 = transpose(transpose(m));
3904 07 May 20 peter 207     check(m3, m, suite);
3904 07 May 20 peter 208     Matrix m4 = transpose(2*transpose(m));
3904 07 May 20 peter 209     Matrix m5 = transpose(transpose(2*m));
3904 07 May 20 peter 210     check(m4, m5, suite);
3654 10 Jul 17 peter 211   }
3654 10 Jul 17 peter 212
3654 10 Jul 17 peter 213
3904 07 May 20 peter 214   // test dgemm expression
3904 07 May 20 peter 215   {
3904 07 May 20 peter 216     suite.out() << "testing dgemm\n";
3904 07 May 20 peter 217
3904 07 May 20 peter 218     Matrix oldres(Bt);
3904 07 May 20 peter 219     oldres *= 0.5;
3904 07 May 20 peter 220     oldres *= C;
3904 07 May 20 peter 221     Matrix res = 0.5 * transpose(B) * C;
3904 07 May 20 peter 222     check(res, oldres, suite);
3904 07 May 20 peter 223
3904 07 May 20 peter 224     res = 0.5 * Bt * C;
3904 07 May 20 peter 225     check(res, oldres, suite);
3904 07 May 20 peter 226
3904 07 May 20 peter 227     res = 0.5 * transpose(B) * 1.0 * transpose(Ct);
3904 07 May 20 peter 228     check(res, oldres, suite);
3904 07 May 20 peter 229
3904 07 May 20 peter 230     res = 1.0 * Bt * 0.5 * transpose(Ct);
3904 07 May 20 peter 231     check(res, oldres, suite);
3904 07 May 20 peter 232
3904 07 May 20 peter 233     res = transpose(B) * C;
3904 07 May 20 peter 234   }
3603 23 Jan 17 peter 235   return suite.return_value();
3603 23 Jan 17 peter 236 }