Line data Source code
1 : #include "OrthogonalMatchingPursuit.h" 2 : #include "TypeCasts.hpp" 3 : #include "WLSProblem.h" 4 : #include "CG.h" 5 : 6 : namespace elsa 7 : { 8 : template <typename data_t> 9 : OrthogonalMatchingPursuit<data_t>::OrthogonalMatchingPursuit( 10 : const RepresentationProblem<data_t>& problem, data_t epsilon) 11 : : Solver<data_t>(), _problem(problem), _epsilon{epsilon} 12 3 : { 13 3 : } 14 : 15 : template <typename data_t> 16 : DataContainer<data_t>& OrthogonalMatchingPursuit<data_t>::solveImpl(index_t iterations) 17 1 : { 18 1 : const auto& dict = _problem.getDictionary(); 19 1 : const auto& residual = _problem.getDataTerm().getResidual(); 20 1 : auto& currentRepresentation = _problem.getCurrentSolution(); 21 : 22 1 : IndexVector_t support(0); // the atoms used for the representation 23 1 : currentRepresentation = 0; 24 : 25 1 : index_t i = 0; 26 3 : while (i < iterations && _problem.evaluate() > _epsilon) { 27 2 : index_t k = mostCorrelatedAtom(dict, residual.evaluate(currentRepresentation)); 28 : 29 2 : support.conservativeResize(support.size() + 1); 30 2 : support[support.size() - 1] = k; 31 2 : Dictionary<data_t> purgedDict = dict.getSupportedDictionary(support); 32 : 33 2 : WLSProblem<data_t> wls(purgedDict, _problem.getSignal()); 34 : 35 2 : CG cgSolver(wls); 36 2 : const auto& wlsSolution = cgSolver.solve(10); 37 : 38 : // wlsSolution has only non-zero coefficients, copy those to the full solution with zero 39 : // coefficients 40 2 : index_t j = 0; 41 3 : for (const auto& atomIndex : support) { 42 3 : currentRepresentation[atomIndex] = wlsSolution[j]; 43 3 : ++j; 44 3 : } 45 : 46 2 : ++i; 47 2 : } 48 : 49 1 : return _problem.getCurrentSolution(); 50 1 : } 51 : 52 : template <typename data_t> 53 : index_t OrthogonalMatchingPursuit<data_t>::mostCorrelatedAtom( 54 : const Dictionary<data_t>& dict, const DataContainer<data_t>& evaluatedResidual) 55 2 : { 56 : // for this to work atom has to be L2-normalized 57 2 : data_t maxCorrelation = 0; 58 2 : index_t argMaxCorrelation = 0; 59 : 60 8 : for (index_t j = 0; j < dict.getNumberOfAtoms(); ++j) { 61 6 : const auto& atom = dict.getAtom(j); 62 6 : data_t correlation_j = std::abs(atom.dot(evaluatedResidual)); 63 : 64 6 : if (correlation_j > maxCorrelation) { 65 3 : maxCorrelation = correlation_j; 66 3 : argMaxCorrelation = j; 67 3 : } 68 6 : } 69 2 : return argMaxCorrelation; 70 2 : } 71 : 72 : template <typename data_t> 73 : OrthogonalMatchingPursuit<data_t>* OrthogonalMatchingPursuit<data_t>::cloneImpl() const 74 1 : { 75 1 : return new OrthogonalMatchingPursuit(_problem, _epsilon); 76 1 : } 77 : 78 : template <typename data_t> 79 : bool OrthogonalMatchingPursuit<data_t>::isEqual(const Solver<data_t>& other) const 80 1 : { 81 1 : auto otherOMP = downcast_safe<OrthogonalMatchingPursuit>(&other); 82 1 : if (!otherOMP) 83 0 : return false; 84 : 85 1 : if (_epsilon != otherOMP->_epsilon) 86 0 : return false; 87 : 88 1 : return true; 89 1 : } 90 : 91 : // ------------------------------------------ 92 : // explicit template instantiation 93 : template class OrthogonalMatchingPursuit<float>; 94 : template class OrthogonalMatchingPursuit<double>; 95 : 96 : } // namespace elsa