Line data Source code
1 : #pragma once 2 : 3 : #include "elsaDefines.h" 4 : #include "Cloneable.h" 5 : #include "DataContainer.h" 6 : #include "DataDescriptor.h" 7 : #include "StrongTypes.h" 8 : 9 : namespace elsa 10 : { 11 : /// Customization point for ProximalOperators. For an object to be type erased by the 12 : /// ProximalOperator interface, you either need to provide a member function `apply` for the 13 : /// type, with the given parameters. Or you can overload this function. 14 : template <class T, class data_t> 15 : void applyProximal(const T& proximal, const DataContainer<data_t>& v, SelfType_t<data_t> tau, 16 : DataContainer<data_t>& out) 17 20 : { 18 20 : proximal.apply(v, tau, out); 19 20 : } 20 : 21 : template <class T, class data_t> 22 : DataContainer<data_t> applyProximal(const T& proximal, const DataContainer<data_t>& v, 23 : SelfType_t<data_t> tau) 24 116 : { 25 116 : return proximal.apply(v, tau); 26 116 : } 27 : 28 : /** 29 : * @brief Base class representing a proximal operator prox. 30 : * 31 : * This class represents a proximal operator prox, expressed through its apply methods, 32 : * which implement the proximal operator of f with penalty r i.e. 33 : * @f$ prox_{f,\rho}(v) = argmin_{x}(f(x) + (\rho/2)ยท\| x - v \|^2_2). @f$ 34 : * 35 : * The class implements type erasure. Classes, that bind to this wrapper should overload 36 : * either the `applyProximal`, or implement the `apply` member function, with the signature 37 : * given in this class. Base types also are assumed to be Semi-Regular. 38 : * 39 : * References: 40 : * https://stanford.edu/~boyd/papers/pdf/admm_distr_stats.pdf 41 : * 42 : * @tparam data_t data type for the values of the operator, defaulting to real_t 43 : * 44 : * @author 45 : * - Andi Braimllari - initial code 46 : * - David Frank - Type Erasure 47 : */ 48 : template <typename data_t = real_t> 49 : class ProximalOperator 50 : { 51 : private: 52 : /// Concept for proximal operators, they should be cloneable, and have an apply method 53 : struct ProxConcept { 54 86 : virtual ~ProxConcept() = default; 55 : virtual std::unique_ptr<ProxConcept> clone() const = 0; 56 : virtual DataContainer<data_t> apply(const DataContainer<data_t>& v, 57 : SelfType_t<data_t> t) const = 0; 58 : virtual void apply(const DataContainer<data_t>& v, SelfType_t<data_t> t, 59 : DataContainer<data_t>& out) const = 0; 60 : }; 61 : 62 : /// Bridge wrapper for concrete types 63 : template <class T> 64 : struct ProxModel : public ProxConcept { 65 86 : ProxModel(T self) : self_(std::move(self)) {} 66 : 67 : /// Just clone the type (assumes regularity, i.e. copy constructible) 68 : std::unique_ptr<ProxConcept> clone() const override 69 52 : { 70 52 : return std::make_unique<ProxModel<T>>(self_); 71 52 : } 72 : 73 : /// Apply proximal by calling `apply_proximal`, this enables flexible extension, without 74 : /// classes as well. The default implementation of `apply_proximal`, 75 : /// just calls the member function 76 : void apply(const DataContainer<data_t>& v, SelfType_t<data_t> t, 77 : DataContainer<data_t>& out) const override 78 20 : { 79 20 : applyProximal(self_, v, t, out); 80 20 : } 81 : 82 : DataContainer<data_t> apply(const DataContainer<data_t>& v, 83 : SelfType_t<data_t> t) const override 84 116 : { 85 116 : return applyProximal(self_, v, t); 86 116 : } 87 : 88 : private: 89 : T self_; 90 : }; 91 : 92 : public: 93 : /// defaulted default constructor for base-class (will point to nothing) 94 : ProximalOperator() = default; 95 : 96 : /// Type erasure constructor, taking everything that kan bind to the above provided 97 : /// interface 98 : template <typename T> 99 : ProximalOperator(T proxOp) : ptr_(std::make_unique<ProxModel<T>>(std::move(proxOp))) 100 34 : { 101 34 : } 102 : 103 : template <typename T> 104 : ProximalOperator& operator=(T proxOp) 105 : { 106 : ptr_ = std::make_unique<ProxModel<T>>(std::move(proxOp)); 107 : return *this; 108 : } 109 : 110 : /// Copy constructor 111 : ProximalOperator(const ProximalOperator& other); 112 : 113 : /// Default move constructor 114 0 : ProximalOperator(ProximalOperator&& other) noexcept = default; 115 : 116 : /// Copy assignment 117 : ProximalOperator& operator=(const ProximalOperator& other); 118 : 119 : /// Default move assignment 120 : ProximalOperator& operator=(ProximalOperator&& other) noexcept = default; 121 : 122 : /// default destructor 123 86 : ~ProximalOperator() = default; 124 : 125 : /** 126 : * @brief apply the proximal operator to an element in the operator's domain 127 : * 128 : * @param[in] v input DataContainer 129 : * @param[in] t input Threshold 130 : * 131 : * @returns prox DataContainer containing the application of the proximal operator to 132 : * data v, i.e. in the range of the operator 133 : * 134 : * Please note: this method uses apply(v, t, prox(v)) to perform the actual operation. 135 : */ 136 : auto apply(const DataContainer<data_t>& v, SelfType_t<data_t> t) const 137 : -> DataContainer<data_t>; 138 : 139 : /** 140 : * @brief apply the proximal operator to an element in the operator's domain 141 : * 142 : * @param[in] v input DataContainer 143 : * @param[in] t input Threshold 144 : * @param[out] prox output DataContainer 145 : * 146 : * Please note: this method calls the method applyImpl that has to be overridden in derived 147 : * classes. (Why is this method not virtual itself? Because you cannot have a non-virtual 148 : * function overloading a virtual one [apply with one vs. two arguments]). 149 : */ 150 : void apply(const DataContainer<data_t>& v, SelfType_t<data_t> t, 151 : DataContainer<data_t>& prox) const; 152 : 153 : private: 154 : std::unique_ptr<ProxConcept> ptr_ = {}; 155 : }; 156 : } // namespace elsa