Line data Source code
1 : #include "Functional.h" 2 : #include "DataContainer.h" 3 : #include "TypeCasts.hpp" 4 : #include "VolumeDescriptor.h" 5 : 6 : #include <stdexcept> 7 : 8 : namespace elsa 9 : { 10 : template <typename data_t> 11 : Functional<data_t>::Functional(const DataDescriptor& domainDescriptor) 12 : : _domainDescriptor{domainDescriptor.clone()} 13 1480 : { 14 1480 : } 15 : 16 : template <typename data_t> 17 : const DataDescriptor& Functional<data_t>::getDomainDescriptor() const 18 9708 : { 19 9708 : return *_domainDescriptor; 20 9708 : } 21 : 22 : template <typename data_t> 23 : bool Functional<data_t>::isDifferentiable() const 24 0 : { 25 0 : return false; 26 0 : } 27 : 28 : template <typename data_t> 29 : bool Functional<data_t>::isProxFriendly() const 30 0 : { 31 0 : return false; 32 0 : } 33 : 34 : template <typename data_t> 35 : bool Functional<data_t>::hasProxDual() const 36 0 : { 37 0 : return isProxFriendly(); 38 0 : } 39 : 40 : template <typename data_t> 41 : data_t Functional<data_t>::evaluate(const DataContainer<data_t>& x) const 42 2606 : { 43 : // TODO: This should compare descriptors shouldn't it? 44 2606 : if (x.getSize() != getDomainDescriptor().getNumberOfCoefficients()) { 45 0 : throw InvalidArgumentError( 46 0 : "Functional::evaluate: argument size does not match functional"); 47 0 : } 48 : 49 2606 : return evaluateImpl(x); 50 2606 : } 51 : 52 : template <class data_t> 53 : data_t Functional<data_t>::convexConjugate(const DataContainer<data_t>&) const 54 0 : { 55 0 : throw Error("Functional: No implementation of convex conjugate"); 56 0 : } 57 : 58 : template <typename data_t> 59 : DataContainer<data_t> Functional<data_t>::getGradient(const DataContainer<data_t>& x) const 60 1839 : { 61 1839 : DataContainer<data_t> result(getDomainDescriptor()); 62 1839 : getGradient(x, result); 63 1839 : return result; 64 1839 : } 65 : 66 : template <typename data_t> 67 : void Functional<data_t>::getGradient(const DataContainer<data_t>& x, 68 : DataContainer<data_t>& result) const 69 4101 : { 70 4101 : if (x.getSize() != getDomainDescriptor().getNumberOfCoefficients()) { 71 0 : throw InvalidArgumentError( 72 0 : "Functional::getGradient: argument sizes do not match functional"); 73 0 : } 74 : 75 4101 : getGradientImpl(x, result); 76 4101 : } 77 : 78 : template <typename data_t> 79 : DataContainer<data_t> Functional<data_t>::proximal(const DataContainer<data_t>&, 80 : SelfType_t<data_t>) const 81 0 : { 82 0 : throw Error("No proximal is implemented for this functional"); 83 0 : } 84 : 85 : template <typename data_t> 86 : void Functional<data_t>::proximal(const DataContainer<data_t>&, SelfType_t<data_t>, 87 : DataContainer<data_t>&) const 88 0 : { 89 0 : throw Error("No proximal is implemented for this functional"); 90 0 : } 91 : 92 : template <typename data_t> 93 : DataContainer<data_t> Functional<data_t>::proxdual(const DataContainer<data_t>& x, 94 : SelfType_t<data_t> tau) const 95 44 : { 96 44 : auto out = emptylike(x); 97 44 : proxdual(x, tau, out); 98 44 : return out; 99 44 : } 100 : 101 : template <typename data_t> 102 : void Functional<data_t>::proxdual(const DataContainer<data_t>& x, SelfType_t<data_t> tau, 103 : DataContainer<data_t>& out) const 104 44 : { 105 44 : if (!isProxFriendly()) { 106 0 : throw Error("Cannot compute proximal of convex conjugate via Moreau's identity"); 107 0 : } 108 : 109 : // TODO: improve efficiency of this approach 110 44 : auto rtau = 1 / tau; 111 44 : out = x - tau * proximal(x * rtau, rtau); 112 44 : } 113 : 114 : template <typename data_t> 115 : LinearOperator<data_t> Functional<data_t>::getHessian(const DataContainer<data_t>& x) const 116 116 : { 117 116 : return getHessianImpl(x); 118 116 : } 119 : 120 : template <typename data_t> 121 : bool Functional<data_t>::isEqual(const Functional<data_t>& other) const 122 278 : { 123 278 : return !static_cast<bool>(*_domainDescriptor != *other._domainDescriptor); 124 278 : } 125 : 126 : // ------------------------------------------ 127 : // FunctionalSum 128 : template <class data_t> 129 : FunctionalSum<data_t>::FunctionalSum(const Functional<data_t>& lhs, 130 : const Functional<data_t>& rhs) 131 : : Functional<data_t>(lhs.getDomainDescriptor()), lhs_(lhs.clone()), rhs_(rhs.clone()) 132 88 : { 133 88 : if (lhs_->getDomainDescriptor() != rhs_->getDomainDescriptor()) { 134 0 : throw InvalidArgumentError("FunctionalSum: domain descriptors need to be the same"); 135 0 : } 136 88 : } 137 : 138 : template <class data_t> 139 : data_t FunctionalSum<data_t>::evaluateImpl(const DataContainer<data_t>& Rx) const 140 144 : { 141 144 : return lhs_->evaluate(Rx) + rhs_->evaluate(Rx); 142 144 : } 143 : 144 : template <class data_t> 145 : void FunctionalSum<data_t>::getGradientImpl(const DataContainer<data_t>& Rx, 146 : DataContainer<data_t>& out) const 147 258 : { 148 258 : auto tmp = Rx; 149 258 : lhs_->getGradient(Rx, out); 150 258 : rhs_->getGradient(tmp, tmp); 151 258 : out += tmp; 152 258 : } 153 : 154 : template <class data_t> 155 : LinearOperator<data_t> 156 : FunctionalSum<data_t>::getHessianImpl(const DataContainer<data_t>& Rx) const 157 0 : { 158 0 : return lhs_->getHessian(Rx) + rhs_->getHessian(Rx); 159 0 : } 160 : 161 : template <class data_t> 162 : FunctionalSum<data_t>* FunctionalSum<data_t>::cloneImpl() const 163 64 : { 164 64 : return new FunctionalSum<data_t>(*lhs_, *rhs_); 165 64 : } 166 : 167 : template <class data_t> 168 : bool FunctionalSum<data_t>::isEqual(const Functional<data_t>& other) const 169 12 : { 170 12 : if (!Functional<data_t>::isEqual(other)) { 171 0 : return false; 172 0 : } 173 : 174 12 : auto* fn = downcast<FunctionalSum<data_t>>(&other); 175 12 : return static_cast<bool>(fn) && (*lhs_) == (*fn->lhs_) && (*rhs_) == (*fn->rhs_); 176 12 : } 177 : 178 : // ------------------------------------------ 179 : // FunctionalScalarMul 180 : template <class data_t> 181 : FunctionalScalarMul<data_t>::FunctionalScalarMul(const Functional<data_t>& fn, 182 : SelfType_t<data_t> scalar) 183 : : Functional<data_t>(fn.getDomainDescriptor()), fn_(fn.clone()), scalar_(scalar) 184 40 : { 185 40 : } 186 : 187 : template <class data_t> 188 : data_t FunctionalScalarMul<data_t>::evaluateImpl(const DataContainer<data_t>& Rx) const 189 16 : { 190 16 : return scalar_ * fn_->evaluate(Rx); 191 16 : } 192 : 193 : template <class data_t> 194 : data_t FunctionalScalarMul<data_t>::convexConjugate(const DataContainer<data_t>& x) const 195 0 : { 196 0 : return scalar_ * fn_->evaluate(x / scalar_); 197 0 : } 198 : 199 : template <class data_t> 200 : void FunctionalScalarMul<data_t>::getGradientImpl(const DataContainer<data_t>& Rx, 201 : DataContainer<data_t>& out) const 202 116 : { 203 116 : fn_->getGradient(Rx, out); 204 116 : out *= scalar_; 205 116 : } 206 : 207 : template <class data_t> 208 : LinearOperator<data_t> 209 : FunctionalScalarMul<data_t>::getHessianImpl(const DataContainer<data_t>& Rx) const 210 0 : { 211 0 : return scalar_ * fn_->getHessian(Rx); 212 0 : } 213 : 214 : template <typename data_t> 215 : bool FunctionalScalarMul<data_t>::isProxFriendly() const 216 0 : { 217 0 : return true; 218 0 : } 219 : 220 : template <typename data_t> 221 : DataContainer<data_t> FunctionalScalarMul<data_t>::proximal(const DataContainer<data_t>& v, 222 : SelfType_t<data_t> t) const 223 0 : { 224 : // If scalar is zero, this is equal to the zero functional, and hence the identity proximal 225 : // operator 226 0 : if (scalar_ == 0) { 227 0 : return v; 228 0 : } 229 0 : return fn_->proximal(v, t * scalar_); 230 0 : } 231 : 232 : template <typename data_t> 233 : void FunctionalScalarMul<data_t>::proximal(const DataContainer<data_t>& v, SelfType_t<data_t> t, 234 : DataContainer<data_t>& out) const 235 0 : { 236 0 : if (scalar_ == 0) { 237 0 : return; 238 0 : } 239 0 : return fn_->proximal(v, t * scalar_, out); 240 0 : } 241 : 242 : template <class data_t> 243 : FunctionalScalarMul<data_t>* FunctionalScalarMul<data_t>::cloneImpl() const 244 28 : { 245 28 : return new FunctionalScalarMul<data_t>(*fn_, scalar_); 246 28 : } 247 : 248 : template <class data_t> 249 : bool FunctionalScalarMul<data_t>::isEqual(const Functional<data_t>& other) const 250 8 : { 251 8 : if (!Functional<data_t>::isEqual(other)) { 252 0 : return false; 253 0 : } 254 : 255 8 : auto* fn = downcast<FunctionalScalarMul<data_t>>(&other); 256 8 : return static_cast<bool>(fn) && (*fn_) == (*fn->fn_) && scalar_ == fn->scalar_; 257 8 : } 258 : 259 : // ------------------------------------------ 260 : // explicit template instantiation 261 : template class Functional<float>; 262 : template class Functional<double>; 263 : template class Functional<complex<float>>; 264 : template class Functional<complex<double>>; 265 : 266 : template class FunctionalSum<float>; 267 : template class FunctionalSum<double>; 268 : template class FunctionalSum<complex<float>>; 269 : template class FunctionalSum<complex<double>>; 270 : 271 : template class FunctionalScalarMul<float>; 272 : template class FunctionalScalarMul<double>; 273 : template class FunctionalScalarMul<complex<float>>; 274 : template class FunctionalScalarMul<complex<double>>; 275 : } // namespace elsa