LCOV - code coverage report
Current view: top level - solvers - ADMM.h (source / functions) Hit Total Coverage
Test: test_coverage.info.cleaned Lines: 0 81 0.0 %
Date: 2022-08-04 03:43:28 Functions: 0 20 0.0 %

          Line data    Source code
       1             : #pragma once
       2             : 
       3             : #include "Solver.h"
       4             : #include "ProximityOperator.h"
       5             : #include "SplittingProblem.h"
       6             : #include "L0PseudoNorm.h"
       7             : #include "L1Norm.h"
       8             : #include "L2NormPow2.h"
       9             : #include "LinearResidual.h"
      10             : #include "Logger.h"
      11             : 
      12             : namespace elsa
      13             : {
      14             :     /**
      15             :      * @brief Class representing an Alternating Direction Method of Multipliers solver
      16             :      *
      17             :      * @author Andi Braimllari - initial code
      18             :      *
      19             :      * @tparam data_t data type for the domain and range of the problem, defaulting to real_t
      20             :      * @tparam XSolver Solver type handling the x update
      21             :      * @tparam ZSolver ProximityOperator type handling the z update
      22             :      *
      23             :      * ADMM solves minimization splitting problems of the form
      24             :      * @f$ x \mapsto f(x) + g(z) @f$ such that @f$ Ax + Bz = c @f$.
      25             :      * Commonly regularized optimization problems can be rewritten in such a form by using variable
      26             :      * splitting.
      27             :      *
      28             :      * ADMM can be expressed in the following scaled form
      29             :      *
      30             :      *  - @f$ x_{k+1} = argmin_{x}(f(x) + (\rho/2) ·\| Ax + Bz_{k} - c + u_{k}\|^2_2) @f$
      31             :      *  - @f$ z_{k+1} = argmin_{z}(g(z) + (\rho/2) ·\| Ax_{k+1} + Bz - c + u_{k}\|^2_2) @f$
      32             :      *  - @f$ u_{k+1} = u_{k} + Ax_{k+1} + Bz_{k+1} - c @f$
      33             :      *
      34             :      * References:
      35             :      * https://stanford.edu/~boyd/papers/pdf/admm_distr_stats.pdf
      36             :      */
      37             :     template <template <typename> class XSolver, template <typename> class ZSolver,
      38             :               typename data_t = real_t>
      39             :     class ADMM : public Solver<data_t>
      40             :     {
      41             :     public:
      42             :         /// Scalar alias
      43             :         using Scalar = typename Solver<data_t>::Scalar;
      44             : 
      45           0 :         ADMM(const SplittingProblem<data_t>& splittingProblem) : Solver<data_t>(splittingProblem)
      46             :         {
      47             :             static_assert(std::is_base_of<Solver<data_t>, XSolver<data_t>>::value,
      48             :                           "ADMM: XSolver must extend Solver");
      49             : 
      50             :             static_assert(std::is_base_of<ProximityOperator<data_t>, ZSolver<data_t>>::value,
      51             :                           "ADMM: ZSolver must extend ProximityOperator");
      52           0 :         }
      53             : 
      54             :         ADMM(const SplittingProblem<data_t>& splittingProblem, index_t defaultXSolverIterations)
      55             :             : Solver<data_t>(splittingProblem), _defaultXSolverIterations{defaultXSolverIterations}
      56             :         {
      57             :             static_assert(std::is_base_of<Solver<data_t>, XSolver<data_t>>::value,
      58             :                           "ADMM: XSolver must extend Solver");
      59             : 
      60             :             static_assert(std::is_base_of<ProximityOperator<data_t>, ZSolver<data_t>>::value,
      61             :                           "ADMM: ZSolver must extend ProximityOperator");
      62             :         }
      63             : 
      64             :         ADMM(const SplittingProblem<data_t>& splittingProblem, index_t defaultXSolverIterations,
      65             :              data_t epsilonAbs, data_t epsilonRel)
      66             :             : Solver<data_t>(splittingProblem),
      67             :               _defaultXSolverIterations{defaultXSolverIterations},
      68             :               _epsilonAbs{epsilonAbs},
      69             :               _epsilonRel{epsilonRel}
      70             :         {
      71             :             static_assert(std::is_base_of<Solver<data_t>, XSolver<data_t>>::value,
      72             :                           "ADMM: XSolver must extend Solver");
      73             : 
      74             :             static_assert(std::is_base_of<ProximityOperator<data_t>, ZSolver<data_t>>::value,
      75             :                           "ADMM: ZSolver must extend ProximityOperator");
      76             :         }
      77             : 
      78             :         /// default destructor
      79           0 :         ~ADMM() override = default;
      80             : 
      81           0 :         auto solveImpl(index_t iterations) -> DataContainer<data_t>& override
      82             :         {
      83           0 :             if (iterations == 0)
      84           0 :                 iterations = _defaultIterations;
      85             : 
      86           0 :             auto& splittingProblem = downcast<SplittingProblem<data_t>>(*_problem);
      87             : 
      88           0 :             const auto& f = splittingProblem.getF();
      89           0 :             const auto& g = splittingProblem.getG();
      90             : 
      91           0 :             const auto& dataTerm = f;
      92             : 
      93           0 :             if (!is<L2NormPow2<data_t>>(dataTerm)) {
      94           0 :                 throw std::invalid_argument(
      95             :                     "ADMM::solveImpl: supported data term only of type L2NormPow2");
      96             :             }
      97             : 
      98             :             // Safe as long as only LinearResidual exits
      99           0 :             const auto& dataTermResidual = downcast<LinearResidual<data_t>>(f.getResidual());
     100             : 
     101           0 :             if (g.size() != 1) {
     102           0 :                 throw std::invalid_argument(
     103             :                     "ADMM::solveImpl: supported number of regularization terms is 1");
     104             :             }
     105             : 
     106           0 :             data_t regWeight = g[0].getWeight();
     107           0 :             Functional<data_t>& regularizationTerm = g[0].getFunctional();
     108             : 
     109           0 :             if (!is<L0PseudoNorm<data_t>>(regularizationTerm)
     110           0 :                 && !is<L1Norm<data_t>>(regularizationTerm)) {
     111           0 :                 throw std::invalid_argument("ADMM::solveImpl: supported regularization terms are "
     112             :                                             "of type L0PseudoNorm or L1Norm");
     113             :             }
     114             : 
     115           0 :             const auto& constraint = splittingProblem.getConstraint();
     116           0 :             const auto& A = constraint.getOperatorA();
     117           0 :             const auto& B = constraint.getOperatorB();
     118           0 :             const auto& c = constraint.getDataVectorC();
     119             : 
     120           0 :             DataContainer<data_t> x(A.getRangeDescriptor());
     121           0 :             x = 0;
     122             : 
     123           0 :             DataContainer<data_t> z(B.getRangeDescriptor());
     124           0 :             z = 0;
     125             : 
     126           0 :             DataContainer<data_t> u(c.getDataDescriptor());
     127           0 :             u = 0;
     128             : 
     129           0 :             Logger::get("ADMM")->info("{:*^20}|{:*^20}|{:*^20}|{:*^20}|{:*^20}|{:*^20}",
     130             :                                       "iteration", "xL2NormSq", "zL2NormSq", "uL2NormSq",
     131             :                                       "rkL2Norm", "skL2Norm");
     132             : 
     133           0 :             for (index_t iter = 0; iter < iterations; ++iter) {
     134           0 :                 LinearResidual<data_t> xLinearResidual(A, c - B.apply(z) - u);
     135           0 :                 RegularizationTerm xRegTerm(_rho / 2, L2NormPow2<data_t>(xLinearResidual));
     136           0 :                 Problem<data_t> xUpdateProblem(dataTerm, xRegTerm, x);
     137             : 
     138           0 :                 XSolver<data_t> xSolver(xUpdateProblem);
     139           0 :                 x = xSolver.solve(_defaultXSolverIterations);
     140             : 
     141           0 :                 DataContainer<data_t> rk = x;
     142           0 :                 DataContainer<data_t> zPrev = z;
     143           0 :                 data_t Axnorm = x.l2Norm();
     144             : 
     145             :                 /// For future reference, below is listed the problem to be solved by the z update
     146             :                 /// solver. Refer to the documentation of ADMM for further details.
     147             :                 // LinearResidual<data_t> zLinearResidual(B, c - A.apply(x) - u);
     148             :                 // RegularizationTerm zRegTerm(_rho / 2, L2NormPow2<data_t>(zLinearResidual));
     149             :                 // Problem<data_t> zUpdateProblem(regularizationTerm, zRegTerm, z);
     150             : 
     151           0 :                 ZSolver<data_t> zProxOp(A.getRangeDescriptor());
     152           0 :                 z = zProxOp.apply(x + u, geometry::Threshold{regWeight / _rho});
     153             : 
     154           0 :                 rk -= z;
     155           0 :                 DataContainer<data_t> sk = zPrev - z;
     156           0 :                 sk *= _rho;
     157             : 
     158           0 :                 u += A.apply(x) + B.apply(z) - c;
     159             : 
     160           0 :                 DataContainer<data_t> Atu = u;
     161           0 :                 Atu *= _rho;
     162           0 :                 data_t rkL2Norm = rk.l2Norm();
     163           0 :                 data_t skL2Norm = sk.l2Norm();
     164             : 
     165           0 :                 Logger::get("ADMM")->info("{:<19}| {:<19}| {:<19}| {:<19}| {:<19}| {:<19}", iter,
     166           0 :                                           x.squaredL2Norm(), z.squaredL2Norm(), u.squaredL2Norm(),
     167             :                                           rkL2Norm, skL2Norm);
     168             : 
     169             :                 /// variables for the stopping criteria
     170           0 :                 const data_t cL2Norm = !dataTermResidual.hasDataVector()
     171           0 :                                            ? static_cast<data_t>(0.0)
     172           0 :                                            : dataTermResidual.getDataVector().l2Norm();
     173           0 :                 const data_t epsRelMax =
     174           0 :                     _epsilonRel * std::max(std::max(Axnorm, z.l2Norm()), cL2Norm);
     175           0 :                 const auto epsilonPri = (std::sqrt(rk.getSize()) * _epsilonAbs) + epsRelMax;
     176             : 
     177           0 :                 const data_t epsRelL2Norm = _epsilonRel * Atu.l2Norm();
     178           0 :                 const auto epsilonDual = (std::sqrt(sk.getSize()) * _epsilonAbs) + epsRelL2Norm;
     179             : 
     180           0 :                 if (rkL2Norm <= epsilonPri && skL2Norm <= epsilonDual) {
     181           0 :                     Logger::get("ADMM")->info("SUCCESS: Reached convergence at {}/{} iterations ",
     182             :                                               iter, iterations);
     183             : 
     184           0 :                     getCurrentSolution() = x;
     185           0 :                     return getCurrentSolution();
     186             :                 }
     187             : 
     188             :                 /// varying penalty parameter
     189           0 :                 if (std::abs(_tauIncr - static_cast<data_t>(1.0))
     190           0 :                         > std::numeric_limits<data_t>::epsilon()
     191           0 :                     || std::abs(_tauDecr - static_cast<data_t>(1.0))
     192           0 :                            > std::numeric_limits<data_t>::epsilon()) {
     193           0 :                     if (rkL2Norm > _mu * skL2Norm) {
     194           0 :                         _rho *= _tauIncr;
     195           0 :                         u /= _tauIncr;
     196           0 :                     } else if (skL2Norm > _mu * rkL2Norm) {
     197           0 :                         _rho /= _tauDecr;
     198           0 :                         u *= _tauDecr;
     199             :                     }
     200             :                 }
     201             :             }
     202             : 
     203           0 :             Logger::get("ADMM")->warn("Failed to reach convergence at {} iterations", iterations);
     204             : 
     205           0 :             getCurrentSolution() = x;
     206           0 :             return getCurrentSolution();
     207           0 :         }
     208             : 
     209             :         /// lift the base class method getCurrentSolution
     210             :         using Solver<data_t>::getCurrentSolution;
     211             : 
     212             :     protected:
     213             :         /// implement the polymorphic clone operation
     214           0 :         auto cloneImpl() const -> ADMM<XSolver, ZSolver, data_t>* override
     215             :         {
     216           0 :             return new ADMM<XSolver, ZSolver, data_t>(
     217           0 :                 downcast<SplittingProblem<data_t>>(*_problem));
     218             :         }
     219             : 
     220             :     private:
     221             :         /// lift the base class variable _problem
     222             :         using Solver<data_t>::_problem;
     223             : 
     224             :         /// the default number of iterations for ADMM
     225             :         index_t _defaultIterations{100};
     226             : 
     227             :         /// the default number of iterations for the XSolver
     228             :         index_t _defaultXSolverIterations{5};
     229             : 
     230             :         /// @f$ \rho @f$ from the problem definition
     231             :         data_t _rho{1};
     232             : 
     233             :         /// variables for varying penalty parameter @f$ \rho @f$
     234             :         data_t _mu{10};
     235             :         data_t _tauIncr{2};
     236             :         data_t _tauDecr{2};
     237             : 
     238             :         /// variables for the stopping criteria
     239             :         data_t _epsilonAbs{1e-5f};
     240             :         data_t _epsilonRel{1e-5f};
     241             :     };
     242             : } // namespace elsa

Generated by: LCOV version 1.14