LCOV - code coverage report
Current view: top level - elsa/functionals - Huber.cpp (source / functions) Hit Total Coverage
Test: coverage-all.lcov Lines: 45 51 88.2 %
Date: 2022-08-25 03:05:39 Functions: 14 14 100.0 %

          Line data    Source code
       1             : #include "Huber.h"
       2             : #include "Scaling.h"
       3             : #include "TypeCasts.hpp"
       4             : 
       5             : #include <cmath>
       6             : #include <stdexcept>
       7             : 
       8             : namespace elsa
       9             : {
      10             :     template <typename data_t>
      11             :     Huber<data_t>::Huber(const DataDescriptor& domainDescriptor, real_t delta)
      12             :         : Functional<data_t>(domainDescriptor), _delta{delta}
      13          14 :     {
      14             :         // sanity check delta
      15          14 :         if (delta <= static_cast<real_t>(0.0))
      16           0 :             throw InvalidArgumentError("Huber: delta has to be positive.");
      17          14 :     }
      18             : 
      19             :     template <typename data_t>
      20             :     Huber<data_t>::Huber(const elsa::Residual<data_t>& residual, real_t delta)
      21             :         : Functional<data_t>(residual), _delta{delta}
      22          22 :     {
      23             :         // sanity check delta
      24          22 :         if (delta <= static_cast<real_t>(0.0))
      25           0 :             throw InvalidArgumentError("Huber: delta has to be positive.");
      26          22 :     }
      27             : 
      28             :     template <typename data_t>
      29             :     data_t Huber<data_t>::evaluateImpl(const DataContainer<data_t>& Rx)
      30           4 :     {
      31             :         // note: this is currently not a reduction in DataContainer, but implemented here "manually"
      32             : 
      33           4 :         auto result = static_cast<data_t>(0.0);
      34             : 
      35        7534 :         for (index_t i = 0; i < Rx.getSize(); ++i) {
      36        7530 :             data_t value = Rx[i];
      37        7530 :             if (std::abs(value) <= _delta)
      38        7524 :                 result += static_cast<data_t>(0.5) * value * value;
      39           6 :             else
      40           6 :                 result += _delta * (std::abs(value) - static_cast<real_t>(0.5) * _delta);
      41        7530 :         }
      42             : 
      43           4 :         return result;
      44           4 :     }
      45             : 
      46             :     template <typename data_t>
      47             :     void Huber<data_t>::getGradientInPlaceImpl(DataContainer<data_t>& Rx)
      48           4 :     {
      49        7534 :         for (index_t i = 0; i < Rx.getSize(); ++i) {
      50        7530 :             data_t value = Rx[i];
      51        7530 :             if (value > _delta)
      52           6 :                 Rx[i] = _delta;
      53        7524 :             else if (value < -_delta)
      54           0 :                 Rx[i] = -_delta;
      55             :             // else Rx[i] = Rx[i], i.e. nothing to do for the quadratic case
      56        7530 :         }
      57           4 :     }
      58             : 
      59             :     template <typename data_t>
      60             :     LinearOperator<data_t> Huber<data_t>::getHessianImpl(const DataContainer<data_t>& Rx)
      61           4 :     {
      62           4 :         DataContainer<data_t> scaleFactors(Rx.getDataDescriptor());
      63        7534 :         for (index_t i = 0; i < Rx.getSize(); ++i) {
      64        7530 :             if (std::abs(Rx[i]) <= _delta)
      65        7524 :                 scaleFactors[i] = static_cast<data_t>(1);
      66           6 :             else
      67           6 :                 scaleFactors[i] = static_cast<data_t>(0);
      68        7530 :         }
      69             : 
      70           4 :         return leaf(Scaling<data_t>(Rx.getDataDescriptor(), scaleFactors));
      71           4 :     }
      72             : 
      73             :     template <typename data_t>
      74             :     Huber<data_t>* Huber<data_t>::cloneImpl() const
      75          12 :     {
      76          12 :         return new Huber(this->getResidual(), _delta);
      77          12 :     }
      78             : 
      79             :     template <typename data_t>
      80             :     bool Huber<data_t>::isEqual(const Functional<data_t>& other) const
      81           4 :     {
      82           4 :         if (!Functional<data_t>::isEqual(other))
      83           0 :             return false;
      84             : 
      85           4 :         auto otherHuber = downcast_safe<Huber>(&other);
      86           4 :         if (!otherHuber)
      87           0 :             return false;
      88             : 
      89           4 :         if (_delta != otherHuber->_delta)
      90           0 :             return false;
      91             : 
      92           4 :         return true;
      93           4 :     }
      94             : 
      95             :     // ------------------------------------------
      96             :     // explicit template instantiation
      97             :     template class Huber<float>;
      98             :     template class Huber<double>;
      99             :     // no complex-number instantiations for Huber! (they would not really be useful)
     100             : 
     101             : } // namespace elsa

Generated by: LCOV version 1.14