1//===- OptimizeSharedMemory.cpp - MLIR NVGPU pass implementation ----------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implements transforms to enable 1xtf32 and 3xtf32 nvgpu.mma sync
10// operations on f32 input datatype
11//
12//===----------------------------------------------------------------------===//
13
14#include "mlir/Dialect/NVGPU/Transforms/Transforms.h"
15
16#include "mlir/Dialect/Arith/IR/Arith.h"
17#include "mlir/Dialect/GPU/IR/GPUDialect.h"
18#include "mlir/Dialect/MemRef/IR/MemRef.h"
19#include "mlir/Dialect/NVGPU/IR/NVGPUDialect.h"
20#include "mlir/Dialect/Vector/IR/VectorOps.h"
21#include "mlir/Interfaces/SideEffectInterfaces.h"
22#include "llvm/ADT/STLExtras.h"
23#include "llvm/Support/MathExtras.h"
24
25using namespace mlir;
26using namespace mlir::nvgpu;
27
28namespace {
29
30struct MmaSyncF32ToTF32Pattern : public OpRewritePattern<nvgpu::MmaSyncOp> {
31
32 using OpRewritePattern<nvgpu::MmaSyncOp>::OpRewritePattern;
33
34 MmaSyncF32ToTF32Pattern(MLIRContext *context,
35 nvgpu::MmaSyncF32Lowering precision)
36 : OpRewritePattern<nvgpu::MmaSyncOp>(context, /*benifit*/ 1),
37 precision(precision) {}
38
39 LogicalResult matchAndRewrite(nvgpu::MmaSyncOp op,
40 PatternRewriter &rewriter) const override {
41 Location location = op->getLoc();
42
43 if (op->hasAttr(op.getTf32EnabledAttrName()) ||
44 !cast<VectorType>(op.getMatrixA().getType()).getElementType().isF32())
45 return failure();
46
47 if (precision == MmaSyncF32Lowering::Unkown)
48 return emitError(loc: location, message: "MmaSync F32-to-TF32 cannot be lowered with "
49 "unknown precision level");
50
51 if (precision == MmaSyncF32Lowering::TF32x3)
52 return emitError(loc: location, message: "TF32x3 is not supported at the moment "
53 "for nvgpu.mma.sync on f32 datatype");
54
55 if (precision == MmaSyncF32Lowering::TF32) {
56 rewriter.modifyOpInPlace(
57 op, [&]() { op.setTf32EnabledAttr(rewriter.getUnitAttr()); });
58 }
59
60 return success();
61 }
62
63private:
64 /// Precision for F32 Tensor Cores (TF32 or TF32x3)
65 nvgpu::MmaSyncF32Lowering precision;
66};
67
68} // namespace
69
70void mlir::nvgpu::populateMmaSyncF32ToTF32Patterns(
71 RewritePatternSet &patterns, nvgpu::MmaSyncF32Lowering precision) {
72
73 patterns.add<MmaSyncF32ToTF32Pattern>(arg: patterns.getContext(), args&: precision);
74}
75

Provided by KDAB

Privacy Policy
Learn to use CMake with our Intro Training
Find out more

source code of mlir/lib/Dialect/NVGPU/Transforms/MmaSyncTF32Transform.cpp