Line data Source code
1 : #include "LASSOProblem.h" 2 : #include "LinearOperator.h" 3 : #include "DataContainer.h" 4 : #include "Error.h" 5 : #include "Identity.h" 6 : #include "TypeCasts.hpp" 7 : 8 : namespace elsa 9 : { 10 : template <typename data_t> 11 0 : LASSOProblem<data_t>::LASSOProblem(const LinearOperator<data_t>& A, 12 : const DataContainer<data_t>& b, const real_t lambda) 13 : : LASSOProblem<data_t>( 14 0 : WLSProblem<data_t>(A, b), 15 0 : RegularizationTerm<data_t>(lambda, L1Norm<data_t>{A.getDomainDescriptor()})) 16 : { 17 0 : } 18 : 19 : template <typename data_t> 20 0 : LASSOProblem<data_t>::LASSOProblem(WLSProblem<data_t> wlsProblem, 21 : const RegularizationTerm<data_t>& regTerm) 22 : : Problem<data_t>{wlsProblem.getDataTerm(), 23 : std::vector<RegularizationTerm<data_t>>{regTerm}, 24 0 : wlsProblem.getCurrentSolution()}, 25 0 : _wlsProblem{wlsProblem} 26 : { 27 0 : if (regTerm.getWeight() < 0) { 28 0 : throw InvalidArgumentError( 29 : "LASSOProblem: regularization term must have a non-negative weight"); 30 : } 31 0 : if (!is<L1Norm<data_t>>(regTerm.getFunctional())) { 32 0 : throw InvalidArgumentError("LASSOProblem: regularization term must be type L1Norm"); 33 : } 34 0 : } 35 : 36 : template <typename data_t> 37 0 : LASSOProblem<data_t>::LASSOProblem(const Problem<data_t>& problem) 38 0 : : LASSOProblem<data_t>{wlsFromProblem(problem), regTermFromProblem(problem)} 39 : { 40 0 : } 41 : 42 : template <typename data_t> 43 0 : auto LASSOProblem<data_t>::cloneImpl() const -> LASSOProblem<data_t>* 44 : { 45 0 : return new LASSOProblem<data_t>(*this); 46 : } 47 : 48 : template <typename data_t> 49 0 : auto LASSOProblem<data_t>::getLipschitzConstantImpl(index_t nIterations) const -> data_t 50 : { 51 : // compute the Lipschitz Constant of the WLSProblem as the reg. term is not differentiable 52 0 : return _wlsProblem.getLipschitzConstant(nIterations); 53 : } 54 : 55 : template <typename data_t> 56 0 : auto LASSOProblem<data_t>::wlsFromProblem(const Problem<data_t>& problem) -> WLSProblem<data_t> 57 : { 58 : // All residuals are LinearResidual, so it's safe 59 0 : auto& linResid = downcast<LinearResidual<data_t>>(problem.getDataTerm().getResidual()); 60 : 61 0 : std::unique_ptr<LinearOperator<data_t>> dataTermOp; 62 : 63 0 : if (linResid.hasOperator()) { 64 0 : dataTermOp = linResid.getOperator().clone(); 65 : } else { 66 0 : dataTermOp = std::make_unique<Identity<data_t>>(linResid.getDomainDescriptor()); 67 : } 68 : 69 0 : const DataContainer<data_t> dataVec = [&] { 70 0 : if (linResid.hasDataVector()) { 71 0 : return DataContainer<data_t>(linResid.getDataVector()); 72 : } else { 73 0 : Eigen::Matrix<data_t, Eigen::Dynamic, 1> zeroes( 74 0 : linResid.getRangeDescriptor().getNumberOfCoefficients()); 75 0 : zeroes.setZero(); 76 : 77 0 : return DataContainer<data_t>(linResid.getRangeDescriptor(), zeroes); 78 0 : } 79 : }(); 80 : 81 0 : return WLSProblem<data_t>(*dataTermOp, dataVec); 82 0 : } 83 : 84 : template <typename data_t> 85 0 : auto LASSOProblem<data_t>::regTermFromProblem(const Problem<data_t>& problem) 86 : -> RegularizationTerm<data_t> 87 : { 88 0 : const auto& regTerms = problem.getRegularizationTerms(); 89 : 90 0 : if (regTerms.size() != 1) { 91 0 : throw InvalidArgumentError("LASSOProblem: exactly one regularization term is required"); 92 : } 93 : 94 0 : return regTerms[0]; 95 : } 96 : 97 : // ------------------------------------------ 98 : // explicit template instantiation 99 : template class LASSOProblem<float>; 100 : template class LASSOProblem<double>; 101 : } // namespace elsa