LCOV - code coverage report
Current view: top level - elsa/functionals - Functional.cpp (source / functions) Hit Total Coverage
Test: coverage-all.lcov Lines: 52 62 83.9 %
Date: 2022-08-25 03:05:39 Functions: 36 36 100.0 %

          Line data    Source code
       1             : #include "Functional.h"
       2             : #include "LinearResidual.h"
       3             : #include "TypeCasts.hpp"
       4             : 
       5             : #include <stdexcept>
       6             : 
       7             : namespace elsa
       8             : {
       9             :     template <typename data_t>
      10             :     Functional<data_t>::Functional(const DataDescriptor& domainDescriptor)
      11             :         : _domainDescriptor{domainDescriptor.clone()},
      12             :           _residual{std::make_unique<LinearResidual<data_t>>(domainDescriptor)}
      13         836 :     {
      14         836 :     }
      15             : 
      16             :     template <typename data_t>
      17             :     Functional<data_t>::Functional(const Residual<data_t>& residual)
      18             :         : _domainDescriptor{residual.getDomainDescriptor().clone()}, _residual{residual.clone()}
      19        1864 :     {
      20        1864 :     }
      21             : 
      22             :     template <typename data_t>
      23             :     const DataDescriptor& Functional<data_t>::getDomainDescriptor() const
      24       11855 :     {
      25       11855 :         return *_domainDescriptor;
      26       11855 :     }
      27             : 
      28             :     template <typename data_t>
      29             :     const Residual<data_t>& Functional<data_t>::getResidual() const
      30        3547 :     {
      31        3547 :         return *_residual;
      32        3547 :     }
      33             : 
      34             :     template <typename data_t>
      35             :     data_t Functional<data_t>::evaluate(const DataContainer<data_t>& x)
      36         156 :     {
      37         156 :         if (x.getSize() != getDomainDescriptor().getNumberOfCoefficients())
      38           0 :             throw InvalidArgumentError(
      39           0 :                 "Functional::evaluate: argument size does not match functional");
      40             : 
      41             :         // optimize for trivial LinearResiduals (no extra copy for residual result needed then)
      42         156 :         if (auto* linearResidual = downcast_safe<LinearResidual<data_t>>(_residual.get())) {
      43         156 :             if (!linearResidual->hasOperator() && !linearResidual->hasDataVector())
      44          98 :                 return evaluateImpl(x);
      45          58 :         }
      46             : 
      47             :         // in all other cases: evaluate the residual first, then call our virtual evaluateImpl
      48          58 :         return evaluateImpl(_residual->evaluate(x));
      49          58 :     }
      50             : 
      51             :     template <typename data_t>
      52             :     DataContainer<data_t> Functional<data_t>::getGradient(const DataContainer<data_t>& x)
      53         122 :     {
      54         122 :         DataContainer<data_t> result(_residual->getDomainDescriptor(), x.getDataHandlerType());
      55         122 :         getGradient(x, result);
      56         122 :         return result;
      57         122 :     }
      58             : 
      59             :     template <typename data_t>
      60             :     void Functional<data_t>::getGradient(const DataContainer<data_t>& x,
      61             :                                          DataContainer<data_t>& result)
      62       11119 :     {
      63       11119 :         if (x.getSize() != getDomainDescriptor().getNumberOfCoefficients()
      64       11119 :             || result.getSize() != _residual->getDomainDescriptor().getNumberOfCoefficients())
      65           0 :             throw InvalidArgumentError(
      66           0 :                 "Functional::getGradient: argument sizes do not match functional");
      67             : 
      68             :         // optimize for trivial or simple LinearResiduals
      69       11119 :         if (auto* linearResidual = downcast_safe<LinearResidual<data_t>>(_residual.get())) {
      70             :             // if trivial, no extra copy for residual result needed (and no chain rule)
      71       11119 :             if (!linearResidual->hasOperator() && !linearResidual->hasDataVector()) {
      72         299 :                 result = x;
      73         299 :                 getGradientInPlaceImpl(result);
      74         299 :                 return;
      75         299 :             }
      76             : 
      77             :             // if no operator, no need for chain rule
      78       10820 :             if (!linearResidual->hasOperator()) {
      79           0 :                 linearResidual->evaluate(x,
      80           0 :                                          result); // sizes of x and result will match in this case
      81           0 :                 getGradientInPlaceImpl(result);
      82           0 :                 return;
      83           0 :             }
      84       10820 :         }
      85             : 
      86             :         // the general case
      87       10820 :         auto temp = _residual->evaluate(x);
      88       10820 :         getGradientInPlaceImpl(temp);
      89       10820 :         _residual->getJacobian(x).applyAdjoint(temp, result); // apply the chain rule
      90       10820 :     }
      91             : 
      92             :     template <typename data_t>
      93             :     LinearOperator<data_t> Functional<data_t>::getHessian(const DataContainer<data_t>& x)
      94         193 :     {
      95             :         // optimize for trivial and simple LinearResiduals
      96         193 :         if (auto* linearResidual = downcast_safe<LinearResidual<data_t>>(_residual.get())) {
      97             :             // if trivial, no extra copy for residual result needed (and no chain rule)
      98         193 :             if (!linearResidual->hasOperator() && !linearResidual->hasDataVector())
      99         102 :                 return getHessianImpl(x);
     100             : 
     101             :             // if no operator, no need for chain rule
     102          91 :             if (!linearResidual->hasOperator())
     103           0 :                 return getHessianImpl(_residual->evaluate(x));
     104          91 :         }
     105             : 
     106             :         // the general case (with chain rule)
     107          91 :         auto jacobian = _residual->getJacobian(x);
     108          91 :         auto hessian = adjoint(jacobian) * (getHessianImpl(_residual->evaluate(x))) * (jacobian);
     109          91 :         return hessian;
     110          91 :     }
     111             : 
     112             :     template <typename data_t>
     113             :     bool Functional<data_t>::isEqual(const Functional<data_t>& other) const
     114         183 :     {
     115         183 :         return !static_cast<bool>(*_domainDescriptor != *other._domainDescriptor
     116         183 :                                   || *_residual != *other._residual);
     117         183 :     }
     118             : 
     119             :     // ------------------------------------------
     120             :     // explicit template instantiation
     121             :     template class Functional<float>;
     122             :     template class Functional<double>;
     123             :     template class Functional<complex<float>>;
     124             :     template class Functional<complex<double>>;
     125             : } // namespace elsa

Generated by: LCOV version 1.14