Line data Source code
1 : #pragma once
2 :
3 : #include "Solver.h"
4 : #include "ProximityOperator.h"
5 : #include "SplittingProblem.h"
6 : #include "L0PseudoNorm.h"
7 : #include "L1Norm.h"
8 : #include "L2NormPow2.h"
9 : #include "LinearResidual.h"
10 : #include "Logger.h"
11 :
12 : namespace elsa
13 : {
14 : /**
15 : * @brief Class representing an Alternating Direction Method of Multipliers solver
16 : *
17 : * @author Andi Braimllari - initial code
18 : *
19 : * @tparam data_t data type for the domain and range of the problem, defaulting to real_t
20 : * @tparam XSolver Solver type handling the x update
21 : * @tparam ZSolver ProximityOperator type handling the z update
22 : *
23 : * ADMM solves minimization splitting problems of the form
24 : * @f$ x \mapsto f(x) + g(z) @f$ such that @f$ Ax + Bz = c @f$.
25 : * Commonly regularized optimization problems can be rewritten in such a form by using variable
26 : * splitting.
27 : *
28 : * ADMM can be expressed in the following scaled form
29 : *
30 : * - @f$ x_{k+1} = argmin_{x}(f(x) + (\rho/2) ·\| Ax + Bz_{k} - c + u_{k}\|^2_2) @f$
31 : * - @f$ z_{k+1} = argmin_{z}(g(z) + (\rho/2) ·\| Ax_{k+1} + Bz - c + u_{k}\|^2_2) @f$
32 : * - @f$ u_{k+1} = u_{k} + Ax_{k+1} + Bz_{k+1} - c @f$
33 : *
34 : * References:
35 : * https://stanford.edu/~boyd/papers/pdf/admm_distr_stats.pdf
36 : */
37 : template <template <typename> class XSolver, template <typename> class ZSolver,
38 : typename data_t = real_t>
39 : class ADMM : public Solver<data_t>
40 : {
41 : public:
42 : /// Scalar alias
43 : using Scalar = typename Solver<data_t>::Scalar;
44 :
45 0 : ADMM(const SplittingProblem<data_t>& splittingProblem) : Solver<data_t>(splittingProblem)
46 : {
47 : static_assert(std::is_base_of<Solver<data_t>, XSolver<data_t>>::value,
48 : "ADMM: XSolver must extend Solver");
49 :
50 : static_assert(std::is_base_of<ProximityOperator<data_t>, ZSolver<data_t>>::value,
51 : "ADMM: ZSolver must extend ProximityOperator");
52 0 : }
53 :
54 : ADMM(const SplittingProblem<data_t>& splittingProblem, index_t defaultXSolverIterations)
55 : : Solver<data_t>(splittingProblem), _defaultXSolverIterations{defaultXSolverIterations}
56 : {
57 : static_assert(std::is_base_of<Solver<data_t>, XSolver<data_t>>::value,
58 : "ADMM: XSolver must extend Solver");
59 :
60 : static_assert(std::is_base_of<ProximityOperator<data_t>, ZSolver<data_t>>::value,
61 : "ADMM: ZSolver must extend ProximityOperator");
62 : }
63 :
64 : ADMM(const SplittingProblem<data_t>& splittingProblem, index_t defaultXSolverIterations,
65 : data_t epsilonAbs, data_t epsilonRel)
66 : : Solver<data_t>(splittingProblem),
67 : _defaultXSolverIterations{defaultXSolverIterations},
68 : _epsilonAbs{epsilonAbs},
69 : _epsilonRel{epsilonRel}
70 : {
71 : static_assert(std::is_base_of<Solver<data_t>, XSolver<data_t>>::value,
72 : "ADMM: XSolver must extend Solver");
73 :
74 : static_assert(std::is_base_of<ProximityOperator<data_t>, ZSolver<data_t>>::value,
75 : "ADMM: ZSolver must extend ProximityOperator");
76 : }
77 :
78 : /// default destructor
79 0 : ~ADMM() override = default;
80 :
81 0 : auto solveImpl(index_t iterations) -> DataContainer<data_t>& override
82 : {
83 0 : if (iterations == 0)
84 0 : iterations = _defaultIterations;
85 :
86 0 : auto& splittingProblem = downcast<SplittingProblem<data_t>>(*_problem);
87 :
88 0 : const auto& f = splittingProblem.getF();
89 0 : const auto& g = splittingProblem.getG();
90 :
91 0 : const auto& dataTerm = f;
92 :
93 0 : if (!is<L2NormPow2<data_t>>(dataTerm)) {
94 0 : throw std::invalid_argument(
95 : "ADMM::solveImpl: supported data term only of type L2NormPow2");
96 : }
97 :
98 : // Safe as long as only LinearResidual exits
99 0 : const auto& dataTermResidual = downcast<LinearResidual<data_t>>(f.getResidual());
100 :
101 0 : if (g.size() != 1) {
102 0 : throw std::invalid_argument(
103 : "ADMM::solveImpl: supported number of regularization terms is 1");
104 : }
105 :
106 0 : data_t regWeight = g[0].getWeight();
107 0 : Functional<data_t>& regularizationTerm = g[0].getFunctional();
108 :
109 0 : if (!is<L0PseudoNorm<data_t>>(regularizationTerm)
110 0 : && !is<L1Norm<data_t>>(regularizationTerm)) {
111 0 : throw std::invalid_argument("ADMM::solveImpl: supported regularization terms are "
112 : "of type L0PseudoNorm or L1Norm");
113 : }
114 :
115 0 : const auto& constraint = splittingProblem.getConstraint();
116 0 : const auto& A = constraint.getOperatorA();
117 0 : const auto& B = constraint.getOperatorB();
118 0 : const auto& c = constraint.getDataVectorC();
119 :
120 0 : DataContainer<data_t> x(A.getRangeDescriptor());
121 0 : x = 0;
122 :
123 0 : DataContainer<data_t> z(B.getRangeDescriptor());
124 0 : z = 0;
125 :
126 0 : DataContainer<data_t> u(c.getDataDescriptor());
127 0 : u = 0;
128 :
129 0 : Logger::get("ADMM")->info("{:*^20}|{:*^20}|{:*^20}|{:*^20}|{:*^20}|{:*^20}",
130 : "iteration", "xL2NormSq", "zL2NormSq", "uL2NormSq",
131 : "rkL2Norm", "skL2Norm");
132 :
133 0 : for (index_t iter = 0; iter < iterations; ++iter) {
134 0 : LinearResidual<data_t> xLinearResidual(A, c - B.apply(z) - u);
135 0 : RegularizationTerm xRegTerm(_rho / 2, L2NormPow2<data_t>(xLinearResidual));
136 0 : Problem<data_t> xUpdateProblem(dataTerm, xRegTerm, x);
137 :
138 0 : XSolver<data_t> xSolver(xUpdateProblem);
139 0 : x = xSolver.solve(_defaultXSolverIterations);
140 :
141 0 : DataContainer<data_t> rk = x;
142 0 : DataContainer<data_t> zPrev = z;
143 0 : data_t Axnorm = x.l2Norm();
144 :
145 : /// For future reference, below is listed the problem to be solved by the z update
146 : /// solver. Refer to the documentation of ADMM for further details.
147 : // LinearResidual<data_t> zLinearResidual(B, c - A.apply(x) - u);
148 : // RegularizationTerm zRegTerm(_rho / 2, L2NormPow2<data_t>(zLinearResidual));
149 : // Problem<data_t> zUpdateProblem(regularizationTerm, zRegTerm, z);
150 :
151 0 : ZSolver<data_t> zProxOp(A.getRangeDescriptor());
152 0 : z = zProxOp.apply(x + u, geometry::Threshold{regWeight / _rho});
153 :
154 0 : rk -= z;
155 0 : DataContainer<data_t> sk = zPrev - z;
156 0 : sk *= _rho;
157 :
158 0 : u += A.apply(x) + B.apply(z) - c;
159 :
160 0 : DataContainer<data_t> Atu = u;
161 0 : Atu *= _rho;
162 0 : data_t rkL2Norm = rk.l2Norm();
163 0 : data_t skL2Norm = sk.l2Norm();
164 :
165 0 : Logger::get("ADMM")->info("{:<19}| {:<19}| {:<19}| {:<19}| {:<19}| {:<19}", iter,
166 0 : x.squaredL2Norm(), z.squaredL2Norm(), u.squaredL2Norm(),
167 : rkL2Norm, skL2Norm);
168 :
169 : /// variables for the stopping criteria
170 0 : const data_t cL2Norm = !dataTermResidual.hasDataVector()
171 0 : ? static_cast<data_t>(0.0)
172 0 : : dataTermResidual.getDataVector().l2Norm();
173 0 : const data_t epsRelMax =
174 0 : _epsilonRel * std::max(std::max(Axnorm, z.l2Norm()), cL2Norm);
175 0 : const auto epsilonPri = (std::sqrt(rk.getSize()) * _epsilonAbs) + epsRelMax;
176 :
177 0 : const data_t epsRelL2Norm = _epsilonRel * Atu.l2Norm();
178 0 : const auto epsilonDual = (std::sqrt(sk.getSize()) * _epsilonAbs) + epsRelL2Norm;
179 :
180 0 : if (rkL2Norm <= epsilonPri && skL2Norm <= epsilonDual) {
181 0 : Logger::get("ADMM")->info("SUCCESS: Reached convergence at {}/{} iterations ",
182 : iter, iterations);
183 :
184 0 : getCurrentSolution() = x;
185 0 : return getCurrentSolution();
186 : }
187 :
188 : /// varying penalty parameter
189 0 : if (std::abs(_tauIncr - static_cast<data_t>(1.0))
190 0 : > std::numeric_limits<data_t>::epsilon()
191 0 : || std::abs(_tauDecr - static_cast<data_t>(1.0))
192 0 : > std::numeric_limits<data_t>::epsilon()) {
193 0 : if (rkL2Norm > _mu * skL2Norm) {
194 0 : _rho *= _tauIncr;
195 0 : u /= _tauIncr;
196 0 : } else if (skL2Norm > _mu * rkL2Norm) {
197 0 : _rho /= _tauDecr;
198 0 : u *= _tauDecr;
199 : }
200 : }
201 : }
202 :
203 0 : Logger::get("ADMM")->warn("Failed to reach convergence at {} iterations", iterations);
204 :
205 0 : getCurrentSolution() = x;
206 0 : return getCurrentSolution();
207 0 : }
208 :
209 : /// lift the base class method getCurrentSolution
210 : using Solver<data_t>::getCurrentSolution;
211 :
212 : protected:
213 : /// implement the polymorphic clone operation
214 0 : auto cloneImpl() const -> ADMM<XSolver, ZSolver, data_t>* override
215 : {
216 0 : return new ADMM<XSolver, ZSolver, data_t>(
217 0 : downcast<SplittingProblem<data_t>>(*_problem));
218 : }
219 :
220 : private:
221 : /// lift the base class variable _problem
222 : using Solver<data_t>::_problem;
223 :
224 : /// the default number of iterations for ADMM
225 : index_t _defaultIterations{100};
226 :
227 : /// the default number of iterations for the XSolver
228 : index_t _defaultXSolverIterations{5};
229 :
230 : /// @f$ \rho @f$ from the problem definition
231 : data_t _rho{1};
232 :
233 : /// variables for varying penalty parameter @f$ \rho @f$
234 : data_t _mu{10};
235 : data_t _tauIncr{2};
236 : data_t _tauDecr{2};
237 :
238 : /// variables for the stopping criteria
239 : data_t _epsilonAbs{1e-5f};
240 : data_t _epsilonRel{1e-5f};
241 : };
242 : } // namespace elsa
|