1 | //===- VectorDistribute.cpp - patterns to do vector distribution ----------===// |
---|---|
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | |
9 | #include "mlir/Dialect/Affine/IR/AffineOps.h" |
10 | #include "mlir/Dialect/Arith/IR/Arith.h" |
11 | #include "mlir/Dialect/GPU/IR/GPUDialect.h" |
12 | #include "mlir/Dialect/GPU/Utils/DistributionUtils.h" |
13 | #include "mlir/Dialect/MemRef/IR/MemRef.h" |
14 | #include "mlir/Dialect/SCF/IR/SCF.h" |
15 | #include "mlir/Dialect/Vector/IR/VectorOps.h" |
16 | #include "mlir/Dialect/Vector/Transforms/VectorDistribution.h" |
17 | #include "mlir/IR/AffineExpr.h" |
18 | #include "mlir/Interfaces/SideEffectInterfaces.h" |
19 | #include "mlir/Transforms/RegionUtils.h" |
20 | #include "llvm/ADT/SetVector.h" |
21 | #include "llvm/Support/FormatVariadic.h" |
22 | #include <utility> |
23 | |
24 | using namespace mlir; |
25 | using namespace mlir::vector; |
26 | using namespace mlir::gpu; |
27 | |
28 | /// Currently the distribution map is implicit based on the vector shape. In the |
29 | /// future it will be part of the op. |
30 | /// Example: |
31 | /// ``` |
32 | /// %0 = gpu.warp_execute_on_lane_0(%arg0) -> (vector<1x16x2xf32>) { |
33 | /// ... |
34 | /// gpu.yield %3 : vector<32x16x64xf32> |
35 | /// } |
36 | /// ``` |
37 | /// Would have an implicit map of: |
38 | /// `(d0, d1, d2) -> (d0, d2)` |
39 | static AffineMap calculateImplicitMap(VectorType sequentialType, |
40 | VectorType distributedType) { |
41 | SmallVector<AffineExpr> perm; |
42 | perm.reserve(N: 1); |
43 | // Check which dimensions of the sequential type are different than the |
44 | // dimensions of the distributed type to know the distributed dimensions. Then |
45 | // associate each distributed dimension to an ID in order. |
46 | for (unsigned i = 0, e = sequentialType.getRank(); i < e; i++) { |
47 | if (sequentialType.getDimSize(i) != distributedType.getDimSize(i)) |
48 | perm.push_back(Elt: getAffineDimExpr(i, distributedType.getContext())); |
49 | } |
50 | auto map = AffineMap::get(sequentialType.getRank(), 0, perm, |
51 | distributedType.getContext()); |
52 | return map; |
53 | } |
54 | |
55 | namespace { |
56 | |
57 | /// Helper struct to create the load / store operations that permit transit |
58 | /// through the parallel / sequential and the sequential / parallel boundaries |
59 | /// when performing `rewriteWarpOpToScfFor`. |
60 | /// |
61 | /// The vector distribution dimension is inferred from the vector types. |
62 | struct DistributedLoadStoreHelper { |
63 | DistributedLoadStoreHelper(Value sequentialVal, Value distributedVal, |
64 | Value laneId, Value zero) |
65 | : sequentialVal(sequentialVal), distributedVal(distributedVal), |
66 | laneId(laneId), zero(zero) { |
67 | sequentialVectorType = dyn_cast<VectorType>(sequentialVal.getType()); |
68 | distributedVectorType = dyn_cast<VectorType>(distributedVal.getType()); |
69 | if (sequentialVectorType && distributedVectorType) |
70 | distributionMap = |
71 | calculateImplicitMap(sequentialVectorType, distributedVectorType); |
72 | } |
73 | |
74 | Value buildDistributedOffset(RewriterBase &b, Location loc, int64_t index) { |
75 | int64_t distributedSize = distributedVectorType.getDimSize(index); |
76 | AffineExpr tid = getAffineSymbolExpr(position: 0, context: b.getContext()); |
77 | return b.createOrFold<affine::AffineApplyOp>(loc, tid * distributedSize, |
78 | ArrayRef<Value>{laneId}); |
79 | } |
80 | |
81 | /// Create a store during the process of distributing the |
82 | /// `vector.warp_execute_on_thread_0` op. |
83 | /// Vector distribution assumes the following convention regarding the |
84 | /// temporary buffers that are created to transition values. This **must** |
85 | /// be properly specified in the `options.warpAllocationFn`: |
86 | /// 1. scalars of type T transit through a memref<1xT>. |
87 | /// 2. vectors of type V<shapexT> transit through a memref<shapexT> |
88 | Operation *buildStore(RewriterBase &b, Location loc, Value val, |
89 | Value buffer) { |
90 | assert((val == distributedVal || val == sequentialVal) && |
91 | "Must store either the preregistered distributed or the " |
92 | "preregistered sequential value."); |
93 | // Scalar case can directly use memref.store. |
94 | if (!isa<VectorType>(val.getType())) |
95 | return b.create<memref::StoreOp>(loc, val, buffer, zero); |
96 | |
97 | // Vector case must use vector::TransferWriteOp which will later lower to |
98 | // vector.store of memref.store depending on further lowerings. |
99 | int64_t rank = sequentialVectorType.getRank(); |
100 | SmallVector<Value> indices(rank, zero); |
101 | if (val == distributedVal) { |
102 | for (auto dimExpr : distributionMap.getResults()) { |
103 | int64_t index = cast<AffineDimExpr>(Val&: dimExpr).getPosition(); |
104 | indices[index] = buildDistributedOffset(b, loc, index); |
105 | } |
106 | } |
107 | SmallVector<bool> inBounds(indices.size(), true); |
108 | return b.create<vector::TransferWriteOp>( |
109 | loc, val, buffer, indices, |
110 | ArrayRef<bool>(inBounds.begin(), inBounds.end())); |
111 | } |
112 | |
113 | /// Create a load during the process of distributing the |
114 | /// `vector.warp_execute_on_thread_0` op. |
115 | /// Vector distribution assumes the following convention regarding the |
116 | /// temporary buffers that are created to transition values. This **must** |
117 | /// be properly specified in the `options.warpAllocationFn`: |
118 | /// 1. scalars of type T transit through a memref<1xT>. |
119 | /// 2. vectors of type V<shapexT> transit through a memref<shapexT> |
120 | /// |
121 | /// When broadcastMode is true, the load is not distributed to account for |
122 | /// the broadcast semantics of the `gpu.warp_execute_on_lane_0` op. |
123 | /// |
124 | /// Example: |
125 | /// |
126 | /// ``` |
127 | /// %r = gpu.warp_execute_on_lane_0(...) -> (f32) { |
128 | /// gpu.yield %cst : f32 |
129 | /// } |
130 | /// // Both types are f32. The constant %cst is broadcasted to all lanes. |
131 | /// ``` |
132 | /// This behavior described in more detail in the documentation of the op. |
133 | Value buildLoad(RewriterBase &b, Location loc, Type type, Value buffer) { |
134 | |
135 | // Scalar case can directly use memref.store. |
136 | if (!isa<VectorType>(type)) |
137 | return b.create<memref::LoadOp>(loc, buffer, zero); |
138 | |
139 | // Other cases must be vector atm. |
140 | // Vector case must use vector::TransferReadOp which will later lower to |
141 | // vector.read of memref.read depending on further lowerings. |
142 | assert((type == distributedVectorType || type == sequentialVectorType) && |
143 | "Must store either the preregistered distributed or the " |
144 | "preregistered sequential type."); |
145 | SmallVector<Value> indices(sequentialVectorType.getRank(), zero); |
146 | if (type == distributedVectorType) { |
147 | for (auto dimExpr : distributionMap.getResults()) { |
148 | int64_t index = cast<AffineDimExpr>(Val&: dimExpr).getPosition(); |
149 | indices[index] = buildDistributedOffset(b, loc, index); |
150 | } |
151 | } |
152 | SmallVector<bool> inBounds(indices.size(), true); |
153 | return b.create<vector::TransferReadOp>( |
154 | loc, cast<VectorType>(type), buffer, indices, |
155 | ArrayRef<bool>(inBounds.begin(), inBounds.end())); |
156 | } |
157 | |
158 | Value sequentialVal, distributedVal, laneId, zero; |
159 | VectorType sequentialVectorType, distributedVectorType; |
160 | AffineMap distributionMap; |
161 | }; |
162 | |
163 | } // namespace |
164 | |
165 | // Clones `op` into a new operation that takes `operands` and returns |
166 | // `resultTypes`. |
167 | static Operation *cloneOpWithOperandsAndTypes(RewriterBase &rewriter, |
168 | Location loc, Operation *op, |
169 | ArrayRef<Value> operands, |
170 | ArrayRef<Type> resultTypes) { |
171 | OperationState res(loc, op->getName().getStringRef(), operands, resultTypes, |
172 | op->getAttrs()); |
173 | return rewriter.create(state: res); |
174 | } |
175 | |
176 | namespace { |
177 | |
178 | /// Rewrite a WarpExecuteOnLane0Op into a predicated scf.if op where the single |
179 | /// thread `laneId` executes the entirety of the computation. |
180 | /// |
181 | /// After the transformation: |
182 | /// - the IR within the scf.if op can be thought of as executing sequentially |
183 | /// (from the point of view of threads along `laneId`). |
184 | /// - the IR outside of the scf.if op can be thought of as executing in |
185 | /// parallel (from the point of view of threads along `laneId`). |
186 | /// |
187 | /// Values that need to transit through the parallel / sequential and the |
188 | /// sequential / parallel boundaries do so via reads and writes to a temporary |
189 | /// memory location. |
190 | /// |
191 | /// The transformation proceeds in multiple steps: |
192 | /// 1. Create the scf.if op. |
193 | /// 2. Insert appropriate (alloc, write)-pairs before the scf.if and reads |
194 | /// within the scf.if to transit the values captured from above. |
195 | /// 3. Synchronize before the scf.if to ensure all writes inserted in 2. are |
196 | /// consistent within the scf.if. |
197 | /// 4. Move the body of the WarpExecuteOnLane0Op inside the scf.if. |
198 | /// 5. Insert appropriate writes within scf.if and reads after the scf.if to |
199 | /// transit the values returned by the op. |
200 | /// 6. Synchronize after the scf.if to ensure all writes inserted in 5. are |
201 | /// consistent after the scf.if. |
202 | /// 7. Perform late cleanups. |
203 | /// |
204 | /// All this assumes the vector distribution occurs along the most minor |
205 | /// distributed vector dimension. |
206 | struct WarpOpToScfIfPattern : public WarpDistributionPattern { |
207 | WarpOpToScfIfPattern(MLIRContext *context, |
208 | const WarpExecuteOnLane0LoweringOptions &options, |
209 | PatternBenefit benefit = 1) |
210 | : WarpDistributionPattern(context, benefit), options(options) {} |
211 | |
212 | LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp, |
213 | PatternRewriter &rewriter) const override { |
214 | assert(warpOp.getBodyRegion().hasOneBlock() && |
215 | "expected WarpOp with single block"); |
216 | Block *warpOpBody = &warpOp.getBodyRegion().front(); |
217 | Location loc = warpOp.getLoc(); |
218 | |
219 | // Passed all checks. Start rewriting. |
220 | OpBuilder::InsertionGuard g(rewriter); |
221 | rewriter.setInsertionPoint(warpOp); |
222 | |
223 | // Step 1: Create scf.if op. |
224 | Value c0 = rewriter.create<arith::ConstantIndexOp>(location: loc, args: 0); |
225 | Value isLane0 = rewriter.create<arith::CmpIOp>( |
226 | loc, arith::CmpIPredicate::eq, warpOp.getLaneid(), c0); |
227 | auto ifOp = rewriter.create<scf::IfOp>(loc, isLane0, |
228 | /*withElseRegion=*/false); |
229 | rewriter.eraseOp(op: ifOp.thenBlock()->getTerminator()); |
230 | |
231 | // Step 2: insert appropriate (alloc, write)-pairs before the scf.if and |
232 | // reads within the scf.if to transit the values captured from above. |
233 | SmallVector<Value> bbArgReplacements; |
234 | for (const auto &it : llvm::enumerate(warpOp.getArgs())) { |
235 | Value sequentialVal = warpOpBody->getArgument(it.index()); |
236 | Value distributedVal = it.value(); |
237 | DistributedLoadStoreHelper helper(sequentialVal, distributedVal, |
238 | warpOp.getLaneid(), c0); |
239 | |
240 | // Create buffer before the ifOp. |
241 | rewriter.setInsertionPoint(ifOp); |
242 | Value buffer = options.warpAllocationFn(loc, rewriter, warpOp, |
243 | sequentialVal.getType()); |
244 | // Store distributed vector into buffer, before the ifOp. |
245 | helper.buildStore(rewriter, loc, distributedVal, buffer); |
246 | // Load sequential vector from buffer, inside the ifOp. |
247 | rewriter.setInsertionPointToStart(ifOp.thenBlock()); |
248 | bbArgReplacements.push_back( |
249 | helper.buildLoad(rewriter, loc, sequentialVal.getType(), buffer)); |
250 | } |
251 | |
252 | // Step 3. Insert sync after all the stores and before all the loads. |
253 | if (!warpOp.getArgs().empty()) { |
254 | rewriter.setInsertionPoint(ifOp); |
255 | options.warpSyncronizationFn(loc, rewriter, warpOp); |
256 | } |
257 | |
258 | // Step 4. Move body of warpOp to ifOp. |
259 | rewriter.mergeBlocks(source: warpOpBody, dest: ifOp.thenBlock(), argValues: bbArgReplacements); |
260 | |
261 | // Step 5. Insert appropriate writes within scf.if and reads after the |
262 | // scf.if to transit the values returned by the op. |
263 | // TODO: at this point, we can reuse the shared memory from previous |
264 | // buffers. |
265 | SmallVector<Value> replacements; |
266 | auto yieldOp = cast<gpu::YieldOp>(ifOp.thenBlock()->getTerminator()); |
267 | Location yieldLoc = yieldOp.getLoc(); |
268 | for (const auto &it : llvm::enumerate(yieldOp.getOperands())) { |
269 | Value sequentialVal = it.value(); |
270 | Value distributedVal = warpOp->getResult(it.index()); |
271 | DistributedLoadStoreHelper helper(sequentialVal, distributedVal, |
272 | warpOp.getLaneid(), c0); |
273 | |
274 | // Create buffer before the ifOp. |
275 | rewriter.setInsertionPoint(ifOp); |
276 | Value buffer = options.warpAllocationFn(loc, rewriter, warpOp, |
277 | sequentialVal.getType()); |
278 | |
279 | // Store yielded value into buffer, inside the ifOp, before the |
280 | // terminator. |
281 | rewriter.setInsertionPoint(yieldOp); |
282 | helper.buildStore(rewriter, loc, sequentialVal, buffer); |
283 | |
284 | // Load distributed value from buffer, after the warpOp. |
285 | rewriter.setInsertionPointAfter(ifOp); |
286 | // Result type and yielded value type are the same. This is a broadcast. |
287 | // E.g.: |
288 | // %r = gpu.warp_execute_on_lane_0(...) -> (f32) { |
289 | // gpu.yield %cst : f32 |
290 | // } |
291 | // Both types are f32. The constant %cst is broadcasted to all lanes. |
292 | // This is described in more detail in the documentation of the op. |
293 | replacements.push_back( |
294 | helper.buildLoad(rewriter, loc, distributedVal.getType(), buffer)); |
295 | } |
296 | |
297 | // Step 6. Insert sync after all the stores and before all the loads. |
298 | if (!yieldOp.getOperands().empty()) { |
299 | rewriter.setInsertionPointAfter(ifOp); |
300 | options.warpSyncronizationFn(loc, rewriter, warpOp); |
301 | } |
302 | |
303 | // Step 7. Delete terminator and add empty scf.yield. |
304 | rewriter.eraseOp(op: yieldOp); |
305 | rewriter.setInsertionPointToEnd(ifOp.thenBlock()); |
306 | rewriter.create<scf::YieldOp>(yieldLoc); |
307 | |
308 | // Compute replacements for WarpOp results. |
309 | rewriter.replaceOp(warpOp, replacements); |
310 | |
311 | return success(); |
312 | } |
313 | |
314 | private: |
315 | const WarpExecuteOnLane0LoweringOptions &options; |
316 | }; |
317 | |
318 | /// Return the distributed vector type based on the original type and the |
319 | /// distribution map. The map is expected to have a dimension equal to the |
320 | /// original type rank and should be a projection where the results are the |
321 | /// distributed dimensions. The number of results should be equal to the number |
322 | /// of warp sizes which is currently limited to 1. |
323 | /// Example: For a vector<16x32x64> distributed with a map(d0, d1, d2) -> (d1) |
324 | /// and a warp size of 16 would distribute the second dimension (associated to |
325 | /// d1) and return vector<16x2x64> |
326 | static VectorType getDistributedType(VectorType originalType, AffineMap map, |
327 | int64_t warpSize) { |
328 | SmallVector<int64_t> targetShape(originalType.getShape()); |
329 | for (unsigned i = 0, e = map.getNumResults(); i < e; i++) { |
330 | unsigned position = map.getDimPosition(idx: i); |
331 | if (targetShape[position] % warpSize != 0) { |
332 | if (warpSize % targetShape[position] != 0) { |
333 | return VectorType(); |
334 | } |
335 | warpSize /= targetShape[position]; |
336 | targetShape[position] = 1; |
337 | continue; |
338 | } |
339 | targetShape[position] = targetShape[position] / warpSize; |
340 | warpSize = 1; |
341 | break; |
342 | } |
343 | if (warpSize != 1) { |
344 | return VectorType(); |
345 | } |
346 | VectorType targetType = |
347 | VectorType::get(targetShape, originalType.getElementType()); |
348 | return targetType; |
349 | } |
350 | |
351 | /// Distribute transfer_write ops based on the affine map returned by |
352 | /// `distributionMapFn`. Writes of size more than `maxNumElementToExtract` |
353 | /// will not be distributed (it should be less than the warp size). |
354 | /// |
355 | /// Example: |
356 | /// ``` |
357 | /// %0 = gpu.warp_execute_on_lane_0(%id){ |
358 | /// ... |
359 | /// vector.transfer_write %v, %A[%c0] : vector<32xf32>, memref<128xf32> |
360 | /// gpu.yield |
361 | /// } |
362 | /// ``` |
363 | /// To |
364 | /// ``` |
365 | /// %r:3 = gpu.warp_execute_on_lane_0(%id) -> (vector<1xf32>) { |
366 | /// ... |
367 | /// gpu.yield %v : vector<32xf32> |
368 | /// } |
369 | /// vector.transfer_write %v, %A[%id] : vector<1xf32>, memref<128xf32> |
370 | struct WarpOpTransferWrite : public WarpDistributionPattern { |
371 | WarpOpTransferWrite(MLIRContext *ctx, DistributionMapFn fn, |
372 | unsigned maxNumElementsToExtract, PatternBenefit b = 1) |
373 | : WarpDistributionPattern(ctx, b), distributionMapFn(std::move(fn)), |
374 | maxNumElementsToExtract(maxNumElementsToExtract) {} |
375 | |
376 | /// Distribute the TransferWriteOp. Only 1D distributions and vector dims that |
377 | /// are multiples of the distribution ratio are supported at the moment. |
378 | LogicalResult tryDistributeOp(RewriterBase &rewriter, |
379 | vector::TransferWriteOp writeOp, |
380 | WarpExecuteOnLane0Op warpOp) const { |
381 | VectorType writtenVectorType = writeOp.getVectorType(); |
382 | |
383 | // 1. If the write is 0-D, we just clone it into a new WarpExecuteOnLane0Op |
384 | // to separate it from the rest. |
385 | if (writtenVectorType.getRank() == 0) |
386 | return failure(); |
387 | |
388 | // 2. Compute the distributed type. |
389 | AffineMap map = distributionMapFn(writeOp.getVector()); |
390 | VectorType targetType = |
391 | getDistributedType(writtenVectorType, map, warpOp.getWarpSize()); |
392 | if (!targetType) |
393 | return failure(); |
394 | |
395 | // 2.5 Compute the distributed type for the new mask; |
396 | VectorType maskType; |
397 | if (writeOp.getMask()) { |
398 | // TODO: Distribution of masked writes with non-trivial permutation maps |
399 | // requires the distribution of the mask to elementwise match the |
400 | // distribution of the permuted written vector. Currently the details |
401 | // of which lane is responsible for which element is captured strictly |
402 | // by shape information on the warp op, and thus requires materializing |
403 | // the permutation in IR. |
404 | if (!writeOp.getPermutationMap().isMinorIdentity()) |
405 | return failure(); |
406 | maskType = |
407 | getDistributedType(writeOp.getMaskType(), map, warpOp.getWarpSize()); |
408 | } |
409 | |
410 | // 3. clone the write into a new WarpExecuteOnLane0Op to separate it from |
411 | // the rest. |
412 | vector::TransferWriteOp newWriteOp = |
413 | cloneWriteOp(rewriter, warpOp, writeOp, targetType, maskType); |
414 | |
415 | // 4. Reindex the write using the distribution map. |
416 | auto newWarpOp = |
417 | newWriteOp.getVector().getDefiningOp<WarpExecuteOnLane0Op>(); |
418 | |
419 | // Delinearize the lane id based on the way threads are divided across the |
420 | // vector. To get the number of threads per vector dimension, divide the |
421 | // sequential size by the distributed size along each dim. |
422 | rewriter.setInsertionPoint(newWriteOp); |
423 | SmallVector<OpFoldResult> delinearizedIdSizes; |
424 | for (auto [seqSize, distSize] : |
425 | llvm::zip_equal(writtenVectorType.getShape(), targetType.getShape())) { |
426 | assert(seqSize % distSize == 0 && "Invalid distributed vector shape"); |
427 | delinearizedIdSizes.push_back(rewriter.getIndexAttr(seqSize / distSize)); |
428 | } |
429 | SmallVector<Value> delinearized; |
430 | if (map.getNumResults() > 1) { |
431 | delinearized = rewriter |
432 | .create<mlir::affine::AffineDelinearizeIndexOp>( |
433 | newWarpOp.getLoc(), newWarpOp.getLaneid(), |
434 | delinearizedIdSizes) |
435 | .getResults(); |
436 | } else { |
437 | // If there is only one map result, we can elide the delinearization |
438 | // op and use the lane id directly. |
439 | delinearized.append(targetType.getRank(), newWarpOp.getLaneid()); |
440 | } |
441 | |
442 | AffineMap indexMap = map.compose(newWriteOp.getPermutationMap()); |
443 | Location loc = newWriteOp.getLoc(); |
444 | SmallVector<Value> indices(newWriteOp.getIndices().begin(), |
445 | newWriteOp.getIndices().end()); |
446 | for (auto it : llvm::zip(indexMap.getResults(), map.getResults())) { |
447 | AffineExpr d0, d1; |
448 | bindDims(newWarpOp.getContext(), d0, d1); |
449 | auto indexExpr = dyn_cast<AffineDimExpr>(std::get<0>(it)); |
450 | if (!indexExpr) |
451 | continue; |
452 | unsigned indexPos = indexExpr.getPosition(); |
453 | unsigned vectorPos = cast<AffineDimExpr>(std::get<1>(it)).getPosition(); |
454 | Value laneId = delinearized[vectorPos]; |
455 | auto scale = |
456 | rewriter.getAffineConstantExpr(targetType.getDimSize(vectorPos)); |
457 | indices[indexPos] = affine::makeComposedAffineApply( |
458 | rewriter, loc, d0 + scale * d1, {indices[indexPos], laneId}); |
459 | } |
460 | newWriteOp.getIndicesMutable().assign(indices); |
461 | |
462 | return success(); |
463 | } |
464 | |
465 | /// Extract TransferWriteOps of vector<1x> into a separate warp op. |
466 | LogicalResult tryExtractOp(RewriterBase &rewriter, |
467 | vector::TransferWriteOp writeOp, |
468 | WarpExecuteOnLane0Op warpOp) const { |
469 | Location loc = writeOp.getLoc(); |
470 | VectorType vecType = writeOp.getVectorType(); |
471 | |
472 | if (vecType.getNumElements() > maxNumElementsToExtract) { |
473 | return rewriter.notifyMatchFailure( |
474 | warpOp, |
475 | llvm::formatv( |
476 | "writes more elements ({0}) than allowed to extract ({1})", |
477 | vecType.getNumElements(), maxNumElementsToExtract)); |
478 | } |
479 | |
480 | // Do not process warp ops that contain only TransferWriteOps. |
481 | if (llvm::all_of(warpOp.getOps(), |
482 | llvm::IsaPred<vector::TransferWriteOp, gpu::YieldOp>)) |
483 | return failure(); |
484 | |
485 | SmallVector<Value> yieldValues = {writeOp.getVector()}; |
486 | SmallVector<Type> retTypes = {vecType}; |
487 | SmallVector<size_t> newRetIndices; |
488 | WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns( |
489 | rewriter, warpOp, yieldValues, retTypes, newRetIndices); |
490 | rewriter.setInsertionPointAfter(newWarpOp); |
491 | |
492 | // Create a second warp op that contains only writeOp. |
493 | auto secondWarpOp = rewriter.create<WarpExecuteOnLane0Op>( |
494 | loc, TypeRange(), newWarpOp.getLaneid(), newWarpOp.getWarpSize()); |
495 | Block &body = secondWarpOp.getBodyRegion().front(); |
496 | rewriter.setInsertionPointToStart(&body); |
497 | auto newWriteOp = |
498 | cast<vector::TransferWriteOp>(rewriter.clone(*writeOp.getOperation())); |
499 | newWriteOp.getValueToStoreMutable().assign( |
500 | newWarpOp.getResult(newRetIndices[0])); |
501 | rewriter.eraseOp(op: writeOp); |
502 | rewriter.create<gpu::YieldOp>(newWarpOp.getLoc()); |
503 | return success(); |
504 | } |
505 | |
506 | LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp, |
507 | PatternRewriter &rewriter) const override { |
508 | auto yield = cast<gpu::YieldOp>( |
509 | warpOp.getBodyRegion().getBlocks().begin()->getTerminator()); |
510 | Operation *lastNode = yield->getPrevNode(); |
511 | auto writeOp = dyn_cast_or_null<vector::TransferWriteOp>(lastNode); |
512 | if (!writeOp) |
513 | return failure(); |
514 | |
515 | Value maybeMask = writeOp.getMask(); |
516 | if (!llvm::all_of(writeOp->getOperands(), [&](Value value) { |
517 | return writeOp.getVector() == value || |
518 | (maybeMask && maybeMask == value) || |
519 | warpOp.isDefinedOutsideOfRegion(value); |
520 | })) |
521 | return failure(); |
522 | |
523 | if (succeeded(tryDistributeOp(rewriter, writeOp, warpOp))) |
524 | return success(); |
525 | |
526 | // Masked writes not supported for extraction. |
527 | if (writeOp.getMask()) |
528 | return failure(); |
529 | |
530 | if (succeeded(tryExtractOp(rewriter, writeOp, warpOp))) |
531 | return success(); |
532 | |
533 | return failure(); |
534 | } |
535 | |
536 | private: |
537 | /// Clone `writeOp` assumed to be nested under `warpOp` into a new warp |
538 | /// execute op with the proper return type. The new write op is updated to |
539 | /// write the result of the new warp execute op. The old `writeOp` is deleted. |
540 | vector::TransferWriteOp cloneWriteOp(RewriterBase &rewriter, |
541 | WarpExecuteOnLane0Op warpOp, |
542 | vector::TransferWriteOp writeOp, |
543 | VectorType targetType, |
544 | VectorType maybeMaskType) const { |
545 | assert(writeOp->getParentOp() == warpOp && |
546 | "write must be nested immediately under warp"); |
547 | OpBuilder::InsertionGuard g(rewriter); |
548 | SmallVector<size_t> newRetIndices; |
549 | WarpExecuteOnLane0Op newWarpOp; |
550 | if (maybeMaskType) { |
551 | newWarpOp = moveRegionToNewWarpOpAndAppendReturns( |
552 | rewriter, warpOp, ValueRange{writeOp.getVector(), writeOp.getMask()}, |
553 | TypeRange{targetType, maybeMaskType}, newRetIndices); |
554 | } else { |
555 | newWarpOp = moveRegionToNewWarpOpAndAppendReturns( |
556 | rewriter, warpOp, ValueRange{{writeOp.getVector()}}, |
557 | TypeRange{targetType}, newRetIndices); |
558 | } |
559 | rewriter.setInsertionPointAfter(newWarpOp); |
560 | auto newWriteOp = |
561 | cast<vector::TransferWriteOp>(rewriter.clone(*writeOp.getOperation())); |
562 | rewriter.eraseOp(op: writeOp); |
563 | newWriteOp.getValueToStoreMutable().assign( |
564 | newWarpOp.getResult(newRetIndices[0])); |
565 | if (maybeMaskType) |
566 | newWriteOp.getMaskMutable().assign(newWarpOp.getResult(newRetIndices[1])); |
567 | return newWriteOp; |
568 | } |
569 | |
570 | DistributionMapFn distributionMapFn; |
571 | unsigned maxNumElementsToExtract = 1; |
572 | }; |
573 | |
574 | /// Sink out elementwise op feeding into a warp op yield. |
575 | /// ``` |
576 | /// %0 = gpu.warp_execute_on_lane_0(%arg0) -> (vector<1xf32>) { |
577 | /// ... |
578 | /// %3 = arith.addf %1, %2 : vector<32xf32> |
579 | /// gpu.yield %3 : vector<32xf32> |
580 | /// } |
581 | /// ``` |
582 | /// To |
583 | /// ``` |
584 | /// %r:3 = gpu.warp_execute_on_lane_0(%arg0) -> (vector<1xf32>, |
585 | /// vector<1xf32>, vector<1xf32>) { |
586 | /// ... |
587 | /// %4 = arith.addf %2, %3 : vector<32xf32> |
588 | /// gpu.yield %4, %2, %3 : vector<32xf32>, vector<32xf32>, |
589 | /// vector<32xf32> |
590 | /// } |
591 | /// %0 = arith.addf %r#1, %r#2 : vector<1xf32> |
592 | struct WarpOpElementwise : public WarpDistributionPattern { |
593 | using Base::Base; |
594 | LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp, |
595 | PatternRewriter &rewriter) const override { |
596 | OpOperand *yieldOperand = getWarpResult(warpOp, [](Operation *op) { |
597 | return OpTrait::hasElementwiseMappableTraits(op); |
598 | }); |
599 | if (!yieldOperand) |
600 | return failure(); |
601 | |
602 | Operation *elementWise = yieldOperand->get().getDefiningOp(); |
603 | unsigned operandIndex = yieldOperand->getOperandNumber(); |
604 | Value distributedVal = warpOp.getResult(operandIndex); |
605 | SmallVector<Value> yieldValues; |
606 | SmallVector<Type> retTypes; |
607 | Location loc = warpOp.getLoc(); |
608 | for (OpOperand &operand : elementWise->getOpOperands()) { |
609 | Type targetType; |
610 | if (auto vecType = dyn_cast<VectorType>(distributedVal.getType())) { |
611 | // If the result type is a vector, the operands must also be vectors. |
612 | auto operandType = cast<VectorType>(operand.get().getType()); |
613 | targetType = |
614 | VectorType::get(vecType.getShape(), operandType.getElementType()); |
615 | } else { |
616 | auto operandType = operand.get().getType(); |
617 | assert(!isa<VectorType>(operandType) && |
618 | "unexpected yield of vector from op with scalar result type"); |
619 | targetType = operandType; |
620 | } |
621 | retTypes.push_back(Elt: targetType); |
622 | yieldValues.push_back(Elt: operand.get()); |
623 | } |
624 | SmallVector<size_t> newRetIndices; |
625 | WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns( |
626 | rewriter, warpOp, yieldValues, retTypes, newRetIndices); |
627 | rewriter.setInsertionPointAfter(newWarpOp); |
628 | SmallVector<Value> newOperands(elementWise->getOperands().begin(), |
629 | elementWise->getOperands().end()); |
630 | for (unsigned i : llvm::seq(Begin: unsigned(0), End: elementWise->getNumOperands())) { |
631 | newOperands[i] = newWarpOp.getResult(newRetIndices[i]); |
632 | } |
633 | OpBuilder::InsertionGuard g(rewriter); |
634 | rewriter.setInsertionPointAfter(newWarpOp); |
635 | Operation *newOp = cloneOpWithOperandsAndTypes( |
636 | rewriter, loc, elementWise, newOperands, |
637 | {newWarpOp.getResult(operandIndex).getType()}); |
638 | rewriter.replaceAllUsesWith(newWarpOp.getResult(operandIndex), |
639 | newOp->getResult(idx: 0)); |
640 | return success(); |
641 | } |
642 | }; |
643 | |
644 | /// Sink out splat constant op feeding into a warp op yield. |
645 | /// ``` |
646 | /// %0 = gpu.warp_execute_on_lane_0(%arg0) -> (vector<1xf32>) { |
647 | /// ... |
648 | /// %cst = arith.constant dense<2.0> : vector<32xf32> |
649 | /// gpu.yield %cst : vector<32xf32> |
650 | /// } |
651 | /// ``` |
652 | /// To |
653 | /// ``` |
654 | /// gpu.warp_execute_on_lane_0(%arg0 { |
655 | /// ... |
656 | /// } |
657 | /// %0 = arith.constant dense<2.0> : vector<1xf32> |
658 | struct WarpOpConstant : public WarpDistributionPattern { |
659 | using Base::Base; |
660 | LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp, |
661 | PatternRewriter &rewriter) const override { |
662 | OpOperand *yieldOperand = |
663 | getWarpResult(warpOp, llvm::IsaPred<arith::ConstantOp>); |
664 | if (!yieldOperand) |
665 | return failure(); |
666 | auto constantOp = yieldOperand->get().getDefiningOp<arith::ConstantOp>(); |
667 | auto dense = dyn_cast<SplatElementsAttr>(constantOp.getValue()); |
668 | if (!dense) |
669 | return failure(); |
670 | // Notify the rewriter that the warp op is changing (see the comment on |
671 | // the WarpOpTransferRead pattern). |
672 | rewriter.startOpModification(op: warpOp); |
673 | unsigned operandIndex = yieldOperand->getOperandNumber(); |
674 | Attribute scalarAttr = dense.getSplatValue<Attribute>(); |
675 | auto newAttr = DenseElementsAttr::get( |
676 | cast<ShapedType>(warpOp.getResult(operandIndex).getType()), scalarAttr); |
677 | Location loc = warpOp.getLoc(); |
678 | rewriter.setInsertionPointAfter(warpOp); |
679 | Value distConstant = rewriter.create<arith::ConstantOp>(loc, newAttr); |
680 | rewriter.replaceAllUsesWith(warpOp.getResult(operandIndex), distConstant); |
681 | rewriter.finalizeOpModification(op: warpOp); |
682 | return success(); |
683 | } |
684 | }; |
685 | |
686 | /// Sink out transfer_read op feeding into a warp op yield. |
687 | /// ``` |
688 | /// %0 = gpu.warp_execute_on_lane_0(%arg0) -> (vector<1xf32>) { |
689 | /// ... |
690 | // %2 = vector.transfer_read %src[%c0], %cst : memref<1024xf32>, |
691 | // vector<32xf32> |
692 | /// gpu.yield %2 : vector<32xf32> |
693 | /// } |
694 | /// ``` |
695 | /// To |
696 | /// ``` |
697 | /// %dead = gpu.warp_execute_on_lane_0(%arg0) -> (vector<1xf32>, |
698 | /// vector<1xf32>, vector<1xf32>) { |
699 | /// ... |
700 | /// %2 = vector.transfer_read %src[%c0], %cst : memref<1024xf32>, |
701 | /// vector<32xf32> gpu.yield %2 : vector<32xf32> |
702 | /// } |
703 | /// %0 = vector.transfer_read %src[%c0], %cst : memref<1024xf32>, vector<1xf32> |
704 | struct WarpOpTransferRead : public WarpDistributionPattern { |
705 | using Base::Base; |
706 | LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp, |
707 | PatternRewriter &rewriter) const override { |
708 | // Try to find a distributable yielded read. Note that this pattern can |
709 | // still fail at the end after distribution, in which case this might have |
710 | // missed another distributable read. |
711 | OpOperand *operand = getWarpResult(warpOp, [](Operation *op) { |
712 | // Don't duplicate transfer_read ops when distributing. |
713 | return isa<vector::TransferReadOp>(op) && op->hasOneUse(); |
714 | }); |
715 | if (!operand) |
716 | return rewriter.notifyMatchFailure( |
717 | warpOp, "warp result is not a vector.transfer_read op"); |
718 | auto read = operand->get().getDefiningOp<vector::TransferReadOp>(); |
719 | |
720 | // Source must be defined outside of the region. |
721 | if (!warpOp.isDefinedOutsideOfRegion(read.getBase())) |
722 | return rewriter.notifyMatchFailure( |
723 | read, "source must be defined outside of the region"); |
724 | |
725 | unsigned operandIndex = operand->getOperandNumber(); |
726 | Value distributedVal = warpOp.getResult(operandIndex); |
727 | |
728 | SmallVector<Value, 4> indices(read.getIndices().begin(), |
729 | read.getIndices().end()); |
730 | auto sequentialType = cast<VectorType>(read.getResult().getType()); |
731 | auto distributedType = cast<VectorType>(distributedVal.getType()); |
732 | AffineMap map = calculateImplicitMap(sequentialType, distributedType); |
733 | AffineMap indexMap = map.compose(read.getPermutationMap()); |
734 | |
735 | // Try to delinearize the lane ID to match the rank expected for |
736 | // distribution. |
737 | SmallVector<Value> delinearizedIds; |
738 | if (!delinearizeLaneId(builder&: rewriter, loc: read.getLoc(), originalShape: sequentialType.getShape(), |
739 | distributedShape: distributedType.getShape(), warpSize: warpOp.getWarpSize(), |
740 | laneId: warpOp.getLaneid(), delinearizedIds)) { |
741 | return rewriter.notifyMatchFailure( |
742 | read, "cannot delinearize lane ID for distribution"); |
743 | } |
744 | assert(!delinearizedIds.empty() || map.getNumResults() == 0); |
745 | |
746 | // Distribute indices and the mask (if present). |
747 | OpBuilder::InsertionGuard g(rewriter); |
748 | SmallVector<Value> additionalResults(indices.begin(), indices.end()); |
749 | SmallVector<Type> additionalResultTypes(indices.size(), |
750 | rewriter.getIndexType()); |
751 | additionalResults.push_back(Elt: read.getPadding()); |
752 | additionalResultTypes.push_back(Elt: read.getPadding().getType()); |
753 | |
754 | bool hasMask = false; |
755 | if (read.getMask()) { |
756 | hasMask = true; |
757 | // TODO: Distribution of masked reads with non-trivial permutation maps |
758 | // requires the distribution of the mask to elementwise match the |
759 | // distribution of the permuted written vector. Currently the details |
760 | // of which lane is responsible for which element is captured strictly |
761 | // by shape information on the warp op, and thus requires materializing |
762 | // the permutation in IR. |
763 | if (!mlir::compressUnusedDims(read.getPermutationMap()).isIdentity()) |
764 | return rewriter.notifyMatchFailure( |
765 | read, "non-trivial permutation maps not supported"); |
766 | VectorType maskType = |
767 | getDistributedType(read.getMaskType(), map, warpOp.getWarpSize()); |
768 | additionalResults.push_back(Elt: read.getMask()); |
769 | additionalResultTypes.push_back(Elt: maskType); |
770 | } |
771 | |
772 | SmallVector<size_t> newRetIndices; |
773 | WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns( |
774 | rewriter, warpOp, additionalResults, additionalResultTypes, |
775 | newRetIndices); |
776 | distributedVal = newWarpOp.getResult(operandIndex); |
777 | |
778 | // Distributed indices were appended first. |
779 | SmallVector<Value> newIndices; |
780 | for (int64_t i = 0, e = indices.size(); i < e; ++i) |
781 | newIndices.push_back(Elt: newWarpOp.getResult(newRetIndices[i])); |
782 | |
783 | rewriter.setInsertionPointAfter(newWarpOp); |
784 | for (auto it : llvm::zip_equal(indexMap.getResults(), map.getResults())) { |
785 | AffineExpr d0, d1; |
786 | bindDims(read.getContext(), d0, d1); |
787 | auto indexExpr = dyn_cast<AffineDimExpr>(std::get<0>(it)); |
788 | if (!indexExpr) |
789 | continue; |
790 | unsigned indexPos = indexExpr.getPosition(); |
791 | unsigned vectorPos = cast<AffineDimExpr>(std::get<1>(it)).getPosition(); |
792 | int64_t scale = distributedType.getDimSize(vectorPos); |
793 | newIndices[indexPos] = affine::makeComposedAffineApply( |
794 | rewriter, read.getLoc(), d0 + scale * d1, |
795 | {newIndices[indexPos], delinearizedIds[vectorPos]}); |
796 | } |
797 | |
798 | // Distributed padding value was appended right after the indices. |
799 | Value newPadding = newWarpOp.getResult(newRetIndices[indices.size()]); |
800 | // Distributed mask value was added at the end (if the op has a mask). |
801 | Value newMask = |
802 | hasMask ? newWarpOp.getResult(newRetIndices[newRetIndices.size() - 1]) |
803 | : Value(); |
804 | auto newRead = rewriter.create<vector::TransferReadOp>( |
805 | read.getLoc(), distributedVal.getType(), read.getBase(), newIndices, |
806 | read.getPermutationMapAttr(), newPadding, newMask, |
807 | read.getInBoundsAttr()); |
808 | |
809 | rewriter.replaceAllUsesWith(distributedVal, newRead); |
810 | return success(); |
811 | } |
812 | }; |
813 | |
814 | /// Remove any result that has no use along with the matching yieldOp operand. |
815 | // TODO: Move this in WarpExecuteOnLane0Op canonicalization. |
816 | struct WarpOpDeadResult : public WarpDistributionPattern { |
817 | using Base::Base; |
818 | LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp, |
819 | PatternRewriter &rewriter) const override { |
820 | SmallVector<Type> newResultTypes; |
821 | newResultTypes.reserve(N: warpOp->getNumResults()); |
822 | SmallVector<Value> newYieldValues; |
823 | newYieldValues.reserve(N: warpOp->getNumResults()); |
824 | DenseMap<Value, int64_t> dedupYieldOperandPositionMap; |
825 | DenseMap<OpResult, int64_t> dedupResultPositionMap; |
826 | auto yield = cast<gpu::YieldOp>( |
827 | warpOp.getBodyRegion().getBlocks().begin()->getTerminator()); |
828 | |
829 | // Some values may be yielded multiple times and correspond to multiple |
830 | // results. Deduplicating occurs by taking each result with its matching |
831 | // yielded value, and: |
832 | // 1. recording the unique first position at which the value is yielded. |
833 | // 2. recording for the result, the first position at which the dedup'ed |
834 | // value is yielded. |
835 | // 3. skipping from the new result types / new yielded values any result |
836 | // that has no use or whose yielded value has already been seen. |
837 | for (OpResult result : warpOp.getResults()) { |
838 | Value yieldOperand = yield.getOperand(result.getResultNumber()); |
839 | auto it = dedupYieldOperandPositionMap.insert( |
840 | std::make_pair(yieldOperand, newResultTypes.size())); |
841 | dedupResultPositionMap.insert(std::make_pair(result, it.first->second)); |
842 | if (result.use_empty() || !it.second) |
843 | continue; |
844 | newResultTypes.push_back(result.getType()); |
845 | newYieldValues.push_back(yieldOperand); |
846 | } |
847 | // No modification, exit early. |
848 | if (yield.getNumOperands() == newYieldValues.size()) |
849 | return failure(); |
850 | // Move the body of the old warpOp to a new warpOp. |
851 | WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndReplaceReturns( |
852 | rewriter, warpOp, newYieldValues, newResultTypes); |
853 | |
854 | // Simplify the new warp op after dropping dead results. |
855 | newWarpOp.getBody()->walk([&](Operation *op) { |
856 | if (isOpTriviallyDead(op)) |
857 | rewriter.eraseOp(op); |
858 | }); |
859 | |
860 | // Replace results of the old warpOp by the new, deduplicated results. |
861 | SmallVector<Value> newValues; |
862 | newValues.reserve(N: warpOp->getNumResults()); |
863 | for (OpResult result : warpOp.getResults()) { |
864 | if (result.use_empty()) |
865 | newValues.push_back(Value()); |
866 | else |
867 | newValues.push_back( |
868 | newWarpOp.getResult(dedupResultPositionMap.lookup(result))); |
869 | } |
870 | rewriter.replaceOp(warpOp, newValues); |
871 | return success(); |
872 | } |
873 | }; |
874 | |
875 | // If an operand is directly yielded out of the region we can forward it |
876 | // directly and it doesn't need to go through the region. |
877 | struct WarpOpForwardOperand : public WarpDistributionPattern { |
878 | using Base::Base; |
879 | LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp, |
880 | PatternRewriter &rewriter) const override { |
881 | auto yield = cast<gpu::YieldOp>( |
882 | warpOp.getBodyRegion().getBlocks().begin()->getTerminator()); |
883 | Value valForwarded; |
884 | unsigned resultIndex; |
885 | for (OpOperand &operand : yield->getOpOperands()) { |
886 | Value result = warpOp.getResult(operand.getOperandNumber()); |
887 | if (result.use_empty()) |
888 | continue; |
889 | |
890 | // Assume all the values coming from above are uniform. |
891 | if (!warpOp.getBodyRegion().isAncestor(operand.get().getParentRegion())) { |
892 | if (result.getType() != operand.get().getType()) |
893 | continue; |
894 | valForwarded = operand.get(); |
895 | resultIndex = operand.getOperandNumber(); |
896 | break; |
897 | } |
898 | auto arg = dyn_cast<BlockArgument>(operand.get()); |
899 | if (!arg || arg.getOwner()->getParentOp() != warpOp.getOperation()) |
900 | continue; |
901 | Value warpOperand = warpOp.getArgs()[arg.getArgNumber()]; |
902 | if (result.getType() != warpOperand.getType()) |
903 | continue; |
904 | valForwarded = warpOperand; |
905 | resultIndex = operand.getOperandNumber(); |
906 | break; |
907 | } |
908 | if (!valForwarded) |
909 | return failure(); |
910 | // Notify the rewriter that the warp op is changing (see the comment on |
911 | // the WarpOpTransferRead pattern). |
912 | rewriter.startOpModification(op: warpOp); |
913 | rewriter.replaceAllUsesWith(warpOp.getResult(resultIndex), valForwarded); |
914 | rewriter.finalizeOpModification(op: warpOp); |
915 | return success(); |
916 | } |
917 | }; |
918 | |
919 | struct WarpOpBroadcast : public WarpDistributionPattern { |
920 | using Base::Base; |
921 | LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp, |
922 | PatternRewriter &rewriter) const override { |
923 | OpOperand *operand = |
924 | getWarpResult(warpOp, llvm::IsaPred<vector::BroadcastOp>); |
925 | if (!operand) |
926 | return failure(); |
927 | unsigned int operandNumber = operand->getOperandNumber(); |
928 | auto broadcastOp = operand->get().getDefiningOp<vector::BroadcastOp>(); |
929 | Location loc = broadcastOp.getLoc(); |
930 | auto destVecType = |
931 | cast<VectorType>(warpOp->getResultTypes()[operandNumber]); |
932 | Value broadcastSrc = broadcastOp.getSource(); |
933 | Type broadcastSrcType = broadcastSrc.getType(); |
934 | |
935 | // Check that the broadcast actually spans a set of values uniformly across |
936 | // all threads. In other words, check that each thread can reconstruct |
937 | // their own broadcast. |
938 | // For that we simply check that the broadcast we want to build makes sense. |
939 | if (vector::isBroadcastableTo(srcType: broadcastSrcType, dstVectorType: destVecType) != |
940 | vector::BroadcastableToResult::Success) |
941 | return failure(); |
942 | SmallVector<size_t> newRetIndices; |
943 | WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns( |
944 | rewriter, warpOp, {broadcastSrc}, {broadcastSrcType}, newRetIndices); |
945 | rewriter.setInsertionPointAfter(newWarpOp); |
946 | Value broadcasted = rewriter.create<vector::BroadcastOp>( |
947 | loc, destVecType, newWarpOp->getResult(newRetIndices[0])); |
948 | rewriter.replaceAllUsesWith(newWarpOp->getResult(operandNumber), |
949 | broadcasted); |
950 | return success(); |
951 | } |
952 | }; |
953 | |
954 | /// Pattern to move shape cast out of the warp op. shape cast is basically a |
955 | /// no-op for warp distribution; we need to handle the shape though. |
956 | struct WarpOpShapeCast : public WarpDistributionPattern { |
957 | using Base::Base; |
958 | LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp, |
959 | PatternRewriter &rewriter) const override { |
960 | OpOperand *operand = |
961 | getWarpResult(warpOp, llvm::IsaPred<vector::ShapeCastOp>); |
962 | if (!operand) |
963 | return failure(); |
964 | |
965 | auto oldCastOp = operand->get().getDefiningOp<vector::ShapeCastOp>(); |
966 | |
967 | unsigned int operandNumber = operand->getOperandNumber(); |
968 | auto castDistributedType = |
969 | cast<VectorType>(warpOp->getResultTypes()[operandNumber]); |
970 | VectorType castOriginalType = oldCastOp.getSourceVectorType(); |
971 | VectorType castResultType = castDistributedType; |
972 | |
973 | // We expect the distributed type to have a smaller rank than the original |
974 | // type. Prepend with size-one dimensions to make them the same. |
975 | unsigned castDistributedRank = castDistributedType.getRank(); |
976 | unsigned castOriginalRank = castOriginalType.getRank(); |
977 | if (castDistributedRank < castOriginalRank) { |
978 | SmallVector<int64_t> shape(castOriginalRank - castDistributedRank, 1); |
979 | llvm::append_range(shape, castDistributedType.getShape()); |
980 | castDistributedType = |
981 | VectorType::get(shape, castDistributedType.getElementType()); |
982 | } |
983 | |
984 | SmallVector<size_t> newRetIndices; |
985 | WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns( |
986 | rewriter, warpOp, {oldCastOp.getSource()}, {castDistributedType}, |
987 | newRetIndices); |
988 | rewriter.setInsertionPointAfter(newWarpOp); |
989 | Value newCast = rewriter.create<vector::ShapeCastOp>( |
990 | oldCastOp.getLoc(), castResultType, |
991 | newWarpOp->getResult(newRetIndices[0])); |
992 | rewriter.replaceAllUsesWith(newWarpOp->getResult(operandNumber), newCast); |
993 | return success(); |
994 | } |
995 | }; |
996 | |
997 | /// Sink out vector.create_mask op feeding into a warp op yield. |
998 | /// ``` |
999 | /// %0 = ... |
1000 | /// %1 = gpu.warp_execute_on_lane_0(%arg0) -> (vector<1xf32>) { |
1001 | /// ... |
1002 | /// %mask = vector.create_mask %0 : vector<32xi1> |
1003 | /// gpu.yield %mask : vector<32xi1> |
1004 | /// } |
1005 | /// ``` |
1006 | /// To |
1007 | /// ``` |
1008 | /// %0 = ... |
1009 | /// gpu.warp_execute_on_lane_0(%arg0) { |
1010 | /// ... |
1011 | /// } |
1012 | /// %cmp = arith.cmpi ult, %laneid, %0 |
1013 | /// %ub = arith.select %cmp, %c0, %c1 |
1014 | /// %1 = vector.create_mask %ub : vector<1xi1> |
1015 | struct WarpOpCreateMask : public WarpDistributionPattern { |
1016 | using Base::Base; |
1017 | LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp, |
1018 | PatternRewriter &rewriter) const override { |
1019 | OpOperand *yieldOperand = |
1020 | getWarpResult(warpOp, llvm::IsaPred<vector::CreateMaskOp>); |
1021 | if (!yieldOperand) |
1022 | return failure(); |
1023 | |
1024 | auto mask = yieldOperand->get().getDefiningOp<vector::CreateMaskOp>(); |
1025 | |
1026 | // Early exit if any values needed for calculating the new mask indices |
1027 | // are defined inside the warp op. |
1028 | if (!llvm::all_of(mask->getOperands(), [&](Value value) { |
1029 | return warpOp.isDefinedOutsideOfRegion(value); |
1030 | })) |
1031 | return failure(); |
1032 | |
1033 | Location loc = mask.getLoc(); |
1034 | unsigned operandIndex = yieldOperand->getOperandNumber(); |
1035 | |
1036 | auto distType = cast<VectorType>(warpOp.getResult(operandIndex).getType()); |
1037 | VectorType seqType = mask.getVectorType(); |
1038 | ArrayRef<int64_t> seqShape = seqType.getShape(); |
1039 | ArrayRef<int64_t> distShape = distType.getShape(); |
1040 | |
1041 | rewriter.setInsertionPointAfter(warpOp); |
1042 | |
1043 | // Delinearize the lane ID for constructing the distributed mask sizes. |
1044 | SmallVector<Value> delinearizedIds; |
1045 | if (!delinearizeLaneId(builder&: rewriter, loc, originalShape: seqShape, distributedShape: distShape, |
1046 | warpSize: warpOp.getWarpSize(), laneId: warpOp.getLaneid(), |
1047 | delinearizedIds)) |
1048 | return rewriter.notifyMatchFailure( |
1049 | mask, "cannot delinearize lane ID for distribution"); |
1050 | assert(!delinearizedIds.empty()); |
1051 | |
1052 | // Notify the rewriter that the warp op is changing (see the comment on |
1053 | // the WarpOpTransferRead pattern). |
1054 | rewriter.startOpModification(op: warpOp); |
1055 | |
1056 | AffineExpr s0, s1; |
1057 | bindSymbols(ctx: rewriter.getContext(), exprs&: s0, exprs&: s1); |
1058 | SmallVector<Value> newOperands; |
1059 | for (int i = 0, e = distShape.size(); i < e; ++i) { |
1060 | // Get `mask_dim_range_upper_limit[i] - lane_id[i] * dist_sizes[i]` to |
1061 | // find the distance from the largest mask index owned by this lane to the |
1062 | // original mask size. `vector.create_mask` implicitly clamps mask |
1063 | // operands to the range [0, mask_vector_size[i]], or in other words, the |
1064 | // mask sizes are always in the range [0, mask_vector_size[i]). |
1065 | Value maskDimIdx = affine::makeComposedAffineApply( |
1066 | rewriter, loc, s1 - s0 * distShape[i], |
1067 | {delinearizedIds[i], mask.getOperand(i)}); |
1068 | newOperands.push_back(Elt: maskDimIdx); |
1069 | } |
1070 | |
1071 | auto newMask = |
1072 | rewriter.create<vector::CreateMaskOp>(loc, distType, newOperands); |
1073 | rewriter.replaceAllUsesWith(warpOp.getResult(operandIndex), newMask); |
1074 | rewriter.finalizeOpModification(op: warpOp); |
1075 | return success(); |
1076 | } |
1077 | }; |
1078 | |
1079 | /// Pattern to move out vector.extract of single element vector. Those don't |
1080 | /// need to be distributed and can just be propagated outside of the region. |
1081 | struct WarpOpExtract : public WarpDistributionPattern { |
1082 | using Base::Base; |
1083 | LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp, |
1084 | PatternRewriter &rewriter) const override { |
1085 | OpOperand *operand = |
1086 | getWarpResult(warpOp, llvm::IsaPred<vector::ExtractOp>); |
1087 | if (!operand) |
1088 | return failure(); |
1089 | unsigned int operandNumber = operand->getOperandNumber(); |
1090 | auto extractOp = operand->get().getDefiningOp<vector::ExtractOp>(); |
1091 | VectorType extractSrcType = extractOp.getSourceVectorType(); |
1092 | Location loc = extractOp.getLoc(); |
1093 | |
1094 | // For 1-d or 0-d source cases, we rely on WarpOpExtractScalar pattern. |
1095 | if (extractSrcType.getRank() <= 1) { |
1096 | return failure(); |
1097 | } |
1098 | |
1099 | // All following cases are 2d or higher dimensional source vectors. |
1100 | |
1101 | if (warpOp.getResult(operandNumber).getType() == operand->get().getType()) { |
1102 | // There is no distribution, this is a broadcast. Simply move the extract |
1103 | // out of the warp op. |
1104 | // TODO: This could be optimized. E.g., in case of a scalar result, let |
1105 | // one lane extract and shuffle the result to all other lanes (same as |
1106 | // the 1d case). |
1107 | SmallVector<size_t> newRetIndices; |
1108 | WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns( |
1109 | rewriter, warpOp, {extractOp.getVector()}, |
1110 | {extractOp.getSourceVectorType()}, newRetIndices); |
1111 | rewriter.setInsertionPointAfter(newWarpOp); |
1112 | Value distributedVec = newWarpOp->getResult(newRetIndices[0]); |
1113 | // Extract from distributed vector. |
1114 | Value newExtract = rewriter.create<vector::ExtractOp>( |
1115 | loc, distributedVec, extractOp.getMixedPosition()); |
1116 | rewriter.replaceAllUsesWith(newWarpOp->getResult(operandNumber), |
1117 | newExtract); |
1118 | return success(); |
1119 | } |
1120 | |
1121 | // Find the distributed dimension. There should be exactly one. |
1122 | auto distributedType = |
1123 | cast<VectorType>(warpOp.getResult(operandNumber).getType()); |
1124 | auto yieldedType = cast<VectorType>(operand->get().getType()); |
1125 | int64_t distributedDim = -1; |
1126 | for (int64_t i = 0; i < yieldedType.getRank(); ++i) { |
1127 | if (distributedType.getDimSize(i) != yieldedType.getDimSize(i)) { |
1128 | // Keep this assert here in case WarpExecuteOnLane0Op gets extended to |
1129 | // support distributing multiple dimensions in the future. |
1130 | assert(distributedDim == -1 && "found multiple distributed dims"); |
1131 | distributedDim = i; |
1132 | } |
1133 | } |
1134 | assert(distributedDim != -1 && "could not find distributed dimension"); |
1135 | (void)distributedDim; |
1136 | |
1137 | // Yield source vector from warp op. |
1138 | SmallVector<int64_t> newDistributedShape(extractSrcType.getShape()); |
1139 | for (int i = 0; i < distributedType.getRank(); ++i) |
1140 | newDistributedShape[i + extractOp.getNumIndices()] = |
1141 | distributedType.getDimSize(i); |
1142 | auto newDistributedType = |
1143 | VectorType::get(newDistributedShape, distributedType.getElementType()); |
1144 | SmallVector<size_t> newRetIndices; |
1145 | WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns( |
1146 | rewriter, warpOp, {extractOp.getVector()}, {newDistributedType}, |
1147 | newRetIndices); |
1148 | rewriter.setInsertionPointAfter(newWarpOp); |
1149 | Value distributedVec = newWarpOp->getResult(newRetIndices[0]); |
1150 | // Extract from distributed vector. |
1151 | Value newExtract = rewriter.create<vector::ExtractOp>( |
1152 | loc, distributedVec, extractOp.getMixedPosition()); |
1153 | rewriter.replaceAllUsesWith(newWarpOp->getResult(operandNumber), |
1154 | newExtract); |
1155 | return success(); |
1156 | } |
1157 | }; |
1158 | |
1159 | /// Pattern to move out vector.extract with a scalar result. |
1160 | /// Only supports 1-D and 0-D sources for now. |
1161 | struct WarpOpExtractScalar : public WarpDistributionPattern { |
1162 | WarpOpExtractScalar(MLIRContext *ctx, WarpShuffleFromIdxFn fn, |
1163 | PatternBenefit b = 1) |
1164 | : WarpDistributionPattern(ctx, b), warpShuffleFromIdxFn(std::move(fn)) {} |
1165 | LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp, |
1166 | PatternRewriter &rewriter) const override { |
1167 | OpOperand *operand = |
1168 | getWarpResult(warpOp, llvm::IsaPred<vector::ExtractOp>); |
1169 | if (!operand) |
1170 | return failure(); |
1171 | unsigned int operandNumber = operand->getOperandNumber(); |
1172 | auto extractOp = operand->get().getDefiningOp<vector::ExtractOp>(); |
1173 | VectorType extractSrcType = extractOp.getSourceVectorType(); |
1174 | // Only supports 1-D or 0-D sources for now. |
1175 | if (extractSrcType.getRank() > 1) { |
1176 | return rewriter.notifyMatchFailure( |
1177 | extractOp, "only 0-D or 1-D source supported for now"); |
1178 | } |
1179 | // TODO: Supported shuffle types should be parameterizable, similar to |
1180 | // `WarpShuffleFromIdxFn`. |
1181 | if (!extractSrcType.getElementType().isF32() && |
1182 | !extractSrcType.getElementType().isInteger(32)) |
1183 | return rewriter.notifyMatchFailure( |
1184 | extractOp, "only f32/i32 element types are supported"); |
1185 | bool is0dOrVec1Extract = extractSrcType.getNumElements() == 1; |
1186 | Type elType = extractSrcType.getElementType(); |
1187 | VectorType distributedVecType; |
1188 | if (!is0dOrVec1Extract) { |
1189 | assert(extractSrcType.getRank() == 1 && |
1190 | "expected that extract src rank is 0 or 1"); |
1191 | if (extractSrcType.getShape()[0] % warpOp.getWarpSize() != 0) |
1192 | return failure(); |
1193 | int64_t elementsPerLane = |
1194 | extractSrcType.getShape()[0] / warpOp.getWarpSize(); |
1195 | distributedVecType = VectorType::get({elementsPerLane}, elType); |
1196 | } else { |
1197 | distributedVecType = extractSrcType; |
1198 | } |
1199 | // Yield source vector and position (if present) from warp op. |
1200 | SmallVector<Value> additionalResults{extractOp.getVector()}; |
1201 | SmallVector<Type> additionalResultTypes{distributedVecType}; |
1202 | additionalResults.append( |
1203 | RHS: SmallVector<Value>(extractOp.getDynamicPosition())); |
1204 | additionalResultTypes.append( |
1205 | RHS: SmallVector<Type>(extractOp.getDynamicPosition().getTypes())); |
1206 | |
1207 | Location loc = extractOp.getLoc(); |
1208 | SmallVector<size_t> newRetIndices; |
1209 | WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns( |
1210 | rewriter, warpOp, additionalResults, additionalResultTypes, |
1211 | newRetIndices); |
1212 | rewriter.setInsertionPointAfter(newWarpOp); |
1213 | Value distributedVec = newWarpOp->getResult(newRetIndices[0]); |
1214 | |
1215 | // 0d extract: The new warp op broadcasts the source vector to all lanes. |
1216 | // All lanes extract the scalar. |
1217 | if (is0dOrVec1Extract) { |
1218 | Value newExtract; |
1219 | SmallVector<int64_t> indices(extractSrcType.getRank(), 0); |
1220 | newExtract = |
1221 | rewriter.create<vector::ExtractOp>(loc, distributedVec, indices); |
1222 | rewriter.replaceAllUsesWith(newWarpOp->getResult(operandNumber), |
1223 | newExtract); |
1224 | return success(); |
1225 | } |
1226 | |
1227 | int64_t staticPos = extractOp.getStaticPosition()[0]; |
1228 | OpFoldResult pos = ShapedType::isDynamic(staticPos) |
1229 | ? (newWarpOp->getResult(newRetIndices[1])) |
1230 | : OpFoldResult(rewriter.getIndexAttr(staticPos)); |
1231 | // 1d extract: Distribute the source vector. One lane extracts and shuffles |
1232 | // the value to all other lanes. |
1233 | int64_t elementsPerLane = distributedVecType.getShape()[0]; |
1234 | AffineExpr sym0 = getAffineSymbolExpr(position: 0, context: rewriter.getContext()); |
1235 | // tid of extracting thread: pos / elementsPerLane |
1236 | Value broadcastFromTid = affine::makeComposedAffineApply( |
1237 | rewriter, loc, sym0.ceilDiv(v: elementsPerLane), pos); |
1238 | // Extract at position: pos % elementsPerLane |
1239 | Value newPos = |
1240 | elementsPerLane == 1 |
1241 | ? rewriter.create<arith::ConstantIndexOp>(location: loc, args: 0).getResult() |
1242 | : affine::makeComposedAffineApply(rewriter, loc, |
1243 | sym0 % elementsPerLane, pos); |
1244 | Value extracted = |
1245 | rewriter.create<vector::ExtractOp>(loc, distributedVec, newPos); |
1246 | |
1247 | // Shuffle the extracted value to all lanes. |
1248 | Value shuffled = warpShuffleFromIdxFn( |
1249 | loc, rewriter, extracted, broadcastFromTid, newWarpOp.getWarpSize()); |
1250 | rewriter.replaceAllUsesWith(newWarpOp->getResult(operandNumber), shuffled); |
1251 | return success(); |
1252 | } |
1253 | |
1254 | private: |
1255 | WarpShuffleFromIdxFn warpShuffleFromIdxFn; |
1256 | }; |
1257 | |
1258 | /// Pattern to convert vector.extractelement to vector.extract. |
1259 | struct WarpOpExtractElement : public WarpDistributionPattern { |
1260 | using Base::Base; |
1261 | LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp, |
1262 | PatternRewriter &rewriter) const override { |
1263 | OpOperand *operand = |
1264 | getWarpResult(warpOp, llvm::IsaPred<vector::ExtractElementOp>); |
1265 | if (!operand) |
1266 | return failure(); |
1267 | auto extractOp = operand->get().getDefiningOp<vector::ExtractElementOp>(); |
1268 | SmallVector<OpFoldResult> indices; |
1269 | if (auto pos = extractOp.getPosition()) { |
1270 | indices.push_back(Elt: pos); |
1271 | } |
1272 | rewriter.setInsertionPoint(extractOp); |
1273 | rewriter.replaceOpWithNewOp<vector::ExtractOp>( |
1274 | extractOp, extractOp.getVector(), indices); |
1275 | return success(); |
1276 | } |
1277 | }; |
1278 | |
1279 | /// Pattern to move out vector.insert with a scalar input. |
1280 | /// Only supports 1-D and 0-D destinations for now. |
1281 | struct WarpOpInsertScalar : public WarpDistributionPattern { |
1282 | using Base::Base; |
1283 | LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp, |
1284 | PatternRewriter &rewriter) const override { |
1285 | OpOperand *operand = getWarpResult(warpOp, llvm::IsaPred<vector::InsertOp>); |
1286 | if (!operand) |
1287 | return failure(); |
1288 | unsigned int operandNumber = operand->getOperandNumber(); |
1289 | auto insertOp = operand->get().getDefiningOp<vector::InsertOp>(); |
1290 | VectorType vecType = insertOp.getDestVectorType(); |
1291 | VectorType distrType = |
1292 | cast<VectorType>(warpOp.getResult(operandNumber).getType()); |
1293 | |
1294 | // Only supports 1-D or 0-D destinations for now. |
1295 | if (vecType.getRank() > 1) { |
1296 | return rewriter.notifyMatchFailure( |
1297 | insertOp, "only 0-D or 1-D source supported for now"); |
1298 | } |
1299 | |
1300 | // Yield destination vector, source scalar and position from warp op. |
1301 | SmallVector<Value> additionalResults{insertOp.getDest(), |
1302 | insertOp.getValueToStore()}; |
1303 | SmallVector<Type> additionalResultTypes{ |
1304 | distrType, insertOp.getValueToStore().getType()}; |
1305 | additionalResults.append(RHS: SmallVector<Value>(insertOp.getDynamicPosition())); |
1306 | additionalResultTypes.append( |
1307 | RHS: SmallVector<Type>(insertOp.getDynamicPosition().getTypes())); |
1308 | |
1309 | Location loc = insertOp.getLoc(); |
1310 | SmallVector<size_t> newRetIndices; |
1311 | WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns( |
1312 | rewriter, warpOp, additionalResults, additionalResultTypes, |
1313 | newRetIndices); |
1314 | rewriter.setInsertionPointAfter(newWarpOp); |
1315 | Value distributedVec = newWarpOp->getResult(newRetIndices[0]); |
1316 | Value newSource = newWarpOp->getResult(newRetIndices[1]); |
1317 | rewriter.setInsertionPointAfter(newWarpOp); |
1318 | |
1319 | OpFoldResult pos; |
1320 | if (vecType.getRank() != 0) { |
1321 | int64_t staticPos = insertOp.getStaticPosition()[0]; |
1322 | pos = ShapedType::isDynamic(staticPos) |
1323 | ? (newWarpOp->getResult(newRetIndices[2])) |
1324 | : OpFoldResult(rewriter.getIndexAttr(staticPos)); |
1325 | } |
1326 | |
1327 | // This condition is always true for 0-d vectors. |
1328 | if (vecType == distrType) { |
1329 | Value newInsert; |
1330 | SmallVector<OpFoldResult> indices; |
1331 | if (pos) { |
1332 | indices.push_back(Elt: pos); |
1333 | } |
1334 | newInsert = rewriter.create<vector::InsertOp>(loc, newSource, |
1335 | distributedVec, indices); |
1336 | // Broadcast: Simply move the vector.insert op out. |
1337 | rewriter.replaceAllUsesWith(newWarpOp->getResult(operandNumber), |
1338 | newInsert); |
1339 | return success(); |
1340 | } |
1341 | |
1342 | // This is a distribution. Only one lane should insert. |
1343 | int64_t elementsPerLane = distrType.getShape()[0]; |
1344 | AffineExpr sym0 = getAffineSymbolExpr(position: 0, context: rewriter.getContext()); |
1345 | // tid of extracting thread: pos / elementsPerLane |
1346 | Value insertingLane = affine::makeComposedAffineApply( |
1347 | rewriter, loc, sym0.ceilDiv(v: elementsPerLane), pos); |
1348 | // Insert position: pos % elementsPerLane |
1349 | OpFoldResult newPos = affine::makeComposedFoldedAffineApply( |
1350 | b&: rewriter, loc, expr: sym0 % elementsPerLane, operands: pos); |
1351 | Value isInsertingLane = rewriter.create<arith::CmpIOp>( |
1352 | loc, arith::CmpIPredicate::eq, newWarpOp.getLaneid(), insertingLane); |
1353 | Value newResult = |
1354 | rewriter |
1355 | .create<scf::IfOp>( |
1356 | loc, isInsertingLane, |
1357 | /*thenBuilder=*/ |
1358 | [&](OpBuilder &builder, Location loc) { |
1359 | Value newInsert = builder.create<vector::InsertOp>( |
1360 | loc, newSource, distributedVec, newPos); |
1361 | builder.create<scf::YieldOp>(loc, newInsert); |
1362 | }, |
1363 | /*elseBuilder=*/ |
1364 | [&](OpBuilder &builder, Location loc) { |
1365 | builder.create<scf::YieldOp>(loc, distributedVec); |
1366 | }) |
1367 | .getResult(0); |
1368 | rewriter.replaceAllUsesWith(newWarpOp->getResult(operandNumber), newResult); |
1369 | return success(); |
1370 | } |
1371 | }; |
1372 | |
1373 | struct WarpOpInsert : public WarpDistributionPattern { |
1374 | using Base::Base; |
1375 | LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp, |
1376 | PatternRewriter &rewriter) const override { |
1377 | OpOperand *operand = getWarpResult(warpOp, llvm::IsaPred<vector::InsertOp>); |
1378 | if (!operand) |
1379 | return failure(); |
1380 | unsigned int operandNumber = operand->getOperandNumber(); |
1381 | auto insertOp = operand->get().getDefiningOp<vector::InsertOp>(); |
1382 | Location loc = insertOp.getLoc(); |
1383 | |
1384 | // For 1-d or 0-d destination cases, we rely on WarpOpInsertScalar pattern. |
1385 | if (insertOp.getDestVectorType().getRank() <= 1) { |
1386 | return failure(); |
1387 | } |
1388 | |
1389 | // All following cases are 2d or higher dimensional source vectors. |
1390 | |
1391 | if (warpOp.getResult(operandNumber).getType() == operand->get().getType()) { |
1392 | // There is no distribution, this is a broadcast. Simply move the insert |
1393 | // out of the warp op. |
1394 | SmallVector<size_t> newRetIndices; |
1395 | WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns( |
1396 | rewriter, warpOp, {insertOp.getValueToStore(), insertOp.getDest()}, |
1397 | {insertOp.getValueToStoreType(), insertOp.getDestVectorType()}, |
1398 | newRetIndices); |
1399 | rewriter.setInsertionPointAfter(newWarpOp); |
1400 | Value distributedSrc = newWarpOp->getResult(newRetIndices[0]); |
1401 | Value distributedDest = newWarpOp->getResult(newRetIndices[1]); |
1402 | Value newResult = rewriter.create<vector::InsertOp>( |
1403 | loc, distributedSrc, distributedDest, insertOp.getMixedPosition()); |
1404 | rewriter.replaceAllUsesWith(newWarpOp->getResult(operandNumber), |
1405 | newResult); |
1406 | return success(); |
1407 | } |
1408 | |
1409 | // Find the distributed dimension. There should be exactly one. |
1410 | auto distrDestType = |
1411 | cast<VectorType>(warpOp.getResult(operandNumber).getType()); |
1412 | auto yieldedType = cast<VectorType>(operand->get().getType()); |
1413 | int64_t distrDestDim = -1; |
1414 | for (int64_t i = 0; i < yieldedType.getRank(); ++i) { |
1415 | if (distrDestType.getDimSize(i) != yieldedType.getDimSize(i)) { |
1416 | // Keep this assert here in case WarpExecuteOnLane0Op gets extended to |
1417 | // support distributing multiple dimensions in the future. |
1418 | assert(distrDestDim == -1 && "found multiple distributed dims"); |
1419 | distrDestDim = i; |
1420 | } |
1421 | } |
1422 | assert(distrDestDim != -1 && "could not find distributed dimension"); |
1423 | |
1424 | // Compute the distributed source vector type. |
1425 | VectorType srcVecType = cast<VectorType>(insertOp.getValueToStoreType()); |
1426 | SmallVector<int64_t> distrSrcShape(srcVecType.getShape()); |
1427 | // E.g.: vector.insert %s, %d [2] : vector<96xf32> into vector<128x96xf32> |
1428 | // Case 1: distrDestDim = 1 (dim of size 96). In that case, each lane will |
1429 | // insert a smaller vector<3xf32>. |
1430 | // Case 2: distrDestDim = 0 (dim of size 128) => distrSrcDim = -1. In that |
1431 | // case, one lane will insert the source vector<96xf32>. The other |
1432 | // lanes will not do anything. |
1433 | int64_t distrSrcDim = distrDestDim - insertOp.getNumIndices(); |
1434 | if (distrSrcDim >= 0) |
1435 | distrSrcShape[distrSrcDim] = distrDestType.getDimSize(distrDestDim); |
1436 | auto distrSrcType = |
1437 | VectorType::get(distrSrcShape, distrDestType.getElementType()); |
1438 | |
1439 | // Yield source and dest vectors from warp op. |
1440 | SmallVector<size_t> newRetIndices; |
1441 | WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns( |
1442 | rewriter, warpOp, {insertOp.getValueToStore(), insertOp.getDest()}, |
1443 | {distrSrcType, distrDestType}, newRetIndices); |
1444 | rewriter.setInsertionPointAfter(newWarpOp); |
1445 | Value distributedSrc = newWarpOp->getResult(newRetIndices[0]); |
1446 | Value distributedDest = newWarpOp->getResult(newRetIndices[1]); |
1447 | |
1448 | // Insert into the distributed vector. |
1449 | Value newResult; |
1450 | if (distrSrcDim >= 0) { |
1451 | // Every lane inserts a small piece. |
1452 | newResult = rewriter.create<vector::InsertOp>( |
1453 | loc, distributedSrc, distributedDest, insertOp.getMixedPosition()); |
1454 | } else { |
1455 | // One lane inserts the entire source vector. |
1456 | int64_t elementsPerLane = distrDestType.getDimSize(distrDestDim); |
1457 | SmallVector<OpFoldResult> pos = insertOp.getMixedPosition(); |
1458 | SmallVector<int64_t> newPos = getAsIntegers(foldResults: pos); |
1459 | // tid of inserting lane: pos / elementsPerLane |
1460 | Value insertingLane = rewriter.create<arith::ConstantIndexOp>( |
1461 | location: loc, args: newPos[distrDestDim] / elementsPerLane); |
1462 | Value isInsertingLane = rewriter.create<arith::CmpIOp>( |
1463 | loc, arith::CmpIPredicate::eq, newWarpOp.getLaneid(), insertingLane); |
1464 | // Insert position: pos % elementsPerLane |
1465 | newPos[distrDestDim] %= elementsPerLane; |
1466 | auto insertingBuilder = [&](OpBuilder &builder, Location loc) { |
1467 | Value newInsert = builder.create<vector::InsertOp>( |
1468 | loc, distributedSrc, distributedDest, newPos); |
1469 | builder.create<scf::YieldOp>(loc, newInsert); |
1470 | }; |
1471 | auto nonInsertingBuilder = [&](OpBuilder &builder, Location loc) { |
1472 | builder.create<scf::YieldOp>(loc, distributedDest); |
1473 | }; |
1474 | newResult = rewriter |
1475 | .create<scf::IfOp>(loc, isInsertingLane, |
1476 | /*thenBuilder=*/insertingBuilder, |
1477 | /*elseBuilder=*/nonInsertingBuilder) |
1478 | .getResult(0); |
1479 | } |
1480 | |
1481 | rewriter.replaceAllUsesWith(newWarpOp->getResult(operandNumber), newResult); |
1482 | return success(); |
1483 | } |
1484 | }; |
1485 | |
1486 | struct WarpOpInsertElement : public WarpDistributionPattern { |
1487 | using Base::Base; |
1488 | LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp, |
1489 | PatternRewriter &rewriter) const override { |
1490 | OpOperand *operand = |
1491 | getWarpResult(warpOp, llvm::IsaPred<vector::InsertElementOp>); |
1492 | if (!operand) |
1493 | return failure(); |
1494 | auto insertOp = operand->get().getDefiningOp<vector::InsertElementOp>(); |
1495 | SmallVector<OpFoldResult> indices; |
1496 | if (auto pos = insertOp.getPosition()) { |
1497 | indices.push_back(Elt: pos); |
1498 | } |
1499 | rewriter.setInsertionPoint(insertOp); |
1500 | rewriter.replaceOpWithNewOp<vector::InsertOp>( |
1501 | insertOp, insertOp.getSource(), insertOp.getDest(), indices); |
1502 | return success(); |
1503 | } |
1504 | }; |
1505 | |
1506 | /// Sink scf.for region out of WarpExecuteOnLane0Op. This can be done only if |
1507 | /// the scf.ForOp is the last operation in the region so that it doesn't |
1508 | /// change the order of execution. This creates a new scf.for region after the |
1509 | /// WarpExecuteOnLane0Op. The new scf.for region will contain a new |
1510 | /// WarpExecuteOnLane0Op region. Example: |
1511 | /// ``` |
1512 | /// %w = gpu.warp_execute_on_lane_0(%laneid) -> (vector<4xf32>) { |
1513 | /// ... |
1514 | /// %v1 = scf.for %arg3 = %c0 to %c128 step %c1 iter_args(%arg4 = %v) |
1515 | /// -> (vector<128xf32>) { |
1516 | /// ... |
1517 | /// scf.yield %r : vector<128xf32> |
1518 | /// } |
1519 | /// gpu.yield %v1 : vector<128xf32> |
1520 | /// } |
1521 | /// ``` |
1522 | /// To: |
1523 | /// %w0 = gpu.warp_execute_on_lane_0(%arg0) -> (vector<4xf32>) { |
1524 | /// ... |
1525 | /// gpu.yield %v : vector<128xf32> |
1526 | /// } |
1527 | /// %w = scf.for %arg3 = %c0 to %c128 step %c1 iter_args(%varg = %q0) |
1528 | /// -> (vector<4xf32>) { |
1529 | /// %iw = gpu.warp_execute_on_lane_0(%laneid) |
1530 | /// args(%varg : vector<4xf32>) -> (vector<4xf32>) { |
1531 | /// ^bb0(%arg: vector<128xf32>): |
1532 | /// ... |
1533 | /// gpu.yield %ir : vector<128xf32> |
1534 | /// } |
1535 | /// scf.yield %iw : vector<4xf32> |
1536 | /// } |
1537 | /// ``` |
1538 | struct WarpOpScfForOp : public WarpDistributionPattern { |
1539 | |
1540 | WarpOpScfForOp(MLIRContext *ctx, DistributionMapFn fn, PatternBenefit b = 1) |
1541 | : WarpDistributionPattern(ctx, b), distributionMapFn(std::move(fn)) {} |
1542 | LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp, |
1543 | PatternRewriter &rewriter) const override { |
1544 | auto yield = cast<gpu::YieldOp>( |
1545 | warpOp.getBodyRegion().getBlocks().begin()->getTerminator()); |
1546 | // Only pick up forOp if it is the last op in the region. |
1547 | Operation *lastNode = yield->getPrevNode(); |
1548 | auto forOp = dyn_cast_or_null<scf::ForOp>(lastNode); |
1549 | if (!forOp) |
1550 | return failure(); |
1551 | // Collect Values that come from the warp op but are outside the forOp. |
1552 | // Those Value needs to be returned by the original warpOp and passed to |
1553 | // the new op. |
1554 | llvm::SmallSetVector<Value, 32> escapingValues; |
1555 | SmallVector<Type> inputTypes; |
1556 | SmallVector<Type> distTypes; |
1557 | auto collectEscapingValues = [&](Value value) { |
1558 | if (!escapingValues.insert(X: value)) |
1559 | return; |
1560 | Type distType = value.getType(); |
1561 | if (auto vecType = dyn_cast<VectorType>(distType)) { |
1562 | AffineMap map = distributionMapFn(value); |
1563 | distType = getDistributedType(vecType, map, warpOp.getWarpSize()); |
1564 | } |
1565 | inputTypes.push_back(Elt: value.getType()); |
1566 | distTypes.push_back(Elt: distType); |
1567 | }; |
1568 | |
1569 | mlir::visitUsedValuesDefinedAbove( |
1570 | forOp.getBodyRegion(), [&](OpOperand *operand) { |
1571 | Operation *parent = operand->get().getParentRegion()->getParentOp(); |
1572 | if (warpOp->isAncestor(parent)) { |
1573 | collectEscapingValues(operand->get()); |
1574 | } |
1575 | }); |
1576 | |
1577 | // Any forOp result that is not already yielded by the warpOp |
1578 | // region is also considered escaping and must be returned by the |
1579 | // original warpOp. |
1580 | for (OpResult forResult : forOp.getResults()) { |
1581 | // Check if this forResult is already yielded by the yield op. |
1582 | if (llvm::is_contained(yield->getOperands(), forResult)) |
1583 | continue; |
1584 | collectEscapingValues(forResult); |
1585 | } |
1586 | |
1587 | if (llvm::is_contained(Range&: distTypes, Element: Type{})) |
1588 | return failure(); |
1589 | |
1590 | SmallVector<size_t> newRetIndices; |
1591 | WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns( |
1592 | rewriter, warpOp, escapingValues.getArrayRef(), distTypes, |
1593 | newRetIndices); |
1594 | yield = cast<gpu::YieldOp>( |
1595 | newWarpOp.getBodyRegion().getBlocks().begin()->getTerminator()); |
1596 | |
1597 | SmallVector<Value> newOperands; |
1598 | SmallVector<unsigned> resultIdx; |
1599 | // Collect all the outputs coming from the forOp. |
1600 | for (OpOperand &yieldOperand : yield->getOpOperands()) { |
1601 | if (yieldOperand.get().getDefiningOp() != forOp.getOperation()) |
1602 | continue; |
1603 | auto forResult = cast<OpResult>(yieldOperand.get()); |
1604 | newOperands.push_back( |
1605 | newWarpOp.getResult(yieldOperand.getOperandNumber())); |
1606 | yieldOperand.set(forOp.getInitArgs()[forResult.getResultNumber()]); |
1607 | resultIdx.push_back(yieldOperand.getOperandNumber()); |
1608 | } |
1609 | |
1610 | OpBuilder::InsertionGuard g(rewriter); |
1611 | rewriter.setInsertionPointAfter(newWarpOp); |
1612 | |
1613 | // Create a new for op outside the region with a WarpExecuteOnLane0Op |
1614 | // region inside. |
1615 | auto newForOp = rewriter.create<scf::ForOp>( |
1616 | forOp.getLoc(), forOp.getLowerBound(), forOp.getUpperBound(), |
1617 | forOp.getStep(), newOperands); |
1618 | rewriter.setInsertionPointToStart(newForOp.getBody()); |
1619 | |
1620 | SmallVector<Value> warpInput(newForOp.getRegionIterArgs().begin(), |
1621 | newForOp.getRegionIterArgs().end()); |
1622 | SmallVector<Type> warpInputType(forOp.getResultTypes().begin(), |
1623 | forOp.getResultTypes().end()); |
1624 | llvm::SmallDenseMap<Value, int64_t> argIndexMapping; |
1625 | for (auto [i, retIdx] : llvm::enumerate(First&: newRetIndices)) { |
1626 | auto newWarpResult = newWarpOp.getResult(retIdx); |
1627 | // Unused forOp results yielded by the warpOp region are already included |
1628 | // in the new ForOp. |
1629 | if (llvm::is_contained(newOperands, newWarpResult)) |
1630 | continue; |
1631 | warpInput.push_back(Elt: newWarpResult); |
1632 | argIndexMapping[escapingValues[i]] = warpInputType.size(); |
1633 | warpInputType.push_back(Elt: inputTypes[i]); |
1634 | } |
1635 | auto innerWarp = rewriter.create<WarpExecuteOnLane0Op>( |
1636 | newWarpOp.getLoc(), newForOp.getResultTypes(), newWarpOp.getLaneid(), |
1637 | newWarpOp.getWarpSize(), warpInput, warpInputType); |
1638 | |
1639 | SmallVector<Value> argMapping; |
1640 | argMapping.push_back(Elt: newForOp.getInductionVar()); |
1641 | for (Value args : innerWarp.getBody()->getArguments()) { |
1642 | argMapping.push_back(args); |
1643 | } |
1644 | argMapping.resize(forOp.getBody()->getNumArguments()); |
1645 | SmallVector<Value> yieldOperands; |
1646 | for (Value operand : forOp.getBody()->getTerminator()->getOperands()) |
1647 | yieldOperands.push_back(operand); |
1648 | rewriter.eraseOp(op: forOp.getBody()->getTerminator()); |
1649 | rewriter.mergeBlocks(source: forOp.getBody(), dest: innerWarp.getBody(), argValues: argMapping); |
1650 | rewriter.setInsertionPointToEnd(innerWarp.getBody()); |
1651 | rewriter.create<gpu::YieldOp>(innerWarp.getLoc(), yieldOperands); |
1652 | rewriter.setInsertionPointAfter(innerWarp); |
1653 | if (!innerWarp.getResults().empty()) |
1654 | rewriter.create<scf::YieldOp>(forOp.getLoc(), innerWarp.getResults()); |
1655 | rewriter.eraseOp(op: forOp); |
1656 | // Replace the warpOp result coming from the original ForOp. |
1657 | for (const auto &res : llvm::enumerate(First&: resultIdx)) { |
1658 | rewriter.replaceAllUsesWith(newWarpOp.getResult(res.value()), |
1659 | newForOp.getResult(res.index())); |
1660 | newForOp->setOperand(res.index() + 3, newWarpOp.getResult(res.value())); |
1661 | } |
1662 | newForOp.walk([&](Operation *op) { |
1663 | for (OpOperand &operand : op->getOpOperands()) { |
1664 | auto it = argIndexMapping.find(Val: operand.get()); |
1665 | if (it == argIndexMapping.end()) |
1666 | continue; |
1667 | operand.set(innerWarp.getBodyRegion().getArgument(it->second)); |
1668 | } |
1669 | }); |
1670 | |
1671 | // Finally, hoist out any now uniform code from the inner warp op. |
1672 | mlir::vector::moveScalarUniformCode(innerWarp); |
1673 | return success(); |
1674 | } |
1675 | |
1676 | private: |
1677 | DistributionMapFn distributionMapFn; |
1678 | }; |
1679 | |
1680 | /// A pattern that extracts vector.reduction ops from a WarpExecuteOnLane0Op. |
1681 | /// The vector is reduced in parallel. Currently limited to vector size |
1682 | /// matching the warpOp size. E.g.: |
1683 | /// ``` |
1684 | /// %r = gpu.warp_execute_on_lane_0(%laneid)[32] -> (f32) { |
1685 | /// %0 = "some_def"() : () -> (vector<32xf32>) |
1686 | /// %1 = vector.reduction "add", %0 : vector<32xf32> into f32 |
1687 | /// gpu.yield %1 : f32 |
1688 | /// } |
1689 | /// ``` |
1690 | /// is lowered to: |
1691 | /// ``` |
1692 | /// %0 = gpu.warp_execute_on_lane_0(%laneid)[32] -> (vector<1xf32>) { |
1693 | /// %1 = "some_def"() : () -> (vector<32xf32>) |
1694 | /// gpu.yield %1 : vector<32xf32> |
1695 | /// } |
1696 | /// %a = vector.extract %0[0] : f32 from vector<1xf32> |
1697 | /// %r = ("warp.reduction %a") |
1698 | /// ``` |
1699 | struct WarpOpReduction : public WarpDistributionPattern { |
1700 | WarpOpReduction(MLIRContext *context, |
1701 | DistributedReductionFn distributedReductionFn, |
1702 | PatternBenefit benefit = 1) |
1703 | : WarpDistributionPattern(context, benefit), |
1704 | distributedReductionFn(std::move(distributedReductionFn)) {} |
1705 | |
1706 | LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp, |
1707 | PatternRewriter &rewriter) const override { |
1708 | OpOperand *yieldOperand = |
1709 | getWarpResult(warpOp, llvm::IsaPred<vector::ReductionOp>); |
1710 | if (!yieldOperand) |
1711 | return failure(); |
1712 | |
1713 | auto reductionOp = |
1714 | cast<vector::ReductionOp>(yieldOperand->get().getDefiningOp()); |
1715 | auto vectorType = cast<VectorType>(reductionOp.getVector().getType()); |
1716 | // Only rank 1 vectors supported. |
1717 | if (vectorType.getRank() != 1) |
1718 | return rewriter.notifyMatchFailure( |
1719 | warpOp, "Only rank 1 reductions can be distributed."); |
1720 | // Only warp_size-sized vectors supported. |
1721 | if (vectorType.getShape()[0] % warpOp.getWarpSize() != 0) |
1722 | return rewriter.notifyMatchFailure( |
1723 | warpOp, "Reduction vector dimension must match was size."); |
1724 | if (!reductionOp.getType().isIntOrFloat()) |
1725 | return rewriter.notifyMatchFailure( |
1726 | warpOp, "Reduction distribution currently only supports floats and " |
1727 | "integer types."); |
1728 | |
1729 | int64_t numElements = vectorType.getShape()[0] / warpOp.getWarpSize(); |
1730 | // Return vector that will be reduced from the WarpExecuteOnLane0Op. |
1731 | unsigned operandIndex = yieldOperand->getOperandNumber(); |
1732 | SmallVector<Value> yieldValues = {reductionOp.getVector()}; |
1733 | SmallVector<Type> retTypes = { |
1734 | VectorType::get({numElements}, reductionOp.getType())}; |
1735 | if (reductionOp.getAcc()) { |
1736 | yieldValues.push_back(Elt: reductionOp.getAcc()); |
1737 | retTypes.push_back(Elt: reductionOp.getAcc().getType()); |
1738 | } |
1739 | SmallVector<size_t> newRetIndices; |
1740 | WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns( |
1741 | rewriter, warpOp, yieldValues, retTypes, newRetIndices); |
1742 | rewriter.setInsertionPointAfter(newWarpOp); |
1743 | |
1744 | // Obtain data to reduce for a single lane. |
1745 | Value laneValVec = newWarpOp.getResult(newRetIndices[0]); |
1746 | // Distribute and reduce across threads. |
1747 | Value fullReduce = |
1748 | distributedReductionFn(reductionOp.getLoc(), rewriter, laneValVec, |
1749 | reductionOp.getKind(), newWarpOp.getWarpSize()); |
1750 | if (reductionOp.getAcc()) { |
1751 | fullReduce = vector::makeArithReduction( |
1752 | rewriter, reductionOp.getLoc(), reductionOp.getKind(), fullReduce, |
1753 | newWarpOp.getResult(newRetIndices[1])); |
1754 | } |
1755 | rewriter.replaceAllUsesWith(newWarpOp.getResult(operandIndex), fullReduce); |
1756 | return success(); |
1757 | } |
1758 | |
1759 | private: |
1760 | DistributedReductionFn distributedReductionFn; |
1761 | }; |
1762 | |
1763 | } // namespace |
1764 | |
1765 | void mlir::vector::populateWarpExecuteOnLane0OpToScfForPattern( |
1766 | RewritePatternSet &patterns, |
1767 | const WarpExecuteOnLane0LoweringOptions &options, PatternBenefit benefit) { |
1768 | patterns.add<WarpOpToScfIfPattern>(arg: patterns.getContext(), args: options, args&: benefit); |
1769 | } |
1770 | |
1771 | void mlir::vector::populateDistributeTransferWriteOpPatterns( |
1772 | RewritePatternSet &patterns, const DistributionMapFn &distributionMapFn, |
1773 | unsigned maxNumElementsToExtract, PatternBenefit benefit) { |
1774 | patterns.add<WarpOpTransferWrite>(arg: patterns.getContext(), args: distributionMapFn, |
1775 | args&: maxNumElementsToExtract, args&: benefit); |
1776 | } |
1777 | |
1778 | void mlir::vector::populatePropagateWarpVectorDistributionPatterns( |
1779 | RewritePatternSet &patterns, const DistributionMapFn &distributionMapFn, |
1780 | const WarpShuffleFromIdxFn &warpShuffleFromIdxFn, PatternBenefit benefit, |
1781 | PatternBenefit readBenefit) { |
1782 | patterns.add<WarpOpTransferRead>(arg: patterns.getContext(), args&: readBenefit); |
1783 | patterns.add<WarpOpElementwise, WarpOpDeadResult, WarpOpBroadcast, |
1784 | WarpOpShapeCast, WarpOpExtract, WarpOpForwardOperand, |
1785 | WarpOpConstant, WarpOpExtractElement, WarpOpInsertElement, |
1786 | WarpOpInsertScalar, WarpOpInsert, WarpOpCreateMask>( |
1787 | arg: patterns.getContext(), args&: benefit); |
1788 | patterns.add<WarpOpExtractScalar>(arg: patterns.getContext(), args: warpShuffleFromIdxFn, |
1789 | args&: benefit); |
1790 | patterns.add<WarpOpScfForOp>(arg: patterns.getContext(), args: distributionMapFn, |
1791 | args&: benefit); |
1792 | } |
1793 | |
1794 | void mlir::vector::populateDistributeReduction( |
1795 | RewritePatternSet &patterns, |
1796 | const DistributedReductionFn &distributedReductionFn, |
1797 | PatternBenefit benefit) { |
1798 | patterns.add<WarpOpReduction>(arg: patterns.getContext(), args: distributedReductionFn, |
1799 | args&: benefit); |
1800 | } |
1801 | |
1802 | /// Helper to know if an op can be hoisted out of the region. |
1803 | static bool canBeHoisted(Operation *op, |
1804 | function_ref<bool(Value)> definedOutside) { |
1805 | return llvm::all_of(Range: op->getOperands(), P: definedOutside) && |
1806 | isMemoryEffectFree(op) && op->getNumRegions() == 0; |
1807 | } |
1808 | |
1809 | void mlir::vector::moveScalarUniformCode(WarpExecuteOnLane0Op warpOp) { |
1810 | Block *body = warpOp.getBody(); |
1811 | |
1812 | // Keep track of the ops we want to hoist. |
1813 | llvm::SmallSetVector<Operation *, 8> opsToMove; |
1814 | |
1815 | // Helper to check if a value is or will be defined outside of the region. |
1816 | auto isDefinedOutsideOfBody = [&](Value value) { |
1817 | auto *definingOp = value.getDefiningOp(); |
1818 | return (definingOp && opsToMove.count(definingOp)) || |
1819 | warpOp.isDefinedOutsideOfRegion(value); |
1820 | }; |
1821 | |
1822 | // Do not use walk here, as we do not want to go into nested regions and hoist |
1823 | // operations from there. |
1824 | for (auto &op : body->without_terminator()) { |
1825 | bool hasVectorResult = llvm::any_of(op.getResults(), [](Value result) { |
1826 | return isa<VectorType>(result.getType()); |
1827 | }); |
1828 | if (!hasVectorResult && canBeHoisted(&op, isDefinedOutsideOfBody)) |
1829 | opsToMove.insert(&op); |
1830 | } |
1831 | |
1832 | // Move all the ops marked as uniform outside of the region. |
1833 | for (Operation *op : opsToMove) |
1834 | op->moveBefore(warpOp); |
1835 | } |
1836 |
Definitions
- calculateImplicitMap
- DistributedLoadStoreHelper
- DistributedLoadStoreHelper
- buildDistributedOffset
- buildStore
- buildLoad
- cloneOpWithOperandsAndTypes
- WarpOpToScfIfPattern
- WarpOpToScfIfPattern
- matchAndRewrite
- getDistributedType
- WarpOpTransferWrite
- WarpOpTransferWrite
- tryDistributeOp
- tryExtractOp
- matchAndRewrite
- cloneWriteOp
- WarpOpElementwise
- matchAndRewrite
- WarpOpConstant
- matchAndRewrite
- WarpOpTransferRead
- matchAndRewrite
- WarpOpDeadResult
- matchAndRewrite
- WarpOpForwardOperand
- matchAndRewrite
- WarpOpBroadcast
- matchAndRewrite
- WarpOpShapeCast
- matchAndRewrite
- WarpOpCreateMask
- matchAndRewrite
- WarpOpExtract
- matchAndRewrite
- WarpOpExtractScalar
- WarpOpExtractScalar
- matchAndRewrite
- WarpOpExtractElement
- matchAndRewrite
- WarpOpInsertScalar
- matchAndRewrite
- WarpOpInsert
- matchAndRewrite
- WarpOpInsertElement
- matchAndRewrite
- WarpOpScfForOp
- WarpOpScfForOp
- matchAndRewrite
- WarpOpReduction
- WarpOpReduction
- matchAndRewrite
- populateWarpExecuteOnLane0OpToScfForPattern
- populateDistributeTransferWriteOpPatterns
- populatePropagateWarpVectorDistributionPatterns
- populateDistributeReduction
- canBeHoisted
Improve your Profiling and Debugging skills
Find out more