LCOV - code coverage report
Current view: top level - elsa/solvers - OrthogonalMatchingPursuit.cpp (source / functions) Hit Total Coverage
Test: coverage-all.lcov Lines: 46 48 95.8 %
Date: 2024-05-16 04:22:26 Functions: 6 12 50.0 %

          Line data    Source code
       1             : #include "OrthogonalMatchingPursuit.h"
       2             : #include "TypeCasts.hpp"
       3             : #include "CGLS.h"
       4             : 
       5             : #include <iostream>
       6             : 
       7             : namespace elsa
       8             : {
       9             :     template <typename data_t>
      10             :     OrthogonalMatchingPursuit<data_t>::OrthogonalMatchingPursuit(const Dictionary<data_t>& D,
      11             :                                                                  const DataContainer<data_t>& y,
      12             :                                                                  data_t epsilon)
      13             :         : Solver<data_t>(), dict_(D), signal_(y), epsilon_{epsilon}
      14           3 :     {
      15           3 :     }
      16             : 
      17             :     template <typename data_t>
      18             :     DataContainer<data_t>
      19             :         OrthogonalMatchingPursuit<data_t>::solve(index_t iterations,
      20             :                                                  std::optional<DataContainer<data_t>> x0)
      21           1 :     {
      22           2 :         auto eval = [&](auto x) { return 0.5 * (dict_.apply(x) - signal_).squaredL2Norm(); };
      23             : 
      24           1 :         auto x = extract_or(x0, dict_.getDomainDescriptor());
      25             : 
      26           1 :         IndexVector_t support(0); // the atoms used for the representation
      27             : 
      28           1 :         index_t i = 0;
      29           3 :         while (i < iterations && eval(x) > epsilon_) {
      30           2 :             auto residual = dict_.apply(x) - signal_;
      31           2 :             index_t k = mostCorrelatedAtom(dict_, residual);
      32             : 
      33           2 :             support.conservativeResize(support.size() + 1);
      34           2 :             support[support.size() - 1] = k;
      35           2 :             Dictionary<data_t> purgedDict = dict_.getSupportedDictionary(support);
      36             : 
      37           2 :             CGLS<data_t> cgSolver(purgedDict, signal_);
      38           2 :             const auto wlsSolution = cgSolver.solve(10);
      39             : 
      40             :             // wlsSolution has only non-zero coefficients, copy those to the full solution with zero
      41             :             // coefficients
      42           2 :             index_t j = 0;
      43           3 :             for (const auto& atomIndex : support) {
      44           3 :                 x[atomIndex] = wlsSolution[j];
      45           3 :                 ++j;
      46           3 :             }
      47             : 
      48           2 :             ++i;
      49           2 :         }
      50             : 
      51           1 :         return x;
      52           1 :     }
      53             : 
      54             :     template <typename data_t>
      55             :     index_t OrthogonalMatchingPursuit<data_t>::mostCorrelatedAtom(
      56             :         const Dictionary<data_t>& dict, const DataContainer<data_t>& evaluatedResidual)
      57           2 :     {
      58             :         // for this to work atom has to be L2-normalized
      59           2 :         data_t maxCorrelation = 0;
      60           2 :         index_t argMaxCorrelation = 0;
      61             : 
      62           8 :         for (index_t j = 0; j < dict.getNumberOfAtoms(); ++j) {
      63           6 :             const auto& atom = dict.getAtom(j);
      64           6 :             data_t correlation_j = std::abs(atom.dot(evaluatedResidual));
      65             : 
      66           6 :             if (correlation_j > maxCorrelation) {
      67           3 :                 maxCorrelation = correlation_j;
      68           3 :                 argMaxCorrelation = j;
      69           3 :             }
      70           6 :         }
      71             : 
      72           2 :         return argMaxCorrelation;
      73           2 :     }
      74             : 
      75             :     template <typename data_t>
      76             :     OrthogonalMatchingPursuit<data_t>* OrthogonalMatchingPursuit<data_t>::cloneImpl() const
      77           1 :     {
      78           1 :         return new OrthogonalMatchingPursuit(dict_, signal_, epsilon_);
      79           1 :     }
      80             : 
      81             :     template <typename data_t>
      82             :     bool OrthogonalMatchingPursuit<data_t>::isEqual(const Solver<data_t>& other) const
      83           1 :     {
      84           1 :         auto otherOMP = downcast_safe<OrthogonalMatchingPursuit>(&other);
      85           1 :         if (!otherOMP)
      86           0 :             return false;
      87             : 
      88           1 :         if (epsilon_ != otherOMP->epsilon_)
      89           0 :             return false;
      90             : 
      91           1 :         return true;
      92           1 :     }
      93             : 
      94             :     // ------------------------------------------
      95             :     // explicit template instantiation
      96             :     template class OrthogonalMatchingPursuit<float>;
      97             :     template class OrthogonalMatchingPursuit<double>;
      98             : 
      99             : } // namespace elsa

Generated by: LCOV version 1.14