Line data Source code
1 : #include "Quadric.h" 2 : #include "DataContainer.h" 3 : #include "Identity.h" 4 : #include "TypeCasts.hpp" 5 : 6 : #include <stdexcept> 7 : 8 : namespace elsa 9 : { 10 : template <typename data_t> 11 : Quadric<data_t>::Quadric(const LinearOperator<data_t>& A, const DataContainer<data_t>& b) 12 : : Functional<data_t>(A.getDomainDescriptor()), linResidual_{A, b} 13 16 : { 14 16 : } 15 : 16 : template <typename data_t> 17 : Quadric<data_t>::Quadric(const LinearOperator<data_t>& A) 18 : : Functional<data_t>(A.getDomainDescriptor()), linResidual_{A} 19 16 : { 20 16 : } 21 : 22 : template <typename data_t> 23 : Quadric<data_t>::Quadric(const DataContainer<data_t>& b) 24 : : Functional<data_t>(b.getDataDescriptor()), linResidual_{b} 25 16 : { 26 16 : } 27 : 28 : template <typename data_t> 29 : Quadric<data_t>::Quadric(const DataDescriptor& domainDescriptor) 30 : : Functional<data_t>(domainDescriptor), linResidual_{domainDescriptor} 31 16 : { 32 16 : } 33 : 34 : template <typename data_t> 35 : bool Quadric<data_t>::isDifferentiable() const 36 0 : { 37 0 : return true; 38 0 : } 39 : 40 : template <typename data_t> 41 : const LinearResidual<data_t>& Quadric<data_t>::getGradientExpression() const 42 16 : { 43 16 : return linResidual_; 44 16 : } 45 : 46 : template <typename data_t> 47 : data_t Quadric<data_t>::evaluateImpl(const DataContainer<data_t>& Rx) const 48 16 : { 49 16 : data_t xtAx; 50 : 51 16 : if (linResidual_.hasOperator()) { 52 8 : auto temp = linResidual_.getOperator().apply(Rx); 53 8 : xtAx = Rx.dot(temp); 54 8 : } else { 55 8 : xtAx = Rx.squaredL2Norm(); 56 8 : } 57 : 58 16 : if (linResidual_.hasDataVector()) { 59 8 : return static_cast<data_t>(0.5) * xtAx - Rx.dot(linResidual_.getDataVector()); 60 8 : } else { 61 8 : return static_cast<data_t>(0.5) * xtAx; 62 8 : } 63 16 : } 64 : 65 : template <typename data_t> 66 : void Quadric<data_t>::getGradientImpl(const DataContainer<data_t>& Rx, 67 : DataContainer<data_t>& out) const 68 16 : { 69 16 : out = linResidual_.evaluate(Rx); 70 16 : } 71 : 72 : template <typename data_t> 73 : LinearOperator<data_t> 74 : Quadric<data_t>::getHessianImpl([[maybe_unused]] const DataContainer<data_t>& Rx) const 75 16 : { 76 16 : if (linResidual_.hasOperator()) 77 8 : return leaf(linResidual_.getOperator()); 78 8 : else 79 8 : return leaf(Identity<data_t>(this->getDomainDescriptor())); 80 16 : } 81 : 82 : template <typename data_t> 83 : Quadric<data_t>* Quadric<data_t>::cloneImpl() const 84 16 : { 85 16 : if (linResidual_.hasOperator() && linResidual_.hasDataVector()) 86 4 : return new Quadric<data_t>(linResidual_.getOperator(), linResidual_.getDataVector()); 87 12 : else if (linResidual_.hasOperator() && !linResidual_.hasDataVector()) 88 4 : return new Quadric<data_t>(linResidual_.getOperator()); 89 8 : else if (!linResidual_.hasOperator() && linResidual_.hasDataVector()) 90 4 : return new Quadric<data_t>(linResidual_.getDataVector()); 91 4 : else 92 4 : return new Quadric<data_t>(this->getDomainDescriptor()); 93 16 : } 94 : 95 : template <typename data_t> 96 : bool Quadric<data_t>::isEqual(const Functional<data_t>& other) const 97 16 : { 98 16 : if (!Functional<data_t>::isEqual(other)) 99 0 : return false; 100 : 101 16 : auto otherQuadric = downcast_safe<Quadric>(&other); 102 16 : if (!otherQuadric) 103 0 : return false; 104 : 105 16 : if (linResidual_ != otherQuadric->linResidual_) 106 0 : return false; 107 : 108 16 : return true; 109 16 : } 110 : 111 : // ------------------------------------------ 112 : // explicit template instantiation 113 : template class Quadric<float>; 114 : template class Quadric<double>; 115 : template class Quadric<complex<float>>; 116 : template class Quadric<complex<double>>; 117 : 118 : } // namespace elsa