Line data Source code
1 : #pragma once 2 : 3 : #include <tuple> 4 : #include <variant> 5 : #include <optional> 6 : 7 : #include "elsaDefines.h" 8 : #include "TypeCasts.hpp" 9 : #include "DataDescriptor.h" 10 : #include "ExpressionPredicates.h" 11 : #include "DataHandlerCPU.h" 12 : #include "DataHandlerMapCPU.h" 13 : 14 : #ifdef ELSA_CUDA_VECTOR 15 : #include "DataHandlerGPU.h" 16 : #endif 17 : 18 : namespace elsa 19 : { 20 : /// Compile time switch to select to recursively evaluate for expression type 21 : template <bool GPU, class Operand, std::enable_if_t<isExpression<Operand>, int> = 0> 22 : constexpr auto evaluateOrReturn(Operand const& operand) 23 48606 : { 24 48606 : return operand.template eval<GPU>(); 25 48606 : } 26 : 27 : /// Compile time switch to return-by-value of arithmetic types 28 : template <bool GPU, class Operand, std::enable_if_t<isArithmetic<Operand>, int> = 0> 29 : constexpr auto evaluateOrReturn(Operand const operand) 30 41799 : { 31 41799 : return operand; 32 41799 : } 33 : 34 : /// Compile time switch to return data in container 35 : template <bool GPU, class Operand, std::enable_if_t<isDataContainer<Operand>, int> = 0> 36 : constexpr auto evaluateOrReturn(Operand const& operand) 37 143335 : { 38 143335 : using data_t = typename Operand::value_type; 39 : 40 143335 : if constexpr (GPU) { 41 : #ifdef ELSA_CUDA_VECTOR 42 : if (auto handler = downcast_safe<DataHandlerGPU<data_t>>(operand._dataHandler.get())) { 43 : return handler->accessData(); 44 : } else if (auto handler = 45 : downcast_safe<DataHandlerMapGPU<data_t>>(operand._dataHandler.get())) { 46 : return handler->accessData(); 47 : } else { 48 : throw InternalError("Unknown handler type"); 49 : } 50 : #endif 51 143335 : } else { 52 143335 : if (auto handler = downcast_safe<DataHandlerCPU<data_t>>(operand._dataHandler.get())) { 53 143228 : return handler->accessData(); 54 143228 : } else if (auto handler = 55 107 : downcast_safe<DataHandlerMapCPU<data_t>>(operand._dataHandler.get())) { 56 107 : return handler->accessData(); 57 107 : } else { 58 0 : throw InternalError("Unknown handler type"); 59 0 : } 60 143335 : } 61 143335 : } 62 : 63 : /// Type trait which decides if const lvalue reference or not 64 : template <typename Operand> 65 : class ReferenceOrNot 66 : { 67 : public: 68 : using type = typename std::conditional<isExpression<Operand> || isArithmetic<Operand>, 69 : Operand, const Operand&>::type; 70 : }; 71 : 72 : /** 73 : * @brief Temporary expression type which enables lazy-evaluation of expression 74 : * 75 : * @author Jens Petit 76 : * 77 : * @tparam Callable - the operation to be performed 78 : * @tparam Operands - the objects on which the operation is performed 79 : * 80 : */ 81 : template <typename Callable, typename... Operands> 82 : class Expression 83 : { 84 : public: 85 : /// type which bundles the meta information to create a new DataContainer 86 : using MetaInfo_t = std::pair<DataDescriptor const&, DataHandlerType>; 87 : 88 : /// indicates data type is used in the expression 89 : using data_t = typename GetOperandsDataType<Operands...>::data_t; 90 : 91 : /// Constructor 92 : Expression(Callable func, Operands const&... args) 93 : : _callable(func), _args(args...), _dataMetaInfo(initDescriptor(args...)) 94 116906 : { 95 116906 : } 96 : 97 : /// Evaluates the expression 98 : template <bool GPU = false> 99 : auto eval() const 100 116900 : { 101 : // generic lambda for evaluating tree, we need this to get a pack again out of the tuple 102 116900 : auto const callCallable = [this](Operands const&... args) { 103 : // here evaluateOrReturn is called on each Operand within args as the unpacking 104 : // takes place after the fcn call 105 : // selects the right callable from the Callables struct with multiple lambdas 106 116900 : if constexpr (GPU) { 107 116900 : return _callable(evaluateOrReturn<GPU>(args)..., GPU); 108 116900 : } else { 109 116900 : return _callable(evaluateOrReturn<GPU>(args)...); 110 116900 : } 111 116900 : }; 112 116900 : return std::apply(callCallable, _args); 113 116900 : } 114 : 115 41704 : MetaInfo_t getDataMetaInfo() const { return _dataMetaInfo; } 116 : 117 : private: 118 : /// The function to call on the operand(s) 119 : const Callable _callable; 120 : 121 : /// Contains all operands saved as const references (DataContainers) or copies 122 : /// (Expressions and arithmetic types) 123 : std::tuple<typename ReferenceOrNot<Operands>::type...> _args; 124 : 125 : /// saves the meta information to create a new DataContainer out of an expression 126 : const MetaInfo_t _dataMetaInfo; 127 : 128 : /// correctly returns the DataContainer descriptor based on the operands (either 129 : /// expressions or Datacontainers) 130 : MetaInfo_t initDescriptor(Operands const&... args) 131 116906 : { 132 116906 : if (auto info = getMetaInfoFromContainers(args...); info.has_value()) { 133 97146 : return *info; 134 97146 : } else { 135 19760 : if (auto info = getMetaInfoFromExpressions(args...); info.has_value()) { 136 19760 : return *info; 137 19760 : } else { 138 0 : throw LogicError("No meta info available, cannot create expression"); 139 0 : } 140 19760 : } 141 116906 : } 142 : 143 : /// base recursive case if no DataContainer as operand 144 19760 : std::optional<MetaInfo_t> getMetaInfoFromContainers() { return {}; } 145 : 146 : /// recursive traversal of all contained DataContainers 147 : template <class T, class... Ts> 148 : std::optional<MetaInfo_t> getMetaInfoFromContainers(T& arg, Ts&... args) 149 148924 : { 150 148924 : if constexpr (isDataContainer<T>) { 151 51778 : return MetaInfo_t{arg.getDataDescriptor(), arg.getDataHandlerType()}; 152 51778 : } else { 153 51778 : return getMetaInfoFromContainers(args...); 154 51778 : } 155 148924 : } 156 : 157 : /// base recursive case if no Expression as operand 158 0 : std::optional<MetaInfo_t> getMetaInfoFromExpressions() { return {}; } 159 : 160 : /// recursive traversal of all contained Expressions 161 : template <class T, class... Ts> 162 : std::optional<MetaInfo_t> getMetaInfoFromExpressions(T& arg, Ts&... args) 163 35222 : { 164 35222 : if constexpr (isExpression<T>) { 165 15462 : return arg.getDataMetaInfo(); 166 15462 : } else { 167 15462 : return getMetaInfoFromExpressions(args...); 168 15462 : } 169 35222 : } 170 : }; 171 : } // namespace elsa