Line data Source code
1 : /**
2 : * @file test_TikhonovProblem.cpp
3 : *
4 : * @brief Tests for the TikhonovProblem class
5 : *
6 : * @author Nikola Dinev
7 : */
8 :
9 : #include "doctest/doctest.h"
10 :
11 : #include "Problem.h"
12 : #include "Identity.h"
13 : #include "Scaling.h"
14 : #include "LinearResidual.h"
15 : #include "L2NormPow2.h"
16 : #include "L1Norm.h"
17 : #include "TikhonovProblem.h"
18 : #include "VolumeDescriptor.h"
19 : #include "Logger.h"
20 : #include "testHelpers.h"
21 :
22 : using namespace elsa;
23 : using namespace doctest;
24 :
25 : TEST_SUITE_BEGIN("problems");
26 :
27 : TEST_CASE_TEMPLATE("TikhonovProblem: Testing with one regularization term", data_t, float, double)
28 16 : {
29 16 : Logger::setLevel(Logger::LogLevel::WARN);
30 :
31 16 : GIVEN("some data term and some regularization term")
32 16 : {
33 : // least squares data term
34 16 : IndexVector_t numCoeff(2);
35 16 : numCoeff << 23, 47;
36 16 : VolumeDescriptor dd(numCoeff);
37 :
38 16 : Vector_t<data_t> scaling(dd.getNumberOfCoefficients());
39 16 : scaling.setRandom();
40 16 : DataContainer<data_t> dcScaling(dd, scaling);
41 16 : Scaling scaleOp(dd, dcScaling);
42 :
43 16 : Vector_t<data_t> dataVec(dd.getNumberOfCoefficients());
44 16 : dataVec.setRandom();
45 16 : DataContainer<data_t> dcData(dd, dataVec);
46 :
47 16 : WLSProblem<data_t> wls(scaleOp, dcData);
48 :
49 : // l2 norm regularization term
50 16 : L2NormPow2<data_t> regFunc(dd);
51 16 : auto weight = data_t{2.0};
52 16 : RegularizationTerm<data_t> regTerm(weight, regFunc);
53 :
54 16 : WHEN("setting up a TikhonovProblem without regularization terms")
55 16 : {
56 2 : THEN("an exception is thrown")
57 2 : {
58 2 : REQUIRE_THROWS_AS(
59 2 : TikhonovProblem<data_t>(wls, std::vector<RegularizationTerm<data_t>>{}),
60 2 : InvalidArgumentError);
61 2 : }
62 2 : }
63 :
64 16 : WHEN("setting up a TikhonovProblem with a non (Weighted)L2NormPow2 regularization term")
65 16 : {
66 2 : L1Norm<data_t> invalidRegFunc(dd);
67 2 : RegularizationTerm<data_t> invalidRegTerm(1.0, invalidRegFunc);
68 2 : THEN("an exception is thrown")
69 2 : {
70 2 : REQUIRE_THROWS_AS(TikhonovProblem<data_t>(wls, invalidRegTerm),
71 2 : InvalidArgumentError);
72 2 : }
73 2 : }
74 :
75 16 : WHEN("setting up the TikhonovProblem without x0")
76 16 : {
77 6 : TikhonovProblem<data_t> prob(wls, regTerm);
78 :
79 6 : THEN("the clone works correctly")
80 6 : {
81 2 : auto probClone = prob.clone();
82 :
83 2 : REQUIRE_NE(probClone.get(), &prob);
84 2 : REQUIRE_EQ(*probClone, prob);
85 2 : }
86 :
87 6 : THEN("the TikhonovProblem behaves as expected")
88 6 : {
89 2 : DataContainer<data_t> dcZero(dd);
90 2 : dcZero = 0;
91 2 : REQUIRE_UNARY(isApprox(prob.getCurrentSolution(), dcZero));
92 :
93 2 : REQUIRE_UNARY(
94 2 : checkApproxEq(prob.evaluate(), as<data_t>(0.5) * dataVec.squaredNorm()));
95 2 : REQUIRE_UNARY(isApprox(prob.getGradient(), as<data_t>(-1.0) * dcScaling * dcData));
96 :
97 2 : auto hessian = prob.getHessian();
98 2 : auto result = hessian.apply(dcData);
99 2164 : for (index_t i = 0; i < result.getSize(); ++i)
100 2 : REQUIRE_UNARY(checkApproxEq(result[i], scaling[i] * scaling[i] * dataVec[i]
101 2 : + weight * dataVec[i]));
102 2 : }
103 :
104 6 : THEN("the TikhonovProblem is different from a Problem with the same terms")
105 6 : {
106 2 : Problem optProb(prob.getDataTerm(), prob.getRegularizationTerms());
107 2 : REQUIRE_NE(prob, optProb);
108 2 : REQUIRE_NE(optProb, prob);
109 2 : }
110 6 : }
111 :
112 16 : WHEN("setting up the TikhonovProblem with x0")
113 16 : {
114 6 : Vector_t<data_t> x0Vec(dd.getNumberOfCoefficients());
115 6 : x0Vec.setRandom();
116 6 : DataContainer<data_t> dcX0(dd, x0Vec);
117 :
118 6 : wls.getCurrentSolution() = dcX0;
119 6 : TikhonovProblem<data_t> prob(wls, regTerm);
120 :
121 6 : THEN("the clone works correctly")
122 6 : {
123 2 : auto probClone = prob.clone();
124 :
125 2 : REQUIRE_NE(probClone.get(), &prob);
126 2 : REQUIRE_EQ(*probClone, prob);
127 2 : }
128 :
129 6 : THEN("the TikhonovProblem behaves as expected")
130 6 : {
131 2 : REQUIRE_UNARY(isApprox(prob.getCurrentSolution(), dcX0));
132 :
133 2 : auto valueData =
134 2 : as<data_t>(0.5)
135 2 : * (scaling.array() * x0Vec.array() - dataVec.array()).matrix().squaredNorm();
136 2 : REQUIRE_UNARY(checkApproxEq(
137 2 : prob.evaluate(), valueData + weight * as<data_t>(0.5) * x0Vec.squaredNorm()));
138 :
139 2 : auto gradient = prob.getGradient();
140 2 : DataContainer gradientDirect =
141 2 : dcScaling * (dcScaling * dcX0 - dcData) + weight * dcX0;
142 :
143 2164 : for (index_t i = 0; i < gradient.getSize(); ++i)
144 2 : REQUIRE_UNARY(checkApproxEq(gradient[i], gradientDirect[i]));
145 :
146 2 : auto hessian = prob.getHessian();
147 2 : auto result = hessian.apply(dcData);
148 2164 : for (index_t i = 0; i < result.getSize(); ++i)
149 2 : REQUIRE_UNARY(checkApproxEq(result[i], scaling[i] * scaling[i] * dataVec[i]
150 2 : + weight * dataVec[i]));
151 2 : }
152 :
153 6 : THEN("the TikhonovProblem is different from a Problem with the same terms")
154 6 : {
155 2 : Problem optProb(prob.getDataTerm(), prob.getRegularizationTerms());
156 2 : REQUIRE_NE(prob, optProb);
157 2 : REQUIRE_NE(optProb, prob);
158 2 : }
159 6 : }
160 16 : }
161 16 : }
162 :
163 : TEST_CASE_TEMPLATE("TikhonovProblem: Testing with several regularization terms", data_t, float,
164 : double)
165 14 : {
166 14 : Logger::setLevel(Logger::LogLevel::WARN);
167 :
168 14 : GIVEN("some data term and several regularization terms")
169 14 : {
170 : // least squares data term
171 14 : IndexVector_t numCoeff(3);
172 14 : numCoeff << 17, 33, 52;
173 14 : VolumeDescriptor dd(numCoeff);
174 :
175 14 : Vector_t<data_t> scaling(dd.getNumberOfCoefficients());
176 14 : scaling.setRandom();
177 14 : DataContainer<data_t> dcScaling(dd, scaling);
178 14 : Scaling scaleOp(dd, dcScaling);
179 :
180 14 : Vector_t<data_t> dataVec(dd.getNumberOfCoefficients());
181 14 : dataVec.setRandom();
182 14 : DataContainer<data_t> dcData(dd, dataVec);
183 :
184 14 : WLSProblem<data_t> wls(scaleOp, dcData);
185 :
186 : // l2 norm regularization term
187 14 : L2NormPow2<data_t> regFunc(dd);
188 14 : auto weight1 = data_t{2.0};
189 14 : RegularizationTerm<data_t> regTerm1(weight1, regFunc);
190 :
191 14 : auto weight2 = data_t{3.0};
192 14 : RegularizationTerm<data_t> regTerm2(weight2, regFunc);
193 :
194 14 : std::vector<RegularizationTerm<data_t>> vecReg{regTerm1, regTerm2};
195 :
196 14 : WHEN("setting up a TikhonovProblem with a non (Weighted)L2NormPow2 regularization term")
197 14 : {
198 2 : L1Norm<data_t> invalidRegFunc(dd);
199 2 : RegularizationTerm<data_t> invalidRegTerm(1.0, invalidRegFunc);
200 2 : std::vector<RegularizationTerm<data_t>> invalidVecReg1{regTerm1, invalidRegTerm};
201 2 : std::vector<RegularizationTerm<data_t>> invalidVecReg2{invalidRegTerm, regTerm2};
202 2 : THEN("an exception is thrown")
203 2 : {
204 2 : REQUIRE_THROWS_AS(TikhonovProblem<data_t>(wls, invalidVecReg1),
205 2 : InvalidArgumentError);
206 2 : REQUIRE_THROWS_AS(TikhonovProblem<data_t>(wls, invalidVecReg2),
207 2 : InvalidArgumentError);
208 2 : }
209 2 : }
210 :
211 14 : WHEN("setting up the TikhonovProblem without x0")
212 14 : {
213 6 : TikhonovProblem<data_t> prob(wls, vecReg);
214 :
215 6 : THEN("the clone works correctly")
216 6 : {
217 2 : auto probClone = prob.clone();
218 :
219 2 : REQUIRE_NE(probClone.get(), &prob);
220 2 : REQUIRE_EQ(*probClone, prob);
221 2 : }
222 :
223 6 : THEN("the TikhonovProblem behaves as expected")
224 6 : {
225 2 : DataContainer<data_t> dcZero(dd);
226 2 : dcZero = 0;
227 2 : REQUIRE_UNARY(isApprox(prob.getCurrentSolution(), dcZero));
228 :
229 2 : REQUIRE_UNARY(
230 2 : checkApproxEq(prob.evaluate(), as<data_t>(0.5) * dataVec.squaredNorm()));
231 2 : REQUIRE_UNARY(isApprox(prob.getGradient(), as<data_t>(-1.0) * dcScaling * dcData));
232 :
233 2 : auto hessian = prob.getHessian();
234 2 : auto result = hessian.apply(dcData);
235 58346 : for (index_t i = 0; i < result.getSize(); ++i)
236 2 : REQUIRE_UNARY(checkApproxEq(result[i], scaling[i] * scaling[i] * dataVec[i]
237 2 : + weight1 * dataVec[i]
238 2 : + weight2 * dataVec[i]));
239 2 : }
240 :
241 6 : THEN("the TikhonovProblem is different from a Problem with the same terms")
242 6 : {
243 2 : Problem optProb(prob.getDataTerm(), prob.getRegularizationTerms());
244 2 : REQUIRE_NE(prob, optProb);
245 2 : REQUIRE_NE(optProb, prob);
246 2 : }
247 6 : }
248 :
249 14 : WHEN("setting up the TikhonovProblem with x0")
250 14 : {
251 6 : Vector_t<data_t> x0Vec(dd.getNumberOfCoefficients());
252 6 : x0Vec.setRandom();
253 6 : DataContainer<data_t> dcX0(dd, x0Vec);
254 :
255 6 : wls.getCurrentSolution() = dcX0;
256 6 : TikhonovProblem<data_t> prob(wls, vecReg);
257 :
258 6 : THEN("the clone works correctly")
259 6 : {
260 2 : auto probClone = prob.clone();
261 :
262 2 : REQUIRE_NE(probClone.get(), &prob);
263 2 : REQUIRE_EQ(*probClone, prob);
264 2 : }
265 :
266 6 : THEN("the TikhonovProblem behaves as expected")
267 6 : {
268 2 : REQUIRE_UNARY(isApprox(prob.getCurrentSolution(), dcX0));
269 :
270 2 : auto valueData =
271 2 : as<data_t>(0.5)
272 2 : * (scaling.array() * x0Vec.array() - dataVec.array()).matrix().squaredNorm();
273 2 : REQUIRE_UNARY(checkApproxEq(
274 2 : prob.evaluate(), valueData + weight1 * as<data_t>(0.5) * x0Vec.squaredNorm()
275 2 : + weight2 * as<data_t>(0.5) * x0Vec.squaredNorm()));
276 :
277 2 : auto gradient = prob.getGradient();
278 2 : DataContainer gradientDirect =
279 2 : dcScaling * (dcScaling * dcX0 - dcData) + weight1 * dcX0 + weight2 * dcX0;
280 58346 : for (index_t i = 0; i < gradient.getSize(); ++i)
281 2 : REQUIRE_UNARY(checkApproxEq(gradient[i], gradientDirect[i]));
282 :
283 2 : auto hessian = prob.getHessian();
284 2 : auto result = hessian.apply(dcData);
285 58346 : for (index_t i = 0; i < result.getSize(); ++i)
286 2 : REQUIRE_UNARY(checkApproxEq(result[i], scaling[i] * scaling[i] * dataVec[i]
287 2 : + weight1 * dataVec[i]
288 2 : + weight2 * dataVec[i]));
289 2 : }
290 :
291 6 : THEN("the TikhonovProblem is different from a Problem with the same terms")
292 6 : {
293 2 : Problem optProb(prob.getDataTerm(), prob.getRegularizationTerms());
294 2 : REQUIRE_NE(prob, optProb);
295 2 : REQUIRE_NE(optProb, prob);
296 2 : }
297 6 : }
298 14 : }
299 14 : }
300 :
301 : TEST_SUITE_END();
|