LCOV - code coverage report
Current view: top level - problems - Problem.cpp (source / functions) Hit Total Coverage
Test: test_coverage.info.cleaned Lines: 0 117 0.0 %
Date: 2022-08-04 03:43:28 Functions: 0 88 0.0 %

          Line data    Source code
       1             : #include "Problem.h"
       2             : #include "Scaling.h"
       3             : #include "Logger.h"
       4             : #include "Timer.h"
       5             : 
       6             : namespace elsa
       7             : {
       8             :     template <typename data_t>
       9           0 :     Problem<data_t>::Problem(const Functional<data_t>& dataTerm,
      10             :                              const std::vector<RegularizationTerm<data_t>>& regTerms,
      11             :                              const DataContainer<data_t>& x0,
      12             :                              const std::optional<data_t> lipschitzConstant)
      13           0 :         : _dataTerm{dataTerm.clone()},
      14           0 :           _regTerms{regTerms},
      15           0 :           _currentSolution{x0},
      16           0 :           _lipschitzConstant{lipschitzConstant}
      17             :     {
      18             :         // sanity checks
      19           0 :         if (_dataTerm->getDomainDescriptor().getNumberOfCoefficients()
      20           0 :             != this->_currentSolution.getSize())
      21           0 :             throw InvalidArgumentError("Problem: domain of dataTerm and solution do not match");
      22           0 :     }
      23             : 
      24             :     template <typename data_t>
      25           0 :     Problem<data_t>::Problem(const Functional<data_t>& dataTerm,
      26             :                              const std::vector<RegularizationTerm<data_t>>& regTerms,
      27             :                              const std::optional<data_t> lipschitzConstant)
      28           0 :         : _dataTerm{dataTerm.clone()},
      29           0 :           _regTerms{regTerms},
      30           0 :           _currentSolution{dataTerm.getDomainDescriptor()},
      31           0 :           _lipschitzConstant{lipschitzConstant}
      32             :     {
      33           0 :         _currentSolution = 0;
      34           0 :     }
      35             : 
      36             :     template <typename data_t>
      37           0 :     Problem<data_t>::Problem(const Functional<data_t>& dataTerm,
      38             :                              const RegularizationTerm<data_t>& regTerm,
      39             :                              const DataContainer<data_t>& x0,
      40             :                              const std::optional<data_t> lipschitzConstant)
      41           0 :         : _dataTerm{dataTerm.clone()},
      42           0 :           _regTerms{regTerm},
      43           0 :           _currentSolution{x0},
      44           0 :           _lipschitzConstant{lipschitzConstant}
      45             :     {
      46             :         // sanity checks
      47           0 :         if (_dataTerm->getDomainDescriptor().getNumberOfCoefficients()
      48           0 :             != this->_currentSolution.getSize())
      49           0 :             throw InvalidArgumentError("Problem: domain of dataTerm and solution do not match");
      50           0 :     }
      51             : 
      52             :     template <typename data_t>
      53           0 :     Problem<data_t>::Problem(const Functional<data_t>& dataTerm,
      54             :                              const RegularizationTerm<data_t>& regTerm,
      55             :                              const std::optional<data_t> lipschitzConstant)
      56           0 :         : _dataTerm{dataTerm.clone()},
      57           0 :           _regTerms{regTerm},
      58           0 :           _currentSolution{dataTerm.getDomainDescriptor(), defaultHandlerType},
      59           0 :           _lipschitzConstant{lipschitzConstant}
      60             :     {
      61           0 :         _currentSolution = 0;
      62           0 :     }
      63             : 
      64             :     template <typename data_t>
      65           0 :     Problem<data_t>::Problem(const Functional<data_t>& dataTerm, const DataContainer<data_t>& x0,
      66             :                              const std::optional<data_t> lipschitzConstant)
      67           0 :         : _dataTerm{dataTerm.clone()}, _currentSolution{x0}, _lipschitzConstant{lipschitzConstant}
      68             :     {
      69             :         // sanity check
      70           0 :         if (_dataTerm->getDomainDescriptor().getNumberOfCoefficients()
      71           0 :             != this->_currentSolution.getSize())
      72           0 :             throw InvalidArgumentError("Problem: domain of dataTerm and solution do not match");
      73           0 :     }
      74             : 
      75             :     template <typename data_t>
      76           0 :     Problem<data_t>::Problem(const Functional<data_t>& dataTerm,
      77             :                              const std::optional<data_t> lipschitzConstant)
      78           0 :         : _dataTerm{dataTerm.clone()},
      79           0 :           _currentSolution{dataTerm.getDomainDescriptor(), defaultHandlerType},
      80           0 :           _lipschitzConstant{lipschitzConstant}
      81             :     {
      82           0 :         _currentSolution = 0;
      83           0 :     }
      84             : 
      85             :     template <typename data_t>
      86           0 :     Problem<data_t>::Problem(const Problem<data_t>& problem)
      87             :         : Cloneable<Problem<data_t>>(),
      88           0 :           _dataTerm{problem._dataTerm->clone()},
      89           0 :           _regTerms{problem._regTerms},
      90           0 :           _currentSolution{problem._currentSolution},
      91           0 :           _lipschitzConstant{problem._lipschitzConstant}
      92             :     {
      93           0 :     }
      94             : 
      95             :     template <typename data_t>
      96           0 :     const Functional<data_t>& Problem<data_t>::getDataTerm() const
      97             :     {
      98           0 :         return *_dataTerm;
      99             :     }
     100             : 
     101             :     template <typename data_t>
     102           0 :     const std::vector<RegularizationTerm<data_t>>& Problem<data_t>::getRegularizationTerms() const
     103             :     {
     104           0 :         return _regTerms;
     105             :     }
     106             : 
     107             :     template <typename data_t>
     108           0 :     const DataContainer<data_t>& Problem<data_t>::getCurrentSolution() const
     109             :     {
     110           0 :         return _currentSolution;
     111             :     }
     112             : 
     113             :     template <typename data_t>
     114           0 :     DataContainer<data_t>& Problem<data_t>::getCurrentSolution()
     115             :     {
     116           0 :         return _currentSolution;
     117             :     }
     118             : 
     119             :     template <typename data_t>
     120           0 :     data_t Problem<data_t>::evaluateImpl()
     121             :     {
     122           0 :         data_t result = _dataTerm->evaluate(_currentSolution);
     123             : 
     124           0 :         for (auto& regTerm : _regTerms)
     125           0 :             result += regTerm.getWeight() * regTerm.getFunctional().evaluate(_currentSolution);
     126             : 
     127           0 :         return result;
     128             :     }
     129             : 
     130             :     template <typename data_t>
     131           0 :     void Problem<data_t>::getGradientImpl(DataContainer<data_t>& result)
     132             :     {
     133           0 :         _dataTerm->getGradient(_currentSolution, result);
     134             : 
     135           0 :         for (auto& regTerm : _regTerms)
     136           0 :             result += regTerm.getWeight() * regTerm.getFunctional().getGradient(_currentSolution);
     137           0 :     }
     138             : 
     139             :     template <typename data_t>
     140           0 :     LinearOperator<data_t> Problem<data_t>::getHessianImpl() const
     141             :     {
     142           0 :         auto hessian = _dataTerm->getHessian(_currentSolution);
     143             : 
     144           0 :         for (auto& regTerm : _regTerms) {
     145           0 :             Scaling weight(_currentSolution.getDataDescriptor(), regTerm.getWeight());
     146           0 :             hessian = hessian + (weight * regTerm.getFunctional().getHessian(_currentSolution));
     147             :         }
     148             : 
     149           0 :         return hessian;
     150           0 :     }
     151             : 
     152             :     template <typename data_t>
     153           0 :     data_t Problem<data_t>::getLipschitzConstantImpl(index_t nIterations) const
     154             :     {
     155           0 :         Timer guard("Problem", "Calculating Lipschitz constant");
     156           0 :         Logger::get("Problem")->info("Calculating Lipschitz constant");
     157             : 
     158           0 :         if (_lipschitzConstant.has_value()) {
     159           0 :             return _lipschitzConstant.value();
     160             :         }
     161             :         // compute the Lipschitz Constant as the largest eigenvalue of the Hessian
     162           0 :         const auto hessian = getHessian();
     163           0 :         DataContainer<data_t> dcB(hessian.getDomainDescriptor());
     164           0 :         dcB = 1;
     165           0 :         for (index_t i = 0; i < nIterations; i++) {
     166           0 :             dcB = hessian.apply(dcB);
     167           0 :             dcB = dcB / dcB.l2Norm();
     168             :         }
     169             : 
     170           0 :         return dcB.dot(hessian.apply(dcB)) / dcB.l2Norm();
     171           0 :     }
     172             : 
     173             :     template <typename data_t>
     174           0 :     Problem<data_t>* Problem<data_t>::cloneImpl() const
     175             :     {
     176           0 :         return new Problem(*this);
     177             :     }
     178             : 
     179             :     template <typename data_t>
     180           0 :     bool Problem<data_t>::isEqual(const Problem<data_t>& other) const
     181             :     {
     182           0 :         if (typeid(*this) != typeid(other))
     183           0 :             return false;
     184             : 
     185           0 :         if (_currentSolution != other._currentSolution)
     186           0 :             return false;
     187             : 
     188           0 :         if (*_dataTerm != *other._dataTerm)
     189           0 :             return false;
     190             : 
     191           0 :         if (_regTerms.size() != other._regTerms.size())
     192           0 :             return false;
     193             : 
     194           0 :         for (std::size_t i = 0; i < _regTerms.size(); ++i)
     195           0 :             if (_regTerms.at(i) != other._regTerms.at(i))
     196           0 :                 return false;
     197             : 
     198           0 :         return true;
     199             :     }
     200             : 
     201             :     template <typename data_t>
     202           0 :     data_t Problem<data_t>::evaluate()
     203             :     {
     204           0 :         return evaluateImpl();
     205             :     }
     206             : 
     207             :     template <typename data_t>
     208           0 :     DataContainer<data_t> Problem<data_t>::getGradient()
     209             :     {
     210           0 :         DataContainer<data_t> result(_currentSolution.getDataDescriptor(),
     211             :                                      _currentSolution.getDataHandlerType());
     212           0 :         getGradient(result);
     213           0 :         return result;
     214           0 :     }
     215             : 
     216             :     template <typename data_t>
     217           0 :     void Problem<data_t>::getGradient(DataContainer<data_t>& result)
     218             :     {
     219           0 :         getGradientImpl(result);
     220           0 :     }
     221             : 
     222             :     template <typename data_t>
     223           0 :     LinearOperator<data_t> Problem<data_t>::getHessian() const
     224             :     {
     225           0 :         return getHessianImpl();
     226             :     }
     227             : 
     228             :     template <typename data_t>
     229           0 :     data_t Problem<data_t>::getLipschitzConstant(index_t nIterations) const
     230             :     {
     231           0 :         return getLipschitzConstantImpl(nIterations);
     232             :     }
     233             : 
     234             :     // ------------------------------------------
     235             :     // explicit template instantiation
     236             :     template class Problem<float>;
     237             : 
     238             :     template class Problem<double>;
     239             : 
     240             :     template class Problem<complex<float>>;
     241             : 
     242             :     template class Problem<complex<double>>;
     243             : 
     244             : } // namespace elsa

Generated by: LCOV version 1.14