LCOV - code coverage report
Current view: top level - functionals - Functional.cpp (source / functions) Hit Total Coverage
Test: test_coverage.info.cleaned Lines: 0 53 0.0 %
Date: 2022-08-04 03:43:28 Functions: 0 36 0.0 %

          Line data    Source code
       1             : #include "Functional.h"
       2             : #include "LinearResidual.h"
       3             : #include "TypeCasts.hpp"
       4             : 
       5             : #include <stdexcept>
       6             : 
       7             : namespace elsa
       8             : {
       9             :     template <typename data_t>
      10           0 :     Functional<data_t>::Functional(const DataDescriptor& domainDescriptor)
      11           0 :         : _domainDescriptor{domainDescriptor.clone()},
      12           0 :           _residual{std::make_unique<LinearResidual<data_t>>(domainDescriptor)}
      13             :     {
      14           0 :     }
      15             : 
      16             :     template <typename data_t>
      17           0 :     Functional<data_t>::Functional(const Residual<data_t>& residual)
      18           0 :         : _domainDescriptor{residual.getDomainDescriptor().clone()}, _residual{residual.clone()}
      19             :     {
      20           0 :     }
      21             : 
      22             :     template <typename data_t>
      23           0 :     const DataDescriptor& Functional<data_t>::getDomainDescriptor() const
      24             :     {
      25           0 :         return *_domainDescriptor;
      26             :     }
      27             : 
      28             :     template <typename data_t>
      29           0 :     const Residual<data_t>& Functional<data_t>::getResidual() const
      30             :     {
      31           0 :         return *_residual;
      32             :     }
      33             : 
      34             :     template <typename data_t>
      35           0 :     data_t Functional<data_t>::evaluate(const DataContainer<data_t>& x)
      36             :     {
      37           0 :         if (x.getSize() != getDomainDescriptor().getNumberOfCoefficients())
      38           0 :             throw InvalidArgumentError(
      39             :                 "Functional::evaluate: argument size does not match functional");
      40             : 
      41             :         // optimize for trivial LinearResiduals (no extra copy for residual result needed then)
      42           0 :         if (auto* linearResidual = downcast_safe<LinearResidual<data_t>>(_residual.get())) {
      43           0 :             if (!linearResidual->hasOperator() && !linearResidual->hasDataVector())
      44           0 :                 return evaluateImpl(x);
      45             :         }
      46             : 
      47             :         // in all other cases: evaluate the residual first, then call our virtual evaluateImpl
      48           0 :         return evaluateImpl(_residual->evaluate(x));
      49             :     }
      50             : 
      51             :     template <typename data_t>
      52           0 :     DataContainer<data_t> Functional<data_t>::getGradient(const DataContainer<data_t>& x)
      53             :     {
      54           0 :         DataContainer<data_t> result(_residual->getRangeDescriptor(), x.getDataHandlerType());
      55           0 :         getGradient(x, result);
      56           0 :         return result;
      57           0 :     }
      58             : 
      59             :     template <typename data_t>
      60           0 :     void Functional<data_t>::getGradient(const DataContainer<data_t>& x,
      61             :                                          DataContainer<data_t>& result)
      62             :     {
      63           0 :         if (x.getSize() != getDomainDescriptor().getNumberOfCoefficients()
      64           0 :             || result.getSize() != _residual->getDomainDescriptor().getNumberOfCoefficients())
      65           0 :             throw InvalidArgumentError(
      66             :                 "Functional::getGradient: argument sizes do not match functional");
      67             : 
      68             :         // optimize for trivial or simple LinearResiduals
      69           0 :         if (auto* linearResidual = downcast_safe<LinearResidual<data_t>>(_residual.get())) {
      70             :             // if trivial, no extra copy for residual result needed (and no chain rule)
      71           0 :             if (!linearResidual->hasOperator() && !linearResidual->hasDataVector()) {
      72           0 :                 result = x;
      73           0 :                 getGradientInPlaceImpl(result);
      74           0 :                 return;
      75             :             }
      76             : 
      77             :             // if no operator, no need for chain rule
      78           0 :             if (!linearResidual->hasOperator()) {
      79           0 :                 linearResidual->evaluate(x,
      80             :                                          result); // sizes of x and result will match in this case
      81           0 :                 getGradientInPlaceImpl(result);
      82           0 :                 return;
      83             :             }
      84             :         }
      85             : 
      86             :         // the general case
      87           0 :         auto temp = _residual->evaluate(x);
      88           0 :         getGradientInPlaceImpl(temp);
      89           0 :         _residual->getJacobian(x).applyAdjoint(temp, result); // apply the chain rule
      90           0 :     }
      91             : 
      92             :     template <typename data_t>
      93           0 :     LinearOperator<data_t> Functional<data_t>::getHessian(const DataContainer<data_t>& x)
      94             :     {
      95             :         // optimize for trivial and simple LinearResiduals
      96           0 :         if (auto* linearResidual = downcast_safe<LinearResidual<data_t>>(_residual.get())) {
      97             :             // if trivial, no extra copy for residual result needed (and no chain rule)
      98           0 :             if (!linearResidual->hasOperator() && !linearResidual->hasDataVector())
      99           0 :                 return getHessianImpl(x);
     100             : 
     101             :             // if no operator, no need for chain rule
     102           0 :             if (!linearResidual->hasOperator())
     103           0 :                 return getHessianImpl(_residual->evaluate(x));
     104             :         }
     105             : 
     106             :         // the general case (with chain rule)
     107           0 :         auto jacobian = _residual->getJacobian(x);
     108           0 :         auto hessian = adjoint(jacobian) * (getHessianImpl(_residual->evaluate(x))) * (jacobian);
     109           0 :         return hessian;
     110           0 :     }
     111             : 
     112             :     template <typename data_t>
     113           0 :     bool Functional<data_t>::isEqual(const Functional<data_t>& other) const
     114             :     {
     115           0 :         return !static_cast<bool>(*_domainDescriptor != *other._domainDescriptor
     116           0 :                                   || *_residual != *other._residual);
     117             :     }
     118             : 
     119             :     // ------------------------------------------
     120             :     // explicit template instantiation
     121             :     template class Functional<float>;
     122             :     template class Functional<double>;
     123             :     template class Functional<complex<float>>;
     124             :     template class Functional<complex<double>>;
     125             : } // namespace elsa

Generated by: LCOV version 1.14