Line data Source code
1 : #include "ProximalL2Squared.h" 2 : #include "DataContainer.h" 3 : 4 : namespace elsa 5 : { 6 : template <class data_t> 7 : ProximalL2Squared<data_t>::ProximalL2Squared(data_t sigma) : sigma_(sigma) 8 0 : { 9 0 : } 10 : 11 : template <class data_t> 12 : ProximalL2Squared<data_t>::ProximalL2Squared(const DataContainer<data_t>& b) : b_(b) 13 10 : { 14 10 : } 15 : 16 : template <class data_t> 17 : ProximalL2Squared<data_t>::ProximalL2Squared(const DataContainer<data_t>& b, 18 : SelfType_t<data_t> sigma) 19 : : sigma_(sigma), b_(b) 20 0 : { 21 0 : } 22 : 23 : template <class data_t> 24 : DataContainer<data_t> ProximalL2Squared<data_t>::apply(const DataContainer<data_t>& v, 25 : SelfType_t<data_t> t) const 26 14 : { 27 14 : auto out = DataContainer<data_t>(v.getDataDescriptor()); 28 14 : apply(v, t, out); 29 14 : return out; 30 14 : } 31 : 32 : template <class data_t> 33 : void ProximalL2Squared<data_t>::apply(const DataContainer<data_t>& v, SelfType_t<data_t> lambda, 34 : DataContainer<data_t>& prox) const 35 28 : { 36 28 : const auto mult = 1 / (data_t{1} + 2 * (lambda * sigma_)); 37 : 38 28 : if (b_.has_value()) { 39 12 : const auto factor = 2 * (sigma_ * lambda) * mult; 40 12 : lincomb(mult, v, factor, *b_, prox); 41 16 : } else { 42 16 : prox = v * mult; 43 16 : } 44 28 : } 45 : 46 : // ------------------------------------------ 47 : // explicit template instantiation 48 : template class ProximalL2Squared<float>; 49 : template class ProximalL2Squared<double>; 50 : } // namespace elsa