| 1 | //===- EnableArmStreaming.cpp - Enable Armv9 Streaming SVE mode -----------===// |
| 2 | // |
| 3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
| 4 | // See https://llvm.org/LICENSE.txt for license information. |
| 5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
| 6 | // |
| 7 | //===----------------------------------------------------------------------===// |
| 8 | // |
| 9 | // This pass enables the Armv9 Scalable Matrix Extension (SME) Streaming SVE |
| 10 | // (SSVE) mode [1][2] by adding either of the following attributes to |
| 11 | // 'func.func' ops: |
| 12 | // |
| 13 | // * 'arm_streaming' (default) |
| 14 | // * 'arm_locally_streaming' |
| 15 | // |
| 16 | // It can also optionally enable the ZA storage array. |
| 17 | // |
| 18 | // Streaming-mode is part of the interface (ABI) for functions with the |
| 19 | // first attribute and it's the responsibility of the caller to manage |
| 20 | // PSTATE.SM on entry/exit to functions with this attribute [3]. The LLVM |
| 21 | // backend will emit 'smstart sm' / 'smstop sm' [4] around calls to |
| 22 | // streaming functions. |
| 23 | // |
| 24 | // In locally streaming functions PSTATE.SM is kept internal and managed by |
| 25 | // the callee on entry/exit. The LLVM backend will emit 'smstart sm' / |
| 26 | // 'smstop sm' in the prologue / epilogue for functions with this |
| 27 | // attribute. |
| 28 | // |
| 29 | // [1] https://developer.arm.com/documentation/ddi0616/aa |
| 30 | // [2] https://llvm.org/docs/AArch64SME.html |
| 31 | // [3] https://github.com/ARM-software/abi-aa/blob/main/aapcs64/aapcs64.rst#671pstatesm-interfaces |
| 32 | // [4] https://developer.arm.com/documentation/ddi0602/2023-03/Base-Instructions/SMSTART--Enables-access-to-Streaming-SVE-mode-and-SME-architectural-state--an-alias-of-MSR--immediate-- |
| 33 | // |
| 34 | //===----------------------------------------------------------------------===// |
| 35 | |
| 36 | #include "mlir/Dialect/ArmSME/IR/ArmSME.h" |
| 37 | #include "mlir/Dialect/ArmSME/Transforms/Passes.h" |
| 38 | #include "mlir/Dialect/ArmSME/Transforms/PassesEnums.cpp.inc" |
| 39 | |
| 40 | #include "mlir/Dialect/Func/IR/FuncOps.h" |
| 41 | |
| 42 | #define DEBUG_TYPE "enable-arm-streaming" |
| 43 | |
| 44 | namespace mlir { |
| 45 | namespace arm_sme { |
| 46 | #define GEN_PASS_DEF_ENABLEARMSTREAMING |
| 47 | #include "mlir/Dialect/ArmSME/Transforms/Passes.h.inc" |
| 48 | } // namespace arm_sme |
| 49 | } // namespace mlir |
| 50 | |
| 51 | using namespace mlir; |
| 52 | using namespace mlir::arm_sme; |
| 53 | namespace { |
| 54 | |
| 55 | constexpr StringLiteral |
| 56 | kEnableArmStreamingIgnoreAttr("enable_arm_streaming_ignore" ); |
| 57 | |
| 58 | template <typename... Ops> |
| 59 | constexpr auto opList() { |
| 60 | return std::array{TypeID::get<Ops>()...}; |
| 61 | } |
| 62 | |
| 63 | bool isScalableVector(Type type) { |
| 64 | if (auto vectorType = dyn_cast<VectorType>(type)) |
| 65 | return vectorType.isScalable(); |
| 66 | return false; |
| 67 | } |
| 68 | |
| 69 | struct EnableArmStreamingPass |
| 70 | : public arm_sme::impl::EnableArmStreamingBase<EnableArmStreamingPass> { |
| 71 | EnableArmStreamingPass(ArmStreamingMode streamingMode, ArmZaMode zaMode, |
| 72 | bool ifRequiredByOps, bool ifScalableAndSupported) { |
| 73 | this->streamingMode = streamingMode; |
| 74 | this->zaMode = zaMode; |
| 75 | this->ifRequiredByOps = ifRequiredByOps; |
| 76 | this->ifScalableAndSupported = ifScalableAndSupported; |
| 77 | } |
| 78 | void runOnOperation() override { |
| 79 | auto function = getOperation(); |
| 80 | |
| 81 | if (ifRequiredByOps && ifScalableAndSupported) { |
| 82 | function->emitOpError( |
| 83 | "enable-arm-streaming: `if-required-by-ops` and " |
| 84 | "`if-scalable-and-supported` are mutually exclusive" ); |
| 85 | return signalPassFailure(); |
| 86 | } |
| 87 | |
| 88 | if (ifRequiredByOps) { |
| 89 | bool foundTileOp = false; |
| 90 | function.walk([&](Operation *op) { |
| 91 | if (llvm::isa<ArmSMETileOpInterface>(op)) { |
| 92 | foundTileOp = true; |
| 93 | return WalkResult::interrupt(); |
| 94 | } |
| 95 | return WalkResult::advance(); |
| 96 | }); |
| 97 | if (!foundTileOp) |
| 98 | return; |
| 99 | } |
| 100 | |
| 101 | if (ifScalableAndSupported) { |
| 102 | // FIXME: This should be based on target information (i.e., the presence |
| 103 | // of FEAT_SME_FA64). This currently errs on the side of caution. If |
| 104 | // possible gathers/scatters should be lowered regular vector loads/stores |
| 105 | // before invoking this pass. |
| 106 | auto disallowedOperations = opList<vector::GatherOp, vector::ScatterOp>(); |
| 107 | bool isCompatibleScalableFunction = false; |
| 108 | function.walk([&](Operation *op) { |
| 109 | if (llvm::is_contained(disallowedOperations, |
| 110 | op->getName().getTypeID())) { |
| 111 | isCompatibleScalableFunction = false; |
| 112 | return WalkResult::interrupt(); |
| 113 | } |
| 114 | if (!isCompatibleScalableFunction && |
| 115 | (llvm::any_of(op->getOperandTypes(), isScalableVector) || |
| 116 | llvm::any_of(op->getResultTypes(), isScalableVector))) { |
| 117 | isCompatibleScalableFunction = true; |
| 118 | } |
| 119 | return WalkResult::advance(); |
| 120 | }); |
| 121 | if (!isCompatibleScalableFunction) |
| 122 | return; |
| 123 | } |
| 124 | |
| 125 | if (function->getAttr(kEnableArmStreamingIgnoreAttr) || |
| 126 | streamingMode == ArmStreamingMode::Disabled) |
| 127 | return; |
| 128 | |
| 129 | auto unitAttr = UnitAttr::get(&getContext()); |
| 130 | |
| 131 | function->setAttr(stringifyArmStreamingMode(streamingMode), unitAttr); |
| 132 | |
| 133 | // The pass currently only supports enabling ZA when in streaming-mode, but |
| 134 | // ZA can be accessed by the SME LDR, STR and ZERO instructions when not in |
| 135 | // streaming-mode (see section B1.1.1, IDGNQM of spec [1]). It may be worth |
| 136 | // supporting this later. |
| 137 | if (zaMode != ArmZaMode::Disabled) |
| 138 | function->setAttr(stringifyArmZaMode(zaMode), unitAttr); |
| 139 | } |
| 140 | }; |
| 141 | } // namespace |
| 142 | |
| 143 | std::unique_ptr<Pass> mlir::arm_sme::createEnableArmStreamingPass( |
| 144 | const ArmStreamingMode streamingMode, const ArmZaMode zaMode, |
| 145 | bool ifRequiredByOps, bool ifScalableAndSupported) { |
| 146 | return std::make_unique<EnableArmStreamingPass>( |
| 147 | streamingMode, zaMode, ifRequiredByOps, ifScalableAndSupported); |
| 148 | } |
| 149 | |