       1             : #include <set>
       2             : #include <numeric>
       3             : #include "SubsetSampler.h"
       4             : #include "PartitionDescriptor.h"
       5             : 
       6             : namespace elsa
       7             : {
       8             :     template <typename DetectorDescriptor_t, typename data_t>
       9             :     SubsetSampler<DetectorDescriptor_t, data_t>::SubsetSampler(
      10             :         const VolumeDescriptor& volumeDescriptor, const DetectorDescriptor_t& detectorDescriptor,
      11             :         index_t nSubsets, SamplingStrategy samplingStrategy)
      12             :         : _indexMapping(static_cast<std::size_t>(nSubsets)),
      13             :           _volumeDescriptor(volumeDescriptor),
      14             :           _fullDetectorDescriptor(detectorDescriptor),
      15             :           _nSubsets{nSubsets}
      16          12 :     {
      17          12 :         if (nSubsets <= 1) {
      18           0 :             throw std::invalid_argument("SubsetSampler: nSubsets must be >= 2");
      19           0 :         }
      20             : 
      21             :         // the mapping of data indices to subsets
      22             : 
      23          12 :         const auto numCoeffsPerDim = detectorDescriptor.getNumberOfCoefficientsPerDimension();
      24          12 :         const index_t nDimensions = detectorDescriptor.getNumberOfDimensions();
      25          12 :         const auto numElements = numCoeffsPerDim[nDimensions - 1];
      26          12 :         if (samplingStrategy == SamplingStrategy::ROUND_ROBIN) {
      27           6 :             std::vector<index_t> indices(static_cast<std::size_t>(numElements));
      28           6 :             std::iota(indices.begin(), indices.end(), 0);
      29           6 :             _indexMapping = splitRoundRobin(indices, _nSubsets);
      30           6 :         } else if (samplingStrategy == SamplingStrategy::ROTATIONAL_CLUSTERING) {
      31           6 :             _indexMapping = splitRotationalClustering(detectorDescriptor, _nSubsets);
      32           6 :         } else {
      33           0 :             throw std::invalid_argument("SubsetSampler: unsupported sampling strategy");
      34           0 :         }
      35             : 
      36             :         // create the detector descriptors that correspond to each subset
      37          48 :         for (const auto& blockIndices : _indexMapping) {
      38          48 :             std::vector<Geometry> geometry;
      39          48 :             geometry.reserve(blockIndices.size());
      40         216 :             for (auto index : blockIndices) {
      41         216 :                 geometry.emplace_back(detectorDescriptor.getGeometryAt(index));
      42         216 :             }
      43          48 :             IndexVector_t numOfCoeffsPerDim =
      44          48 :                 detectorDescriptor.getNumberOfCoefficientsPerDimension();
      45          48 :             numOfCoeffsPerDim[numOfCoeffsPerDim.size() - 1] =
      46          48 :                 static_cast<index_t>(blockIndices.size());
      47             : 
      48          48 :             _detectorDescriptors.emplace_back(DetectorDescriptor_t(numOfCoeffsPerDim, geometry));
      49          48 :         }
      50          12 :     }
      51             : 
      52             :     template <typename DetectorDescriptor_t, typename data_t>
      53             :     std::vector<std::vector<index_t>> SubsetSampler<DetectorDescriptor_t, data_t>::splitRoundRobin(
      54             :         const std::vector<index_t>& indices, index_t nSubsets)
      55          18 :     {
      56          18 :         std::vector<std::vector<index_t>> subsetIndices(static_cast<std::size_t>(nSubsets));
      57             : 
      58             :         // determine the mapping of indices to subsets
      59         426 :         for (std::size_t i = 0; i < indices.size(); ++i) {
      60         408 :             const auto subset = i % static_cast<std::size_t>(nSubsets);
      61         408 :             subsetIndices[subset].template emplace_back(indices[i]);
      62         408 :         }
      63             : 
      64          18 :         return subsetIndices;
      65          18 :     }
      66             : 
      67             :     template <typename DetectorDescriptor_t, typename data_t>
      68             :     std::vector<std::vector<index_t>>
      69             :         SubsetSampler<DetectorDescriptor_t, data_t>::splitRotationalClustering(
      70             :             const DetectorDescriptor_t& detectorDescriptor, index_t nSubsets)
      71           8 :     {
      72             : 
      73           8 :         const auto numCoeffsPerDim = detectorDescriptor.getNumberOfCoefficientsPerDimension();
      74           8 :         const index_t nDimensions = detectorDescriptor.getNumberOfDimensions();
      75           8 :         const auto numElements = numCoeffsPerDim[nDimensions - 1];
      76           8 :         std::vector<index_t> indices(static_cast<std::size_t>(numElements));
      77           8 :         std::iota(indices.begin(), indices.end(), 0);
      78           8 :         const auto geometry = detectorDescriptor.getGeometry();
      79             : 
      80             :         // angle between two rotation matrices used as a distance measure
      81        1660 :         auto dist = [nDimensions](auto& g1, auto& g2) {
      82        1660 :             const auto& r1 = g1.getRotationMatrix();
      83        1660 :             const auto& r2 = g2.getRotationMatrix();
      84        1660 :             auto product = r1 * r2.transpose();
      85        1660 :             if (nDimensions == 2) { // the 2D case
      86         938 :                 return static_cast<double>(std::atan2(product(1, 0), product(0, 0)));
      87         938 :             } else { // the 3D case
      88         722 :                 return std::acos((product.trace() - 1.0) / 2.0);
      89         722 :             }
      90        1660 :         };
      91             : 
      92           8 :         const auto first = geometry.front();
      93           8 :         std::sort(std::begin(indices), std::end(indices),
      94         830 :                   [dist, first, &geometry](auto lhs, auto rhs) {
      95         830 :                       return dist(first, geometry[static_cast<std::size_t>(lhs)])
      96         830 :                              < dist(first, geometry[static_cast<std::size_t>(rhs)]);
      97         830 :                   });
      98             : 
      99           8 :         return splitRoundRobin(indices, nSubsets);
     100           8 :     }
     101             : 
     102             :     template <typename DetectorDescriptor_t, typename data_t>
     103             :     SubsetSampler<DetectorDescriptor_t, data_t>::SubsetSampler(
     104             :         const SubsetSampler<DetectorDescriptor_t, data_t>& other)
     105             :         : _indexMapping{other._indexMapping},
     106             :           _volumeDescriptor(other._volumeDescriptor),
     107             :           _fullDetectorDescriptor(other._fullDetectorDescriptor),
     108             :           _nSubsets{other._nSubsets}
     109           4 :     {
     110          16 :         for (const auto& detectorDescriptor : other._detectorDescriptors) {
     111          16 :             _detectorDescriptors.emplace_back(detectorDescriptor);
     112          16 :         }
     113           4 :     }
     114             : 
     115             :     template <typename DetectorDescriptor_t, typename data_t>
     116             :     DataContainer<data_t> SubsetSampler<DetectorDescriptor_t, data_t>::getPartitionedData(
     117             :         const DataContainer<data_t>& sinogram)
     118           8 :     {
     119             :         // save the number of entries per subset
     120           8 :         IndexVector_t slicesInBlock(_indexMapping.size());
     121          40 :         for (unsigned long i = 0; i < _indexMapping.size(); i++) {
     122          32 :             slicesInBlock[static_cast<index_t>(i)] = static_cast<index_t>(_indexMapping[i].size());
     123          32 :         }
     124           8 :         PartitionDescriptor dataDescriptor(sinogram.getDataDescriptor(), slicesInBlock);
     125             : 
     126           8 :         const auto numCoeffsPerDim =
     127           8 :             sinogram.getDataDescriptor().getNumberOfCoefficientsPerDimension();
     128           8 :         auto partitionedData = DataContainer<data_t>(dataDescriptor);
     129             : 
     130             :         // resort the actual measurement partitionedData
     131           8 :         index_t coeffsPerRow = numCoeffsPerDim.head(numCoeffsPerDim.size() - 1).prod();
     132          40 :         for (index_t i = 0; i < _nSubsets; i++) {
     133             :             // the indices of the partitionedData rows belonging to this subset
     134          32 :             std::vector<index_t> indices = _indexMapping[static_cast<std::size_t>(i)];
     135             : 
     136          32 :             auto block = partitionedData.getBlock(i);
     137             : 
     138         176 :             for (std::size_t j = 0; j < indices.size(); j++) {
     139        9648 :                 for (int k = 0; k < coeffsPerRow; k++) {
     140        9504 :                     block[static_cast<index_t>(j) * coeffsPerRow + k] =
     141        9504 :                         sinogram[indices[j] * coeffsPerRow + k];
     142        9504 :                 }
     143         144 :             }
     144          32 :         }
     145           8 :         return partitionedData;
     146           8 :     }
     147             : 
     148             :     template <typename DetectorDescriptor_t, typename data_t>
     149             :     SubsetSampler<DetectorDescriptor_t, data_t>*
     150             :         SubsetSampler<DetectorDescriptor_t, data_t>::cloneImpl() const
     151           4 :     {
     152           4 :         return new SubsetSampler<DetectorDescriptor_t, data_t>(*this);
     153           4 :     }
     154             : 
     155             :     template <typename DetectorDescriptor_t, typename data_t>
     156             :     bool SubsetSampler<DetectorDescriptor_t, data_t>::isEqual(
     157             :         const SubsetSampler<DetectorDescriptor_t, data_t>& other) const
     158           4 :     {
     159           4 :         if (typeid(*this) != typeid(other))
     160           0 :             return false;
     161             : 
     162           4 :         if (_indexMapping != other._indexMapping)
     163           0 :             return false;
     164             : 
     165           4 :         if (_volumeDescriptor != other._volumeDescriptor)
     166           0 :             return false;
     167             : 
     168           4 :         if (_fullDetectorDescriptor != other._fullDetectorDescriptor)
     169           0 :             return false;
     170             : 
     171           4 :         if (_nSubsets != other._nSubsets)
     172           0 :             return false;
     173             : 
     174             :         // we do not need to check if the vector of detector descriptors is equal as this is
     175             :         // implied by the equality of _data in combination with the full detector descriptor
     176             : 
     177           4 :         return true;
     178           4 :     }
     179             : 
     180             :     // ------------------------------------------
     181             :     // explicit template instantiation
     182             :     template class SubsetSampler<PlanarDetectorDescriptor, float>;
     183             :     template class SubsetSampler<PlanarDetectorDescriptor, double>;
     184             :     template class SubsetSampler<PlanarDetectorDescriptor, complex<float>>;
     185             :     template class SubsetSampler<PlanarDetectorDescriptor, complex<double>>;
     186             : } // namespace elsa

