Line data Source code
1 : #include "Initializer.h"
2 :
3 : namespace elsa
4 : {
5 : namespace ml
6 : {
7 : namespace detail
8 : {
9 : template <typename data_t>
10 : std::random_device InitializerImpl<data_t>::randomDevice_{};
11 :
12 : template <typename data_t>
13 : bool InitializerImpl<data_t>::useSeed_ = false;
14 :
15 : template <typename data_t>
16 : uint64_t InitializerImpl<data_t>::seed_ = 1;
17 :
18 : template <typename data_t>
19 0 : void InitializerImpl<data_t>::setSeed(uint64_t seed)
20 : {
21 0 : seed_ = seed;
22 0 : useSeed_ = true;
23 0 : }
24 :
25 : template <typename data_t>
26 0 : void InitializerImpl<data_t>::clearSeed()
27 : {
28 0 : useSeed_ = false;
29 0 : }
30 :
31 : template <typename data_t>
32 2 : void InitializerImpl<data_t>::initialize(
33 : data_t* data, index_t size, Initializer initializer,
34 : [[maybe_unused]] const InitializerImpl<data_t>::FanPairType& fanInOut)
35 : {
36 2 : switch (initializer) {
37 1 : case Initializer::Ones:
38 1 : InitializerImpl::ones(data, size);
39 1 : return;
40 1 : case Initializer::Zeros:
41 1 : InitializerImpl::zeros(data, size);
42 1 : return;
43 0 : case Initializer::Uniform:
44 0 : InitializerImpl::uniform(data, size, -1, 1);
45 0 : return;
46 0 : case Initializer::GlorotUniform:
47 0 : InitializerImpl::glorotUniform(data, size, fanInOut);
48 0 : return;
49 0 : case Initializer::HeUniform:
50 0 : InitializerImpl::heUniform(data, size, fanInOut);
51 0 : return;
52 0 : case Initializer::TruncatedNormal:
53 0 : InitializerImpl::truncatedNormal(data, size, 0, 1);
54 0 : return;
55 0 : case Initializer::Normal:
56 0 : InitializerImpl::normal(data, size, 0, 1);
57 0 : return;
58 0 : case Initializer::GlorotNormal:
59 0 : InitializerImpl::glorotNormal(data, size, fanInOut);
60 0 : return;
61 0 : case Initializer::RamLak:
62 0 : InitializerImpl::ramlak(data, size);
63 0 : return;
64 0 : default:
65 0 : throw std::invalid_argument("Unkown random initializer");
66 : }
67 : }
68 :
69 : template <typename data_t>
70 2 : void InitializerImpl<data_t>::initialize(data_t* data, index_t size,
71 : Initializer initializer)
72 : {
73 2 : FanPairType fan{0, 0};
74 2 : initialize(data, size, initializer, fan);
75 2 : }
76 :
77 : template <typename data_t>
78 0 : std::mt19937_64 InitializerImpl<data_t>::getEngine()
79 : {
80 0 : if (useSeed_)
81 0 : return std::mt19937_64(seed_);
82 : else
83 0 : return std::mt19937_64(randomDevice_());
84 : }
85 :
86 : template <typename data_t>
87 0 : void InitializerImpl<data_t>::uniform(data_t* data, index_t size, data_t lowerBound,
88 : data_t upperBound)
89 : {
90 0 : UniformDistributionType dist(lowerBound, upperBound);
91 0 : std::mt19937_64 engine = getEngine();
92 :
93 0 : for (index_t i = 0; i < size; ++i)
94 0 : data[i] = dist(engine);
95 0 : }
96 :
97 : template <typename data_t>
98 0 : void InitializerImpl<data_t>::uniform(data_t* data, index_t size)
99 : {
100 0 : InitializerImpl<data_t>::uniform(data, size, 0, std::numeric_limits<data_t>::max());
101 0 : }
102 :
103 : template <typename data_t>
104 2 : void InitializerImpl<data_t>::constant(data_t* data, index_t size, data_t constant)
105 : {
106 2002 : for (index_t i = 0; i < size; ++i)
107 2000 : data[i] = constant;
108 2 : }
109 :
110 : template <typename data_t>
111 1 : void InitializerImpl<data_t>::ones(data_t* data, index_t size)
112 : {
113 1 : constant(data, size, static_cast<data_t>(1));
114 1 : }
115 :
116 : template <typename data_t>
117 1 : void InitializerImpl<data_t>::zeros(data_t* data, index_t size)
118 : {
119 1 : constant(data, size, static_cast<data_t>(0));
120 1 : }
121 :
122 : template <typename data_t>
123 0 : void InitializerImpl<data_t>::glorotUniform(
124 : data_t* data, index_t size, const InitializerImpl<data_t>::FanPairType& fan)
125 : {
126 0 : auto bound = static_cast<data_t>(std::sqrt(
127 0 : 6 / (static_cast<data_t>(fan.first) + static_cast<data_t>(fan.second))));
128 0 : uniform(data, size, -1 * bound, bound);
129 0 : }
130 :
131 : template <typename data_t>
132 0 : void InitializerImpl<data_t>::glorotNormal(
133 : data_t* data, index_t size, const InitializerImpl<data_t>::FanPairType& fan)
134 : {
135 0 : auto stddev = static_cast<data_t>(std::sqrt(
136 0 : 2 / (static_cast<data_t>(fan.first) + static_cast<data_t>(fan.second))));
137 0 : truncatedNormal(data, size, 0, stddev);
138 0 : }
139 :
140 : template <typename data_t>
141 0 : void InitializerImpl<data_t>::normal(data_t* data, index_t size, data_t mean,
142 : data_t stddev)
143 : {
144 : static_assert(!std::is_same<std::false_type, NormalDistributionType>::value,
145 : "Cannot use normal distribution with the given data-type");
146 :
147 0 : NormalDistributionType dist(mean, stddev);
148 0 : std::mt19937_64 engine = getEngine();
149 :
150 0 : for (index_t i = 0; i < size; ++i)
151 0 : data[i] = dist(engine);
152 0 : }
153 :
154 : template <typename data_t>
155 0 : void InitializerImpl<data_t>::truncatedNormal(data_t* data, index_t size, data_t mean,
156 : data_t stddev)
157 : {
158 : static_assert(!std::is_same<std::false_type, NormalDistributionType>::value,
159 : "Cannot use normal distribution with the given data-type");
160 :
161 0 : NormalDistributionType dist(mean, stddev);
162 0 : std::mt19937_64 engine = getEngine();
163 :
164 0 : for (index_t i = 0; i < size; ++i) {
165 0 : auto value = dist(engine);
166 0 : while (std::abs(mean - value) > 2 * stddev) {
167 0 : value = dist(engine);
168 : }
169 0 : data[i] = value;
170 : }
171 0 : }
172 :
173 : template <typename data_t>
174 0 : void InitializerImpl<data_t>::heNormal(
175 : data_t* data, index_t size, const InitializerImpl<data_t>::FanPairType& fanInOut)
176 : {
177 : auto stddev =
178 0 : std::sqrt(static_cast<data_t>(2) / static_cast<data_t>(fanInOut.first));
179 0 : truncatedNormal(data, size, 0, stddev);
180 0 : }
181 :
182 : template <typename data_t>
183 0 : void InitializerImpl<data_t>::heUniform(data_t* data, index_t size,
184 : const InitializerImpl<data_t>::FanPairType& fan)
185 : {
186 0 : auto bound = static_cast<data_t>(std::sqrt(6 / (static_cast<data_t>(fan.first))));
187 0 : uniform(data, size, -1 * bound, bound);
188 0 : }
189 :
190 : template <typename data_t>
191 0 : void InitializerImpl<data_t>::ramlak(data_t* data, index_t size)
192 : {
193 0 : const index_t hw = as<index_t>((as<data_t>(size) - 1) / 2);
194 :
195 0 : for (index_t i = -hw; i <= hw; ++i) {
196 0 : if ((i % 2) != 0)
197 0 : data[i + hw] =
198 0 : data_t(-1) / (as<data_t>(i) * as<data_t>(i) * pi<data_t> * pi<data_t>);
199 : else
200 0 : data[i + hw] = data_t(0);
201 : }
202 0 : data[hw] = data_t(0.25);
203 0 : }
204 :
205 : template class InitializerImpl<float>;
206 : } // namespace detail
207 : } // namespace ml
208 : } // namespace elsa
|