Line data Source code
1 : #include "FiniteDifferences.h" 2 : #include "Timer.h" 3 : #include "VolumeDescriptor.h" 4 : #include "TypeCasts.hpp" 5 : 6 : namespace elsa 7 : { 8 : template <typename data_t> 9 : FiniteDifferences<data_t>::FiniteDifferences(const DataDescriptor& domainDescriptor, 10 : DiffType type) 11 : : FiniteDifferences(domainDescriptor, 12 : BooleanVector_t::Ones(domainDescriptor.getNumberOfDimensions()), type) 13 16 : { 14 16 : } 15 : 16 : template <typename data_t> 17 : FiniteDifferences<data_t>::FiniteDifferences(const DataDescriptor& domainDescriptor, 18 : const BooleanVector_t& activeDims, DiffType type) 19 : : LinearOperator<data_t>(domainDescriptor, 20 : domainDescriptor), // setting range in body of constructor 21 : _type{type}, 22 : _activeDims{activeDims}, 23 : _coordDiff{activeDims.size()}, 24 : _coordDelta{activeDims.size()}, 25 : _dimCounter{activeDims.size()} 26 30 : { 27 : // build the range descriptor of appropriate size 28 30 : IndexVector_t coefficients(domainDescriptor.getNumberOfDimensions() + 1); 29 30 : coefficients << domainDescriptor.getNumberOfCoefficientsPerDimension(), 30 30 : activeDims.cast<index_t>().sum(); 31 : 32 30 : RealVector_t spacing(domainDescriptor.getNumberOfDimensions() + 1); 33 30 : spacing << domainDescriptor.getSpacingPerDimension(), 1; 34 : 35 30 : this->_rangeDescriptor = std::make_unique<VolumeDescriptor>(coefficients, spacing); 36 : 37 30 : precomputeHelpers(); 38 30 : } 39 : 40 : template <typename data_t> 41 : void FiniteDifferences<data_t>::precomputeHelpers() 42 30 : { 43 30 : IndexVector_t numberOfCoefficients = 44 30 : this->_rangeDescriptor->getNumberOfCoefficientsPerDimension(); 45 : 46 30 : index_t deltaTmp = 1; 47 30 : int count = -1; 48 84 : for (index_t ic = 0; ic < this->getDomainDescriptor().getNumberOfDimensions(); ++ic) { 49 54 : _coordDiff[ic] = numberOfCoefficients.head(ic).prod(); 50 : 51 54 : deltaTmp *= numberOfCoefficients[ic]; 52 54 : _coordDelta[ic] = deltaTmp; 53 : 54 54 : if (_activeDims[ic]) 55 42 : ++count; 56 54 : _dimCounter[ic] = count; 57 54 : } 58 30 : } 59 : 60 : template <typename data_t> 61 : void FiniteDifferences<data_t>::applyImpl(const DataContainer<data_t>& x, 62 : DataContainer<data_t>& Ax) const 63 24 : { 64 24 : Timer timeguard("FiniteDifferences", "apply"); 65 : 66 24 : Ax = 0; 67 : 68 24 : switch (_type) { 69 8 : case DiffType::FORWARD: 70 8 : applyHelper(x, Ax, DiffType::FORWARD); 71 8 : break; 72 8 : case DiffType::BACKWARD: 73 8 : applyHelper(x, Ax, DiffType::BACKWARD); 74 8 : break; 75 8 : case DiffType::CENTRAL: 76 8 : applyHelper(x, Ax, DiffType::CENTRAL); 77 8 : break; 78 0 : default: 79 0 : throw LogicError("FiniteDifferences::apply: invalid DiffType"); 80 24 : } 81 24 : } 82 : 83 : template <typename data_t> 84 : void FiniteDifferences<data_t>::applyAdjointImpl(const DataContainer<data_t>& y, 85 : DataContainer<data_t>& Aty) const 86 6 : { 87 6 : Timer timeguard("FiniteDifferences", "applyAdjoint"); 88 : 89 6 : Aty = 0; 90 : 91 6 : switch (_type) { 92 2 : case DiffType::FORWARD: 93 2 : applyAdjointHelper(y, Aty, DiffType::FORWARD); 94 2 : break; 95 2 : case DiffType::BACKWARD: 96 2 : applyAdjointHelper(y, Aty, DiffType::BACKWARD); 97 2 : break; 98 2 : case DiffType::CENTRAL: 99 2 : applyAdjointHelper(y, Aty, DiffType::CENTRAL); 100 2 : break; 101 0 : default: 102 0 : throw LogicError("FiniteDifferences::applyAdjoint: invalid DiffType"); 103 6 : } 104 6 : } 105 : 106 : template <typename data_t> 107 : template <typename FDtype> 108 : void FiniteDifferences<data_t>::applyHelper(const DataContainer<data_t>& x, 109 : DataContainer<data_t>& Ax, FDtype type) const 110 24 : { 111 24 : index_t sizeOfDomain = this->getDomainDescriptor().getNumberOfCoefficients(); 112 24 : index_t numDim = this->getDomainDescriptor().getNumberOfDimensions(); 113 : 114 24 : IndexVector_t numberOfCoefficients = 115 24 : this->getRangeDescriptor().getNumberOfCoefficientsPerDimension(); 116 24 : IndexVector_t decrementedCoefficients = 117 24 : numberOfCoefficients 118 24 : - IndexVector_t::Ones(this->getRangeDescriptor().getNumberOfDimensions()); 119 : 120 24 : #pragma omp parallel 121 3190 : for (long currDim = 0; currDim < numDim; ++currDim) { 122 1595 : if (!_activeDims[currDim]) 123 488 : continue; 124 : 125 1107 : index_t modulus = numberOfCoefficients.head(currDim + 1).prod(); 126 1107 : index_t divisor = numberOfCoefficients.head(currDim).prod(); 127 : 128 1107 : #pragma omp for nowait 129 1107 : for (index_t id = 0; id < sizeOfDomain; ++id) { 130 1107 : index_t icCount = (id % modulus) / divisor; //_domainDescriptor.index(id, ic); 131 1107 : index_t ir = id + _dimCounter[currDim] * _coordDelta[currDim]; 132 : 133 : // store result depending on mode 134 1107 : switch (type) { 135 1107 : case DiffType::FORWARD: 136 1107 : Ax[ir] = -x[id]; 137 1107 : if (icCount < decrementedCoefficients[currDim]) 138 1107 : Ax[ir] += x[id + _coordDiff[currDim]]; 139 1107 : break; 140 1107 : case DiffType::BACKWARD: 141 1107 : Ax[ir] = x[id]; 142 1107 : if (icCount > 0) 143 1107 : Ax[ir] -= x[id - _coordDiff[currDim]]; 144 1107 : break; 145 1107 : case DiffType::CENTRAL: 146 1107 : Ax[ir] = static_cast<data_t>(0.0); 147 1107 : if (icCount < decrementedCoefficients[currDim]) 148 1107 : Ax[ir] += static_cast<data_t>(0.5) * x[id + _coordDiff[currDim]]; 149 1107 : if (icCount > 0) 150 1107 : Ax[ir] -= static_cast<data_t>(0.5) * x[id - _coordDiff[currDim]]; 151 1107 : break; 152 1107 : } 153 1107 : } 154 1107 : } 155 24 : } 156 : 157 : template <typename data_t> 158 : template <typename FDtype> 159 : void FiniteDifferences<data_t>::applyAdjointHelper(const DataContainer<data_t>& y, 160 : DataContainer<data_t>& Aty, 161 : FDtype type) const 162 6 : { 163 6 : index_t sizeOfDomain = this->getDomainDescriptor().getNumberOfCoefficients(); 164 6 : index_t numDim = this->getDomainDescriptor().getNumberOfDimensions(); 165 : 166 6 : IndexVector_t numberOfCoefficients = 167 6 : this->getDomainDescriptor().getNumberOfCoefficientsPerDimension(); 168 6 : IndexVector_t decrementedCoefficients = numberOfCoefficients - IndexVector_t::Ones(numDim); 169 : 170 6 : #pragma omp parallel 171 394 : for (long currDim = 0; currDim < numDim; ++currDim) { 172 197 : if (!_activeDims[currDim]) 173 0 : continue; 174 : 175 197 : index_t modulus = numberOfCoefficients.head(currDim + 1).prod(); 176 197 : index_t divisor = numberOfCoefficients.head(currDim).prod(); 177 : 178 197 : #pragma omp for nowait 179 197 : for (index_t id = 0; id < sizeOfDomain; ++id) { 180 197 : index_t icCount = (id % modulus) / divisor; 181 197 : index_t ir = id + _dimCounter[currDim] * _coordDelta[currDim]; 182 : 183 197 : switch (type) { 184 197 : case DiffType::FORWARD: 185 197 : if (icCount > 0) 186 197 : Aty[id] += y[ir - _coordDiff[currDim]]; 187 : 188 197 : Aty[id] -= y[ir]; 189 197 : break; 190 197 : case DiffType::BACKWARD: 191 197 : if (icCount < decrementedCoefficients(currDim)) 192 197 : Aty[id] -= y[ir + _coordDiff[currDim]]; 193 : 194 197 : Aty[id] += y[ir]; 195 197 : break; 196 197 : case DiffType::CENTRAL: 197 197 : if (icCount > 0) 198 197 : Aty[ir] += static_cast<data_t>(0.5) * y[ir - _coordDiff[currDim]]; 199 : 200 197 : if (icCount < decrementedCoefficients(currDim)) 201 197 : Aty[ir] -= static_cast<data_t>(0.5) * y[ir + _coordDiff[currDim]]; 202 197 : break; 203 197 : } 204 197 : } 205 197 : } 206 6 : } 207 : 208 : template <typename data_t> 209 : FiniteDifferences<data_t>* FiniteDifferences<data_t>::cloneImpl() const 210 2 : { 211 2 : return new FiniteDifferences(this->getDomainDescriptor(), _activeDims, _type); 212 2 : } 213 : 214 : template <typename data_t> 215 : bool FiniteDifferences<data_t>::isEqual(const LinearOperator<data_t>& other) const 216 2 : { 217 2 : if (!LinearOperator<data_t>::isEqual(other)) 218 0 : return false; 219 : 220 2 : auto otherFD = downcast_safe<FiniteDifferences>(&other); 221 2 : if (!otherFD) 222 0 : return false; 223 : 224 2 : if (_type != otherFD->_type || _activeDims != otherFD->_activeDims) 225 0 : return false; 226 : 227 2 : return true; 228 2 : } 229 : 230 : // ------------------------------------------ 231 : // explicit template instantiation 232 : template class FiniteDifferences<float>; 233 : template class FiniteDifferences<double>; 234 : template class FiniteDifferences<complex<float>>; 235 : template class FiniteDifferences<complex<double>>; 236 : 237 : } // namespace elsa