Line data Source code
1 : #pragma once 2 : 3 : #include "elsaDefines.h" 4 : #include "Timer.h" 5 : #include "SliceTraversal.h" 6 : #include "LinearOperator.h" 7 : #include "VolumeDescriptor.h" 8 : #include "DetectorDescriptor.h" 9 : #include "DataContainer.h" 10 : #include "BoundingBox.h" 11 : #include "Logger.h" 12 : #include "Blobs.h" 13 : #include "BSplines.h" 14 : #include "CartesianIndices.h" 15 : 16 : #include "XrayProjector.h" 17 : 18 : #include "spdlog/fmt/fmt.h" 19 : #include "spdlog/fmt/ostr.h" 20 : 21 : namespace elsa 22 : { 23 : template <typename data_t, typename Derived> 24 : class LutProjector; 25 : 26 : template <typename data_t = real_t> 27 : class BlobProjector; 28 : 29 : template <typename data_t = real_t> 30 : class BSplineProjector; 31 : 32 : template <typename data_t> 33 : struct XrayProjectorInnerTypes<BlobProjector<data_t>> { 34 : using value_type = data_t; 35 : using forward_tag = ray_driven_tag; 36 : using backward_tag = ray_driven_tag; 37 : }; 38 : 39 : template <typename data_t> 40 : struct XrayProjectorInnerTypes<BSplineProjector<data_t>> { 41 : using value_type = data_t; 42 : using forward_tag = ray_driven_tag; 43 : using backward_tag = ray_driven_tag; 44 : }; 45 : 46 : template <typename data_t, typename Derived> 47 : class LutProjector : public XrayProjector<Derived> 48 : { 49 : public: 50 : using self_type = LutProjector<data_t, Derived>; 51 : using base_type = XrayProjector<Derived>; 52 : using value_type = typename base_type::value_type; 53 : using forward_tag = typename base_type::forward_tag; 54 : using backward_tag = typename base_type::backward_tag; 55 : 56 : LutProjector(const VolumeDescriptor& domainDescriptor, 57 : const DetectorDescriptor& rangeDescriptor) 58 : : base_type(domainDescriptor, rangeDescriptor) 59 871 : { 60 : // sanity checks 61 871 : auto dim = domainDescriptor.getNumberOfDimensions(); 62 871 : if (dim < 2 || dim > 3) { 63 0 : throw InvalidArgumentError("LutProjector: only supporting 2d/3d operations"); 64 0 : } 65 : 66 871 : if (dim != rangeDescriptor.getNumberOfDimensions()) { 67 0 : throw InvalidArgumentError( 68 0 : "LutProjector: domain and range dimension need to match"); 69 0 : } 70 : 71 871 : if (rangeDescriptor.getNumberOfGeometryPoses() == 0) { 72 0 : throw InvalidArgumentError("LutProjector: rangeDescriptor without any geometry"); 73 0 : } 74 871 : } 75 : 76 : /// default destructor 77 871 : ~LutProjector() override = default; 78 : 79 : private: 80 : /// apply the binary method (i.e. forward projection) 81 : data_t traverseRayForward(const BoundingBox& boundingbox, const RealRay_t& ray, 82 : const DataContainer<data_t>& x) const 83 1611 : { 84 1611 : const IndexVector_t lower = boundingbox.min().template cast<index_t>(); 85 1611 : const IndexVector_t upper = boundingbox.max().template cast<index_t>(); 86 1611 : const auto support = this->self().support(); 87 : 88 1611 : index_t leadingdir = 0; 89 1611 : ray.direction().array().cwiseAbs().maxCoeff(&leadingdir); 90 : 91 1611 : IndexVector_t distvec = IndexVector_t::Constant(lower.size(), support); 92 1611 : distvec[leadingdir] = 0; 93 : 94 1611 : auto rangeVal = data_t(0); 95 : 96 : // Expand bounding box as rays have larger support now 97 1611 : auto aabb = boundingbox; 98 1611 : aabb.min().array() -= static_cast<real_t>(support); 99 1611 : aabb.min()[leadingdir] += static_cast<real_t>(support); 100 : 101 1611 : aabb.max().array() += static_cast<real_t>(support); 102 1611 : aabb.max()[leadingdir] -= static_cast<real_t>(support); 103 : 104 : // Keep this here, as it saves us a couple of allocations on clang 105 1611 : CartesianIndices neighbours(upper); 106 : 107 : // --> setup traversal algorithm 108 1611 : SliceTraversal traversal(boundingbox, ray); 109 : 110 7856 : for (const auto& curVoxel : traversal) { 111 7856 : neighbours = neighbours_in_slice(curVoxel, distvec, lower, upper); 112 49784 : for (auto neighbour : neighbours) { 113 : // Correct position, such that the distance is still correct 114 49784 : const auto correctedPos = neighbour.template cast<real_t>().array() + 0.5; 115 49784 : const auto distance = ray.distance(correctedPos); 116 49784 : const auto weight = this->self().weight(distance); 117 : 118 49784 : rangeVal += weight * x(neighbour); 119 49784 : } 120 7856 : } 121 : 122 1611 : return rangeVal; 123 1611 : } 124 : 125 : void traverseRayBackward(const BoundingBox& boundingbox, const RealRay_t& ray, 126 : const value_type& detectorValue, DataContainer<data_t>& Aty) const 127 15 : { 128 15 : const IndexVector_t lower = boundingbox.min().template cast<index_t>(); 129 15 : const IndexVector_t upper = boundingbox.max().template cast<index_t>(); 130 15 : const auto support = this->self().support(); 131 : 132 15 : index_t leadingdir = 0; 133 15 : ray.direction().array().cwiseAbs().maxCoeff(&leadingdir); 134 : 135 15 : IndexVector_t distvec = IndexVector_t::Constant(lower.size(), support); 136 15 : distvec[leadingdir] = 0; 137 : 138 : // Expand bounding box as rays have larger support now 139 15 : auto aabb = boundingbox; 140 15 : aabb.min().array() -= static_cast<real_t>(support); 141 15 : aabb.min()[leadingdir] += static_cast<real_t>(support); 142 : 143 15 : aabb.max().array() += static_cast<real_t>(support); 144 15 : aabb.max()[leadingdir] -= static_cast<real_t>(support); 145 : 146 : // Keep this here, as it saves us a couple of allocations on clang 147 15 : CartesianIndices neighbours(upper); 148 : 149 : // --> setup traversal algorithm 150 15 : SliceTraversal traversal(aabb, ray); 151 : 152 75 : for (const auto& curVoxel : traversal) { 153 75 : neighbours = neighbours_in_slice(curVoxel, distvec, lower, upper); 154 357 : for (auto neighbour : neighbours) { 155 : // Correct position, such that the distance is still correct 156 357 : const auto correctedPos = neighbour.template cast<real_t>().array() + 0.5; 157 357 : const auto distance = ray.distance(correctedPos); 158 357 : const auto weight = this->self().weight(distance); 159 : 160 357 : #pragma omp atomic 161 357 : Aty(neighbour) += weight * detectorValue; 162 357 : } 163 75 : } 164 15 : } 165 : 166 : /// implement the polymorphic clone operation 167 : LutProjector<data_t, Derived>* _cloneImpl() const 168 : { 169 : return new LutProjector(downcast<VolumeDescriptor>(*this->_domainDescriptor), 170 : downcast<DetectorDescriptor>(*this->_rangeDescriptor)); 171 : } 172 : 173 : /// implement the polymorphic comparison operation 174 : bool _isEqual(const LinearOperator<data_t>& other) const 175 : { 176 : if (!LinearOperator<data_t>::isEqual(other)) 177 : return false; 178 : 179 : auto otherOp = downcast_safe<LutProjector>(&other); 180 : return static_cast<bool>(otherOp); 181 : } 182 : 183 : friend class XrayProjector<Derived>; 184 : }; 185 : 186 : template <typename data_t> 187 : class BlobProjector : public LutProjector<data_t, BlobProjector<data_t>> 188 : { 189 : public: 190 : using self_type = BlobProjector<data_t>; 191 : 192 : BlobProjector(const VolumeDescriptor& domainDescriptor, 193 : const DetectorDescriptor& rangeDescriptor, 194 : data_t radius = blobs::DEFAULT_RADIUS, data_t alpha = blobs::DEFAULT_ALPHA, 195 : index_t order = blobs::DEFAULT_ORDER); 196 : 197 51141 : data_t weight(data_t distance) const { return blob_.get_lut()(distance); } 198 : 199 1623 : index_t support() const { return static_cast<index_t>(std::ceil(blob_.radius())); } 200 : 201 : /// implement the polymorphic clone operation 202 : BlobProjector<data_t>* _cloneImpl() const; 203 : 204 : /// implement the polymorphic comparison operation 205 : bool _isEqual(const LinearOperator<data_t>& other) const; 206 : 207 : private: 208 : ProjectedBlob<data_t> blob_; 209 : 210 : using Base = LutProjector<data_t, BlobProjector<data_t>>; 211 : 212 : friend class XrayProjector<self_type>; 213 : }; 214 : 215 : template <typename data_t> 216 : class BSplineProjector : public LutProjector<data_t, BSplineProjector<data_t>> 217 : { 218 : public: 219 : using self_type = BlobProjector<data_t>; 220 : 221 : BSplineProjector(const VolumeDescriptor& domainDescriptor, 222 : const DetectorDescriptor& rangeDescriptor, 223 : index_t order = bspline::DEFAULT_ORDER); 224 : 225 : data_t weight(data_t distance) const; 226 : 227 : index_t support() const; 228 : 229 : /// implement the polymorphic clone operation 230 : BSplineProjector<data_t>* _cloneImpl() const; 231 : 232 : /// implement the polymorphic comparison operation 233 : bool _isEqual(const LinearOperator<data_t>& other) const; 234 : 235 : private: 236 : ProjectedBSpline<data_t> bspline_; 237 : 238 : using Base = LutProjector<data_t, BSplineProjector<data_t>>; 239 : 240 : friend class XrayProjector<self_type>; 241 : }; 242 : } // namespace elsa