LCOV - code coverage report
Current view: top level - solvers - OrthogonalMatchingPursuit.cpp (source / functions) Hit Total Coverage
Test: test_coverage.info.cleaned Lines: 0 48 0.0 %
Date: 2022-08-04 03:43:28 Functions: 0 10 0.0 %

          Line data    Source code
       1             : #include "OrthogonalMatchingPursuit.h"
       2             : #include "TypeCasts.hpp"
       3             : 
       4             : namespace elsa
       5             : {
       6             :     template <typename data_t>
       7           0 :     OrthogonalMatchingPursuit<data_t>::OrthogonalMatchingPursuit(
       8             :         const RepresentationProblem<data_t>& problem, data_t epsilon)
       9           0 :         : Solver<data_t>(problem), _epsilon{epsilon}
      10             :     {
      11           0 :     }
      12             : 
      13             :     template <typename data_t>
      14           0 :     DataContainer<data_t>& OrthogonalMatchingPursuit<data_t>::solveImpl(index_t iterations)
      15             :     {
      16             :         // Safe, as it's the only possible input
      17           0 :         const auto& reprProblem = downcast<RepresentationProblem<data_t>>(*_problem);
      18             : 
      19           0 :         const auto& dict = reprProblem.getDictionary();
      20           0 :         const auto& residual = _problem->getDataTerm().getResidual();
      21           0 :         auto& currentRepresentation = _problem->getCurrentSolution();
      22             : 
      23           0 :         IndexVector_t support(0); // the atoms used for the representation
      24           0 :         currentRepresentation = 0;
      25             : 
      26           0 :         index_t i = 0;
      27           0 :         while (i < iterations && _problem->evaluate() > _epsilon) {
      28           0 :             index_t k = mostCorrelatedAtom(dict, residual.evaluate(currentRepresentation));
      29             : 
      30           0 :             support.conservativeResize(support.size() + 1);
      31           0 :             support[support.size() - 1] = k;
      32           0 :             Dictionary<data_t> purgedDict = dict.getSupportedDictionary(support);
      33             : 
      34           0 :             WLSProblem<data_t> wls(purgedDict, reprProblem.getSignal());
      35             : 
      36           0 :             CG cgSolver(wls);
      37           0 :             const auto& wlsSolution = cgSolver.solve(10);
      38             : 
      39             :             // wlsSolution has only non-zero coefficients, copy those to the full solution with zero
      40             :             // coefficients
      41           0 :             index_t j = 0;
      42           0 :             for (const auto& atomIndex : support) {
      43           0 :                 currentRepresentation[atomIndex] = wlsSolution[j];
      44           0 :                 ++j;
      45             :             }
      46             : 
      47           0 :             ++i;
      48             :         }
      49             : 
      50           0 :         return getCurrentSolution();
      51           0 :     }
      52             : 
      53             :     template <typename data_t>
      54           0 :     index_t OrthogonalMatchingPursuit<data_t>::mostCorrelatedAtom(
      55             :         const Dictionary<data_t>& dict, const DataContainer<data_t>& evaluatedResidual)
      56             :     {
      57             :         // for this to work atom has to be L2-normalized
      58           0 :         data_t maxCorrelation = 0;
      59           0 :         index_t argMaxCorrelation = 0;
      60             : 
      61           0 :         for (index_t j = 0; j < dict.getNumberOfAtoms(); ++j) {
      62           0 :             const auto& atom = dict.getAtom(j);
      63           0 :             data_t correlation_j = std::abs(atom.dot(evaluatedResidual));
      64             : 
      65           0 :             if (correlation_j > maxCorrelation) {
      66           0 :                 maxCorrelation = correlation_j;
      67           0 :                 argMaxCorrelation = j;
      68             :             }
      69             :         }
      70           0 :         return argMaxCorrelation;
      71             :     }
      72             : 
      73             :     template <typename data_t>
      74           0 :     OrthogonalMatchingPursuit<data_t>* OrthogonalMatchingPursuit<data_t>::cloneImpl() const
      75             :     {
      76           0 :         return new OrthogonalMatchingPursuit(downcast<RepresentationProblem<data_t>>(*_problem),
      77           0 :                                              _epsilon);
      78             :     }
      79             : 
      80             :     template <typename data_t>
      81           0 :     bool OrthogonalMatchingPursuit<data_t>::isEqual(const Solver<data_t>& other) const
      82             :     {
      83           0 :         if (!Solver<data_t>::isEqual(other))
      84           0 :             return false;
      85             : 
      86           0 :         auto otherOMP = downcast_safe<OrthogonalMatchingPursuit>(&other);
      87           0 :         if (!otherOMP)
      88           0 :             return false;
      89             : 
      90           0 :         if (_epsilon != otherOMP->_epsilon)
      91           0 :             return false;
      92             : 
      93           0 :         return true;
      94             :     }
      95             : 
      96             :     // ------------------------------------------
      97             :     // explicit template instantiation
      98             :     template class OrthogonalMatchingPursuit<float>;
      99             :     template class OrthogonalMatchingPursuit<double>;
     100             : 
     101             : } // namespace elsa

Generated by: LCOV version 1.14