LCOV - code coverage report
Current view: top level - elsa/functionals - WeightedLeastSquares.cpp (source / functions) Hit Total Coverage
Test: coverage-all.lcov Lines: 22 39 56.4 %
Date: 2024-05-16 04:22:26 Functions: 10 18 55.6 %

          Line data    Source code
       1             : #include "WeightedLeastSquares.h"
       2             : 
       3             : #include "DataContainer.h"
       4             : #include "DataDescriptor.h"
       5             : #include "Error.h"
       6             : #include "LinearOperator.h"
       7             : #include "TypeCasts.hpp"
       8             : 
       9             : namespace elsa
      10             : {
      11             :     template <typename data_t>
      12             :     WeightedLeastSquares<data_t>::WeightedLeastSquares(const LinearOperator<data_t>& A,
      13             :                                                        const DataContainer<data_t>& b,
      14             :                                                        const DataContainer<data_t>& weights)
      15             :         : Functional<data_t>(A.getDomainDescriptor()), A_(A.clone()), b_(b), W_(weights)
      16          36 :     {
      17          36 :         if (A.getDomainDescriptor().getNumberOfCoefficientsPerDimension()
      18          36 :             != weights.getDataDescriptor().getNumberOfCoefficientsPerDimension()) {
      19           0 :             throw InvalidArgumentError("Domain of A and weights need to match");
      20           0 :         }
      21          36 :     }
      22             : 
      23             :     template <typename data_t>
      24             :     bool WeightedLeastSquares<data_t>::isDifferentiable() const
      25           6 :     {
      26           6 :         return true;
      27           6 :     }
      28             : 
      29             :     template <typename data_t>
      30             :     const LinearOperator<data_t>& WeightedLeastSquares<data_t>::getOperator() const
      31           0 :     {
      32           0 :         return *A_;
      33           0 :     }
      34             : 
      35             :     template <typename data_t>
      36             :     const DataContainer<data_t>& WeightedLeastSquares<data_t>::getDataVector() const
      37           0 :     {
      38           0 :         return b_;
      39           0 :     }
      40             : 
      41             :     template <typename data_t>
      42             :     data_t WeightedLeastSquares<data_t>::evaluateImpl(const DataContainer<data_t>& x) const
      43         104 :     {
      44             :         // Evaluate A(x) - b
      45         104 :         auto temp = A_->apply(x);
      46         104 :         temp -= b_;
      47             : 
      48             :         // evaluate weighted l2 norm
      49         104 :         W_.apply(x, temp);
      50         104 :         return static_cast<data_t>(0.5) * x.dot(temp);
      51         104 :     }
      52             : 
      53             :     template <typename data_t>
      54             :     void WeightedLeastSquares<data_t>::getGradientImpl(const DataContainer<data_t>& x,
      55             :                                                        DataContainer<data_t>& out) const
      56         108 :     {
      57             :         // Evaluate A(x) - b
      58         108 :         auto temp = A_->apply(x);
      59         108 :         temp -= b_;
      60             : 
      61         108 :         W_.apply(temp, temp);
      62             : 
      63             :         // Apply chain rule
      64         108 :         A_->applyAdjoint(temp, out);
      65         108 :     }
      66             : 
      67             :     template <typename data_t>
      68             :     LinearOperator<data_t>
      69             :         WeightedLeastSquares<data_t>::getHessianImpl(const DataContainer<data_t>&) const
      70           0 :     {
      71           0 :         return leaf(adjoint(*A_) * W_ * (*A_));
      72           0 :     }
      73             : 
      74             :     template <typename data_t>
      75             :     WeightedLeastSquares<data_t>* WeightedLeastSquares<data_t>::cloneImpl() const
      76          30 :     {
      77          30 :         return new WeightedLeastSquares(*A_, b_, W_.getScaleFactors());
      78          30 :     }
      79             : 
      80             :     template <typename data_t>
      81             :     bool WeightedLeastSquares<data_t>::isEqual(const Functional<data_t>& other) const
      82           0 :     {
      83           0 :         if (!Functional<data_t>::isEqual(other))
      84           0 :             return false;
      85             : 
      86           0 :         auto fn = downcast_safe<WeightedLeastSquares<data_t>>(&other);
      87           0 :         return fn && *A_ == *fn->A_ && b_ == fn->b_ && W_ == fn->W_;
      88           0 :     }
      89             : 
      90             :     // ------------------------------------------
      91             :     // explicit template instantiation
      92             :     template class WeightedLeastSquares<float>;
      93             :     template class WeightedLeastSquares<double>;
      94             : } // namespace elsa

Generated by: LCOV version 1.14