1 | //===- TestVectorTransforms.cpp - Test Vector transforms and lowerings ----===// |
---|---|
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | |
9 | #include <optional> |
10 | |
11 | #include "mlir/Analysis/SliceAnalysis.h" |
12 | #include "mlir/Dialect/Affine/IR/AffineOps.h" |
13 | #include "mlir/Dialect/Arith/IR/Arith.h" |
14 | #include "mlir/Dialect/Func/IR/FuncOps.h" |
15 | #include "mlir/Dialect/GPU/IR/GPUDialect.h" |
16 | #include "mlir/Dialect/Linalg/Passes.h" |
17 | #include "mlir/Dialect/MemRef/IR/MemRef.h" |
18 | #include "mlir/Dialect/NVGPU/IR/NVGPUDialect.h" |
19 | #include "mlir/Dialect/SCF/IR/SCF.h" |
20 | #include "mlir/Dialect/SCF/Transforms/Patterns.h" |
21 | #include "mlir/Dialect/Tensor/IR/Tensor.h" |
22 | #include "mlir/Dialect/Vector/IR/VectorOps.h" |
23 | #include "mlir/Dialect/Vector/Transforms/LoweringPatterns.h" |
24 | #include "mlir/Dialect/Vector/Transforms/VectorDistribution.h" |
25 | #include "mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h" |
26 | #include "mlir/Dialect/Vector/Transforms/VectorTransforms.h" |
27 | #include "mlir/Pass/Pass.h" |
28 | #include "mlir/Pass/PassManager.h" |
29 | #include "mlir/Support/LLVM.h" |
30 | #include "mlir/Transforms/GreedyPatternRewriteDriver.h" |
31 | |
32 | using namespace mlir; |
33 | using namespace mlir::linalg; |
34 | using namespace mlir::vector; |
35 | |
36 | namespace { |
37 | |
38 | struct TestVectorToVectorLowering |
39 | : public PassWrapper<TestVectorToVectorLowering, |
40 | OperationPass<func::FuncOp>> { |
41 | MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestVectorToVectorLowering) |
42 | |
43 | TestVectorToVectorLowering() = default; |
44 | TestVectorToVectorLowering(const TestVectorToVectorLowering &pass) |
45 | : PassWrapper(pass) {} |
46 | StringRef getArgument() const final { |
47 | return "test-vector-to-vector-lowering"; |
48 | } |
49 | StringRef getDescription() const final { |
50 | return "Test lowering patterns between ops in the vector dialect"; |
51 | } |
52 | |
53 | void getDependentDialects(DialectRegistry ®istry) const override { |
54 | registry.insert<affine::AffineDialect>(); |
55 | registry.insert<vector::VectorDialect>(); |
56 | } |
57 | |
58 | Option<bool> unroll{*this, "unroll", llvm::cl::desc( "Include unrolling"), |
59 | llvm::cl::init(Val: false)}; |
60 | |
61 | void runOnOperation() override { |
62 | auto *ctx = &getContext(); |
63 | RewritePatternSet patterns(ctx); |
64 | if (unroll) { |
65 | populateVectorUnrollPatterns( |
66 | patterns, |
67 | UnrollVectorOptions().setNativeShapeFn(getShape).setFilterConstraint( |
68 | filter)); |
69 | } |
70 | populateVectorToVectorCanonicalizationPatterns(patterns); |
71 | populateBubbleVectorBitCastOpPatterns(patterns); |
72 | populateCastAwayVectorLeadingOneDimPatterns(patterns); |
73 | (void)applyPatternsGreedily(getOperation(), std::move(patterns)); |
74 | } |
75 | |
76 | private: |
77 | // Return the target shape based on op type. |
78 | static std::optional<SmallVector<int64_t>> getShape(Operation *op) { |
79 | if (isa<arith::AddFOp, arith::SelectOp, arith::CmpFOp>(op)) |
80 | return SmallVector<int64_t>(2, 2); |
81 | if (isa<vector::ContractionOp>(Val: op)) |
82 | return SmallVector<int64_t>(3, 2); |
83 | // For transfer ops, just propagate the shape coming from |
84 | // InsertStridedSlices/ExtractStridedSlices. |
85 | if (auto readOp = dyn_cast<vector::TransferReadOp>(op)) { |
86 | VectorType dstVec; |
87 | for (Operation *users : readOp->getUsers()) { |
88 | auto extract = dyn_cast<ExtractStridedSliceOp>(users); |
89 | if (!extract) |
90 | return std::nullopt; |
91 | auto vecType = cast<VectorType>(extract.getResult().getType()); |
92 | if (dstVec && dstVec != vecType) |
93 | return std::nullopt; |
94 | dstVec = vecType; |
95 | } |
96 | return SmallVector<int64_t>(dstVec.getShape()); |
97 | } |
98 | if (auto writeOp = dyn_cast<vector::TransferWriteOp>(op)) { |
99 | auto insert = writeOp.getVector().getDefiningOp<InsertStridedSliceOp>(); |
100 | if (!insert) |
101 | return std::nullopt; |
102 | ArrayRef<int64_t> shape = insert.getSourceVectorType().getShape(); |
103 | return SmallVector<int64_t>(shape); |
104 | } |
105 | return std::nullopt; |
106 | } |
107 | |
108 | static LogicalResult filter(Operation *op) { |
109 | return success(isa<arith::AddFOp, arith::SelectOp, arith::CmpFOp, |
110 | ContractionOp, TransferReadOp, TransferWriteOp>(op)); |
111 | } |
112 | }; |
113 | |
114 | struct TestVectorContractionPrepareForMMTLowering |
115 | : public PassWrapper<TestVectorContractionPrepareForMMTLowering, |
116 | OperationPass<func::FuncOp>> { |
117 | MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID( |
118 | TestVectorContractionPrepareForMMTLowering) |
119 | |
120 | StringRef getArgument() const final { |
121 | return "test-vector-contraction-prepare-for-mmt-lowering"; |
122 | } |
123 | StringRef getDescription() const final { |
124 | return "Test vector.contraction matmul canonicalization for MMT lowering."; |
125 | } |
126 | TestVectorContractionPrepareForMMTLowering() = default; |
127 | |
128 | void getDependentDialects(DialectRegistry ®istry) const override { |
129 | registry.insert<affine::AffineDialect, arith::ArithDialect, |
130 | vector::VectorDialect>(); |
131 | } |
132 | |
133 | void runOnOperation() override { |
134 | MLIRContext *ctx = &getContext(); |
135 | RewritePatternSet patterns(ctx); |
136 | vector::populateVectorContractCanonicalizeMatmulToMMT(patterns); |
137 | (void)applyPatternsGreedily(getOperation(), std::move(patterns)); |
138 | } |
139 | }; |
140 | |
141 | struct TestVectorUnrollingPatterns |
142 | : public PassWrapper<TestVectorUnrollingPatterns, |
143 | OperationPass<func::FuncOp>> { |
144 | MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestVectorUnrollingPatterns) |
145 | |
146 | StringRef getArgument() const final { |
147 | return "test-vector-unrolling-patterns"; |
148 | } |
149 | StringRef getDescription() const final { |
150 | return "Test lowering patterns to unroll contract ops in the vector " |
151 | "dialect"; |
152 | } |
153 | TestVectorUnrollingPatterns() = default; |
154 | TestVectorUnrollingPatterns(const TestVectorUnrollingPatterns &pass) |
155 | : PassWrapper(pass) {} |
156 | void runOnOperation() override { |
157 | MLIRContext *ctx = &getContext(); |
158 | RewritePatternSet patterns(ctx); |
159 | populateVectorUnrollPatterns( |
160 | patterns, |
161 | options: UnrollVectorOptions() |
162 | .setNativeShape(ArrayRef<int64_t>{2, 2}) |
163 | .setFilterConstraint([](Operation *op) { |
164 | return success( |
165 | isa<arith::AddFOp, vector::FMAOp, vector::MultiDimReductionOp, |
166 | vector::BroadcastOp>(op)); |
167 | })); |
168 | populateVectorUnrollPatterns( |
169 | patterns, options: UnrollVectorOptions() |
170 | .setNativeShape(ArrayRef<int64_t>{2}) |
171 | .setFilterConstraint([](Operation *op) { |
172 | return success(isa<vector::ReductionOp>(op)); |
173 | })); |
174 | populateVectorUnrollPatterns( |
175 | patterns, options: UnrollVectorOptions() |
176 | .setNativeShape(ArrayRef<int64_t>{1, 3, 4, 2}) |
177 | .setFilterConstraint([](Operation *op) { |
178 | return success(isa<vector::TransposeOp>(op)); |
179 | })); |
180 | |
181 | if (unrollBasedOnType) { |
182 | UnrollVectorOptions::NativeShapeFnType nativeShapeFn = |
183 | [](Operation *op) -> std::optional<SmallVector<int64_t>> { |
184 | vector::ContractionOp contractOp = cast<vector::ContractionOp>(op); |
185 | SmallVector<int64_t> nativeShape(contractOp.getIteratorTypes().size(), |
186 | 4); |
187 | Type lhsType = contractOp.getLhsType().getElementType(); |
188 | nativeShape[nativeShape.size() - 1] = lhsType.isF16() ? 4 : 2; |
189 | return nativeShape; |
190 | }; |
191 | |
192 | UnrollVectorOptions opts; |
193 | opts.setNativeShapeFn(nativeShapeFn) |
194 | .setFilterConstraint( |
195 | [](Operation *op) { return success(IsSuccess: isa<ContractionOp>(Val: op)); }); |
196 | |
197 | if (!unrollOrder.empty()) { |
198 | opts.setUnrollTraversalOrderFn( |
199 | [this](Operation *op) -> std::optional<SmallVector<int64_t>> { |
200 | vector::ContractionOp contractOp = |
201 | cast<vector::ContractionOp>(op); |
202 | if (contractOp.getIteratorTypes().size() == unrollOrder.size()) |
203 | return SmallVector<int64_t>(unrollOrder.begin(), |
204 | unrollOrder.end()); |
205 | return std::nullopt; |
206 | }); |
207 | } |
208 | populateVectorUnrollPatterns(patterns, options: opts); |
209 | } else { |
210 | auto nativeShapeFn = |
211 | [](Operation *op) -> std::optional<SmallVector<int64_t>> { |
212 | auto contractOp = dyn_cast<ContractionOp>(op); |
213 | if (!contractOp) |
214 | return std::nullopt; |
215 | return SmallVector<int64_t>(contractOp.getIteratorTypes().size(), 2); |
216 | }; |
217 | populateVectorUnrollPatterns(patterns, |
218 | UnrollVectorOptions() |
219 | .setNativeShapeFn(nativeShapeFn) |
220 | .setFilterConstraint([](Operation *op) { |
221 | return success(IsSuccess: isa<ContractionOp>(Val: op)); |
222 | })); |
223 | } |
224 | populateVectorToVectorCanonicalizationPatterns(patterns); |
225 | (void)applyPatternsGreedily(getOperation(), std::move(patterns)); |
226 | } |
227 | |
228 | ListOption<int64_t> unrollOrder{*this, "unroll-order", |
229 | llvm::cl::desc("set the unroll order")}; |
230 | |
231 | Option<bool> unrollBasedOnType{ |
232 | *this, "unroll-based-on-type", |
233 | llvm::cl::desc("Set the unroll factor based on type of the operation"), |
234 | llvm::cl::init(Val: false)}; |
235 | }; |
236 | |
237 | struct TestVectorTransferUnrollingPatterns |
238 | : public PassWrapper<TestVectorTransferUnrollingPatterns, |
239 | OperationPass<func::FuncOp>> { |
240 | MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID( |
241 | TestVectorTransferUnrollingPatterns) |
242 | |
243 | TestVectorTransferUnrollingPatterns() = default; |
244 | TestVectorTransferUnrollingPatterns( |
245 | const TestVectorTransferUnrollingPatterns &pass) |
246 | : PassWrapper(pass) {} |
247 | |
248 | void getDependentDialects(DialectRegistry ®istry) const override { |
249 | registry.insert<affine::AffineDialect>(); |
250 | } |
251 | StringRef getArgument() const final { |
252 | return "test-vector-transfer-unrolling-patterns"; |
253 | } |
254 | StringRef getDescription() const final { |
255 | return "Test lowering patterns to unroll transfer ops in the vector " |
256 | "dialect"; |
257 | } |
258 | void runOnOperation() override { |
259 | MLIRContext *ctx = &getContext(); |
260 | RewritePatternSet patterns(ctx); |
261 | UnrollVectorOptions opts; |
262 | opts.setNativeShape(ArrayRef<int64_t>{2, 2}) |
263 | .setFilterConstraint([](Operation *op) { |
264 | return success(isa<vector::TransferReadOp, vector::TransferWriteOp, |
265 | vector::GatherOp>(op)); |
266 | }); |
267 | if (reverseUnrollOrder.getValue()) { |
268 | opts.setUnrollTraversalOrderFn( |
269 | [](Operation *op) -> std::optional<SmallVector<int64_t>> { |
270 | int64_t numLoops = 0; |
271 | if (auto readOp = dyn_cast<vector::TransferReadOp>(op)) |
272 | numLoops = readOp.getVectorType().getRank(); |
273 | else if (auto writeOp = dyn_cast<vector::TransferWriteOp>(op)) |
274 | numLoops = writeOp.getVectorType().getRank(); |
275 | else if (auto gatherOp = dyn_cast<vector::GatherOp>(op)) |
276 | numLoops = gatherOp.getVectorType().getRank(); |
277 | else |
278 | return std::nullopt; |
279 | auto order = llvm::reverse(C: llvm::seq<int64_t>(Begin: 0, End: numLoops)); |
280 | return llvm::to_vector(Range&: order); |
281 | }); |
282 | } |
283 | populateVectorUnrollPatterns(patterns, options: opts); |
284 | populateVectorToVectorCanonicalizationPatterns(patterns); |
285 | (void)applyPatternsGreedily(getOperation(), std::move(patterns)); |
286 | } |
287 | |
288 | Option<bool> reverseUnrollOrder{ |
289 | *this, "reverse-unroll-order", |
290 | llvm::cl::desc( |
291 | "reverse the order of unrolling of vector transfer operations"), |
292 | llvm::cl::init(Val: false)}; |
293 | }; |
294 | |
295 | struct TestScalarVectorTransferLoweringPatterns |
296 | : public PassWrapper<TestScalarVectorTransferLoweringPatterns, |
297 | OperationPass<func::FuncOp>> { |
298 | MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID( |
299 | TestScalarVectorTransferLoweringPatterns) |
300 | |
301 | TestScalarVectorTransferLoweringPatterns() = default; |
302 | TestScalarVectorTransferLoweringPatterns( |
303 | const TestScalarVectorTransferLoweringPatterns &pass) |
304 | : PassWrapper(pass) {} |
305 | |
306 | StringRef getArgument() const final { |
307 | return "test-scalar-vector-transfer-lowering"; |
308 | } |
309 | StringRef getDescription() const final { |
310 | return "Test lowering of scalar vector transfers to memref loads/stores."; |
311 | } |
312 | |
313 | void getDependentDialects(DialectRegistry ®istry) const override { |
314 | registry.insert<affine::AffineDialect, memref::MemRefDialect, |
315 | tensor::TensorDialect, vector::VectorDialect>(); |
316 | } |
317 | |
318 | Option<bool> allowMultipleUses{ |
319 | *this, "allow-multiple-uses", |
320 | llvm::cl::desc("Fold transfer operations with multiple uses"), |
321 | llvm::cl::init(Val: false)}; |
322 | |
323 | void runOnOperation() override { |
324 | MLIRContext *ctx = &getContext(); |
325 | RewritePatternSet patterns(ctx); |
326 | vector::populateScalarVectorTransferLoweringPatterns( |
327 | patterns, /*benefit=*/1, allowMultipleUses: allowMultipleUses.getValue()); |
328 | (void)applyPatternsGreedily(getOperation(), std::move(patterns)); |
329 | } |
330 | }; |
331 | |
332 | struct TestVectorTransferOpt |
333 | : public PassWrapper<TestVectorTransferOpt, OperationPass<func::FuncOp>> { |
334 | MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestVectorTransferOpt) |
335 | |
336 | StringRef getArgument() const final { return "test-vector-transferop-opt"; } |
337 | StringRef getDescription() const final { |
338 | return "Test optimization transformations for transfer ops"; |
339 | } |
340 | void runOnOperation() override { |
341 | IRRewriter rewriter(&getContext()); |
342 | transferOpflowOpt(rewriter, getOperation()); |
343 | } |
344 | }; |
345 | |
346 | struct TestVectorTransferCollapseInnerMostContiguousDims |
347 | : public PassWrapper<TestVectorTransferCollapseInnerMostContiguousDims, |
348 | OperationPass<func::FuncOp>> { |
349 | MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID( |
350 | TestVectorTransferCollapseInnerMostContiguousDims) |
351 | |
352 | TestVectorTransferCollapseInnerMostContiguousDims() = default; |
353 | TestVectorTransferCollapseInnerMostContiguousDims( |
354 | const TestVectorTransferCollapseInnerMostContiguousDims &pass) = default; |
355 | |
356 | void getDependentDialects(DialectRegistry ®istry) const override { |
357 | registry.insert<memref::MemRefDialect, affine::AffineDialect>(); |
358 | } |
359 | |
360 | StringRef getArgument() const final { |
361 | return "test-vector-transfer-collapse-inner-most-dims"; |
362 | } |
363 | |
364 | StringRef getDescription() const final { |
365 | return "Test lowering patterns that reduces the rank of the vector " |
366 | "transfer memory and vector operands."; |
367 | } |
368 | |
369 | void runOnOperation() override { |
370 | RewritePatternSet patterns(&getContext()); |
371 | populateVectorTransferCollapseInnerMostContiguousDimsPatterns(patterns); |
372 | (void)applyPatternsGreedily(getOperation(), std::move(patterns)); |
373 | } |
374 | }; |
375 | |
376 | struct TestVectorSinkPatterns |
377 | : public PassWrapper<TestVectorSinkPatterns, OperationPass<func::FuncOp>> { |
378 | MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestVectorSinkPatterns) |
379 | |
380 | TestVectorSinkPatterns() = default; |
381 | TestVectorSinkPatterns(const TestVectorSinkPatterns &pass) = default; |
382 | |
383 | void getDependentDialects(DialectRegistry ®istry) const override { |
384 | registry.insert<memref::MemRefDialect, affine::AffineDialect>(); |
385 | } |
386 | |
387 | StringRef getArgument() const final { return "test-vector-sink-patterns"; } |
388 | |
389 | StringRef getDescription() const final { |
390 | return "Test lowering patterns that eliminate redundant broadcast " |
391 | "and transpose operations."; |
392 | } |
393 | |
394 | void runOnOperation() override { |
395 | RewritePatternSet patterns(&getContext()); |
396 | populateSinkVectorOpsPatterns(patterns); |
397 | populateSinkVectorMemOpsPatterns(patterns); |
398 | (void)applyPatternsGreedily(getOperation(), std::move(patterns)); |
399 | } |
400 | }; |
401 | |
402 | struct TestVectorReduceToContractPatternsPatterns |
403 | : public PassWrapper<TestVectorReduceToContractPatternsPatterns, |
404 | OperationPass<func::FuncOp>> { |
405 | MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID( |
406 | TestVectorReduceToContractPatternsPatterns) |
407 | |
408 | StringRef getArgument() const final { |
409 | return "test-vector-reduction-to-contract-patterns"; |
410 | } |
411 | StringRef getDescription() const final { |
412 | return "Test patterns to convert multireduce op to contract and combine " |
413 | "broadcast/transpose to contract"; |
414 | } |
415 | void runOnOperation() override { |
416 | RewritePatternSet patterns(&getContext()); |
417 | populateVectorReductionToContractPatterns(patterns); |
418 | (void)applyPatternsGreedily(getOperation(), std::move(patterns)); |
419 | } |
420 | }; |
421 | |
422 | struct TestVectorChainedReductionFoldingPatterns |
423 | : public PassWrapper<TestVectorChainedReductionFoldingPatterns, |
424 | OperationPass<func::FuncOp>> { |
425 | MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID( |
426 | TestVectorChainedReductionFoldingPatterns) |
427 | |
428 | StringRef getArgument() const final { |
429 | return "test-vector-chained-reduction-folding-patterns"; |
430 | } |
431 | StringRef getDescription() const final { |
432 | return "Test patterns to fold chained vector reductions"; |
433 | } |
434 | void runOnOperation() override { |
435 | RewritePatternSet patterns(&getContext()); |
436 | populateChainedVectorReductionFoldingPatterns(patterns); |
437 | (void)applyPatternsGreedily(getOperation(), std::move(patterns)); |
438 | } |
439 | }; |
440 | |
441 | struct TestVectorBreakDownReductionPatterns |
442 | : public PassWrapper<TestVectorBreakDownReductionPatterns, |
443 | OperationPass<func::FuncOp>> { |
444 | MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID( |
445 | TestVectorBreakDownReductionPatterns) |
446 | |
447 | StringRef getArgument() const final { |
448 | return "test-vector-break-down-reduction-patterns"; |
449 | } |
450 | StringRef getDescription() const final { |
451 | return "Test patterns to break down vector reductions into arith " |
452 | "reductions"; |
453 | } |
454 | void runOnOperation() override { |
455 | RewritePatternSet patterns(&getContext()); |
456 | populateBreakDownVectorReductionPatterns(patterns, |
457 | /*maxNumElementsToExtract=*/2); |
458 | (void)applyPatternsGreedily(getOperation(), std::move(patterns)); |
459 | } |
460 | }; |
461 | |
462 | struct TestFlattenVectorTransferPatterns |
463 | : public PassWrapper<TestFlattenVectorTransferPatterns, |
464 | OperationPass<func::FuncOp>> { |
465 | MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID( |
466 | TestFlattenVectorTransferPatterns) |
467 | |
468 | TestFlattenVectorTransferPatterns() = default; |
469 | TestFlattenVectorTransferPatterns( |
470 | const TestFlattenVectorTransferPatterns &pass) |
471 | : PassWrapper(pass) {} |
472 | |
473 | StringRef getArgument() const final { |
474 | return "test-vector-transfer-flatten-patterns"; |
475 | } |
476 | |
477 | StringRef getDescription() const final { |
478 | return "Test patterns to rewrite contiguous row-major N-dimensional " |
479 | "vector.transfer_{read,write} ops into 1D transfers"; |
480 | } |
481 | |
482 | void getDependentDialects(DialectRegistry ®istry) const override { |
483 | registry.insert<memref::MemRefDialect>(); |
484 | registry.insert<affine::AffineDialect>(); |
485 | registry.insert<vector::VectorDialect>(); |
486 | } |
487 | |
488 | Option<unsigned> targetVectorBitwidth{ |
489 | *this, "target-vector-bitwidth", |
490 | llvm::cl::desc( |
491 | "Minimum vector bitwidth to enable the flattening transformation. " |
492 | "For scalable vectors this is the base size, i.e. the size " |
493 | "corresponding to vscale=1."), |
494 | llvm::cl::init(Val: std::numeric_limits<unsigned>::max())}; |
495 | |
496 | void runOnOperation() override { |
497 | RewritePatternSet patterns(&getContext()); |
498 | populateFlattenVectorTransferPatterns(patterns, targetVectorBitwidth); |
499 | (void)applyPatternsGreedily(getOperation(), std::move(patterns)); |
500 | } |
501 | }; |
502 | |
503 | struct TestVectorScanLowering |
504 | : public PassWrapper<TestVectorScanLowering, OperationPass<func::FuncOp>> { |
505 | MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestVectorScanLowering) |
506 | |
507 | StringRef getArgument() const final { return "test-vector-scan-lowering"; } |
508 | StringRef getDescription() const final { |
509 | return "Test lowering patterns that lower the scan op in the vector " |
510 | "dialect"; |
511 | } |
512 | void runOnOperation() override { |
513 | RewritePatternSet patterns(&getContext()); |
514 | populateVectorScanLoweringPatterns(patterns); |
515 | (void)applyPatternsGreedily(getOperation(), std::move(patterns)); |
516 | } |
517 | }; |
518 | |
519 | /// Allocate shared memory for a single warp to test lowering of |
520 | /// WarpExecuteOnLane0Op. |
521 | static Value allocateGlobalSharedMemory(Location loc, OpBuilder &builder, |
522 | gpu::WarpExecuteOnLane0Op warpOp, |
523 | Type type) { |
524 | static constexpr int64_t kSharedMemorySpace = 3; |
525 | // Compute type of shared memory buffer. |
526 | MemRefType memrefType; |
527 | if (auto vectorType = dyn_cast<VectorType>(type)) { |
528 | memrefType = |
529 | MemRefType::get(vectorType.getShape(), vectorType.getElementType(), {}, |
530 | kSharedMemorySpace); |
531 | } else { |
532 | memrefType = MemRefType::get({1}, type, {}, kSharedMemorySpace); |
533 | } |
534 | |
535 | // Get symbol table holding all shared memory globals. |
536 | ModuleOp moduleOp = warpOp->getParentOfType<ModuleOp>(); |
537 | SymbolTable symbolTable(moduleOp); |
538 | |
539 | // Create a pretty name. |
540 | SmallString<64> buf; |
541 | llvm::raw_svector_ostream os(buf); |
542 | interleave(memrefType.getShape(), os, "x"); |
543 | os << "x"<< memrefType.getElementType(); |
544 | std::string symbolName = (Twine("__shared_") + os.str()).str(); |
545 | |
546 | auto ip = builder.saveInsertionPoint(); |
547 | builder.setInsertionPoint(moduleOp); |
548 | auto global = builder.create<memref::GlobalOp>( |
549 | loc, |
550 | /*sym_name=*/symbolName, |
551 | /*sym_visibility=*/builder.getStringAttr("private"), |
552 | /*type=*/memrefType, |
553 | /*initial_value=*/Attribute(), |
554 | /*constant=*/false, |
555 | /*alignment=*/IntegerAttr()); |
556 | symbolTable.insert(symbol: global); |
557 | // The symbol table inserts at the end of the module, but globals are a bit |
558 | // nicer if they are at the beginning. |
559 | global->moveBefore(&moduleOp.front()); |
560 | |
561 | builder.restoreInsertionPoint(ip); |
562 | return builder.create<memref::GetGlobalOp>(loc, memrefType, symbolName); |
563 | } |
564 | |
565 | static Value warpReduction(Location loc, OpBuilder &builder, Value input, |
566 | CombiningKind kind, uint32_t size) { |
567 | // First reduce on a single thread to get per lane reduction value. |
568 | Value laneVal = builder.create<vector::ReductionOp>(loc, kind, input); |
569 | // Parallel reduction using butterfly shuffles. |
570 | for (uint64_t i = 1; i < size; i <<= 1) { |
571 | Value shuffled = builder |
572 | .create<gpu::ShuffleOp>(loc, laneVal, i, |
573 | /*width=*/size, |
574 | /*mode=*/gpu::ShuffleMode::XOR) |
575 | .getShuffleResult(); |
576 | laneVal = makeArithReduction(builder, loc, kind, laneVal, shuffled); |
577 | } |
578 | return laneVal; |
579 | } |
580 | |
581 | struct TestVectorDistribution |
582 | : public PassWrapper<TestVectorDistribution, OperationPass<func::FuncOp>> { |
583 | MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestVectorDistribution) |
584 | |
585 | void getDependentDialects(DialectRegistry ®istry) const override { |
586 | registry |
587 | .insert<vector::VectorDialect, scf::SCFDialect, memref::MemRefDialect, |
588 | gpu::GPUDialect, affine::AffineDialect>(); |
589 | } |
590 | |
591 | StringRef getArgument() const final { return "test-vector-warp-distribute"; } |
592 | StringRef getDescription() const final { |
593 | return "Test vector warp distribute transformation and lowering patterns"; |
594 | } |
595 | TestVectorDistribution() = default; |
596 | TestVectorDistribution(const TestVectorDistribution &pass) |
597 | : PassWrapper(pass) {} |
598 | |
599 | Option<bool> warpOpToSCF{ |
600 | *this, "rewrite-warp-ops-to-scf-if", |
601 | llvm::cl::desc("Lower vector.warp_execute_on_lane0 to scf.if op"), |
602 | llvm::cl::init(Val: false)}; |
603 | |
604 | Option<bool> distributeTransferWriteOps{ |
605 | *this, "distribute-transfer-write", |
606 | llvm::cl::desc("Test distribution of transfer write"), |
607 | llvm::cl::init(Val: false)}; |
608 | |
609 | Option<unsigned> maxTransferWriteElements{ |
610 | *this, "max-transfer-write-elements", |
611 | llvm::cl::desc("Maximum number of transfer write elements to distribute"), |
612 | llvm::cl::init(Val: 1)}; |
613 | |
614 | Option<bool> hoistUniform{*this, "hoist-uniform", |
615 | llvm::cl::desc("Test hoist uniform"), |
616 | llvm::cl::init(Val: false)}; |
617 | |
618 | Option<bool> propagateDistribution{ |
619 | *this, "propagate-distribution", |
620 | llvm::cl::desc("Test distribution propagation"), llvm::cl::init(Val: false)}; |
621 | |
622 | void runOnOperation() override { |
623 | RewritePatternSet patterns(&getContext()); |
624 | |
625 | getOperation().walk([&](Operation *op) { |
626 | if (auto warpOp = dyn_cast<gpu::WarpExecuteOnLane0Op>(op)) { |
627 | if (hoistUniform) { |
628 | moveScalarUniformCode(warpOp); |
629 | } |
630 | WalkResult::interrupt(); |
631 | } |
632 | }); |
633 | MLIRContext *ctx = &getContext(); |
634 | auto distributionFn = [](Value val) { |
635 | // Create an identity dim map of the same rank as the vector. |
636 | VectorType vecType = dyn_cast<VectorType>(val.getType()); |
637 | int64_t vecRank = vecType ? vecType.getRank() : 0; |
638 | OpBuilder builder(val.getContext()); |
639 | if (vecRank == 0) |
640 | return AffineMap::get(context: val.getContext()); |
641 | return AffineMap::getMultiDimIdentityMap(vecRank, val.getContext()); |
642 | }; |
643 | auto shuffleFn = [](Location loc, OpBuilder &builder, Value val, |
644 | Value srcIdx, int64_t warpSz) { |
645 | assert((val.getType().isF32() || val.getType().isInteger(32)) && |
646 | "unsupported shuffle type"); |
647 | Type i32Type = builder.getIntegerType(32); |
648 | Value srcIdxI32 = |
649 | builder.create<arith::IndexCastOp>(loc, i32Type, srcIdx); |
650 | Value warpSzI32 = builder.create<arith::ConstantOp>( |
651 | loc, builder.getIntegerAttr(i32Type, warpSz)); |
652 | Value result = builder |
653 | .create<gpu::ShuffleOp>(loc, val, srcIdxI32, warpSzI32, |
654 | gpu::ShuffleMode::IDX) |
655 | .getResult(0); |
656 | return result; |
657 | }; |
658 | if (distributeTransferWriteOps && propagateDistribution) { |
659 | RewritePatternSet patterns(ctx); |
660 | vector::populatePropagateWarpVectorDistributionPatterns( |
661 | patterns, distributionFn, shuffleFn, /*benefit=*/1, |
662 | /*readBenefit=*/0); |
663 | vector::populateDistributeReduction(pattern&: patterns, distributedReductionFn: warpReduction, benefit: 1); |
664 | populateDistributeTransferWriteOpPatterns(patterns, distributionFn, 2); |
665 | (void)applyPatternsGreedily(getOperation(), std::move(patterns)); |
666 | } else if (distributeTransferWriteOps) { |
667 | RewritePatternSet patterns(ctx); |
668 | populateDistributeTransferWriteOpPatterns(patterns, distributionFn, |
669 | maxTransferWriteElements); |
670 | (void)applyPatternsGreedily(getOperation(), std::move(patterns)); |
671 | } else if (propagateDistribution) { |
672 | RewritePatternSet patterns(ctx); |
673 | vector::populatePropagateWarpVectorDistributionPatterns( |
674 | patterns, distributionFn, shuffleFn); |
675 | vector::populateDistributeReduction(pattern&: patterns, distributedReductionFn: warpReduction); |
676 | (void)applyPatternsGreedily(getOperation(), std::move(patterns)); |
677 | } |
678 | WarpExecuteOnLane0LoweringOptions options; |
679 | options.warpAllocationFn = allocateGlobalSharedMemory; |
680 | options.warpSyncronizationFn = [](Location loc, OpBuilder &builder, |
681 | gpu::WarpExecuteOnLane0Op warpOp) { |
682 | builder.create<gpu::BarrierOp>(loc); |
683 | }; |
684 | // Test on one pattern in isolation. |
685 | if (warpOpToSCF) { |
686 | populateWarpExecuteOnLane0OpToScfForPattern(patterns, options); |
687 | (void)applyPatternsGreedily(getOperation(), std::move(patterns)); |
688 | return; |
689 | } |
690 | } |
691 | }; |
692 | |
693 | struct TestVectorExtractStridedSliceLowering |
694 | : public PassWrapper<TestVectorExtractStridedSliceLowering, |
695 | OperationPass<func::FuncOp>> { |
696 | MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID( |
697 | TestVectorExtractStridedSliceLowering) |
698 | |
699 | StringRef getArgument() const final { |
700 | return "test-vector-extract-strided-slice-lowering"; |
701 | } |
702 | StringRef getDescription() const final { |
703 | return "Test lowering patterns that converts vector.extract_strided_slice " |
704 | "into a chain of vector.extract and vector.insert ops"; |
705 | } |
706 | void runOnOperation() override { |
707 | RewritePatternSet patterns(&getContext()); |
708 | populateVectorExtractStridedSliceToExtractInsertChainPatterns(patterns); |
709 | (void)applyPatternsGreedily(getOperation(), std::move(patterns)); |
710 | } |
711 | }; |
712 | |
713 | struct TestVectorBreakDownBitCast |
714 | : public PassWrapper<TestVectorBreakDownBitCast, |
715 | OperationPass<func::FuncOp>> { |
716 | MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestVectorBreakDownBitCast) |
717 | |
718 | StringRef getArgument() const final { |
719 | return "test-vector-break-down-bitcast"; |
720 | } |
721 | StringRef getDescription() const final { |
722 | return "Test pattern that breaks down vector.bitcast ops "; |
723 | } |
724 | void runOnOperation() override { |
725 | RewritePatternSet patterns(&getContext()); |
726 | populateBreakDownVectorBitCastOpPatterns(patterns, [](BitCastOp op) { |
727 | return op.getSourceVectorType().getShape().back() > 4; |
728 | }); |
729 | (void)applyPatternsGreedily(getOperation(), std::move(patterns)); |
730 | } |
731 | }; |
732 | |
733 | struct TestCreateVectorBroadcast |
734 | : public PassWrapper<TestCreateVectorBroadcast, |
735 | OperationPass<func::FuncOp>> { |
736 | MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestCreateVectorBroadcast) |
737 | |
738 | StringRef getArgument() const final { return "test-create-vector-broadcast"; } |
739 | StringRef getDescription() const final { |
740 | return "Test optimization transformations for transfer ops"; |
741 | } |
742 | void getDependentDialects(DialectRegistry ®istry) const override { |
743 | registry.insert<vector::VectorDialect>(); |
744 | } |
745 | |
746 | void runOnOperation() override { |
747 | getOperation()->walk([](Operation *op) { |
748 | if (op->getName().getStringRef() != "test_create_broadcast") |
749 | return; |
750 | auto targetShape = |
751 | cast<VectorType>(op->getResult(0).getType()).getShape(); |
752 | auto arrayAttr = |
753 | cast<DenseI64ArrayAttr>(op->getDiscardableAttr("broadcast_dims")) |
754 | .asArrayRef(); |
755 | llvm::SetVector<int64_t> broadcastedDims; |
756 | broadcastedDims.insert_range(arrayAttr); |
757 | OpBuilder b(op); |
758 | Value bcast = vector::BroadcastOp::createOrFoldBroadcastOp( |
759 | b, op->getOperand(0), targetShape, broadcastedDims); |
760 | op->getResult(0).replaceAllUsesWith(bcast); |
761 | op->erase(); |
762 | }); |
763 | } |
764 | }; |
765 | |
766 | struct TestVectorGatherLowering |
767 | : public PassWrapper<TestVectorGatherLowering, |
768 | OperationPass<func::FuncOp>> { |
769 | MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestVectorGatherLowering) |
770 | |
771 | StringRef getArgument() const final { return "test-vector-gather-lowering"; } |
772 | StringRef getDescription() const final { |
773 | return "Test patterns that lower the gather op in the vector conditional " |
774 | "loads"; |
775 | } |
776 | void getDependentDialects(DialectRegistry ®istry) const override { |
777 | registry.insert<arith::ArithDialect, func::FuncDialect, |
778 | memref::MemRefDialect, scf::SCFDialect, |
779 | tensor::TensorDialect, vector::VectorDialect>(); |
780 | } |
781 | |
782 | void runOnOperation() override { |
783 | RewritePatternSet patterns(&getContext()); |
784 | populateVectorGatherLoweringPatterns(patterns); |
785 | populateVectorGatherToConditionalLoadPatterns(patterns); |
786 | (void)applyPatternsGreedily(getOperation(), std::move(patterns)); |
787 | } |
788 | }; |
789 | |
790 | struct TestFoldArithExtensionIntoVectorContractPatterns |
791 | : public PassWrapper<TestFoldArithExtensionIntoVectorContractPatterns, |
792 | OperationPass<func::FuncOp>> { |
793 | MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID( |
794 | TestFoldArithExtensionIntoVectorContractPatterns) |
795 | |
796 | StringRef getArgument() const final { |
797 | return "test-fold-arith-extf-into-vector-contract-patterns"; |
798 | } |
799 | StringRef getDescription() const final { |
800 | return "Test patterns that fold arithmetic extension ops into vector " |
801 | "contract ops"; |
802 | } |
803 | |
804 | void getDependentDialects(DialectRegistry ®istry) const override { |
805 | registry.insert<arith::ArithDialect, func::FuncDialect, nvgpu::NVGPUDialect, |
806 | memref::MemRefDialect, scf::SCFDialect, |
807 | tensor::TensorDialect, vector::VectorDialect>(); |
808 | } |
809 | |
810 | void runOnOperation() override { |
811 | RewritePatternSet patterns(&getContext()); |
812 | populateFoldArithExtensionPatterns(patterns); |
813 | (void)applyPatternsGreedily(getOperation(), std::move(patterns)); |
814 | } |
815 | }; |
816 | |
817 | struct TestVectorEmulateMaskedLoadStore final |
818 | : public PassWrapper<TestVectorEmulateMaskedLoadStore, |
819 | OperationPass<func::FuncOp>> { |
820 | MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestVectorEmulateMaskedLoadStore) |
821 | |
822 | StringRef getArgument() const override { |
823 | return "test-vector-emulate-masked-load-store"; |
824 | } |
825 | StringRef getDescription() const override { |
826 | return "Test patterns that emulate the maskedload/maskedstore op by " |
827 | " memref.load/store and scf.if"; |
828 | } |
829 | void getDependentDialects(DialectRegistry ®istry) const override { |
830 | registry |
831 | .insert<arith::ArithDialect, func::FuncDialect, memref::MemRefDialect, |
832 | scf::SCFDialect, vector::VectorDialect>(); |
833 | } |
834 | |
835 | void runOnOperation() override { |
836 | RewritePatternSet patterns(&getContext()); |
837 | populateVectorMaskedLoadStoreEmulationPatterns(patterns); |
838 | (void)applyPatternsGreedily(getOperation(), std::move(patterns)); |
839 | } |
840 | }; |
841 | |
842 | /// Get the set of operand/result types to check for sufficiently |
843 | /// small inner-most dimension size. |
844 | static SmallVector<std::pair<Type, unsigned>> |
845 | getTypeBitWidthBoundPairs(Operation *op, unsigned targetBitWidth) { |
846 | |
847 | if (auto insertOp = dyn_cast<vector::InsertOp>(op)) { |
848 | unsigned w = targetBitWidth < std::numeric_limits<unsigned>::max() |
849 | ? targetBitWidth + 1 |
850 | : targetBitWidth; |
851 | return {{insertOp.getValueToStoreType(), w}}; |
852 | } |
853 | |
854 | auto resultTypes = op->getResultTypes(); |
855 | SmallVector<std::pair<Type, unsigned>> resultsWithBitWidth; |
856 | resultsWithBitWidth.reserve(N: resultTypes.size()); |
857 | for (Type type : resultTypes) { |
858 | resultsWithBitWidth.push_back(Elt: {type, targetBitWidth}); |
859 | } |
860 | return resultsWithBitWidth; |
861 | } |
862 | |
863 | /// If `type` is VectorType with trailing dimension of (bit) size greater than |
864 | /// or equal to `targetBitWidth`, its defining op is considered legal. |
865 | static bool |
866 | isNotLinearizableBecauseLargeInnerDimension(Type type, |
867 | unsigned targetBitWidth) { |
868 | |
869 | VectorType vecType = dyn_cast<VectorType>(type); |
870 | |
871 | // Not linearizable for reasons other than what this function checks. |
872 | if (!vecType || vecType.getRank() == 0) |
873 | return false; |
874 | |
875 | // The width of the type 'index' is unbounded (and therefore potentially above |
876 | // the target width). |
877 | if (vecType.getElementType().isIndex()) |
878 | return true; |
879 | |
880 | unsigned finalDimSize = vecType.getShape().back(); |
881 | unsigned nbBitsPerElm = vecType.getElementTypeBitWidth(); |
882 | unsigned trailingVecDimBitWidth = finalDimSize * nbBitsPerElm; |
883 | return trailingVecDimBitWidth >= targetBitWidth; |
884 | } |
885 | |
886 | static bool |
887 | isNotLinearizableBecauseLargeInnerDimension(Operation *op, |
888 | unsigned targetBitWidth) { |
889 | // Check on bitwidths. |
890 | SmallVector<std::pair<Type, unsigned>> toCheck = |
891 | getTypeBitWidthBoundPairs(op, targetBitWidth); |
892 | return llvm::any_of(Range&: toCheck, P: [&](std::pair<Type, unsigned> typeWidth) { |
893 | return isNotLinearizableBecauseLargeInnerDimension(type: typeWidth.first, |
894 | targetBitWidth: typeWidth.second); |
895 | }); |
896 | } |
897 | |
898 | void populateWithBitWidthConstraints(TypeConverter &typeConverter, |
899 | ConversionTarget &target, |
900 | unsigned targetBitWidth) { |
901 | |
902 | // The general purpose definition of what ops are legal must come first. |
903 | populateForVectorLinearize(typeConverter, conversionTarget&: target); |
904 | |
905 | // Extend the set of legal ops to include those with large inner-most |
906 | // dimensions on selected operands/results. |
907 | target.markUnknownOpDynamicallyLegal( |
908 | fn: [=](Operation *op) -> std::optional<bool> { |
909 | if (isNotLinearizableBecauseLargeInnerDimension(op, targetBitWidth)) { |
910 | return true; |
911 | } |
912 | return {}; |
913 | }); |
914 | } |
915 | |
916 | struct TestVectorBitWidthLinearize final |
917 | : public PassWrapper<TestVectorBitWidthLinearize, OperationPass<>> { |
918 | MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestVectorBitWidthLinearize) |
919 | |
920 | TestVectorBitWidthLinearize() = default; |
921 | TestVectorBitWidthLinearize(const TestVectorBitWidthLinearize &pass) |
922 | : PassWrapper(pass) {} |
923 | |
924 | StringRef getArgument() const override { |
925 | return "test-bit-width-constrained-vector-linearize"; |
926 | } |
927 | StringRef getDescription() const override { |
928 | return "Linearizes ND vectors for N >= 2 into 1D vectors, with constraints " |
929 | "in inner-most dimension's bit width."; |
930 | } |
931 | void getDependentDialects(DialectRegistry ®istry) const override { |
932 | registry.insert<vector::VectorDialect>(); |
933 | } |
934 | |
935 | Option<unsigned> targetVectorBitwidth{ |
936 | *this, "target-vector-bitwidth", |
937 | llvm::cl::desc( |
938 | "Minimum vector bitwidth to enable the flattening transformation"), |
939 | llvm::cl::init(Val: std::numeric_limits<unsigned>::max())}; |
940 | void runOnOperation() override { |
941 | auto *context = &getContext(); |
942 | |
943 | TypeConverter typeConverter; |
944 | RewritePatternSet patterns(context); |
945 | ConversionTarget target(*context); |
946 | |
947 | populateWithBitWidthConstraints(typeConverter, target, |
948 | targetBitWidth: targetVectorBitwidth); |
949 | |
950 | vector::populateVectorLinearizeBasePatterns(typeConverter, target, |
951 | patterns); |
952 | |
953 | vector::populateVectorLinearizeShuffleLikeOpsPatterns(typeConverter, target, |
954 | patterns); |
955 | |
956 | if (failed(applyPartialConversion(getOperation(), target, |
957 | std::move(patterns)))) |
958 | return signalPassFailure(); |
959 | } |
960 | }; |
961 | |
962 | struct TestVectorLinearize final |
963 | : public PassWrapper<TestVectorLinearize, OperationPass<>> { |
964 | MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestVectorLinearize) |
965 | |
966 | TestVectorLinearize() = default; |
967 | |
968 | StringRef getArgument() const override { return "test-vector-linearize"; } |
969 | StringRef getDescription() const override { |
970 | return "Linearizes ND vectors for N >= 2 into 1D vectors"; |
971 | } |
972 | void getDependentDialects(DialectRegistry ®istry) const override { |
973 | registry.insert<vector::VectorDialect, arith::ArithDialect>(); |
974 | } |
975 | |
976 | void runOnOperation() override { |
977 | MLIRContext &context = getContext(); |
978 | TypeConverter converter; |
979 | RewritePatternSet patterns(&context); |
980 | ConversionTarget target(context); |
981 | |
982 | vector::populateForVectorLinearize(typeConverter&: converter, conversionTarget&: target); |
983 | |
984 | vector::populateVectorLinearizeBasePatterns(converter, target, patterns); |
985 | vector::populateVectorLinearizeShuffleLikeOpsPatterns(converter, target, |
986 | patterns); |
987 | mlir::scf::populateSCFStructuralTypeConversionsAndLegality( |
988 | typeConverter: converter, patterns, target); |
989 | |
990 | if (failed(applyPartialConversion(getOperation(), target, |
991 | std::move(patterns)))) |
992 | return signalPassFailure(); |
993 | } |
994 | }; |
995 | |
996 | struct TestEliminateVectorMasks |
997 | : public PassWrapper<TestEliminateVectorMasks, |
998 | OperationPass<func::FuncOp>> { |
999 | MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestEliminateVectorMasks) |
1000 | |
1001 | TestEliminateVectorMasks() = default; |
1002 | TestEliminateVectorMasks(const TestEliminateVectorMasks &pass) |
1003 | : PassWrapper(pass) {} |
1004 | |
1005 | Option<unsigned> vscaleMin{ |
1006 | *this, "vscale-min", llvm::cl::desc( "Minimum possible value of vscale."), |
1007 | llvm::cl::init(Val: 1)}; |
1008 | Option<unsigned> vscaleMax{ |
1009 | *this, "vscale-max", llvm::cl::desc( "Maximum possible value of vscale."), |
1010 | llvm::cl::init(Val: 16)}; |
1011 | |
1012 | StringRef getArgument() const final { return "test-eliminate-vector-masks"; } |
1013 | StringRef getDescription() const final { |
1014 | return "Test eliminating vector masks"; |
1015 | } |
1016 | void runOnOperation() override { |
1017 | IRRewriter rewriter(&getContext()); |
1018 | eliminateVectorMasks(rewriter, getOperation(), |
1019 | VscaleRange{.vscaleMin: vscaleMin, .vscaleMax: vscaleMax}); |
1020 | } |
1021 | }; |
1022 | } // namespace |
1023 | |
1024 | namespace mlir { |
1025 | namespace test { |
1026 | void registerTestVectorLowerings() { |
1027 | PassRegistration<TestVectorToVectorLowering>(); |
1028 | |
1029 | PassRegistration<TestVectorContractionPrepareForMMTLowering>(); |
1030 | |
1031 | PassRegistration<TestVectorUnrollingPatterns>(); |
1032 | |
1033 | PassRegistration<TestVectorTransferUnrollingPatterns>(); |
1034 | |
1035 | PassRegistration<TestScalarVectorTransferLoweringPatterns>(); |
1036 | |
1037 | PassRegistration<TestVectorTransferOpt>(); |
1038 | |
1039 | PassRegistration<TestVectorTransferCollapseInnerMostContiguousDims>(); |
1040 | |
1041 | PassRegistration<TestVectorSinkPatterns>(); |
1042 | |
1043 | PassRegistration<TestVectorReduceToContractPatternsPatterns>(); |
1044 | |
1045 | PassRegistration<TestVectorChainedReductionFoldingPatterns>(); |
1046 | |
1047 | PassRegistration<TestVectorBreakDownReductionPatterns>(); |
1048 | |
1049 | PassRegistration<TestFlattenVectorTransferPatterns>(); |
1050 | |
1051 | PassRegistration<TestVectorScanLowering>(); |
1052 | |
1053 | PassRegistration<TestVectorDistribution>(); |
1054 | |
1055 | PassRegistration<TestVectorExtractStridedSliceLowering>(); |
1056 | |
1057 | PassRegistration<TestVectorBreakDownBitCast>(); |
1058 | |
1059 | PassRegistration<TestCreateVectorBroadcast>(); |
1060 | |
1061 | PassRegistration<TestVectorGatherLowering>(); |
1062 | |
1063 | PassRegistration<TestFoldArithExtensionIntoVectorContractPatterns>(); |
1064 | |
1065 | PassRegistration<TestVectorEmulateMaskedLoadStore>(); |
1066 | |
1067 | PassRegistration<TestVectorLinearize>(); |
1068 | |
1069 | PassRegistration<TestVectorBitWidthLinearize>(); |
1070 | |
1071 | PassRegistration<TestEliminateVectorMasks>(); |
1072 | } |
1073 | } // namespace test |
1074 | } // namespace mlir |
1075 |
Definitions
- TestVectorToVectorLowering
- TestVectorToVectorLowering
- TestVectorToVectorLowering
- getArgument
- getDescription
- getDependentDialects
- runOnOperation
- getShape
- filter
- TestVectorContractionPrepareForMMTLowering
- getArgument
- getDescription
- TestVectorContractionPrepareForMMTLowering
- getDependentDialects
- runOnOperation
- TestVectorUnrollingPatterns
- getArgument
- getDescription
- TestVectorUnrollingPatterns
- TestVectorUnrollingPatterns
- runOnOperation
- TestVectorTransferUnrollingPatterns
- TestVectorTransferUnrollingPatterns
- TestVectorTransferUnrollingPatterns
- getDependentDialects
- getArgument
- getDescription
- runOnOperation
- TestScalarVectorTransferLoweringPatterns
- TestScalarVectorTransferLoweringPatterns
- TestScalarVectorTransferLoweringPatterns
- getArgument
- getDescription
- getDependentDialects
- runOnOperation
- TestVectorTransferOpt
- getArgument
- getDescription
- runOnOperation
- TestVectorTransferCollapseInnerMostContiguousDims
- TestVectorTransferCollapseInnerMostContiguousDims
- TestVectorTransferCollapseInnerMostContiguousDims
- getDependentDialects
- getArgument
- getDescription
- runOnOperation
- TestVectorSinkPatterns
- TestVectorSinkPatterns
- TestVectorSinkPatterns
- getDependentDialects
- getArgument
- getDescription
- runOnOperation
- TestVectorReduceToContractPatternsPatterns
- getArgument
- getDescription
- runOnOperation
- TestVectorChainedReductionFoldingPatterns
- getArgument
- getDescription
- runOnOperation
- TestVectorBreakDownReductionPatterns
- getArgument
- getDescription
- runOnOperation
- TestFlattenVectorTransferPatterns
- TestFlattenVectorTransferPatterns
- TestFlattenVectorTransferPatterns
- getArgument
- getDescription
- getDependentDialects
- runOnOperation
- TestVectorScanLowering
- getArgument
- getDescription
- runOnOperation
- allocateGlobalSharedMemory
- warpReduction
- TestVectorDistribution
- getDependentDialects
- getArgument
- getDescription
- TestVectorDistribution
- TestVectorDistribution
- runOnOperation
- TestVectorExtractStridedSliceLowering
- getArgument
- getDescription
- runOnOperation
- TestVectorBreakDownBitCast
- getArgument
- getDescription
- runOnOperation
- TestCreateVectorBroadcast
- getArgument
- getDescription
- getDependentDialects
- runOnOperation
- TestVectorGatherLowering
- getArgument
- getDescription
- getDependentDialects
- runOnOperation
- TestFoldArithExtensionIntoVectorContractPatterns
- getArgument
- getDescription
- getDependentDialects
- runOnOperation
- TestVectorEmulateMaskedLoadStore
- getArgument
- getDescription
- getDependentDialects
- runOnOperation
- getTypeBitWidthBoundPairs
- isNotLinearizableBecauseLargeInnerDimension
- isNotLinearizableBecauseLargeInnerDimension
- populateWithBitWidthConstraints
- TestVectorBitWidthLinearize
- TestVectorBitWidthLinearize
- TestVectorBitWidthLinearize
- getArgument
- getDescription
- getDependentDialects
- runOnOperation
- TestVectorLinearize
- TestVectorLinearize
- getArgument
- getDescription
- getDependentDialects
- runOnOperation
- TestEliminateVectorMasks
- TestEliminateVectorMasks
- TestEliminateVectorMasks
- getArgument
- getDescription
- runOnOperation
Update your C++ knowledge – Modern C++11/14/17 Training
Find out more