Line data Source code
1 : #include "LeastSquares.h" 2 : 3 : #include "DataContainer.h" 4 : #include "DataDescriptor.h" 5 : #include "LinearOperator.h" 6 : #include "TypeCasts.hpp" 7 : 8 : namespace elsa 9 : { 10 : template <typename data_t> 11 : LeastSquares<data_t>::LeastSquares(const LinearOperator<data_t>& A, 12 : const DataContainer<data_t>& b) 13 : : Functional<data_t>(A.getDomainDescriptor()), A_(A.clone()), b_(b) 14 358 : { 15 358 : } 16 : 17 : template <typename data_t> 18 : bool LeastSquares<data_t>::isDifferentiable() const 19 18 : { 20 18 : return true; 21 18 : } 22 : 23 : template <typename data_t> 24 : const LinearOperator<data_t>& LeastSquares<data_t>::getOperator() const 25 20 : { 26 20 : return *A_; 27 20 : } 28 : 29 : template <typename data_t> 30 : const DataContainer<data_t>& LeastSquares<data_t>::getDataVector() const 31 20 : { 32 20 : return b_; 33 20 : } 34 : 35 : template <typename data_t> 36 : data_t LeastSquares<data_t>::evaluateImpl(const DataContainer<data_t>& x) const 37 1716 : { 38 1716 : auto Ax = A_->apply(x); 39 1716 : Ax -= b_; 40 : 41 1716 : return static_cast<data_t>(0.5) * Ax.squaredL2Norm(); 42 1716 : } 43 : 44 : template <typename data_t> 45 : void LeastSquares<data_t>::getGradientImpl(const DataContainer<data_t>& x, 46 : DataContainer<data_t>& out) const 47 3223 : { 48 3223 : auto temp = A_->apply(x); 49 3223 : temp -= b_; 50 : 51 : // Apply chain rule 52 3223 : A_->applyAdjoint(temp, out); 53 3223 : } 54 : 55 : template <typename data_t> 56 : LinearOperator<data_t> LeastSquares<data_t>::getHessianImpl(const DataContainer<data_t>&) const 57 20 : { 58 20 : return leaf(adjoint(*A_) * (*A_)); 59 20 : } 60 : 61 : template <typename data_t> 62 : LeastSquares<data_t>* LeastSquares<data_t>::cloneImpl() const 63 252 : { 64 252 : return new LeastSquares(*A_, b_); 65 252 : } 66 : 67 : template <typename data_t> 68 : bool LeastSquares<data_t>::isEqual(const Functional<data_t>& other) const 69 4 : { 70 4 : if (!Functional<data_t>::isEqual(other)) 71 0 : return false; 72 : 73 4 : auto fn = downcast_safe<LeastSquares<data_t>>(&other); 74 4 : return fn && *A_ == *fn->A_ && b_ == fn->b_; 75 4 : } 76 : 77 : // ------------------------------------------ 78 : // explicit template instantiation 79 : template class LeastSquares<float>; 80 : template class LeastSquares<double>; 81 : } // namespace elsa