Line data Source code
1 : #include "WeightedL2NormPow2.h" 2 : #include "LinearOperator.h" 3 : #include "TypeCasts.hpp" 4 : 5 : #include <stdexcept> 6 : 7 : namespace elsa 8 : { 9 : template <typename data_t> 10 0 : WeightedL2NormPow2<data_t>::WeightedL2NormPow2(const Scaling<data_t>& weightingOp) 11 : : Functional<data_t>(weightingOp.getDomainDescriptor()), 12 0 : _weightingOp{static_cast<Scaling<data_t>*>(weightingOp.clone().release())} 13 : { 14 0 : } 15 : 16 : template <typename data_t> 17 0 : WeightedL2NormPow2<data_t>::WeightedL2NormPow2(const Residual<data_t>& residual, 18 : const Scaling<data_t>& weightingOp) 19 : : Functional<data_t>(residual), 20 0 : _weightingOp{static_cast<Scaling<data_t>*>(weightingOp.clone().release())} 21 : { 22 : // sanity check 23 0 : if (residual.getRangeDescriptor().getNumberOfCoefficients() 24 0 : != weightingOp.getDomainDescriptor().getNumberOfCoefficients()) 25 0 : throw InvalidArgumentError( 26 : "WeightedL2NormPow2: sizes of residual and weighting operator do not match"); 27 0 : } 28 : 29 : template <typename data_t> 30 0 : const Scaling<data_t>& WeightedL2NormPow2<data_t>::getWeightingOperator() const 31 : { 32 0 : return *_weightingOp; 33 : } 34 : 35 : template <typename data_t> 36 0 : data_t WeightedL2NormPow2<data_t>::evaluateImpl(const DataContainer<data_t>& Rx) 37 : { 38 0 : auto temp = _weightingOp->apply(Rx); 39 : 40 0 : return static_cast<data_t>(0.5) * Rx.dot(temp); 41 0 : } 42 : 43 : template <typename data_t> 44 0 : void WeightedL2NormPow2<data_t>::getGradientInPlaceImpl(DataContainer<data_t>& Rx) 45 : { 46 0 : auto temp = _weightingOp->apply(Rx); 47 0 : Rx = temp; 48 0 : } 49 : 50 : template <typename data_t> 51 : LinearOperator<data_t> 52 0 : WeightedL2NormPow2<data_t>::getHessianImpl([[maybe_unused]] const DataContainer<data_t>& Rx) 53 : { 54 0 : return leaf(*_weightingOp); 55 : } 56 : 57 : template <typename data_t> 58 0 : WeightedL2NormPow2<data_t>* WeightedL2NormPow2<data_t>::cloneImpl() const 59 : { 60 : // this ugly cast has to go away at some point.. 61 : // Still not nice, but still safe as _weightingOp is allways of type Scaling 62 0 : const auto& scaling = downcast<Scaling<data_t>>(*_weightingOp); 63 0 : return new WeightedL2NormPow2(this->getResidual(), scaling); 64 : } 65 : 66 : template <typename data_t> 67 0 : bool WeightedL2NormPow2<data_t>::isEqual(const Functional<data_t>& other) const 68 : { 69 0 : if (!Functional<data_t>::isEqual(other)) 70 0 : return false; 71 : 72 0 : auto otherWL2 = downcast_safe<WeightedL2NormPow2>(&other); 73 0 : if (!otherWL2) 74 0 : return false; 75 : 76 0 : if (*_weightingOp != *otherWL2->_weightingOp) 77 0 : return false; 78 : 79 0 : return true; 80 : } 81 : 82 : // ------------------------------------------ 83 : // explicit template instantiation 84 : template class WeightedL2NormPow2<float>; 85 : template class WeightedL2NormPow2<double>; 86 : template class WeightedL2NormPow2<complex<float>>; 87 : template class WeightedL2NormPow2<complex<double>>; 88 : 89 : } // namespace elsa