1//===- TransposeMatmul.cpp - Convert Linalg matmul to transposed variants -===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8// This is intended to be a simple high-level (target-agnostic) matmul
9// transposition transformation.
10//===----------------------------------------------------------------------===//
11
12#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
13#include "mlir/IR/PatternMatch.h"
14#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
15
16#define DEBUG_TYPE "linalg-transpose-matmul"
17
18using namespace mlir;
19using namespace mlir::linalg;
20
21/// Pattern to replace
22///
23/// linalg.matmul(a, b)
24///
25/// with
26///
27/// linalg.matmul_transpose_a(linalg.transpose(a), b)
28///
29/// By default the LHS is transposed. Set `transposeLHS=false` to
30/// transpose RHS instead.
31FailureOr<Operation *> mlir::linalg::transposeMatmul(RewriterBase &rewriter,
32 linalg::MatmulOp matmulOp,
33 bool transposeLHS) {
34 // Check to not let go the matmul with extended semantic, through this
35 // transform.
36 if (matmulOp.hasUserDefinedMaps()) {
37 return rewriter.notifyMatchFailure(
38 matmulOp, "only matmul ops with non-extended semantics are supported");
39 }
40
41 if (!bufferization::hasTensorSemantics(op: matmulOp))
42 return rewriter.notifyMatchFailure(
43 matmulOp, "only matmul ops with tensors are supported");
44
45 Location loc = matmulOp.getLoc();
46 Value input = matmulOp.getInputs()[transposeLHS ? 0 : 1];
47 auto type = cast<ShapedType>(input.getType());
48
49 SmallVector<Value> dynamicDims;
50 if (type.isDynamicDim(1))
51 dynamicDims.push_back(rewriter.create<tensor::DimOp>(loc, input, 1));
52 if (type.isDynamicDim(0))
53 dynamicDims.push_back(rewriter.create<tensor::DimOp>(loc, input, 0));
54
55 ArrayRef<int64_t> shape = type.getShape();
56 Value empty = rewriter.create<tensor::EmptyOp>(
57 loc, ArrayRef<int64_t>{shape[1], shape[0]}, type.getElementType(),
58 dynamicDims);
59 auto transposeOp = rewriter.create<linalg::TransposeOp>(
60 loc, input, empty, ArrayRef<int64_t>{1, 0});
61 Operation *newMatmulOp;
62 if (transposeLHS) {
63 newMatmulOp = rewriter.create<linalg::MatmulTransposeAOp>(
64 loc, matmulOp.getResultTypes(),
65 ValueRange{transposeOp->getResult(0), matmulOp.getInputs()[1]},
66 matmulOp.getOutputs());
67 } else {
68 newMatmulOp = rewriter.create<linalg::MatmulTransposeBOp>(
69 loc, matmulOp.getResultTypes(),
70 ValueRange{matmulOp.getInputs()[0], transposeOp->getResult(0)},
71 matmulOp.getOutputs());
72 }
73 rewriter.replaceOp(matmulOp, newMatmulOp);
74 return newMatmulOp;
75}
76
77/// Pattern to replace
78///
79/// linalg.batch_matmul(a, b)
80///
81/// with
82///
83/// linalg.batch_matmul_transpose_a(linalg.transpose(a), b)
84///
85/// Only the non-batch dimensions are transposed. By default the LHS is
86/// transposed. Set `transposeLHS=false` to transpose RHS instead.
87FailureOr<Operation *>
88mlir::linalg::transposeBatchMatmul(RewriterBase &rewriter,
89 linalg::BatchMatmulOp batchMatmulOp,
90 bool transposeLHS) {
91 if (batchMatmulOp.hasUserDefinedMaps()) {
92 return rewriter.notifyMatchFailure(
93 batchMatmulOp, "ops with user-defined maps are not supported");
94 }
95
96 if (!bufferization::hasTensorSemantics(op: batchMatmulOp))
97 return rewriter.notifyMatchFailure(
98 batchMatmulOp, "only matmul ops with tensors are supported");
99
100 Location loc = batchMatmulOp.getLoc();
101 Value input = batchMatmulOp.getInputs()[transposeLHS ? 0 : 1];
102 auto type = cast<ShapedType>(input.getType());
103
104 SmallVector<Value> dynamicDims;
105 if (type.isDynamicDim(0))
106 dynamicDims.push_back(rewriter.create<tensor::DimOp>(loc, input, 0));
107 if (type.isDynamicDim(2))
108 dynamicDims.push_back(rewriter.create<tensor::DimOp>(loc, input, 2));
109 if (type.isDynamicDim(1))
110 dynamicDims.push_back(rewriter.create<tensor::DimOp>(loc, input, 1));
111
112 ArrayRef<int64_t> shape = type.getShape();
113 Value empty = rewriter.create<tensor::EmptyOp>(
114 loc, ArrayRef<int64_t>{shape[0], shape[2], shape[1]},
115 type.getElementType(), dynamicDims);
116 auto transposeOp = rewriter.create<linalg::TransposeOp>(
117 loc, input, empty, ArrayRef<int64_t>{0, 2, 1});
118 Operation *newMatmulOp;
119 if (transposeLHS) {
120 newMatmulOp = rewriter.create<linalg::BatchMatmulTransposeAOp>(
121 loc, batchMatmulOp.getResultTypes(),
122 ValueRange{transposeOp->getResult(0), batchMatmulOp.getInputs()[1]},
123 batchMatmulOp.getOutputs());
124 } else {
125 newMatmulOp = rewriter.create<linalg::BatchMatmulTransposeBOp>(
126 loc, batchMatmulOp.getResultTypes(),
127 ValueRange{batchMatmulOp.getInputs()[0], transposeOp->getResult(0)},
128 batchMatmulOp.getOutputs());
129 }
130 rewriter.replaceOp(batchMatmulOp, newMatmulOp);
131 return newMatmulOp;
132}
133
134namespace {
135struct TransposeMatmul final : public OpRewritePattern<linalg::MatmulOp> {
136 TransposeMatmul(MLIRContext *ctx, bool transposeLHS)
137 : OpRewritePattern(ctx), transposeLHS(transposeLHS) {}
138
139 LogicalResult matchAndRewrite(linalg::MatmulOp op,
140 PatternRewriter &rewriter) const override {
141 if (failed(transposeMatmul(rewriter, op, transposeLHS))) {
142 return failure();
143 }
144 return success();
145 }
146
147private:
148 bool transposeLHS;
149};
150
151struct TransposeBatchMatmul final
152 : public OpRewritePattern<linalg::BatchMatmulOp> {
153 TransposeBatchMatmul(MLIRContext *ctx, bool transposeLHS)
154 : OpRewritePattern(ctx), transposeLHS(transposeLHS) {}
155
156 LogicalResult matchAndRewrite(linalg::BatchMatmulOp op,
157 PatternRewriter &rewriter) const override {
158 if (failed(transposeBatchMatmul(rewriter, op, transposeLHS))) {
159 return failure();
160 }
161 return success();
162 }
163
164private:
165 bool transposeLHS;
166};
167} // namespace
168
169void mlir::linalg::populateTransposeMatmulPatterns(RewritePatternSet &patterns,
170 bool transposeLHS) {
171 patterns.add<TransposeMatmul, TransposeBatchMatmul>(arg: patterns.getContext(),
172 args&: transposeLHS);
173}
174

Provided by KDAB

Privacy Policy
Learn to use CMake with our Intro Training
Find out more

source code of mlir/lib/Dialect/Linalg/Transforms/TransposeMatmul.cpp