LCOV - code coverage report
Current view: top level - elsa/functionals - ConstantFunctional.cpp (source / functions) Hit Total Coverage
Test: coverage-all.lcov Lines: 44 71 62.0 %
Date: 2024-05-16 04:22:26 Functions: 56 92 60.9 %

          Line data    Source code
       1             : #include "ConstantFunctional.h"
       2             : #include "Error.h"
       3             : #include "Functional.h"
       4             : #include "TypeCasts.hpp"
       5             : #include "elsaDefines.h"
       6             : #include <unistd.h>
       7             : #include <utility>
       8             : 
       9             : namespace elsa
      10             : {
      11             : 
      12             :     template <typename data_t>
      13             :     ConstantFunctional<data_t>::ConstantFunctional(const DataDescriptor& descriptor,
      14             :                                                    SelfType_t<data_t> constant)
      15             :         : Functional<data_t>(descriptor), constant_(constant)
      16         112 :     {
      17         112 :     }
      18             : 
      19             :     template <typename data_t>
      20             :     bool ConstantFunctional<data_t>::isDifferentiable() const
      21           0 :     {
      22           0 :         return true;
      23           0 :     }
      24             : 
      25             :     template <typename data_t>
      26             :     bool ConstantFunctional<data_t>::isProxFriendly() const
      27          12 :     {
      28          12 :         return true;
      29          12 :     }
      30             : 
      31             :     template <typename data_t>
      32             :     data_t ConstantFunctional<data_t>::getConstant() const
      33           0 :     {
      34           0 :         return constant_;
      35           0 :     }
      36             : 
      37             :     template <typename data_t>
      38             :     data_t ConstantFunctional<data_t>::evaluateImpl(const DataContainer<data_t>&) const
      39          20 :     {
      40          20 :         return constant_;
      41          20 :     }
      42             : 
      43             :     template <class data_t>
      44             :     data_t ConstantFunctional<data_t>::convexConjugate(const DataContainer<data_t>& x) const
      45          12 :     {
      46          12 :         return elsa::maximum(x, 0).sum();
      47          12 :     }
      48             : 
      49             :     template <typename data_t>
      50             :     void ConstantFunctional<data_t>::getGradientImpl(const DataContainer<data_t>&,
      51             :                                                      DataContainer<data_t>& out) const
      52           0 :     {
      53           0 :         out = 0;
      54           0 :     }
      55             : 
      56             :     template <typename data_t>
      57             :     LinearOperator<data_t>
      58             :         ConstantFunctional<data_t>::getHessianImpl(const DataContainer<data_t>&) const
      59           0 :     {
      60           0 :         throw NotImplementedError("ConstantFunctional: not twice differentiable");
      61           0 :     }
      62             : 
      63             :     template <typename data_t>
      64             :     DataContainer<data_t> ConstantFunctional<data_t>::proximal(const DataContainer<data_t>& v,
      65             :                                                                SelfType_t<data_t>) const
      66          12 :     {
      67          12 :         return v;
      68          12 :     }
      69             : 
      70             :     template <typename data_t>
      71             :     void ConstantFunctional<data_t>::proximal(const DataContainer<data_t>& v, SelfType_t<data_t>,
      72             :                                               DataContainer<data_t>& out) const
      73           0 :     {
      74           0 :         out = v;
      75           0 :     }
      76             : 
      77             :     template <typename data_t>
      78             :     ConstantFunctional<data_t>* ConstantFunctional<data_t>::cloneImpl() const
      79          56 :     {
      80          56 :         return new ConstantFunctional<data_t>(this->getDomainDescriptor(), constant_);
      81          56 :     }
      82             : 
      83             :     template <typename data_t>
      84             :     bool ConstantFunctional<data_t>::isEqual(const Functional<data_t>& other) const
      85          40 :     {
      86          40 :         if (!Functional<data_t>::isEqual(other)) {
      87           8 :             return false;
      88           8 :         }
      89             : 
      90          32 :         auto* fn = downcast<ConstantFunctional<data_t>>(&other);
      91          32 :         return static_cast<bool>(fn) && constant_ == fn->constant_;
      92          32 :     }
      93             : 
      94             :     // ------------------------------------------
      95             :     // Zero Functional
      96             :     template <typename data_t>
      97             :     ZeroFunctional<data_t>::ZeroFunctional(const DataDescriptor& descriptor)
      98             :         : Functional<data_t>(descriptor)
      99          80 :     {
     100          80 :     }
     101             : 
     102             :     template <typename data_t>
     103             :     bool ZeroFunctional<data_t>::isDifferentiable() const
     104           0 :     {
     105           0 :         return true;
     106           0 :     }
     107             : 
     108             :     template <typename data_t>
     109             :     bool ZeroFunctional<data_t>::isProxFriendly() const
     110          36 :     {
     111          36 :         return true;
     112          36 :     }
     113             : 
     114             :     template <typename data_t>
     115             :     data_t ZeroFunctional<data_t>::evaluateImpl(const DataContainer<data_t>&) const
     116         216 :     {
     117         216 :         return 0;
     118         216 :     }
     119             : 
     120             :     template <class data_t>
     121             :     data_t ZeroFunctional<data_t>::convexConjugate(const DataContainer<data_t>& x) const
     122          12 :     {
     123          12 :         return elsa::maximum(x, 0).sum();
     124          12 :     }
     125             : 
     126             :     template <typename data_t>
     127             :     void ZeroFunctional<data_t>::getGradientImpl(const DataContainer<data_t>&,
     128             :                                                  DataContainer<data_t>& out) const
     129           0 :     {
     130           0 :         out = 0;
     131           0 :     }
     132             : 
     133             :     template <typename data_t>
     134             :     LinearOperator<data_t>
     135             :         ZeroFunctional<data_t>::getHessianImpl(const DataContainer<data_t>&) const
     136           0 :     {
     137           0 :         throw NotImplementedError("ZeroFunctional: not twice differentiable");
     138           0 :     }
     139             : 
     140             :     template <typename data_t>
     141             :     DataContainer<data_t>
     142             :         ZeroFunctional<data_t>::proximal(const DataContainer<data_t>& v,
     143             :                                          [[maybe_unused]] SelfType_t<data_t>) const
     144         224 :     {
     145         224 :         return v;
     146         224 :     }
     147             : 
     148             :     template <typename data_t>
     149             :     void ZeroFunctional<data_t>::proximal(const DataContainer<data_t>& v,
     150             :                                           [[maybe_unused]] SelfType_t<data_t>,
     151             :                                           DataContainer<data_t>& out) const
     152           0 :     {
     153           0 :         out = v;
     154           0 :     }
     155             : 
     156             :     template <typename data_t>
     157             :     ZeroFunctional<data_t>* ZeroFunctional<data_t>::cloneImpl() const
     158          40 :     {
     159          40 :         return new ZeroFunctional<data_t>(this->getDomainDescriptor());
     160          40 :     }
     161             : 
     162             :     template <typename data_t>
     163             :     bool ZeroFunctional<data_t>::isEqual(const Functional<data_t>& other) const
     164          32 :     {
     165          32 :         return Functional<data_t>::isEqual(other) && is<ZeroFunctional<data_t>>(other);
     166          32 :     }
     167             : 
     168             :     // ------------------------------------------
     169             :     // explicit template instantiation
     170             :     template class ConstantFunctional<float>;
     171             :     template class ConstantFunctional<double>;
     172             :     template class ConstantFunctional<complex<float>>;
     173             :     template class ConstantFunctional<complex<double>>;
     174             : 
     175             :     template class ZeroFunctional<float>;
     176             :     template class ZeroFunctional<double>;
     177             :     template class ZeroFunctional<complex<float>>;
     178             :     template class ZeroFunctional<complex<double>>;
     179             : } // namespace elsa

Generated by: LCOV version 1.14