1//===- OptimizeSharedMemory.cpp - MLIR NVGPU pass implementation ----------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implements transforms to enable 1xtf32 and 3xtf32 nvgpu.mma sync
10// operations on f32 input datatype
11//
12//===----------------------------------------------------------------------===//
13
14#include "mlir/Dialect/NVGPU/Transforms/Transforms.h"
15
16#include "mlir/Dialect/Arith/IR/Arith.h"
17#include "mlir/Dialect/GPU/IR/GPUDialect.h"
18#include "mlir/Dialect/MemRef/IR/MemRef.h"
19#include "mlir/Dialect/NVGPU/IR/NVGPUDialect.h"
20#include "mlir/Dialect/Vector/IR/VectorOps.h"
21#include "mlir/Interfaces/SideEffectInterfaces.h"
22#include "mlir/Support/LogicalResult.h"
23#include "llvm/ADT/STLExtras.h"
24#include "llvm/Support/MathExtras.h"
25
26using namespace mlir;
27using namespace mlir::nvgpu;
28
29namespace {
30
31struct MmaSyncF32ToTF32Pattern : public OpRewritePattern<nvgpu::MmaSyncOp> {
32
33 using OpRewritePattern<nvgpu::MmaSyncOp>::OpRewritePattern;
34
35 MmaSyncF32ToTF32Pattern(MLIRContext *context,
36 nvgpu::MmaSyncF32Lowering precision)
37 : OpRewritePattern<nvgpu::MmaSyncOp>(context, /*benifit*/ 1),
38 precision(precision) {}
39
40 LogicalResult matchAndRewrite(nvgpu::MmaSyncOp op,
41 PatternRewriter &rewriter) const override {
42 Location location = op->getLoc();
43
44 if (op->hasAttr(op.getTf32EnabledAttrName()) ||
45 !cast<VectorType>(op.getMatrixA().getType()).getElementType().isF32())
46 return failure();
47
48 if (precision == MmaSyncF32Lowering::Unkown)
49 return emitError(loc: location, message: "MmaSync F32-to-TF32 cannot be lowered with "
50 "unknown precision level");
51
52 if (precision == MmaSyncF32Lowering::TF32x3)
53 return emitError(loc: location, message: "TF32x3 is not supported at the moment "
54 "for nvgpu.mma.sync on f32 datatype");
55
56 if (precision == MmaSyncF32Lowering::TF32) {
57 rewriter.modifyOpInPlace(
58 op, [&]() { op.setTf32EnabledAttr(rewriter.getUnitAttr()); });
59 }
60
61 return success();
62 }
63
64private:
65 /// Precision for F32 Tensor Cores (TF32 or TF32x3)
66 nvgpu::MmaSyncF32Lowering precision;
67};
68
69} // namespace
70
71void mlir::nvgpu::populateMmaSyncF32ToTF32Patterns(
72 RewritePatternSet &patterns, nvgpu::MmaSyncF32Lowering precision) {
73
74 patterns.add<MmaSyncF32ToTF32Pattern>(arg: patterns.getContext(), args&: precision);
75}
76

source code of mlir/lib/Dialect/NVGPU/Transforms/MmaSyncTF32Transform.cpp