Line data Source code
1 : #include "MatrixOperator.h" 2 : #include "Timer.h" 3 : #include "TypeCasts.hpp" 4 : #include "VolumeDescriptor.h" 5 : #include "elsaDefines.h" 6 : #include "thrust/detail/raw_pointer_cast.h" 7 : #include <algorithm> 8 : 9 : namespace elsa 10 : { 11 : template <typename data_t> 12 : MatrixOperator<data_t>::MatrixOperator(const Matrix_t<data_t>& mat) 13 : : MatrixOperator<data_t>(VolumeDescriptor({{mat.cols()}}), VolumeDescriptor({{mat.rows()}}), 14 : mat) 15 44 : { 16 44 : } 17 : 18 : template <typename data_t> 19 : MatrixOperator<data_t>::MatrixOperator(const DataDescriptor& domain, 20 : const DataDescriptor& range, const Matrix_t<data_t>& mat) 21 : : LinearOperator<data_t>(domain, range), 22 : storage_(mat.data(), mat.data() + mat.size()), 23 : mat_(thrust::raw_pointer_cast(storage_.data()), mat.rows(), mat.cols()) 24 304 : { 25 304 : } 26 : 27 : template <typename data_t> 28 : void MatrixOperator<data_t>::applyImpl(const DataContainer<data_t>& x, 29 : DataContainer<data_t>& Ax) const 30 971 : { 31 971 : Timer timeguard("MatrixOperator", "apply"); 32 : 33 971 : if (x.getSize() != this->getDomainDescriptor().getNumberOfCoefficients()) { 34 2 : throw Error("MatrixOperator: x needs to be of size {} (is {})", 35 2 : this->getDomainDescriptor().getNumberOfCoefficients(), x.getSize()); 36 2 : } 37 : 38 : // Wrap data into Eigen Map 39 969 : const data_t* ptr = thrust::raw_pointer_cast(x.storage().data()); 40 969 : Eigen::Map<const Vector_t<data_t>> vec(ptr, mat_.cols()); 41 : 42 969 : Vector_t<data_t> result = mat_ * vec; 43 969 : Ax = DataContainer(Ax.getDataDescriptor(), result); 44 969 : } 45 : 46 : template <typename data_t> 47 : void MatrixOperator<data_t>::applyAdjointImpl(const DataContainer<data_t>& y, 48 : DataContainer<data_t>& Aty) const 49 593 : { 50 593 : Timer timeguard("MatrixOperator", "applyAdjoint"); 51 : 52 593 : if (y.getSize() != this->getRangeDescriptor().getNumberOfCoefficients()) { 53 2 : throw Error("MatrixOperator: y needs to be of size {} (is {})", 54 2 : this->getRangeDescriptor().getNumberOfCoefficients(), y.getSize()); 55 2 : } 56 : 57 591 : const data_t* ptr = thrust::raw_pointer_cast(y.storage().data()); 58 591 : Eigen::Map<const Vector_t<data_t>> vec(ptr, mat_.rows()); 59 : 60 591 : Vector_t<data_t> result = mat_.transpose() * vec; 61 591 : Aty = DataContainer(Aty.getDataDescriptor(), result); 62 591 : } 63 : 64 : template <typename data_t> 65 : MatrixOperator<data_t>* MatrixOperator<data_t>::cloneImpl() const 66 260 : { 67 260 : return new MatrixOperator(this->getDomainDescriptor(), this->getRangeDescriptor(), mat_); 68 260 : } 69 : 70 : template <typename data_t> 71 : bool MatrixOperator<data_t>::isEqual(const LinearOperator<data_t>& other) const 72 18 : { 73 18 : if (!LinearOperator<data_t>::isEqual(other)) 74 0 : return false; 75 : 76 18 : if (!is<MatrixOperator>(other)) { 77 0 : return false; 78 0 : } 79 : 80 18 : const auto& otherOp = downcast<MatrixOperator>(other); 81 18 : return mat_.isApprox(otherOp.mat_); 82 18 : } 83 : 84 : // ------------------------------------------ 85 : // explicit template instantiation 86 : template class MatrixOperator<float>; 87 : template class MatrixOperator<double>; 88 : 89 : } // namespace elsa