LCOV - code coverage report
Current view: top level - elsa/problems - QuadricProblem.cpp (source / functions) Hit Total Coverage
Test: coverage-all.lcov Lines: 116 194 59.8 %
Date: 2022-08-25 03:05:39 Functions: 16 16 100.0 %

          Line data    Source code
       1             : #include "QuadricProblem.h"
       2             : #include "L2NormPow2.h"
       3             : #include "WeightedL2NormPow2.h"
       4             : #include "LinearOperator.h"
       5             : #include "Identity.h"
       6             : #include "TypeCasts.hpp"
       7             : 
       8             : namespace elsa
       9             : {
      10             : 
      11             :     template <typename data_t>
      12             :     QuadricProblem<data_t>::QuadricProblem(const LinearOperator<data_t>& A,
      13             :                                            const DataContainer<data_t>& b,
      14             :                                            const DataContainer<data_t>& x0, bool spdA)
      15             :         : QuadricProblem<data_t>{spdA ? Quadric{A, b} : Quadric{adjoint(A) * A, A.applyAdjoint(b)},
      16             :                                  x0}
      17           8 :     {
      18             :         // sanity checks are done in the member constructors already
      19           8 :     }
      20             : 
      21             :     template <typename data_t>
      22             :     QuadricProblem<data_t>::QuadricProblem(const LinearOperator<data_t>& A,
      23             :                                            const DataContainer<data_t>& b, bool spdA)
      24             :         : QuadricProblem<data_t>{spdA ? Quadric{A, b} : Quadric{adjoint(A) * A, A.applyAdjoint(b)}}
      25          16 :     {
      26             :         // sanity checks are done in the member constructors already
      27          16 :     }
      28             : 
      29             :     template <typename data_t>
      30             :     QuadricProblem<data_t>::QuadricProblem(const Quadric<data_t>& quadric,
      31             :                                            const DataContainer<data_t>& x0)
      32             :         : Problem<data_t>{quadric, x0}
      33          18 :     {
      34             :         // sanity checks are done in the member constructors already
      35          18 :     }
      36             : 
      37             :     template <typename data_t>
      38             :     QuadricProblem<data_t>::QuadricProblem(const Quadric<data_t>& quadric)
      39             :         : Problem<data_t>{quadric}
      40          22 :     {
      41             :         // sanity checks are done in the member constructors already
      42          22 :     }
      43             : 
      44             :     template <typename data_t>
      45             :     QuadricProblem<data_t>::QuadricProblem(const Problem<data_t>& problem)
      46             :         : Problem<data_t>{*quadricFromProblem(problem), problem.getCurrentSolution()}
      47          72 :     {
      48             :         // sanity checks are done in the member constructors already
      49          72 :     }
      50             : 
      51             :     template <typename data_t>
      52             :     QuadricProblem<data_t>* QuadricProblem<data_t>::cloneImpl() const
      53          26 :     {
      54          26 :         return new QuadricProblem(*this);
      55          26 :     }
      56             : 
      57             :     template <typename data_t>
      58             :     LinearResidual<data_t>
      59             :         QuadricProblem<data_t>::getGradientExpression(const RegularizationTerm<data_t>& regTerm)
      60          42 :     {
      61          42 :         const auto lambda = regTerm.getWeight();
      62          42 :         const Scaling<data_t> lambdaOp{regTerm.getFunctional().getDomainDescriptor(), lambda};
      63             : 
      64          42 :         if (is<L2NormPow2<data_t>>(regTerm.getFunctional())) {
      65          38 :             const auto& regFunc = downcast<L2NormPow2<data_t>>(regTerm.getFunctional());
      66             : 
      67          38 :             if (!is<LinearResidual<data_t>>(regFunc.getResidual())) {
      68           0 :                 throw LogicError("QuadricProblem: cannot convert a non-linear regularization "
      69           0 :                                  "term to quadric form");
      70           0 :             }
      71          38 :             const auto& regTermResidual = downcast<LinearResidual<data_t>>(regFunc.getResidual());
      72             : 
      73          38 :             if (regTermResidual.hasOperator()) {
      74          34 :                 const auto& regTermOp = regTermResidual.getOperator();
      75          34 :                 LinearOperator<data_t> A = lambdaOp * adjoint(regTermOp) * regTermOp;
      76             : 
      77          34 :                 if (regTermResidual.hasDataVector()) {
      78          34 :                     DataContainer<data_t> b =
      79          34 :                         lambda * regTermOp.applyAdjoint(regTermResidual.getDataVector());
      80          34 :                     return LinearResidual<data_t>{A, b};
      81          34 :                 } else {
      82           0 :                     return LinearResidual<data_t>{A};
      83           0 :                 }
      84           4 :             } else {
      85           4 :                 if (regTermResidual.hasDataVector()) {
      86           0 :                     DataContainer<data_t> b = lambda * regTermResidual.getDataVector();
      87           0 :                     return LinearResidual<data_t>{lambdaOp, b};
      88           4 :                 } else {
      89           4 :                     return LinearResidual<data_t>{lambdaOp};
      90           4 :                 }
      91           4 :             }
      92           4 :         } else if (is<WeightedL2NormPow2<data_t>>(regTerm.getFunctional())) {
      93           0 :             const auto& regFunc = downcast<WeightedL2NormPow2<data_t>>(regTerm.getFunctional());
      94             : 
      95           0 :             if (!is<LinearResidual<data_t>>(regFunc.getResidual()))
      96           0 :                 throw LogicError("QuadricProblem: cannot convert a non-linear regularization "
      97           0 :                                  "term to quadric form");
      98             : 
      99           0 :             const auto& regTermResidual = downcast<LinearResidual<data_t>>(regFunc.getResidual());
     100             : 
     101           0 :             const auto& W = regFunc.getWeightingOperator();
     102           0 :             std::unique_ptr<Scaling<data_t>> lambdaWPtr;
     103             : 
     104           0 :             if (W.isIsotropic())
     105           0 :                 lambdaWPtr = std::make_unique<Scaling<data_t>>(W.getDomainDescriptor(),
     106           0 :                                                                lambda * W.getScaleFactor());
     107           0 :             else
     108           0 :                 lambdaWPtr = std::make_unique<Scaling<data_t>>(W.getDomainDescriptor(),
     109           0 :                                                                lambda * W.getScaleFactors());
     110             : 
     111           0 :             const auto& lambdaW = *lambdaWPtr;
     112             : 
     113           0 :             if (regTermResidual.hasOperator()) {
     114           0 :                 auto& regTermOp = regTermResidual.getOperator();
     115           0 :                 LinearOperator<data_t> A = adjoint(regTermOp) * lambdaW * regTermOp;
     116             : 
     117           0 :                 if (regTermResidual.hasDataVector()) {
     118           0 :                     DataContainer<data_t> b =
     119           0 :                         regTermOp.applyAdjoint(lambdaW.apply(regTermResidual.getDataVector()));
     120           0 :                     return LinearResidual<data_t>{A, b};
     121           0 :                 } else {
     122           0 :                     return LinearResidual<data_t>{A};
     123           0 :                 }
     124           0 :             } else {
     125           0 :                 if (regTermResidual.hasDataVector()) {
     126           0 :                     DataContainer<data_t> b = lambdaW.apply(regTermResidual.getDataVector());
     127           0 :                     return LinearResidual<data_t>{lambdaW, b};
     128           0 :                 } else {
     129           0 :                     return LinearResidual<data_t>{lambdaW};
     130           0 :                 }
     131           4 :             }
     132           4 :         } else if (is<Quadric<data_t>>(regTerm.getFunctional())) {
     133           2 :             const auto& regFunc = downcast<Quadric<data_t>>(regTerm.getFunctional());
     134           2 :             const auto& quadricResidual = regFunc.getGradientExpression();
     135           2 :             if (quadricResidual.hasOperator()) {
     136           0 :                 LinearOperator<data_t> A = lambdaOp * quadricResidual.getOperator();
     137             : 
     138           0 :                 if (quadricResidual.hasDataVector()) {
     139           0 :                     const DataContainer<data_t>& b = quadricResidual.getDataVector();
     140           0 :                     return LinearResidual<data_t>{A, lambda * b};
     141           0 :                 } else {
     142           0 :                     return LinearResidual<data_t>{A};
     143           0 :                 }
     144           2 :             } else {
     145           2 :                 if (quadricResidual.hasDataVector()) {
     146           0 :                     const DataContainer<data_t>& b = quadricResidual.getDataVector();
     147           0 :                     return LinearResidual<data_t>{lambdaOp, lambda * b};
     148           2 :                 } else {
     149           2 :                     return LinearResidual<data_t>{lambdaOp};
     150           2 :                 }
     151           2 :             }
     152           2 :         } else {
     153           2 :             throw InvalidArgumentError("QuadricProblem: regularization terms should be of type "
     154           2 :                                        "(Weighted)L2NormPow2 or Quadric");
     155           2 :         }
     156          42 :     }
     157             : 
     158             :     template <typename data_t>
     159             :     std::unique_ptr<Quadric<data_t>>
     160             :         QuadricProblem<data_t>::quadricFromProblem(const Problem<data_t>& problem)
     161          72 :     {
     162          72 :         const auto& functional = problem.getDataTerm();
     163          72 :         if (is<Quadric<data_t>>(functional) && problem.getRegularizationTerms().empty()) {
     164          21 :             return downcast<Quadric<data_t>>(functional.clone());
     165          51 :         } else {
     166          51 :             std::unique_ptr<LinearOperator<data_t>> dataTermOp;
     167          51 :             std::unique_ptr<DataContainer<data_t>> quadricVec;
     168             : 
     169             :             // convert data term
     170          51 :             if (is<Quadric<data_t>>(functional)) {
     171           2 :                 const auto& trueFunctional = downcast<Quadric<data_t>>(functional);
     172           2 :                 const LinearResidual<data_t>& residual = trueFunctional.getGradientExpression();
     173             : 
     174           2 :                 if (residual.hasOperator()) {
     175           0 :                     dataTermOp = residual.getOperator().clone();
     176           2 :                 } else {
     177           2 :                     dataTermOp = std::make_unique<Identity<data_t>>(residual.getDomainDescriptor());
     178           2 :                 }
     179             : 
     180           2 :                 if (residual.hasDataVector()) {
     181           0 :                     quadricVec = std::make_unique<DataContainer<data_t>>(residual.getDataVector());
     182           0 :                 }
     183          49 :             } else if (is<L2NormPow2<data_t>>(functional)) {
     184          47 :                 const auto& trueFunctional = downcast<L2NormPow2<data_t>>(functional);
     185             : 
     186          47 :                 if (!is<LinearResidual<data_t>>(trueFunctional.getResidual()))
     187           0 :                     throw LogicError(
     188           0 :                         "QuadricProblem: cannot convert a non-linear term to quadric form");
     189             : 
     190          47 :                 const auto& residual =
     191          47 :                     downcast<LinearResidual<data_t>>(trueFunctional.getResidual());
     192             : 
     193          47 :                 if (residual.hasOperator()) {
     194          41 :                     const auto& A = residual.getOperator();
     195          41 :                     dataTermOp = std::make_unique<LinearOperator<data_t>>(adjoint(A) * A);
     196             : 
     197          41 :                     if (residual.hasDataVector()) {
     198          41 :                         quadricVec = std::make_unique<DataContainer<data_t>>(
     199          41 :                             A.applyAdjoint(residual.getDataVector()));
     200          41 :                     }
     201          41 :                 } else {
     202           6 :                     dataTermOp = std::make_unique<Identity<data_t>>(residual.getDomainDescriptor());
     203             : 
     204           6 :                     if (residual.hasDataVector()) {
     205           0 :                         quadricVec =
     206           0 :                             std::make_unique<DataContainer<data_t>>(residual.getDataVector());
     207           0 :                     }
     208           6 :                 }
     209          47 :             } else if (is<WeightedL2NormPow2<data_t>>(functional)) {
     210           0 :                 const auto& trueFunctional = downcast<WeightedL2NormPow2<data_t>>(functional);
     211             : 
     212           0 :                 if (!is<LinearResidual<data_t>>(trueFunctional.getResidual()))
     213           0 :                     throw LogicError(
     214           0 :                         "QuadricProblem: cannot convert a non-linear term to quadric form");
     215           0 :                 const auto& residual =
     216           0 :                     downcast<LinearResidual<data_t>>(trueFunctional.getResidual());
     217             : 
     218           0 :                 const auto& W = trueFunctional.getWeightingOperator();
     219             : 
     220           0 :                 if (residual.hasOperator()) {
     221           0 :                     const auto& A = residual.getOperator();
     222           0 :                     dataTermOp = std::make_unique<LinearOperator<data_t>>(adjoint(A) * W * A);
     223             : 
     224           0 :                     if (residual.hasDataVector()) {
     225           0 :                         quadricVec = std::make_unique<DataContainer<data_t>>(
     226           0 :                             A.applyAdjoint(W.apply(residual.getDataVector())));
     227           0 :                     }
     228           0 :                 } else {
     229           0 :                     dataTermOp = W.clone();
     230             : 
     231           0 :                     if (residual.hasDataVector()) {
     232           0 :                         quadricVec = std::make_unique<DataContainer<data_t>>(
     233           0 :                             W.apply(residual.getDataVector()));
     234           0 :                     }
     235           0 :                 }
     236           2 :             } else {
     237           2 :                 throw LogicError("QuadricProblem: can only convert functionals of type "
     238           2 :                                  "(Weighted)L2NormPow2 to Quadric");
     239           2 :             }
     240             : 
     241          49 :             if (problem.getRegularizationTerms().empty()) {
     242           7 :                 if (!quadricVec) {
     243           4 :                     return std::make_unique<Quadric<data_t>>(*dataTermOp);
     244           4 :                 } else {
     245           3 :                     return std::make_unique<Quadric<data_t>>(*dataTermOp, *quadricVec);
     246           3 :                 }
     247          42 :             }
     248             : 
     249             :             // add regularization terms
     250          42 :             LinearOperator<data_t> quadricOp{dataTermOp->getDomainDescriptor(),
     251          42 :                                              dataTermOp->getRangeDescriptor()};
     252             : 
     253          84 :             for (std::size_t i = 0; i < problem.getRegularizationTerms().size(); i++) {
     254          42 :                 const auto& regTerm = problem.getRegularizationTerms()[i];
     255          42 :                 LinearResidual<data_t> residual = getGradientExpression(regTerm);
     256             : 
     257          42 :                 if (i == 0)
     258          40 :                     quadricOp = (*dataTermOp + residual.getOperator());
     259           2 :                 else
     260           2 :                     quadricOp = quadricOp + residual.getOperator();
     261             : 
     262          42 :                 if (residual.hasDataVector()) {
     263          34 :                     if (!quadricVec)
     264           0 :                         quadricVec =
     265           0 :                             std::make_unique<DataContainer<data_t>>(residual.getDataVector());
     266          34 :                     else
     267          34 :                         *quadricVec += residual.getDataVector();
     268          34 :                 }
     269          42 :             }
     270             : 
     271          42 :             if (!quadricVec) {
     272           2 :                 return std::make_unique<Quadric<data_t>>(quadricOp);
     273          40 :             } else {
     274          40 :                 return std::make_unique<Quadric<data_t>>(quadricOp, *quadricVec);
     275          40 :             }
     276          42 :         }
     277          72 :     }
     278             : 
     279             :     // ------------------------------------------
     280             :     // explicit template instantiation
     281             :     template class QuadricProblem<float>;
     282             :     template class QuadricProblem<double>;
     283             : 
     284             : } // namespace elsa

Generated by: LCOV version 1.14