Line data Source code
1 : #pragma once 2 : 3 : #include "Cloneable.h" 4 : #include "DataDescriptor.h" 5 : #include "Residual.h" 6 : #include "LinearOperator.h" 7 : 8 : namespace elsa 9 : { 10 : /** 11 : * @brief Abstract base class representing a functional, i.e. a mapping from vectors to scalars. 12 : * 13 : * @author Matthias Wieczorek - initial code 14 : * @author Maximilian Hornung - modularization 15 : * @author Tobias Lasser - rewrite 16 : * 17 : * @tparam data_t data type for the domain of the residual of the functional, defaulting to 18 : * real_t 19 : * 20 : * A functional is a mapping a vector to a scalar value (e.g. mapping the output of a Residual 21 : * to a scalar). Typical examples of functionals are norms or semi-norms, such as the L2 or L1 22 : * norms. 23 : * 24 : * Using LinearOperators, Residuals (e.g. LinearResidual) and a Functional (e.g. L2NormPow2) 25 : * enables the formulation of typical terms in an OptimizationProblem. 26 : */ 27 : template <typename data_t = real_t> 28 : class Functional : public Cloneable<Functional<data_t>> 29 : { 30 : public: 31 : /** 32 : * @brief Constructor for the functional, mapping a domain vector to a scalar (without a 33 : * residual) 34 : * 35 : * @param[in] domainDescriptor describing the domain of the functional 36 : */ 37 : explicit Functional(const DataDescriptor& domainDescriptor); 38 : 39 : /** 40 : * @brief Constructor for the functional, using a Residual as input to map to a scalar 41 : * 42 : * @param[in] residual to be used when evaluating the functional (or its derivatives) 43 : */ 44 : explicit Functional(const Residual<data_t>& residual); 45 : 46 : /// default destructor 47 2700 : ~Functional() override = default; 48 : 49 : /// return the domain descriptor 50 : const DataDescriptor& getDomainDescriptor() const; 51 : 52 : /// return the residual (will be trivial if Functional was constructed without one) 53 : const Residual<data_t>& getResidual() const; 54 : 55 : /** 56 : * @brief evaluate the functional at x and return the result 57 : * 58 : * @param[in] x input DataContainer (in the domain of the functional) 59 : * 60 : * @returns result the scalar of the functional evaluated at x 61 : * 62 : * Please note: after evaluating the residual at x, this method calls the method 63 : * evaluateImpl that has to be overridden in derived classes to compute the functional's 64 : * value. 65 : */ 66 : data_t evaluate(const DataContainer<data_t>& x); 67 : 68 : /** 69 : * @brief compute the gradient of the functional at x and return the result 70 : * 71 : * @param[in] x input DataContainer (in the domain of the functional) 72 : * 73 : * @returns result DataContainer (in the domain of the functional) containing the result of 74 : * the gradient at x. 75 : * 76 : * Please note: this method uses getGradient(x, result) to perform the actual operation. 77 : */ 78 : DataContainer<data_t> getGradient(const DataContainer<data_t>& x); 79 : 80 : /** 81 : * @brief compute the gradient of the functional at x and store in result 82 : * 83 : * @param[in] x input DataContainer (in the domain of the functional) 84 : * @param[out] result output DataContainer (in the domain of the functional) 85 : * 86 : * Please note: after evaluating the residual at x, this methods calls the method 87 : * getGradientInPlaceImpl that has to be overridden in derived classes to compute the 88 : * functional's gradient, and after that the chain rule for the residual is applied (if 89 : * necessary). 90 : */ 91 : void getGradient(const DataContainer<data_t>& x, DataContainer<data_t>& result); 92 : 93 : /** 94 : * @brief return the Hessian of the functional at x 95 : * 96 : * @param[in] x input DataContainer (in the domain of the functional) 97 : * 98 : * @returns a LinearOperator (the Hessian) 99 : * 100 : * Note: some derived classes might decide to use only the diagonal of the Hessian as a fast 101 : * approximation! 102 : * 103 : * Please note: after evaluating the residual at x, this method calls the method 104 : * getHessianImpl that has to be overridden in derived classes to compute the functional's 105 : * Hessian, and after that the chain rule for the residual is applied (if necessary). 106 : */ 107 : LinearOperator<data_t> getHessian(const DataContainer<data_t>& x); 108 : 109 : protected: 110 : /// the data descriptor of the domain of the functional 111 : std::unique_ptr<DataDescriptor> _domainDescriptor; 112 : 113 : /// the residual 114 : std::unique_ptr<Residual<data_t>> _residual; 115 : 116 : /// implement the polymorphic comparison operation 117 : bool isEqual(const Functional<data_t>& other) const override; 118 : 119 : /** 120 : * @brief the evaluateImpl method that has to be overridden in derived classes 121 : * 122 : * @param[in] Rx the residual evaluated at x 123 : * 124 : * @returns the evaluated functional 125 : * 126 : * Please note: the evaluation of the residual is already performed in evaluate, so this 127 : * method only has to compute the functional's value itself. 128 : */ 129 : virtual data_t evaluateImpl(const DataContainer<data_t>& Rx) = 0; 130 : 131 : /** 132 : * @brief the getGradientInPlaceImpl method that has to be overridden in derived classes 133 : * 134 : * @param[in,out] Rx the residual evaluated at x (in), and the gradient of the functional 135 : * (out) 136 : * 137 : * Please note: the evaluation of the residual is already performed in getGradient, as well 138 : * as the application of the chain rule. This method here only has to compute the gradient 139 : * of the functional itself, in an in-place manner (to avoid unnecessary DataContainers). 140 : */ 141 : virtual void getGradientInPlaceImpl(DataContainer<data_t>& Rx) = 0; 142 : 143 : /** 144 : * @brief the getHessianImpl method that has to be overridden in derived classes 145 : * 146 : * @param[in] Rx the residual evaluated at x 147 : * 148 : * @returns the LinearOperator representing the Hessian of the functional 149 : * 150 : * Please note: the evaluation of the residual is already performed in getHessian, as well 151 : * as the application of the chain rule. This method here only has to compute the Hessian of 152 : * the functional itself. 153 : */ 154 : virtual LinearOperator<data_t> getHessianImpl(const DataContainer<data_t>& Rx) = 0; 155 : }; 156 : } // namespace elsa