1 | //===- ArithToArmSME.cpp - Arith to ArmSME dialect conversion -------------===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | |
9 | #include "mlir/Conversion/ArithToArmSME/ArithToArmSME.h" |
10 | |
11 | #include "mlir/Dialect/Arith/IR/Arith.h" |
12 | #include "mlir/Dialect/ArmSME/IR/ArmSME.h" |
13 | #include "mlir/Dialect/ArmSME/Utils/Utils.h" |
14 | #include "mlir/Pass/Pass.h" |
15 | #include "mlir/Transforms/GreedyPatternRewriteDriver.h" |
16 | |
17 | namespace mlir { |
18 | #define GEN_PASS_DEF_ARITHTOARMSMECONVERSIONPASS |
19 | #include "mlir/Conversion/Passes.h.inc" |
20 | } // namespace mlir |
21 | |
22 | #define DEBUG_TYPE "arith-to-arm-sme" |
23 | |
24 | using namespace mlir; |
25 | |
26 | //===----------------------------------------------------------------------===// |
27 | // Conversion helpers |
28 | //===----------------------------------------------------------------------===// |
29 | |
30 | /// Returns true if 'val' is a splat of zero, false otherwise. |
31 | static bool isSplatZero(Type elemType, DenseElementsAttr val) { |
32 | if (llvm::isa<FloatType>(elemType)) |
33 | return val && val.isSplat() && val.getSplatValue<APFloat>().isZero(); |
34 | if (llvm::isa<IntegerType>(elemType)) |
35 | return val && val.isSplat() && val.getSplatValue<APInt>().isZero(); |
36 | return false; |
37 | } |
38 | |
39 | namespace { |
40 | |
41 | //===----------------------------------------------------------------------===// |
42 | // ConstantOp |
43 | //===----------------------------------------------------------------------===// |
44 | |
45 | /// Conversion pattern for dense arith.constant. |
46 | struct ConstantOpToArmSMELowering : public OpRewritePattern<arith::ConstantOp> { |
47 | using OpRewritePattern<arith::ConstantOp>::OpRewritePattern; |
48 | |
49 | LogicalResult matchAndRewrite(arith::ConstantOp constantOp, |
50 | PatternRewriter &rewriter) const final { |
51 | auto tileType = dyn_cast<VectorType>(constantOp.getType()); |
52 | if (!tileType || !arm_sme::isValidSMETileVectorType(vType: tileType)) |
53 | return failure(); |
54 | |
55 | auto denseAttr = dyn_cast<DenseElementsAttr>(constantOp.getValueAttr()); |
56 | if (!denseAttr || !denseAttr.isSplat()) |
57 | return failure(); |
58 | |
59 | auto tileElementType = tileType.getElementType(); |
60 | |
61 | // Lower 'arith.constant dense<0>' to 'arm_sme.zero' op. |
62 | if (isSplatZero(tileElementType, denseAttr)) { |
63 | rewriter.replaceOpWithNewOp<arm_sme::ZeroOp>(constantOp, tileType); |
64 | return success(); |
65 | } |
66 | |
67 | // Lower non-zero constants to a loop of 'arm_sme.move_vector_to_tile_slice' |
68 | // ops that broadcast the constant to each tile slice. |
69 | auto loc = constantOp.getLoc(); |
70 | |
71 | // To fill a tile with a constant, we create a 1-D splat of the constant, |
72 | // then move that into each tile slice (the largest unit we can set at once, |
73 | // outside of operations like the outerproduct). |
74 | VectorType tileSliceType = VectorType::Builder(tileType).dropDim(0); |
75 | auto denseAttr1D = DenseElementsAttr::get( |
76 | tileSliceType, denseAttr.getSplatValue<Attribute>()); |
77 | auto constantOp1D = rewriter.create<arith::ConstantOp>(loc, denseAttr1D); |
78 | |
79 | auto initTile = rewriter.create<arm_sme::GetTileOp>(loc, tileType); |
80 | auto makeLoopBody = [&](OpBuilder &b, Location loc, Value tileSliceIndex, |
81 | Value currentTile) { |
82 | // Create 'arm_sme.move_vector_to_tile_slice' to write vector to tile |
83 | // slice. |
84 | auto nextTile = b.create<arm_sme::MoveVectorToTileSliceOp>( |
85 | loc, tileType, constantOp1D, currentTile, tileSliceIndex); |
86 | return nextTile.getResult(); |
87 | }; |
88 | auto forOp = mlir::arm_sme::createLoopOverTileSlices( |
89 | rewriter, loc, initTile, makeLoopBody); |
90 | rewriter.replaceOp(constantOp, forOp.getResult(0)); |
91 | |
92 | return success(); |
93 | } |
94 | }; |
95 | |
96 | } // namespace |
97 | |
98 | //===----------------------------------------------------------------------===// |
99 | // Pattern population |
100 | //===----------------------------------------------------------------------===// |
101 | |
102 | void mlir::arith::populateArithToArmSMEConversionPatterns( |
103 | RewritePatternSet &patterns) { |
104 | patterns.add<ConstantOpToArmSMELowering>(arg: patterns.getContext()); |
105 | } |
106 | |
107 | //===----------------------------------------------------------------------===// |
108 | // Pass definition |
109 | //===----------------------------------------------------------------------===// |
110 | |
111 | namespace { |
112 | struct ArithToArmSMEConversionPass final |
113 | : impl::ArithToArmSMEConversionPassBase<ArithToArmSMEConversionPass> { |
114 | using impl::ArithToArmSMEConversionPassBase< |
115 | ArithToArmSMEConversionPass>::ArithToArmSMEConversionPassBase; |
116 | |
117 | void runOnOperation() override { |
118 | RewritePatternSet patterns(&getContext()); |
119 | arith::populateArithToArmSMEConversionPatterns(patterns); |
120 | if (failed( |
121 | applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)))) |
122 | return signalPassFailure(); |
123 | } |
124 | }; |
125 | } // namespace |
126 | |