Line data Source code
1 : /** 2 : * @file test_ADMM.cpp 3 : * 4 : * @brief Tests for the ADMM class 5 : * 6 : * @author Andi Braimllari 7 : */ 8 : 9 : #include "doctest/doctest.h" 10 : 11 : #include "ADMM.h" 12 : #include "CG.h" 13 : #include "SoftThresholding.h" 14 : #include "HardThresholding.h" 15 : #include "Identity.h" 16 : #include "FISTA.h" 17 : #include "Logger.h" 18 : #include "VolumeDescriptor.h" 19 : #include <testHelpers.h> 20 : 21 : using namespace elsa; 22 : using namespace doctest; 23 : 24 : TEST_SUITE_BEGIN("solvers"); 25 : 26 10 : TEST_CASE_TEMPLATE("ADMM: Solving problems", data_t, float, double) 27 : { 28 4 : Logger::setLevel(Logger::LogLevel::OFF); 29 : 30 8 : GIVEN("some problems and a constraint") 31 : { 32 8 : IndexVector_t numCoeff(2); 33 4 : numCoeff << 21, 11; 34 8 : VolumeDescriptor volDescr(numCoeff); 35 : 36 8 : Vector_t<data_t> bVec(volDescr.getNumberOfCoefficients()); 37 4 : bVec.setRandom(); 38 8 : DataContainer<data_t> dcB(volDescr, bVec); 39 : 40 8 : Identity<data_t> idOp(volDescr); 41 8 : Scaling<data_t> negativeIdOp(volDescr, -1); 42 8 : DataContainer<data_t> dCC(volDescr); 43 4 : dCC = 0; 44 : 45 8 : WLSProblem<data_t> wlsProb(idOp, dcB); 46 : 47 8 : Constraint<data_t> constraint(idOp, negativeIdOp, dCC); 48 : 49 6 : WHEN("setting up ADMM and FISTA to solve a LASSOProblem") 50 : { 51 4 : L1Norm<data_t> regFunc(volDescr); 52 4 : RegularizationTerm<data_t> regTerm(0.000001f, regFunc); 53 : 54 4 : SplittingProblem<data_t> splittingProblem(wlsProb.getDataTerm(), regTerm, constraint); 55 : 56 4 : ADMM<CG, SoftThresholding, data_t> admm(splittingProblem); 57 : 58 4 : LASSOProblem<data_t> lassoProb(wlsProb, regTerm); 59 4 : FISTA<data_t> fista(lassoProb); 60 : 61 4 : THEN("the solutions match") 62 : { 63 2 : REQUIRE_UNARY(isApprox(admm.solve(100), fista.solve(100))); 64 : } 65 : } 66 : 67 6 : WHEN("setting up ADMM to solve a WLSProblem + L0PseudoNorm") 68 : { 69 4 : L0PseudoNorm<data_t> regFunc(volDescr); 70 4 : RegularizationTerm<data_t> regTerm(0.000001f, regFunc); 71 : 72 4 : SplittingProblem<data_t> splittingProblem(wlsProb.getDataTerm(), regTerm, constraint); 73 : 74 4 : ADMM<CG, HardThresholding, data_t> admm(splittingProblem); 75 : 76 4 : THEN("the solution doesn't throw, is not nan and is approximate to the b vector") 77 : { 78 2 : REQUIRE_NOTHROW(admm.solve(10)); 79 2 : REQUIRE_UNARY(!std::isnan(admm.solve(20).squaredL2Norm())); 80 2 : REQUIRE_UNARY(isApprox(admm.solve(20), dcB)); 81 : } 82 : } 83 : } 84 4 : } 85 : 86 : TEST_SUITE_END();