1//===- VectorDistribute.cpp - patterns to do vector distribution ----------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#include "mlir/Dialect/Affine/IR/AffineOps.h"
10#include "mlir/Dialect/Arith/IR/Arith.h"
11#include "mlir/Dialect/MemRef/IR/MemRef.h"
12#include "mlir/Dialect/SCF/IR/SCF.h"
13#include "mlir/Dialect/Vector/IR/VectorOps.h"
14#include "mlir/Dialect/Vector/Transforms/VectorDistribution.h"
15#include "mlir/IR/AffineExpr.h"
16#include "mlir/Interfaces/SideEffectInterfaces.h"
17#include "mlir/Transforms/RegionUtils.h"
18#include "llvm/ADT/SetVector.h"
19#include "llvm/Support/FormatVariadic.h"
20#include <numeric>
21#include <utility>
22
23using namespace mlir;
24using namespace mlir::vector;
25
26/// Currently the distribution map is implicit based on the vector shape. In the
27/// future it will be part of the op.
28/// Example:
29/// ```
30/// %0 = vector.warp_execute_on_lane_0(%arg0) -> (vector<1x16x2xf32>) {
31/// ...
32/// vector.yield %3 : vector<32x16x64xf32>
33/// }
34/// ```
35/// Would have an implicit map of:
36/// `(d0, d1, d2) -> (d0, d2)`
37static AffineMap calculateImplicitMap(VectorType sequentialType,
38 VectorType distributedType) {
39 SmallVector<AffineExpr> perm;
40 perm.reserve(N: 1);
41 // Check which dimensions of the sequential type are different than the
42 // dimensions of the distributed type to know the distributed dimensions. Then
43 // associate each distributed dimension to an ID in order.
44 for (unsigned i = 0, e = sequentialType.getRank(); i < e; i++) {
45 if (sequentialType.getDimSize(i) != distributedType.getDimSize(i))
46 perm.push_back(Elt: getAffineDimExpr(i, distributedType.getContext()));
47 }
48 auto map = AffineMap::get(sequentialType.getRank(), 0, perm,
49 distributedType.getContext());
50 return map;
51}
52
53namespace {
54
55/// Helper struct to create the load / store operations that permit transit
56/// through the parallel / sequential and the sequential / parallel boundaries
57/// when performing `rewriteWarpOpToScfFor`.
58///
59/// The vector distribution dimension is inferred from the vector types.
60struct DistributedLoadStoreHelper {
61 DistributedLoadStoreHelper(Value sequentialVal, Value distributedVal,
62 Value laneId, Value zero)
63 : sequentialVal(sequentialVal), distributedVal(distributedVal),
64 laneId(laneId), zero(zero) {
65 sequentialVectorType = dyn_cast<VectorType>(sequentialVal.getType());
66 distributedVectorType = dyn_cast<VectorType>(distributedVal.getType());
67 if (sequentialVectorType && distributedVectorType)
68 distributionMap =
69 calculateImplicitMap(sequentialVectorType, distributedVectorType);
70 }
71
72 Value buildDistributedOffset(RewriterBase &b, Location loc, int64_t index) {
73 int64_t distributedSize = distributedVectorType.getDimSize(index);
74 AffineExpr tid = getAffineSymbolExpr(position: 0, context: b.getContext());
75 return b.createOrFold<affine::AffineApplyOp>(loc, tid * distributedSize,
76 ArrayRef<Value>{laneId});
77 }
78
79 /// Create a store during the process of distributing the
80 /// `vector.warp_execute_on_thread_0` op.
81 /// Vector distribution assumes the following convention regarding the
82 /// temporary buffers that are created to transition values. This **must**
83 /// be properly specified in the `options.warpAllocationFn`:
84 /// 1. scalars of type T transit through a memref<1xT>.
85 /// 2. vectors of type V<shapexT> transit through a memref<shapexT>
86 Operation *buildStore(RewriterBase &b, Location loc, Value val,
87 Value buffer) {
88 assert((val == distributedVal || val == sequentialVal) &&
89 "Must store either the preregistered distributed or the "
90 "preregistered sequential value.");
91 // Scalar case can directly use memref.store.
92 if (!isa<VectorType>(val.getType()))
93 return b.create<memref::StoreOp>(loc, val, buffer, zero);
94
95 // Vector case must use vector::TransferWriteOp which will later lower to
96 // vector.store of memref.store depending on further lowerings.
97 int64_t rank = sequentialVectorType.getRank();
98 SmallVector<Value> indices(rank, zero);
99 if (val == distributedVal) {
100 for (auto dimExpr : distributionMap.getResults()) {
101 int64_t index = cast<AffineDimExpr>(Val&: dimExpr).getPosition();
102 indices[index] = buildDistributedOffset(b, loc, index);
103 }
104 }
105 SmallVector<bool> inBounds(indices.size(), true);
106 return b.create<vector::TransferWriteOp>(
107 loc, val, buffer, indices,
108 ArrayRef<bool>(inBounds.begin(), inBounds.end()));
109 }
110
111 /// Create a load during the process of distributing the
112 /// `vector.warp_execute_on_thread_0` op.
113 /// Vector distribution assumes the following convention regarding the
114 /// temporary buffers that are created to transition values. This **must**
115 /// be properly specified in the `options.warpAllocationFn`:
116 /// 1. scalars of type T transit through a memref<1xT>.
117 /// 2. vectors of type V<shapexT> transit through a memref<shapexT>
118 ///
119 /// When broadcastMode is true, the load is not distributed to account for
120 /// the broadcast semantics of the `vector.warp_execute_on_lane_0` op.
121 ///
122 /// Example:
123 ///
124 /// ```
125 /// %r = vector.warp_execute_on_lane_0(...) -> (f32) {
126 /// vector.yield %cst : f32
127 /// }
128 /// // Both types are f32. The constant %cst is broadcasted to all lanes.
129 /// ```
130 /// This behavior described in more detail in the documentation of the op.
131 Value buildLoad(RewriterBase &b, Location loc, Type type, Value buffer) {
132
133 // Scalar case can directly use memref.store.
134 if (!isa<VectorType>(type))
135 return b.create<memref::LoadOp>(loc, buffer, zero);
136
137 // Other cases must be vector atm.
138 // Vector case must use vector::TransferReadOp which will later lower to
139 // vector.read of memref.read depending on further lowerings.
140 assert((type == distributedVectorType || type == sequentialVectorType) &&
141 "Must store either the preregistered distributed or the "
142 "preregistered sequential type.");
143 SmallVector<Value> indices(sequentialVectorType.getRank(), zero);
144 if (type == distributedVectorType) {
145 for (auto dimExpr : distributionMap.getResults()) {
146 int64_t index = cast<AffineDimExpr>(Val&: dimExpr).getPosition();
147 indices[index] = buildDistributedOffset(b, loc, index);
148 }
149 }
150 SmallVector<bool> inBounds(indices.size(), true);
151 return b.create<vector::TransferReadOp>(
152 loc, cast<VectorType>(type), buffer, indices,
153 ArrayRef<bool>(inBounds.begin(), inBounds.end()));
154 }
155
156 Value sequentialVal, distributedVal, laneId, zero;
157 VectorType sequentialVectorType, distributedVectorType;
158 AffineMap distributionMap;
159};
160
161} // namespace
162
163/// Helper to create a new WarpExecuteOnLane0Op with different signature.
164static WarpExecuteOnLane0Op moveRegionToNewWarpOpAndReplaceReturns(
165 RewriterBase &rewriter, WarpExecuteOnLane0Op warpOp,
166 ValueRange newYieldedValues, TypeRange newReturnTypes) {
167 // Create a new op before the existing one, with the extra operands.
168 OpBuilder::InsertionGuard g(rewriter);
169 rewriter.setInsertionPoint(warpOp);
170 auto newWarpOp = rewriter.create<WarpExecuteOnLane0Op>(
171 warpOp.getLoc(), newReturnTypes, warpOp.getLaneid(), warpOp.getWarpSize(),
172 warpOp.getArgs(), warpOp.getBody()->getArgumentTypes());
173
174 Region &opBody = warpOp.getBodyRegion();
175 Region &newOpBody = newWarpOp.getBodyRegion();
176 Block &newOpFirstBlock = newOpBody.front();
177 rewriter.inlineRegionBefore(region&: opBody, parent&: newOpBody, before: newOpBody.begin());
178 rewriter.eraseBlock(block: &newOpFirstBlock);
179 assert(newWarpOp.getWarpRegion().hasOneBlock() &&
180 "expected WarpOp with single block");
181
182 auto yield =
183 cast<vector::YieldOp>(newOpBody.getBlocks().begin()->getTerminator());
184
185 rewriter.modifyOpInPlace(
186 yield, [&]() { yield.getOperandsMutable().assign(newYieldedValues); });
187 return newWarpOp;
188}
189
190/// Helper to create a new WarpExecuteOnLane0Op region with extra outputs.
191/// `indices` return the index of each new output.
192static WarpExecuteOnLane0Op moveRegionToNewWarpOpAndAppendReturns(
193 RewriterBase &rewriter, WarpExecuteOnLane0Op warpOp,
194 ValueRange newYieldedValues, TypeRange newReturnTypes,
195 llvm::SmallVector<size_t> &indices) {
196 SmallVector<Type> types(warpOp.getResultTypes().begin(),
197 warpOp.getResultTypes().end());
198 auto yield = cast<vector::YieldOp>(
199 warpOp.getBodyRegion().getBlocks().begin()->getTerminator());
200 llvm::SmallSetVector<Value, 32> yieldValues(yield.getOperands().begin(),
201 yield.getOperands().end());
202 for (auto newRet : llvm::zip(t&: newYieldedValues, u&: newReturnTypes)) {
203 if (yieldValues.insert(X: std::get<0>(t&: newRet))) {
204 types.push_back(Elt: std::get<1>(t&: newRet));
205 indices.push_back(Elt: yieldValues.size() - 1);
206 } else {
207 // If the value already exit the region don't create a new output.
208 for (auto [idx, yieldOperand] :
209 llvm::enumerate(yieldValues.getArrayRef())) {
210 if (yieldOperand == std::get<0>(newRet)) {
211 indices.push_back(idx);
212 break;
213 }
214 }
215 }
216 }
217 yieldValues.insert(Start: newYieldedValues.begin(), End: newYieldedValues.end());
218 WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndReplaceReturns(
219 rewriter, warpOp, yieldValues.getArrayRef(), types);
220 rewriter.replaceOp(warpOp,
221 newWarpOp.getResults().take_front(warpOp.getNumResults()));
222 return newWarpOp;
223}
224
225/// Helper to know if an op can be hoisted out of the region.
226static bool canBeHoisted(Operation *op,
227 function_ref<bool(Value)> definedOutside) {
228 return llvm::all_of(Range: op->getOperands(), P: definedOutside) &&
229 isMemoryEffectFree(op) && op->getNumRegions() == 0;
230}
231
232/// Return a value yielded by `warpOp` which statifies the filter lamdba
233/// condition and is not dead.
234static OpOperand *getWarpResult(WarpExecuteOnLane0Op warpOp,
235 const std::function<bool(Operation *)> &fn) {
236 auto yield = cast<vector::YieldOp>(
237 warpOp.getBodyRegion().getBlocks().begin()->getTerminator());
238 for (OpOperand &yieldOperand : yield->getOpOperands()) {
239 Value yieldValues = yieldOperand.get();
240 Operation *definedOp = yieldValues.getDefiningOp();
241 if (definedOp && fn(definedOp)) {
242 if (!warpOp.getResult(yieldOperand.getOperandNumber()).use_empty())
243 return &yieldOperand;
244 }
245 }
246 return {};
247}
248
249// Clones `op` into a new operation that takes `operands` and returns
250// `resultTypes`.
251static Operation *cloneOpWithOperandsAndTypes(RewriterBase &rewriter,
252 Location loc, Operation *op,
253 ArrayRef<Value> operands,
254 ArrayRef<Type> resultTypes) {
255 OperationState res(loc, op->getName().getStringRef(), operands, resultTypes,
256 op->getAttrs());
257 return rewriter.create(state: res);
258}
259
260namespace {
261
262/// Rewrite a WarpExecuteOnLane0Op into a predicated scf.if op where the single
263/// thread `laneId` executes the entirety of the computation.
264///
265/// After the transformation:
266/// - the IR within the scf.if op can be thought of as executing sequentially
267/// (from the point of view of threads along `laneId`).
268/// - the IR outside of the scf.if op can be thought of as executing in
269/// parallel (from the point of view of threads along `laneId`).
270///
271/// Values that need to transit through the parallel / sequential and the
272/// sequential / parallel boundaries do so via reads and writes to a temporary
273/// memory location.
274///
275/// The transformation proceeds in multiple steps:
276/// 1. Create the scf.if op.
277/// 2. Insert appropriate (alloc, write)-pairs before the scf.if and reads
278/// within the scf.if to transit the values captured from above.
279/// 3. Synchronize before the scf.if to ensure all writes inserted in 2. are
280/// consistent within the scf.if.
281/// 4. Move the body of the WarpExecuteOnLane0Op inside the scf.if.
282/// 5. Insert appropriate writes within scf.if and reads after the scf.if to
283/// transit the values returned by the op.
284/// 6. Synchronize after the scf.if to ensure all writes inserted in 5. are
285/// consistent after the scf.if.
286/// 7. Perform late cleanups.
287///
288/// All this assumes the vector distribution occurs along the most minor
289/// distributed vector dimension.
290struct WarpOpToScfIfPattern : public OpRewritePattern<WarpExecuteOnLane0Op> {
291 WarpOpToScfIfPattern(MLIRContext *context,
292 const WarpExecuteOnLane0LoweringOptions &options,
293 PatternBenefit benefit = 1)
294 : OpRewritePattern<WarpExecuteOnLane0Op>(context, benefit),
295 options(options) {}
296
297 LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
298 PatternRewriter &rewriter) const override {
299 assert(warpOp.getBodyRegion().hasOneBlock() &&
300 "expected WarpOp with single block");
301 Block *warpOpBody = &warpOp.getBodyRegion().front();
302 Location loc = warpOp.getLoc();
303
304 // Passed all checks. Start rewriting.
305 OpBuilder::InsertionGuard g(rewriter);
306 rewriter.setInsertionPoint(warpOp);
307
308 // Step 1: Create scf.if op.
309 Value c0 = rewriter.create<arith::ConstantIndexOp>(location: loc, args: 0);
310 Value isLane0 = rewriter.create<arith::CmpIOp>(
311 loc, arith::CmpIPredicate::eq, warpOp.getLaneid(), c0);
312 auto ifOp = rewriter.create<scf::IfOp>(loc, isLane0,
313 /*withElseRegion=*/false);
314 rewriter.eraseOp(op: ifOp.thenBlock()->getTerminator());
315
316 // Step 2: insert appropriate (alloc, write)-pairs before the scf.if and
317 // reads within the scf.if to transit the values captured from above.
318 SmallVector<Value> bbArgReplacements;
319 for (const auto &it : llvm::enumerate(warpOp.getArgs())) {
320 Value sequentialVal = warpOpBody->getArgument(it.index());
321 Value distributedVal = it.value();
322 DistributedLoadStoreHelper helper(sequentialVal, distributedVal,
323 warpOp.getLaneid(), c0);
324
325 // Create buffer before the ifOp.
326 rewriter.setInsertionPoint(ifOp);
327 Value buffer = options.warpAllocationFn(loc, rewriter, warpOp,
328 sequentialVal.getType());
329 // Store distributed vector into buffer, before the ifOp.
330 helper.buildStore(rewriter, loc, distributedVal, buffer);
331 // Load sequential vector from buffer, inside the ifOp.
332 rewriter.setInsertionPointToStart(ifOp.thenBlock());
333 bbArgReplacements.push_back(
334 helper.buildLoad(rewriter, loc, sequentialVal.getType(), buffer));
335 }
336
337 // Step 3. Insert sync after all the stores and before all the loads.
338 if (!warpOp.getArgs().empty()) {
339 rewriter.setInsertionPoint(ifOp);
340 options.warpSyncronizationFn(loc, rewriter, warpOp);
341 }
342
343 // Step 4. Move body of warpOp to ifOp.
344 rewriter.mergeBlocks(source: warpOpBody, dest: ifOp.thenBlock(), argValues: bbArgReplacements);
345
346 // Step 5. Insert appropriate writes within scf.if and reads after the
347 // scf.if to transit the values returned by the op.
348 // TODO: at this point, we can reuse the shared memory from previous
349 // buffers.
350 SmallVector<Value> replacements;
351 auto yieldOp = cast<vector::YieldOp>(ifOp.thenBlock()->getTerminator());
352 Location yieldLoc = yieldOp.getLoc();
353 for (const auto &it : llvm::enumerate(yieldOp.getOperands())) {
354 Value sequentialVal = it.value();
355 Value distributedVal = warpOp->getResult(it.index());
356 DistributedLoadStoreHelper helper(sequentialVal, distributedVal,
357 warpOp.getLaneid(), c0);
358
359 // Create buffer before the ifOp.
360 rewriter.setInsertionPoint(ifOp);
361 Value buffer = options.warpAllocationFn(loc, rewriter, warpOp,
362 sequentialVal.getType());
363
364 // Store yielded value into buffer, inside the ifOp, before the
365 // terminator.
366 rewriter.setInsertionPoint(yieldOp);
367 helper.buildStore(rewriter, loc, sequentialVal, buffer);
368
369 // Load distributed value from buffer, after the warpOp.
370 rewriter.setInsertionPointAfter(ifOp);
371 // Result type and yielded value type are the same. This is a broadcast.
372 // E.g.:
373 // %r = vector.warp_execute_on_lane_0(...) -> (f32) {
374 // vector.yield %cst : f32
375 // }
376 // Both types are f32. The constant %cst is broadcasted to all lanes.
377 // This is described in more detail in the documentation of the op.
378 replacements.push_back(
379 helper.buildLoad(rewriter, loc, distributedVal.getType(), buffer));
380 }
381
382 // Step 6. Insert sync after all the stores and before all the loads.
383 if (!yieldOp.getOperands().empty()) {
384 rewriter.setInsertionPointAfter(ifOp);
385 options.warpSyncronizationFn(loc, rewriter, warpOp);
386 }
387
388 // Step 7. Delete terminator and add empty scf.yield.
389 rewriter.eraseOp(op: yieldOp);
390 rewriter.setInsertionPointToEnd(ifOp.thenBlock());
391 rewriter.create<scf::YieldOp>(yieldLoc);
392
393 // Compute replacements for WarpOp results.
394 rewriter.replaceOp(warpOp, replacements);
395
396 return success();
397 }
398
399private:
400 const WarpExecuteOnLane0LoweringOptions &options;
401};
402
403/// Clone `writeOp` assumed to be nested under `warpOp` into a new warp execute
404/// op with the proper return type.
405/// The new write op is updated to write the result of the new warp execute op.
406/// The old `writeOp` is deleted.
407static vector::TransferWriteOp cloneWriteOp(RewriterBase &rewriter,
408 WarpExecuteOnLane0Op warpOp,
409 vector::TransferWriteOp writeOp,
410 VectorType targetType,
411 VectorType maybeMaskType) {
412 assert(writeOp->getParentOp() == warpOp &&
413 "write must be nested immediately under warp");
414 OpBuilder::InsertionGuard g(rewriter);
415 SmallVector<size_t> newRetIndices;
416 WarpExecuteOnLane0Op newWarpOp;
417 if (maybeMaskType) {
418 newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
419 rewriter, warpOp, ValueRange{writeOp.getVector(), writeOp.getMask()},
420 TypeRange{targetType, maybeMaskType}, newRetIndices);
421 } else {
422 newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
423 rewriter, warpOp, ValueRange{{writeOp.getVector()}},
424 TypeRange{targetType}, newRetIndices);
425 }
426 rewriter.setInsertionPointAfter(newWarpOp);
427 auto newWriteOp =
428 cast<vector::TransferWriteOp>(rewriter.clone(*writeOp.getOperation()));
429 rewriter.eraseOp(op: writeOp);
430 newWriteOp.getVectorMutable().assign(newWarpOp.getResult(newRetIndices[0]));
431 if (maybeMaskType)
432 newWriteOp.getMaskMutable().assign(newWarpOp.getResult(newRetIndices[1]));
433 return newWriteOp;
434}
435
436/// Return the distributed vector type based on the original type and the
437/// distribution map. The map is expected to have a dimension equal to the
438/// original type rank and should be a projection where the results are the
439/// distributed dimensions. The number of results should be equal to the number
440/// of warp sizes which is currently limited to 1.
441/// Example: For a vector<16x32x64> distributed with a map(d0, d1, d2) -> (d1)
442/// and a warp size of 16 would distribute the second dimension (associated to
443/// d1) and return vector<16x2x64>
444static VectorType getDistributedType(VectorType originalType, AffineMap map,
445 int64_t warpSize) {
446 SmallVector<int64_t> targetShape(originalType.getShape().begin(),
447 originalType.getShape().end());
448 for (unsigned i = 0, e = map.getNumResults(); i < e; i++) {
449 unsigned position = map.getDimPosition(idx: i);
450 if (targetShape[position] % warpSize != 0) {
451 if (warpSize % targetShape[position] != 0) {
452 return VectorType();
453 }
454 warpSize /= targetShape[position];
455 targetShape[position] = 1;
456 continue;
457 }
458 targetShape[position] = targetShape[position] / warpSize;
459 warpSize = 1;
460 break;
461 }
462 if (warpSize != 1) {
463 return VectorType();
464 }
465 VectorType targetType =
466 VectorType::get(targetShape, originalType.getElementType());
467 return targetType;
468}
469
470/// Distribute transfer_write ops based on the affine map returned by
471/// `distributionMapFn`. Writes of size more than `maxNumElementToExtract`
472/// will not be distributed (it should be less than the warp size).
473///
474/// Example:
475/// ```
476/// %0 = vector.warp_execute_on_lane_0(%id){
477/// ...
478/// vector.transfer_write %v, %A[%c0] : vector<32xf32>, memref<128xf32>
479/// vector.yield
480/// }
481/// ```
482/// To
483/// ```
484/// %r:3 = vector.warp_execute_on_lane_0(%id) -> (vector<1xf32>) {
485/// ...
486/// vector.yield %v : vector<32xf32>
487/// }
488/// vector.transfer_write %v, %A[%id] : vector<1xf32>, memref<128xf32>
489struct WarpOpTransferWrite : public OpRewritePattern<WarpExecuteOnLane0Op> {
490 WarpOpTransferWrite(MLIRContext *ctx, DistributionMapFn fn,
491 unsigned maxNumElementsToExtract, PatternBenefit b = 1)
492 : OpRewritePattern<WarpExecuteOnLane0Op>(ctx, b),
493 distributionMapFn(std::move(fn)),
494 maxNumElementsToExtract(maxNumElementsToExtract) {}
495
496 /// Distribute the TransferWriteOp. Only 1D distributions and vector dims that
497 /// are multiples of the distribution ratio are supported at the moment.
498 LogicalResult tryDistributeOp(RewriterBase &rewriter,
499 vector::TransferWriteOp writeOp,
500 WarpExecuteOnLane0Op warpOp) const {
501 VectorType writtenVectorType = writeOp.getVectorType();
502
503 // 1. If the write is 0-D, we just clone it into a new WarpExecuteOnLane0Op
504 // to separate it from the rest.
505 if (writtenVectorType.getRank() == 0)
506 return failure();
507
508 // 2. Compute the distributed type.
509 AffineMap map = distributionMapFn(writeOp.getVector());
510 VectorType targetType =
511 getDistributedType(writtenVectorType, map, warpOp.getWarpSize());
512 if (!targetType)
513 return failure();
514
515 // 2.5 Compute the distributed type for the new mask;
516 VectorType maskType;
517 if (writeOp.getMask()) {
518 // TODO: Distribution of masked writes with non-trivial permutation maps
519 // requires the distribution of the mask to elementwise match the
520 // distribution of the permuted written vector. Currently the details
521 // of which lane is responsible for which element is captured strictly
522 // by shape information on the warp op, and thus requires materializing
523 // the permutation in IR.
524 if (!writeOp.getPermutationMap().isMinorIdentity())
525 return failure();
526 maskType =
527 getDistributedType(writeOp.getMaskType(), map, warpOp.getWarpSize());
528 }
529
530 // 3. clone the write into a new WarpExecuteOnLane0Op to separate it from
531 // the rest.
532 vector::TransferWriteOp newWriteOp =
533 cloneWriteOp(rewriter, warpOp, writeOp, targetType, maskType);
534
535 // 4. Reindex the write using the distribution map.
536 auto newWarpOp =
537 newWriteOp.getVector().getDefiningOp<WarpExecuteOnLane0Op>();
538
539 // Delinearize the lane id based on the way threads are divided across the
540 // vector. To get the number of threads per vector dimension, divide the
541 // sequential size by the distributed size along each dim.
542 rewriter.setInsertionPoint(newWriteOp);
543 SmallVector<OpFoldResult> delinearizedIdSizes;
544 for (auto [seqSize, distSize] :
545 llvm::zip_equal(writtenVectorType.getShape(), targetType.getShape())) {
546 assert(seqSize % distSize == 0 && "Invalid distributed vector shape");
547 delinearizedIdSizes.push_back(rewriter.getIndexAttr(seqSize / distSize));
548 }
549 SmallVector<Value> delinearized;
550 if (map.getNumResults() > 1) {
551 delinearized = rewriter
552 .create<mlir::affine::AffineDelinearizeIndexOp>(
553 newWarpOp.getLoc(), newWarpOp.getLaneid(),
554 delinearizedIdSizes)
555 .getResults();
556 } else {
557 // If there is only one map result, we can elide the delinearization
558 // op and use the lane id directly.
559 delinearized.append(targetType.getRank(), newWarpOp.getLaneid());
560 }
561
562 AffineMap indexMap = map.compose(newWriteOp.getPermutationMap());
563 Location loc = newWriteOp.getLoc();
564 SmallVector<Value> indices(newWriteOp.getIndices().begin(),
565 newWriteOp.getIndices().end());
566 for (auto it : llvm::zip(indexMap.getResults(), map.getResults())) {
567 AffineExpr d0, d1;
568 bindDims(newWarpOp.getContext(), d0, d1);
569 auto indexExpr = dyn_cast<AffineDimExpr>(std::get<0>(it));
570 if (!indexExpr)
571 continue;
572 unsigned indexPos = indexExpr.getPosition();
573 unsigned vectorPos = cast<AffineDimExpr>(std::get<1>(it)).getPosition();
574 Value laneId = delinearized[vectorPos];
575 auto scale =
576 rewriter.getAffineConstantExpr(targetType.getDimSize(vectorPos));
577 indices[indexPos] = affine::makeComposedAffineApply(
578 rewriter, loc, d0 + scale * d1, {indices[indexPos], laneId});
579 }
580 newWriteOp.getIndicesMutable().assign(indices);
581
582 return success();
583 }
584
585 /// Extract TransferWriteOps of vector<1x> into a separate warp op.
586 LogicalResult tryExtractOp(RewriterBase &rewriter,
587 vector::TransferWriteOp writeOp,
588 WarpExecuteOnLane0Op warpOp) const {
589 Location loc = writeOp.getLoc();
590 VectorType vecType = writeOp.getVectorType();
591
592 if (vecType.getNumElements() > maxNumElementsToExtract) {
593 return rewriter.notifyMatchFailure(
594 warpOp,
595 llvm::formatv(
596 "writes more elements ({0}) than allowed to extract ({1})",
597 vecType.getNumElements(), maxNumElementsToExtract));
598 }
599
600 // Do not process warp ops that contain only TransferWriteOps.
601 if (llvm::all_of(warpOp.getOps(),
602 llvm::IsaPred<vector::TransferWriteOp, vector::YieldOp>))
603 return failure();
604
605 SmallVector<Value> yieldValues = {writeOp.getVector()};
606 SmallVector<Type> retTypes = {vecType};
607 SmallVector<size_t> newRetIndices;
608 WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
609 rewriter, warpOp, yieldValues, retTypes, newRetIndices);
610 rewriter.setInsertionPointAfter(newWarpOp);
611
612 // Create a second warp op that contains only writeOp.
613 auto secondWarpOp = rewriter.create<WarpExecuteOnLane0Op>(
614 loc, TypeRange(), newWarpOp.getLaneid(), newWarpOp.getWarpSize());
615 Block &body = secondWarpOp.getBodyRegion().front();
616 rewriter.setInsertionPointToStart(&body);
617 auto newWriteOp =
618 cast<vector::TransferWriteOp>(rewriter.clone(*writeOp.getOperation()));
619 newWriteOp.getVectorMutable().assign(newWarpOp.getResult(newRetIndices[0]));
620 rewriter.eraseOp(op: writeOp);
621 rewriter.create<vector::YieldOp>(newWarpOp.getLoc());
622 return success();
623 }
624
625 LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
626 PatternRewriter &rewriter) const override {
627 auto yield = cast<vector::YieldOp>(
628 warpOp.getBodyRegion().getBlocks().begin()->getTerminator());
629 Operation *lastNode = yield->getPrevNode();
630 auto writeOp = dyn_cast_or_null<vector::TransferWriteOp>(lastNode);
631 if (!writeOp)
632 return failure();
633
634 Value maybeMask = writeOp.getMask();
635 if (!llvm::all_of(writeOp->getOperands(), [&](Value value) {
636 return writeOp.getVector() == value ||
637 (maybeMask && maybeMask == value) ||
638 warpOp.isDefinedOutsideOfRegion(value);
639 }))
640 return failure();
641
642 if (succeeded(tryDistributeOp(rewriter, writeOp, warpOp)))
643 return success();
644
645 // Masked writes not supported for extraction.
646 if (writeOp.getMask())
647 return failure();
648
649 if (succeeded(tryExtractOp(rewriter, writeOp, warpOp)))
650 return success();
651
652 return failure();
653 }
654
655private:
656 DistributionMapFn distributionMapFn;
657 unsigned maxNumElementsToExtract = 1;
658};
659
660/// Sink out elementwise op feeding into a warp op yield.
661/// ```
662/// %0 = vector.warp_execute_on_lane_0(%arg0) -> (vector<1xf32>) {
663/// ...
664/// %3 = arith.addf %1, %2 : vector<32xf32>
665/// vector.yield %3 : vector<32xf32>
666/// }
667/// ```
668/// To
669/// ```
670/// %r:3 = vector.warp_execute_on_lane_0(%arg0) -> (vector<1xf32>,
671/// vector<1xf32>, vector<1xf32>) {
672/// ...
673/// %4 = arith.addf %2, %3 : vector<32xf32>
674/// vector.yield %4, %2, %3 : vector<32xf32>, vector<32xf32>,
675/// vector<32xf32>
676/// }
677/// %0 = arith.addf %r#1, %r#2 : vector<1xf32>
678struct WarpOpElementwise : public OpRewritePattern<WarpExecuteOnLane0Op> {
679 using OpRewritePattern<WarpExecuteOnLane0Op>::OpRewritePattern;
680 LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
681 PatternRewriter &rewriter) const override {
682 OpOperand *yieldOperand = getWarpResult(warpOp, [](Operation *op) {
683 return OpTrait::hasElementwiseMappableTraits(op);
684 });
685 if (!yieldOperand)
686 return failure();
687
688 Operation *elementWise = yieldOperand->get().getDefiningOp();
689 unsigned operandIndex = yieldOperand->getOperandNumber();
690 Value distributedVal = warpOp.getResult(operandIndex);
691 SmallVector<Value> yieldValues;
692 SmallVector<Type> retTypes;
693 Location loc = warpOp.getLoc();
694 for (OpOperand &operand : elementWise->getOpOperands()) {
695 Type targetType;
696 if (auto vecType = dyn_cast<VectorType>(distributedVal.getType())) {
697 // If the result type is a vector, the operands must also be vectors.
698 auto operandType = cast<VectorType>(operand.get().getType());
699 targetType =
700 VectorType::get(vecType.getShape(), operandType.getElementType());
701 } else {
702 auto operandType = operand.get().getType();
703 assert(!isa<VectorType>(operandType) &&
704 "unexpected yield of vector from op with scalar result type");
705 targetType = operandType;
706 }
707 retTypes.push_back(Elt: targetType);
708 yieldValues.push_back(Elt: operand.get());
709 }
710 SmallVector<size_t> newRetIndices;
711 WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
712 rewriter, warpOp, yieldValues, retTypes, newRetIndices);
713 rewriter.setInsertionPointAfter(newWarpOp);
714 SmallVector<Value> newOperands(elementWise->getOperands().begin(),
715 elementWise->getOperands().end());
716 for (unsigned i : llvm::seq(Begin: unsigned(0), End: elementWise->getNumOperands())) {
717 newOperands[i] = newWarpOp.getResult(newRetIndices[i]);
718 }
719 OpBuilder::InsertionGuard g(rewriter);
720 rewriter.setInsertionPointAfter(newWarpOp);
721 Operation *newOp = cloneOpWithOperandsAndTypes(
722 rewriter, loc, elementWise, newOperands,
723 {newWarpOp.getResult(operandIndex).getType()});
724 rewriter.replaceAllUsesWith(newWarpOp.getResult(operandIndex),
725 newOp->getResult(idx: 0));
726 return success();
727 }
728};
729
730/// Sink out splat constant op feeding into a warp op yield.
731/// ```
732/// %0 = vector.warp_execute_on_lane_0(%arg0) -> (vector<1xf32>) {
733/// ...
734/// %cst = arith.constant dense<2.0> : vector<32xf32>
735/// vector.yield %cst : vector<32xf32>
736/// }
737/// ```
738/// To
739/// ```
740/// vector.warp_execute_on_lane_0(%arg0 {
741/// ...
742/// }
743/// %0 = arith.constant dense<2.0> : vector<1xf32>
744struct WarpOpConstant : public OpRewritePattern<WarpExecuteOnLane0Op> {
745 using OpRewritePattern<WarpExecuteOnLane0Op>::OpRewritePattern;
746 LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
747 PatternRewriter &rewriter) const override {
748 OpOperand *yieldOperand =
749 getWarpResult(warpOp, llvm::IsaPred<arith::ConstantOp>);
750 if (!yieldOperand)
751 return failure();
752 auto constantOp = yieldOperand->get().getDefiningOp<arith::ConstantOp>();
753 auto dense = dyn_cast<SplatElementsAttr>(constantOp.getValue());
754 if (!dense)
755 return failure();
756 // Notify the rewriter that the warp op is changing (see the comment on
757 // the WarpOpTransferRead pattern).
758 rewriter.startOpModification(op: warpOp);
759 unsigned operandIndex = yieldOperand->getOperandNumber();
760 Attribute scalarAttr = dense.getSplatValue<Attribute>();
761 auto newAttr = DenseElementsAttr::get(
762 cast<ShapedType>(warpOp.getResult(operandIndex).getType()), scalarAttr);
763 Location loc = warpOp.getLoc();
764 rewriter.setInsertionPointAfter(warpOp);
765 Value distConstant = rewriter.create<arith::ConstantOp>(loc, newAttr);
766 rewriter.replaceAllUsesWith(warpOp.getResult(operandIndex), distConstant);
767 rewriter.finalizeOpModification(op: warpOp);
768 return success();
769 }
770};
771
772/// Delinearize the given `laneId` into multiple dimensions, where each
773/// dimension's size is determined by `originalShape` and `distributedShape`
774/// together. This function expects the total numbers of threads needed for
775/// distribution is equal to `warpSize`. Returns true and updates
776/// `delinearizedIds` if so.
777bool delinearizeLaneId(OpBuilder &builder, Location loc,
778 ArrayRef<int64_t> originalShape,
779 ArrayRef<int64_t> distributedShape, int64_t warpSize,
780 Value laneId, SmallVectorImpl<Value> &delinearizedIds) {
781 // If the original shape and the distributed shape is the same, we don't
782 // distribute at all--every thread is handling the whole. For such case, we
783 // should not rely on lane IDs later. So just return an empty lane ID vector.
784 if (originalShape == distributedShape) {
785 delinearizedIds.clear();
786 return true;
787 }
788
789 SmallVector<int64_t> sizes;
790 for (auto [large, small] : llvm::zip_equal(t&: originalShape, u&: distributedShape)) {
791 if (large % small != 0)
792 return false;
793 sizes.push_back(Elt: large / small);
794 }
795 if (std::accumulate(first: sizes.begin(), last: sizes.end(), init: 1,
796 binary_op: std::multiplies<int64_t>()) != warpSize)
797 return false;
798
799 AffineExpr s0, s1;
800 bindSymbols(ctx: builder.getContext(), exprs&: s0, exprs&: s1);
801
802 int64_t usedThreads = 1;
803
804 Value zero = builder.create<arith::ConstantIndexOp>(location: loc, args: 0);
805 delinearizedIds.assign(NumElts: sizes.size(), Elt: zero);
806
807 for (int i = sizes.size() - 1; i >= 0; --i) {
808 usedThreads *= sizes[i];
809 if (usedThreads == warpSize) {
810 // We've used up all available threads. Don't need to perform modulo
811 // anymore. And we can stop the calculation for further dimensions.
812 delinearizedIds[i] = laneId;
813 break;
814 }
815 delinearizedIds[i] =
816 affine::makeComposedAffineApply(builder, loc, s0 % sizes[i], {laneId});
817 laneId = affine::makeComposedAffineApply(
818 builder, loc, s0.floorDiv(v: usedThreads), {laneId});
819 }
820 return true;
821}
822
823/// Sink out transfer_read op feeding into a warp op yield.
824/// ```
825/// %0 = vector.warp_execute_on_lane_0(%arg0) -> (vector<1xf32>) {
826/// ...
827// %2 = vector.transfer_read %src[%c0], %cst : memref<1024xf32>,
828// vector<32xf32>
829/// vector.yield %2 : vector<32xf32>
830/// }
831/// ```
832/// To
833/// ```
834/// %dead = vector.warp_execute_on_lane_0(%arg0) -> (vector<1xf32>,
835/// vector<1xf32>, vector<1xf32>) {
836/// ...
837/// %2 = vector.transfer_read %src[%c0], %cst : memref<1024xf32>,
838/// vector<32xf32> vector.yield %2 : vector<32xf32>
839/// }
840/// %0 = vector.transfer_read %src[%c0], %cst : memref<1024xf32>, vector<1xf32>
841struct WarpOpTransferRead : public OpRewritePattern<WarpExecuteOnLane0Op> {
842 using OpRewritePattern<WarpExecuteOnLane0Op>::OpRewritePattern;
843 LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
844 PatternRewriter &rewriter) const override {
845 // Try to find a distributable yielded read. Note that this pattern can
846 // still fail at the end after distribution, in which case this might have
847 // missed another distributable read.
848 OpOperand *operand = getWarpResult(warpOp, [](Operation *op) {
849 // Don't duplicate transfer_read ops when distributing.
850 return isa<vector::TransferReadOp>(op) && op->hasOneUse();
851 });
852 if (!operand)
853 return rewriter.notifyMatchFailure(
854 warpOp, "warp result is not a vector.transfer_read op");
855 auto read = operand->get().getDefiningOp<vector::TransferReadOp>();
856
857 // Source must be defined outside of the region.
858 if (!warpOp.isDefinedOutsideOfRegion(read.getSource()))
859 return rewriter.notifyMatchFailure(
860 read, "source must be defined outside of the region");
861
862 unsigned operandIndex = operand->getOperandNumber();
863 Value distributedVal = warpOp.getResult(operandIndex);
864
865 SmallVector<Value, 4> indices(read.getIndices().begin(),
866 read.getIndices().end());
867 auto sequentialType = cast<VectorType>(read.getResult().getType());
868 auto distributedType = cast<VectorType>(distributedVal.getType());
869 AffineMap map = calculateImplicitMap(sequentialType, distributedType);
870 AffineMap indexMap = map.compose(read.getPermutationMap());
871
872 // Try to delinearize the lane ID to match the rank expected for
873 // distribution.
874 SmallVector<Value> delinearizedIds;
875 if (!delinearizeLaneId(rewriter, read.getLoc(), sequentialType.getShape(),
876 distributedType.getShape(), warpOp.getWarpSize(),
877 warpOp.getLaneid(), delinearizedIds)) {
878 return rewriter.notifyMatchFailure(
879 read, "cannot delinearize lane ID for distribution");
880 }
881 assert(!delinearizedIds.empty() || map.getNumResults() == 0);
882
883 // Distribute indices and the mask (if present).
884 OpBuilder::InsertionGuard g(rewriter);
885 SmallVector<Value> additionalResults(indices.begin(), indices.end());
886 SmallVector<Type> additionalResultTypes(indices.size(),
887 rewriter.getIndexType());
888 additionalResults.push_back(Elt: read.getPadding());
889 additionalResultTypes.push_back(Elt: read.getPadding().getType());
890
891 bool hasMask = false;
892 if (read.getMask()) {
893 hasMask = true;
894 // TODO: Distribution of masked reads with non-trivial permutation maps
895 // requires the distribution of the mask to elementwise match the
896 // distribution of the permuted written vector. Currently the details
897 // of which lane is responsible for which element is captured strictly
898 // by shape information on the warp op, and thus requires materializing
899 // the permutation in IR.
900 if (!mlir::compressUnusedDims(read.getPermutationMap()).isIdentity())
901 return rewriter.notifyMatchFailure(
902 read, "non-trivial permutation maps not supported");
903 VectorType maskType =
904 getDistributedType(read.getMaskType(), map, warpOp.getWarpSize());
905 additionalResults.push_back(Elt: read.getMask());
906 additionalResultTypes.push_back(Elt: maskType);
907 }
908
909 SmallVector<size_t> newRetIndices;
910 WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
911 rewriter, warpOp, additionalResults, additionalResultTypes,
912 newRetIndices);
913 distributedVal = newWarpOp.getResult(operandIndex);
914
915 // Distributed indices were appended first.
916 SmallVector<Value> newIndices;
917 for (int64_t i = 0, e = indices.size(); i < e; ++i)
918 newIndices.push_back(Elt: newWarpOp.getResult(newRetIndices[i]));
919
920 rewriter.setInsertionPointAfter(newWarpOp);
921 for (auto it : llvm::zip_equal(indexMap.getResults(), map.getResults())) {
922 AffineExpr d0, d1;
923 bindDims(read.getContext(), d0, d1);
924 auto indexExpr = dyn_cast<AffineDimExpr>(std::get<0>(it));
925 if (!indexExpr)
926 continue;
927 unsigned indexPos = indexExpr.getPosition();
928 unsigned vectorPos = cast<AffineDimExpr>(std::get<1>(it)).getPosition();
929 int64_t scale = distributedType.getDimSize(vectorPos);
930 newIndices[indexPos] = affine::makeComposedAffineApply(
931 rewriter, read.getLoc(), d0 + scale * d1,
932 {newIndices[indexPos], delinearizedIds[vectorPos]});
933 }
934
935 // Distributed padding value was appended right after the indices.
936 Value newPadding = newWarpOp.getResult(newRetIndices[indices.size()]);
937 // Distributed mask value was added at the end (if the op has a mask).
938 Value newMask =
939 hasMask ? newWarpOp.getResult(newRetIndices[newRetIndices.size() - 1])
940 : Value();
941 auto newRead = rewriter.create<vector::TransferReadOp>(
942 read.getLoc(), distributedVal.getType(), read.getSource(), newIndices,
943 read.getPermutationMapAttr(), newPadding, newMask,
944 read.getInBoundsAttr());
945
946 rewriter.replaceAllUsesWith(distributedVal, newRead);
947 return success();
948 }
949};
950
951/// Remove any result that has no use along with the matching yieldOp operand.
952// TODO: Move this in WarpExecuteOnLane0Op canonicalization.
953struct WarpOpDeadResult : public OpRewritePattern<WarpExecuteOnLane0Op> {
954 using OpRewritePattern<WarpExecuteOnLane0Op>::OpRewritePattern;
955 LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
956 PatternRewriter &rewriter) const override {
957 SmallVector<Type> newResultTypes;
958 newResultTypes.reserve(N: warpOp->getNumResults());
959 SmallVector<Value> newYieldValues;
960 newYieldValues.reserve(N: warpOp->getNumResults());
961 DenseMap<Value, int64_t> dedupYieldOperandPositionMap;
962 DenseMap<OpResult, int64_t> dedupResultPositionMap;
963 auto yield = cast<vector::YieldOp>(
964 warpOp.getBodyRegion().getBlocks().begin()->getTerminator());
965
966 // Some values may be yielded multiple times and correspond to multiple
967 // results. Deduplicating occurs by taking each result with its matching
968 // yielded value, and:
969 // 1. recording the unique first position at which the value is yielded.
970 // 2. recording for the result, the first position at which the dedup'ed
971 // value is yielded.
972 // 3. skipping from the new result types / new yielded values any result
973 // that has no use or whose yielded value has already been seen.
974 for (OpResult result : warpOp.getResults()) {
975 Value yieldOperand = yield.getOperand(result.getResultNumber());
976 auto it = dedupYieldOperandPositionMap.insert(
977 std::make_pair(yieldOperand, newResultTypes.size()));
978 dedupResultPositionMap.insert(std::make_pair(result, it.first->second));
979 if (result.use_empty() || !it.second)
980 continue;
981 newResultTypes.push_back(result.getType());
982 newYieldValues.push_back(yieldOperand);
983 }
984 // No modification, exit early.
985 if (yield.getNumOperands() == newYieldValues.size())
986 return failure();
987 // Move the body of the old warpOp to a new warpOp.
988 WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndReplaceReturns(
989 rewriter, warpOp, newYieldValues, newResultTypes);
990
991 // Simplify the new warp op after dropping dead results.
992 newWarpOp.getBody()->walk([&](Operation *op) {
993 if (isOpTriviallyDead(op))
994 rewriter.eraseOp(op);
995 });
996
997 // Replace results of the old warpOp by the new, deduplicated results.
998 SmallVector<Value> newValues;
999 newValues.reserve(N: warpOp->getNumResults());
1000 for (OpResult result : warpOp.getResults()) {
1001 if (result.use_empty())
1002 newValues.push_back(Value());
1003 else
1004 newValues.push_back(
1005 newWarpOp.getResult(dedupResultPositionMap.lookup(result)));
1006 }
1007 rewriter.replaceOp(warpOp, newValues);
1008 return success();
1009 }
1010};
1011
1012// If an operand is directly yielded out of the region we can forward it
1013// directly and it doesn't need to go through the region.
1014struct WarpOpForwardOperand : public OpRewritePattern<WarpExecuteOnLane0Op> {
1015 using OpRewritePattern<WarpExecuteOnLane0Op>::OpRewritePattern;
1016 LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
1017 PatternRewriter &rewriter) const override {
1018 SmallVector<Type> resultTypes;
1019 SmallVector<Value> yieldValues;
1020 auto yield = cast<vector::YieldOp>(
1021 warpOp.getBodyRegion().getBlocks().begin()->getTerminator());
1022 Value valForwarded;
1023 unsigned resultIndex;
1024 for (OpOperand &operand : yield->getOpOperands()) {
1025 Value result = warpOp.getResult(operand.getOperandNumber());
1026 if (result.use_empty())
1027 continue;
1028
1029 // Assume all the values coming from above are uniform.
1030 if (!warpOp.getBodyRegion().isAncestor(operand.get().getParentRegion())) {
1031 if (result.getType() != operand.get().getType())
1032 continue;
1033 valForwarded = operand.get();
1034 resultIndex = operand.getOperandNumber();
1035 break;
1036 }
1037 auto arg = dyn_cast<BlockArgument>(operand.get());
1038 if (!arg || arg.getOwner()->getParentOp() != warpOp.getOperation())
1039 continue;
1040 Value warpOperand = warpOp.getArgs()[arg.getArgNumber()];
1041 if (result.getType() != warpOperand.getType())
1042 continue;
1043 valForwarded = warpOperand;
1044 resultIndex = operand.getOperandNumber();
1045 break;
1046 }
1047 if (!valForwarded)
1048 return failure();
1049 // Notify the rewriter that the warp op is changing (see the comment on
1050 // the WarpOpTransferRead pattern).
1051 rewriter.startOpModification(op: warpOp);
1052 rewriter.replaceAllUsesWith(warpOp.getResult(resultIndex), valForwarded);
1053 rewriter.finalizeOpModification(op: warpOp);
1054 return success();
1055 }
1056};
1057
1058struct WarpOpBroadcast : public OpRewritePattern<WarpExecuteOnLane0Op> {
1059 using OpRewritePattern<WarpExecuteOnLane0Op>::OpRewritePattern;
1060 LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
1061 PatternRewriter &rewriter) const override {
1062 OpOperand *operand =
1063 getWarpResult(warpOp, llvm::IsaPred<vector::BroadcastOp>);
1064 if (!operand)
1065 return failure();
1066 unsigned int operandNumber = operand->getOperandNumber();
1067 auto broadcastOp = operand->get().getDefiningOp<vector::BroadcastOp>();
1068 Location loc = broadcastOp.getLoc();
1069 auto destVecType =
1070 cast<VectorType>(warpOp->getResultTypes()[operandNumber]);
1071 Value broadcastSrc = broadcastOp.getSource();
1072 Type broadcastSrcType = broadcastSrc.getType();
1073
1074 // Check that the broadcast actually spans a set of values uniformly across
1075 // all threads. In other words, check that each thread can reconstruct
1076 // their own broadcast.
1077 // For that we simply check that the broadcast we want to build makes sense.
1078 if (vector::isBroadcastableTo(srcType: broadcastSrcType, dstVectorType: destVecType) !=
1079 vector::BroadcastableToResult::Success)
1080 return failure();
1081 SmallVector<size_t> newRetIndices;
1082 WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
1083 rewriter, warpOp, {broadcastSrc}, {broadcastSrcType}, newRetIndices);
1084 rewriter.setInsertionPointAfter(newWarpOp);
1085 Value broadcasted = rewriter.create<vector::BroadcastOp>(
1086 loc, destVecType, newWarpOp->getResult(newRetIndices[0]));
1087 rewriter.replaceAllUsesWith(newWarpOp->getResult(operandNumber),
1088 broadcasted);
1089 return success();
1090 }
1091};
1092
1093/// Pattern to move shape cast out of the warp op. shape cast is basically a
1094/// no-op for warp distribution; we need to handle the shape though.
1095struct WarpOpShapeCast : public OpRewritePattern<WarpExecuteOnLane0Op> {
1096 using OpRewritePattern<WarpExecuteOnLane0Op>::OpRewritePattern;
1097 LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
1098 PatternRewriter &rewriter) const override {
1099 OpOperand *operand =
1100 getWarpResult(warpOp, llvm::IsaPred<vector::ShapeCastOp>);
1101 if (!operand)
1102 return failure();
1103
1104 auto oldCastOp = operand->get().getDefiningOp<vector::ShapeCastOp>();
1105
1106 unsigned int operandNumber = operand->getOperandNumber();
1107 auto castDistributedType =
1108 cast<VectorType>(warpOp->getResultTypes()[operandNumber]);
1109 VectorType castOriginalType = oldCastOp.getSourceVectorType();
1110 VectorType castResultType = castDistributedType;
1111
1112 // We expect the distributed type to have a smaller rank than the original
1113 // type. Prepend with size-one dimensions to make them the same.
1114 unsigned castDistributedRank = castDistributedType.getRank();
1115 unsigned castOriginalRank = castOriginalType.getRank();
1116 if (castDistributedRank < castOriginalRank) {
1117 SmallVector<int64_t> shape(castOriginalRank - castDistributedRank, 1);
1118 llvm::append_range(shape, castDistributedType.getShape());
1119 castDistributedType =
1120 VectorType::get(shape, castDistributedType.getElementType());
1121 }
1122
1123 SmallVector<size_t> newRetIndices;
1124 WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
1125 rewriter, warpOp, {oldCastOp.getSource()}, {castDistributedType},
1126 newRetIndices);
1127 rewriter.setInsertionPointAfter(newWarpOp);
1128 Value newCast = rewriter.create<vector::ShapeCastOp>(
1129 oldCastOp.getLoc(), castResultType,
1130 newWarpOp->getResult(newRetIndices[0]));
1131 rewriter.replaceAllUsesWith(newWarpOp->getResult(operandNumber), newCast);
1132 return success();
1133 }
1134};
1135
1136/// Sink out vector.create_mask op feeding into a warp op yield.
1137/// ```
1138/// %0 = ...
1139/// %1 = vector.warp_execute_on_lane_0(%arg0) -> (vector<1xf32>) {
1140/// ...
1141/// %mask = vector.create_mask %0 : vector<32xi1>
1142/// vector.yield %mask : vector<32xi1>
1143/// }
1144/// ```
1145/// To
1146/// ```
1147/// %0 = ...
1148/// vector.warp_execute_on_lane_0(%arg0) {
1149/// ...
1150/// }
1151/// %cmp = arith.cmpi ult, %laneid, %0
1152/// %ub = arith.select %cmp, %c0, %c1
1153/// %1 = vector.create_mask %ub : vector<1xi1>
1154struct WarpOpCreateMask : public OpRewritePattern<WarpExecuteOnLane0Op> {
1155 using OpRewritePattern::OpRewritePattern;
1156 LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
1157 PatternRewriter &rewriter) const override {
1158 OpOperand *yieldOperand =
1159 getWarpResult(warpOp, llvm::IsaPred<vector::CreateMaskOp>);
1160 if (!yieldOperand)
1161 return failure();
1162
1163 auto mask = yieldOperand->get().getDefiningOp<vector::CreateMaskOp>();
1164
1165 // Early exit if any values needed for calculating the new mask indices
1166 // are defined inside the warp op.
1167 if (!llvm::all_of(mask->getOperands(), [&](Value value) {
1168 return warpOp.isDefinedOutsideOfRegion(value);
1169 }))
1170 return failure();
1171
1172 Location loc = mask.getLoc();
1173 unsigned operandIndex = yieldOperand->getOperandNumber();
1174
1175 auto distType = cast<VectorType>(warpOp.getResult(operandIndex).getType());
1176 VectorType seqType = mask.getVectorType();
1177 ArrayRef<int64_t> seqShape = seqType.getShape();
1178 ArrayRef<int64_t> distShape = distType.getShape();
1179
1180 rewriter.setInsertionPointAfter(warpOp);
1181
1182 // Delinearize the lane ID for constructing the distributed mask sizes.
1183 SmallVector<Value> delinearizedIds;
1184 if (!delinearizeLaneId(rewriter, loc, seqShape, distShape,
1185 warpOp.getWarpSize(), warpOp.getLaneid(),
1186 delinearizedIds))
1187 return rewriter.notifyMatchFailure(
1188 mask, "cannot delinearize lane ID for distribution");
1189 assert(!delinearizedIds.empty());
1190
1191 // Notify the rewriter that the warp op is changing (see the comment on
1192 // the WarpOpTransferRead pattern).
1193 rewriter.startOpModification(op: warpOp);
1194
1195 AffineExpr s0, s1;
1196 bindSymbols(ctx: rewriter.getContext(), exprs&: s0, exprs&: s1);
1197 SmallVector<Value> newOperands;
1198 for (int i = 0, e = distShape.size(); i < e; ++i) {
1199 // Get `mask_dim_range_upper_limit[i] - lane_id[i] * dist_sizes[i]` to
1200 // find the distance from the largest mask index owned by this lane to the
1201 // original mask size. `vector.create_mask` implicitly clamps mask
1202 // operands to the range [0, mask_vector_size[i]], or in other words, the
1203 // mask sizes are always in the range [0, mask_vector_size[i]).
1204 Value maskDimIdx = affine::makeComposedAffineApply(
1205 rewriter, loc, s1 - s0 * distShape[i],
1206 {delinearizedIds[i], mask.getOperand(i)});
1207 newOperands.push_back(Elt: maskDimIdx);
1208 }
1209
1210 auto newMask =
1211 rewriter.create<vector::CreateMaskOp>(loc, distType, newOperands);
1212 rewriter.replaceAllUsesWith(warpOp.getResult(operandIndex), newMask);
1213 rewriter.finalizeOpModification(op: warpOp);
1214 return success();
1215 }
1216};
1217
1218/// Pattern to move out vector.extract of single element vector. Those don't
1219/// need to be distributed and can just be propagated outside of the region.
1220struct WarpOpExtract : public OpRewritePattern<WarpExecuteOnLane0Op> {
1221 using OpRewritePattern<WarpExecuteOnLane0Op>::OpRewritePattern;
1222 LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
1223 PatternRewriter &rewriter) const override {
1224 OpOperand *operand =
1225 getWarpResult(warpOp, llvm::IsaPred<vector::ExtractOp>);
1226 if (!operand)
1227 return failure();
1228 unsigned int operandNumber = operand->getOperandNumber();
1229 auto extractOp = operand->get().getDefiningOp<vector::ExtractOp>();
1230 VectorType extractSrcType = extractOp.getSourceVectorType();
1231 Location loc = extractOp.getLoc();
1232
1233 // "vector.extract %v[] : vector<f32> from vector<f32>" is an invalid op.
1234 assert(extractSrcType.getRank() > 0 &&
1235 "vector.extract does not support rank 0 sources");
1236
1237 // "vector.extract %v[] : vector<...xf32> from vector<...xf32>" can be
1238 // canonicalized to %v.
1239 if (extractOp.getNumIndices() == 0)
1240 return failure();
1241
1242 // Rewrite vector.extract with 1d source to vector.extractelement.
1243 if (extractSrcType.getRank() == 1) {
1244 if (extractOp.hasDynamicPosition())
1245 // TODO: Dinamic position not supported yet.
1246 return failure();
1247
1248 assert(extractOp.getNumIndices() == 1 && "expected 1 index");
1249 int64_t pos = extractOp.getStaticPosition()[0];
1250 rewriter.setInsertionPoint(extractOp);
1251 rewriter.replaceOpWithNewOp<vector::ExtractElementOp>(
1252 extractOp, extractOp.getVector(),
1253 rewriter.create<arith::ConstantIndexOp>(loc, pos));
1254 return success();
1255 }
1256
1257 // All following cases are 2d or higher dimensional source vectors.
1258
1259 if (warpOp.getResult(operandNumber).getType() == operand->get().getType()) {
1260 // There is no distribution, this is a broadcast. Simply move the extract
1261 // out of the warp op.
1262 // TODO: This could be optimized. E.g., in case of a scalar result, let
1263 // one lane extract and shuffle the result to all other lanes (same as
1264 // the 1d case).
1265 SmallVector<size_t> newRetIndices;
1266 WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
1267 rewriter, warpOp, {extractOp.getVector()},
1268 {extractOp.getSourceVectorType()}, newRetIndices);
1269 rewriter.setInsertionPointAfter(newWarpOp);
1270 Value distributedVec = newWarpOp->getResult(newRetIndices[0]);
1271 // Extract from distributed vector.
1272 Value newExtract = rewriter.create<vector::ExtractOp>(
1273 loc, distributedVec, extractOp.getMixedPosition());
1274 rewriter.replaceAllUsesWith(newWarpOp->getResult(operandNumber),
1275 newExtract);
1276 return success();
1277 }
1278
1279 // Find the distributed dimension. There should be exactly one.
1280 auto distributedType =
1281 cast<VectorType>(warpOp.getResult(operandNumber).getType());
1282 auto yieldedType = cast<VectorType>(operand->get().getType());
1283 int64_t distributedDim = -1;
1284 for (int64_t i = 0; i < yieldedType.getRank(); ++i) {
1285 if (distributedType.getDimSize(i) != yieldedType.getDimSize(i)) {
1286 // Keep this assert here in case WarpExecuteOnLane0Op gets extended to
1287 // support distributing multiple dimensions in the future.
1288 assert(distributedDim == -1 && "found multiple distributed dims");
1289 distributedDim = i;
1290 }
1291 }
1292 assert(distributedDim != -1 && "could not find distributed dimension");
1293 (void)distributedDim;
1294
1295 // Yield source vector from warp op.
1296 SmallVector<int64_t> newDistributedShape(extractSrcType.getShape().begin(),
1297 extractSrcType.getShape().end());
1298 for (int i = 0; i < distributedType.getRank(); ++i)
1299 newDistributedShape[i + extractOp.getNumIndices()] =
1300 distributedType.getDimSize(i);
1301 auto newDistributedType =
1302 VectorType::get(newDistributedShape, distributedType.getElementType());
1303 SmallVector<size_t> newRetIndices;
1304 WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
1305 rewriter, warpOp, {extractOp.getVector()}, {newDistributedType},
1306 newRetIndices);
1307 rewriter.setInsertionPointAfter(newWarpOp);
1308 Value distributedVec = newWarpOp->getResult(newRetIndices[0]);
1309 // Extract from distributed vector.
1310 Value newExtract = rewriter.create<vector::ExtractOp>(
1311 loc, distributedVec, extractOp.getMixedPosition());
1312 rewriter.replaceAllUsesWith(newWarpOp->getResult(operandNumber),
1313 newExtract);
1314 return success();
1315 }
1316};
1317
1318/// Pattern to move out vector.extractelement of 0-D tensors. Those don't
1319/// need to be distributed and can just be propagated outside of the region.
1320struct WarpOpExtractElement : public OpRewritePattern<WarpExecuteOnLane0Op> {
1321 WarpOpExtractElement(MLIRContext *ctx, WarpShuffleFromIdxFn fn,
1322 PatternBenefit b = 1)
1323 : OpRewritePattern<WarpExecuteOnLane0Op>(ctx, b),
1324 warpShuffleFromIdxFn(std::move(fn)) {}
1325 LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
1326 PatternRewriter &rewriter) const override {
1327 OpOperand *operand =
1328 getWarpResult(warpOp, llvm::IsaPred<vector::ExtractElementOp>);
1329 if (!operand)
1330 return failure();
1331 unsigned int operandNumber = operand->getOperandNumber();
1332 auto extractOp = operand->get().getDefiningOp<vector::ExtractElementOp>();
1333 VectorType extractSrcType = extractOp.getSourceVectorType();
1334 // TODO: Supported shuffle types should be parameterizable, similar to
1335 // `WarpShuffleFromIdxFn`.
1336 if (!extractSrcType.getElementType().isF32() &&
1337 !extractSrcType.getElementType().isInteger(32))
1338 return rewriter.notifyMatchFailure(
1339 extractOp, "only f32/i32 element types are supported");
1340 bool is0dOrVec1Extract = extractSrcType.getNumElements() == 1;
1341 Type elType = extractSrcType.getElementType();
1342 VectorType distributedVecType;
1343 if (!is0dOrVec1Extract) {
1344 assert(extractSrcType.getRank() == 1 &&
1345 "expected that extractelement src rank is 0 or 1");
1346 if (extractSrcType.getShape()[0] % warpOp.getWarpSize() != 0)
1347 return failure();
1348 int64_t elementsPerLane =
1349 extractSrcType.getShape()[0] / warpOp.getWarpSize();
1350 distributedVecType = VectorType::get({elementsPerLane}, elType);
1351 } else {
1352 distributedVecType = extractSrcType;
1353 }
1354 // Yield source vector and position (if present) from warp op.
1355 SmallVector<Value> additionalResults{extractOp.getVector()};
1356 SmallVector<Type> additionalResultTypes{distributedVecType};
1357 if (static_cast<bool>(extractOp.getPosition())) {
1358 additionalResults.push_back(Elt: extractOp.getPosition());
1359 additionalResultTypes.push_back(Elt: extractOp.getPosition().getType());
1360 }
1361 Location loc = extractOp.getLoc();
1362 SmallVector<size_t> newRetIndices;
1363 WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
1364 rewriter, warpOp, additionalResults, additionalResultTypes,
1365 newRetIndices);
1366 rewriter.setInsertionPointAfter(newWarpOp);
1367 Value distributedVec = newWarpOp->getResult(newRetIndices[0]);
1368
1369 // 0d extract: The new warp op broadcasts the source vector to all lanes.
1370 // All lanes extract the scalar.
1371 if (is0dOrVec1Extract) {
1372 Value newExtract;
1373 if (extractSrcType.getRank() == 1) {
1374 newExtract = rewriter.create<vector::ExtractElementOp>(
1375 loc, distributedVec,
1376 rewriter.create<arith::ConstantIndexOp>(loc, 0));
1377
1378 } else {
1379 newExtract =
1380 rewriter.create<vector::ExtractElementOp>(loc, distributedVec);
1381 }
1382 rewriter.replaceAllUsesWith(newWarpOp->getResult(operandNumber),
1383 newExtract);
1384 return success();
1385 }
1386
1387 // 1d extract: Distribute the source vector. One lane extracts and shuffles
1388 // the value to all other lanes.
1389 int64_t elementsPerLane = distributedVecType.getShape()[0];
1390 AffineExpr sym0 = getAffineSymbolExpr(position: 0, context: rewriter.getContext());
1391 // tid of extracting thread: pos / elementsPerLane
1392 Value broadcastFromTid = rewriter.create<affine::AffineApplyOp>(
1393 loc, sym0.ceilDiv(v: elementsPerLane),
1394 newWarpOp->getResult(newRetIndices[1]));
1395 // Extract at position: pos % elementsPerLane
1396 Value pos =
1397 elementsPerLane == 1
1398 ? rewriter.create<arith::ConstantIndexOp>(location: loc, args: 0).getResult()
1399 : rewriter
1400 .create<affine::AffineApplyOp>(
1401 loc, sym0 % elementsPerLane,
1402 newWarpOp->getResult(newRetIndices[1]))
1403 .getResult();
1404 Value extracted =
1405 rewriter.create<vector::ExtractElementOp>(loc, distributedVec, pos);
1406
1407 // Shuffle the extracted value to all lanes.
1408 Value shuffled = warpShuffleFromIdxFn(
1409 loc, rewriter, extracted, broadcastFromTid, newWarpOp.getWarpSize());
1410 rewriter.replaceAllUsesWith(newWarpOp->getResult(operandNumber), shuffled);
1411 return success();
1412 }
1413
1414private:
1415 WarpShuffleFromIdxFn warpShuffleFromIdxFn;
1416};
1417
1418struct WarpOpInsertElement : public OpRewritePattern<WarpExecuteOnLane0Op> {
1419 using OpRewritePattern<WarpExecuteOnLane0Op>::OpRewritePattern;
1420
1421 LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
1422 PatternRewriter &rewriter) const override {
1423 OpOperand *operand =
1424 getWarpResult(warpOp, llvm::IsaPred<vector::InsertElementOp>);
1425 if (!operand)
1426 return failure();
1427 unsigned int operandNumber = operand->getOperandNumber();
1428 auto insertOp = operand->get().getDefiningOp<vector::InsertElementOp>();
1429 VectorType vecType = insertOp.getDestVectorType();
1430 VectorType distrType =
1431 cast<VectorType>(warpOp.getResult(operandNumber).getType());
1432 bool hasPos = static_cast<bool>(insertOp.getPosition());
1433
1434 // Yield destination vector, source scalar and position from warp op.
1435 SmallVector<Value> additionalResults{insertOp.getDest(),
1436 insertOp.getSource()};
1437 SmallVector<Type> additionalResultTypes{distrType,
1438 insertOp.getSource().getType()};
1439 if (hasPos) {
1440 additionalResults.push_back(Elt: insertOp.getPosition());
1441 additionalResultTypes.push_back(Elt: insertOp.getPosition().getType());
1442 }
1443 Location loc = insertOp.getLoc();
1444 SmallVector<size_t> newRetIndices;
1445 WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
1446 rewriter, warpOp, additionalResults, additionalResultTypes,
1447 newRetIndices);
1448 rewriter.setInsertionPointAfter(newWarpOp);
1449 Value distributedVec = newWarpOp->getResult(newRetIndices[0]);
1450 Value newSource = newWarpOp->getResult(newRetIndices[1]);
1451 Value newPos = hasPos ? newWarpOp->getResult(newRetIndices[2]) : Value();
1452 rewriter.setInsertionPointAfter(newWarpOp);
1453
1454 if (vecType == distrType) {
1455 // Broadcast: Simply move the vector.inserelement op out.
1456 Value newInsert = rewriter.create<vector::InsertElementOp>(
1457 loc, newSource, distributedVec, newPos);
1458 rewriter.replaceAllUsesWith(newWarpOp->getResult(operandNumber),
1459 newInsert);
1460 return success();
1461 }
1462
1463 // This is a distribution. Only one lane should insert.
1464 int64_t elementsPerLane = distrType.getShape()[0];
1465 AffineExpr sym0 = getAffineSymbolExpr(position: 0, context: rewriter.getContext());
1466 // tid of extracting thread: pos / elementsPerLane
1467 Value insertingLane = rewriter.create<affine::AffineApplyOp>(
1468 loc, sym0.ceilDiv(v: elementsPerLane), newPos);
1469 // Insert position: pos % elementsPerLane
1470 Value pos =
1471 elementsPerLane == 1
1472 ? rewriter.create<arith::ConstantIndexOp>(location: loc, args: 0).getResult()
1473 : rewriter
1474 .create<affine::AffineApplyOp>(loc, sym0 % elementsPerLane,
1475 newPos)
1476 .getResult();
1477 Value isInsertingLane = rewriter.create<arith::CmpIOp>(
1478 loc, arith::CmpIPredicate::eq, newWarpOp.getLaneid(), insertingLane);
1479 Value newResult =
1480 rewriter
1481 .create<scf::IfOp>(
1482 loc, isInsertingLane,
1483 /*thenBuilder=*/
1484 [&](OpBuilder &builder, Location loc) {
1485 Value newInsert = builder.create<vector::InsertElementOp>(
1486 loc, newSource, distributedVec, pos);
1487 builder.create<scf::YieldOp>(loc, newInsert);
1488 },
1489 /*elseBuilder=*/
1490 [&](OpBuilder &builder, Location loc) {
1491 builder.create<scf::YieldOp>(loc, distributedVec);
1492 })
1493 .getResult(0);
1494 rewriter.replaceAllUsesWith(newWarpOp->getResult(operandNumber), newResult);
1495 return success();
1496 }
1497};
1498
1499struct WarpOpInsert : public OpRewritePattern<WarpExecuteOnLane0Op> {
1500 using OpRewritePattern<WarpExecuteOnLane0Op>::OpRewritePattern;
1501
1502 LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
1503 PatternRewriter &rewriter) const override {
1504 OpOperand *operand = getWarpResult(warpOp, llvm::IsaPred<vector::InsertOp>);
1505 if (!operand)
1506 return failure();
1507 unsigned int operandNumber = operand->getOperandNumber();
1508 auto insertOp = operand->get().getDefiningOp<vector::InsertOp>();
1509 Location loc = insertOp.getLoc();
1510
1511 // "vector.insert %v, %v[] : ..." can be canonicalized to %v.
1512 if (insertOp.getNumIndices() == 0)
1513 return failure();
1514
1515 // Rewrite vector.insert with 1d dest to vector.insertelement.
1516 if (insertOp.getDestVectorType().getRank() == 1) {
1517 if (insertOp.hasDynamicPosition())
1518 // TODO: Dinamic position not supported yet.
1519 return failure();
1520
1521 assert(insertOp.getNumIndices() == 1 && "expected 1 index");
1522 int64_t pos = insertOp.getStaticPosition()[0];
1523 rewriter.setInsertionPoint(insertOp);
1524 rewriter.replaceOpWithNewOp<vector::InsertElementOp>(
1525 insertOp, insertOp.getSource(), insertOp.getDest(),
1526 rewriter.create<arith::ConstantIndexOp>(loc, pos));
1527 return success();
1528 }
1529
1530 if (warpOp.getResult(operandNumber).getType() == operand->get().getType()) {
1531 // There is no distribution, this is a broadcast. Simply move the insert
1532 // out of the warp op.
1533 SmallVector<size_t> newRetIndices;
1534 WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
1535 rewriter, warpOp, {insertOp.getSource(), insertOp.getDest()},
1536 {insertOp.getSourceType(), insertOp.getDestVectorType()},
1537 newRetIndices);
1538 rewriter.setInsertionPointAfter(newWarpOp);
1539 Value distributedSrc = newWarpOp->getResult(newRetIndices[0]);
1540 Value distributedDest = newWarpOp->getResult(newRetIndices[1]);
1541 Value newResult = rewriter.create<vector::InsertOp>(
1542 loc, distributedSrc, distributedDest, insertOp.getMixedPosition());
1543 rewriter.replaceAllUsesWith(newWarpOp->getResult(operandNumber),
1544 newResult);
1545 return success();
1546 }
1547
1548 // Find the distributed dimension. There should be exactly one.
1549 auto distrDestType =
1550 cast<VectorType>(warpOp.getResult(operandNumber).getType());
1551 auto yieldedType = cast<VectorType>(operand->get().getType());
1552 int64_t distrDestDim = -1;
1553 for (int64_t i = 0; i < yieldedType.getRank(); ++i) {
1554 if (distrDestType.getDimSize(i) != yieldedType.getDimSize(i)) {
1555 // Keep this assert here in case WarpExecuteOnLane0Op gets extended to
1556 // support distributing multiple dimensions in the future.
1557 assert(distrDestDim == -1 && "found multiple distributed dims");
1558 distrDestDim = i;
1559 }
1560 }
1561 assert(distrDestDim != -1 && "could not find distributed dimension");
1562
1563 // Compute the distributed source vector type.
1564 VectorType srcVecType = cast<VectorType>(insertOp.getSourceType());
1565 SmallVector<int64_t> distrSrcShape(srcVecType.getShape().begin(),
1566 srcVecType.getShape().end());
1567 // E.g.: vector.insert %s, %d [2] : vector<96xf32> into vector<128x96xf32>
1568 // Case 1: distrDestDim = 1 (dim of size 96). In that case, each lane will
1569 // insert a smaller vector<3xf32>.
1570 // Case 2: distrDestDim = 0 (dim of size 128) => distrSrcDim = -1. In that
1571 // case, one lane will insert the source vector<96xf32>. The other
1572 // lanes will not do anything.
1573 int64_t distrSrcDim = distrDestDim - insertOp.getNumIndices();
1574 if (distrSrcDim >= 0)
1575 distrSrcShape[distrSrcDim] = distrDestType.getDimSize(distrDestDim);
1576 auto distrSrcType =
1577 VectorType::get(distrSrcShape, distrDestType.getElementType());
1578
1579 // Yield source and dest vectors from warp op.
1580 SmallVector<size_t> newRetIndices;
1581 WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
1582 rewriter, warpOp, {insertOp.getSource(), insertOp.getDest()},
1583 {distrSrcType, distrDestType}, newRetIndices);
1584 rewriter.setInsertionPointAfter(newWarpOp);
1585 Value distributedSrc = newWarpOp->getResult(newRetIndices[0]);
1586 Value distributedDest = newWarpOp->getResult(newRetIndices[1]);
1587
1588 // Insert into the distributed vector.
1589 Value newResult;
1590 if (distrSrcDim >= 0) {
1591 // Every lane inserts a small piece.
1592 newResult = rewriter.create<vector::InsertOp>(
1593 loc, distributedSrc, distributedDest, insertOp.getMixedPosition());
1594 } else {
1595 // One lane inserts the entire source vector.
1596 int64_t elementsPerLane = distrDestType.getDimSize(distrDestDim);
1597 SmallVector<OpFoldResult> pos = insertOp.getMixedPosition();
1598 SmallVector<int64_t> newPos = getAsIntegers(foldResults: pos);
1599 // tid of inserting lane: pos / elementsPerLane
1600 Value insertingLane = rewriter.create<arith::ConstantIndexOp>(
1601 location: loc, args: newPos[distrDestDim] / elementsPerLane);
1602 Value isInsertingLane = rewriter.create<arith::CmpIOp>(
1603 loc, arith::CmpIPredicate::eq, newWarpOp.getLaneid(), insertingLane);
1604 // Insert position: pos % elementsPerLane
1605 newPos[distrDestDim] %= elementsPerLane;
1606 auto insertingBuilder = [&](OpBuilder &builder, Location loc) {
1607 Value newInsert = builder.create<vector::InsertOp>(
1608 loc, distributedSrc, distributedDest, newPos);
1609 builder.create<scf::YieldOp>(loc, newInsert);
1610 };
1611 auto nonInsertingBuilder = [&](OpBuilder &builder, Location loc) {
1612 builder.create<scf::YieldOp>(loc, distributedDest);
1613 };
1614 newResult = rewriter
1615 .create<scf::IfOp>(loc, isInsertingLane,
1616 /*thenBuilder=*/insertingBuilder,
1617 /*elseBuilder=*/nonInsertingBuilder)
1618 .getResult(0);
1619 }
1620
1621 rewriter.replaceAllUsesWith(newWarpOp->getResult(operandNumber), newResult);
1622 return success();
1623 }
1624};
1625
1626/// Sink scf.for region out of WarpExecuteOnLane0Op. This can be done only if
1627/// the scf.ForOp is the last operation in the region so that it doesn't change
1628/// the order of execution. This creates a new scf.for region after the
1629/// WarpExecuteOnLane0Op. The new scf.for region will contain a new
1630/// WarpExecuteOnLane0Op region. Example:
1631/// ```
1632/// %w = vector.warp_execute_on_lane_0(%laneid) -> (vector<4xf32>) {
1633/// ...
1634/// %v1 = scf.for %arg3 = %c0 to %c128 step %c1 iter_args(%arg4 = %v)
1635/// -> (vector<128xf32>) {
1636/// ...
1637/// scf.yield %r : vector<128xf32>
1638/// }
1639/// vector.yield %v1 : vector<128xf32>
1640/// }
1641/// ```
1642/// To:
1643/// %w0 = vector.warp_execute_on_lane_0(%arg0) -> (vector<4xf32>) {
1644/// ...
1645/// vector.yield %v : vector<128xf32>
1646/// }
1647/// %w = scf.for %arg3 = %c0 to %c128 step %c1 iter_args(%varg = %q0)
1648/// -> (vector<4xf32>) {
1649/// %iw = vector.warp_execute_on_lane_0(%laneid)
1650/// args(%varg : vector<4xf32>) -> (vector<4xf32>) {
1651/// ^bb0(%arg: vector<128xf32>):
1652/// ...
1653/// vector.yield %ir : vector<128xf32>
1654/// }
1655/// scf.yield %iw : vector<4xf32>
1656/// }
1657/// ```
1658struct WarpOpScfForOp : public OpRewritePattern<WarpExecuteOnLane0Op> {
1659
1660 WarpOpScfForOp(MLIRContext *ctx, DistributionMapFn fn, PatternBenefit b = 1)
1661 : OpRewritePattern<WarpExecuteOnLane0Op>(ctx, b),
1662 distributionMapFn(std::move(fn)) {}
1663 using OpRewritePattern<WarpExecuteOnLane0Op>::OpRewritePattern;
1664 LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
1665 PatternRewriter &rewriter) const override {
1666 auto yield = cast<vector::YieldOp>(
1667 warpOp.getBodyRegion().getBlocks().begin()->getTerminator());
1668 // Only pick up forOp if it is the last op in the region.
1669 Operation *lastNode = yield->getPrevNode();
1670 auto forOp = dyn_cast_or_null<scf::ForOp>(lastNode);
1671 if (!forOp)
1672 return failure();
1673 // Collect Values that come from the warp op but are outside the forOp.
1674 // Those Value needs to be returned by the original warpOp and passed to the
1675 // new op.
1676 llvm::SmallSetVector<Value, 32> escapingValues;
1677 SmallVector<Type> inputTypes;
1678 SmallVector<Type> distTypes;
1679 mlir::visitUsedValuesDefinedAbove(
1680 forOp.getBodyRegion(), [&](OpOperand *operand) {
1681 Operation *parent = operand->get().getParentRegion()->getParentOp();
1682 if (warpOp->isAncestor(parent)) {
1683 if (!escapingValues.insert(X: operand->get()))
1684 return;
1685 Type distType = operand->get().getType();
1686 if (auto vecType = dyn_cast<VectorType>(distType)) {
1687 AffineMap map = distributionMapFn(operand->get());
1688 distType = getDistributedType(vecType, map, warpOp.getWarpSize());
1689 }
1690 inputTypes.push_back(Elt: operand->get().getType());
1691 distTypes.push_back(Elt: distType);
1692 }
1693 });
1694
1695 SmallVector<size_t> newRetIndices;
1696 WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
1697 rewriter, warpOp, escapingValues.getArrayRef(), distTypes,
1698 newRetIndices);
1699 yield = cast<vector::YieldOp>(
1700 newWarpOp.getBodyRegion().getBlocks().begin()->getTerminator());
1701
1702 SmallVector<Value> newOperands;
1703 SmallVector<unsigned> resultIdx;
1704 // Collect all the outputs coming from the forOp.
1705 for (OpOperand &yieldOperand : yield->getOpOperands()) {
1706 if (yieldOperand.get().getDefiningOp() != forOp.getOperation())
1707 continue;
1708 auto forResult = cast<OpResult>(yieldOperand.get());
1709 newOperands.push_back(
1710 newWarpOp.getResult(yieldOperand.getOperandNumber()));
1711 yieldOperand.set(forOp.getInitArgs()[forResult.getResultNumber()]);
1712 resultIdx.push_back(yieldOperand.getOperandNumber());
1713 }
1714
1715 OpBuilder::InsertionGuard g(rewriter);
1716 rewriter.setInsertionPointAfter(newWarpOp);
1717
1718 // Create a new for op outside the region with a WarpExecuteOnLane0Op region
1719 // inside.
1720 auto newForOp = rewriter.create<scf::ForOp>(
1721 forOp.getLoc(), forOp.getLowerBound(), forOp.getUpperBound(),
1722 forOp.getStep(), newOperands);
1723 rewriter.setInsertionPoint(newForOp.getBody(), newForOp.getBody()->begin());
1724
1725 SmallVector<Value> warpInput(newForOp.getRegionIterArgs().begin(),
1726 newForOp.getRegionIterArgs().end());
1727 SmallVector<Type> warpInputType(forOp.getResultTypes().begin(),
1728 forOp.getResultTypes().end());
1729 llvm::SmallDenseMap<Value, int64_t> argIndexMapping;
1730 for (auto [i, retIdx] : llvm::enumerate(First&: newRetIndices)) {
1731 warpInput.push_back(Elt: newWarpOp.getResult(retIdx));
1732 argIndexMapping[escapingValues[i]] = warpInputType.size();
1733 warpInputType.push_back(Elt: inputTypes[i]);
1734 }
1735 auto innerWarp = rewriter.create<WarpExecuteOnLane0Op>(
1736 newWarpOp.getLoc(), newForOp.getResultTypes(), newWarpOp.getLaneid(),
1737 newWarpOp.getWarpSize(), warpInput, warpInputType);
1738
1739 SmallVector<Value> argMapping;
1740 argMapping.push_back(Elt: newForOp.getInductionVar());
1741 for (Value args : innerWarp.getBody()->getArguments()) {
1742 argMapping.push_back(args);
1743 }
1744 argMapping.resize(forOp.getBody()->getNumArguments());
1745 SmallVector<Value> yieldOperands;
1746 for (Value operand : forOp.getBody()->getTerminator()->getOperands())
1747 yieldOperands.push_back(operand);
1748 rewriter.eraseOp(op: forOp.getBody()->getTerminator());
1749 rewriter.mergeBlocks(source: forOp.getBody(), dest: innerWarp.getBody(), argValues: argMapping);
1750 rewriter.setInsertionPoint(innerWarp.getBody(), innerWarp.getBody()->end());
1751 rewriter.create<vector::YieldOp>(innerWarp.getLoc(), yieldOperands);
1752 rewriter.setInsertionPointAfter(innerWarp);
1753 if (!innerWarp.getResults().empty())
1754 rewriter.create<scf::YieldOp>(forOp.getLoc(), innerWarp.getResults());
1755 rewriter.eraseOp(op: forOp);
1756 // Replace the warpOp result coming from the original ForOp.
1757 for (const auto &res : llvm::enumerate(First&: resultIdx)) {
1758 rewriter.replaceAllUsesWith(newWarpOp.getResult(res.value()),
1759 newForOp.getResult(res.index()));
1760 newForOp->setOperand(res.index() + 3, newWarpOp.getResult(res.value()));
1761 }
1762 newForOp.walk([&](Operation *op) {
1763 for (OpOperand &operand : op->getOpOperands()) {
1764 auto it = argIndexMapping.find(Val: operand.get());
1765 if (it == argIndexMapping.end())
1766 continue;
1767 operand.set(innerWarp.getBodyRegion().getArgument(it->second));
1768 }
1769 });
1770
1771 // Finally, hoist out any now uniform code from the inner warp op.
1772 mlir::vector::moveScalarUniformCode(innerWarp);
1773 return success();
1774 }
1775
1776private:
1777 DistributionMapFn distributionMapFn;
1778};
1779
1780/// A pattern that extracts vector.reduction ops from a WarpExecuteOnLane0Op.
1781/// The vector is reduced in parallel. Currently limited to vector size matching
1782/// the warpOp size. E.g.:
1783/// ```
1784/// %r = vector_ext.warp_execute_on_lane_0(%laneid)[32] -> (f32) {
1785/// %0 = "some_def"() : () -> (vector<32xf32>)
1786/// %1 = vector.reduction "add", %0 : vector<32xf32> into f32
1787/// vector_ext.yield %1 : f32
1788/// }
1789/// ```
1790/// is lowered to:
1791/// ```
1792/// %0 = vector_ext.warp_execute_on_lane_0(%laneid)[32] -> (vector<1xf32>) {
1793/// %1 = "some_def"() : () -> (vector<32xf32>)
1794/// vector_ext.yield %1 : vector<32xf32>
1795/// }
1796/// %a = vector.extract %0[0] : f32 from vector<1xf32>
1797/// %r = ("warp.reduction %a")
1798/// ```
1799struct WarpOpReduction : public OpRewritePattern<WarpExecuteOnLane0Op> {
1800 WarpOpReduction(MLIRContext *context,
1801 DistributedReductionFn distributedReductionFn,
1802 PatternBenefit benefit = 1)
1803 : OpRewritePattern<WarpExecuteOnLane0Op>(context, benefit),
1804 distributedReductionFn(std::move(distributedReductionFn)) {}
1805
1806 LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
1807 PatternRewriter &rewriter) const override {
1808 OpOperand *yieldOperand =
1809 getWarpResult(warpOp, llvm::IsaPred<vector::ReductionOp>);
1810 if (!yieldOperand)
1811 return failure();
1812
1813 auto reductionOp =
1814 cast<vector::ReductionOp>(yieldOperand->get().getDefiningOp());
1815 auto vectorType = cast<VectorType>(reductionOp.getVector().getType());
1816 // Only rank 1 vectors supported.
1817 if (vectorType.getRank() != 1)
1818 return rewriter.notifyMatchFailure(
1819 warpOp, "Only rank 1 reductions can be distributed.");
1820 // Only warp_size-sized vectors supported.
1821 if (vectorType.getShape()[0] % warpOp.getWarpSize() != 0)
1822 return rewriter.notifyMatchFailure(
1823 warpOp, "Reduction vector dimension must match was size.");
1824 if (!reductionOp.getType().isIntOrFloat())
1825 return rewriter.notifyMatchFailure(
1826 warpOp, "Reduction distribution currently only supports floats and "
1827 "integer types.");
1828
1829 int64_t numElements = vectorType.getShape()[0] / warpOp.getWarpSize();
1830 // Return vector that will be reduced from the WarpExecuteOnLane0Op.
1831 unsigned operandIndex = yieldOperand->getOperandNumber();
1832 SmallVector<Value> yieldValues = {reductionOp.getVector()};
1833 SmallVector<Type> retTypes = {
1834 VectorType::get({numElements}, reductionOp.getType())};
1835 if (reductionOp.getAcc()) {
1836 yieldValues.push_back(Elt: reductionOp.getAcc());
1837 retTypes.push_back(Elt: reductionOp.getAcc().getType());
1838 }
1839 SmallVector<size_t> newRetIndices;
1840 WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
1841 rewriter, warpOp, yieldValues, retTypes, newRetIndices);
1842 rewriter.setInsertionPointAfter(newWarpOp);
1843
1844 // Obtain data to reduce for a single lane.
1845 Value laneValVec = newWarpOp.getResult(newRetIndices[0]);
1846 // Distribute and reduce across threads.
1847 Value fullReduce =
1848 distributedReductionFn(reductionOp.getLoc(), rewriter, laneValVec,
1849 reductionOp.getKind(), newWarpOp.getWarpSize());
1850 if (reductionOp.getAcc()) {
1851 fullReduce = vector::makeArithReduction(
1852 rewriter, reductionOp.getLoc(), reductionOp.getKind(), fullReduce,
1853 newWarpOp.getResult(newRetIndices[1]));
1854 }
1855 rewriter.replaceAllUsesWith(newWarpOp.getResult(operandIndex), fullReduce);
1856 return success();
1857 }
1858
1859private:
1860 DistributedReductionFn distributedReductionFn;
1861};
1862
1863} // namespace
1864
1865void mlir::vector::populateWarpExecuteOnLane0OpToScfForPattern(
1866 RewritePatternSet &patterns,
1867 const WarpExecuteOnLane0LoweringOptions &options, PatternBenefit benefit) {
1868 patterns.add<WarpOpToScfIfPattern>(arg: patterns.getContext(), args: options, args&: benefit);
1869}
1870
1871void mlir::vector::populateDistributeTransferWriteOpPatterns(
1872 RewritePatternSet &patterns, const DistributionMapFn &distributionMapFn,
1873 unsigned maxNumElementsToExtract, PatternBenefit benefit) {
1874 patterns.add<WarpOpTransferWrite>(arg: patterns.getContext(), args: distributionMapFn,
1875 args&: maxNumElementsToExtract, args&: benefit);
1876}
1877
1878void mlir::vector::populatePropagateWarpVectorDistributionPatterns(
1879 RewritePatternSet &patterns, const DistributionMapFn &distributionMapFn,
1880 const WarpShuffleFromIdxFn &warpShuffleFromIdxFn, PatternBenefit benefit,
1881 PatternBenefit readBenefit) {
1882 patterns.add<WarpOpTransferRead>(arg: patterns.getContext(), args&: readBenefit);
1883 patterns
1884 .add<WarpOpElementwise, WarpOpDeadResult, WarpOpBroadcast,
1885 WarpOpShapeCast, WarpOpExtract, WarpOpForwardOperand, WarpOpConstant,
1886 WarpOpInsertElement, WarpOpInsert, WarpOpCreateMask>(
1887 arg: patterns.getContext(), args&: benefit);
1888 patterns.add<WarpOpExtractElement>(arg: patterns.getContext(),
1889 args: warpShuffleFromIdxFn, args&: benefit);
1890 patterns.add<WarpOpScfForOp>(arg: patterns.getContext(), args: distributionMapFn,
1891 args&: benefit);
1892}
1893
1894void mlir::vector::populateDistributeReduction(
1895 RewritePatternSet &patterns,
1896 const DistributedReductionFn &distributedReductionFn,
1897 PatternBenefit benefit) {
1898 patterns.add<WarpOpReduction>(arg: patterns.getContext(), args: distributedReductionFn,
1899 args&: benefit);
1900}
1901
1902void mlir::vector::moveScalarUniformCode(WarpExecuteOnLane0Op warpOp) {
1903 Block *body = warpOp.getBody();
1904
1905 // Keep track of the ops we want to hoist.
1906 llvm::SmallSetVector<Operation *, 8> opsToMove;
1907
1908 // Helper to check if a value is or will be defined outside of the region.
1909 auto isDefinedOutsideOfBody = [&](Value value) {
1910 auto *definingOp = value.getDefiningOp();
1911 return (definingOp && opsToMove.count(definingOp)) ||
1912 warpOp.isDefinedOutsideOfRegion(value);
1913 };
1914
1915 // Do not use walk here, as we do not want to go into nested regions and hoist
1916 // operations from there.
1917 for (auto &op : body->without_terminator()) {
1918 bool hasVectorResult = llvm::any_of(op.getResults(), [](Value result) {
1919 return isa<VectorType>(result.getType());
1920 });
1921 if (!hasVectorResult && canBeHoisted(&op, isDefinedOutsideOfBody))
1922 opsToMove.insert(&op);
1923 }
1924
1925 // Move all the ops marked as uniform outside of the region.
1926 for (Operation *op : opsToMove)
1927 op->moveBefore(warpOp);
1928}
1929

source code of mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp