Line data Source code
1 : #include "Functional.h" 2 : #include "LinearResidual.h" 3 : #include "TypeCasts.hpp" 4 : 5 : #include <stdexcept> 6 : 7 : namespace elsa 8 : { 9 : template <typename data_t> 10 0 : Functional<data_t>::Functional(const DataDescriptor& domainDescriptor) 11 0 : : _domainDescriptor{domainDescriptor.clone()}, 12 0 : _residual{std::make_unique<LinearResidual<data_t>>(domainDescriptor)} 13 : { 14 0 : } 15 : 16 : template <typename data_t> 17 0 : Functional<data_t>::Functional(const Residual<data_t>& residual) 18 0 : : _domainDescriptor{residual.getDomainDescriptor().clone()}, _residual{residual.clone()} 19 : { 20 0 : } 21 : 22 : template <typename data_t> 23 0 : const DataDescriptor& Functional<data_t>::getDomainDescriptor() const 24 : { 25 0 : return *_domainDescriptor; 26 : } 27 : 28 : template <typename data_t> 29 0 : const Residual<data_t>& Functional<data_t>::getResidual() const 30 : { 31 0 : return *_residual; 32 : } 33 : 34 : template <typename data_t> 35 0 : data_t Functional<data_t>::evaluate(const DataContainer<data_t>& x) 36 : { 37 0 : if (x.getSize() != getDomainDescriptor().getNumberOfCoefficients()) 38 0 : throw InvalidArgumentError( 39 : "Functional::evaluate: argument size does not match functional"); 40 : 41 : // optimize for trivial LinearResiduals (no extra copy for residual result needed then) 42 0 : if (auto* linearResidual = downcast_safe<LinearResidual<data_t>>(_residual.get())) { 43 0 : if (!linearResidual->hasOperator() && !linearResidual->hasDataVector()) 44 0 : return evaluateImpl(x); 45 : } 46 : 47 : // in all other cases: evaluate the residual first, then call our virtual evaluateImpl 48 0 : return evaluateImpl(_residual->evaluate(x)); 49 : } 50 : 51 : template <typename data_t> 52 0 : DataContainer<data_t> Functional<data_t>::getGradient(const DataContainer<data_t>& x) 53 : { 54 0 : DataContainer<data_t> result(_residual->getRangeDescriptor(), x.getDataHandlerType()); 55 0 : getGradient(x, result); 56 0 : return result; 57 0 : } 58 : 59 : template <typename data_t> 60 0 : void Functional<data_t>::getGradient(const DataContainer<data_t>& x, 61 : DataContainer<data_t>& result) 62 : { 63 0 : if (x.getSize() != getDomainDescriptor().getNumberOfCoefficients() 64 0 : || result.getSize() != _residual->getDomainDescriptor().getNumberOfCoefficients()) 65 0 : throw InvalidArgumentError( 66 : "Functional::getGradient: argument sizes do not match functional"); 67 : 68 : // optimize for trivial or simple LinearResiduals 69 0 : if (auto* linearResidual = downcast_safe<LinearResidual<data_t>>(_residual.get())) { 70 : // if trivial, no extra copy for residual result needed (and no chain rule) 71 0 : if (!linearResidual->hasOperator() && !linearResidual->hasDataVector()) { 72 0 : result = x; 73 0 : getGradientInPlaceImpl(result); 74 0 : return; 75 : } 76 : 77 : // if no operator, no need for chain rule 78 0 : if (!linearResidual->hasOperator()) { 79 0 : linearResidual->evaluate(x, 80 : result); // sizes of x and result will match in this case 81 0 : getGradientInPlaceImpl(result); 82 0 : return; 83 : } 84 : } 85 : 86 : // the general case 87 0 : auto temp = _residual->evaluate(x); 88 0 : getGradientInPlaceImpl(temp); 89 0 : _residual->getJacobian(x).applyAdjoint(temp, result); // apply the chain rule 90 0 : } 91 : 92 : template <typename data_t> 93 0 : LinearOperator<data_t> Functional<data_t>::getHessian(const DataContainer<data_t>& x) 94 : { 95 : // optimize for trivial and simple LinearResiduals 96 0 : if (auto* linearResidual = downcast_safe<LinearResidual<data_t>>(_residual.get())) { 97 : // if trivial, no extra copy for residual result needed (and no chain rule) 98 0 : if (!linearResidual->hasOperator() && !linearResidual->hasDataVector()) 99 0 : return getHessianImpl(x); 100 : 101 : // if no operator, no need for chain rule 102 0 : if (!linearResidual->hasOperator()) 103 0 : return getHessianImpl(_residual->evaluate(x)); 104 : } 105 : 106 : // the general case (with chain rule) 107 0 : auto jacobian = _residual->getJacobian(x); 108 0 : auto hessian = adjoint(jacobian) * (getHessianImpl(_residual->evaluate(x))) * (jacobian); 109 0 : return hessian; 110 0 : } 111 : 112 : template <typename data_t> 113 0 : bool Functional<data_t>::isEqual(const Functional<data_t>& other) const 114 : { 115 0 : return !static_cast<bool>(*_domainDescriptor != *other._domainDescriptor 116 0 : || *_residual != *other._residual); 117 : } 118 : 119 : // ------------------------------------------ 120 : // explicit template instantiation 121 : template class Functional<float>; 122 : template class Functional<double>; 123 : template class Functional<complex<float>>; 124 : template class Functional<complex<double>>; 125 : } // namespace elsa