Line data Source code
1 : /**
2 : * @file test_ShearletTransform.cpp
3 : *
4 : * @brief Tests for the ShearletTransform class
5 : *
6 : * @author Andi Braimllari
7 : */
8 :
9 : #include "ShearletTransform.h"
10 : #include "VolumeDescriptor.h"
11 : #include "TypeCasts.hpp"
12 :
13 : #include <doctest/doctest.h>
14 : #include <testHelpers.h>
15 :
16 : using namespace elsa;
17 : using namespace doctest;
18 :
19 : TEST_SUITE_BEGIN("core");
20 :
21 : TEST_CASE_TEMPLATE("ShearletTransform: Testing construction", TestType, float, double)
22 4 : {
23 4 : GIVEN("a DataDescriptor")
24 4 : {
25 4 : IndexVector_t size(2);
26 4 : size << 64, 64;
27 4 : VolumeDescriptor volDescr(size);
28 :
29 4 : WHEN("instantiating a ShearletTransform operator")
30 4 : {
31 2 : ShearletTransform<TestType, TestType> shearletTransform(size[0], size[1]);
32 :
33 2 : THEN("the DataDescriptors are equal")
34 2 : {
35 2 : REQUIRE_EQ(shearletTransform.getDomainDescriptor(), volDescr);
36 2 : }
37 2 : }
38 :
39 4 : WHEN("cloning a ShearletTransform operator")
40 4 : {
41 2 : ShearletTransform<TestType, TestType> shearletTransform(size[0], size[1]);
42 2 : auto shearletTransformClone = shearletTransform.clone();
43 :
44 2 : THEN("cloned ShearletTransform operator equals original ShearletTransform operator")
45 2 : {
46 2 : REQUIRE_NE(shearletTransformClone.get(), &shearletTransform);
47 2 : REQUIRE_EQ(*shearletTransformClone, shearletTransform);
48 2 : }
49 2 : }
50 4 : }
51 4 : }
52 :
53 : TEST_CASE_TEMPLATE("ShearletTransform: Testing reconstruction precision", TestType, float, double)
54 2 : {
55 2 : GIVEN("a 2D signal")
56 2 : {
57 2 : IndexVector_t size(2);
58 2 : size << 32, 32;
59 2 : VolumeDescriptor volDescr(size);
60 :
61 2 : Vector_t<TestType> randomData(volDescr.getNumberOfCoefficients());
62 2 : randomData.setRandom();
63 2 : DataContainer<TestType> signal(volDescr, randomData);
64 2 : DataContainer<elsa::complex<TestType>> complexSignal(volDescr);
65 2050 : for (index_t i = 0; i < signal.getSize(); ++i) {
66 2048 : complexSignal[i] = elsa::complex<TestType>(signal[i], 0);
67 2048 : }
68 :
69 2 : WHEN("reconstructing the signal")
70 2 : {
71 2 : ShearletTransform<elsa::complex<TestType>, TestType> shearletTransform(size[0], size[1],
72 2 : 4);
73 :
74 2 : DataContainer<elsa::complex<TestType>> shearletCoefficients =
75 2 : shearletTransform.apply(complexSignal);
76 :
77 2 : DataContainer<TestType> reconstruction =
78 2 : real(shearletTransform.applyAdjoint(shearletCoefficients));
79 :
80 2 : THEN("the ground truth and the reconstruction match")
81 2 : {
82 2 : REQUIRE_UNARY(isApprox(reconstruction, signal));
83 2 : }
84 2 : }
85 2 : }
86 2 : }
87 :
88 : TEST_CASE_TEMPLATE("ShearletTransform: Testing spectra's Parseval frame property", TestType, float,
89 : double)
90 6 : {
91 6 : GIVEN("a 2D signal")
92 6 : {
93 6 : IndexVector_t size(2);
94 6 : size << 32, 32;
95 6 : VolumeDescriptor volDescr(size);
96 :
97 6 : Vector_t<TestType> randomData(volDescr.getNumberOfCoefficients());
98 6 : randomData.setRandom();
99 6 : DataContainer<TestType> signal(volDescr, randomData);
100 :
101 6 : WHEN("not generating the spectra")
102 6 : {
103 2 : ShearletTransform<TestType, TestType> shearletTransform(size);
104 :
105 2 : THEN("an error is thrown when fetching it")
106 2 : {
107 2 : REQUIRE_THROWS_AS(shearletTransform.getSpectra(), LogicError);
108 2 : }
109 2 : }
110 :
111 6 : WHEN("generating the spectra")
112 6 : {
113 4 : ShearletTransform<TestType, TestType> shearletTransform(size[0], size[1], 4);
114 :
115 4 : shearletTransform.computeSpectra();
116 :
117 4 : THEN("the spectra is reported as computed")
118 4 : {
119 2 : REQUIRE(shearletTransform.isSpectraComputed());
120 2 : }
121 :
122 : /// If a matrix mxn A has rows that constitute Parseval frame, then AtA = I
123 : /// (Corollary 1.4.7 from An Introduction to Frames and Riesz Bases). Given that our
124 : /// spectra constitute a Parseval frame, we can utilize this property to check if
125 : /// they've been generated correctly.
126 4 : THEN("the spectra constitute a Parseval frame")
127 4 : {
128 2 : DataContainer<TestType> spectra = shearletTransform.getSpectra();
129 2 : index_t width = shearletTransform.getWidth();
130 2 : index_t height = shearletTransform.getHeight();
131 2 : index_t layers = shearletTransform.getNumOfLayers();
132 :
133 2 : DataContainer<TestType> frameCorrectness(VolumeDescriptor{{width, height}});
134 :
135 66 : for (index_t w1 = 0; w1 < width; w1++) {
136 2112 : for (index_t w2 = 0; w2 < height; w2++) {
137 2048 : TestType currFrameSum = 0;
138 126976 : for (index_t i = 0; i < layers; i++) {
139 124928 : currFrameSum += spectra(w1, w2, i) * spectra(w1, w2, i);
140 124928 : }
141 2048 : frameCorrectness(w1, w2) = currFrameSum - 1;
142 2048 : }
143 64 : }
144 :
145 2 : DataContainer<TestType> zeroes(VolumeDescriptor{{width, height}});
146 2 : zeroes = 0;
147 :
148 : // spectra here is of shape (W, H, L), square its elements and get the sum by the
149 : // last axis and subtract 1, the output will be of shape (W, H), its elements
150 : // should be zeroes, or very close to it
151 :
152 2 : REQUIRE_UNARY(frameCorrectness.squaredL2Norm() < 0.000000001);
153 2 : }
154 4 : }
155 6 : }
156 6 : }
157 :
158 : TEST_SUITE_END();
|