Line data Source code
1 : #include "BlockLinearOperator.h" 2 : #include "PartitionDescriptor.h" 3 : #include "RandomBlocksDescriptor.h" 4 : #include "DescriptorUtils.h" 5 : #include "TypeCasts.hpp" 6 : 7 : #include <algorithm> 8 : 9 : namespace elsa 10 : { 11 : template <typename data_t> 12 0 : BlockLinearOperator<data_t>::BlockLinearOperator(const OperatorList& ops, BlockType blockType) 13 0 : : LinearOperator<data_t>{*determineDomainDescriptor(ops, blockType), 14 0 : *determineRangeDescriptor(ops, blockType)}, 15 0 : _operatorList(0), 16 0 : _blockType{blockType} 17 : { 18 0 : for (const auto& op : ops) 19 0 : _operatorList.push_back(op->clone()); 20 0 : } 21 : 22 : template <typename data_t> 23 0 : BlockLinearOperator<data_t>::BlockLinearOperator(const DataDescriptor& domainDescriptor, 24 : const DataDescriptor& rangeDescriptor, 25 : const OperatorList& ops, BlockType blockType) 26 : : LinearOperator<data_t>{domainDescriptor, rangeDescriptor}, 27 0 : _operatorList(0), 28 0 : _blockType{blockType} 29 : { 30 0 : if (_blockType == COL) { 31 0 : const auto* trueDomainDesc = downcast_safe<BlockDescriptor>(&domainDescriptor); 32 : 33 0 : if (trueDomainDesc == nullptr) 34 0 : throw InvalidArgumentError( 35 : "BlockLinearOperator: domain descriptor is not a BlockDescriptor"); 36 : 37 0 : if (trueDomainDesc->getNumberOfBlocks() != static_cast<index_t>(ops.size())) 38 0 : throw InvalidArgumentError("BlockLinearOperator: domain descriptor number of " 39 : "blocks does not match operator list size"); 40 : 41 0 : for (index_t i = 0; i < static_cast<index_t>(ops.size()); i++) { 42 0 : const auto& op = ops[static_cast<std::size_t>(i)]; 43 0 : if (op->getRangeDescriptor().getNumberOfCoefficients() 44 0 : != _rangeDescriptor->getNumberOfCoefficients()) 45 0 : throw InvalidArgumentError( 46 : "BlockLinearOperator: the range descriptor of a COL BlockLinearOperator " 47 : "must have the same size as the range of every operator in the list"); 48 : 49 0 : if (op->getDomainDescriptor().getNumberOfCoefficients() 50 0 : != trueDomainDesc->getDescriptorOfBlock(i).getNumberOfCoefficients()) 51 0 : throw InvalidArgumentError( 52 : "BlockLinearOperator: block of incorrect size in domain descriptor"); 53 : } 54 : } 55 : 56 0 : if (_blockType == ROW) { 57 0 : const auto* trueRangeDesc = downcast_safe<BlockDescriptor>(&rangeDescriptor); 58 : 59 0 : if (trueRangeDesc == nullptr) 60 0 : throw InvalidArgumentError( 61 : "BlockLinearOperator: range descriptor is not a BlockDescriptor"); 62 : 63 0 : if (trueRangeDesc->getNumberOfBlocks() != static_cast<index_t>(ops.size())) 64 0 : throw InvalidArgumentError("BlockLinearOperator: range descriptor number of " 65 : "blocks does not match operator list size"); 66 : 67 0 : for (index_t i = 0; i < static_cast<index_t>(ops.size()); i++) { 68 0 : const auto& op = ops[static_cast<std::size_t>(i)]; 69 0 : if (op->getDomainDescriptor().getNumberOfCoefficients() 70 0 : != _domainDescriptor->getNumberOfCoefficients()) 71 0 : throw InvalidArgumentError( 72 : "BlockLinearOperator: the domain descriptor of a ROW BlockLinearOperator " 73 : "must have the same size as the domain of every operator in the list"); 74 : 75 0 : if (op->getRangeDescriptor().getNumberOfCoefficients() 76 0 : != trueRangeDesc->getDescriptorOfBlock(i).getNumberOfCoefficients()) 77 0 : throw InvalidArgumentError( 78 : "BlockLinearOperator: block of incorrect size in range descriptor"); 79 : } 80 : } 81 : 82 0 : for (const auto& op : ops) 83 0 : _operatorList.push_back(op->clone()); 84 0 : } 85 : 86 : template <typename data_t> 87 0 : const LinearOperator<data_t>& BlockLinearOperator<data_t>::getIthOperator(index_t index) const 88 : { 89 0 : return *_operatorList.at(static_cast<std::size_t>(index)); 90 : } 91 : 92 : template <typename data_t> 93 0 : index_t BlockLinearOperator<data_t>::numberOfOps() const 94 : { 95 0 : return static_cast<index_t>(_operatorList.size()); 96 : } 97 : 98 : template <typename data_t> 99 0 : BlockLinearOperator<data_t>::BlockLinearOperator(const BlockLinearOperator<data_t>& other) 100 0 : : LinearOperator<data_t>{*other._domainDescriptor, *other._rangeDescriptor}, 101 0 : _operatorList(0), 102 0 : _blockType{other._blockType} 103 : { 104 0 : for (const auto& op : other._operatorList) 105 0 : _operatorList.push_back(op->clone()); 106 0 : } 107 : 108 : template <typename data_t> 109 0 : void BlockLinearOperator<data_t>::applyImpl(const DataContainer<data_t>& x, 110 : DataContainer<data_t>& Ax) const 111 : { 112 0 : switch (_blockType) { 113 0 : case BlockType::COL: { 114 0 : Ax = 0; 115 0 : auto tmpAx = DataContainer<data_t>(Ax.getDataDescriptor()); 116 : 117 0 : auto xView = x.viewAs(*_domainDescriptor); 118 0 : index_t i = 0; 119 0 : for (const auto& op : _operatorList) { 120 0 : op->apply(xView.getBlock(i), tmpAx); 121 0 : Ax += tmpAx; 122 0 : ++i; 123 : } 124 : 125 0 : break; 126 0 : } 127 0 : case BlockType::ROW: { 128 0 : index_t i = 0; 129 : 130 0 : auto AxView = Ax.viewAs(*_rangeDescriptor); 131 0 : for (const auto& op : _operatorList) { 132 0 : auto blk = AxView.getBlock(i); 133 0 : op->apply(x, blk); 134 0 : ++i; 135 : } 136 : 137 0 : break; 138 0 : } 139 : } 140 0 : } 141 : 142 : template <typename data_t> 143 0 : void BlockLinearOperator<data_t>::applyAdjointImpl(const DataContainer<data_t>& y, 144 : DataContainer<data_t>& Aty) const 145 : { 146 0 : switch (_blockType) { 147 0 : case BlockType::COL: { 148 0 : index_t i = 0; 149 : 150 0 : auto AtyView = Aty.viewAs(*_domainDescriptor); 151 0 : for (const auto& op : _operatorList) { 152 0 : auto&& blk = AtyView.getBlock(i); 153 0 : op->applyAdjoint(y, blk); 154 0 : ++i; 155 : } 156 : 157 0 : break; 158 0 : } 159 0 : case BlockType::ROW: { 160 0 : Aty = 0; 161 0 : auto tmpAty = DataContainer<data_t>(Aty.getDataDescriptor()); 162 : 163 0 : auto yView = y.viewAs(*_rangeDescriptor); 164 0 : index_t i = 0; 165 0 : for (const auto& op : _operatorList) { 166 0 : op->applyAdjoint(yView.getBlock(i), tmpAty); 167 0 : Aty += tmpAty; 168 0 : ++i; 169 : } 170 : 171 0 : break; 172 0 : } 173 : } 174 0 : } 175 : 176 : template <typename data_t> 177 0 : BlockLinearOperator<data_t>* BlockLinearOperator<data_t>::cloneImpl() const 178 : { 179 0 : return new BlockLinearOperator<data_t>(*this); 180 : } 181 : 182 : template <typename data_t> 183 0 : bool BlockLinearOperator<data_t>::isEqual(const LinearOperator<data_t>& other) const 184 : { 185 0 : if (!LinearOperator<data_t>::isEqual(other)) 186 0 : return false; 187 : 188 : // static_cast as type checked in base comparison 189 0 : auto otherBlockOp = static_cast<const BlockLinearOperator<data_t>*>(&other); 190 : 191 0 : for (std::size_t i = 0; i < _operatorList.size(); i++) 192 0 : if (*_operatorList[i] != *otherBlockOp->_operatorList[i]) 193 0 : return false; 194 : 195 0 : return true; 196 : } 197 : 198 : template <typename data_t> 199 : std::unique_ptr<DataDescriptor> 200 0 : BlockLinearOperator<data_t>::determineDomainDescriptor(const OperatorList& operatorList, 201 : BlockType blockType) 202 : { 203 0 : std::vector<const DataDescriptor*> vec(operatorList.size()); 204 0 : for (std::size_t i = 0; i < vec.size(); i++) 205 0 : vec[i] = &operatorList[i]->getDomainDescriptor(); 206 : 207 0 : switch (blockType) { 208 0 : case BlockType::ROW: 209 0 : return bestCommon(vec); 210 : 211 0 : case BlockType::COL: 212 0 : return bestBlockDescriptor(vec); 213 : 214 0 : default: 215 0 : throw InvalidArgumentError("BlockLinearOpearator: unsupported block type"); 216 : } 217 0 : } 218 : 219 : template <typename data_t> 220 : std::unique_ptr<DataDescriptor> 221 0 : BlockLinearOperator<data_t>::determineRangeDescriptor(const OperatorList& operatorList, 222 : BlockType blockType) 223 : { 224 0 : std::vector<const DataDescriptor*> vec(operatorList.size()); 225 0 : for (std::size_t i = 0; i < vec.size(); i++) 226 0 : vec[i] = &operatorList[i]->getRangeDescriptor(); 227 : 228 0 : switch (blockType) { 229 0 : case BlockType::ROW: 230 0 : return bestBlockDescriptor(vec); 231 : 232 0 : case BlockType::COL: 233 0 : return bestCommon(vec); 234 : 235 0 : default: 236 0 : throw InvalidArgumentError("BlockLinearOpearator: unsupported block type"); 237 : } 238 0 : } 239 : 240 : template <typename data_t> 241 0 : std::unique_ptr<BlockDescriptor> BlockLinearOperator<data_t>::bestBlockDescriptor( 242 : const std::vector<const DataDescriptor*>& descList) 243 : { 244 0 : auto numBlocks = descList.size(); 245 0 : if (numBlocks == 0) 246 0 : throw InvalidArgumentError("BlockLinearOperator: operator list cannot be empty"); 247 : 248 0 : const auto& firstDesc = *descList[0]; 249 0 : auto numDims = firstDesc.getNumberOfDimensions(); 250 0 : auto coeffs = firstDesc.getNumberOfCoefficientsPerDimension(); 251 : 252 0 : bool allNumDimSame = true; 253 0 : bool allButLastDimSame = true; 254 0 : IndexVector_t lastDimSplit(numBlocks); 255 0 : for (std::size_t i = 1; i < numBlocks && allButLastDimSame; i++) { 256 0 : lastDimSplit[static_cast<index_t>(i - 1)] = 257 0 : descList[i - 1]->getNumberOfCoefficientsPerDimension()[numDims - 1]; 258 : 259 0 : allNumDimSame = allNumDimSame && descList[i]->getNumberOfDimensions() == numDims; 260 0 : allButLastDimSame = 261 0 : allNumDimSame && allButLastDimSame 262 0 : && descList[i]->getNumberOfCoefficientsPerDimension().head(numDims - 1) 263 0 : == coeffs.head(numDims - 1); 264 : } 265 : 266 0 : if (allButLastDimSame) { 267 0 : lastDimSplit[static_cast<index_t>(numBlocks) - 1] = 268 0 : descList[numBlocks - 1]->getNumberOfCoefficientsPerDimension()[numDims - 1]; 269 : 270 0 : auto spacing = firstDesc.getSpacingPerDimension(); 271 : bool allSameSpacing = 272 0 : all_of(descList.begin(), descList.end(), [&spacing](const DataDescriptor* d) { 273 0 : return d->getSpacingPerDimension() == spacing; 274 : }); 275 : 276 0 : coeffs[numDims - 1] = lastDimSplit.sum(); 277 0 : if (allSameSpacing) { 278 0 : VolumeDescriptor tmp(coeffs, spacing); 279 0 : return std::make_unique<PartitionDescriptor>(tmp, lastDimSplit); 280 0 : } else { 281 0 : VolumeDescriptor tmp(coeffs); 282 0 : return std::make_unique<PartitionDescriptor>(tmp, lastDimSplit); 283 0 : } 284 0 : } 285 : 286 0 : std::vector<std::unique_ptr<DataDescriptor>> tmp(numBlocks); 287 0 : auto it = descList.begin(); 288 0 : std::generate(tmp.begin(), tmp.end(), [&it]() { return (*it++)->clone(); }); 289 : 290 0 : return std::make_unique<RandomBlocksDescriptor>(std::move(tmp)); 291 0 : } 292 : 293 : // ---------------------------------------------- 294 : // explicit template instantiation 295 : template class BlockLinearOperator<float>; 296 : template class BlockLinearOperator<double>; 297 : template class BlockLinearOperator<complex<float>>; 298 : template class BlockLinearOperator<complex<double>>; 299 : } // namespace elsa