Line data Source code
1 : #include <Identity.h> 2 : #include <Scaling.h> 3 : #include "SQS.h" 4 : #include "Logger.h" 5 : #include "SubsetProblem.h" 6 : #include "TypeCasts.hpp" 7 : 8 : namespace elsa 9 : { 10 : template <typename data_t> 11 0 : SQS<data_t>::SQS(const Problem<data_t>& problem, bool momentumAcceleration, data_t epsilon) 12 0 : : Solver<data_t>(problem), _epsilon{epsilon}, _momentumAcceleration{momentumAcceleration} 13 : { 14 0 : if (is<SubsetProblem<data_t>>(problem)) { 15 0 : Logger::get("SQS")->info( 16 : "SQS did received a SubsetProblem, running in ordered subset mode"); 17 0 : _subsetMode = true; 18 : } else { 19 0 : Logger::get("SQS")->info("SQS did not receive a SubsetProblem, running in normal mode"); 20 : } 21 0 : } 22 : 23 : template <typename data_t> 24 0 : SQS<data_t>::SQS(const Problem<data_t>& problem, const LinearOperator<data_t>& preconditioner, 25 : bool momentumAcceleration, data_t epsilon) 26 : : Solver<data_t>(problem), 27 0 : _epsilon{epsilon}, 28 0 : _preconditioner{preconditioner.clone()}, 29 0 : _momentumAcceleration{momentumAcceleration} 30 : { 31 0 : if (is<SubsetProblem<data_t>>(problem)) { 32 0 : Logger::get("SQS")->info( 33 : "SQS did received a SubsetProblem, running in ordered subset mode"); 34 0 : _subsetMode = true; 35 : } else { 36 0 : Logger::get("SQS")->info("SQS did not receive a SubsetProblem, running in normal mode"); 37 : } 38 : 39 : // check that preconditioner is compatible with problem 40 0 : if (_preconditioner->getDomainDescriptor().getNumberOfCoefficients() 41 0 : != _problem->getCurrentSolution().getSize() 42 0 : || _preconditioner->getRangeDescriptor().getNumberOfCoefficients() 43 0 : != _problem->getCurrentSolution().getSize()) { 44 0 : throw InvalidArgumentError("SQS: incorrect size of preconditioner"); 45 : } 46 0 : } 47 : 48 : template <typename data_t> 49 0 : DataContainer<data_t>& SQS<data_t>::solveImpl(index_t iterations) 50 : { 51 0 : if (iterations == 0) 52 0 : iterations = _defaultIterations; 53 : 54 0 : auto convergenceThreshold = _problem->getGradient().squaredL2Norm() * _epsilon * _epsilon; 55 : 56 0 : auto hessian = _problem->getHessian(); 57 : 58 0 : auto ones = DataContainer<data_t>(getCurrentSolution().getDataDescriptor()); 59 0 : ones = 1; 60 0 : auto diagVector = hessian.apply(ones); 61 0 : diagVector = static_cast<data_t>(1.0) / diagVector; 62 0 : auto diag = Scaling<data_t>(hessian.getDomainDescriptor(), diagVector); 63 : 64 0 : data_t prevT = 1; 65 0 : data_t t = 1; 66 0 : data_t nextT = 0; 67 0 : auto& z = getCurrentSolution(); 68 0 : DataContainer<data_t> x = DataContainer<data_t>(getCurrentSolution()); 69 0 : DataContainer<data_t> prevX = x; 70 0 : DataContainer<data_t> gradient(getCurrentSolution().getDataDescriptor()); 71 : 72 0 : index_t nSubsets = 1; 73 0 : if (_subsetMode) { 74 0 : const auto& subsetProblem = static_cast<const SubsetProblem<data_t>*>(_problem.get()); 75 0 : nSubsets = subsetProblem->getNumberOfSubsets(); 76 : } 77 : 78 0 : for (index_t i = 0; i < iterations; i++) { 79 0 : Logger::get("SQS")->info("iteration {} of {}", i + 1, iterations); 80 : 81 0 : for (index_t m = 0; m < nSubsets; m++) { 82 0 : if (_subsetMode) { 83 0 : gradient = 84 0 : static_cast<SubsetProblem<data_t>*>(_problem.get())->getSubsetGradient(m); 85 : } else { 86 0 : gradient = _problem->getGradient(); 87 : } 88 : 89 0 : if (_preconditioner) 90 0 : gradient = _preconditioner->apply(gradient); 91 : 92 : // TODO: element wise relu 93 0 : if (_momentumAcceleration) { 94 0 : nextT = as<data_t>(1) 95 0 : + std::sqrt(as<data_t>(1) + as<data_t>(4) * t * t) / as<data_t>(2); 96 : 97 0 : x = z - nSubsets * diag.apply(gradient); 98 0 : z = x + prevT / nextT * (x - prevX); 99 : } else { 100 0 : z = z - nSubsets * diag.apply(gradient); 101 : } 102 : 103 : // if the gradient is too small we stop 104 0 : if (gradient.squaredL2Norm() <= convergenceThreshold) { 105 0 : if (!_subsetMode 106 0 : || _problem->getGradient().squaredL2Norm() <= convergenceThreshold) { 107 0 : Logger::get("SQS")->info("SUCCESS: Reached convergence at {}/{} iteration", 108 0 : i + 1, iterations); 109 : 110 : // TODO: make return more sane 111 0 : if (_momentumAcceleration) { 112 0 : z = x; 113 : } 114 0 : return getCurrentSolution(); 115 : } 116 : } 117 : 118 0 : if (_momentumAcceleration) { 119 0 : prevT = t; 120 0 : t = nextT; 121 0 : prevX = x; 122 : } 123 : } 124 : } 125 : 126 0 : Logger::get("SQS")->warn("Failed to reach convergence at {} iterations", iterations); 127 : 128 : // TODO: make return more sane 129 0 : if (_momentumAcceleration) { 130 0 : z = x; 131 : } 132 0 : return getCurrentSolution(); 133 0 : } 134 : 135 : template <typename data_t> 136 0 : SQS<data_t>* SQS<data_t>::cloneImpl() const 137 : { 138 0 : if (_preconditioner) 139 0 : return new SQS(*_problem, *_preconditioner, _momentumAcceleration, _epsilon); 140 : 141 0 : return new SQS(*_problem, _momentumAcceleration, _epsilon); 142 : } 143 : 144 : template <typename data_t> 145 0 : bool SQS<data_t>::isEqual(const Solver<data_t>& other) const 146 : { 147 0 : if (!Solver<data_t>::isEqual(other)) 148 0 : return false; 149 : 150 0 : auto otherSQS = downcast_safe<SQS>(&other); 151 0 : if (!otherSQS) 152 0 : return false; 153 : 154 0 : if (_epsilon != otherSQS->_epsilon) 155 0 : return false; 156 : 157 0 : if ((_preconditioner && !otherSQS->_preconditioner) 158 0 : || (!_preconditioner && otherSQS->_preconditioner)) 159 0 : return false; 160 : 161 0 : if (_preconditioner && otherSQS->_preconditioner) 162 0 : if (*_preconditioner != *otherSQS->_preconditioner) 163 0 : return false; 164 : 165 0 : return true; 166 : } 167 : 168 : // ------------------------------------------ 169 : // explicit template instantiation 170 : template class SQS<float>; 171 : template class SQS<double>; 172 : 173 : } // namespace elsa