Line data Source code
1 : #include "Quadric.h" 2 : #include "Identity.h" 3 : #include "TypeCasts.hpp" 4 : 5 : #include <stdexcept> 6 : 7 : namespace elsa 8 : { 9 : template <typename data_t> 10 : Quadric<data_t>::Quadric(const LinearOperator<data_t>& A, const DataContainer<data_t>& b) 11 : : Functional<data_t>(A.getDomainDescriptor()), _linearResidual{A, b} 12 304 : { 13 304 : } 14 : 15 : template <typename data_t> 16 : Quadric<data_t>::Quadric(const LinearOperator<data_t>& A) 17 : : Functional<data_t>(A.getDomainDescriptor()), _linearResidual{A} 18 84 : { 19 84 : } 20 : 21 : template <typename data_t> 22 : Quadric<data_t>::Quadric(const DataContainer<data_t>& b) 23 : : Functional<data_t>(b.getDataDescriptor()), _linearResidual{b} 24 68 : { 25 68 : } 26 : 27 : template <typename data_t> 28 : Quadric<data_t>::Quadric(const DataDescriptor& domainDescriptor) 29 : : Functional<data_t>(domainDescriptor), _linearResidual{domainDescriptor} 30 92 : { 31 92 : } 32 : 33 : template <typename data_t> 34 : const LinearResidual<data_t>& Quadric<data_t>::getGradientExpression() const 35 65 : { 36 65 : return _linearResidual; 37 65 : } 38 : 39 : template <typename data_t> 40 : data_t Quadric<data_t>::evaluateImpl(const DataContainer<data_t>& Rx) 41 46 : { 42 46 : data_t xtAx; 43 : 44 46 : if (_linearResidual.hasOperator()) { 45 34 : auto temp = _linearResidual.getOperator().apply(Rx); 46 34 : xtAx = Rx.dot(temp); 47 34 : } else { 48 12 : xtAx = Rx.squaredL2Norm(); 49 12 : } 50 : 51 46 : if (_linearResidual.hasDataVector()) { 52 28 : return static_cast<data_t>(0.5) * xtAx - Rx.dot(_linearResidual.getDataVector()); 53 28 : } else { 54 18 : return static_cast<data_t>(0.5) * xtAx; 55 18 : } 56 46 : } 57 : 58 : template <typename data_t> 59 : void Quadric<data_t>::getGradientInPlaceImpl(DataContainer<data_t>& Rx) 60 225 : { 61 225 : Rx = _linearResidual.evaluate(Rx); 62 225 : } 63 : 64 : template <typename data_t> 65 : LinearOperator<data_t> 66 : Quadric<data_t>::getHessianImpl([[maybe_unused]] const DataContainer<data_t>& Rx) 67 46 : { 68 46 : if (_linearResidual.hasOperator()) 69 34 : return leaf(_linearResidual.getOperator()); 70 12 : else 71 12 : return leaf(Identity<data_t>(*_domainDescriptor)); 72 46 : } 73 : 74 : template <typename data_t> 75 : Quadric<data_t>* Quadric<data_t>::cloneImpl() const 76 197 : { 77 197 : if (_linearResidual.hasOperator() && _linearResidual.hasDataVector()) 78 157 : return new Quadric<data_t>(_linearResidual.getOperator(), 79 157 : _linearResidual.getDataVector()); 80 40 : else if (_linearResidual.hasOperator() && !_linearResidual.hasDataVector()) 81 14 : return new Quadric<data_t>(_linearResidual.getOperator()); 82 26 : else if (!_linearResidual.hasOperator() && _linearResidual.hasDataVector()) 83 4 : return new Quadric<data_t>(_linearResidual.getDataVector()); 84 22 : else 85 22 : return new Quadric<data_t>(*_domainDescriptor); 86 197 : } 87 : 88 : template <typename data_t> 89 : bool Quadric<data_t>::isEqual(const Functional<data_t>& other) const 90 42 : { 91 42 : if (!Functional<data_t>::isEqual(other)) 92 0 : return false; 93 : 94 42 : auto otherQuadric = downcast_safe<Quadric>(&other); 95 42 : if (!otherQuadric) 96 0 : return false; 97 : 98 42 : if (_linearResidual != otherQuadric->_linearResidual) 99 0 : return false; 100 : 101 42 : return true; 102 42 : } 103 : 104 : // ------------------------------------------ 105 : // explicit template instantiation 106 : template class Quadric<float>; 107 : template class Quadric<double>; 108 : template class Quadric<complex<float>>; 109 : template class Quadric<complex<double>>; 110 : 111 : } // namespace elsa