1//===- Specialize.cpp - linalg generic ops to named ops ------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implements a method to specialize generic operations to named
10// operations. Conceptually it is the opposite of generalize.cpp.
11//
12//===----------------------------------------------------------------------===//
13
14#include "mlir/Dialect/Complex/IR/Complex.h"
15#include "mlir/Dialect/Linalg/IR/Linalg.h"
16#include "mlir/Dialect/Linalg/IR/LinalgInterfaces.h"
17#include "mlir/Dialect/Linalg/Passes.h"
18#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
19#include "mlir/Dialect/Math/IR/Math.h"
20#include "mlir/IR/PatternMatch.h"
21#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
22
23namespace mlir {
24#define GEN_PASS_DEF_LINALGSPECIALIZEGENERICOPSPASS
25#include "mlir/Dialect/Linalg/Passes.h.inc"
26} // namespace mlir
27
28#define DEBUG_TYPE "linalg-specialization"
29
30#define REPLACE_BINARY_OP(NEWOP, OPERANDS_SWAP) \
31 (rewriter.replaceOpWithNewOp<NEWOP>( \
32 genericOp, \
33 ValueRange{genericOp.getDpsInputs()[(OPERANDS_SWAP) ? 1 : 0], \
34 genericOp.getDpsInputs()[(OPERANDS_SWAP) ? 0 : 1]}, \
35 ValueRange{genericOp.getDpsInits()[0]}))
36
37#define REPLACE_UNARY_OP(NEWOP) \
38 (rewriter.replaceOpWithNewOp<NEWOP>(genericOp, \
39 ValueRange{genericOp.getDpsInputs()[0]}, \
40 ValueRange{genericOp.getDpsInits()[0]}))
41
42using namespace mlir;
43using namespace mlir::linalg;
44
45// Given a elementwise single binary linalg generic op, checks whether the
46// binary op accesses operands as swapped. e.g.
47// this differentiates between a linalg-generic body that contains:
48// ^bb0(%a: f32, %b: f32, %c : f32):
49// %0 = arith.subf %a, %b : f32
50// linalg.yield %0: f32
51// against:
52// ^bb0(%a: f32, %b: f32, %c : f32):
53// %0 = arith.subf %b, %a : f32
54// linalg.yield %0: f32
55// Former is linalg.sub(a,b), latter is linalg.sub(b,a).
56static bool areBinOpsSwapped(GenericOp genericOp) {
57 Block *body = genericOp.getBody();
58 Operation *op = &body->front();
59 bool swapped = false;
60 if (op->getOpOperand(idx: 0).get() != body->getArgument(i: 0)) {
61 swapped = true;
62 assert(op->getOpOperand(0).get() == body->getArgument(1) &&
63 op->getOpOperand(1).get() == body->getArgument(0) &&
64 "binary op uses just one block arg");
65 }
66 return swapped;
67}
68
69//===----------------------------------------------------------------------===//
70// Specialize linalg generic to matmul variants.
71//===----------------------------------------------------------------------===//
72/// Identifies linalg.generic that is essentially named op of the form:
73// ` linalg.{batch_}?matmul{_transpose_a | _transpose_b}? `
74//
75// It is possible that a linalg.generic may be implementing a matmul but not
76// in a straight-forward way e.g. below is matrix multiply over some slice
77// ```
78// %0 = linalg.generic {
79// indexing_maps = [affine_map<(d0, d1, d2) -> (3, d1, d0)>,
80// affine_map<(d0, d1, d2) -> (d0, 5, d2)>,
81// affine_map<(d0, d1, d2) -> (d2, d1, 13)>],
82// iterator_types = ["parallel", "parallel", "parallel"]}
83// ins(%A, %B : tensor<20x20x20xf32>, tensor<20x20x20xf32>)
84// outs(%C : tensor<20x20x20xf32>) {
85// ^bb0(%a: f32, %b: f32, %c : f32):
86// %mul = arith.mulf %a, %b : f32
87// %add = arith.addf %mul, %c : f32
88// linalg.yield %add : f32
89// } -> tensor<20x20x20xf32>
90// ```
91// It is not possible to represent above as named op.
92// e.g. linalg.batch_matmul(%A, %B : tensor<20x20x20xf32>, ...) is
93// not the same as linalg.generic above.
94namespace {
95enum class IndexMatchResult {
96 Match = 0, // identity map.
97 Transposed, // transposed map.
98 Mismatch // none of the above.
99};
100
101// Checks whether the input Affine `map` contains two consecutive dims that
102// can be interpreted as accessing a 2D matrix. It is assumed that the row
103// column dimension are adjacent axis (in this order) and start at
104// `rowDimIdx` in the input map.
105//
106// e.g. consider A matrix in `C[M,N] = A[M,K] * B[K,N]`. We will check
107// whether the map of A is identity (match), transposed, or something
108// completely different (mis-match). Similar for B and C.
109static IndexMatchResult matchOperandMap(AffineMap map, unsigned rowDimIdx,
110 unsigned expectedPosOfRowDim,
111 unsigned expectedPosOfColDim) {
112 // Get the matrix multiply indices. They are past the batch indices.
113 auto exprOfRowDim = map.getResults()[rowDimIdx];
114 auto exprOfColDim = map.getResults()[rowDimIdx + 1];
115
116 // They should be pure dimension ids.
117 if (exprOfRowDim.getKind() != AffineExprKind::DimId ||
118 exprOfColDim.getKind() != AffineExprKind::DimId)
119 return IndexMatchResult::Mismatch;
120
121 auto posRowDim = cast<AffineDimExpr>(Val&: exprOfRowDim).getPosition();
122 auto posColDim = cast<AffineDimExpr>(Val&: exprOfColDim).getPosition();
123
124 if (expectedPosOfRowDim == posRowDim && expectedPosOfColDim == posColDim)
125 return IndexMatchResult::Match;
126
127 if (expectedPosOfRowDim == posColDim && expectedPosOfColDim == posRowDim)
128 return IndexMatchResult::Transposed;
129
130 return IndexMatchResult::Mismatch;
131}
132
133// Replaces genericOp with `NamedOpTy` op, supplied as a template arg.
134// All the variants expressed as pseudo regular expression:
135// `linalg.{batch_}?matmul{_transpose_a | _transpose_b}?`
136// have same number of ins/out, so its easy to stamp different versions.
137template <typename NamedOpTy>
138static LinalgOp replaceWithMatmulVariant(RewriterBase &rewriter, GenericOp op) {
139 LinalgOp namedOp = rewriter.replaceOpWithNewOp<NamedOpTy>(
140 op, ValueRange{op.getDpsInputs()[0], op.getDpsInputs()[1]},
141 ValueRange{op.getDpsInits()[0]});
142 return namedOp;
143}
144
145// Converts linalg.generic to named linalg.*matmul* where possible.
146static FailureOr<LinalgOp> specializeLinalgContractions(RewriterBase &rewriter,
147 GenericOp genericOp) {
148 if (genericOp.getNumDpsInputs() != 2 || genericOp.getNumDpsInits() != 1)
149 return failure();
150
151 // Early exit if not projected permutations.
152 auto mapRange = genericOp.getIndexingMapsArray();
153 if (llvm::any_of(Range&: mapRange,
154 P: [](AffineMap m) { return !m.isProjectedPermutation(); }))
155 return failure();
156
157 // Linalg generic contraction can be across multiple axis e.g.
158 // ```
159 // linalg.generic
160 // {indexing_maps = [affine_map<(m, n, k1, k2) -> (m, k1, k2)>,
161 // affine_map<(m, n, k1, k2) -> (k2, k1, n)>,
162 // affine_map<(m, n, k1, k2) -> (m, n)>],
163 // iterator_types = ["parallel", "parallel",
164 // "reduction", "reduction"]}
165 // ins(%A, %B : tensor<10x20x30xf32>, tensor<30x20x40xf32>)
166 // outs(%C : tensor<10x40xf32>) {
167 // ^bb0(%a: f32, %b: f32, %c: f32):
168 // %1 = arith.mulf %a, %b : f32
169 // %2 = arith.addf %c, %1 : f32
170 // linalg.yield %2 : f32
171 // } -> tensor<10x40xf32>
172 // ```
173 // In above contraction, there are two reduction dimensions {k1, k2}
174 // and although a valid linalg contraction, it is not a named-op
175 // matrix multiply kind. Therefore, reject multi-dim reduction.
176 auto res = inferContractionDims(linalgOp: genericOp);
177 if (!succeeded(Result: res))
178 return failure();
179 auto dims = *res;
180 if (dims.m.size() != 1 || dims.n.size() != 1 || dims.k.size() != 1)
181 return failure();
182
183 if (!mlir::linalg::detail::isContractionBody(
184 block&: *genericOp.getBlock(), isaPair: [](Operation *first, Operation *second) {
185 if ((isa<arith::MulFOp>(Val: first) && isa<arith::AddFOp>(Val: second)) ||
186 (isa<arith::MulIOp>(Val: first) && isa<arith::AddIOp>(Val: second)) ||
187 (isa<complex::MulOp>(Val: first) && isa<complex::AddOp>(Val: second)))
188 return true;
189 return false;
190 }))
191 return failure();
192
193 // Check rank of operands
194 auto indexingMaps = genericOp.getIndexingMapsArray();
195 if (llvm::any_of(Range&: indexingMaps, P: [&dims](AffineMap m) {
196 return m.getResults().size() !=
197 dims.batch.size() + 2 /* any two of {m,n,k} */;
198 }))
199 return failure();
200
201 auto numOfBatchDims = dims.batch.size();
202 if (indexingMaps[0].getNumDims() != numOfBatchDims + 3)
203 return failure();
204
205 if (numOfBatchDims) {
206 // Each operand in a linalg generic contraction could express different
207 // permutations for its batch dimension. But for named op it must be
208 // identity since separate maps are not specified.
209 if (llvm::any_of(Range&: indexingMaps, P: [numOfBatchDims](AffineMap m) {
210 for (unsigned i = 0; i < numOfBatchDims; ++i) {
211 auto expr = m.getResults()[i];
212 if (expr.getKind() != AffineExprKind::DimId ||
213 cast<AffineDimExpr>(Val&: expr).getPosition() != i)
214 return true;
215 }
216 return false;
217 }))
218 return failure();
219 }
220
221 auto a =
222 matchOperandMap(map: indexingMaps[0], rowDimIdx: numOfBatchDims, expectedPosOfRowDim: dims.m[0], expectedPosOfColDim: dims.k[0]);
223 auto b =
224 matchOperandMap(map: indexingMaps[1], rowDimIdx: numOfBatchDims, expectedPosOfRowDim: dims.k[0], expectedPosOfColDim: dims.n[0]);
225 auto c =
226 matchOperandMap(map: indexingMaps[2], rowDimIdx: numOfBatchDims, expectedPosOfRowDim: dims.m[0], expectedPosOfColDim: dims.n[0]);
227
228 if (llvm::is_contained(Set: {a, b, c}, Element: IndexMatchResult::Mismatch))
229 return failure();
230
231 if (c != IndexMatchResult::Match ||
232 (a == IndexMatchResult::Transposed && b == IndexMatchResult::Transposed))
233 return failure();
234
235 /// Codegen the different matmul variants.
236 if (numOfBatchDims) {
237 if (a == IndexMatchResult::Transposed)
238 return replaceWithMatmulVariant<BatchMatmulTransposeAOp>(rewriter,
239 op: genericOp);
240 if (b == IndexMatchResult::Transposed)
241 return replaceWithMatmulVariant<BatchMatmulTransposeBOp>(rewriter,
242 op: genericOp);
243 return replaceWithMatmulVariant<BatchMatmulOp>(rewriter, op: genericOp);
244 }
245
246 if (a == IndexMatchResult::Transposed)
247 return replaceWithMatmulVariant<MatmulTransposeAOp>(rewriter, op: genericOp);
248 if (b == IndexMatchResult::Transposed)
249 return replaceWithMatmulVariant<MatmulTransposeBOp>(rewriter, op: genericOp);
250 return replaceWithMatmulVariant<MatmulOp>(rewriter, op: genericOp);
251}
252
253} // namespace
254
255//===----------------------------------------------------------------------===//
256// Categorize linalg generic to named op where possible.
257//===----------------------------------------------------------------------===//
258FailureOr<LinalgOp> mlir::linalg::specializeGenericOp(RewriterBase &rewriter,
259 GenericOp genericOp) {
260 // Copy
261 if (isaCopyOpInterface(linalgOp: genericOp)) {
262 LinalgOp namedOp = rewriter.replaceOpWithNewOp<CopyOp>(
263 op: genericOp, args&: genericOp.getDpsInputs()[0], args: genericOp.getDpsInits()[0]);
264 return namedOp;
265 }
266
267 // Fill
268 if (std::optional<Value> fillValue = isaFillOpInterface(genericOp)) {
269 // Always use the detected fill value, regardless of pattern
270 LinalgOp namedOp = rewriter.replaceOpWithNewOp<FillOp>(
271 op: genericOp, args&: *fillValue, args: genericOp.getDpsInits()[0]);
272 return namedOp;
273 }
274
275 // Broadcast
276 std::optional<SmallVector<int64_t>> equivalentToBroadcast =
277 isaBroadcastOpInterface(genericOp);
278 if (equivalentToBroadcast) {
279 auto dims = *equivalentToBroadcast;
280 LinalgOp namedOp = rewriter.replaceOpWithNewOp<BroadcastOp>(
281 op: genericOp, args&: genericOp.getDpsInputs()[0], args: genericOp.getDpsInits()[0],
282 args&: dims);
283 return namedOp;
284 }
285
286 // Transpose
287 std::optional<SmallVector<int64_t>> equivalentToTranspose =
288 isaTransposeOpInterface(genericOp);
289 if (equivalentToTranspose) {
290 auto permutation = *equivalentToTranspose;
291 LinalgOp namedOp = rewriter.replaceOpWithNewOp<TransposeOp>(
292 op: genericOp, args&: genericOp.getDpsInputs()[0], args: genericOp.getDpsInits()[0],
293 args&: permutation);
294 return namedOp;
295 }
296
297 // Elementwise Unary
298 if (isaElemwiseSingleUnaryOpInterface(genericOp)) {
299 Operation *op = &genericOp.getBody()->front();
300 if (isa<math::ExpOp>(Val: op)) {
301 LinalgOp namedOp = REPLACE_UNARY_OP(ExpOp);
302 return namedOp;
303 }
304 }
305
306 // Elementwise Binary
307 if (isaElemwiseSingleBinaryOpInterface(genericOp)) {
308 bool swap = areBinOpsSwapped(genericOp);
309 Operation *op = &genericOp.getBody()->front();
310 if (isa<arith::AddFOp>(Val: op)) {
311 LinalgOp namedOp = REPLACE_BINARY_OP(AddOp, swap);
312 return namedOp;
313 }
314 if (isa<arith::SubFOp>(Val: op)) {
315 LinalgOp namedOp = REPLACE_BINARY_OP(SubOp, swap);
316 return namedOp;
317 }
318 if (isa<arith::MulFOp>(Val: op)) {
319 LinalgOp namedOp = REPLACE_BINARY_OP(MulOp, swap);
320 return namedOp;
321 }
322 if (isa<arith::DivFOp>(Val: op)) {
323 LinalgOp namedOp = REPLACE_BINARY_OP(DivOp, swap);
324 return namedOp;
325 }
326 }
327
328 // Contraction - e.g. matmul
329 if (isaContractionOpInterface(linalgOp: genericOp)) {
330 return specializeLinalgContractions(rewriter, genericOp);
331 }
332 return failure();
333}
334
335namespace {
336struct LinalgSpecializeGenericOpsPass
337 : public impl::LinalgSpecializeGenericOpsPassBase<
338 LinalgSpecializeGenericOpsPass> {
339
340 using impl::LinalgSpecializeGenericOpsPassBase<
341 LinalgSpecializeGenericOpsPass>::LinalgSpecializeGenericOpsPassBase;
342 void runOnOperation() override;
343};
344} // namespace
345
346void LinalgSpecializeGenericOpsPass::runOnOperation() {
347 RewritePatternSet patterns(&getContext());
348 populateLinalgGenericOpsSpecializationPatterns(patterns);
349 populateDecomposeProjectedPermutationPatterns(patterns);
350
351 if (failed(Result: applyPatternsGreedily(op: getOperation(), patterns: std::move(patterns))))
352 signalPassFailure();
353}
354
355void mlir::linalg::populateLinalgGenericOpsSpecializationPatterns(
356 RewritePatternSet &patterns) {
357 patterns.add<LinalgSpecializationPattern>(arg: patterns.getContext());
358}
359

source code of mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp