Line data Source code
1 : #include "BinaryMethod.h" 2 : #include "Timer.h" 3 : #include "TraverseAABB.h" 4 : #include "TypeCasts.hpp" 5 : 6 : #include <stdexcept> 7 : #include <type_traits> 8 : 9 : namespace elsa 10 : { 11 : template <typename data_t> 12 0 : BinaryMethod<data_t>::BinaryMethod(const VolumeDescriptor& domainDescriptor, 13 : const DetectorDescriptor& rangeDescriptor) 14 : : LinearOperator<data_t>(domainDescriptor, rangeDescriptor), 15 0 : _boundingBox{domainDescriptor.getNumberOfCoefficientsPerDimension()}, 16 0 : _detectorDescriptor(static_cast<DetectorDescriptor&>(*_rangeDescriptor)), 17 0 : _volumeDescriptor(static_cast<VolumeDescriptor&>(*_domainDescriptor)) 18 : { 19 : // sanity checks 20 0 : auto dim = _domainDescriptor->getNumberOfDimensions(); 21 0 : if (dim < 2 || dim > 3) 22 0 : throw InvalidArgumentError("BinaryMethod: only supporting 2d/3d operations"); 23 : 24 0 : if (dim != _rangeDescriptor->getNumberOfDimensions()) 25 0 : throw InvalidArgumentError("BinaryMethod: domain and range dimension need to match"); 26 : 27 0 : if (_detectorDescriptor.getNumberOfGeometryPoses() == 0) 28 0 : throw InvalidArgumentError("BinaryMethod: rangeDescriptor without any geometry"); 29 0 : } 30 : 31 : template <typename data_t> 32 0 : void BinaryMethod<data_t>::applyImpl(const DataContainer<data_t>& x, 33 : DataContainer<data_t>& Ax) const 34 : { 35 0 : Timer t("BinaryMethod", "apply"); 36 0 : traverseVolume<false>(x, Ax); 37 0 : } 38 : 39 : template <typename data_t> 40 0 : void BinaryMethod<data_t>::applyAdjointImpl(const DataContainer<data_t>& y, 41 : DataContainer<data_t>& Aty) const 42 : { 43 0 : Timer t("BinaryMethod", "applyAdjoint"); 44 0 : traverseVolume<true>(y, Aty); 45 0 : } 46 : 47 : template <typename data_t> 48 0 : BinaryMethod<data_t>* BinaryMethod<data_t>::cloneImpl() const 49 : { 50 0 : return new BinaryMethod(_volumeDescriptor, _detectorDescriptor); 51 : } 52 : 53 : template <typename data_t> 54 0 : bool BinaryMethod<data_t>::isEqual(const LinearOperator<data_t>& other) const 55 : { 56 0 : if (!LinearOperator<data_t>::isEqual(other)) 57 0 : return false; 58 : 59 0 : auto otherBM = downcast_safe<BinaryMethod>(&other); 60 0 : if (!otherBM) 61 0 : return false; 62 : 63 0 : return true; 64 : } 65 : 66 : template <typename data_t> 67 : template <bool adjoint> 68 0 : void BinaryMethod<data_t>::traverseVolume(const DataContainer<data_t>& vector, 69 : DataContainer<data_t>& result) const 70 : { 71 0 : const index_t maxIterations = adjoint ? vector.getSize() : result.getSize(); 72 : 73 : if constexpr (adjoint) { 74 0 : result = 0; // initialize volume to 0, because we are not going to hit every voxel! 75 : } 76 : 77 : // --> loop either over every voxel that should be updated or every detector 78 : // cell that should be calculated 79 0 : #pragma omp parallel for 80 : for (index_t rangeIndex = 0; rangeIndex < maxIterations; ++rangeIndex) { 81 : 82 : // --> get the current ray to the detector center (from reference to DetectorDescriptor) 83 : auto ray = _detectorDescriptor.computeRayFromDetectorCoord(rangeIndex); 84 : 85 : // --> setup traversal algorithm 86 : TraverseAABB traverse(_boundingBox, ray); 87 : 88 : if constexpr (!adjoint) 89 : result[rangeIndex] = 0; 90 : 91 : while (traverse.isInBoundingBox()) { 92 : // --> initial index to access the data vector 93 : auto dataIndexForCurrentVoxel = 94 : _domainDescriptor->getIndexFromCoordinate(traverse.getCurrentVoxel()); 95 : 96 : // --> update result depending on the operation performed 97 : if constexpr (adjoint) 98 : #pragma omp atomic 99 : result[dataIndexForCurrentVoxel] += vector[rangeIndex]; 100 : else 101 : result[rangeIndex] += vector[dataIndexForCurrentVoxel]; 102 : 103 : traverse.updateTraverse(); 104 : } 105 : } // end for 106 0 : } 107 : 108 : // ------------------------------------------ 109 : // explicit template instantiation 110 : template class BinaryMethod<float>; 111 : template class BinaryMethod<double>; 112 : 113 : } // namespace elsa