LCOV - code coverage report
Current view: top level - elsa/solvers - SQS.cpp (source / functions) Hit Total Coverage
Test: coverage-all.lcov Lines: 86 108 79.6 %
Date: 2024-05-16 04:22:26 Functions: 14 14 100.0 %

          Line data    Source code
       1             : #include "SQS.h"
       2             : #include "Identity.h"
       3             : #include "Scaling.h"
       4             : #include "Logger.h"
       5             : #include "Solver.h"
       6             : #include "TypeCasts.hpp"
       7             : 
       8             : namespace elsa
       9             : {
      10             :     template <typename data_t>
      11             :     SQS<data_t>::SQS(const LeastSquares<data_t>& problem,
      12             :                      std::vector<std::unique_ptr<LeastSquares<data_t>>>&& subsets,
      13             :                      bool momentumAcceleration, data_t epsilon)
      14             :         : Solver<data_t>(),
      15             :           fullProblem_(downcast<LeastSquares<data_t>>(problem.clone())),
      16             :           subsets_(std::move(subsets)),
      17             :           epsilon_{epsilon},
      18             :           momentumAcceleration_{momentumAcceleration},
      19             :           subsetMode_(!subsets.empty())
      20             : 
      21           4 :     {
      22           4 :         Logger::get("SQS")->info("SQS running in ordered subset mode");
      23           4 :     }
      24             : 
      25             :     template <typename data_t>
      26             :     SQS<data_t>::SQS(const LeastSquares<data_t>& problem,
      27             :                      std::vector<std::unique_ptr<LeastSquares<data_t>>>&& subsets,
      28             :                      const LinearOperator<data_t>& preconditioner, bool momentumAcceleration,
      29             :                      data_t epsilon)
      30             :         : Solver<data_t>(),
      31             :           fullProblem_(downcast<LeastSquares<data_t>>(problem.clone())),
      32             :           subsets_(std::move(subsets)),
      33             :           epsilon_{epsilon},
      34             :           preconditioner_{preconditioner.clone()},
      35             :           momentumAcceleration_{momentumAcceleration},
      36             :           subsetMode_(!subsets.empty())
      37           4 :     {
      38           4 :         Logger::get("SQS")->info("SQS running in ordered subset mode");
      39             : 
      40             :         // check that preconditioner is compatible with problem
      41           4 :         if (preconditioner_->getDomainDescriptor().getNumberOfCoefficients()
      42           4 :                 != fullProblem_->getDomainDescriptor().getNumberOfCoefficients()
      43           4 :             || preconditioner_->getRangeDescriptor().getNumberOfCoefficients()
      44           4 :                    != fullProblem_->getDomainDescriptor().getNumberOfCoefficients()) {
      45           0 :             throw InvalidArgumentError("SQS: incorrect size of preconditioner");
      46           0 :         }
      47           4 :     }
      48             : 
      49             :     template <typename data_t>
      50             :     SQS<data_t>::SQS(const LeastSquares<data_t>& problem, bool momentumAcceleration, data_t epsilon)
      51             :         : Solver<data_t>(),
      52             :           fullProblem_(downcast<LeastSquares<data_t>>(problem.clone())),
      53             :           epsilon_{epsilon},
      54             :           momentumAcceleration_{momentumAcceleration},
      55             :           subsetMode_(false)
      56             : 
      57           8 :     {
      58           8 :         Logger::get("SQS")->info("SQS running in normal mode");
      59           8 :     }
      60             : 
      61             :     template <typename data_t>
      62             :     SQS<data_t>::SQS(const LeastSquares<data_t>& problem,
      63             :                      const LinearOperator<data_t>& preconditioner, bool momentumAcceleration,
      64             :                      data_t epsilon)
      65             :         : Solver<data_t>(),
      66             :           fullProblem_(downcast<LeastSquares<data_t>>(problem.clone())),
      67             :           epsilon_{epsilon},
      68             :           preconditioner_{preconditioner.clone()},
      69             :           momentumAcceleration_{momentumAcceleration},
      70             :           subsetMode_(false)
      71           4 :     {
      72           4 :         Logger::get("SQS")->info("SQS running in normal mode");
      73             : 
      74             :         // check that preconditioner is compatible with problem
      75           4 :         if (preconditioner_->getDomainDescriptor().getNumberOfCoefficients()
      76           4 :                 != fullProblem_->getDomainDescriptor().getNumberOfCoefficients()
      77           4 :             || preconditioner_->getRangeDescriptor().getNumberOfCoefficients()
      78           4 :                    != fullProblem_->getDomainDescriptor().getNumberOfCoefficients()) {
      79           0 :             throw InvalidArgumentError("SQS: incorrect size of preconditioner");
      80           0 :         }
      81           4 :     }
      82             : 
      83             :     template <typename data_t>
      84             :     DataContainer<data_t> SQS<data_t>::solve(index_t iterations,
      85             :                                              std::optional<DataContainer<data_t>> x0)
      86           8 :     {
      87           8 :         auto& domain = fullProblem_->getDomainDescriptor();
      88           8 :         auto x = extract_or(x0, domain);
      89             : 
      90           8 :         auto convergenceThreshold =
      91           8 :             fullProblem_->getGradient(x).squaredL2Norm() * epsilon_ * epsilon_;
      92             : 
      93           8 :         auto hessian = fullProblem_->getHessian(x);
      94             : 
      95           8 :         auto rowsum = hessian.apply(ones<data_t>(domain));
      96           8 :         rowsum = static_cast<data_t>(1.0) / rowsum;
      97           8 :         auto diag = Scaling<data_t>(rowsum);
      98             : 
      99           8 :         data_t tOld = 1;
     100           8 :         data_t t = 1;
     101           8 :         data_t tNew = 0;
     102             : 
     103           8 :         auto& z = x;
     104             : 
     105           8 :         DataContainer<data_t> xOld = x;
     106           8 :         auto gradient = empty<data_t>(domain);
     107             : 
     108           8 :         index_t nSubsets = subsetMode_ ? subsets_.size() : 1;
     109             : 
     110        1008 :         for (index_t i = 0; i < iterations; i++) {
     111        1000 :             Logger::get("SQS")->info("iteration {} of {}", i + 1, iterations);
     112             : 
     113        2000 :             for (index_t m = 0; m < nSubsets; m++) {
     114        1000 :                 if (subsetMode_) {
     115           0 :                     subsets_[m]->getGradient(x, gradient);
     116        1000 :                 } else {
     117        1000 :                     fullProblem_->getGradient(x, gradient);
     118        1000 :                 }
     119             : 
     120        1000 :                 if (preconditioner_) {
     121         800 :                     preconditioner_->apply(gradient, gradient);
     122         800 :                 }
     123             : 
     124             :                 // TODO: element wise relu
     125        1000 :                 if (momentumAcceleration_) {
     126        1000 :                     tNew = as<data_t>(1)
     127        1000 :                            + std::sqrt(as<data_t>(1) + as<data_t>(4) * t * t) / as<data_t>(2);
     128             : 
     129        1000 :                     lincomb(1, z, -nSubsets, diag.apply(gradient), x);
     130        1000 :                     lincomb(1, x, tOld / tNew, (x - xOld), z);
     131        1000 :                 } else {
     132           0 :                     lincomb(1, z, -nSubsets, diag.apply(gradient), z);
     133           0 :                 }
     134             : 
     135             :                 // if the gradient is too small we stop
     136        1000 :                 if (gradient.squaredL2Norm() <= convergenceThreshold) {
     137           0 :                     if (!subsetMode_
     138           0 :                         || fullProblem_->getGradient(x).squaredL2Norm() <= convergenceThreshold) {
     139           0 :                         Logger::get("SQS")->info("SUCCESS: Reached convergence at {}/{} iteration",
     140           0 :                                                  i + 1, iterations);
     141             : 
     142             :                         // TODO: make return more sane
     143           0 :                         if (momentumAcceleration_) {
     144           0 :                             z = x;
     145           0 :                         }
     146           0 :                         return x;
     147           0 :                     }
     148        1000 :                 }
     149             : 
     150        1000 :                 if (momentumAcceleration_) {
     151        1000 :                     tOld = t;
     152        1000 :                     t = tNew;
     153        1000 :                     xOld = x;
     154        1000 :                 }
     155        1000 :             }
     156        1000 :         }
     157             : 
     158           8 :         Logger::get("SQS")->warn("Failed to reach convergence at {} iterations", iterations);
     159             : 
     160             :         // TODO: make return more sane
     161           8 :         if (momentumAcceleration_) {
     162           8 :             z = x;
     163           8 :         }
     164           8 :         return x;
     165           8 :     }
     166             : 
     167             :     template <typename data_t>
     168             :     SQS<data_t>* SQS<data_t>::cloneImpl() const
     169           8 :     {
     170           8 :         std::vector<std::unique_ptr<LeastSquares<data_t>>> newsubsets;
     171           8 :         newsubsets.reserve(subsets_.size());
     172           8 :         for (const auto& ptr : subsets_) {
     173           0 :             newsubsets.emplace_back(downcast<LeastSquares<data_t>>(ptr->clone()));
     174           0 :         }
     175             : 
     176           8 :         if (preconditioner_)
     177           4 :             return new SQS(*fullProblem_, std::move(newsubsets), *preconditioner_,
     178           4 :                            momentumAcceleration_, epsilon_);
     179             : 
     180           4 :         return new SQS(*fullProblem_, std::move(newsubsets), momentumAcceleration_, epsilon_);
     181           4 :     }
     182             : 
     183             :     template <typename data_t>
     184             :     bool SQS<data_t>::isEqual(const Solver<data_t>& other) const
     185           8 :     {
     186           8 :         auto otherSQS = downcast_safe<SQS>(&other);
     187           8 :         if (!otherSQS)
     188           0 :             return false;
     189             : 
     190           8 :         if (epsilon_ != otherSQS->epsilon_)
     191           0 :             return false;
     192             : 
     193           8 :         if ((preconditioner_ && !otherSQS->preconditioner_)
     194           8 :             || (!preconditioner_ && otherSQS->preconditioner_))
     195           0 :             return false;
     196             : 
     197           8 :         if (preconditioner_ && otherSQS->preconditioner_)
     198           4 :             if (*preconditioner_ != *otherSQS->preconditioner_)
     199           0 :                 return false;
     200             : 
     201           8 :         return true;
     202           8 :     }
     203             : 
     204             :     // ------------------------------------------
     205             :     // explicit template instantiation
     206             :     template class SQS<float>;
     207             :     template class SQS<double>;
     208             : 
     209             : } // namespace elsa

Generated by: LCOV version 1.14