Line data Source code
1 : #include "Problem.h" 2 : #include "Scaling.h" 3 : #include "Logger.h" 4 : #include "Timer.h" 5 : 6 : namespace elsa 7 : { 8 : template <typename data_t> 9 0 : Problem<data_t>::Problem(const Functional<data_t>& dataTerm, 10 : const std::vector<RegularizationTerm<data_t>>& regTerms, 11 : const DataContainer<data_t>& x0, 12 : const std::optional<data_t> lipschitzConstant) 13 0 : : _dataTerm{dataTerm.clone()}, 14 0 : _regTerms{regTerms}, 15 0 : _currentSolution{x0}, 16 0 : _lipschitzConstant{lipschitzConstant} 17 : { 18 : // sanity checks 19 0 : if (_dataTerm->getDomainDescriptor().getNumberOfCoefficients() 20 0 : != this->_currentSolution.getSize()) 21 0 : throw InvalidArgumentError("Problem: domain of dataTerm and solution do not match"); 22 0 : } 23 : 24 : template <typename data_t> 25 0 : Problem<data_t>::Problem(const Functional<data_t>& dataTerm, 26 : const std::vector<RegularizationTerm<data_t>>& regTerms, 27 : const std::optional<data_t> lipschitzConstant) 28 0 : : _dataTerm{dataTerm.clone()}, 29 0 : _regTerms{regTerms}, 30 0 : _currentSolution{dataTerm.getDomainDescriptor()}, 31 0 : _lipschitzConstant{lipschitzConstant} 32 : { 33 0 : _currentSolution = 0; 34 0 : } 35 : 36 : template <typename data_t> 37 0 : Problem<data_t>::Problem(const Functional<data_t>& dataTerm, 38 : const RegularizationTerm<data_t>& regTerm, 39 : const DataContainer<data_t>& x0, 40 : const std::optional<data_t> lipschitzConstant) 41 0 : : _dataTerm{dataTerm.clone()}, 42 0 : _regTerms{regTerm}, 43 0 : _currentSolution{x0}, 44 0 : _lipschitzConstant{lipschitzConstant} 45 : { 46 : // sanity checks 47 0 : if (_dataTerm->getDomainDescriptor().getNumberOfCoefficients() 48 0 : != this->_currentSolution.getSize()) 49 0 : throw InvalidArgumentError("Problem: domain of dataTerm and solution do not match"); 50 0 : } 51 : 52 : template <typename data_t> 53 0 : Problem<data_t>::Problem(const Functional<data_t>& dataTerm, 54 : const RegularizationTerm<data_t>& regTerm, 55 : const std::optional<data_t> lipschitzConstant) 56 0 : : _dataTerm{dataTerm.clone()}, 57 0 : _regTerms{regTerm}, 58 0 : _currentSolution{dataTerm.getDomainDescriptor(), defaultHandlerType}, 59 0 : _lipschitzConstant{lipschitzConstant} 60 : { 61 0 : _currentSolution = 0; 62 0 : } 63 : 64 : template <typename data_t> 65 0 : Problem<data_t>::Problem(const Functional<data_t>& dataTerm, const DataContainer<data_t>& x0, 66 : const std::optional<data_t> lipschitzConstant) 67 0 : : _dataTerm{dataTerm.clone()}, _currentSolution{x0}, _lipschitzConstant{lipschitzConstant} 68 : { 69 : // sanity check 70 0 : if (_dataTerm->getDomainDescriptor().getNumberOfCoefficients() 71 0 : != this->_currentSolution.getSize()) 72 0 : throw InvalidArgumentError("Problem: domain of dataTerm and solution do not match"); 73 0 : } 74 : 75 : template <typename data_t> 76 0 : Problem<data_t>::Problem(const Functional<data_t>& dataTerm, 77 : const std::optional<data_t> lipschitzConstant) 78 0 : : _dataTerm{dataTerm.clone()}, 79 0 : _currentSolution{dataTerm.getDomainDescriptor(), defaultHandlerType}, 80 0 : _lipschitzConstant{lipschitzConstant} 81 : { 82 0 : _currentSolution = 0; 83 0 : } 84 : 85 : template <typename data_t> 86 0 : Problem<data_t>::Problem(const Problem<data_t>& problem) 87 : : Cloneable<Problem<data_t>>(), 88 0 : _dataTerm{problem._dataTerm->clone()}, 89 0 : _regTerms{problem._regTerms}, 90 0 : _currentSolution{problem._currentSolution}, 91 0 : _lipschitzConstant{problem._lipschitzConstant} 92 : { 93 0 : } 94 : 95 : template <typename data_t> 96 0 : const Functional<data_t>& Problem<data_t>::getDataTerm() const 97 : { 98 0 : return *_dataTerm; 99 : } 100 : 101 : template <typename data_t> 102 0 : const std::vector<RegularizationTerm<data_t>>& Problem<data_t>::getRegularizationTerms() const 103 : { 104 0 : return _regTerms; 105 : } 106 : 107 : template <typename data_t> 108 0 : const DataContainer<data_t>& Problem<data_t>::getCurrentSolution() const 109 : { 110 0 : return _currentSolution; 111 : } 112 : 113 : template <typename data_t> 114 0 : DataContainer<data_t>& Problem<data_t>::getCurrentSolution() 115 : { 116 0 : return _currentSolution; 117 : } 118 : 119 : template <typename data_t> 120 0 : data_t Problem<data_t>::evaluateImpl() 121 : { 122 0 : data_t result = _dataTerm->evaluate(_currentSolution); 123 : 124 0 : for (auto& regTerm : _regTerms) 125 0 : result += regTerm.getWeight() * regTerm.getFunctional().evaluate(_currentSolution); 126 : 127 0 : return result; 128 : } 129 : 130 : template <typename data_t> 131 0 : void Problem<data_t>::getGradientImpl(DataContainer<data_t>& result) 132 : { 133 0 : _dataTerm->getGradient(_currentSolution, result); 134 : 135 0 : for (auto& regTerm : _regTerms) 136 0 : result += regTerm.getWeight() * regTerm.getFunctional().getGradient(_currentSolution); 137 0 : } 138 : 139 : template <typename data_t> 140 0 : LinearOperator<data_t> Problem<data_t>::getHessianImpl() const 141 : { 142 0 : auto hessian = _dataTerm->getHessian(_currentSolution); 143 : 144 0 : for (auto& regTerm : _regTerms) { 145 0 : Scaling weight(_currentSolution.getDataDescriptor(), regTerm.getWeight()); 146 0 : hessian = hessian + (weight * regTerm.getFunctional().getHessian(_currentSolution)); 147 : } 148 : 149 0 : return hessian; 150 0 : } 151 : 152 : template <typename data_t> 153 0 : data_t Problem<data_t>::getLipschitzConstantImpl(index_t nIterations) const 154 : { 155 0 : Timer guard("Problem", "Calculating Lipschitz constant"); 156 0 : Logger::get("Problem")->info("Calculating Lipschitz constant"); 157 : 158 0 : if (_lipschitzConstant.has_value()) { 159 0 : return _lipschitzConstant.value(); 160 : } 161 : // compute the Lipschitz Constant as the largest eigenvalue of the Hessian 162 0 : const auto hessian = getHessian(); 163 0 : DataContainer<data_t> dcB(hessian.getDomainDescriptor()); 164 0 : dcB = 1; 165 0 : for (index_t i = 0; i < nIterations; i++) { 166 0 : dcB = hessian.apply(dcB); 167 0 : dcB = dcB / dcB.l2Norm(); 168 : } 169 : 170 0 : return dcB.dot(hessian.apply(dcB)) / dcB.l2Norm(); 171 0 : } 172 : 173 : template <typename data_t> 174 0 : Problem<data_t>* Problem<data_t>::cloneImpl() const 175 : { 176 0 : return new Problem(*this); 177 : } 178 : 179 : template <typename data_t> 180 0 : bool Problem<data_t>::isEqual(const Problem<data_t>& other) const 181 : { 182 0 : if (typeid(*this) != typeid(other)) 183 0 : return false; 184 : 185 0 : if (_currentSolution != other._currentSolution) 186 0 : return false; 187 : 188 0 : if (*_dataTerm != *other._dataTerm) 189 0 : return false; 190 : 191 0 : if (_regTerms.size() != other._regTerms.size()) 192 0 : return false; 193 : 194 0 : for (std::size_t i = 0; i < _regTerms.size(); ++i) 195 0 : if (_regTerms.at(i) != other._regTerms.at(i)) 196 0 : return false; 197 : 198 0 : return true; 199 : } 200 : 201 : template <typename data_t> 202 0 : data_t Problem<data_t>::evaluate() 203 : { 204 0 : return evaluateImpl(); 205 : } 206 : 207 : template <typename data_t> 208 0 : DataContainer<data_t> Problem<data_t>::getGradient() 209 : { 210 0 : DataContainer<data_t> result(_currentSolution.getDataDescriptor(), 211 : _currentSolution.getDataHandlerType()); 212 0 : getGradient(result); 213 0 : return result; 214 0 : } 215 : 216 : template <typename data_t> 217 0 : void Problem<data_t>::getGradient(DataContainer<data_t>& result) 218 : { 219 0 : getGradientImpl(result); 220 0 : } 221 : 222 : template <typename data_t> 223 0 : LinearOperator<data_t> Problem<data_t>::getHessian() const 224 : { 225 0 : return getHessianImpl(); 226 : } 227 : 228 : template <typename data_t> 229 0 : data_t Problem<data_t>::getLipschitzConstant(index_t nIterations) const 230 : { 231 0 : return getLipschitzConstantImpl(nIterations); 232 : } 233 : 234 : // ------------------------------------------ 235 : // explicit template instantiation 236 : template class Problem<float>; 237 : 238 : template class Problem<double>; 239 : 240 : template class Problem<complex<float>>; 241 : 242 : template class Problem<complex<double>>; 243 : 244 : } // namespace elsa