1//===- TestVectorTransforms.cpp - Test Vector transforms and lowerings ----===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#include <optional>
10
11#include "mlir/Analysis/SliceAnalysis.h"
12#include "mlir/Dialect/Affine/IR/AffineOps.h"
13#include "mlir/Dialect/Arith/IR/Arith.h"
14#include "mlir/Dialect/Func/IR/FuncOps.h"
15#include "mlir/Dialect/GPU/IR/GPUDialect.h"
16#include "mlir/Dialect/Linalg/Passes.h"
17#include "mlir/Dialect/MemRef/IR/MemRef.h"
18#include "mlir/Dialect/NVGPU/IR/NVGPUDialect.h"
19#include "mlir/Dialect/SCF/IR/SCF.h"
20#include "mlir/Dialect/SCF/Transforms/Patterns.h"
21#include "mlir/Dialect/Tensor/IR/Tensor.h"
22#include "mlir/Dialect/Vector/IR/VectorOps.h"
23#include "mlir/Dialect/Vector/Transforms/LoweringPatterns.h"
24#include "mlir/Dialect/Vector/Transforms/VectorDistribution.h"
25#include "mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h"
26#include "mlir/Dialect/Vector/Transforms/VectorTransforms.h"
27#include "mlir/Pass/Pass.h"
28#include "mlir/Pass/PassManager.h"
29#include "mlir/Support/LLVM.h"
30#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
31
32using namespace mlir;
33using namespace mlir::linalg;
34using namespace mlir::vector;
35
36namespace {
37
38struct TestVectorToVectorLowering
39 : public PassWrapper<TestVectorToVectorLowering,
40 OperationPass<func::FuncOp>> {
41 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestVectorToVectorLowering)
42
43 TestVectorToVectorLowering() = default;
44 TestVectorToVectorLowering(const TestVectorToVectorLowering &pass)
45 : PassWrapper(pass) {}
46 StringRef getArgument() const final {
47 return "test-vector-to-vector-lowering";
48 }
49 StringRef getDescription() const final {
50 return "Test lowering patterns between ops in the vector dialect";
51 }
52
53 void getDependentDialects(DialectRegistry &registry) const override {
54 registry.insert<affine::AffineDialect>();
55 registry.insert<vector::VectorDialect>();
56 }
57
58 Option<bool> unroll{*this, "unroll", llvm::cl::desc("Include unrolling"),
59 llvm::cl::init(Val: false)};
60
61 void runOnOperation() override {
62 auto *ctx = &getContext();
63 RewritePatternSet patterns(ctx);
64 if (unroll) {
65 populateVectorUnrollPatterns(
66 patterns,
67 options: UnrollVectorOptions().setNativeShapeFn(getShape).setFilterConstraint(
68 filter));
69 }
70 populateVectorToVectorCanonicalizationPatterns(patterns);
71 populateBubbleVectorBitCastOpPatterns(patterns);
72 populateCastAwayVectorLeadingOneDimPatterns(patterns);
73 (void)applyPatternsGreedily(op: getOperation(), patterns: std::move(patterns));
74 }
75
76private:
77 // Return the target shape based on op type.
78 static std::optional<SmallVector<int64_t>> getShape(Operation *op) {
79 if (isa<arith::AddFOp, arith::SelectOp, arith::CmpFOp>(Val: op))
80 return SmallVector<int64_t>(2, 2);
81 if (isa<vector::ContractionOp>(Val: op))
82 return SmallVector<int64_t>(3, 2);
83 // For transfer ops, just propagate the shape coming from
84 // InsertStridedSlices/ExtractStridedSlices.
85 if (auto readOp = dyn_cast<vector::TransferReadOp>(Val: op)) {
86 VectorType dstVec;
87 for (Operation *users : readOp->getUsers()) {
88 auto extract = dyn_cast<ExtractStridedSliceOp>(Val: users);
89 if (!extract)
90 return std::nullopt;
91 auto vecType = cast<VectorType>(Val: extract.getResult().getType());
92 if (dstVec && dstVec != vecType)
93 return std::nullopt;
94 dstVec = vecType;
95 }
96 return SmallVector<int64_t>(dstVec.getShape());
97 }
98 if (auto writeOp = dyn_cast<vector::TransferWriteOp>(Val: op)) {
99 auto insert = writeOp.getVector().getDefiningOp<InsertStridedSliceOp>();
100 if (!insert)
101 return std::nullopt;
102 ArrayRef<int64_t> shape = insert.getSourceVectorType().getShape();
103 return SmallVector<int64_t>(shape);
104 }
105 return std::nullopt;
106 }
107
108 static LogicalResult filter(Operation *op) {
109 return success(IsSuccess: isa<arith::AddFOp, arith::SelectOp, arith::CmpFOp,
110 ContractionOp, TransferReadOp, TransferWriteOp>(Val: op));
111 }
112};
113
114struct TestVectorContractionPrepareForMMTLowering
115 : public PassWrapper<TestVectorContractionPrepareForMMTLowering,
116 OperationPass<func::FuncOp>> {
117 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(
118 TestVectorContractionPrepareForMMTLowering)
119
120 StringRef getArgument() const final {
121 return "test-vector-contraction-prepare-for-mmt-lowering";
122 }
123 StringRef getDescription() const final {
124 return "Test vector.contraction matmul canonicalization for MMT lowering.";
125 }
126 TestVectorContractionPrepareForMMTLowering() = default;
127
128 void getDependentDialects(DialectRegistry &registry) const override {
129 registry.insert<affine::AffineDialect, arith::ArithDialect,
130 vector::VectorDialect>();
131 }
132
133 void runOnOperation() override {
134 MLIRContext *ctx = &getContext();
135 RewritePatternSet patterns(ctx);
136 vector::populateVectorContractCanonicalizeMatmulToMMT(patterns);
137 (void)applyPatternsGreedily(op: getOperation(), patterns: std::move(patterns));
138 }
139};
140
141struct TestVectorUnrollingPatterns
142 : public PassWrapper<TestVectorUnrollingPatterns,
143 OperationPass<func::FuncOp>> {
144 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestVectorUnrollingPatterns)
145
146 StringRef getArgument() const final {
147 return "test-vector-unrolling-patterns";
148 }
149 StringRef getDescription() const final {
150 return "Test lowering patterns to unroll contract ops in the vector "
151 "dialect";
152 }
153 TestVectorUnrollingPatterns() = default;
154 TestVectorUnrollingPatterns(const TestVectorUnrollingPatterns &pass)
155 : PassWrapper(pass) {}
156 void runOnOperation() override {
157 MLIRContext *ctx = &getContext();
158 RewritePatternSet patterns(ctx);
159 populateVectorUnrollPatterns(
160 patterns,
161 options: UnrollVectorOptions()
162 .setNativeShape(ArrayRef<int64_t>{2, 2})
163 .setFilterConstraint([](Operation *op) {
164 return success(
165 IsSuccess: isa<arith::AddFOp, vector::FMAOp, vector::MultiDimReductionOp,
166 vector::BroadcastOp, vector::LoadOp, vector::StoreOp>(
167 Val: op));
168 }));
169 populateVectorUnrollPatterns(
170 patterns, options: UnrollVectorOptions()
171 .setNativeShape(ArrayRef<int64_t>{2})
172 .setFilterConstraint([](Operation *op) {
173 return success(IsSuccess: isa<vector::ReductionOp>(Val: op));
174 }));
175 populateVectorUnrollPatterns(
176 patterns, options: UnrollVectorOptions()
177 .setNativeShape(ArrayRef<int64_t>{1, 3, 4, 2})
178 .setFilterConstraint([](Operation *op) {
179 return success(IsSuccess: isa<vector::TransposeOp>(Val: op));
180 }));
181
182 if (unrollBasedOnType) {
183 UnrollVectorOptions::NativeShapeFnType nativeShapeFn =
184 [](Operation *op) -> std::optional<SmallVector<int64_t>> {
185 vector::ContractionOp contractOp = cast<vector::ContractionOp>(Val: op);
186 SmallVector<int64_t> nativeShape(contractOp.getIteratorTypes().size(),
187 4);
188 Type lhsType = contractOp.getLhsType().getElementType();
189 nativeShape[nativeShape.size() - 1] = lhsType.isF16() ? 4 : 2;
190 return nativeShape;
191 };
192
193 UnrollVectorOptions opts;
194 opts.setNativeShapeFn(nativeShapeFn)
195 .setFilterConstraint(
196 [](Operation *op) { return success(IsSuccess: isa<ContractionOp>(Val: op)); });
197
198 if (!unrollOrder.empty()) {
199 opts.setUnrollTraversalOrderFn(
200 [this](Operation *op) -> std::optional<SmallVector<int64_t>> {
201 vector::ContractionOp contractOp =
202 cast<vector::ContractionOp>(Val: op);
203 if (contractOp.getIteratorTypes().size() == unrollOrder.size())
204 return SmallVector<int64_t>(unrollOrder.begin(),
205 unrollOrder.end());
206 return std::nullopt;
207 });
208 }
209 populateVectorUnrollPatterns(patterns, options: opts);
210 } else {
211 auto nativeShapeFn =
212 [](Operation *op) -> std::optional<SmallVector<int64_t>> {
213 auto contractOp = dyn_cast<ContractionOp>(Val: op);
214 if (!contractOp)
215 return std::nullopt;
216 return SmallVector<int64_t>(contractOp.getIteratorTypes().size(), 2);
217 };
218 populateVectorUnrollPatterns(patterns,
219 options: UnrollVectorOptions()
220 .setNativeShapeFn(nativeShapeFn)
221 .setFilterConstraint([](Operation *op) {
222 return success(IsSuccess: isa<ContractionOp>(Val: op));
223 }));
224 }
225 populateVectorToVectorCanonicalizationPatterns(patterns);
226 (void)applyPatternsGreedily(op: getOperation(), patterns: std::move(patterns));
227 }
228
229 ListOption<int64_t> unrollOrder{*this, "unroll-order",
230 llvm::cl::desc("set the unroll order")};
231
232 Option<bool> unrollBasedOnType{
233 *this, "unroll-based-on-type",
234 llvm::cl::desc("Set the unroll factor based on type of the operation"),
235 llvm::cl::init(Val: false)};
236};
237
238struct TestVectorTransferUnrollingPatterns
239 : public PassWrapper<TestVectorTransferUnrollingPatterns,
240 OperationPass<func::FuncOp>> {
241 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(
242 TestVectorTransferUnrollingPatterns)
243
244 TestVectorTransferUnrollingPatterns() = default;
245 TestVectorTransferUnrollingPatterns(
246 const TestVectorTransferUnrollingPatterns &pass)
247 : PassWrapper(pass) {}
248
249 void getDependentDialects(DialectRegistry &registry) const override {
250 registry.insert<affine::AffineDialect>();
251 }
252 StringRef getArgument() const final {
253 return "test-vector-transfer-unrolling-patterns";
254 }
255 StringRef getDescription() const final {
256 return "Test lowering patterns to unroll transfer ops in the vector "
257 "dialect";
258 }
259 void runOnOperation() override {
260 MLIRContext *ctx = &getContext();
261 RewritePatternSet patterns(ctx);
262 UnrollVectorOptions opts;
263 opts.setNativeShape(ArrayRef<int64_t>{2, 2})
264 .setFilterConstraint([](Operation *op) {
265 return success(IsSuccess: isa<vector::TransferReadOp, vector::TransferWriteOp,
266 vector::GatherOp>(Val: op));
267 });
268 if (reverseUnrollOrder.getValue()) {
269 opts.setUnrollTraversalOrderFn(
270 [](Operation *op) -> std::optional<SmallVector<int64_t>> {
271 int64_t numLoops = 0;
272 if (auto readOp = dyn_cast<vector::TransferReadOp>(Val: op))
273 numLoops = readOp.getVectorType().getRank();
274 else if (auto writeOp = dyn_cast<vector::TransferWriteOp>(Val: op))
275 numLoops = writeOp.getVectorType().getRank();
276 else if (auto gatherOp = dyn_cast<vector::GatherOp>(Val: op))
277 numLoops = gatherOp.getVectorType().getRank();
278 else
279 return std::nullopt;
280 auto order = llvm::reverse(C: llvm::seq<int64_t>(Begin: 0, End: numLoops));
281 return llvm::to_vector(Range&: order);
282 });
283 }
284 populateVectorUnrollPatterns(patterns, options: opts);
285 populateVectorToVectorCanonicalizationPatterns(patterns);
286 (void)applyPatternsGreedily(op: getOperation(), patterns: std::move(patterns));
287 }
288
289 Option<bool> reverseUnrollOrder{
290 *this, "reverse-unroll-order",
291 llvm::cl::desc(
292 "reverse the order of unrolling of vector transfer operations"),
293 llvm::cl::init(Val: false)};
294};
295
296struct TestScalarVectorTransferLoweringPatterns
297 : public PassWrapper<TestScalarVectorTransferLoweringPatterns,
298 OperationPass<func::FuncOp>> {
299 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(
300 TestScalarVectorTransferLoweringPatterns)
301
302 TestScalarVectorTransferLoweringPatterns() = default;
303 TestScalarVectorTransferLoweringPatterns(
304 const TestScalarVectorTransferLoweringPatterns &pass)
305 : PassWrapper(pass) {}
306
307 StringRef getArgument() const final {
308 return "test-scalar-vector-transfer-lowering";
309 }
310 StringRef getDescription() const final {
311 return "Test lowering of scalar vector transfers to memref loads/stores.";
312 }
313
314 void getDependentDialects(DialectRegistry &registry) const override {
315 registry.insert<affine::AffineDialect, memref::MemRefDialect,
316 tensor::TensorDialect, vector::VectorDialect>();
317 }
318
319 Option<bool> allowMultipleUses{
320 *this, "allow-multiple-uses",
321 llvm::cl::desc("Fold transfer operations with multiple uses"),
322 llvm::cl::init(Val: false)};
323
324 void runOnOperation() override {
325 MLIRContext *ctx = &getContext();
326 RewritePatternSet patterns(ctx);
327 vector::populateScalarVectorTransferLoweringPatterns(
328 patterns, /*benefit=*/1, allowMultipleUses: allowMultipleUses.getValue());
329 (void)applyPatternsGreedily(op: getOperation(), patterns: std::move(patterns));
330 }
331};
332
333struct TestVectorTransferOpt
334 : public PassWrapper<TestVectorTransferOpt, OperationPass<func::FuncOp>> {
335 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestVectorTransferOpt)
336
337 StringRef getArgument() const final { return "test-vector-transferop-opt"; }
338 StringRef getDescription() const final {
339 return "Test optimization transformations for transfer ops";
340 }
341 void runOnOperation() override {
342 IRRewriter rewriter(&getContext());
343 transferOpflowOpt(rewriter, rootOp: getOperation());
344 }
345};
346
347struct TestVectorTransferCollapseInnerMostContiguousDims
348 : public PassWrapper<TestVectorTransferCollapseInnerMostContiguousDims,
349 OperationPass<func::FuncOp>> {
350 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(
351 TestVectorTransferCollapseInnerMostContiguousDims)
352
353 TestVectorTransferCollapseInnerMostContiguousDims() = default;
354 TestVectorTransferCollapseInnerMostContiguousDims(
355 const TestVectorTransferCollapseInnerMostContiguousDims &pass) = default;
356
357 void getDependentDialects(DialectRegistry &registry) const override {
358 registry.insert<memref::MemRefDialect, affine::AffineDialect>();
359 }
360
361 StringRef getArgument() const final {
362 return "test-vector-transfer-collapse-inner-most-dims";
363 }
364
365 StringRef getDescription() const final {
366 return "Test lowering patterns that reduces the rank of the vector "
367 "transfer memory and vector operands.";
368 }
369
370 void runOnOperation() override {
371 RewritePatternSet patterns(&getContext());
372 populateDropInnerMostUnitDimsXferOpPatterns(patterns);
373 (void)applyPatternsGreedily(op: getOperation(), patterns: std::move(patterns));
374 }
375};
376
377struct TestVectorSinkPatterns
378 : public PassWrapper<TestVectorSinkPatterns, OperationPass<func::FuncOp>> {
379 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestVectorSinkPatterns)
380
381 TestVectorSinkPatterns() = default;
382 TestVectorSinkPatterns(const TestVectorSinkPatterns &pass) = default;
383
384 void getDependentDialects(DialectRegistry &registry) const override {
385 registry.insert<memref::MemRefDialect, affine::AffineDialect>();
386 }
387
388 StringRef getArgument() const final { return "test-vector-sink-patterns"; }
389
390 StringRef getDescription() const final {
391 return "Test lowering patterns that eliminate redundant broadcast "
392 "and transpose operations.";
393 }
394
395 void runOnOperation() override {
396 RewritePatternSet patterns(&getContext());
397 populateSinkVectorOpsPatterns(patterns);
398 populateSinkVectorMemOpsPatterns(patterns);
399 (void)applyPatternsGreedily(op: getOperation(), patterns: std::move(patterns));
400 }
401};
402
403struct TestVectorReduceToContractPatternsPatterns
404 : public PassWrapper<TestVectorReduceToContractPatternsPatterns,
405 OperationPass<func::FuncOp>> {
406 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(
407 TestVectorReduceToContractPatternsPatterns)
408
409 StringRef getArgument() const final {
410 return "test-vector-reduction-to-contract-patterns";
411 }
412 StringRef getDescription() const final {
413 return "Test patterns to convert multireduce op to contract and combine "
414 "broadcast/transpose to contract";
415 }
416 void runOnOperation() override {
417 RewritePatternSet patterns(&getContext());
418 populateVectorReductionToContractPatterns(patterns);
419 (void)applyPatternsGreedily(op: getOperation(), patterns: std::move(patterns));
420 }
421};
422
423struct TestVectorChainedReductionFoldingPatterns
424 : public PassWrapper<TestVectorChainedReductionFoldingPatterns,
425 OperationPass<func::FuncOp>> {
426 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(
427 TestVectorChainedReductionFoldingPatterns)
428
429 StringRef getArgument() const final {
430 return "test-vector-chained-reduction-folding-patterns";
431 }
432 StringRef getDescription() const final {
433 return "Test patterns to fold chained vector reductions";
434 }
435 void runOnOperation() override {
436 RewritePatternSet patterns(&getContext());
437 populateChainedVectorReductionFoldingPatterns(patterns);
438 (void)applyPatternsGreedily(op: getOperation(), patterns: std::move(patterns));
439 }
440};
441
442struct TestVectorBreakDownReductionPatterns
443 : public PassWrapper<TestVectorBreakDownReductionPatterns,
444 OperationPass<func::FuncOp>> {
445 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(
446 TestVectorBreakDownReductionPatterns)
447
448 StringRef getArgument() const final {
449 return "test-vector-break-down-reduction-patterns";
450 }
451 StringRef getDescription() const final {
452 return "Test patterns to break down vector reductions into arith "
453 "reductions";
454 }
455 void runOnOperation() override {
456 RewritePatternSet patterns(&getContext());
457 populateBreakDownVectorReductionPatterns(patterns,
458 /*maxNumElementsToExtract=*/2);
459 (void)applyPatternsGreedily(op: getOperation(), patterns: std::move(patterns));
460 }
461};
462
463struct TestFlattenVectorTransferPatterns
464 : public PassWrapper<TestFlattenVectorTransferPatterns,
465 OperationPass<func::FuncOp>> {
466 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(
467 TestFlattenVectorTransferPatterns)
468
469 TestFlattenVectorTransferPatterns() = default;
470 TestFlattenVectorTransferPatterns(
471 const TestFlattenVectorTransferPatterns &pass)
472 : PassWrapper(pass) {}
473
474 StringRef getArgument() const final {
475 return "test-vector-transfer-flatten-patterns";
476 }
477
478 StringRef getDescription() const final {
479 return "Test patterns to rewrite contiguous row-major N-dimensional "
480 "vector.transfer_{read,write} ops into 1D transfers";
481 }
482
483 void getDependentDialects(DialectRegistry &registry) const override {
484 registry.insert<memref::MemRefDialect>();
485 registry.insert<affine::AffineDialect>();
486 registry.insert<vector::VectorDialect>();
487 }
488
489 Option<unsigned> targetVectorBitwidth{
490 *this, "target-vector-bitwidth",
491 llvm::cl::desc(
492 "Minimum vector bitwidth to enable the flattening transformation. "
493 "For scalable vectors this is the base size, i.e. the size "
494 "corresponding to vscale=1."),
495 llvm::cl::init(Val: std::numeric_limits<unsigned>::max())};
496
497 void runOnOperation() override {
498 RewritePatternSet patterns(&getContext());
499 populateFlattenVectorTransferPatterns(patterns, targetVectorBitwidth);
500 (void)applyPatternsGreedily(op: getOperation(), patterns: std::move(patterns));
501 }
502};
503
504struct TestVectorScanLowering
505 : public PassWrapper<TestVectorScanLowering, OperationPass<func::FuncOp>> {
506 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestVectorScanLowering)
507
508 StringRef getArgument() const final { return "test-vector-scan-lowering"; }
509 StringRef getDescription() const final {
510 return "Test lowering patterns that lower the scan op in the vector "
511 "dialect";
512 }
513 void runOnOperation() override {
514 RewritePatternSet patterns(&getContext());
515 populateVectorScanLoweringPatterns(patterns);
516 (void)applyPatternsGreedily(op: getOperation(), patterns: std::move(patterns));
517 }
518};
519
520/// Allocate shared memory for a single warp to test lowering of
521/// WarpExecuteOnLane0Op.
522static Value allocateGlobalSharedMemory(Location loc, OpBuilder &builder,
523 gpu::WarpExecuteOnLane0Op warpOp,
524 Type type) {
525 static constexpr int64_t kSharedMemorySpace = 3;
526 // Compute type of shared memory buffer.
527 MemRefType memrefType;
528 if (auto vectorType = dyn_cast<VectorType>(Val&: type)) {
529 memrefType =
530 MemRefType::get(shape: vectorType.getShape(), elementType: vectorType.getElementType(), map: {},
531 memorySpaceInd: kSharedMemorySpace);
532 } else {
533 memrefType = MemRefType::get(shape: {1}, elementType: type, map: {}, memorySpaceInd: kSharedMemorySpace);
534 }
535
536 // Get symbol table holding all shared memory globals.
537 ModuleOp moduleOp = warpOp->getParentOfType<ModuleOp>();
538 SymbolTable symbolTable(moduleOp);
539
540 // Create a pretty name.
541 SmallString<64> buf;
542 llvm::raw_svector_ostream os(buf);
543 interleave(c: memrefType.getShape(), os, separator: "x");
544 os << "x" << memrefType.getElementType();
545 std::string symbolName = (Twine("__shared_") + os.str()).str();
546
547 auto ip = builder.saveInsertionPoint();
548 builder.setInsertionPoint(moduleOp);
549 auto global = builder.create<memref::GlobalOp>(
550 location: loc,
551 /*sym_name=*/args&: symbolName,
552 /*sym_visibility=*/args: builder.getStringAttr(bytes: "private"),
553 /*type=*/args&: memrefType,
554 /*initial_value=*/args: Attribute(),
555 /*constant=*/args: false,
556 /*alignment=*/args: IntegerAttr());
557 symbolTable.insert(symbol: global);
558 // The symbol table inserts at the end of the module, but globals are a bit
559 // nicer if they are at the beginning.
560 global->moveBefore(existingOp: &moduleOp.front());
561
562 builder.restoreInsertionPoint(ip);
563 return builder.create<memref::GetGlobalOp>(location: loc, args&: memrefType, args&: symbolName);
564}
565
566static Value warpReduction(Location loc, OpBuilder &builder, Value input,
567 CombiningKind kind, uint32_t size) {
568 // First reduce on a single thread to get per lane reduction value.
569 Value laneVal = builder.create<vector::ReductionOp>(location: loc, args&: kind, args&: input);
570 // Parallel reduction using butterfly shuffles.
571 for (uint64_t i = 1; i < size; i <<= 1) {
572 Value shuffled = builder
573 .create<gpu::ShuffleOp>(location: loc, args&: laneVal, args&: i,
574 /*width=*/args&: size,
575 /*mode=*/args: gpu::ShuffleMode::XOR)
576 .getShuffleResult();
577 laneVal = makeArithReduction(b&: builder, loc, kind, v1: laneVal, acc: shuffled);
578 }
579 return laneVal;
580}
581
582struct TestVectorDistribution
583 : public PassWrapper<TestVectorDistribution, OperationPass<func::FuncOp>> {
584 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestVectorDistribution)
585
586 void getDependentDialects(DialectRegistry &registry) const override {
587 registry
588 .insert<vector::VectorDialect, scf::SCFDialect, memref::MemRefDialect,
589 gpu::GPUDialect, affine::AffineDialect>();
590 }
591
592 StringRef getArgument() const final { return "test-vector-warp-distribute"; }
593 StringRef getDescription() const final {
594 return "Test vector warp distribute transformation and lowering patterns";
595 }
596 TestVectorDistribution() = default;
597 TestVectorDistribution(const TestVectorDistribution &pass)
598 : PassWrapper(pass) {}
599
600 Option<bool> warpOpToSCF{
601 *this, "rewrite-warp-ops-to-scf-if",
602 llvm::cl::desc("Lower vector.warp_execute_on_lane0 to scf.if op"),
603 llvm::cl::init(Val: false)};
604
605 Option<bool> distributeTransferWriteOps{
606 *this, "distribute-transfer-write",
607 llvm::cl::desc("Test distribution of transfer write"),
608 llvm::cl::init(Val: false)};
609
610 Option<unsigned> maxTransferWriteElements{
611 *this, "max-transfer-write-elements",
612 llvm::cl::desc("Maximum number of transfer write elements to distribute"),
613 llvm::cl::init(Val: 1)};
614
615 Option<bool> hoistUniform{*this, "hoist-uniform",
616 llvm::cl::desc("Test hoist uniform"),
617 llvm::cl::init(Val: false)};
618
619 Option<bool> propagateDistribution{
620 *this, "propagate-distribution",
621 llvm::cl::desc("Test distribution propagation"), llvm::cl::init(Val: false)};
622
623 void runOnOperation() override {
624 RewritePatternSet patterns(&getContext());
625
626 getOperation().walk(callback: [&](Operation *op) {
627 if (auto warpOp = dyn_cast<gpu::WarpExecuteOnLane0Op>(Val: op)) {
628 if (hoistUniform) {
629 moveScalarUniformCode(op: warpOp);
630 }
631 WalkResult::interrupt();
632 }
633 });
634 MLIRContext *ctx = &getContext();
635 auto distributionFn = [](Value val) {
636 // Create an identity dim map of the same rank as the vector.
637 VectorType vecType = dyn_cast<VectorType>(Val: val.getType());
638 int64_t vecRank = vecType ? vecType.getRank() : 0;
639 OpBuilder builder(val.getContext());
640 if (vecRank == 0)
641 return AffineMap::get(context: val.getContext());
642 return AffineMap::getMultiDimIdentityMap(numDims: vecRank, context: val.getContext());
643 };
644 auto shuffleFn = [](Location loc, OpBuilder &builder, Value val,
645 Value srcIdx, int64_t warpSz) {
646 assert((val.getType().isF32() || val.getType().isInteger(32)) &&
647 "unsupported shuffle type");
648 Type i32Type = builder.getIntegerType(width: 32);
649 Value srcIdxI32 =
650 builder.create<arith::IndexCastOp>(location: loc, args&: i32Type, args&: srcIdx);
651 Value warpSzI32 = builder.create<arith::ConstantOp>(
652 location: loc, args: builder.getIntegerAttr(type: i32Type, value: warpSz));
653 Value result = builder
654 .create<gpu::ShuffleOp>(location: loc, args&: val, args&: srcIdxI32, args&: warpSzI32,
655 args: gpu::ShuffleMode::IDX)
656 .getResult(i: 0);
657 return result;
658 };
659 if (distributeTransferWriteOps && propagateDistribution) {
660 RewritePatternSet patterns(ctx);
661 vector::populatePropagateWarpVectorDistributionPatterns(
662 pattern&: patterns, distributionMapFn: distributionFn, warpShuffleFromIdxFn: shuffleFn, /*benefit=*/1,
663 /*readBenefit=*/0);
664 vector::populateDistributeReduction(pattern&: patterns, distributedReductionFn: warpReduction, benefit: 1);
665 populateDistributeTransferWriteOpPatterns(patterns, distributionMapFn: distributionFn, maxNumElementsToExtract: 2);
666 (void)applyPatternsGreedily(op: getOperation(), patterns: std::move(patterns));
667 } else if (distributeTransferWriteOps) {
668 RewritePatternSet patterns(ctx);
669 populateDistributeTransferWriteOpPatterns(patterns, distributionMapFn: distributionFn,
670 maxNumElementsToExtract: maxTransferWriteElements);
671 (void)applyPatternsGreedily(op: getOperation(), patterns: std::move(patterns));
672 } else if (propagateDistribution) {
673 RewritePatternSet patterns(ctx);
674 vector::populatePropagateWarpVectorDistributionPatterns(
675 pattern&: patterns, distributionMapFn: distributionFn, warpShuffleFromIdxFn: shuffleFn);
676 vector::populateDistributeReduction(pattern&: patterns, distributedReductionFn: warpReduction);
677 (void)applyPatternsGreedily(op: getOperation(), patterns: std::move(patterns));
678 }
679 WarpExecuteOnLane0LoweringOptions options;
680 options.warpAllocationFn = allocateGlobalSharedMemory;
681 options.warpSyncronizationFn = [](Location loc, OpBuilder &builder,
682 gpu::WarpExecuteOnLane0Op warpOp) {
683 builder.create<gpu::BarrierOp>(location: loc);
684 };
685 // Test on one pattern in isolation.
686 if (warpOpToSCF) {
687 populateWarpExecuteOnLane0OpToScfForPattern(patterns, options);
688 (void)applyPatternsGreedily(op: getOperation(), patterns: std::move(patterns));
689 return;
690 }
691 }
692};
693
694struct TestVectorExtractStridedSliceLowering
695 : public PassWrapper<TestVectorExtractStridedSliceLowering,
696 OperationPass<func::FuncOp>> {
697 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(
698 TestVectorExtractStridedSliceLowering)
699
700 StringRef getArgument() const final {
701 return "test-vector-extract-strided-slice-lowering";
702 }
703 StringRef getDescription() const final {
704 return "Test lowering patterns that converts vector.extract_strided_slice "
705 "into a chain of vector.extract and vector.insert ops";
706 }
707 void runOnOperation() override {
708 RewritePatternSet patterns(&getContext());
709 populateVectorExtractStridedSliceToExtractInsertChainPatterns(patterns);
710 (void)applyPatternsGreedily(op: getOperation(), patterns: std::move(patterns));
711 }
712};
713
714struct TestVectorBreakDownBitCast
715 : public PassWrapper<TestVectorBreakDownBitCast,
716 OperationPass<func::FuncOp>> {
717 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestVectorBreakDownBitCast)
718
719 StringRef getArgument() const final {
720 return "test-vector-break-down-bitcast";
721 }
722 StringRef getDescription() const final {
723 return "Test pattern that breaks down vector.bitcast ops ";
724 }
725 void runOnOperation() override {
726 RewritePatternSet patterns(&getContext());
727 populateBreakDownVectorBitCastOpPatterns(patterns, controlFn: [](BitCastOp op) {
728 return op.getSourceVectorType().getShape().back() > 4;
729 });
730 (void)applyPatternsGreedily(op: getOperation(), patterns: std::move(patterns));
731 }
732};
733
734struct TestCreateVectorBroadcast
735 : public PassWrapper<TestCreateVectorBroadcast,
736 OperationPass<func::FuncOp>> {
737 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestCreateVectorBroadcast)
738
739 StringRef getArgument() const final { return "test-create-vector-broadcast"; }
740 StringRef getDescription() const final {
741 return "Test optimization transformations for transfer ops";
742 }
743 void getDependentDialects(DialectRegistry &registry) const override {
744 registry.insert<vector::VectorDialect>();
745 }
746
747 void runOnOperation() override {
748 getOperation()->walk(callback: [](Operation *op) {
749 if (op->getName().getStringRef() != "test_create_broadcast")
750 return;
751 auto targetShape =
752 cast<VectorType>(Val: op->getResult(idx: 0).getType()).getShape();
753 auto arrayAttr =
754 cast<DenseI64ArrayAttr>(Val: op->getDiscardableAttr(name: "broadcast_dims"))
755 .asArrayRef();
756 llvm::SetVector<int64_t> broadcastedDims;
757 broadcastedDims.insert_range(R&: arrayAttr);
758 OpBuilder b(op);
759 Value bcast = vector::BroadcastOp::createOrFoldBroadcastOp(
760 b, value: op->getOperand(idx: 0), dstShape: targetShape, broadcastedDims);
761 op->getResult(idx: 0).replaceAllUsesWith(newValue: bcast);
762 op->erase();
763 });
764 }
765};
766
767struct TestVectorGatherLowering
768 : public PassWrapper<TestVectorGatherLowering,
769 OperationPass<func::FuncOp>> {
770 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestVectorGatherLowering)
771
772 StringRef getArgument() const final { return "test-vector-gather-lowering"; }
773 StringRef getDescription() const final {
774 return "Test patterns that lower the gather op in the vector conditional "
775 "loads";
776 }
777 void getDependentDialects(DialectRegistry &registry) const override {
778 registry.insert<arith::ArithDialect, func::FuncDialect,
779 memref::MemRefDialect, scf::SCFDialect,
780 tensor::TensorDialect, vector::VectorDialect>();
781 }
782
783 void runOnOperation() override {
784 RewritePatternSet patterns(&getContext());
785 populateVectorGatherLoweringPatterns(patterns);
786 populateVectorGatherToConditionalLoadPatterns(patterns);
787 (void)applyPatternsGreedily(op: getOperation(), patterns: std::move(patterns));
788 }
789};
790
791struct TestFoldArithExtensionIntoVectorContractPatterns
792 : public PassWrapper<TestFoldArithExtensionIntoVectorContractPatterns,
793 OperationPass<func::FuncOp>> {
794 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(
795 TestFoldArithExtensionIntoVectorContractPatterns)
796
797 StringRef getArgument() const final {
798 return "test-fold-arith-extf-into-vector-contract-patterns";
799 }
800 StringRef getDescription() const final {
801 return "Test patterns that fold arithmetic extension ops into vector "
802 "contract ops";
803 }
804
805 void getDependentDialects(DialectRegistry &registry) const override {
806 registry.insert<arith::ArithDialect, func::FuncDialect, nvgpu::NVGPUDialect,
807 memref::MemRefDialect, scf::SCFDialect,
808 tensor::TensorDialect, vector::VectorDialect>();
809 }
810
811 void runOnOperation() override {
812 RewritePatternSet patterns(&getContext());
813 populateFoldArithExtensionPatterns(patterns);
814 (void)applyPatternsGreedily(op: getOperation(), patterns: std::move(patterns));
815 }
816};
817
818struct TestVectorEmulateMaskedLoadStore final
819 : public PassWrapper<TestVectorEmulateMaskedLoadStore,
820 OperationPass<func::FuncOp>> {
821 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestVectorEmulateMaskedLoadStore)
822
823 StringRef getArgument() const override {
824 return "test-vector-emulate-masked-load-store";
825 }
826 StringRef getDescription() const override {
827 return "Test patterns that emulate the maskedload/maskedstore op by "
828 " memref.load/store and scf.if";
829 }
830 void getDependentDialects(DialectRegistry &registry) const override {
831 registry
832 .insert<arith::ArithDialect, func::FuncDialect, memref::MemRefDialect,
833 scf::SCFDialect, vector::VectorDialect>();
834 }
835
836 void runOnOperation() override {
837 RewritePatternSet patterns(&getContext());
838 populateVectorMaskedLoadStoreEmulationPatterns(patterns);
839 (void)applyPatternsGreedily(op: getOperation(), patterns: std::move(patterns));
840 }
841};
842
843/// Get the set of operand/result types to check for sufficiently
844/// small inner-most dimension size.
845static SmallVector<std::pair<Type, unsigned>>
846getTypeBitWidthBoundPairs(Operation *op, unsigned targetBitWidth) {
847
848 if (auto insertOp = dyn_cast<vector::InsertOp>(Val: op)) {
849 unsigned w = targetBitWidth < std::numeric_limits<unsigned>::max()
850 ? targetBitWidth + 1
851 : targetBitWidth;
852 return {{insertOp.getValueToStoreType(), w}};
853 }
854
855 auto resultTypes = op->getResultTypes();
856 SmallVector<std::pair<Type, unsigned>> resultsWithBitWidth;
857 resultsWithBitWidth.reserve(N: resultTypes.size());
858 for (Type type : resultTypes) {
859 resultsWithBitWidth.push_back(Elt: {type, targetBitWidth});
860 }
861 return resultsWithBitWidth;
862}
863
864/// If `type` is VectorType with trailing dimension of (bit) size greater than
865/// or equal to `targetBitWidth`, its defining op is considered legal.
866static bool
867isNotLinearizableBecauseLargeInnerDimension(Type type,
868 unsigned targetBitWidth) {
869
870 VectorType vecType = dyn_cast<VectorType>(Val&: type);
871
872 // Not linearizable for reasons other than what this function checks.
873 if (!vecType || vecType.getRank() == 0)
874 return false;
875
876 // The width of the type 'index' is unbounded (and therefore potentially above
877 // the target width).
878 if (vecType.getElementType().isIndex())
879 return true;
880
881 unsigned finalDimSize = vecType.getShape().back();
882 unsigned nbBitsPerElm = vecType.getElementTypeBitWidth();
883 unsigned trailingVecDimBitWidth = finalDimSize * nbBitsPerElm;
884 return trailingVecDimBitWidth >= targetBitWidth;
885}
886
887static bool
888isNotLinearizableBecauseLargeInnerDimension(Operation *op,
889 unsigned targetBitWidth) {
890 // Check on bitwidths.
891 SmallVector<std::pair<Type, unsigned>> toCheck =
892 getTypeBitWidthBoundPairs(op, targetBitWidth);
893 return llvm::any_of(Range&: toCheck, P: [&](std::pair<Type, unsigned> typeWidth) {
894 return isNotLinearizableBecauseLargeInnerDimension(type: typeWidth.first,
895 targetBitWidth: typeWidth.second);
896 });
897}
898
899void populateWithBitWidthConstraints(TypeConverter &typeConverter,
900 ConversionTarget &target,
901 unsigned targetBitWidth) {
902
903 // The general purpose definition of what ops are legal must come first.
904 populateForVectorLinearize(typeConverter, conversionTarget&: target);
905
906 // Extend the set of legal ops to include those with large inner-most
907 // dimensions on selected operands/results.
908 target.markUnknownOpDynamicallyLegal(
909 fn: [=](Operation *op) -> std::optional<bool> {
910 if (isNotLinearizableBecauseLargeInnerDimension(op, targetBitWidth)) {
911 return true;
912 }
913 return {};
914 });
915}
916
917struct TestVectorBitWidthLinearize final
918 : public PassWrapper<TestVectorBitWidthLinearize, OperationPass<>> {
919 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestVectorBitWidthLinearize)
920
921 TestVectorBitWidthLinearize() = default;
922 TestVectorBitWidthLinearize(const TestVectorBitWidthLinearize &pass)
923 : PassWrapper(pass) {}
924
925 StringRef getArgument() const override {
926 return "test-bit-width-constrained-vector-linearize";
927 }
928 StringRef getDescription() const override {
929 return "Linearizes ND vectors for N >= 2 into 1D vectors, with constraints "
930 "in inner-most dimension's bit width.";
931 }
932 void getDependentDialects(DialectRegistry &registry) const override {
933 registry.insert<vector::VectorDialect>();
934 }
935
936 Option<unsigned> targetVectorBitwidth{
937 *this, "target-vector-bitwidth",
938 llvm::cl::desc(
939 "Minimum vector bitwidth to enable the flattening transformation"),
940 llvm::cl::init(Val: std::numeric_limits<unsigned>::max())};
941 void runOnOperation() override {
942 auto *context = &getContext();
943
944 TypeConverter typeConverter;
945 RewritePatternSet patterns(context);
946 ConversionTarget target(*context);
947
948 populateWithBitWidthConstraints(typeConverter, target,
949 targetBitWidth: targetVectorBitwidth);
950
951 vector::populateVectorLinearizeBasePatterns(typeConverter, target,
952 patterns);
953
954 vector::populateVectorLinearizeShuffleLikeOpsPatterns(typeConverter, target,
955 patterns);
956
957 if (failed(Result: applyPartialConversion(op: getOperation(), target,
958 patterns: std::move(patterns))))
959 return signalPassFailure();
960 }
961};
962
963struct TestVectorLinearize final
964 : public PassWrapper<TestVectorLinearize, OperationPass<>> {
965 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestVectorLinearize)
966
967 TestVectorLinearize() = default;
968
969 StringRef getArgument() const override { return "test-vector-linearize"; }
970 StringRef getDescription() const override {
971 return "Linearizes ND vectors for N >= 2 into 1D vectors";
972 }
973 void getDependentDialects(DialectRegistry &registry) const override {
974 registry.insert<vector::VectorDialect, arith::ArithDialect>();
975 }
976
977 void runOnOperation() override {
978 MLIRContext &context = getContext();
979 TypeConverter converter;
980 RewritePatternSet patterns(&context);
981 ConversionTarget target(context);
982
983 vector::populateForVectorLinearize(typeConverter&: converter, conversionTarget&: target);
984
985 vector::populateVectorLinearizeBasePatterns(converter, target, patterns);
986 vector::populateVectorLinearizeShuffleLikeOpsPatterns(converter, target,
987 patterns);
988 mlir::scf::populateSCFStructuralTypeConversionsAndLegality(
989 typeConverter: converter, patterns, target);
990
991 if (failed(Result: applyPartialConversion(op: getOperation(), target,
992 patterns: std::move(patterns))))
993 return signalPassFailure();
994 }
995};
996
997struct TestEliminateVectorMasks
998 : public PassWrapper<TestEliminateVectorMasks,
999 OperationPass<func::FuncOp>> {
1000 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestEliminateVectorMasks)
1001
1002 TestEliminateVectorMasks() = default;
1003 TestEliminateVectorMasks(const TestEliminateVectorMasks &pass)
1004 : PassWrapper(pass) {}
1005
1006 Option<unsigned> vscaleMin{
1007 *this, "vscale-min", llvm::cl::desc("Minimum possible value of vscale."),
1008 llvm::cl::init(Val: 1)};
1009 Option<unsigned> vscaleMax{
1010 *this, "vscale-max", llvm::cl::desc("Maximum possible value of vscale."),
1011 llvm::cl::init(Val: 16)};
1012
1013 StringRef getArgument() const final { return "test-eliminate-vector-masks"; }
1014 StringRef getDescription() const final {
1015 return "Test eliminating vector masks";
1016 }
1017 void runOnOperation() override {
1018 IRRewriter rewriter(&getContext());
1019 eliminateVectorMasks(rewriter, function: getOperation(),
1020 vscaleRange: VscaleRange{.vscaleMin: vscaleMin, .vscaleMax: vscaleMax});
1021 }
1022};
1023} // namespace
1024
1025namespace mlir {
1026namespace test {
1027void registerTestVectorLowerings() {
1028 PassRegistration<TestVectorToVectorLowering>();
1029
1030 PassRegistration<TestVectorContractionPrepareForMMTLowering>();
1031
1032 PassRegistration<TestVectorUnrollingPatterns>();
1033
1034 PassRegistration<TestVectorTransferUnrollingPatterns>();
1035
1036 PassRegistration<TestScalarVectorTransferLoweringPatterns>();
1037
1038 PassRegistration<TestVectorTransferOpt>();
1039
1040 PassRegistration<TestVectorTransferCollapseInnerMostContiguousDims>();
1041
1042 PassRegistration<TestVectorSinkPatterns>();
1043
1044 PassRegistration<TestVectorReduceToContractPatternsPatterns>();
1045
1046 PassRegistration<TestVectorChainedReductionFoldingPatterns>();
1047
1048 PassRegistration<TestVectorBreakDownReductionPatterns>();
1049
1050 PassRegistration<TestFlattenVectorTransferPatterns>();
1051
1052 PassRegistration<TestVectorScanLowering>();
1053
1054 PassRegistration<TestVectorDistribution>();
1055
1056 PassRegistration<TestVectorExtractStridedSliceLowering>();
1057
1058 PassRegistration<TestVectorBreakDownBitCast>();
1059
1060 PassRegistration<TestCreateVectorBroadcast>();
1061
1062 PassRegistration<TestVectorGatherLowering>();
1063
1064 PassRegistration<TestFoldArithExtensionIntoVectorContractPatterns>();
1065
1066 PassRegistration<TestVectorEmulateMaskedLoadStore>();
1067
1068 PassRegistration<TestVectorLinearize>();
1069
1070 PassRegistration<TestVectorBitWidthLinearize>();
1071
1072 PassRegistration<TestEliminateVectorMasks>();
1073}
1074} // namespace test
1075} // namespace mlir
1076

source code of mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp