LCOV - code coverage report
Current view: top level - elsa/functionals - Quadric.cpp (source / functions) Hit Total Coverage
Test: coverage-all.lcov Lines: 51 57 89.5 %
Date: 2024-05-16 04:22:26 Functions: 40 44 90.9 %

          Line data    Source code
       1             : #include "Quadric.h"
       2             : #include "DataContainer.h"
       3             : #include "Identity.h"
       4             : #include "TypeCasts.hpp"
       5             : 
       6             : #include <stdexcept>
       7             : 
       8             : namespace elsa
       9             : {
      10             :     template <typename data_t>
      11             :     Quadric<data_t>::Quadric(const LinearOperator<data_t>& A, const DataContainer<data_t>& b)
      12             :         : Functional<data_t>(A.getDomainDescriptor()), linResidual_{A, b}
      13          16 :     {
      14          16 :     }
      15             : 
      16             :     template <typename data_t>
      17             :     Quadric<data_t>::Quadric(const LinearOperator<data_t>& A)
      18             :         : Functional<data_t>(A.getDomainDescriptor()), linResidual_{A}
      19          16 :     {
      20          16 :     }
      21             : 
      22             :     template <typename data_t>
      23             :     Quadric<data_t>::Quadric(const DataContainer<data_t>& b)
      24             :         : Functional<data_t>(b.getDataDescriptor()), linResidual_{b}
      25          16 :     {
      26          16 :     }
      27             : 
      28             :     template <typename data_t>
      29             :     Quadric<data_t>::Quadric(const DataDescriptor& domainDescriptor)
      30             :         : Functional<data_t>(domainDescriptor), linResidual_{domainDescriptor}
      31          16 :     {
      32          16 :     }
      33             : 
      34             :     template <typename data_t>
      35             :     bool Quadric<data_t>::isDifferentiable() const
      36           0 :     {
      37           0 :         return true;
      38           0 :     }
      39             : 
      40             :     template <typename data_t>
      41             :     const LinearResidual<data_t>& Quadric<data_t>::getGradientExpression() const
      42          16 :     {
      43          16 :         return linResidual_;
      44          16 :     }
      45             : 
      46             :     template <typename data_t>
      47             :     data_t Quadric<data_t>::evaluateImpl(const DataContainer<data_t>& Rx) const
      48          16 :     {
      49          16 :         data_t xtAx;
      50             : 
      51          16 :         if (linResidual_.hasOperator()) {
      52           8 :             auto temp = linResidual_.getOperator().apply(Rx);
      53           8 :             xtAx = Rx.dot(temp);
      54           8 :         } else {
      55           8 :             xtAx = Rx.squaredL2Norm();
      56           8 :         }
      57             : 
      58          16 :         if (linResidual_.hasDataVector()) {
      59           8 :             return static_cast<data_t>(0.5) * xtAx - Rx.dot(linResidual_.getDataVector());
      60           8 :         } else {
      61           8 :             return static_cast<data_t>(0.5) * xtAx;
      62           8 :         }
      63          16 :     }
      64             : 
      65             :     template <typename data_t>
      66             :     void Quadric<data_t>::getGradientImpl(const DataContainer<data_t>& Rx,
      67             :                                           DataContainer<data_t>& out) const
      68          16 :     {
      69          16 :         out = linResidual_.evaluate(Rx);
      70          16 :     }
      71             : 
      72             :     template <typename data_t>
      73             :     LinearOperator<data_t>
      74             :         Quadric<data_t>::getHessianImpl([[maybe_unused]] const DataContainer<data_t>& Rx) const
      75          16 :     {
      76          16 :         if (linResidual_.hasOperator())
      77           8 :             return leaf(linResidual_.getOperator());
      78           8 :         else
      79           8 :             return leaf(Identity<data_t>(this->getDomainDescriptor()));
      80          16 :     }
      81             : 
      82             :     template <typename data_t>
      83             :     Quadric<data_t>* Quadric<data_t>::cloneImpl() const
      84          16 :     {
      85          16 :         if (linResidual_.hasOperator() && linResidual_.hasDataVector())
      86           4 :             return new Quadric<data_t>(linResidual_.getOperator(), linResidual_.getDataVector());
      87          12 :         else if (linResidual_.hasOperator() && !linResidual_.hasDataVector())
      88           4 :             return new Quadric<data_t>(linResidual_.getOperator());
      89           8 :         else if (!linResidual_.hasOperator() && linResidual_.hasDataVector())
      90           4 :             return new Quadric<data_t>(linResidual_.getDataVector());
      91           4 :         else
      92           4 :             return new Quadric<data_t>(this->getDomainDescriptor());
      93          16 :     }
      94             : 
      95             :     template <typename data_t>
      96             :     bool Quadric<data_t>::isEqual(const Functional<data_t>& other) const
      97          16 :     {
      98          16 :         if (!Functional<data_t>::isEqual(other))
      99           0 :             return false;
     100             : 
     101          16 :         auto otherQuadric = downcast_safe<Quadric>(&other);
     102          16 :         if (!otherQuadric)
     103           0 :             return false;
     104             : 
     105          16 :         if (linResidual_ != otherQuadric->linResidual_)
     106           0 :             return false;
     107             : 
     108          16 :         return true;
     109          16 :     }
     110             : 
     111             :     // ------------------------------------------
     112             :     // explicit template instantiation
     113             :     template class Quadric<float>;
     114             :     template class Quadric<double>;
     115             :     template class Quadric<complex<float>>;
     116             :     template class Quadric<complex<double>>;
     117             : 
     118             : } // namespace elsa

Generated by: LCOV version 1.14