1//===- VectorDistribute.cpp - patterns to do vector distribution ----------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#include "mlir/Dialect/Affine/IR/AffineOps.h"
10#include "mlir/Dialect/Arith/IR/Arith.h"
11#include "mlir/Dialect/GPU/IR/GPUDialect.h"
12#include "mlir/Dialect/GPU/Utils/DistributionUtils.h"
13#include "mlir/Dialect/MemRef/IR/MemRef.h"
14#include "mlir/Dialect/SCF/IR/SCF.h"
15#include "mlir/Dialect/Vector/IR/VectorOps.h"
16#include "mlir/Dialect/Vector/Transforms/VectorDistribution.h"
17#include "mlir/IR/AffineExpr.h"
18#include "mlir/IR/Attributes.h"
19#include "mlir/IR/BuiltinTypes.h"
20#include "mlir/Interfaces/SideEffectInterfaces.h"
21#include "mlir/Transforms/RegionUtils.h"
22#include "llvm/ADT/SetVector.h"
23#include "llvm/ADT/SmallVectorExtras.h"
24#include "llvm/Support/FormatVariadic.h"
25#include <utility>
26
27using namespace mlir;
28using namespace mlir::vector;
29using namespace mlir::gpu;
30
31/// Currently the distribution map is implicit based on the vector shape. In the
32/// future it will be part of the op.
33/// Example:
34/// ```
35/// %0 = gpu.warp_execute_on_lane_0(%arg0) -> (vector<1x16x2xf32>) {
36/// ...
37/// gpu.yield %3 : vector<32x16x64xf32>
38/// }
39/// ```
40/// Would have an implicit map of:
41/// `(d0, d1, d2) -> (d0, d2)`
42static AffineMap calculateImplicitMap(VectorType sequentialType,
43 VectorType distributedType) {
44 SmallVector<AffineExpr> perm;
45 perm.reserve(N: 1);
46 // Check which dimensions of the sequential type are different than the
47 // dimensions of the distributed type to know the distributed dimensions. Then
48 // associate each distributed dimension to an ID in order.
49 for (unsigned i = 0, e = sequentialType.getRank(); i < e; i++) {
50 if (sequentialType.getDimSize(idx: i) != distributedType.getDimSize(idx: i))
51 perm.push_back(Elt: getAffineDimExpr(position: i, context: distributedType.getContext()));
52 }
53 auto map = AffineMap::get(dimCount: sequentialType.getRank(), symbolCount: 0, results: perm,
54 context: distributedType.getContext());
55 return map;
56}
57
58/// Given a sequential and distributed vector type, returns the distributed
59/// dimension. This function expects that only a single dimension is
60/// distributed.
61static int getDistributedDim(VectorType sequentialType,
62 VectorType distributedType) {
63 assert(sequentialType.getRank() == distributedType.getRank() &&
64 "sequential and distributed vector types must have the same rank");
65 int64_t distributedDim = -1;
66 for (int64_t i = 0; i < sequentialType.getRank(); ++i) {
67 if (distributedType.getDimSize(idx: i) != sequentialType.getDimSize(idx: i)) {
68 // Keep this assert here in case WarpExecuteOnLane0Op gets extended to
69 // support distributing multiple dimensions in the future.
70 assert(distributedDim == -1 && "found multiple distributed dims");
71 distributedDim = i;
72 }
73 }
74 return distributedDim;
75}
76
77namespace {
78
79/// Helper struct to create the load / store operations that permit transit
80/// through the parallel / sequential and the sequential / parallel boundaries
81/// when performing `rewriteWarpOpToScfFor`.
82///
83/// The vector distribution dimension is inferred from the vector types.
84struct DistributedLoadStoreHelper {
85 DistributedLoadStoreHelper(Value sequentialVal, Value distributedVal,
86 Value laneId, Value zero)
87 : sequentialVal(sequentialVal), distributedVal(distributedVal),
88 laneId(laneId), zero(zero) {
89 sequentialVectorType = dyn_cast<VectorType>(Val: sequentialVal.getType());
90 distributedVectorType = dyn_cast<VectorType>(Val: distributedVal.getType());
91 if (sequentialVectorType && distributedVectorType)
92 distributionMap =
93 calculateImplicitMap(sequentialType: sequentialVectorType, distributedType: distributedVectorType);
94 }
95
96 Value buildDistributedOffset(RewriterBase &b, Location loc, int64_t index) {
97 int64_t distributedSize = distributedVectorType.getDimSize(idx: index);
98 AffineExpr tid = getAffineSymbolExpr(position: 0, context: b.getContext());
99 return b.createOrFold<affine::AffineApplyOp>(location: loc, args: tid * distributedSize,
100 args: ArrayRef<Value>{laneId});
101 }
102
103 /// Create a store during the process of distributing the
104 /// `vector.warp_execute_on_thread_0` op.
105 /// Vector distribution assumes the following convention regarding the
106 /// temporary buffers that are created to transition values. This **must**
107 /// be properly specified in the `options.warpAllocationFn`:
108 /// 1. scalars of type T transit through a memref<1xT>.
109 /// 2. vectors of type V<shapexT> transit through a memref<shapexT>
110 Operation *buildStore(RewriterBase &b, Location loc, Value val,
111 Value buffer) {
112 assert((val == distributedVal || val == sequentialVal) &&
113 "Must store either the preregistered distributed or the "
114 "preregistered sequential value.");
115 // Scalar case can directly use memref.store.
116 if (!isa<VectorType>(Val: val.getType()))
117 return b.create<memref::StoreOp>(location: loc, args&: val, args&: buffer, args&: zero);
118
119 // Vector case must use vector::TransferWriteOp which will later lower to
120 // vector.store of memref.store depending on further lowerings.
121 int64_t rank = sequentialVectorType.getRank();
122 SmallVector<Value> indices(rank, zero);
123 if (val == distributedVal) {
124 for (auto dimExpr : distributionMap.getResults()) {
125 int64_t index = cast<AffineDimExpr>(Val&: dimExpr).getPosition();
126 indices[index] = buildDistributedOffset(b, loc, index);
127 }
128 }
129 SmallVector<bool> inBounds(indices.size(), true);
130 return b.create<vector::TransferWriteOp>(
131 location: loc, args&: val, args&: buffer, args&: indices,
132 args: ArrayRef<bool>(inBounds.begin(), inBounds.end()));
133 }
134
135 /// Create a load during the process of distributing the
136 /// `vector.warp_execute_on_thread_0` op.
137 /// Vector distribution assumes the following convention regarding the
138 /// temporary buffers that are created to transition values. This **must**
139 /// be properly specified in the `options.warpAllocationFn`:
140 /// 1. scalars of type T transit through a memref<1xT>.
141 /// 2. vectors of type V<shapexT> transit through a memref<shapexT>
142 ///
143 /// When broadcastMode is true, the load is not distributed to account for
144 /// the broadcast semantics of the `gpu.warp_execute_on_lane_0` op.
145 ///
146 /// Example:
147 ///
148 /// ```
149 /// %r = gpu.warp_execute_on_lane_0(...) -> (f32) {
150 /// gpu.yield %cst : f32
151 /// }
152 /// // Both types are f32. The constant %cst is broadcasted to all lanes.
153 /// ```
154 /// This behavior described in more detail in the documentation of the op.
155 Value buildLoad(RewriterBase &b, Location loc, Type type, Value buffer) {
156
157 // Scalar case can directly use memref.store.
158 if (!isa<VectorType>(Val: type))
159 return b.create<memref::LoadOp>(location: loc, args&: buffer, args&: zero);
160
161 // Other cases must be vector atm.
162 // Vector case must use vector::TransferReadOp which will later lower to
163 // vector.read of memref.read depending on further lowerings.
164 assert((type == distributedVectorType || type == sequentialVectorType) &&
165 "Must store either the preregistered distributed or the "
166 "preregistered sequential type.");
167 SmallVector<Value> indices(sequentialVectorType.getRank(), zero);
168 if (type == distributedVectorType) {
169 for (auto dimExpr : distributionMap.getResults()) {
170 int64_t index = cast<AffineDimExpr>(Val&: dimExpr).getPosition();
171 indices[index] = buildDistributedOffset(b, loc, index);
172 }
173 }
174 SmallVector<bool> inBounds(indices.size(), true);
175 return b.create<vector::TransferReadOp>(
176 location: loc, args: cast<VectorType>(Val&: type), args&: buffer, args&: indices, /*padding=*/args: std::nullopt,
177 args: ArrayRef<bool>(inBounds.begin(), inBounds.end()));
178 }
179
180 Value sequentialVal, distributedVal, laneId, zero;
181 VectorType sequentialVectorType, distributedVectorType;
182 AffineMap distributionMap;
183};
184
185} // namespace
186
187// Clones `op` into a new operation that takes `operands` and returns
188// `resultTypes`.
189static Operation *cloneOpWithOperandsAndTypes(RewriterBase &rewriter,
190 Location loc, Operation *op,
191 ArrayRef<Value> operands,
192 ArrayRef<Type> resultTypes) {
193 OperationState res(loc, op->getName().getStringRef(), operands, resultTypes,
194 op->getAttrs());
195 return rewriter.create(state: res);
196}
197
198namespace {
199
200/// Rewrite a WarpExecuteOnLane0Op into a predicated scf.if op where the single
201/// thread `laneId` executes the entirety of the computation.
202///
203/// After the transformation:
204/// - the IR within the scf.if op can be thought of as executing sequentially
205/// (from the point of view of threads along `laneId`).
206/// - the IR outside of the scf.if op can be thought of as executing in
207/// parallel (from the point of view of threads along `laneId`).
208///
209/// Values that need to transit through the parallel / sequential and the
210/// sequential / parallel boundaries do so via reads and writes to a temporary
211/// memory location.
212///
213/// The transformation proceeds in multiple steps:
214/// 1. Create the scf.if op.
215/// 2. Insert appropriate (alloc, write)-pairs before the scf.if and reads
216/// within the scf.if to transit the values captured from above.
217/// 3. Synchronize before the scf.if to ensure all writes inserted in 2. are
218/// consistent within the scf.if.
219/// 4. Move the body of the WarpExecuteOnLane0Op inside the scf.if.
220/// 5. Insert appropriate writes within scf.if and reads after the scf.if to
221/// transit the values returned by the op.
222/// 6. Synchronize after the scf.if to ensure all writes inserted in 5. are
223/// consistent after the scf.if.
224/// 7. Perform late cleanups.
225///
226/// All this assumes the vector distribution occurs along the most minor
227/// distributed vector dimension.
228struct WarpOpToScfIfPattern : public WarpDistributionPattern {
229 WarpOpToScfIfPattern(MLIRContext *context,
230 const WarpExecuteOnLane0LoweringOptions &options,
231 PatternBenefit benefit = 1)
232 : WarpDistributionPattern(context, benefit), options(options) {}
233
234 LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
235 PatternRewriter &rewriter) const override {
236 assert(warpOp.getBodyRegion().hasOneBlock() &&
237 "expected WarpOp with single block");
238 Block *warpOpBody = &warpOp.getBodyRegion().front();
239 Location loc = warpOp.getLoc();
240
241 // Passed all checks. Start rewriting.
242 OpBuilder::InsertionGuard g(rewriter);
243 rewriter.setInsertionPoint(warpOp);
244
245 // Step 1: Create scf.if op.
246 Value c0 = rewriter.create<arith::ConstantIndexOp>(location: loc, args: 0);
247 Value isLane0 = rewriter.create<arith::CmpIOp>(
248 location: loc, args: arith::CmpIPredicate::eq, args: warpOp.getLaneid(), args&: c0);
249 auto ifOp = rewriter.create<scf::IfOp>(location: loc, args&: isLane0,
250 /*withElseRegion=*/args: false);
251 rewriter.eraseOp(op: ifOp.thenBlock()->getTerminator());
252
253 // Step 2: insert appropriate (alloc, write)-pairs before the scf.if and
254 // reads within the scf.if to transit the values captured from above.
255 SmallVector<Value> bbArgReplacements;
256 for (const auto &it : llvm::enumerate(First: warpOp.getArgs())) {
257 Value sequentialVal = warpOpBody->getArgument(i: it.index());
258 Value distributedVal = it.value();
259 DistributedLoadStoreHelper helper(sequentialVal, distributedVal,
260 warpOp.getLaneid(), c0);
261
262 // Create buffer before the ifOp.
263 rewriter.setInsertionPoint(ifOp);
264 Value buffer = options.warpAllocationFn(loc, rewriter, warpOp,
265 sequentialVal.getType());
266 // Store distributed vector into buffer, before the ifOp.
267 helper.buildStore(b&: rewriter, loc, val: distributedVal, buffer);
268 // Load sequential vector from buffer, inside the ifOp.
269 rewriter.setInsertionPointToStart(ifOp.thenBlock());
270 bbArgReplacements.push_back(
271 Elt: helper.buildLoad(b&: rewriter, loc, type: sequentialVal.getType(), buffer));
272 }
273
274 // Step 3. Insert sync after all the stores and before all the loads.
275 if (!warpOp.getArgs().empty()) {
276 rewriter.setInsertionPoint(ifOp);
277 options.warpSyncronizationFn(loc, rewriter, warpOp);
278 }
279
280 // Step 4. Move body of warpOp to ifOp.
281 rewriter.mergeBlocks(source: warpOpBody, dest: ifOp.thenBlock(), argValues: bbArgReplacements);
282
283 // Step 5. Insert appropriate writes within scf.if and reads after the
284 // scf.if to transit the values returned by the op.
285 // TODO: at this point, we can reuse the shared memory from previous
286 // buffers.
287 SmallVector<Value> replacements;
288 auto yieldOp = cast<gpu::YieldOp>(Val: ifOp.thenBlock()->getTerminator());
289 Location yieldLoc = yieldOp.getLoc();
290 for (const auto &it : llvm::enumerate(First: yieldOp.getOperands())) {
291 Value sequentialVal = it.value();
292 Value distributedVal = warpOp->getResult(idx: it.index());
293 DistributedLoadStoreHelper helper(sequentialVal, distributedVal,
294 warpOp.getLaneid(), c0);
295
296 // Create buffer before the ifOp.
297 rewriter.setInsertionPoint(ifOp);
298 Value buffer = options.warpAllocationFn(loc, rewriter, warpOp,
299 sequentialVal.getType());
300
301 // Store yielded value into buffer, inside the ifOp, before the
302 // terminator.
303 rewriter.setInsertionPoint(yieldOp);
304 helper.buildStore(b&: rewriter, loc, val: sequentialVal, buffer);
305
306 // Load distributed value from buffer, after the warpOp.
307 rewriter.setInsertionPointAfter(ifOp);
308 // Result type and yielded value type are the same. This is a broadcast.
309 // E.g.:
310 // %r = gpu.warp_execute_on_lane_0(...) -> (f32) {
311 // gpu.yield %cst : f32
312 // }
313 // Both types are f32. The constant %cst is broadcasted to all lanes.
314 // This is described in more detail in the documentation of the op.
315 replacements.push_back(
316 Elt: helper.buildLoad(b&: rewriter, loc, type: distributedVal.getType(), buffer));
317 }
318
319 // Step 6. Insert sync after all the stores and before all the loads.
320 if (!yieldOp.getOperands().empty()) {
321 rewriter.setInsertionPointAfter(ifOp);
322 options.warpSyncronizationFn(loc, rewriter, warpOp);
323 }
324
325 // Step 7. Delete terminator and add empty scf.yield.
326 rewriter.eraseOp(op: yieldOp);
327 rewriter.setInsertionPointToEnd(ifOp.thenBlock());
328 rewriter.create<scf::YieldOp>(location: yieldLoc);
329
330 // Compute replacements for WarpOp results.
331 rewriter.replaceOp(op: warpOp, newValues: replacements);
332
333 return success();
334 }
335
336private:
337 const WarpExecuteOnLane0LoweringOptions &options;
338};
339
340/// Return the distributed vector type based on the original type and the
341/// distribution map. The map is expected to have a dimension equal to the
342/// original type rank and should be a projection where the results are the
343/// distributed dimensions. The number of results should be equal to the number
344/// of warp sizes which is currently limited to 1.
345/// Example: For a vector<16x32x64> distributed with a map(d0, d1, d2) -> (d1)
346/// and a warp size of 16 would distribute the second dimension (associated to
347/// d1) and return vector<16x2x64>
348static VectorType getDistributedType(VectorType originalType, AffineMap map,
349 int64_t warpSize) {
350 SmallVector<int64_t> targetShape(originalType.getShape());
351 for (unsigned i = 0, e = map.getNumResults(); i < e; i++) {
352 unsigned position = map.getDimPosition(idx: i);
353 if (targetShape[position] % warpSize != 0) {
354 if (warpSize % targetShape[position] != 0) {
355 return VectorType();
356 }
357 warpSize /= targetShape[position];
358 targetShape[position] = 1;
359 continue;
360 }
361 targetShape[position] = targetShape[position] / warpSize;
362 warpSize = 1;
363 break;
364 }
365 if (warpSize != 1) {
366 return VectorType();
367 }
368 VectorType targetType =
369 VectorType::get(shape: targetShape, elementType: originalType.getElementType());
370 return targetType;
371}
372
373/// Distribute transfer_write ops based on the affine map returned by
374/// `distributionMapFn`. Writes of size more than `maxNumElementToExtract`
375/// will not be distributed (it should be less than the warp size).
376///
377/// Example:
378/// ```
379/// %0 = gpu.warp_execute_on_lane_0(%id){
380/// ...
381/// vector.transfer_write %v, %A[%c0] : vector<32xf32>, memref<128xf32>
382/// gpu.yield
383/// }
384/// ```
385/// To
386/// ```
387/// %r:3 = gpu.warp_execute_on_lane_0(%id) -> (vector<1xf32>) {
388/// ...
389/// gpu.yield %v : vector<32xf32>
390/// }
391/// vector.transfer_write %v, %A[%id] : vector<1xf32>, memref<128xf32>
392struct WarpOpTransferWrite : public WarpDistributionPattern {
393 WarpOpTransferWrite(MLIRContext *ctx, DistributionMapFn fn,
394 unsigned maxNumElementsToExtract, PatternBenefit b = 1)
395 : WarpDistributionPattern(ctx, b), distributionMapFn(std::move(fn)),
396 maxNumElementsToExtract(maxNumElementsToExtract) {}
397
398 /// Distribute the TransferWriteOp. Only 1D distributions and vector dims that
399 /// are multiples of the distribution ratio are supported at the moment.
400 LogicalResult tryDistributeOp(RewriterBase &rewriter,
401 vector::TransferWriteOp writeOp,
402 WarpExecuteOnLane0Op warpOp) const {
403 VectorType writtenVectorType = writeOp.getVectorType();
404
405 // 1. If the write is 0-D, we just clone it into a new WarpExecuteOnLane0Op
406 // to separate it from the rest.
407 if (writtenVectorType.getRank() == 0)
408 return failure();
409
410 // 2. Compute the distributed type.
411 AffineMap map = distributionMapFn(writeOp.getVector());
412 VectorType targetType =
413 getDistributedType(originalType: writtenVectorType, map, warpSize: warpOp.getWarpSize());
414 if (!targetType)
415 return failure();
416
417 // 2.5 Compute the distributed type for the new mask;
418 VectorType maskType;
419 if (writeOp.getMask()) {
420 // TODO: Distribution of masked writes with non-trivial permutation maps
421 // requires the distribution of the mask to elementwise match the
422 // distribution of the permuted written vector. Currently the details
423 // of which lane is responsible for which element is captured strictly
424 // by shape information on the warp op, and thus requires materializing
425 // the permutation in IR.
426 if (!writeOp.getPermutationMap().isMinorIdentity())
427 return failure();
428 maskType =
429 getDistributedType(originalType: writeOp.getMaskType(), map, warpSize: warpOp.getWarpSize());
430 }
431
432 // 3. clone the write into a new WarpExecuteOnLane0Op to separate it from
433 // the rest.
434 vector::TransferWriteOp newWriteOp =
435 cloneWriteOp(rewriter, warpOp, writeOp, targetType, maybeMaskType: maskType);
436
437 // 4. Reindex the write using the distribution map.
438 auto newWarpOp =
439 newWriteOp.getVector().getDefiningOp<WarpExecuteOnLane0Op>();
440
441 // Delinearize the lane id based on the way threads are divided across the
442 // vector. To get the number of threads per vector dimension, divide the
443 // sequential size by the distributed size along each dim.
444 rewriter.setInsertionPoint(newWriteOp);
445 SmallVector<OpFoldResult> delinearizedIdSizes;
446 for (auto [seqSize, distSize] :
447 llvm::zip_equal(t: writtenVectorType.getShape(), u: targetType.getShape())) {
448 assert(seqSize % distSize == 0 && "Invalid distributed vector shape");
449 delinearizedIdSizes.push_back(Elt: rewriter.getIndexAttr(value: seqSize / distSize));
450 }
451 SmallVector<Value> delinearized;
452 if (map.getNumResults() > 1) {
453 delinearized = rewriter
454 .create<mlir::affine::AffineDelinearizeIndexOp>(
455 location: newWarpOp.getLoc(), args: newWarpOp.getLaneid(),
456 args&: delinearizedIdSizes)
457 .getResults();
458 } else {
459 // If there is only one map result, we can elide the delinearization
460 // op and use the lane id directly.
461 delinearized.append(NumInputs: targetType.getRank(), Elt: newWarpOp.getLaneid());
462 }
463
464 AffineMap indexMap = map.compose(map: newWriteOp.getPermutationMap());
465 Location loc = newWriteOp.getLoc();
466 SmallVector<Value> indices(newWriteOp.getIndices().begin(),
467 newWriteOp.getIndices().end());
468 for (auto it : llvm::zip(t: indexMap.getResults(), u: map.getResults())) {
469 AffineExpr d0, d1;
470 bindDims(ctx: newWarpOp.getContext(), exprs&: d0, exprs&: d1);
471 auto indexExpr = dyn_cast<AffineDimExpr>(Val: std::get<0>(t&: it));
472 if (!indexExpr)
473 continue;
474 unsigned indexPos = indexExpr.getPosition();
475 unsigned vectorPos = cast<AffineDimExpr>(Val: std::get<1>(t&: it)).getPosition();
476 Value laneId = delinearized[vectorPos];
477 auto scale =
478 rewriter.getAffineConstantExpr(constant: targetType.getDimSize(idx: vectorPos));
479 indices[indexPos] = affine::makeComposedAffineApply(
480 b&: rewriter, loc, e: d0 + scale * d1, operands: {indices[indexPos], laneId});
481 }
482 newWriteOp.getIndicesMutable().assign(values: indices);
483
484 return success();
485 }
486
487 /// Extract TransferWriteOps of vector<1x> into a separate warp op.
488 LogicalResult tryExtractOp(RewriterBase &rewriter,
489 vector::TransferWriteOp writeOp,
490 WarpExecuteOnLane0Op warpOp) const {
491 Location loc = writeOp.getLoc();
492 VectorType vecType = writeOp.getVectorType();
493
494 if (vecType.getNumElements() > maxNumElementsToExtract) {
495 return rewriter.notifyMatchFailure(
496 arg&: warpOp,
497 msg: llvm::formatv(
498 Fmt: "writes more elements ({0}) than allowed to extract ({1})",
499 Vals: vecType.getNumElements(), Vals: maxNumElementsToExtract));
500 }
501
502 // Do not process warp ops that contain only TransferWriteOps.
503 if (llvm::all_of(Range: warpOp.getOps(),
504 P: llvm::IsaPred<vector::TransferWriteOp, gpu::YieldOp>))
505 return failure();
506
507 SmallVector<Value> yieldValues = {writeOp.getVector()};
508 SmallVector<Type> retTypes = {vecType};
509 SmallVector<size_t> newRetIndices;
510 WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
511 rewriter, warpOp, newYieldedValues: yieldValues, newReturnTypes: retTypes, indices&: newRetIndices);
512 rewriter.setInsertionPointAfter(newWarpOp);
513
514 // Create a second warp op that contains only writeOp.
515 auto secondWarpOp = rewriter.create<WarpExecuteOnLane0Op>(
516 location: loc, args: TypeRange(), args: newWarpOp.getLaneid(), args: newWarpOp.getWarpSize());
517 Block &body = secondWarpOp.getBodyRegion().front();
518 rewriter.setInsertionPointToStart(&body);
519 auto newWriteOp =
520 cast<vector::TransferWriteOp>(Val: rewriter.clone(op&: *writeOp.getOperation()));
521 newWriteOp.getValueToStoreMutable().assign(
522 value: newWarpOp.getResult(i: newRetIndices[0]));
523 rewriter.eraseOp(op: writeOp);
524 rewriter.create<gpu::YieldOp>(location: newWarpOp.getLoc());
525 return success();
526 }
527
528 LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
529 PatternRewriter &rewriter) const override {
530 auto yield = cast<gpu::YieldOp>(
531 Val: warpOp.getBodyRegion().getBlocks().begin()->getTerminator());
532 Operation *lastNode = yield->getPrevNode();
533 auto writeOp = dyn_cast_or_null<vector::TransferWriteOp>(Val: lastNode);
534 if (!writeOp)
535 return failure();
536
537 Value maybeMask = writeOp.getMask();
538 if (!llvm::all_of(Range: writeOp->getOperands(), P: [&](Value value) {
539 return writeOp.getVector() == value ||
540 (maybeMask && maybeMask == value) ||
541 warpOp.isDefinedOutsideOfRegion(value);
542 }))
543 return failure();
544
545 if (succeeded(Result: tryDistributeOp(rewriter, writeOp, warpOp)))
546 return success();
547
548 // Masked writes not supported for extraction.
549 if (writeOp.getMask())
550 return failure();
551
552 if (succeeded(Result: tryExtractOp(rewriter, writeOp, warpOp)))
553 return success();
554
555 return failure();
556 }
557
558private:
559 /// Clone `writeOp` assumed to be nested under `warpOp` into a new warp
560 /// execute op with the proper return type. The new write op is updated to
561 /// write the result of the new warp execute op. The old `writeOp` is deleted.
562 vector::TransferWriteOp cloneWriteOp(RewriterBase &rewriter,
563 WarpExecuteOnLane0Op warpOp,
564 vector::TransferWriteOp writeOp,
565 VectorType targetType,
566 VectorType maybeMaskType) const {
567 assert(writeOp->getParentOp() == warpOp &&
568 "write must be nested immediately under warp");
569 OpBuilder::InsertionGuard g(rewriter);
570 SmallVector<size_t> newRetIndices;
571 WarpExecuteOnLane0Op newWarpOp;
572 if (maybeMaskType) {
573 newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
574 rewriter, warpOp, newYieldedValues: ValueRange{writeOp.getVector(), writeOp.getMask()},
575 newReturnTypes: TypeRange{targetType, maybeMaskType}, indices&: newRetIndices);
576 } else {
577 newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
578 rewriter, warpOp, newYieldedValues: ValueRange{{writeOp.getVector()}},
579 newReturnTypes: TypeRange{targetType}, indices&: newRetIndices);
580 }
581 rewriter.setInsertionPointAfter(newWarpOp);
582 auto newWriteOp =
583 cast<vector::TransferWriteOp>(Val: rewriter.clone(op&: *writeOp.getOperation()));
584 rewriter.eraseOp(op: writeOp);
585 newWriteOp.getValueToStoreMutable().assign(
586 value: newWarpOp.getResult(i: newRetIndices[0]));
587 if (maybeMaskType)
588 newWriteOp.getMaskMutable().assign(value: newWarpOp.getResult(i: newRetIndices[1]));
589 return newWriteOp;
590 }
591
592 DistributionMapFn distributionMapFn;
593 unsigned maxNumElementsToExtract = 1;
594};
595
596/// Sink out elementwise op feeding into a warp op yield.
597/// ```
598/// %0 = gpu.warp_execute_on_lane_0(%arg0) -> (vector<1xf32>) {
599/// ...
600/// %3 = arith.addf %1, %2 : vector<32xf32>
601/// gpu.yield %3 : vector<32xf32>
602/// }
603/// ```
604/// To
605/// ```
606/// %r:3 = gpu.warp_execute_on_lane_0(%arg0) -> (vector<1xf32>,
607/// vector<1xf32>, vector<1xf32>) {
608/// ...
609/// %4 = arith.addf %2, %3 : vector<32xf32>
610/// gpu.yield %4, %2, %3 : vector<32xf32>, vector<32xf32>,
611/// vector<32xf32>
612/// }
613/// %0 = arith.addf %r#1, %r#2 : vector<1xf32>
614struct WarpOpElementwise : public WarpDistributionPattern {
615 using Base::Base;
616 LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
617 PatternRewriter &rewriter) const override {
618 OpOperand *yieldOperand = getWarpResult(warpOp, fn: [](Operation *op) {
619 return OpTrait::hasElementwiseMappableTraits(op);
620 });
621 if (!yieldOperand)
622 return failure();
623
624 Operation *elementWise = yieldOperand->get().getDefiningOp();
625 unsigned operandIndex = yieldOperand->getOperandNumber();
626 Value distributedVal = warpOp.getResult(i: operandIndex);
627 SmallVector<Value> yieldValues;
628 SmallVector<Type> retTypes;
629 Location loc = warpOp.getLoc();
630 for (OpOperand &operand : elementWise->getOpOperands()) {
631 Type targetType;
632 if (auto vecType = dyn_cast<VectorType>(Val: distributedVal.getType())) {
633 // If the result type is a vector, the operands must also be vectors.
634 auto operandType = cast<VectorType>(Val: operand.get().getType());
635 targetType =
636 VectorType::get(shape: vecType.getShape(), elementType: operandType.getElementType());
637 } else {
638 auto operandType = operand.get().getType();
639 assert(!isa<VectorType>(operandType) &&
640 "unexpected yield of vector from op with scalar result type");
641 targetType = operandType;
642 }
643 retTypes.push_back(Elt: targetType);
644 yieldValues.push_back(Elt: operand.get());
645 }
646 SmallVector<size_t> newRetIndices;
647 WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
648 rewriter, warpOp, newYieldedValues: yieldValues, newReturnTypes: retTypes, indices&: newRetIndices);
649 rewriter.setInsertionPointAfter(newWarpOp);
650 SmallVector<Value> newOperands(elementWise->getOperands().begin(),
651 elementWise->getOperands().end());
652 for (unsigned i : llvm::seq(Begin: unsigned(0), End: elementWise->getNumOperands())) {
653 newOperands[i] = newWarpOp.getResult(i: newRetIndices[i]);
654 }
655 OpBuilder::InsertionGuard g(rewriter);
656 rewriter.setInsertionPointAfter(newWarpOp);
657 Operation *newOp = cloneOpWithOperandsAndTypes(
658 rewriter, loc, op: elementWise, operands: newOperands,
659 resultTypes: {newWarpOp.getResult(i: operandIndex).getType()});
660 rewriter.replaceAllUsesWith(from: newWarpOp.getResult(i: operandIndex),
661 to: newOp->getResult(idx: 0));
662 return success();
663 }
664};
665
666/// Sink out splat constant op feeding into a warp op yield.
667/// ```
668/// %0 = gpu.warp_execute_on_lane_0(%arg0) -> (vector<1xf32>) {
669/// ...
670/// %cst = arith.constant dense<2.0> : vector<32xf32>
671/// gpu.yield %cst : vector<32xf32>
672/// }
673/// ```
674/// To
675/// ```
676/// gpu.warp_execute_on_lane_0(%arg0 {
677/// ...
678/// }
679/// %0 = arith.constant dense<2.0> : vector<1xf32>
680struct WarpOpConstant : public WarpDistributionPattern {
681 using Base::Base;
682 LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
683 PatternRewriter &rewriter) const override {
684 OpOperand *yieldOperand =
685 getWarpResult(warpOp, fn: llvm::IsaPred<arith::ConstantOp>);
686 if (!yieldOperand)
687 return failure();
688 auto constantOp = yieldOperand->get().getDefiningOp<arith::ConstantOp>();
689 auto dense = dyn_cast<SplatElementsAttr>(Val: constantOp.getValue());
690 if (!dense)
691 return failure();
692 // Notify the rewriter that the warp op is changing (see the comment on
693 // the WarpOpTransferRead pattern).
694 rewriter.startOpModification(op: warpOp);
695 unsigned operandIndex = yieldOperand->getOperandNumber();
696 Attribute scalarAttr = dense.getSplatValue<Attribute>();
697 auto newAttr = DenseElementsAttr::get(
698 type: cast<ShapedType>(Val: warpOp.getResult(i: operandIndex).getType()), values: scalarAttr);
699 Location loc = warpOp.getLoc();
700 rewriter.setInsertionPointAfter(warpOp);
701 Value distConstant = rewriter.create<arith::ConstantOp>(location: loc, args&: newAttr);
702 rewriter.replaceAllUsesWith(from: warpOp.getResult(i: operandIndex), to: distConstant);
703 rewriter.finalizeOpModification(op: warpOp);
704 return success();
705 }
706};
707
708/// Sink out transfer_read op feeding into a warp op yield.
709/// ```
710/// %0 = gpu.warp_execute_on_lane_0(%arg0) -> (vector<1xf32>) {
711/// ...
712// %2 = vector.transfer_read %src[%c0], %cst : memref<1024xf32>,
713// vector<32xf32>
714/// gpu.yield %2 : vector<32xf32>
715/// }
716/// ```
717/// To
718/// ```
719/// %dead = gpu.warp_execute_on_lane_0(%arg0) -> (vector<1xf32>,
720/// vector<1xf32>, vector<1xf32>) {
721/// ...
722/// %2 = vector.transfer_read %src[%c0], %cst : memref<1024xf32>,
723/// vector<32xf32> gpu.yield %2 : vector<32xf32>
724/// }
725/// %0 = vector.transfer_read %src[%c0], %cst : memref<1024xf32>, vector<1xf32>
726struct WarpOpTransferRead : public WarpDistributionPattern {
727 using Base::Base;
728 LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
729 PatternRewriter &rewriter) const override {
730 // Try to find a distributable yielded read. Note that this pattern can
731 // still fail at the end after distribution, in which case this might have
732 // missed another distributable read.
733 OpOperand *operand = getWarpResult(warpOp, fn: [](Operation *op) {
734 // Don't duplicate transfer_read ops when distributing.
735 return isa<vector::TransferReadOp>(Val: op) && op->hasOneUse();
736 });
737 if (!operand)
738 return rewriter.notifyMatchFailure(
739 arg&: warpOp, msg: "warp result is not a vector.transfer_read op");
740 auto read = operand->get().getDefiningOp<vector::TransferReadOp>();
741
742 // Source must be defined outside of the region.
743 if (!warpOp.isDefinedOutsideOfRegion(value: read.getBase()))
744 return rewriter.notifyMatchFailure(
745 arg&: read, msg: "source must be defined outside of the region");
746
747 unsigned operandIndex = operand->getOperandNumber();
748 Value distributedVal = warpOp.getResult(i: operandIndex);
749
750 SmallVector<Value, 4> indices(read.getIndices().begin(),
751 read.getIndices().end());
752 auto sequentialType = cast<VectorType>(Val: read.getResult().getType());
753 auto distributedType = cast<VectorType>(Val: distributedVal.getType());
754 AffineMap map = calculateImplicitMap(sequentialType, distributedType);
755 AffineMap indexMap = map.compose(map: read.getPermutationMap());
756
757 // Try to delinearize the lane ID to match the rank expected for
758 // distribution.
759 SmallVector<Value> delinearizedIds;
760 if (!delinearizeLaneId(builder&: rewriter, loc: read.getLoc(), originalShape: sequentialType.getShape(),
761 distributedShape: distributedType.getShape(), warpSize: warpOp.getWarpSize(),
762 laneId: warpOp.getLaneid(), delinearizedIds)) {
763 return rewriter.notifyMatchFailure(
764 arg&: read, msg: "cannot delinearize lane ID for distribution");
765 }
766 assert(!delinearizedIds.empty() || map.getNumResults() == 0);
767
768 // Distribute indices and the mask (if present).
769 OpBuilder::InsertionGuard g(rewriter);
770 SmallVector<Value> additionalResults(indices.begin(), indices.end());
771 SmallVector<Type> additionalResultTypes(indices.size(),
772 rewriter.getIndexType());
773 additionalResults.push_back(Elt: read.getPadding());
774 additionalResultTypes.push_back(Elt: read.getPadding().getType());
775
776 bool hasMask = false;
777 if (read.getMask()) {
778 hasMask = true;
779 // TODO: Distribution of masked reads with non-trivial permutation maps
780 // requires the distribution of the mask to elementwise match the
781 // distribution of the permuted written vector. Currently the details
782 // of which lane is responsible for which element is captured strictly
783 // by shape information on the warp op, and thus requires materializing
784 // the permutation in IR.
785 if (!mlir::compressUnusedDims(map: read.getPermutationMap()).isIdentity())
786 return rewriter.notifyMatchFailure(
787 arg&: read, msg: "non-trivial permutation maps not supported");
788 VectorType maskType =
789 getDistributedType(originalType: read.getMaskType(), map, warpSize: warpOp.getWarpSize());
790 additionalResults.push_back(Elt: read.getMask());
791 additionalResultTypes.push_back(Elt: maskType);
792 }
793
794 SmallVector<size_t> newRetIndices;
795 WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
796 rewriter, warpOp, newYieldedValues: additionalResults, newReturnTypes: additionalResultTypes,
797 indices&: newRetIndices);
798 distributedVal = newWarpOp.getResult(i: operandIndex);
799
800 // Distributed indices were appended first.
801 SmallVector<Value> newIndices;
802 for (int64_t i = 0, e = indices.size(); i < e; ++i)
803 newIndices.push_back(Elt: newWarpOp.getResult(i: newRetIndices[i]));
804
805 rewriter.setInsertionPointAfter(newWarpOp);
806 for (auto it : llvm::zip_equal(t: indexMap.getResults(), u: map.getResults())) {
807 AffineExpr d0, d1;
808 bindDims(ctx: read.getContext(), exprs&: d0, exprs&: d1);
809 auto indexExpr = dyn_cast<AffineDimExpr>(Val: std::get<0>(t&: it));
810 if (!indexExpr)
811 continue;
812 unsigned indexPos = indexExpr.getPosition();
813 unsigned vectorPos = cast<AffineDimExpr>(Val: std::get<1>(t&: it)).getPosition();
814 int64_t scale = distributedType.getDimSize(idx: vectorPos);
815 newIndices[indexPos] = affine::makeComposedAffineApply(
816 b&: rewriter, loc: read.getLoc(), e: d0 + scale * d1,
817 operands: {newIndices[indexPos], delinearizedIds[vectorPos]});
818 }
819
820 // Distributed padding value was appended right after the indices.
821 Value newPadding = newWarpOp.getResult(i: newRetIndices[indices.size()]);
822 // Distributed mask value was added at the end (if the op has a mask).
823 Value newMask =
824 hasMask ? newWarpOp.getResult(i: newRetIndices[newRetIndices.size() - 1])
825 : Value();
826 auto newRead = rewriter.create<vector::TransferReadOp>(
827 location: read.getLoc(), args: distributedVal.getType(), args: read.getBase(), args&: newIndices,
828 args: read.getPermutationMapAttr(), args&: newPadding, args&: newMask,
829 args: read.getInBoundsAttr());
830
831 rewriter.replaceAllUsesWith(from: distributedVal, to: newRead);
832 return success();
833 }
834};
835
836/// Remove any result that has no use along with the matching yieldOp operand.
837// TODO: Move this in WarpExecuteOnLane0Op canonicalization.
838struct WarpOpDeadResult : public WarpDistributionPattern {
839 using Base::Base;
840 LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
841 PatternRewriter &rewriter) const override {
842 SmallVector<Type> newResultTypes;
843 newResultTypes.reserve(N: warpOp->getNumResults());
844 SmallVector<Value> newYieldValues;
845 newYieldValues.reserve(N: warpOp->getNumResults());
846 DenseMap<Value, int64_t> dedupYieldOperandPositionMap;
847 DenseMap<OpResult, int64_t> dedupResultPositionMap;
848 auto yield = cast<gpu::YieldOp>(
849 Val: warpOp.getBodyRegion().getBlocks().begin()->getTerminator());
850
851 // Some values may be yielded multiple times and correspond to multiple
852 // results. Deduplicating occurs by taking each result with its matching
853 // yielded value, and:
854 // 1. recording the unique first position at which the value is yielded.
855 // 2. recording for the result, the first position at which the dedup'ed
856 // value is yielded.
857 // 3. skipping from the new result types / new yielded values any result
858 // that has no use or whose yielded value has already been seen.
859 for (OpResult result : warpOp.getResults()) {
860 Value yieldOperand = yield.getOperand(i: result.getResultNumber());
861 auto it = dedupYieldOperandPositionMap.insert(
862 KV: std::make_pair(x&: yieldOperand, y: newResultTypes.size()));
863 dedupResultPositionMap.insert(KV: std::make_pair(x&: result, y&: it.first->second));
864 if (result.use_empty() || !it.second)
865 continue;
866 newResultTypes.push_back(Elt: result.getType());
867 newYieldValues.push_back(Elt: yieldOperand);
868 }
869 // No modification, exit early.
870 if (yield.getNumOperands() == newYieldValues.size())
871 return failure();
872 // Move the body of the old warpOp to a new warpOp.
873 WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndReplaceReturns(
874 rewriter, warpOp, newYieldedValues: newYieldValues, newReturnTypes: newResultTypes);
875
876 // Simplify the new warp op after dropping dead results.
877 newWarpOp.getBody()->walk(callback: [&](Operation *op) {
878 if (isOpTriviallyDead(op))
879 rewriter.eraseOp(op);
880 });
881
882 // Replace results of the old warpOp by the new, deduplicated results.
883 SmallVector<Value> newValues;
884 newValues.reserve(N: warpOp->getNumResults());
885 for (OpResult result : warpOp.getResults()) {
886 if (result.use_empty())
887 newValues.push_back(Elt: Value());
888 else
889 newValues.push_back(
890 Elt: newWarpOp.getResult(i: dedupResultPositionMap.lookup(Val: result)));
891 }
892 rewriter.replaceOp(op: warpOp, newValues);
893 return success();
894 }
895};
896
897// If an operand is directly yielded out of the region we can forward it
898// directly and it doesn't need to go through the region.
899struct WarpOpForwardOperand : public WarpDistributionPattern {
900 using Base::Base;
901 LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
902 PatternRewriter &rewriter) const override {
903 auto yield = cast<gpu::YieldOp>(
904 Val: warpOp.getBodyRegion().getBlocks().begin()->getTerminator());
905 Value valForwarded;
906 unsigned resultIndex;
907 for (OpOperand &operand : yield->getOpOperands()) {
908 Value result = warpOp.getResult(i: operand.getOperandNumber());
909 if (result.use_empty())
910 continue;
911
912 // Assume all the values coming from above are uniform.
913 if (!warpOp.getBodyRegion().isAncestor(other: operand.get().getParentRegion())) {
914 if (result.getType() != operand.get().getType())
915 continue;
916 valForwarded = operand.get();
917 resultIndex = operand.getOperandNumber();
918 break;
919 }
920 auto arg = dyn_cast<BlockArgument>(Val: operand.get());
921 if (!arg || arg.getOwner()->getParentOp() != warpOp.getOperation())
922 continue;
923 Value warpOperand = warpOp.getArgs()[arg.getArgNumber()];
924 if (result.getType() != warpOperand.getType())
925 continue;
926 valForwarded = warpOperand;
927 resultIndex = operand.getOperandNumber();
928 break;
929 }
930 if (!valForwarded)
931 return failure();
932 // Notify the rewriter that the warp op is changing (see the comment on
933 // the WarpOpTransferRead pattern).
934 rewriter.startOpModification(op: warpOp);
935 rewriter.replaceAllUsesWith(from: warpOp.getResult(i: resultIndex), to: valForwarded);
936 rewriter.finalizeOpModification(op: warpOp);
937 return success();
938 }
939};
940
941struct WarpOpBroadcast : public WarpDistributionPattern {
942 using Base::Base;
943 LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
944 PatternRewriter &rewriter) const override {
945 OpOperand *operand =
946 getWarpResult(warpOp, fn: llvm::IsaPred<vector::BroadcastOp>);
947 if (!operand)
948 return failure();
949 unsigned int operandNumber = operand->getOperandNumber();
950 auto broadcastOp = operand->get().getDefiningOp<vector::BroadcastOp>();
951 Location loc = broadcastOp.getLoc();
952 auto destVecType =
953 cast<VectorType>(Val: warpOp->getResultTypes()[operandNumber]);
954 Value broadcastSrc = broadcastOp.getSource();
955 Type broadcastSrcType = broadcastSrc.getType();
956
957 // Check that the broadcast actually spans a set of values uniformly across
958 // all threads. In other words, check that each thread can reconstruct
959 // their own broadcast.
960 // For that we simply check that the broadcast we want to build makes sense.
961 if (vector::isBroadcastableTo(srcType: broadcastSrcType, dstVectorType: destVecType) !=
962 vector::BroadcastableToResult::Success)
963 return failure();
964 SmallVector<size_t> newRetIndices;
965 WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
966 rewriter, warpOp, newYieldedValues: {broadcastSrc}, newReturnTypes: {broadcastSrcType}, indices&: newRetIndices);
967 rewriter.setInsertionPointAfter(newWarpOp);
968 Value broadcasted = rewriter.create<vector::BroadcastOp>(
969 location: loc, args&: destVecType, args: newWarpOp->getResult(idx: newRetIndices[0]));
970 rewriter.replaceAllUsesWith(from: newWarpOp->getResult(idx: operandNumber),
971 to: broadcasted);
972 return success();
973 }
974};
975
976/// Pattern to move shape cast out of the warp op. shape cast is basically a
977/// no-op for warp distribution; we need to handle the shape though.
978struct WarpOpShapeCast : public WarpDistributionPattern {
979 using Base::Base;
980 LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
981 PatternRewriter &rewriter) const override {
982 OpOperand *operand =
983 getWarpResult(warpOp, fn: llvm::IsaPred<vector::ShapeCastOp>);
984 if (!operand)
985 return failure();
986
987 auto oldCastOp = operand->get().getDefiningOp<vector::ShapeCastOp>();
988
989 unsigned int operandNumber = operand->getOperandNumber();
990 auto castDistributedType =
991 cast<VectorType>(Val: warpOp->getResultTypes()[operandNumber]);
992 VectorType castOriginalType = oldCastOp.getSourceVectorType();
993 VectorType castResultType = castDistributedType;
994
995 // We expect the distributed type to have a smaller rank than the original
996 // type. Prepend with size-one dimensions to make them the same.
997 unsigned castDistributedRank = castDistributedType.getRank();
998 unsigned castOriginalRank = castOriginalType.getRank();
999 if (castDistributedRank < castOriginalRank) {
1000 SmallVector<int64_t> shape(castOriginalRank - castDistributedRank, 1);
1001 llvm::append_range(C&: shape, R: castDistributedType.getShape());
1002 castDistributedType =
1003 VectorType::get(shape, elementType: castDistributedType.getElementType());
1004 }
1005
1006 SmallVector<size_t> newRetIndices;
1007 WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
1008 rewriter, warpOp, newYieldedValues: {oldCastOp.getSource()}, newReturnTypes: {castDistributedType},
1009 indices&: newRetIndices);
1010 rewriter.setInsertionPointAfter(newWarpOp);
1011 Value newCast = rewriter.create<vector::ShapeCastOp>(
1012 location: oldCastOp.getLoc(), args&: castResultType,
1013 args: newWarpOp->getResult(idx: newRetIndices[0]));
1014 rewriter.replaceAllUsesWith(from: newWarpOp->getResult(idx: operandNumber), to: newCast);
1015 return success();
1016 }
1017};
1018
1019/// Sink out vector.create_mask op feeding into a warp op yield.
1020/// ```
1021/// %0 = ...
1022/// %1 = gpu.warp_execute_on_lane_0(%arg0) -> (vector<1xf32>) {
1023/// ...
1024/// %mask = vector.create_mask %0 : vector<32xi1>
1025/// gpu.yield %mask : vector<32xi1>
1026/// }
1027/// ```
1028/// To
1029/// ```
1030/// %0 = ...
1031/// gpu.warp_execute_on_lane_0(%arg0) {
1032/// ...
1033/// }
1034/// %cmp = arith.cmpi ult, %laneid, %0
1035/// %ub = arith.select %cmp, %c0, %c1
1036/// %1 = vector.create_mask %ub : vector<1xi1>
1037struct WarpOpCreateMask : public WarpDistributionPattern {
1038 using Base::Base;
1039 LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
1040 PatternRewriter &rewriter) const override {
1041 OpOperand *yieldOperand =
1042 getWarpResult(warpOp, fn: llvm::IsaPred<vector::CreateMaskOp>);
1043 if (!yieldOperand)
1044 return failure();
1045
1046 auto mask = yieldOperand->get().getDefiningOp<vector::CreateMaskOp>();
1047
1048 // Early exit if any values needed for calculating the new mask indices
1049 // are defined inside the warp op.
1050 if (!llvm::all_of(Range: mask->getOperands(), P: [&](Value value) {
1051 return warpOp.isDefinedOutsideOfRegion(value);
1052 }))
1053 return failure();
1054
1055 Location loc = mask.getLoc();
1056 unsigned operandIndex = yieldOperand->getOperandNumber();
1057
1058 auto distType = cast<VectorType>(Val: warpOp.getResult(i: operandIndex).getType());
1059 VectorType seqType = mask.getVectorType();
1060 ArrayRef<int64_t> seqShape = seqType.getShape();
1061 ArrayRef<int64_t> distShape = distType.getShape();
1062
1063 rewriter.setInsertionPointAfter(warpOp);
1064
1065 // Delinearize the lane ID for constructing the distributed mask sizes.
1066 SmallVector<Value> delinearizedIds;
1067 if (!delinearizeLaneId(builder&: rewriter, loc, originalShape: seqShape, distributedShape: distShape,
1068 warpSize: warpOp.getWarpSize(), laneId: warpOp.getLaneid(),
1069 delinearizedIds))
1070 return rewriter.notifyMatchFailure(
1071 arg&: mask, msg: "cannot delinearize lane ID for distribution");
1072 assert(!delinearizedIds.empty());
1073
1074 // Notify the rewriter that the warp op is changing (see the comment on
1075 // the WarpOpTransferRead pattern).
1076 rewriter.startOpModification(op: warpOp);
1077
1078 AffineExpr s0, s1;
1079 bindSymbols(ctx: rewriter.getContext(), exprs&: s0, exprs&: s1);
1080 SmallVector<Value> newOperands;
1081 for (int i = 0, e = distShape.size(); i < e; ++i) {
1082 // Get `mask_dim_range_upper_limit[i] - lane_id[i] * dist_sizes[i]` to
1083 // find the distance from the largest mask index owned by this lane to the
1084 // original mask size. `vector.create_mask` implicitly clamps mask
1085 // operands to the range [0, mask_vector_size[i]], or in other words, the
1086 // mask sizes are always in the range [0, mask_vector_size[i]).
1087 Value maskDimIdx = affine::makeComposedAffineApply(
1088 b&: rewriter, loc, e: s1 - s0 * distShape[i],
1089 operands: {delinearizedIds[i], mask.getOperand(i)});
1090 newOperands.push_back(Elt: maskDimIdx);
1091 }
1092
1093 auto newMask =
1094 rewriter.create<vector::CreateMaskOp>(location: loc, args&: distType, args&: newOperands);
1095 rewriter.replaceAllUsesWith(from: warpOp.getResult(i: operandIndex), to: newMask);
1096 rewriter.finalizeOpModification(op: warpOp);
1097 return success();
1098 }
1099};
1100
1101/// Sink out insert_strided_slice op feeding into a warp op yield.
1102/// ```
1103/// %0 = gpu.warp_execute_on_lane_0(%arg0) -> (vector<8x1xf32>) {
1104/// ...
1105/// %src = ... : vector<4x32xf32>
1106/// %dest = ... : vector<8x32xf32>
1107/// %insert = vector.insert_strided_slice %src, %dest, offsets = [0, 0],
1108/// strides = [1, 1] : vector<4x32xf32> into vector<8x32xf32>
1109/// gpu.yield %insert : vector<8x32xf32>
1110/// }
1111/// ```
1112/// To
1113/// ```
1114/// %0 = gpu.warp_execute_on_lane_0(%arg0) -> (vector<4x1xf32>,
1115/// vector<8x1xf32>) {
1116/// ...
1117/// %src = ... : vector<4x32xf32>
1118/// %dest = ... : vector<8x32xf32>
1119/// gpu.yield %src, %dest : vector<4x16xf32>, vector<8x16xf32>
1120/// }
1121/// %insert = vector.insert_strided_slice %0#0, %0#1,
1122/// offsets = [0, 0], strides = [1, 1] : vector<4x1xf32> into vector<8x1xf32>
1123/// ```
1124/// NOTE: Current support assumes that both src and dest vectors are distributed
1125/// to lanes and sinking the insert op does not require any cross lane
1126/// communication.
1127struct WarpOpInsertStridedSlice : public WarpDistributionPattern {
1128 using Base::Base;
1129 LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
1130 PatternRewriter &rewriter) const override {
1131 OpOperand *operand =
1132 getWarpResult(warpOp, fn: llvm::IsaPred<vector::InsertStridedSliceOp>);
1133 if (!operand)
1134 return failure();
1135 unsigned int operandNumber = operand->getOperandNumber();
1136 auto insertOp =
1137 operand->get().getDefiningOp<vector::InsertStridedSliceOp>();
1138 auto distributedType =
1139 cast<VectorType>(Val: warpOp.getResult(i: operandNumber).getType());
1140 // Distributed type must be 2D or higher.
1141 // TODO: Support 1D distributed types.
1142 if (distributedType.getRank() < 2)
1143 return rewriter.notifyMatchFailure(
1144 arg&: insertOp, msg: "result vector type must be 2D or higher");
1145 // Find the distributed dimension of the dest vector. There should be
1146 // exactly one.
1147 auto yieldedType = cast<VectorType>(Val: operand->get().getType());
1148 int64_t destDistributedDim =
1149 getDistributedDim(sequentialType: yieldedType, distributedType);
1150 assert(destDistributedDim != -1 && "could not find distributed dimension");
1151
1152 VectorType srcType = insertOp.getSourceVectorType();
1153 VectorType destType = insertOp.getDestVectorType();
1154 // Currently we require that both source (kD) and dest (nD) vectors are
1155 // distributed. This requires that distributedDim (d) is contained in the
1156 // last k dims of the dest vector (d >= n - k).
1157 // TODO: Add support for case where source vector is not distributed.
1158 int64_t sourceDistributedDim =
1159 destDistributedDim - (destType.getRank() - srcType.getRank());
1160 if (sourceDistributedDim < 0)
1161 return rewriter.notifyMatchFailure(
1162 arg&: insertOp,
1163 msg: "distributed dimension must be in the last k dims of dest vector");
1164 // Distributed dimension must be fully inserted.
1165 if (srcType.getDimSize(idx: sourceDistributedDim) !=
1166 destType.getDimSize(idx: destDistributedDim))
1167 return rewriter.notifyMatchFailure(
1168 arg&: insertOp, msg: "distributed dimension must be fully inserted");
1169 SmallVector<int64_t> newSourceDistShape(
1170 insertOp.getSourceVectorType().getShape());
1171 newSourceDistShape[sourceDistributedDim] =
1172 distributedType.getDimSize(idx: destDistributedDim);
1173 auto newSourceTy =
1174 VectorType::get(shape: newSourceDistShape, elementType: distributedType.getElementType());
1175 VectorType newDestTy = distributedType;
1176 SmallVector<size_t> newRetIndices;
1177 WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
1178 rewriter, warpOp, newYieldedValues: {insertOp.getValueToStore(), insertOp.getDest()},
1179 newReturnTypes: {newSourceTy, newDestTy}, indices&: newRetIndices);
1180 rewriter.setInsertionPointAfter(newWarpOp);
1181 Value distributedSource = newWarpOp->getResult(idx: newRetIndices[0]);
1182 Value distributedDest = newWarpOp->getResult(idx: newRetIndices[1]);
1183 // Create a new insert strided slice op that inserts distributed source into
1184 // distributed dest.
1185 Value newInsert = rewriter.create<vector::InsertStridedSliceOp>(
1186 location: insertOp.getLoc(), args: distributedDest.getType(), args&: distributedSource,
1187 args&: distributedDest, args: insertOp.getOffsets(), args: insertOp.getStrides());
1188 rewriter.replaceAllUsesWith(from: newWarpOp->getResult(idx: operandNumber), to: newInsert);
1189 return success();
1190 }
1191};
1192
1193/// Sink out extract_strided_slice op feeding into a warp op yield.
1194/// ```
1195/// %0 = gpu.warp_execute_on_lane_0(%arg0) -> (vector<16x1xf32>) {
1196/// ...
1197/// %src = ... : vector<64x32xf32>
1198/// %extract = vector.extract_strided_slice %src, offsets = [0], sizes = [16],
1199/// strides = [1] : vector<64x32xf32> to vector<16x32xf32>
1200/// gpu.yield %extract : vector<16x32xf32>
1201/// }
1202/// ```
1203/// To
1204/// ```
1205/// %0 = gpu.warp_execute_on_lane_0(%arg0) -> (vector<64x1xf32>) {
1206/// ...
1207/// %src = ... : vector<64x32xf32>
1208/// gpu.yield %src : vector<64x32xf32>
1209/// }
1210/// %extract = vector.extract_strided_slice %0, offsets = [0], sizes = [16],
1211/// strides = [1] : vector<64x1xf32> to vector<16x1xf32>
1212/// ```
1213/// NOTE: Current support assumes that the extraction happens only on non
1214/// distributed dimensions (does not require cross lane communication).
1215struct WarpOpExtractStridedSlice : public WarpDistributionPattern {
1216 using Base::Base;
1217 LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
1218 PatternRewriter &rewriter) const override {
1219 OpOperand *operand =
1220 getWarpResult(warpOp, fn: llvm::IsaPred<vector::ExtractStridedSliceOp>);
1221 if (!operand)
1222 return failure();
1223 unsigned int operandNumber = operand->getOperandNumber();
1224 auto extractOp =
1225 operand->get().getDefiningOp<vector::ExtractStridedSliceOp>();
1226 auto distributedType =
1227 cast<VectorType>(Val: warpOp.getResult(i: operandNumber).getType());
1228 // Distributed type must be 2D or higher.
1229 // TODO: Support 1D distributed types.
1230 if (distributedType.getRank() < 2)
1231 return rewriter.notifyMatchFailure(
1232 arg&: extractOp, msg: "result vector type must be 2D or higher");
1233
1234 // Find the distributed dimension. There should be exactly one.
1235 auto yieldedType = cast<VectorType>(Val: operand->get().getType());
1236 int64_t distributedDim = getDistributedDim(sequentialType: yieldedType, distributedType);
1237 assert(distributedDim != -1 && "could not find distributed dimension");
1238
1239 int64_t numOfExtractedDims =
1240 static_cast<int64_t>(extractOp.getSizes().size());
1241 // If the distributed dim is included in the extracted dims, then we make
1242 // sure distributed dim is fully extracted. If distributed dim is not
1243 // included in extracted dims, it is guaranteed to be fully extracted (i.e.
1244 // distributed dim comes after all the extracted dims)
1245 // TODO: Partial extraction from distributed dimension require cross lane
1246 // communication.
1247 if (distributedDim < numOfExtractedDims) {
1248 int64_t distributedDimOffset =
1249 llvm::cast<IntegerAttr>(Val: extractOp.getOffsets()[distributedDim])
1250 .getInt();
1251 int64_t distributedDimSize =
1252 llvm::cast<IntegerAttr>(Val: extractOp.getSizes()[distributedDim])
1253 .getInt();
1254 if (distributedDimOffset != 0 ||
1255 distributedDimSize != yieldedType.getDimSize(idx: distributedDim))
1256 return rewriter.notifyMatchFailure(
1257 arg&: extractOp, msg: "distributed dimension must be fully extracted");
1258 }
1259 SmallVector<int64_t> newDistributedShape(
1260 extractOp.getSourceVectorType().getShape());
1261 newDistributedShape[distributedDim] =
1262 distributedType.getDimSize(idx: distributedDim);
1263 auto newDistributedType =
1264 VectorType::get(shape: newDistributedShape, elementType: distributedType.getElementType());
1265 SmallVector<size_t> newRetIndices;
1266 WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
1267 rewriter, warpOp, newYieldedValues: {extractOp.getVector()}, newReturnTypes: {newDistributedType},
1268 indices&: newRetIndices);
1269 rewriter.setInsertionPointAfter(newWarpOp);
1270 SmallVector<Attribute> distributedSizes = llvm::map_to_vector(
1271 C: extractOp.getSizes(), F: [](Attribute attr) { return attr; });
1272 // Update the distributed sizes to match the distributed type.
1273 if (distributedDim < static_cast<int64_t>(distributedSizes.size()))
1274 distributedSizes[distributedDim] = rewriter.getI64IntegerAttr(
1275 value: distributedType.getDimSize(idx: distributedDim));
1276
1277 // Create a new extract strided slice op that extracts from the
1278 // distributed vector.
1279 Value distributedVec = newWarpOp->getResult(idx: newRetIndices[0]);
1280 Value newExtract = rewriter.create<vector::ExtractStridedSliceOp>(
1281 location: extractOp.getLoc(), args&: distributedType, args&: distributedVec,
1282 args: extractOp.getOffsets(),
1283 args: ArrayAttr::get(context: rewriter.getContext(), value: distributedSizes),
1284 args: extractOp.getStrides());
1285 rewriter.replaceAllUsesWith(from: newWarpOp->getResult(idx: operandNumber),
1286 to: newExtract);
1287 return success();
1288 }
1289};
1290
1291/// Pattern to move out vector.extract of single element vector. Those don't
1292/// need to be distributed and can just be propagated outside of the region.
1293struct WarpOpExtract : public WarpDistributionPattern {
1294 using Base::Base;
1295 LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
1296 PatternRewriter &rewriter) const override {
1297 OpOperand *operand =
1298 getWarpResult(warpOp, fn: llvm::IsaPred<vector::ExtractOp>);
1299 if (!operand)
1300 return failure();
1301 unsigned int operandNumber = operand->getOperandNumber();
1302 auto extractOp = operand->get().getDefiningOp<vector::ExtractOp>();
1303 VectorType extractSrcType = extractOp.getSourceVectorType();
1304 Location loc = extractOp.getLoc();
1305
1306 // For 1-d or 0-d source cases, we rely on WarpOpExtractScalar pattern.
1307 if (extractSrcType.getRank() <= 1) {
1308 return failure();
1309 }
1310
1311 // All following cases are 2d or higher dimensional source vectors.
1312
1313 if (warpOp.getResult(i: operandNumber).getType() == operand->get().getType()) {
1314 // There is no distribution, this is a broadcast. Simply move the extract
1315 // out of the warp op.
1316 // TODO: This could be optimized. E.g., in case of a scalar result, let
1317 // one lane extract and shuffle the result to all other lanes (same as
1318 // the 1d case).
1319 SmallVector<size_t> newRetIndices;
1320 WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
1321 rewriter, warpOp, newYieldedValues: {extractOp.getVector()},
1322 newReturnTypes: {extractOp.getSourceVectorType()}, indices&: newRetIndices);
1323 rewriter.setInsertionPointAfter(newWarpOp);
1324 Value distributedVec = newWarpOp->getResult(idx: newRetIndices[0]);
1325 // Extract from distributed vector.
1326 Value newExtract = rewriter.create<vector::ExtractOp>(
1327 location: loc, args&: distributedVec, args: extractOp.getMixedPosition());
1328 rewriter.replaceAllUsesWith(from: newWarpOp->getResult(idx: operandNumber),
1329 to: newExtract);
1330 return success();
1331 }
1332
1333 // Find the distributed dimension. There should be exactly one.
1334 auto distributedType =
1335 cast<VectorType>(Val: warpOp.getResult(i: operandNumber).getType());
1336 auto yieldedType = cast<VectorType>(Val: operand->get().getType());
1337 int64_t distributedDim = getDistributedDim(sequentialType: yieldedType, distributedType);
1338 assert(distributedDim != -1 && "could not find distributed dimension");
1339 (void)distributedDim;
1340
1341 // Yield source vector from warp op.
1342 SmallVector<int64_t> newDistributedShape(extractSrcType.getShape());
1343 for (int i = 0; i < distributedType.getRank(); ++i)
1344 newDistributedShape[i + extractOp.getNumIndices()] =
1345 distributedType.getDimSize(idx: i);
1346 auto newDistributedType =
1347 VectorType::get(shape: newDistributedShape, elementType: distributedType.getElementType());
1348 SmallVector<size_t> newRetIndices;
1349 WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
1350 rewriter, warpOp, newYieldedValues: {extractOp.getVector()}, newReturnTypes: {newDistributedType},
1351 indices&: newRetIndices);
1352 rewriter.setInsertionPointAfter(newWarpOp);
1353 Value distributedVec = newWarpOp->getResult(idx: newRetIndices[0]);
1354 // Extract from distributed vector.
1355 Value newExtract = rewriter.create<vector::ExtractOp>(
1356 location: loc, args&: distributedVec, args: extractOp.getMixedPosition());
1357 rewriter.replaceAllUsesWith(from: newWarpOp->getResult(idx: operandNumber),
1358 to: newExtract);
1359 return success();
1360 }
1361};
1362
1363/// Pattern to move out vector.extract with a scalar result.
1364/// Only supports 1-D and 0-D sources for now.
1365struct WarpOpExtractScalar : public WarpDistributionPattern {
1366 WarpOpExtractScalar(MLIRContext *ctx, WarpShuffleFromIdxFn fn,
1367 PatternBenefit b = 1)
1368 : WarpDistributionPattern(ctx, b), warpShuffleFromIdxFn(std::move(fn)) {}
1369 LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
1370 PatternRewriter &rewriter) const override {
1371 OpOperand *operand =
1372 getWarpResult(warpOp, fn: llvm::IsaPred<vector::ExtractOp>);
1373 if (!operand)
1374 return failure();
1375 unsigned int operandNumber = operand->getOperandNumber();
1376 auto extractOp = operand->get().getDefiningOp<vector::ExtractOp>();
1377 VectorType extractSrcType = extractOp.getSourceVectorType();
1378 // Only supports 1-D or 0-D sources for now.
1379 if (extractSrcType.getRank() > 1) {
1380 return rewriter.notifyMatchFailure(
1381 arg&: extractOp, msg: "only 0-D or 1-D source supported for now");
1382 }
1383 // TODO: Supported shuffle types should be parameterizable, similar to
1384 // `WarpShuffleFromIdxFn`.
1385 if (!extractSrcType.getElementType().isF32() &&
1386 !extractSrcType.getElementType().isInteger(width: 32))
1387 return rewriter.notifyMatchFailure(
1388 arg&: extractOp, msg: "only f32/i32 element types are supported");
1389 bool is0dOrVec1Extract = extractSrcType.getNumElements() == 1;
1390 Type elType = extractSrcType.getElementType();
1391 VectorType distributedVecType;
1392 if (!is0dOrVec1Extract) {
1393 assert(extractSrcType.getRank() == 1 &&
1394 "expected that extract src rank is 0 or 1");
1395 if (extractSrcType.getShape()[0] % warpOp.getWarpSize() != 0)
1396 return failure();
1397 int64_t elementsPerLane =
1398 extractSrcType.getShape()[0] / warpOp.getWarpSize();
1399 distributedVecType = VectorType::get(shape: {elementsPerLane}, elementType: elType);
1400 } else {
1401 distributedVecType = extractSrcType;
1402 }
1403 // Yield source vector and position (if present) from warp op.
1404 SmallVector<Value> additionalResults{extractOp.getVector()};
1405 SmallVector<Type> additionalResultTypes{distributedVecType};
1406 additionalResults.append(
1407 RHS: SmallVector<Value>(extractOp.getDynamicPosition()));
1408 additionalResultTypes.append(
1409 RHS: SmallVector<Type>(extractOp.getDynamicPosition().getTypes()));
1410
1411 Location loc = extractOp.getLoc();
1412 SmallVector<size_t> newRetIndices;
1413 WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
1414 rewriter, warpOp, newYieldedValues: additionalResults, newReturnTypes: additionalResultTypes,
1415 indices&: newRetIndices);
1416 rewriter.setInsertionPointAfter(newWarpOp);
1417 Value distributedVec = newWarpOp->getResult(idx: newRetIndices[0]);
1418
1419 // 0d extract: The new warp op broadcasts the source vector to all lanes.
1420 // All lanes extract the scalar.
1421 if (is0dOrVec1Extract) {
1422 Value newExtract;
1423 SmallVector<int64_t> indices(extractSrcType.getRank(), 0);
1424 newExtract =
1425 rewriter.create<vector::ExtractOp>(location: loc, args&: distributedVec, args&: indices);
1426 rewriter.replaceAllUsesWith(from: newWarpOp->getResult(idx: operandNumber),
1427 to: newExtract);
1428 return success();
1429 }
1430
1431 int64_t staticPos = extractOp.getStaticPosition()[0];
1432 OpFoldResult pos = ShapedType::isDynamic(dValue: staticPos)
1433 ? (newWarpOp->getResult(idx: newRetIndices[1]))
1434 : OpFoldResult(rewriter.getIndexAttr(value: staticPos));
1435 // 1d extract: Distribute the source vector. One lane extracts and shuffles
1436 // the value to all other lanes.
1437 int64_t elementsPerLane = distributedVecType.getShape()[0];
1438 AffineExpr sym0 = getAffineSymbolExpr(position: 0, context: rewriter.getContext());
1439 // tid of extracting thread: pos / elementsPerLane
1440 Value broadcastFromTid = affine::makeComposedAffineApply(
1441 b&: rewriter, loc, e: sym0.ceilDiv(v: elementsPerLane), operands: pos);
1442 // Extract at position: pos % elementsPerLane
1443 Value newPos =
1444 elementsPerLane == 1
1445 ? rewriter.create<arith::ConstantIndexOp>(location: loc, args: 0).getResult()
1446 : affine::makeComposedAffineApply(b&: rewriter, loc,
1447 e: sym0 % elementsPerLane, operands: pos);
1448 Value extracted =
1449 rewriter.create<vector::ExtractOp>(location: loc, args&: distributedVec, args&: newPos);
1450
1451 // Shuffle the extracted value to all lanes.
1452 Value shuffled = warpShuffleFromIdxFn(
1453 loc, rewriter, extracted, broadcastFromTid, newWarpOp.getWarpSize());
1454 rewriter.replaceAllUsesWith(from: newWarpOp->getResult(idx: operandNumber), to: shuffled);
1455 return success();
1456 }
1457
1458private:
1459 WarpShuffleFromIdxFn warpShuffleFromIdxFn;
1460};
1461
1462/// Pattern to move out vector.insert with a scalar input.
1463/// Only supports 1-D and 0-D destinations for now.
1464struct WarpOpInsertScalar : public WarpDistributionPattern {
1465 using Base::Base;
1466 LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
1467 PatternRewriter &rewriter) const override {
1468 OpOperand *operand = getWarpResult(warpOp, fn: llvm::IsaPred<vector::InsertOp>);
1469 if (!operand)
1470 return failure();
1471 unsigned int operandNumber = operand->getOperandNumber();
1472 auto insertOp = operand->get().getDefiningOp<vector::InsertOp>();
1473 VectorType vecType = insertOp.getDestVectorType();
1474 VectorType distrType =
1475 cast<VectorType>(Val: warpOp.getResult(i: operandNumber).getType());
1476
1477 // Only supports 1-D or 0-D destinations for now.
1478 if (vecType.getRank() > 1) {
1479 return rewriter.notifyMatchFailure(
1480 arg&: insertOp, msg: "only 0-D or 1-D source supported for now");
1481 }
1482
1483 // Yield destination vector, source scalar and position from warp op.
1484 SmallVector<Value> additionalResults{insertOp.getDest(),
1485 insertOp.getValueToStore()};
1486 SmallVector<Type> additionalResultTypes{
1487 distrType, insertOp.getValueToStore().getType()};
1488 additionalResults.append(RHS: SmallVector<Value>(insertOp.getDynamicPosition()));
1489 additionalResultTypes.append(
1490 RHS: SmallVector<Type>(insertOp.getDynamicPosition().getTypes()));
1491
1492 Location loc = insertOp.getLoc();
1493 SmallVector<size_t> newRetIndices;
1494 WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
1495 rewriter, warpOp, newYieldedValues: additionalResults, newReturnTypes: additionalResultTypes,
1496 indices&: newRetIndices);
1497 rewriter.setInsertionPointAfter(newWarpOp);
1498 Value distributedVec = newWarpOp->getResult(idx: newRetIndices[0]);
1499 Value newSource = newWarpOp->getResult(idx: newRetIndices[1]);
1500 rewriter.setInsertionPointAfter(newWarpOp);
1501
1502 OpFoldResult pos;
1503 if (vecType.getRank() != 0) {
1504 int64_t staticPos = insertOp.getStaticPosition()[0];
1505 pos = ShapedType::isDynamic(dValue: staticPos)
1506 ? (newWarpOp->getResult(idx: newRetIndices[2]))
1507 : OpFoldResult(rewriter.getIndexAttr(value: staticPos));
1508 }
1509
1510 // This condition is always true for 0-d vectors.
1511 if (vecType == distrType) {
1512 Value newInsert;
1513 SmallVector<OpFoldResult> indices;
1514 if (pos) {
1515 indices.push_back(Elt: pos);
1516 }
1517 newInsert = rewriter.create<vector::InsertOp>(location: loc, args&: newSource,
1518 args&: distributedVec, args&: indices);
1519 // Broadcast: Simply move the vector.insert op out.
1520 rewriter.replaceAllUsesWith(from: newWarpOp->getResult(idx: operandNumber),
1521 to: newInsert);
1522 return success();
1523 }
1524
1525 // This is a distribution. Only one lane should insert.
1526 int64_t elementsPerLane = distrType.getShape()[0];
1527 AffineExpr sym0 = getAffineSymbolExpr(position: 0, context: rewriter.getContext());
1528 // tid of extracting thread: pos / elementsPerLane
1529 Value insertingLane = affine::makeComposedAffineApply(
1530 b&: rewriter, loc, e: sym0.ceilDiv(v: elementsPerLane), operands: pos);
1531 // Insert position: pos % elementsPerLane
1532 OpFoldResult newPos = affine::makeComposedFoldedAffineApply(
1533 b&: rewriter, loc, expr: sym0 % elementsPerLane, operands: pos);
1534 Value isInsertingLane = rewriter.create<arith::CmpIOp>(
1535 location: loc, args: arith::CmpIPredicate::eq, args: newWarpOp.getLaneid(), args&: insertingLane);
1536 Value newResult =
1537 rewriter
1538 .create<scf::IfOp>(
1539 location: loc, args&: isInsertingLane,
1540 /*thenBuilder=*/
1541 args: [&](OpBuilder &builder, Location loc) {
1542 Value newInsert = builder.create<vector::InsertOp>(
1543 location: loc, args&: newSource, args&: distributedVec, args&: newPos);
1544 builder.create<scf::YieldOp>(location: loc, args&: newInsert);
1545 },
1546 /*elseBuilder=*/
1547 args: [&](OpBuilder &builder, Location loc) {
1548 builder.create<scf::YieldOp>(location: loc, args&: distributedVec);
1549 })
1550 .getResult(i: 0);
1551 rewriter.replaceAllUsesWith(from: newWarpOp->getResult(idx: operandNumber), to: newResult);
1552 return success();
1553 }
1554};
1555
1556struct WarpOpInsert : public WarpDistributionPattern {
1557 using Base::Base;
1558 LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
1559 PatternRewriter &rewriter) const override {
1560 OpOperand *operand = getWarpResult(warpOp, fn: llvm::IsaPred<vector::InsertOp>);
1561 if (!operand)
1562 return failure();
1563 unsigned int operandNumber = operand->getOperandNumber();
1564 auto insertOp = operand->get().getDefiningOp<vector::InsertOp>();
1565 Location loc = insertOp.getLoc();
1566
1567 // For 1-d or 0-d destination cases, we rely on WarpOpInsertScalar pattern.
1568 if (insertOp.getDestVectorType().getRank() <= 1) {
1569 return failure();
1570 }
1571
1572 // All following cases are 2d or higher dimensional source vectors.
1573
1574 if (warpOp.getResult(i: operandNumber).getType() == operand->get().getType()) {
1575 // There is no distribution, this is a broadcast. Simply move the insert
1576 // out of the warp op.
1577 SmallVector<size_t> newRetIndices;
1578 WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
1579 rewriter, warpOp, newYieldedValues: {insertOp.getValueToStore(), insertOp.getDest()},
1580 newReturnTypes: {insertOp.getValueToStoreType(), insertOp.getDestVectorType()},
1581 indices&: newRetIndices);
1582 rewriter.setInsertionPointAfter(newWarpOp);
1583 Value distributedSrc = newWarpOp->getResult(idx: newRetIndices[0]);
1584 Value distributedDest = newWarpOp->getResult(idx: newRetIndices[1]);
1585 Value newResult = rewriter.create<vector::InsertOp>(
1586 location: loc, args&: distributedSrc, args&: distributedDest, args: insertOp.getMixedPosition());
1587 rewriter.replaceAllUsesWith(from: newWarpOp->getResult(idx: operandNumber),
1588 to: newResult);
1589 return success();
1590 }
1591
1592 // Find the distributed dimension. There should be exactly one.
1593 auto distrDestType =
1594 cast<VectorType>(Val: warpOp.getResult(i: operandNumber).getType());
1595 auto yieldedType = cast<VectorType>(Val: operand->get().getType());
1596 int64_t distrDestDim = -1;
1597 for (int64_t i = 0; i < yieldedType.getRank(); ++i) {
1598 if (distrDestType.getDimSize(idx: i) != yieldedType.getDimSize(idx: i)) {
1599 // Keep this assert here in case WarpExecuteOnLane0Op gets extended to
1600 // support distributing multiple dimensions in the future.
1601 assert(distrDestDim == -1 && "found multiple distributed dims");
1602 distrDestDim = i;
1603 }
1604 }
1605 assert(distrDestDim != -1 && "could not find distributed dimension");
1606
1607 // Compute the distributed source vector type.
1608 VectorType srcVecType = cast<VectorType>(Val: insertOp.getValueToStoreType());
1609 SmallVector<int64_t> distrSrcShape(srcVecType.getShape());
1610 // E.g.: vector.insert %s, %d [2] : vector<96xf32> into vector<128x96xf32>
1611 // Case 1: distrDestDim = 1 (dim of size 96). In that case, each lane will
1612 // insert a smaller vector<3xf32>.
1613 // Case 2: distrDestDim = 0 (dim of size 128) => distrSrcDim = -1. In that
1614 // case, one lane will insert the source vector<96xf32>. The other
1615 // lanes will not do anything.
1616 int64_t distrSrcDim = distrDestDim - insertOp.getNumIndices();
1617 if (distrSrcDim >= 0)
1618 distrSrcShape[distrSrcDim] = distrDestType.getDimSize(idx: distrDestDim);
1619 auto distrSrcType =
1620 VectorType::get(shape: distrSrcShape, elementType: distrDestType.getElementType());
1621
1622 // Yield source and dest vectors from warp op.
1623 SmallVector<size_t> newRetIndices;
1624 WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
1625 rewriter, warpOp, newYieldedValues: {insertOp.getValueToStore(), insertOp.getDest()},
1626 newReturnTypes: {distrSrcType, distrDestType}, indices&: newRetIndices);
1627 rewriter.setInsertionPointAfter(newWarpOp);
1628 Value distributedSrc = newWarpOp->getResult(idx: newRetIndices[0]);
1629 Value distributedDest = newWarpOp->getResult(idx: newRetIndices[1]);
1630
1631 // Insert into the distributed vector.
1632 Value newResult;
1633 if (distrSrcDim >= 0) {
1634 // Every lane inserts a small piece.
1635 newResult = rewriter.create<vector::InsertOp>(
1636 location: loc, args&: distributedSrc, args&: distributedDest, args: insertOp.getMixedPosition());
1637 } else {
1638 // One lane inserts the entire source vector.
1639 int64_t elementsPerLane = distrDestType.getDimSize(idx: distrDestDim);
1640 SmallVector<OpFoldResult> pos = insertOp.getMixedPosition();
1641 SmallVector<int64_t> newPos = getAsIntegers(foldResults: pos);
1642 // tid of inserting lane: pos / elementsPerLane
1643 Value insertingLane = rewriter.create<arith::ConstantIndexOp>(
1644 location: loc, args: newPos[distrDestDim] / elementsPerLane);
1645 Value isInsertingLane = rewriter.create<arith::CmpIOp>(
1646 location: loc, args: arith::CmpIPredicate::eq, args: newWarpOp.getLaneid(), args&: insertingLane);
1647 // Insert position: pos % elementsPerLane
1648 newPos[distrDestDim] %= elementsPerLane;
1649 auto insertingBuilder = [&](OpBuilder &builder, Location loc) {
1650 Value newInsert = builder.create<vector::InsertOp>(
1651 location: loc, args&: distributedSrc, args&: distributedDest, args&: newPos);
1652 builder.create<scf::YieldOp>(location: loc, args&: newInsert);
1653 };
1654 auto nonInsertingBuilder = [&](OpBuilder &builder, Location loc) {
1655 builder.create<scf::YieldOp>(location: loc, args&: distributedDest);
1656 };
1657 newResult = rewriter
1658 .create<scf::IfOp>(location: loc, args&: isInsertingLane,
1659 /*thenBuilder=*/args&: insertingBuilder,
1660 /*elseBuilder=*/args&: nonInsertingBuilder)
1661 .getResult(i: 0);
1662 }
1663
1664 rewriter.replaceAllUsesWith(from: newWarpOp->getResult(idx: operandNumber), to: newResult);
1665 return success();
1666 }
1667};
1668
1669/// Sink scf.for region out of WarpExecuteOnLane0Op. This can be done only if
1670/// the scf.ForOp is the last operation in the region so that it doesn't
1671/// change the order of execution. This creates a new scf.for region after the
1672/// WarpExecuteOnLane0Op. The new scf.for region will contain a new
1673/// WarpExecuteOnLane0Op region. Example:
1674/// ```
1675/// %w = gpu.warp_execute_on_lane_0(%laneid) -> (vector<4xf32>) {
1676/// ...
1677/// %v1 = scf.for %arg3 = %c0 to %c128 step %c1 iter_args(%arg4 = %v)
1678/// -> (vector<128xf32>) {
1679/// ...
1680/// scf.yield %r : vector<128xf32>
1681/// }
1682/// gpu.yield %v1 : vector<128xf32>
1683/// }
1684/// ```
1685/// To:
1686/// %w0 = gpu.warp_execute_on_lane_0(%arg0) -> (vector<4xf32>) {
1687/// ...
1688/// gpu.yield %v : vector<128xf32>
1689/// }
1690/// %w = scf.for %arg3 = %c0 to %c128 step %c1 iter_args(%varg = %q0)
1691/// -> (vector<4xf32>) {
1692/// %iw = gpu.warp_execute_on_lane_0(%laneid)
1693/// args(%varg : vector<4xf32>) -> (vector<4xf32>) {
1694/// ^bb0(%arg: vector<128xf32>):
1695/// ...
1696/// gpu.yield %ir : vector<128xf32>
1697/// }
1698/// scf.yield %iw : vector<4xf32>
1699/// }
1700/// ```
1701struct WarpOpScfForOp : public WarpDistributionPattern {
1702
1703 WarpOpScfForOp(MLIRContext *ctx, DistributionMapFn fn, PatternBenefit b = 1)
1704 : WarpDistributionPattern(ctx, b), distributionMapFn(std::move(fn)) {}
1705 LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
1706 PatternRewriter &rewriter) const override {
1707 auto warpOpYield = cast<gpu::YieldOp>(
1708 Val: warpOp.getBodyRegion().getBlocks().begin()->getTerminator());
1709 // Only pick up `ForOp` if it is the last op in the region.
1710 Operation *lastNode = warpOpYield->getPrevNode();
1711 auto forOp = dyn_cast_or_null<scf::ForOp>(Val: lastNode);
1712 if (!forOp)
1713 return failure();
1714 // Collect Values that come from the `WarpOp` but are outside the `ForOp`.
1715 // Those Values need to be returned by the new warp op.
1716 llvm::SmallSetVector<Value, 32> escapingValues;
1717 SmallVector<Type> escapingValueInputTypes;
1718 SmallVector<Type> escapingValueDistTypes;
1719 mlir::visitUsedValuesDefinedAbove(
1720 regions: forOp.getBodyRegion(), callback: [&](OpOperand *operand) {
1721 Operation *parent = operand->get().getParentRegion()->getParentOp();
1722 if (warpOp->isAncestor(other: parent)) {
1723 if (!escapingValues.insert(X: operand->get()))
1724 return;
1725 Type distType = operand->get().getType();
1726 if (auto vecType = dyn_cast<VectorType>(Val&: distType)) {
1727 AffineMap map = distributionMapFn(operand->get());
1728 distType = getDistributedType(originalType: vecType, map, warpSize: warpOp.getWarpSize());
1729 }
1730 escapingValueInputTypes.push_back(Elt: operand->get().getType());
1731 escapingValueDistTypes.push_back(Elt: distType);
1732 }
1733 });
1734
1735 if (llvm::is_contained(Range&: escapingValueDistTypes, Element: Type{}))
1736 return failure();
1737 // `WarpOp` can yield two types of values:
1738 // 1. Values that are not results of the `ForOp`:
1739 // These values must also be yielded by the new `WarpOp`. Also, we need
1740 // to record the index mapping for these values to replace them later.
1741 // 2. Values that are results of the `ForOp`:
1742 // In this case, we record the index mapping between the `WarpOp` result
1743 // index and matching `ForOp` result index.
1744 // Additionally, we keep track of the distributed types for all `ForOp`
1745 // vector results.
1746 SmallVector<Value> nonForYieldedValues;
1747 SmallVector<unsigned> nonForResultIndices;
1748 llvm::SmallDenseMap<unsigned, unsigned> forResultMapping;
1749 llvm::SmallDenseMap<unsigned, VectorType> forResultDistTypes;
1750 for (OpOperand &yieldOperand : warpOpYield->getOpOperands()) {
1751 // Yielded value is not a result of the forOp.
1752 if (yieldOperand.get().getDefiningOp() != forOp.getOperation()) {
1753 nonForYieldedValues.push_back(Elt: yieldOperand.get());
1754 nonForResultIndices.push_back(Elt: yieldOperand.getOperandNumber());
1755 continue;
1756 }
1757 OpResult forResult = cast<OpResult>(Val: yieldOperand.get());
1758 unsigned int forResultNumber = forResult.getResultNumber();
1759 forResultMapping[yieldOperand.getOperandNumber()] = forResultNumber;
1760 // If this `ForOp` result is vector type and it is yielded by the
1761 // `WarpOp`, we keep track the distributed type for this result.
1762 if (!isa<VectorType>(Val: forResult.getType()))
1763 continue;
1764 VectorType distType = cast<VectorType>(
1765 Val: warpOp.getResult(i: yieldOperand.getOperandNumber()).getType());
1766 forResultDistTypes[forResultNumber] = distType;
1767 }
1768
1769 // Newly created `WarpOp` will yield values in following order:
1770 // 1. All init args of the `ForOp`.
1771 // 2. All escaping values.
1772 // 3. All non-`ForOp` yielded values.
1773 SmallVector<Value> newWarpOpYieldValues;
1774 SmallVector<Type> newWarpOpDistTypes;
1775 for (auto [i, initArg] : llvm::enumerate(First: forOp.getInitArgs())) {
1776 newWarpOpYieldValues.push_back(Elt: initArg);
1777 // Compute the distributed type for this init arg.
1778 Type distType = initArg.getType();
1779 if (auto vecType = dyn_cast<VectorType>(Val&: distType)) {
1780 // If the `ForOp` result corresponds to this init arg is already yielded
1781 // we can get the distributed type from `forResultDistTypes` map.
1782 // Otherwise, we compute it using distributionMapFn.
1783 AffineMap map = distributionMapFn(initArg);
1784 distType = forResultDistTypes.count(Val: i)
1785 ? forResultDistTypes[i]
1786 : getDistributedType(originalType: vecType, map, warpSize: warpOp.getWarpSize());
1787 }
1788 newWarpOpDistTypes.push_back(Elt: distType);
1789 }
1790 // Insert escaping values and their distributed types.
1791 newWarpOpYieldValues.insert(I: newWarpOpYieldValues.end(),
1792 From: escapingValues.begin(), To: escapingValues.end());
1793 newWarpOpDistTypes.insert(I: newWarpOpDistTypes.end(),
1794 From: escapingValueDistTypes.begin(),
1795 To: escapingValueDistTypes.end());
1796 // Next, we insert all non-`ForOp` yielded values and their distributed
1797 // types. We also create a mapping between the non-`ForOp` yielded value
1798 // index and the corresponding new `WarpOp` yield value index (needed to
1799 // update users later).
1800 llvm::SmallDenseMap<unsigned, unsigned> nonForResultMapping;
1801 for (auto [i, v] :
1802 llvm::zip_equal(t&: nonForResultIndices, u&: nonForYieldedValues)) {
1803 nonForResultMapping[i] = newWarpOpYieldValues.size();
1804 newWarpOpYieldValues.push_back(Elt: v);
1805 newWarpOpDistTypes.push_back(Elt: warpOp.getResult(i).getType());
1806 }
1807 // Create the new `WarpOp` with the updated yield values and types.
1808 WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndReplaceReturns(
1809 rewriter, warpOp, newYieldedValues: newWarpOpYieldValues, newReturnTypes: newWarpOpDistTypes);
1810
1811 // Next, we create a new `ForOp` with the init args yielded by the new
1812 // `WarpOp`.
1813 const unsigned escapingValuesStartIdx =
1814 forOp.getInitArgs().size(); // `ForOp` init args are positioned before
1815 // escaping values in the new `WarpOp`.
1816 SmallVector<Value> newForOpOperands;
1817 for (size_t i = 0; i < escapingValuesStartIdx; ++i)
1818 newForOpOperands.push_back(Elt: newWarpOp.getResult(i));
1819
1820 // Create a new `ForOp` outside the new `WarpOp` region.
1821 OpBuilder::InsertionGuard g(rewriter);
1822 rewriter.setInsertionPointAfter(newWarpOp);
1823 auto newForOp = rewriter.create<scf::ForOp>(
1824 location: forOp.getLoc(), args: forOp.getLowerBound(), args: forOp.getUpperBound(),
1825 args: forOp.getStep(), args&: newForOpOperands);
1826 // Next, we insert a new `WarpOp` (called inner `WarpOp`) inside the
1827 // newly created `ForOp`. This `WarpOp` will contain all ops that were
1828 // contained within the original `ForOp` body.
1829 rewriter.setInsertionPointToStart(newForOp.getBody());
1830
1831 SmallVector<Value> innerWarpInput(newForOp.getRegionIterArgs().begin(),
1832 newForOp.getRegionIterArgs().end());
1833 SmallVector<Type> innerWarpInputType(forOp.getResultTypes().begin(),
1834 forOp.getResultTypes().end());
1835 // Escaping values are forwarded to the inner `WarpOp` as its (additional)
1836 // arguments. We keep track of the mapping between these values and their
1837 // argument index in the inner `WarpOp` (to replace users later).
1838 llvm::SmallDenseMap<Value, int64_t> argIndexMapping;
1839 for (size_t i = escapingValuesStartIdx;
1840 i < escapingValuesStartIdx + escapingValues.size(); ++i) {
1841 innerWarpInput.push_back(Elt: newWarpOp.getResult(i));
1842 argIndexMapping[escapingValues[i - escapingValuesStartIdx]] =
1843 innerWarpInputType.size();
1844 innerWarpInputType.push_back(
1845 Elt: escapingValueInputTypes[i - escapingValuesStartIdx]);
1846 }
1847 // Create the inner `WarpOp` with the new input values and types.
1848 auto innerWarp = rewriter.create<WarpExecuteOnLane0Op>(
1849 location: newWarpOp.getLoc(), args: newForOp.getResultTypes(), args: newWarpOp.getLaneid(),
1850 args: newWarpOp.getWarpSize(), args&: innerWarpInput, args&: innerWarpInputType);
1851
1852 // Inline the `ForOp` body into the inner `WarpOp` body.
1853 SmallVector<Value> argMapping;
1854 argMapping.push_back(Elt: newForOp.getInductionVar());
1855 for (Value args : innerWarp.getBody()->getArguments())
1856 argMapping.push_back(Elt: args);
1857
1858 argMapping.resize(N: forOp.getBody()->getNumArguments());
1859 SmallVector<Value> yieldOperands;
1860 for (Value operand : forOp.getBody()->getTerminator()->getOperands())
1861 yieldOperands.push_back(Elt: operand);
1862
1863 rewriter.eraseOp(op: forOp.getBody()->getTerminator());
1864 rewriter.mergeBlocks(source: forOp.getBody(), dest: innerWarp.getBody(), argValues: argMapping);
1865
1866 // Insert a gpu `YieldOp` at the end of the inner `WarpOp` body that yields
1867 // original `ForOp` results.
1868 rewriter.setInsertionPointToEnd(innerWarp.getBody());
1869 rewriter.create<gpu::YieldOp>(location: innerWarp.getLoc(), args&: yieldOperands);
1870 rewriter.setInsertionPointAfter(innerWarp);
1871 // Insert a scf.yield op at the end of the new `ForOp` body that yields
1872 // the inner `WarpOp` results.
1873 if (!innerWarp.getResults().empty())
1874 rewriter.create<scf::YieldOp>(location: forOp.getLoc(), args: innerWarp.getResults());
1875
1876 // Update the users of original `WarpOp` results that were coming from the
1877 // original `ForOp` to the corresponding new `ForOp` result.
1878 for (auto [origIdx, newIdx] : forResultMapping)
1879 rewriter.replaceAllUsesExcept(from: warpOp.getResult(i: origIdx),
1880 to: newForOp.getResult(i: newIdx), exceptedUser: newForOp);
1881 // Similarly, update any users of the `WarpOp` results that were not
1882 // results of the `ForOp`.
1883 for (auto [origIdx, newIdx] : nonForResultMapping)
1884 rewriter.replaceAllUsesWith(from: warpOp.getResult(i: origIdx),
1885 to: newWarpOp.getResult(i: newIdx));
1886 // Remove the original `WarpOp` and `ForOp`, they should not have any uses
1887 // at this point.
1888 rewriter.eraseOp(op: forOp);
1889 rewriter.eraseOp(op: warpOp);
1890 // Update any users of escaping values that were forwarded to the
1891 // inner `WarpOp`. These values are now arguments of the inner `WarpOp`.
1892 newForOp.walk(callback: [&](Operation *op) {
1893 for (OpOperand &operand : op->getOpOperands()) {
1894 auto it = argIndexMapping.find(Val: operand.get());
1895 if (it == argIndexMapping.end())
1896 continue;
1897 operand.set(innerWarp.getBodyRegion().getArgument(i: it->second));
1898 }
1899 });
1900
1901 // Finally, hoist out any now uniform code from the inner `WarpOp`.
1902 mlir::vector::moveScalarUniformCode(op: innerWarp);
1903 return success();
1904 }
1905
1906private:
1907 DistributionMapFn distributionMapFn;
1908};
1909
1910/// A pattern that extracts vector.reduction ops from a WarpExecuteOnLane0Op.
1911/// The vector is reduced in parallel. Currently limited to vector size
1912/// matching the warpOp size. E.g.:
1913/// ```
1914/// %r = gpu.warp_execute_on_lane_0(%laneid)[32] -> (f32) {
1915/// %0 = "some_def"() : () -> (vector<32xf32>)
1916/// %1 = vector.reduction "add", %0 : vector<32xf32> into f32
1917/// gpu.yield %1 : f32
1918/// }
1919/// ```
1920/// is lowered to:
1921/// ```
1922/// %0 = gpu.warp_execute_on_lane_0(%laneid)[32] -> (vector<1xf32>) {
1923/// %1 = "some_def"() : () -> (vector<32xf32>)
1924/// gpu.yield %1 : vector<32xf32>
1925/// }
1926/// %a = vector.extract %0[0] : f32 from vector<1xf32>
1927/// %r = ("warp.reduction %a")
1928/// ```
1929struct WarpOpReduction : public WarpDistributionPattern {
1930 WarpOpReduction(MLIRContext *context,
1931 DistributedReductionFn distributedReductionFn,
1932 PatternBenefit benefit = 1)
1933 : WarpDistributionPattern(context, benefit),
1934 distributedReductionFn(std::move(distributedReductionFn)) {}
1935
1936 LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
1937 PatternRewriter &rewriter) const override {
1938 OpOperand *yieldOperand =
1939 getWarpResult(warpOp, fn: llvm::IsaPred<vector::ReductionOp>);
1940 if (!yieldOperand)
1941 return failure();
1942
1943 auto reductionOp =
1944 cast<vector::ReductionOp>(Val: yieldOperand->get().getDefiningOp());
1945 auto vectorType = cast<VectorType>(Val: reductionOp.getVector().getType());
1946 // Only rank 1 vectors supported.
1947 if (vectorType.getRank() != 1)
1948 return rewriter.notifyMatchFailure(
1949 arg&: warpOp, msg: "Only rank 1 reductions can be distributed.");
1950 // Only warp_size-sized vectors supported.
1951 if (vectorType.getShape()[0] % warpOp.getWarpSize() != 0)
1952 return rewriter.notifyMatchFailure(
1953 arg&: warpOp, msg: "Reduction vector dimension must match was size.");
1954 if (!reductionOp.getType().isIntOrFloat())
1955 return rewriter.notifyMatchFailure(
1956 arg&: warpOp, msg: "Reduction distribution currently only supports floats and "
1957 "integer types.");
1958
1959 int64_t numElements = vectorType.getShape()[0] / warpOp.getWarpSize();
1960 // Return vector that will be reduced from the WarpExecuteOnLane0Op.
1961 unsigned operandIndex = yieldOperand->getOperandNumber();
1962 SmallVector<Value> yieldValues = {reductionOp.getVector()};
1963 SmallVector<Type> retTypes = {
1964 VectorType::get(shape: {numElements}, elementType: reductionOp.getType())};
1965 if (reductionOp.getAcc()) {
1966 yieldValues.push_back(Elt: reductionOp.getAcc());
1967 retTypes.push_back(Elt: reductionOp.getAcc().getType());
1968 }
1969 SmallVector<size_t> newRetIndices;
1970 WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
1971 rewriter, warpOp, newYieldedValues: yieldValues, newReturnTypes: retTypes, indices&: newRetIndices);
1972 rewriter.setInsertionPointAfter(newWarpOp);
1973
1974 // Obtain data to reduce for a single lane.
1975 Value laneValVec = newWarpOp.getResult(i: newRetIndices[0]);
1976 // Distribute and reduce across threads.
1977 Value fullReduce =
1978 distributedReductionFn(reductionOp.getLoc(), rewriter, laneValVec,
1979 reductionOp.getKind(), newWarpOp.getWarpSize());
1980 if (reductionOp.getAcc()) {
1981 fullReduce = vector::makeArithReduction(
1982 b&: rewriter, loc: reductionOp.getLoc(), kind: reductionOp.getKind(), v1: fullReduce,
1983 acc: newWarpOp.getResult(i: newRetIndices[1]));
1984 }
1985 rewriter.replaceAllUsesWith(from: newWarpOp.getResult(i: operandIndex), to: fullReduce);
1986 return success();
1987 }
1988
1989private:
1990 DistributedReductionFn distributedReductionFn;
1991};
1992
1993} // namespace
1994
1995void mlir::vector::populateWarpExecuteOnLane0OpToScfForPattern(
1996 RewritePatternSet &patterns,
1997 const WarpExecuteOnLane0LoweringOptions &options, PatternBenefit benefit) {
1998 patterns.add<WarpOpToScfIfPattern>(arg: patterns.getContext(), args: options, args&: benefit);
1999}
2000
2001void mlir::vector::populateDistributeTransferWriteOpPatterns(
2002 RewritePatternSet &patterns, const DistributionMapFn &distributionMapFn,
2003 unsigned maxNumElementsToExtract, PatternBenefit benefit) {
2004 patterns.add<WarpOpTransferWrite>(arg: patterns.getContext(), args: distributionMapFn,
2005 args&: maxNumElementsToExtract, args&: benefit);
2006}
2007
2008void mlir::vector::populatePropagateWarpVectorDistributionPatterns(
2009 RewritePatternSet &patterns, const DistributionMapFn &distributionMapFn,
2010 const WarpShuffleFromIdxFn &warpShuffleFromIdxFn, PatternBenefit benefit,
2011 PatternBenefit readBenefit) {
2012 patterns.add<WarpOpTransferRead>(arg: patterns.getContext(), args&: readBenefit);
2013 patterns
2014 .add<WarpOpElementwise, WarpOpDeadResult, WarpOpBroadcast,
2015 WarpOpShapeCast, WarpOpExtract, WarpOpForwardOperand, WarpOpConstant,
2016 WarpOpInsertScalar, WarpOpInsert, WarpOpCreateMask,
2017 WarpOpExtractStridedSlice, WarpOpInsertStridedSlice>(
2018 arg: patterns.getContext(), args&: benefit);
2019 patterns.add<WarpOpExtractScalar>(arg: patterns.getContext(), args: warpShuffleFromIdxFn,
2020 args&: benefit);
2021 patterns.add<WarpOpScfForOp>(arg: patterns.getContext(), args: distributionMapFn,
2022 args&: benefit);
2023}
2024
2025void mlir::vector::populateDistributeReduction(
2026 RewritePatternSet &patterns,
2027 const DistributedReductionFn &distributedReductionFn,
2028 PatternBenefit benefit) {
2029 patterns.add<WarpOpReduction>(arg: patterns.getContext(), args: distributedReductionFn,
2030 args&: benefit);
2031}
2032
2033/// Helper to know if an op can be hoisted out of the region.
2034static bool canBeHoisted(Operation *op,
2035 function_ref<bool(Value)> definedOutside) {
2036 return llvm::all_of(Range: op->getOperands(), P: definedOutside) &&
2037 isMemoryEffectFree(op) && op->getNumRegions() == 0;
2038}
2039
2040void mlir::vector::moveScalarUniformCode(WarpExecuteOnLane0Op warpOp) {
2041 Block *body = warpOp.getBody();
2042
2043 // Keep track of the ops we want to hoist.
2044 llvm::SmallSetVector<Operation *, 8> opsToMove;
2045
2046 // Helper to check if a value is or will be defined outside of the region.
2047 auto isDefinedOutsideOfBody = [&](Value value) {
2048 auto *definingOp = value.getDefiningOp();
2049 return (definingOp && opsToMove.count(key: definingOp)) ||
2050 warpOp.isDefinedOutsideOfRegion(value);
2051 };
2052
2053 // Do not use walk here, as we do not want to go into nested regions and hoist
2054 // operations from there.
2055 for (auto &op : body->without_terminator()) {
2056 bool hasVectorResult = llvm::any_of(Range: op.getResults(), P: [](Value result) {
2057 return isa<VectorType>(Val: result.getType());
2058 });
2059 if (!hasVectorResult && canBeHoisted(op: &op, definedOutside: isDefinedOutsideOfBody))
2060 opsToMove.insert(X: &op);
2061 }
2062
2063 // Move all the ops marked as uniform outside of the region.
2064 for (Operation *op : opsToMove)
2065 op->moveBefore(existingOp: warpOp);
2066}
2067

source code of mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp