Line data Source code
1 : /**
2 : * @file test_Problem.cpp
3 : *
4 : * @brief Tests for the Problem class
5 : *
6 : * @author David Frank - initial code
7 : * @author Tobias Lasser - rewrite
8 : */
9 :
10 : #include "doctest/doctest.h"
11 :
12 : #include <iostream>
13 : #include <Logger.h>
14 : #include "Problem.h"
15 : #include "Identity.h"
16 : #include "Scaling.h"
17 : #include "LinearResidual.h"
18 : #include "L2NormPow2.h"
19 : #include "VolumeDescriptor.h"
20 : #include "testHelpers.h"
21 : #include "TypeCasts.hpp"
22 :
23 : using namespace elsa;
24 : using namespace doctest;
25 :
26 : TEST_SUITE_BEGIN("problems");
27 :
28 4 : TEST_CASE("Problem: Testing without regularization")
29 : {
30 : // eliminate the timing info from console for the tests
31 4 : Logger::setLevel(Logger::LogLevel::WARN);
32 :
33 8 : GIVEN("some data term")
34 : {
35 8 : IndexVector_t numCoeff(2);
36 4 : numCoeff << 17, 23;
37 8 : VolumeDescriptor dd(numCoeff);
38 :
39 8 : RealVector_t scaling(dd.getNumberOfCoefficients());
40 4 : scaling.setRandom();
41 8 : DataContainer dcScaling(dd, scaling);
42 8 : Scaling scaleOp(dd, dcScaling);
43 :
44 8 : RealVector_t dataVec(dd.getNumberOfCoefficients());
45 4 : dataVec.setRandom();
46 8 : DataContainer dcData(dd, dataVec);
47 :
48 8 : LinearResidual linRes(scaleOp, dcData);
49 8 : L2NormPow2 func(linRes);
50 :
51 6 : WHEN("setting up the problem without x0")
52 : {
53 4 : Problem prob(func);
54 :
55 3 : THEN("the clone works correctly")
56 : {
57 2 : auto probClone = prob.clone();
58 :
59 1 : REQUIRE_NE(probClone.get(), &prob);
60 1 : REQUIRE_EQ(*probClone, prob);
61 : }
62 :
63 3 : THEN("the problem behaves as expected")
64 : {
65 2 : DataContainer dcZero(dd);
66 1 : dcZero = 0;
67 1 : REQUIRE_EQ(prob.getCurrentSolution(), dcZero);
68 :
69 1 : REQUIRE_UNARY(checkApproxEq(prob.evaluate(), 0.5f * dataVec.squaredNorm()));
70 1 : REQUIRE_UNARY(checkApproxEq(prob.getGradient(), -1.0f * dcScaling * dcData));
71 :
72 2 : auto hessian = prob.getHessian();
73 1 : auto result = hessian.apply(dcData);
74 392 : for (index_t i = 0; i < result.getSize(); ++i)
75 391 : REQUIRE_UNARY(checkApproxEq(result[i], scaling[i] * scaling[i] * dataVec[i]));
76 :
77 1 : REQUIRE_UNARY(checkApproxEq(prob.getLipschitzConstant(100), 1.0f));
78 : }
79 : }
80 :
81 6 : WHEN("setting up the problem with x0")
82 : {
83 4 : RealVector_t x0Vec(dd.getNumberOfCoefficients());
84 2 : x0Vec.setRandom();
85 4 : DataContainer dcX0(dd, x0Vec);
86 :
87 4 : Problem prob(func, dcX0);
88 :
89 3 : THEN("the clone works correctly")
90 : {
91 2 : auto probClone = prob.clone();
92 :
93 1 : REQUIRE_NE(probClone.get(), &prob);
94 1 : REQUIRE_EQ(*probClone, prob);
95 : }
96 :
97 3 : THEN("the problem behaves as expected")
98 : {
99 1 : REQUIRE_EQ(prob.getCurrentSolution(), dcX0);
100 :
101 1 : REQUIRE_UNARY(checkApproxEq(
102 : prob.evaluate(), 0.5f
103 : * (scaling.array() * x0Vec.array() - dataVec.array())
104 : .matrix()
105 : .squaredNorm()));
106 :
107 2 : DataContainer gradientDirect = dcScaling * (dcScaling * dcX0 - dcData);
108 2 : auto gradient = prob.getGradient();
109 392 : for (index_t i = 0; i < gradientDirect.getSize(); ++i)
110 391 : REQUIRE_UNARY(checkApproxEq(gradient[i], gradientDirect[i]));
111 :
112 2 : auto hessian = prob.getHessian();
113 1 : auto result = hessian.apply(dcData);
114 392 : for (index_t i = 0; i < result.getSize(); ++i)
115 391 : REQUIRE_UNARY(checkApproxEq(result[i], scaling[i] * scaling[i] * dataVec[i]));
116 :
117 1 : REQUIRE_UNARY(checkApproxEq(prob.getLipschitzConstant(100), 1.0f));
118 : }
119 : }
120 : }
121 4 : }
122 :
123 5 : TEST_CASE("Problem: Testing with one regularization term")
124 : {
125 : // eliminate the timing info from console for the tests
126 5 : Logger::setLevel(Logger::LogLevel::WARN);
127 :
128 10 : GIVEN("some data term and some regularization term")
129 : {
130 : // least squares data term
131 10 : IndexVector_t numCoeff(2);
132 5 : numCoeff << 23, 47;
133 10 : VolumeDescriptor dd(numCoeff);
134 :
135 10 : RealVector_t scaling(dd.getNumberOfCoefficients());
136 5 : scaling.setRandom();
137 10 : DataContainer dcScaling(dd, scaling);
138 10 : Scaling scaleOp(dd, dcScaling);
139 :
140 10 : RealVector_t dataVec(dd.getNumberOfCoefficients());
141 5 : dataVec.setRandom();
142 10 : DataContainer dcData(dd, dataVec);
143 :
144 10 : LinearResidual linRes(scaleOp, dcData);
145 10 : L2NormPow2 func(linRes);
146 :
147 : // l2 norm regularization term
148 10 : L2NormPow2 regFunc(dd);
149 5 : real_t weight = 2.0;
150 10 : RegularizationTerm regTerm(weight, regFunc);
151 :
152 7 : WHEN("setting up the problem without x0")
153 : {
154 4 : Problem prob(func, regTerm);
155 :
156 3 : THEN("the clone works correctly")
157 : {
158 2 : auto probClone = prob.clone();
159 :
160 1 : REQUIRE_NE(probClone.get(), &prob);
161 1 : REQUIRE_EQ(*probClone, prob);
162 : }
163 :
164 3 : THEN("the problem behaves as expected")
165 : {
166 2 : DataContainer dcZero(dd);
167 1 : dcZero = 0;
168 1 : REQUIRE_UNARY(checkApproxEq(prob.getCurrentSolution(), dcZero));
169 :
170 1 : REQUIRE_UNARY(checkApproxEq(prob.evaluate(), 0.5f * dataVec.squaredNorm()));
171 1 : REQUIRE_UNARY(checkApproxEq(prob.getGradient(), -1.0f * dcScaling * dcData));
172 :
173 2 : auto hessian = prob.getHessian();
174 1 : auto result = hessian.apply(dcData);
175 1082 : for (index_t i = 0; i < result.getSize(); ++i)
176 1081 : REQUIRE_UNARY(checkApproxEq(result[i], scaling[i] * scaling[i] * dataVec[i]
177 : + weight * dataVec[i]));
178 :
179 1 : REQUIRE_UNARY(checkApproxEq(prob.getLipschitzConstant(100), 1.0f + weight));
180 : }
181 : }
182 :
183 7 : WHEN("setting up the problem with x0")
184 : {
185 4 : RealVector_t x0Vec(dd.getNumberOfCoefficients());
186 2 : x0Vec.setRandom();
187 4 : DataContainer dcX0(dd, x0Vec);
188 :
189 4 : Problem prob(func, regTerm, dcX0);
190 :
191 3 : THEN("the clone works correctly")
192 : {
193 2 : auto probClone = prob.clone();
194 :
195 1 : REQUIRE_NE(probClone.get(), &prob);
196 1 : REQUIRE_EQ(*probClone, prob);
197 : }
198 :
199 3 : THEN("the problem behaves as expected")
200 : {
201 1 : REQUIRE_EQ(prob.getCurrentSolution(), dcX0);
202 :
203 : auto valueData =
204 : 0.5f
205 1 : * (scaling.array() * x0Vec.array() - dataVec.array()).matrix().squaredNorm();
206 1 : REQUIRE_UNARY(checkApproxEq(prob.evaluate(),
207 : valueData + weight * 0.5f * x0Vec.squaredNorm()));
208 :
209 : DataContainer gradientDirect =
210 2 : dcScaling * (dcScaling * dcX0 - dcData) + weight * dcX0;
211 2 : auto gradient = prob.getGradient();
212 1082 : for (index_t i = 0; i < gradient.getSize(); ++i)
213 1081 : REQUIRE_UNARY(checkApproxEq(gradient[i], gradientDirect[i]));
214 :
215 2 : auto hessian = prob.getHessian();
216 1 : auto result = hessian.apply(dcData);
217 1082 : for (index_t i = 0; i < result.getSize(); ++i)
218 1081 : REQUIRE_UNARY(checkApproxEq(result[i], scaling[i] * scaling[i] * dataVec[i]
219 : + weight * dataVec[i]));
220 :
221 1 : REQUIRE_UNARY(checkApproxEq(prob.getLipschitzConstant(100), 1.0f + weight));
222 : }
223 : }
224 :
225 6 : WHEN("given a different data descriptor and another regularization term with a different "
226 : "domain descriptor")
227 : {
228 : // three-dimensional data descriptor
229 2 : IndexVector_t otherNumCoeff(3);
230 1 : otherNumCoeff << 15, 38, 22;
231 2 : VolumeDescriptor otherDD(otherNumCoeff);
232 :
233 : // l2 norm regularization term
234 2 : L2NormPow2 otherRegFunc(otherDD);
235 2 : RegularizationTerm otherRegTerm(weight, otherRegFunc);
236 :
237 2 : THEN("no exception is thrown when setting up a problem with different domain "
238 : "descriptors")
239 : {
240 1 : REQUIRE_NOTHROW(Problem{func, otherRegTerm});
241 : }
242 : }
243 : }
244 5 : }
245 :
246 5 : TEST_CASE("Problem: Testing with several regularization terms")
247 : {
248 : // eliminate the timing info from console for the tests
249 5 : Logger::setLevel(Logger::LogLevel::WARN);
250 :
251 10 : GIVEN("some data term and several regularization terms")
252 : {
253 : // least squares data term
254 10 : IndexVector_t numCoeff(3);
255 5 : numCoeff << 17, 33, 52;
256 10 : VolumeDescriptor dd(numCoeff);
257 :
258 10 : RealVector_t scaling(dd.getNumberOfCoefficients());
259 5 : scaling.setRandom();
260 10 : DataContainer dcScaling(dd, scaling);
261 10 : Scaling scaleOp(dd, dcScaling);
262 :
263 10 : RealVector_t dataVec(dd.getNumberOfCoefficients());
264 5 : dataVec.setRandom();
265 10 : DataContainer dcData(dd, dataVec);
266 :
267 10 : LinearResidual linRes(scaleOp, dcData);
268 10 : L2NormPow2 func(linRes);
269 :
270 : // l2 norm regularization term
271 10 : L2NormPow2 regFunc(dd);
272 5 : real_t weight1 = 2.0;
273 10 : RegularizationTerm regTerm1(weight1, regFunc);
274 :
275 5 : real_t weight2 = 3.0;
276 10 : RegularizationTerm regTerm2(weight2, regFunc);
277 :
278 25 : std::vector<RegularizationTerm<real_t>> vecReg{regTerm1, regTerm2};
279 :
280 7 : WHEN("setting up the problem without x0")
281 : {
282 4 : Problem prob(func, vecReg);
283 :
284 3 : THEN("the clone works correctly")
285 : {
286 2 : auto probClone = prob.clone();
287 :
288 1 : REQUIRE_NE(probClone.get(), &prob);
289 1 : REQUIRE_EQ(*probClone, prob);
290 : }
291 :
292 3 : THEN("the problem behaves as expected")
293 : {
294 2 : DataContainer dcZero(dd);
295 1 : dcZero = 0;
296 1 : REQUIRE_UNARY(checkApproxEq(prob.getCurrentSolution(), dcZero));
297 :
298 1 : REQUIRE_UNARY(checkApproxEq(prob.evaluate(), 0.5f * dataVec.squaredNorm()));
299 1 : REQUIRE_UNARY(checkApproxEq(prob.getGradient(), -1.0f * dcScaling * dcData));
300 :
301 2 : auto hessian = prob.getHessian();
302 1 : auto result = hessian.apply(dcData);
303 29173 : for (index_t i = 0; i < result.getSize(); ++i)
304 29172 : REQUIRE_UNARY(checkApproxEq(result[i], scaling[i] * scaling[i] * dataVec[i]
305 : + weight1 * dataVec[i]
306 : + weight2 * dataVec[i]));
307 :
308 1 : REQUIRE_UNARY(
309 : checkApproxEq(prob.getLipschitzConstant(100), 1.0f + weight1 + weight2));
310 : }
311 : }
312 :
313 7 : WHEN("setting up the problem with x0")
314 : {
315 4 : RealVector_t x0Vec(dd.getNumberOfCoefficients());
316 2 : x0Vec.setRandom();
317 4 : DataContainer dcX0(dd, x0Vec);
318 :
319 4 : Problem prob(func, vecReg, dcX0);
320 :
321 3 : THEN("the clone works correctly")
322 : {
323 2 : auto probClone = prob.clone();
324 :
325 1 : REQUIRE_NE(probClone.get(), &prob);
326 1 : REQUIRE_EQ(*probClone, prob);
327 : }
328 :
329 3 : THEN("the problem behaves as expected")
330 : {
331 1 : REQUIRE_UNARY(isApprox(prob.getCurrentSolution(), dcX0));
332 :
333 : auto valueData =
334 : 0.5f
335 1 : * (scaling.array() * x0Vec.array() - dataVec.array()).matrix().squaredNorm();
336 1 : REQUIRE_UNARY(
337 : checkApproxEq(prob.evaluate(), valueData + weight1 * 0.5f * x0Vec.squaredNorm()
338 : + weight2 * 0.5f * x0Vec.squaredNorm()));
339 :
340 2 : auto gradient = prob.getGradient();
341 : DataContainer gradientDirect =
342 2 : dcScaling * (dcScaling * dcX0 - dcData) + weight1 * dcX0 + weight2 * dcX0;
343 29173 : for (index_t i = 0; i < gradient.getSize(); ++i)
344 29172 : REQUIRE_UNARY(checkApproxEq(gradient[i], gradientDirect[i]));
345 :
346 2 : auto hessian = prob.getHessian();
347 1 : auto result = hessian.apply(dcData);
348 29173 : for (index_t i = 0; i < result.getSize(); ++i)
349 29172 : REQUIRE_UNARY(checkApproxEq(result[i], scaling[i] * scaling[i] * dataVec[i]
350 : + weight1 * dataVec[i]
351 : + weight2 * dataVec[i]));
352 :
353 1 : REQUIRE_UNARY(
354 : checkApproxEq(prob.getLipschitzConstant(100), 1.0f + weight1 + weight2));
355 : }
356 : }
357 :
358 6 : WHEN("given two different data descriptors and two regularization terms with a different "
359 : "domain descriptor")
360 : {
361 : // three-dimensional data descriptor
362 2 : IndexVector_t otherNumCoeff(3);
363 1 : otherNumCoeff << 15, 38, 22;
364 2 : VolumeDescriptor otherDD(otherNumCoeff);
365 :
366 : // four-dimensional data descriptor
367 2 : IndexVector_t anotherNumCoeff(4);
368 1 : anotherNumCoeff << 7, 9, 21, 17;
369 2 : VolumeDescriptor anotherDD(anotherNumCoeff);
370 :
371 : // l2 norm regularization term
372 2 : L2NormPow2 otherRegFunc(otherDD);
373 2 : RegularizationTerm otherRegTerm(weight1, otherRegFunc);
374 :
375 : // l2 norm regularization term
376 2 : L2NormPow2 anotherRegFunc(anotherDD);
377 2 : RegularizationTerm anotherRegTerm(weight2, anotherRegFunc);
378 :
379 2 : THEN("no exception is thrown when setting up a problem with different domain "
380 : "descriptors")
381 : {
382 3 : REQUIRE_NOTHROW(Problem{func, std::vector{otherRegTerm, anotherRegTerm}});
383 : }
384 : }
385 : }
386 5 : }
387 :
388 : TEST_SUITE_END();
|