LCOV - code coverage report
Current view: top level - ml - Utils.h (source / functions) Hit Total Coverage
Test: test_coverage.info.cleaned Lines: 21 21 100.0 %
Date: 2022-07-06 02:47:47 Functions: 2 2 100.0 %

          Line data    Source code
       1             : #pragma once
       2             : 
       3             : #include <string>
       4             : #include <fstream>
       5             : 
       6             : #include "DataContainer.h"
       7             : #include "VolumeDescriptor.h"
       8             : #include "State.h"
       9             : 
      10             : namespace elsa::ml
      11             : {
      12             :     /// Common ml utilities
      13             :     ///
      14             :     /// @author David Tellenbach
      15             :     struct Utils {
      16             :         /// Utilities to plot a model
      17             :         struct Plotting {
      18             : 
      19             :             /// Direction for plotting the model graph.
      20             :             enum class RankDir {
      21             :                 /// Plot model from top to bottom
      22             :                 TD,
      23             :                 /// Plot model from left to right
      24             :                 LR
      25             :             };
      26             : 
      27             :             /// Convert a model to Graphviz' DOT format.
      28             :             ///
      29             :             /// @param model The model to plot
      30             :             /// @param filename The filename of the DOT file. This parameter is
      31             :             /// optional and defaults to "model.png"
      32             :             /// @param rankDir The direction of the plotted DOT graph. This
      33             :             /// parameter is optional and defaults to RankDir::TD.
      34             :             /// @param dpi The dots-per-inch encoded in the DOT file.
      35             :             template <typename T>
      36             :             static void modelToDot([[maybe_unused]] const T& model,
      37             :                                    const std::string& filename = "model.png",
      38             :                                    RankDir rankDir = RankDir::TD, int dpi = 96)
      39             :             {
      40             :                 std::ofstream os(filename);
      41             : 
      42             :                 auto& graph = detail::State<real_t>::getGraph();
      43             : 
      44             :                 // Use Eigen's formatting facilities to control the format of in- and output shapes
      45             :                 Eigen::IOFormat fmt(Eigen::StreamPrecision, 0, ", ", ", ", "", "", "(", ")");
      46             : 
      47             :                 graph.toDot(
      48             :                     filename,
      49             :                     [&fmt](auto layer, index_t idx) {
      50             :                         std::stringstream ss;
      51             :                         ss << idx << " [shape=\"record\", label=\"" << layer->getName() << ": "
      52             :                            << detail::getEnumMemberAsString(layer->getLayerType())
      53             :                            << " | {input: \\l| output:\\l} | { [";
      54             :                         for (int i = 0; i < layer->getNumberOfInputs(); ++i) {
      55             :                             ss << layer->getInputDescriptor(i)
      56             :                                       .getNumberOfCoefficientsPerDimension()
      57             :                                       .format(fmt)
      58             :                                << (i == layer->getNumberOfInputs() - 1 ? "]" : ", ");
      59             :                         }
      60             :                         ss << " \\l | "
      61             :                            << layer->getOutputDescriptor()
      62             :                                   .getNumberOfCoefficientsPerDimension()
      63             :                                   .format(fmt)
      64             :                            << "\\l}\"];\n";
      65             :                         return ss.str();
      66             :                     },
      67             :                     dpi,
      68             :                     rankDir == RankDir::TD ? std::decay_t<decltype(graph)>::RankDir::TD
      69             :                                            : std::decay_t<decltype(graph)>::RankDir::LR);
      70             :             }
      71             :         };
      72             : 
      73             :         /// Common encoding routines
      74             :         struct Encoding {
      75             :             /// Encode a DataContainer in one-hot encoding
      76             :             ///
      77             :             /// We expect the input-data to be shaped as ([1,] batchSize). The
      78             :             /// new one-hot encoded descriptor has shape (num_classes, batch_size)
      79             :             template <typename data_t>
      80           5 :             static DataContainer<data_t> toOneHot(const DataContainer<data_t>& dc,
      81             :                                                   index_t numClasses, index_t batchSize)
      82             :             {
      83          10 :                 Eigen::VectorXf oneHotData(batchSize * numClasses);
      84           5 :                 oneHotData.setZero();
      85             : 
      86           5 :                 index_t segmentIdx = 0;
      87          25 :                 for (index_t i = 0; i < batchSize; ++i) {
      88          20 :                     oneHotData.segment(segmentIdx, numClasses)
      89          20 :                         .coeffRef(static_cast<index_t>(dc[i])) = data_t(1);
      90          20 :                     segmentIdx += numClasses;
      91             :                 }
      92             : 
      93           5 :                 IndexVector_t dims{{numClasses, batchSize}};
      94          10 :                 return DataContainer<data_t>(VolumeDescriptor(dims), oneHotData);
      95             :             }
      96             : 
      97             :             /// Decode a DataContainer that is encoded in one-hot encoding.
      98             :             template <typename data_t>
      99           1 :             static DataContainer<data_t> fromOneHot(const DataContainer<data_t> dc,
     100             :                                                     index_t numClasses)
     101             :             {
     102           1 :                 index_t batchSize = dc.getDataDescriptor().getNumberOfCoefficientsPerDimension()(1);
     103           2 :                 Eigen::VectorXf data(batchSize);
     104           1 :                 data.setZero();
     105             : #ifndef ELSA_CUDA_VECTOR
     106           1 :                 auto expr = (data_t(1) * dc).eval();
     107             : #else
     108             :                 Eigen::VectorXf expr(dc.getSize());
     109             :                 for (index_t i = 0; i < dc.getSize(); ++i) {
     110             :                     expr[i] = dc[i];
     111             :                 }
     112             : #endif
     113           5 :                 for (int i = 0; i < batchSize; ++i) {
     114             :                     index_t maxIdx;
     115           4 :                     expr.segment(i * numClasses, numClasses).maxCoeff(&maxIdx);
     116           4 :                     data[i] = static_cast<data_t>(maxIdx);
     117             :                 }
     118           2 :                 IndexVector_t dims{{batchSize}};
     119           2 :                 VolumeDescriptor desc(dims);
     120           2 :                 return DataContainer<data_t>(desc, data);
     121             :             }
     122             :         };
     123             :     };
     124             : } // namespace elsa::ml

Generated by: LCOV version 1.15