Line data Source code
1 : #include "RegularizedInversion.h" 2 : #include "BlockLinearOperator.h" 3 : #include "RandomBlocksDescriptor.h" 4 : #include "CGNE.h" 5 : #include "Solver.h" 6 : #include "Scaling.h" 7 : 8 : #include <variant> 9 : #include <vector> 10 : 11 : namespace elsa 12 : { 13 : template <class data_t> 14 : DataContainer<data_t> 15 : reguarlizedInversion(const LinearOperator<data_t>& op, const DataContainer<data_t>& b, 16 : const std::vector<std::unique_ptr<LinearOperator<data_t>>>& regOps, 17 : const std::vector<DataContainer<data_t>>& regData, 18 : std::variant<data_t, std::vector<data_t>> lambda, index_t niters, 19 : std::optional<DataContainer<data_t>> W, 20 : std::optional<DataContainer<data_t>> x0) 21 24 : { 22 24 : index_t size = 1 + asSigned(regOps.size()); 23 : 24 24 : auto x = extract_or(x0, op.getDomainDescriptor()); 25 : 26 : // Setup a block problem, where K = [Op; regOps..], and w = [b; c - Bz - u] 27 24 : std::vector<std::unique_ptr<DataDescriptor>> descs; 28 24 : descs.emplace_back(b.getDataDescriptor().clone()); 29 48 : for (size_t i = 0; i < regData.size(); ++i) { 30 24 : descs.emplace_back(regData[i].getDataDescriptor().clone()); 31 24 : } 32 24 : RandomBlocksDescriptor blockDesc(descs); 33 : 34 24 : std::vector<std::unique_ptr<LinearOperator<data_t>>> opList; 35 24 : opList.reserve(size); 36 : 37 24 : if (W.has_value()) { 38 2 : opList.emplace_back((Scaling<data_t>(*W) * op).clone()); 39 22 : } else { 40 22 : opList.emplace_back(op.clone()); 41 22 : } 42 : 43 48 : for (size_t i = 0; i < regOps.size(); ++i) { 44 24 : auto& regOp = *regOps[i]; 45 : 46 24 : auto regParam = [&]() { 47 24 : if (std::holds_alternative<data_t>(lambda)) { 48 24 : return std::get<data_t>(lambda); 49 24 : } else { 50 0 : return std::get<std::vector<data_t>>(lambda)[i]; 51 0 : } 52 24 : }(); 53 24 : opList.emplace_back((regParam * regOp).clone()); 54 24 : } 55 : 56 24 : BlockLinearOperator K(op.getDomainDescriptor(), blockDesc, opList, 57 24 : BlockLinearOperator<data_t>::BlockType::ROW); 58 : 59 24 : DataContainer<data_t> w(blockDesc); 60 24 : if (W.has_value()) { 61 2 : w.getBlock(0) = *W * b; 62 22 : } else { 63 22 : w.getBlock(0) = b; 64 22 : } 65 : 66 48 : for (index_t i = 1; i < size; ++i) { 67 24 : auto regParam = [&]() { 68 24 : if (std::holds_alternative<data_t>(lambda)) { 69 24 : return std::get<data_t>(lambda); 70 24 : } else { 71 0 : return std::get<std::vector<data_t>>(lambda)[i]; 72 0 : } 73 24 : }(); 74 24 : w.getBlock(i) = regParam * regData[i - 1]; 75 24 : } 76 : 77 24 : CGNE<data_t> cg(K, w); 78 24 : return cg.solve(niters, x); 79 24 : } 80 : 81 : template DataContainer<float> reguarlizedInversion<float>( 82 : const LinearOperator<float>& op, const DataContainer<float>& b, 83 : const std::vector<std::unique_ptr<LinearOperator<float>>>& regOps, 84 : const std::vector<DataContainer<float>>& regData, 85 : std::variant<float, std::vector<float>> lambda, index_t niters, 86 : std::optional<DataContainer<float>> W, std::optional<DataContainer<float>> x0); 87 : 88 : template DataContainer<double> 89 : reguarlizedInversion(const LinearOperator<double>& op, const DataContainer<double>& b, 90 : const std::vector<std::unique_ptr<LinearOperator<double>>>& regOps, 91 : const std::vector<DataContainer<double>>& regData, 92 : std::variant<double, std::vector<double>> lambda, index_t niters, 93 : std::optional<DataContainer<double>> W, 94 : std::optional<DataContainer<double>> x0); 95 : } // namespace elsa