1//===- Utils.cpp - Utilities to support the ArmSME dialect ----------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implements utilities for the ArmSME dialect.
10//
11//===----------------------------------------------------------------------===//
12
13#include "mlir/Dialect/ArmSME/Utils/Utils.h"
14#include "mlir/Dialect/ArmSME/IR/ArmSME.h"
15
16namespace mlir::arm_sme {
17
18unsigned getSMETileSliceMinNumElts(Type type) {
19 assert(isValidSMETileElementType(type) && "invalid tile type!");
20 return MinStreamingVectorLengthInBits / type.getIntOrFloatBitWidth();
21}
22
23bool isValidSMETileElementType(Type type) {
24 return type.isInteger(width: 8) || type.isInteger(width: 16) || type.isInteger(width: 32) ||
25 type.isInteger(width: 64) || type.isInteger(width: 128) || type.isF16() ||
26 type.isBF16() || type.isF32() || type.isF64() || type.isF128();
27}
28
29bool isValidSMETileVectorType(VectorType vType) {
30 if ((vType.getRank() != 2) || !vType.allDimsScalable())
31 return false;
32
33 auto elemType = vType.getElementType();
34 if (!isValidSMETileElementType(elemType))
35 return false;
36
37 unsigned minNumElts = getSMETileSliceMinNumElts(elemType);
38 if (vType.getShape() != ArrayRef<int64_t>({minNumElts, minNumElts}))
39 return false;
40
41 return true;
42}
43
44std::optional<ArmSMETileType> getSMETileType(VectorType type) {
45 if (!isValidSMETileVectorType(type))
46 return {};
47 switch (type.getElementTypeBitWidth()) {
48 case 8:
49 return ArmSMETileType::ZAB;
50 case 16:
51 return ArmSMETileType::ZAH;
52 case 32:
53 return ArmSMETileType::ZAS;
54 case 64:
55 return ArmSMETileType::ZAD;
56 case 128:
57 return ArmSMETileType::ZAQ;
58 default:
59 llvm_unreachable("unknown SME tile type");
60 }
61}
62
63LogicalResult verifyOperationHasValidTileId(Operation *op) {
64 auto tileOp = llvm::dyn_cast<ArmSMETileOpInterface>(op);
65 if (!tileOp)
66 return success(); // Not a tile op (no need to check).
67 auto tileId = tileOp.getTileId();
68 if (!tileId)
69 return success(); // Not having a tile ID (yet) is okay.
70 if (!tileId.getType().isSignlessInteger(32))
71 return tileOp.emitOpError("tile ID should be a 32-bit signless integer");
72 return success();
73}
74
75scf::ForOp createLoopOverTileSlices(
76 PatternRewriter &rewriter, Location loc, Value initTile,
77 std::function<Value(OpBuilder &, Location, Value, Value)> makeLoopBody) {
78 OpBuilder::InsertionGuard g(rewriter);
79 auto step = rewriter.create<arith::ConstantIndexOp>(location: loc, args: 1);
80 auto minTileSlices = rewriter.create<arith::ConstantIndexOp>(
81 loc, llvm::cast<VectorType>(initTile.getType()).getDimSize(0));
82 auto vscale =
83 rewriter.create<vector::VectorScaleOp>(loc, rewriter.getIndexType());
84 auto lowerBound = rewriter.create<arith::ConstantIndexOp>(location: loc, args: 0);
85 auto numTileSlices =
86 rewriter.create<arith::MulIOp>(loc, minTileSlices, vscale);
87 auto forOp = rewriter.create<scf::ForOp>(loc, lowerBound, numTileSlices, step,
88 ValueRange{initTile});
89 rewriter.setInsertionPointToStart(forOp.getBody());
90 Value nextTile =
91 makeLoopBody(rewriter, loc, /*tileSliceIndex=*/forOp.getInductionVar(),
92 /*currentTile=*/forOp.getRegionIterArg(0));
93 rewriter.create<scf::YieldOp>(loc, nextTile);
94 return forOp;
95}
96
97bool isMultipleOfSMETileVectorType(VectorType vType) {
98 if (vType.getRank() != 2 || !vType.allDimsScalable())
99 return false;
100
101 auto elementType = vType.getElementType();
102 if (!isValidSMETileElementType(elementType))
103 return false;
104
105 unsigned minNumElts = getSMETileSliceMinNumElts(elementType);
106
107 int64_t vectorRows = vType.getDimSize(0);
108 int64_t vectorCols = vType.getDimSize(1);
109
110 return (vectorRows > minNumElts || vectorCols > minNumElts) &&
111 vectorRows % minNumElts == 0 && vectorCols % minNumElts == 0;
112}
113
114VectorType getSMETileTypeForElement(Type elementType) {
115 unsigned minNumElts = getSMETileSliceMinNumElts(type: elementType);
116 return VectorType::get({minNumElts, minNumElts}, elementType, {true, true});
117}
118
119} // namespace mlir::arm_sme
120

source code of mlir/lib/Dialect/ArmSME/IR/Utils.cpp