Line data Source code
1 : #include "Functional.h" 2 : #include "LinearResidual.h" 3 : #include "TypeCasts.hpp" 4 : 5 : #include <stdexcept> 6 : 7 : namespace elsa 8 : { 9 : template <typename data_t> 10 : Functional<data_t>::Functional(const DataDescriptor& domainDescriptor) 11 : : _domainDescriptor{domainDescriptor.clone()}, 12 : _residual{std::make_unique<LinearResidual<data_t>>(domainDescriptor)} 13 836 : { 14 836 : } 15 : 16 : template <typename data_t> 17 : Functional<data_t>::Functional(const Residual<data_t>& residual) 18 : : _domainDescriptor{residual.getDomainDescriptor().clone()}, _residual{residual.clone()} 19 1864 : { 20 1864 : } 21 : 22 : template <typename data_t> 23 : const DataDescriptor& Functional<data_t>::getDomainDescriptor() const 24 11855 : { 25 11855 : return *_domainDescriptor; 26 11855 : } 27 : 28 : template <typename data_t> 29 : const Residual<data_t>& Functional<data_t>::getResidual() const 30 3547 : { 31 3547 : return *_residual; 32 3547 : } 33 : 34 : template <typename data_t> 35 : data_t Functional<data_t>::evaluate(const DataContainer<data_t>& x) 36 156 : { 37 156 : if (x.getSize() != getDomainDescriptor().getNumberOfCoefficients()) 38 0 : throw InvalidArgumentError( 39 0 : "Functional::evaluate: argument size does not match functional"); 40 : 41 : // optimize for trivial LinearResiduals (no extra copy for residual result needed then) 42 156 : if (auto* linearResidual = downcast_safe<LinearResidual<data_t>>(_residual.get())) { 43 156 : if (!linearResidual->hasOperator() && !linearResidual->hasDataVector()) 44 98 : return evaluateImpl(x); 45 58 : } 46 : 47 : // in all other cases: evaluate the residual first, then call our virtual evaluateImpl 48 58 : return evaluateImpl(_residual->evaluate(x)); 49 58 : } 50 : 51 : template <typename data_t> 52 : DataContainer<data_t> Functional<data_t>::getGradient(const DataContainer<data_t>& x) 53 122 : { 54 122 : DataContainer<data_t> result(_residual->getDomainDescriptor(), x.getDataHandlerType()); 55 122 : getGradient(x, result); 56 122 : return result; 57 122 : } 58 : 59 : template <typename data_t> 60 : void Functional<data_t>::getGradient(const DataContainer<data_t>& x, 61 : DataContainer<data_t>& result) 62 11119 : { 63 11119 : if (x.getSize() != getDomainDescriptor().getNumberOfCoefficients() 64 11119 : || result.getSize() != _residual->getDomainDescriptor().getNumberOfCoefficients()) 65 0 : throw InvalidArgumentError( 66 0 : "Functional::getGradient: argument sizes do not match functional"); 67 : 68 : // optimize for trivial or simple LinearResiduals 69 11119 : if (auto* linearResidual = downcast_safe<LinearResidual<data_t>>(_residual.get())) { 70 : // if trivial, no extra copy for residual result needed (and no chain rule) 71 11119 : if (!linearResidual->hasOperator() && !linearResidual->hasDataVector()) { 72 299 : result = x; 73 299 : getGradientInPlaceImpl(result); 74 299 : return; 75 299 : } 76 : 77 : // if no operator, no need for chain rule 78 10820 : if (!linearResidual->hasOperator()) { 79 0 : linearResidual->evaluate(x, 80 0 : result); // sizes of x and result will match in this case 81 0 : getGradientInPlaceImpl(result); 82 0 : return; 83 0 : } 84 10820 : } 85 : 86 : // the general case 87 10820 : auto temp = _residual->evaluate(x); 88 10820 : getGradientInPlaceImpl(temp); 89 10820 : _residual->getJacobian(x).applyAdjoint(temp, result); // apply the chain rule 90 10820 : } 91 : 92 : template <typename data_t> 93 : LinearOperator<data_t> Functional<data_t>::getHessian(const DataContainer<data_t>& x) 94 193 : { 95 : // optimize for trivial and simple LinearResiduals 96 193 : if (auto* linearResidual = downcast_safe<LinearResidual<data_t>>(_residual.get())) { 97 : // if trivial, no extra copy for residual result needed (and no chain rule) 98 193 : if (!linearResidual->hasOperator() && !linearResidual->hasDataVector()) 99 102 : return getHessianImpl(x); 100 : 101 : // if no operator, no need for chain rule 102 91 : if (!linearResidual->hasOperator()) 103 0 : return getHessianImpl(_residual->evaluate(x)); 104 91 : } 105 : 106 : // the general case (with chain rule) 107 91 : auto jacobian = _residual->getJacobian(x); 108 91 : auto hessian = adjoint(jacobian) * (getHessianImpl(_residual->evaluate(x))) * (jacobian); 109 91 : return hessian; 110 91 : } 111 : 112 : template <typename data_t> 113 : bool Functional<data_t>::isEqual(const Functional<data_t>& other) const 114 183 : { 115 183 : return !static_cast<bool>(*_domainDescriptor != *other._domainDescriptor 116 183 : || *_residual != *other._residual); 117 183 : } 118 : 119 : // ------------------------------------------ 120 : // explicit template instantiation 121 : template class Functional<float>; 122 : template class Functional<double>; 123 : template class Functional<complex<float>>; 124 : template class Functional<complex<double>>; 125 : } // namespace elsa