1 | //===- ModelUnderTrainingRunner.h -- 'development' mode runner --*- C++ -*-===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | // |
9 | |
10 | #ifndef LLVM_ANALYSIS_MODELUNDERTRAININGRUNNER_H |
11 | #define LLVM_ANALYSIS_MODELUNDERTRAININGRUNNER_H |
12 | |
13 | #include "llvm/ADT/STLExtras.h" |
14 | #include "llvm/ADT/iterator_range.h" |
15 | #include "llvm/Analysis/TensorSpec.h" |
16 | #include "llvm/Config/llvm-config.h" |
17 | |
18 | #ifdef LLVM_HAVE_TFLITE |
19 | #include "llvm/Analysis/MLModelRunner.h" |
20 | #include "llvm/Analysis/Utils/TFUtils.h" |
21 | #include "llvm/IR/LLVMContext.h" |
22 | #include "llvm/IR/PassManager.h" |
23 | |
24 | namespace llvm { |
25 | |
26 | /// ModelUnderTrainingRunner - training mode implementation. It uses TFLite |
27 | /// to dynamically load and evaluate a TF SavedModel |
28 | /// (https://www.tensorflow.org/guide/saved_model) converted to TFLite. see |
29 | /// lib/Analysis/models/saved-model-to-tflite.py. Runtime performance is |
30 | /// sacrificed for ease of use while training. |
31 | class ModelUnderTrainingRunner final : public MLModelRunner { |
32 | public: |
33 | // Disallows copy and assign. |
34 | ModelUnderTrainingRunner(const ModelUnderTrainingRunner &) = delete; |
35 | ModelUnderTrainingRunner & |
36 | operator=(const ModelUnderTrainingRunner &) = delete; |
37 | |
38 | const std::vector<TensorSpec> &extraOutputsForLoggingSpecs() const { |
39 | return ExtraOutputsForLogging; |
40 | } |
41 | |
42 | const void *getUntypedExtraOutputValue(size_t ExtraOutputIndex) const { |
43 | return lastEvaluationResult()->getUntypedTensorValue(ExtraOutputIndex + 1); |
44 | } |
45 | |
46 | const std::optional<TFModelEvaluator::EvaluationResult> & |
47 | lastEvaluationResult() const { |
48 | return LastEvaluationResult; |
49 | } |
50 | static bool classof(const MLModelRunner *R) { |
51 | return R->getKind() == MLModelRunner::Kind::Development; |
52 | } |
53 | |
54 | static std::unique_ptr<ModelUnderTrainingRunner> |
55 | createAndEnsureValid(LLVMContext &Ctx, const std::string &ModelPath, |
56 | StringRef DecisionName, |
57 | const std::vector<TensorSpec> &InputSpecs, |
58 | StringRef OutputSpecsPathOverride = "" ); |
59 | |
60 | ModelUnderTrainingRunner( |
61 | LLVMContext &Ctx, const std::string &ModelPath, |
62 | const std::vector<TensorSpec> &InputSpecs, |
63 | const std::vector<TensorSpec> &OutputSpecs, |
64 | const std::vector<TensorSpec> &ExtraOutputsForLogging = {}); |
65 | |
66 | bool isValid() const { return !!Evaluator; } |
67 | |
68 | private: |
69 | std::unique_ptr<TFModelEvaluator> Evaluator; |
70 | const std::vector<TensorSpec> OutputSpecs; |
71 | const std::vector<TensorSpec> ExtraOutputsForLogging; |
72 | std::optional<TFModelEvaluator::EvaluationResult> LastEvaluationResult; |
73 | void *evaluateUntyped() override; |
74 | }; |
75 | |
76 | } // namespace llvm |
77 | #endif // define(LLVM_HAVE_TFLITE) |
78 | #endif // LLVM_ANALYSIS_MODELUNDERTRAININGRUNNER_H |
79 | |