LCOV - code coverage report
Current view: top level - elsa/solvers - POGM.cpp (source / functions) Hit Total Coverage
Test: coverage-all.lcov Lines: 74 118 62.7 %
Date: 2024-05-16 04:22:26 Functions: 12 18 66.7 %

          Line data    Source code
       1             : #include "POGM.h"
       2             : #include "DataContainer.h"
       3             : #include "Error.h"
       4             : #include "Functional.h"
       5             : #include "LeastSquares.h"
       6             : #include "LinearOperator.h"
       7             : #include "LinearResidual.h"
       8             : #include "ProximalL1.h"
       9             : #include "TypeCasts.hpp"
      10             : #include "Logger.h"
      11             : #include "PowerIterations.h"
      12             : 
      13             : #include "WeightedLeastSquares.h"
      14             : #include "spdlog/stopwatch.h"
      15             : #include <cmath>
      16             : 
      17             : namespace elsa
      18             : {
      19             :     template <typename data_t>
      20             :     POGM<data_t>::POGM(const LinearOperator<data_t>& A, const DataContainer<data_t>& b,
      21             :                        const Functional<data_t>& h, std::optional<data_t> mu, data_t epsilon)
      22             :         : g_(LeastSquares<data_t>(A, b).clone()), h_(h.clone()), epsilon_(epsilon)
      23           2 :     {
      24           2 :         if (!h.isProxFriendly()) {
      25           0 :             throw Error("POGM: h must be prox friendly");
      26           0 :         }
      27             : 
      28           2 :         if (mu.has_value()) {
      29           2 :             lineSearchMethod_ = FixedStepSize<data_t>(*g_, *mu).clone();
      30           2 :         } else {
      31           0 :             Logger::get("POGM")->info("Computing Lipschitz constant for least squares...");
      32             :             // Chose it a little larger, to be safe
      33           0 :             auto L = 1.05 * powerIterations(adjoint(A) * A);
      34           0 :             lineSearchMethod_ = FixedStepSize<data_t>(*g_, 1 / L).clone();
      35           0 :             Logger::get("POGM")->info("Step length chosen to be: {}", 1 / L);
      36           0 :         }
      37           2 :     }
      38             : 
      39             :     template <typename data_t>
      40             :     POGM<data_t>::POGM(const LinearOperator<data_t>& A, const DataContainer<data_t>& b,
      41             :                        const DataContainer<data_t>& W, const Functional<data_t>& h,
      42             :                        std::optional<data_t> mu, data_t epsilon)
      43             :         : g_(WeightedLeastSquares<data_t>(A, b, W).clone()), h_(h.clone()), epsilon_(epsilon)
      44           2 :     {
      45           2 :         if (!h.isProxFriendly()) {
      46           0 :             throw Error("APGD: h must be prox friendly");
      47           0 :         }
      48             : 
      49           2 :         if (mu.has_value()) {
      50           2 :             lineSearchMethod_ = FixedStepSize<data_t>(*g_, *mu).clone();
      51           2 :         } else {
      52           0 :             Logger::get("POGM")->info("Computing Lipschitz constant for least squares...");
      53             :             // Chose it a little larger, to be safe
      54           0 :             auto L = 1.05 * powerIterations(adjoint(A) * A);
      55           0 :             lineSearchMethod_ = FixedStepSize<data_t>(*g_, 1 / L).clone();
      56           0 :             Logger::get("POGM")->info("Step length chosen to be: {}", 1 / L);
      57           0 :         }
      58           2 :     }
      59             : 
      60             :     template <typename data_t>
      61             :     POGM<data_t>::POGM(const Functional<data_t>& g, const Functional<data_t>& h, data_t mu,
      62             :                        data_t epsilon)
      63             :         : g_(g.clone()),
      64             :           h_(h.clone()),
      65             :           lineSearchMethod_(FixedStepSize<data_t>(*g_, mu).clone()),
      66             :           epsilon_(epsilon)
      67           0 :     {
      68           0 :         if (!h.isProxFriendly()) {
      69           0 :             throw Error("POGM: h must be prox friendly");
      70           0 :         }
      71             : 
      72           0 :         if (!g.isDifferentiable()) {
      73           0 :             throw Error("POGM: g must be differentiable");
      74           0 :         }
      75           0 :     }
      76             : 
      77             :     template <typename data_t>
      78             :     POGM<data_t>::POGM(const LinearOperator<data_t>& A, const DataContainer<data_t>& b,
      79             :                        const Functional<data_t>& h,
      80             :                        const LineSearchMethod<data_t>& lineSearchMethod, data_t epsilon)
      81             :         : g_(LeastSquares<data_t>(A, b).clone()),
      82             :           h_(h.clone()),
      83             :           lineSearchMethod_(lineSearchMethod.clone()),
      84             :           epsilon_(epsilon)
      85           0 :     {
      86           0 :         if (!h.isProxFriendly()) {
      87           0 :             throw Error("POGM: h must be prox friendly");
      88           0 :         }
      89           0 :     }
      90             : 
      91             :     template <typename data_t>
      92             :     POGM<data_t>::POGM(const LinearOperator<data_t>& A, const DataContainer<data_t>& b,
      93             :                        const DataContainer<data_t>& W, const Functional<data_t>& h,
      94             :                        const LineSearchMethod<data_t>& lineSearchMethod, data_t epsilon)
      95             :         : g_(WeightedLeastSquares<data_t>(A, b, W).clone()),
      96             :           h_(h.clone()),
      97             :           lineSearchMethod_(lineSearchMethod.clone()),
      98             :           epsilon_(epsilon)
      99           0 :     {
     100           0 :         if (!h.isProxFriendly()) {
     101           0 :             throw Error("APGD: h must be prox friendly");
     102           0 :         }
     103           0 :     }
     104             : 
     105             :     template <typename data_t>
     106             :     POGM<data_t>::POGM(const Functional<data_t>& g, const Functional<data_t>& h,
     107             :                        const LineSearchMethod<data_t>& lineSearchMethod, data_t epsilon)
     108             :         : g_(g.clone()),
     109             :           h_(h.clone()),
     110             :           lineSearchMethod_(lineSearchMethod.clone()),
     111             :           epsilon_(epsilon)
     112           4 :     {
     113           4 :         if (!h.isProxFriendly()) {
     114           0 :             throw Error("POGM: h must be prox friendly");
     115           0 :         }
     116             : 
     117           4 :         if (!g.isDifferentiable()) {
     118           0 :             throw Error("POGM: g must be differentiable");
     119           0 :         }
     120           4 :     }
     121             : 
     122             :     template <typename data_t>
     123             :     auto POGM<data_t>::solve(index_t iterations, std::optional<DataContainer<data_t>> x0)
     124             :         -> DataContainer<data_t>
     125           4 :     {
     126           4 :         spdlog::stopwatch aggregate_time;
     127             : 
     128           4 :         auto x = DataContainer<data_t>(g_->getDomainDescriptor());
     129           4 :         if (x0.has_value()) {
     130           0 :             x = *x0;
     131           4 :         } else {
     132           4 :             x = 0;
     133           4 :         }
     134             : 
     135           4 :         auto w = x;
     136           4 :         auto wPrev = x;
     137           4 :         auto z = x;
     138             : 
     139           4 :         data_t theta = 1;
     140           4 :         data_t thetaPrev = 1;
     141             : 
     142           4 :         data_t gamma = 1;
     143           4 :         data_t gammaPrev = 1;
     144             : 
     145           4 :         auto grad = DataContainer<data_t>(g_->getDomainDescriptor());
     146             : 
     147           4 :         Logger::get("POGM")->info("| {:^6} | {:^12} | {:^12} | {:^9} |", "iter", "objective",
     148           4 :                                   "gradient", "elapsed");
     149             : 
     150         204 :         for (index_t iter = 0; iter < iterations; ++iter) {
     151         200 :             if (iter != iterations - 1) {
     152             :                 // \frac{1}{2}(1 + \sqrt{4 \theta_{k-1}^2 + 1})
     153         196 :                 theta = 0.5 * (1 + std::sqrt(4 * std::pow(thetaPrev, 2) + 1));
     154         196 :             } else {
     155             :                 // \frac{1}{2}(1 + \sqrt{8 \theta_{k-1}^2 + 1})
     156           4 :                 theta = 0.5 * (1 + std::sqrt(8 * std::pow(thetaPrev, 2) + 1));
     157           4 :             }
     158             : 
     159         200 :             auto mu = lineSearchMethod_->solve(x, -grad);
     160             : 
     161         200 :             gamma = mu * ((2 * thetaPrev) + theta - 1) / theta;
     162             : 
     163             :             // Compute gradient
     164         200 :             g_->getGradient(x, grad);
     165             : 
     166             :             // w = x - mu_ * grad
     167         200 :             lincomb(1, x, -mu, grad, w);
     168             : 
     169             :             // POGM term: w_k + (\theta_{k-1} - 1) / (L * \gamma_{k-1} * \theta_k) (z_{k-1} -
     170             :             // x_{k-1}) Use the fact the our mu should be (close to) 1 / L. Start with this term to
     171             :             // reuse z.
     172         200 :             data_t weight3 = mu * ((thetaPrev - 1) / (gammaPrev * theta));
     173         200 :             lincomb(1, w, weight3, z, z);
     174         200 :             lincomb(1, z, -weight3, x, z);
     175             : 
     176             :             // Nesterov momentum: (\theta_{k-1} - 1) / \theta_k (w_k - w_{k-1})
     177         200 :             auto weight1 = (thetaPrev - 1) / theta;
     178         200 :             lincomb(1, z, weight1, w, z);
     179         200 :             lincomb(1, z, -weight1, wPrev, z);
     180             : 
     181             :             // OGM mementum term: (\theta_{k-1}  / \theta) (w_k - x_{k-1})
     182         200 :             data_t weight2 = thetaPrev / theta;
     183         200 :             lincomb(1, z, weight2, w, z);
     184         200 :             lincomb(1, z, -weight2, x, z);
     185             : 
     186             :             // x_{k+1} = prox_{gamma * g}(z)
     187         200 :             x = h_->proximal(z, gamma);
     188             : 
     189         200 :             wPrev = w;
     190             : 
     191         200 :             thetaPrev = theta;
     192         200 :             gammaPrev = gamma;
     193             : 
     194         200 :             if (grad.squaredL2Norm() <= epsilon_) {
     195           0 :                 Logger::get("POGM")->info("SUCCESS: Reached convergence at {}/{} iteration",
     196           0 :                                           iter + 1, iterations);
     197           0 :                 return x;
     198           0 :             }
     199             : 
     200         200 :             Logger::get("POGM")->info("| {:>6} | {:>12.3} | {:>12.3} | {:>8.3}s |", iter,
     201         200 :                                       g_->evaluate(x) + h_->evaluate(x), grad.squaredL2Norm(),
     202         200 :                                       aggregate_time);
     203         200 :         }
     204             : 
     205           4 :         Logger::get("POGM")->warn("Failed to reach convergence at {} iterations", iterations);
     206             : 
     207           4 :         return x;
     208           4 :     }
     209             : 
     210             :     template <typename data_t>
     211             :     auto POGM<data_t>::cloneImpl() const -> POGM<data_t>*
     212           4 :     {
     213           4 :         return new POGM<data_t>(*g_, *h_, *lineSearchMethod_, epsilon_);
     214           4 :     }
     215             : 
     216             :     template <typename data_t>
     217             :     auto POGM<data_t>::isEqual(const Solver<data_t>& other) const -> bool
     218           4 :     {
     219           4 :         auto otherPOGM = downcast_safe<POGM>(&other);
     220           4 :         if (!otherPOGM)
     221           0 :             return false;
     222             : 
     223           4 :         if (not lineSearchMethod_->isEqual(*(otherPOGM->lineSearchMethod_)))
     224           0 :             return false;
     225             : 
     226           4 :         if (epsilon_ != otherPOGM->epsilon_)
     227           0 :             return false;
     228             : 
     229           4 :         return true;
     230           4 :     }
     231             : 
     232             :     // ------------------------------------------
     233             :     // explicit template instantiation
     234             :     template class POGM<float>;
     235             :     template class POGM<double>;
     236             : } // namespace elsa

Generated by: LCOV version 1.14