LCOV - code coverage report
Current view: top level - elsa/solvers - RegularizedInversion.cpp (source / functions) Hit Total Coverage
Test: coverage-all.lcov Lines: 44 48 91.7 %
Date: 2024-05-16 04:22:26 Functions: 6 6 100.0 %

          Line data    Source code
       1             : #include "RegularizedInversion.h"
       2             : #include "BlockLinearOperator.h"
       3             : #include "RandomBlocksDescriptor.h"
       4             : #include "CGNE.h"
       5             : #include "Solver.h"
       6             : #include "Scaling.h"
       7             : 
       8             : #include <variant>
       9             : #include <vector>
      10             : 
      11             : namespace elsa
      12             : {
      13             :     template <class data_t>
      14             :     DataContainer<data_t>
      15             :         reguarlizedInversion(const LinearOperator<data_t>& op, const DataContainer<data_t>& b,
      16             :                              const std::vector<std::unique_ptr<LinearOperator<data_t>>>& regOps,
      17             :                              const std::vector<DataContainer<data_t>>& regData,
      18             :                              std::variant<data_t, std::vector<data_t>> lambda, index_t niters,
      19             :                              std::optional<DataContainer<data_t>> W,
      20             :                              std::optional<DataContainer<data_t>> x0)
      21          24 :     {
      22          24 :         index_t size = 1 + asSigned(regOps.size());
      23             : 
      24          24 :         auto x = extract_or(x0, op.getDomainDescriptor());
      25             : 
      26             :         // Setup a block problem, where K = [Op; regOps..], and w = [b; c - Bz - u]
      27          24 :         std::vector<std::unique_ptr<DataDescriptor>> descs;
      28          24 :         descs.emplace_back(b.getDataDescriptor().clone());
      29          48 :         for (size_t i = 0; i < regData.size(); ++i) {
      30          24 :             descs.emplace_back(regData[i].getDataDescriptor().clone());
      31          24 :         }
      32          24 :         RandomBlocksDescriptor blockDesc(descs);
      33             : 
      34          24 :         std::vector<std::unique_ptr<LinearOperator<data_t>>> opList;
      35          24 :         opList.reserve(size);
      36             : 
      37          24 :         if (W.has_value()) {
      38           2 :             opList.emplace_back((Scaling<data_t>(*W) * op).clone());
      39          22 :         } else {
      40          22 :             opList.emplace_back(op.clone());
      41          22 :         }
      42             : 
      43          48 :         for (size_t i = 0; i < regOps.size(); ++i) {
      44          24 :             auto& regOp = *regOps[i];
      45             : 
      46          24 :             auto regParam = [&]() {
      47          24 :                 if (std::holds_alternative<data_t>(lambda)) {
      48          24 :                     return std::get<data_t>(lambda);
      49          24 :                 } else {
      50           0 :                     return std::get<std::vector<data_t>>(lambda)[i];
      51           0 :                 }
      52          24 :             }();
      53          24 :             opList.emplace_back((regParam * regOp).clone());
      54          24 :         }
      55             : 
      56          24 :         BlockLinearOperator K(op.getDomainDescriptor(), blockDesc, opList,
      57          24 :                               BlockLinearOperator<data_t>::BlockType::ROW);
      58             : 
      59          24 :         DataContainer<data_t> w(blockDesc);
      60          24 :         if (W.has_value()) {
      61           2 :             w.getBlock(0) = *W * b;
      62          22 :         } else {
      63          22 :             w.getBlock(0) = b;
      64          22 :         }
      65             : 
      66          48 :         for (index_t i = 1; i < size; ++i) {
      67          24 :             auto regParam = [&]() {
      68          24 :                 if (std::holds_alternative<data_t>(lambda)) {
      69          24 :                     return std::get<data_t>(lambda);
      70          24 :                 } else {
      71           0 :                     return std::get<std::vector<data_t>>(lambda)[i];
      72           0 :                 }
      73          24 :             }();
      74          24 :             w.getBlock(i) = regParam * regData[i - 1];
      75          24 :         }
      76             : 
      77          24 :         CGNE<data_t> cg(K, w);
      78          24 :         return cg.solve(niters, x);
      79          24 :     }
      80             : 
      81             :     template DataContainer<float> reguarlizedInversion<float>(
      82             :         const LinearOperator<float>& op, const DataContainer<float>& b,
      83             :         const std::vector<std::unique_ptr<LinearOperator<float>>>& regOps,
      84             :         const std::vector<DataContainer<float>>& regData,
      85             :         std::variant<float, std::vector<float>> lambda, index_t niters,
      86             :         std::optional<DataContainer<float>> W, std::optional<DataContainer<float>> x0);
      87             : 
      88             :     template DataContainer<double>
      89             :         reguarlizedInversion(const LinearOperator<double>& op, const DataContainer<double>& b,
      90             :                              const std::vector<std::unique_ptr<LinearOperator<double>>>& regOps,
      91             :                              const std::vector<DataContainer<double>>& regData,
      92             :                              std::variant<double, std::vector<double>> lambda, index_t niters,
      93             :                              std::optional<DataContainer<double>> W,
      94             :                              std::optional<DataContainer<double>> x0);
      95             : } // namespace elsa

Generated by: LCOV version 1.14