Line data Source code
1 : #pragma once 2 : 3 : #include "elsaDefines.h" 4 : #include "Cloneable.h" 5 : #include "Error.h" 6 : #include "ExpressionPredicates.h" 7 : 8 : #ifdef ELSA_CUDA_VECTOR 9 : #include "Quickvec.cuh" 10 : #endif 11 : 12 : #include <Eigen/Core> 13 : 14 : namespace elsa 15 : { 16 : class DataDescriptor; 17 : 18 : /** 19 : * @brief Base class encapsulating data handling. The data is stored transparently, for example 20 : * on CPU or GPU. 21 : * 22 : * @author David Frank - initial code 23 : * @author Tobias Lasser - modularization, modernization 24 : * @author Nikola Dinev - add block support 25 : * 26 : * This abstract base class serves as an interface for data handlers, which encapsulate the 27 : * actual data being stored e.g. in main memory of the CPU or in various memory types of GPUs. 28 : * The data itself is treated as a vector, i.e. an array of data_t elements (which usually comes 29 : * from linearized n-dimensional signals). 30 : * 31 : * Caveat: If data is not stored in main memory (e.g. on GPUs), then some operations may trigger 32 : * an automatic synchronization of GPU to main memory. Please see the GPU-based handlers' 33 : * documentation for details. 34 : */ 35 : template <typename data_t = real_t> 36 : class DataHandler : public Cloneable<DataHandler<data_t>> 37 : { 38 : protected: 39 : /// convenience typedef for the Eigen::Matrix data vector 40 : using DataVector_t = Eigen::Matrix<data_t, Eigen::Dynamic, 1>; 41 : 42 : /// convenience typedef for the Eigen::Map 43 : using DataMap_t = Eigen::Map<DataVector_t>; 44 : 45 : public: 46 : /// convenience typedef to access data type that is internally stored 47 : using value_type = data_t; 48 : 49 : /// return the size of the stored data (i.e. number of elements in linearized data vector) 50 : virtual index_t getSize() const = 0; 51 : 52 : /// return the index-th element of the data vector (not bounds-checked!) 53 : virtual data_t& operator[](index_t index) = 0; 54 : 55 : /// return the index-th element of the data vector as read-only (not bounds-checked!) 56 : virtual const data_t& operator[](index_t index) const = 0; 57 : 58 : /// return the dot product of the data vector with vector v 59 : virtual data_t dot(const DataHandler<data_t>& v) const = 0; 60 : 61 : /// return the squared l2 norm of the data vector (dot product with itself) 62 : virtual GetFloatingPointType_t<data_t> squaredL2Norm() const = 0; 63 : 64 : /// return the l2 norm of the data vector (square root of dot product with itself) 65 : virtual GetFloatingPointType_t<data_t> l2Norm() const = 0; 66 : 67 : /// return the l0 pseudo-norm of the data vector (number of non-zero values) 68 : virtual index_t l0PseudoNorm() const = 0; 69 : 70 : /// return the l1 norm of the data vector (sum of absolute values) 71 : virtual GetFloatingPointType_t<data_t> l1Norm() const = 0; 72 : 73 : /// return the linf norm of the data vector (maximum of absolute values) 74 : virtual GetFloatingPointType_t<data_t> lInfNorm() const = 0; 75 : 76 : /// return the sum of all elements of the data vector 77 : virtual data_t sum() const = 0; 78 : 79 : /// return the min of all elements of the data vector 80 : virtual data_t minElement() const = 0; 81 : 82 : /// return the max of all elements of the data vector 83 : virtual data_t maxElement() const = 0; 84 : 85 : /// in-place create the fourier transformed of the data vector 86 : virtual DataHandler<data_t>& fft(const DataDescriptor& source_desc, FFTNorm norm) = 0; 87 : 88 : /// in-place create the inverse fourier transformed of the data vector 89 : virtual DataHandler<data_t>& ifft(const DataDescriptor& source_desc, FFTNorm norm) = 0; 90 : 91 : /// compute in-place element-wise addition of another vector v 92 : virtual DataHandler<data_t>& operator+=(const DataHandler<data_t>& v) = 0; 93 : 94 : /// compute in-place element-wise subtraction of another vector v 95 : virtual DataHandler<data_t>& operator-=(const DataHandler<data_t>& v) = 0; 96 : 97 : /// compute in-place element-wise multiplication by another vector v 98 : virtual DataHandler<data_t>& operator*=(const DataHandler<data_t>& v) = 0; 99 : 100 : /// compute in-place element-wise division by another vector v 101 : virtual DataHandler<data_t>& operator/=(const DataHandler<data_t>& v) = 0; 102 : 103 : /// compute in-place addition of a scalar 104 : virtual DataHandler<data_t>& operator+=(data_t scalar) = 0; 105 : 106 : /// compute in-place subtraction of a scalar 107 : virtual DataHandler<data_t>& operator-=(data_t scalar) = 0; 108 : 109 : /// compute in-place multiplication by a scalar 110 : virtual DataHandler<data_t>& operator*=(data_t scalar) = 0; 111 : 112 : /// compute in-place division by a scalar 113 : virtual DataHandler<data_t>& operator/=(data_t scalar) = 0; 114 : 115 : /// assign a scalar to all elements of the data vector 116 : virtual DataHandler<data_t>& operator=(data_t scalar) = 0; 117 : 118 : /// copy assignment operator 119 0 : DataHandler<data_t>& operator=(const DataHandler<data_t>& other) 120 : { 121 0 : if (other.getSize() != getSize()) 122 0 : throw InvalidArgumentError("DataHandler: assignment argument has wrong size"); 123 : 124 0 : assign(other); 125 0 : return *this; 126 : } 127 : 128 : /// move assignment operator 129 0 : DataHandler<data_t>& operator=(DataHandler<data_t>&& other) 130 : { 131 0 : if (other.getSize() != getSize()) 132 0 : throw InvalidArgumentError("DataHandler: assignment argument has wrong size"); 133 : 134 0 : assign(std::move(other)); 135 0 : return *this; 136 : } 137 : 138 : /// return a reference to the sequential block starting at startIndex and containing 139 : /// numberOfElements elements 140 : virtual std::unique_ptr<DataHandler<data_t>> getBlock(index_t startIndex, 141 : index_t numberOfElements) = 0; 142 : 143 : /// return a const reference to the sequential block starting at startIndex and containing 144 : /// numberOfElements elements 145 : virtual std::unique_ptr<const DataHandler<data_t>> 146 : getBlock(index_t startIndex, index_t numberOfElements) const = 0; 147 : 148 : protected: 149 : /// slow element-wise dot product fall-back for when DataHandler types do not match 150 0 : data_t slowDotProduct(const DataHandler<data_t>& v) const 151 : { 152 0 : data_t result = 0; 153 0 : for (index_t i = 0; i < getSize(); ++i) 154 0 : result += (*this)[i] * v[i]; 155 0 : return result; 156 : } 157 : 158 : /// slow element-wise addition fall-back for when DataHandler types do not match 159 0 : void slowAddition(const DataHandler<data_t>& v) 160 : { 161 0 : for (index_t i = 0; i < getSize(); ++i) 162 0 : (*this)[i] += v[i]; 163 0 : } 164 : 165 : /// slow element-wise subtraction fall-back for when DataHandler types do not match 166 0 : void slowSubtraction(const DataHandler<data_t>& v) 167 : { 168 0 : for (index_t i = 0; i < getSize(); ++i) 169 0 : (*this)[i] -= v[i]; 170 0 : } 171 : 172 : /// slow element-wise multiplication fall-back for when DataHandler types do not match 173 0 : void slowMultiplication(const DataHandler<data_t>& v) 174 : { 175 0 : for (index_t i = 0; i < getSize(); ++i) 176 0 : (*this)[i] *= v[i]; 177 0 : } 178 : 179 : /// slow element-wise division fall-back for when DataHandler types do not match 180 0 : void slowDivision(const DataHandler<data_t>& v) 181 : { 182 0 : for (index_t i = 0; i < getSize(); ++i) 183 0 : (*this)[i] /= v[i]; 184 0 : } 185 : 186 : /// slow element-wise assignment fall-back for when DataHandler types do not match 187 0 : void slowAssign(const DataHandler<data_t>& other) 188 : { 189 0 : for (index_t i = 0; i < getSize(); ++i) 190 0 : (*this)[i] = other[i]; 191 0 : } 192 : 193 : /// derived classes should override this method to implement copy assignment 194 : virtual void assign(const DataHandler<data_t>& other) = 0; 195 : 196 : /// derived classes should override this method to implement move assignment 197 : virtual void assign(DataHandler<data_t>&& other) = 0; 198 : }; 199 : } // namespace elsa