Line data Source code
1 : #include "L2Reg.h" 2 : 3 : #include "DataContainer.h" 4 : #include "DataDescriptor.h" 5 : #include "Identity.h" 6 : #include "LinearOperator.h" 7 : #include "TypeCasts.hpp" 8 : 9 : namespace elsa 10 : { 11 : template <typename data_t> 12 : L2Reg<data_t>::L2Reg(const DataDescriptor& domainDescriptor) 13 : : Functional<data_t>(domainDescriptor) 14 24 : { 15 24 : } 16 : 17 : template <typename data_t> 18 : L2Reg<data_t>::L2Reg(const LinearOperator<data_t>& A) 19 : : Functional<data_t>(A.getDomainDescriptor()), A_(A.clone()) 20 24 : { 21 24 : } 22 : 23 : template <typename data_t> 24 : bool L2Reg<data_t>::isDifferentiable() const 25 0 : { 26 0 : return true; 27 0 : } 28 : 29 : template <typename data_t> 30 : bool L2Reg<data_t>::hasOperator() const 31 32 : { 32 32 : return static_cast<bool>(A_); 33 32 : } 34 : 35 : template <typename data_t> 36 : const LinearOperator<data_t>& L2Reg<data_t>::getOperator() const 37 0 : { 38 0 : if (!hasOperator()) { 39 0 : throw Error("L2Reg: No operator present"); 40 0 : } 41 : 42 0 : return *A_; 43 0 : } 44 : 45 : template <typename data_t> 46 : data_t L2Reg<data_t>::evaluateImpl(const DataContainer<data_t>& x) const 47 8 : { 48 : // If we have an operator, apply it 49 8 : if (hasOperator()) { 50 4 : auto tmp = DataContainer<data_t>(A_->getRangeDescriptor()); 51 4 : A_->apply(x, tmp); 52 : 53 4 : return data_t{0.5} * tmp.squaredL2Norm(); 54 4 : } 55 : 56 4 : return data_t{0.5} * x.squaredL2Norm(); 57 4 : } 58 : 59 : template <typename data_t> 60 : void L2Reg<data_t>::getGradientImpl(const DataContainer<data_t>& x, 61 : DataContainer<data_t>& out) const 62 8 : { 63 8 : if (hasOperator()) { 64 4 : auto temp = A_->apply(x); 65 : 66 : // Apply chain rule 67 4 : A_->applyAdjoint(temp, out); 68 4 : } else { 69 4 : out = x; 70 4 : } 71 8 : } 72 : 73 : template <typename data_t> 74 : LinearOperator<data_t> L2Reg<data_t>::getHessianImpl(const DataContainer<data_t>& Rx) const 75 8 : { 76 8 : if (hasOperator()) { 77 4 : return leaf(adjoint(*A_) * (*A_)); 78 4 : } 79 4 : return leaf(Identity<data_t>(Rx.getDataDescriptor())); 80 4 : } 81 : 82 : template <typename data_t> 83 : L2Reg<data_t>* L2Reg<data_t>::cloneImpl() const 84 8 : { 85 8 : if (hasOperator()) { 86 4 : return new L2Reg(*A_); 87 4 : } 88 4 : return new L2Reg(this->getDomainDescriptor()); 89 4 : } 90 : 91 : template <typename data_t> 92 : bool L2Reg<data_t>::isEqual(const Functional<data_t>& other) const 93 8 : { 94 8 : if (!Functional<data_t>::isEqual(other)) 95 0 : return false; 96 : 97 8 : auto fn = downcast_safe<L2Reg<data_t>>(&other); 98 8 : if (!fn) { 99 0 : return false; 100 0 : } 101 : 102 8 : if (A_ && fn->A_) { 103 4 : return *A_ == *fn->A_; 104 4 : } 105 : 106 4 : return A_ == fn->A_; 107 4 : } 108 : 109 : // ------------------------------------------ 110 : // explicit template instantiation 111 : template class L2Reg<float>; 112 : template class L2Reg<double>; 113 : template class L2Reg<complex<double>>; 114 : template class L2Reg<complex<float>>; 115 : } // namespace elsa