Line data Source code
1 : #include "WeightedLeastSquares.h" 2 : 3 : #include "DataContainer.h" 4 : #include "DataDescriptor.h" 5 : #include "Error.h" 6 : #include "LinearOperator.h" 7 : #include "TypeCasts.hpp" 8 : 9 : namespace elsa 10 : { 11 : template <typename data_t> 12 : WeightedLeastSquares<data_t>::WeightedLeastSquares(const LinearOperator<data_t>& A, 13 : const DataContainer<data_t>& b, 14 : const DataContainer<data_t>& weights) 15 : : Functional<data_t>(A.getDomainDescriptor()), A_(A.clone()), b_(b), W_(weights) 16 36 : { 17 36 : if (A.getDomainDescriptor().getNumberOfCoefficientsPerDimension() 18 36 : != weights.getDataDescriptor().getNumberOfCoefficientsPerDimension()) { 19 0 : throw InvalidArgumentError("Domain of A and weights need to match"); 20 0 : } 21 36 : } 22 : 23 : template <typename data_t> 24 : bool WeightedLeastSquares<data_t>::isDifferentiable() const 25 6 : { 26 6 : return true; 27 6 : } 28 : 29 : template <typename data_t> 30 : const LinearOperator<data_t>& WeightedLeastSquares<data_t>::getOperator() const 31 0 : { 32 0 : return *A_; 33 0 : } 34 : 35 : template <typename data_t> 36 : const DataContainer<data_t>& WeightedLeastSquares<data_t>::getDataVector() const 37 0 : { 38 0 : return b_; 39 0 : } 40 : 41 : template <typename data_t> 42 : data_t WeightedLeastSquares<data_t>::evaluateImpl(const DataContainer<data_t>& x) const 43 104 : { 44 : // Evaluate A(x) - b 45 104 : auto temp = A_->apply(x); 46 104 : temp -= b_; 47 : 48 : // evaluate weighted l2 norm 49 104 : W_.apply(x, temp); 50 104 : return static_cast<data_t>(0.5) * x.dot(temp); 51 104 : } 52 : 53 : template <typename data_t> 54 : void WeightedLeastSquares<data_t>::getGradientImpl(const DataContainer<data_t>& x, 55 : DataContainer<data_t>& out) const 56 108 : { 57 : // Evaluate A(x) - b 58 108 : auto temp = A_->apply(x); 59 108 : temp -= b_; 60 : 61 108 : W_.apply(temp, temp); 62 : 63 : // Apply chain rule 64 108 : A_->applyAdjoint(temp, out); 65 108 : } 66 : 67 : template <typename data_t> 68 : LinearOperator<data_t> 69 : WeightedLeastSquares<data_t>::getHessianImpl(const DataContainer<data_t>&) const 70 0 : { 71 0 : return leaf(adjoint(*A_) * W_ * (*A_)); 72 0 : } 73 : 74 : template <typename data_t> 75 : WeightedLeastSquares<data_t>* WeightedLeastSquares<data_t>::cloneImpl() const 76 30 : { 77 30 : return new WeightedLeastSquares(*A_, b_, W_.getScaleFactors()); 78 30 : } 79 : 80 : template <typename data_t> 81 : bool WeightedLeastSquares<data_t>::isEqual(const Functional<data_t>& other) const 82 0 : { 83 0 : if (!Functional<data_t>::isEqual(other)) 84 0 : return false; 85 : 86 0 : auto fn = downcast_safe<WeightedLeastSquares<data_t>>(&other); 87 0 : return fn && *A_ == *fn->A_ && b_ == fn->b_ && W_ == fn->W_; 88 0 : } 89 : 90 : // ------------------------------------------ 91 : // explicit template instantiation 92 : template class WeightedLeastSquares<float>; 93 : template class WeightedLeastSquares<double>; 94 : } // namespace elsa