Line data Source code
1 : #include <cmath> 2 : #include <optional> 3 : 4 : #include "ProximalL1.h" 5 : #include "DataContainer.h" 6 : #include "Error.h" 7 : #include "TypeCasts.hpp" 8 : #include "Math.hpp" 9 : #include "elsaDefines.h" 10 : 11 : namespace elsa 12 : { 13 : template <class data_t> 14 : DataContainer<data_t> softthreshold(const DataContainer<data_t>& x, SelfType_t<data_t> thresh) 15 104 : { 16 104 : return elsa::maximum(elsa::cwiseAbs(x) - thresh, 0) * sign(x); 17 104 : } 18 : 19 : template <typename data_t> 20 : ProximalL1<data_t>::ProximalL1(data_t sigma) : sigma_(sigma), b_(std::nullopt) 21 0 : { 22 0 : } 23 : 24 : template <class data_t> 25 : ProximalL1<data_t>::ProximalL1(const DataContainer<data_t>& b) : sigma_(1.0), b_(b) 26 8 : { 27 8 : } 28 : 29 : template <class data_t> 30 : ProximalL1<data_t>::ProximalL1(const DataContainer<data_t>& b, SelfType_t<data_t> sigma) 31 : : sigma_(sigma), b_(b) 32 0 : { 33 0 : } 34 : 35 : template <class data_t> 36 : ProximalL1<data_t>::ProximalL1(ProximalL1<data_t>&& other) noexcept 37 : : sigma_(other.sigma_), b_(std::move(other.b_)) 38 46 : { 39 46 : } 40 : 41 : template <class data_t> 42 : ProximalL1<data_t>& ProximalL1<data_t>::operator=(ProximalL1<data_t>&& other) noexcept 43 0 : { 44 0 : sigma_ = other.sigma_; 45 0 : b_ = std::move(other.b_); 46 : 47 0 : return *this; 48 0 : } 49 : 50 : template <typename data_t> 51 : DataContainer<data_t> ProximalL1<data_t>::apply(const DataContainer<data_t>& v, 52 : SelfType_t<data_t> t) const 53 90 : { 54 90 : DataContainer<data_t> out{v.getDataDescriptor()}; 55 90 : apply(v, t, out); 56 90 : return out; 57 90 : } 58 : 59 : template <typename data_t> 60 : void ProximalL1<data_t>::apply(const DataContainer<data_t>& v, SelfType_t<data_t> t, 61 : DataContainer<data_t>& prox) const 62 106 : { 63 106 : if (v.getSize() != prox.getSize()) { 64 2 : throw LogicError("ProximalL1: sizes of v and prox must match"); 65 2 : } 66 : 67 104 : if (b_.has_value()) { 68 8 : prox.assign(softthreshold(v - *b_, t * sigma_) + *b_); 69 96 : } else { 70 96 : prox.assign(softthreshold(v, t * sigma_)); 71 96 : } 72 104 : } 73 : 74 : // ------------------------------------------ 75 : // explicit template instantiation 76 : template class ProximalL1<float>; 77 : template class ProximalL1<double>; 78 : } // namespace elsa