LCOV - code coverage report
Current view: top level - elsa/operators - ShearletTransform.cpp (source / functions) Hit Total Coverage
Test: coverage-all.lcov Lines: 192 224 85.7 %
Date: 2022-08-25 03:05:39 Functions: 60 84 71.4 %

          Line data    Source code
       1             : #include "ShearletTransform.h"
       2             : #include "FourierTransform.h"
       3             : #include "VolumeDescriptor.h"
       4             : #include "Timer.h"
       5             : #include "Math.hpp"
       6             : 
       7             : namespace elsa
       8             : {
       9             :     template <typename ret_t, typename data_t>
      10             :     ShearletTransform<ret_t, data_t>::ShearletTransform(IndexVector_t spatialDimensions)
      11             :         : ShearletTransform(spatialDimensions[0], spatialDimensions[1])
      12           2 :     {
      13           2 :         if (spatialDimensions.size() != 2) {
      14           0 :             throw LogicError("ShearletTransform: Only 2D shape supported");
      15           0 :         }
      16           2 :     }
      17             : 
      18             :     template <typename ret_t, typename data_t>
      19             :     ShearletTransform<ret_t, data_t>::ShearletTransform(index_t width, index_t height)
      20             :         : ShearletTransform(width, height, calculateNumOfScales(width, height))
      21           6 :     {
      22           6 :     }
      23             : 
      24             :     template <typename ret_t, typename data_t>
      25             :     ShearletTransform<ret_t, data_t>::ShearletTransform(index_t width, index_t height,
      26             :                                                         index_t numOfScales)
      27             :         : ShearletTransform(width, height, numOfScales, std::nullopt)
      28          12 :     {
      29          12 :     }
      30             : 
      31             :     template <typename ret_t, typename data_t>
      32             :     ShearletTransform<ret_t, data_t>::ShearletTransform(
      33             :         index_t width, index_t height, index_t numOfScales,
      34             :         std::optional<DataContainer<data_t>> spectra)
      35             :         : LinearOperator<ret_t>(
      36             :             VolumeDescriptor{{width, height}},
      37             :             VolumeDescriptor{{width, height, calculateNumOfLayers(numOfScales)}}),
      38             :           _spectra{spectra},
      39             :           _width{width},
      40             :           _height{height},
      41             :           _numOfScales{numOfScales},
      42             :           _numOfLayers{calculateNumOfLayers(numOfScales)}
      43          14 :     {
      44          14 :         if (width < 0 || height < 0) {
      45           0 :             throw LogicError("ShearletTransform: negative width/height were provided");
      46           0 :         }
      47          14 :         if (numOfScales < 0) {
      48           0 :             throw LogicError("ShearletTransform: negative number of scales was provided");
      49           0 :         }
      50          14 :     }
      51             : 
      52             :     // TODO implement sumByAxis in DataContainer and remove me
      53             :     template <typename ret_t, typename data_t>
      54             :     DataContainer<elsa::complex<data_t>> ShearletTransform<ret_t, data_t>::sumByLastAxis(
      55             :         DataContainer<elsa::complex<data_t>> dc) const
      56           2 :     {
      57           2 :         auto coeffsPerDim = dc.getDataDescriptor().getNumberOfCoefficientsPerDimension();
      58           2 :         index_t width = coeffsPerDim[0];
      59           2 :         index_t height = coeffsPerDim[1];
      60           2 :         index_t layers = coeffsPerDim[2];
      61           2 :         DataContainer<elsa::complex<data_t>> summedDC(VolumeDescriptor{{width, height}});
      62             : 
      63          66 :         for (index_t j = 0; j < width; j++) {
      64        2112 :             for (index_t k = 0; k < height; k++) {
      65        2048 :                 elsa::complex<data_t> currValue = 0;
      66      126976 :                 for (index_t i = 0; i < layers; i++) {
      67      124928 :                     currValue += dc(j, k, i);
      68      124928 :                 }
      69        2048 :                 summedDC(j, k) = currValue;
      70        2048 :             }
      71          64 :         }
      72             : 
      73           2 :         return summedDC;
      74           2 :     }
      75             : 
      76             :     template <typename ret_t, typename data_t>
      77             :     void ShearletTransform<ret_t, data_t>::applyImpl(const DataContainer<ret_t>& x,
      78             :                                                      DataContainer<ret_t>& Ax) const
      79           2 :     {
      80           2 :         Timer timeguard("ShearletTransform", "apply");
      81             : 
      82           2 :         if (_width != this->getDomainDescriptor().getNumberOfCoefficientsPerDimension()[0]
      83           2 :             || _height != this->getDomainDescriptor().getNumberOfCoefficientsPerDimension()[1]) {
      84           0 :             throw InvalidArgumentError("ShearletTransform: Width and height of the input do not "
      85           0 :                                        "match to that of this shearlet system");
      86           0 :         }
      87             : 
      88           2 :         Logger::get("ShearletTransform")
      89           2 :             ->info("Running the shearlet transform on a 2D signal of shape ({}, {}), on {} "
      90           2 :                    "scales with an oversampling factor of {} and {} spectra",
      91           2 :                    _width, _height, _numOfScales, _numOfLayers,
      92           2 :                    isSpectraComputed() ? "precomputed" : "non-precomputed");
      93             : 
      94           2 :         if (!isSpectraComputed()) {
      95           2 :             computeSpectra();
      96           2 :         }
      97             : 
      98           2 :         FourierTransform<elsa::complex<data_t>> fourierTransform(x.getDataDescriptor());
      99             : 
     100           2 :         DataContainer<elsa::complex<data_t>> fftImg = fourierTransform.apply(x.asComplex());
     101             : 
     102         124 :         for (index_t i = 0; i < getNumOfLayers(); i++) {
     103         122 :             DataContainer<elsa::complex<data_t>> temp =
     104         122 :                 getSpectra().slice(i).viewAs(x.getDataDescriptor()).asComplex() * fftImg;
     105         122 :             if constexpr (isComplex<ret_t>) {
     106           0 :                 Ax.slice(i) = fourierTransform.applyAdjoint(temp);
     107           0 :             } else {
     108           0 :                 Ax.slice(i) = real(fourierTransform.applyAdjoint(temp));
     109           0 :             }
     110         122 :         }
     111           2 :     }
     112             : 
     113             :     template <typename ret_t, typename data_t>
     114             :     void ShearletTransform<ret_t, data_t>::applyAdjointImpl(const DataContainer<ret_t>& y,
     115             :                                                             DataContainer<ret_t>& Aty) const
     116           2 :     {
     117           2 :         Timer timeguard("ShearletTransform", "applyAdjoint");
     118             : 
     119           2 :         if (_width != this->getDomainDescriptor().getNumberOfCoefficientsPerDimension()[0]
     120           2 :             || _height != this->getDomainDescriptor().getNumberOfCoefficientsPerDimension()[1]) {
     121           0 :             throw InvalidArgumentError("ShearletTransform: Width and height of the input do not "
     122           0 :                                        "match to that of this shearlet system");
     123           0 :         }
     124             : 
     125           2 :         Logger::get("ShearletTransform")
     126           2 :             ->info("Running the inverse shearlet transform on a 2D signal of shape ({}, {}), "
     127           2 :                    "on {} "
     128           2 :                    "scales with an oversampling factor of {} and {} spectra",
     129           2 :                    _width, _height, _numOfScales, _numOfLayers,
     130           2 :                    isSpectraComputed() ? "precomputed" : "non-precomputed");
     131             : 
     132           2 :         if (!isSpectraComputed()) {
     133           0 :             computeSpectra();
     134           0 :         }
     135             : 
     136           2 :         FourierTransform<elsa::complex<data_t>> fourierTransform(Aty.getDataDescriptor());
     137             : 
     138           2 :         DataContainer<elsa::complex<data_t>> intermRes(y.getDataDescriptor());
     139             : 
     140         124 :         for (index_t i = 0; i < getNumOfLayers(); i++) {
     141         122 :             DataContainer<elsa::complex<data_t>> temp =
     142         122 :                 fourierTransform.apply(y.slice(i).viewAs(Aty.getDataDescriptor()).asComplex())
     143         122 :                 * getSpectra().slice(i).viewAs(Aty.getDataDescriptor()).asComplex();
     144         122 :             intermRes.slice(i) = fourierTransform.applyAdjoint(temp);
     145         122 :         }
     146             : 
     147           2 :         if constexpr (isComplex<ret_t>) {
     148           0 :             Aty = sumByLastAxis(intermRes);
     149           0 :         } else {
     150           0 :             Aty = real(sumByLastAxis(intermRes));
     151           0 :         }
     152           2 :     }
     153             : 
     154             :     template <typename ret_t, typename data_t>
     155             :     void ShearletTransform<ret_t, data_t>::computeSpectra() const
     156           6 :     {
     157           6 :         if (isSpectraComputed()) {
     158           0 :             Logger::get("ShearletTransform")->warn("Spectra have already been computed!");
     159           0 :         }
     160             : 
     161           6 :         _spectra = DataContainer<data_t>(VolumeDescriptor{{_width, _height, _numOfLayers}});
     162             : 
     163           6 :         _computeSpectraAtLowFreq();
     164             : 
     165          30 :         for (index_t j = 0; j < _numOfScales; j++) {
     166          24 :             auto twoPowJ = static_cast<index_t>(std::pow(2, j));
     167          24 :             auto shearletsAtJ = static_cast<index_t>(std::pow(2, j + 2));
     168          24 :             index_t shearletsUpUntilJ = shearletsAtJ - 3;
     169          24 :             index_t index = 1;
     170             : 
     171          24 :             _computeSpectraAtSeamLines(j, -twoPowJ, shearletsUpUntilJ + twoPowJ);
     172         180 :             for (auto k = -twoPowJ + 1; k < twoPowJ; k++) {
     173             :                 // modulo instead of remainder for negative numbers is needed here, therefore doing
     174             :                 // "((a % b) + b) % b" instead of "a % b"
     175         156 :                 index_t modIndex =
     176         156 :                     (((twoPowJ - index + 1) % shearletsAtJ) + shearletsAtJ) % shearletsAtJ;
     177         156 :                 if (modIndex == 0) {
     178          18 :                     modIndex = shearletsAtJ - 1;
     179         138 :                 } else {
     180         138 :                     --modIndex;
     181         138 :                 }
     182             : 
     183         156 :                 _computeSpectraAtConicRegions(j, k, shearletsUpUntilJ + modIndex,
     184         156 :                                               shearletsUpUntilJ + twoPowJ + index);
     185         156 :                 ++index;
     186         156 :             }
     187          24 :             _computeSpectraAtSeamLines(j, twoPowJ, shearletsUpUntilJ + twoPowJ + index);
     188          24 :         }
     189           6 :     }
     190             : 
     191             :     template <typename ret_t, typename data_t>
     192             :     void ShearletTransform<ret_t, data_t>::_computeSpectraAtLowFreq() const
     193           6 :     {
     194           6 :         DataContainer<data_t> sectionZero(VolumeDescriptor{{_width, _height}});
     195           6 :         sectionZero = 0;
     196             : 
     197           6 :         auto negativeHalfWidth = static_cast<index_t>(-std::floor(_width / 2.0));
     198           6 :         auto halfWidth = static_cast<index_t>(std::ceil(_width / 2.0));
     199           6 :         auto negativeHalfHeight = static_cast<index_t>(-std::floor(_height / 2.0));
     200           6 :         auto halfHeight = static_cast<index_t>(std::ceil(_height / 2.0));
     201             : 
     202             :         // TODO attempt to refactor the negative indexing
     203         198 :         for (auto w = negativeHalfWidth; w < halfWidth; w++) {
     204        6336 :             for (auto h = negativeHalfHeight; h < halfHeight; h++) {
     205        6144 :                 sectionZero(w < 0 ? w + _width : w, h < 0 ? h + _height : h) =
     206        6144 :                     shearlet::phiHat<data_t>(static_cast<data_t>(w), static_cast<data_t>(h));
     207        6144 :             }
     208         192 :         }
     209             : 
     210           6 :         _spectra.value().slice(0) = sectionZero;
     211           6 :     }
     212             : 
     213             :     template <typename ret_t, typename data_t>
     214             :     void ShearletTransform<ret_t, data_t>::_computeSpectraAtConicRegions(index_t j, index_t k,
     215             :                                                                          index_t hSliceIndex,
     216             :                                                                          index_t vSliceIndex) const
     217         156 :     {
     218         156 :         DataContainer<data_t> sectionh(VolumeDescriptor{{_width, _height}});
     219         156 :         sectionh = 0;
     220         156 :         DataContainer<data_t> sectionv(VolumeDescriptor{{_width, _height}});
     221         156 :         sectionv = 0;
     222             : 
     223         156 :         auto negativeHalfWidth = static_cast<index_t>(-std::floor(_width / 2.0));
     224         156 :         auto halfWidth = static_cast<index_t>(std::ceil(_width / 2.0));
     225         156 :         auto negativeHalfHeight = static_cast<index_t>(-std::floor(_height / 2.0));
     226         156 :         auto halfHeight = static_cast<index_t>(std::ceil(_height / 2.0));
     227             : 
     228             :         // TODO attempt to refactor the negative indexing
     229        5148 :         for (auto w = negativeHalfWidth; w < halfWidth; w++) {
     230      164736 :             for (auto h = negativeHalfHeight; h < halfHeight; h++) {
     231      159744 :                 if (std::abs(h) <= std::abs(w)) {
     232       84708 :                     sectionh(w < 0 ? w + _width : w, h < 0 ? h + _height : h) =
     233       84708 :                         shearlet::psiHat<data_t>(std::pow(4, -j) * w,
     234       84708 :                                                  std::pow(4, -j) * k * w + std::pow(2, -j) * h);
     235       84708 :                 } else {
     236       75036 :                     sectionv(w < 0 ? w + _width : w, h < 0 ? h + _height : h) =
     237       75036 :                         shearlet::psiHat<data_t>(std::pow(4, -j) * h,
     238       75036 :                                                  std::pow(4, -j) * k * h + std::pow(2, -j) * w);
     239       75036 :                 }
     240      159744 :             }
     241        4992 :         }
     242             : 
     243         156 :         _spectra.value().slice(hSliceIndex) = sectionh;
     244         156 :         _spectra.value().slice(vSliceIndex) = sectionv;
     245         156 :     }
     246             : 
     247             :     template <typename ret_t, typename data_t>
     248             :     void ShearletTransform<ret_t, data_t>::_computeSpectraAtSeamLines(index_t j, index_t k,
     249             :                                                                       index_t hxvSliceIndex) const
     250          48 :     {
     251          48 :         DataContainer<data_t> sectionhxv(VolumeDescriptor{{_width, _height}});
     252          48 :         sectionhxv = 0;
     253             : 
     254          48 :         auto negativeHalfWidth = static_cast<index_t>(-std::floor(_width / 2.0));
     255          48 :         auto halfWidth = static_cast<index_t>(std::ceil(_width / 2.0));
     256          48 :         auto negativeHalfHeight = static_cast<index_t>(-std::floor(_height / 2.0));
     257          48 :         auto halfHeight = static_cast<index_t>(std::ceil(_height / 2.0));
     258             : 
     259             :         // TODO attempt to refactor the negative indexing
     260        1584 :         for (auto w = negativeHalfWidth; w < halfWidth; w++) {
     261       50688 :             for (auto h = negativeHalfHeight; h < halfHeight; h++) {
     262       49152 :                 if (std::abs(h) <= std::abs(w)) {
     263       26064 :                     sectionhxv(w < 0 ? w + _width : w, h < 0 ? h + _height : h) =
     264       26064 :                         shearlet::psiHat<data_t>(std::pow(4, -j) * w,
     265       26064 :                                                  std::pow(4, -j) * k * w + std::pow(2, -j) * h);
     266       26064 :                 } else {
     267       23088 :                     sectionhxv(w < 0 ? w + _width : w, h < 0 ? h + _height : h) =
     268       23088 :                         shearlet::psiHat<data_t>(std::pow(4, -j) * h,
     269       23088 :                                                  std::pow(4, -j) * k * h + std::pow(2, -j) * w);
     270       23088 :                 }
     271       49152 :             }
     272        1536 :         }
     273             : 
     274          48 :         _spectra.value().slice(hxvSliceIndex) = sectionhxv;
     275          48 :     }
     276             : 
     277             :     template <typename ret_t, typename data_t>
     278             :     auto ShearletTransform<ret_t, data_t>::getSpectra() const -> DataContainer<data_t>
     279         248 :     {
     280         248 :         if (!_spectra.has_value()) {
     281           2 :             throw LogicError(std::string("ShearletTransform: the spectra is not yet computed"));
     282           2 :         }
     283         246 :         return _spectra.value();
     284         246 :     }
     285             : 
     286             :     template <typename ret_t, typename data_t>
     287             :     bool ShearletTransform<ret_t, data_t>::isSpectraComputed() const
     288          16 :     {
     289          16 :         return _spectra.has_value();
     290          16 :     }
     291             : 
     292             :     template <typename ret_t, typename data_t>
     293             :     index_t ShearletTransform<ret_t, data_t>::calculateNumOfScales(index_t width, index_t height)
     294           6 :     {
     295           6 :         return static_cast<index_t>(std::log2(std::max(width, height)) / 2.0);
     296           6 :     }
     297             : 
     298             :     template <typename ret_t, typename data_t>
     299             :     index_t ShearletTransform<ret_t, data_t>::calculateNumOfLayers(index_t width, index_t height)
     300           0 :     {
     301           0 :         return static_cast<index_t>(std::pow(2, (calculateNumOfScales(width, height) + 2)) - 3);
     302           0 :     }
     303             : 
     304             :     template <typename ret_t, typename data_t>
     305             :     index_t ShearletTransform<ret_t, data_t>::calculateNumOfLayers(index_t numOfScales)
     306          28 :     {
     307          28 :         return static_cast<index_t>(std::pow(2, numOfScales + 2) - 3);
     308          28 :     }
     309             : 
     310             :     template <typename ret_t, typename data_t>
     311             :     auto ShearletTransform<ret_t, data_t>::getWidth() const -> index_t
     312           2 :     {
     313           2 :         return _width;
     314           2 :     }
     315             : 
     316             :     template <typename ret_t, typename data_t>
     317             :     auto ShearletTransform<ret_t, data_t>::getHeight() const -> index_t
     318           2 :     {
     319           2 :         return _height;
     320           2 :     }
     321             : 
     322             :     template <typename ret_t, typename data_t>
     323             :     auto ShearletTransform<ret_t, data_t>::getNumOfLayers() const -> index_t
     324         250 :     {
     325         250 :         return _numOfLayers;
     326         250 :     }
     327             : 
     328             :     template <typename ret_t, typename data_t>
     329             :     ShearletTransform<ret_t, data_t>* ShearletTransform<ret_t, data_t>::cloneImpl() const
     330           2 :     {
     331           2 :         return new ShearletTransform<ret_t, data_t>(_width, _height, _numOfScales, _spectra);
     332           2 :     }
     333             : 
     334             :     template <typename ret_t, typename data_t>
     335             :     bool ShearletTransform<ret_t, data_t>::isEqual(const LinearOperator<ret_t>& other) const
     336           2 :     {
     337           2 :         if (!LinearOperator<ret_t>::isEqual(other))
     338           0 :             return false;
     339             : 
     340           2 :         auto otherST = downcast_safe<ShearletTransform<ret_t, data_t>>(&other);
     341             : 
     342           2 :         if (!otherST)
     343           0 :             return false;
     344             : 
     345           2 :         if (_width != otherST->_width)
     346           0 :             return false;
     347             : 
     348           2 :         if (_height != otherST->_height)
     349           0 :             return false;
     350             : 
     351           2 :         if (_numOfScales != otherST->_numOfScales)
     352           0 :             return false;
     353             : 
     354           2 :         return true;
     355           2 :     }
     356             : 
     357             :     // ------------------------------------------
     358             :     // explicit template instantiation
     359             :     template class ShearletTransform<float, float>;
     360             :     template class ShearletTransform<elsa::complex<float>, float>;
     361             :     template class ShearletTransform<double, double>;
     362             :     template class ShearletTransform<elsa::complex<double>, double>;
     363             : } // namespace elsa

Generated by: LCOV version 1.14