Line data Source code
1 : #include "OrthogonalMatchingPursuit.h" 2 : #include "TypeCasts.hpp" 3 : 4 : namespace elsa 5 : { 6 : template <typename data_t> 7 0 : OrthogonalMatchingPursuit<data_t>::OrthogonalMatchingPursuit( 8 : const RepresentationProblem<data_t>& problem, data_t epsilon) 9 0 : : Solver<data_t>(problem), _epsilon{epsilon} 10 : { 11 0 : } 12 : 13 : template <typename data_t> 14 0 : DataContainer<data_t>& OrthogonalMatchingPursuit<data_t>::solveImpl(index_t iterations) 15 : { 16 : // Safe, as it's the only possible input 17 0 : const auto& reprProblem = downcast<RepresentationProblem<data_t>>(*_problem); 18 : 19 0 : const auto& dict = reprProblem.getDictionary(); 20 0 : const auto& residual = _problem->getDataTerm().getResidual(); 21 0 : auto& currentRepresentation = _problem->getCurrentSolution(); 22 : 23 0 : IndexVector_t support(0); // the atoms used for the representation 24 0 : currentRepresentation = 0; 25 : 26 0 : index_t i = 0; 27 0 : while (i < iterations && _problem->evaluate() > _epsilon) { 28 0 : index_t k = mostCorrelatedAtom(dict, residual.evaluate(currentRepresentation)); 29 : 30 0 : support.conservativeResize(support.size() + 1); 31 0 : support[support.size() - 1] = k; 32 0 : Dictionary<data_t> purgedDict = dict.getSupportedDictionary(support); 33 : 34 0 : WLSProblem<data_t> wls(purgedDict, reprProblem.getSignal()); 35 : 36 0 : CG cgSolver(wls); 37 0 : const auto& wlsSolution = cgSolver.solve(10); 38 : 39 : // wlsSolution has only non-zero coefficients, copy those to the full solution with zero 40 : // coefficients 41 0 : index_t j = 0; 42 0 : for (const auto& atomIndex : support) { 43 0 : currentRepresentation[atomIndex] = wlsSolution[j]; 44 0 : ++j; 45 : } 46 : 47 0 : ++i; 48 : } 49 : 50 0 : return getCurrentSolution(); 51 0 : } 52 : 53 : template <typename data_t> 54 0 : index_t OrthogonalMatchingPursuit<data_t>::mostCorrelatedAtom( 55 : const Dictionary<data_t>& dict, const DataContainer<data_t>& evaluatedResidual) 56 : { 57 : // for this to work atom has to be L2-normalized 58 0 : data_t maxCorrelation = 0; 59 0 : index_t argMaxCorrelation = 0; 60 : 61 0 : for (index_t j = 0; j < dict.getNumberOfAtoms(); ++j) { 62 0 : const auto& atom = dict.getAtom(j); 63 0 : data_t correlation_j = std::abs(atom.dot(evaluatedResidual)); 64 : 65 0 : if (correlation_j > maxCorrelation) { 66 0 : maxCorrelation = correlation_j; 67 0 : argMaxCorrelation = j; 68 : } 69 : } 70 0 : return argMaxCorrelation; 71 : } 72 : 73 : template <typename data_t> 74 0 : OrthogonalMatchingPursuit<data_t>* OrthogonalMatchingPursuit<data_t>::cloneImpl() const 75 : { 76 0 : return new OrthogonalMatchingPursuit(downcast<RepresentationProblem<data_t>>(*_problem), 77 0 : _epsilon); 78 : } 79 : 80 : template <typename data_t> 81 0 : bool OrthogonalMatchingPursuit<data_t>::isEqual(const Solver<data_t>& other) const 82 : { 83 0 : if (!Solver<data_t>::isEqual(other)) 84 0 : return false; 85 : 86 0 : auto otherOMP = downcast_safe<OrthogonalMatchingPursuit>(&other); 87 0 : if (!otherOMP) 88 0 : return false; 89 : 90 0 : if (_epsilon != otherOMP->_epsilon) 91 0 : return false; 92 : 93 0 : return true; 94 : } 95 : 96 : // ------------------------------------------ 97 : // explicit template instantiation 98 : template class OrthogonalMatchingPursuit<float>; 99 : template class OrthogonalMatchingPursuit<double>; 100 : 101 : } // namespace elsa