1//===- OuterProductFusion.cpp - Fuse 'arm_sme.outerproduct' ops -----------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implements rewrites that fuse 'arm_sme.outerproduct' operations
10// into the 2-way or 4-way widening outerproduct operations.
11//
12//===----------------------------------------------------------------------===//
13
14#include "mlir/Dialect/ArmSME/IR/ArmSME.h"
15#include "mlir/Dialect/ArmSME/Transforms/Passes.h"
16#include "mlir/Dialect/ArmSME/Transforms/Transforms.h"
17#include "mlir/Dialect/Func/IR/FuncOps.h"
18#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
19#include "mlir/IR/PatternMatch.h"
20#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
21#include "llvm/ADT/TypeSwitch.h"
22
23#define DEBUG_TYPE "arm-sme-outerproduct-fusion"
24
25namespace mlir::arm_sme {
26#define GEN_PASS_DEF_OUTERPRODUCTFUSION
27#include "mlir/Dialect/ArmSME/Transforms/Passes.h.inc"
28} // namespace mlir::arm_sme
29
30using namespace mlir;
31using namespace mlir::arm_sme;
32
33namespace {
34
35// Common match failure reasons.
36static constexpr StringLiteral
37 kMatchFailureNoAccumulator("no accumulator operand");
38static constexpr StringLiteral kMatchFailureExpectedOuterProductDefOp(
39 "defining op of accumulator must be 'arm_sme.outerproduct'");
40static constexpr StringLiteral kMatchFailureInconsistentCombiningKind(
41 "combining kind (add or sub) of outer products must match");
42static constexpr StringLiteral kMatchFailureInconsistentMasking(
43 "unsupported masking, either both outerproducts are masked "
44 "or neither");
45static constexpr StringLiteral kMatchFailureOuterProductNotSingleUse(
46 "outer product(s) not single use and cannot be removed, no benefit to "
47 "fusing");
48
49// An outer product is compatible if all of the following are true:
50// - the result type matches `resultType`.
51// - the defining operation of LHS is of the type `LhsExtOp`.
52// - the defining operation of RHS is of the type `RhsExtOp`.
53// - the input types of the defining operations are identical and match
54// `inputType`.
55template <typename LhsExtOp, typename RhsExtOp = LhsExtOp>
56static LogicalResult isCompatible(PatternRewriter &rewriter,
57 arm_sme::OuterProductOp op,
58 VectorType resultType, VectorType inputType) {
59 if (op.getResultType() != resultType)
60 return rewriter.notifyMatchFailure(op.getLoc(), [&](Diagnostic &diag) {
61 diag << "unsupported result type, expected " << resultType;
62 });
63
64 auto lhsDefOp = op.getLhs().getDefiningOp<LhsExtOp>();
65 auto rhsDefOp = op.getRhs().getDefiningOp<RhsExtOp>();
66
67 if (!lhsDefOp || !rhsDefOp)
68 return rewriter.notifyMatchFailure(
69 op, "defining op of outerproduct operands must be one of: "
70 "'arith.extf' or 'arith.extsi' or 'arith.extui'");
71
72 auto lhsInType = cast<VectorType>(lhsDefOp.getIn().getType());
73 auto rhsInType = cast<VectorType>(rhsDefOp.getIn().getType());
74
75 if (lhsInType != inputType || rhsInType != inputType)
76 return rewriter.notifyMatchFailure(op.getLoc(), [&](Diagnostic &diag) {
77 diag << "unsupported input type, expected " << inputType;
78 });
79
80 return success();
81}
82
83// Create 'llvm.experimental.vector.interleave2' intrinsic from `lhs` and `rhs`.
84static Value createInterleave2Intrinsic(RewriterBase &rewriter, Location loc,
85 Value lhs, Value rhs) {
86 auto inputType = cast<VectorType>(lhs.getType());
87 VectorType inputTypeX2 =
88 VectorType::Builder(inputType).setDim(0, inputType.getShape()[0] * 2);
89 return rewriter.create<LLVM::experimental_vector_interleave2>(
90 loc, inputTypeX2, lhs, rhs);
91}
92
93// Fuse two 'arm_sme.outerproduct' operations that are chained via the
94// accumulator into 2-way outer product operation.
95//
96// For example:
97//
98// %a0_ext = arith.extf %a0 : vector<[4]xf16> to vector<[4]xf32>
99// %b0_ext = arith.extf %b0 : vector<[4]xf16> to vector<[4]xf32>
100// %0 = arm_sme.outerproduct %a0_ext, %b0_ext : vector<[4]xf32>,
101// vector<[4]xf32>
102//
103// %a1_ext = arith.extf %a1 : vector<[4]xf16> to vector<[4]xf32>
104// %b1_ext = arith.extf %b1 : vector<[4]xf16> to vector<[4]xf32>
105// %1 = arm_sme.outerproduct %a1_ext, %b1_ext, %0 : vector<[4]xf32>,
106// vector<[4]xf32>
107//
108// Becomes:
109//
110// %a_packed = "llvm.intr.experimental.vector.interleave2"(%a0, %a1)
111// : (vector<[4]xf16>, vector<[4]xf16>) -> vector<[8]xf16>
112// %b_packed = "llvm.intr.experimental.vector.interleave2"(%b0, %b1)
113// : (vector<[4]xf16>, vector<[4]xf16>) -> vector<[8]xf16>
114// %0 = arm_sme.fmopa_2way %a_packed, %b_packed
115// : vector<[8]xf16>, vector<[8]xf16> into vector<[4]x[4]xf32>
116class OuterProductFusion2Way
117 : public OpRewritePattern<arm_sme::OuterProductOp> {
118public:
119 using OpRewritePattern::OpRewritePattern;
120
121 LogicalResult matchAndRewrite(arm_sme::OuterProductOp op,
122 PatternRewriter &rewriter) const override {
123 Value acc = op.getAcc();
124 if (!acc)
125 return rewriter.notifyMatchFailure(op, kMatchFailureNoAccumulator);
126
127 arm_sme::OuterProductOp op1 = acc.getDefiningOp<arm_sme::OuterProductOp>();
128 arm_sme::OuterProductOp op2 = op;
129 if (!op1)
130 return rewriter.notifyMatchFailure(
131 op, kMatchFailureExpectedOuterProductDefOp);
132
133 if (op1.getKind() != op2.getKind())
134 return rewriter.notifyMatchFailure(
135 op, kMatchFailureInconsistentCombiningKind);
136
137 if (!op1->hasOneUse()) {
138 // If the first outer product has uses other than as the input to another
139 // outer product, it can't be erased after fusion. This is a problem when
140 // it also has an accumulator as this will be used as the root for tile
141 // allocation and since the widening outer product uses the same
142 // accumulator it will get assigned the same tile ID, resulting in 3
143 // outer products accumulating to the same tile and incorrect results.
144 //
145 // Example:
146 //
147 // %acc = arith.constant dense<0.0> ; root for tile allocation
148 // %0 = arm_sme.outerproduct %a0, %b0 acc(%acc)
149 // vector.print %0 ; intermediary use, can't erase %0
150 // %1 = arm_sme.outerproduct %a1, %b1 acc(%0)
151 //
152 // After fusion and tile allocation
153 //
154 // %0 = arm_sme.zero {tile_id = 0 : i32}
155 // %1 = arm_sme.outerproduct %a0, %b0 acc(%0) {tile_id = 0 : i32}
156 // vector.print %1
157 // %2 = arm_sme.fmopa_2way %a, %b acc(%0) {tile_id = 0 : i32}
158 //
159 // No accumulator would be ok, but it's simpler to prevent this
160 // altogether, since it has no benefit.
161 return rewriter.notifyMatchFailure(op,
162 kMatchFailureOuterProductNotSingleUse);
163 }
164
165 if (bool(op1.getLhsMask()) != bool(op2.getLhsMask()))
166 return rewriter.notifyMatchFailure(op, kMatchFailureInconsistentMasking);
167
168 if (failed(canFuseOuterProducts(rewriter, op1, op2)))
169 return failure();
170
171 auto loc = op.getLoc();
172 auto packInputs = [&](Value lhs, Value rhs) {
173 return createInterleave2Intrinsic(rewriter, loc, lhs, rhs);
174 };
175
176 auto lhs = packInputs(op1.getLhs().getDefiningOp()->getOperand(0),
177 op2.getLhs().getDefiningOp()->getOperand(0));
178 auto rhs = packInputs(op1.getRhs().getDefiningOp()->getOperand(0),
179 op2.getRhs().getDefiningOp()->getOperand(0));
180
181 Value lhsMask, rhsMask;
182 if (op1.getLhsMask() || op2.getLhsMask()) {
183 lhsMask = packInputs(op1.getLhsMask(), op2.getLhsMask());
184 rhsMask = packInputs(op1.getRhsMask(), op2.getRhsMask());
185 }
186
187 auto extOp = op.getLhs().getDefiningOp();
188
189 arm_sme::CombiningKind kind = op.getKind();
190 if (kind == arm_sme::CombiningKind::Add) {
191 TypeSwitch<Operation *>(extOp)
192 .Case<arith::ExtFOp>([&](auto) {
193 rewriter.replaceOpWithNewOp<arm_sme::FMopa2WayOp>(
194 op2, op.getResultType(), lhs, rhs, lhsMask, rhsMask,
195 op1.getAcc());
196 })
197 .Case<arith::ExtSIOp>([&](auto) {
198 rewriter.replaceOpWithNewOp<arm_sme::SMopa2WayOp>(
199 op2, op.getResultType(), lhs, rhs, lhsMask, rhsMask,
200 op1.getAcc());
201 })
202 .Case<arith::ExtUIOp>([&](auto) {
203 rewriter.replaceOpWithNewOp<arm_sme::UMopa2WayOp>(
204 op2, op.getResultType(), lhs, rhs, lhsMask, rhsMask,
205 op1.getAcc());
206 })
207 .Default([&](auto) { llvm_unreachable("unexpected extend op!"); });
208 } else if (kind == arm_sme::CombiningKind::Sub) {
209 TypeSwitch<Operation *>(extOp)
210 .Case<arith::ExtFOp>([&](auto) {
211 rewriter.replaceOpWithNewOp<arm_sme::FMops2WayOp>(
212 op2, op.getResultType(), lhs, rhs, lhsMask, rhsMask,
213 op1.getAcc());
214 })
215 .Case<arith::ExtSIOp>([&](auto) {
216 rewriter.replaceOpWithNewOp<arm_sme::SMops2WayOp>(
217 op2, op.getResultType(), lhs, rhs, lhsMask, rhsMask,
218 op1.getAcc());
219 })
220 .Case<arith::ExtUIOp>([&](auto) {
221 rewriter.replaceOpWithNewOp<arm_sme::UMops2WayOp>(
222 op2, op.getResultType(), lhs, rhs, lhsMask, rhsMask,
223 op1.getAcc());
224 })
225 .Default([&](auto) { llvm_unreachable("unexpected extend op!"); });
226 } else {
227 llvm_unreachable("unexpected arm_sme::CombiningKind!");
228 }
229
230 rewriter.eraseOp(op: op1);
231
232 return success();
233 }
234
235private:
236 // A pair of outer product can be fused if all of the following are true:
237 // - input and result types match.
238 // - the defining operations of the inputs are identical extensions,
239 // specifically either:
240 // - a signed or unsigned extension for integer types.
241 // - a floating-point extension for floating-point types.
242 // - the types and extension are supported, i.e. there's a 2-way operation
243 // they can be fused into.
244 LogicalResult canFuseOuterProducts(PatternRewriter &rewriter,
245 arm_sme::OuterProductOp op1,
246 arm_sme::OuterProductOp op2) const {
247 // Supported result types.
248 auto nxnxv4i32 =
249 VectorType::get({4, 4}, rewriter.getI32Type(), {true, true});
250 auto nxnxv4f32 =
251 VectorType::get({4, 4}, rewriter.getF32Type(), {true, true});
252 // Supported input types.
253 // Note: this is before packing so these have half the number of elements
254 // of the input vector types of the 2-way operations.
255 auto nxv4i16 = VectorType::get({4}, rewriter.getI16Type(), true);
256 auto nxv4f16 = VectorType::get({4}, rewriter.getF16Type(), true);
257 auto nxv4bf16 = VectorType::get({4}, rewriter.getBF16Type(), true);
258 if ((failed(
259 isCompatible<arith::ExtFOp>(rewriter, op1, nxnxv4f32, nxv4f16)) ||
260 failed(
261 isCompatible<arith::ExtFOp>(rewriter, op2, nxnxv4f32, nxv4f16))) &&
262 (failed(
263 isCompatible<arith::ExtFOp>(rewriter, op1, nxnxv4f32, nxv4bf16)) ||
264 failed(isCompatible<arith::ExtFOp>(rewriter, op2, nxnxv4f32,
265 nxv4bf16))) &&
266 (failed(
267 isCompatible<arith::ExtSIOp>(rewriter, op1, nxnxv4i32, nxv4i16)) ||
268 failed(isCompatible<arith::ExtSIOp>(rewriter, op2, nxnxv4i32,
269 nxv4i16))) &&
270 (failed(
271 isCompatible<arith::ExtUIOp>(rewriter, op1, nxnxv4i32, nxv4i16)) ||
272 failed(
273 isCompatible<arith::ExtUIOp>(rewriter, op2, nxnxv4i32, nxv4i16))))
274 return failure();
275
276 return success();
277 }
278};
279
280// Fuse four 'arm_sme.outerproduct' operations that are chained via the
281// accumulator into 4-way outer product operation.
282class OuterProductFusion4Way
283 : public OpRewritePattern<arm_sme::OuterProductOp> {
284public:
285 using OpRewritePattern::OpRewritePattern;
286
287 LogicalResult matchAndRewrite(arm_sme::OuterProductOp op,
288 PatternRewriter &rewriter) const override {
289 SmallVector<arm_sme::OuterProductOp, 4> outerProductChain;
290 outerProductChain.push_back(op);
291
292 for (int i = 0; i < 3; ++i) {
293 auto currentOp = outerProductChain.back();
294 auto acc = currentOp.getAcc();
295 if (!acc)
296 return rewriter.notifyMatchFailure(op, kMatchFailureNoAccumulator);
297 auto previousOp = acc.getDefiningOp<arm_sme::OuterProductOp>();
298 if (!previousOp)
299 return rewriter.notifyMatchFailure(
300 op, kMatchFailureExpectedOuterProductDefOp);
301 if (!previousOp->hasOneUse())
302 return rewriter.notifyMatchFailure(
303 op, kMatchFailureOuterProductNotSingleUse);
304 if (previousOp.getKind() != currentOp.getKind())
305 return rewriter.notifyMatchFailure(
306 op, kMatchFailureInconsistentCombiningKind);
307 if (bool(previousOp.getLhsMask()) != bool(currentOp.getLhsMask()))
308 return rewriter.notifyMatchFailure(
309 op, kMatchFailureInconsistentCombiningKind);
310 outerProductChain.push_back(previousOp);
311 }
312
313 if (failed(canFuseOuterProducts(rewriter, outerProductChain)))
314 return failure();
315
316 arm_sme::OuterProductOp op1 = outerProductChain[3];
317 arm_sme::OuterProductOp op2 = outerProductChain[2];
318 arm_sme::OuterProductOp op3 = outerProductChain[1];
319 arm_sme::OuterProductOp op4 = outerProductChain[0];
320
321 auto loc = op.getLoc();
322 auto packInputs = [&](Value lhs, Value rhs) {
323 return createInterleave2Intrinsic(rewriter, loc, lhs, rhs);
324 };
325
326 auto lhs0 = packInputs(op1.getLhs().getDefiningOp()->getOperand(0),
327 op3.getLhs().getDefiningOp()->getOperand(0));
328 auto lhs1 = packInputs(op2.getLhs().getDefiningOp()->getOperand(0),
329 op4.getLhs().getDefiningOp()->getOperand(0));
330 auto lhs = packInputs(lhs0, lhs1);
331
332 auto rhs0 = packInputs(op1.getRhs().getDefiningOp()->getOperand(0),
333 op3.getRhs().getDefiningOp()->getOperand(0));
334 auto rhs1 = packInputs(op2.getRhs().getDefiningOp()->getOperand(0),
335 op4.getRhs().getDefiningOp()->getOperand(0));
336 auto rhs = packInputs(rhs0, rhs1);
337
338 Value lhsMask, rhsMask;
339 if (op1.getLhsMask() || op2.getLhsMask() || op3.getLhsMask() ||
340 op4.getLhsMask()) {
341 auto lhs0Mask = packInputs(op1.getLhsMask(), op3.getLhsMask());
342 auto lhs1Mask = packInputs(op2.getLhsMask(), op4.getLhsMask());
343 lhsMask = packInputs(lhs0Mask, lhs1Mask);
344
345 auto rhs0Mask = packInputs(op1.getRhsMask(), op3.getRhsMask());
346 auto rhs1Mask = packInputs(op2.getRhsMask(), op4.getRhsMask());
347 rhsMask = packInputs(rhs0Mask, rhs1Mask);
348 }
349
350 auto lhsExtOp = op.getLhs().getDefiningOp();
351 auto rhsExtOp = op.getRhs().getDefiningOp();
352
353 arm_sme::CombiningKind kind = op.getKind();
354 if (kind == arm_sme::CombiningKind::Add) {
355 if (isa<arith::ExtSIOp>(lhsExtOp) && isa<arith::ExtSIOp>(rhsExtOp)) {
356 // signed
357 rewriter.replaceOpWithNewOp<arm_sme::SMopa4WayOp>(
358 op4, op.getResultType(), lhs, rhs, lhsMask, rhsMask, op1.getAcc());
359 } else if (isa<arith::ExtUIOp>(lhsExtOp) &&
360 isa<arith::ExtUIOp>(rhsExtOp)) {
361 // unsigned
362 rewriter.replaceOpWithNewOp<arm_sme::UMopa4WayOp>(
363 op4, op.getResultType(), lhs, rhs, lhsMask, rhsMask, op1.getAcc());
364 } else if (isa<arith::ExtSIOp>(lhsExtOp) &&
365 isa<arith::ExtUIOp>(rhsExtOp)) {
366 // signed by unsigned
367 rewriter.replaceOpWithNewOp<arm_sme::SuMopa4WayOp>(
368 op4, op.getResultType(), lhs, rhs, lhsMask, rhsMask, op1.getAcc());
369 } else if (isa<arith::ExtUIOp>(lhsExtOp) &&
370 isa<arith::ExtSIOp>(rhsExtOp)) {
371 // unsigned by signed
372 rewriter.replaceOpWithNewOp<arm_sme::UsMopa4WayOp>(
373 op4, op.getResultType(), lhs, rhs, lhsMask, rhsMask, op1.getAcc());
374 } else {
375 llvm_unreachable("unexpected extend op!");
376 }
377 } else if (kind == arm_sme::CombiningKind::Sub) {
378 if (isa<arith::ExtSIOp>(lhsExtOp) && isa<arith::ExtSIOp>(rhsExtOp)) {
379 // signed
380 rewriter.replaceOpWithNewOp<arm_sme::SMops4WayOp>(
381 op4, op.getResultType(), lhs, rhs, lhsMask, rhsMask, op1.getAcc());
382 } else if (isa<arith::ExtUIOp>(lhsExtOp) &&
383 isa<arith::ExtUIOp>(rhsExtOp)) {
384 // unsigned
385 rewriter.replaceOpWithNewOp<arm_sme::UMops4WayOp>(
386 op4, op.getResultType(), lhs, rhs, lhsMask, rhsMask, op1.getAcc());
387 } else if (isa<arith::ExtSIOp>(lhsExtOp) &&
388 isa<arith::ExtUIOp>(rhsExtOp)) {
389 // signed by unsigned
390 rewriter.replaceOpWithNewOp<arm_sme::SuMops4WayOp>(
391 op4, op.getResultType(), lhs, rhs, lhsMask, rhsMask, op1.getAcc());
392 } else if (isa<arith::ExtUIOp>(lhsExtOp) &&
393 isa<arith::ExtSIOp>(rhsExtOp)) {
394 // unsigned by signed
395 rewriter.replaceOpWithNewOp<arm_sme::UsMops4WayOp>(
396 op4, op.getResultType(), lhs, rhs, lhsMask, rhsMask, op1.getAcc());
397 } else {
398 llvm_unreachable("unexpected extend op!");
399 }
400 } else {
401 llvm_unreachable("unexpected arm_sme::CombiningKind!");
402 }
403
404 rewriter.eraseOp(op: op3);
405 rewriter.eraseOp(op: op2);
406 rewriter.eraseOp(op: op1);
407
408 return success();
409 }
410
411private:
412 // Four outer products can be fused if all of the following are true:
413 // - input and result types match.
414 // - the defining operations of the inputs are identical extensions,
415 // specifically either:
416 // - a signed or unsigned extension for integer types.
417 // - a floating-point extension for floating-point types.
418 // - the types and extension are supported, i.e. there's a 4-way operation
419 // they can be fused into.
420 LogicalResult
421 canFuseOuterProducts(PatternRewriter &rewriter,
422 ArrayRef<arm_sme::OuterProductOp> ops) const {
423 // Supported result types.
424 auto nxnxv4i32 =
425 VectorType::get({4, 4}, rewriter.getI32Type(), {true, true});
426 auto nxnxv2i64 =
427 VectorType::get({2, 2}, rewriter.getI64Type(), {true, true});
428
429 // Supported input types.
430 // Note: this is before packing so these have 1/4 the number of elements
431 // of the input vector types of the 4-way operations.
432 auto nxv4i8 = VectorType::get({4}, rewriter.getI8Type(), true);
433 auto nxv2i16 = VectorType::get({2}, rewriter.getI16Type(), true);
434
435 auto failedToMatch = [&](VectorType resultType, VectorType inputType,
436 auto lhsExtendOp, auto rhsExtendOp) {
437 using LhsExtendOpTy = decltype(lhsExtendOp);
438 using RhsExtendOpTy = decltype(rhsExtendOp);
439 for (auto op : ops) {
440 if (failed(isCompatible<LhsExtendOpTy, RhsExtendOpTy>(
441 rewriter, op, resultType, inputType)))
442 return true;
443 }
444 return false;
445 };
446
447 if (failedToMatch(nxnxv4i32, nxv4i8, arith::ExtSIOp{}, arith::ExtSIOp{}) &&
448 failedToMatch(nxnxv4i32, nxv4i8, arith::ExtUIOp{}, arith::ExtUIOp{}) &&
449 failedToMatch(nxnxv4i32, nxv4i8, arith::ExtSIOp{}, arith::ExtUIOp{}) &&
450 failedToMatch(nxnxv4i32, nxv4i8, arith::ExtUIOp{}, arith::ExtSIOp{}) &&
451 failedToMatch(nxnxv2i64, nxv2i16, arith::ExtSIOp{}, arith::ExtSIOp{}) &&
452 failedToMatch(nxnxv2i64, nxv2i16, arith::ExtUIOp{}, arith::ExtUIOp{}) &&
453 failedToMatch(nxnxv2i64, nxv2i16, arith::ExtSIOp{}, arith::ExtUIOp{}) &&
454 failedToMatch(nxnxv2i64, nxv2i16, arith::ExtUIOp{}, arith::ExtSIOp{}))
455 return failure();
456
457 return success();
458 }
459};
460
461// Rewrites: vector.extract(arith.extend) -> arith.extend(vector.extract).
462//
463// This transforms IR like:
464// %0 = arith.extsi %src : vector<4x[8]xi8> to vector<4x[8]xi32>
465// %1 = vector.extract %0[0] : vector<[8]xi32> from vector<4x[8]xi32>
466// Into:
467// %0 = vector.extract %src[0] : vector<[8]xi8> from vector<4x[8]xi8>
468// %1 = arith.extsi %0 : vector<[8]xi8> to vector<[8]xi32>
469//
470// This enables outer product fusion in the `-arm-sme-outer-product-fusion`
471// pass when the result is the input to an outer product.
472struct SwapVectorExtractOfArithExtend
473 : public OpRewritePattern<vector::ExtractOp> {
474 using OpRewritePattern::OpRewritePattern;
475
476 LogicalResult matchAndRewrite(vector::ExtractOp extractOp,
477 PatternRewriter &rewriter) const override {
478 VectorType resultType = llvm::dyn_cast<VectorType>(extractOp.getType());
479 if (!resultType)
480 return rewriter.notifyMatchFailure(extractOp,
481 "extracted type is not a vector type");
482
483 auto numScalableDims = llvm::count(resultType.getScalableDims(), true);
484 if (numScalableDims != 1)
485 return rewriter.notifyMatchFailure(
486 extractOp, "extracted type is not a 1-D scalable vector type");
487
488 auto *extendOp = extractOp.getVector().getDefiningOp();
489 if (!isa_and_present<arith::ExtSIOp, arith::ExtUIOp, arith::ExtFOp>(
490 extendOp))
491 return rewriter.notifyMatchFailure(extractOp,
492 "extract not from extend op");
493
494 auto loc = extractOp.getLoc();
495 StringAttr extendOpName = extendOp->getName().getIdentifier();
496 Value extendSource = extendOp->getOperand(0);
497
498 // Create new extract from source of extend.
499 Value newExtract = rewriter.create<vector::ExtractOp>(
500 loc, extendSource, extractOp.getMixedPosition());
501
502 // Extend new extract to original result type.
503 Operation *newExtend =
504 rewriter.create(loc, extendOpName, Value(newExtract), resultType);
505
506 rewriter.replaceOp(extractOp, newExtend);
507
508 return success();
509 }
510};
511
512// Same as above, but for vector.scalable.extract.
513//
514// This transforms IR like:
515// %0 = arith.extsi %src : vector<[8]xi8> to vector<[8]xi32>
516// %1 = vector.scalable.extract %0[0] : vector<[4]xi32> from vector<[8]xi32>
517// Into:
518// %0 = vector.scalable.extract %src[0] : vector<[4]xi8> from vector<[8]xi8>
519// %1 = arith.extsi %0 : vector<[4]xi8> to vector<[4]xi32>
520//
521// This enables outer product fusion in the `-arm-sme-outer-product-fusion`
522// pass when the result is the input to an outer product.
523struct SwapVectorScalableExtractOfArithExtend
524 : public OpRewritePattern<vector::ScalableExtractOp> {
525 using OpRewritePattern::OpRewritePattern;
526
527 LogicalResult matchAndRewrite(vector::ScalableExtractOp extractOp,
528 PatternRewriter &rewriter) const override {
529 auto *extendOp = extractOp.getSource().getDefiningOp();
530 if (!isa_and_present<arith::ExtSIOp, arith::ExtUIOp, arith::ExtFOp>(
531 extendOp))
532 return rewriter.notifyMatchFailure(extractOp,
533 "extract not from extend op");
534
535 auto loc = extractOp.getLoc();
536 VectorType resultType = extractOp.getResultVectorType();
537
538 Value extendSource = extendOp->getOperand(0);
539 StringAttr extendOpName = extendOp->getName().getIdentifier();
540 VectorType extendSourceVectorType =
541 cast<VectorType>(extendSource.getType());
542
543 // Create new extract from source of extend.
544 VectorType extractResultVectorType =
545 resultType.clone(extendSourceVectorType.getElementType());
546 Value newExtract = rewriter.create<vector::ScalableExtractOp>(
547 loc, extractResultVectorType, extendSource, extractOp.getPos());
548
549 // Extend new extract to original result type.
550 Operation *newExtend =
551 rewriter.create(loc, extendOpName, Value(newExtract), resultType);
552
553 rewriter.replaceOp(extractOp, newExtend);
554
555 return success();
556 }
557};
558
559struct OuterProductFusionPass
560 : public arm_sme::impl::OuterProductFusionBase<OuterProductFusionPass> {
561
562 void runOnOperation() override {
563 RewritePatternSet patterns(&getContext());
564 populateOuterProductFusionPatterns(patterns);
565
566 if (failed(
567 applyPatternsAndFoldGreedily(getOperation(), std::move(patterns))))
568 signalPassFailure();
569 }
570};
571
572} // namespace
573
574void mlir::arm_sme::populateOuterProductFusionPatterns(
575 RewritePatternSet &patterns) {
576 MLIRContext *context = patterns.getContext();
577 // Note: High benefit to ensure extract(extend) are swapped first.
578 patterns.add<SwapVectorExtractOfArithExtend,
579 SwapVectorScalableExtractOfArithExtend>(arg&: context, args: 1024);
580 patterns.add<OuterProductFusion2Way, OuterProductFusion4Way>(arg&: context);
581}
582
583std::unique_ptr<Pass> mlir::arm_sme::createOuterProductFusionPass() {
584 return std::make_unique<OuterProductFusionPass>();
585}
586

source code of mlir/lib/Dialect/ArmSME/Transforms/OuterProductFusion.cpp