Line data Source code
1 : #include "ShearletTransform.h"
2 : #include "FourierTransform.h"
3 : #include "VolumeDescriptor.h"
4 : #include "Timer.h"
5 : #include "Math.hpp"
6 :
7 : namespace elsa
8 : {
9 : template <typename ret_t, typename data_t>
10 : ShearletTransform<ret_t, data_t>::ShearletTransform(IndexVector_t spatialDimensions)
11 : : ShearletTransform(spatialDimensions[0], spatialDimensions[1])
12 2 : {
13 2 : if (spatialDimensions.size() != 2) {
14 0 : throw LogicError("ShearletTransform: Only 2D shape supported");
15 0 : }
16 2 : }
17 :
18 : template <typename ret_t, typename data_t>
19 : ShearletTransform<ret_t, data_t>::ShearletTransform(index_t width, index_t height)
20 : : ShearletTransform(width, height, calculateNumOfScales(width, height))
21 6 : {
22 6 : }
23 :
24 : template <typename ret_t, typename data_t>
25 : ShearletTransform<ret_t, data_t>::ShearletTransform(index_t width, index_t height,
26 : index_t numOfScales)
27 : : ShearletTransform(width, height, numOfScales, std::nullopt)
28 12 : {
29 12 : }
30 :
31 : template <typename ret_t, typename data_t>
32 : ShearletTransform<ret_t, data_t>::ShearletTransform(
33 : index_t width, index_t height, index_t numOfScales,
34 : std::optional<DataContainer<data_t>> spectra)
35 : : LinearOperator<ret_t>(
36 : VolumeDescriptor{{width, height}},
37 : VolumeDescriptor{{width, height, calculateNumOfLayers(numOfScales)}}),
38 : _spectra{spectra},
39 : _width{width},
40 : _height{height},
41 : _numOfScales{numOfScales},
42 : _numOfLayers{calculateNumOfLayers(numOfScales)}
43 14 : {
44 14 : if (width < 0 || height < 0) {
45 0 : throw LogicError("ShearletTransform: negative width/height were provided");
46 0 : }
47 14 : if (numOfScales < 0) {
48 0 : throw LogicError("ShearletTransform: negative number of scales was provided");
49 0 : }
50 14 : }
51 :
52 : // TODO implement sumByAxis in DataContainer and remove me
53 : template <typename ret_t, typename data_t>
54 : DataContainer<elsa::complex<data_t>> ShearletTransform<ret_t, data_t>::sumByLastAxis(
55 : DataContainer<elsa::complex<data_t>> dc) const
56 2 : {
57 2 : auto coeffsPerDim = dc.getDataDescriptor().getNumberOfCoefficientsPerDimension();
58 2 : index_t width = coeffsPerDim[0];
59 2 : index_t height = coeffsPerDim[1];
60 2 : index_t layers = coeffsPerDim[2];
61 2 : DataContainer<elsa::complex<data_t>> summedDC(VolumeDescriptor{{width, height}});
62 :
63 34 : for (index_t j = 0; j < width; j++) {
64 544 : for (index_t k = 0; k < height; k++) {
65 512 : elsa::complex<data_t> currValue = 0;
66 31744 : for (index_t i = 0; i < layers; i++) {
67 31232 : currValue += dc(j, k, i);
68 31232 : }
69 512 : summedDC(j, k) = currValue;
70 512 : }
71 32 : }
72 :
73 2 : return summedDC;
74 2 : }
75 :
76 : template <typename ret_t, typename data_t>
77 : void ShearletTransform<ret_t, data_t>::applyImpl(const DataContainer<ret_t>& x,
78 : DataContainer<ret_t>& Ax) const
79 2 : {
80 2 : Timer timeguard("ShearletTransform", "apply");
81 :
82 2 : if (_width != this->getDomainDescriptor().getNumberOfCoefficientsPerDimension()[0]
83 2 : || _height != this->getDomainDescriptor().getNumberOfCoefficientsPerDimension()[1]) {
84 0 : throw InvalidArgumentError("ShearletTransform: Width and height of the input do not "
85 0 : "match to that of this shearlet system");
86 0 : }
87 :
88 2 : Logger::get("ShearletTransform")
89 2 : ->info("Running the shearlet transform on a 2D signal of shape ({}, {}), on {} "
90 2 : "scales with an oversampling factor of {} and {} spectra",
91 2 : _width, _height, _numOfScales, _numOfLayers,
92 2 : isSpectraComputed() ? "precomputed" : "non-precomputed");
93 :
94 2 : if (!isSpectraComputed()) {
95 2 : computeSpectra();
96 2 : }
97 :
98 2 : FourierTransform<elsa::complex<data_t>> fourierTransform(x.getDataDescriptor());
99 :
100 2 : DataContainer<elsa::complex<data_t>> fftImg = fourierTransform.apply(x.asComplex());
101 :
102 124 : for (index_t i = 0; i < getNumOfLayers(); i++) {
103 122 : DataContainer<elsa::complex<data_t>> temp =
104 122 : getSpectra().slice(i).viewAs(x.getDataDescriptor()).asComplex() * fftImg;
105 122 : if constexpr (isComplex<ret_t>) {
106 0 : Ax.slice(i) = fourierTransform.applyAdjoint(temp);
107 0 : } else {
108 0 : Ax.slice(i) = real(fourierTransform.applyAdjoint(temp));
109 0 : }
110 122 : }
111 2 : }
112 :
113 : template <typename ret_t, typename data_t>
114 : void ShearletTransform<ret_t, data_t>::applyAdjointImpl(const DataContainer<ret_t>& y,
115 : DataContainer<ret_t>& Aty) const
116 2 : {
117 2 : Timer timeguard("ShearletTransform", "applyAdjoint");
118 :
119 2 : if (_width != this->getDomainDescriptor().getNumberOfCoefficientsPerDimension()[0]
120 2 : || _height != this->getDomainDescriptor().getNumberOfCoefficientsPerDimension()[1]) {
121 0 : throw InvalidArgumentError("ShearletTransform: Width and height of the input do not "
122 0 : "match to that of this shearlet system");
123 0 : }
124 :
125 2 : Logger::get("ShearletTransform")
126 2 : ->info("Running the inverse shearlet transform on a 2D signal of shape ({}, {}), "
127 2 : "on {} "
128 2 : "scales with an oversampling factor of {} and {} spectra",
129 2 : _width, _height, _numOfScales, _numOfLayers,
130 2 : isSpectraComputed() ? "precomputed" : "non-precomputed");
131 :
132 2 : if (!isSpectraComputed()) {
133 0 : computeSpectra();
134 0 : }
135 :
136 2 : FourierTransform<elsa::complex<data_t>> fourierTransform(Aty.getDataDescriptor());
137 :
138 2 : DataContainer<elsa::complex<data_t>> intermRes(y.getDataDescriptor());
139 :
140 124 : for (index_t i = 0; i < getNumOfLayers(); i++) {
141 122 : DataContainer<elsa::complex<data_t>> temp =
142 122 : fourierTransform.apply(y.slice(i).viewAs(Aty.getDataDescriptor()).asComplex())
143 122 : * getSpectra().slice(i).viewAs(Aty.getDataDescriptor()).asComplex();
144 122 : intermRes.slice(i) = fourierTransform.applyAdjoint(temp);
145 122 : }
146 :
147 2 : if constexpr (isComplex<ret_t>) {
148 0 : Aty = sumByLastAxis(intermRes);
149 0 : } else {
150 0 : Aty = real(sumByLastAxis(intermRes));
151 0 : }
152 2 : }
153 :
154 : template <typename ret_t, typename data_t>
155 : void ShearletTransform<ret_t, data_t>::computeSpectra() const
156 6 : {
157 6 : if (isSpectraComputed()) {
158 0 : Logger::get("ShearletTransform")->warn("Spectra have already been computed!");
159 0 : }
160 :
161 6 : _spectra = DataContainer<data_t>(VolumeDescriptor{{_width, _height, _numOfLayers}});
162 :
163 6 : _computeSpectraAtLowFreq();
164 :
165 30 : for (index_t j = 0; j < _numOfScales; j++) {
166 24 : auto twoPowJ = static_cast<index_t>(std::pow(2, j));
167 24 : auto shearletsAtJ = static_cast<index_t>(std::pow(2, j + 2));
168 24 : index_t shearletsUpUntilJ = shearletsAtJ - 3;
169 24 : index_t index = 1;
170 :
171 24 : _computeSpectraAtSeamLines(j, -twoPowJ, shearletsUpUntilJ + twoPowJ);
172 180 : for (auto k = -twoPowJ + 1; k < twoPowJ; k++) {
173 : // modulo instead of remainder for negative numbers is needed here, therefore doing
174 : // "((a % b) + b) % b" instead of "a % b"
175 156 : index_t modIndex =
176 156 : (((twoPowJ - index + 1) % shearletsAtJ) + shearletsAtJ) % shearletsAtJ;
177 156 : if (modIndex == 0) {
178 18 : modIndex = shearletsAtJ - 1;
179 138 : } else {
180 138 : --modIndex;
181 138 : }
182 :
183 156 : _computeSpectraAtConicRegions(j, k, shearletsUpUntilJ + modIndex,
184 156 : shearletsUpUntilJ + twoPowJ + index);
185 156 : ++index;
186 156 : }
187 24 : _computeSpectraAtSeamLines(j, twoPowJ, shearletsUpUntilJ + twoPowJ + index);
188 24 : }
189 6 : }
190 :
191 : template <typename ret_t, typename data_t>
192 : void ShearletTransform<ret_t, data_t>::_computeSpectraAtLowFreq() const
193 6 : {
194 6 : DataContainer<data_t> sectionZero(VolumeDescriptor{{_width, _height}});
195 6 : sectionZero = 0;
196 :
197 6 : auto shape = getShapeFractions();
198 :
199 : // TODO attempt to refactor the negative indexing
200 166 : for (auto w = shape.negativeHalfWidth; w < shape.halfWidth; w++) {
201 4768 : for (auto h = shape.negativeHalfHeight; h < shape.halfHeight; h++) {
202 4608 : sectionZero(w < 0 ? w + _width : w, h < 0 ? h + _height : h) =
203 4608 : shearlet::phiHat<data_t>(static_cast<data_t>(w), static_cast<data_t>(h));
204 4608 : }
205 160 : }
206 :
207 6 : _spectra.value().slice(0) = sectionZero;
208 6 : }
209 :
210 : template <typename ret_t, typename data_t>
211 : void ShearletTransform<ret_t, data_t>::_computeSpectraAtConicRegions(index_t j, index_t k,
212 : index_t hSliceIndex,
213 : index_t vSliceIndex) const
214 156 : {
215 156 : DataContainer<data_t> sectionh(VolumeDescriptor{{_width, _height}});
216 156 : sectionh = 0;
217 156 : DataContainer<data_t> sectionv(VolumeDescriptor{{_width, _height}});
218 156 : sectionv = 0;
219 :
220 156 : auto shape = getShapeFractions();
221 156 : auto jr = static_cast<data_t>(j);
222 156 : auto kr = static_cast<data_t>(k);
223 :
224 : // TODO attempt to refactor the negative indexing
225 4316 : for (auto w = shape.negativeHalfWidth; w < shape.halfWidth; w++) {
226 4160 : auto wr = static_cast<data_t>(w);
227 123968 : for (auto h = shape.negativeHalfHeight; h < shape.halfHeight; h++) {
228 119808 : auto hr = static_cast<data_t>(h);
229 119808 : if (std::abs(h) <= std::abs(w)) {
230 63908 : sectionh(w < 0 ? w + _width : w, h < 0 ? h + _height : h) =
231 63908 : shearlet::psiHat<data_t>(std::pow(4.f, -jr) * wr,
232 63908 : std::pow(4.f, -jr) * kr * wr
233 63908 : + std::pow(2.f, -jr) * hr);
234 63908 : } else {
235 55900 : sectionv(w < 0 ? w + _width : w, h < 0 ? h + _height : h) =
236 55900 : shearlet::psiHat<data_t>(std::pow(4.f, -jr) * hr,
237 55900 : std::pow(4.f, -jr) * kr * hr
238 55900 : + std::pow(2.f, -jr) * wr);
239 55900 : }
240 119808 : }
241 4160 : }
242 :
243 156 : _spectra.value().slice(hSliceIndex) = sectionh;
244 156 : _spectra.value().slice(vSliceIndex) = sectionv;
245 156 : }
246 :
247 : template <typename ret_t, typename data_t>
248 : void ShearletTransform<ret_t, data_t>::_computeSpectraAtSeamLines(index_t j, index_t k,
249 : index_t hxvSliceIndex) const
250 48 : {
251 48 : DataContainer<data_t> sectionhxv(VolumeDescriptor{{_width, _height}});
252 48 : sectionhxv = 0;
253 :
254 48 : auto shape = getShapeFractions();
255 48 : auto jr = static_cast<data_t>(j);
256 48 : auto kr = static_cast<data_t>(k);
257 :
258 : // TODO attempt to refactor the negative indexing
259 1328 : for (auto w = shape.negativeHalfWidth; w < shape.halfWidth; w++) {
260 1280 : auto wr = static_cast<data_t>(w);
261 38144 : for (auto h = shape.negativeHalfHeight; h < shape.halfHeight; h++) {
262 36864 : auto hr = static_cast<data_t>(h);
263 36864 : if (std::abs(h) <= std::abs(w)) {
264 19664 : sectionhxv(w < 0 ? w + _width : w, h < 0 ? h + _height : h) =
265 19664 : shearlet::psiHat<data_t>(std::pow(4.f, -jr) * wr,
266 19664 : std::pow(4.f, -jr) * kr * wr
267 19664 : + std::pow(2.f, -jr) * hr);
268 19664 : } else {
269 17200 : sectionhxv(w < 0 ? w + _width : w, h < 0 ? h + _height : h) =
270 17200 : shearlet::psiHat<data_t>(std::pow(4.f, -jr) * hr,
271 17200 : std::pow(4.f, -jr) * kr * hr
272 17200 : + std::pow(2.f, -jr) * wr);
273 17200 : }
274 36864 : }
275 1280 : }
276 :
277 48 : _spectra.value().slice(hxvSliceIndex) = sectionhxv;
278 48 : }
279 :
280 : /**
281 : * helper function to calculate input data fractions.
282 : */
283 : template <typename ret_t, typename data_t>
284 : auto ShearletTransform<ret_t, data_t>::getShapeFractions() const -> shape_fractions
285 210 : {
286 210 : shape_fractions ret;
287 210 : auto width = static_cast<real_t>(_width);
288 210 : auto height = static_cast<real_t>(_height);
289 :
290 210 : ret.negativeHalfWidth = static_cast<index_t>(-std::floor(width / 2.0));
291 210 : ret.halfWidth = static_cast<index_t>(std::ceil(width / 2.0));
292 210 : ret.negativeHalfHeight = static_cast<index_t>(-std::floor(height / 2.0));
293 210 : ret.halfHeight = static_cast<index_t>(std::ceil(height / 2.0));
294 :
295 210 : return ret;
296 210 : }
297 :
298 : template <typename ret_t, typename data_t>
299 : auto ShearletTransform<ret_t, data_t>::getSpectra() const -> DataContainer<data_t>
300 248 : {
301 248 : if (!_spectra.has_value()) {
302 2 : throw LogicError(std::string("ShearletTransform: the spectra is not yet computed"));
303 2 : }
304 246 : return _spectra.value();
305 246 : }
306 :
307 : template <typename ret_t, typename data_t>
308 : bool ShearletTransform<ret_t, data_t>::isSpectraComputed() const
309 16 : {
310 16 : return _spectra.has_value();
311 16 : }
312 :
313 : template <typename ret_t, typename data_t>
314 : index_t ShearletTransform<ret_t, data_t>::calculateNumOfScales(index_t width, index_t height)
315 6 : {
316 6 : return static_cast<index_t>(std::log2(std::max(width, height)) / 2.0);
317 6 : }
318 :
319 : template <typename ret_t, typename data_t>
320 : index_t ShearletTransform<ret_t, data_t>::calculateNumOfLayers(index_t width, index_t height)
321 0 : {
322 0 : return static_cast<index_t>(std::pow(2, (calculateNumOfScales(width, height) + 2)) - 3);
323 0 : }
324 :
325 : template <typename ret_t, typename data_t>
326 : index_t ShearletTransform<ret_t, data_t>::calculateNumOfLayers(index_t numOfScales)
327 28 : {
328 28 : return static_cast<index_t>(std::pow(2, numOfScales + 2) - 3);
329 28 : }
330 :
331 : template <typename ret_t, typename data_t>
332 : auto ShearletTransform<ret_t, data_t>::getWidth() const -> index_t
333 2 : {
334 2 : return _width;
335 2 : }
336 :
337 : template <typename ret_t, typename data_t>
338 : auto ShearletTransform<ret_t, data_t>::getHeight() const -> index_t
339 2 : {
340 2 : return _height;
341 2 : }
342 :
343 : template <typename ret_t, typename data_t>
344 : auto ShearletTransform<ret_t, data_t>::getNumOfLayers() const -> index_t
345 250 : {
346 250 : return _numOfLayers;
347 250 : }
348 :
349 : template <typename ret_t, typename data_t>
350 : ShearletTransform<ret_t, data_t>* ShearletTransform<ret_t, data_t>::cloneImpl() const
351 2 : {
352 2 : return new ShearletTransform<ret_t, data_t>(_width, _height, _numOfScales, _spectra);
353 2 : }
354 :
355 : template <typename ret_t, typename data_t>
356 : bool ShearletTransform<ret_t, data_t>::isEqual(const LinearOperator<ret_t>& other) const
357 2 : {
358 2 : if (!LinearOperator<ret_t>::isEqual(other))
359 0 : return false;
360 :
361 2 : auto otherST = downcast_safe<ShearletTransform<ret_t, data_t>>(&other);
362 :
363 2 : if (!otherST)
364 0 : return false;
365 :
366 2 : if (_width != otherST->_width)
367 0 : return false;
368 :
369 2 : if (_height != otherST->_height)
370 0 : return false;
371 :
372 2 : if (_numOfScales != otherST->_numOfScales)
373 0 : return false;
374 :
375 2 : return true;
376 2 : }
377 :
378 : // ------------------------------------------
379 : // explicit template instantiation
380 : template class ShearletTransform<float, float>;
381 : template class ShearletTransform<elsa::complex<float>, float>;
382 : template class ShearletTransform<double, double>;
383 : template class ShearletTransform<elsa::complex<double>, double>;
384 : } // namespace elsa
|