LCOV - code coverage report
Current view: top level - elsa/solvers - PGD.cpp (source / functions) Hit Total Coverage
Test: coverage-all.lcov Lines: 48 88 54.5 %
Date: 2024-05-16 04:22:26 Functions: 12 18 66.7 %

          Line data    Source code
       1             : #include "PGD.h"
       2             : #include "DataContainer.h"
       3             : #include "Functional.h"
       4             : #include "LeastSquares.h"
       5             : #include "LinearOperator.h"
       6             : #include "LinearResidual.h"
       7             : #include "ProximalL1.h"
       8             : #include "Solver.h"
       9             : #include "TypeCasts.hpp"
      10             : #include "Logger.h"
      11             : #include "PowerIterations.h"
      12             : 
      13             : #include "WeightedLeastSquares.h"
      14             : #include "spdlog/stopwatch.h"
      15             : 
      16             : namespace elsa
      17             : {
      18             :     template <typename data_t>
      19             :     PGD<data_t>::PGD(const LinearOperator<data_t>& A, const DataContainer<data_t>& b,
      20             :                      const Functional<data_t>& h, std::optional<data_t> mu, data_t epsilon)
      21             :         : g_(LeastSquares<data_t>(A, b).clone()), h_(h.clone()), epsilon_(epsilon)
      22           4 :     {
      23           4 :         if (!h.isProxFriendly()) {
      24           0 :             throw Error("PGD: h must be prox friendly");
      25           0 :         }
      26             : 
      27           4 :         if (mu.has_value()) {
      28           4 :             lineSearchMethod_ = FixedStepSize<data_t>(*g_, *mu).clone();
      29           4 :         } else {
      30           0 :             Logger::get("PGD")->info("Computing Lipschitz constant for least squares...");
      31             :             // Chose it a little larger, to be safe
      32           0 :             auto L = 1.05 * powerIterations(adjoint(A) * A);
      33           0 :             lineSearchMethod_ = FixedStepSize<data_t>(*g_, 1 / L).clone();
      34           0 :             Logger::get("PGD")->info("Step length chosen to be: {}", 1 / L);
      35           0 :         }
      36           4 :     }
      37             : 
      38             :     template <typename data_t>
      39             :     PGD<data_t>::PGD(const LinearOperator<data_t>& A, const DataContainer<data_t>& b,
      40             :                      const DataContainer<data_t>& W, const Functional<data_t>& h,
      41             :                      std::optional<data_t> mu, data_t epsilon)
      42             :         : g_(WeightedLeastSquares<data_t>(A, b, W).clone()), h_(h.clone()), epsilon_(epsilon)
      43           2 :     {
      44           2 :         if (!h.isProxFriendly()) {
      45           0 :             throw Error("APGD: h must be prox friendly");
      46           0 :         }
      47             : 
      48           2 :         if (mu.has_value()) {
      49           2 :             lineSearchMethod_ = FixedStepSize<data_t>(*g_, *mu).clone();
      50           2 :         } else {
      51           0 :             Logger::get("PGD")->info("Computing Lipschitz constant for least squares...");
      52             :             // Chose it a little larger, to be safe
      53           0 :             auto L = 1.05 * powerIterations(adjoint(A) * A);
      54           0 :             lineSearchMethod_ = FixedStepSize<data_t>(*g_, 1 / L).clone();
      55           0 :             Logger::get("PGD")->info("Step length chosen to be: {}", 1 / L);
      56           0 :         }
      57           2 :     }
      58             : 
      59             :     template <typename data_t>
      60             :     PGD<data_t>::PGD(const Functional<data_t>& g, const Functional<data_t>& h, data_t mu,
      61             :                      data_t epsilon)
      62             :         : g_(g.clone()),
      63             :           h_(h.clone()),
      64             :           epsilon_(epsilon),
      65             :           lineSearchMethod_(FixedStepSize<data_t>(*g_, mu).clone())
      66           0 :     {
      67           0 :         if (!h.isProxFriendly()) {
      68           0 :             throw Error("PGD: h must be prox friendly");
      69           0 :         }
      70             : 
      71           0 :         if (!g.isDifferentiable()) {
      72           0 :             throw Error("PGD: g must be differentiable");
      73           0 :         }
      74           0 :     }
      75             : 
      76             :     template <typename data_t>
      77             :     PGD<data_t>::PGD(const LinearOperator<data_t>& A, const DataContainer<data_t>& b,
      78             :                      const Functional<data_t>& h, const LineSearchMethod<data_t>& lineSearchMethod,
      79             :                      data_t epsilon)
      80             :         : g_(LeastSquares<data_t>(A, b).clone()),
      81             :           h_(h.clone()),
      82             :           epsilon_(epsilon),
      83             :           lineSearchMethod_(lineSearchMethod.clone())
      84           0 :     {
      85           0 :         if (!h.isProxFriendly()) {
      86           0 :             throw Error("PGD: h must be prox friendly");
      87           0 :         }
      88           0 :     }
      89             : 
      90             :     template <typename data_t>
      91             :     PGD<data_t>::PGD(const LinearOperator<data_t>& A, const DataContainer<data_t>& b,
      92             :                      const DataContainer<data_t>& W, const Functional<data_t>& h,
      93             :                      const LineSearchMethod<data_t>& lineSearchMethod, data_t epsilon)
      94             :         : g_(WeightedLeastSquares<data_t>(A, b, W).clone()),
      95             :           h_(h.clone()),
      96             :           epsilon_(epsilon),
      97             :           lineSearchMethod_(lineSearchMethod.clone())
      98           0 :     {
      99           0 :         if (!h.isProxFriendly()) {
     100           0 :             throw Error("APGD: h must be prox friendly");
     101           0 :         }
     102           0 :     }
     103             : 
     104             :     template <typename data_t>
     105             :     PGD<data_t>::PGD(const Functional<data_t>& g, const Functional<data_t>& h,
     106             :                      const LineSearchMethod<data_t>& lineSearchMethod, data_t epsilon)
     107             :         : g_(g.clone()),
     108             :           h_(h.clone()),
     109             :           epsilon_(epsilon),
     110             :           lineSearchMethod_(lineSearchMethod.clone())
     111           6 :     {
     112           6 :         if (!h.isProxFriendly()) {
     113           0 :             throw Error("PGD: h must be prox friendly");
     114           0 :         }
     115             : 
     116           6 :         if (!g.isDifferentiable()) {
     117           0 :             throw Error("PGD: g must be differentiable");
     118           0 :         }
     119           6 :     }
     120             : 
     121             :     template <typename data_t>
     122             :     auto PGD<data_t>::solve(index_t iterations, std::optional<DataContainer<data_t>> x0)
     123             :         -> DataContainer<data_t>
     124           6 :     {
     125           6 :         spdlog::stopwatch aggregate_time;
     126             : 
     127           6 :         auto x = extract_or(x0, g_->getDomainDescriptor());
     128           6 :         auto grad = emptylike(x);
     129           6 :         auto y = emptylike(x);
     130             : 
     131           6 :         Logger::get("PGD")->info("| {:^6} | {:^12} | {:^12} | {:^9} |", "iter", "g", "gradient",
     132           6 :                                  "elapsed");
     133             : 
     134          12 :         for (index_t iter = 0; iter < iterations; ++iter) {
     135          12 :             g_->getGradient(x, grad);
     136          12 :             auto mu = lineSearchMethod_->solve(x, -grad);
     137             : 
     138             :             // y = x - mu_ * grad
     139          12 :             lincomb(1, x, -mu, grad, y);
     140             : 
     141             :             // apply proximal
     142          12 :             x = h_->proximal(y, mu);
     143             : 
     144          12 :             if (grad.squaredL2Norm() <= epsilon_) {
     145           6 :                 Logger::get("PGD")->info("SUCCESS: Reached convergence at {}/{} iteration",
     146           6 :                                          iter + 1, iterations);
     147           6 :                 return x;
     148           6 :             }
     149             : 
     150           6 :             Logger::get("PGD")->info("| {:>6} | {:>12.3} | {:>12.3} | {:>8.3}s |", iter,
     151           6 :                                      g_->evaluate(x), grad.squaredL2Norm(), aggregate_time);
     152           6 :         }
     153             : 
     154           6 :         Logger::get("PGD")->warn("Failed to reach convergence at {} iterations", iterations);
     155             : 
     156           0 :         return x;
     157           6 :     }
     158             : 
     159             :     template <typename data_t>
     160             :     auto PGD<data_t>::cloneImpl() const -> PGD<data_t>*
     161           6 :     {
     162           6 :         return new PGD<data_t>(*g_, *h_, *lineSearchMethod_, epsilon_);
     163           6 :     }
     164             : 
     165             :     template <typename data_t>
     166             :     auto PGD<data_t>::isEqual(const Solver<data_t>& other) const -> bool
     167           6 :     {
     168           6 :         auto otherPGD = downcast_safe<PGD>(&other);
     169           6 :         if (!otherPGD)
     170           0 :             return false;
     171             : 
     172           6 :         if (not lineSearchMethod_->isEqual(*(otherPGD->lineSearchMethod_)))
     173           0 :             return false;
     174             : 
     175           6 :         if (epsilon_ != otherPGD->epsilon_)
     176           0 :             return false;
     177             : 
     178           6 :         return true;
     179           6 :     }
     180             : 
     181             :     // ------------------------------------------
     182             :     // explicit template instantiation
     183             :     template class PGD<float>;
     184             :     template class PGD<double>;
     185             : } // namespace elsa

Generated by: LCOV version 1.14