LCOV - code coverage report
Current view: top level - problems - QuadricProblem.cpp (source / functions) Hit Total Coverage
Test: test_coverage.info.cleaned Lines: 0 143 0.0 %
Date: 2022-08-04 03:43:28 Functions: 0 16 0.0 %

          Line data    Source code
       1             : #include "QuadricProblem.h"
       2             : #include "L2NormPow2.h"
       3             : #include "WeightedL2NormPow2.h"
       4             : #include "LinearOperator.h"
       5             : #include "Identity.h"
       6             : #include "TypeCasts.hpp"
       7             : 
       8             : namespace elsa
       9             : {
      10             : 
      11             :     template <typename data_t>
      12           0 :     QuadricProblem<data_t>::QuadricProblem(const LinearOperator<data_t>& A,
      13             :                                            const DataContainer<data_t>& b,
      14             :                                            const DataContainer<data_t>& x0, bool spdA)
      15           0 :         : QuadricProblem<data_t>{spdA ? Quadric{A, b} : Quadric{adjoint(A) * A, A.applyAdjoint(b)},
      16           0 :                                  x0}
      17             :     {
      18             :         // sanity checks are done in the member constructors already
      19           0 :     }
      20             : 
      21             :     template <typename data_t>
      22           0 :     QuadricProblem<data_t>::QuadricProblem(const LinearOperator<data_t>& A,
      23             :                                            const DataContainer<data_t>& b, bool spdA)
      24           0 :         : QuadricProblem<data_t>{spdA ? Quadric{A, b} : Quadric{adjoint(A) * A, A.applyAdjoint(b)}}
      25             :     {
      26             :         // sanity checks are done in the member constructors already
      27           0 :     }
      28             : 
      29             :     template <typename data_t>
      30           0 :     QuadricProblem<data_t>::QuadricProblem(const Quadric<data_t>& quadric,
      31             :                                            const DataContainer<data_t>& x0)
      32           0 :         : Problem<data_t>{quadric, x0}
      33             :     {
      34             :         // sanity checks are done in the member constructors already
      35           0 :     }
      36             : 
      37             :     template <typename data_t>
      38           0 :     QuadricProblem<data_t>::QuadricProblem(const Quadric<data_t>& quadric)
      39           0 :         : Problem<data_t>{quadric}
      40             :     {
      41             :         // sanity checks are done in the member constructors already
      42           0 :     }
      43             : 
      44             :     template <typename data_t>
      45           0 :     QuadricProblem<data_t>::QuadricProblem(const Problem<data_t>& problem)
      46           0 :         : Problem<data_t>{*quadricFromProblem(problem), problem.getCurrentSolution()}
      47             :     {
      48             :         // sanity checks are done in the member constructors already
      49           0 :     }
      50             : 
      51             :     template <typename data_t>
      52           0 :     QuadricProblem<data_t>* QuadricProblem<data_t>::cloneImpl() const
      53             :     {
      54           0 :         return new QuadricProblem(*this);
      55             :     }
      56             : 
      57             :     template <typename data_t>
      58             :     LinearResidual<data_t>
      59           0 :         QuadricProblem<data_t>::getGradientExpression(const RegularizationTerm<data_t>& regTerm)
      60             :     {
      61           0 :         const auto lambda = regTerm.getWeight();
      62           0 :         const Scaling<data_t> lambdaOp{regTerm.getFunctional().getDomainDescriptor(), lambda};
      63             : 
      64           0 :         if (is<L2NormPow2<data_t>>(regTerm.getFunctional())) {
      65           0 :             const auto& regFunc = downcast<L2NormPow2<data_t>>(regTerm.getFunctional());
      66             : 
      67           0 :             if (!is<LinearResidual<data_t>>(regFunc.getResidual())) {
      68           0 :                 throw LogicError("QuadricProblem: cannot convert a non-linear regularization "
      69             :                                  "term to quadric form");
      70             :             }
      71           0 :             const auto& regTermResidual = downcast<LinearResidual<data_t>>(regFunc.getResidual());
      72             : 
      73           0 :             if (regTermResidual.hasOperator()) {
      74           0 :                 const auto& regTermOp = regTermResidual.getOperator();
      75           0 :                 LinearOperator<data_t> A = lambdaOp * adjoint(regTermOp) * regTermOp;
      76             : 
      77           0 :                 if (regTermResidual.hasDataVector()) {
      78           0 :                     DataContainer<data_t> b =
      79             :                         lambda * regTermOp.applyAdjoint(regTermResidual.getDataVector());
      80           0 :                     return LinearResidual<data_t>{A, b};
      81           0 :                 } else {
      82           0 :                     return LinearResidual<data_t>{A};
      83             :                 }
      84           0 :             } else {
      85           0 :                 if (regTermResidual.hasDataVector()) {
      86           0 :                     DataContainer<data_t> b = lambda * regTermResidual.getDataVector();
      87           0 :                     return LinearResidual<data_t>{lambdaOp, b};
      88           0 :                 } else {
      89           0 :                     return LinearResidual<data_t>{lambdaOp};
      90             :                 }
      91             :             }
      92           0 :         } else if (is<WeightedL2NormPow2<data_t>>(regTerm.getFunctional())) {
      93           0 :             const auto& regFunc = downcast<WeightedL2NormPow2<data_t>>(regTerm.getFunctional());
      94             : 
      95           0 :             if (!is<LinearResidual<data_t>>(regFunc.getResidual()))
      96           0 :                 throw LogicError("QuadricProblem: cannot convert a non-linear regularization "
      97             :                                  "term to quadric form");
      98             : 
      99           0 :             const auto& regTermResidual = downcast<LinearResidual<data_t>>(regFunc.getResidual());
     100             : 
     101           0 :             const auto& W = regFunc.getWeightingOperator();
     102           0 :             std::unique_ptr<Scaling<data_t>> lambdaWPtr;
     103             : 
     104           0 :             if (W.isIsotropic())
     105           0 :                 lambdaWPtr = std::make_unique<Scaling<data_t>>(W.getDomainDescriptor(),
     106           0 :                                                                lambda * W.getScaleFactor());
     107             :             else
     108           0 :                 lambdaWPtr = std::make_unique<Scaling<data_t>>(W.getDomainDescriptor(),
     109             :                                                                lambda * W.getScaleFactors());
     110             : 
     111           0 :             const auto& lambdaW = *lambdaWPtr;
     112             : 
     113           0 :             if (regTermResidual.hasOperator()) {
     114           0 :                 auto& regTermOp = regTermResidual.getOperator();
     115           0 :                 LinearOperator<data_t> A = adjoint(regTermOp) * lambdaW * regTermOp;
     116             : 
     117           0 :                 if (regTermResidual.hasDataVector()) {
     118           0 :                     DataContainer<data_t> b =
     119             :                         regTermOp.applyAdjoint(lambdaW.apply(regTermResidual.getDataVector()));
     120           0 :                     return LinearResidual<data_t>{A, b};
     121           0 :                 } else {
     122           0 :                     return LinearResidual<data_t>{A};
     123             :                 }
     124           0 :             } else {
     125           0 :                 if (regTermResidual.hasDataVector()) {
     126           0 :                     DataContainer<data_t> b = lambdaW.apply(regTermResidual.getDataVector());
     127           0 :                     return LinearResidual<data_t>{lambdaW, b};
     128           0 :                 } else {
     129           0 :                     return LinearResidual<data_t>{lambdaW};
     130             :                 }
     131             :             }
     132           0 :         } else if (is<Quadric<data_t>>(regTerm.getFunctional())) {
     133           0 :             const auto& regFunc = downcast<Quadric<data_t>>(regTerm.getFunctional());
     134           0 :             const auto& quadricResidual = regFunc.getGradientExpression();
     135           0 :             if (quadricResidual.hasOperator()) {
     136           0 :                 LinearOperator<data_t> A = lambdaOp * quadricResidual.getOperator();
     137             : 
     138           0 :                 if (quadricResidual.hasDataVector()) {
     139           0 :                     const DataContainer<data_t>& b = quadricResidual.getDataVector();
     140           0 :                     return LinearResidual<data_t>{A, lambda * b};
     141             :                 } else {
     142           0 :                     return LinearResidual<data_t>{A};
     143             :                 }
     144           0 :             } else {
     145           0 :                 if (quadricResidual.hasDataVector()) {
     146           0 :                     const DataContainer<data_t>& b = quadricResidual.getDataVector();
     147           0 :                     return LinearResidual<data_t>{lambdaOp, lambda * b};
     148             :                 } else {
     149           0 :                     return LinearResidual<data_t>{lambdaOp};
     150             :                 }
     151             :             }
     152             :         } else {
     153           0 :             throw InvalidArgumentError("QuadricProblem: regularization terms should be of type "
     154             :                                        "(Weighted)L2NormPow2 or Quadric");
     155             :         }
     156           0 :     }
     157             : 
     158             :     template <typename data_t>
     159             :     std::unique_ptr<Quadric<data_t>>
     160           0 :         QuadricProblem<data_t>::quadricFromProblem(const Problem<data_t>& problem)
     161             :     {
     162           0 :         const auto& functional = problem.getDataTerm();
     163           0 :         if (is<Quadric<data_t>>(functional) && problem.getRegularizationTerms().empty()) {
     164           0 :             return downcast<Quadric<data_t>>(functional.clone());
     165             :         } else {
     166           0 :             std::unique_ptr<LinearOperator<data_t>> dataTermOp;
     167           0 :             std::unique_ptr<DataContainer<data_t>> quadricVec;
     168             : 
     169             :             // convert data term
     170           0 :             if (is<Quadric<data_t>>(functional)) {
     171           0 :                 const auto& trueFunctional = downcast<Quadric<data_t>>(functional);
     172           0 :                 const LinearResidual<data_t>& residual = trueFunctional.getGradientExpression();
     173             : 
     174           0 :                 if (residual.hasOperator()) {
     175           0 :                     dataTermOp = residual.getOperator().clone();
     176             :                 } else {
     177           0 :                     dataTermOp = std::make_unique<Identity<data_t>>(residual.getDomainDescriptor());
     178             :                 }
     179             : 
     180           0 :                 if (residual.hasDataVector()) {
     181           0 :                     quadricVec = std::make_unique<DataContainer<data_t>>(residual.getDataVector());
     182             :                 }
     183           0 :             } else if (is<L2NormPow2<data_t>>(functional)) {
     184           0 :                 const auto& trueFunctional = downcast<L2NormPow2<data_t>>(functional);
     185             : 
     186           0 :                 if (!is<LinearResidual<data_t>>(trueFunctional.getResidual()))
     187           0 :                     throw LogicError(
     188             :                         "QuadricProblem: cannot convert a non-linear term to quadric form");
     189             : 
     190             :                 const auto& residual =
     191           0 :                     downcast<LinearResidual<data_t>>(trueFunctional.getResidual());
     192             : 
     193           0 :                 if (residual.hasOperator()) {
     194           0 :                     const auto& A = residual.getOperator();
     195           0 :                     dataTermOp = std::make_unique<LinearOperator<data_t>>(adjoint(A) * A);
     196             : 
     197           0 :                     if (residual.hasDataVector()) {
     198           0 :                         quadricVec = std::make_unique<DataContainer<data_t>>(
     199             :                             A.applyAdjoint(residual.getDataVector()));
     200             :                     }
     201             :                 } else {
     202           0 :                     dataTermOp = std::make_unique<Identity<data_t>>(residual.getDomainDescriptor());
     203             : 
     204           0 :                     if (residual.hasDataVector()) {
     205           0 :                         quadricVec =
     206             :                             std::make_unique<DataContainer<data_t>>(residual.getDataVector());
     207             :                     }
     208             :                 }
     209           0 :             } else if (is<WeightedL2NormPow2<data_t>>(functional)) {
     210           0 :                 const auto& trueFunctional = downcast<WeightedL2NormPow2<data_t>>(functional);
     211             : 
     212           0 :                 if (!is<LinearResidual<data_t>>(trueFunctional.getResidual()))
     213           0 :                     throw LogicError(
     214             :                         "QuadricProblem: cannot convert a non-linear term to quadric form");
     215             :                 const auto& residual =
     216           0 :                     downcast<LinearResidual<data_t>>(trueFunctional.getResidual());
     217             : 
     218           0 :                 const auto& W = trueFunctional.getWeightingOperator();
     219             : 
     220           0 :                 if (residual.hasOperator()) {
     221           0 :                     const auto& A = residual.getOperator();
     222           0 :                     dataTermOp = std::make_unique<LinearOperator<data_t>>(adjoint(A) * W * A);
     223             : 
     224           0 :                     if (residual.hasDataVector()) {
     225           0 :                         quadricVec = std::make_unique<DataContainer<data_t>>(
     226             :                             A.applyAdjoint(W.apply(residual.getDataVector())));
     227             :                     }
     228             :                 } else {
     229           0 :                     dataTermOp = W.clone();
     230             : 
     231           0 :                     if (residual.hasDataVector()) {
     232           0 :                         quadricVec = std::make_unique<DataContainer<data_t>>(
     233             :                             W.apply(residual.getDataVector()));
     234             :                     }
     235             :                 }
     236             :             } else {
     237           0 :                 throw LogicError("QuadricProblem: can only convert functionals of type "
     238             :                                  "(Weighted)L2NormPow2 to Quadric");
     239             :             }
     240             : 
     241           0 :             if (problem.getRegularizationTerms().empty()) {
     242           0 :                 if (!quadricVec) {
     243           0 :                     return std::make_unique<Quadric<data_t>>(*dataTermOp);
     244             :                 } else {
     245           0 :                     return std::make_unique<Quadric<data_t>>(*dataTermOp, *quadricVec);
     246             :                 }
     247             :             }
     248             : 
     249             :             // add regularization terms
     250           0 :             LinearOperator<data_t> quadricOp{dataTermOp->getDomainDescriptor(),
     251             :                                              dataTermOp->getRangeDescriptor()};
     252             : 
     253           0 :             for (std::size_t i = 0; i < problem.getRegularizationTerms().size(); i++) {
     254           0 :                 const auto& regTerm = problem.getRegularizationTerms()[i];
     255           0 :                 LinearResidual<data_t> residual = getGradientExpression(regTerm);
     256             : 
     257           0 :                 if (i == 0)
     258           0 :                     quadricOp = (*dataTermOp + residual.getOperator());
     259             :                 else
     260           0 :                     quadricOp = quadricOp + residual.getOperator();
     261             : 
     262           0 :                 if (residual.hasDataVector()) {
     263           0 :                     if (!quadricVec)
     264           0 :                         quadricVec =
     265             :                             std::make_unique<DataContainer<data_t>>(residual.getDataVector());
     266             :                     else
     267           0 :                         *quadricVec += residual.getDataVector();
     268             :                 }
     269             :             }
     270             : 
     271           0 :             if (!quadricVec) {
     272           0 :                 return std::make_unique<Quadric<data_t>>(quadricOp);
     273             :             } else {
     274           0 :                 return std::make_unique<Quadric<data_t>>(quadricOp, *quadricVec);
     275             :             }
     276           0 :         }
     277             :     }
     278             : 
     279             :     // ------------------------------------------
     280             :     // explicit template instantiation
     281             :     template class QuadricProblem<float>;
     282             :     template class QuadricProblem<double>;
     283             : 
     284             : } // namespace elsa

Generated by: LCOV version 1.14