我要问的是一个概括这个问题 https://stackoverflow.com/questions/13548253/eigen-library-return-a-matrix-block-in-a-function-as-lvalue。具体来说,我想围绕遗留的 C 和 Fortran 库制作一个 C++ Eigen 包装器,它使用 2D 数据结构:
[ x[0,0] ... x[0,w-1] ]
[ u[0,0] ... u[0,w-1] ]
[ ... ]
[ x[c-1,0] ... x[c-1,w-1] ]
[ u[c-1,0] ... u[c-1,w-1] ]
其中每个条目x[i,j]
and u[i,j]
它们本身就是大小为 (nx1
) and (mx1
) 分别。
这会导致一些复杂(且容易出错)的指针算术以及一些非常不可读的代码。
因此,我想编写一个 Eigen 类,其唯一目的是使提取该矩阵的条目尽可能容易。在 C++14 中,看起来像这样data_getter.h
:
#ifndef DATA_GETTER_HEADER
#define DATA_GETTER_HEADER
#include "Eigen/Dense"
template<typename T, int n, int m, int c, int w>
class DataGetter {
public:
/** Return a reference to the data as a matrix */
static auto asMatrix(T *raw_ptr) {
auto out = Eigen::Map<Eigen::Matrix<T, (n + m) * c, w>>(raw_ptr);
static_assert(decltype(out)::RowsAtCompileTime == (n + m) * c);
static_assert(decltype(out)::ColsAtCompileTime == w);
return out;
}
/** Return a reference to the submatrix
* [ x[i,0], ..., x[i,w-1]]
* [ u[i,0], ..., u[i,w-1]] */
static auto W(T *raw_ptr, int i) {
auto out = asMatrix(raw_ptr).template middleRows<n + m>((n + m) * i);
static_assert(decltype(out)::RowsAtCompileTime == (n + m));
static_assert(decltype(out)::ColsAtCompileTime == w);
return out;
}
/** Return a reference to the submatrix [ x[i,0], ..., x[i,w-1]] */
static auto X(T *raw_ptr, int i) {
auto out = W(raw_ptr, i).template topRows<n>();
static_assert(decltype(out)::RowsAtCompileTime == n);
static_assert(decltype(out)::ColsAtCompileTime == w);
return out;
}
/** Return a reference to x[i,j] */
static auto X(T *raw_ptr, int i, int j) {
auto out = X(raw_ptr, i).col(j);
static_assert(decltype(out)::RowsAtCompileTime == n);
static_assert(decltype(out)::ColsAtCompileTime == 1);
return out;
}
/** Return a reference to the submatrix [ u[i,0], ..., u[i,w-1]] */
static auto U(T *raw_ptr, int i) {
auto out = W(raw_ptr, i).template bottomRows<m>();
static_assert(decltype(out)::RowsAtCompileTime == m);
static_assert(decltype(out)::ColsAtCompileTime == w);
return out;
}
/** Return a reference to u[i,j] */
static auto U(T *raw_ptr, int i, int j) {
auto out = U(raw_ptr, i).col(j);
static_assert(decltype(out)::RowsAtCompileTime == m);
static_assert(decltype(out)::ColsAtCompileTime == 1);
return out;
}
/** Return a reference to the submatrix
* [ x[0,i], ..., x[c-1,i]]
* [ u[0,i], ..., u[c-1,i]] */
static auto C(T *raw_ptr, int i) {
auto out = Eigen::Map<Eigen::Matrix<T, n + m, c>>(
asMatrix(raw_ptr).col(i).template topRows<(n + m) * c>().data());
static_assert(decltype(out)::RowsAtCompileTime == (n + m));
static_assert(decltype(out)::ColsAtCompileTime == c);
return out;
}
/** Return a reference to the submatrix [ x[0,i], ..., x[c-1,i]] */
static auto Xc(T *raw_ptr, int i) {
auto out = C(raw_ptr, i).template topRows<n>();
static_assert(decltype(out)::RowsAtCompileTime == n);
static_assert(decltype(out)::ColsAtCompileTime == c);
return out;
}
/** Return a reference to the submatrix [ u[0,i], ..., u[c-1,i]] */
static auto Uc(T *raw_ptr, int i) {
auto out = C(raw_ptr, i).template bottomRows<m>();
static_assert(decltype(out)::RowsAtCompileTime == m);
static_assert(decltype(out)::ColsAtCompileTime == c);
return out;
}
};
#endif /* DATA_GETTER_HEADER */
这是一个测试程序,演示了它是如何工作的:
#include <iostream>
#include <vector>
#include "Eigen/Dense"
#include "data_getter.h"
using namespace std;
using namespace Eigen;
template<typename T>
void printSize(MatrixBase<T> &mat) {
cout << T::RowsAtCompileTime << " x " << T::ColsAtCompileTime;
}
int main() {
using T = double;
const int n = 2;
const int m = 3;
const int c = 2;
const int w = 5;
const int size = w * (c * (n + m));
std::vector<T> vec;
for (int i = 0; i < size; ++i)
vec.push_back(i);
/* Define the interface that we will use a lot */
using Data = DataGetter<T, n, m, c, w>;
/* Now let's map that pointer to some submatrices */
Ref<Matrix<T, (n + m) * c, w>> allData = Data::asMatrix(vec.data());
Ref<Matrix<T, n, w>> x1 = Data::X(vec.data(), 1);
Ref<Matrix<T, n, c>> xc2 = Data::Xc(vec.data(), 2);
Ref<Matrix<T, n + m, c>> xuc2 = Data::C(vec.data(), 2);
Ref<Matrix<T, n, 1>> x12 = Data::X(vec.data(), 1, 2);
cout << "Data::asMatrix( T* ): ";
printSize(allData);
cout << endl << endl << allData << endl << endl;
cout << "Data::X( T*, 1 ) : ";
printSize(x1);
cout << endl << endl << x1 << endl << endl;
cout << "Data::Xc( T*, 2 ) : ";
printSize(xc2);
cout << endl << endl << xc2 << endl << endl;
cout << "Data::C( T*, 2 ) : ";
printSize(xuc2);
cout << endl << endl << xuc2 << endl << endl;
cout << "Data::X( T*, 1, 2 ) : ";
printSize(x12);
cout << endl << endl << x12 << endl << endl;
/* Now changes to x12 should be reflected in the other variables */
x12.setZero();
cout << "-----" << endl << endl << "x12.setZero() " << endl << endl << "-----" << endl;
cout << "allData" << endl << endl << allData << endl << endl;
cout << "x1" << endl << endl << x1 << endl << endl;
cout << "xc2" << endl << endl << xc2 << endl << endl;
cout << "xuc2" << endl << endl << xuc2 << endl << endl;
cout << "x12" << endl << endl << x12 << endl << endl;
return 0;
}
具体来说,它产生以下输出(如预期):
Data::asMatrix( T* ): 10 x 5
0 10 20 30 40
1 11 21 31 41
2 12 22 32 42
3 13 23 33 43
4 14 24 34 44
5 15 25 35 45
6 16 26 36 46
7 17 27 37 47
8 18 28 38 48
9 19 29 39 49
Data::X( T*, 1 ) : 2 x 5
5 15 25 35 45
6 16 26 36 46
Data::Xc( T*, 2 ) : 2 x 2
20 25
21 26
Data::C( T*, 2 ) : 5 x 2
20 25
21 26
22 27
23 28
24 29
Data::X( T*, 1, 2 ) : 2 x 1
25
26
-----
x12.setZero()
-----
allData
0 10 20 30 40
1 11 21 31 41
2 12 22 32 42
3 13 23 33 43
4 14 24 34 44
5 15 0 35 45
6 16 0 36 46
7 17 27 37 47
8 18 28 38 48
9 19 29 39 49
x1
5 15 0 35 45
6 16 0 36 46
xc2
20 0
21 0
xuc2
20 0
21 0
22 27
23 28
24 29
x12
0
0
问题是对维度的编译时检查似乎不起作用。在里面data_getter.h
,你可能会注意到我放了一堆static_assert
尺寸。这可能看起来有点矫枉过正,但我想确保表达式确实执行编译时操作,以便我们可以检查维度。如果它们是动态表达式,那么大小都将为-1。
然而,尽管事实上所有的static_assert
通过,似乎没有对引用进行任何编译时检查。例如,如果我们更改测试程序中的以下行
Ref<Matrix<T, (n + m) * c, w>> allData = Data::asMatrix(vec.data());
into
Ref<Matrix<T, (n + m) * c + 1, w>> allData = Data::asMatrix(vec.data());
代码可以编译,但会产生运行时崩溃。这似乎表明Ref
正在丢弃维度。那么我应该如何定义这些变量呢?
可能想到的一个想法是将这些返回值定义为auto
以及。然而,这是Eigen 文档明确不鼓励 https://eigen.tuxfamily.org/dox/TopicPitfalls.html因为如果我们最终在循环中使用输出,则可能会导致表达式被一遍又一遍地求值。这就是我使用的原因Ref
s。另外,明确说明大小似乎是个好主意,因为我们在编译时就知道它......
那么这是 Ref 中的一个错误吗?对于所有访问器方法吐出的变量,我应该使用什么类型?