Line data Source code
1 : #include "IndicatorFunctionals.h" 2 : #include "ProximalBoxConstraint.h" 3 : #include "DataDescriptor.h" 4 : #include "DataContainer.h" 5 : #include "Error.h" 6 : #include "Functional.h" 7 : #include "elsaDefines.h" 8 : #include <limits> 9 : 10 : namespace elsa 11 : { 12 : // ------------------------------------------ 13 : // IndicatorBox 14 : template <class data_t> 15 : IndicatorBox<data_t>::IndicatorBox(const DataDescriptor& desc) : Functional<data_t>(desc) 16 2 : { 17 2 : } 18 : 19 : template <class data_t> 20 : IndicatorBox<data_t>::IndicatorBox(const DataDescriptor& desc, SelfType_t<data_t> lower, 21 : SelfType_t<data_t> upper) 22 : : Functional<data_t>(desc), lower_(lower), upper_(upper) 23 14 : { 24 14 : } 25 : 26 : template <class data_t> 27 : data_t IndicatorBox<data_t>::evaluateImpl(const DataContainer<data_t>& Rx) const 28 46 : { 29 : // Project the input value onto the box set 30 46 : auto projected = ::elsa::clip(Rx, lower_, upper_); 31 : 32 : // Check if anything changed, by computing the distance 33 46 : return (projected - Rx).l2Norm() > 0 ? std::numeric_limits<data_t>::infinity() : 0; 34 46 : } 35 : 36 : template <class data_t> 37 : data_t IndicatorBox<data_t>::convexConjugate(const DataContainer<data_t>& x) const 38 10 : { 39 10 : return elsa::maximum(x, 0).sum(); 40 10 : } 41 : 42 : template <class data_t> 43 : void IndicatorBox<data_t>::getGradientImpl(const DataContainer<data_t>&, 44 : DataContainer<data_t>&) const 45 0 : { 46 0 : throw NotImplementedError("IndicatorBox: Not differentiable"); 47 0 : } 48 : 49 : template <class data_t> 50 : LinearOperator<data_t> IndicatorBox<data_t>::getHessianImpl(const DataContainer<data_t>&) const 51 0 : { 52 0 : throw NotImplementedError("IndicatorBox: Not differentiable"); 53 0 : } 54 : 55 : template <class data_t> 56 : DataContainer<data_t> 57 : IndicatorBox<data_t>::proximal(const DataContainer<data_t>& v, 58 : [[maybe_unused]] SelfType_t<data_t> t) const 59 8 : { 60 8 : DataContainer<data_t> out(v.getDataDescriptor()); 61 8 : proximal(v, t, out); 62 8 : return out; 63 8 : } 64 : 65 : template <class data_t> 66 : void IndicatorBox<data_t>::proximal(const DataContainer<data_t>& v, 67 : [[maybe_unused]] SelfType_t<data_t> t, 68 : DataContainer<data_t>& out) const 69 8 : { 70 8 : ProximalBoxConstraint<data_t> prox(lower_, upper_); 71 8 : prox.apply(v, t, out); 72 8 : } 73 : 74 : template <class data_t> 75 : IndicatorBox<data_t>* IndicatorBox<data_t>::cloneImpl() const 76 4 : { 77 4 : return new IndicatorBox<data_t>(this->getDomainDescriptor(), lower_, upper_); 78 4 : } 79 : 80 : template <class data_t> 81 : bool IndicatorBox<data_t>::isEqual(const Functional<data_t>& other) const 82 4 : { 83 4 : if (!Functional<data_t>::isEqual(other)) { 84 0 : return false; 85 0 : } 86 : 87 4 : auto* fn = downcast<IndicatorBox<data_t>>(&other); 88 4 : return static_cast<bool>(fn) && lower_ == fn->lower_ && upper_ == fn->upper_; 89 4 : } 90 : 91 : // ------------------------------------------ 92 : // IndicatorNonNegativity 93 : template <class data_t> 94 : IndicatorNonNegativity<data_t>::IndicatorNonNegativity(const DataDescriptor& desc) 95 : : Functional<data_t>(desc) 96 24 : { 97 24 : } 98 : 99 : template <class data_t> 100 : data_t IndicatorNonNegativity<data_t>::evaluateImpl(const DataContainer<data_t>& Rx) const 101 22 : { 102 22 : constexpr auto infinity = std::numeric_limits<data_t>::infinity(); 103 : 104 : // Project the input value onto the box set 105 22 : auto projected = ::elsa::clip(Rx, data_t{0}, infinity); 106 : 107 : // Check if anything changed, by computing the distance 108 22 : return (projected - Rx).l2Norm() > 0 ? infinity : 0; 109 22 : } 110 : 111 : template <class data_t> 112 : data_t IndicatorNonNegativity<data_t>::convexConjugate(const DataContainer<data_t>& x) const 113 0 : { 114 0 : return elsa::maximum(x, 0).sum(); 115 0 : } 116 : 117 : template <class data_t> 118 : void IndicatorNonNegativity<data_t>::getGradientImpl(const DataContainer<data_t>&, 119 : DataContainer<data_t>&) const 120 0 : { 121 0 : throw NotImplementedError("IndicatorNonNegativity: Not differentiable"); 122 0 : } 123 : 124 : template <class data_t> 125 : LinearOperator<data_t> 126 : IndicatorNonNegativity<data_t>::getHessianImpl(const DataContainer<data_t>&) const 127 0 : { 128 0 : throw NotImplementedError("IndicatorNonNegativity: Not differentiable"); 129 0 : } 130 : 131 : template <class data_t> 132 : DataContainer<data_t> 133 : IndicatorNonNegativity<data_t>::proximal(const DataContainer<data_t>& v, 134 : [[maybe_unused]] SelfType_t<data_t> t) const 135 10 : { 136 10 : DataContainer<data_t> out(v.getDataDescriptor()); 137 10 : proximal(v, t, out); 138 10 : return out; 139 10 : } 140 : 141 : template <class data_t> 142 : void IndicatorNonNegativity<data_t>::proximal(const DataContainer<data_t>& v, 143 : [[maybe_unused]] SelfType_t<data_t> t, 144 : DataContainer<data_t>& out) const 145 10 : { 146 10 : ProximalBoxConstraint<data_t> prox(0); 147 10 : prox.apply(v, t, out); 148 10 : } 149 : 150 : template <class data_t> 151 : IndicatorNonNegativity<data_t>* IndicatorNonNegativity<data_t>::cloneImpl() const 152 12 : { 153 12 : return new IndicatorNonNegativity<data_t>(this->getDomainDescriptor()); 154 12 : } 155 : 156 : template <class data_t> 157 : bool IndicatorNonNegativity<data_t>::isEqual(const Functional<data_t>& other) const 158 4 : { 159 4 : return Functional<data_t>::isEqual(other) && is<IndicatorNonNegativity<data_t>>(&other); 160 4 : } 161 : 162 : // ------------------------------------------ 163 : // explicit template instantiation 164 : template class IndicatorBox<float>; 165 : template class IndicatorBox<double>; 166 : 167 : template class IndicatorNonNegativity<float>; 168 : template class IndicatorNonNegativity<double>; 169 : } // namespace elsa