Line data Source code
1 : #include "Quadric.h" 2 : #include "Identity.h" 3 : #include "TypeCasts.hpp" 4 : 5 : #include <stdexcept> 6 : 7 : namespace elsa 8 : { 9 : template <typename data_t> 10 0 : Quadric<data_t>::Quadric(const LinearOperator<data_t>& A, const DataContainer<data_t>& b) 11 0 : : Functional<data_t>(A.getDomainDescriptor()), _linearResidual{A, b} 12 : { 13 0 : } 14 : 15 : template <typename data_t> 16 0 : Quadric<data_t>::Quadric(const LinearOperator<data_t>& A) 17 0 : : Functional<data_t>(A.getDomainDescriptor()), _linearResidual{A} 18 : { 19 0 : } 20 : 21 : template <typename data_t> 22 0 : Quadric<data_t>::Quadric(const DataContainer<data_t>& b) 23 0 : : Functional<data_t>(b.getDataDescriptor()), _linearResidual{b} 24 : { 25 0 : } 26 : 27 : template <typename data_t> 28 0 : Quadric<data_t>::Quadric(const DataDescriptor& domainDescriptor) 29 0 : : Functional<data_t>(domainDescriptor), _linearResidual{domainDescriptor} 30 : { 31 0 : } 32 : 33 : template <typename data_t> 34 0 : const LinearResidual<data_t>& Quadric<data_t>::getGradientExpression() const 35 : { 36 0 : return _linearResidual; 37 : } 38 : 39 : template <typename data_t> 40 0 : data_t Quadric<data_t>::evaluateImpl(const DataContainer<data_t>& Rx) 41 : { 42 0 : data_t xtAx; 43 : 44 0 : if (_linearResidual.hasOperator()) { 45 0 : auto temp = _linearResidual.getOperator().apply(Rx); 46 0 : xtAx = Rx.dot(temp); 47 0 : } else { 48 0 : xtAx = Rx.squaredL2Norm(); 49 : } 50 : 51 0 : if (_linearResidual.hasDataVector()) { 52 0 : return static_cast<data_t>(0.5) * xtAx - Rx.dot(_linearResidual.getDataVector()); 53 : } else { 54 0 : return static_cast<data_t>(0.5) * xtAx; 55 : } 56 : } 57 : 58 : template <typename data_t> 59 0 : void Quadric<data_t>::getGradientInPlaceImpl(DataContainer<data_t>& Rx) 60 : { 61 0 : Rx = _linearResidual.evaluate(Rx); 62 0 : } 63 : 64 : template <typename data_t> 65 : LinearOperator<data_t> 66 0 : Quadric<data_t>::getHessianImpl([[maybe_unused]] const DataContainer<data_t>& Rx) 67 : { 68 0 : if (_linearResidual.hasOperator()) 69 0 : return leaf(_linearResidual.getOperator()); 70 : else 71 0 : return leaf(Identity<data_t>(*_domainDescriptor)); 72 : } 73 : 74 : template <typename data_t> 75 0 : Quadric<data_t>* Quadric<data_t>::cloneImpl() const 76 : { 77 0 : if (_linearResidual.hasOperator() && _linearResidual.hasDataVector()) 78 0 : return new Quadric<data_t>(_linearResidual.getOperator(), 79 0 : _linearResidual.getDataVector()); 80 0 : else if (_linearResidual.hasOperator() && !_linearResidual.hasDataVector()) 81 0 : return new Quadric<data_t>(_linearResidual.getOperator()); 82 0 : else if (!_linearResidual.hasOperator() && _linearResidual.hasDataVector()) 83 0 : return new Quadric<data_t>(_linearResidual.getDataVector()); 84 : else 85 0 : return new Quadric<data_t>(*_domainDescriptor); 86 : } 87 : 88 : template <typename data_t> 89 0 : bool Quadric<data_t>::isEqual(const Functional<data_t>& other) const 90 : { 91 0 : if (!Functional<data_t>::isEqual(other)) 92 0 : return false; 93 : 94 0 : auto otherQuadric = downcast_safe<Quadric>(&other); 95 0 : if (!otherQuadric) 96 0 : return false; 97 : 98 0 : if (_linearResidual != otherQuadric->_linearResidual) 99 0 : return false; 100 : 101 0 : return true; 102 : } 103 : 104 : // ------------------------------------------ 105 : // explicit template instantiation 106 : template class Quadric<float>; 107 : template class Quadric<double>; 108 : template class Quadric<complex<float>>; 109 : template class Quadric<complex<double>>; 110 : 111 : } // namespace elsa