1//===- AsyncRuntime.cpp - Async runtime reference implementation ----------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implements basic Async runtime API for supporting Async dialect
10// to LLVM dialect lowering.
11//
12//===----------------------------------------------------------------------===//
13
14#include "mlir/ExecutionEngine/AsyncRuntime.h"
15
16#include <atomic>
17#include <cassert>
18#include <condition_variable>
19#include <functional>
20#include <iostream>
21#include <mutex>
22#include <thread>
23#include <vector>
24
25#include "llvm/ADT/StringMap.h"
26#include "llvm/Support/ThreadPool.h"
27
28using namespace mlir::runtime;
29
30//===----------------------------------------------------------------------===//
31// Async runtime API.
32//===----------------------------------------------------------------------===//
33
34namespace mlir {
35namespace runtime {
36namespace {
37
38// Forward declare class defined below.
39class RefCounted;
40
41// -------------------------------------------------------------------------- //
42// AsyncRuntime orchestrates all async operations and Async runtime API is built
43// on top of the default runtime instance.
44// -------------------------------------------------------------------------- //
45
46class AsyncRuntime {
47public:
48 AsyncRuntime() : numRefCountedObjects(0) {}
49
50 ~AsyncRuntime() {
51 threadPool.wait(); // wait for the completion of all async tasks
52 assert(getNumRefCountedObjects() == 0 &&
53 "all ref counted objects must be destroyed");
54 }
55
56 int64_t getNumRefCountedObjects() {
57 return numRefCountedObjects.load(m: std::memory_order_relaxed);
58 }
59
60 llvm::ThreadPoolInterface &getThreadPool() { return threadPool; }
61
62private:
63 friend class RefCounted;
64
65 // Count the total number of reference counted objects in this instance
66 // of an AsyncRuntime. For debugging purposes only.
67 void addNumRefCountedObjects() {
68 numRefCountedObjects.fetch_add(i: 1, m: std::memory_order_relaxed);
69 }
70 void dropNumRefCountedObjects() {
71 numRefCountedObjects.fetch_sub(i: 1, m: std::memory_order_relaxed);
72 }
73
74 std::atomic<int64_t> numRefCountedObjects;
75 llvm::DefaultThreadPool threadPool;
76};
77
78// -------------------------------------------------------------------------- //
79// A state of the async runtime value (token, value or group).
80// -------------------------------------------------------------------------- //
81
82class State {
83public:
84 enum StateEnum : int8_t {
85 // The underlying value is not yet available for consumption.
86 kUnavailable = 0,
87 // The underlying value is available for consumption. This state can not
88 // transition to any other state.
89 kAvailable = 1,
90 // This underlying value is available and contains an error. This state can
91 // not transition to any other state.
92 kError = 2,
93 };
94
95 /* implicit */ State(StateEnum s) : state(s) {}
96 /* implicit */ operator StateEnum() { return state; }
97
98 bool isUnavailable() const { return state == kUnavailable; }
99 bool isAvailable() const { return state == kAvailable; }
100 bool isError() const { return state == kError; }
101 bool isAvailableOrError() const { return isAvailable() || isError(); }
102
103 const char *debug() const {
104 switch (state) {
105 case kUnavailable:
106 return "unavailable";
107 case kAvailable:
108 return "available";
109 case kError:
110 return "error";
111 }
112 }
113
114private:
115 StateEnum state;
116};
117
118// -------------------------------------------------------------------------- //
119// A base class for all reference counted objects created by the async runtime.
120// -------------------------------------------------------------------------- //
121
122class RefCounted {
123public:
124 RefCounted(AsyncRuntime *runtime, int64_t refCount = 1)
125 : runtime(runtime), refCount(refCount) {
126 runtime->addNumRefCountedObjects();
127 }
128
129 virtual ~RefCounted() {
130 assert(refCount.load() == 0 && "reference count must be zero");
131 runtime->dropNumRefCountedObjects();
132 }
133
134 RefCounted(const RefCounted &) = delete;
135 RefCounted &operator=(const RefCounted &) = delete;
136
137 void addRef(int64_t count = 1) { refCount.fetch_add(i: count); }
138
139 void dropRef(int64_t count = 1) {
140 int64_t previous = refCount.fetch_sub(i: count);
141 assert(previous >= count && "reference count should not go below zero");
142 if (previous == count)
143 destroy();
144 }
145
146protected:
147 virtual void destroy() { delete this; }
148
149private:
150 AsyncRuntime *runtime;
151 std::atomic<int64_t> refCount;
152};
153
154} // namespace
155
156// Returns the default per-process instance of an async runtime.
157static std::unique_ptr<AsyncRuntime> &getDefaultAsyncRuntimeInstance() {
158 static auto runtime = std::make_unique<AsyncRuntime>();
159 return runtime;
160}
161
162static void resetDefaultAsyncRuntime() {
163 return getDefaultAsyncRuntimeInstance().reset();
164}
165
166static AsyncRuntime *getDefaultAsyncRuntime() {
167 return getDefaultAsyncRuntimeInstance().get();
168}
169
170// Async token provides a mechanism to signal asynchronous operation completion.
171struct AsyncToken : public RefCounted {
172 // AsyncToken created with a reference count of 2 because it will be returned
173 // to the `async.execute` caller and also will be later on emplaced by the
174 // asynchronously executed task. If the caller immediately will drop its
175 // reference we must ensure that the token will be alive until the
176 // asynchronous operation is completed.
177 AsyncToken(AsyncRuntime *runtime)
178 : RefCounted(runtime, /*refCount=*/2), state(State::kUnavailable) {}
179
180 std::atomic<State::StateEnum> state;
181
182 // Pending awaiters are guarded by a mutex.
183 std::mutex mu;
184 std::condition_variable cv;
185 std::vector<std::function<void()>> awaiters;
186};
187
188// Async value provides a mechanism to access the result of asynchronous
189// operations. It owns the storage that is used to store/load the value of the
190// underlying type, and a flag to signal if the value is ready or not.
191struct AsyncValue : public RefCounted {
192 // AsyncValue similar to an AsyncToken created with a reference count of 2.
193 AsyncValue(AsyncRuntime *runtime, int64_t size)
194 : RefCounted(runtime, /*refCount=*/2), state(State::kUnavailable),
195 storage(size) {}
196
197 std::atomic<State::StateEnum> state;
198
199 // Use vector of bytes to store async value payload.
200 std::vector<std::byte> storage;
201
202 // Pending awaiters are guarded by a mutex.
203 std::mutex mu;
204 std::condition_variable cv;
205 std::vector<std::function<void()>> awaiters;
206};
207
208// Async group provides a mechanism to group together multiple async tokens or
209// values to await on all of them together (wait for the completion of all
210// tokens or values added to the group).
211struct AsyncGroup : public RefCounted {
212 AsyncGroup(AsyncRuntime *runtime, int64_t size)
213 : RefCounted(runtime), pendingTokens(size), numErrors(0), rank(0) {}
214
215 std::atomic<int> pendingTokens;
216 std::atomic<int> numErrors;
217 std::atomic<int> rank;
218
219 // Pending awaiters are guarded by a mutex.
220 std::mutex mu;
221 std::condition_variable cv;
222 std::vector<std::function<void()>> awaiters;
223};
224
225// Adds references to reference counted runtime object.
226extern "C" void mlirAsyncRuntimeAddRef(RefCountedObjPtr ptr, int64_t count) {
227 RefCounted *refCounted = static_cast<RefCounted *>(ptr);
228 refCounted->addRef(count);
229}
230
231// Drops references from reference counted runtime object.
232extern "C" void mlirAsyncRuntimeDropRef(RefCountedObjPtr ptr, int64_t count) {
233 RefCounted *refCounted = static_cast<RefCounted *>(ptr);
234 refCounted->dropRef(count);
235}
236
237// Creates a new `async.token` in not-ready state.
238extern "C" AsyncToken *mlirAsyncRuntimeCreateToken() {
239 AsyncToken *token = new AsyncToken(getDefaultAsyncRuntime());
240 return token;
241}
242
243// Creates a new `async.value` in not-ready state.
244extern "C" AsyncValue *mlirAsyncRuntimeCreateValue(int64_t size) {
245 AsyncValue *value = new AsyncValue(getDefaultAsyncRuntime(), size);
246 return value;
247}
248
249// Create a new `async.group` in empty state.
250extern "C" AsyncGroup *mlirAsyncRuntimeCreateGroup(int64_t size) {
251 AsyncGroup *group = new AsyncGroup(getDefaultAsyncRuntime(), size);
252 return group;
253}
254
255extern "C" int64_t mlirAsyncRuntimeAddTokenToGroup(AsyncToken *token,
256 AsyncGroup *group) {
257 std::unique_lock<std::mutex> lockToken(token->mu);
258 std::unique_lock<std::mutex> lockGroup(group->mu);
259
260 // Get the rank of the token inside the group before we drop the reference.
261 int rank = group->rank.fetch_add(i: 1);
262
263 auto onTokenReady = [group, token]() {
264 // Increment the number of errors in the group.
265 if (State(token->state).isError())
266 group->numErrors.fetch_add(i: 1);
267
268 // If pending tokens go below zero it means that more tokens than the group
269 // size were added to this group.
270 assert(group->pendingTokens > 0 && "wrong group size");
271
272 // Run all group awaiters if it was the last token in the group.
273 if (group->pendingTokens.fetch_sub(i: 1) == 1) {
274 group->cv.notify_all();
275 for (auto &awaiter : group->awaiters)
276 awaiter();
277 }
278 };
279
280 if (State(token->state).isAvailableOrError()) {
281 // Update group pending tokens immediately and maybe run awaiters.
282 onTokenReady();
283
284 } else {
285 // Update group pending tokens when token will become ready. Because this
286 // will happen asynchronously we must ensure that `group` is alive until
287 // then, and re-ackquire the lock.
288 group->addRef();
289
290 token->awaiters.emplace_back(args: [group, onTokenReady]() {
291 // Make sure that `dropRef` does not destroy the mutex owned by the lock.
292 {
293 std::unique_lock<std::mutex> lockGroup(group->mu);
294 onTokenReady();
295 }
296 group->dropRef();
297 });
298 }
299
300 return rank;
301}
302
303// Switches `async.token` to available or error state (terminatl state) and runs
304// all awaiters.
305static void setTokenState(AsyncToken *token, State state) {
306 assert(state.isAvailableOrError() && "must be terminal state");
307 assert(State(token->state).isUnavailable() && "token must be unavailable");
308
309 // Make sure that `dropRef` does not destroy the mutex owned by the lock.
310 {
311 std::unique_lock<std::mutex> lock(token->mu);
312 token->state = state;
313 token->cv.notify_all();
314 for (auto &awaiter : token->awaiters)
315 awaiter();
316 }
317
318 // Async tokens created with a ref count `2` to keep token alive until the
319 // async task completes. Drop this reference explicitly when token emplaced.
320 token->dropRef();
321}
322
323static void setValueState(AsyncValue *value, State state) {
324 assert(state.isAvailableOrError() && "must be terminal state");
325 assert(State(value->state).isUnavailable() && "value must be unavailable");
326
327 // Make sure that `dropRef` does not destroy the mutex owned by the lock.
328 {
329 std::unique_lock<std::mutex> lock(value->mu);
330 value->state = state;
331 value->cv.notify_all();
332 for (auto &awaiter : value->awaiters)
333 awaiter();
334 }
335
336 // Async values created with a ref count `2` to keep value alive until the
337 // async task completes. Drop this reference explicitly when value emplaced.
338 value->dropRef();
339}
340
341extern "C" void mlirAsyncRuntimeEmplaceToken(AsyncToken *token) {
342 setTokenState(token, state: State::kAvailable);
343}
344
345extern "C" void mlirAsyncRuntimeEmplaceValue(AsyncValue *value) {
346 setValueState(value, state: State::kAvailable);
347}
348
349extern "C" void mlirAsyncRuntimeSetTokenError(AsyncToken *token) {
350 setTokenState(token, state: State::kError);
351}
352
353extern "C" void mlirAsyncRuntimeSetValueError(AsyncValue *value) {
354 setValueState(value, state: State::kError);
355}
356
357extern "C" bool mlirAsyncRuntimeIsTokenError(AsyncToken *token) {
358 return State(token->state).isError();
359}
360
361extern "C" bool mlirAsyncRuntimeIsValueError(AsyncValue *value) {
362 return State(value->state).isError();
363}
364
365extern "C" bool mlirAsyncRuntimeIsGroupError(AsyncGroup *group) {
366 return group->numErrors.load() > 0;
367}
368
369extern "C" void mlirAsyncRuntimeAwaitToken(AsyncToken *token) {
370 std::unique_lock<std::mutex> lock(token->mu);
371 if (!State(token->state).isAvailableOrError())
372 token->cv.wait(
373 lock&: lock, p: [token] { return State(token->state).isAvailableOrError(); });
374}
375
376extern "C" void mlirAsyncRuntimeAwaitValue(AsyncValue *value) {
377 std::unique_lock<std::mutex> lock(value->mu);
378 if (!State(value->state).isAvailableOrError())
379 value->cv.wait(
380 lock&: lock, p: [value] { return State(value->state).isAvailableOrError(); });
381}
382
383extern "C" void mlirAsyncRuntimeAwaitAllInGroup(AsyncGroup *group) {
384 std::unique_lock<std::mutex> lock(group->mu);
385 if (group->pendingTokens != 0)
386 group->cv.wait(lock&: lock, p: [group] { return group->pendingTokens == 0; });
387}
388
389// Returns a pointer to the storage owned by the async value.
390extern "C" ValueStorage mlirAsyncRuntimeGetValueStorage(AsyncValue *value) {
391 assert(!State(value->state).isError() && "unexpected error state");
392 return value->storage.data();
393}
394
395extern "C" void mlirAsyncRuntimeExecute(CoroHandle handle, CoroResume resume) {
396 auto *runtime = getDefaultAsyncRuntime();
397 runtime->getThreadPool().async(F: [handle, resume]() { (*resume)(handle); });
398}
399
400extern "C" void mlirAsyncRuntimeAwaitTokenAndExecute(AsyncToken *token,
401 CoroHandle handle,
402 CoroResume resume) {
403 auto execute = [handle, resume]() { (*resume)(handle); };
404 std::unique_lock<std::mutex> lock(token->mu);
405 if (State(token->state).isAvailableOrError()) {
406 lock.unlock();
407 execute();
408 } else {
409 token->awaiters.emplace_back(args: [execute]() { execute(); });
410 }
411}
412
413extern "C" void mlirAsyncRuntimeAwaitValueAndExecute(AsyncValue *value,
414 CoroHandle handle,
415 CoroResume resume) {
416 auto execute = [handle, resume]() { (*resume)(handle); };
417 std::unique_lock<std::mutex> lock(value->mu);
418 if (State(value->state).isAvailableOrError()) {
419 lock.unlock();
420 execute();
421 } else {
422 value->awaiters.emplace_back(args: [execute]() { execute(); });
423 }
424}
425
426extern "C" void mlirAsyncRuntimeAwaitAllInGroupAndExecute(AsyncGroup *group,
427 CoroHandle handle,
428 CoroResume resume) {
429 auto execute = [handle, resume]() { (*resume)(handle); };
430 std::unique_lock<std::mutex> lock(group->mu);
431 if (group->pendingTokens == 0) {
432 lock.unlock();
433 execute();
434 } else {
435 group->awaiters.emplace_back(args: [execute]() { execute(); });
436 }
437}
438
439extern "C" int64_t mlirAsyncRuntimGetNumWorkerThreads() {
440 return getDefaultAsyncRuntime()->getThreadPool().getMaxConcurrency();
441}
442
443//===----------------------------------------------------------------------===//
444// Small async runtime support library for testing.
445//===----------------------------------------------------------------------===//
446
447extern "C" void mlirAsyncRuntimePrintCurrentThreadId() {
448 static thread_local std::thread::id thisId = std::this_thread::get_id();
449 std::cout << "Current thread id: " << thisId << '\n';
450}
451
452//===----------------------------------------------------------------------===//
453// MLIR ExecutionEngine dynamic library integration.
454//===----------------------------------------------------------------------===//
455
456// Visual Studio had a bug that fails to compile nested generic lambdas
457// inside an `extern "C"` function.
458// https://developercommunity.visualstudio.com/content/problem/475494/clexe-error-with-lambda-inside-function-templates.html
459// The bug is fixed in VS2019 16.1. Separating the declaration and definition is
460// a work around for older versions of Visual Studio.
461// NOLINTNEXTLINE(*-identifier-naming): externally called.
462extern "C" MLIR_ASYNC_RUNTIME_EXPORT void
463__mlir_execution_engine_init(llvm::StringMap<void *> &exportSymbols);
464
465// NOLINTNEXTLINE(*-identifier-naming): externally called.
466void __mlir_execution_engine_init(llvm::StringMap<void *> &exportSymbols) {
467 auto exportSymbol = [&](llvm::StringRef name, auto ptr) {
468 assert(exportSymbols.count(name) == 0 && "symbol already exists");
469 exportSymbols[name] = reinterpret_cast<void *>(ptr);
470 };
471
472 exportSymbol("mlirAsyncRuntimeAddRef",
473 &mlir::runtime::mlirAsyncRuntimeAddRef);
474 exportSymbol("mlirAsyncRuntimeDropRef",
475 &mlir::runtime::mlirAsyncRuntimeDropRef);
476 exportSymbol("mlirAsyncRuntimeExecute",
477 &mlir::runtime::mlirAsyncRuntimeExecute);
478 exportSymbol("mlirAsyncRuntimeGetValueStorage",
479 &mlir::runtime::mlirAsyncRuntimeGetValueStorage);
480 exportSymbol("mlirAsyncRuntimeCreateToken",
481 &mlir::runtime::mlirAsyncRuntimeCreateToken);
482 exportSymbol("mlirAsyncRuntimeCreateValue",
483 &mlir::runtime::mlirAsyncRuntimeCreateValue);
484 exportSymbol("mlirAsyncRuntimeEmplaceToken",
485 &mlir::runtime::mlirAsyncRuntimeEmplaceToken);
486 exportSymbol("mlirAsyncRuntimeEmplaceValue",
487 &mlir::runtime::mlirAsyncRuntimeEmplaceValue);
488 exportSymbol("mlirAsyncRuntimeSetTokenError",
489 &mlir::runtime::mlirAsyncRuntimeSetTokenError);
490 exportSymbol("mlirAsyncRuntimeSetValueError",
491 &mlir::runtime::mlirAsyncRuntimeSetValueError);
492 exportSymbol("mlirAsyncRuntimeIsTokenError",
493 &mlir::runtime::mlirAsyncRuntimeIsTokenError);
494 exportSymbol("mlirAsyncRuntimeIsValueError",
495 &mlir::runtime::mlirAsyncRuntimeIsValueError);
496 exportSymbol("mlirAsyncRuntimeIsGroupError",
497 &mlir::runtime::mlirAsyncRuntimeIsGroupError);
498 exportSymbol("mlirAsyncRuntimeAwaitToken",
499 &mlir::runtime::mlirAsyncRuntimeAwaitToken);
500 exportSymbol("mlirAsyncRuntimeAwaitValue",
501 &mlir::runtime::mlirAsyncRuntimeAwaitValue);
502 exportSymbol("mlirAsyncRuntimeAwaitTokenAndExecute",
503 &mlir::runtime::mlirAsyncRuntimeAwaitTokenAndExecute);
504 exportSymbol("mlirAsyncRuntimeAwaitValueAndExecute",
505 &mlir::runtime::mlirAsyncRuntimeAwaitValueAndExecute);
506 exportSymbol("mlirAsyncRuntimeCreateGroup",
507 &mlir::runtime::mlirAsyncRuntimeCreateGroup);
508 exportSymbol("mlirAsyncRuntimeAddTokenToGroup",
509 &mlir::runtime::mlirAsyncRuntimeAddTokenToGroup);
510 exportSymbol("mlirAsyncRuntimeAwaitAllInGroup",
511 &mlir::runtime::mlirAsyncRuntimeAwaitAllInGroup);
512 exportSymbol("mlirAsyncRuntimeAwaitAllInGroupAndExecute",
513 &mlir::runtime::mlirAsyncRuntimeAwaitAllInGroupAndExecute);
514 exportSymbol("mlirAsyncRuntimGetNumWorkerThreads",
515 &mlir::runtime::mlirAsyncRuntimGetNumWorkerThreads);
516 exportSymbol("mlirAsyncRuntimePrintCurrentThreadId",
517 &mlir::runtime::mlirAsyncRuntimePrintCurrentThreadId);
518}
519
520// NOLINTNEXTLINE(*-identifier-naming): externally called.
521extern "C" MLIR_ASYNC_RUNTIME_EXPORT void __mlir_execution_engine_destroy() {
522 resetDefaultAsyncRuntime();
523}
524
525} // namespace runtime
526} // namespace mlir
527

source code of mlir/lib/ExecutionEngine/AsyncRuntime.cpp