Line data Source code
1 : #include "QuadricProblem.h" 2 : #include "L2NormPow2.h" 3 : #include "WeightedL2NormPow2.h" 4 : #include "LinearOperator.h" 5 : #include "Identity.h" 6 : #include "TypeCasts.hpp" 7 : 8 : namespace elsa 9 : { 10 : 11 : template <typename data_t> 12 0 : QuadricProblem<data_t>::QuadricProblem(const LinearOperator<data_t>& A, 13 : const DataContainer<data_t>& b, 14 : const DataContainer<data_t>& x0, bool spdA) 15 0 : : QuadricProblem<data_t>{spdA ? Quadric{A, b} : Quadric{adjoint(A) * A, A.applyAdjoint(b)}, 16 0 : x0} 17 : { 18 : // sanity checks are done in the member constructors already 19 0 : } 20 : 21 : template <typename data_t> 22 0 : QuadricProblem<data_t>::QuadricProblem(const LinearOperator<data_t>& A, 23 : const DataContainer<data_t>& b, bool spdA) 24 0 : : QuadricProblem<data_t>{spdA ? Quadric{A, b} : Quadric{adjoint(A) * A, A.applyAdjoint(b)}} 25 : { 26 : // sanity checks are done in the member constructors already 27 0 : } 28 : 29 : template <typename data_t> 30 0 : QuadricProblem<data_t>::QuadricProblem(const Quadric<data_t>& quadric, 31 : const DataContainer<data_t>& x0) 32 0 : : Problem<data_t>{quadric, x0} 33 : { 34 : // sanity checks are done in the member constructors already 35 0 : } 36 : 37 : template <typename data_t> 38 0 : QuadricProblem<data_t>::QuadricProblem(const Quadric<data_t>& quadric) 39 0 : : Problem<data_t>{quadric} 40 : { 41 : // sanity checks are done in the member constructors already 42 0 : } 43 : 44 : template <typename data_t> 45 0 : QuadricProblem<data_t>::QuadricProblem(const Problem<data_t>& problem) 46 0 : : Problem<data_t>{*quadricFromProblem(problem), problem.getCurrentSolution()} 47 : { 48 : // sanity checks are done in the member constructors already 49 0 : } 50 : 51 : template <typename data_t> 52 0 : QuadricProblem<data_t>* QuadricProblem<data_t>::cloneImpl() const 53 : { 54 0 : return new QuadricProblem(*this); 55 : } 56 : 57 : template <typename data_t> 58 : LinearResidual<data_t> 59 0 : QuadricProblem<data_t>::getGradientExpression(const RegularizationTerm<data_t>& regTerm) 60 : { 61 0 : const auto lambda = regTerm.getWeight(); 62 0 : const Scaling<data_t> lambdaOp{regTerm.getFunctional().getDomainDescriptor(), lambda}; 63 : 64 0 : if (is<L2NormPow2<data_t>>(regTerm.getFunctional())) { 65 0 : const auto& regFunc = downcast<L2NormPow2<data_t>>(regTerm.getFunctional()); 66 : 67 0 : if (!is<LinearResidual<data_t>>(regFunc.getResidual())) { 68 0 : throw LogicError("QuadricProblem: cannot convert a non-linear regularization " 69 : "term to quadric form"); 70 : } 71 0 : const auto& regTermResidual = downcast<LinearResidual<data_t>>(regFunc.getResidual()); 72 : 73 0 : if (regTermResidual.hasOperator()) { 74 0 : const auto& regTermOp = regTermResidual.getOperator(); 75 0 : LinearOperator<data_t> A = lambdaOp * adjoint(regTermOp) * regTermOp; 76 : 77 0 : if (regTermResidual.hasDataVector()) { 78 0 : DataContainer<data_t> b = 79 : lambda * regTermOp.applyAdjoint(regTermResidual.getDataVector()); 80 0 : return LinearResidual<data_t>{A, b}; 81 0 : } else { 82 0 : return LinearResidual<data_t>{A}; 83 : } 84 0 : } else { 85 0 : if (regTermResidual.hasDataVector()) { 86 0 : DataContainer<data_t> b = lambda * regTermResidual.getDataVector(); 87 0 : return LinearResidual<data_t>{lambdaOp, b}; 88 0 : } else { 89 0 : return LinearResidual<data_t>{lambdaOp}; 90 : } 91 : } 92 0 : } else if (is<WeightedL2NormPow2<data_t>>(regTerm.getFunctional())) { 93 0 : const auto& regFunc = downcast<WeightedL2NormPow2<data_t>>(regTerm.getFunctional()); 94 : 95 0 : if (!is<LinearResidual<data_t>>(regFunc.getResidual())) 96 0 : throw LogicError("QuadricProblem: cannot convert a non-linear regularization " 97 : "term to quadric form"); 98 : 99 0 : const auto& regTermResidual = downcast<LinearResidual<data_t>>(regFunc.getResidual()); 100 : 101 0 : const auto& W = regFunc.getWeightingOperator(); 102 0 : std::unique_ptr<Scaling<data_t>> lambdaWPtr; 103 : 104 0 : if (W.isIsotropic()) 105 0 : lambdaWPtr = std::make_unique<Scaling<data_t>>(W.getDomainDescriptor(), 106 0 : lambda * W.getScaleFactor()); 107 : else 108 0 : lambdaWPtr = std::make_unique<Scaling<data_t>>(W.getDomainDescriptor(), 109 : lambda * W.getScaleFactors()); 110 : 111 0 : const auto& lambdaW = *lambdaWPtr; 112 : 113 0 : if (regTermResidual.hasOperator()) { 114 0 : auto& regTermOp = regTermResidual.getOperator(); 115 0 : LinearOperator<data_t> A = adjoint(regTermOp) * lambdaW * regTermOp; 116 : 117 0 : if (regTermResidual.hasDataVector()) { 118 0 : DataContainer<data_t> b = 119 : regTermOp.applyAdjoint(lambdaW.apply(regTermResidual.getDataVector())); 120 0 : return LinearResidual<data_t>{A, b}; 121 0 : } else { 122 0 : return LinearResidual<data_t>{A}; 123 : } 124 0 : } else { 125 0 : if (regTermResidual.hasDataVector()) { 126 0 : DataContainer<data_t> b = lambdaW.apply(regTermResidual.getDataVector()); 127 0 : return LinearResidual<data_t>{lambdaW, b}; 128 0 : } else { 129 0 : return LinearResidual<data_t>{lambdaW}; 130 : } 131 : } 132 0 : } else if (is<Quadric<data_t>>(regTerm.getFunctional())) { 133 0 : const auto& regFunc = downcast<Quadric<data_t>>(regTerm.getFunctional()); 134 0 : const auto& quadricResidual = regFunc.getGradientExpression(); 135 0 : if (quadricResidual.hasOperator()) { 136 0 : LinearOperator<data_t> A = lambdaOp * quadricResidual.getOperator(); 137 : 138 0 : if (quadricResidual.hasDataVector()) { 139 0 : const DataContainer<data_t>& b = quadricResidual.getDataVector(); 140 0 : return LinearResidual<data_t>{A, lambda * b}; 141 : } else { 142 0 : return LinearResidual<data_t>{A}; 143 : } 144 0 : } else { 145 0 : if (quadricResidual.hasDataVector()) { 146 0 : const DataContainer<data_t>& b = quadricResidual.getDataVector(); 147 0 : return LinearResidual<data_t>{lambdaOp, lambda * b}; 148 : } else { 149 0 : return LinearResidual<data_t>{lambdaOp}; 150 : } 151 : } 152 : } else { 153 0 : throw InvalidArgumentError("QuadricProblem: regularization terms should be of type " 154 : "(Weighted)L2NormPow2 or Quadric"); 155 : } 156 0 : } 157 : 158 : template <typename data_t> 159 : std::unique_ptr<Quadric<data_t>> 160 0 : QuadricProblem<data_t>::quadricFromProblem(const Problem<data_t>& problem) 161 : { 162 0 : const auto& functional = problem.getDataTerm(); 163 0 : if (is<Quadric<data_t>>(functional) && problem.getRegularizationTerms().empty()) { 164 0 : return downcast<Quadric<data_t>>(functional.clone()); 165 : } else { 166 0 : std::unique_ptr<LinearOperator<data_t>> dataTermOp; 167 0 : std::unique_ptr<DataContainer<data_t>> quadricVec; 168 : 169 : // convert data term 170 0 : if (is<Quadric<data_t>>(functional)) { 171 0 : const auto& trueFunctional = downcast<Quadric<data_t>>(functional); 172 0 : const LinearResidual<data_t>& residual = trueFunctional.getGradientExpression(); 173 : 174 0 : if (residual.hasOperator()) { 175 0 : dataTermOp = residual.getOperator().clone(); 176 : } else { 177 0 : dataTermOp = std::make_unique<Identity<data_t>>(residual.getDomainDescriptor()); 178 : } 179 : 180 0 : if (residual.hasDataVector()) { 181 0 : quadricVec = std::make_unique<DataContainer<data_t>>(residual.getDataVector()); 182 : } 183 0 : } else if (is<L2NormPow2<data_t>>(functional)) { 184 0 : const auto& trueFunctional = downcast<L2NormPow2<data_t>>(functional); 185 : 186 0 : if (!is<LinearResidual<data_t>>(trueFunctional.getResidual())) 187 0 : throw LogicError( 188 : "QuadricProblem: cannot convert a non-linear term to quadric form"); 189 : 190 : const auto& residual = 191 0 : downcast<LinearResidual<data_t>>(trueFunctional.getResidual()); 192 : 193 0 : if (residual.hasOperator()) { 194 0 : const auto& A = residual.getOperator(); 195 0 : dataTermOp = std::make_unique<LinearOperator<data_t>>(adjoint(A) * A); 196 : 197 0 : if (residual.hasDataVector()) { 198 0 : quadricVec = std::make_unique<DataContainer<data_t>>( 199 : A.applyAdjoint(residual.getDataVector())); 200 : } 201 : } else { 202 0 : dataTermOp = std::make_unique<Identity<data_t>>(residual.getDomainDescriptor()); 203 : 204 0 : if (residual.hasDataVector()) { 205 0 : quadricVec = 206 : std::make_unique<DataContainer<data_t>>(residual.getDataVector()); 207 : } 208 : } 209 0 : } else if (is<WeightedL2NormPow2<data_t>>(functional)) { 210 0 : const auto& trueFunctional = downcast<WeightedL2NormPow2<data_t>>(functional); 211 : 212 0 : if (!is<LinearResidual<data_t>>(trueFunctional.getResidual())) 213 0 : throw LogicError( 214 : "QuadricProblem: cannot convert a non-linear term to quadric form"); 215 : const auto& residual = 216 0 : downcast<LinearResidual<data_t>>(trueFunctional.getResidual()); 217 : 218 0 : const auto& W = trueFunctional.getWeightingOperator(); 219 : 220 0 : if (residual.hasOperator()) { 221 0 : const auto& A = residual.getOperator(); 222 0 : dataTermOp = std::make_unique<LinearOperator<data_t>>(adjoint(A) * W * A); 223 : 224 0 : if (residual.hasDataVector()) { 225 0 : quadricVec = std::make_unique<DataContainer<data_t>>( 226 : A.applyAdjoint(W.apply(residual.getDataVector()))); 227 : } 228 : } else { 229 0 : dataTermOp = W.clone(); 230 : 231 0 : if (residual.hasDataVector()) { 232 0 : quadricVec = std::make_unique<DataContainer<data_t>>( 233 : W.apply(residual.getDataVector())); 234 : } 235 : } 236 : } else { 237 0 : throw LogicError("QuadricProblem: can only convert functionals of type " 238 : "(Weighted)L2NormPow2 to Quadric"); 239 : } 240 : 241 0 : if (problem.getRegularizationTerms().empty()) { 242 0 : if (!quadricVec) { 243 0 : return std::make_unique<Quadric<data_t>>(*dataTermOp); 244 : } else { 245 0 : return std::make_unique<Quadric<data_t>>(*dataTermOp, *quadricVec); 246 : } 247 : } 248 : 249 : // add regularization terms 250 0 : LinearOperator<data_t> quadricOp{dataTermOp->getDomainDescriptor(), 251 : dataTermOp->getRangeDescriptor()}; 252 : 253 0 : for (std::size_t i = 0; i < problem.getRegularizationTerms().size(); i++) { 254 0 : const auto& regTerm = problem.getRegularizationTerms()[i]; 255 0 : LinearResidual<data_t> residual = getGradientExpression(regTerm); 256 : 257 0 : if (i == 0) 258 0 : quadricOp = (*dataTermOp + residual.getOperator()); 259 : else 260 0 : quadricOp = quadricOp + residual.getOperator(); 261 : 262 0 : if (residual.hasDataVector()) { 263 0 : if (!quadricVec) 264 0 : quadricVec = 265 : std::make_unique<DataContainer<data_t>>(residual.getDataVector()); 266 : else 267 0 : *quadricVec += residual.getDataVector(); 268 : } 269 : } 270 : 271 0 : if (!quadricVec) { 272 0 : return std::make_unique<Quadric<data_t>>(quadricOp); 273 : } else { 274 0 : return std::make_unique<Quadric<data_t>>(quadricOp, *quadricVec); 275 : } 276 0 : } 277 : } 278 : 279 : // ------------------------------------------ 280 : // explicit template instantiation 281 : template class QuadricProblem<float>; 282 : template class QuadricProblem<double>; 283 : 284 : } // namespace elsa