LCOV - code coverage report
Current view: top level - elsa/functionals - LinearResidual.cpp (source / functions) Hit Total Coverage
Test: coverage-all.lcov Lines: 60 85 70.6 %
Date: 2025-01-14 06:38:49 Functions: 60 72 83.3 %

          Line data    Source code
       1             : #include "LinearResidual.h"
       2             : #include "DataContainer.h"
       3             : #include "DataDescriptor.h"
       4             : #include "Identity.h"
       5             : #include "LinearOperator.h"
       6             : #include "TypeCasts.hpp"
       7             : 
       8             : #include <optional>
       9             : #include <stdexcept>
      10             : 
      11             : namespace elsa
      12             : {
      13             :     template <typename data_t>
      14             :     LinearResidual<data_t>::LinearResidual(const DataDescriptor& descriptor)
      15             :         : domainDesc_(descriptor.clone()), rangeDesc_(descriptor.clone())
      16          28 :     {
      17          28 :     }
      18             : 
      19             :     template <typename data_t>
      20             :     LinearResidual<data_t>::LinearResidual(const DataContainer<data_t>& b)
      21             :         : domainDesc_(b.getDataDescriptor().clone()),
      22             :           rangeDesc_(b.getDataDescriptor().clone()),
      23             :           _dataVector(b)
      24          28 :     {
      25          28 :     }
      26             : 
      27             :     template <typename data_t>
      28             :     LinearResidual<data_t>::LinearResidual(const LinearOperator<data_t>& A)
      29             :         : domainDesc_(A.getDomainDescriptor().clone()),
      30             :           rangeDesc_(A.getRangeDescriptor().clone()),
      31             :           _operator(A.clone())
      32          28 :     {
      33          28 :     }
      34             : 
      35             :     template <typename data_t>
      36             :     LinearResidual<data_t>::LinearResidual(const LinearOperator<data_t>& A,
      37             :                                            const DataContainer<data_t>& b)
      38             :         : domainDesc_(A.getDomainDescriptor().clone()),
      39             :           rangeDesc_(A.getRangeDescriptor().clone()),
      40             :           _operator(A.clone()),
      41             :           _dataVector{b}
      42          28 :     {
      43          28 :         if (A.getRangeDescriptor() != b.getDataDescriptor())
      44           0 :             throw InvalidArgumentError("LinearResidual: A and b do not match");
      45          28 :     }
      46             : 
      47             :     namespace detail
      48             :     {
      49             :         template <class data_t>
      50             :         std::unique_ptr<LinearOperator<data_t>> extractOp(LinearOperator<data_t>* op)
      51          16 :         {
      52          16 :             if (op) {
      53           8 :                 return op->clone();
      54           8 :             } else {
      55           8 :                 return nullptr;
      56           8 :             }
      57          16 :         }
      58             :     } // namespace detail
      59             : 
      60             :     template <typename data_t>
      61             :     LinearResidual<data_t>::LinearResidual(const LinearResidual<data_t>& other)
      62             :         : domainDesc_(other.domainDesc_->clone()),
      63             :           rangeDesc_(other.rangeDesc_->clone()),
      64             :           _operator(detail::extractOp(other._operator.get())),
      65             :           _dataVector(other._dataVector)
      66          16 :     {
      67          16 :     }
      68             : 
      69             :     template <typename data_t>
      70             :     LinearResidual<data_t>& LinearResidual<data_t>::operator=(const LinearResidual<data_t>& other)
      71           0 :     {
      72           0 :         domainDesc_ = other.domainDesc_->clone();
      73           0 :         rangeDesc_ = other.rangeDesc_->clone();
      74             : 
      75           0 :         if (other.hasOperator()) {
      76           0 :             _operator = other._operator->clone();
      77           0 :         } else {
      78           0 :             _operator = nullptr;
      79           0 :         }
      80             : 
      81           0 :         if (other.hasDataVector()) {
      82           0 :             _dataVector = other.getDataVector();
      83           0 :         } else {
      84           0 :             _dataVector = std::nullopt;
      85           0 :         }
      86           0 :         return *this;
      87           0 :     }
      88             : 
      89             :     template <typename data_t>
      90             :     LinearResidual<data_t>::LinearResidual(LinearResidual<data_t>&& other) noexcept
      91             :         : domainDesc_(std::move(other.domainDesc_)),
      92             :           rangeDesc_(std::move(other.rangeDesc_)),
      93             :           _operator(std::move(other._operator)),
      94             :           _dataVector(std::move(other._dataVector))
      95           0 :     {
      96           0 :     }
      97             : 
      98             :     template <typename data_t>
      99             :     LinearResidual<data_t>&
     100             :         LinearResidual<data_t>::operator=(LinearResidual<data_t>&& other) noexcept
     101           0 :     {
     102           0 :         domainDesc_ = std::move(other.domainDesc_);
     103           0 :         rangeDesc_ = std::move(other.rangeDesc_);
     104           0 :         _operator = std::move(other._operator);
     105           0 :         _dataVector = std::move(other._dataVector);
     106             : 
     107           0 :         return *this;
     108           0 :     }
     109             : 
     110             :     template <typename data_t>
     111             :     const DataDescriptor& LinearResidual<data_t>::getDomainDescriptor() const
     112          80 :     {
     113          80 :         return *domainDesc_;
     114          80 :     }
     115             : 
     116             :     template <typename data_t>
     117             :     const DataDescriptor& LinearResidual<data_t>::getRangeDescriptor() const
     118         120 :     {
     119         120 :         return *rangeDesc_;
     120         120 :     }
     121             : 
     122             :     template <typename data_t>
     123             :     bool LinearResidual<data_t>::hasOperator() const
     124         188 :     {
     125         188 :         return static_cast<bool>(_operator);
     126         188 :     }
     127             : 
     128             :     template <typename data_t>
     129             :     bool LinearResidual<data_t>::hasDataVector() const
     130         140 :     {
     131         140 :         return _dataVector.has_value();
     132         140 :     }
     133             : 
     134             :     template <typename data_t>
     135             :     const LinearOperator<data_t>& LinearResidual<data_t>::getOperator() const
     136          80 :     {
     137          80 :         if (!_operator)
     138           8 :             throw Error("LinearResidual::getOperator: operator not present");
     139             : 
     140          72 :         return *_operator;
     141          72 :     }
     142             : 
     143             :     template <typename data_t>
     144             :     const DataContainer<data_t>& LinearResidual<data_t>::getDataVector() const
     145          72 :     {
     146          72 :         if (!_dataVector)
     147           8 :             throw Error("LinearResidual::getDataVector: data vector not present");
     148             : 
     149          64 :         return *_dataVector;
     150          64 :     }
     151             : 
     152             :     template <typename data_t>
     153             :     DataContainer<data_t> LinearResidual<data_t>::evaluate(const DataContainer<data_t>& x) const
     154          32 :     {
     155          32 :         DataContainer<data_t> out(this->getRangeDescriptor());
     156          32 :         evaluate(x, out);
     157          32 :         return out;
     158          32 :     }
     159             : 
     160             :     template <typename data_t>
     161             :     void LinearResidual<data_t>::evaluate(const DataContainer<data_t>& x,
     162             :                                           DataContainer<data_t>& result) const
     163          32 :     {
     164          32 :         if (hasOperator())
     165          16 :             _operator->apply(x, result);
     166          16 :         else
     167          16 :             result = x;
     168             : 
     169          32 :         if (hasDataVector()) {
     170          16 :             result -= *_dataVector;
     171          16 :         }
     172          32 :     }
     173             : 
     174             :     template <typename data_t>
     175             :     LinearOperator<data_t>
     176             :         LinearResidual<data_t>::getJacobian([[maybe_unused]] const DataContainer<data_t>& x)
     177          16 :     {
     178          16 :         if (hasOperator())
     179           8 :             return leaf(*_operator);
     180           8 :         else
     181           8 :             return leaf(Identity<data_t>(this->getRangeDescriptor()));
     182          16 :     }
     183             : 
     184             :     // ------------------------------------------
     185             :     // explicit template instantiation
     186             :     template class LinearResidual<float>;
     187             :     template class LinearResidual<double>;
     188             :     template class LinearResidual<complex<float>>;
     189             :     template class LinearResidual<complex<double>>;
     190             : 
     191             : } // namespace elsa

Generated by: LCOV version 1.14