Line data Source code
1 : #include "Huber.h" 2 : #include "DataContainer.h" 3 : #include "Scaling.h" 4 : #include "TypeCasts.hpp" 5 : 6 : #include <cmath> 7 : #include <stdexcept> 8 : 9 : namespace elsa 10 : { 11 : template <typename data_t> 12 : Huber<data_t>::Huber(const DataDescriptor& domainDescriptor, real_t delta) 13 : : Functional<data_t>(domainDescriptor), delta_{delta} 14 12 : { 15 : // sanity check delta 16 12 : if (delta <= static_cast<real_t>(0.0)) 17 0 : throw InvalidArgumentError("Huber: delta has to be positive."); 18 12 : } 19 : 20 : template <typename data_t> 21 : bool Huber<data_t>::isDifferentiable() const 22 0 : { 23 0 : return true; 24 0 : } 25 : 26 : template <typename data_t> 27 : data_t Huber<data_t>::evaluateImpl(const DataContainer<data_t>& x) const 28 2 : { 29 : // note: this is currently not a reduction in DataContainer, but implemented here "manually" 30 2 : auto result = data_t{0.0}; 31 : 32 6498 : for (index_t i = 0; i < x.getSize(); ++i) { 33 6496 : data_t value = x[i]; 34 6496 : if (std::abs(value) <= delta_) 35 6490 : result += static_cast<data_t>(0.5) * value * value; 36 6 : else 37 6 : result += delta_ * (std::abs(value) - static_cast<real_t>(0.5) * delta_); 38 6496 : } 39 : 40 2 : return result; 41 2 : } 42 : 43 : template <typename data_t> 44 : void Huber<data_t>::getGradientImpl(const DataContainer<data_t>& x, 45 : DataContainer<data_t>& out) const 46 2 : { 47 6498 : for (index_t i = 0; i < x.getSize(); ++i) { 48 6496 : data_t value = x[i]; 49 6496 : if (value > delta_) { 50 6 : out[i] = delta_; 51 6490 : } else if (value < -delta_) { 52 0 : out[i] = -delta_; 53 6490 : } else { 54 6490 : out[i] = x[i]; 55 6490 : } 56 6496 : } 57 2 : } 58 : 59 : template <typename data_t> 60 : LinearOperator<data_t> Huber<data_t>::getHessianImpl(const DataContainer<data_t>& x) const 61 2 : { 62 2 : DataContainer<data_t> s(x.getDataDescriptor()); 63 6498 : for (index_t i = 0; i < x.getSize(); ++i) { 64 6496 : if (std::abs(x[i]) <= delta_) { 65 6490 : s[i] = data_t{1}; 66 6490 : } else { 67 6 : s[i] = data_t{0}; 68 6 : } 69 6496 : } 70 : 71 2 : return leaf(Scaling<data_t>(s)); 72 2 : } 73 : 74 : template <typename data_t> 75 : Huber<data_t>* Huber<data_t>::cloneImpl() const 76 2 : { 77 2 : return new Huber(this->getDomainDescriptor(), delta_); 78 2 : } 79 : 80 : template <typename data_t> 81 : bool Huber<data_t>::isEqual(const Functional<data_t>& other) const 82 2 : { 83 2 : if (!Functional<data_t>::isEqual(other)) 84 0 : return false; 85 : 86 2 : auto otherHuber = downcast_safe<Huber>(&other); 87 2 : if (!otherHuber) 88 0 : return false; 89 : 90 2 : if (delta_ != otherHuber->delta_) 91 0 : return false; 92 : 93 2 : return true; 94 2 : } 95 : 96 : // ------------------------------------------ 97 : // explicit template instantiation 98 : template class Huber<float>; 99 : template class Huber<double>; 100 : // no complex-number instantiations for Huber! (they would not really be useful) 101 : 102 : } // namespace elsa