1//===- VectorDistribute.cpp - patterns to do vector distribution ----------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#include "mlir/Dialect/Affine/IR/AffineOps.h"
10#include "mlir/Dialect/Arith/IR/Arith.h"
11#include "mlir/Dialect/GPU/IR/GPUDialect.h"
12#include "mlir/Dialect/GPU/Utils/DistributionUtils.h"
13#include "mlir/Dialect/MemRef/IR/MemRef.h"
14#include "mlir/Dialect/SCF/IR/SCF.h"
15#include "mlir/Dialect/Vector/IR/VectorOps.h"
16#include "mlir/Dialect/Vector/Transforms/VectorDistribution.h"
17#include "mlir/IR/AffineExpr.h"
18#include "mlir/Interfaces/SideEffectInterfaces.h"
19#include "mlir/Transforms/RegionUtils.h"
20#include "llvm/ADT/SetVector.h"
21#include "llvm/Support/FormatVariadic.h"
22#include <utility>
23
24using namespace mlir;
25using namespace mlir::vector;
26using namespace mlir::gpu;
27
28/// Currently the distribution map is implicit based on the vector shape. In the
29/// future it will be part of the op.
30/// Example:
31/// ```
32/// %0 = gpu.warp_execute_on_lane_0(%arg0) -> (vector<1x16x2xf32>) {
33/// ...
34/// gpu.yield %3 : vector<32x16x64xf32>
35/// }
36/// ```
37/// Would have an implicit map of:
38/// `(d0, d1, d2) -> (d0, d2)`
39static AffineMap calculateImplicitMap(VectorType sequentialType,
40 VectorType distributedType) {
41 SmallVector<AffineExpr> perm;
42 perm.reserve(N: 1);
43 // Check which dimensions of the sequential type are different than the
44 // dimensions of the distributed type to know the distributed dimensions. Then
45 // associate each distributed dimension to an ID in order.
46 for (unsigned i = 0, e = sequentialType.getRank(); i < e; i++) {
47 if (sequentialType.getDimSize(i) != distributedType.getDimSize(i))
48 perm.push_back(Elt: getAffineDimExpr(i, distributedType.getContext()));
49 }
50 auto map = AffineMap::get(sequentialType.getRank(), 0, perm,
51 distributedType.getContext());
52 return map;
53}
54
55namespace {
56
57/// Helper struct to create the load / store operations that permit transit
58/// through the parallel / sequential and the sequential / parallel boundaries
59/// when performing `rewriteWarpOpToScfFor`.
60///
61/// The vector distribution dimension is inferred from the vector types.
62struct DistributedLoadStoreHelper {
63 DistributedLoadStoreHelper(Value sequentialVal, Value distributedVal,
64 Value laneId, Value zero)
65 : sequentialVal(sequentialVal), distributedVal(distributedVal),
66 laneId(laneId), zero(zero) {
67 sequentialVectorType = dyn_cast<VectorType>(sequentialVal.getType());
68 distributedVectorType = dyn_cast<VectorType>(distributedVal.getType());
69 if (sequentialVectorType && distributedVectorType)
70 distributionMap =
71 calculateImplicitMap(sequentialVectorType, distributedVectorType);
72 }
73
74 Value buildDistributedOffset(RewriterBase &b, Location loc, int64_t index) {
75 int64_t distributedSize = distributedVectorType.getDimSize(index);
76 AffineExpr tid = getAffineSymbolExpr(position: 0, context: b.getContext());
77 return b.createOrFold<affine::AffineApplyOp>(loc, tid * distributedSize,
78 ArrayRef<Value>{laneId});
79 }
80
81 /// Create a store during the process of distributing the
82 /// `vector.warp_execute_on_thread_0` op.
83 /// Vector distribution assumes the following convention regarding the
84 /// temporary buffers that are created to transition values. This **must**
85 /// be properly specified in the `options.warpAllocationFn`:
86 /// 1. scalars of type T transit through a memref<1xT>.
87 /// 2. vectors of type V<shapexT> transit through a memref<shapexT>
88 Operation *buildStore(RewriterBase &b, Location loc, Value val,
89 Value buffer) {
90 assert((val == distributedVal || val == sequentialVal) &&
91 "Must store either the preregistered distributed or the "
92 "preregistered sequential value.");
93 // Scalar case can directly use memref.store.
94 if (!isa<VectorType>(val.getType()))
95 return b.create<memref::StoreOp>(loc, val, buffer, zero);
96
97 // Vector case must use vector::TransferWriteOp which will later lower to
98 // vector.store of memref.store depending on further lowerings.
99 int64_t rank = sequentialVectorType.getRank();
100 SmallVector<Value> indices(rank, zero);
101 if (val == distributedVal) {
102 for (auto dimExpr : distributionMap.getResults()) {
103 int64_t index = cast<AffineDimExpr>(Val&: dimExpr).getPosition();
104 indices[index] = buildDistributedOffset(b, loc, index);
105 }
106 }
107 SmallVector<bool> inBounds(indices.size(), true);
108 return b.create<vector::TransferWriteOp>(
109 loc, val, buffer, indices,
110 ArrayRef<bool>(inBounds.begin(), inBounds.end()));
111 }
112
113 /// Create a load during the process of distributing the
114 /// `vector.warp_execute_on_thread_0` op.
115 /// Vector distribution assumes the following convention regarding the
116 /// temporary buffers that are created to transition values. This **must**
117 /// be properly specified in the `options.warpAllocationFn`:
118 /// 1. scalars of type T transit through a memref<1xT>.
119 /// 2. vectors of type V<shapexT> transit through a memref<shapexT>
120 ///
121 /// When broadcastMode is true, the load is not distributed to account for
122 /// the broadcast semantics of the `gpu.warp_execute_on_lane_0` op.
123 ///
124 /// Example:
125 ///
126 /// ```
127 /// %r = gpu.warp_execute_on_lane_0(...) -> (f32) {
128 /// gpu.yield %cst : f32
129 /// }
130 /// // Both types are f32. The constant %cst is broadcasted to all lanes.
131 /// ```
132 /// This behavior described in more detail in the documentation of the op.
133 Value buildLoad(RewriterBase &b, Location loc, Type type, Value buffer) {
134
135 // Scalar case can directly use memref.store.
136 if (!isa<VectorType>(type))
137 return b.create<memref::LoadOp>(loc, buffer, zero);
138
139 // Other cases must be vector atm.
140 // Vector case must use vector::TransferReadOp which will later lower to
141 // vector.read of memref.read depending on further lowerings.
142 assert((type == distributedVectorType || type == sequentialVectorType) &&
143 "Must store either the preregistered distributed or the "
144 "preregistered sequential type.");
145 SmallVector<Value> indices(sequentialVectorType.getRank(), zero);
146 if (type == distributedVectorType) {
147 for (auto dimExpr : distributionMap.getResults()) {
148 int64_t index = cast<AffineDimExpr>(Val&: dimExpr).getPosition();
149 indices[index] = buildDistributedOffset(b, loc, index);
150 }
151 }
152 SmallVector<bool> inBounds(indices.size(), true);
153 return b.create<vector::TransferReadOp>(
154 loc, cast<VectorType>(type), buffer, indices,
155 ArrayRef<bool>(inBounds.begin(), inBounds.end()));
156 }
157
158 Value sequentialVal, distributedVal, laneId, zero;
159 VectorType sequentialVectorType, distributedVectorType;
160 AffineMap distributionMap;
161};
162
163} // namespace
164
165// Clones `op` into a new operation that takes `operands` and returns
166// `resultTypes`.
167static Operation *cloneOpWithOperandsAndTypes(RewriterBase &rewriter,
168 Location loc, Operation *op,
169 ArrayRef<Value> operands,
170 ArrayRef<Type> resultTypes) {
171 OperationState res(loc, op->getName().getStringRef(), operands, resultTypes,
172 op->getAttrs());
173 return rewriter.create(state: res);
174}
175
176namespace {
177
178/// Rewrite a WarpExecuteOnLane0Op into a predicated scf.if op where the single
179/// thread `laneId` executes the entirety of the computation.
180///
181/// After the transformation:
182/// - the IR within the scf.if op can be thought of as executing sequentially
183/// (from the point of view of threads along `laneId`).
184/// - the IR outside of the scf.if op can be thought of as executing in
185/// parallel (from the point of view of threads along `laneId`).
186///
187/// Values that need to transit through the parallel / sequential and the
188/// sequential / parallel boundaries do so via reads and writes to a temporary
189/// memory location.
190///
191/// The transformation proceeds in multiple steps:
192/// 1. Create the scf.if op.
193/// 2. Insert appropriate (alloc, write)-pairs before the scf.if and reads
194/// within the scf.if to transit the values captured from above.
195/// 3. Synchronize before the scf.if to ensure all writes inserted in 2. are
196/// consistent within the scf.if.
197/// 4. Move the body of the WarpExecuteOnLane0Op inside the scf.if.
198/// 5. Insert appropriate writes within scf.if and reads after the scf.if to
199/// transit the values returned by the op.
200/// 6. Synchronize after the scf.if to ensure all writes inserted in 5. are
201/// consistent after the scf.if.
202/// 7. Perform late cleanups.
203///
204/// All this assumes the vector distribution occurs along the most minor
205/// distributed vector dimension.
206struct WarpOpToScfIfPattern : public WarpDistributionPattern {
207 WarpOpToScfIfPattern(MLIRContext *context,
208 const WarpExecuteOnLane0LoweringOptions &options,
209 PatternBenefit benefit = 1)
210 : WarpDistributionPattern(context, benefit), options(options) {}
211
212 LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
213 PatternRewriter &rewriter) const override {
214 assert(warpOp.getBodyRegion().hasOneBlock() &&
215 "expected WarpOp with single block");
216 Block *warpOpBody = &warpOp.getBodyRegion().front();
217 Location loc = warpOp.getLoc();
218
219 // Passed all checks. Start rewriting.
220 OpBuilder::InsertionGuard g(rewriter);
221 rewriter.setInsertionPoint(warpOp);
222
223 // Step 1: Create scf.if op.
224 Value c0 = rewriter.create<arith::ConstantIndexOp>(location: loc, args: 0);
225 Value isLane0 = rewriter.create<arith::CmpIOp>(
226 loc, arith::CmpIPredicate::eq, warpOp.getLaneid(), c0);
227 auto ifOp = rewriter.create<scf::IfOp>(loc, isLane0,
228 /*withElseRegion=*/false);
229 rewriter.eraseOp(op: ifOp.thenBlock()->getTerminator());
230
231 // Step 2: insert appropriate (alloc, write)-pairs before the scf.if and
232 // reads within the scf.if to transit the values captured from above.
233 SmallVector<Value> bbArgReplacements;
234 for (const auto &it : llvm::enumerate(warpOp.getArgs())) {
235 Value sequentialVal = warpOpBody->getArgument(it.index());
236 Value distributedVal = it.value();
237 DistributedLoadStoreHelper helper(sequentialVal, distributedVal,
238 warpOp.getLaneid(), c0);
239
240 // Create buffer before the ifOp.
241 rewriter.setInsertionPoint(ifOp);
242 Value buffer = options.warpAllocationFn(loc, rewriter, warpOp,
243 sequentialVal.getType());
244 // Store distributed vector into buffer, before the ifOp.
245 helper.buildStore(rewriter, loc, distributedVal, buffer);
246 // Load sequential vector from buffer, inside the ifOp.
247 rewriter.setInsertionPointToStart(ifOp.thenBlock());
248 bbArgReplacements.push_back(
249 helper.buildLoad(rewriter, loc, sequentialVal.getType(), buffer));
250 }
251
252 // Step 3. Insert sync after all the stores and before all the loads.
253 if (!warpOp.getArgs().empty()) {
254 rewriter.setInsertionPoint(ifOp);
255 options.warpSyncronizationFn(loc, rewriter, warpOp);
256 }
257
258 // Step 4. Move body of warpOp to ifOp.
259 rewriter.mergeBlocks(source: warpOpBody, dest: ifOp.thenBlock(), argValues: bbArgReplacements);
260
261 // Step 5. Insert appropriate writes within scf.if and reads after the
262 // scf.if to transit the values returned by the op.
263 // TODO: at this point, we can reuse the shared memory from previous
264 // buffers.
265 SmallVector<Value> replacements;
266 auto yieldOp = cast<gpu::YieldOp>(ifOp.thenBlock()->getTerminator());
267 Location yieldLoc = yieldOp.getLoc();
268 for (const auto &it : llvm::enumerate(yieldOp.getOperands())) {
269 Value sequentialVal = it.value();
270 Value distributedVal = warpOp->getResult(it.index());
271 DistributedLoadStoreHelper helper(sequentialVal, distributedVal,
272 warpOp.getLaneid(), c0);
273
274 // Create buffer before the ifOp.
275 rewriter.setInsertionPoint(ifOp);
276 Value buffer = options.warpAllocationFn(loc, rewriter, warpOp,
277 sequentialVal.getType());
278
279 // Store yielded value into buffer, inside the ifOp, before the
280 // terminator.
281 rewriter.setInsertionPoint(yieldOp);
282 helper.buildStore(rewriter, loc, sequentialVal, buffer);
283
284 // Load distributed value from buffer, after the warpOp.
285 rewriter.setInsertionPointAfter(ifOp);
286 // Result type and yielded value type are the same. This is a broadcast.
287 // E.g.:
288 // %r = gpu.warp_execute_on_lane_0(...) -> (f32) {
289 // gpu.yield %cst : f32
290 // }
291 // Both types are f32. The constant %cst is broadcasted to all lanes.
292 // This is described in more detail in the documentation of the op.
293 replacements.push_back(
294 helper.buildLoad(rewriter, loc, distributedVal.getType(), buffer));
295 }
296
297 // Step 6. Insert sync after all the stores and before all the loads.
298 if (!yieldOp.getOperands().empty()) {
299 rewriter.setInsertionPointAfter(ifOp);
300 options.warpSyncronizationFn(loc, rewriter, warpOp);
301 }
302
303 // Step 7. Delete terminator and add empty scf.yield.
304 rewriter.eraseOp(op: yieldOp);
305 rewriter.setInsertionPointToEnd(ifOp.thenBlock());
306 rewriter.create<scf::YieldOp>(yieldLoc);
307
308 // Compute replacements for WarpOp results.
309 rewriter.replaceOp(warpOp, replacements);
310
311 return success();
312 }
313
314private:
315 const WarpExecuteOnLane0LoweringOptions &options;
316};
317
318/// Return the distributed vector type based on the original type and the
319/// distribution map. The map is expected to have a dimension equal to the
320/// original type rank and should be a projection where the results are the
321/// distributed dimensions. The number of results should be equal to the number
322/// of warp sizes which is currently limited to 1.
323/// Example: For a vector<16x32x64> distributed with a map(d0, d1, d2) -> (d1)
324/// and a warp size of 16 would distribute the second dimension (associated to
325/// d1) and return vector<16x2x64>
326static VectorType getDistributedType(VectorType originalType, AffineMap map,
327 int64_t warpSize) {
328 SmallVector<int64_t> targetShape(originalType.getShape());
329 for (unsigned i = 0, e = map.getNumResults(); i < e; i++) {
330 unsigned position = map.getDimPosition(idx: i);
331 if (targetShape[position] % warpSize != 0) {
332 if (warpSize % targetShape[position] != 0) {
333 return VectorType();
334 }
335 warpSize /= targetShape[position];
336 targetShape[position] = 1;
337 continue;
338 }
339 targetShape[position] = targetShape[position] / warpSize;
340 warpSize = 1;
341 break;
342 }
343 if (warpSize != 1) {
344 return VectorType();
345 }
346 VectorType targetType =
347 VectorType::get(targetShape, originalType.getElementType());
348 return targetType;
349}
350
351/// Distribute transfer_write ops based on the affine map returned by
352/// `distributionMapFn`. Writes of size more than `maxNumElementToExtract`
353/// will not be distributed (it should be less than the warp size).
354///
355/// Example:
356/// ```
357/// %0 = gpu.warp_execute_on_lane_0(%id){
358/// ...
359/// vector.transfer_write %v, %A[%c0] : vector<32xf32>, memref<128xf32>
360/// gpu.yield
361/// }
362/// ```
363/// To
364/// ```
365/// %r:3 = gpu.warp_execute_on_lane_0(%id) -> (vector<1xf32>) {
366/// ...
367/// gpu.yield %v : vector<32xf32>
368/// }
369/// vector.transfer_write %v, %A[%id] : vector<1xf32>, memref<128xf32>
370struct WarpOpTransferWrite : public WarpDistributionPattern {
371 WarpOpTransferWrite(MLIRContext *ctx, DistributionMapFn fn,
372 unsigned maxNumElementsToExtract, PatternBenefit b = 1)
373 : WarpDistributionPattern(ctx, b), distributionMapFn(std::move(fn)),
374 maxNumElementsToExtract(maxNumElementsToExtract) {}
375
376 /// Distribute the TransferWriteOp. Only 1D distributions and vector dims that
377 /// are multiples of the distribution ratio are supported at the moment.
378 LogicalResult tryDistributeOp(RewriterBase &rewriter,
379 vector::TransferWriteOp writeOp,
380 WarpExecuteOnLane0Op warpOp) const {
381 VectorType writtenVectorType = writeOp.getVectorType();
382
383 // 1. If the write is 0-D, we just clone it into a new WarpExecuteOnLane0Op
384 // to separate it from the rest.
385 if (writtenVectorType.getRank() == 0)
386 return failure();
387
388 // 2. Compute the distributed type.
389 AffineMap map = distributionMapFn(writeOp.getVector());
390 VectorType targetType =
391 getDistributedType(writtenVectorType, map, warpOp.getWarpSize());
392 if (!targetType)
393 return failure();
394
395 // 2.5 Compute the distributed type for the new mask;
396 VectorType maskType;
397 if (writeOp.getMask()) {
398 // TODO: Distribution of masked writes with non-trivial permutation maps
399 // requires the distribution of the mask to elementwise match the
400 // distribution of the permuted written vector. Currently the details
401 // of which lane is responsible for which element is captured strictly
402 // by shape information on the warp op, and thus requires materializing
403 // the permutation in IR.
404 if (!writeOp.getPermutationMap().isMinorIdentity())
405 return failure();
406 maskType =
407 getDistributedType(writeOp.getMaskType(), map, warpOp.getWarpSize());
408 }
409
410 // 3. clone the write into a new WarpExecuteOnLane0Op to separate it from
411 // the rest.
412 vector::TransferWriteOp newWriteOp =
413 cloneWriteOp(rewriter, warpOp, writeOp, targetType, maskType);
414
415 // 4. Reindex the write using the distribution map.
416 auto newWarpOp =
417 newWriteOp.getVector().getDefiningOp<WarpExecuteOnLane0Op>();
418
419 // Delinearize the lane id based on the way threads are divided across the
420 // vector. To get the number of threads per vector dimension, divide the
421 // sequential size by the distributed size along each dim.
422 rewriter.setInsertionPoint(newWriteOp);
423 SmallVector<OpFoldResult> delinearizedIdSizes;
424 for (auto [seqSize, distSize] :
425 llvm::zip_equal(writtenVectorType.getShape(), targetType.getShape())) {
426 assert(seqSize % distSize == 0 && "Invalid distributed vector shape");
427 delinearizedIdSizes.push_back(rewriter.getIndexAttr(seqSize / distSize));
428 }
429 SmallVector<Value> delinearized;
430 if (map.getNumResults() > 1) {
431 delinearized = rewriter
432 .create<mlir::affine::AffineDelinearizeIndexOp>(
433 newWarpOp.getLoc(), newWarpOp.getLaneid(),
434 delinearizedIdSizes)
435 .getResults();
436 } else {
437 // If there is only one map result, we can elide the delinearization
438 // op and use the lane id directly.
439 delinearized.append(targetType.getRank(), newWarpOp.getLaneid());
440 }
441
442 AffineMap indexMap = map.compose(newWriteOp.getPermutationMap());
443 Location loc = newWriteOp.getLoc();
444 SmallVector<Value> indices(newWriteOp.getIndices().begin(),
445 newWriteOp.getIndices().end());
446 for (auto it : llvm::zip(indexMap.getResults(), map.getResults())) {
447 AffineExpr d0, d1;
448 bindDims(newWarpOp.getContext(), d0, d1);
449 auto indexExpr = dyn_cast<AffineDimExpr>(std::get<0>(it));
450 if (!indexExpr)
451 continue;
452 unsigned indexPos = indexExpr.getPosition();
453 unsigned vectorPos = cast<AffineDimExpr>(std::get<1>(it)).getPosition();
454 Value laneId = delinearized[vectorPos];
455 auto scale =
456 rewriter.getAffineConstantExpr(targetType.getDimSize(vectorPos));
457 indices[indexPos] = affine::makeComposedAffineApply(
458 rewriter, loc, d0 + scale * d1, {indices[indexPos], laneId});
459 }
460 newWriteOp.getIndicesMutable().assign(indices);
461
462 return success();
463 }
464
465 /// Extract TransferWriteOps of vector<1x> into a separate warp op.
466 LogicalResult tryExtractOp(RewriterBase &rewriter,
467 vector::TransferWriteOp writeOp,
468 WarpExecuteOnLane0Op warpOp) const {
469 Location loc = writeOp.getLoc();
470 VectorType vecType = writeOp.getVectorType();
471
472 if (vecType.getNumElements() > maxNumElementsToExtract) {
473 return rewriter.notifyMatchFailure(
474 warpOp,
475 llvm::formatv(
476 "writes more elements ({0}) than allowed to extract ({1})",
477 vecType.getNumElements(), maxNumElementsToExtract));
478 }
479
480 // Do not process warp ops that contain only TransferWriteOps.
481 if (llvm::all_of(warpOp.getOps(),
482 llvm::IsaPred<vector::TransferWriteOp, gpu::YieldOp>))
483 return failure();
484
485 SmallVector<Value> yieldValues = {writeOp.getVector()};
486 SmallVector<Type> retTypes = {vecType};
487 SmallVector<size_t> newRetIndices;
488 WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
489 rewriter, warpOp, yieldValues, retTypes, newRetIndices);
490 rewriter.setInsertionPointAfter(newWarpOp);
491
492 // Create a second warp op that contains only writeOp.
493 auto secondWarpOp = rewriter.create<WarpExecuteOnLane0Op>(
494 loc, TypeRange(), newWarpOp.getLaneid(), newWarpOp.getWarpSize());
495 Block &body = secondWarpOp.getBodyRegion().front();
496 rewriter.setInsertionPointToStart(&body);
497 auto newWriteOp =
498 cast<vector::TransferWriteOp>(rewriter.clone(*writeOp.getOperation()));
499 newWriteOp.getValueToStoreMutable().assign(
500 newWarpOp.getResult(newRetIndices[0]));
501 rewriter.eraseOp(op: writeOp);
502 rewriter.create<gpu::YieldOp>(newWarpOp.getLoc());
503 return success();
504 }
505
506 LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
507 PatternRewriter &rewriter) const override {
508 auto yield = cast<gpu::YieldOp>(
509 warpOp.getBodyRegion().getBlocks().begin()->getTerminator());
510 Operation *lastNode = yield->getPrevNode();
511 auto writeOp = dyn_cast_or_null<vector::TransferWriteOp>(lastNode);
512 if (!writeOp)
513 return failure();
514
515 Value maybeMask = writeOp.getMask();
516 if (!llvm::all_of(writeOp->getOperands(), [&](Value value) {
517 return writeOp.getVector() == value ||
518 (maybeMask && maybeMask == value) ||
519 warpOp.isDefinedOutsideOfRegion(value);
520 }))
521 return failure();
522
523 if (succeeded(tryDistributeOp(rewriter, writeOp, warpOp)))
524 return success();
525
526 // Masked writes not supported for extraction.
527 if (writeOp.getMask())
528 return failure();
529
530 if (succeeded(tryExtractOp(rewriter, writeOp, warpOp)))
531 return success();
532
533 return failure();
534 }
535
536private:
537 /// Clone `writeOp` assumed to be nested under `warpOp` into a new warp
538 /// execute op with the proper return type. The new write op is updated to
539 /// write the result of the new warp execute op. The old `writeOp` is deleted.
540 vector::TransferWriteOp cloneWriteOp(RewriterBase &rewriter,
541 WarpExecuteOnLane0Op warpOp,
542 vector::TransferWriteOp writeOp,
543 VectorType targetType,
544 VectorType maybeMaskType) const {
545 assert(writeOp->getParentOp() == warpOp &&
546 "write must be nested immediately under warp");
547 OpBuilder::InsertionGuard g(rewriter);
548 SmallVector<size_t> newRetIndices;
549 WarpExecuteOnLane0Op newWarpOp;
550 if (maybeMaskType) {
551 newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
552 rewriter, warpOp, ValueRange{writeOp.getVector(), writeOp.getMask()},
553 TypeRange{targetType, maybeMaskType}, newRetIndices);
554 } else {
555 newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
556 rewriter, warpOp, ValueRange{{writeOp.getVector()}},
557 TypeRange{targetType}, newRetIndices);
558 }
559 rewriter.setInsertionPointAfter(newWarpOp);
560 auto newWriteOp =
561 cast<vector::TransferWriteOp>(rewriter.clone(*writeOp.getOperation()));
562 rewriter.eraseOp(op: writeOp);
563 newWriteOp.getValueToStoreMutable().assign(
564 newWarpOp.getResult(newRetIndices[0]));
565 if (maybeMaskType)
566 newWriteOp.getMaskMutable().assign(newWarpOp.getResult(newRetIndices[1]));
567 return newWriteOp;
568 }
569
570 DistributionMapFn distributionMapFn;
571 unsigned maxNumElementsToExtract = 1;
572};
573
574/// Sink out elementwise op feeding into a warp op yield.
575/// ```
576/// %0 = gpu.warp_execute_on_lane_0(%arg0) -> (vector<1xf32>) {
577/// ...
578/// %3 = arith.addf %1, %2 : vector<32xf32>
579/// gpu.yield %3 : vector<32xf32>
580/// }
581/// ```
582/// To
583/// ```
584/// %r:3 = gpu.warp_execute_on_lane_0(%arg0) -> (vector<1xf32>,
585/// vector<1xf32>, vector<1xf32>) {
586/// ...
587/// %4 = arith.addf %2, %3 : vector<32xf32>
588/// gpu.yield %4, %2, %3 : vector<32xf32>, vector<32xf32>,
589/// vector<32xf32>
590/// }
591/// %0 = arith.addf %r#1, %r#2 : vector<1xf32>
592struct WarpOpElementwise : public WarpDistributionPattern {
593 using Base::Base;
594 LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
595 PatternRewriter &rewriter) const override {
596 OpOperand *yieldOperand = getWarpResult(warpOp, [](Operation *op) {
597 return OpTrait::hasElementwiseMappableTraits(op);
598 });
599 if (!yieldOperand)
600 return failure();
601
602 Operation *elementWise = yieldOperand->get().getDefiningOp();
603 unsigned operandIndex = yieldOperand->getOperandNumber();
604 Value distributedVal = warpOp.getResult(operandIndex);
605 SmallVector<Value> yieldValues;
606 SmallVector<Type> retTypes;
607 Location loc = warpOp.getLoc();
608 for (OpOperand &operand : elementWise->getOpOperands()) {
609 Type targetType;
610 if (auto vecType = dyn_cast<VectorType>(distributedVal.getType())) {
611 // If the result type is a vector, the operands must also be vectors.
612 auto operandType = cast<VectorType>(operand.get().getType());
613 targetType =
614 VectorType::get(vecType.getShape(), operandType.getElementType());
615 } else {
616 auto operandType = operand.get().getType();
617 assert(!isa<VectorType>(operandType) &&
618 "unexpected yield of vector from op with scalar result type");
619 targetType = operandType;
620 }
621 retTypes.push_back(Elt: targetType);
622 yieldValues.push_back(Elt: operand.get());
623 }
624 SmallVector<size_t> newRetIndices;
625 WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
626 rewriter, warpOp, yieldValues, retTypes, newRetIndices);
627 rewriter.setInsertionPointAfter(newWarpOp);
628 SmallVector<Value> newOperands(elementWise->getOperands().begin(),
629 elementWise->getOperands().end());
630 for (unsigned i : llvm::seq(Begin: unsigned(0), End: elementWise->getNumOperands())) {
631 newOperands[i] = newWarpOp.getResult(newRetIndices[i]);
632 }
633 OpBuilder::InsertionGuard g(rewriter);
634 rewriter.setInsertionPointAfter(newWarpOp);
635 Operation *newOp = cloneOpWithOperandsAndTypes(
636 rewriter, loc, elementWise, newOperands,
637 {newWarpOp.getResult(operandIndex).getType()});
638 rewriter.replaceAllUsesWith(newWarpOp.getResult(operandIndex),
639 newOp->getResult(idx: 0));
640 return success();
641 }
642};
643
644/// Sink out splat constant op feeding into a warp op yield.
645/// ```
646/// %0 = gpu.warp_execute_on_lane_0(%arg0) -> (vector<1xf32>) {
647/// ...
648/// %cst = arith.constant dense<2.0> : vector<32xf32>
649/// gpu.yield %cst : vector<32xf32>
650/// }
651/// ```
652/// To
653/// ```
654/// gpu.warp_execute_on_lane_0(%arg0 {
655/// ...
656/// }
657/// %0 = arith.constant dense<2.0> : vector<1xf32>
658struct WarpOpConstant : public WarpDistributionPattern {
659 using Base::Base;
660 LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
661 PatternRewriter &rewriter) const override {
662 OpOperand *yieldOperand =
663 getWarpResult(warpOp, llvm::IsaPred<arith::ConstantOp>);
664 if (!yieldOperand)
665 return failure();
666 auto constantOp = yieldOperand->get().getDefiningOp<arith::ConstantOp>();
667 auto dense = dyn_cast<SplatElementsAttr>(constantOp.getValue());
668 if (!dense)
669 return failure();
670 // Notify the rewriter that the warp op is changing (see the comment on
671 // the WarpOpTransferRead pattern).
672 rewriter.startOpModification(op: warpOp);
673 unsigned operandIndex = yieldOperand->getOperandNumber();
674 Attribute scalarAttr = dense.getSplatValue<Attribute>();
675 auto newAttr = DenseElementsAttr::get(
676 cast<ShapedType>(warpOp.getResult(operandIndex).getType()), scalarAttr);
677 Location loc = warpOp.getLoc();
678 rewriter.setInsertionPointAfter(warpOp);
679 Value distConstant = rewriter.create<arith::ConstantOp>(loc, newAttr);
680 rewriter.replaceAllUsesWith(warpOp.getResult(operandIndex), distConstant);
681 rewriter.finalizeOpModification(op: warpOp);
682 return success();
683 }
684};
685
686/// Sink out transfer_read op feeding into a warp op yield.
687/// ```
688/// %0 = gpu.warp_execute_on_lane_0(%arg0) -> (vector<1xf32>) {
689/// ...
690// %2 = vector.transfer_read %src[%c0], %cst : memref<1024xf32>,
691// vector<32xf32>
692/// gpu.yield %2 : vector<32xf32>
693/// }
694/// ```
695/// To
696/// ```
697/// %dead = gpu.warp_execute_on_lane_0(%arg0) -> (vector<1xf32>,
698/// vector<1xf32>, vector<1xf32>) {
699/// ...
700/// %2 = vector.transfer_read %src[%c0], %cst : memref<1024xf32>,
701/// vector<32xf32> gpu.yield %2 : vector<32xf32>
702/// }
703/// %0 = vector.transfer_read %src[%c0], %cst : memref<1024xf32>, vector<1xf32>
704struct WarpOpTransferRead : public WarpDistributionPattern {
705 using Base::Base;
706 LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
707 PatternRewriter &rewriter) const override {
708 // Try to find a distributable yielded read. Note that this pattern can
709 // still fail at the end after distribution, in which case this might have
710 // missed another distributable read.
711 OpOperand *operand = getWarpResult(warpOp, [](Operation *op) {
712 // Don't duplicate transfer_read ops when distributing.
713 return isa<vector::TransferReadOp>(op) && op->hasOneUse();
714 });
715 if (!operand)
716 return rewriter.notifyMatchFailure(
717 warpOp, "warp result is not a vector.transfer_read op");
718 auto read = operand->get().getDefiningOp<vector::TransferReadOp>();
719
720 // Source must be defined outside of the region.
721 if (!warpOp.isDefinedOutsideOfRegion(read.getBase()))
722 return rewriter.notifyMatchFailure(
723 read, "source must be defined outside of the region");
724
725 unsigned operandIndex = operand->getOperandNumber();
726 Value distributedVal = warpOp.getResult(operandIndex);
727
728 SmallVector<Value, 4> indices(read.getIndices().begin(),
729 read.getIndices().end());
730 auto sequentialType = cast<VectorType>(read.getResult().getType());
731 auto distributedType = cast<VectorType>(distributedVal.getType());
732 AffineMap map = calculateImplicitMap(sequentialType, distributedType);
733 AffineMap indexMap = map.compose(read.getPermutationMap());
734
735 // Try to delinearize the lane ID to match the rank expected for
736 // distribution.
737 SmallVector<Value> delinearizedIds;
738 if (!delinearizeLaneId(builder&: rewriter, loc: read.getLoc(), originalShape: sequentialType.getShape(),
739 distributedShape: distributedType.getShape(), warpSize: warpOp.getWarpSize(),
740 laneId: warpOp.getLaneid(), delinearizedIds)) {
741 return rewriter.notifyMatchFailure(
742 read, "cannot delinearize lane ID for distribution");
743 }
744 assert(!delinearizedIds.empty() || map.getNumResults() == 0);
745
746 // Distribute indices and the mask (if present).
747 OpBuilder::InsertionGuard g(rewriter);
748 SmallVector<Value> additionalResults(indices.begin(), indices.end());
749 SmallVector<Type> additionalResultTypes(indices.size(),
750 rewriter.getIndexType());
751 additionalResults.push_back(Elt: read.getPadding());
752 additionalResultTypes.push_back(Elt: read.getPadding().getType());
753
754 bool hasMask = false;
755 if (read.getMask()) {
756 hasMask = true;
757 // TODO: Distribution of masked reads with non-trivial permutation maps
758 // requires the distribution of the mask to elementwise match the
759 // distribution of the permuted written vector. Currently the details
760 // of which lane is responsible for which element is captured strictly
761 // by shape information on the warp op, and thus requires materializing
762 // the permutation in IR.
763 if (!mlir::compressUnusedDims(read.getPermutationMap()).isIdentity())
764 return rewriter.notifyMatchFailure(
765 read, "non-trivial permutation maps not supported");
766 VectorType maskType =
767 getDistributedType(read.getMaskType(), map, warpOp.getWarpSize());
768 additionalResults.push_back(Elt: read.getMask());
769 additionalResultTypes.push_back(Elt: maskType);
770 }
771
772 SmallVector<size_t> newRetIndices;
773 WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
774 rewriter, warpOp, additionalResults, additionalResultTypes,
775 newRetIndices);
776 distributedVal = newWarpOp.getResult(operandIndex);
777
778 // Distributed indices were appended first.
779 SmallVector<Value> newIndices;
780 for (int64_t i = 0, e = indices.size(); i < e; ++i)
781 newIndices.push_back(Elt: newWarpOp.getResult(newRetIndices[i]));
782
783 rewriter.setInsertionPointAfter(newWarpOp);
784 for (auto it : llvm::zip_equal(indexMap.getResults(), map.getResults())) {
785 AffineExpr d0, d1;
786 bindDims(read.getContext(), d0, d1);
787 auto indexExpr = dyn_cast<AffineDimExpr>(std::get<0>(it));
788 if (!indexExpr)
789 continue;
790 unsigned indexPos = indexExpr.getPosition();
791 unsigned vectorPos = cast<AffineDimExpr>(std::get<1>(it)).getPosition();
792 int64_t scale = distributedType.getDimSize(vectorPos);
793 newIndices[indexPos] = affine::makeComposedAffineApply(
794 rewriter, read.getLoc(), d0 + scale * d1,
795 {newIndices[indexPos], delinearizedIds[vectorPos]});
796 }
797
798 // Distributed padding value was appended right after the indices.
799 Value newPadding = newWarpOp.getResult(newRetIndices[indices.size()]);
800 // Distributed mask value was added at the end (if the op has a mask).
801 Value newMask =
802 hasMask ? newWarpOp.getResult(newRetIndices[newRetIndices.size() - 1])
803 : Value();
804 auto newRead = rewriter.create<vector::TransferReadOp>(
805 read.getLoc(), distributedVal.getType(), read.getBase(), newIndices,
806 read.getPermutationMapAttr(), newPadding, newMask,
807 read.getInBoundsAttr());
808
809 rewriter.replaceAllUsesWith(distributedVal, newRead);
810 return success();
811 }
812};
813
814/// Remove any result that has no use along with the matching yieldOp operand.
815// TODO: Move this in WarpExecuteOnLane0Op canonicalization.
816struct WarpOpDeadResult : public WarpDistributionPattern {
817 using Base::Base;
818 LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
819 PatternRewriter &rewriter) const override {
820 SmallVector<Type> newResultTypes;
821 newResultTypes.reserve(N: warpOp->getNumResults());
822 SmallVector<Value> newYieldValues;
823 newYieldValues.reserve(N: warpOp->getNumResults());
824 DenseMap<Value, int64_t> dedupYieldOperandPositionMap;
825 DenseMap<OpResult, int64_t> dedupResultPositionMap;
826 auto yield = cast<gpu::YieldOp>(
827 warpOp.getBodyRegion().getBlocks().begin()->getTerminator());
828
829 // Some values may be yielded multiple times and correspond to multiple
830 // results. Deduplicating occurs by taking each result with its matching
831 // yielded value, and:
832 // 1. recording the unique first position at which the value is yielded.
833 // 2. recording for the result, the first position at which the dedup'ed
834 // value is yielded.
835 // 3. skipping from the new result types / new yielded values any result
836 // that has no use or whose yielded value has already been seen.
837 for (OpResult result : warpOp.getResults()) {
838 Value yieldOperand = yield.getOperand(result.getResultNumber());
839 auto it = dedupYieldOperandPositionMap.insert(
840 std::make_pair(yieldOperand, newResultTypes.size()));
841 dedupResultPositionMap.insert(std::make_pair(result, it.first->second));
842 if (result.use_empty() || !it.second)
843 continue;
844 newResultTypes.push_back(result.getType());
845 newYieldValues.push_back(yieldOperand);
846 }
847 // No modification, exit early.
848 if (yield.getNumOperands() == newYieldValues.size())
849 return failure();
850 // Move the body of the old warpOp to a new warpOp.
851 WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndReplaceReturns(
852 rewriter, warpOp, newYieldValues, newResultTypes);
853
854 // Simplify the new warp op after dropping dead results.
855 newWarpOp.getBody()->walk([&](Operation *op) {
856 if (isOpTriviallyDead(op))
857 rewriter.eraseOp(op);
858 });
859
860 // Replace results of the old warpOp by the new, deduplicated results.
861 SmallVector<Value> newValues;
862 newValues.reserve(N: warpOp->getNumResults());
863 for (OpResult result : warpOp.getResults()) {
864 if (result.use_empty())
865 newValues.push_back(Value());
866 else
867 newValues.push_back(
868 newWarpOp.getResult(dedupResultPositionMap.lookup(result)));
869 }
870 rewriter.replaceOp(warpOp, newValues);
871 return success();
872 }
873};
874
875// If an operand is directly yielded out of the region we can forward it
876// directly and it doesn't need to go through the region.
877struct WarpOpForwardOperand : public WarpDistributionPattern {
878 using Base::Base;
879 LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
880 PatternRewriter &rewriter) const override {
881 auto yield = cast<gpu::YieldOp>(
882 warpOp.getBodyRegion().getBlocks().begin()->getTerminator());
883 Value valForwarded;
884 unsigned resultIndex;
885 for (OpOperand &operand : yield->getOpOperands()) {
886 Value result = warpOp.getResult(operand.getOperandNumber());
887 if (result.use_empty())
888 continue;
889
890 // Assume all the values coming from above are uniform.
891 if (!warpOp.getBodyRegion().isAncestor(operand.get().getParentRegion())) {
892 if (result.getType() != operand.get().getType())
893 continue;
894 valForwarded = operand.get();
895 resultIndex = operand.getOperandNumber();
896 break;
897 }
898 auto arg = dyn_cast<BlockArgument>(operand.get());
899 if (!arg || arg.getOwner()->getParentOp() != warpOp.getOperation())
900 continue;
901 Value warpOperand = warpOp.getArgs()[arg.getArgNumber()];
902 if (result.getType() != warpOperand.getType())
903 continue;
904 valForwarded = warpOperand;
905 resultIndex = operand.getOperandNumber();
906 break;
907 }
908 if (!valForwarded)
909 return failure();
910 // Notify the rewriter that the warp op is changing (see the comment on
911 // the WarpOpTransferRead pattern).
912 rewriter.startOpModification(op: warpOp);
913 rewriter.replaceAllUsesWith(warpOp.getResult(resultIndex), valForwarded);
914 rewriter.finalizeOpModification(op: warpOp);
915 return success();
916 }
917};
918
919struct WarpOpBroadcast : public WarpDistributionPattern {
920 using Base::Base;
921 LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
922 PatternRewriter &rewriter) const override {
923 OpOperand *operand =
924 getWarpResult(warpOp, llvm::IsaPred<vector::BroadcastOp>);
925 if (!operand)
926 return failure();
927 unsigned int operandNumber = operand->getOperandNumber();
928 auto broadcastOp = operand->get().getDefiningOp<vector::BroadcastOp>();
929 Location loc = broadcastOp.getLoc();
930 auto destVecType =
931 cast<VectorType>(warpOp->getResultTypes()[operandNumber]);
932 Value broadcastSrc = broadcastOp.getSource();
933 Type broadcastSrcType = broadcastSrc.getType();
934
935 // Check that the broadcast actually spans a set of values uniformly across
936 // all threads. In other words, check that each thread can reconstruct
937 // their own broadcast.
938 // For that we simply check that the broadcast we want to build makes sense.
939 if (vector::isBroadcastableTo(srcType: broadcastSrcType, dstVectorType: destVecType) !=
940 vector::BroadcastableToResult::Success)
941 return failure();
942 SmallVector<size_t> newRetIndices;
943 WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
944 rewriter, warpOp, {broadcastSrc}, {broadcastSrcType}, newRetIndices);
945 rewriter.setInsertionPointAfter(newWarpOp);
946 Value broadcasted = rewriter.create<vector::BroadcastOp>(
947 loc, destVecType, newWarpOp->getResult(newRetIndices[0]));
948 rewriter.replaceAllUsesWith(newWarpOp->getResult(operandNumber),
949 broadcasted);
950 return success();
951 }
952};
953
954/// Pattern to move shape cast out of the warp op. shape cast is basically a
955/// no-op for warp distribution; we need to handle the shape though.
956struct WarpOpShapeCast : public WarpDistributionPattern {
957 using Base::Base;
958 LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
959 PatternRewriter &rewriter) const override {
960 OpOperand *operand =
961 getWarpResult(warpOp, llvm::IsaPred<vector::ShapeCastOp>);
962 if (!operand)
963 return failure();
964
965 auto oldCastOp = operand->get().getDefiningOp<vector::ShapeCastOp>();
966
967 unsigned int operandNumber = operand->getOperandNumber();
968 auto castDistributedType =
969 cast<VectorType>(warpOp->getResultTypes()[operandNumber]);
970 VectorType castOriginalType = oldCastOp.getSourceVectorType();
971 VectorType castResultType = castDistributedType;
972
973 // We expect the distributed type to have a smaller rank than the original
974 // type. Prepend with size-one dimensions to make them the same.
975 unsigned castDistributedRank = castDistributedType.getRank();
976 unsigned castOriginalRank = castOriginalType.getRank();
977 if (castDistributedRank < castOriginalRank) {
978 SmallVector<int64_t> shape(castOriginalRank - castDistributedRank, 1);
979 llvm::append_range(shape, castDistributedType.getShape());
980 castDistributedType =
981 VectorType::get(shape, castDistributedType.getElementType());
982 }
983
984 SmallVector<size_t> newRetIndices;
985 WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
986 rewriter, warpOp, {oldCastOp.getSource()}, {castDistributedType},
987 newRetIndices);
988 rewriter.setInsertionPointAfter(newWarpOp);
989 Value newCast = rewriter.create<vector::ShapeCastOp>(
990 oldCastOp.getLoc(), castResultType,
991 newWarpOp->getResult(newRetIndices[0]));
992 rewriter.replaceAllUsesWith(newWarpOp->getResult(operandNumber), newCast);
993 return success();
994 }
995};
996
997/// Sink out vector.create_mask op feeding into a warp op yield.
998/// ```
999/// %0 = ...
1000/// %1 = gpu.warp_execute_on_lane_0(%arg0) -> (vector<1xf32>) {
1001/// ...
1002/// %mask = vector.create_mask %0 : vector<32xi1>
1003/// gpu.yield %mask : vector<32xi1>
1004/// }
1005/// ```
1006/// To
1007/// ```
1008/// %0 = ...
1009/// gpu.warp_execute_on_lane_0(%arg0) {
1010/// ...
1011/// }
1012/// %cmp = arith.cmpi ult, %laneid, %0
1013/// %ub = arith.select %cmp, %c0, %c1
1014/// %1 = vector.create_mask %ub : vector<1xi1>
1015struct WarpOpCreateMask : public WarpDistributionPattern {
1016 using Base::Base;
1017 LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
1018 PatternRewriter &rewriter) const override {
1019 OpOperand *yieldOperand =
1020 getWarpResult(warpOp, llvm::IsaPred<vector::CreateMaskOp>);
1021 if (!yieldOperand)
1022 return failure();
1023
1024 auto mask = yieldOperand->get().getDefiningOp<vector::CreateMaskOp>();
1025
1026 // Early exit if any values needed for calculating the new mask indices
1027 // are defined inside the warp op.
1028 if (!llvm::all_of(mask->getOperands(), [&](Value value) {
1029 return warpOp.isDefinedOutsideOfRegion(value);
1030 }))
1031 return failure();
1032
1033 Location loc = mask.getLoc();
1034 unsigned operandIndex = yieldOperand->getOperandNumber();
1035
1036 auto distType = cast<VectorType>(warpOp.getResult(operandIndex).getType());
1037 VectorType seqType = mask.getVectorType();
1038 ArrayRef<int64_t> seqShape = seqType.getShape();
1039 ArrayRef<int64_t> distShape = distType.getShape();
1040
1041 rewriter.setInsertionPointAfter(warpOp);
1042
1043 // Delinearize the lane ID for constructing the distributed mask sizes.
1044 SmallVector<Value> delinearizedIds;
1045 if (!delinearizeLaneId(builder&: rewriter, loc, originalShape: seqShape, distributedShape: distShape,
1046 warpSize: warpOp.getWarpSize(), laneId: warpOp.getLaneid(),
1047 delinearizedIds))
1048 return rewriter.notifyMatchFailure(
1049 mask, "cannot delinearize lane ID for distribution");
1050 assert(!delinearizedIds.empty());
1051
1052 // Notify the rewriter that the warp op is changing (see the comment on
1053 // the WarpOpTransferRead pattern).
1054 rewriter.startOpModification(op: warpOp);
1055
1056 AffineExpr s0, s1;
1057 bindSymbols(ctx: rewriter.getContext(), exprs&: s0, exprs&: s1);
1058 SmallVector<Value> newOperands;
1059 for (int i = 0, e = distShape.size(); i < e; ++i) {
1060 // Get `mask_dim_range_upper_limit[i] - lane_id[i] * dist_sizes[i]` to
1061 // find the distance from the largest mask index owned by this lane to the
1062 // original mask size. `vector.create_mask` implicitly clamps mask
1063 // operands to the range [0, mask_vector_size[i]], or in other words, the
1064 // mask sizes are always in the range [0, mask_vector_size[i]).
1065 Value maskDimIdx = affine::makeComposedAffineApply(
1066 rewriter, loc, s1 - s0 * distShape[i],
1067 {delinearizedIds[i], mask.getOperand(i)});
1068 newOperands.push_back(Elt: maskDimIdx);
1069 }
1070
1071 auto newMask =
1072 rewriter.create<vector::CreateMaskOp>(loc, distType, newOperands);
1073 rewriter.replaceAllUsesWith(warpOp.getResult(operandIndex), newMask);
1074 rewriter.finalizeOpModification(op: warpOp);
1075 return success();
1076 }
1077};
1078
1079/// Pattern to move out vector.extract of single element vector. Those don't
1080/// need to be distributed and can just be propagated outside of the region.
1081struct WarpOpExtract : public WarpDistributionPattern {
1082 using Base::Base;
1083 LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
1084 PatternRewriter &rewriter) const override {
1085 OpOperand *operand =
1086 getWarpResult(warpOp, llvm::IsaPred<vector::ExtractOp>);
1087 if (!operand)
1088 return failure();
1089 unsigned int operandNumber = operand->getOperandNumber();
1090 auto extractOp = operand->get().getDefiningOp<vector::ExtractOp>();
1091 VectorType extractSrcType = extractOp.getSourceVectorType();
1092 Location loc = extractOp.getLoc();
1093
1094 // For 1-d or 0-d source cases, we rely on WarpOpExtractScalar pattern.
1095 if (extractSrcType.getRank() <= 1) {
1096 return failure();
1097 }
1098
1099 // All following cases are 2d or higher dimensional source vectors.
1100
1101 if (warpOp.getResult(operandNumber).getType() == operand->get().getType()) {
1102 // There is no distribution, this is a broadcast. Simply move the extract
1103 // out of the warp op.
1104 // TODO: This could be optimized. E.g., in case of a scalar result, let
1105 // one lane extract and shuffle the result to all other lanes (same as
1106 // the 1d case).
1107 SmallVector<size_t> newRetIndices;
1108 WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
1109 rewriter, warpOp, {extractOp.getVector()},
1110 {extractOp.getSourceVectorType()}, newRetIndices);
1111 rewriter.setInsertionPointAfter(newWarpOp);
1112 Value distributedVec = newWarpOp->getResult(newRetIndices[0]);
1113 // Extract from distributed vector.
1114 Value newExtract = rewriter.create<vector::ExtractOp>(
1115 loc, distributedVec, extractOp.getMixedPosition());
1116 rewriter.replaceAllUsesWith(newWarpOp->getResult(operandNumber),
1117 newExtract);
1118 return success();
1119 }
1120
1121 // Find the distributed dimension. There should be exactly one.
1122 auto distributedType =
1123 cast<VectorType>(warpOp.getResult(operandNumber).getType());
1124 auto yieldedType = cast<VectorType>(operand->get().getType());
1125 int64_t distributedDim = -1;
1126 for (int64_t i = 0; i < yieldedType.getRank(); ++i) {
1127 if (distributedType.getDimSize(i) != yieldedType.getDimSize(i)) {
1128 // Keep this assert here in case WarpExecuteOnLane0Op gets extended to
1129 // support distributing multiple dimensions in the future.
1130 assert(distributedDim == -1 && "found multiple distributed dims");
1131 distributedDim = i;
1132 }
1133 }
1134 assert(distributedDim != -1 && "could not find distributed dimension");
1135 (void)distributedDim;
1136
1137 // Yield source vector from warp op.
1138 SmallVector<int64_t> newDistributedShape(extractSrcType.getShape());
1139 for (int i = 0; i < distributedType.getRank(); ++i)
1140 newDistributedShape[i + extractOp.getNumIndices()] =
1141 distributedType.getDimSize(i);
1142 auto newDistributedType =
1143 VectorType::get(newDistributedShape, distributedType.getElementType());
1144 SmallVector<size_t> newRetIndices;
1145 WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
1146 rewriter, warpOp, {extractOp.getVector()}, {newDistributedType},
1147 newRetIndices);
1148 rewriter.setInsertionPointAfter(newWarpOp);
1149 Value distributedVec = newWarpOp->getResult(newRetIndices[0]);
1150 // Extract from distributed vector.
1151 Value newExtract = rewriter.create<vector::ExtractOp>(
1152 loc, distributedVec, extractOp.getMixedPosition());
1153 rewriter.replaceAllUsesWith(newWarpOp->getResult(operandNumber),
1154 newExtract);
1155 return success();
1156 }
1157};
1158
1159/// Pattern to move out vector.extract with a scalar result.
1160/// Only supports 1-D and 0-D sources for now.
1161struct WarpOpExtractScalar : public WarpDistributionPattern {
1162 WarpOpExtractScalar(MLIRContext *ctx, WarpShuffleFromIdxFn fn,
1163 PatternBenefit b = 1)
1164 : WarpDistributionPattern(ctx, b), warpShuffleFromIdxFn(std::move(fn)) {}
1165 LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
1166 PatternRewriter &rewriter) const override {
1167 OpOperand *operand =
1168 getWarpResult(warpOp, llvm::IsaPred<vector::ExtractOp>);
1169 if (!operand)
1170 return failure();
1171 unsigned int operandNumber = operand->getOperandNumber();
1172 auto extractOp = operand->get().getDefiningOp<vector::ExtractOp>();
1173 VectorType extractSrcType = extractOp.getSourceVectorType();
1174 // Only supports 1-D or 0-D sources for now.
1175 if (extractSrcType.getRank() > 1) {
1176 return rewriter.notifyMatchFailure(
1177 extractOp, "only 0-D or 1-D source supported for now");
1178 }
1179 // TODO: Supported shuffle types should be parameterizable, similar to
1180 // `WarpShuffleFromIdxFn`.
1181 if (!extractSrcType.getElementType().isF32() &&
1182 !extractSrcType.getElementType().isInteger(32))
1183 return rewriter.notifyMatchFailure(
1184 extractOp, "only f32/i32 element types are supported");
1185 bool is0dOrVec1Extract = extractSrcType.getNumElements() == 1;
1186 Type elType = extractSrcType.getElementType();
1187 VectorType distributedVecType;
1188 if (!is0dOrVec1Extract) {
1189 assert(extractSrcType.getRank() == 1 &&
1190 "expected that extract src rank is 0 or 1");
1191 if (extractSrcType.getShape()[0] % warpOp.getWarpSize() != 0)
1192 return failure();
1193 int64_t elementsPerLane =
1194 extractSrcType.getShape()[0] / warpOp.getWarpSize();
1195 distributedVecType = VectorType::get({elementsPerLane}, elType);
1196 } else {
1197 distributedVecType = extractSrcType;
1198 }
1199 // Yield source vector and position (if present) from warp op.
1200 SmallVector<Value> additionalResults{extractOp.getVector()};
1201 SmallVector<Type> additionalResultTypes{distributedVecType};
1202 additionalResults.append(
1203 RHS: SmallVector<Value>(extractOp.getDynamicPosition()));
1204 additionalResultTypes.append(
1205 RHS: SmallVector<Type>(extractOp.getDynamicPosition().getTypes()));
1206
1207 Location loc = extractOp.getLoc();
1208 SmallVector<size_t> newRetIndices;
1209 WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
1210 rewriter, warpOp, additionalResults, additionalResultTypes,
1211 newRetIndices);
1212 rewriter.setInsertionPointAfter(newWarpOp);
1213 Value distributedVec = newWarpOp->getResult(newRetIndices[0]);
1214
1215 // 0d extract: The new warp op broadcasts the source vector to all lanes.
1216 // All lanes extract the scalar.
1217 if (is0dOrVec1Extract) {
1218 Value newExtract;
1219 SmallVector<int64_t> indices(extractSrcType.getRank(), 0);
1220 newExtract =
1221 rewriter.create<vector::ExtractOp>(loc, distributedVec, indices);
1222 rewriter.replaceAllUsesWith(newWarpOp->getResult(operandNumber),
1223 newExtract);
1224 return success();
1225 }
1226
1227 int64_t staticPos = extractOp.getStaticPosition()[0];
1228 OpFoldResult pos = ShapedType::isDynamic(staticPos)
1229 ? (newWarpOp->getResult(newRetIndices[1]))
1230 : OpFoldResult(rewriter.getIndexAttr(staticPos));
1231 // 1d extract: Distribute the source vector. One lane extracts and shuffles
1232 // the value to all other lanes.
1233 int64_t elementsPerLane = distributedVecType.getShape()[0];
1234 AffineExpr sym0 = getAffineSymbolExpr(position: 0, context: rewriter.getContext());
1235 // tid of extracting thread: pos / elementsPerLane
1236 Value broadcastFromTid = affine::makeComposedAffineApply(
1237 rewriter, loc, sym0.ceilDiv(v: elementsPerLane), pos);
1238 // Extract at position: pos % elementsPerLane
1239 Value newPos =
1240 elementsPerLane == 1
1241 ? rewriter.create<arith::ConstantIndexOp>(location: loc, args: 0).getResult()
1242 : affine::makeComposedAffineApply(rewriter, loc,
1243 sym0 % elementsPerLane, pos);
1244 Value extracted =
1245 rewriter.create<vector::ExtractOp>(loc, distributedVec, newPos);
1246
1247 // Shuffle the extracted value to all lanes.
1248 Value shuffled = warpShuffleFromIdxFn(
1249 loc, rewriter, extracted, broadcastFromTid, newWarpOp.getWarpSize());
1250 rewriter.replaceAllUsesWith(newWarpOp->getResult(operandNumber), shuffled);
1251 return success();
1252 }
1253
1254private:
1255 WarpShuffleFromIdxFn warpShuffleFromIdxFn;
1256};
1257
1258/// Pattern to convert vector.extractelement to vector.extract.
1259struct WarpOpExtractElement : public WarpDistributionPattern {
1260 using Base::Base;
1261 LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
1262 PatternRewriter &rewriter) const override {
1263 OpOperand *operand =
1264 getWarpResult(warpOp, llvm::IsaPred<vector::ExtractElementOp>);
1265 if (!operand)
1266 return failure();
1267 auto extractOp = operand->get().getDefiningOp<vector::ExtractElementOp>();
1268 SmallVector<OpFoldResult> indices;
1269 if (auto pos = extractOp.getPosition()) {
1270 indices.push_back(Elt: pos);
1271 }
1272 rewriter.setInsertionPoint(extractOp);
1273 rewriter.replaceOpWithNewOp<vector::ExtractOp>(
1274 extractOp, extractOp.getVector(), indices);
1275 return success();
1276 }
1277};
1278
1279/// Pattern to move out vector.insert with a scalar input.
1280/// Only supports 1-D and 0-D destinations for now.
1281struct WarpOpInsertScalar : public WarpDistributionPattern {
1282 using Base::Base;
1283 LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
1284 PatternRewriter &rewriter) const override {
1285 OpOperand *operand = getWarpResult(warpOp, llvm::IsaPred<vector::InsertOp>);
1286 if (!operand)
1287 return failure();
1288 unsigned int operandNumber = operand->getOperandNumber();
1289 auto insertOp = operand->get().getDefiningOp<vector::InsertOp>();
1290 VectorType vecType = insertOp.getDestVectorType();
1291 VectorType distrType =
1292 cast<VectorType>(warpOp.getResult(operandNumber).getType());
1293
1294 // Only supports 1-D or 0-D destinations for now.
1295 if (vecType.getRank() > 1) {
1296 return rewriter.notifyMatchFailure(
1297 insertOp, "only 0-D or 1-D source supported for now");
1298 }
1299
1300 // Yield destination vector, source scalar and position from warp op.
1301 SmallVector<Value> additionalResults{insertOp.getDest(),
1302 insertOp.getValueToStore()};
1303 SmallVector<Type> additionalResultTypes{
1304 distrType, insertOp.getValueToStore().getType()};
1305 additionalResults.append(RHS: SmallVector<Value>(insertOp.getDynamicPosition()));
1306 additionalResultTypes.append(
1307 RHS: SmallVector<Type>(insertOp.getDynamicPosition().getTypes()));
1308
1309 Location loc = insertOp.getLoc();
1310 SmallVector<size_t> newRetIndices;
1311 WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
1312 rewriter, warpOp, additionalResults, additionalResultTypes,
1313 newRetIndices);
1314 rewriter.setInsertionPointAfter(newWarpOp);
1315 Value distributedVec = newWarpOp->getResult(newRetIndices[0]);
1316 Value newSource = newWarpOp->getResult(newRetIndices[1]);
1317 rewriter.setInsertionPointAfter(newWarpOp);
1318
1319 OpFoldResult pos;
1320 if (vecType.getRank() != 0) {
1321 int64_t staticPos = insertOp.getStaticPosition()[0];
1322 pos = ShapedType::isDynamic(staticPos)
1323 ? (newWarpOp->getResult(newRetIndices[2]))
1324 : OpFoldResult(rewriter.getIndexAttr(staticPos));
1325 }
1326
1327 // This condition is always true for 0-d vectors.
1328 if (vecType == distrType) {
1329 Value newInsert;
1330 SmallVector<OpFoldResult> indices;
1331 if (pos) {
1332 indices.push_back(Elt: pos);
1333 }
1334 newInsert = rewriter.create<vector::InsertOp>(loc, newSource,
1335 distributedVec, indices);
1336 // Broadcast: Simply move the vector.insert op out.
1337 rewriter.replaceAllUsesWith(newWarpOp->getResult(operandNumber),
1338 newInsert);
1339 return success();
1340 }
1341
1342 // This is a distribution. Only one lane should insert.
1343 int64_t elementsPerLane = distrType.getShape()[0];
1344 AffineExpr sym0 = getAffineSymbolExpr(position: 0, context: rewriter.getContext());
1345 // tid of extracting thread: pos / elementsPerLane
1346 Value insertingLane = affine::makeComposedAffineApply(
1347 rewriter, loc, sym0.ceilDiv(v: elementsPerLane), pos);
1348 // Insert position: pos % elementsPerLane
1349 OpFoldResult newPos = affine::makeComposedFoldedAffineApply(
1350 b&: rewriter, loc, expr: sym0 % elementsPerLane, operands: pos);
1351 Value isInsertingLane = rewriter.create<arith::CmpIOp>(
1352 loc, arith::CmpIPredicate::eq, newWarpOp.getLaneid(), insertingLane);
1353 Value newResult =
1354 rewriter
1355 .create<scf::IfOp>(
1356 loc, isInsertingLane,
1357 /*thenBuilder=*/
1358 [&](OpBuilder &builder, Location loc) {
1359 Value newInsert = builder.create<vector::InsertOp>(
1360 loc, newSource, distributedVec, newPos);
1361 builder.create<scf::YieldOp>(loc, newInsert);
1362 },
1363 /*elseBuilder=*/
1364 [&](OpBuilder &builder, Location loc) {
1365 builder.create<scf::YieldOp>(loc, distributedVec);
1366 })
1367 .getResult(0);
1368 rewriter.replaceAllUsesWith(newWarpOp->getResult(operandNumber), newResult);
1369 return success();
1370 }
1371};
1372
1373struct WarpOpInsert : public WarpDistributionPattern {
1374 using Base::Base;
1375 LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
1376 PatternRewriter &rewriter) const override {
1377 OpOperand *operand = getWarpResult(warpOp, llvm::IsaPred<vector::InsertOp>);
1378 if (!operand)
1379 return failure();
1380 unsigned int operandNumber = operand->getOperandNumber();
1381 auto insertOp = operand->get().getDefiningOp<vector::InsertOp>();
1382 Location loc = insertOp.getLoc();
1383
1384 // For 1-d or 0-d destination cases, we rely on WarpOpInsertScalar pattern.
1385 if (insertOp.getDestVectorType().getRank() <= 1) {
1386 return failure();
1387 }
1388
1389 // All following cases are 2d or higher dimensional source vectors.
1390
1391 if (warpOp.getResult(operandNumber).getType() == operand->get().getType()) {
1392 // There is no distribution, this is a broadcast. Simply move the insert
1393 // out of the warp op.
1394 SmallVector<size_t> newRetIndices;
1395 WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
1396 rewriter, warpOp, {insertOp.getValueToStore(), insertOp.getDest()},
1397 {insertOp.getValueToStoreType(), insertOp.getDestVectorType()},
1398 newRetIndices);
1399 rewriter.setInsertionPointAfter(newWarpOp);
1400 Value distributedSrc = newWarpOp->getResult(newRetIndices[0]);
1401 Value distributedDest = newWarpOp->getResult(newRetIndices[1]);
1402 Value newResult = rewriter.create<vector::InsertOp>(
1403 loc, distributedSrc, distributedDest, insertOp.getMixedPosition());
1404 rewriter.replaceAllUsesWith(newWarpOp->getResult(operandNumber),
1405 newResult);
1406 return success();
1407 }
1408
1409 // Find the distributed dimension. There should be exactly one.
1410 auto distrDestType =
1411 cast<VectorType>(warpOp.getResult(operandNumber).getType());
1412 auto yieldedType = cast<VectorType>(operand->get().getType());
1413 int64_t distrDestDim = -1;
1414 for (int64_t i = 0; i < yieldedType.getRank(); ++i) {
1415 if (distrDestType.getDimSize(i) != yieldedType.getDimSize(i)) {
1416 // Keep this assert here in case WarpExecuteOnLane0Op gets extended to
1417 // support distributing multiple dimensions in the future.
1418 assert(distrDestDim == -1 && "found multiple distributed dims");
1419 distrDestDim = i;
1420 }
1421 }
1422 assert(distrDestDim != -1 && "could not find distributed dimension");
1423
1424 // Compute the distributed source vector type.
1425 VectorType srcVecType = cast<VectorType>(insertOp.getValueToStoreType());
1426 SmallVector<int64_t> distrSrcShape(srcVecType.getShape());
1427 // E.g.: vector.insert %s, %d [2] : vector<96xf32> into vector<128x96xf32>
1428 // Case 1: distrDestDim = 1 (dim of size 96). In that case, each lane will
1429 // insert a smaller vector<3xf32>.
1430 // Case 2: distrDestDim = 0 (dim of size 128) => distrSrcDim = -1. In that
1431 // case, one lane will insert the source vector<96xf32>. The other
1432 // lanes will not do anything.
1433 int64_t distrSrcDim = distrDestDim - insertOp.getNumIndices();
1434 if (distrSrcDim >= 0)
1435 distrSrcShape[distrSrcDim] = distrDestType.getDimSize(distrDestDim);
1436 auto distrSrcType =
1437 VectorType::get(distrSrcShape, distrDestType.getElementType());
1438
1439 // Yield source and dest vectors from warp op.
1440 SmallVector<size_t> newRetIndices;
1441 WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
1442 rewriter, warpOp, {insertOp.getValueToStore(), insertOp.getDest()},
1443 {distrSrcType, distrDestType}, newRetIndices);
1444 rewriter.setInsertionPointAfter(newWarpOp);
1445 Value distributedSrc = newWarpOp->getResult(newRetIndices[0]);
1446 Value distributedDest = newWarpOp->getResult(newRetIndices[1]);
1447
1448 // Insert into the distributed vector.
1449 Value newResult;
1450 if (distrSrcDim >= 0) {
1451 // Every lane inserts a small piece.
1452 newResult = rewriter.create<vector::InsertOp>(
1453 loc, distributedSrc, distributedDest, insertOp.getMixedPosition());
1454 } else {
1455 // One lane inserts the entire source vector.
1456 int64_t elementsPerLane = distrDestType.getDimSize(distrDestDim);
1457 SmallVector<OpFoldResult> pos = insertOp.getMixedPosition();
1458 SmallVector<int64_t> newPos = getAsIntegers(foldResults: pos);
1459 // tid of inserting lane: pos / elementsPerLane
1460 Value insertingLane = rewriter.create<arith::ConstantIndexOp>(
1461 location: loc, args: newPos[distrDestDim] / elementsPerLane);
1462 Value isInsertingLane = rewriter.create<arith::CmpIOp>(
1463 loc, arith::CmpIPredicate::eq, newWarpOp.getLaneid(), insertingLane);
1464 // Insert position: pos % elementsPerLane
1465 newPos[distrDestDim] %= elementsPerLane;
1466 auto insertingBuilder = [&](OpBuilder &builder, Location loc) {
1467 Value newInsert = builder.create<vector::InsertOp>(
1468 loc, distributedSrc, distributedDest, newPos);
1469 builder.create<scf::YieldOp>(loc, newInsert);
1470 };
1471 auto nonInsertingBuilder = [&](OpBuilder &builder, Location loc) {
1472 builder.create<scf::YieldOp>(loc, distributedDest);
1473 };
1474 newResult = rewriter
1475 .create<scf::IfOp>(loc, isInsertingLane,
1476 /*thenBuilder=*/insertingBuilder,
1477 /*elseBuilder=*/nonInsertingBuilder)
1478 .getResult(0);
1479 }
1480
1481 rewriter.replaceAllUsesWith(newWarpOp->getResult(operandNumber), newResult);
1482 return success();
1483 }
1484};
1485
1486struct WarpOpInsertElement : public WarpDistributionPattern {
1487 using Base::Base;
1488 LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
1489 PatternRewriter &rewriter) const override {
1490 OpOperand *operand =
1491 getWarpResult(warpOp, llvm::IsaPred<vector::InsertElementOp>);
1492 if (!operand)
1493 return failure();
1494 auto insertOp = operand->get().getDefiningOp<vector::InsertElementOp>();
1495 SmallVector<OpFoldResult> indices;
1496 if (auto pos = insertOp.getPosition()) {
1497 indices.push_back(Elt: pos);
1498 }
1499 rewriter.setInsertionPoint(insertOp);
1500 rewriter.replaceOpWithNewOp<vector::InsertOp>(
1501 insertOp, insertOp.getSource(), insertOp.getDest(), indices);
1502 return success();
1503 }
1504};
1505
1506/// Sink scf.for region out of WarpExecuteOnLane0Op. This can be done only if
1507/// the scf.ForOp is the last operation in the region so that it doesn't
1508/// change the order of execution. This creates a new scf.for region after the
1509/// WarpExecuteOnLane0Op. The new scf.for region will contain a new
1510/// WarpExecuteOnLane0Op region. Example:
1511/// ```
1512/// %w = gpu.warp_execute_on_lane_0(%laneid) -> (vector<4xf32>) {
1513/// ...
1514/// %v1 = scf.for %arg3 = %c0 to %c128 step %c1 iter_args(%arg4 = %v)
1515/// -> (vector<128xf32>) {
1516/// ...
1517/// scf.yield %r : vector<128xf32>
1518/// }
1519/// gpu.yield %v1 : vector<128xf32>
1520/// }
1521/// ```
1522/// To:
1523/// %w0 = gpu.warp_execute_on_lane_0(%arg0) -> (vector<4xf32>) {
1524/// ...
1525/// gpu.yield %v : vector<128xf32>
1526/// }
1527/// %w = scf.for %arg3 = %c0 to %c128 step %c1 iter_args(%varg = %q0)
1528/// -> (vector<4xf32>) {
1529/// %iw = gpu.warp_execute_on_lane_0(%laneid)
1530/// args(%varg : vector<4xf32>) -> (vector<4xf32>) {
1531/// ^bb0(%arg: vector<128xf32>):
1532/// ...
1533/// gpu.yield %ir : vector<128xf32>
1534/// }
1535/// scf.yield %iw : vector<4xf32>
1536/// }
1537/// ```
1538struct WarpOpScfForOp : public WarpDistributionPattern {
1539
1540 WarpOpScfForOp(MLIRContext *ctx, DistributionMapFn fn, PatternBenefit b = 1)
1541 : WarpDistributionPattern(ctx, b), distributionMapFn(std::move(fn)) {}
1542 LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
1543 PatternRewriter &rewriter) const override {
1544 auto yield = cast<gpu::YieldOp>(
1545 warpOp.getBodyRegion().getBlocks().begin()->getTerminator());
1546 // Only pick up forOp if it is the last op in the region.
1547 Operation *lastNode = yield->getPrevNode();
1548 auto forOp = dyn_cast_or_null<scf::ForOp>(lastNode);
1549 if (!forOp)
1550 return failure();
1551 // Collect Values that come from the warp op but are outside the forOp.
1552 // Those Value needs to be returned by the original warpOp and passed to
1553 // the new op.
1554 llvm::SmallSetVector<Value, 32> escapingValues;
1555 SmallVector<Type> inputTypes;
1556 SmallVector<Type> distTypes;
1557 auto collectEscapingValues = [&](Value value) {
1558 if (!escapingValues.insert(X: value))
1559 return;
1560 Type distType = value.getType();
1561 if (auto vecType = dyn_cast<VectorType>(distType)) {
1562 AffineMap map = distributionMapFn(value);
1563 distType = getDistributedType(vecType, map, warpOp.getWarpSize());
1564 }
1565 inputTypes.push_back(Elt: value.getType());
1566 distTypes.push_back(Elt: distType);
1567 };
1568
1569 mlir::visitUsedValuesDefinedAbove(
1570 forOp.getBodyRegion(), [&](OpOperand *operand) {
1571 Operation *parent = operand->get().getParentRegion()->getParentOp();
1572 if (warpOp->isAncestor(parent)) {
1573 collectEscapingValues(operand->get());
1574 }
1575 });
1576
1577 // Any forOp result that is not already yielded by the warpOp
1578 // region is also considered escaping and must be returned by the
1579 // original warpOp.
1580 for (OpResult forResult : forOp.getResults()) {
1581 // Check if this forResult is already yielded by the yield op.
1582 if (llvm::is_contained(yield->getOperands(), forResult))
1583 continue;
1584 collectEscapingValues(forResult);
1585 }
1586
1587 if (llvm::is_contained(Range&: distTypes, Element: Type{}))
1588 return failure();
1589
1590 SmallVector<size_t> newRetIndices;
1591 WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
1592 rewriter, warpOp, escapingValues.getArrayRef(), distTypes,
1593 newRetIndices);
1594 yield = cast<gpu::YieldOp>(
1595 newWarpOp.getBodyRegion().getBlocks().begin()->getTerminator());
1596
1597 SmallVector<Value> newOperands;
1598 SmallVector<unsigned> resultIdx;
1599 // Collect all the outputs coming from the forOp.
1600 for (OpOperand &yieldOperand : yield->getOpOperands()) {
1601 if (yieldOperand.get().getDefiningOp() != forOp.getOperation())
1602 continue;
1603 auto forResult = cast<OpResult>(yieldOperand.get());
1604 newOperands.push_back(
1605 newWarpOp.getResult(yieldOperand.getOperandNumber()));
1606 yieldOperand.set(forOp.getInitArgs()[forResult.getResultNumber()]);
1607 resultIdx.push_back(yieldOperand.getOperandNumber());
1608 }
1609
1610 OpBuilder::InsertionGuard g(rewriter);
1611 rewriter.setInsertionPointAfter(newWarpOp);
1612
1613 // Create a new for op outside the region with a WarpExecuteOnLane0Op
1614 // region inside.
1615 auto newForOp = rewriter.create<scf::ForOp>(
1616 forOp.getLoc(), forOp.getLowerBound(), forOp.getUpperBound(),
1617 forOp.getStep(), newOperands);
1618 rewriter.setInsertionPointToStart(newForOp.getBody());
1619
1620 SmallVector<Value> warpInput(newForOp.getRegionIterArgs().begin(),
1621 newForOp.getRegionIterArgs().end());
1622 SmallVector<Type> warpInputType(forOp.getResultTypes().begin(),
1623 forOp.getResultTypes().end());
1624 llvm::SmallDenseMap<Value, int64_t> argIndexMapping;
1625 for (auto [i, retIdx] : llvm::enumerate(First&: newRetIndices)) {
1626 auto newWarpResult = newWarpOp.getResult(retIdx);
1627 // Unused forOp results yielded by the warpOp region are already included
1628 // in the new ForOp.
1629 if (llvm::is_contained(newOperands, newWarpResult))
1630 continue;
1631 warpInput.push_back(Elt: newWarpResult);
1632 argIndexMapping[escapingValues[i]] = warpInputType.size();
1633 warpInputType.push_back(Elt: inputTypes[i]);
1634 }
1635 auto innerWarp = rewriter.create<WarpExecuteOnLane0Op>(
1636 newWarpOp.getLoc(), newForOp.getResultTypes(), newWarpOp.getLaneid(),
1637 newWarpOp.getWarpSize(), warpInput, warpInputType);
1638
1639 SmallVector<Value> argMapping;
1640 argMapping.push_back(Elt: newForOp.getInductionVar());
1641 for (Value args : innerWarp.getBody()->getArguments()) {
1642 argMapping.push_back(args);
1643 }
1644 argMapping.resize(forOp.getBody()->getNumArguments());
1645 SmallVector<Value> yieldOperands;
1646 for (Value operand : forOp.getBody()->getTerminator()->getOperands())
1647 yieldOperands.push_back(operand);
1648 rewriter.eraseOp(op: forOp.getBody()->getTerminator());
1649 rewriter.mergeBlocks(source: forOp.getBody(), dest: innerWarp.getBody(), argValues: argMapping);
1650 rewriter.setInsertionPointToEnd(innerWarp.getBody());
1651 rewriter.create<gpu::YieldOp>(innerWarp.getLoc(), yieldOperands);
1652 rewriter.setInsertionPointAfter(innerWarp);
1653 if (!innerWarp.getResults().empty())
1654 rewriter.create<scf::YieldOp>(forOp.getLoc(), innerWarp.getResults());
1655 rewriter.eraseOp(op: forOp);
1656 // Replace the warpOp result coming from the original ForOp.
1657 for (const auto &res : llvm::enumerate(First&: resultIdx)) {
1658 rewriter.replaceAllUsesWith(newWarpOp.getResult(res.value()),
1659 newForOp.getResult(res.index()));
1660 newForOp->setOperand(res.index() + 3, newWarpOp.getResult(res.value()));
1661 }
1662 newForOp.walk([&](Operation *op) {
1663 for (OpOperand &operand : op->getOpOperands()) {
1664 auto it = argIndexMapping.find(Val: operand.get());
1665 if (it == argIndexMapping.end())
1666 continue;
1667 operand.set(innerWarp.getBodyRegion().getArgument(it->second));
1668 }
1669 });
1670
1671 // Finally, hoist out any now uniform code from the inner warp op.
1672 mlir::vector::moveScalarUniformCode(innerWarp);
1673 return success();
1674 }
1675
1676private:
1677 DistributionMapFn distributionMapFn;
1678};
1679
1680/// A pattern that extracts vector.reduction ops from a WarpExecuteOnLane0Op.
1681/// The vector is reduced in parallel. Currently limited to vector size
1682/// matching the warpOp size. E.g.:
1683/// ```
1684/// %r = gpu.warp_execute_on_lane_0(%laneid)[32] -> (f32) {
1685/// %0 = "some_def"() : () -> (vector<32xf32>)
1686/// %1 = vector.reduction "add", %0 : vector<32xf32> into f32
1687/// gpu.yield %1 : f32
1688/// }
1689/// ```
1690/// is lowered to:
1691/// ```
1692/// %0 = gpu.warp_execute_on_lane_0(%laneid)[32] -> (vector<1xf32>) {
1693/// %1 = "some_def"() : () -> (vector<32xf32>)
1694/// gpu.yield %1 : vector<32xf32>
1695/// }
1696/// %a = vector.extract %0[0] : f32 from vector<1xf32>
1697/// %r = ("warp.reduction %a")
1698/// ```
1699struct WarpOpReduction : public WarpDistributionPattern {
1700 WarpOpReduction(MLIRContext *context,
1701 DistributedReductionFn distributedReductionFn,
1702 PatternBenefit benefit = 1)
1703 : WarpDistributionPattern(context, benefit),
1704 distributedReductionFn(std::move(distributedReductionFn)) {}
1705
1706 LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
1707 PatternRewriter &rewriter) const override {
1708 OpOperand *yieldOperand =
1709 getWarpResult(warpOp, llvm::IsaPred<vector::ReductionOp>);
1710 if (!yieldOperand)
1711 return failure();
1712
1713 auto reductionOp =
1714 cast<vector::ReductionOp>(yieldOperand->get().getDefiningOp());
1715 auto vectorType = cast<VectorType>(reductionOp.getVector().getType());
1716 // Only rank 1 vectors supported.
1717 if (vectorType.getRank() != 1)
1718 return rewriter.notifyMatchFailure(
1719 warpOp, "Only rank 1 reductions can be distributed.");
1720 // Only warp_size-sized vectors supported.
1721 if (vectorType.getShape()[0] % warpOp.getWarpSize() != 0)
1722 return rewriter.notifyMatchFailure(
1723 warpOp, "Reduction vector dimension must match was size.");
1724 if (!reductionOp.getType().isIntOrFloat())
1725 return rewriter.notifyMatchFailure(
1726 warpOp, "Reduction distribution currently only supports floats and "
1727 "integer types.");
1728
1729 int64_t numElements = vectorType.getShape()[0] / warpOp.getWarpSize();
1730 // Return vector that will be reduced from the WarpExecuteOnLane0Op.
1731 unsigned operandIndex = yieldOperand->getOperandNumber();
1732 SmallVector<Value> yieldValues = {reductionOp.getVector()};
1733 SmallVector<Type> retTypes = {
1734 VectorType::get({numElements}, reductionOp.getType())};
1735 if (reductionOp.getAcc()) {
1736 yieldValues.push_back(Elt: reductionOp.getAcc());
1737 retTypes.push_back(Elt: reductionOp.getAcc().getType());
1738 }
1739 SmallVector<size_t> newRetIndices;
1740 WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
1741 rewriter, warpOp, yieldValues, retTypes, newRetIndices);
1742 rewriter.setInsertionPointAfter(newWarpOp);
1743
1744 // Obtain data to reduce for a single lane.
1745 Value laneValVec = newWarpOp.getResult(newRetIndices[0]);
1746 // Distribute and reduce across threads.
1747 Value fullReduce =
1748 distributedReductionFn(reductionOp.getLoc(), rewriter, laneValVec,
1749 reductionOp.getKind(), newWarpOp.getWarpSize());
1750 if (reductionOp.getAcc()) {
1751 fullReduce = vector::makeArithReduction(
1752 rewriter, reductionOp.getLoc(), reductionOp.getKind(), fullReduce,
1753 newWarpOp.getResult(newRetIndices[1]));
1754 }
1755 rewriter.replaceAllUsesWith(newWarpOp.getResult(operandIndex), fullReduce);
1756 return success();
1757 }
1758
1759private:
1760 DistributedReductionFn distributedReductionFn;
1761};
1762
1763} // namespace
1764
1765void mlir::vector::populateWarpExecuteOnLane0OpToScfForPattern(
1766 RewritePatternSet &patterns,
1767 const WarpExecuteOnLane0LoweringOptions &options, PatternBenefit benefit) {
1768 patterns.add<WarpOpToScfIfPattern>(arg: patterns.getContext(), args: options, args&: benefit);
1769}
1770
1771void mlir::vector::populateDistributeTransferWriteOpPatterns(
1772 RewritePatternSet &patterns, const DistributionMapFn &distributionMapFn,
1773 unsigned maxNumElementsToExtract, PatternBenefit benefit) {
1774 patterns.add<WarpOpTransferWrite>(arg: patterns.getContext(), args: distributionMapFn,
1775 args&: maxNumElementsToExtract, args&: benefit);
1776}
1777
1778void mlir::vector::populatePropagateWarpVectorDistributionPatterns(
1779 RewritePatternSet &patterns, const DistributionMapFn &distributionMapFn,
1780 const WarpShuffleFromIdxFn &warpShuffleFromIdxFn, PatternBenefit benefit,
1781 PatternBenefit readBenefit) {
1782 patterns.add<WarpOpTransferRead>(arg: patterns.getContext(), args&: readBenefit);
1783 patterns.add<WarpOpElementwise, WarpOpDeadResult, WarpOpBroadcast,
1784 WarpOpShapeCast, WarpOpExtract, WarpOpForwardOperand,
1785 WarpOpConstant, WarpOpExtractElement, WarpOpInsertElement,
1786 WarpOpInsertScalar, WarpOpInsert, WarpOpCreateMask>(
1787 arg: patterns.getContext(), args&: benefit);
1788 patterns.add<WarpOpExtractScalar>(arg: patterns.getContext(), args: warpShuffleFromIdxFn,
1789 args&: benefit);
1790 patterns.add<WarpOpScfForOp>(arg: patterns.getContext(), args: distributionMapFn,
1791 args&: benefit);
1792}
1793
1794void mlir::vector::populateDistributeReduction(
1795 RewritePatternSet &patterns,
1796 const DistributedReductionFn &distributedReductionFn,
1797 PatternBenefit benefit) {
1798 patterns.add<WarpOpReduction>(arg: patterns.getContext(), args: distributedReductionFn,
1799 args&: benefit);
1800}
1801
1802/// Helper to know if an op can be hoisted out of the region.
1803static bool canBeHoisted(Operation *op,
1804 function_ref<bool(Value)> definedOutside) {
1805 return llvm::all_of(Range: op->getOperands(), P: definedOutside) &&
1806 isMemoryEffectFree(op) && op->getNumRegions() == 0;
1807}
1808
1809void mlir::vector::moveScalarUniformCode(WarpExecuteOnLane0Op warpOp) {
1810 Block *body = warpOp.getBody();
1811
1812 // Keep track of the ops we want to hoist.
1813 llvm::SmallSetVector<Operation *, 8> opsToMove;
1814
1815 // Helper to check if a value is or will be defined outside of the region.
1816 auto isDefinedOutsideOfBody = [&](Value value) {
1817 auto *definingOp = value.getDefiningOp();
1818 return (definingOp && opsToMove.count(definingOp)) ||
1819 warpOp.isDefinedOutsideOfRegion(value);
1820 };
1821
1822 // Do not use walk here, as we do not want to go into nested regions and hoist
1823 // operations from there.
1824 for (auto &op : body->without_terminator()) {
1825 bool hasVectorResult = llvm::any_of(op.getResults(), [](Value result) {
1826 return isa<VectorType>(result.getType());
1827 });
1828 if (!hasVectorResult && canBeHoisted(&op, isDefinedOutsideOfBody))
1829 opsToMove.insert(&op);
1830 }
1831
1832 // Move all the ops marked as uniform outside of the region.
1833 for (Operation *op : opsToMove)
1834 op->moveBefore(warpOp);
1835}
1836

Provided by KDAB

Privacy Policy
Improve your Profiling and Debugging skills
Find out more

source code of mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp