Line data Source code
1 : #include "WeightedL2Squared.h" 2 : #include "DataContainer.h" 3 : #include "LinearOperator.h" 4 : #include "Scaling.h" 5 : #include "TypeCasts.hpp" 6 : 7 : #include <stdexcept> 8 : 9 : namespace elsa 10 : { 11 : template <typename data_t> 12 : WeightedL2Squared<data_t>::WeightedL2Squared(const DataContainer<data_t>& weights) 13 : : Functional<data_t>(weights.getDataDescriptor()), weights_{weights} 14 24 : { 15 24 : } 16 : 17 : template <typename data_t> 18 : bool WeightedL2Squared<data_t>::isDifferentiable() const 19 0 : { 20 0 : return true; 21 0 : } 22 : 23 : template <typename data_t> 24 : Scaling<data_t> WeightedL2Squared<data_t>::getWeightingOperator() const 25 8 : { 26 8 : return Scaling<data_t>(weights_); 27 8 : } 28 : 29 : template <typename data_t> 30 : data_t WeightedL2Squared<data_t>::evaluateImpl(const DataContainer<data_t>& Rx) const 31 2 : { 32 2 : auto temp = weights_ * Rx; 33 2 : return static_cast<data_t>(0.5) * Rx.dot(temp); 34 2 : } 35 : 36 : template <typename data_t> 37 : void WeightedL2Squared<data_t>::getGradientImpl(const DataContainer<data_t>& Rx, 38 : DataContainer<data_t>& out) const 39 4 : { 40 4 : out = weights_ * Rx; 41 4 : } 42 : 43 : template <typename data_t> 44 : LinearOperator<data_t> 45 : WeightedL2Squared<data_t>::getHessianImpl(const DataContainer<data_t>&) const 46 4 : { 47 4 : return leaf(getWeightingOperator()); 48 4 : } 49 : 50 : template <typename data_t> 51 : WeightedL2Squared<data_t>* WeightedL2Squared<data_t>::cloneImpl() const 52 4 : { 53 4 : return new WeightedL2Squared(weights_); 54 4 : } 55 : 56 : template <typename data_t> 57 : bool WeightedL2Squared<data_t>::isEqual(const Functional<data_t>& other) const 58 4 : { 59 4 : if (!Functional<data_t>::isEqual(other)) 60 0 : return false; 61 : 62 4 : auto otherWL2 = downcast_safe<WeightedL2Squared>(&other); 63 4 : if (!otherWL2) 64 0 : return false; 65 : 66 4 : return weights_ == otherWL2->weights_; 67 4 : } 68 : 69 : // ------------------------------------------ 70 : // explicit template instantiation 71 : template class WeightedL2Squared<float>; 72 : template class WeightedL2Squared<double>; 73 : template class WeightedL2Squared<complex<float>>; 74 : template class WeightedL2Squared<complex<double>>; 75 : 76 : } // namespace elsa