Line data Source code
1 : #pragma once 2 : 3 : #include "DataContainer.h" 4 : #include "elsaDefines.h" 5 : 6 : namespace elsa 7 : { 8 : /** 9 : * @brief This is the proximal operator for the mixed L21 norm, or sometimes 10 : * also referred to as the group L1 Norm. The functional is specifically 11 : * important for isotropic TV 12 : * 13 : * A signal needs to be represented as different blocks. When viewing each 14 : * block as a 1D signal, you can create a matrix where each block is a row 15 : * in the matrix. Let \f$X \in \mathbb{R}^{m \times n}\f$, where \f$m\f$ 16 : * is the number of blocks, and \f$n\f$ the size of each block linearized to 17 : * an 1D signal. The L21 Norm is then defined as: 18 : * \f[ 19 : * ||X||_{2,1} = \sum_{j=0}^m || x_j ||_2 20 : * \f] 21 : * This is the sum (L1 norm) of the column-wise L2 norm. 22 : * 23 : * The proximal operator is then given by: 24 : * (1 - (tau * self.sigma) / np.maximum(aux, tau * self.sigma)) * x.ravel() 25 : * \f[ 26 : * prox_{\sigma ||\cdot||_{2,1}}(x_j) = (1 - \frac{sigma}{\max\{ ||x_j||, 0 \}}) x_j \quad 27 : * \forall j 28 : * \f] 29 : * The factor \f$(1 - \frac{sigma}{\max\{ ||x_j||, 0\}})\f$ can be computed 30 : * for each column, which results in an \f$n\f$ sized vector, which, with 31 : * correct broadcasting, can be multiplied directly to the input signal. 32 : */ 33 : template <typename data_t = real_t> 34 : class ProximalMixedL21Norm 35 : { 36 : public: 37 0 : ProximalMixedL21Norm() = default; 38 : 39 : ProximalMixedL21Norm(data_t sigma); 40 : 41 : ~ProximalMixedL21Norm() = default; 42 : 43 : /** 44 : * @brief apply the proximal operator of the l1 norm to an element in the operator's domain 45 : * 46 : * @param[in] v input DataContainer 47 : * @param[in] t input Threshold 48 : * @param[out] prox output DataContainer 49 : */ 50 : void apply(const DataContainer<data_t>& v, SelfType_t<data_t> t, 51 : DataContainer<data_t>& prox) const; 52 : 53 : DataContainer<data_t> apply(const DataContainer<data_t>& v, SelfType_t<data_t> t) const; 54 : 55 0 : data_t sigma() const { return sigma_; } 56 : 57 : private: 58 : data_t sigma_{1}; 59 : }; 60 : 61 : template <typename T> 62 : bool operator==(const ProximalMixedL21Norm<T>& lhs, const ProximalMixedL21Norm<T>& rhs) 63 : { 64 : return lhs.sigma() == rhs.sigma(); 65 : } 66 : 67 : template <typename T> 68 : bool operator!=(const ProximalMixedL21Norm<T>& lhs, const ProximalMixedL21Norm<T>& rhs) 69 : { 70 : return !(lhs == rhs); 71 : } 72 : } // namespace elsa