LCOV - code coverage report
Current view: top level - elsa/functionals - Functional.h (source / functions) Hit Total Coverage
Test: coverage-all.lcov Lines: 12 28 42.9 %
Date: 2024-05-16 04:22:26 Functions: 24 40 60.0 %

          Line data    Source code
       1             : #pragma once
       2             : 
       3             : #include "Cloneable.h"
       4             : #include "DataContainer.h"
       5             : #include "DataDescriptor.h"
       6             : #include "Error.h"
       7             : #include "LinearOperator.h"
       8             : #include "TypeCasts.hpp"
       9             : #include "elsaDefines.h"
      10             : 
      11             : namespace elsa
      12             : {
      13             :     /**
      14             :      * @brief Abstract base class representing a functional, i.e. a mapping from vectors to scalars.
      15             :      *
      16             :      * A functional is a mapping a vector to a scalar value (e.g. mapping the output of a Residual
      17             :      * to a scalar). Typical examples of functionals are norms or semi-norms, such as the L2 or L1
      18             :      * norms.
      19             :      *
      20             :      * Using LinearOperators, Residuals (e.g. LinearResidual) and a Functional (e.g. LeastSquares)
      21             :      * enables the formulation of typical terms in an OptimizationProblem.
      22             :      *
      23             :      * @tparam data_t data type for the domain of the residual of the functional, defaulting to
      24             :      * real_t
      25             :      *
      26             :      * @author
      27             :      * * Matthias Wieczorek - initial code
      28             :      * * Maximilian Hornung - modularization
      29             :      * * Tobias Lasser - rewrite
      30             :      *
      31             :      */
      32             :     template <typename data_t = real_t>
      33             :     class Functional : public Cloneable<Functional<data_t>>
      34             :     {
      35             :     public:
      36             :         /**
      37             :          * @brief Constructor for the functional, mapping a domain vector to a scalar (without a
      38             :          * residual)
      39             :          *
      40             :          * @param[in] domainDescriptor describing the domain of the functional
      41             :          */
      42             :         explicit Functional(const DataDescriptor& domainDescriptor);
      43             : 
      44             :         /// default destructor
      45        1480 :         ~Functional() override = default;
      46             : 
      47             :         /// return the domain descriptor
      48             :         const DataDescriptor& getDomainDescriptor() const;
      49             : 
      50             :         /**
      51             :          * @brief Indicate if a functional is differentiable. The default implementation returns
      52             :          * `false`. Functionals which are at least once differentiable should override this
      53             :          * functions.
      54             :          */
      55             :         virtual bool isDifferentiable() const;
      56             : 
      57             :         /// @brief Indicate if the functional has a simple to compute proximal
      58             :         virtual bool isProxFriendly() const;
      59             : 
      60             :         /// @brief Indicate if the functional can compute the proximal of the dual
      61             :         virtual bool hasProxDual() const;
      62             : 
      63             :         /**
      64             :          * @brief evaluate the functional at x and return the result
      65             :          *
      66             :          * @param[in] x input DataContainer (in the domain of the functional)
      67             :          *
      68             :          * @returns result the scalar of the functional evaluated at x
      69             :          *
      70             :          * Please note: after evaluating the residual at x, this method calls the method
      71             :          * evaluateImpl that has to be overridden in derived classes to compute the functional's
      72             :          * value.
      73             :          */
      74             :         data_t evaluate(const DataContainer<data_t>& x) const;
      75             : 
      76             :         /**
      77             :          * @brief compute the gradient of the functional at x and return the result
      78             :          *
      79             :          * @param[in] x input DataContainer (in the domain of the functional)
      80             :          *
      81             :          * @returns result DataContainer (in the domain of the functional) containing the result of
      82             :          * the gradient at x.
      83             :          *
      84             :          * Please note: this method uses getGradient(x, result) to perform the actual operation.
      85             :          */
      86             :         DataContainer<data_t> getGradient(const DataContainer<data_t>& x) const;
      87             : 
      88             :         /**
      89             :          * @brief Compute the convex conjugate of the functional
      90             :          *
      91             :          * @param[in] x input DataContainer (in the domain of the functional)
      92             :          */
      93             :         virtual data_t convexConjugate(const DataContainer<data_t>& x) const;
      94             : 
      95             :         /**
      96             :          * @brief compute the gradient of the functional at x and store in result
      97             :          *
      98             :          * @param[in] x input DataContainer (in the domain of the functional)
      99             :          * @param[out] result output DataContainer (in the domain of the functional)
     100             :          */
     101             :         void getGradient(const DataContainer<data_t>& x, DataContainer<data_t>& result) const;
     102             : 
     103             :         /**
     104             :          * @brief return the Hessian of the functional at x
     105             :          *
     106             :          * @param[in] x input DataContainer (in the domain of the functional)
     107             :          *
     108             :          * @returns a LinearOperator (the Hessian)
     109             :          *
     110             :          * Note: some derived classes might decide to use only the diagonal of the Hessian as a fast
     111             :          * approximation!
     112             :          *
     113             :          * Please note: after evaluating the residual at x, this method calls the method
     114             :          * getHessianImpl that has to be overridden in derived classes to compute the functional's
     115             :          * Hessian, and after that the chain rule for the residual is applied (if necessary).
     116             :          */
     117             :         LinearOperator<data_t> getHessian(const DataContainer<data_t>& x) const;
     118             : 
     119             :         /**
     120             :          * @brief compute the proximal of the given functional
     121             :          *
     122             :          * @param[in] v input DataContainer (in the domain of the functional)
     123             :          * @param[in] tau threshold/scaling parameter for proximal
     124             :          */
     125             :         virtual DataContainer<data_t> proximal(const DataContainer<data_t>& v,
     126             :                                                SelfType_t<data_t> tau) const;
     127             : 
     128             :         /**
     129             :          * @brief compute the proximal of the given functional and write the result to the output
     130             :          * DataContainer
     131             :          *
     132             :          * @param[in] v input DataContainer (in the domain of the functional)
     133             :          * @param[in] tau threshold/scaling parameter for proximal
     134             :          * @param[out] out output DataContainer (in the domain of the functional)
     135             :          */
     136             :         virtual void proximal(const DataContainer<data_t>& v, SelfType_t<data_t> t,
     137             :                               DataContainer<data_t>& out) const;
     138             : 
     139             :         /// @brief compute the proximal of the convex conjugate of the functional.
     140             :         /// This method can either be overridden, or by default it computes the
     141             :         /// proximal of the convex conjugate using the Moreau’s identity. It is
     142             :         /// given as:
     143             :         /// @f[
     144             :         /// \operatorname{prox}_{\tau f^*}(x) = x - \tau \operatorname{prox}_{\tau^{-1}f}(\tau^{-1}
     145             :         /// x)
     146             :         /// @f]
     147             :         virtual DataContainer<data_t> proxdual(const DataContainer<data_t>& x,
     148             :                                                SelfType_t<data_t> tau) const;
     149             : 
     150             :         /// @brief compute the proximal of the convex conjugate of the functional
     151             :         virtual void proxdual(const DataContainer<data_t>& x, SelfType_t<data_t> tau,
     152             :                               DataContainer<data_t>& out) const;
     153             : 
     154             :     protected:
     155             :         /// the data descriptor of the domain of the functional
     156             :         std::unique_ptr<DataDescriptor> _domainDescriptor;
     157             : 
     158             :         /// implement the polymorphic comparison operation
     159             :         bool isEqual(const Functional<data_t>& other) const override;
     160             : 
     161             :         /**
     162             :          * @brief the evaluateImpl method that has to be overridden in derived classes
     163             :          *
     164             :          * @param[in] Rx the residual evaluated at x
     165             :          *
     166             :          * @returns the evaluated functional
     167             :          *
     168             :          * Please note: the evaluation of the residual is already performed in evaluate, so this
     169             :          * method only has to compute the functional's value itself.
     170             :          */
     171             :         virtual data_t evaluateImpl(const DataContainer<data_t>& Rx) const = 0;
     172             : 
     173             :         /**
     174             :          * @brief the getGradientImplt method that has to be overridden in derived classes
     175             :          *
     176             :          * @param[in] Rx the value to evaluate the gradient of the functional
     177             :          * @param[in,out] out the evaluated gradient of the functional
     178             :          *
     179             :          * Please note: the evaluation of the residual is already performed in getGradient, as well
     180             :          * as the application of the chain rule. This method here only has to compute the gradient
     181             :          * of the functional itself, in an in-place manner (to avoid unnecessary DataContainers).
     182             :          */
     183             :         virtual void getGradientImpl(const DataContainer<data_t>& Rx,
     184             :                                      DataContainer<data_t>& out) const = 0;
     185             : 
     186             :         /**
     187             :          * @brief the getHessianImpl method that has to be overridden in derived classes
     188             :          *
     189             :          * @param[in] Rx the residual evaluated at x
     190             :          *
     191             :          * @returns the LinearOperator representing the Hessian of the functional
     192             :          *
     193             :          * Please note: the evaluation of the residual is already performed in getHessian, as well
     194             :          * as the application of the chain rule. This method here only has to compute the Hessian of
     195             :          * the functional itself.
     196             :          */
     197             :         virtual LinearOperator<data_t> getHessianImpl(const DataContainer<data_t>& Rx) const = 0;
     198             :     };
     199             : 
     200             :     /**
     201             :      * @brief Class representing a sum of two functionals
     202             :      * \f[
     203             :      * f(x) = h(x) + g(x)
     204             :      * \f]
     205             :      *
     206             :      * The gradient at \f$x\f$ is given as:
     207             :      * \f[
     208             :      * \nabla f(x) = \nabla h(x) + \nabla g(x)
     209             :      * \f]
     210             :      *
     211             :      * and finally the hessian is given by:
     212             :      * \f[
     213             :      * \nabla^2 f(x) = \nabla^2 h(x) \nabla^2 g(x)
     214             :      * \f]
     215             :      *
     216             :      * The gradient and hessian is only valid if the functional is (twice)
     217             :      * differentiable. The `operator+` is overloaded for, to conveniently create
     218             :      * this class. It should not be necessary to create it explicitly.
     219             :      */
     220             :     template <class data_t>
     221             :     class FunctionalSum final : public Functional<data_t>
     222             :     {
     223             :     public:
     224             :         /// Construct from two functionals
     225             :         FunctionalSum(const Functional<data_t>& lhs, const Functional<data_t>& rhs);
     226             : 
     227             :         /// Make deletion of copy constructor explicit
     228             :         FunctionalSum(const FunctionalSum<data_t>&) = delete;
     229             : 
     230             :         /// Default Move constructor
     231             :         FunctionalSum(FunctionalSum<data_t>&& other)
     232             :             : Functional<data_t>(other.getDomainDescriptor()),
     233             :               lhs_(std::move(other.lhs_)),
     234             :               rhs_(std::move(other.rhs_))
     235           0 :         {
     236           0 :         }
     237             : 
     238             :         /// Make deletion of copy assignment explicit
     239             :         FunctionalSum& operator=(const FunctionalSum<data_t>&) = delete;
     240             : 
     241             :         /// Default Move assignment
     242             :         FunctionalSum& operator=(FunctionalSum<data_t>&& other) noexcept
     243           0 :         {
     244           0 :             this->_domainDescriptor = std::move(other._domainDescriptor);
     245           0 :             lhs_ = std::move(other.lhs_);
     246           0 :             rhs_ = std::move(other.rhs_);
     247             : 
     248           0 :             return *this;
     249           0 :         }
     250             : 
     251             :         // Default destructor
     252          88 :         ~FunctionalSum() override = default;
     253             : 
     254             :     private:
     255             :         /// evaluate the functional as \f$g(x) + h(x)\f$
     256             :         data_t evaluateImpl(const DataContainer<data_t>& Rx) const override;
     257             : 
     258             :         /// evaluate the gradient as: \f$\nabla g(x) + \nabla h(x)\f$
     259             :         void getGradientImpl(const DataContainer<data_t>& Rx,
     260             :                              DataContainer<data_t>& out) const override;
     261             : 
     262             :         /// construct the hessian as: \f$\nabla^2 g(x) + \nabla^2 h(x)\f$
     263             :         LinearOperator<data_t> getHessianImpl(const DataContainer<data_t>& Rx) const override;
     264             : 
     265             :         /// Implement polymorphic clone
     266             :         FunctionalSum<data_t>* cloneImpl() const override;
     267             : 
     268             :         /// Implement polymorphic equality
     269             :         bool isEqual(const Functional<data_t>& other) const override;
     270             : 
     271             :         /// Store the left hand side functionl
     272             :         std::unique_ptr<Functional<data_t>> lhs_{};
     273             : 
     274             :         /// Store the right hand side functional
     275             :         std::unique_ptr<Functional<data_t>> rhs_{};
     276             :     };
     277             : 
     278             :     /**
     279             :      * @brief Class representing a functional with a scalar multiplication:
     280             :      * \f[
     281             :      * f(x) = \lambda * g(x)
     282             :      * \f]
     283             :      *
     284             :      * The gradient at \f$x\f$ is given as:
     285             :      * \f[
     286             :      * \nabla f(x) = \lambda \nabla g(x)
     287             :      * \f]
     288             :      *
     289             :      * and finally the hessian is given by:
     290             :      * \f[
     291             :      * \nabla^2 f(x) = \lambda \nabla^2 g(x)
     292             :      * \f]
     293             :      *
     294             :      * The gradient and hessian is only valid if the functional is differentiable.
     295             :      * The `operator*` is overloaded for scalar values with functionals, to
     296             :      * conveniently create this class. It should not be necessary to create it
     297             :      * explicitly.
     298             :      */
     299             :     template <class data_t>
     300             :     class FunctionalScalarMul final : public Functional<data_t>
     301             :     {
     302             :     public:
     303             :         /// Construct functional from other functional and scalar
     304             :         FunctionalScalarMul(const Functional<data_t>& fn, SelfType_t<data_t> scalar);
     305             : 
     306             :         /// Make deletion of copy constructor explicit
     307             :         FunctionalScalarMul(const FunctionalScalarMul<data_t>&) = delete;
     308             : 
     309             :         /// Implement the move constructor
     310             :         FunctionalScalarMul(FunctionalScalarMul<data_t>&& other)
     311             :             : Functional<data_t>(other.getDomainDescriptor()),
     312             :               fn_(std::move(other.fn_)),
     313             :               scalar_(std::move(other.scalar_))
     314           0 :         {
     315           0 :         }
     316             : 
     317             :         /// Make deletion of copy assignment explicit
     318             :         FunctionalScalarMul& operator=(const FunctionalScalarMul<data_t>&) = delete;
     319             : 
     320             :         /// Implement the move assignment operator
     321             :         FunctionalScalarMul& operator=(FunctionalScalarMul<data_t>&& other) noexcept
     322           0 :         {
     323           0 :             this->_domainDescriptor = std::move(other._domainDescriptor);
     324           0 :             fn_ = std::move(other.fn_);
     325           0 :             scalar_ = std::move(other.scalar_);
     326             : 
     327           0 :             return *this;
     328           0 :         }
     329             : 
     330             :         /// Default destructor
     331          40 :         ~FunctionalScalarMul() override = default;
     332             : 
     333             :         bool isProxFriendly() const override;
     334             : 
     335             :         /**
     336             :          * @brief The convex conjugate of a scaled function @f$f(x) = \lambda g(x)@f$ is given as:
     337             :          * @f[
     338             :          * f^*(x) = \lambda g^*(\frac{x}{\lambda})
     339             :          * @f]
     340             :          */
     341             :         data_t convexConjugate(const DataContainer<data_t>& x) const override;
     342             : 
     343             :         DataContainer<data_t> proximal(const DataContainer<data_t>& v,
     344             :                                        SelfType_t<data_t> t) const override;
     345             : 
     346             :         void proximal(const DataContainer<data_t>& v, SelfType_t<data_t> t,
     347             :                       DataContainer<data_t>& out) const override;
     348             : 
     349             :     private:
     350             :         /// Evaluate as \f$\lambda * \nabla g(x)\f$
     351             :         data_t evaluateImpl(const DataContainer<data_t>& Rx) const override;
     352             : 
     353             :         /// Evaluate gradient as: \f$\lambda * \nabla g(x)\f$
     354             :         void getGradientImpl(const DataContainer<data_t>& Rx,
     355             :                              DataContainer<data_t>& out) const override;
     356             : 
     357             :         /// Construct hessian as: \f$\lambda * \nabla^2 g(x)\f$
     358             :         LinearOperator<data_t> getHessianImpl(const DataContainer<data_t>& Rx) const override;
     359             : 
     360             :         /// Implementation of polymorphic clone
     361             :         FunctionalScalarMul<data_t>* cloneImpl() const override;
     362             : 
     363             :         /// Implementation of polymorphic equality
     364             :         bool isEqual(const Functional<data_t>& other) const override;
     365             : 
     366             :         /// Store other functional \f$g\f$
     367             :         std::unique_ptr<Functional<data_t>> fn_{};
     368             : 
     369             :         /// The scalar
     370             :         data_t scalar_;
     371             :     };
     372             : 
     373             :     template <class data_t>
     374             :     FunctionalScalarMul<data_t> operator*(SelfType_t<data_t> s, const Functional<data_t>& f)
     375           8 :     {
     376             :         // TODO: consider returning the ZeroFunctional, if s == 0, but then
     377             :         // it's necessary to return unique_ptr and I hate that
     378           8 :         return FunctionalScalarMul<data_t>(f, s);
     379           8 :     }
     380             : 
     381             :     template <class data_t>
     382             :     FunctionalScalarMul<data_t> operator*(const Functional<data_t>& f, SelfType_t<data_t> s)
     383           4 :     {
     384             :         // TODO: consider returning the ZeroFunctional, if s == 0, but then
     385             :         // it's necessary to return unique_ptr and I hate that
     386           4 :         return FunctionalScalarMul<data_t>(f, s);
     387           4 :     }
     388             : 
     389             :     template <class data_t>
     390             :     FunctionalSum<data_t> operator+(const Functional<data_t>& lhs, const Functional<data_t>& rhs)
     391          24 :     {
     392          24 :         return FunctionalSum<data_t>(lhs, rhs);
     393          24 :     }
     394             : 
     395             : } // namespace elsa

Generated by: LCOV version 1.14