1//===- TestLowerToArmSME.cpp - Test lowering to ArmSME as a sink pass -----===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implements a pass for testing the lowering to ArmSME as a
10// generally usable sink pass.
11//
12//===----------------------------------------------------------------------===//
13
14#include "mlir/Conversion/ArithToArmSME/ArithToArmSME.h"
15#include "mlir/Conversion/ArmSMEToLLVM/ArmSMEToLLVM.h"
16#include "mlir/Conversion/ArmSMEToSCF/ArmSMEToSCF.h"
17#include "mlir/Conversion/VectorToArmSME/VectorToArmSME.h"
18#include "mlir/Conversion/VectorToSCF/VectorToSCF.h"
19#include "mlir/Dialect/ArmSME/Transforms/Passes.h"
20#include "mlir/Dialect/ArmSVE/Transforms/Passes.h"
21#include "mlir/IR/DialectRegistry.h"
22#include "mlir/Pass/Pass.h"
23#include "mlir/Pass/PassManager.h"
24#include "mlir/Pass/PassOptions.h"
25#include "mlir/Transforms/Passes.h"
26
27using namespace mlir;
28
29namespace {
30struct TestLowerToArmSMEOptions
31 : public PassPipelineOptions<TestLowerToArmSMEOptions> {
32 PassOptions::Option<bool> fuseOuterProducts{
33 *this, "fuse-outer-products",
34 llvm::cl::desc("Fuse outer product operations via "
35 "'-arm-sme-outer-product-fusion' pass"),
36 llvm::cl::init(Val: true)};
37};
38
39void buildTestLowerToArmSME(OpPassManager &pm,
40 const TestLowerToArmSMEOptions &options) {
41 // Legalize vector operations so they can be converted to ArmSME.
42 pm.addPass(arm_sme::pass: createVectorLegalizationPass());
43
44 // Sprinkle some cleanups.
45 pm.addPass(pass: createCanonicalizerPass());
46 pm.addPass(pass: createCSEPass());
47
48 // Passes that convert operations on vectors to ArmSME operations.
49
50 // Convert Arith to ArmSME.
51 pm.addPass(pass: createArithToArmSMEConversionPass());
52 // Convert Vector to ArmSME.
53 pm.addPass(pass: createConvertVectorToArmSMEPass());
54
55 // Fuse outer products.
56 if (options.fuseOuterProducts)
57 pm.addPass(arm_sme::pass: createOuterProductFusionPass());
58
59 // Convert operations on high-level vectors to loops.
60
61 // Convert ArmSME to SCF.
62 pm.addPass(pass: createConvertArmSMEToSCFPass());
63
64 // Convert Vector to SCF (with full unroll enabled).
65 pm.addPass(createConvertVectorToSCFPass(
66 VectorTransferToSCFOptions().enableFullUnroll()));
67
68 // Allocate tiles for ArmSME operations.
69 //
70 // Later passes may create further ArmSME ops that implement the
71 // ArmSMETileOpInterface, but tiles are allocated for root operations,
72 // all of which should now exist.
73 pm.addPass(arm_sme::pass: createTileAllocationPass());
74
75 // Enable streaming-mode and ZA.
76 pm.addPass(arm_sme::createEnableArmStreamingPass(
77 arm_sme::ArmStreamingMode::StreamingLocally, arm_sme::ArmZaMode::NewZA,
78 /*onlyIfRequiredByOps=*/true));
79
80 // Convert ArmSME to LLVM.
81 pm.addPass(pass: createConvertArmSMEToLLVMPass());
82
83 // Sprinkle some cleanups.
84 pm.addPass(pass: createCanonicalizerPass());
85 pm.addPass(pass: createCSEPass());
86}
87} // namespace
88
89namespace mlir {
90namespace test {
91void registerTestLowerToArmSME() {
92 PassPipelineRegistration<TestLowerToArmSMEOptions>(
93 "test-lower-to-arm-sme",
94 "An example pipeline to lower operations on vectors (arith, vector) to "
95 "LLVM via ArmSME.",
96 buildTestLowerToArmSME);
97}
98} // namespace test
99} // namespace mlir
100

source code of mlir/test/lib/Dialect/ArmSME/TestLowerToArmSME.cpp