LCOV - code coverage report
Current view: top level - elsa/operators - SymmetrizedDerivative.cpp (source / functions) Hit Total Coverage
Test: coverage-all.lcov Lines: 59 62 95.2 %
Date: 2024-05-16 04:22:26 Functions: 14 14 100.0 %

          Line data    Source code
       1             : #include "SymmetrizedDerivative.h"
       2             : #include "Timer.h"
       3             : #include "Identity.h"
       4             : #include "FiniteDifferences.h"
       5             : #include "BlockLinearOperator.h"
       6             : #include "DataContainer.h"
       7             : 
       8             : namespace elsa
       9             : {
      10             : 
      11             :     using namespace std;
      12             : 
      13             :     template <typename data_t>
      14             :     SymmetrizedDerivative<data_t>::SymmetrizedDerivative(const DataDescriptor& domainDescriptor)
      15             :         : LinearOperator<data_t>(createBase(domainDescriptor))
      16          18 :     {
      17          18 :         precomputeHelpers();
      18          18 :     }
      19             : 
      20             :     using namespace std;
      21             : 
      22             :     template <typename data_t>
      23             :     void SymmetrizedDerivative<data_t>::applyImpl(const DataContainer<data_t>& x,
      24             :                                                   DataContainer<data_t>& Ax) const
      25           2 :     {
      26           2 :         Timer timeguard("SymmetrizedDerivative", "apply");
      27             : 
      28             :         // D_1 x_1
      29           2 :         Ax.getBlock(0) = forwardX_->apply(x.getBlock(0));
      30             : 
      31             :         // D_2 x_2
      32           2 :         Ax.getBlock(1) = forwardY_->apply(x.getBlock(1));
      33             : 
      34             :         // 0.5 (D_2 x_1 + D_1 x_2)
      35           2 :         Ax.getBlock(2) = lincomb(data_t(0.5), forwardY_->apply(x.getBlock(0)), data_t(0.5),
      36           2 :                                  forwardX_->apply(x.getBlock(1)));
      37           2 :     }
      38             : 
      39             :     template <typename data_t>
      40             :     void SymmetrizedDerivative<data_t>::applyAdjointImpl(const DataContainer<data_t>& y,
      41             :                                                          DataContainer<data_t>& Aty) const
      42           2 :     {
      43           2 :         Timer timeguard("SymmetrizedDerivative", "applyAdjoint");
      44             : 
      45             :         //-D_1^{-} y_1 - D_2^{-} y_3
      46           2 :         Aty.getBlock(0) = lincomb(data_t(-1), backwardX_->apply(y.getBlock(0)), data_t(-1),
      47           2 :                                   backwardY_->apply(y.getBlock(2)));
      48             : 
      49             :         //-D_1^{-} y_3 - D_2^{-} y_2
      50           2 :         Aty.getBlock(1) = lincomb(data_t(-1), backwardX_->apply(y.getBlock(2)), data_t(-1),
      51           2 :                                   backwardY_->apply(y.getBlock(1)));
      52           2 :     }
      53             : 
      54             :     template <typename data_t>
      55             :     void SymmetrizedDerivative<data_t>::precomputeHelpers()
      56          10 :     {
      57             : 
      58          10 :         auto& domainBlocked = downcast_safe<BlockDescriptor>(this->getDomainDescriptor());
      59             : 
      60          10 :         BooleanVector_t bv1(2);
      61          10 :         bv1 << true, false;
      62          10 :         BooleanVector_t bv2(2);
      63          10 :         bv2 << false, true;
      64             : 
      65          10 :         forwardX_ =
      66          10 :             make_unique<FiniteDifferences<data_t>>(domainBlocked.getDescriptorOfBlock(0), bv1);
      67          10 :         forwardY_ =
      68          10 :             make_unique<FiniteDifferences<data_t>>(domainBlocked.getDescriptorOfBlock(0), bv2);
      69          10 :         backwardX_ =
      70          10 :             make_unique<FiniteDifferences<data_t>>(domainBlocked.getDescriptorOfBlock(0), bv1,
      71          10 :                                                    FiniteDifferences<data_t>::DiffType::BACKWARD);
      72          10 :         backwardY_ =
      73          10 :             make_unique<FiniteDifferences<data_t>>(domainBlocked.getDescriptorOfBlock(0), bv2,
      74          10 :                                                    FiniteDifferences<data_t>::DiffType::BACKWARD);
      75          10 :     }
      76             : 
      77             :     template <typename data_t>
      78             :     LinearOperator<data_t>
      79             :         SymmetrizedDerivative<data_t>::createBase(const DataDescriptor& domainDescriptor)
      80          18 :     {
      81             : 
      82          18 :         if (!is<IdenticalBlocksDescriptor>(domainDescriptor))
      83           4 :             throw LogicError("SymmetrizedDerivative: cannot construct. Domain should be "
      84           4 :                              "IdenticalBlocksDescriptor");
      85             : 
      86          14 :         auto& domainBlocked = downcast_safe<BlockDescriptor>(domainDescriptor);
      87          14 :         if (domainBlocked.getNumberOfBlocks() != 2)
      88           2 :             throw LogicError(
      89           2 :                 "SymmetrizedDerivative: cannot construct. Domain should have 2 blocks");
      90             : 
      91          12 :         if (domainBlocked.getDescriptorOfBlock(0).getNumberOfDimensions() != 2)
      92           2 :             throw LogicError("SymmetrizedDerivative: cannot construct. Domain blocks should be "
      93           2 :                              "identical and 2-dimensional");
      94             : 
      95          10 :         return LinearOperator<data_t>(
      96          10 :             domainBlocked, IdenticalBlocksDescriptor{3, domainBlocked.getDescriptorOfBlock(0)});
      97          10 :     }
      98             : 
      99             :     template <typename data_t>
     100             :     SymmetrizedDerivative<data_t>* SymmetrizedDerivative<data_t>::cloneImpl() const
     101           2 :     {
     102           2 :         return new SymmetrizedDerivative(this->getDomainDescriptor());
     103           2 :     }
     104             : 
     105             :     template <typename data_t>
     106             :     bool SymmetrizedDerivative<data_t>::isEqual(const LinearOperator<data_t>& other) const
     107           2 :     {
     108           2 :         if (!LinearOperator<data_t>::isEqual(other))
     109           0 :             return false;
     110             : 
     111           2 :         auto otherSD = downcast_safe<SymmetrizedDerivative>(&other);
     112           2 :         if (!otherSD)
     113           0 :             return false;
     114             : 
     115           2 :         if (otherSD->getDomainDescriptor() != this->getDomainDescriptor()
     116           2 :             || otherSD->getRangeDescriptor() != this->getRangeDescriptor())
     117           0 :             return false;
     118             : 
     119           2 :         return true;
     120           2 :     }
     121             : 
     122             :     // ------------------------------------------
     123             :     // explicit template instantiation
     124             :     template class SymmetrizedDerivative<float>;
     125             :     template class SymmetrizedDerivative<double>;
     126             : 
     127             : } // namespace elsa

Generated by: LCOV version 1.14