LCOV - code coverage report
Current view: top level - elsa/functionals - LeastSquares.cpp (source / functions) Hit Total Coverage
Test: coverage-all.lcov Lines: 32 33 97.0 %
Date: 2024-05-16 04:22:26 Functions: 18 18 100.0 %

          Line data    Source code
       1             : #include "LeastSquares.h"
       2             : 
       3             : #include "DataContainer.h"
       4             : #include "DataDescriptor.h"
       5             : #include "LinearOperator.h"
       6             : #include "TypeCasts.hpp"
       7             : 
       8             : namespace elsa
       9             : {
      10             :     template <typename data_t>
      11             :     LeastSquares<data_t>::LeastSquares(const LinearOperator<data_t>& A,
      12             :                                        const DataContainer<data_t>& b)
      13             :         : Functional<data_t>(A.getDomainDescriptor()), A_(A.clone()), b_(b)
      14         358 :     {
      15         358 :     }
      16             : 
      17             :     template <typename data_t>
      18             :     bool LeastSquares<data_t>::isDifferentiable() const
      19          18 :     {
      20          18 :         return true;
      21          18 :     }
      22             : 
      23             :     template <typename data_t>
      24             :     const LinearOperator<data_t>& LeastSquares<data_t>::getOperator() const
      25          20 :     {
      26          20 :         return *A_;
      27          20 :     }
      28             : 
      29             :     template <typename data_t>
      30             :     const DataContainer<data_t>& LeastSquares<data_t>::getDataVector() const
      31          20 :     {
      32          20 :         return b_;
      33          20 :     }
      34             : 
      35             :     template <typename data_t>
      36             :     data_t LeastSquares<data_t>::evaluateImpl(const DataContainer<data_t>& x) const
      37        1716 :     {
      38        1716 :         auto Ax = A_->apply(x);
      39        1716 :         Ax -= b_;
      40             : 
      41        1716 :         return static_cast<data_t>(0.5) * Ax.squaredL2Norm();
      42        1716 :     }
      43             : 
      44             :     template <typename data_t>
      45             :     void LeastSquares<data_t>::getGradientImpl(const DataContainer<data_t>& x,
      46             :                                                DataContainer<data_t>& out) const
      47        3223 :     {
      48        3223 :         auto temp = A_->apply(x);
      49        3223 :         temp -= b_;
      50             : 
      51             :         // Apply chain rule
      52        3223 :         A_->applyAdjoint(temp, out);
      53        3223 :     }
      54             : 
      55             :     template <typename data_t>
      56             :     LinearOperator<data_t> LeastSquares<data_t>::getHessianImpl(const DataContainer<data_t>&) const
      57          20 :     {
      58          20 :         return leaf(adjoint(*A_) * (*A_));
      59          20 :     }
      60             : 
      61             :     template <typename data_t>
      62             :     LeastSquares<data_t>* LeastSquares<data_t>::cloneImpl() const
      63         252 :     {
      64         252 :         return new LeastSquares(*A_, b_);
      65         252 :     }
      66             : 
      67             :     template <typename data_t>
      68             :     bool LeastSquares<data_t>::isEqual(const Functional<data_t>& other) const
      69           4 :     {
      70           4 :         if (!Functional<data_t>::isEqual(other))
      71           0 :             return false;
      72             : 
      73           4 :         auto fn = downcast_safe<LeastSquares<data_t>>(&other);
      74           4 :         return fn && *A_ == *fn->A_ && b_ == fn->b_;
      75           4 :     }
      76             : 
      77             :     // ------------------------------------------
      78             :     // explicit template instantiation
      79             :     template class LeastSquares<float>;
      80             :     template class LeastSquares<double>;
      81             : } // namespace elsa

Generated by: LCOV version 1.14