Line data Source code
1 : #include "Huber.h" 2 : #include "Scaling.h" 3 : #include "TypeCasts.hpp" 4 : 5 : #include <cmath> 6 : #include <stdexcept> 7 : 8 : namespace elsa 9 : { 10 : template <typename data_t> 11 : Huber<data_t>::Huber(const DataDescriptor& domainDescriptor, real_t delta) 12 : : Functional<data_t>(domainDescriptor), _delta{delta} 13 14 : { 14 : // sanity check delta 15 14 : if (delta <= static_cast<real_t>(0.0)) 16 0 : throw InvalidArgumentError("Huber: delta has to be positive."); 17 14 : } 18 : 19 : template <typename data_t> 20 : Huber<data_t>::Huber(const elsa::Residual<data_t>& residual, real_t delta) 21 : : Functional<data_t>(residual), _delta{delta} 22 22 : { 23 : // sanity check delta 24 22 : if (delta <= static_cast<real_t>(0.0)) 25 0 : throw InvalidArgumentError("Huber: delta has to be positive."); 26 22 : } 27 : 28 : template <typename data_t> 29 : data_t Huber<data_t>::evaluateImpl(const DataContainer<data_t>& Rx) 30 4 : { 31 : // note: this is currently not a reduction in DataContainer, but implemented here "manually" 32 : 33 4 : auto result = static_cast<data_t>(0.0); 34 : 35 7534 : for (index_t i = 0; i < Rx.getSize(); ++i) { 36 7530 : data_t value = Rx[i]; 37 7530 : if (std::abs(value) <= _delta) 38 7524 : result += static_cast<data_t>(0.5) * value * value; 39 6 : else 40 6 : result += _delta * (std::abs(value) - static_cast<real_t>(0.5) * _delta); 41 7530 : } 42 : 43 4 : return result; 44 4 : } 45 : 46 : template <typename data_t> 47 : void Huber<data_t>::getGradientInPlaceImpl(DataContainer<data_t>& Rx) 48 4 : { 49 7534 : for (index_t i = 0; i < Rx.getSize(); ++i) { 50 7530 : data_t value = Rx[i]; 51 7530 : if (value > _delta) 52 6 : Rx[i] = _delta; 53 7524 : else if (value < -_delta) 54 0 : Rx[i] = -_delta; 55 : // else Rx[i] = Rx[i], i.e. nothing to do for the quadratic case 56 7530 : } 57 4 : } 58 : 59 : template <typename data_t> 60 : LinearOperator<data_t> Huber<data_t>::getHessianImpl(const DataContainer<data_t>& Rx) 61 4 : { 62 4 : DataContainer<data_t> scaleFactors(Rx.getDataDescriptor()); 63 7534 : for (index_t i = 0; i < Rx.getSize(); ++i) { 64 7530 : if (std::abs(Rx[i]) <= _delta) 65 7524 : scaleFactors[i] = static_cast<data_t>(1); 66 6 : else 67 6 : scaleFactors[i] = static_cast<data_t>(0); 68 7530 : } 69 : 70 4 : return leaf(Scaling<data_t>(Rx.getDataDescriptor(), scaleFactors)); 71 4 : } 72 : 73 : template <typename data_t> 74 : Huber<data_t>* Huber<data_t>::cloneImpl() const 75 12 : { 76 12 : return new Huber(this->getResidual(), _delta); 77 12 : } 78 : 79 : template <typename data_t> 80 : bool Huber<data_t>::isEqual(const Functional<data_t>& other) const 81 4 : { 82 4 : if (!Functional<data_t>::isEqual(other)) 83 0 : return false; 84 : 85 4 : auto otherHuber = downcast_safe<Huber>(&other); 86 4 : if (!otherHuber) 87 0 : return false; 88 : 89 4 : if (_delta != otherHuber->_delta) 90 0 : return false; 91 : 92 4 : return true; 93 4 : } 94 : 95 : // ------------------------------------------ 96 : // explicit template instantiation 97 : template class Huber<float>; 98 : template class Huber<double>; 99 : // no complex-number instantiations for Huber! (they would not really be useful) 100 : 101 : } // namespace elsa