Line data Source code
1 : #include "RicianLoss.h" 2 : #include "DataContainer.h" 3 : #include "TypeCasts.hpp" 4 : #include "IdenticalBlocksDescriptor.h" 5 : #include "Identity.h" 6 : #include "LinearOperator.h" 7 : #include "Scaling.h" 8 : #include "BlockLinearOperator.h" 9 : #include "ZeroOperator.h" 10 : #include "Timer.h" 11 : #include "Assertions.h" 12 : 13 : namespace elsa 14 : { 15 : template <typename data_t> 16 : std::unique_ptr<DataDescriptor> RicianLoss<data_t>::generate_placeholder_descriptor() 17 0 : { 18 0 : index_t numOfDims{1}; 19 0 : IndexVector_t dims(numOfDims); 20 0 : dims << 1; 21 0 : return std::make_unique<VolumeDescriptor>(dims); 22 0 : } 23 : 24 : template <typename data_t> 25 : RandomBlocksDescriptor RicianLoss<data_t>::generate_descriptors(const DataDescriptor& desc1, 26 : const DataDescriptor& desc2) 27 24 : { 28 24 : std::vector<std::unique_ptr<DataDescriptor>> descs; 29 : 30 24 : descs.emplace_back(desc1.clone()); 31 24 : descs.emplace_back(desc2.clone()); 32 24 : return RandomBlocksDescriptor(descs); 33 24 : } 34 : 35 : template <typename data_t> 36 : RicianLoss<data_t>::RicianLoss(const DataContainer<data_t>& ffa, 37 : const DataContainer<data_t>& ffb, const DataContainer<data_t>& a, 38 : const DataContainer<data_t>& b, 39 : const LinearOperator<data_t>& absorp_op, 40 : const LinearOperator<data_t>& axdt_op, index_t N, 41 : bool approximate) 42 : : Functional<data_t>( 43 : generate_descriptors(absorp_op.getDomainDescriptor(), axdt_op.getDomainDescriptor())), 44 : ffa_(ffa), 45 : ffb_(ffb), 46 : a_tilde_(a), 47 : b_tilde_(b), 48 : absorp_op_(absorp_op.clone()), 49 : axdt_op_(axdt_op.clone()), 50 : alpha_(ffb / ffa), 51 : d_tilde_(b_tilde_ / a_tilde_ / alpha_), 52 : N_(static_cast<data_t>(N)), 53 : approximate_(approximate) 54 24 : { 55 24 : } 56 : 57 : template <typename data_t> 58 : bool RicianLoss<data_t>::isDifferentiable() const 59 0 : { 60 0 : return true; 61 0 : } 62 : 63 : template <typename data_t> 64 : data_t RicianLoss<data_t>::evaluateImpl(const DataContainer<data_t>& Rx) const 65 4 : { 66 4 : Timer timeguard("AXDTStatRecon", "evaluate"); 67 : 68 4 : const auto mu = materialize(Rx.getBlock(0)); 69 4 : const auto eta = materialize(Rx.getBlock(1)); 70 : 71 4 : auto log_d = -axdt_op_->apply(eta); 72 4 : auto d = exp(log_d); 73 : 74 4 : auto a = exp(-absorp_op_->apply(mu)) * ffa_; 75 4 : auto log_a = log(a); 76 : 77 4 : auto alpha_d = alpha_ * d; 78 : 79 4 : if (approximate_) { 80 2 : auto reorder = sq(a_tilde_ - a); 81 2 : reorder *= 2.0f; 82 2 : reorder += sq(b_tilde_ - (a * alpha_d)); 83 2 : reorder /= a; 84 : 85 2 : return log_a.sum() + (N_ / 4.0f) * reorder.sum(); 86 2 : } else { 87 : 88 2 : auto reorder = sq(a_tilde_); 89 2 : reorder *= 2.0f; 90 2 : reorder += sq(b_tilde_); 91 2 : reorder /= a; 92 : 93 2 : return 1.5f * log_a.sum() + (N_ * 0.5f) * a.sum() 94 2 : + (N_ / 4.0f) * (reorder.sum() + a.dot(sq(alpha_d))) 95 2 : - axdt::log_bessel_0((N_ / 2) * b_tilde_ * alpha_d).sum(); 96 2 : } 97 4 : } 98 : 99 : template <typename data_t> 100 : void RicianLoss<data_t>::getGradientImpl(const DataContainer<data_t>& Rx, 101 : DataContainer<data_t>& out) const 102 4 : { 103 4 : Timer timeguard("AXDTStatRecon", "getGradient"); 104 : 105 4 : const auto mu = materialize(Rx.getBlock(0)); 106 4 : const auto eta = materialize(Rx.getBlock(1)); 107 : 108 4 : auto log_d = -axdt_op_->apply(eta); 109 4 : auto d = exp(log_d); 110 : 111 4 : auto grad_eta = empty<data_t>(out.getBlock(1).getDataDescriptor()); 112 : 113 4 : auto a = exp(-absorp_op_->apply(mu)) * ffa_; 114 4 : auto log_a = log(a); 115 : 116 : // Some temporaries, potential improvement? First measure! 117 4 : auto sqa = sq(a); 118 4 : auto sqa_tilde = sq(a_tilde_); 119 : 120 4 : auto alpha_d = alpha_ * d; 121 : 122 4 : auto tmp_a_alpha_d = a * alpha_d; // Might reuse this later in approximation codepath 123 : 124 4 : auto tmp_mu = 2.0f * (sqa - sqa_tilde); 125 4 : tmp_mu += sq(tmp_a_alpha_d); 126 4 : tmp_mu -= sq(b_tilde_); 127 4 : tmp_mu /= a; 128 4 : tmp_mu *= -N_ / 4.0f; 129 4 : tmp_mu -= approximate_ ? 1.0f : 1.5f; // only part that differs 130 4 : auto grad_mu = absorp_op_->applyAdjoint(tmp_mu); 131 : 132 4 : if (approximate_) { 133 2 : tmp_a_alpha_d -= b_tilde_; 134 2 : tmp_a_alpha_d *= alpha_d; 135 2 : tmp_a_alpha_d *= -N_ * 0.5f; 136 : 137 2 : grad_eta = axdt_op_->applyAdjoint(tmp_a_alpha_d); 138 2 : } else { 139 2 : auto tmp = a * sq(alpha_d); 140 2 : tmp *= -N_ * 0.5f; 141 2 : auto grad_eta_tmp_bessel = b_tilde_ * alpha_d; 142 2 : grad_eta_tmp_bessel *= N_ * 0.5f; 143 2 : tmp += grad_eta_tmp_bessel * axdt::quot_bessel_1_0(grad_eta_tmp_bessel); 144 : 145 2 : grad_eta = axdt_op_->applyAdjoint(tmp); 146 2 : } 147 : 148 4 : out.getBlock(0) = grad_mu; 149 4 : out.getBlock(1) = grad_eta; 150 4 : } 151 : 152 : template <typename data_t> 153 : LinearOperator<data_t> RicianLoss<data_t>::getHessianImpl(const DataContainer<data_t>& Rx) const 154 4 : { 155 4 : Timer timeguard("AXDTStatRecon", "getHessian"); 156 : 157 4 : const auto mu = materialize(Rx.getBlock(0)); 158 4 : const auto eta = materialize(Rx.getBlock(1)); 159 : 160 4 : using BlockType = typename BlockLinearOperator<data_t>::BlockType; 161 : 162 4 : auto d = exp(-axdt_op_->apply(eta)); 163 4 : auto a = exp(-absorp_op_->apply(mu)) * ffa_; 164 : 165 4 : auto alpha_d = alpha_ * d; 166 : 167 4 : auto H_1_1 = 168 4 : N_ * (2.f * sq(a) + 2.f * sq(a_tilde_) + (sq(a * alpha_d)) + sq(b_tilde_)) / 4.f / a; 169 4 : auto H_1_2 = N_ * 0.5f * a * sq(alpha_d); 170 4 : auto H_2_2 = emptylike(a); 171 4 : if (approximate_) { 172 2 : H_2_2 = (N_ * 0.5f) * (alpha_d * (2.f * a * alpha_d - b_tilde_)); 173 2 : } else { 174 2 : H_2_2 = N_ * a * sq(alpha_d); 175 2 : auto z = (N_ * 0.5f) * b_tilde_ * alpha_d; 176 2 : auto quot_z = axdt::quot_bessel_1_0(z); 177 2 : H_2_2 -= sq(z) * (1.f - sq(quot_z)); 178 2 : } 179 : 180 4 : auto hessian_absorp = [&]() { 181 4 : auto hessian_absorp_0 = adjoint(*absorp_op_) * H_1_1 * *absorp_op_; 182 4 : auto hessian_absorp_1 = adjoint(*absorp_op_) * H_1_2 * *axdt_op_; 183 : 184 4 : typename BlockLinearOperator<data_t>::OperatorList ops; 185 4 : ops.emplace_back(hessian_absorp_0.clone()); 186 4 : ops.emplace_back(hessian_absorp_1.clone()); 187 4 : return BlockLinearOperator<data_t>(ops, BlockType::COL).clone(); 188 4 : }(); 189 : 190 4 : auto hessian_axdt = [&]() { 191 4 : auto hessian_axdt_0 = adjoint(*axdt_op_) * H_1_2 * *absorp_op_; 192 4 : auto hessian_axdt_1 = adjoint(*axdt_op_) * H_2_2 * *axdt_op_; 193 : 194 4 : typename BlockLinearOperator<data_t>::OperatorList ops; 195 4 : ops.emplace_back(hessian_axdt_0.clone()); 196 4 : ops.emplace_back(hessian_axdt_1.clone()); 197 4 : return BlockLinearOperator<data_t>(ops, BlockType::COL).clone(); 198 4 : }(); 199 : 200 4 : typename BlockLinearOperator<data_t>::OperatorList ops; 201 4 : ops.emplace_back(std::move(hessian_absorp)); 202 4 : ops.emplace_back(std::move(hessian_axdt)); 203 4 : return leaf(BlockLinearOperator<data_t>(ops, BlockType::ROW)); 204 4 : } 205 : 206 : template <typename data_t> 207 : RicianLoss<data_t>* RicianLoss<data_t>::cloneImpl() const 208 4 : { 209 4 : return new RicianLoss(ffa_, ffb_, a_tilde_, b_tilde_, *absorp_op_, *axdt_op_, 210 4 : static_cast<index_t>(N_), approximate_); 211 4 : } 212 : 213 : template <typename data_t> 214 : bool RicianLoss<data_t>::isEqual(const Functional<data_t>& other) const 215 4 : { 216 4 : if (!Functional<data_t>::isEqual(other)) 217 0 : return false; 218 : 219 4 : auto otherFn = downcast_safe<RicianLoss>(&other); 220 4 : if (!otherFn) 221 0 : return false; 222 : 223 4 : if (otherFn->approximate_ != approximate_) 224 0 : return false; 225 : 226 4 : if (otherFn->ffa_ != ffa_ || otherFn->ffb_ != ffb_ || otherFn->a_tilde_ != a_tilde_ 227 4 : || otherFn->b_tilde_ != b_tilde_ || *(otherFn->absorp_op_) != *(absorp_op_) 228 4 : || *(otherFn->axdt_op_) != *(axdt_op_) || otherFn->N_ != N_) 229 0 : return false; 230 4 : else 231 4 : return true; 232 4 : } 233 : 234 : // ------------------------------------------ 235 : // explicit template instantiation 236 : template class RicianLoss<float>; 237 : template class RicianLoss<double>; 238 : } // namespace elsa