LCOV - code coverage report
Current view: top level - elsa/functionals - Huber.cpp (source / functions) Hit Total Coverage
Test: coverage-all.lcov Lines: 46 54 85.2 %
Date: 2024-05-16 04:22:26 Functions: 12 14 85.7 %

          Line data    Source code
       1             : #include "Huber.h"
       2             : #include "DataContainer.h"
       3             : #include "Scaling.h"
       4             : #include "TypeCasts.hpp"
       5             : 
       6             : #include <cmath>
       7             : #include <stdexcept>
       8             : 
       9             : namespace elsa
      10             : {
      11             :     template <typename data_t>
      12             :     Huber<data_t>::Huber(const DataDescriptor& domainDescriptor, real_t delta)
      13             :         : Functional<data_t>(domainDescriptor), delta_{delta}
      14          12 :     {
      15             :         // sanity check delta
      16          12 :         if (delta <= static_cast<real_t>(0.0))
      17           0 :             throw InvalidArgumentError("Huber: delta has to be positive.");
      18          12 :     }
      19             : 
      20             :     template <typename data_t>
      21             :     bool Huber<data_t>::isDifferentiable() const
      22           0 :     {
      23           0 :         return true;
      24           0 :     }
      25             : 
      26             :     template <typename data_t>
      27             :     data_t Huber<data_t>::evaluateImpl(const DataContainer<data_t>& x) const
      28           2 :     {
      29             :         // note: this is currently not a reduction in DataContainer, but implemented here "manually"
      30           2 :         auto result = data_t{0.0};
      31             : 
      32        6498 :         for (index_t i = 0; i < x.getSize(); ++i) {
      33        6496 :             data_t value = x[i];
      34        6496 :             if (std::abs(value) <= delta_)
      35        6490 :                 result += static_cast<data_t>(0.5) * value * value;
      36           6 :             else
      37           6 :                 result += delta_ * (std::abs(value) - static_cast<real_t>(0.5) * delta_);
      38        6496 :         }
      39             : 
      40           2 :         return result;
      41           2 :     }
      42             : 
      43             :     template <typename data_t>
      44             :     void Huber<data_t>::getGradientImpl(const DataContainer<data_t>& x,
      45             :                                         DataContainer<data_t>& out) const
      46           2 :     {
      47        6498 :         for (index_t i = 0; i < x.getSize(); ++i) {
      48        6496 :             data_t value = x[i];
      49        6496 :             if (value > delta_) {
      50           6 :                 out[i] = delta_;
      51        6490 :             } else if (value < -delta_) {
      52           0 :                 out[i] = -delta_;
      53        6490 :             } else {
      54        6490 :                 out[i] = x[i];
      55        6490 :             }
      56        6496 :         }
      57           2 :     }
      58             : 
      59             :     template <typename data_t>
      60             :     LinearOperator<data_t> Huber<data_t>::getHessianImpl(const DataContainer<data_t>& x) const
      61           2 :     {
      62           2 :         DataContainer<data_t> s(x.getDataDescriptor());
      63        6498 :         for (index_t i = 0; i < x.getSize(); ++i) {
      64        6496 :             if (std::abs(x[i]) <= delta_) {
      65        6490 :                 s[i] = data_t{1};
      66        6490 :             } else {
      67           6 :                 s[i] = data_t{0};
      68           6 :             }
      69        6496 :         }
      70             : 
      71           2 :         return leaf(Scaling<data_t>(s));
      72           2 :     }
      73             : 
      74             :     template <typename data_t>
      75             :     Huber<data_t>* Huber<data_t>::cloneImpl() const
      76           2 :     {
      77           2 :         return new Huber(this->getDomainDescriptor(), delta_);
      78           2 :     }
      79             : 
      80             :     template <typename data_t>
      81             :     bool Huber<data_t>::isEqual(const Functional<data_t>& other) const
      82           2 :     {
      83           2 :         if (!Functional<data_t>::isEqual(other))
      84           0 :             return false;
      85             : 
      86           2 :         auto otherHuber = downcast_safe<Huber>(&other);
      87           2 :         if (!otherHuber)
      88           0 :             return false;
      89             : 
      90           2 :         if (delta_ != otherHuber->delta_)
      91           0 :             return false;
      92             : 
      93           2 :         return true;
      94           2 :     }
      95             : 
      96             :     // ------------------------------------------
      97             :     // explicit template instantiation
      98             :     template class Huber<float>;
      99             :     template class Huber<double>;
     100             :     // no complex-number instantiations for Huber! (they would not really be useful)
     101             : 
     102             : } // namespace elsa

Generated by: LCOV version 1.14