1 | //===- Utils.cpp - Utilities to support the ArmSME dialect ----------------===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | // |
9 | // This file implements utilities for the ArmSME dialect. |
10 | // |
11 | //===----------------------------------------------------------------------===// |
12 | |
13 | #include "mlir/Dialect/ArmSME/Utils/Utils.h" |
14 | #include "mlir/Dialect/ArmSME/IR/ArmSME.h" |
15 | |
16 | namespace mlir::arm_sme { |
17 | |
18 | unsigned getSMETileSliceMinNumElts(Type type) { |
19 | assert(isValidSMETileElementType(type) && "invalid tile type!" ); |
20 | return MinStreamingVectorLengthInBits / type.getIntOrFloatBitWidth(); |
21 | } |
22 | |
23 | bool isValidSMETileElementType(Type type) { |
24 | return type.isInteger(width: 8) || type.isInteger(width: 16) || type.isInteger(width: 32) || |
25 | type.isInteger(width: 64) || type.isInteger(width: 128) || type.isF16() || |
26 | type.isBF16() || type.isF32() || type.isF64() || type.isF128(); |
27 | } |
28 | |
29 | bool isValidSMETileVectorType(VectorType vType) { |
30 | if ((vType.getRank() != 2) || !vType.allDimsScalable()) |
31 | return false; |
32 | |
33 | auto elemType = vType.getElementType(); |
34 | if (!isValidSMETileElementType(elemType)) |
35 | return false; |
36 | |
37 | unsigned minNumElts = getSMETileSliceMinNumElts(elemType); |
38 | if (vType.getShape() != ArrayRef<int64_t>({minNumElts, minNumElts})) |
39 | return false; |
40 | |
41 | return true; |
42 | } |
43 | |
44 | std::optional<ArmSMETileType> getSMETileType(VectorType type) { |
45 | if (!isValidSMETileVectorType(type)) |
46 | return {}; |
47 | switch (type.getElementTypeBitWidth()) { |
48 | case 8: |
49 | return ArmSMETileType::ZAB; |
50 | case 16: |
51 | return ArmSMETileType::ZAH; |
52 | case 32: |
53 | return ArmSMETileType::ZAS; |
54 | case 64: |
55 | return ArmSMETileType::ZAD; |
56 | case 128: |
57 | return ArmSMETileType::ZAQ; |
58 | default: |
59 | llvm_unreachable("unknown SME tile type" ); |
60 | } |
61 | } |
62 | |
63 | LogicalResult verifyOperationHasValidTileId(Operation *op) { |
64 | auto tileOp = llvm::dyn_cast<ArmSMETileOpInterface>(op); |
65 | if (!tileOp) |
66 | return success(); // Not a tile op (no need to check). |
67 | auto tileId = tileOp.getTileId(); |
68 | if (!tileId) |
69 | return success(); // Not having a tile ID (yet) is okay. |
70 | if (!tileId.getType().isSignlessInteger(32)) |
71 | return tileOp.emitOpError("tile ID should be a 32-bit signless integer" ); |
72 | return success(); |
73 | } |
74 | |
75 | scf::ForOp createLoopOverTileSlices( |
76 | PatternRewriter &rewriter, Location loc, Value initTile, |
77 | std::function<Value(OpBuilder &, Location, Value, Value)> makeLoopBody) { |
78 | OpBuilder::InsertionGuard g(rewriter); |
79 | auto step = rewriter.create<arith::ConstantIndexOp>(location: loc, args: 1); |
80 | auto minTileSlices = rewriter.create<arith::ConstantIndexOp>( |
81 | loc, llvm::cast<VectorType>(initTile.getType()).getDimSize(0)); |
82 | auto vscale = |
83 | rewriter.create<vector::VectorScaleOp>(loc, rewriter.getIndexType()); |
84 | auto lowerBound = rewriter.create<arith::ConstantIndexOp>(location: loc, args: 0); |
85 | auto numTileSlices = |
86 | rewriter.create<arith::MulIOp>(loc, minTileSlices, vscale); |
87 | auto forOp = rewriter.create<scf::ForOp>(loc, lowerBound, numTileSlices, step, |
88 | ValueRange{initTile}); |
89 | rewriter.setInsertionPointToStart(forOp.getBody()); |
90 | Value nextTile = |
91 | makeLoopBody(rewriter, loc, /*tileSliceIndex=*/forOp.getInductionVar(), |
92 | /*currentTile=*/forOp.getRegionIterArg(0)); |
93 | rewriter.create<scf::YieldOp>(loc, nextTile); |
94 | return forOp; |
95 | } |
96 | |
97 | bool isMultipleOfSMETileVectorType(VectorType vType) { |
98 | if (vType.getRank() != 2 || !vType.allDimsScalable()) |
99 | return false; |
100 | |
101 | auto elementType = vType.getElementType(); |
102 | if (!isValidSMETileElementType(elementType)) |
103 | return false; |
104 | |
105 | unsigned minNumElts = getSMETileSliceMinNumElts(elementType); |
106 | |
107 | int64_t vectorRows = vType.getDimSize(0); |
108 | int64_t vectorCols = vType.getDimSize(1); |
109 | |
110 | return (vectorRows > minNumElts || vectorCols > minNumElts) && |
111 | vectorRows % minNumElts == 0 && vectorCols % minNumElts == 0; |
112 | } |
113 | |
114 | VectorType getSMETileTypeForElement(Type elementType) { |
115 | unsigned minNumElts = getSMETileSliceMinNumElts(type: elementType); |
116 | return VectorType::get({minNumElts, minNumElts}, elementType, {true, true}); |
117 | } |
118 | |
119 | } // namespace mlir::arm_sme |
120 | |