1 | //===- VectorToSCF.cpp - Convert vector to SCF dialect ----------*- C++ -*-===// |
---|---|
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | // |
9 | // This file implements lowering of vector transfer operations to SCF. |
10 | // |
11 | //===----------------------------------------------------------------------===// |
12 | |
13 | #include <numeric> |
14 | #include <optional> |
15 | #include <type_traits> |
16 | |
17 | #include "mlir/Conversion/VectorToSCF/VectorToSCF.h" |
18 | |
19 | #include "mlir/Dialect/Affine/IR/AffineOps.h" |
20 | #include "mlir/Dialect/Arith/IR/Arith.h" |
21 | #include "mlir/Dialect/MemRef/IR/MemRef.h" |
22 | #include "mlir/Dialect/SCF/IR/SCF.h" |
23 | #include "mlir/Dialect/Tensor/IR/Tensor.h" |
24 | #include "mlir/Dialect/Vector/IR/VectorOps.h" |
25 | #include "mlir/Dialect/Vector/Transforms/LoweringPatterns.h" |
26 | #include "mlir/Dialect/Vector/Transforms/VectorTransforms.h" |
27 | #include "mlir/Dialect/Vector/Utils/VectorUtils.h" |
28 | #include "mlir/IR/Builders.h" |
29 | #include "mlir/IR/ImplicitLocOpBuilder.h" |
30 | #include "mlir/Pass/Pass.h" |
31 | #include "mlir/Transforms/GreedyPatternRewriteDriver.h" |
32 | #include "mlir/Transforms/Passes.h" |
33 | |
34 | namespace mlir { |
35 | #define GEN_PASS_DEF_CONVERTVECTORTOSCF |
36 | #include "mlir/Conversion/Passes.h.inc" |
37 | } // namespace mlir |
38 | |
39 | using namespace mlir; |
40 | using vector::TransferReadOp; |
41 | using vector::TransferWriteOp; |
42 | |
43 | namespace { |
44 | |
45 | /// Attribute name used for labeling transfer ops during progressive lowering. |
46 | static const char kPassLabel[] = "__vector_to_scf_lowering__"; |
47 | |
48 | /// Return true if this transfer op operates on a source tensor. |
49 | static bool isTensorOp(VectorTransferOpInterface xferOp) { |
50 | if (isa<RankedTensorType>(xferOp.getShapedType())) { |
51 | if (isa<vector::TransferWriteOp>(xferOp)) { |
52 | // TransferWriteOps on tensors have a result. |
53 | assert(xferOp->getNumResults() > 0); |
54 | } |
55 | return true; |
56 | } |
57 | return false; |
58 | } |
59 | |
60 | /// Patterns that inherit from this struct have access to |
61 | /// VectorTransferToSCFOptions. |
62 | template <typename OpTy> |
63 | struct VectorToSCFPattern : public OpRewritePattern<OpTy> { |
64 | explicit VectorToSCFPattern(MLIRContext *context, |
65 | VectorTransferToSCFOptions opt) |
66 | : OpRewritePattern<OpTy>(context), options(opt) {} |
67 | |
68 | LogicalResult checkLowerTensors(VectorTransferOpInterface xferOp, |
69 | PatternRewriter &rewriter) const { |
70 | if (isTensorOp(xferOp) && !options.lowerTensors) { |
71 | return rewriter.notifyMatchFailure( |
72 | xferOp, "lowering tensor transfers is disabled"); |
73 | } |
74 | return success(); |
75 | } |
76 | |
77 | VectorTransferToSCFOptions options; |
78 | }; |
79 | |
80 | /// Given a vector transfer op, calculate which dimension of the `source` |
81 | /// memref should be unpacked in the next application of TransferOpConversion. |
82 | /// A return value of std::nullopt indicates a broadcast. |
83 | template <typename OpTy> |
84 | static std::optional<int64_t> unpackedDim(OpTy xferOp) { |
85 | // TODO: support 0-d corner case. |
86 | assert(xferOp.getTransferRank() > 0 && "unexpected 0-d transfer"); |
87 | auto map = xferOp.getPermutationMap(); |
88 | if (auto expr = dyn_cast<AffineDimExpr>(map.getResult(0))) { |
89 | return expr.getPosition(); |
90 | } |
91 | assert(xferOp.isBroadcastDim(0) && |
92 | "Expected AffineDimExpr or AffineConstantExpr"); |
93 | return std::nullopt; |
94 | } |
95 | |
96 | /// Compute the permutation map for the new (N-1)-D vector transfer op. This |
97 | /// map is identical to the current permutation map, but the first result is |
98 | /// omitted. |
99 | template <typename OpTy> |
100 | static AffineMap unpackedPermutationMap(OpBuilder &b, OpTy xferOp) { |
101 | // TODO: support 0-d corner case. |
102 | assert(xferOp.getTransferRank() > 0 && "unexpected 0-d transfer"); |
103 | auto map = xferOp.getPermutationMap(); |
104 | return AffineMap::get(map.getNumDims(), 0, map.getResults().drop_front(), |
105 | b.getContext()); |
106 | } |
107 | |
108 | /// Calculate the indices for the new vector transfer op. |
109 | /// |
110 | /// E.g.: transfer_read %A[%a, %b, %c, %d] ... : vector<5x4x3xf32> ... |
111 | /// --> transfer_read %A[%a, %b + iv, %c, %d] ... vector<4x3f32> |
112 | /// ^^^^^^ |
113 | /// `iv` is the iteration variable of the (new) surrounding loop. |
114 | template <typename OpTy> |
115 | static void getXferIndices(OpBuilder &b, OpTy xferOp, Value iv, |
116 | SmallVector<Value, 8> &indices) { |
117 | typename OpTy::Adaptor adaptor(xferOp); |
118 | // Corresponding memref dim of the vector dim that is unpacked. |
119 | auto dim = unpackedDim(xferOp); |
120 | auto prevIndices = adaptor.getIndices(); |
121 | indices.append(prevIndices.begin(), prevIndices.end()); |
122 | |
123 | Location loc = xferOp.getLoc(); |
124 | bool isBroadcast = !dim.has_value(); |
125 | if (!isBroadcast) { |
126 | AffineExpr d0, d1; |
127 | bindDims(xferOp.getContext(), d0, d1); |
128 | Value offset = adaptor.getIndices()[*dim]; |
129 | indices[*dim] = |
130 | affine::makeComposedAffineApply(b, loc, d0 + d1, {offset, iv}); |
131 | } |
132 | } |
133 | |
134 | static void maybeYieldValue(OpBuilder &b, Location loc, bool hasRetVal, |
135 | Value value) { |
136 | if (hasRetVal) { |
137 | assert(value && "Expected non-empty value"); |
138 | b.create<scf::YieldOp>(loc, value); |
139 | } else { |
140 | b.create<scf::YieldOp>(loc); |
141 | } |
142 | } |
143 | |
144 | /// Generates a boolean Value that is true if the iv-th bit in xferOp's mask |
145 | /// is set to true. No such check is generated under following circumstances: |
146 | /// * xferOp does not have a mask. |
147 | /// * xferOp's mask is not 1D. (In case of (N>1)-D, a subvector of the mask is |
148 | /// computed and attached to the new transfer op in the pattern.) |
149 | /// * The to-be-unpacked dim of xferOp is a broadcast. |
150 | template <typename OpTy> |
151 | static Value generateMaskCheck(OpBuilder &b, OpTy xferOp, Value iv) { |
152 | if (!xferOp.getMask()) |
153 | return Value(); |
154 | if (xferOp.getMaskType().getRank() != 1) |
155 | return Value(); |
156 | if (xferOp.isBroadcastDim(0)) |
157 | return Value(); |
158 | |
159 | Location loc = xferOp.getLoc(); |
160 | return b.create<vector::ExtractElementOp>(loc, xferOp.getMask(), iv); |
161 | } |
162 | |
163 | /// Helper function TransferOpConversion and TransferOp1dConversion. |
164 | /// Generate an in-bounds check if the transfer op may go out-of-bounds on the |
165 | /// specified dimension `dim` with the loop iteration variable `iv`. |
166 | /// E.g., when unpacking dimension 0 from: |
167 | /// ``` |
168 | /// %vec = vector.transfer_read %A[%a, %b] %cst |
169 | /// : vector<5x4xf32>, memref<?x?xf32> |
170 | /// ``` |
171 | /// An if check similar to this will be generated inside the loop: |
172 | /// ``` |
173 | /// %d = memref.dim %A, %c0 : memref<?x?xf32> |
174 | /// if (%a + iv < %d) { |
175 | /// (in-bounds case) |
176 | /// } else { |
177 | /// (out-of-bounds case) |
178 | /// } |
179 | /// ``` |
180 | /// |
181 | /// If the transfer is 1D and has a mask, this function generates a more complex |
182 | /// check also accounts for potentially masked out elements. |
183 | /// |
184 | /// This function variant returns the value returned by `inBoundsCase` or |
185 | /// `outOfBoundsCase`. The MLIR type of the return value must be specified in |
186 | /// `resultTypes`. |
187 | template <typename OpTy> |
188 | static Value generateInBoundsCheck( |
189 | OpBuilder &b, OpTy xferOp, Value iv, std::optional<int64_t> dim, |
190 | TypeRange resultTypes, |
191 | function_ref<Value(OpBuilder &, Location)> inBoundsCase, |
192 | function_ref<Value(OpBuilder &, Location)> outOfBoundsCase = nullptr) { |
193 | bool hasRetVal = !resultTypes.empty(); |
194 | Value cond; // Condition to be built... |
195 | |
196 | // Condition check 1: Access in-bounds? |
197 | bool isBroadcast = !dim; // No in-bounds check for broadcasts. |
198 | Location loc = xferOp.getLoc(); |
199 | ImplicitLocOpBuilder lb(xferOp.getLoc(), b); |
200 | if (!xferOp.isDimInBounds(0) && !isBroadcast) { |
201 | Value memrefDim = vector::createOrFoldDimOp(b, loc, source: xferOp.getBase(), dim: *dim); |
202 | AffineExpr d0, d1; |
203 | bindDims(xferOp.getContext(), d0, d1); |
204 | Value base = xferOp.getIndices()[*dim]; |
205 | Value memrefIdx = |
206 | affine::makeComposedAffineApply(b, loc, d0 + d1, {base, iv}); |
207 | cond = lb.create<arith::CmpIOp>(arith::CmpIPredicate::sgt, memrefDim, |
208 | memrefIdx); |
209 | } |
210 | |
211 | // Condition check 2: Masked in? |
212 | if (auto maskCond = generateMaskCheck(b, xferOp, iv)) { |
213 | if (cond) |
214 | cond = lb.create<arith::AndIOp>(cond, maskCond); |
215 | else |
216 | cond = maskCond; |
217 | } |
218 | |
219 | // If the condition is non-empty, generate an SCF::IfOp. |
220 | if (cond) { |
221 | auto check = lb.create<scf::IfOp>( |
222 | cond, |
223 | /*thenBuilder=*/ |
224 | [&](OpBuilder &b, Location loc) { |
225 | maybeYieldValue(b, loc, hasRetVal, inBoundsCase(b, loc)); |
226 | }, |
227 | /*elseBuilder=*/ |
228 | [&](OpBuilder &b, Location loc) { |
229 | if (outOfBoundsCase) { |
230 | maybeYieldValue(b, loc, hasRetVal, outOfBoundsCase(b, loc)); |
231 | } else { |
232 | b.create<scf::YieldOp>(loc); |
233 | } |
234 | }); |
235 | |
236 | return hasRetVal ? check.getResult(0) : Value(); |
237 | } |
238 | |
239 | // Condition is empty, no need for an SCF::IfOp. |
240 | return inBoundsCase(b, loc); |
241 | } |
242 | |
243 | /// In this function variant, `inBoundsCase` and `outOfBoundsCase` do not have |
244 | /// a return value. Consequently, this function does not have a return value. |
245 | template <typename OpTy> |
246 | static void generateInBoundsCheck( |
247 | OpBuilder &b, OpTy xferOp, Value iv, std::optional<int64_t> dim, |
248 | function_ref<void(OpBuilder &, Location)> inBoundsCase, |
249 | function_ref<void(OpBuilder &, Location)> outOfBoundsCase = nullptr) { |
250 | generateInBoundsCheck( |
251 | b, xferOp, iv, dim, /*resultTypes=*/TypeRange(), |
252 | /*inBoundsCase=*/ |
253 | [&](OpBuilder &b, Location loc) { |
254 | inBoundsCase(b, loc); |
255 | return Value(); |
256 | }, |
257 | /*outOfBoundsCase=*/ |
258 | [&](OpBuilder &b, Location loc) { |
259 | if (outOfBoundsCase) |
260 | outOfBoundsCase(b, loc); |
261 | return Value(); |
262 | }); |
263 | } |
264 | |
265 | /// Given an ArrayAttr, return a copy where the first element is dropped. |
266 | static ArrayAttr dropFirstElem(OpBuilder &b, ArrayAttr attr) { |
267 | if (!attr) |
268 | return attr; |
269 | return ArrayAttr::get(b.getContext(), attr.getValue().drop_front()); |
270 | } |
271 | |
272 | /// Add the pass label to a vector transfer op if its rank is not the target |
273 | /// rank. |
274 | template <typename OpTy> |
275 | static void maybeApplyPassLabel(OpBuilder &b, OpTy newXferOp, |
276 | unsigned targetRank) { |
277 | if (newXferOp.getVectorType().getRank() > targetRank) |
278 | newXferOp->setAttr(kPassLabel, b.getUnitAttr()); |
279 | } |
280 | |
281 | namespace lowering_n_d { |
282 | |
283 | /// Helper data structure for data and mask buffers. |
284 | struct BufferAllocs { |
285 | Value dataBuffer; |
286 | Value maskBuffer; |
287 | }; |
288 | |
289 | // TODO: Parallelism and threadlocal considerations with a ParallelScope trait. |
290 | static Operation *getAutomaticAllocationScope(Operation *op) { |
291 | Operation *scope = |
292 | op->getParentWithTrait<OpTrait::AutomaticAllocationScope>(); |
293 | assert(scope && "Expected op to be inside automatic allocation scope"); |
294 | return scope; |
295 | } |
296 | |
297 | /// Allocate temporary buffers for data (vector) and mask (if present). |
298 | template <typename OpTy> |
299 | static BufferAllocs allocBuffers(OpBuilder &b, OpTy xferOp) { |
300 | Location loc = xferOp.getLoc(); |
301 | OpBuilder::InsertionGuard guard(b); |
302 | Operation *scope = getAutomaticAllocationScope(xferOp); |
303 | assert(scope->getNumRegions() == 1 && |
304 | "AutomaticAllocationScope with >1 regions"); |
305 | b.setInsertionPointToStart(&scope->getRegion(index: 0).front()); |
306 | |
307 | BufferAllocs result; |
308 | auto bufferType = MemRefType::get({}, xferOp.getVectorType()); |
309 | result.dataBuffer = b.create<memref::AllocaOp>(loc, bufferType); |
310 | |
311 | if (xferOp.getMask()) { |
312 | auto maskType = MemRefType::get({}, xferOp.getMask().getType()); |
313 | auto maskBuffer = b.create<memref::AllocaOp>(loc, maskType); |
314 | b.setInsertionPoint(xferOp); |
315 | b.create<memref::StoreOp>(loc, xferOp.getMask(), maskBuffer); |
316 | result.maskBuffer = b.create<memref::LoadOp>(loc, maskBuffer, ValueRange()); |
317 | } |
318 | |
319 | return result; |
320 | } |
321 | |
322 | /// Given a MemRefType with VectorType element type, unpack one dimension from |
323 | /// the VectorType into the MemRefType. |
324 | /// |
325 | /// E.g.: memref<9xvector<5x6xf32>> --> memref<9x5xvector<6xf32>> |
326 | static FailureOr<MemRefType> unpackOneDim(MemRefType type) { |
327 | auto vectorType = dyn_cast<VectorType>(type.getElementType()); |
328 | // Vectors with leading scalable dims are not supported. |
329 | // It may be possible to support these in future by using dynamic memref dims. |
330 | if (vectorType.getScalableDims().front()) |
331 | return failure(); |
332 | auto memrefShape = type.getShape(); |
333 | SmallVector<int64_t, 8> newMemrefShape; |
334 | newMemrefShape.append(memrefShape.begin(), memrefShape.end()); |
335 | newMemrefShape.push_back(vectorType.getDimSize(0)); |
336 | return MemRefType::get(newMemrefShape, |
337 | VectorType::Builder(vectorType).dropDim(0)); |
338 | } |
339 | |
340 | /// Given a transfer op, find the memref from which the mask is loaded. This |
341 | /// is similar to Strategy<TransferWriteOp>::getBuffer. |
342 | template <typename OpTy> |
343 | static Value getMaskBuffer(OpTy xferOp) { |
344 | assert(xferOp.getMask() && "Expected that transfer op has mask"); |
345 | auto loadOp = xferOp.getMask().template getDefiningOp<memref::LoadOp>(); |
346 | assert(loadOp && "Expected transfer op mask produced by LoadOp"); |
347 | return loadOp.getMemRef(); |
348 | } |
349 | |
350 | /// Codegen strategy, depending on the operation. |
351 | template <typename OpTy> |
352 | struct Strategy; |
353 | |
354 | /// Code strategy for vector TransferReadOp. |
355 | template <> |
356 | struct Strategy<TransferReadOp> { |
357 | /// Find the StoreOp that is used for writing the current TransferReadOp's |
358 | /// result to the temporary buffer allocation. |
359 | static memref::StoreOp getStoreOp(TransferReadOp xferOp) { |
360 | assert(xferOp->hasOneUse() && "Expected exactly one use of TransferReadOp"); |
361 | auto storeOp = dyn_cast<memref::StoreOp>((*xferOp->use_begin()).getOwner()); |
362 | assert(storeOp && "Expected TransferReadOp result used by StoreOp"); |
363 | return storeOp; |
364 | } |
365 | |
366 | /// Find the temporary buffer allocation. All labeled TransferReadOps are |
367 | /// used like this, where %buf is either the buffer allocation or a type cast |
368 | /// of the buffer allocation: |
369 | /// ``` |
370 | /// %vec = vector.transfer_read ... { __vector_to_scf_lowering__ } ... |
371 | /// memref.store %vec, %buf[...] ... |
372 | /// ``` |
373 | static Value getBuffer(TransferReadOp xferOp) { |
374 | return getStoreOp(xferOp).getMemRef(); |
375 | } |
376 | |
377 | /// Retrieve the indices of the current StoreOp that stores into the buffer. |
378 | static void getBufferIndices(TransferReadOp xferOp, |
379 | SmallVector<Value, 8> &indices) { |
380 | auto storeOp = getStoreOp(xferOp); |
381 | auto prevIndices = memref::StoreOpAdaptor(storeOp).getIndices(); |
382 | indices.append(prevIndices.begin(), prevIndices.end()); |
383 | } |
384 | |
385 | /// Rewrite the TransferReadOp, assuming that there are no out-of-bounds |
386 | /// accesses on the to-be-unpacked dimension. |
387 | /// |
388 | /// 1. Generate a new (N-1)-d TransferReadOp using the loop iteration |
389 | /// variable `iv`. |
390 | /// 2. Store the result into the (already `vector.type_cast`ed) buffer. |
391 | /// |
392 | /// E.g.: |
393 | /// ``` |
394 | /// %vec = vector.transfer_read %A[%a+%i, %b, %c], %cst |
395 | /// : memref<?x?x?xf32>, vector<4x3xf32> |
396 | /// memref.store %vec, %buf[%i] : memref<5xvector<4x3xf32>> |
397 | /// ``` |
398 | /// Is rewritten to: |
399 | /// ``` |
400 | /// %casted = vector.type_cast %buf |
401 | /// : memref<5xvector<4x3xf32>> to memref<5x4xvector<3xf32>> |
402 | /// for %j = 0 to 4 { |
403 | /// %vec = vector.transfer_read %A[%a+%i, %b+%j, %c], %cst |
404 | /// : memref<?x?x?xf32>, vector<3xf32> |
405 | /// memref.store %vec, %casted[%i, %j] : memref<5x4xvector<3xf32>> |
406 | /// } |
407 | /// ``` |
408 | /// |
409 | /// Note: The loop and type cast are generated in TransferOpConversion. |
410 | /// The original TransferReadOp and store op are deleted in `cleanup`. |
411 | /// Note: The `mask` operand is set in TransferOpConversion. |
412 | static TransferReadOp rewriteOp(OpBuilder &b, |
413 | VectorTransferToSCFOptions options, |
414 | TransferReadOp xferOp, Value buffer, Value iv, |
415 | ValueRange /*loopState*/) { |
416 | SmallVector<Value, 8> storeIndices; |
417 | getBufferIndices(xferOp: xferOp, indices&: storeIndices); |
418 | storeIndices.push_back(iv); |
419 | |
420 | SmallVector<Value, 8> xferIndices; |
421 | getXferIndices(b, xferOp, iv, xferIndices); |
422 | |
423 | Location loc = xferOp.getLoc(); |
424 | auto bufferType = dyn_cast<ShapedType>(buffer.getType()); |
425 | auto vecType = dyn_cast<VectorType>(bufferType.getElementType()); |
426 | auto inBoundsAttr = dropFirstElem(b, xferOp.getInBoundsAttr()); |
427 | auto newXferOp = b.create<vector::TransferReadOp>( |
428 | loc, vecType, xferOp.getBase(), xferIndices, |
429 | AffineMapAttr::get(unpackedPermutationMap(b, xferOp)), |
430 | xferOp.getPadding(), Value(), inBoundsAttr); |
431 | |
432 | maybeApplyPassLabel(b, newXferOp, options.targetRank); |
433 | |
434 | b.create<memref::StoreOp>(loc, newXferOp.getVector(), buffer, storeIndices); |
435 | return newXferOp; |
436 | } |
437 | |
438 | /// Handle out-of-bounds accesses on the to-be-unpacked dimension: Write |
439 | /// padding value to the temporary buffer. |
440 | static Value handleOutOfBoundsDim(OpBuilder &b, TransferReadOp xferOp, |
441 | Value buffer, Value iv, |
442 | ValueRange /*loopState*/) { |
443 | SmallVector<Value, 8> storeIndices; |
444 | getBufferIndices(xferOp: xferOp, indices&: storeIndices); |
445 | storeIndices.push_back(iv); |
446 | |
447 | Location loc = xferOp.getLoc(); |
448 | auto bufferType = dyn_cast<ShapedType>(buffer.getType()); |
449 | auto vecType = dyn_cast<VectorType>(bufferType.getElementType()); |
450 | auto vec = b.create<vector::SplatOp>(loc, vecType, xferOp.getPadding()); |
451 | b.create<memref::StoreOp>(loc, vec, buffer, storeIndices); |
452 | |
453 | return Value(); |
454 | } |
455 | |
456 | /// Cleanup after rewriting the op. |
457 | static void cleanup(PatternRewriter &rewriter, TransferReadOp xferOp, |
458 | scf::ForOp /*forOp*/) { |
459 | rewriter.eraseOp(op: getStoreOp(xferOp)); |
460 | rewriter.eraseOp(op: xferOp); |
461 | } |
462 | |
463 | /// Return the initial loop state for the generated scf.for loop. |
464 | static Value initialLoopState(TransferReadOp xferOp) { return Value(); } |
465 | }; |
466 | |
467 | /// Codegen strategy for vector TransferWriteOp. |
468 | template <> |
469 | struct Strategy<TransferWriteOp> { |
470 | /// Find the temporary buffer allocation. All labeled TransferWriteOps are |
471 | /// used like this, where %buf is either the buffer allocation or a type cast |
472 | /// of the buffer allocation: |
473 | /// ``` |
474 | /// %vec = memref.load %buf[...] ... |
475 | /// vector.transfer_write %vec ... { __vector_to_scf_lowering__ } ... |
476 | /// ``` |
477 | static Value getBuffer(TransferWriteOp xferOp) { |
478 | auto loadOp = xferOp.getVector().getDefiningOp<memref::LoadOp>(); |
479 | assert(loadOp && "Expected transfer op vector produced by LoadOp"); |
480 | return loadOp.getMemRef(); |
481 | } |
482 | |
483 | /// Retrieve the indices of the current LoadOp that loads from the buffer. |
484 | static void getBufferIndices(TransferWriteOp xferOp, |
485 | SmallVector<Value, 8> &indices) { |
486 | auto loadOp = xferOp.getVector().getDefiningOp<memref::LoadOp>(); |
487 | auto prevIndices = memref::LoadOpAdaptor(loadOp).getIndices(); |
488 | indices.append(prevIndices.begin(), prevIndices.end()); |
489 | } |
490 | |
491 | /// Rewrite the TransferWriteOp, assuming that there are no out-of-bounds |
492 | /// accesses on the to-be-unpacked dimension. |
493 | /// |
494 | /// 1. Load an (N-1)-d vector from the (already `vector.type_cast`ed) buffer, |
495 | /// using the loop iteration variable `iv`. |
496 | /// 2. Generate a new (N-1)-d TransferWriteOp, writing the loaded vector back |
497 | /// to memory. |
498 | /// |
499 | /// Note: For more details, see comments on Strategy<TransferReadOp>. |
500 | static TransferWriteOp rewriteOp(OpBuilder &b, |
501 | VectorTransferToSCFOptions options, |
502 | TransferWriteOp xferOp, Value buffer, |
503 | Value iv, ValueRange loopState) { |
504 | SmallVector<Value, 8> loadIndices; |
505 | getBufferIndices(xferOp: xferOp, indices&: loadIndices); |
506 | loadIndices.push_back(iv); |
507 | |
508 | SmallVector<Value, 8> xferIndices; |
509 | getXferIndices(b, xferOp, iv, xferIndices); |
510 | |
511 | Location loc = xferOp.getLoc(); |
512 | auto vec = b.create<memref::LoadOp>(loc, buffer, loadIndices); |
513 | auto inBoundsAttr = dropFirstElem(b, xferOp.getInBoundsAttr()); |
514 | auto source = loopState.empty() ? xferOp.getBase() : loopState[0]; |
515 | Type type = isTensorOp(xferOp) ? xferOp.getShapedType() : Type(); |
516 | auto newXferOp = b.create<vector::TransferWriteOp>( |
517 | loc, type, vec, source, xferIndices, |
518 | AffineMapAttr::get(unpackedPermutationMap(b, xferOp)), Value(), |
519 | inBoundsAttr); |
520 | |
521 | maybeApplyPassLabel(b, newXferOp, options.targetRank); |
522 | |
523 | return newXferOp; |
524 | } |
525 | |
526 | /// Handle out-of-bounds accesses on the to-be-unpacked dimension. |
527 | static Value handleOutOfBoundsDim(OpBuilder &b, TransferWriteOp xferOp, |
528 | Value buffer, Value iv, |
529 | ValueRange loopState) { |
530 | return isTensorOp(xferOp) ? loopState[0] : Value(); |
531 | } |
532 | |
533 | /// Cleanup after rewriting the op. |
534 | static void cleanup(PatternRewriter &rewriter, TransferWriteOp xferOp, |
535 | scf::ForOp forOp) { |
536 | if (isTensorOp(xferOp)) { |
537 | assert(forOp->getNumResults() == 1 && "Expected one for loop result"); |
538 | rewriter.replaceOp(xferOp, forOp->getResult(0)); |
539 | } else { |
540 | rewriter.eraseOp(op: xferOp); |
541 | } |
542 | } |
543 | |
544 | /// Return the initial loop state for the generated scf.for loop. |
545 | static Value initialLoopState(TransferWriteOp xferOp) { |
546 | return isTensorOp(xferOp) ? xferOp.getBase() : Value(); |
547 | } |
548 | }; |
549 | |
550 | template <typename OpTy> |
551 | static LogicalResult checkPrepareXferOp(OpTy xferOp, PatternRewriter &rewriter, |
552 | VectorTransferToSCFOptions options) { |
553 | if (xferOp->hasAttr(kPassLabel)) |
554 | return rewriter.notifyMatchFailure( |
555 | xferOp, "kPassLabel is present (vector-to-scf lowering in progress)"); |
556 | if (xferOp.getVectorType().getRank() <= options.targetRank) |
557 | return rewriter.notifyMatchFailure( |
558 | xferOp, "xferOp vector rank <= transformation target rank"); |
559 | if (xferOp.getVectorType().getScalableDims().front()) |
560 | return rewriter.notifyMatchFailure( |
561 | xferOp, "Unpacking of the leading dimension into the memref is not yet " |
562 | "supported for scalable dims"); |
563 | if (isTensorOp(xferOp) && !options.lowerTensors) |
564 | return rewriter.notifyMatchFailure( |
565 | xferOp, "Unpacking for tensors has been disabled."); |
566 | if (xferOp.getVectorType().getElementType() != |
567 | xferOp.getShapedType().getElementType()) |
568 | return rewriter.notifyMatchFailure( |
569 | xferOp, "Mismatching source and destination element types."); |
570 | |
571 | return success(); |
572 | } |
573 | |
574 | /// Prepare a TransferReadOp for progressive lowering. |
575 | /// |
576 | /// 1. Allocate a temporary buffer. |
577 | /// 2. Label the TransferReadOp, marking it eligible for progressive lowering. |
578 | /// 3. Store the result of the TransferReadOp into the temporary buffer. |
579 | /// 4. Load the result from the temporary buffer and replace all uses of the |
580 | /// original TransferReadOp with this load. |
581 | /// |
582 | /// E.g.: |
583 | /// ``` |
584 | /// %vec = vector.transfer_read %A[%a, %b, %c], %cst |
585 | /// : vector<5x4xf32>, memref<?x?x?xf32> |
586 | /// ``` |
587 | /// is rewritten to: |
588 | /// ``` |
589 | /// %0 = memref.alloca() : memref<vector<5x4xf32>> |
590 | /// %1 = vector.transfer_read %A[%a, %b, %c], %cst |
591 | /// { __vector_to_scf_lowering__ } : vector<5x4xf32>, memref<?x?x?xf32> |
592 | /// memref.store %1, %0[] : memref<vector<5x4xf32>> |
593 | /// %vec = memref.load %0[] : memref<vector<5x4xf32>> |
594 | /// ``` |
595 | /// |
596 | /// Note: A second temporary buffer may be allocated for the `mask` operand. |
597 | struct PrepareTransferReadConversion |
598 | : public VectorToSCFPattern<TransferReadOp> { |
599 | using VectorToSCFPattern<TransferReadOp>::VectorToSCFPattern; |
600 | |
601 | LogicalResult matchAndRewrite(TransferReadOp xferOp, |
602 | PatternRewriter &rewriter) const override { |
603 | if (checkPrepareXferOp(xferOp, rewriter, options).failed()) |
604 | return rewriter.notifyMatchFailure( |
605 | xferOp, "checkPrepareXferOp conditions not met!"); |
606 | |
607 | auto buffers = allocBuffers(rewriter, xferOp); |
608 | auto *newXfer = rewriter.clone(*xferOp.getOperation()); |
609 | newXfer->setAttr(kPassLabel, rewriter.getUnitAttr()); |
610 | if (xferOp.getMask()) { |
611 | dyn_cast<TransferReadOp>(newXfer).getMaskMutable().assign( |
612 | buffers.maskBuffer); |
613 | } |
614 | |
615 | Location loc = xferOp.getLoc(); |
616 | rewriter.create<memref::StoreOp>(loc, newXfer->getResult(0), |
617 | buffers.dataBuffer); |
618 | rewriter.replaceOpWithNewOp<memref::LoadOp>(xferOp, buffers.dataBuffer); |
619 | |
620 | return success(); |
621 | } |
622 | }; |
623 | |
624 | /// Prepare a TransferWriteOp for progressive lowering. |
625 | /// |
626 | /// 1. Allocate a temporary buffer. |
627 | /// 2. Store the vector into the buffer. |
628 | /// 3. Load the vector from the buffer again. |
629 | /// 4. Use the loaded vector as a TransferWriteOp operand and label the op, |
630 | /// marking it eligible for progressive lowering via TransferOpConversion. |
631 | /// |
632 | /// E.g.: |
633 | /// ``` |
634 | /// vector.transfer_write %vec, %A[%a, %b, %c] |
635 | /// : vector<5x4xf32>, memref<?x?x?xf32> |
636 | /// ``` |
637 | /// is rewritten to: |
638 | /// ``` |
639 | /// %0 = memref.alloca() : memref<vector<5x4xf32>> |
640 | /// memref.store %vec, %0[] : memref<vector<5x4xf32>> |
641 | /// %1 = memref.load %0[] : memref<vector<5x4xf32>> |
642 | /// vector.transfer_write %1, %A[%a, %b, %c] { __vector_to_scf_lowering__ } |
643 | /// : vector<5x4xf32>, memref<?x?x?xf32> |
644 | /// ``` |
645 | /// |
646 | /// Note: A second temporary buffer may be allocated for the `mask` operand. |
647 | struct PrepareTransferWriteConversion |
648 | : public VectorToSCFPattern<TransferWriteOp> { |
649 | using VectorToSCFPattern<TransferWriteOp>::VectorToSCFPattern; |
650 | |
651 | LogicalResult matchAndRewrite(TransferWriteOp xferOp, |
652 | PatternRewriter &rewriter) const override { |
653 | if (checkPrepareXferOp(xferOp, rewriter, options).failed()) |
654 | return rewriter.notifyMatchFailure( |
655 | xferOp, "checkPrepareXferOp conditions not met!"); |
656 | |
657 | Location loc = xferOp.getLoc(); |
658 | auto buffers = allocBuffers(rewriter, xferOp); |
659 | rewriter.create<memref::StoreOp>(loc, xferOp.getVector(), |
660 | buffers.dataBuffer); |
661 | auto loadedVec = rewriter.create<memref::LoadOp>(loc, buffers.dataBuffer); |
662 | rewriter.modifyOpInPlace(xferOp, [&]() { |
663 | xferOp.getValueToStoreMutable().assign(loadedVec); |
664 | xferOp->setAttr(kPassLabel, rewriter.getUnitAttr()); |
665 | }); |
666 | |
667 | if (xferOp.getMask()) { |
668 | rewriter.modifyOpInPlace(xferOp, [&]() { |
669 | xferOp.getMaskMutable().assign(buffers.maskBuffer); |
670 | }); |
671 | } |
672 | |
673 | return success(); |
674 | } |
675 | }; |
676 | |
677 | /// Decompose a n-D PrintOp into a loop of elementary/scalar prints. This allows |
678 | /// printing both 1D scalable vectors and n-D fixed size vectors. |
679 | /// |
680 | /// E.g.: |
681 | /// ``` |
682 | /// vector.print %v : vector<[4]xi32> |
683 | /// ``` |
684 | /// is rewritten to: |
685 | /// ``` |
686 | /// %c0 = arith.constant 0 : index |
687 | /// %c4 = arith.constant 4 : index |
688 | /// %c1 = arith.constant 1 : index |
689 | /// %vscale = vector.vscale |
690 | /// %length = arith.muli %vscale, %c4 : index |
691 | /// %lastIndex = arith.subi %length, %c1 : index |
692 | /// vector.print punctuation <open> |
693 | /// scf.for %i = %c0 to %length step %c1 { |
694 | /// %el = vector.extractelement %v[%i : index] : vector<[4]xi32> |
695 | /// vector.print %el : i32 punctuation <no_punctuation> |
696 | /// %notLastIndex = arith.cmpi ult, %i, %lastIndex : index |
697 | /// scf.if %notLastIndex { |
698 | /// vector.print punctuation <comma> |
699 | /// } |
700 | /// } |
701 | /// vector.print punctuation <close> |
702 | /// vector.print |
703 | /// ``` |
704 | struct DecomposePrintOpConversion : public VectorToSCFPattern<vector::PrintOp> { |
705 | using VectorToSCFPattern<vector::PrintOp>::VectorToSCFPattern; |
706 | LogicalResult matchAndRewrite(vector::PrintOp printOp, |
707 | PatternRewriter &rewriter) const override { |
708 | if (!printOp.getSource()) |
709 | return failure(); |
710 | |
711 | VectorType vectorType = dyn_cast<VectorType>(printOp.getPrintType()); |
712 | if (!vectorType) |
713 | return failure(); |
714 | |
715 | // Currently >= 2D scalable vectors are not supported. |
716 | // These can't be lowered to LLVM (as LLVM does not support scalable vectors |
717 | // of scalable vectors), and due to limitations of current ops can't be |
718 | // indexed with SSA values or flattened. This may change after |
719 | // https://reviews.llvm.org/D155034, though there still needs to be a path |
720 | // for lowering to LLVM. |
721 | if (vectorType.getRank() > 1 && vectorType.isScalable()) |
722 | return failure(); |
723 | |
724 | auto loc = printOp.getLoc(); |
725 | auto value = printOp.getSource(); |
726 | |
727 | if (auto intTy = dyn_cast<IntegerType>(vectorType.getElementType())) { |
728 | // Oddly sized integers are (somewhat) buggy on a lot of backends, so to |
729 | // avoid issues extend them to a more standard size. |
730 | // https://github.com/llvm/llvm-project/issues/30613 |
731 | auto width = intTy.getWidth(); |
732 | auto legalWidth = llvm::NextPowerOf2(A: std::max(8u, width) - 1); |
733 | auto legalIntTy = IntegerType::get(rewriter.getContext(), legalWidth, |
734 | intTy.getSignedness()); |
735 | // arith can only take signless integers, so we must cast back and forth. |
736 | auto signlessSourceVectorType = |
737 | vectorType.cloneWith({}, getIntTypeWithSignlessSemantics(intTy: intTy)); |
738 | auto signlessTargetVectorType = |
739 | vectorType.cloneWith({}, getIntTypeWithSignlessSemantics(intTy: legalIntTy)); |
740 | auto targetVectorType = vectorType.cloneWith({}, legalIntTy); |
741 | value = rewriter.create<vector::BitCastOp>(loc, signlessSourceVectorType, |
742 | value); |
743 | if (value.getType() != signlessTargetVectorType) { |
744 | if (width == 1 || intTy.isUnsigned()) |
745 | value = rewriter.create<arith::ExtUIOp>(loc, signlessTargetVectorType, |
746 | value); |
747 | else |
748 | value = rewriter.create<arith::ExtSIOp>(loc, signlessTargetVectorType, |
749 | value); |
750 | } |
751 | value = rewriter.create<vector::BitCastOp>(loc, targetVectorType, value); |
752 | vectorType = targetVectorType; |
753 | } |
754 | |
755 | auto scalableDimensions = vectorType.getScalableDims(); |
756 | auto shape = vectorType.getShape(); |
757 | constexpr int64_t singletonShape[] = {1}; |
758 | if (vectorType.getRank() == 0) |
759 | shape = singletonShape; |
760 | |
761 | if (vectorType.getRank() != 1) { |
762 | // Flatten n-D vectors to 1D. This is done to allow indexing with a |
763 | // non-constant value (which can currently only be done via |
764 | // vector.extractelement for 1D vectors). |
765 | auto flatLength = std::accumulate(shape.begin(), shape.end(), 1, |
766 | std::multiplies<int64_t>()); |
767 | auto flatVectorType = |
768 | VectorType::get({flatLength}, vectorType.getElementType()); |
769 | value = rewriter.create<vector::ShapeCastOp>(loc, flatVectorType, value); |
770 | } |
771 | |
772 | vector::PrintOp firstClose; |
773 | SmallVector<Value, 8> loopIndices; |
774 | for (unsigned d = 0; d < shape.size(); d++) { |
775 | // Setup loop bounds and step. |
776 | Value lowerBound = rewriter.create<arith::ConstantIndexOp>(loc, 0); |
777 | Value upperBound = rewriter.create<arith::ConstantIndexOp>(loc, shape[d]); |
778 | Value step = rewriter.create<arith::ConstantIndexOp>(loc, 1); |
779 | if (!scalableDimensions.empty() && scalableDimensions[d]) { |
780 | auto vscale = rewriter.create<vector::VectorScaleOp>( |
781 | loc, rewriter.getIndexType()); |
782 | upperBound = rewriter.create<arith::MulIOp>(loc, upperBound, vscale); |
783 | } |
784 | auto lastIndex = rewriter.create<arith::SubIOp>(loc, upperBound, step); |
785 | |
786 | // Create a loop to print the elements surrounded by parentheses. |
787 | rewriter.create<vector::PrintOp>(loc, vector::PrintPunctuation::Open); |
788 | auto loop = |
789 | rewriter.create<scf::ForOp>(loc, lowerBound, upperBound, step); |
790 | auto printClose = rewriter.create<vector::PrintOp>( |
791 | loc, vector::PrintPunctuation::Close); |
792 | if (!firstClose) |
793 | firstClose = printClose; |
794 | |
795 | auto loopIdx = loop.getInductionVar(); |
796 | loopIndices.push_back(loopIdx); |
797 | |
798 | // Print a comma after all but the last element. |
799 | rewriter.setInsertionPointToStart(loop.getBody()); |
800 | auto notLastIndex = rewriter.create<arith::CmpIOp>( |
801 | loc, arith::CmpIPredicate::ult, loopIdx, lastIndex); |
802 | rewriter.create<scf::IfOp>(loc, notLastIndex, |
803 | [&](OpBuilder &builder, Location loc) { |
804 | builder.create<vector::PrintOp>( |
805 | loc, vector::PrintPunctuation::Comma); |
806 | builder.create<scf::YieldOp>(loc); |
807 | }); |
808 | |
809 | rewriter.setInsertionPointToStart(loop.getBody()); |
810 | } |
811 | |
812 | // Compute the flattened index. |
813 | // Note: For the > rank 1 vectors this assumes non-scalable. |
814 | Value flatIndex; |
815 | auto currentStride = 1; |
816 | for (int d = shape.size() - 1; d >= 0; d--) { |
817 | auto stride = rewriter.create<arith::ConstantIndexOp>(loc, currentStride); |
818 | auto index = rewriter.create<arith::MulIOp>(loc, stride, loopIndices[d]); |
819 | if (flatIndex) |
820 | flatIndex = rewriter.create<arith::AddIOp>(loc, flatIndex, index); |
821 | else |
822 | flatIndex = index; |
823 | currentStride *= shape[d]; |
824 | } |
825 | |
826 | // Print the scalar elements in the inner most loop. |
827 | auto element = |
828 | rewriter.create<vector::ExtractElementOp>(loc, value, flatIndex); |
829 | rewriter.create<vector::PrintOp>(loc, element, |
830 | vector::PrintPunctuation::NoPunctuation); |
831 | |
832 | rewriter.setInsertionPointAfter(firstClose); |
833 | rewriter.create<vector::PrintOp>(loc, printOp.getPunctuation()); |
834 | rewriter.eraseOp(op: printOp); |
835 | return success(); |
836 | } |
837 | |
838 | static IntegerType getIntTypeWithSignlessSemantics(IntegerType intTy) { |
839 | return IntegerType::get(intTy.getContext(), intTy.getWidth(), |
840 | IntegerType::Signless); |
841 | }; |
842 | }; |
843 | |
844 | /// Progressive lowering of vector transfer ops: Unpack one dimension. |
845 | /// |
846 | /// 1. Unpack one dimension from the current buffer type and cast the buffer |
847 | /// to that new type. E.g.: |
848 | /// ``` |
849 | /// %vec = memref.load %0[%1] : memref<5xvector<4x3xf32>> |
850 | /// vector.transfer_write %vec ... |
851 | /// ``` |
852 | /// The following cast is generated: |
853 | /// ``` |
854 | /// %casted = vector.type_cast %0 |
855 | /// : memref<5xvector<4x3xf32>> to memref<5x4xvector<3xf32>> |
856 | /// ``` |
857 | /// 2. Generate a for loop and rewrite the transfer op according to the |
858 | /// corresponding Strategy<OpTy>. If the to-be-unpacked dimension can be |
859 | /// out-of-bounds, generate an if-check and handle both cases separately. |
860 | /// 3. Clean up according to the corresponding Strategy<OpTy>. |
861 | /// |
862 | /// Note: If the transfer op is a TransferWriteOp and operates on a tensor |
863 | /// source (as opposed to a memref source), then each iteration of the generated |
864 | /// scf.for loop yields the new tensor value. E.g.: |
865 | /// ``` |
866 | /// %result = scf.for i = 0 to 5 { |
867 | /// %0 = memref.load %buffer[i] : memref<5xvector<4x3xf32>> |
868 | /// %1 = vector.transfer_write %0, %source[...] |
869 | /// : vector<4x3xf32>, tensor<5x4x3xf32> |
870 | /// scf.yield %1 : tensor<5x4x3xf32> |
871 | /// } |
872 | /// ``` |
873 | template <typename OpTy> |
874 | struct TransferOpConversion : public VectorToSCFPattern<OpTy> { |
875 | using VectorToSCFPattern<OpTy>::VectorToSCFPattern; |
876 | |
877 | void initialize() { |
878 | // This pattern recursively unpacks one dimension at a time. The recursion |
879 | // bounded as the rank is strictly decreasing. |
880 | this->setHasBoundedRewriteRecursion(); |
881 | } |
882 | |
883 | static void getMaskBufferLoadIndices(OpTy xferOp, Value castedMaskBuffer, |
884 | SmallVectorImpl<Value> &loadIndices, |
885 | Value iv) { |
886 | assert(xferOp.getMask() && "Expected transfer op to have mask"); |
887 | |
888 | // Add load indices from the previous iteration. |
889 | // The mask buffer depends on the permutation map, which makes determining |
890 | // the indices quite complex, so this is why we need to "look back" to the |
891 | // previous iteration to find the right indices. |
892 | Value maskBuffer = getMaskBuffer(xferOp); |
893 | for (Operation *user : maskBuffer.getUsers()) { |
894 | // If there is no previous load op, then the indices are empty. |
895 | if (auto loadOp = dyn_cast<memref::LoadOp>(user)) { |
896 | Operation::operand_range prevIndices = loadOp.getIndices(); |
897 | loadIndices.append(prevIndices.begin(), prevIndices.end()); |
898 | break; |
899 | } |
900 | } |
901 | |
902 | // In case of broadcast: Use same indices to load from memref |
903 | // as before. |
904 | if (!xferOp.isBroadcastDim(0)) |
905 | loadIndices.push_back(Elt: iv); |
906 | } |
907 | |
908 | LogicalResult matchAndRewrite(OpTy xferOp, |
909 | PatternRewriter &rewriter) const override { |
910 | if (!xferOp->hasAttr(kPassLabel)) |
911 | return rewriter.notifyMatchFailure( |
912 | xferOp, "kPassLabel is present (progressing lowering in progress)"); |
913 | |
914 | // Find and cast data buffer. How the buffer can be found depends on OpTy. |
915 | ImplicitLocOpBuilder locB(xferOp.getLoc(), rewriter); |
916 | Value dataBuffer = Strategy<OpTy>::getBuffer(xferOp); |
917 | auto dataBufferType = dyn_cast<MemRefType>(dataBuffer.getType()); |
918 | FailureOr<MemRefType> castedDataType = unpackOneDim(dataBufferType); |
919 | if (failed(castedDataType)) |
920 | return rewriter.notifyMatchFailure(xferOp, |
921 | "Failed to unpack one vector dim."); |
922 | |
923 | auto castedDataBuffer = |
924 | locB.create<vector::TypeCastOp>(*castedDataType, dataBuffer); |
925 | |
926 | // If the xferOp has a mask: Find and cast mask buffer. |
927 | Value castedMaskBuffer; |
928 | if (xferOp.getMask()) { |
929 | Value maskBuffer = getMaskBuffer(xferOp); |
930 | if (xferOp.isBroadcastDim(0) || xferOp.getMaskType().getRank() == 1) { |
931 | // Do not unpack a dimension of the mask, if: |
932 | // * To-be-unpacked transfer op dimension is a broadcast. |
933 | // * Mask is 1D, i.e., the mask cannot be further unpacked. |
934 | // (That means that all remaining dimensions of the transfer op must |
935 | // be broadcasted.) |
936 | castedMaskBuffer = maskBuffer; |
937 | } else { |
938 | // It's safe to assume the mask buffer can be unpacked if the data |
939 | // buffer was unpacked. |
940 | auto maskBufferType = cast<MemRefType>(maskBuffer.getType()); |
941 | MemRefType castedMaskType = *unpackOneDim(maskBufferType); |
942 | castedMaskBuffer = |
943 | locB.create<vector::TypeCastOp>(castedMaskType, maskBuffer); |
944 | } |
945 | } |
946 | |
947 | // Loop bounds and step. |
948 | auto lb = locB.create<arith::ConstantIndexOp>(0); |
949 | auto ub = locB.create<arith::ConstantIndexOp>( |
950 | castedDataType->getDimSize(castedDataType->getRank() - 1)); |
951 | auto step = locB.create<arith::ConstantIndexOp>(1); |
952 | // TransferWriteOps that operate on tensors return the modified tensor and |
953 | // require a loop state. |
954 | auto loopState = Strategy<OpTy>::initialLoopState(xferOp); |
955 | |
956 | // Generate for loop. |
957 | auto result = locB.create<scf::ForOp>( |
958 | lb, ub, step, loopState ? ValueRange(loopState) : ValueRange(), |
959 | [&](OpBuilder &b, Location loc, Value iv, ValueRange loopState) { |
960 | Type stateType = loopState.empty() ? Type() : loopState[0].getType(); |
961 | |
962 | auto result = generateInBoundsCheck( |
963 | b, xferOp, iv, unpackedDim(xferOp), |
964 | stateType ? TypeRange(stateType) : TypeRange(), |
965 | /*inBoundsCase=*/ |
966 | [&](OpBuilder &b, Location loc) { |
967 | // Create new transfer op. |
968 | OpTy newXfer = Strategy<OpTy>::rewriteOp( |
969 | b, this->options, xferOp, castedDataBuffer, iv, loopState); |
970 | |
971 | // If old transfer op has a mask: Set mask on new transfer op. |
972 | // Special case: If the mask of the old transfer op is 1D and |
973 | // the unpacked dim is not a broadcast, no mask is needed on |
974 | // the new transfer op. |
975 | if (xferOp.getMask() && (xferOp.isBroadcastDim(0) || |
976 | xferOp.getMaskType().getRank() > 1)) { |
977 | OpBuilder::InsertionGuard guard(b); |
978 | b.setInsertionPoint(newXfer); // Insert load before newXfer. |
979 | |
980 | SmallVector<Value, 8> loadIndices; |
981 | getMaskBufferLoadIndices(xferOp, castedMaskBuffer, |
982 | loadIndices, iv); |
983 | auto mask = b.create<memref::LoadOp>(loc, castedMaskBuffer, |
984 | loadIndices); |
985 | rewriter.modifyOpInPlace(newXfer, [&]() { |
986 | newXfer.getMaskMutable().assign(mask); |
987 | }); |
988 | } |
989 | |
990 | return loopState.empty() ? Value() : newXfer->getResult(0); |
991 | }, |
992 | /*outOfBoundsCase=*/ |
993 | [&](OpBuilder &b, Location /*loc*/) { |
994 | return Strategy<OpTy>::handleOutOfBoundsDim( |
995 | b, xferOp, castedDataBuffer, iv, loopState); |
996 | }); |
997 | |
998 | maybeYieldValue(b, loc, !loopState.empty(), result); |
999 | }); |
1000 | |
1001 | Strategy<OpTy>::cleanup(rewriter, xferOp, result); |
1002 | return success(); |
1003 | } |
1004 | }; |
1005 | |
1006 | /// Retrieves the dimensions sizes of a mask. Currently supports CreateMaskOp |
1007 | /// and ConstantMaskOp. |
1008 | template <typename VscaleConstantBuilder> |
1009 | static FailureOr<SmallVector<OpFoldResult>> |
1010 | getMaskDimSizes(Value mask, VscaleConstantBuilder &createVscaleMultiple) { |
1011 | if (!mask) |
1012 | return SmallVector<OpFoldResult>{}; |
1013 | if (auto createMaskOp = mask.getDefiningOp<vector::CreateMaskOp>()) { |
1014 | return llvm::map_to_vector(createMaskOp.getOperands(), [](Value dimSize) { |
1015 | return OpFoldResult(dimSize); |
1016 | }); |
1017 | } |
1018 | if (auto constantMask = mask.getDefiningOp<vector::ConstantMaskOp>()) { |
1019 | int dimIdx = 0; |
1020 | VectorType maskType = constantMask.getVectorType(); |
1021 | auto indexType = IndexType::get(mask.getContext()); |
1022 | return llvm::map_to_vector( |
1023 | constantMask.getMaskDimSizes(), [&](int64_t dimSize) { |
1024 | // A scalable dim in a constant_mask means vscale x dimSize. |
1025 | if (maskType.getScalableDims()[dimIdx++]) |
1026 | return OpFoldResult(createVscaleMultiple(dimSize)); |
1027 | return OpFoldResult(IntegerAttr::get(indexType, dimSize)); |
1028 | }); |
1029 | } |
1030 | return failure(); |
1031 | } |
1032 | |
1033 | /// Scalable vector lowering of transfer_write(transpose). This lowering only |
1034 | /// supports rank 2 (scalable) vectors, but can be used in conjunction with |
1035 | /// `UnrollTransferWriteConversion` to support n-D cases. The unroll conversion |
1036 | /// unrolls until the first scalable dimension. |
1037 | /// |
1038 | /// Example: |
1039 | /// |
1040 | /// BEFORE: |
1041 | /// ```mlir |
1042 | /// %transpose = vector.transpose %vec, [1, 0] |
1043 | /// : vector<4x[4]xf32> to vector<[4]x4xf32> |
1044 | /// vector.transfer_write %transpose, %dest[%i, %j] {in_bounds = [true, true]} |
1045 | /// : vector<[4]x4xf32>, memref<?x?xf32> |
1046 | /// ``` |
1047 | /// |
1048 | /// AFTER: |
1049 | /// ```mlir |
1050 | /// %c1 = arith.constant 1 : index |
1051 | /// %c4 = arith.constant 4 : index |
1052 | /// %c0 = arith.constant 0 : index |
1053 | /// %0 = vector.extract %arg0[0] : vector<[4]xf32> from vector<4x[4]xf32> |
1054 | /// %1 = vector.extract %arg0[1] : vector<[4]xf32> from vector<4x[4]xf32> |
1055 | /// %2 = vector.extract %arg0[2] : vector<[4]xf32> from vector<4x[4]xf32> |
1056 | /// %3 = vector.extract %arg0[3] : vector<[4]xf32> from vector<4x[4]xf32> |
1057 | /// %vscale = vector.vscale |
1058 | /// %c4_vscale = arith.muli %vscale, %c4 : index |
1059 | /// scf.for %idx = %c0 to %c4_vscale step %c1 { |
1060 | /// %4 = vector.extract %0[%idx] : f32 from vector<[4]xf32> |
1061 | /// %5 = vector.extract %1[%idx] : f32 from vector<[4]xf32> |
1062 | /// %6 = vector.extract %2[%idx] : f32 from vector<[4]xf32> |
1063 | /// %7 = vector.extract %3[%idx] : f32 from vector<[4]xf32> |
1064 | /// %slice_i = affine.apply #map(%idx)[%i] |
1065 | /// %slice = vector.from_elements %4, %5, %6, %7 : vector<4xf32> |
1066 | /// vector.transfer_write %slice, %arg1[%slice_i, %j] {in_bounds = [true]} |
1067 | /// : vector<4xf32>, memref<?x?xf32> |
1068 | /// } |
1069 | /// ``` |
1070 | struct ScalableTransposeTransferWriteConversion |
1071 | : VectorToSCFPattern<vector::TransferWriteOp> { |
1072 | using VectorToSCFPattern::VectorToSCFPattern; |
1073 | |
1074 | LogicalResult matchAndRewrite(TransferWriteOp writeOp, |
1075 | PatternRewriter &rewriter) const override { |
1076 | if (failed(checkLowerTensors(writeOp, rewriter))) |
1077 | return failure(); |
1078 | |
1079 | VectorType vectorType = writeOp.getVectorType(); |
1080 | |
1081 | // Note: By comparing the scalable dims to an ArrayRef of length two this |
1082 | // implicitly checks the rank (is also two). |
1083 | ArrayRef<bool> scalableFlags = vectorType.getScalableDims(); |
1084 | if (scalableFlags != ArrayRef<bool>{true, false}) { |
1085 | return rewriter.notifyMatchFailure( |
1086 | writeOp, "expected vector of the form vector<[N]xMxty>"); |
1087 | } |
1088 | |
1089 | auto permutationMap = writeOp.getPermutationMap(); |
1090 | if (!permutationMap.isIdentity()) { |
1091 | return rewriter.notifyMatchFailure( |
1092 | writeOp, "non-identity permutations are unsupported (lower first)"); |
1093 | } |
1094 | |
1095 | // Note: This pattern is only lowering the leading dimension (to a loop), |
1096 | // so we only check if the leading dimension is in bounds. The in-bounds |
1097 | // attribute for the trailing dimension will be propagated. |
1098 | if (!writeOp.isDimInBounds(0)) { |
1099 | return rewriter.notifyMatchFailure( |
1100 | writeOp, "out-of-bounds dims are unsupported (use masking)"); |
1101 | } |
1102 | |
1103 | Value vector = writeOp.getVector(); |
1104 | auto transposeOp = vector.getDefiningOp<vector::TransposeOp>(); |
1105 | if (!transposeOp || |
1106 | transposeOp.getPermutation() != ArrayRef<int64_t>{1, 0}) { |
1107 | return rewriter.notifyMatchFailure(writeOp, "source not transpose"); |
1108 | } |
1109 | |
1110 | auto loc = writeOp.getLoc(); |
1111 | auto createVscaleMultiple = |
1112 | vector::makeVscaleConstantBuilder(rewriter, loc: loc); |
1113 | |
1114 | auto maskDims = getMaskDimSizes(writeOp.getMask(), createVscaleMultiple); |
1115 | if (failed(maskDims)) { |
1116 | return rewriter.notifyMatchFailure(writeOp, |
1117 | "failed to resolve mask dims"); |
1118 | } |
1119 | |
1120 | int64_t fixedDimSize = vectorType.getDimSize(1); |
1121 | auto fixedDimOffsets = llvm::seq(fixedDimSize); |
1122 | |
1123 | // Extract all slices from the source of the transpose. |
1124 | auto transposeSource = transposeOp.getVector(); |
1125 | SmallVector<Value> transposeSourceSlices = |
1126 | llvm::map_to_vector(fixedDimOffsets, [&](int64_t idx) -> Value { |
1127 | return rewriter.create<vector::ExtractOp>(loc, transposeSource, idx); |
1128 | }); |
1129 | |
1130 | // Loop bounds and step. |
1131 | auto lb = rewriter.create<arith::ConstantIndexOp>(loc, 0); |
1132 | auto ub = |
1133 | maskDims->empty() |
1134 | ? Value(createVscaleMultiple(vectorType.getDimSize(0))) |
1135 | : vector::getAsValues(builder&: rewriter, loc: loc, foldResults: maskDims->front()).front(); |
1136 | auto step = rewriter.create<arith::ConstantIndexOp>(loc, 1); |
1137 | |
1138 | // Generate a new mask for the slice. |
1139 | VectorType sliceType = VectorType::Builder(vectorType).dropDim(0); |
1140 | Value sliceMask = nullptr; |
1141 | if (!maskDims->empty()) { |
1142 | sliceMask = rewriter.create<vector::CreateMaskOp>( |
1143 | loc, sliceType.clone(rewriter.getI1Type()), |
1144 | ArrayRef<OpFoldResult>(*maskDims).drop_front()); |
1145 | } |
1146 | |
1147 | Value initDest = isTensorOp(writeOp) ? writeOp.getBase() : Value{}; |
1148 | ValueRange initLoopArgs = initDest ? initDest : ValueRange{}; |
1149 | auto result = rewriter.create<scf::ForOp>( |
1150 | loc, lb, ub, step, initLoopArgs, |
1151 | [&](OpBuilder &b, Location loc, Value iv, ValueRange loopIterArgs) { |
1152 | // Indices for the new transfer op. |
1153 | SmallVector<Value, 8> xferIndices; |
1154 | getXferIndices(b, writeOp, iv, xferIndices); |
1155 | |
1156 | // Extract a transposed slice from the source vector. |
1157 | SmallVector<Value> transposeElements = |
1158 | llvm::map_to_vector(fixedDimOffsets, [&](int64_t idx) -> Value { |
1159 | return b.create<vector::ExtractOp>( |
1160 | loc, transposeSourceSlices[idx], iv); |
1161 | }); |
1162 | auto sliceVec = b.create<vector::FromElementsOp>(loc, sliceType, |
1163 | transposeElements); |
1164 | |
1165 | // Create the transfer_write for the slice. |
1166 | Value dest = |
1167 | loopIterArgs.empty() ? writeOp.getBase() : loopIterArgs.front(); |
1168 | auto newWriteOp = b.create<vector::TransferWriteOp>( |
1169 | loc, sliceVec, dest, xferIndices, |
1170 | ArrayRef<bool>(writeOp.getInBoundsValues()).drop_front()); |
1171 | if (sliceMask) |
1172 | newWriteOp.getMaskMutable().assign(sliceMask); |
1173 | |
1174 | // Yield from the loop. |
1175 | b.create<scf::YieldOp>(loc, loopIterArgs.empty() |
1176 | ? ValueRange{} |
1177 | : newWriteOp.getResult()); |
1178 | }); |
1179 | |
1180 | if (isTensorOp(writeOp)) |
1181 | rewriter.replaceOp(writeOp, result); |
1182 | else |
1183 | rewriter.eraseOp(op: writeOp); |
1184 | |
1185 | return success(); |
1186 | } |
1187 | }; |
1188 | |
1189 | } // namespace lowering_n_d |
1190 | |
1191 | namespace lowering_n_d_unrolled { |
1192 | |
1193 | /// If the original transfer op has a mask, compute the mask of the new transfer |
1194 | /// op (for the current iteration `i`) and assign it. |
1195 | template <typename OpTy> |
1196 | static void maybeAssignMask(OpBuilder &b, OpTy xferOp, OpTy newXferOp, |
1197 | int64_t i) { |
1198 | if (!xferOp.getMask()) |
1199 | return; |
1200 | |
1201 | if (xferOp.isBroadcastDim(0)) { |
1202 | // To-be-unpacked dimension is a broadcast, which does not have a |
1203 | // corresponding mask dimension. Mask attribute remains unchanged. |
1204 | newXferOp.getMaskMutable().assign(xferOp.getMask()); |
1205 | return; |
1206 | } |
1207 | |
1208 | if (xferOp.getMaskType().getRank() > 1) { |
1209 | // Unpack one dimension of the mask. |
1210 | OpBuilder::InsertionGuard guard(b); |
1211 | b.setInsertionPoint(newXferOp); // Insert load before newXfer. |
1212 | |
1213 | llvm::SmallVector<int64_t, 1> indices({i}); |
1214 | Location loc = xferOp.getLoc(); |
1215 | auto newMask = b.create<vector::ExtractOp>(loc, xferOp.getMask(), indices); |
1216 | newXferOp.getMaskMutable().assign(newMask); |
1217 | } |
1218 | |
1219 | // If we end up here: The mask of the old transfer op is 1D and the unpacked |
1220 | // dim is not a broadcast, so no mask is needed on the new transfer op. |
1221 | // `generateInBoundsCheck` will have evaluated the mask already. |
1222 | } |
1223 | |
1224 | /// Progressive lowering of vector TransferReadOp with unrolling: Unpack one |
1225 | /// dimension. This is similar to TransferOpConversion<TransferReadOp>, but no |
1226 | /// memref buffer is allocated and the SCF loop is fully unrolled. |
1227 | /// |
1228 | /// ``` |
1229 | /// E.g.: |
1230 | /// ``` |
1231 | /// %vec = vector.transfer_read %A[%a, %b, %c], %padding |
1232 | /// : memref<?x?x?xf32>, vector<5x4xf32> |
1233 | /// ``` |
1234 | /// is rewritten to IR such as (simplified): |
1235 | /// ``` |
1236 | /// %v_init = splat %padding : vector<5x4xf32> |
1237 | /// %tmp0 = vector.transfer_read %A[%a, %b, %c], %padding |
1238 | /// : memref<?x?x?xf32>, vector<4xf32> |
1239 | /// %v0 = vector.insert %tmp0, %v_init[0] : vector<4xf32> into vector<5x4xf32> |
1240 | /// %tmp1 = vector.transfer_read %A[%a, %b + 1, %c], %padding |
1241 | /// : memref<?x?x?xf32>, vector<4xf32> |
1242 | /// %v1 = vector.insert %tmp1, %v0[1] : vector<4xf32> into vector<5x4xf32> |
1243 | /// ... |
1244 | /// %tmp4 = vector.transfer_read %A[%a, %b + 4, %c], %padding |
1245 | /// : memref<?x?x?xf32>, vector<4xf32> |
1246 | /// %vec = vector.insert %tmp1, %v3[4] : vector<4xf32> into vector<5x4xf32> |
1247 | /// ``` |
1248 | /// |
1249 | /// Note: As an optimization, if the result of the original TransferReadOp |
1250 | /// was directly inserted into another vector, no new %v_init vector is created. |
1251 | /// Instead, the new TransferReadOp results are inserted into that vector. |
1252 | struct UnrollTransferReadConversion |
1253 | : public VectorToSCFPattern<TransferReadOp> { |
1254 | using VectorToSCFPattern<TransferReadOp>::VectorToSCFPattern; |
1255 | |
1256 | void initialize() { |
1257 | // This pattern recursively unpacks one dimension at a time. The recursion |
1258 | // bounded as the rank is strictly decreasing. |
1259 | setHasBoundedRewriteRecursion(); |
1260 | } |
1261 | |
1262 | /// Get or build the vector into which the newly created TransferReadOp |
1263 | /// results are inserted. |
1264 | Value buildResultVector(PatternRewriter &rewriter, |
1265 | TransferReadOp xferOp) const { |
1266 | if (auto insertOp = getInsertOp(xferOp)) |
1267 | return insertOp.getDest(); |
1268 | Location loc = xferOp.getLoc(); |
1269 | return rewriter.create<vector::SplatOp>(loc, xferOp.getVectorType(), |
1270 | xferOp.getPadding()); |
1271 | } |
1272 | |
1273 | /// If the result of the TransferReadOp has exactly one user, which is a |
1274 | /// vector::InsertOp, return that operation. |
1275 | vector::InsertOp getInsertOp(TransferReadOp xferOp) const { |
1276 | if (xferOp->hasOneUse()) { |
1277 | Operation *xferOpUser = *xferOp->getUsers().begin(); |
1278 | if (auto insertOp = dyn_cast<vector::InsertOp>(xferOpUser)) |
1279 | return insertOp; |
1280 | } |
1281 | |
1282 | return vector::InsertOp(); |
1283 | } |
1284 | |
1285 | /// If the result of the TransferReadOp has exactly one user, which is a |
1286 | /// vector::InsertOp, return that operation's indices. |
1287 | void getInsertionIndices(TransferReadOp xferOp, |
1288 | SmallVectorImpl<OpFoldResult> &indices) const { |
1289 | if (auto insertOp = getInsertOp(xferOp)) { |
1290 | auto pos = insertOp.getMixedPosition(); |
1291 | indices.append(pos.begin(), pos.end()); |
1292 | } |
1293 | } |
1294 | |
1295 | /// Rewrite the op: Unpack one dimension. Can handle masks, out-of-bounds |
1296 | /// accesses, and broadcasts and transposes in permutation maps. |
1297 | LogicalResult matchAndRewrite(TransferReadOp xferOp, |
1298 | PatternRewriter &rewriter) const override { |
1299 | if (xferOp.getVectorType().getRank() <= options.targetRank) |
1300 | return rewriter.notifyMatchFailure( |
1301 | xferOp, "vector rank is less or equal to target rank"); |
1302 | if (failed(checkLowerTensors(xferOp, rewriter))) |
1303 | return failure(); |
1304 | if (xferOp.getVectorType().getElementType() != |
1305 | xferOp.getShapedType().getElementType()) |
1306 | return rewriter.notifyMatchFailure( |
1307 | xferOp, "not yet supported: element type mismatch"); |
1308 | auto xferVecType = xferOp.getVectorType(); |
1309 | if (xferVecType.getScalableDims()[0]) { |
1310 | return rewriter.notifyMatchFailure( |
1311 | xferOp, "scalable dimensions cannot be unrolled at compile time"); |
1312 | } |
1313 | |
1314 | auto insertOp = getInsertOp(xferOp); |
1315 | auto vec = buildResultVector(rewriter, xferOp: xferOp); |
1316 | auto vecType = dyn_cast<VectorType>(vec.getType()); |
1317 | |
1318 | VectorType newXferVecType = VectorType::Builder(xferVecType).dropDim(0); |
1319 | |
1320 | int64_t dimSize = xferVecType.getShape()[0]; |
1321 | |
1322 | // Generate fully unrolled loop of transfer ops. |
1323 | Location loc = xferOp.getLoc(); |
1324 | for (int64_t i = 0; i < dimSize; ++i) { |
1325 | Value iv = rewriter.create<arith::ConstantIndexOp>(loc, i); |
1326 | |
1327 | vec = generateInBoundsCheck( |
1328 | rewriter, xferOp, iv, unpackedDim(xferOp), TypeRange(vecType), |
1329 | /*inBoundsCase=*/ |
1330 | [&](OpBuilder &b, Location loc) { |
1331 | // Indices for the new transfer op. |
1332 | SmallVector<Value, 8> xferIndices; |
1333 | getXferIndices(b, xferOp, iv, xferIndices); |
1334 | |
1335 | // Indices for the new vector.insert op. |
1336 | SmallVector<OpFoldResult, 8> insertionIndices; |
1337 | getInsertionIndices(xferOp: xferOp, indices&: insertionIndices); |
1338 | insertionIndices.push_back(rewriter.getIndexAttr(i)); |
1339 | |
1340 | auto inBoundsAttr = dropFirstElem(b, xferOp.getInBoundsAttr()); |
1341 | auto newXferOp = b.create<vector::TransferReadOp>( |
1342 | loc, newXferVecType, xferOp.getBase(), xferIndices, |
1343 | AffineMapAttr::get(unpackedPermutationMap(b, xferOp)), |
1344 | xferOp.getPadding(), Value(), inBoundsAttr); |
1345 | maybeAssignMask(b, xferOp, newXferOp, i); |
1346 | return b.create<vector::InsertOp>(loc, newXferOp, vec, |
1347 | insertionIndices); |
1348 | }, |
1349 | /*outOfBoundsCase=*/ |
1350 | [&](OpBuilder &b, Location loc) { |
1351 | // Loop through original (unmodified) vector. |
1352 | return vec; |
1353 | }); |
1354 | } |
1355 | |
1356 | if (insertOp) { |
1357 | // Rewrite single user of the old TransferReadOp, which was an InsertOp. |
1358 | rewriter.replaceOp(insertOp, vec); |
1359 | rewriter.eraseOp(op: xferOp); |
1360 | } else { |
1361 | rewriter.replaceOp(xferOp, vec); |
1362 | } |
1363 | |
1364 | return success(); |
1365 | } |
1366 | }; |
1367 | |
1368 | /// Progressive lowering of vector TransferWriteOp with unrolling: Unpack one |
1369 | /// dimension. This is similar to TransferOpConversion<TransferWriteOp>, but no |
1370 | /// memref buffer is allocated and the SCF loop is fully unrolled. |
1371 | /// |
1372 | /// ``` |
1373 | /// E.g.: |
1374 | /// ``` |
1375 | /// vector.transfer_write %vec, %A[%a, %b, %c] |
1376 | /// : vector<5x4xf32>, memref<?x?x?xf32> |
1377 | /// ``` |
1378 | /// is rewritten to IR such as (simplified): |
1379 | /// ``` |
1380 | /// %v0 = vector.extract %vec[0] : vector<4xf32> from vector<5x4xf32> |
1381 | /// vector.transfer_write %v0, %A[%a, %b, %c] : vector<4xf32>, memref<...> |
1382 | /// %v1 = vector.extract %vec[1] : vector<4xf32> from vector<5x4xf32> |
1383 | /// vector.transfer_write %v1, %A[%a, %b + 1, %c] : vector<4xf32>, memref<...> |
1384 | /// ... |
1385 | /// %v4 = vector.extract %vec[4] : vector<4xf32> from vector<5x4xf32> |
1386 | /// vector.transfer_write %v4, %A[%a, %b + 4, %c] : vector<4xf32>, memref<...> |
1387 | /// ``` |
1388 | /// |
1389 | /// Note: As an optimization, if the vector of the original TransferWriteOp |
1390 | /// was directly extracted from another vector via an ExtractOp `a`, extract |
1391 | /// the vectors for the newly generated TransferWriteOps from `a`'s input. By |
1392 | /// doing so, `a` may become dead, and the number of ExtractOps generated during |
1393 | /// recursive application of this pattern will be minimal. |
1394 | struct UnrollTransferWriteConversion |
1395 | : public VectorToSCFPattern<TransferWriteOp> { |
1396 | using VectorToSCFPattern<TransferWriteOp>::VectorToSCFPattern; |
1397 | |
1398 | void initialize() { |
1399 | // This pattern recursively unpacks one dimension at a time. The recursion |
1400 | // bounded as the rank is strictly decreasing. |
1401 | setHasBoundedRewriteRecursion(); |
1402 | } |
1403 | |
1404 | /// Return the vector from which newly generated ExtracOps will extract. |
1405 | Value getDataVector(TransferWriteOp xferOp) const { |
1406 | if (auto extractOp = getExtractOp(xferOp)) |
1407 | return extractOp.getVector(); |
1408 | return xferOp.getVector(); |
1409 | } |
1410 | |
1411 | /// If the input of the given TransferWriteOp is an ExtractOp, return it. |
1412 | vector::ExtractOp getExtractOp(TransferWriteOp xferOp) const { |
1413 | if (auto *op = xferOp.getVector().getDefiningOp()) |
1414 | return dyn_cast<vector::ExtractOp>(op); |
1415 | return vector::ExtractOp(); |
1416 | } |
1417 | |
1418 | /// If the input of the given TransferWriteOp is an ExtractOp, return its |
1419 | /// indices. |
1420 | void getExtractionIndices(TransferWriteOp xferOp, |
1421 | SmallVectorImpl<OpFoldResult> &indices) const { |
1422 | if (auto extractOp = getExtractOp(xferOp)) { |
1423 | auto pos = extractOp.getMixedPosition(); |
1424 | indices.append(pos.begin(), pos.end()); |
1425 | } |
1426 | } |
1427 | |
1428 | /// Rewrite the op: Unpack one dimension. Can handle masks, out-of-bounds |
1429 | /// accesses, and broadcasts and transposes in permutation maps. |
1430 | LogicalResult matchAndRewrite(TransferWriteOp xferOp, |
1431 | PatternRewriter &rewriter) const override { |
1432 | VectorType inputVectorTy = xferOp.getVectorType(); |
1433 | |
1434 | if (inputVectorTy.getRank() <= options.targetRank) |
1435 | return failure(); |
1436 | |
1437 | if (failed(checkLowerTensors(xferOp, rewriter))) |
1438 | return failure(); |
1439 | // Transfer ops that modify the element type are not supported atm. |
1440 | if (inputVectorTy.getElementType() != |
1441 | xferOp.getShapedType().getElementType()) |
1442 | return failure(); |
1443 | |
1444 | auto vec = getDataVector(xferOp: xferOp); |
1445 | if (inputVectorTy.getScalableDims()[0]) { |
1446 | // Cannot unroll a scalable dimension at compile time. |
1447 | return failure(); |
1448 | } |
1449 | |
1450 | int64_t dimSize = inputVectorTy.getShape()[0]; |
1451 | Value source = xferOp.getBase(); // memref or tensor to be written to. |
1452 | auto sourceType = isTensorOp(xferOp) ? xferOp.getShapedType() : Type(); |
1453 | |
1454 | // Generate fully unrolled loop of transfer ops. |
1455 | Location loc = xferOp.getLoc(); |
1456 | for (int64_t i = 0; i < dimSize; ++i) { |
1457 | Value iv = rewriter.create<arith::ConstantIndexOp>(loc, i); |
1458 | |
1459 | auto updatedSource = generateInBoundsCheck( |
1460 | rewriter, xferOp, iv, unpackedDim(xferOp), |
1461 | isTensorOp(xferOp) ? TypeRange(sourceType) : TypeRange(), |
1462 | /*inBoundsCase=*/ |
1463 | [&](OpBuilder &b, Location loc) { |
1464 | // Indices for the new transfer op. |
1465 | SmallVector<Value, 8> xferIndices; |
1466 | getXferIndices(b, xferOp, iv, xferIndices); |
1467 | |
1468 | // Indices for the new vector.extract op. |
1469 | SmallVector<OpFoldResult, 8> extractionIndices; |
1470 | getExtractionIndices(xferOp: xferOp, indices&: extractionIndices); |
1471 | extractionIndices.push_back(b.getI64IntegerAttr(i)); |
1472 | |
1473 | auto extracted = |
1474 | b.create<vector::ExtractOp>(loc, vec, extractionIndices); |
1475 | auto inBoundsAttr = dropFirstElem(b, xferOp.getInBoundsAttr()); |
1476 | Value xferVec; |
1477 | if (inputVectorTy.getRank() == 1) { |
1478 | // When target-rank=0, unrolling would causes the vector input |
1479 | // argument into `transfer_write` to become a scalar. We solve |
1480 | // this by broadcasting the scalar to a 0D vector. |
1481 | xferVec = b.create<vector::BroadcastOp>( |
1482 | loc, VectorType::get({}, extracted.getType()), extracted); |
1483 | } else { |
1484 | xferVec = extracted; |
1485 | } |
1486 | auto newXferOp = b.create<vector::TransferWriteOp>( |
1487 | loc, sourceType, xferVec, source, xferIndices, |
1488 | AffineMapAttr::get(unpackedPermutationMap(b, xferOp)), Value(), |
1489 | inBoundsAttr); |
1490 | |
1491 | maybeAssignMask(b, xferOp, newXferOp, i); |
1492 | |
1493 | return isTensorOp(xferOp) ? newXferOp->getResult(0) : Value(); |
1494 | }, |
1495 | /*outOfBoundsCase=*/ |
1496 | [&](OpBuilder &b, Location loc) { |
1497 | return isTensorOp(xferOp) ? source : Value(); |
1498 | }); |
1499 | |
1500 | if (isTensorOp(xferOp)) |
1501 | source = updatedSource; |
1502 | } |
1503 | |
1504 | if (isTensorOp(xferOp)) |
1505 | rewriter.replaceOp(xferOp, source); |
1506 | else |
1507 | rewriter.eraseOp(op: xferOp); |
1508 | |
1509 | return success(); |
1510 | } |
1511 | }; |
1512 | |
1513 | } // namespace lowering_n_d_unrolled |
1514 | |
1515 | namespace lowering_1_d { |
1516 | |
1517 | /// Compute the indices into the memref for the LoadOp/StoreOp generated as |
1518 | /// part of TransferOp1dConversion. Return the memref dimension on which |
1519 | /// the transfer is operating. A return value of std::nullopt indicates a |
1520 | /// broadcast. |
1521 | template <typename OpTy> |
1522 | static std::optional<int64_t> |
1523 | get1dMemrefIndices(OpBuilder &b, OpTy xferOp, Value iv, |
1524 | SmallVector<Value, 8> &memrefIndices) { |
1525 | auto indices = xferOp.getIndices(); |
1526 | auto map = xferOp.getPermutationMap(); |
1527 | assert(xferOp.getTransferRank() > 0 && "unexpected 0-d transfer"); |
1528 | |
1529 | memrefIndices.append(indices.begin(), indices.end()); |
1530 | assert(map.getNumResults() == 1 && |
1531 | "Expected 1 permutation map result for 1D transfer"); |
1532 | if (auto expr = dyn_cast<AffineDimExpr>(map.getResult(0))) { |
1533 | Location loc = xferOp.getLoc(); |
1534 | auto dim = expr.getPosition(); |
1535 | AffineExpr d0, d1; |
1536 | bindDims(xferOp.getContext(), d0, d1); |
1537 | Value offset = memrefIndices[dim]; |
1538 | memrefIndices[dim] = |
1539 | affine::makeComposedAffineApply(b, loc, d0 + d1, {offset, iv}); |
1540 | return dim; |
1541 | } |
1542 | |
1543 | assert(xferOp.isBroadcastDim(0) && |
1544 | "Expected AffineDimExpr or AffineConstantExpr"); |
1545 | return std::nullopt; |
1546 | } |
1547 | |
1548 | /// Codegen strategy for TransferOp1dConversion, depending on the |
1549 | /// operation. |
1550 | template <typename OpTy> |
1551 | struct Strategy1d; |
1552 | |
1553 | /// Codegen strategy for TransferReadOp. |
1554 | template <> |
1555 | struct Strategy1d<TransferReadOp> { |
1556 | static void generateForLoopBody(OpBuilder &b, Location loc, |
1557 | TransferReadOp xferOp, Value iv, |
1558 | ValueRange loopState) { |
1559 | SmallVector<Value, 8> indices; |
1560 | auto dim = get1dMemrefIndices(b, xferOp, iv, indices); |
1561 | auto vec = loopState[0]; |
1562 | |
1563 | // In case of out-of-bounds access, leave `vec` as is (was initialized with |
1564 | // padding value). |
1565 | auto nextVec = generateInBoundsCheck( |
1566 | b, xferOp, iv, dim, TypeRange(xferOp.getVectorType()), |
1567 | /*inBoundsCase=*/ |
1568 | [&](OpBuilder &b, Location loc) { |
1569 | Value val = b.create<memref::LoadOp>(loc, xferOp.getBase(), indices); |
1570 | return b.create<vector::InsertElementOp>(loc, val, vec, iv); |
1571 | }, |
1572 | /*outOfBoundsCase=*/ |
1573 | [&](OpBuilder & /*b*/, Location loc) { return vec; }); |
1574 | b.create<scf::YieldOp>(loc, nextVec); |
1575 | } |
1576 | |
1577 | static Value initialLoopState(OpBuilder &b, TransferReadOp xferOp) { |
1578 | // Inititalize vector with padding value. |
1579 | Location loc = xferOp.getLoc(); |
1580 | return b.create<vector::SplatOp>(loc, xferOp.getVectorType(), |
1581 | xferOp.getPadding()); |
1582 | } |
1583 | }; |
1584 | |
1585 | /// Codegen strategy for TransferWriteOp. |
1586 | template <> |
1587 | struct Strategy1d<TransferWriteOp> { |
1588 | static void generateForLoopBody(OpBuilder &b, Location loc, |
1589 | TransferWriteOp xferOp, Value iv, |
1590 | ValueRange /*loopState*/) { |
1591 | SmallVector<Value, 8> indices; |
1592 | auto dim = get1dMemrefIndices(b, xferOp, iv, indices); |
1593 | |
1594 | // Nothing to do in case of out-of-bounds access. |
1595 | generateInBoundsCheck( |
1596 | b, xferOp, iv, dim, |
1597 | /*inBoundsCase=*/[&](OpBuilder &b, Location loc) { |
1598 | auto val = |
1599 | b.create<vector::ExtractElementOp>(loc, xferOp.getVector(), iv); |
1600 | b.create<memref::StoreOp>(loc, val, xferOp.getBase(), indices); |
1601 | }); |
1602 | b.create<scf::YieldOp>(loc); |
1603 | } |
1604 | |
1605 | static Value initialLoopState(OpBuilder &b, TransferWriteOp xferOp) { |
1606 | return Value(); |
1607 | } |
1608 | }; |
1609 | |
1610 | /// Lower a 1D vector transfer op to SCF using scalar loads/stores. This is |
1611 | /// necessary in cases where a 1D vector transfer op cannot be lowered into |
1612 | /// vector load/stores due to non-unit strides or broadcasts: |
1613 | /// |
1614 | /// * Transfer dimension is not the last memref dimension |
1615 | /// * Transfer dimension is a broadcast (i.e., scalar load + broadcast) |
1616 | /// * Memref has a layout map with non-unit stride on the last dimension |
1617 | /// |
1618 | /// This pattern generates IR as follows: |
1619 | /// |
1620 | /// 1. Generate a for loop iterating over each vector element. |
1621 | /// 2. Inside the loop, generate a InsertElementOp or ExtractElementOp, |
1622 | /// depending on OpTy. |
1623 | /// |
1624 | /// TODO: In some cases (no masking, etc.), LLVM::MatrixColumnMajorLoadOp |
1625 | /// can be generated instead of TransferOp1dConversion. Add such a pattern |
1626 | /// to ConvertVectorToLLVM. |
1627 | /// |
1628 | /// E.g.: |
1629 | /// ``` |
1630 | /// vector.transfer_write %vec, %A[%a, %b] |
1631 | /// {permutation_map = affine_map<(d0, d1) -> (d0)>, in_bounds = [true]} |
1632 | /// : vector<9xf32>, memref<?x?xf32> |
1633 | /// ``` |
1634 | /// Is rewritten to approximately the following pseudo-IR: |
1635 | /// ``` |
1636 | /// for i = 0 to 9 { |
1637 | /// %t = vector.extractelement %vec[i] : vector<9xf32> |
1638 | /// memref.store %t, %arg0[%a + i, %b] : memref<?x?xf32> |
1639 | /// } |
1640 | /// ``` |
1641 | template <typename OpTy> |
1642 | struct TransferOp1dConversion : public VectorToSCFPattern<OpTy> { |
1643 | using VectorToSCFPattern<OpTy>::VectorToSCFPattern; |
1644 | |
1645 | LogicalResult matchAndRewrite(OpTy xferOp, |
1646 | PatternRewriter &rewriter) const override { |
1647 | // TODO: support 0-d corner case. |
1648 | if (xferOp.getTransferRank() == 0) |
1649 | return failure(); |
1650 | auto map = xferOp.getPermutationMap(); |
1651 | auto memRefType = dyn_cast<MemRefType>(xferOp.getShapedType()); |
1652 | |
1653 | if (!memRefType) |
1654 | return failure(); |
1655 | if (xferOp.getVectorType().getRank() != 1) |
1656 | return failure(); |
1657 | if (map.isMinorIdentity() && memRefType.isLastDimUnitStride()) |
1658 | return failure(); // Handled by ConvertVectorToLLVM |
1659 | |
1660 | // Loop bounds, step, state... |
1661 | Location loc = xferOp.getLoc(); |
1662 | auto vecType = xferOp.getVectorType(); |
1663 | auto lb = rewriter.create<arith::ConstantIndexOp>(loc, 0); |
1664 | Value ub = |
1665 | rewriter.create<arith::ConstantIndexOp>(loc, vecType.getDimSize(0)); |
1666 | if (vecType.isScalable()) { |
1667 | Value vscale = |
1668 | rewriter.create<vector::VectorScaleOp>(loc, rewriter.getIndexType()); |
1669 | ub = rewriter.create<arith::MulIOp>(loc, ub, vscale); |
1670 | } |
1671 | auto step = rewriter.create<arith::ConstantIndexOp>(loc, 1); |
1672 | auto loopState = Strategy1d<OpTy>::initialLoopState(rewriter, xferOp); |
1673 | |
1674 | // Generate for loop. |
1675 | rewriter.replaceOpWithNewOp<scf::ForOp>( |
1676 | xferOp, lb, ub, step, loopState ? ValueRange(loopState) : ValueRange(), |
1677 | [&](OpBuilder &b, Location loc, Value iv, ValueRange loopState) { |
1678 | Strategy1d<OpTy>::generateForLoopBody(b, loc, xferOp, iv, loopState); |
1679 | }); |
1680 | |
1681 | return success(); |
1682 | } |
1683 | }; |
1684 | |
1685 | } // namespace lowering_1_d |
1686 | } // namespace |
1687 | |
1688 | void mlir::populateVectorToSCFConversionPatterns( |
1689 | RewritePatternSet &patterns, const VectorTransferToSCFOptions &options) { |
1690 | if (options.unroll) { |
1691 | patterns.add<lowering_n_d_unrolled::UnrollTransferReadConversion, |
1692 | lowering_n_d_unrolled::UnrollTransferWriteConversion>( |
1693 | arg: patterns.getContext(), args: options); |
1694 | } else { |
1695 | patterns.add<lowering_n_d::PrepareTransferReadConversion, |
1696 | lowering_n_d::PrepareTransferWriteConversion, |
1697 | lowering_n_d::TransferOpConversion<TransferReadOp>, |
1698 | lowering_n_d::TransferOpConversion<TransferWriteOp>>( |
1699 | arg: patterns.getContext(), args: options); |
1700 | } |
1701 | if (options.lowerScalable) { |
1702 | patterns.add<lowering_n_d::ScalableTransposeTransferWriteConversion>( |
1703 | arg: patterns.getContext(), args: options); |
1704 | } |
1705 | if (options.targetRank == 1) { |
1706 | patterns.add<lowering_1_d::TransferOp1dConversion<TransferReadOp>, |
1707 | lowering_1_d::TransferOp1dConversion<TransferWriteOp>>( |
1708 | arg: patterns.getContext(), args: options); |
1709 | } |
1710 | patterns.add<lowering_n_d::DecomposePrintOpConversion>(arg: patterns.getContext(), |
1711 | args: options); |
1712 | } |
1713 | |
1714 | namespace { |
1715 | |
1716 | struct ConvertVectorToSCFPass |
1717 | : public impl::ConvertVectorToSCFBase<ConvertVectorToSCFPass> { |
1718 | ConvertVectorToSCFPass() = default; |
1719 | ConvertVectorToSCFPass(const VectorTransferToSCFOptions &options) { |
1720 | this->fullUnroll = options.unroll; |
1721 | this->targetRank = options.targetRank; |
1722 | this->lowerTensors = options.lowerTensors; |
1723 | this->lowerScalable = options.lowerScalable; |
1724 | } |
1725 | |
1726 | void runOnOperation() override { |
1727 | VectorTransferToSCFOptions options; |
1728 | options.unroll = fullUnroll; |
1729 | options.targetRank = targetRank; |
1730 | options.lowerTensors = lowerTensors; |
1731 | options.lowerScalable = lowerScalable; |
1732 | |
1733 | // Lower permutation maps first. |
1734 | RewritePatternSet lowerTransferPatterns(&getContext()); |
1735 | mlir::vector::populateVectorTransferPermutationMapLoweringPatterns( |
1736 | patterns&: lowerTransferPatterns); |
1737 | (void)applyPatternsGreedily(getOperation(), |
1738 | std::move(lowerTransferPatterns)); |
1739 | |
1740 | RewritePatternSet patterns(&getContext()); |
1741 | populateVectorToSCFConversionPatterns(patterns, options); |
1742 | (void)applyPatternsGreedily(getOperation(), std::move(patterns)); |
1743 | } |
1744 | }; |
1745 | |
1746 | } // namespace |
1747 | |
1748 | std::unique_ptr<Pass> |
1749 | mlir::createConvertVectorToSCFPass(const VectorTransferToSCFOptions &options) { |
1750 | return std::make_unique<ConvertVectorToSCFPass>(args: options); |
1751 | } |
1752 |
Definitions
- kPassLabel
- isTensorOp
- VectorToSCFPattern
- VectorToSCFPattern
- checkLowerTensors
- unpackedDim
- unpackedPermutationMap
- getXferIndices
- maybeYieldValue
- generateMaskCheck
- generateInBoundsCheck
- generateInBoundsCheck
- dropFirstElem
- maybeApplyPassLabel
- BufferAllocs
- getAutomaticAllocationScope
- allocBuffers
- unpackOneDim
- getMaskBuffer
- Strategy
- getStoreOp
- getBuffer
- getBufferIndices
- rewriteOp
- handleOutOfBoundsDim
- cleanup
- initialLoopState
- Strategy
- getBuffer
- getBufferIndices
- rewriteOp
- handleOutOfBoundsDim
- cleanup
- initialLoopState
- checkPrepareXferOp
- PrepareTransferReadConversion
- matchAndRewrite
- PrepareTransferWriteConversion
- matchAndRewrite
- DecomposePrintOpConversion
- matchAndRewrite
- getIntTypeWithSignlessSemantics
- TransferOpConversion
- initialize
- getMaskBufferLoadIndices
- matchAndRewrite
- getMaskDimSizes
- ScalableTransposeTransferWriteConversion
- matchAndRewrite
- maybeAssignMask
- UnrollTransferReadConversion
- initialize
- buildResultVector
- getInsertOp
- getInsertionIndices
- matchAndRewrite
- UnrollTransferWriteConversion
- initialize
- getDataVector
- getExtractOp
- getExtractionIndices
- matchAndRewrite
- get1dMemrefIndices
- Strategy1d
- generateForLoopBody
- initialLoopState
- Strategy1d
- generateForLoopBody
- initialLoopState
- TransferOp1dConversion
- matchAndRewrite
- populateVectorToSCFConversionPatterns
- ConvertVectorToSCFPass
- ConvertVectorToSCFPass
- ConvertVectorToSCFPass
- runOnOperation
Improve your Profiling and Debugging skills
Find out more