LCOV - code coverage report
Current view: top level - elsa/solvers - ADMM.h (source / functions) Hit Total Coverage
Test: coverage-all.lcov Lines: 81 108 75.0 %
Date: 2022-08-25 03:05:39 Functions: 12 20 60.0 %

          Line data    Source code
       1             : #pragma once
       2             : 
       3             : #include <memory>
       4             : 
       5             : #include "Solver.h"
       6             : #include "ProximityOperator.h"
       7             : #include "SplittingProblem.h"
       8             : #include "L0PseudoNorm.h"
       9             : #include "L1Norm.h"
      10             : #include "L2NormPow2.h"
      11             : #include "LinearResidual.h"
      12             : #include "Logger.h"
      13             : 
      14             : namespace elsa
      15             : {
      16             :     /**
      17             :      * @brief Class representing an Alternating Direction Method of Multipliers solver
      18             :      *
      19             :      * @author Andi Braimllari - initial code
      20             :      *
      21             :      * @tparam data_t data type for the domain and range of the problem, defaulting to real_t
      22             :      * @tparam XSolver Solver type handling the x update
      23             :      * @tparam ZSolver ProximityOperator type handling the z update
      24             :      *
      25             :      * ADMM solves minimization splitting problems of the form
      26             :      * @f$ x \mapsto f(x) + g(z) @f$ such that @f$ Ax + Bz = c @f$.
      27             :      * Commonly regularized optimization problems can be rewritten in such a form by using variable
      28             :      * splitting.
      29             :      *
      30             :      * ADMM can be expressed in the following scaled form
      31             :      *
      32             :      *  - @f$ x_{k+1} = argmin_{x}(f(x) + (\rho/2) ·\| Ax + Bz_{k} - c + u_{k}\|^2_2) @f$
      33             :      *  - @f$ z_{k+1} = argmin_{z}(g(z) + (\rho/2) ·\| Ax_{k+1} + Bz - c + u_{k}\|^2_2) @f$
      34             :      *  - @f$ u_{k+1} = u_{k} + Ax_{k+1} + Bz_{k+1} - c @f$
      35             :      *
      36             :      * References:
      37             :      * https://stanford.edu/~boyd/papers/pdf/admm_distr_stats.pdf
      38             :      */
      39             :     template <template <typename> class XSolver, template <typename> class ZSolver,
      40             :               typename data_t = real_t>
      41             :     class ADMM : public Solver<data_t>
      42             :     {
      43             :     public:
      44             :         /// Scalar alias
      45             :         using Scalar = typename Solver<data_t>::Scalar;
      46             : 
      47             :         ADMM(const SplittingProblem<data_t>& splittingProblem)
      48             :             : Solver<data_t>(),
      49             :               _problem(static_cast<SplittingProblem<data_t>*>(splittingProblem.clone().release()))
      50           4 :         {
      51           4 :             static_assert(std::is_base_of<Solver<data_t>, XSolver<data_t>>::value,
      52           4 :                           "ADMM: XSolver must extend Solver");
      53             : 
      54           4 :             static_assert(std::is_base_of<ProximityOperator<data_t>, ZSolver<data_t>>::value,
      55           4 :                           "ADMM: ZSolver must extend ProximityOperator");
      56           4 :         }
      57             : 
      58             :         ADMM(const SplittingProblem<data_t>& splittingProblem, index_t defaultXSolverIterations)
      59             :             : Solver<data_t>(),
      60             :               _problem(static_cast<SplittingProblem<data_t>*>(splittingProblem.clone().release())),
      61             :               _defaultXSolverIterations{defaultXSolverIterations}
      62             :         {
      63             :             static_assert(std::is_base_of<Solver<data_t>, XSolver<data_t>>::value,
      64             :                           "ADMM: XSolver must extend Solver");
      65             : 
      66             :             static_assert(std::is_base_of<ProximityOperator<data_t>, ZSolver<data_t>>::value,
      67             :                           "ADMM: ZSolver must extend ProximityOperator");
      68             :         }
      69             : 
      70             :         ADMM(const SplittingProblem<data_t>& splittingProblem, index_t defaultXSolverIterations,
      71             :              data_t epsilonAbs, data_t epsilonRel)
      72             :             : Solver<data_t>(),
      73             :               _problem(splittingProblem),
      74             :               _defaultXSolverIterations{defaultXSolverIterations},
      75             :               _epsilonAbs{epsilonAbs},
      76             :               _epsilonRel{epsilonRel}
      77             :         {
      78             :             static_assert(std::is_base_of<Solver<data_t>, XSolver<data_t>>::value,
      79             :                           "ADMM: XSolver must extend Solver");
      80             : 
      81             :             static_assert(std::is_base_of<ProximityOperator<data_t>, ZSolver<data_t>>::value,
      82             :                           "ADMM: ZSolver must extend ProximityOperator");
      83             :         }
      84             : 
      85             :         /// default destructor
      86           4 :         ~ADMM() override = default;
      87             : 
      88             :         auto solveImpl(index_t iterations) -> DataContainer<data_t>& override
      89           8 :         {
      90           8 :             if (iterations == 0)
      91           0 :                 iterations = _defaultIterations;
      92             : 
      93           8 :             const auto& f = _problem->getF();
      94           8 :             const auto& g = _problem->getG();
      95             : 
      96           8 :             const auto& dataTerm = f;
      97             : 
      98           8 :             if (!is<L2NormPow2<data_t>>(dataTerm)) {
      99           0 :                 throw std::invalid_argument(
     100           0 :                     "ADMM::solveImpl: supported data term only of type L2NormPow2");
     101           0 :             }
     102             : 
     103             :             // Safe as long as only LinearResidual exits
     104           8 :             const auto& dataTermResidual = downcast<LinearResidual<data_t>>(f.getResidual());
     105             : 
     106           8 :             if (g.size() != 1) {
     107           0 :                 throw std::invalid_argument(
     108           0 :                     "ADMM::solveImpl: supported number of regularization terms is 1");
     109           0 :             }
     110             : 
     111           8 :             data_t regWeight = g[0].getWeight();
     112           8 :             Functional<data_t>& regularizationTerm = g[0].getFunctional();
     113             : 
     114           8 :             if (!is<L0PseudoNorm<data_t>>(regularizationTerm)
     115           8 :                 && !is<L1Norm<data_t>>(regularizationTerm)) {
     116           0 :                 throw std::invalid_argument("ADMM::solveImpl: supported regularization terms are "
     117           0 :                                             "of type L0PseudoNorm or L1Norm");
     118           0 :             }
     119             : 
     120           8 :             const auto& constraint = _problem->getConstraint();
     121           8 :             const auto& A = constraint.getOperatorA();
     122           8 :             const auto& B = constraint.getOperatorB();
     123           8 :             const auto& c = constraint.getDataVectorC();
     124             : 
     125           8 :             DataContainer<data_t> x(A.getRangeDescriptor());
     126           8 :             x = 0;
     127             : 
     128           8 :             DataContainer<data_t> z(B.getRangeDescriptor());
     129           8 :             z = 0;
     130             : 
     131           8 :             DataContainer<data_t> u(c.getDataDescriptor());
     132           8 :             u = 0;
     133             : 
     134           8 :             Logger::get("ADMM")->info("{:*^20}|{:*^20}|{:*^20}|{:*^20}|{:*^20}|{:*^20}",
     135           8 :                                       "iteration", "xL2NormSq", "zL2NormSq", "uL2NormSq",
     136           8 :                                       "rkL2Norm", "skL2Norm");
     137             : 
     138          34 :             for (index_t iter = 0; iter < iterations; ++iter) {
     139          34 :                 LinearResidual<data_t> xLinearResidual(A, c - B.apply(z) - u);
     140          34 :                 RegularizationTerm xRegTerm(_rho / 2, L2NormPow2<data_t>(xLinearResidual));
     141          34 :                 Problem<data_t> xUpdateProblem(dataTerm, xRegTerm, x);
     142             : 
     143          34 :                 XSolver<data_t> xSolver(xUpdateProblem);
     144          34 :                 x = xSolver.solve(_defaultXSolverIterations);
     145             : 
     146          34 :                 DataContainer<data_t> rk = x;
     147          34 :                 DataContainer<data_t> zPrev = z;
     148          34 :                 data_t Axnorm = x.l2Norm();
     149             : 
     150             :                 /// For future reference, below is listed the problem to be solved by the z update
     151             :                 /// solver. Refer to the documentation of ADMM for further details.
     152             :                 // LinearResidual<data_t> zLinearResidual(B, c - A.apply(x) - u);
     153             :                 // RegularizationTerm zRegTerm(_rho / 2, L2NormPow2<data_t>(zLinearResidual));
     154             :                 // Problem<data_t> zUpdateProblem(regularizationTerm, zRegTerm, z);
     155             : 
     156          34 :                 ZSolver<data_t> zProxOp(A.getRangeDescriptor());
     157          34 :                 z = zProxOp.apply(x + u, geometry::Threshold{regWeight / _rho});
     158             : 
     159          34 :                 rk -= z;
     160          34 :                 DataContainer<data_t> sk = zPrev - z;
     161          34 :                 sk *= _rho;
     162             : 
     163          34 :                 u += A.apply(x) + B.apply(z) - c;
     164             : 
     165          34 :                 DataContainer<data_t> Atu = u;
     166          34 :                 Atu *= _rho;
     167          34 :                 data_t rkL2Norm = rk.l2Norm();
     168          34 :                 data_t skL2Norm = sk.l2Norm();
     169             : 
     170          34 :                 Logger::get("ADMM")->info("{:<19}| {:<19}| {:<19}| {:<19}| {:<19}| {:<19}", iter,
     171          34 :                                           x.squaredL2Norm(), z.squaredL2Norm(), u.squaredL2Norm(),
     172          34 :                                           rkL2Norm, skL2Norm);
     173             : 
     174             :                 /// variables for the stopping criteria
     175          34 :                 const data_t cL2Norm = !dataTermResidual.hasDataVector()
     176          34 :                                            ? static_cast<data_t>(0.0)
     177          34 :                                            : dataTermResidual.getDataVector().l2Norm();
     178          34 :                 const data_t epsRelMax =
     179          34 :                     _epsilonRel * std::max(std::max(Axnorm, z.l2Norm()), cL2Norm);
     180          34 :                 const auto epsilonPri = (std::sqrt(rk.getSize()) * _epsilonAbs) + epsRelMax;
     181             : 
     182          34 :                 const data_t epsRelL2Norm = _epsilonRel * Atu.l2Norm();
     183          34 :                 const auto epsilonDual = (std::sqrt(sk.getSize()) * _epsilonAbs) + epsRelL2Norm;
     184             : 
     185          34 :                 if (rkL2Norm <= epsilonPri && skL2Norm <= epsilonDual) {
     186           8 :                     Logger::get("ADMM")->info("SUCCESS: Reached convergence at {}/{} iterations ",
     187           8 :                                               iter, iterations);
     188             : 
     189           8 :                     _problem->getCurrentSolution() = x;
     190           8 :                     return _problem->getCurrentSolution();
     191           8 :                 }
     192             : 
     193             :                 /// varying penalty parameter
     194          26 :                 if (std::abs(_tauIncr - static_cast<data_t>(1.0))
     195          26 :                         > std::numeric_limits<data_t>::epsilon()
     196          26 :                     || std::abs(_tauDecr - static_cast<data_t>(1.0))
     197          26 :                            > std::numeric_limits<data_t>::epsilon()) {
     198          26 :                     if (rkL2Norm > _mu * skL2Norm) {
     199           0 :                         _rho *= _tauIncr;
     200           0 :                         u /= _tauIncr;
     201          26 :                     } else if (skL2Norm > _mu * rkL2Norm) {
     202          26 :                         _rho /= _tauDecr;
     203          26 :                         u *= _tauDecr;
     204          26 :                     }
     205          26 :                 }
     206          26 :             }
     207             : 
     208           8 :             Logger::get("ADMM")->warn("Failed to reach convergence at {} iterations", iterations);
     209             : 
     210           0 :             _problem->getCurrentSolution() = x;
     211           0 :             return _problem->getCurrentSolution();
     212           8 :         }
     213             : 
     214             :     protected:
     215             :         /// implement the polymorphic clone operation
     216             :         auto cloneImpl() const -> ADMM<XSolver, ZSolver, data_t>* override
     217           0 :         {
     218           0 :             return new ADMM<XSolver, ZSolver, data_t>(
     219           0 :                 downcast<SplittingProblem<data_t>>(*_problem));
     220           0 :         }
     221             : 
     222             :         bool isEqual(const Solver<data_t>& other) const
     223           0 :         {
     224           0 :             auto otherADMM = downcast_safe<ADMM>(&other);
     225           0 :             if (!otherADMM)
     226           0 :                 return false;
     227             : 
     228           0 :             if (_problem != otherADMM->_problem)
     229           0 :                 return false;
     230             : 
     231           0 :             return _rho == otherADMM->_rho && _mu == otherADMM->_mu
     232           0 :                    && _tauIncr == otherADMM->_tauIncr && _tauDecr == otherADMM->_tauDecr;
     233           0 :         }
     234             : 
     235             :     private:
     236             :         /// Splitting problem to solve
     237             :         /// TODO: Remove requirement of unique_ptr
     238             :         std::unique_ptr<SplittingProblem<data_t>> _problem;
     239             : 
     240             :         /// the default number of iterations for ADMM
     241             :         index_t _defaultIterations{100};
     242             : 
     243             :         /// the default number of iterations for the XSolver
     244             :         index_t _defaultXSolverIterations{5};
     245             : 
     246             :         /// @f$ \rho @f$ from the problem definition
     247             :         data_t _rho{1};
     248             : 
     249             :         /// variables for varying penalty parameter @f$ \rho @f$
     250             :         data_t _mu{10};
     251             :         data_t _tauIncr{2};
     252             :         data_t _tauDecr{2};
     253             : 
     254             :         /// variables for the stopping criteria
     255             :         data_t _epsilonAbs{1e-5f};
     256             :         data_t _epsilonRel{1e-5f};
     257             :     };
     258             : } // namespace elsa

Generated by: LCOV version 1.14