LCOV - code coverage report
Current view: top level - elsa/solvers - OrthogonalMatchingPursuit.cpp (source / functions) Hit Total Coverage
Test: coverage-all.lcov Lines: 48 50 96.0 %
Date: 2022-08-25 03:05:39 Functions: 5 10 50.0 %

          Line data    Source code
       1             : #include "OrthogonalMatchingPursuit.h"
       2             : #include "TypeCasts.hpp"
       3             : #include "WLSProblem.h"
       4             : #include "CG.h"
       5             : 
       6             : namespace elsa
       7             : {
       8             :     template <typename data_t>
       9             :     OrthogonalMatchingPursuit<data_t>::OrthogonalMatchingPursuit(
      10             :         const RepresentationProblem<data_t>& problem, data_t epsilon)
      11             :         : Solver<data_t>(), _problem(problem), _epsilon{epsilon}
      12           3 :     {
      13           3 :     }
      14             : 
      15             :     template <typename data_t>
      16             :     DataContainer<data_t>& OrthogonalMatchingPursuit<data_t>::solveImpl(index_t iterations)
      17           1 :     {
      18           1 :         const auto& dict = _problem.getDictionary();
      19           1 :         const auto& residual = _problem.getDataTerm().getResidual();
      20           1 :         auto& currentRepresentation = _problem.getCurrentSolution();
      21             : 
      22           1 :         IndexVector_t support(0); // the atoms used for the representation
      23           1 :         currentRepresentation = 0;
      24             : 
      25           1 :         index_t i = 0;
      26           3 :         while (i < iterations && _problem.evaluate() > _epsilon) {
      27           2 :             index_t k = mostCorrelatedAtom(dict, residual.evaluate(currentRepresentation));
      28             : 
      29           2 :             support.conservativeResize(support.size() + 1);
      30           2 :             support[support.size() - 1] = k;
      31           2 :             Dictionary<data_t> purgedDict = dict.getSupportedDictionary(support);
      32             : 
      33           2 :             WLSProblem<data_t> wls(purgedDict, _problem.getSignal());
      34             : 
      35           2 :             CG cgSolver(wls);
      36           2 :             const auto& wlsSolution = cgSolver.solve(10);
      37             : 
      38             :             // wlsSolution has only non-zero coefficients, copy those to the full solution with zero
      39             :             // coefficients
      40           2 :             index_t j = 0;
      41           3 :             for (const auto& atomIndex : support) {
      42           3 :                 currentRepresentation[atomIndex] = wlsSolution[j];
      43           3 :                 ++j;
      44           3 :             }
      45             : 
      46           2 :             ++i;
      47           2 :         }
      48             : 
      49           1 :         return _problem.getCurrentSolution();
      50           1 :     }
      51             : 
      52             :     template <typename data_t>
      53             :     index_t OrthogonalMatchingPursuit<data_t>::mostCorrelatedAtom(
      54             :         const Dictionary<data_t>& dict, const DataContainer<data_t>& evaluatedResidual)
      55           2 :     {
      56             :         // for this to work atom has to be L2-normalized
      57           2 :         data_t maxCorrelation = 0;
      58           2 :         index_t argMaxCorrelation = 0;
      59             : 
      60           8 :         for (index_t j = 0; j < dict.getNumberOfAtoms(); ++j) {
      61           6 :             const auto& atom = dict.getAtom(j);
      62           6 :             data_t correlation_j = std::abs(atom.dot(evaluatedResidual));
      63             : 
      64           6 :             if (correlation_j > maxCorrelation) {
      65           3 :                 maxCorrelation = correlation_j;
      66           3 :                 argMaxCorrelation = j;
      67           3 :             }
      68           6 :         }
      69           2 :         return argMaxCorrelation;
      70           2 :     }
      71             : 
      72             :     template <typename data_t>
      73             :     OrthogonalMatchingPursuit<data_t>* OrthogonalMatchingPursuit<data_t>::cloneImpl() const
      74           1 :     {
      75           1 :         return new OrthogonalMatchingPursuit(_problem, _epsilon);
      76           1 :     }
      77             : 
      78             :     template <typename data_t>
      79             :     bool OrthogonalMatchingPursuit<data_t>::isEqual(const Solver<data_t>& other) const
      80           1 :     {
      81           1 :         auto otherOMP = downcast_safe<OrthogonalMatchingPursuit>(&other);
      82           1 :         if (!otherOMP)
      83           0 :             return false;
      84             : 
      85           1 :         if (_epsilon != otherOMP->_epsilon)
      86           0 :             return false;
      87             : 
      88           1 :         return true;
      89           1 :     }
      90             : 
      91             :     // ------------------------------------------
      92             :     // explicit template instantiation
      93             :     template class OrthogonalMatchingPursuit<float>;
      94             :     template class OrthogonalMatchingPursuit<double>;
      95             : 
      96             : } // namespace elsa

Generated by: LCOV version 1.14