1 | //===- OuterProductFusion.cpp - Fuse 'arm_sme.outerproduct' ops -----------===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | // |
9 | // This file implements rewrites that fuse 'arm_sme.outerproduct' operations |
10 | // into the 2-way or 4-way widening outerproduct operations. |
11 | // |
12 | //===----------------------------------------------------------------------===// |
13 | |
14 | #include "mlir/Dialect/ArmSME/IR/ArmSME.h" |
15 | #include "mlir/Dialect/ArmSME/Transforms/Passes.h" |
16 | #include "mlir/Dialect/ArmSME/Transforms/Transforms.h" |
17 | #include "mlir/Dialect/Func/IR/FuncOps.h" |
18 | #include "mlir/Dialect/LLVMIR/LLVMDialect.h" |
19 | #include "mlir/IR/PatternMatch.h" |
20 | #include "mlir/Transforms/GreedyPatternRewriteDriver.h" |
21 | #include "llvm/ADT/TypeSwitch.h" |
22 | |
23 | #define DEBUG_TYPE "arm-sme-outerproduct-fusion" |
24 | |
25 | namespace mlir::arm_sme { |
26 | #define GEN_PASS_DEF_OUTERPRODUCTFUSION |
27 | #include "mlir/Dialect/ArmSME/Transforms/Passes.h.inc" |
28 | } // namespace mlir::arm_sme |
29 | |
30 | using namespace mlir; |
31 | using namespace mlir::arm_sme; |
32 | |
33 | namespace { |
34 | |
35 | // Common match failure reasons. |
36 | static constexpr StringLiteral |
37 | kMatchFailureNoAccumulator("no accumulator operand" ); |
38 | static constexpr StringLiteral kMatchFailureExpectedOuterProductDefOp( |
39 | "defining op of accumulator must be 'arm_sme.outerproduct'" ); |
40 | static constexpr StringLiteral kMatchFailureInconsistentCombiningKind( |
41 | "combining kind (add or sub) of outer products must match" ); |
42 | static constexpr StringLiteral kMatchFailureInconsistentMasking( |
43 | "unsupported masking, either both outerproducts are masked " |
44 | "or neither" ); |
45 | static constexpr StringLiteral kMatchFailureOuterProductNotSingleUse( |
46 | "outer product(s) not single use and cannot be removed, no benefit to " |
47 | "fusing" ); |
48 | |
49 | // An outer product is compatible if all of the following are true: |
50 | // - the result type matches `resultType`. |
51 | // - the defining operation of LHS is of the type `LhsExtOp`. |
52 | // - the defining operation of RHS is of the type `RhsExtOp`. |
53 | // - the input types of the defining operations are identical and match |
54 | // `inputType`. |
55 | template <typename LhsExtOp, typename RhsExtOp = LhsExtOp> |
56 | static LogicalResult isCompatible(PatternRewriter &rewriter, |
57 | arm_sme::OuterProductOp op, |
58 | VectorType resultType, VectorType inputType) { |
59 | if (op.getResultType() != resultType) |
60 | return rewriter.notifyMatchFailure(op.getLoc(), [&](Diagnostic &diag) { |
61 | diag << "unsupported result type, expected " << resultType; |
62 | }); |
63 | |
64 | auto lhsDefOp = op.getLhs().getDefiningOp<LhsExtOp>(); |
65 | auto rhsDefOp = op.getRhs().getDefiningOp<RhsExtOp>(); |
66 | |
67 | if (!lhsDefOp || !rhsDefOp) |
68 | return rewriter.notifyMatchFailure( |
69 | op, "defining op of outerproduct operands must be one of: " |
70 | "'arith.extf' or 'arith.extsi' or 'arith.extui'" ); |
71 | |
72 | auto lhsInType = cast<VectorType>(lhsDefOp.getIn().getType()); |
73 | auto rhsInType = cast<VectorType>(rhsDefOp.getIn().getType()); |
74 | |
75 | if (lhsInType != inputType || rhsInType != inputType) |
76 | return rewriter.notifyMatchFailure(op.getLoc(), [&](Diagnostic &diag) { |
77 | diag << "unsupported input type, expected " << inputType; |
78 | }); |
79 | |
80 | return success(); |
81 | } |
82 | |
83 | // Create 'llvm.experimental.vector.interleave2' intrinsic from `lhs` and `rhs`. |
84 | static Value createInterleave2Intrinsic(RewriterBase &rewriter, Location loc, |
85 | Value lhs, Value rhs) { |
86 | auto inputType = cast<VectorType>(lhs.getType()); |
87 | VectorType inputTypeX2 = |
88 | VectorType::Builder(inputType).setDim(0, inputType.getShape()[0] * 2); |
89 | return rewriter.create<LLVM::experimental_vector_interleave2>( |
90 | loc, inputTypeX2, lhs, rhs); |
91 | } |
92 | |
93 | // Fuse two 'arm_sme.outerproduct' operations that are chained via the |
94 | // accumulator into 2-way outer product operation. |
95 | // |
96 | // For example: |
97 | // |
98 | // %a0_ext = arith.extf %a0 : vector<[4]xf16> to vector<[4]xf32> |
99 | // %b0_ext = arith.extf %b0 : vector<[4]xf16> to vector<[4]xf32> |
100 | // %0 = arm_sme.outerproduct %a0_ext, %b0_ext : vector<[4]xf32>, |
101 | // vector<[4]xf32> |
102 | // |
103 | // %a1_ext = arith.extf %a1 : vector<[4]xf16> to vector<[4]xf32> |
104 | // %b1_ext = arith.extf %b1 : vector<[4]xf16> to vector<[4]xf32> |
105 | // %1 = arm_sme.outerproduct %a1_ext, %b1_ext, %0 : vector<[4]xf32>, |
106 | // vector<[4]xf32> |
107 | // |
108 | // Becomes: |
109 | // |
110 | // %a_packed = "llvm.intr.experimental.vector.interleave2"(%a0, %a1) |
111 | // : (vector<[4]xf16>, vector<[4]xf16>) -> vector<[8]xf16> |
112 | // %b_packed = "llvm.intr.experimental.vector.interleave2"(%b0, %b1) |
113 | // : (vector<[4]xf16>, vector<[4]xf16>) -> vector<[8]xf16> |
114 | // %0 = arm_sme.fmopa_2way %a_packed, %b_packed |
115 | // : vector<[8]xf16>, vector<[8]xf16> into vector<[4]x[4]xf32> |
116 | class OuterProductFusion2Way |
117 | : public OpRewritePattern<arm_sme::OuterProductOp> { |
118 | public: |
119 | using OpRewritePattern::OpRewritePattern; |
120 | |
121 | LogicalResult matchAndRewrite(arm_sme::OuterProductOp op, |
122 | PatternRewriter &rewriter) const override { |
123 | Value acc = op.getAcc(); |
124 | if (!acc) |
125 | return rewriter.notifyMatchFailure(op, kMatchFailureNoAccumulator); |
126 | |
127 | arm_sme::OuterProductOp op1 = acc.getDefiningOp<arm_sme::OuterProductOp>(); |
128 | arm_sme::OuterProductOp op2 = op; |
129 | if (!op1) |
130 | return rewriter.notifyMatchFailure( |
131 | op, kMatchFailureExpectedOuterProductDefOp); |
132 | |
133 | if (op1.getKind() != op2.getKind()) |
134 | return rewriter.notifyMatchFailure( |
135 | op, kMatchFailureInconsistentCombiningKind); |
136 | |
137 | if (!op1->hasOneUse()) { |
138 | // If the first outer product has uses other than as the input to another |
139 | // outer product, it can't be erased after fusion. This is a problem when |
140 | // it also has an accumulator as this will be used as the root for tile |
141 | // allocation and since the widening outer product uses the same |
142 | // accumulator it will get assigned the same tile ID, resulting in 3 |
143 | // outer products accumulating to the same tile and incorrect results. |
144 | // |
145 | // Example: |
146 | // |
147 | // %acc = arith.constant dense<0.0> ; root for tile allocation |
148 | // %0 = arm_sme.outerproduct %a0, %b0 acc(%acc) |
149 | // vector.print %0 ; intermediary use, can't erase %0 |
150 | // %1 = arm_sme.outerproduct %a1, %b1 acc(%0) |
151 | // |
152 | // After fusion and tile allocation |
153 | // |
154 | // %0 = arm_sme.zero {tile_id = 0 : i32} |
155 | // %1 = arm_sme.outerproduct %a0, %b0 acc(%0) {tile_id = 0 : i32} |
156 | // vector.print %1 |
157 | // %2 = arm_sme.fmopa_2way %a, %b acc(%0) {tile_id = 0 : i32} |
158 | // |
159 | // No accumulator would be ok, but it's simpler to prevent this |
160 | // altogether, since it has no benefit. |
161 | return rewriter.notifyMatchFailure(op, |
162 | kMatchFailureOuterProductNotSingleUse); |
163 | } |
164 | |
165 | if (bool(op1.getLhsMask()) != bool(op2.getLhsMask())) |
166 | return rewriter.notifyMatchFailure(op, kMatchFailureInconsistentMasking); |
167 | |
168 | if (failed(canFuseOuterProducts(rewriter, op1, op2))) |
169 | return failure(); |
170 | |
171 | auto loc = op.getLoc(); |
172 | auto packInputs = [&](Value lhs, Value rhs) { |
173 | return createInterleave2Intrinsic(rewriter, loc, lhs, rhs); |
174 | }; |
175 | |
176 | auto lhs = packInputs(op1.getLhs().getDefiningOp()->getOperand(0), |
177 | op2.getLhs().getDefiningOp()->getOperand(0)); |
178 | auto rhs = packInputs(op1.getRhs().getDefiningOp()->getOperand(0), |
179 | op2.getRhs().getDefiningOp()->getOperand(0)); |
180 | |
181 | Value lhsMask, rhsMask; |
182 | if (op1.getLhsMask() || op2.getLhsMask()) { |
183 | lhsMask = packInputs(op1.getLhsMask(), op2.getLhsMask()); |
184 | rhsMask = packInputs(op1.getRhsMask(), op2.getRhsMask()); |
185 | } |
186 | |
187 | auto extOp = op.getLhs().getDefiningOp(); |
188 | |
189 | arm_sme::CombiningKind kind = op.getKind(); |
190 | if (kind == arm_sme::CombiningKind::Add) { |
191 | TypeSwitch<Operation *>(extOp) |
192 | .Case<arith::ExtFOp>([&](auto) { |
193 | rewriter.replaceOpWithNewOp<arm_sme::FMopa2WayOp>( |
194 | op2, op.getResultType(), lhs, rhs, lhsMask, rhsMask, |
195 | op1.getAcc()); |
196 | }) |
197 | .Case<arith::ExtSIOp>([&](auto) { |
198 | rewriter.replaceOpWithNewOp<arm_sme::SMopa2WayOp>( |
199 | op2, op.getResultType(), lhs, rhs, lhsMask, rhsMask, |
200 | op1.getAcc()); |
201 | }) |
202 | .Case<arith::ExtUIOp>([&](auto) { |
203 | rewriter.replaceOpWithNewOp<arm_sme::UMopa2WayOp>( |
204 | op2, op.getResultType(), lhs, rhs, lhsMask, rhsMask, |
205 | op1.getAcc()); |
206 | }) |
207 | .Default([&](auto) { llvm_unreachable("unexpected extend op!" ); }); |
208 | } else if (kind == arm_sme::CombiningKind::Sub) { |
209 | TypeSwitch<Operation *>(extOp) |
210 | .Case<arith::ExtFOp>([&](auto) { |
211 | rewriter.replaceOpWithNewOp<arm_sme::FMops2WayOp>( |
212 | op2, op.getResultType(), lhs, rhs, lhsMask, rhsMask, |
213 | op1.getAcc()); |
214 | }) |
215 | .Case<arith::ExtSIOp>([&](auto) { |
216 | rewriter.replaceOpWithNewOp<arm_sme::SMops2WayOp>( |
217 | op2, op.getResultType(), lhs, rhs, lhsMask, rhsMask, |
218 | op1.getAcc()); |
219 | }) |
220 | .Case<arith::ExtUIOp>([&](auto) { |
221 | rewriter.replaceOpWithNewOp<arm_sme::UMops2WayOp>( |
222 | op2, op.getResultType(), lhs, rhs, lhsMask, rhsMask, |
223 | op1.getAcc()); |
224 | }) |
225 | .Default([&](auto) { llvm_unreachable("unexpected extend op!" ); }); |
226 | } else { |
227 | llvm_unreachable("unexpected arm_sme::CombiningKind!" ); |
228 | } |
229 | |
230 | rewriter.eraseOp(op: op1); |
231 | |
232 | return success(); |
233 | } |
234 | |
235 | private: |
236 | // A pair of outer product can be fused if all of the following are true: |
237 | // - input and result types match. |
238 | // - the defining operations of the inputs are identical extensions, |
239 | // specifically either: |
240 | // - a signed or unsigned extension for integer types. |
241 | // - a floating-point extension for floating-point types. |
242 | // - the types and extension are supported, i.e. there's a 2-way operation |
243 | // they can be fused into. |
244 | LogicalResult canFuseOuterProducts(PatternRewriter &rewriter, |
245 | arm_sme::OuterProductOp op1, |
246 | arm_sme::OuterProductOp op2) const { |
247 | // Supported result types. |
248 | auto nxnxv4i32 = |
249 | VectorType::get({4, 4}, rewriter.getI32Type(), {true, true}); |
250 | auto nxnxv4f32 = |
251 | VectorType::get({4, 4}, rewriter.getF32Type(), {true, true}); |
252 | // Supported input types. |
253 | // Note: this is before packing so these have half the number of elements |
254 | // of the input vector types of the 2-way operations. |
255 | auto nxv4i16 = VectorType::get({4}, rewriter.getI16Type(), true); |
256 | auto nxv4f16 = VectorType::get({4}, rewriter.getF16Type(), true); |
257 | auto nxv4bf16 = VectorType::get({4}, rewriter.getBF16Type(), true); |
258 | if ((failed( |
259 | isCompatible<arith::ExtFOp>(rewriter, op1, nxnxv4f32, nxv4f16)) || |
260 | failed( |
261 | isCompatible<arith::ExtFOp>(rewriter, op2, nxnxv4f32, nxv4f16))) && |
262 | (failed( |
263 | isCompatible<arith::ExtFOp>(rewriter, op1, nxnxv4f32, nxv4bf16)) || |
264 | failed(isCompatible<arith::ExtFOp>(rewriter, op2, nxnxv4f32, |
265 | nxv4bf16))) && |
266 | (failed( |
267 | isCompatible<arith::ExtSIOp>(rewriter, op1, nxnxv4i32, nxv4i16)) || |
268 | failed(isCompatible<arith::ExtSIOp>(rewriter, op2, nxnxv4i32, |
269 | nxv4i16))) && |
270 | (failed( |
271 | isCompatible<arith::ExtUIOp>(rewriter, op1, nxnxv4i32, nxv4i16)) || |
272 | failed( |
273 | isCompatible<arith::ExtUIOp>(rewriter, op2, nxnxv4i32, nxv4i16)))) |
274 | return failure(); |
275 | |
276 | return success(); |
277 | } |
278 | }; |
279 | |
280 | // Fuse four 'arm_sme.outerproduct' operations that are chained via the |
281 | // accumulator into 4-way outer product operation. |
282 | class OuterProductFusion4Way |
283 | : public OpRewritePattern<arm_sme::OuterProductOp> { |
284 | public: |
285 | using OpRewritePattern::OpRewritePattern; |
286 | |
287 | LogicalResult matchAndRewrite(arm_sme::OuterProductOp op, |
288 | PatternRewriter &rewriter) const override { |
289 | SmallVector<arm_sme::OuterProductOp, 4> outerProductChain; |
290 | outerProductChain.push_back(op); |
291 | |
292 | for (int i = 0; i < 3; ++i) { |
293 | auto currentOp = outerProductChain.back(); |
294 | auto acc = currentOp.getAcc(); |
295 | if (!acc) |
296 | return rewriter.notifyMatchFailure(op, kMatchFailureNoAccumulator); |
297 | auto previousOp = acc.getDefiningOp<arm_sme::OuterProductOp>(); |
298 | if (!previousOp) |
299 | return rewriter.notifyMatchFailure( |
300 | op, kMatchFailureExpectedOuterProductDefOp); |
301 | if (!previousOp->hasOneUse()) |
302 | return rewriter.notifyMatchFailure( |
303 | op, kMatchFailureOuterProductNotSingleUse); |
304 | if (previousOp.getKind() != currentOp.getKind()) |
305 | return rewriter.notifyMatchFailure( |
306 | op, kMatchFailureInconsistentCombiningKind); |
307 | if (bool(previousOp.getLhsMask()) != bool(currentOp.getLhsMask())) |
308 | return rewriter.notifyMatchFailure( |
309 | op, kMatchFailureInconsistentCombiningKind); |
310 | outerProductChain.push_back(previousOp); |
311 | } |
312 | |
313 | if (failed(canFuseOuterProducts(rewriter, outerProductChain))) |
314 | return failure(); |
315 | |
316 | arm_sme::OuterProductOp op1 = outerProductChain[3]; |
317 | arm_sme::OuterProductOp op2 = outerProductChain[2]; |
318 | arm_sme::OuterProductOp op3 = outerProductChain[1]; |
319 | arm_sme::OuterProductOp op4 = outerProductChain[0]; |
320 | |
321 | auto loc = op.getLoc(); |
322 | auto packInputs = [&](Value lhs, Value rhs) { |
323 | return createInterleave2Intrinsic(rewriter, loc, lhs, rhs); |
324 | }; |
325 | |
326 | auto lhs0 = packInputs(op1.getLhs().getDefiningOp()->getOperand(0), |
327 | op3.getLhs().getDefiningOp()->getOperand(0)); |
328 | auto lhs1 = packInputs(op2.getLhs().getDefiningOp()->getOperand(0), |
329 | op4.getLhs().getDefiningOp()->getOperand(0)); |
330 | auto lhs = packInputs(lhs0, lhs1); |
331 | |
332 | auto rhs0 = packInputs(op1.getRhs().getDefiningOp()->getOperand(0), |
333 | op3.getRhs().getDefiningOp()->getOperand(0)); |
334 | auto rhs1 = packInputs(op2.getRhs().getDefiningOp()->getOperand(0), |
335 | op4.getRhs().getDefiningOp()->getOperand(0)); |
336 | auto rhs = packInputs(rhs0, rhs1); |
337 | |
338 | Value lhsMask, rhsMask; |
339 | if (op1.getLhsMask() || op2.getLhsMask() || op3.getLhsMask() || |
340 | op4.getLhsMask()) { |
341 | auto lhs0Mask = packInputs(op1.getLhsMask(), op3.getLhsMask()); |
342 | auto lhs1Mask = packInputs(op2.getLhsMask(), op4.getLhsMask()); |
343 | lhsMask = packInputs(lhs0Mask, lhs1Mask); |
344 | |
345 | auto rhs0Mask = packInputs(op1.getRhsMask(), op3.getRhsMask()); |
346 | auto rhs1Mask = packInputs(op2.getRhsMask(), op4.getRhsMask()); |
347 | rhsMask = packInputs(rhs0Mask, rhs1Mask); |
348 | } |
349 | |
350 | auto lhsExtOp = op.getLhs().getDefiningOp(); |
351 | auto rhsExtOp = op.getRhs().getDefiningOp(); |
352 | |
353 | arm_sme::CombiningKind kind = op.getKind(); |
354 | if (kind == arm_sme::CombiningKind::Add) { |
355 | if (isa<arith::ExtSIOp>(lhsExtOp) && isa<arith::ExtSIOp>(rhsExtOp)) { |
356 | // signed |
357 | rewriter.replaceOpWithNewOp<arm_sme::SMopa4WayOp>( |
358 | op4, op.getResultType(), lhs, rhs, lhsMask, rhsMask, op1.getAcc()); |
359 | } else if (isa<arith::ExtUIOp>(lhsExtOp) && |
360 | isa<arith::ExtUIOp>(rhsExtOp)) { |
361 | // unsigned |
362 | rewriter.replaceOpWithNewOp<arm_sme::UMopa4WayOp>( |
363 | op4, op.getResultType(), lhs, rhs, lhsMask, rhsMask, op1.getAcc()); |
364 | } else if (isa<arith::ExtSIOp>(lhsExtOp) && |
365 | isa<arith::ExtUIOp>(rhsExtOp)) { |
366 | // signed by unsigned |
367 | rewriter.replaceOpWithNewOp<arm_sme::SuMopa4WayOp>( |
368 | op4, op.getResultType(), lhs, rhs, lhsMask, rhsMask, op1.getAcc()); |
369 | } else if (isa<arith::ExtUIOp>(lhsExtOp) && |
370 | isa<arith::ExtSIOp>(rhsExtOp)) { |
371 | // unsigned by signed |
372 | rewriter.replaceOpWithNewOp<arm_sme::UsMopa4WayOp>( |
373 | op4, op.getResultType(), lhs, rhs, lhsMask, rhsMask, op1.getAcc()); |
374 | } else { |
375 | llvm_unreachable("unexpected extend op!" ); |
376 | } |
377 | } else if (kind == arm_sme::CombiningKind::Sub) { |
378 | if (isa<arith::ExtSIOp>(lhsExtOp) && isa<arith::ExtSIOp>(rhsExtOp)) { |
379 | // signed |
380 | rewriter.replaceOpWithNewOp<arm_sme::SMops4WayOp>( |
381 | op4, op.getResultType(), lhs, rhs, lhsMask, rhsMask, op1.getAcc()); |
382 | } else if (isa<arith::ExtUIOp>(lhsExtOp) && |
383 | isa<arith::ExtUIOp>(rhsExtOp)) { |
384 | // unsigned |
385 | rewriter.replaceOpWithNewOp<arm_sme::UMops4WayOp>( |
386 | op4, op.getResultType(), lhs, rhs, lhsMask, rhsMask, op1.getAcc()); |
387 | } else if (isa<arith::ExtSIOp>(lhsExtOp) && |
388 | isa<arith::ExtUIOp>(rhsExtOp)) { |
389 | // signed by unsigned |
390 | rewriter.replaceOpWithNewOp<arm_sme::SuMops4WayOp>( |
391 | op4, op.getResultType(), lhs, rhs, lhsMask, rhsMask, op1.getAcc()); |
392 | } else if (isa<arith::ExtUIOp>(lhsExtOp) && |
393 | isa<arith::ExtSIOp>(rhsExtOp)) { |
394 | // unsigned by signed |
395 | rewriter.replaceOpWithNewOp<arm_sme::UsMops4WayOp>( |
396 | op4, op.getResultType(), lhs, rhs, lhsMask, rhsMask, op1.getAcc()); |
397 | } else { |
398 | llvm_unreachable("unexpected extend op!" ); |
399 | } |
400 | } else { |
401 | llvm_unreachable("unexpected arm_sme::CombiningKind!" ); |
402 | } |
403 | |
404 | rewriter.eraseOp(op: op3); |
405 | rewriter.eraseOp(op: op2); |
406 | rewriter.eraseOp(op: op1); |
407 | |
408 | return success(); |
409 | } |
410 | |
411 | private: |
412 | // Four outer products can be fused if all of the following are true: |
413 | // - input and result types match. |
414 | // - the defining operations of the inputs are identical extensions, |
415 | // specifically either: |
416 | // - a signed or unsigned extension for integer types. |
417 | // - a floating-point extension for floating-point types. |
418 | // - the types and extension are supported, i.e. there's a 4-way operation |
419 | // they can be fused into. |
420 | LogicalResult |
421 | canFuseOuterProducts(PatternRewriter &rewriter, |
422 | ArrayRef<arm_sme::OuterProductOp> ops) const { |
423 | // Supported result types. |
424 | auto nxnxv4i32 = |
425 | VectorType::get({4, 4}, rewriter.getI32Type(), {true, true}); |
426 | auto nxnxv2i64 = |
427 | VectorType::get({2, 2}, rewriter.getI64Type(), {true, true}); |
428 | |
429 | // Supported input types. |
430 | // Note: this is before packing so these have 1/4 the number of elements |
431 | // of the input vector types of the 4-way operations. |
432 | auto nxv4i8 = VectorType::get({4}, rewriter.getI8Type(), true); |
433 | auto nxv2i16 = VectorType::get({2}, rewriter.getI16Type(), true); |
434 | |
435 | auto failedToMatch = [&](VectorType resultType, VectorType inputType, |
436 | auto lhsExtendOp, auto rhsExtendOp) { |
437 | using LhsExtendOpTy = decltype(lhsExtendOp); |
438 | using RhsExtendOpTy = decltype(rhsExtendOp); |
439 | for (auto op : ops) { |
440 | if (failed(isCompatible<LhsExtendOpTy, RhsExtendOpTy>( |
441 | rewriter, op, resultType, inputType))) |
442 | return true; |
443 | } |
444 | return false; |
445 | }; |
446 | |
447 | if (failedToMatch(nxnxv4i32, nxv4i8, arith::ExtSIOp{}, arith::ExtSIOp{}) && |
448 | failedToMatch(nxnxv4i32, nxv4i8, arith::ExtUIOp{}, arith::ExtUIOp{}) && |
449 | failedToMatch(nxnxv4i32, nxv4i8, arith::ExtSIOp{}, arith::ExtUIOp{}) && |
450 | failedToMatch(nxnxv4i32, nxv4i8, arith::ExtUIOp{}, arith::ExtSIOp{}) && |
451 | failedToMatch(nxnxv2i64, nxv2i16, arith::ExtSIOp{}, arith::ExtSIOp{}) && |
452 | failedToMatch(nxnxv2i64, nxv2i16, arith::ExtUIOp{}, arith::ExtUIOp{}) && |
453 | failedToMatch(nxnxv2i64, nxv2i16, arith::ExtSIOp{}, arith::ExtUIOp{}) && |
454 | failedToMatch(nxnxv2i64, nxv2i16, arith::ExtUIOp{}, arith::ExtSIOp{})) |
455 | return failure(); |
456 | |
457 | return success(); |
458 | } |
459 | }; |
460 | |
461 | // Rewrites: vector.extract(arith.extend) -> arith.extend(vector.extract). |
462 | // |
463 | // This transforms IR like: |
464 | // %0 = arith.extsi %src : vector<4x[8]xi8> to vector<4x[8]xi32> |
465 | // %1 = vector.extract %0[0] : vector<[8]xi32> from vector<4x[8]xi32> |
466 | // Into: |
467 | // %0 = vector.extract %src[0] : vector<[8]xi8> from vector<4x[8]xi8> |
468 | // %1 = arith.extsi %0 : vector<[8]xi8> to vector<[8]xi32> |
469 | // |
470 | // This enables outer product fusion in the `-arm-sme-outer-product-fusion` |
471 | // pass when the result is the input to an outer product. |
472 | struct |
473 | : public OpRewritePattern<vector::ExtractOp> { |
474 | using OpRewritePattern::OpRewritePattern; |
475 | |
476 | LogicalResult matchAndRewrite(vector::ExtractOp , |
477 | PatternRewriter &rewriter) const override { |
478 | VectorType resultType = llvm::dyn_cast<VectorType>(extractOp.getType()); |
479 | if (!resultType) |
480 | return rewriter.notifyMatchFailure(extractOp, |
481 | "extracted type is not a vector type" ); |
482 | |
483 | auto numScalableDims = llvm::count(resultType.getScalableDims(), true); |
484 | if (numScalableDims != 1) |
485 | return rewriter.notifyMatchFailure( |
486 | extractOp, "extracted type is not a 1-D scalable vector type" ); |
487 | |
488 | auto *extendOp = extractOp.getVector().getDefiningOp(); |
489 | if (!isa_and_present<arith::ExtSIOp, arith::ExtUIOp, arith::ExtFOp>( |
490 | extendOp)) |
491 | return rewriter.notifyMatchFailure(extractOp, |
492 | "extract not from extend op" ); |
493 | |
494 | auto loc = extractOp.getLoc(); |
495 | StringAttr extendOpName = extendOp->getName().getIdentifier(); |
496 | Value extendSource = extendOp->getOperand(0); |
497 | |
498 | // Create new extract from source of extend. |
499 | Value = rewriter.create<vector::ExtractOp>( |
500 | loc, extendSource, extractOp.getMixedPosition()); |
501 | |
502 | // Extend new extract to original result type. |
503 | Operation *newExtend = |
504 | rewriter.create(loc, extendOpName, Value(newExtract), resultType); |
505 | |
506 | rewriter.replaceOp(extractOp, newExtend); |
507 | |
508 | return success(); |
509 | } |
510 | }; |
511 | |
512 | // Same as above, but for vector.scalable.extract. |
513 | // |
514 | // This transforms IR like: |
515 | // %0 = arith.extsi %src : vector<[8]xi8> to vector<[8]xi32> |
516 | // %1 = vector.scalable.extract %0[0] : vector<[4]xi32> from vector<[8]xi32> |
517 | // Into: |
518 | // %0 = vector.scalable.extract %src[0] : vector<[4]xi8> from vector<[8]xi8> |
519 | // %1 = arith.extsi %0 : vector<[4]xi8> to vector<[4]xi32> |
520 | // |
521 | // This enables outer product fusion in the `-arm-sme-outer-product-fusion` |
522 | // pass when the result is the input to an outer product. |
523 | struct |
524 | : public OpRewritePattern<vector::ScalableExtractOp> { |
525 | using OpRewritePattern::OpRewritePattern; |
526 | |
527 | LogicalResult matchAndRewrite(vector::ScalableExtractOp , |
528 | PatternRewriter &rewriter) const override { |
529 | auto *extendOp = extractOp.getSource().getDefiningOp(); |
530 | if (!isa_and_present<arith::ExtSIOp, arith::ExtUIOp, arith::ExtFOp>( |
531 | extendOp)) |
532 | return rewriter.notifyMatchFailure(extractOp, |
533 | "extract not from extend op" ); |
534 | |
535 | auto loc = extractOp.getLoc(); |
536 | VectorType resultType = extractOp.getResultVectorType(); |
537 | |
538 | Value extendSource = extendOp->getOperand(0); |
539 | StringAttr extendOpName = extendOp->getName().getIdentifier(); |
540 | VectorType extendSourceVectorType = |
541 | cast<VectorType>(extendSource.getType()); |
542 | |
543 | // Create new extract from source of extend. |
544 | VectorType = |
545 | resultType.clone(extendSourceVectorType.getElementType()); |
546 | Value = rewriter.create<vector::ScalableExtractOp>( |
547 | loc, extractResultVectorType, extendSource, extractOp.getPos()); |
548 | |
549 | // Extend new extract to original result type. |
550 | Operation *newExtend = |
551 | rewriter.create(loc, extendOpName, Value(newExtract), resultType); |
552 | |
553 | rewriter.replaceOp(extractOp, newExtend); |
554 | |
555 | return success(); |
556 | } |
557 | }; |
558 | |
559 | struct OuterProductFusionPass |
560 | : public arm_sme::impl::OuterProductFusionBase<OuterProductFusionPass> { |
561 | |
562 | void runOnOperation() override { |
563 | RewritePatternSet patterns(&getContext()); |
564 | populateOuterProductFusionPatterns(patterns); |
565 | |
566 | if (failed( |
567 | applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)))) |
568 | signalPassFailure(); |
569 | } |
570 | }; |
571 | |
572 | } // namespace |
573 | |
574 | void mlir::arm_sme::populateOuterProductFusionPatterns( |
575 | RewritePatternSet &patterns) { |
576 | MLIRContext *context = patterns.getContext(); |
577 | // Note: High benefit to ensure extract(extend) are swapped first. |
578 | patterns.add<SwapVectorExtractOfArithExtend, |
579 | SwapVectorScalableExtractOfArithExtend>(arg&: context, args: 1024); |
580 | patterns.add<OuterProductFusion2Way, OuterProductFusion4Way>(arg&: context); |
581 | } |
582 | |
583 | std::unique_ptr<Pass> mlir::arm_sme::createOuterProductFusionPass() { |
584 | return std::make_unique<OuterProductFusionPass>(); |
585 | } |
586 | |