1//===- AsyncRegionRewriter.cpp - Implementation of GPU async rewriters ----===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implements the GPU dialect pattern rewriters that make GPU op
10// within a region execute asynchronously.
11//
12//===----------------------------------------------------------------------===//
13
14#include "mlir/Dialect/GPU/Transforms/Passes.h"
15
16#include "mlir/Dialect/Async/IR/Async.h"
17#include "mlir/Dialect/Func/IR/FuncOps.h"
18#include "mlir/Dialect/GPU/IR/GPUDialect.h"
19#include "mlir/Dialect/GPU/Transforms/Utils.h"
20#include "mlir/IR/Builders.h"
21#include "mlir/IR/IRMapping.h"
22#include "mlir/IR/PatternMatch.h"
23#include "mlir/IR/SymbolTable.h"
24#include "mlir/Interfaces/SideEffectInterfaces.h"
25#include "mlir/Support/LLVM.h"
26#include "mlir/Transforms/RegionUtils.h"
27#include "llvm/ADT/TypeSwitch.h"
28
29namespace mlir {
30#define GEN_PASS_DEF_GPUASYNCREGIONPASS
31#include "mlir/Dialect/GPU/Transforms/Passes.h.inc"
32} // namespace mlir
33
34using namespace mlir;
35
36namespace {
37class GpuAsyncRegionPass
38 : public impl::GpuAsyncRegionPassBase<GpuAsyncRegionPass> {
39 struct ThreadTokenCallback;
40 struct DeferWaitCallback;
41 struct SingleTokenUseCallback;
42 void runOnOperation() override;
43};
44} // namespace
45
46static bool isTerminator(Operation *op) {
47 return op->mightHaveTrait<OpTrait::IsTerminator>();
48}
49static bool hasSideEffects(Operation *op) { return !isMemoryEffectFree(op); }
50
51// Region walk callback which makes GPU ops implementing the AsyncOpInterface
52// execute asynchronously.
53struct GpuAsyncRegionPass::ThreadTokenCallback {
54 ThreadTokenCallback(MLIRContext &context) : builder(&context) {}
55
56 WalkResult operator()(Block *block) {
57 for (Operation &op : make_early_inc_range(Range&: *block)) {
58 if (failed(result: visit(op: &op)))
59 return WalkResult::interrupt();
60 }
61 return WalkResult::advance();
62 }
63
64private:
65 // If `op` implements the AsyncOpInterface, insert a `gpu.wait async` to
66 // create a current token (unless it already exists), and 'thread' that token
67 // through the `op` so that it executes asynchronously.
68 //
69 // If `op` is a terminator or an op with side-effects, insert a `gpu.wait` to
70 // host-synchronize execution. A `!gpu.async.token` will therefore only be
71 // used inside of its block and GPU execution will always synchronize with
72 // the host at block boundaries.
73 LogicalResult visit(Operation *op) {
74 if (isa<gpu::LaunchOp>(Val: op))
75 return op->emitOpError(message: "replace with gpu.launch_func first");
76 if (auto waitOp = llvm::dyn_cast<gpu::WaitOp>(op)) {
77 if (currentToken)
78 waitOp.addAsyncDependency(currentToken);
79 currentToken = waitOp.getAsyncToken();
80 return success();
81 }
82 builder.setInsertionPoint(op);
83 if (auto asyncOp = dyn_cast<gpu::AsyncOpInterface>(op))
84 return rewriteAsyncOp(asyncOp); // Replace GPU op with async version.
85 if (!currentToken)
86 return success();
87 // Insert host synchronization before terminator or op with side effects.
88 if (isTerminator(op) || hasSideEffects(op))
89 currentToken = createWaitOp(loc: op->getLoc(), resultType: Type(), operands: {currentToken});
90 return success();
91 }
92
93 // Replaces asyncOp with a clone that returns a token.
94 LogicalResult rewriteAsyncOp(gpu::AsyncOpInterface asyncOp) {
95 auto *op = asyncOp.getOperation();
96 auto tokenType = builder.getType<gpu::AsyncTokenType>();
97
98 // If there is no current token, insert a `gpu.wait async` without
99 // dependencies to create one.
100 if (!currentToken)
101 currentToken = createWaitOp(loc: op->getLoc(), resultType: tokenType, operands: {});
102 asyncOp.addAsyncDependency(currentToken);
103
104 // Return early if op returns a token already.
105 currentToken = asyncOp.getAsyncToken();
106 if (currentToken)
107 return success();
108
109 // Clone the op to return a token in addition to the other results.
110 SmallVector<Type, 1> resultTypes;
111 resultTypes.reserve(N: 1 + op->getNumResults());
112 copy(op->getResultTypes(), std::back_inserter(x&: resultTypes));
113 resultTypes.push_back(Elt: tokenType);
114 auto *newOp = Operation::create(
115 op->getLoc(), op->getName(), resultTypes, op->getOperands(),
116 op->getDiscardableAttrDictionary(), op->getPropertiesStorage(),
117 op->getSuccessors(), op->getNumRegions());
118
119 // Clone regions into new op.
120 IRMapping mapping;
121 for (auto pair : llvm::zip_first(op->getRegions(), newOp->getRegions()))
122 std::get<0>(pair).cloneInto(&std::get<1>(pair), mapping);
123
124 // Replace the op with the async clone.
125 auto results = newOp->getResults();
126 currentToken = results.back();
127 builder.insert(op: newOp);
128 op->replaceAllUsesWith(results.drop_back());
129 op->erase();
130
131 return success();
132 }
133
134 Value createWaitOp(Location loc, Type resultType, ValueRange operands) {
135 return builder.create<gpu::WaitOp>(loc, resultType, operands)
136 .getAsyncToken();
137 }
138
139 OpBuilder builder;
140
141 // The token that represents the current asynchronous dependency. It's valid
142 // range starts with a `gpu.wait async` op, and ends with a `gpu.wait` op.
143 // In between, each gpu::AsyncOpInterface depends on the current token and
144 // produces the new one.
145 Value currentToken = {};
146};
147
148/// Erases `executeOp` and returns a clone with additional `results`.
149async::ExecuteOp addExecuteResults(async::ExecuteOp executeOp,
150 ValueRange results) {
151 // Add values to async.yield op.
152 Operation *yieldOp = executeOp.getBody()->getTerminator();
153 yieldOp->insertOperands(index: yieldOp->getNumOperands(), operands: results);
154
155 // Construct new result type list with additional types.
156 SmallVector<Type, 2> resultTypes;
157 resultTypes.reserve(N: executeOp.getNumResults() + results.size());
158 transform(executeOp.getResultTypes(), std::back_inserter(x&: resultTypes),
159 [](Type type) {
160 // Extract value type from !async.value.
161 if (auto valueType = dyn_cast<async::ValueType>(type))
162 return valueType.getValueType();
163 assert(isa<async::TokenType>(type) && "expected token type");
164 return type;
165 });
166 transform(Range&: results, d_first: std::back_inserter(x&: resultTypes),
167 F: [](Value value) { return value.getType(); });
168
169 // Clone executeOp with the extra results.
170 OpBuilder builder(executeOp);
171 auto newOp = builder.create<async::ExecuteOp>(
172 executeOp.getLoc(), TypeRange{resultTypes}.drop_front() /*drop token*/,
173 executeOp.getDependencies(), executeOp.getBodyOperands());
174 IRMapping mapper;
175 newOp.getRegion().getBlocks().clear();
176 executeOp.getRegion().cloneInto(&newOp.getRegion(), mapper);
177
178 // Replace executeOp with cloned one.
179 executeOp.getOperation()->replaceAllUsesWith(
180 newOp.getResults().drop_back(results.size()));
181 executeOp.erase();
182
183 return newOp;
184}
185
186// Callback for `async.execute` ops which tries to push the contained
187// synchronous `gpu.wait` op to the dependencies of the `async.execute`.
188struct GpuAsyncRegionPass::DeferWaitCallback {
189 // If the `executeOp`s token is used only in `async.execute` or `async.await`
190 // ops, add the region's last `gpu.wait` op to the worklist if it is
191 // synchronous and is the last op with side effects.
192 void operator()(async::ExecuteOp executeOp) {
193 if (!areAllUsersExecuteOrAwait(token: executeOp.getToken()))
194 return;
195 // async.execute's region is currently restricted to one block.
196 for (auto &op : llvm::reverse(executeOp.getBody()->without_terminator())) {
197 if (auto waitOp = dyn_cast<gpu::WaitOp>(op)) {
198 if (!waitOp.getAsyncToken())
199 worklist.push_back(waitOp);
200 return;
201 }
202 if (hasSideEffects(&op))
203 return;
204 }
205 }
206
207 // The destructor performs the actual rewrite work.
208 ~DeferWaitCallback() {
209 for (size_t i = 0; i < worklist.size(); ++i) {
210 auto waitOp = worklist[i];
211 auto executeOp = waitOp->getParentOfType<async::ExecuteOp>();
212
213 // Erase `gpu.wait` and return async dependencies from execute op instead.
214 SmallVector<Value, 4> dependencies = waitOp.getAsyncDependencies();
215 waitOp.erase();
216 executeOp = addExecuteResults(executeOp, dependencies);
217
218 // Add the async dependency to each user of the `async.execute` token.
219 auto asyncTokens = executeOp.getResults().take_back(dependencies.size());
220 SmallVector<Operation *, 4> users(executeOp.getToken().user_begin(),
221 executeOp.getToken().user_end());
222 for (Operation *user : users)
223 addAsyncDependencyAfter(asyncTokens, user);
224 }
225 }
226
227private:
228 // Returns whether all token users are either 'async.execute' or 'async.await'
229 // ops. This is used as a requirement for pushing 'gpu.wait' ops from a
230 // 'async.execute' body to it's users. Specifically, we do not allow
231 // terminator users, because it could mean that the `async.execute` is inside
232 // control flow code.
233 static bool areAllUsersExecuteOrAwait(Value token) {
234 return !token.use_empty() &&
235 llvm::all_of(token.getUsers(),
236 llvm::IsaPred<async::ExecuteOp, async::AwaitOp>);
237 }
238
239 // Add the `asyncToken` as dependency as needed after `op`.
240 void addAsyncDependencyAfter(ValueRange asyncTokens, Operation *op) {
241 OpBuilder builder(op->getContext());
242 auto loc = op->getLoc();
243
244 Block::iterator it;
245 SmallVector<Value, 1> tokens;
246 tokens.reserve(N: asyncTokens.size());
247 TypeSwitch<Operation *>(op)
248 .Case<async::AwaitOp>([&](auto awaitOp) {
249 // Add async.await ops to wait for the !gpu.async.tokens.
250 builder.setInsertionPointAfter(op);
251 for (auto asyncToken : asyncTokens)
252 tokens.push_back(
253 builder.create<async::AwaitOp>(loc, asyncToken).getResult());
254 // Set `it` after the inserted async.await ops.
255 it = builder.getInsertionPoint();
256 })
257 .Case<async::ExecuteOp>([&](auto executeOp) {
258 // Set `it` to the beginning of the region and add asyncTokens to the
259 // async.execute operands.
260 it = executeOp.getBody()->begin();
261 executeOp.getBodyOperandsMutable().append(asyncTokens);
262 SmallVector<Type, 1> tokenTypes(
263 asyncTokens.size(), builder.getType<gpu::AsyncTokenType>());
264 SmallVector<Location, 1> tokenLocs(asyncTokens.size(),
265 executeOp.getLoc());
266 copy(executeOp.getBody()->addArguments(tokenTypes, tokenLocs),
267 std::back_inserter(tokens));
268 });
269
270 // Advance `it` to terminator or op with side-effects.
271 it = std::find_if(first: it, last: Block::iterator(), pred: [](Operation &op) {
272 return isTerminator(op: &op) || hasSideEffects(op: &op);
273 });
274
275 // If `op` implements the AsyncOpInterface, add `token` to the list of async
276 // dependencies.
277 if (auto asyncOp = dyn_cast<gpu::AsyncOpInterface>(*it)) {
278 for (auto token : tokens)
279 asyncOp.addAsyncDependency(token);
280 return;
281 }
282
283 // Otherwise, insert a gpu.wait before 'it'.
284 builder.setInsertionPoint(block: it->getBlock(), insertPoint: it);
285 auto waitOp = builder.create<gpu::WaitOp>(loc, Type{}, tokens);
286
287 // If the new waitOp is at the end of an async.execute region, add it to the
288 // worklist. 'operator()(executeOp)' would do the same, but this is faster.
289 auto executeOp = dyn_cast<async::ExecuteOp>(it->getParentOp());
290 if (executeOp && areAllUsersExecuteOrAwait(executeOp.getToken()) &&
291 !it->getNextNode())
292 worklist.push_back(waitOp);
293 }
294
295 SmallVector<gpu::WaitOp, 8> worklist;
296};
297
298// Callback for `async.execute` ops which repeats !gpu.async.token results
299// so that each of them is only used once.
300struct GpuAsyncRegionPass::SingleTokenUseCallback {
301 void operator()(async::ExecuteOp executeOp) {
302 // Extract !gpu.async.token results which have multiple uses.
303 auto multiUseResults = llvm::make_filter_range(
304 executeOp.getBodyResults(), [](OpResult result) {
305 if (result.use_empty() || result.hasOneUse())
306 return false;
307 auto valueType = dyn_cast<async::ValueType>(result.getType());
308 return valueType &&
309 isa<gpu::AsyncTokenType>(valueType.getValueType());
310 });
311 if (multiUseResults.empty())
312 return;
313
314 // Indices within !async.execute results (i.e. without the async.token).
315 SmallVector<int, 4> indices;
316 transform(multiUseResults, std::back_inserter(x&: indices),
317 [](OpResult result) {
318 return result.getResultNumber() - 1; // Index without token.
319 });
320
321 for (auto index : indices) {
322 assert(!executeOp.getBodyResults()[index].getUses().empty());
323 // Repeat async.yield token result, one for each use after the first one.
324 auto uses = llvm::drop_begin(executeOp.getBodyResults()[index].getUses());
325 auto count = std::distance(uses.begin(), uses.end());
326 auto yieldOp = cast<async::YieldOp>(executeOp.getBody()->getTerminator());
327 SmallVector<Value, 4> operands(count, yieldOp.getOperand(index));
328 executeOp = addExecuteResults(executeOp, operands);
329 // Update 'uses' to refer to the new executeOp.
330 uses = llvm::drop_begin(executeOp.getBodyResults()[index].getUses());
331 auto results = executeOp.getBodyResults().take_back(count);
332 for (auto pair : llvm::zip(uses, results))
333 std::get<0>(pair).set(std::get<1>(pair));
334 }
335 }
336};
337
338// Replaces synchronous GPU ops in the op's region with asynchronous ones and
339// inserts the necessary synchronization (as gpu.wait ops). Assumes sequential
340// execution semantics and that no GPU ops are asynchronous yet.
341void GpuAsyncRegionPass::runOnOperation() {
342 if (getOperation()->walk(ThreadTokenCallback(getContext())).wasInterrupted())
343 return signalPassFailure();
344
345 // Collect gpu.wait ops that we can move out of async.execute regions.
346 getOperation().getRegion().walk(DeferWaitCallback());
347 // Makes each !gpu.async.token returned from async.execute op have single use.
348 getOperation().getRegion().walk(SingleTokenUseCallback());
349}
350
351std::unique_ptr<OperationPass<func::FuncOp>> mlir::createGpuAsyncRegionPass() {
352 return std::make_unique<GpuAsyncRegionPass>();
353}
354

source code of mlir/lib/Dialect/GPU/Transforms/AsyncRegionRewriter.cpp