1//===- OuterProductFusion.cpp - Fuse 'arm_sme.outerproduct' ops -----------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implements rewrites that fuse 'arm_sme.outerproduct' operations
10// into the 2-way or 4-way widening outerproduct operations.
11//
12//===----------------------------------------------------------------------===//
13
14#include "mlir/Dialect/ArmSME/IR/ArmSME.h"
15#include "mlir/Dialect/ArmSME/Transforms/Passes.h"
16#include "mlir/Dialect/ArmSME/Transforms/Transforms.h"
17#include "mlir/Dialect/Func/IR/FuncOps.h"
18#include "mlir/IR/PatternMatch.h"
19#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
20#include "llvm/ADT/TypeSwitch.h"
21
22#define DEBUG_TYPE "arm-sme-outerproduct-fusion"
23
24namespace mlir::arm_sme {
25#define GEN_PASS_DEF_OUTERPRODUCTFUSION
26#include "mlir/Dialect/ArmSME/Transforms/Passes.h.inc"
27} // namespace mlir::arm_sme
28
29using namespace mlir;
30using namespace mlir::arm_sme;
31
32namespace {
33
34// Common match failure reasons.
35static constexpr StringLiteral
36 kMatchFailureNoAccumulator("no accumulator operand");
37static constexpr StringLiteral kMatchFailureExpectedOuterProductDefOp(
38 "defining op of accumulator must be 'arm_sme.outerproduct'");
39static constexpr StringLiteral kMatchFailureInconsistentCombiningKind(
40 "combining kind (add or sub) of outer products must match");
41static constexpr StringLiteral kMatchFailureInconsistentMasking(
42 "unsupported masking, either both outerproducts are masked "
43 "or neither");
44static constexpr StringLiteral kMatchFailureOuterProductNotSingleUse(
45 "outer product(s) not single use and cannot be removed, no benefit to "
46 "fusing");
47
48// An outer product is compatible if all of the following are true:
49// - the result type matches `resultType`.
50// - the defining operation of LHS is of the type `LhsExtOp`.
51// - the defining operation of RHS is of the type `RhsExtOp`.
52// - the input types of the defining operations are identical and match
53// `inputType`.
54template <typename LhsExtOp, typename RhsExtOp = LhsExtOp>
55static LogicalResult isCompatible(PatternRewriter &rewriter,
56 arm_sme::OuterProductOp op,
57 VectorType resultType, VectorType inputType) {
58 if (op.getResultType() != resultType)
59 return rewriter.notifyMatchFailure(op.getLoc(), [&](Diagnostic &diag) {
60 diag << "unsupported result type, expected " << resultType;
61 });
62
63 auto lhsDefOp = op.getLhs().getDefiningOp<LhsExtOp>();
64 auto rhsDefOp = op.getRhs().getDefiningOp<RhsExtOp>();
65
66 if (!lhsDefOp || !rhsDefOp)
67 return rewriter.notifyMatchFailure(
68 op, "defining op of outerproduct operands must be one of: "
69 "'arith.extf' or 'arith.extsi' or 'arith.extui'");
70
71 auto lhsInType = cast<VectorType>(lhsDefOp.getIn().getType());
72 auto rhsInType = cast<VectorType>(rhsDefOp.getIn().getType());
73
74 if (lhsInType != inputType || rhsInType != inputType)
75 return rewriter.notifyMatchFailure(op.getLoc(), [&](Diagnostic &diag) {
76 diag << "unsupported input type, expected " << inputType;
77 });
78
79 return success();
80}
81
82// Fuse two 'arm_sme.outerproduct' operations that are chained via the
83// accumulator into 2-way outer product operation.
84//
85// For example:
86//
87// %a0_ext = arith.extf %a0 : vector<[4]xf16> to vector<[4]xf32>
88// %b0_ext = arith.extf %b0 : vector<[4]xf16> to vector<[4]xf32>
89// %0 = arm_sme.outerproduct %a0_ext, %b0_ext : vector<[4]xf32>,
90// vector<[4]xf32>
91//
92// %a1_ext = arith.extf %a1 : vector<[4]xf16> to vector<[4]xf32>
93// %b1_ext = arith.extf %b1 : vector<[4]xf16> to vector<[4]xf32>
94// %1 = arm_sme.outerproduct %a1_ext, %b1_ext, %0 : vector<[4]xf32>,
95// vector<[4]xf32>
96//
97// Becomes:
98//
99// %a_packed = vector.interleave %a0, %a1 : vector<[4]xf16> -> vector<[8]xf16>
100// %b_packed = vector.interleave %b0, %b1 : vector<[4]xf16> -> vector<[8]xf16>
101// %0 = arm_sme.fmopa_2way %a_packed, %b_packed
102// : vector<[8]xf16>, vector<[8]xf16> into vector<[4]x[4]xf32>
103class OuterProductFusion2Way
104 : public OpRewritePattern<arm_sme::OuterProductOp> {
105public:
106 using OpRewritePattern::OpRewritePattern;
107
108 LogicalResult matchAndRewrite(arm_sme::OuterProductOp op,
109 PatternRewriter &rewriter) const override {
110 Value acc = op.getAcc();
111 if (!acc)
112 return rewriter.notifyMatchFailure(op, kMatchFailureNoAccumulator);
113
114 arm_sme::OuterProductOp op1 = acc.getDefiningOp<arm_sme::OuterProductOp>();
115 arm_sme::OuterProductOp op2 = op;
116 if (!op1)
117 return rewriter.notifyMatchFailure(
118 op, kMatchFailureExpectedOuterProductDefOp);
119
120 if (op1.getKind() != op2.getKind())
121 return rewriter.notifyMatchFailure(
122 op, kMatchFailureInconsistentCombiningKind);
123
124 if (!op1->hasOneUse()) {
125 // If the first outer product has uses other than as the input to another
126 // outer product, it can't be erased after fusion.
127 return rewriter.notifyMatchFailure(op,
128 kMatchFailureOuterProductNotSingleUse);
129 }
130
131 if (bool(op1.getLhsMask()) != bool(op2.getLhsMask()))
132 return rewriter.notifyMatchFailure(op, kMatchFailureInconsistentMasking);
133
134 if (failed(canFuseOuterProducts(rewriter, op1, op2)))
135 return failure();
136
137 auto loc = op.getLoc();
138 auto packInputs = [&](Value lhs, Value rhs) {
139 return rewriter.create<vector::InterleaveOp>(loc, lhs, rhs);
140 };
141
142 auto lhs = packInputs(op1.getLhs().getDefiningOp()->getOperand(0),
143 op2.getLhs().getDefiningOp()->getOperand(0));
144 auto rhs = packInputs(op1.getRhs().getDefiningOp()->getOperand(0),
145 op2.getRhs().getDefiningOp()->getOperand(0));
146
147 Value lhsMask, rhsMask;
148 if (op1.getLhsMask() || op2.getLhsMask()) {
149 lhsMask = packInputs(op1.getLhsMask(), op2.getLhsMask());
150 rhsMask = packInputs(op1.getRhsMask(), op2.getRhsMask());
151 }
152
153 auto extOp = op.getLhs().getDefiningOp();
154
155 arm_sme::CombiningKind kind = op.getKind();
156 if (kind == arm_sme::CombiningKind::Add) {
157 TypeSwitch<Operation *>(extOp)
158 .Case<arith::ExtFOp>([&](auto) {
159 rewriter.replaceOpWithNewOp<arm_sme::FMopa2WayOp>(
160 op2, op.getResultType(), lhs, rhs, lhsMask, rhsMask,
161 op1.getAcc());
162 })
163 .Case<arith::ExtSIOp>([&](auto) {
164 rewriter.replaceOpWithNewOp<arm_sme::SMopa2WayOp>(
165 op2, op.getResultType(), lhs, rhs, lhsMask, rhsMask,
166 op1.getAcc());
167 })
168 .Case<arith::ExtUIOp>([&](auto) {
169 rewriter.replaceOpWithNewOp<arm_sme::UMopa2WayOp>(
170 op2, op.getResultType(), lhs, rhs, lhsMask, rhsMask,
171 op1.getAcc());
172 })
173 .Default([&](auto) { llvm_unreachable("unexpected extend op!"); });
174 } else if (kind == arm_sme::CombiningKind::Sub) {
175 TypeSwitch<Operation *>(extOp)
176 .Case<arith::ExtFOp>([&](auto) {
177 rewriter.replaceOpWithNewOp<arm_sme::FMops2WayOp>(
178 op2, op.getResultType(), lhs, rhs, lhsMask, rhsMask,
179 op1.getAcc());
180 })
181 .Case<arith::ExtSIOp>([&](auto) {
182 rewriter.replaceOpWithNewOp<arm_sme::SMops2WayOp>(
183 op2, op.getResultType(), lhs, rhs, lhsMask, rhsMask,
184 op1.getAcc());
185 })
186 .Case<arith::ExtUIOp>([&](auto) {
187 rewriter.replaceOpWithNewOp<arm_sme::UMops2WayOp>(
188 op2, op.getResultType(), lhs, rhs, lhsMask, rhsMask,
189 op1.getAcc());
190 })
191 .Default([&](auto) { llvm_unreachable("unexpected extend op!"); });
192 } else {
193 llvm_unreachable("unexpected arm_sme::CombiningKind!");
194 }
195
196 return success();
197 }
198
199private:
200 // A pair of outer product can be fused if all of the following are true:
201 // - input and result types match.
202 // - the defining operations of the inputs are identical extensions,
203 // specifically either:
204 // - a signed or unsigned extension for integer types.
205 // - a floating-point extension for floating-point types.
206 // - the types and extension are supported, i.e. there's a 2-way operation
207 // they can be fused into.
208 LogicalResult canFuseOuterProducts(PatternRewriter &rewriter,
209 arm_sme::OuterProductOp op1,
210 arm_sme::OuterProductOp op2) const {
211 // Supported result types.
212 auto nxnxv4i32 =
213 VectorType::get({4, 4}, rewriter.getI32Type(), {true, true});
214 auto nxnxv4f32 =
215 VectorType::get({4, 4}, rewriter.getF32Type(), {true, true});
216 // Supported input types.
217 // Note: this is before packing so these have half the number of elements
218 // of the input vector types of the 2-way operations.
219 auto nxv4i16 = VectorType::get({4}, rewriter.getI16Type(), true);
220 auto nxv4f16 = VectorType::get({4}, rewriter.getF16Type(), true);
221 auto nxv4bf16 = VectorType::get({4}, rewriter.getBF16Type(), true);
222 if ((failed(
223 isCompatible<arith::ExtFOp>(rewriter, op1, nxnxv4f32, nxv4f16)) ||
224 failed(
225 isCompatible<arith::ExtFOp>(rewriter, op2, nxnxv4f32, nxv4f16))) &&
226 (failed(
227 isCompatible<arith::ExtFOp>(rewriter, op1, nxnxv4f32, nxv4bf16)) ||
228 failed(isCompatible<arith::ExtFOp>(rewriter, op2, nxnxv4f32,
229 nxv4bf16))) &&
230 (failed(
231 isCompatible<arith::ExtSIOp>(rewriter, op1, nxnxv4i32, nxv4i16)) ||
232 failed(isCompatible<arith::ExtSIOp>(rewriter, op2, nxnxv4i32,
233 nxv4i16))) &&
234 (failed(
235 isCompatible<arith::ExtUIOp>(rewriter, op1, nxnxv4i32, nxv4i16)) ||
236 failed(
237 isCompatible<arith::ExtUIOp>(rewriter, op2, nxnxv4i32, nxv4i16))))
238 return failure();
239
240 return success();
241 }
242};
243
244// Fuse four 'arm_sme.outerproduct' operations that are chained via the
245// accumulator into 4-way outer product operation.
246class OuterProductFusion4Way
247 : public OpRewritePattern<arm_sme::OuterProductOp> {
248public:
249 using OpRewritePattern::OpRewritePattern;
250
251 LogicalResult matchAndRewrite(arm_sme::OuterProductOp op,
252 PatternRewriter &rewriter) const override {
253 SmallVector<arm_sme::OuterProductOp, 4> outerProductChain;
254 outerProductChain.push_back(op);
255
256 for (int i = 0; i < 3; ++i) {
257 auto currentOp = outerProductChain.back();
258 auto acc = currentOp.getAcc();
259 if (!acc)
260 return rewriter.notifyMatchFailure(op, kMatchFailureNoAccumulator);
261 auto previousOp = acc.getDefiningOp<arm_sme::OuterProductOp>();
262 if (!previousOp)
263 return rewriter.notifyMatchFailure(
264 op, kMatchFailureExpectedOuterProductDefOp);
265 if (!previousOp->hasOneUse())
266 return rewriter.notifyMatchFailure(
267 op, kMatchFailureOuterProductNotSingleUse);
268 if (previousOp.getKind() != currentOp.getKind())
269 return rewriter.notifyMatchFailure(
270 op, kMatchFailureInconsistentCombiningKind);
271 if (bool(previousOp.getLhsMask()) != bool(currentOp.getLhsMask()))
272 return rewriter.notifyMatchFailure(
273 op, kMatchFailureInconsistentCombiningKind);
274 outerProductChain.push_back(previousOp);
275 }
276
277 if (failed(canFuseOuterProducts(rewriter, outerProductChain)))
278 return failure();
279
280 arm_sme::OuterProductOp op1 = outerProductChain[3];
281 arm_sme::OuterProductOp op2 = outerProductChain[2];
282 arm_sme::OuterProductOp op3 = outerProductChain[1];
283 arm_sme::OuterProductOp op4 = outerProductChain[0];
284
285 auto loc = op.getLoc();
286 auto packInputs = [&](Value lhs, Value rhs) {
287 return rewriter.create<vector::InterleaveOp>(loc, lhs, rhs);
288 };
289
290 auto lhs0 = packInputs(op1.getLhs().getDefiningOp()->getOperand(0),
291 op3.getLhs().getDefiningOp()->getOperand(0));
292 auto lhs1 = packInputs(op2.getLhs().getDefiningOp()->getOperand(0),
293 op4.getLhs().getDefiningOp()->getOperand(0));
294 auto lhs = packInputs(lhs0, lhs1);
295
296 auto rhs0 = packInputs(op1.getRhs().getDefiningOp()->getOperand(0),
297 op3.getRhs().getDefiningOp()->getOperand(0));
298 auto rhs1 = packInputs(op2.getRhs().getDefiningOp()->getOperand(0),
299 op4.getRhs().getDefiningOp()->getOperand(0));
300 auto rhs = packInputs(rhs0, rhs1);
301
302 Value lhsMask, rhsMask;
303 if (op1.getLhsMask() || op2.getLhsMask() || op3.getLhsMask() ||
304 op4.getLhsMask()) {
305 auto lhs0Mask = packInputs(op1.getLhsMask(), op3.getLhsMask());
306 auto lhs1Mask = packInputs(op2.getLhsMask(), op4.getLhsMask());
307 lhsMask = packInputs(lhs0Mask, lhs1Mask);
308
309 auto rhs0Mask = packInputs(op1.getRhsMask(), op3.getRhsMask());
310 auto rhs1Mask = packInputs(op2.getRhsMask(), op4.getRhsMask());
311 rhsMask = packInputs(rhs0Mask, rhs1Mask);
312 }
313
314 auto lhsExtOp = op.getLhs().getDefiningOp();
315 auto rhsExtOp = op.getRhs().getDefiningOp();
316
317 arm_sme::CombiningKind kind = op.getKind();
318 if (kind == arm_sme::CombiningKind::Add) {
319 if (isa<arith::ExtSIOp>(lhsExtOp) && isa<arith::ExtSIOp>(rhsExtOp)) {
320 // signed
321 rewriter.replaceOpWithNewOp<arm_sme::SMopa4WayOp>(
322 op4, op.getResultType(), lhs, rhs, lhsMask, rhsMask, op1.getAcc());
323 } else if (isa<arith::ExtUIOp>(lhsExtOp) &&
324 isa<arith::ExtUIOp>(rhsExtOp)) {
325 // unsigned
326 rewriter.replaceOpWithNewOp<arm_sme::UMopa4WayOp>(
327 op4, op.getResultType(), lhs, rhs, lhsMask, rhsMask, op1.getAcc());
328 } else if (isa<arith::ExtSIOp>(lhsExtOp) &&
329 isa<arith::ExtUIOp>(rhsExtOp)) {
330 // signed by unsigned
331 rewriter.replaceOpWithNewOp<arm_sme::SuMopa4WayOp>(
332 op4, op.getResultType(), lhs, rhs, lhsMask, rhsMask, op1.getAcc());
333 } else if (isa<arith::ExtUIOp>(lhsExtOp) &&
334 isa<arith::ExtSIOp>(rhsExtOp)) {
335 // unsigned by signed
336 rewriter.replaceOpWithNewOp<arm_sme::UsMopa4WayOp>(
337 op4, op.getResultType(), lhs, rhs, lhsMask, rhsMask, op1.getAcc());
338 } else {
339 llvm_unreachable("unexpected extend op!");
340 }
341 } else if (kind == arm_sme::CombiningKind::Sub) {
342 if (isa<arith::ExtSIOp>(lhsExtOp) && isa<arith::ExtSIOp>(rhsExtOp)) {
343 // signed
344 rewriter.replaceOpWithNewOp<arm_sme::SMops4WayOp>(
345 op4, op.getResultType(), lhs, rhs, lhsMask, rhsMask, op1.getAcc());
346 } else if (isa<arith::ExtUIOp>(lhsExtOp) &&
347 isa<arith::ExtUIOp>(rhsExtOp)) {
348 // unsigned
349 rewriter.replaceOpWithNewOp<arm_sme::UMops4WayOp>(
350 op4, op.getResultType(), lhs, rhs, lhsMask, rhsMask, op1.getAcc());
351 } else if (isa<arith::ExtSIOp>(lhsExtOp) &&
352 isa<arith::ExtUIOp>(rhsExtOp)) {
353 // signed by unsigned
354 rewriter.replaceOpWithNewOp<arm_sme::SuMops4WayOp>(
355 op4, op.getResultType(), lhs, rhs, lhsMask, rhsMask, op1.getAcc());
356 } else if (isa<arith::ExtUIOp>(lhsExtOp) &&
357 isa<arith::ExtSIOp>(rhsExtOp)) {
358 // unsigned by signed
359 rewriter.replaceOpWithNewOp<arm_sme::UsMops4WayOp>(
360 op4, op.getResultType(), lhs, rhs, lhsMask, rhsMask, op1.getAcc());
361 } else {
362 llvm_unreachable("unexpected extend op!");
363 }
364 } else {
365 llvm_unreachable("unexpected arm_sme::CombiningKind!");
366 }
367
368 return success();
369 }
370
371private:
372 // Four outer products can be fused if all of the following are true:
373 // - input and result types match.
374 // - the defining operations of the inputs are identical extensions,
375 // specifically either:
376 // - a signed or unsigned extension for integer types.
377 // - a floating-point extension for floating-point types.
378 // - the types and extension are supported, i.e. there's a 4-way operation
379 // they can be fused into.
380 LogicalResult
381 canFuseOuterProducts(PatternRewriter &rewriter,
382 ArrayRef<arm_sme::OuterProductOp> ops) const {
383 // Supported result types.
384 auto nxnxv4i32 =
385 VectorType::get({4, 4}, rewriter.getI32Type(), {true, true});
386 auto nxnxv2i64 =
387 VectorType::get({2, 2}, rewriter.getI64Type(), {true, true});
388
389 // Supported input types.
390 // Note: this is before packing so these have 1/4 the number of elements
391 // of the input vector types of the 4-way operations.
392 auto nxv4i8 = VectorType::get({4}, rewriter.getI8Type(), true);
393 auto nxv2i16 = VectorType::get({2}, rewriter.getI16Type(), true);
394
395 auto failedToMatch = [&](VectorType resultType, VectorType inputType,
396 auto lhsExtendOp, auto rhsExtendOp) {
397 using LhsExtendOpTy = decltype(lhsExtendOp);
398 using RhsExtendOpTy = decltype(rhsExtendOp);
399 for (auto op : ops) {
400 if (failed(isCompatible<LhsExtendOpTy, RhsExtendOpTy>(
401 rewriter, op, resultType, inputType)))
402 return true;
403 }
404 return false;
405 };
406
407 if (failedToMatch(nxnxv4i32, nxv4i8, arith::ExtSIOp{}, arith::ExtSIOp{}) &&
408 failedToMatch(nxnxv4i32, nxv4i8, arith::ExtUIOp{}, arith::ExtUIOp{}) &&
409 failedToMatch(nxnxv4i32, nxv4i8, arith::ExtSIOp{}, arith::ExtUIOp{}) &&
410 failedToMatch(nxnxv4i32, nxv4i8, arith::ExtUIOp{}, arith::ExtSIOp{}) &&
411 failedToMatch(nxnxv2i64, nxv2i16, arith::ExtSIOp{}, arith::ExtSIOp{}) &&
412 failedToMatch(nxnxv2i64, nxv2i16, arith::ExtUIOp{}, arith::ExtUIOp{}) &&
413 failedToMatch(nxnxv2i64, nxv2i16, arith::ExtSIOp{}, arith::ExtUIOp{}) &&
414 failedToMatch(nxnxv2i64, nxv2i16, arith::ExtUIOp{}, arith::ExtSIOp{}))
415 return failure();
416
417 return success();
418 }
419};
420
421// Rewrites: vector.extract(arith.extend) -> arith.extend(vector.extract).
422//
423// This transforms IR like:
424// %0 = arith.extsi %src : vector<4x[8]xi8> to vector<4x[8]xi32>
425// %1 = vector.extract %0[0] : vector<[8]xi32> from vector<4x[8]xi32>
426// Into:
427// %0 = vector.extract %src[0] : vector<[8]xi8> from vector<4x[8]xi8>
428// %1 = arith.extsi %0 : vector<[8]xi8> to vector<[8]xi32>
429//
430// This enables outer product fusion in the `-arm-sme-outer-product-fusion`
431// pass when the result is the input to an outer product.
432struct SwapVectorExtractOfArithExtend
433 : public OpRewritePattern<vector::ExtractOp> {
434 using OpRewritePattern::OpRewritePattern;
435
436 LogicalResult matchAndRewrite(vector::ExtractOp extractOp,
437 PatternRewriter &rewriter) const override {
438 VectorType resultType = llvm::dyn_cast<VectorType>(extractOp.getType());
439 if (!resultType)
440 return rewriter.notifyMatchFailure(extractOp,
441 "extracted type is not a vector type");
442
443 auto numScalableDims = resultType.getNumScalableDims();
444 if (numScalableDims != 1)
445 return rewriter.notifyMatchFailure(
446 extractOp, "extracted type is not a 1-D scalable vector type");
447
448 auto *extendOp = extractOp.getVector().getDefiningOp();
449 if (!isa_and_present<arith::ExtSIOp, arith::ExtUIOp, arith::ExtFOp>(
450 extendOp))
451 return rewriter.notifyMatchFailure(extractOp,
452 "extract not from extend op");
453
454 auto loc = extractOp.getLoc();
455 StringAttr extendOpName = extendOp->getName().getIdentifier();
456 Value extendSource = extendOp->getOperand(0);
457
458 // Create new extract from source of extend.
459 Value newExtract = rewriter.create<vector::ExtractOp>(
460 loc, extendSource, extractOp.getMixedPosition());
461
462 // Extend new extract to original result type.
463 Operation *newExtend =
464 rewriter.create(loc, extendOpName, Value(newExtract), resultType);
465
466 rewriter.replaceOp(extractOp, newExtend);
467
468 return success();
469 }
470};
471
472// Same as above, but for vector.scalable.extract.
473//
474// This transforms IR like:
475// %0 = arith.extsi %src : vector<[8]xi8> to vector<[8]xi32>
476// %1 = vector.scalable.extract %0[0] : vector<[4]xi32> from vector<[8]xi32>
477// Into:
478// %0 = vector.scalable.extract %src[0] : vector<[4]xi8> from vector<[8]xi8>
479// %1 = arith.extsi %0 : vector<[4]xi8> to vector<[4]xi32>
480//
481// This enables outer product fusion in the `-arm-sme-outer-product-fusion`
482// pass when the result is the input to an outer product.
483struct SwapVectorScalableExtractOfArithExtend
484 : public OpRewritePattern<vector::ScalableExtractOp> {
485 using OpRewritePattern::OpRewritePattern;
486
487 LogicalResult matchAndRewrite(vector::ScalableExtractOp extractOp,
488 PatternRewriter &rewriter) const override {
489 auto *extendOp = extractOp.getSource().getDefiningOp();
490 if (!isa_and_present<arith::ExtSIOp, arith::ExtUIOp, arith::ExtFOp>(
491 extendOp))
492 return rewriter.notifyMatchFailure(extractOp,
493 "extract not from extend op");
494
495 auto loc = extractOp.getLoc();
496 VectorType resultType = extractOp.getResultVectorType();
497
498 Value extendSource = extendOp->getOperand(0);
499 StringAttr extendOpName = extendOp->getName().getIdentifier();
500 VectorType extendSourceVectorType =
501 cast<VectorType>(extendSource.getType());
502
503 // Create new extract from source of extend.
504 VectorType extractResultVectorType =
505 resultType.clone(extendSourceVectorType.getElementType());
506 Value newExtract = rewriter.create<vector::ScalableExtractOp>(
507 loc, extractResultVectorType, extendSource, extractOp.getPos());
508
509 // Extend new extract to original result type.
510 Operation *newExtend =
511 rewriter.create(loc, extendOpName, Value(newExtract), resultType);
512
513 rewriter.replaceOp(extractOp, newExtend);
514
515 return success();
516 }
517};
518
519struct OuterProductFusionPass
520 : public arm_sme::impl::OuterProductFusionBase<OuterProductFusionPass> {
521
522 void runOnOperation() override {
523 RewritePatternSet patterns(&getContext());
524 populateOuterProductFusionPatterns(patterns);
525
526 if (failed(applyPatternsGreedily(getOperation(), std::move(patterns))))
527 signalPassFailure();
528 }
529};
530
531} // namespace
532
533void mlir::arm_sme::populateOuterProductFusionPatterns(
534 RewritePatternSet &patterns) {
535 MLIRContext *context = patterns.getContext();
536 // Note: High benefit to ensure extract(extend) are swapped first.
537 patterns.add<SwapVectorExtractOfArithExtend,
538 SwapVectorScalableExtractOfArithExtend>(arg&: context, args: 1024);
539 patterns.add<OuterProductFusion2Way, OuterProductFusion4Way>(arg&: context);
540}
541
542std::unique_ptr<Pass> mlir::arm_sme::createOuterProductFusionPass() {
543 return std::make_unique<OuterProductFusionPass>();
544}
545

Provided by KDAB

Privacy Policy
Update your C++ knowledge – Modern C++11/14/17 Training
Find out more

source code of mlir/lib/Dialect/ArmSME/Transforms/OuterProductFusion.cpp