LCOV - code coverage report
Current view: top level - elsa/problems - LASSOProblem.cpp (source / functions) Hit Total Coverage
Test: coverage-all.lcov Lines: 44 46 95.7 %
Date: 2022-08-25 03:05:39 Functions: 14 16 87.5 %

          Line data    Source code
       1             : #include "LASSOProblem.h"
       2             : #include "LinearOperator.h"
       3             : #include "DataContainer.h"
       4             : #include "Error.h"
       5             : #include "Identity.h"
       6             : #include "TypeCasts.hpp"
       7             : 
       8             : namespace elsa
       9             : {
      10             :     template <typename data_t>
      11             :     LASSOProblem<data_t>::LASSOProblem(const LinearOperator<data_t>& A,
      12             :                                        const DataContainer<data_t>& b, const real_t lambda)
      13             :         : LASSOProblem<data_t>(
      14             :             WLSProblem<data_t>(A, b),
      15             :             RegularizationTerm<data_t>(lambda, L1Norm<data_t>{A.getDomainDescriptor()}))
      16           0 :     {
      17           0 :     }
      18             : 
      19             :     template <typename data_t>
      20             :     LASSOProblem<data_t>::LASSOProblem(WLSProblem<data_t> wlsProblem,
      21             :                                        const RegularizationTerm<data_t>& regTerm)
      22             :         : Problem<data_t>{wlsProblem.getDataTerm(),
      23             :                           std::vector<RegularizationTerm<data_t>>{regTerm},
      24             :                           wlsProblem.getCurrentSolution()},
      25             :           _wlsProblem{wlsProblem}
      26          22 :     {
      27          22 :         if (regTerm.getWeight() < 0) {
      28           2 :             throw InvalidArgumentError(
      29           2 :                 "LASSOProblem: regularization term must have a non-negative weight");
      30           2 :         }
      31          20 :         if (!is<L1Norm<data_t>>(regTerm.getFunctional())) {
      32           2 :             throw InvalidArgumentError("LASSOProblem: regularization term must be type L1Norm");
      33           2 :         }
      34          20 :     }
      35             : 
      36             :     template <typename data_t>
      37             :     LASSOProblem<data_t>::LASSOProblem(const Problem<data_t>& problem)
      38             :         : LASSOProblem<data_t>{wlsFromProblem(problem), regTermFromProblem(problem)}
      39           8 :     {
      40           8 :     }
      41             : 
      42             :     template <typename data_t>
      43             :     auto LASSOProblem<data_t>::cloneImpl() const -> LASSOProblem<data_t>*
      44           4 :     {
      45           4 :         return new LASSOProblem<data_t>(*this);
      46           4 :     }
      47             : 
      48             :     template <typename data_t>
      49             :     auto LASSOProblem<data_t>::getLipschitzConstantImpl(index_t nIterations) const -> data_t
      50           8 :     {
      51             :         // compute the Lipschitz Constant of the WLSProblem as the reg. term is not differentiable
      52           8 :         return _wlsProblem.getLipschitzConstant(nIterations);
      53           8 :     }
      54             : 
      55             :     template <typename data_t>
      56             :     auto LASSOProblem<data_t>::wlsFromProblem(const Problem<data_t>& problem) -> WLSProblem<data_t>
      57           8 :     {
      58             :         // All residuals are LinearResidual, so it's safe
      59           8 :         auto& linResid = downcast<LinearResidual<data_t>>(problem.getDataTerm().getResidual());
      60             : 
      61           8 :         std::unique_ptr<LinearOperator<data_t>> dataTermOp;
      62             : 
      63           8 :         if (linResid.hasOperator()) {
      64           6 :             dataTermOp = linResid.getOperator().clone();
      65           6 :         } else {
      66           2 :             dataTermOp = std::make_unique<Identity<data_t>>(linResid.getDomainDescriptor());
      67           2 :         }
      68             : 
      69           8 :         const DataContainer<data_t> dataVec = [&] {
      70           8 :             if (linResid.hasDataVector()) {
      71           6 :                 return DataContainer<data_t>(linResid.getDataVector());
      72           6 :             } else {
      73           2 :                 Eigen::Matrix<data_t, Eigen::Dynamic, 1> zeroes(
      74           2 :                     linResid.getRangeDescriptor().getNumberOfCoefficients());
      75           2 :                 zeroes.setZero();
      76             : 
      77           2 :                 return DataContainer<data_t>(linResid.getRangeDescriptor(), zeroes);
      78           2 :             }
      79           8 :         }();
      80             : 
      81           8 :         return WLSProblem<data_t>(*dataTermOp, dataVec);
      82           8 :     }
      83             : 
      84             :     template <typename data_t>
      85             :     auto LASSOProblem<data_t>::regTermFromProblem(const Problem<data_t>& problem)
      86             :         -> RegularizationTerm<data_t>
      87           8 :     {
      88           8 :         const auto& regTerms = problem.getRegularizationTerms();
      89             : 
      90           8 :         if (regTerms.size() != 1) {
      91           4 :             throw InvalidArgumentError("LASSOProblem: exactly one regularization term is required");
      92           4 :         }
      93             : 
      94           4 :         return regTerms[0];
      95           4 :     }
      96             : 
      97             :     // ------------------------------------------
      98             :     // explicit template instantiation
      99             :     template class LASSOProblem<float>;
     100             :     template class LASSOProblem<double>;
     101             : } // namespace elsa

Generated by: LCOV version 1.14