Line data Source code
1 : #include "SymmetrizedDerivative.h" 2 : #include "Timer.h" 3 : #include "Identity.h" 4 : #include "FiniteDifferences.h" 5 : #include "BlockLinearOperator.h" 6 : #include "DataContainer.h" 7 : 8 : namespace elsa 9 : { 10 : 11 : using namespace std; 12 : 13 : template <typename data_t> 14 : SymmetrizedDerivative<data_t>::SymmetrizedDerivative(const DataDescriptor& domainDescriptor) 15 : : LinearOperator<data_t>(createBase(domainDescriptor)) 16 18 : { 17 18 : precomputeHelpers(); 18 18 : } 19 : 20 : using namespace std; 21 : 22 : template <typename data_t> 23 : void SymmetrizedDerivative<data_t>::applyImpl(const DataContainer<data_t>& x, 24 : DataContainer<data_t>& Ax) const 25 2 : { 26 2 : Timer timeguard("SymmetrizedDerivative", "apply"); 27 : 28 : // D_1 x_1 29 2 : Ax.getBlock(0) = forwardX_->apply(x.getBlock(0)); 30 : 31 : // D_2 x_2 32 2 : Ax.getBlock(1) = forwardY_->apply(x.getBlock(1)); 33 : 34 : // 0.5 (D_2 x_1 + D_1 x_2) 35 2 : Ax.getBlock(2) = lincomb(data_t(0.5), forwardY_->apply(x.getBlock(0)), data_t(0.5), 36 2 : forwardX_->apply(x.getBlock(1))); 37 2 : } 38 : 39 : template <typename data_t> 40 : void SymmetrizedDerivative<data_t>::applyAdjointImpl(const DataContainer<data_t>& y, 41 : DataContainer<data_t>& Aty) const 42 2 : { 43 2 : Timer timeguard("SymmetrizedDerivative", "applyAdjoint"); 44 : 45 : //-D_1^{-} y_1 - D_2^{-} y_3 46 2 : Aty.getBlock(0) = lincomb(data_t(-1), backwardX_->apply(y.getBlock(0)), data_t(-1), 47 2 : backwardY_->apply(y.getBlock(2))); 48 : 49 : //-D_1^{-} y_3 - D_2^{-} y_2 50 2 : Aty.getBlock(1) = lincomb(data_t(-1), backwardX_->apply(y.getBlock(2)), data_t(-1), 51 2 : backwardY_->apply(y.getBlock(1))); 52 2 : } 53 : 54 : template <typename data_t> 55 : void SymmetrizedDerivative<data_t>::precomputeHelpers() 56 10 : { 57 : 58 10 : auto& domainBlocked = downcast_safe<BlockDescriptor>(this->getDomainDescriptor()); 59 : 60 10 : BooleanVector_t bv1(2); 61 10 : bv1 << true, false; 62 10 : BooleanVector_t bv2(2); 63 10 : bv2 << false, true; 64 : 65 10 : forwardX_ = 66 10 : make_unique<FiniteDifferences<data_t>>(domainBlocked.getDescriptorOfBlock(0), bv1); 67 10 : forwardY_ = 68 10 : make_unique<FiniteDifferences<data_t>>(domainBlocked.getDescriptorOfBlock(0), bv2); 69 10 : backwardX_ = 70 10 : make_unique<FiniteDifferences<data_t>>(domainBlocked.getDescriptorOfBlock(0), bv1, 71 10 : FiniteDifferences<data_t>::DiffType::BACKWARD); 72 10 : backwardY_ = 73 10 : make_unique<FiniteDifferences<data_t>>(domainBlocked.getDescriptorOfBlock(0), bv2, 74 10 : FiniteDifferences<data_t>::DiffType::BACKWARD); 75 10 : } 76 : 77 : template <typename data_t> 78 : LinearOperator<data_t> 79 : SymmetrizedDerivative<data_t>::createBase(const DataDescriptor& domainDescriptor) 80 18 : { 81 : 82 18 : if (!is<IdenticalBlocksDescriptor>(domainDescriptor)) 83 4 : throw LogicError("SymmetrizedDerivative: cannot construct. Domain should be " 84 4 : "IdenticalBlocksDescriptor"); 85 : 86 14 : auto& domainBlocked = downcast_safe<BlockDescriptor>(domainDescriptor); 87 14 : if (domainBlocked.getNumberOfBlocks() != 2) 88 2 : throw LogicError( 89 2 : "SymmetrizedDerivative: cannot construct. Domain should have 2 blocks"); 90 : 91 12 : if (domainBlocked.getDescriptorOfBlock(0).getNumberOfDimensions() != 2) 92 2 : throw LogicError("SymmetrizedDerivative: cannot construct. Domain blocks should be " 93 2 : "identical and 2-dimensional"); 94 : 95 10 : return LinearOperator<data_t>( 96 10 : domainBlocked, IdenticalBlocksDescriptor{3, domainBlocked.getDescriptorOfBlock(0)}); 97 10 : } 98 : 99 : template <typename data_t> 100 : SymmetrizedDerivative<data_t>* SymmetrizedDerivative<data_t>::cloneImpl() const 101 2 : { 102 2 : return new SymmetrizedDerivative(this->getDomainDescriptor()); 103 2 : } 104 : 105 : template <typename data_t> 106 : bool SymmetrizedDerivative<data_t>::isEqual(const LinearOperator<data_t>& other) const 107 2 : { 108 2 : if (!LinearOperator<data_t>::isEqual(other)) 109 0 : return false; 110 : 111 2 : auto otherSD = downcast_safe<SymmetrizedDerivative>(&other); 112 2 : if (!otherSD) 113 0 : return false; 114 : 115 2 : if (otherSD->getDomainDescriptor() != this->getDomainDescriptor() 116 2 : || otherSD->getRangeDescriptor() != this->getRangeDescriptor()) 117 0 : return false; 118 : 119 2 : return true; 120 2 : } 121 : 122 : // ------------------------------------------ 123 : // explicit template instantiation 124 : template class SymmetrizedDerivative<float>; 125 : template class SymmetrizedDerivative<double>; 126 : 127 : } // namespace elsa