Line data Source code
1 : #include "WeightedL2NormPow2.h" 2 : #include "LinearOperator.h" 3 : #include "TypeCasts.hpp" 4 : 5 : #include <stdexcept> 6 : 7 : namespace elsa 8 : { 9 : template <typename data_t> 10 : WeightedL2NormPow2<data_t>::WeightedL2NormPow2(const Scaling<data_t>& weightingOp) 11 : : Functional<data_t>(weightingOp.getDomainDescriptor()), 12 : _weightingOp{static_cast<Scaling<data_t>*>(weightingOp.clone().release())} 13 46 : { 14 46 : } 15 : 16 : template <typename data_t> 17 : WeightedL2NormPow2<data_t>::WeightedL2NormPow2(const Residual<data_t>& residual, 18 : const Scaling<data_t>& weightingOp) 19 : : Functional<data_t>(residual), 20 : _weightingOp{static_cast<Scaling<data_t>*>(weightingOp.clone().release())} 21 198 : { 22 : // sanity check 23 198 : if (residual.getRangeDescriptor().getNumberOfCoefficients() 24 198 : != weightingOp.getDomainDescriptor().getNumberOfCoefficients()) 25 0 : throw InvalidArgumentError( 26 0 : "WeightedL2NormPow2: sizes of residual and weighting operator do not match"); 27 198 : } 28 : 29 : template <typename data_t> 30 : const Scaling<data_t>& WeightedL2NormPow2<data_t>::getWeightingOperator() const 31 16 : { 32 16 : return *_weightingOp; 33 16 : } 34 : 35 : template <typename data_t> 36 : data_t WeightedL2NormPow2<data_t>::evaluateImpl(const DataContainer<data_t>& Rx) 37 8 : { 38 8 : auto temp = _weightingOp->apply(Rx); 39 : 40 8 : return static_cast<data_t>(0.5) * Rx.dot(temp); 41 8 : } 42 : 43 : template <typename data_t> 44 : void WeightedL2NormPow2<data_t>::getGradientInPlaceImpl(DataContainer<data_t>& Rx) 45 12 : { 46 12 : auto temp = _weightingOp->apply(Rx); 47 12 : Rx = temp; 48 12 : } 49 : 50 : template <typename data_t> 51 : LinearOperator<data_t> 52 : WeightedL2NormPow2<data_t>::getHessianImpl([[maybe_unused]] const DataContainer<data_t>& Rx) 53 12 : { 54 12 : return leaf(*_weightingOp); 55 12 : } 56 : 57 : template <typename data_t> 58 : WeightedL2NormPow2<data_t>* WeightedL2NormPow2<data_t>::cloneImpl() const 59 62 : { 60 : // this ugly cast has to go away at some point.. 61 : // Still not nice, but still safe as _weightingOp is allways of type Scaling 62 62 : const auto& scaling = downcast<Scaling<data_t>>(*_weightingOp); 63 62 : return new WeightedL2NormPow2(this->getResidual(), scaling); 64 62 : } 65 : 66 : template <typename data_t> 67 : bool WeightedL2NormPow2<data_t>::isEqual(const Functional<data_t>& other) const 68 18 : { 69 18 : if (!Functional<data_t>::isEqual(other)) 70 0 : return false; 71 : 72 18 : auto otherWL2 = downcast_safe<WeightedL2NormPow2>(&other); 73 18 : if (!otherWL2) 74 0 : return false; 75 : 76 18 : if (*_weightingOp != *otherWL2->_weightingOp) 77 0 : return false; 78 : 79 18 : return true; 80 18 : } 81 : 82 : // ------------------------------------------ 83 : // explicit template instantiation 84 : template class WeightedL2NormPow2<float>; 85 : template class WeightedL2NormPow2<double>; 86 : template class WeightedL2NormPow2<complex<float>>; 87 : template class WeightedL2NormPow2<complex<double>>; 88 : 89 : } // namespace elsa