1 | //===- VectorDistribute.cpp - patterns to do vector distribution ----------===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | |
9 | #include "mlir/Dialect/Affine/IR/AffineOps.h" |
10 | #include "mlir/Dialect/Arith/IR/Arith.h" |
11 | #include "mlir/Dialect/MemRef/IR/MemRef.h" |
12 | #include "mlir/Dialect/SCF/IR/SCF.h" |
13 | #include "mlir/Dialect/Vector/IR/VectorOps.h" |
14 | #include "mlir/Dialect/Vector/Transforms/VectorDistribution.h" |
15 | #include "mlir/IR/AffineExpr.h" |
16 | #include "mlir/Interfaces/SideEffectInterfaces.h" |
17 | #include "mlir/Transforms/RegionUtils.h" |
18 | #include "llvm/ADT/SetVector.h" |
19 | #include "llvm/Support/FormatVariadic.h" |
20 | #include <numeric> |
21 | #include <utility> |
22 | |
23 | using namespace mlir; |
24 | using namespace mlir::vector; |
25 | |
26 | /// Currently the distribution map is implicit based on the vector shape. In the |
27 | /// future it will be part of the op. |
28 | /// Example: |
29 | /// ``` |
30 | /// %0 = vector.warp_execute_on_lane_0(%arg0) -> (vector<1x16x2xf32>) { |
31 | /// ... |
32 | /// vector.yield %3 : vector<32x16x64xf32> |
33 | /// } |
34 | /// ``` |
35 | /// Would have an implicit map of: |
36 | /// `(d0, d1, d2) -> (d0, d2)` |
37 | static AffineMap calculateImplicitMap(VectorType sequentialType, |
38 | VectorType distributedType) { |
39 | SmallVector<AffineExpr> perm; |
40 | perm.reserve(N: 1); |
41 | // Check which dimensions of the sequential type are different than the |
42 | // dimensions of the distributed type to know the distributed dimensions. Then |
43 | // associate each distributed dimension to an ID in order. |
44 | for (unsigned i = 0, e = sequentialType.getRank(); i < e; i++) { |
45 | if (sequentialType.getDimSize(i) != distributedType.getDimSize(i)) |
46 | perm.push_back(Elt: getAffineDimExpr(i, distributedType.getContext())); |
47 | } |
48 | auto map = AffineMap::get(sequentialType.getRank(), 0, perm, |
49 | distributedType.getContext()); |
50 | return map; |
51 | } |
52 | |
53 | namespace { |
54 | |
55 | /// Helper struct to create the load / store operations that permit transit |
56 | /// through the parallel / sequential and the sequential / parallel boundaries |
57 | /// when performing `rewriteWarpOpToScfFor`. |
58 | /// |
59 | /// The vector distribution dimension is inferred from the vector types. |
60 | struct DistributedLoadStoreHelper { |
61 | DistributedLoadStoreHelper(Value sequentialVal, Value distributedVal, |
62 | Value laneId, Value zero) |
63 | : sequentialVal(sequentialVal), distributedVal(distributedVal), |
64 | laneId(laneId), zero(zero) { |
65 | sequentialVectorType = dyn_cast<VectorType>(sequentialVal.getType()); |
66 | distributedVectorType = dyn_cast<VectorType>(distributedVal.getType()); |
67 | if (sequentialVectorType && distributedVectorType) |
68 | distributionMap = |
69 | calculateImplicitMap(sequentialVectorType, distributedVectorType); |
70 | } |
71 | |
72 | Value buildDistributedOffset(RewriterBase &b, Location loc, int64_t index) { |
73 | int64_t distributedSize = distributedVectorType.getDimSize(index); |
74 | AffineExpr tid = getAffineSymbolExpr(position: 0, context: b.getContext()); |
75 | return b.createOrFold<affine::AffineApplyOp>(loc, tid * distributedSize, |
76 | ArrayRef<Value>{laneId}); |
77 | } |
78 | |
79 | /// Create a store during the process of distributing the |
80 | /// `vector.warp_execute_on_thread_0` op. |
81 | /// Vector distribution assumes the following convention regarding the |
82 | /// temporary buffers that are created to transition values. This **must** |
83 | /// be properly specified in the `options.warpAllocationFn`: |
84 | /// 1. scalars of type T transit through a memref<1xT>. |
85 | /// 2. vectors of type V<shapexT> transit through a memref<shapexT> |
86 | Operation *buildStore(RewriterBase &b, Location loc, Value val, |
87 | Value buffer) { |
88 | assert((val == distributedVal || val == sequentialVal) && |
89 | "Must store either the preregistered distributed or the " |
90 | "preregistered sequential value." ); |
91 | // Scalar case can directly use memref.store. |
92 | if (!isa<VectorType>(val.getType())) |
93 | return b.create<memref::StoreOp>(loc, val, buffer, zero); |
94 | |
95 | // Vector case must use vector::TransferWriteOp which will later lower to |
96 | // vector.store of memref.store depending on further lowerings. |
97 | int64_t rank = sequentialVectorType.getRank(); |
98 | SmallVector<Value> indices(rank, zero); |
99 | if (val == distributedVal) { |
100 | for (auto dimExpr : distributionMap.getResults()) { |
101 | int64_t index = cast<AffineDimExpr>(Val&: dimExpr).getPosition(); |
102 | indices[index] = buildDistributedOffset(b, loc, index); |
103 | } |
104 | } |
105 | SmallVector<bool> inBounds(indices.size(), true); |
106 | return b.create<vector::TransferWriteOp>( |
107 | loc, val, buffer, indices, |
108 | ArrayRef<bool>(inBounds.begin(), inBounds.end())); |
109 | } |
110 | |
111 | /// Create a load during the process of distributing the |
112 | /// `vector.warp_execute_on_thread_0` op. |
113 | /// Vector distribution assumes the following convention regarding the |
114 | /// temporary buffers that are created to transition values. This **must** |
115 | /// be properly specified in the `options.warpAllocationFn`: |
116 | /// 1. scalars of type T transit through a memref<1xT>. |
117 | /// 2. vectors of type V<shapexT> transit through a memref<shapexT> |
118 | /// |
119 | /// When broadcastMode is true, the load is not distributed to account for |
120 | /// the broadcast semantics of the `vector.warp_execute_on_lane_0` op. |
121 | /// |
122 | /// Example: |
123 | /// |
124 | /// ``` |
125 | /// %r = vector.warp_execute_on_lane_0(...) -> (f32) { |
126 | /// vector.yield %cst : f32 |
127 | /// } |
128 | /// // Both types are f32. The constant %cst is broadcasted to all lanes. |
129 | /// ``` |
130 | /// This behavior described in more detail in the documentation of the op. |
131 | Value buildLoad(RewriterBase &b, Location loc, Type type, Value buffer) { |
132 | |
133 | // Scalar case can directly use memref.store. |
134 | if (!isa<VectorType>(type)) |
135 | return b.create<memref::LoadOp>(loc, buffer, zero); |
136 | |
137 | // Other cases must be vector atm. |
138 | // Vector case must use vector::TransferReadOp which will later lower to |
139 | // vector.read of memref.read depending on further lowerings. |
140 | assert((type == distributedVectorType || type == sequentialVectorType) && |
141 | "Must store either the preregistered distributed or the " |
142 | "preregistered sequential type." ); |
143 | SmallVector<Value> indices(sequentialVectorType.getRank(), zero); |
144 | if (type == distributedVectorType) { |
145 | for (auto dimExpr : distributionMap.getResults()) { |
146 | int64_t index = cast<AffineDimExpr>(Val&: dimExpr).getPosition(); |
147 | indices[index] = buildDistributedOffset(b, loc, index); |
148 | } |
149 | } |
150 | SmallVector<bool> inBounds(indices.size(), true); |
151 | return b.create<vector::TransferReadOp>( |
152 | loc, cast<VectorType>(type), buffer, indices, |
153 | ArrayRef<bool>(inBounds.begin(), inBounds.end())); |
154 | } |
155 | |
156 | Value sequentialVal, distributedVal, laneId, zero; |
157 | VectorType sequentialVectorType, distributedVectorType; |
158 | AffineMap distributionMap; |
159 | }; |
160 | |
161 | } // namespace |
162 | |
163 | /// Helper to create a new WarpExecuteOnLane0Op with different signature. |
164 | static WarpExecuteOnLane0Op moveRegionToNewWarpOpAndReplaceReturns( |
165 | RewriterBase &rewriter, WarpExecuteOnLane0Op warpOp, |
166 | ValueRange newYieldedValues, TypeRange newReturnTypes) { |
167 | // Create a new op before the existing one, with the extra operands. |
168 | OpBuilder::InsertionGuard g(rewriter); |
169 | rewriter.setInsertionPoint(warpOp); |
170 | auto newWarpOp = rewriter.create<WarpExecuteOnLane0Op>( |
171 | warpOp.getLoc(), newReturnTypes, warpOp.getLaneid(), warpOp.getWarpSize(), |
172 | warpOp.getArgs(), warpOp.getBody()->getArgumentTypes()); |
173 | |
174 | Region &opBody = warpOp.getBodyRegion(); |
175 | Region &newOpBody = newWarpOp.getBodyRegion(); |
176 | Block &newOpFirstBlock = newOpBody.front(); |
177 | rewriter.inlineRegionBefore(region&: opBody, parent&: newOpBody, before: newOpBody.begin()); |
178 | rewriter.eraseBlock(block: &newOpFirstBlock); |
179 | assert(newWarpOp.getWarpRegion().hasOneBlock() && |
180 | "expected WarpOp with single block" ); |
181 | |
182 | auto yield = |
183 | cast<vector::YieldOp>(newOpBody.getBlocks().begin()->getTerminator()); |
184 | |
185 | rewriter.modifyOpInPlace( |
186 | yield, [&]() { yield.getOperandsMutable().assign(newYieldedValues); }); |
187 | return newWarpOp; |
188 | } |
189 | |
190 | /// Helper to create a new WarpExecuteOnLane0Op region with extra outputs. |
191 | /// `indices` return the index of each new output. |
192 | static WarpExecuteOnLane0Op moveRegionToNewWarpOpAndAppendReturns( |
193 | RewriterBase &rewriter, WarpExecuteOnLane0Op warpOp, |
194 | ValueRange newYieldedValues, TypeRange newReturnTypes, |
195 | llvm::SmallVector<size_t> &indices) { |
196 | SmallVector<Type> types(warpOp.getResultTypes().begin(), |
197 | warpOp.getResultTypes().end()); |
198 | auto yield = cast<vector::YieldOp>( |
199 | warpOp.getBodyRegion().getBlocks().begin()->getTerminator()); |
200 | llvm::SmallSetVector<Value, 32> yieldValues(yield.getOperands().begin(), |
201 | yield.getOperands().end()); |
202 | for (auto newRet : llvm::zip(t&: newYieldedValues, u&: newReturnTypes)) { |
203 | if (yieldValues.insert(X: std::get<0>(t&: newRet))) { |
204 | types.push_back(Elt: std::get<1>(t&: newRet)); |
205 | indices.push_back(Elt: yieldValues.size() - 1); |
206 | } else { |
207 | // If the value already exit the region don't create a new output. |
208 | for (auto [idx, yieldOperand] : |
209 | llvm::enumerate(yieldValues.getArrayRef())) { |
210 | if (yieldOperand == std::get<0>(newRet)) { |
211 | indices.push_back(idx); |
212 | break; |
213 | } |
214 | } |
215 | } |
216 | } |
217 | yieldValues.insert(Start: newYieldedValues.begin(), End: newYieldedValues.end()); |
218 | WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndReplaceReturns( |
219 | rewriter, warpOp, yieldValues.getArrayRef(), types); |
220 | rewriter.replaceOp(warpOp, |
221 | newWarpOp.getResults().take_front(warpOp.getNumResults())); |
222 | return newWarpOp; |
223 | } |
224 | |
225 | /// Helper to know if an op can be hoisted out of the region. |
226 | static bool canBeHoisted(Operation *op, |
227 | function_ref<bool(Value)> definedOutside) { |
228 | return llvm::all_of(Range: op->getOperands(), P: definedOutside) && |
229 | isMemoryEffectFree(op) && op->getNumRegions() == 0; |
230 | } |
231 | |
232 | /// Return a value yielded by `warpOp` which statifies the filter lamdba |
233 | /// condition and is not dead. |
234 | static OpOperand *getWarpResult(WarpExecuteOnLane0Op warpOp, |
235 | const std::function<bool(Operation *)> &fn) { |
236 | auto yield = cast<vector::YieldOp>( |
237 | warpOp.getBodyRegion().getBlocks().begin()->getTerminator()); |
238 | for (OpOperand &yieldOperand : yield->getOpOperands()) { |
239 | Value yieldValues = yieldOperand.get(); |
240 | Operation *definedOp = yieldValues.getDefiningOp(); |
241 | if (definedOp && fn(definedOp)) { |
242 | if (!warpOp.getResult(yieldOperand.getOperandNumber()).use_empty()) |
243 | return &yieldOperand; |
244 | } |
245 | } |
246 | return {}; |
247 | } |
248 | |
249 | // Clones `op` into a new operation that takes `operands` and returns |
250 | // `resultTypes`. |
251 | static Operation *cloneOpWithOperandsAndTypes(RewriterBase &rewriter, |
252 | Location loc, Operation *op, |
253 | ArrayRef<Value> operands, |
254 | ArrayRef<Type> resultTypes) { |
255 | OperationState res(loc, op->getName().getStringRef(), operands, resultTypes, |
256 | op->getAttrs()); |
257 | return rewriter.create(state: res); |
258 | } |
259 | |
260 | namespace { |
261 | |
262 | /// Rewrite a WarpExecuteOnLane0Op into a predicated scf.if op where the single |
263 | /// thread `laneId` executes the entirety of the computation. |
264 | /// |
265 | /// After the transformation: |
266 | /// - the IR within the scf.if op can be thought of as executing sequentially |
267 | /// (from the point of view of threads along `laneId`). |
268 | /// - the IR outside of the scf.if op can be thought of as executing in |
269 | /// parallel (from the point of view of threads along `laneId`). |
270 | /// |
271 | /// Values that need to transit through the parallel / sequential and the |
272 | /// sequential / parallel boundaries do so via reads and writes to a temporary |
273 | /// memory location. |
274 | /// |
275 | /// The transformation proceeds in multiple steps: |
276 | /// 1. Create the scf.if op. |
277 | /// 2. Insert appropriate (alloc, write)-pairs before the scf.if and reads |
278 | /// within the scf.if to transit the values captured from above. |
279 | /// 3. Synchronize before the scf.if to ensure all writes inserted in 2. are |
280 | /// consistent within the scf.if. |
281 | /// 4. Move the body of the WarpExecuteOnLane0Op inside the scf.if. |
282 | /// 5. Insert appropriate writes within scf.if and reads after the scf.if to |
283 | /// transit the values returned by the op. |
284 | /// 6. Synchronize after the scf.if to ensure all writes inserted in 5. are |
285 | /// consistent after the scf.if. |
286 | /// 7. Perform late cleanups. |
287 | /// |
288 | /// All this assumes the vector distribution occurs along the most minor |
289 | /// distributed vector dimension. |
290 | struct WarpOpToScfIfPattern : public OpRewritePattern<WarpExecuteOnLane0Op> { |
291 | WarpOpToScfIfPattern(MLIRContext *context, |
292 | const WarpExecuteOnLane0LoweringOptions &options, |
293 | PatternBenefit benefit = 1) |
294 | : OpRewritePattern<WarpExecuteOnLane0Op>(context, benefit), |
295 | options(options) {} |
296 | |
297 | LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp, |
298 | PatternRewriter &rewriter) const override { |
299 | assert(warpOp.getBodyRegion().hasOneBlock() && |
300 | "expected WarpOp with single block" ); |
301 | Block *warpOpBody = &warpOp.getBodyRegion().front(); |
302 | Location loc = warpOp.getLoc(); |
303 | |
304 | // Passed all checks. Start rewriting. |
305 | OpBuilder::InsertionGuard g(rewriter); |
306 | rewriter.setInsertionPoint(warpOp); |
307 | |
308 | // Step 1: Create scf.if op. |
309 | Value c0 = rewriter.create<arith::ConstantIndexOp>(location: loc, args: 0); |
310 | Value isLane0 = rewriter.create<arith::CmpIOp>( |
311 | loc, arith::CmpIPredicate::eq, warpOp.getLaneid(), c0); |
312 | auto ifOp = rewriter.create<scf::IfOp>(loc, isLane0, |
313 | /*withElseRegion=*/false); |
314 | rewriter.eraseOp(op: ifOp.thenBlock()->getTerminator()); |
315 | |
316 | // Step 2: insert appropriate (alloc, write)-pairs before the scf.if and |
317 | // reads within the scf.if to transit the values captured from above. |
318 | SmallVector<Value> bbArgReplacements; |
319 | for (const auto &it : llvm::enumerate(warpOp.getArgs())) { |
320 | Value sequentialVal = warpOpBody->getArgument(it.index()); |
321 | Value distributedVal = it.value(); |
322 | DistributedLoadStoreHelper helper(sequentialVal, distributedVal, |
323 | warpOp.getLaneid(), c0); |
324 | |
325 | // Create buffer before the ifOp. |
326 | rewriter.setInsertionPoint(ifOp); |
327 | Value buffer = options.warpAllocationFn(loc, rewriter, warpOp, |
328 | sequentialVal.getType()); |
329 | // Store distributed vector into buffer, before the ifOp. |
330 | helper.buildStore(rewriter, loc, distributedVal, buffer); |
331 | // Load sequential vector from buffer, inside the ifOp. |
332 | rewriter.setInsertionPointToStart(ifOp.thenBlock()); |
333 | bbArgReplacements.push_back( |
334 | helper.buildLoad(rewriter, loc, sequentialVal.getType(), buffer)); |
335 | } |
336 | |
337 | // Step 3. Insert sync after all the stores and before all the loads. |
338 | if (!warpOp.getArgs().empty()) { |
339 | rewriter.setInsertionPoint(ifOp); |
340 | options.warpSyncronizationFn(loc, rewriter, warpOp); |
341 | } |
342 | |
343 | // Step 4. Move body of warpOp to ifOp. |
344 | rewriter.mergeBlocks(source: warpOpBody, dest: ifOp.thenBlock(), argValues: bbArgReplacements); |
345 | |
346 | // Step 5. Insert appropriate writes within scf.if and reads after the |
347 | // scf.if to transit the values returned by the op. |
348 | // TODO: at this point, we can reuse the shared memory from previous |
349 | // buffers. |
350 | SmallVector<Value> replacements; |
351 | auto yieldOp = cast<vector::YieldOp>(ifOp.thenBlock()->getTerminator()); |
352 | Location yieldLoc = yieldOp.getLoc(); |
353 | for (const auto &it : llvm::enumerate(yieldOp.getOperands())) { |
354 | Value sequentialVal = it.value(); |
355 | Value distributedVal = warpOp->getResult(it.index()); |
356 | DistributedLoadStoreHelper helper(sequentialVal, distributedVal, |
357 | warpOp.getLaneid(), c0); |
358 | |
359 | // Create buffer before the ifOp. |
360 | rewriter.setInsertionPoint(ifOp); |
361 | Value buffer = options.warpAllocationFn(loc, rewriter, warpOp, |
362 | sequentialVal.getType()); |
363 | |
364 | // Store yielded value into buffer, inside the ifOp, before the |
365 | // terminator. |
366 | rewriter.setInsertionPoint(yieldOp); |
367 | helper.buildStore(rewriter, loc, sequentialVal, buffer); |
368 | |
369 | // Load distributed value from buffer, after the warpOp. |
370 | rewriter.setInsertionPointAfter(ifOp); |
371 | // Result type and yielded value type are the same. This is a broadcast. |
372 | // E.g.: |
373 | // %r = vector.warp_execute_on_lane_0(...) -> (f32) { |
374 | // vector.yield %cst : f32 |
375 | // } |
376 | // Both types are f32. The constant %cst is broadcasted to all lanes. |
377 | // This is described in more detail in the documentation of the op. |
378 | replacements.push_back( |
379 | helper.buildLoad(rewriter, loc, distributedVal.getType(), buffer)); |
380 | } |
381 | |
382 | // Step 6. Insert sync after all the stores and before all the loads. |
383 | if (!yieldOp.getOperands().empty()) { |
384 | rewriter.setInsertionPointAfter(ifOp); |
385 | options.warpSyncronizationFn(loc, rewriter, warpOp); |
386 | } |
387 | |
388 | // Step 7. Delete terminator and add empty scf.yield. |
389 | rewriter.eraseOp(op: yieldOp); |
390 | rewriter.setInsertionPointToEnd(ifOp.thenBlock()); |
391 | rewriter.create<scf::YieldOp>(yieldLoc); |
392 | |
393 | // Compute replacements for WarpOp results. |
394 | rewriter.replaceOp(warpOp, replacements); |
395 | |
396 | return success(); |
397 | } |
398 | |
399 | private: |
400 | const WarpExecuteOnLane0LoweringOptions &options; |
401 | }; |
402 | |
403 | /// Clone `writeOp` assumed to be nested under `warpOp` into a new warp execute |
404 | /// op with the proper return type. |
405 | /// The new write op is updated to write the result of the new warp execute op. |
406 | /// The old `writeOp` is deleted. |
407 | static vector::TransferWriteOp cloneWriteOp(RewriterBase &rewriter, |
408 | WarpExecuteOnLane0Op warpOp, |
409 | vector::TransferWriteOp writeOp, |
410 | VectorType targetType, |
411 | VectorType maybeMaskType) { |
412 | assert(writeOp->getParentOp() == warpOp && |
413 | "write must be nested immediately under warp" ); |
414 | OpBuilder::InsertionGuard g(rewriter); |
415 | SmallVector<size_t> newRetIndices; |
416 | WarpExecuteOnLane0Op newWarpOp; |
417 | if (maybeMaskType) { |
418 | newWarpOp = moveRegionToNewWarpOpAndAppendReturns( |
419 | rewriter, warpOp, ValueRange{writeOp.getVector(), writeOp.getMask()}, |
420 | TypeRange{targetType, maybeMaskType}, newRetIndices); |
421 | } else { |
422 | newWarpOp = moveRegionToNewWarpOpAndAppendReturns( |
423 | rewriter, warpOp, ValueRange{{writeOp.getVector()}}, |
424 | TypeRange{targetType}, newRetIndices); |
425 | } |
426 | rewriter.setInsertionPointAfter(newWarpOp); |
427 | auto newWriteOp = |
428 | cast<vector::TransferWriteOp>(rewriter.clone(*writeOp.getOperation())); |
429 | rewriter.eraseOp(op: writeOp); |
430 | newWriteOp.getVectorMutable().assign(newWarpOp.getResult(newRetIndices[0])); |
431 | if (maybeMaskType) |
432 | newWriteOp.getMaskMutable().assign(newWarpOp.getResult(newRetIndices[1])); |
433 | return newWriteOp; |
434 | } |
435 | |
436 | /// Return the distributed vector type based on the original type and the |
437 | /// distribution map. The map is expected to have a dimension equal to the |
438 | /// original type rank and should be a projection where the results are the |
439 | /// distributed dimensions. The number of results should be equal to the number |
440 | /// of warp sizes which is currently limited to 1. |
441 | /// Example: For a vector<16x32x64> distributed with a map(d0, d1, d2) -> (d1) |
442 | /// and a warp size of 16 would distribute the second dimension (associated to |
443 | /// d1) and return vector<16x2x64> |
444 | static VectorType getDistributedType(VectorType originalType, AffineMap map, |
445 | int64_t warpSize) { |
446 | SmallVector<int64_t> targetShape(originalType.getShape().begin(), |
447 | originalType.getShape().end()); |
448 | for (unsigned i = 0, e = map.getNumResults(); i < e; i++) { |
449 | unsigned position = map.getDimPosition(idx: i); |
450 | if (targetShape[position] % warpSize != 0) { |
451 | if (warpSize % targetShape[position] != 0) { |
452 | return VectorType(); |
453 | } |
454 | warpSize /= targetShape[position]; |
455 | targetShape[position] = 1; |
456 | continue; |
457 | } |
458 | targetShape[position] = targetShape[position] / warpSize; |
459 | warpSize = 1; |
460 | break; |
461 | } |
462 | if (warpSize != 1) { |
463 | return VectorType(); |
464 | } |
465 | VectorType targetType = |
466 | VectorType::get(targetShape, originalType.getElementType()); |
467 | return targetType; |
468 | } |
469 | |
470 | /// Distribute transfer_write ops based on the affine map returned by |
471 | /// `distributionMapFn`. Writes of size more than `maxNumElementToExtract` |
472 | /// will not be distributed (it should be less than the warp size). |
473 | /// |
474 | /// Example: |
475 | /// ``` |
476 | /// %0 = vector.warp_execute_on_lane_0(%id){ |
477 | /// ... |
478 | /// vector.transfer_write %v, %A[%c0] : vector<32xf32>, memref<128xf32> |
479 | /// vector.yield |
480 | /// } |
481 | /// ``` |
482 | /// To |
483 | /// ``` |
484 | /// %r:3 = vector.warp_execute_on_lane_0(%id) -> (vector<1xf32>) { |
485 | /// ... |
486 | /// vector.yield %v : vector<32xf32> |
487 | /// } |
488 | /// vector.transfer_write %v, %A[%id] : vector<1xf32>, memref<128xf32> |
489 | struct WarpOpTransferWrite : public OpRewritePattern<WarpExecuteOnLane0Op> { |
490 | WarpOpTransferWrite(MLIRContext *ctx, DistributionMapFn fn, |
491 | unsigned , PatternBenefit b = 1) |
492 | : OpRewritePattern<WarpExecuteOnLane0Op>(ctx, b), |
493 | distributionMapFn(std::move(fn)), |
494 | maxNumElementsToExtract(maxNumElementsToExtract) {} |
495 | |
496 | /// Distribute the TransferWriteOp. Only 1D distributions and vector dims that |
497 | /// are multiples of the distribution ratio are supported at the moment. |
498 | LogicalResult tryDistributeOp(RewriterBase &rewriter, |
499 | vector::TransferWriteOp writeOp, |
500 | WarpExecuteOnLane0Op warpOp) const { |
501 | VectorType writtenVectorType = writeOp.getVectorType(); |
502 | |
503 | // 1. If the write is 0-D, we just clone it into a new WarpExecuteOnLane0Op |
504 | // to separate it from the rest. |
505 | if (writtenVectorType.getRank() == 0) |
506 | return failure(); |
507 | |
508 | // 2. Compute the distributed type. |
509 | AffineMap map = distributionMapFn(writeOp.getVector()); |
510 | VectorType targetType = |
511 | getDistributedType(writtenVectorType, map, warpOp.getWarpSize()); |
512 | if (!targetType) |
513 | return failure(); |
514 | |
515 | // 2.5 Compute the distributed type for the new mask; |
516 | VectorType maskType; |
517 | if (writeOp.getMask()) { |
518 | // TODO: Distribution of masked writes with non-trivial permutation maps |
519 | // requires the distribution of the mask to elementwise match the |
520 | // distribution of the permuted written vector. Currently the details |
521 | // of which lane is responsible for which element is captured strictly |
522 | // by shape information on the warp op, and thus requires materializing |
523 | // the permutation in IR. |
524 | if (!writeOp.getPermutationMap().isMinorIdentity()) |
525 | return failure(); |
526 | maskType = |
527 | getDistributedType(writeOp.getMaskType(), map, warpOp.getWarpSize()); |
528 | } |
529 | |
530 | // 3. clone the write into a new WarpExecuteOnLane0Op to separate it from |
531 | // the rest. |
532 | vector::TransferWriteOp newWriteOp = |
533 | cloneWriteOp(rewriter, warpOp, writeOp, targetType, maskType); |
534 | |
535 | // 4. Reindex the write using the distribution map. |
536 | auto newWarpOp = |
537 | newWriteOp.getVector().getDefiningOp<WarpExecuteOnLane0Op>(); |
538 | |
539 | // Delinearize the lane id based on the way threads are divided across the |
540 | // vector. To get the number of threads per vector dimension, divide the |
541 | // sequential size by the distributed size along each dim. |
542 | rewriter.setInsertionPoint(newWriteOp); |
543 | SmallVector<OpFoldResult> delinearizedIdSizes; |
544 | for (auto [seqSize, distSize] : |
545 | llvm::zip_equal(writtenVectorType.getShape(), targetType.getShape())) { |
546 | assert(seqSize % distSize == 0 && "Invalid distributed vector shape" ); |
547 | delinearizedIdSizes.push_back(rewriter.getIndexAttr(seqSize / distSize)); |
548 | } |
549 | SmallVector<Value> delinearized; |
550 | if (map.getNumResults() > 1) { |
551 | delinearized = rewriter |
552 | .create<mlir::affine::AffineDelinearizeIndexOp>( |
553 | newWarpOp.getLoc(), newWarpOp.getLaneid(), |
554 | delinearizedIdSizes) |
555 | .getResults(); |
556 | } else { |
557 | // If there is only one map result, we can elide the delinearization |
558 | // op and use the lane id directly. |
559 | delinearized.append(targetType.getRank(), newWarpOp.getLaneid()); |
560 | } |
561 | |
562 | AffineMap indexMap = map.compose(newWriteOp.getPermutationMap()); |
563 | Location loc = newWriteOp.getLoc(); |
564 | SmallVector<Value> indices(newWriteOp.getIndices().begin(), |
565 | newWriteOp.getIndices().end()); |
566 | for (auto it : llvm::zip(indexMap.getResults(), map.getResults())) { |
567 | AffineExpr d0, d1; |
568 | bindDims(newWarpOp.getContext(), d0, d1); |
569 | auto indexExpr = dyn_cast<AffineDimExpr>(std::get<0>(it)); |
570 | if (!indexExpr) |
571 | continue; |
572 | unsigned indexPos = indexExpr.getPosition(); |
573 | unsigned vectorPos = cast<AffineDimExpr>(std::get<1>(it)).getPosition(); |
574 | Value laneId = delinearized[vectorPos]; |
575 | auto scale = |
576 | rewriter.getAffineConstantExpr(targetType.getDimSize(vectorPos)); |
577 | indices[indexPos] = affine::makeComposedAffineApply( |
578 | rewriter, loc, d0 + scale * d1, {indices[indexPos], laneId}); |
579 | } |
580 | newWriteOp.getIndicesMutable().assign(indices); |
581 | |
582 | return success(); |
583 | } |
584 | |
585 | /// Extract TransferWriteOps of vector<1x> into a separate warp op. |
586 | LogicalResult (RewriterBase &rewriter, |
587 | vector::TransferWriteOp writeOp, |
588 | WarpExecuteOnLane0Op warpOp) const { |
589 | Location loc = writeOp.getLoc(); |
590 | VectorType vecType = writeOp.getVectorType(); |
591 | |
592 | if (vecType.getNumElements() > maxNumElementsToExtract) { |
593 | return rewriter.notifyMatchFailure( |
594 | warpOp, |
595 | llvm::formatv( |
596 | "writes more elements ({0}) than allowed to extract ({1})" , |
597 | vecType.getNumElements(), maxNumElementsToExtract)); |
598 | } |
599 | |
600 | // Do not process warp ops that contain only TransferWriteOps. |
601 | if (llvm::all_of(warpOp.getOps(), |
602 | llvm::IsaPred<vector::TransferWriteOp, vector::YieldOp>)) |
603 | return failure(); |
604 | |
605 | SmallVector<Value> yieldValues = {writeOp.getVector()}; |
606 | SmallVector<Type> retTypes = {vecType}; |
607 | SmallVector<size_t> newRetIndices; |
608 | WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns( |
609 | rewriter, warpOp, yieldValues, retTypes, newRetIndices); |
610 | rewriter.setInsertionPointAfter(newWarpOp); |
611 | |
612 | // Create a second warp op that contains only writeOp. |
613 | auto secondWarpOp = rewriter.create<WarpExecuteOnLane0Op>( |
614 | loc, TypeRange(), newWarpOp.getLaneid(), newWarpOp.getWarpSize()); |
615 | Block &body = secondWarpOp.getBodyRegion().front(); |
616 | rewriter.setInsertionPointToStart(&body); |
617 | auto newWriteOp = |
618 | cast<vector::TransferWriteOp>(rewriter.clone(*writeOp.getOperation())); |
619 | newWriteOp.getVectorMutable().assign(newWarpOp.getResult(newRetIndices[0])); |
620 | rewriter.eraseOp(op: writeOp); |
621 | rewriter.create<vector::YieldOp>(newWarpOp.getLoc()); |
622 | return success(); |
623 | } |
624 | |
625 | LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp, |
626 | PatternRewriter &rewriter) const override { |
627 | auto yield = cast<vector::YieldOp>( |
628 | warpOp.getBodyRegion().getBlocks().begin()->getTerminator()); |
629 | Operation *lastNode = yield->getPrevNode(); |
630 | auto writeOp = dyn_cast_or_null<vector::TransferWriteOp>(lastNode); |
631 | if (!writeOp) |
632 | return failure(); |
633 | |
634 | Value maybeMask = writeOp.getMask(); |
635 | if (!llvm::all_of(writeOp->getOperands(), [&](Value value) { |
636 | return writeOp.getVector() == value || |
637 | (maybeMask && maybeMask == value) || |
638 | warpOp.isDefinedOutsideOfRegion(value); |
639 | })) |
640 | return failure(); |
641 | |
642 | if (succeeded(tryDistributeOp(rewriter, writeOp, warpOp))) |
643 | return success(); |
644 | |
645 | // Masked writes not supported for extraction. |
646 | if (writeOp.getMask()) |
647 | return failure(); |
648 | |
649 | if (succeeded(tryExtractOp(rewriter, writeOp, warpOp))) |
650 | return success(); |
651 | |
652 | return failure(); |
653 | } |
654 | |
655 | private: |
656 | DistributionMapFn distributionMapFn; |
657 | unsigned = 1; |
658 | }; |
659 | |
660 | /// Sink out elementwise op feeding into a warp op yield. |
661 | /// ``` |
662 | /// %0 = vector.warp_execute_on_lane_0(%arg0) -> (vector<1xf32>) { |
663 | /// ... |
664 | /// %3 = arith.addf %1, %2 : vector<32xf32> |
665 | /// vector.yield %3 : vector<32xf32> |
666 | /// } |
667 | /// ``` |
668 | /// To |
669 | /// ``` |
670 | /// %r:3 = vector.warp_execute_on_lane_0(%arg0) -> (vector<1xf32>, |
671 | /// vector<1xf32>, vector<1xf32>) { |
672 | /// ... |
673 | /// %4 = arith.addf %2, %3 : vector<32xf32> |
674 | /// vector.yield %4, %2, %3 : vector<32xf32>, vector<32xf32>, |
675 | /// vector<32xf32> |
676 | /// } |
677 | /// %0 = arith.addf %r#1, %r#2 : vector<1xf32> |
678 | struct WarpOpElementwise : public OpRewritePattern<WarpExecuteOnLane0Op> { |
679 | using OpRewritePattern<WarpExecuteOnLane0Op>::OpRewritePattern; |
680 | LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp, |
681 | PatternRewriter &rewriter) const override { |
682 | OpOperand *yieldOperand = getWarpResult(warpOp, [](Operation *op) { |
683 | return OpTrait::hasElementwiseMappableTraits(op); |
684 | }); |
685 | if (!yieldOperand) |
686 | return failure(); |
687 | |
688 | Operation *elementWise = yieldOperand->get().getDefiningOp(); |
689 | unsigned operandIndex = yieldOperand->getOperandNumber(); |
690 | Value distributedVal = warpOp.getResult(operandIndex); |
691 | SmallVector<Value> yieldValues; |
692 | SmallVector<Type> retTypes; |
693 | Location loc = warpOp.getLoc(); |
694 | for (OpOperand &operand : elementWise->getOpOperands()) { |
695 | Type targetType; |
696 | if (auto vecType = dyn_cast<VectorType>(distributedVal.getType())) { |
697 | // If the result type is a vector, the operands must also be vectors. |
698 | auto operandType = cast<VectorType>(operand.get().getType()); |
699 | targetType = |
700 | VectorType::get(vecType.getShape(), operandType.getElementType()); |
701 | } else { |
702 | auto operandType = operand.get().getType(); |
703 | assert(!isa<VectorType>(operandType) && |
704 | "unexpected yield of vector from op with scalar result type" ); |
705 | targetType = operandType; |
706 | } |
707 | retTypes.push_back(Elt: targetType); |
708 | yieldValues.push_back(Elt: operand.get()); |
709 | } |
710 | SmallVector<size_t> newRetIndices; |
711 | WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns( |
712 | rewriter, warpOp, yieldValues, retTypes, newRetIndices); |
713 | rewriter.setInsertionPointAfter(newWarpOp); |
714 | SmallVector<Value> newOperands(elementWise->getOperands().begin(), |
715 | elementWise->getOperands().end()); |
716 | for (unsigned i : llvm::seq(Begin: unsigned(0), End: elementWise->getNumOperands())) { |
717 | newOperands[i] = newWarpOp.getResult(newRetIndices[i]); |
718 | } |
719 | OpBuilder::InsertionGuard g(rewriter); |
720 | rewriter.setInsertionPointAfter(newWarpOp); |
721 | Operation *newOp = cloneOpWithOperandsAndTypes( |
722 | rewriter, loc, elementWise, newOperands, |
723 | {newWarpOp.getResult(operandIndex).getType()}); |
724 | rewriter.replaceAllUsesWith(newWarpOp.getResult(operandIndex), |
725 | newOp->getResult(idx: 0)); |
726 | return success(); |
727 | } |
728 | }; |
729 | |
730 | /// Sink out splat constant op feeding into a warp op yield. |
731 | /// ``` |
732 | /// %0 = vector.warp_execute_on_lane_0(%arg0) -> (vector<1xf32>) { |
733 | /// ... |
734 | /// %cst = arith.constant dense<2.0> : vector<32xf32> |
735 | /// vector.yield %cst : vector<32xf32> |
736 | /// } |
737 | /// ``` |
738 | /// To |
739 | /// ``` |
740 | /// vector.warp_execute_on_lane_0(%arg0 { |
741 | /// ... |
742 | /// } |
743 | /// %0 = arith.constant dense<2.0> : vector<1xf32> |
744 | struct WarpOpConstant : public OpRewritePattern<WarpExecuteOnLane0Op> { |
745 | using OpRewritePattern<WarpExecuteOnLane0Op>::OpRewritePattern; |
746 | LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp, |
747 | PatternRewriter &rewriter) const override { |
748 | OpOperand *yieldOperand = |
749 | getWarpResult(warpOp, llvm::IsaPred<arith::ConstantOp>); |
750 | if (!yieldOperand) |
751 | return failure(); |
752 | auto constantOp = yieldOperand->get().getDefiningOp<arith::ConstantOp>(); |
753 | auto dense = dyn_cast<SplatElementsAttr>(constantOp.getValue()); |
754 | if (!dense) |
755 | return failure(); |
756 | // Notify the rewriter that the warp op is changing (see the comment on |
757 | // the WarpOpTransferRead pattern). |
758 | rewriter.startOpModification(op: warpOp); |
759 | unsigned operandIndex = yieldOperand->getOperandNumber(); |
760 | Attribute scalarAttr = dense.getSplatValue<Attribute>(); |
761 | auto newAttr = DenseElementsAttr::get( |
762 | cast<ShapedType>(warpOp.getResult(operandIndex).getType()), scalarAttr); |
763 | Location loc = warpOp.getLoc(); |
764 | rewriter.setInsertionPointAfter(warpOp); |
765 | Value distConstant = rewriter.create<arith::ConstantOp>(loc, newAttr); |
766 | rewriter.replaceAllUsesWith(warpOp.getResult(operandIndex), distConstant); |
767 | rewriter.finalizeOpModification(op: warpOp); |
768 | return success(); |
769 | } |
770 | }; |
771 | |
772 | /// Delinearize the given `laneId` into multiple dimensions, where each |
773 | /// dimension's size is determined by `originalShape` and `distributedShape` |
774 | /// together. This function expects the total numbers of threads needed for |
775 | /// distribution is equal to `warpSize`. Returns true and updates |
776 | /// `delinearizedIds` if so. |
777 | bool delinearizeLaneId(OpBuilder &builder, Location loc, |
778 | ArrayRef<int64_t> originalShape, |
779 | ArrayRef<int64_t> distributedShape, int64_t warpSize, |
780 | Value laneId, SmallVectorImpl<Value> &delinearizedIds) { |
781 | // If the original shape and the distributed shape is the same, we don't |
782 | // distribute at all--every thread is handling the whole. For such case, we |
783 | // should not rely on lane IDs later. So just return an empty lane ID vector. |
784 | if (originalShape == distributedShape) { |
785 | delinearizedIds.clear(); |
786 | return true; |
787 | } |
788 | |
789 | SmallVector<int64_t> sizes; |
790 | for (auto [large, small] : llvm::zip_equal(t&: originalShape, u&: distributedShape)) { |
791 | if (large % small != 0) |
792 | return false; |
793 | sizes.push_back(Elt: large / small); |
794 | } |
795 | if (std::accumulate(first: sizes.begin(), last: sizes.end(), init: 1, |
796 | binary_op: std::multiplies<int64_t>()) != warpSize) |
797 | return false; |
798 | |
799 | AffineExpr s0, s1; |
800 | bindSymbols(ctx: builder.getContext(), exprs&: s0, exprs&: s1); |
801 | |
802 | int64_t usedThreads = 1; |
803 | |
804 | Value zero = builder.create<arith::ConstantIndexOp>(location: loc, args: 0); |
805 | delinearizedIds.assign(NumElts: sizes.size(), Elt: zero); |
806 | |
807 | for (int i = sizes.size() - 1; i >= 0; --i) { |
808 | usedThreads *= sizes[i]; |
809 | if (usedThreads == warpSize) { |
810 | // We've used up all available threads. Don't need to perform modulo |
811 | // anymore. And we can stop the calculation for further dimensions. |
812 | delinearizedIds[i] = laneId; |
813 | break; |
814 | } |
815 | delinearizedIds[i] = |
816 | affine::makeComposedAffineApply(builder, loc, s0 % sizes[i], {laneId}); |
817 | laneId = affine::makeComposedAffineApply( |
818 | builder, loc, s0.floorDiv(v: usedThreads), {laneId}); |
819 | } |
820 | return true; |
821 | } |
822 | |
823 | /// Sink out transfer_read op feeding into a warp op yield. |
824 | /// ``` |
825 | /// %0 = vector.warp_execute_on_lane_0(%arg0) -> (vector<1xf32>) { |
826 | /// ... |
827 | // %2 = vector.transfer_read %src[%c0], %cst : memref<1024xf32>, |
828 | // vector<32xf32> |
829 | /// vector.yield %2 : vector<32xf32> |
830 | /// } |
831 | /// ``` |
832 | /// To |
833 | /// ``` |
834 | /// %dead = vector.warp_execute_on_lane_0(%arg0) -> (vector<1xf32>, |
835 | /// vector<1xf32>, vector<1xf32>) { |
836 | /// ... |
837 | /// %2 = vector.transfer_read %src[%c0], %cst : memref<1024xf32>, |
838 | /// vector<32xf32> vector.yield %2 : vector<32xf32> |
839 | /// } |
840 | /// %0 = vector.transfer_read %src[%c0], %cst : memref<1024xf32>, vector<1xf32> |
841 | struct WarpOpTransferRead : public OpRewritePattern<WarpExecuteOnLane0Op> { |
842 | using OpRewritePattern<WarpExecuteOnLane0Op>::OpRewritePattern; |
843 | LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp, |
844 | PatternRewriter &rewriter) const override { |
845 | // Try to find a distributable yielded read. Note that this pattern can |
846 | // still fail at the end after distribution, in which case this might have |
847 | // missed another distributable read. |
848 | OpOperand *operand = getWarpResult(warpOp, [](Operation *op) { |
849 | // Don't duplicate transfer_read ops when distributing. |
850 | return isa<vector::TransferReadOp>(op) && op->hasOneUse(); |
851 | }); |
852 | if (!operand) |
853 | return rewriter.notifyMatchFailure( |
854 | warpOp, "warp result is not a vector.transfer_read op" ); |
855 | auto read = operand->get().getDefiningOp<vector::TransferReadOp>(); |
856 | |
857 | // Source must be defined outside of the region. |
858 | if (!warpOp.isDefinedOutsideOfRegion(read.getSource())) |
859 | return rewriter.notifyMatchFailure( |
860 | read, "source must be defined outside of the region" ); |
861 | |
862 | unsigned operandIndex = operand->getOperandNumber(); |
863 | Value distributedVal = warpOp.getResult(operandIndex); |
864 | |
865 | SmallVector<Value, 4> indices(read.getIndices().begin(), |
866 | read.getIndices().end()); |
867 | auto sequentialType = cast<VectorType>(read.getResult().getType()); |
868 | auto distributedType = cast<VectorType>(distributedVal.getType()); |
869 | AffineMap map = calculateImplicitMap(sequentialType, distributedType); |
870 | AffineMap indexMap = map.compose(read.getPermutationMap()); |
871 | |
872 | // Try to delinearize the lane ID to match the rank expected for |
873 | // distribution. |
874 | SmallVector<Value> delinearizedIds; |
875 | if (!delinearizeLaneId(rewriter, read.getLoc(), sequentialType.getShape(), |
876 | distributedType.getShape(), warpOp.getWarpSize(), |
877 | warpOp.getLaneid(), delinearizedIds)) { |
878 | return rewriter.notifyMatchFailure( |
879 | read, "cannot delinearize lane ID for distribution" ); |
880 | } |
881 | assert(!delinearizedIds.empty() || map.getNumResults() == 0); |
882 | |
883 | // Distribute indices and the mask (if present). |
884 | OpBuilder::InsertionGuard g(rewriter); |
885 | SmallVector<Value> additionalResults(indices.begin(), indices.end()); |
886 | SmallVector<Type> additionalResultTypes(indices.size(), |
887 | rewriter.getIndexType()); |
888 | additionalResults.push_back(Elt: read.getPadding()); |
889 | additionalResultTypes.push_back(Elt: read.getPadding().getType()); |
890 | |
891 | bool hasMask = false; |
892 | if (read.getMask()) { |
893 | hasMask = true; |
894 | // TODO: Distribution of masked reads with non-trivial permutation maps |
895 | // requires the distribution of the mask to elementwise match the |
896 | // distribution of the permuted written vector. Currently the details |
897 | // of which lane is responsible for which element is captured strictly |
898 | // by shape information on the warp op, and thus requires materializing |
899 | // the permutation in IR. |
900 | if (!mlir::compressUnusedDims(read.getPermutationMap()).isIdentity()) |
901 | return rewriter.notifyMatchFailure( |
902 | read, "non-trivial permutation maps not supported" ); |
903 | VectorType maskType = |
904 | getDistributedType(read.getMaskType(), map, warpOp.getWarpSize()); |
905 | additionalResults.push_back(Elt: read.getMask()); |
906 | additionalResultTypes.push_back(Elt: maskType); |
907 | } |
908 | |
909 | SmallVector<size_t> newRetIndices; |
910 | WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns( |
911 | rewriter, warpOp, additionalResults, additionalResultTypes, |
912 | newRetIndices); |
913 | distributedVal = newWarpOp.getResult(operandIndex); |
914 | |
915 | // Distributed indices were appended first. |
916 | SmallVector<Value> newIndices; |
917 | for (int64_t i = 0, e = indices.size(); i < e; ++i) |
918 | newIndices.push_back(Elt: newWarpOp.getResult(newRetIndices[i])); |
919 | |
920 | rewriter.setInsertionPointAfter(newWarpOp); |
921 | for (auto it : llvm::zip_equal(indexMap.getResults(), map.getResults())) { |
922 | AffineExpr d0, d1; |
923 | bindDims(read.getContext(), d0, d1); |
924 | auto indexExpr = dyn_cast<AffineDimExpr>(std::get<0>(it)); |
925 | if (!indexExpr) |
926 | continue; |
927 | unsigned indexPos = indexExpr.getPosition(); |
928 | unsigned vectorPos = cast<AffineDimExpr>(std::get<1>(it)).getPosition(); |
929 | int64_t scale = distributedType.getDimSize(vectorPos); |
930 | newIndices[indexPos] = affine::makeComposedAffineApply( |
931 | rewriter, read.getLoc(), d0 + scale * d1, |
932 | {newIndices[indexPos], delinearizedIds[vectorPos]}); |
933 | } |
934 | |
935 | // Distributed padding value was appended right after the indices. |
936 | Value newPadding = newWarpOp.getResult(newRetIndices[indices.size()]); |
937 | // Distributed mask value was added at the end (if the op has a mask). |
938 | Value newMask = |
939 | hasMask ? newWarpOp.getResult(newRetIndices[newRetIndices.size() - 1]) |
940 | : Value(); |
941 | auto newRead = rewriter.create<vector::TransferReadOp>( |
942 | read.getLoc(), distributedVal.getType(), read.getSource(), newIndices, |
943 | read.getPermutationMapAttr(), newPadding, newMask, |
944 | read.getInBoundsAttr()); |
945 | |
946 | rewriter.replaceAllUsesWith(distributedVal, newRead); |
947 | return success(); |
948 | } |
949 | }; |
950 | |
951 | /// Remove any result that has no use along with the matching yieldOp operand. |
952 | // TODO: Move this in WarpExecuteOnLane0Op canonicalization. |
953 | struct WarpOpDeadResult : public OpRewritePattern<WarpExecuteOnLane0Op> { |
954 | using OpRewritePattern<WarpExecuteOnLane0Op>::OpRewritePattern; |
955 | LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp, |
956 | PatternRewriter &rewriter) const override { |
957 | SmallVector<Type> newResultTypes; |
958 | newResultTypes.reserve(N: warpOp->getNumResults()); |
959 | SmallVector<Value> newYieldValues; |
960 | newYieldValues.reserve(N: warpOp->getNumResults()); |
961 | DenseMap<Value, int64_t> dedupYieldOperandPositionMap; |
962 | DenseMap<OpResult, int64_t> dedupResultPositionMap; |
963 | auto yield = cast<vector::YieldOp>( |
964 | warpOp.getBodyRegion().getBlocks().begin()->getTerminator()); |
965 | |
966 | // Some values may be yielded multiple times and correspond to multiple |
967 | // results. Deduplicating occurs by taking each result with its matching |
968 | // yielded value, and: |
969 | // 1. recording the unique first position at which the value is yielded. |
970 | // 2. recording for the result, the first position at which the dedup'ed |
971 | // value is yielded. |
972 | // 3. skipping from the new result types / new yielded values any result |
973 | // that has no use or whose yielded value has already been seen. |
974 | for (OpResult result : warpOp.getResults()) { |
975 | Value yieldOperand = yield.getOperand(result.getResultNumber()); |
976 | auto it = dedupYieldOperandPositionMap.insert( |
977 | std::make_pair(yieldOperand, newResultTypes.size())); |
978 | dedupResultPositionMap.insert(std::make_pair(result, it.first->second)); |
979 | if (result.use_empty() || !it.second) |
980 | continue; |
981 | newResultTypes.push_back(result.getType()); |
982 | newYieldValues.push_back(yieldOperand); |
983 | } |
984 | // No modification, exit early. |
985 | if (yield.getNumOperands() == newYieldValues.size()) |
986 | return failure(); |
987 | // Move the body of the old warpOp to a new warpOp. |
988 | WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndReplaceReturns( |
989 | rewriter, warpOp, newYieldValues, newResultTypes); |
990 | |
991 | // Simplify the new warp op after dropping dead results. |
992 | newWarpOp.getBody()->walk([&](Operation *op) { |
993 | if (isOpTriviallyDead(op)) |
994 | rewriter.eraseOp(op); |
995 | }); |
996 | |
997 | // Replace results of the old warpOp by the new, deduplicated results. |
998 | SmallVector<Value> newValues; |
999 | newValues.reserve(N: warpOp->getNumResults()); |
1000 | for (OpResult result : warpOp.getResults()) { |
1001 | if (result.use_empty()) |
1002 | newValues.push_back(Value()); |
1003 | else |
1004 | newValues.push_back( |
1005 | newWarpOp.getResult(dedupResultPositionMap.lookup(result))); |
1006 | } |
1007 | rewriter.replaceOp(warpOp, newValues); |
1008 | return success(); |
1009 | } |
1010 | }; |
1011 | |
1012 | // If an operand is directly yielded out of the region we can forward it |
1013 | // directly and it doesn't need to go through the region. |
1014 | struct WarpOpForwardOperand : public OpRewritePattern<WarpExecuteOnLane0Op> { |
1015 | using OpRewritePattern<WarpExecuteOnLane0Op>::OpRewritePattern; |
1016 | LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp, |
1017 | PatternRewriter &rewriter) const override { |
1018 | SmallVector<Type> resultTypes; |
1019 | SmallVector<Value> yieldValues; |
1020 | auto yield = cast<vector::YieldOp>( |
1021 | warpOp.getBodyRegion().getBlocks().begin()->getTerminator()); |
1022 | Value valForwarded; |
1023 | unsigned resultIndex; |
1024 | for (OpOperand &operand : yield->getOpOperands()) { |
1025 | Value result = warpOp.getResult(operand.getOperandNumber()); |
1026 | if (result.use_empty()) |
1027 | continue; |
1028 | |
1029 | // Assume all the values coming from above are uniform. |
1030 | if (!warpOp.getBodyRegion().isAncestor(operand.get().getParentRegion())) { |
1031 | if (result.getType() != operand.get().getType()) |
1032 | continue; |
1033 | valForwarded = operand.get(); |
1034 | resultIndex = operand.getOperandNumber(); |
1035 | break; |
1036 | } |
1037 | auto arg = dyn_cast<BlockArgument>(operand.get()); |
1038 | if (!arg || arg.getOwner()->getParentOp() != warpOp.getOperation()) |
1039 | continue; |
1040 | Value warpOperand = warpOp.getArgs()[arg.getArgNumber()]; |
1041 | if (result.getType() != warpOperand.getType()) |
1042 | continue; |
1043 | valForwarded = warpOperand; |
1044 | resultIndex = operand.getOperandNumber(); |
1045 | break; |
1046 | } |
1047 | if (!valForwarded) |
1048 | return failure(); |
1049 | // Notify the rewriter that the warp op is changing (see the comment on |
1050 | // the WarpOpTransferRead pattern). |
1051 | rewriter.startOpModification(op: warpOp); |
1052 | rewriter.replaceAllUsesWith(warpOp.getResult(resultIndex), valForwarded); |
1053 | rewriter.finalizeOpModification(op: warpOp); |
1054 | return success(); |
1055 | } |
1056 | }; |
1057 | |
1058 | struct WarpOpBroadcast : public OpRewritePattern<WarpExecuteOnLane0Op> { |
1059 | using OpRewritePattern<WarpExecuteOnLane0Op>::OpRewritePattern; |
1060 | LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp, |
1061 | PatternRewriter &rewriter) const override { |
1062 | OpOperand *operand = |
1063 | getWarpResult(warpOp, llvm::IsaPred<vector::BroadcastOp>); |
1064 | if (!operand) |
1065 | return failure(); |
1066 | unsigned int operandNumber = operand->getOperandNumber(); |
1067 | auto broadcastOp = operand->get().getDefiningOp<vector::BroadcastOp>(); |
1068 | Location loc = broadcastOp.getLoc(); |
1069 | auto destVecType = |
1070 | cast<VectorType>(warpOp->getResultTypes()[operandNumber]); |
1071 | Value broadcastSrc = broadcastOp.getSource(); |
1072 | Type broadcastSrcType = broadcastSrc.getType(); |
1073 | |
1074 | // Check that the broadcast actually spans a set of values uniformly across |
1075 | // all threads. In other words, check that each thread can reconstruct |
1076 | // their own broadcast. |
1077 | // For that we simply check that the broadcast we want to build makes sense. |
1078 | if (vector::isBroadcastableTo(srcType: broadcastSrcType, dstVectorType: destVecType) != |
1079 | vector::BroadcastableToResult::Success) |
1080 | return failure(); |
1081 | SmallVector<size_t> newRetIndices; |
1082 | WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns( |
1083 | rewriter, warpOp, {broadcastSrc}, {broadcastSrcType}, newRetIndices); |
1084 | rewriter.setInsertionPointAfter(newWarpOp); |
1085 | Value broadcasted = rewriter.create<vector::BroadcastOp>( |
1086 | loc, destVecType, newWarpOp->getResult(newRetIndices[0])); |
1087 | rewriter.replaceAllUsesWith(newWarpOp->getResult(operandNumber), |
1088 | broadcasted); |
1089 | return success(); |
1090 | } |
1091 | }; |
1092 | |
1093 | /// Pattern to move shape cast out of the warp op. shape cast is basically a |
1094 | /// no-op for warp distribution; we need to handle the shape though. |
1095 | struct WarpOpShapeCast : public OpRewritePattern<WarpExecuteOnLane0Op> { |
1096 | using OpRewritePattern<WarpExecuteOnLane0Op>::OpRewritePattern; |
1097 | LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp, |
1098 | PatternRewriter &rewriter) const override { |
1099 | OpOperand *operand = |
1100 | getWarpResult(warpOp, llvm::IsaPred<vector::ShapeCastOp>); |
1101 | if (!operand) |
1102 | return failure(); |
1103 | |
1104 | auto oldCastOp = operand->get().getDefiningOp<vector::ShapeCastOp>(); |
1105 | |
1106 | unsigned int operandNumber = operand->getOperandNumber(); |
1107 | auto castDistributedType = |
1108 | cast<VectorType>(warpOp->getResultTypes()[operandNumber]); |
1109 | VectorType castOriginalType = oldCastOp.getSourceVectorType(); |
1110 | VectorType castResultType = castDistributedType; |
1111 | |
1112 | // We expect the distributed type to have a smaller rank than the original |
1113 | // type. Prepend with size-one dimensions to make them the same. |
1114 | unsigned castDistributedRank = castDistributedType.getRank(); |
1115 | unsigned castOriginalRank = castOriginalType.getRank(); |
1116 | if (castDistributedRank < castOriginalRank) { |
1117 | SmallVector<int64_t> shape(castOriginalRank - castDistributedRank, 1); |
1118 | llvm::append_range(shape, castDistributedType.getShape()); |
1119 | castDistributedType = |
1120 | VectorType::get(shape, castDistributedType.getElementType()); |
1121 | } |
1122 | |
1123 | SmallVector<size_t> newRetIndices; |
1124 | WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns( |
1125 | rewriter, warpOp, {oldCastOp.getSource()}, {castDistributedType}, |
1126 | newRetIndices); |
1127 | rewriter.setInsertionPointAfter(newWarpOp); |
1128 | Value newCast = rewriter.create<vector::ShapeCastOp>( |
1129 | oldCastOp.getLoc(), castResultType, |
1130 | newWarpOp->getResult(newRetIndices[0])); |
1131 | rewriter.replaceAllUsesWith(newWarpOp->getResult(operandNumber), newCast); |
1132 | return success(); |
1133 | } |
1134 | }; |
1135 | |
1136 | /// Sink out vector.create_mask op feeding into a warp op yield. |
1137 | /// ``` |
1138 | /// %0 = ... |
1139 | /// %1 = vector.warp_execute_on_lane_0(%arg0) -> (vector<1xf32>) { |
1140 | /// ... |
1141 | /// %mask = vector.create_mask %0 : vector<32xi1> |
1142 | /// vector.yield %mask : vector<32xi1> |
1143 | /// } |
1144 | /// ``` |
1145 | /// To |
1146 | /// ``` |
1147 | /// %0 = ... |
1148 | /// vector.warp_execute_on_lane_0(%arg0) { |
1149 | /// ... |
1150 | /// } |
1151 | /// %cmp = arith.cmpi ult, %laneid, %0 |
1152 | /// %ub = arith.select %cmp, %c0, %c1 |
1153 | /// %1 = vector.create_mask %ub : vector<1xi1> |
1154 | struct WarpOpCreateMask : public OpRewritePattern<WarpExecuteOnLane0Op> { |
1155 | using OpRewritePattern::OpRewritePattern; |
1156 | LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp, |
1157 | PatternRewriter &rewriter) const override { |
1158 | OpOperand *yieldOperand = |
1159 | getWarpResult(warpOp, llvm::IsaPred<vector::CreateMaskOp>); |
1160 | if (!yieldOperand) |
1161 | return failure(); |
1162 | |
1163 | auto mask = yieldOperand->get().getDefiningOp<vector::CreateMaskOp>(); |
1164 | |
1165 | // Early exit if any values needed for calculating the new mask indices |
1166 | // are defined inside the warp op. |
1167 | if (!llvm::all_of(mask->getOperands(), [&](Value value) { |
1168 | return warpOp.isDefinedOutsideOfRegion(value); |
1169 | })) |
1170 | return failure(); |
1171 | |
1172 | Location loc = mask.getLoc(); |
1173 | unsigned operandIndex = yieldOperand->getOperandNumber(); |
1174 | |
1175 | auto distType = cast<VectorType>(warpOp.getResult(operandIndex).getType()); |
1176 | VectorType seqType = mask.getVectorType(); |
1177 | ArrayRef<int64_t> seqShape = seqType.getShape(); |
1178 | ArrayRef<int64_t> distShape = distType.getShape(); |
1179 | |
1180 | rewriter.setInsertionPointAfter(warpOp); |
1181 | |
1182 | // Delinearize the lane ID for constructing the distributed mask sizes. |
1183 | SmallVector<Value> delinearizedIds; |
1184 | if (!delinearizeLaneId(rewriter, loc, seqShape, distShape, |
1185 | warpOp.getWarpSize(), warpOp.getLaneid(), |
1186 | delinearizedIds)) |
1187 | return rewriter.notifyMatchFailure( |
1188 | mask, "cannot delinearize lane ID for distribution" ); |
1189 | assert(!delinearizedIds.empty()); |
1190 | |
1191 | // Notify the rewriter that the warp op is changing (see the comment on |
1192 | // the WarpOpTransferRead pattern). |
1193 | rewriter.startOpModification(op: warpOp); |
1194 | |
1195 | AffineExpr s0, s1; |
1196 | bindSymbols(ctx: rewriter.getContext(), exprs&: s0, exprs&: s1); |
1197 | SmallVector<Value> newOperands; |
1198 | for (int i = 0, e = distShape.size(); i < e; ++i) { |
1199 | // Get `mask_dim_range_upper_limit[i] - lane_id[i] * dist_sizes[i]` to |
1200 | // find the distance from the largest mask index owned by this lane to the |
1201 | // original mask size. `vector.create_mask` implicitly clamps mask |
1202 | // operands to the range [0, mask_vector_size[i]], or in other words, the |
1203 | // mask sizes are always in the range [0, mask_vector_size[i]). |
1204 | Value maskDimIdx = affine::makeComposedAffineApply( |
1205 | rewriter, loc, s1 - s0 * distShape[i], |
1206 | {delinearizedIds[i], mask.getOperand(i)}); |
1207 | newOperands.push_back(Elt: maskDimIdx); |
1208 | } |
1209 | |
1210 | auto newMask = |
1211 | rewriter.create<vector::CreateMaskOp>(loc, distType, newOperands); |
1212 | rewriter.replaceAllUsesWith(warpOp.getResult(operandIndex), newMask); |
1213 | rewriter.finalizeOpModification(op: warpOp); |
1214 | return success(); |
1215 | } |
1216 | }; |
1217 | |
1218 | /// Pattern to move out vector.extract of single element vector. Those don't |
1219 | /// need to be distributed and can just be propagated outside of the region. |
1220 | struct : public OpRewritePattern<WarpExecuteOnLane0Op> { |
1221 | using OpRewritePattern<WarpExecuteOnLane0Op>::OpRewritePattern; |
1222 | LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp, |
1223 | PatternRewriter &rewriter) const override { |
1224 | OpOperand *operand = |
1225 | getWarpResult(warpOp, llvm::IsaPred<vector::ExtractOp>); |
1226 | if (!operand) |
1227 | return failure(); |
1228 | unsigned int operandNumber = operand->getOperandNumber(); |
1229 | auto = operand->get().getDefiningOp<vector::ExtractOp>(); |
1230 | VectorType = extractOp.getSourceVectorType(); |
1231 | Location loc = extractOp.getLoc(); |
1232 | |
1233 | // "vector.extract %v[] : vector<f32> from vector<f32>" is an invalid op. |
1234 | assert(extractSrcType.getRank() > 0 && |
1235 | "vector.extract does not support rank 0 sources" ); |
1236 | |
1237 | // "vector.extract %v[] : vector<...xf32> from vector<...xf32>" can be |
1238 | // canonicalized to %v. |
1239 | if (extractOp.getNumIndices() == 0) |
1240 | return failure(); |
1241 | |
1242 | // Rewrite vector.extract with 1d source to vector.extractelement. |
1243 | if (extractSrcType.getRank() == 1) { |
1244 | if (extractOp.hasDynamicPosition()) |
1245 | // TODO: Dinamic position not supported yet. |
1246 | return failure(); |
1247 | |
1248 | assert(extractOp.getNumIndices() == 1 && "expected 1 index" ); |
1249 | int64_t pos = extractOp.getStaticPosition()[0]; |
1250 | rewriter.setInsertionPoint(extractOp); |
1251 | rewriter.replaceOpWithNewOp<vector::ExtractElementOp>( |
1252 | extractOp, extractOp.getVector(), |
1253 | rewriter.create<arith::ConstantIndexOp>(loc, pos)); |
1254 | return success(); |
1255 | } |
1256 | |
1257 | // All following cases are 2d or higher dimensional source vectors. |
1258 | |
1259 | if (warpOp.getResult(operandNumber).getType() == operand->get().getType()) { |
1260 | // There is no distribution, this is a broadcast. Simply move the extract |
1261 | // out of the warp op. |
1262 | // TODO: This could be optimized. E.g., in case of a scalar result, let |
1263 | // one lane extract and shuffle the result to all other lanes (same as |
1264 | // the 1d case). |
1265 | SmallVector<size_t> newRetIndices; |
1266 | WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns( |
1267 | rewriter, warpOp, {extractOp.getVector()}, |
1268 | {extractOp.getSourceVectorType()}, newRetIndices); |
1269 | rewriter.setInsertionPointAfter(newWarpOp); |
1270 | Value distributedVec = newWarpOp->getResult(newRetIndices[0]); |
1271 | // Extract from distributed vector. |
1272 | Value = rewriter.create<vector::ExtractOp>( |
1273 | loc, distributedVec, extractOp.getMixedPosition()); |
1274 | rewriter.replaceAllUsesWith(newWarpOp->getResult(operandNumber), |
1275 | newExtract); |
1276 | return success(); |
1277 | } |
1278 | |
1279 | // Find the distributed dimension. There should be exactly one. |
1280 | auto distributedType = |
1281 | cast<VectorType>(warpOp.getResult(operandNumber).getType()); |
1282 | auto yieldedType = cast<VectorType>(operand->get().getType()); |
1283 | int64_t distributedDim = -1; |
1284 | for (int64_t i = 0; i < yieldedType.getRank(); ++i) { |
1285 | if (distributedType.getDimSize(i) != yieldedType.getDimSize(i)) { |
1286 | // Keep this assert here in case WarpExecuteOnLane0Op gets extended to |
1287 | // support distributing multiple dimensions in the future. |
1288 | assert(distributedDim == -1 && "found multiple distributed dims" ); |
1289 | distributedDim = i; |
1290 | } |
1291 | } |
1292 | assert(distributedDim != -1 && "could not find distributed dimension" ); |
1293 | (void)distributedDim; |
1294 | |
1295 | // Yield source vector from warp op. |
1296 | SmallVector<int64_t> newDistributedShape(extractSrcType.getShape().begin(), |
1297 | extractSrcType.getShape().end()); |
1298 | for (int i = 0; i < distributedType.getRank(); ++i) |
1299 | newDistributedShape[i + extractOp.getNumIndices()] = |
1300 | distributedType.getDimSize(i); |
1301 | auto newDistributedType = |
1302 | VectorType::get(newDistributedShape, distributedType.getElementType()); |
1303 | SmallVector<size_t> newRetIndices; |
1304 | WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns( |
1305 | rewriter, warpOp, {extractOp.getVector()}, {newDistributedType}, |
1306 | newRetIndices); |
1307 | rewriter.setInsertionPointAfter(newWarpOp); |
1308 | Value distributedVec = newWarpOp->getResult(newRetIndices[0]); |
1309 | // Extract from distributed vector. |
1310 | Value = rewriter.create<vector::ExtractOp>( |
1311 | loc, distributedVec, extractOp.getMixedPosition()); |
1312 | rewriter.replaceAllUsesWith(newWarpOp->getResult(operandNumber), |
1313 | newExtract); |
1314 | return success(); |
1315 | } |
1316 | }; |
1317 | |
1318 | /// Pattern to move out vector.extractelement of 0-D tensors. Those don't |
1319 | /// need to be distributed and can just be propagated outside of the region. |
1320 | struct : public OpRewritePattern<WarpExecuteOnLane0Op> { |
1321 | (MLIRContext *ctx, WarpShuffleFromIdxFn fn, |
1322 | PatternBenefit b = 1) |
1323 | : OpRewritePattern<WarpExecuteOnLane0Op>(ctx, b), |
1324 | warpShuffleFromIdxFn(std::move(fn)) {} |
1325 | LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp, |
1326 | PatternRewriter &rewriter) const override { |
1327 | OpOperand *operand = |
1328 | getWarpResult(warpOp, llvm::IsaPred<vector::ExtractElementOp>); |
1329 | if (!operand) |
1330 | return failure(); |
1331 | unsigned int operandNumber = operand->getOperandNumber(); |
1332 | auto = operand->get().getDefiningOp<vector::ExtractElementOp>(); |
1333 | VectorType = extractOp.getSourceVectorType(); |
1334 | // TODO: Supported shuffle types should be parameterizable, similar to |
1335 | // `WarpShuffleFromIdxFn`. |
1336 | if (!extractSrcType.getElementType().isF32() && |
1337 | !extractSrcType.getElementType().isInteger(32)) |
1338 | return rewriter.notifyMatchFailure( |
1339 | extractOp, "only f32/i32 element types are supported" ); |
1340 | bool = extractSrcType.getNumElements() == 1; |
1341 | Type elType = extractSrcType.getElementType(); |
1342 | VectorType distributedVecType; |
1343 | if (!is0dOrVec1Extract) { |
1344 | assert(extractSrcType.getRank() == 1 && |
1345 | "expected that extractelement src rank is 0 or 1" ); |
1346 | if (extractSrcType.getShape()[0] % warpOp.getWarpSize() != 0) |
1347 | return failure(); |
1348 | int64_t elementsPerLane = |
1349 | extractSrcType.getShape()[0] / warpOp.getWarpSize(); |
1350 | distributedVecType = VectorType::get({elementsPerLane}, elType); |
1351 | } else { |
1352 | distributedVecType = extractSrcType; |
1353 | } |
1354 | // Yield source vector and position (if present) from warp op. |
1355 | SmallVector<Value> additionalResults{extractOp.getVector()}; |
1356 | SmallVector<Type> additionalResultTypes{distributedVecType}; |
1357 | if (static_cast<bool>(extractOp.getPosition())) { |
1358 | additionalResults.push_back(Elt: extractOp.getPosition()); |
1359 | additionalResultTypes.push_back(Elt: extractOp.getPosition().getType()); |
1360 | } |
1361 | Location loc = extractOp.getLoc(); |
1362 | SmallVector<size_t> newRetIndices; |
1363 | WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns( |
1364 | rewriter, warpOp, additionalResults, additionalResultTypes, |
1365 | newRetIndices); |
1366 | rewriter.setInsertionPointAfter(newWarpOp); |
1367 | Value distributedVec = newWarpOp->getResult(newRetIndices[0]); |
1368 | |
1369 | // 0d extract: The new warp op broadcasts the source vector to all lanes. |
1370 | // All lanes extract the scalar. |
1371 | if (is0dOrVec1Extract) { |
1372 | Value ; |
1373 | if (extractSrcType.getRank() == 1) { |
1374 | newExtract = rewriter.create<vector::ExtractElementOp>( |
1375 | loc, distributedVec, |
1376 | rewriter.create<arith::ConstantIndexOp>(loc, 0)); |
1377 | |
1378 | } else { |
1379 | newExtract = |
1380 | rewriter.create<vector::ExtractElementOp>(loc, distributedVec); |
1381 | } |
1382 | rewriter.replaceAllUsesWith(newWarpOp->getResult(operandNumber), |
1383 | newExtract); |
1384 | return success(); |
1385 | } |
1386 | |
1387 | // 1d extract: Distribute the source vector. One lane extracts and shuffles |
1388 | // the value to all other lanes. |
1389 | int64_t elementsPerLane = distributedVecType.getShape()[0]; |
1390 | AffineExpr sym0 = getAffineSymbolExpr(position: 0, context: rewriter.getContext()); |
1391 | // tid of extracting thread: pos / elementsPerLane |
1392 | Value broadcastFromTid = rewriter.create<affine::AffineApplyOp>( |
1393 | loc, sym0.ceilDiv(v: elementsPerLane), |
1394 | newWarpOp->getResult(newRetIndices[1])); |
1395 | // Extract at position: pos % elementsPerLane |
1396 | Value pos = |
1397 | elementsPerLane == 1 |
1398 | ? rewriter.create<arith::ConstantIndexOp>(location: loc, args: 0).getResult() |
1399 | : rewriter |
1400 | .create<affine::AffineApplyOp>( |
1401 | loc, sym0 % elementsPerLane, |
1402 | newWarpOp->getResult(newRetIndices[1])) |
1403 | .getResult(); |
1404 | Value = |
1405 | rewriter.create<vector::ExtractElementOp>(loc, distributedVec, pos); |
1406 | |
1407 | // Shuffle the extracted value to all lanes. |
1408 | Value shuffled = warpShuffleFromIdxFn( |
1409 | loc, rewriter, extracted, broadcastFromTid, newWarpOp.getWarpSize()); |
1410 | rewriter.replaceAllUsesWith(newWarpOp->getResult(operandNumber), shuffled); |
1411 | return success(); |
1412 | } |
1413 | |
1414 | private: |
1415 | WarpShuffleFromIdxFn ; |
1416 | }; |
1417 | |
1418 | struct WarpOpInsertElement : public OpRewritePattern<WarpExecuteOnLane0Op> { |
1419 | using OpRewritePattern<WarpExecuteOnLane0Op>::OpRewritePattern; |
1420 | |
1421 | LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp, |
1422 | PatternRewriter &rewriter) const override { |
1423 | OpOperand *operand = |
1424 | getWarpResult(warpOp, llvm::IsaPred<vector::InsertElementOp>); |
1425 | if (!operand) |
1426 | return failure(); |
1427 | unsigned int operandNumber = operand->getOperandNumber(); |
1428 | auto insertOp = operand->get().getDefiningOp<vector::InsertElementOp>(); |
1429 | VectorType vecType = insertOp.getDestVectorType(); |
1430 | VectorType distrType = |
1431 | cast<VectorType>(warpOp.getResult(operandNumber).getType()); |
1432 | bool hasPos = static_cast<bool>(insertOp.getPosition()); |
1433 | |
1434 | // Yield destination vector, source scalar and position from warp op. |
1435 | SmallVector<Value> additionalResults{insertOp.getDest(), |
1436 | insertOp.getSource()}; |
1437 | SmallVector<Type> additionalResultTypes{distrType, |
1438 | insertOp.getSource().getType()}; |
1439 | if (hasPos) { |
1440 | additionalResults.push_back(Elt: insertOp.getPosition()); |
1441 | additionalResultTypes.push_back(Elt: insertOp.getPosition().getType()); |
1442 | } |
1443 | Location loc = insertOp.getLoc(); |
1444 | SmallVector<size_t> newRetIndices; |
1445 | WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns( |
1446 | rewriter, warpOp, additionalResults, additionalResultTypes, |
1447 | newRetIndices); |
1448 | rewriter.setInsertionPointAfter(newWarpOp); |
1449 | Value distributedVec = newWarpOp->getResult(newRetIndices[0]); |
1450 | Value newSource = newWarpOp->getResult(newRetIndices[1]); |
1451 | Value newPos = hasPos ? newWarpOp->getResult(newRetIndices[2]) : Value(); |
1452 | rewriter.setInsertionPointAfter(newWarpOp); |
1453 | |
1454 | if (vecType == distrType) { |
1455 | // Broadcast: Simply move the vector.inserelement op out. |
1456 | Value newInsert = rewriter.create<vector::InsertElementOp>( |
1457 | loc, newSource, distributedVec, newPos); |
1458 | rewriter.replaceAllUsesWith(newWarpOp->getResult(operandNumber), |
1459 | newInsert); |
1460 | return success(); |
1461 | } |
1462 | |
1463 | // This is a distribution. Only one lane should insert. |
1464 | int64_t elementsPerLane = distrType.getShape()[0]; |
1465 | AffineExpr sym0 = getAffineSymbolExpr(position: 0, context: rewriter.getContext()); |
1466 | // tid of extracting thread: pos / elementsPerLane |
1467 | Value insertingLane = rewriter.create<affine::AffineApplyOp>( |
1468 | loc, sym0.ceilDiv(v: elementsPerLane), newPos); |
1469 | // Insert position: pos % elementsPerLane |
1470 | Value pos = |
1471 | elementsPerLane == 1 |
1472 | ? rewriter.create<arith::ConstantIndexOp>(location: loc, args: 0).getResult() |
1473 | : rewriter |
1474 | .create<affine::AffineApplyOp>(loc, sym0 % elementsPerLane, |
1475 | newPos) |
1476 | .getResult(); |
1477 | Value isInsertingLane = rewriter.create<arith::CmpIOp>( |
1478 | loc, arith::CmpIPredicate::eq, newWarpOp.getLaneid(), insertingLane); |
1479 | Value newResult = |
1480 | rewriter |
1481 | .create<scf::IfOp>( |
1482 | loc, isInsertingLane, |
1483 | /*thenBuilder=*/ |
1484 | [&](OpBuilder &builder, Location loc) { |
1485 | Value newInsert = builder.create<vector::InsertElementOp>( |
1486 | loc, newSource, distributedVec, pos); |
1487 | builder.create<scf::YieldOp>(loc, newInsert); |
1488 | }, |
1489 | /*elseBuilder=*/ |
1490 | [&](OpBuilder &builder, Location loc) { |
1491 | builder.create<scf::YieldOp>(loc, distributedVec); |
1492 | }) |
1493 | .getResult(0); |
1494 | rewriter.replaceAllUsesWith(newWarpOp->getResult(operandNumber), newResult); |
1495 | return success(); |
1496 | } |
1497 | }; |
1498 | |
1499 | struct WarpOpInsert : public OpRewritePattern<WarpExecuteOnLane0Op> { |
1500 | using OpRewritePattern<WarpExecuteOnLane0Op>::OpRewritePattern; |
1501 | |
1502 | LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp, |
1503 | PatternRewriter &rewriter) const override { |
1504 | OpOperand *operand = getWarpResult(warpOp, llvm::IsaPred<vector::InsertOp>); |
1505 | if (!operand) |
1506 | return failure(); |
1507 | unsigned int operandNumber = operand->getOperandNumber(); |
1508 | auto insertOp = operand->get().getDefiningOp<vector::InsertOp>(); |
1509 | Location loc = insertOp.getLoc(); |
1510 | |
1511 | // "vector.insert %v, %v[] : ..." can be canonicalized to %v. |
1512 | if (insertOp.getNumIndices() == 0) |
1513 | return failure(); |
1514 | |
1515 | // Rewrite vector.insert with 1d dest to vector.insertelement. |
1516 | if (insertOp.getDestVectorType().getRank() == 1) { |
1517 | if (insertOp.hasDynamicPosition()) |
1518 | // TODO: Dinamic position not supported yet. |
1519 | return failure(); |
1520 | |
1521 | assert(insertOp.getNumIndices() == 1 && "expected 1 index" ); |
1522 | int64_t pos = insertOp.getStaticPosition()[0]; |
1523 | rewriter.setInsertionPoint(insertOp); |
1524 | rewriter.replaceOpWithNewOp<vector::InsertElementOp>( |
1525 | insertOp, insertOp.getSource(), insertOp.getDest(), |
1526 | rewriter.create<arith::ConstantIndexOp>(loc, pos)); |
1527 | return success(); |
1528 | } |
1529 | |
1530 | if (warpOp.getResult(operandNumber).getType() == operand->get().getType()) { |
1531 | // There is no distribution, this is a broadcast. Simply move the insert |
1532 | // out of the warp op. |
1533 | SmallVector<size_t> newRetIndices; |
1534 | WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns( |
1535 | rewriter, warpOp, {insertOp.getSource(), insertOp.getDest()}, |
1536 | {insertOp.getSourceType(), insertOp.getDestVectorType()}, |
1537 | newRetIndices); |
1538 | rewriter.setInsertionPointAfter(newWarpOp); |
1539 | Value distributedSrc = newWarpOp->getResult(newRetIndices[0]); |
1540 | Value distributedDest = newWarpOp->getResult(newRetIndices[1]); |
1541 | Value newResult = rewriter.create<vector::InsertOp>( |
1542 | loc, distributedSrc, distributedDest, insertOp.getMixedPosition()); |
1543 | rewriter.replaceAllUsesWith(newWarpOp->getResult(operandNumber), |
1544 | newResult); |
1545 | return success(); |
1546 | } |
1547 | |
1548 | // Find the distributed dimension. There should be exactly one. |
1549 | auto distrDestType = |
1550 | cast<VectorType>(warpOp.getResult(operandNumber).getType()); |
1551 | auto yieldedType = cast<VectorType>(operand->get().getType()); |
1552 | int64_t distrDestDim = -1; |
1553 | for (int64_t i = 0; i < yieldedType.getRank(); ++i) { |
1554 | if (distrDestType.getDimSize(i) != yieldedType.getDimSize(i)) { |
1555 | // Keep this assert here in case WarpExecuteOnLane0Op gets extended to |
1556 | // support distributing multiple dimensions in the future. |
1557 | assert(distrDestDim == -1 && "found multiple distributed dims" ); |
1558 | distrDestDim = i; |
1559 | } |
1560 | } |
1561 | assert(distrDestDim != -1 && "could not find distributed dimension" ); |
1562 | |
1563 | // Compute the distributed source vector type. |
1564 | VectorType srcVecType = cast<VectorType>(insertOp.getSourceType()); |
1565 | SmallVector<int64_t> distrSrcShape(srcVecType.getShape().begin(), |
1566 | srcVecType.getShape().end()); |
1567 | // E.g.: vector.insert %s, %d [2] : vector<96xf32> into vector<128x96xf32> |
1568 | // Case 1: distrDestDim = 1 (dim of size 96). In that case, each lane will |
1569 | // insert a smaller vector<3xf32>. |
1570 | // Case 2: distrDestDim = 0 (dim of size 128) => distrSrcDim = -1. In that |
1571 | // case, one lane will insert the source vector<96xf32>. The other |
1572 | // lanes will not do anything. |
1573 | int64_t distrSrcDim = distrDestDim - insertOp.getNumIndices(); |
1574 | if (distrSrcDim >= 0) |
1575 | distrSrcShape[distrSrcDim] = distrDestType.getDimSize(distrDestDim); |
1576 | auto distrSrcType = |
1577 | VectorType::get(distrSrcShape, distrDestType.getElementType()); |
1578 | |
1579 | // Yield source and dest vectors from warp op. |
1580 | SmallVector<size_t> newRetIndices; |
1581 | WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns( |
1582 | rewriter, warpOp, {insertOp.getSource(), insertOp.getDest()}, |
1583 | {distrSrcType, distrDestType}, newRetIndices); |
1584 | rewriter.setInsertionPointAfter(newWarpOp); |
1585 | Value distributedSrc = newWarpOp->getResult(newRetIndices[0]); |
1586 | Value distributedDest = newWarpOp->getResult(newRetIndices[1]); |
1587 | |
1588 | // Insert into the distributed vector. |
1589 | Value newResult; |
1590 | if (distrSrcDim >= 0) { |
1591 | // Every lane inserts a small piece. |
1592 | newResult = rewriter.create<vector::InsertOp>( |
1593 | loc, distributedSrc, distributedDest, insertOp.getMixedPosition()); |
1594 | } else { |
1595 | // One lane inserts the entire source vector. |
1596 | int64_t elementsPerLane = distrDestType.getDimSize(distrDestDim); |
1597 | SmallVector<OpFoldResult> pos = insertOp.getMixedPosition(); |
1598 | SmallVector<int64_t> newPos = getAsIntegers(foldResults: pos); |
1599 | // tid of inserting lane: pos / elementsPerLane |
1600 | Value insertingLane = rewriter.create<arith::ConstantIndexOp>( |
1601 | location: loc, args: newPos[distrDestDim] / elementsPerLane); |
1602 | Value isInsertingLane = rewriter.create<arith::CmpIOp>( |
1603 | loc, arith::CmpIPredicate::eq, newWarpOp.getLaneid(), insertingLane); |
1604 | // Insert position: pos % elementsPerLane |
1605 | newPos[distrDestDim] %= elementsPerLane; |
1606 | auto insertingBuilder = [&](OpBuilder &builder, Location loc) { |
1607 | Value newInsert = builder.create<vector::InsertOp>( |
1608 | loc, distributedSrc, distributedDest, newPos); |
1609 | builder.create<scf::YieldOp>(loc, newInsert); |
1610 | }; |
1611 | auto nonInsertingBuilder = [&](OpBuilder &builder, Location loc) { |
1612 | builder.create<scf::YieldOp>(loc, distributedDest); |
1613 | }; |
1614 | newResult = rewriter |
1615 | .create<scf::IfOp>(loc, isInsertingLane, |
1616 | /*thenBuilder=*/insertingBuilder, |
1617 | /*elseBuilder=*/nonInsertingBuilder) |
1618 | .getResult(0); |
1619 | } |
1620 | |
1621 | rewriter.replaceAllUsesWith(newWarpOp->getResult(operandNumber), newResult); |
1622 | return success(); |
1623 | } |
1624 | }; |
1625 | |
1626 | /// Sink scf.for region out of WarpExecuteOnLane0Op. This can be done only if |
1627 | /// the scf.ForOp is the last operation in the region so that it doesn't change |
1628 | /// the order of execution. This creates a new scf.for region after the |
1629 | /// WarpExecuteOnLane0Op. The new scf.for region will contain a new |
1630 | /// WarpExecuteOnLane0Op region. Example: |
1631 | /// ``` |
1632 | /// %w = vector.warp_execute_on_lane_0(%laneid) -> (vector<4xf32>) { |
1633 | /// ... |
1634 | /// %v1 = scf.for %arg3 = %c0 to %c128 step %c1 iter_args(%arg4 = %v) |
1635 | /// -> (vector<128xf32>) { |
1636 | /// ... |
1637 | /// scf.yield %r : vector<128xf32> |
1638 | /// } |
1639 | /// vector.yield %v1 : vector<128xf32> |
1640 | /// } |
1641 | /// ``` |
1642 | /// To: |
1643 | /// %w0 = vector.warp_execute_on_lane_0(%arg0) -> (vector<4xf32>) { |
1644 | /// ... |
1645 | /// vector.yield %v : vector<128xf32> |
1646 | /// } |
1647 | /// %w = scf.for %arg3 = %c0 to %c128 step %c1 iter_args(%varg = %q0) |
1648 | /// -> (vector<4xf32>) { |
1649 | /// %iw = vector.warp_execute_on_lane_0(%laneid) |
1650 | /// args(%varg : vector<4xf32>) -> (vector<4xf32>) { |
1651 | /// ^bb0(%arg: vector<128xf32>): |
1652 | /// ... |
1653 | /// vector.yield %ir : vector<128xf32> |
1654 | /// } |
1655 | /// scf.yield %iw : vector<4xf32> |
1656 | /// } |
1657 | /// ``` |
1658 | struct WarpOpScfForOp : public OpRewritePattern<WarpExecuteOnLane0Op> { |
1659 | |
1660 | WarpOpScfForOp(MLIRContext *ctx, DistributionMapFn fn, PatternBenefit b = 1) |
1661 | : OpRewritePattern<WarpExecuteOnLane0Op>(ctx, b), |
1662 | distributionMapFn(std::move(fn)) {} |
1663 | using OpRewritePattern<WarpExecuteOnLane0Op>::OpRewritePattern; |
1664 | LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp, |
1665 | PatternRewriter &rewriter) const override { |
1666 | auto yield = cast<vector::YieldOp>( |
1667 | warpOp.getBodyRegion().getBlocks().begin()->getTerminator()); |
1668 | // Only pick up forOp if it is the last op in the region. |
1669 | Operation *lastNode = yield->getPrevNode(); |
1670 | auto forOp = dyn_cast_or_null<scf::ForOp>(lastNode); |
1671 | if (!forOp) |
1672 | return failure(); |
1673 | // Collect Values that come from the warp op but are outside the forOp. |
1674 | // Those Value needs to be returned by the original warpOp and passed to the |
1675 | // new op. |
1676 | llvm::SmallSetVector<Value, 32> escapingValues; |
1677 | SmallVector<Type> inputTypes; |
1678 | SmallVector<Type> distTypes; |
1679 | mlir::visitUsedValuesDefinedAbove( |
1680 | forOp.getBodyRegion(), [&](OpOperand *operand) { |
1681 | Operation *parent = operand->get().getParentRegion()->getParentOp(); |
1682 | if (warpOp->isAncestor(parent)) { |
1683 | if (!escapingValues.insert(X: operand->get())) |
1684 | return; |
1685 | Type distType = operand->get().getType(); |
1686 | if (auto vecType = dyn_cast<VectorType>(distType)) { |
1687 | AffineMap map = distributionMapFn(operand->get()); |
1688 | distType = getDistributedType(vecType, map, warpOp.getWarpSize()); |
1689 | } |
1690 | inputTypes.push_back(Elt: operand->get().getType()); |
1691 | distTypes.push_back(Elt: distType); |
1692 | } |
1693 | }); |
1694 | |
1695 | SmallVector<size_t> newRetIndices; |
1696 | WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns( |
1697 | rewriter, warpOp, escapingValues.getArrayRef(), distTypes, |
1698 | newRetIndices); |
1699 | yield = cast<vector::YieldOp>( |
1700 | newWarpOp.getBodyRegion().getBlocks().begin()->getTerminator()); |
1701 | |
1702 | SmallVector<Value> newOperands; |
1703 | SmallVector<unsigned> resultIdx; |
1704 | // Collect all the outputs coming from the forOp. |
1705 | for (OpOperand &yieldOperand : yield->getOpOperands()) { |
1706 | if (yieldOperand.get().getDefiningOp() != forOp.getOperation()) |
1707 | continue; |
1708 | auto forResult = cast<OpResult>(yieldOperand.get()); |
1709 | newOperands.push_back( |
1710 | newWarpOp.getResult(yieldOperand.getOperandNumber())); |
1711 | yieldOperand.set(forOp.getInitArgs()[forResult.getResultNumber()]); |
1712 | resultIdx.push_back(yieldOperand.getOperandNumber()); |
1713 | } |
1714 | |
1715 | OpBuilder::InsertionGuard g(rewriter); |
1716 | rewriter.setInsertionPointAfter(newWarpOp); |
1717 | |
1718 | // Create a new for op outside the region with a WarpExecuteOnLane0Op region |
1719 | // inside. |
1720 | auto newForOp = rewriter.create<scf::ForOp>( |
1721 | forOp.getLoc(), forOp.getLowerBound(), forOp.getUpperBound(), |
1722 | forOp.getStep(), newOperands); |
1723 | rewriter.setInsertionPoint(newForOp.getBody(), newForOp.getBody()->begin()); |
1724 | |
1725 | SmallVector<Value> warpInput(newForOp.getRegionIterArgs().begin(), |
1726 | newForOp.getRegionIterArgs().end()); |
1727 | SmallVector<Type> warpInputType(forOp.getResultTypes().begin(), |
1728 | forOp.getResultTypes().end()); |
1729 | llvm::SmallDenseMap<Value, int64_t> argIndexMapping; |
1730 | for (auto [i, retIdx] : llvm::enumerate(First&: newRetIndices)) { |
1731 | warpInput.push_back(Elt: newWarpOp.getResult(retIdx)); |
1732 | argIndexMapping[escapingValues[i]] = warpInputType.size(); |
1733 | warpInputType.push_back(Elt: inputTypes[i]); |
1734 | } |
1735 | auto innerWarp = rewriter.create<WarpExecuteOnLane0Op>( |
1736 | newWarpOp.getLoc(), newForOp.getResultTypes(), newWarpOp.getLaneid(), |
1737 | newWarpOp.getWarpSize(), warpInput, warpInputType); |
1738 | |
1739 | SmallVector<Value> argMapping; |
1740 | argMapping.push_back(Elt: newForOp.getInductionVar()); |
1741 | for (Value args : innerWarp.getBody()->getArguments()) { |
1742 | argMapping.push_back(args); |
1743 | } |
1744 | argMapping.resize(forOp.getBody()->getNumArguments()); |
1745 | SmallVector<Value> yieldOperands; |
1746 | for (Value operand : forOp.getBody()->getTerminator()->getOperands()) |
1747 | yieldOperands.push_back(operand); |
1748 | rewriter.eraseOp(op: forOp.getBody()->getTerminator()); |
1749 | rewriter.mergeBlocks(source: forOp.getBody(), dest: innerWarp.getBody(), argValues: argMapping); |
1750 | rewriter.setInsertionPoint(innerWarp.getBody(), innerWarp.getBody()->end()); |
1751 | rewriter.create<vector::YieldOp>(innerWarp.getLoc(), yieldOperands); |
1752 | rewriter.setInsertionPointAfter(innerWarp); |
1753 | if (!innerWarp.getResults().empty()) |
1754 | rewriter.create<scf::YieldOp>(forOp.getLoc(), innerWarp.getResults()); |
1755 | rewriter.eraseOp(op: forOp); |
1756 | // Replace the warpOp result coming from the original ForOp. |
1757 | for (const auto &res : llvm::enumerate(First&: resultIdx)) { |
1758 | rewriter.replaceAllUsesWith(newWarpOp.getResult(res.value()), |
1759 | newForOp.getResult(res.index())); |
1760 | newForOp->setOperand(res.index() + 3, newWarpOp.getResult(res.value())); |
1761 | } |
1762 | newForOp.walk([&](Operation *op) { |
1763 | for (OpOperand &operand : op->getOpOperands()) { |
1764 | auto it = argIndexMapping.find(Val: operand.get()); |
1765 | if (it == argIndexMapping.end()) |
1766 | continue; |
1767 | operand.set(innerWarp.getBodyRegion().getArgument(it->second)); |
1768 | } |
1769 | }); |
1770 | |
1771 | // Finally, hoist out any now uniform code from the inner warp op. |
1772 | mlir::vector::moveScalarUniformCode(innerWarp); |
1773 | return success(); |
1774 | } |
1775 | |
1776 | private: |
1777 | DistributionMapFn distributionMapFn; |
1778 | }; |
1779 | |
1780 | /// A pattern that extracts vector.reduction ops from a WarpExecuteOnLane0Op. |
1781 | /// The vector is reduced in parallel. Currently limited to vector size matching |
1782 | /// the warpOp size. E.g.: |
1783 | /// ``` |
1784 | /// %r = vector_ext.warp_execute_on_lane_0(%laneid)[32] -> (f32) { |
1785 | /// %0 = "some_def"() : () -> (vector<32xf32>) |
1786 | /// %1 = vector.reduction "add", %0 : vector<32xf32> into f32 |
1787 | /// vector_ext.yield %1 : f32 |
1788 | /// } |
1789 | /// ``` |
1790 | /// is lowered to: |
1791 | /// ``` |
1792 | /// %0 = vector_ext.warp_execute_on_lane_0(%laneid)[32] -> (vector<1xf32>) { |
1793 | /// %1 = "some_def"() : () -> (vector<32xf32>) |
1794 | /// vector_ext.yield %1 : vector<32xf32> |
1795 | /// } |
1796 | /// %a = vector.extract %0[0] : f32 from vector<1xf32> |
1797 | /// %r = ("warp.reduction %a") |
1798 | /// ``` |
1799 | struct WarpOpReduction : public OpRewritePattern<WarpExecuteOnLane0Op> { |
1800 | WarpOpReduction(MLIRContext *context, |
1801 | DistributedReductionFn distributedReductionFn, |
1802 | PatternBenefit benefit = 1) |
1803 | : OpRewritePattern<WarpExecuteOnLane0Op>(context, benefit), |
1804 | distributedReductionFn(std::move(distributedReductionFn)) {} |
1805 | |
1806 | LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp, |
1807 | PatternRewriter &rewriter) const override { |
1808 | OpOperand *yieldOperand = |
1809 | getWarpResult(warpOp, llvm::IsaPred<vector::ReductionOp>); |
1810 | if (!yieldOperand) |
1811 | return failure(); |
1812 | |
1813 | auto reductionOp = |
1814 | cast<vector::ReductionOp>(yieldOperand->get().getDefiningOp()); |
1815 | auto vectorType = cast<VectorType>(reductionOp.getVector().getType()); |
1816 | // Only rank 1 vectors supported. |
1817 | if (vectorType.getRank() != 1) |
1818 | return rewriter.notifyMatchFailure( |
1819 | warpOp, "Only rank 1 reductions can be distributed." ); |
1820 | // Only warp_size-sized vectors supported. |
1821 | if (vectorType.getShape()[0] % warpOp.getWarpSize() != 0) |
1822 | return rewriter.notifyMatchFailure( |
1823 | warpOp, "Reduction vector dimension must match was size." ); |
1824 | if (!reductionOp.getType().isIntOrFloat()) |
1825 | return rewriter.notifyMatchFailure( |
1826 | warpOp, "Reduction distribution currently only supports floats and " |
1827 | "integer types." ); |
1828 | |
1829 | int64_t numElements = vectorType.getShape()[0] / warpOp.getWarpSize(); |
1830 | // Return vector that will be reduced from the WarpExecuteOnLane0Op. |
1831 | unsigned operandIndex = yieldOperand->getOperandNumber(); |
1832 | SmallVector<Value> yieldValues = {reductionOp.getVector()}; |
1833 | SmallVector<Type> retTypes = { |
1834 | VectorType::get({numElements}, reductionOp.getType())}; |
1835 | if (reductionOp.getAcc()) { |
1836 | yieldValues.push_back(Elt: reductionOp.getAcc()); |
1837 | retTypes.push_back(Elt: reductionOp.getAcc().getType()); |
1838 | } |
1839 | SmallVector<size_t> newRetIndices; |
1840 | WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns( |
1841 | rewriter, warpOp, yieldValues, retTypes, newRetIndices); |
1842 | rewriter.setInsertionPointAfter(newWarpOp); |
1843 | |
1844 | // Obtain data to reduce for a single lane. |
1845 | Value laneValVec = newWarpOp.getResult(newRetIndices[0]); |
1846 | // Distribute and reduce across threads. |
1847 | Value fullReduce = |
1848 | distributedReductionFn(reductionOp.getLoc(), rewriter, laneValVec, |
1849 | reductionOp.getKind(), newWarpOp.getWarpSize()); |
1850 | if (reductionOp.getAcc()) { |
1851 | fullReduce = vector::makeArithReduction( |
1852 | rewriter, reductionOp.getLoc(), reductionOp.getKind(), fullReduce, |
1853 | newWarpOp.getResult(newRetIndices[1])); |
1854 | } |
1855 | rewriter.replaceAllUsesWith(newWarpOp.getResult(operandIndex), fullReduce); |
1856 | return success(); |
1857 | } |
1858 | |
1859 | private: |
1860 | DistributedReductionFn distributedReductionFn; |
1861 | }; |
1862 | |
1863 | } // namespace |
1864 | |
1865 | void mlir::vector::populateWarpExecuteOnLane0OpToScfForPattern( |
1866 | RewritePatternSet &patterns, |
1867 | const WarpExecuteOnLane0LoweringOptions &options, PatternBenefit benefit) { |
1868 | patterns.add<WarpOpToScfIfPattern>(arg: patterns.getContext(), args: options, args&: benefit); |
1869 | } |
1870 | |
1871 | void mlir::vector::populateDistributeTransferWriteOpPatterns( |
1872 | RewritePatternSet &patterns, const DistributionMapFn &distributionMapFn, |
1873 | unsigned , PatternBenefit benefit) { |
1874 | patterns.add<WarpOpTransferWrite>(arg: patterns.getContext(), args: distributionMapFn, |
1875 | args&: maxNumElementsToExtract, args&: benefit); |
1876 | } |
1877 | |
1878 | void mlir::vector::populatePropagateWarpVectorDistributionPatterns( |
1879 | RewritePatternSet &patterns, const DistributionMapFn &distributionMapFn, |
1880 | const WarpShuffleFromIdxFn &warpShuffleFromIdxFn, PatternBenefit benefit, |
1881 | PatternBenefit readBenefit) { |
1882 | patterns.add<WarpOpTransferRead>(arg: patterns.getContext(), args&: readBenefit); |
1883 | patterns |
1884 | .add<WarpOpElementwise, WarpOpDeadResult, WarpOpBroadcast, |
1885 | WarpOpShapeCast, WarpOpExtract, WarpOpForwardOperand, WarpOpConstant, |
1886 | WarpOpInsertElement, WarpOpInsert, WarpOpCreateMask>( |
1887 | arg: patterns.getContext(), args&: benefit); |
1888 | patterns.add<WarpOpExtractElement>(arg: patterns.getContext(), |
1889 | args: warpShuffleFromIdxFn, args&: benefit); |
1890 | patterns.add<WarpOpScfForOp>(arg: patterns.getContext(), args: distributionMapFn, |
1891 | args&: benefit); |
1892 | } |
1893 | |
1894 | void mlir::vector::populateDistributeReduction( |
1895 | RewritePatternSet &patterns, |
1896 | const DistributedReductionFn &distributedReductionFn, |
1897 | PatternBenefit benefit) { |
1898 | patterns.add<WarpOpReduction>(arg: patterns.getContext(), args: distributedReductionFn, |
1899 | args&: benefit); |
1900 | } |
1901 | |
1902 | void mlir::vector::moveScalarUniformCode(WarpExecuteOnLane0Op warpOp) { |
1903 | Block *body = warpOp.getBody(); |
1904 | |
1905 | // Keep track of the ops we want to hoist. |
1906 | llvm::SmallSetVector<Operation *, 8> opsToMove; |
1907 | |
1908 | // Helper to check if a value is or will be defined outside of the region. |
1909 | auto isDefinedOutsideOfBody = [&](Value value) { |
1910 | auto *definingOp = value.getDefiningOp(); |
1911 | return (definingOp && opsToMove.count(definingOp)) || |
1912 | warpOp.isDefinedOutsideOfRegion(value); |
1913 | }; |
1914 | |
1915 | // Do not use walk here, as we do not want to go into nested regions and hoist |
1916 | // operations from there. |
1917 | for (auto &op : body->without_terminator()) { |
1918 | bool hasVectorResult = llvm::any_of(op.getResults(), [](Value result) { |
1919 | return isa<VectorType>(result.getType()); |
1920 | }); |
1921 | if (!hasVectorResult && canBeHoisted(&op, isDefinedOutsideOfBody)) |
1922 | opsToMove.insert(&op); |
1923 | } |
1924 | |
1925 | // Move all the ops marked as uniform outside of the region. |
1926 | for (Operation *op : opsToMove) |
1927 | op->moveBefore(warpOp); |
1928 | } |
1929 | |