1//===- TransposeMatmul.cpp - Convert Linalg matmul to transposed variants -===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8// This is intended to be a simple high-level (target-agnostic) matmul
9// transposition transformation.
10//===----------------------------------------------------------------------===//
11
12#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
13#include "mlir/IR/PatternMatch.h"
14
15#define DEBUG_TYPE "linalg-transpose-matmul"
16
17using namespace mlir;
18using namespace mlir::linalg;
19
20/// Pattern to replace
21///
22/// linalg.matmul(a, b)
23///
24/// with
25///
26/// linalg.matmul_transpose_a(linalg.transpose(a), b)
27///
28/// By default the LHS is transposed. Set `transposeLHS=false` to
29/// transpose RHS instead.
30FailureOr<Operation *> mlir::linalg::transposeMatmul(RewriterBase &rewriter,
31 linalg::MatmulOp matmulOp,
32 bool transposeLHS) {
33 // Check to not let go the matmul with extended semantic, through this
34 // transform.
35 if (matmulOp.hasUserDefinedMaps()) {
36 return rewriter.notifyMatchFailure(
37 arg&: matmulOp, msg: "only matmul ops with non-extended semantics are supported");
38 }
39
40 if (!matmulOp.hasPureTensorSemantics())
41 return rewriter.notifyMatchFailure(
42 arg&: matmulOp, msg: "only matmul ops with tensors are supported");
43
44 Location loc = matmulOp.getLoc();
45 Value input = matmulOp.getInputs()[transposeLHS ? 0 : 1];
46 auto type = cast<ShapedType>(Val: input.getType());
47
48 SmallVector<Value> dynamicDims;
49 if (type.isDynamicDim(idx: 1))
50 dynamicDims.push_back(Elt: rewriter.create<tensor::DimOp>(location: loc, args&: input, args: 1));
51 if (type.isDynamicDim(idx: 0))
52 dynamicDims.push_back(Elt: rewriter.create<tensor::DimOp>(location: loc, args&: input, args: 0));
53
54 ArrayRef<int64_t> shape = type.getShape();
55 Value empty = rewriter.create<tensor::EmptyOp>(
56 location: loc, args: ArrayRef<int64_t>{shape[1], shape[0]}, args: type.getElementType(),
57 args&: dynamicDims);
58 auto transposeOp = rewriter.create<linalg::TransposeOp>(
59 location: loc, args&: input, args&: empty, args: ArrayRef<int64_t>{1, 0});
60 Operation *newMatmulOp;
61 if (transposeLHS) {
62 newMatmulOp = rewriter.create<linalg::MatmulTransposeAOp>(
63 location: loc, args: matmulOp.getResultTypes(),
64 args: ValueRange{transposeOp->getResult(idx: 0), matmulOp.getInputs()[1]},
65 args: matmulOp.getOutputs());
66 } else {
67 newMatmulOp = rewriter.create<linalg::MatmulTransposeBOp>(
68 location: loc, args: matmulOp.getResultTypes(),
69 args: ValueRange{matmulOp.getInputs()[0], transposeOp->getResult(idx: 0)},
70 args: matmulOp.getOutputs());
71 }
72 rewriter.replaceOp(op: matmulOp, newOp: newMatmulOp);
73 return newMatmulOp;
74}
75
76/// Pattern to replace
77///
78/// linalg.batch_matmul(a, b)
79///
80/// with
81///
82/// linalg.batch_matmul_transpose_a(linalg.transpose(a), b)
83///
84/// Only the non-batch dimensions are transposed. By default the LHS is
85/// transposed. Set `transposeLHS=false` to transpose RHS instead.
86FailureOr<Operation *>
87mlir::linalg::transposeBatchMatmul(RewriterBase &rewriter,
88 linalg::BatchMatmulOp batchMatmulOp,
89 bool transposeLHS) {
90 if (batchMatmulOp.hasUserDefinedMaps()) {
91 return rewriter.notifyMatchFailure(
92 arg&: batchMatmulOp, msg: "ops with user-defined maps are not supported");
93 }
94
95 if (!batchMatmulOp.hasPureTensorSemantics())
96 return rewriter.notifyMatchFailure(
97 arg&: batchMatmulOp, msg: "only matmul ops with tensors are supported");
98
99 Location loc = batchMatmulOp.getLoc();
100 Value input = batchMatmulOp.getInputs()[transposeLHS ? 0 : 1];
101 auto type = cast<ShapedType>(Val: input.getType());
102
103 SmallVector<Value> dynamicDims;
104 if (type.isDynamicDim(idx: 0))
105 dynamicDims.push_back(Elt: rewriter.create<tensor::DimOp>(location: loc, args&: input, args: 0));
106 if (type.isDynamicDim(idx: 2))
107 dynamicDims.push_back(Elt: rewriter.create<tensor::DimOp>(location: loc, args&: input, args: 2));
108 if (type.isDynamicDim(idx: 1))
109 dynamicDims.push_back(Elt: rewriter.create<tensor::DimOp>(location: loc, args&: input, args: 1));
110
111 ArrayRef<int64_t> shape = type.getShape();
112 Value empty = rewriter.create<tensor::EmptyOp>(
113 location: loc, args: ArrayRef<int64_t>{shape[0], shape[2], shape[1]},
114 args: type.getElementType(), args&: dynamicDims);
115 auto transposeOp = rewriter.create<linalg::TransposeOp>(
116 location: loc, args&: input, args&: empty, args: ArrayRef<int64_t>{0, 2, 1});
117 Operation *newMatmulOp;
118 if (transposeLHS) {
119 newMatmulOp = rewriter.create<linalg::BatchMatmulTransposeAOp>(
120 location: loc, args: batchMatmulOp.getResultTypes(),
121 args: ValueRange{transposeOp->getResult(idx: 0), batchMatmulOp.getInputs()[1]},
122 args: batchMatmulOp.getOutputs());
123 } else {
124 newMatmulOp = rewriter.create<linalg::BatchMatmulTransposeBOp>(
125 location: loc, args: batchMatmulOp.getResultTypes(),
126 args: ValueRange{batchMatmulOp.getInputs()[0], transposeOp->getResult(idx: 0)},
127 args: batchMatmulOp.getOutputs());
128 }
129 rewriter.replaceOp(op: batchMatmulOp, newOp: newMatmulOp);
130 return newMatmulOp;
131}
132
133namespace {
134struct TransposeMatmul final : public OpRewritePattern<linalg::MatmulOp> {
135 TransposeMatmul(MLIRContext *ctx, bool transposeLHS)
136 : OpRewritePattern(ctx), transposeLHS(transposeLHS) {}
137
138 LogicalResult matchAndRewrite(linalg::MatmulOp op,
139 PatternRewriter &rewriter) const override {
140 if (failed(Result: transposeMatmul(rewriter, matmulOp: op, transposeLHS))) {
141 return failure();
142 }
143 return success();
144 }
145
146private:
147 bool transposeLHS;
148};
149
150struct TransposeBatchMatmul final
151 : public OpRewritePattern<linalg::BatchMatmulOp> {
152 TransposeBatchMatmul(MLIRContext *ctx, bool transposeLHS)
153 : OpRewritePattern(ctx), transposeLHS(transposeLHS) {}
154
155 LogicalResult matchAndRewrite(linalg::BatchMatmulOp op,
156 PatternRewriter &rewriter) const override {
157 if (failed(Result: transposeBatchMatmul(rewriter, batchMatmulOp: op, transposeLHS))) {
158 return failure();
159 }
160 return success();
161 }
162
163private:
164 bool transposeLHS;
165};
166} // namespace
167
168void mlir::linalg::populateTransposeMatmulPatterns(RewritePatternSet &patterns,
169 bool transposeLHS) {
170 patterns.add<TransposeMatmul, TransposeBatchMatmul>(arg: patterns.getContext(),
171 args&: transposeLHS);
172}
173

source code of mlir/lib/Dialect/Linalg/Transforms/TransposeMatmul.cpp