1//===- LegalizeVectorStorage.cpp - Ensures SVE loads/stores are legal -----===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#include "mlir/Dialect/ArmSVE/IR/ArmSVEDialect.h"
10#include "mlir/Dialect/ArmSVE/Transforms/Passes.h"
11#include "mlir/Dialect/Func/IR/FuncOps.h"
12#include "mlir/Dialect/MemRef/IR/MemRef.h"
13#include "mlir/Dialect/Vector/IR/VectorOps.h"
14#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
15
16namespace mlir::arm_sve {
17#define GEN_PASS_DEF_LEGALIZEVECTORSTORAGE
18#include "mlir/Dialect/ArmSVE/Transforms/Passes.h.inc"
19} // namespace mlir::arm_sve
20
21using namespace mlir;
22using namespace mlir::arm_sve;
23
24// A tag to mark unrealized_conversions produced by this pass. This is used to
25// detect IR this pass failed to completely legalize, and report an error.
26// If everything was successfully legalized, no tagged ops will remain after
27// this pass.
28constexpr StringLiteral kSVELegalizerTag("__arm_sve_legalize_vector_storage__");
29
30/// Definitions:
31///
32/// [1] svbool = vector<...x[16]xi1>, which maps to some multiple of full SVE
33/// predicate registers. A full predicate is the smallest quantity that can be
34/// loaded/stored.
35///
36/// [2] SVE mask = hardware-sized SVE predicate mask, i.e. its trailing
37/// dimension matches the size of a legal SVE vector size (such as
38/// vector<[4]xi1>), but is too small to be stored to memory (i.e smaller than
39/// a svbool).
40
41namespace {
42
43/// Checks if a vector type is a SVE mask [2].
44bool isSVEMaskType(VectorType type) {
45 return type.getRank() > 0 && type.getElementType().isInteger(1) &&
46 type.getScalableDims().back() && type.getShape().back() < 16 &&
47 llvm::isPowerOf2_32(Value: type.getShape().back()) &&
48 !llvm::is_contained(type.getScalableDims().drop_back(), true);
49}
50
51VectorType widenScalableMaskTypeToSvbool(VectorType type) {
52 assert(isSVEMaskType(type));
53 return VectorType::Builder(type).setDim(type.getRank() - 1, 16);
54}
55
56/// A helper for cloning an op and replacing it will a new version, updated by a
57/// callback.
58template <typename TOp, typename TLegalizerCallback>
59void replaceOpWithLegalizedOp(PatternRewriter &rewriter, TOp op,
60 TLegalizerCallback callback) {
61 // Clone the previous op to preserve any properties/attributes.
62 auto newOp = op.clone();
63 rewriter.insert(op: newOp);
64 rewriter.replaceOp(op, callback(newOp));
65}
66
67/// A helper for cloning an op and replacing it with a new version, updated by a
68/// callback, and an unrealized conversion back to the type of the replaced op.
69template <typename TOp, typename TLegalizerCallback>
70void replaceOpWithUnrealizedConversion(PatternRewriter &rewriter, TOp op,
71 TLegalizerCallback callback) {
72 replaceOpWithLegalizedOp(rewriter, op, [&](TOp newOp) {
73 // Mark our `unrealized_conversion_casts` with a pass label.
74 return rewriter.create<UnrealizedConversionCastOp>(
75 op.getLoc(), TypeRange{op.getResult().getType()},
76 ValueRange{callback(newOp)},
77 NamedAttribute(rewriter.getStringAttr(kSVELegalizerTag),
78 rewriter.getUnitAttr()));
79 });
80}
81
82/// Extracts the widened SVE memref value (that's legal to store/load) from the
83/// `unrealized_conversion_cast`s added by this pass.
84static FailureOr<Value> getSVELegalizedMemref(Value illegalMemref) {
85 Operation *definingOp = illegalMemref.getDefiningOp();
86 if (!definingOp || !definingOp->hasAttr(name: kSVELegalizerTag))
87 return failure();
88 auto unrealizedConversion =
89 llvm::cast<UnrealizedConversionCastOp>(definingOp);
90 return unrealizedConversion.getOperand(0);
91}
92
93/// The default alignment of an alloca in LLVM may request overaligned sizes for
94/// SVE types, which will fail during stack frame allocation. This rewrite
95/// explicitly adds a reasonable alignment to allocas of scalable types.
96struct RelaxScalableVectorAllocaAlignment
97 : public OpRewritePattern<memref::AllocaOp> {
98 using OpRewritePattern::OpRewritePattern;
99
100 LogicalResult matchAndRewrite(memref::AllocaOp allocaOp,
101 PatternRewriter &rewriter) const override {
102 auto memrefElementType = allocaOp.getType().getElementType();
103 auto vectorType = llvm::dyn_cast<VectorType>(memrefElementType);
104 if (!vectorType || !vectorType.isScalable() || allocaOp.getAlignment())
105 return failure();
106
107 // Set alignment based on the defaults for SVE vectors and predicates.
108 unsigned aligment = vectorType.getElementType().isInteger(1) ? 2 : 16;
109 rewriter.modifyOpInPlace(allocaOp,
110 [&] { allocaOp.setAlignment(aligment); });
111
112 return success();
113 }
114};
115
116/// Replaces allocations of SVE predicates smaller than an svbool [1] (_illegal_
117/// to load/store) with a wider allocation of svbool (_legal_ to load/store)
118/// followed by a tagged unrealized conversion to the original type.
119///
120/// Example
121/// ```
122/// %alloca = memref.alloca() : memref<vector<[4]xi1>>
123/// ```
124/// is rewritten into:
125/// ```
126/// %widened = memref.alloca() {alignment = 1 : i64} : memref<vector<[16]xi1>>
127/// %alloca = builtin.unrealized_conversion_cast %widened
128/// : memref<vector<[16]xi1>> to memref<vector<[4]xi1>>
129/// {__arm_sve_legalize_vector_storage__}
130/// ```
131template <typename AllocLikeOp>
132struct LegalizeSVEMaskAllocation : public OpRewritePattern<AllocLikeOp> {
133 using OpRewritePattern<AllocLikeOp>::OpRewritePattern;
134
135 LogicalResult matchAndRewrite(AllocLikeOp allocLikeOp,
136 PatternRewriter &rewriter) const override {
137 auto vectorType =
138 llvm::dyn_cast<VectorType>(allocLikeOp.getType().getElementType());
139
140 if (!vectorType || !isSVEMaskType(vectorType))
141 return failure();
142
143 // Replace this alloc-like op of an SVE mask [2] with one of a (storable)
144 // svbool mask [1]. A temporary unrealized_conversion_cast is added to the
145 // old type to allow local rewrites.
146 replaceOpWithUnrealizedConversion(
147 rewriter, allocLikeOp, [&](AllocLikeOp newAllocLikeOp) {
148 newAllocLikeOp.getResult().setType(
149 llvm::cast<MemRefType>(newAllocLikeOp.getType().cloneWith(
150 {}, widenScalableMaskTypeToSvbool(vectorType))));
151 return newAllocLikeOp;
152 });
153
154 return success();
155 }
156};
157
158/// Replaces vector.type_casts of unrealized conversions to SVE predicate memref
159/// types that are _illegal_ to load/store from (!= svbool [1]), with type casts
160/// of memref types that are _legal_ to load/store, followed by unrealized
161/// conversions.
162///
163/// Example:
164/// ```
165/// %alloca = builtin.unrealized_conversion_cast %widened
166/// : memref<vector<[16]xi1>> to memref<vector<[8]xi1>>
167/// {__arm_sve_legalize_vector_storage__}
168/// %cast = vector.type_cast %alloca
169/// : memref<vector<3x[8]xi1>> to memref<3xvector<[8]xi1>>
170/// ```
171/// is rewritten into:
172/// ```
173/// %widened_cast = vector.type_cast %widened
174/// : memref<vector<3x[16]xi1>> to memref<3xvector<[16]xi1>>
175/// %cast = builtin.unrealized_conversion_cast %widened_cast
176/// : memref<3xvector<[16]xi1>> to memref<3xvector<[8]xi1>>
177/// {__arm_sve_legalize_vector_storage__}
178/// ```
179struct LegalizeSVEMaskTypeCastConversion
180 : public OpRewritePattern<vector::TypeCastOp> {
181 using OpRewritePattern::OpRewritePattern;
182
183 LogicalResult matchAndRewrite(vector::TypeCastOp typeCastOp,
184 PatternRewriter &rewriter) const override {
185 auto resultType = typeCastOp.getResultMemRefType();
186 auto vectorType = llvm::dyn_cast<VectorType>(resultType.getElementType());
187
188 if (!vectorType || !isSVEMaskType(vectorType))
189 return failure();
190
191 auto legalMemref = getSVELegalizedMemref(typeCastOp.getMemref());
192 if (failed(legalMemref))
193 return failure();
194
195 // Replace this vector.type_cast with one of a (storable) svbool mask [1].
196 replaceOpWithUnrealizedConversion(
197 rewriter, typeCastOp, [&](vector::TypeCastOp newTypeCast) {
198 newTypeCast.setOperand(*legalMemref);
199 newTypeCast.getResult().setType(
200 llvm::cast<MemRefType>(newTypeCast.getType().cloneWith(
201 {}, widenScalableMaskTypeToSvbool(vectorType))));
202 return newTypeCast;
203 });
204
205 return success();
206 }
207};
208
209/// Replaces stores to unrealized conversions to SVE predicate memref types that
210/// are _illegal_ to load/store from (!= svbool [1]), with
211/// `arm_sve.convert_to_svbool`s followed by (legal) wider stores.
212///
213/// Example:
214/// ```
215/// memref.store %mask, %alloca[] : memref<vector<[8]xi1>>
216/// ```
217/// is rewritten into:
218/// ```
219/// %svbool = arm_sve.convert_to_svbool %mask : vector<[8]xi1>
220/// memref.store %svbool, %widened[] : memref<vector<[16]xi1>>
221/// ```
222struct LegalizeSVEMaskStoreConversion
223 : public OpRewritePattern<memref::StoreOp> {
224 using OpRewritePattern::OpRewritePattern;
225
226 LogicalResult matchAndRewrite(memref::StoreOp storeOp,
227 PatternRewriter &rewriter) const override {
228 auto loc = storeOp.getLoc();
229
230 Value valueToStore = storeOp.getValueToStore();
231 auto vectorType = llvm::dyn_cast<VectorType>(valueToStore.getType());
232
233 if (!vectorType || !isSVEMaskType(vectorType))
234 return failure();
235
236 auto legalMemref = getSVELegalizedMemref(storeOp.getMemref());
237 if (failed(legalMemref))
238 return failure();
239
240 auto legalMaskType = widenScalableMaskTypeToSvbool(
241 llvm::cast<VectorType>(valueToStore.getType()));
242 auto convertToSvbool = rewriter.create<arm_sve::ConvertToSvboolOp>(
243 loc, legalMaskType, valueToStore);
244 // Replace this store with a conversion to a storable svbool mask [1],
245 // followed by a wider store.
246 replaceOpWithLegalizedOp(rewriter, storeOp,
247 [&](memref::StoreOp newStoreOp) {
248 newStoreOp.setOperand(0, convertToSvbool);
249 newStoreOp.setOperand(1, *legalMemref);
250 return newStoreOp;
251 });
252
253 return success();
254 }
255};
256
257/// Replaces loads from unrealized conversions to SVE predicate memref types
258/// that are _illegal_ to load/store from (!= svbool [1]), types with (legal)
259/// wider loads, followed by `arm_sve.convert_from_svbool`s.
260///
261/// Example:
262/// ```
263/// %reload = memref.load %alloca[] : memref<vector<[4]xi1>>
264/// ```
265/// is rewritten into:
266/// ```
267/// %svbool = memref.load %widened[] : memref<vector<[16]xi1>>
268/// %reload = arm_sve.convert_from_svbool %reload : vector<[4]xi1>
269/// ```
270struct LegalizeSVEMaskLoadConversion : public OpRewritePattern<memref::LoadOp> {
271 using OpRewritePattern::OpRewritePattern;
272
273 LogicalResult matchAndRewrite(memref::LoadOp loadOp,
274 PatternRewriter &rewriter) const override {
275 auto loc = loadOp.getLoc();
276
277 Value loadedMask = loadOp.getResult();
278 auto vectorType = llvm::dyn_cast<VectorType>(loadedMask.getType());
279
280 if (!vectorType || !isSVEMaskType(vectorType))
281 return failure();
282
283 auto legalMemref = getSVELegalizedMemref(loadOp.getMemref());
284 if (failed(legalMemref))
285 return failure();
286
287 auto legalMaskType = widenScalableMaskTypeToSvbool(vectorType);
288 // Replace this load with a legal load of an svbool type, followed by a
289 // conversion back to the original type.
290 replaceOpWithLegalizedOp(rewriter, loadOp, [&](memref::LoadOp newLoadOp) {
291 newLoadOp.setMemRef(*legalMemref);
292 newLoadOp.getResult().setType(legalMaskType);
293 return rewriter.create<arm_sve::ConvertFromSvboolOp>(
294 loc, loadedMask.getType(), newLoadOp);
295 });
296
297 return success();
298 }
299};
300
301} // namespace
302
303void mlir::arm_sve::populateLegalizeVectorStoragePatterns(
304 RewritePatternSet &patterns) {
305 patterns.add<RelaxScalableVectorAllocaAlignment,
306 LegalizeSVEMaskAllocation<memref::AllocaOp>,
307 LegalizeSVEMaskAllocation<memref::AllocOp>,
308 LegalizeSVEMaskTypeCastConversion,
309 LegalizeSVEMaskStoreConversion, LegalizeSVEMaskLoadConversion>(
310 patterns.getContext());
311}
312
313namespace {
314struct LegalizeVectorStorage
315 : public arm_sve::impl::LegalizeVectorStorageBase<LegalizeVectorStorage> {
316
317 void runOnOperation() override {
318 RewritePatternSet patterns(&getContext());
319 populateLegalizeVectorStoragePatterns(patterns);
320 if (failed(applyPatternsAndFoldGreedily(getOperation(),
321 std::move(patterns)))) {
322 signalPassFailure();
323 }
324 ConversionTarget target(getContext());
325 target.addDynamicallyLegalOp<UnrealizedConversionCastOp>(
326 [](UnrealizedConversionCastOp unrealizedConversion) {
327 return !unrealizedConversion->hasAttr(kSVELegalizerTag);
328 });
329 // This detects if we failed to completely legalize the IR.
330 if (failed(applyPartialConversion(getOperation(), target, {})))
331 signalPassFailure();
332 }
333};
334
335} // namespace
336
337std::unique_ptr<Pass> mlir::arm_sve::createLegalizeVectorStoragePass() {
338 return std::make_unique<LegalizeVectorStorage>();
339}
340

source code of mlir/lib/Dialect/ArmSVE/Transforms/LegalizeVectorStorage.cpp