LCOV - code coverage report
Current view: top level - elsa/functionals - L2Reg.cpp (source / functions) Hit Total Coverage
Test: coverage-all.lcov Lines: 44 56 78.6 %
Date: 2024-05-16 04:22:26 Functions: 32 40 80.0 %

          Line data    Source code
       1             : #include "L2Reg.h"
       2             : 
       3             : #include "DataContainer.h"
       4             : #include "DataDescriptor.h"
       5             : #include "Identity.h"
       6             : #include "LinearOperator.h"
       7             : #include "TypeCasts.hpp"
       8             : 
       9             : namespace elsa
      10             : {
      11             :     template <typename data_t>
      12             :     L2Reg<data_t>::L2Reg(const DataDescriptor& domainDescriptor)
      13             :         : Functional<data_t>(domainDescriptor)
      14          24 :     {
      15          24 :     }
      16             : 
      17             :     template <typename data_t>
      18             :     L2Reg<data_t>::L2Reg(const LinearOperator<data_t>& A)
      19             :         : Functional<data_t>(A.getDomainDescriptor()), A_(A.clone())
      20          24 :     {
      21          24 :     }
      22             : 
      23             :     template <typename data_t>
      24             :     bool L2Reg<data_t>::isDifferentiable() const
      25           0 :     {
      26           0 :         return true;
      27           0 :     }
      28             : 
      29             :     template <typename data_t>
      30             :     bool L2Reg<data_t>::hasOperator() const
      31          32 :     {
      32          32 :         return static_cast<bool>(A_);
      33          32 :     }
      34             : 
      35             :     template <typename data_t>
      36             :     const LinearOperator<data_t>& L2Reg<data_t>::getOperator() const
      37           0 :     {
      38           0 :         if (!hasOperator()) {
      39           0 :             throw Error("L2Reg: No operator present");
      40           0 :         }
      41             : 
      42           0 :         return *A_;
      43           0 :     }
      44             : 
      45             :     template <typename data_t>
      46             :     data_t L2Reg<data_t>::evaluateImpl(const DataContainer<data_t>& x) const
      47           8 :     {
      48             :         // If we have an operator, apply it
      49           8 :         if (hasOperator()) {
      50           4 :             auto tmp = DataContainer<data_t>(A_->getRangeDescriptor());
      51           4 :             A_->apply(x, tmp);
      52             : 
      53           4 :             return data_t{0.5} * tmp.squaredL2Norm();
      54           4 :         }
      55             : 
      56           4 :         return data_t{0.5} * x.squaredL2Norm();
      57           4 :     }
      58             : 
      59             :     template <typename data_t>
      60             :     void L2Reg<data_t>::getGradientImpl(const DataContainer<data_t>& x,
      61             :                                         DataContainer<data_t>& out) const
      62           8 :     {
      63           8 :         if (hasOperator()) {
      64           4 :             auto temp = A_->apply(x);
      65             : 
      66             :             // Apply chain rule
      67           4 :             A_->applyAdjoint(temp, out);
      68           4 :         } else {
      69           4 :             out = x;
      70           4 :         }
      71           8 :     }
      72             : 
      73             :     template <typename data_t>
      74             :     LinearOperator<data_t> L2Reg<data_t>::getHessianImpl(const DataContainer<data_t>& Rx) const
      75           8 :     {
      76           8 :         if (hasOperator()) {
      77           4 :             return leaf(adjoint(*A_) * (*A_));
      78           4 :         }
      79           4 :         return leaf(Identity<data_t>(Rx.getDataDescriptor()));
      80           4 :     }
      81             : 
      82             :     template <typename data_t>
      83             :     L2Reg<data_t>* L2Reg<data_t>::cloneImpl() const
      84           8 :     {
      85           8 :         if (hasOperator()) {
      86           4 :             return new L2Reg(*A_);
      87           4 :         }
      88           4 :         return new L2Reg(this->getDomainDescriptor());
      89           4 :     }
      90             : 
      91             :     template <typename data_t>
      92             :     bool L2Reg<data_t>::isEqual(const Functional<data_t>& other) const
      93           8 :     {
      94           8 :         if (!Functional<data_t>::isEqual(other))
      95           0 :             return false;
      96             : 
      97           8 :         auto fn = downcast_safe<L2Reg<data_t>>(&other);
      98           8 :         if (!fn) {
      99           0 :             return false;
     100           0 :         }
     101             : 
     102           8 :         if (A_ && fn->A_) {
     103           4 :             return *A_ == *fn->A_;
     104           4 :         }
     105             : 
     106           4 :         return A_ == fn->A_;
     107           4 :     }
     108             : 
     109             :     // ------------------------------------------
     110             :     // explicit template instantiation
     111             :     template class L2Reg<float>;
     112             :     template class L2Reg<double>;
     113             :     template class L2Reg<complex<double>>;
     114             :     template class L2Reg<complex<float>>;
     115             : } // namespace elsa

Generated by: LCOV version 1.14