1//===- ArmSMEToSCF.cpp - Convert ArmSME to SCF dialect ----------*- C++ -*-===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implements lowering of ArmSME operations to SCF.
10//
11//===----------------------------------------------------------------------===//
12#include "mlir/Conversion/ArmSMEToSCF/ArmSMEToSCF.h"
13
14#include "mlir/Dialect/Arith/IR/Arith.h"
15#include "mlir/Dialect/ArmSME/IR/ArmSME.h"
16#include "mlir/Dialect/ArmSME/Utils/Utils.h"
17#include "mlir/Dialect/SCF/IR/SCF.h"
18#include "mlir/Pass/Pass.h"
19#include "mlir/Transforms/DialectConversion.h"
20
21namespace mlir {
22#define GEN_PASS_DEF_CONVERTARMSMETOSCFPASS
23#include "mlir/Conversion/Passes.h.inc"
24} // namespace mlir
25
26using namespace mlir;
27
28namespace {
29/// Returns adjusted (1-D or 2-D) `indices` for a tile slice as follows:
30/// rank 1: (indices[0] + (tileSliceIndex * tileSliceNumElts))
31/// rank 2: (indices[0] + tileSliceIndex, indices[1])
32SmallVector<Value, 2> getMemrefIndices(ValueRange indices, unsigned rank,
33 Value tileSliceIndex,
34 Value tileSliceNumElts, Location loc,
35 PatternRewriter &rewriter) {
36 assert(rank == 2 && "memref has unexpected rank!");
37 SmallVector<Value, 2> outIndices;
38
39 auto tileSliceOffset = tileSliceIndex;
40
41 auto baseIndexPlusTileSliceOffset =
42 rewriter.create<arith::AddIOp>(location: loc, args: indices[0], args&: tileSliceOffset);
43 outIndices.push_back(Elt: baseIndexPlusTileSliceOffset);
44 outIndices.push_back(Elt: indices[1]);
45
46 return outIndices;
47}
48
49/// Creates an scf.for for the load/store of an ArmSME tile.
50FailureOr<scf::ForOp> createLoadStoreForOverTileSlices(
51 PatternRewriter &rewriter, Location loc, VectorType tileType,
52 ValueRange memrefIndices, int memrefRank, Value mask, Value initTile,
53 function_ref<Value(/*index=*/Value, ValueRange, /*predicate=*/Value,
54 /*currentTile=*/Value)>
55 makeLoopBody) {
56 PatternRewriter::InsertionGuard guard(rewriter);
57
58 // TODO: This case should be captured and rejected by a verifier.
59 if (memrefIndices.size() != 2)
60 return rewriter.notifyMatchFailure(arg&: loc, msg: "invalid number of indices");
61
62 auto minTileSlices = rewriter.create<arith::ConstantIndexOp>(
63 location: loc, args: arm_sme::getSMETileSliceMinNumElts(type: tileType.getElementType()));
64 auto vscale =
65 rewriter.create<vector::VectorScaleOp>(location: loc, args: rewriter.getIndexType());
66 auto predicateType =
67 VectorType::get(shape: tileType.getDimSize(idx: 1), elementType: rewriter.getI1Type(), scalableDims: true);
68
69 // This describes both the number of ZA tile slices and the number of
70 // elements in a vector of SVL bits for a given element type (SVL_B,
71 // SVL_H, ..., SVL_Q).
72 auto numTileSlices =
73 rewriter.create<arith::MulIOp>(location: loc, args&: minTileSlices, args&: vscale);
74
75 Value predicate;
76 Value upperBound;
77 if (mask) {
78 auto createMaskOp = mask.getDefiningOp<vector::CreateMaskOp>();
79 auto maskDim0 = createMaskOp.getOperands()[0];
80 auto maskDim1 = createMaskOp.getOperands()[1];
81
82 // The upper bound of the loop must be clamped at `numTileSlices` as
83 // `vector.create_mask` allows operands to be greater than the size of a
84 // dimension.
85 auto numRowI64 = rewriter.create<arith::IndexCastOp>(
86 location: loc, args: rewriter.getI64Type(), args&: maskDim0);
87 auto numTileSlicesI64 = rewriter.create<arith::IndexCastOp>(
88 location: loc, args: rewriter.getI64Type(), args&: numTileSlices);
89 auto upperBoundI64 =
90 rewriter.create<arith::MinSIOp>(location: loc, args&: numRowI64, args&: numTileSlicesI64);
91 upperBound = rewriter.create<arith::IndexCastOp>(
92 location: loc, args: rewriter.getIndexType(), args&: upperBoundI64);
93
94 predicate =
95 rewriter.create<vector::CreateMaskOp>(location: loc, args&: predicateType, args&: maskDim1);
96 } else {
97 upperBound = numTileSlices;
98 // No mask. Create an 'all true' predicate for the tile slice.
99 predicate = rewriter.create<arith::ConstantOp>(
100 location: loc, args: DenseElementsAttr::get(type: predicateType, value: true));
101 }
102
103 bool hasCarriedArgs = bool(initTile);
104 auto lowerBound = rewriter.create<arith::ConstantIndexOp>(location: loc, args: 0);
105 auto step = rewriter.create<arith::ConstantIndexOp>(location: loc, args: 1);
106 auto forOp = rewriter.create<scf::ForOp>(location: loc, args&: lowerBound, args&: upperBound, args&: step,
107 args: hasCarriedArgs ? ValueRange{initTile}
108 : ValueRange{});
109
110 rewriter.setInsertionPointToStart(forOp.getBody());
111 Value tileSliceIndex = forOp.getInductionVar();
112
113 auto adjustedIndices = getMemrefIndices(
114 indices: memrefIndices, rank: memrefRank, tileSliceIndex, tileSliceNumElts: numTileSlices, loc, rewriter);
115 auto nextTile = makeLoopBody(
116 tileSliceIndex, adjustedIndices, predicate,
117 /*currentTile=*/hasCarriedArgs ? forOp.getRegionIterArg(index: 0) : Value{});
118
119 assert(bool(nextTile) == hasCarriedArgs);
120 if (nextTile)
121 rewriter.create<scf::YieldOp>(location: loc, args&: nextTile);
122
123 return forOp;
124}
125
126FailureOr<scf::ForOp> createLoadStoreForOverTileSlices(
127 PatternRewriter &rewriter, Location loc, VectorType tileType,
128 ValueRange memrefIndices, int memrefRank, Value mask,
129 function_ref<void(/*index=*/Value, ValueRange, /*predicate=*/Value)>
130 makeLoopBody) {
131 return createLoadStoreForOverTileSlices(
132 rewriter, loc, tileType, memrefIndices, memrefRank, mask, initTile: Value{},
133 makeLoopBody: [&](Value index, ValueRange adjustedIndices, Value predicate,
134 Value) -> Value {
135 makeLoopBody(index, adjustedIndices, predicate);
136 return {};
137 });
138}
139
140/// Lower `arm_sme.tile_load` without a mask, or with a mask and a zero pad.
141///
142/// With a mask:
143///
144/// BEFORE:
145/// ```mlir
146/// %pad = arith.constant 0 : i32
147/// %mask = vector.create_mask %num_rows, %num_cols : vector<[4]x[4]xi1>
148/// %tile = arm_sme.tile_load %src[%c0, %c0], %pad, %mask :
149/// memref<?x?xi32>, vector<[4]x[4]xi32>
150/// ```
151///
152/// AFTER:
153/// ```mlir
154/// %init_tile = arm_sme.zero : vector<[4]x[4]xi32>
155/// %mask_cols = vector.create_mask %num_cols : vector<[4]xi1>
156/// %loop_rows = arith.minsi %num_rows, %svl_s : index
157/// %tile = scf.for %tile_slice_idx = %c0 to %loop_rows step %c1
158/// iter_args(%iter_tile = %init_tile) -> (vector<[4]x[4]xi32>) {
159/// %tile_update = arm_sme.load_tile_slice
160/// %src[%tile_slice_idx], %num_cols, %iter_tile, %tile_slice_idx :
161/// memref<?x?xi32>, vector<[1]xi32>, vector<[4]x[4]xi32>
162/// scf.yield %tile_update : vector<[4]x[4]xi32>
163/// }
164/// ```
165///
166/// Without a mask the lowering is pretty much identical. The only difference is
167/// %mask_cols becomes an all-true mask, and %loop_rows becomes %svl_s.
168///
169/// NOTE: Only mask of 'vector.create_mask' op is currently supported.
170struct TileLoadOpConversion : public OpRewritePattern<arm_sme::TileLoadOp> {
171 using OpRewritePattern<arm_sme::TileLoadOp>::OpRewritePattern;
172
173 LogicalResult matchAndRewrite(arm_sme::TileLoadOp tileLoadOp,
174 PatternRewriter &rewriter) const override {
175 auto loc = tileLoadOp.getLoc();
176 auto tileType = tileLoadOp.getVectorType();
177 auto mask = tileLoadOp.getMask();
178
179 Value initTile;
180 if (mask) {
181 if (!mask.getDefiningOp<vector::CreateMaskOp>())
182 return rewriter.notifyMatchFailure(
183 arg&: loc, msg: "unsupported mask op, only 'vector.create_mask' is "
184 "currently supported");
185 auto padOp = tileLoadOp.getPadding();
186 assert(padOp && "expected padding when masking!");
187
188 auto constPadOp = padOp.getDefiningOp<arith::ConstantOp>();
189 if (!constPadOp || constPadOp.getValue() !=
190 rewriter.getZeroAttr(type: tileType.getElementType()))
191 return rewriter.notifyMatchFailure(
192 arg&: tileLoadOp, msg: "op has non-zero pad, needs non-zero pad pattern");
193
194 // Initialize tile with zero to satisfy padding. Inactive cols will be
195 // zeroed anyway since the loads use zeroing predication. For inactive
196 // rows however, no load will occur so these need to be zeroed.
197 initTile = rewriter.create<arm_sme::ZeroOp>(location: loc, args&: tileType);
198 } else {
199 initTile = rewriter.create<arm_sme::GetTileOp>(location: loc, args&: tileType);
200 }
201
202 // Create a loop to load the active tile slices from memory.
203 auto forOp = createLoadStoreForOverTileSlices(
204 rewriter, loc, tileType, memrefIndices: tileLoadOp.getIndices(),
205 memrefRank: tileLoadOp.getMemRefType().getRank(), mask, initTile,
206 makeLoopBody: [&](Value tileSliceIndex, ValueRange memrefIndices, Value predicate,
207 Value currentTile) -> Value {
208 // Create 'arm_sme.load_tile_slice' to load tile slice from memory
209 // into tile.
210 return rewriter.create<arm_sme::LoadTileSliceOp>(
211 location: loc, args&: tileType, args: tileLoadOp.getBase(), args&: predicate, args&: currentTile,
212 args&: memrefIndices, args&: tileSliceIndex, args: tileLoadOp.getLayout());
213 });
214
215 if (failed(Result: forOp))
216 return forOp;
217
218 // Replace 'arm_sme.tile_load' with the result.
219 rewriter.replaceOp(op: tileLoadOp, newValues: forOp->getResult(i: 0));
220
221 return success();
222 }
223};
224
225/// Lower `arm_sme.tile_load` with mask and non-zero pad.
226///
227/// BEFORE:
228/// ```mlir
229/// %mask = vector.create_mask %num_rows, %num_cols : vector<[4]x[4]xi1>
230/// %tile = arm_sme.tile_load %src[%c0, %c0], %pad, %mask :
231/// memref<?x?xi32>, vector<[4]x[4]xi32>
232/// ```
233///
234/// AFTER:
235/// ```mlir
236/// ...
237/// %pad_1d = vector.splat %pad : vector<[4]xi32>
238/// %tile = scf.for %tile_slice_idx = %c0 to %svl_s step %c1
239/// iter_args(%iter_tile = %init_tile) -> (vector<[4]x[4]xi32>) {
240/// ...
241/// %mask_1d = vector.create_mask <combined_mask> : vector<[4]xi1>
242/// %slice = vector.maskedload %base[%tile_slice_idx, %c0], %mask_1d, %pad_1d
243/// : memref<?x?xi32>, vector<[4]xi1>,
244/// vector<[4]xi32> into vector<[4]xi32>
245/// // Insert slice into tile
246/// %tile_update = arm_sme.insert_tile_slice
247/// %slice, %iter_tile[%tile_slice_idx] :
248/// vector<[4]xi32> into vector<[4]x[4]xi32>
249/// scf.yield %tile_update : vector<[4]x[4]xi32>
250/// }
251/// ```
252struct TileLoadOpWithMaskAndPadNonZeroConversion
253 : public OpRewritePattern<arm_sme::TileLoadOp> {
254 using OpRewritePattern<arm_sme::TileLoadOp>::OpRewritePattern;
255
256 LogicalResult matchAndRewrite(arm_sme::TileLoadOp tileLoadOp,
257 PatternRewriter &rewriter) const override {
258 OpBuilder::InsertionGuard g(rewriter);
259 auto loc = tileLoadOp.getLoc();
260 auto tileType = tileLoadOp.getVectorType();
261 auto tileElementType = tileType.getElementType();
262
263 auto maskOp = tileLoadOp.getMask();
264 if (!maskOp)
265 return rewriter.notifyMatchFailure(
266 arg&: tileLoadOp, msg: "op has no mask, needs unmasked pattern");
267
268 auto padOp = tileLoadOp.getPadding();
269 assert(padOp && "expected padding when masking!");
270
271 auto createMaskOp = maskOp.getDefiningOp<vector::CreateMaskOp>();
272 if (!createMaskOp)
273 return rewriter.notifyMatchFailure(
274 arg&: tileLoadOp, msg: "unsupported mask op, only 'vector.create_mask' is "
275 "currently supported");
276
277 auto constPadOp = padOp.getDefiningOp<arith::ConstantOp>();
278 if (constPadOp &&
279 constPadOp.getValue() == rewriter.getZeroAttr(type: tileElementType))
280 return rewriter.notifyMatchFailure(
281 arg&: tileLoadOp, msg: "op has constant zero pad, needs zero pad pattern");
282
283 auto numRows = createMaskOp.getOperands()[0];
284 auto numCols = createMaskOp.getOperands()[1];
285
286 auto numColsI32 = rewriter.create<arith::IndexCastUIOp>(
287 location: loc, args: rewriter.getI32Type(), args&: numCols);
288
289 auto initTile = rewriter.create<arm_sme::GetTileOp>(location: loc, args&: tileType);
290
291 // Create a loop that loads each ZA tile slice from memory.
292 auto step = rewriter.create<arith::ConstantIndexOp>(location: loc, args: 1);
293 auto minTileSlices = rewriter.create<arith::ConstantIndexOp>(
294 location: loc, args: arm_sme::getSMETileSliceMinNumElts(type: tileElementType));
295 auto vscale =
296 rewriter.create<vector::VectorScaleOp>(location: loc, args: rewriter.getIndexType());
297 auto lowerBound = rewriter.create<arith::ConstantIndexOp>(location: loc, args: 0);
298 auto numTileSlices =
299 rewriter.create<arith::MulIOp>(location: loc, args&: minTileSlices, args&: vscale);
300 auto forOp = rewriter.create<scf::ForOp>(location: loc, args&: lowerBound, args&: numTileSlices,
301 args&: step, args: ValueRange{initTile});
302
303 rewriter.setInsertionPointToStart(forOp.getBody());
304
305 auto tileSliceIndex = forOp.getInductionVar();
306 auto currentTile = forOp.getRegionIterArg(index: 0);
307
308 // Combine masks.
309 auto rowIsActive = rewriter.create<arith::CmpIOp>(
310 location: loc, args: arith::CmpIPredicate::ult, args&: tileSliceIndex, args&: numRows);
311 auto rowIsActiveI32 = rewriter.create<arith::ExtSIOp>(
312 location: loc, args: rewriter.getI32Type(), args&: rowIsActive);
313 auto mask = rewriter.create<arith::AndIOp>(location: loc, args&: rowIsActiveI32, args&: numColsI32);
314 auto maskIndex =
315 rewriter.create<arith::IndexCastOp>(location: loc, args: rewriter.getIndexType(), args&: mask);
316 auto predicateType =
317 VectorType::get(shape: tileType.getDimSize(idx: 1), elementType: rewriter.getI1Type(), scalableDims: true);
318 auto maskOp1D = rewriter.create<vector::CreateMaskOp>(
319 location: loc, args&: predicateType, args: maskIndex.getResult());
320
321 auto memrefIndices = getMemrefIndices(
322 indices: tileLoadOp.getIndices(), rank: tileLoadOp.getMemRefType().getRank(),
323 tileSliceIndex, tileSliceNumElts: numTileSlices, loc, rewriter);
324
325 // Splat pad into 1-D vector matching type of tile slice.
326 VectorType tileSliceType = VectorType::Builder(tileType).dropDim(pos: 0);
327 auto pad1DOp = rewriter.create<vector::SplatOp>(location: loc, args&: tileSliceType, args&: padOp);
328
329 auto loadSlice = rewriter.create<vector::MaskedLoadOp>(
330 location: loc, args&: tileSliceType, args: tileLoadOp.getBase(), args&: memrefIndices, args&: maskOp1D,
331 /*passthru=*/args&: pad1DOp);
332
333 // Create 'arm_sme.insert_tile_slice' to insert slice into tile.
334 auto insertSlice = rewriter.create<arm_sme::InsertTileSliceOp>(
335 location: loc, args&: tileType, args: loadSlice->getResult(idx: 0), args&: currentTile, args&: tileSliceIndex,
336 args: tileLoadOp.getLayout());
337 rewriter.create<scf::YieldOp>(location: loc, args: insertSlice.getResult());
338
339 rewriter.setInsertionPointAfter(forOp);
340
341 // Replace 'arm_sme.tile_load' with the result.
342 rewriter.replaceOp(op: tileLoadOp, newValues: forOp.getResult(i: 0));
343
344 return success();
345 }
346};
347
348/// Lower `arm_sme.tile_store` to a loop over the tile slices and store each
349/// slice using `arm_sme.store_tile_slice`.
350///
351/// BEFORE:
352/// ```mlir
353/// arm_sme.tile_store %tile, %dest[%c0, %c0] layout<vertical>
354/// : memref<?x?xi32>, vector<[4]x[4]xi32
355/// ```
356///
357/// AFTER:
358/// ```mlir
359/// %vscale = vector.vscale
360/// %c0 = arith.constant 0 : index
361/// %c1 = arith.constant 1 : index
362/// %min_svl_s = arith.constant 4 : index
363/// %svl_s = arith.muli %min_svl_s, %vscale : index
364/// scf.for %tile_slice_idx = %c0 to %svl_s step %c1 {
365/// arm_sme.store_tile_slice %tile, %tile_slice_idx, %dest[%tile_slice_idx],
366/// layout<vertical> : memref<?x?xi32>, vector<[4]x[4]xi32>
367/// }
368/// ```
369struct TileStoreOpConversion : public OpRewritePattern<arm_sme::TileStoreOp> {
370 using OpRewritePattern<arm_sme::TileStoreOp>::OpRewritePattern;
371
372 LogicalResult matchAndRewrite(arm_sme::TileStoreOp tileStoreOp,
373 PatternRewriter &rewriter) const override {
374 if (Value mask = tileStoreOp.getMask()) {
375 if (!mask.getDefiningOp<vector::CreateMaskOp>())
376 return rewriter.notifyMatchFailure(
377 arg: tileStoreOp.getLoc(),
378 msg: "unsupported mask op, only 'vector.create_mask' is "
379 "currently supported");
380 }
381
382 // Create a loop that stores each active ZA tile slice from memory.
383 return createLoadStoreForOverTileSlices(
384 rewriter, loc: tileStoreOp.getLoc(), tileType: tileStoreOp.getVectorType(),
385 memrefIndices: tileStoreOp.getIndices(), memrefRank: tileStoreOp.getMemRefType().getRank(),
386 mask: tileStoreOp.getMask(),
387 makeLoopBody: [&](Value tileSliceIndex, ValueRange memrefIndices, Value predicate) {
388 rewriter.replaceOpWithNewOp<arm_sme::StoreTileSliceOp>(
389 op: tileStoreOp, args: tileStoreOp.getValueToStore(), args&: tileSliceIndex,
390 args&: predicate, args: tileStoreOp.getBase(), args&: memrefIndices,
391 args: tileStoreOp.getLayout());
392 });
393 }
394};
395
396} // namespace
397
398void mlir::populateArmSMEToSCFConversionPatterns(RewritePatternSet &patterns) {
399 patterns.add<TileLoadOpConversion, TileLoadOpWithMaskAndPadNonZeroConversion,
400 TileStoreOpConversion>(arg: patterns.getContext());
401}
402
403namespace {
404
405struct ConvertArmSMEToSCFPass
406 : public impl::ConvertArmSMEToSCFPassBase<ConvertArmSMEToSCFPass> {
407 void runOnOperation() override {
408 RewritePatternSet patterns(&getContext());
409 ConversionTarget target(getContext());
410 populateArmSMEToSCFConversionPatterns(patterns);
411 target.addLegalDialect<arm_sme::ArmSMEDialect, vector::VectorDialect,
412 arith::ArithDialect, scf::SCFDialect>();
413 target.addIllegalOp<arm_sme::TileLoadOp, arm_sme::TileStoreOp>();
414 if (failed(Result: applyPartialConversion(op: getOperation(), target,
415 patterns: std::move(patterns))))
416 signalPassFailure();
417 }
418};
419
420} // namespace
421

source code of mlir/lib/Conversion/ArmSMEToSCF/ArmSMEToSCF.cpp