1//===- AsyncRegionRewriter.cpp - Implementation of GPU async rewriters ----===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implements the GPU dialect pattern rewriters that make GPU op
10// within a region execute asynchronously.
11//
12//===----------------------------------------------------------------------===//
13
14#include "mlir/Dialect/GPU/Transforms/Passes.h"
15
16#include "mlir/Dialect/Async/IR/Async.h"
17#include "mlir/Dialect/Func/IR/FuncOps.h"
18#include "mlir/Dialect/GPU/IR/GPUDialect.h"
19#include "mlir/Dialect/GPU/Utils/GPUUtils.h"
20#include "mlir/IR/Builders.h"
21#include "mlir/IR/IRMapping.h"
22#include "mlir/Interfaces/SideEffectInterfaces.h"
23#include "mlir/Support/LLVM.h"
24#include "llvm/ADT/TypeSwitch.h"
25
26namespace mlir {
27#define GEN_PASS_DEF_GPUASYNCREGIONPASS
28#include "mlir/Dialect/GPU/Transforms/Passes.h.inc"
29} // namespace mlir
30
31using namespace mlir;
32
33namespace {
34class GpuAsyncRegionPass
35 : public impl::GpuAsyncRegionPassBase<GpuAsyncRegionPass> {
36 struct ThreadTokenCallback;
37 struct DeferWaitCallback;
38 struct SingleTokenUseCallback;
39 void runOnOperation() override;
40};
41} // namespace
42
43static bool isTerminator(Operation *op) {
44 return op->mightHaveTrait<OpTrait::IsTerminator>();
45}
46static bool hasSideEffects(Operation *op) { return !isMemoryEffectFree(op); }
47
48// Region walk callback which makes GPU ops implementing the AsyncOpInterface
49// execute asynchronously.
50struct GpuAsyncRegionPass::ThreadTokenCallback {
51 ThreadTokenCallback(MLIRContext &context) : builder(&context) {}
52
53 WalkResult operator()(Block *block) {
54 for (Operation &op : make_early_inc_range(Range&: *block)) {
55 if (failed(Result: visit(op: &op)))
56 return WalkResult::interrupt();
57 }
58 return WalkResult::advance();
59 }
60
61private:
62 // If `op` implements the AsyncOpInterface, insert a `gpu.wait async` to
63 // create a current token (unless it already exists), and 'thread' that token
64 // through the `op` so that it executes asynchronously.
65 //
66 // If `op` is a terminator or an op with side-effects, insert a `gpu.wait` to
67 // host-synchronize execution. A `!gpu.async.token` will therefore only be
68 // used inside of its block and GPU execution will always synchronize with
69 // the host at block boundaries.
70 LogicalResult visit(Operation *op) {
71 if (isa<gpu::LaunchOp>(Val: op))
72 return op->emitOpError(message: "replace with gpu.launch_func first");
73 if (auto waitOp = llvm::dyn_cast<gpu::WaitOp>(Val: op)) {
74 if (currentToken)
75 waitOp.addAsyncDependency(token: currentToken);
76 currentToken = waitOp.getAsyncToken();
77 return success();
78 }
79 builder.setInsertionPoint(op);
80 if (auto asyncOp = dyn_cast<gpu::AsyncOpInterface>(Val: op))
81 return rewriteAsyncOp(asyncOp); // Replace GPU op with async version.
82 if (!currentToken)
83 return success();
84 // Insert host synchronization before terminator or op with side effects.
85 if (isTerminator(op) || hasSideEffects(op))
86 currentToken = createWaitOp(loc: op->getLoc(), resultType: Type(), operands: {currentToken});
87 return success();
88 }
89
90 // Replaces asyncOp with a clone that returns a token.
91 LogicalResult rewriteAsyncOp(gpu::AsyncOpInterface asyncOp) {
92 auto *op = asyncOp.getOperation();
93 auto tokenType = builder.getType<gpu::AsyncTokenType>();
94
95 // If there is no current token, insert a `gpu.wait async` without
96 // dependencies to create one.
97 if (!currentToken)
98 currentToken = createWaitOp(loc: op->getLoc(), resultType: tokenType, operands: {});
99 asyncOp.addAsyncDependency(token: currentToken);
100
101 // Return early if op returns a token already.
102 currentToken = asyncOp.getAsyncToken();
103 if (currentToken)
104 return success();
105
106 // Clone the op to return a token in addition to the other results.
107 SmallVector<Type, 1> resultTypes;
108 resultTypes.reserve(N: 1 + op->getNumResults());
109 copy(Range: op->getResultTypes(), Out: std::back_inserter(x&: resultTypes));
110 resultTypes.push_back(Elt: tokenType);
111 auto *newOp = Operation::create(
112 location: op->getLoc(), name: op->getName(), resultTypes, operands: op->getOperands(),
113 attributes: op->getDiscardableAttrDictionary(), properties: op->getPropertiesStorage(),
114 successors: op->getSuccessors(), numRegions: op->getNumRegions());
115
116 // Clone regions into new op.
117 IRMapping mapping;
118 for (auto pair : llvm::zip_first(t: op->getRegions(), u: newOp->getRegions()))
119 std::get<0>(t&: pair).cloneInto(dest: &std::get<1>(t&: pair), mapper&: mapping);
120
121 // Replace the op with the async clone.
122 auto results = newOp->getResults();
123 currentToken = results.back();
124 builder.insert(op: newOp);
125 op->replaceAllUsesWith(values: results.drop_back());
126 op->erase();
127
128 return success();
129 }
130
131 Value createWaitOp(Location loc, Type resultType, ValueRange operands) {
132 return builder.create<gpu::WaitOp>(location: loc, args&: resultType, args&: operands)
133 .getAsyncToken();
134 }
135
136 OpBuilder builder;
137
138 // The token that represents the current asynchronous dependency. It's valid
139 // range starts with a `gpu.wait async` op, and ends with a `gpu.wait` op.
140 // In between, each gpu::AsyncOpInterface depends on the current token and
141 // produces the new one.
142 Value currentToken = {};
143};
144
145/// Erases `executeOp` and returns a clone with additional `results`.
146async::ExecuteOp addExecuteResults(async::ExecuteOp executeOp,
147 ValueRange results) {
148 // Add values to async.yield op.
149 Operation *yieldOp = executeOp.getBody()->getTerminator();
150 yieldOp->insertOperands(index: yieldOp->getNumOperands(), operands: results);
151
152 // Construct new result type list with additional types.
153 SmallVector<Type, 2> resultTypes;
154 resultTypes.reserve(N: executeOp.getNumResults() + results.size());
155 transform(Range: executeOp.getResultTypes(), d_first: std::back_inserter(x&: resultTypes),
156 F: [](Type type) {
157 // Extract value type from !async.value.
158 if (auto valueType = dyn_cast<async::ValueType>(Val&: type))
159 return valueType.getValueType();
160 assert(isa<async::TokenType>(type) && "expected token type");
161 return type;
162 });
163 transform(Range&: results, d_first: std::back_inserter(x&: resultTypes),
164 F: [](Value value) { return value.getType(); });
165
166 // Clone executeOp with the extra results.
167 OpBuilder builder(executeOp);
168 auto newOp = builder.create<async::ExecuteOp>(
169 location: executeOp.getLoc(), args: TypeRange{resultTypes}.drop_front() /*drop token*/,
170 args: executeOp.getDependencies(), args: executeOp.getBodyOperands());
171 IRMapping mapper;
172 newOp.getRegion().getBlocks().clear();
173 executeOp.getRegion().cloneInto(dest: &newOp.getRegion(), mapper);
174
175 // Replace executeOp with cloned one.
176 executeOp.getOperation()->replaceAllUsesWith(
177 values: newOp.getResults().drop_back(n: results.size()));
178 executeOp.erase();
179
180 return newOp;
181}
182
183// Callback for `async.execute` ops which tries to push the contained
184// synchronous `gpu.wait` op to the dependencies of the `async.execute`.
185struct GpuAsyncRegionPass::DeferWaitCallback {
186 // If the `executeOp`s token is used only in `async.execute` or `async.await`
187 // ops, add the region's last `gpu.wait` op to the worklist if it is
188 // synchronous and is the last op with side effects.
189 void operator()(async::ExecuteOp executeOp) {
190 if (!areAllUsersExecuteOrAwait(token: executeOp.getToken()))
191 return;
192 // async.execute's region is currently restricted to one block.
193 for (auto &op : llvm::reverse(C: executeOp.getBody()->without_terminator())) {
194 if (auto waitOp = dyn_cast<gpu::WaitOp>(Val&: op)) {
195 if (!waitOp.getAsyncToken())
196 worklist.push_back(Elt: waitOp);
197 return;
198 }
199 if (hasSideEffects(op: &op))
200 return;
201 }
202 }
203
204 // The destructor performs the actual rewrite work.
205 ~DeferWaitCallback() {
206 for (size_t i = 0; i < worklist.size(); ++i) {
207 auto waitOp = worklist[i];
208 auto executeOp = waitOp->getParentOfType<async::ExecuteOp>();
209
210 // Erase `gpu.wait` and return async dependencies from execute op instead.
211 SmallVector<Value, 4> dependencies = waitOp.getAsyncDependencies();
212 waitOp.erase();
213 executeOp = addExecuteResults(executeOp, results: dependencies);
214
215 // Add the async dependency to each user of the `async.execute` token.
216 auto asyncTokens = executeOp.getResults().take_back(n: dependencies.size());
217 SmallVector<Operation *, 4> users(executeOp.getToken().user_begin(),
218 executeOp.getToken().user_end());
219 for (Operation *user : users)
220 addAsyncDependencyAfter(asyncTokens, op: user);
221 }
222 }
223
224private:
225 // Returns whether all token users are either 'async.execute' or 'async.await'
226 // ops. This is used as a requirement for pushing 'gpu.wait' ops from a
227 // 'async.execute' body to it's users. Specifically, we do not allow
228 // terminator users, because it could mean that the `async.execute` is inside
229 // control flow code.
230 static bool areAllUsersExecuteOrAwait(Value token) {
231 return !token.use_empty() &&
232 llvm::all_of(Range: token.getUsers(),
233 P: llvm::IsaPred<async::ExecuteOp, async::AwaitOp>);
234 }
235
236 // Add the `asyncToken` as dependency as needed after `op`.
237 void addAsyncDependencyAfter(ValueRange asyncTokens, Operation *op) {
238 OpBuilder builder(op->getContext());
239 auto loc = op->getLoc();
240
241 Block::iterator it;
242 SmallVector<Value, 1> tokens;
243 tokens.reserve(N: asyncTokens.size());
244 TypeSwitch<Operation *>(op)
245 .Case<async::AwaitOp>(caseFn: [&](auto awaitOp) {
246 // Add async.await ops to wait for the !gpu.async.tokens.
247 builder.setInsertionPointAfter(op);
248 for (auto asyncToken : asyncTokens)
249 tokens.push_back(
250 Elt: builder.create<async::AwaitOp>(location: loc, args&: asyncToken).getResult());
251 // Set `it` after the inserted async.await ops.
252 it = builder.getInsertionPoint();
253 })
254 .Case<async::ExecuteOp>(caseFn: [&](auto executeOp) {
255 // Set `it` to the beginning of the region and add asyncTokens to the
256 // async.execute operands.
257 it = executeOp.getBody()->begin();
258 executeOp.getBodyOperandsMutable().append(asyncTokens);
259 SmallVector<Type, 1> tokenTypes(
260 asyncTokens.size(), builder.getType<gpu::AsyncTokenType>());
261 SmallVector<Location, 1> tokenLocs(asyncTokens.size(),
262 executeOp.getLoc());
263 copy(executeOp.getBody()->addArguments(tokenTypes, tokenLocs),
264 std::back_inserter(x&: tokens));
265 });
266
267 // Advance `it` to terminator or op with side-effects.
268 it = std::find_if(first: it, last: Block::iterator(), pred: [](Operation &op) {
269 return isTerminator(op: &op) || hasSideEffects(op: &op);
270 });
271
272 // If `op` implements the AsyncOpInterface, add `token` to the list of async
273 // dependencies.
274 if (auto asyncOp = dyn_cast<gpu::AsyncOpInterface>(Val&: *it)) {
275 for (auto token : tokens)
276 asyncOp.addAsyncDependency(token);
277 return;
278 }
279
280 // Otherwise, insert a gpu.wait before 'it'.
281 builder.setInsertionPoint(block: it->getBlock(), insertPoint: it);
282 auto waitOp = builder.create<gpu::WaitOp>(location: loc, args: Type{}, args&: tokens);
283
284 // If the new waitOp is at the end of an async.execute region, add it to the
285 // worklist. 'operator()(executeOp)' would do the same, but this is faster.
286 auto executeOp = dyn_cast<async::ExecuteOp>(Val: it->getParentOp());
287 if (executeOp && areAllUsersExecuteOrAwait(token: executeOp.getToken()) &&
288 !it->getNextNode())
289 worklist.push_back(Elt: waitOp);
290 }
291
292 SmallVector<gpu::WaitOp, 8> worklist;
293};
294
295// Callback for `async.execute` ops which repeats !gpu.async.token results
296// so that each of them is only used once.
297struct GpuAsyncRegionPass::SingleTokenUseCallback {
298 void operator()(async::ExecuteOp executeOp) {
299 // Extract !gpu.async.token results which have multiple uses.
300 auto multiUseResults = llvm::make_filter_range(
301 Range: executeOp.getBodyResults(), Pred: [](OpResult result) {
302 if (result.use_empty() || result.hasOneUse())
303 return false;
304 auto valueType = dyn_cast<async::ValueType>(Val: result.getType());
305 return valueType &&
306 isa<gpu::AsyncTokenType>(Val: valueType.getValueType());
307 });
308 if (multiUseResults.empty())
309 return;
310
311 // Indices within !async.execute results (i.e. without the async.token).
312 SmallVector<int, 4> indices;
313 transform(Range&: multiUseResults, d_first: std::back_inserter(x&: indices),
314 F: [](OpResult result) {
315 return result.getResultNumber() - 1; // Index without token.
316 });
317
318 for (auto index : indices) {
319 assert(!executeOp.getBodyResults()[index].getUses().empty());
320 // Repeat async.yield token result, one for each use after the first one.
321 auto uses = llvm::drop_begin(RangeOrContainer: executeOp.getBodyResults()[index].getUses());
322 auto count = std::distance(first: uses.begin(), last: uses.end());
323 auto yieldOp = cast<async::YieldOp>(Val: executeOp.getBody()->getTerminator());
324 SmallVector<Value, 4> operands(count, yieldOp.getOperand(i: index));
325 executeOp = addExecuteResults(executeOp, results: operands);
326 // Update 'uses' to refer to the new executeOp.
327 uses = llvm::drop_begin(RangeOrContainer: executeOp.getBodyResults()[index].getUses());
328 auto results = executeOp.getBodyResults().take_back(n: count);
329 for (auto pair : llvm::zip(t&: uses, u&: results))
330 std::get<0>(t&: pair).set(std::get<1>(t&: pair));
331 }
332 }
333};
334
335// Replaces synchronous GPU ops in the op's region with asynchronous ones and
336// inserts the necessary synchronization (as gpu.wait ops). Assumes sequential
337// execution semantics and that no GPU ops are asynchronous yet.
338void GpuAsyncRegionPass::runOnOperation() {
339 if (getOperation()->walk(callback: ThreadTokenCallback(getContext())).wasInterrupted())
340 return signalPassFailure();
341
342 // Collect gpu.wait ops that we can move out of async.execute regions.
343 getOperation().getRegion().walk(callback: DeferWaitCallback());
344 // Makes each !gpu.async.token returned from async.execute op have single use.
345 getOperation().getRegion().walk(callback: SingleTokenUseCallback());
346}
347

source code of mlir/lib/Dialect/GPU/Transforms/AsyncRegionRewriter.cpp