Line data Source code
1 : #include <Identity.h> 2 : #include <Scaling.h> 3 : #include "SQS.h" 4 : #include "Logger.h" 5 : #include "SubsetProblem.h" 6 : #include "TypeCasts.hpp" 7 : 8 : namespace elsa 9 : { 10 : template <typename data_t> 11 : SQS<data_t>::SQS(const Problem<data_t>& problem, bool momentumAcceleration, data_t epsilon) 12 : : Solver<data_t>(), 13 : _problem(problem.clone()), 14 : _epsilon{epsilon}, 15 : _momentumAcceleration{momentumAcceleration} 16 14 : { 17 14 : if (is<SubsetProblem<data_t>>(problem)) { 18 4 : Logger::get("SQS")->info( 19 4 : "SQS did received a SubsetProblem, running in ordered subset mode"); 20 4 : _subsetMode = true; 21 10 : } else { 22 10 : Logger::get("SQS")->info("SQS did not receive a SubsetProblem, running in normal mode"); 23 10 : } 24 14 : } 25 : 26 : template <typename data_t> 27 : SQS<data_t>::SQS(const Problem<data_t>& problem, const LinearOperator<data_t>& preconditioner, 28 : bool momentumAcceleration, data_t epsilon) 29 : : Solver<data_t>(), 30 : _problem(problem.clone()), 31 : _epsilon{epsilon}, 32 : _preconditioner{preconditioner.clone()}, 33 : _momentumAcceleration{momentumAcceleration} 34 8 : { 35 8 : if (is<SubsetProblem<data_t>>(problem)) { 36 0 : Logger::get("SQS")->info( 37 0 : "SQS did received a SubsetProblem, running in ordered subset mode"); 38 0 : _subsetMode = true; 39 8 : } else { 40 8 : Logger::get("SQS")->info("SQS did not receive a SubsetProblem, running in normal mode"); 41 8 : } 42 : 43 : // check that preconditioner is compatible with problem 44 8 : if (_preconditioner->getDomainDescriptor().getNumberOfCoefficients() 45 8 : != _problem->getCurrentSolution().getSize() 46 8 : || _preconditioner->getRangeDescriptor().getNumberOfCoefficients() 47 8 : != _problem->getCurrentSolution().getSize()) { 48 0 : throw InvalidArgumentError("SQS: incorrect size of preconditioner"); 49 0 : } 50 8 : } 51 : 52 : template <typename data_t> 53 : DataContainer<data_t>& SQS<data_t>::solveImpl(index_t iterations) 54 11 : { 55 11 : if (iterations == 0) 56 0 : iterations = _defaultIterations; 57 : 58 11 : auto convergenceThreshold = _problem->getGradient().squaredL2Norm() * _epsilon * _epsilon; 59 : 60 11 : auto hessian = _problem->getHessian(); 61 : 62 11 : auto& solutionDesc = _problem->getCurrentSolution().getDataDescriptor(); 63 : 64 11 : auto ones = DataContainer<data_t>(solutionDesc); 65 11 : ones = 1; 66 11 : auto diagVector = hessian.apply(ones); 67 11 : diagVector = static_cast<data_t>(1.0) / diagVector; 68 11 : auto diag = Scaling<data_t>(hessian.getDomainDescriptor(), diagVector); 69 : 70 11 : data_t prevT = 1; 71 11 : data_t t = 1; 72 11 : data_t nextT = 0; 73 11 : auto& z = _problem->getCurrentSolution(); 74 11 : DataContainer<data_t> x = _problem->getCurrentSolution(); 75 11 : DataContainer<data_t> prevX = x; 76 11 : DataContainer<data_t> gradient(solutionDesc); 77 : 78 11 : index_t nSubsets = 1; 79 11 : if (_subsetMode) { 80 2 : const auto& subsetProblem = static_cast<const SubsetProblem<data_t>*>(_problem.get()); 81 2 : nSubsets = subsetProblem->getNumberOfSubsets(); 82 2 : } 83 : 84 915 : for (index_t i = 0; i < iterations; i++) { 85 909 : Logger::get("SQS")->info("iteration {} of {}", i + 1, iterations); 86 : 87 1873 : for (index_t m = 0; m < nSubsets; m++) { 88 969 : if (_subsetMode) { 89 80 : gradient = 90 80 : static_cast<SubsetProblem<data_t>*>(_problem.get())->getSubsetGradient(m); 91 889 : } else { 92 889 : gradient = _problem->getGradient(); 93 889 : } 94 : 95 969 : if (_preconditioner) 96 837 : gradient = _preconditioner->apply(gradient); 97 : 98 : // TODO: element wise relu 99 969 : if (_momentumAcceleration) { 100 969 : nextT = as<data_t>(1) 101 969 : + std::sqrt(as<data_t>(1) + as<data_t>(4) * t * t) / as<data_t>(2); 102 : 103 969 : x = z - nSubsets * diag.apply(gradient); 104 969 : z = x + prevT / nextT * (x - prevX); 105 969 : } else { 106 0 : z = z - nSubsets * diag.apply(gradient); 107 0 : } 108 : 109 : // if the gradient is too small we stop 110 969 : if (gradient.squaredL2Norm() <= convergenceThreshold) { 111 5 : if (!_subsetMode 112 5 : || _problem->getGradient().squaredL2Norm() <= convergenceThreshold) { 113 5 : Logger::get("SQS")->info("SUCCESS: Reached convergence at {}/{} iteration", 114 5 : i + 1, iterations); 115 : 116 : // TODO: make return more sane 117 5 : if (_momentumAcceleration) { 118 5 : z = x; 119 5 : } 120 5 : return _problem->getCurrentSolution(); 121 5 : } 122 964 : } 123 : 124 964 : if (_momentumAcceleration) { 125 964 : prevT = t; 126 964 : t = nextT; 127 964 : prevX = x; 128 964 : } 129 964 : } 130 909 : } 131 : 132 11 : Logger::get("SQS")->warn("Failed to reach convergence at {} iterations", iterations); 133 : 134 : // TODO: make return more sane 135 6 : if (_momentumAcceleration) { 136 6 : z = x; 137 6 : } 138 6 : return _problem->getCurrentSolution(); 139 11 : } 140 : 141 : template <typename data_t> 142 : SQS<data_t>* SQS<data_t>::cloneImpl() const 143 11 : { 144 11 : if (_preconditioner) 145 4 : return new SQS(*_problem, *_preconditioner, _momentumAcceleration, _epsilon); 146 : 147 7 : return new SQS(*_problem, _momentumAcceleration, _epsilon); 148 7 : } 149 : 150 : template <typename data_t> 151 : bool SQS<data_t>::isEqual(const Solver<data_t>& other) const 152 11 : { 153 11 : auto otherSQS = downcast_safe<SQS>(&other); 154 11 : if (!otherSQS) 155 0 : return false; 156 : 157 11 : if (_epsilon != otherSQS->_epsilon) 158 0 : return false; 159 : 160 11 : if ((_preconditioner && !otherSQS->_preconditioner) 161 11 : || (!_preconditioner && otherSQS->_preconditioner)) 162 0 : return false; 163 : 164 11 : if (_preconditioner && otherSQS->_preconditioner) 165 4 : if (*_preconditioner != *otherSQS->_preconditioner) 166 0 : return false; 167 : 168 11 : return true; 169 11 : } 170 : 171 : // ------------------------------------------ 172 : // explicit template instantiation 173 : template class SQS<float>; 174 : template class SQS<double>; 175 : 176 : } // namespace elsa