LCOV - code coverage report
Current view: top level - operators - BlockLinearOperator.cpp (source / functions) Hit Total Coverage
Test: test_coverage.info.cleaned Lines: 0 171 0.0 %
Date: 2022-08-04 03:43:28 Functions: 0 56 0.0 %

          Line data    Source code
       1             : #include "BlockLinearOperator.h"
       2             : #include "PartitionDescriptor.h"
       3             : #include "RandomBlocksDescriptor.h"
       4             : #include "DescriptorUtils.h"
       5             : #include "TypeCasts.hpp"
       6             : 
       7             : #include <algorithm>
       8             : 
       9             : namespace elsa
      10             : {
      11             :     template <typename data_t>
      12           0 :     BlockLinearOperator<data_t>::BlockLinearOperator(const OperatorList& ops, BlockType blockType)
      13           0 :         : LinearOperator<data_t>{*determineDomainDescriptor(ops, blockType),
      14           0 :                                  *determineRangeDescriptor(ops, blockType)},
      15           0 :           _operatorList(0),
      16           0 :           _blockType{blockType}
      17             :     {
      18           0 :         for (const auto& op : ops)
      19           0 :             _operatorList.push_back(op->clone());
      20           0 :     }
      21             : 
      22             :     template <typename data_t>
      23           0 :     BlockLinearOperator<data_t>::BlockLinearOperator(const DataDescriptor& domainDescriptor,
      24             :                                                      const DataDescriptor& rangeDescriptor,
      25             :                                                      const OperatorList& ops, BlockType blockType)
      26             :         : LinearOperator<data_t>{domainDescriptor, rangeDescriptor},
      27           0 :           _operatorList(0),
      28           0 :           _blockType{blockType}
      29             :     {
      30           0 :         if (_blockType == COL) {
      31           0 :             const auto* trueDomainDesc = downcast_safe<BlockDescriptor>(&domainDescriptor);
      32             : 
      33           0 :             if (trueDomainDesc == nullptr)
      34           0 :                 throw InvalidArgumentError(
      35             :                     "BlockLinearOperator: domain descriptor is not a BlockDescriptor");
      36             : 
      37           0 :             if (trueDomainDesc->getNumberOfBlocks() != static_cast<index_t>(ops.size()))
      38           0 :                 throw InvalidArgumentError("BlockLinearOperator: domain descriptor number of "
      39             :                                            "blocks does not match operator list size");
      40             : 
      41           0 :             for (index_t i = 0; i < static_cast<index_t>(ops.size()); i++) {
      42           0 :                 const auto& op = ops[static_cast<std::size_t>(i)];
      43           0 :                 if (op->getRangeDescriptor().getNumberOfCoefficients()
      44           0 :                     != _rangeDescriptor->getNumberOfCoefficients())
      45           0 :                     throw InvalidArgumentError(
      46             :                         "BlockLinearOperator: the range descriptor of a COL BlockLinearOperator "
      47             :                         "must have the same size as the range of every operator in the list");
      48             : 
      49           0 :                 if (op->getDomainDescriptor().getNumberOfCoefficients()
      50           0 :                     != trueDomainDesc->getDescriptorOfBlock(i).getNumberOfCoefficients())
      51           0 :                     throw InvalidArgumentError(
      52             :                         "BlockLinearOperator: block of incorrect size in domain descriptor");
      53             :             }
      54             :         }
      55             : 
      56           0 :         if (_blockType == ROW) {
      57           0 :             const auto* trueRangeDesc = downcast_safe<BlockDescriptor>(&rangeDescriptor);
      58             : 
      59           0 :             if (trueRangeDesc == nullptr)
      60           0 :                 throw InvalidArgumentError(
      61             :                     "BlockLinearOperator: range descriptor is not a BlockDescriptor");
      62             : 
      63           0 :             if (trueRangeDesc->getNumberOfBlocks() != static_cast<index_t>(ops.size()))
      64           0 :                 throw InvalidArgumentError("BlockLinearOperator: range descriptor number of "
      65             :                                            "blocks does not match operator list size");
      66             : 
      67           0 :             for (index_t i = 0; i < static_cast<index_t>(ops.size()); i++) {
      68           0 :                 const auto& op = ops[static_cast<std::size_t>(i)];
      69           0 :                 if (op->getDomainDescriptor().getNumberOfCoefficients()
      70           0 :                     != _domainDescriptor->getNumberOfCoefficients())
      71           0 :                     throw InvalidArgumentError(
      72             :                         "BlockLinearOperator: the domain descriptor of a ROW BlockLinearOperator "
      73             :                         "must have the same size as the domain of every operator in the list");
      74             : 
      75           0 :                 if (op->getRangeDescriptor().getNumberOfCoefficients()
      76           0 :                     != trueRangeDesc->getDescriptorOfBlock(i).getNumberOfCoefficients())
      77           0 :                     throw InvalidArgumentError(
      78             :                         "BlockLinearOperator: block of incorrect size in range descriptor");
      79             :             }
      80             :         }
      81             : 
      82           0 :         for (const auto& op : ops)
      83           0 :             _operatorList.push_back(op->clone());
      84           0 :     }
      85             : 
      86             :     template <typename data_t>
      87           0 :     const LinearOperator<data_t>& BlockLinearOperator<data_t>::getIthOperator(index_t index) const
      88             :     {
      89           0 :         return *_operatorList.at(static_cast<std::size_t>(index));
      90             :     }
      91             : 
      92             :     template <typename data_t>
      93           0 :     index_t BlockLinearOperator<data_t>::numberOfOps() const
      94             :     {
      95           0 :         return static_cast<index_t>(_operatorList.size());
      96             :     }
      97             : 
      98             :     template <typename data_t>
      99           0 :     BlockLinearOperator<data_t>::BlockLinearOperator(const BlockLinearOperator<data_t>& other)
     100           0 :         : LinearOperator<data_t>{*other._domainDescriptor, *other._rangeDescriptor},
     101           0 :           _operatorList(0),
     102           0 :           _blockType{other._blockType}
     103             :     {
     104           0 :         for (const auto& op : other._operatorList)
     105           0 :             _operatorList.push_back(op->clone());
     106           0 :     }
     107             : 
     108             :     template <typename data_t>
     109           0 :     void BlockLinearOperator<data_t>::applyImpl(const DataContainer<data_t>& x,
     110             :                                                 DataContainer<data_t>& Ax) const
     111             :     {
     112           0 :         switch (_blockType) {
     113           0 :             case BlockType::COL: {
     114           0 :                 Ax = 0;
     115           0 :                 auto tmpAx = DataContainer<data_t>(Ax.getDataDescriptor());
     116             : 
     117           0 :                 auto xView = x.viewAs(*_domainDescriptor);
     118           0 :                 index_t i = 0;
     119           0 :                 for (const auto& op : _operatorList) {
     120           0 :                     op->apply(xView.getBlock(i), tmpAx);
     121           0 :                     Ax += tmpAx;
     122           0 :                     ++i;
     123             :                 }
     124             : 
     125           0 :                 break;
     126           0 :             }
     127           0 :             case BlockType::ROW: {
     128           0 :                 index_t i = 0;
     129             : 
     130           0 :                 auto AxView = Ax.viewAs(*_rangeDescriptor);
     131           0 :                 for (const auto& op : _operatorList) {
     132           0 :                     auto blk = AxView.getBlock(i);
     133           0 :                     op->apply(x, blk);
     134           0 :                     ++i;
     135             :                 }
     136             : 
     137           0 :                 break;
     138           0 :             }
     139             :         }
     140           0 :     }
     141             : 
     142             :     template <typename data_t>
     143           0 :     void BlockLinearOperator<data_t>::applyAdjointImpl(const DataContainer<data_t>& y,
     144             :                                                        DataContainer<data_t>& Aty) const
     145             :     {
     146           0 :         switch (_blockType) {
     147           0 :             case BlockType::COL: {
     148           0 :                 index_t i = 0;
     149             : 
     150           0 :                 auto AtyView = Aty.viewAs(*_domainDescriptor);
     151           0 :                 for (const auto& op : _operatorList) {
     152           0 :                     auto&& blk = AtyView.getBlock(i);
     153           0 :                     op->applyAdjoint(y, blk);
     154           0 :                     ++i;
     155             :                 }
     156             : 
     157           0 :                 break;
     158           0 :             }
     159           0 :             case BlockType::ROW: {
     160           0 :                 Aty = 0;
     161           0 :                 auto tmpAty = DataContainer<data_t>(Aty.getDataDescriptor());
     162             : 
     163           0 :                 auto yView = y.viewAs(*_rangeDescriptor);
     164           0 :                 index_t i = 0;
     165           0 :                 for (const auto& op : _operatorList) {
     166           0 :                     op->applyAdjoint(yView.getBlock(i), tmpAty);
     167           0 :                     Aty += tmpAty;
     168           0 :                     ++i;
     169             :                 }
     170             : 
     171           0 :                 break;
     172           0 :             }
     173             :         }
     174           0 :     }
     175             : 
     176             :     template <typename data_t>
     177           0 :     BlockLinearOperator<data_t>* BlockLinearOperator<data_t>::cloneImpl() const
     178             :     {
     179           0 :         return new BlockLinearOperator<data_t>(*this);
     180             :     }
     181             : 
     182             :     template <typename data_t>
     183           0 :     bool BlockLinearOperator<data_t>::isEqual(const LinearOperator<data_t>& other) const
     184             :     {
     185           0 :         if (!LinearOperator<data_t>::isEqual(other))
     186           0 :             return false;
     187             : 
     188             :         // static_cast as type checked in base comparison
     189           0 :         auto otherBlockOp = static_cast<const BlockLinearOperator<data_t>*>(&other);
     190             : 
     191           0 :         for (std::size_t i = 0; i < _operatorList.size(); i++)
     192           0 :             if (*_operatorList[i] != *otherBlockOp->_operatorList[i])
     193           0 :                 return false;
     194             : 
     195           0 :         return true;
     196             :     }
     197             : 
     198             :     template <typename data_t>
     199             :     std::unique_ptr<DataDescriptor>
     200           0 :         BlockLinearOperator<data_t>::determineDomainDescriptor(const OperatorList& operatorList,
     201             :                                                                BlockType blockType)
     202             :     {
     203           0 :         std::vector<const DataDescriptor*> vec(operatorList.size());
     204           0 :         for (std::size_t i = 0; i < vec.size(); i++)
     205           0 :             vec[i] = &operatorList[i]->getDomainDescriptor();
     206             : 
     207           0 :         switch (blockType) {
     208           0 :             case BlockType::ROW:
     209           0 :                 return bestCommon(vec);
     210             : 
     211           0 :             case BlockType::COL:
     212           0 :                 return bestBlockDescriptor(vec);
     213             : 
     214           0 :             default:
     215           0 :                 throw InvalidArgumentError("BlockLinearOpearator: unsupported block type");
     216             :         }
     217           0 :     }
     218             : 
     219             :     template <typename data_t>
     220             :     std::unique_ptr<DataDescriptor>
     221           0 :         BlockLinearOperator<data_t>::determineRangeDescriptor(const OperatorList& operatorList,
     222             :                                                               BlockType blockType)
     223             :     {
     224           0 :         std::vector<const DataDescriptor*> vec(operatorList.size());
     225           0 :         for (std::size_t i = 0; i < vec.size(); i++)
     226           0 :             vec[i] = &operatorList[i]->getRangeDescriptor();
     227             : 
     228           0 :         switch (blockType) {
     229           0 :             case BlockType::ROW:
     230           0 :                 return bestBlockDescriptor(vec);
     231             : 
     232           0 :             case BlockType::COL:
     233           0 :                 return bestCommon(vec);
     234             : 
     235           0 :             default:
     236           0 :                 throw InvalidArgumentError("BlockLinearOpearator: unsupported block type");
     237             :         }
     238           0 :     }
     239             : 
     240             :     template <typename data_t>
     241           0 :     std::unique_ptr<BlockDescriptor> BlockLinearOperator<data_t>::bestBlockDescriptor(
     242             :         const std::vector<const DataDescriptor*>& descList)
     243             :     {
     244           0 :         auto numBlocks = descList.size();
     245           0 :         if (numBlocks == 0)
     246           0 :             throw InvalidArgumentError("BlockLinearOperator: operator list cannot be empty");
     247             : 
     248           0 :         const auto& firstDesc = *descList[0];
     249           0 :         auto numDims = firstDesc.getNumberOfDimensions();
     250           0 :         auto coeffs = firstDesc.getNumberOfCoefficientsPerDimension();
     251             : 
     252           0 :         bool allNumDimSame = true;
     253           0 :         bool allButLastDimSame = true;
     254           0 :         IndexVector_t lastDimSplit(numBlocks);
     255           0 :         for (std::size_t i = 1; i < numBlocks && allButLastDimSame; i++) {
     256           0 :             lastDimSplit[static_cast<index_t>(i - 1)] =
     257           0 :                 descList[i - 1]->getNumberOfCoefficientsPerDimension()[numDims - 1];
     258             : 
     259           0 :             allNumDimSame = allNumDimSame && descList[i]->getNumberOfDimensions() == numDims;
     260           0 :             allButLastDimSame =
     261           0 :                 allNumDimSame && allButLastDimSame
     262           0 :                 && descList[i]->getNumberOfCoefficientsPerDimension().head(numDims - 1)
     263           0 :                        == coeffs.head(numDims - 1);
     264             :         }
     265             : 
     266           0 :         if (allButLastDimSame) {
     267           0 :             lastDimSplit[static_cast<index_t>(numBlocks) - 1] =
     268           0 :                 descList[numBlocks - 1]->getNumberOfCoefficientsPerDimension()[numDims - 1];
     269             : 
     270           0 :             auto spacing = firstDesc.getSpacingPerDimension();
     271             :             bool allSameSpacing =
     272           0 :                 all_of(descList.begin(), descList.end(), [&spacing](const DataDescriptor* d) {
     273           0 :                     return d->getSpacingPerDimension() == spacing;
     274             :                 });
     275             : 
     276           0 :             coeffs[numDims - 1] = lastDimSplit.sum();
     277           0 :             if (allSameSpacing) {
     278           0 :                 VolumeDescriptor tmp(coeffs, spacing);
     279           0 :                 return std::make_unique<PartitionDescriptor>(tmp, lastDimSplit);
     280           0 :             } else {
     281           0 :                 VolumeDescriptor tmp(coeffs);
     282           0 :                 return std::make_unique<PartitionDescriptor>(tmp, lastDimSplit);
     283           0 :             }
     284           0 :         }
     285             : 
     286           0 :         std::vector<std::unique_ptr<DataDescriptor>> tmp(numBlocks);
     287           0 :         auto it = descList.begin();
     288           0 :         std::generate(tmp.begin(), tmp.end(), [&it]() { return (*it++)->clone(); });
     289             : 
     290           0 :         return std::make_unique<RandomBlocksDescriptor>(std::move(tmp));
     291           0 :     }
     292             : 
     293             :     // ----------------------------------------------
     294             :     // explicit template instantiation
     295             :     template class BlockLinearOperator<float>;
     296             :     template class BlockLinearOperator<double>;
     297             :     template class BlockLinearOperator<complex<float>>;
     298             :     template class BlockLinearOperator<complex<double>>;
     299             : } // namespace elsa

Generated by: LCOV version 1.14