Line data Source code
1 : #include "WeightedL1Norm.h" 2 : #include "DataContainer.h" 3 : #include "Error.h" 4 : #include "LinearOperator.h" 5 : 6 : namespace elsa 7 : { 8 : template <typename data_t> 9 : WeightedL1Norm<data_t>::WeightedL1Norm(const DataContainer<data_t>& weightingOp) 10 : : Functional<data_t>(weightingOp.getDataDescriptor()), _weightingOp{weightingOp} 11 10 : { 12 : // sanity check 13 10 : if (weightingOp.minElement() < 0) { 14 0 : throw InvalidArgumentError( 15 0 : "WeightedL1Norm: all weights in the w vector should be >= 0"); 16 0 : } 17 10 : } 18 : 19 : template <typename data_t> 20 : const DataContainer<data_t>& WeightedL1Norm<data_t>::getWeightingOperator() const 21 2 : { 22 2 : return _weightingOp; 23 2 : } 24 : 25 : template <typename data_t> 26 : data_t WeightedL1Norm<data_t>::evaluateImpl(const DataContainer<data_t>& x) const 27 2 : { 28 2 : if (x.getDataDescriptor() != _weightingOp.getDataDescriptor()) { 29 0 : throw InvalidArgumentError("WeightedL1Norm: x is not of correct size"); 30 0 : } 31 : 32 2 : return _weightingOp.dot(cwiseAbs(x)); 33 2 : } 34 : 35 : template <typename data_t> 36 : void WeightedL1Norm<data_t>::getGradientImpl(const DataContainer<data_t>&, 37 : DataContainer<data_t>&) const 38 2 : { 39 2 : throw LogicError("WeightedL1Norm: not differentiable, so no gradient! (busted!)"); 40 2 : } 41 : 42 : template <typename data_t> 43 : LinearOperator<data_t> 44 : WeightedL1Norm<data_t>::getHessianImpl(const DataContainer<data_t>&) const 45 2 : { 46 2 : throw LogicError("WeightedL1Norm: not differentiable, so no Hessian! (busted!)"); 47 2 : } 48 : 49 : template <typename data_t> 50 : WeightedL1Norm<data_t>* WeightedL1Norm<data_t>::cloneImpl() const 51 2 : { 52 2 : return new WeightedL1Norm(_weightingOp); 53 2 : } 54 : 55 : template <typename data_t> 56 : bool WeightedL1Norm<data_t>::isEqual(const Functional<data_t>& other) const 57 2 : { 58 2 : if (!Functional<data_t>::isEqual(other)) 59 0 : return false; 60 : 61 2 : auto otherWL1 = dynamic_cast<const WeightedL1Norm*>(&other); 62 2 : if (!otherWL1) 63 0 : return false; 64 : 65 2 : return _weightingOp == otherWL1->_weightingOp; 66 2 : } 67 : 68 : // ------------------------------------------ 69 : // explicit template instantiation 70 : template class WeightedL1Norm<float>; 71 : template class WeightedL1Norm<double>; 72 : } // namespace elsa