LCOV - code coverage report
Current view: top level - functionals - Huber.cpp (source / functions) Hit Total Coverage
Test: test_coverage.info.cleaned Lines: 0 45 0.0 %
Date: 2022-08-04 03:43:28 Functions: 0 14 0.0 %

          Line data    Source code
       1             : #include "Huber.h"
       2             : #include "Scaling.h"
       3             : #include "TypeCasts.hpp"
       4             : 
       5             : #include <cmath>
       6             : #include <stdexcept>
       7             : 
       8             : namespace elsa
       9             : {
      10             :     template <typename data_t>
      11           0 :     Huber<data_t>::Huber(const DataDescriptor& domainDescriptor, real_t delta)
      12           0 :         : Functional<data_t>(domainDescriptor), _delta{delta}
      13             :     {
      14             :         // sanity check delta
      15           0 :         if (delta <= static_cast<real_t>(0.0))
      16           0 :             throw InvalidArgumentError("Huber: delta has to be positive.");
      17           0 :     }
      18             : 
      19             :     template <typename data_t>
      20           0 :     Huber<data_t>::Huber(const elsa::Residual<data_t>& residual, real_t delta)
      21           0 :         : Functional<data_t>(residual), _delta{delta}
      22             :     {
      23             :         // sanity check delta
      24           0 :         if (delta <= static_cast<real_t>(0.0))
      25           0 :             throw InvalidArgumentError("Huber: delta has to be positive.");
      26           0 :     }
      27             : 
      28             :     template <typename data_t>
      29           0 :     data_t Huber<data_t>::evaluateImpl(const DataContainer<data_t>& Rx)
      30             :     {
      31             :         // note: this is currently not a reduction in DataContainer, but implemented here "manually"
      32             : 
      33           0 :         auto result = static_cast<data_t>(0.0);
      34             : 
      35           0 :         for (index_t i = 0; i < Rx.getSize(); ++i) {
      36           0 :             data_t value = Rx[i];
      37           0 :             if (std::abs(value) <= _delta)
      38           0 :                 result += static_cast<data_t>(0.5) * value * value;
      39             :             else
      40           0 :                 result += _delta * (std::abs(value) - static_cast<real_t>(0.5) * _delta);
      41             :         }
      42             : 
      43           0 :         return result;
      44             :     }
      45             : 
      46             :     template <typename data_t>
      47           0 :     void Huber<data_t>::getGradientInPlaceImpl(DataContainer<data_t>& Rx)
      48             :     {
      49           0 :         for (index_t i = 0; i < Rx.getSize(); ++i) {
      50           0 :             data_t value = Rx[i];
      51           0 :             if (value > _delta)
      52           0 :                 Rx[i] = _delta;
      53           0 :             else if (value < -_delta)
      54           0 :                 Rx[i] = -_delta;
      55             :             // else Rx[i] = Rx[i], i.e. nothing to do for the quadratic case
      56             :         }
      57           0 :     }
      58             : 
      59             :     template <typename data_t>
      60           0 :     LinearOperator<data_t> Huber<data_t>::getHessianImpl(const DataContainer<data_t>& Rx)
      61             :     {
      62           0 :         DataContainer<data_t> scaleFactors(Rx.getDataDescriptor());
      63           0 :         for (index_t i = 0; i < Rx.getSize(); ++i) {
      64           0 :             if (std::abs(Rx[i]) <= _delta)
      65           0 :                 scaleFactors[i] = static_cast<data_t>(1);
      66             :             else
      67           0 :                 scaleFactors[i] = static_cast<data_t>(0);
      68             :         }
      69             : 
      70           0 :         return leaf(Scaling<data_t>(Rx.getDataDescriptor(), scaleFactors));
      71           0 :     }
      72             : 
      73             :     template <typename data_t>
      74           0 :     Huber<data_t>* Huber<data_t>::cloneImpl() const
      75             :     {
      76           0 :         return new Huber(this->getResidual(), _delta);
      77             :     }
      78             : 
      79             :     template <typename data_t>
      80           0 :     bool Huber<data_t>::isEqual(const Functional<data_t>& other) const
      81             :     {
      82           0 :         if (!Functional<data_t>::isEqual(other))
      83           0 :             return false;
      84             : 
      85           0 :         auto otherHuber = downcast_safe<Huber>(&other);
      86           0 :         if (!otherHuber)
      87           0 :             return false;
      88             : 
      89           0 :         if (_delta != otherHuber->_delta)
      90           0 :             return false;
      91             : 
      92           0 :         return true;
      93             :     }
      94             : 
      95             :     // ------------------------------------------
      96             :     // explicit template instantiation
      97             :     template class Huber<float>;
      98             :     template class Huber<double>;
      99             :     // no complex-number instantiations for Huber! (they would not really be useful)
     100             : 
     101             : } // namespace elsa

Generated by: LCOV version 1.14