LCOV - code coverage report
Current view: top level - elsa/operators - FiniteDifferences.cpp (source / functions) Hit Total Coverage
Test: coverage-all.lcov Lines: 135 143 94.4 %
Date: 2024-05-16 04:22:26 Functions: 34 44 77.3 %

          Line data    Source code
       1             : #include "FiniteDifferences.h"
       2             : #include "Timer.h"
       3             : #include "VolumeDescriptor.h"
       4             : #include "TypeCasts.hpp"
       5             : #include "IdenticalBlocksDescriptor.h"
       6             : 
       7             : namespace elsa
       8             : {
       9             :     template <typename data_t>
      10             :     FiniteDifferences<data_t>::FiniteDifferences(const DataDescriptor& domainDescriptor,
      11             :                                                  DiffType type)
      12             :         : FiniteDifferences(domainDescriptor,
      13             :                             BooleanVector_t::Ones(domainDescriptor.getNumberOfDimensions()), type)
      14          52 :     {
      15          52 :     }
      16             : 
      17             :     template <typename data_t>
      18             :     FiniteDifferences<data_t>::FiniteDifferences(const DataDescriptor& domainDescriptor,
      19             :                                                  const BooleanVector_t& activeDims, DiffType type)
      20             :         : LinearOperator<data_t>(
      21             :             domainDescriptor,
      22             :             IdenticalBlocksDescriptor{activeDims.cast<index_t>().sum(), domainDescriptor}),
      23             :           _type{type},
      24             :           _activeDims{activeDims},
      25             :           _coordDiff{activeDims.size()},
      26             :           _coordDelta{activeDims.size()},
      27             :           _dimCounter{activeDims.size()}
      28         142 :     {
      29         142 :         precomputeHelpers();
      30         142 :     }
      31             : 
      32             :     template <typename data_t>
      33             :     void FiniteDifferences<data_t>::precomputeHelpers()
      34         142 :     {
      35         142 :         IndexVector_t numberOfCoefficients =
      36         142 :             this->_rangeDescriptor->getNumberOfCoefficientsPerDimension();
      37             : 
      38         142 :         index_t deltaTmp = 1;
      39         142 :         int count = -1;
      40         472 :         for (index_t ic = 0; ic < this->getDomainDescriptor().getNumberOfDimensions(); ++ic) {
      41         330 :             _coordDiff[ic] = numberOfCoefficients.head(ic).prod();
      42             : 
      43         330 :             deltaTmp *= numberOfCoefficients[ic];
      44         330 :             _coordDelta[ic] = deltaTmp;
      45             : 
      46         330 :             if (_activeDims[ic])
      47         224 :                 ++count;
      48         330 :             _dimCounter[ic] = count;
      49         330 :         }
      50         142 :     }
      51             : 
      52             :     template <typename data_t>
      53             :     void FiniteDifferences<data_t>::applyImpl(const DataContainer<data_t>& x,
      54             :                                               DataContainer<data_t>& Ax) const
      55         112 :     {
      56         112 :         Timer timeguard("FiniteDifferences", "apply");
      57             : 
      58         112 :         Ax = 0;
      59             : 
      60         112 :         switch (_type) {
      61          60 :             case DiffType::FORWARD:
      62          60 :                 applyHelper(x, Ax, DiffType::FORWARD);
      63          60 :                 break;
      64          30 :             case DiffType::BACKWARD:
      65          30 :                 applyHelper(x, Ax, DiffType::BACKWARD);
      66          30 :                 break;
      67          22 :             case DiffType::CENTRAL:
      68          22 :                 applyHelper(x, Ax, DiffType::CENTRAL);
      69          22 :                 break;
      70           0 :             default:
      71           0 :                 throw LogicError("FiniteDifferences::apply: invalid DiffType");
      72         112 :         }
      73         112 :     }
      74             : 
      75             :     template <typename data_t>
      76             :     void FiniteDifferences<data_t>::applyAdjointImpl(const DataContainer<data_t>& y,
      77             :                                                      DataContainer<data_t>& Aty) const
      78           6 :     {
      79           6 :         Timer timeguard("FiniteDifferences", "applyAdjoint");
      80             : 
      81           6 :         Aty = 0;
      82             : 
      83           6 :         switch (_type) {
      84           2 :             case DiffType::FORWARD:
      85           2 :                 applyAdjointHelper(y, Aty, DiffType::FORWARD);
      86           2 :                 break;
      87           2 :             case DiffType::BACKWARD:
      88           2 :                 applyAdjointHelper(y, Aty, DiffType::BACKWARD);
      89           2 :                 break;
      90           2 :             case DiffType::CENTRAL:
      91           2 :                 applyAdjointHelper(y, Aty, DiffType::CENTRAL);
      92           2 :                 break;
      93           0 :             default:
      94           0 :                 throw LogicError("FiniteDifferences::applyAdjoint: invalid DiffType");
      95           6 :         }
      96           6 :     }
      97             : 
      98             :     template <typename data_t>
      99             :     template <typename FDtype>
     100             :     void FiniteDifferences<data_t>::applyHelper(const DataContainer<data_t>& x,
     101             :                                                 DataContainer<data_t>& Ax, FDtype type) const
     102         112 :     {
     103         112 :         index_t sizeOfDomain = this->getDomainDescriptor().getNumberOfCoefficients();
     104         112 :         index_t numDim = this->getDomainDescriptor().getNumberOfDimensions();
     105             : 
     106         112 :         IndexVector_t numberOfCoefficients =
     107         112 :             this->getRangeDescriptor().getNumberOfCoefficientsPerDimension();
     108         112 :         IndexVector_t decrementedCoefficients =
     109         112 :             numberOfCoefficients
     110         112 :             - IndexVector_t::Ones(this->getRangeDescriptor().getNumberOfDimensions());
     111             : 
     112         112 : #pragma omp parallel
     113       27354 :         for (long currDim = 0; currDim < numDim; ++currDim) {
     114       13677 :             if (!_activeDims[currDim])
     115        3743 :                 continue;
     116             : 
     117        9934 :             index_t modulus = numberOfCoefficients.head(currDim + 1).prod();
     118        9934 :             index_t divisor = numberOfCoefficients.head(currDim).prod();
     119             : 
     120        9934 : #pragma omp for nowait
     121        9934 :             for (index_t id = 0; id < sizeOfDomain; ++id) {
     122        9934 :                 index_t icCount = (id % modulus) / divisor; //_domainDescriptor.index(id, ic);
     123        9934 :                 index_t ir = id + _dimCounter[currDim] * sizeOfDomain;
     124             : 
     125             :                 // store result depending on mode
     126        9934 :                 switch (type) {
     127        9934 :                     case DiffType::FORWARD:
     128        9934 :                         Ax[ir] = -x[id];
     129        9934 :                         if (icCount < decrementedCoefficients[currDim])
     130        9934 :                             Ax[ir] += x[id + _coordDiff[currDim]];
     131        9934 :                         break;
     132        9934 :                     case DiffType::BACKWARD:
     133        9934 :                         Ax[ir] = x[id];
     134        9934 :                         if (icCount > 0)
     135        9934 :                             Ax[ir] -= x[id - _coordDiff[currDim]];
     136        9934 :                         break;
     137        9934 :                     case DiffType::CENTRAL:
     138        9934 :                         Ax[ir] = static_cast<data_t>(0.0);
     139        9934 :                         if (icCount < decrementedCoefficients[currDim])
     140        9934 :                             Ax[ir] += static_cast<data_t>(0.5) * x[id + _coordDiff[currDim]];
     141        9934 :                         if (icCount > 0)
     142        9934 :                             Ax[ir] -= static_cast<data_t>(0.5) * x[id - _coordDiff[currDim]];
     143        9934 :                         break;
     144        9934 :                 }
     145        9934 :             }
     146        9934 :         }
     147         112 :     }
     148             : 
     149             :     template <typename data_t>
     150             :     template <typename FDtype>
     151             :     void FiniteDifferences<data_t>::applyAdjointHelper(const DataContainer<data_t>& y,
     152             :                                                        DataContainer<data_t>& Aty,
     153             :                                                        FDtype type) const
     154           6 :     {
     155           6 :         index_t sizeOfDomain = this->getDomainDescriptor().getNumberOfCoefficients();
     156           6 :         index_t numDim = this->getDomainDescriptor().getNumberOfDimensions();
     157             : 
     158           6 :         IndexVector_t numberOfCoefficients =
     159           6 :             this->getDomainDescriptor().getNumberOfCoefficientsPerDimension();
     160           6 :         IndexVector_t decrementedCoefficients = numberOfCoefficients - IndexVector_t::Ones(numDim);
     161             : 
     162           6 : #pragma omp parallel
     163         612 :         for (long currDim = 0; currDim < numDim; ++currDim) {
     164         306 :             if (!_activeDims[currDim])
     165           0 :                 continue;
     166             : 
     167         306 :             index_t modulus = numberOfCoefficients.head(currDim + 1).prod();
     168         306 :             index_t divisor = numberOfCoefficients.head(currDim).prod();
     169             : 
     170         306 : #pragma omp for nowait
     171         306 :             for (index_t id = 0; id < sizeOfDomain; ++id) {
     172         306 :                 index_t icCount = (id % modulus) / divisor;
     173         306 :                 index_t ir = id + _dimCounter[currDim] * _coordDelta[currDim];
     174             : 
     175         306 :                 switch (type) {
     176         306 :                     case DiffType::FORWARD:
     177         306 :                         if (icCount > 0)
     178         306 :                             Aty[id] += y[ir - _coordDiff[currDim]];
     179             : 
     180         306 :                         Aty[id] -= y[ir];
     181         306 :                         break;
     182         306 :                     case DiffType::BACKWARD:
     183         306 :                         if (icCount < decrementedCoefficients(currDim))
     184         306 :                             Aty[id] -= y[ir + _coordDiff[currDim]];
     185             : 
     186         306 :                         Aty[id] += y[ir];
     187         306 :                         break;
     188         306 :                     case DiffType::CENTRAL:
     189         306 :                         if (icCount > 0)
     190         306 :                             Aty[ir] += static_cast<data_t>(0.5) * y[ir - _coordDiff[currDim]];
     191             : 
     192         306 :                         if (icCount < decrementedCoefficients(currDim))
     193         306 :                             Aty[ir] -= static_cast<data_t>(0.5) * y[ir + _coordDiff[currDim]];
     194         306 :                         break;
     195         306 :                 }
     196         306 :             }
     197         306 :         }
     198           6 :     }
     199             : 
     200             :     template <typename data_t>
     201             :     FiniteDifferences<data_t>* FiniteDifferences<data_t>::cloneImpl() const
     202           2 :     {
     203           2 :         return new FiniteDifferences(this->getDomainDescriptor(), _activeDims, _type);
     204           2 :     }
     205             : 
     206             :     template <typename data_t>
     207             :     bool FiniteDifferences<data_t>::isEqual(const LinearOperator<data_t>& other) const
     208           2 :     {
     209           2 :         if (!LinearOperator<data_t>::isEqual(other))
     210           0 :             return false;
     211             : 
     212           2 :         auto otherFD = downcast_safe<FiniteDifferences>(&other);
     213           2 :         if (!otherFD)
     214           0 :             return false;
     215             : 
     216           2 :         if (_type != otherFD->_type || _activeDims != otherFD->_activeDims)
     217           0 :             return false;
     218             : 
     219           2 :         return true;
     220           2 :     }
     221             : 
     222             :     // ------------------------------------------
     223             :     // explicit template instantiation
     224             :     template class FiniteDifferences<float>;
     225             :     template class FiniteDifferences<double>;
     226             :     template class FiniteDifferences<complex<float>>;
     227             :     template class FiniteDifferences<complex<double>>;
     228             : 
     229             : } // namespace elsa

Generated by: LCOV version 1.14