1 | //===- MLModelRunner.h ---- ML model runner interface -----------*- C++ -*-===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | // |
9 | |
10 | #ifndef LLVM_ANALYSIS_MLMODELRUNNER_H |
11 | #define LLVM_ANALYSIS_MLMODELRUNNER_H |
12 | |
13 | #include "llvm/Analysis/TensorSpec.h" |
14 | #include "llvm/IR/PassManager.h" |
15 | |
16 | namespace llvm { |
17 | class LLVMContext; |
18 | |
19 | /// MLModelRunner interface: abstraction of a mechanism for evaluating a |
20 | /// ML model. More abstractly, evaluating a function that has as tensors as |
21 | /// arguments, described via TensorSpecs, and returns a tensor. Currently, the |
22 | /// latter is assumed to be a scalar, in absence of more elaborate scenarios. |
23 | /// NOTE: feature indices are expected to be consistent all accross |
24 | /// MLModelRunners (pertaining to the same model), and also Loggers (see |
25 | /// TFUtils.h) |
26 | class MLModelRunner { |
27 | public: |
28 | // Disallows copy and assign. |
29 | MLModelRunner(const MLModelRunner &) = delete; |
30 | MLModelRunner &operator=(const MLModelRunner &) = delete; |
31 | virtual ~MLModelRunner() = default; |
32 | |
33 | template <typename T> T evaluate() { |
34 | return *reinterpret_cast<T *>(evaluateUntyped()); |
35 | } |
36 | |
37 | template <typename T, typename I> T *getTensor(I FeatureID) { |
38 | return reinterpret_cast<T *>( |
39 | getTensorUntyped(Index: static_cast<size_t>(FeatureID))); |
40 | } |
41 | |
42 | template <typename T, typename I> const T *getTensor(I FeatureID) const { |
43 | return reinterpret_cast<const T *>( |
44 | getTensorUntyped(Index: static_cast<size_t>(FeatureID))); |
45 | } |
46 | |
47 | void *getTensorUntyped(size_t Index) { return InputBuffers[Index]; } |
48 | const void *getTensorUntyped(size_t Index) const { |
49 | return (const_cast<MLModelRunner *>(this))->getTensorUntyped(Index); |
50 | } |
51 | |
52 | enum class Kind : int { Unknown, Release, Development, NoOp, Interactive }; |
53 | Kind getKind() const { return Type; } |
54 | virtual void switchContext(StringRef Name) {} |
55 | |
56 | protected: |
57 | MLModelRunner(LLVMContext &Ctx, Kind Type, size_t NrInputs) |
58 | : Ctx(Ctx), Type(Type), InputBuffers(NrInputs) { |
59 | assert(Type != Kind::Unknown); |
60 | } |
61 | virtual void *evaluateUntyped() = 0; |
62 | |
63 | void setUpBufferForTensor(size_t Index, const TensorSpec &Spec, |
64 | void *Buffer) { |
65 | if (!Buffer) { |
66 | OwnedBuffers.emplace_back(args: Spec.getTotalTensorBufferSize()); |
67 | Buffer = OwnedBuffers.back().data(); |
68 | } |
69 | InputBuffers[Index] = Buffer; |
70 | } |
71 | |
72 | LLVMContext &Ctx; |
73 | const Kind Type; |
74 | |
75 | private: |
76 | std::vector<void *> InputBuffers; |
77 | std::vector<std::vector<char *>> OwnedBuffers; |
78 | }; |
79 | } // namespace llvm |
80 | |
81 | #endif // LLVM_ANALYSIS_MLMODELRUNNER_H |
82 | |