Line data Source code
1 : #include "WLSProblem.h" 2 : #include "L2NormPow2.h" 3 : #include "WeightedL2NormPow2.h" 4 : #include "RandomBlocksDescriptor.h" 5 : #include "BlockLinearOperator.h" 6 : #include "Identity.h" 7 : #include "TypeCasts.hpp" 8 : 9 : namespace elsa 10 : { 11 : template <typename data_t> 12 : WLSProblem<data_t>::WLSProblem(const Scaling<data_t>& W, const LinearOperator<data_t>& A, 13 : const DataContainer<data_t>& b, const DataContainer<data_t>& x0, 14 : const std::optional<data_t> lipschitzConstant) 15 : : Problem<data_t>{WeightedL2NormPow2<data_t>{LinearResidual<data_t>{A, b}, W}, x0, 16 : lipschitzConstant} 17 4 : { 18 : // sanity checks are done in the member constructors already 19 4 : } 20 : 21 : template <typename data_t> 22 : WLSProblem<data_t>::WLSProblem(const Scaling<data_t>& W, const LinearOperator<data_t>& A, 23 : const DataContainer<data_t>& b, 24 : const std::optional<data_t> lipschitzConstant) 25 : : Problem<data_t>{WeightedL2NormPow2<data_t>{LinearResidual<data_t>{A, b}, W}, 26 : lipschitzConstant} 27 4 : { 28 : // sanity checks are done in the member constructors already 29 4 : } 30 : 31 : template <typename data_t> 32 : WLSProblem<data_t>::WLSProblem(const LinearOperator<data_t>& A, const DataContainer<data_t>& b, 33 : const DataContainer<data_t>& x0, 34 : const std::optional<data_t> lipschitzConstant) 35 : : Problem<data_t>{L2NormPow2<data_t>{LinearResidual<data_t>{A, b}}, x0, lipschitzConstant} 36 4 : { 37 : // sanity checks are done in the member constructors already 38 4 : } 39 : 40 : template <typename data_t> 41 : WLSProblem<data_t>::WLSProblem(const LinearOperator<data_t>& A, const DataContainer<data_t>& b, 42 : const std::optional<data_t> lipschitzConstant) 43 : : Problem<data_t>{L2NormPow2<data_t>{LinearResidual<data_t>{A, b}}, lipschitzConstant} 44 134 : { 45 : // sanity checks are done in the member constructors already 46 134 : } 47 : 48 : template <typename data_t> 49 : WLSProblem<data_t>::WLSProblem(const Problem<data_t>& problem) 50 : : Problem<data_t>{*wlsFromProblem(problem), problem.getCurrentSolution()} 51 26 : { 52 26 : } 53 : 54 : template <typename data_t> 55 : WLSProblem<data_t>* WLSProblem<data_t>::cloneImpl() const 56 94 : { 57 94 : return new WLSProblem(*this); 58 94 : } 59 : 60 : template <typename data_t> 61 : std::unique_ptr<Functional<data_t>> 62 : WLSProblem<data_t>::wlsFromProblem(const Problem<data_t>& problem) 63 26 : { 64 26 : const auto& dataTerm = problem.getDataTerm(); 65 26 : const auto& regTerms = problem.getRegularizationTerms(); 66 : 67 26 : if (!is<WeightedL2NormPow2<data_t>>(dataTerm) && !is<L2NormPow2<data_t>>(dataTerm)) 68 2 : throw LogicError("WLSProblem: conversion failed - data term is not " 69 2 : "of type (Weighted)L2NormPow2"); 70 : 71 24 : const auto dataTermResidual = 72 24 : downcast_safe<LinearResidual<data_t>>(&dataTerm.getResidual()); 73 : 74 24 : if (!dataTermResidual) 75 0 : throw LogicError("WLSProblem: conversion failed - data term is non-linear"); 76 : 77 : // data term is of convertible type 78 : 79 : // no conversion needed if no regTerms 80 24 : if (regTerms.empty()) 81 8 : return dataTerm.clone(); 82 : 83 : // else regTerms present, determine data vector descriptor 84 16 : std::vector<std::unique_ptr<DataDescriptor>> rangeDescList(0); 85 16 : rangeDescList.push_back(dataTermResidual->getRangeDescriptor().clone()); 86 16 : for (const auto& regTerm : regTerms) { 87 16 : if (!is<WeightedL2NormPow2<data_t>>(regTerm.getFunctional()) 88 16 : && !is<L2NormPow2<data_t>>(regTerm.getFunctional())) 89 2 : throw LogicError("WLSProblem: conversion failed - regularization term is not " 90 2 : "of type (Weighted)L2NormPow2"); 91 : 92 : // 93 14 : const auto regTermResidual = 94 14 : downcast_safe<LinearResidual<data_t>>(®Term.getFunctional().getResidual()); 95 : 96 14 : if (!regTermResidual) 97 0 : throw LogicError( 98 0 : "WLSProblem: conversion failed - regularization term is non-linear"); 99 : 100 14 : rangeDescList.push_back(regTermResidual->getRangeDescriptor().clone()); 101 14 : } 102 : 103 : // problem is convertible, allocate memory, set to zero and start building block op 104 16 : RandomBlocksDescriptor dataVecDesc{rangeDescList}; 105 0 : DataContainer<data_t> dataVec{dataVecDesc}; 106 0 : dataVec = 0; 107 0 : std::vector<std::unique_ptr<LinearOperator<data_t>>> opList(0); 108 : 109 : // add block corresponding to data term 110 14 : if (is<WeightedL2NormPow2<data_t>>(dataTerm)) { 111 4 : const auto& trueFunc = downcast<WeightedL2NormPow2<data_t>>(dataTerm); 112 4 : const auto& scaling = trueFunc.getWeightingOperator(); 113 4 : const auto& desc = scaling.getDomainDescriptor(); 114 : 115 4 : std::unique_ptr<Scaling<data_t>> sqrtW{}; 116 4 : if (scaling.isIsotropic()) { 117 2 : auto fac = scaling.getScaleFactor(); 118 : 119 2 : if constexpr (std::is_floating_point_v<data_t>) { 120 2 : if (fac < 0) { 121 2 : throw LogicError("WLSProblem: conversion failed - negative weighting " 122 2 : "factor in WeightedL2NormPow2 term"); 123 2 : } 124 0 : } 125 : 126 0 : sqrtW = std::make_unique<Scaling<data_t>>(desc, sqrt(fac)); 127 2 : } else { 128 2 : const auto& fac = scaling.getScaleFactors(); 129 : 130 2 : if constexpr (std::is_floating_point_v<data_t>) { 131 514 : for (const auto& w : fac) { 132 514 : if (w < 0) { 133 2 : throw LogicError("WLSProblem: conversion failed - negative weighting " 134 2 : "factor in WeightedL2NormPow2 term"); 135 2 : } 136 514 : } 137 2 : } 138 : 139 2 : sqrtW = std::make_unique<Scaling<data_t>>(desc, sqrt(fac)); 140 0 : } 141 : 142 4 : if (dataTermResidual->hasDataVector()) 143 0 : dataVec.getBlock(0) = sqrtW->apply(dataTermResidual->getDataVector()); 144 : 145 0 : if (dataTermResidual->hasOperator()) { 146 0 : const auto composite = *sqrtW * dataTermResidual->getOperator(); 147 0 : opList.emplace_back(composite.clone()); 148 0 : } else { 149 0 : opList.push_back(std::move(sqrtW)); 150 0 : } 151 : 152 10 : } else { 153 10 : if (dataTermResidual->hasOperator()) { 154 2 : opList.push_back(dataTermResidual->getOperator().clone()); 155 8 : } else { 156 8 : opList.push_back( 157 8 : std::make_unique<Identity<data_t>>(dataTermResidual->getDomainDescriptor())); 158 8 : } 159 : 160 10 : if (dataTermResidual->hasDataVector()) 161 2 : dataVec.getBlock(0) = dataTermResidual->getDataVector(); 162 10 : } 163 : 164 : // add blocks corresponding to reg terms 165 10 : index_t blockNum = 1; 166 10 : for (const auto& regTerm : regTerms) { 167 10 : const data_t lambda = regTerm.getWeight(); 168 10 : const auto& func = regTerm.getFunctional(); 169 10 : const auto residual = static_cast<const LinearResidual<data_t>*>(&func.getResidual()); 170 : 171 10 : if (is<WeightedL2NormPow2<data_t>>(func)) { 172 4 : const auto& trueFunc = downcast<WeightedL2NormPow2<data_t>>(func); 173 4 : const auto& scaling = trueFunc.getWeightingOperator(); 174 4 : const auto& desc = scaling.getDomainDescriptor(); 175 : 176 4 : std::unique_ptr<Scaling<data_t>> sqrtLambdaW{}; 177 4 : if (scaling.isIsotropic()) { 178 2 : auto fac = scaling.getScaleFactor(); 179 : 180 2 : if constexpr (std::is_floating_point_v<data_t>) { 181 2 : if (lambda * fac < 0) { 182 2 : throw LogicError("WLSProblem: conversion failed - negative weighting " 183 2 : "factor in WeightedL2NormPow2 term"); 184 2 : } 185 0 : } 186 : 187 0 : sqrtLambdaW = std::make_unique<Scaling<data_t>>(desc, sqrt(lambda * fac)); 188 2 : } else { 189 2 : const auto& fac = scaling.getScaleFactors(); 190 : 191 2 : if constexpr (std::is_floating_point_v<data_t>) { 192 514 : for (const auto& w : fac) { 193 514 : if (lambda * w < 0) { 194 2 : throw LogicError( 195 2 : "WLSProblem: conversion failed - negative weighting " 196 2 : "factor in WeightedL2NormPow2 term"); 197 2 : } 198 514 : } 199 2 : } 200 : 201 2 : sqrtLambdaW = std::make_unique<Scaling<data_t>>(desc, sqrt(lambda * fac)); 202 0 : } 203 : 204 4 : if (residual->hasDataVector()) 205 0 : dataVec.getBlock(blockNum) = sqrtLambdaW->apply(residual->getDataVector()); 206 : 207 0 : if (residual->hasOperator()) { 208 0 : const auto composite = *sqrtLambdaW * residual->getOperator(); 209 0 : opList.push_back(composite.clone()); 210 0 : } else { 211 0 : opList.push_back(std::move(sqrtLambdaW)); 212 0 : } 213 : 214 6 : } else { 215 6 : if constexpr (std::is_floating_point_v<data_t>) { 216 6 : if (lambda < 0) { 217 2 : throw LogicError( 218 2 : "WLSProblem: conversion failed - negative regularization term weight"); 219 2 : } 220 4 : } 221 : 222 4 : auto sqrtLambdaScaling = 223 4 : std::make_unique<Scaling<data_t>>(residual->getRangeDescriptor(), sqrt(lambda)); 224 : 225 4 : if (residual->hasDataVector()) 226 0 : dataVec.getBlock(blockNum) = sqrt(lambda) * residual->getDataVector(); 227 : 228 4 : if (residual->hasOperator()) { 229 0 : const auto composite = *sqrtLambdaScaling * residual->getOperator(); 230 0 : opList.emplace_back(composite.clone()); 231 4 : } else { 232 4 : opList.push_back(std::move(sqrtLambdaScaling)); 233 4 : } 234 4 : } 235 : 236 10 : blockNum++; 237 4 : } 238 : 239 10 : BlockLinearOperator<data_t> blockOp{opList, BlockLinearOperator<data_t>::BlockType::ROW}; 240 : 241 4 : return std::make_unique<L2NormPow2<data_t>>(LinearResidual<data_t>{blockOp, dataVec}); 242 10 : } 243 : 244 : // ------------------------------------------ 245 : // explicit template instantiation 246 : template class WLSProblem<float>; 247 : template class WLSProblem<double>; 248 : template class WLSProblem<complex<float>>; 249 : template class WLSProblem<complex<double>>; 250 : 251 : } // namespace elsa