Line data Source code
1 : #include "L1Loss.h" 2 : 3 : #include "DataContainer.h" 4 : #include "DataDescriptor.h" 5 : #include "LinearOperator.h" 6 : #include "TypeCasts.hpp" 7 : 8 : namespace elsa 9 : { 10 : template <typename data_t> 11 : L1Loss<data_t>::L1Loss(const LinearOperator<data_t>& A, const DataContainer<data_t>& b) 12 : : Functional<data_t>(A.getDomainDescriptor()), A_(A.clone()), b_(b) 13 8 : { 14 8 : } 15 : 16 : template <typename data_t> 17 : bool L1Loss<data_t>::isDifferentiable() const 18 0 : { 19 0 : return true; 20 0 : } 21 : 22 : template <typename data_t> 23 : const LinearOperator<data_t>& L1Loss<data_t>::getOperator() const 24 0 : { 25 0 : return *A_; 26 0 : } 27 : 28 : template <typename data_t> 29 : const DataContainer<data_t>& L1Loss<data_t>::getDataVector() const 30 0 : { 31 0 : return b_; 32 0 : } 33 : 34 : template <typename data_t> 35 : data_t L1Loss<data_t>::evaluateImpl(const DataContainer<data_t>& x) const 36 2 : { 37 2 : auto Ax = A_->apply(x); 38 2 : Ax -= b_; 39 : 40 2 : return Ax.l1Norm(); 41 2 : } 42 : 43 : template <typename data_t> 44 : void L1Loss<data_t>::getGradientImpl(const DataContainer<data_t>&, DataContainer<data_t>&) const 45 2 : { 46 2 : throw LogicError("L1Loss: not differentiable, so no gradient! (busted!)"); 47 2 : } 48 : 49 : template <typename data_t> 50 : LinearOperator<data_t> L1Loss<data_t>::getHessianImpl(const DataContainer<data_t>&) const 51 2 : { 52 2 : throw LogicError("L1Loss: not differentiable, so no Hessian! (busted!)"); 53 2 : } 54 : 55 : template <typename data_t> 56 : L1Loss<data_t>* L1Loss<data_t>::cloneImpl() const 57 2 : { 58 2 : return new L1Loss(*A_, b_); 59 2 : } 60 : 61 : template <typename data_t> 62 : bool L1Loss<data_t>::isEqual(const Functional<data_t>& other) const 63 2 : { 64 2 : if (!Functional<data_t>::isEqual(other)) 65 0 : return false; 66 : 67 2 : auto fn = downcast_safe<L1Loss<data_t>>(&other); 68 2 : return fn && *A_ == *fn->A_ && b_ == fn->b_; 69 2 : } 70 : 71 : // ------------------------------------------ 72 : // explicit template instantiation 73 : template class L1Loss<float>; 74 : template class L1Loss<double>; 75 : } // namespace elsa