LCOV - code coverage report
Current view: top level - core - LinearOperator.cpp (source / functions) Hit Total Coverage
Test: test_coverage.info.cleaned Lines: 0 205 0.0 %
Date: 2022-08-04 03:43:28 Functions: 0 64 0.0 %

          Line data    Source code
       1             : #include "LinearOperator.h"
       2             : 
       3             : #include <stdexcept>
       4             : #include <typeinfo>
       5             : 
       6             : #include "DescriptorUtils.h"
       7             : 
       8             : namespace elsa
       9             : {
      10             :     template <typename data_t>
      11           0 :     LinearOperator<data_t>::LinearOperator(const DataDescriptor& domainDescriptor,
      12             :                                            const DataDescriptor& rangeDescriptor)
      13           0 :         : _domainDescriptor{domainDescriptor.clone()}, _rangeDescriptor{rangeDescriptor.clone()}
      14             :     {
      15           0 :     }
      16             : 
      17             :     template <typename data_t>
      18           0 :     LinearOperator<data_t>::LinearOperator(const LinearOperator<data_t>& other)
      19             :         : Cloneable<LinearOperator<data_t>>(),
      20           0 :           _domainDescriptor{other._domainDescriptor->clone()},
      21           0 :           _rangeDescriptor{other._rangeDescriptor->clone()},
      22           0 :           _scalar{other._scalar},
      23           0 :           _isLeaf{other._isLeaf},
      24           0 :           _isAdjoint{other._isAdjoint},
      25           0 :           _isComposite{other._isComposite},
      26           0 :           _mode{other._mode}
      27             :     {
      28           0 :         if (_isLeaf)
      29           0 :             _lhs = other._lhs->clone();
      30             : 
      31           0 :         if (_isComposite) {
      32           0 :             if (_mode == CompositeMode::ADD || _mode == CompositeMode::MULT) {
      33           0 :                 _lhs = other._lhs->clone();
      34           0 :                 _rhs = other._rhs->clone();
      35             :             }
      36             : 
      37           0 :             if (_mode == CompositeMode::SCALAR_MULT) {
      38           0 :                 _rhs = other._rhs->clone();
      39             :             }
      40             :         }
      41           0 :     }
      42             : 
      43             :     template <typename data_t>
      44           0 :     LinearOperator<data_t>& LinearOperator<data_t>::operator=(const LinearOperator<data_t>& other)
      45             :     {
      46           0 :         if (*this != other) {
      47           0 :             _domainDescriptor = other._domainDescriptor->clone();
      48           0 :             _rangeDescriptor = other._rangeDescriptor->clone();
      49           0 :             _scalar = other._scalar;
      50           0 :             _isLeaf = other._isLeaf;
      51           0 :             _isAdjoint = other._isAdjoint;
      52           0 :             _isComposite = other._isComposite;
      53           0 :             _mode = other._mode;
      54             : 
      55           0 :             if (_isLeaf)
      56           0 :                 _lhs = other._lhs->clone();
      57             : 
      58           0 :             if (_isComposite) {
      59           0 :                 if (_mode == CompositeMode::ADD || _mode == CompositeMode::MULT) {
      60           0 :                     _lhs = other._lhs->clone();
      61           0 :                     _rhs = other._rhs->clone();
      62             :                 }
      63             : 
      64           0 :                 if (_mode == CompositeMode::SCALAR_MULT) {
      65           0 :                     _rhs = other._rhs->clone();
      66             :                 }
      67             :             }
      68             :         }
      69             : 
      70           0 :         return *this;
      71             :     }
      72             : 
      73             :     template <typename data_t>
      74           0 :     const DataDescriptor& LinearOperator<data_t>::getDomainDescriptor() const
      75             :     {
      76           0 :         return *_domainDescriptor;
      77             :     }
      78             : 
      79             :     template <typename data_t>
      80           0 :     const DataDescriptor& LinearOperator<data_t>::getRangeDescriptor() const
      81             :     {
      82           0 :         return *_rangeDescriptor;
      83             :     }
      84             : 
      85             :     template <typename data_t>
      86           0 :     DataContainer<data_t> LinearOperator<data_t>::apply(const DataContainer<data_t>& x) const
      87             :     {
      88           0 :         DataContainer<data_t> result(*_rangeDescriptor, x.getDataHandlerType());
      89           0 :         apply(x, result);
      90           0 :         return result;
      91           0 :     }
      92             : 
      93             :     template <typename data_t>
      94           0 :     void LinearOperator<data_t>::apply(const DataContainer<data_t>& x,
      95             :                                        DataContainer<data_t>& Ax) const
      96             :     {
      97           0 :         applyImpl(x, Ax);
      98           0 :     }
      99             : 
     100             :     template <typename data_t>
     101           0 :     void LinearOperator<data_t>::applyImpl(const DataContainer<data_t>& x,
     102             :                                            DataContainer<data_t>& Ax) const
     103             :     {
     104           0 :         if (_isLeaf) {
     105           0 :             if (_isAdjoint) {
     106             :                 // sanity check the arguments for the intended evaluation tree leaf operation
     107           0 :                 if (_lhs->getRangeDescriptor().getNumberOfCoefficients() != x.getSize()
     108           0 :                     || _lhs->getDomainDescriptor().getNumberOfCoefficients() != Ax.getSize())
     109           0 :                     throw InvalidArgumentError(
     110             :                         "LinearOperator::apply: incorrect input/output sizes for adjoint leaf");
     111             : 
     112           0 :                 _lhs->applyAdjoint(x, Ax);
     113           0 :                 return;
     114             :             } else {
     115             :                 // sanity check the arguments for the intended evaluation tree leaf operation
     116           0 :                 if (_lhs->getDomainDescriptor().getNumberOfCoefficients() != x.getSize()
     117           0 :                     || _lhs->getRangeDescriptor().getNumberOfCoefficients() != Ax.getSize())
     118           0 :                     throw InvalidArgumentError(
     119             :                         "LinearOperator::apply: incorrect input/output sizes for leaf");
     120             : 
     121           0 :                 _lhs->apply(x, Ax);
     122           0 :                 return;
     123             :             }
     124             :         }
     125             : 
     126           0 :         if (_isComposite) {
     127           0 :             if (_mode == CompositeMode::ADD) {
     128             :                 // sanity check the arguments for the intended evaluation tree leaf operation
     129           0 :                 if (_rhs->getDomainDescriptor().getNumberOfCoefficients() != x.getSize()
     130           0 :                     || _rhs->getRangeDescriptor().getNumberOfCoefficients() != Ax.getSize()
     131           0 :                     || _lhs->getDomainDescriptor().getNumberOfCoefficients() != x.getSize()
     132           0 :                     || _lhs->getRangeDescriptor().getNumberOfCoefficients() != Ax.getSize())
     133           0 :                     throw InvalidArgumentError(
     134             :                         "LinearOperator::apply: incorrect input/output sizes for add leaf");
     135             : 
     136           0 :                 _rhs->apply(x, Ax);
     137           0 :                 Ax += _lhs->apply(x);
     138           0 :                 return;
     139             :             }
     140             : 
     141           0 :             if (_mode == CompositeMode::MULT) {
     142             :                 // sanity check the arguments for the intended evaluation tree leaf operation
     143           0 :                 if (_rhs->getDomainDescriptor().getNumberOfCoefficients() != x.getSize()
     144           0 :                     || _lhs->getRangeDescriptor().getNumberOfCoefficients() != Ax.getSize())
     145           0 :                     throw InvalidArgumentError(
     146             :                         "LinearOperator::apply: incorrect input/output sizes for mult leaf");
     147             : 
     148           0 :                 DataContainer<data_t> temp(_rhs->getRangeDescriptor(), x.getDataHandlerType());
     149           0 :                 _rhs->apply(x, temp);
     150           0 :                 _lhs->apply(temp, Ax);
     151           0 :                 return;
     152           0 :             }
     153             : 
     154           0 :             if (_mode == CompositeMode::SCALAR_MULT) {
     155             :                 // sanity check the arguments for the intended evaluation tree leaf operation
     156           0 :                 if (_rhs->getDomainDescriptor().getNumberOfCoefficients() != x.getSize())
     157           0 :                     throw InvalidArgumentError("LinearOperator::apply: incorrect input/output "
     158             :                                                "sizes for scalar mult. leaf");
     159             :                 // sanity check the scalar in the optional
     160           0 :                 if (!_scalar.has_value())
     161           0 :                     throw InvalidArgumentError(
     162             :                         "LinearOperator::apply: no value found in the scalar optional");
     163             : 
     164           0 :                 _rhs->apply(x, Ax);
     165           0 :                 Ax *= _scalar.value();
     166           0 :                 return;
     167             :             }
     168             :         }
     169             : 
     170           0 :         throw LogicError("LinearOperator: apply called on ill-formed object");
     171             :     }
     172             : 
     173             :     template <typename data_t>
     174           0 :     DataContainer<data_t> LinearOperator<data_t>::applyAdjoint(const DataContainer<data_t>& y) const
     175             :     {
     176           0 :         DataContainer<data_t> result(*_domainDescriptor, y.getDataHandlerType());
     177           0 :         applyAdjoint(y, result);
     178           0 :         return result;
     179           0 :     }
     180             : 
     181             :     template <typename data_t>
     182           0 :     void LinearOperator<data_t>::applyAdjoint(const DataContainer<data_t>& y,
     183             :                                               DataContainer<data_t>& Aty) const
     184             :     {
     185           0 :         applyAdjointImpl(y, Aty);
     186           0 :     }
     187             : 
     188             :     template <typename data_t>
     189           0 :     void LinearOperator<data_t>::applyAdjointImpl(const DataContainer<data_t>& y,
     190             :                                                   DataContainer<data_t>& Aty) const
     191             :     {
     192           0 :         if (_isLeaf) {
     193           0 :             if (_isAdjoint) {
     194             :                 // sanity check the arguments for the intended evaluation tree leaf operation
     195           0 :                 if (_lhs->getDomainDescriptor().getNumberOfCoefficients() != y.getSize()
     196           0 :                     || _lhs->getRangeDescriptor().getNumberOfCoefficients() != Aty.getSize())
     197           0 :                     throw InvalidArgumentError("LinearOperator::applyAdjoint: incorrect "
     198             :                                                "input/output sizes for adjoint leaf");
     199             : 
     200           0 :                 _lhs->apply(y, Aty);
     201           0 :                 return;
     202             :             } else {
     203             :                 // sanity check the arguments for the intended evaluation tree leaf operation
     204           0 :                 if (_lhs->getRangeDescriptor().getNumberOfCoefficients() != y.getSize()
     205           0 :                     || _lhs->getDomainDescriptor().getNumberOfCoefficients() != Aty.getSize())
     206           0 :                     throw InvalidArgumentError(
     207             :                         "LinearOperator::applyAdjoint: incorrect input/output sizes for leaf");
     208             : 
     209           0 :                 _lhs->applyAdjoint(y, Aty);
     210           0 :                 return;
     211             :             }
     212             :         }
     213             : 
     214           0 :         if (_isComposite) {
     215           0 :             if (_mode == CompositeMode::ADD) {
     216             :                 // sanity check the arguments for the intended evaluation tree leaf operation
     217           0 :                 if (_rhs->getRangeDescriptor().getNumberOfCoefficients() != y.getSize()
     218           0 :                     || _rhs->getDomainDescriptor().getNumberOfCoefficients() != Aty.getSize()
     219           0 :                     || _lhs->getRangeDescriptor().getNumberOfCoefficients() != y.getSize()
     220           0 :                     || _lhs->getDomainDescriptor().getNumberOfCoefficients() != Aty.getSize())
     221           0 :                     throw InvalidArgumentError(
     222             :                         "LinearOperator::applyAdjoint: incorrect input/output sizes for add leaf");
     223             : 
     224           0 :                 _rhs->applyAdjoint(y, Aty);
     225           0 :                 Aty += _lhs->applyAdjoint(y);
     226           0 :                 return;
     227             :             }
     228             : 
     229           0 :             if (_mode == CompositeMode::MULT) {
     230             :                 // sanity check the arguments for the intended evaluation tree leaf operation
     231           0 :                 if (_lhs->getRangeDescriptor().getNumberOfCoefficients() != y.getSize()
     232           0 :                     || _rhs->getDomainDescriptor().getNumberOfCoefficients() != Aty.getSize())
     233           0 :                     throw InvalidArgumentError(
     234             :                         "LinearOperator::applyAdjoint: incorrect input/output sizes for mult leaf");
     235             : 
     236           0 :                 DataContainer<data_t> temp(_lhs->getDomainDescriptor(), y.getDataHandlerType());
     237           0 :                 _lhs->applyAdjoint(y, temp);
     238           0 :                 _rhs->applyAdjoint(temp, Aty);
     239           0 :                 return;
     240           0 :             }
     241             : 
     242           0 :             if (_mode == CompositeMode::SCALAR_MULT) {
     243             :                 // sanity check the arguments for the intended evaluation tree leaf operation
     244           0 :                 if (_rhs->getRangeDescriptor().getNumberOfCoefficients() != y.getSize())
     245           0 :                     throw InvalidArgumentError("LinearOperator::apply: incorrect input/output "
     246             :                                                "sizes for scalar mult. leaf");
     247             :                 // sanity check the scalar in the optional
     248           0 :                 if (!_scalar.has_value())
     249           0 :                     throw InvalidArgumentError(
     250             :                         "LinearOperator::apply: no value found in the scalar optional");
     251             : 
     252           0 :                 _rhs->applyAdjoint(y, Aty);
     253           0 :                 Aty *= _scalar.value();
     254           0 :                 return;
     255             :             }
     256             :         }
     257             : 
     258           0 :         throw LogicError("LinearOperator: applyAdjoint called on ill-formed object");
     259             :     }
     260             : 
     261             :     template <typename data_t>
     262           0 :     LinearOperator<data_t>* LinearOperator<data_t>::cloneImpl() const
     263             :     {
     264           0 :         if (_isLeaf)
     265           0 :             return new LinearOperator<data_t>(*_lhs, _isAdjoint);
     266             : 
     267           0 :         if (_isComposite) {
     268           0 :             if (_mode == CompositeMode::ADD || _mode == CompositeMode::MULT) {
     269           0 :                 return new LinearOperator<data_t>(*_lhs, *_rhs, _mode);
     270             :             }
     271             : 
     272           0 :             if (_mode == CompositeMode::SCALAR_MULT) {
     273           0 :                 return new LinearOperator<data_t>(*this);
     274             :             }
     275             :         }
     276             : 
     277           0 :         return new LinearOperator<data_t>(*_domainDescriptor, *_rangeDescriptor);
     278             :     }
     279             : 
     280             :     template <typename data_t>
     281           0 :     bool LinearOperator<data_t>::isEqual(const LinearOperator<data_t>& other) const
     282             :     {
     283           0 :         if (typeid(other) != typeid(*this))
     284           0 :             return false;
     285             : 
     286           0 :         if (*_domainDescriptor != *other._domainDescriptor
     287           0 :             || *_rangeDescriptor != *other._rangeDescriptor)
     288           0 :             return false;
     289             : 
     290           0 :         if (_isLeaf ^ other._isLeaf || _isComposite ^ other._isComposite)
     291           0 :             return false;
     292             : 
     293           0 :         if (_isLeaf)
     294           0 :             return (_isAdjoint == other._isAdjoint) && (*_lhs == *other._lhs);
     295             : 
     296           0 :         if (_isComposite) {
     297           0 :             if (_mode == CompositeMode::ADD || _mode == CompositeMode::MULT) {
     298           0 :                 return _mode == other._mode && (*_lhs == *other._lhs) && (*_rhs == *other._rhs);
     299             :             }
     300             : 
     301           0 :             if (_mode == CompositeMode::SCALAR_MULT) {
     302           0 :                 return (_isAdjoint == other._isAdjoint) && (*_rhs == *other._rhs);
     303             :             }
     304             :         }
     305             : 
     306           0 :         return true;
     307             :     }
     308             : 
     309             :     template <typename data_t>
     310           0 :     LinearOperator<data_t>::LinearOperator(const LinearOperator<data_t>& op, bool isAdjoint)
     311           0 :         : _domainDescriptor{(isAdjoint) ? op.getRangeDescriptor().clone()
     312           0 :                                         : op.getDomainDescriptor().clone()},
     313           0 :           _rangeDescriptor{(isAdjoint) ? op.getDomainDescriptor().clone()
     314           0 :                                        : op.getRangeDescriptor().clone()},
     315           0 :           _lhs{op.clone()},
     316           0 :           _scalar{op._scalar},
     317           0 :           _isLeaf{true},
     318           0 :           _isAdjoint{isAdjoint}
     319             :     {
     320           0 :     }
     321             : 
     322             :     template <typename data_t>
     323           0 :     LinearOperator<data_t>::LinearOperator(const LinearOperator<data_t>& lhs,
     324             :                                            const LinearOperator<data_t>& rhs, CompositeMode mode)
     325           0 :         : _domainDescriptor{mode == CompositeMode::MULT
     326           0 :                                 ? rhs.getDomainDescriptor().clone()
     327           0 :                                 : bestCommon(*lhs._domainDescriptor, *rhs._domainDescriptor)},
     328           0 :           _rangeDescriptor{mode == CompositeMode::MULT
     329           0 :                                ? lhs.getRangeDescriptor().clone()
     330           0 :                                : bestCommon(*lhs._rangeDescriptor, *rhs._rangeDescriptor)},
     331           0 :           _lhs{lhs.clone()},
     332           0 :           _rhs{rhs.clone()},
     333           0 :           _isComposite{true},
     334           0 :           _mode{mode}
     335             :     {
     336             :         // sanity check the descriptors
     337           0 :         switch (_mode) {
     338           0 :             case CompositeMode::ADD:
     339             :                 /// feasibility checked by bestCommon()
     340           0 :                 break;
     341             : 
     342           0 :             case CompositeMode::MULT:
     343             :                 // for multiplication, domain of _lhs should match range of _rhs
     344           0 :                 if (_lhs->getDomainDescriptor().getNumberOfCoefficients()
     345           0 :                     != _rhs->getRangeDescriptor().getNumberOfCoefficients())
     346           0 :                     throw InvalidArgumentError(
     347             :                         "LinearOperator: composite mult domain/range mismatch");
     348           0 :                 break;
     349             : 
     350           0 :             default:
     351           0 :                 throw LogicError("LinearOperator: unknown composition mode");
     352             :         }
     353           0 :     }
     354             : 
     355             :     template <typename data_t>
     356           0 :     LinearOperator<data_t>::LinearOperator(data_t scalar, const LinearOperator<data_t>& rhs)
     357           0 :         : _domainDescriptor{rhs.getDomainDescriptor().clone()},
     358           0 :           _rangeDescriptor{rhs.getRangeDescriptor().clone()},
     359           0 :           _rhs{rhs.clone()},
     360           0 :           _scalar{scalar},
     361           0 :           _isComposite{true},
     362           0 :           _mode{CompositeMode::SCALAR_MULT}
     363             :     {
     364           0 :     }
     365             : 
     366             :     // ------------------------------------------
     367             :     // explicit template instantiation
     368             :     template class LinearOperator<float>;
     369             :     template class LinearOperator<complex<float>>;
     370             :     template class LinearOperator<double>;
     371             :     template class LinearOperator<complex<double>>;
     372             : 
     373             : } // namespace elsa

Generated by: LCOV version 1.14