1//===- ArithToArmSME.cpp - Arith to ArmSME dialect conversion -------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#include "mlir/Conversion/ArithToArmSME/ArithToArmSME.h"
10
11#include "mlir/Dialect/Arith/IR/Arith.h"
12#include "mlir/Dialect/ArmSME/IR/ArmSME.h"
13#include "mlir/Dialect/ArmSME/Utils/Utils.h"
14#include "mlir/Pass/Pass.h"
15#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
16
17namespace mlir {
18#define GEN_PASS_DEF_ARITHTOARMSMECONVERSIONPASS
19#include "mlir/Conversion/Passes.h.inc"
20} // namespace mlir
21
22#define DEBUG_TYPE "arith-to-arm-sme"
23
24using namespace mlir;
25
26//===----------------------------------------------------------------------===//
27// Conversion helpers
28//===----------------------------------------------------------------------===//
29
30/// Returns true if 'val' is a splat of zero, false otherwise.
31static bool isSplatZero(Type elemType, DenseElementsAttr val) {
32 if (llvm::isa<FloatType>(elemType))
33 return val && val.isSplat() && val.getSplatValue<APFloat>().isZero();
34 if (llvm::isa<IntegerType>(elemType))
35 return val && val.isSplat() && val.getSplatValue<APInt>().isZero();
36 return false;
37}
38
39namespace {
40
41//===----------------------------------------------------------------------===//
42// ConstantOp
43//===----------------------------------------------------------------------===//
44
45/// Conversion pattern for dense arith.constant.
46struct ConstantOpToArmSMELowering : public OpRewritePattern<arith::ConstantOp> {
47 using OpRewritePattern<arith::ConstantOp>::OpRewritePattern;
48
49 LogicalResult matchAndRewrite(arith::ConstantOp constantOp,
50 PatternRewriter &rewriter) const final {
51 auto tileType = dyn_cast<VectorType>(constantOp.getType());
52 if (!tileType || !arm_sme::isValidSMETileVectorType(vType: tileType))
53 return failure();
54
55 auto denseAttr = dyn_cast<DenseElementsAttr>(constantOp.getValueAttr());
56 if (!denseAttr || !denseAttr.isSplat())
57 return failure();
58
59 auto tileElementType = tileType.getElementType();
60
61 // Lower 'arith.constant dense<0>' to 'arm_sme.zero' op.
62 if (isSplatZero(tileElementType, denseAttr)) {
63 rewriter.replaceOpWithNewOp<arm_sme::ZeroOp>(constantOp, tileType);
64 return success();
65 }
66
67 // Lower non-zero constants to a loop of 'arm_sme.move_vector_to_tile_slice'
68 // ops that broadcast the constant to each tile slice.
69 auto loc = constantOp.getLoc();
70
71 // To fill a tile with a constant, we create a 1-D splat of the constant,
72 // then move that into each tile slice (the largest unit we can set at once,
73 // outside of operations like the outerproduct).
74 VectorType tileSliceType = VectorType::Builder(tileType).dropDim(0);
75 auto denseAttr1D = DenseElementsAttr::get(
76 tileSliceType, denseAttr.getSplatValue<Attribute>());
77 auto constantOp1D = rewriter.create<arith::ConstantOp>(loc, denseAttr1D);
78
79 auto initTile = rewriter.create<arm_sme::GetTileOp>(loc, tileType);
80 auto makeLoopBody = [&](OpBuilder &b, Location loc, Value tileSliceIndex,
81 Value currentTile) {
82 // Create 'arm_sme.move_vector_to_tile_slice' to write vector to tile
83 // slice.
84 auto nextTile = b.create<arm_sme::MoveVectorToTileSliceOp>(
85 loc, tileType, constantOp1D, currentTile, tileSliceIndex);
86 return nextTile.getResult();
87 };
88 auto forOp = mlir::arm_sme::createLoopOverTileSlices(
89 rewriter, loc, initTile, makeLoopBody);
90 rewriter.replaceOp(constantOp, forOp.getResult(0));
91
92 return success();
93 }
94};
95
96} // namespace
97
98//===----------------------------------------------------------------------===//
99// Pattern population
100//===----------------------------------------------------------------------===//
101
102void mlir::arith::populateArithToArmSMEConversionPatterns(
103 RewritePatternSet &patterns) {
104 patterns.add<ConstantOpToArmSMELowering>(arg: patterns.getContext());
105}
106
107//===----------------------------------------------------------------------===//
108// Pass definition
109//===----------------------------------------------------------------------===//
110
111namespace {
112struct ArithToArmSMEConversionPass final
113 : impl::ArithToArmSMEConversionPassBase<ArithToArmSMEConversionPass> {
114 using impl::ArithToArmSMEConversionPassBase<
115 ArithToArmSMEConversionPass>::ArithToArmSMEConversionPassBase;
116
117 void runOnOperation() override {
118 RewritePatternSet patterns(&getContext());
119 arith::populateArithToArmSMEConversionPatterns(patterns);
120 if (failed(
121 applyPatternsAndFoldGreedily(getOperation(), std::move(patterns))))
122 return signalPassFailure();
123 }
124};
125} // namespace
126

source code of mlir/lib/Conversion/ArithToArmSME/ArithToArmSME.cpp