Line data Source code
1 : #include "ShearletTransform.h"
2 : #include "FourierTransform.h"
3 : #include "VolumeDescriptor.h"
4 : #include "Timer.h"
5 : #include "Math.hpp"
6 :
7 : namespace elsa
8 : {
9 : template <typename ret_t, typename data_t>
10 : ShearletTransform<ret_t, data_t>::ShearletTransform(IndexVector_t spatialDimensions)
11 : : ShearletTransform(spatialDimensions[0], spatialDimensions[1])
12 2 : {
13 2 : if (spatialDimensions.size() != 2) {
14 0 : throw LogicError("ShearletTransform: Only 2D shape supported");
15 0 : }
16 2 : }
17 :
18 : template <typename ret_t, typename data_t>
19 : ShearletTransform<ret_t, data_t>::ShearletTransform(index_t width, index_t height)
20 : : ShearletTransform(width, height, calculateNumOfScales(width, height))
21 6 : {
22 6 : }
23 :
24 : template <typename ret_t, typename data_t>
25 : ShearletTransform<ret_t, data_t>::ShearletTransform(index_t width, index_t height,
26 : index_t numOfScales)
27 : : ShearletTransform(width, height, numOfScales, std::nullopt)
28 12 : {
29 12 : }
30 :
31 : template <typename ret_t, typename data_t>
32 : ShearletTransform<ret_t, data_t>::ShearletTransform(
33 : index_t width, index_t height, index_t numOfScales,
34 : std::optional<DataContainer<data_t>> spectra)
35 : : LinearOperator<ret_t>(
36 : VolumeDescriptor{{width, height}},
37 : VolumeDescriptor{{width, height, calculateNumOfLayers(numOfScales)}}),
38 : _spectra{spectra},
39 : _width{width},
40 : _height{height},
41 : _numOfScales{numOfScales},
42 : _numOfLayers{calculateNumOfLayers(numOfScales)}
43 14 : {
44 14 : if (width < 0 || height < 0) {
45 0 : throw LogicError("ShearletTransform: negative width/height were provided");
46 0 : }
47 14 : if (numOfScales < 0) {
48 0 : throw LogicError("ShearletTransform: negative number of scales was provided");
49 0 : }
50 14 : }
51 :
52 : // TODO implement sumByAxis in DataContainer and remove me
53 : template <typename ret_t, typename data_t>
54 : DataContainer<elsa::complex<data_t>> ShearletTransform<ret_t, data_t>::sumByLastAxis(
55 : DataContainer<elsa::complex<data_t>> dc) const
56 2 : {
57 2 : auto coeffsPerDim = dc.getDataDescriptor().getNumberOfCoefficientsPerDimension();
58 2 : index_t width = coeffsPerDim[0];
59 2 : index_t height = coeffsPerDim[1];
60 2 : index_t layers = coeffsPerDim[2];
61 2 : DataContainer<elsa::complex<data_t>> summedDC(VolumeDescriptor{{width, height}});
62 :
63 66 : for (index_t j = 0; j < width; j++) {
64 2112 : for (index_t k = 0; k < height; k++) {
65 2048 : elsa::complex<data_t> currValue = 0;
66 126976 : for (index_t i = 0; i < layers; i++) {
67 124928 : currValue += dc(j, k, i);
68 124928 : }
69 2048 : summedDC(j, k) = currValue;
70 2048 : }
71 64 : }
72 :
73 2 : return summedDC;
74 2 : }
75 :
76 : template <typename ret_t, typename data_t>
77 : void ShearletTransform<ret_t, data_t>::applyImpl(const DataContainer<ret_t>& x,
78 : DataContainer<ret_t>& Ax) const
79 2 : {
80 2 : Timer timeguard("ShearletTransform", "apply");
81 :
82 2 : if (_width != this->getDomainDescriptor().getNumberOfCoefficientsPerDimension()[0]
83 2 : || _height != this->getDomainDescriptor().getNumberOfCoefficientsPerDimension()[1]) {
84 0 : throw InvalidArgumentError("ShearletTransform: Width and height of the input do not "
85 0 : "match to that of this shearlet system");
86 0 : }
87 :
88 2 : Logger::get("ShearletTransform")
89 2 : ->info("Running the shearlet transform on a 2D signal of shape ({}, {}), on {} "
90 2 : "scales with an oversampling factor of {} and {} spectra",
91 2 : _width, _height, _numOfScales, _numOfLayers,
92 2 : isSpectraComputed() ? "precomputed" : "non-precomputed");
93 :
94 2 : if (!isSpectraComputed()) {
95 2 : computeSpectra();
96 2 : }
97 :
98 2 : FourierTransform<elsa::complex<data_t>> fourierTransform(x.getDataDescriptor());
99 :
100 2 : DataContainer<elsa::complex<data_t>> fftImg = fourierTransform.apply(x.asComplex());
101 :
102 124 : for (index_t i = 0; i < getNumOfLayers(); i++) {
103 122 : DataContainer<elsa::complex<data_t>> temp =
104 122 : getSpectra().slice(i).viewAs(x.getDataDescriptor()).asComplex() * fftImg;
105 122 : if constexpr (isComplex<ret_t>) {
106 0 : Ax.slice(i) = fourierTransform.applyAdjoint(temp);
107 0 : } else {
108 0 : Ax.slice(i) = real(fourierTransform.applyAdjoint(temp));
109 0 : }
110 122 : }
111 2 : }
112 :
113 : template <typename ret_t, typename data_t>
114 : void ShearletTransform<ret_t, data_t>::applyAdjointImpl(const DataContainer<ret_t>& y,
115 : DataContainer<ret_t>& Aty) const
116 2 : {
117 2 : Timer timeguard("ShearletTransform", "applyAdjoint");
118 :
119 2 : if (_width != this->getDomainDescriptor().getNumberOfCoefficientsPerDimension()[0]
120 2 : || _height != this->getDomainDescriptor().getNumberOfCoefficientsPerDimension()[1]) {
121 0 : throw InvalidArgumentError("ShearletTransform: Width and height of the input do not "
122 0 : "match to that of this shearlet system");
123 0 : }
124 :
125 2 : Logger::get("ShearletTransform")
126 2 : ->info("Running the inverse shearlet transform on a 2D signal of shape ({}, {}), "
127 2 : "on {} "
128 2 : "scales with an oversampling factor of {} and {} spectra",
129 2 : _width, _height, _numOfScales, _numOfLayers,
130 2 : isSpectraComputed() ? "precomputed" : "non-precomputed");
131 :
132 2 : if (!isSpectraComputed()) {
133 0 : computeSpectra();
134 0 : }
135 :
136 2 : FourierTransform<elsa::complex<data_t>> fourierTransform(Aty.getDataDescriptor());
137 :
138 2 : DataContainer<elsa::complex<data_t>> intermRes(y.getDataDescriptor());
139 :
140 124 : for (index_t i = 0; i < getNumOfLayers(); i++) {
141 122 : DataContainer<elsa::complex<data_t>> temp =
142 122 : fourierTransform.apply(y.slice(i).viewAs(Aty.getDataDescriptor()).asComplex())
143 122 : * getSpectra().slice(i).viewAs(Aty.getDataDescriptor()).asComplex();
144 122 : intermRes.slice(i) = fourierTransform.applyAdjoint(temp);
145 122 : }
146 :
147 2 : if constexpr (isComplex<ret_t>) {
148 0 : Aty = sumByLastAxis(intermRes);
149 0 : } else {
150 0 : Aty = real(sumByLastAxis(intermRes));
151 0 : }
152 2 : }
153 :
154 : template <typename ret_t, typename data_t>
155 : void ShearletTransform<ret_t, data_t>::computeSpectra() const
156 6 : {
157 6 : if (isSpectraComputed()) {
158 0 : Logger::get("ShearletTransform")->warn("Spectra have already been computed!");
159 0 : }
160 :
161 6 : _spectra = DataContainer<data_t>(VolumeDescriptor{{_width, _height, _numOfLayers}});
162 :
163 6 : _computeSpectraAtLowFreq();
164 :
165 30 : for (index_t j = 0; j < _numOfScales; j++) {
166 24 : auto twoPowJ = static_cast<index_t>(std::pow(2, j));
167 24 : auto shearletsAtJ = static_cast<index_t>(std::pow(2, j + 2));
168 24 : index_t shearletsUpUntilJ = shearletsAtJ - 3;
169 24 : index_t index = 1;
170 :
171 24 : _computeSpectraAtSeamLines(j, -twoPowJ, shearletsUpUntilJ + twoPowJ);
172 180 : for (auto k = -twoPowJ + 1; k < twoPowJ; k++) {
173 : // modulo instead of remainder for negative numbers is needed here, therefore doing
174 : // "((a % b) + b) % b" instead of "a % b"
175 156 : index_t modIndex =
176 156 : (((twoPowJ - index + 1) % shearletsAtJ) + shearletsAtJ) % shearletsAtJ;
177 156 : if (modIndex == 0) {
178 18 : modIndex = shearletsAtJ - 1;
179 138 : } else {
180 138 : --modIndex;
181 138 : }
182 :
183 156 : _computeSpectraAtConicRegions(j, k, shearletsUpUntilJ + modIndex,
184 156 : shearletsUpUntilJ + twoPowJ + index);
185 156 : ++index;
186 156 : }
187 24 : _computeSpectraAtSeamLines(j, twoPowJ, shearletsUpUntilJ + twoPowJ + index);
188 24 : }
189 6 : }
190 :
191 : template <typename ret_t, typename data_t>
192 : void ShearletTransform<ret_t, data_t>::_computeSpectraAtLowFreq() const
193 6 : {
194 6 : DataContainer<data_t> sectionZero(VolumeDescriptor{{_width, _height}});
195 6 : sectionZero = 0;
196 :
197 6 : auto negativeHalfWidth = static_cast<index_t>(-std::floor(_width / 2.0));
198 6 : auto halfWidth = static_cast<index_t>(std::ceil(_width / 2.0));
199 6 : auto negativeHalfHeight = static_cast<index_t>(-std::floor(_height / 2.0));
200 6 : auto halfHeight = static_cast<index_t>(std::ceil(_height / 2.0));
201 :
202 : // TODO attempt to refactor the negative indexing
203 198 : for (auto w = negativeHalfWidth; w < halfWidth; w++) {
204 6336 : for (auto h = negativeHalfHeight; h < halfHeight; h++) {
205 6144 : sectionZero(w < 0 ? w + _width : w, h < 0 ? h + _height : h) =
206 6144 : shearlet::phiHat<data_t>(static_cast<data_t>(w), static_cast<data_t>(h));
207 6144 : }
208 192 : }
209 :
210 6 : _spectra.value().slice(0) = sectionZero;
211 6 : }
212 :
213 : template <typename ret_t, typename data_t>
214 : void ShearletTransform<ret_t, data_t>::_computeSpectraAtConicRegions(index_t j, index_t k,
215 : index_t hSliceIndex,
216 : index_t vSliceIndex) const
217 156 : {
218 156 : DataContainer<data_t> sectionh(VolumeDescriptor{{_width, _height}});
219 156 : sectionh = 0;
220 156 : DataContainer<data_t> sectionv(VolumeDescriptor{{_width, _height}});
221 156 : sectionv = 0;
222 :
223 156 : auto negativeHalfWidth = static_cast<index_t>(-std::floor(_width / 2.0));
224 156 : auto halfWidth = static_cast<index_t>(std::ceil(_width / 2.0));
225 156 : auto negativeHalfHeight = static_cast<index_t>(-std::floor(_height / 2.0));
226 156 : auto halfHeight = static_cast<index_t>(std::ceil(_height / 2.0));
227 :
228 : // TODO attempt to refactor the negative indexing
229 5148 : for (auto w = negativeHalfWidth; w < halfWidth; w++) {
230 164736 : for (auto h = negativeHalfHeight; h < halfHeight; h++) {
231 159744 : if (std::abs(h) <= std::abs(w)) {
232 84708 : sectionh(w < 0 ? w + _width : w, h < 0 ? h + _height : h) =
233 84708 : shearlet::psiHat<data_t>(std::pow(4, -j) * w,
234 84708 : std::pow(4, -j) * k * w + std::pow(2, -j) * h);
235 84708 : } else {
236 75036 : sectionv(w < 0 ? w + _width : w, h < 0 ? h + _height : h) =
237 75036 : shearlet::psiHat<data_t>(std::pow(4, -j) * h,
238 75036 : std::pow(4, -j) * k * h + std::pow(2, -j) * w);
239 75036 : }
240 159744 : }
241 4992 : }
242 :
243 156 : _spectra.value().slice(hSliceIndex) = sectionh;
244 156 : _spectra.value().slice(vSliceIndex) = sectionv;
245 156 : }
246 :
247 : template <typename ret_t, typename data_t>
248 : void ShearletTransform<ret_t, data_t>::_computeSpectraAtSeamLines(index_t j, index_t k,
249 : index_t hxvSliceIndex) const
250 48 : {
251 48 : DataContainer<data_t> sectionhxv(VolumeDescriptor{{_width, _height}});
252 48 : sectionhxv = 0;
253 :
254 48 : auto negativeHalfWidth = static_cast<index_t>(-std::floor(_width / 2.0));
255 48 : auto halfWidth = static_cast<index_t>(std::ceil(_width / 2.0));
256 48 : auto negativeHalfHeight = static_cast<index_t>(-std::floor(_height / 2.0));
257 48 : auto halfHeight = static_cast<index_t>(std::ceil(_height / 2.0));
258 :
259 : // TODO attempt to refactor the negative indexing
260 1584 : for (auto w = negativeHalfWidth; w < halfWidth; w++) {
261 50688 : for (auto h = negativeHalfHeight; h < halfHeight; h++) {
262 49152 : if (std::abs(h) <= std::abs(w)) {
263 26064 : sectionhxv(w < 0 ? w + _width : w, h < 0 ? h + _height : h) =
264 26064 : shearlet::psiHat<data_t>(std::pow(4, -j) * w,
265 26064 : std::pow(4, -j) * k * w + std::pow(2, -j) * h);
266 26064 : } else {
267 23088 : sectionhxv(w < 0 ? w + _width : w, h < 0 ? h + _height : h) =
268 23088 : shearlet::psiHat<data_t>(std::pow(4, -j) * h,
269 23088 : std::pow(4, -j) * k * h + std::pow(2, -j) * w);
270 23088 : }
271 49152 : }
272 1536 : }
273 :
274 48 : _spectra.value().slice(hxvSliceIndex) = sectionhxv;
275 48 : }
276 :
277 : template <typename ret_t, typename data_t>
278 : auto ShearletTransform<ret_t, data_t>::getSpectra() const -> DataContainer<data_t>
279 248 : {
280 248 : if (!_spectra.has_value()) {
281 2 : throw LogicError(std::string("ShearletTransform: the spectra is not yet computed"));
282 2 : }
283 246 : return _spectra.value();
284 246 : }
285 :
286 : template <typename ret_t, typename data_t>
287 : bool ShearletTransform<ret_t, data_t>::isSpectraComputed() const
288 16 : {
289 16 : return _spectra.has_value();
290 16 : }
291 :
292 : template <typename ret_t, typename data_t>
293 : index_t ShearletTransform<ret_t, data_t>::calculateNumOfScales(index_t width, index_t height)
294 6 : {
295 6 : return static_cast<index_t>(std::log2(std::max(width, height)) / 2.0);
296 6 : }
297 :
298 : template <typename ret_t, typename data_t>
299 : index_t ShearletTransform<ret_t, data_t>::calculateNumOfLayers(index_t width, index_t height)
300 0 : {
301 0 : return static_cast<index_t>(std::pow(2, (calculateNumOfScales(width, height) + 2)) - 3);
302 0 : }
303 :
304 : template <typename ret_t, typename data_t>
305 : index_t ShearletTransform<ret_t, data_t>::calculateNumOfLayers(index_t numOfScales)
306 28 : {
307 28 : return static_cast<index_t>(std::pow(2, numOfScales + 2) - 3);
308 28 : }
309 :
310 : template <typename ret_t, typename data_t>
311 : auto ShearletTransform<ret_t, data_t>::getWidth() const -> index_t
312 2 : {
313 2 : return _width;
314 2 : }
315 :
316 : template <typename ret_t, typename data_t>
317 : auto ShearletTransform<ret_t, data_t>::getHeight() const -> index_t
318 2 : {
319 2 : return _height;
320 2 : }
321 :
322 : template <typename ret_t, typename data_t>
323 : auto ShearletTransform<ret_t, data_t>::getNumOfLayers() const -> index_t
324 250 : {
325 250 : return _numOfLayers;
326 250 : }
327 :
328 : template <typename ret_t, typename data_t>
329 : ShearletTransform<ret_t, data_t>* ShearletTransform<ret_t, data_t>::cloneImpl() const
330 2 : {
331 2 : return new ShearletTransform<ret_t, data_t>(_width, _height, _numOfScales, _spectra);
332 2 : }
333 :
334 : template <typename ret_t, typename data_t>
335 : bool ShearletTransform<ret_t, data_t>::isEqual(const LinearOperator<ret_t>& other) const
336 2 : {
337 2 : if (!LinearOperator<ret_t>::isEqual(other))
338 0 : return false;
339 :
340 2 : auto otherST = downcast_safe<ShearletTransform<ret_t, data_t>>(&other);
341 :
342 2 : if (!otherST)
343 0 : return false;
344 :
345 2 : if (_width != otherST->_width)
346 0 : return false;
347 :
348 2 : if (_height != otherST->_height)
349 0 : return false;
350 :
351 2 : if (_numOfScales != otherST->_numOfScales)
352 0 : return false;
353 :
354 2 : return true;
355 2 : }
356 :
357 : // ------------------------------------------
358 : // explicit template instantiation
359 : template class ShearletTransform<float, float>;
360 : template class ShearletTransform<elsa::complex<float>, float>;
361 : template class ShearletTransform<double, double>;
362 : template class ShearletTransform<elsa::complex<double>, double>;
363 : } // namespace elsa
|