| 1 | //===- VectorToArmSME.cpp - Conversion from Vector to the ArmSME dialect --===// |
| 2 | // |
| 3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
| 4 | // See https://llvm.org/LICENSE.txt for license information. |
| 5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
| 6 | // |
| 7 | //===----------------------------------------------------------------------===// |
| 8 | |
| 9 | #include "mlir/Conversion/VectorToArmSME/VectorToArmSME.h" |
| 10 | |
| 11 | #include "mlir/Dialect/ArmSME/IR/ArmSME.h" |
| 12 | #include "mlir/Dialect/ArmSME/Utils/Utils.h" |
| 13 | #include "mlir/Dialect/ArmSVE/IR/ArmSVEDialect.h" |
| 14 | #include "mlir/Dialect/MemRef/IR/MemRef.h" |
| 15 | #include "mlir/IR/BuiltinTypes.h" |
| 16 | #include "llvm/Support/Casting.h" |
| 17 | |
| 18 | using namespace mlir; |
| 19 | |
| 20 | namespace { |
| 21 | |
| 22 | /// Conversion pattern for vector.transfer_read. |
| 23 | /// |
| 24 | /// --- |
| 25 | /// |
| 26 | /// Example 1: op with identity permutation map to horizontal |
| 27 | /// arm_sme.tile_load: |
| 28 | /// |
| 29 | /// vector.transfer_read ... permutation_map: (d0, d1) -> (d0, d1) |
| 30 | /// |
| 31 | /// is converted to: |
| 32 | /// |
| 33 | /// arm_sme.tile_load ... |
| 34 | /// |
| 35 | /// --- |
| 36 | /// |
| 37 | /// Example 2: op with transpose permutation map to vertical arm_sme.tile_load |
| 38 | /// (in-flight transpose): |
| 39 | /// |
| 40 | /// vector.transfer_read ... permutation_map: (d0, d1) -> (d1, d0) |
| 41 | /// |
| 42 | /// is converted to: |
| 43 | /// |
| 44 | /// arm_sme.tile_load ... layout<vertical> |
| 45 | struct TransferReadToArmSMELowering |
| 46 | : public OpRewritePattern<vector::TransferReadOp> { |
| 47 | using OpRewritePattern<vector::TransferReadOp>::OpRewritePattern; |
| 48 | |
| 49 | LogicalResult matchAndRewrite(vector::TransferReadOp transferReadOp, |
| 50 | PatternRewriter &rewriter) const final { |
| 51 | // The permutation map must have two results. |
| 52 | if (transferReadOp.getTransferRank() != 2) |
| 53 | return rewriter.notifyMatchFailure(transferReadOp, |
| 54 | "not a 2 result permutation map" ); |
| 55 | |
| 56 | auto vectorType = transferReadOp.getVectorType(); |
| 57 | if (!arm_sme::isValidSMETileVectorType(vectorType)) |
| 58 | return rewriter.notifyMatchFailure(transferReadOp, |
| 59 | "not a valid vector type for SME" ); |
| 60 | |
| 61 | if (!llvm::isa<MemRefType>(transferReadOp.getBase().getType())) |
| 62 | return rewriter.notifyMatchFailure(transferReadOp, "not a memref source" ); |
| 63 | |
| 64 | // Out-of-bounds dims are not supported. |
| 65 | if (transferReadOp.hasOutOfBoundsDim()) |
| 66 | return rewriter.notifyMatchFailure(transferReadOp, |
| 67 | "not inbounds transfer read" ); |
| 68 | |
| 69 | AffineMap map = transferReadOp.getPermutationMap(); |
| 70 | if (!map.isPermutation()) |
| 71 | return rewriter.notifyMatchFailure(transferReadOp, |
| 72 | "unsupported permutation map" ); |
| 73 | |
| 74 | // Note: For 2D vector types the only non-identity permutation is a simple |
| 75 | // transpose [1, 0]. |
| 76 | bool transposed = !map.isIdentity(); |
| 77 | arm_sme::TileSliceLayout layout = |
| 78 | transposed ? arm_sme::TileSliceLayout::Vertical |
| 79 | : arm_sme::TileSliceLayout::Horizontal; |
| 80 | |
| 81 | // Padding isn't optional for transfer_read, but is only used in the case |
| 82 | // of out-of-bounds accesses (not supported here) and/or masking. Mask is |
| 83 | // optional, if it's not present don't pass padding. |
| 84 | auto mask = transferReadOp.getMask(); |
| 85 | auto padding = mask ? transferReadOp.getPadding() : nullptr; |
| 86 | rewriter.replaceOpWithNewOp<arm_sme::TileLoadOp>( |
| 87 | transferReadOp, vectorType, transferReadOp.getBase(), |
| 88 | transferReadOp.getIndices(), padding, mask, layout); |
| 89 | |
| 90 | return success(); |
| 91 | } |
| 92 | }; |
| 93 | |
| 94 | /// Conversion pattern for vector.transfer_write. |
| 95 | /// |
| 96 | /// --- |
| 97 | /// |
| 98 | /// Example 1: op with identity permutation map to horizontal |
| 99 | /// arm_sme.tile_store: |
| 100 | /// |
| 101 | /// vector.transfer_write %vector, %source[%c0, %c0] |
| 102 | /// {in_bounds = [true, true]} : vector<[16]x[16]xi8>, memref<?x?xi8> |
| 103 | /// |
| 104 | /// is converted to: |
| 105 | /// |
| 106 | /// arm_sme.tile_store %vector, %source[%c0, %c0] : memref<?x?xi8>, |
| 107 | /// vector<[16]x[16]xi8> |
| 108 | /// --- |
| 109 | /// |
| 110 | /// Example 2: op with transpose permutation map to vertical arm_sme.tile_store |
| 111 | /// (in-flight transpose): |
| 112 | /// |
| 113 | /// vector.transfer_write %vector, %source[%c0, %c0] |
| 114 | /// {permutation_map = affine_map<(d0, d1) -> (d1, d0)>, |
| 115 | /// in_bounds = [true, true]} : vector<[16]x[16]xi8>, memref<?x?xi8> |
| 116 | /// |
| 117 | /// is converted to: |
| 118 | /// |
| 119 | /// arm_sme.tile_store %vector, %source[%c0, %c0] layout<vertical> |
| 120 | /// : memref<?x?xi8>, vector<[16]x[16]xi8> |
| 121 | struct TransferWriteToArmSMELowering |
| 122 | : public OpRewritePattern<vector::TransferWriteOp> { |
| 123 | using OpRewritePattern<vector::TransferWriteOp>::OpRewritePattern; |
| 124 | |
| 125 | LogicalResult matchAndRewrite(vector::TransferWriteOp writeOp, |
| 126 | PatternRewriter &rewriter) const final { |
| 127 | auto vType = writeOp.getVectorType(); |
| 128 | if (!arm_sme::isValidSMETileVectorType(vType)) |
| 129 | return failure(); |
| 130 | |
| 131 | if (!llvm::isa<MemRefType>(writeOp.getBase().getType())) |
| 132 | return failure(); |
| 133 | |
| 134 | // Out-of-bounds dims are not supported. |
| 135 | if (writeOp.hasOutOfBoundsDim()) |
| 136 | return rewriter.notifyMatchFailure(writeOp, |
| 137 | "not inbounds transfer write" ); |
| 138 | |
| 139 | AffineMap map = writeOp.getPermutationMap(); |
| 140 | if (!map.isPermutation()) |
| 141 | return rewriter.notifyMatchFailure(writeOp, |
| 142 | "unsupported permutation map" ); |
| 143 | |
| 144 | // Note: For 2D vector types the only non-identity permutation is a simple |
| 145 | // transpose [1, 0]. |
| 146 | bool transposed = !map.isIdentity(); |
| 147 | arm_sme::TileSliceLayout layout = |
| 148 | transposed ? arm_sme::TileSliceLayout::Vertical |
| 149 | : arm_sme::TileSliceLayout::Horizontal; |
| 150 | |
| 151 | rewriter.replaceOpWithNewOp<arm_sme::TileStoreOp>( |
| 152 | writeOp, writeOp.getVector(), writeOp.getBase(), writeOp.getIndices(), |
| 153 | writeOp.getMask(), layout); |
| 154 | return success(); |
| 155 | } |
| 156 | }; |
| 157 | |
| 158 | /// Conversion pattern for vector.load. |
| 159 | struct VectorLoadToArmSMELowering : public OpRewritePattern<vector::LoadOp> { |
| 160 | using OpRewritePattern<vector::LoadOp>::OpRewritePattern; |
| 161 | |
| 162 | LogicalResult matchAndRewrite(vector::LoadOp load, |
| 163 | PatternRewriter &rewriter) const override { |
| 164 | if (!arm_sme::isValidSMETileVectorType(load.getVectorType())) |
| 165 | return failure(); |
| 166 | |
| 167 | rewriter.replaceOpWithNewOp<arm_sme::TileLoadOp>( |
| 168 | load, load.getVectorType(), load.getBase(), load.getIndices()); |
| 169 | |
| 170 | return success(); |
| 171 | } |
| 172 | }; |
| 173 | |
| 174 | /// Conversion pattern for vector.store. |
| 175 | struct VectorStoreToArmSMELowering : public OpRewritePattern<vector::StoreOp> { |
| 176 | using OpRewritePattern<vector::StoreOp>::OpRewritePattern; |
| 177 | |
| 178 | LogicalResult matchAndRewrite(vector::StoreOp store, |
| 179 | PatternRewriter &rewriter) const override { |
| 180 | if (!arm_sme::isValidSMETileVectorType(store.getVectorType())) |
| 181 | return failure(); |
| 182 | |
| 183 | rewriter.replaceOpWithNewOp<arm_sme::TileStoreOp>( |
| 184 | store, store.getValueToStore(), store.getBase(), store.getIndices()); |
| 185 | |
| 186 | return success(); |
| 187 | } |
| 188 | }; |
| 189 | |
| 190 | /// Conversion pattern for vector.broadcast. |
| 191 | /// |
| 192 | /// Example: |
| 193 | /// |
| 194 | /// %broadcast_to_tile = vector.broadcast %src : i32 to vector<[4]x[4]xi32> |
| 195 | /// |
| 196 | /// is converted to: |
| 197 | /// |
| 198 | /// %broadcast_to_1d = vector.broadcast %src : i32 to vector<[4]xi32> |
| 199 | /// %broadcast_to_tile = scf.for %tile_slice_index = %c0 to %num_tile_slices |
| 200 | /// step %c1 iter_args(%iter_tile = %init_tile) -> (vector<[4]x[4]xi32>) |
| 201 | /// { |
| 202 | /// %tile_update = arm_sme.insert_tile_slice |
| 203 | /// %broadcast_to_1d, %iter_tile[%tile_slice_index] : |
| 204 | /// vector<[4]xi32> into vector<[4]x[4]xi32> |
| 205 | /// scf.yield %tile_update : vector<[4]x[4]xi32> |
| 206 | /// } |
| 207 | /// |
| 208 | /// Supports scalar, 0-d vector, and 1-d vector broadcasts. |
| 209 | struct BroadcastOpToArmSMELowering |
| 210 | : public OpRewritePattern<vector::BroadcastOp> { |
| 211 | using OpRewritePattern<vector::BroadcastOp>::OpRewritePattern; |
| 212 | |
| 213 | LogicalResult matchAndRewrite(vector::BroadcastOp broadcastOp, |
| 214 | PatternRewriter &rewriter) const final { |
| 215 | auto tileType = broadcastOp.getResultVectorType(); |
| 216 | if (!tileType || !arm_sme::isValidSMETileVectorType(tileType)) |
| 217 | return failure(); |
| 218 | |
| 219 | auto loc = broadcastOp.getLoc(); |
| 220 | |
| 221 | auto srcType = broadcastOp.getSourceType(); |
| 222 | auto srcVectorType = dyn_cast<VectorType>(srcType); |
| 223 | |
| 224 | Value broadcastOp1D; |
| 225 | if (srcType.isIntOrFloat() || |
| 226 | (srcVectorType && (srcVectorType.getRank() == 0))) { |
| 227 | // Broadcast scalar or 0-d vector to 1-d vector. |
| 228 | VectorType tileSliceType = VectorType::Builder(tileType).dropDim(0); |
| 229 | broadcastOp1D = rewriter.create<vector::BroadcastOp>( |
| 230 | loc, tileSliceType, broadcastOp.getSource()); |
| 231 | } else if (srcVectorType && (srcVectorType.getRank() == 1)) |
| 232 | // Value to broadcast is already a 1-d vector, nothing to do. |
| 233 | broadcastOp1D = broadcastOp.getSource(); |
| 234 | else |
| 235 | return failure(); |
| 236 | |
| 237 | auto initTile = rewriter.create<arm_sme::GetTileOp>(loc, tileType); |
| 238 | |
| 239 | auto makeLoopBody = [&](OpBuilder &b, Location loc, Value tileSliceIndex, |
| 240 | Value currentTile) { |
| 241 | // Create 'arm_sme.insert_tile_slice' to broadcast the value |
| 242 | // to each tile slice. |
| 243 | auto nextTile = b.create<arm_sme::InsertTileSliceOp>( |
| 244 | loc, tileType, broadcastOp1D, currentTile, tileSliceIndex); |
| 245 | return nextTile.getResult(); |
| 246 | }; |
| 247 | |
| 248 | // Create a loop over ZA tile slices. |
| 249 | auto forOp = |
| 250 | createLoopOverTileSlices(rewriter, loc, initTile, makeLoopBody); |
| 251 | |
| 252 | rewriter.replaceOp(broadcastOp, forOp.getResult(0)); |
| 253 | |
| 254 | return success(); |
| 255 | } |
| 256 | }; |
| 257 | |
| 258 | /// Conversion pattern for vector.splat. |
| 259 | /// |
| 260 | /// Example: |
| 261 | /// |
| 262 | /// %splat_to_tile = vector.splat %src : i32 to vector<[4]x[4]xi32> |
| 263 | /// |
| 264 | /// is converted to: |
| 265 | /// |
| 266 | /// %broadcast_to_1d = vector.broadcast %src : i32 to vector<[4]xi32> |
| 267 | /// %broadcast_to_tile = scf.for %tile_slice_index = %c0 to %num_tile_slices |
| 268 | /// step %c1 iter_args(%iter_tile = %init_tile) -> (vector<[4]x[4]xi32>) |
| 269 | /// { |
| 270 | /// %tile_update = arm_sme.insert_tile_slice |
| 271 | /// %broadcast_to_1d, %iter_tile[%tile_slice_index] : |
| 272 | /// vector<[4]xi32> into vector<[4]x[4]xi32> |
| 273 | /// scf.yield %tile_update : vector<[4]x[4]xi32> |
| 274 | /// } |
| 275 | /// |
| 276 | /// This is identical to vector.broadcast of a scalar. |
| 277 | struct SplatOpToArmSMELowering : public OpRewritePattern<vector::SplatOp> { |
| 278 | using OpRewritePattern<vector::SplatOp>::OpRewritePattern; |
| 279 | |
| 280 | LogicalResult matchAndRewrite(vector::SplatOp splatOp, |
| 281 | PatternRewriter &rewriter) const final { |
| 282 | auto tileType = splatOp.getResult().getType(); |
| 283 | if (!tileType || !arm_sme::isValidSMETileVectorType(tileType)) |
| 284 | return failure(); |
| 285 | |
| 286 | auto loc = splatOp.getLoc(); |
| 287 | auto srcType = splatOp.getOperand().getType(); |
| 288 | |
| 289 | assert(srcType.isIntOrFloat() && "Invalid source type for vector.splat" ); |
| 290 | // Avoid unused-variable warning when building without assertions. |
| 291 | (void)srcType; |
| 292 | |
| 293 | // First, broadcast the scalar to a 1-d vector. |
| 294 | VectorType tileSliceType = VectorType::Builder(tileType).dropDim(0); |
| 295 | Value broadcastOp1D = rewriter.create<vector::BroadcastOp>( |
| 296 | loc, tileSliceType, splatOp.getInput()); |
| 297 | |
| 298 | auto initTile = rewriter.create<arm_sme::GetTileOp>(loc, tileType); |
| 299 | |
| 300 | auto makeLoopBody = [&](OpBuilder &b, Location loc, Value tileSliceIndex, |
| 301 | Value currentTile) { |
| 302 | auto nextTile = b.create<arm_sme::InsertTileSliceOp>( |
| 303 | loc, tileType, broadcastOp1D, currentTile, tileSliceIndex); |
| 304 | return nextTile.getResult(); |
| 305 | }; |
| 306 | |
| 307 | // Next, create a loop over ZA tile slices and "move" the generated 1-d |
| 308 | // vector to each slice. |
| 309 | auto forOp = |
| 310 | createLoopOverTileSlices(rewriter, loc, initTile, makeLoopBody); |
| 311 | |
| 312 | rewriter.replaceOp(splatOp, forOp.getResult(0)); |
| 313 | |
| 314 | return success(); |
| 315 | } |
| 316 | }; |
| 317 | |
| 318 | /// Conversion pattern for vector.transpose. |
| 319 | /// |
| 320 | /// Stores the input tile to memory and reloads vertically. |
| 321 | /// |
| 322 | /// Example: |
| 323 | /// |
| 324 | /// %transposed_src = vector.transpose %src, [1, 0] |
| 325 | /// : vector<[4]x[4]xi32> to vector<[4]x[4]xi32> |
| 326 | /// |
| 327 | /// is converted to: |
| 328 | /// |
| 329 | /// %alloca = memref.alloca(%svl_s, %svl_s) : memref<?x?xi32> |
| 330 | /// %arm_sme.tile_store %src, <hor>, %alloca[%c0, %c0] |
| 331 | /// : memref<?x?xi32>, vector<[4]x[4]xi32> |
| 332 | /// %transposed_src = arm_sme.tile_load %alloca[%c0, %c0] |
| 333 | /// layout<vertical> : memref<?x?xi32>, vector<[4]x[4]xi32> |
| 334 | /// |
| 335 | /// NOTE: Transposing via memory is obviously expensive, the current intention |
| 336 | /// is to avoid the transpose if possible, this is therefore intended as a |
| 337 | /// fallback and to provide base support for Vector ops. If it turns out |
| 338 | /// transposes can't be avoided then this should be replaced with a more optimal |
| 339 | /// implementation, perhaps with tile <-> vector (MOVA) ops. |
| 340 | struct TransposeOpToArmSMELowering |
| 341 | : public OpRewritePattern<vector::TransposeOp> { |
| 342 | using OpRewritePattern<vector::TransposeOp>::OpRewritePattern; |
| 343 | |
| 344 | LogicalResult matchAndRewrite(vector::TransposeOp transposeOp, |
| 345 | PatternRewriter &rewriter) const final { |
| 346 | auto tileType = transposeOp.getResultVectorType(); |
| 347 | if (!tileType || !arm_sme::isValidSMETileVectorType(tileType)) |
| 348 | return failure(); |
| 349 | |
| 350 | // Bail unless this is a true 2-D matrix transpose. |
| 351 | ArrayRef<int64_t> permutation = transposeOp.getPermutation(); |
| 352 | if (permutation[0] != 1 || permutation[1] != 0) |
| 353 | return failure(); |
| 354 | |
| 355 | auto loc = transposeOp.getLoc(); |
| 356 | Value input = transposeOp.getVector(); |
| 357 | |
| 358 | if (auto xferOp = input.getDefiningOp<vector::TransferReadOp>(); |
| 359 | xferOp && xferOp->hasOneUse()) { |
| 360 | // Fold transpose into transfer_read to enable in-flight transpose when |
| 361 | // converting to arm_sme.tile_load. |
| 362 | rewriter.modifyOpInPlace(xferOp, [&]() { |
| 363 | xferOp->setAttr(xferOp.getPermutationMapAttrName(), |
| 364 | AffineMapAttr::get(AffineMap::getPermutationMap( |
| 365 | permutation, transposeOp.getContext()))); |
| 366 | }); |
| 367 | rewriter.replaceOp(transposeOp, xferOp); |
| 368 | return success(); |
| 369 | } |
| 370 | |
| 371 | // Allocate buffer to store input tile to. |
| 372 | Value vscale = |
| 373 | rewriter.create<vector::VectorScaleOp>(loc, rewriter.getIndexType()); |
| 374 | Value minTileSlices = rewriter.create<arith::ConstantOp>( |
| 375 | loc, rewriter.getIndexAttr(tileType.getDimSize(0))); |
| 376 | Value c0 = |
| 377 | rewriter.create<arith::ConstantOp>(loc, rewriter.getIndexAttr(0)); |
| 378 | Value numTileSlices = |
| 379 | rewriter.create<arith::MulIOp>(loc, vscale, minTileSlices); |
| 380 | auto bufferType = |
| 381 | MemRefType::get({ShapedType::kDynamic, ShapedType::kDynamic}, |
| 382 | tileType.getElementType()); |
| 383 | auto buffer = rewriter.create<memref::AllocaOp>( |
| 384 | loc, bufferType, ValueRange{numTileSlices, numTileSlices}); |
| 385 | |
| 386 | // Store input tile. |
| 387 | auto tileStoreOp = rewriter.create<arm_sme::TileStoreOp>( |
| 388 | loc, input, buffer, ValueRange{c0, c0}); |
| 389 | |
| 390 | // Reload input tile vertically. |
| 391 | rewriter.replaceOpWithNewOp<arm_sme::TileLoadOp>( |
| 392 | transposeOp, tileType, tileStoreOp.getBase(), tileStoreOp.getIndices(), |
| 393 | arm_sme::TileSliceLayout::Vertical); |
| 394 | |
| 395 | return success(); |
| 396 | } |
| 397 | }; |
| 398 | |
| 399 | /// Conversion pattern for vector.outerproduct. |
| 400 | /// |
| 401 | /// If the vector.outerproduct is masked (and the mask is from a |
| 402 | /// vector.create_mask), then the mask is decomposed into two 1-D masks for the |
| 403 | /// operands. |
| 404 | /// |
| 405 | /// Example: |
| 406 | /// |
| 407 | /// %mask = vector.create_mask %dimA, %dimB : vector<[4]x[4]xi1> |
| 408 | /// %result = vector.mask %mask { |
| 409 | /// vector.outerproduct %vecA, %vecB |
| 410 | /// : vector<[4]xf32>, vector<[4]xf32> |
| 411 | /// } : vector<[4]x[4]xi1> -> vector<[4]x[4]xf32> |
| 412 | /// |
| 413 | /// is converted to: |
| 414 | /// |
| 415 | /// %maskA = vector.create_mask %dimA : vector<[4]xi1> |
| 416 | /// %maskB = vector.create_mask %dimB : vector<[4]xi1> |
| 417 | /// %result = arm_sme.outerproduct %vecA, %vecB masks(%maskA, %maskB) |
| 418 | /// : vector<[4]xf32>, vector<[4]xf32> |
| 419 | /// |
| 420 | /// Unmasked outerproducts can be directly replaced with the arm_sme op. |
| 421 | /// |
| 422 | /// Example: |
| 423 | /// |
| 424 | /// %result = vector.outerproduct %vecA, %vecB |
| 425 | /// : vector<[4]xf32>, vector<[4]xf32> |
| 426 | /// |
| 427 | /// is converted to: |
| 428 | /// |
| 429 | /// %result = arm_sme.outerproduct %vecA, %vecB |
| 430 | /// : vector<[4]xf32>, vector<[4]xf32> |
| 431 | /// |
| 432 | struct VectorOuterProductToArmSMELowering |
| 433 | : public OpRewritePattern<vector::OuterProductOp> { |
| 434 | |
| 435 | using OpRewritePattern<vector::OuterProductOp>::OpRewritePattern; |
| 436 | |
| 437 | LogicalResult matchAndRewrite(vector::OuterProductOp outerProductOp, |
| 438 | PatternRewriter &rewriter) const override { |
| 439 | |
| 440 | // We don't yet support lowering AXPY operations to SME. These could be |
| 441 | // lowered by masking out all but the first element of the LHS. |
| 442 | if (!isa<VectorType>(outerProductOp.getOperandTypeRHS())) |
| 443 | return rewriter.notifyMatchFailure(outerProductOp, |
| 444 | "AXPY operations not supported" ); |
| 445 | |
| 446 | if (!arm_sme::isValidSMETileVectorType( |
| 447 | outerProductOp.getResultVectorType())) |
| 448 | return rewriter.notifyMatchFailure( |
| 449 | outerProductOp, "outer product does not fit into SME tile" ); |
| 450 | |
| 451 | auto kind = outerProductOp.getKind(); |
| 452 | if (kind != vector::CombiningKind::ADD) |
| 453 | return rewriter.notifyMatchFailure( |
| 454 | outerProductOp, |
| 455 | "unsupported kind (lowering to SME only supports ADD at the moment)" ); |
| 456 | |
| 457 | Value lhsMask = {}; |
| 458 | Value rhsMask = {}; |
| 459 | Operation *rootOp = outerProductOp; |
| 460 | auto loc = outerProductOp.getLoc(); |
| 461 | if (outerProductOp.isMasked()) { |
| 462 | auto maskOp = outerProductOp.getMaskingOp(); |
| 463 | rewriter.setInsertionPoint(maskOp); |
| 464 | rootOp = maskOp; |
| 465 | auto operandMasks = decomposeResultMask(loc: loc, mask: maskOp.getMask(), rewriter); |
| 466 | if (failed(operandMasks)) |
| 467 | return failure(); |
| 468 | std::tie(args&: lhsMask, args&: rhsMask) = *operandMasks; |
| 469 | } |
| 470 | |
| 471 | rewriter.replaceOpWithNewOp<arm_sme::OuterProductOp>( |
| 472 | rootOp, outerProductOp.getResultVectorType(), outerProductOp.getLhs(), |
| 473 | outerProductOp.getRhs(), lhsMask, rhsMask, outerProductOp.getAcc()); |
| 474 | |
| 475 | return success(); |
| 476 | } |
| 477 | |
| 478 | static FailureOr<std::pair<Value, Value>> |
| 479 | decomposeResultMask(Location loc, Value mask, PatternRewriter &rewriter) { |
| 480 | // Attempt to extract masks from vector.create_mask. |
| 481 | // TODO: Add support for other mask sources. |
| 482 | auto createMaskOp = mask.getDefiningOp<vector::CreateMaskOp>(); |
| 483 | if (!createMaskOp) |
| 484 | return failure(); |
| 485 | |
| 486 | auto maskType = createMaskOp.getVectorType(); |
| 487 | Value lhsMaskDim = createMaskOp.getOperand(0); |
| 488 | Value rhsMaskDim = createMaskOp.getOperand(1); |
| 489 | |
| 490 | VectorType operandMaskType = VectorType::Builder(maskType).dropDim(0); |
| 491 | Value lhsMask = |
| 492 | rewriter.create<vector::CreateMaskOp>(loc, operandMaskType, lhsMaskDim); |
| 493 | Value rhsMask = |
| 494 | rewriter.create<vector::CreateMaskOp>(loc, operandMaskType, rhsMaskDim); |
| 495 | |
| 496 | return std::make_pair(x&: lhsMask, y&: rhsMask); |
| 497 | } |
| 498 | }; |
| 499 | |
| 500 | /// Lower `vector.extract` using `arm_sme.extract_tile_slice`. |
| 501 | /// |
| 502 | /// Example: |
| 503 | /// ``` |
| 504 | /// %el = vector.extract %tile[%row, %col]: i32 from vector<[4]x[4]xi32> |
| 505 | /// ``` |
| 506 | /// Becomes: |
| 507 | /// ``` |
| 508 | /// %slice = arm_sme.extract_tile_slice %tile[%row] |
| 509 | /// : vector<[4]xi32> from vector<[4]x[4]xi32> |
| 510 | /// %el = vector.extract %slice[%col] : i32 from vector<[4]xi32> |
| 511 | /// ``` |
| 512 | struct |
| 513 | : public OpRewritePattern<vector::ExtractOp> { |
| 514 | using OpRewritePattern<vector::ExtractOp>::OpRewritePattern; |
| 515 | |
| 516 | LogicalResult matchAndRewrite(vector::ExtractOp , |
| 517 | PatternRewriter &rewriter) const override { |
| 518 | VectorType sourceType = extractOp.getSourceVectorType(); |
| 519 | if (!arm_sme::isValidSMETileVectorType(sourceType)) |
| 520 | return failure(); |
| 521 | |
| 522 | auto loc = extractOp.getLoc(); |
| 523 | auto position = extractOp.getMixedPosition(); |
| 524 | |
| 525 | Value sourceVector = extractOp.getVector(); |
| 526 | |
| 527 | // Extract entire vector. Should be handled by folder, but just to be safe. |
| 528 | if (position.empty()) { |
| 529 | rewriter.replaceOp(extractOp, sourceVector); |
| 530 | return success(); |
| 531 | } |
| 532 | |
| 533 | Value sliceIndex = vector::getAsValues(builder&: rewriter, loc: loc, foldResults: position[0]).front(); |
| 534 | auto = rewriter.create<arm_sme::ExtractTileSliceOp>( |
| 535 | loc, sourceVector, sliceIndex); |
| 536 | |
| 537 | if (position.size() == 1) { |
| 538 | // Single index case: Extracts a 1D slice. |
| 539 | rewriter.replaceOp(extractOp, extractTileSlice); |
| 540 | return success(); |
| 541 | } |
| 542 | |
| 543 | // Two indices case: Extracts a single element. |
| 544 | assert(position.size() == 2); |
| 545 | rewriter.replaceOpWithNewOp<vector::ExtractOp>(extractOp, extractTileSlice, |
| 546 | position[1]); |
| 547 | |
| 548 | return success(); |
| 549 | } |
| 550 | }; |
| 551 | |
| 552 | /// Lower `vector.insert` using `arm_sme.insert_tile_slice` and |
| 553 | /// `arm_sme.extract_tile_slice`. |
| 554 | /// |
| 555 | /// Example: |
| 556 | /// ``` |
| 557 | /// %new_tile = vector.insert %el, %tile[%row, %col] |
| 558 | /// : i32 into vector<[4]x[4]xi32> |
| 559 | /// ``` |
| 560 | /// Becomes: |
| 561 | /// ``` |
| 562 | /// %slice = arm_sme.extract_tile_slice %tile[%row] |
| 563 | /// : vector<[4]xi32> from vector<[4]x[4]xi32> |
| 564 | /// %new_slice = vector.insert %el, %slice[%col] : i32 into vector<[4]xi32> |
| 565 | /// %new_tile = arm_sme.insert_tile_slice %new_slice, %tile[%row] |
| 566 | /// : vector<[4]xi32> into vector<[4]x[4]xi32> |
| 567 | /// ``` |
| 568 | struct VectorInsertToArmSMELowering |
| 569 | : public OpRewritePattern<vector::InsertOp> { |
| 570 | using OpRewritePattern<vector::InsertOp>::OpRewritePattern; |
| 571 | |
| 572 | LogicalResult matchAndRewrite(vector::InsertOp insertOp, |
| 573 | PatternRewriter &rewriter) const override { |
| 574 | VectorType resultType = insertOp.getResult().getType(); |
| 575 | |
| 576 | if (!arm_sme::isValidSMETileVectorType(resultType)) |
| 577 | return failure(); |
| 578 | |
| 579 | auto loc = insertOp.getLoc(); |
| 580 | auto position = insertOp.getMixedPosition(); |
| 581 | |
| 582 | Value source = insertOp.getValueToStore(); |
| 583 | |
| 584 | // Overwrite entire vector with value. Should be handled by folder, but |
| 585 | // just to be safe. |
| 586 | if (position.empty()) { |
| 587 | rewriter.replaceOp(insertOp, source); |
| 588 | return success(); |
| 589 | } |
| 590 | |
| 591 | Value tileSlice = source; |
| 592 | Value sliceIndex = vector::getAsValues(builder&: rewriter, loc: loc, foldResults: position[0]).front(); |
| 593 | if (position.size() == 2) { |
| 594 | // Two indices case: Insert single element into tile. |
| 595 | // We need to first extract the existing slice and update the element. |
| 596 | tileSlice = rewriter.create<arm_sme::ExtractTileSliceOp>( |
| 597 | loc, insertOp.getDest(), sliceIndex); |
| 598 | tileSlice = rewriter.create<vector::InsertOp>(loc, source, tileSlice, |
| 599 | position[1]); |
| 600 | } |
| 601 | |
| 602 | // Insert the slice into the destination tile. |
| 603 | rewriter.replaceOpWithNewOp<arm_sme::InsertTileSliceOp>( |
| 604 | insertOp, tileSlice, insertOp.getDest(), sliceIndex); |
| 605 | return success(); |
| 606 | } |
| 607 | }; |
| 608 | |
| 609 | /// Lowers `vector.print` of a tile into a loop over the rows of the tile, |
| 610 | /// extracting them via `arm_sme.extract_tile_slice`, then printing with |
| 611 | /// a 1D `vector.print`. |
| 612 | /// |
| 613 | /// BEFORE: |
| 614 | /// ```mlir |
| 615 | /// vector.print %tile : vector<[4]x[4]xf32> |
| 616 | /// ``` |
| 617 | /// AFTER: |
| 618 | /// ```mlir |
| 619 | /// %c0 = arith.constant 0 : index |
| 620 | /// %c1 = arith.constant 1 : index |
| 621 | /// %c4 = arith.constant 4 : index |
| 622 | /// %vscale = vector.vscale |
| 623 | /// %svl_s = arith.muli %c4, %vscale : index |
| 624 | /// scf.for %i = %c0 to %svl_s step %c1 { |
| 625 | /// %tile_slice = arm_sme.extract_tile_slice %tile[%i] |
| 626 | /// : vector<[4]xf32> from vector<[4]x[4]xf32> |
| 627 | /// vector.print %tile_slice : vector<[4]xf32> |
| 628 | /// } |
| 629 | /// ``` |
| 630 | struct VectorPrintToArmSMELowering : public OpRewritePattern<vector::PrintOp> { |
| 631 | using OpRewritePattern<vector::PrintOp>::OpRewritePattern; |
| 632 | |
| 633 | LogicalResult matchAndRewrite(vector::PrintOp printOp, |
| 634 | PatternRewriter &rewriter) const override { |
| 635 | if (!printOp.getSource()) |
| 636 | return failure(); |
| 637 | |
| 638 | VectorType vectorType = dyn_cast<VectorType>(printOp.getPrintType()); |
| 639 | if (!vectorType || !arm_sme::isValidSMETileVectorType(vectorType)) |
| 640 | return failure(); |
| 641 | |
| 642 | auto loc = printOp.getLoc(); |
| 643 | |
| 644 | // Create a loop over the rows of the tile. |
| 645 | auto vscale = rewriter.create<vector::VectorScaleOp>(loc); |
| 646 | auto minTileRows = |
| 647 | rewriter.create<arith::ConstantIndexOp>(loc, vectorType.getDimSize(0)); |
| 648 | auto lowerBound = rewriter.create<arith::ConstantIndexOp>(loc, 0); |
| 649 | auto upperBound = rewriter.create<arith::MulIOp>(loc, minTileRows, vscale); |
| 650 | auto step = rewriter.create<arith::ConstantIndexOp>(loc, 1); |
| 651 | auto forOp = rewriter.create<scf::ForOp>(loc, lowerBound, upperBound, step); |
| 652 | { |
| 653 | // Loop body. |
| 654 | rewriter.setInsertionPointToStart(forOp.getBody()); |
| 655 | // Extract the current row from the tile. |
| 656 | Value rowIndex = forOp.getInductionVar(); |
| 657 | auto tileSlice = rewriter.create<arm_sme::ExtractTileSliceOp>( |
| 658 | loc, printOp.getSource(), rowIndex); |
| 659 | // Print the row with a 1D vector.print. |
| 660 | rewriter.create<vector::PrintOp>(loc, tileSlice, |
| 661 | printOp.getPunctuation()); |
| 662 | } |
| 663 | |
| 664 | rewriter.eraseOp(op: printOp); |
| 665 | return success(); |
| 666 | } |
| 667 | }; |
| 668 | |
| 669 | /// Folds a ExtractTileSliceOp + TransferWriteOp to a StoreTileSliceOp. |
| 670 | /// |
| 671 | /// BEFORE: |
| 672 | /// ```mlir |
| 673 | /// %slice = arm_sme.extract_tile_slice %tile[%index] |
| 674 | /// : vector<[4]xf32> from vector<[4]x[4]xf32> |
| 675 | /// vector.transfer_write %slice, %memref[%i, %j], %mask {in_bounds = [true]} |
| 676 | /// : vector<[4]xf32>, memref<?x?xf32> |
| 677 | /// ``` |
| 678 | /// AFTER: |
| 679 | /// ```mlir |
| 680 | /// arm_sme.store_tile_slice %tile, %index, %mask, %memref[%i, %j] |
| 681 | /// : memref<?x?xf32>, vector<[4]xi1>, vector<[4]x[4]xf32> |
| 682 | /// ``` |
| 683 | struct |
| 684 | : public OpRewritePattern<vector::TransferWriteOp> { |
| 685 | using OpRewritePattern<vector::TransferWriteOp>::OpRewritePattern; |
| 686 | |
| 687 | LogicalResult matchAndRewrite(vector::TransferWriteOp writeOp, |
| 688 | PatternRewriter &rewriter) const final { |
| 689 | if (!isa<MemRefType>(writeOp.getBase().getType())) |
| 690 | return rewriter.notifyMatchFailure(writeOp, "destination not a memref" ); |
| 691 | |
| 692 | if (writeOp.hasOutOfBoundsDim()) |
| 693 | return rewriter.notifyMatchFailure(writeOp, |
| 694 | "not inbounds transfer write" ); |
| 695 | |
| 696 | auto = |
| 697 | writeOp.getVector().getDefiningOp<arm_sme::ExtractTileSliceOp>(); |
| 698 | if (!extractTileSlice) |
| 699 | return rewriter.notifyMatchFailure( |
| 700 | writeOp, "vector to store not from ExtractTileSliceOp" ); |
| 701 | |
| 702 | AffineMap map = writeOp.getPermutationMap(); |
| 703 | if (!map.isMinorIdentity()) |
| 704 | return rewriter.notifyMatchFailure(writeOp, |
| 705 | "unsupported permutation map" ); |
| 706 | |
| 707 | Value mask = writeOp.getMask(); |
| 708 | if (!mask) { |
| 709 | auto maskType = writeOp.getVectorType().clone(rewriter.getI1Type()); |
| 710 | mask = rewriter.create<arith::ConstantOp>( |
| 711 | writeOp.getLoc(), maskType, DenseElementsAttr::get(maskType, true)); |
| 712 | } |
| 713 | |
| 714 | rewriter.replaceOpWithNewOp<arm_sme::StoreTileSliceOp>( |
| 715 | writeOp, extractTileSlice.getTile(), |
| 716 | extractTileSlice.getTileSliceIndex(), mask, writeOp.getBase(), |
| 717 | writeOp.getIndices(), extractTileSlice.getLayout()); |
| 718 | return success(); |
| 719 | } |
| 720 | }; |
| 721 | |
| 722 | /// Lower a `vector.extract` from a 2-D scalable `vector.create_mask` to |
| 723 | /// `arm_sve.psel`. Note: While psel is under ArmSVE it requires SME (or |
| 724 | /// SVE 2.1), so this is currently the most logical place for this lowering. |
| 725 | /// |
| 726 | /// Example: |
| 727 | /// ```mlir |
| 728 | /// %mask = vector.create_mask %a, %b : vector<[4]x[8]xi1> |
| 729 | /// %slice = vector.extract %mask[%index] |
| 730 | /// : vector<[8]xi1> from vector<[4]x[8]xi1> |
| 731 | /// ``` |
| 732 | /// Becomes: |
| 733 | /// ``` |
| 734 | /// %mask_rows = vector.create_mask %a : vector<[4]xi1> |
| 735 | /// %mask_cols = vector.create_mask %b : vector<[8]xi1> |
| 736 | /// %slice = arm_sve.psel %mask_cols, %mask_rows[%index] |
| 737 | /// : vector<[8]xi1>, vector<[4]xi1> |
| 738 | /// ``` |
| 739 | struct |
| 740 | : public OpRewritePattern<vector::ExtractOp> { |
| 741 | using OpRewritePattern<vector::ExtractOp>::OpRewritePattern; |
| 742 | |
| 743 | LogicalResult matchAndRewrite(vector::ExtractOp , |
| 744 | PatternRewriter &rewriter) const override { |
| 745 | if (extractOp.getNumIndices() != 1) |
| 746 | return rewriter.notifyMatchFailure(extractOp, "not single extract index" ); |
| 747 | |
| 748 | auto resultType = extractOp.getResult().getType(); |
| 749 | auto resultVectorType = dyn_cast<VectorType>(resultType); |
| 750 | if (!resultVectorType) |
| 751 | return rewriter.notifyMatchFailure(extractOp, "result not VectorType" ); |
| 752 | |
| 753 | auto createMaskOp = |
| 754 | extractOp.getVector().getDefiningOp<vector::CreateMaskOp>(); |
| 755 | if (!createMaskOp) |
| 756 | return rewriter.notifyMatchFailure(extractOp, "source not CreateMaskOp" ); |
| 757 | |
| 758 | auto maskType = createMaskOp.getVectorType(); |
| 759 | if (maskType.getRank() != 2 || !maskType.allDimsScalable()) |
| 760 | return rewriter.notifyMatchFailure(createMaskOp, "not 2-D scalable mask" ); |
| 761 | |
| 762 | auto isSVEPredicateSize = [](int64_t size) { |
| 763 | return size > 0 && size <= 16 && llvm::isPowerOf2_32(Value: uint32_t(size)); |
| 764 | }; |
| 765 | |
| 766 | auto rowsBaseSize = maskType.getDimSize(0); |
| 767 | auto colsBaseSize = maskType.getDimSize(1); |
| 768 | if (!isSVEPredicateSize(rowsBaseSize) || !isSVEPredicateSize(colsBaseSize)) |
| 769 | return rewriter.notifyMatchFailure( |
| 770 | createMaskOp, "mask dimensions not SVE predicate-sized" ); |
| 771 | |
| 772 | auto loc = extractOp.getLoc(); |
| 773 | VectorType rowMaskType = VectorType::Builder(maskType).dropDim(1); |
| 774 | VectorType colMaskType = VectorType::Builder(maskType).dropDim(0); |
| 775 | |
| 776 | // Create the two 1-D masks at the location of the 2-D create_mask (which is |
| 777 | // usually outside a loop). This prevents the need for later hoisting. |
| 778 | rewriter.setInsertionPoint(createMaskOp); |
| 779 | auto rowMask = rewriter.create<vector::CreateMaskOp>( |
| 780 | loc, rowMaskType, createMaskOp.getOperand(0)); |
| 781 | auto colMask = rewriter.create<vector::CreateMaskOp>( |
| 782 | loc, colMaskType, createMaskOp.getOperand(1)); |
| 783 | |
| 784 | rewriter.setInsertionPoint(extractOp); |
| 785 | auto position = |
| 786 | vector::getAsValues(builder&: rewriter, loc: loc, foldResults: extractOp.getMixedPosition()); |
| 787 | rewriter.replaceOpWithNewOp<arm_sve::PselOp>(extractOp, colMask, rowMask, |
| 788 | position[0]); |
| 789 | return success(); |
| 790 | } |
| 791 | }; |
| 792 | |
| 793 | } // namespace |
| 794 | |
| 795 | void mlir::populateVectorToArmSMEPatterns(RewritePatternSet &patterns, |
| 796 | MLIRContext &ctx) { |
| 797 | patterns.add<BroadcastOpToArmSMELowering, SplatOpToArmSMELowering, |
| 798 | TransferReadToArmSMELowering, TransferWriteToArmSMELowering, |
| 799 | TransposeOpToArmSMELowering, VectorLoadToArmSMELowering, |
| 800 | VectorStoreToArmSMELowering, VectorOuterProductToArmSMELowering, |
| 801 | VectorExtractToArmSMELowering, VectorInsertToArmSMELowering, |
| 802 | VectorPrintToArmSMELowering, FoldTransferWriteOfExtractTileSlice, |
| 803 | ExtractFromCreateMaskToPselLowering>(arg: &ctx); |
| 804 | } |
| 805 | |