1//===- TestVectorTransforms.cpp - Test Vector transforms and lowerings ----===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#include <optional>
10#include <type_traits>
11
12#include "mlir/Analysis/SliceAnalysis.h"
13#include "mlir/Dialect/Affine/IR/AffineOps.h"
14#include "mlir/Dialect/Arith/IR/Arith.h"
15#include "mlir/Dialect/Func/IR/FuncOps.h"
16#include "mlir/Dialect/GPU/IR/GPUDialect.h"
17#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
18#include "mlir/Dialect/Linalg/IR/Linalg.h"
19#include "mlir/Dialect/Linalg/Passes.h"
20#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
21#include "mlir/Dialect/MemRef/IR/MemRef.h"
22#include "mlir/Dialect/NVGPU/IR/NVGPUDialect.h"
23#include "mlir/Dialect/SCF/IR/SCF.h"
24#include "mlir/Dialect/Tensor/IR/Tensor.h"
25#include "mlir/Dialect/Vector/IR/VectorOps.h"
26#include "mlir/Dialect/Vector/Transforms/LoweringPatterns.h"
27#include "mlir/Dialect/Vector/Transforms/VectorDistribution.h"
28#include "mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h"
29#include "mlir/Dialect/Vector/Transforms/VectorTransforms.h"
30#include "mlir/Pass/Pass.h"
31#include "mlir/Pass/PassManager.h"
32#include "mlir/Support/LLVM.h"
33#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
34
35using namespace mlir;
36using namespace mlir::linalg;
37using namespace mlir::vector;
38
39namespace {
40
41struct TestVectorToVectorLowering
42 : public PassWrapper<TestVectorToVectorLowering,
43 OperationPass<func::FuncOp>> {
44 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestVectorToVectorLowering)
45
46 TestVectorToVectorLowering() = default;
47 TestVectorToVectorLowering(const TestVectorToVectorLowering &pass)
48 : PassWrapper(pass) {}
49 StringRef getArgument() const final {
50 return "test-vector-to-vector-lowering";
51 }
52 StringRef getDescription() const final {
53 return "Test lowering patterns between ops in the vector dialect";
54 }
55
56 void getDependentDialects(DialectRegistry &registry) const override {
57 registry.insert<affine::AffineDialect>();
58 registry.insert<vector::VectorDialect>();
59 }
60
61 Option<bool> unroll{*this, "unroll", llvm::cl::desc("Include unrolling"),
62 llvm::cl::init(Val: false)};
63
64 void runOnOperation() override {
65 auto *ctx = &getContext();
66 RewritePatternSet patterns(ctx);
67 if (unroll) {
68 populateVectorUnrollPatterns(
69 patterns,
70 UnrollVectorOptions().setNativeShapeFn(getShape).setFilterConstraint(
71 filter));
72 }
73 populateVectorToVectorCanonicalizationPatterns(patterns);
74 populateBubbleVectorBitCastOpPatterns(patterns);
75 populateCastAwayVectorLeadingOneDimPatterns(patterns);
76 (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
77 }
78
79private:
80 // Return the target shape based on op type.
81 static std::optional<SmallVector<int64_t>> getShape(Operation *op) {
82 if (isa<arith::AddFOp, arith::SelectOp, arith::CmpFOp>(op))
83 return SmallVector<int64_t>(2, 2);
84 if (isa<vector::ContractionOp>(Val: op))
85 return SmallVector<int64_t>(3, 2);
86 // For transfer ops, just propagate the shape coming from
87 // InsertStridedSlices/ExtractStridedSlices.
88 if (auto readOp = dyn_cast<vector::TransferReadOp>(op)) {
89 VectorType dstVec;
90 for (Operation *users : readOp->getUsers()) {
91 auto extract = dyn_cast<ExtractStridedSliceOp>(users);
92 if (!extract)
93 return std::nullopt;
94 auto vecType = cast<VectorType>(extract.getResult().getType());
95 if (dstVec && dstVec != vecType)
96 return std::nullopt;
97 dstVec = vecType;
98 }
99 return SmallVector<int64_t>(dstVec.getShape().begin(),
100 dstVec.getShape().end());
101 }
102 if (auto writeOp = dyn_cast<vector::TransferWriteOp>(op)) {
103 auto insert = writeOp.getVector().getDefiningOp<InsertStridedSliceOp>();
104 if (!insert)
105 return std::nullopt;
106 ArrayRef<int64_t> shape = insert.getSourceVectorType().getShape();
107 return SmallVector<int64_t>(shape.begin(), shape.end());
108 }
109 return std::nullopt;
110 }
111
112 static LogicalResult filter(Operation *op) {
113 return success(isa<arith::AddFOp, arith::SelectOp, arith::CmpFOp,
114 ContractionOp, TransferReadOp, TransferWriteOp>(op));
115 }
116};
117
118struct TestVectorContractionPrepareForMMTLowering
119 : public PassWrapper<TestVectorContractionPrepareForMMTLowering,
120 OperationPass<func::FuncOp>> {
121 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(
122 TestVectorContractionPrepareForMMTLowering)
123
124 StringRef getArgument() const final {
125 return "test-vector-contraction-prepare-for-mmt-lowering";
126 }
127 StringRef getDescription() const final {
128 return "Test vector.contraction matmul canonicalization for MMT lowering.";
129 }
130 TestVectorContractionPrepareForMMTLowering() = default;
131
132 void getDependentDialects(DialectRegistry &registry) const override {
133 registry.insert<affine::AffineDialect, arith::ArithDialect,
134 vector::VectorDialect>();
135 }
136
137 void runOnOperation() override {
138 MLIRContext *ctx = &getContext();
139 RewritePatternSet patterns(ctx);
140 vector::populateVectorContractCanonicalizeMatmulToMMT(patterns);
141 (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
142 }
143};
144
145struct TestVectorUnrollingPatterns
146 : public PassWrapper<TestVectorUnrollingPatterns,
147 OperationPass<func::FuncOp>> {
148 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestVectorUnrollingPatterns)
149
150 StringRef getArgument() const final {
151 return "test-vector-unrolling-patterns";
152 }
153 StringRef getDescription() const final {
154 return "Test lowering patterns to unroll contract ops in the vector "
155 "dialect";
156 }
157 TestVectorUnrollingPatterns() = default;
158 TestVectorUnrollingPatterns(const TestVectorUnrollingPatterns &pass)
159 : PassWrapper(pass) {}
160 void runOnOperation() override {
161 MLIRContext *ctx = &getContext();
162 RewritePatternSet patterns(ctx);
163 populateVectorUnrollPatterns(
164 patterns, options: UnrollVectorOptions()
165 .setNativeShape(ArrayRef<int64_t>{2, 2})
166 .setFilterConstraint([](Operation *op) {
167 return success(isa<arith::AddFOp, vector::FMAOp,
168 vector::MultiDimReductionOp>(op));
169 }));
170 populateVectorUnrollPatterns(
171 patterns, options: UnrollVectorOptions()
172 .setNativeShape(ArrayRef<int64_t>{2})
173 .setFilterConstraint([](Operation *op) {
174 return success(isa<vector::ReductionOp>(op));
175 }));
176 populateVectorUnrollPatterns(
177 patterns, options: UnrollVectorOptions()
178 .setNativeShape(ArrayRef<int64_t>{1, 3, 4, 2})
179 .setFilterConstraint([](Operation *op) {
180 return success(isa<vector::TransposeOp>(op));
181 }));
182
183 if (unrollBasedOnType) {
184 UnrollVectorOptions::NativeShapeFnType nativeShapeFn =
185 [](Operation *op) -> std::optional<SmallVector<int64_t>> {
186 vector::ContractionOp contractOp = cast<vector::ContractionOp>(op);
187 SmallVector<int64_t> nativeShape(contractOp.getIteratorTypes().size(),
188 4);
189 Type lhsType = contractOp.getLhsType().getElementType();
190 nativeShape[nativeShape.size() - 1] = lhsType.isF16() ? 4 : 2;
191 return nativeShape;
192 };
193
194 UnrollVectorOptions opts;
195 opts.setNativeShapeFn(nativeShapeFn)
196 .setFilterConstraint(
197 [](Operation *op) { return success(isSuccess: isa<ContractionOp>(Val: op)); });
198
199 if (!unrollOrder.empty()) {
200 opts.setUnrollTraversalOrderFn(
201 [this](Operation *op) -> std::optional<SmallVector<int64_t>> {
202 vector::ContractionOp contractOp =
203 cast<vector::ContractionOp>(op);
204 if (contractOp.getIteratorTypes().size() == unrollOrder.size())
205 return SmallVector<int64_t>(unrollOrder.begin(),
206 unrollOrder.end());
207 return std::nullopt;
208 });
209 }
210 populateVectorUnrollPatterns(patterns, options: opts);
211 } else {
212 auto nativeShapeFn =
213 [](Operation *op) -> std::optional<SmallVector<int64_t>> {
214 auto contractOp = dyn_cast<ContractionOp>(op);
215 if (!contractOp)
216 return std::nullopt;
217 return SmallVector<int64_t>(contractOp.getIteratorTypes().size(), 2);
218 };
219 populateVectorUnrollPatterns(patterns,
220 UnrollVectorOptions()
221 .setNativeShapeFn(nativeShapeFn)
222 .setFilterConstraint([](Operation *op) {
223 return success(isSuccess: isa<ContractionOp>(Val: op));
224 }));
225 }
226 populateVectorToVectorCanonicalizationPatterns(patterns);
227 (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
228 }
229
230 ListOption<int64_t> unrollOrder{*this, "unroll-order",
231 llvm::cl::desc("set the unroll order")};
232
233 Option<bool> unrollBasedOnType{
234 *this, "unroll-based-on-type",
235 llvm::cl::desc("Set the unroll factor based on type of the operation"),
236 llvm::cl::init(Val: false)};
237};
238
239struct TestVectorTransferUnrollingPatterns
240 : public PassWrapper<TestVectorTransferUnrollingPatterns,
241 OperationPass<func::FuncOp>> {
242 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(
243 TestVectorTransferUnrollingPatterns)
244
245 TestVectorTransferUnrollingPatterns() = default;
246 TestVectorTransferUnrollingPatterns(
247 const TestVectorTransferUnrollingPatterns &pass)
248 : PassWrapper(pass) {}
249
250 void getDependentDialects(DialectRegistry &registry) const override {
251 registry.insert<affine::AffineDialect>();
252 }
253 StringRef getArgument() const final {
254 return "test-vector-transfer-unrolling-patterns";
255 }
256 StringRef getDescription() const final {
257 return "Test lowering patterns to unroll transfer ops in the vector "
258 "dialect";
259 }
260 void runOnOperation() override {
261 MLIRContext *ctx = &getContext();
262 RewritePatternSet patterns(ctx);
263 UnrollVectorOptions opts;
264 opts.setNativeShape(ArrayRef<int64_t>{2, 2})
265 .setFilterConstraint([](Operation *op) {
266 return success(isa<vector::TransferReadOp, vector::TransferWriteOp,
267 vector::GatherOp>(op));
268 });
269 if (reverseUnrollOrder.getValue()) {
270 opts.setUnrollTraversalOrderFn(
271 [](Operation *op) -> std::optional<SmallVector<int64_t>> {
272 int64_t numLoops = 0;
273 if (auto readOp = dyn_cast<vector::TransferReadOp>(op))
274 numLoops = readOp.getVectorType().getRank();
275 else if (auto writeOp = dyn_cast<vector::TransferWriteOp>(op))
276 numLoops = writeOp.getVectorType().getRank();
277 else if (auto gatherOp = dyn_cast<vector::GatherOp>(op))
278 numLoops = gatherOp.getVectorType().getRank();
279 else
280 return std::nullopt;
281 auto order = llvm::reverse(C: llvm::seq<int64_t>(Begin: 0, End: numLoops));
282 return llvm::to_vector(Range&: order);
283 });
284 }
285 populateVectorUnrollPatterns(patterns, options: opts);
286 populateVectorToVectorCanonicalizationPatterns(patterns);
287 (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
288 }
289
290 Option<bool> reverseUnrollOrder{
291 *this, "reverse-unroll-order",
292 llvm::cl::desc(
293 "reverse the order of unrolling of vector transfer operations"),
294 llvm::cl::init(Val: false)};
295};
296
297struct TestScalarVectorTransferLoweringPatterns
298 : public PassWrapper<TestScalarVectorTransferLoweringPatterns,
299 OperationPass<func::FuncOp>> {
300 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(
301 TestScalarVectorTransferLoweringPatterns)
302
303 TestScalarVectorTransferLoweringPatterns() = default;
304 TestScalarVectorTransferLoweringPatterns(
305 const TestScalarVectorTransferLoweringPatterns &pass)
306 : PassWrapper(pass) {}
307
308 StringRef getArgument() const final {
309 return "test-scalar-vector-transfer-lowering";
310 }
311 StringRef getDescription() const final {
312 return "Test lowering of scalar vector transfers to memref loads/stores.";
313 }
314
315 void getDependentDialects(DialectRegistry &registry) const override {
316 registry.insert<affine::AffineDialect, memref::MemRefDialect,
317 tensor::TensorDialect, vector::VectorDialect>();
318 }
319
320 Option<bool> allowMultipleUses{
321 *this, "allow-multiple-uses",
322 llvm::cl::desc("Fold transfer operations with multiple uses"),
323 llvm::cl::init(Val: false)};
324
325 void runOnOperation() override {
326 MLIRContext *ctx = &getContext();
327 RewritePatternSet patterns(ctx);
328 vector::populateScalarVectorTransferLoweringPatterns(
329 patterns, /*benefit=*/1, allowMultipleUses: allowMultipleUses.getValue());
330 (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
331 }
332};
333
334struct TestVectorTransferOpt
335 : public PassWrapper<TestVectorTransferOpt, OperationPass<func::FuncOp>> {
336 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestVectorTransferOpt)
337
338 StringRef getArgument() const final { return "test-vector-transferop-opt"; }
339 StringRef getDescription() const final {
340 return "Test optimization transformations for transfer ops";
341 }
342 void runOnOperation() override {
343 IRRewriter rewriter(&getContext());
344 transferOpflowOpt(rewriter, getOperation());
345 }
346};
347
348struct TestVectorTransferCollapseInnerMostContiguousDims
349 : public PassWrapper<TestVectorTransferCollapseInnerMostContiguousDims,
350 OperationPass<func::FuncOp>> {
351 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(
352 TestVectorTransferCollapseInnerMostContiguousDims)
353
354 TestVectorTransferCollapseInnerMostContiguousDims() = default;
355 TestVectorTransferCollapseInnerMostContiguousDims(
356 const TestVectorTransferCollapseInnerMostContiguousDims &pass) = default;
357
358 void getDependentDialects(DialectRegistry &registry) const override {
359 registry.insert<memref::MemRefDialect, affine::AffineDialect>();
360 }
361
362 StringRef getArgument() const final {
363 return "test-vector-transfer-collapse-inner-most-dims";
364 }
365
366 StringRef getDescription() const final {
367 return "Test lowering patterns that reducedes the rank of the vector "
368 "transfer memory and vector operands.";
369 }
370
371 void runOnOperation() override {
372 RewritePatternSet patterns(&getContext());
373 populateVectorTransferCollapseInnerMostContiguousDimsPatterns(patterns);
374 (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
375 }
376};
377
378struct TestSinkVectorBroadcast
379 : public PassWrapper<TestSinkVectorBroadcast, OperationPass<func::FuncOp>> {
380 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestSinkVectorBroadcast)
381
382 TestSinkVectorBroadcast() = default;
383 TestSinkVectorBroadcast(const TestSinkVectorBroadcast &pass) = default;
384
385 void getDependentDialects(DialectRegistry &registry) const override {
386 registry.insert<memref::MemRefDialect, affine::AffineDialect>();
387 }
388
389 StringRef getArgument() const final { return "test-sink-vector-broadcast"; }
390
391 StringRef getDescription() const final {
392 return "Test lowering patterns that eliminate redundant brodacast "
393 "operations.";
394 }
395
396 void runOnOperation() override {
397 RewritePatternSet patterns(&getContext());
398 populateSinkVectorBroadcastPatterns(patterns);
399 (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
400 }
401};
402
403struct TestVectorReduceToContractPatternsPatterns
404 : public PassWrapper<TestVectorReduceToContractPatternsPatterns,
405 OperationPass<func::FuncOp>> {
406 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(
407 TestVectorReduceToContractPatternsPatterns)
408
409 StringRef getArgument() const final {
410 return "test-vector-reduction-to-contract-patterns";
411 }
412 StringRef getDescription() const final {
413 return "Test patterns to convert multireduce op to contract and combine "
414 "broadcast/transpose to contract";
415 }
416 void runOnOperation() override {
417 RewritePatternSet patterns(&getContext());
418 populateVectorReductionToContractPatterns(patterns);
419 (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
420 }
421};
422
423struct TestVectorChainedReductionFoldingPatterns
424 : public PassWrapper<TestVectorChainedReductionFoldingPatterns,
425 OperationPass<func::FuncOp>> {
426 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(
427 TestVectorChainedReductionFoldingPatterns)
428
429 StringRef getArgument() const final {
430 return "test-vector-chained-reduction-folding-patterns";
431 }
432 StringRef getDescription() const final {
433 return "Test patterns to fold chained vector reductions";
434 }
435 void runOnOperation() override {
436 RewritePatternSet patterns(&getContext());
437 populateChainedVectorReductionFoldingPatterns(patterns);
438 (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
439 }
440};
441
442struct TestVectorBreakDownReductionPatterns
443 : public PassWrapper<TestVectorBreakDownReductionPatterns,
444 OperationPass<func::FuncOp>> {
445 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(
446 TestVectorBreakDownReductionPatterns)
447
448 StringRef getArgument() const final {
449 return "test-vector-break-down-reduction-patterns";
450 }
451 StringRef getDescription() const final {
452 return "Test patterns to break down vector reductions into arith "
453 "reductions";
454 }
455 void runOnOperation() override {
456 RewritePatternSet patterns(&getContext());
457 populateBreakDownVectorReductionPatterns(patterns,
458 /*maxNumElementsToExtract=*/2);
459 (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
460 }
461};
462
463struct TestFlattenVectorTransferPatterns
464 : public PassWrapper<TestFlattenVectorTransferPatterns,
465 OperationPass<func::FuncOp>> {
466 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(
467 TestFlattenVectorTransferPatterns)
468
469 TestFlattenVectorTransferPatterns() = default;
470 TestFlattenVectorTransferPatterns(
471 const TestFlattenVectorTransferPatterns &pass)
472 : PassWrapper(pass) {}
473
474 StringRef getArgument() const final {
475 return "test-vector-transfer-flatten-patterns";
476 }
477
478 StringRef getDescription() const final {
479 return "Test patterns to rewrite contiguous row-major N-dimensional "
480 "vector.transfer_{read,write} ops into 1D transfers";
481 }
482
483 void getDependentDialects(DialectRegistry &registry) const override {
484 registry.insert<memref::MemRefDialect>();
485 registry.insert<affine::AffineDialect>();
486 registry.insert<vector::VectorDialect>();
487 }
488
489 Option<unsigned> targetVectorBitwidth{
490 *this, "target-vector-bitwidth",
491 llvm::cl::desc(
492 "Minimum vector bitwidth to enable the flattening transformation. "
493 "For scalable vectors this is the base size, i.e. the size "
494 "corresponding to vscale=1."),
495 llvm::cl::init(Val: std::numeric_limits<unsigned>::max())};
496
497 void runOnOperation() override {
498 RewritePatternSet patterns(&getContext());
499 populateFlattenVectorTransferPatterns(patterns, targetVectorBitwidth);
500 (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
501 }
502};
503
504struct TestVectorScanLowering
505 : public PassWrapper<TestVectorScanLowering, OperationPass<func::FuncOp>> {
506 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestVectorScanLowering)
507
508 StringRef getArgument() const final { return "test-vector-scan-lowering"; }
509 StringRef getDescription() const final {
510 return "Test lowering patterns that lower the scan op in the vector "
511 "dialect";
512 }
513 void runOnOperation() override {
514 RewritePatternSet patterns(&getContext());
515 populateVectorScanLoweringPatterns(patterns);
516 (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
517 }
518};
519
520/// Allocate shared memory for a single warp to test lowering of
521/// WarpExecuteOnLane0Op.
522static Value allocateGlobalSharedMemory(Location loc, OpBuilder &builder,
523 WarpExecuteOnLane0Op warpOp,
524 Type type) {
525 static constexpr int64_t kSharedMemorySpace = 3;
526 // Compute type of shared memory buffer.
527 MemRefType memrefType;
528 if (auto vectorType = dyn_cast<VectorType>(type)) {
529 memrefType =
530 MemRefType::get(vectorType.getShape(), vectorType.getElementType(), {},
531 kSharedMemorySpace);
532 } else {
533 memrefType = MemRefType::get({1}, type, {}, kSharedMemorySpace);
534 }
535
536 // Get symbol table holding all shared memory globals.
537 ModuleOp moduleOp = warpOp->getParentOfType<ModuleOp>();
538 SymbolTable symbolTable(moduleOp);
539
540 // Create a pretty name.
541 SmallString<64> buf;
542 llvm::raw_svector_ostream os(buf);
543 interleave(memrefType.getShape(), os, "x");
544 os << "x" << memrefType.getElementType();
545 std::string symbolName = (Twine("__shared_") + os.str()).str();
546
547 auto ip = builder.saveInsertionPoint();
548 builder.setInsertionPoint(moduleOp);
549 auto global = builder.create<memref::GlobalOp>(
550 loc,
551 /*sym_name=*/symbolName,
552 /*sym_visibility=*/builder.getStringAttr("private"),
553 /*type=*/memrefType,
554 /*initial_value=*/Attribute(),
555 /*constant=*/false,
556 /*alignment=*/IntegerAttr());
557 symbolTable.insert(symbol: global);
558 // The symbol table inserts at the end of the module, but globals are a bit
559 // nicer if they are at the beginning.
560 global->moveBefore(&moduleOp.front());
561
562 builder.restoreInsertionPoint(ip);
563 return builder.create<memref::GetGlobalOp>(loc, memrefType, symbolName);
564}
565
566static Value warpReduction(Location loc, OpBuilder &builder, Value input,
567 CombiningKind kind, uint32_t size) {
568 // First reduce on a single thread to get per lane reduction value.
569 Value laneVal = builder.create<vector::ReductionOp>(loc, kind, input);
570 // Parallel reduction using butterfly shuffles.
571 for (uint64_t i = 1; i < size; i <<= 1) {
572 Value shuffled = builder
573 .create<gpu::ShuffleOp>(loc, laneVal, i,
574 /*width=*/size,
575 /*mode=*/gpu::ShuffleMode::XOR)
576 .getShuffleResult();
577 laneVal = makeArithReduction(builder, loc, kind, laneVal, shuffled);
578 }
579 return laneVal;
580}
581
582struct TestVectorDistribution
583 : public PassWrapper<TestVectorDistribution, OperationPass<func::FuncOp>> {
584 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestVectorDistribution)
585
586 void getDependentDialects(DialectRegistry &registry) const override {
587 registry.insert<scf::SCFDialect, memref::MemRefDialect, gpu::GPUDialect,
588 affine::AffineDialect>();
589 }
590
591 StringRef getArgument() const final { return "test-vector-warp-distribute"; }
592 StringRef getDescription() const final {
593 return "Test vector warp distribute transformation and lowering patterns";
594 }
595 TestVectorDistribution() = default;
596 TestVectorDistribution(const TestVectorDistribution &pass)
597 : PassWrapper(pass) {}
598
599 Option<bool> warpOpToSCF{
600 *this, "rewrite-warp-ops-to-scf-if",
601 llvm::cl::desc("Lower vector.warp_execute_on_lane0 to scf.if op"),
602 llvm::cl::init(Val: false)};
603
604 Option<bool> distributeTransferWriteOps{
605 *this, "distribute-transfer-write",
606 llvm::cl::desc("Test distribution of transfer write"),
607 llvm::cl::init(Val: false)};
608
609 Option<unsigned> maxTransferWriteElements{
610 *this, "max-transfer-write-elements",
611 llvm::cl::desc("Maximum number of transfer write elements to distribute"),
612 llvm::cl::init(Val: 1)};
613
614 Option<bool> hoistUniform{*this, "hoist-uniform",
615 llvm::cl::desc("Test hoist uniform"),
616 llvm::cl::init(Val: false)};
617
618 Option<bool> propagateDistribution{
619 *this, "propagate-distribution",
620 llvm::cl::desc("Test distribution propgation"), llvm::cl::init(Val: false)};
621
622 void runOnOperation() override {
623 RewritePatternSet patterns(&getContext());
624
625 getOperation().walk([&](Operation *op) {
626 if (auto warpOp = dyn_cast<WarpExecuteOnLane0Op>(op)) {
627 if (hoistUniform) {
628 moveScalarUniformCode(warpOp);
629 }
630 WalkResult::interrupt();
631 }
632 });
633 MLIRContext *ctx = &getContext();
634 auto distributionFn = [](Value val) {
635 // Create an identity dim map of the same rank as the vector.
636 VectorType vecType = dyn_cast<VectorType>(val.getType());
637 int64_t vecRank = vecType ? vecType.getRank() : 0;
638 OpBuilder builder(val.getContext());
639 if (vecRank == 0)
640 return AffineMap::get(context: val.getContext());
641 return AffineMap::getMultiDimIdentityMap(vecRank, val.getContext());
642 };
643 auto shuffleFn = [](Location loc, OpBuilder &builder, Value val,
644 Value srcIdx, int64_t warpSz) {
645 assert((val.getType().isF32() || val.getType().isInteger(32)) &&
646 "unsupported shuffle type");
647 Type i32Type = builder.getIntegerType(32);
648 Value srcIdxI32 =
649 builder.create<arith::IndexCastOp>(loc, i32Type, srcIdx);
650 Value warpSzI32 = builder.create<arith::ConstantOp>(
651 loc, builder.getIntegerAttr(i32Type, warpSz));
652 Value result = builder
653 .create<gpu::ShuffleOp>(loc, val, srcIdxI32, warpSzI32,
654 gpu::ShuffleMode::IDX)
655 .getResult(0);
656 return result;
657 };
658 if (distributeTransferWriteOps && propagateDistribution) {
659 RewritePatternSet patterns(ctx);
660 vector::populatePropagateWarpVectorDistributionPatterns(
661 patterns, distributionFn, shuffleFn, /*benefit=*/1,
662 /*readBenefit=*/0);
663 vector::populateDistributeReduction(pattern&: patterns, distributedReductionFn: warpReduction, benefit: 1);
664 populateDistributeTransferWriteOpPatterns(patterns, distributionFn, 2);
665 (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
666 } else if (distributeTransferWriteOps) {
667 RewritePatternSet patterns(ctx);
668 populateDistributeTransferWriteOpPatterns(patterns, distributionFn,
669 maxTransferWriteElements);
670 (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
671 } else if (propagateDistribution) {
672 RewritePatternSet patterns(ctx);
673 vector::populatePropagateWarpVectorDistributionPatterns(
674 patterns, distributionFn, shuffleFn);
675 vector::populateDistributeReduction(pattern&: patterns, distributedReductionFn: warpReduction);
676 (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
677 }
678 WarpExecuteOnLane0LoweringOptions options;
679 options.warpAllocationFn = allocateGlobalSharedMemory;
680 options.warpSyncronizationFn = [](Location loc, OpBuilder &builder,
681 WarpExecuteOnLane0Op warpOp) {
682 builder.create<gpu::BarrierOp>(loc);
683 };
684 // Test on one pattern in isolation.
685 if (warpOpToSCF) {
686 populateWarpExecuteOnLane0OpToScfForPattern(patterns, options);
687 (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
688 return;
689 }
690 }
691};
692
693struct TestVectorExtractStridedSliceLowering
694 : public PassWrapper<TestVectorExtractStridedSliceLowering,
695 OperationPass<func::FuncOp>> {
696 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(
697 TestVectorExtractStridedSliceLowering)
698
699 StringRef getArgument() const final {
700 return "test-vector-extract-strided-slice-lowering";
701 }
702 StringRef getDescription() const final {
703 return "Test lowering patterns that converts vector.extract_strided_slice "
704 "into a chain of vector.extract and vector.insert ops";
705 }
706 void runOnOperation() override {
707 RewritePatternSet patterns(&getContext());
708 populateVectorExtractStridedSliceToExtractInsertChainPatterns(patterns);
709 (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
710 }
711};
712
713struct TestVectorBreakDownBitCast
714 : public PassWrapper<TestVectorBreakDownBitCast,
715 OperationPass<func::FuncOp>> {
716 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestVectorBreakDownBitCast)
717
718 StringRef getArgument() const final {
719 return "test-vector-break-down-bitcast";
720 }
721 StringRef getDescription() const final {
722 return "Test pattern that breaks down vector.bitcast ops ";
723 }
724 void runOnOperation() override {
725 RewritePatternSet patterns(&getContext());
726 populateBreakDownVectorBitCastOpPatterns(patterns, [](BitCastOp op) {
727 return op.getSourceVectorType().getShape().back() > 4;
728 });
729 (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
730 }
731};
732
733struct TestCreateVectorBroadcast
734 : public PassWrapper<TestCreateVectorBroadcast,
735 OperationPass<func::FuncOp>> {
736 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestCreateVectorBroadcast)
737
738 StringRef getArgument() const final { return "test-create-vector-broadcast"; }
739 StringRef getDescription() const final {
740 return "Test optimization transformations for transfer ops";
741 }
742 void getDependentDialects(DialectRegistry &registry) const override {
743 registry.insert<vector::VectorDialect>();
744 }
745
746 void runOnOperation() override {
747 getOperation()->walk([](Operation *op) {
748 if (op->getName().getStringRef() != "test_create_broadcast")
749 return;
750 auto targetShape =
751 cast<VectorType>(op->getResult(0).getType()).getShape();
752 auto arrayAttr =
753 cast<DenseI64ArrayAttr>(op->getDiscardableAttr("broadcast_dims"))
754 .asArrayRef();
755 llvm::SetVector<int64_t> broadcastedDims;
756 broadcastedDims.insert(arrayAttr.begin(), arrayAttr.end());
757 OpBuilder b(op);
758 Value bcast = vector::BroadcastOp::createOrFoldBroadcastOp(
759 b, op->getOperand(0), targetShape, broadcastedDims);
760 op->getResult(0).replaceAllUsesWith(bcast);
761 op->erase();
762 });
763 }
764};
765
766struct TestVectorGatherLowering
767 : public PassWrapper<TestVectorGatherLowering,
768 OperationPass<func::FuncOp>> {
769 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestVectorGatherLowering)
770
771 StringRef getArgument() const final { return "test-vector-gather-lowering"; }
772 StringRef getDescription() const final {
773 return "Test patterns that lower the gather op in the vector conditional "
774 "loads";
775 }
776 void getDependentDialects(DialectRegistry &registry) const override {
777 registry.insert<arith::ArithDialect, func::FuncDialect,
778 memref::MemRefDialect, scf::SCFDialect,
779 tensor::TensorDialect, vector::VectorDialect>();
780 }
781
782 void runOnOperation() override {
783 RewritePatternSet patterns(&getContext());
784 populateVectorGatherLoweringPatterns(patterns);
785 (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
786 }
787};
788
789struct TestFoldArithExtensionIntoVectorContractPatterns
790 : public PassWrapper<TestFoldArithExtensionIntoVectorContractPatterns,
791 OperationPass<func::FuncOp>> {
792 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(
793 TestFoldArithExtensionIntoVectorContractPatterns)
794
795 StringRef getArgument() const final {
796 return "test-fold-arith-extf-into-vector-contract-patterns";
797 }
798 StringRef getDescription() const final {
799 return "Test patterns that fold arithmetic extension ops into vector "
800 "contract ops";
801 }
802
803 void getDependentDialects(DialectRegistry &registry) const override {
804 registry.insert<arith::ArithDialect, func::FuncDialect, nvgpu::NVGPUDialect,
805 memref::MemRefDialect, scf::SCFDialect,
806 tensor::TensorDialect, vector::VectorDialect>();
807 }
808
809 void runOnOperation() override {
810 RewritePatternSet patterns(&getContext());
811 populateFoldArithExtensionPatterns(patterns);
812 (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
813 }
814};
815
816struct TestVectorEmulateMaskedLoadStore final
817 : public PassWrapper<TestVectorEmulateMaskedLoadStore,
818 OperationPass<func::FuncOp>> {
819 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestVectorEmulateMaskedLoadStore)
820
821 StringRef getArgument() const override {
822 return "test-vector-emulate-masked-load-store";
823 }
824 StringRef getDescription() const override {
825 return "Test patterns that emulate the maskedload/maskedstore op by "
826 " memref.load/store and scf.if";
827 }
828 void getDependentDialects(DialectRegistry &registry) const override {
829 registry
830 .insert<arith::ArithDialect, func::FuncDialect, memref::MemRefDialect,
831 scf::SCFDialect, vector::VectorDialect>();
832 }
833
834 void runOnOperation() override {
835 RewritePatternSet patterns(&getContext());
836 populateVectorMaskedLoadStoreEmulationPatterns(patterns);
837 (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
838 }
839};
840
841struct TestVectorLinearize final
842 : public PassWrapper<TestVectorLinearize, OperationPass<>> {
843 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestVectorLinearize)
844
845 TestVectorLinearize() = default;
846 TestVectorLinearize(const TestVectorLinearize &pass) : PassWrapper(pass) {}
847
848 StringRef getArgument() const override { return "test-vector-linearize"; }
849 StringRef getDescription() const override {
850 return "Linearizes ND vectors for N >= 2 into 1D vectors";
851 }
852 void getDependentDialects(DialectRegistry &registry) const override {
853 registry.insert<vector::VectorDialect>();
854 }
855
856 Option<unsigned> targetVectorBitwidth{
857 *this, "target-vector-bitwidth",
858 llvm::cl::desc(
859 "Minimum vector bitwidth to enable the flattening transformation"),
860 llvm::cl::init(Val: std::numeric_limits<unsigned>::max())};
861 void runOnOperation() override {
862 auto *context = &getContext();
863
864 TypeConverter typeConverter;
865 RewritePatternSet patterns(context);
866 ConversionTarget target(*context);
867
868 vector::populateVectorLinearizeTypeConversionsAndLegality(
869 typeConverter, patterns, target, targetBitWidth: targetVectorBitwidth);
870 vector::populateVectorLinearizeShuffleLikeOpsPatterns(
871 typeConverter, patterns, target, targetBitWidth: targetVectorBitwidth);
872 if (failed(applyPartialConversion(getOperation(), target,
873 std::move(patterns))))
874 return signalPassFailure();
875 }
876};
877} // namespace
878
879namespace mlir {
880namespace test {
881void registerTestVectorLowerings() {
882 PassRegistration<TestVectorToVectorLowering>();
883
884 PassRegistration<TestVectorContractionPrepareForMMTLowering>();
885
886 PassRegistration<TestVectorUnrollingPatterns>();
887
888 PassRegistration<TestVectorTransferUnrollingPatterns>();
889
890 PassRegistration<TestScalarVectorTransferLoweringPatterns>();
891
892 PassRegistration<TestVectorTransferOpt>();
893
894 PassRegistration<TestVectorTransferCollapseInnerMostContiguousDims>();
895
896 PassRegistration<TestSinkVectorBroadcast>();
897
898 PassRegistration<TestVectorReduceToContractPatternsPatterns>();
899
900 PassRegistration<TestVectorChainedReductionFoldingPatterns>();
901
902 PassRegistration<TestVectorBreakDownReductionPatterns>();
903
904 PassRegistration<TestFlattenVectorTransferPatterns>();
905
906 PassRegistration<TestVectorScanLowering>();
907
908 PassRegistration<TestVectorDistribution>();
909
910 PassRegistration<TestVectorExtractStridedSliceLowering>();
911
912 PassRegistration<TestVectorBreakDownBitCast>();
913
914 PassRegistration<TestCreateVectorBroadcast>();
915
916 PassRegistration<TestVectorGatherLowering>();
917
918 PassRegistration<TestFoldArithExtensionIntoVectorContractPatterns>();
919
920 PassRegistration<TestVectorEmulateMaskedLoadStore>();
921
922 PassRegistration<TestVectorLinearize>();
923}
924} // namespace test
925} // namespace mlir
926

source code of mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp