LCOV - code coverage report
Current view: top level - elsa/operators - FiniteDifferences.cpp (source / functions) Hit Total Coverage
Test: coverage-all.lcov Lines: 141 149 94.6 %
Date: 2022-08-25 03:05:39 Functions: 22 44 50.0 %

          Line data    Source code
       1             : #include "FiniteDifferences.h"
       2             : #include "Timer.h"
       3             : #include "VolumeDescriptor.h"
       4             : #include "TypeCasts.hpp"
       5             : 
       6             : namespace elsa
       7             : {
       8             :     template <typename data_t>
       9             :     FiniteDifferences<data_t>::FiniteDifferences(const DataDescriptor& domainDescriptor,
      10             :                                                  DiffType type)
      11             :         : FiniteDifferences(domainDescriptor,
      12             :                             BooleanVector_t::Ones(domainDescriptor.getNumberOfDimensions()), type)
      13          16 :     {
      14          16 :     }
      15             : 
      16             :     template <typename data_t>
      17             :     FiniteDifferences<data_t>::FiniteDifferences(const DataDescriptor& domainDescriptor,
      18             :                                                  const BooleanVector_t& activeDims, DiffType type)
      19             :         : LinearOperator<data_t>(domainDescriptor,
      20             :                                  domainDescriptor), // setting range in body of constructor
      21             :           _type{type},
      22             :           _activeDims{activeDims},
      23             :           _coordDiff{activeDims.size()},
      24             :           _coordDelta{activeDims.size()},
      25             :           _dimCounter{activeDims.size()}
      26          30 :     {
      27             :         // build the range descriptor of appropriate size
      28          30 :         IndexVector_t coefficients(domainDescriptor.getNumberOfDimensions() + 1);
      29          30 :         coefficients << domainDescriptor.getNumberOfCoefficientsPerDimension(),
      30          30 :             activeDims.cast<index_t>().sum();
      31             : 
      32          30 :         RealVector_t spacing(domainDescriptor.getNumberOfDimensions() + 1);
      33          30 :         spacing << domainDescriptor.getSpacingPerDimension(), 1;
      34             : 
      35          30 :         this->_rangeDescriptor = std::make_unique<VolumeDescriptor>(coefficients, spacing);
      36             : 
      37          30 :         precomputeHelpers();
      38          30 :     }
      39             : 
      40             :     template <typename data_t>
      41             :     void FiniteDifferences<data_t>::precomputeHelpers()
      42          30 :     {
      43          30 :         IndexVector_t numberOfCoefficients =
      44          30 :             this->_rangeDescriptor->getNumberOfCoefficientsPerDimension();
      45             : 
      46          30 :         index_t deltaTmp = 1;
      47          30 :         int count = -1;
      48          84 :         for (index_t ic = 0; ic < this->getDomainDescriptor().getNumberOfDimensions(); ++ic) {
      49          54 :             _coordDiff[ic] = numberOfCoefficients.head(ic).prod();
      50             : 
      51          54 :             deltaTmp *= numberOfCoefficients[ic];
      52          54 :             _coordDelta[ic] = deltaTmp;
      53             : 
      54          54 :             if (_activeDims[ic])
      55          42 :                 ++count;
      56          54 :             _dimCounter[ic] = count;
      57          54 :         }
      58          30 :     }
      59             : 
      60             :     template <typename data_t>
      61             :     void FiniteDifferences<data_t>::applyImpl(const DataContainer<data_t>& x,
      62             :                                               DataContainer<data_t>& Ax) const
      63          24 :     {
      64          24 :         Timer timeguard("FiniteDifferences", "apply");
      65             : 
      66          24 :         Ax = 0;
      67             : 
      68          24 :         switch (_type) {
      69           8 :             case DiffType::FORWARD:
      70           8 :                 applyHelper(x, Ax, DiffType::FORWARD);
      71           8 :                 break;
      72           8 :             case DiffType::BACKWARD:
      73           8 :                 applyHelper(x, Ax, DiffType::BACKWARD);
      74           8 :                 break;
      75           8 :             case DiffType::CENTRAL:
      76           8 :                 applyHelper(x, Ax, DiffType::CENTRAL);
      77           8 :                 break;
      78           0 :             default:
      79           0 :                 throw LogicError("FiniteDifferences::apply: invalid DiffType");
      80          24 :         }
      81          24 :     }
      82             : 
      83             :     template <typename data_t>
      84             :     void FiniteDifferences<data_t>::applyAdjointImpl(const DataContainer<data_t>& y,
      85             :                                                      DataContainer<data_t>& Aty) const
      86           6 :     {
      87           6 :         Timer timeguard("FiniteDifferences", "applyAdjoint");
      88             : 
      89           6 :         Aty = 0;
      90             : 
      91           6 :         switch (_type) {
      92           2 :             case DiffType::FORWARD:
      93           2 :                 applyAdjointHelper(y, Aty, DiffType::FORWARD);
      94           2 :                 break;
      95           2 :             case DiffType::BACKWARD:
      96           2 :                 applyAdjointHelper(y, Aty, DiffType::BACKWARD);
      97           2 :                 break;
      98           2 :             case DiffType::CENTRAL:
      99           2 :                 applyAdjointHelper(y, Aty, DiffType::CENTRAL);
     100           2 :                 break;
     101           0 :             default:
     102           0 :                 throw LogicError("FiniteDifferences::applyAdjoint: invalid DiffType");
     103           6 :         }
     104           6 :     }
     105             : 
     106             :     template <typename data_t>
     107             :     template <typename FDtype>
     108             :     void FiniteDifferences<data_t>::applyHelper(const DataContainer<data_t>& x,
     109             :                                                 DataContainer<data_t>& Ax, FDtype type) const
     110          24 :     {
     111          24 :         index_t sizeOfDomain = this->getDomainDescriptor().getNumberOfCoefficients();
     112          24 :         index_t numDim = this->getDomainDescriptor().getNumberOfDimensions();
     113             : 
     114          24 :         IndexVector_t numberOfCoefficients =
     115          24 :             this->getRangeDescriptor().getNumberOfCoefficientsPerDimension();
     116          24 :         IndexVector_t decrementedCoefficients =
     117          24 :             numberOfCoefficients
     118          24 :             - IndexVector_t::Ones(this->getRangeDescriptor().getNumberOfDimensions());
     119             : 
     120          24 : #pragma omp parallel
     121        3190 :         for (long currDim = 0; currDim < numDim; ++currDim) {
     122        1595 :             if (!_activeDims[currDim])
     123         488 :                 continue;
     124             : 
     125        1107 :             index_t modulus = numberOfCoefficients.head(currDim + 1).prod();
     126        1107 :             index_t divisor = numberOfCoefficients.head(currDim).prod();
     127             : 
     128        1107 : #pragma omp for nowait
     129        1107 :             for (index_t id = 0; id < sizeOfDomain; ++id) {
     130        1107 :                 index_t icCount = (id % modulus) / divisor; //_domainDescriptor.index(id, ic);
     131        1107 :                 index_t ir = id + _dimCounter[currDim] * _coordDelta[currDim];
     132             : 
     133             :                 // store result depending on mode
     134        1107 :                 switch (type) {
     135        1107 :                     case DiffType::FORWARD:
     136        1107 :                         Ax[ir] = -x[id];
     137        1107 :                         if (icCount < decrementedCoefficients[currDim])
     138        1107 :                             Ax[ir] += x[id + _coordDiff[currDim]];
     139        1107 :                         break;
     140        1107 :                     case DiffType::BACKWARD:
     141        1107 :                         Ax[ir] = x[id];
     142        1107 :                         if (icCount > 0)
     143        1107 :                             Ax[ir] -= x[id - _coordDiff[currDim]];
     144        1107 :                         break;
     145        1107 :                     case DiffType::CENTRAL:
     146        1107 :                         Ax[ir] = static_cast<data_t>(0.0);
     147        1107 :                         if (icCount < decrementedCoefficients[currDim])
     148        1107 :                             Ax[ir] += static_cast<data_t>(0.5) * x[id + _coordDiff[currDim]];
     149        1107 :                         if (icCount > 0)
     150        1107 :                             Ax[ir] -= static_cast<data_t>(0.5) * x[id - _coordDiff[currDim]];
     151        1107 :                         break;
     152        1107 :                 }
     153        1107 :             }
     154        1107 :         }
     155          24 :     }
     156             : 
     157             :     template <typename data_t>
     158             :     template <typename FDtype>
     159             :     void FiniteDifferences<data_t>::applyAdjointHelper(const DataContainer<data_t>& y,
     160             :                                                        DataContainer<data_t>& Aty,
     161             :                                                        FDtype type) const
     162           6 :     {
     163           6 :         index_t sizeOfDomain = this->getDomainDescriptor().getNumberOfCoefficients();
     164           6 :         index_t numDim = this->getDomainDescriptor().getNumberOfDimensions();
     165             : 
     166           6 :         IndexVector_t numberOfCoefficients =
     167           6 :             this->getDomainDescriptor().getNumberOfCoefficientsPerDimension();
     168           6 :         IndexVector_t decrementedCoefficients = numberOfCoefficients - IndexVector_t::Ones(numDim);
     169             : 
     170           6 : #pragma omp parallel
     171         394 :         for (long currDim = 0; currDim < numDim; ++currDim) {
     172         197 :             if (!_activeDims[currDim])
     173           0 :                 continue;
     174             : 
     175         197 :             index_t modulus = numberOfCoefficients.head(currDim + 1).prod();
     176         197 :             index_t divisor = numberOfCoefficients.head(currDim).prod();
     177             : 
     178         197 : #pragma omp for nowait
     179         197 :             for (index_t id = 0; id < sizeOfDomain; ++id) {
     180         197 :                 index_t icCount = (id % modulus) / divisor;
     181         197 :                 index_t ir = id + _dimCounter[currDim] * _coordDelta[currDim];
     182             : 
     183         197 :                 switch (type) {
     184         197 :                     case DiffType::FORWARD:
     185         197 :                         if (icCount > 0)
     186         197 :                             Aty[id] += y[ir - _coordDiff[currDim]];
     187             : 
     188         197 :                         Aty[id] -= y[ir];
     189         197 :                         break;
     190         197 :                     case DiffType::BACKWARD:
     191         197 :                         if (icCount < decrementedCoefficients(currDim))
     192         197 :                             Aty[id] -= y[ir + _coordDiff[currDim]];
     193             : 
     194         197 :                         Aty[id] += y[ir];
     195         197 :                         break;
     196         197 :                     case DiffType::CENTRAL:
     197         197 :                         if (icCount > 0)
     198         197 :                             Aty[ir] += static_cast<data_t>(0.5) * y[ir - _coordDiff[currDim]];
     199             : 
     200         197 :                         if (icCount < decrementedCoefficients(currDim))
     201         197 :                             Aty[ir] -= static_cast<data_t>(0.5) * y[ir + _coordDiff[currDim]];
     202         197 :                         break;
     203         197 :                 }
     204         197 :             }
     205         197 :         }
     206           6 :     }
     207             : 
     208             :     template <typename data_t>
     209             :     FiniteDifferences<data_t>* FiniteDifferences<data_t>::cloneImpl() const
     210           2 :     {
     211           2 :         return new FiniteDifferences(this->getDomainDescriptor(), _activeDims, _type);
     212           2 :     }
     213             : 
     214             :     template <typename data_t>
     215             :     bool FiniteDifferences<data_t>::isEqual(const LinearOperator<data_t>& other) const
     216           2 :     {
     217           2 :         if (!LinearOperator<data_t>::isEqual(other))
     218           0 :             return false;
     219             : 
     220           2 :         auto otherFD = downcast_safe<FiniteDifferences>(&other);
     221           2 :         if (!otherFD)
     222           0 :             return false;
     223             : 
     224           2 :         if (_type != otherFD->_type || _activeDims != otherFD->_activeDims)
     225           0 :             return false;
     226             : 
     227           2 :         return true;
     228           2 :     }
     229             : 
     230             :     // ------------------------------------------
     231             :     // explicit template instantiation
     232             :     template class FiniteDifferences<float>;
     233             :     template class FiniteDifferences<double>;
     234             :     template class FiniteDifferences<complex<float>>;
     235             :     template class FiniteDifferences<complex<double>>;
     236             : 
     237             : } // namespace elsa

Generated by: LCOV version 1.14