Line data Source code
1 : #pragma once 2 : 3 : #include "Cloneable.h" 4 : #include "DataContainer.h" 5 : #include "DataDescriptor.h" 6 : #include "Error.h" 7 : #include "LinearOperator.h" 8 : #include "TypeCasts.hpp" 9 : #include "elsaDefines.h" 10 : 11 : namespace elsa 12 : { 13 : /** 14 : * @brief Abstract base class representing a functional, i.e. a mapping from vectors to scalars. 15 : * 16 : * A functional is a mapping a vector to a scalar value (e.g. mapping the output of a Residual 17 : * to a scalar). Typical examples of functionals are norms or semi-norms, such as the L2 or L1 18 : * norms. 19 : * 20 : * Using LinearOperators, Residuals (e.g. LinearResidual) and a Functional (e.g. LeastSquares) 21 : * enables the formulation of typical terms in an OptimizationProblem. 22 : * 23 : * @tparam data_t data type for the domain of the residual of the functional, defaulting to 24 : * real_t 25 : * 26 : * @author 27 : * * Matthias Wieczorek - initial code 28 : * * Maximilian Hornung - modularization 29 : * * Tobias Lasser - rewrite 30 : * 31 : */ 32 : template <typename data_t = real_t> 33 : class Functional : public Cloneable<Functional<data_t>> 34 : { 35 : public: 36 : /** 37 : * @brief Constructor for the functional, mapping a domain vector to a scalar (without a 38 : * residual) 39 : * 40 : * @param[in] domainDescriptor describing the domain of the functional 41 : */ 42 : explicit Functional(const DataDescriptor& domainDescriptor); 43 : 44 : /// default destructor 45 1480 : ~Functional() override = default; 46 : 47 : /// return the domain descriptor 48 : const DataDescriptor& getDomainDescriptor() const; 49 : 50 : /** 51 : * @brief Indicate if a functional is differentiable. The default implementation returns 52 : * `false`. Functionals which are at least once differentiable should override this 53 : * functions. 54 : */ 55 : virtual bool isDifferentiable() const; 56 : 57 : /// @brief Indicate if the functional has a simple to compute proximal 58 : virtual bool isProxFriendly() const; 59 : 60 : /// @brief Indicate if the functional can compute the proximal of the dual 61 : virtual bool hasProxDual() const; 62 : 63 : /** 64 : * @brief evaluate the functional at x and return the result 65 : * 66 : * @param[in] x input DataContainer (in the domain of the functional) 67 : * 68 : * @returns result the scalar of the functional evaluated at x 69 : * 70 : * Please note: after evaluating the residual at x, this method calls the method 71 : * evaluateImpl that has to be overridden in derived classes to compute the functional's 72 : * value. 73 : */ 74 : data_t evaluate(const DataContainer<data_t>& x) const; 75 : 76 : /** 77 : * @brief compute the gradient of the functional at x and return the result 78 : * 79 : * @param[in] x input DataContainer (in the domain of the functional) 80 : * 81 : * @returns result DataContainer (in the domain of the functional) containing the result of 82 : * the gradient at x. 83 : * 84 : * Please note: this method uses getGradient(x, result) to perform the actual operation. 85 : */ 86 : DataContainer<data_t> getGradient(const DataContainer<data_t>& x) const; 87 : 88 : /** 89 : * @brief Compute the convex conjugate of the functional 90 : * 91 : * @param[in] x input DataContainer (in the domain of the functional) 92 : */ 93 : virtual data_t convexConjugate(const DataContainer<data_t>& x) const; 94 : 95 : /** 96 : * @brief compute the gradient of the functional at x and store in result 97 : * 98 : * @param[in] x input DataContainer (in the domain of the functional) 99 : * @param[out] result output DataContainer (in the domain of the functional) 100 : */ 101 : void getGradient(const DataContainer<data_t>& x, DataContainer<data_t>& result) const; 102 : 103 : /** 104 : * @brief return the Hessian of the functional at x 105 : * 106 : * @param[in] x input DataContainer (in the domain of the functional) 107 : * 108 : * @returns a LinearOperator (the Hessian) 109 : * 110 : * Note: some derived classes might decide to use only the diagonal of the Hessian as a fast 111 : * approximation! 112 : * 113 : * Please note: after evaluating the residual at x, this method calls the method 114 : * getHessianImpl that has to be overridden in derived classes to compute the functional's 115 : * Hessian, and after that the chain rule for the residual is applied (if necessary). 116 : */ 117 : LinearOperator<data_t> getHessian(const DataContainer<data_t>& x) const; 118 : 119 : /** 120 : * @brief compute the proximal of the given functional 121 : * 122 : * @param[in] v input DataContainer (in the domain of the functional) 123 : * @param[in] tau threshold/scaling parameter for proximal 124 : */ 125 : virtual DataContainer<data_t> proximal(const DataContainer<data_t>& v, 126 : SelfType_t<data_t> tau) const; 127 : 128 : /** 129 : * @brief compute the proximal of the given functional and write the result to the output 130 : * DataContainer 131 : * 132 : * @param[in] v input DataContainer (in the domain of the functional) 133 : * @param[in] tau threshold/scaling parameter for proximal 134 : * @param[out] out output DataContainer (in the domain of the functional) 135 : */ 136 : virtual void proximal(const DataContainer<data_t>& v, SelfType_t<data_t> t, 137 : DataContainer<data_t>& out) const; 138 : 139 : /// @brief compute the proximal of the convex conjugate of the functional. 140 : /// This method can either be overridden, or by default it computes the 141 : /// proximal of the convex conjugate using the Moreau’s identity. It is 142 : /// given as: 143 : /// @f[ 144 : /// \operatorname{prox}_{\tau f^*}(x) = x - \tau \operatorname{prox}_{\tau^{-1}f}(\tau^{-1} 145 : /// x) 146 : /// @f] 147 : virtual DataContainer<data_t> proxdual(const DataContainer<data_t>& x, 148 : SelfType_t<data_t> tau) const; 149 : 150 : /// @brief compute the proximal of the convex conjugate of the functional 151 : virtual void proxdual(const DataContainer<data_t>& x, SelfType_t<data_t> tau, 152 : DataContainer<data_t>& out) const; 153 : 154 : protected: 155 : /// the data descriptor of the domain of the functional 156 : std::unique_ptr<DataDescriptor> _domainDescriptor; 157 : 158 : /// implement the polymorphic comparison operation 159 : bool isEqual(const Functional<data_t>& other) const override; 160 : 161 : /** 162 : * @brief the evaluateImpl method that has to be overridden in derived classes 163 : * 164 : * @param[in] Rx the residual evaluated at x 165 : * 166 : * @returns the evaluated functional 167 : * 168 : * Please note: the evaluation of the residual is already performed in evaluate, so this 169 : * method only has to compute the functional's value itself. 170 : */ 171 : virtual data_t evaluateImpl(const DataContainer<data_t>& Rx) const = 0; 172 : 173 : /** 174 : * @brief the getGradientImplt method that has to be overridden in derived classes 175 : * 176 : * @param[in] Rx the value to evaluate the gradient of the functional 177 : * @param[in,out] out the evaluated gradient of the functional 178 : * 179 : * Please note: the evaluation of the residual is already performed in getGradient, as well 180 : * as the application of the chain rule. This method here only has to compute the gradient 181 : * of the functional itself, in an in-place manner (to avoid unnecessary DataContainers). 182 : */ 183 : virtual void getGradientImpl(const DataContainer<data_t>& Rx, 184 : DataContainer<data_t>& out) const = 0; 185 : 186 : /** 187 : * @brief the getHessianImpl method that has to be overridden in derived classes 188 : * 189 : * @param[in] Rx the residual evaluated at x 190 : * 191 : * @returns the LinearOperator representing the Hessian of the functional 192 : * 193 : * Please note: the evaluation of the residual is already performed in getHessian, as well 194 : * as the application of the chain rule. This method here only has to compute the Hessian of 195 : * the functional itself. 196 : */ 197 : virtual LinearOperator<data_t> getHessianImpl(const DataContainer<data_t>& Rx) const = 0; 198 : }; 199 : 200 : /** 201 : * @brief Class representing a sum of two functionals 202 : * \f[ 203 : * f(x) = h(x) + g(x) 204 : * \f] 205 : * 206 : * The gradient at \f$x\f$ is given as: 207 : * \f[ 208 : * \nabla f(x) = \nabla h(x) + \nabla g(x) 209 : * \f] 210 : * 211 : * and finally the hessian is given by: 212 : * \f[ 213 : * \nabla^2 f(x) = \nabla^2 h(x) \nabla^2 g(x) 214 : * \f] 215 : * 216 : * The gradient and hessian is only valid if the functional is (twice) 217 : * differentiable. The `operator+` is overloaded for, to conveniently create 218 : * this class. It should not be necessary to create it explicitly. 219 : */ 220 : template <class data_t> 221 : class FunctionalSum final : public Functional<data_t> 222 : { 223 : public: 224 : /// Construct from two functionals 225 : FunctionalSum(const Functional<data_t>& lhs, const Functional<data_t>& rhs); 226 : 227 : /// Make deletion of copy constructor explicit 228 : FunctionalSum(const FunctionalSum<data_t>&) = delete; 229 : 230 : /// Default Move constructor 231 : FunctionalSum(FunctionalSum<data_t>&& other) 232 : : Functional<data_t>(other.getDomainDescriptor()), 233 : lhs_(std::move(other.lhs_)), 234 : rhs_(std::move(other.rhs_)) 235 0 : { 236 0 : } 237 : 238 : /// Make deletion of copy assignment explicit 239 : FunctionalSum& operator=(const FunctionalSum<data_t>&) = delete; 240 : 241 : /// Default Move assignment 242 : FunctionalSum& operator=(FunctionalSum<data_t>&& other) noexcept 243 0 : { 244 0 : this->_domainDescriptor = std::move(other._domainDescriptor); 245 0 : lhs_ = std::move(other.lhs_); 246 0 : rhs_ = std::move(other.rhs_); 247 : 248 0 : return *this; 249 0 : } 250 : 251 : // Default destructor 252 88 : ~FunctionalSum() override = default; 253 : 254 : private: 255 : /// evaluate the functional as \f$g(x) + h(x)\f$ 256 : data_t evaluateImpl(const DataContainer<data_t>& Rx) const override; 257 : 258 : /// evaluate the gradient as: \f$\nabla g(x) + \nabla h(x)\f$ 259 : void getGradientImpl(const DataContainer<data_t>& Rx, 260 : DataContainer<data_t>& out) const override; 261 : 262 : /// construct the hessian as: \f$\nabla^2 g(x) + \nabla^2 h(x)\f$ 263 : LinearOperator<data_t> getHessianImpl(const DataContainer<data_t>& Rx) const override; 264 : 265 : /// Implement polymorphic clone 266 : FunctionalSum<data_t>* cloneImpl() const override; 267 : 268 : /// Implement polymorphic equality 269 : bool isEqual(const Functional<data_t>& other) const override; 270 : 271 : /// Store the left hand side functionl 272 : std::unique_ptr<Functional<data_t>> lhs_{}; 273 : 274 : /// Store the right hand side functional 275 : std::unique_ptr<Functional<data_t>> rhs_{}; 276 : }; 277 : 278 : /** 279 : * @brief Class representing a functional with a scalar multiplication: 280 : * \f[ 281 : * f(x) = \lambda * g(x) 282 : * \f] 283 : * 284 : * The gradient at \f$x\f$ is given as: 285 : * \f[ 286 : * \nabla f(x) = \lambda \nabla g(x) 287 : * \f] 288 : * 289 : * and finally the hessian is given by: 290 : * \f[ 291 : * \nabla^2 f(x) = \lambda \nabla^2 g(x) 292 : * \f] 293 : * 294 : * The gradient and hessian is only valid if the functional is differentiable. 295 : * The `operator*` is overloaded for scalar values with functionals, to 296 : * conveniently create this class. It should not be necessary to create it 297 : * explicitly. 298 : */ 299 : template <class data_t> 300 : class FunctionalScalarMul final : public Functional<data_t> 301 : { 302 : public: 303 : /// Construct functional from other functional and scalar 304 : FunctionalScalarMul(const Functional<data_t>& fn, SelfType_t<data_t> scalar); 305 : 306 : /// Make deletion of copy constructor explicit 307 : FunctionalScalarMul(const FunctionalScalarMul<data_t>&) = delete; 308 : 309 : /// Implement the move constructor 310 : FunctionalScalarMul(FunctionalScalarMul<data_t>&& other) 311 : : Functional<data_t>(other.getDomainDescriptor()), 312 : fn_(std::move(other.fn_)), 313 : scalar_(std::move(other.scalar_)) 314 0 : { 315 0 : } 316 : 317 : /// Make deletion of copy assignment explicit 318 : FunctionalScalarMul& operator=(const FunctionalScalarMul<data_t>&) = delete; 319 : 320 : /// Implement the move assignment operator 321 : FunctionalScalarMul& operator=(FunctionalScalarMul<data_t>&& other) noexcept 322 0 : { 323 0 : this->_domainDescriptor = std::move(other._domainDescriptor); 324 0 : fn_ = std::move(other.fn_); 325 0 : scalar_ = std::move(other.scalar_); 326 : 327 0 : return *this; 328 0 : } 329 : 330 : /// Default destructor 331 40 : ~FunctionalScalarMul() override = default; 332 : 333 : bool isProxFriendly() const override; 334 : 335 : /** 336 : * @brief The convex conjugate of a scaled function @f$f(x) = \lambda g(x)@f$ is given as: 337 : * @f[ 338 : * f^*(x) = \lambda g^*(\frac{x}{\lambda}) 339 : * @f] 340 : */ 341 : data_t convexConjugate(const DataContainer<data_t>& x) const override; 342 : 343 : DataContainer<data_t> proximal(const DataContainer<data_t>& v, 344 : SelfType_t<data_t> t) const override; 345 : 346 : void proximal(const DataContainer<data_t>& v, SelfType_t<data_t> t, 347 : DataContainer<data_t>& out) const override; 348 : 349 : private: 350 : /// Evaluate as \f$\lambda * \nabla g(x)\f$ 351 : data_t evaluateImpl(const DataContainer<data_t>& Rx) const override; 352 : 353 : /// Evaluate gradient as: \f$\lambda * \nabla g(x)\f$ 354 : void getGradientImpl(const DataContainer<data_t>& Rx, 355 : DataContainer<data_t>& out) const override; 356 : 357 : /// Construct hessian as: \f$\lambda * \nabla^2 g(x)\f$ 358 : LinearOperator<data_t> getHessianImpl(const DataContainer<data_t>& Rx) const override; 359 : 360 : /// Implementation of polymorphic clone 361 : FunctionalScalarMul<data_t>* cloneImpl() const override; 362 : 363 : /// Implementation of polymorphic equality 364 : bool isEqual(const Functional<data_t>& other) const override; 365 : 366 : /// Store other functional \f$g\f$ 367 : std::unique_ptr<Functional<data_t>> fn_{}; 368 : 369 : /// The scalar 370 : data_t scalar_; 371 : }; 372 : 373 : template <class data_t> 374 : FunctionalScalarMul<data_t> operator*(SelfType_t<data_t> s, const Functional<data_t>& f) 375 8 : { 376 : // TODO: consider returning the ZeroFunctional, if s == 0, but then 377 : // it's necessary to return unique_ptr and I hate that 378 8 : return FunctionalScalarMul<data_t>(f, s); 379 8 : } 380 : 381 : template <class data_t> 382 : FunctionalScalarMul<data_t> operator*(const Functional<data_t>& f, SelfType_t<data_t> s) 383 4 : { 384 : // TODO: consider returning the ZeroFunctional, if s == 0, but then 385 : // it's necessary to return unique_ptr and I hate that 386 4 : return FunctionalScalarMul<data_t>(f, s); 387 4 : } 388 : 389 : template <class data_t> 390 : FunctionalSum<data_t> operator+(const Functional<data_t>& lhs, const Functional<data_t>& rhs) 391 24 : { 392 24 : return FunctionalSum<data_t>(lhs, rhs); 393 24 : } 394 : 395 : } // namespace elsa