LCOV - code coverage report
Current view: top level - elsa/problems - WLSProblem.cpp (source / functions) Hit Total Coverage
Test: coverage-all.lcov Lines: 121 150 80.7 %
Date: 2022-08-25 03:05:39 Functions: 16 28 57.1 %

          Line data    Source code
       1             : #include "WLSProblem.h"
       2             : #include "L2NormPow2.h"
       3             : #include "WeightedL2NormPow2.h"
       4             : #include "RandomBlocksDescriptor.h"
       5             : #include "BlockLinearOperator.h"
       6             : #include "Identity.h"
       7             : #include "TypeCasts.hpp"
       8             : 
       9             : namespace elsa
      10             : {
      11             :     template <typename data_t>
      12             :     WLSProblem<data_t>::WLSProblem(const Scaling<data_t>& W, const LinearOperator<data_t>& A,
      13             :                                    const DataContainer<data_t>& b, const DataContainer<data_t>& x0,
      14             :                                    const std::optional<data_t> lipschitzConstant)
      15             :         : Problem<data_t>{WeightedL2NormPow2<data_t>{LinearResidual<data_t>{A, b}, W}, x0,
      16             :                           lipschitzConstant}
      17           4 :     {
      18             :         // sanity checks are done in the member constructors already
      19           4 :     }
      20             : 
      21             :     template <typename data_t>
      22             :     WLSProblem<data_t>::WLSProblem(const Scaling<data_t>& W, const LinearOperator<data_t>& A,
      23             :                                    const DataContainer<data_t>& b,
      24             :                                    const std::optional<data_t> lipschitzConstant)
      25             :         : Problem<data_t>{WeightedL2NormPow2<data_t>{LinearResidual<data_t>{A, b}, W},
      26             :                           lipschitzConstant}
      27           4 :     {
      28             :         // sanity checks are done in the member constructors already
      29           4 :     }
      30             : 
      31             :     template <typename data_t>
      32             :     WLSProblem<data_t>::WLSProblem(const LinearOperator<data_t>& A, const DataContainer<data_t>& b,
      33             :                                    const DataContainer<data_t>& x0,
      34             :                                    const std::optional<data_t> lipschitzConstant)
      35             :         : Problem<data_t>{L2NormPow2<data_t>{LinearResidual<data_t>{A, b}}, x0, lipschitzConstant}
      36           4 :     {
      37             :         // sanity checks are done in the member constructors already
      38           4 :     }
      39             : 
      40             :     template <typename data_t>
      41             :     WLSProblem<data_t>::WLSProblem(const LinearOperator<data_t>& A, const DataContainer<data_t>& b,
      42             :                                    const std::optional<data_t> lipschitzConstant)
      43             :         : Problem<data_t>{L2NormPow2<data_t>{LinearResidual<data_t>{A, b}}, lipschitzConstant}
      44         134 :     {
      45             :         // sanity checks are done in the member constructors already
      46         134 :     }
      47             : 
      48             :     template <typename data_t>
      49             :     WLSProblem<data_t>::WLSProblem(const Problem<data_t>& problem)
      50             :         : Problem<data_t>{*wlsFromProblem(problem), problem.getCurrentSolution()}
      51          26 :     {
      52          26 :     }
      53             : 
      54             :     template <typename data_t>
      55             :     WLSProblem<data_t>* WLSProblem<data_t>::cloneImpl() const
      56          94 :     {
      57          94 :         return new WLSProblem(*this);
      58          94 :     }
      59             : 
      60             :     template <typename data_t>
      61             :     std::unique_ptr<Functional<data_t>>
      62             :         WLSProblem<data_t>::wlsFromProblem(const Problem<data_t>& problem)
      63          26 :     {
      64          26 :         const auto& dataTerm = problem.getDataTerm();
      65          26 :         const auto& regTerms = problem.getRegularizationTerms();
      66             : 
      67          26 :         if (!is<WeightedL2NormPow2<data_t>>(dataTerm) && !is<L2NormPow2<data_t>>(dataTerm))
      68           2 :             throw LogicError("WLSProblem: conversion failed - data term is not "
      69           2 :                              "of type (Weighted)L2NormPow2");
      70             : 
      71          24 :         const auto dataTermResidual =
      72          24 :             downcast_safe<LinearResidual<data_t>>(&dataTerm.getResidual());
      73             : 
      74          24 :         if (!dataTermResidual)
      75           0 :             throw LogicError("WLSProblem: conversion failed - data term is non-linear");
      76             : 
      77             :         // data term is of convertible type
      78             : 
      79             :         // no conversion needed if no regTerms
      80          24 :         if (regTerms.empty())
      81           8 :             return dataTerm.clone();
      82             : 
      83             :         // else regTerms present, determine data vector descriptor
      84          16 :         std::vector<std::unique_ptr<DataDescriptor>> rangeDescList(0);
      85          16 :         rangeDescList.push_back(dataTermResidual->getRangeDescriptor().clone());
      86          16 :         for (const auto& regTerm : regTerms) {
      87          16 :             if (!is<WeightedL2NormPow2<data_t>>(regTerm.getFunctional())
      88          16 :                 && !is<L2NormPow2<data_t>>(regTerm.getFunctional()))
      89           2 :                 throw LogicError("WLSProblem: conversion failed - regularization term is not "
      90           2 :                                  "of type (Weighted)L2NormPow2");
      91             : 
      92             :             //
      93          14 :             const auto regTermResidual =
      94          14 :                 downcast_safe<LinearResidual<data_t>>(&regTerm.getFunctional().getResidual());
      95             : 
      96          14 :             if (!regTermResidual)
      97           0 :                 throw LogicError(
      98           0 :                     "WLSProblem: conversion failed - regularization term is non-linear");
      99             : 
     100          14 :             rangeDescList.push_back(regTermResidual->getRangeDescriptor().clone());
     101          14 :         }
     102             : 
     103             :         // problem is convertible, allocate memory, set to zero and start building block op
     104          16 :         RandomBlocksDescriptor dataVecDesc{rangeDescList};
     105           0 :         DataContainer<data_t> dataVec{dataVecDesc};
     106           0 :         dataVec = 0;
     107           0 :         std::vector<std::unique_ptr<LinearOperator<data_t>>> opList(0);
     108             : 
     109             :         // add block corresponding to data term
     110          14 :         if (is<WeightedL2NormPow2<data_t>>(dataTerm)) {
     111           4 :             const auto& trueFunc = downcast<WeightedL2NormPow2<data_t>>(dataTerm);
     112           4 :             const auto& scaling = trueFunc.getWeightingOperator();
     113           4 :             const auto& desc = scaling.getDomainDescriptor();
     114             : 
     115           4 :             std::unique_ptr<Scaling<data_t>> sqrtW{};
     116           4 :             if (scaling.isIsotropic()) {
     117           2 :                 auto fac = scaling.getScaleFactor();
     118             : 
     119           2 :                 if constexpr (std::is_floating_point_v<data_t>) {
     120           2 :                     if (fac < 0) {
     121           2 :                         throw LogicError("WLSProblem: conversion failed - negative weighting "
     122           2 :                                          "factor in WeightedL2NormPow2 term");
     123           2 :                     }
     124           0 :                 }
     125             : 
     126           0 :                 sqrtW = std::make_unique<Scaling<data_t>>(desc, sqrt(fac));
     127           2 :             } else {
     128           2 :                 const auto& fac = scaling.getScaleFactors();
     129             : 
     130           2 :                 if constexpr (std::is_floating_point_v<data_t>) {
     131         514 :                     for (const auto& w : fac) {
     132         514 :                         if (w < 0) {
     133           2 :                             throw LogicError("WLSProblem: conversion failed - negative weighting "
     134           2 :                                              "factor in WeightedL2NormPow2 term");
     135           2 :                         }
     136         514 :                     }
     137           2 :                 }
     138             : 
     139           2 :                 sqrtW = std::make_unique<Scaling<data_t>>(desc, sqrt(fac));
     140           0 :             }
     141             : 
     142           4 :             if (dataTermResidual->hasDataVector())
     143           0 :                 dataVec.getBlock(0) = sqrtW->apply(dataTermResidual->getDataVector());
     144             : 
     145           0 :             if (dataTermResidual->hasOperator()) {
     146           0 :                 const auto composite = *sqrtW * dataTermResidual->getOperator();
     147           0 :                 opList.emplace_back(composite.clone());
     148           0 :             } else {
     149           0 :                 opList.push_back(std::move(sqrtW));
     150           0 :             }
     151             : 
     152          10 :         } else {
     153          10 :             if (dataTermResidual->hasOperator()) {
     154           2 :                 opList.push_back(dataTermResidual->getOperator().clone());
     155           8 :             } else {
     156           8 :                 opList.push_back(
     157           8 :                     std::make_unique<Identity<data_t>>(dataTermResidual->getDomainDescriptor()));
     158           8 :             }
     159             : 
     160          10 :             if (dataTermResidual->hasDataVector())
     161           2 :                 dataVec.getBlock(0) = dataTermResidual->getDataVector();
     162          10 :         }
     163             : 
     164             :         // add blocks corresponding to reg terms
     165          10 :         index_t blockNum = 1;
     166          10 :         for (const auto& regTerm : regTerms) {
     167          10 :             const data_t lambda = regTerm.getWeight();
     168          10 :             const auto& func = regTerm.getFunctional();
     169          10 :             const auto residual = static_cast<const LinearResidual<data_t>*>(&func.getResidual());
     170             : 
     171          10 :             if (is<WeightedL2NormPow2<data_t>>(func)) {
     172           4 :                 const auto& trueFunc = downcast<WeightedL2NormPow2<data_t>>(func);
     173           4 :                 const auto& scaling = trueFunc.getWeightingOperator();
     174           4 :                 const auto& desc = scaling.getDomainDescriptor();
     175             : 
     176           4 :                 std::unique_ptr<Scaling<data_t>> sqrtLambdaW{};
     177           4 :                 if (scaling.isIsotropic()) {
     178           2 :                     auto fac = scaling.getScaleFactor();
     179             : 
     180           2 :                     if constexpr (std::is_floating_point_v<data_t>) {
     181           2 :                         if (lambda * fac < 0) {
     182           2 :                             throw LogicError("WLSProblem: conversion failed - negative weighting "
     183           2 :                                              "factor in WeightedL2NormPow2 term");
     184           2 :                         }
     185           0 :                     }
     186             : 
     187           0 :                     sqrtLambdaW = std::make_unique<Scaling<data_t>>(desc, sqrt(lambda * fac));
     188           2 :                 } else {
     189           2 :                     const auto& fac = scaling.getScaleFactors();
     190             : 
     191           2 :                     if constexpr (std::is_floating_point_v<data_t>) {
     192         514 :                         for (const auto& w : fac) {
     193         514 :                             if (lambda * w < 0) {
     194           2 :                                 throw LogicError(
     195           2 :                                     "WLSProblem: conversion failed - negative weighting "
     196           2 :                                     "factor in WeightedL2NormPow2 term");
     197           2 :                             }
     198         514 :                         }
     199           2 :                     }
     200             : 
     201           2 :                     sqrtLambdaW = std::make_unique<Scaling<data_t>>(desc, sqrt(lambda * fac));
     202           0 :                 }
     203             : 
     204           4 :                 if (residual->hasDataVector())
     205           0 :                     dataVec.getBlock(blockNum) = sqrtLambdaW->apply(residual->getDataVector());
     206             : 
     207           0 :                 if (residual->hasOperator()) {
     208           0 :                     const auto composite = *sqrtLambdaW * residual->getOperator();
     209           0 :                     opList.push_back(composite.clone());
     210           0 :                 } else {
     211           0 :                     opList.push_back(std::move(sqrtLambdaW));
     212           0 :                 }
     213             : 
     214           6 :             } else {
     215           6 :                 if constexpr (std::is_floating_point_v<data_t>) {
     216           6 :                     if (lambda < 0) {
     217           2 :                         throw LogicError(
     218           2 :                             "WLSProblem: conversion failed - negative regularization term weight");
     219           2 :                     }
     220           4 :                 }
     221             : 
     222           4 :                 auto sqrtLambdaScaling =
     223           4 :                     std::make_unique<Scaling<data_t>>(residual->getRangeDescriptor(), sqrt(lambda));
     224             : 
     225           4 :                 if (residual->hasDataVector())
     226           0 :                     dataVec.getBlock(blockNum) = sqrt(lambda) * residual->getDataVector();
     227             : 
     228           4 :                 if (residual->hasOperator()) {
     229           0 :                     const auto composite = *sqrtLambdaScaling * residual->getOperator();
     230           0 :                     opList.emplace_back(composite.clone());
     231           4 :                 } else {
     232           4 :                     opList.push_back(std::move(sqrtLambdaScaling));
     233           4 :                 }
     234           4 :             }
     235             : 
     236          10 :             blockNum++;
     237           4 :         }
     238             : 
     239          10 :         BlockLinearOperator<data_t> blockOp{opList, BlockLinearOperator<data_t>::BlockType::ROW};
     240             : 
     241           4 :         return std::make_unique<L2NormPow2<data_t>>(LinearResidual<data_t>{blockOp, dataVec});
     242          10 :     }
     243             : 
     244             :     // ------------------------------------------
     245             :     // explicit template instantiation
     246             :     template class WLSProblem<float>;
     247             :     template class WLSProblem<double>;
     248             :     template class WLSProblem<complex<float>>;
     249             :     template class WLSProblem<complex<double>>;
     250             : 
     251             : } // namespace elsa

Generated by: LCOV version 1.14