LCOV - code coverage report
Current view: top level - elsa/solvers - APGD.cpp (source / functions) Hit Total Coverage
Test: coverage-all.lcov Lines: 60 102 58.8 %
Date: 2024-12-21 07:37:52 Functions: 20 26 76.9 %

          Line data    Source code
       1             : #include "APGD.h"
       2             : #include "DataContainer.h"
       3             : #include "Error.h"
       4             : #include "Functional.h"
       5             : #include "LeastSquares.h"
       6             : #include "LinearOperator.h"
       7             : #include "LinearResidual.h"
       8             : #include "ProximalL1.h"
       9             : #include "TypeCasts.hpp"
      10             : #include "Logger.h"
      11             : #include "PowerIterations.h"
      12             : 
      13             : #include "WeightedLeastSquares.h"
      14             : #include "spdlog/stopwatch.h"
      15             : 
      16             : namespace elsa
      17             : {
      18             :     template <typename data_t>
      19             :     APGD<data_t>::APGD(const LinearOperator<data_t>& A, const DataContainer<data_t>& b,
      20             :                        const Functional<data_t>& h, std::optional<data_t> mu, data_t epsilon)
      21             :         : g_(LeastSquares<data_t>(A, b).clone()),
      22             :           h_(h.clone()),
      23             :           xPrev_(empty<data_t>(g_->getDomainDescriptor())),
      24             :           y_(empty<data_t>(g_->getDomainDescriptor())),
      25             :           z_(empty<data_t>(g_->getDomainDescriptor())),
      26             :           grad_(empty<data_t>(g_->getDomainDescriptor())),
      27             :           epsilon_(epsilon)
      28           4 :     {
      29           4 :         if (!h.isProxFriendly()) {
      30           0 :             throw Error("APGD: h must be prox friendly");
      31           0 :         }
      32             : 
      33           4 :         if (mu.has_value()) {
      34           4 :             lineSearchMethod_ = FixedStepSize<data_t>(*g_, *mu).clone();
      35           4 :         } else {
      36           0 :             Logger::get("APGD")->info("Computing Lipschitz constant for least squares...");
      37             :             // Chose it a little larger, to be safe
      38           0 :             auto L = 1.05 * powerIterations(adjoint(A) * A);
      39           0 :             lineSearchMethod_ = FixedStepSize<data_t>(*g_, 1 / L).clone();
      40           0 :             Logger::get("APGD")->info("Step length chosen to be: {}", 1 / L);
      41           0 :         }
      42             : 
      43           4 :         this->name_ = "APGD";
      44           4 :     }
      45             : 
      46             :     template <typename data_t>
      47             :     APGD<data_t>::APGD(const LinearOperator<data_t>& A, const DataContainer<data_t>& b,
      48             :                        const DataContainer<data_t>& W, const Functional<data_t>& h,
      49             :                        std::optional<data_t> mu, data_t epsilon)
      50             :         : g_(WeightedLeastSquares<data_t>(A, b, W).clone()),
      51             :           h_(h.clone()),
      52             :           xPrev_(empty<data_t>(g_->getDomainDescriptor())),
      53             :           y_(empty<data_t>(g_->getDomainDescriptor())),
      54             :           z_(empty<data_t>(g_->getDomainDescriptor())),
      55             :           grad_(empty<data_t>(g_->getDomainDescriptor())),
      56             :           epsilon_(epsilon)
      57           2 :     {
      58           2 :         if (!h.isProxFriendly()) {
      59           0 :             throw Error("APGD: h must be prox friendly");
      60           0 :         }
      61             : 
      62           2 :         if (mu.has_value()) {
      63           2 :             lineSearchMethod_ = FixedStepSize<data_t>(*g_, *mu).clone();
      64           2 :         } else {
      65           0 :             Logger::get("APGD")->info("Computing Lipschitz constant for least squares...");
      66             :             // Chose it a little larger, to be safe
      67           0 :             auto L = 1.05 * powerIterations(adjoint(A) * A);
      68           0 :             lineSearchMethod_ = FixedStepSize<data_t>(*g_, 1 / L).clone();
      69           0 :             Logger::get("APGD")->info("Step length chosen to be: {}", 1 / L);
      70           0 :         }
      71             : 
      72           2 :         this->name_ = "APGD";
      73           2 :     }
      74             : 
      75             :     template <typename data_t>
      76             :     APGD<data_t>::APGD(const Functional<data_t>& g, const Functional<data_t>& h, data_t mu,
      77             :                        data_t epsilon)
      78             :         : g_(g.clone()),
      79             :           h_(h.clone()),
      80             :           xPrev_(empty<data_t>(g_->getDomainDescriptor())),
      81             :           y_(empty<data_t>(g_->getDomainDescriptor())),
      82             :           z_(empty<data_t>(g_->getDomainDescriptor())),
      83             :           grad_(empty<data_t>(g_->getDomainDescriptor())),
      84             :           lineSearchMethod_(FixedStepSize<data_t>(*g_, mu).clone()),
      85             :           epsilon_(epsilon)
      86           0 :     {
      87           0 :         if (!h.isProxFriendly()) {
      88           0 :             throw Error("APGD: h must be prox friendly");
      89           0 :         }
      90             : 
      91           0 :         if (!g.isDifferentiable()) {
      92           0 :             throw Error("APGD: g must be differentiable");
      93           0 :         }
      94             : 
      95           0 :         this->name_ = "APGD";
      96           0 :     }
      97             : 
      98             :     template <typename data_t>
      99             :     APGD<data_t>::APGD(const LinearOperator<data_t>& A, const DataContainer<data_t>& b,
     100             :                        const Functional<data_t>& h,
     101             :                        const LineSearchMethod<data_t>& lineSearchMethod, data_t epsilon)
     102             :         : g_(LeastSquares<data_t>(A, b).clone()),
     103             :           h_(h.clone()),
     104             :           xPrev_(empty<data_t>(g_->getDomainDescriptor())),
     105             :           y_(empty<data_t>(g_->getDomainDescriptor())),
     106             :           z_(empty<data_t>(g_->getDomainDescriptor())),
     107             :           grad_(empty<data_t>(g_->getDomainDescriptor())),
     108             :           lineSearchMethod_(lineSearchMethod.clone()),
     109             :           epsilon_(epsilon)
     110           0 :     {
     111           0 :         if (!h.isProxFriendly()) {
     112           0 :             throw Error("APGD: h must be prox friendly");
     113           0 :         }
     114             : 
     115           0 :         this->name_ = "APGD";
     116           0 :     }
     117             : 
     118             :     template <typename data_t>
     119             :     APGD<data_t>::APGD(const LinearOperator<data_t>& A, const DataContainer<data_t>& b,
     120             :                        const DataContainer<data_t>& W, const Functional<data_t>& h,
     121             :                        const LineSearchMethod<data_t>& lineSearchMethod, data_t epsilon)
     122             :         : g_(WeightedLeastSquares<data_t>(A, b, W).clone()),
     123             :           h_(h.clone()),
     124             :           xPrev_(empty<data_t>(g_->getDomainDescriptor())),
     125             :           y_(empty<data_t>(g_->getDomainDescriptor())),
     126             :           z_(empty<data_t>(g_->getDomainDescriptor())),
     127             :           grad_(empty<data_t>(g_->getDomainDescriptor())),
     128             :           lineSearchMethod_(lineSearchMethod.clone()),
     129             :           epsilon_(epsilon)
     130           0 :     {
     131           0 :         if (!h.isProxFriendly()) {
     132           0 :             throw Error("APGD: h must be prox friendly");
     133           0 :         }
     134             : 
     135           0 :         this->name_ = "APGD";
     136           0 :     }
     137             : 
     138             :     template <typename data_t>
     139             :     APGD<data_t>::APGD(const Functional<data_t>& g, const Functional<data_t>& h,
     140             :                        const LineSearchMethod<data_t>& lineSearchMethod, data_t epsilon)
     141             :         : g_(g.clone()),
     142             :           h_(h.clone()),
     143             :           xPrev_(empty<data_t>(g_->getDomainDescriptor())),
     144             :           y_(empty<data_t>(g_->getDomainDescriptor())),
     145             :           z_(empty<data_t>(g_->getDomainDescriptor())),
     146             :           grad_(empty<data_t>(g_->getDomainDescriptor())),
     147             :           lineSearchMethod_(lineSearchMethod.clone()),
     148             :           epsilon_(epsilon)
     149           6 :     {
     150           6 :         if (!h.isProxFriendly()) {
     151           0 :             throw Error("APGD: h must be prox friendly");
     152           0 :         }
     153             : 
     154           6 :         if (!g.isDifferentiable()) {
     155           0 :             throw Error("APGD: g must be differentiable");
     156           0 :         }
     157             : 
     158           6 :         this->name_ = "APGD";
     159           6 :     }
     160             : 
     161             :     template <typename data_t>
     162             :     DataContainer<data_t> APGD<data_t>::setup(std::optional<DataContainer<data_t>> x0)
     163           6 :     {
     164           6 :         auto x = extract_or(x0, g_->getDomainDescriptor());
     165             : 
     166           6 :         xPrev_ = x;
     167           6 :         y_ = x;
     168           6 :         z_ = x;
     169             : 
     170           6 :         tPrev_ = 1;
     171             : 
     172             :         // update gradient
     173           6 :         g_->getGradient(x, grad_);
     174             : 
     175             :         // setup done!
     176           6 :         this->configured_ = true;
     177             : 
     178           6 :         return x;
     179           6 :     }
     180             : 
     181             :     template <typename data_t>
     182             :     DataContainer<data_t> APGD<data_t>::step(DataContainer<data_t> x)
     183           6 :     {
     184           6 :         auto mu = lineSearchMethod_->solve(x, -grad_);
     185             :         // z = y - mu_ * grad
     186           6 :         lincomb(1, y_, -mu, grad_, z_);
     187             : 
     188             :         // x_{k+1} = prox_{mu * g}(y - mu * grad)
     189             :         // x = prox_.apply(z, mu_);
     190           6 :         x = h_->proximal(z_, mu);
     191             : 
     192             :         // t_{k+1} = \frac{\sqrt{1 + 4t_k^2} + 1}{2}
     193           6 :         data_t t = (1 + std::sqrt(1 + 4 * tPrev_ * tPrev_)) / 2;
     194             : 
     195             :         // y_{k+1} = x_k + \frac{t_{k-1} - 1}{t_k}(x_k - x_{k-1})
     196           6 :         lincomb(1, x, (tPrev_ - 1) / t, x - xPrev_, y_); // 1 temporary
     197             : 
     198           6 :         xPrev_ = x;
     199           6 :         tPrev_ = t;
     200             : 
     201             :         // update gradient last
     202           6 :         g_->getGradient(x, grad_);
     203             : 
     204           6 :         return x;
     205           6 :     }
     206             : 
     207             :     template <typename data_t>
     208             :     bool APGD<data_t>::shouldStop() const
     209          12 :     {
     210          12 :         return grad_.squaredL2Norm() <= epsilon_;
     211          12 :     }
     212             : 
     213             :     template <typename data_t>
     214             :     std::string APGD<data_t>::formatHeader() const
     215           6 :     {
     216           6 :         return fmt::format("{:^12} | {:^12}", "objective", "gradient");
     217           6 :     }
     218             : 
     219             :     template <typename data_t>
     220             :     std::string APGD<data_t>::formatStep(const DataContainer<data_t>& x) const
     221           6 :     {
     222           6 :         return fmt::format("{:>12.3} | {:>12.3}", g_->evaluate(x) + h_->evaluate(x),
     223           6 :                            grad_.squaredL2Norm());
     224           6 :     }
     225             : 
     226             :     template <typename data_t>
     227             :     auto APGD<data_t>::cloneImpl() const -> APGD<data_t>*
     228           6 :     {
     229           6 :         return new APGD<data_t>(*g_, *h_, *lineSearchMethod_, epsilon_);
     230           6 :     }
     231             : 
     232             :     template <typename data_t>
     233             :     auto APGD<data_t>::isEqual(const Solver<data_t>& other) const -> bool
     234           6 :     {
     235           6 :         auto otherAPGD = downcast_safe<APGD>(&other);
     236           6 :         if (!otherAPGD)
     237           0 :             return false;
     238             : 
     239           6 :         if (not lineSearchMethod_->isEqual(*(otherAPGD->lineSearchMethod_)))
     240           0 :             return false;
     241             : 
     242           6 :         if (epsilon_ != otherAPGD->epsilon_)
     243           0 :             return false;
     244             : 
     245           6 :         return true;
     246           6 :     }
     247             : 
     248             :     // ------------------------------------------
     249             :     // explicit template instantiation
     250             :     template class APGD<float>;
     251             :     template class APGD<double>;
     252             : } // namespace elsa

Generated by: LCOV version 1.14