Line data Source code
1 : #include "BinaryMethod.h" 2 : #include "Timer.h" 3 : #include "TraverseAABB.h" 4 : #include "TypeCasts.hpp" 5 : 6 : #include <stdexcept> 7 : #include <type_traits> 8 : 9 : namespace elsa 10 : { 11 : template <typename data_t> 12 : BinaryMethod<data_t>::BinaryMethod(const VolumeDescriptor& domainDescriptor, 13 : const DetectorDescriptor& rangeDescriptor) 14 : : LinearOperator<data_t>(domainDescriptor, rangeDescriptor), 15 : _boundingBox{domainDescriptor.getNumberOfCoefficientsPerDimension()}, 16 : _detectorDescriptor(static_cast<DetectorDescriptor&>(*_rangeDescriptor)), 17 : _volumeDescriptor(static_cast<VolumeDescriptor&>(*_domainDescriptor)) 18 46 : { 19 : // sanity checks 20 46 : auto dim = _domainDescriptor->getNumberOfDimensions(); 21 46 : if (dim < 2 || dim > 3) 22 0 : throw InvalidArgumentError("BinaryMethod: only supporting 2d/3d operations"); 23 : 24 46 : if (dim != _rangeDescriptor->getNumberOfDimensions()) 25 0 : throw InvalidArgumentError("BinaryMethod: domain and range dimension need to match"); 26 : 27 46 : if (_detectorDescriptor.getNumberOfGeometryPoses() == 0) 28 0 : throw InvalidArgumentError("BinaryMethod: rangeDescriptor without any geometry"); 29 46 : } 30 : 31 : template <typename data_t> 32 : void BinaryMethod<data_t>::applyImpl(const DataContainer<data_t>& x, 33 : DataContainer<data_t>& Ax) const 34 71 : { 35 71 : Timer t("BinaryMethod", "apply"); 36 71 : traverseVolume<false>(x, Ax); 37 71 : } 38 : 39 : template <typename data_t> 40 : void BinaryMethod<data_t>::applyAdjointImpl(const DataContainer<data_t>& y, 41 : DataContainer<data_t>& Aty) const 42 41 : { 43 41 : Timer t("BinaryMethod", "applyAdjoint"); 44 41 : traverseVolume<true>(y, Aty); 45 41 : } 46 : 47 : template <typename data_t> 48 : BinaryMethod<data_t>* BinaryMethod<data_t>::cloneImpl() const 49 2 : { 50 2 : return new BinaryMethod(_volumeDescriptor, _detectorDescriptor); 51 2 : } 52 : 53 : template <typename data_t> 54 : bool BinaryMethod<data_t>::isEqual(const LinearOperator<data_t>& other) const 55 0 : { 56 0 : if (!LinearOperator<data_t>::isEqual(other)) 57 0 : return false; 58 : 59 0 : auto otherBM = downcast_safe<BinaryMethod>(&other); 60 0 : if (!otherBM) 61 0 : return false; 62 : 63 0 : return true; 64 0 : } 65 : 66 : template <typename data_t> 67 : template <bool adjoint> 68 : void BinaryMethod<data_t>::traverseVolume(const DataContainer<data_t>& vector, 69 : DataContainer<data_t>& result) const 70 112 : { 71 112 : if (_domainDescriptor->getNumberOfDimensions() == 2) { 72 86 : return doTraverseVolume<adjoint, 2>(vector, result); 73 86 : } else if (_domainDescriptor->getNumberOfDimensions() == 3) { 74 26 : return doTraverseVolume<adjoint, 3>(vector, result); 75 26 : } 76 112 : } 77 : 78 : template <typename data_t> 79 : template <bool adjoint, int dim> 80 : void BinaryMethod<data_t>::doTraverseVolume(const DataContainer<data_t>& vector, 81 : DataContainer<data_t>& result) const 82 112 : { 83 112 : const index_t maxIterations = adjoint ? vector.getSize() : result.getSize(); 84 : 85 112 : if constexpr (adjoint) { 86 41 : result = 0; // initialize volume to 0, because we are not going to hit every voxel! 87 41 : } 88 : 89 : // --> loop either over every voxel that should be updated or every detector 90 : // cell that should be calculated 91 112 : #pragma omp parallel for 92 112 : for (index_t rangeIndex = 0; rangeIndex < maxIterations; ++rangeIndex) { 93 : 94 : // --> get the current ray to the detector center (from reference to DetectorDescriptor) 95 0 : auto ray = _detectorDescriptor.computeRayFromDetectorCoord(rangeIndex); 96 : 97 : // --> setup traversal algorithm 98 0 : TraverseAABB<dim> traverse(_boundingBox, ray, 99 0 : _domainDescriptor->getNumberOfCoefficientsPerDimension()); 100 : 101 0 : if constexpr (!adjoint) 102 5085 : result[rangeIndex] = 0; 103 : 104 519867 : while (traverse.isInBoundingBox()) { 105 : // --> initial index to access the data vector 106 519867 : auto dataIndexForCurrentVoxel = 107 519867 : _domainDescriptor->getIndexFromCoordinate(traverse.getCurrentVoxel()); 108 : 109 : // --> update result depending on the operation performed 110 519867 : if constexpr (adjoint) 111 280190 : #pragma omp atomic 112 262225 : result[dataIndexForCurrentVoxel] += vector[rangeIndex]; 113 262225 : else 114 262225 : result[rangeIndex] += vector[dataIndexForCurrentVoxel]; 115 : 116 519867 : traverse.updateTraverse(); 117 519867 : } 118 0 : } // end for 119 112 : } 120 : 121 : // ------------------------------------------ 122 : // explicit template instantiation 123 : template class BinaryMethod<float>; 124 : template class BinaryMethod<double>; 125 : 126 : } // namespace elsa