Line data Source code
1 : #include "WLSProblem.h" 2 : #include "L2NormPow2.h" 3 : #include "WeightedL2NormPow2.h" 4 : #include "RandomBlocksDescriptor.h" 5 : #include "BlockLinearOperator.h" 6 : #include "Identity.h" 7 : #include "TypeCasts.hpp" 8 : 9 : namespace elsa 10 : { 11 : template <typename data_t> 12 0 : WLSProblem<data_t>::WLSProblem(const Scaling<data_t>& W, const LinearOperator<data_t>& A, 13 : const DataContainer<data_t>& b, const DataContainer<data_t>& x0, 14 : const std::optional<data_t> lipschitzConstant) 15 : : Problem<data_t>{WeightedL2NormPow2<data_t>{LinearResidual<data_t>{A, b}, W}, x0, 16 0 : lipschitzConstant} 17 : { 18 : // sanity checks are done in the member constructors already 19 0 : } 20 : 21 : template <typename data_t> 22 0 : WLSProblem<data_t>::WLSProblem(const Scaling<data_t>& W, const LinearOperator<data_t>& A, 23 : const DataContainer<data_t>& b, 24 : const std::optional<data_t> lipschitzConstant) 25 : : Problem<data_t>{WeightedL2NormPow2<data_t>{LinearResidual<data_t>{A, b}, W}, 26 0 : lipschitzConstant} 27 : { 28 : // sanity checks are done in the member constructors already 29 0 : } 30 : 31 : template <typename data_t> 32 0 : WLSProblem<data_t>::WLSProblem(const LinearOperator<data_t>& A, const DataContainer<data_t>& b, 33 : const DataContainer<data_t>& x0, 34 : const std::optional<data_t> lipschitzConstant) 35 0 : : Problem<data_t>{L2NormPow2<data_t>{LinearResidual<data_t>{A, b}}, x0, lipschitzConstant} 36 : { 37 : // sanity checks are done in the member constructors already 38 0 : } 39 : 40 : template <typename data_t> 41 0 : WLSProblem<data_t>::WLSProblem(const LinearOperator<data_t>& A, const DataContainer<data_t>& b, 42 : const std::optional<data_t> lipschitzConstant) 43 0 : : Problem<data_t>{L2NormPow2<data_t>{LinearResidual<data_t>{A, b}}, lipschitzConstant} 44 : { 45 : // sanity checks are done in the member constructors already 46 0 : } 47 : 48 : template <typename data_t> 49 0 : WLSProblem<data_t>::WLSProblem(const Problem<data_t>& problem) 50 0 : : Problem<data_t>{*wlsFromProblem(problem), problem.getCurrentSolution()} 51 : { 52 0 : } 53 : 54 : template <typename data_t> 55 0 : WLSProblem<data_t>* WLSProblem<data_t>::cloneImpl() const 56 : { 57 0 : return new WLSProblem(*this); 58 : } 59 : 60 : template <typename data_t> 61 : std::unique_ptr<Functional<data_t>> 62 0 : WLSProblem<data_t>::wlsFromProblem(const Problem<data_t>& problem) 63 : { 64 0 : const auto& dataTerm = problem.getDataTerm(); 65 0 : const auto& regTerms = problem.getRegularizationTerms(); 66 : 67 0 : if (!is<WeightedL2NormPow2<data_t>>(dataTerm) && !is<L2NormPow2<data_t>>(dataTerm)) 68 0 : throw LogicError("WLSProblem: conversion failed - data term is not " 69 : "of type (Weighted)L2NormPow2"); 70 : 71 : const auto dataTermResidual = 72 0 : downcast_safe<LinearResidual<data_t>>(&dataTerm.getResidual()); 73 : 74 0 : if (!dataTermResidual) 75 0 : throw LogicError("WLSProblem: conversion failed - data term is non-linear"); 76 : 77 : // data term is of convertible type 78 : 79 : // no conversion needed if no regTerms 80 0 : if (regTerms.empty()) 81 0 : return dataTerm.clone(); 82 : 83 : // else regTerms present, determine data vector descriptor 84 0 : std::vector<std::unique_ptr<DataDescriptor>> rangeDescList(0); 85 0 : rangeDescList.push_back(dataTermResidual->getRangeDescriptor().clone()); 86 0 : for (const auto& regTerm : regTerms) { 87 0 : if (!is<WeightedL2NormPow2<data_t>>(regTerm.getFunctional()) 88 0 : && !is<L2NormPow2<data_t>>(regTerm.getFunctional())) 89 0 : throw LogicError("WLSProblem: conversion failed - regularization term is not " 90 : "of type (Weighted)L2NormPow2"); 91 : 92 : // 93 : const auto regTermResidual = 94 0 : downcast_safe<LinearResidual<data_t>>(®Term.getFunctional().getResidual()); 95 : 96 0 : if (!regTermResidual) 97 0 : throw LogicError( 98 : "WLSProblem: conversion failed - regularization term is non-linear"); 99 : 100 0 : rangeDescList.push_back(regTermResidual->getRangeDescriptor().clone()); 101 : } 102 : 103 : // problem is convertible, allocate memory, set to zero and start building block op 104 0 : RandomBlocksDescriptor dataVecDesc{rangeDescList}; 105 0 : DataContainer<data_t> dataVec{dataVecDesc}; 106 0 : dataVec = 0; 107 0 : std::vector<std::unique_ptr<LinearOperator<data_t>>> opList(0); 108 : 109 : // add block corresponding to data term 110 0 : if (is<WeightedL2NormPow2<data_t>>(dataTerm)) { 111 0 : const auto& trueFunc = downcast<WeightedL2NormPow2<data_t>>(dataTerm); 112 0 : const auto& scaling = trueFunc.getWeightingOperator(); 113 0 : const auto& desc = scaling.getDomainDescriptor(); 114 : 115 0 : std::unique_ptr<Scaling<data_t>> sqrtW{}; 116 0 : if (scaling.isIsotropic()) { 117 0 : auto fac = scaling.getScaleFactor(); 118 : 119 : if constexpr (std::is_floating_point_v<data_t>) { 120 0 : if (fac < 0) { 121 0 : throw LogicError("WLSProblem: conversion failed - negative weighting " 122 : "factor in WeightedL2NormPow2 term"); 123 : } 124 : } 125 : 126 0 : sqrtW = std::make_unique<Scaling<data_t>>(desc, sqrt(fac)); 127 : } else { 128 0 : const auto& fac = scaling.getScaleFactors(); 129 : 130 : if constexpr (std::is_floating_point_v<data_t>) { 131 0 : for (const auto& w : fac) { 132 0 : if (w < 0) { 133 0 : throw LogicError("WLSProblem: conversion failed - negative weighting " 134 : "factor in WeightedL2NormPow2 term"); 135 : } 136 : } 137 : } 138 : 139 0 : sqrtW = std::make_unique<Scaling<data_t>>(desc, sqrt(fac)); 140 : } 141 : 142 0 : if (dataTermResidual->hasDataVector()) 143 0 : dataVec.getBlock(0) = sqrtW->apply(dataTermResidual->getDataVector()); 144 : 145 0 : if (dataTermResidual->hasOperator()) { 146 0 : const auto composite = *sqrtW * dataTermResidual->getOperator(); 147 0 : opList.emplace_back(composite.clone()); 148 0 : } else { 149 0 : opList.push_back(std::move(sqrtW)); 150 : } 151 : 152 0 : } else { 153 0 : if (dataTermResidual->hasOperator()) { 154 0 : opList.push_back(dataTermResidual->getOperator().clone()); 155 : } else { 156 0 : opList.push_back( 157 : std::make_unique<Identity<data_t>>(dataTermResidual->getDomainDescriptor())); 158 : } 159 : 160 0 : if (dataTermResidual->hasDataVector()) 161 0 : dataVec.getBlock(0) = dataTermResidual->getDataVector(); 162 : } 163 : 164 : // add blocks corresponding to reg terms 165 0 : index_t blockNum = 1; 166 0 : for (const auto& regTerm : regTerms) { 167 0 : const data_t lambda = regTerm.getWeight(); 168 0 : const auto& func = regTerm.getFunctional(); 169 0 : const auto residual = static_cast<const LinearResidual<data_t>*>(&func.getResidual()); 170 : 171 0 : if (is<WeightedL2NormPow2<data_t>>(func)) { 172 0 : const auto& trueFunc = downcast<WeightedL2NormPow2<data_t>>(func); 173 0 : const auto& scaling = trueFunc.getWeightingOperator(); 174 0 : const auto& desc = scaling.getDomainDescriptor(); 175 : 176 0 : std::unique_ptr<Scaling<data_t>> sqrtLambdaW{}; 177 0 : if (scaling.isIsotropic()) { 178 0 : auto fac = scaling.getScaleFactor(); 179 : 180 : if constexpr (std::is_floating_point_v<data_t>) { 181 0 : if (lambda * fac < 0) { 182 0 : throw LogicError("WLSProblem: conversion failed - negative weighting " 183 : "factor in WeightedL2NormPow2 term"); 184 : } 185 : } 186 : 187 0 : sqrtLambdaW = std::make_unique<Scaling<data_t>>(desc, sqrt(lambda * fac)); 188 : } else { 189 0 : const auto& fac = scaling.getScaleFactors(); 190 : 191 : if constexpr (std::is_floating_point_v<data_t>) { 192 0 : for (const auto& w : fac) { 193 0 : if (lambda * w < 0) { 194 0 : throw LogicError( 195 : "WLSProblem: conversion failed - negative weighting " 196 : "factor in WeightedL2NormPow2 term"); 197 : } 198 : } 199 : } 200 : 201 0 : sqrtLambdaW = std::make_unique<Scaling<data_t>>(desc, sqrt(lambda * fac)); 202 : } 203 : 204 0 : if (residual->hasDataVector()) 205 0 : dataVec.getBlock(blockNum) = sqrtLambdaW->apply(residual->getDataVector()); 206 : 207 0 : if (residual->hasOperator()) { 208 0 : const auto composite = *sqrtLambdaW * residual->getOperator(); 209 0 : opList.push_back(composite.clone()); 210 0 : } else { 211 0 : opList.push_back(std::move(sqrtLambdaW)); 212 : } 213 : 214 0 : } else { 215 : if constexpr (std::is_floating_point_v<data_t>) { 216 0 : if (lambda < 0) { 217 0 : throw LogicError( 218 : "WLSProblem: conversion failed - negative regularization term weight"); 219 : } 220 : } 221 : 222 0 : auto sqrtLambdaScaling = 223 0 : std::make_unique<Scaling<data_t>>(residual->getRangeDescriptor(), sqrt(lambda)); 224 : 225 0 : if (residual->hasDataVector()) 226 0 : dataVec.getBlock(blockNum) = sqrt(lambda) * residual->getDataVector(); 227 : 228 0 : if (residual->hasOperator()) { 229 0 : const auto composite = *sqrtLambdaScaling * residual->getOperator(); 230 0 : opList.emplace_back(composite.clone()); 231 0 : } else { 232 0 : opList.push_back(std::move(sqrtLambdaScaling)); 233 : } 234 0 : } 235 : 236 0 : blockNum++; 237 : } 238 : 239 0 : BlockLinearOperator<data_t> blockOp{opList, BlockLinearOperator<data_t>::BlockType::ROW}; 240 : 241 0 : return std::make_unique<L2NormPow2<data_t>>(LinearResidual<data_t>{blockOp, dataVec}); 242 0 : } 243 : 244 : // ------------------------------------------ 245 : // explicit template instantiation 246 : template class WLSProblem<float>; 247 : template class WLSProblem<double>; 248 : template class WLSProblem<complex<float>>; 249 : template class WLSProblem<complex<double>>; 250 : 251 : } // namespace elsa