1 | //===- EnableArmStreaming.cpp - Enable Armv9 Streaming SVE mode -----------===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | // |
9 | // This pass enables the Armv9 Scalable Matrix Extension (SME) Streaming SVE |
10 | // (SSVE) mode [1][2] by adding either of the following attributes to |
11 | // 'func.func' ops: |
12 | // |
13 | // * 'arm_streaming' (default) |
14 | // * 'arm_locally_streaming' |
15 | // |
16 | // It can also optionally enable the ZA storage array. |
17 | // |
18 | // Streaming-mode is part of the interface (ABI) for functions with the |
19 | // first attribute and it's the responsibility of the caller to manage |
20 | // PSTATE.SM on entry/exit to functions with this attribute [3]. The LLVM |
21 | // backend will emit 'smstart sm' / 'smstop sm' [4] around calls to |
22 | // streaming functions. |
23 | // |
24 | // In locally streaming functions PSTATE.SM is kept internal and managed by |
25 | // the callee on entry/exit. The LLVM backend will emit 'smstart sm' / |
26 | // 'smstop sm' in the prologue / epilogue for functions with this |
27 | // attribute. |
28 | // |
29 | // [1] https://developer.arm.com/documentation/ddi0616/aa |
30 | // [2] https://llvm.org/docs/AArch64SME.html |
31 | // [3] https://github.com/ARM-software/abi-aa/blob/main/aapcs64/aapcs64.rst#671pstatesm-interfaces |
32 | // [4] https://developer.arm.com/documentation/ddi0602/2023-03/Base-Instructions/SMSTART--Enables-access-to-Streaming-SVE-mode-and-SME-architectural-state--an-alias-of-MSR--immediate-- |
33 | // |
34 | //===----------------------------------------------------------------------===// |
35 | |
36 | #include "mlir/Dialect/ArmSME/IR/ArmSME.h" |
37 | #include "mlir/Dialect/ArmSME/Transforms/Passes.h" |
38 | #include "mlir/Dialect/ArmSME/Transforms/PassesEnums.cpp.inc" |
39 | |
40 | #include "mlir/Dialect/Func/IR/FuncOps.h" |
41 | |
42 | #define DEBUG_TYPE "enable-arm-streaming" |
43 | |
44 | namespace mlir { |
45 | namespace arm_sme { |
46 | #define GEN_PASS_DEF_ENABLEARMSTREAMING |
47 | #include "mlir/Dialect/ArmSME/Transforms/Passes.h.inc" |
48 | } // namespace arm_sme |
49 | } // namespace mlir |
50 | |
51 | using namespace mlir; |
52 | using namespace mlir::arm_sme; |
53 | namespace { |
54 | |
55 | constexpr StringLiteral |
56 | kEnableArmStreamingIgnoreAttr("enable_arm_streaming_ignore" ); |
57 | |
58 | template <typename... Ops> |
59 | constexpr auto opList() { |
60 | return std::array{TypeID::get<Ops>()...}; |
61 | } |
62 | |
63 | bool isScalableVector(Type type) { |
64 | if (auto vectorType = dyn_cast<VectorType>(type)) |
65 | return vectorType.isScalable(); |
66 | return false; |
67 | } |
68 | |
69 | struct EnableArmStreamingPass |
70 | : public arm_sme::impl::EnableArmStreamingBase<EnableArmStreamingPass> { |
71 | EnableArmStreamingPass(ArmStreamingMode streamingMode, ArmZaMode zaMode, |
72 | bool ifRequiredByOps, bool ifScalableAndSupported) { |
73 | this->streamingMode = streamingMode; |
74 | this->zaMode = zaMode; |
75 | this->ifRequiredByOps = ifRequiredByOps; |
76 | this->ifScalableAndSupported = ifScalableAndSupported; |
77 | } |
78 | void runOnOperation() override { |
79 | auto function = getOperation(); |
80 | |
81 | if (ifRequiredByOps && ifScalableAndSupported) { |
82 | function->emitOpError( |
83 | "enable-arm-streaming: `if-required-by-ops` and " |
84 | "`if-scalable-and-supported` are mutually exclusive" ); |
85 | return signalPassFailure(); |
86 | } |
87 | |
88 | if (ifRequiredByOps) { |
89 | bool foundTileOp = false; |
90 | function.walk([&](Operation *op) { |
91 | if (llvm::isa<ArmSMETileOpInterface>(op)) { |
92 | foundTileOp = true; |
93 | return WalkResult::interrupt(); |
94 | } |
95 | return WalkResult::advance(); |
96 | }); |
97 | if (!foundTileOp) |
98 | return; |
99 | } |
100 | |
101 | if (ifScalableAndSupported) { |
102 | // FIXME: This should be based on target information (i.e., the presence |
103 | // of FEAT_SME_FA64). This currently errs on the side of caution. If |
104 | // possible gathers/scatters should be lowered regular vector loads/stores |
105 | // before invoking this pass. |
106 | auto disallowedOperations = opList<vector::GatherOp, vector::ScatterOp>(); |
107 | bool isCompatibleScalableFunction = false; |
108 | function.walk([&](Operation *op) { |
109 | if (llvm::is_contained(disallowedOperations, |
110 | op->getName().getTypeID())) { |
111 | isCompatibleScalableFunction = false; |
112 | return WalkResult::interrupt(); |
113 | } |
114 | if (!isCompatibleScalableFunction && |
115 | (llvm::any_of(op->getOperandTypes(), isScalableVector) || |
116 | llvm::any_of(op->getResultTypes(), isScalableVector))) { |
117 | isCompatibleScalableFunction = true; |
118 | } |
119 | return WalkResult::advance(); |
120 | }); |
121 | if (!isCompatibleScalableFunction) |
122 | return; |
123 | } |
124 | |
125 | if (function->getAttr(kEnableArmStreamingIgnoreAttr) || |
126 | streamingMode == ArmStreamingMode::Disabled) |
127 | return; |
128 | |
129 | auto unitAttr = UnitAttr::get(&getContext()); |
130 | |
131 | function->setAttr(stringifyArmStreamingMode(streamingMode), unitAttr); |
132 | |
133 | // The pass currently only supports enabling ZA when in streaming-mode, but |
134 | // ZA can be accessed by the SME LDR, STR and ZERO instructions when not in |
135 | // streaming-mode (see section B1.1.1, IDGNQM of spec [1]). It may be worth |
136 | // supporting this later. |
137 | if (zaMode != ArmZaMode::Disabled) |
138 | function->setAttr(stringifyArmZaMode(zaMode), unitAttr); |
139 | } |
140 | }; |
141 | } // namespace |
142 | |
143 | std::unique_ptr<Pass> mlir::arm_sme::createEnableArmStreamingPass( |
144 | const ArmStreamingMode streamingMode, const ArmZaMode zaMode, |
145 | bool ifRequiredByOps, bool ifScalableAndSupported) { |
146 | return std::make_unique<EnableArmStreamingPass>( |
147 | streamingMode, zaMode, ifRequiredByOps, ifScalableAndSupported); |
148 | } |
149 | |