1//===- TestVectorTransforms.cpp - Test Vector transforms and lowerings ----===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#include <optional>
10
11#include "mlir/Analysis/SliceAnalysis.h"
12#include "mlir/Dialect/Affine/IR/AffineOps.h"
13#include "mlir/Dialect/Arith/IR/Arith.h"
14#include "mlir/Dialect/Func/IR/FuncOps.h"
15#include "mlir/Dialect/GPU/IR/GPUDialect.h"
16#include "mlir/Dialect/Linalg/Passes.h"
17#include "mlir/Dialect/MemRef/IR/MemRef.h"
18#include "mlir/Dialect/NVGPU/IR/NVGPUDialect.h"
19#include "mlir/Dialect/SCF/IR/SCF.h"
20#include "mlir/Dialect/SCF/Transforms/Patterns.h"
21#include "mlir/Dialect/Tensor/IR/Tensor.h"
22#include "mlir/Dialect/Vector/IR/VectorOps.h"
23#include "mlir/Dialect/Vector/Transforms/LoweringPatterns.h"
24#include "mlir/Dialect/Vector/Transforms/VectorDistribution.h"
25#include "mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h"
26#include "mlir/Dialect/Vector/Transforms/VectorTransforms.h"
27#include "mlir/Pass/Pass.h"
28#include "mlir/Pass/PassManager.h"
29#include "mlir/Support/LLVM.h"
30#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
31
32using namespace mlir;
33using namespace mlir::linalg;
34using namespace mlir::vector;
35
36namespace {
37
38struct TestVectorToVectorLowering
39 : public PassWrapper<TestVectorToVectorLowering,
40 OperationPass<func::FuncOp>> {
41 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestVectorToVectorLowering)
42
43 TestVectorToVectorLowering() = default;
44 TestVectorToVectorLowering(const TestVectorToVectorLowering &pass)
45 : PassWrapper(pass) {}
46 StringRef getArgument() const final {
47 return "test-vector-to-vector-lowering";
48 }
49 StringRef getDescription() const final {
50 return "Test lowering patterns between ops in the vector dialect";
51 }
52
53 void getDependentDialects(DialectRegistry &registry) const override {
54 registry.insert<affine::AffineDialect>();
55 registry.insert<vector::VectorDialect>();
56 }
57
58 Option<bool> unroll{*this, "unroll", llvm::cl::desc("Include unrolling"),
59 llvm::cl::init(Val: false)};
60
61 void runOnOperation() override {
62 auto *ctx = &getContext();
63 RewritePatternSet patterns(ctx);
64 if (unroll) {
65 populateVectorUnrollPatterns(
66 patterns,
67 UnrollVectorOptions().setNativeShapeFn(getShape).setFilterConstraint(
68 filter));
69 }
70 populateVectorToVectorCanonicalizationPatterns(patterns);
71 populateBubbleVectorBitCastOpPatterns(patterns);
72 populateCastAwayVectorLeadingOneDimPatterns(patterns);
73 (void)applyPatternsGreedily(getOperation(), std::move(patterns));
74 }
75
76private:
77 // Return the target shape based on op type.
78 static std::optional<SmallVector<int64_t>> getShape(Operation *op) {
79 if (isa<arith::AddFOp, arith::SelectOp, arith::CmpFOp>(op))
80 return SmallVector<int64_t>(2, 2);
81 if (isa<vector::ContractionOp>(Val: op))
82 return SmallVector<int64_t>(3, 2);
83 // For transfer ops, just propagate the shape coming from
84 // InsertStridedSlices/ExtractStridedSlices.
85 if (auto readOp = dyn_cast<vector::TransferReadOp>(op)) {
86 VectorType dstVec;
87 for (Operation *users : readOp->getUsers()) {
88 auto extract = dyn_cast<ExtractStridedSliceOp>(users);
89 if (!extract)
90 return std::nullopt;
91 auto vecType = cast<VectorType>(extract.getResult().getType());
92 if (dstVec && dstVec != vecType)
93 return std::nullopt;
94 dstVec = vecType;
95 }
96 return SmallVector<int64_t>(dstVec.getShape());
97 }
98 if (auto writeOp = dyn_cast<vector::TransferWriteOp>(op)) {
99 auto insert = writeOp.getVector().getDefiningOp<InsertStridedSliceOp>();
100 if (!insert)
101 return std::nullopt;
102 ArrayRef<int64_t> shape = insert.getSourceVectorType().getShape();
103 return SmallVector<int64_t>(shape);
104 }
105 return std::nullopt;
106 }
107
108 static LogicalResult filter(Operation *op) {
109 return success(isa<arith::AddFOp, arith::SelectOp, arith::CmpFOp,
110 ContractionOp, TransferReadOp, TransferWriteOp>(op));
111 }
112};
113
114struct TestVectorContractionPrepareForMMTLowering
115 : public PassWrapper<TestVectorContractionPrepareForMMTLowering,
116 OperationPass<func::FuncOp>> {
117 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(
118 TestVectorContractionPrepareForMMTLowering)
119
120 StringRef getArgument() const final {
121 return "test-vector-contraction-prepare-for-mmt-lowering";
122 }
123 StringRef getDescription() const final {
124 return "Test vector.contraction matmul canonicalization for MMT lowering.";
125 }
126 TestVectorContractionPrepareForMMTLowering() = default;
127
128 void getDependentDialects(DialectRegistry &registry) const override {
129 registry.insert<affine::AffineDialect, arith::ArithDialect,
130 vector::VectorDialect>();
131 }
132
133 void runOnOperation() override {
134 MLIRContext *ctx = &getContext();
135 RewritePatternSet patterns(ctx);
136 vector::populateVectorContractCanonicalizeMatmulToMMT(patterns);
137 (void)applyPatternsGreedily(getOperation(), std::move(patterns));
138 }
139};
140
141struct TestVectorUnrollingPatterns
142 : public PassWrapper<TestVectorUnrollingPatterns,
143 OperationPass<func::FuncOp>> {
144 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestVectorUnrollingPatterns)
145
146 StringRef getArgument() const final {
147 return "test-vector-unrolling-patterns";
148 }
149 StringRef getDescription() const final {
150 return "Test lowering patterns to unroll contract ops in the vector "
151 "dialect";
152 }
153 TestVectorUnrollingPatterns() = default;
154 TestVectorUnrollingPatterns(const TestVectorUnrollingPatterns &pass)
155 : PassWrapper(pass) {}
156 void runOnOperation() override {
157 MLIRContext *ctx = &getContext();
158 RewritePatternSet patterns(ctx);
159 populateVectorUnrollPatterns(
160 patterns,
161 options: UnrollVectorOptions()
162 .setNativeShape(ArrayRef<int64_t>{2, 2})
163 .setFilterConstraint([](Operation *op) {
164 return success(
165 isa<arith::AddFOp, vector::FMAOp, vector::MultiDimReductionOp,
166 vector::BroadcastOp>(op));
167 }));
168 populateVectorUnrollPatterns(
169 patterns, options: UnrollVectorOptions()
170 .setNativeShape(ArrayRef<int64_t>{2})
171 .setFilterConstraint([](Operation *op) {
172 return success(isa<vector::ReductionOp>(op));
173 }));
174 populateVectorUnrollPatterns(
175 patterns, options: UnrollVectorOptions()
176 .setNativeShape(ArrayRef<int64_t>{1, 3, 4, 2})
177 .setFilterConstraint([](Operation *op) {
178 return success(isa<vector::TransposeOp>(op));
179 }));
180
181 if (unrollBasedOnType) {
182 UnrollVectorOptions::NativeShapeFnType nativeShapeFn =
183 [](Operation *op) -> std::optional<SmallVector<int64_t>> {
184 vector::ContractionOp contractOp = cast<vector::ContractionOp>(op);
185 SmallVector<int64_t> nativeShape(contractOp.getIteratorTypes().size(),
186 4);
187 Type lhsType = contractOp.getLhsType().getElementType();
188 nativeShape[nativeShape.size() - 1] = lhsType.isF16() ? 4 : 2;
189 return nativeShape;
190 };
191
192 UnrollVectorOptions opts;
193 opts.setNativeShapeFn(nativeShapeFn)
194 .setFilterConstraint(
195 [](Operation *op) { return success(IsSuccess: isa<ContractionOp>(Val: op)); });
196
197 if (!unrollOrder.empty()) {
198 opts.setUnrollTraversalOrderFn(
199 [this](Operation *op) -> std::optional<SmallVector<int64_t>> {
200 vector::ContractionOp contractOp =
201 cast<vector::ContractionOp>(op);
202 if (contractOp.getIteratorTypes().size() == unrollOrder.size())
203 return SmallVector<int64_t>(unrollOrder.begin(),
204 unrollOrder.end());
205 return std::nullopt;
206 });
207 }
208 populateVectorUnrollPatterns(patterns, options: opts);
209 } else {
210 auto nativeShapeFn =
211 [](Operation *op) -> std::optional<SmallVector<int64_t>> {
212 auto contractOp = dyn_cast<ContractionOp>(op);
213 if (!contractOp)
214 return std::nullopt;
215 return SmallVector<int64_t>(contractOp.getIteratorTypes().size(), 2);
216 };
217 populateVectorUnrollPatterns(patterns,
218 UnrollVectorOptions()
219 .setNativeShapeFn(nativeShapeFn)
220 .setFilterConstraint([](Operation *op) {
221 return success(IsSuccess: isa<ContractionOp>(Val: op));
222 }));
223 }
224 populateVectorToVectorCanonicalizationPatterns(patterns);
225 (void)applyPatternsGreedily(getOperation(), std::move(patterns));
226 }
227
228 ListOption<int64_t> unrollOrder{*this, "unroll-order",
229 llvm::cl::desc("set the unroll order")};
230
231 Option<bool> unrollBasedOnType{
232 *this, "unroll-based-on-type",
233 llvm::cl::desc("Set the unroll factor based on type of the operation"),
234 llvm::cl::init(Val: false)};
235};
236
237struct TestVectorTransferUnrollingPatterns
238 : public PassWrapper<TestVectorTransferUnrollingPatterns,
239 OperationPass<func::FuncOp>> {
240 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(
241 TestVectorTransferUnrollingPatterns)
242
243 TestVectorTransferUnrollingPatterns() = default;
244 TestVectorTransferUnrollingPatterns(
245 const TestVectorTransferUnrollingPatterns &pass)
246 : PassWrapper(pass) {}
247
248 void getDependentDialects(DialectRegistry &registry) const override {
249 registry.insert<affine::AffineDialect>();
250 }
251 StringRef getArgument() const final {
252 return "test-vector-transfer-unrolling-patterns";
253 }
254 StringRef getDescription() const final {
255 return "Test lowering patterns to unroll transfer ops in the vector "
256 "dialect";
257 }
258 void runOnOperation() override {
259 MLIRContext *ctx = &getContext();
260 RewritePatternSet patterns(ctx);
261 UnrollVectorOptions opts;
262 opts.setNativeShape(ArrayRef<int64_t>{2, 2})
263 .setFilterConstraint([](Operation *op) {
264 return success(isa<vector::TransferReadOp, vector::TransferWriteOp,
265 vector::GatherOp>(op));
266 });
267 if (reverseUnrollOrder.getValue()) {
268 opts.setUnrollTraversalOrderFn(
269 [](Operation *op) -> std::optional<SmallVector<int64_t>> {
270 int64_t numLoops = 0;
271 if (auto readOp = dyn_cast<vector::TransferReadOp>(op))
272 numLoops = readOp.getVectorType().getRank();
273 else if (auto writeOp = dyn_cast<vector::TransferWriteOp>(op))
274 numLoops = writeOp.getVectorType().getRank();
275 else if (auto gatherOp = dyn_cast<vector::GatherOp>(op))
276 numLoops = gatherOp.getVectorType().getRank();
277 else
278 return std::nullopt;
279 auto order = llvm::reverse(C: llvm::seq<int64_t>(Begin: 0, End: numLoops));
280 return llvm::to_vector(Range&: order);
281 });
282 }
283 populateVectorUnrollPatterns(patterns, options: opts);
284 populateVectorToVectorCanonicalizationPatterns(patterns);
285 (void)applyPatternsGreedily(getOperation(), std::move(patterns));
286 }
287
288 Option<bool> reverseUnrollOrder{
289 *this, "reverse-unroll-order",
290 llvm::cl::desc(
291 "reverse the order of unrolling of vector transfer operations"),
292 llvm::cl::init(Val: false)};
293};
294
295struct TestScalarVectorTransferLoweringPatterns
296 : public PassWrapper<TestScalarVectorTransferLoweringPatterns,
297 OperationPass<func::FuncOp>> {
298 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(
299 TestScalarVectorTransferLoweringPatterns)
300
301 TestScalarVectorTransferLoweringPatterns() = default;
302 TestScalarVectorTransferLoweringPatterns(
303 const TestScalarVectorTransferLoweringPatterns &pass)
304 : PassWrapper(pass) {}
305
306 StringRef getArgument() const final {
307 return "test-scalar-vector-transfer-lowering";
308 }
309 StringRef getDescription() const final {
310 return "Test lowering of scalar vector transfers to memref loads/stores.";
311 }
312
313 void getDependentDialects(DialectRegistry &registry) const override {
314 registry.insert<affine::AffineDialect, memref::MemRefDialect,
315 tensor::TensorDialect, vector::VectorDialect>();
316 }
317
318 Option<bool> allowMultipleUses{
319 *this, "allow-multiple-uses",
320 llvm::cl::desc("Fold transfer operations with multiple uses"),
321 llvm::cl::init(Val: false)};
322
323 void runOnOperation() override {
324 MLIRContext *ctx = &getContext();
325 RewritePatternSet patterns(ctx);
326 vector::populateScalarVectorTransferLoweringPatterns(
327 patterns, /*benefit=*/1, allowMultipleUses: allowMultipleUses.getValue());
328 (void)applyPatternsGreedily(getOperation(), std::move(patterns));
329 }
330};
331
332struct TestVectorTransferOpt
333 : public PassWrapper<TestVectorTransferOpt, OperationPass<func::FuncOp>> {
334 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestVectorTransferOpt)
335
336 StringRef getArgument() const final { return "test-vector-transferop-opt"; }
337 StringRef getDescription() const final {
338 return "Test optimization transformations for transfer ops";
339 }
340 void runOnOperation() override {
341 IRRewriter rewriter(&getContext());
342 transferOpflowOpt(rewriter, getOperation());
343 }
344};
345
346struct TestVectorTransferCollapseInnerMostContiguousDims
347 : public PassWrapper<TestVectorTransferCollapseInnerMostContiguousDims,
348 OperationPass<func::FuncOp>> {
349 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(
350 TestVectorTransferCollapseInnerMostContiguousDims)
351
352 TestVectorTransferCollapseInnerMostContiguousDims() = default;
353 TestVectorTransferCollapseInnerMostContiguousDims(
354 const TestVectorTransferCollapseInnerMostContiguousDims &pass) = default;
355
356 void getDependentDialects(DialectRegistry &registry) const override {
357 registry.insert<memref::MemRefDialect, affine::AffineDialect>();
358 }
359
360 StringRef getArgument() const final {
361 return "test-vector-transfer-collapse-inner-most-dims";
362 }
363
364 StringRef getDescription() const final {
365 return "Test lowering patterns that reduces the rank of the vector "
366 "transfer memory and vector operands.";
367 }
368
369 void runOnOperation() override {
370 RewritePatternSet patterns(&getContext());
371 populateVectorTransferCollapseInnerMostContiguousDimsPatterns(patterns);
372 (void)applyPatternsGreedily(getOperation(), std::move(patterns));
373 }
374};
375
376struct TestVectorSinkPatterns
377 : public PassWrapper<TestVectorSinkPatterns, OperationPass<func::FuncOp>> {
378 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestVectorSinkPatterns)
379
380 TestVectorSinkPatterns() = default;
381 TestVectorSinkPatterns(const TestVectorSinkPatterns &pass) = default;
382
383 void getDependentDialects(DialectRegistry &registry) const override {
384 registry.insert<memref::MemRefDialect, affine::AffineDialect>();
385 }
386
387 StringRef getArgument() const final { return "test-vector-sink-patterns"; }
388
389 StringRef getDescription() const final {
390 return "Test lowering patterns that eliminate redundant broadcast "
391 "and transpose operations.";
392 }
393
394 void runOnOperation() override {
395 RewritePatternSet patterns(&getContext());
396 populateSinkVectorOpsPatterns(patterns);
397 populateSinkVectorMemOpsPatterns(patterns);
398 (void)applyPatternsGreedily(getOperation(), std::move(patterns));
399 }
400};
401
402struct TestVectorReduceToContractPatternsPatterns
403 : public PassWrapper<TestVectorReduceToContractPatternsPatterns,
404 OperationPass<func::FuncOp>> {
405 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(
406 TestVectorReduceToContractPatternsPatterns)
407
408 StringRef getArgument() const final {
409 return "test-vector-reduction-to-contract-patterns";
410 }
411 StringRef getDescription() const final {
412 return "Test patterns to convert multireduce op to contract and combine "
413 "broadcast/transpose to contract";
414 }
415 void runOnOperation() override {
416 RewritePatternSet patterns(&getContext());
417 populateVectorReductionToContractPatterns(patterns);
418 (void)applyPatternsGreedily(getOperation(), std::move(patterns));
419 }
420};
421
422struct TestVectorChainedReductionFoldingPatterns
423 : public PassWrapper<TestVectorChainedReductionFoldingPatterns,
424 OperationPass<func::FuncOp>> {
425 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(
426 TestVectorChainedReductionFoldingPatterns)
427
428 StringRef getArgument() const final {
429 return "test-vector-chained-reduction-folding-patterns";
430 }
431 StringRef getDescription() const final {
432 return "Test patterns to fold chained vector reductions";
433 }
434 void runOnOperation() override {
435 RewritePatternSet patterns(&getContext());
436 populateChainedVectorReductionFoldingPatterns(patterns);
437 (void)applyPatternsGreedily(getOperation(), std::move(patterns));
438 }
439};
440
441struct TestVectorBreakDownReductionPatterns
442 : public PassWrapper<TestVectorBreakDownReductionPatterns,
443 OperationPass<func::FuncOp>> {
444 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(
445 TestVectorBreakDownReductionPatterns)
446
447 StringRef getArgument() const final {
448 return "test-vector-break-down-reduction-patterns";
449 }
450 StringRef getDescription() const final {
451 return "Test patterns to break down vector reductions into arith "
452 "reductions";
453 }
454 void runOnOperation() override {
455 RewritePatternSet patterns(&getContext());
456 populateBreakDownVectorReductionPatterns(patterns,
457 /*maxNumElementsToExtract=*/2);
458 (void)applyPatternsGreedily(getOperation(), std::move(patterns));
459 }
460};
461
462struct TestFlattenVectorTransferPatterns
463 : public PassWrapper<TestFlattenVectorTransferPatterns,
464 OperationPass<func::FuncOp>> {
465 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(
466 TestFlattenVectorTransferPatterns)
467
468 TestFlattenVectorTransferPatterns() = default;
469 TestFlattenVectorTransferPatterns(
470 const TestFlattenVectorTransferPatterns &pass)
471 : PassWrapper(pass) {}
472
473 StringRef getArgument() const final {
474 return "test-vector-transfer-flatten-patterns";
475 }
476
477 StringRef getDescription() const final {
478 return "Test patterns to rewrite contiguous row-major N-dimensional "
479 "vector.transfer_{read,write} ops into 1D transfers";
480 }
481
482 void getDependentDialects(DialectRegistry &registry) const override {
483 registry.insert<memref::MemRefDialect>();
484 registry.insert<affine::AffineDialect>();
485 registry.insert<vector::VectorDialect>();
486 }
487
488 Option<unsigned> targetVectorBitwidth{
489 *this, "target-vector-bitwidth",
490 llvm::cl::desc(
491 "Minimum vector bitwidth to enable the flattening transformation. "
492 "For scalable vectors this is the base size, i.e. the size "
493 "corresponding to vscale=1."),
494 llvm::cl::init(Val: std::numeric_limits<unsigned>::max())};
495
496 void runOnOperation() override {
497 RewritePatternSet patterns(&getContext());
498 populateFlattenVectorTransferPatterns(patterns, targetVectorBitwidth);
499 (void)applyPatternsGreedily(getOperation(), std::move(patterns));
500 }
501};
502
503struct TestVectorScanLowering
504 : public PassWrapper<TestVectorScanLowering, OperationPass<func::FuncOp>> {
505 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestVectorScanLowering)
506
507 StringRef getArgument() const final { return "test-vector-scan-lowering"; }
508 StringRef getDescription() const final {
509 return "Test lowering patterns that lower the scan op in the vector "
510 "dialect";
511 }
512 void runOnOperation() override {
513 RewritePatternSet patterns(&getContext());
514 populateVectorScanLoweringPatterns(patterns);
515 (void)applyPatternsGreedily(getOperation(), std::move(patterns));
516 }
517};
518
519/// Allocate shared memory for a single warp to test lowering of
520/// WarpExecuteOnLane0Op.
521static Value allocateGlobalSharedMemory(Location loc, OpBuilder &builder,
522 gpu::WarpExecuteOnLane0Op warpOp,
523 Type type) {
524 static constexpr int64_t kSharedMemorySpace = 3;
525 // Compute type of shared memory buffer.
526 MemRefType memrefType;
527 if (auto vectorType = dyn_cast<VectorType>(type)) {
528 memrefType =
529 MemRefType::get(vectorType.getShape(), vectorType.getElementType(), {},
530 kSharedMemorySpace);
531 } else {
532 memrefType = MemRefType::get({1}, type, {}, kSharedMemorySpace);
533 }
534
535 // Get symbol table holding all shared memory globals.
536 ModuleOp moduleOp = warpOp->getParentOfType<ModuleOp>();
537 SymbolTable symbolTable(moduleOp);
538
539 // Create a pretty name.
540 SmallString<64> buf;
541 llvm::raw_svector_ostream os(buf);
542 interleave(memrefType.getShape(), os, "x");
543 os << "x" << memrefType.getElementType();
544 std::string symbolName = (Twine("__shared_") + os.str()).str();
545
546 auto ip = builder.saveInsertionPoint();
547 builder.setInsertionPoint(moduleOp);
548 auto global = builder.create<memref::GlobalOp>(
549 loc,
550 /*sym_name=*/symbolName,
551 /*sym_visibility=*/builder.getStringAttr("private"),
552 /*type=*/memrefType,
553 /*initial_value=*/Attribute(),
554 /*constant=*/false,
555 /*alignment=*/IntegerAttr());
556 symbolTable.insert(symbol: global);
557 // The symbol table inserts at the end of the module, but globals are a bit
558 // nicer if they are at the beginning.
559 global->moveBefore(&moduleOp.front());
560
561 builder.restoreInsertionPoint(ip);
562 return builder.create<memref::GetGlobalOp>(loc, memrefType, symbolName);
563}
564
565static Value warpReduction(Location loc, OpBuilder &builder, Value input,
566 CombiningKind kind, uint32_t size) {
567 // First reduce on a single thread to get per lane reduction value.
568 Value laneVal = builder.create<vector::ReductionOp>(loc, kind, input);
569 // Parallel reduction using butterfly shuffles.
570 for (uint64_t i = 1; i < size; i <<= 1) {
571 Value shuffled = builder
572 .create<gpu::ShuffleOp>(loc, laneVal, i,
573 /*width=*/size,
574 /*mode=*/gpu::ShuffleMode::XOR)
575 .getShuffleResult();
576 laneVal = makeArithReduction(builder, loc, kind, laneVal, shuffled);
577 }
578 return laneVal;
579}
580
581struct TestVectorDistribution
582 : public PassWrapper<TestVectorDistribution, OperationPass<func::FuncOp>> {
583 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestVectorDistribution)
584
585 void getDependentDialects(DialectRegistry &registry) const override {
586 registry
587 .insert<vector::VectorDialect, scf::SCFDialect, memref::MemRefDialect,
588 gpu::GPUDialect, affine::AffineDialect>();
589 }
590
591 StringRef getArgument() const final { return "test-vector-warp-distribute"; }
592 StringRef getDescription() const final {
593 return "Test vector warp distribute transformation and lowering patterns";
594 }
595 TestVectorDistribution() = default;
596 TestVectorDistribution(const TestVectorDistribution &pass)
597 : PassWrapper(pass) {}
598
599 Option<bool> warpOpToSCF{
600 *this, "rewrite-warp-ops-to-scf-if",
601 llvm::cl::desc("Lower vector.warp_execute_on_lane0 to scf.if op"),
602 llvm::cl::init(Val: false)};
603
604 Option<bool> distributeTransferWriteOps{
605 *this, "distribute-transfer-write",
606 llvm::cl::desc("Test distribution of transfer write"),
607 llvm::cl::init(Val: false)};
608
609 Option<unsigned> maxTransferWriteElements{
610 *this, "max-transfer-write-elements",
611 llvm::cl::desc("Maximum number of transfer write elements to distribute"),
612 llvm::cl::init(Val: 1)};
613
614 Option<bool> hoistUniform{*this, "hoist-uniform",
615 llvm::cl::desc("Test hoist uniform"),
616 llvm::cl::init(Val: false)};
617
618 Option<bool> propagateDistribution{
619 *this, "propagate-distribution",
620 llvm::cl::desc("Test distribution propagation"), llvm::cl::init(Val: false)};
621
622 void runOnOperation() override {
623 RewritePatternSet patterns(&getContext());
624
625 getOperation().walk([&](Operation *op) {
626 if (auto warpOp = dyn_cast<gpu::WarpExecuteOnLane0Op>(op)) {
627 if (hoistUniform) {
628 moveScalarUniformCode(warpOp);
629 }
630 WalkResult::interrupt();
631 }
632 });
633 MLIRContext *ctx = &getContext();
634 auto distributionFn = [](Value val) {
635 // Create an identity dim map of the same rank as the vector.
636 VectorType vecType = dyn_cast<VectorType>(val.getType());
637 int64_t vecRank = vecType ? vecType.getRank() : 0;
638 OpBuilder builder(val.getContext());
639 if (vecRank == 0)
640 return AffineMap::get(context: val.getContext());
641 return AffineMap::getMultiDimIdentityMap(vecRank, val.getContext());
642 };
643 auto shuffleFn = [](Location loc, OpBuilder &builder, Value val,
644 Value srcIdx, int64_t warpSz) {
645 assert((val.getType().isF32() || val.getType().isInteger(32)) &&
646 "unsupported shuffle type");
647 Type i32Type = builder.getIntegerType(32);
648 Value srcIdxI32 =
649 builder.create<arith::IndexCastOp>(loc, i32Type, srcIdx);
650 Value warpSzI32 = builder.create<arith::ConstantOp>(
651 loc, builder.getIntegerAttr(i32Type, warpSz));
652 Value result = builder
653 .create<gpu::ShuffleOp>(loc, val, srcIdxI32, warpSzI32,
654 gpu::ShuffleMode::IDX)
655 .getResult(0);
656 return result;
657 };
658 if (distributeTransferWriteOps && propagateDistribution) {
659 RewritePatternSet patterns(ctx);
660 vector::populatePropagateWarpVectorDistributionPatterns(
661 patterns, distributionFn, shuffleFn, /*benefit=*/1,
662 /*readBenefit=*/0);
663 vector::populateDistributeReduction(pattern&: patterns, distributedReductionFn: warpReduction, benefit: 1);
664 populateDistributeTransferWriteOpPatterns(patterns, distributionFn, 2);
665 (void)applyPatternsGreedily(getOperation(), std::move(patterns));
666 } else if (distributeTransferWriteOps) {
667 RewritePatternSet patterns(ctx);
668 populateDistributeTransferWriteOpPatterns(patterns, distributionFn,
669 maxTransferWriteElements);
670 (void)applyPatternsGreedily(getOperation(), std::move(patterns));
671 } else if (propagateDistribution) {
672 RewritePatternSet patterns(ctx);
673 vector::populatePropagateWarpVectorDistributionPatterns(
674 patterns, distributionFn, shuffleFn);
675 vector::populateDistributeReduction(pattern&: patterns, distributedReductionFn: warpReduction);
676 (void)applyPatternsGreedily(getOperation(), std::move(patterns));
677 }
678 WarpExecuteOnLane0LoweringOptions options;
679 options.warpAllocationFn = allocateGlobalSharedMemory;
680 options.warpSyncronizationFn = [](Location loc, OpBuilder &builder,
681 gpu::WarpExecuteOnLane0Op warpOp) {
682 builder.create<gpu::BarrierOp>(loc);
683 };
684 // Test on one pattern in isolation.
685 if (warpOpToSCF) {
686 populateWarpExecuteOnLane0OpToScfForPattern(patterns, options);
687 (void)applyPatternsGreedily(getOperation(), std::move(patterns));
688 return;
689 }
690 }
691};
692
693struct TestVectorExtractStridedSliceLowering
694 : public PassWrapper<TestVectorExtractStridedSliceLowering,
695 OperationPass<func::FuncOp>> {
696 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(
697 TestVectorExtractStridedSliceLowering)
698
699 StringRef getArgument() const final {
700 return "test-vector-extract-strided-slice-lowering";
701 }
702 StringRef getDescription() const final {
703 return "Test lowering patterns that converts vector.extract_strided_slice "
704 "into a chain of vector.extract and vector.insert ops";
705 }
706 void runOnOperation() override {
707 RewritePatternSet patterns(&getContext());
708 populateVectorExtractStridedSliceToExtractInsertChainPatterns(patterns);
709 (void)applyPatternsGreedily(getOperation(), std::move(patterns));
710 }
711};
712
713struct TestVectorBreakDownBitCast
714 : public PassWrapper<TestVectorBreakDownBitCast,
715 OperationPass<func::FuncOp>> {
716 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestVectorBreakDownBitCast)
717
718 StringRef getArgument() const final {
719 return "test-vector-break-down-bitcast";
720 }
721 StringRef getDescription() const final {
722 return "Test pattern that breaks down vector.bitcast ops ";
723 }
724 void runOnOperation() override {
725 RewritePatternSet patterns(&getContext());
726 populateBreakDownVectorBitCastOpPatterns(patterns, [](BitCastOp op) {
727 return op.getSourceVectorType().getShape().back() > 4;
728 });
729 (void)applyPatternsGreedily(getOperation(), std::move(patterns));
730 }
731};
732
733struct TestCreateVectorBroadcast
734 : public PassWrapper<TestCreateVectorBroadcast,
735 OperationPass<func::FuncOp>> {
736 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestCreateVectorBroadcast)
737
738 StringRef getArgument() const final { return "test-create-vector-broadcast"; }
739 StringRef getDescription() const final {
740 return "Test optimization transformations for transfer ops";
741 }
742 void getDependentDialects(DialectRegistry &registry) const override {
743 registry.insert<vector::VectorDialect>();
744 }
745
746 void runOnOperation() override {
747 getOperation()->walk([](Operation *op) {
748 if (op->getName().getStringRef() != "test_create_broadcast")
749 return;
750 auto targetShape =
751 cast<VectorType>(op->getResult(0).getType()).getShape();
752 auto arrayAttr =
753 cast<DenseI64ArrayAttr>(op->getDiscardableAttr("broadcast_dims"))
754 .asArrayRef();
755 llvm::SetVector<int64_t> broadcastedDims;
756 broadcastedDims.insert_range(arrayAttr);
757 OpBuilder b(op);
758 Value bcast = vector::BroadcastOp::createOrFoldBroadcastOp(
759 b, op->getOperand(0), targetShape, broadcastedDims);
760 op->getResult(0).replaceAllUsesWith(bcast);
761 op->erase();
762 });
763 }
764};
765
766struct TestVectorGatherLowering
767 : public PassWrapper<TestVectorGatherLowering,
768 OperationPass<func::FuncOp>> {
769 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestVectorGatherLowering)
770
771 StringRef getArgument() const final { return "test-vector-gather-lowering"; }
772 StringRef getDescription() const final {
773 return "Test patterns that lower the gather op in the vector conditional "
774 "loads";
775 }
776 void getDependentDialects(DialectRegistry &registry) const override {
777 registry.insert<arith::ArithDialect, func::FuncDialect,
778 memref::MemRefDialect, scf::SCFDialect,
779 tensor::TensorDialect, vector::VectorDialect>();
780 }
781
782 void runOnOperation() override {
783 RewritePatternSet patterns(&getContext());
784 populateVectorGatherLoweringPatterns(patterns);
785 populateVectorGatherToConditionalLoadPatterns(patterns);
786 (void)applyPatternsGreedily(getOperation(), std::move(patterns));
787 }
788};
789
790struct TestFoldArithExtensionIntoVectorContractPatterns
791 : public PassWrapper<TestFoldArithExtensionIntoVectorContractPatterns,
792 OperationPass<func::FuncOp>> {
793 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(
794 TestFoldArithExtensionIntoVectorContractPatterns)
795
796 StringRef getArgument() const final {
797 return "test-fold-arith-extf-into-vector-contract-patterns";
798 }
799 StringRef getDescription() const final {
800 return "Test patterns that fold arithmetic extension ops into vector "
801 "contract ops";
802 }
803
804 void getDependentDialects(DialectRegistry &registry) const override {
805 registry.insert<arith::ArithDialect, func::FuncDialect, nvgpu::NVGPUDialect,
806 memref::MemRefDialect, scf::SCFDialect,
807 tensor::TensorDialect, vector::VectorDialect>();
808 }
809
810 void runOnOperation() override {
811 RewritePatternSet patterns(&getContext());
812 populateFoldArithExtensionPatterns(patterns);
813 (void)applyPatternsGreedily(getOperation(), std::move(patterns));
814 }
815};
816
817struct TestVectorEmulateMaskedLoadStore final
818 : public PassWrapper<TestVectorEmulateMaskedLoadStore,
819 OperationPass<func::FuncOp>> {
820 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestVectorEmulateMaskedLoadStore)
821
822 StringRef getArgument() const override {
823 return "test-vector-emulate-masked-load-store";
824 }
825 StringRef getDescription() const override {
826 return "Test patterns that emulate the maskedload/maskedstore op by "
827 " memref.load/store and scf.if";
828 }
829 void getDependentDialects(DialectRegistry &registry) const override {
830 registry
831 .insert<arith::ArithDialect, func::FuncDialect, memref::MemRefDialect,
832 scf::SCFDialect, vector::VectorDialect>();
833 }
834
835 void runOnOperation() override {
836 RewritePatternSet patterns(&getContext());
837 populateVectorMaskedLoadStoreEmulationPatterns(patterns);
838 (void)applyPatternsGreedily(getOperation(), std::move(patterns));
839 }
840};
841
842/// Get the set of operand/result types to check for sufficiently
843/// small inner-most dimension size.
844static SmallVector<std::pair<Type, unsigned>>
845getTypeBitWidthBoundPairs(Operation *op, unsigned targetBitWidth) {
846
847 if (auto insertOp = dyn_cast<vector::InsertOp>(op)) {
848 unsigned w = targetBitWidth < std::numeric_limits<unsigned>::max()
849 ? targetBitWidth + 1
850 : targetBitWidth;
851 return {{insertOp.getValueToStoreType(), w}};
852 }
853
854 auto resultTypes = op->getResultTypes();
855 SmallVector<std::pair<Type, unsigned>> resultsWithBitWidth;
856 resultsWithBitWidth.reserve(N: resultTypes.size());
857 for (Type type : resultTypes) {
858 resultsWithBitWidth.push_back(Elt: {type, targetBitWidth});
859 }
860 return resultsWithBitWidth;
861}
862
863/// If `type` is VectorType with trailing dimension of (bit) size greater than
864/// or equal to `targetBitWidth`, its defining op is considered legal.
865static bool
866isNotLinearizableBecauseLargeInnerDimension(Type type,
867 unsigned targetBitWidth) {
868
869 VectorType vecType = dyn_cast<VectorType>(type);
870
871 // Not linearizable for reasons other than what this function checks.
872 if (!vecType || vecType.getRank() == 0)
873 return false;
874
875 // The width of the type 'index' is unbounded (and therefore potentially above
876 // the target width).
877 if (vecType.getElementType().isIndex())
878 return true;
879
880 unsigned finalDimSize = vecType.getShape().back();
881 unsigned nbBitsPerElm = vecType.getElementTypeBitWidth();
882 unsigned trailingVecDimBitWidth = finalDimSize * nbBitsPerElm;
883 return trailingVecDimBitWidth >= targetBitWidth;
884}
885
886static bool
887isNotLinearizableBecauseLargeInnerDimension(Operation *op,
888 unsigned targetBitWidth) {
889 // Check on bitwidths.
890 SmallVector<std::pair<Type, unsigned>> toCheck =
891 getTypeBitWidthBoundPairs(op, targetBitWidth);
892 return llvm::any_of(Range&: toCheck, P: [&](std::pair<Type, unsigned> typeWidth) {
893 return isNotLinearizableBecauseLargeInnerDimension(type: typeWidth.first,
894 targetBitWidth: typeWidth.second);
895 });
896}
897
898void populateWithBitWidthConstraints(TypeConverter &typeConverter,
899 ConversionTarget &target,
900 unsigned targetBitWidth) {
901
902 // The general purpose definition of what ops are legal must come first.
903 populateForVectorLinearize(typeConverter, conversionTarget&: target);
904
905 // Extend the set of legal ops to include those with large inner-most
906 // dimensions on selected operands/results.
907 target.markUnknownOpDynamicallyLegal(
908 fn: [=](Operation *op) -> std::optional<bool> {
909 if (isNotLinearizableBecauseLargeInnerDimension(op, targetBitWidth)) {
910 return true;
911 }
912 return {};
913 });
914}
915
916struct TestVectorBitWidthLinearize final
917 : public PassWrapper<TestVectorBitWidthLinearize, OperationPass<>> {
918 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestVectorBitWidthLinearize)
919
920 TestVectorBitWidthLinearize() = default;
921 TestVectorBitWidthLinearize(const TestVectorBitWidthLinearize &pass)
922 : PassWrapper(pass) {}
923
924 StringRef getArgument() const override {
925 return "test-bit-width-constrained-vector-linearize";
926 }
927 StringRef getDescription() const override {
928 return "Linearizes ND vectors for N >= 2 into 1D vectors, with constraints "
929 "in inner-most dimension's bit width.";
930 }
931 void getDependentDialects(DialectRegistry &registry) const override {
932 registry.insert<vector::VectorDialect>();
933 }
934
935 Option<unsigned> targetVectorBitwidth{
936 *this, "target-vector-bitwidth",
937 llvm::cl::desc(
938 "Minimum vector bitwidth to enable the flattening transformation"),
939 llvm::cl::init(Val: std::numeric_limits<unsigned>::max())};
940 void runOnOperation() override {
941 auto *context = &getContext();
942
943 TypeConverter typeConverter;
944 RewritePatternSet patterns(context);
945 ConversionTarget target(*context);
946
947 populateWithBitWidthConstraints(typeConverter, target,
948 targetBitWidth: targetVectorBitwidth);
949
950 vector::populateVectorLinearizeBasePatterns(typeConverter, target,
951 patterns);
952
953 vector::populateVectorLinearizeShuffleLikeOpsPatterns(typeConverter, target,
954 patterns);
955
956 if (failed(applyPartialConversion(getOperation(), target,
957 std::move(patterns))))
958 return signalPassFailure();
959 }
960};
961
962struct TestVectorLinearize final
963 : public PassWrapper<TestVectorLinearize, OperationPass<>> {
964 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestVectorLinearize)
965
966 TestVectorLinearize() = default;
967
968 StringRef getArgument() const override { return "test-vector-linearize"; }
969 StringRef getDescription() const override {
970 return "Linearizes ND vectors for N >= 2 into 1D vectors";
971 }
972 void getDependentDialects(DialectRegistry &registry) const override {
973 registry.insert<vector::VectorDialect, arith::ArithDialect>();
974 }
975
976 void runOnOperation() override {
977 MLIRContext &context = getContext();
978 TypeConverter converter;
979 RewritePatternSet patterns(&context);
980 ConversionTarget target(context);
981
982 vector::populateForVectorLinearize(typeConverter&: converter, conversionTarget&: target);
983
984 vector::populateVectorLinearizeBasePatterns(converter, target, patterns);
985 vector::populateVectorLinearizeShuffleLikeOpsPatterns(converter, target,
986 patterns);
987 mlir::scf::populateSCFStructuralTypeConversionsAndLegality(
988 typeConverter: converter, patterns, target);
989
990 if (failed(applyPartialConversion(getOperation(), target,
991 std::move(patterns))))
992 return signalPassFailure();
993 }
994};
995
996struct TestEliminateVectorMasks
997 : public PassWrapper<TestEliminateVectorMasks,
998 OperationPass<func::FuncOp>> {
999 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestEliminateVectorMasks)
1000
1001 TestEliminateVectorMasks() = default;
1002 TestEliminateVectorMasks(const TestEliminateVectorMasks &pass)
1003 : PassWrapper(pass) {}
1004
1005 Option<unsigned> vscaleMin{
1006 *this, "vscale-min", llvm::cl::desc("Minimum possible value of vscale."),
1007 llvm::cl::init(Val: 1)};
1008 Option<unsigned> vscaleMax{
1009 *this, "vscale-max", llvm::cl::desc("Maximum possible value of vscale."),
1010 llvm::cl::init(Val: 16)};
1011
1012 StringRef getArgument() const final { return "test-eliminate-vector-masks"; }
1013 StringRef getDescription() const final {
1014 return "Test eliminating vector masks";
1015 }
1016 void runOnOperation() override {
1017 IRRewriter rewriter(&getContext());
1018 eliminateVectorMasks(rewriter, getOperation(),
1019 VscaleRange{.vscaleMin: vscaleMin, .vscaleMax: vscaleMax});
1020 }
1021};
1022} // namespace
1023
1024namespace mlir {
1025namespace test {
1026void registerTestVectorLowerings() {
1027 PassRegistration<TestVectorToVectorLowering>();
1028
1029 PassRegistration<TestVectorContractionPrepareForMMTLowering>();
1030
1031 PassRegistration<TestVectorUnrollingPatterns>();
1032
1033 PassRegistration<TestVectorTransferUnrollingPatterns>();
1034
1035 PassRegistration<TestScalarVectorTransferLoweringPatterns>();
1036
1037 PassRegistration<TestVectorTransferOpt>();
1038
1039 PassRegistration<TestVectorTransferCollapseInnerMostContiguousDims>();
1040
1041 PassRegistration<TestVectorSinkPatterns>();
1042
1043 PassRegistration<TestVectorReduceToContractPatternsPatterns>();
1044
1045 PassRegistration<TestVectorChainedReductionFoldingPatterns>();
1046
1047 PassRegistration<TestVectorBreakDownReductionPatterns>();
1048
1049 PassRegistration<TestFlattenVectorTransferPatterns>();
1050
1051 PassRegistration<TestVectorScanLowering>();
1052
1053 PassRegistration<TestVectorDistribution>();
1054
1055 PassRegistration<TestVectorExtractStridedSliceLowering>();
1056
1057 PassRegistration<TestVectorBreakDownBitCast>();
1058
1059 PassRegistration<TestCreateVectorBroadcast>();
1060
1061 PassRegistration<TestVectorGatherLowering>();
1062
1063 PassRegistration<TestFoldArithExtensionIntoVectorContractPatterns>();
1064
1065 PassRegistration<TestVectorEmulateMaskedLoadStore>();
1066
1067 PassRegistration<TestVectorLinearize>();
1068
1069 PassRegistration<TestVectorBitWidthLinearize>();
1070
1071 PassRegistration<TestEliminateVectorMasks>();
1072}
1073} // namespace test
1074} // namespace mlir
1075

Provided by KDAB

Privacy Policy
Update your C++ knowledge – Modern C++11/14/17 Training
Find out more

source code of mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp