| 1 | //===- Vectorization.cpp - Implementation of linalg Vectorization ---------===// |
| 2 | // |
| 3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
| 4 | // See https://llvm.org/LICENSE.txt for license information. |
| 5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
| 6 | // |
| 7 | //===----------------------------------------------------------------------===// |
| 8 | // |
| 9 | // This file implements the linalg dialect Vectorization transformations. |
| 10 | // |
| 11 | //===----------------------------------------------------------------------===// |
| 12 | #include "mlir/Dialect/Affine/Utils.h" |
| 13 | |
| 14 | #include "mlir/Analysis/SliceAnalysis.h" |
| 15 | #include "mlir/Dialect/Affine/IR/AffineOps.h" |
| 16 | #include "mlir/Dialect/Arith/IR/Arith.h" |
| 17 | #include "mlir/Dialect/Func/IR/FuncOps.h" |
| 18 | #include "mlir/Dialect/Linalg/IR/Linalg.h" |
| 19 | #include "mlir/Dialect/Linalg/Transforms/Transforms.h" |
| 20 | #include "mlir/Dialect/Linalg/Utils/Utils.h" |
| 21 | #include "mlir/Dialect/Tensor/IR/Tensor.h" |
| 22 | #include "mlir/Dialect/Tensor/Utils/Utils.h" |
| 23 | #include "mlir/Dialect/Utils/IndexingUtils.h" |
| 24 | #include "mlir/Dialect/Utils/StructuredOpsUtils.h" |
| 25 | #include "mlir/Dialect/Vector/IR/VectorOps.h" |
| 26 | #include "mlir/Dialect/Vector/Interfaces/MaskableOpInterface.h" |
| 27 | #include "mlir/Dialect/Vector/Utils/VectorUtils.h" |
| 28 | #include "mlir/IR/AffineExpr.h" |
| 29 | #include "mlir/IR/Builders.h" |
| 30 | #include "mlir/IR/BuiltinTypeInterfaces.h" |
| 31 | #include "mlir/IR/BuiltinTypes.h" |
| 32 | #include "mlir/IR/OpDefinition.h" |
| 33 | #include "mlir/IR/PatternMatch.h" |
| 34 | #include "mlir/Support/LLVM.h" |
| 35 | #include "mlir/Transforms/RegionUtils.h" |
| 36 | #include "llvm/ADT/STLExtras.h" |
| 37 | #include "llvm/ADT/Sequence.h" |
| 38 | #include "llvm/ADT/SmallVector.h" |
| 39 | #include "llvm/ADT/TypeSwitch.h" |
| 40 | #include "llvm/ADT/iterator_range.h" |
| 41 | #include "llvm/Support/Debug.h" |
| 42 | #include "llvm/Support/MathExtras.h" |
| 43 | #include "llvm/Support/raw_ostream.h" |
| 44 | #include <optional> |
| 45 | #include <type_traits> |
| 46 | |
| 47 | using namespace mlir; |
| 48 | using namespace mlir::linalg; |
| 49 | |
| 50 | #define DEBUG_TYPE "linalg-vectorization" |
| 51 | |
| 52 | #define DBGS() (llvm::dbgs() << '[' << DEBUG_TYPE << "] ") |
| 53 | #define LDBG(X) LLVM_DEBUG(DBGS() << X << "\n") |
| 54 | |
| 55 | /// Try to vectorize `convOp` as a convolution. |
| 56 | static FailureOr<Operation *> |
| 57 | vectorizeConvolution(RewriterBase &rewriter, LinalgOp convOp, |
| 58 | ArrayRef<int64_t> inputVecSizes = {}, |
| 59 | ArrayRef<bool> inputVecScalableFlags = {}, |
| 60 | bool flatten1DDepthwiseConv = false); |
| 61 | |
| 62 | /// Vectorize tensor::InsertSliceOp with: |
| 63 | /// * vector::TransferReadOp + vector::TransferWriteOp |
| 64 | /// The vector sizes are either: |
| 65 | /// * user-provided in `inputVectorSizes`, or |
| 66 | /// * inferred from the static dims in the input and output tensors. |
| 67 | /// Bails out if: |
| 68 | /// * vector sizes are not user-provided, and |
| 69 | /// * at least one dim is dynamic (in both the input and output tensors). |
| 70 | /// |
| 71 | /// Before: |
| 72 | /// !t_in_type = tensor<1x2x3xf32> |
| 73 | /// !t_out_type = tensor<9x8x7x1x2x3xf32> |
| 74 | /// !v_type = vector<1x2x3xf32> |
| 75 | /// %inserted_slice = tensor.insert_slice %src into %dest ... : !t_in_type |
| 76 | /// into !t_out_type |
| 77 | /// After: |
| 78 | /// %read = vector.transfer_read %src[...], %pad ... : !t_in_type, !v_type |
| 79 | /// %write = vector.transfer_write %read, %dest ... : !v_type, !t_out_type |
| 80 | static LogicalResult |
| 81 | vectorizeAsInsertSliceOp(RewriterBase &rewriter, tensor::InsertSliceOp sliceOp, |
| 82 | ArrayRef<int64_t> inputVectorSizes, |
| 83 | SmallVectorImpl<Value> &newResults); |
| 84 | |
| 85 | /// Returns the effective Pad value for the input op, provided it's a scalar. |
| 86 | /// |
| 87 | /// Many Ops exhibit pad-like behaviour, but this isn't always explicit. If |
| 88 | /// this Op performs padding, retrieve the padding value provided that it's |
| 89 | /// a scalar and static/fixed for all the padded values. Returns an empty value |
| 90 | /// otherwise. |
| 91 | static Value getStaticPadVal(Operation *op); |
| 92 | |
| 93 | /// Return the unique instance of OpType in `block` if it is indeed unique. |
| 94 | /// Return null if none or more than 1 instances exist. |
| 95 | template <typename OpType> |
| 96 | static OpType getSingleOpOfType(Block &block) { |
| 97 | OpType res; |
| 98 | block.walk([&](OpType op) { |
| 99 | if (res) { |
| 100 | res = nullptr; |
| 101 | return WalkResult::interrupt(); |
| 102 | } |
| 103 | res = op; |
| 104 | return WalkResult::advance(); |
| 105 | }); |
| 106 | return res; |
| 107 | } |
| 108 | |
| 109 | /// Helper function to extract the input slices after filter is unrolled along |
| 110 | /// kw. |
| 111 | static SmallVector<Value> |
| 112 | (RewriterBase &rewriter, Location loc, Value input, |
| 113 | int64_t nSize, int64_t wSize, int64_t cSize, |
| 114 | int64_t kwSize, int strideW, int dilationW, |
| 115 | int64_t wSizeStep, bool isSingleChanneled) { |
| 116 | SmallVector<Value> result; |
| 117 | if (isSingleChanneled) { |
| 118 | // Extract input slice of size {wSizeStep} @ [w + kw] for non-channeled |
| 119 | // convolution. |
| 120 | SmallVector<int64_t> sizes = {wSizeStep}; |
| 121 | SmallVector<int64_t> strides = {1}; |
| 122 | for (int64_t kw = 0; kw < kwSize; ++kw) { |
| 123 | for (int64_t w = 0; w < wSize; w += wSizeStep) { |
| 124 | result.push_back(rewriter.create<vector::ExtractStridedSliceOp>( |
| 125 | loc, input, /*offsets=*/ArrayRef<int64_t>{w + kw}, sizes, strides)); |
| 126 | } |
| 127 | } |
| 128 | } else { |
| 129 | // Extract lhs slice of size {n, wSizeStep, c} @ [0, sw * w + dw * kw, 0] |
| 130 | // for channeled convolution. |
| 131 | SmallVector<int64_t> sizes = {nSize, wSizeStep, cSize}; |
| 132 | SmallVector<int64_t> strides = {1, 1, 1}; |
| 133 | for (int64_t kw = 0; kw < kwSize; ++kw) { |
| 134 | for (int64_t w = 0; w < wSize; w += wSizeStep) { |
| 135 | result.push_back(rewriter.create<vector::ExtractStridedSliceOp>( |
| 136 | loc, input, |
| 137 | /*offsets=*/ArrayRef<int64_t>{0, w * strideW + kw * dilationW, 0}, |
| 138 | sizes, strides)); |
| 139 | } |
| 140 | } |
| 141 | } |
| 142 | return result; |
| 143 | } |
| 144 | |
| 145 | /// Helper function to extract the filter slices after filter is unrolled along |
| 146 | /// kw. |
| 147 | static SmallVector<Value> (RewriterBase &rewriter, |
| 148 | Location loc, Value filter, |
| 149 | int64_t kwSize) { |
| 150 | SmallVector<Value> result; |
| 151 | // Extract rhs slice of size [{c, f} for channeled convolutions and {1} for |
| 152 | // non-chanelled convolution] @ [kw]. |
| 153 | for (int64_t kw = 0; kw < kwSize; ++kw) { |
| 154 | result.push_back(rewriter.create<vector::ExtractOp>( |
| 155 | loc, filter, /*offsets=*/ArrayRef<int64_t>{kw})); |
| 156 | } |
| 157 | return result; |
| 158 | } |
| 159 | |
| 160 | /// Helper function to extract the result slices after filter is unrolled along |
| 161 | /// kw. |
| 162 | static SmallVector<Value> |
| 163 | (RewriterBase &rewriter, Location loc, Value res, |
| 164 | int64_t nSize, int64_t wSize, int64_t fSize, |
| 165 | int64_t wSizeStep, bool isSingleChanneled) { |
| 166 | SmallVector<Value> result; |
| 167 | if (isSingleChanneled) { |
| 168 | // Extract res slice: {wSizeStep} @ [w] for non-channeled convolution. |
| 169 | SmallVector<int64_t> sizes = {wSizeStep}; |
| 170 | SmallVector<int64_t> strides = {1}; |
| 171 | for (int64_t w = 0; w < wSize; w += wSizeStep) { |
| 172 | result.push_back(rewriter.create<vector::ExtractStridedSliceOp>( |
| 173 | loc, res, /*offsets=*/ArrayRef<int64_t>{w}, sizes, strides)); |
| 174 | } |
| 175 | } else { |
| 176 | // Extract res slice: {n, wSizeStep, f} @ [0, w, 0] for channeled |
| 177 | // convolution. |
| 178 | SmallVector<int64_t> sizes = {nSize, wSizeStep, fSize}; |
| 179 | SmallVector<int64_t> strides = {1, 1, 1}; |
| 180 | for (int64_t w = 0; w < wSize; w += wSizeStep) { |
| 181 | result.push_back(rewriter.create<vector::ExtractStridedSliceOp>( |
| 182 | loc, res, /*offsets=*/ArrayRef<int64_t>{0, w, 0}, sizes, strides)); |
| 183 | } |
| 184 | } |
| 185 | return result; |
| 186 | } |
| 187 | |
| 188 | /// Helper function to insert the computed result slices. |
| 189 | static Value insertConvResultSlices(RewriterBase &rewriter, Location loc, |
| 190 | Value res, int64_t wSize, int64_t wSizeStep, |
| 191 | SmallVectorImpl<Value> &resVals, |
| 192 | bool isSingleChanneled) { |
| 193 | |
| 194 | if (isSingleChanneled) { |
| 195 | // Write back res slice: {wSizeStep} @ [w] for non-channeled convolution. |
| 196 | // This does not depend on kw. |
| 197 | SmallVector<int64_t> strides = {1}; |
| 198 | for (int64_t w = 0; w < wSize; w += wSizeStep) { |
| 199 | res = rewriter.create<vector::InsertStridedSliceOp>( |
| 200 | loc, resVals[w], res, /*offsets=*/ArrayRef<int64_t>{w}, strides); |
| 201 | } |
| 202 | } else { |
| 203 | // Write back res slice: {n, wSizeStep, f} @ [0, w, 0] for channeled |
| 204 | // convolution. This does not depend on kw. |
| 205 | SmallVector<int64_t> strides = {1, 1, 1}; |
| 206 | for (int64_t w = 0; w < wSize; w += wSizeStep) { |
| 207 | res = rewriter.create<vector::InsertStridedSliceOp>( |
| 208 | loc, resVals[w], res, /*offsets=*/ArrayRef<int64_t>{0, w, 0}, |
| 209 | strides); |
| 210 | } |
| 211 | } |
| 212 | return res; |
| 213 | } |
| 214 | |
| 215 | /// Contains the vectorization state and related methods used across the |
| 216 | /// vectorization process of a given operation. |
| 217 | struct VectorizationState { |
| 218 | VectorizationState(RewriterBase &rewriter) : rewriterGuard(rewriter) {} |
| 219 | |
| 220 | /// Initializes the vectorization state, including the computation of the |
| 221 | /// canonical vector shape for vectorization. |
| 222 | LogicalResult initState(RewriterBase &rewriter, LinalgOp linalgOp, |
| 223 | ArrayRef<int64_t> inputVectorSizes, |
| 224 | ArrayRef<bool> inputScalableVecDims); |
| 225 | |
| 226 | /// Returns the canonical vector shape used to vectorize the iteration space. |
| 227 | ArrayRef<int64_t> getCanonicalVecShape() const { return canonicalVecShape; } |
| 228 | |
| 229 | /// Returns the vector dimensions that are scalable in the canonical vector |
| 230 | /// shape. |
| 231 | ArrayRef<bool> getScalableVecDims() const { return scalableVecDims; } |
| 232 | |
| 233 | /// Returns a vector type of the provided `elementType` with the canonical |
| 234 | /// vector shape and the corresponding fixed/scalable dimensions bit. If |
| 235 | /// `dimPermutation` is provided, the canonical vector dimensions are permuted |
| 236 | /// accordingly. |
| 237 | VectorType getCanonicalVecType( |
| 238 | Type elementType, |
| 239 | std::optional<AffineMap> dimPermutation = std::nullopt) const { |
| 240 | SmallVector<int64_t> vectorShape; |
| 241 | SmallVector<bool> scalableDims; |
| 242 | if (dimPermutation.has_value()) { |
| 243 | vectorShape = |
| 244 | applyPermutationMap<int64_t>(map: *dimPermutation, source: canonicalVecShape); |
| 245 | scalableDims = |
| 246 | applyPermutationMap<bool>(map: *dimPermutation, source: scalableVecDims); |
| 247 | } else { |
| 248 | vectorShape.append(in_start: canonicalVecShape.begin(), in_end: canonicalVecShape.end()); |
| 249 | scalableDims.append(in_start: scalableVecDims.begin(), in_end: scalableVecDims.end()); |
| 250 | } |
| 251 | |
| 252 | return VectorType::get(vectorShape, elementType, scalableDims); |
| 253 | } |
| 254 | |
| 255 | /// Masks an operation with the canonical vector mask if the operation needs |
| 256 | /// masking. Returns the masked operation or the original operation if masking |
| 257 | /// is not needed. If provided, the canonical mask for this operation is |
| 258 | /// permuted using `maybeIndexingMap`. |
| 259 | Operation * |
| 260 | maskOperation(RewriterBase &rewriter, Operation *opToMask, LinalgOp linalgOp, |
| 261 | std::optional<AffineMap> maybeIndexingMap = std::nullopt); |
| 262 | |
| 263 | private: |
| 264 | /// Initializes the iteration space static sizes using the Linalg op |
| 265 | /// information. This may become more complicated in the future. |
| 266 | void initIterSpaceStaticSizes(LinalgOp linalgOp) { |
| 267 | iterSpaceStaticSizes.append(linalgOp.getStaticLoopRanges()); |
| 268 | } |
| 269 | |
| 270 | /// Generates 'arith.constant' and 'tensor/memref.dim' operations for |
| 271 | /// all the static and dynamic dimensions of the iteration space to be |
| 272 | /// vectorized and store them in `iterSpaceValueSizes`. |
| 273 | LogicalResult precomputeIterSpaceValueSizes(RewriterBase &rewriter, |
| 274 | LinalgOp linalgOp); |
| 275 | |
| 276 | /// Create or retrieve an existing mask value to mask `opToMask` in the |
| 277 | /// canonical vector iteration space. If `maybeMaskingMap` the mask is |
| 278 | /// permuted using that permutation map. If a new mask is created, it will be |
| 279 | /// cached for future users. |
| 280 | Value getOrCreateMaskFor(RewriterBase &rewriter, Operation *opToMask, |
| 281 | LinalgOp linalgOp, |
| 282 | std::optional<AffineMap> maybeMaskingMap); |
| 283 | |
| 284 | /// Check whether this permutation map can be used for masking. At the |
| 285 | /// moment we only make sure that there are no broadcast dimensions, but this |
| 286 | /// might change if indexing maps evolve. |
| 287 | bool isValidMaskingMap(AffineMap maskingMap) { |
| 288 | return maskingMap.getBroadcastDims().size() == 0; |
| 289 | } |
| 290 | |
| 291 | /// Turn the input indexing map into a valid masking map. |
| 292 | /// |
| 293 | /// The input indexing map may contain "zero" results, e.g.: |
| 294 | /// (d0, d1, d2, d3) -> (d2, d1, d0, 0) |
| 295 | /// Applying such maps to canonical vector shapes like this one: |
| 296 | /// (1, 16, 16, 4) |
| 297 | /// would yield an invalid vector shape like this: |
| 298 | /// (16, 16, 1, 0) |
| 299 | /// Instead, drop the broadcasting dims that make no sense for masking perm. |
| 300 | /// maps: |
| 301 | /// (d0, d1, d2, d3) -> (d2, d1, d0) |
| 302 | /// This way, the corresponding vector/mask type will be: |
| 303 | /// vector<16x16x1xty> |
| 304 | /// rather than this invalid Vector type: |
| 305 | /// vector<16x16x1x0xty> |
| 306 | AffineMap getMaskingMapFromIndexingMap(AffineMap &indexingMap) { |
| 307 | return indexingMap.dropZeroResults(); |
| 308 | } |
| 309 | |
| 310 | // Holds the compile-time static sizes of the iteration space to vectorize. |
| 311 | // Dynamic dimensions are represented using ShapedType::kDynamic. |
| 312 | SmallVector<int64_t> iterSpaceStaticSizes; |
| 313 | |
| 314 | /// Holds the value sizes of the iteration space to vectorize. Static |
| 315 | /// dimensions are represented by 'arith.constant' and dynamic |
| 316 | /// dimensions by 'tensor/memref.dim'. |
| 317 | SmallVector<Value> iterSpaceValueSizes; |
| 318 | |
| 319 | /// Holds the canonical vector shape used to vectorize the iteration space. |
| 320 | SmallVector<int64_t> canonicalVecShape; |
| 321 | |
| 322 | /// Holds the vector dimensions that are scalable in the canonical vector |
| 323 | /// shape. |
| 324 | SmallVector<bool> scalableVecDims; |
| 325 | |
| 326 | /// Holds the active masks for permutations of the canonical vector iteration |
| 327 | /// space. |
| 328 | DenseMap<AffineMap, Value> activeMaskCache; |
| 329 | |
| 330 | /// Global vectorization guard for the incoming rewriter. It's initialized |
| 331 | /// when the vectorization state is initialized. |
| 332 | OpBuilder::InsertionGuard rewriterGuard; |
| 333 | }; |
| 334 | |
| 335 | LogicalResult |
| 336 | VectorizationState::precomputeIterSpaceValueSizes(RewriterBase &rewriter, |
| 337 | LinalgOp linalgOp) { |
| 338 | // TODO: Support 0-d vectors. |
| 339 | for (int vecDim = 0, end = canonicalVecShape.size(); vecDim < end; ++vecDim) { |
| 340 | if (!ShapedType::isDynamic(iterSpaceStaticSizes[vecDim])) { |
| 341 | // Create constant index op for static dimensions. |
| 342 | iterSpaceValueSizes.push_back(Elt: rewriter.create<arith::ConstantIndexOp>( |
| 343 | linalgOp.getLoc(), iterSpaceStaticSizes[vecDim])); |
| 344 | continue; |
| 345 | } |
| 346 | |
| 347 | // Find an operand defined on this dimension of the iteration space to |
| 348 | // extract the runtime dimension size. |
| 349 | Value operand; |
| 350 | unsigned operandDimPos; |
| 351 | if (failed(linalgOp.mapIterationSpaceDimToOperandDim(vecDim, operand, |
| 352 | operandDimPos))) |
| 353 | return failure(); |
| 354 | |
| 355 | Value dynamicDim = linalgOp.hasPureTensorSemantics() |
| 356 | ? (Value)rewriter.create<tensor::DimOp>( |
| 357 | linalgOp.getLoc(), operand, operandDimPos) |
| 358 | : (Value)rewriter.create<memref::DimOp>( |
| 359 | linalgOp.getLoc(), operand, operandDimPos); |
| 360 | iterSpaceValueSizes.push_back(Elt: dynamicDim); |
| 361 | } |
| 362 | |
| 363 | return success(); |
| 364 | } |
| 365 | |
| 366 | /// Initializes the vectorization state, including the computation of the |
| 367 | /// canonical vector shape for vectorization. |
| 368 | // TODO: Move this to the constructor when we can remove the failure cases. |
| 369 | LogicalResult |
| 370 | VectorizationState::initState(RewriterBase &rewriter, LinalgOp linalgOp, |
| 371 | ArrayRef<int64_t> inputVectorSizes, |
| 372 | ArrayRef<bool> inputScalableVecDims) { |
| 373 | // Initialize the insertion point. |
| 374 | rewriter.setInsertionPoint(linalgOp); |
| 375 | |
| 376 | if (!inputVectorSizes.empty()) { |
| 377 | // Get the canonical vector shape from the input vector sizes provided. This |
| 378 | // path should be taken to vectorize code with dynamic shapes and when using |
| 379 | // vector sizes greater than the iteration space sizes. |
| 380 | canonicalVecShape.append(in_start: inputVectorSizes.begin(), in_end: inputVectorSizes.end()); |
| 381 | scalableVecDims.append(in_start: inputScalableVecDims.begin(), |
| 382 | in_end: inputScalableVecDims.end()); |
| 383 | } else { |
| 384 | // Compute the canonical vector shape from the operation shape. If there are |
| 385 | // dynamic shapes, the operation won't be vectorized. We assume all the |
| 386 | // vector dimensions are fixed. |
| 387 | canonicalVecShape = linalgOp.getStaticLoopRanges(); |
| 388 | scalableVecDims.append(linalgOp.getNumLoops(), false); |
| 389 | } |
| 390 | |
| 391 | LDBG("Canonical vector shape: " ); |
| 392 | LLVM_DEBUG(llvm::interleaveComma(canonicalVecShape, llvm::dbgs())); |
| 393 | LLVM_DEBUG(llvm::dbgs() << "\n" ); |
| 394 | LDBG("Scalable vector dims: " ); |
| 395 | LLVM_DEBUG(llvm::interleaveComma(scalableVecDims, llvm::dbgs())); |
| 396 | LLVM_DEBUG(llvm::dbgs() << "\n" ); |
| 397 | |
| 398 | if (ShapedType::isDynamicShape(canonicalVecShape)) |
| 399 | return failure(); |
| 400 | |
| 401 | // Initialize iteration space static sizes. |
| 402 | initIterSpaceStaticSizes(linalgOp: linalgOp); |
| 403 | |
| 404 | // Generate 'arith.constant' and 'tensor/memref.dim' operations for |
| 405 | // all the static and dynamic dimensions of the iteration space, needed to |
| 406 | // compute a mask during vectorization. |
| 407 | if (failed(precomputeIterSpaceValueSizes(rewriter, linalgOp: linalgOp))) |
| 408 | return failure(); |
| 409 | |
| 410 | return success(); |
| 411 | } |
| 412 | |
| 413 | /// Create or retrieve an existing mask value to mask `opToMask` in the |
| 414 | /// canonical vector iteration space. If `maybeMaskingMap` the mask is permuted |
| 415 | /// using that permutation map. If a new mask is created, it will be cached for |
| 416 | /// future users. |
| 417 | Value VectorizationState::getOrCreateMaskFor( |
| 418 | RewriterBase &rewriter, Operation *opToMask, LinalgOp linalgOp, |
| 419 | std::optional<AffineMap> maybeMaskingMap) { |
| 420 | |
| 421 | assert((!maybeMaskingMap || isValidMaskingMap(*maybeMaskingMap)) && |
| 422 | "Ill-formed masking map." ); |
| 423 | |
| 424 | // No mask is needed if the operation is not maskable. |
| 425 | auto maskableOp = dyn_cast<vector::MaskableOpInterface>(opToMask); |
| 426 | if (!maskableOp) |
| 427 | return Value(); |
| 428 | |
| 429 | assert(!maskableOp.isMasked() && |
| 430 | "Masking an operation that is already masked" ); |
| 431 | |
| 432 | // If no masking map was provided, use an identity map with the loop dims. |
| 433 | assert((!maybeMaskingMap || *maybeMaskingMap) && |
| 434 | "Unexpected null mask permutation map" ); |
| 435 | AffineMap maskingMap = |
| 436 | maybeMaskingMap ? *maybeMaskingMap |
| 437 | : AffineMap::getMultiDimIdentityMap( |
| 438 | numDims: linalgOp.getNumLoops(), context: rewriter.getContext()); |
| 439 | |
| 440 | LDBG("Masking map: " << maskingMap << "\n" ); |
| 441 | |
| 442 | // Return the active mask for the masking map of this operation if it was |
| 443 | // already created. |
| 444 | auto activeMaskIt = activeMaskCache.find(Val: maskingMap); |
| 445 | if (activeMaskIt != activeMaskCache.end()) { |
| 446 | Value mask = activeMaskIt->second; |
| 447 | LDBG("Reusing mask: " << mask << "\n" ); |
| 448 | return mask; |
| 449 | } |
| 450 | |
| 451 | // Compute permuted projection of the iteration space to be masked and the |
| 452 | // corresponding mask shape. If the resulting iteration space dimensions are |
| 453 | // static and identical to the mask shape, masking is not needed for this |
| 454 | // operation. |
| 455 | // TODO: Improve this check. Only projected permutation indexing maps are |
| 456 | // supported. |
| 457 | SmallVector<int64_t> permutedStaticSizes = |
| 458 | applyPermutationMap<int64_t>(map: maskingMap, source: iterSpaceStaticSizes); |
| 459 | auto maskType = getCanonicalVecType(rewriter.getI1Type(), maskingMap); |
| 460 | auto maskShape = maskType.getShape(); |
| 461 | |
| 462 | LDBG("Mask shape: " ); |
| 463 | LLVM_DEBUG(llvm::interleaveComma(maskShape, llvm::dbgs())); |
| 464 | LLVM_DEBUG(llvm::dbgs() << "\n" ); |
| 465 | |
| 466 | if (permutedStaticSizes == maskShape) { |
| 467 | LDBG("Masking is not needed for masking map: " << maskingMap << "\n" ); |
| 468 | activeMaskCache[maskingMap] = Value(); |
| 469 | return Value(); |
| 470 | } |
| 471 | |
| 472 | // Permute the iteration space value sizes to compute the mask upper bounds. |
| 473 | SmallVector<Value> upperBounds = |
| 474 | applyPermutationMap(map: maskingMap, source: ArrayRef<Value>(iterSpaceValueSizes)); |
| 475 | assert(!maskShape.empty() && !upperBounds.empty() && |
| 476 | "Masked 0-d vectors are not supported yet" ); |
| 477 | |
| 478 | // Create the mask based on the dimension values. |
| 479 | Value mask = rewriter.create<vector::CreateMaskOp>(linalgOp.getLoc(), |
| 480 | maskType, upperBounds); |
| 481 | LDBG("Creating new mask: " << mask << "\n" ); |
| 482 | activeMaskCache[maskingMap] = mask; |
| 483 | return mask; |
| 484 | } |
| 485 | |
| 486 | Operation * |
| 487 | VectorizationState::maskOperation(RewriterBase &rewriter, Operation *opToMask, |
| 488 | LinalgOp linalgOp, |
| 489 | std::optional<AffineMap> maybeIndexingMap) { |
| 490 | LDBG("Trying to mask: " << *opToMask << "\n" ); |
| 491 | |
| 492 | std::optional<AffineMap> maybeMaskingMap = std::nullopt; |
| 493 | if (maybeIndexingMap) |
| 494 | maybeMaskingMap = getMaskingMapFromIndexingMap(indexingMap&: *maybeIndexingMap); |
| 495 | |
| 496 | // Create or retrieve mask for this operation. |
| 497 | Value mask = |
| 498 | getOrCreateMaskFor(rewriter, opToMask, linalgOp: linalgOp, maybeMaskingMap); |
| 499 | |
| 500 | if (!mask) { |
| 501 | LDBG("No mask required\n" ); |
| 502 | return opToMask; |
| 503 | } |
| 504 | |
| 505 | // Wrap the operation with a new `vector.mask` and update D-U chain. |
| 506 | assert(opToMask && "Expected a valid operation to mask" ); |
| 507 | auto maskOp = cast<vector::MaskOp>( |
| 508 | mlir::vector::maskOperation(rewriter, opToMask, mask)); |
| 509 | Operation *maskOpTerminator = &maskOp.getMaskRegion().front().back(); |
| 510 | |
| 511 | for (auto [resIdx, resVal] : llvm::enumerate(First: opToMask->getResults())) |
| 512 | rewriter.replaceAllUsesExcept(resVal, maskOp.getResult(resIdx), |
| 513 | maskOpTerminator); |
| 514 | |
| 515 | LDBG("Masked operation: " << *maskOp << "\n" ); |
| 516 | return maskOp; |
| 517 | } |
| 518 | |
| 519 | /// Given an indexing `map` coming from a LinalgOp indexing, restricted to a |
| 520 | /// projectedPermutation, compress the unused dimensions to serve as a |
| 521 | /// permutation_map for a vector transfer operation. |
| 522 | /// For example, given a linalg op such as: |
| 523 | /// |
| 524 | /// ``` |
| 525 | /// %0 = linalg.generic { |
| 526 | /// indexing_maps = affine_map<(d0, d1, d2, d3, d4) -> (d4, d0, d2)>, |
| 527 | /// indexing_maps = affine_map<(d0, d1, d2, d3, d4) -> (d1, d3)> |
| 528 | /// } |
| 529 | /// ins(%0 : tensor<2x3x4xf32>) |
| 530 | /// outs(%1 : tensor<5x6xf32>) |
| 531 | /// ``` |
| 532 | /// |
| 533 | /// the iteration domain size of the linalg op is 3x5x4x6x2. The first affine |
| 534 | /// map is reindexed to `affine_map<(d0, d1, d2) -> (d2, d0, d1)>`, the second |
| 535 | /// affine map is reindexed to `affine_map<(d0, d1) -> (d0, d1)>`. |
| 536 | static AffineMap reindexIndexingMap(AffineMap map) { |
| 537 | assert(map.isProjectedPermutation(/*allowZeroInResults=*/true) && |
| 538 | "expected projected permutation" ); |
| 539 | auto res = compressUnusedDims(map); |
| 540 | assert(res.getNumDims() == |
| 541 | (res.getNumResults() - res.getNumOfZeroResults()) && |
| 542 | "expected reindexed map with same number of dims and results" ); |
| 543 | return res; |
| 544 | } |
| 545 | |
| 546 | /// Helper enum to represent conv1d input traversal order. |
| 547 | enum class Conv1DOpOrder { |
| 548 | W, // Corresponds to non-channeled 1D convolution operation. |
| 549 | Ncw, // Corresponds to operation that traverses the input in (n, c, w) order. |
| 550 | Nwc // Corresponds to operation that traverses the input in (n, w, c) order. |
| 551 | }; |
| 552 | |
| 553 | /// Helper data structure to represent the result of vectorization. |
| 554 | /// In certain specific cases, like terminators, we do not want to propagate/ |
| 555 | enum VectorizationStatus { |
| 556 | /// Op failed to vectorize. |
| 557 | Failure = 0, |
| 558 | /// Op vectorized and custom function took care of replacement logic |
| 559 | NoReplace, |
| 560 | /// Op vectorized into a new Op whose results will replace original Op's |
| 561 | /// results. |
| 562 | NewOp |
| 563 | // TODO: support values if Op vectorized to Many-Ops whose results we need to |
| 564 | // aggregate for replacement. |
| 565 | }; |
| 566 | struct VectorizationResult { |
| 567 | /// Return status from vectorizing the current op. |
| 568 | enum VectorizationStatus status = VectorizationStatus::Failure; |
| 569 | /// New vectorized operation to replace the current op. |
| 570 | /// Replacement behavior is specified by `status`. |
| 571 | Operation *newOp; |
| 572 | }; |
| 573 | |
| 574 | std::optional<vector::CombiningKind> |
| 575 | mlir::linalg::getCombinerOpKind(Operation *combinerOp) { |
| 576 | using ::mlir::vector::CombiningKind; |
| 577 | |
| 578 | if (!combinerOp) |
| 579 | return std::nullopt; |
| 580 | return llvm::TypeSwitch<Operation *, std::optional<CombiningKind>>(combinerOp) |
| 581 | .Case<arith::AddIOp, arith::AddFOp>( |
| 582 | [&](auto op) { return CombiningKind::ADD; }) |
| 583 | .Case<arith::AndIOp>([&](auto op) { return CombiningKind::AND; }) |
| 584 | .Case<arith::MaxSIOp>([&](auto op) { return CombiningKind::MAXSI; }) |
| 585 | .Case<arith::MaxUIOp>([&](auto op) { return CombiningKind::MAXUI; }) |
| 586 | .Case<arith::MaximumFOp>([&](auto op) { return CombiningKind::MAXIMUMF; }) |
| 587 | .Case<arith::MaxNumFOp>([&](auto op) { return CombiningKind::MAXNUMF; }) |
| 588 | .Case<arith::MinSIOp>([&](auto op) { return CombiningKind::MINSI; }) |
| 589 | .Case<arith::MinUIOp>([&](auto op) { return CombiningKind::MINUI; }) |
| 590 | .Case<arith::MinimumFOp>([&](auto op) { return CombiningKind::MINIMUMF; }) |
| 591 | .Case<arith::MinNumFOp>([&](auto op) { return CombiningKind::MINNUMF; }) |
| 592 | .Case<arith::MulIOp, arith::MulFOp>( |
| 593 | [&](auto op) { return CombiningKind::MUL; }) |
| 594 | .Case<arith::OrIOp>([&](auto op) { return CombiningKind::OR; }) |
| 595 | .Case<arith::XOrIOp>([&](auto op) { return CombiningKind::XOR; }) |
| 596 | .Default([&](auto op) { return std::nullopt; }); |
| 597 | } |
| 598 | |
| 599 | /// Check whether `outputOperand` is a reduction with a single combiner |
| 600 | /// operation. Return the combiner operation of the reduction. Return |
| 601 | /// nullptr otherwise. Multiple reduction operations would impose an |
| 602 | /// ordering between reduction dimensions and is currently unsupported in |
| 603 | /// Linalg. This limitation is motivated by the fact that e.g. min(max(X)) != |
| 604 | /// max(min(X)) |
| 605 | // TODO: use in LinalgOp verification, there is a circular dependency atm. |
| 606 | static Operation *matchLinalgReduction(OpOperand *outputOperand) { |
| 607 | auto linalgOp = cast<LinalgOp>(outputOperand->getOwner()); |
| 608 | unsigned outputPos = |
| 609 | outputOperand->getOperandNumber() - linalgOp.getNumDpsInputs(); |
| 610 | // Only single combiner operations are supported for now. |
| 611 | SmallVector<Operation *, 4> combinerOps; |
| 612 | if (!matchReduction(linalgOp.getRegionOutputArgs(), outputPos, combinerOps) || |
| 613 | combinerOps.size() != 1) |
| 614 | return nullptr; |
| 615 | |
| 616 | // Return the combiner operation. |
| 617 | return combinerOps[0]; |
| 618 | } |
| 619 | |
| 620 | /// Broadcast `value` to a vector of `shape` if possible. Return value |
| 621 | /// otherwise. |
| 622 | static Value broadcastIfNeeded(OpBuilder &b, Value value, Type dstType) { |
| 623 | auto dstVecType = dyn_cast<VectorType>(dstType); |
| 624 | // If no shape to broadcast to, just return `value`. |
| 625 | if (dstVecType.getRank() == 0) |
| 626 | return value; |
| 627 | if (vector::isBroadcastableTo(srcType: value.getType(), dstVectorType: dstVecType) != |
| 628 | vector::BroadcastableToResult::Success) |
| 629 | return value; |
| 630 | Location loc = b.getInsertionPoint()->getLoc(); |
| 631 | return b.createOrFold<vector::BroadcastOp>(loc, dstVecType, value); |
| 632 | } |
| 633 | |
| 634 | /// Create MultiDimReductionOp to compute the reduction for `reductionOp`. This |
| 635 | /// assumes that `reductionOp` has two operands and one of them is the reduction |
| 636 | /// initial value.buildMultiDimReduce |
| 637 | // Note: this is a true builder that notifies the OpBuilder listener. |
| 638 | // TODO: Consider moving as a static helper on the ReduceOp. |
| 639 | static Operation *buildMultiDimReduce(OpBuilder &b, Operation *reduceOp, |
| 640 | Value valueToReduce, Value acc, |
| 641 | ArrayRef<bool> dimsToMask) { |
| 642 | auto maybeKind = getCombinerOpKind(reduceOp); |
| 643 | assert(maybeKind && "Failed precondition: could not get reduction kind" ); |
| 644 | return b.create<vector::MultiDimReductionOp>( |
| 645 | reduceOp->getLoc(), valueToReduce, acc, dimsToMask, *maybeKind); |
| 646 | } |
| 647 | |
| 648 | static SmallVector<bool> getDimsToReduce(LinalgOp linalgOp) { |
| 649 | return llvm::to_vector( |
| 650 | llvm::map_range(linalgOp.getIteratorTypesArray(), isReductionIterator)); |
| 651 | } |
| 652 | |
| 653 | /// Check if `op` is a linalg.reduce or a linalg.generic that has at least one |
| 654 | /// reduction iterator. |
| 655 | static bool hasReductionIterator(LinalgOp &op) { |
| 656 | return isa<linalg::ReduceOp>(op) || |
| 657 | (isa<linalg::GenericOp>(op) && |
| 658 | llvm::any_of(op.getIteratorTypesArray(), isReductionIterator)); |
| 659 | } |
| 660 | |
| 661 | /// Build a vector.transfer_write of `value` into `outputOperand` at indices set |
| 662 | /// to all `0`; where `outputOperand` is an output operand of the LinalgOp |
| 663 | /// currently being vectorized. If `dest` has null rank, build an memref.store. |
| 664 | /// Return the produced value or null if no value is produced. |
| 665 | // Note: this is a true builder that notifies the OpBuilder listener. |
| 666 | // TODO: Consider moving as a static helper on the ReduceOp. |
| 667 | static Value buildVectorWrite(RewriterBase &rewriter, Value value, |
| 668 | OpOperand *outputOperand, |
| 669 | VectorizationState &state) { |
| 670 | Location loc = value.getLoc(); |
| 671 | auto linalgOp = cast<LinalgOp>(outputOperand->getOwner()); |
| 672 | AffineMap opOperandMap = linalgOp.getMatchingIndexingMap(outputOperand); |
| 673 | |
| 674 | // Compute the vector type of the value to store. This type should be an |
| 675 | // identity or projection of the canonical vector type without any permutation |
| 676 | // applied, given that any permutation in a transfer write happens as part of |
| 677 | // the write itself. |
| 678 | AffineMap vectorTypeMap = AffineMap::getFilteredIdentityMap( |
| 679 | ctx: opOperandMap.getContext(), numDims: opOperandMap.getNumInputs(), |
| 680 | keepDimFilter: [&](AffineDimExpr dimExpr) -> bool { |
| 681 | return llvm::is_contained(Range: opOperandMap.getResults(), Element: dimExpr); |
| 682 | }); |
| 683 | auto vectorType = state.getCanonicalVecType( |
| 684 | getElementTypeOrSelf(type: outputOperand->get().getType()), vectorTypeMap); |
| 685 | |
| 686 | Operation *write; |
| 687 | if (vectorType.getRank() > 0) { |
| 688 | AffineMap writeMap = inversePermutation(map: reindexIndexingMap(map: opOperandMap)); |
| 689 | SmallVector<Value> indices(linalgOp.getRank(outputOperand), |
| 690 | rewriter.create<arith::ConstantIndexOp>(location: loc, args: 0)); |
| 691 | value = broadcastIfNeeded(rewriter, value, vectorType); |
| 692 | assert(value.getType() == vectorType && "Incorrect type" ); |
| 693 | write = rewriter.create<vector::TransferWriteOp>( |
| 694 | loc, value, outputOperand->get(), indices, writeMap); |
| 695 | } else { |
| 696 | // 0-d case is still special: do not invert the reindexing writeMap. |
| 697 | if (!isa<VectorType>(value.getType())) |
| 698 | value = rewriter.create<vector::BroadcastOp>(loc, vectorType, value); |
| 699 | assert(value.getType() == vectorType && "Incorrect type" ); |
| 700 | write = rewriter.create<vector::TransferWriteOp>( |
| 701 | loc, value, outputOperand->get(), ValueRange{}); |
| 702 | } |
| 703 | |
| 704 | write = state.maskOperation(rewriter, opToMask: write, linalgOp: linalgOp, maybeIndexingMap: opOperandMap); |
| 705 | |
| 706 | // If masked, set in-bounds to true. Masking guarantees that the access will |
| 707 | // be in-bounds. |
| 708 | if (auto maskOp = dyn_cast<vector::MaskingOpInterface>(write)) { |
| 709 | auto maskedWriteOp = cast<vector::TransferWriteOp>(maskOp.getMaskableOp()); |
| 710 | SmallVector<bool> inBounds(maskedWriteOp.getVectorType().getRank(), true); |
| 711 | maskedWriteOp.setInBoundsAttr(rewriter.getBoolArrayAttr(inBounds)); |
| 712 | } |
| 713 | |
| 714 | LDBG("vectorized op: " << *write << "\n" ); |
| 715 | if (!write->getResults().empty()) |
| 716 | return write->getResult(idx: 0); |
| 717 | return Value(); |
| 718 | } |
| 719 | |
| 720 | // Custom vectorization precondition function type. This is intented to be used |
| 721 | // with CustomVectorizationHook. Returns success if the corresponding custom |
| 722 | // hook can vectorize the op. |
| 723 | using CustomVectorizationPrecondition = |
| 724 | std::function<LogicalResult(Operation *, bool)>; |
| 725 | |
| 726 | // Custom vectorization function type. Produce a vector form of Operation* |
| 727 | // assuming all its vectorized operands are already in the IRMapping. |
| 728 | // Return nullptr if the Operation cannot be vectorized. |
| 729 | using CustomVectorizationHook = |
| 730 | std::function<VectorizationResult(Operation *, const IRMapping &)>; |
| 731 | |
| 732 | /// Helper function to vectorize the terminator of a `linalgOp`. New result |
| 733 | /// vector values are appended to `newResults`. Return |
| 734 | /// VectorizationStatus::NoReplace to signal the vectorization algorithm that it |
| 735 | /// should not try to map produced operations and instead return the results |
| 736 | /// using the `newResults` vector making them available to the vectorization |
| 737 | /// algorithm for RAUW. This function is meant to be used as a |
| 738 | /// CustomVectorizationHook. |
| 739 | static VectorizationResult |
| 740 | vectorizeLinalgYield(RewriterBase &rewriter, Operation *op, |
| 741 | const IRMapping &bvm, VectorizationState &state, |
| 742 | LinalgOp linalgOp, SmallVectorImpl<Value> &newResults) { |
| 743 | auto yieldOp = dyn_cast<linalg::YieldOp>(op); |
| 744 | if (!yieldOp) |
| 745 | return VectorizationResult{.status: VectorizationStatus::Failure, .newOp: nullptr}; |
| 746 | for (const auto &output : llvm::enumerate(yieldOp.getValues())) { |
| 747 | // TODO: Scan for an opportunity for reuse. |
| 748 | // TODO: use a map. |
| 749 | Value vectorValue = bvm.lookup(output.value()); |
| 750 | Value newResult = |
| 751 | buildVectorWrite(rewriter, vectorValue, |
| 752 | linalgOp.getDpsInitOperand(output.index()), state); |
| 753 | if (newResult) |
| 754 | newResults.push_back(newResult); |
| 755 | } |
| 756 | |
| 757 | return VectorizationResult{.status: VectorizationStatus::NoReplace, .newOp: nullptr}; |
| 758 | } |
| 759 | |
| 760 | /// Helper function to vectorize the index operations of a `linalgOp`. Return |
| 761 | /// VectorizationStatus::NewOp to signal the vectorization algorithm that it |
| 762 | /// should map the produced operations. This function is meant to be used as a |
| 763 | /// CustomVectorizationHook. |
| 764 | static VectorizationResult vectorizeLinalgIndex(RewriterBase &rewriter, |
| 765 | VectorizationState &state, |
| 766 | Operation *op, |
| 767 | LinalgOp linalgOp) { |
| 768 | IndexOp indexOp = dyn_cast<linalg::IndexOp>(op); |
| 769 | if (!indexOp) |
| 770 | return VectorizationResult{.status: VectorizationStatus::Failure, .newOp: nullptr}; |
| 771 | auto loc = indexOp.getLoc(); |
| 772 | // Compute the static loop sizes of the index op. |
| 773 | ArrayRef<int64_t> targetShape = state.getCanonicalVecShape(); |
| 774 | auto dim = indexOp.getDim(); |
| 775 | // Compute a one-dimensional index vector for the index op dimension. |
| 776 | auto indexVectorType = |
| 777 | VectorType::get({targetShape[dim]}, rewriter.getIndexType(), |
| 778 | state.getScalableVecDims()[dim]); |
| 779 | auto indexSteps = rewriter.create<vector::StepOp>(loc, indexVectorType); |
| 780 | // Return the one-dimensional index vector if it lives in the trailing |
| 781 | // dimension of the iteration space since the vectorization algorithm in this |
| 782 | // case can handle the broadcast. |
| 783 | if (dim == targetShape.size() - 1) |
| 784 | return VectorizationResult{VectorizationStatus::NewOp, indexSteps}; |
| 785 | // Otherwise permute the targetShape to move the index dimension last, |
| 786 | // broadcast the one-dimensional index vector to the permuted shape, and |
| 787 | // finally transpose the broadcasted index vector to undo the permutation. |
| 788 | auto permPattern = |
| 789 | llvm::to_vector(Range: llvm::seq<unsigned>(Begin: 0, End: targetShape.size())); |
| 790 | std::swap(permPattern[dim], permPattern.back()); |
| 791 | auto permMap = |
| 792 | AffineMap::getPermutationMap(permPattern, linalgOp.getContext()); |
| 793 | |
| 794 | auto broadCastOp = rewriter.create<vector::BroadcastOp>( |
| 795 | loc, state.getCanonicalVecType(rewriter.getIndexType(), permMap), |
| 796 | indexSteps); |
| 797 | SmallVector<int64_t> transposition = |
| 798 | llvm::to_vector<16>(llvm::seq<int64_t>(0, linalgOp.getNumLoops())); |
| 799 | std::swap(transposition.back(), transposition[dim]); |
| 800 | auto transposeOp = |
| 801 | rewriter.create<vector::TransposeOp>(loc, broadCastOp, transposition); |
| 802 | return VectorizationResult{VectorizationStatus::NewOp, transposeOp}; |
| 803 | } |
| 804 | |
| 805 | /// Helper function to check if the tensor.extract can be vectorized by the |
| 806 | /// custom hook vectorizeTensorExtract. |
| 807 | static LogicalResult |
| 808 | (Operation *op, bool ) { |
| 809 | tensor::ExtractOp = dyn_cast<tensor::ExtractOp>(op); |
| 810 | if (!extractOp) |
| 811 | return failure(); |
| 812 | |
| 813 | if (extractOp.getIndices().size() != 1 && !vectorizeNDExtract) |
| 814 | return failure(); |
| 815 | |
| 816 | // Check the index type, but only for non 0-d tensors (for which we do need |
| 817 | // access indices). |
| 818 | if (not extractOp.getIndices().empty()) { |
| 819 | if (!VectorType::isValidElementType(extractOp.getIndices()[0].getType())) |
| 820 | return failure(); |
| 821 | } |
| 822 | |
| 823 | if (!llvm::all_of(extractOp->getResultTypes(), |
| 824 | VectorType::isValidElementType)) { |
| 825 | return failure(); |
| 826 | } |
| 827 | |
| 828 | return success(); |
| 829 | } |
| 830 | |
| 831 | /// Calculates the offsets (`$index_vec`) for `vector.gather` operations |
| 832 | /// generated from `tensor.extract`. The offset is calculated as follows |
| 833 | /// (example using scalar values): |
| 834 | /// |
| 835 | /// offset = extractOp.indices[0] |
| 836 | /// for (i = 1; i < numIndices; i++) |
| 837 | /// offset = extractOp.dimSize[i] * offset + extractOp.indices[i]; |
| 838 | /// |
| 839 | /// For tensor<45 x 80 x 15 x f32> and index [1, 2, 3], this leads to: |
| 840 | /// offset = ( ( 1 ) * 80 + 2 ) * 15 + 3 |
| 841 | static Value calculateGatherOffset(RewriterBase &rewriter, |
| 842 | VectorizationState &state, |
| 843 | tensor::ExtractOp , |
| 844 | const IRMapping &bvm) { |
| 845 | // The vector of indices for GatherOp should be shaped as the output vector. |
| 846 | auto indexVecType = state.getCanonicalVecType(rewriter.getIndexType()); |
| 847 | auto loc = extractOp.getLoc(); |
| 848 | |
| 849 | Value offset = broadcastIfNeeded( |
| 850 | rewriter, bvm.lookup(extractOp.getIndices()[0]), indexVecType); |
| 851 | |
| 852 | const size_t numIndices = extractOp.getIndices().size(); |
| 853 | for (size_t i = 1; i < numIndices; i++) { |
| 854 | Value dimIdx = rewriter.create<arith::ConstantIndexOp>(loc, i); |
| 855 | |
| 856 | auto dimSize = broadcastIfNeeded( |
| 857 | rewriter, |
| 858 | rewriter.create<tensor::DimOp>(loc, extractOp.getTensor(), dimIdx), |
| 859 | indexVecType); |
| 860 | |
| 861 | offset = rewriter.create<arith::MulIOp>(loc, offset, dimSize); |
| 862 | |
| 863 | auto = broadcastIfNeeded( |
| 864 | rewriter, bvm.lookup(extractOp.getIndices()[i]), indexVecType); |
| 865 | |
| 866 | offset = rewriter.create<arith::AddIOp>(loc, extractOpIndex, offset); |
| 867 | } |
| 868 | |
| 869 | return offset; |
| 870 | } |
| 871 | |
| 872 | enum VectorMemoryAccessKind { ScalarBroadcast, Contiguous, Gather }; |
| 873 | |
| 874 | /// Find the index of the trailing non-unit dim in linalgOp. This hook is used |
| 875 | /// when checking whether `tensor.extract` Op (within a `linalg.generic` Op) |
| 876 | /// represents a contiguous load operation. |
| 877 | /// |
| 878 | /// Note that when calling this hook, it is assumed that the output vector is |
| 879 | /// effectively 1D. Other cases (i.e. reading n-D vectors) should've been |
| 880 | /// labelled as a gather load before entering this method. |
| 881 | /// |
| 882 | /// Following on from the above, it is assumed that: |
| 883 | /// * for statically shaped loops, when no masks are used, only one dim is != |
| 884 | /// 1 (that's what the shape of the output vector is based on). |
| 885 | /// * for dynamically shaped loops, there might be more non-unit dims |
| 886 | /// as the output vector type is user-specified. |
| 887 | /// |
| 888 | /// TODO: Statically shaped loops + vector masking |
| 889 | static uint64_t getTrailingNonUnitLoopDimIdx(LinalgOp linalgOp) { |
| 890 | SmallVector<int64_t> loopRanges = linalgOp.getStaticLoopRanges(); |
| 891 | assert( |
| 892 | (linalgOp.hasDynamicShape() || |
| 893 | llvm::count_if(loopRanges, [](int64_t dim) { return dim != 1; }) == 1) && |
| 894 | "For statically shaped Linalg Ops, only one " |
| 895 | "non-unit loop dim is expected" ); |
| 896 | assert(loopRanges.size() != 0 && "Empty loops, nothing to analyse." ); |
| 897 | |
| 898 | size_t idx = loopRanges.size() - 1; |
| 899 | for (; idx != 0; idx--) |
| 900 | if (loopRanges[idx] != 1) |
| 901 | break; |
| 902 | |
| 903 | return idx; |
| 904 | } |
| 905 | |
| 906 | /// Checks whether `val` can be used for calculating a loop invariant index. |
| 907 | static bool isLoopInvariantIdx(LinalgOp &linalgOp, Value &val, |
| 908 | VectorType resType) { |
| 909 | |
| 910 | assert(((llvm::count_if(resType.getShape(), |
| 911 | [](int64_t dimSize) { return dimSize > 1; }) == 1)) && |
| 912 | "n-D vectors are not yet supported" ); |
| 913 | |
| 914 | // Blocks outside _this_ linalg.generic are effectively loop invariant. |
| 915 | // However, analysing block arguments for _this_ linalg.generic Op is a bit |
| 916 | // tricky. Just bail out in the latter case. |
| 917 | // TODO: We could try analysing the corresponding affine map here. |
| 918 | auto *block = linalgOp.getBlock(); |
| 919 | if (isa<BlockArgument>(Val: val)) |
| 920 | return llvm::all_of(block->getArguments(), |
| 921 | [&val](Value v) { return (v != val); }); |
| 922 | |
| 923 | Operation *defOp = val.getDefiningOp(); |
| 924 | assert(defOp && "This is neither a block argument nor an operation result" ); |
| 925 | |
| 926 | // IndexOp is loop invariant as long as its result remains constant across |
| 927 | // iterations. Note that for dynamic shapes, the corresponding dim will also |
| 928 | // be conservatively treated as != 1. |
| 929 | if (auto indexOp = dyn_cast<linalg::IndexOp>(defOp)) { |
| 930 | return linalgOp.getStaticLoopRanges()[indexOp.getDim()] == 1; |
| 931 | } |
| 932 | |
| 933 | auto *ancestor = block->findAncestorOpInBlock(*defOp); |
| 934 | |
| 935 | // Values define outside `linalgOp` are loop invariant. |
| 936 | if (!ancestor) |
| 937 | return true; |
| 938 | |
| 939 | // Values defined inside `linalgOp`, which are constant, are loop invariant. |
| 940 | if (isa<arith::ConstantOp>(ancestor)) |
| 941 | return true; |
| 942 | |
| 943 | bool result = true; |
| 944 | for (auto op : ancestor->getOperands()) |
| 945 | result &= isLoopInvariantIdx(linalgOp, op, resType); |
| 946 | |
| 947 | return result; |
| 948 | } |
| 949 | |
| 950 | /// Check whether `val` could be used for calculating the trailing index for a |
| 951 | /// contiguous load operation. |
| 952 | /// |
| 953 | /// There are currently 3 types of values that are allowed here: |
| 954 | /// 1. loop-invariant values, |
| 955 | /// 2. values that increment by 1 with every loop iteration, |
| 956 | /// 3. results of basic arithmetic operations (linear and continuous) |
| 957 | /// involving 1., 2. and 3. |
| 958 | /// This method returns True if indeed only such values are used in calculating |
| 959 | /// `val.` |
| 960 | /// |
| 961 | /// Additionally, the trailing index for a contiguous load operation should |
| 962 | /// increment by 1 with every loop iteration, i.e. be based on: |
| 963 | /// * `linalg.index <dim>` , |
| 964 | /// where <dim> is the trailing non-unit dim of the iteration space (this way, |
| 965 | /// `linalg.index <dim>` increments by 1 with every loop iteration). |
| 966 | /// `foundIndexOp` is updated to `true` when such Op is found. |
| 967 | static bool isContiguousLoadIdx(LinalgOp &linalgOp, Value &val, |
| 968 | bool &foundIndexOp, VectorType resType) { |
| 969 | |
| 970 | assert(((llvm::count_if(resType.getShape(), |
| 971 | [](int64_t dimSize) { return dimSize > 1; }) == 1)) && |
| 972 | "n-D vectors are not yet supported" ); |
| 973 | |
| 974 | // Blocks outside _this_ linalg.generic are effectively loop invariant. |
| 975 | // However, analysing block arguments for _this_ linalg.generic Op is a bit |
| 976 | // tricky. Just bail out in the latter case. |
| 977 | // TODO: We could try analysing the corresponding affine map here. |
| 978 | auto *block = linalgOp.getBlock(); |
| 979 | if (isa<BlockArgument>(Val: val)) |
| 980 | return llvm::all_of(block->getArguments(), |
| 981 | [&val](Value v) { return (v != val); }); |
| 982 | |
| 983 | Operation *defOp = val.getDefiningOp(); |
| 984 | assert(defOp && "This is neither a block argument nor an operation result" ); |
| 985 | |
| 986 | if (auto indexOp = dyn_cast<linalg::IndexOp>(defOp)) { |
| 987 | auto loopDimThatIncrementsByOne = getTrailingNonUnitLoopDimIdx(linalgOp); |
| 988 | |
| 989 | foundIndexOp = (indexOp.getDim() == loopDimThatIncrementsByOne); |
| 990 | return true; |
| 991 | } |
| 992 | |
| 993 | auto *ancestor = block->findAncestorOpInBlock(*defOp); |
| 994 | |
| 995 | if (!ancestor) |
| 996 | return false; |
| 997 | |
| 998 | // Conservatively reject Ops that could lead to indices with stride other |
| 999 | // than 1. |
| 1000 | if (!isa<arith::AddIOp, arith::ConstantOp, linalg::IndexOp>(ancestor)) |
| 1001 | return false; |
| 1002 | |
| 1003 | bool result = false; |
| 1004 | for (auto op : ancestor->getOperands()) |
| 1005 | result |= isContiguousLoadIdx(linalgOp, op, foundIndexOp, resType); |
| 1006 | |
| 1007 | return result; |
| 1008 | } |
| 1009 | |
| 1010 | /// Infer the memory access pattern for the input ExtractOp |
| 1011 | /// |
| 1012 | /// Based on the ExtratOp result shape and the access indices, decides whether |
| 1013 | /// this Op corresponds to a contiguous load (including a broadcast of a scalar) |
| 1014 | /// or a gather load. When analysing the ExtractOp indices (to identify |
| 1015 | /// contiguous laods), this method looks for "loop" invariant indices (e.g. |
| 1016 | /// block arguments) and indices that change linearly (e.g. via `linalg.index` |
| 1017 | /// Op). |
| 1018 | /// |
| 1019 | /// Note that it is always safe to use gather load operations for contiguous |
| 1020 | /// loads (albeit slow), but not vice-versa. When in doubt, bail out and assume |
| 1021 | /// that `extractOp` is a gather load. |
| 1022 | static VectorMemoryAccessKind |
| 1023 | (tensor::ExtractOp , |
| 1024 | LinalgOp &linalgOp, VectorType resType) { |
| 1025 | |
| 1026 | auto inputShape = cast<ShapedType>(extractOp.getTensor().getType()); |
| 1027 | |
| 1028 | // 0. Is this a 0-D vector? If yes then this is a scalar broadcast. |
| 1029 | if (inputShape.getShape().empty()) |
| 1030 | return VectorMemoryAccessKind::ScalarBroadcast; |
| 1031 | |
| 1032 | // True for vectors that are effectively 1D, e.g. `vector<1x4x1xi32>`, false |
| 1033 | // otherwise. |
| 1034 | bool isOutput1DVector = |
| 1035 | (llvm::count_if(resType.getShape(), |
| 1036 | [](int64_t dimSize) { return dimSize > 1; }) == 1); |
| 1037 | // 1. Assume that it's a gather load when reading non-1D vector. |
| 1038 | if (!isOutput1DVector) |
| 1039 | return VectorMemoryAccessKind::Gather; |
| 1040 | |
| 1041 | bool leadingIdxsLoopInvariant = true; |
| 1042 | |
| 1043 | // 2. Analyze the leading indices of `extractOp`. |
| 1044 | // Look at the way each index is calculated and decide whether it is suitable |
| 1045 | // for a contiguous load, i.e. whether it's loop invariant. If not, it's a |
| 1046 | // gather load. |
| 1047 | auto indices = extractOp.getIndices(); |
| 1048 | auto leadIndices = indices.drop_back(1); |
| 1049 | |
| 1050 | for (auto [i, indexVal] : llvm::enumerate(leadIndices)) { |
| 1051 | if (inputShape.getShape()[i] == 1) |
| 1052 | continue; |
| 1053 | |
| 1054 | leadingIdxsLoopInvariant &= isLoopInvariantIdx(linalgOp, indexVal, resType); |
| 1055 | } |
| 1056 | |
| 1057 | if (!leadingIdxsLoopInvariant) { |
| 1058 | LDBG("Found gather load: " << extractOp); |
| 1059 | return VectorMemoryAccessKind::Gather; |
| 1060 | } |
| 1061 | |
| 1062 | // 3. Analyze the trailing index for `extractOp`. |
| 1063 | // At this point we know that the leading indices are loop invariant. This |
| 1064 | // means that is potentially a scalar or a contiguous load. We can decide |
| 1065 | // based on the trailing idx. |
| 1066 | auto = indices.back(); |
| 1067 | |
| 1068 | // 3a. Scalar broadcast load |
| 1069 | // If the trailing index is loop invariant then this is a scalar load. |
| 1070 | if (leadingIdxsLoopInvariant && |
| 1071 | isLoopInvariantIdx(linalgOp, extractOpTrailingIdx, resType)) { |
| 1072 | LDBG("Found scalar broadcast load: " << extractOp); |
| 1073 | |
| 1074 | return VectorMemoryAccessKind::ScalarBroadcast; |
| 1075 | } |
| 1076 | |
| 1077 | // 3b. Contiguous loads |
| 1078 | // The trailing `extractOp` index should increment with every loop iteration. |
| 1079 | // This effectively means that it must be based on the trailing loop index. |
| 1080 | // This is what the following bool captures. |
| 1081 | bool foundIndexOp = false; |
| 1082 | bool isContiguousLoad = isContiguousLoadIdx(linalgOp, extractOpTrailingIdx, |
| 1083 | foundIndexOp, resType); |
| 1084 | // TODO: Support generating contiguous loads for column vectors - that will |
| 1085 | // require adding a permutation map to tranfer_read Ops. |
| 1086 | bool isRowVector = resType.getShape().back() != 1; |
| 1087 | isContiguousLoad &= (foundIndexOp && isRowVector); |
| 1088 | |
| 1089 | if (isContiguousLoad) { |
| 1090 | LDBG("Found contigous load: " << extractOp); |
| 1091 | return VectorMemoryAccessKind::Contiguous; |
| 1092 | } |
| 1093 | |
| 1094 | // 4. Fallback case - gather load. |
| 1095 | LDBG("Found gather load: " << extractOp); |
| 1096 | return VectorMemoryAccessKind::Gather; |
| 1097 | } |
| 1098 | |
| 1099 | /// Helper function to vectorize the tensor.extract operations. Returns |
| 1100 | /// VectorizationStatus::NewOp to signal the vectorization algorithm that it |
| 1101 | /// should map the produced operations. This function is meant to be used as a |
| 1102 | /// CustomVectorizationHook. |
| 1103 | static VectorizationResult |
| 1104 | (RewriterBase &rewriter, VectorizationState &state, |
| 1105 | Operation *op, LinalgOp linalgOp, const IRMapping &bvm) { |
| 1106 | tensor::ExtractOp = dyn_cast<tensor::ExtractOp>(op); |
| 1107 | if (!extractOp) |
| 1108 | return VectorizationResult{.status: VectorizationStatus::Failure, .newOp: nullptr}; |
| 1109 | auto loc = extractOp.getLoc(); |
| 1110 | |
| 1111 | // Compute the static loop sizes of the extract op. |
| 1112 | auto resultType = state.getCanonicalVecType(extractOp.getResult().getType()); |
| 1113 | auto maskConstantOp = rewriter.create<arith::ConstantOp>( |
| 1114 | loc, |
| 1115 | DenseIntElementsAttr::get(state.getCanonicalVecType(rewriter.getI1Type()), |
| 1116 | /*value=*/true)); |
| 1117 | auto passThruConstantOp = |
| 1118 | rewriter.create<arith::ConstantOp>(loc, rewriter.getZeroAttr(resultType)); |
| 1119 | |
| 1120 | // Base indices are currently set to 0. We will need to re-visit if more |
| 1121 | // generic scenarios are to be supported. |
| 1122 | SmallVector<Value> baseIndices( |
| 1123 | extractOp.getIndices().size(), |
| 1124 | rewriter.create<arith::ConstantIndexOp>(loc, 0)); |
| 1125 | |
| 1126 | VectorMemoryAccessKind memAccessKind = |
| 1127 | getTensorExtractMemoryAccessPattern(extractOp, linalgOp, resultType); |
| 1128 | |
| 1129 | // 1. Handle gather access |
| 1130 | if (memAccessKind == VectorMemoryAccessKind::Gather) { |
| 1131 | Value offset = calculateGatherOffset(rewriter, state, extractOp, bvm); |
| 1132 | |
| 1133 | // Generate the gather load |
| 1134 | Operation *gatherOp = rewriter.create<vector::GatherOp>( |
| 1135 | loc, resultType, extractOp.getTensor(), baseIndices, offset, |
| 1136 | maskConstantOp, passThruConstantOp); |
| 1137 | gatherOp = state.maskOperation(rewriter, opToMask: gatherOp, linalgOp: linalgOp); |
| 1138 | |
| 1139 | LDBG("Vectorised as gather load: " << extractOp << "\n" ); |
| 1140 | return VectorizationResult{.status: VectorizationStatus::NewOp, .newOp: gatherOp}; |
| 1141 | } |
| 1142 | |
| 1143 | // 2. Handle: |
| 1144 | // a. scalar loads + broadcast, |
| 1145 | // b. contiguous loads. |
| 1146 | // Both cases use vector.transfer_read. |
| 1147 | |
| 1148 | // Collect indices for `vector.transfer_read`. At this point, the indices will |
| 1149 | // either be scalars or would have been broadcast to vectors matching the |
| 1150 | // result type. For indices that are vectors, there are two options: |
| 1151 | // * for non-trailing indices, all elements are identical (contiguous |
| 1152 | // loads are identified by looking for non-trailing indices that are |
| 1153 | // invariant with respect to the corresponding linalg.generic), or |
| 1154 | // * for trailing indices, the index vector will contain values with stride |
| 1155 | // one, but for `vector.transfer_read` only the first (i.e. 0th) index is |
| 1156 | // needed. |
| 1157 | // This means that |
| 1158 | // * for scalar indices - just re-use it, |
| 1159 | // * for vector indices (e.g. `vector<1x1x4xindex>`) - extract the bottom |
| 1160 | // (0th) element and use that. |
| 1161 | SmallVector<Value> transferReadIdxs; |
| 1162 | for (size_t i = 0; i < extractOp.getIndices().size(); i++) { |
| 1163 | Value idx = bvm.lookup(extractOp.getIndices()[i]); |
| 1164 | if (idx.getType().isIndex()) { |
| 1165 | transferReadIdxs.push_back(Elt: idx); |
| 1166 | continue; |
| 1167 | } |
| 1168 | |
| 1169 | auto indexAs1dVector = rewriter.create<vector::ShapeCastOp>( |
| 1170 | loc, |
| 1171 | VectorType::get(resultType.getShape().back(), rewriter.getIndexType(), |
| 1172 | resultType.getScalableDims().back()), |
| 1173 | idx); |
| 1174 | transferReadIdxs.push_back( |
| 1175 | rewriter.create<vector::ExtractOp>(loc, indexAs1dVector, 0)); |
| 1176 | } |
| 1177 | |
| 1178 | // `tensor.extract_element` is always in-bounds, hence the following holds. |
| 1179 | auto dstRank = resultType.getRank(); |
| 1180 | auto srcRank = extractOp.getTensor().getType().getRank(); |
| 1181 | SmallVector<bool> inBounds(dstRank, true); |
| 1182 | |
| 1183 | // 2a. Handle scalar broadcast access. |
| 1184 | if (memAccessKind == VectorMemoryAccessKind::ScalarBroadcast) { |
| 1185 | MLIRContext *ctx = rewriter.getContext(); |
| 1186 | SmallVector<AffineExpr> exprs(dstRank, getAffineConstantExpr(constant: 0, context: ctx)); |
| 1187 | auto permutationMap = AffineMap::get(srcRank, 0, exprs, ctx); |
| 1188 | |
| 1189 | auto transferReadOp = rewriter.create<vector::TransferReadOp>( |
| 1190 | loc, resultType, extractOp.getTensor(), transferReadIdxs, |
| 1191 | permutationMap, inBounds); |
| 1192 | |
| 1193 | // Mask this broadcasting xfer_read here rather than relying on the generic |
| 1194 | // path (the generic path assumes identity masking map, which wouldn't be |
| 1195 | // valid here). |
| 1196 | SmallVector<int64_t> readMaskShape = {1}; |
| 1197 | auto readMaskType = VectorType::get(readMaskShape, rewriter.getI1Type()); |
| 1198 | auto allTrue = rewriter.create<vector::ConstantMaskOp>( |
| 1199 | loc, readMaskType, vector::ConstantMaskKind::AllTrue); |
| 1200 | auto *maskedReadOp = |
| 1201 | mlir::vector::maskOperation(builder&: rewriter, maskableOp: transferReadOp, mask: allTrue); |
| 1202 | |
| 1203 | LDBG("Vectorised as scalar broadcast load: " << extractOp << "\n" ); |
| 1204 | return VectorizationResult{VectorizationStatus::NewOp, maskedReadOp}; |
| 1205 | } |
| 1206 | |
| 1207 | // 2b. Handle contiguous access. |
| 1208 | auto permutationMap = AffineMap::getMinorIdentityMap( |
| 1209 | dims: srcRank, results: std::min(dstRank, srcRank), context: rewriter.getContext()); |
| 1210 | |
| 1211 | int32_t rankDiff = dstRank - srcRank; |
| 1212 | // When dstRank > srcRank, broadcast the source tensor to the unitary leading |
| 1213 | // dims so that the ranks match. This is done by extending the map with 0s. |
| 1214 | // For example, for dstRank = 3, srcRank = 2, the following map created |
| 1215 | // above: |
| 1216 | // (d0, d1) --> (d0, d1) |
| 1217 | // is extended as: |
| 1218 | // (d0, d1) --> (0, d0, d1) |
| 1219 | while (rankDiff > 0) { |
| 1220 | permutationMap = permutationMap.insertResult( |
| 1221 | mlir::getAffineConstantExpr(constant: 0, context: rewriter.getContext()), 0); |
| 1222 | rankDiff--; |
| 1223 | } |
| 1224 | |
| 1225 | auto transferReadOp = rewriter.create<vector::TransferReadOp>( |
| 1226 | loc, resultType, extractOp.getTensor(), transferReadIdxs, permutationMap, |
| 1227 | inBounds); |
| 1228 | |
| 1229 | LDBG("Vectorised as contiguous load: " << extractOp); |
| 1230 | return VectorizationResult{VectorizationStatus::NewOp, transferReadOp}; |
| 1231 | } |
| 1232 | |
| 1233 | /// Emit reduction operations if the shapes of the value to reduce is different |
| 1234 | /// that the result shape. |
| 1235 | // Note: this is a true builder that notifies the OpBuilder listener. |
| 1236 | // TODO: Consider moving as a static helper on the ReduceOp. |
| 1237 | static Operation *reduceIfNeeded(OpBuilder &b, LinalgOp linalgOp, Operation *op, |
| 1238 | Value reduceValue, Value initialValue, |
| 1239 | const IRMapping &bvm) { |
| 1240 | Value reduceVec = bvm.lookup(from: reduceValue); |
| 1241 | Value outputVec = bvm.lookup(from: initialValue); |
| 1242 | auto reduceType = dyn_cast<VectorType>(reduceVec.getType()); |
| 1243 | auto outputType = dyn_cast<VectorType>(outputVec.getType()); |
| 1244 | // Reduce only if needed as the value may already have been reduce for |
| 1245 | // contraction vectorization. |
| 1246 | if (!reduceType || |
| 1247 | (outputType && reduceType.getShape() == outputType.getShape())) |
| 1248 | return nullptr; |
| 1249 | SmallVector<bool> dimsToMask = getDimsToReduce(linalgOp); |
| 1250 | return buildMultiDimReduce(b, reduceOp: op, valueToReduce: reduceVec, acc: outputVec, dimsToMask); |
| 1251 | } |
| 1252 | |
| 1253 | /// Generic vectorization for a single operation `op`, given already vectorized |
| 1254 | /// operands carried by `bvm`. Vectorization occurs as follows: |
| 1255 | /// 1. Try to apply any of the `customVectorizationHooks` and return its |
| 1256 | /// result on success. |
| 1257 | /// 2. Clone any constant in the current scope without vectorization: each |
| 1258 | /// consumer of the constant will later determine the shape to which the |
| 1259 | /// constant needs to be broadcast to. |
| 1260 | /// 3. Fail on any remaining non `ElementwiseMappable` op. It is the purpose |
| 1261 | /// of the `customVectorizationHooks` to cover such cases. |
| 1262 | /// 4. Clone `op` in vector form to a vector of shape prescribed by the first |
| 1263 | /// operand of maximal rank. Other operands have smaller rank and are |
| 1264 | /// broadcast accordingly. It is assumed this broadcast is always legal, |
| 1265 | /// otherwise, it means one of the `customVectorizationHooks` is incorrect. |
| 1266 | /// |
| 1267 | /// This function assumes all operands of `op` have been vectorized and are in |
| 1268 | /// the `bvm` mapping. As a consequence, this function is meant to be called on |
| 1269 | /// a topologically-sorted list of ops. |
| 1270 | /// This function does not update `bvm` but returns a VectorizationStatus that |
| 1271 | /// instructs the caller what `bvm` update needs to occur. |
| 1272 | static VectorizationResult |
| 1273 | vectorizeOneOp(RewriterBase &rewriter, VectorizationState &state, |
| 1274 | LinalgOp linalgOp, Operation *op, const IRMapping &bvm, |
| 1275 | ArrayRef<CustomVectorizationHook> customVectorizationHooks) { |
| 1276 | LDBG("vectorize op " << *op << "\n" ); |
| 1277 | |
| 1278 | // 1. Try to apply any CustomVectorizationHook. |
| 1279 | if (!customVectorizationHooks.empty()) { |
| 1280 | for (auto &customFunc : customVectorizationHooks) { |
| 1281 | VectorizationResult result = customFunc(op, bvm); |
| 1282 | if (result.status == VectorizationStatus::Failure) |
| 1283 | continue; |
| 1284 | return result; |
| 1285 | } |
| 1286 | } |
| 1287 | |
| 1288 | // 2. Constant ops don't get vectorized but rather broadcasted at their users. |
| 1289 | // Clone so that the constant is not confined to the linalgOp block . |
| 1290 | if (isa<arith::ConstantOp, func::ConstantOp>(op)) |
| 1291 | return VectorizationResult{.status: VectorizationStatus::NewOp, .newOp: rewriter.clone(op&: *op)}; |
| 1292 | |
| 1293 | // 3. Only ElementwiseMappable are allowed in the generic vectorization. |
| 1294 | if (!OpTrait::hasElementwiseMappableTraits(op)) |
| 1295 | return VectorizationResult{.status: VectorizationStatus::Failure, .newOp: nullptr}; |
| 1296 | |
| 1297 | // 4 . Check if the operation is a reduction. |
| 1298 | SmallVector<std::pair<Value, Value>> reductionOperands; |
| 1299 | for (Value operand : op->getOperands()) { |
| 1300 | auto blockArg = dyn_cast<BlockArgument>(Val&: operand); |
| 1301 | if (!blockArg || blockArg.getOwner() != linalgOp.getBlock() || |
| 1302 | blockArg.getArgNumber() < linalgOp.getNumDpsInputs()) |
| 1303 | continue; |
| 1304 | SmallVector<Operation *> reductionOps; |
| 1305 | Value reduceValue = matchReduction( |
| 1306 | linalgOp.getRegionOutputArgs(), |
| 1307 | blockArg.getArgNumber() - linalgOp.getNumDpsInputs(), reductionOps); |
| 1308 | if (!reduceValue) |
| 1309 | continue; |
| 1310 | reductionOperands.push_back(Elt: std::make_pair(x&: reduceValue, y&: operand)); |
| 1311 | } |
| 1312 | if (!reductionOperands.empty()) { |
| 1313 | assert(reductionOperands.size() == 1); |
| 1314 | Operation *reduceOp = |
| 1315 | reduceIfNeeded(rewriter, linalgOp, op, reductionOperands[0].first, |
| 1316 | reductionOperands[0].second, bvm); |
| 1317 | if (reduceOp) |
| 1318 | return VectorizationResult{.status: VectorizationStatus::NewOp, .newOp: reduceOp}; |
| 1319 | } |
| 1320 | |
| 1321 | // 5. Generic vectorization path for ElementwiseMappable ops. |
| 1322 | // a. Get the first max ranked shape. |
| 1323 | VectorType firstMaxRankedType; |
| 1324 | for (Value operand : op->getOperands()) { |
| 1325 | auto vecOperand = bvm.lookup(from: operand); |
| 1326 | assert(vecOperand && "Vector operand couldn't be found" ); |
| 1327 | |
| 1328 | auto vecType = dyn_cast<VectorType>(vecOperand.getType()); |
| 1329 | if (vecType && (!firstMaxRankedType || |
| 1330 | firstMaxRankedType.getRank() < vecType.getRank())) |
| 1331 | firstMaxRankedType = vecType; |
| 1332 | } |
| 1333 | // b. Broadcast each op if needed. |
| 1334 | SmallVector<Value> vecOperands; |
| 1335 | for (Value scalarOperand : op->getOperands()) { |
| 1336 | Value vecOperand = bvm.lookup(from: scalarOperand); |
| 1337 | assert(vecOperand && "Vector operand couldn't be found" ); |
| 1338 | |
| 1339 | if (firstMaxRankedType) { |
| 1340 | auto vecType = VectorType::get(firstMaxRankedType.getShape(), |
| 1341 | getElementTypeOrSelf(vecOperand.getType()), |
| 1342 | firstMaxRankedType.getScalableDims()); |
| 1343 | vecOperands.push_back(Elt: broadcastIfNeeded(rewriter, vecOperand, vecType)); |
| 1344 | } else { |
| 1345 | vecOperands.push_back(Elt: vecOperand); |
| 1346 | } |
| 1347 | } |
| 1348 | // c. for elementwise, the result is the vector with the firstMaxRankedShape |
| 1349 | SmallVector<Type> resultTypes; |
| 1350 | for (Type resultType : op->getResultTypes()) { |
| 1351 | resultTypes.push_back( |
| 1352 | firstMaxRankedType |
| 1353 | ? VectorType::get(firstMaxRankedType.getShape(), resultType, |
| 1354 | firstMaxRankedType.getScalableDims()) |
| 1355 | : resultType); |
| 1356 | } |
| 1357 | // d. Build and return the new op. |
| 1358 | return VectorizationResult{ |
| 1359 | .status: VectorizationStatus::NewOp, |
| 1360 | .newOp: rewriter.create(op->getLoc(), op->getName().getIdentifier(), vecOperands, |
| 1361 | resultTypes, op->getAttrs())}; |
| 1362 | } |
| 1363 | |
| 1364 | /// Generic vectorization function that rewrites the body of a `linalgOp` into |
| 1365 | /// vector form. Generic vectorization proceeds as follows: |
| 1366 | /// 1. Verify the `linalgOp` has one non-empty region. |
| 1367 | /// 2. Values defined above the region are mapped to themselves and will be |
| 1368 | /// broadcasted on a per-need basis by their consumers. |
| 1369 | /// 3. Each region argument is vectorized into a vector.transfer_read (or 0-d |
| 1370 | /// load). |
| 1371 | /// TODO: Reuse opportunities for RAR dependencies. |
| 1372 | /// 4a. Register CustomVectorizationHook for YieldOp to capture the results. |
| 1373 | /// 4rewriter. Register CustomVectorizationHook for IndexOp to access the |
| 1374 | /// iteration indices. |
| 1375 | /// 5. Iteratively call vectorizeOneOp on the region operations. |
| 1376 | /// |
| 1377 | /// When `broadcastToMaximalCommonShape` is set to true, eager broadcasting is |
| 1378 | /// performed to the maximal common vector size implied by the `linalgOp` |
| 1379 | /// iteration space. This eager broadcasting is introduced in the |
| 1380 | /// permutation_map of the vector.transfer_read operations. The eager |
| 1381 | /// broadcasting makes it trivial to detrmine where broadcast, transposes and |
| 1382 | /// reductions should occur, without any bookkeeping. The tradeoff is that, in |
| 1383 | /// the absence of good canonicalizations, the amount of work increases. |
| 1384 | /// This is not deemed a problem as we expect canonicalizations and foldings to |
| 1385 | /// aggressively clean up the useless work. |
| 1386 | static LogicalResult |
| 1387 | vectorizeAsLinalgGeneric(RewriterBase &rewriter, VectorizationState &state, |
| 1388 | LinalgOp linalgOp, |
| 1389 | SmallVectorImpl<Value> &newResults) { |
| 1390 | LDBG("Vectorizing operation as linalg generic\n" ); |
| 1391 | Block *block = linalgOp.getBlock(); |
| 1392 | |
| 1393 | // 2. Values defined above the region can only be broadcast for now. Make them |
| 1394 | // map to themselves. |
| 1395 | IRMapping bvm; |
| 1396 | SetVector<Value> valuesSet; |
| 1397 | mlir::getUsedValuesDefinedAbove(linalgOp->getRegion(0), valuesSet); |
| 1398 | bvm.map(from: valuesSet.getArrayRef(), to: valuesSet.getArrayRef()); |
| 1399 | |
| 1400 | if (linalgOp.getNumDpsInits() == 0) |
| 1401 | return failure(); |
| 1402 | |
| 1403 | // 3. Turn all BBArgs into vector.transfer_read / load. |
| 1404 | Location loc = linalgOp.getLoc(); |
| 1405 | Value zero = rewriter.create<arith::ConstantIndexOp>(location: loc, args: 0); |
| 1406 | for (OpOperand *opOperand : linalgOp.getOpOperandsMatchingBBargs()) { |
| 1407 | BlockArgument bbarg = linalgOp.getMatchingBlockArgument(opOperand); |
| 1408 | if (linalgOp.isScalar(opOperand)) { |
| 1409 | bvm.map(bbarg, opOperand->get()); |
| 1410 | continue; |
| 1411 | } |
| 1412 | |
| 1413 | // 3.a. Convert the indexing map for this input/output to a transfer read |
| 1414 | // permutation map and masking map. |
| 1415 | AffineMap indexingMap = linalgOp.getMatchingIndexingMap(opOperand); |
| 1416 | |
| 1417 | AffineMap readMap; |
| 1418 | VectorType readType; |
| 1419 | Type elemType = getElementTypeOrSelf(opOperand->get()); |
| 1420 | if (linalgOp.isDpsInput(opOperand)) { |
| 1421 | // 3.a.i. For input reads we use the canonical vector shape. |
| 1422 | readMap = inverseAndBroadcastProjectedPermutation(indexingMap); |
| 1423 | readType = state.getCanonicalVecType(elemType); |
| 1424 | } else { |
| 1425 | // 3.a.ii. For output reads (iteration-carried dependence, e.g., |
| 1426 | // reductions), the vector shape is computed by mapping the canonical |
| 1427 | // vector shape to the output domain and back to the canonical domain. |
| 1428 | readMap = inversePermutation(reindexIndexingMap(indexingMap)); |
| 1429 | readType = |
| 1430 | state.getCanonicalVecType(elemType, readMap.compose(indexingMap)); |
| 1431 | } |
| 1432 | |
| 1433 | SmallVector<Value> indices(linalgOp.getShape(opOperand).size(), zero); |
| 1434 | |
| 1435 | Operation *read = rewriter.create<vector::TransferReadOp>( |
| 1436 | loc, readType, opOperand->get(), indices, readMap); |
| 1437 | read = state.maskOperation(rewriter, read, linalgOp, indexingMap); |
| 1438 | Value readValue = read->getResult(0); |
| 1439 | |
| 1440 | // 3.b. If masked, set in-bounds to true. Masking guarantees that the access |
| 1441 | // will be in-bounds. |
| 1442 | if (auto maskOp = dyn_cast<vector::MaskingOpInterface>(read)) { |
| 1443 | SmallVector<bool> inBounds(readType.getRank(), true); |
| 1444 | cast<vector::TransferReadOp>(maskOp.getMaskableOp()) |
| 1445 | .setInBoundsAttr(rewriter.getBoolArrayAttr(inBounds)); |
| 1446 | } |
| 1447 | |
| 1448 | // 3.c. Not all ops support 0-d vectors, extract the scalar for now. |
| 1449 | // TODO: remove this. |
| 1450 | if (readType.getRank() == 0) |
| 1451 | readValue = rewriter.create<vector::ExtractOp>(loc, readValue, |
| 1452 | ArrayRef<int64_t>()); |
| 1453 | |
| 1454 | LDBG("New vectorized bbarg(" << bbarg.getArgNumber() << "): " << readValue |
| 1455 | << "\n" ); |
| 1456 | bvm.map(bbarg, readValue); |
| 1457 | bvm.map(opOperand->get(), readValue); |
| 1458 | } |
| 1459 | |
| 1460 | SmallVector<CustomVectorizationHook> hooks; |
| 1461 | // 4a. Register CustomVectorizationHook for yieldOp. |
| 1462 | CustomVectorizationHook vectorizeYield = |
| 1463 | [&](Operation *op, const IRMapping &bvm) -> VectorizationResult { |
| 1464 | return vectorizeLinalgYield(rewriter, op, bvm, state, linalgOp, newResults); |
| 1465 | }; |
| 1466 | hooks.push_back(Elt: vectorizeYield); |
| 1467 | |
| 1468 | // 4b. Register CustomVectorizationHook for indexOp. |
| 1469 | CustomVectorizationHook vectorizeIndex = |
| 1470 | [&](Operation *op, const IRMapping &bvm) -> VectorizationResult { |
| 1471 | return vectorizeLinalgIndex(rewriter, state, op, linalgOp); |
| 1472 | }; |
| 1473 | hooks.push_back(Elt: vectorizeIndex); |
| 1474 | |
| 1475 | // 4c. Register CustomVectorizationHook for extractOp. |
| 1476 | CustomVectorizationHook = |
| 1477 | [&](Operation *op, const IRMapping &bvm) -> VectorizationResult { |
| 1478 | return vectorizeTensorExtract(rewriter, state, op, linalgOp, bvm); |
| 1479 | }; |
| 1480 | hooks.push_back(Elt: vectorizeExtract); |
| 1481 | |
| 1482 | // 5. Iteratively call `vectorizeOneOp` to each op in the slice. |
| 1483 | for (Operation &op : block->getOperations()) { |
| 1484 | VectorizationResult result = |
| 1485 | vectorizeOneOp(rewriter, state, linalgOp, &op, bvm, hooks); |
| 1486 | if (result.status == VectorizationStatus::Failure) { |
| 1487 | LDBG("failed to vectorize: " << op << "\n" ); |
| 1488 | return failure(); |
| 1489 | } |
| 1490 | if (result.status == VectorizationStatus::NewOp) { |
| 1491 | Operation *maybeMaskedOp = |
| 1492 | state.maskOperation(rewriter, result.newOp, linalgOp); |
| 1493 | LDBG("New vector op: " << *maybeMaskedOp << "\n" ); |
| 1494 | bvm.map(op.getResults(), maybeMaskedOp->getResults()); |
| 1495 | } |
| 1496 | } |
| 1497 | |
| 1498 | return success(); |
| 1499 | } |
| 1500 | |
| 1501 | /// Given a linalg::PackOp, return the `dest` shape before any packing |
| 1502 | /// permutations. |
| 1503 | static SmallVector<int64_t> getTiledPackShape(linalg::PackOp packOp, |
| 1504 | ArrayRef<int64_t> destShape) { |
| 1505 | return applyPermutation(destShape, linalg::getPackInverseDestPerm(packOp)); |
| 1506 | } |
| 1507 | |
| 1508 | /// Determines whether a mask for xfer_write is trivially "all true" |
| 1509 | /// |
| 1510 | /// Given all the inputs required to generate a mask (mask sizes and shapes), |
| 1511 | /// and an xfer_write operation (write indices and the destination tensor |
| 1512 | /// shape), determines whether the corresponding mask would be trivially |
| 1513 | /// foldable (i.e., trivially "all true"). |
| 1514 | /// |
| 1515 | /// Use this method to avoid generating spurious masks and relaying on |
| 1516 | /// vectorization post-processing to remove them. |
| 1517 | /// |
| 1518 | /// Pre-conditions for a mask to be trivially foldable: |
| 1519 | /// * All involved shapes (mask + destination tensor) are static. |
| 1520 | /// * All write indices are constant. |
| 1521 | /// * All mask sizes are constant (including `arith.constant`). |
| 1522 | /// |
| 1523 | /// If the pre-conditions are met, the method checks for each destination |
| 1524 | /// dimension `d`: |
| 1525 | /// (1) destDimSize[rankDiff + d] <= maskShape[d] |
| 1526 | /// (2) destDimSize[rankDiff + d] <= writeIndex[d] + maskSize[d] |
| 1527 | /// |
| 1528 | /// rankDiff = rank(dest) - rank(mask). |
| 1529 | /// |
| 1530 | /// This method takes a conservative view: it may return false even if the mask |
| 1531 | /// is technically foldable. |
| 1532 | /// |
| 1533 | /// EXAMPLE 1 (trivially foldable, all shapes match, mask sizes match the shape |
| 1534 | /// of the dest tensor): |
| 1535 | /// %c0 = arith.constant 0 : index |
| 1536 | /// %mask = vector.create_mask 5, 1 |
| 1537 | /// vector.mask %mask { |
| 1538 | /// vector.transfer_write %vecToStore_1, %dest{[%c0, %c0] |
| 1539 | /// {in_bounds = [true, true]} |
| 1540 | /// : vector<5x1xi32>, tensor<5x1xi32> |
| 1541 | /// } |
| 1542 | /// |
| 1543 | /// EXAMPLE 2 (not trivially foldable - vector shape exceeds the tensor shape, |
| 1544 | /// mask is required to avoid out-of-bounds write): |
| 1545 | /// %c0 = arith.constant 0 : index |
| 1546 | /// %mask = vector.create_mask 5, 1 |
| 1547 | /// vector.mask %mask { |
| 1548 | /// vector.transfer_write %vecToStore_2, %dest[%c0, %c0] |
| 1549 | /// {in_bounds = [true, true]} |
| 1550 | /// : vector<8x1xi32>, tensor<5x1xi32> |
| 1551 | /// } |
| 1552 | /// |
| 1553 | /// TODO: Re-use in createReadOrMaskedRead |
| 1554 | static bool isMaskTriviallyFoldable(SmallVector<OpFoldResult> &maskSizes, |
| 1555 | SmallVector<Value> &writeIdxs, |
| 1556 | ArrayRef<int64_t> destShape, |
| 1557 | ArrayRef<int64_t> maskShape) { |
| 1558 | // Masking is unavoidable in the case of dynamic tensors. |
| 1559 | if (ShapedType::isDynamicShape(destShape)) |
| 1560 | return false; |
| 1561 | |
| 1562 | // Collect all constant mask sizes. |
| 1563 | SmallVector<int64_t, 4> cstMaskSizes; |
| 1564 | for (auto [i, dimSize] : llvm::enumerate(First&: maskSizes)) { |
| 1565 | if (auto intSize = getConstantIntValue(ofr: dimSize)) { |
| 1566 | cstMaskSizes.push_back(Elt: *intSize); |
| 1567 | } |
| 1568 | } |
| 1569 | |
| 1570 | // If any of the mask sizes is non-constant, bail out. |
| 1571 | if (cstMaskSizes.size() != maskShape.size()) |
| 1572 | return false; |
| 1573 | |
| 1574 | // Collect all constant write indices. |
| 1575 | SmallVector<int64_t, 4> cstWriteIdxs; |
| 1576 | for (auto [i, idx] : llvm::enumerate(First&: writeIdxs)) { |
| 1577 | APSInt intVal; |
| 1578 | if (matchPattern(idx, m_ConstantInt(&intVal))) { |
| 1579 | cstWriteIdxs.push_back(Elt: intVal.getSExtValue()); |
| 1580 | } |
| 1581 | } |
| 1582 | |
| 1583 | // If any of the write indices is non-constant, bail out. |
| 1584 | if (cstWriteIdxs.size() != destShape.size()) |
| 1585 | return false; |
| 1586 | |
| 1587 | // Go over all destination dims and check (1) and (2). Take into account that: |
| 1588 | // * The number of mask sizes will match the rank of the vector to store. |
| 1589 | // This could be lower than the rank of the destination tensor. |
| 1590 | // * Mask sizes could be larger than the corresponding mask shape (hence |
| 1591 | // `clamp`). |
| 1592 | // TODO: The 2nd item should be rejected by the verifier. |
| 1593 | int64_t rankDiff = destShape.size() - cstMaskSizes.size(); |
| 1594 | for (auto [i, idx] : llvm::enumerate(First&: cstMaskSizes)) { |
| 1595 | if (/*(1)*/ maskShape[i] > destShape[rankDiff + i] || |
| 1596 | /*(2)*/ destShape[rankDiff + i] < |
| 1597 | (std::clamp(val: cstMaskSizes[i], lo: int64_t(0), hi: maskShape[i]) + |
| 1598 | cstWriteIdxs[i])) |
| 1599 | return false; |
| 1600 | } |
| 1601 | |
| 1602 | return true; |
| 1603 | } |
| 1604 | |
| 1605 | /// Creates an optionally masked TransferWriteOp |
| 1606 | /// |
| 1607 | /// Generates the following operation: |
| 1608 | /// %res = vector.transfer_write %vecToStore into %dest |
| 1609 | /// |
| 1610 | /// If shape(vecToStore) != shape(dest), masking is used to ensure correctness: |
| 1611 | /// |
| 1612 | /// %mask = vector.create_mask(%destShape) : %vecToStoreShape |
| 1613 | /// %res = vector.mask %mask { |
| 1614 | /// vector.transfer_write %vecToStore into %dest |
| 1615 | /// } |
| 1616 | /// |
| 1617 | /// The mask shape is identical to `vecToStore` (with the element type == |
| 1618 | /// i1), and the mask values are based on the shape of the `dest` tensor. |
| 1619 | /// |
| 1620 | /// If `useInBoundsInsteadOfMasking` is set to `true`, the `in_bounds` attribute |
| 1621 | /// is used instead of masking: |
| 1622 | /// |
| 1623 | /// %write = vector.transfer_write %vecToStore into %dest |
| 1624 | /// in_bounds_flags = (...) |
| 1625 | /// %res = vector.transfer_write %input into %dest |
| 1626 | /// {in_bounds = in_bounds_flags} |
| 1627 | /// |
| 1628 | /// Finally, `writeIndices` specifies the offsets to use. If empty, all indices |
| 1629 | /// are set to 0. |
| 1630 | static Operation * |
| 1631 | createWriteOrMaskedWrite(OpBuilder &builder, Location loc, Value vecToStore, |
| 1632 | Value dest, SmallVector<Value> writeIndices = {}, |
| 1633 | bool useInBoundsInsteadOfMasking = false) { |
| 1634 | |
| 1635 | ShapedType destType = cast<ShapedType>(dest.getType()); |
| 1636 | int64_t destRank = destType.getRank(); |
| 1637 | auto destShape = destType.getShape(); |
| 1638 | |
| 1639 | VectorType vecToStoreType = cast<VectorType>(vecToStore.getType()); |
| 1640 | int64_t vecToStoreRank = vecToStoreType.getRank(); |
| 1641 | auto vecToStoreShape = vecToStoreType.getShape(); |
| 1642 | |
| 1643 | // Compute the in_bounds attribute |
| 1644 | SmallVector<bool> inBoundsVal(vecToStoreRank, true); |
| 1645 | if (useInBoundsInsteadOfMasking) { |
| 1646 | // Update the inBounds attribute. |
| 1647 | // FIXME: This computation is too weak - it ignores the write indices. |
| 1648 | for (unsigned i = 0; i < vecToStoreRank; i++) |
| 1649 | inBoundsVal[i] = |
| 1650 | (destShape[destRank - vecToStoreRank + i] >= vecToStoreShape[i]) && |
| 1651 | !ShapedType::isDynamic(destShape[destRank - vecToStoreRank + i]); |
| 1652 | } |
| 1653 | |
| 1654 | // If missing, initialize the write indices to 0. |
| 1655 | assert(writeIndices.empty() || |
| 1656 | writeIndices.size() == static_cast<size_t>(destRank) && |
| 1657 | "Invalid number of write indices!" ); |
| 1658 | if (writeIndices.empty()) { |
| 1659 | auto zero = builder.create<arith::ConstantIndexOp>(location: loc, args: 0); |
| 1660 | writeIndices.assign(destRank, zero); |
| 1661 | } |
| 1662 | |
| 1663 | // Generate the xfer_write Op |
| 1664 | Operation *write = |
| 1665 | builder.create<vector::TransferWriteOp>(loc, |
| 1666 | /*vector=*/vecToStore, |
| 1667 | /*source=*/dest, |
| 1668 | /*indices=*/writeIndices, |
| 1669 | /*inBounds=*/inBoundsVal); |
| 1670 | |
| 1671 | // If masking is disabled, exit. |
| 1672 | if (useInBoundsInsteadOfMasking) |
| 1673 | return write; |
| 1674 | |
| 1675 | // Check if masking is needed. If not, exit. |
| 1676 | if (llvm::equal(vecToStoreShape, destShape.take_back(vecToStoreRank))) |
| 1677 | return write; |
| 1678 | |
| 1679 | // Compute the mask and mask the write Op. |
| 1680 | auto writeMaskType = VectorType::get(vecToStoreShape, builder.getI1Type()); |
| 1681 | |
| 1682 | SmallVector<OpFoldResult> destSizes = |
| 1683 | tensor::getMixedSizes(builder, loc, value: dest); |
| 1684 | SmallVector<OpFoldResult> maskSizes(destSizes.end() - vecToStoreRank, |
| 1685 | destSizes.end()); |
| 1686 | |
| 1687 | if (isMaskTriviallyFoldable(maskSizes, writeIndices, destShape, |
| 1688 | vecToStoreShape)) |
| 1689 | return write; |
| 1690 | |
| 1691 | Value maskForWrite = |
| 1692 | builder.createOrFold<vector::CreateMaskOp>(loc, writeMaskType, maskSizes); |
| 1693 | return mlir::vector::maskOperation(builder, maskableOp: write, mask: maskForWrite); |
| 1694 | } |
| 1695 | |
| 1696 | /// Vectorize linalg::PackOp with (1) static inner_tiles (2) constant |
| 1697 | /// padding value and (3) input vector sizes into: |
| 1698 | /// |
| 1699 | /// masked_transfer_read->shape_cast->transpose->transfer_write_in_bounds |
| 1700 | /// |
| 1701 | /// As in the following example: |
| 1702 | /// %pack = tensor.pack %src inner_dims_pos = [2, 1] inner_tiles = [16, 2] |
| 1703 | /// into %dst : tensor<32x8x16xf32> -> tensor<32x4x1x16x2xf32> |
| 1704 | /// |
| 1705 | /// This pack would be vectorized to: |
| 1706 | /// |
| 1707 | /// %load = vector.mask %mask { |
| 1708 | /// vector.transfer_read %arg0[%c0, %c0, %c0], %cst |
| 1709 | /// {in_bounds = [true, true, true]} : |
| 1710 | /// tensor<32x7x16xf32>, vector<32x8x16xf32> |
| 1711 | /// } : vector<32x8x16xi1> -> vector<32x8x16xf32> |
| 1712 | /// %shape_cast = vector.shape_cast %load : vector<32x8x16xf32> |
| 1713 | /// to vector<32x4x2x1x16xf32> |
| 1714 | /// %transpose = vector.transpose %shape_cast, [0, 1, 3, 4, 2] |
| 1715 | /// : vector<32x4x2x1x16xf32> to vector<32x4x1x16x2xf32> |
| 1716 | /// %write = vector.transfer_write %transpose, |
| 1717 | /// %empty[%c0_0, %c0_0, %c0_0, %c0_0, %c0_0] |
| 1718 | /// {in_bounds = [true, true, true, true, true]} |
| 1719 | /// : vector<32x4x1x16x2xf32>, tensor<32x4x1x16x2xf32> |
| 1720 | /// |
| 1721 | /// If the (3) input vector sizes are not provided, the vector sizes are |
| 1722 | /// determined by the result tensor shape and the `in_bounds` |
| 1723 | /// attribute is used instead of masking to mark out-of-bounds accesses. |
| 1724 | /// |
| 1725 | /// NOTE: The input vector sizes specify the dimensions corresponding to the |
| 1726 | /// outer dimensions of the output tensor. The remaining dimensions are |
| 1727 | /// computed based on, e.g., the static inner tiles. |
| 1728 | /// Supporting dynamic inner tiles will require the user to specify the |
| 1729 | /// missing vector sizes. This is left as a TODO. |
| 1730 | static LogicalResult |
| 1731 | vectorizeAsTensorPackOp(RewriterBase &rewriter, linalg::PackOp packOp, |
| 1732 | ArrayRef<int64_t> inputVectorSizes, |
| 1733 | SmallVectorImpl<Value> &newResults) { |
| 1734 | // TODO: Introduce a parent class that will handle the insertion point update. |
| 1735 | OpBuilder::InsertionGuard g(rewriter); |
| 1736 | rewriter.setInsertionPoint(packOp); |
| 1737 | |
| 1738 | Location loc = packOp.getLoc(); |
| 1739 | auto padValue = packOp.getPaddingValue(); |
| 1740 | if (!padValue) { |
| 1741 | padValue = rewriter.create<arith::ConstantOp>( |
| 1742 | loc, rewriter.getZeroAttr(packOp.getSourceType().getElementType())); |
| 1743 | } |
| 1744 | ReifiedRankedShapedTypeDims reifiedReturnShapes; |
| 1745 | LogicalResult status = |
| 1746 | cast<ReifyRankedShapedTypeOpInterface>(packOp.getOperation()) |
| 1747 | .reifyResultShapes(rewriter, reifiedReturnShapes); |
| 1748 | (void)status; // prevent unused variable warning on non-assert builds. |
| 1749 | assert(succeeded(status) && "failed to reify result shapes" ); |
| 1750 | |
| 1751 | // If the input vector sizes are not provided, then the vector sizes are |
| 1752 | // determined by the result tensor shape. In case the vector sizes aren't |
| 1753 | // provided, we update the inBounds attribute instead of masking. |
| 1754 | bool useInBoundsInsteadOfMasking = false; |
| 1755 | if (inputVectorSizes.empty()) { |
| 1756 | ArrayRef<int64_t> resultTensorShape = packOp.getDestType().getShape(); |
| 1757 | inputVectorSizes = resultTensorShape.take_front(N: packOp.getSourceRank()); |
| 1758 | useInBoundsInsteadOfMasking = true; |
| 1759 | } |
| 1760 | |
| 1761 | // Create masked TransferReadOp. |
| 1762 | SmallVector<int64_t> inputShape(inputVectorSizes); |
| 1763 | auto innerTiles = packOp.getStaticInnerTiles(); |
| 1764 | auto innerDimsPos = packOp.getInnerDimsPos(); |
| 1765 | auto outerDimsPerm = packOp.getOuterDimsPerm(); |
| 1766 | if (!outerDimsPerm.empty()) |
| 1767 | applyPermutationToVector(inputShape, |
| 1768 | invertPermutationVector(outerDimsPerm)); |
| 1769 | for (auto [idx, size] : enumerate(innerTiles)) |
| 1770 | inputShape[innerDimsPos[idx]] *= size; |
| 1771 | auto maskedRead = vector::createReadOrMaskedRead( |
| 1772 | builder&: rewriter, loc, source: packOp.getSource(), inputVectorSizes: inputShape, padValue: padValue, |
| 1773 | useInBoundsInsteadOfMasking); |
| 1774 | |
| 1775 | // Create ShapeCastOp. |
| 1776 | SmallVector<int64_t> destShape(inputVectorSizes); |
| 1777 | destShape.append(innerTiles.begin(), innerTiles.end()); |
| 1778 | auto tiledPackType = VectorType::get(getTiledPackShape(packOp, destShape), |
| 1779 | packOp.getDestType().getElementType()); |
| 1780 | auto shapeCastOp = |
| 1781 | rewriter.create<vector::ShapeCastOp>(loc, tiledPackType, maskedRead); |
| 1782 | |
| 1783 | // Create TransposeOp. |
| 1784 | auto destPermutation = |
| 1785 | invertPermutationVector(getPackInverseDestPerm(packOp)); |
| 1786 | auto transposeOp = rewriter.create<vector::TransposeOp>( |
| 1787 | loc, shapeCastOp.getResult(), destPermutation); |
| 1788 | |
| 1789 | // Create TransferWriteOp. |
| 1790 | Value dest = rewriter.create<tensor::EmptyOp>( |
| 1791 | loc, reifiedReturnShapes[0], |
| 1792 | transposeOp.getResult().getType().getElementType()); |
| 1793 | Operation *write = |
| 1794 | createWriteOrMaskedWrite(rewriter, loc, transposeOp.getResult(), dest); |
| 1795 | newResults.push_back(Elt: write->getResult(idx: 0)); |
| 1796 | return success(); |
| 1797 | } |
| 1798 | |
| 1799 | /// Vectorize a `linalg::UnPackOp` to these 4 Ops: |
| 1800 | /// Vector::TransferReadOp - Reads a vector from the source tensor |
| 1801 | /// vector::TransposeOp - Transpose the Source tensor |
| 1802 | /// ShapeCastOp - Reshape the data based on the target. |
| 1803 | /// vector::TransferWriteOp. - Write the result vector back to the destination |
| 1804 | /// tensor. |
| 1805 | /// If the vector sizes are not provided: |
| 1806 | /// * the vector sizes are determined by the input operand and attributes, |
| 1807 | /// * update the inBounds attribute instead of masking. |
| 1808 | static LogicalResult |
| 1809 | vectorizeAsTensorUnpackOp(RewriterBase &rewriter, linalg::UnPackOp unpackOp, |
| 1810 | ArrayRef<int64_t> inputVectorSizes, |
| 1811 | SmallVectorImpl<Value> &newResults) { |
| 1812 | |
| 1813 | // TODO: Introduce a parent class that will handle the insertion point update. |
| 1814 | OpBuilder::InsertionGuard g(rewriter); |
| 1815 | rewriter.setInsertionPoint(unpackOp); |
| 1816 | |
| 1817 | RankedTensorType unpackTensorType = unpackOp.getSourceType(); |
| 1818 | |
| 1819 | ArrayRef<int64_t> innerDimPos = unpackOp.getInnerDimsPos(); |
| 1820 | ArrayRef<int64_t> innerTiles = unpackOp.getStaticInnerTiles(); |
| 1821 | ArrayRef<int64_t> sourceShape = unpackTensorType.getShape(); |
| 1822 | bool useInBoundsInsteadOfMasking = false; |
| 1823 | ArrayRef<int64_t> outerDimsPerm = unpackOp.getOuterDimsPerm(); |
| 1824 | |
| 1825 | auto destSize = unpackOp.getDestRank(); |
| 1826 | |
| 1827 | if (!inputVectorSizes.empty()) |
| 1828 | assert(inputVectorSizes.size() == destSize && |
| 1829 | "Incorrect number of input vector sizes" ); |
| 1830 | |
| 1831 | // vectorSizes is the shape of the vector that will be used to do final |
| 1832 | // write on the destination tensor. It is set like this: Let's say the |
| 1833 | // source tensor is rank 'M' and the dest tensor rank 'N', where N <= M. |
| 1834 | // Thus: |
| 1835 | // 1. vectorSizes = sourceShape.take_front(N) |
| 1836 | // 2. if outer_dims_perms is present: do that permutation on vectorSizes. |
| 1837 | // 3. multiply all the locations in vectorSize pointed by innerDimPos by the |
| 1838 | // innerTiles attribute value. |
| 1839 | SmallVector<int64_t> vectorSizes(inputVectorSizes); |
| 1840 | if (vectorSizes.empty()) { |
| 1841 | llvm::append_range(vectorSizes, sourceShape.take_front(N: destSize)); |
| 1842 | if (!outerDimsPerm.empty()) |
| 1843 | applyPermutationToVector(inVec&: vectorSizes, permutation: outerDimsPerm); |
| 1844 | for (auto [i, pos] : llvm::enumerate(innerDimPos)) |
| 1845 | vectorSizes[pos] *= innerTiles[i]; |
| 1846 | |
| 1847 | useInBoundsInsteadOfMasking = true; |
| 1848 | } |
| 1849 | |
| 1850 | // readVectorSizes is the size of tensor used to read and apply mask. It is |
| 1851 | // set like this: Let's say the vectorSize (VS) array is size 'N' and |
| 1852 | // the sourceShape(SS) is 'M' where M >= N and InnerTileSizes (IT) of |
| 1853 | // size M-N |
| 1854 | // Thus: |
| 1855 | // - initially: readVectorSizes = vectorInputSizes |
| 1856 | // - Divide all the readMaskShape locations pointed by innerDimPos |
| 1857 | // by the innerTileSize attribute value. |
| 1858 | // - if outer_dims_perms is present: do that permutation on readVectorSizes. |
| 1859 | // - Append the remaining shape from SS |
| 1860 | // E.g. let's say let's say unpackTensorType.getShape() = <8x8x32x16> |
| 1861 | // inner Dim Pos = [0, 1] and Inner Tiles = [32, 16], vector_sizes are [512, |
| 1862 | // 128] and outer_dims_perm is [1, 0] then read shape is: |
| 1863 | // ReadVectorSizes(initial): [512, 128] |
| 1864 | // Final Value(after innerDim Adjustment): [512/32, 128/16] |
| 1865 | // = [16, 8] |
| 1866 | // After applying outer_dims_perm: [8, 16] |
| 1867 | // After appending the rest of the sourceShape: [8, 16, 32, 16] |
| 1868 | |
| 1869 | SmallVector<int64_t> readVectorSizes(vectorSizes.begin(), vectorSizes.end()); |
| 1870 | |
| 1871 | for (auto [index, size] : enumerate(innerTiles)) { |
| 1872 | readVectorSizes[innerDimPos[index]] = |
| 1873 | llvm::divideCeil(readVectorSizes[innerDimPos[index]], size); |
| 1874 | } |
| 1875 | if (!outerDimsPerm.empty()) { |
| 1876 | applyPermutationToVector(inVec&: readVectorSizes, permutation: outerDimsPerm); |
| 1877 | } |
| 1878 | readVectorSizes.append(in_start: sourceShape.begin() + vectorSizes.size(), |
| 1879 | in_end: sourceShape.end()); |
| 1880 | |
| 1881 | ReifiedRankedShapedTypeDims reifiedRetShapes; |
| 1882 | LogicalResult status = |
| 1883 | cast<ReifyRankedShapedTypeOpInterface>(unpackOp.getOperation()) |
| 1884 | .reifyResultShapes(rewriter, reifiedRetShapes); |
| 1885 | if (status.failed()) { |
| 1886 | LDBG("Unable to reify result shapes of " << unpackOp); |
| 1887 | return failure(); |
| 1888 | } |
| 1889 | Location loc = unpackOp->getLoc(); |
| 1890 | |
| 1891 | auto padValue = rewriter.create<arith::ConstantOp>( |
| 1892 | loc, rewriter.getZeroAttr(unpackOp.getSourceType().getElementType())); |
| 1893 | |
| 1894 | // Read result, mask if necessary. If transferReadOp shape is not equal |
| 1895 | // to shape of source, then a mask is necessary. |
| 1896 | Value readResult = vector::createReadOrMaskedRead( |
| 1897 | builder&: rewriter, loc, source: unpackOp.getSource(), inputVectorSizes: readVectorSizes, padValue: padValue, |
| 1898 | /*useInBoundsInsteadOfMasking=*/false); |
| 1899 | |
| 1900 | PackingMetadata packMetadata; |
| 1901 | SmallVector<int64_t> lastDimToInsertPosPerm = |
| 1902 | getUnPackInverseSrcPerm(unpackOp, packMetadata); |
| 1903 | ShapedType maskedOpShapedType = cast<ShapedType>(readResult.getType()); |
| 1904 | SmallVector<int64_t> stripMineShape(maskedOpShapedType.getShape()); |
| 1905 | mlir::Type stripMineElemType = maskedOpShapedType.getElementType(); |
| 1906 | applyPermutationToVector(inVec&: stripMineShape, permutation: lastDimToInsertPosPerm); |
| 1907 | RankedTensorType stripMineTensorType = |
| 1908 | RankedTensorType::get(stripMineShape, stripMineElemType); |
| 1909 | // Transpose the appropriate rows to match output. |
| 1910 | vector::TransposeOp transposeOp = rewriter.create<vector::TransposeOp>( |
| 1911 | loc, readResult, lastDimToInsertPosPerm); |
| 1912 | |
| 1913 | // Collapse the vector to the size required by result. |
| 1914 | RankedTensorType collapsedType = tensor::CollapseShapeOp::inferCollapsedType( |
| 1915 | stripMineTensorType, packMetadata.reassociations); |
| 1916 | mlir::VectorType vecCollapsedType = |
| 1917 | VectorType::get(collapsedType.getShape(), collapsedType.getElementType()); |
| 1918 | vector::ShapeCastOp shapeCastOp = rewriter.create<vector::ShapeCastOp>( |
| 1919 | loc, vecCollapsedType, transposeOp->getResult(0)); |
| 1920 | |
| 1921 | // writeVectorSizes had to match the shapecast shape for dynamic sizes, |
| 1922 | // otherwise the validator complains that the mask size is invalid. |
| 1923 | SmallVector<int64_t> writeVectorSizes( |
| 1924 | unpackOp.getDestType().hasStaticShape() |
| 1925 | ? vectorSizes |
| 1926 | : shapeCastOp.getResultVectorType().getShape()); |
| 1927 | Value dest = rewriter.create<tensor::EmptyOp>( |
| 1928 | loc, reifiedRetShapes[0], |
| 1929 | shapeCastOp.getResult().getType().getElementType()); |
| 1930 | Operation *write = createWriteOrMaskedWrite( |
| 1931 | rewriter, loc, shapeCastOp.getResult(), dest, |
| 1932 | /*writeIndices=*/{}, useInBoundsInsteadOfMasking); |
| 1933 | newResults.push_back(Elt: write->getResult(idx: 0)); |
| 1934 | return success(); |
| 1935 | } |
| 1936 | |
| 1937 | /// Vectorize a `padOp` with (1) static result type, (2) constant padding value |
| 1938 | /// and (3) all-zero lowPad to |
| 1939 | /// `transfer_write_in_bounds(transfer_read_masked(pad_source, pad_value))`. |
| 1940 | static LogicalResult |
| 1941 | vectorizeAsTensorPadOp(RewriterBase &rewriter, tensor::PadOp padOp, |
| 1942 | ArrayRef<int64_t> inputVectorSizes, |
| 1943 | SmallVectorImpl<Value> &newResults) { |
| 1944 | auto padValue = padOp.getConstantPaddingValue(); |
| 1945 | Location loc = padOp.getLoc(); |
| 1946 | |
| 1947 | // TODO: Introduce a parent class that will handle the insertion point update. |
| 1948 | OpBuilder::InsertionGuard g(rewriter); |
| 1949 | rewriter.setInsertionPoint(padOp); |
| 1950 | |
| 1951 | ReifiedRankedShapedTypeDims reifiedReturnShapes; |
| 1952 | LogicalResult status = |
| 1953 | cast<ReifyRankedShapedTypeOpInterface>(padOp.getOperation()) |
| 1954 | .reifyResultShapes(rewriter, reifiedReturnShapes); |
| 1955 | (void)status; // prevent unused variable warning on non-assert builds |
| 1956 | assert(succeeded(status) && "failed to reify result shapes" ); |
| 1957 | auto maskedRead = vector::createReadOrMaskedRead( |
| 1958 | builder&: rewriter, loc, source: padOp.getSource(), inputVectorSizes, padValue: padValue, |
| 1959 | /*useInBoundsInsteadOfMasking=*/false); |
| 1960 | |
| 1961 | // Create Xfer write Op |
| 1962 | Value dest = rewriter.create<tensor::EmptyOp>( |
| 1963 | loc, reifiedReturnShapes[0], padOp.getResultType().getElementType()); |
| 1964 | Operation *write = createWriteOrMaskedWrite(rewriter, loc, maskedRead, dest); |
| 1965 | newResults.push_back(Elt: write->getResult(idx: 0)); |
| 1966 | return success(); |
| 1967 | } |
| 1968 | |
| 1969 | // TODO: probably need some extra checks for reduction followed by consumer |
| 1970 | // ops that may not commute (e.g. linear reduction + non-linear instructions). |
| 1971 | static LogicalResult reductionPreconditions(LinalgOp op) { |
| 1972 | if (llvm::none_of(op.getIteratorTypesArray(), isReductionIterator)) { |
| 1973 | LDBG("reduction precondition failed: no reduction iterator\n" ); |
| 1974 | return failure(); |
| 1975 | } |
| 1976 | for (OpOperand &opOperand : op.getDpsInitsMutable()) { |
| 1977 | AffineMap indexingMap = op.getMatchingIndexingMap(&opOperand); |
| 1978 | if (indexingMap.isPermutation()) |
| 1979 | continue; |
| 1980 | |
| 1981 | Operation *reduceOp = matchLinalgReduction(&opOperand); |
| 1982 | if (!reduceOp || !getCombinerOpKind(reduceOp)) { |
| 1983 | LDBG("reduction precondition failed: reduction detection failed\n" ); |
| 1984 | return failure(); |
| 1985 | } |
| 1986 | } |
| 1987 | return success(); |
| 1988 | } |
| 1989 | |
| 1990 | static LogicalResult |
| 1991 | vectorizeDynamicConvOpPrecondition(linalg::LinalgOp conv, |
| 1992 | bool flatten1DDepthwiseConv) { |
| 1993 | if (flatten1DDepthwiseConv) { |
| 1994 | LDBG("Vectorization of flattened convs with dynamic shapes is not " |
| 1995 | "supported\n" ); |
| 1996 | return failure(); |
| 1997 | } |
| 1998 | |
| 1999 | if (!isa<linalg::DepthwiseConv1DNwcWcOp>(conv)) { |
| 2000 | LDBG("Not a 1D depth-wise WC conv, dynamic shapes are not supported\n" ); |
| 2001 | return failure(); |
| 2002 | } |
| 2003 | |
| 2004 | // Support dynamic shapes in 1D depthwise convolution, but only in the |
| 2005 | // _channel_ dimension. |
| 2006 | Value lhs = conv.getDpsInputOperand(0)->get(); |
| 2007 | ArrayRef<int64_t> lhsShape = cast<ShapedType>(lhs.getType()).getShape(); |
| 2008 | auto shapeWithoutCh = lhsShape.drop_back(N: 1); |
| 2009 | if (ShapedType::isDynamicShape(shapeWithoutCh)) { |
| 2010 | LDBG("Dynamically-shaped op vectorization precondition failed: only " |
| 2011 | "channel dim can be dynamic\n" ); |
| 2012 | return failure(); |
| 2013 | } |
| 2014 | |
| 2015 | return success(); |
| 2016 | } |
| 2017 | |
| 2018 | static LogicalResult |
| 2019 | vectorizeDynamicLinalgOpPrecondition(linalg::LinalgOp op, |
| 2020 | bool flatten1DDepthwiseConv) { |
| 2021 | if (isa<ConvolutionOpInterface>(op.getOperation())) |
| 2022 | return vectorizeDynamicConvOpPrecondition(op, flatten1DDepthwiseConv); |
| 2023 | |
| 2024 | if (hasReductionIterator(op)) |
| 2025 | return reductionPreconditions(op); |
| 2026 | |
| 2027 | // TODO: Masking only supports dynamic element-wise ops, linalg.generic ops, |
| 2028 | // linalg.copy ops and ops that implement ContractionOpInterface for now. |
| 2029 | if (!isElementwise(op) && |
| 2030 | !isa<linalg::GenericOp, linalg::CopyOp, linalg::ContractionOpInterface>( |
| 2031 | op.getOperation())) |
| 2032 | return failure(); |
| 2033 | |
| 2034 | LDBG("Dynamically-shaped op meets vectorization pre-conditions\n" ); |
| 2035 | return success(); |
| 2036 | } |
| 2037 | |
| 2038 | /// Need to check if the inner-tiles are static/constant. |
| 2039 | static LogicalResult |
| 2040 | vectorizeUnPackOpPrecondition(linalg::UnPackOp unpackOp, |
| 2041 | ArrayRef<int64_t> inputVectorSizes) { |
| 2042 | |
| 2043 | if (llvm::any_of(unpackOp.getInnerTiles(), [](OpFoldResult res) { |
| 2044 | return !getConstantIntValue(ofr: res).has_value(); |
| 2045 | })) { |
| 2046 | LDBG("Inner-tiles must be constant: " << unpackOp << "\n" ); |
| 2047 | return failure(); |
| 2048 | } |
| 2049 | ArrayRef<int64_t> resultShape = unpackOp.getDestType().getShape(); |
| 2050 | bool satisfyEmptyCond = inputVectorSizes.empty() && |
| 2051 | unpackOp.getDestType().hasStaticShape() && |
| 2052 | unpackOp.getSourceType().hasStaticShape(); |
| 2053 | if (!satisfyEmptyCond && |
| 2054 | failed(Result: vector::isValidMaskedInputVector(shape: resultShape, inputVectorSizes))) |
| 2055 | return failure(); |
| 2056 | |
| 2057 | return success(); |
| 2058 | } |
| 2059 | |
| 2060 | static LogicalResult |
| 2061 | vectorizeInsertSliceOpPrecondition(tensor::InsertSliceOp sliceOp, |
| 2062 | ArrayRef<int64_t> inputVectorSizes) { |
| 2063 | |
| 2064 | TypedValue<RankedTensorType> source = sliceOp.getSource(); |
| 2065 | auto sourceType = source.getType(); |
| 2066 | if (!VectorType::isValidElementType(sourceType.getElementType())) |
| 2067 | return failure(); |
| 2068 | |
| 2069 | // Get the pad value. |
| 2070 | // TransferReadOp (which is used to vectorize InsertSliceOp), requires a |
| 2071 | // scalar padding value. Note that: |
| 2072 | // * for in-bounds accesses, |
| 2073 | // the value is actually irrelevant. There are 2 cases in which xfer.read |
| 2074 | // accesses are known to be in-bounds: |
| 2075 | // 1. The source shape is static (output vector sizes would be based on |
| 2076 | // the source shape and hence all memory accesses would be in-bounds), |
| 2077 | // 2. Masking is used, i.e. the output vector sizes are user-provided. In |
| 2078 | // this case it is safe to assume that all memory accesses are in-bounds. |
| 2079 | // |
| 2080 | // When the value is not known and not needed, use 0. Otherwise, bail out. |
| 2081 | Value padValue = getStaticPadVal(sliceOp); |
| 2082 | bool isOutOfBoundsRead = |
| 2083 | !sourceType.hasStaticShape() && inputVectorSizes.empty(); |
| 2084 | |
| 2085 | if (!padValue && isOutOfBoundsRead) { |
| 2086 | LDBG("Failed to get a pad value for out-of-bounds read access\n" ); |
| 2087 | return failure(); |
| 2088 | } |
| 2089 | return success(); |
| 2090 | } |
| 2091 | |
| 2092 | namespace { |
| 2093 | enum class ConvOperationKind { Conv, Pool }; |
| 2094 | } // namespace |
| 2095 | |
| 2096 | static bool isCastOfBlockArgument(Operation *op) { |
| 2097 | return isa<CastOpInterface>(op) && op->getNumOperands() == 1 && |
| 2098 | isa<BlockArgument>(op->getOperand(0)); |
| 2099 | } |
| 2100 | |
| 2101 | // Returns the ConvOperationKind of the op using reduceOp of the generic |
| 2102 | // payload. If it is neither a convolution nor a pooling, it returns |
| 2103 | // std::nullopt. |
| 2104 | // |
| 2105 | // If (region has 2 ops (reduction + yield) or 3 ops (extension + reduction |
| 2106 | // + yield) and rhs is not used) then it is the body of a pooling |
| 2107 | // If conv, check for single `mul` predecessor. The `mul` operands must be |
| 2108 | // block arguments or extension of block arguments. |
| 2109 | // Otherwise, check for one or zero `ext` predecessor. The `ext` operands |
| 2110 | // must be block arguments or extension of block arguments. |
| 2111 | static std::optional<ConvOperationKind> |
| 2112 | getConvOperationKind(Operation *reduceOp) { |
| 2113 | int numBlockArguments = |
| 2114 | llvm::count_if(Range: reduceOp->getOperands(), P: llvm::IsaPred<BlockArgument>); |
| 2115 | |
| 2116 | switch (numBlockArguments) { |
| 2117 | case 1: { |
| 2118 | // Will be convolution if feeder is a MulOp. |
| 2119 | // A strength reduced version of MulOp for i1 type is AndOp which is also |
| 2120 | // supported. Otherwise, it can be pooling. This strength reduction logic |
| 2121 | // is in `buildBinaryFn` helper in the Linalg dialect. |
| 2122 | auto feedValIt = llvm::find_if_not(Range: reduceOp->getOperands(), |
| 2123 | P: llvm::IsaPred<BlockArgument>); |
| 2124 | assert(feedValIt != reduceOp->operand_end() && |
| 2125 | "Expected a non-block argument operand" ); |
| 2126 | Operation *feedOp = (*feedValIt).getDefiningOp(); |
| 2127 | if (isCastOfBlockArgument(op: feedOp)) { |
| 2128 | return ConvOperationKind::Pool; |
| 2129 | } |
| 2130 | |
| 2131 | if (!((isa<arith::MulIOp, arith::MulFOp>(feedOp) || |
| 2132 | (isa<arith::AndIOp>(feedOp) && |
| 2133 | feedOp->getResultTypes()[0].isInteger(1))) && |
| 2134 | llvm::all_of(feedOp->getOperands(), [](Value v) { |
| 2135 | if (isa<BlockArgument>(v)) |
| 2136 | return true; |
| 2137 | if (Operation *op = v.getDefiningOp()) |
| 2138 | return isCastOfBlockArgument(op); |
| 2139 | return false; |
| 2140 | }))) { |
| 2141 | return std::nullopt; |
| 2142 | } |
| 2143 | |
| 2144 | return ConvOperationKind::Conv; |
| 2145 | } |
| 2146 | case 2: |
| 2147 | // Must be pooling |
| 2148 | return ConvOperationKind::Pool; |
| 2149 | default: |
| 2150 | return std::nullopt; |
| 2151 | } |
| 2152 | } |
| 2153 | |
| 2154 | static bool isSupportedPoolKind(vector::CombiningKind kind) { |
| 2155 | switch (kind) { |
| 2156 | case vector::CombiningKind::ADD: |
| 2157 | case vector::CombiningKind::MAXNUMF: |
| 2158 | case vector::CombiningKind::MAXIMUMF: |
| 2159 | case vector::CombiningKind::MAXSI: |
| 2160 | case vector::CombiningKind::MAXUI: |
| 2161 | case vector::CombiningKind::MINNUMF: |
| 2162 | case vector::CombiningKind::MINIMUMF: |
| 2163 | case vector::CombiningKind::MINSI: |
| 2164 | case vector::CombiningKind::MINUI: |
| 2165 | return true; |
| 2166 | default: |
| 2167 | return false; |
| 2168 | } |
| 2169 | } |
| 2170 | |
| 2171 | static LogicalResult vectorizeConvOpPrecondition(linalg::LinalgOp convOp) { |
| 2172 | auto getOperandType = [&](auto operand) { |
| 2173 | return dyn_cast<ShapedType>((operand->get()).getType()); |
| 2174 | }; |
| 2175 | ShapedType lhsShapedType = getOperandType(convOp.getDpsInputOperand(0)); |
| 2176 | ShapedType rhsShapedType = getOperandType(convOp.getDpsInputOperand(1)); |
| 2177 | ShapedType resShapedType = getOperandType(convOp.getDpsInitOperand(0)); |
| 2178 | // (LHS has dimension NCW/NWC and RES has dimension NFW/NCW/NWF/NWC) OR |
| 2179 | // (non-channeled convolution -> LHS and RHS both have single dimensions). |
| 2180 | // Note that this also ensures 2D and 3D convolutions are rejected. |
| 2181 | if ((lhsShapedType.getRank() != 3 || resShapedType.getRank() != 3) && |
| 2182 | (lhsShapedType.getRank() != 1 || resShapedType.getRank() != 1)) |
| 2183 | return failure(); |
| 2184 | |
| 2185 | Operation *reduceOp = matchLinalgReduction(convOp.getDpsInitOperand(0)); |
| 2186 | if (!reduceOp) |
| 2187 | return failure(); |
| 2188 | |
| 2189 | auto maybeOper = getConvOperationKind(reduceOp); |
| 2190 | if (!maybeOper.has_value()) |
| 2191 | return failure(); |
| 2192 | |
| 2193 | auto maybeKind = getCombinerOpKind(reduceOp); |
| 2194 | // Typically convolution will have a `Add` CombiningKind but for i1 type it |
| 2195 | // can get strength reduced to `OR` which is also supported. This strength |
| 2196 | // reduction logic is in `buildBinaryFn` helper in the Linalg dialect. |
| 2197 | if (!maybeKind || ((*maybeKind != vector::CombiningKind::ADD && |
| 2198 | *maybeKind != vector::CombiningKind::OR) && |
| 2199 | (*maybeOper != ConvOperationKind::Pool || |
| 2200 | !isSupportedPoolKind(*maybeKind)))) { |
| 2201 | return failure(); |
| 2202 | } |
| 2203 | |
| 2204 | auto rhsRank = rhsShapedType.getRank(); |
| 2205 | if (*maybeOper == ConvOperationKind::Pool) { |
| 2206 | if (rhsRank != 1) |
| 2207 | return failure(); |
| 2208 | } else { |
| 2209 | if (rhsRank != 1 && rhsRank != 2 && rhsRank != 3) |
| 2210 | return failure(); |
| 2211 | } |
| 2212 | |
| 2213 | return success(); |
| 2214 | } |
| 2215 | |
| 2216 | static LogicalResult vectorizeLinalgOpPrecondition( |
| 2217 | LinalgOp linalgOp, ArrayRef<int64_t> inputVectorSizes, |
| 2218 | bool , bool flatten1DDepthwiseConv) { |
| 2219 | // tensor with dimension of 0 cannot be vectorized. |
| 2220 | if (llvm::is_contained(linalgOp.getStaticShape(), 0)) |
| 2221 | return failure(); |
| 2222 | // Check API contract for input vector sizes. |
| 2223 | if (!inputVectorSizes.empty() && |
| 2224 | failed(vector::isValidMaskedInputVector(shape: linalgOp.getStaticLoopRanges(), |
| 2225 | inputVectorSizes))) |
| 2226 | return failure(); |
| 2227 | |
| 2228 | if (linalgOp.hasDynamicShape() && failed(vectorizeDynamicLinalgOpPrecondition( |
| 2229 | linalgOp, flatten1DDepthwiseConv))) { |
| 2230 | LDBG("Dynamically-shaped op failed vectorization pre-conditions\n" ); |
| 2231 | return failure(); |
| 2232 | } |
| 2233 | |
| 2234 | SmallVector<CustomVectorizationPrecondition> customPreconditions; |
| 2235 | |
| 2236 | // Register CustomVectorizationPrecondition for extractOp. |
| 2237 | customPreconditions.push_back(Elt: tensorExtractVectorizationPrecondition); |
| 2238 | |
| 2239 | // All types in the body should be a supported element type for VectorType. |
| 2240 | for (Operation &innerOp : linalgOp->getRegion(0).front()) { |
| 2241 | // Check if any custom hook can vectorize the inner op. |
| 2242 | if (llvm::any_of( |
| 2243 | customPreconditions, |
| 2244 | [&](const CustomVectorizationPrecondition &customPrecondition) { |
| 2245 | return succeeded( |
| 2246 | customPrecondition(&innerOp, vectorizeNDExtract)); |
| 2247 | })) { |
| 2248 | continue; |
| 2249 | } |
| 2250 | if (!llvm::all_of(innerOp.getOperandTypes(), |
| 2251 | VectorType::isValidElementType)) { |
| 2252 | return failure(); |
| 2253 | } |
| 2254 | if (!llvm::all_of(innerOp.getResultTypes(), |
| 2255 | VectorType::isValidElementType)) { |
| 2256 | return failure(); |
| 2257 | } |
| 2258 | } |
| 2259 | if (isElementwise(linalgOp)) |
| 2260 | return success(); |
| 2261 | |
| 2262 | // TODO: isaConvolutionOpInterface that can also infer from generic |
| 2263 | // features. But we will still need stride/dilation attributes that will be |
| 2264 | // annoying to reverse-engineer... |
| 2265 | if (isa<ConvolutionOpInterface>(linalgOp.getOperation())) |
| 2266 | return vectorizeConvOpPrecondition(linalgOp); |
| 2267 | |
| 2268 | // TODO: the common vector shape is equal to the static loop sizes only when |
| 2269 | // all indexing maps are projected permutations. For convs and stencils the |
| 2270 | // logic will need to evolve. |
| 2271 | if (!allIndexingsAreProjectedPermutation(linalgOp)) { |
| 2272 | LDBG("precondition failed: not projected permutations\n" ); |
| 2273 | return failure(); |
| 2274 | } |
| 2275 | if (failed(reductionPreconditions(linalgOp))) { |
| 2276 | LDBG("precondition failed: reduction preconditions\n" ); |
| 2277 | return failure(); |
| 2278 | } |
| 2279 | return success(); |
| 2280 | } |
| 2281 | |
| 2282 | static LogicalResult |
| 2283 | vectorizePackOpPrecondition(linalg::PackOp packOp, |
| 2284 | ArrayRef<int64_t> inputVectorSizes) { |
| 2285 | auto padValue = packOp.getPaddingValue(); |
| 2286 | Attribute cstAttr; |
| 2287 | if (padValue && !matchPattern(padValue, m_Constant(bind_value: &cstAttr))) { |
| 2288 | LDBG("pad value is not constant: " << packOp << "\n" ); |
| 2289 | return failure(); |
| 2290 | } |
| 2291 | ArrayRef<int64_t> resultTensorShape = packOp.getDestType().getShape(); |
| 2292 | bool satisfyEmptyCond = true; |
| 2293 | if (inputVectorSizes.empty()) { |
| 2294 | if (!packOp.getDestType().hasStaticShape() || |
| 2295 | !packOp.getSourceType().hasStaticShape()) |
| 2296 | satisfyEmptyCond = false; |
| 2297 | } |
| 2298 | |
| 2299 | if (!satisfyEmptyCond && |
| 2300 | failed(vector::isValidMaskedInputVector( |
| 2301 | shape: resultTensorShape.take_front(N: packOp.getSourceRank()), |
| 2302 | inputVectorSizes))) |
| 2303 | return failure(); |
| 2304 | |
| 2305 | if (llvm::any_of(packOp.getInnerTiles(), [](OpFoldResult v) { |
| 2306 | return !getConstantIntValue(ofr: v).has_value(); |
| 2307 | })) { |
| 2308 | LDBG("inner_tiles must be constant: " << packOp << "\n" ); |
| 2309 | return failure(); |
| 2310 | } |
| 2311 | |
| 2312 | return success(); |
| 2313 | } |
| 2314 | |
| 2315 | static LogicalResult |
| 2316 | vectorizePadOpPrecondition(tensor::PadOp padOp, |
| 2317 | ArrayRef<int64_t> inputVectorSizes) { |
| 2318 | auto padValue = padOp.getConstantPaddingValue(); |
| 2319 | if (!padValue) { |
| 2320 | LDBG("pad value is not constant: " << padOp << "\n" ); |
| 2321 | return failure(); |
| 2322 | } |
| 2323 | |
| 2324 | ArrayRef<int64_t> resultTensorShape = padOp.getResultType().getShape(); |
| 2325 | if (failed(Result: vector::isValidMaskedInputVector(shape: resultTensorShape, |
| 2326 | inputVectorSizes))) |
| 2327 | return failure(); |
| 2328 | |
| 2329 | // Padding with non-zero low pad values is not supported, unless the |
| 2330 | // corresponding result dim is 1 as this would require shifting the results to |
| 2331 | // the right for the low padded dims by the required amount of low padding. |
| 2332 | // However, we do support low padding if the dims being low padded have result |
| 2333 | // sizes of 1. The reason is when we have a low pad on a unit result dim, the |
| 2334 | // input size of that dimension will be dynamically zero (as the sum of the |
| 2335 | // low pad and input dim size has to be one) and hence we will create a zero |
| 2336 | // mask as the lowering logic just makes the mask one for the input dim size - |
| 2337 | // which is zero here. Hence we will load the pad value which is what we want |
| 2338 | // in this case. If the low pad is dynamically zero then the lowering is |
| 2339 | // correct as well as no shifts are necessary. |
| 2340 | if (llvm::any_of(llvm::enumerate(padOp.getLow()), [&](const auto &en) { |
| 2341 | Value padValue = en.value(); |
| 2342 | unsigned pos = en.index(); |
| 2343 | std::optional<int64_t> pad = getConstantIntValue(ofr: padValue); |
| 2344 | return (!pad.has_value() || pad.value() != 0) && |
| 2345 | resultTensorShape[pos] != 1; |
| 2346 | })) { |
| 2347 | LDBG("low pad must all be zero for all non unit dims: " << padOp << "\n" ); |
| 2348 | return failure(); |
| 2349 | } |
| 2350 | |
| 2351 | return success(); |
| 2352 | } |
| 2353 | |
| 2354 | /// Preconditions for scalable vectors. This is quite restrictive - it models |
| 2355 | /// the fact that in practice we would only make selected dimensions scalable. |
| 2356 | static LogicalResult |
| 2357 | vectorizeScalableVectorPrecondition(Operation *op, |
| 2358 | ArrayRef<int64_t> inputVectorSizes, |
| 2359 | ArrayRef<bool> inputScalableVecDims) { |
| 2360 | assert(inputVectorSizes.size() == inputScalableVecDims.size() && |
| 2361 | "Number of input vector sizes and scalable dims doesn't match" ); |
| 2362 | |
| 2363 | size_t numOfScalableDims = |
| 2364 | llvm::count_if(Range&: inputScalableVecDims, P: [](bool flag) { return flag; }); |
| 2365 | |
| 2366 | if (numOfScalableDims == 0) |
| 2367 | return success(); |
| 2368 | |
| 2369 | auto linalgOp = dyn_cast<LinalgOp>(op); |
| 2370 | |
| 2371 | // Cond 1: There's been no need for scalable vectorisation of |
| 2372 | // non-linalg Ops so far |
| 2373 | if (!linalgOp) |
| 2374 | return failure(); |
| 2375 | |
| 2376 | // Cond 2: There's been no need for more than 2 scalable dims so far |
| 2377 | if (numOfScalableDims > 2) |
| 2378 | return failure(); |
| 2379 | |
| 2380 | // Cond 3: Look at the configuration in `inputScalableVecDims` and verify that |
| 2381 | // it matches one of the supported cases: |
| 2382 | // 1. Exactly 1 dim is scalable and that's the _last_ non-unit parallel dim |
| 2383 | // (*). |
| 2384 | // 2. Exactly 2 dims are scalable and those are the _last two adjacent_ |
| 2385 | // parallel dims. |
| 2386 | // 3. Exactly 1 reduction dim is scalable and that's the last (innermost) |
| 2387 | // dim. |
| 2388 | // The 2nd restriction above means that only Matmul-like Ops are supported |
| 2389 | // when 2 dims are scalable, e.g. : |
| 2390 | // * iterators = [parallel, parallel, reduction] |
| 2391 | // * scalable flags = [true, true, false] |
| 2392 | // |
| 2393 | // (*) Non-unit dims get folded away in practice. |
| 2394 | // TODO: Relax these conditions as good motivating examples are identified. |
| 2395 | |
| 2396 | // Find the first scalable flag. |
| 2397 | bool seenNonUnitParallel = false; |
| 2398 | auto iterators = linalgOp.getIteratorTypesArray(); |
| 2399 | SmallVector<bool> scalableFlags(inputScalableVecDims); |
| 2400 | int64_t idx = scalableFlags.size() - 1; |
| 2401 | while (!scalableFlags[idx]) { |
| 2402 | bool isNonUnitDim = (inputVectorSizes[idx] != 1); |
| 2403 | seenNonUnitParallel |= |
| 2404 | (iterators[idx] == utils::IteratorType::parallel && isNonUnitDim); |
| 2405 | |
| 2406 | iterators.pop_back(); |
| 2407 | scalableFlags.pop_back(); |
| 2408 | --idx; |
| 2409 | } |
| 2410 | |
| 2411 | // Analyze the iterator corresponding to the first scalable dim. |
| 2412 | switch (iterators.back()) { |
| 2413 | case utils::IteratorType::reduction: { |
| 2414 | // Check 3. above is met. |
| 2415 | if (iterators.size() != inputVectorSizes.size()) { |
| 2416 | LDBG("Non-trailing reduction dim requested for scalable " |
| 2417 | "vectorization\n" ); |
| 2418 | return failure(); |
| 2419 | } |
| 2420 | if (isa<linalg::MatmulOp>(op) || isa<linalg::MatmulTransposeAOp>(op)) { |
| 2421 | LDBG("Scalable vectorization of the reduction dim in Matmul-like ops " |
| 2422 | "is not supported\n" ); |
| 2423 | return failure(); |
| 2424 | } |
| 2425 | break; |
| 2426 | } |
| 2427 | case utils::IteratorType::parallel: { |
| 2428 | // Check 1. and 2. above are met. |
| 2429 | if (seenNonUnitParallel) { |
| 2430 | LDBG("Inner parallel dim not requested for scalable " |
| 2431 | "vectorization\n" ); |
| 2432 | return failure(); |
| 2433 | } |
| 2434 | break; |
| 2435 | } |
| 2436 | } |
| 2437 | |
| 2438 | // If present, check the 2nd scalable dim. ATM, only Matmul-like Ops are |
| 2439 | // supported for which expect the folowing config: |
| 2440 | // * iterators = [parallel, parallel, reduction] |
| 2441 | // * scalable flags = [true, true, false] |
| 2442 | if (numOfScalableDims == 2) { |
| 2443 | // Disallow below case which breaks 3. above: |
| 2444 | // * iterators = [..., parallel, reduction] |
| 2445 | // * scalable flags = [..., true, true] |
| 2446 | if (iterators.back() == utils::IteratorType::reduction) { |
| 2447 | LDBG("Higher dim than the trailing reduction dim requested for scalable " |
| 2448 | "vectorization\n" ); |
| 2449 | return failure(); |
| 2450 | } |
| 2451 | scalableFlags.pop_back(); |
| 2452 | iterators.pop_back(); |
| 2453 | |
| 2454 | if (!scalableFlags.back() || |
| 2455 | (iterators.back() != utils::IteratorType::parallel)) |
| 2456 | return failure(); |
| 2457 | } |
| 2458 | |
| 2459 | // Check to not let go the matmul with extended semantic, through this |
| 2460 | // transform. |
| 2461 | if (linalgOp.hasUserDefinedMaps()) |
| 2462 | return failure(); |
| 2463 | |
| 2464 | // Cond 4: Only the following ops are supported in the |
| 2465 | // presence of scalable vectors |
| 2466 | return success(isElementwise(linalgOp) || isa<linalg::MatmulOp>(op) || |
| 2467 | isa<linalg::MatmulTransposeAOp>(op) || |
| 2468 | isa<linalg::DepthwiseConv1DNwcWcOp>(op) || |
| 2469 | isa<linalg::MatvecOp>(op) || hasReductionIterator(linalgOp)); |
| 2470 | } |
| 2471 | |
| 2472 | LogicalResult mlir::linalg::vectorizeOpPrecondition( |
| 2473 | Operation *op, ArrayRef<int64_t> inputVectorSizes, |
| 2474 | ArrayRef<bool> inputScalableVecDims, bool , |
| 2475 | bool flatten1DDepthwiseConv) { |
| 2476 | |
| 2477 | if (!hasVectorizationImpl(op)) |
| 2478 | return failure(); |
| 2479 | |
| 2480 | if (failed(Result: vectorizeScalableVectorPrecondition(op, inputVectorSizes, |
| 2481 | inputScalableVecDims))) |
| 2482 | return failure(); |
| 2483 | |
| 2484 | return TypeSwitch<Operation *, LogicalResult>(op) |
| 2485 | .Case<linalg::LinalgOp>([&](auto linalgOp) { |
| 2486 | return vectorizeLinalgOpPrecondition(linalgOp, inputVectorSizes, |
| 2487 | vectorizeNDExtract, |
| 2488 | flatten1DDepthwiseConv); |
| 2489 | }) |
| 2490 | .Case<tensor::PadOp>([&](auto padOp) { |
| 2491 | return vectorizePadOpPrecondition(padOp, inputVectorSizes); |
| 2492 | }) |
| 2493 | .Case<linalg::PackOp>([&](auto packOp) { |
| 2494 | return vectorizePackOpPrecondition(packOp, inputVectorSizes); |
| 2495 | }) |
| 2496 | .Case<linalg::UnPackOp>([&](auto unpackOp) { |
| 2497 | return vectorizeUnPackOpPrecondition(unpackOp, inputVectorSizes); |
| 2498 | }) |
| 2499 | .Case<tensor::InsertSliceOp>([&](auto sliceOp) { |
| 2500 | return vectorizeInsertSliceOpPrecondition(sliceOp, inputVectorSizes); |
| 2501 | }) |
| 2502 | .Default([](auto) { return failure(); }); |
| 2503 | } |
| 2504 | |
| 2505 | /// Converts affine.apply Ops to arithmetic operations. |
| 2506 | static void convertAffineApply(RewriterBase &rewriter, LinalgOp linalgOp) { |
| 2507 | OpBuilder::InsertionGuard g(rewriter); |
| 2508 | auto toReplace = linalgOp.getBlock()->getOps<affine::AffineApplyOp>(); |
| 2509 | |
| 2510 | for (auto op : make_early_inc_range(toReplace)) { |
| 2511 | rewriter.setInsertionPoint(op); |
| 2512 | auto expanded = affine::expandAffineExpr( |
| 2513 | rewriter, op->getLoc(), op.getAffineMap().getResult(0), |
| 2514 | op.getOperands().take_front(op.getAffineMap().getNumDims()), |
| 2515 | op.getOperands().take_back(op.getAffineMap().getNumSymbols())); |
| 2516 | rewriter.replaceOp(op, expanded); |
| 2517 | } |
| 2518 | } |
| 2519 | |
| 2520 | bool mlir::linalg::hasVectorizationImpl(Operation *op) { |
| 2521 | return isa<linalg::LinalgOp, tensor::PadOp, linalg::PackOp, linalg::UnPackOp, |
| 2522 | tensor::InsertSliceOp>(op); |
| 2523 | } |
| 2524 | |
| 2525 | /// Emit a suitable vector form for an operation. If provided, |
| 2526 | /// `inputVectorSizes` are used to vectorize this operation. |
| 2527 | /// `inputVectorSizes` must match the rank of the iteration space of the |
| 2528 | /// operation and the input vector sizes must be greater than or equal to |
| 2529 | /// their counterpart iteration space sizes, if static. `inputVectorShapes` |
| 2530 | /// also allows the vectorization of operations with dynamic shapes. |
| 2531 | LogicalResult mlir::linalg::vectorize(RewriterBase &rewriter, Operation *op, |
| 2532 | ArrayRef<int64_t> inputVectorSizes, |
| 2533 | ArrayRef<bool> inputScalableVecDims, |
| 2534 | bool , |
| 2535 | bool flatten1DDepthwiseConv) { |
| 2536 | LDBG("Attempting to vectorize:\n" << *op << "\n" ); |
| 2537 | LDBG("Input vector sizes: " ); |
| 2538 | LLVM_DEBUG(llvm::interleaveComma(inputVectorSizes, llvm::dbgs())); |
| 2539 | LLVM_DEBUG(llvm::dbgs() << "\n" ); |
| 2540 | LDBG("Input scalable vector dims: " ); |
| 2541 | LLVM_DEBUG(llvm::interleaveComma(inputScalableVecDims, llvm::dbgs())); |
| 2542 | LLVM_DEBUG(llvm::dbgs() << "\n" ); |
| 2543 | |
| 2544 | if (failed(Result: vectorizeOpPrecondition(op, inputVectorSizes, inputScalableVecDims, |
| 2545 | vectorizeNDExtract, |
| 2546 | flatten1DDepthwiseConv))) { |
| 2547 | LDBG("Vectorization pre-conditions failed\n" ); |
| 2548 | return failure(); |
| 2549 | } |
| 2550 | |
| 2551 | // Initialize vectorization state. |
| 2552 | VectorizationState state(rewriter); |
| 2553 | if (auto linalgOp = dyn_cast<linalg::LinalgOp>(op)) { |
| 2554 | if (failed(state.initState(rewriter, linalgOp: linalgOp, inputVectorSizes, |
| 2555 | inputScalableVecDims))) { |
| 2556 | LDBG("Vectorization state couldn't be initialized\n" ); |
| 2557 | return failure(); |
| 2558 | } |
| 2559 | } |
| 2560 | |
| 2561 | SmallVector<Value> results; |
| 2562 | auto vectorizeResult = |
| 2563 | TypeSwitch<Operation *, LogicalResult>(op) |
| 2564 | .Case<linalg::LinalgOp>([&](auto linalgOp) { |
| 2565 | // TODO: isaConvolutionOpInterface that can also infer from |
| 2566 | // generic features. Will require stride/dilation attributes |
| 2567 | // inference. |
| 2568 | if (isa<ConvolutionOpInterface>(linalgOp.getOperation())) { |
| 2569 | FailureOr<Operation *> convOr = vectorizeConvolution( |
| 2570 | rewriter, linalgOp, inputVectorSizes, inputScalableVecDims, |
| 2571 | flatten1DDepthwiseConv); |
| 2572 | if (succeeded(convOr)) { |
| 2573 | llvm::append_range(results, (*convOr)->getResults()); |
| 2574 | return success(); |
| 2575 | } |
| 2576 | |
| 2577 | LDBG("Unsupported convolution can't be vectorized.\n" ); |
| 2578 | return failure(); |
| 2579 | } |
| 2580 | |
| 2581 | LDBG("Vectorize generic by broadcasting to the canonical vector " |
| 2582 | "shape\n" ); |
| 2583 | |
| 2584 | // Pre-process before proceeding. |
| 2585 | convertAffineApply(rewriter, linalgOp); |
| 2586 | |
| 2587 | // TODO: 'vectorize' takes in a 'RewriterBase' which is up-casted |
| 2588 | // to 'OpBuilder' when it is passed over to some methods like |
| 2589 | // 'vectorizeAsLinalgGeneric'. This is highly problematic: if we |
| 2590 | // erase an op within these methods, the actual rewriter won't be |
| 2591 | // notified and we will end up with read-after-free issues! |
| 2592 | return vectorizeAsLinalgGeneric(rewriter, state, linalgOp, results); |
| 2593 | }) |
| 2594 | .Case<tensor::PadOp>([&](auto padOp) { |
| 2595 | return vectorizeAsTensorPadOp(rewriter, padOp, inputVectorSizes, |
| 2596 | results); |
| 2597 | }) |
| 2598 | .Case<linalg::PackOp>([&](auto packOp) { |
| 2599 | return vectorizeAsTensorPackOp(rewriter, packOp, inputVectorSizes, |
| 2600 | results); |
| 2601 | }) |
| 2602 | .Case<linalg::UnPackOp>([&](auto unpackOp) { |
| 2603 | return vectorizeAsTensorUnpackOp(rewriter, unpackOp, |
| 2604 | inputVectorSizes, results); |
| 2605 | }) |
| 2606 | .Case<tensor::InsertSliceOp>([&](auto sliceOp) { |
| 2607 | return vectorizeAsInsertSliceOp(rewriter, sliceOp, inputVectorSizes, |
| 2608 | results); |
| 2609 | }) |
| 2610 | .Default([](auto) { return failure(); }); |
| 2611 | |
| 2612 | if (failed(vectorizeResult)) { |
| 2613 | LDBG("Vectorization failed\n" ); |
| 2614 | return failure(); |
| 2615 | } |
| 2616 | |
| 2617 | if (!results.empty()) |
| 2618 | rewriter.replaceOp(op, newValues: results); |
| 2619 | else |
| 2620 | rewriter.eraseOp(op); |
| 2621 | |
| 2622 | return success(); |
| 2623 | } |
| 2624 | |
| 2625 | LogicalResult mlir::linalg::vectorizeCopy(RewriterBase &rewriter, |
| 2626 | memref::CopyOp copyOp) { |
| 2627 | auto srcType = cast<MemRefType>(copyOp.getSource().getType()); |
| 2628 | auto dstType = cast<MemRefType>(copyOp.getTarget().getType()); |
| 2629 | if (!srcType.hasStaticShape() || !dstType.hasStaticShape()) |
| 2630 | return failure(); |
| 2631 | |
| 2632 | auto srcElementType = getElementTypeOrSelf(srcType); |
| 2633 | auto dstElementType = getElementTypeOrSelf(dstType); |
| 2634 | if (!VectorType::isValidElementType(srcElementType) || |
| 2635 | !VectorType::isValidElementType(dstElementType)) |
| 2636 | return failure(); |
| 2637 | |
| 2638 | auto readType = VectorType::get(srcType.getShape(), srcElementType); |
| 2639 | auto writeType = VectorType::get(dstType.getShape(), dstElementType); |
| 2640 | |
| 2641 | Location loc = copyOp->getLoc(); |
| 2642 | Value zero = rewriter.create<arith::ConstantIndexOp>(location: loc, args: 0); |
| 2643 | SmallVector<Value> indices(srcType.getRank(), zero); |
| 2644 | |
| 2645 | Value readValue = rewriter.create<vector::TransferReadOp>( |
| 2646 | loc, readType, copyOp.getSource(), indices, |
| 2647 | rewriter.getMultiDimIdentityMap(rank: srcType.getRank())); |
| 2648 | if (cast<VectorType>(readValue.getType()).getRank() == 0) { |
| 2649 | readValue = |
| 2650 | rewriter.create<vector::ExtractOp>(loc, readValue, ArrayRef<int64_t>()); |
| 2651 | readValue = rewriter.create<vector::BroadcastOp>(loc, writeType, readValue); |
| 2652 | } |
| 2653 | Operation *writeValue = rewriter.create<vector::TransferWriteOp>( |
| 2654 | loc, readValue, copyOp.getTarget(), indices, |
| 2655 | rewriter.getMultiDimIdentityMap(rank: srcType.getRank())); |
| 2656 | rewriter.replaceOp(copyOp, writeValue->getResults()); |
| 2657 | return success(); |
| 2658 | } |
| 2659 | |
| 2660 | //----------------------------------------------------------------------------// |
| 2661 | // Misc. vectorization patterns. |
| 2662 | //----------------------------------------------------------------------------// |
| 2663 | /// Base pattern for rewriting tensor::PadOps whose result is consumed by a |
| 2664 | /// given operation type OpTy. |
| 2665 | template <typename OpTy> |
| 2666 | struct VectorizePadOpUserPattern : public OpRewritePattern<tensor::PadOp> { |
| 2667 | using OpRewritePattern<tensor::PadOp>::OpRewritePattern; |
| 2668 | |
| 2669 | LogicalResult matchAndRewrite(tensor::PadOp padOp, |
| 2670 | PatternRewriter &rewriter) const final { |
| 2671 | bool changed = false; |
| 2672 | // Insert users in vector, because some users may be replaced/removed. |
| 2673 | for (auto *user : llvm::to_vector<4>(padOp->getUsers())) |
| 2674 | if (auto op = dyn_cast<OpTy>(user)) |
| 2675 | changed |= rewriteUser(rewriter, padOp, op).succeeded(); |
| 2676 | return success(IsSuccess: changed); |
| 2677 | } |
| 2678 | |
| 2679 | protected: |
| 2680 | virtual LogicalResult rewriteUser(PatternRewriter &rewriter, |
| 2681 | tensor::PadOp padOp, OpTy op) const = 0; |
| 2682 | }; |
| 2683 | |
| 2684 | /// Rewrite use of tensor::PadOp result in TransferReadOp. E.g.: |
| 2685 | /// ``` |
| 2686 | /// %0 = tensor.pad %src ... : tensor<?x?xf32> to tensor<17x5xf32> |
| 2687 | /// %r = vector.transfer_read %0[%c0, %c0], %cst |
| 2688 | /// {in_bounds = [true, true]} : tensor<17x5xf32>, vector<17x5xf32> |
| 2689 | /// ``` |
| 2690 | /// is rewritten to: |
| 2691 | /// ``` |
| 2692 | /// %r = vector.transfer_read %src[%c0, %c0], %padding |
| 2693 | /// {in_bounds = [true, true]} |
| 2694 | /// : tensor<?x?xf32>, vector<17x5xf32> |
| 2695 | /// ``` |
| 2696 | /// Note: By restricting this pattern to in-bounds TransferReadOps, we can be |
| 2697 | /// sure that the original padding value %cst was never used. |
| 2698 | /// |
| 2699 | /// This rewrite is possible if: |
| 2700 | /// - `xferOp` has no out-of-bounds dims or mask. |
| 2701 | /// - Low padding is static 0. |
| 2702 | /// - Single, scalar padding value. |
| 2703 | struct PadOpVectorizationWithTransferReadPattern |
| 2704 | : public VectorizePadOpUserPattern<vector::TransferReadOp> { |
| 2705 | using VectorizePadOpUserPattern< |
| 2706 | vector::TransferReadOp>::VectorizePadOpUserPattern; |
| 2707 | |
| 2708 | LogicalResult rewriteUser(PatternRewriter &rewriter, tensor::PadOp padOp, |
| 2709 | vector::TransferReadOp xferOp) const override { |
| 2710 | // Low padding must be static 0. |
| 2711 | if (!padOp.hasZeroLowPad()) |
| 2712 | return failure(); |
| 2713 | // Pad value must be a constant. |
| 2714 | auto padValue = padOp.getConstantPaddingValue(); |
| 2715 | if (!padValue) |
| 2716 | return failure(); |
| 2717 | // Padding value of existing `xferOp` is unused. |
| 2718 | if (xferOp.hasOutOfBoundsDim() || xferOp.getMask()) |
| 2719 | return failure(); |
| 2720 | |
| 2721 | rewriter.modifyOpInPlace(xferOp, [&]() { |
| 2722 | SmallVector<bool> inBounds(xferOp.getVectorType().getRank(), false); |
| 2723 | xferOp->setAttr(xferOp.getInBoundsAttrName(), |
| 2724 | rewriter.getBoolArrayAttr(inBounds)); |
| 2725 | xferOp.getBaseMutable().assign(padOp.getSource()); |
| 2726 | xferOp.getPaddingMutable().assign(padValue); |
| 2727 | }); |
| 2728 | |
| 2729 | return success(); |
| 2730 | } |
| 2731 | }; |
| 2732 | |
| 2733 | /// Rewrite use of tensor::PadOp result in TransferWriteOp. |
| 2734 | /// This pattern rewrites TransferWriteOps that write to a padded tensor |
| 2735 | /// value, where the same amount of padding is immediately removed again after |
| 2736 | /// the write. In such cases, the TransferWriteOp can write to the non-padded |
| 2737 | /// tensor value and apply out-of-bounds masking. E.g.: |
| 2738 | /// ``` |
| 2739 | /// %0 = tensor.extract_slice ...[...] [%s0, %s1] [1, 1] |
| 2740 | /// : tensor<...> to tensor<?x?xf32> |
| 2741 | /// %1 = tensor.pad %0 ... : tensor<?x?xf32> to tensor<17x5xf32> |
| 2742 | /// %2 = vector.transfer_write %vec, %1[...] |
| 2743 | /// : vector<17x5xf32>, tensor<17x5xf32> |
| 2744 | /// %r = tensor.extract_slice %2[0, 0] [%s0, %s1] [1, 1] |
| 2745 | /// : tensor<17x5xf32> to tensor<?x?xf32> |
| 2746 | /// ``` |
| 2747 | /// is rewritten to: |
| 2748 | /// ``` |
| 2749 | /// %0 = tensor.extract_slice ...[...] [%s0, %s1] [1, 1] |
| 2750 | /// : tensor<...> to tensor<?x?xf32> |
| 2751 | /// %r = vector.transfer_write %vec, %0[...] : vector<17x5xf32>, |
| 2752 | /// tensor<?x?xf32> |
| 2753 | /// ``` |
| 2754 | /// Note: It is important that the ExtractSliceOp %r resizes the result of the |
| 2755 | /// TransferWriteOp to the same size as the input of the TensorPadOp (or an |
| 2756 | /// even smaller size). Otherwise, %r's new (dynamic) dimensions would differ |
| 2757 | /// from %r's old dimensions. |
| 2758 | /// |
| 2759 | /// This rewrite is possible if: |
| 2760 | /// - Low padding is static 0. |
| 2761 | /// - `xferOp` has exactly one use, which is an ExtractSliceOp. This |
| 2762 | /// ExtractSliceOp trims the same amount of padding that was added |
| 2763 | /// beforehand. |
| 2764 | /// - Single, scalar padding value. |
| 2765 | struct PadOpVectorizationWithTransferWritePattern |
| 2766 | : public VectorizePadOpUserPattern<vector::TransferWriteOp> { |
| 2767 | using VectorizePadOpUserPattern< |
| 2768 | vector::TransferWriteOp>::VectorizePadOpUserPattern; |
| 2769 | |
| 2770 | LogicalResult rewriteUser(PatternRewriter &rewriter, tensor::PadOp padOp, |
| 2771 | vector::TransferWriteOp xferOp) const override { |
| 2772 | // TODO: support 0-d corner case. |
| 2773 | if (xferOp.getTransferRank() == 0) |
| 2774 | return failure(); |
| 2775 | |
| 2776 | // Low padding must be static 0. |
| 2777 | if (!padOp.hasZeroLowPad()) |
| 2778 | return failure(); |
| 2779 | // Pad value must be a constant. |
| 2780 | auto padValue = padOp.getConstantPaddingValue(); |
| 2781 | if (!padValue) |
| 2782 | return failure(); |
| 2783 | // TransferWriteOp result must be directly consumed by an ExtractSliceOp. |
| 2784 | if (!xferOp->hasOneUse()) |
| 2785 | return failure(); |
| 2786 | auto trimPadding = dyn_cast<tensor::ExtractSliceOp>(*xferOp->user_begin()); |
| 2787 | if (!trimPadding) |
| 2788 | return failure(); |
| 2789 | // Only static zero offsets supported when trimming padding. |
| 2790 | if (!trimPadding.hasZeroOffset()) |
| 2791 | return failure(); |
| 2792 | // trimPadding must remove the amount of padding that was added earlier. |
| 2793 | if (!hasSameTensorSize(beforePadding: padOp.getSource(), afterTrimming: trimPadding)) |
| 2794 | return failure(); |
| 2795 | |
| 2796 | // Insert the new TransferWriteOp at position of the old TransferWriteOp. |
| 2797 | rewriter.setInsertionPoint(xferOp); |
| 2798 | |
| 2799 | SmallVector<bool> inBounds(xferOp.getVectorType().getRank(), false); |
| 2800 | auto newXferOp = rewriter.replaceOpWithNewOp<vector::TransferWriteOp>( |
| 2801 | xferOp, padOp.getSource().getType(), xferOp.getVector(), |
| 2802 | padOp.getSource(), xferOp.getIndices(), xferOp.getPermutationMapAttr(), |
| 2803 | xferOp.getMask(), rewriter.getBoolArrayAttr(inBounds)); |
| 2804 | rewriter.replaceOp(trimPadding, newXferOp->getResult(0)); |
| 2805 | |
| 2806 | return success(); |
| 2807 | } |
| 2808 | |
| 2809 | /// Check if `beforePadding` and `afterTrimming` have the same tensor size, |
| 2810 | /// i.e., same dimensions. |
| 2811 | /// |
| 2812 | /// Dimensions may be static, dynamic or mix of both. In case of dynamic |
| 2813 | /// dimensions, this function tries to infer the (static) tensor size by |
| 2814 | /// looking at the defining op and utilizing op-specific knowledge. |
| 2815 | /// |
| 2816 | /// This is a conservative analysis. In case equal tensor sizes cannot be |
| 2817 | /// proven statically, this analysis returns `false` even though the tensor |
| 2818 | /// sizes may turn out to be equal at runtime. |
| 2819 | bool (Value beforePadding, |
| 2820 | tensor::ExtractSliceOp afterTrimming) const { |
| 2821 | // If the input to tensor::PadOp is a CastOp, try with both CastOp |
| 2822 | // result and CastOp operand. |
| 2823 | if (auto castOp = beforePadding.getDefiningOp<tensor::CastOp>()) |
| 2824 | if (hasSameTensorSize(beforePadding: castOp.getSource(), afterTrimming: afterTrimming)) |
| 2825 | return true; |
| 2826 | |
| 2827 | auto t1 = dyn_cast<RankedTensorType>(beforePadding.getType()); |
| 2828 | auto t2 = dyn_cast<RankedTensorType>(afterTrimming.getType()); |
| 2829 | // Only RankedTensorType supported. |
| 2830 | if (!t1 || !t2) |
| 2831 | return false; |
| 2832 | // Rank of both values must be the same. |
| 2833 | if (t1.getRank() != t2.getRank()) |
| 2834 | return false; |
| 2835 | |
| 2836 | // All static dimensions must be the same. Mixed cases (e.g., dimension |
| 2837 | // static in `t1` but dynamic in `t2`) are not supported. |
| 2838 | for (unsigned i = 0; i < t1.getRank(); ++i) { |
| 2839 | if (t1.isDynamicDim(i) != t2.isDynamicDim(i)) |
| 2840 | return false; |
| 2841 | if (!t1.isDynamicDim(i) && t1.getDimSize(i) != t2.getDimSize(i)) |
| 2842 | return false; |
| 2843 | } |
| 2844 | |
| 2845 | // Nothing more to check if all dimensions are static. |
| 2846 | if (t1.getNumDynamicDims() == 0) |
| 2847 | return true; |
| 2848 | |
| 2849 | // All dynamic sizes must be the same. The only supported case at the |
| 2850 | // moment is when `beforePadding` is an ExtractSliceOp (or a cast |
| 2851 | // thereof). |
| 2852 | |
| 2853 | // Apart from CastOp, only ExtractSliceOp is supported. |
| 2854 | auto beforeSlice = beforePadding.getDefiningOp<tensor::ExtractSliceOp>(); |
| 2855 | if (!beforeSlice) |
| 2856 | return false; |
| 2857 | |
| 2858 | assert(static_cast<size_t>(t1.getRank()) == |
| 2859 | beforeSlice.getMixedSizes().size()); |
| 2860 | assert(static_cast<size_t>(t2.getRank()) == |
| 2861 | afterTrimming.getMixedSizes().size()); |
| 2862 | |
| 2863 | for (unsigned i = 0; i < t1.getRank(); ++i) { |
| 2864 | // Skip static dimensions. |
| 2865 | if (!t1.isDynamicDim(i)) |
| 2866 | continue; |
| 2867 | auto size1 = beforeSlice.getMixedSizes()[i]; |
| 2868 | auto size2 = afterTrimming.getMixedSizes()[i]; |
| 2869 | |
| 2870 | // Case 1: Same value or same constant int. |
| 2871 | if (isEqualConstantIntOrValue(size1, size2)) |
| 2872 | continue; |
| 2873 | |
| 2874 | // Other cases: Take a deeper look at defining ops of values. |
| 2875 | auto v1 = llvm::dyn_cast_if_present<Value>(size1); |
| 2876 | auto v2 = llvm::dyn_cast_if_present<Value>(size2); |
| 2877 | if (!v1 || !v2) |
| 2878 | return false; |
| 2879 | |
| 2880 | // Case 2: Both values are identical AffineMinOps. (Should not happen if |
| 2881 | // CSE is run.) |
| 2882 | auto minOp1 = v1.getDefiningOp<affine::AffineMinOp>(); |
| 2883 | auto minOp2 = v2.getDefiningOp<affine::AffineMinOp>(); |
| 2884 | if (minOp1 && minOp2 && minOp1.getAffineMap() == minOp2.getAffineMap() && |
| 2885 | minOp1.getOperands() == minOp2.getOperands()) |
| 2886 | continue; |
| 2887 | |
| 2888 | // Add additional cases as needed. |
| 2889 | } |
| 2890 | |
| 2891 | // All tests passed. |
| 2892 | return true; |
| 2893 | } |
| 2894 | }; |
| 2895 | |
| 2896 | /// Returns the effective Pad value for the input op, provided it's a scalar. |
| 2897 | /// |
| 2898 | /// Many Ops exhibit pad-like behaviour, but this isn't always explicit. If |
| 2899 | /// this Op performs padding, retrieve the padding value provided that it's |
| 2900 | /// a scalar and static/fixed for all the padded values. Returns an empty value |
| 2901 | /// otherwise. |
| 2902 | /// |
| 2903 | /// TODO: This is used twice (when checking vectorization pre-conditions and |
| 2904 | /// when vectorizing). Cache results instead of re-running. |
| 2905 | static Value getStaticPadVal(Operation *op) { |
| 2906 | if (!op) |
| 2907 | return {}; |
| 2908 | |
| 2909 | // 1. vector.broadcast (f32 -> vector <...xf32>) - return the value that's |
| 2910 | // being broadcast, provided that it's a scalar. |
| 2911 | if (auto bcast = llvm::dyn_cast<vector::BroadcastOp>(op)) { |
| 2912 | auto source = bcast.getSource(); |
| 2913 | if (llvm::dyn_cast<VectorType>(source.getType())) |
| 2914 | return {}; |
| 2915 | |
| 2916 | return source; |
| 2917 | } |
| 2918 | |
| 2919 | // 2. linalg.fill - use the scalar input value that used to fill the output |
| 2920 | // tensor. |
| 2921 | if (auto fill = llvm::dyn_cast<linalg::FillOp>(op)) { |
| 2922 | return fill.getInputs()[0]; |
| 2923 | } |
| 2924 | |
| 2925 | // 3. tensor.generateOp - can't guarantee the value is fixed without |
| 2926 | // analysing, bail out. |
| 2927 | if (auto generate = llvm::dyn_cast<tensor::GenerateOp>(op)) { |
| 2928 | return {}; |
| 2929 | } |
| 2930 | |
| 2931 | // 4. vector.transfer_write - inspect the input vector that's written from. If |
| 2932 | // if contains a single value that has been broadcast (e.g. via |
| 2933 | // vector.broadcast), extract it, fail otherwise. |
| 2934 | if (auto xferWrite = llvm::dyn_cast<vector::TransferWriteOp>(op)) |
| 2935 | return getStaticPadVal(xferWrite.getVector().getDefiningOp()); |
| 2936 | |
| 2937 | // 5. tensor.insert_slice - inspect the destination tensor. If it's larger |
| 2938 | // than the input tensor, then, provided it's constant, we'll extract the |
| 2939 | // value that was used to generate it (via e.g. linalg.fill), fail otherwise. |
| 2940 | // TODO: Clarify the semantics when the input tensor is larger than the |
| 2941 | // destination. |
| 2942 | if (auto slice = llvm::dyn_cast<tensor::InsertSliceOp>(op)) |
| 2943 | return getStaticPadVal(slice.getDest().getDefiningOp()); |
| 2944 | |
| 2945 | return {}; |
| 2946 | } |
| 2947 | |
| 2948 | static LogicalResult |
| 2949 | vectorizeAsInsertSliceOp(RewriterBase &rewriter, tensor::InsertSliceOp sliceOp, |
| 2950 | ArrayRef<int64_t> inputVectorSizes, |
| 2951 | SmallVectorImpl<Value> &newResults) { |
| 2952 | // TODO: Introduce a parent class that will handle the insertion point update. |
| 2953 | OpBuilder::InsertionGuard g(rewriter); |
| 2954 | rewriter.setInsertionPoint(sliceOp); |
| 2955 | |
| 2956 | TypedValue<RankedTensorType> source = sliceOp.getSource(); |
| 2957 | auto sourceType = source.getType(); |
| 2958 | auto resultType = sliceOp.getResultType(); |
| 2959 | |
| 2960 | Value padValue = getStaticPadVal(sliceOp); |
| 2961 | |
| 2962 | if (!padValue) { |
| 2963 | auto elemType = sourceType.getElementType(); |
| 2964 | padValue = rewriter.create<arith::ConstantOp>( |
| 2965 | sliceOp.getLoc(), elemType, rewriter.getZeroAttr(elemType)); |
| 2966 | } |
| 2967 | |
| 2968 | // 2. Get the vector shape |
| 2969 | SmallVector<int64_t> vecShape; |
| 2970 | size_t rankDiff = resultType.getRank() - sourceType.getRank(); |
| 2971 | for (int64_t i = 0, end = sourceType.getRank(); i < end; ++i) { |
| 2972 | if (!inputVectorSizes.empty()) { |
| 2973 | vecShape.push_back(Elt: inputVectorSizes[i]); |
| 2974 | } else if (!sourceType.isDynamicDim(i)) { |
| 2975 | vecShape.push_back(Elt: sourceType.getDimSize(i)); |
| 2976 | } else if (!resultType.isDynamicDim(i)) { |
| 2977 | // Source shape is not statically known, but result shape is. |
| 2978 | // Vectorize with size of result shape. This may be larger than the |
| 2979 | // source size. |
| 2980 | // FIXME: Using rankDiff implies that the source tensor is inserted at |
| 2981 | // the end of the destination tensor. However, that's not required. |
| 2982 | vecShape.push_back(Elt: resultType.getDimSize(rankDiff + i)); |
| 2983 | } else { |
| 2984 | // Neither source nor result dim of padOp is static. Cannot vectorize |
| 2985 | // the copy. |
| 2986 | return failure(); |
| 2987 | } |
| 2988 | } |
| 2989 | auto vecType = VectorType::get(vecShape, sourceType.getElementType()); |
| 2990 | |
| 2991 | // 3. Generate TransferReadOp + TransferWriteOp |
| 2992 | auto loc = sliceOp.getLoc(); |
| 2993 | |
| 2994 | // Create read |
| 2995 | SmallVector<Value> readIndices( |
| 2996 | vecType.getRank(), rewriter.create<arith::ConstantIndexOp>(loc, 0)); |
| 2997 | Value read = mlir::vector::createReadOrMaskedRead( |
| 2998 | builder&: rewriter, loc: loc, source, inputVectorSizes: vecType.getShape(), padValue, |
| 2999 | /*useInBoundsInsteadOfMasking=*/inputVectorSizes.empty()); |
| 3000 | |
| 3001 | // Create write |
| 3002 | auto writeIndices = |
| 3003 | getValueOrCreateConstantIndexOp(rewriter, loc, sliceOp.getMixedOffsets()); |
| 3004 | Operation *write = |
| 3005 | createWriteOrMaskedWrite(rewriter, loc, read, sliceOp.getDest(), |
| 3006 | writeIndices, inputVectorSizes.empty()); |
| 3007 | |
| 3008 | // 4. Finalize |
| 3009 | newResults.push_back(Elt: write->getResult(idx: 0)); |
| 3010 | |
| 3011 | return success(); |
| 3012 | } |
| 3013 | |
| 3014 | /// Rewrite use of tensor::PadOp result in InsertSliceOp. E.g.: |
| 3015 | /// ``` |
| 3016 | /// %0 = tensor.pad %src ... : tensor<?x?xf32> to tensor<17x5xf32> |
| 3017 | /// %r = tensor.insert_slice %0 |
| 3018 | /// into %dest[%a, %b, 0, 0] [1, 1, 17, 5] [1, 1, 1, 1] |
| 3019 | /// : tensor<17x5xf32> into tensor<?x?x17x5xf32> |
| 3020 | /// ``` |
| 3021 | /// is rewritten to: |
| 3022 | /// ``` |
| 3023 | /// %0 = vector.transfer_read %src[%c0, %c0], %padding |
| 3024 | /// : tensor<?x?xf32>, vector<17x5xf32> |
| 3025 | /// %r = vector.transfer_write %0, %dest[%a, %b, %c0, %c0] |
| 3026 | /// {in_bounds = [true, true]} : vector<17x5xf32>, tensor<?x?x17x5xf32> |
| 3027 | /// ``` |
| 3028 | /// |
| 3029 | /// This rewrite is possible if: |
| 3030 | /// - Low padding is static 0. |
| 3031 | /// - `padOp` result shape is static. |
| 3032 | /// - The entire padded tensor is inserted. |
| 3033 | /// (Implies that sizes of `insertOp` are all static.) |
| 3034 | /// - Only unit strides in `insertOp`. |
| 3035 | /// - Single, scalar padding value. |
| 3036 | /// - `padOp` result not used as destination. |
| 3037 | struct PadOpVectorizationWithInsertSlicePattern |
| 3038 | : public VectorizePadOpUserPattern<tensor::InsertSliceOp> { |
| 3039 | using VectorizePadOpUserPattern< |
| 3040 | tensor::InsertSliceOp>::VectorizePadOpUserPattern; |
| 3041 | |
| 3042 | LogicalResult rewriteUser(PatternRewriter &rewriter, tensor::PadOp padOp, |
| 3043 | tensor::InsertSliceOp insertOp) const override { |
| 3044 | // Low padding must be static 0. |
| 3045 | if (!padOp.hasZeroLowPad()) |
| 3046 | return failure(); |
| 3047 | // Only unit stride supported. |
| 3048 | if (!insertOp.hasUnitStride()) |
| 3049 | return failure(); |
| 3050 | // Pad value must be a constant. |
| 3051 | auto padValue = padOp.getConstantPaddingValue(); |
| 3052 | if (!padValue) |
| 3053 | return failure(); |
| 3054 | // Dynamic shapes not supported. |
| 3055 | if (!cast<ShapedType>(padOp.getResult().getType()).hasStaticShape()) |
| 3056 | return failure(); |
| 3057 | // Pad result not used as destination. |
| 3058 | if (insertOp.getDest() == padOp.getResult()) |
| 3059 | return failure(); |
| 3060 | |
| 3061 | auto vecType = VectorType::get(padOp.getType().getShape(), |
| 3062 | padOp.getType().getElementType()); |
| 3063 | unsigned vecRank = vecType.getRank(); |
| 3064 | unsigned tensorRank = insertOp.getType().getRank(); |
| 3065 | |
| 3066 | // Check if sizes match: Insert the entire tensor into most minor dims. |
| 3067 | // (No permutations allowed.) |
| 3068 | SmallVector<int64_t> expectedSizes(tensorRank - vecRank, 1); |
| 3069 | expectedSizes.append(vecType.getShape().begin(), vecType.getShape().end()); |
| 3070 | if (!llvm::all_of( |
| 3071 | llvm::zip(insertOp.getMixedSizes(), expectedSizes), [](auto it) { |
| 3072 | return getConstantIntValue(std::get<0>(it)) == std::get<1>(it); |
| 3073 | })) |
| 3074 | return failure(); |
| 3075 | |
| 3076 | // Insert the TransferReadOp and TransferWriteOp at the position of the |
| 3077 | // InsertSliceOp. |
| 3078 | rewriter.setInsertionPoint(insertOp); |
| 3079 | |
| 3080 | // Generate TransferReadOp: Read entire source tensor and add high |
| 3081 | // padding. |
| 3082 | SmallVector<Value> readIndices( |
| 3083 | vecRank, rewriter.create<arith::ConstantIndexOp>(padOp.getLoc(), 0)); |
| 3084 | auto read = rewriter.create<vector::TransferReadOp>( |
| 3085 | padOp.getLoc(), vecType, padOp.getSource(), readIndices, padValue); |
| 3086 | |
| 3087 | // Generate TransferWriteOp: Write to InsertSliceOp's dest tensor at |
| 3088 | // specified offsets. Write is fully in-bounds because a InsertSliceOp's |
| 3089 | // source must fit into the destination at the specified offsets. |
| 3090 | auto writeIndices = getValueOrCreateConstantIndexOp( |
| 3091 | rewriter, padOp.getLoc(), insertOp.getMixedOffsets()); |
| 3092 | SmallVector<bool> inBounds(vecRank, true); |
| 3093 | rewriter.replaceOpWithNewOp<vector::TransferWriteOp>( |
| 3094 | insertOp, read, insertOp.getDest(), writeIndices, |
| 3095 | ArrayRef<bool>{inBounds}); |
| 3096 | |
| 3097 | return success(); |
| 3098 | } |
| 3099 | }; |
| 3100 | |
| 3101 | void mlir::linalg::populatePadOpVectorizationPatterns( |
| 3102 | RewritePatternSet &patterns, PatternBenefit baseBenefit) { |
| 3103 | patterns.add<PadOpVectorizationWithTransferReadPattern, |
| 3104 | PadOpVectorizationWithTransferWritePattern, |
| 3105 | PadOpVectorizationWithInsertSlicePattern>( |
| 3106 | arg: patterns.getContext(), args: baseBenefit.getBenefit() + 1); |
| 3107 | } |
| 3108 | |
| 3109 | //----------------------------------------------------------------------------// |
| 3110 | // Forwarding patterns |
| 3111 | //----------------------------------------------------------------------------// |
| 3112 | |
| 3113 | /// Check whether there is any interleaved use of any `values` between |
| 3114 | /// `firstOp` and `secondOp`. Conservatively return `true` if any op or value |
| 3115 | /// is in a different block. |
| 3116 | static bool mayExistInterleavedUses(Operation *firstOp, Operation *secondOp, |
| 3117 | ValueRange values) { |
| 3118 | if (firstOp->getBlock() != secondOp->getBlock() || |
| 3119 | !firstOp->isBeforeInBlock(other: secondOp)) { |
| 3120 | LDBG("interleavedUses precondition failed, firstOp: " |
| 3121 | << *firstOp << ", second op: " << *secondOp << "\n" ); |
| 3122 | return true; |
| 3123 | } |
| 3124 | for (auto v : values) { |
| 3125 | for (auto &u : v.getUses()) { |
| 3126 | Operation *owner = u.getOwner(); |
| 3127 | if (owner == firstOp || owner == secondOp) |
| 3128 | continue; |
| 3129 | // TODO: this is too conservative, use dominance info in the future. |
| 3130 | if (owner->getBlock() == firstOp->getBlock() && |
| 3131 | (owner->isBeforeInBlock(other: firstOp) || secondOp->isBeforeInBlock(other: owner))) |
| 3132 | continue; |
| 3133 | LDBG(" found interleaved op " << *owner << ", firstOp: " << *firstOp |
| 3134 | << ", second op: " << *secondOp << "\n" ); |
| 3135 | return true; |
| 3136 | } |
| 3137 | } |
| 3138 | return false; |
| 3139 | } |
| 3140 | |
| 3141 | /// Return the unique subview use of `v` if it is indeed unique, null |
| 3142 | /// otherwise. |
| 3143 | static memref::SubViewOp getSubViewUseIfUnique(Value v) { |
| 3144 | memref::SubViewOp subViewOp; |
| 3145 | for (auto &u : v.getUses()) { |
| 3146 | if (auto newSubViewOp = dyn_cast<memref::SubViewOp>(u.getOwner())) { |
| 3147 | if (subViewOp) |
| 3148 | return memref::SubViewOp(); |
| 3149 | subViewOp = newSubViewOp; |
| 3150 | } |
| 3151 | } |
| 3152 | return subViewOp; |
| 3153 | } |
| 3154 | |
| 3155 | /// TODO: use interfaces, side-effects and aliasing analysis as appropriate, |
| 3156 | /// when available. |
| 3157 | LogicalResult LinalgCopyVTRForwardingPattern::matchAndRewrite( |
| 3158 | vector::TransferReadOp xferOp, PatternRewriter &rewriter) const { |
| 3159 | |
| 3160 | // TODO: support mask. |
| 3161 | if (xferOp.getMask()) |
| 3162 | return rewriter.notifyMatchFailure(xferOp, "unsupported mask" ); |
| 3163 | |
| 3164 | // Transfer into `view`. |
| 3165 | Value viewOrAlloc = xferOp.getBase(); |
| 3166 | if (!viewOrAlloc.getDefiningOp<memref::ViewOp>() && |
| 3167 | !viewOrAlloc.getDefiningOp<memref::AllocOp>()) |
| 3168 | return rewriter.notifyMatchFailure(xferOp, "source not a view or alloc" ); |
| 3169 | |
| 3170 | // Ensure there is exactly one subview of `viewOrAlloc` defining `subView`. |
| 3171 | memref::SubViewOp subViewOp = getSubViewUseIfUnique(viewOrAlloc); |
| 3172 | if (!subViewOp) |
| 3173 | return rewriter.notifyMatchFailure(xferOp, "no subview found" ); |
| 3174 | Value subView = subViewOp.getResult(); |
| 3175 | |
| 3176 | // Find the copy into `subView` without interleaved uses. |
| 3177 | memref::CopyOp copyOp; |
| 3178 | for (auto &u : subView.getUses()) { |
| 3179 | if (auto newCopyOp = dyn_cast<memref::CopyOp>(u.getOwner())) { |
| 3180 | assert(isa<MemRefType>(newCopyOp.getTarget().getType())); |
| 3181 | if (newCopyOp.getTarget() != subView) |
| 3182 | continue; |
| 3183 | if (mayExistInterleavedUses(newCopyOp, xferOp, {viewOrAlloc, subView})) |
| 3184 | continue; |
| 3185 | copyOp = newCopyOp; |
| 3186 | break; |
| 3187 | } |
| 3188 | } |
| 3189 | if (!copyOp) |
| 3190 | return rewriter.notifyMatchFailure(xferOp, "no copy found" ); |
| 3191 | |
| 3192 | // Find the fill into `viewOrAlloc` without interleaved uses before the |
| 3193 | // copy. |
| 3194 | FillOp maybeFillOp; |
| 3195 | for (auto &u : viewOrAlloc.getUses()) { |
| 3196 | if (auto newFillOp = dyn_cast<FillOp>(u.getOwner())) { |
| 3197 | assert(isa<MemRefType>(newFillOp.output().getType())); |
| 3198 | if (newFillOp.output() != viewOrAlloc) |
| 3199 | continue; |
| 3200 | if (mayExistInterleavedUses(newFillOp, copyOp, {viewOrAlloc, subView})) |
| 3201 | continue; |
| 3202 | maybeFillOp = newFillOp; |
| 3203 | break; |
| 3204 | } |
| 3205 | } |
| 3206 | // Ensure padding matches. |
| 3207 | if (maybeFillOp && xferOp.getPadding() != maybeFillOp.value()) |
| 3208 | return rewriter.notifyMatchFailure(xferOp, |
| 3209 | "padding value does not match fill" ); |
| 3210 | |
| 3211 | // `in` is the subview that memref.copy reads. Replace it. |
| 3212 | Value in = copyOp.getSource(); |
| 3213 | |
| 3214 | // memref.copy + linalg.fill can be used to create a padded local buffer. |
| 3215 | // The `masked` attribute is only valid on this padded buffer. |
| 3216 | // When forwarding to vector.transfer_read, the attribute must be reset |
| 3217 | // conservatively. |
| 3218 | auto vectorType = xferOp.getVectorType(); |
| 3219 | Value res = rewriter.create<vector::TransferReadOp>( |
| 3220 | xferOp.getLoc(), vectorType, in, xferOp.getIndices(), |
| 3221 | xferOp.getPermutationMapAttr(), xferOp.getPadding(), xferOp.getMask(), |
| 3222 | rewriter.getBoolArrayAttr( |
| 3223 | SmallVector<bool>(vectorType.getRank(), false))); |
| 3224 | |
| 3225 | if (maybeFillOp) |
| 3226 | rewriter.eraseOp(op: maybeFillOp); |
| 3227 | rewriter.eraseOp(op: copyOp); |
| 3228 | rewriter.replaceOp(xferOp, res); |
| 3229 | |
| 3230 | return success(); |
| 3231 | } |
| 3232 | |
| 3233 | /// TODO: use interfaces, side-effects and aliasing analysis as appropriate, |
| 3234 | /// when available. |
| 3235 | LogicalResult LinalgCopyVTWForwardingPattern::matchAndRewrite( |
| 3236 | vector::TransferWriteOp xferOp, PatternRewriter &rewriter) const { |
| 3237 | // TODO: support mask. |
| 3238 | if (xferOp.getMask()) |
| 3239 | return rewriter.notifyMatchFailure(xferOp, "unsupported mask" ); |
| 3240 | |
| 3241 | // Transfer into `viewOrAlloc`. |
| 3242 | Value viewOrAlloc = xferOp.getBase(); |
| 3243 | if (!viewOrAlloc.getDefiningOp<memref::ViewOp>() && |
| 3244 | !viewOrAlloc.getDefiningOp<memref::AllocOp>()) |
| 3245 | return rewriter.notifyMatchFailure(xferOp, "source not a view or alloc" ); |
| 3246 | |
| 3247 | // Ensure there is exactly one subview of `viewOrAlloc` defining `subView`. |
| 3248 | memref::SubViewOp subViewOp = getSubViewUseIfUnique(viewOrAlloc); |
| 3249 | if (!subViewOp) |
| 3250 | return rewriter.notifyMatchFailure(xferOp, "no subview found" ); |
| 3251 | Value subView = subViewOp.getResult(); |
| 3252 | |
| 3253 | // Find the copy from `subView` without interleaved uses. |
| 3254 | memref::CopyOp copyOp; |
| 3255 | for (auto &u : subViewOp.getResult().getUses()) { |
| 3256 | if (auto newCopyOp = dyn_cast<memref::CopyOp>(u.getOwner())) { |
| 3257 | if (newCopyOp.getSource() != subView) |
| 3258 | continue; |
| 3259 | if (mayExistInterleavedUses(xferOp, newCopyOp, {viewOrAlloc, subView})) |
| 3260 | continue; |
| 3261 | copyOp = newCopyOp; |
| 3262 | break; |
| 3263 | } |
| 3264 | } |
| 3265 | if (!copyOp) |
| 3266 | return rewriter.notifyMatchFailure(xferOp, "no copy found" ); |
| 3267 | |
| 3268 | // `out` is the subview copied into that we replace. |
| 3269 | assert(isa<MemRefType>(copyOp.getTarget().getType())); |
| 3270 | Value out = copyOp.getTarget(); |
| 3271 | |
| 3272 | // Forward vector.transfer into copy. |
| 3273 | // memref.copy + linalg.fill can be used to create a padded local buffer. |
| 3274 | // The `masked` attribute is only valid on this padded buffer. |
| 3275 | // When forwarding to vector.transfer_write, the attribute must be reset |
| 3276 | // conservatively. |
| 3277 | auto vector = xferOp.getVector(); |
| 3278 | rewriter.create<vector::TransferWriteOp>( |
| 3279 | xferOp.getLoc(), vector, out, xferOp.getIndices(), |
| 3280 | xferOp.getPermutationMapAttr(), xferOp.getMask(), |
| 3281 | rewriter.getBoolArrayAttr(SmallVector<bool>( |
| 3282 | dyn_cast<VectorType>(vector.getType()).getRank(), false))); |
| 3283 | |
| 3284 | rewriter.eraseOp(op: copyOp); |
| 3285 | rewriter.eraseOp(op: xferOp); |
| 3286 | |
| 3287 | return success(); |
| 3288 | } |
| 3289 | |
| 3290 | //===----------------------------------------------------------------------===// |
| 3291 | // Convolution vectorization patterns |
| 3292 | //===----------------------------------------------------------------------===// |
| 3293 | |
| 3294 | template <int N> |
| 3295 | static void bindShapeDims(ShapedType shapedType) {} |
| 3296 | |
| 3297 | template <int N, typename IntTy, typename... IntTy2> |
| 3298 | static void bindShapeDims(ShapedType shapedType, IntTy &val, IntTy2 &...vals) { |
| 3299 | val = shapedType.getShape()[N]; |
| 3300 | bindShapeDims<N + 1, IntTy2 &...>(shapedType, vals...); |
| 3301 | } |
| 3302 | |
| 3303 | /// Bind a pack of int& to the leading dimensions of shapedType.getShape(). |
| 3304 | template <typename... IntTy> |
| 3305 | static void bindShapeDims(ShapedType shapedType, IntTy &...vals) { |
| 3306 | bindShapeDims<0>(shapedType, vals...); |
| 3307 | } |
| 3308 | |
| 3309 | namespace { |
| 3310 | /// Generate a vector implementation for either: |
| 3311 | /// ``` |
| 3312 | /// Op def: ( w, kw ) |
| 3313 | /// Iters: ({Par(), Red()}) |
| 3314 | /// Layout: {{w + kw}, {kw}, {w}} |
| 3315 | /// ``` |
| 3316 | /// kw is unrolled. |
| 3317 | /// |
| 3318 | /// or |
| 3319 | /// |
| 3320 | /// ``` |
| 3321 | /// Op def: ( n, w, c, kw, f ) |
| 3322 | /// Iters: ({Par(), Par(), Par(), Red(), Red()}) |
| 3323 | /// Layout: {{n, strideW * w + dilationW * kw, c}, {kw, c, f}, {n, w, f}} |
| 3324 | /// ``` |
| 3325 | /// kw is unrolled, w is unrolled iff dilationW > 1. |
| 3326 | /// |
| 3327 | /// or |
| 3328 | /// |
| 3329 | /// ``` |
| 3330 | /// Op def: ( n, c, w, f, kw ) |
| 3331 | /// Iters: ({Par(), Par(), Par(), Red(), Red()}) |
| 3332 | /// Layout: {{n, c, strideW * w + dilationW * kw}, {f, c, kw}, {n, f, w}} |
| 3333 | /// ``` |
| 3334 | /// kw is unrolled, w is unrolled iff dilationW > 1. |
| 3335 | /// |
| 3336 | /// or |
| 3337 | /// |
| 3338 | /// ``` |
| 3339 | /// Op def: ( n, w, c, kw ) |
| 3340 | /// Iters: ({Par(), Par(), Par(), Red()}) |
| 3341 | /// Layout: {{n, strideW * w + dilationW * kw, c}, {kw, c}, {n, w, c}} |
| 3342 | /// ``` |
| 3343 | /// kw is unrolled, w is unrolled iff dilationW > 1. |
| 3344 | struct Conv1DGenerator |
| 3345 | : public StructuredGenerator<LinalgOp, utils::IteratorType> { |
| 3346 | Conv1DGenerator(RewriterBase &rewriter, LinalgOp linalgOp) |
| 3347 | : StructuredGenerator<LinalgOp, utils::IteratorType>(rewriter, linalgOp) { |
| 3348 | |
| 3349 | lhsShaped = linalgOp.getDpsInputOperand(0)->get(); |
| 3350 | rhsShaped = linalgOp.getDpsInputOperand(1)->get(); |
| 3351 | resShaped = linalgOp.getDpsInitOperand(0)->get(); |
| 3352 | lhsShapedType = dyn_cast<ShapedType>(lhsShaped.getType()); |
| 3353 | rhsShapedType = dyn_cast<ShapedType>(rhsShaped.getType()); |
| 3354 | resShapedType = dyn_cast<ShapedType>(resShaped.getType()); |
| 3355 | |
| 3356 | Operation *reduceOp = matchLinalgReduction(linalgOp.getDpsInitOperand(0)); |
| 3357 | redOp = reduceOp->getName().getIdentifier(); |
| 3358 | |
| 3359 | setConvOperationKind(reduceOp); |
| 3360 | |
| 3361 | auto maybeKind = getCombinerOpKind(reduceOp); |
| 3362 | reductionKind = maybeKind.value(); |
| 3363 | |
| 3364 | // The ConvolutionOpInterface gives us guarantees of existence for |
| 3365 | // strides/dilations. However, we do not need to rely on those, we can |
| 3366 | // simply use them if present, otherwise use the default and let the generic |
| 3367 | // conv. matcher in the ConvGenerator succeed or fail. |
| 3368 | auto strides = linalgOp->getAttrOfType<DenseIntElementsAttr>("strides" ); |
| 3369 | auto dilations = linalgOp->getAttrOfType<DenseIntElementsAttr>("dilations" ); |
| 3370 | strideW = strides ? *strides.getValues<uint64_t>().begin() : 1; |
| 3371 | dilationW = dilations ? *dilations.getValues<uint64_t>().begin() : 1; |
| 3372 | } |
| 3373 | |
| 3374 | /// Generate a vector implementation for: |
| 3375 | /// ``` |
| 3376 | /// Op def: ( w, kw ) |
| 3377 | /// Iters: ({Par(), Red()}) |
| 3378 | /// Layout: {{w + kw}, {kw}, {w}} |
| 3379 | /// ``` |
| 3380 | /// kw is always unrolled. |
| 3381 | /// |
| 3382 | /// or |
| 3383 | /// |
| 3384 | /// ``` |
| 3385 | /// Op def: ( n, w, c, kw, f ) |
| 3386 | /// Iters: ({Par(), Par(), Par(), Red(), Red()}) |
| 3387 | /// Layout: {{n, strideW * w + dilationW * kw, c}, {kw, c, f}, {n, w, f}} |
| 3388 | /// ``` |
| 3389 | /// kw is always unrolled. |
| 3390 | /// TODO: w (resp. kw) is unrolled when the strideW ( resp. dilationW) is |
| 3391 | /// > 1. |
| 3392 | FailureOr<Operation *> conv(Conv1DOpOrder conv1DOpOrder) { |
| 3393 | int64_t nSize, wSize, cSize, kwSize, fSize; |
| 3394 | SmallVector<int64_t, 3> lhsShape, rhsShape, resShape; |
| 3395 | bool isSingleChanneled = (conv1DOpOrder == Conv1DOpOrder::W); |
| 3396 | switch (conv1DOpOrder) { |
| 3397 | case Conv1DOpOrder::W: |
| 3398 | // Initialize unused dimensions |
| 3399 | nSize = fSize = cSize = 0; |
| 3400 | // out{W} |
| 3401 | bindShapeDims(resShapedType, wSize); |
| 3402 | // kernel{kw} |
| 3403 | bindShapeDims(rhsShapedType, kwSize); |
| 3404 | lhsShape = {// iw = ow + kw - 1 |
| 3405 | // (i.e. 16 convolved with 3 -> 14) |
| 3406 | (wSize + kwSize - 1)}; |
| 3407 | rhsShape = {kwSize}; |
| 3408 | resShape = {wSize}; |
| 3409 | break; |
| 3410 | case Conv1DOpOrder::Nwc: |
| 3411 | // out{n, w, f} |
| 3412 | bindShapeDims(resShapedType, nSize, wSize, fSize); |
| 3413 | switch (oper) { |
| 3414 | case ConvOperationKind::Conv: |
| 3415 | // kernel{kw, c, f} |
| 3416 | bindShapeDims(rhsShapedType, kwSize, cSize); |
| 3417 | break; |
| 3418 | case ConvOperationKind::Pool: |
| 3419 | // kernel{kw} |
| 3420 | bindShapeDims(rhsShapedType, kwSize); |
| 3421 | cSize = fSize; |
| 3422 | break; |
| 3423 | } |
| 3424 | lhsShape = {nSize, |
| 3425 | // iw = ow * sw + kw * dw - 1 |
| 3426 | // (i.e. 16 convolved with 3 (@stride 1 dilation 1) -> 14) |
| 3427 | // Perform the proper inclusive -> exclusive -> inclusive. |
| 3428 | ((wSize - 1) * strideW + 1) + ((kwSize - 1) * dilationW + 1) - |
| 3429 | 1, |
| 3430 | cSize}; |
| 3431 | switch (oper) { |
| 3432 | case ConvOperationKind::Conv: |
| 3433 | rhsShape = {kwSize, cSize, fSize}; |
| 3434 | break; |
| 3435 | case ConvOperationKind::Pool: |
| 3436 | rhsShape = {kwSize}; |
| 3437 | break; |
| 3438 | } |
| 3439 | resShape = {nSize, wSize, fSize}; |
| 3440 | break; |
| 3441 | case Conv1DOpOrder::Ncw: |
| 3442 | // out{n, f, w} |
| 3443 | bindShapeDims(resShapedType, nSize, fSize, wSize); |
| 3444 | switch (oper) { |
| 3445 | case ConvOperationKind::Conv: |
| 3446 | // kernel{f, c, kw} |
| 3447 | bindShapeDims(rhsShapedType, fSize, cSize, kwSize); |
| 3448 | break; |
| 3449 | case ConvOperationKind::Pool: |
| 3450 | // kernel{kw} |
| 3451 | bindShapeDims(rhsShapedType, kwSize); |
| 3452 | cSize = fSize; |
| 3453 | break; |
| 3454 | } |
| 3455 | lhsShape = {nSize, cSize, |
| 3456 | // iw = ow * sw + kw * dw - 1 |
| 3457 | // (i.e. 16 convolved with 3 (@stride 1 dilation 1) -> 14) |
| 3458 | // Perform the proper inclusive -> exclusive -> inclusive. |
| 3459 | ((wSize - 1) * strideW + 1) + ((kwSize - 1) * dilationW + 1) - |
| 3460 | 1}; |
| 3461 | switch (oper) { |
| 3462 | case ConvOperationKind::Conv: |
| 3463 | rhsShape = {fSize, cSize, kwSize}; |
| 3464 | break; |
| 3465 | case ConvOperationKind::Pool: |
| 3466 | rhsShape = {kwSize}; |
| 3467 | break; |
| 3468 | } |
| 3469 | resShape = {nSize, fSize, wSize}; |
| 3470 | break; |
| 3471 | } |
| 3472 | |
| 3473 | vector::TransferWriteOp write; |
| 3474 | Value zero = rewriter.create<arith::ConstantIndexOp>(loc, 0); |
| 3475 | |
| 3476 | // w is unrolled (i.e. wSizeStep == 1) iff strideW > 1. |
| 3477 | // When strideW == 1, we can batch the contiguous loads and avoid |
| 3478 | // unrolling |
| 3479 | int64_t wSizeStep = strideW == 1 ? wSize : 1; |
| 3480 | |
| 3481 | Type lhsEltType = lhsShapedType.getElementType(); |
| 3482 | Type rhsEltType = rhsShapedType.getElementType(); |
| 3483 | Type resEltType = resShapedType.getElementType(); |
| 3484 | auto lhsType = VectorType::get(lhsShape, lhsEltType); |
| 3485 | auto rhsType = VectorType::get(rhsShape, rhsEltType); |
| 3486 | auto resType = VectorType::get(resShape, resEltType); |
| 3487 | // Zero padding with the corresponding dimensions for lhs, rhs and res. |
| 3488 | SmallVector<Value> lhsPadding(lhsShape.size(), zero); |
| 3489 | SmallVector<Value> rhsPadding(rhsShape.size(), zero); |
| 3490 | SmallVector<Value> resPadding(resShape.size(), zero); |
| 3491 | |
| 3492 | // Read the whole lhs, rhs and res in one shot (with zero padding). |
| 3493 | Value lhs = rewriter.create<vector::TransferReadOp>(loc, lhsType, lhsShaped, |
| 3494 | lhsPadding); |
| 3495 | // This is needed only for Conv. |
| 3496 | Value rhs = nullptr; |
| 3497 | if (oper == ConvOperationKind::Conv) |
| 3498 | rhs = rewriter.create<vector::TransferReadOp>(loc, rhsType, rhsShaped, |
| 3499 | rhsPadding); |
| 3500 | Value res = rewriter.create<vector::TransferReadOp>(loc, resType, resShaped, |
| 3501 | resPadding); |
| 3502 | |
| 3503 | // The base vectorization case for channeled convolution is input: |
| 3504 | // {n,w,c}, weight: {kw,c,f}, output: {n,w,f}. To reuse the base pattern |
| 3505 | // vectorization case, we do pre transpose on input, weight, and output. |
| 3506 | switch (conv1DOpOrder) { |
| 3507 | case Conv1DOpOrder::W: |
| 3508 | case Conv1DOpOrder::Nwc: |
| 3509 | // Base case, so no transposes necessary. |
| 3510 | break; |
| 3511 | case Conv1DOpOrder::Ncw: { |
| 3512 | // To match base vectorization case, we pre-transpose current case. |
| 3513 | // ncw -> nwc |
| 3514 | static constexpr std::array<int64_t, 3> permLhs = {0, 2, 1}; |
| 3515 | lhs = rewriter.create<vector::TransposeOp>(loc, lhs, permLhs); |
| 3516 | // fcw -> wcf |
| 3517 | static constexpr std::array<int64_t, 3> permRhs = {2, 1, 0}; |
| 3518 | |
| 3519 | // This is needed only for Conv. |
| 3520 | if (oper == ConvOperationKind::Conv) |
| 3521 | rhs = rewriter.create<vector::TransposeOp>(loc, rhs, permRhs); |
| 3522 | // nfw -> nwf |
| 3523 | static constexpr std::array<int64_t, 3> permRes = {0, 2, 1}; |
| 3524 | res = rewriter.create<vector::TransposeOp>(loc, res, permRes); |
| 3525 | break; |
| 3526 | } |
| 3527 | } |
| 3528 | |
| 3529 | //===------------------------------------------------------------------===// |
| 3530 | // Begin vector-only rewrite part |
| 3531 | //===------------------------------------------------------------------===// |
| 3532 | // Unroll along kw and read slices of lhs and rhs. |
| 3533 | SmallVector<Value> lhsVals, rhsVals, resVals; |
| 3534 | lhsVals = extractConvInputSlices(rewriter, loc, lhs, nSize, wSize, cSize, |
| 3535 | kwSize, strideW, dilationW, wSizeStep, |
| 3536 | isSingleChanneled); |
| 3537 | // Do not do for pooling. |
| 3538 | if (oper == ConvOperationKind::Conv) |
| 3539 | rhsVals = extractConvFilterSlices(rewriter, loc, rhs, kwSize); |
| 3540 | resVals = extractConvResultSlices(rewriter, loc, res, nSize, wSize, fSize, |
| 3541 | wSizeStep, isSingleChanneled); |
| 3542 | |
| 3543 | auto linearIndex = [&](int64_t kw, int64_t w) { |
| 3544 | return kw * (wSize / wSizeStep) + w; |
| 3545 | }; |
| 3546 | |
| 3547 | // Compute contraction: O{n, w, f} += I{n, sw * w + dw * kw, c} * F{c, f} |
| 3548 | // or perform outerproduct for non-channeled convolution or perform simple |
| 3549 | // arith operation for pooling |
| 3550 | for (int64_t kw = 0; kw < kwSize; ++kw) { |
| 3551 | for (int64_t w = 0; w < wSize; w += wSizeStep) { |
| 3552 | switch (oper) { |
| 3553 | case ConvOperationKind::Conv: |
| 3554 | if (isSingleChanneled) { |
| 3555 | resVals[w] = conv1dSliceAsOuterProduct(rewriter, loc, |
| 3556 | lhsVals[linearIndex(kw, w)], |
| 3557 | rhsVals[kw], resVals[w]); |
| 3558 | } else { |
| 3559 | resVals[w] = conv1dSliceAsContraction(rewriter, loc, |
| 3560 | lhsVals[linearIndex(kw, w)], |
| 3561 | rhsVals[kw], resVals[w]); |
| 3562 | } |
| 3563 | break; |
| 3564 | case ConvOperationKind::Pool: |
| 3565 | resVals[w] = pool1dSlice(rewriter, loc, lhsVals[linearIndex(kw, w)], |
| 3566 | resVals[w]); |
| 3567 | break; |
| 3568 | } |
| 3569 | } |
| 3570 | } |
| 3571 | |
| 3572 | res = insertConvResultSlices(rewriter, loc, res, wSize, wSizeStep, resVals, |
| 3573 | isSingleChanneled); |
| 3574 | //===------------------------------------------------------------------===// |
| 3575 | // End vector-only rewrite part |
| 3576 | //===------------------------------------------------------------------===// |
| 3577 | |
| 3578 | // The base vectorization case for channeled convolution is output: |
| 3579 | // {n,w,f} To reuse the result from base pattern vectorization case, we |
| 3580 | // post transpose the base case result. |
| 3581 | switch (conv1DOpOrder) { |
| 3582 | case Conv1DOpOrder::W: |
| 3583 | case Conv1DOpOrder::Nwc: |
| 3584 | // Base case, so no transposes necessary. |
| 3585 | break; |
| 3586 | case Conv1DOpOrder::Ncw: { |
| 3587 | // nwf -> nfw |
| 3588 | static constexpr std::array<int64_t, 3> perm = {0, 2, 1}; |
| 3589 | res = rewriter.create<vector::TransposeOp>(loc, res, perm); |
| 3590 | break; |
| 3591 | } |
| 3592 | } |
| 3593 | |
| 3594 | return rewriter |
| 3595 | .create<vector::TransferWriteOp>(loc, res, resShaped, resPadding) |
| 3596 | .getOperation(); |
| 3597 | } |
| 3598 | |
| 3599 | // Take a value and widen to have the same element type as `ty`. |
| 3600 | Value promote(RewriterBase &rewriter, Location loc, Value val, Type ty) { |
| 3601 | const Type srcElementType = getElementTypeOrSelf(type: val.getType()); |
| 3602 | const Type dstElementType = getElementTypeOrSelf(type: ty); |
| 3603 | assert(isa<IntegerType>(dstElementType) || isa<FloatType>(dstElementType)); |
| 3604 | if (srcElementType == dstElementType) |
| 3605 | return val; |
| 3606 | |
| 3607 | const int64_t srcWidth = srcElementType.getIntOrFloatBitWidth(); |
| 3608 | const int64_t dstWidth = dstElementType.getIntOrFloatBitWidth(); |
| 3609 | const Type dstType = |
| 3610 | cast<ShapedType>(val.getType()).cloneWith(std::nullopt, dstElementType); |
| 3611 | |
| 3612 | if (isa<IntegerType>(Val: srcElementType) && isa<FloatType>(Val: dstElementType)) { |
| 3613 | return rewriter.create<arith::SIToFPOp>(loc, dstType, val); |
| 3614 | } |
| 3615 | |
| 3616 | if (isa<FloatType>(srcElementType) && isa<FloatType>(dstElementType) && |
| 3617 | srcWidth < dstWidth) |
| 3618 | return rewriter.create<arith::ExtFOp>(loc, dstType, val); |
| 3619 | |
| 3620 | if (isa<IntegerType>(srcElementType) && isa<IntegerType>(dstElementType) && |
| 3621 | srcWidth < dstWidth) |
| 3622 | return rewriter.create<arith::ExtSIOp>(loc, dstType, val); |
| 3623 | |
| 3624 | assert(false && "unhandled promotion case" ); |
| 3625 | return nullptr; |
| 3626 | } |
| 3627 | |
| 3628 | // Create a contraction: lhs{n, w, c} * rhs{c, f} -> res{n, w, f} |
| 3629 | Value conv1dSliceAsContraction(RewriterBase &rewriter, Location loc, |
| 3630 | Value lhs, Value rhs, Value res) { |
| 3631 | vector::IteratorType par = vector::IteratorType::parallel; |
| 3632 | vector::IteratorType red = vector::IteratorType::reduction; |
| 3633 | AffineExpr n, w, f, c; |
| 3634 | bindDims(ctx, n, w, f, c); |
| 3635 | lhs = promote(rewriter, loc, val: lhs, ty: res.getType()); |
| 3636 | rhs = promote(rewriter, loc, val: rhs, ty: res.getType()); |
| 3637 | auto contrationOp = rewriter.create<vector::ContractionOp>( |
| 3638 | loc, lhs, rhs, res, |
| 3639 | /*indexingMaps=*/MapList{{n, w, c}, {c, f}, {n, w, f}}, |
| 3640 | /*iteratorTypes=*/ArrayRef<vector::IteratorType>{par, par, par, red}); |
| 3641 | contrationOp.setKind(reductionKind); |
| 3642 | return contrationOp; |
| 3643 | } |
| 3644 | |
| 3645 | // Create an outerproduct: lhs{w} * rhs{1} -> res{w} for single channel |
| 3646 | // convolution. |
| 3647 | Value conv1dSliceAsOuterProduct(RewriterBase &rewriter, Location loc, |
| 3648 | Value lhs, Value rhs, Value res) { |
| 3649 | return rewriter.create<vector::OuterProductOp>( |
| 3650 | loc, res.getType(), lhs, rhs, res, vector::CombiningKind::ADD); |
| 3651 | } |
| 3652 | |
| 3653 | // Create a reduction: lhs{n, w, c} -> res{n, w, c} |
| 3654 | Value pool1dSlice(RewriterBase &rewriter, Location loc, Value lhs, |
| 3655 | Value res) { |
| 3656 | if (isPoolExt) |
| 3657 | lhs = rewriter.create(loc, poolExtOp, lhs, res.getType())->getResult(0); |
| 3658 | return rewriter |
| 3659 | .create(loc, redOp, ArrayRef<Value>{lhs, res}, res.getType()) |
| 3660 | ->getResult(0); |
| 3661 | } |
| 3662 | |
| 3663 | /// Generate a vector implementation for: |
| 3664 | /// ``` |
| 3665 | /// Op def: ( n, w, c, kw) |
| 3666 | /// Iters: ({Par(), Par(), Par(), Red()}) |
| 3667 | /// Layout: {{n, strideW * w + dilationW * kw, c}, {kw, c}, {n, w, c}} |
| 3668 | /// ``` |
| 3669 | /// kw is always unrolled. |
| 3670 | /// TODO: w (resp. kw) is unrolled when the strideW ( resp. dilationW) is |
| 3671 | /// > 1. |
| 3672 | FailureOr<Operation *> depthwiseConv(uint64_t channelDimVecSize, |
| 3673 | bool channelDimScalableFlag, |
| 3674 | bool flatten) { |
| 3675 | bool scalableChDim = false; |
| 3676 | bool useMasking = false; |
| 3677 | int64_t nSize, wSize, cSize, kwSize; |
| 3678 | // kernel{kw, c} |
| 3679 | bindShapeDims(rhsShapedType, kwSize, cSize); |
| 3680 | if (ShapedType::isDynamic(cSize)) { |
| 3681 | assert(channelDimVecSize != 0 && "Channel dim vec size must be > 0" ); |
| 3682 | cSize = channelDimVecSize; |
| 3683 | // Scalable vectors are only used when both conditions are met: |
| 3684 | // 1. channel dim is dynamic |
| 3685 | // 2. channelDimScalableFlag is set |
| 3686 | scalableChDim = channelDimScalableFlag; |
| 3687 | useMasking = true; |
| 3688 | } |
| 3689 | |
| 3690 | assert(!(useMasking && flatten) && |
| 3691 | "Unsupported flattened conv with dynamic shapes" ); |
| 3692 | |
| 3693 | // out{n, w, c} |
| 3694 | bindShapeDims(resShapedType, nSize, wSize); |
| 3695 | |
| 3696 | vector::TransferWriteOp write; |
| 3697 | Value zero = rewriter.create<arith::ConstantIndexOp>(loc, 0); |
| 3698 | |
| 3699 | // w is unrolled (i.e. wSizeStep == 1) iff strideW > 1. |
| 3700 | // When strideW == 1, we can batch the contiguous loads and avoid |
| 3701 | // unrolling |
| 3702 | int64_t wSizeStep = strideW == 1 ? wSize : 1; |
| 3703 | |
| 3704 | Type lhsEltType = lhsShapedType.getElementType(); |
| 3705 | Type rhsEltType = rhsShapedType.getElementType(); |
| 3706 | Type resEltType = resShapedType.getElementType(); |
| 3707 | VectorType lhsType = VectorType::get( |
| 3708 | {nSize, |
| 3709 | // iw = ow * sw + kw * dw - 1 |
| 3710 | // (i.e. 16 convolved with 3 (@stride 1 dilation 1) -> 14) |
| 3711 | ((wSize - 1) * strideW + 1) + ((kwSize - 1) * dilationW + 1) - 1, |
| 3712 | cSize}, |
| 3713 | lhsEltType, /*scalableDims=*/{false, false, scalableChDim}); |
| 3714 | VectorType rhsType = |
| 3715 | VectorType::get({kwSize, cSize}, rhsEltType, |
| 3716 | /*scalableDims=*/{false, scalableChDim}); |
| 3717 | VectorType resType = |
| 3718 | VectorType::get({nSize, wSize, cSize}, resEltType, |
| 3719 | /*scalableDims=*/{false, false, scalableChDim}); |
| 3720 | |
| 3721 | // Masks the input xfer Op along the channel dim, iff the corresponding |
| 3722 | // scalable flag is set. |
| 3723 | auto maybeMaskXferOp = [&](ArrayRef<int64_t> maskShape, |
| 3724 | ArrayRef<bool> scalableDims, |
| 3725 | Operation *opToMask) { |
| 3726 | if (!useMasking) |
| 3727 | return opToMask; |
| 3728 | auto maskType = |
| 3729 | VectorType::get(maskShape, rewriter.getI1Type(), scalableDims); |
| 3730 | |
| 3731 | SmallVector<bool> inBounds(maskShape.size(), true); |
| 3732 | auto xferOp = cast<VectorTransferOpInterface>(opToMask); |
| 3733 | xferOp->setAttr(xferOp.getInBoundsAttrName(), |
| 3734 | rewriter.getBoolArrayAttr(inBounds)); |
| 3735 | |
| 3736 | SmallVector<OpFoldResult> mixedDims = vector::getMixedSizesXfer( |
| 3737 | cast<LinalgOp>(op).hasPureTensorSemantics(), opToMask, rewriter); |
| 3738 | |
| 3739 | Value maskOp = |
| 3740 | rewriter.create<vector::CreateMaskOp>(loc, maskType, mixedDims); |
| 3741 | |
| 3742 | return mlir::vector::maskOperation(rewriter, opToMask, maskOp); |
| 3743 | }; |
| 3744 | |
| 3745 | // Read lhs slice of size {n, w * strideW + kw * dilationW, c} @ [0, 0, |
| 3746 | // 0]. |
| 3747 | Value lhs = rewriter.create<vector::TransferReadOp>( |
| 3748 | loc, lhsType, lhsShaped, ValueRange{zero, zero, zero}); |
| 3749 | auto maybeMaskedLhs = maybeMaskXferOp( |
| 3750 | lhsType.getShape(), lhsType.getScalableDims(), lhs.getDefiningOp()); |
| 3751 | |
| 3752 | // Read rhs slice of size {kw, c} @ [0, 0]. |
| 3753 | Value rhs = rewriter.create<vector::TransferReadOp>(loc, rhsType, rhsShaped, |
| 3754 | ValueRange{zero, zero}); |
| 3755 | auto maybeMaskedRhs = maybeMaskXferOp( |
| 3756 | rhsType.getShape(), rhsType.getScalableDims(), rhs.getDefiningOp()); |
| 3757 | |
| 3758 | // Read res slice of size {n, w, c} @ [0, 0, 0]. |
| 3759 | Value res = rewriter.create<vector::TransferReadOp>( |
| 3760 | loc, resType, resShaped, ValueRange{zero, zero, zero}); |
| 3761 | auto maybeMaskedRes = maybeMaskXferOp( |
| 3762 | resType.getShape(), resType.getScalableDims(), res.getDefiningOp()); |
| 3763 | |
| 3764 | //===------------------------------------------------------------------===// |
| 3765 | // Begin vector-only rewrite part |
| 3766 | //===------------------------------------------------------------------===// |
| 3767 | // Unroll along kw and read slices of lhs and rhs. |
| 3768 | SmallVector<Value> lhsVals, rhsVals, resVals; |
| 3769 | SmallVector<int64_t> inOutSliceSizes = {nSize, wSizeStep, cSize}; |
| 3770 | SmallVector<int64_t> inOutStrides = {1, 1, 1}; |
| 3771 | |
| 3772 | // Extract lhs slice of size {n, wSizeStep, c} |
| 3773 | // @ [0, sw * w + dw * kw, 0]. |
| 3774 | for (int64_t kw = 0; kw < kwSize; ++kw) { |
| 3775 | for (int64_t w = 0; w < wSize; w += wSizeStep) { |
| 3776 | lhsVals.push_back(rewriter.create<vector::ExtractStridedSliceOp>( |
| 3777 | loc, maybeMaskedLhs->getResult(0), |
| 3778 | /*offsets=*/ArrayRef<int64_t>{0, w * strideW + kw * dilationW, 0}, |
| 3779 | inOutSliceSizes, inOutStrides)); |
| 3780 | } |
| 3781 | } |
| 3782 | // Extract rhs slice of size {c} @ [kw]. |
| 3783 | for (int64_t kw = 0; kw < kwSize; ++kw) { |
| 3784 | rhsVals.push_back(rewriter.create<vector::ExtractOp>( |
| 3785 | loc, maybeMaskedRhs->getResult(0), |
| 3786 | /*offsets=*/ArrayRef<int64_t>{kw})); |
| 3787 | } |
| 3788 | // Extract res slice: {n, wSizeStep, c} @ [0, w, 0]. |
| 3789 | for (int64_t w = 0; w < wSize; w += wSizeStep) { |
| 3790 | resVals.push_back(rewriter.create<vector::ExtractStridedSliceOp>( |
| 3791 | loc, maybeMaskedRes->getResult(0), |
| 3792 | /*offsets=*/ArrayRef<int64_t>{0, w, 0}, inOutSliceSizes, |
| 3793 | inOutStrides)); |
| 3794 | } |
| 3795 | |
| 3796 | auto linearIndex = [&](int64_t kw, int64_t w) { |
| 3797 | return kw * (wSize / wSizeStep) + w; |
| 3798 | }; |
| 3799 | |
| 3800 | // Note - the scalable flags are ignored as flattening combined with |
| 3801 | // scalable vectorization is not supported. |
| 3802 | SmallVector<int64_t> inOutFlattenSliceSizes = {nSize, wSizeStep * cSize}; |
| 3803 | auto lhsTypeAfterFlattening = |
| 3804 | VectorType::get(inOutFlattenSliceSizes, lhsEltType); |
| 3805 | auto resTypeAfterFlattening = |
| 3806 | VectorType::get(inOutFlattenSliceSizes, resEltType); |
| 3807 | |
| 3808 | // Compute contraction: O{n, w, c} += I{n, sw * w + dw * kw, c} * F{c} |
| 3809 | for (int64_t kw = 0; kw < kwSize; ++kw) { |
| 3810 | for (int64_t w = 0; w < wSize; w += wSizeStep) { |
| 3811 | Value lhsVal = lhsVals[linearIndex(kw, w)]; |
| 3812 | Value resVal = resVals[w]; |
| 3813 | if (flatten) { |
| 3814 | // Flatten the input and output vectors (collapse the channel |
| 3815 | // dimension) |
| 3816 | lhsVal = rewriter.create<vector::ShapeCastOp>( |
| 3817 | loc, lhsTypeAfterFlattening, lhsVals[linearIndex(kw, w)]); |
| 3818 | resVal = rewriter.create<vector::ShapeCastOp>( |
| 3819 | loc, resTypeAfterFlattening, resVals[w]); |
| 3820 | } |
| 3821 | resVals[w] = depthwiseConv1dSliceAsMulAcc(rewriter, loc, lhsVal, |
| 3822 | rhsVals[kw], resVal, flatten); |
| 3823 | if (flatten) { |
| 3824 | // Un-flatten the output vector (restore the channel dimension) |
| 3825 | resVals[w] = rewriter.create<vector::ShapeCastOp>( |
| 3826 | loc, VectorType::get(inOutSliceSizes, resEltType), resVals[w]); |
| 3827 | } |
| 3828 | } |
| 3829 | } |
| 3830 | |
| 3831 | // Its possible we failed to create the Fma. |
| 3832 | if (!llvm::all_of(Range&: resVals, P: [](Value v) { return v; })) { |
| 3833 | // Manually revert (in reverse order) to avoid leaving a bad IR state. |
| 3834 | for (auto &collection : |
| 3835 | {resVals, rhsVals, lhsVals, {res, rhs, lhs, zero}}) |
| 3836 | for (Value v : collection) |
| 3837 | rewriter.eraseOp(v.getDefiningOp()); |
| 3838 | return rewriter.notifyMatchFailure(op, "failed to create FMA" ); |
| 3839 | } |
| 3840 | |
| 3841 | // Write back res slice: {n, wSizeStep, c} @ [0, w, 0]. |
| 3842 | // This does not depend on kw. |
| 3843 | for (int64_t w = 0; w < wSize; w += wSizeStep) { |
| 3844 | maybeMaskedRes = rewriter.create<vector::InsertStridedSliceOp>( |
| 3845 | loc, resVals[w], maybeMaskedRes->getResult(0), |
| 3846 | /*offsets=*/ArrayRef<int64_t>{0, w, 0}, |
| 3847 | /*strides=*/ArrayRef<int64_t>{1, 1, 1}); |
| 3848 | } |
| 3849 | //===------------------------------------------------------------------===// |
| 3850 | // End vector-only rewrite part |
| 3851 | //===------------------------------------------------------------------===// |
| 3852 | |
| 3853 | // Write back res slice of size {n, w, c} @ [0, 0, 0]. |
| 3854 | Operation *resOut = rewriter.create<vector::TransferWriteOp>( |
| 3855 | loc, maybeMaskedRes->getResult(0), resShaped, |
| 3856 | ValueRange{zero, zero, zero}); |
| 3857 | return maybeMaskXferOp(resType.getShape(), resType.getScalableDims(), |
| 3858 | resOut); |
| 3859 | } |
| 3860 | |
| 3861 | /// Lower: |
| 3862 | /// * lhs{n, w, c} * rhs{c} -> res{n, w, c} (flatten = false) |
| 3863 | /// * lhs{n, w * c} * rhs{c} -> res{n, w * c} (flatten = true) |
| 3864 | /// to MulAcc. |
| 3865 | Value depthwiseConv1dSliceAsMulAcc(RewriterBase &rewriter, Location loc, |
| 3866 | Value lhs, Value rhs, Value res, |
| 3867 | bool flatten) { |
| 3868 | auto rhsTy = cast<ShapedType>(rhs.getType()); |
| 3869 | auto resTy = cast<ShapedType>(res.getType()); |
| 3870 | |
| 3871 | // TODO(suderman): Change this to use a vector.ima intrinsic. |
| 3872 | lhs = promote(rewriter, loc, val: lhs, ty: resTy); |
| 3873 | |
| 3874 | if (flatten) { |
| 3875 | // NOTE: This following logic won't work for scalable vectors. For this |
| 3876 | // reason, "flattening" is not supported when shapes are dynamic (this |
| 3877 | // should be captured by one of the pre-conditions). |
| 3878 | |
| 3879 | // There are two options for handling the filter: |
| 3880 | // * shape_cast(broadcast(filter)) |
| 3881 | // * broadcast(shuffle(filter)) |
| 3882 | // Opt for the option without shape_cast to simplify the codegen. |
| 3883 | auto rhsSize = cast<VectorType>(rhs.getType()).getShape()[0]; |
| 3884 | auto resSize = cast<VectorType>(res.getType()).getShape()[1]; |
| 3885 | |
| 3886 | SmallVector<int64_t, 16> indices; |
| 3887 | for (int i = 0; i < resSize / rhsSize; ++i) { |
| 3888 | for (int j = 0; j < rhsSize; ++j) |
| 3889 | indices.push_back(Elt: j); |
| 3890 | } |
| 3891 | |
| 3892 | rhs = rewriter.create<vector::ShuffleOp>(loc, rhs, rhs, indices); |
| 3893 | } |
| 3894 | // Broadcast the filter to match the output vector |
| 3895 | rhs = rewriter.create<vector::BroadcastOp>( |
| 3896 | loc, resTy.clone(rhsTy.getElementType()), rhs); |
| 3897 | |
| 3898 | rhs = promote(rewriter, loc, val: rhs, ty: resTy); |
| 3899 | |
| 3900 | if (!lhs || !rhs) |
| 3901 | return nullptr; |
| 3902 | |
| 3903 | if (isa<FloatType>(resTy.getElementType())) |
| 3904 | return rewriter.create<vector::FMAOp>(loc, lhs, rhs, res); |
| 3905 | |
| 3906 | auto mul = rewriter.create<arith::MulIOp>(loc, lhs, rhs); |
| 3907 | return rewriter.create<arith::AddIOp>(loc, mul, res); |
| 3908 | } |
| 3909 | |
| 3910 | /// Entry point for non-channeled convolution: |
| 3911 | /// {{w + kw}, {kw}, {w}} |
| 3912 | FailureOr<Operation *> generateNonChanneledConv() { |
| 3913 | AffineExpr w, kw; |
| 3914 | bindDims(ctx, w, kw); |
| 3915 | if (!iters({Par(), Red()})) |
| 3916 | return rewriter.notifyMatchFailure(op, |
| 3917 | "failed to match conv::W 1-par 1-red" ); |
| 3918 | |
| 3919 | // No transposition needed. |
| 3920 | if (layout({/*lhsIndex*/ {w + kw}, |
| 3921 | /*rhsIndex*/ {kw}, |
| 3922 | /*resIndex*/ {w}})) |
| 3923 | return conv(conv1DOpOrder: Conv1DOpOrder::W); |
| 3924 | |
| 3925 | return rewriter.notifyMatchFailure(op, "not a conv::W layout" ); |
| 3926 | } |
| 3927 | |
| 3928 | /// Entry point that transposes into the common form: |
| 3929 | /// {{n, strideW * w + dilationW * kw, c}, {kw, c, f}, {n, w, f}} |
| 3930 | FailureOr<Operation *> generateNwcConv() { |
| 3931 | AffineExpr n, w, f, kw, c; |
| 3932 | bindDims(ctx, n, w, f, kw, c); |
| 3933 | if (!iters({Par(), Par(), Par(), Red(), Red()})) |
| 3934 | return rewriter.notifyMatchFailure( |
| 3935 | op, "failed to match conv::Nwc 3-par 2-red" ); |
| 3936 | |
| 3937 | // No transposition needed. |
| 3938 | if (layout({/*lhsIndex*/ {n, strideW * w + dilationW * kw, c}, |
| 3939 | /*rhsIndex*/ {kw, c, f}, |
| 3940 | /*resIndex*/ {n, w, f}})) |
| 3941 | return conv(conv1DOpOrder: Conv1DOpOrder::Nwc); |
| 3942 | |
| 3943 | return rewriter.notifyMatchFailure(op, "not a conv::Nwc layout" ); |
| 3944 | } |
| 3945 | |
| 3946 | /// Entry point that transposes into the common form: |
| 3947 | /// {{n, c, strideW * w + dilationW * kw}, {f, c, kw}, {n, f, w}} |
| 3948 | FailureOr<Operation *> generateNcwConv() { |
| 3949 | AffineExpr n, w, f, kw, c; |
| 3950 | bindDims(ctx, n, f, w, c, kw); |
| 3951 | if (!iters({Par(), Par(), Par(), Red(), Red()})) |
| 3952 | return rewriter.notifyMatchFailure( |
| 3953 | op, "failed to match conv::Ncw 3-par 2-red" ); |
| 3954 | |
| 3955 | if (layout({/*lhsIndex*/ {n, c, strideW * w + dilationW * kw}, |
| 3956 | /*rhsIndex*/ {f, c, kw}, |
| 3957 | /*resIndex*/ {n, f, w}})) |
| 3958 | return conv(conv1DOpOrder: Conv1DOpOrder::Ncw); |
| 3959 | |
| 3960 | return rewriter.notifyMatchFailure(op, "not a conv::Ncw layout" ); |
| 3961 | } |
| 3962 | |
| 3963 | /// Entry point that transposes into the common form: |
| 3964 | /// {{n, strideW * w + dilationW * kw, c}, {kw}, {n, w, c}} for pooling |
| 3965 | FailureOr<Operation *> generateNwcPooling() { |
| 3966 | AffineExpr n, w, c, kw; |
| 3967 | bindDims(ctx, n, w, c, kw); |
| 3968 | if (!iters({Par(), Par(), Par(), Red()})) |
| 3969 | return rewriter.notifyMatchFailure(op, |
| 3970 | "failed to match pooling 3-par 1-red" ); |
| 3971 | |
| 3972 | // No transposition needed. |
| 3973 | if (layout({/*lhsIndex*/ {n, strideW * w + dilationW * kw, c}, |
| 3974 | /*rhsIndex*/ {kw}, |
| 3975 | /*resIndex*/ {n, w, c}})) |
| 3976 | return conv(conv1DOpOrder: Conv1DOpOrder::Nwc); |
| 3977 | |
| 3978 | return rewriter.notifyMatchFailure(op, "not a pooling::Nwc layout" ); |
| 3979 | } |
| 3980 | |
| 3981 | /// Entry point that transposes into the common form: |
| 3982 | /// {{n, c, strideW * w + dilationW * kw}, {kw}, {n, c, w}} for pooling |
| 3983 | FailureOr<Operation *> generateNcwPooling() { |
| 3984 | AffineExpr n, w, c, kw; |
| 3985 | bindDims(ctx, n, c, w, kw); |
| 3986 | if (!iters({Par(), Par(), Par(), Red()})) |
| 3987 | return rewriter.notifyMatchFailure(op, |
| 3988 | "failed to match pooling 3-par 1-red" ); |
| 3989 | |
| 3990 | if (layout({/*lhsIndex*/ {n, c, strideW * w + dilationW * kw}, |
| 3991 | /*rhsIndex*/ {kw}, |
| 3992 | /*resIndex*/ {n, c, w}})) |
| 3993 | return conv(conv1DOpOrder: Conv1DOpOrder::Ncw); |
| 3994 | |
| 3995 | return rewriter.notifyMatchFailure(op, "not a pooling::Ncw layout" ); |
| 3996 | } |
| 3997 | |
| 3998 | /// Entry point that transposes into the common form: |
| 3999 | /// {{n, strideW * w + dilationW * kw, c}, {kw, c}, {n, w, c}} |
| 4000 | FailureOr<Operation *> generateDilatedConv(uint64_t vecChDimSize = 0, |
| 4001 | bool vecChDimScalableFlag = false, |
| 4002 | bool flatten = false) { |
| 4003 | AffineExpr n, w, c, kw; |
| 4004 | bindDims(ctx, n, w, c, kw); |
| 4005 | if (!iters({Par(), Par(), Par(), Red()})) |
| 4006 | return rewriter.notifyMatchFailure( |
| 4007 | op, "failed to match depthwise::Nwc conv 3-par 1-red" ); |
| 4008 | |
| 4009 | // No transposition needed. |
| 4010 | if (layout({/*lhsIndex*/ {n, strideW * w + dilationW * kw, c}, |
| 4011 | /*rhsIndex*/ {kw, c}, |
| 4012 | /*resIndex*/ {n, w, c}})) |
| 4013 | return depthwiseConv(channelDimVecSize: vecChDimSize, channelDimScalableFlag: vecChDimScalableFlag, flatten); |
| 4014 | |
| 4015 | return rewriter.notifyMatchFailure(op, "not a depthwise::Nwc layout" ); |
| 4016 | } |
| 4017 | |
| 4018 | private: |
| 4019 | ConvOperationKind oper = ConvOperationKind::Conv; |
| 4020 | StringAttr redOp; |
| 4021 | StringAttr poolExtOp; |
| 4022 | bool isPoolExt = false; |
| 4023 | int strideW, dilationW; |
| 4024 | Value lhsShaped, rhsShaped, resShaped; |
| 4025 | ShapedType lhsShapedType, rhsShapedType, resShapedType; |
| 4026 | vector::CombiningKind reductionKind; |
| 4027 | |
| 4028 | // Sets oper, poolExtOp and isPoolExt for valid conv/pooling ops. |
| 4029 | void setConvOperationKind(Operation *reduceOp) { |
| 4030 | int numBlockArguments = |
| 4031 | llvm::count_if(Range: reduceOp->getOperands(), P: llvm::IsaPred<BlockArgument>); |
| 4032 | if (numBlockArguments == 1) { |
| 4033 | // Will be convolution if feeder is a MulOp. |
| 4034 | // A strength reduced version of MulOp for i1 type is AndOp which is also |
| 4035 | // supported. Otherwise, it can be pooling. This strength reduction logic |
| 4036 | // is in `buildBinaryFn` helper in the Linalg dialect. |
| 4037 | auto feedValIt = llvm::find_if_not(Range: reduceOp->getOperands(), |
| 4038 | P: llvm::IsaPred<BlockArgument>); |
| 4039 | Operation *feedOp = (*feedValIt).getDefiningOp(); |
| 4040 | if (isCastOfBlockArgument(op: feedOp)) { |
| 4041 | oper = ConvOperationKind::Pool; |
| 4042 | isPoolExt = true; |
| 4043 | poolExtOp = feedOp->getName().getIdentifier(); |
| 4044 | return; |
| 4045 | } |
| 4046 | oper = ConvOperationKind::Conv; |
| 4047 | return; |
| 4048 | } |
| 4049 | // numBlockArugments == 2 and this is a pooling op. |
| 4050 | oper = ConvOperationKind::Pool; |
| 4051 | isPoolExt = false; |
| 4052 | } |
| 4053 | }; |
| 4054 | } // namespace |
| 4055 | |
| 4056 | /// Helper function to vectorize a LinalgOp with convolution semantics. |
| 4057 | // TODO: extend the generic vectorization to support windows and drop this. |
| 4058 | static FailureOr<Operation *> vectorizeConvolution( |
| 4059 | RewriterBase &rewriter, LinalgOp op, ArrayRef<int64_t> inputVecSizes, |
| 4060 | ArrayRef<bool> inputScalableVecDims, bool flatten1DDepthwiseConv) { |
| 4061 | Conv1DGenerator conv1dGen(rewriter, op); |
| 4062 | auto res = conv1dGen.generateNonChanneledConv(); |
| 4063 | if (succeeded(res)) |
| 4064 | return res; |
| 4065 | res = conv1dGen.generateNwcConv(); |
| 4066 | if (succeeded(res)) |
| 4067 | return res; |
| 4068 | res = conv1dGen.generateNcwConv(); |
| 4069 | if (succeeded(res)) |
| 4070 | return res; |
| 4071 | res = conv1dGen.generateNwcPooling(); |
| 4072 | if (succeeded(res)) |
| 4073 | return res; |
| 4074 | res = conv1dGen.generateNcwPooling(); |
| 4075 | if (succeeded(res)) |
| 4076 | return res; |
| 4077 | |
| 4078 | // Only depthwise 1D NWC convs are left - these can be vectorized using masks |
| 4079 | // and scalable vectors. Note that ATM the only dim that can be dynamic (i.e. |
| 4080 | // masked/scalable) is the channel dim (i.e. the trailing dim). |
| 4081 | uint64_t vecChDimSize = ShapedType::kDynamic; |
| 4082 | bool vecChDimScalableFlag = false; |
| 4083 | if (!inputVecSizes.empty()) { |
| 4084 | // Only use the input vector size corresponding to the channel dim. Other |
| 4085 | // vector dims will be inferred from the Ops. |
| 4086 | assert((isa<linalg::DepthwiseConv1DNwcWcOp>(*op) || |
| 4087 | isa<linalg::DepthwiseConv1DNcwCwOp>(*op)) && |
| 4088 | "Not a 1D depthwise conv!" ); |
| 4089 | size_t chDimIdx = |
| 4090 | TypeSwitch<Operation *, size_t>(op) |
| 4091 | .Case<linalg::DepthwiseConv1DNwcWcOp>([](auto conv) { return 2; }) |
| 4092 | .Case<linalg::DepthwiseConv1DNcwCwOp>([](auto conv) { return 1; }); |
| 4093 | |
| 4094 | vecChDimSize = inputVecSizes[chDimIdx]; |
| 4095 | vecChDimScalableFlag = inputScalableVecDims[chDimIdx]; |
| 4096 | } |
| 4097 | return conv1dGen.generateDilatedConv(vecChDimSize, vecChDimScalableFlag, |
| 4098 | flatten: flatten1DDepthwiseConv); |
| 4099 | } |
| 4100 | |
| 4101 | struct VectorizeConvolution : public OpInterfaceRewritePattern<LinalgOp> { |
| 4102 | using OpInterfaceRewritePattern::OpInterfaceRewritePattern; |
| 4103 | |
| 4104 | LogicalResult matchAndRewrite(LinalgOp op, |
| 4105 | PatternRewriter &rewriter) const override { |
| 4106 | FailureOr<Operation *> resultOrFail = vectorizeConvolution(rewriter, op); |
| 4107 | if (failed(Result: resultOrFail)) |
| 4108 | return failure(); |
| 4109 | Operation *newOp = *resultOrFail; |
| 4110 | if (newOp->getNumResults() == 0) { |
| 4111 | rewriter.eraseOp(op: op.getOperation()); |
| 4112 | return success(); |
| 4113 | } |
| 4114 | assert(newOp->getNumResults() == 1 && "expected single result" ); |
| 4115 | rewriter.replaceOp(op.getOperation(), newOp->getResult(idx: 0)); |
| 4116 | return success(); |
| 4117 | } |
| 4118 | }; |
| 4119 | |
| 4120 | void mlir::linalg::populateConvolutionVectorizationPatterns( |
| 4121 | RewritePatternSet &patterns, PatternBenefit benefit) { |
| 4122 | patterns.add<VectorizeConvolution>(arg: patterns.getContext(), args&: benefit); |
| 4123 | } |
| 4124 | |