1 | //===- Utils.cpp - Utilities to support the ArmSME dialect ----------------===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | // |
9 | // This file implements utilities for the ArmSME dialect. |
10 | // |
11 | //===----------------------------------------------------------------------===// |
12 | |
13 | #include "mlir/Dialect/ArmSME/Utils/Utils.h" |
14 | #include "mlir/Dialect/ArmSME/IR/ArmSME.h" |
15 | |
16 | namespace mlir::arm_sme { |
17 | |
18 | unsigned getSMETileSliceMinNumElts(Type type) { |
19 | assert(isValidSMETileElementType(type) && "invalid tile type!" ); |
20 | return MinStreamingVectorLengthInBits / type.getIntOrFloatBitWidth(); |
21 | } |
22 | |
23 | bool isValidSMETileElementType(Type type) { |
24 | return type.isInteger(width: 8) || type.isInteger(width: 16) || type.isInteger(width: 32) || |
25 | type.isInteger(width: 64) || type.isInteger(width: 128) || type.isF16() || |
26 | type.isBF16() || type.isF32() || type.isF64() || type.isF128(); |
27 | } |
28 | |
29 | bool isValidSMETileVectorType(VectorType vType) { |
30 | if ((vType.getRank() != 2) || !vType.allDimsScalable()) |
31 | return false; |
32 | |
33 | auto elemType = vType.getElementType(); |
34 | if (!isValidSMETileElementType(elemType)) |
35 | return false; |
36 | |
37 | unsigned minNumElts = getSMETileSliceMinNumElts(elemType); |
38 | if (vType.getShape() != ArrayRef<int64_t>({minNumElts, minNumElts})) |
39 | return false; |
40 | |
41 | return true; |
42 | } |
43 | |
44 | std::optional<ArmSMETileType> getSMETileType(VectorType type) { |
45 | if (!isValidSMETileVectorType(type)) |
46 | return {}; |
47 | switch (type.getElementTypeBitWidth()) { |
48 | case 8: |
49 | return ArmSMETileType::ZAB; |
50 | case 16: |
51 | return ArmSMETileType::ZAH; |
52 | case 32: |
53 | return ArmSMETileType::ZAS; |
54 | case 64: |
55 | return ArmSMETileType::ZAD; |
56 | case 128: |
57 | return ArmSMETileType::ZAQ; |
58 | default: |
59 | llvm_unreachable("unknown SME tile type" ); |
60 | } |
61 | } |
62 | |
63 | LogicalResult verifyOperationHasValidTileId(Operation *op) { |
64 | auto tileOp = llvm::dyn_cast<ArmSMETileOpInterface>(op); |
65 | if (!tileOp) |
66 | return success(); // Not a tile op (no need to check). |
67 | auto tileId = tileOp.getTileId(); |
68 | if (!tileId) |
69 | return success(); // Not having a tile ID (yet) is okay. |
70 | if (!tileId.getType().isSignlessInteger(32)) |
71 | return tileOp.emitOpError("tile ID should be a 32-bit signless integer" ); |
72 | return success(); |
73 | } |
74 | |
75 | scf::ForOp createLoopOverTileSlices( |
76 | PatternRewriter &rewriter, Location loc, Value initTile, |
77 | std::function<Value(OpBuilder &, Location, Value, Value)> makeLoopBody) { |
78 | OpBuilder::InsertionGuard g(rewriter); |
79 | auto step = rewriter.create<arith::ConstantIndexOp>(location: loc, args: 1); |
80 | auto minTileSlices = rewriter.create<arith::ConstantIndexOp>( |
81 | loc, llvm::cast<VectorType>(initTile.getType()).getDimSize(0)); |
82 | auto vscale = |
83 | rewriter.create<vector::VectorScaleOp>(loc, rewriter.getIndexType()); |
84 | auto lowerBound = rewriter.create<arith::ConstantIndexOp>(location: loc, args: 0); |
85 | auto numTileSlices = |
86 | rewriter.create<arith::MulIOp>(loc, minTileSlices, vscale); |
87 | auto forOp = rewriter.create<scf::ForOp>(loc, lowerBound, numTileSlices, step, |
88 | ValueRange{initTile}); |
89 | rewriter.setInsertionPointToStart(forOp.getBody()); |
90 | Value nextTile = |
91 | makeLoopBody(rewriter, loc, /*tileSliceIndex=*/forOp.getInductionVar(), |
92 | /*currentTile=*/forOp.getRegionIterArg(0)); |
93 | rewriter.create<scf::YieldOp>(loc, nextTile); |
94 | return forOp; |
95 | } |
96 | |
97 | bool isMultipleOfSMETileVectorType(VectorType vType) { |
98 | if (vType.getRank() != 2 || !vType.allDimsScalable()) |
99 | return false; |
100 | |
101 | auto elementType = vType.getElementType(); |
102 | if (!isValidSMETileElementType(elementType)) |
103 | return false; |
104 | |
105 | unsigned minNumElts = getSMETileSliceMinNumElts(elementType); |
106 | |
107 | int64_t vectorRows = vType.getDimSize(0); |
108 | int64_t vectorCols = vType.getDimSize(1); |
109 | |
110 | return (vectorRows > minNumElts || vectorCols > minNumElts) && |
111 | vectorRows % minNumElts == 0 && vectorCols % minNumElts == 0; |
112 | } |
113 | |
114 | VectorType getSMETileTypeForElement(Type elementType) { |
115 | unsigned minNumElts = getSMETileSliceMinNumElts(type: elementType); |
116 | return VectorType::get({minNumElts, minNumElts}, elementType, {true, true}); |
117 | } |
118 | |
119 | void eraseTriviallyDeadTileOps(IRRewriter &rewriter, |
120 | FunctionOpInterface function) { |
121 | SmallVector<Operation *> worklist; |
122 | function->walk([&](Operation *op) { |
123 | auto armSMEOp = dyn_cast<arm_sme::ArmSMETileOpInterface>(op); |
124 | if (armSMEOp && isOpTriviallyDead(armSMEOp)) |
125 | worklist.push_back(Elt: armSMEOp); |
126 | }); |
127 | while (!worklist.empty()) { |
128 | Operation *op = worklist.pop_back_val(); |
129 | if (!isOpTriviallyDead(op)) |
130 | continue; |
131 | for (Value value : op->getOperands()) { |
132 | if (auto armSMEOp = value.getDefiningOp<arm_sme::ArmSMETileOpInterface>()) |
133 | worklist.push_back(Elt: armSMEOp); |
134 | } |
135 | rewriter.eraseOp(op); |
136 | } |
137 | } |
138 | |
139 | bool isTriviallyCloneableTileOp(arm_sme::ArmSMETileOpInterface tileOp) { |
140 | return tileOp && tileOp->getNumResults() == 1 && |
141 | tileOp->getNumOperands() == 0 && isPure(tileOp); |
142 | } |
143 | |
144 | bool hasTileResult(arm_sme::ArmSMETileOpInterface tileOp) { |
145 | for (Value result : tileOp->getResults()) { |
146 | if (arm_sme::isValidSMETileVectorType(result.getType())) |
147 | return true; |
148 | } |
149 | return false; |
150 | } |
151 | |
152 | OpOperand *getTileOpOperand(arm_sme::ArmSMETileOpInterface tileOp) { |
153 | if (!tileOp) |
154 | return nullptr; |
155 | auto isTileOperandType = [](OpOperand &operand) { |
156 | return arm_sme::isValidSMETileVectorType(type: operand.get().getType()); |
157 | }; |
158 | assert(llvm::count_if(tileOp->getOpOperands(), isTileOperandType) <= 1 && |
159 | "expected at most one tile operand" ); |
160 | OpOperand *tileOperand = |
161 | llvm::find_if(tileOp->getOpOperands(), isTileOperandType); |
162 | if (tileOperand == tileOp->getOpOperands().end()) |
163 | return nullptr; |
164 | return tileOperand; |
165 | } |
166 | |
167 | bool isTileTypeGreaterOrEqual(ArmSMETileType typeA, ArmSMETileType typeB) { |
168 | // Note: This is <= due to how tile types are numbered in ArmSMEOps.td. |
169 | return static_cast<unsigned>(typeA) <= static_cast<unsigned>(typeB); |
170 | } |
171 | |
172 | } // namespace mlir::arm_sme |
173 | |