Line data Source code
1 : #include "JosephsMethod.h"
2 : #include "Timer.h"
3 : #include "DrivingDirectionTraversal.h"
4 : #include "Error.h"
5 : #include "TypeCasts.hpp"
6 :
7 : #include <type_traits>
8 :
9 : namespace elsa
10 : {
11 : template <typename data_t>
12 : JosephsMethod<data_t>::JosephsMethod(const VolumeDescriptor& domainDescriptor,
13 : const DetectorDescriptor& rangeDescriptor)
14 : : base_type(domainDescriptor, rangeDescriptor)
15 62 : {
16 62 : auto dim = domainDescriptor.getNumberOfDimensions();
17 62 : if (dim != 2 && dim != 3) {
18 0 : throw InvalidArgumentError("JosephsMethod:only supporting 2d/3d operations");
19 0 : }
20 :
21 62 : if (dim != rangeDescriptor.getNumberOfDimensions()) {
22 0 : throw InvalidArgumentError("JosephsMethod: domain and range dimension need to match");
23 0 : }
24 :
25 62 : if (rangeDescriptor.getNumberOfGeometryPoses() == 0) {
26 0 : throw InvalidArgumentError("JosephsMethod: geometry list was empty");
27 0 : }
28 62 : }
29 :
30 : template <int dim>
31 : bool isInAABB(const IndexArray_t<dim>& indices, const IndexArray_t<dim>& aabbMin,
32 : const IndexArray_t<dim>& aabbMax)
33 2980865 : {
34 2980865 : return (indices >= aabbMin && indices < aabbMax).all();
35 2980865 : }
36 :
37 : template <int dim>
38 : std::pair<RealArray_t<dim>, RealArray_t<dim>>
39 : getLinearInterpolationWeights(const RealArray_t<dim>& currentPos,
40 : const IndexArray_t<dim>& voxelFloor,
41 : const index_t drivingDirection)
42 1818199 : {
43 : // subtract 0.5 because the weight calculation assumes indices that refer to the center of
44 : // the voxels, while elsa works with the lower corners of the indices.
45 1818199 : RealArray_t<dim> complement_weight = currentPos - voxelFloor.template cast<real_t>() - 0.5f;
46 1818199 : RealArray_t<dim> weight = RealArray_t<dim>{1} - complement_weight;
47 : // set weights along drivingDirection to 1 so that the interpolation does not have to handle
48 : // the drivingDirection as a special case
49 1818199 : weight(drivingDirection) = 1;
50 1818199 : complement_weight(drivingDirection) = 1;
51 1818199 : return std::make_pair(weight, complement_weight);
52 1818199 : }
53 :
54 : template <int dim>
55 : index_t coord2Idx(IndexArray_t<dim> coord, IndexArray_t<dim> strides)
56 3406546 : {
57 3406546 : return (coord * strides).sum();
58 3406546 : }
59 :
60 : template <typename data_t, int dim, class Fn>
61 : void doInterpolation(const IndexArray_t<dim>& voxelFloor, const IndexArray_t<dim>& voxelCeil,
62 : const RealArray_t<dim>& weight, const RealArray_t<dim>& complement_weight,
63 : const IndexArray_t<dim>& aabbMin, const IndexArray_t<dim>& aabbMax, Fn fn)
64 1870844 : {
65 3109894 : auto clip = [](auto coord, auto lower, auto upper) { return coord.min(upper).max(lower); };
66 1870844 : IndexArray_t<dim> tempIndices;
67 1870844 : if constexpr (dim == 2) {
68 3297145 : auto interpol = [&](auto v1, auto v2, auto w1, auto w2) {
69 3297145 : tempIndices << v1, v2;
70 :
71 3297145 : bool is_in_aab = isInAABB(tempIndices, aabbMin, aabbMax);
72 3297145 : tempIndices = clip(tempIndices, aabbMin, (aabbMax - 1));
73 3297145 : auto weight = is_in_aab * w1 * w2;
74 3297145 : fn(tempIndices, weight);
75 3297145 : };
76 :
77 97 : interpol(voxelFloor[0], voxelCeil[1], weight[0], complement_weight[1]);
78 97 : interpol(voxelCeil[0], voxelFloor[1], complement_weight[0], weight[1]);
79 97 : } else {
80 379 : auto interpol = [&](auto v1, auto v2, auto v3, auto w1, auto w2, auto w3) {
81 379 : tempIndices << v1, v2, v3;
82 :
83 379 : bool is_in_aab = isInAABB(tempIndices, aabbMin, aabbMax);
84 379 : tempIndices = clip(tempIndices, aabbMin, (aabbMax - 1));
85 379 : auto weight = is_in_aab * w1 * w2 * w3;
86 379 : fn(tempIndices, weight);
87 379 : };
88 :
89 97 : interpol(voxelFloor[0], voxelFloor[1], voxelFloor[2], weight[0], weight[1], weight[2]);
90 97 : interpol(voxelFloor[0], voxelCeil[1], voxelCeil[2], weight[0], complement_weight[1],
91 97 : complement_weight[2]);
92 97 : interpol(voxelCeil[0], voxelFloor[1], voxelCeil[2], complement_weight[0], weight[1],
93 97 : complement_weight[2]);
94 97 : interpol(voxelCeil[0], voxelCeil[1], voxelFloor[2], complement_weight[0],
95 97 : complement_weight[1], weight[2]);
96 97 : }
97 1870844 : }
98 :
99 : template <typename data_t>
100 : void JosephsMethod<data_t>::forward(const BoundingBox& aabb, const DataContainer<data_t>& x,
101 : DataContainer<data_t>& Ax) const
102 210 : {
103 210 : Timer timeguard("JosephsMethod", "apply");
104 210 : if (aabb.dim() == 2) {
105 194 : traverseVolume<false, 2>(aabb, x, Ax);
106 194 : } else if (aabb.dim() == 3) {
107 16 : traverseVolume<false, 3>(aabb, x, Ax);
108 16 : }
109 210 : }
110 :
111 : template <typename data_t>
112 : void JosephsMethod<data_t>::backward(const BoundingBox& aabb, const DataContainer<data_t>& y,
113 : DataContainer<data_t>& Aty) const
114 144 : {
115 144 : Timer timeguard("JosephsMethod", "applyAdjoint");
116 144 : if (aabb.dim() == 2) {
117 135 : traverseVolume<true, 2>(aabb, y, Aty);
118 135 : } else if (aabb.dim() == 3) {
119 9 : traverseVolume<true, 3>(aabb, y, Aty);
120 9 : }
121 144 : }
122 :
123 : template <typename data_t>
124 : JosephsMethod<data_t>* JosephsMethod<data_t>::_cloneImpl() const
125 23 : {
126 23 : return new self_type(downcast<VolumeDescriptor>(*this->_domainDescriptor),
127 23 : downcast<DetectorDescriptor>(*this->_rangeDescriptor));
128 23 : }
129 :
130 : template <typename data_t>
131 : bool JosephsMethod<data_t>::_isEqual(const LinearOperator<data_t>& other) const
132 4 : {
133 4 : if (!LinearOperator<data_t>::isEqual(other))
134 0 : return false;
135 :
136 4 : auto otherJM = downcast_safe<JosephsMethod>(&other);
137 4 : return static_cast<bool>(otherJM);
138 4 : }
139 :
140 : template <typename data_t>
141 : template <bool adjoint, int dim>
142 : void JosephsMethod<data_t>::traverseVolume(const BoundingBox& aabb,
143 : const DataContainer<data_t>& vector,
144 : DataContainer<data_t>& result) const
145 354 : {
146 354 : if constexpr (adjoint)
147 144 : result = 0;
148 :
149 354 : const auto& domain = adjoint ? result.getDataDescriptor() : vector.getDataDescriptor();
150 354 : const auto& range = downcast<DetectorDescriptor>(adjoint ? vector.getDataDescriptor()
151 354 : : result.getDataDescriptor());
152 :
153 354 : const IndexArray_t<dim> strides = domain.getProductOfCoefficientsPerDimension();
154 354 : const auto sizeOfRange = range.getNumberOfCoefficients();
155 :
156 354 : const IndexArray_t<dim> aabbMin = aabb.min().template cast<index_t>();
157 354 : const IndexArray_t<dim> aabbMax = aabb.max().template cast<index_t>();
158 :
159 : // iterate over all rays
160 354 : #pragma omp parallel for
161 354 : for (index_t ir = 0; ir < sizeOfRange; ir++) {
162 0 : const auto ray = range.computeRayFromDetectorCoord(ir);
163 :
164 : // --> setup traversal algorithm
165 :
166 0 : DrivingDirectionTraversal<dim> traverse(aabb, ray);
167 0 : const index_t drivingDirection = traverse.getDrivingDirection();
168 0 : const data_t intersection = traverse.getIntersectionLength();
169 :
170 0 : if constexpr (!adjoint)
171 99350 : result[ir] = 0;
172 :
173 : // Make steps through the volume
174 1813720 : while (traverse.isInBoundingBox()) {
175 1813720 : const IndexArray_t<dim> voxelFloor = traverse.getCurrentVoxelFloor();
176 1813720 : const IndexArray_t<dim> voxelCeil = traverse.getCurrentVoxelCeil();
177 1813720 : const auto [weight, complement_weight] = getLinearInterpolationWeights(
178 1813720 : traverse.getCurrentPos(), voxelFloor, drivingDirection);
179 :
180 1813720 : if constexpr (adjoint) {
181 1043741 : doInterpolation<data_t>(voxelFloor, voxelCeil, weight, complement_weight,
182 1458028 : aabbMin, aabbMax, [&](const auto& coord, auto wght) {
183 1458028 : #pragma omp atomic
184 1458028 : result[coord2Idx(coord, strides)] +=
185 1458028 : vector[ir] * intersection * wght;
186 1458028 : });
187 1043741 : } else {
188 1043741 : doInterpolation<data_t>(voxelFloor, voxelCeil, weight, complement_weight,
189 1963012 : aabbMin, aabbMax, [&](const auto& coord, auto wght) {
190 1963012 : result[ir] += vector[coord2Idx(coord, strides)]
191 1963012 : * intersection * wght;
192 1963012 : });
193 1043741 : }
194 : // update Traverse
195 1813720 : traverse.updateTraverse();
196 1813720 : }
197 0 : }
198 354 : }
199 :
200 : // ------------------------------------------
201 : // explicit template instantiation
202 : template class JosephsMethod<float>;
203 : template class JosephsMethod<double>;
204 :
205 : } // namespace elsa
|