Line data Source code
1 : #pragma once 2 : 3 : #include <string> 4 : #include <fstream> 5 : 6 : #include "DataContainer.h" 7 : #include "VolumeDescriptor.h" 8 : #include "State.h" 9 : 10 : namespace elsa::ml 11 : { 12 : /// Common ml utilities 13 : /// 14 : /// @author David Tellenbach 15 : struct Utils { 16 : /// Utilities to plot a model 17 : struct Plotting { 18 : 19 : /// Direction for plotting the model graph. 20 : enum class RankDir { 21 : /// Plot model from top to bottom 22 : TD, 23 : /// Plot model from left to right 24 : LR 25 : }; 26 : 27 : /// Convert a model to Graphviz' DOT format. 28 : /// 29 : /// @param model The model to plot 30 : /// @param filename The filename of the DOT file. This parameter is 31 : /// optional and defaults to "model.png" 32 : /// @param rankDir The direction of the plotted DOT graph. This 33 : /// parameter is optional and defaults to RankDir::TD. 34 : /// @param dpi The dots-per-inch encoded in the DOT file. 35 : template <typename T> 36 : static void modelToDot([[maybe_unused]] const T& model, 37 : const std::string& filename = "model.png", 38 : RankDir rankDir = RankDir::TD, int dpi = 96) 39 : { 40 : std::ofstream os(filename); 41 : 42 : auto& graph = detail::State<real_t>::getGraph(); 43 : 44 : // Use Eigen's formatting facilities to control the format of in- and output shapes 45 : Eigen::IOFormat fmt(Eigen::StreamPrecision, 0, ", ", ", ", "", "", "(", ")"); 46 : 47 : graph.toDot( 48 : filename, 49 : [&fmt](auto layer, index_t idx) { 50 : std::stringstream ss; 51 : ss << idx << " [shape=\"record\", label=\"" << layer->getName() << ": " 52 : << detail::getEnumMemberAsString(layer->getLayerType()) 53 : << " | {input: \\l| output:\\l} | { ["; 54 : for (int i = 0; i < layer->getNumberOfInputs(); ++i) { 55 : ss << layer->getInputDescriptor(i) 56 : .getNumberOfCoefficientsPerDimension() 57 : .format(fmt) 58 : << (i == layer->getNumberOfInputs() - 1 ? "]" : ", "); 59 : } 60 : ss << " \\l | " 61 : << layer->getOutputDescriptor() 62 : .getNumberOfCoefficientsPerDimension() 63 : .format(fmt) 64 : << "\\l}\"];\n"; 65 : return ss.str(); 66 : }, 67 : dpi, 68 : rankDir == RankDir::TD ? std::decay_t<decltype(graph)>::RankDir::TD 69 : : std::decay_t<decltype(graph)>::RankDir::LR); 70 : } 71 : }; 72 : 73 : /// Common encoding routines 74 : struct Encoding { 75 : /// Encode a DataContainer in one-hot encoding 76 : /// 77 : /// We expect the input-data to be shaped as ([1,] batchSize). The 78 : /// new one-hot encoded descriptor has shape (num_classes, batch_size) 79 : template <typename data_t> 80 5 : static DataContainer<data_t> toOneHot(const DataContainer<data_t>& dc, 81 : index_t numClasses, index_t batchSize) 82 : { 83 10 : Eigen::VectorXf oneHotData(batchSize * numClasses); 84 5 : oneHotData.setZero(); 85 : 86 5 : index_t segmentIdx = 0; 87 25 : for (index_t i = 0; i < batchSize; ++i) { 88 20 : oneHotData.segment(segmentIdx, numClasses) 89 20 : .coeffRef(static_cast<index_t>(dc[i])) = data_t(1); 90 20 : segmentIdx += numClasses; 91 : } 92 : 93 5 : IndexVector_t dims{{numClasses, batchSize}}; 94 10 : return DataContainer<data_t>(VolumeDescriptor(dims), oneHotData); 95 : } 96 : 97 : /// Decode a DataContainer that is encoded in one-hot encoding. 98 : template <typename data_t> 99 1 : static DataContainer<data_t> fromOneHot(const DataContainer<data_t> dc, 100 : index_t numClasses) 101 : { 102 1 : index_t batchSize = dc.getDataDescriptor().getNumberOfCoefficientsPerDimension()(1); 103 2 : Eigen::VectorXf data(batchSize); 104 1 : data.setZero(); 105 : #ifndef ELSA_CUDA_VECTOR 106 1 : auto expr = (data_t(1) * dc).eval(); 107 : #else 108 : Eigen::VectorXf expr(dc.getSize()); 109 : for (index_t i = 0; i < dc.getSize(); ++i) { 110 : expr[i] = dc[i]; 111 : } 112 : #endif 113 5 : for (int i = 0; i < batchSize; ++i) { 114 : index_t maxIdx; 115 4 : expr.segment(i * numClasses, numClasses).maxCoeff(&maxIdx); 116 4 : data[i] = static_cast<data_t>(maxIdx); 117 : } 118 2 : IndexVector_t dims{{batchSize}}; 119 2 : VolumeDescriptor desc(dims); 120 2 : return DataContainer<data_t>(desc, data); 121 : } 122 : }; 123 : }; 124 : } // namespace elsa::ml