1//===- TestNVGPUTransforms.cpp - Test NVGPU transforms and lowerings ----===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#include <type_traits>
10
11#include "mlir/Analysis/SliceAnalysis.h"
12#include "mlir/Dialect/Affine/IR/AffineOps.h"
13#include "mlir/Dialect/Func/IR/FuncOps.h"
14#include "mlir/Dialect/GPU/IR/GPUDialect.h"
15#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
16#include "mlir/Dialect/Linalg/IR/Linalg.h"
17#include "mlir/Dialect/Linalg/Passes.h"
18#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
19#include "mlir/Dialect/MemRef/IR/MemRef.h"
20#include "mlir/Dialect/NVGPU/Transforms/Transforms.h"
21#include "mlir/Dialect/SCF/IR/SCF.h"
22#include "mlir/Pass/Pass.h"
23#include "mlir/Pass/PassManager.h"
24#include "mlir/Support/LLVM.h"
25#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
26
27using namespace mlir;
28using namespace mlir::nvgpu;
29
30namespace {
31
32struct TestMmaSyncF32ToTF32Patterns
33 : public PassWrapper<TestMmaSyncF32ToTF32Patterns,
34 OperationPass<func::FuncOp>> {
35 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestMmaSyncF32ToTF32Patterns)
36
37 StringRef getArgument() const final {
38 return "test-nvgpu-mmasync-f32-to-tf32-patterns";
39 }
40 StringRef getDescription() const final {
41 return "Test patterns to convert mma.sync on f32 with tf32 precision";
42 }
43 TestMmaSyncF32ToTF32Patterns() = default;
44 TestMmaSyncF32ToTF32Patterns(const TestMmaSyncF32ToTF32Patterns &pass)
45 : PassWrapper(pass) {}
46
47 Option<std::string> precision{
48 *this, "precision",
49 llvm::cl::desc(
50 "Target nvgpu.mma.sync on f32 input with tf32 or tf32x3 precision"),
51 llvm::cl::init(Val: "tf32")};
52
53 MmaSyncF32Lowering tf32Precision =
54 llvm::StringSwitch<MmaSyncF32Lowering>(precision)
55 .Case(S: "tf32", Value: MmaSyncF32Lowering::TF32)
56 .Case(S: "tf32x3", Value: MmaSyncF32Lowering::TF32x3)
57 .Default(Value: MmaSyncF32Lowering::Unkown);
58
59 void runOnOperation() override {
60 RewritePatternSet patterns(&getContext());
61
62 populateMmaSyncF32ToTF32Patterns(patterns, precision: tf32Precision);
63 (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
64 }
65};
66
67} // namespace
68
69namespace mlir {
70namespace test {
71void registerTestNvgpuLowerings() {
72 PassRegistration<TestMmaSyncF32ToTF32Patterns>();
73}
74
75} // namespace test
76} // namespace mlir

source code of mlir/test/lib/Dialect/NVGPU/TestNVGPUTransforms.cpp