Line data Source code
1 : 2 : #include <optional> 3 : 4 : #include "DataContainer.h" 5 : #include "Functional.h" 6 : #include "Solver.h" 7 : #include "LinearOperator.h" 8 : #include "StrongTypes.h" 9 : #include "MaybeUninitialized.hpp" 10 : #include "ProximalOperator.h" 11 : #include "LineSearchMethod.h" 12 : #include "FixedStepSize.h" 13 : 14 : namespace elsa 15 : { 16 : /** 17 : * @brief Accelerated Proximal Gradient Descent (APGD) 18 : * 19 : * APGD minimizes function of the the same for as PGD. See the documentation there. 20 : * 21 : * This class represents a APGD solver with the following steps: 22 : * 23 : * - @f$ x_{k} = prox_h(y_k - \mu * A^T (Ay_k - b)) @f$ 24 : * - @f$ t_{k+1} = \frac{1 + \sqrt{1 + 4 * t_{k}^2}}{2} @f$ 25 : * - @f$ y_{k+1} = x_{k} + (\frac{t_{k} - 1}{t_{k+1}}) * (x_{k} - x_{k - 1}) @f$ 26 : * 27 : * APGD has a worst-case complexity result of @f$ O(1/k^2) @f$. 28 : * 29 : * References: 30 : * http://www.cs.cmu.edu/afs/cs/Web/People/airg/readings/2012_02_21_a_fast_iterative_shrinkage-thresholding.pdf 31 : * https://arxiv.org/pdf/2008.02683.pdf 32 : * 33 : * @see For a more detailed discussion of the type of problem for this solver, 34 : * see PGD. 35 : * 36 : * @author 37 : * Andi Braimllari - initial code 38 : * David Frank - generalization to APGD 39 : * 40 : * @tparam data_t data type for the domain and range of the problem, defaulting to real_t 41 : */ 42 : template <typename data_t = real_t> 43 : class APGD : public Solver<data_t> 44 : { 45 : public: 46 : /// Scalar alias 47 : using Scalar = typename Solver<data_t>::Scalar; 48 : 49 : /// Construct APGD with a least squares data fidelity term 50 : /// 51 : /// @note The step length for least squares can be chosen to be dependend on the Lipschitz 52 : /// constant. Compute it using `powerIterations(adjoint(A) * A)`. Depending on the source, 53 : /// both \f$ \frac{2}{L} \f$ and \f$ \frac{1}{L}\f$ seem common. If mu is not given, the 54 : /// step length is chosen by default, the computation of the power method might be 55 : /// expensive. 56 : /// 57 : /// @param A the operator for the least squares data term 58 : /// @param b the measured data for the least squares data term 59 : /// @param prox the proximal operator for g 60 : /// @param mu the step length 61 : APGD(const LinearOperator<data_t>& A, const DataContainer<data_t>& b, 62 : const Functional<data_t>& h, std::optional<data_t> mu = std::nullopt, 63 : data_t epsilon = std::numeric_limits<data_t>::epsilon()); 64 : 65 : /// Construct APGD with a weighted least squares data fidelity term 66 : /// 67 : /// @note The step length for least squares can be chosen to be dependend on the Lipschitz 68 : /// constant. Compute it using `powerIterations(adjoint(A) * A)`. Depending on which 69 : /// literature, both \f$ \frac{2}{L} \f$ and \f$ \frac{1}{L}\f$. If mu is not given, the 70 : /// step length is chosen by default, the computation of the power method might be 71 : /// expensive. 72 : /// 73 : /// @param A the operator for the least squares data term 74 : /// @param b the measured data for the least squares data term 75 : /// @param W the weights (usually `counts / I0`) 76 : /// @param prox the proximal operator for g 77 : /// @param mu the step length 78 : APGD(const LinearOperator<data_t>& A, const DataContainer<data_t>& b, 79 : const DataContainer<data_t>& W, const Functional<data_t>& h, 80 : std::optional<data_t> mu = std::nullopt, 81 : data_t epsilon = std::numeric_limits<data_t>::epsilon()); 82 : 83 : /// Construct APGD with a given data fidelity term 84 : /// 85 : /// @param g differentiable function of the LASSO problem 86 : /// @param h prox friendly functional of the LASSO problem 87 : /// @param mu the step length 88 : APGD(const Functional<data_t>& g, const Functional<data_t>& h, data_t mu, 89 : data_t epsilon = std::numeric_limits<data_t>::epsilon()); 90 : 91 : /// Construct APGD with a least squares data fidelity term 92 : /// 93 : /// @note The step length for least squares can be chosen to be dependend on the Lipschitz 94 : /// constant. Compute it using `powerIterations(adjoint(A) * A)`. Depending on which 95 : /// literature, both \f$ \frac{2}{L} \f$ and \f$ \frac{1}{L}\f$. 96 : /// 97 : /// @param A the operator for the least squares data term 98 : /// @param b the measured data for the least squares data term 99 : /// @param prox the proximal operator for g 100 : /// @param lineSearchMethod the line search method to find the step size at 101 : /// each iteration 102 : APGD(const LinearOperator<data_t>& A, const DataContainer<data_t>& b, 103 : const Functional<data_t>& h, const LineSearchMethod<data_t>& lineSearchMethod, 104 : data_t epsilon = std::numeric_limits<data_t>::epsilon()); 105 : 106 : /// Construct APGD with a weighted least squares data fidelity term 107 : /// 108 : /// @note The step length for least squares can be chosen to be dependend on the Lipschitz 109 : /// constant. Compute it using `powerIterations(adjoint(A) * A)`. Depending on which 110 : /// literature, both \f$ \frac{2}{L} \f$ and \f$ \frac{1}{L}\f$. 111 : /// 112 : /// @param A the operator for the least squares data term 113 : /// @param b the measured data for the least squares data term 114 : /// @param W the weights (usually `counts / I0`) 115 : /// @param prox the proximal operator for g 116 : /// @param lineSearchMethod the line search method to find the step size at 117 : /// each iteration 118 : APGD(const LinearOperator<data_t>& A, const DataContainer<data_t>& b, 119 : const DataContainer<data_t>& W, const Functional<data_t>& h, 120 : const LineSearchMethod<data_t>& lineSearchMethod, 121 : data_t epsilon = std::numeric_limits<data_t>::epsilon()); 122 : 123 : /// Construct APGD with a given data fidelity term 124 : /// 125 : /// @param g differentiable function of the LASSO problem 126 : /// @param h prox friendly functional of the LASSO problem 127 : /// @param lineSearchMethod the line search method to find the step size at 128 : /// each iteration 129 : APGD(const Functional<data_t>& g, const Functional<data_t>& h, 130 : const LineSearchMethod<data_t>& lineSearchMethod, 131 : data_t epsilon = std::numeric_limits<data_t>::epsilon()); 132 : 133 : /// default destructor 134 12 : ~APGD() override = default; 135 : 136 : DataContainer<data_t> setup(std::optional<DataContainer<data_t>> x) override; 137 : 138 : DataContainer<data_t> step(DataContainer<data_t> x) override; 139 : 140 : bool shouldStop() const override; 141 : 142 : std::string formatHeader() const override; 143 : 144 : std::string formatStep(const DataContainer<data_t>& x) const override; 145 : 146 : protected: 147 : /// implement the polymorphic clone operation 148 : auto cloneImpl() const -> APGD<data_t>* override; 149 : 150 : /// implement the polymorphic comparison operation 151 : auto isEqual(const Solver<data_t>& other) const -> bool override; 152 : 153 : private: 154 : /// Differentiable part of the problem formulation 155 : std::unique_ptr<Functional<data_t>> g_; 156 : 157 : /// Prox-friendly part of the problem formulation 158 : std::unique_ptr<Functional<data_t>> h_; 159 : 160 : DataContainer<data_t> xPrev_; 161 : 162 : DataContainer<data_t> y_; 163 : 164 : DataContainer<data_t> z_; 165 : 166 : DataContainer<data_t> grad_; 167 : 168 : data_t tPrev_ = 1; 169 : 170 : /// the line search method 171 : std::unique_ptr<LineSearchMethod<data_t>> lineSearchMethod_; 172 : 173 : /// variable affecting the stopping condition 174 : data_t epsilon_; 175 : }; 176 : } // namespace elsa