1//===- VectorLinearize.cpp - vector linearization transforms --------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implements patterns and pass for linearizing ND vectors into 1D.
10//
11//===----------------------------------------------------------------------===//
12
13#include "mlir/Dialect/Arith/IR/Arith.h"
14#include "mlir/Dialect/Vector/IR/VectorOps.h"
15#include "mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h"
16#include "mlir/IR/Attributes.h"
17#include "mlir/IR/BuiltinAttributes.h"
18#include "mlir/IR/Operation.h"
19#include "mlir/IR/PatternMatch.h"
20#include "mlir/IR/TypeUtilities.h"
21#include "mlir/Support/LogicalResult.h"
22#include "mlir/Transforms/DialectConversion.h"
23#include "llvm/ADT/ArrayRef.h"
24#include <cstdint>
25#include <numeric>
26
27using namespace mlir;
28
29static bool isLessThanTargetBitWidth(Operation *op, unsigned targetBitWidth) {
30 auto resultTypes = op->getResultTypes();
31 for (auto resType : resultTypes) {
32 VectorType vecType = dyn_cast<VectorType>(resType);
33 // Reject index since getElementTypeBitWidth will abort for Index types.
34 if (!vecType || vecType.getElementType().isIndex())
35 return false;
36 // There are no dimension to fold if it is a 0-D vector.
37 if (vecType.getRank() == 0)
38 return false;
39 unsigned trailingVecDimBitWidth =
40 vecType.getShape().back() * vecType.getElementTypeBitWidth();
41 if (trailingVecDimBitWidth >= targetBitWidth)
42 return false;
43 }
44 return true;
45}
46
47namespace {
48struct LinearizeConstant final : OpConversionPattern<arith::ConstantOp> {
49 using OpConversionPattern::OpConversionPattern;
50 LinearizeConstant(
51 const TypeConverter &typeConverter, MLIRContext *context,
52 unsigned targetVectBitWidth = std::numeric_limits<unsigned>::max(),
53 PatternBenefit benefit = 1)
54 : OpConversionPattern(typeConverter, context, benefit),
55 targetVectorBitWidth(targetVectBitWidth) {}
56 LogicalResult
57 matchAndRewrite(arith::ConstantOp constOp, OpAdaptor adaptor,
58 ConversionPatternRewriter &rewriter) const override {
59 Location loc = constOp.getLoc();
60 auto resType =
61 getTypeConverter()->convertType<VectorType>(constOp.getType());
62
63 if (resType.isScalable() && !isa<SplatElementsAttr>(constOp.getValue()))
64 return rewriter.notifyMatchFailure(
65 arg&: loc,
66 msg: "Cannot linearize a constant scalable vector that's not a splat");
67
68 if (!resType)
69 return rewriter.notifyMatchFailure(arg&: loc, msg: "can't convert return type");
70 if (!isLessThanTargetBitWidth(constOp, targetVectorBitWidth))
71 return rewriter.notifyMatchFailure(
72 arg&: loc, msg: "Can't flatten since targetBitWidth <= OpSize");
73 auto dstElementsAttr = dyn_cast<DenseElementsAttr>(constOp.getValue());
74 if (!dstElementsAttr)
75 return rewriter.notifyMatchFailure(arg&: loc, msg: "unsupported attr type");
76
77 dstElementsAttr = dstElementsAttr.reshape(resType);
78 rewriter.replaceOpWithNewOp<arith::ConstantOp>(constOp, resType,
79 dstElementsAttr);
80 return success();
81 }
82
83private:
84 unsigned targetVectorBitWidth;
85};
86
87struct LinearizeVectorizable final
88 : OpTraitConversionPattern<OpTrait::Vectorizable> {
89 using OpTraitConversionPattern::OpTraitConversionPattern;
90
91public:
92 LinearizeVectorizable(
93 const TypeConverter &typeConverter, MLIRContext *context,
94 unsigned targetVectBitWidth = std::numeric_limits<unsigned>::max(),
95 PatternBenefit benefit = 1)
96 : OpTraitConversionPattern(typeConverter, context, benefit),
97 targetVectorBitWidth(targetVectBitWidth) {}
98 LogicalResult
99 matchAndRewrite(Operation *op, ArrayRef<Value> operands,
100 ConversionPatternRewriter &rewriter) const override {
101 if (!isLessThanTargetBitWidth(op, targetBitWidth: targetVectorBitWidth))
102 return rewriter.notifyMatchFailure(
103 arg: op->getLoc(), msg: "Can't flatten since targetBitWidth <= OpSize");
104 FailureOr<Operation *> newOp =
105 convertOpResultTypes(op, operands, converter: *getTypeConverter(), rewriter);
106 if (failed(result: newOp))
107 return failure();
108
109 rewriter.replaceOp(op, newValues: (*newOp)->getResults());
110 return success();
111 }
112
113private:
114 unsigned targetVectorBitWidth;
115};
116
117/// This pattern converts the ExtractStridedSliceOp into a ShuffleOp that works
118/// on a linearized vector.
119/// Following,
120/// vector.extract_strided_slice %source
121/// { offsets = [..], strides = [..], sizes = [..] }
122/// is converted to :
123/// %source_1d = vector.shape_cast %source
124/// %out_1d = vector.shuffle %source_1d, %source_1d [ shuffle_indices_1d ]
125/// %out_nd = vector.shape_cast %out_1d
126/// `shuffle_indices_1d` is computed using the offsets and sizes of the
127/// extraction.
128struct LinearizeVectorExtractStridedSlice final
129 : public mlir::OpConversionPattern<mlir::vector::ExtractStridedSliceOp> {
130 using OpConversionPattern::OpConversionPattern;
131 LinearizeVectorExtractStridedSlice(
132 const TypeConverter &typeConverter, MLIRContext *context,
133 unsigned targetVectBitWidth = std::numeric_limits<unsigned>::max(),
134 PatternBenefit benefit = 1)
135 : OpConversionPattern(typeConverter, context, benefit),
136 targetVectorBitWidth(targetVectBitWidth) {}
137
138 LogicalResult
139 matchAndRewrite(vector::ExtractStridedSliceOp extractOp, OpAdaptor adaptor,
140 ConversionPatternRewriter &rewriter) const override {
141 Type dstType = getTypeConverter()->convertType(extractOp.getType());
142 assert(!(extractOp.getVector().getType().isScalable() ||
143 dstType.cast<VectorType>().isScalable()) &&
144 "scalable vectors are not supported.");
145 if (!isLessThanTargetBitWidth(extractOp, targetVectorBitWidth))
146 return rewriter.notifyMatchFailure(
147 extractOp, "Can't flatten since targetBitWidth <= OpSize");
148
149 ArrayAttr offsets = extractOp.getOffsets();
150 ArrayAttr sizes = extractOp.getSizes();
151 ArrayAttr strides = extractOp.getStrides();
152 if (!isConstantIntValue(strides[0], 1))
153 return rewriter.notifyMatchFailure(
154 extractOp, "Strided slice with stride != 1 is not supported.");
155 Value srcVector = adaptor.getVector();
156 // If kD offsets are specified for nD source vector (n > k), the granularity
157 // of the extraction is greater than 1. In this case last (n-k) dimensions
158 // form the extraction granularity.
159 // Example :
160 // vector.extract_strided_slice %src {
161 // offsets = [0, 0], sizes = [2, 2], strides = [1, 1]} :
162 // vector<4x8x8xf32> to vector<2x2x8xf32>
163 // Here, extraction granularity is 8.
164 int64_t extractGranularitySize = 1;
165 int64_t nD = extractOp.getSourceVectorType().getRank();
166 int64_t kD = (int64_t)offsets.size();
167 int64_t k = kD;
168 while (k < nD) {
169 extractGranularitySize *= extractOp.getSourceVectorType().getShape()[k];
170 ++k;
171 }
172 // Get total number of extracted slices.
173 int64_t nExtractedSlices = 1;
174 for (Attribute size : sizes) {
175 nExtractedSlices *= size.cast<IntegerAttr>().getInt();
176 }
177 // Compute the strides of the source vector considering first k dimensions.
178 llvm::SmallVector<int64_t, 4> sourceStrides(kD, extractGranularitySize);
179 for (int i = kD - 2; i >= 0; --i) {
180 sourceStrides[i] = sourceStrides[i + 1] *
181 extractOp.getSourceVectorType().getShape()[i + 1];
182 }
183 // Final shuffle indices has nExtractedSlices * extractGranularitySize
184 // elements.
185 llvm::SmallVector<int64_t, 4> indices(nExtractedSlices *
186 extractGranularitySize);
187 // Compute the strides of the extracted kD vector.
188 llvm::SmallVector<int64_t, 4> extractedStrides(kD, 1);
189 // Compute extractedStrides.
190 for (int i = kD - 2; i >= 0; --i) {
191 extractedStrides[i] =
192 extractedStrides[i + 1] * sizes[i + 1].cast<IntegerAttr>().getInt();
193 }
194 // Iterate over all extracted slices from 0 to nExtractedSlices - 1
195 // and compute the multi-dimensional index and the corresponding linearized
196 // index within the source vector.
197 for (int64_t i = 0; i < nExtractedSlices; ++i) {
198 int64_t index = i;
199 // Compute the corresponding multi-dimensional index.
200 llvm::SmallVector<int64_t, 4> multiDimIndex(kD, 0);
201 for (int64_t j = 0; j < kD; ++j) {
202 multiDimIndex[j] = (index / extractedStrides[j]);
203 index -= multiDimIndex[j] * extractedStrides[j];
204 }
205 // Compute the corresponding linearized index in the source vector
206 // i.e. shift the multiDimIndex by the offsets.
207 int64_t linearizedIndex = 0;
208 for (int64_t j = 0; j < kD; ++j) {
209 linearizedIndex +=
210 (offsets[j].cast<IntegerAttr>().getInt() + multiDimIndex[j]) *
211 sourceStrides[j];
212 }
213 // Fill the indices array form linearizedIndex to linearizedIndex +
214 // extractGranularitySize.
215 for (int64_t j = 0; j < extractGranularitySize; ++j) {
216 indices[i * extractGranularitySize + j] = linearizedIndex + j;
217 }
218 }
219 // Perform a shuffle to extract the kD vector.
220 rewriter.replaceOpWithNewOp<vector::ShuffleOp>(
221 extractOp, dstType, srcVector, srcVector,
222 rewriter.getI64ArrayAttr(indices));
223 return success();
224 }
225
226private:
227 unsigned targetVectorBitWidth;
228};
229
230/// This pattern converts the ShuffleOp that works on nD (n > 1)
231/// vectors to a ShuffleOp that works on linearized vectors.
232/// Following,
233/// vector.shuffle %v1, %v2 [ shuffle_indices ]
234/// is converted to :
235/// %v1_1d = vector.shape_cast %v1
236/// %v2_1d = vector.shape_cast %v2
237/// %out_1d = vector.shuffle %v1_1d, %v2_1d [ shuffle_indices_1d ]
238/// %out_nd = vector.shape_cast %out_1d
239// `shuffle_indices_1d` is computed using the sizes and `shuffle_indices`
240/// of the original shuffle operation.
241struct LinearizeVectorShuffle final
242 : public OpConversionPattern<vector::ShuffleOp> {
243 using OpConversionPattern::OpConversionPattern;
244 LinearizeVectorShuffle(
245 const TypeConverter &typeConverter, MLIRContext *context,
246 unsigned targetVectBitWidth = std::numeric_limits<unsigned>::max(),
247 PatternBenefit benefit = 1)
248 : OpConversionPattern(typeConverter, context, benefit),
249 targetVectorBitWidth(targetVectBitWidth) {}
250
251 LogicalResult
252 matchAndRewrite(vector::ShuffleOp shuffleOp, OpAdaptor adaptor,
253 ConversionPatternRewriter &rewriter) const override {
254 Type dstType = getTypeConverter()->convertType(shuffleOp.getType());
255 assert(!(shuffleOp.getV1VectorType().isScalable() ||
256 shuffleOp.getV2VectorType().isScalable() ||
257 dstType.cast<VectorType>().isScalable()) &&
258 "scalable vectors are not supported.");
259 if (!isLessThanTargetBitWidth(shuffleOp, targetVectorBitWidth))
260 return rewriter.notifyMatchFailure(
261 shuffleOp, "Can't flatten since targetBitWidth <= OpSize");
262
263 Value vec1 = adaptor.getV1();
264 Value vec2 = adaptor.getV2();
265 int shuffleSliceLen = 1;
266 int rank = shuffleOp.getV1().getType().getRank();
267
268 // If rank > 1, we need to do the shuffle in the granularity of slices
269 // instead of scalars. Size of the slice is equal to the rank-1 innermost
270 // dims. Mask of the shuffle op specifies which slice to take from the
271 // outermost dim.
272 if (rank > 1) {
273 llvm::ArrayRef<int64_t> shape = shuffleOp.getV1().getType().getShape();
274 for (unsigned i = 1; i < shape.size(); ++i) {
275 shuffleSliceLen *= shape[i];
276 }
277 }
278
279 // For each value in the mask, we generate the indices of the source vectors
280 // that needs to be shuffled to the destination vector. If shuffleSliceLen >
281 // 1 we need to shuffle the slices (consecutive shuffleSliceLen number of
282 // elements) instead of scalars.
283 ArrayAttr mask = shuffleOp.getMask();
284 int64_t totalSizeOfShuffledElmnts = mask.size() * shuffleSliceLen;
285 llvm::SmallVector<int64_t, 2> indices(totalSizeOfShuffledElmnts);
286 for (auto [i, value] :
287 llvm::enumerate(mask.getAsValueRange<IntegerAttr>())) {
288
289 int64_t v = value.getZExtValue();
290 std::iota(indices.begin() + shuffleSliceLen * i,
291 indices.begin() + shuffleSliceLen * (i + 1),
292 shuffleSliceLen * v);
293 }
294
295 rewriter.replaceOpWithNewOp<vector::ShuffleOp>(
296 shuffleOp, dstType, vec1, vec2, rewriter.getI64ArrayAttr(indices));
297 return success();
298 }
299
300private:
301 unsigned targetVectorBitWidth;
302};
303
304/// This pattern converts the ExtractOp to a ShuffleOp that works on a
305/// linearized vector.
306/// Following,
307/// vector.extract %source [ position ]
308/// is converted to :
309/// %source_1d = vector.shape_cast %source
310/// %out_1d = vector.shuffle %source_1d, %source_1d [ shuffle_indices_1d ]
311/// %out_nd = vector.shape_cast %out_1d
312/// `shuffle_indices_1d` is computed using the position of the original extract.
313struct LinearizeVectorExtract final
314 : public OpConversionPattern<vector::ExtractOp> {
315 using OpConversionPattern::OpConversionPattern;
316 LinearizeVectorExtract(
317 const TypeConverter &typeConverter, MLIRContext *context,
318 unsigned targetVectBitWidth = std::numeric_limits<unsigned>::max(),
319 PatternBenefit benefit = 1)
320 : OpConversionPattern(typeConverter, context, benefit),
321 targetVectorBitWidth(targetVectBitWidth) {}
322 LogicalResult
323 matchAndRewrite(vector::ExtractOp extractOp, OpAdaptor adaptor,
324 ConversionPatternRewriter &rewriter) const override {
325 Type dstTy = getTypeConverter()->convertType(extractOp.getType());
326 assert(!(extractOp.getVector().getType().isScalable() ||
327 dstTy.cast<VectorType>().isScalable()) &&
328 "scalable vectors are not supported.");
329 if (!isLessThanTargetBitWidth(extractOp, targetVectorBitWidth))
330 return rewriter.notifyMatchFailure(
331 extractOp, "Can't flatten since targetBitWidth <= OpSize");
332
333 // Dynamic position is not supported.
334 if (extractOp.hasDynamicPosition())
335 return rewriter.notifyMatchFailure(extractOp,
336 "dynamic position is not supported.");
337
338 llvm::ArrayRef<int64_t> shape = extractOp.getVector().getType().getShape();
339 int64_t size = extractOp.getVector().getType().getNumElements();
340
341 // Compute linearized offset.
342 int64_t linearizedOffset = 0;
343 llvm::ArrayRef<int64_t> offsets = extractOp.getStaticPosition();
344 for (auto [i, off] : llvm::enumerate(offsets)) {
345 size /= shape[i];
346 linearizedOffset += offsets[i] * size;
347 }
348
349 llvm::SmallVector<int64_t, 2> indices(size);
350 std::iota(first: indices.begin(), last: indices.end(), value: linearizedOffset);
351 rewriter.replaceOpWithNewOp<vector::ShuffleOp>(
352 extractOp, dstTy, adaptor.getVector(), adaptor.getVector(),
353 rewriter.getI64ArrayAttr(indices));
354
355 return success();
356 }
357
358private:
359 unsigned targetVectorBitWidth;
360};
361} // namespace
362
363void mlir::vector::populateVectorLinearizeTypeConversionsAndLegality(
364 TypeConverter &typeConverter, RewritePatternSet &patterns,
365 ConversionTarget &target, unsigned targetBitWidth) {
366
367 typeConverter.addConversion(callback: [](VectorType type) -> std::optional<Type> {
368 if (!isLinearizableVector(type))
369 return type;
370
371 return VectorType::get(type.getNumElements(), type.getElementType(),
372 type.isScalable());
373 });
374
375 auto materializeCast = [](OpBuilder &builder, Type type, ValueRange inputs,
376 Location loc) -> Value {
377 if (inputs.size() != 1 || !isa<VectorType>(Val: inputs.front().getType()) ||
378 !isa<VectorType>(Val: type))
379 return nullptr;
380
381 return builder.create<vector::ShapeCastOp>(loc, type, inputs.front());
382 };
383 typeConverter.addArgumentMaterialization(callback&: materializeCast);
384 typeConverter.addSourceMaterialization(callback&: materializeCast);
385 typeConverter.addTargetMaterialization(callback&: materializeCast);
386 target.markUnknownOpDynamicallyLegal(
387 fn: [=](Operation *op) -> std::optional<bool> {
388 if ((isa<arith::ConstantOp>(op) ||
389 op->hasTrait<OpTrait::Vectorizable>())) {
390 return (isLessThanTargetBitWidth(op, targetBitWidth)
391 ? typeConverter.isLegal(op)
392 : true);
393 }
394 return std::nullopt;
395 });
396
397 patterns.add<LinearizeConstant, LinearizeVectorizable>(
398 arg&: typeConverter, args: patterns.getContext(), args&: targetBitWidth);
399}
400
401void mlir::vector::populateVectorLinearizeShuffleLikeOpsPatterns(
402 TypeConverter &typeConverter, RewritePatternSet &patterns,
403 ConversionTarget &target, unsigned int targetBitWidth) {
404 target.addDynamicallyLegalOp<vector::ShuffleOp>(
405 [=](vector::ShuffleOp shuffleOp) -> bool {
406 return isLessThanTargetBitWidth(shuffleOp, targetBitWidth)
407 ? (typeConverter.isLegal(shuffleOp) &&
408 shuffleOp.getResult()
409 .getType()
410 .cast<mlir::VectorType>()
411 .getRank() == 1)
412 : true;
413 });
414 patterns.add<LinearizeVectorShuffle, LinearizeVectorExtract,
415 LinearizeVectorExtractStridedSlice>(
416 arg&: typeConverter, args: patterns.getContext(), args&: targetBitWidth);
417}
418

source code of mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp