Line data Source code
1 : /**
2 : * @file test_Problem.cpp
3 : *
4 : * @brief Tests for the Problem class
5 : *
6 : * @author David Frank - initial code
7 : * @author Tobias Lasser - rewrite
8 : */
9 :
10 : #include "doctest/doctest.h"
11 :
12 : #include <iostream>
13 : #include <Logger.h>
14 : #include "Problem.h"
15 : #include "Identity.h"
16 : #include "Scaling.h"
17 : #include "LinearResidual.h"
18 : #include "L2NormPow2.h"
19 : #include "VolumeDescriptor.h"
20 : #include "testHelpers.h"
21 : #include "TypeCasts.hpp"
22 :
23 : using namespace elsa;
24 : using namespace doctest;
25 :
26 : TEST_SUITE_BEGIN("problems");
27 :
28 : TEST_CASE("Problem: Testing without regularization")
29 4 : {
30 : // eliminate the timing info from console for the tests
31 4 : Logger::setLevel(Logger::LogLevel::WARN);
32 :
33 4 : GIVEN("some data term")
34 4 : {
35 4 : IndexVector_t numCoeff(2);
36 4 : numCoeff << 17, 23;
37 4 : VolumeDescriptor dd(numCoeff);
38 :
39 4 : RealVector_t scaling(dd.getNumberOfCoefficients());
40 4 : scaling.setRandom();
41 4 : DataContainer dcScaling(dd, scaling);
42 4 : Scaling scaleOp(dd, dcScaling);
43 :
44 4 : RealVector_t dataVec(dd.getNumberOfCoefficients());
45 4 : dataVec.setRandom();
46 4 : DataContainer dcData(dd, dataVec);
47 :
48 4 : LinearResidual linRes(scaleOp, dcData);
49 4 : L2NormPow2 func(linRes);
50 :
51 4 : WHEN("setting up the problem without x0")
52 4 : {
53 2 : Problem prob(func);
54 :
55 2 : THEN("the clone works correctly")
56 2 : {
57 1 : auto probClone = prob.clone();
58 :
59 1 : REQUIRE_NE(probClone.get(), &prob);
60 1 : REQUIRE_EQ(*probClone, prob);
61 1 : }
62 :
63 2 : THEN("the problem behaves as expected")
64 2 : {
65 1 : DataContainer dcZero(dd);
66 1 : dcZero = 0;
67 1 : REQUIRE_EQ(prob.getCurrentSolution(), dcZero);
68 :
69 1 : REQUIRE_UNARY(checkApproxEq(prob.evaluate(), 0.5f * dataVec.squaredNorm()));
70 1 : REQUIRE_UNARY(checkApproxEq(prob.getGradient(), -1.0f * dcScaling * dcData));
71 :
72 1 : auto hessian = prob.getHessian();
73 1 : auto result = hessian.apply(dcData);
74 392 : for (index_t i = 0; i < result.getSize(); ++i)
75 1 : REQUIRE_UNARY(checkApproxEq(result[i], scaling[i] * scaling[i] * dataVec[i]));
76 :
77 1 : REQUIRE_UNARY(checkApproxEq(prob.getLipschitzConstant(100), 1.0f));
78 1 : }
79 2 : }
80 :
81 4 : WHEN("setting up the problem with x0")
82 4 : {
83 2 : RealVector_t x0Vec(dd.getNumberOfCoefficients());
84 2 : x0Vec.setRandom();
85 2 : DataContainer dcX0(dd, x0Vec);
86 :
87 2 : Problem prob(func, dcX0);
88 :
89 2 : THEN("the clone works correctly")
90 2 : {
91 1 : auto probClone = prob.clone();
92 :
93 1 : REQUIRE_NE(probClone.get(), &prob);
94 1 : REQUIRE_EQ(*probClone, prob);
95 1 : }
96 :
97 2 : THEN("the problem behaves as expected")
98 2 : {
99 1 : REQUIRE_EQ(prob.getCurrentSolution(), dcX0);
100 :
101 1 : REQUIRE_UNARY(checkApproxEq(
102 1 : prob.evaluate(), 0.5f
103 1 : * (scaling.array() * x0Vec.array() - dataVec.array())
104 1 : .matrix()
105 1 : .squaredNorm()));
106 :
107 1 : DataContainer gradientDirect = dcScaling * (dcScaling * dcX0 - dcData);
108 1 : auto gradient = prob.getGradient();
109 392 : for (index_t i = 0; i < gradientDirect.getSize(); ++i)
110 1 : REQUIRE_UNARY(checkApproxEq(gradient[i], gradientDirect[i]));
111 :
112 1 : auto hessian = prob.getHessian();
113 1 : auto result = hessian.apply(dcData);
114 392 : for (index_t i = 0; i < result.getSize(); ++i)
115 1 : REQUIRE_UNARY(checkApproxEq(result[i], scaling[i] * scaling[i] * dataVec[i]));
116 :
117 1 : REQUIRE_UNARY(checkApproxEq(prob.getLipschitzConstant(100), 1.0f));
118 1 : }
119 2 : }
120 4 : }
121 4 : }
122 :
123 : TEST_CASE("Problem: Testing with one regularization term")
124 5 : {
125 : // eliminate the timing info from console for the tests
126 5 : Logger::setLevel(Logger::LogLevel::WARN);
127 :
128 5 : GIVEN("some data term and some regularization term")
129 5 : {
130 : // least squares data term
131 5 : IndexVector_t numCoeff(2);
132 5 : numCoeff << 23, 47;
133 5 : VolumeDescriptor dd(numCoeff);
134 :
135 5 : RealVector_t scaling(dd.getNumberOfCoefficients());
136 5 : scaling.setRandom();
137 5 : DataContainer dcScaling(dd, scaling);
138 5 : Scaling scaleOp(dd, dcScaling);
139 :
140 5 : RealVector_t dataVec(dd.getNumberOfCoefficients());
141 5 : dataVec.setRandom();
142 5 : DataContainer dcData(dd, dataVec);
143 :
144 5 : LinearResidual linRes(scaleOp, dcData);
145 5 : L2NormPow2 func(linRes);
146 :
147 : // l2 norm regularization term
148 5 : L2NormPow2 regFunc(dd);
149 5 : real_t weight = 2.0;
150 5 : RegularizationTerm regTerm(weight, regFunc);
151 :
152 5 : WHEN("setting up the problem without x0")
153 5 : {
154 2 : Problem prob(func, regTerm);
155 :
156 2 : THEN("the clone works correctly")
157 2 : {
158 1 : auto probClone = prob.clone();
159 :
160 1 : REQUIRE_NE(probClone.get(), &prob);
161 1 : REQUIRE_EQ(*probClone, prob);
162 1 : }
163 :
164 2 : THEN("the problem behaves as expected")
165 2 : {
166 1 : DataContainer dcZero(dd);
167 1 : dcZero = 0;
168 1 : REQUIRE_UNARY(checkApproxEq(prob.getCurrentSolution(), dcZero));
169 :
170 1 : REQUIRE_UNARY(checkApproxEq(prob.evaluate(), 0.5f * dataVec.squaredNorm()));
171 1 : REQUIRE_UNARY(checkApproxEq(prob.getGradient(), -1.0f * dcScaling * dcData));
172 :
173 1 : auto hessian = prob.getHessian();
174 1 : auto result = hessian.apply(dcData);
175 1082 : for (index_t i = 0; i < result.getSize(); ++i)
176 1 : REQUIRE_UNARY(checkApproxEq(result[i], scaling[i] * scaling[i] * dataVec[i]
177 1 : + weight * dataVec[i]));
178 :
179 1 : REQUIRE_UNARY(checkApproxEq(prob.getLipschitzConstant(100), 1.0f + weight));
180 1 : }
181 2 : }
182 :
183 5 : WHEN("setting up the problem with x0")
184 5 : {
185 2 : RealVector_t x0Vec(dd.getNumberOfCoefficients());
186 2 : x0Vec.setRandom();
187 2 : DataContainer dcX0(dd, x0Vec);
188 :
189 2 : Problem prob(func, regTerm, dcX0);
190 :
191 2 : THEN("the clone works correctly")
192 2 : {
193 1 : auto probClone = prob.clone();
194 :
195 1 : REQUIRE_NE(probClone.get(), &prob);
196 1 : REQUIRE_EQ(*probClone, prob);
197 1 : }
198 :
199 2 : THEN("the problem behaves as expected")
200 2 : {
201 1 : REQUIRE_EQ(prob.getCurrentSolution(), dcX0);
202 :
203 1 : auto valueData =
204 1 : 0.5f
205 1 : * (scaling.array() * x0Vec.array() - dataVec.array()).matrix().squaredNorm();
206 1 : REQUIRE_UNARY(checkApproxEq(prob.evaluate(),
207 1 : valueData + weight * 0.5f * x0Vec.squaredNorm()));
208 :
209 1 : DataContainer gradientDirect =
210 1 : dcScaling * (dcScaling * dcX0 - dcData) + weight * dcX0;
211 1 : auto gradient = prob.getGradient();
212 1082 : for (index_t i = 0; i < gradient.getSize(); ++i)
213 1 : REQUIRE_UNARY(checkApproxEq(gradient[i], gradientDirect[i]));
214 :
215 1 : auto hessian = prob.getHessian();
216 1 : auto result = hessian.apply(dcData);
217 1082 : for (index_t i = 0; i < result.getSize(); ++i)
218 1 : REQUIRE_UNARY(checkApproxEq(result[i], scaling[i] * scaling[i] * dataVec[i]
219 1 : + weight * dataVec[i]));
220 :
221 1 : REQUIRE_UNARY(checkApproxEq(prob.getLipschitzConstant(100), 1.0f + weight));
222 1 : }
223 2 : }
224 :
225 5 : WHEN("given a different data descriptor and another regularization term with a different "
226 5 : "domain descriptor")
227 5 : {
228 : // three-dimensional data descriptor
229 1 : IndexVector_t otherNumCoeff(3);
230 1 : otherNumCoeff << 15, 38, 22;
231 1 : VolumeDescriptor otherDD(otherNumCoeff);
232 :
233 : // l2 norm regularization term
234 1 : L2NormPow2 otherRegFunc(otherDD);
235 1 : RegularizationTerm otherRegTerm(weight, otherRegFunc);
236 :
237 1 : THEN("no exception is thrown when setting up a problem with different domain "
238 1 : "descriptors")
239 1 : {
240 1 : REQUIRE_NOTHROW(Problem{func, otherRegTerm});
241 1 : }
242 1 : }
243 5 : }
244 5 : }
245 :
246 : TEST_CASE("Problem: Testing with several regularization terms")
247 5 : {
248 : // eliminate the timing info from console for the tests
249 5 : Logger::setLevel(Logger::LogLevel::WARN);
250 :
251 5 : GIVEN("some data term and several regularization terms")
252 5 : {
253 : // least squares data term
254 5 : IndexVector_t numCoeff(3);
255 5 : numCoeff << 17, 33, 52;
256 5 : VolumeDescriptor dd(numCoeff);
257 :
258 5 : RealVector_t scaling(dd.getNumberOfCoefficients());
259 5 : scaling.setRandom();
260 5 : DataContainer dcScaling(dd, scaling);
261 5 : Scaling scaleOp(dd, dcScaling);
262 :
263 5 : RealVector_t dataVec(dd.getNumberOfCoefficients());
264 5 : dataVec.setRandom();
265 5 : DataContainer dcData(dd, dataVec);
266 :
267 5 : LinearResidual linRes(scaleOp, dcData);
268 5 : L2NormPow2 func(linRes);
269 :
270 : // l2 norm regularization term
271 5 : L2NormPow2 regFunc(dd);
272 5 : real_t weight1 = 2.0;
273 5 : RegularizationTerm regTerm1(weight1, regFunc);
274 :
275 5 : real_t weight2 = 3.0;
276 5 : RegularizationTerm regTerm2(weight2, regFunc);
277 :
278 5 : std::vector<RegularizationTerm<real_t>> vecReg{regTerm1, regTerm2};
279 :
280 5 : WHEN("setting up the problem without x0")
281 5 : {
282 2 : Problem prob(func, vecReg);
283 :
284 2 : THEN("the clone works correctly")
285 2 : {
286 1 : auto probClone = prob.clone();
287 :
288 1 : REQUIRE_NE(probClone.get(), &prob);
289 1 : REQUIRE_EQ(*probClone, prob);
290 1 : }
291 :
292 2 : THEN("the problem behaves as expected")
293 2 : {
294 1 : DataContainer dcZero(dd);
295 1 : dcZero = 0;
296 1 : REQUIRE_UNARY(checkApproxEq(prob.getCurrentSolution(), dcZero));
297 :
298 1 : REQUIRE_UNARY(checkApproxEq(prob.evaluate(), 0.5f * dataVec.squaredNorm()));
299 1 : REQUIRE_UNARY(checkApproxEq(prob.getGradient(), -1.0f * dcScaling * dcData));
300 :
301 1 : auto hessian = prob.getHessian();
302 1 : auto result = hessian.apply(dcData);
303 29173 : for (index_t i = 0; i < result.getSize(); ++i)
304 1 : REQUIRE_UNARY(checkApproxEq(result[i], scaling[i] * scaling[i] * dataVec[i]
305 1 : + weight1 * dataVec[i]
306 1 : + weight2 * dataVec[i]));
307 :
308 1 : REQUIRE_UNARY(
309 1 : checkApproxEq(prob.getLipschitzConstant(100), 1.0f + weight1 + weight2));
310 1 : }
311 2 : }
312 :
313 5 : WHEN("setting up the problem with x0")
314 5 : {
315 2 : RealVector_t x0Vec(dd.getNumberOfCoefficients());
316 2 : x0Vec.setRandom();
317 2 : DataContainer dcX0(dd, x0Vec);
318 :
319 2 : Problem prob(func, vecReg, dcX0);
320 :
321 2 : THEN("the clone works correctly")
322 2 : {
323 1 : auto probClone = prob.clone();
324 :
325 1 : REQUIRE_NE(probClone.get(), &prob);
326 1 : REQUIRE_EQ(*probClone, prob);
327 1 : }
328 :
329 2 : THEN("the problem behaves as expected")
330 2 : {
331 1 : REQUIRE_UNARY(isApprox(prob.getCurrentSolution(), dcX0));
332 :
333 1 : auto valueData =
334 1 : 0.5f
335 1 : * (scaling.array() * x0Vec.array() - dataVec.array()).matrix().squaredNorm();
336 1 : REQUIRE_UNARY(
337 1 : checkApproxEq(prob.evaluate(), valueData + weight1 * 0.5f * x0Vec.squaredNorm()
338 1 : + weight2 * 0.5f * x0Vec.squaredNorm()));
339 :
340 1 : auto gradient = prob.getGradient();
341 1 : DataContainer gradientDirect =
342 1 : dcScaling * (dcScaling * dcX0 - dcData) + weight1 * dcX0 + weight2 * dcX0;
343 29173 : for (index_t i = 0; i < gradient.getSize(); ++i)
344 1 : REQUIRE_UNARY(checkApproxEq(gradient[i], gradientDirect[i]));
345 :
346 1 : auto hessian = prob.getHessian();
347 1 : auto result = hessian.apply(dcData);
348 29173 : for (index_t i = 0; i < result.getSize(); ++i)
349 1 : REQUIRE_UNARY(checkApproxEq(result[i], scaling[i] * scaling[i] * dataVec[i]
350 1 : + weight1 * dataVec[i]
351 1 : + weight2 * dataVec[i]));
352 :
353 1 : REQUIRE_UNARY(
354 1 : checkApproxEq(prob.getLipschitzConstant(100), 1.0f + weight1 + weight2));
355 1 : }
356 2 : }
357 :
358 5 : WHEN("given two different data descriptors and two regularization terms with a different "
359 5 : "domain descriptor")
360 5 : {
361 : // three-dimensional data descriptor
362 1 : IndexVector_t otherNumCoeff(3);
363 1 : otherNumCoeff << 15, 38, 22;
364 1 : VolumeDescriptor otherDD(otherNumCoeff);
365 :
366 : // four-dimensional data descriptor
367 1 : IndexVector_t anotherNumCoeff(4);
368 1 : anotherNumCoeff << 7, 9, 21, 17;
369 1 : VolumeDescriptor anotherDD(anotherNumCoeff);
370 :
371 : // l2 norm regularization term
372 1 : L2NormPow2 otherRegFunc(otherDD);
373 1 : RegularizationTerm otherRegTerm(weight1, otherRegFunc);
374 :
375 : // l2 norm regularization term
376 1 : L2NormPow2 anotherRegFunc(anotherDD);
377 1 : RegularizationTerm anotherRegTerm(weight2, anotherRegFunc);
378 :
379 1 : THEN("no exception is thrown when setting up a problem with different domain "
380 1 : "descriptors")
381 1 : {
382 1 : REQUIRE_NOTHROW(Problem{func, std::vector{otherRegTerm, anotherRegTerm}});
383 1 : }
384 1 : }
385 5 : }
386 5 : }
387 :
388 : TEST_SUITE_END();
|