Line data Source code
1 : #include "Model.h"
2 :
3 : #include <deque>
4 : #include "TypeCasts.hpp"
5 :
6 : namespace elsa::ml
7 : {
8 : template <typename data_t, MlBackend Backend>
9 0 : Model<data_t, Backend>::Model(std::initializer_list<Input<data_t>*> inputs,
10 : std::initializer_list<Layer<data_t>*> outputs,
11 : const std::string& name)
12 0 : : name_(name), inputs_(inputs), outputs_(outputs)
13 : {
14 : // Save the batch-size this model uses
15 0 : batchSize_ = inputs_.front()->getBatchSize();
16 :
17 : // Set all input-descriptors by traversing the graph
18 0 : setInputDescriptors();
19 0 : }
20 :
21 : template <typename data_t, MlBackend Backend>
22 0 : Model<data_t, Backend>::Model(Input<data_t>* input, Layer<data_t>* output,
23 : const std::string& name)
24 0 : : Model({input}, {output}, name)
25 : {
26 0 : }
27 :
28 : template <typename data_t, MlBackend Backend>
29 0 : index_t Model<data_t, Backend>::getBatchSize() const
30 : {
31 0 : return batchSize_;
32 : }
33 :
34 : template <typename data_t, MlBackend Backend>
35 0 : void Model<data_t, Backend>::setInputDescriptors()
36 : {
37 : // TODO(tellenbach): Replace by Graph::visit method
38 :
39 : // Get the graph
40 0 : auto& graph = detail::State<data_t>::getGraph();
41 :
42 : // Get all nodes of the graph, i.e., a map with node-indices as keys
43 : // and nodes as values
44 0 : auto& nodes = graph.getNodes();
45 :
46 : // We maintain a list of nodes we've already visited and a call-queue to
47 : // ensure the correct order of traversal.
48 0 : std::vector<bool> visited(asUnsigned(graph.getNumberOfNodes()));
49 : // Note that this queue is in fact a deque, so we can push and pop from
50 : // both, the front and back.
51 0 : std::deque<index_t> queue;
52 :
53 : // Perform an iterative depth-first traversal through the graph
54 0 : for (auto in : inputs_) {
55 : // Push the input-node onto the call-queue
56 0 : queue.push_back(in->getGlobalIndex());
57 :
58 0 : while (!queue.empty()) {
59 : // The current node is the top of the stack and we compute its
60 : // output descriptor
61 0 : index_t s = queue.back();
62 0 : queue.pop_back();
63 :
64 0 : if (!visited[static_cast<std::size_t>(s)]) {
65 : // If the current node is a merging layer, its
66 : // output-descriptor can depend on *all* input-descriptors.
67 : // We therefore have to make sure that we really set all
68 : // input-descriptors before attempting to compute a merging
69 : // layer's output-descriptor or attempting to continue the
70 : // traversal.
71 : //
72 : // We do this by checking if the number of edges that reach
73 : // a merging layer is equal to the number of set inputs.
74 : //
75 : // +---------+
76 : // | Layer 1 | +---------+
77 : // +---------+ | Layer 2 |
78 : // | +---------+
79 : // v |
80 : // +---------+ |
81 : // | Merging |<-------+
82 : // +---------+
83 : // (?) Do we have the input from Layer1 *and* Layer2?
84 : //
85 : // We also have to make sure that a merging layer get's
86 : // visited again when all of its inputs are set. Pushing the
87 : // layer on top of the queue again causes an infinite loop
88 : // since we will always visit it again, see that we can't
89 : // compute its output-descriptor yet, visit it again...
90 : //
91 : // To solve this problem we push a merging layer to the
92 : // *front* of the queue such that it get's visited again,
93 : // in a delayed fashion.
94 0 : if (!nodes.at(s).getData()->canMerge()
95 0 : || nodes.at(s).getData()->getNumberOfInputs()
96 0 : == static_cast<index_t>(graph.getIncomingEdges(s).size())) {
97 : // We end up here if we either have no merging layer
98 : // or if we have a merging layer but already gathered
99 : // all of its inputs
100 0 : nodes.at(s).getData()->computeOutputDescriptor();
101 :
102 0 : visited[asUnsigned(s)] = true;
103 : } else {
104 : // We end up here if we have a merging layer but haven't
105 : // collected all of its inputs yet. In this case we
106 : // push the layer to the *front* of out call-queue.
107 0 : queue.push_front(s);
108 :
109 : // Make sure we don't handle a merging layer's childs
110 : // before handling the layer itself
111 0 : continue;
112 : }
113 : }
114 :
115 : // TODO(tellenbach): Stop if we reach one of the model's output
116 : // layers
117 :
118 : // Consider all outgoing edges of a node and set their
119 : // input-descriptors to the output-descriptor of their parent
120 : // node
121 0 : for (auto& e : graph.getOutgoingEdges(s)) {
122 0 : auto idx = e.end()->getIndex();
123 :
124 : // If we haven't visited this child node yet, add it to the
125 : // call-queue and set its input-descriptor
126 0 : if (!visited[static_cast<std::size_t>(idx)]) {
127 0 : queue.push_back(idx);
128 0 : e.end()->getData()->setInputDescriptor(
129 0 : nodes.at(s).getData()->getOutputDescriptor());
130 : }
131 : }
132 : }
133 : }
134 0 : }
135 :
136 : template <typename data_t, MlBackend Backend>
137 0 : void Model<data_t, Backend>::compile(const Loss<data_t>& loss, Optimizer<data_t>* optimizer)
138 : {
139 0 : loss_ = loss;
140 0 : optimizer_ = optimizer;
141 0 : detail::BackendAdaptor<data_t, Backend>::constructBackendGraph(this);
142 0 : }
143 :
144 : template <typename data_t, MlBackend Backend>
145 : typename Model<data_t, Backend>::History
146 0 : Model<data_t, Backend>::fit(const std::vector<DataContainer<data_t>>& x,
147 : const std::vector<DataContainer<data_t>>& y, index_t epochs)
148 : {
149 : // Check if all elements of x have the same data-container
150 0 : if (std::adjacent_find(x.begin(), x.end(),
151 0 : [](const auto& dc0, const auto& dc1) {
152 0 : return dc0.getDataDescriptor() != dc1.getDataDescriptor();
153 : })
154 0 : != x.end())
155 0 : throw std::invalid_argument("All elements of x must have the same data-descriptor");
156 :
157 : // Check if all elements of y have the same data-container
158 0 : if (std::adjacent_find(y.begin(), y.end(),
159 0 : [](const auto& dc0, const auto& dc1) {
160 0 : return dc0.getDataDescriptor() != dc1.getDataDescriptor();
161 : })
162 0 : != y.end())
163 0 : throw std::invalid_argument("All elements of y must have the same data-descriptor");
164 :
165 0 : return detail::BackendAdaptor<data_t, Backend>::fit(this, x, y, epochs);
166 : }
167 :
168 : template <typename data_t, MlBackend Backend>
169 0 : DataContainer<data_t> Model<data_t, Backend>::predict(const DataContainer<data_t>& x)
170 : {
171 0 : return detail::BackendAdaptor<data_t, Backend>::predict(this, x);
172 : }
173 :
174 : template <typename data_t, MlBackend Backend>
175 0 : Optimizer<data_t>* Model<data_t, Backend>::getOptimizer()
176 : {
177 0 : return optimizer_;
178 : }
179 :
180 : template <typename data_t, MlBackend Backend>
181 0 : std::string Model<data_t, Backend>::getName() const
182 : {
183 0 : return name_;
184 : }
185 :
186 : template <typename data_t, MlBackend Backend>
187 0 : std::vector<Input<data_t>*> Model<data_t, Backend>::getInputs()
188 : {
189 0 : return inputs_;
190 : }
191 :
192 : template <typename data_t, MlBackend Backend>
193 0 : std::vector<Layer<data_t>*> Model<data_t, Backend>::getOutputs()
194 : {
195 0 : return outputs_;
196 : }
197 :
198 : template <typename data_t, MlBackend Backend>
199 : detail::Graph<typename detail::BackendSelector<data_t, Backend, LayerType::Undefined>::Type,
200 : false>&
201 0 : Model<data_t, Backend>::getBackendGraph()
202 : {
203 0 : return backendGraph_;
204 : }
205 :
206 : template <typename data_t, MlBackend Backend>
207 : const detail::Graph<
208 : typename detail::BackendSelector<data_t, Backend, LayerType::Undefined>::Type, false>&
209 0 : Model<data_t, Backend>::getBackendGraph() const
210 : {
211 0 : return backendGraph_;
212 : }
213 :
214 : template <typename data_t, MlBackend Backend>
215 0 : const Loss<data_t>& Model<data_t, Backend>::getLoss() const
216 : {
217 0 : return loss_;
218 : }
219 :
220 : template class Model<float, MlBackend::Dnnl>;
221 : template class Model<float, MlBackend::Cudnn>;
222 : } // namespace elsa::ml
|