LCOV - code coverage report
Current view: top level - elsa/operators - BlockLinearOperator.cpp (source / functions) Hit Total Coverage
Test: coverage-all.lcov Lines: 180 188 95.7 %
Date: 2024-05-16 04:22:26 Functions: 28 56 50.0 %

          Line data    Source code
       1             : #include "BlockLinearOperator.h"
       2             : #include "PartitionDescriptor.h"
       3             : #include "RandomBlocksDescriptor.h"
       4             : #include "DescriptorUtils.h"
       5             : #include "TypeCasts.hpp"
       6             : 
       7             : #include <algorithm>
       8             : 
       9             : namespace elsa
      10             : {
      11             :     template <typename data_t>
      12             :     BlockLinearOperator<data_t>::BlockLinearOperator(const OperatorList& ops, BlockType blockType)
      13             :         : LinearOperator<data_t>{*determineDomainDescriptor(ops, blockType),
      14             :                                  *determineRangeDescriptor(ops, blockType)},
      15             :           _operatorList(0),
      16             :           _blockType{blockType}
      17          78 :     {
      18          78 :         for (const auto& op : ops)
      19         170 :             _operatorList.push_back(op->clone());
      20          78 :     }
      21             : 
      22             :     template <typename data_t>
      23             :     BlockLinearOperator<data_t>::BlockLinearOperator(const DataDescriptor& domainDescriptor,
      24             :                                                      const DataDescriptor& rangeDescriptor,
      25             :                                                      const OperatorList& ops, BlockType blockType)
      26             :         : LinearOperator<data_t>{domainDescriptor, rangeDescriptor},
      27             :           _operatorList(0),
      28             :           _blockType{blockType}
      29          80 :     {
      30          80 :         if (_blockType == BlockType::COL) {
      31          34 :             const auto* trueDomainDesc = downcast_safe<BlockDescriptor>(&domainDescriptor);
      32             : 
      33          34 :             if (trueDomainDesc == nullptr)
      34           4 :                 throw InvalidArgumentError(
      35           4 :                     "BlockLinearOperator: domain descriptor is not a BlockDescriptor");
      36             : 
      37          30 :             if (trueDomainDesc->getNumberOfBlocks() != static_cast<index_t>(ops.size()))
      38           2 :                 throw InvalidArgumentError("BlockLinearOperator: domain descriptor number of "
      39           2 :                                            "blocks does not match operator list size");
      40             : 
      41         264 :             for (index_t i = 0; i < static_cast<index_t>(ops.size()); i++) {
      42         244 :                 const auto& op = ops[static_cast<std::size_t>(i)];
      43         244 :                 if (op->getRangeDescriptor().getNumberOfCoefficients()
      44         244 :                     != _rangeDescriptor->getNumberOfCoefficients())
      45           4 :                     throw InvalidArgumentError(
      46           4 :                         "BlockLinearOperator: the range descriptor of a COL BlockLinearOperator "
      47           4 :                         "must have the same size as the range of every operator in the list");
      48             : 
      49         240 :                 if (op->getDomainDescriptor().getNumberOfCoefficients()
      50         240 :                     != trueDomainDesc->getDescriptorOfBlock(i).getNumberOfCoefficients())
      51           4 :                     throw InvalidArgumentError(
      52           4 :                         "BlockLinearOperator: block of incorrect size in domain descriptor");
      53         240 :             }
      54          28 :         }
      55             : 
      56          46 :         else if (_blockType == BlockType::ROW) {
      57          46 :             const auto* trueRangeDesc = downcast_safe<BlockDescriptor>(&rangeDescriptor);
      58             : 
      59          46 :             if (trueRangeDesc == nullptr)
      60           4 :                 throw InvalidArgumentError(
      61           4 :                     "BlockLinearOperator: range descriptor is not a BlockDescriptor");
      62             : 
      63          42 :             if (trueRangeDesc->getNumberOfBlocks() != static_cast<index_t>(ops.size()))
      64           2 :                 throw InvalidArgumentError("BlockLinearOperator: range descriptor number of "
      65           2 :                                            "blocks does not match operator list size");
      66             : 
      67         104 :             for (index_t i = 0; i < static_cast<index_t>(ops.size()); i++) {
      68          72 :                 const auto& op = ops[static_cast<std::size_t>(i)];
      69          72 :                 if (op->getDomainDescriptor().getNumberOfCoefficients()
      70          72 :                     != _domainDescriptor->getNumberOfCoefficients())
      71           4 :                     throw InvalidArgumentError(
      72           4 :                         "BlockLinearOperator: the domain descriptor of a ROW BlockLinearOperator "
      73           4 :                         "must have the same size as the domain of every operator in the list");
      74             : 
      75          68 :                 if (op->getRangeDescriptor().getNumberOfCoefficients()
      76          68 :                     != trueRangeDesc->getDescriptorOfBlock(i).getNumberOfCoefficients())
      77           4 :                     throw InvalidArgumentError(
      78           4 :                         "BlockLinearOperator: block of incorrect size in range descriptor");
      79          68 :             }
      80          40 :         }
      81             : 
      82          80 :         for (const auto& op : ops)
      83         300 :             _operatorList.push_back(op->clone());
      84          52 :     }
      85             : 
      86             :     template <typename data_t>
      87             :     const LinearOperator<data_t>& BlockLinearOperator<data_t>::getIthOperator(index_t index) const
      88          40 :     {
      89          40 :         return *_operatorList.at(static_cast<std::size_t>(index));
      90          40 :     }
      91             : 
      92             :     template <typename data_t>
      93             :     index_t BlockLinearOperator<data_t>::numberOfOps() const
      94          20 :     {
      95          20 :         return static_cast<index_t>(_operatorList.size());
      96          20 :     }
      97             : 
      98             :     template <typename data_t>
      99             :     BlockLinearOperator<data_t>::BlockLinearOperator(const BlockLinearOperator<data_t>& other)
     100             :         : LinearOperator<data_t>{*other._domainDescriptor, *other._rangeDescriptor},
     101             :           _operatorList(0),
     102             :           _blockType{other._blockType}
     103         226 :     {
     104         226 :         for (const auto& op : other._operatorList)
     105         458 :             _operatorList.push_back(op->clone());
     106         226 :     }
     107             : 
     108             :     template <typename data_t>
     109             :     void BlockLinearOperator<data_t>::applyImpl(const DataContainer<data_t>& x,
     110             :                                                 DataContainer<data_t>& Ax) const
     111         128 :     {
     112         128 :         switch (_blockType) {
     113          18 :             case BlockType::COL: {
     114             : 
     115          18 :                 auto acc = zeroslike(Ax);
     116          18 :                 auto img = emptylike(Ax);
     117          18 :                 auto xView = x.viewAs(*_domainDescriptor);
     118             : 
     119         200 :                 for (size_t i = 0; i < _operatorList.size(); ++i) {
     120         182 :                     _operatorList[i]->apply(xView.getBlock(i), img);
     121         182 :                     acc += img;
     122         182 :                 }
     123          18 :                 Ax = acc;
     124          18 :                 break;
     125           0 :             }
     126         110 :             case BlockType::ROW: {
     127             : 
     128         110 :                 auto AxView = Ax.viewAs(*_rangeDescriptor);
     129             : 
     130         336 :                 for (size_t i = 0; i < _operatorList.size(); ++i) {
     131         226 :                     auto block = AxView.getBlock(i);
     132         226 :                     _operatorList[i]->apply(x, block);
     133         226 :                 }
     134             : 
     135         110 :                 break;
     136           0 :             }
     137         128 :         }
     138         128 :     }
     139             : 
     140             :     template <typename data_t>
     141             :     void BlockLinearOperator<data_t>::applyAdjointImpl(const DataContainer<data_t>& y,
     142             :                                                        DataContainer<data_t>& Aty) const
     143         144 :     {
     144         144 :         switch (_blockType) {
     145          10 :             case BlockType::COL: {
     146          10 :                 index_t i = 0;
     147             : 
     148          10 :                 auto AtyView = Aty.viewAs(*_domainDescriptor);
     149          78 :                 for (const auto& op : _operatorList) {
     150          78 :                     auto&& blk = AtyView.getBlock(i);
     151          78 :                     op->applyAdjoint(y, blk);
     152          78 :                     ++i;
     153          78 :                 }
     154             : 
     155          10 :                 break;
     156           0 :             }
     157         134 :             case BlockType::ROW: {
     158         134 :                 Aty = 0;
     159         134 :                 auto tmpAty = DataContainer<data_t>(Aty.getDataDescriptor());
     160             : 
     161         134 :                 auto yView = y.viewAs(*_rangeDescriptor);
     162         134 :                 index_t i = 0;
     163         270 :                 for (const auto& op : _operatorList) {
     164         270 :                     op->applyAdjoint(materialize(yView.getBlock(i)), tmpAty);
     165         270 :                     Aty += tmpAty;
     166         270 :                     ++i;
     167         270 :                 }
     168             : 
     169         134 :                 break;
     170           0 :             }
     171         144 :         }
     172         144 :     }
     173             : 
     174             :     template <typename data_t>
     175             :     BlockLinearOperator<data_t>* BlockLinearOperator<data_t>::cloneImpl() const
     176         226 :     {
     177         226 :         return new BlockLinearOperator<data_t>(*this);
     178         226 :     }
     179             : 
     180             :     template <typename data_t>
     181             :     bool BlockLinearOperator<data_t>::isEqual(const LinearOperator<data_t>& other) const
     182          18 :     {
     183          18 :         if (!LinearOperator<data_t>::isEqual(other))
     184           6 :             return false;
     185             : 
     186             :         // static_cast as type checked in base comparison
     187          12 :         auto otherBlockOp = static_cast<const BlockLinearOperator<data_t>*>(&other);
     188             : 
     189          40 :         for (std::size_t i = 0; i < _operatorList.size(); i++)
     190          32 :             if (*_operatorList[i] != *otherBlockOp->_operatorList[i])
     191           4 :                 return false;
     192             : 
     193          12 :         return true;
     194          12 :     }
     195             : 
     196             :     template <typename data_t>
     197             :     std::unique_ptr<DataDescriptor>
     198             :         BlockLinearOperator<data_t>::determineDomainDescriptor(const OperatorList& operatorList,
     199             :                                                                BlockType blockType)
     200          78 :     {
     201          78 :         std::vector<const DataDescriptor*> vec(operatorList.size());
     202         256 :         for (std::size_t i = 0; i < vec.size(); i++)
     203         178 :             vec[i] = &operatorList[i]->getDomainDescriptor();
     204             : 
     205          78 :         switch (blockType) {
     206          42 :             case BlockType::ROW:
     207          42 :                 return bestCommon(vec);
     208             : 
     209          36 :             case BlockType::COL:
     210          36 :                 return bestBlockDescriptor(vec);
     211             : 
     212           0 :             default:
     213           0 :                 throw InvalidArgumentError("BlockLinearOpearator: unsupported block type");
     214          78 :         }
     215          78 :     }
     216             : 
     217             :     template <typename data_t>
     218             :     std::unique_ptr<DataDescriptor>
     219             :         BlockLinearOperator<data_t>::determineRangeDescriptor(const OperatorList& operatorList,
     220             :                                                               BlockType blockType)
     221          72 :     {
     222          72 :         std::vector<const DataDescriptor*> vec(operatorList.size());
     223         246 :         for (std::size_t i = 0; i < vec.size(); i++)
     224         174 :             vec[i] = &operatorList[i]->getRangeDescriptor();
     225             : 
     226          72 :         switch (blockType) {
     227          38 :             case BlockType::ROW:
     228          38 :                 return bestBlockDescriptor(vec);
     229             : 
     230          34 :             case BlockType::COL:
     231          34 :                 return bestCommon(vec);
     232             : 
     233           0 :             default:
     234           0 :                 throw InvalidArgumentError("BlockLinearOpearator: unsupported block type");
     235          72 :         }
     236          72 :     }
     237             : 
     238             :     template <typename data_t>
     239             :     std::unique_ptr<BlockDescriptor> BlockLinearOperator<data_t>::bestBlockDescriptor(
     240             :         const std::vector<const DataDescriptor*>& descList)
     241          74 :     {
     242          74 :         auto numBlocks = descList.size();
     243          74 :         if (numBlocks == 0)
     244           2 :             throw InvalidArgumentError("BlockLinearOperator: operator list cannot be empty");
     245             : 
     246          72 :         const auto& firstDesc = *descList[0];
     247          72 :         auto numDims = firstDesc.getNumberOfDimensions();
     248          72 :         auto coeffs = firstDesc.getNumberOfCoefficientsPerDimension();
     249             : 
     250          72 :         bool allNumDimSame = true;
     251          72 :         bool allButLastDimSame = true;
     252          72 :         IndexVector_t lastDimSplit(numBlocks);
     253         174 :         for (std::size_t i = 1; i < numBlocks && allButLastDimSame; i++) {
     254         102 :             lastDimSplit[static_cast<index_t>(i - 1)] =
     255         102 :                 descList[i - 1]->getNumberOfCoefficientsPerDimension()[numDims - 1];
     256             : 
     257         102 :             allNumDimSame = allNumDimSame && descList[i]->getNumberOfDimensions() == numDims;
     258         102 :             allButLastDimSame =
     259         102 :                 allNumDimSame && allButLastDimSame
     260         102 :                 && descList[i]->getNumberOfCoefficientsPerDimension().head(numDims - 1)
     261          94 :                        == coeffs.head(numDims - 1);
     262         102 :         }
     263             : 
     264          72 :         if (allButLastDimSame) {
     265          64 :             lastDimSplit[static_cast<index_t>(numBlocks) - 1] =
     266          64 :                 descList[numBlocks - 1]->getNumberOfCoefficientsPerDimension()[numDims - 1];
     267             : 
     268          64 :             auto spacing = firstDesc.getSpacingPerDimension();
     269          64 :             bool allSameSpacing =
     270         158 :                 all_of(descList.begin(), descList.end(), [&spacing](const DataDescriptor* d) {
     271         158 :                     return d->getSpacingPerDimension() == spacing;
     272         158 :                 });
     273             : 
     274          64 :             coeffs[numDims - 1] = lastDimSplit.sum();
     275          64 :             if (allSameSpacing) {
     276          56 :                 VolumeDescriptor tmp(coeffs, spacing);
     277          56 :                 return std::make_unique<PartitionDescriptor>(tmp, lastDimSplit);
     278          56 :             } else {
     279           8 :                 VolumeDescriptor tmp(coeffs);
     280           8 :                 return std::make_unique<PartitionDescriptor>(tmp, lastDimSplit);
     281           8 :             }
     282           8 :         }
     283             : 
     284           8 :         std::vector<std::unique_ptr<DataDescriptor>> tmp(numBlocks);
     285           8 :         auto it = descList.begin();
     286          16 :         std::generate(tmp.begin(), tmp.end(), [&it]() { return (*it++)->clone(); });
     287             : 
     288           8 :         return std::make_unique<RandomBlocksDescriptor>(std::move(tmp));
     289           8 :     }
     290             : 
     291             :     // ----------------------------------------------
     292             :     // explicit template instantiation
     293             :     template class BlockLinearOperator<float>;
     294             :     template class BlockLinearOperator<double>;
     295             :     template class BlockLinearOperator<complex<float>>;
     296             :     template class BlockLinearOperator<complex<double>>;
     297             : } // namespace elsa

Generated by: LCOV version 1.14