Line data Source code
1 : #pragma once 2 : 3 : #include "elsaDefines.h" 4 : #include "Timer.h" 5 : #include "Luts.hpp" 6 : #include "SliceTraversal.h" 7 : #include "LinearOperator.h" 8 : #include "VolumeDescriptor.h" 9 : #include "DetectorDescriptor.h" 10 : #include "DataContainer.h" 11 : #include "BoundingBox.h" 12 : #include "Logger.h" 13 : #include "Blobs.h" 14 : #include "CartesianIndices.h" 15 : 16 : #include "XrayProjector.h" 17 : 18 : #include "spdlog/fmt/bundled/core.h" 19 : #include "spdlog/fmt/fmt.h" 20 : #include "spdlog/fmt/ostr.h" 21 : 22 : namespace elsa 23 : { 24 : template <typename data_t, typename Derived> 25 : class LutProjector; 26 : 27 : template <typename data_t = real_t> 28 : class BlobProjector; 29 : 30 : template <typename data_t = real_t> 31 : class BSplineProjector; 32 : 33 : template <typename data_t> 34 : struct XrayProjectorInnerTypes<BlobProjector<data_t>> { 35 : using value_type = data_t; 36 : using forward_tag = ray_driven_tag; 37 : using backward_tag = ray_driven_tag; 38 : }; 39 : 40 : template <typename data_t> 41 : struct XrayProjectorInnerTypes<BSplineProjector<data_t>> { 42 : using value_type = data_t; 43 : using forward_tag = ray_driven_tag; 44 : using backward_tag = ray_driven_tag; 45 : }; 46 : 47 : template <typename data_t, typename Derived> 48 : class LutProjector : public XrayProjector<Derived> 49 : { 50 : public: 51 : using self_type = LutProjector<data_t, Derived>; 52 : using base_type = XrayProjector<Derived>; 53 : using value_type = typename base_type::value_type; 54 : using forward_tag = typename base_type::forward_tag; 55 : using backward_tag = typename base_type::backward_tag; 56 : 57 : LutProjector(const VolumeDescriptor& domainDescriptor, 58 : const DetectorDescriptor& rangeDescriptor) 59 : : base_type(domainDescriptor, rangeDescriptor) 60 873 : { 61 : // sanity checks 62 873 : auto dim = domainDescriptor.getNumberOfDimensions(); 63 873 : if (dim < 2 || dim > 3) { 64 0 : throw InvalidArgumentError("LutProjector: only supporting 2d/3d operations"); 65 0 : } 66 : 67 873 : if (dim != rangeDescriptor.getNumberOfDimensions()) { 68 0 : throw InvalidArgumentError( 69 0 : "LutProjector: domain and range dimension need to match"); 70 0 : } 71 : 72 873 : if (rangeDescriptor.getNumberOfGeometryPoses() == 0) { 73 0 : throw InvalidArgumentError("LutProjector: rangeDescriptor without any geometry"); 74 0 : } 75 873 : } 76 : 77 : /// default destructor 78 873 : ~LutProjector() override = default; 79 : 80 : private: 81 : /// apply the binary method (i.e. forward projection) 82 : data_t traverseRayForward(const BoundingBox& boundingbox, const RealRay_t& ray, 83 : const DataContainer<data_t>& x) const 84 1607 : { 85 1607 : const IndexVector_t lower = boundingbox._min.template cast<index_t>(); 86 1607 : const IndexVector_t upper = boundingbox._max.template cast<index_t>(); 87 1607 : const auto support = this->self().support(); 88 : 89 1607 : index_t leadingdir = 0; 90 1607 : ray.direction().array().cwiseAbs().maxCoeff(&leadingdir); 91 : 92 1607 : IndexVector_t distvec = IndexVector_t::Constant(lower.size(), support); 93 1607 : distvec[leadingdir] = 0; 94 : 95 1607 : auto rangeVal = data_t(0); 96 : 97 : // Expand bounding box as rays have larger support now 98 1607 : auto aabb = boundingbox; 99 1607 : aabb._min.array() -= static_cast<real_t>(support); 100 1607 : aabb._min[leadingdir] += static_cast<real_t>(support); 101 : 102 1607 : aabb._max.array() += static_cast<real_t>(support); 103 1607 : aabb._max[leadingdir] -= static_cast<real_t>(support); 104 : 105 : // Keep this here, as it saves us a couple of allocations on clang 106 1607 : CartesianIndices neighbours(upper); 107 : 108 : // --> setup traversal algorithm 109 1607 : SliceTraversal traversal(boundingbox, ray); 110 : 111 7864 : for (const auto& curVoxel : traversal) { 112 7864 : neighbours = neighbours_in_slice(curVoxel, distvec, lower, upper); 113 49862 : for (auto neighbour : neighbours) { 114 : // Correct position, such that the distance is still correct 115 49862 : const auto correctedPos = neighbour.template cast<real_t>().array() + 0.5; 116 49862 : const auto distance = ray.distance(correctedPos); 117 49862 : const auto weight = this->self().weight(distance); 118 : 119 49862 : rangeVal += weight * x(neighbour); 120 49862 : } 121 7864 : } 122 : 123 1607 : return rangeVal; 124 1607 : } 125 : 126 : void traverseRayBackward(const BoundingBox& boundingbox, const RealRay_t& ray, 127 : const value_type& detectorValue, DataContainer<data_t>& Aty) const 128 15 : { 129 15 : const IndexVector_t lower = boundingbox._min.template cast<index_t>(); 130 15 : const IndexVector_t upper = boundingbox._max.template cast<index_t>(); 131 15 : const auto support = this->self().support(); 132 : 133 15 : index_t leadingdir = 0; 134 15 : ray.direction().array().cwiseAbs().maxCoeff(&leadingdir); 135 : 136 15 : IndexVector_t distvec = IndexVector_t::Constant(lower.size(), support); 137 15 : distvec[leadingdir] = 0; 138 : 139 : // Expand bounding box as rays have larger support now 140 15 : auto aabb = boundingbox; 141 15 : aabb._min.array() -= static_cast<real_t>(support); 142 15 : aabb._min[leadingdir] += static_cast<real_t>(support); 143 : 144 15 : aabb._max.array() += static_cast<real_t>(support); 145 15 : aabb._max[leadingdir] -= static_cast<real_t>(support); 146 : 147 : // Keep this here, as it saves us a couple of allocations on clang 148 15 : CartesianIndices neighbours(upper); 149 : 150 : // --> setup traversal algorithm 151 15 : SliceTraversal traversal(aabb, ray); 152 : 153 75 : for (const auto& curVoxel : traversal) { 154 75 : neighbours = neighbours_in_slice(curVoxel, distvec, lower, upper); 155 357 : for (auto neighbour : neighbours) { 156 : // Correct position, such that the distance is still correct 157 357 : const auto correctedPos = neighbour.template cast<real_t>().array() + 0.5; 158 357 : const auto distance = ray.distance(correctedPos); 159 357 : const auto weight = this->self().weight(distance); 160 : 161 357 : #pragma omp atomic 162 357 : Aty(neighbour) += weight * detectorValue; 163 357 : } 164 75 : } 165 15 : } 166 : 167 : /// implement the polymorphic clone operation 168 : LutProjector<data_t, Derived>* _cloneImpl() const 169 : { 170 : return new LutProjector(downcast<VolumeDescriptor>(*this->_domainDescriptor), 171 : downcast<DetectorDescriptor>(*this->_rangeDescriptor)); 172 : } 173 : 174 : /// implement the polymorphic comparison operation 175 : bool _isEqual(const LinearOperator<data_t>& other) const 176 : { 177 : if (!LinearOperator<data_t>::isEqual(other)) 178 : return false; 179 : 180 : auto otherOp = downcast_safe<LutProjector>(&other); 181 : return static_cast<bool>(otherOp); 182 : } 183 : 184 : friend class XrayProjector<Derived>; 185 : }; 186 : 187 : template <typename data_t> 188 : class BlobProjector : public LutProjector<data_t, BlobProjector<data_t>> 189 : { 190 : public: 191 : using self_type = BlobProjector<data_t>; 192 : 193 : BlobProjector(data_t radius, data_t alpha, data_t order, 194 : const VolumeDescriptor& domainDescriptor, 195 : const DetectorDescriptor& rangeDescriptor); 196 : 197 : BlobProjector(const VolumeDescriptor& domainDescriptor, 198 : const DetectorDescriptor& rangeDescriptor); 199 : 200 51439 : data_t weight(data_t distance) const { return lut_(distance); } 201 : 202 1610 : index_t support() const { return static_cast<index_t>(std::ceil(lut_.radius())); } 203 : 204 : /// implement the polymorphic clone operation 205 : BlobProjector<data_t>* _cloneImpl() const; 206 : 207 : /// implement the polymorphic comparison operation 208 : bool _isEqual(const LinearOperator<data_t>& other) const; 209 : 210 : private: 211 : ProjectedBlobLut<data_t, 100> lut_; 212 : 213 : using Base = LutProjector<data_t, BlobProjector<data_t>>; 214 : 215 : friend class XrayProjector<self_type>; 216 : }; 217 : 218 : template <typename data_t> 219 : class BSplineProjector : public LutProjector<data_t, BSplineProjector<data_t>> 220 : { 221 : public: 222 : using self_type = BlobProjector<data_t>; 223 : 224 : BSplineProjector(data_t degree, const VolumeDescriptor& domainDescriptor, 225 : const DetectorDescriptor& rangeDescriptor); 226 : 227 : BSplineProjector(const VolumeDescriptor& domainDescriptor, 228 : const DetectorDescriptor& rangeDescriptor); 229 : 230 : data_t weight(data_t distance) const; 231 : 232 : index_t support() const; 233 : 234 : /// implement the polymorphic clone operation 235 : BSplineProjector<data_t>* _cloneImpl() const; 236 : 237 : /// implement the polymorphic comparison operation 238 : bool _isEqual(const LinearOperator<data_t>& other) const; 239 : 240 : private: 241 : ProjectedBSplineLut<data_t, 100> lut_; 242 : 243 : using Base = LutProjector<data_t, BSplineProjector<data_t>>; 244 : 245 : friend class XrayProjector<self_type>; 246 : }; 247 : } // namespace elsa