LCOV - code coverage report
Current view: top level - elsa/operators - BlockLinearOperator.cpp (source / functions) Hit Total Coverage
Test: coverage-all.lcov Lines: 183 191 95.8 %
Date: 2022-08-25 03:05:39 Functions: 28 56 50.0 %

          Line data    Source code
       1             : #include "BlockLinearOperator.h"
       2             : #include "PartitionDescriptor.h"
       3             : #include "RandomBlocksDescriptor.h"
       4             : #include "DescriptorUtils.h"
       5             : #include "TypeCasts.hpp"
       6             : 
       7             : #include <algorithm>
       8             : 
       9             : namespace elsa
      10             : {
      11             :     template <typename data_t>
      12             :     BlockLinearOperator<data_t>::BlockLinearOperator(const OperatorList& ops, BlockType blockType)
      13             :         : LinearOperator<data_t>{*determineDomainDescriptor(ops, blockType),
      14             :                                  *determineRangeDescriptor(ops, blockType)},
      15             :           _operatorList(0),
      16             :           _blockType{blockType}
      17          68 :     {
      18          68 :         for (const auto& op : ops)
      19         150 :             _operatorList.push_back(op->clone());
      20          68 :     }
      21             : 
      22             :     template <typename data_t>
      23             :     BlockLinearOperator<data_t>::BlockLinearOperator(const DataDescriptor& domainDescriptor,
      24             :                                                      const DataDescriptor& rangeDescriptor,
      25             :                                                      const OperatorList& ops, BlockType blockType)
      26             :         : LinearOperator<data_t>{domainDescriptor, rangeDescriptor},
      27             :           _operatorList(0),
      28             :           _blockType{blockType}
      29          44 :     {
      30          44 :         if (_blockType == COL) {
      31          22 :             const auto* trueDomainDesc = downcast_safe<BlockDescriptor>(&domainDescriptor);
      32             : 
      33          22 :             if (trueDomainDesc == nullptr)
      34           4 :                 throw InvalidArgumentError(
      35           4 :                     "BlockLinearOperator: domain descriptor is not a BlockDescriptor");
      36             : 
      37          18 :             if (trueDomainDesc->getNumberOfBlocks() != static_cast<index_t>(ops.size()))
      38           2 :                 throw InvalidArgumentError("BlockLinearOperator: domain descriptor number of "
      39           2 :                                            "blocks does not match operator list size");
      40             : 
      41          32 :             for (index_t i = 0; i < static_cast<index_t>(ops.size()); i++) {
      42          24 :                 const auto& op = ops[static_cast<std::size_t>(i)];
      43          24 :                 if (op->getRangeDescriptor().getNumberOfCoefficients()
      44          24 :                     != _rangeDescriptor->getNumberOfCoefficients())
      45           4 :                     throw InvalidArgumentError(
      46           4 :                         "BlockLinearOperator: the range descriptor of a COL BlockLinearOperator "
      47           4 :                         "must have the same size as the range of every operator in the list");
      48             : 
      49          20 :                 if (op->getDomainDescriptor().getNumberOfCoefficients()
      50          20 :                     != trueDomainDesc->getDescriptorOfBlock(i).getNumberOfCoefficients())
      51           4 :                     throw InvalidArgumentError(
      52           4 :                         "BlockLinearOperator: block of incorrect size in domain descriptor");
      53          20 :             }
      54          16 :         }
      55             : 
      56          44 :         if (_blockType == ROW) {
      57          22 :             const auto* trueRangeDesc = downcast_safe<BlockDescriptor>(&rangeDescriptor);
      58             : 
      59          22 :             if (trueRangeDesc == nullptr)
      60           4 :                 throw InvalidArgumentError(
      61           4 :                     "BlockLinearOperator: range descriptor is not a BlockDescriptor");
      62             : 
      63          18 :             if (trueRangeDesc->getNumberOfBlocks() != static_cast<index_t>(ops.size()))
      64           2 :                 throw InvalidArgumentError("BlockLinearOperator: range descriptor number of "
      65           2 :                                            "blocks does not match operator list size");
      66             : 
      67          32 :             for (index_t i = 0; i < static_cast<index_t>(ops.size()); i++) {
      68          24 :                 const auto& op = ops[static_cast<std::size_t>(i)];
      69          24 :                 if (op->getDomainDescriptor().getNumberOfCoefficients()
      70          24 :                     != _domainDescriptor->getNumberOfCoefficients())
      71           4 :                     throw InvalidArgumentError(
      72           4 :                         "BlockLinearOperator: the domain descriptor of a ROW BlockLinearOperator "
      73           4 :                         "must have the same size as the domain of every operator in the list");
      74             : 
      75          20 :                 if (op->getRangeDescriptor().getNumberOfCoefficients()
      76          20 :                     != trueRangeDesc->getDescriptorOfBlock(i).getNumberOfCoefficients())
      77           4 :                     throw InvalidArgumentError(
      78           4 :                         "BlockLinearOperator: block of incorrect size in range descriptor");
      79          20 :             }
      80          16 :         }
      81             : 
      82          30 :         for (const auto& op : ops)
      83          32 :             _operatorList.push_back(op->clone());
      84          16 :     }
      85             : 
      86             :     template <typename data_t>
      87             :     const LinearOperator<data_t>& BlockLinearOperator<data_t>::getIthOperator(index_t index) const
      88          40 :     {
      89          40 :         return *_operatorList.at(static_cast<std::size_t>(index));
      90          40 :     }
      91             : 
      92             :     template <typename data_t>
      93             :     index_t BlockLinearOperator<data_t>::numberOfOps() const
      94          20 :     {
      95          20 :         return static_cast<index_t>(_operatorList.size());
      96          20 :     }
      97             : 
      98             :     template <typename data_t>
      99             :     BlockLinearOperator<data_t>::BlockLinearOperator(const BlockLinearOperator<data_t>& other)
     100             :         : LinearOperator<data_t>{*other._domainDescriptor, *other._rangeDescriptor},
     101             :           _operatorList(0),
     102             :           _blockType{other._blockType}
     103          24 :     {
     104          24 :         for (const auto& op : other._operatorList)
     105          54 :             _operatorList.push_back(op->clone());
     106          24 :     }
     107             : 
     108             :     template <typename data_t>
     109             :     void BlockLinearOperator<data_t>::applyImpl(const DataContainer<data_t>& x,
     110             :                                                 DataContainer<data_t>& Ax) const
     111          12 :     {
     112          12 :         switch (_blockType) {
     113           2 :             case BlockType::COL: {
     114           2 :                 Ax = 0;
     115           2 :                 auto tmpAx = DataContainer<data_t>(Ax.getDataDescriptor());
     116             : 
     117           2 :                 auto xView = x.viewAs(*_domainDescriptor);
     118           2 :                 index_t i = 0;
     119           6 :                 for (const auto& op : _operatorList) {
     120           6 :                     op->apply(xView.getBlock(i), tmpAx);
     121           6 :                     Ax += tmpAx;
     122           6 :                     ++i;
     123           6 :                 }
     124             : 
     125           2 :                 break;
     126           0 :             }
     127          10 :             case BlockType::ROW: {
     128          10 :                 index_t i = 0;
     129             : 
     130          10 :                 auto AxView = Ax.viewAs(*_rangeDescriptor);
     131          26 :                 for (const auto& op : _operatorList) {
     132          26 :                     auto blk = AxView.getBlock(i);
     133          26 :                     op->apply(x, blk);
     134          26 :                     ++i;
     135          26 :                 }
     136             : 
     137          10 :                 break;
     138           0 :             }
     139          12 :         }
     140          12 :     }
     141             : 
     142             :     template <typename data_t>
     143             :     void BlockLinearOperator<data_t>::applyAdjointImpl(const DataContainer<data_t>& y,
     144             :                                                        DataContainer<data_t>& Aty) const
     145          10 :     {
     146          10 :         switch (_blockType) {
     147           6 :             case BlockType::COL: {
     148           6 :                 index_t i = 0;
     149             : 
     150           6 :                 auto AtyView = Aty.viewAs(*_domainDescriptor);
     151          18 :                 for (const auto& op : _operatorList) {
     152          18 :                     auto&& blk = AtyView.getBlock(i);
     153          18 :                     op->applyAdjoint(y, blk);
     154          18 :                     ++i;
     155          18 :                 }
     156             : 
     157           6 :                 break;
     158           0 :             }
     159           4 :             case BlockType::ROW: {
     160           4 :                 Aty = 0;
     161           4 :                 auto tmpAty = DataContainer<data_t>(Aty.getDataDescriptor());
     162             : 
     163           4 :                 auto yView = y.viewAs(*_rangeDescriptor);
     164           4 :                 index_t i = 0;
     165          10 :                 for (const auto& op : _operatorList) {
     166          10 :                     op->applyAdjoint(yView.getBlock(i), tmpAty);
     167          10 :                     Aty += tmpAty;
     168          10 :                     ++i;
     169          10 :                 }
     170             : 
     171           4 :                 break;
     172           0 :             }
     173          10 :         }
     174          10 :     }
     175             : 
     176             :     template <typename data_t>
     177             :     BlockLinearOperator<data_t>* BlockLinearOperator<data_t>::cloneImpl() const
     178          24 :     {
     179          24 :         return new BlockLinearOperator<data_t>(*this);
     180          24 :     }
     181             : 
     182             :     template <typename data_t>
     183             :     bool BlockLinearOperator<data_t>::isEqual(const LinearOperator<data_t>& other) const
     184          16 :     {
     185          16 :         if (!LinearOperator<data_t>::isEqual(other))
     186           6 :             return false;
     187             : 
     188             :         // static_cast as type checked in base comparison
     189          10 :         auto otherBlockOp = static_cast<const BlockLinearOperator<data_t>*>(&other);
     190             : 
     191          34 :         for (std::size_t i = 0; i < _operatorList.size(); i++)
     192          28 :             if (*_operatorList[i] != *otherBlockOp->_operatorList[i])
     193           4 :                 return false;
     194             : 
     195          10 :         return true;
     196          10 :     }
     197             : 
     198             :     template <typename data_t>
     199             :     std::unique_ptr<DataDescriptor>
     200             :         BlockLinearOperator<data_t>::determineDomainDescriptor(const OperatorList& operatorList,
     201             :                                                                BlockType blockType)
     202          68 :     {
     203          68 :         std::vector<const DataDescriptor*> vec(operatorList.size());
     204         226 :         for (std::size_t i = 0; i < vec.size(); i++)
     205         158 :             vec[i] = &operatorList[i]->getDomainDescriptor();
     206             : 
     207          68 :         switch (blockType) {
     208          40 :             case BlockType::ROW:
     209          40 :                 return bestCommon(vec);
     210             : 
     211          28 :             case BlockType::COL:
     212          28 :                 return bestBlockDescriptor(vec);
     213             : 
     214           0 :             default:
     215           0 :                 throw InvalidArgumentError("BlockLinearOpearator: unsupported block type");
     216          68 :         }
     217          68 :     }
     218             : 
     219             :     template <typename data_t>
     220             :     std::unique_ptr<DataDescriptor>
     221             :         BlockLinearOperator<data_t>::determineRangeDescriptor(const OperatorList& operatorList,
     222             :                                                               BlockType blockType)
     223          62 :     {
     224          62 :         std::vector<const DataDescriptor*> vec(operatorList.size());
     225         216 :         for (std::size_t i = 0; i < vec.size(); i++)
     226         154 :             vec[i] = &operatorList[i]->getRangeDescriptor();
     227             : 
     228          62 :         switch (blockType) {
     229          36 :             case BlockType::ROW:
     230          36 :                 return bestBlockDescriptor(vec);
     231             : 
     232          26 :             case BlockType::COL:
     233          26 :                 return bestCommon(vec);
     234             : 
     235           0 :             default:
     236           0 :                 throw InvalidArgumentError("BlockLinearOpearator: unsupported block type");
     237          62 :         }
     238          62 :     }
     239             : 
     240             :     template <typename data_t>
     241             :     std::unique_ptr<BlockDescriptor> BlockLinearOperator<data_t>::bestBlockDescriptor(
     242             :         const std::vector<const DataDescriptor*>& descList)
     243          64 :     {
     244          64 :         auto numBlocks = descList.size();
     245          64 :         if (numBlocks == 0)
     246           2 :             throw InvalidArgumentError("BlockLinearOperator: operator list cannot be empty");
     247             : 
     248          62 :         const auto& firstDesc = *descList[0];
     249          62 :         auto numDims = firstDesc.getNumberOfDimensions();
     250          62 :         auto coeffs = firstDesc.getNumberOfCoefficientsPerDimension();
     251             : 
     252          62 :         bool allNumDimSame = true;
     253          62 :         bool allButLastDimSame = true;
     254          62 :         IndexVector_t lastDimSplit(numBlocks);
     255         154 :         for (std::size_t i = 1; i < numBlocks && allButLastDimSame; i++) {
     256          92 :             lastDimSplit[static_cast<index_t>(i - 1)] =
     257          92 :                 descList[i - 1]->getNumberOfCoefficientsPerDimension()[numDims - 1];
     258             : 
     259          92 :             allNumDimSame = allNumDimSame && descList[i]->getNumberOfDimensions() == numDims;
     260          92 :             allButLastDimSame =
     261          92 :                 allNumDimSame && allButLastDimSame
     262          92 :                 && descList[i]->getNumberOfCoefficientsPerDimension().head(numDims - 1)
     263          84 :                        == coeffs.head(numDims - 1);
     264          92 :         }
     265             : 
     266          62 :         if (allButLastDimSame) {
     267          54 :             lastDimSplit[static_cast<index_t>(numBlocks) - 1] =
     268          54 :                 descList[numBlocks - 1]->getNumberOfCoefficientsPerDimension()[numDims - 1];
     269             : 
     270          54 :             auto spacing = firstDesc.getSpacingPerDimension();
     271          54 :             bool allSameSpacing =
     272         138 :                 all_of(descList.begin(), descList.end(), [&spacing](const DataDescriptor* d) {
     273         138 :                     return d->getSpacingPerDimension() == spacing;
     274         138 :                 });
     275             : 
     276          54 :             coeffs[numDims - 1] = lastDimSplit.sum();
     277          54 :             if (allSameSpacing) {
     278          46 :                 VolumeDescriptor tmp(coeffs, spacing);
     279          46 :                 return std::make_unique<PartitionDescriptor>(tmp, lastDimSplit);
     280          46 :             } else {
     281           8 :                 VolumeDescriptor tmp(coeffs);
     282           8 :                 return std::make_unique<PartitionDescriptor>(tmp, lastDimSplit);
     283           8 :             }
     284           8 :         }
     285             : 
     286           8 :         std::vector<std::unique_ptr<DataDescriptor>> tmp(numBlocks);
     287           8 :         auto it = descList.begin();
     288          16 :         std::generate(tmp.begin(), tmp.end(), [&it]() { return (*it++)->clone(); });
     289             : 
     290           8 :         return std::make_unique<RandomBlocksDescriptor>(std::move(tmp));
     291           8 :     }
     292             : 
     293             :     // ----------------------------------------------
     294             :     // explicit template instantiation
     295             :     template class BlockLinearOperator<float>;
     296             :     template class BlockLinearOperator<double>;
     297             :     template class BlockLinearOperator<complex<float>>;
     298             :     template class BlockLinearOperator<complex<double>>;
     299             : } // namespace elsa

Generated by: LCOV version 1.14