LCOV - code coverage report
Current view: top level - elsa/functionals - L2Squared.cpp (source / functions) Hit Total Coverage
Test: coverage-all.lcov Lines: 62 74 83.8 %
Date: 2024-05-16 04:22:26 Functions: 24 28 85.7 %

          Line data    Source code
       1             : #include "L2Squared.h"
       2             : 
       3             : #include "DataContainer.h"
       4             : #include "DataDescriptor.h"
       5             : #include "Identity.h"
       6             : #include "LinearOperator.h"
       7             : #include "TypeCasts.hpp"
       8             : #include "ProximalL2Squared.h"
       9             : 
      10             : namespace elsa
      11             : {
      12             :     template <typename data_t>
      13             :     L2Squared<data_t>::L2Squared(const DataDescriptor& domainDescriptor)
      14             :         : Functional<data_t>(domainDescriptor)
      15          96 :     {
      16          96 :     }
      17             : 
      18             :     template <typename data_t>
      19             :     L2Squared<data_t>::L2Squared(const DataContainer<data_t>& b)
      20             :         : Functional<data_t>(b.getDataDescriptor()), b_(b)
      21          16 :     {
      22          16 :     }
      23             : 
      24             :     template <typename data_t>
      25             :     bool L2Squared<data_t>::isDifferentiable() const
      26           0 :     {
      27           0 :         return true;
      28           0 :     }
      29             : 
      30             :     template <typename data_t>
      31             :     bool L2Squared<data_t>::isProxFriendly() const
      32           4 :     {
      33           4 :         return true;
      34           4 :     }
      35             : 
      36             :     template <typename data_t>
      37             :     bool L2Squared<data_t>::hasDataVector() const
      38         448 :     {
      39         448 :         return b_.has_value();
      40         448 :     }
      41             : 
      42             :     template <typename data_t>
      43             :     const DataContainer<data_t>& L2Squared<data_t>::getDataVector() const
      44           0 :     {
      45           0 :         if (!hasDataVector()) {
      46           0 :             throw Error("L2Squared: No data vector present");
      47           0 :         }
      48             : 
      49           0 :         return *b_;
      50           0 :     }
      51             : 
      52             :     template <typename data_t>
      53             :     data_t L2Squared<data_t>::evaluateImpl(const DataContainer<data_t>& x) const
      54         124 :     {
      55         124 :         if (!hasDataVector()) {
      56         122 :             return data_t{0.5} * x.squaredL2Norm();
      57         122 :         }
      58             : 
      59           2 :         return data_t{0.5} * (x - *b_).squaredL2Norm();
      60           2 :     }
      61             : 
      62             :     template <typename data_t>
      63             :     data_t L2Squared<data_t>::convexConjugate(const DataContainer<data_t>& x) const
      64           4 :     {
      65           4 :         auto res = 0.25 * x.squaredL2Norm();
      66             : 
      67           4 :         if (hasDataVector()) {
      68           2 :             res += x.dot(*b_);
      69           2 :         }
      70             : 
      71           4 :         return res;
      72           4 :     }
      73             : 
      74             :     template <typename data_t>
      75             :     void L2Squared<data_t>::getGradientImpl(const DataContainer<data_t>& x,
      76             :                                             DataContainer<data_t>& out) const
      77         244 :     {
      78         244 :         if (!hasDataVector()) {
      79         242 :             out = x;
      80         242 :         } else {
      81           2 :             out = x - *b_;
      82           2 :         }
      83         244 :     }
      84             : 
      85             :     template <typename data_t>
      86             :     LinearOperator<data_t> L2Squared<data_t>::getHessianImpl(const DataContainer<data_t>& Rx) const
      87           4 :     {
      88           4 :         return leaf(Identity<data_t>(Rx.getDataDescriptor()));
      89           4 :     }
      90             : 
      91             :     template <typename data_t>
      92             :     DataContainer<data_t> L2Squared<data_t>::proximal(const DataContainer<data_t>& v,
      93             :                                                       SelfType_t<data_t> tau) const
      94           8 :     {
      95           8 :         auto out = emptylike(v);
      96           8 :         proximal(v, tau, out);
      97           8 :         return out;
      98           8 :     }
      99             : 
     100             :     template <typename data_t>
     101             :     void L2Squared<data_t>::proximal(const DataContainer<data_t>& v, SelfType_t<data_t> t,
     102             :                                      DataContainer<data_t>& out) const
     103           8 :     {
     104           8 :         if (!hasDataVector()) {
     105           4 :             ProximalL2Squared<data_t> prox;
     106           4 :             prox.apply(v, t, out);
     107           4 :         } else {
     108           4 :             ProximalL2Squared<data_t> prox(*b_);
     109           4 :             prox.apply(v, t, out);
     110           4 :         }
     111           8 :     }
     112             : 
     113             :     template <typename data_t>
     114             :     L2Squared<data_t>* L2Squared<data_t>::cloneImpl() const
     115          68 :     {
     116          68 :         if (!hasDataVector()) {
     117          66 :             return new L2Squared(this->getDomainDescriptor());
     118          66 :         }
     119           2 :         return new L2Squared(*b_);
     120           2 :     }
     121             : 
     122             :     template <typename data_t>
     123             :     bool L2Squared<data_t>::isEqual(const Functional<data_t>& other) const
     124           4 :     {
     125           4 :         if (!Functional<data_t>::isEqual(other))
     126           0 :             return false;
     127             : 
     128           4 :         auto fn = downcast_safe<L2Squared<data_t>>(&other);
     129           4 :         if (!fn) {
     130           0 :             return false;
     131           0 :         }
     132             : 
     133           4 :         if (b_ && fn->b_) {
     134           2 :             return *b_ == *fn->b_;
     135           2 :         }
     136             : 
     137           2 :         return b_ == fn->b_;
     138           2 :     }
     139             : 
     140             :     // ------------------------------------------
     141             :     // explicit template instantiation
     142             :     template class L2Squared<float>;
     143             :     template class L2Squared<double>;
     144             : } // namespace elsa

Generated by: LCOV version 1.14