1//===- Utils.cpp - Utilities to support the ArmSME dialect ----------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implements utilities for the ArmSME dialect.
10//
11//===----------------------------------------------------------------------===//
12
13#include "mlir/Dialect/ArmSME/Utils/Utils.h"
14
15namespace mlir::arm_sme {
16
17unsigned getSMETileSliceMinNumElts(Type type) {
18 assert(isValidSMETileElementType(type) && "invalid tile type!");
19 return MinStreamingVectorLengthInBits / type.getIntOrFloatBitWidth();
20}
21
22bool isValidSMETileElementType(Type type) {
23 return type.isInteger(width: 8) || type.isInteger(width: 16) || type.isInteger(width: 32) ||
24 type.isInteger(width: 64) || type.isInteger(width: 128) || type.isF16() ||
25 type.isBF16() || type.isF32() || type.isF64() || type.isF128();
26}
27
28bool isValidSMETileVectorType(VectorType vType) {
29 if ((vType.getRank() != 2) || !vType.allDimsScalable())
30 return false;
31
32 auto elemType = vType.getElementType();
33 if (!isValidSMETileElementType(type: elemType))
34 return false;
35
36 unsigned minNumElts = getSMETileSliceMinNumElts(type: elemType);
37 if (vType.getShape() != ArrayRef<int64_t>({minNumElts, minNumElts}))
38 return false;
39
40 return true;
41}
42
43std::optional<ArmSMETileType> getSMETileType(VectorType type) {
44 if (!isValidSMETileVectorType(vType: type))
45 return {};
46 switch (type.getElementTypeBitWidth()) {
47 case 8:
48 return ArmSMETileType::ZAB;
49 case 16:
50 return ArmSMETileType::ZAH;
51 case 32:
52 return ArmSMETileType::ZAS;
53 case 64:
54 return ArmSMETileType::ZAD;
55 case 128:
56 return ArmSMETileType::ZAQ;
57 default:
58 llvm_unreachable("unknown SME tile type");
59 }
60}
61
62LogicalResult verifyOperationHasValidTileId(Operation *op) {
63 auto tileOp = llvm::dyn_cast<ArmSMETileOpInterface>(Val: op);
64 if (!tileOp)
65 return success(); // Not a tile op (no need to check).
66 auto tileId = tileOp.getTileId();
67 if (!tileId)
68 return success(); // Not having a tile ID (yet) is okay.
69 if (!tileId.getType().isSignlessInteger(width: 32))
70 return tileOp.emitOpError(message: "tile ID should be a 32-bit signless integer");
71 return success();
72}
73
74scf::ForOp createLoopOverTileSlices(
75 PatternRewriter &rewriter, Location loc, Value initTile,
76 std::function<Value(OpBuilder &, Location, Value, Value)> makeLoopBody) {
77 OpBuilder::InsertionGuard g(rewriter);
78 auto step = rewriter.create<arith::ConstantIndexOp>(location: loc, args: 1);
79 auto minTileSlices = rewriter.create<arith::ConstantIndexOp>(
80 location: loc, args: llvm::cast<VectorType>(Val: initTile.getType()).getDimSize(idx: 0));
81 auto vscale =
82 rewriter.create<vector::VectorScaleOp>(location: loc, args: rewriter.getIndexType());
83 auto lowerBound = rewriter.create<arith::ConstantIndexOp>(location: loc, args: 0);
84 auto numTileSlices =
85 rewriter.create<arith::MulIOp>(location: loc, args&: minTileSlices, args&: vscale);
86 auto forOp = rewriter.create<scf::ForOp>(location: loc, args&: lowerBound, args&: numTileSlices, args&: step,
87 args: ValueRange{initTile});
88 rewriter.setInsertionPointToStart(forOp.getBody());
89 Value nextTile =
90 makeLoopBody(rewriter, loc, /*tileSliceIndex=*/forOp.getInductionVar(),
91 /*currentTile=*/forOp.getRegionIterArg(index: 0));
92 rewriter.create<scf::YieldOp>(location: loc, args&: nextTile);
93 return forOp;
94}
95
96bool isMultipleOfSMETileVectorType(VectorType vType) {
97 if (vType.getRank() != 2 || !vType.allDimsScalable())
98 return false;
99
100 auto elementType = vType.getElementType();
101 if (!isValidSMETileElementType(type: elementType))
102 return false;
103
104 unsigned minNumElts = getSMETileSliceMinNumElts(type: elementType);
105
106 int64_t vectorRows = vType.getDimSize(idx: 0);
107 int64_t vectorCols = vType.getDimSize(idx: 1);
108
109 return (vectorRows > minNumElts || vectorCols > minNumElts) &&
110 vectorRows % minNumElts == 0 && vectorCols % minNumElts == 0;
111}
112
113VectorType getSMETileTypeForElement(Type elementType) {
114 unsigned minNumElts = getSMETileSliceMinNumElts(type: elementType);
115 return VectorType::get(shape: {minNumElts, minNumElts}, elementType, scalableDims: {true, true});
116}
117
118void eraseTriviallyDeadTileOps(IRRewriter &rewriter,
119 FunctionOpInterface function) {
120 SmallVector<Operation *> worklist;
121 function->walk(callback: [&](Operation *op) {
122 auto armSMEOp = dyn_cast<arm_sme::ArmSMETileOpInterface>(Val: op);
123 if (armSMEOp && isOpTriviallyDead(op: armSMEOp))
124 worklist.push_back(Elt: armSMEOp);
125 });
126 while (!worklist.empty()) {
127 Operation *op = worklist.pop_back_val();
128 if (!isOpTriviallyDead(op))
129 continue;
130 for (Value value : op->getOperands()) {
131 if (auto armSMEOp = value.getDefiningOp<arm_sme::ArmSMETileOpInterface>())
132 worklist.push_back(Elt: armSMEOp);
133 }
134 rewriter.eraseOp(op);
135 }
136}
137
138bool isTriviallyCloneableTileOp(arm_sme::ArmSMETileOpInterface tileOp) {
139 return tileOp && tileOp->getNumResults() == 1 &&
140 tileOp->getNumOperands() == 0 && isPure(op: tileOp);
141}
142
143bool hasTileResult(arm_sme::ArmSMETileOpInterface tileOp) {
144 for (Value result : tileOp->getResults()) {
145 if (arm_sme::isValidSMETileVectorType(type: result.getType()))
146 return true;
147 }
148 return false;
149}
150
151OpOperand *getTileOpOperand(arm_sme::ArmSMETileOpInterface tileOp) {
152 if (!tileOp)
153 return nullptr;
154 auto isTileOperandType = [](OpOperand &operand) {
155 return arm_sme::isValidSMETileVectorType(type: operand.get().getType());
156 };
157 assert(llvm::count_if(tileOp->getOpOperands(), isTileOperandType) <= 1 &&
158 "expected at most one tile operand");
159 OpOperand *tileOperand =
160 llvm::find_if(Range: tileOp->getOpOperands(), P: isTileOperandType);
161 if (tileOperand == tileOp->getOpOperands().end())
162 return nullptr;
163 return tileOperand;
164}
165
166bool isTileTypeGreaterOrEqual(ArmSMETileType typeA, ArmSMETileType typeB) {
167 // Note: This is <= due to how tile types are numbered in ArmSMEOps.td.
168 return static_cast<unsigned>(typeA) <= static_cast<unsigned>(typeB);
169}
170
171} // namespace mlir::arm_sme
172

source code of mlir/lib/Dialect/ArmSME/IR/Utils.cpp