LCOV - code coverage report
Current view: top level - elsa/operators - SphericalFunctionTransform.cpp (source / functions) Hit Total Coverage
Test: coverage-all.lcov Lines: 48 67 71.6 %
Date: 2024-07-03 03:56:50 Functions: 8 16 50.0 %

          Line data    Source code
       1             : #include "DataContainer.h"
       2             : #include "IdenticalBlocksDescriptor.h"
       3             : #include "Math.hpp"
       4             : #include "SphericalCoefficientsDescriptor.h"
       5             : #include "TypeCasts.hpp"
       6             : #include "VolumeDescriptor.h"
       7             : #include "elsaDefines.h"
       8             : #include "SphericalFunctionTransform.h"
       9             : 
      10             : namespace elsa::axdt
      11             : {
      12             : 
      13             :     template <typename data_t>
      14             :     Matrix_t<data_t> evalSphericalHarmonics(const Symmetry symmetry, const index_t maxL,
      15             :                                             const DirVecList<data_t>& dirs)
      16          32 :     {
      17          32 :         size_t directions = dirs.size();
      18          32 :         Matrix_t<data_t> sphericalHarmonicsBasis{
      19          32 :             directions, SphericalCoefficientsDescriptor::coefficientCount(symmetry, maxL)};
      20             : 
      21          32 :         data_t normFac = sqrt(as<data_t>(4) * pi_t / dirs.size());
      22             : 
      23          32 : #pragma omp parallel for
      24          32 :         for (index_t i = 0; i < asSigned(directions); ++i) {
      25           0 :             index_t j = 0;
      26             : 
      27           0 :             auto dir = dirs[asUnsigned(i)];
      28             : 
      29             :             // theta: atan2 returns the elevation, so to get theta = the inclination, we need:
      30             :             // inclination = pi/2-elevation azimuth = phi
      31           0 :             auto theta = pi<data_t> / 2 + atan2(dir[2], hypot(dir[0], dir[1]));
      32           0 :             auto phi = atan2(dir[1], dir[0]);
      33           0 :             auto sh_dir = SH_basis_real<data_t>(maxL, theta, phi);
      34             : 
      35        3692 :             for (int l = 0; l <= maxL; ++l) {
      36       16488 :                 for (int m = -l; m <= l; ++m) {
      37       12796 :                     if (symmetry == Symmetry::even && (l % 2) != 0) {
      38        1992 :                         continue;
      39        1992 :                     }
      40             : 
      41       10804 :                     sphericalHarmonicsBasis(i, j) = normFac * sh_dir(l * l + l + m);
      42       10804 :                     j++;
      43       10804 :                 }
      44        3692 :             }
      45           0 :         }
      46             : 
      47          32 :         return sphericalHarmonicsBasis;
      48          32 :     }
      49             : 
      50             :     template <typename data_t>
      51             :     SphericalFunctionTransform<data_t>::SphericalFunctionTransform(
      52             :         const SphericalCoefficientsDescriptor& domainDescriptor,
      53             :         const DirVecList<data_t>& samplingDirections)
      54             :         : LinearOperator<data_t>{domainDescriptor,
      55             :                                  IdenticalBlocksDescriptor{
      56             :                                      asSigned(samplingDirections.size()),
      57             :                                      domainDescriptor.getDescriptorOfBlock(0)}},
      58             :           _samplingDirections(samplingDirections),
      59             :           _basis{evalSphericalHarmonics<data_t>(domainDescriptor.symmetry, domainDescriptor.degree,
      60             :                                                 samplingDirections)}
      61          30 :     {
      62          30 :     }
      63             : 
      64             :     template <typename data_t>
      65             :     SphericalFunctionTransform<data_t>::SphericalFunctionTransform(
      66             :         const SphericalFunctionTransform& other)
      67             :         : LinearOperator<data_t>{other},
      68             :           _samplingDirections(other._samplingDirections),
      69             :           _basis(other._basis)
      70           0 :     {
      71           0 :     }
      72             : 
      73             :     template <typename data_t>
      74             :     SphericalFunctionTransform<data_t>* SphericalFunctionTransform<data_t>::cloneImpl() const
      75           0 :     {
      76           0 :         return new SphericalFunctionTransform<data_t>(*this);
      77           0 :     }
      78             : 
      79             :     template <typename data_t>
      80             :     bool SphericalFunctionTransform<data_t>::isEqual(const LinearOperator<data_t>& other) const
      81           0 :     {
      82           0 :         if (!LinearOperator<data_t>::isEqual(other))
      83           0 :             return false;
      84             : 
      85           0 :         if (!is<SphericalFunctionTransform<data_t>>(other))
      86           0 :             return false;
      87             : 
      88           0 :         const auto& otherSFT = downcast<const SphericalFunctionTransform<data_t>>(other);
      89             : 
      90           0 :         return _samplingDirections == otherSFT._samplingDirections;
      91           0 :     }
      92             : 
      93             :     template <typename data_t>
      94             :     void SphericalFunctionTransform<data_t>::applyImpl(const DataContainer<data_t>& x,
      95             :                                                        DataContainer<data_t>& Ax) const
      96           2 :     {
      97           2 :         assert(x.getDataDescriptor() == this->getDomainDescriptor());
      98           2 :         assert(Ax.getDataDescriptor() == this->getRangeDescriptor());
      99             : 
     100           2 :         const auto& sphDesc =
     101           2 :             downcast<const SphericalCoefficientsDescriptor>(this->getDomainDescriptor());
     102             : 
     103           2 :         const auto numCoeffs = sphDesc.getNumberOfBlocks();
     104           2 :         const index_t numDirs = _samplingDirections.size();
     105           2 :         const auto volSize = x.getSize() / x.getNumberOfBlocks();
     106             : 
     107             :         // Eigen is column-major by default
     108           2 :         Eigen::Map<const Matrix_t<data_t>> sphericalCoefficients{
     109           2 :             thrust::raw_pointer_cast(x.storage().data()), volSize, numCoeffs};
     110             : 
     111           2 :         Eigen::Map<Matrix_t<data_t>> sampledDirections{
     112           2 :             thrust::raw_pointer_cast(Ax.storage().data()), volSize, numDirs};
     113             : 
     114             :         // Hope that Eigen can cope with these huge matrices...
     115             :         // (volSize×dirs) = (volSize×coeffs) * (coeffs×dirs)
     116           2 :         sampledDirections = sphericalCoefficients * _basis.transpose();
     117           2 :     }
     118             : 
     119             :     template <typename data_t>
     120             :     void SphericalFunctionTransform<data_t>::applyAdjointImpl(const DataContainer<data_t>& y,
     121             :                                                               DataContainer<data_t>& Aty) const
     122          30 :     {
     123          30 :         assert(y.getDataDescriptor() == this->getRangeDescriptor());
     124          30 :         assert(Aty.getDataDescriptor() == this->getDomainDescriptor());
     125             : 
     126          30 :         const auto& sphDesc =
     127          30 :             downcast<const SphericalCoefficientsDescriptor>(this->getDomainDescriptor());
     128             : 
     129          30 :         const auto numCoeffs = sphDesc.getNumberOfBlocks();
     130          30 :         const index_t numDirs = _samplingDirections.size();
     131          30 :         const auto volSize = y.getSize() / y.getNumberOfBlocks();
     132             : 
     133             :         // Eigen is column-major by default
     134          30 :         Eigen::Map<const Matrix_t<data_t>> sampledDirections{
     135          30 :             thrust::raw_pointer_cast(y.storage().data()), volSize, numDirs};
     136             : 
     137          30 :         Eigen::Map<Matrix_t<data_t>> sphericalCoefficients{
     138          30 :             thrust::raw_pointer_cast(Aty.storage().data()), volSize, numCoeffs};
     139             : 
     140             :         // This approximates the scalar product on S2 and hence computes the coefficients in the
     141             :         // spherical harmonics basis.
     142             :         // (volSize×coeffs) = (volSize×dirs) * (dirs×coeffs)
     143          30 :         sphericalCoefficients = sampledDirections * _basis;
     144          30 :     }
     145             : 
     146             :     template Matrix_t<float> evalSphericalHarmonics<float>(const Symmetry, const index_t,
     147             :                                                            const DirVecList<float>&);
     148             : 
     149             :     template Matrix_t<double> evalSphericalHarmonics<double>(const Symmetry, const index_t,
     150             :                                                              const DirVecList<double>&);
     151             : 
     152             :     template class SphericalFunctionTransform<float>;
     153             :     template class SphericalFunctionTransform<double>;
     154             : 
     155             : } // namespace elsa::axdt

Generated by: LCOV version 1.14