LCOV - code coverage report
Current view: top level - elsa/operators - ShearletTransform.cpp (source / functions) Hit Total Coverage
Test: coverage-all.lcov Lines: 205 237 86.5 %
Date: 2025-01-02 06:42:49 Functions: 64 88 72.7 %

          Line data    Source code
       1             : #include "ShearletTransform.h"
       2             : #include "FourierTransform.h"
       3             : #include "VolumeDescriptor.h"
       4             : #include "Timer.h"
       5             : #include "Math.hpp"
       6             : 
       7             : namespace elsa
       8             : {
       9             :     template <typename ret_t, typename data_t>
      10             :     ShearletTransform<ret_t, data_t>::ShearletTransform(IndexVector_t spatialDimensions)
      11             :         : ShearletTransform(spatialDimensions[0], spatialDimensions[1])
      12           2 :     {
      13           2 :         if (spatialDimensions.size() != 2) {
      14           0 :             throw LogicError("ShearletTransform: Only 2D shape supported");
      15           0 :         }
      16           2 :     }
      17             : 
      18             :     template <typename ret_t, typename data_t>
      19             :     ShearletTransform<ret_t, data_t>::ShearletTransform(index_t width, index_t height)
      20             :         : ShearletTransform(width, height, calculateNumOfScales(width, height))
      21           6 :     {
      22           6 :     }
      23             : 
      24             :     template <typename ret_t, typename data_t>
      25             :     ShearletTransform<ret_t, data_t>::ShearletTransform(index_t width, index_t height,
      26             :                                                         index_t numOfScales)
      27             :         : ShearletTransform(width, height, numOfScales, std::nullopt)
      28          12 :     {
      29          12 :     }
      30             : 
      31             :     template <typename ret_t, typename data_t>
      32             :     ShearletTransform<ret_t, data_t>::ShearletTransform(
      33             :         index_t width, index_t height, index_t numOfScales,
      34             :         std::optional<DataContainer<data_t>> spectra)
      35             :         : LinearOperator<ret_t>(
      36             :             VolumeDescriptor{{width, height}},
      37             :             VolumeDescriptor{{width, height, calculateNumOfLayers(numOfScales)}}),
      38             :           _spectra{spectra},
      39             :           _width{width},
      40             :           _height{height},
      41             :           _numOfScales{numOfScales},
      42             :           _numOfLayers{calculateNumOfLayers(numOfScales)}
      43          14 :     {
      44          14 :         if (width < 0 || height < 0) {
      45           0 :             throw LogicError("ShearletTransform: negative width/height were provided");
      46           0 :         }
      47          14 :         if (numOfScales < 0) {
      48           0 :             throw LogicError("ShearletTransform: negative number of scales was provided");
      49           0 :         }
      50          14 :     }
      51             : 
      52             :     // TODO implement sumByAxis in DataContainer and remove me
      53             :     template <typename ret_t, typename data_t>
      54             :     DataContainer<elsa::complex<data_t>> ShearletTransform<ret_t, data_t>::sumByLastAxis(
      55             :         DataContainer<elsa::complex<data_t>> dc) const
      56           2 :     {
      57           2 :         auto coeffsPerDim = dc.getDataDescriptor().getNumberOfCoefficientsPerDimension();
      58           2 :         index_t width = coeffsPerDim[0];
      59           2 :         index_t height = coeffsPerDim[1];
      60           2 :         index_t layers = coeffsPerDim[2];
      61           2 :         DataContainer<elsa::complex<data_t>> summedDC(VolumeDescriptor{{width, height}});
      62             : 
      63          34 :         for (index_t j = 0; j < width; j++) {
      64         544 :             for (index_t k = 0; k < height; k++) {
      65         512 :                 elsa::complex<data_t> currValue = 0;
      66       31744 :                 for (index_t i = 0; i < layers; i++) {
      67       31232 :                     currValue += dc(j, k, i);
      68       31232 :                 }
      69         512 :                 summedDC(j, k) = currValue;
      70         512 :             }
      71          32 :         }
      72             : 
      73           2 :         return summedDC;
      74           2 :     }
      75             : 
      76             :     template <typename ret_t, typename data_t>
      77             :     void ShearletTransform<ret_t, data_t>::applyImpl(const DataContainer<ret_t>& x,
      78             :                                                      DataContainer<ret_t>& Ax) const
      79           2 :     {
      80           2 :         Timer timeguard("ShearletTransform", "apply");
      81             : 
      82           2 :         if (_width != this->getDomainDescriptor().getNumberOfCoefficientsPerDimension()[0]
      83           2 :             || _height != this->getDomainDescriptor().getNumberOfCoefficientsPerDimension()[1]) {
      84           0 :             throw InvalidArgumentError("ShearletTransform: Width and height of the input do not "
      85           0 :                                        "match to that of this shearlet system");
      86           0 :         }
      87             : 
      88           2 :         Logger::get("ShearletTransform")
      89           2 :             ->info("Running the shearlet transform on a 2D signal of shape ({}, {}), on {} "
      90           2 :                    "scales with an oversampling factor of {} and {} spectra",
      91           2 :                    _width, _height, _numOfScales, _numOfLayers,
      92           2 :                    isSpectraComputed() ? "precomputed" : "non-precomputed");
      93             : 
      94           2 :         if (!isSpectraComputed()) {
      95           2 :             computeSpectra();
      96           2 :         }
      97             : 
      98           2 :         FourierTransform<elsa::complex<data_t>> fourierTransform(x.getDataDescriptor());
      99             : 
     100           2 :         DataContainer<elsa::complex<data_t>> fftImg = fourierTransform.apply(x.asComplex());
     101             : 
     102         124 :         for (index_t i = 0; i < getNumOfLayers(); i++) {
     103         122 :             DataContainer<elsa::complex<data_t>> temp =
     104         122 :                 getSpectra().slice(i).viewAs(x.getDataDescriptor()).asComplex() * fftImg;
     105         122 :             if constexpr (isComplex<ret_t>) {
     106           0 :                 Ax.slice(i) = fourierTransform.applyAdjoint(temp);
     107           0 :             } else {
     108           0 :                 Ax.slice(i) = real(fourierTransform.applyAdjoint(temp));
     109           0 :             }
     110         122 :         }
     111           2 :     }
     112             : 
     113             :     template <typename ret_t, typename data_t>
     114             :     void ShearletTransform<ret_t, data_t>::applyAdjointImpl(const DataContainer<ret_t>& y,
     115             :                                                             DataContainer<ret_t>& Aty) const
     116           2 :     {
     117           2 :         Timer timeguard("ShearletTransform", "applyAdjoint");
     118             : 
     119           2 :         if (_width != this->getDomainDescriptor().getNumberOfCoefficientsPerDimension()[0]
     120           2 :             || _height != this->getDomainDescriptor().getNumberOfCoefficientsPerDimension()[1]) {
     121           0 :             throw InvalidArgumentError("ShearletTransform: Width and height of the input do not "
     122           0 :                                        "match to that of this shearlet system");
     123           0 :         }
     124             : 
     125           2 :         Logger::get("ShearletTransform")
     126           2 :             ->info("Running the inverse shearlet transform on a 2D signal of shape ({}, {}), "
     127           2 :                    "on {} "
     128           2 :                    "scales with an oversampling factor of {} and {} spectra",
     129           2 :                    _width, _height, _numOfScales, _numOfLayers,
     130           2 :                    isSpectraComputed() ? "precomputed" : "non-precomputed");
     131             : 
     132           2 :         if (!isSpectraComputed()) {
     133           0 :             computeSpectra();
     134           0 :         }
     135             : 
     136           2 :         FourierTransform<elsa::complex<data_t>> fourierTransform(Aty.getDataDescriptor());
     137             : 
     138           2 :         DataContainer<elsa::complex<data_t>> intermRes(y.getDataDescriptor());
     139             : 
     140         124 :         for (index_t i = 0; i < getNumOfLayers(); i++) {
     141         122 :             DataContainer<elsa::complex<data_t>> temp =
     142         122 :                 fourierTransform.apply(y.slice(i).viewAs(Aty.getDataDescriptor()).asComplex())
     143         122 :                 * getSpectra().slice(i).viewAs(Aty.getDataDescriptor()).asComplex();
     144         122 :             intermRes.slice(i) = fourierTransform.applyAdjoint(temp);
     145         122 :         }
     146             : 
     147           2 :         if constexpr (isComplex<ret_t>) {
     148           0 :             Aty = sumByLastAxis(intermRes);
     149           0 :         } else {
     150           0 :             Aty = real(sumByLastAxis(intermRes));
     151           0 :         }
     152           2 :     }
     153             : 
     154             :     template <typename ret_t, typename data_t>
     155             :     void ShearletTransform<ret_t, data_t>::computeSpectra() const
     156           6 :     {
     157           6 :         if (isSpectraComputed()) {
     158           0 :             Logger::get("ShearletTransform")->warn("Spectra have already been computed!");
     159           0 :         }
     160             : 
     161           6 :         _spectra = DataContainer<data_t>(VolumeDescriptor{{_width, _height, _numOfLayers}});
     162             : 
     163           6 :         _computeSpectraAtLowFreq();
     164             : 
     165          30 :         for (index_t j = 0; j < _numOfScales; j++) {
     166          24 :             auto twoPowJ = static_cast<index_t>(std::pow(2, j));
     167          24 :             auto shearletsAtJ = static_cast<index_t>(std::pow(2, j + 2));
     168          24 :             index_t shearletsUpUntilJ = shearletsAtJ - 3;
     169          24 :             index_t index = 1;
     170             : 
     171          24 :             _computeSpectraAtSeamLines(j, -twoPowJ, shearletsUpUntilJ + twoPowJ);
     172         180 :             for (auto k = -twoPowJ + 1; k < twoPowJ; k++) {
     173             :                 // modulo instead of remainder for negative numbers is needed here, therefore doing
     174             :                 // "((a % b) + b) % b" instead of "a % b"
     175         156 :                 index_t modIndex =
     176         156 :                     (((twoPowJ - index + 1) % shearletsAtJ) + shearletsAtJ) % shearletsAtJ;
     177         156 :                 if (modIndex == 0) {
     178          18 :                     modIndex = shearletsAtJ - 1;
     179         138 :                 } else {
     180         138 :                     --modIndex;
     181         138 :                 }
     182             : 
     183         156 :                 _computeSpectraAtConicRegions(j, k, shearletsUpUntilJ + modIndex,
     184         156 :                                               shearletsUpUntilJ + twoPowJ + index);
     185         156 :                 ++index;
     186         156 :             }
     187          24 :             _computeSpectraAtSeamLines(j, twoPowJ, shearletsUpUntilJ + twoPowJ + index);
     188          24 :         }
     189           6 :     }
     190             : 
     191             :     template <typename ret_t, typename data_t>
     192             :     void ShearletTransform<ret_t, data_t>::_computeSpectraAtLowFreq() const
     193           6 :     {
     194           6 :         DataContainer<data_t> sectionZero(VolumeDescriptor{{_width, _height}});
     195           6 :         sectionZero = 0;
     196             : 
     197           6 :         auto shape = getShapeFractions();
     198             : 
     199             :         // TODO attempt to refactor the negative indexing
     200         166 :         for (auto w = shape.negativeHalfWidth; w < shape.halfWidth; w++) {
     201        4768 :             for (auto h = shape.negativeHalfHeight; h < shape.halfHeight; h++) {
     202        4608 :                 sectionZero(w < 0 ? w + _width : w, h < 0 ? h + _height : h) =
     203        4608 :                     shearlet::phiHat<data_t>(static_cast<data_t>(w), static_cast<data_t>(h));
     204        4608 :             }
     205         160 :         }
     206             : 
     207           6 :         _spectra.value().slice(0) = sectionZero;
     208           6 :     }
     209             : 
     210             :     template <typename ret_t, typename data_t>
     211             :     void ShearletTransform<ret_t, data_t>::_computeSpectraAtConicRegions(index_t j, index_t k,
     212             :                                                                          index_t hSliceIndex,
     213             :                                                                          index_t vSliceIndex) const
     214         156 :     {
     215         156 :         DataContainer<data_t> sectionh(VolumeDescriptor{{_width, _height}});
     216         156 :         sectionh = 0;
     217         156 :         DataContainer<data_t> sectionv(VolumeDescriptor{{_width, _height}});
     218         156 :         sectionv = 0;
     219             : 
     220         156 :         auto shape = getShapeFractions();
     221         156 :         auto jr = static_cast<data_t>(j);
     222         156 :         auto kr = static_cast<data_t>(k);
     223             : 
     224             :         // TODO attempt to refactor the negative indexing
     225        4316 :         for (auto w = shape.negativeHalfWidth; w < shape.halfWidth; w++) {
     226        4160 :             auto wr = static_cast<data_t>(w);
     227      123968 :             for (auto h = shape.negativeHalfHeight; h < shape.halfHeight; h++) {
     228      119808 :                 auto hr = static_cast<data_t>(h);
     229      119808 :                 if (std::abs(h) <= std::abs(w)) {
     230       63908 :                     sectionh(w < 0 ? w + _width : w, h < 0 ? h + _height : h) =
     231       63908 :                         shearlet::psiHat<data_t>(std::pow(4.f, -jr) * wr,
     232       63908 :                                                  std::pow(4.f, -jr) * kr * wr
     233       63908 :                                                      + std::pow(2.f, -jr) * hr);
     234       63908 :                 } else {
     235       55900 :                     sectionv(w < 0 ? w + _width : w, h < 0 ? h + _height : h) =
     236       55900 :                         shearlet::psiHat<data_t>(std::pow(4.f, -jr) * hr,
     237       55900 :                                                  std::pow(4.f, -jr) * kr * hr
     238       55900 :                                                      + std::pow(2.f, -jr) * wr);
     239       55900 :                 }
     240      119808 :             }
     241        4160 :         }
     242             : 
     243         156 :         _spectra.value().slice(hSliceIndex) = sectionh;
     244         156 :         _spectra.value().slice(vSliceIndex) = sectionv;
     245         156 :     }
     246             : 
     247             :     template <typename ret_t, typename data_t>
     248             :     void ShearletTransform<ret_t, data_t>::_computeSpectraAtSeamLines(index_t j, index_t k,
     249             :                                                                       index_t hxvSliceIndex) const
     250          48 :     {
     251          48 :         DataContainer<data_t> sectionhxv(VolumeDescriptor{{_width, _height}});
     252          48 :         sectionhxv = 0;
     253             : 
     254          48 :         auto shape = getShapeFractions();
     255          48 :         auto jr = static_cast<data_t>(j);
     256          48 :         auto kr = static_cast<data_t>(k);
     257             : 
     258             :         // TODO attempt to refactor the negative indexing
     259        1328 :         for (auto w = shape.negativeHalfWidth; w < shape.halfWidth; w++) {
     260        1280 :             auto wr = static_cast<data_t>(w);
     261       38144 :             for (auto h = shape.negativeHalfHeight; h < shape.halfHeight; h++) {
     262       36864 :                 auto hr = static_cast<data_t>(h);
     263       36864 :                 if (std::abs(h) <= std::abs(w)) {
     264       19664 :                     sectionhxv(w < 0 ? w + _width : w, h < 0 ? h + _height : h) =
     265       19664 :                         shearlet::psiHat<data_t>(std::pow(4.f, -jr) * wr,
     266       19664 :                                                  std::pow(4.f, -jr) * kr * wr
     267       19664 :                                                      + std::pow(2.f, -jr) * hr);
     268       19664 :                 } else {
     269       17200 :                     sectionhxv(w < 0 ? w + _width : w, h < 0 ? h + _height : h) =
     270       17200 :                         shearlet::psiHat<data_t>(std::pow(4.f, -jr) * hr,
     271       17200 :                                                  std::pow(4.f, -jr) * kr * hr
     272       17200 :                                                      + std::pow(2.f, -jr) * wr);
     273       17200 :                 }
     274       36864 :             }
     275        1280 :         }
     276             : 
     277          48 :         _spectra.value().slice(hxvSliceIndex) = sectionhxv;
     278          48 :     }
     279             : 
     280             :     /**
     281             :      * helper function to calculate input data fractions.
     282             :      */
     283             :     template <typename ret_t, typename data_t>
     284             :     auto ShearletTransform<ret_t, data_t>::getShapeFractions() const -> shape_fractions
     285         210 :     {
     286         210 :         shape_fractions ret;
     287         210 :         auto width = static_cast<real_t>(_width);
     288         210 :         auto height = static_cast<real_t>(_height);
     289             : 
     290         210 :         ret.negativeHalfWidth = static_cast<index_t>(-std::floor(width / 2.0));
     291         210 :         ret.halfWidth = static_cast<index_t>(std::ceil(width / 2.0));
     292         210 :         ret.negativeHalfHeight = static_cast<index_t>(-std::floor(height / 2.0));
     293         210 :         ret.halfHeight = static_cast<index_t>(std::ceil(height / 2.0));
     294             : 
     295         210 :         return ret;
     296         210 :     }
     297             : 
     298             :     template <typename ret_t, typename data_t>
     299             :     auto ShearletTransform<ret_t, data_t>::getSpectra() const -> DataContainer<data_t>
     300         248 :     {
     301         248 :         if (!_spectra.has_value()) {
     302           2 :             throw LogicError(std::string("ShearletTransform: the spectra is not yet computed"));
     303           2 :         }
     304         246 :         return _spectra.value();
     305         246 :     }
     306             : 
     307             :     template <typename ret_t, typename data_t>
     308             :     bool ShearletTransform<ret_t, data_t>::isSpectraComputed() const
     309          16 :     {
     310          16 :         return _spectra.has_value();
     311          16 :     }
     312             : 
     313             :     template <typename ret_t, typename data_t>
     314             :     index_t ShearletTransform<ret_t, data_t>::calculateNumOfScales(index_t width, index_t height)
     315           6 :     {
     316           6 :         return static_cast<index_t>(std::log2(std::max(width, height)) / 2.0);
     317           6 :     }
     318             : 
     319             :     template <typename ret_t, typename data_t>
     320             :     index_t ShearletTransform<ret_t, data_t>::calculateNumOfLayers(index_t width, index_t height)
     321           0 :     {
     322           0 :         return static_cast<index_t>(std::pow(2, (calculateNumOfScales(width, height) + 2)) - 3);
     323           0 :     }
     324             : 
     325             :     template <typename ret_t, typename data_t>
     326             :     index_t ShearletTransform<ret_t, data_t>::calculateNumOfLayers(index_t numOfScales)
     327          28 :     {
     328          28 :         return static_cast<index_t>(std::pow(2, numOfScales + 2) - 3);
     329          28 :     }
     330             : 
     331             :     template <typename ret_t, typename data_t>
     332             :     auto ShearletTransform<ret_t, data_t>::getWidth() const -> index_t
     333           2 :     {
     334           2 :         return _width;
     335           2 :     }
     336             : 
     337             :     template <typename ret_t, typename data_t>
     338             :     auto ShearletTransform<ret_t, data_t>::getHeight() const -> index_t
     339           2 :     {
     340           2 :         return _height;
     341           2 :     }
     342             : 
     343             :     template <typename ret_t, typename data_t>
     344             :     auto ShearletTransform<ret_t, data_t>::getNumOfLayers() const -> index_t
     345         250 :     {
     346         250 :         return _numOfLayers;
     347         250 :     }
     348             : 
     349             :     template <typename ret_t, typename data_t>
     350             :     ShearletTransform<ret_t, data_t>* ShearletTransform<ret_t, data_t>::cloneImpl() const
     351           2 :     {
     352           2 :         return new ShearletTransform<ret_t, data_t>(_width, _height, _numOfScales, _spectra);
     353           2 :     }
     354             : 
     355             :     template <typename ret_t, typename data_t>
     356             :     bool ShearletTransform<ret_t, data_t>::isEqual(const LinearOperator<ret_t>& other) const
     357           2 :     {
     358           2 :         if (!LinearOperator<ret_t>::isEqual(other))
     359           0 :             return false;
     360             : 
     361           2 :         auto otherST = downcast_safe<ShearletTransform<ret_t, data_t>>(&other);
     362             : 
     363           2 :         if (!otherST)
     364           0 :             return false;
     365             : 
     366           2 :         if (_width != otherST->_width)
     367           0 :             return false;
     368             : 
     369           2 :         if (_height != otherST->_height)
     370           0 :             return false;
     371             : 
     372           2 :         if (_numOfScales != otherST->_numOfScales)
     373           0 :             return false;
     374             : 
     375           2 :         return true;
     376           2 :     }
     377             : 
     378             :     // ------------------------------------------
     379             :     // explicit template instantiation
     380             :     template class ShearletTransform<float, float>;
     381             :     template class ShearletTransform<elsa::complex<float>, float>;
     382             :     template class ShearletTransform<double, double>;
     383             :     template class ShearletTransform<elsa::complex<double>, double>;
     384             : } // namespace elsa

Generated by: LCOV version 1.14