Line data Source code
1 : #pragma once
2 :
3 : #include <memory>
4 :
5 : #include "Solver.h"
6 : #include "ProximityOperator.h"
7 : #include "SplittingProblem.h"
8 : #include "L0PseudoNorm.h"
9 : #include "L1Norm.h"
10 : #include "L2NormPow2.h"
11 : #include "LinearResidual.h"
12 : #include "Logger.h"
13 :
14 : namespace elsa
15 : {
16 : /**
17 : * @brief Class representing an Alternating Direction Method of Multipliers solver
18 : *
19 : * @author Andi Braimllari - initial code
20 : *
21 : * @tparam data_t data type for the domain and range of the problem, defaulting to real_t
22 : * @tparam XSolver Solver type handling the x update
23 : * @tparam ZSolver ProximityOperator type handling the z update
24 : *
25 : * ADMM solves minimization splitting problems of the form
26 : * @f$ x \mapsto f(x) + g(z) @f$ such that @f$ Ax + Bz = c @f$.
27 : * Commonly regularized optimization problems can be rewritten in such a form by using variable
28 : * splitting.
29 : *
30 : * ADMM can be expressed in the following scaled form
31 : *
32 : * - @f$ x_{k+1} = argmin_{x}(f(x) + (\rho/2) ·\| Ax + Bz_{k} - c + u_{k}\|^2_2) @f$
33 : * - @f$ z_{k+1} = argmin_{z}(g(z) + (\rho/2) ·\| Ax_{k+1} + Bz - c + u_{k}\|^2_2) @f$
34 : * - @f$ u_{k+1} = u_{k} + Ax_{k+1} + Bz_{k+1} - c @f$
35 : *
36 : * References:
37 : * https://stanford.edu/~boyd/papers/pdf/admm_distr_stats.pdf
38 : */
39 : template <template <typename> class XSolver, template <typename> class ZSolver,
40 : typename data_t = real_t>
41 : class ADMM : public Solver<data_t>
42 : {
43 : public:
44 : /// Scalar alias
45 : using Scalar = typename Solver<data_t>::Scalar;
46 :
47 : ADMM(const SplittingProblem<data_t>& splittingProblem)
48 : : Solver<data_t>(),
49 : _problem(static_cast<SplittingProblem<data_t>*>(splittingProblem.clone().release()))
50 4 : {
51 4 : static_assert(std::is_base_of<Solver<data_t>, XSolver<data_t>>::value,
52 4 : "ADMM: XSolver must extend Solver");
53 :
54 4 : static_assert(std::is_base_of<ProximityOperator<data_t>, ZSolver<data_t>>::value,
55 4 : "ADMM: ZSolver must extend ProximityOperator");
56 4 : }
57 :
58 : ADMM(const SplittingProblem<data_t>& splittingProblem, index_t defaultXSolverIterations)
59 : : Solver<data_t>(),
60 : _problem(static_cast<SplittingProblem<data_t>*>(splittingProblem.clone().release())),
61 : _defaultXSolverIterations{defaultXSolverIterations}
62 : {
63 : static_assert(std::is_base_of<Solver<data_t>, XSolver<data_t>>::value,
64 : "ADMM: XSolver must extend Solver");
65 :
66 : static_assert(std::is_base_of<ProximityOperator<data_t>, ZSolver<data_t>>::value,
67 : "ADMM: ZSolver must extend ProximityOperator");
68 : }
69 :
70 : ADMM(const SplittingProblem<data_t>& splittingProblem, index_t defaultXSolverIterations,
71 : data_t epsilonAbs, data_t epsilonRel)
72 : : Solver<data_t>(),
73 : _problem(splittingProblem),
74 : _defaultXSolverIterations{defaultXSolverIterations},
75 : _epsilonAbs{epsilonAbs},
76 : _epsilonRel{epsilonRel}
77 : {
78 : static_assert(std::is_base_of<Solver<data_t>, XSolver<data_t>>::value,
79 : "ADMM: XSolver must extend Solver");
80 :
81 : static_assert(std::is_base_of<ProximityOperator<data_t>, ZSolver<data_t>>::value,
82 : "ADMM: ZSolver must extend ProximityOperator");
83 : }
84 :
85 : /// default destructor
86 4 : ~ADMM() override = default;
87 :
88 : auto solveImpl(index_t iterations) -> DataContainer<data_t>& override
89 8 : {
90 8 : if (iterations == 0)
91 0 : iterations = _defaultIterations;
92 :
93 8 : const auto& f = _problem->getF();
94 8 : const auto& g = _problem->getG();
95 :
96 8 : const auto& dataTerm = f;
97 :
98 8 : if (!is<L2NormPow2<data_t>>(dataTerm)) {
99 0 : throw std::invalid_argument(
100 0 : "ADMM::solveImpl: supported data term only of type L2NormPow2");
101 0 : }
102 :
103 : // Safe as long as only LinearResidual exits
104 8 : const auto& dataTermResidual = downcast<LinearResidual<data_t>>(f.getResidual());
105 :
106 8 : if (g.size() != 1) {
107 0 : throw std::invalid_argument(
108 0 : "ADMM::solveImpl: supported number of regularization terms is 1");
109 0 : }
110 :
111 8 : data_t regWeight = g[0].getWeight();
112 8 : Functional<data_t>& regularizationTerm = g[0].getFunctional();
113 :
114 8 : if (!is<L0PseudoNorm<data_t>>(regularizationTerm)
115 8 : && !is<L1Norm<data_t>>(regularizationTerm)) {
116 0 : throw std::invalid_argument("ADMM::solveImpl: supported regularization terms are "
117 0 : "of type L0PseudoNorm or L1Norm");
118 0 : }
119 :
120 8 : const auto& constraint = _problem->getConstraint();
121 8 : const auto& A = constraint.getOperatorA();
122 8 : const auto& B = constraint.getOperatorB();
123 8 : const auto& c = constraint.getDataVectorC();
124 :
125 8 : DataContainer<data_t> x(A.getRangeDescriptor());
126 8 : x = 0;
127 :
128 8 : DataContainer<data_t> z(B.getRangeDescriptor());
129 8 : z = 0;
130 :
131 8 : DataContainer<data_t> u(c.getDataDescriptor());
132 8 : u = 0;
133 :
134 8 : Logger::get("ADMM")->info("{:*^20}|{:*^20}|{:*^20}|{:*^20}|{:*^20}|{:*^20}",
135 8 : "iteration", "xL2NormSq", "zL2NormSq", "uL2NormSq",
136 8 : "rkL2Norm", "skL2Norm");
137 :
138 34 : for (index_t iter = 0; iter < iterations; ++iter) {
139 34 : LinearResidual<data_t> xLinearResidual(A, c - B.apply(z) - u);
140 34 : RegularizationTerm xRegTerm(_rho / 2, L2NormPow2<data_t>(xLinearResidual));
141 34 : Problem<data_t> xUpdateProblem(dataTerm, xRegTerm, x);
142 :
143 34 : XSolver<data_t> xSolver(xUpdateProblem);
144 34 : x = xSolver.solve(_defaultXSolverIterations);
145 :
146 34 : DataContainer<data_t> rk = x;
147 34 : DataContainer<data_t> zPrev = z;
148 34 : data_t Axnorm = x.l2Norm();
149 :
150 : /// For future reference, below is listed the problem to be solved by the z update
151 : /// solver. Refer to the documentation of ADMM for further details.
152 : // LinearResidual<data_t> zLinearResidual(B, c - A.apply(x) - u);
153 : // RegularizationTerm zRegTerm(_rho / 2, L2NormPow2<data_t>(zLinearResidual));
154 : // Problem<data_t> zUpdateProblem(regularizationTerm, zRegTerm, z);
155 :
156 34 : ZSolver<data_t> zProxOp(A.getRangeDescriptor());
157 34 : z = zProxOp.apply(x + u, geometry::Threshold{regWeight / _rho});
158 :
159 34 : rk -= z;
160 34 : DataContainer<data_t> sk = zPrev - z;
161 34 : sk *= _rho;
162 :
163 34 : u += A.apply(x) + B.apply(z) - c;
164 :
165 34 : DataContainer<data_t> Atu = u;
166 34 : Atu *= _rho;
167 34 : data_t rkL2Norm = rk.l2Norm();
168 34 : data_t skL2Norm = sk.l2Norm();
169 :
170 34 : Logger::get("ADMM")->info("{:<19}| {:<19}| {:<19}| {:<19}| {:<19}| {:<19}", iter,
171 34 : x.squaredL2Norm(), z.squaredL2Norm(), u.squaredL2Norm(),
172 34 : rkL2Norm, skL2Norm);
173 :
174 : /// variables for the stopping criteria
175 34 : const data_t cL2Norm = !dataTermResidual.hasDataVector()
176 34 : ? static_cast<data_t>(0.0)
177 34 : : dataTermResidual.getDataVector().l2Norm();
178 34 : const data_t epsRelMax =
179 34 : _epsilonRel * std::max(std::max(Axnorm, z.l2Norm()), cL2Norm);
180 34 : const auto epsilonPri = (std::sqrt(rk.getSize()) * _epsilonAbs) + epsRelMax;
181 :
182 34 : const data_t epsRelL2Norm = _epsilonRel * Atu.l2Norm();
183 34 : const auto epsilonDual = (std::sqrt(sk.getSize()) * _epsilonAbs) + epsRelL2Norm;
184 :
185 34 : if (rkL2Norm <= epsilonPri && skL2Norm <= epsilonDual) {
186 8 : Logger::get("ADMM")->info("SUCCESS: Reached convergence at {}/{} iterations ",
187 8 : iter, iterations);
188 :
189 8 : _problem->getCurrentSolution() = x;
190 8 : return _problem->getCurrentSolution();
191 8 : }
192 :
193 : /// varying penalty parameter
194 26 : if (std::abs(_tauIncr - static_cast<data_t>(1.0))
195 26 : > std::numeric_limits<data_t>::epsilon()
196 26 : || std::abs(_tauDecr - static_cast<data_t>(1.0))
197 26 : > std::numeric_limits<data_t>::epsilon()) {
198 26 : if (rkL2Norm > _mu * skL2Norm) {
199 0 : _rho *= _tauIncr;
200 0 : u /= _tauIncr;
201 26 : } else if (skL2Norm > _mu * rkL2Norm) {
202 26 : _rho /= _tauDecr;
203 26 : u *= _tauDecr;
204 26 : }
205 26 : }
206 26 : }
207 :
208 8 : Logger::get("ADMM")->warn("Failed to reach convergence at {} iterations", iterations);
209 :
210 0 : _problem->getCurrentSolution() = x;
211 0 : return _problem->getCurrentSolution();
212 8 : }
213 :
214 : protected:
215 : /// implement the polymorphic clone operation
216 : auto cloneImpl() const -> ADMM<XSolver, ZSolver, data_t>* override
217 0 : {
218 0 : return new ADMM<XSolver, ZSolver, data_t>(
219 0 : downcast<SplittingProblem<data_t>>(*_problem));
220 0 : }
221 :
222 : bool isEqual(const Solver<data_t>& other) const
223 0 : {
224 0 : auto otherADMM = downcast_safe<ADMM>(&other);
225 0 : if (!otherADMM)
226 0 : return false;
227 :
228 0 : if (_problem != otherADMM->_problem)
229 0 : return false;
230 :
231 0 : return _rho == otherADMM->_rho && _mu == otherADMM->_mu
232 0 : && _tauIncr == otherADMM->_tauIncr && _tauDecr == otherADMM->_tauDecr;
233 0 : }
234 :
235 : private:
236 : /// Splitting problem to solve
237 : /// TODO: Remove requirement of unique_ptr
238 : std::unique_ptr<SplittingProblem<data_t>> _problem;
239 :
240 : /// the default number of iterations for ADMM
241 : index_t _defaultIterations{100};
242 :
243 : /// the default number of iterations for the XSolver
244 : index_t _defaultXSolverIterations{5};
245 :
246 : /// @f$ \rho @f$ from the problem definition
247 : data_t _rho{1};
248 :
249 : /// variables for varying penalty parameter @f$ \rho @f$
250 : data_t _mu{10};
251 : data_t _tauIncr{2};
252 : data_t _tauDecr{2};
253 :
254 : /// variables for the stopping criteria
255 : data_t _epsilonAbs{1e-5f};
256 : data_t _epsilonRel{1e-5f};
257 : };
258 : } // namespace elsa
|