LCOV - code coverage report
Current view: top level - elsa/operators/tests - test_ShearletTransform.cpp (source / functions) Hit Total Coverage
Test: coverage-all.lcov Lines: 102 102 100.0 %
Date: 2022-08-25 03:05:39 Functions: 6 6 100.0 %

          Line data    Source code
       1             : /**
       2             :  * @file test_ShearletTransform.cpp
       3             :  *
       4             :  * @brief Tests for the ShearletTransform class
       5             :  *
       6             :  * @author Andi Braimllari
       7             :  */
       8             : 
       9             : #include "ShearletTransform.h"
      10             : #include "VolumeDescriptor.h"
      11             : #include "TypeCasts.hpp"
      12             : 
      13             : #include <doctest/doctest.h>
      14             : #include <testHelpers.h>
      15             : 
      16             : using namespace elsa;
      17             : using namespace doctest;
      18             : 
      19             : TEST_SUITE_BEGIN("core");
      20             : 
      21             : TEST_CASE_TEMPLATE("ShearletTransform: Testing construction", TestType, float, double)
      22           4 : {
      23           4 :     GIVEN("a DataDescriptor")
      24           4 :     {
      25           4 :         IndexVector_t size(2);
      26           4 :         size << 64, 64;
      27           4 :         VolumeDescriptor volDescr(size);
      28             : 
      29           4 :         WHEN("instantiating a ShearletTransform operator")
      30           4 :         {
      31           2 :             ShearletTransform<TestType, TestType> shearletTransform(size[0], size[1]);
      32             : 
      33           2 :             THEN("the DataDescriptors are equal")
      34           2 :             {
      35           2 :                 REQUIRE_EQ(shearletTransform.getDomainDescriptor(), volDescr);
      36           2 :             }
      37           2 :         }
      38             : 
      39           4 :         WHEN("cloning a ShearletTransform operator")
      40           4 :         {
      41           2 :             ShearletTransform<TestType, TestType> shearletTransform(size[0], size[1]);
      42           2 :             auto shearletTransformClone = shearletTransform.clone();
      43             : 
      44           2 :             THEN("cloned ShearletTransform operator equals original ShearletTransform operator")
      45           2 :             {
      46           2 :                 REQUIRE_NE(shearletTransformClone.get(), &shearletTransform);
      47           2 :                 REQUIRE_EQ(*shearletTransformClone, shearletTransform);
      48           2 :             }
      49           2 :         }
      50           4 :     }
      51           4 : }
      52             : 
      53             : TEST_CASE_TEMPLATE("ShearletTransform: Testing reconstruction precision", TestType, float, double)
      54           2 : {
      55           2 :     GIVEN("a 2D signal")
      56           2 :     {
      57           2 :         IndexVector_t size(2);
      58           2 :         size << 32, 32;
      59           2 :         VolumeDescriptor volDescr(size);
      60             : 
      61           2 :         Vector_t<TestType> randomData(volDescr.getNumberOfCoefficients());
      62           2 :         randomData.setRandom();
      63           2 :         DataContainer<TestType> signal(volDescr, randomData);
      64           2 :         DataContainer<elsa::complex<TestType>> complexSignal(volDescr);
      65        2050 :         for (index_t i = 0; i < signal.getSize(); ++i) {
      66        2048 :             complexSignal[i] = elsa::complex<TestType>(signal[i], 0);
      67        2048 :         }
      68             : 
      69           2 :         WHEN("reconstructing the signal")
      70           2 :         {
      71           2 :             ShearletTransform<elsa::complex<TestType>, TestType> shearletTransform(size[0], size[1],
      72           2 :                                                                                    4);
      73             : 
      74           2 :             DataContainer<elsa::complex<TestType>> shearletCoefficients =
      75           2 :                 shearletTransform.apply(complexSignal);
      76             : 
      77           2 :             DataContainer<TestType> reconstruction =
      78           2 :                 real(shearletTransform.applyAdjoint(shearletCoefficients));
      79             : 
      80           2 :             THEN("the ground truth and the reconstruction match")
      81           2 :             {
      82           2 :                 REQUIRE_UNARY(isApprox(reconstruction, signal));
      83           2 :             }
      84           2 :         }
      85           2 :     }
      86           2 : }
      87             : 
      88             : TEST_CASE_TEMPLATE("ShearletTransform: Testing spectra's Parseval frame property", TestType, float,
      89             :                    double)
      90           6 : {
      91           6 :     GIVEN("a 2D signal")
      92           6 :     {
      93           6 :         IndexVector_t size(2);
      94           6 :         size << 32, 32;
      95           6 :         VolumeDescriptor volDescr(size);
      96             : 
      97           6 :         Vector_t<TestType> randomData(volDescr.getNumberOfCoefficients());
      98           6 :         randomData.setRandom();
      99           6 :         DataContainer<TestType> signal(volDescr, randomData);
     100             : 
     101           6 :         WHEN("not generating the spectra")
     102           6 :         {
     103           2 :             ShearletTransform<TestType, TestType> shearletTransform(size);
     104             : 
     105           2 :             THEN("an error is thrown when fetching it")
     106           2 :             {
     107           2 :                 REQUIRE_THROWS_AS(shearletTransform.getSpectra(), LogicError);
     108           2 :             }
     109           2 :         }
     110             : 
     111           6 :         WHEN("generating the spectra")
     112           6 :         {
     113           4 :             ShearletTransform<TestType, TestType> shearletTransform(size[0], size[1], 4);
     114             : 
     115           4 :             shearletTransform.computeSpectra();
     116             : 
     117           4 :             THEN("the spectra is reported as computed")
     118           4 :             {
     119           2 :                 REQUIRE(shearletTransform.isSpectraComputed());
     120           2 :             }
     121             : 
     122             :             /// If a matrix mxn A has rows that constitute Parseval frame, then AtA = I
     123             :             /// (Corollary 1.4.7 from An Introduction to Frames and Riesz Bases). Given that our
     124             :             /// spectra constitute a Parseval frame, we can utilize this property to check if
     125             :             /// they've been generated correctly.
     126           4 :             THEN("the spectra constitute a Parseval frame")
     127           4 :             {
     128           2 :                 DataContainer<TestType> spectra = shearletTransform.getSpectra();
     129           2 :                 index_t width = shearletTransform.getWidth();
     130           2 :                 index_t height = shearletTransform.getHeight();
     131           2 :                 index_t layers = shearletTransform.getNumOfLayers();
     132             : 
     133           2 :                 DataContainer<TestType> frameCorrectness(VolumeDescriptor{{width, height}});
     134             : 
     135          66 :                 for (index_t w1 = 0; w1 < width; w1++) {
     136        2112 :                     for (index_t w2 = 0; w2 < height; w2++) {
     137        2048 :                         TestType currFrameSum = 0;
     138      126976 :                         for (index_t i = 0; i < layers; i++) {
     139      124928 :                             currFrameSum += spectra(w1, w2, i) * spectra(w1, w2, i);
     140      124928 :                         }
     141        2048 :                         frameCorrectness(w1, w2) = currFrameSum - 1;
     142        2048 :                     }
     143          64 :                 }
     144             : 
     145           2 :                 DataContainer<TestType> zeroes(VolumeDescriptor{{width, height}});
     146           2 :                 zeroes = 0;
     147             : 
     148             :                 // spectra here is of shape (W, H, L), square its elements and get the sum by the
     149             :                 // last axis and subtract 1, the output will be of shape (W, H), its elements
     150             :                 // should be zeroes, or very close to it
     151             : 
     152           2 :                 REQUIRE_UNARY(frameCorrectness.squaredL2Norm() < 0.000000001);
     153           2 :             }
     154           4 :         }
     155           6 :     }
     156           6 : }
     157             : 
     158             : TEST_SUITE_END();

Generated by: LCOV version 1.14