1 | //===- TestVectorTransforms.cpp - Test Vector transforms and lowerings ----===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | |
9 | #include <optional> |
10 | #include <type_traits> |
11 | |
12 | #include "mlir/Analysis/SliceAnalysis.h" |
13 | #include "mlir/Dialect/Affine/IR/AffineOps.h" |
14 | #include "mlir/Dialect/Arith/IR/Arith.h" |
15 | #include "mlir/Dialect/Func/IR/FuncOps.h" |
16 | #include "mlir/Dialect/GPU/IR/GPUDialect.h" |
17 | #include "mlir/Dialect/LLVMIR/LLVMDialect.h" |
18 | #include "mlir/Dialect/Linalg/IR/Linalg.h" |
19 | #include "mlir/Dialect/Linalg/Passes.h" |
20 | #include "mlir/Dialect/Linalg/Transforms/Transforms.h" |
21 | #include "mlir/Dialect/MemRef/IR/MemRef.h" |
22 | #include "mlir/Dialect/NVGPU/IR/NVGPUDialect.h" |
23 | #include "mlir/Dialect/SCF/IR/SCF.h" |
24 | #include "mlir/Dialect/Tensor/IR/Tensor.h" |
25 | #include "mlir/Dialect/Vector/IR/VectorOps.h" |
26 | #include "mlir/Dialect/Vector/Transforms/LoweringPatterns.h" |
27 | #include "mlir/Dialect/Vector/Transforms/VectorDistribution.h" |
28 | #include "mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h" |
29 | #include "mlir/Dialect/Vector/Transforms/VectorTransforms.h" |
30 | #include "mlir/Pass/Pass.h" |
31 | #include "mlir/Pass/PassManager.h" |
32 | #include "mlir/Support/LLVM.h" |
33 | #include "mlir/Transforms/GreedyPatternRewriteDriver.h" |
34 | |
35 | using namespace mlir; |
36 | using namespace mlir::linalg; |
37 | using namespace mlir::vector; |
38 | |
39 | namespace { |
40 | |
41 | struct TestVectorToVectorLowering |
42 | : public PassWrapper<TestVectorToVectorLowering, |
43 | OperationPass<func::FuncOp>> { |
44 | MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestVectorToVectorLowering) |
45 | |
46 | TestVectorToVectorLowering() = default; |
47 | TestVectorToVectorLowering(const TestVectorToVectorLowering &pass) |
48 | : PassWrapper(pass) {} |
49 | StringRef getArgument() const final { |
50 | return "test-vector-to-vector-lowering" ; |
51 | } |
52 | StringRef getDescription() const final { |
53 | return "Test lowering patterns between ops in the vector dialect" ; |
54 | } |
55 | |
56 | void getDependentDialects(DialectRegistry ®istry) const override { |
57 | registry.insert<affine::AffineDialect>(); |
58 | registry.insert<vector::VectorDialect>(); |
59 | } |
60 | |
61 | Option<bool> unroll{*this, "unroll" , llvm::cl::desc("Include unrolling" ), |
62 | llvm::cl::init(Val: false)}; |
63 | |
64 | void runOnOperation() override { |
65 | auto *ctx = &getContext(); |
66 | RewritePatternSet patterns(ctx); |
67 | if (unroll) { |
68 | populateVectorUnrollPatterns( |
69 | patterns, |
70 | UnrollVectorOptions().setNativeShapeFn(getShape).setFilterConstraint( |
71 | filter)); |
72 | } |
73 | populateVectorToVectorCanonicalizationPatterns(patterns); |
74 | populateBubbleVectorBitCastOpPatterns(patterns); |
75 | populateCastAwayVectorLeadingOneDimPatterns(patterns); |
76 | (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)); |
77 | } |
78 | |
79 | private: |
80 | // Return the target shape based on op type. |
81 | static std::optional<SmallVector<int64_t>> getShape(Operation *op) { |
82 | if (isa<arith::AddFOp, arith::SelectOp, arith::CmpFOp>(op)) |
83 | return SmallVector<int64_t>(2, 2); |
84 | if (isa<vector::ContractionOp>(Val: op)) |
85 | return SmallVector<int64_t>(3, 2); |
86 | // For transfer ops, just propagate the shape coming from |
87 | // InsertStridedSlices/ExtractStridedSlices. |
88 | if (auto readOp = dyn_cast<vector::TransferReadOp>(op)) { |
89 | VectorType dstVec; |
90 | for (Operation *users : readOp->getUsers()) { |
91 | auto extract = dyn_cast<ExtractStridedSliceOp>(users); |
92 | if (!extract) |
93 | return std::nullopt; |
94 | auto vecType = cast<VectorType>(extract.getResult().getType()); |
95 | if (dstVec && dstVec != vecType) |
96 | return std::nullopt; |
97 | dstVec = vecType; |
98 | } |
99 | return SmallVector<int64_t>(dstVec.getShape().begin(), |
100 | dstVec.getShape().end()); |
101 | } |
102 | if (auto writeOp = dyn_cast<vector::TransferWriteOp>(op)) { |
103 | auto insert = writeOp.getVector().getDefiningOp<InsertStridedSliceOp>(); |
104 | if (!insert) |
105 | return std::nullopt; |
106 | ArrayRef<int64_t> shape = insert.getSourceVectorType().getShape(); |
107 | return SmallVector<int64_t>(shape.begin(), shape.end()); |
108 | } |
109 | return std::nullopt; |
110 | } |
111 | |
112 | static LogicalResult filter(Operation *op) { |
113 | return success(isa<arith::AddFOp, arith::SelectOp, arith::CmpFOp, |
114 | ContractionOp, TransferReadOp, TransferWriteOp>(op)); |
115 | } |
116 | }; |
117 | |
118 | struct TestVectorContractionPrepareForMMTLowering |
119 | : public PassWrapper<TestVectorContractionPrepareForMMTLowering, |
120 | OperationPass<func::FuncOp>> { |
121 | MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID( |
122 | TestVectorContractionPrepareForMMTLowering) |
123 | |
124 | StringRef getArgument() const final { |
125 | return "test-vector-contraction-prepare-for-mmt-lowering" ; |
126 | } |
127 | StringRef getDescription() const final { |
128 | return "Test vector.contraction matmul canonicalization for MMT lowering." ; |
129 | } |
130 | TestVectorContractionPrepareForMMTLowering() = default; |
131 | |
132 | void getDependentDialects(DialectRegistry ®istry) const override { |
133 | registry.insert<affine::AffineDialect, arith::ArithDialect, |
134 | vector::VectorDialect>(); |
135 | } |
136 | |
137 | void runOnOperation() override { |
138 | MLIRContext *ctx = &getContext(); |
139 | RewritePatternSet patterns(ctx); |
140 | vector::populateVectorContractCanonicalizeMatmulToMMT(patterns); |
141 | (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)); |
142 | } |
143 | }; |
144 | |
145 | struct TestVectorUnrollingPatterns |
146 | : public PassWrapper<TestVectorUnrollingPatterns, |
147 | OperationPass<func::FuncOp>> { |
148 | MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestVectorUnrollingPatterns) |
149 | |
150 | StringRef getArgument() const final { |
151 | return "test-vector-unrolling-patterns" ; |
152 | } |
153 | StringRef getDescription() const final { |
154 | return "Test lowering patterns to unroll contract ops in the vector " |
155 | "dialect" ; |
156 | } |
157 | TestVectorUnrollingPatterns() = default; |
158 | TestVectorUnrollingPatterns(const TestVectorUnrollingPatterns &pass) |
159 | : PassWrapper(pass) {} |
160 | void runOnOperation() override { |
161 | MLIRContext *ctx = &getContext(); |
162 | RewritePatternSet patterns(ctx); |
163 | populateVectorUnrollPatterns( |
164 | patterns, options: UnrollVectorOptions() |
165 | .setNativeShape(ArrayRef<int64_t>{2, 2}) |
166 | .setFilterConstraint([](Operation *op) { |
167 | return success(isa<arith::AddFOp, vector::FMAOp, |
168 | vector::MultiDimReductionOp>(op)); |
169 | })); |
170 | populateVectorUnrollPatterns( |
171 | patterns, options: UnrollVectorOptions() |
172 | .setNativeShape(ArrayRef<int64_t>{2}) |
173 | .setFilterConstraint([](Operation *op) { |
174 | return success(isa<vector::ReductionOp>(op)); |
175 | })); |
176 | populateVectorUnrollPatterns( |
177 | patterns, options: UnrollVectorOptions() |
178 | .setNativeShape(ArrayRef<int64_t>{1, 3, 4, 2}) |
179 | .setFilterConstraint([](Operation *op) { |
180 | return success(isa<vector::TransposeOp>(op)); |
181 | })); |
182 | |
183 | if (unrollBasedOnType) { |
184 | UnrollVectorOptions::NativeShapeFnType nativeShapeFn = |
185 | [](Operation *op) -> std::optional<SmallVector<int64_t>> { |
186 | vector::ContractionOp contractOp = cast<vector::ContractionOp>(op); |
187 | SmallVector<int64_t> nativeShape(contractOp.getIteratorTypes().size(), |
188 | 4); |
189 | Type lhsType = contractOp.getLhsType().getElementType(); |
190 | nativeShape[nativeShape.size() - 1] = lhsType.isF16() ? 4 : 2; |
191 | return nativeShape; |
192 | }; |
193 | |
194 | UnrollVectorOptions opts; |
195 | opts.setNativeShapeFn(nativeShapeFn) |
196 | .setFilterConstraint( |
197 | [](Operation *op) { return success(isSuccess: isa<ContractionOp>(Val: op)); }); |
198 | |
199 | if (!unrollOrder.empty()) { |
200 | opts.setUnrollTraversalOrderFn( |
201 | [this](Operation *op) -> std::optional<SmallVector<int64_t>> { |
202 | vector::ContractionOp contractOp = |
203 | cast<vector::ContractionOp>(op); |
204 | if (contractOp.getIteratorTypes().size() == unrollOrder.size()) |
205 | return SmallVector<int64_t>(unrollOrder.begin(), |
206 | unrollOrder.end()); |
207 | return std::nullopt; |
208 | }); |
209 | } |
210 | populateVectorUnrollPatterns(patterns, options: opts); |
211 | } else { |
212 | auto nativeShapeFn = |
213 | [](Operation *op) -> std::optional<SmallVector<int64_t>> { |
214 | auto contractOp = dyn_cast<ContractionOp>(op); |
215 | if (!contractOp) |
216 | return std::nullopt; |
217 | return SmallVector<int64_t>(contractOp.getIteratorTypes().size(), 2); |
218 | }; |
219 | populateVectorUnrollPatterns(patterns, |
220 | UnrollVectorOptions() |
221 | .setNativeShapeFn(nativeShapeFn) |
222 | .setFilterConstraint([](Operation *op) { |
223 | return success(isSuccess: isa<ContractionOp>(Val: op)); |
224 | })); |
225 | } |
226 | populateVectorToVectorCanonicalizationPatterns(patterns); |
227 | (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)); |
228 | } |
229 | |
230 | ListOption<int64_t> unrollOrder{*this, "unroll-order" , |
231 | llvm::cl::desc("set the unroll order" )}; |
232 | |
233 | Option<bool> unrollBasedOnType{ |
234 | *this, "unroll-based-on-type" , |
235 | llvm::cl::desc("Set the unroll factor based on type of the operation" ), |
236 | llvm::cl::init(Val: false)}; |
237 | }; |
238 | |
239 | struct TestVectorTransferUnrollingPatterns |
240 | : public PassWrapper<TestVectorTransferUnrollingPatterns, |
241 | OperationPass<func::FuncOp>> { |
242 | MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID( |
243 | TestVectorTransferUnrollingPatterns) |
244 | |
245 | TestVectorTransferUnrollingPatterns() = default; |
246 | TestVectorTransferUnrollingPatterns( |
247 | const TestVectorTransferUnrollingPatterns &pass) |
248 | : PassWrapper(pass) {} |
249 | |
250 | void getDependentDialects(DialectRegistry ®istry) const override { |
251 | registry.insert<affine::AffineDialect>(); |
252 | } |
253 | StringRef getArgument() const final { |
254 | return "test-vector-transfer-unrolling-patterns" ; |
255 | } |
256 | StringRef getDescription() const final { |
257 | return "Test lowering patterns to unroll transfer ops in the vector " |
258 | "dialect" ; |
259 | } |
260 | void runOnOperation() override { |
261 | MLIRContext *ctx = &getContext(); |
262 | RewritePatternSet patterns(ctx); |
263 | UnrollVectorOptions opts; |
264 | opts.setNativeShape(ArrayRef<int64_t>{2, 2}) |
265 | .setFilterConstraint([](Operation *op) { |
266 | return success(isa<vector::TransferReadOp, vector::TransferWriteOp, |
267 | vector::GatherOp>(op)); |
268 | }); |
269 | if (reverseUnrollOrder.getValue()) { |
270 | opts.setUnrollTraversalOrderFn( |
271 | [](Operation *op) -> std::optional<SmallVector<int64_t>> { |
272 | int64_t numLoops = 0; |
273 | if (auto readOp = dyn_cast<vector::TransferReadOp>(op)) |
274 | numLoops = readOp.getVectorType().getRank(); |
275 | else if (auto writeOp = dyn_cast<vector::TransferWriteOp>(op)) |
276 | numLoops = writeOp.getVectorType().getRank(); |
277 | else if (auto gatherOp = dyn_cast<vector::GatherOp>(op)) |
278 | numLoops = gatherOp.getVectorType().getRank(); |
279 | else |
280 | return std::nullopt; |
281 | auto order = llvm::reverse(C: llvm::seq<int64_t>(Begin: 0, End: numLoops)); |
282 | return llvm::to_vector(Range&: order); |
283 | }); |
284 | } |
285 | populateVectorUnrollPatterns(patterns, options: opts); |
286 | populateVectorToVectorCanonicalizationPatterns(patterns); |
287 | (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)); |
288 | } |
289 | |
290 | Option<bool> reverseUnrollOrder{ |
291 | *this, "reverse-unroll-order" , |
292 | llvm::cl::desc( |
293 | "reverse the order of unrolling of vector transfer operations" ), |
294 | llvm::cl::init(Val: false)}; |
295 | }; |
296 | |
297 | struct TestScalarVectorTransferLoweringPatterns |
298 | : public PassWrapper<TestScalarVectorTransferLoweringPatterns, |
299 | OperationPass<func::FuncOp>> { |
300 | MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID( |
301 | TestScalarVectorTransferLoweringPatterns) |
302 | |
303 | TestScalarVectorTransferLoweringPatterns() = default; |
304 | TestScalarVectorTransferLoweringPatterns( |
305 | const TestScalarVectorTransferLoweringPatterns &pass) |
306 | : PassWrapper(pass) {} |
307 | |
308 | StringRef getArgument() const final { |
309 | return "test-scalar-vector-transfer-lowering" ; |
310 | } |
311 | StringRef getDescription() const final { |
312 | return "Test lowering of scalar vector transfers to memref loads/stores." ; |
313 | } |
314 | |
315 | void getDependentDialects(DialectRegistry ®istry) const override { |
316 | registry.insert<affine::AffineDialect, memref::MemRefDialect, |
317 | tensor::TensorDialect, vector::VectorDialect>(); |
318 | } |
319 | |
320 | Option<bool> allowMultipleUses{ |
321 | *this, "allow-multiple-uses" , |
322 | llvm::cl::desc("Fold transfer operations with multiple uses" ), |
323 | llvm::cl::init(Val: false)}; |
324 | |
325 | void runOnOperation() override { |
326 | MLIRContext *ctx = &getContext(); |
327 | RewritePatternSet patterns(ctx); |
328 | vector::populateScalarVectorTransferLoweringPatterns( |
329 | patterns, /*benefit=*/1, allowMultipleUses: allowMultipleUses.getValue()); |
330 | (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)); |
331 | } |
332 | }; |
333 | |
334 | struct TestVectorTransferOpt |
335 | : public PassWrapper<TestVectorTransferOpt, OperationPass<func::FuncOp>> { |
336 | MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestVectorTransferOpt) |
337 | |
338 | StringRef getArgument() const final { return "test-vector-transferop-opt" ; } |
339 | StringRef getDescription() const final { |
340 | return "Test optimization transformations for transfer ops" ; |
341 | } |
342 | void runOnOperation() override { |
343 | IRRewriter rewriter(&getContext()); |
344 | transferOpflowOpt(rewriter, getOperation()); |
345 | } |
346 | }; |
347 | |
348 | struct TestVectorTransferCollapseInnerMostContiguousDims |
349 | : public PassWrapper<TestVectorTransferCollapseInnerMostContiguousDims, |
350 | OperationPass<func::FuncOp>> { |
351 | MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID( |
352 | TestVectorTransferCollapseInnerMostContiguousDims) |
353 | |
354 | TestVectorTransferCollapseInnerMostContiguousDims() = default; |
355 | TestVectorTransferCollapseInnerMostContiguousDims( |
356 | const TestVectorTransferCollapseInnerMostContiguousDims &pass) = default; |
357 | |
358 | void getDependentDialects(DialectRegistry ®istry) const override { |
359 | registry.insert<memref::MemRefDialect, affine::AffineDialect>(); |
360 | } |
361 | |
362 | StringRef getArgument() const final { |
363 | return "test-vector-transfer-collapse-inner-most-dims" ; |
364 | } |
365 | |
366 | StringRef getDescription() const final { |
367 | return "Test lowering patterns that reducedes the rank of the vector " |
368 | "transfer memory and vector operands." ; |
369 | } |
370 | |
371 | void runOnOperation() override { |
372 | RewritePatternSet patterns(&getContext()); |
373 | populateVectorTransferCollapseInnerMostContiguousDimsPatterns(patterns); |
374 | (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)); |
375 | } |
376 | }; |
377 | |
378 | struct TestSinkVectorBroadcast |
379 | : public PassWrapper<TestSinkVectorBroadcast, OperationPass<func::FuncOp>> { |
380 | MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestSinkVectorBroadcast) |
381 | |
382 | TestSinkVectorBroadcast() = default; |
383 | TestSinkVectorBroadcast(const TestSinkVectorBroadcast &pass) = default; |
384 | |
385 | void getDependentDialects(DialectRegistry ®istry) const override { |
386 | registry.insert<memref::MemRefDialect, affine::AffineDialect>(); |
387 | } |
388 | |
389 | StringRef getArgument() const final { return "test-sink-vector-broadcast" ; } |
390 | |
391 | StringRef getDescription() const final { |
392 | return "Test lowering patterns that eliminate redundant brodacast " |
393 | "operations." ; |
394 | } |
395 | |
396 | void runOnOperation() override { |
397 | RewritePatternSet patterns(&getContext()); |
398 | populateSinkVectorBroadcastPatterns(patterns); |
399 | (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)); |
400 | } |
401 | }; |
402 | |
403 | struct TestVectorReduceToContractPatternsPatterns |
404 | : public PassWrapper<TestVectorReduceToContractPatternsPatterns, |
405 | OperationPass<func::FuncOp>> { |
406 | MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID( |
407 | TestVectorReduceToContractPatternsPatterns) |
408 | |
409 | StringRef getArgument() const final { |
410 | return "test-vector-reduction-to-contract-patterns" ; |
411 | } |
412 | StringRef getDescription() const final { |
413 | return "Test patterns to convert multireduce op to contract and combine " |
414 | "broadcast/transpose to contract" ; |
415 | } |
416 | void runOnOperation() override { |
417 | RewritePatternSet patterns(&getContext()); |
418 | populateVectorReductionToContractPatterns(patterns); |
419 | (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)); |
420 | } |
421 | }; |
422 | |
423 | struct TestVectorChainedReductionFoldingPatterns |
424 | : public PassWrapper<TestVectorChainedReductionFoldingPatterns, |
425 | OperationPass<func::FuncOp>> { |
426 | MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID( |
427 | TestVectorChainedReductionFoldingPatterns) |
428 | |
429 | StringRef getArgument() const final { |
430 | return "test-vector-chained-reduction-folding-patterns" ; |
431 | } |
432 | StringRef getDescription() const final { |
433 | return "Test patterns to fold chained vector reductions" ; |
434 | } |
435 | void runOnOperation() override { |
436 | RewritePatternSet patterns(&getContext()); |
437 | populateChainedVectorReductionFoldingPatterns(patterns); |
438 | (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)); |
439 | } |
440 | }; |
441 | |
442 | struct TestVectorBreakDownReductionPatterns |
443 | : public PassWrapper<TestVectorBreakDownReductionPatterns, |
444 | OperationPass<func::FuncOp>> { |
445 | MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID( |
446 | TestVectorBreakDownReductionPatterns) |
447 | |
448 | StringRef getArgument() const final { |
449 | return "test-vector-break-down-reduction-patterns" ; |
450 | } |
451 | StringRef getDescription() const final { |
452 | return "Test patterns to break down vector reductions into arith " |
453 | "reductions" ; |
454 | } |
455 | void runOnOperation() override { |
456 | RewritePatternSet patterns(&getContext()); |
457 | populateBreakDownVectorReductionPatterns(patterns, |
458 | /*maxNumElementsToExtract=*/2); |
459 | (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)); |
460 | } |
461 | }; |
462 | |
463 | struct TestFlattenVectorTransferPatterns |
464 | : public PassWrapper<TestFlattenVectorTransferPatterns, |
465 | OperationPass<func::FuncOp>> { |
466 | MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID( |
467 | TestFlattenVectorTransferPatterns) |
468 | |
469 | TestFlattenVectorTransferPatterns() = default; |
470 | TestFlattenVectorTransferPatterns( |
471 | const TestFlattenVectorTransferPatterns &pass) |
472 | : PassWrapper(pass) {} |
473 | |
474 | StringRef getArgument() const final { |
475 | return "test-vector-transfer-flatten-patterns" ; |
476 | } |
477 | |
478 | StringRef getDescription() const final { |
479 | return "Test patterns to rewrite contiguous row-major N-dimensional " |
480 | "vector.transfer_{read,write} ops into 1D transfers" ; |
481 | } |
482 | |
483 | void getDependentDialects(DialectRegistry ®istry) const override { |
484 | registry.insert<memref::MemRefDialect>(); |
485 | registry.insert<affine::AffineDialect>(); |
486 | registry.insert<vector::VectorDialect>(); |
487 | } |
488 | |
489 | Option<unsigned> targetVectorBitwidth{ |
490 | *this, "target-vector-bitwidth" , |
491 | llvm::cl::desc( |
492 | "Minimum vector bitwidth to enable the flattening transformation. " |
493 | "For scalable vectors this is the base size, i.e. the size " |
494 | "corresponding to vscale=1." ), |
495 | llvm::cl::init(Val: std::numeric_limits<unsigned>::max())}; |
496 | |
497 | void runOnOperation() override { |
498 | RewritePatternSet patterns(&getContext()); |
499 | populateFlattenVectorTransferPatterns(patterns, targetVectorBitwidth); |
500 | (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)); |
501 | } |
502 | }; |
503 | |
504 | struct TestVectorScanLowering |
505 | : public PassWrapper<TestVectorScanLowering, OperationPass<func::FuncOp>> { |
506 | MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestVectorScanLowering) |
507 | |
508 | StringRef getArgument() const final { return "test-vector-scan-lowering" ; } |
509 | StringRef getDescription() const final { |
510 | return "Test lowering patterns that lower the scan op in the vector " |
511 | "dialect" ; |
512 | } |
513 | void runOnOperation() override { |
514 | RewritePatternSet patterns(&getContext()); |
515 | populateVectorScanLoweringPatterns(patterns); |
516 | (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)); |
517 | } |
518 | }; |
519 | |
520 | /// Allocate shared memory for a single warp to test lowering of |
521 | /// WarpExecuteOnLane0Op. |
522 | static Value allocateGlobalSharedMemory(Location loc, OpBuilder &builder, |
523 | WarpExecuteOnLane0Op warpOp, |
524 | Type type) { |
525 | static constexpr int64_t kSharedMemorySpace = 3; |
526 | // Compute type of shared memory buffer. |
527 | MemRefType memrefType; |
528 | if (auto vectorType = dyn_cast<VectorType>(type)) { |
529 | memrefType = |
530 | MemRefType::get(vectorType.getShape(), vectorType.getElementType(), {}, |
531 | kSharedMemorySpace); |
532 | } else { |
533 | memrefType = MemRefType::get({1}, type, {}, kSharedMemorySpace); |
534 | } |
535 | |
536 | // Get symbol table holding all shared memory globals. |
537 | ModuleOp moduleOp = warpOp->getParentOfType<ModuleOp>(); |
538 | SymbolTable symbolTable(moduleOp); |
539 | |
540 | // Create a pretty name. |
541 | SmallString<64> buf; |
542 | llvm::raw_svector_ostream os(buf); |
543 | interleave(memrefType.getShape(), os, "x" ); |
544 | os << "x" << memrefType.getElementType(); |
545 | std::string symbolName = (Twine("__shared_" ) + os.str()).str(); |
546 | |
547 | auto ip = builder.saveInsertionPoint(); |
548 | builder.setInsertionPoint(moduleOp); |
549 | auto global = builder.create<memref::GlobalOp>( |
550 | loc, |
551 | /*sym_name=*/symbolName, |
552 | /*sym_visibility=*/builder.getStringAttr("private" ), |
553 | /*type=*/memrefType, |
554 | /*initial_value=*/Attribute(), |
555 | /*constant=*/false, |
556 | /*alignment=*/IntegerAttr()); |
557 | symbolTable.insert(symbol: global); |
558 | // The symbol table inserts at the end of the module, but globals are a bit |
559 | // nicer if they are at the beginning. |
560 | global->moveBefore(&moduleOp.front()); |
561 | |
562 | builder.restoreInsertionPoint(ip); |
563 | return builder.create<memref::GetGlobalOp>(loc, memrefType, symbolName); |
564 | } |
565 | |
566 | static Value warpReduction(Location loc, OpBuilder &builder, Value input, |
567 | CombiningKind kind, uint32_t size) { |
568 | // First reduce on a single thread to get per lane reduction value. |
569 | Value laneVal = builder.create<vector::ReductionOp>(loc, kind, input); |
570 | // Parallel reduction using butterfly shuffles. |
571 | for (uint64_t i = 1; i < size; i <<= 1) { |
572 | Value shuffled = builder |
573 | .create<gpu::ShuffleOp>(loc, laneVal, i, |
574 | /*width=*/size, |
575 | /*mode=*/gpu::ShuffleMode::XOR) |
576 | .getShuffleResult(); |
577 | laneVal = makeArithReduction(builder, loc, kind, laneVal, shuffled); |
578 | } |
579 | return laneVal; |
580 | } |
581 | |
582 | struct TestVectorDistribution |
583 | : public PassWrapper<TestVectorDistribution, OperationPass<func::FuncOp>> { |
584 | MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestVectorDistribution) |
585 | |
586 | void getDependentDialects(DialectRegistry ®istry) const override { |
587 | registry.insert<scf::SCFDialect, memref::MemRefDialect, gpu::GPUDialect, |
588 | affine::AffineDialect>(); |
589 | } |
590 | |
591 | StringRef getArgument() const final { return "test-vector-warp-distribute" ; } |
592 | StringRef getDescription() const final { |
593 | return "Test vector warp distribute transformation and lowering patterns" ; |
594 | } |
595 | TestVectorDistribution() = default; |
596 | TestVectorDistribution(const TestVectorDistribution &pass) |
597 | : PassWrapper(pass) {} |
598 | |
599 | Option<bool> warpOpToSCF{ |
600 | *this, "rewrite-warp-ops-to-scf-if" , |
601 | llvm::cl::desc("Lower vector.warp_execute_on_lane0 to scf.if op" ), |
602 | llvm::cl::init(Val: false)}; |
603 | |
604 | Option<bool> distributeTransferWriteOps{ |
605 | *this, "distribute-transfer-write" , |
606 | llvm::cl::desc("Test distribution of transfer write" ), |
607 | llvm::cl::init(Val: false)}; |
608 | |
609 | Option<unsigned> maxTransferWriteElements{ |
610 | *this, "max-transfer-write-elements" , |
611 | llvm::cl::desc("Maximum number of transfer write elements to distribute" ), |
612 | llvm::cl::init(Val: 1)}; |
613 | |
614 | Option<bool> hoistUniform{*this, "hoist-uniform" , |
615 | llvm::cl::desc("Test hoist uniform" ), |
616 | llvm::cl::init(Val: false)}; |
617 | |
618 | Option<bool> propagateDistribution{ |
619 | *this, "propagate-distribution" , |
620 | llvm::cl::desc("Test distribution propgation" ), llvm::cl::init(Val: false)}; |
621 | |
622 | void runOnOperation() override { |
623 | RewritePatternSet patterns(&getContext()); |
624 | |
625 | getOperation().walk([&](Operation *op) { |
626 | if (auto warpOp = dyn_cast<WarpExecuteOnLane0Op>(op)) { |
627 | if (hoistUniform) { |
628 | moveScalarUniformCode(warpOp); |
629 | } |
630 | WalkResult::interrupt(); |
631 | } |
632 | }); |
633 | MLIRContext *ctx = &getContext(); |
634 | auto distributionFn = [](Value val) { |
635 | // Create an identity dim map of the same rank as the vector. |
636 | VectorType vecType = dyn_cast<VectorType>(val.getType()); |
637 | int64_t vecRank = vecType ? vecType.getRank() : 0; |
638 | OpBuilder builder(val.getContext()); |
639 | if (vecRank == 0) |
640 | return AffineMap::get(context: val.getContext()); |
641 | return AffineMap::getMultiDimIdentityMap(vecRank, val.getContext()); |
642 | }; |
643 | auto shuffleFn = [](Location loc, OpBuilder &builder, Value val, |
644 | Value srcIdx, int64_t warpSz) { |
645 | assert((val.getType().isF32() || val.getType().isInteger(32)) && |
646 | "unsupported shuffle type" ); |
647 | Type i32Type = builder.getIntegerType(32); |
648 | Value srcIdxI32 = |
649 | builder.create<arith::IndexCastOp>(loc, i32Type, srcIdx); |
650 | Value warpSzI32 = builder.create<arith::ConstantOp>( |
651 | loc, builder.getIntegerAttr(i32Type, warpSz)); |
652 | Value result = builder |
653 | .create<gpu::ShuffleOp>(loc, val, srcIdxI32, warpSzI32, |
654 | gpu::ShuffleMode::IDX) |
655 | .getResult(0); |
656 | return result; |
657 | }; |
658 | if (distributeTransferWriteOps && propagateDistribution) { |
659 | RewritePatternSet patterns(ctx); |
660 | vector::populatePropagateWarpVectorDistributionPatterns( |
661 | patterns, distributionFn, shuffleFn, /*benefit=*/1, |
662 | /*readBenefit=*/0); |
663 | vector::populateDistributeReduction(pattern&: patterns, distributedReductionFn: warpReduction, benefit: 1); |
664 | populateDistributeTransferWriteOpPatterns(patterns, distributionFn, 2); |
665 | (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)); |
666 | } else if (distributeTransferWriteOps) { |
667 | RewritePatternSet patterns(ctx); |
668 | populateDistributeTransferWriteOpPatterns(patterns, distributionFn, |
669 | maxTransferWriteElements); |
670 | (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)); |
671 | } else if (propagateDistribution) { |
672 | RewritePatternSet patterns(ctx); |
673 | vector::populatePropagateWarpVectorDistributionPatterns( |
674 | patterns, distributionFn, shuffleFn); |
675 | vector::populateDistributeReduction(pattern&: patterns, distributedReductionFn: warpReduction); |
676 | (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)); |
677 | } |
678 | WarpExecuteOnLane0LoweringOptions options; |
679 | options.warpAllocationFn = allocateGlobalSharedMemory; |
680 | options.warpSyncronizationFn = [](Location loc, OpBuilder &builder, |
681 | WarpExecuteOnLane0Op warpOp) { |
682 | builder.create<gpu::BarrierOp>(loc); |
683 | }; |
684 | // Test on one pattern in isolation. |
685 | if (warpOpToSCF) { |
686 | populateWarpExecuteOnLane0OpToScfForPattern(patterns, options); |
687 | (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)); |
688 | return; |
689 | } |
690 | } |
691 | }; |
692 | |
693 | struct |
694 | : public PassWrapper<TestVectorExtractStridedSliceLowering, |
695 | OperationPass<func::FuncOp>> { |
696 | MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID( |
697 | TestVectorExtractStridedSliceLowering) |
698 | |
699 | StringRef () const final { |
700 | return "test-vector-extract-strided-slice-lowering" ; |
701 | } |
702 | StringRef () const final { |
703 | return "Test lowering patterns that converts vector.extract_strided_slice " |
704 | "into a chain of vector.extract and vector.insert ops" ; |
705 | } |
706 | void () override { |
707 | RewritePatternSet patterns(&getContext()); |
708 | populateVectorExtractStridedSliceToExtractInsertChainPatterns(patterns); |
709 | (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)); |
710 | } |
711 | }; |
712 | |
713 | struct TestVectorBreakDownBitCast |
714 | : public PassWrapper<TestVectorBreakDownBitCast, |
715 | OperationPass<func::FuncOp>> { |
716 | MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestVectorBreakDownBitCast) |
717 | |
718 | StringRef getArgument() const final { |
719 | return "test-vector-break-down-bitcast" ; |
720 | } |
721 | StringRef getDescription() const final { |
722 | return "Test pattern that breaks down vector.bitcast ops " ; |
723 | } |
724 | void runOnOperation() override { |
725 | RewritePatternSet patterns(&getContext()); |
726 | populateBreakDownVectorBitCastOpPatterns(patterns, [](BitCastOp op) { |
727 | return op.getSourceVectorType().getShape().back() > 4; |
728 | }); |
729 | (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)); |
730 | } |
731 | }; |
732 | |
733 | struct TestCreateVectorBroadcast |
734 | : public PassWrapper<TestCreateVectorBroadcast, |
735 | OperationPass<func::FuncOp>> { |
736 | MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestCreateVectorBroadcast) |
737 | |
738 | StringRef getArgument() const final { return "test-create-vector-broadcast" ; } |
739 | StringRef getDescription() const final { |
740 | return "Test optimization transformations for transfer ops" ; |
741 | } |
742 | void getDependentDialects(DialectRegistry ®istry) const override { |
743 | registry.insert<vector::VectorDialect>(); |
744 | } |
745 | |
746 | void runOnOperation() override { |
747 | getOperation()->walk([](Operation *op) { |
748 | if (op->getName().getStringRef() != "test_create_broadcast" ) |
749 | return; |
750 | auto targetShape = |
751 | cast<VectorType>(op->getResult(0).getType()).getShape(); |
752 | auto arrayAttr = |
753 | cast<DenseI64ArrayAttr>(op->getDiscardableAttr("broadcast_dims" )) |
754 | .asArrayRef(); |
755 | llvm::SetVector<int64_t> broadcastedDims; |
756 | broadcastedDims.insert(arrayAttr.begin(), arrayAttr.end()); |
757 | OpBuilder b(op); |
758 | Value bcast = vector::BroadcastOp::createOrFoldBroadcastOp( |
759 | b, op->getOperand(0), targetShape, broadcastedDims); |
760 | op->getResult(0).replaceAllUsesWith(bcast); |
761 | op->erase(); |
762 | }); |
763 | } |
764 | }; |
765 | |
766 | struct TestVectorGatherLowering |
767 | : public PassWrapper<TestVectorGatherLowering, |
768 | OperationPass<func::FuncOp>> { |
769 | MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestVectorGatherLowering) |
770 | |
771 | StringRef getArgument() const final { return "test-vector-gather-lowering" ; } |
772 | StringRef getDescription() const final { |
773 | return "Test patterns that lower the gather op in the vector conditional " |
774 | "loads" ; |
775 | } |
776 | void getDependentDialects(DialectRegistry ®istry) const override { |
777 | registry.insert<arith::ArithDialect, func::FuncDialect, |
778 | memref::MemRefDialect, scf::SCFDialect, |
779 | tensor::TensorDialect, vector::VectorDialect>(); |
780 | } |
781 | |
782 | void runOnOperation() override { |
783 | RewritePatternSet patterns(&getContext()); |
784 | populateVectorGatherLoweringPatterns(patterns); |
785 | (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)); |
786 | } |
787 | }; |
788 | |
789 | struct TestFoldArithExtensionIntoVectorContractPatterns |
790 | : public PassWrapper<TestFoldArithExtensionIntoVectorContractPatterns, |
791 | OperationPass<func::FuncOp>> { |
792 | MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID( |
793 | TestFoldArithExtensionIntoVectorContractPatterns) |
794 | |
795 | StringRef getArgument() const final { |
796 | return "test-fold-arith-extf-into-vector-contract-patterns" ; |
797 | } |
798 | StringRef getDescription() const final { |
799 | return "Test patterns that fold arithmetic extension ops into vector " |
800 | "contract ops" ; |
801 | } |
802 | |
803 | void getDependentDialects(DialectRegistry ®istry) const override { |
804 | registry.insert<arith::ArithDialect, func::FuncDialect, nvgpu::NVGPUDialect, |
805 | memref::MemRefDialect, scf::SCFDialect, |
806 | tensor::TensorDialect, vector::VectorDialect>(); |
807 | } |
808 | |
809 | void runOnOperation() override { |
810 | RewritePatternSet patterns(&getContext()); |
811 | populateFoldArithExtensionPatterns(patterns); |
812 | (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)); |
813 | } |
814 | }; |
815 | |
816 | struct TestVectorEmulateMaskedLoadStore final |
817 | : public PassWrapper<TestVectorEmulateMaskedLoadStore, |
818 | OperationPass<func::FuncOp>> { |
819 | MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestVectorEmulateMaskedLoadStore) |
820 | |
821 | StringRef getArgument() const override { |
822 | return "test-vector-emulate-masked-load-store" ; |
823 | } |
824 | StringRef getDescription() const override { |
825 | return "Test patterns that emulate the maskedload/maskedstore op by " |
826 | " memref.load/store and scf.if" ; |
827 | } |
828 | void getDependentDialects(DialectRegistry ®istry) const override { |
829 | registry |
830 | .insert<arith::ArithDialect, func::FuncDialect, memref::MemRefDialect, |
831 | scf::SCFDialect, vector::VectorDialect>(); |
832 | } |
833 | |
834 | void runOnOperation() override { |
835 | RewritePatternSet patterns(&getContext()); |
836 | populateVectorMaskedLoadStoreEmulationPatterns(patterns); |
837 | (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)); |
838 | } |
839 | }; |
840 | |
841 | struct TestVectorLinearize final |
842 | : public PassWrapper<TestVectorLinearize, OperationPass<>> { |
843 | MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestVectorLinearize) |
844 | |
845 | TestVectorLinearize() = default; |
846 | TestVectorLinearize(const TestVectorLinearize &pass) : PassWrapper(pass) {} |
847 | |
848 | StringRef getArgument() const override { return "test-vector-linearize" ; } |
849 | StringRef getDescription() const override { |
850 | return "Linearizes ND vectors for N >= 2 into 1D vectors" ; |
851 | } |
852 | void getDependentDialects(DialectRegistry ®istry) const override { |
853 | registry.insert<vector::VectorDialect>(); |
854 | } |
855 | |
856 | Option<unsigned> targetVectorBitwidth{ |
857 | *this, "target-vector-bitwidth" , |
858 | llvm::cl::desc( |
859 | "Minimum vector bitwidth to enable the flattening transformation" ), |
860 | llvm::cl::init(Val: std::numeric_limits<unsigned>::max())}; |
861 | void runOnOperation() override { |
862 | auto *context = &getContext(); |
863 | |
864 | TypeConverter typeConverter; |
865 | RewritePatternSet patterns(context); |
866 | ConversionTarget target(*context); |
867 | |
868 | vector::populateVectorLinearizeTypeConversionsAndLegality( |
869 | typeConverter, patterns, target, targetBitWidth: targetVectorBitwidth); |
870 | vector::populateVectorLinearizeShuffleLikeOpsPatterns( |
871 | typeConverter, patterns, target, targetBitWidth: targetVectorBitwidth); |
872 | if (failed(applyPartialConversion(getOperation(), target, |
873 | std::move(patterns)))) |
874 | return signalPassFailure(); |
875 | } |
876 | }; |
877 | } // namespace |
878 | |
879 | namespace mlir { |
880 | namespace test { |
881 | void registerTestVectorLowerings() { |
882 | PassRegistration<TestVectorToVectorLowering>(); |
883 | |
884 | PassRegistration<TestVectorContractionPrepareForMMTLowering>(); |
885 | |
886 | PassRegistration<TestVectorUnrollingPatterns>(); |
887 | |
888 | PassRegistration<TestVectorTransferUnrollingPatterns>(); |
889 | |
890 | PassRegistration<TestScalarVectorTransferLoweringPatterns>(); |
891 | |
892 | PassRegistration<TestVectorTransferOpt>(); |
893 | |
894 | PassRegistration<TestVectorTransferCollapseInnerMostContiguousDims>(); |
895 | |
896 | PassRegistration<TestSinkVectorBroadcast>(); |
897 | |
898 | PassRegistration<TestVectorReduceToContractPatternsPatterns>(); |
899 | |
900 | PassRegistration<TestVectorChainedReductionFoldingPatterns>(); |
901 | |
902 | PassRegistration<TestVectorBreakDownReductionPatterns>(); |
903 | |
904 | PassRegistration<TestFlattenVectorTransferPatterns>(); |
905 | |
906 | PassRegistration<TestVectorScanLowering>(); |
907 | |
908 | PassRegistration<TestVectorDistribution>(); |
909 | |
910 | PassRegistration<TestVectorExtractStridedSliceLowering>(); |
911 | |
912 | PassRegistration<TestVectorBreakDownBitCast>(); |
913 | |
914 | PassRegistration<TestCreateVectorBroadcast>(); |
915 | |
916 | PassRegistration<TestVectorGatherLowering>(); |
917 | |
918 | PassRegistration<TestFoldArithExtensionIntoVectorContractPatterns>(); |
919 | |
920 | PassRegistration<TestVectorEmulateMaskedLoadStore>(); |
921 | |
922 | PassRegistration<TestVectorLinearize>(); |
923 | } |
924 | } // namespace test |
925 | } // namespace mlir |
926 | |