Line data Source code
1 : #pragma once 2 : 3 : #include <memory> 4 : #include <type_traits> 5 : #include <optional> 6 : 7 : #include "Solver.h" 8 : #include "DataContainer.h" 9 : #include "LinearOperator.h" 10 : #include "ProximalOperator.h" 11 : #include "TypeTraits.hpp" 12 : #include "elsaDefines.h" 13 : 14 : namespace elsa 15 : { 16 : /** 17 : * @brief Class representing an Alternating Direction Method of Multipliers solver 18 : * for a specific subset of constraints 19 : * 20 : * The general form of ADMM solves the following optimization problem: 21 : * \f[ 22 : * \min f(x) + g(z) \\ 23 : * \text{s.t. } Ax + Bz = c 24 : * \f] 25 : * with \f$x \in \mathbb{R}^n\f$, \f$z \in \mathbb{R}^m\f$, \f$A \in \mathbb{R}^{p\times n}\f$, 26 : * \f$B \in \mathbb{R}^{p\times m}\f$ and \f$c \in \mathbb{R}^p\f$ 27 : * 28 : * This specific version solves the problem of the form: 29 : * \f[ 30 : * \min \frac{1}{2} || Op x - b ||_2^2 + g(z) \\ 31 : * \text{s.t. } Ax = z 32 : * \f] 33 : * with \f$B = Id\f$ and \f$c = 0\f$. Further: \f$f(x) = || Op x - b ||_2^2\f$. 34 : * 35 : * This version of ADMM is useful, as the proximal operator is not known for 36 : * the least squares functional, and this specifically implements and optimization of the 37 : * first update step of ADMM. In this implementation, this is done via CGLS. 38 : * 39 : * The update steps for ADMM are: 40 : * \f[ 41 : * x_{k+1} = \argmin_{x} \frac{1}{2}||Op x - b||_2^2 + \frac{1}{2\tau} ||Ax - z_k + u_k||_2^2 \\ 42 : * z_{k+1} = \prox_{\tau g}(Ax_{k+1} + u_{k}) \\ 43 : * u_{k+1} = u_k + Ax_{k+1} - z_{k+1} 44 : * \f] 45 : * 46 : * This is further useful to solve problems such as TV, by setting the \f$A = \nabla\f$. 47 : * And \f$ g = || \dot ||_1\f$ 48 : * 49 : * References: 50 : * - Distributed Optimization and Statistical Learning via the Alternating Direction Method of 51 : * Multipliers, by Boyd et al. 52 : * - Chapter 5.3 of "An introduction to continuous optimization for imaging", by Chambolle and 53 : * Pock 54 : */ 55 : template <typename data_t = real_t> 56 : class ADMML2 : public Solver<data_t> 57 : { 58 : public: 59 : /// Scalar alias 60 : using Scalar = typename Solver<data_t>::Scalar; 61 : 62 : ADMML2(const LinearOperator<data_t>& op, const DataContainer<data_t>& b, 63 : const LinearOperator<data_t>& A, const ProximalOperator<data_t>& proxg, 64 : std::optional<data_t> tau = std::nullopt, index_t ninneriters = 5); 65 : 66 : ADMML2(const LinearOperator<data_t>& op, const DataContainer<data_t>& b, 67 : const DataContainer<data_t>& W, const LinearOperator<data_t>& A, 68 : const ProximalOperator<data_t>& proxg, std::optional<data_t> tau = std::nullopt, 69 : index_t ninneriters = 5); 70 : 71 : /// default destructor 72 2 : ~ADMML2() override = default; 73 : 74 : DataContainer<data_t> 75 : solve(index_t iterations, 76 : std::optional<DataContainer<data_t>> x0 = std::nullopt) override; 77 : 78 : protected: 79 : /// implement the polymorphic clone operation 80 : ADMML2<data_t>* cloneImpl() const override; 81 : 82 : /// implement the polymorphic equality operation 83 : bool isEqual(const Solver<data_t>& other) const override; 84 : 85 : private: 86 : std::unique_ptr<LinearOperator<data_t>> op_; 87 : 88 : DataContainer<data_t> b_; 89 : 90 : std::unique_ptr<LinearOperator<data_t>> A_; 91 : 92 : std::optional<DataContainer<data_t>> W_ = std::nullopt; 93 : 94 : ProximalOperator<data_t> proxg_; 95 : 96 : /// @f$ \tau @f$ from the problem definition 97 : data_t tau_{1}; 98 : 99 : index_t ninneriters_; 100 : }; 101 : } // namespace elsa