1//===- ArmSMEToSCF.cpp - Convert ArmSME to SCF dialect ----------*- C++ -*-===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implements lowering of ArmSME operations to SCF.
10//
11//===----------------------------------------------------------------------===//
12#include "mlir/Conversion/ArmSMEToSCF/ArmSMEToSCF.h"
13
14#include "mlir/Dialect/Arith/IR/Arith.h"
15#include "mlir/Dialect/ArmSME/IR/ArmSME.h"
16#include "mlir/Dialect/ArmSME/Utils/Utils.h"
17#include "mlir/Dialect/SCF/IR/SCF.h"
18#include "mlir/Pass/Pass.h"
19#include "mlir/Transforms/DialectConversion.h"
20
21namespace mlir {
22#define GEN_PASS_DEF_CONVERTARMSMETOSCF
23#include "mlir/Conversion/Passes.h.inc"
24} // namespace mlir
25
26using namespace mlir;
27
28namespace {
29/// Returns adjusted (1-D or 2-D) `indices` for a tile slice as follows:
30/// rank 1: (indices[0] + (tileSliceIndex * tileSliceNumElts))
31/// rank 2: (indices[0] + tileSliceIndex, indices[1])
32SmallVector<Value, 2> getMemrefIndices(ValueRange indices, unsigned rank,
33 Value tileSliceIndex,
34 Value tileSliceNumElts, Location loc,
35 PatternRewriter &rewriter) {
36 assert((rank == 1 || rank == 2) && "memref has unexpected rank!");
37 SmallVector<Value, 2> outIndices;
38
39 auto tileSliceOffset = tileSliceIndex;
40 if (rank == 1)
41 tileSliceOffset =
42 rewriter.create<arith::MulIOp>(loc, tileSliceOffset, tileSliceNumElts);
43
44 auto baseIndexPlusTileSliceOffset =
45 rewriter.create<arith::AddIOp>(loc, indices[0], tileSliceOffset);
46 outIndices.push_back(baseIndexPlusTileSliceOffset);
47
48 if (rank == 2)
49 outIndices.push_back(indices[1]);
50
51 return outIndices;
52}
53
54/// Creates an scf.for for the load/store of an ArmSME tile.
55FailureOr<scf::ForOp> createLoadStoreForOverTileSlices(
56 PatternRewriter &rewriter, Location loc, VectorType tileType,
57 ValueRange memrefIndices, int memrefRank, Value mask, Value initTile,
58 function_ref<Value(/*index=*/Value, ValueRange, /*predicate=*/Value,
59 /*currentTile=*/Value)>
60 makeLoopBody) {
61 PatternRewriter::InsertionGuard guard(rewriter);
62
63 auto minTileSlices = rewriter.create<arith::ConstantIndexOp>(
64 loc, arm_sme::getSMETileSliceMinNumElts(type: tileType.getElementType()));
65 auto vscale =
66 rewriter.create<vector::VectorScaleOp>(loc, rewriter.getIndexType());
67 auto predicateType =
68 VectorType::get(tileType.getDimSize(1), rewriter.getI1Type(), true);
69
70 // This describes both the number of ZA tile slices and the number of
71 // elements in a vector of SVL bits for a given element type (SVL_B,
72 // SVL_H, ..., SVL_Q).
73 auto numTileSlices =
74 rewriter.create<arith::MulIOp>(loc, minTileSlices, vscale);
75
76 Value predicate;
77 Value upperBound;
78 if (mask) {
79 auto createMaskOp = mask.getDefiningOp<vector::CreateMaskOp>();
80 if (!createMaskOp)
81 return rewriter.notifyMatchFailure(
82 loc, "unsupported mask op, only 'vector.create_mask' is "
83 "currently supported");
84
85 auto maskDim0 = createMaskOp.getOperands()[0];
86 auto maskDim1 = createMaskOp.getOperands()[1];
87
88 // The upper bound of the loop must be clamped at `numTileSlices` as
89 // `vector.create_mask` allows operands to be greater than the size of a
90 // dimension.
91 auto numRowI64 = rewriter.create<arith::IndexCastOp>(
92 loc, rewriter.getI64Type(), maskDim0);
93 auto numTileSlicesI64 = rewriter.create<arith::IndexCastOp>(
94 loc, rewriter.getI64Type(), numTileSlices);
95 auto upperBoundI64 =
96 rewriter.create<arith::MinSIOp>(loc, numRowI64, numTileSlicesI64);
97 upperBound = rewriter.create<arith::IndexCastOp>(
98 loc, rewriter.getIndexType(), upperBoundI64);
99
100 predicate =
101 rewriter.create<vector::CreateMaskOp>(loc, predicateType, maskDim1);
102 } else {
103 upperBound = numTileSlices;
104 // No mask. Create an 'all true' predicate for the tile slice.
105 predicate = rewriter.create<arith::ConstantOp>(
106 loc, DenseElementsAttr::get(predicateType, true));
107 }
108
109 bool hasCarriedArgs = bool(initTile);
110 auto lowerBound = rewriter.create<arith::ConstantIndexOp>(loc, 0);
111 auto step = rewriter.create<arith::ConstantIndexOp>(loc, 1);
112 auto forOp = rewriter.create<scf::ForOp>(loc, lowerBound, upperBound, step,
113 hasCarriedArgs ? ValueRange{initTile}
114 : ValueRange{});
115
116 rewriter.setInsertionPointToStart(forOp.getBody());
117 Value tileSliceIndex = forOp.getInductionVar();
118
119 auto adjustedIndices = getMemrefIndices(
120 memrefIndices, memrefRank, tileSliceIndex, numTileSlices, loc, rewriter);
121 auto nextTile = makeLoopBody(
122 tileSliceIndex, adjustedIndices, predicate,
123 /*currentTile=*/hasCarriedArgs ? forOp.getRegionIterArg(0) : Value{});
124
125 assert(bool(nextTile) == hasCarriedArgs);
126 if (nextTile)
127 rewriter.create<scf::YieldOp>(loc, nextTile);
128
129 return forOp;
130}
131
132FailureOr<scf::ForOp> createLoadStoreForOverTileSlices(
133 PatternRewriter &rewriter, Location loc, VectorType tileType,
134 ValueRange memrefIndices, int memrefRank, Value mask,
135 function_ref<void(/*index=*/Value, ValueRange, /*predicate=*/Value)>
136 makeLoopBody) {
137 return createLoadStoreForOverTileSlices(
138 rewriter, loc, tileType, memrefIndices, memrefRank, mask, Value{},
139 [&](Value index, ValueRange adjustedIndices, Value predicate,
140 Value) -> Value {
141 makeLoopBody(index, adjustedIndices, predicate);
142 return {};
143 });
144}
145
146/// Lower `arm_sme.tile_load` without a mask, or with a mask and a zero pad.
147///
148/// With a mask:
149///
150/// BEFORE:
151/// ```mlir
152/// %pad = arith.constant 0 : i32
153/// %mask = vector.create_mask %num_rows, %num_cols : vector<[4]x[4]xi1>
154/// %tile = arm_sme.tile_load %src[%c0, %c0], %pad, %mask :
155/// memref<?x?xi32>, vector<[4]x[4]xi32>
156/// ```
157///
158/// AFTER:
159/// ```mlir
160/// %init_tile = arm_sme.zero : vector<[4]x[4]xi32>
161/// %mask_cols = vector.create_mask %num_cols : vector<[4]xi1>
162/// %loop_rows = arith.minsi %num_rows, %svl_s : index
163/// %tile = scf.for %tile_slice_idx = %c0 to %loop_rows step %c1
164/// iter_args(%iter_tile = %init_tile) -> (vector<[4]x[4]xi32>) {
165/// %tile_update = arm_sme.load_tile_slice
166/// %src[%tile_slice_idx], %num_cols, %iter_tile, %tile_slice_idx :
167/// memref<?x?xi32>, vector<[1]xi32>, vector<[4]x[4]xi32>
168/// scf.yield %tile_update : vector<[4]x[4]xi32>
169/// }
170/// ```
171///
172/// Without a mask the lowering is pretty much identical. The only difference is
173/// %mask_cols becomes an all-true mask, and %loop_rows becomes %svl_s.
174///
175/// NOTE: Only mask of 'vector.create_mask' op is currently supported.
176struct TileLoadOpConversion : public OpRewritePattern<arm_sme::TileLoadOp> {
177 using OpRewritePattern<arm_sme::TileLoadOp>::OpRewritePattern;
178
179 LogicalResult matchAndRewrite(arm_sme::TileLoadOp tileLoadOp,
180 PatternRewriter &rewriter) const override {
181 auto loc = tileLoadOp.getLoc();
182 auto tileType = tileLoadOp.getVectorType();
183 auto mask = tileLoadOp.getMask();
184
185 Value initTile;
186 if (mask) {
187 auto padOp = tileLoadOp.getPadding();
188 assert(padOp && "expected padding when masking!");
189
190 auto constPadOp = padOp.getDefiningOp<arith::ConstantOp>();
191 if (!constPadOp || constPadOp.getValue() !=
192 rewriter.getZeroAttr(type: tileType.getElementType()))
193 return rewriter.notifyMatchFailure(
194 tileLoadOp, "op has non-zero pad, needs non-zero pad pattern");
195
196 // Initialize tile with zero to satisfy padding. Inactive cols will be
197 // zeroed anyway since the loads use zeroing predication. For inactive
198 // rows however, no load will occur so these need to be zeroed.
199 initTile = tileLoadOp.createOpAndForwardTileId<arm_sme::ZeroOp>(
200 rewriter, loc, tileType);
201 } else {
202 // Allocate a new SME tile.
203 initTile = tileLoadOp.createOpAndForwardTileId<arm_sme::GetTileOp>(
204 rewriter, loc, tileType);
205 }
206
207 // Create a loop to load the active tile slices from memory.
208 auto forOp = createLoadStoreForOverTileSlices(
209 rewriter, loc, tileType, tileLoadOp.getIndices(),
210 tileLoadOp.getMemRefType().getRank(), mask, initTile,
211 [&](Value tileSliceIndex, ValueRange memrefIndices, Value predicate,
212 Value currentTile) -> Value {
213 // Create 'arm_sme.load_tile_slice' to load tile slice from memory
214 // into tile.
215 return tileLoadOp.createOpAndForwardTileId<arm_sme::LoadTileSliceOp>(
216 rewriter, loc, tileType, tileLoadOp.getBase(), predicate,
217 currentTile, memrefIndices, tileSliceIndex,
218 tileLoadOp.getLayout());
219 });
220
221 if (failed(forOp))
222 return forOp;
223
224 // Replace 'arm_sme.tile_load' with the result.
225 rewriter.replaceOp(tileLoadOp, forOp->getResult(0));
226
227 return success();
228 }
229};
230
231/// Lower `arm_sme.tile_load` with mask and non-zero pad.
232///
233/// BEFORE:
234/// ```mlir
235/// %mask = vector.create_mask %num_rows, %num_cols : vector<[4]x[4]xi1>
236/// %tile = arm_sme.tile_load %src[%c0, %c0], %pad, %mask :
237/// memref<?x?xi32>, vector<[4]x[4]xi32>
238/// ```
239///
240/// AFTER:
241/// ```mlir
242/// ...
243/// %pad_1d = vector.splat %pad : vector<[4]xi32>
244/// %tile = scf.for %tile_slice_idx = %c0 to %svl_s step %c1
245/// iter_args(%iter_tile = %init_tile) -> (vector<[4]x[4]xi32>) {
246/// ...
247/// %mask_1d = vector.create_mask <combined_mask> : vector<[4]xi1>
248/// %slice = vector.maskedload %base[%tile_slice_idx, %c0], %mask_1d, %pad_1d
249/// : memref<?x?xi32>, vector<[4]xi1>,
250/// vector<[4]xi32> into vector<[4]xi32>
251/// // Insert slice into tile
252/// %tile_update = arm_sme.move_vector_to_tile_slice
253/// %slice, %iter_tile, %tile_slice_idx :
254/// vector<[4]xi32> into vector<[4]x[4]xi32>
255/// scf.yield %tile_update : vector<[4]x[4]xi32>
256/// }
257/// ```
258struct TileLoadOpWithMaskAndPadNonZeroConversion
259 : public OpRewritePattern<arm_sme::TileLoadOp> {
260 using OpRewritePattern<arm_sme::TileLoadOp>::OpRewritePattern;
261
262 LogicalResult matchAndRewrite(arm_sme::TileLoadOp tileLoadOp,
263 PatternRewriter &rewriter) const override {
264 OpBuilder::InsertionGuard g(rewriter);
265 auto loc = tileLoadOp.getLoc();
266 auto tileType = tileLoadOp.getVectorType();
267 auto tileElementType = tileType.getElementType();
268
269 auto maskOp = tileLoadOp.getMask();
270 if (!maskOp)
271 return rewriter.notifyMatchFailure(
272 tileLoadOp, "op has no mask, needs unmasked pattern");
273
274 auto padOp = tileLoadOp.getPadding();
275 assert(padOp && "expected padding when masking!");
276
277 auto createMaskOp = maskOp.getDefiningOp<vector::CreateMaskOp>();
278 if (!createMaskOp)
279 return rewriter.notifyMatchFailure(
280 tileLoadOp, "unsupported mask op, only 'vector.create_mask' is "
281 "currently supported");
282
283 auto constPadOp = padOp.getDefiningOp<arith::ConstantOp>();
284 if (constPadOp &&
285 constPadOp.getValue() == rewriter.getZeroAttr(type: tileElementType))
286 return rewriter.notifyMatchFailure(
287 tileLoadOp, "op has constant zero pad, needs zero pad pattern");
288
289 auto numRows = createMaskOp.getOperands()[0];
290 auto numCols = createMaskOp.getOperands()[1];
291
292 auto numColsI32 = rewriter.create<arith::IndexCastUIOp>(
293 loc, rewriter.getI32Type(), numCols);
294
295 // Allocate a new SME tile.
296 auto initTile = tileLoadOp.createOpAndForwardTileId<arm_sme::GetTileOp>(
297 rewriter, loc, tileType);
298
299 // Create a loop that loads each ZA tile slice from memory.
300 auto step = rewriter.create<arith::ConstantIndexOp>(loc, 1);
301 auto minTileSlices = rewriter.create<arith::ConstantIndexOp>(
302 loc, arm_sme::getSMETileSliceMinNumElts(type: tileElementType));
303 auto vscale =
304 rewriter.create<vector::VectorScaleOp>(loc, rewriter.getIndexType());
305 auto lowerBound = rewriter.create<arith::ConstantIndexOp>(loc, 0);
306 auto numTileSlices =
307 rewriter.create<arith::MulIOp>(loc, minTileSlices, vscale);
308 auto forOp = rewriter.create<scf::ForOp>(loc, lowerBound, numTileSlices,
309 step, ValueRange{initTile});
310
311 rewriter.setInsertionPointToStart(forOp.getBody());
312
313 auto tileSliceIndex = forOp.getInductionVar();
314 auto currentTile = forOp.getRegionIterArg(0);
315
316 // Combine masks.
317 auto rowIsActive = rewriter.create<arith::CmpIOp>(
318 loc, arith::CmpIPredicate::ult, tileSliceIndex, numRows);
319 auto rowIsActiveI32 = rewriter.create<arith::ExtSIOp>(
320 loc, rewriter.getI32Type(), rowIsActive);
321 auto mask = rewriter.create<arith::AndIOp>(loc, rowIsActiveI32, numColsI32);
322 auto maskIndex =
323 rewriter.create<arith::IndexCastOp>(loc, rewriter.getIndexType(), mask);
324 auto predicateType =
325 VectorType::get(tileType.getDimSize(1), rewriter.getI1Type(), true);
326 auto maskOp1D = rewriter.create<vector::CreateMaskOp>(
327 loc, predicateType, maskIndex.getResult());
328
329 auto memrefIndices = getMemrefIndices(
330 tileLoadOp.getIndices(), tileLoadOp.getMemRefType().getRank(),
331 tileSliceIndex, numTileSlices, loc, rewriter);
332
333 // Splat pad into 1-D vector matching type of tile slice.
334 VectorType tileSliceType = VectorType::Builder(tileType).dropDim(0);
335 auto pad1DOp = rewriter.create<vector::SplatOp>(loc, tileSliceType, padOp);
336
337 auto loadSlice = rewriter.create<vector::MaskedLoadOp>(
338 loc, tileSliceType, tileLoadOp.getBase(), memrefIndices, maskOp1D,
339 /*passthru=*/pad1DOp);
340
341 // Create 'arm_sme.move_vector_to_tile_slice' to move slice into tile.
342 auto moveSlice =
343 tileLoadOp.createOpAndForwardTileId<arm_sme::MoveVectorToTileSliceOp>(
344 rewriter, loc, tileType, loadSlice->getResult(0), currentTile,
345 tileSliceIndex, tileLoadOp.getLayout());
346 rewriter.create<scf::YieldOp>(loc, moveSlice.getResult());
347
348 rewriter.setInsertionPointAfter(forOp);
349
350 // Replace 'arm_sme.tile_load' with the result.
351 rewriter.replaceOp(tileLoadOp, forOp.getResult(0));
352
353 return success();
354 }
355};
356
357/// Lower `arm_sme.tile_store` to a loop over the tile slices and store each
358/// slice using `arm_sme.store_tile_slice`.
359///
360/// BEFORE:
361/// ```mlir
362/// arm_sme.tile_store %tile, %dest[%c0, %c0] layout<vertical>
363/// : memref<?x?xi32>, vector<[4]x[4]xi32
364/// ```
365///
366/// AFTER:
367/// ```mlir
368/// %vscale = vector.vscale
369/// %c0 = arith.constant 0 : index
370/// %c1 = arith.constant 1 : index
371/// %min_svl_s = arith.constant 4 : index
372/// %svl_s = arith.muli %min_svl_s, %vscale : index
373/// scf.for %tile_slice_idx = %c0 to %svl_s step %c1 {
374/// arm_sme.store_tile_slice %tile, %tile_slice_idx, %dest[%tile_slice_idx],
375/// layout<vertical> : memref<?x?xi32>, vector<[4]x[4]xi32>
376/// }
377/// ```
378struct TileStoreOpConversion : public OpRewritePattern<arm_sme::TileStoreOp> {
379 using OpRewritePattern<arm_sme::TileStoreOp>::OpRewritePattern;
380
381 LogicalResult matchAndRewrite(arm_sme::TileStoreOp tileStoreOp,
382 PatternRewriter &rewriter) const override {
383 // Create a loop that stores each active ZA tile slice from memory.
384 return createLoadStoreForOverTileSlices(
385 rewriter, tileStoreOp.getLoc(), tileStoreOp.getVectorType(),
386 tileStoreOp.getIndices(), tileStoreOp.getMemRefType().getRank(),
387 tileStoreOp.getMask(),
388 [&](Value tileSliceIndex, ValueRange memrefIndices, Value predicate) {
389 tileStoreOp.replaceWithAndForwardTileId<arm_sme::StoreTileSliceOp>(
390 rewriter, tileStoreOp.getValueToStore(), tileSliceIndex,
391 predicate, tileStoreOp.getBase(), memrefIndices,
392 tileStoreOp.getLayout());
393 });
394 }
395};
396
397} // namespace
398
399void mlir::populateArmSMEToSCFConversionPatterns(RewritePatternSet &patterns) {
400 patterns.add<TileLoadOpConversion, TileLoadOpWithMaskAndPadNonZeroConversion,
401 TileStoreOpConversion>(arg: patterns.getContext());
402}
403
404namespace {
405
406struct ConvertArmSMEToSCFPass
407 : public impl::ConvertArmSMEToSCFBase<ConvertArmSMEToSCFPass> {
408 void runOnOperation() override {
409 RewritePatternSet patterns(&getContext());
410 ConversionTarget target(getContext());
411 populateArmSMEToSCFConversionPatterns(patterns);
412 target.addLegalDialect<arm_sme::ArmSMEDialect, vector::VectorDialect,
413 arith::ArithDialect, scf::SCFDialect>();
414 target.addIllegalOp<arm_sme::TileLoadOp, arm_sme::TileStoreOp>();
415 if (failed(applyPartialConversion(getOperation(), target,
416 std::move(patterns))))
417 signalPassFailure();
418 }
419};
420
421} // namespace
422
423std::unique_ptr<Pass> mlir::createConvertArmSMEToSCFPass() {
424 return std::make_unique<ConvertArmSMEToSCFPass>();
425}
426

source code of mlir/lib/Conversion/ArmSMEToSCF/ArmSMEToSCF.cpp