LCOV - code coverage report
Current view: top level - problems - WLSProblem.cpp (source / functions) Hit Total Coverage
Test: test_coverage.info.cleaned Lines: 0 113 0.0 %
Date: 2022-08-04 03:43:28 Functions: 0 28 0.0 %

          Line data    Source code
       1             : #include "WLSProblem.h"
       2             : #include "L2NormPow2.h"
       3             : #include "WeightedL2NormPow2.h"
       4             : #include "RandomBlocksDescriptor.h"
       5             : #include "BlockLinearOperator.h"
       6             : #include "Identity.h"
       7             : #include "TypeCasts.hpp"
       8             : 
       9             : namespace elsa
      10             : {
      11             :     template <typename data_t>
      12           0 :     WLSProblem<data_t>::WLSProblem(const Scaling<data_t>& W, const LinearOperator<data_t>& A,
      13             :                                    const DataContainer<data_t>& b, const DataContainer<data_t>& x0,
      14             :                                    const std::optional<data_t> lipschitzConstant)
      15             :         : Problem<data_t>{WeightedL2NormPow2<data_t>{LinearResidual<data_t>{A, b}, W}, x0,
      16           0 :                           lipschitzConstant}
      17             :     {
      18             :         // sanity checks are done in the member constructors already
      19           0 :     }
      20             : 
      21             :     template <typename data_t>
      22           0 :     WLSProblem<data_t>::WLSProblem(const Scaling<data_t>& W, const LinearOperator<data_t>& A,
      23             :                                    const DataContainer<data_t>& b,
      24             :                                    const std::optional<data_t> lipschitzConstant)
      25             :         : Problem<data_t>{WeightedL2NormPow2<data_t>{LinearResidual<data_t>{A, b}, W},
      26           0 :                           lipschitzConstant}
      27             :     {
      28             :         // sanity checks are done in the member constructors already
      29           0 :     }
      30             : 
      31             :     template <typename data_t>
      32           0 :     WLSProblem<data_t>::WLSProblem(const LinearOperator<data_t>& A, const DataContainer<data_t>& b,
      33             :                                    const DataContainer<data_t>& x0,
      34             :                                    const std::optional<data_t> lipschitzConstant)
      35           0 :         : Problem<data_t>{L2NormPow2<data_t>{LinearResidual<data_t>{A, b}}, x0, lipschitzConstant}
      36             :     {
      37             :         // sanity checks are done in the member constructors already
      38           0 :     }
      39             : 
      40             :     template <typename data_t>
      41           0 :     WLSProblem<data_t>::WLSProblem(const LinearOperator<data_t>& A, const DataContainer<data_t>& b,
      42             :                                    const std::optional<data_t> lipschitzConstant)
      43           0 :         : Problem<data_t>{L2NormPow2<data_t>{LinearResidual<data_t>{A, b}}, lipschitzConstant}
      44             :     {
      45             :         // sanity checks are done in the member constructors already
      46           0 :     }
      47             : 
      48             :     template <typename data_t>
      49           0 :     WLSProblem<data_t>::WLSProblem(const Problem<data_t>& problem)
      50           0 :         : Problem<data_t>{*wlsFromProblem(problem), problem.getCurrentSolution()}
      51             :     {
      52           0 :     }
      53             : 
      54             :     template <typename data_t>
      55           0 :     WLSProblem<data_t>* WLSProblem<data_t>::cloneImpl() const
      56             :     {
      57           0 :         return new WLSProblem(*this);
      58             :     }
      59             : 
      60             :     template <typename data_t>
      61             :     std::unique_ptr<Functional<data_t>>
      62           0 :         WLSProblem<data_t>::wlsFromProblem(const Problem<data_t>& problem)
      63             :     {
      64           0 :         const auto& dataTerm = problem.getDataTerm();
      65           0 :         const auto& regTerms = problem.getRegularizationTerms();
      66             : 
      67           0 :         if (!is<WeightedL2NormPow2<data_t>>(dataTerm) && !is<L2NormPow2<data_t>>(dataTerm))
      68           0 :             throw LogicError("WLSProblem: conversion failed - data term is not "
      69             :                              "of type (Weighted)L2NormPow2");
      70             : 
      71             :         const auto dataTermResidual =
      72           0 :             downcast_safe<LinearResidual<data_t>>(&dataTerm.getResidual());
      73             : 
      74           0 :         if (!dataTermResidual)
      75           0 :             throw LogicError("WLSProblem: conversion failed - data term is non-linear");
      76             : 
      77             :         // data term is of convertible type
      78             : 
      79             :         // no conversion needed if no regTerms
      80           0 :         if (regTerms.empty())
      81           0 :             return dataTerm.clone();
      82             : 
      83             :         // else regTerms present, determine data vector descriptor
      84           0 :         std::vector<std::unique_ptr<DataDescriptor>> rangeDescList(0);
      85           0 :         rangeDescList.push_back(dataTermResidual->getRangeDescriptor().clone());
      86           0 :         for (const auto& regTerm : regTerms) {
      87           0 :             if (!is<WeightedL2NormPow2<data_t>>(regTerm.getFunctional())
      88           0 :                 && !is<L2NormPow2<data_t>>(regTerm.getFunctional()))
      89           0 :                 throw LogicError("WLSProblem: conversion failed - regularization term is not "
      90             :                                  "of type (Weighted)L2NormPow2");
      91             : 
      92             :             //
      93             :             const auto regTermResidual =
      94           0 :                 downcast_safe<LinearResidual<data_t>>(&regTerm.getFunctional().getResidual());
      95             : 
      96           0 :             if (!regTermResidual)
      97           0 :                 throw LogicError(
      98             :                     "WLSProblem: conversion failed - regularization term is non-linear");
      99             : 
     100           0 :             rangeDescList.push_back(regTermResidual->getRangeDescriptor().clone());
     101             :         }
     102             : 
     103             :         // problem is convertible, allocate memory, set to zero and start building block op
     104           0 :         RandomBlocksDescriptor dataVecDesc{rangeDescList};
     105           0 :         DataContainer<data_t> dataVec{dataVecDesc};
     106           0 :         dataVec = 0;
     107           0 :         std::vector<std::unique_ptr<LinearOperator<data_t>>> opList(0);
     108             : 
     109             :         // add block corresponding to data term
     110           0 :         if (is<WeightedL2NormPow2<data_t>>(dataTerm)) {
     111           0 :             const auto& trueFunc = downcast<WeightedL2NormPow2<data_t>>(dataTerm);
     112           0 :             const auto& scaling = trueFunc.getWeightingOperator();
     113           0 :             const auto& desc = scaling.getDomainDescriptor();
     114             : 
     115           0 :             std::unique_ptr<Scaling<data_t>> sqrtW{};
     116           0 :             if (scaling.isIsotropic()) {
     117           0 :                 auto fac = scaling.getScaleFactor();
     118             : 
     119             :                 if constexpr (std::is_floating_point_v<data_t>) {
     120           0 :                     if (fac < 0) {
     121           0 :                         throw LogicError("WLSProblem: conversion failed - negative weighting "
     122             :                                          "factor in WeightedL2NormPow2 term");
     123             :                     }
     124             :                 }
     125             : 
     126           0 :                 sqrtW = std::make_unique<Scaling<data_t>>(desc, sqrt(fac));
     127             :             } else {
     128           0 :                 const auto& fac = scaling.getScaleFactors();
     129             : 
     130             :                 if constexpr (std::is_floating_point_v<data_t>) {
     131           0 :                     for (const auto& w : fac) {
     132           0 :                         if (w < 0) {
     133           0 :                             throw LogicError("WLSProblem: conversion failed - negative weighting "
     134             :                                              "factor in WeightedL2NormPow2 term");
     135             :                         }
     136             :                     }
     137             :                 }
     138             : 
     139           0 :                 sqrtW = std::make_unique<Scaling<data_t>>(desc, sqrt(fac));
     140             :             }
     141             : 
     142           0 :             if (dataTermResidual->hasDataVector())
     143           0 :                 dataVec.getBlock(0) = sqrtW->apply(dataTermResidual->getDataVector());
     144             : 
     145           0 :             if (dataTermResidual->hasOperator()) {
     146           0 :                 const auto composite = *sqrtW * dataTermResidual->getOperator();
     147           0 :                 opList.emplace_back(composite.clone());
     148           0 :             } else {
     149           0 :                 opList.push_back(std::move(sqrtW));
     150             :             }
     151             : 
     152           0 :         } else {
     153           0 :             if (dataTermResidual->hasOperator()) {
     154           0 :                 opList.push_back(dataTermResidual->getOperator().clone());
     155             :             } else {
     156           0 :                 opList.push_back(
     157             :                     std::make_unique<Identity<data_t>>(dataTermResidual->getDomainDescriptor()));
     158             :             }
     159             : 
     160           0 :             if (dataTermResidual->hasDataVector())
     161           0 :                 dataVec.getBlock(0) = dataTermResidual->getDataVector();
     162             :         }
     163             : 
     164             :         // add blocks corresponding to reg terms
     165           0 :         index_t blockNum = 1;
     166           0 :         for (const auto& regTerm : regTerms) {
     167           0 :             const data_t lambda = regTerm.getWeight();
     168           0 :             const auto& func = regTerm.getFunctional();
     169           0 :             const auto residual = static_cast<const LinearResidual<data_t>*>(&func.getResidual());
     170             : 
     171           0 :             if (is<WeightedL2NormPow2<data_t>>(func)) {
     172           0 :                 const auto& trueFunc = downcast<WeightedL2NormPow2<data_t>>(func);
     173           0 :                 const auto& scaling = trueFunc.getWeightingOperator();
     174           0 :                 const auto& desc = scaling.getDomainDescriptor();
     175             : 
     176           0 :                 std::unique_ptr<Scaling<data_t>> sqrtLambdaW{};
     177           0 :                 if (scaling.isIsotropic()) {
     178           0 :                     auto fac = scaling.getScaleFactor();
     179             : 
     180             :                     if constexpr (std::is_floating_point_v<data_t>) {
     181           0 :                         if (lambda * fac < 0) {
     182           0 :                             throw LogicError("WLSProblem: conversion failed - negative weighting "
     183             :                                              "factor in WeightedL2NormPow2 term");
     184             :                         }
     185             :                     }
     186             : 
     187           0 :                     sqrtLambdaW = std::make_unique<Scaling<data_t>>(desc, sqrt(lambda * fac));
     188             :                 } else {
     189           0 :                     const auto& fac = scaling.getScaleFactors();
     190             : 
     191             :                     if constexpr (std::is_floating_point_v<data_t>) {
     192           0 :                         for (const auto& w : fac) {
     193           0 :                             if (lambda * w < 0) {
     194           0 :                                 throw LogicError(
     195             :                                     "WLSProblem: conversion failed - negative weighting "
     196             :                                     "factor in WeightedL2NormPow2 term");
     197             :                             }
     198             :                         }
     199             :                     }
     200             : 
     201           0 :                     sqrtLambdaW = std::make_unique<Scaling<data_t>>(desc, sqrt(lambda * fac));
     202             :                 }
     203             : 
     204           0 :                 if (residual->hasDataVector())
     205           0 :                     dataVec.getBlock(blockNum) = sqrtLambdaW->apply(residual->getDataVector());
     206             : 
     207           0 :                 if (residual->hasOperator()) {
     208           0 :                     const auto composite = *sqrtLambdaW * residual->getOperator();
     209           0 :                     opList.push_back(composite.clone());
     210           0 :                 } else {
     211           0 :                     opList.push_back(std::move(sqrtLambdaW));
     212             :                 }
     213             : 
     214           0 :             } else {
     215             :                 if constexpr (std::is_floating_point_v<data_t>) {
     216           0 :                     if (lambda < 0) {
     217           0 :                         throw LogicError(
     218             :                             "WLSProblem: conversion failed - negative regularization term weight");
     219             :                     }
     220             :                 }
     221             : 
     222           0 :                 auto sqrtLambdaScaling =
     223           0 :                     std::make_unique<Scaling<data_t>>(residual->getRangeDescriptor(), sqrt(lambda));
     224             : 
     225           0 :                 if (residual->hasDataVector())
     226           0 :                     dataVec.getBlock(blockNum) = sqrt(lambda) * residual->getDataVector();
     227             : 
     228           0 :                 if (residual->hasOperator()) {
     229           0 :                     const auto composite = *sqrtLambdaScaling * residual->getOperator();
     230           0 :                     opList.emplace_back(composite.clone());
     231           0 :                 } else {
     232           0 :                     opList.push_back(std::move(sqrtLambdaScaling));
     233             :                 }
     234           0 :             }
     235             : 
     236           0 :             blockNum++;
     237             :         }
     238             : 
     239           0 :         BlockLinearOperator<data_t> blockOp{opList, BlockLinearOperator<data_t>::BlockType::ROW};
     240             : 
     241           0 :         return std::make_unique<L2NormPow2<data_t>>(LinearResidual<data_t>{blockOp, dataVec});
     242           0 :     }
     243             : 
     244             :     // ------------------------------------------
     245             :     // explicit template instantiation
     246             :     template class WLSProblem<float>;
     247             :     template class WLSProblem<double>;
     248             :     template class WLSProblem<complex<float>>;
     249             :     template class WLSProblem<complex<double>>;
     250             : 
     251             : } // namespace elsa

Generated by: LCOV version 1.14