LCOV - code coverage report
Current view: top level - operators - ShearletTransform.cpp (source / functions) Hit Total Coverage
Test: test_coverage.info.cleaned Lines: 0 180 0.0 %
Date: 2022-08-04 03:43:28 Functions: 0 84 0.0 %

          Line data    Source code
       1             : #include "ShearletTransform.h"
       2             : #include "FourierTransform.h"
       3             : #include "VolumeDescriptor.h"
       4             : #include "Timer.h"
       5             : #include "Math.hpp"
       6             : 
       7             : namespace elsa
       8             : {
       9             :     template <typename ret_t, typename data_t>
      10           0 :     ShearletTransform<ret_t, data_t>::ShearletTransform(IndexVector_t spatialDimensions)
      11           0 :         : ShearletTransform(spatialDimensions[0], spatialDimensions[1])
      12             :     {
      13           0 :         if (spatialDimensions.size() != 2) {
      14           0 :             throw LogicError("ShearletTransform: Only 2D shape supported");
      15             :         }
      16           0 :     }
      17             : 
      18             :     template <typename ret_t, typename data_t>
      19           0 :     ShearletTransform<ret_t, data_t>::ShearletTransform(index_t width, index_t height)
      20           0 :         : ShearletTransform(width, height, calculateNumOfScales(width, height))
      21             :     {
      22           0 :     }
      23             : 
      24             :     template <typename ret_t, typename data_t>
      25           0 :     ShearletTransform<ret_t, data_t>::ShearletTransform(index_t width, index_t height,
      26             :                                                         index_t numOfScales)
      27           0 :         : ShearletTransform(width, height, numOfScales, std::nullopt)
      28             :     {
      29           0 :     }
      30             : 
      31             :     template <typename ret_t, typename data_t>
      32           0 :     ShearletTransform<ret_t, data_t>::ShearletTransform(
      33             :         index_t width, index_t height, index_t numOfScales,
      34             :         std::optional<DataContainer<data_t>> spectra)
      35             :         : LinearOperator<ret_t>(
      36           0 :             VolumeDescriptor{{width, height}},
      37           0 :             VolumeDescriptor{{width, height, calculateNumOfLayers(numOfScales)}}),
      38           0 :           _spectra{spectra},
      39           0 :           _width{width},
      40           0 :           _height{height},
      41           0 :           _numOfScales{numOfScales},
      42           0 :           _numOfLayers{calculateNumOfLayers(numOfScales)}
      43             :     {
      44           0 :         if (width < 0 || height < 0) {
      45           0 :             throw LogicError("ShearletTransform: negative width/height were provided");
      46             :         }
      47           0 :         if (numOfScales < 0) {
      48           0 :             throw LogicError("ShearletTransform: negative number of scales was provided");
      49             :         }
      50           0 :     }
      51             : 
      52             :     // TODO implement sumByAxis in DataContainer and remove me
      53             :     template <typename ret_t, typename data_t>
      54           0 :     DataContainer<elsa::complex<data_t>> ShearletTransform<ret_t, data_t>::sumByLastAxis(
      55             :         DataContainer<elsa::complex<data_t>> dc) const
      56             :     {
      57           0 :         auto coeffsPerDim = dc.getDataDescriptor().getNumberOfCoefficientsPerDimension();
      58           0 :         index_t width = coeffsPerDim[0];
      59           0 :         index_t height = coeffsPerDim[1];
      60           0 :         index_t layers = coeffsPerDim[2];
      61           0 :         DataContainer<elsa::complex<data_t>> summedDC(VolumeDescriptor{{width, height}});
      62             : 
      63           0 :         for (index_t j = 0; j < width; j++) {
      64           0 :             for (index_t k = 0; k < height; k++) {
      65           0 :                 elsa::complex<data_t> currValue = 0;
      66           0 :                 for (index_t i = 0; i < layers; i++) {
      67           0 :                     currValue += dc(j, k, i);
      68             :                 }
      69           0 :                 summedDC(j, k) = currValue;
      70             :             }
      71             :         }
      72             : 
      73           0 :         return summedDC;
      74           0 :     }
      75             : 
      76             :     template <typename ret_t, typename data_t>
      77           0 :     void ShearletTransform<ret_t, data_t>::applyImpl(const DataContainer<ret_t>& x,
      78             :                                                      DataContainer<ret_t>& Ax) const
      79             :     {
      80           0 :         Timer timeguard("ShearletTransform", "apply");
      81             : 
      82           0 :         if (_width != this->getDomainDescriptor().getNumberOfCoefficientsPerDimension()[0]
      83           0 :             || _height != this->getDomainDescriptor().getNumberOfCoefficientsPerDimension()[1]) {
      84           0 :             throw InvalidArgumentError("ShearletTransform: Width and height of the input do not "
      85             :                                        "match to that of this shearlet system");
      86             :         }
      87             : 
      88             :         Logger::get("ShearletTransform")
      89           0 :             ->info("Running the shearlet transform on a 2D signal of shape ({}, {}), on {} "
      90             :                    "scales with an oversampling factor of {} and {} spectra",
      91           0 :                    _width, _height, _numOfScales, _numOfLayers,
      92           0 :                    isSpectraComputed() ? "precomputed" : "non-precomputed");
      93             : 
      94           0 :         if (!isSpectraComputed()) {
      95           0 :             computeSpectra();
      96             :         }
      97             : 
      98           0 :         FourierTransform<elsa::complex<data_t>> fourierTransform(x.getDataDescriptor());
      99             : 
     100           0 :         DataContainer<elsa::complex<data_t>> fftImg = fourierTransform.apply(x.asComplex());
     101             : 
     102           0 :         for (index_t i = 0; i < getNumOfLayers(); i++) {
     103           0 :             DataContainer<elsa::complex<data_t>> temp =
     104             :                 getSpectra().slice(i).viewAs(x.getDataDescriptor()).asComplex() * fftImg;
     105             :             if constexpr (isComplex<ret_t>) {
     106           0 :                 Ax.slice(i) = fourierTransform.applyAdjoint(temp);
     107             :             } else {
     108           0 :                 Ax.slice(i) = real(fourierTransform.applyAdjoint(temp));
     109             :             }
     110             :         }
     111           0 :     }
     112             : 
     113             :     template <typename ret_t, typename data_t>
     114           0 :     void ShearletTransform<ret_t, data_t>::applyAdjointImpl(const DataContainer<ret_t>& y,
     115             :                                                             DataContainer<ret_t>& Aty) const
     116             :     {
     117           0 :         Timer timeguard("ShearletTransform", "applyAdjoint");
     118             : 
     119           0 :         if (_width != this->getDomainDescriptor().getNumberOfCoefficientsPerDimension()[0]
     120           0 :             || _height != this->getDomainDescriptor().getNumberOfCoefficientsPerDimension()[1]) {
     121           0 :             throw InvalidArgumentError("ShearletTransform: Width and height of the input do not "
     122             :                                        "match to that of this shearlet system");
     123             :         }
     124             : 
     125             :         Logger::get("ShearletTransform")
     126           0 :             ->info("Running the inverse shearlet transform on a 2D signal of shape ({}, {}), "
     127             :                    "on {} "
     128             :                    "scales with an oversampling factor of {} and {} spectra",
     129           0 :                    _width, _height, _numOfScales, _numOfLayers,
     130           0 :                    isSpectraComputed() ? "precomputed" : "non-precomputed");
     131             : 
     132           0 :         if (!isSpectraComputed()) {
     133           0 :             computeSpectra();
     134             :         }
     135             : 
     136           0 :         FourierTransform<elsa::complex<data_t>> fourierTransform(Aty.getDataDescriptor());
     137             : 
     138           0 :         DataContainer<elsa::complex<data_t>> intermRes(y.getDataDescriptor());
     139             : 
     140           0 :         for (index_t i = 0; i < getNumOfLayers(); i++) {
     141           0 :             DataContainer<elsa::complex<data_t>> temp =
     142             :                 fourierTransform.apply(y.slice(i).viewAs(Aty.getDataDescriptor()).asComplex())
     143             :                 * getSpectra().slice(i).viewAs(Aty.getDataDescriptor()).asComplex();
     144           0 :             intermRes.slice(i) = fourierTransform.applyAdjoint(temp);
     145             :         }
     146             : 
     147             :         if constexpr (isComplex<ret_t>) {
     148           0 :             Aty = sumByLastAxis(intermRes);
     149             :         } else {
     150           0 :             Aty = real(sumByLastAxis(intermRes));
     151             :         }
     152           0 :     }
     153             : 
     154             :     template <typename ret_t, typename data_t>
     155           0 :     void ShearletTransform<ret_t, data_t>::computeSpectra() const
     156             :     {
     157           0 :         if (isSpectraComputed()) {
     158           0 :             Logger::get("ShearletTransform")->warn("Spectra have already been computed!");
     159             :         }
     160             : 
     161           0 :         _spectra = DataContainer<data_t>(VolumeDescriptor{{_width, _height, _numOfLayers}});
     162             : 
     163           0 :         _computeSpectraAtLowFreq();
     164             : 
     165           0 :         for (index_t j = 0; j < _numOfScales; j++) {
     166           0 :             auto twoPowJ = static_cast<index_t>(std::pow(2, j));
     167           0 :             auto shearletsAtJ = static_cast<index_t>(std::pow(2, j + 2));
     168           0 :             index_t shearletsUpUntilJ = shearletsAtJ - 3;
     169           0 :             index_t index = 1;
     170             : 
     171           0 :             _computeSpectraAtSeamLines(j, -twoPowJ, shearletsUpUntilJ + twoPowJ);
     172           0 :             for (auto k = -twoPowJ + 1; k < twoPowJ; k++) {
     173             :                 // modulo instead of remainder for negative numbers is needed here, therefore doing
     174             :                 // "((a % b) + b) % b" instead of "a % b"
     175           0 :                 index_t modIndex =
     176           0 :                     (((twoPowJ - index + 1) % shearletsAtJ) + shearletsAtJ) % shearletsAtJ;
     177           0 :                 if (modIndex == 0) {
     178           0 :                     modIndex = shearletsAtJ - 1;
     179             :                 } else {
     180           0 :                     --modIndex;
     181             :                 }
     182             : 
     183           0 :                 _computeSpectraAtConicRegions(j, k, shearletsUpUntilJ + modIndex,
     184           0 :                                               shearletsUpUntilJ + twoPowJ + index);
     185           0 :                 ++index;
     186             :             }
     187           0 :             _computeSpectraAtSeamLines(j, twoPowJ, shearletsUpUntilJ + twoPowJ + index);
     188             :         }
     189           0 :     }
     190             : 
     191             :     template <typename ret_t, typename data_t>
     192           0 :     void ShearletTransform<ret_t, data_t>::_computeSpectraAtLowFreq() const
     193             :     {
     194           0 :         DataContainer<data_t> sectionZero(VolumeDescriptor{{_width, _height}});
     195           0 :         sectionZero = 0;
     196             : 
     197           0 :         auto negativeHalfWidth = static_cast<index_t>(-std::floor(_width / 2.0));
     198           0 :         auto halfWidth = static_cast<index_t>(std::ceil(_width / 2.0));
     199           0 :         auto negativeHalfHeight = static_cast<index_t>(-std::floor(_height / 2.0));
     200           0 :         auto halfHeight = static_cast<index_t>(std::ceil(_height / 2.0));
     201             : 
     202             :         // TODO attempt to refactor the negative indexing
     203           0 :         for (auto w = negativeHalfWidth; w < halfWidth; w++) {
     204           0 :             for (auto h = negativeHalfHeight; h < halfHeight; h++) {
     205           0 :                 sectionZero(w < 0 ? w + _width : w, h < 0 ? h + _height : h) =
     206           0 :                     shearlet::phiHat<data_t>(static_cast<data_t>(w), static_cast<data_t>(h));
     207             :             }
     208             :         }
     209             : 
     210           0 :         _spectra.value().slice(0) = sectionZero;
     211           0 :     }
     212             : 
     213             :     template <typename ret_t, typename data_t>
     214           0 :     void ShearletTransform<ret_t, data_t>::_computeSpectraAtConicRegions(index_t j, index_t k,
     215             :                                                                          index_t hSliceIndex,
     216             :                                                                          index_t vSliceIndex) const
     217             :     {
     218           0 :         DataContainer<data_t> sectionh(VolumeDescriptor{{_width, _height}});
     219           0 :         sectionh = 0;
     220           0 :         DataContainer<data_t> sectionv(VolumeDescriptor{{_width, _height}});
     221           0 :         sectionv = 0;
     222             : 
     223           0 :         auto negativeHalfWidth = static_cast<index_t>(-std::floor(_width / 2.0));
     224           0 :         auto halfWidth = static_cast<index_t>(std::ceil(_width / 2.0));
     225           0 :         auto negativeHalfHeight = static_cast<index_t>(-std::floor(_height / 2.0));
     226           0 :         auto halfHeight = static_cast<index_t>(std::ceil(_height / 2.0));
     227             : 
     228             :         // TODO attempt to refactor the negative indexing
     229           0 :         for (auto w = negativeHalfWidth; w < halfWidth; w++) {
     230           0 :             for (auto h = negativeHalfHeight; h < halfHeight; h++) {
     231           0 :                 if (std::abs(h) <= std::abs(w)) {
     232           0 :                     sectionh(w < 0 ? w + _width : w, h < 0 ? h + _height : h) =
     233           0 :                         shearlet::psiHat<data_t>(std::pow(4, -j) * w,
     234           0 :                                                  std::pow(4, -j) * k * w + std::pow(2, -j) * h);
     235             :                 } else {
     236           0 :                     sectionv(w < 0 ? w + _width : w, h < 0 ? h + _height : h) =
     237           0 :                         shearlet::psiHat<data_t>(std::pow(4, -j) * h,
     238           0 :                                                  std::pow(4, -j) * k * h + std::pow(2, -j) * w);
     239             :                 }
     240             :             }
     241             :         }
     242             : 
     243           0 :         _spectra.value().slice(hSliceIndex) = sectionh;
     244           0 :         _spectra.value().slice(vSliceIndex) = sectionv;
     245           0 :     }
     246             : 
     247             :     template <typename ret_t, typename data_t>
     248           0 :     void ShearletTransform<ret_t, data_t>::_computeSpectraAtSeamLines(index_t j, index_t k,
     249             :                                                                       index_t hxvSliceIndex) const
     250             :     {
     251           0 :         DataContainer<data_t> sectionhxv(VolumeDescriptor{{_width, _height}});
     252           0 :         sectionhxv = 0;
     253             : 
     254           0 :         auto negativeHalfWidth = static_cast<index_t>(-std::floor(_width / 2.0));
     255           0 :         auto halfWidth = static_cast<index_t>(std::ceil(_width / 2.0));
     256           0 :         auto negativeHalfHeight = static_cast<index_t>(-std::floor(_height / 2.0));
     257           0 :         auto halfHeight = static_cast<index_t>(std::ceil(_height / 2.0));
     258             : 
     259             :         // TODO attempt to refactor the negative indexing
     260           0 :         for (auto w = negativeHalfWidth; w < halfWidth; w++) {
     261           0 :             for (auto h = negativeHalfHeight; h < halfHeight; h++) {
     262           0 :                 if (std::abs(h) <= std::abs(w)) {
     263           0 :                     sectionhxv(w < 0 ? w + _width : w, h < 0 ? h + _height : h) =
     264           0 :                         shearlet::psiHat<data_t>(std::pow(4, -j) * w,
     265           0 :                                                  std::pow(4, -j) * k * w + std::pow(2, -j) * h);
     266             :                 } else {
     267           0 :                     sectionhxv(w < 0 ? w + _width : w, h < 0 ? h + _height : h) =
     268           0 :                         shearlet::psiHat<data_t>(std::pow(4, -j) * h,
     269           0 :                                                  std::pow(4, -j) * k * h + std::pow(2, -j) * w);
     270             :                 }
     271             :             }
     272             :         }
     273             : 
     274           0 :         _spectra.value().slice(hxvSliceIndex) = sectionhxv;
     275           0 :     }
     276             : 
     277             :     template <typename ret_t, typename data_t>
     278           0 :     auto ShearletTransform<ret_t, data_t>::getSpectra() const -> DataContainer<data_t>
     279             :     {
     280           0 :         if (!_spectra.has_value()) {
     281           0 :             throw LogicError(std::string("ShearletTransform: the spectra is not yet computed"));
     282             :         }
     283           0 :         return _spectra.value();
     284             :     }
     285             : 
     286             :     template <typename ret_t, typename data_t>
     287           0 :     bool ShearletTransform<ret_t, data_t>::isSpectraComputed() const
     288             :     {
     289           0 :         return _spectra.has_value();
     290             :     }
     291             : 
     292             :     template <typename ret_t, typename data_t>
     293           0 :     index_t ShearletTransform<ret_t, data_t>::calculateNumOfScales(index_t width, index_t height)
     294             :     {
     295           0 :         return static_cast<index_t>(std::log2(std::max(width, height)) / 2.0);
     296             :     }
     297             : 
     298             :     template <typename ret_t, typename data_t>
     299           0 :     index_t ShearletTransform<ret_t, data_t>::calculateNumOfLayers(index_t width, index_t height)
     300             :     {
     301           0 :         return static_cast<index_t>(std::pow(2, (calculateNumOfScales(width, height) + 2)) - 3);
     302             :     }
     303             : 
     304             :     template <typename ret_t, typename data_t>
     305           0 :     index_t ShearletTransform<ret_t, data_t>::calculateNumOfLayers(index_t numOfScales)
     306             :     {
     307           0 :         return static_cast<index_t>(std::pow(2, numOfScales + 2) - 3);
     308             :     }
     309             : 
     310             :     template <typename ret_t, typename data_t>
     311           0 :     auto ShearletTransform<ret_t, data_t>::getWidth() const -> index_t
     312             :     {
     313           0 :         return _width;
     314             :     }
     315             : 
     316             :     template <typename ret_t, typename data_t>
     317           0 :     auto ShearletTransform<ret_t, data_t>::getHeight() const -> index_t
     318             :     {
     319           0 :         return _height;
     320             :     }
     321             : 
     322             :     template <typename ret_t, typename data_t>
     323           0 :     auto ShearletTransform<ret_t, data_t>::getNumOfLayers() const -> index_t
     324             :     {
     325           0 :         return _numOfLayers;
     326             :     }
     327             : 
     328             :     template <typename ret_t, typename data_t>
     329           0 :     ShearletTransform<ret_t, data_t>* ShearletTransform<ret_t, data_t>::cloneImpl() const
     330             :     {
     331           0 :         return new ShearletTransform<ret_t, data_t>(_width, _height, _numOfScales, _spectra);
     332             :     }
     333             : 
     334             :     template <typename ret_t, typename data_t>
     335           0 :     bool ShearletTransform<ret_t, data_t>::isEqual(const LinearOperator<ret_t>& other) const
     336             :     {
     337           0 :         if (!LinearOperator<ret_t>::isEqual(other))
     338           0 :             return false;
     339             : 
     340           0 :         auto otherST = downcast_safe<ShearletTransform<ret_t, data_t>>(&other);
     341             : 
     342           0 :         if (!otherST)
     343           0 :             return false;
     344             : 
     345           0 :         if (_width != otherST->_width)
     346           0 :             return false;
     347             : 
     348           0 :         if (_height != otherST->_height)
     349           0 :             return false;
     350             : 
     351           0 :         if (_numOfScales != otherST->_numOfScales)
     352           0 :             return false;
     353             : 
     354           0 :         return true;
     355             :     }
     356             : 
     357             :     // ------------------------------------------
     358             :     // explicit template instantiation
     359             :     template class ShearletTransform<float, float>;
     360             :     template class ShearletTransform<elsa::complex<float>, float>;
     361             :     template class ShearletTransform<double, double>;
     362             :     template class ShearletTransform<elsa::complex<double>, double>;
     363             : } // namespace elsa

Generated by: LCOV version 1.14