LCOV - code coverage report
Current view: top level - elsa/core - LinearOperator.cpp (source / functions) Hit Total Coverage
Test: coverage-all.lcov Lines: 211 219 96.3 %
Date: 2025-01-02 06:42:49 Functions: 64 64 100.0 %

          Line data    Source code
       1             : #include "LinearOperator.h"
       2             : 
       3             : #include <stdexcept>
       4             : #include <typeinfo>
       5             : 
       6             : #include "DescriptorUtils.h"
       7             : 
       8             : namespace elsa
       9             : {
      10             :     template <typename data_t>
      11             :     LinearOperator<data_t>::LinearOperator(const DataDescriptor& domainDescriptor,
      12             :                                            const DataDescriptor& rangeDescriptor)
      13             :         : _domainDescriptor{domainDescriptor.clone()}, _rangeDescriptor{rangeDescriptor.clone()}
      14       11595 :     {
      15       11595 :     }
      16             : 
      17             :     template <typename data_t>
      18             :     LinearOperator<data_t>::LinearOperator(const LinearOperator<data_t>& other)
      19             :         : Cloneable<LinearOperator<data_t>>(),
      20             :           _domainDescriptor{other._domainDescriptor->clone()},
      21             :           _rangeDescriptor{other._rangeDescriptor->clone()},
      22             :           _scalar{other._scalar},
      23             :           _isLeaf{other._isLeaf},
      24             :           _isAdjoint{other._isAdjoint},
      25             :           _isComposite{other._isComposite},
      26             :           _mode{other._mode}
      27        1093 :     {
      28        1093 :         if (_isLeaf)
      29           8 :             _lhs = other._lhs->clone();
      30             : 
      31        1093 :         if (_isComposite) {
      32         224 :             if (_mode == CompositeMode::ADD || _mode == CompositeMode::MULT) {
      33          12 :                 _lhs = other._lhs->clone();
      34          12 :                 _rhs = other._rhs->clone();
      35          12 :             }
      36             : 
      37         224 :             if (_mode == CompositeMode::SCALAR_MULT) {
      38         212 :                 _rhs = other._rhs->clone();
      39         212 :             }
      40         224 :         }
      41        1093 :     }
      42             : 
      43             :     template <typename data_t>
      44             :     LinearOperator<data_t>& LinearOperator<data_t>::operator=(const LinearOperator<data_t>& other)
      45          24 :     {
      46          24 :         if (*this != other) {
      47          24 :             _domainDescriptor = other._domainDescriptor->clone();
      48          24 :             _rangeDescriptor = other._rangeDescriptor->clone();
      49          24 :             _scalar = other._scalar;
      50          24 :             _isLeaf = other._isLeaf;
      51          24 :             _isAdjoint = other._isAdjoint;
      52          24 :             _isComposite = other._isComposite;
      53          24 :             _mode = other._mode;
      54             : 
      55          24 :             if (_isLeaf)
      56           8 :                 _lhs = other._lhs->clone();
      57             : 
      58          24 :             if (_isComposite) {
      59          16 :                 if (_mode == CompositeMode::ADD || _mode == CompositeMode::MULT) {
      60          12 :                     _lhs = other._lhs->clone();
      61          12 :                     _rhs = other._rhs->clone();
      62          12 :                 }
      63             : 
      64          16 :                 if (_mode == CompositeMode::SCALAR_MULT) {
      65           4 :                     _rhs = other._rhs->clone();
      66           4 :                 }
      67          16 :             }
      68          24 :         }
      69             : 
      70          24 :         return *this;
      71          24 :     }
      72             : 
      73             :     template <typename data_t>
      74             :     const DataDescriptor& LinearOperator<data_t>::getDomainDescriptor() const
      75       10163 :     {
      76       10163 :         return *_domainDescriptor;
      77       10163 :     }
      78             : 
      79             :     template <typename data_t>
      80             :     const DataDescriptor& LinearOperator<data_t>::getRangeDescriptor() const
      81        7316 :     {
      82        7316 :         return *_rangeDescriptor;
      83        7316 :     }
      84             : 
      85             :     template <typename data_t>
      86             :     DataContainer<data_t> LinearOperator<data_t>::apply(const DataContainer<data_t>& x) const
      87       10586 :     {
      88       10586 :         DataContainer<data_t> result(*_rangeDescriptor);
      89       10586 :         apply(x, result);
      90       10586 :         return result;
      91       10586 :     }
      92             : 
      93             :     template <typename data_t>
      94             :     void LinearOperator<data_t>::apply(const DataContainer<data_t>& x,
      95             :                                        DataContainer<data_t>& Ax) const
      96       19741 :     {
      97       19741 :         applyImpl(x, Ax);
      98       19741 :     }
      99             : 
     100             :     template <typename data_t>
     101             :     void LinearOperator<data_t>::applyImpl(const DataContainer<data_t>& x,
     102             :                                            DataContainer<data_t>& Ax) const
     103         854 :     {
     104         854 :         if (_isLeaf) {
     105         380 :             if (_isAdjoint) {
     106             :                 // sanity check the arguments for the intended evaluation tree leaf operation
     107         328 :                 if (_lhs->getRangeDescriptor().getNumberOfCoefficients() != x.getSize()
     108         328 :                     || _lhs->getDomainDescriptor().getNumberOfCoefficients() != Ax.getSize())
     109           4 :                     throw InvalidArgumentError(
     110           4 :                         "LinearOperator::apply: incorrect input/output sizes for adjoint leaf");
     111             : 
     112         324 :                 _lhs->applyAdjoint(x, Ax);
     113         324 :                 return;
     114         324 :             } else {
     115             :                 // sanity check the arguments for the intended evaluation tree leaf operation
     116          52 :                 if (_lhs->getDomainDescriptor().getNumberOfCoefficients() != x.getSize()
     117          52 :                     || _lhs->getRangeDescriptor().getNumberOfCoefficients() != Ax.getSize())
     118           4 :                     throw InvalidArgumentError(
     119           4 :                         "LinearOperator::apply: incorrect input/output sizes for leaf");
     120             : 
     121          48 :                 _lhs->apply(x, Ax);
     122          48 :                 return;
     123          48 :             }
     124         380 :         }
     125             : 
     126         474 :         if (_isComposite) {
     127         470 :             if (_mode == CompositeMode::ADD) {
     128             :                 // sanity check the arguments for the intended evaluation tree leaf operation
     129          12 :                 if (_rhs->getDomainDescriptor().getNumberOfCoefficients() != x.getSize()
     130          12 :                     || _rhs->getRangeDescriptor().getNumberOfCoefficients() != Ax.getSize()
     131          12 :                     || _lhs->getDomainDescriptor().getNumberOfCoefficients() != x.getSize()
     132          12 :                     || _lhs->getRangeDescriptor().getNumberOfCoefficients() != Ax.getSize())
     133           4 :                     throw InvalidArgumentError(
     134           4 :                         "LinearOperator::apply: incorrect input/output sizes for add leaf");
     135             : 
     136           8 :                 _rhs->apply(x, Ax);
     137           8 :                 Ax += _lhs->apply(x);
     138           8 :                 return;
     139           8 :             }
     140             : 
     141         458 :             if (_mode == CompositeMode::MULT) {
     142             :                 // sanity check the arguments for the intended evaluation tree leaf operation
     143         402 :                 if (_rhs->getDomainDescriptor().getNumberOfCoefficients() != x.getSize()
     144         402 :                     || _lhs->getRangeDescriptor().getNumberOfCoefficients() != Ax.getSize())
     145           8 :                     throw InvalidArgumentError(
     146           8 :                         "LinearOperator::apply: incorrect input/output sizes for mult leaf");
     147             : 
     148         394 :                 DataContainer<data_t> temp(_rhs->getRangeDescriptor());
     149         394 :                 _rhs->apply(x, temp);
     150         394 :                 _lhs->apply(temp, Ax);
     151         394 :                 return;
     152         394 :             }
     153             : 
     154          56 :             if (_mode == CompositeMode::SCALAR_MULT) {
     155             :                 // sanity check the arguments for the intended evaluation tree leaf operation
     156          56 :                 if (_rhs->getDomainDescriptor().getNumberOfCoefficients() != x.getSize())
     157           4 :                     throw InvalidArgumentError("LinearOperator::apply: incorrect input/output "
     158           4 :                                                "sizes for scalar mult. leaf");
     159             :                 // sanity check the scalar in the optional
     160          52 :                 if (!_scalar.has_value())
     161           0 :                     throw InvalidArgumentError(
     162           0 :                         "LinearOperator::apply: no value found in the scalar optional");
     163             : 
     164          52 :                 _rhs->apply(x, Ax);
     165          52 :                 Ax *= _scalar.value();
     166          52 :                 return;
     167          52 :             }
     168          56 :         }
     169             : 
     170           4 :         throw LogicError("LinearOperator: apply called on ill-formed object");
     171           4 :     }
     172             : 
     173             :     template <typename data_t>
     174             :     DataContainer<data_t> LinearOperator<data_t>::applyAdjoint(const DataContainer<data_t>& y) const
     175         731 :     {
     176         731 :         DataContainer<data_t> result(*_domainDescriptor);
     177         731 :         applyAdjoint(y, result);
     178         731 :         return result;
     179         731 :     }
     180             : 
     181             :     template <typename data_t>
     182             :     void LinearOperator<data_t>::applyAdjoint(const DataContainer<data_t>& y,
     183             :                                               DataContainer<data_t>& Aty) const
     184        5241 :     {
     185        5241 :         applyAdjointImpl(y, Aty);
     186        5241 :     }
     187             : 
     188             :     template <typename data_t>
     189             :     void LinearOperator<data_t>::applyAdjointImpl(const DataContainer<data_t>& y,
     190             :                                                   DataContainer<data_t>& Aty) const
     191         198 :     {
     192         198 :         if (_isLeaf) {
     193          20 :             if (_isAdjoint) {
     194             :                 // sanity check the arguments for the intended evaluation tree leaf operation
     195          12 :                 if (_lhs->getDomainDescriptor().getNumberOfCoefficients() != y.getSize()
     196          12 :                     || _lhs->getRangeDescriptor().getNumberOfCoefficients() != Aty.getSize())
     197           4 :                     throw InvalidArgumentError("LinearOperator::applyAdjoint: incorrect "
     198           4 :                                                "input/output sizes for adjoint leaf");
     199             : 
     200           8 :                 _lhs->apply(y, Aty);
     201           8 :                 return;
     202           8 :             } else {
     203             :                 // sanity check the arguments for the intended evaluation tree leaf operation
     204           8 :                 if (_lhs->getRangeDescriptor().getNumberOfCoefficients() != y.getSize()
     205           8 :                     || _lhs->getDomainDescriptor().getNumberOfCoefficients() != Aty.getSize())
     206           4 :                     throw InvalidArgumentError(
     207           4 :                         "LinearOperator::applyAdjoint: incorrect input/output sizes for leaf");
     208             : 
     209           4 :                 _lhs->applyAdjoint(y, Aty);
     210           4 :                 return;
     211           4 :             }
     212          20 :         }
     213             : 
     214         178 :         if (_isComposite) {
     215         174 :             if (_mode == CompositeMode::ADD) {
     216             :                 // sanity check the arguments for the intended evaluation tree leaf operation
     217          12 :                 if (_rhs->getRangeDescriptor().getNumberOfCoefficients() != y.getSize()
     218          12 :                     || _rhs->getDomainDescriptor().getNumberOfCoefficients() != Aty.getSize()
     219          12 :                     || _lhs->getRangeDescriptor().getNumberOfCoefficients() != y.getSize()
     220          12 :                     || _lhs->getDomainDescriptor().getNumberOfCoefficients() != Aty.getSize())
     221           4 :                     throw InvalidArgumentError(
     222           4 :                         "LinearOperator::applyAdjoint: incorrect input/output sizes for add leaf");
     223             : 
     224           8 :                 _rhs->applyAdjoint(y, Aty);
     225           8 :                 Aty += _lhs->applyAdjoint(y);
     226           8 :                 return;
     227           8 :             }
     228             : 
     229         162 :             if (_mode == CompositeMode::MULT) {
     230             :                 // sanity check the arguments for the intended evaluation tree leaf operation
     231          82 :                 if (_lhs->getRangeDescriptor().getNumberOfCoefficients() != y.getSize()
     232          82 :                     || _rhs->getDomainDescriptor().getNumberOfCoefficients() != Aty.getSize())
     233           8 :                     throw InvalidArgumentError(
     234           8 :                         "LinearOperator::applyAdjoint: incorrect input/output sizes for mult leaf");
     235             : 
     236          74 :                 DataContainer<data_t> temp(_lhs->getDomainDescriptor());
     237          74 :                 _lhs->applyAdjoint(y, temp);
     238          74 :                 _rhs->applyAdjoint(temp, Aty);
     239          74 :                 return;
     240          74 :             }
     241             : 
     242          80 :             if (_mode == CompositeMode::SCALAR_MULT) {
     243             :                 // sanity check the arguments for the intended evaluation tree leaf operation
     244          80 :                 if (_rhs->getRangeDescriptor().getNumberOfCoefficients() != y.getSize())
     245           4 :                     throw InvalidArgumentError("LinearOperator::apply: incorrect input/output "
     246           4 :                                                "sizes for scalar mult. leaf");
     247             :                 // sanity check the scalar in the optional
     248          76 :                 if (!_scalar.has_value())
     249           0 :                     throw InvalidArgumentError(
     250           0 :                         "LinearOperator::apply: no value found in the scalar optional");
     251             : 
     252          76 :                 _rhs->applyAdjoint(y, Aty);
     253          76 :                 Aty *= _scalar.value();
     254          76 :                 return;
     255          76 :             }
     256          80 :         }
     257             : 
     258           4 :         throw LogicError("LinearOperator: applyAdjoint called on ill-formed object");
     259           4 :     }
     260             : 
     261             :     template <typename data_t>
     262             :     LinearOperator<data_t>* LinearOperator<data_t>::cloneImpl() const
     263        1578 :     {
     264        1578 :         if (_isLeaf)
     265         534 :             return new LinearOperator<data_t>(*_lhs, _isAdjoint);
     266             : 
     267        1044 :         if (_isComposite) {
     268        1016 :             if (_mode == CompositeMode::ADD || _mode == CompositeMode::MULT) {
     269         808 :                 return new LinearOperator<data_t>(*_lhs, *_rhs, _mode);
     270         808 :             }
     271             : 
     272         208 :             if (_mode == CompositeMode::SCALAR_MULT) {
     273         208 :                 return new LinearOperator<data_t>(*this);
     274         208 :             }
     275          28 :         }
     276             : 
     277          28 :         return new LinearOperator<data_t>(*_domainDescriptor, *_rangeDescriptor);
     278          28 :     }
     279             : 
     280             :     template <typename data_t>
     281             :     bool LinearOperator<data_t>::isEqual(const LinearOperator<data_t>& other) const
     282         722 :     {
     283         722 :         if (typeid(other) != typeid(*this))
     284           8 :             return false;
     285             : 
     286         714 :         if (*_domainDescriptor != *other._domainDescriptor
     287         714 :             || *_rangeDescriptor != *other._rangeDescriptor)
     288          24 :             return false;
     289             : 
     290         690 :         if (_isLeaf ^ other._isLeaf || _isComposite ^ other._isComposite)
     291           0 :             return false;
     292             : 
     293         690 :         if (_isLeaf)
     294         148 :             return (_isAdjoint == other._isAdjoint) && (*_lhs == *other._lhs);
     295             : 
     296         542 :         if (_isComposite) {
     297          96 :             if (_mode == CompositeMode::ADD || _mode == CompositeMode::MULT) {
     298          76 :                 return _mode == other._mode && (*_lhs == *other._lhs) && (*_rhs == *other._rhs);
     299          76 :             }
     300             : 
     301          20 :             if (_mode == CompositeMode::SCALAR_MULT) {
     302          20 :                 return (_isAdjoint == other._isAdjoint) && (*_rhs == *other._rhs);
     303          20 :             }
     304         446 :         }
     305             : 
     306         446 :         return true;
     307         446 :     }
     308             : 
     309             :     template <typename data_t>
     310             :     LinearOperator<data_t>::LinearOperator(const LinearOperator<data_t>& op, bool isAdjoint)
     311             :         : _domainDescriptor{(isAdjoint) ? op.getRangeDescriptor().clone()
     312             :                                         : op.getDomainDescriptor().clone()},
     313             :           _rangeDescriptor{(isAdjoint) ? op.getDomainDescriptor().clone()
     314             :                                        : op.getRangeDescriptor().clone()},
     315             :           _lhs{op.clone()},
     316             :           _scalar{op._scalar},
     317             :           _isLeaf{true},
     318             :           _isAdjoint{isAdjoint}
     319         920 :     {
     320         920 :     }
     321             : 
     322             :     template <typename data_t>
     323             :     LinearOperator<data_t>::LinearOperator(const LinearOperator<data_t>& lhs,
     324             :                                            const LinearOperator<data_t>& rhs, CompositeMode mode)
     325             :         : _domainDescriptor{mode == CompositeMode::MULT
     326             :                                 ? rhs.getDomainDescriptor().clone()
     327             :                                 : bestCommon(*lhs._domainDescriptor, *rhs._domainDescriptor)},
     328             :           _rangeDescriptor{mode == CompositeMode::MULT
     329             :                                ? lhs.getRangeDescriptor().clone()
     330             :                                : bestCommon(*lhs._rangeDescriptor, *rhs._rangeDescriptor)},
     331             :           _lhs{lhs.clone()},
     332             :           _rhs{rhs.clone()},
     333             :           _isComposite{true},
     334             :           _mode{mode}
     335        1248 :     {
     336             :         // sanity check the descriptors
     337        1248 :         switch (_mode) {
     338          72 :             case CompositeMode::ADD:
     339             :                 /// feasibility checked by bestCommon()
     340          72 :                 break;
     341             : 
     342        1176 :             case CompositeMode::MULT:
     343             :                 // for multiplication, domain of _lhs should match range of _rhs
     344        1176 :                 if (_lhs->getDomainDescriptor().getNumberOfCoefficients()
     345        1176 :                     != _rhs->getRangeDescriptor().getNumberOfCoefficients())
     346           0 :                     throw InvalidArgumentError(
     347           0 :                         "LinearOperator: composite mult domain/range mismatch");
     348        1176 :                 break;
     349             : 
     350        1176 :             default:
     351           0 :                 throw LogicError("LinearOperator: unknown composition mode");
     352        1248 :         }
     353        1248 :     }
     354             : 
     355             :     template <typename data_t>
     356             :     LinearOperator<data_t>::LinearOperator(data_t scalar, const LinearOperator<data_t>& rhs)
     357             :         : _domainDescriptor{rhs.getDomainDescriptor().clone()},
     358             :           _rangeDescriptor{rhs.getRangeDescriptor().clone()},
     359             :           _rhs{rhs.clone()},
     360             :           _scalar{scalar},
     361             :           _isComposite{true},
     362             :           _mode{CompositeMode::SCALAR_MULT}
     363          48 :     {
     364          48 :     }
     365             : 
     366             :     // ------------------------------------------
     367             :     // explicit template instantiation
     368             :     template class LinearOperator<float>;
     369             :     template class LinearOperator<complex<float>>;
     370             :     template class LinearOperator<double>;
     371             :     template class LinearOperator<complex<double>>;
     372             : 
     373             : } // namespace elsa

Generated by: LCOV version 1.14