Line data Source code
1 : #include "Trainable.h" 2 : 3 : namespace elsa::ml 4 : { 5 : template <typename data_t> 6 3 : Trainable<data_t>::Trainable(LayerType layerType, Activation activation, bool useBias, 7 : Initializer kernelInitializer, Initializer biasInitializer, 8 : const std::string& name, int requiredNumberOfDimensions) 9 : : Layer<data_t>(layerType, name, requiredNumberOfDimensions, 10 : /* allowed number of inputs */ 1, 11 : /* is trainable */ true), 12 : useBias_(useBias), 13 : activation_(activation), 14 : kernelInitializer_(kernelInitializer), 15 3 : biasInitializer_(biasInitializer) 16 : { 17 3 : } 18 : 19 : template <typename data_t> 20 2 : Activation Trainable<data_t>::getActivation() const 21 : { 22 2 : return activation_; 23 : } 24 : 25 : template <typename data_t> 26 2 : bool Trainable<data_t>::useBias() const 27 : { 28 2 : return useBias_; 29 : } 30 : 31 : template <typename data_t> 32 1 : Initializer Trainable<data_t>::getKernelInitializer() const 33 : { 34 1 : return kernelInitializer_; 35 : } 36 : 37 : template <typename data_t> 38 1 : Initializer Trainable<data_t>::getBiasInitializer() const 39 : { 40 1 : return biasInitializer_; 41 : } 42 : 43 : template class Trainable<float>; 44 : } // namespace elsa::ml