Line data Source code
1 : #include "Problem.h" 2 : #include "Scaling.h" 3 : #include "Logger.h" 4 : #include "Timer.h" 5 : 6 : namespace elsa 7 : { 8 : template <typename data_t> 9 : Problem<data_t>::Problem(const Functional<data_t>& dataTerm, 10 : const std::vector<RegularizationTerm<data_t>>& regTerms, 11 : const DataContainer<data_t>& x0, 12 : const std::optional<data_t> lipschitzConstant) 13 : : _dataTerm{dataTerm.clone()}, 14 : _regTerms{regTerms}, 15 : _currentSolution{x0}, 16 : _lipschitzConstant{lipschitzConstant} 17 64 : { 18 : // sanity checks 19 64 : if (_dataTerm->getDomainDescriptor().getNumberOfCoefficients() 20 64 : != this->_currentSolution.getSize()) 21 0 : throw InvalidArgumentError("Problem: domain of dataTerm and solution do not match"); 22 64 : } 23 : 24 : template <typename data_t> 25 : Problem<data_t>::Problem(const Functional<data_t>& dataTerm, 26 : const std::vector<RegularizationTerm<data_t>>& regTerms, 27 : const std::optional<data_t> lipschitzConstant) 28 : : _dataTerm{dataTerm.clone()}, 29 : _regTerms{regTerms}, 30 : _currentSolution{dataTerm.getDomainDescriptor()}, 31 : _lipschitzConstant{lipschitzConstant} 32 19 : { 33 19 : _currentSolution = 0; 34 19 : } 35 : 36 : template <typename data_t> 37 : Problem<data_t>::Problem(const Functional<data_t>& dataTerm, 38 : const RegularizationTerm<data_t>& regTerm, 39 : const DataContainer<data_t>& x0, 40 : const std::optional<data_t> lipschitzConstant) 41 : : _dataTerm{dataTerm.clone()}, 42 : _regTerms{regTerm}, 43 : _currentSolution{x0}, 44 : _lipschitzConstant{lipschitzConstant} 45 42 : { 46 : // sanity checks 47 42 : if (_dataTerm->getDomainDescriptor().getNumberOfCoefficients() 48 42 : != this->_currentSolution.getSize()) 49 0 : throw InvalidArgumentError("Problem: domain of dataTerm and solution do not match"); 50 42 : } 51 : 52 : template <typename data_t> 53 : Problem<data_t>::Problem(const Functional<data_t>& dataTerm, 54 : const RegularizationTerm<data_t>& regTerm, 55 : const std::optional<data_t> lipschitzConstant) 56 : : _dataTerm{dataTerm.clone()}, 57 : _regTerms{regTerm}, 58 : _currentSolution{dataTerm.getDomainDescriptor(), defaultHandlerType}, 59 : _lipschitzConstant{lipschitzConstant} 60 39 : { 61 39 : _currentSolution = 0; 62 39 : } 63 : 64 : template <typename data_t> 65 : Problem<data_t>::Problem(const Functional<data_t>& dataTerm, const DataContainer<data_t>& x0, 66 : const std::optional<data_t> lipschitzConstant) 67 : : _dataTerm{dataTerm.clone()}, _currentSolution{x0}, _lipschitzConstant{lipschitzConstant} 68 114 : { 69 : // sanity check 70 114 : if (_dataTerm->getDomainDescriptor().getNumberOfCoefficients() 71 114 : != this->_currentSolution.getSize()) 72 0 : throw InvalidArgumentError("Problem: domain of dataTerm and solution do not match"); 73 114 : } 74 : 75 : template <typename data_t> 76 : Problem<data_t>::Problem(const Functional<data_t>& dataTerm, 77 : const std::optional<data_t> lipschitzConstant) 78 : : _dataTerm{dataTerm.clone()}, 79 : _currentSolution{dataTerm.getDomainDescriptor(), defaultHandlerType}, 80 : _lipschitzConstant{lipschitzConstant} 81 188 : { 82 188 : _currentSolution = 0; 83 188 : } 84 : 85 : template <typename data_t> 86 : Problem<data_t>::Problem(const Problem<data_t>& problem) 87 : : Cloneable<Problem<data_t>>(), 88 : _dataTerm{problem._dataTerm->clone()}, 89 : _regTerms{problem._regTerms}, 90 : _currentSolution{problem._currentSolution}, 91 : _lipschitzConstant{problem._lipschitzConstant} 92 213 : { 93 213 : } 94 : 95 : template <typename data_t> 96 : const Functional<data_t>& Problem<data_t>::getDataTerm() const 97 260 : { 98 260 : return *_dataTerm; 99 260 : } 100 : 101 : template <typename data_t> 102 : const std::vector<RegularizationTerm<data_t>>& Problem<data_t>::getRegularizationTerms() const 103 266 : { 104 266 : return _regTerms; 105 266 : } 106 : 107 : template <typename data_t> 108 : const DataContainer<data_t>& Problem<data_t>::getCurrentSolution() const 109 120 : { 110 120 : return _currentSolution; 111 120 : } 112 : 113 : template <typename data_t> 114 : DataContainer<data_t>& Problem<data_t>::getCurrentSolution() 115 10246 : { 116 10246 : return _currentSolution; 117 10246 : } 118 : 119 : template <typename data_t> 120 : data_t Problem<data_t>::evaluateImpl() 121 58 : { 122 58 : data_t result = _dataTerm->evaluate(_currentSolution); 123 : 124 58 : for (auto& regTerm : _regTerms) 125 22 : result += regTerm.getWeight() * regTerm.getFunctional().evaluate(_currentSolution); 126 : 127 58 : return result; 128 58 : } 129 : 130 : template <typename data_t> 131 : void Problem<data_t>::getGradientImpl(DataContainer<data_t>& result) 132 10997 : { 133 10997 : _dataTerm->getGradient(_currentSolution, result); 134 : 135 10997 : for (auto& regTerm : _regTerms) 136 42 : result += regTerm.getWeight() * regTerm.getFunctional().getGradient(_currentSolution); 137 10997 : } 138 : 139 : template <typename data_t> 140 : LinearOperator<data_t> Problem<data_t>::getHessianImpl() const 141 87 : { 142 87 : auto hessian = _dataTerm->getHessian(_currentSolution); 143 : 144 87 : for (auto& regTerm : _regTerms) { 145 26 : Scaling weight(_currentSolution.getDataDescriptor(), regTerm.getWeight()); 146 26 : hessian = hessian + (weight * regTerm.getFunctional().getHessian(_currentSolution)); 147 26 : } 148 : 149 87 : return hessian; 150 87 : } 151 : 152 : template <typename data_t> 153 : data_t Problem<data_t>::getLipschitzConstantImpl(index_t nIterations) const 154 34 : { 155 34 : Timer guard("Problem", "Calculating Lipschitz constant"); 156 34 : Logger::get("Problem")->info("Calculating Lipschitz constant"); 157 : 158 34 : if (_lipschitzConstant.has_value()) { 159 8 : return _lipschitzConstant.value(); 160 8 : } 161 : // compute the Lipschitz Constant as the largest eigenvalue of the Hessian 162 26 : const auto hessian = getHessian(); 163 26 : DataContainer<data_t> dcB(hessian.getDomainDescriptor()); 164 26 : dcB = 1; 165 741 : for (index_t i = 0; i < nIterations; i++) { 166 715 : dcB = hessian.apply(dcB); 167 715 : dcB = dcB / dcB.l2Norm(); 168 715 : } 169 : 170 26 : return dcB.dot(hessian.apply(dcB)) / dcB.l2Norm(); 171 26 : } 172 : 173 : template <typename data_t> 174 : Problem<data_t>* Problem<data_t>::cloneImpl() const 175 12 : { 176 12 : return new Problem(*this); 177 12 : } 178 : 179 : template <typename data_t> 180 : bool Problem<data_t>::isEqual(const Problem<data_t>& other) const 181 74 : { 182 74 : if (typeid(*this) != typeid(other)) 183 16 : return false; 184 : 185 58 : if (_currentSolution != other._currentSolution) 186 0 : return false; 187 : 188 58 : if (*_dataTerm != *other._dataTerm) 189 0 : return false; 190 : 191 58 : if (_regTerms.size() != other._regTerms.size()) 192 0 : return false; 193 : 194 84 : for (std::size_t i = 0; i < _regTerms.size(); ++i) 195 26 : if (_regTerms.at(i) != other._regTerms.at(i)) 196 0 : return false; 197 : 198 58 : return true; 199 58 : } 200 : 201 : template <typename data_t> 202 : data_t Problem<data_t>::evaluate() 203 62 : { 204 62 : return evaluateImpl(); 205 62 : } 206 : 207 : template <typename data_t> 208 : DataContainer<data_t> Problem<data_t>::getGradient() 209 11001 : { 210 11001 : DataContainer<data_t> result(_currentSolution.getDataDescriptor(), 211 11001 : _currentSolution.getDataHandlerType()); 212 11001 : getGradient(result); 213 11001 : return result; 214 11001 : } 215 : 216 : template <typename data_t> 217 : void Problem<data_t>::getGradient(DataContainer<data_t>& result) 218 11001 : { 219 11001 : getGradientImpl(result); 220 11001 : } 221 : 222 : template <typename data_t> 223 : LinearOperator<data_t> Problem<data_t>::getHessian() const 224 91 : { 225 91 : return getHessianImpl(); 226 91 : } 227 : 228 : template <typename data_t> 229 : data_t Problem<data_t>::getLipschitzConstant(index_t nIterations) const 230 46 : { 231 46 : return getLipschitzConstantImpl(nIterations); 232 46 : } 233 : 234 : // ------------------------------------------ 235 : // explicit template instantiation 236 : template class Problem<float>; 237 : 238 : template class Problem<double>; 239 : 240 : template class Problem<complex<float>>; 241 : 242 : template class Problem<complex<double>>; 243 : 244 : } // namespace elsa