| 1 | //===- AsyncRegionRewriter.cpp - Implementation of GPU async rewriters ----===// |
| 2 | // |
| 3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
| 4 | // See https://llvm.org/LICENSE.txt for license information. |
| 5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
| 6 | // |
| 7 | //===----------------------------------------------------------------------===// |
| 8 | // |
| 9 | // This file implements the GPU dialect pattern rewriters that make GPU op |
| 10 | // within a region execute asynchronously. |
| 11 | // |
| 12 | //===----------------------------------------------------------------------===// |
| 13 | |
| 14 | #include "mlir/Dialect/GPU/Transforms/Passes.h" |
| 15 | |
| 16 | #include "mlir/Dialect/Async/IR/Async.h" |
| 17 | #include "mlir/Dialect/Func/IR/FuncOps.h" |
| 18 | #include "mlir/Dialect/GPU/IR/GPUDialect.h" |
| 19 | #include "mlir/Dialect/GPU/Utils/GPUUtils.h" |
| 20 | #include "mlir/IR/Builders.h" |
| 21 | #include "mlir/IR/IRMapping.h" |
| 22 | #include "mlir/Interfaces/SideEffectInterfaces.h" |
| 23 | #include "mlir/Support/LLVM.h" |
| 24 | #include "llvm/ADT/TypeSwitch.h" |
| 25 | |
| 26 | namespace mlir { |
| 27 | #define GEN_PASS_DEF_GPUASYNCREGIONPASS |
| 28 | #include "mlir/Dialect/GPU/Transforms/Passes.h.inc" |
| 29 | } // namespace mlir |
| 30 | |
| 31 | using namespace mlir; |
| 32 | |
| 33 | namespace { |
| 34 | class GpuAsyncRegionPass |
| 35 | : public impl::GpuAsyncRegionPassBase<GpuAsyncRegionPass> { |
| 36 | struct ThreadTokenCallback; |
| 37 | struct DeferWaitCallback; |
| 38 | struct SingleTokenUseCallback; |
| 39 | void runOnOperation() override; |
| 40 | }; |
| 41 | } // namespace |
| 42 | |
| 43 | static bool isTerminator(Operation *op) { |
| 44 | return op->mightHaveTrait<OpTrait::IsTerminator>(); |
| 45 | } |
| 46 | static bool hasSideEffects(Operation *op) { return !isMemoryEffectFree(op); } |
| 47 | |
| 48 | // Region walk callback which makes GPU ops implementing the AsyncOpInterface |
| 49 | // execute asynchronously. |
| 50 | struct GpuAsyncRegionPass::ThreadTokenCallback { |
| 51 | ThreadTokenCallback(MLIRContext &context) : builder(&context) {} |
| 52 | |
| 53 | WalkResult operator()(Block *block) { |
| 54 | for (Operation &op : make_early_inc_range(Range&: *block)) { |
| 55 | if (failed(Result: visit(op: &op))) |
| 56 | return WalkResult::interrupt(); |
| 57 | } |
| 58 | return WalkResult::advance(); |
| 59 | } |
| 60 | |
| 61 | private: |
| 62 | // If `op` implements the AsyncOpInterface, insert a `gpu.wait async` to |
| 63 | // create a current token (unless it already exists), and 'thread' that token |
| 64 | // through the `op` so that it executes asynchronously. |
| 65 | // |
| 66 | // If `op` is a terminator or an op with side-effects, insert a `gpu.wait` to |
| 67 | // host-synchronize execution. A `!gpu.async.token` will therefore only be |
| 68 | // used inside of its block and GPU execution will always synchronize with |
| 69 | // the host at block boundaries. |
| 70 | LogicalResult visit(Operation *op) { |
| 71 | if (isa<gpu::LaunchOp>(Val: op)) |
| 72 | return op->emitOpError(message: "replace with gpu.launch_func first" ); |
| 73 | if (auto waitOp = llvm::dyn_cast<gpu::WaitOp>(Val: op)) { |
| 74 | if (currentToken) |
| 75 | waitOp.addAsyncDependency(token: currentToken); |
| 76 | currentToken = waitOp.getAsyncToken(); |
| 77 | return success(); |
| 78 | } |
| 79 | builder.setInsertionPoint(op); |
| 80 | if (auto asyncOp = dyn_cast<gpu::AsyncOpInterface>(Val: op)) |
| 81 | return rewriteAsyncOp(asyncOp); // Replace GPU op with async version. |
| 82 | if (!currentToken) |
| 83 | return success(); |
| 84 | // Insert host synchronization before terminator or op with side effects. |
| 85 | if (isTerminator(op) || hasSideEffects(op)) |
| 86 | currentToken = createWaitOp(loc: op->getLoc(), resultType: Type(), operands: {currentToken}); |
| 87 | return success(); |
| 88 | } |
| 89 | |
| 90 | // Replaces asyncOp with a clone that returns a token. |
| 91 | LogicalResult rewriteAsyncOp(gpu::AsyncOpInterface asyncOp) { |
| 92 | auto *op = asyncOp.getOperation(); |
| 93 | auto tokenType = builder.getType<gpu::AsyncTokenType>(); |
| 94 | |
| 95 | // If there is no current token, insert a `gpu.wait async` without |
| 96 | // dependencies to create one. |
| 97 | if (!currentToken) |
| 98 | currentToken = createWaitOp(loc: op->getLoc(), resultType: tokenType, operands: {}); |
| 99 | asyncOp.addAsyncDependency(token: currentToken); |
| 100 | |
| 101 | // Return early if op returns a token already. |
| 102 | currentToken = asyncOp.getAsyncToken(); |
| 103 | if (currentToken) |
| 104 | return success(); |
| 105 | |
| 106 | // Clone the op to return a token in addition to the other results. |
| 107 | SmallVector<Type, 1> resultTypes; |
| 108 | resultTypes.reserve(N: 1 + op->getNumResults()); |
| 109 | copy(Range: op->getResultTypes(), Out: std::back_inserter(x&: resultTypes)); |
| 110 | resultTypes.push_back(Elt: tokenType); |
| 111 | auto *newOp = Operation::create( |
| 112 | location: op->getLoc(), name: op->getName(), resultTypes, operands: op->getOperands(), |
| 113 | attributes: op->getDiscardableAttrDictionary(), properties: op->getPropertiesStorage(), |
| 114 | successors: op->getSuccessors(), numRegions: op->getNumRegions()); |
| 115 | |
| 116 | // Clone regions into new op. |
| 117 | IRMapping mapping; |
| 118 | for (auto pair : llvm::zip_first(t: op->getRegions(), u: newOp->getRegions())) |
| 119 | std::get<0>(t&: pair).cloneInto(dest: &std::get<1>(t&: pair), mapper&: mapping); |
| 120 | |
| 121 | // Replace the op with the async clone. |
| 122 | auto results = newOp->getResults(); |
| 123 | currentToken = results.back(); |
| 124 | builder.insert(op: newOp); |
| 125 | op->replaceAllUsesWith(values: results.drop_back()); |
| 126 | op->erase(); |
| 127 | |
| 128 | return success(); |
| 129 | } |
| 130 | |
| 131 | Value createWaitOp(Location loc, Type resultType, ValueRange operands) { |
| 132 | return builder.create<gpu::WaitOp>(location: loc, args&: resultType, args&: operands) |
| 133 | .getAsyncToken(); |
| 134 | } |
| 135 | |
| 136 | OpBuilder builder; |
| 137 | |
| 138 | // The token that represents the current asynchronous dependency. It's valid |
| 139 | // range starts with a `gpu.wait async` op, and ends with a `gpu.wait` op. |
| 140 | // In between, each gpu::AsyncOpInterface depends on the current token and |
| 141 | // produces the new one. |
| 142 | Value currentToken = {}; |
| 143 | }; |
| 144 | |
| 145 | /// Erases `executeOp` and returns a clone with additional `results`. |
| 146 | async::ExecuteOp addExecuteResults(async::ExecuteOp executeOp, |
| 147 | ValueRange results) { |
| 148 | // Add values to async.yield op. |
| 149 | Operation *yieldOp = executeOp.getBody()->getTerminator(); |
| 150 | yieldOp->insertOperands(index: yieldOp->getNumOperands(), operands: results); |
| 151 | |
| 152 | // Construct new result type list with additional types. |
| 153 | SmallVector<Type, 2> resultTypes; |
| 154 | resultTypes.reserve(N: executeOp.getNumResults() + results.size()); |
| 155 | transform(Range: executeOp.getResultTypes(), d_first: std::back_inserter(x&: resultTypes), |
| 156 | F: [](Type type) { |
| 157 | // Extract value type from !async.value. |
| 158 | if (auto valueType = dyn_cast<async::ValueType>(Val&: type)) |
| 159 | return valueType.getValueType(); |
| 160 | assert(isa<async::TokenType>(type) && "expected token type" ); |
| 161 | return type; |
| 162 | }); |
| 163 | transform(Range&: results, d_first: std::back_inserter(x&: resultTypes), |
| 164 | F: [](Value value) { return value.getType(); }); |
| 165 | |
| 166 | // Clone executeOp with the extra results. |
| 167 | OpBuilder builder(executeOp); |
| 168 | auto newOp = builder.create<async::ExecuteOp>( |
| 169 | location: executeOp.getLoc(), args: TypeRange{resultTypes}.drop_front() /*drop token*/, |
| 170 | args: executeOp.getDependencies(), args: executeOp.getBodyOperands()); |
| 171 | IRMapping mapper; |
| 172 | newOp.getRegion().getBlocks().clear(); |
| 173 | executeOp.getRegion().cloneInto(dest: &newOp.getRegion(), mapper); |
| 174 | |
| 175 | // Replace executeOp with cloned one. |
| 176 | executeOp.getOperation()->replaceAllUsesWith( |
| 177 | values: newOp.getResults().drop_back(n: results.size())); |
| 178 | executeOp.erase(); |
| 179 | |
| 180 | return newOp; |
| 181 | } |
| 182 | |
| 183 | // Callback for `async.execute` ops which tries to push the contained |
| 184 | // synchronous `gpu.wait` op to the dependencies of the `async.execute`. |
| 185 | struct GpuAsyncRegionPass::DeferWaitCallback { |
| 186 | // If the `executeOp`s token is used only in `async.execute` or `async.await` |
| 187 | // ops, add the region's last `gpu.wait` op to the worklist if it is |
| 188 | // synchronous and is the last op with side effects. |
| 189 | void operator()(async::ExecuteOp executeOp) { |
| 190 | if (!areAllUsersExecuteOrAwait(token: executeOp.getToken())) |
| 191 | return; |
| 192 | // async.execute's region is currently restricted to one block. |
| 193 | for (auto &op : llvm::reverse(C: executeOp.getBody()->without_terminator())) { |
| 194 | if (auto waitOp = dyn_cast<gpu::WaitOp>(Val&: op)) { |
| 195 | if (!waitOp.getAsyncToken()) |
| 196 | worklist.push_back(Elt: waitOp); |
| 197 | return; |
| 198 | } |
| 199 | if (hasSideEffects(op: &op)) |
| 200 | return; |
| 201 | } |
| 202 | } |
| 203 | |
| 204 | // The destructor performs the actual rewrite work. |
| 205 | ~DeferWaitCallback() { |
| 206 | for (size_t i = 0; i < worklist.size(); ++i) { |
| 207 | auto waitOp = worklist[i]; |
| 208 | auto executeOp = waitOp->getParentOfType<async::ExecuteOp>(); |
| 209 | |
| 210 | // Erase `gpu.wait` and return async dependencies from execute op instead. |
| 211 | SmallVector<Value, 4> dependencies = waitOp.getAsyncDependencies(); |
| 212 | waitOp.erase(); |
| 213 | executeOp = addExecuteResults(executeOp, results: dependencies); |
| 214 | |
| 215 | // Add the async dependency to each user of the `async.execute` token. |
| 216 | auto asyncTokens = executeOp.getResults().take_back(n: dependencies.size()); |
| 217 | SmallVector<Operation *, 4> users(executeOp.getToken().user_begin(), |
| 218 | executeOp.getToken().user_end()); |
| 219 | for (Operation *user : users) |
| 220 | addAsyncDependencyAfter(asyncTokens, op: user); |
| 221 | } |
| 222 | } |
| 223 | |
| 224 | private: |
| 225 | // Returns whether all token users are either 'async.execute' or 'async.await' |
| 226 | // ops. This is used as a requirement for pushing 'gpu.wait' ops from a |
| 227 | // 'async.execute' body to it's users. Specifically, we do not allow |
| 228 | // terminator users, because it could mean that the `async.execute` is inside |
| 229 | // control flow code. |
| 230 | static bool areAllUsersExecuteOrAwait(Value token) { |
| 231 | return !token.use_empty() && |
| 232 | llvm::all_of(Range: token.getUsers(), |
| 233 | P: llvm::IsaPred<async::ExecuteOp, async::AwaitOp>); |
| 234 | } |
| 235 | |
| 236 | // Add the `asyncToken` as dependency as needed after `op`. |
| 237 | void addAsyncDependencyAfter(ValueRange asyncTokens, Operation *op) { |
| 238 | OpBuilder builder(op->getContext()); |
| 239 | auto loc = op->getLoc(); |
| 240 | |
| 241 | Block::iterator it; |
| 242 | SmallVector<Value, 1> tokens; |
| 243 | tokens.reserve(N: asyncTokens.size()); |
| 244 | TypeSwitch<Operation *>(op) |
| 245 | .Case<async::AwaitOp>(caseFn: [&](auto awaitOp) { |
| 246 | // Add async.await ops to wait for the !gpu.async.tokens. |
| 247 | builder.setInsertionPointAfter(op); |
| 248 | for (auto asyncToken : asyncTokens) |
| 249 | tokens.push_back( |
| 250 | Elt: builder.create<async::AwaitOp>(location: loc, args&: asyncToken).getResult()); |
| 251 | // Set `it` after the inserted async.await ops. |
| 252 | it = builder.getInsertionPoint(); |
| 253 | }) |
| 254 | .Case<async::ExecuteOp>(caseFn: [&](auto executeOp) { |
| 255 | // Set `it` to the beginning of the region and add asyncTokens to the |
| 256 | // async.execute operands. |
| 257 | it = executeOp.getBody()->begin(); |
| 258 | executeOp.getBodyOperandsMutable().append(asyncTokens); |
| 259 | SmallVector<Type, 1> tokenTypes( |
| 260 | asyncTokens.size(), builder.getType<gpu::AsyncTokenType>()); |
| 261 | SmallVector<Location, 1> tokenLocs(asyncTokens.size(), |
| 262 | executeOp.getLoc()); |
| 263 | copy(executeOp.getBody()->addArguments(tokenTypes, tokenLocs), |
| 264 | std::back_inserter(x&: tokens)); |
| 265 | }); |
| 266 | |
| 267 | // Advance `it` to terminator or op with side-effects. |
| 268 | it = std::find_if(first: it, last: Block::iterator(), pred: [](Operation &op) { |
| 269 | return isTerminator(op: &op) || hasSideEffects(op: &op); |
| 270 | }); |
| 271 | |
| 272 | // If `op` implements the AsyncOpInterface, add `token` to the list of async |
| 273 | // dependencies. |
| 274 | if (auto asyncOp = dyn_cast<gpu::AsyncOpInterface>(Val&: *it)) { |
| 275 | for (auto token : tokens) |
| 276 | asyncOp.addAsyncDependency(token); |
| 277 | return; |
| 278 | } |
| 279 | |
| 280 | // Otherwise, insert a gpu.wait before 'it'. |
| 281 | builder.setInsertionPoint(block: it->getBlock(), insertPoint: it); |
| 282 | auto waitOp = builder.create<gpu::WaitOp>(location: loc, args: Type{}, args&: tokens); |
| 283 | |
| 284 | // If the new waitOp is at the end of an async.execute region, add it to the |
| 285 | // worklist. 'operator()(executeOp)' would do the same, but this is faster. |
| 286 | auto executeOp = dyn_cast<async::ExecuteOp>(Val: it->getParentOp()); |
| 287 | if (executeOp && areAllUsersExecuteOrAwait(token: executeOp.getToken()) && |
| 288 | !it->getNextNode()) |
| 289 | worklist.push_back(Elt: waitOp); |
| 290 | } |
| 291 | |
| 292 | SmallVector<gpu::WaitOp, 8> worklist; |
| 293 | }; |
| 294 | |
| 295 | // Callback for `async.execute` ops which repeats !gpu.async.token results |
| 296 | // so that each of them is only used once. |
| 297 | struct GpuAsyncRegionPass::SingleTokenUseCallback { |
| 298 | void operator()(async::ExecuteOp executeOp) { |
| 299 | // Extract !gpu.async.token results which have multiple uses. |
| 300 | auto multiUseResults = llvm::make_filter_range( |
| 301 | Range: executeOp.getBodyResults(), Pred: [](OpResult result) { |
| 302 | if (result.use_empty() || result.hasOneUse()) |
| 303 | return false; |
| 304 | auto valueType = dyn_cast<async::ValueType>(Val: result.getType()); |
| 305 | return valueType && |
| 306 | isa<gpu::AsyncTokenType>(Val: valueType.getValueType()); |
| 307 | }); |
| 308 | if (multiUseResults.empty()) |
| 309 | return; |
| 310 | |
| 311 | // Indices within !async.execute results (i.e. without the async.token). |
| 312 | SmallVector<int, 4> indices; |
| 313 | transform(Range&: multiUseResults, d_first: std::back_inserter(x&: indices), |
| 314 | F: [](OpResult result) { |
| 315 | return result.getResultNumber() - 1; // Index without token. |
| 316 | }); |
| 317 | |
| 318 | for (auto index : indices) { |
| 319 | assert(!executeOp.getBodyResults()[index].getUses().empty()); |
| 320 | // Repeat async.yield token result, one for each use after the first one. |
| 321 | auto uses = llvm::drop_begin(RangeOrContainer: executeOp.getBodyResults()[index].getUses()); |
| 322 | auto count = std::distance(first: uses.begin(), last: uses.end()); |
| 323 | auto yieldOp = cast<async::YieldOp>(Val: executeOp.getBody()->getTerminator()); |
| 324 | SmallVector<Value, 4> operands(count, yieldOp.getOperand(i: index)); |
| 325 | executeOp = addExecuteResults(executeOp, results: operands); |
| 326 | // Update 'uses' to refer to the new executeOp. |
| 327 | uses = llvm::drop_begin(RangeOrContainer: executeOp.getBodyResults()[index].getUses()); |
| 328 | auto results = executeOp.getBodyResults().take_back(n: count); |
| 329 | for (auto pair : llvm::zip(t&: uses, u&: results)) |
| 330 | std::get<0>(t&: pair).set(std::get<1>(t&: pair)); |
| 331 | } |
| 332 | } |
| 333 | }; |
| 334 | |
| 335 | // Replaces synchronous GPU ops in the op's region with asynchronous ones and |
| 336 | // inserts the necessary synchronization (as gpu.wait ops). Assumes sequential |
| 337 | // execution semantics and that no GPU ops are asynchronous yet. |
| 338 | void GpuAsyncRegionPass::runOnOperation() { |
| 339 | if (getOperation()->walk(callback: ThreadTokenCallback(getContext())).wasInterrupted()) |
| 340 | return signalPassFailure(); |
| 341 | |
| 342 | // Collect gpu.wait ops that we can move out of async.execute regions. |
| 343 | getOperation().getRegion().walk(callback: DeferWaitCallback()); |
| 344 | // Makes each !gpu.async.token returned from async.execute op have single use. |
| 345 | getOperation().getRegion().walk(callback: SingleTokenUseCallback()); |
| 346 | } |
| 347 | |