1 | //===- InteractiveModelRunner.h ---- "gym" ML model runner -----*- C++ -*-===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | // |
9 | |
10 | #ifndef LLVM_ANALYSIS_INTERACTIVEMODELRUNNER_H |
11 | #define LLVM_ANALYSIS_INTERACTIVEMODELRUNNER_H |
12 | |
13 | #include "llvm/Analysis/MLModelRunner.h" |
14 | #include "llvm/Analysis/TensorSpec.h" |
15 | #include "llvm/Analysis/Utils/TrainingLogger.h" |
16 | #include "llvm/Config/llvm-config.h" |
17 | #include "llvm/Support/FileSystem.h" |
18 | #include "llvm/Support/raw_ostream.h" |
19 | #include <system_error> |
20 | |
21 | namespace llvm { |
22 | |
23 | /// A MLModelRunner that asks for advice from an external agent, or host. It |
24 | /// uses 2 files - ideally named pipes - one to send data to that agent, and |
25 | /// one to receive advice. |
26 | /// The data exchange uses the training logger (Utils/TrainingLogger.h) format. |
27 | /// Specifically, the compiler will send the log header, set the context, and |
28 | /// send observations; the host is expected to reply with a tensor value after |
29 | /// each observation as a binary buffer that's conforming to the shape of the |
30 | /// advice. Interleaved, the data closely resembles the training log for a |
31 | /// log where we don't capture the reward signal. |
32 | /// |
33 | /// Note that the correctness of the received data is the responsibility of the |
34 | /// host. In particular, if insufficient data were sent, the compiler will block |
35 | /// when waiting for an advice. |
36 | /// |
37 | /// Note that the host can either open the pipes RW, or open first the pipe to |
38 | /// the compiler - i.e. the "Inbound" - and then the "Outbound", to avoid |
39 | /// deadlock. This is because the compiler first tries to open the inbound |
40 | /// (which will hang until there's a writer on the other end). |
41 | class InteractiveModelRunner : public MLModelRunner { |
42 | public: |
43 | InteractiveModelRunner(LLVMContext &Ctx, |
44 | const std::vector<TensorSpec> &Inputs, |
45 | const TensorSpec &Advice, StringRef OutboundName, |
46 | StringRef InboundName); |
47 | |
48 | static bool classof(const MLModelRunner *R) { |
49 | return R->getKind() == MLModelRunner::Kind::Interactive; |
50 | } |
51 | void switchContext(StringRef Name) override { |
52 | Log->switchContext(Name); |
53 | Log->flush(); |
54 | } |
55 | |
56 | virtual ~InteractiveModelRunner(); |
57 | |
58 | private: |
59 | void *evaluateUntyped() override; |
60 | // This must be declared before InEC if we want to initialize it in the |
61 | // ctor initializer list. |
62 | int Inbound = -1; |
63 | const std::vector<TensorSpec> InputSpecs; |
64 | const TensorSpec OutputSpec; |
65 | std::error_code OutEC; |
66 | std::error_code InEC; |
67 | std::vector<char> OutputBuffer; |
68 | std::unique_ptr<Logger> Log; |
69 | }; |
70 | } // namespace llvm |
71 | #endif // LLVM_ANALYSIS_INTERACTIVEMODELRUNNER_H |
72 | |