LCOV - code coverage report
Current view: top level - elsa/functionals - WeightedLeastSquares.h (source / functions) Hit Total Coverage
Test: coverage-all.lcov Lines: 1 1 100.0 %
Date: 2024-05-16 04:22:26 Functions: 2 2 100.0 %

          Line data    Source code
       1             : #pragma once
       2             : 
       3             : #include "DataDescriptor.h"
       4             : #include "Functional.h"
       5             : #include "DataContainer.h"
       6             : #include "LinearOperator.h"
       7             : #include "Scaling.h"
       8             : 
       9             : namespace elsa
      10             : {
      11             :     /**
      12             :      * @brief The least squares functional / loss functional.
      13             :      *
      14             :      * The least squares loss is given by:
      15             :      * \[
      16             :      * \frac{1}{2} || A(x) - b ||_2^2
      17             :      * \]
      18             :      * i.e. the squared \f$\ell^2\f$ of the linear residual.
      19             :      *
      20             :      * @tparam data_t data type for the domain of the residual of the functional, defaulting to
      21             :      * real_t
      22             :      */
      23             :     template <typename data_t = real_t>
      24             :     class WeightedLeastSquares : public Functional<data_t>
      25             :     {
      26             :     public:
      27             :         /**
      28             :          * @brief Constructor the l2 norm (squared) functional with a LinearResidual
      29             :          *
      30             :          * @param[in] A LinearOperator to use in the residual
      31             :          * @param[in] b data to use in the linear residual
      32             :          */
      33             :         WeightedLeastSquares(const LinearOperator<data_t>& A, const DataContainer<data_t>& b,
      34             :                              const DataContainer<data_t>& weights);
      35             : 
      36             :         /// make copy constructor deletion explicit
      37             :         WeightedLeastSquares(const WeightedLeastSquares<data_t>&) = delete;
      38             : 
      39             :         /// default destructor
      40          36 :         ~WeightedLeastSquares() override = default;
      41             : 
      42             :         bool isDifferentiable() const override;
      43             : 
      44             :         const LinearOperator<data_t>& getOperator() const;
      45             : 
      46             :         const DataContainer<data_t>& getDataVector() const;
      47             : 
      48             :     protected:
      49             :         /// the evaluation of the l2 norm (squared)
      50             :         data_t evaluateImpl(const DataContainer<data_t>& Rx) const override;
      51             : 
      52             :         /// the computation of the gradient (in place)
      53             :         void getGradientImpl(const DataContainer<data_t>& Rx,
      54             :                              DataContainer<data_t>& out) const override;
      55             : 
      56             :         /// the computation of the Hessian
      57             :         LinearOperator<data_t> getHessianImpl(const DataContainer<data_t>& Rx) const override;
      58             : 
      59             :         /// implement the polymorphic clone operation
      60             :         WeightedLeastSquares<data_t>* cloneImpl() const override;
      61             : 
      62             :         /// implement the polymorphic comparison operation
      63             :         bool isEqual(const Functional<data_t>& other) const override;
      64             : 
      65             :     private:
      66             :         std::unique_ptr<LinearOperator<data_t>> A_{};
      67             : 
      68             :         DataContainer<data_t> b_{};
      69             : 
      70             :         Scaling<data_t> W_;
      71             :     };
      72             : 
      73             : } // namespace elsa

Generated by: LCOV version 1.14