Line data Source code
1 : #pragma once 2 : 3 : #include "DataDescriptor.h" 4 : #include "Functional.h" 5 : #include "DataContainer.h" 6 : #include "LinearOperator.h" 7 : #include "Scaling.h" 8 : 9 : namespace elsa 10 : { 11 : /** 12 : * @brief The least squares functional / loss functional. 13 : * 14 : * The least squares loss is given by: 15 : * \[ 16 : * \frac{1}{2} || A(x) - b ||_2^2 17 : * \] 18 : * i.e. the squared \f$\ell^2\f$ of the linear residual. 19 : * 20 : * @tparam data_t data type for the domain of the residual of the functional, defaulting to 21 : * real_t 22 : */ 23 : template <typename data_t = real_t> 24 : class WeightedLeastSquares : public Functional<data_t> 25 : { 26 : public: 27 : /** 28 : * @brief Constructor the l2 norm (squared) functional with a LinearResidual 29 : * 30 : * @param[in] A LinearOperator to use in the residual 31 : * @param[in] b data to use in the linear residual 32 : */ 33 : WeightedLeastSquares(const LinearOperator<data_t>& A, const DataContainer<data_t>& b, 34 : const DataContainer<data_t>& weights); 35 : 36 : /// make copy constructor deletion explicit 37 : WeightedLeastSquares(const WeightedLeastSquares<data_t>&) = delete; 38 : 39 : /// default destructor 40 36 : ~WeightedLeastSquares() override = default; 41 : 42 : bool isDifferentiable() const override; 43 : 44 : const LinearOperator<data_t>& getOperator() const; 45 : 46 : const DataContainer<data_t>& getDataVector() const; 47 : 48 : protected: 49 : /// the evaluation of the l2 norm (squared) 50 : data_t evaluateImpl(const DataContainer<data_t>& Rx) const override; 51 : 52 : /// the computation of the gradient (in place) 53 : void getGradientImpl(const DataContainer<data_t>& Rx, 54 : DataContainer<data_t>& out) const override; 55 : 56 : /// the computation of the Hessian 57 : LinearOperator<data_t> getHessianImpl(const DataContainer<data_t>& Rx) const override; 58 : 59 : /// implement the polymorphic clone operation 60 : WeightedLeastSquares<data_t>* cloneImpl() const override; 61 : 62 : /// implement the polymorphic comparison operation 63 : bool isEqual(const Functional<data_t>& other) const override; 64 : 65 : private: 66 : std::unique_ptr<LinearOperator<data_t>> A_{}; 67 : 68 : DataContainer<data_t> b_{}; 69 : 70 : Scaling<data_t> W_; 71 : }; 72 : 73 : } // namespace elsa