Line data Source code
1 : #include "JointRicianLikelihood.h" 2 : #include "DataContainer.h" 3 : #include "TypeCasts.hpp" 4 : #include "LinearOperator.h" 5 : #include "BlockLinearOperator.h" 6 : #include "Timer.h" 7 : #include "Scaling.h" 8 : 9 : namespace elsa 10 : { 11 : 12 : /// construct a RandomBlockDescriptor to wrap up the absorption projection data 13 : /// and the axdt projection data as a single input for this functional 14 : RandomBlocksDescriptor mergeDescriptors(const DataDescriptor& desc1, 15 : const DataDescriptor& desc2) 16 24 : { 17 24 : std::vector<std::unique_ptr<DataDescriptor>> descs; 18 : 19 24 : descs.emplace_back(desc1.clone()); 20 24 : descs.emplace_back(desc2.clone()); 21 : 22 24 : return RandomBlocksDescriptor(descs); 23 24 : } 24 : 25 : template <typename data_t> 26 : RicianLoss<data_t>::RicianLoss(const DataContainer<data_t>& ffa, 27 : const DataContainer<data_t>& ffb, const DataContainer<data_t>& a, 28 : const DataContainer<data_t>& b, 29 : const LinearOperator<data_t>& absorp_op, 30 : const LinearOperator<data_t>& axdt_op, index_t N, 31 : bool approximate) 32 : : Functional<data_t>{mergeDescriptors(absorp_op.getDomainDescriptor(), 33 : axdt_op.getDomainDescriptor())}, 34 : ffa_(ffa), 35 : ffb_(ffb), 36 : a_tilde_(a), 37 : b_tilde_(b), 38 : absorp_op_(absorp_op.clone()), 39 : axdt_op_(axdt_op.clone()), 40 : alpha_(ffb / ffa), 41 : log_ffa_(log(ffa)), 42 : N_(as<data_t>(N)), 43 : approximate_(approximate) 44 24 : { 45 24 : } 46 : 47 : template <typename data_t> 48 : bool RicianLoss<data_t>::isDifferentiable() const 49 0 : { 50 0 : return true; 51 0 : } 52 : 53 : template <typename data_t> 54 : data_t RicianLoss<data_t>::evaluateImpl(const DataContainer<data_t>& Rx) const 55 4 : { 56 4 : Timer timeguard("AXDTStatRecon", "evaluate"); 57 : 58 4 : const auto mu = materialize(Rx.getBlock(0)); 59 4 : const auto eta = materialize(Rx.getBlock(1)); 60 : 61 4 : const auto log_d = -axdt_op_->apply(eta); 62 4 : const auto d = exp(log_d); 63 : 64 4 : auto log_a = -absorp_op_->apply(mu) + log_ffa_; 65 4 : auto a = exp(log_a); 66 : 67 4 : auto alpha_d = alpha_ * d; 68 : 69 4 : if (approximate_) { 70 2 : auto reorder = sq(a_tilde_ - a); 71 2 : reorder *= 2.0f; 72 2 : reorder += sq(b_tilde_ - (a * alpha_d)); 73 2 : reorder /= a; 74 : 75 2 : return log_a.sum() + (N_ / 4.0f) * reorder.sum(); 76 2 : } else { 77 : 78 2 : auto reorder = sq(a_tilde_); 79 2 : reorder *= 2.0f; 80 2 : reorder += sq(b_tilde_); 81 2 : reorder /= a; 82 : 83 2 : return 1.5f * log_a.sum() + (N_ * 0.5f) * a.sum() 84 2 : + (N_ / 4.0f) * (reorder.sum() + a.dot(sq(alpha_d))) 85 2 : - bessel_log_0((N_ / 2) * b_tilde_ * alpha_d).sum(); 86 2 : } 87 4 : } 88 : 89 : template <typename data_t> 90 : void RicianLoss<data_t>::getGradientImpl(const DataContainer<data_t>& Rx, 91 : DataContainer<data_t>& out) const 92 4 : { 93 4 : Timer timeguard("AXDTStatRecon", "getGradient"); 94 : 95 4 : const auto mu = materialize(Rx.getBlock(0)); 96 4 : const auto eta = materialize(Rx.getBlock(1)); 97 : 98 4 : const auto log_d = -axdt_op_->apply(eta); 99 4 : const auto d = exp(log_d); 100 : 101 4 : const auto log_a = -absorp_op_->apply(mu) + log_ffa_; 102 4 : const auto a = exp(log_a); 103 : // Some temporaries, potential improvement? First measure! 104 4 : const auto sqa = sq(a); 105 4 : const auto sqa_tilde = sq(a_tilde_); 106 : 107 4 : const auto alpha_d = alpha_ * d; 108 : 109 4 : auto tmp_a_alpha_d = a * alpha_d; // Might reuse this later in approximation codepath 110 : 111 4 : auto tmp_mu = 2.0f * (sqa - sqa_tilde); 112 4 : tmp_mu += sq(tmp_a_alpha_d); 113 4 : tmp_mu -= sq(b_tilde_); 114 4 : tmp_mu /= a; 115 4 : tmp_mu *= -N_ / 4.0f; 116 4 : tmp_mu -= approximate_ ? 1.0f : 1.5f; // only part that differs 117 4 : const auto grad_mu = absorp_op_->applyAdjoint(tmp_mu); 118 : 119 4 : const auto grad_eta = [&]() { 120 4 : if (approximate_) { 121 2 : tmp_a_alpha_d -= b_tilde_; 122 2 : tmp_a_alpha_d *= alpha_d; 123 2 : tmp_a_alpha_d *= -N_ * 0.5f; 124 : 125 2 : return axdt_op_->applyAdjoint(tmp_a_alpha_d); 126 2 : } else { 127 2 : auto tmp = a * sq(alpha_d); 128 2 : tmp *= -N_ * 0.5f; 129 2 : auto grad_eta_tmp_bessel = b_tilde_ * alpha_d; 130 2 : grad_eta_tmp_bessel *= N_ * 0.5f; 131 2 : tmp += grad_eta_tmp_bessel * bessel_1_0(grad_eta_tmp_bessel); 132 : 133 2 : return axdt_op_->applyAdjoint(tmp); 134 2 : } 135 4 : }(); 136 : 137 4 : out.getBlock(0) = grad_mu; 138 4 : out.getBlock(1) = grad_eta; 139 4 : } 140 : 141 : template <typename data_t> 142 : LinearOperator<data_t> RicianLoss<data_t>::getHessianImpl(const DataContainer<data_t>& Rx) const 143 4 : { 144 4 : Timer timeguard("AXDTStatRecon", "getHessian"); 145 : 146 4 : const auto mu = materialize(Rx.getBlock(0)); 147 4 : const auto eta = materialize(Rx.getBlock(1)); 148 : 149 4 : using BlockType = typename BlockLinearOperator<data_t>::BlockType; 150 : 151 4 : const auto d = exp(-axdt_op_->apply(eta)); 152 4 : const auto a = exp(-absorp_op_->apply(mu)) * ffa_; 153 : 154 4 : const auto alpha_d = alpha_ * d; 155 : 156 4 : const auto H_1_1 = 157 4 : N_ * (2.f * sq(a) + 2.f * sq(a_tilde_) + (sq(a * alpha_d)) + sq(b_tilde_)) / 4.f / a; 158 4 : const auto H_1_2 = N_ * 0.5f * a * sq(alpha_d); 159 4 : const auto H_2_2 = [&]() { 160 4 : if (approximate_) { 161 2 : return (N_ * 0.5f) * (alpha_d * (2.f * a * alpha_d - b_tilde_)); 162 2 : } else { 163 2 : auto tmp = N_ * a * sq(alpha_d); 164 2 : auto z = (N_ * 0.5f) * b_tilde_ * alpha_d; 165 2 : auto quot_z = bessel_1_0(z); 166 2 : tmp -= sq(z) * (1.f - sq(quot_z)); 167 2 : return tmp; 168 2 : } 169 4 : }(); 170 : 171 4 : auto hessian_absorp = [&]() { 172 4 : auto hessian_absorp_0 = adjoint(*absorp_op_) * H_1_1 * *absorp_op_; 173 4 : auto hessian_absorp_1 = adjoint(*absorp_op_) * H_1_2 * *axdt_op_; 174 : 175 4 : typename BlockLinearOperator<data_t>::OperatorList ops; 176 4 : ops.emplace_back(hessian_absorp_0.clone()); 177 4 : ops.emplace_back(hessian_absorp_1.clone()); 178 4 : return BlockLinearOperator<data_t>(ops, BlockType::COL).clone(); 179 4 : }(); 180 : 181 4 : auto hessian_axdt = [&]() { 182 4 : auto hessian_axdt_0 = adjoint(*axdt_op_) * H_1_2 * *absorp_op_; 183 4 : auto hessian_axdt_1 = adjoint(*axdt_op_) * H_2_2 * *axdt_op_; 184 : 185 4 : typename BlockLinearOperator<data_t>::OperatorList ops; 186 4 : ops.emplace_back(hessian_axdt_0.clone()); 187 4 : ops.emplace_back(hessian_axdt_1.clone()); 188 4 : return BlockLinearOperator<data_t>(ops, BlockType::COL).clone(); 189 4 : }(); 190 : 191 4 : typename BlockLinearOperator<data_t>::OperatorList ops; 192 4 : ops.emplace_back(std::move(hessian_absorp)); 193 4 : ops.emplace_back(std::move(hessian_axdt)); 194 4 : return leaf(BlockLinearOperator<data_t>(ops, BlockType::ROW)); 195 4 : } 196 : 197 : template <typename data_t> 198 : RicianLoss<data_t>* RicianLoss<data_t>::cloneImpl() const 199 4 : { 200 4 : return new RicianLoss(ffa_, ffb_, a_tilde_, b_tilde_, *absorp_op_, *axdt_op_, 201 4 : as<index_t>(N_), approximate_); 202 4 : } 203 : 204 : template <typename data_t> 205 : bool RicianLoss<data_t>::isEqual(const Functional<data_t>& other) const 206 4 : { 207 4 : if (!Functional<data_t>::isEqual(other)) 208 0 : return false; 209 : 210 4 : auto otherFn = downcast_safe<RicianLoss>(&other); 211 4 : if (!otherFn) 212 0 : return false; 213 : 214 4 : if (otherFn->approximate_ != approximate_) 215 0 : return false; 216 : 217 4 : if (otherFn->ffa_ != ffa_ || otherFn->ffb_ != ffb_ || otherFn->a_tilde_ != a_tilde_ 218 4 : || otherFn->b_tilde_ != b_tilde_ || *(otherFn->absorp_op_) != *(absorp_op_) 219 4 : || *(otherFn->axdt_op_) != *(axdt_op_) || otherFn->N_ != N_) 220 0 : return false; 221 4 : else 222 4 : return true; 223 4 : } 224 : 225 : // ------------------------------------------ 226 : // explicit template instantiation 227 : template class RicianLoss<float>; 228 : template class RicianLoss<double>; 229 : } // namespace elsa