LCOV - code coverage report
Current view: top level - elsa/functionals - Functional.cpp (source / functions) Hit Total Coverage
Test: coverage-all.lcov Lines: 71 127 55.9 %
Date: 2024-05-15 03:55:36 Functions: 76 124 61.3 %

          Line data    Source code
       1             : #include "Functional.h"
       2             : #include "DataContainer.h"
       3             : #include "TypeCasts.hpp"
       4             : #include "VolumeDescriptor.h"
       5             : 
       6             : #include <stdexcept>
       7             : 
       8             : namespace elsa
       9             : {
      10             :     template <typename data_t>
      11             :     Functional<data_t>::Functional(const DataDescriptor& domainDescriptor)
      12             :         : _domainDescriptor{domainDescriptor.clone()}
      13        1480 :     {
      14        1480 :     }
      15             : 
      16             :     template <typename data_t>
      17             :     const DataDescriptor& Functional<data_t>::getDomainDescriptor() const
      18        9708 :     {
      19        9708 :         return *_domainDescriptor;
      20        9708 :     }
      21             : 
      22             :     template <typename data_t>
      23             :     bool Functional<data_t>::isDifferentiable() const
      24           0 :     {
      25           0 :         return false;
      26           0 :     }
      27             : 
      28             :     template <typename data_t>
      29             :     bool Functional<data_t>::isProxFriendly() const
      30           0 :     {
      31           0 :         return false;
      32           0 :     }
      33             : 
      34             :     template <typename data_t>
      35             :     bool Functional<data_t>::hasProxDual() const
      36           0 :     {
      37           0 :         return isProxFriendly();
      38           0 :     }
      39             : 
      40             :     template <typename data_t>
      41             :     data_t Functional<data_t>::evaluate(const DataContainer<data_t>& x) const
      42        2606 :     {
      43             :         // TODO: This should compare descriptors shouldn't it?
      44        2606 :         if (x.getSize() != getDomainDescriptor().getNumberOfCoefficients()) {
      45           0 :             throw InvalidArgumentError(
      46           0 :                 "Functional::evaluate: argument size does not match functional");
      47           0 :         }
      48             : 
      49        2606 :         return evaluateImpl(x);
      50        2606 :     }
      51             : 
      52             :     template <class data_t>
      53             :     data_t Functional<data_t>::convexConjugate(const DataContainer<data_t>&) const
      54           0 :     {
      55           0 :         throw Error("Functional: No implementation of convex conjugate");
      56           0 :     }
      57             : 
      58             :     template <typename data_t>
      59             :     DataContainer<data_t> Functional<data_t>::getGradient(const DataContainer<data_t>& x) const
      60        1839 :     {
      61        1839 :         DataContainer<data_t> result(getDomainDescriptor());
      62        1839 :         getGradient(x, result);
      63        1839 :         return result;
      64        1839 :     }
      65             : 
      66             :     template <typename data_t>
      67             :     void Functional<data_t>::getGradient(const DataContainer<data_t>& x,
      68             :                                          DataContainer<data_t>& result) const
      69        4101 :     {
      70        4101 :         if (x.getSize() != getDomainDescriptor().getNumberOfCoefficients()) {
      71           0 :             throw InvalidArgumentError(
      72           0 :                 "Functional::getGradient: argument sizes do not match functional");
      73           0 :         }
      74             : 
      75        4101 :         getGradientImpl(x, result);
      76        4101 :     }
      77             : 
      78             :     template <typename data_t>
      79             :     DataContainer<data_t> Functional<data_t>::proximal(const DataContainer<data_t>&,
      80             :                                                        SelfType_t<data_t>) const
      81           0 :     {
      82           0 :         throw Error("No proximal is implemented for this functional");
      83           0 :     }
      84             : 
      85             :     template <typename data_t>
      86             :     void Functional<data_t>::proximal(const DataContainer<data_t>&, SelfType_t<data_t>,
      87             :                                       DataContainer<data_t>&) const
      88           0 :     {
      89           0 :         throw Error("No proximal is implemented for this functional");
      90           0 :     }
      91             : 
      92             :     template <typename data_t>
      93             :     DataContainer<data_t> Functional<data_t>::proxdual(const DataContainer<data_t>& x,
      94             :                                                        SelfType_t<data_t> tau) const
      95          44 :     {
      96          44 :         auto out = emptylike(x);
      97          44 :         proxdual(x, tau, out);
      98          44 :         return out;
      99          44 :     }
     100             : 
     101             :     template <typename data_t>
     102             :     void Functional<data_t>::proxdual(const DataContainer<data_t>& x, SelfType_t<data_t> tau,
     103             :                                       DataContainer<data_t>& out) const
     104          44 :     {
     105          44 :         if (!isProxFriendly()) {
     106           0 :             throw Error("Cannot compute proximal of convex conjugate via Moreau's identity");
     107           0 :         }
     108             : 
     109             :         // TODO: improve efficiency of this approach
     110          44 :         auto rtau = 1 / tau;
     111          44 :         out = x - tau * proximal(x * rtau, rtau);
     112          44 :     }
     113             : 
     114             :     template <typename data_t>
     115             :     LinearOperator<data_t> Functional<data_t>::getHessian(const DataContainer<data_t>& x) const
     116         116 :     {
     117         116 :         return getHessianImpl(x);
     118         116 :     }
     119             : 
     120             :     template <typename data_t>
     121             :     bool Functional<data_t>::isEqual(const Functional<data_t>& other) const
     122         278 :     {
     123         278 :         return !static_cast<bool>(*_domainDescriptor != *other._domainDescriptor);
     124         278 :     }
     125             : 
     126             :     // ------------------------------------------
     127             :     // FunctionalSum
     128             :     template <class data_t>
     129             :     FunctionalSum<data_t>::FunctionalSum(const Functional<data_t>& lhs,
     130             :                                          const Functional<data_t>& rhs)
     131             :         : Functional<data_t>(lhs.getDomainDescriptor()), lhs_(lhs.clone()), rhs_(rhs.clone())
     132          88 :     {
     133          88 :         if (lhs_->getDomainDescriptor() != rhs_->getDomainDescriptor()) {
     134           0 :             throw InvalidArgumentError("FunctionalSum: domain descriptors need to be the same");
     135           0 :         }
     136          88 :     }
     137             : 
     138             :     template <class data_t>
     139             :     data_t FunctionalSum<data_t>::evaluateImpl(const DataContainer<data_t>& Rx) const
     140         144 :     {
     141         144 :         return lhs_->evaluate(Rx) + rhs_->evaluate(Rx);
     142         144 :     }
     143             : 
     144             :     template <class data_t>
     145             :     void FunctionalSum<data_t>::getGradientImpl(const DataContainer<data_t>& Rx,
     146             :                                                 DataContainer<data_t>& out) const
     147         258 :     {
     148         258 :         auto tmp = Rx;
     149         258 :         lhs_->getGradient(Rx, out);
     150         258 :         rhs_->getGradient(tmp, tmp);
     151         258 :         out += tmp;
     152         258 :     }
     153             : 
     154             :     template <class data_t>
     155             :     LinearOperator<data_t>
     156             :         FunctionalSum<data_t>::getHessianImpl(const DataContainer<data_t>& Rx) const
     157           0 :     {
     158           0 :         return lhs_->getHessian(Rx) + rhs_->getHessian(Rx);
     159           0 :     }
     160             : 
     161             :     template <class data_t>
     162             :     FunctionalSum<data_t>* FunctionalSum<data_t>::cloneImpl() const
     163          64 :     {
     164          64 :         return new FunctionalSum<data_t>(*lhs_, *rhs_);
     165          64 :     }
     166             : 
     167             :     template <class data_t>
     168             :     bool FunctionalSum<data_t>::isEqual(const Functional<data_t>& other) const
     169          12 :     {
     170          12 :         if (!Functional<data_t>::isEqual(other)) {
     171           0 :             return false;
     172           0 :         }
     173             : 
     174          12 :         auto* fn = downcast<FunctionalSum<data_t>>(&other);
     175          12 :         return static_cast<bool>(fn) && (*lhs_) == (*fn->lhs_) && (*rhs_) == (*fn->rhs_);
     176          12 :     }
     177             : 
     178             :     // ------------------------------------------
     179             :     // FunctionalScalarMul
     180             :     template <class data_t>
     181             :     FunctionalScalarMul<data_t>::FunctionalScalarMul(const Functional<data_t>& fn,
     182             :                                                      SelfType_t<data_t> scalar)
     183             :         : Functional<data_t>(fn.getDomainDescriptor()), fn_(fn.clone()), scalar_(scalar)
     184          40 :     {
     185          40 :     }
     186             : 
     187             :     template <class data_t>
     188             :     data_t FunctionalScalarMul<data_t>::evaluateImpl(const DataContainer<data_t>& Rx) const
     189          16 :     {
     190          16 :         return scalar_ * fn_->evaluate(Rx);
     191          16 :     }
     192             : 
     193             :     template <class data_t>
     194             :     data_t FunctionalScalarMul<data_t>::convexConjugate(const DataContainer<data_t>& x) const
     195           0 :     {
     196           0 :         return scalar_ * fn_->evaluate(x / scalar_);
     197           0 :     }
     198             : 
     199             :     template <class data_t>
     200             :     void FunctionalScalarMul<data_t>::getGradientImpl(const DataContainer<data_t>& Rx,
     201             :                                                       DataContainer<data_t>& out) const
     202         116 :     {
     203         116 :         fn_->getGradient(Rx, out);
     204         116 :         out *= scalar_;
     205         116 :     }
     206             : 
     207             :     template <class data_t>
     208             :     LinearOperator<data_t>
     209             :         FunctionalScalarMul<data_t>::getHessianImpl(const DataContainer<data_t>& Rx) const
     210           0 :     {
     211           0 :         return scalar_ * fn_->getHessian(Rx);
     212           0 :     }
     213             : 
     214             :     template <typename data_t>
     215             :     bool FunctionalScalarMul<data_t>::isProxFriendly() const
     216           0 :     {
     217           0 :         return true;
     218           0 :     }
     219             : 
     220             :     template <typename data_t>
     221             :     DataContainer<data_t> FunctionalScalarMul<data_t>::proximal(const DataContainer<data_t>& v,
     222             :                                                                 SelfType_t<data_t> t) const
     223           0 :     {
     224             :         // If scalar is zero, this is equal to the zero functional, and hence the identity proximal
     225             :         // operator
     226           0 :         if (scalar_ == 0) {
     227           0 :             return v;
     228           0 :         }
     229           0 :         return fn_->proximal(v, t * scalar_);
     230           0 :     }
     231             : 
     232             :     template <typename data_t>
     233             :     void FunctionalScalarMul<data_t>::proximal(const DataContainer<data_t>& v, SelfType_t<data_t> t,
     234             :                                                DataContainer<data_t>& out) const
     235           0 :     {
     236           0 :         if (scalar_ == 0) {
     237           0 :             return;
     238           0 :         }
     239           0 :         return fn_->proximal(v, t * scalar_, out);
     240           0 :     }
     241             : 
     242             :     template <class data_t>
     243             :     FunctionalScalarMul<data_t>* FunctionalScalarMul<data_t>::cloneImpl() const
     244          28 :     {
     245          28 :         return new FunctionalScalarMul<data_t>(*fn_, scalar_);
     246          28 :     }
     247             : 
     248             :     template <class data_t>
     249             :     bool FunctionalScalarMul<data_t>::isEqual(const Functional<data_t>& other) const
     250           8 :     {
     251           8 :         if (!Functional<data_t>::isEqual(other)) {
     252           0 :             return false;
     253           0 :         }
     254             : 
     255           8 :         auto* fn = downcast<FunctionalScalarMul<data_t>>(&other);
     256           8 :         return static_cast<bool>(fn) && (*fn_) == (*fn->fn_) && scalar_ == fn->scalar_;
     257           8 :     }
     258             : 
     259             :     // ------------------------------------------
     260             :     // explicit template instantiation
     261             :     template class Functional<float>;
     262             :     template class Functional<double>;
     263             :     template class Functional<complex<float>>;
     264             :     template class Functional<complex<double>>;
     265             : 
     266             :     template class FunctionalSum<float>;
     267             :     template class FunctionalSum<double>;
     268             :     template class FunctionalSum<complex<float>>;
     269             :     template class FunctionalSum<complex<double>>;
     270             : 
     271             :     template class FunctionalScalarMul<float>;
     272             :     template class FunctionalScalarMul<double>;
     273             :     template class FunctionalScalarMul<complex<float>>;
     274             :     template class FunctionalScalarMul<complex<double>>;
     275             : } // namespace elsa

Generated by: LCOV version 1.14