Line data Source code
1 : #include "OrthogonalMatchingPursuit.h" 2 : #include "TypeCasts.hpp" 3 : #include "CGLS.h" 4 : 5 : #include <iostream> 6 : 7 : namespace elsa 8 : { 9 : template <typename data_t> 10 : OrthogonalMatchingPursuit<data_t>::OrthogonalMatchingPursuit(const Dictionary<data_t>& D, 11 : const DataContainer<data_t>& y, 12 : data_t epsilon) 13 : : Solver<data_t>(), dict_(D), signal_(y), epsilon_{epsilon} 14 3 : { 15 3 : } 16 : 17 : template <typename data_t> 18 : DataContainer<data_t> 19 : OrthogonalMatchingPursuit<data_t>::solve(index_t iterations, 20 : std::optional<DataContainer<data_t>> x0) 21 1 : { 22 2 : auto eval = [&](auto x) { return 0.5 * (dict_.apply(x) - signal_).squaredL2Norm(); }; 23 : 24 1 : auto x = extract_or(x0, dict_.getDomainDescriptor()); 25 : 26 1 : IndexVector_t support(0); // the atoms used for the representation 27 : 28 1 : index_t i = 0; 29 3 : while (i < iterations && eval(x) > epsilon_) { 30 2 : auto residual = dict_.apply(x) - signal_; 31 2 : index_t k = mostCorrelatedAtom(dict_, residual); 32 : 33 2 : support.conservativeResize(support.size() + 1); 34 2 : support[support.size() - 1] = k; 35 2 : Dictionary<data_t> purgedDict = dict_.getSupportedDictionary(support); 36 : 37 2 : CGLS<data_t> cgSolver(purgedDict, signal_); 38 2 : const auto wlsSolution = cgSolver.solve(10); 39 : 40 : // wlsSolution has only non-zero coefficients, copy those to the full solution with zero 41 : // coefficients 42 2 : index_t j = 0; 43 3 : for (const auto& atomIndex : support) { 44 3 : x[atomIndex] = wlsSolution[j]; 45 3 : ++j; 46 3 : } 47 : 48 2 : ++i; 49 2 : } 50 : 51 1 : return x; 52 1 : } 53 : 54 : template <typename data_t> 55 : index_t OrthogonalMatchingPursuit<data_t>::mostCorrelatedAtom( 56 : const Dictionary<data_t>& dict, const DataContainer<data_t>& evaluatedResidual) 57 2 : { 58 : // for this to work atom has to be L2-normalized 59 2 : data_t maxCorrelation = 0; 60 2 : index_t argMaxCorrelation = 0; 61 : 62 8 : for (index_t j = 0; j < dict.getNumberOfAtoms(); ++j) { 63 6 : const auto& atom = dict.getAtom(j); 64 6 : data_t correlation_j = std::abs(atom.dot(evaluatedResidual)); 65 : 66 6 : if (correlation_j > maxCorrelation) { 67 3 : maxCorrelation = correlation_j; 68 3 : argMaxCorrelation = j; 69 3 : } 70 6 : } 71 : 72 2 : return argMaxCorrelation; 73 2 : } 74 : 75 : template <typename data_t> 76 : OrthogonalMatchingPursuit<data_t>* OrthogonalMatchingPursuit<data_t>::cloneImpl() const 77 1 : { 78 1 : return new OrthogonalMatchingPursuit(dict_, signal_, epsilon_); 79 1 : } 80 : 81 : template <typename data_t> 82 : bool OrthogonalMatchingPursuit<data_t>::isEqual(const Solver<data_t>& other) const 83 1 : { 84 1 : auto otherOMP = downcast_safe<OrthogonalMatchingPursuit>(&other); 85 1 : if (!otherOMP) 86 0 : return false; 87 : 88 1 : if (epsilon_ != otherOMP->epsilon_) 89 0 : return false; 90 : 91 1 : return true; 92 1 : } 93 : 94 : // ------------------------------------------ 95 : // explicit template instantiation 96 : template class OrthogonalMatchingPursuit<float>; 97 : template class OrthogonalMatchingPursuit<double>; 98 : 99 : } // namespace elsa