1//===- Specialize.cpp - linalg generic ops to named ops ------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implements a method to specialize generic operations to named
10// operations. Conceptually it is the opposite of generalize.cpp.
11//
12//===----------------------------------------------------------------------===//
13
14#include "mlir/Dialect/Complex/IR/Complex.h"
15#include "mlir/Dialect/Linalg/IR/Linalg.h"
16#include "mlir/Dialect/Linalg/IR/LinalgInterfaces.h"
17#include "mlir/Dialect/Linalg/Passes.h"
18#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
19#include "mlir/Dialect/Math/IR/Math.h"
20#include "mlir/IR/PatternMatch.h"
21#include "mlir/Support/TypeID.h"
22#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
23#include "llvm/Support/Debug.h"
24
25namespace mlir {
26#define GEN_PASS_DEF_LINALGSPECIALIZEGENERICOPSPASS
27#include "mlir/Dialect/Linalg/Passes.h.inc"
28} // namespace mlir
29
30#define DEBUG_TYPE "linalg-specialization"
31
32#define REPLACE_BINARY_OP(NEWOP, OPERANDS_SWAP) \
33 (rewriter.replaceOpWithNewOp<NEWOP>( \
34 genericOp, \
35 ValueRange{genericOp.getDpsInputs()[(OPERANDS_SWAP) ? 1 : 0], \
36 genericOp.getDpsInputs()[(OPERANDS_SWAP) ? 0 : 1]}, \
37 ValueRange{genericOp.getDpsInits()[0]}))
38
39#define REPLACE_UNARY_OP(NEWOP) \
40 (rewriter.replaceOpWithNewOp<NEWOP>(genericOp, \
41 ValueRange{genericOp.getDpsInputs()[0]}, \
42 ValueRange{genericOp.getDpsInits()[0]}))
43
44using namespace mlir;
45using namespace mlir::linalg;
46
47// Given a elementwise single binary linalg generic op, checks whether the
48// binary op accesses operands as swapped. e.g.
49// this differentiates between a linalg-generic body that contains:
50// ^bb0(%a: f32, %b: f32, %c : f32):
51// %0 = arith.subf %a, %b : f32
52// linalg.yield %0: f32
53// against:
54// ^bb0(%a: f32, %b: f32, %c : f32):
55// %0 = arith.subf %b, %a : f32
56// linalg.yield %0: f32
57// Former is linalg.sub(a,b), latter is linalg.sub(b,a).
58static bool areBinOpsSwapped(GenericOp genericOp) {
59 Block *body = genericOp.getBody();
60 Operation *op = &body->front();
61 bool swapped = false;
62 if (op->getOpOperand(idx: 0).get() != body->getArgument(i: 0)) {
63 swapped = true;
64 assert(op->getOpOperand(0).get() == body->getArgument(1) &&
65 op->getOpOperand(1).get() == body->getArgument(0) &&
66 "binary op uses just one block arg");
67 }
68 return swapped;
69}
70
71//===----------------------------------------------------------------------===//
72// Specialize linalg generic to matmul variants.
73//===----------------------------------------------------------------------===//
74/// Identifies linalg.generic that is essentially named op of the form:
75// ` linalg.{batch_}?matmul{_transpose_a | _transpose_b}? `
76//
77// It is possible that a linalg.generic may be implementing a matmul but not
78// in a straight-forward way e.g. below is matrix multiply over some slice
79// ```
80// %0 = linalg.generic {
81// indexing_maps = [affine_map<(d0, d1, d2) -> (3, d1, d0)>,
82// affine_map<(d0, d1, d2) -> (d0, 5, d2)>,
83// affine_map<(d0, d1, d2) -> (d2, d1, 13)>],
84// iterator_types = ["parallel", "parallel", "parallel"]}
85// ins(%A, %B : tensor<20x20x20xf32>, tensor<20x20x20xf32>)
86// outs(%C : tensor<20x20x20xf32>) {
87// ^bb0(%a: f32, %b: f32, %c : f32):
88// %mul = arith.mulf %a, %b : f32
89// %add = arith.addf %mul, %c : f32
90// linalg.yield %add : f32
91// } -> tensor<20x20x20xf32>
92// ```
93// It is not possible to represent above as named op.
94// e.g. linalg.batch_matmul(%A, %B : tensor<20x20x20xf32>, ...) is
95// not the same as linalg.generic above.
96namespace {
97enum class IndexMatchResult {
98 Match = 0, // identity map.
99 Transposed, // transposed map.
100 Mismatch // none of the above.
101};
102
103// Checks whether the input Affine `map` contains two consecutive dims that
104// can be interpreted as accessing a 2D matrix. It is assumed that the row
105// column dimension are adjacent axis (in this order) and start at
106// `rowDimIdx` in the input map.
107//
108// e.g. consider A matrix in `C[M,N] = A[M,K] * B[K,N]`. We will check
109// whether the map of A is identity (match), transposed, or something
110// completely different (mis-match). Similar for B and C.
111static IndexMatchResult matchOperandMap(AffineMap map, unsigned rowDimIdx,
112 unsigned expectedPosOfRowDim,
113 unsigned expectedPosOfColDim) {
114 // Get the matrix multiply indices. They are past the batch indices.
115 auto exprOfRowDim = map.getResults()[rowDimIdx];
116 auto exprOfColDim = map.getResults()[rowDimIdx + 1];
117
118 // They should be pure dimension ids.
119 if (exprOfRowDim.getKind() != AffineExprKind::DimId ||
120 exprOfColDim.getKind() != AffineExprKind::DimId)
121 return IndexMatchResult::Mismatch;
122
123 auto posRowDim = cast<AffineDimExpr>(Val&: exprOfRowDim).getPosition();
124 auto posColDim = cast<AffineDimExpr>(Val&: exprOfColDim).getPosition();
125
126 if (expectedPosOfRowDim == posRowDim && expectedPosOfColDim == posColDim)
127 return IndexMatchResult::Match;
128
129 if (expectedPosOfRowDim == posColDim && expectedPosOfColDim == posRowDim)
130 return IndexMatchResult::Transposed;
131
132 return IndexMatchResult::Mismatch;
133}
134
135// Replaces genericOp with `NamedOpTy` op, supplied as a template arg.
136// All the variants expressed as pseudo regular expression:
137// `linalg.{batch_}?matmul{_transpose_a | _transpose_b}?`
138// have same number of ins/out, so its easy to stamp different versions.
139template <typename NamedOpTy>
140static LinalgOp replaceWithMatmulVariant(RewriterBase &rewriter, GenericOp op) {
141 LinalgOp namedOp = rewriter.replaceOpWithNewOp<NamedOpTy>(
142 op, ValueRange{op.getDpsInputs()[0], op.getDpsInputs()[1]},
143 ValueRange{op.getDpsInits()[0]});
144 return namedOp;
145}
146
147// Converts linalg.generic to named linalg.*matmul* where possible.
148static FailureOr<LinalgOp> specializeLinalgContractions(RewriterBase &rewriter,
149 GenericOp genericOp) {
150 if (genericOp.getNumDpsInputs() != 2 || genericOp.getNumDpsInits() != 1)
151 return failure();
152
153 // Early exit if not projected permutations.
154 auto mapRange = genericOp.getIndexingMapsArray();
155 if (llvm::any_of(mapRange,
156 [](AffineMap m) { return !m.isProjectedPermutation(); }))
157 return failure();
158
159 // Linalg generic contraction can be across multiple axis e.g.
160 // ```
161 // linalg.generic
162 // {indexing_maps = [affine_map<(m, n, k1, k2) -> (m, k1, k2)>,
163 // affine_map<(m, n, k1, k2) -> (k2, k1, n)>,
164 // affine_map<(m, n, k1, k2) -> (m, n)>],
165 // iterator_types = ["parallel", "parallel",
166 // "reduction", "reduction"]}
167 // ins(%A, %B : tensor<10x20x30xf32>, tensor<30x20x40xf32>)
168 // outs(%C : tensor<10x40xf32>) {
169 // ^bb0(%a: f32, %b: f32, %c: f32):
170 // %1 = arith.mulf %a, %b : f32
171 // %2 = arith.addf %c, %1 : f32
172 // linalg.yield %2 : f32
173 // } -> tensor<10x40xf32>
174 // ```
175 // In above contraction, there are two reduction dimensions {k1, k2}
176 // and although a valid linalg contraction, it is not a named-op
177 // matrix multiply kind. Therefore, reject multi-dim reduction.
178 auto res = inferContractionDims(genericOp);
179 if (!succeeded(res))
180 return failure();
181 auto dims = *res;
182 if (dims.m.size() != 1 || dims.n.size() != 1 || dims.k.size() != 1)
183 return failure();
184
185 if (!mlir::linalg::detail::isContractionBody(
186 block&: *genericOp.getBlock(), isaPair: [](Operation *first, Operation *second) {
187 if ((isa<arith::MulFOp>(first) && isa<arith::AddFOp>(second)) ||
188 (isa<arith::MulIOp>(first) && isa<arith::AddIOp>(second)) ||
189 (isa<complex::MulOp>(first) && isa<complex::AddOp>(second)))
190 return true;
191 return false;
192 }))
193 return failure();
194
195 // Check rank of operands
196 auto indexingMaps = genericOp.getIndexingMapsArray();
197 if (llvm::any_of(indexingMaps, [&dims](AffineMap m) {
198 return m.getResults().size() !=
199 dims.batch.size() + 2 /* any two of {m,n,k} */;
200 }))
201 return failure();
202
203 auto numOfBatchDims = dims.batch.size();
204 if (indexingMaps[0].getNumDims() != numOfBatchDims + 3)
205 return failure();
206
207 if (numOfBatchDims) {
208 // Each operand in a linalg generic contraction could express different
209 // permutations for its batch dimension. But for named op it must be
210 // identity since separate maps are not specified.
211 if (llvm::any_of(indexingMaps, [numOfBatchDims](AffineMap m) {
212 for (unsigned i = 0; i < numOfBatchDims; ++i) {
213 auto expr = m.getResults()[i];
214 if (expr.getKind() != AffineExprKind::DimId ||
215 cast<AffineDimExpr>(Val&: expr).getPosition() != i)
216 return true;
217 }
218 return false;
219 }))
220 return failure();
221 }
222
223 auto a =
224 matchOperandMap(indexingMaps[0], numOfBatchDims, dims.m[0], dims.k[0]);
225 auto b =
226 matchOperandMap(indexingMaps[1], numOfBatchDims, dims.k[0], dims.n[0]);
227 auto c =
228 matchOperandMap(indexingMaps[2], numOfBatchDims, dims.m[0], dims.n[0]);
229
230 if (llvm::is_contained({a, b, c}, IndexMatchResult::Mismatch))
231 return failure();
232
233 if (c != IndexMatchResult::Match ||
234 (a == IndexMatchResult::Transposed && b == IndexMatchResult::Transposed))
235 return failure();
236
237 /// Codegen the different matmul variants.
238 if (numOfBatchDims) {
239 if (a == IndexMatchResult::Transposed)
240 return replaceWithMatmulVariant<BatchMatmulTransposeAOp>(rewriter,
241 genericOp);
242 if (b == IndexMatchResult::Transposed)
243 return replaceWithMatmulVariant<BatchMatmulTransposeBOp>(rewriter,
244 genericOp);
245 return replaceWithMatmulVariant<BatchMatmulOp>(rewriter, genericOp);
246 }
247
248 if (a == IndexMatchResult::Transposed)
249 return replaceWithMatmulVariant<MatmulTransposeAOp>(rewriter, genericOp);
250 if (b == IndexMatchResult::Transposed)
251 return replaceWithMatmulVariant<MatmulTransposeBOp>(rewriter, genericOp);
252 return replaceWithMatmulVariant<MatmulOp>(rewriter, genericOp);
253}
254
255} // namespace
256
257//===----------------------------------------------------------------------===//
258// Categorize linalg generic to named op where possible.
259//===----------------------------------------------------------------------===//
260FailureOr<LinalgOp> mlir::linalg::specializeGenericOp(RewriterBase &rewriter,
261 GenericOp genericOp) {
262 // Copy
263 if (isaCopyOpInterface(genericOp)) {
264 LinalgOp namedOp = rewriter.replaceOpWithNewOp<CopyOp>(
265 genericOp, genericOp.getDpsInputs()[0], genericOp.getDpsInits()[0]);
266 return namedOp;
267 }
268
269 // Fill
270 if (isaFillOpInterface(genericOp)) {
271 LinalgOp namedOp = rewriter.replaceOpWithNewOp<FillOp>(
272 genericOp, genericOp.getDpsInputs()[0], genericOp.getDpsInits()[0]);
273 return namedOp;
274 }
275
276 // Broadcast
277 std::optional<SmallVector<int64_t>> equivalentToBroadcast =
278 isaBroadcastOpInterface(genericOp);
279 if (equivalentToBroadcast) {
280 auto dims = *equivalentToBroadcast;
281 LinalgOp namedOp = rewriter.replaceOpWithNewOp<BroadcastOp>(
282 genericOp, genericOp.getDpsInputs()[0], genericOp.getDpsInits()[0],
283 dims);
284 return namedOp;
285 }
286
287 // Transpose
288 std::optional<SmallVector<int64_t>> equivalentToTranspose =
289 isaTransposeOpInterface(genericOp);
290 if (equivalentToTranspose) {
291 auto permutation = *equivalentToTranspose;
292 LinalgOp namedOp = rewriter.replaceOpWithNewOp<TransposeOp>(
293 genericOp, genericOp.getDpsInputs()[0], genericOp.getDpsInits()[0],
294 permutation);
295 return namedOp;
296 }
297
298 // Elementwise Unary
299 if (isaElemwiseSingleUnaryOpInterface(genericOp)) {
300 Operation *op = &genericOp.getBody()->front();
301 if (isa<math::ExpOp>(op)) {
302 LinalgOp namedOp = REPLACE_UNARY_OP(ExpOp);
303 return namedOp;
304 }
305 }
306
307 // Elementwise Binary
308 if (isaElemwiseSingleBinaryOpInterface(genericOp)) {
309 bool swap = areBinOpsSwapped(genericOp);
310 Operation *op = &genericOp.getBody()->front();
311 if (isa<arith::AddFOp>(op)) {
312 LinalgOp namedOp = REPLACE_BINARY_OP(AddOp, swap);
313 return namedOp;
314 }
315 if (isa<arith::SubFOp>(op)) {
316 LinalgOp namedOp = REPLACE_BINARY_OP(SubOp, swap);
317 return namedOp;
318 }
319 if (isa<arith::MulFOp>(op)) {
320 LinalgOp namedOp = REPLACE_BINARY_OP(MulOp, swap);
321 return namedOp;
322 }
323 if (isa<arith::DivFOp>(op)) {
324 LinalgOp namedOp = REPLACE_BINARY_OP(DivOp, swap);
325 return namedOp;
326 }
327 }
328
329 // Contraction - e.g. matmul
330 if (isaContractionOpInterface(genericOp)) {
331 return specializeLinalgContractions(rewriter, genericOp);
332 }
333 return failure();
334}
335
336namespace {
337struct LinalgSpecializeGenericOpsPass
338 : public impl::LinalgSpecializeGenericOpsPassBase<
339 LinalgSpecializeGenericOpsPass> {
340
341 using impl::LinalgSpecializeGenericOpsPassBase<
342 LinalgSpecializeGenericOpsPass>::LinalgSpecializeGenericOpsPassBase;
343 void runOnOperation() override;
344};
345} // namespace
346
347void LinalgSpecializeGenericOpsPass::runOnOperation() {
348 RewritePatternSet patterns(&getContext());
349 populateLinalgGenericOpsSpecializationPatterns(patterns);
350 populateDecomposeProjectedPermutationPatterns(patterns);
351
352 if (failed(applyPatternsGreedily(getOperation(), std::move(patterns))))
353 signalPassFailure();
354}
355
356void mlir::linalg::populateLinalgGenericOpsSpecializationPatterns(
357 RewritePatternSet &patterns) {
358 patterns.add<LinalgSpecializationPattern>(arg: patterns.getContext());
359}
360

Provided by KDAB

Privacy Policy
Learn to use CMake with our Intro Training
Find out more

source code of mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp