| 1 | //===- AsyncRegionRewriter.cpp - Implementation of GPU async rewriters ----===// |
| 2 | // |
| 3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
| 4 | // See https://llvm.org/LICENSE.txt for license information. |
| 5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
| 6 | // |
| 7 | //===----------------------------------------------------------------------===// |
| 8 | // |
| 9 | // This file implements the GPU dialect pattern rewriters that make GPU op |
| 10 | // within a region execute asynchronously. |
| 11 | // |
| 12 | //===----------------------------------------------------------------------===// |
| 13 | |
| 14 | #include "mlir/Dialect/GPU/Transforms/Passes.h" |
| 15 | |
| 16 | #include "mlir/Dialect/Async/IR/Async.h" |
| 17 | #include "mlir/Dialect/Func/IR/FuncOps.h" |
| 18 | #include "mlir/Dialect/GPU/IR/GPUDialect.h" |
| 19 | #include "mlir/Dialect/GPU/Utils/GPUUtils.h" |
| 20 | #include "mlir/IR/Builders.h" |
| 21 | #include "mlir/IR/IRMapping.h" |
| 22 | #include "mlir/IR/PatternMatch.h" |
| 23 | #include "mlir/IR/SymbolTable.h" |
| 24 | #include "mlir/Interfaces/SideEffectInterfaces.h" |
| 25 | #include "mlir/Support/LLVM.h" |
| 26 | #include "mlir/Transforms/RegionUtils.h" |
| 27 | #include "llvm/ADT/TypeSwitch.h" |
| 28 | |
| 29 | namespace mlir { |
| 30 | #define GEN_PASS_DEF_GPUASYNCREGIONPASS |
| 31 | #include "mlir/Dialect/GPU/Transforms/Passes.h.inc" |
| 32 | } // namespace mlir |
| 33 | |
| 34 | using namespace mlir; |
| 35 | |
| 36 | namespace { |
| 37 | class GpuAsyncRegionPass |
| 38 | : public impl::GpuAsyncRegionPassBase<GpuAsyncRegionPass> { |
| 39 | struct ThreadTokenCallback; |
| 40 | struct DeferWaitCallback; |
| 41 | struct SingleTokenUseCallback; |
| 42 | void runOnOperation() override; |
| 43 | }; |
| 44 | } // namespace |
| 45 | |
| 46 | static bool isTerminator(Operation *op) { |
| 47 | return op->mightHaveTrait<OpTrait::IsTerminator>(); |
| 48 | } |
| 49 | static bool hasSideEffects(Operation *op) { return !isMemoryEffectFree(op); } |
| 50 | |
| 51 | // Region walk callback which makes GPU ops implementing the AsyncOpInterface |
| 52 | // execute asynchronously. |
| 53 | struct GpuAsyncRegionPass::ThreadTokenCallback { |
| 54 | ThreadTokenCallback(MLIRContext &context) : builder(&context) {} |
| 55 | |
| 56 | WalkResult operator()(Block *block) { |
| 57 | for (Operation &op : make_early_inc_range(Range&: *block)) { |
| 58 | if (failed(Result: visit(op: &op))) |
| 59 | return WalkResult::interrupt(); |
| 60 | } |
| 61 | return WalkResult::advance(); |
| 62 | } |
| 63 | |
| 64 | private: |
| 65 | // If `op` implements the AsyncOpInterface, insert a `gpu.wait async` to |
| 66 | // create a current token (unless it already exists), and 'thread' that token |
| 67 | // through the `op` so that it executes asynchronously. |
| 68 | // |
| 69 | // If `op` is a terminator or an op with side-effects, insert a `gpu.wait` to |
| 70 | // host-synchronize execution. A `!gpu.async.token` will therefore only be |
| 71 | // used inside of its block and GPU execution will always synchronize with |
| 72 | // the host at block boundaries. |
| 73 | LogicalResult visit(Operation *op) { |
| 74 | if (isa<gpu::LaunchOp>(Val: op)) |
| 75 | return op->emitOpError(message: "replace with gpu.launch_func first" ); |
| 76 | if (auto waitOp = llvm::dyn_cast<gpu::WaitOp>(op)) { |
| 77 | if (currentToken) |
| 78 | waitOp.addAsyncDependency(currentToken); |
| 79 | currentToken = waitOp.getAsyncToken(); |
| 80 | return success(); |
| 81 | } |
| 82 | builder.setInsertionPoint(op); |
| 83 | if (auto asyncOp = dyn_cast<gpu::AsyncOpInterface>(op)) |
| 84 | return rewriteAsyncOp(asyncOp); // Replace GPU op with async version. |
| 85 | if (!currentToken) |
| 86 | return success(); |
| 87 | // Insert host synchronization before terminator or op with side effects. |
| 88 | if (isTerminator(op) || hasSideEffects(op)) |
| 89 | currentToken = createWaitOp(loc: op->getLoc(), resultType: Type(), operands: {currentToken}); |
| 90 | return success(); |
| 91 | } |
| 92 | |
| 93 | // Replaces asyncOp with a clone that returns a token. |
| 94 | LogicalResult rewriteAsyncOp(gpu::AsyncOpInterface asyncOp) { |
| 95 | auto *op = asyncOp.getOperation(); |
| 96 | auto tokenType = builder.getType<gpu::AsyncTokenType>(); |
| 97 | |
| 98 | // If there is no current token, insert a `gpu.wait async` without |
| 99 | // dependencies to create one. |
| 100 | if (!currentToken) |
| 101 | currentToken = createWaitOp(loc: op->getLoc(), resultType: tokenType, operands: {}); |
| 102 | asyncOp.addAsyncDependency(currentToken); |
| 103 | |
| 104 | // Return early if op returns a token already. |
| 105 | currentToken = asyncOp.getAsyncToken(); |
| 106 | if (currentToken) |
| 107 | return success(); |
| 108 | |
| 109 | // Clone the op to return a token in addition to the other results. |
| 110 | SmallVector<Type, 1> resultTypes; |
| 111 | resultTypes.reserve(N: 1 + op->getNumResults()); |
| 112 | copy(op->getResultTypes(), std::back_inserter(x&: resultTypes)); |
| 113 | resultTypes.push_back(Elt: tokenType); |
| 114 | auto *newOp = Operation::create( |
| 115 | op->getLoc(), op->getName(), resultTypes, op->getOperands(), |
| 116 | op->getDiscardableAttrDictionary(), op->getPropertiesStorage(), |
| 117 | op->getSuccessors(), op->getNumRegions()); |
| 118 | |
| 119 | // Clone regions into new op. |
| 120 | IRMapping mapping; |
| 121 | for (auto pair : llvm::zip_first(op->getRegions(), newOp->getRegions())) |
| 122 | std::get<0>(pair).cloneInto(&std::get<1>(pair), mapping); |
| 123 | |
| 124 | // Replace the op with the async clone. |
| 125 | auto results = newOp->getResults(); |
| 126 | currentToken = results.back(); |
| 127 | builder.insert(op: newOp); |
| 128 | op->replaceAllUsesWith(results.drop_back()); |
| 129 | op->erase(); |
| 130 | |
| 131 | return success(); |
| 132 | } |
| 133 | |
| 134 | Value createWaitOp(Location loc, Type resultType, ValueRange operands) { |
| 135 | return builder.create<gpu::WaitOp>(loc, resultType, operands) |
| 136 | .getAsyncToken(); |
| 137 | } |
| 138 | |
| 139 | OpBuilder builder; |
| 140 | |
| 141 | // The token that represents the current asynchronous dependency. It's valid |
| 142 | // range starts with a `gpu.wait async` op, and ends with a `gpu.wait` op. |
| 143 | // In between, each gpu::AsyncOpInterface depends on the current token and |
| 144 | // produces the new one. |
| 145 | Value currentToken = {}; |
| 146 | }; |
| 147 | |
| 148 | /// Erases `executeOp` and returns a clone with additional `results`. |
| 149 | async::ExecuteOp addExecuteResults(async::ExecuteOp executeOp, |
| 150 | ValueRange results) { |
| 151 | // Add values to async.yield op. |
| 152 | Operation *yieldOp = executeOp.getBody()->getTerminator(); |
| 153 | yieldOp->insertOperands(index: yieldOp->getNumOperands(), operands: results); |
| 154 | |
| 155 | // Construct new result type list with additional types. |
| 156 | SmallVector<Type, 2> resultTypes; |
| 157 | resultTypes.reserve(N: executeOp.getNumResults() + results.size()); |
| 158 | transform(executeOp.getResultTypes(), std::back_inserter(x&: resultTypes), |
| 159 | [](Type type) { |
| 160 | // Extract value type from !async.value. |
| 161 | if (auto valueType = dyn_cast<async::ValueType>(type)) |
| 162 | return valueType.getValueType(); |
| 163 | assert(isa<async::TokenType>(type) && "expected token type" ); |
| 164 | return type; |
| 165 | }); |
| 166 | transform(Range&: results, d_first: std::back_inserter(x&: resultTypes), |
| 167 | F: [](Value value) { return value.getType(); }); |
| 168 | |
| 169 | // Clone executeOp with the extra results. |
| 170 | OpBuilder builder(executeOp); |
| 171 | auto newOp = builder.create<async::ExecuteOp>( |
| 172 | executeOp.getLoc(), TypeRange{resultTypes}.drop_front() /*drop token*/, |
| 173 | executeOp.getDependencies(), executeOp.getBodyOperands()); |
| 174 | IRMapping mapper; |
| 175 | newOp.getRegion().getBlocks().clear(); |
| 176 | executeOp.getRegion().cloneInto(&newOp.getRegion(), mapper); |
| 177 | |
| 178 | // Replace executeOp with cloned one. |
| 179 | executeOp.getOperation()->replaceAllUsesWith( |
| 180 | newOp.getResults().drop_back(results.size())); |
| 181 | executeOp.erase(); |
| 182 | |
| 183 | return newOp; |
| 184 | } |
| 185 | |
| 186 | // Callback for `async.execute` ops which tries to push the contained |
| 187 | // synchronous `gpu.wait` op to the dependencies of the `async.execute`. |
| 188 | struct GpuAsyncRegionPass::DeferWaitCallback { |
| 189 | // If the `executeOp`s token is used only in `async.execute` or `async.await` |
| 190 | // ops, add the region's last `gpu.wait` op to the worklist if it is |
| 191 | // synchronous and is the last op with side effects. |
| 192 | void operator()(async::ExecuteOp executeOp) { |
| 193 | if (!areAllUsersExecuteOrAwait(token: executeOp.getToken())) |
| 194 | return; |
| 195 | // async.execute's region is currently restricted to one block. |
| 196 | for (auto &op : llvm::reverse(executeOp.getBody()->without_terminator())) { |
| 197 | if (auto waitOp = dyn_cast<gpu::WaitOp>(op)) { |
| 198 | if (!waitOp.getAsyncToken()) |
| 199 | worklist.push_back(waitOp); |
| 200 | return; |
| 201 | } |
| 202 | if (hasSideEffects(&op)) |
| 203 | return; |
| 204 | } |
| 205 | } |
| 206 | |
| 207 | // The destructor performs the actual rewrite work. |
| 208 | ~DeferWaitCallback() { |
| 209 | for (size_t i = 0; i < worklist.size(); ++i) { |
| 210 | auto waitOp = worklist[i]; |
| 211 | auto executeOp = waitOp->getParentOfType<async::ExecuteOp>(); |
| 212 | |
| 213 | // Erase `gpu.wait` and return async dependencies from execute op instead. |
| 214 | SmallVector<Value, 4> dependencies = waitOp.getAsyncDependencies(); |
| 215 | waitOp.erase(); |
| 216 | executeOp = addExecuteResults(executeOp, dependencies); |
| 217 | |
| 218 | // Add the async dependency to each user of the `async.execute` token. |
| 219 | auto asyncTokens = executeOp.getResults().take_back(dependencies.size()); |
| 220 | SmallVector<Operation *, 4> users(executeOp.getToken().user_begin(), |
| 221 | executeOp.getToken().user_end()); |
| 222 | for (Operation *user : users) |
| 223 | addAsyncDependencyAfter(asyncTokens, user); |
| 224 | } |
| 225 | } |
| 226 | |
| 227 | private: |
| 228 | // Returns whether all token users are either 'async.execute' or 'async.await' |
| 229 | // ops. This is used as a requirement for pushing 'gpu.wait' ops from a |
| 230 | // 'async.execute' body to it's users. Specifically, we do not allow |
| 231 | // terminator users, because it could mean that the `async.execute` is inside |
| 232 | // control flow code. |
| 233 | static bool areAllUsersExecuteOrAwait(Value token) { |
| 234 | return !token.use_empty() && |
| 235 | llvm::all_of(token.getUsers(), |
| 236 | llvm::IsaPred<async::ExecuteOp, async::AwaitOp>); |
| 237 | } |
| 238 | |
| 239 | // Add the `asyncToken` as dependency as needed after `op`. |
| 240 | void addAsyncDependencyAfter(ValueRange asyncTokens, Operation *op) { |
| 241 | OpBuilder builder(op->getContext()); |
| 242 | auto loc = op->getLoc(); |
| 243 | |
| 244 | Block::iterator it; |
| 245 | SmallVector<Value, 1> tokens; |
| 246 | tokens.reserve(N: asyncTokens.size()); |
| 247 | TypeSwitch<Operation *>(op) |
| 248 | .Case<async::AwaitOp>([&](auto awaitOp) { |
| 249 | // Add async.await ops to wait for the !gpu.async.tokens. |
| 250 | builder.setInsertionPointAfter(op); |
| 251 | for (auto asyncToken : asyncTokens) |
| 252 | tokens.push_back( |
| 253 | builder.create<async::AwaitOp>(loc, asyncToken).getResult()); |
| 254 | // Set `it` after the inserted async.await ops. |
| 255 | it = builder.getInsertionPoint(); |
| 256 | }) |
| 257 | .Case<async::ExecuteOp>([&](auto executeOp) { |
| 258 | // Set `it` to the beginning of the region and add asyncTokens to the |
| 259 | // async.execute operands. |
| 260 | it = executeOp.getBody()->begin(); |
| 261 | executeOp.getBodyOperandsMutable().append(asyncTokens); |
| 262 | SmallVector<Type, 1> tokenTypes( |
| 263 | asyncTokens.size(), builder.getType<gpu::AsyncTokenType>()); |
| 264 | SmallVector<Location, 1> tokenLocs(asyncTokens.size(), |
| 265 | executeOp.getLoc()); |
| 266 | copy(executeOp.getBody()->addArguments(tokenTypes, tokenLocs), |
| 267 | std::back_inserter(tokens)); |
| 268 | }); |
| 269 | |
| 270 | // Advance `it` to terminator or op with side-effects. |
| 271 | it = std::find_if(first: it, last: Block::iterator(), pred: [](Operation &op) { |
| 272 | return isTerminator(op: &op) || hasSideEffects(op: &op); |
| 273 | }); |
| 274 | |
| 275 | // If `op` implements the AsyncOpInterface, add `token` to the list of async |
| 276 | // dependencies. |
| 277 | if (auto asyncOp = dyn_cast<gpu::AsyncOpInterface>(*it)) { |
| 278 | for (auto token : tokens) |
| 279 | asyncOp.addAsyncDependency(token); |
| 280 | return; |
| 281 | } |
| 282 | |
| 283 | // Otherwise, insert a gpu.wait before 'it'. |
| 284 | builder.setInsertionPoint(block: it->getBlock(), insertPoint: it); |
| 285 | auto waitOp = builder.create<gpu::WaitOp>(loc, Type{}, tokens); |
| 286 | |
| 287 | // If the new waitOp is at the end of an async.execute region, add it to the |
| 288 | // worklist. 'operator()(executeOp)' would do the same, but this is faster. |
| 289 | auto executeOp = dyn_cast<async::ExecuteOp>(it->getParentOp()); |
| 290 | if (executeOp && areAllUsersExecuteOrAwait(executeOp.getToken()) && |
| 291 | !it->getNextNode()) |
| 292 | worklist.push_back(waitOp); |
| 293 | } |
| 294 | |
| 295 | SmallVector<gpu::WaitOp, 8> worklist; |
| 296 | }; |
| 297 | |
| 298 | // Callback for `async.execute` ops which repeats !gpu.async.token results |
| 299 | // so that each of them is only used once. |
| 300 | struct GpuAsyncRegionPass::SingleTokenUseCallback { |
| 301 | void operator()(async::ExecuteOp executeOp) { |
| 302 | // Extract !gpu.async.token results which have multiple uses. |
| 303 | auto multiUseResults = llvm::make_filter_range( |
| 304 | executeOp.getBodyResults(), [](OpResult result) { |
| 305 | if (result.use_empty() || result.hasOneUse()) |
| 306 | return false; |
| 307 | auto valueType = dyn_cast<async::ValueType>(result.getType()); |
| 308 | return valueType && |
| 309 | isa<gpu::AsyncTokenType>(valueType.getValueType()); |
| 310 | }); |
| 311 | if (multiUseResults.empty()) |
| 312 | return; |
| 313 | |
| 314 | // Indices within !async.execute results (i.e. without the async.token). |
| 315 | SmallVector<int, 4> indices; |
| 316 | transform(multiUseResults, std::back_inserter(x&: indices), |
| 317 | [](OpResult result) { |
| 318 | return result.getResultNumber() - 1; // Index without token. |
| 319 | }); |
| 320 | |
| 321 | for (auto index : indices) { |
| 322 | assert(!executeOp.getBodyResults()[index].getUses().empty()); |
| 323 | // Repeat async.yield token result, one for each use after the first one. |
| 324 | auto uses = llvm::drop_begin(executeOp.getBodyResults()[index].getUses()); |
| 325 | auto count = std::distance(uses.begin(), uses.end()); |
| 326 | auto yieldOp = cast<async::YieldOp>(executeOp.getBody()->getTerminator()); |
| 327 | SmallVector<Value, 4> operands(count, yieldOp.getOperand(index)); |
| 328 | executeOp = addExecuteResults(executeOp, operands); |
| 329 | // Update 'uses' to refer to the new executeOp. |
| 330 | uses = llvm::drop_begin(executeOp.getBodyResults()[index].getUses()); |
| 331 | auto results = executeOp.getBodyResults().take_back(count); |
| 332 | for (auto pair : llvm::zip(uses, results)) |
| 333 | std::get<0>(pair).set(std::get<1>(pair)); |
| 334 | } |
| 335 | } |
| 336 | }; |
| 337 | |
| 338 | // Replaces synchronous GPU ops in the op's region with asynchronous ones and |
| 339 | // inserts the necessary synchronization (as gpu.wait ops). Assumes sequential |
| 340 | // execution semantics and that no GPU ops are asynchronous yet. |
| 341 | void GpuAsyncRegionPass::runOnOperation() { |
| 342 | if (getOperation()->walk(ThreadTokenCallback(getContext())).wasInterrupted()) |
| 343 | return signalPassFailure(); |
| 344 | |
| 345 | // Collect gpu.wait ops that we can move out of async.execute regions. |
| 346 | getOperation().getRegion().walk(DeferWaitCallback()); |
| 347 | // Makes each !gpu.async.token returned from async.execute op have single use. |
| 348 | getOperation().getRegion().walk(SingleTokenUseCallback()); |
| 349 | } |
| 350 | |