Line data Source code
1 : #include "BlockLinearOperator.h" 2 : #include "PartitionDescriptor.h" 3 : #include "RandomBlocksDescriptor.h" 4 : #include "DescriptorUtils.h" 5 : #include "TypeCasts.hpp" 6 : 7 : #include <algorithm> 8 : 9 : namespace elsa 10 : { 11 : template <typename data_t> 12 : BlockLinearOperator<data_t>::BlockLinearOperator(const OperatorList& ops, BlockType blockType) 13 : : LinearOperator<data_t>{*determineDomainDescriptor(ops, blockType), 14 : *determineRangeDescriptor(ops, blockType)}, 15 : _operatorList(0), 16 : _blockType{blockType} 17 78 : { 18 78 : for (const auto& op : ops) 19 170 : _operatorList.push_back(op->clone()); 20 78 : } 21 : 22 : template <typename data_t> 23 : BlockLinearOperator<data_t>::BlockLinearOperator(const DataDescriptor& domainDescriptor, 24 : const DataDescriptor& rangeDescriptor, 25 : const OperatorList& ops, BlockType blockType) 26 : : LinearOperator<data_t>{domainDescriptor, rangeDescriptor}, 27 : _operatorList(0), 28 : _blockType{blockType} 29 80 : { 30 80 : if (_blockType == BlockType::COL) { 31 34 : const auto* trueDomainDesc = downcast_safe<BlockDescriptor>(&domainDescriptor); 32 : 33 34 : if (trueDomainDesc == nullptr) 34 4 : throw InvalidArgumentError( 35 4 : "BlockLinearOperator: domain descriptor is not a BlockDescriptor"); 36 : 37 30 : if (trueDomainDesc->getNumberOfBlocks() != static_cast<index_t>(ops.size())) 38 2 : throw InvalidArgumentError("BlockLinearOperator: domain descriptor number of " 39 2 : "blocks does not match operator list size"); 40 : 41 264 : for (index_t i = 0; i < static_cast<index_t>(ops.size()); i++) { 42 244 : const auto& op = ops[static_cast<std::size_t>(i)]; 43 244 : if (op->getRangeDescriptor().getNumberOfCoefficients() 44 244 : != _rangeDescriptor->getNumberOfCoefficients()) 45 4 : throw InvalidArgumentError( 46 4 : "BlockLinearOperator: the range descriptor of a COL BlockLinearOperator " 47 4 : "must have the same size as the range of every operator in the list"); 48 : 49 240 : if (op->getDomainDescriptor().getNumberOfCoefficients() 50 240 : != trueDomainDesc->getDescriptorOfBlock(i).getNumberOfCoefficients()) 51 4 : throw InvalidArgumentError( 52 4 : "BlockLinearOperator: block of incorrect size in domain descriptor"); 53 240 : } 54 28 : } 55 : 56 46 : else if (_blockType == BlockType::ROW) { 57 46 : const auto* trueRangeDesc = downcast_safe<BlockDescriptor>(&rangeDescriptor); 58 : 59 46 : if (trueRangeDesc == nullptr) 60 4 : throw InvalidArgumentError( 61 4 : "BlockLinearOperator: range descriptor is not a BlockDescriptor"); 62 : 63 42 : if (trueRangeDesc->getNumberOfBlocks() != static_cast<index_t>(ops.size())) 64 2 : throw InvalidArgumentError("BlockLinearOperator: range descriptor number of " 65 2 : "blocks does not match operator list size"); 66 : 67 104 : for (index_t i = 0; i < static_cast<index_t>(ops.size()); i++) { 68 72 : const auto& op = ops[static_cast<std::size_t>(i)]; 69 72 : if (op->getDomainDescriptor().getNumberOfCoefficients() 70 72 : != _domainDescriptor->getNumberOfCoefficients()) 71 4 : throw InvalidArgumentError( 72 4 : "BlockLinearOperator: the domain descriptor of a ROW BlockLinearOperator " 73 4 : "must have the same size as the domain of every operator in the list"); 74 : 75 68 : if (op->getRangeDescriptor().getNumberOfCoefficients() 76 68 : != trueRangeDesc->getDescriptorOfBlock(i).getNumberOfCoefficients()) 77 4 : throw InvalidArgumentError( 78 4 : "BlockLinearOperator: block of incorrect size in range descriptor"); 79 68 : } 80 40 : } 81 : 82 80 : for (const auto& op : ops) 83 300 : _operatorList.push_back(op->clone()); 84 52 : } 85 : 86 : template <typename data_t> 87 : const LinearOperator<data_t>& BlockLinearOperator<data_t>::getIthOperator(index_t index) const 88 40 : { 89 40 : return *_operatorList.at(static_cast<std::size_t>(index)); 90 40 : } 91 : 92 : template <typename data_t> 93 : index_t BlockLinearOperator<data_t>::numberOfOps() const 94 20 : { 95 20 : return static_cast<index_t>(_operatorList.size()); 96 20 : } 97 : 98 : template <typename data_t> 99 : BlockLinearOperator<data_t>::BlockLinearOperator(const BlockLinearOperator<data_t>& other) 100 : : LinearOperator<data_t>{*other._domainDescriptor, *other._rangeDescriptor}, 101 : _operatorList(0), 102 : _blockType{other._blockType} 103 226 : { 104 226 : for (const auto& op : other._operatorList) 105 458 : _operatorList.push_back(op->clone()); 106 226 : } 107 : 108 : template <typename data_t> 109 : void BlockLinearOperator<data_t>::applyImpl(const DataContainer<data_t>& x, 110 : DataContainer<data_t>& Ax) const 111 128 : { 112 128 : switch (_blockType) { 113 18 : case BlockType::COL: { 114 : 115 18 : auto acc = zeroslike(Ax); 116 18 : auto img = emptylike(Ax); 117 18 : const auto xView = x.viewAs(*_domainDescriptor); 118 : 119 200 : for (size_t i = 0; i < _operatorList.size(); ++i) { 120 182 : _operatorList[i]->apply(xView.getBlock(i), img); 121 182 : acc += img; 122 182 : } 123 18 : Ax = acc; 124 18 : break; 125 0 : } 126 110 : case BlockType::ROW: { 127 : 128 110 : auto AxView = Ax.viewAs(*_rangeDescriptor); 129 : 130 336 : for (size_t i = 0; i < _operatorList.size(); ++i) { 131 226 : auto block = AxView.getBlock(i); 132 226 : _operatorList[i]->apply(x, block); 133 226 : } 134 : 135 110 : break; 136 0 : } 137 128 : } 138 128 : } 139 : 140 : template <typename data_t> 141 : void BlockLinearOperator<data_t>::applyAdjointImpl(const DataContainer<data_t>& y, 142 : DataContainer<data_t>& Aty) const 143 144 : { 144 144 : switch (_blockType) { 145 10 : case BlockType::COL: { 146 10 : index_t i = 0; 147 : 148 10 : auto AtyView = Aty.viewAs(*_domainDescriptor); 149 78 : for (const auto& op : _operatorList) { 150 78 : auto&& blk = AtyView.getBlock(i); 151 78 : op->applyAdjoint(y, blk); 152 78 : ++i; 153 78 : } 154 : 155 10 : break; 156 0 : } 157 134 : case BlockType::ROW: { 158 134 : Aty = 0; 159 134 : auto tmpAty = DataContainer<data_t>(Aty.getDataDescriptor()); 160 : 161 134 : auto yView = y.viewAs(*_rangeDescriptor); 162 134 : index_t i = 0; 163 270 : for (const auto& op : _operatorList) { 164 270 : op->applyAdjoint(materialize(yView.getBlock(i)), tmpAty); 165 270 : Aty += tmpAty; 166 270 : ++i; 167 270 : } 168 : 169 134 : break; 170 0 : } 171 144 : } 172 144 : } 173 : 174 : template <typename data_t> 175 : BlockLinearOperator<data_t>* BlockLinearOperator<data_t>::cloneImpl() const 176 226 : { 177 226 : return new BlockLinearOperator<data_t>(*this); 178 226 : } 179 : 180 : template <typename data_t> 181 : bool BlockLinearOperator<data_t>::isEqual(const LinearOperator<data_t>& other) const 182 18 : { 183 18 : if (!LinearOperator<data_t>::isEqual(other)) 184 6 : return false; 185 : 186 : // static_cast as type checked in base comparison 187 12 : auto otherBlockOp = static_cast<const BlockLinearOperator<data_t>*>(&other); 188 : 189 40 : for (std::size_t i = 0; i < _operatorList.size(); i++) 190 32 : if (*_operatorList[i] != *otherBlockOp->_operatorList[i]) 191 4 : return false; 192 : 193 12 : return true; 194 12 : } 195 : 196 : template <typename data_t> 197 : std::unique_ptr<DataDescriptor> 198 : BlockLinearOperator<data_t>::determineDomainDescriptor(const OperatorList& operatorList, 199 : BlockType blockType) 200 78 : { 201 78 : std::vector<const DataDescriptor*> vec(operatorList.size()); 202 256 : for (std::size_t i = 0; i < vec.size(); i++) 203 178 : vec[i] = &operatorList[i]->getDomainDescriptor(); 204 : 205 78 : switch (blockType) { 206 42 : case BlockType::ROW: 207 42 : return bestCommon(vec); 208 : 209 36 : case BlockType::COL: 210 36 : return bestBlockDescriptor(vec); 211 : 212 0 : default: 213 0 : throw InvalidArgumentError("BlockLinearOpearator: unsupported block type"); 214 78 : } 215 78 : } 216 : 217 : template <typename data_t> 218 : std::unique_ptr<DataDescriptor> 219 : BlockLinearOperator<data_t>::determineRangeDescriptor(const OperatorList& operatorList, 220 : BlockType blockType) 221 72 : { 222 72 : std::vector<const DataDescriptor*> vec(operatorList.size()); 223 246 : for (std::size_t i = 0; i < vec.size(); i++) 224 174 : vec[i] = &operatorList[i]->getRangeDescriptor(); 225 : 226 72 : switch (blockType) { 227 38 : case BlockType::ROW: 228 38 : return bestBlockDescriptor(vec); 229 : 230 34 : case BlockType::COL: 231 34 : return bestCommon(vec); 232 : 233 0 : default: 234 0 : throw InvalidArgumentError("BlockLinearOpearator: unsupported block type"); 235 72 : } 236 72 : } 237 : 238 : template <typename data_t> 239 : std::unique_ptr<BlockDescriptor> BlockLinearOperator<data_t>::bestBlockDescriptor( 240 : const std::vector<const DataDescriptor*>& descList) 241 74 : { 242 74 : auto numBlocks = descList.size(); 243 74 : if (numBlocks == 0) 244 2 : throw InvalidArgumentError("BlockLinearOperator: operator list cannot be empty"); 245 : 246 72 : const auto& firstDesc = *descList[0]; 247 72 : auto numDims = firstDesc.getNumberOfDimensions(); 248 72 : auto coeffs = firstDesc.getNumberOfCoefficientsPerDimension(); 249 : 250 72 : bool allNumDimSame = true; 251 72 : bool allButLastDimSame = true; 252 72 : IndexVector_t lastDimSplit(numBlocks); 253 174 : for (std::size_t i = 1; i < numBlocks && allButLastDimSame; i++) { 254 102 : lastDimSplit[static_cast<index_t>(i - 1)] = 255 102 : descList[i - 1]->getNumberOfCoefficientsPerDimension()[numDims - 1]; 256 : 257 102 : allNumDimSame = allNumDimSame && descList[i]->getNumberOfDimensions() == numDims; 258 102 : allButLastDimSame = 259 102 : allNumDimSame && allButLastDimSame 260 102 : && descList[i]->getNumberOfCoefficientsPerDimension().head(numDims - 1) 261 94 : == coeffs.head(numDims - 1); 262 102 : } 263 : 264 72 : if (allButLastDimSame) { 265 64 : lastDimSplit[static_cast<index_t>(numBlocks) - 1] = 266 64 : descList[numBlocks - 1]->getNumberOfCoefficientsPerDimension()[numDims - 1]; 267 : 268 64 : auto spacing = firstDesc.getSpacingPerDimension(); 269 64 : bool allSameSpacing = 270 158 : all_of(descList.begin(), descList.end(), [&spacing](const DataDescriptor* d) { 271 158 : return d->getSpacingPerDimension() == spacing; 272 158 : }); 273 : 274 64 : coeffs[numDims - 1] = lastDimSplit.sum(); 275 64 : if (allSameSpacing) { 276 56 : VolumeDescriptor tmp(coeffs, spacing); 277 56 : return std::make_unique<PartitionDescriptor>(tmp, lastDimSplit); 278 56 : } else { 279 8 : VolumeDescriptor tmp(coeffs); 280 8 : return std::make_unique<PartitionDescriptor>(tmp, lastDimSplit); 281 8 : } 282 8 : } 283 : 284 8 : std::vector<std::unique_ptr<DataDescriptor>> tmp(numBlocks); 285 8 : auto it = descList.begin(); 286 16 : std::generate(tmp.begin(), tmp.end(), [&it]() { return (*it++)->clone(); }); 287 : 288 8 : return std::make_unique<RandomBlocksDescriptor>(std::move(tmp)); 289 8 : } 290 : 291 : // ---------------------------------------------- 292 : // explicit template instantiation 293 : template class BlockLinearOperator<float>; 294 : template class BlockLinearOperator<double>; 295 : template class BlockLinearOperator<complex<float>>; 296 : template class BlockLinearOperator<complex<double>>; 297 : } // namespace elsa