LCOV - code coverage report
Current view: top level - elsa/functionals - WeightedL2NormPow2.cpp (source / functions) Hit Total Coverage
Test: coverage-all.lcov Lines: 31 36 86.1 %
Date: 2022-08-25 03:05:39 Functions: 30 32 93.8 %

          Line data    Source code
       1             : #include "WeightedL2NormPow2.h"
       2             : #include "LinearOperator.h"
       3             : #include "TypeCasts.hpp"
       4             : 
       5             : #include <stdexcept>
       6             : 
       7             : namespace elsa
       8             : {
       9             :     template <typename data_t>
      10             :     WeightedL2NormPow2<data_t>::WeightedL2NormPow2(const Scaling<data_t>& weightingOp)
      11             :         : Functional<data_t>(weightingOp.getDomainDescriptor()),
      12             :           _weightingOp{static_cast<Scaling<data_t>*>(weightingOp.clone().release())}
      13          46 :     {
      14          46 :     }
      15             : 
      16             :     template <typename data_t>
      17             :     WeightedL2NormPow2<data_t>::WeightedL2NormPow2(const Residual<data_t>& residual,
      18             :                                                    const Scaling<data_t>& weightingOp)
      19             :         : Functional<data_t>(residual),
      20             :           _weightingOp{static_cast<Scaling<data_t>*>(weightingOp.clone().release())}
      21         198 :     {
      22             :         // sanity check
      23         198 :         if (residual.getRangeDescriptor().getNumberOfCoefficients()
      24         198 :             != weightingOp.getDomainDescriptor().getNumberOfCoefficients())
      25           0 :             throw InvalidArgumentError(
      26           0 :                 "WeightedL2NormPow2: sizes of residual and weighting operator do not match");
      27         198 :     }
      28             : 
      29             :     template <typename data_t>
      30             :     const Scaling<data_t>& WeightedL2NormPow2<data_t>::getWeightingOperator() const
      31          16 :     {
      32          16 :         return *_weightingOp;
      33          16 :     }
      34             : 
      35             :     template <typename data_t>
      36             :     data_t WeightedL2NormPow2<data_t>::evaluateImpl(const DataContainer<data_t>& Rx)
      37           8 :     {
      38           8 :         auto temp = _weightingOp->apply(Rx);
      39             : 
      40           8 :         return static_cast<data_t>(0.5) * Rx.dot(temp);
      41           8 :     }
      42             : 
      43             :     template <typename data_t>
      44             :     void WeightedL2NormPow2<data_t>::getGradientInPlaceImpl(DataContainer<data_t>& Rx)
      45          12 :     {
      46          12 :         auto temp = _weightingOp->apply(Rx);
      47          12 :         Rx = temp;
      48          12 :     }
      49             : 
      50             :     template <typename data_t>
      51             :     LinearOperator<data_t>
      52             :         WeightedL2NormPow2<data_t>::getHessianImpl([[maybe_unused]] const DataContainer<data_t>& Rx)
      53          12 :     {
      54          12 :         return leaf(*_weightingOp);
      55          12 :     }
      56             : 
      57             :     template <typename data_t>
      58             :     WeightedL2NormPow2<data_t>* WeightedL2NormPow2<data_t>::cloneImpl() const
      59          62 :     {
      60             :         // this ugly cast has to go away at some point..
      61             :         // Still not nice, but still safe as _weightingOp is allways of type Scaling
      62          62 :         const auto& scaling = downcast<Scaling<data_t>>(*_weightingOp);
      63          62 :         return new WeightedL2NormPow2(this->getResidual(), scaling);
      64          62 :     }
      65             : 
      66             :     template <typename data_t>
      67             :     bool WeightedL2NormPow2<data_t>::isEqual(const Functional<data_t>& other) const
      68          18 :     {
      69          18 :         if (!Functional<data_t>::isEqual(other))
      70           0 :             return false;
      71             : 
      72          18 :         auto otherWL2 = downcast_safe<WeightedL2NormPow2>(&other);
      73          18 :         if (!otherWL2)
      74           0 :             return false;
      75             : 
      76          18 :         if (*_weightingOp != *otherWL2->_weightingOp)
      77           0 :             return false;
      78             : 
      79          18 :         return true;
      80          18 :     }
      81             : 
      82             :     // ------------------------------------------
      83             :     // explicit template instantiation
      84             :     template class WeightedL2NormPow2<float>;
      85             :     template class WeightedL2NormPow2<double>;
      86             :     template class WeightedL2NormPow2<complex<float>>;
      87             :     template class WeightedL2NormPow2<complex<double>>;
      88             : 
      89             : } // namespace elsa

Generated by: LCOV version 1.14