LCOV - code coverage report
Current view: top level - problems - LASSOProblem.cpp (source / functions) Hit Total Coverage
Test: test_coverage.info.cleaned Lines: 0 40 0.0 %
Date: 2022-08-04 03:43:28 Functions: 0 16 0.0 %

          Line data    Source code
       1             : #include "LASSOProblem.h"
       2             : #include "LinearOperator.h"
       3             : #include "DataContainer.h"
       4             : #include "Error.h"
       5             : #include "Identity.h"
       6             : #include "TypeCasts.hpp"
       7             : 
       8             : namespace elsa
       9             : {
      10             :     template <typename data_t>
      11           0 :     LASSOProblem<data_t>::LASSOProblem(const LinearOperator<data_t>& A,
      12             :                                        const DataContainer<data_t>& b, const real_t lambda)
      13             :         : LASSOProblem<data_t>(
      14           0 :             WLSProblem<data_t>(A, b),
      15           0 :             RegularizationTerm<data_t>(lambda, L1Norm<data_t>{A.getDomainDescriptor()}))
      16             :     {
      17           0 :     }
      18             : 
      19             :     template <typename data_t>
      20           0 :     LASSOProblem<data_t>::LASSOProblem(WLSProblem<data_t> wlsProblem,
      21             :                                        const RegularizationTerm<data_t>& regTerm)
      22             :         : Problem<data_t>{wlsProblem.getDataTerm(),
      23             :                           std::vector<RegularizationTerm<data_t>>{regTerm},
      24           0 :                           wlsProblem.getCurrentSolution()},
      25           0 :           _wlsProblem{wlsProblem}
      26             :     {
      27           0 :         if (regTerm.getWeight() < 0) {
      28           0 :             throw InvalidArgumentError(
      29             :                 "LASSOProblem: regularization term must have a non-negative weight");
      30             :         }
      31           0 :         if (!is<L1Norm<data_t>>(regTerm.getFunctional())) {
      32           0 :             throw InvalidArgumentError("LASSOProblem: regularization term must be type L1Norm");
      33             :         }
      34           0 :     }
      35             : 
      36             :     template <typename data_t>
      37           0 :     LASSOProblem<data_t>::LASSOProblem(const Problem<data_t>& problem)
      38           0 :         : LASSOProblem<data_t>{wlsFromProblem(problem), regTermFromProblem(problem)}
      39             :     {
      40           0 :     }
      41             : 
      42             :     template <typename data_t>
      43           0 :     auto LASSOProblem<data_t>::cloneImpl() const -> LASSOProblem<data_t>*
      44             :     {
      45           0 :         return new LASSOProblem<data_t>(*this);
      46             :     }
      47             : 
      48             :     template <typename data_t>
      49           0 :     auto LASSOProblem<data_t>::getLipschitzConstantImpl(index_t nIterations) const -> data_t
      50             :     {
      51             :         // compute the Lipschitz Constant of the WLSProblem as the reg. term is not differentiable
      52           0 :         return _wlsProblem.getLipschitzConstant(nIterations);
      53             :     }
      54             : 
      55             :     template <typename data_t>
      56           0 :     auto LASSOProblem<data_t>::wlsFromProblem(const Problem<data_t>& problem) -> WLSProblem<data_t>
      57             :     {
      58             :         // All residuals are LinearResidual, so it's safe
      59           0 :         auto& linResid = downcast<LinearResidual<data_t>>(problem.getDataTerm().getResidual());
      60             : 
      61           0 :         std::unique_ptr<LinearOperator<data_t>> dataTermOp;
      62             : 
      63           0 :         if (linResid.hasOperator()) {
      64           0 :             dataTermOp = linResid.getOperator().clone();
      65             :         } else {
      66           0 :             dataTermOp = std::make_unique<Identity<data_t>>(linResid.getDomainDescriptor());
      67             :         }
      68             : 
      69           0 :         const DataContainer<data_t> dataVec = [&] {
      70           0 :             if (linResid.hasDataVector()) {
      71           0 :                 return DataContainer<data_t>(linResid.getDataVector());
      72             :             } else {
      73           0 :                 Eigen::Matrix<data_t, Eigen::Dynamic, 1> zeroes(
      74           0 :                     linResid.getRangeDescriptor().getNumberOfCoefficients());
      75           0 :                 zeroes.setZero();
      76             : 
      77           0 :                 return DataContainer<data_t>(linResid.getRangeDescriptor(), zeroes);
      78           0 :             }
      79             :         }();
      80             : 
      81           0 :         return WLSProblem<data_t>(*dataTermOp, dataVec);
      82           0 :     }
      83             : 
      84             :     template <typename data_t>
      85           0 :     auto LASSOProblem<data_t>::regTermFromProblem(const Problem<data_t>& problem)
      86             :         -> RegularizationTerm<data_t>
      87             :     {
      88           0 :         const auto& regTerms = problem.getRegularizationTerms();
      89             : 
      90           0 :         if (regTerms.size() != 1) {
      91           0 :             throw InvalidArgumentError("LASSOProblem: exactly one regularization term is required");
      92             :         }
      93             : 
      94           0 :         return regTerms[0];
      95             :     }
      96             : 
      97             :     // ------------------------------------------
      98             :     // explicit template instantiation
      99             :     template class LASSOProblem<float>;
     100             :     template class LASSOProblem<double>;
     101             : } // namespace elsa

Generated by: LCOV version 1.14