Line data Source code
1 : /**
2 : * @file test_ShearletTransform.cpp
3 : *
4 : * @brief Tests for the ShearletTransform class
5 : *
6 : * @author Andi Braimllari
7 : */
8 :
9 : #include "ShearletTransform.h"
10 : #include "VolumeDescriptor.h"
11 : #include "TypeCasts.hpp"
12 :
13 : #include <doctest/doctest.h>
14 : #include <testHelpers.h>
15 :
16 : using namespace elsa;
17 : using namespace doctest;
18 :
19 : TEST_SUITE_BEGIN("core");
20 :
21 22 : TEST_CASE_TEMPLATE("ShearletTransform: Testing construction", TestType, float, double)
22 : {
23 8 : GIVEN("a DataDescriptor")
24 : {
25 8 : IndexVector_t size(2);
26 4 : size << 64, 64;
27 8 : VolumeDescriptor volDescr(size);
28 :
29 6 : WHEN("instantiating a ShearletTransform operator")
30 : {
31 4 : ShearletTransform<TestType, TestType> shearletTransform(size[0], size[1]);
32 :
33 4 : THEN("the DataDescriptors are equal")
34 : {
35 2 : REQUIRE_EQ(shearletTransform.getDomainDescriptor(), volDescr);
36 : }
37 : }
38 :
39 6 : WHEN("cloning a ShearletTransform operator")
40 : {
41 4 : ShearletTransform<TestType, TestType> shearletTransform(size[0], size[1]);
42 4 : auto shearletTransformClone = shearletTransform.clone();
43 :
44 4 : THEN("cloned ShearletTransform operator equals original ShearletTransform operator")
45 : {
46 2 : REQUIRE_NE(shearletTransformClone.get(), &shearletTransform);
47 2 : REQUIRE_EQ(*shearletTransformClone, shearletTransform);
48 : }
49 : }
50 : }
51 4 : }
52 :
53 20 : TEST_CASE_TEMPLATE("ShearletTransform: Testing reconstruction precision", TestType, float, double)
54 : {
55 4 : GIVEN("a 2D signal")
56 : {
57 4 : IndexVector_t size(2);
58 2 : size << 127, 127;
59 4 : VolumeDescriptor volDescr(size);
60 :
61 4 : Vector_t<TestType> randomData(volDescr.getNumberOfCoefficients());
62 2 : randomData.setRandom();
63 4 : DataContainer<TestType> signal(volDescr, randomData);
64 4 : DataContainer<elsa::complex<TestType>> complexSignal(volDescr);
65 32260 : for (index_t i = 0; i < signal.getSize(); ++i) {
66 32258 : complexSignal[i] = elsa::complex<TestType>(signal[i], 0);
67 : }
68 :
69 4 : WHEN("reconstructing the signal")
70 : {
71 4 : ShearletTransform<elsa::complex<TestType>, TestType> shearletTransform(size[0], size[1],
72 : 4);
73 :
74 4 : DataContainer<elsa::complex<TestType>> shearletCoefficients =
75 : shearletTransform.apply(complexSignal);
76 :
77 4 : DataContainer<TestType> reconstruction =
78 : real(shearletTransform.applyAdjoint(shearletCoefficients));
79 :
80 4 : THEN("the ground truth and the reconstruction match")
81 : {
82 2 : REQUIRE_UNARY(isApprox(reconstruction, signal));
83 : }
84 : }
85 : }
86 2 : }
87 :
88 24 : TEST_CASE_TEMPLATE("ShearletTransform: Testing spectra's Parseval frame property", TestType, float,
89 : double)
90 : {
91 12 : GIVEN("a 2D signal")
92 : {
93 12 : IndexVector_t size(2);
94 6 : size << 127, 127;
95 12 : VolumeDescriptor volDescr(size);
96 :
97 12 : Vector_t<TestType> randomData(volDescr.getNumberOfCoefficients());
98 6 : randomData.setRandom();
99 12 : DataContainer<TestType> signal(volDescr, randomData);
100 :
101 8 : WHEN("not generating the spectra")
102 : {
103 4 : ShearletTransform<TestType, TestType> shearletTransform(size);
104 :
105 4 : THEN("an error is thrown when fetching it")
106 : {
107 4 : REQUIRE_THROWS_AS(shearletTransform.getSpectra(), LogicError);
108 : }
109 : }
110 :
111 10 : WHEN("generating the spectra")
112 : {
113 8 : ShearletTransform<TestType, TestType> shearletTransform(size[0], size[1], 4);
114 :
115 4 : shearletTransform.computeSpectra();
116 :
117 6 : THEN("the spectra is reported as computed")
118 : {
119 2 : REQUIRE(shearletTransform.isSpectraComputed());
120 : }
121 :
122 : /// If a matrix mxn A has rows that constitute Parseval frame, then AtA = I
123 : /// (Corollary 1.4.7 from An Introduction to Frames and Riesz Bases). Given that our
124 : /// spectra constitute a Parseval frame, we can utilize this property to check if
125 : /// they've been generated correctly.
126 6 : THEN("the spectra constitute a Parseval frame")
127 : {
128 4 : DataContainer<TestType> spectra = shearletTransform.getSpectra();
129 2 : index_t width = shearletTransform.getWidth();
130 2 : index_t height = shearletTransform.getHeight();
131 2 : index_t layers = shearletTransform.getNumOfLayers();
132 :
133 4 : DataContainer<TestType> frameCorrectness(VolumeDescriptor{{width, height}});
134 :
135 256 : for (index_t w1 = 0; w1 < width; w1++) {
136 32512 : for (index_t w2 = 0; w2 < height; w2++) {
137 32258 : TestType currFrameSum = 0;
138 1999996 : for (index_t i = 0; i < layers; i++) {
139 1967738 : currFrameSum += spectra(w1, w2, i) * spectra(w1, w2, i);
140 : }
141 32258 : frameCorrectness(w1, w2) = currFrameSum - 1;
142 : }
143 : }
144 :
145 2 : DataContainer<TestType> zeroes(VolumeDescriptor{{width, height}});
146 2 : zeroes = 0;
147 :
148 : // spectra here is of shape (W, H, L), square its elements and get the sum by the
149 : // last axis and subtract 1, the output will be of shape (W, H), its elements
150 : // should be zeroes, or very close to it
151 :
152 2 : REQUIRE_UNARY(frameCorrectness.squaredL2Norm() < 0.000000001);
153 : }
154 : }
155 : }
156 6 : }
157 :
158 : TEST_SUITE_END();
|