Line data Source code
1 : #pragma once 2 : 3 : #include <limits> 4 : #include <memory> 5 : #include <optional> 6 : 7 : #include "DataContainer.h" 8 : #include "Functional.h" 9 : #include "LinearOperator.h" 10 : #include "MaybeUninitialized.hpp" 11 : #include "Solver.h" 12 : #include "StrongTypes.h" 13 : #include "ProximalOperator.h" 14 : #include "LineSearchMethod.h" 15 : #include "FixedStepSize.h" 16 : 17 : namespace elsa 18 : { 19 : /** 20 : * @brief Proximal Gradient Descent (PGD) 21 : * 22 : * PGD minimizes function of the form: 23 : * 24 : * @f[ 25 : * \min_x g(x) + h(x) 26 : * @f] 27 : * 28 : * where @f$g: \mathbb{R}^n \to \mathbb{R}@f$ is convex and differentiable, 29 : * and @f$h: \mathbb{R}^n \to \mathbb{R} \cup \{-\infty, \infty\}@f$ is closed 30 : * convex. Importantly @f$h@f$ needs not to be differentiable, but it needs 31 : * an proximal operator. Usually, the proximal operator is assumed to be simple, 32 : * and have an analytical solution. 33 : * 34 : * This class currently implements the special case of @f$ g(x) = \frac{1}{2} 35 : * ||A x - b||_2^2@f$. However, @f$h@f$ can be chosen freely. 36 : * 37 : * Given @f$g@f$ defined as above and a convex set @f$\mathcal{C}@f$, one can 38 : * define an constrained optimization problem: 39 : * @f[ 40 : * \min_{x \in \mathcal{C}} g(x) 41 : * @f] 42 : * Such constraints can take the form of, non-negativity or box constraints. 43 : * This can be reformulated as an unconstrained problem: 44 : * @f[ 45 : * \min_{x} g(x) + \mathcal{I}_{\mathcal{C}}(x) 46 : * @f] 47 : * where @f$\mathcal{I}_{\mathcal{C}}(x)@f$ is the indicator function of the 48 : * convex set @f$\mathcal{C}@f$, defined as: 49 : * 50 : * @f[ 51 : * \mathcal{I}_{\mathcal{C}}(x) = 52 : * \begin{cases} 53 : * 0, & \text{if } x \in \mathcal{C} \\ 54 : * \infty, & \text{if } x \notin \mathcal{C} 55 : * \end{cases} 56 : * @f] 57 : * 58 : * References: 59 : * - 60 : * http://www.cs.cmu.edu/afs/cs/Web/People/airg/readings/2012_02_21_a_fast_iterative_shrinkage-thresholding.pdf 61 : * - https://arxiv.org/pdf/2008.02683.pdf 62 : * 63 : * @note PGD has a worst-case complexity result of @f$ O(1/k) @f$. 64 : * 65 : * @note A special class of optimization is of the form: 66 : * @f[ 67 : * \min_{x} \frac{1}{2} || A x - b ||_2^2 + ||x||_1 68 : * @f] 69 : * often referred to as @f$\ell_1@f$-Regularization. In this case, the proximal operator 70 : * for the @f$\ell_1@f$-Regularization is the soft thresolding operator (ProximalL1). This 71 : * can also be extended with constrains, such as non-negativity constraints. 72 : * 73 : * @see An accerlerated version of proximal gradient descent is APGD. 74 : * 75 : * @author 76 : * - Andi Braimllari - initial code 77 : * - David Frank - generalization 78 : * 79 : * @tparam data_t data type for the domain and range of the problem, defaulting to real_t 80 : */ 81 : template <typename data_t = real_t> 82 : class PGD : public Solver<data_t> 83 : { 84 : public: 85 : /// Scalar alias 86 : using Scalar = typename Solver<data_t>::Scalar; 87 : 88 : /// Construct PGD with a least squares data fidelity term 89 : /// 90 : /// @note The step length for least squares can be chosen to be dependend on the Lipschitz 91 : /// constant. Compute it using `powerIterations(adjoint(A) * A)`. Depending on which 92 : /// literature, both \f$ \frac{2}{L} \f$ and \f$ \frac{1}{L}\f$. If mu is not given, the 93 : /// step length is chosen by default, the computation of the power method might be 94 : /// expensive. 95 : /// 96 : /// @param A the operator for the least squares data term 97 : /// @param b the measured data for the least squares data term 98 : /// @param h prox-friendly function 99 : /// @param mu the step length 100 : PGD(const LinearOperator<data_t>& A, const DataContainer<data_t>& b, 101 : const Functional<data_t>& h, std::optional<data_t> mu = std::nullopt, 102 : data_t epsilon = std::numeric_limits<data_t>::epsilon()); 103 : 104 : /// Construct PGD with a weighted least squares data fidelity term 105 : /// 106 : /// @note The step length for least squares can be chosen to be dependend on the Lipschitz 107 : /// constant. Compute it using `powerIterations(adjoint(A) * A)`. Depending on which 108 : /// literature, both \f$ \frac{2}{L} \f$ and \f$ \frac{1}{L}\f$. If mu is not given, the 109 : /// step length is chosen by default, the computation of the power method might be 110 : /// expensive. 111 : /// 112 : /// @param A the operator for the least squares data term 113 : /// @param b the measured data for the least squares data term 114 : /// @param W the weights (usually `counts / I0`) 115 : /// @param prox the proximal operator for g 116 : /// @param mu the step length 117 : PGD(const LinearOperator<data_t>& A, const DataContainer<data_t>& b, 118 : const DataContainer<data_t>& W, const Functional<data_t>& h, 119 : std::optional<data_t> mu = std::nullopt, 120 : data_t epsilon = std::numeric_limits<data_t>::epsilon()); 121 : 122 : /// Construct PGD with a given data fidelity term 123 : /// 124 : /// @param g differentiable function 125 : /// @param h prox-friendly function 126 : /// @param mu the step length 127 : PGD(const Functional<data_t>& g, const Functional<data_t>& h, data_t mu, 128 : data_t epsilon = std::numeric_limits<data_t>::epsilon()); 129 : 130 : /// Construct PGD with a least squares data fidelity term 131 : /// 132 : /// @note The step length for least squares can be chosen to be dependend on the Lipschitz 133 : /// constant. Compute it using `powerIterations(adjoint(A) * A)`. Depending on which 134 : /// literature, both \f$ \frac{2}{L} \f$ and \f$ \frac{1}{L}\f$. 135 : /// 136 : /// @param A the operator for the least squares data term 137 : /// @param b the measured data for the least squares data term 138 : /// @param h prox-friendly function 139 : /// @param lineSearchMethod the line search method to find the step size at 140 : /// each iteration 141 : PGD(const LinearOperator<data_t>& A, const DataContainer<data_t>& b, 142 : const Functional<data_t>& h, const LineSearchMethod<data_t>& lineSearchMethod, 143 : data_t epsilon = std::numeric_limits<data_t>::epsilon()); 144 : 145 : /// Construct PGD with a weighted least squares data fidelity term 146 : /// 147 : /// @note The step length for least squares can be chosen to be dependend on the Lipschitz 148 : /// constant. Compute it using `powerIterations(adjoint(A) * A)`. Depending on which 149 : /// literature, both \f$ \frac{2}{L} \f$ and \f$ \frac{1}{L}\f$. 150 : /// 151 : /// @param A the operator for the least squares data term 152 : /// @param b the measured data for the least squares data term 153 : /// @param W the weights (usually `counts / I0`) 154 : /// @param prox the proximal operator for g 155 : /// @param lineSearchMethod the line search method to find the step size at 156 : /// each iteration 157 : PGD(const LinearOperator<data_t>& A, const DataContainer<data_t>& b, 158 : const DataContainer<data_t>& W, const Functional<data_t>& h, 159 : const LineSearchMethod<data_t>& lineSearchMethod, 160 : data_t epsilon = std::numeric_limits<data_t>::epsilon()); 161 : 162 : /// Construct PGD with a given data fidelity term 163 : /// 164 : /// @param g differentiable function 165 : /// @param h prox-friendly function 166 : /// @param lineSearchMethod the line search method to find the step size at 167 : /// each iteration 168 : PGD(const Functional<data_t>& g, const Functional<data_t>& h, 169 : const LineSearchMethod<data_t>& lineSearchMethod, 170 : data_t epsilon = std::numeric_limits<data_t>::epsilon()); 171 : 172 : /// make copy constructor deletion explicit 173 : PGD(const PGD<data_t>&) = delete; 174 : 175 : /// default destructor 176 12 : ~PGD() override = default; 177 : 178 : /** 179 : * @brief Solve the optimization problem, i.e. apply iterations number of iterations of 180 : * PGD 181 : * 182 : * @param[in] iterations number of iterations to execute 183 : * @param[in] x0 optional initial solution, initial solution set to zero if not present 184 : * 185 : * @returns the approximated solution 186 : */ 187 : DataContainer<data_t> 188 : solve(index_t iterations, 189 : std::optional<DataContainer<data_t>> x0 = std::nullopt) override; 190 : 191 : protected: 192 : /// implement the polymorphic clone operation 193 : auto cloneImpl() const -> PGD<data_t>* override; 194 : 195 : /// implement the polymorphic comparison operation 196 : auto isEqual(const Solver<data_t>& other) const -> bool override; 197 : 198 : private: 199 : /// differentiable function of problem formulation 200 : std::unique_ptr<Functional<data_t>> g_; 201 : 202 : /// prox-friendly function of problem formulation 203 : std::unique_ptr<Functional<data_t>> h_; 204 : 205 : /// variable affecting the stopping condition 206 : data_t epsilon_; 207 : 208 : /// the line search method 209 : std::unique_ptr<LineSearchMethod<data_t>> lineSearchMethod_; 210 : }; 211 : 212 : template <class data_t = real_t> 213 : using ISTA = PGD<data_t>; 214 : } // namespace elsa