1//===- TransposeMatmul.cpp - Convert Linalg matmul to transposed variants -===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8// This is intended to be a simple high-level (target-agnostic) matmul
9// transposition transformation.
10//===----------------------------------------------------------------------===//
11
12#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
13#include "mlir/IR/PatternMatch.h"
14#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
15
16#define DEBUG_TYPE "linalg-transpose-matmul"
17
18using namespace mlir;
19using namespace mlir::linalg;
20
21/// Pattern to replace
22///
23/// linalg.matmul(a, b)
24///
25/// with
26///
27/// linalg.matmul_transpose_a(linalg.transpose(a), b)
28///
29/// By default the LHS is transposed. Set `transposeLHS=false` to
30/// transpose RHS instead.
31FailureOr<Operation *> mlir::linalg::transposeMatmul(RewriterBase &rewriter,
32 linalg::MatmulOp matmulOp,
33 bool transposeLHS) {
34 if (!bufferization::hasTensorSemantics(op: matmulOp))
35 return rewriter.notifyMatchFailure(
36 matmulOp, "only matmul ops with tensors are supported");
37
38 Location loc = matmulOp.getLoc();
39 Value input = matmulOp.getInputs()[transposeLHS ? 0 : 1];
40 auto type = cast<ShapedType>(input.getType());
41
42 SmallVector<Value> dynamicDims;
43 if (type.isDynamicDim(1))
44 dynamicDims.push_back(rewriter.create<tensor::DimOp>(loc, input, 1));
45 if (type.isDynamicDim(0))
46 dynamicDims.push_back(rewriter.create<tensor::DimOp>(loc, input, 0));
47
48 ArrayRef<int64_t> shape = type.getShape();
49 Value empty = rewriter.create<tensor::EmptyOp>(
50 loc, ArrayRef<int64_t>{shape[1], shape[0]}, type.getElementType(),
51 dynamicDims);
52 auto transposeOp = rewriter.create<linalg::TransposeOp>(
53 loc, input, empty, ArrayRef<int64_t>{1, 0});
54 Operation *newMatmulOp;
55 if (transposeLHS) {
56 newMatmulOp = rewriter.create<linalg::MatmulTransposeAOp>(
57 loc, matmulOp.getResultTypes(),
58 ValueRange{transposeOp->getResult(0), matmulOp.getInputs()[1]},
59 matmulOp.getOutputs());
60 } else {
61 newMatmulOp = rewriter.create<linalg::MatmulTransposeBOp>(
62 loc, matmulOp.getResultTypes(),
63 ValueRange{matmulOp.getInputs()[0], transposeOp->getResult(0)},
64 matmulOp.getOutputs());
65 }
66 rewriter.replaceOp(matmulOp, newMatmulOp);
67 return newMatmulOp;
68}
69
70/// Pattern to replace
71///
72/// linalg.batch_matmul(a, b)
73///
74/// with
75///
76/// linalg.batch_matmul_transpose_a(linalg.transpose(a), b)
77///
78/// Only the non-batch dimensions are transposed. By default the LHS is
79/// transposed. Set `transposeLHS=false` to transpose RHS instead.
80FailureOr<Operation *>
81mlir::linalg::transposeBatchMatmul(RewriterBase &rewriter,
82 linalg::BatchMatmulOp batchMatmulOp,
83 bool transposeLHS) {
84 if (!bufferization::hasTensorSemantics(op: batchMatmulOp))
85 return rewriter.notifyMatchFailure(
86 batchMatmulOp, "only matmul ops with tensors are supported");
87
88 Location loc = batchMatmulOp.getLoc();
89 Value input = batchMatmulOp.getInputs()[transposeLHS ? 0 : 1];
90 auto type = cast<ShapedType>(input.getType());
91
92 SmallVector<Value> dynamicDims;
93 if (type.isDynamicDim(0))
94 dynamicDims.push_back(rewriter.create<tensor::DimOp>(loc, input, 0));
95 if (type.isDynamicDim(2))
96 dynamicDims.push_back(rewriter.create<tensor::DimOp>(loc, input, 2));
97 if (type.isDynamicDim(1))
98 dynamicDims.push_back(rewriter.create<tensor::DimOp>(loc, input, 1));
99
100 ArrayRef<int64_t> shape = type.getShape();
101 Value empty = rewriter.create<tensor::EmptyOp>(
102 loc, ArrayRef<int64_t>{shape[0], shape[2], shape[1]},
103 type.getElementType(), dynamicDims);
104 auto transposeOp = rewriter.create<linalg::TransposeOp>(
105 loc, input, empty, ArrayRef<int64_t>{0, 2, 1});
106 Operation *newMatmulOp;
107 if (transposeLHS) {
108 newMatmulOp = rewriter.create<linalg::BatchMatmulTransposeAOp>(
109 loc, batchMatmulOp.getResultTypes(),
110 ValueRange{transposeOp->getResult(0), batchMatmulOp.getInputs()[1]},
111 batchMatmulOp.getOutputs());
112 } else {
113 newMatmulOp = rewriter.create<linalg::BatchMatmulTransposeBOp>(
114 loc, batchMatmulOp.getResultTypes(),
115 ValueRange{batchMatmulOp.getInputs()[0], transposeOp->getResult(0)},
116 batchMatmulOp.getOutputs());
117 }
118 rewriter.replaceOp(batchMatmulOp, newMatmulOp);
119 return newMatmulOp;
120}
121
122namespace {
123struct TransposeMatmul final : public OpRewritePattern<linalg::MatmulOp> {
124 TransposeMatmul(MLIRContext *ctx, bool transposeLHS)
125 : OpRewritePattern(ctx), transposeLHS(transposeLHS) {}
126
127 LogicalResult matchAndRewrite(linalg::MatmulOp op,
128 PatternRewriter &rewriter) const override {
129 if (failed(transposeMatmul(rewriter, op, transposeLHS))) {
130 return failure();
131 }
132 return success();
133 }
134
135private:
136 bool transposeLHS;
137};
138
139struct TransposeBatchMatmul final
140 : public OpRewritePattern<linalg::BatchMatmulOp> {
141 TransposeBatchMatmul(MLIRContext *ctx, bool transposeLHS)
142 : OpRewritePattern(ctx), transposeLHS(transposeLHS) {}
143
144 LogicalResult matchAndRewrite(linalg::BatchMatmulOp op,
145 PatternRewriter &rewriter) const override {
146 if (failed(transposeBatchMatmul(rewriter, op, transposeLHS))) {
147 return failure();
148 }
149 return success();
150 }
151
152private:
153 bool transposeLHS;
154};
155} // namespace
156
157void mlir::linalg::populateTransposeMatmulPatterns(RewritePatternSet &patterns,
158 bool transposeLHS) {
159 patterns.add<TransposeMatmul, TransposeBatchMatmul>(arg: patterns.getContext(),
160 args&: transposeLHS);
161}
162

source code of mlir/lib/Dialect/Linalg/Transforms/TransposeMatmul.cpp