1 | //===- VectorToSCF.cpp - Convert vector to SCF dialect ----------*- C++ -*-===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | // |
9 | // This file implements lowering of vector transfer operations to SCF. |
10 | // |
11 | //===----------------------------------------------------------------------===// |
12 | |
13 | #include <numeric> |
14 | #include <optional> |
15 | #include <type_traits> |
16 | |
17 | #include "mlir/Conversion/VectorToSCF/VectorToSCF.h" |
18 | |
19 | #include "mlir/Dialect/Affine/IR/AffineOps.h" |
20 | #include "mlir/Dialect/Arith/IR/Arith.h" |
21 | #include "mlir/Dialect/MemRef/IR/MemRef.h" |
22 | #include "mlir/Dialect/SCF/IR/SCF.h" |
23 | #include "mlir/Dialect/Tensor/IR/Tensor.h" |
24 | #include "mlir/Dialect/Vector/IR/VectorOps.h" |
25 | #include "mlir/Dialect/Vector/Transforms/LoweringPatterns.h" |
26 | #include "mlir/Dialect/Vector/Transforms/VectorTransforms.h" |
27 | #include "mlir/IR/Builders.h" |
28 | #include "mlir/IR/ImplicitLocOpBuilder.h" |
29 | #include "mlir/Pass/Pass.h" |
30 | #include "mlir/Transforms/GreedyPatternRewriteDriver.h" |
31 | #include "mlir/Transforms/Passes.h" |
32 | |
33 | namespace mlir { |
34 | #define GEN_PASS_DEF_CONVERTVECTORTOSCF |
35 | #include "mlir/Conversion/Passes.h.inc" |
36 | } // namespace mlir |
37 | |
38 | using namespace mlir; |
39 | using vector::TransferReadOp; |
40 | using vector::TransferWriteOp; |
41 | |
42 | namespace { |
43 | |
44 | /// Attribute name used for labeling transfer ops during progressive lowering. |
45 | static const char kPassLabel[] = "__vector_to_scf_lowering__" ; |
46 | |
47 | /// Patterns that inherit from this struct have access to |
48 | /// VectorTransferToSCFOptions. |
49 | template <typename OpTy> |
50 | struct VectorToSCFPattern : public OpRewritePattern<OpTy> { |
51 | explicit VectorToSCFPattern(MLIRContext *context, |
52 | VectorTransferToSCFOptions opt) |
53 | : OpRewritePattern<OpTy>(context), options(opt) {} |
54 | |
55 | VectorTransferToSCFOptions options; |
56 | }; |
57 | |
58 | /// Given a vector transfer op, calculate which dimension of the `source` |
59 | /// memref should be unpacked in the next application of TransferOpConversion. |
60 | /// A return value of std::nullopt indicates a broadcast. |
61 | template <typename OpTy> |
62 | static std::optional<int64_t> unpackedDim(OpTy xferOp) { |
63 | // TODO: support 0-d corner case. |
64 | assert(xferOp.getTransferRank() > 0 && "unexpected 0-d transfer" ); |
65 | auto map = xferOp.getPermutationMap(); |
66 | if (auto expr = dyn_cast<AffineDimExpr>(map.getResult(0))) { |
67 | return expr.getPosition(); |
68 | } |
69 | assert(xferOp.isBroadcastDim(0) && |
70 | "Expected AffineDimExpr or AffineConstantExpr" ); |
71 | return std::nullopt; |
72 | } |
73 | |
74 | /// Compute the permutation map for the new (N-1)-D vector transfer op. This |
75 | /// map is identical to the current permutation map, but the first result is |
76 | /// omitted. |
77 | template <typename OpTy> |
78 | static AffineMap unpackedPermutationMap(OpBuilder &b, OpTy xferOp) { |
79 | // TODO: support 0-d corner case. |
80 | assert(xferOp.getTransferRank() > 0 && "unexpected 0-d transfer" ); |
81 | auto map = xferOp.getPermutationMap(); |
82 | return AffineMap::get(map.getNumDims(), 0, map.getResults().drop_front(), |
83 | b.getContext()); |
84 | } |
85 | |
86 | /// Calculate the indices for the new vector transfer op. |
87 | /// |
88 | /// E.g.: transfer_read %A[%a, %b, %c, %d] ... : vector<5x4x3xf32> ... |
89 | /// --> transfer_read %A[%a, %b + iv, %c, %d] ... vector<4x3f32> |
90 | /// ^^^^^^ |
91 | /// `iv` is the iteration variable of the (new) surrounding loop. |
92 | template <typename OpTy> |
93 | static void getXferIndices(OpBuilder &b, OpTy xferOp, Value iv, |
94 | SmallVector<Value, 8> &indices) { |
95 | typename OpTy::Adaptor adaptor(xferOp); |
96 | // Corresponding memref dim of the vector dim that is unpacked. |
97 | auto dim = unpackedDim(xferOp); |
98 | auto prevIndices = adaptor.getIndices(); |
99 | indices.append(prevIndices.begin(), prevIndices.end()); |
100 | |
101 | Location loc = xferOp.getLoc(); |
102 | bool isBroadcast = !dim.has_value(); |
103 | if (!isBroadcast) { |
104 | AffineExpr d0, d1; |
105 | bindDims(xferOp.getContext(), d0, d1); |
106 | Value offset = adaptor.getIndices()[*dim]; |
107 | indices[*dim] = |
108 | affine::makeComposedAffineApply(b, loc, d0 + d1, {offset, iv}); |
109 | } |
110 | } |
111 | |
112 | static void maybeYieldValue(OpBuilder &b, Location loc, bool hasRetVal, |
113 | Value value) { |
114 | if (hasRetVal) { |
115 | assert(value && "Expected non-empty value" ); |
116 | b.create<scf::YieldOp>(loc, value); |
117 | } else { |
118 | b.create<scf::YieldOp>(loc); |
119 | } |
120 | } |
121 | |
122 | /// Generates a boolean Value that is true if the iv-th bit in xferOp's mask |
123 | /// is set to true. No such check is generated under following circumstances: |
124 | /// * xferOp does not have a mask. |
125 | /// * xferOp's mask is not 1D. (In case of (N>1)-D, a subvector of the mask is |
126 | /// computed and attached to the new transfer op in the pattern.) |
127 | /// * The to-be-unpacked dim of xferOp is a broadcast. |
128 | template <typename OpTy> |
129 | static Value generateMaskCheck(OpBuilder &b, OpTy xferOp, Value iv) { |
130 | if (!xferOp.getMask()) |
131 | return Value(); |
132 | if (xferOp.getMaskType().getRank() != 1) |
133 | return Value(); |
134 | if (xferOp.isBroadcastDim(0)) |
135 | return Value(); |
136 | |
137 | Location loc = xferOp.getLoc(); |
138 | return b.create<vector::ExtractElementOp>(loc, xferOp.getMask(), iv); |
139 | } |
140 | |
141 | /// Helper function TransferOpConversion and TransferOp1dConversion. |
142 | /// Generate an in-bounds check if the transfer op may go out-of-bounds on the |
143 | /// specified dimension `dim` with the loop iteration variable `iv`. |
144 | /// E.g., when unpacking dimension 0 from: |
145 | /// ``` |
146 | /// %vec = vector.transfer_read %A[%a, %b] %cst |
147 | /// : vector<5x4xf32>, memref<?x?xf32> |
148 | /// ``` |
149 | /// An if check similar to this will be generated inside the loop: |
150 | /// ``` |
151 | /// %d = memref.dim %A, %c0 : memref<?x?xf32> |
152 | /// if (%a + iv < %d) { |
153 | /// (in-bounds case) |
154 | /// } else { |
155 | /// (out-of-bounds case) |
156 | /// } |
157 | /// ``` |
158 | /// |
159 | /// If the transfer is 1D and has a mask, this function generates a more complex |
160 | /// check also accounts for potentially masked out elements. |
161 | /// |
162 | /// This function variant returns the value returned by `inBoundsCase` or |
163 | /// `outOfBoundsCase`. The MLIR type of the return value must be specified in |
164 | /// `resultTypes`. |
165 | template <typename OpTy> |
166 | static Value generateInBoundsCheck( |
167 | OpBuilder &b, OpTy xferOp, Value iv, std::optional<int64_t> dim, |
168 | TypeRange resultTypes, |
169 | function_ref<Value(OpBuilder &, Location)> inBoundsCase, |
170 | function_ref<Value(OpBuilder &, Location)> outOfBoundsCase = nullptr) { |
171 | bool hasRetVal = !resultTypes.empty(); |
172 | Value cond; // Condition to be built... |
173 | |
174 | // Condition check 1: Access in-bounds? |
175 | bool isBroadcast = !dim; // No in-bounds check for broadcasts. |
176 | Location loc = xferOp.getLoc(); |
177 | ImplicitLocOpBuilder lb(xferOp.getLoc(), b); |
178 | if (!xferOp.isDimInBounds(0) && !isBroadcast) { |
179 | Value memrefDim = |
180 | vector::createOrFoldDimOp(b, loc, source: xferOp.getSource(), dim: *dim); |
181 | AffineExpr d0, d1; |
182 | bindDims(xferOp.getContext(), d0, d1); |
183 | Value base = xferOp.getIndices()[*dim]; |
184 | Value memrefIdx = |
185 | affine::makeComposedAffineApply(b, loc, d0 + d1, {base, iv}); |
186 | cond = lb.create<arith::CmpIOp>(arith::CmpIPredicate::sgt, memrefDim, |
187 | memrefIdx); |
188 | } |
189 | |
190 | // Condition check 2: Masked in? |
191 | if (auto maskCond = generateMaskCheck(b, xferOp, iv)) { |
192 | if (cond) |
193 | cond = lb.create<arith::AndIOp>(cond, maskCond); |
194 | else |
195 | cond = maskCond; |
196 | } |
197 | |
198 | // If the condition is non-empty, generate an SCF::IfOp. |
199 | if (cond) { |
200 | auto check = lb.create<scf::IfOp>( |
201 | cond, |
202 | /*thenBuilder=*/ |
203 | [&](OpBuilder &b, Location loc) { |
204 | maybeYieldValue(b, loc, hasRetVal, inBoundsCase(b, loc)); |
205 | }, |
206 | /*elseBuilder=*/ |
207 | [&](OpBuilder &b, Location loc) { |
208 | if (outOfBoundsCase) { |
209 | maybeYieldValue(b, loc, hasRetVal, outOfBoundsCase(b, loc)); |
210 | } else { |
211 | b.create<scf::YieldOp>(loc); |
212 | } |
213 | }); |
214 | |
215 | return hasRetVal ? check.getResult(0) : Value(); |
216 | } |
217 | |
218 | // Condition is empty, no need for an SCF::IfOp. |
219 | return inBoundsCase(b, loc); |
220 | } |
221 | |
222 | /// In this function variant, `inBoundsCase` and `outOfBoundsCase` do not have |
223 | /// a return value. Consequently, this function does not have a return value. |
224 | template <typename OpTy> |
225 | static void generateInBoundsCheck( |
226 | OpBuilder &b, OpTy xferOp, Value iv, std::optional<int64_t> dim, |
227 | function_ref<void(OpBuilder &, Location)> inBoundsCase, |
228 | function_ref<void(OpBuilder &, Location)> outOfBoundsCase = nullptr) { |
229 | generateInBoundsCheck( |
230 | b, xferOp, iv, dim, /*resultTypes=*/TypeRange(), |
231 | /*inBoundsCase=*/ |
232 | [&](OpBuilder &b, Location loc) { |
233 | inBoundsCase(b, loc); |
234 | return Value(); |
235 | }, |
236 | /*outOfBoundsCase=*/ |
237 | [&](OpBuilder &b, Location loc) { |
238 | if (outOfBoundsCase) |
239 | outOfBoundsCase(b, loc); |
240 | return Value(); |
241 | }); |
242 | } |
243 | |
244 | /// Given an ArrayAttr, return a copy where the first element is dropped. |
245 | static ArrayAttr dropFirstElem(OpBuilder &b, ArrayAttr attr) { |
246 | if (!attr) |
247 | return attr; |
248 | return ArrayAttr::get(b.getContext(), attr.getValue().drop_front()); |
249 | } |
250 | |
251 | /// Add the pass label to a vector transfer op if its rank is not the target |
252 | /// rank. |
253 | template <typename OpTy> |
254 | static void maybeApplyPassLabel(OpBuilder &b, OpTy newXferOp, |
255 | unsigned targetRank) { |
256 | if (newXferOp.getVectorType().getRank() > targetRank) |
257 | newXferOp->setAttr(kPassLabel, b.getUnitAttr()); |
258 | } |
259 | |
260 | /// Return true if this transfer op operates on a source tensor. |
261 | template <typename OpTy> |
262 | static bool isTensorOp(OpTy xferOp) { |
263 | if (isa<RankedTensorType>(xferOp.getShapedType())) { |
264 | if (xferOp.getOperationName().equals(TransferWriteOp::getOperationName())) { |
265 | // TransferWriteOps on tensors have a result. |
266 | assert(xferOp->getNumResults() > 0); |
267 | } |
268 | return true; |
269 | } |
270 | return false; |
271 | } |
272 | |
273 | namespace lowering_n_d { |
274 | |
275 | /// Helper data structure for data and mask buffers. |
276 | struct BufferAllocs { |
277 | Value dataBuffer; |
278 | Value maskBuffer; |
279 | }; |
280 | |
281 | // TODO: Parallelism and threadlocal considerations with a ParallelScope trait. |
282 | static Operation *getAutomaticAllocationScope(Operation *op) { |
283 | Operation *scope = |
284 | op->getParentWithTrait<OpTrait::AutomaticAllocationScope>(); |
285 | assert(scope && "Expected op to be inside automatic allocation scope" ); |
286 | return scope; |
287 | } |
288 | |
289 | /// Allocate temporary buffers for data (vector) and mask (if present). |
290 | template <typename OpTy> |
291 | static BufferAllocs allocBuffers(OpBuilder &b, OpTy xferOp) { |
292 | Location loc = xferOp.getLoc(); |
293 | OpBuilder::InsertionGuard guard(b); |
294 | Operation *scope = getAutomaticAllocationScope(xferOp); |
295 | assert(scope->getNumRegions() == 1 && |
296 | "AutomaticAllocationScope with >1 regions" ); |
297 | b.setInsertionPointToStart(&scope->getRegion(index: 0).front()); |
298 | |
299 | BufferAllocs result; |
300 | auto bufferType = MemRefType::get({}, xferOp.getVectorType()); |
301 | result.dataBuffer = b.create<memref::AllocaOp>(loc, bufferType); |
302 | |
303 | if (xferOp.getMask()) { |
304 | auto maskType = MemRefType::get({}, xferOp.getMask().getType()); |
305 | auto maskBuffer = b.create<memref::AllocaOp>(loc, maskType); |
306 | b.setInsertionPoint(xferOp); |
307 | b.create<memref::StoreOp>(loc, xferOp.getMask(), maskBuffer); |
308 | result.maskBuffer = b.create<memref::LoadOp>(loc, maskBuffer, ValueRange()); |
309 | } |
310 | |
311 | return result; |
312 | } |
313 | |
314 | /// Given a MemRefType with VectorType element type, unpack one dimension from |
315 | /// the VectorType into the MemRefType. |
316 | /// |
317 | /// E.g.: memref<9xvector<5x6xf32>> --> memref<9x5xvector<6xf32>> |
318 | static FailureOr<MemRefType> unpackOneDim(MemRefType type) { |
319 | auto vectorType = dyn_cast<VectorType>(type.getElementType()); |
320 | // Vectors with leading scalable dims are not supported. |
321 | // It may be possible to support these in future by using dynamic memref dims. |
322 | if (vectorType.getScalableDims().front()) |
323 | return failure(); |
324 | auto memrefShape = type.getShape(); |
325 | SmallVector<int64_t, 8> newMemrefShape; |
326 | newMemrefShape.append(memrefShape.begin(), memrefShape.end()); |
327 | newMemrefShape.push_back(vectorType.getDimSize(0)); |
328 | return MemRefType::get(newMemrefShape, |
329 | VectorType::Builder(vectorType).dropDim(0)); |
330 | } |
331 | |
332 | /// Given a transfer op, find the memref from which the mask is loaded. This |
333 | /// is similar to Strategy<TransferWriteOp>::getBuffer. |
334 | template <typename OpTy> |
335 | static Value getMaskBuffer(OpTy xferOp) { |
336 | assert(xferOp.getMask() && "Expected that transfer op has mask" ); |
337 | auto loadOp = xferOp.getMask().template getDefiningOp<memref::LoadOp>(); |
338 | assert(loadOp && "Expected transfer op mask produced by LoadOp" ); |
339 | return loadOp.getMemRef(); |
340 | } |
341 | |
342 | /// Codegen strategy, depending on the operation. |
343 | template <typename OpTy> |
344 | struct Strategy; |
345 | |
346 | /// Code strategy for vector TransferReadOp. |
347 | template <> |
348 | struct Strategy<TransferReadOp> { |
349 | /// Find the StoreOp that is used for writing the current TransferReadOp's |
350 | /// result to the temporary buffer allocation. |
351 | static memref::StoreOp getStoreOp(TransferReadOp xferOp) { |
352 | assert(xferOp->hasOneUse() && "Expected exactly one use of TransferReadOp" ); |
353 | auto storeOp = dyn_cast<memref::StoreOp>((*xferOp->use_begin()).getOwner()); |
354 | assert(storeOp && "Expected TransferReadOp result used by StoreOp" ); |
355 | return storeOp; |
356 | } |
357 | |
358 | /// Find the temporary buffer allocation. All labeled TransferReadOps are |
359 | /// used like this, where %buf is either the buffer allocation or a type cast |
360 | /// of the buffer allocation: |
361 | /// ``` |
362 | /// %vec = vector.transfer_read ... { __vector_to_scf_lowering__ } ... |
363 | /// memref.store %vec, %buf[...] ... |
364 | /// ``` |
365 | static Value getBuffer(TransferReadOp xferOp) { |
366 | return getStoreOp(xferOp).getMemRef(); |
367 | } |
368 | |
369 | /// Retrieve the indices of the current StoreOp that stores into the buffer. |
370 | static void getBufferIndices(TransferReadOp xferOp, |
371 | SmallVector<Value, 8> &indices) { |
372 | auto storeOp = getStoreOp(xferOp); |
373 | auto prevIndices = memref::StoreOpAdaptor(storeOp).getIndices(); |
374 | indices.append(prevIndices.begin(), prevIndices.end()); |
375 | } |
376 | |
377 | /// Rewrite the TransferReadOp, assuming that there are no out-of-bounds |
378 | /// accesses on the to-be-unpacked dimension. |
379 | /// |
380 | /// 1. Generate a new (N-1)-d TransferReadOp using the loop iteration |
381 | /// variable `iv`. |
382 | /// 2. Store the result into the (already `vector.type_cast`ed) buffer. |
383 | /// |
384 | /// E.g.: |
385 | /// ``` |
386 | /// %vec = vector.transfer_read %A[%a+%i, %b, %c], %cst |
387 | /// : memref<?x?x?xf32>, vector<4x3xf32> |
388 | /// memref.store %vec, %buf[%i] : memref<5xvector<4x3xf32>> |
389 | /// ``` |
390 | /// Is rewritten to: |
391 | /// ``` |
392 | /// %casted = vector.type_cast %buf |
393 | /// : memref<5xvector<4x3xf32>> to memref<5x4xvector<3xf32>> |
394 | /// for %j = 0 to 4 { |
395 | /// %vec = vector.transfer_read %A[%a+%i, %b+%j, %c], %cst |
396 | /// : memref<?x?x?xf32>, vector<3xf32> |
397 | /// memref.store %vec, %casted[%i, %j] : memref<5x4xvector<3xf32>> |
398 | /// } |
399 | /// ``` |
400 | /// |
401 | /// Note: The loop and type cast are generated in TransferOpConversion. |
402 | /// The original TransferReadOp and store op are deleted in `cleanup`. |
403 | /// Note: The `mask` operand is set in TransferOpConversion. |
404 | static TransferReadOp rewriteOp(OpBuilder &b, |
405 | VectorTransferToSCFOptions options, |
406 | TransferReadOp xferOp, Value buffer, Value iv, |
407 | ValueRange /*loopState*/) { |
408 | SmallVector<Value, 8> storeIndices; |
409 | getBufferIndices(xferOp: xferOp, indices&: storeIndices); |
410 | storeIndices.push_back(iv); |
411 | |
412 | SmallVector<Value, 8> xferIndices; |
413 | getXferIndices(b, xferOp, iv, xferIndices); |
414 | |
415 | Location loc = xferOp.getLoc(); |
416 | auto bufferType = dyn_cast<ShapedType>(buffer.getType()); |
417 | auto vecType = dyn_cast<VectorType>(bufferType.getElementType()); |
418 | auto inBoundsAttr = dropFirstElem(b, xferOp.getInBoundsAttr()); |
419 | auto newXferOp = b.create<vector::TransferReadOp>( |
420 | loc, vecType, xferOp.getSource(), xferIndices, |
421 | AffineMapAttr::get(unpackedPermutationMap(b, xferOp)), |
422 | xferOp.getPadding(), Value(), inBoundsAttr); |
423 | |
424 | maybeApplyPassLabel(b, newXferOp, options.targetRank); |
425 | |
426 | b.create<memref::StoreOp>(loc, newXferOp.getVector(), buffer, storeIndices); |
427 | return newXferOp; |
428 | } |
429 | |
430 | /// Handle out-of-bounds accesses on the to-be-unpacked dimension: Write |
431 | /// padding value to the temporary buffer. |
432 | static Value handleOutOfBoundsDim(OpBuilder &b, TransferReadOp xferOp, |
433 | Value buffer, Value iv, |
434 | ValueRange /*loopState*/) { |
435 | SmallVector<Value, 8> storeIndices; |
436 | getBufferIndices(xferOp: xferOp, indices&: storeIndices); |
437 | storeIndices.push_back(iv); |
438 | |
439 | Location loc = xferOp.getLoc(); |
440 | auto bufferType = dyn_cast<ShapedType>(buffer.getType()); |
441 | auto vecType = dyn_cast<VectorType>(bufferType.getElementType()); |
442 | auto vec = b.create<vector::SplatOp>(loc, vecType, xferOp.getPadding()); |
443 | b.create<memref::StoreOp>(loc, vec, buffer, storeIndices); |
444 | |
445 | return Value(); |
446 | } |
447 | |
448 | /// Cleanup after rewriting the op. |
449 | static void cleanup(PatternRewriter &rewriter, TransferReadOp xferOp, |
450 | scf::ForOp /*forOp*/) { |
451 | rewriter.eraseOp(op: getStoreOp(xferOp)); |
452 | rewriter.eraseOp(op: xferOp); |
453 | } |
454 | |
455 | /// Return the initial loop state for the generated scf.for loop. |
456 | static Value initialLoopState(TransferReadOp xferOp) { return Value(); } |
457 | }; |
458 | |
459 | /// Codegen strategy for vector TransferWriteOp. |
460 | template <> |
461 | struct Strategy<TransferWriteOp> { |
462 | /// Find the temporary buffer allocation. All labeled TransferWriteOps are |
463 | /// used like this, where %buf is either the buffer allocation or a type cast |
464 | /// of the buffer allocation: |
465 | /// ``` |
466 | /// %vec = memref.load %buf[...] ... |
467 | /// vector.transfer_write %vec ... { __vector_to_scf_lowering__ } ... |
468 | /// ``` |
469 | static Value getBuffer(TransferWriteOp xferOp) { |
470 | auto loadOp = xferOp.getVector().getDefiningOp<memref::LoadOp>(); |
471 | assert(loadOp && "Expected transfer op vector produced by LoadOp" ); |
472 | return loadOp.getMemRef(); |
473 | } |
474 | |
475 | /// Retrieve the indices of the current LoadOp that loads from the buffer. |
476 | static void getBufferIndices(TransferWriteOp xferOp, |
477 | SmallVector<Value, 8> &indices) { |
478 | auto loadOp = xferOp.getVector().getDefiningOp<memref::LoadOp>(); |
479 | auto prevIndices = memref::LoadOpAdaptor(loadOp).getIndices(); |
480 | indices.append(prevIndices.begin(), prevIndices.end()); |
481 | } |
482 | |
483 | /// Rewrite the TransferWriteOp, assuming that there are no out-of-bounds |
484 | /// accesses on the to-be-unpacked dimension. |
485 | /// |
486 | /// 1. Load an (N-1)-d vector from the (already `vector.type_cast`ed) buffer, |
487 | /// using the loop iteration variable `iv`. |
488 | /// 2. Generate a new (N-1)-d TransferWriteOp, writing the loaded vector back |
489 | /// to memory. |
490 | /// |
491 | /// Note: For more details, see comments on Strategy<TransferReadOp>. |
492 | static TransferWriteOp rewriteOp(OpBuilder &b, |
493 | VectorTransferToSCFOptions options, |
494 | TransferWriteOp xferOp, Value buffer, |
495 | Value iv, ValueRange loopState) { |
496 | SmallVector<Value, 8> loadIndices; |
497 | getBufferIndices(xferOp: xferOp, indices&: loadIndices); |
498 | loadIndices.push_back(iv); |
499 | |
500 | SmallVector<Value, 8> xferIndices; |
501 | getXferIndices(b, xferOp, iv, xferIndices); |
502 | |
503 | Location loc = xferOp.getLoc(); |
504 | auto vec = b.create<memref::LoadOp>(loc, buffer, loadIndices); |
505 | auto inBoundsAttr = dropFirstElem(b, xferOp.getInBoundsAttr()); |
506 | auto source = loopState.empty() ? xferOp.getSource() : loopState[0]; |
507 | Type type = isTensorOp(xferOp) ? xferOp.getShapedType() : Type(); |
508 | auto newXferOp = b.create<vector::TransferWriteOp>( |
509 | loc, type, vec, source, xferIndices, |
510 | AffineMapAttr::get(unpackedPermutationMap(b, xferOp)), Value(), |
511 | inBoundsAttr); |
512 | |
513 | maybeApplyPassLabel(b, newXferOp, options.targetRank); |
514 | |
515 | return newXferOp; |
516 | } |
517 | |
518 | /// Handle out-of-bounds accesses on the to-be-unpacked dimension. |
519 | static Value handleOutOfBoundsDim(OpBuilder &b, TransferWriteOp xferOp, |
520 | Value buffer, Value iv, |
521 | ValueRange loopState) { |
522 | return isTensorOp(xferOp) ? loopState[0] : Value(); |
523 | } |
524 | |
525 | /// Cleanup after rewriting the op. |
526 | static void cleanup(PatternRewriter &rewriter, TransferWriteOp xferOp, |
527 | scf::ForOp forOp) { |
528 | if (isTensorOp(xferOp)) { |
529 | assert(forOp->getNumResults() == 1 && "Expected one for loop result" ); |
530 | rewriter.replaceOp(xferOp, forOp->getResult(0)); |
531 | } else { |
532 | rewriter.eraseOp(op: xferOp); |
533 | } |
534 | } |
535 | |
536 | /// Return the initial loop state for the generated scf.for loop. |
537 | static Value initialLoopState(TransferWriteOp xferOp) { |
538 | return isTensorOp(xferOp) ? xferOp.getSource() : Value(); |
539 | } |
540 | }; |
541 | |
542 | template <typename OpTy> |
543 | LogicalResult checkPrepareXferOp(OpTy xferOp, |
544 | VectorTransferToSCFOptions options) { |
545 | if (xferOp->hasAttr(kPassLabel)) |
546 | return failure(); |
547 | if (xferOp.getVectorType().getRank() <= options.targetRank) |
548 | return failure(); |
549 | // Currently the unpacking of the leading dimension into the memref is not |
550 | // supported for scalable dimensions. |
551 | if (xferOp.getVectorType().getScalableDims().front()) |
552 | return failure(); |
553 | if (isTensorOp(xferOp) && !options.lowerTensors) |
554 | return failure(); |
555 | // Transfer ops that modify the element type are not supported atm. |
556 | if (xferOp.getVectorType().getElementType() != |
557 | xferOp.getShapedType().getElementType()) |
558 | return failure(); |
559 | return success(); |
560 | } |
561 | |
562 | /// Prepare a TransferReadOp for progressive lowering. |
563 | /// |
564 | /// 1. Allocate a temporary buffer. |
565 | /// 2. Label the TransferReadOp, marking it eligible for progressive lowering. |
566 | /// 3. Store the result of the TransferReadOp into the temporary buffer. |
567 | /// 4. Load the result from the temporary buffer and replace all uses of the |
568 | /// original TransferReadOp with this load. |
569 | /// |
570 | /// E.g.: |
571 | /// ``` |
572 | /// %vec = vector.transfer_read %A[%a, %b, %c], %cst |
573 | /// : vector<5x4xf32>, memref<?x?x?xf32> |
574 | /// ``` |
575 | /// is rewritten to: |
576 | /// ``` |
577 | /// %0 = memref.alloca() : memref<vector<5x4xf32>> |
578 | /// %1 = vector.transfer_read %A[%a, %b, %c], %cst |
579 | /// { __vector_to_scf_lowering__ } : vector<5x4xf32>, memref<?x?x?xf32> |
580 | /// memref.store %1, %0[] : memref<vector<5x4xf32>> |
581 | /// %vec = memref.load %0[] : memref<vector<5x4xf32>> |
582 | /// ``` |
583 | /// |
584 | /// Note: A second temporary buffer may be allocated for the `mask` operand. |
585 | struct PrepareTransferReadConversion |
586 | : public VectorToSCFPattern<TransferReadOp> { |
587 | using VectorToSCFPattern<TransferReadOp>::VectorToSCFPattern; |
588 | |
589 | LogicalResult matchAndRewrite(TransferReadOp xferOp, |
590 | PatternRewriter &rewriter) const override { |
591 | if (checkPrepareXferOp(xferOp, options).failed()) |
592 | return failure(); |
593 | |
594 | auto buffers = allocBuffers(rewriter, xferOp); |
595 | auto *newXfer = rewriter.clone(*xferOp.getOperation()); |
596 | newXfer->setAttr(kPassLabel, rewriter.getUnitAttr()); |
597 | if (xferOp.getMask()) { |
598 | dyn_cast<TransferReadOp>(newXfer).getMaskMutable().assign( |
599 | buffers.maskBuffer); |
600 | } |
601 | |
602 | Location loc = xferOp.getLoc(); |
603 | rewriter.create<memref::StoreOp>(loc, newXfer->getResult(0), |
604 | buffers.dataBuffer); |
605 | rewriter.replaceOpWithNewOp<memref::LoadOp>(xferOp, buffers.dataBuffer); |
606 | |
607 | return success(); |
608 | } |
609 | }; |
610 | |
611 | /// Prepare a TransferWriteOp for progressive lowering. |
612 | /// |
613 | /// 1. Allocate a temporary buffer. |
614 | /// 2. Store the vector into the buffer. |
615 | /// 3. Load the vector from the buffer again. |
616 | /// 4. Use the loaded vector as a TransferWriteOp operand and label the op, |
617 | /// marking it eligible for progressive lowering via TransferOpConversion. |
618 | /// |
619 | /// E.g.: |
620 | /// ``` |
621 | /// vector.transfer_write %vec, %A[%a, %b, %c] |
622 | /// : vector<5x4xf32>, memref<?x?x?xf32> |
623 | /// ``` |
624 | /// is rewritten to: |
625 | /// ``` |
626 | /// %0 = memref.alloca() : memref<vector<5x4xf32>> |
627 | /// memref.store %vec, %0[] : memref<vector<5x4xf32>> |
628 | /// %1 = memref.load %0[] : memref<vector<5x4xf32>> |
629 | /// vector.transfer_write %1, %A[%a, %b, %c] { __vector_to_scf_lowering__ } |
630 | /// : vector<5x4xf32>, memref<?x?x?xf32> |
631 | /// ``` |
632 | /// |
633 | /// Note: A second temporary buffer may be allocated for the `mask` operand. |
634 | struct PrepareTransferWriteConversion |
635 | : public VectorToSCFPattern<TransferWriteOp> { |
636 | using VectorToSCFPattern<TransferWriteOp>::VectorToSCFPattern; |
637 | |
638 | LogicalResult matchAndRewrite(TransferWriteOp xferOp, |
639 | PatternRewriter &rewriter) const override { |
640 | if (checkPrepareXferOp(xferOp, options).failed()) |
641 | return failure(); |
642 | |
643 | Location loc = xferOp.getLoc(); |
644 | auto buffers = allocBuffers(rewriter, xferOp); |
645 | rewriter.create<memref::StoreOp>(loc, xferOp.getVector(), |
646 | buffers.dataBuffer); |
647 | auto loadedVec = rewriter.create<memref::LoadOp>(loc, buffers.dataBuffer); |
648 | rewriter.modifyOpInPlace(xferOp, [&]() { |
649 | xferOp.getVectorMutable().assign(loadedVec); |
650 | xferOp->setAttr(kPassLabel, rewriter.getUnitAttr()); |
651 | }); |
652 | |
653 | if (xferOp.getMask()) { |
654 | rewriter.modifyOpInPlace(xferOp, [&]() { |
655 | xferOp.getMaskMutable().assign(buffers.maskBuffer); |
656 | }); |
657 | } |
658 | |
659 | return success(); |
660 | } |
661 | }; |
662 | |
663 | /// Decompose a n-D PrintOp into a loop of elementary/scalar prints. This allows |
664 | /// printing both 1D scalable vectors and n-D fixed size vectors. |
665 | /// |
666 | /// E.g.: |
667 | /// ``` |
668 | /// vector.print %v : vector<[4]xi32> |
669 | /// ``` |
670 | /// is rewritten to: |
671 | /// ``` |
672 | /// %c0 = arith.constant 0 : index |
673 | /// %c4 = arith.constant 4 : index |
674 | /// %c1 = arith.constant 1 : index |
675 | /// %vscale = vector.vscale |
676 | /// %length = arith.muli %vscale, %c4 : index |
677 | /// %lastIndex = arith.subi %length, %c1 : index |
678 | /// vector.print punctuation <open> |
679 | /// scf.for %i = %c0 to %length step %c1 { |
680 | /// %el = vector.extractelement %v[%i : index] : vector<[4]xi32> |
681 | /// vector.print %el : i32 punctuation <no_punctuation> |
682 | /// %notLastIndex = arith.cmpi ult, %i, %lastIndex : index |
683 | /// scf.if %notLastIndex { |
684 | /// vector.print punctuation <comma> |
685 | /// } |
686 | /// } |
687 | /// vector.print punctuation <close> |
688 | /// vector.print |
689 | /// ``` |
690 | struct DecomposePrintOpConversion : public VectorToSCFPattern<vector::PrintOp> { |
691 | using VectorToSCFPattern<vector::PrintOp>::VectorToSCFPattern; |
692 | LogicalResult matchAndRewrite(vector::PrintOp printOp, |
693 | PatternRewriter &rewriter) const override { |
694 | if (!printOp.getSource()) |
695 | return failure(); |
696 | |
697 | VectorType vectorType = dyn_cast<VectorType>(printOp.getPrintType()); |
698 | if (!vectorType) |
699 | return failure(); |
700 | |
701 | // Currently >= 2D scalable vectors are not supported. |
702 | // These can't be lowered to LLVM (as LLVM does not support scalable vectors |
703 | // of scalable vectors), and due to limitations of current ops can't be |
704 | // indexed with SSA values or flattened. This may change after |
705 | // https://reviews.llvm.org/D155034, though there still needs to be a path |
706 | // for lowering to LLVM. |
707 | if (vectorType.getRank() > 1 && vectorType.isScalable()) |
708 | return failure(); |
709 | |
710 | auto loc = printOp.getLoc(); |
711 | auto value = printOp.getSource(); |
712 | |
713 | if (auto intTy = dyn_cast<IntegerType>(vectorType.getElementType())) { |
714 | // Oddly sized integers are (somewhat) buggy on a lot of backends, so to |
715 | // avoid issues extend them to a more standard size. |
716 | // https://github.com/llvm/llvm-project/issues/30613 |
717 | auto width = intTy.getWidth(); |
718 | auto legalWidth = llvm::NextPowerOf2(A: std::max(8u, width) - 1); |
719 | auto legalIntTy = IntegerType::get(rewriter.getContext(), legalWidth, |
720 | intTy.getSignedness()); |
721 | // arith can only take signless integers, so we must cast back and forth. |
722 | auto signlessSourceVectorType = |
723 | vectorType.cloneWith({}, getIntTypeWithSignlessSemantics(intTy: intTy)); |
724 | auto signlessTargetVectorType = |
725 | vectorType.cloneWith({}, getIntTypeWithSignlessSemantics(intTy: legalIntTy)); |
726 | auto targetVectorType = vectorType.cloneWith({}, legalIntTy); |
727 | value = rewriter.create<vector::BitCastOp>(loc, signlessSourceVectorType, |
728 | value); |
729 | if (value.getType() != signlessTargetVectorType) { |
730 | if (width == 1 || intTy.isUnsigned()) |
731 | value = rewriter.create<arith::ExtUIOp>(loc, signlessTargetVectorType, |
732 | value); |
733 | else |
734 | value = rewriter.create<arith::ExtSIOp>(loc, signlessTargetVectorType, |
735 | value); |
736 | } |
737 | value = rewriter.create<vector::BitCastOp>(loc, targetVectorType, value); |
738 | vectorType = targetVectorType; |
739 | } |
740 | |
741 | auto scalableDimensions = vectorType.getScalableDims(); |
742 | auto shape = vectorType.getShape(); |
743 | constexpr int64_t singletonShape[] = {1}; |
744 | if (vectorType.getRank() == 0) |
745 | shape = singletonShape; |
746 | |
747 | if (vectorType.getRank() != 1) { |
748 | // Flatten n-D vectors to 1D. This is done to allow indexing with a |
749 | // non-constant value (which can currently only be done via |
750 | // vector.extractelement for 1D vectors). |
751 | auto flatLength = std::accumulate(shape.begin(), shape.end(), 1, |
752 | std::multiplies<int64_t>()); |
753 | auto flatVectorType = |
754 | VectorType::get({flatLength}, vectorType.getElementType()); |
755 | value = rewriter.create<vector::ShapeCastOp>(loc, flatVectorType, value); |
756 | } |
757 | |
758 | vector::PrintOp firstClose; |
759 | SmallVector<Value, 8> loopIndices; |
760 | for (unsigned d = 0; d < shape.size(); d++) { |
761 | // Setup loop bounds and step. |
762 | Value lowerBound = rewriter.create<arith::ConstantIndexOp>(loc, 0); |
763 | Value upperBound = rewriter.create<arith::ConstantIndexOp>(loc, shape[d]); |
764 | Value step = rewriter.create<arith::ConstantIndexOp>(loc, 1); |
765 | if (!scalableDimensions.empty() && scalableDimensions[d]) { |
766 | auto vscale = rewriter.create<vector::VectorScaleOp>( |
767 | loc, rewriter.getIndexType()); |
768 | upperBound = rewriter.create<arith::MulIOp>(loc, upperBound, vscale); |
769 | } |
770 | auto lastIndex = rewriter.create<arith::SubIOp>(loc, upperBound, step); |
771 | |
772 | // Create a loop to print the elements surrounded by parentheses. |
773 | rewriter.create<vector::PrintOp>(loc, vector::PrintPunctuation::Open); |
774 | auto loop = |
775 | rewriter.create<scf::ForOp>(loc, lowerBound, upperBound, step); |
776 | auto printClose = rewriter.create<vector::PrintOp>( |
777 | loc, vector::PrintPunctuation::Close); |
778 | if (!firstClose) |
779 | firstClose = printClose; |
780 | |
781 | auto loopIdx = loop.getInductionVar(); |
782 | loopIndices.push_back(loopIdx); |
783 | |
784 | // Print a comma after all but the last element. |
785 | rewriter.setInsertionPointToStart(loop.getBody()); |
786 | auto notLastIndex = rewriter.create<arith::CmpIOp>( |
787 | loc, arith::CmpIPredicate::ult, loopIdx, lastIndex); |
788 | rewriter.create<scf::IfOp>(loc, notLastIndex, |
789 | [&](OpBuilder &builder, Location loc) { |
790 | builder.create<vector::PrintOp>( |
791 | loc, vector::PrintPunctuation::Comma); |
792 | builder.create<scf::YieldOp>(loc); |
793 | }); |
794 | |
795 | rewriter.setInsertionPointToStart(loop.getBody()); |
796 | } |
797 | |
798 | // Compute the flattened index. |
799 | // Note: For the > rank 1 vectors this assumes non-scalable. |
800 | Value flatIndex; |
801 | auto currentStride = 1; |
802 | for (int d = shape.size() - 1; d >= 0; d--) { |
803 | auto stride = rewriter.create<arith::ConstantIndexOp>(loc, currentStride); |
804 | auto index = rewriter.create<arith::MulIOp>(loc, stride, loopIndices[d]); |
805 | if (flatIndex) |
806 | flatIndex = rewriter.create<arith::AddIOp>(loc, flatIndex, index); |
807 | else |
808 | flatIndex = index; |
809 | currentStride *= shape[d]; |
810 | } |
811 | |
812 | // Print the scalar elements in the inner most loop. |
813 | auto element = |
814 | rewriter.create<vector::ExtractElementOp>(loc, value, flatIndex); |
815 | rewriter.create<vector::PrintOp>(loc, element, |
816 | vector::PrintPunctuation::NoPunctuation); |
817 | |
818 | rewriter.setInsertionPointAfter(firstClose); |
819 | rewriter.create<vector::PrintOp>(loc, printOp.getPunctuation()); |
820 | rewriter.eraseOp(op: printOp); |
821 | return success(); |
822 | } |
823 | |
824 | static IntegerType getIntTypeWithSignlessSemantics(IntegerType intTy) { |
825 | return IntegerType::get(intTy.getContext(), intTy.getWidth(), |
826 | IntegerType::Signless); |
827 | }; |
828 | }; |
829 | |
830 | /// Progressive lowering of vector transfer ops: Unpack one dimension. |
831 | /// |
832 | /// 1. Unpack one dimension from the current buffer type and cast the buffer |
833 | /// to that new type. E.g.: |
834 | /// ``` |
835 | /// %vec = memref.load %0[%1] : memref<5xvector<4x3xf32>> |
836 | /// vector.transfer_write %vec ... |
837 | /// ``` |
838 | /// The following cast is generated: |
839 | /// ``` |
840 | /// %casted = vector.type_cast %0 |
841 | /// : memref<5xvector<4x3xf32>> to memref<5x4xvector<3xf32>> |
842 | /// ``` |
843 | /// 2. Generate a for loop and rewrite the transfer op according to the |
844 | /// corresponding Strategy<OpTy>. If the to-be-unpacked dimension can be |
845 | /// out-of-bounds, generate an if-check and handle both cases separately. |
846 | /// 3. Clean up according to the corresponding Strategy<OpTy>. |
847 | /// |
848 | /// Note: If the transfer op is a TransferWriteOp and operates on a tensor |
849 | /// source (as opposed to a memref source), then each iteration of the generated |
850 | /// scf.for loop yields the new tensor value. E.g.: |
851 | /// ``` |
852 | /// %result = scf.for i = 0 to 5 { |
853 | /// %0 = memref.load %buffer[i] : memref<5xvector<4x3xf32>> |
854 | /// %1 = vector.transfer_write %0, %source[...] |
855 | /// : vector<4x3xf32>, tensor<5x4x3xf32> |
856 | /// scf.yield %1 : tensor<5x4x3xf32> |
857 | /// } |
858 | /// ``` |
859 | template <typename OpTy> |
860 | struct TransferOpConversion : public VectorToSCFPattern<OpTy> { |
861 | using VectorToSCFPattern<OpTy>::VectorToSCFPattern; |
862 | |
863 | void initialize() { |
864 | // This pattern recursively unpacks one dimension at a time. The recursion |
865 | // bounded as the rank is strictly decreasing. |
866 | this->setHasBoundedRewriteRecursion(); |
867 | } |
868 | |
869 | static void getMaskBufferLoadIndices(OpTy xferOp, Value castedMaskBuffer, |
870 | SmallVectorImpl<Value> &loadIndices, |
871 | Value iv) { |
872 | assert(xferOp.getMask() && "Expected transfer op to have mask" ); |
873 | |
874 | // Add load indices from the previous iteration. |
875 | // The mask buffer depends on the permutation map, which makes determining |
876 | // the indices quite complex, so this is why we need to "look back" to the |
877 | // previous iteration to find the right indices. |
878 | Value maskBuffer = getMaskBuffer(xferOp); |
879 | for (Operation *user : maskBuffer.getUsers()) { |
880 | // If there is no previous load op, then the indices are empty. |
881 | if (auto loadOp = dyn_cast<memref::LoadOp>(user)) { |
882 | Operation::operand_range prevIndices = loadOp.getIndices(); |
883 | loadIndices.append(prevIndices.begin(), prevIndices.end()); |
884 | break; |
885 | } |
886 | } |
887 | |
888 | // In case of broadcast: Use same indices to load from memref |
889 | // as before. |
890 | if (!xferOp.isBroadcastDim(0)) |
891 | loadIndices.push_back(Elt: iv); |
892 | } |
893 | |
894 | LogicalResult matchAndRewrite(OpTy xferOp, |
895 | PatternRewriter &rewriter) const override { |
896 | if (!xferOp->hasAttr(kPassLabel)) |
897 | return failure(); |
898 | |
899 | // Find and cast data buffer. How the buffer can be found depends on OpTy. |
900 | ImplicitLocOpBuilder locB(xferOp.getLoc(), rewriter); |
901 | Value dataBuffer = Strategy<OpTy>::getBuffer(xferOp); |
902 | auto dataBufferType = dyn_cast<MemRefType>(dataBuffer.getType()); |
903 | FailureOr<MemRefType> castedDataType = unpackOneDim(dataBufferType); |
904 | if (failed(castedDataType)) |
905 | return failure(); |
906 | |
907 | auto castedDataBuffer = |
908 | locB.create<vector::TypeCastOp>(*castedDataType, dataBuffer); |
909 | |
910 | // If the xferOp has a mask: Find and cast mask buffer. |
911 | Value castedMaskBuffer; |
912 | if (xferOp.getMask()) { |
913 | Value maskBuffer = getMaskBuffer(xferOp); |
914 | if (xferOp.isBroadcastDim(0) || xferOp.getMaskType().getRank() == 1) { |
915 | // Do not unpack a dimension of the mask, if: |
916 | // * To-be-unpacked transfer op dimension is a broadcast. |
917 | // * Mask is 1D, i.e., the mask cannot be further unpacked. |
918 | // (That means that all remaining dimensions of the transfer op must |
919 | // be broadcasted.) |
920 | castedMaskBuffer = maskBuffer; |
921 | } else { |
922 | // It's safe to assume the mask buffer can be unpacked if the data |
923 | // buffer was unpacked. |
924 | auto maskBufferType = cast<MemRefType>(maskBuffer.getType()); |
925 | MemRefType castedMaskType = *unpackOneDim(maskBufferType); |
926 | castedMaskBuffer = |
927 | locB.create<vector::TypeCastOp>(castedMaskType, maskBuffer); |
928 | } |
929 | } |
930 | |
931 | // Loop bounds and step. |
932 | auto lb = locB.create<arith::ConstantIndexOp>(0); |
933 | auto ub = locB.create<arith::ConstantIndexOp>( |
934 | castedDataType->getDimSize(castedDataType->getRank() - 1)); |
935 | auto step = locB.create<arith::ConstantIndexOp>(1); |
936 | // TransferWriteOps that operate on tensors return the modified tensor and |
937 | // require a loop state. |
938 | auto loopState = Strategy<OpTy>::initialLoopState(xferOp); |
939 | |
940 | // Generate for loop. |
941 | auto result = locB.create<scf::ForOp>( |
942 | lb, ub, step, loopState ? ValueRange(loopState) : ValueRange(), |
943 | [&](OpBuilder &b, Location loc, Value iv, ValueRange loopState) { |
944 | Type stateType = loopState.empty() ? Type() : loopState[0].getType(); |
945 | |
946 | auto result = generateInBoundsCheck( |
947 | b, xferOp, iv, unpackedDim(xferOp), |
948 | stateType ? TypeRange(stateType) : TypeRange(), |
949 | /*inBoundsCase=*/ |
950 | [&](OpBuilder &b, Location loc) { |
951 | // Create new transfer op. |
952 | OpTy newXfer = Strategy<OpTy>::rewriteOp( |
953 | b, this->options, xferOp, castedDataBuffer, iv, loopState); |
954 | |
955 | // If old transfer op has a mask: Set mask on new transfer op. |
956 | // Special case: If the mask of the old transfer op is 1D and |
957 | // the unpacked dim is not a broadcast, no mask is needed on |
958 | // the new transfer op. |
959 | if (xferOp.getMask() && (xferOp.isBroadcastDim(0) || |
960 | xferOp.getMaskType().getRank() > 1)) { |
961 | OpBuilder::InsertionGuard guard(b); |
962 | b.setInsertionPoint(newXfer); // Insert load before newXfer. |
963 | |
964 | SmallVector<Value, 8> loadIndices; |
965 | getMaskBufferLoadIndices(xferOp, castedMaskBuffer, |
966 | loadIndices, iv); |
967 | auto mask = b.create<memref::LoadOp>(loc, castedMaskBuffer, |
968 | loadIndices); |
969 | rewriter.modifyOpInPlace(newXfer, [&]() { |
970 | newXfer.getMaskMutable().assign(mask); |
971 | }); |
972 | } |
973 | |
974 | return loopState.empty() ? Value() : newXfer->getResult(0); |
975 | }, |
976 | /*outOfBoundsCase=*/ |
977 | [&](OpBuilder &b, Location /*loc*/) { |
978 | return Strategy<OpTy>::handleOutOfBoundsDim( |
979 | b, xferOp, castedDataBuffer, iv, loopState); |
980 | }); |
981 | |
982 | maybeYieldValue(b, loc, !loopState.empty(), result); |
983 | }); |
984 | |
985 | Strategy<OpTy>::cleanup(rewriter, xferOp, result); |
986 | return success(); |
987 | } |
988 | }; |
989 | |
990 | } // namespace lowering_n_d |
991 | |
992 | namespace lowering_n_d_unrolled { |
993 | |
994 | /// If the original transfer op has a mask, compute the mask of the new transfer |
995 | /// op (for the current iteration `i`) and assign it. |
996 | template <typename OpTy> |
997 | static void maybeAssignMask(OpBuilder &b, OpTy xferOp, OpTy newXferOp, |
998 | int64_t i) { |
999 | if (!xferOp.getMask()) |
1000 | return; |
1001 | |
1002 | if (xferOp.isBroadcastDim(0)) { |
1003 | // To-be-unpacked dimension is a broadcast, which does not have a |
1004 | // corresponding mask dimension. Mask attribute remains unchanged. |
1005 | newXferOp.getMaskMutable().assign(xferOp.getMask()); |
1006 | return; |
1007 | } |
1008 | |
1009 | if (xferOp.getMaskType().getRank() > 1) { |
1010 | // Unpack one dimension of the mask. |
1011 | OpBuilder::InsertionGuard guard(b); |
1012 | b.setInsertionPoint(newXferOp); // Insert load before newXfer. |
1013 | |
1014 | llvm::SmallVector<int64_t, 1> indices({i}); |
1015 | Location loc = xferOp.getLoc(); |
1016 | auto newMask = b.create<vector::ExtractOp>(loc, xferOp.getMask(), indices); |
1017 | newXferOp.getMaskMutable().assign(newMask); |
1018 | } |
1019 | |
1020 | // If we end up here: The mask of the old transfer op is 1D and the unpacked |
1021 | // dim is not a broadcast, so no mask is needed on the new transfer op. |
1022 | // `generateInBoundsCheck` will have evaluated the mask already. |
1023 | } |
1024 | |
1025 | /// Progressive lowering of vector TransferReadOp with unrolling: Unpack one |
1026 | /// dimension. This is similar to TransferOpConversion<TransferReadOp>, but no |
1027 | /// memref buffer is allocated and the SCF loop is fully unrolled. |
1028 | /// |
1029 | /// ``` |
1030 | /// E.g.: |
1031 | /// ``` |
1032 | /// %vec = vector.transfer_read %A[%a, %b, %c], %padding |
1033 | /// : memref<?x?x?xf32>, vector<5x4xf32> |
1034 | /// ``` |
1035 | /// is rewritten to IR such as (simplified): |
1036 | /// ``` |
1037 | /// %v_init = splat %padding : vector<5x4xf32> |
1038 | /// %tmp0 = vector.transfer_read %A[%a, %b, %c], %padding |
1039 | /// : memref<?x?x?xf32>, vector<4xf32> |
1040 | /// %v0 = vector.insert %tmp0, %v_init[0] : vector<4xf32> into vector<5x4xf32> |
1041 | /// %tmp1 = vector.transfer_read %A[%a, %b + 1, %c], %padding |
1042 | /// : memref<?x?x?xf32>, vector<4xf32> |
1043 | /// %v1 = vector.insert %tmp1, %v0[1] : vector<4xf32> into vector<5x4xf32> |
1044 | /// ... |
1045 | /// %tmp4 = vector.transfer_read %A[%a, %b + 4, %c], %padding |
1046 | /// : memref<?x?x?xf32>, vector<4xf32> |
1047 | /// %vec = vector.insert %tmp1, %v3[4] : vector<4xf32> into vector<5x4xf32> |
1048 | /// ``` |
1049 | /// |
1050 | /// Note: As an optimization, if the result of the original TransferReadOp |
1051 | /// was directly inserted into another vector, no new %v_init vector is created. |
1052 | /// Instead, the new TransferReadOp results are inserted into that vector. |
1053 | struct UnrollTransferReadConversion |
1054 | : public VectorToSCFPattern<TransferReadOp> { |
1055 | using VectorToSCFPattern<TransferReadOp>::VectorToSCFPattern; |
1056 | |
1057 | void initialize() { |
1058 | // This pattern recursively unpacks one dimension at a time. The recursion |
1059 | // bounded as the rank is strictly decreasing. |
1060 | setHasBoundedRewriteRecursion(); |
1061 | } |
1062 | |
1063 | /// Get or build the vector into which the newly created TransferReadOp |
1064 | /// results are inserted. |
1065 | Value buildResultVector(PatternRewriter &rewriter, |
1066 | TransferReadOp xferOp) const { |
1067 | if (auto insertOp = getInsertOp(xferOp)) |
1068 | return insertOp.getDest(); |
1069 | Location loc = xferOp.getLoc(); |
1070 | return rewriter.create<vector::SplatOp>(loc, xferOp.getVectorType(), |
1071 | xferOp.getPadding()); |
1072 | } |
1073 | |
1074 | /// If the result of the TransferReadOp has exactly one user, which is a |
1075 | /// vector::InsertOp, return that operation. |
1076 | vector::InsertOp getInsertOp(TransferReadOp xferOp) const { |
1077 | if (xferOp->hasOneUse()) { |
1078 | Operation *xferOpUser = *xferOp->getUsers().begin(); |
1079 | if (auto insertOp = dyn_cast<vector::InsertOp>(xferOpUser)) |
1080 | return insertOp; |
1081 | } |
1082 | |
1083 | return vector::InsertOp(); |
1084 | } |
1085 | |
1086 | /// If the result of the TransferReadOp has exactly one user, which is a |
1087 | /// vector::InsertOp, return that operation's indices. |
1088 | void getInsertionIndices(TransferReadOp xferOp, |
1089 | SmallVectorImpl<OpFoldResult> &indices) const { |
1090 | if (auto insertOp = getInsertOp(xferOp)) { |
1091 | auto pos = insertOp.getMixedPosition(); |
1092 | indices.append(pos.begin(), pos.end()); |
1093 | } |
1094 | } |
1095 | |
1096 | /// Rewrite the op: Unpack one dimension. Can handle masks, out-of-bounds |
1097 | /// accesses, and broadcasts and transposes in permutation maps. |
1098 | LogicalResult matchAndRewrite(TransferReadOp xferOp, |
1099 | PatternRewriter &rewriter) const override { |
1100 | if (xferOp.getVectorType().getRank() <= options.targetRank) |
1101 | return rewriter.notifyMatchFailure( |
1102 | xferOp, "vector rank is less or equal to target rank" ); |
1103 | if (isTensorOp(xferOp) && !options.lowerTensors) |
1104 | return rewriter.notifyMatchFailure( |
1105 | xferOp, "transfers operating on tensors are excluded" ); |
1106 | // Transfer ops that modify the element type are not supported atm. |
1107 | if (xferOp.getVectorType().getElementType() != |
1108 | xferOp.getShapedType().getElementType()) |
1109 | return rewriter.notifyMatchFailure( |
1110 | xferOp, "not yet supported: element type mismatch" ); |
1111 | auto xferVecType = xferOp.getVectorType(); |
1112 | if (xferVecType.getScalableDims()[0]) { |
1113 | // Cannot unroll a scalable dimension at compile time. |
1114 | return rewriter.notifyMatchFailure( |
1115 | xferOp, "scalable dimensions cannot be unrolled" ); |
1116 | } |
1117 | |
1118 | auto insertOp = getInsertOp(xferOp); |
1119 | auto vec = buildResultVector(rewriter, xferOp: xferOp); |
1120 | auto vecType = dyn_cast<VectorType>(vec.getType()); |
1121 | |
1122 | VectorType newXferVecType = VectorType::Builder(xferVecType).dropDim(0); |
1123 | |
1124 | int64_t dimSize = xferVecType.getShape()[0]; |
1125 | |
1126 | // Generate fully unrolled loop of transfer ops. |
1127 | Location loc = xferOp.getLoc(); |
1128 | for (int64_t i = 0; i < dimSize; ++i) { |
1129 | Value iv = rewriter.create<arith::ConstantIndexOp>(loc, i); |
1130 | |
1131 | vec = generateInBoundsCheck( |
1132 | rewriter, xferOp, iv, unpackedDim(xferOp), TypeRange(vecType), |
1133 | /*inBoundsCase=*/ |
1134 | [&](OpBuilder &b, Location loc) { |
1135 | // Indices for the new transfer op. |
1136 | SmallVector<Value, 8> xferIndices; |
1137 | getXferIndices(b, xferOp, iv, xferIndices); |
1138 | |
1139 | // Indices for the new vector.insert op. |
1140 | SmallVector<OpFoldResult, 8> insertionIndices; |
1141 | getInsertionIndices(xferOp: xferOp, indices&: insertionIndices); |
1142 | insertionIndices.push_back(rewriter.getIndexAttr(i)); |
1143 | |
1144 | auto inBoundsAttr = dropFirstElem(b, xferOp.getInBoundsAttr()); |
1145 | auto newXferOp = b.create<vector::TransferReadOp>( |
1146 | loc, newXferVecType, xferOp.getSource(), xferIndices, |
1147 | AffineMapAttr::get(unpackedPermutationMap(b, xferOp)), |
1148 | xferOp.getPadding(), Value(), inBoundsAttr); |
1149 | maybeAssignMask(b, xferOp, newXferOp, i); |
1150 | return b.create<vector::InsertOp>(loc, newXferOp, vec, |
1151 | insertionIndices); |
1152 | }, |
1153 | /*outOfBoundsCase=*/ |
1154 | [&](OpBuilder &b, Location loc) { |
1155 | // Loop through original (unmodified) vector. |
1156 | return vec; |
1157 | }); |
1158 | } |
1159 | |
1160 | if (insertOp) { |
1161 | // Rewrite single user of the old TransferReadOp, which was an InsertOp. |
1162 | rewriter.replaceOp(insertOp, vec); |
1163 | rewriter.eraseOp(op: xferOp); |
1164 | } else { |
1165 | rewriter.replaceOp(xferOp, vec); |
1166 | } |
1167 | |
1168 | return success(); |
1169 | } |
1170 | }; |
1171 | |
1172 | /// Progressive lowering of vector TransferWriteOp with unrolling: Unpack one |
1173 | /// dimension. This is similar to TransferOpConversion<TransferWriteOp>, but no |
1174 | /// memref buffer is allocated and the SCF loop is fully unrolled. |
1175 | /// |
1176 | /// ``` |
1177 | /// E.g.: |
1178 | /// ``` |
1179 | /// vector.transfer_write %vec, %A[%a, %b, %c] |
1180 | /// : vector<5x4xf32>, memref<?x?x?xf32> |
1181 | /// ``` |
1182 | /// is rewritten to IR such as (simplified): |
1183 | /// ``` |
1184 | /// %v0 = vector.extract %vec[0] : vector<4xf32> from vector<5x4xf32> |
1185 | /// vector.transfer_write %v0, %A[%a, %b, %c] : vector<4xf32>, memref<...> |
1186 | /// %v1 = vector.extract %vec[1] : vector<4xf32> from vector<5x4xf32> |
1187 | /// vector.transfer_write %v1, %A[%a, %b + 1, %c] : vector<4xf32>, memref<...> |
1188 | /// ... |
1189 | /// %v4 = vector.extract %vec[4] : vector<4xf32> from vector<5x4xf32> |
1190 | /// vector.transfer_write %v4, %A[%a, %b + 4, %c] : vector<4xf32>, memref<...> |
1191 | /// ``` |
1192 | /// |
1193 | /// Note: As an optimization, if the vector of the original TransferWriteOp |
1194 | /// was directly extracted from another vector via an ExtractOp `a`, extract |
1195 | /// the vectors for the newly generated TransferWriteOps from `a`'s input. By |
1196 | /// doing so, `a` may become dead, and the number of ExtractOps generated during |
1197 | /// recursive application of this pattern will be minimal. |
1198 | struct UnrollTransferWriteConversion |
1199 | : public VectorToSCFPattern<TransferWriteOp> { |
1200 | using VectorToSCFPattern<TransferWriteOp>::VectorToSCFPattern; |
1201 | |
1202 | void initialize() { |
1203 | // This pattern recursively unpacks one dimension at a time. The recursion |
1204 | // bounded as the rank is strictly decreasing. |
1205 | setHasBoundedRewriteRecursion(); |
1206 | } |
1207 | |
1208 | /// Return the vector from which newly generated ExtracOps will extract. |
1209 | Value getDataVector(TransferWriteOp xferOp) const { |
1210 | if (auto extractOp = getExtractOp(xferOp)) |
1211 | return extractOp.getVector(); |
1212 | return xferOp.getVector(); |
1213 | } |
1214 | |
1215 | /// If the input of the given TransferWriteOp is an ExtractOp, return it. |
1216 | vector::ExtractOp (TransferWriteOp xferOp) const { |
1217 | if (auto *op = xferOp.getVector().getDefiningOp()) |
1218 | return dyn_cast<vector::ExtractOp>(op); |
1219 | return vector::ExtractOp(); |
1220 | } |
1221 | |
1222 | /// If the input of the given TransferWriteOp is an ExtractOp, return its |
1223 | /// indices. |
1224 | void (TransferWriteOp xferOp, |
1225 | SmallVectorImpl<OpFoldResult> &indices) const { |
1226 | if (auto extractOp = getExtractOp(xferOp)) { |
1227 | auto pos = extractOp.getMixedPosition(); |
1228 | indices.append(pos.begin(), pos.end()); |
1229 | } |
1230 | } |
1231 | |
1232 | /// Rewrite the op: Unpack one dimension. Can handle masks, out-of-bounds |
1233 | /// accesses, and broadcasts and transposes in permutation maps. |
1234 | LogicalResult matchAndRewrite(TransferWriteOp xferOp, |
1235 | PatternRewriter &rewriter) const override { |
1236 | VectorType inputVectorTy = xferOp.getVectorType(); |
1237 | |
1238 | if (inputVectorTy.getRank() <= options.targetRank) |
1239 | return failure(); |
1240 | |
1241 | if (isTensorOp(xferOp) && !options.lowerTensors) |
1242 | return failure(); |
1243 | // Transfer ops that modify the element type are not supported atm. |
1244 | if (inputVectorTy.getElementType() != |
1245 | xferOp.getShapedType().getElementType()) |
1246 | return failure(); |
1247 | |
1248 | auto vec = getDataVector(xferOp: xferOp); |
1249 | if (inputVectorTy.getScalableDims()[0]) { |
1250 | // Cannot unroll a scalable dimension at compile time. |
1251 | return failure(); |
1252 | } |
1253 | |
1254 | int64_t dimSize = inputVectorTy.getShape()[0]; |
1255 | Value source = xferOp.getSource(); // memref or tensor to be written to. |
1256 | auto sourceType = isTensorOp(xferOp) ? xferOp.getShapedType() : Type(); |
1257 | |
1258 | // Generate fully unrolled loop of transfer ops. |
1259 | Location loc = xferOp.getLoc(); |
1260 | for (int64_t i = 0; i < dimSize; ++i) { |
1261 | Value iv = rewriter.create<arith::ConstantIndexOp>(loc, i); |
1262 | |
1263 | auto updatedSource = generateInBoundsCheck( |
1264 | rewriter, xferOp, iv, unpackedDim(xferOp), |
1265 | isTensorOp(xferOp) ? TypeRange(sourceType) : TypeRange(), |
1266 | /*inBoundsCase=*/ |
1267 | [&](OpBuilder &b, Location loc) { |
1268 | // Indices for the new transfer op. |
1269 | SmallVector<Value, 8> xferIndices; |
1270 | getXferIndices(b, xferOp, iv, xferIndices); |
1271 | |
1272 | // Indices for the new vector.extract op. |
1273 | SmallVector<OpFoldResult, 8> ; |
1274 | getExtractionIndices(xferOp: xferOp, indices&: extractionIndices); |
1275 | extractionIndices.push_back(b.getI64IntegerAttr(i)); |
1276 | |
1277 | auto = |
1278 | b.create<vector::ExtractOp>(loc, vec, extractionIndices); |
1279 | auto inBoundsAttr = dropFirstElem(b, xferOp.getInBoundsAttr()); |
1280 | Value xferVec; |
1281 | if (inputVectorTy.getRank() == 1) { |
1282 | // When target-rank=0, unrolling would causes the vector input |
1283 | // argument into `transfer_write` to become a scalar. We solve |
1284 | // this by broadcasting the scalar to a 0D vector. |
1285 | xferVec = b.create<vector::BroadcastOp>( |
1286 | loc, VectorType::get({}, extracted.getType()), extracted); |
1287 | } else { |
1288 | xferVec = extracted; |
1289 | } |
1290 | auto newXferOp = b.create<vector::TransferWriteOp>( |
1291 | loc, sourceType, xferVec, source, xferIndices, |
1292 | AffineMapAttr::get(unpackedPermutationMap(b, xferOp)), Value(), |
1293 | inBoundsAttr); |
1294 | |
1295 | maybeAssignMask(b, xferOp, newXferOp, i); |
1296 | |
1297 | return isTensorOp(xferOp) ? newXferOp->getResult(0) : Value(); |
1298 | }, |
1299 | /*outOfBoundsCase=*/ |
1300 | [&](OpBuilder &b, Location loc) { |
1301 | return isTensorOp(xferOp) ? source : Value(); |
1302 | }); |
1303 | |
1304 | if (isTensorOp(xferOp)) |
1305 | source = updatedSource; |
1306 | } |
1307 | |
1308 | if (isTensorOp(xferOp)) |
1309 | rewriter.replaceOp(xferOp, source); |
1310 | else |
1311 | rewriter.eraseOp(op: xferOp); |
1312 | |
1313 | return success(); |
1314 | } |
1315 | }; |
1316 | |
1317 | } // namespace lowering_n_d_unrolled |
1318 | |
1319 | namespace lowering_1_d { |
1320 | |
1321 | /// Compute the indices into the memref for the LoadOp/StoreOp generated as |
1322 | /// part of TransferOp1dConversion. Return the memref dimension on which |
1323 | /// the transfer is operating. A return value of std::nullopt indicates a |
1324 | /// broadcast. |
1325 | template <typename OpTy> |
1326 | static std::optional<int64_t> |
1327 | get1dMemrefIndices(OpBuilder &b, OpTy xferOp, Value iv, |
1328 | SmallVector<Value, 8> &memrefIndices) { |
1329 | auto indices = xferOp.getIndices(); |
1330 | auto map = xferOp.getPermutationMap(); |
1331 | assert(xferOp.getTransferRank() > 0 && "unexpected 0-d transfer" ); |
1332 | |
1333 | memrefIndices.append(indices.begin(), indices.end()); |
1334 | assert(map.getNumResults() == 1 && |
1335 | "Expected 1 permutation map result for 1D transfer" ); |
1336 | if (auto expr = dyn_cast<AffineDimExpr>(map.getResult(0))) { |
1337 | Location loc = xferOp.getLoc(); |
1338 | auto dim = expr.getPosition(); |
1339 | AffineExpr d0, d1; |
1340 | bindDims(xferOp.getContext(), d0, d1); |
1341 | Value offset = memrefIndices[dim]; |
1342 | memrefIndices[dim] = |
1343 | affine::makeComposedAffineApply(b, loc, d0 + d1, {offset, iv}); |
1344 | return dim; |
1345 | } |
1346 | |
1347 | assert(xferOp.isBroadcastDim(0) && |
1348 | "Expected AffineDimExpr or AffineConstantExpr" ); |
1349 | return std::nullopt; |
1350 | } |
1351 | |
1352 | /// Codegen strategy for TransferOp1dConversion, depending on the |
1353 | /// operation. |
1354 | template <typename OpTy> |
1355 | struct Strategy1d; |
1356 | |
1357 | /// Codegen strategy for TransferReadOp. |
1358 | template <> |
1359 | struct Strategy1d<TransferReadOp> { |
1360 | static void generateForLoopBody(OpBuilder &b, Location loc, |
1361 | TransferReadOp xferOp, Value iv, |
1362 | ValueRange loopState) { |
1363 | SmallVector<Value, 8> indices; |
1364 | auto dim = get1dMemrefIndices(b, xferOp, iv, indices); |
1365 | auto vec = loopState[0]; |
1366 | |
1367 | // In case of out-of-bounds access, leave `vec` as is (was initialized with |
1368 | // padding value). |
1369 | auto nextVec = generateInBoundsCheck( |
1370 | b, xferOp, iv, dim, TypeRange(xferOp.getVectorType()), |
1371 | /*inBoundsCase=*/ |
1372 | [&](OpBuilder &b, Location loc) { |
1373 | Value val = |
1374 | b.create<memref::LoadOp>(loc, xferOp.getSource(), indices); |
1375 | return b.create<vector::InsertElementOp>(loc, val, vec, iv); |
1376 | }, |
1377 | /*outOfBoundsCase=*/ |
1378 | [&](OpBuilder & /*b*/, Location loc) { return vec; }); |
1379 | b.create<scf::YieldOp>(loc, nextVec); |
1380 | } |
1381 | |
1382 | static Value initialLoopState(OpBuilder &b, TransferReadOp xferOp) { |
1383 | // Inititalize vector with padding value. |
1384 | Location loc = xferOp.getLoc(); |
1385 | return b.create<vector::SplatOp>(loc, xferOp.getVectorType(), |
1386 | xferOp.getPadding()); |
1387 | } |
1388 | }; |
1389 | |
1390 | /// Codegen strategy for TransferWriteOp. |
1391 | template <> |
1392 | struct Strategy1d<TransferWriteOp> { |
1393 | static void generateForLoopBody(OpBuilder &b, Location loc, |
1394 | TransferWriteOp xferOp, Value iv, |
1395 | ValueRange /*loopState*/) { |
1396 | SmallVector<Value, 8> indices; |
1397 | auto dim = get1dMemrefIndices(b, xferOp, iv, indices); |
1398 | |
1399 | // Nothing to do in case of out-of-bounds access. |
1400 | generateInBoundsCheck( |
1401 | b, xferOp, iv, dim, |
1402 | /*inBoundsCase=*/[&](OpBuilder &b, Location loc) { |
1403 | auto val = |
1404 | b.create<vector::ExtractElementOp>(loc, xferOp.getVector(), iv); |
1405 | b.create<memref::StoreOp>(loc, val, xferOp.getSource(), indices); |
1406 | }); |
1407 | b.create<scf::YieldOp>(loc); |
1408 | } |
1409 | |
1410 | static Value initialLoopState(OpBuilder &b, TransferWriteOp xferOp) { |
1411 | return Value(); |
1412 | } |
1413 | }; |
1414 | |
1415 | /// Lower a 1D vector transfer op to SCF using scalar loads/stores. This is |
1416 | /// necessary in cases where a 1D vector transfer op cannot be lowered into |
1417 | /// vector load/stores due to non-unit strides or broadcasts: |
1418 | /// |
1419 | /// * Transfer dimension is not the last memref dimension |
1420 | /// * Transfer dimension is a broadcast (i.e., scalar load + broadcast) |
1421 | /// * Memref has a layout map with non-unit stride on the last dimension |
1422 | /// |
1423 | /// This pattern generates IR as follows: |
1424 | /// |
1425 | /// 1. Generate a for loop iterating over each vector element. |
1426 | /// 2. Inside the loop, generate a InsertElementOp or ExtractElementOp, |
1427 | /// depending on OpTy. |
1428 | /// |
1429 | /// TODO: In some cases (no masking, etc.), LLVM::MatrixColumnMajorLoadOp |
1430 | /// can be generated instead of TransferOp1dConversion. Add such a pattern |
1431 | /// to ConvertVectorToLLVM. |
1432 | /// |
1433 | /// E.g.: |
1434 | /// ``` |
1435 | /// vector.transfer_write %vec, %A[%a, %b] |
1436 | /// {permutation_map = affine_map<(d0, d1) -> (d0)>, in_bounds = [true]} |
1437 | /// : vector<9xf32>, memref<?x?xf32> |
1438 | /// ``` |
1439 | /// Is rewritten to approximately the following pseudo-IR: |
1440 | /// ``` |
1441 | /// for i = 0 to 9 { |
1442 | /// %t = vector.extractelement %vec[i] : vector<9xf32> |
1443 | /// memref.store %t, %arg0[%a + i, %b] : memref<?x?xf32> |
1444 | /// } |
1445 | /// ``` |
1446 | template <typename OpTy> |
1447 | struct TransferOp1dConversion : public VectorToSCFPattern<OpTy> { |
1448 | using VectorToSCFPattern<OpTy>::VectorToSCFPattern; |
1449 | |
1450 | LogicalResult matchAndRewrite(OpTy xferOp, |
1451 | PatternRewriter &rewriter) const override { |
1452 | // TODO: support 0-d corner case. |
1453 | if (xferOp.getTransferRank() == 0) |
1454 | return failure(); |
1455 | auto map = xferOp.getPermutationMap(); |
1456 | auto memRefType = dyn_cast<MemRefType>(xferOp.getShapedType()); |
1457 | |
1458 | if (!memRefType) |
1459 | return failure(); |
1460 | if (xferOp.getVectorType().getRank() != 1) |
1461 | return failure(); |
1462 | if (map.isMinorIdentity() && isLastMemrefDimUnitStride(memRefType)) |
1463 | return failure(); // Handled by ConvertVectorToLLVM |
1464 | |
1465 | // Loop bounds, step, state... |
1466 | Location loc = xferOp.getLoc(); |
1467 | auto vecType = xferOp.getVectorType(); |
1468 | auto lb = rewriter.create<arith::ConstantIndexOp>(loc, 0); |
1469 | Value ub = |
1470 | rewriter.create<arith::ConstantIndexOp>(loc, vecType.getDimSize(0)); |
1471 | if (vecType.isScalable()) { |
1472 | Value vscale = |
1473 | rewriter.create<vector::VectorScaleOp>(loc, rewriter.getIndexType()); |
1474 | ub = rewriter.create<arith::MulIOp>(loc, ub, vscale); |
1475 | } |
1476 | auto step = rewriter.create<arith::ConstantIndexOp>(loc, 1); |
1477 | auto loopState = Strategy1d<OpTy>::initialLoopState(rewriter, xferOp); |
1478 | |
1479 | // Generate for loop. |
1480 | rewriter.replaceOpWithNewOp<scf::ForOp>( |
1481 | xferOp, lb, ub, step, loopState ? ValueRange(loopState) : ValueRange(), |
1482 | [&](OpBuilder &b, Location loc, Value iv, ValueRange loopState) { |
1483 | Strategy1d<OpTy>::generateForLoopBody(b, loc, xferOp, iv, loopState); |
1484 | }); |
1485 | |
1486 | return success(); |
1487 | } |
1488 | }; |
1489 | |
1490 | } // namespace lowering_1_d |
1491 | } // namespace |
1492 | |
1493 | void mlir::populateVectorToSCFConversionPatterns( |
1494 | RewritePatternSet &patterns, const VectorTransferToSCFOptions &options) { |
1495 | if (options.unroll) { |
1496 | patterns.add<lowering_n_d_unrolled::UnrollTransferReadConversion, |
1497 | lowering_n_d_unrolled::UnrollTransferWriteConversion>( |
1498 | arg: patterns.getContext(), args: options); |
1499 | } else { |
1500 | patterns.add<lowering_n_d::PrepareTransferReadConversion, |
1501 | lowering_n_d::PrepareTransferWriteConversion, |
1502 | lowering_n_d::TransferOpConversion<TransferReadOp>, |
1503 | lowering_n_d::TransferOpConversion<TransferWriteOp>>( |
1504 | arg: patterns.getContext(), args: options); |
1505 | } |
1506 | |
1507 | if (options.targetRank == 1) { |
1508 | patterns.add<lowering_1_d::TransferOp1dConversion<TransferReadOp>, |
1509 | lowering_1_d::TransferOp1dConversion<TransferWriteOp>>( |
1510 | arg: patterns.getContext(), args: options); |
1511 | } |
1512 | patterns.add<lowering_n_d::DecomposePrintOpConversion>(arg: patterns.getContext(), |
1513 | args: options); |
1514 | } |
1515 | |
1516 | namespace { |
1517 | |
1518 | struct ConvertVectorToSCFPass |
1519 | : public impl::ConvertVectorToSCFBase<ConvertVectorToSCFPass> { |
1520 | ConvertVectorToSCFPass() = default; |
1521 | ConvertVectorToSCFPass(const VectorTransferToSCFOptions &options) { |
1522 | this->fullUnroll = options.unroll; |
1523 | this->targetRank = options.targetRank; |
1524 | this->lowerTensors = options.lowerTensors; |
1525 | } |
1526 | |
1527 | void runOnOperation() override { |
1528 | VectorTransferToSCFOptions options; |
1529 | options.unroll = fullUnroll; |
1530 | options.targetRank = targetRank; |
1531 | options.lowerTensors = lowerTensors; |
1532 | |
1533 | // Lower permutation maps first. |
1534 | RewritePatternSet lowerTransferPatterns(&getContext()); |
1535 | mlir::vector::populateVectorTransferPermutationMapLoweringPatterns( |
1536 | patterns&: lowerTransferPatterns); |
1537 | (void)applyPatternsAndFoldGreedily(getOperation(), |
1538 | std::move(lowerTransferPatterns)); |
1539 | |
1540 | RewritePatternSet patterns(&getContext()); |
1541 | populateVectorToSCFConversionPatterns(patterns, options); |
1542 | (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)); |
1543 | } |
1544 | }; |
1545 | |
1546 | } // namespace |
1547 | |
1548 | std::unique_ptr<Pass> |
1549 | mlir::createConvertVectorToSCFPass(const VectorTransferToSCFOptions &options) { |
1550 | return std::make_unique<ConvertVectorToSCFPass>(args: options); |
1551 | } |
1552 | |