LCOV - code coverage report
Current view: top level - elsa/functionals - IndicatorFunctionals.cpp (source / functions) Hit Total Coverage
Test: coverage-all.lcov Lines: 50 67 74.6 %
Date: 2024-05-16 04:22:26 Functions: 28 38 73.7 %

          Line data    Source code
       1             : #include "IndicatorFunctionals.h"
       2             : #include "ProximalBoxConstraint.h"
       3             : #include "DataDescriptor.h"
       4             : #include "DataContainer.h"
       5             : #include "Error.h"
       6             : #include "Functional.h"
       7             : #include "elsaDefines.h"
       8             : #include <limits>
       9             : 
      10             : namespace elsa
      11             : {
      12             :     // ------------------------------------------
      13             :     // IndicatorBox
      14             :     template <class data_t>
      15             :     IndicatorBox<data_t>::IndicatorBox(const DataDescriptor& desc) : Functional<data_t>(desc)
      16           2 :     {
      17           2 :     }
      18             : 
      19             :     template <class data_t>
      20             :     IndicatorBox<data_t>::IndicatorBox(const DataDescriptor& desc, SelfType_t<data_t> lower,
      21             :                                        SelfType_t<data_t> upper)
      22             :         : Functional<data_t>(desc), lower_(lower), upper_(upper)
      23          14 :     {
      24          14 :     }
      25             : 
      26             :     template <class data_t>
      27             :     data_t IndicatorBox<data_t>::evaluateImpl(const DataContainer<data_t>& Rx) const
      28          46 :     {
      29             :         // Project the input value onto the box set
      30          46 :         auto projected = ::elsa::clip(Rx, lower_, upper_);
      31             : 
      32             :         // Check if anything changed, by computing the distance
      33          46 :         return (projected - Rx).l2Norm() > 0 ? std::numeric_limits<data_t>::infinity() : 0;
      34          46 :     }
      35             : 
      36             :     template <class data_t>
      37             :     data_t IndicatorBox<data_t>::convexConjugate(const DataContainer<data_t>& x) const
      38          10 :     {
      39          10 :         return elsa::maximum(x, 0).sum();
      40          10 :     }
      41             : 
      42             :     template <class data_t>
      43             :     void IndicatorBox<data_t>::getGradientImpl(const DataContainer<data_t>&,
      44             :                                                DataContainer<data_t>&) const
      45           0 :     {
      46           0 :         throw NotImplementedError("IndicatorBox: Not differentiable");
      47           0 :     }
      48             : 
      49             :     template <class data_t>
      50             :     LinearOperator<data_t> IndicatorBox<data_t>::getHessianImpl(const DataContainer<data_t>&) const
      51           0 :     {
      52           0 :         throw NotImplementedError("IndicatorBox: Not differentiable");
      53           0 :     }
      54             : 
      55             :     template <class data_t>
      56             :     DataContainer<data_t>
      57             :         IndicatorBox<data_t>::proximal(const DataContainer<data_t>& v,
      58             :                                        [[maybe_unused]] SelfType_t<data_t> t) const
      59           8 :     {
      60           8 :         DataContainer<data_t> out(v.getDataDescriptor());
      61           8 :         proximal(v, t, out);
      62           8 :         return out;
      63           8 :     }
      64             : 
      65             :     template <class data_t>
      66             :     void IndicatorBox<data_t>::proximal(const DataContainer<data_t>& v,
      67             :                                         [[maybe_unused]] SelfType_t<data_t> t,
      68             :                                         DataContainer<data_t>& out) const
      69           8 :     {
      70           8 :         ProximalBoxConstraint<data_t> prox(lower_, upper_);
      71           8 :         prox.apply(v, t, out);
      72           8 :     }
      73             : 
      74             :     template <class data_t>
      75             :     IndicatorBox<data_t>* IndicatorBox<data_t>::cloneImpl() const
      76           4 :     {
      77           4 :         return new IndicatorBox<data_t>(this->getDomainDescriptor(), lower_, upper_);
      78           4 :     }
      79             : 
      80             :     template <class data_t>
      81             :     bool IndicatorBox<data_t>::isEqual(const Functional<data_t>& other) const
      82           4 :     {
      83           4 :         if (!Functional<data_t>::isEqual(other)) {
      84           0 :             return false;
      85           0 :         }
      86             : 
      87           4 :         auto* fn = downcast<IndicatorBox<data_t>>(&other);
      88           4 :         return static_cast<bool>(fn) && lower_ == fn->lower_ && upper_ == fn->upper_;
      89           4 :     }
      90             : 
      91             :     // ------------------------------------------
      92             :     // IndicatorNonNegativity
      93             :     template <class data_t>
      94             :     IndicatorNonNegativity<data_t>::IndicatorNonNegativity(const DataDescriptor& desc)
      95             :         : Functional<data_t>(desc)
      96          24 :     {
      97          24 :     }
      98             : 
      99             :     template <class data_t>
     100             :     data_t IndicatorNonNegativity<data_t>::evaluateImpl(const DataContainer<data_t>& Rx) const
     101          22 :     {
     102          22 :         constexpr auto infinity = std::numeric_limits<data_t>::infinity();
     103             : 
     104             :         // Project the input value onto the box set
     105          22 :         auto projected = ::elsa::clip(Rx, data_t{0}, infinity);
     106             : 
     107             :         // Check if anything changed, by computing the distance
     108          22 :         return (projected - Rx).l2Norm() > 0 ? infinity : 0;
     109          22 :     }
     110             : 
     111             :     template <class data_t>
     112             :     data_t IndicatorNonNegativity<data_t>::convexConjugate(const DataContainer<data_t>& x) const
     113           0 :     {
     114           0 :         return elsa::maximum(x, 0).sum();
     115           0 :     }
     116             : 
     117             :     template <class data_t>
     118             :     void IndicatorNonNegativity<data_t>::getGradientImpl(const DataContainer<data_t>&,
     119             :                                                          DataContainer<data_t>&) const
     120           0 :     {
     121           0 :         throw NotImplementedError("IndicatorNonNegativity: Not differentiable");
     122           0 :     }
     123             : 
     124             :     template <class data_t>
     125             :     LinearOperator<data_t>
     126             :         IndicatorNonNegativity<data_t>::getHessianImpl(const DataContainer<data_t>&) const
     127           0 :     {
     128           0 :         throw NotImplementedError("IndicatorNonNegativity: Not differentiable");
     129           0 :     }
     130             : 
     131             :     template <class data_t>
     132             :     DataContainer<data_t>
     133             :         IndicatorNonNegativity<data_t>::proximal(const DataContainer<data_t>& v,
     134             :                                                  [[maybe_unused]] SelfType_t<data_t> t) const
     135          10 :     {
     136          10 :         DataContainer<data_t> out(v.getDataDescriptor());
     137          10 :         proximal(v, t, out);
     138          10 :         return out;
     139          10 :     }
     140             : 
     141             :     template <class data_t>
     142             :     void IndicatorNonNegativity<data_t>::proximal(const DataContainer<data_t>& v,
     143             :                                                   [[maybe_unused]] SelfType_t<data_t> t,
     144             :                                                   DataContainer<data_t>& out) const
     145          10 :     {
     146          10 :         ProximalBoxConstraint<data_t> prox(0);
     147          10 :         prox.apply(v, t, out);
     148          10 :     }
     149             : 
     150             :     template <class data_t>
     151             :     IndicatorNonNegativity<data_t>* IndicatorNonNegativity<data_t>::cloneImpl() const
     152          12 :     {
     153          12 :         return new IndicatorNonNegativity<data_t>(this->getDomainDescriptor());
     154          12 :     }
     155             : 
     156             :     template <class data_t>
     157             :     bool IndicatorNonNegativity<data_t>::isEqual(const Functional<data_t>& other) const
     158           4 :     {
     159           4 :         return Functional<data_t>::isEqual(other) && is<IndicatorNonNegativity<data_t>>(&other);
     160           4 :     }
     161             : 
     162             :     // ------------------------------------------
     163             :     // explicit template instantiation
     164             :     template class IndicatorBox<float>;
     165             :     template class IndicatorBox<double>;
     166             : 
     167             :     template class IndicatorNonNegativity<float>;
     168             :     template class IndicatorNonNegativity<double>;
     169             : } // namespace elsa

Generated by: LCOV version 1.14