Line data Source code
1 : #include "QuadricProblem.h" 2 : #include "L2NormPow2.h" 3 : #include "WeightedL2NormPow2.h" 4 : #include "LinearOperator.h" 5 : #include "Identity.h" 6 : #include "TypeCasts.hpp" 7 : 8 : namespace elsa 9 : { 10 : 11 : template <typename data_t> 12 : QuadricProblem<data_t>::QuadricProblem(const LinearOperator<data_t>& A, 13 : const DataContainer<data_t>& b, 14 : const DataContainer<data_t>& x0, bool spdA) 15 : : QuadricProblem<data_t>{spdA ? Quadric{A, b} : Quadric{adjoint(A) * A, A.applyAdjoint(b)}, 16 : x0} 17 8 : { 18 : // sanity checks are done in the member constructors already 19 8 : } 20 : 21 : template <typename data_t> 22 : QuadricProblem<data_t>::QuadricProblem(const LinearOperator<data_t>& A, 23 : const DataContainer<data_t>& b, bool spdA) 24 : : QuadricProblem<data_t>{spdA ? Quadric{A, b} : Quadric{adjoint(A) * A, A.applyAdjoint(b)}} 25 16 : { 26 : // sanity checks are done in the member constructors already 27 16 : } 28 : 29 : template <typename data_t> 30 : QuadricProblem<data_t>::QuadricProblem(const Quadric<data_t>& quadric, 31 : const DataContainer<data_t>& x0) 32 : : Problem<data_t>{quadric, x0} 33 18 : { 34 : // sanity checks are done in the member constructors already 35 18 : } 36 : 37 : template <typename data_t> 38 : QuadricProblem<data_t>::QuadricProblem(const Quadric<data_t>& quadric) 39 : : Problem<data_t>{quadric} 40 22 : { 41 : // sanity checks are done in the member constructors already 42 22 : } 43 : 44 : template <typename data_t> 45 : QuadricProblem<data_t>::QuadricProblem(const Problem<data_t>& problem) 46 : : Problem<data_t>{*quadricFromProblem(problem), problem.getCurrentSolution()} 47 72 : { 48 : // sanity checks are done in the member constructors already 49 72 : } 50 : 51 : template <typename data_t> 52 : QuadricProblem<data_t>* QuadricProblem<data_t>::cloneImpl() const 53 26 : { 54 26 : return new QuadricProblem(*this); 55 26 : } 56 : 57 : template <typename data_t> 58 : LinearResidual<data_t> 59 : QuadricProblem<data_t>::getGradientExpression(const RegularizationTerm<data_t>& regTerm) 60 42 : { 61 42 : const auto lambda = regTerm.getWeight(); 62 42 : const Scaling<data_t> lambdaOp{regTerm.getFunctional().getDomainDescriptor(), lambda}; 63 : 64 42 : if (is<L2NormPow2<data_t>>(regTerm.getFunctional())) { 65 38 : const auto& regFunc = downcast<L2NormPow2<data_t>>(regTerm.getFunctional()); 66 : 67 38 : if (!is<LinearResidual<data_t>>(regFunc.getResidual())) { 68 0 : throw LogicError("QuadricProblem: cannot convert a non-linear regularization " 69 0 : "term to quadric form"); 70 0 : } 71 38 : const auto& regTermResidual = downcast<LinearResidual<data_t>>(regFunc.getResidual()); 72 : 73 38 : if (regTermResidual.hasOperator()) { 74 34 : const auto& regTermOp = regTermResidual.getOperator(); 75 34 : LinearOperator<data_t> A = lambdaOp * adjoint(regTermOp) * regTermOp; 76 : 77 34 : if (regTermResidual.hasDataVector()) { 78 34 : DataContainer<data_t> b = 79 34 : lambda * regTermOp.applyAdjoint(regTermResidual.getDataVector()); 80 34 : return LinearResidual<data_t>{A, b}; 81 34 : } else { 82 0 : return LinearResidual<data_t>{A}; 83 0 : } 84 4 : } else { 85 4 : if (regTermResidual.hasDataVector()) { 86 0 : DataContainer<data_t> b = lambda * regTermResidual.getDataVector(); 87 0 : return LinearResidual<data_t>{lambdaOp, b}; 88 4 : } else { 89 4 : return LinearResidual<data_t>{lambdaOp}; 90 4 : } 91 4 : } 92 4 : } else if (is<WeightedL2NormPow2<data_t>>(regTerm.getFunctional())) { 93 0 : const auto& regFunc = downcast<WeightedL2NormPow2<data_t>>(regTerm.getFunctional()); 94 : 95 0 : if (!is<LinearResidual<data_t>>(regFunc.getResidual())) 96 0 : throw LogicError("QuadricProblem: cannot convert a non-linear regularization " 97 0 : "term to quadric form"); 98 : 99 0 : const auto& regTermResidual = downcast<LinearResidual<data_t>>(regFunc.getResidual()); 100 : 101 0 : const auto& W = regFunc.getWeightingOperator(); 102 0 : std::unique_ptr<Scaling<data_t>> lambdaWPtr; 103 : 104 0 : if (W.isIsotropic()) 105 0 : lambdaWPtr = std::make_unique<Scaling<data_t>>(W.getDomainDescriptor(), 106 0 : lambda * W.getScaleFactor()); 107 0 : else 108 0 : lambdaWPtr = std::make_unique<Scaling<data_t>>(W.getDomainDescriptor(), 109 0 : lambda * W.getScaleFactors()); 110 : 111 0 : const auto& lambdaW = *lambdaWPtr; 112 : 113 0 : if (regTermResidual.hasOperator()) { 114 0 : auto& regTermOp = regTermResidual.getOperator(); 115 0 : LinearOperator<data_t> A = adjoint(regTermOp) * lambdaW * regTermOp; 116 : 117 0 : if (regTermResidual.hasDataVector()) { 118 0 : DataContainer<data_t> b = 119 0 : regTermOp.applyAdjoint(lambdaW.apply(regTermResidual.getDataVector())); 120 0 : return LinearResidual<data_t>{A, b}; 121 0 : } else { 122 0 : return LinearResidual<data_t>{A}; 123 0 : } 124 0 : } else { 125 0 : if (regTermResidual.hasDataVector()) { 126 0 : DataContainer<data_t> b = lambdaW.apply(regTermResidual.getDataVector()); 127 0 : return LinearResidual<data_t>{lambdaW, b}; 128 0 : } else { 129 0 : return LinearResidual<data_t>{lambdaW}; 130 0 : } 131 4 : } 132 4 : } else if (is<Quadric<data_t>>(regTerm.getFunctional())) { 133 2 : const auto& regFunc = downcast<Quadric<data_t>>(regTerm.getFunctional()); 134 2 : const auto& quadricResidual = regFunc.getGradientExpression(); 135 2 : if (quadricResidual.hasOperator()) { 136 0 : LinearOperator<data_t> A = lambdaOp * quadricResidual.getOperator(); 137 : 138 0 : if (quadricResidual.hasDataVector()) { 139 0 : const DataContainer<data_t>& b = quadricResidual.getDataVector(); 140 0 : return LinearResidual<data_t>{A, lambda * b}; 141 0 : } else { 142 0 : return LinearResidual<data_t>{A}; 143 0 : } 144 2 : } else { 145 2 : if (quadricResidual.hasDataVector()) { 146 0 : const DataContainer<data_t>& b = quadricResidual.getDataVector(); 147 0 : return LinearResidual<data_t>{lambdaOp, lambda * b}; 148 2 : } else { 149 2 : return LinearResidual<data_t>{lambdaOp}; 150 2 : } 151 2 : } 152 2 : } else { 153 2 : throw InvalidArgumentError("QuadricProblem: regularization terms should be of type " 154 2 : "(Weighted)L2NormPow2 or Quadric"); 155 2 : } 156 42 : } 157 : 158 : template <typename data_t> 159 : std::unique_ptr<Quadric<data_t>> 160 : QuadricProblem<data_t>::quadricFromProblem(const Problem<data_t>& problem) 161 72 : { 162 72 : const auto& functional = problem.getDataTerm(); 163 72 : if (is<Quadric<data_t>>(functional) && problem.getRegularizationTerms().empty()) { 164 21 : return downcast<Quadric<data_t>>(functional.clone()); 165 51 : } else { 166 51 : std::unique_ptr<LinearOperator<data_t>> dataTermOp; 167 51 : std::unique_ptr<DataContainer<data_t>> quadricVec; 168 : 169 : // convert data term 170 51 : if (is<Quadric<data_t>>(functional)) { 171 2 : const auto& trueFunctional = downcast<Quadric<data_t>>(functional); 172 2 : const LinearResidual<data_t>& residual = trueFunctional.getGradientExpression(); 173 : 174 2 : if (residual.hasOperator()) { 175 0 : dataTermOp = residual.getOperator().clone(); 176 2 : } else { 177 2 : dataTermOp = std::make_unique<Identity<data_t>>(residual.getDomainDescriptor()); 178 2 : } 179 : 180 2 : if (residual.hasDataVector()) { 181 0 : quadricVec = std::make_unique<DataContainer<data_t>>(residual.getDataVector()); 182 0 : } 183 49 : } else if (is<L2NormPow2<data_t>>(functional)) { 184 47 : const auto& trueFunctional = downcast<L2NormPow2<data_t>>(functional); 185 : 186 47 : if (!is<LinearResidual<data_t>>(trueFunctional.getResidual())) 187 0 : throw LogicError( 188 0 : "QuadricProblem: cannot convert a non-linear term to quadric form"); 189 : 190 47 : const auto& residual = 191 47 : downcast<LinearResidual<data_t>>(trueFunctional.getResidual()); 192 : 193 47 : if (residual.hasOperator()) { 194 41 : const auto& A = residual.getOperator(); 195 41 : dataTermOp = std::make_unique<LinearOperator<data_t>>(adjoint(A) * A); 196 : 197 41 : if (residual.hasDataVector()) { 198 41 : quadricVec = std::make_unique<DataContainer<data_t>>( 199 41 : A.applyAdjoint(residual.getDataVector())); 200 41 : } 201 41 : } else { 202 6 : dataTermOp = std::make_unique<Identity<data_t>>(residual.getDomainDescriptor()); 203 : 204 6 : if (residual.hasDataVector()) { 205 0 : quadricVec = 206 0 : std::make_unique<DataContainer<data_t>>(residual.getDataVector()); 207 0 : } 208 6 : } 209 47 : } else if (is<WeightedL2NormPow2<data_t>>(functional)) { 210 0 : const auto& trueFunctional = downcast<WeightedL2NormPow2<data_t>>(functional); 211 : 212 0 : if (!is<LinearResidual<data_t>>(trueFunctional.getResidual())) 213 0 : throw LogicError( 214 0 : "QuadricProblem: cannot convert a non-linear term to quadric form"); 215 0 : const auto& residual = 216 0 : downcast<LinearResidual<data_t>>(trueFunctional.getResidual()); 217 : 218 0 : const auto& W = trueFunctional.getWeightingOperator(); 219 : 220 0 : if (residual.hasOperator()) { 221 0 : const auto& A = residual.getOperator(); 222 0 : dataTermOp = std::make_unique<LinearOperator<data_t>>(adjoint(A) * W * A); 223 : 224 0 : if (residual.hasDataVector()) { 225 0 : quadricVec = std::make_unique<DataContainer<data_t>>( 226 0 : A.applyAdjoint(W.apply(residual.getDataVector()))); 227 0 : } 228 0 : } else { 229 0 : dataTermOp = W.clone(); 230 : 231 0 : if (residual.hasDataVector()) { 232 0 : quadricVec = std::make_unique<DataContainer<data_t>>( 233 0 : W.apply(residual.getDataVector())); 234 0 : } 235 0 : } 236 2 : } else { 237 2 : throw LogicError("QuadricProblem: can only convert functionals of type " 238 2 : "(Weighted)L2NormPow2 to Quadric"); 239 2 : } 240 : 241 49 : if (problem.getRegularizationTerms().empty()) { 242 7 : if (!quadricVec) { 243 4 : return std::make_unique<Quadric<data_t>>(*dataTermOp); 244 4 : } else { 245 3 : return std::make_unique<Quadric<data_t>>(*dataTermOp, *quadricVec); 246 3 : } 247 42 : } 248 : 249 : // add regularization terms 250 42 : LinearOperator<data_t> quadricOp{dataTermOp->getDomainDescriptor(), 251 42 : dataTermOp->getRangeDescriptor()}; 252 : 253 84 : for (std::size_t i = 0; i < problem.getRegularizationTerms().size(); i++) { 254 42 : const auto& regTerm = problem.getRegularizationTerms()[i]; 255 42 : LinearResidual<data_t> residual = getGradientExpression(regTerm); 256 : 257 42 : if (i == 0) 258 40 : quadricOp = (*dataTermOp + residual.getOperator()); 259 2 : else 260 2 : quadricOp = quadricOp + residual.getOperator(); 261 : 262 42 : if (residual.hasDataVector()) { 263 34 : if (!quadricVec) 264 0 : quadricVec = 265 0 : std::make_unique<DataContainer<data_t>>(residual.getDataVector()); 266 34 : else 267 34 : *quadricVec += residual.getDataVector(); 268 34 : } 269 42 : } 270 : 271 42 : if (!quadricVec) { 272 2 : return std::make_unique<Quadric<data_t>>(quadricOp); 273 40 : } else { 274 40 : return std::make_unique<Quadric<data_t>>(quadricOp, *quadricVec); 275 40 : } 276 42 : } 277 72 : } 278 : 279 : // ------------------------------------------ 280 : // explicit template instantiation 281 : template class QuadricProblem<float>; 282 : template class QuadricProblem<double>; 283 : 284 : } // namespace elsa