Line data Source code
1 : #pragma once 2 : 3 : #include "DataContainer.h" 4 : #include "Functional.h" 5 : 6 : namespace elsa 7 : { 8 : /** 9 : * @brief Class representing the Huber loss. 10 : * 11 : * The Huber loss evaluates to \f$ \sum_{i=1}^n \begin{cases} \frac{1}{2} x_i^2 & \text{for } 12 : * |x_i| \leq \delta \\ \delta\left(|x_i| - \frac{1}{2}\delta\right) & \text{else} \end{cases} 13 : * \f$ for \f$ x=(x_i)_{i=1}^n \f$ and a cut-off parameter \f$ \delta \f$. 14 : * 15 : * Reference: https://doi.org/10.1214%2Faoms%2F1177703732 16 : * 17 : * @tparam data_t data type for the domain of the residual of the functional, defaulting to 18 : * real_t 19 : * 20 : * @author 21 : * * Matthias Wieczorek - initial code 22 : * * Maximilian Hornung - modularization 23 : * * Tobias Lasser - modernization 24 : * 25 : */ 26 : template <typename data_t = real_t> 27 : class Huber : public Functional<data_t> 28 : { 29 : public: 30 : /** 31 : * @brief Constructor for the Huber functional, mapping domain vector to scalar (without a 32 : * residual) 33 : * 34 : * @param[in] domainDescriptor describing the domain of the functional 35 : * @param[in] delta parameter for linear/square cutoff (defaults to 1e-6) 36 : */ 37 : explicit Huber(const DataDescriptor& domainDescriptor, 38 : real_t delta = static_cast<real_t>(1e-6)); 39 : 40 : /// make copy constructor deletion explicit 41 : Huber(const Huber<data_t>&) = delete; 42 : 43 : /// default destructor 44 12 : ~Huber() override = default; 45 : 46 : bool isDifferentiable() const override; 47 : 48 : protected: 49 : /// the evaluation of the Huber loss 50 : data_t evaluateImpl(const DataContainer<data_t>& Rx) const override; 51 : 52 : /// the computation of the gradient (in place) 53 : void getGradientImpl(const DataContainer<data_t>& Rx, 54 : DataContainer<data_t>& out) const override; 55 : 56 : /// the computation of the Hessian 57 : LinearOperator<data_t> getHessianImpl(const DataContainer<data_t>& Rx) const override; 58 : 59 : /// implement the polymorphic clone operation 60 : Huber<data_t>* cloneImpl() const override; 61 : 62 : /// implement the polymorphic comparison operation 63 : bool isEqual(const Functional<data_t>& other) const override; 64 : 65 : private: 66 : /// the cut-off delta 67 : data_t delta_; 68 : }; 69 : 70 : } // namespace elsa