Line data Source code
1 : #include "ConstantFunctional.h" 2 : #include "Error.h" 3 : #include "Functional.h" 4 : #include "TypeCasts.hpp" 5 : #include "elsaDefines.h" 6 : #include <unistd.h> 7 : #include <utility> 8 : 9 : namespace elsa 10 : { 11 : 12 : template <typename data_t> 13 : ConstantFunctional<data_t>::ConstantFunctional(const DataDescriptor& descriptor, 14 : SelfType_t<data_t> constant) 15 : : Functional<data_t>(descriptor), constant_(constant) 16 112 : { 17 112 : } 18 : 19 : template <typename data_t> 20 : bool ConstantFunctional<data_t>::isDifferentiable() const 21 0 : { 22 0 : return true; 23 0 : } 24 : 25 : template <typename data_t> 26 : bool ConstantFunctional<data_t>::isProxFriendly() const 27 12 : { 28 12 : return true; 29 12 : } 30 : 31 : template <typename data_t> 32 : data_t ConstantFunctional<data_t>::getConstant() const 33 0 : { 34 0 : return constant_; 35 0 : } 36 : 37 : template <typename data_t> 38 : data_t ConstantFunctional<data_t>::evaluateImpl(const DataContainer<data_t>&) const 39 20 : { 40 20 : return constant_; 41 20 : } 42 : 43 : template <class data_t> 44 : data_t ConstantFunctional<data_t>::convexConjugate(const DataContainer<data_t>& x) const 45 12 : { 46 12 : return elsa::maximum(x, 0).sum(); 47 12 : } 48 : 49 : template <typename data_t> 50 : void ConstantFunctional<data_t>::getGradientImpl(const DataContainer<data_t>&, 51 : DataContainer<data_t>& out) const 52 0 : { 53 0 : out = 0; 54 0 : } 55 : 56 : template <typename data_t> 57 : LinearOperator<data_t> 58 : ConstantFunctional<data_t>::getHessianImpl(const DataContainer<data_t>&) const 59 0 : { 60 0 : throw NotImplementedError("ConstantFunctional: not twice differentiable"); 61 0 : } 62 : 63 : template <typename data_t> 64 : DataContainer<data_t> ConstantFunctional<data_t>::proximal(const DataContainer<data_t>& v, 65 : SelfType_t<data_t>) const 66 12 : { 67 12 : return v; 68 12 : } 69 : 70 : template <typename data_t> 71 : void ConstantFunctional<data_t>::proximal(const DataContainer<data_t>& v, SelfType_t<data_t>, 72 : DataContainer<data_t>& out) const 73 0 : { 74 0 : out = v; 75 0 : } 76 : 77 : template <typename data_t> 78 : ConstantFunctional<data_t>* ConstantFunctional<data_t>::cloneImpl() const 79 56 : { 80 56 : return new ConstantFunctional<data_t>(this->getDomainDescriptor(), constant_); 81 56 : } 82 : 83 : template <typename data_t> 84 : bool ConstantFunctional<data_t>::isEqual(const Functional<data_t>& other) const 85 40 : { 86 40 : if (!Functional<data_t>::isEqual(other)) { 87 8 : return false; 88 8 : } 89 : 90 32 : auto* fn = downcast<ConstantFunctional<data_t>>(&other); 91 32 : return static_cast<bool>(fn) && constant_ == fn->constant_; 92 32 : } 93 : 94 : // ------------------------------------------ 95 : // Zero Functional 96 : template <typename data_t> 97 : ZeroFunctional<data_t>::ZeroFunctional(const DataDescriptor& descriptor) 98 : : Functional<data_t>(descriptor) 99 80 : { 100 80 : } 101 : 102 : template <typename data_t> 103 : bool ZeroFunctional<data_t>::isDifferentiable() const 104 0 : { 105 0 : return true; 106 0 : } 107 : 108 : template <typename data_t> 109 : bool ZeroFunctional<data_t>::isProxFriendly() const 110 36 : { 111 36 : return true; 112 36 : } 113 : 114 : template <typename data_t> 115 : data_t ZeroFunctional<data_t>::evaluateImpl(const DataContainer<data_t>&) const 116 216 : { 117 216 : return 0; 118 216 : } 119 : 120 : template <class data_t> 121 : data_t ZeroFunctional<data_t>::convexConjugate(const DataContainer<data_t>& x) const 122 12 : { 123 12 : return elsa::maximum(x, 0).sum(); 124 12 : } 125 : 126 : template <typename data_t> 127 : void ZeroFunctional<data_t>::getGradientImpl(const DataContainer<data_t>&, 128 : DataContainer<data_t>& out) const 129 0 : { 130 0 : out = 0; 131 0 : } 132 : 133 : template <typename data_t> 134 : LinearOperator<data_t> 135 : ZeroFunctional<data_t>::getHessianImpl(const DataContainer<data_t>&) const 136 0 : { 137 0 : throw NotImplementedError("ZeroFunctional: not twice differentiable"); 138 0 : } 139 : 140 : template <typename data_t> 141 : DataContainer<data_t> 142 : ZeroFunctional<data_t>::proximal(const DataContainer<data_t>& v, 143 : [[maybe_unused]] SelfType_t<data_t>) const 144 224 : { 145 224 : return v; 146 224 : } 147 : 148 : template <typename data_t> 149 : void ZeroFunctional<data_t>::proximal(const DataContainer<data_t>& v, 150 : [[maybe_unused]] SelfType_t<data_t>, 151 : DataContainer<data_t>& out) const 152 0 : { 153 0 : out = v; 154 0 : } 155 : 156 : template <typename data_t> 157 : ZeroFunctional<data_t>* ZeroFunctional<data_t>::cloneImpl() const 158 40 : { 159 40 : return new ZeroFunctional<data_t>(this->getDomainDescriptor()); 160 40 : } 161 : 162 : template <typename data_t> 163 : bool ZeroFunctional<data_t>::isEqual(const Functional<data_t>& other) const 164 32 : { 165 32 : return Functional<data_t>::isEqual(other) && is<ZeroFunctional<data_t>>(other); 166 32 : } 167 : 168 : // ------------------------------------------ 169 : // explicit template instantiation 170 : template class ConstantFunctional<float>; 171 : template class ConstantFunctional<double>; 172 : template class ConstantFunctional<complex<float>>; 173 : template class ConstantFunctional<complex<double>>; 174 : 175 : template class ZeroFunctional<float>; 176 : template class ZeroFunctional<double>; 177 : template class ZeroFunctional<complex<float>>; 178 : template class ZeroFunctional<complex<double>>; 179 : } // namespace elsa