1 | //===- OuterProductFusion.cpp - Fuse 'arm_sme.outerproduct' ops -----------===// |
---|---|
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | // |
9 | // This file implements rewrites that fuse 'arm_sme.outerproduct' operations |
10 | // into the 2-way or 4-way widening outerproduct operations. |
11 | // |
12 | //===----------------------------------------------------------------------===// |
13 | |
14 | #include "mlir/Dialect/ArmSME/IR/ArmSME.h" |
15 | #include "mlir/Dialect/ArmSME/Transforms/Passes.h" |
16 | #include "mlir/Dialect/ArmSME/Transforms/Transforms.h" |
17 | #include "mlir/Dialect/Func/IR/FuncOps.h" |
18 | #include "mlir/IR/PatternMatch.h" |
19 | #include "mlir/Transforms/GreedyPatternRewriteDriver.h" |
20 | #include "llvm/ADT/TypeSwitch.h" |
21 | |
22 | #define DEBUG_TYPE "arm-sme-outerproduct-fusion" |
23 | |
24 | namespace mlir::arm_sme { |
25 | #define GEN_PASS_DEF_OUTERPRODUCTFUSION |
26 | #include "mlir/Dialect/ArmSME/Transforms/Passes.h.inc" |
27 | } // namespace mlir::arm_sme |
28 | |
29 | using namespace mlir; |
30 | using namespace mlir::arm_sme; |
31 | |
32 | namespace { |
33 | |
34 | // Common match failure reasons. |
35 | static constexpr StringLiteral |
36 | kMatchFailureNoAccumulator("no accumulator operand"); |
37 | static constexpr StringLiteral kMatchFailureExpectedOuterProductDefOp( |
38 | "defining op of accumulator must be 'arm_sme.outerproduct'"); |
39 | static constexpr StringLiteral kMatchFailureInconsistentCombiningKind( |
40 | "combining kind (add or sub) of outer products must match"); |
41 | static constexpr StringLiteral kMatchFailureInconsistentMasking( |
42 | "unsupported masking, either both outerproducts are masked " |
43 | "or neither"); |
44 | static constexpr StringLiteral kMatchFailureOuterProductNotSingleUse( |
45 | "outer product(s) not single use and cannot be removed, no benefit to " |
46 | "fusing"); |
47 | |
48 | // An outer product is compatible if all of the following are true: |
49 | // - the result type matches `resultType`. |
50 | // - the defining operation of LHS is of the type `LhsExtOp`. |
51 | // - the defining operation of RHS is of the type `RhsExtOp`. |
52 | // - the input types of the defining operations are identical and match |
53 | // `inputType`. |
54 | template <typename LhsExtOp, typename RhsExtOp = LhsExtOp> |
55 | static LogicalResult isCompatible(PatternRewriter &rewriter, |
56 | arm_sme::OuterProductOp op, |
57 | VectorType resultType, VectorType inputType) { |
58 | if (op.getResultType() != resultType) |
59 | return rewriter.notifyMatchFailure(op.getLoc(), [&](Diagnostic &diag) { |
60 | diag << "unsupported result type, expected "<< resultType; |
61 | }); |
62 | |
63 | auto lhsDefOp = op.getLhs().getDefiningOp<LhsExtOp>(); |
64 | auto rhsDefOp = op.getRhs().getDefiningOp<RhsExtOp>(); |
65 | |
66 | if (!lhsDefOp || !rhsDefOp) |
67 | return rewriter.notifyMatchFailure( |
68 | op, "defining op of outerproduct operands must be one of: " |
69 | "'arith.extf' or 'arith.extsi' or 'arith.extui'"); |
70 | |
71 | auto lhsInType = cast<VectorType>(lhsDefOp.getIn().getType()); |
72 | auto rhsInType = cast<VectorType>(rhsDefOp.getIn().getType()); |
73 | |
74 | if (lhsInType != inputType || rhsInType != inputType) |
75 | return rewriter.notifyMatchFailure(op.getLoc(), [&](Diagnostic &diag) { |
76 | diag << "unsupported input type, expected "<< inputType; |
77 | }); |
78 | |
79 | return success(); |
80 | } |
81 | |
82 | // Fuse two 'arm_sme.outerproduct' operations that are chained via the |
83 | // accumulator into 2-way outer product operation. |
84 | // |
85 | // For example: |
86 | // |
87 | // %a0_ext = arith.extf %a0 : vector<[4]xf16> to vector<[4]xf32> |
88 | // %b0_ext = arith.extf %b0 : vector<[4]xf16> to vector<[4]xf32> |
89 | // %0 = arm_sme.outerproduct %a0_ext, %b0_ext : vector<[4]xf32>, |
90 | // vector<[4]xf32> |
91 | // |
92 | // %a1_ext = arith.extf %a1 : vector<[4]xf16> to vector<[4]xf32> |
93 | // %b1_ext = arith.extf %b1 : vector<[4]xf16> to vector<[4]xf32> |
94 | // %1 = arm_sme.outerproduct %a1_ext, %b1_ext, %0 : vector<[4]xf32>, |
95 | // vector<[4]xf32> |
96 | // |
97 | // Becomes: |
98 | // |
99 | // %a_packed = vector.interleave %a0, %a1 : vector<[4]xf16> -> vector<[8]xf16> |
100 | // %b_packed = vector.interleave %b0, %b1 : vector<[4]xf16> -> vector<[8]xf16> |
101 | // %0 = arm_sme.fmopa_2way %a_packed, %b_packed |
102 | // : vector<[8]xf16>, vector<[8]xf16> into vector<[4]x[4]xf32> |
103 | class OuterProductFusion2Way |
104 | : public OpRewritePattern<arm_sme::OuterProductOp> { |
105 | public: |
106 | using OpRewritePattern::OpRewritePattern; |
107 | |
108 | LogicalResult matchAndRewrite(arm_sme::OuterProductOp op, |
109 | PatternRewriter &rewriter) const override { |
110 | Value acc = op.getAcc(); |
111 | if (!acc) |
112 | return rewriter.notifyMatchFailure(op, kMatchFailureNoAccumulator); |
113 | |
114 | arm_sme::OuterProductOp op1 = acc.getDefiningOp<arm_sme::OuterProductOp>(); |
115 | arm_sme::OuterProductOp op2 = op; |
116 | if (!op1) |
117 | return rewriter.notifyMatchFailure( |
118 | op, kMatchFailureExpectedOuterProductDefOp); |
119 | |
120 | if (op1.getKind() != op2.getKind()) |
121 | return rewriter.notifyMatchFailure( |
122 | op, kMatchFailureInconsistentCombiningKind); |
123 | |
124 | if (!op1->hasOneUse()) { |
125 | // If the first outer product has uses other than as the input to another |
126 | // outer product, it can't be erased after fusion. |
127 | return rewriter.notifyMatchFailure(op, |
128 | kMatchFailureOuterProductNotSingleUse); |
129 | } |
130 | |
131 | if (bool(op1.getLhsMask()) != bool(op2.getLhsMask())) |
132 | return rewriter.notifyMatchFailure(op, kMatchFailureInconsistentMasking); |
133 | |
134 | if (failed(canFuseOuterProducts(rewriter, op1, op2))) |
135 | return failure(); |
136 | |
137 | auto loc = op.getLoc(); |
138 | auto packInputs = [&](Value lhs, Value rhs) { |
139 | return rewriter.create<vector::InterleaveOp>(loc, lhs, rhs); |
140 | }; |
141 | |
142 | auto lhs = packInputs(op1.getLhs().getDefiningOp()->getOperand(0), |
143 | op2.getLhs().getDefiningOp()->getOperand(0)); |
144 | auto rhs = packInputs(op1.getRhs().getDefiningOp()->getOperand(0), |
145 | op2.getRhs().getDefiningOp()->getOperand(0)); |
146 | |
147 | Value lhsMask, rhsMask; |
148 | if (op1.getLhsMask() || op2.getLhsMask()) { |
149 | lhsMask = packInputs(op1.getLhsMask(), op2.getLhsMask()); |
150 | rhsMask = packInputs(op1.getRhsMask(), op2.getRhsMask()); |
151 | } |
152 | |
153 | auto extOp = op.getLhs().getDefiningOp(); |
154 | |
155 | arm_sme::CombiningKind kind = op.getKind(); |
156 | if (kind == arm_sme::CombiningKind::Add) { |
157 | TypeSwitch<Operation *>(extOp) |
158 | .Case<arith::ExtFOp>([&](auto) { |
159 | rewriter.replaceOpWithNewOp<arm_sme::FMopa2WayOp>( |
160 | op2, op.getResultType(), lhs, rhs, lhsMask, rhsMask, |
161 | op1.getAcc()); |
162 | }) |
163 | .Case<arith::ExtSIOp>([&](auto) { |
164 | rewriter.replaceOpWithNewOp<arm_sme::SMopa2WayOp>( |
165 | op2, op.getResultType(), lhs, rhs, lhsMask, rhsMask, |
166 | op1.getAcc()); |
167 | }) |
168 | .Case<arith::ExtUIOp>([&](auto) { |
169 | rewriter.replaceOpWithNewOp<arm_sme::UMopa2WayOp>( |
170 | op2, op.getResultType(), lhs, rhs, lhsMask, rhsMask, |
171 | op1.getAcc()); |
172 | }) |
173 | .Default([&](auto) { llvm_unreachable("unexpected extend op!"); }); |
174 | } else if (kind == arm_sme::CombiningKind::Sub) { |
175 | TypeSwitch<Operation *>(extOp) |
176 | .Case<arith::ExtFOp>([&](auto) { |
177 | rewriter.replaceOpWithNewOp<arm_sme::FMops2WayOp>( |
178 | op2, op.getResultType(), lhs, rhs, lhsMask, rhsMask, |
179 | op1.getAcc()); |
180 | }) |
181 | .Case<arith::ExtSIOp>([&](auto) { |
182 | rewriter.replaceOpWithNewOp<arm_sme::SMops2WayOp>( |
183 | op2, op.getResultType(), lhs, rhs, lhsMask, rhsMask, |
184 | op1.getAcc()); |
185 | }) |
186 | .Case<arith::ExtUIOp>([&](auto) { |
187 | rewriter.replaceOpWithNewOp<arm_sme::UMops2WayOp>( |
188 | op2, op.getResultType(), lhs, rhs, lhsMask, rhsMask, |
189 | op1.getAcc()); |
190 | }) |
191 | .Default([&](auto) { llvm_unreachable("unexpected extend op!"); }); |
192 | } else { |
193 | llvm_unreachable("unexpected arm_sme::CombiningKind!"); |
194 | } |
195 | |
196 | return success(); |
197 | } |
198 | |
199 | private: |
200 | // A pair of outer product can be fused if all of the following are true: |
201 | // - input and result types match. |
202 | // - the defining operations of the inputs are identical extensions, |
203 | // specifically either: |
204 | // - a signed or unsigned extension for integer types. |
205 | // - a floating-point extension for floating-point types. |
206 | // - the types and extension are supported, i.e. there's a 2-way operation |
207 | // they can be fused into. |
208 | LogicalResult canFuseOuterProducts(PatternRewriter &rewriter, |
209 | arm_sme::OuterProductOp op1, |
210 | arm_sme::OuterProductOp op2) const { |
211 | // Supported result types. |
212 | auto nxnxv4i32 = |
213 | VectorType::get({4, 4}, rewriter.getI32Type(), {true, true}); |
214 | auto nxnxv4f32 = |
215 | VectorType::get({4, 4}, rewriter.getF32Type(), {true, true}); |
216 | // Supported input types. |
217 | // Note: this is before packing so these have half the number of elements |
218 | // of the input vector types of the 2-way operations. |
219 | auto nxv4i16 = VectorType::get({4}, rewriter.getI16Type(), true); |
220 | auto nxv4f16 = VectorType::get({4}, rewriter.getF16Type(), true); |
221 | auto nxv4bf16 = VectorType::get({4}, rewriter.getBF16Type(), true); |
222 | if ((failed( |
223 | isCompatible<arith::ExtFOp>(rewriter, op1, nxnxv4f32, nxv4f16)) || |
224 | failed( |
225 | isCompatible<arith::ExtFOp>(rewriter, op2, nxnxv4f32, nxv4f16))) && |
226 | (failed( |
227 | isCompatible<arith::ExtFOp>(rewriter, op1, nxnxv4f32, nxv4bf16)) || |
228 | failed(isCompatible<arith::ExtFOp>(rewriter, op2, nxnxv4f32, |
229 | nxv4bf16))) && |
230 | (failed( |
231 | isCompatible<arith::ExtSIOp>(rewriter, op1, nxnxv4i32, nxv4i16)) || |
232 | failed(isCompatible<arith::ExtSIOp>(rewriter, op2, nxnxv4i32, |
233 | nxv4i16))) && |
234 | (failed( |
235 | isCompatible<arith::ExtUIOp>(rewriter, op1, nxnxv4i32, nxv4i16)) || |
236 | failed( |
237 | isCompatible<arith::ExtUIOp>(rewriter, op2, nxnxv4i32, nxv4i16)))) |
238 | return failure(); |
239 | |
240 | return success(); |
241 | } |
242 | }; |
243 | |
244 | // Fuse four 'arm_sme.outerproduct' operations that are chained via the |
245 | // accumulator into 4-way outer product operation. |
246 | class OuterProductFusion4Way |
247 | : public OpRewritePattern<arm_sme::OuterProductOp> { |
248 | public: |
249 | using OpRewritePattern::OpRewritePattern; |
250 | |
251 | LogicalResult matchAndRewrite(arm_sme::OuterProductOp op, |
252 | PatternRewriter &rewriter) const override { |
253 | SmallVector<arm_sme::OuterProductOp, 4> outerProductChain; |
254 | outerProductChain.push_back(op); |
255 | |
256 | for (int i = 0; i < 3; ++i) { |
257 | auto currentOp = outerProductChain.back(); |
258 | auto acc = currentOp.getAcc(); |
259 | if (!acc) |
260 | return rewriter.notifyMatchFailure(op, kMatchFailureNoAccumulator); |
261 | auto previousOp = acc.getDefiningOp<arm_sme::OuterProductOp>(); |
262 | if (!previousOp) |
263 | return rewriter.notifyMatchFailure( |
264 | op, kMatchFailureExpectedOuterProductDefOp); |
265 | if (!previousOp->hasOneUse()) |
266 | return rewriter.notifyMatchFailure( |
267 | op, kMatchFailureOuterProductNotSingleUse); |
268 | if (previousOp.getKind() != currentOp.getKind()) |
269 | return rewriter.notifyMatchFailure( |
270 | op, kMatchFailureInconsistentCombiningKind); |
271 | if (bool(previousOp.getLhsMask()) != bool(currentOp.getLhsMask())) |
272 | return rewriter.notifyMatchFailure( |
273 | op, kMatchFailureInconsistentCombiningKind); |
274 | outerProductChain.push_back(previousOp); |
275 | } |
276 | |
277 | if (failed(canFuseOuterProducts(rewriter, outerProductChain))) |
278 | return failure(); |
279 | |
280 | arm_sme::OuterProductOp op1 = outerProductChain[3]; |
281 | arm_sme::OuterProductOp op2 = outerProductChain[2]; |
282 | arm_sme::OuterProductOp op3 = outerProductChain[1]; |
283 | arm_sme::OuterProductOp op4 = outerProductChain[0]; |
284 | |
285 | auto loc = op.getLoc(); |
286 | auto packInputs = [&](Value lhs, Value rhs) { |
287 | return rewriter.create<vector::InterleaveOp>(loc, lhs, rhs); |
288 | }; |
289 | |
290 | auto lhs0 = packInputs(op1.getLhs().getDefiningOp()->getOperand(0), |
291 | op3.getLhs().getDefiningOp()->getOperand(0)); |
292 | auto lhs1 = packInputs(op2.getLhs().getDefiningOp()->getOperand(0), |
293 | op4.getLhs().getDefiningOp()->getOperand(0)); |
294 | auto lhs = packInputs(lhs0, lhs1); |
295 | |
296 | auto rhs0 = packInputs(op1.getRhs().getDefiningOp()->getOperand(0), |
297 | op3.getRhs().getDefiningOp()->getOperand(0)); |
298 | auto rhs1 = packInputs(op2.getRhs().getDefiningOp()->getOperand(0), |
299 | op4.getRhs().getDefiningOp()->getOperand(0)); |
300 | auto rhs = packInputs(rhs0, rhs1); |
301 | |
302 | Value lhsMask, rhsMask; |
303 | if (op1.getLhsMask() || op2.getLhsMask() || op3.getLhsMask() || |
304 | op4.getLhsMask()) { |
305 | auto lhs0Mask = packInputs(op1.getLhsMask(), op3.getLhsMask()); |
306 | auto lhs1Mask = packInputs(op2.getLhsMask(), op4.getLhsMask()); |
307 | lhsMask = packInputs(lhs0Mask, lhs1Mask); |
308 | |
309 | auto rhs0Mask = packInputs(op1.getRhsMask(), op3.getRhsMask()); |
310 | auto rhs1Mask = packInputs(op2.getRhsMask(), op4.getRhsMask()); |
311 | rhsMask = packInputs(rhs0Mask, rhs1Mask); |
312 | } |
313 | |
314 | auto lhsExtOp = op.getLhs().getDefiningOp(); |
315 | auto rhsExtOp = op.getRhs().getDefiningOp(); |
316 | |
317 | arm_sme::CombiningKind kind = op.getKind(); |
318 | if (kind == arm_sme::CombiningKind::Add) { |
319 | if (isa<arith::ExtSIOp>(lhsExtOp) && isa<arith::ExtSIOp>(rhsExtOp)) { |
320 | // signed |
321 | rewriter.replaceOpWithNewOp<arm_sme::SMopa4WayOp>( |
322 | op4, op.getResultType(), lhs, rhs, lhsMask, rhsMask, op1.getAcc()); |
323 | } else if (isa<arith::ExtUIOp>(lhsExtOp) && |
324 | isa<arith::ExtUIOp>(rhsExtOp)) { |
325 | // unsigned |
326 | rewriter.replaceOpWithNewOp<arm_sme::UMopa4WayOp>( |
327 | op4, op.getResultType(), lhs, rhs, lhsMask, rhsMask, op1.getAcc()); |
328 | } else if (isa<arith::ExtSIOp>(lhsExtOp) && |
329 | isa<arith::ExtUIOp>(rhsExtOp)) { |
330 | // signed by unsigned |
331 | rewriter.replaceOpWithNewOp<arm_sme::SuMopa4WayOp>( |
332 | op4, op.getResultType(), lhs, rhs, lhsMask, rhsMask, op1.getAcc()); |
333 | } else if (isa<arith::ExtUIOp>(lhsExtOp) && |
334 | isa<arith::ExtSIOp>(rhsExtOp)) { |
335 | // unsigned by signed |
336 | rewriter.replaceOpWithNewOp<arm_sme::UsMopa4WayOp>( |
337 | op4, op.getResultType(), lhs, rhs, lhsMask, rhsMask, op1.getAcc()); |
338 | } else { |
339 | llvm_unreachable("unexpected extend op!"); |
340 | } |
341 | } else if (kind == arm_sme::CombiningKind::Sub) { |
342 | if (isa<arith::ExtSIOp>(lhsExtOp) && isa<arith::ExtSIOp>(rhsExtOp)) { |
343 | // signed |
344 | rewriter.replaceOpWithNewOp<arm_sme::SMops4WayOp>( |
345 | op4, op.getResultType(), lhs, rhs, lhsMask, rhsMask, op1.getAcc()); |
346 | } else if (isa<arith::ExtUIOp>(lhsExtOp) && |
347 | isa<arith::ExtUIOp>(rhsExtOp)) { |
348 | // unsigned |
349 | rewriter.replaceOpWithNewOp<arm_sme::UMops4WayOp>( |
350 | op4, op.getResultType(), lhs, rhs, lhsMask, rhsMask, op1.getAcc()); |
351 | } else if (isa<arith::ExtSIOp>(lhsExtOp) && |
352 | isa<arith::ExtUIOp>(rhsExtOp)) { |
353 | // signed by unsigned |
354 | rewriter.replaceOpWithNewOp<arm_sme::SuMops4WayOp>( |
355 | op4, op.getResultType(), lhs, rhs, lhsMask, rhsMask, op1.getAcc()); |
356 | } else if (isa<arith::ExtUIOp>(lhsExtOp) && |
357 | isa<arith::ExtSIOp>(rhsExtOp)) { |
358 | // unsigned by signed |
359 | rewriter.replaceOpWithNewOp<arm_sme::UsMops4WayOp>( |
360 | op4, op.getResultType(), lhs, rhs, lhsMask, rhsMask, op1.getAcc()); |
361 | } else { |
362 | llvm_unreachable("unexpected extend op!"); |
363 | } |
364 | } else { |
365 | llvm_unreachable("unexpected arm_sme::CombiningKind!"); |
366 | } |
367 | |
368 | return success(); |
369 | } |
370 | |
371 | private: |
372 | // Four outer products can be fused if all of the following are true: |
373 | // - input and result types match. |
374 | // - the defining operations of the inputs are identical extensions, |
375 | // specifically either: |
376 | // - a signed or unsigned extension for integer types. |
377 | // - a floating-point extension for floating-point types. |
378 | // - the types and extension are supported, i.e. there's a 4-way operation |
379 | // they can be fused into. |
380 | LogicalResult |
381 | canFuseOuterProducts(PatternRewriter &rewriter, |
382 | ArrayRef<arm_sme::OuterProductOp> ops) const { |
383 | // Supported result types. |
384 | auto nxnxv4i32 = |
385 | VectorType::get({4, 4}, rewriter.getI32Type(), {true, true}); |
386 | auto nxnxv2i64 = |
387 | VectorType::get({2, 2}, rewriter.getI64Type(), {true, true}); |
388 | |
389 | // Supported input types. |
390 | // Note: this is before packing so these have 1/4 the number of elements |
391 | // of the input vector types of the 4-way operations. |
392 | auto nxv4i8 = VectorType::get({4}, rewriter.getI8Type(), true); |
393 | auto nxv2i16 = VectorType::get({2}, rewriter.getI16Type(), true); |
394 | |
395 | auto failedToMatch = [&](VectorType resultType, VectorType inputType, |
396 | auto lhsExtendOp, auto rhsExtendOp) { |
397 | using LhsExtendOpTy = decltype(lhsExtendOp); |
398 | using RhsExtendOpTy = decltype(rhsExtendOp); |
399 | for (auto op : ops) { |
400 | if (failed(isCompatible<LhsExtendOpTy, RhsExtendOpTy>( |
401 | rewriter, op, resultType, inputType))) |
402 | return true; |
403 | } |
404 | return false; |
405 | }; |
406 | |
407 | if (failedToMatch(nxnxv4i32, nxv4i8, arith::ExtSIOp{}, arith::ExtSIOp{}) && |
408 | failedToMatch(nxnxv4i32, nxv4i8, arith::ExtUIOp{}, arith::ExtUIOp{}) && |
409 | failedToMatch(nxnxv4i32, nxv4i8, arith::ExtSIOp{}, arith::ExtUIOp{}) && |
410 | failedToMatch(nxnxv4i32, nxv4i8, arith::ExtUIOp{}, arith::ExtSIOp{}) && |
411 | failedToMatch(nxnxv2i64, nxv2i16, arith::ExtSIOp{}, arith::ExtSIOp{}) && |
412 | failedToMatch(nxnxv2i64, nxv2i16, arith::ExtUIOp{}, arith::ExtUIOp{}) && |
413 | failedToMatch(nxnxv2i64, nxv2i16, arith::ExtSIOp{}, arith::ExtUIOp{}) && |
414 | failedToMatch(nxnxv2i64, nxv2i16, arith::ExtUIOp{}, arith::ExtSIOp{})) |
415 | return failure(); |
416 | |
417 | return success(); |
418 | } |
419 | }; |
420 | |
421 | // Rewrites: vector.extract(arith.extend) -> arith.extend(vector.extract). |
422 | // |
423 | // This transforms IR like: |
424 | // %0 = arith.extsi %src : vector<4x[8]xi8> to vector<4x[8]xi32> |
425 | // %1 = vector.extract %0[0] : vector<[8]xi32> from vector<4x[8]xi32> |
426 | // Into: |
427 | // %0 = vector.extract %src[0] : vector<[8]xi8> from vector<4x[8]xi8> |
428 | // %1 = arith.extsi %0 : vector<[8]xi8> to vector<[8]xi32> |
429 | // |
430 | // This enables outer product fusion in the `-arm-sme-outer-product-fusion` |
431 | // pass when the result is the input to an outer product. |
432 | struct SwapVectorExtractOfArithExtend |
433 | : public OpRewritePattern<vector::ExtractOp> { |
434 | using OpRewritePattern::OpRewritePattern; |
435 | |
436 | LogicalResult matchAndRewrite(vector::ExtractOp extractOp, |
437 | PatternRewriter &rewriter) const override { |
438 | VectorType resultType = llvm::dyn_cast<VectorType>(extractOp.getType()); |
439 | if (!resultType) |
440 | return rewriter.notifyMatchFailure(extractOp, |
441 | "extracted type is not a vector type"); |
442 | |
443 | auto numScalableDims = resultType.getNumScalableDims(); |
444 | if (numScalableDims != 1) |
445 | return rewriter.notifyMatchFailure( |
446 | extractOp, "extracted type is not a 1-D scalable vector type"); |
447 | |
448 | auto *extendOp = extractOp.getVector().getDefiningOp(); |
449 | if (!isa_and_present<arith::ExtSIOp, arith::ExtUIOp, arith::ExtFOp>( |
450 | extendOp)) |
451 | return rewriter.notifyMatchFailure(extractOp, |
452 | "extract not from extend op"); |
453 | |
454 | auto loc = extractOp.getLoc(); |
455 | StringAttr extendOpName = extendOp->getName().getIdentifier(); |
456 | Value extendSource = extendOp->getOperand(0); |
457 | |
458 | // Create new extract from source of extend. |
459 | Value newExtract = rewriter.create<vector::ExtractOp>( |
460 | loc, extendSource, extractOp.getMixedPosition()); |
461 | |
462 | // Extend new extract to original result type. |
463 | Operation *newExtend = |
464 | rewriter.create(loc, extendOpName, Value(newExtract), resultType); |
465 | |
466 | rewriter.replaceOp(extractOp, newExtend); |
467 | |
468 | return success(); |
469 | } |
470 | }; |
471 | |
472 | // Same as above, but for vector.scalable.extract. |
473 | // |
474 | // This transforms IR like: |
475 | // %0 = arith.extsi %src : vector<[8]xi8> to vector<[8]xi32> |
476 | // %1 = vector.scalable.extract %0[0] : vector<[4]xi32> from vector<[8]xi32> |
477 | // Into: |
478 | // %0 = vector.scalable.extract %src[0] : vector<[4]xi8> from vector<[8]xi8> |
479 | // %1 = arith.extsi %0 : vector<[4]xi8> to vector<[4]xi32> |
480 | // |
481 | // This enables outer product fusion in the `-arm-sme-outer-product-fusion` |
482 | // pass when the result is the input to an outer product. |
483 | struct SwapVectorScalableExtractOfArithExtend |
484 | : public OpRewritePattern<vector::ScalableExtractOp> { |
485 | using OpRewritePattern::OpRewritePattern; |
486 | |
487 | LogicalResult matchAndRewrite(vector::ScalableExtractOp extractOp, |
488 | PatternRewriter &rewriter) const override { |
489 | auto *extendOp = extractOp.getSource().getDefiningOp(); |
490 | if (!isa_and_present<arith::ExtSIOp, arith::ExtUIOp, arith::ExtFOp>( |
491 | extendOp)) |
492 | return rewriter.notifyMatchFailure(extractOp, |
493 | "extract not from extend op"); |
494 | |
495 | auto loc = extractOp.getLoc(); |
496 | VectorType resultType = extractOp.getResultVectorType(); |
497 | |
498 | Value extendSource = extendOp->getOperand(0); |
499 | StringAttr extendOpName = extendOp->getName().getIdentifier(); |
500 | VectorType extendSourceVectorType = |
501 | cast<VectorType>(extendSource.getType()); |
502 | |
503 | // Create new extract from source of extend. |
504 | VectorType extractResultVectorType = |
505 | resultType.clone(extendSourceVectorType.getElementType()); |
506 | Value newExtract = rewriter.create<vector::ScalableExtractOp>( |
507 | loc, extractResultVectorType, extendSource, extractOp.getPos()); |
508 | |
509 | // Extend new extract to original result type. |
510 | Operation *newExtend = |
511 | rewriter.create(loc, extendOpName, Value(newExtract), resultType); |
512 | |
513 | rewriter.replaceOp(extractOp, newExtend); |
514 | |
515 | return success(); |
516 | } |
517 | }; |
518 | |
519 | struct OuterProductFusionPass |
520 | : public arm_sme::impl::OuterProductFusionBase<OuterProductFusionPass> { |
521 | |
522 | void runOnOperation() override { |
523 | RewritePatternSet patterns(&getContext()); |
524 | populateOuterProductFusionPatterns(patterns); |
525 | |
526 | if (failed(applyPatternsGreedily(getOperation(), std::move(patterns)))) |
527 | signalPassFailure(); |
528 | } |
529 | }; |
530 | |
531 | } // namespace |
532 | |
533 | void mlir::arm_sme::populateOuterProductFusionPatterns( |
534 | RewritePatternSet &patterns) { |
535 | MLIRContext *context = patterns.getContext(); |
536 | // Note: High benefit to ensure extract(extend) are swapped first. |
537 | patterns.add<SwapVectorExtractOfArithExtend, |
538 | SwapVectorScalableExtractOfArithExtend>(arg&: context, args: 1024); |
539 | patterns.add<OuterProductFusion2Way, OuterProductFusion4Way>(arg&: context); |
540 | } |
541 | |
542 | std::unique_ptr<Pass> mlir::arm_sme::createOuterProductFusionPass() { |
543 | return std::make_unique<OuterProductFusionPass>(); |
544 | } |
545 |
Definitions
- kMatchFailureNoAccumulator
- kMatchFailureExpectedOuterProductDefOp
- kMatchFailureInconsistentCombiningKind
- kMatchFailureInconsistentMasking
- kMatchFailureOuterProductNotSingleUse
- isCompatible
- OuterProductFusion2Way
- matchAndRewrite
- canFuseOuterProducts
- OuterProductFusion4Way
- matchAndRewrite
- canFuseOuterProducts
- SwapVectorExtractOfArithExtend
- matchAndRewrite
- SwapVectorScalableExtractOfArithExtend
- matchAndRewrite
- OuterProductFusionPass
- runOnOperation
- populateOuterProductFusionPatterns
Update your C++ knowledge – Modern C++11/14/17 Training
Find out more