Line data Source code
1 : #include "BlockLinearOperator.h" 2 : #include "PartitionDescriptor.h" 3 : #include "RandomBlocksDescriptor.h" 4 : #include "DescriptorUtils.h" 5 : #include "TypeCasts.hpp" 6 : 7 : #include <algorithm> 8 : 9 : namespace elsa 10 : { 11 : template <typename data_t> 12 : BlockLinearOperator<data_t>::BlockLinearOperator(const OperatorList& ops, BlockType blockType) 13 : : LinearOperator<data_t>{*determineDomainDescriptor(ops, blockType), 14 : *determineRangeDescriptor(ops, blockType)}, 15 : _operatorList(0), 16 : _blockType{blockType} 17 68 : { 18 68 : for (const auto& op : ops) 19 150 : _operatorList.push_back(op->clone()); 20 68 : } 21 : 22 : template <typename data_t> 23 : BlockLinearOperator<data_t>::BlockLinearOperator(const DataDescriptor& domainDescriptor, 24 : const DataDescriptor& rangeDescriptor, 25 : const OperatorList& ops, BlockType blockType) 26 : : LinearOperator<data_t>{domainDescriptor, rangeDescriptor}, 27 : _operatorList(0), 28 : _blockType{blockType} 29 44 : { 30 44 : if (_blockType == COL) { 31 22 : const auto* trueDomainDesc = downcast_safe<BlockDescriptor>(&domainDescriptor); 32 : 33 22 : if (trueDomainDesc == nullptr) 34 4 : throw InvalidArgumentError( 35 4 : "BlockLinearOperator: domain descriptor is not a BlockDescriptor"); 36 : 37 18 : if (trueDomainDesc->getNumberOfBlocks() != static_cast<index_t>(ops.size())) 38 2 : throw InvalidArgumentError("BlockLinearOperator: domain descriptor number of " 39 2 : "blocks does not match operator list size"); 40 : 41 32 : for (index_t i = 0; i < static_cast<index_t>(ops.size()); i++) { 42 24 : const auto& op = ops[static_cast<std::size_t>(i)]; 43 24 : if (op->getRangeDescriptor().getNumberOfCoefficients() 44 24 : != _rangeDescriptor->getNumberOfCoefficients()) 45 4 : throw InvalidArgumentError( 46 4 : "BlockLinearOperator: the range descriptor of a COL BlockLinearOperator " 47 4 : "must have the same size as the range of every operator in the list"); 48 : 49 20 : if (op->getDomainDescriptor().getNumberOfCoefficients() 50 20 : != trueDomainDesc->getDescriptorOfBlock(i).getNumberOfCoefficients()) 51 4 : throw InvalidArgumentError( 52 4 : "BlockLinearOperator: block of incorrect size in domain descriptor"); 53 20 : } 54 16 : } 55 : 56 44 : if (_blockType == ROW) { 57 22 : const auto* trueRangeDesc = downcast_safe<BlockDescriptor>(&rangeDescriptor); 58 : 59 22 : if (trueRangeDesc == nullptr) 60 4 : throw InvalidArgumentError( 61 4 : "BlockLinearOperator: range descriptor is not a BlockDescriptor"); 62 : 63 18 : if (trueRangeDesc->getNumberOfBlocks() != static_cast<index_t>(ops.size())) 64 2 : throw InvalidArgumentError("BlockLinearOperator: range descriptor number of " 65 2 : "blocks does not match operator list size"); 66 : 67 32 : for (index_t i = 0; i < static_cast<index_t>(ops.size()); i++) { 68 24 : const auto& op = ops[static_cast<std::size_t>(i)]; 69 24 : if (op->getDomainDescriptor().getNumberOfCoefficients() 70 24 : != _domainDescriptor->getNumberOfCoefficients()) 71 4 : throw InvalidArgumentError( 72 4 : "BlockLinearOperator: the domain descriptor of a ROW BlockLinearOperator " 73 4 : "must have the same size as the domain of every operator in the list"); 74 : 75 20 : if (op->getRangeDescriptor().getNumberOfCoefficients() 76 20 : != trueRangeDesc->getDescriptorOfBlock(i).getNumberOfCoefficients()) 77 4 : throw InvalidArgumentError( 78 4 : "BlockLinearOperator: block of incorrect size in range descriptor"); 79 20 : } 80 16 : } 81 : 82 30 : for (const auto& op : ops) 83 32 : _operatorList.push_back(op->clone()); 84 16 : } 85 : 86 : template <typename data_t> 87 : const LinearOperator<data_t>& BlockLinearOperator<data_t>::getIthOperator(index_t index) const 88 40 : { 89 40 : return *_operatorList.at(static_cast<std::size_t>(index)); 90 40 : } 91 : 92 : template <typename data_t> 93 : index_t BlockLinearOperator<data_t>::numberOfOps() const 94 20 : { 95 20 : return static_cast<index_t>(_operatorList.size()); 96 20 : } 97 : 98 : template <typename data_t> 99 : BlockLinearOperator<data_t>::BlockLinearOperator(const BlockLinearOperator<data_t>& other) 100 : : LinearOperator<data_t>{*other._domainDescriptor, *other._rangeDescriptor}, 101 : _operatorList(0), 102 : _blockType{other._blockType} 103 24 : { 104 24 : for (const auto& op : other._operatorList) 105 54 : _operatorList.push_back(op->clone()); 106 24 : } 107 : 108 : template <typename data_t> 109 : void BlockLinearOperator<data_t>::applyImpl(const DataContainer<data_t>& x, 110 : DataContainer<data_t>& Ax) const 111 12 : { 112 12 : switch (_blockType) { 113 2 : case BlockType::COL: { 114 2 : Ax = 0; 115 2 : auto tmpAx = DataContainer<data_t>(Ax.getDataDescriptor()); 116 : 117 2 : auto xView = x.viewAs(*_domainDescriptor); 118 2 : index_t i = 0; 119 6 : for (const auto& op : _operatorList) { 120 6 : op->apply(xView.getBlock(i), tmpAx); 121 6 : Ax += tmpAx; 122 6 : ++i; 123 6 : } 124 : 125 2 : break; 126 0 : } 127 10 : case BlockType::ROW: { 128 10 : index_t i = 0; 129 : 130 10 : auto AxView = Ax.viewAs(*_rangeDescriptor); 131 26 : for (const auto& op : _operatorList) { 132 26 : auto blk = AxView.getBlock(i); 133 26 : op->apply(x, blk); 134 26 : ++i; 135 26 : } 136 : 137 10 : break; 138 0 : } 139 12 : } 140 12 : } 141 : 142 : template <typename data_t> 143 : void BlockLinearOperator<data_t>::applyAdjointImpl(const DataContainer<data_t>& y, 144 : DataContainer<data_t>& Aty) const 145 10 : { 146 10 : switch (_blockType) { 147 6 : case BlockType::COL: { 148 6 : index_t i = 0; 149 : 150 6 : auto AtyView = Aty.viewAs(*_domainDescriptor); 151 18 : for (const auto& op : _operatorList) { 152 18 : auto&& blk = AtyView.getBlock(i); 153 18 : op->applyAdjoint(y, blk); 154 18 : ++i; 155 18 : } 156 : 157 6 : break; 158 0 : } 159 4 : case BlockType::ROW: { 160 4 : Aty = 0; 161 4 : auto tmpAty = DataContainer<data_t>(Aty.getDataDescriptor()); 162 : 163 4 : auto yView = y.viewAs(*_rangeDescriptor); 164 4 : index_t i = 0; 165 10 : for (const auto& op : _operatorList) { 166 10 : op->applyAdjoint(yView.getBlock(i), tmpAty); 167 10 : Aty += tmpAty; 168 10 : ++i; 169 10 : } 170 : 171 4 : break; 172 0 : } 173 10 : } 174 10 : } 175 : 176 : template <typename data_t> 177 : BlockLinearOperator<data_t>* BlockLinearOperator<data_t>::cloneImpl() const 178 24 : { 179 24 : return new BlockLinearOperator<data_t>(*this); 180 24 : } 181 : 182 : template <typename data_t> 183 : bool BlockLinearOperator<data_t>::isEqual(const LinearOperator<data_t>& other) const 184 16 : { 185 16 : if (!LinearOperator<data_t>::isEqual(other)) 186 6 : return false; 187 : 188 : // static_cast as type checked in base comparison 189 10 : auto otherBlockOp = static_cast<const BlockLinearOperator<data_t>*>(&other); 190 : 191 34 : for (std::size_t i = 0; i < _operatorList.size(); i++) 192 28 : if (*_operatorList[i] != *otherBlockOp->_operatorList[i]) 193 4 : return false; 194 : 195 10 : return true; 196 10 : } 197 : 198 : template <typename data_t> 199 : std::unique_ptr<DataDescriptor> 200 : BlockLinearOperator<data_t>::determineDomainDescriptor(const OperatorList& operatorList, 201 : BlockType blockType) 202 68 : { 203 68 : std::vector<const DataDescriptor*> vec(operatorList.size()); 204 226 : for (std::size_t i = 0; i < vec.size(); i++) 205 158 : vec[i] = &operatorList[i]->getDomainDescriptor(); 206 : 207 68 : switch (blockType) { 208 40 : case BlockType::ROW: 209 40 : return bestCommon(vec); 210 : 211 28 : case BlockType::COL: 212 28 : return bestBlockDescriptor(vec); 213 : 214 0 : default: 215 0 : throw InvalidArgumentError("BlockLinearOpearator: unsupported block type"); 216 68 : } 217 68 : } 218 : 219 : template <typename data_t> 220 : std::unique_ptr<DataDescriptor> 221 : BlockLinearOperator<data_t>::determineRangeDescriptor(const OperatorList& operatorList, 222 : BlockType blockType) 223 62 : { 224 62 : std::vector<const DataDescriptor*> vec(operatorList.size()); 225 216 : for (std::size_t i = 0; i < vec.size(); i++) 226 154 : vec[i] = &operatorList[i]->getRangeDescriptor(); 227 : 228 62 : switch (blockType) { 229 36 : case BlockType::ROW: 230 36 : return bestBlockDescriptor(vec); 231 : 232 26 : case BlockType::COL: 233 26 : return bestCommon(vec); 234 : 235 0 : default: 236 0 : throw InvalidArgumentError("BlockLinearOpearator: unsupported block type"); 237 62 : } 238 62 : } 239 : 240 : template <typename data_t> 241 : std::unique_ptr<BlockDescriptor> BlockLinearOperator<data_t>::bestBlockDescriptor( 242 : const std::vector<const DataDescriptor*>& descList) 243 64 : { 244 64 : auto numBlocks = descList.size(); 245 64 : if (numBlocks == 0) 246 2 : throw InvalidArgumentError("BlockLinearOperator: operator list cannot be empty"); 247 : 248 62 : const auto& firstDesc = *descList[0]; 249 62 : auto numDims = firstDesc.getNumberOfDimensions(); 250 62 : auto coeffs = firstDesc.getNumberOfCoefficientsPerDimension(); 251 : 252 62 : bool allNumDimSame = true; 253 62 : bool allButLastDimSame = true; 254 62 : IndexVector_t lastDimSplit(numBlocks); 255 154 : for (std::size_t i = 1; i < numBlocks && allButLastDimSame; i++) { 256 92 : lastDimSplit[static_cast<index_t>(i - 1)] = 257 92 : descList[i - 1]->getNumberOfCoefficientsPerDimension()[numDims - 1]; 258 : 259 92 : allNumDimSame = allNumDimSame && descList[i]->getNumberOfDimensions() == numDims; 260 92 : allButLastDimSame = 261 92 : allNumDimSame && allButLastDimSame 262 92 : && descList[i]->getNumberOfCoefficientsPerDimension().head(numDims - 1) 263 84 : == coeffs.head(numDims - 1); 264 92 : } 265 : 266 62 : if (allButLastDimSame) { 267 54 : lastDimSplit[static_cast<index_t>(numBlocks) - 1] = 268 54 : descList[numBlocks - 1]->getNumberOfCoefficientsPerDimension()[numDims - 1]; 269 : 270 54 : auto spacing = firstDesc.getSpacingPerDimension(); 271 54 : bool allSameSpacing = 272 138 : all_of(descList.begin(), descList.end(), [&spacing](const DataDescriptor* d) { 273 138 : return d->getSpacingPerDimension() == spacing; 274 138 : }); 275 : 276 54 : coeffs[numDims - 1] = lastDimSplit.sum(); 277 54 : if (allSameSpacing) { 278 46 : VolumeDescriptor tmp(coeffs, spacing); 279 46 : return std::make_unique<PartitionDescriptor>(tmp, lastDimSplit); 280 46 : } else { 281 8 : VolumeDescriptor tmp(coeffs); 282 8 : return std::make_unique<PartitionDescriptor>(tmp, lastDimSplit); 283 8 : } 284 8 : } 285 : 286 8 : std::vector<std::unique_ptr<DataDescriptor>> tmp(numBlocks); 287 8 : auto it = descList.begin(); 288 16 : std::generate(tmp.begin(), tmp.end(), [&it]() { return (*it++)->clone(); }); 289 : 290 8 : return std::make_unique<RandomBlocksDescriptor>(std::move(tmp)); 291 8 : } 292 : 293 : // ---------------------------------------------- 294 : // explicit template instantiation 295 : template class BlockLinearOperator<float>; 296 : template class BlockLinearOperator<double>; 297 : template class BlockLinearOperator<complex<float>>; 298 : template class BlockLinearOperator<complex<double>>; 299 : } // namespace elsa