1//===- VectorToLLVM.cpp - Conversion from Vector to the LLVM dialect ------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#include "mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h"
10
11#include "mlir/Conversion/ArithCommon/AttrToLLVMConverter.h"
12#include "mlir/Conversion/ConvertToLLVM/ToLLVMInterface.h"
13#include "mlir/Conversion/LLVMCommon/PrintCallHelper.h"
14#include "mlir/Conversion/LLVMCommon/TypeConverter.h"
15#include "mlir/Conversion/LLVMCommon/VectorPattern.h"
16#include "mlir/Dialect/Arith/IR/Arith.h"
17#include "mlir/Dialect/Arith/Utils/Utils.h"
18#include "mlir/Dialect/LLVMIR/FunctionCallUtils.h"
19#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
20#include "mlir/Dialect/MemRef/IR/MemRef.h"
21#include "mlir/Dialect/Vector/IR/VectorOps.h"
22#include "mlir/Dialect/Vector/Interfaces/MaskableOpInterface.h"
23#include "mlir/Dialect/Vector/Transforms/LoweringPatterns.h"
24#include "mlir/Dialect/Vector/Transforms/VectorTransforms.h"
25#include "mlir/IR/BuiltinAttributes.h"
26#include "mlir/IR/BuiltinTypeInterfaces.h"
27#include "mlir/IR/BuiltinTypes.h"
28#include "mlir/IR/TypeUtilities.h"
29#include "mlir/Target/LLVMIR/TypeToLLVM.h"
30#include "mlir/Transforms/DialectConversion.h"
31#include "llvm/ADT/APFloat.h"
32#include "llvm/Support/Casting.h"
33#include <optional>
34
35using namespace mlir;
36using namespace mlir::vector;
37
38// Helper that picks the proper sequence for inserting.
39static Value insertOne(ConversionPatternRewriter &rewriter,
40 const LLVMTypeConverter &typeConverter, Location loc,
41 Value val1, Value val2, Type llvmType, int64_t rank,
42 int64_t pos) {
43 assert(rank > 0 && "0-D vector corner case should have been handled already");
44 if (rank == 1) {
45 auto idxType = rewriter.getIndexType();
46 auto constant = rewriter.create<LLVM::ConstantOp>(
47 loc, typeConverter.convertType(idxType),
48 rewriter.getIntegerAttr(idxType, pos));
49 return rewriter.create<LLVM::InsertElementOp>(loc, llvmType, val1, val2,
50 constant);
51 }
52 return rewriter.create<LLVM::InsertValueOp>(loc, val1, val2, pos);
53}
54
55// Helper that picks the proper sequence for extracting.
56static Value extractOne(ConversionPatternRewriter &rewriter,
57 const LLVMTypeConverter &typeConverter, Location loc,
58 Value val, Type llvmType, int64_t rank, int64_t pos) {
59 if (rank <= 1) {
60 auto idxType = rewriter.getIndexType();
61 auto constant = rewriter.create<LLVM::ConstantOp>(
62 loc, typeConverter.convertType(idxType),
63 rewriter.getIntegerAttr(idxType, pos));
64 return rewriter.create<LLVM::ExtractElementOp>(loc, llvmType, val,
65 constant);
66 }
67 return rewriter.create<LLVM::ExtractValueOp>(loc, val, pos);
68}
69
70// Helper that returns data layout alignment of a vector.
71LogicalResult getVectorAlignment(const LLVMTypeConverter &typeConverter,
72 VectorType vectorType, unsigned &align) {
73 Type convertedVectorTy = typeConverter.convertType(vectorType);
74 if (!convertedVectorTy)
75 return failure();
76
77 llvm::LLVMContext llvmContext;
78 align = LLVM::TypeToLLVMIRTranslator(llvmContext)
79 .getPreferredAlignment(type: convertedVectorTy,
80 layout: typeConverter.getDataLayout());
81
82 return success();
83}
84
85// Helper that returns data layout alignment of a memref.
86LogicalResult getMemRefAlignment(const LLVMTypeConverter &typeConverter,
87 MemRefType memrefType, unsigned &align) {
88 Type elementTy = typeConverter.convertType(memrefType.getElementType());
89 if (!elementTy)
90 return failure();
91
92 // TODO: this should use the MLIR data layout when it becomes available and
93 // stop depending on translation.
94 llvm::LLVMContext llvmContext;
95 align = LLVM::TypeToLLVMIRTranslator(llvmContext)
96 .getPreferredAlignment(type: elementTy, layout: typeConverter.getDataLayout());
97 return success();
98}
99
100// Helper to resolve the alignment for vector load/store, gather and scatter
101// ops. If useVectorAlignment is true, get the preferred alignment for the
102// vector type in the operation. This option is used for hardware backends with
103// vectorization. Otherwise, use the preferred alignment of the element type of
104// the memref. Note that if you choose to use vector alignment, the shape of the
105// vector type must be resolved before the ConvertVectorToLLVM pass is run.
106LogicalResult getVectorToLLVMAlignment(const LLVMTypeConverter &typeConverter,
107 VectorType vectorType,
108 MemRefType memrefType, unsigned &align,
109 bool useVectorAlignment) {
110 if (useVectorAlignment) {
111 if (failed(getVectorAlignment(typeConverter, vectorType, align))) {
112 return failure();
113 }
114 } else {
115 if (failed(getMemRefAlignment(typeConverter, memrefType, align))) {
116 return failure();
117 }
118 }
119 return success();
120}
121
122// Check if the last stride is non-unit and has a valid memory space.
123static LogicalResult isMemRefTypeSupported(MemRefType memRefType,
124 const LLVMTypeConverter &converter) {
125 if (!memRefType.isLastDimUnitStride())
126 return failure();
127 if (failed(converter.getMemRefAddressSpace(type: memRefType)))
128 return failure();
129 return success();
130}
131
132// Add an index vector component to a base pointer.
133static Value getIndexedPtrs(ConversionPatternRewriter &rewriter, Location loc,
134 const LLVMTypeConverter &typeConverter,
135 MemRefType memRefType, Value llvmMemref, Value base,
136 Value index, VectorType vectorType) {
137 assert(succeeded(isMemRefTypeSupported(memRefType, typeConverter)) &&
138 "unsupported memref type");
139 assert(vectorType.getRank() == 1 && "expected a 1-d vector type");
140 auto pType = MemRefDescriptor(llvmMemref).getElementPtrType();
141 auto ptrsType =
142 LLVM::getVectorType(pType, vectorType.getDimSize(0),
143 /*isScalable=*/vectorType.getScalableDims()[0]);
144 return rewriter.create<LLVM::GEPOp>(
145 loc, ptrsType, typeConverter.convertType(memRefType.getElementType()),
146 base, index);
147}
148
149/// Convert `foldResult` into a Value. Integer attribute is converted to
150/// an LLVM constant op.
151static Value getAsLLVMValue(OpBuilder &builder, Location loc,
152 OpFoldResult foldResult) {
153 if (auto attr = dyn_cast<Attribute>(Val&: foldResult)) {
154 auto intAttr = cast<IntegerAttr>(attr);
155 return builder.create<LLVM::ConstantOp>(loc, intAttr).getResult();
156 }
157
158 return cast<Value>(Val&: foldResult);
159}
160
161namespace {
162
163/// Trivial Vector to LLVM conversions
164using VectorScaleOpConversion =
165 OneToOneConvertToLLVMPattern<vector::VectorScaleOp, LLVM::vscale>;
166
167/// Conversion pattern for a vector.bitcast.
168class VectorBitCastOpConversion
169 : public ConvertOpToLLVMPattern<vector::BitCastOp> {
170public:
171 using ConvertOpToLLVMPattern<vector::BitCastOp>::ConvertOpToLLVMPattern;
172
173 LogicalResult
174 matchAndRewrite(vector::BitCastOp bitCastOp, OpAdaptor adaptor,
175 ConversionPatternRewriter &rewriter) const override {
176 // Only 0-D and 1-D vectors can be lowered to LLVM.
177 VectorType resultTy = bitCastOp.getResultVectorType();
178 if (resultTy.getRank() > 1)
179 return failure();
180 Type newResultTy = typeConverter->convertType(resultTy);
181 rewriter.replaceOpWithNewOp<LLVM::BitcastOp>(bitCastOp, newResultTy,
182 adaptor.getOperands()[0]);
183 return success();
184 }
185};
186
187/// Conversion pattern for a vector.matrix_multiply.
188/// This is lowered directly to the proper llvm.intr.matrix.multiply.
189class VectorMatmulOpConversion
190 : public ConvertOpToLLVMPattern<vector::MatmulOp> {
191public:
192 using ConvertOpToLLVMPattern<vector::MatmulOp>::ConvertOpToLLVMPattern;
193
194 LogicalResult
195 matchAndRewrite(vector::MatmulOp matmulOp, OpAdaptor adaptor,
196 ConversionPatternRewriter &rewriter) const override {
197 rewriter.replaceOpWithNewOp<LLVM::MatrixMultiplyOp>(
198 matmulOp, typeConverter->convertType(matmulOp.getRes().getType()),
199 adaptor.getLhs(), adaptor.getRhs(), matmulOp.getLhsRows(),
200 matmulOp.getLhsColumns(), matmulOp.getRhsColumns());
201 return success();
202 }
203};
204
205/// Conversion pattern for a vector.flat_transpose.
206/// This is lowered directly to the proper llvm.intr.matrix.transpose.
207class VectorFlatTransposeOpConversion
208 : public ConvertOpToLLVMPattern<vector::FlatTransposeOp> {
209public:
210 using ConvertOpToLLVMPattern<vector::FlatTransposeOp>::ConvertOpToLLVMPattern;
211
212 LogicalResult
213 matchAndRewrite(vector::FlatTransposeOp transOp, OpAdaptor adaptor,
214 ConversionPatternRewriter &rewriter) const override {
215 rewriter.replaceOpWithNewOp<LLVM::MatrixTransposeOp>(
216 transOp, typeConverter->convertType(transOp.getRes().getType()),
217 adaptor.getMatrix(), transOp.getRows(), transOp.getColumns());
218 return success();
219 }
220};
221
222/// Overloaded utility that replaces a vector.load, vector.store,
223/// vector.maskedload and vector.maskedstore with their respective LLVM
224/// couterparts.
225static void replaceLoadOrStoreOp(vector::LoadOp loadOp,
226 vector::LoadOpAdaptor adaptor,
227 VectorType vectorTy, Value ptr, unsigned align,
228 ConversionPatternRewriter &rewriter) {
229 rewriter.replaceOpWithNewOp<LLVM::LoadOp>(loadOp, vectorTy, ptr, align,
230 /*volatile_=*/false,
231 loadOp.getNontemporal());
232}
233
234static void replaceLoadOrStoreOp(vector::MaskedLoadOp loadOp,
235 vector::MaskedLoadOpAdaptor adaptor,
236 VectorType vectorTy, Value ptr, unsigned align,
237 ConversionPatternRewriter &rewriter) {
238 rewriter.replaceOpWithNewOp<LLVM::MaskedLoadOp>(
239 loadOp, vectorTy, ptr, adaptor.getMask(), adaptor.getPassThru(), align);
240}
241
242static void replaceLoadOrStoreOp(vector::StoreOp storeOp,
243 vector::StoreOpAdaptor adaptor,
244 VectorType vectorTy, Value ptr, unsigned align,
245 ConversionPatternRewriter &rewriter) {
246 rewriter.replaceOpWithNewOp<LLVM::StoreOp>(storeOp, adaptor.getValueToStore(),
247 ptr, align, /*volatile_=*/false,
248 storeOp.getNontemporal());
249}
250
251static void replaceLoadOrStoreOp(vector::MaskedStoreOp storeOp,
252 vector::MaskedStoreOpAdaptor adaptor,
253 VectorType vectorTy, Value ptr, unsigned align,
254 ConversionPatternRewriter &rewriter) {
255 rewriter.replaceOpWithNewOp<LLVM::MaskedStoreOp>(
256 storeOp, adaptor.getValueToStore(), ptr, adaptor.getMask(), align);
257}
258
259/// Conversion pattern for a vector.load, vector.store, vector.maskedload, and
260/// vector.maskedstore.
261template <class LoadOrStoreOp>
262class VectorLoadStoreConversion : public ConvertOpToLLVMPattern<LoadOrStoreOp> {
263public:
264 explicit VectorLoadStoreConversion(const LLVMTypeConverter &typeConv,
265 bool useVectorAlign)
266 : ConvertOpToLLVMPattern<LoadOrStoreOp>(typeConv),
267 useVectorAlignment(useVectorAlign) {}
268 using ConvertOpToLLVMPattern<LoadOrStoreOp>::ConvertOpToLLVMPattern;
269
270 LogicalResult
271 matchAndRewrite(LoadOrStoreOp loadOrStoreOp,
272 typename LoadOrStoreOp::Adaptor adaptor,
273 ConversionPatternRewriter &rewriter) const override {
274 // Only 1-D vectors can be lowered to LLVM.
275 VectorType vectorTy = loadOrStoreOp.getVectorType();
276 if (vectorTy.getRank() > 1)
277 return failure();
278
279 auto loc = loadOrStoreOp->getLoc();
280 MemRefType memRefTy = loadOrStoreOp.getMemRefType();
281
282 // Resolve alignment.
283 unsigned align;
284 if (failed(getVectorToLLVMAlignment(*this->getTypeConverter(), vectorTy,
285 memRefTy, align, useVectorAlignment)))
286 return rewriter.notifyMatchFailure(loadOrStoreOp,
287 "could not resolve alignment");
288
289 // Resolve address.
290 auto vtype = cast<VectorType>(
291 this->typeConverter->convertType(loadOrStoreOp.getVectorType()));
292 Value dataPtr = this->getStridedElementPtr(
293 rewriter, loc, memRefTy, adaptor.getBase(), adaptor.getIndices());
294 replaceLoadOrStoreOp(loadOrStoreOp, adaptor, vtype, dataPtr, align,
295 rewriter);
296 return success();
297 }
298
299private:
300 // If true, use the preferred alignment of the vector type.
301 // If false, use the preferred alignment of the element type
302 // of the memref. This flag is intended for use with hardware
303 // backends that require alignment of vector operations.
304 const bool useVectorAlignment;
305};
306
307/// Conversion pattern for a vector.gather.
308class VectorGatherOpConversion
309 : public ConvertOpToLLVMPattern<vector::GatherOp> {
310public:
311 explicit VectorGatherOpConversion(const LLVMTypeConverter &typeConv,
312 bool useVectorAlign)
313 : ConvertOpToLLVMPattern<vector::GatherOp>(typeConv),
314 useVectorAlignment(useVectorAlign) {}
315 using ConvertOpToLLVMPattern<vector::GatherOp>::ConvertOpToLLVMPattern;
316
317 LogicalResult
318 matchAndRewrite(vector::GatherOp gather, OpAdaptor adaptor,
319 ConversionPatternRewriter &rewriter) const override {
320 Location loc = gather->getLoc();
321 MemRefType memRefType = dyn_cast<MemRefType>(gather.getBaseType());
322 assert(memRefType && "The base should be bufferized");
323
324 if (failed(isMemRefTypeSupported(memRefType, *this->getTypeConverter())))
325 return rewriter.notifyMatchFailure(gather, "memref type not supported");
326
327 VectorType vType = gather.getVectorType();
328 if (vType.getRank() > 1) {
329 return rewriter.notifyMatchFailure(
330 gather, "only 1-D vectors can be lowered to LLVM");
331 }
332
333 // Resolve alignment.
334 unsigned align;
335 if (failed(getVectorToLLVMAlignment(*this->getTypeConverter(), vType,
336 memRefType, align, useVectorAlignment)))
337 return rewriter.notifyMatchFailure(gather, "could not resolve alignment");
338
339 // Resolve address.
340 Value ptr = getStridedElementPtr(rewriter, loc, memRefType,
341 adaptor.getBase(), adaptor.getIndices());
342 Value base = adaptor.getBase();
343 Value ptrs =
344 getIndexedPtrs(rewriter, loc, *this->getTypeConverter(), memRefType,
345 base, ptr, adaptor.getIndexVec(), vType);
346
347 // Replace with the gather intrinsic.
348 rewriter.replaceOpWithNewOp<LLVM::masked_gather>(
349 gather, typeConverter->convertType(vType), ptrs, adaptor.getMask(),
350 adaptor.getPassThru(), rewriter.getI32IntegerAttr(align));
351 return success();
352 }
353
354private:
355 // If true, use the preferred alignment of the vector type.
356 // If false, use the preferred alignment of the element type
357 // of the memref. This flag is intended for use with hardware
358 // backends that require alignment of vector operations.
359 const bool useVectorAlignment;
360};
361
362/// Conversion pattern for a vector.scatter.
363class VectorScatterOpConversion
364 : public ConvertOpToLLVMPattern<vector::ScatterOp> {
365public:
366 explicit VectorScatterOpConversion(const LLVMTypeConverter &typeConv,
367 bool useVectorAlign)
368 : ConvertOpToLLVMPattern<vector::ScatterOp>(typeConv),
369 useVectorAlignment(useVectorAlign) {}
370
371 using ConvertOpToLLVMPattern<vector::ScatterOp>::ConvertOpToLLVMPattern;
372
373 LogicalResult
374 matchAndRewrite(vector::ScatterOp scatter, OpAdaptor adaptor,
375 ConversionPatternRewriter &rewriter) const override {
376 auto loc = scatter->getLoc();
377 MemRefType memRefType = scatter.getMemRefType();
378
379 if (failed(isMemRefTypeSupported(memRefType, *this->getTypeConverter())))
380 return rewriter.notifyMatchFailure(scatter, "memref type not supported");
381
382 VectorType vType = scatter.getVectorType();
383 if (vType.getRank() > 1) {
384 return rewriter.notifyMatchFailure(
385 scatter, "only 1-D vectors can be lowered to LLVM");
386 }
387
388 // Resolve alignment.
389 unsigned align;
390 if (failed(getVectorToLLVMAlignment(*this->getTypeConverter(), vType,
391 memRefType, align, useVectorAlignment)))
392 return rewriter.notifyMatchFailure(scatter,
393 "could not resolve alignment");
394
395 // Resolve address.
396 Value ptr = getStridedElementPtr(rewriter, loc, memRefType,
397 adaptor.getBase(), adaptor.getIndices());
398 Value ptrs =
399 getIndexedPtrs(rewriter, loc, *this->getTypeConverter(), memRefType,
400 adaptor.getBase(), ptr, adaptor.getIndexVec(), vType);
401
402 // Replace with the scatter intrinsic.
403 rewriter.replaceOpWithNewOp<LLVM::masked_scatter>(
404 scatter, adaptor.getValueToStore(), ptrs, adaptor.getMask(),
405 rewriter.getI32IntegerAttr(align));
406 return success();
407 }
408
409private:
410 // If true, use the preferred alignment of the vector type.
411 // If false, use the preferred alignment of the element type
412 // of the memref. This flag is intended for use with hardware
413 // backends that require alignment of vector operations.
414 const bool useVectorAlignment;
415};
416
417/// Conversion pattern for a vector.expandload.
418class VectorExpandLoadOpConversion
419 : public ConvertOpToLLVMPattern<vector::ExpandLoadOp> {
420public:
421 using ConvertOpToLLVMPattern<vector::ExpandLoadOp>::ConvertOpToLLVMPattern;
422
423 LogicalResult
424 matchAndRewrite(vector::ExpandLoadOp expand, OpAdaptor adaptor,
425 ConversionPatternRewriter &rewriter) const override {
426 auto loc = expand->getLoc();
427 MemRefType memRefType = expand.getMemRefType();
428
429 // Resolve address.
430 auto vtype = typeConverter->convertType(expand.getVectorType());
431 Value ptr = getStridedElementPtr(rewriter, loc, memRefType,
432 adaptor.getBase(), adaptor.getIndices());
433
434 rewriter.replaceOpWithNewOp<LLVM::masked_expandload>(
435 expand, vtype, ptr, adaptor.getMask(), adaptor.getPassThru());
436 return success();
437 }
438};
439
440/// Conversion pattern for a vector.compressstore.
441class VectorCompressStoreOpConversion
442 : public ConvertOpToLLVMPattern<vector::CompressStoreOp> {
443public:
444 using ConvertOpToLLVMPattern<vector::CompressStoreOp>::ConvertOpToLLVMPattern;
445
446 LogicalResult
447 matchAndRewrite(vector::CompressStoreOp compress, OpAdaptor adaptor,
448 ConversionPatternRewriter &rewriter) const override {
449 auto loc = compress->getLoc();
450 MemRefType memRefType = compress.getMemRefType();
451
452 // Resolve address.
453 Value ptr = getStridedElementPtr(rewriter, loc, memRefType,
454 adaptor.getBase(), adaptor.getIndices());
455
456 rewriter.replaceOpWithNewOp<LLVM::masked_compressstore>(
457 compress, adaptor.getValueToStore(), ptr, adaptor.getMask());
458 return success();
459 }
460};
461
462/// Reduction neutral classes for overloading.
463class ReductionNeutralZero {};
464class ReductionNeutralIntOne {};
465class ReductionNeutralFPOne {};
466class ReductionNeutralAllOnes {};
467class ReductionNeutralSIntMin {};
468class ReductionNeutralUIntMin {};
469class ReductionNeutralSIntMax {};
470class ReductionNeutralUIntMax {};
471class ReductionNeutralFPMin {};
472class ReductionNeutralFPMax {};
473
474/// Create the reduction neutral zero value.
475static Value createReductionNeutralValue(ReductionNeutralZero neutral,
476 ConversionPatternRewriter &rewriter,
477 Location loc, Type llvmType) {
478 return rewriter.create<LLVM::ConstantOp>(loc, llvmType,
479 rewriter.getZeroAttr(llvmType));
480}
481
482/// Create the reduction neutral integer one value.
483static Value createReductionNeutralValue(ReductionNeutralIntOne neutral,
484 ConversionPatternRewriter &rewriter,
485 Location loc, Type llvmType) {
486 return rewriter.create<LLVM::ConstantOp>(
487 loc, llvmType, rewriter.getIntegerAttr(llvmType, 1));
488}
489
490/// Create the reduction neutral fp one value.
491static Value createReductionNeutralValue(ReductionNeutralFPOne neutral,
492 ConversionPatternRewriter &rewriter,
493 Location loc, Type llvmType) {
494 return rewriter.create<LLVM::ConstantOp>(
495 loc, llvmType, rewriter.getFloatAttr(llvmType, 1.0));
496}
497
498/// Create the reduction neutral all-ones value.
499static Value createReductionNeutralValue(ReductionNeutralAllOnes neutral,
500 ConversionPatternRewriter &rewriter,
501 Location loc, Type llvmType) {
502 return rewriter.create<LLVM::ConstantOp>(
503 loc, llvmType,
504 rewriter.getIntegerAttr(
505 llvmType, llvm::APInt::getAllOnes(llvmType.getIntOrFloatBitWidth())));
506}
507
508/// Create the reduction neutral signed int minimum value.
509static Value createReductionNeutralValue(ReductionNeutralSIntMin neutral,
510 ConversionPatternRewriter &rewriter,
511 Location loc, Type llvmType) {
512 return rewriter.create<LLVM::ConstantOp>(
513 loc, llvmType,
514 rewriter.getIntegerAttr(llvmType, llvm::APInt::getSignedMinValue(
515 llvmType.getIntOrFloatBitWidth())));
516}
517
518/// Create the reduction neutral unsigned int minimum value.
519static Value createReductionNeutralValue(ReductionNeutralUIntMin neutral,
520 ConversionPatternRewriter &rewriter,
521 Location loc, Type llvmType) {
522 return rewriter.create<LLVM::ConstantOp>(
523 loc, llvmType,
524 rewriter.getIntegerAttr(llvmType, llvm::APInt::getMinValue(
525 llvmType.getIntOrFloatBitWidth())));
526}
527
528/// Create the reduction neutral signed int maximum value.
529static Value createReductionNeutralValue(ReductionNeutralSIntMax neutral,
530 ConversionPatternRewriter &rewriter,
531 Location loc, Type llvmType) {
532 return rewriter.create<LLVM::ConstantOp>(
533 loc, llvmType,
534 rewriter.getIntegerAttr(llvmType, llvm::APInt::getSignedMaxValue(
535 llvmType.getIntOrFloatBitWidth())));
536}
537
538/// Create the reduction neutral unsigned int maximum value.
539static Value createReductionNeutralValue(ReductionNeutralUIntMax neutral,
540 ConversionPatternRewriter &rewriter,
541 Location loc, Type llvmType) {
542 return rewriter.create<LLVM::ConstantOp>(
543 loc, llvmType,
544 rewriter.getIntegerAttr(llvmType, llvm::APInt::getMaxValue(
545 llvmType.getIntOrFloatBitWidth())));
546}
547
548/// Create the reduction neutral fp minimum value.
549static Value createReductionNeutralValue(ReductionNeutralFPMin neutral,
550 ConversionPatternRewriter &rewriter,
551 Location loc, Type llvmType) {
552 auto floatType = cast<FloatType>(llvmType);
553 return rewriter.create<LLVM::ConstantOp>(
554 loc, llvmType,
555 rewriter.getFloatAttr(
556 llvmType, llvm::APFloat::getQNaN(floatType.getFloatSemantics(),
557 /*Negative=*/false)));
558}
559
560/// Create the reduction neutral fp maximum value.
561static Value createReductionNeutralValue(ReductionNeutralFPMax neutral,
562 ConversionPatternRewriter &rewriter,
563 Location loc, Type llvmType) {
564 auto floatType = cast<FloatType>(llvmType);
565 return rewriter.create<LLVM::ConstantOp>(
566 loc, llvmType,
567 rewriter.getFloatAttr(
568 llvmType, llvm::APFloat::getQNaN(floatType.getFloatSemantics(),
569 /*Negative=*/true)));
570}
571
572/// Returns `accumulator` if it has a valid value. Otherwise, creates and
573/// returns a new accumulator value using `ReductionNeutral`.
574template <class ReductionNeutral>
575static Value getOrCreateAccumulator(ConversionPatternRewriter &rewriter,
576 Location loc, Type llvmType,
577 Value accumulator) {
578 if (accumulator)
579 return accumulator;
580
581 return createReductionNeutralValue(ReductionNeutral(), rewriter, loc,
582 llvmType);
583}
584
585/// Creates a value with the 1-D vector shape provided in `llvmType`.
586/// This is used as effective vector length by some intrinsics supporting
587/// dynamic vector lengths at runtime.
588static Value createVectorLengthValue(ConversionPatternRewriter &rewriter,
589 Location loc, Type llvmType) {
590 VectorType vType = cast<VectorType>(llvmType);
591 auto vShape = vType.getShape();
592 assert(vShape.size() == 1 && "Unexpected multi-dim vector type");
593
594 Value baseVecLength = rewriter.create<LLVM::ConstantOp>(
595 loc, rewriter.getI32Type(),
596 rewriter.getIntegerAttr(rewriter.getI32Type(), vShape[0]));
597
598 if (!vType.getScalableDims()[0])
599 return baseVecLength;
600
601 // For a scalable vector type, create and return `vScale * baseVecLength`.
602 Value vScale = rewriter.create<vector::VectorScaleOp>(loc);
603 vScale =
604 rewriter.create<arith::IndexCastOp>(loc, rewriter.getI32Type(), vScale);
605 Value scalableVecLength =
606 rewriter.create<arith::MulIOp>(loc, baseVecLength, vScale);
607 return scalableVecLength;
608}
609
610/// Helper method to lower a `vector.reduction` op that performs an arithmetic
611/// operation like add,mul, etc.. `VectorOp` is the LLVM vector intrinsic to use
612/// and `ScalarOp` is the scalar operation used to add the accumulation value if
613/// non-null.
614template <class LLVMRedIntrinOp, class ScalarOp>
615static Value createIntegerReductionArithmeticOpLowering(
616 ConversionPatternRewriter &rewriter, Location loc, Type llvmType,
617 Value vectorOperand, Value accumulator) {
618
619 Value result = rewriter.create<LLVMRedIntrinOp>(loc, llvmType, vectorOperand);
620
621 if (accumulator)
622 result = rewriter.create<ScalarOp>(loc, accumulator, result);
623 return result;
624}
625
626/// Helper method to lower a `vector.reduction` operation that performs
627/// a comparison operation like `min`/`max`. `VectorOp` is the LLVM vector
628/// intrinsic to use and `predicate` is the predicate to use to compare+combine
629/// the accumulator value if non-null.
630template <class LLVMRedIntrinOp>
631static Value createIntegerReductionComparisonOpLowering(
632 ConversionPatternRewriter &rewriter, Location loc, Type llvmType,
633 Value vectorOperand, Value accumulator, LLVM::ICmpPredicate predicate) {
634 Value result = rewriter.create<LLVMRedIntrinOp>(loc, llvmType, vectorOperand);
635 if (accumulator) {
636 Value cmp =
637 rewriter.create<LLVM::ICmpOp>(loc, predicate, accumulator, result);
638 result = rewriter.create<LLVM::SelectOp>(loc, cmp, accumulator, result);
639 }
640 return result;
641}
642
643namespace {
644template <typename Source>
645struct VectorToScalarMapper;
646template <>
647struct VectorToScalarMapper<LLVM::vector_reduce_fmaximum> {
648 using Type = LLVM::MaximumOp;
649};
650template <>
651struct VectorToScalarMapper<LLVM::vector_reduce_fminimum> {
652 using Type = LLVM::MinimumOp;
653};
654template <>
655struct VectorToScalarMapper<LLVM::vector_reduce_fmax> {
656 using Type = LLVM::MaxNumOp;
657};
658template <>
659struct VectorToScalarMapper<LLVM::vector_reduce_fmin> {
660 using Type = LLVM::MinNumOp;
661};
662} // namespace
663
664template <class LLVMRedIntrinOp>
665static Value createFPReductionComparisonOpLowering(
666 ConversionPatternRewriter &rewriter, Location loc, Type llvmType,
667 Value vectorOperand, Value accumulator, LLVM::FastmathFlagsAttr fmf) {
668 Value result =
669 rewriter.create<LLVMRedIntrinOp>(loc, llvmType, vectorOperand, fmf);
670
671 if (accumulator) {
672 result =
673 rewriter.create<typename VectorToScalarMapper<LLVMRedIntrinOp>::Type>(
674 loc, result, accumulator);
675 }
676
677 return result;
678}
679
680/// Reduction neutral classes for overloading
681class MaskNeutralFMaximum {};
682class MaskNeutralFMinimum {};
683
684/// Get the mask neutral floating point maximum value
685static llvm::APFloat
686getMaskNeutralValue(MaskNeutralFMaximum,
687 const llvm::fltSemantics &floatSemantics) {
688 return llvm::APFloat::getSmallest(Sem: floatSemantics, /*Negative=*/true);
689}
690/// Get the mask neutral floating point minimum value
691static llvm::APFloat
692getMaskNeutralValue(MaskNeutralFMinimum,
693 const llvm::fltSemantics &floatSemantics) {
694 return llvm::APFloat::getLargest(Sem: floatSemantics, /*Negative=*/false);
695}
696
697/// Create the mask neutral floating point MLIR vector constant
698template <typename MaskNeutral>
699static Value createMaskNeutralValue(ConversionPatternRewriter &rewriter,
700 Location loc, Type llvmType,
701 Type vectorType) {
702 const auto &floatSemantics = cast<FloatType>(llvmType).getFloatSemantics();
703 auto value = getMaskNeutralValue(MaskNeutral{}, floatSemantics);
704 auto denseValue = DenseElementsAttr::get(cast<ShapedType>(vectorType), value);
705 return rewriter.create<LLVM::ConstantOp>(loc, vectorType, denseValue);
706}
707
708/// Lowers masked `fmaximum` and `fminimum` reductions using the non-masked
709/// intrinsics. It is a workaround to overcome the lack of masked intrinsics for
710/// `fmaximum`/`fminimum`.
711/// More information: https://github.com/llvm/llvm-project/issues/64940
712template <class LLVMRedIntrinOp, class MaskNeutral>
713static Value
714lowerMaskedReductionWithRegular(ConversionPatternRewriter &rewriter,
715 Location loc, Type llvmType,
716 Value vectorOperand, Value accumulator,
717 Value mask, LLVM::FastmathFlagsAttr fmf) {
718 const Value vectorMaskNeutral = createMaskNeutralValue<MaskNeutral>(
719 rewriter, loc, llvmType, vectorOperand.getType());
720 const Value selectedVectorByMask = rewriter.create<LLVM::SelectOp>(
721 loc, mask, vectorOperand, vectorMaskNeutral);
722 return createFPReductionComparisonOpLowering<LLVMRedIntrinOp>(
723 rewriter, loc, llvmType, selectedVectorByMask, accumulator, fmf);
724}
725
726template <class LLVMRedIntrinOp, class ReductionNeutral>
727static Value
728lowerReductionWithStartValue(ConversionPatternRewriter &rewriter, Location loc,
729 Type llvmType, Value vectorOperand,
730 Value accumulator, LLVM::FastmathFlagsAttr fmf) {
731 accumulator = getOrCreateAccumulator<ReductionNeutral>(rewriter, loc,
732 llvmType, accumulator);
733 return rewriter.create<LLVMRedIntrinOp>(loc, llvmType,
734 /*startValue=*/accumulator,
735 vectorOperand, fmf);
736}
737
738/// Overloaded methods to lower a *predicated* reduction to an llvm intrinsic
739/// that requires a start value. This start value format spans across fp
740/// reductions without mask and all the masked reduction intrinsics.
741template <class LLVMVPRedIntrinOp, class ReductionNeutral>
742static Value
743lowerPredicatedReductionWithStartValue(ConversionPatternRewriter &rewriter,
744 Location loc, Type llvmType,
745 Value vectorOperand, Value accumulator) {
746 accumulator = getOrCreateAccumulator<ReductionNeutral>(rewriter, loc,
747 llvmType, accumulator);
748 return rewriter.create<LLVMVPRedIntrinOp>(loc, llvmType,
749 /*startValue=*/accumulator,
750 vectorOperand);
751}
752
753template <class LLVMVPRedIntrinOp, class ReductionNeutral>
754static Value lowerPredicatedReductionWithStartValue(
755 ConversionPatternRewriter &rewriter, Location loc, Type llvmType,
756 Value vectorOperand, Value accumulator, Value mask) {
757 accumulator = getOrCreateAccumulator<ReductionNeutral>(rewriter, loc,
758 llvmType, accumulator);
759 Value vectorLength =
760 createVectorLengthValue(rewriter, loc, llvmType: vectorOperand.getType());
761 return rewriter.create<LLVMVPRedIntrinOp>(loc, llvmType,
762 /*startValue=*/accumulator,
763 vectorOperand, mask, vectorLength);
764}
765
766template <class LLVMIntVPRedIntrinOp, class IntReductionNeutral,
767 class LLVMFPVPRedIntrinOp, class FPReductionNeutral>
768static Value lowerPredicatedReductionWithStartValue(
769 ConversionPatternRewriter &rewriter, Location loc, Type llvmType,
770 Value vectorOperand, Value accumulator, Value mask) {
771 if (llvmType.isIntOrIndex())
772 return lowerPredicatedReductionWithStartValue<LLVMIntVPRedIntrinOp,
773 IntReductionNeutral>(
774 rewriter, loc, llvmType, vectorOperand, accumulator, mask);
775
776 // FP dispatch.
777 return lowerPredicatedReductionWithStartValue<LLVMFPVPRedIntrinOp,
778 FPReductionNeutral>(
779 rewriter, loc, llvmType, vectorOperand, accumulator, mask);
780}
781
782/// Conversion pattern for all vector reductions.
783class VectorReductionOpConversion
784 : public ConvertOpToLLVMPattern<vector::ReductionOp> {
785public:
786 explicit VectorReductionOpConversion(const LLVMTypeConverter &typeConv,
787 bool reassociateFPRed)
788 : ConvertOpToLLVMPattern<vector::ReductionOp>(typeConv),
789 reassociateFPReductions(reassociateFPRed) {}
790
791 LogicalResult
792 matchAndRewrite(vector::ReductionOp reductionOp, OpAdaptor adaptor,
793 ConversionPatternRewriter &rewriter) const override {
794 auto kind = reductionOp.getKind();
795 Type eltType = reductionOp.getDest().getType();
796 Type llvmType = typeConverter->convertType(eltType);
797 Value operand = adaptor.getVector();
798 Value acc = adaptor.getAcc();
799 Location loc = reductionOp.getLoc();
800
801 if (eltType.isIntOrIndex()) {
802 // Integer reductions: add/mul/min/max/and/or/xor.
803 Value result;
804 switch (kind) {
805 case vector::CombiningKind::ADD:
806 result =
807 createIntegerReductionArithmeticOpLowering<LLVM::vector_reduce_add,
808 LLVM::AddOp>(
809 rewriter, loc, llvmType, operand, acc);
810 break;
811 case vector::CombiningKind::MUL:
812 result =
813 createIntegerReductionArithmeticOpLowering<LLVM::vector_reduce_mul,
814 LLVM::MulOp>(
815 rewriter, loc, llvmType, operand, acc);
816 break;
817 case vector::CombiningKind::MINUI:
818 result = createIntegerReductionComparisonOpLowering<
819 LLVM::vector_reduce_umin>(rewriter, loc, llvmType, operand, acc,
820 LLVM::ICmpPredicate::ule);
821 break;
822 case vector::CombiningKind::MINSI:
823 result = createIntegerReductionComparisonOpLowering<
824 LLVM::vector_reduce_smin>(rewriter, loc, llvmType, operand, acc,
825 LLVM::ICmpPredicate::sle);
826 break;
827 case vector::CombiningKind::MAXUI:
828 result = createIntegerReductionComparisonOpLowering<
829 LLVM::vector_reduce_umax>(rewriter, loc, llvmType, operand, acc,
830 LLVM::ICmpPredicate::uge);
831 break;
832 case vector::CombiningKind::MAXSI:
833 result = createIntegerReductionComparisonOpLowering<
834 LLVM::vector_reduce_smax>(rewriter, loc, llvmType, operand, acc,
835 LLVM::ICmpPredicate::sge);
836 break;
837 case vector::CombiningKind::AND:
838 result =
839 createIntegerReductionArithmeticOpLowering<LLVM::vector_reduce_and,
840 LLVM::AndOp>(
841 rewriter, loc, llvmType, operand, acc);
842 break;
843 case vector::CombiningKind::OR:
844 result =
845 createIntegerReductionArithmeticOpLowering<LLVM::vector_reduce_or,
846 LLVM::OrOp>(
847 rewriter, loc, llvmType, operand, acc);
848 break;
849 case vector::CombiningKind::XOR:
850 result =
851 createIntegerReductionArithmeticOpLowering<LLVM::vector_reduce_xor,
852 LLVM::XOrOp>(
853 rewriter, loc, llvmType, operand, acc);
854 break;
855 default:
856 return failure();
857 }
858 rewriter.replaceOp(reductionOp, result);
859
860 return success();
861 }
862
863 if (!isa<FloatType>(Val: eltType))
864 return failure();
865
866 arith::FastMathFlagsAttr fMFAttr = reductionOp.getFastMathFlagsAttr();
867 LLVM::FastmathFlagsAttr fmf = LLVM::FastmathFlagsAttr::get(
868 reductionOp.getContext(),
869 convertArithFastMathFlagsToLLVM(fMFAttr.getValue()));
870 fmf = LLVM::FastmathFlagsAttr::get(
871 reductionOp.getContext(),
872 fmf.getValue() | (reassociateFPReductions ? LLVM::FastmathFlags::reassoc
873 : LLVM::FastmathFlags::none));
874
875 // Floating-point reductions: add/mul/min/max
876 Value result;
877 if (kind == vector::CombiningKind::ADD) {
878 result = lowerReductionWithStartValue<LLVM::vector_reduce_fadd,
879 ReductionNeutralZero>(
880 rewriter, loc, llvmType, operand, acc, fmf);
881 } else if (kind == vector::CombiningKind::MUL) {
882 result = lowerReductionWithStartValue<LLVM::vector_reduce_fmul,
883 ReductionNeutralFPOne>(
884 rewriter, loc, llvmType, operand, acc, fmf);
885 } else if (kind == vector::CombiningKind::MINIMUMF) {
886 result =
887 createFPReductionComparisonOpLowering<LLVM::vector_reduce_fminimum>(
888 rewriter, loc, llvmType, operand, acc, fmf);
889 } else if (kind == vector::CombiningKind::MAXIMUMF) {
890 result =
891 createFPReductionComparisonOpLowering<LLVM::vector_reduce_fmaximum>(
892 rewriter, loc, llvmType, operand, acc, fmf);
893 } else if (kind == vector::CombiningKind::MINNUMF) {
894 result = createFPReductionComparisonOpLowering<LLVM::vector_reduce_fmin>(
895 rewriter, loc, llvmType, operand, acc, fmf);
896 } else if (kind == vector::CombiningKind::MAXNUMF) {
897 result = createFPReductionComparisonOpLowering<LLVM::vector_reduce_fmax>(
898 rewriter, loc, llvmType, operand, acc, fmf);
899 } else {
900 return failure();
901 }
902
903 rewriter.replaceOp(reductionOp, result);
904 return success();
905 }
906
907private:
908 const bool reassociateFPReductions;
909};
910
911/// Base class to convert a `vector.mask` operation while matching traits
912/// of the maskable operation nested inside. A `VectorMaskOpConversionBase`
913/// instance matches against a `vector.mask` operation. The `matchAndRewrite`
914/// method performs a second match against the maskable operation `MaskedOp`.
915/// Finally, it invokes the virtual method `matchAndRewriteMaskableOp` to be
916/// implemented by the concrete conversion classes. This method can match
917/// against specific traits of the `vector.mask` and the maskable operation. It
918/// must replace the `vector.mask` operation.
919template <class MaskedOp>
920class VectorMaskOpConversionBase
921 : public ConvertOpToLLVMPattern<vector::MaskOp> {
922public:
923 using ConvertOpToLLVMPattern<vector::MaskOp>::ConvertOpToLLVMPattern;
924
925 LogicalResult
926 matchAndRewrite(vector::MaskOp maskOp, OpAdaptor adaptor,
927 ConversionPatternRewriter &rewriter) const final {
928 // Match against the maskable operation kind.
929 auto maskedOp = llvm::dyn_cast_or_null<MaskedOp>(maskOp.getMaskableOp());
930 if (!maskedOp)
931 return failure();
932 return matchAndRewriteMaskableOp(maskOp, maskedOp, rewriter);
933 }
934
935protected:
936 virtual LogicalResult
937 matchAndRewriteMaskableOp(vector::MaskOp maskOp,
938 vector::MaskableOpInterface maskableOp,
939 ConversionPatternRewriter &rewriter) const = 0;
940};
941
942class MaskedReductionOpConversion
943 : public VectorMaskOpConversionBase<vector::ReductionOp> {
944
945public:
946 using VectorMaskOpConversionBase<
947 vector::ReductionOp>::VectorMaskOpConversionBase;
948
949 LogicalResult matchAndRewriteMaskableOp(
950 vector::MaskOp maskOp, MaskableOpInterface maskableOp,
951 ConversionPatternRewriter &rewriter) const override {
952 auto reductionOp = cast<ReductionOp>(maskableOp.getOperation());
953 auto kind = reductionOp.getKind();
954 Type eltType = reductionOp.getDest().getType();
955 Type llvmType = typeConverter->convertType(eltType);
956 Value operand = reductionOp.getVector();
957 Value acc = reductionOp.getAcc();
958 Location loc = reductionOp.getLoc();
959
960 arith::FastMathFlagsAttr fMFAttr = reductionOp.getFastMathFlagsAttr();
961 LLVM::FastmathFlagsAttr fmf = LLVM::FastmathFlagsAttr::get(
962 reductionOp.getContext(),
963 convertArithFastMathFlagsToLLVM(fMFAttr.getValue()));
964
965 Value result;
966 switch (kind) {
967 case vector::CombiningKind::ADD:
968 result = lowerPredicatedReductionWithStartValue<
969 LLVM::VPReduceAddOp, ReductionNeutralZero, LLVM::VPReduceFAddOp,
970 ReductionNeutralZero>(rewriter, loc, llvmType, operand, acc,
971 maskOp.getMask());
972 break;
973 case vector::CombiningKind::MUL:
974 result = lowerPredicatedReductionWithStartValue<
975 LLVM::VPReduceMulOp, ReductionNeutralIntOne, LLVM::VPReduceFMulOp,
976 ReductionNeutralFPOne>(rewriter, loc, llvmType, operand, acc,
977 maskOp.getMask());
978 break;
979 case vector::CombiningKind::MINUI:
980 result = lowerPredicatedReductionWithStartValue<LLVM::VPReduceUMinOp,
981 ReductionNeutralUIntMax>(
982 rewriter, loc, llvmType, operand, acc, maskOp.getMask());
983 break;
984 case vector::CombiningKind::MINSI:
985 result = lowerPredicatedReductionWithStartValue<LLVM::VPReduceSMinOp,
986 ReductionNeutralSIntMax>(
987 rewriter, loc, llvmType, operand, acc, maskOp.getMask());
988 break;
989 case vector::CombiningKind::MAXUI:
990 result = lowerPredicatedReductionWithStartValue<LLVM::VPReduceUMaxOp,
991 ReductionNeutralUIntMin>(
992 rewriter, loc, llvmType, operand, acc, maskOp.getMask());
993 break;
994 case vector::CombiningKind::MAXSI:
995 result = lowerPredicatedReductionWithStartValue<LLVM::VPReduceSMaxOp,
996 ReductionNeutralSIntMin>(
997 rewriter, loc, llvmType, operand, acc, maskOp.getMask());
998 break;
999 case vector::CombiningKind::AND:
1000 result = lowerPredicatedReductionWithStartValue<LLVM::VPReduceAndOp,
1001 ReductionNeutralAllOnes>(
1002 rewriter, loc, llvmType, operand, acc, maskOp.getMask());
1003 break;
1004 case vector::CombiningKind::OR:
1005 result = lowerPredicatedReductionWithStartValue<LLVM::VPReduceOrOp,
1006 ReductionNeutralZero>(
1007 rewriter, loc, llvmType, operand, acc, maskOp.getMask());
1008 break;
1009 case vector::CombiningKind::XOR:
1010 result = lowerPredicatedReductionWithStartValue<LLVM::VPReduceXorOp,
1011 ReductionNeutralZero>(
1012 rewriter, loc, llvmType, operand, acc, maskOp.getMask());
1013 break;
1014 case vector::CombiningKind::MINNUMF:
1015 result = lowerPredicatedReductionWithStartValue<LLVM::VPReduceFMinOp,
1016 ReductionNeutralFPMax>(
1017 rewriter, loc, llvmType, operand, acc, maskOp.getMask());
1018 break;
1019 case vector::CombiningKind::MAXNUMF:
1020 result = lowerPredicatedReductionWithStartValue<LLVM::VPReduceFMaxOp,
1021 ReductionNeutralFPMin>(
1022 rewriter, loc, llvmType, operand, acc, maskOp.getMask());
1023 break;
1024 case CombiningKind::MAXIMUMF:
1025 result = lowerMaskedReductionWithRegular<LLVM::vector_reduce_fmaximum,
1026 MaskNeutralFMaximum>(
1027 rewriter, loc, llvmType, operand, acc, maskOp.getMask(), fmf);
1028 break;
1029 case CombiningKind::MINIMUMF:
1030 result = lowerMaskedReductionWithRegular<LLVM::vector_reduce_fminimum,
1031 MaskNeutralFMinimum>(
1032 rewriter, loc, llvmType, operand, acc, maskOp.getMask(), fmf);
1033 break;
1034 }
1035
1036 // Replace `vector.mask` operation altogether.
1037 rewriter.replaceOp(maskOp, result);
1038 return success();
1039 }
1040};
1041
1042class VectorShuffleOpConversion
1043 : public ConvertOpToLLVMPattern<vector::ShuffleOp> {
1044public:
1045 using ConvertOpToLLVMPattern<vector::ShuffleOp>::ConvertOpToLLVMPattern;
1046
1047 LogicalResult
1048 matchAndRewrite(vector::ShuffleOp shuffleOp, OpAdaptor adaptor,
1049 ConversionPatternRewriter &rewriter) const override {
1050 auto loc = shuffleOp->getLoc();
1051 auto v1Type = shuffleOp.getV1VectorType();
1052 auto v2Type = shuffleOp.getV2VectorType();
1053 auto vectorType = shuffleOp.getResultVectorType();
1054 Type llvmType = typeConverter->convertType(vectorType);
1055 ArrayRef<int64_t> mask = shuffleOp.getMask();
1056
1057 // Bail if result type cannot be lowered.
1058 if (!llvmType)
1059 return failure();
1060
1061 // Get rank and dimension sizes.
1062 int64_t rank = vectorType.getRank();
1063#ifndef NDEBUG
1064 bool wellFormed0DCase =
1065 v1Type.getRank() == 0 && v2Type.getRank() == 0 && rank == 1;
1066 bool wellFormedNDCase =
1067 v1Type.getRank() == rank && v2Type.getRank() == rank;
1068 assert((wellFormed0DCase || wellFormedNDCase) && "op is not well-formed");
1069#endif
1070
1071 // For rank 0 and 1, where both operands have *exactly* the same vector
1072 // type, there is direct shuffle support in LLVM. Use it!
1073 if (rank <= 1 && v1Type == v2Type) {
1074 Value llvmShuffleOp = rewriter.create<LLVM::ShuffleVectorOp>(
1075 loc, adaptor.getV1(), adaptor.getV2(),
1076 llvm::to_vector_of<int32_t>(mask));
1077 rewriter.replaceOp(shuffleOp, llvmShuffleOp);
1078 return success();
1079 }
1080
1081 // For all other cases, insert the individual values individually.
1082 int64_t v1Dim = v1Type.getDimSize(0);
1083 Type eltType;
1084 if (auto arrayType = dyn_cast<LLVM::LLVMArrayType>(llvmType))
1085 eltType = arrayType.getElementType();
1086 else
1087 eltType = cast<VectorType>(llvmType).getElementType();
1088 Value insert = rewriter.create<LLVM::PoisonOp>(loc, llvmType);
1089 int64_t insPos = 0;
1090 for (int64_t extPos : mask) {
1091 Value value = adaptor.getV1();
1092 if (extPos >= v1Dim) {
1093 extPos -= v1Dim;
1094 value = adaptor.getV2();
1095 }
1096 Value extract = extractOne(rewriter, *getTypeConverter(), loc, value,
1097 eltType, rank, extPos);
1098 insert = insertOne(rewriter, *getTypeConverter(), loc, insert, extract,
1099 llvmType, rank, insPos++);
1100 }
1101 rewriter.replaceOp(shuffleOp, insert);
1102 return success();
1103 }
1104};
1105
1106class VectorExtractElementOpConversion
1107 : public ConvertOpToLLVMPattern<vector::ExtractElementOp> {
1108public:
1109 using ConvertOpToLLVMPattern<
1110 vector::ExtractElementOp>::ConvertOpToLLVMPattern;
1111
1112 LogicalResult
1113 matchAndRewrite(vector::ExtractElementOp extractEltOp, OpAdaptor adaptor,
1114 ConversionPatternRewriter &rewriter) const override {
1115 auto vectorType = extractEltOp.getSourceVectorType();
1116 auto llvmType = typeConverter->convertType(vectorType.getElementType());
1117
1118 // Bail if result type cannot be lowered.
1119 if (!llvmType)
1120 return failure();
1121
1122 if (vectorType.getRank() == 0) {
1123 Location loc = extractEltOp.getLoc();
1124 auto idxType = rewriter.getIndexType();
1125 auto zero = rewriter.create<LLVM::ConstantOp>(
1126 loc, typeConverter->convertType(idxType),
1127 rewriter.getIntegerAttr(idxType, 0));
1128 rewriter.replaceOpWithNewOp<LLVM::ExtractElementOp>(
1129 extractEltOp, llvmType, adaptor.getVector(), zero);
1130 return success();
1131 }
1132
1133 rewriter.replaceOpWithNewOp<LLVM::ExtractElementOp>(
1134 extractEltOp, llvmType, adaptor.getVector(), adaptor.getPosition());
1135 return success();
1136 }
1137};
1138
1139class VectorExtractOpConversion
1140 : public ConvertOpToLLVMPattern<vector::ExtractOp> {
1141public:
1142 using ConvertOpToLLVMPattern<vector::ExtractOp>::ConvertOpToLLVMPattern;
1143
1144 LogicalResult
1145 matchAndRewrite(vector::ExtractOp extractOp, OpAdaptor adaptor,
1146 ConversionPatternRewriter &rewriter) const override {
1147 auto loc = extractOp->getLoc();
1148 auto resultType = extractOp.getResult().getType();
1149 auto llvmResultType = typeConverter->convertType(resultType);
1150 // Bail if result type cannot be lowered.
1151 if (!llvmResultType)
1152 return failure();
1153
1154 SmallVector<OpFoldResult> positionVec = getMixedValues(
1155 adaptor.getStaticPosition(), adaptor.getDynamicPosition(), rewriter);
1156
1157 // The Vector -> LLVM lowering models N-D vectors as nested aggregates of
1158 // 1-d vectors. This nesting is modeled using arrays. We do this conversion
1159 // from a N-d vector extract to a nested aggregate vector extract in two
1160 // steps:
1161 // - Extract a member from the nested aggregate. The result can be
1162 // a lower rank nested aggregate or a vector (1-D). This is done using
1163 // `llvm.extractvalue`.
1164 // - Extract a scalar out of the vector if needed. This is done using
1165 // `llvm.extractelement`.
1166
1167 // Determine if we need to extract a member out of the aggregate. We
1168 // always need to extract a member if the input rank >= 2.
1169 bool extractsAggregate = extractOp.getSourceVectorType().getRank() >= 2;
1170 // Determine if we need to extract a scalar as the result. We extract
1171 // a scalar if the extract is full rank, i.e., the number of indices is
1172 // equal to source vector rank.
1173 bool extractsScalar = static_cast<int64_t>(positionVec.size()) ==
1174 extractOp.getSourceVectorType().getRank();
1175
1176 // Since the LLVM type converter converts 0-d vectors to 1-d vectors, we
1177 // need to add a position for this change.
1178 if (extractOp.getSourceVectorType().getRank() == 0) {
1179 Type idxType = typeConverter->convertType(rewriter.getIndexType());
1180 positionVec.push_back(rewriter.getZeroAttr(idxType));
1181 }
1182
1183 Value extracted = adaptor.getVector();
1184 if (extractsAggregate) {
1185 ArrayRef<OpFoldResult> position(positionVec);
1186 if (extractsScalar) {
1187 // If we are extracting a scalar from the extracted member, we drop
1188 // the last index, which will be used to extract the scalar out of the
1189 // vector.
1190 position = position.drop_back();
1191 }
1192 // llvm.extractvalue does not support dynamic dimensions.
1193 if (!llvm::all_of(Range&: position, P: llvm::IsaPred<Attribute>)) {
1194 return failure();
1195 }
1196 extracted = rewriter.create<LLVM::ExtractValueOp>(
1197 loc, extracted, getAsIntegers(position));
1198 }
1199
1200 if (extractsScalar) {
1201 extracted = rewriter.create<LLVM::ExtractElementOp>(
1202 loc, extracted, getAsLLVMValue(rewriter, loc, positionVec.back()));
1203 }
1204
1205 rewriter.replaceOp(extractOp, extracted);
1206 return success();
1207 }
1208};
1209
1210/// Conversion pattern that turns a vector.fma on a 1-D vector
1211/// into an llvm.intr.fmuladd. This is a trivial 1-1 conversion.
1212/// This does not match vectors of n >= 2 rank.
1213///
1214/// Example:
1215/// ```
1216/// vector.fma %a, %a, %a : vector<8xf32>
1217/// ```
1218/// is converted to:
1219/// ```
1220/// llvm.intr.fmuladd %va, %va, %va:
1221/// (!llvm."<8 x f32>">, !llvm<"<8 x f32>">, !llvm<"<8 x f32>">)
1222/// -> !llvm."<8 x f32>">
1223/// ```
1224class VectorFMAOp1DConversion : public ConvertOpToLLVMPattern<vector::FMAOp> {
1225public:
1226 using ConvertOpToLLVMPattern<vector::FMAOp>::ConvertOpToLLVMPattern;
1227
1228 LogicalResult
1229 matchAndRewrite(vector::FMAOp fmaOp, OpAdaptor adaptor,
1230 ConversionPatternRewriter &rewriter) const override {
1231 VectorType vType = fmaOp.getVectorType();
1232 if (vType.getRank() > 1)
1233 return failure();
1234
1235 rewriter.replaceOpWithNewOp<LLVM::FMulAddOp>(
1236 fmaOp, adaptor.getLhs(), adaptor.getRhs(), adaptor.getAcc());
1237 return success();
1238 }
1239};
1240
1241class VectorInsertElementOpConversion
1242 : public ConvertOpToLLVMPattern<vector::InsertElementOp> {
1243public:
1244 using ConvertOpToLLVMPattern<vector::InsertElementOp>::ConvertOpToLLVMPattern;
1245
1246 LogicalResult
1247 matchAndRewrite(vector::InsertElementOp insertEltOp, OpAdaptor adaptor,
1248 ConversionPatternRewriter &rewriter) const override {
1249 auto vectorType = insertEltOp.getDestVectorType();
1250 auto llvmType = typeConverter->convertType(vectorType);
1251
1252 // Bail if result type cannot be lowered.
1253 if (!llvmType)
1254 return failure();
1255
1256 if (vectorType.getRank() == 0) {
1257 Location loc = insertEltOp.getLoc();
1258 auto idxType = rewriter.getIndexType();
1259 auto zero = rewriter.create<LLVM::ConstantOp>(
1260 loc, typeConverter->convertType(idxType),
1261 rewriter.getIntegerAttr(idxType, 0));
1262 rewriter.replaceOpWithNewOp<LLVM::InsertElementOp>(
1263 insertEltOp, llvmType, adaptor.getDest(), adaptor.getSource(), zero);
1264 return success();
1265 }
1266
1267 rewriter.replaceOpWithNewOp<LLVM::InsertElementOp>(
1268 insertEltOp, llvmType, adaptor.getDest(), adaptor.getSource(),
1269 adaptor.getPosition());
1270 return success();
1271 }
1272};
1273
1274class VectorInsertOpConversion
1275 : public ConvertOpToLLVMPattern<vector::InsertOp> {
1276public:
1277 using ConvertOpToLLVMPattern<vector::InsertOp>::ConvertOpToLLVMPattern;
1278
1279 LogicalResult
1280 matchAndRewrite(vector::InsertOp insertOp, OpAdaptor adaptor,
1281 ConversionPatternRewriter &rewriter) const override {
1282 auto loc = insertOp->getLoc();
1283 auto destVectorType = insertOp.getDestVectorType();
1284 auto llvmResultType = typeConverter->convertType(destVectorType);
1285 // Bail if result type cannot be lowered.
1286 if (!llvmResultType)
1287 return failure();
1288
1289 SmallVector<OpFoldResult> positionVec = getMixedValues(
1290 adaptor.getStaticPosition(), adaptor.getDynamicPosition(), rewriter);
1291
1292 // The logic in this pattern mirrors VectorExtractOpConversion. Refer to
1293 // its explanatory comment about how N-D vectors are converted as nested
1294 // aggregates (llvm.array's) of 1D vectors.
1295 //
1296 // The innermost dimension of the destination vector, when converted to a
1297 // nested aggregate form, will always be a 1D vector.
1298 //
1299 // * If the insertion is happening into the innermost dimension of the
1300 // destination vector:
1301 // - If the destination is a nested aggregate, extract a 1D vector out of
1302 // the aggregate. This can be done using llvm.extractvalue. The
1303 // destination is now guaranteed to be a 1D vector, to which we are
1304 // inserting.
1305 // - Do the insertion into the 1D destination vector, and make the result
1306 // the new source nested aggregate. This can be done using
1307 // llvm.insertelement.
1308 // * Insert the source nested aggregate into the destination nested
1309 // aggregate.
1310
1311 // Determine if we need to extract/insert a 1D vector out of the aggregate.
1312 bool isNestedAggregate = isa<LLVM::LLVMArrayType>(llvmResultType);
1313 // Determine if we need to insert a scalar into the 1D vector.
1314 bool insertIntoInnermostDim =
1315 static_cast<int64_t>(positionVec.size()) == destVectorType.getRank();
1316
1317 ArrayRef<OpFoldResult> positionOf1DVectorWithinAggregate(
1318 positionVec.begin(),
1319 insertIntoInnermostDim ? positionVec.size() - 1 : positionVec.size());
1320 OpFoldResult positionOfScalarWithin1DVector;
1321 if (destVectorType.getRank() == 0) {
1322 // Since the LLVM type converter converts 0D vectors to 1D vectors, we
1323 // need to create a 0 here as the position into the 1D vector.
1324 Type idxType = typeConverter->convertType(rewriter.getIndexType());
1325 positionOfScalarWithin1DVector = rewriter.getZeroAttr(idxType);
1326 } else if (insertIntoInnermostDim) {
1327 positionOfScalarWithin1DVector = positionVec.back();
1328 }
1329
1330 // We are going to mutate this 1D vector until it is either the final
1331 // result (in the non-aggregate case) or the value that needs to be
1332 // inserted into the aggregate result.
1333 Value sourceAggregate = adaptor.getValueToStore();
1334 if (insertIntoInnermostDim) {
1335 // Scalar-into-1D-vector case, so we know we will have to create a
1336 // InsertElementOp. The question is into what destination.
1337 if (isNestedAggregate) {
1338 // Aggregate case: the destination for the InsertElementOp needs to be
1339 // extracted from the aggregate.
1340 if (!llvm::all_of(Range&: positionOf1DVectorWithinAggregate,
1341 P: llvm::IsaPred<Attribute>)) {
1342 // llvm.extractvalue does not support dynamic dimensions.
1343 return failure();
1344 }
1345 sourceAggregate = rewriter.create<LLVM::ExtractValueOp>(
1346 loc, adaptor.getDest(),
1347 getAsIntegers(positionOf1DVectorWithinAggregate));
1348 } else {
1349 // No-aggregate case. The destination for the InsertElementOp is just
1350 // the insertOp's destination.
1351 sourceAggregate = adaptor.getDest();
1352 }
1353 // Insert the scalar into the 1D vector.
1354 sourceAggregate = rewriter.create<LLVM::InsertElementOp>(
1355 loc, sourceAggregate.getType(), sourceAggregate,
1356 adaptor.getValueToStore(),
1357 getAsLLVMValue(rewriter, loc, positionOfScalarWithin1DVector));
1358 }
1359
1360 Value result = sourceAggregate;
1361 if (isNestedAggregate) {
1362 result = rewriter.create<LLVM::InsertValueOp>(
1363 loc, adaptor.getDest(), sourceAggregate,
1364 getAsIntegers(positionOf1DVectorWithinAggregate));
1365 }
1366
1367 rewriter.replaceOp(insertOp, result);
1368 return success();
1369 }
1370};
1371
1372/// Lower vector.scalable.insert ops to LLVM vector.insert
1373struct VectorScalableInsertOpLowering
1374 : public ConvertOpToLLVMPattern<vector::ScalableInsertOp> {
1375 using ConvertOpToLLVMPattern<
1376 vector::ScalableInsertOp>::ConvertOpToLLVMPattern;
1377
1378 LogicalResult
1379 matchAndRewrite(vector::ScalableInsertOp insOp, OpAdaptor adaptor,
1380 ConversionPatternRewriter &rewriter) const override {
1381 rewriter.replaceOpWithNewOp<LLVM::vector_insert>(
1382 insOp, adaptor.getDest(), adaptor.getValueToStore(), adaptor.getPos());
1383 return success();
1384 }
1385};
1386
1387/// Lower vector.scalable.extract ops to LLVM vector.extract
1388struct VectorScalableExtractOpLowering
1389 : public ConvertOpToLLVMPattern<vector::ScalableExtractOp> {
1390 using ConvertOpToLLVMPattern<
1391 vector::ScalableExtractOp>::ConvertOpToLLVMPattern;
1392
1393 LogicalResult
1394 matchAndRewrite(vector::ScalableExtractOp extOp, OpAdaptor adaptor,
1395 ConversionPatternRewriter &rewriter) const override {
1396 rewriter.replaceOpWithNewOp<LLVM::vector_extract>(
1397 extOp, typeConverter->convertType(extOp.getResultVectorType()),
1398 adaptor.getSource(), adaptor.getPos());
1399 return success();
1400 }
1401};
1402
1403/// Rank reducing rewrite for n-D FMA into (n-1)-D FMA where n > 1.
1404///
1405/// Example:
1406/// ```
1407/// %d = vector.fma %a, %b, %c : vector<2x4xf32>
1408/// ```
1409/// is rewritten into:
1410/// ```
1411/// %r = splat %f0: vector<2x4xf32>
1412/// %va = vector.extractvalue %a[0] : vector<2x4xf32>
1413/// %vb = vector.extractvalue %b[0] : vector<2x4xf32>
1414/// %vc = vector.extractvalue %c[0] : vector<2x4xf32>
1415/// %vd = vector.fma %va, %vb, %vc : vector<4xf32>
1416/// %r2 = vector.insertvalue %vd, %r[0] : vector<4xf32> into vector<2x4xf32>
1417/// %va2 = vector.extractvalue %a2[1] : vector<2x4xf32>
1418/// %vb2 = vector.extractvalue %b2[1] : vector<2x4xf32>
1419/// %vc2 = vector.extractvalue %c2[1] : vector<2x4xf32>
1420/// %vd2 = vector.fma %va2, %vb2, %vc2 : vector<4xf32>
1421/// %r3 = vector.insertvalue %vd2, %r2[1] : vector<4xf32> into vector<2x4xf32>
1422/// // %r3 holds the final value.
1423/// ```
1424class VectorFMAOpNDRewritePattern : public OpRewritePattern<FMAOp> {
1425public:
1426 using OpRewritePattern<FMAOp>::OpRewritePattern;
1427
1428 void initialize() {
1429 // This pattern recursively unpacks one dimension at a time. The recursion
1430 // bounded as the rank is strictly decreasing.
1431 setHasBoundedRewriteRecursion();
1432 }
1433
1434 LogicalResult matchAndRewrite(FMAOp op,
1435 PatternRewriter &rewriter) const override {
1436 auto vType = op.getVectorType();
1437 if (vType.getRank() < 2)
1438 return failure();
1439
1440 auto loc = op.getLoc();
1441 auto elemType = vType.getElementType();
1442 Value zero = rewriter.create<arith::ConstantOp>(
1443 loc, elemType, rewriter.getZeroAttr(elemType));
1444 Value desc = rewriter.create<vector::SplatOp>(loc, vType, zero);
1445 for (int64_t i = 0, e = vType.getShape().front(); i != e; ++i) {
1446 Value extrLHS = rewriter.create<ExtractOp>(loc, op.getLhs(), i);
1447 Value extrRHS = rewriter.create<ExtractOp>(loc, op.getRhs(), i);
1448 Value extrACC = rewriter.create<ExtractOp>(loc, op.getAcc(), i);
1449 Value fma = rewriter.create<FMAOp>(loc, extrLHS, extrRHS, extrACC);
1450 desc = rewriter.create<InsertOp>(loc, fma, desc, i);
1451 }
1452 rewriter.replaceOp(op, desc);
1453 return success();
1454 }
1455};
1456
1457/// Returns the strides if the memory underlying `memRefType` has a contiguous
1458/// static layout.
1459static std::optional<SmallVector<int64_t, 4>>
1460computeContiguousStrides(MemRefType memRefType) {
1461 int64_t offset;
1462 SmallVector<int64_t, 4> strides;
1463 if (failed(memRefType.getStridesAndOffset(strides, offset)))
1464 return std::nullopt;
1465 if (!strides.empty() && strides.back() != 1)
1466 return std::nullopt;
1467 // If no layout or identity layout, this is contiguous by definition.
1468 if (memRefType.getLayout().isIdentity())
1469 return strides;
1470
1471 // Otherwise, we must determine contiguity form shapes. This can only ever
1472 // work in static cases because MemRefType is underspecified to represent
1473 // contiguous dynamic shapes in other ways than with just empty/identity
1474 // layout.
1475 auto sizes = memRefType.getShape();
1476 for (int index = 0, e = strides.size() - 1; index < e; ++index) {
1477 if (ShapedType::isDynamic(sizes[index + 1]) ||
1478 ShapedType::isDynamic(strides[index]) ||
1479 ShapedType::isDynamic(strides[index + 1]))
1480 return std::nullopt;
1481 if (strides[index] != strides[index + 1] * sizes[index + 1])
1482 return std::nullopt;
1483 }
1484 return strides;
1485}
1486
1487class VectorTypeCastOpConversion
1488 : public ConvertOpToLLVMPattern<vector::TypeCastOp> {
1489public:
1490 using ConvertOpToLLVMPattern<vector::TypeCastOp>::ConvertOpToLLVMPattern;
1491
1492 LogicalResult
1493 matchAndRewrite(vector::TypeCastOp castOp, OpAdaptor adaptor,
1494 ConversionPatternRewriter &rewriter) const override {
1495 auto loc = castOp->getLoc();
1496 MemRefType sourceMemRefType =
1497 cast<MemRefType>(castOp.getOperand().getType());
1498 MemRefType targetMemRefType = castOp.getType();
1499
1500 // Only static shape casts supported atm.
1501 if (!sourceMemRefType.hasStaticShape() ||
1502 !targetMemRefType.hasStaticShape())
1503 return failure();
1504
1505 auto llvmSourceDescriptorTy =
1506 dyn_cast<LLVM::LLVMStructType>(adaptor.getOperands()[0].getType());
1507 if (!llvmSourceDescriptorTy)
1508 return failure();
1509 MemRefDescriptor sourceMemRef(adaptor.getOperands()[0]);
1510
1511 auto llvmTargetDescriptorTy = dyn_cast_or_null<LLVM::LLVMStructType>(
1512 typeConverter->convertType(targetMemRefType));
1513 if (!llvmTargetDescriptorTy)
1514 return failure();
1515
1516 // Only contiguous source buffers supported atm.
1517 auto sourceStrides = computeContiguousStrides(sourceMemRefType);
1518 if (!sourceStrides)
1519 return failure();
1520 auto targetStrides = computeContiguousStrides(targetMemRefType);
1521 if (!targetStrides)
1522 return failure();
1523 // Only support static strides for now, regardless of contiguity.
1524 if (llvm::any_of(*targetStrides, ShapedType::isDynamic))
1525 return failure();
1526
1527 auto int64Ty = IntegerType::get(rewriter.getContext(), 64);
1528
1529 // Create descriptor.
1530 auto desc = MemRefDescriptor::poison(builder&: rewriter, loc: loc, descriptorType: llvmTargetDescriptorTy);
1531 // Set allocated ptr.
1532 Value allocated = sourceMemRef.allocatedPtr(builder&: rewriter, loc: loc);
1533 desc.setAllocatedPtr(rewriter, loc, allocated);
1534
1535 // Set aligned ptr.
1536 Value ptr = sourceMemRef.alignedPtr(builder&: rewriter, loc: loc);
1537 desc.setAlignedPtr(rewriter, loc, ptr);
1538 // Fill offset 0.
1539 auto attr = rewriter.getIntegerAttr(rewriter.getIndexType(), 0);
1540 auto zero = rewriter.create<LLVM::ConstantOp>(loc, int64Ty, attr);
1541 desc.setOffset(rewriter, loc, zero);
1542
1543 // Fill size and stride descriptors in memref.
1544 for (const auto &indexedSize :
1545 llvm::enumerate(targetMemRefType.getShape())) {
1546 int64_t index = indexedSize.index();
1547 auto sizeAttr =
1548 rewriter.getIntegerAttr(rewriter.getIndexType(), indexedSize.value());
1549 auto size = rewriter.create<LLVM::ConstantOp>(loc, int64Ty, sizeAttr);
1550 desc.setSize(rewriter, loc, index, size);
1551 auto strideAttr = rewriter.getIntegerAttr(rewriter.getIndexType(),
1552 (*targetStrides)[index]);
1553 auto stride = rewriter.create<LLVM::ConstantOp>(loc, int64Ty, strideAttr);
1554 desc.setStride(rewriter, loc, index, stride);
1555 }
1556
1557 rewriter.replaceOp(castOp, {desc});
1558 return success();
1559 }
1560};
1561
1562/// Conversion pattern for a `vector.create_mask` (1-D scalable vectors only).
1563/// Non-scalable versions of this operation are handled in Vector Transforms.
1564class VectorCreateMaskOpConversion
1565 : public OpConversionPattern<vector::CreateMaskOp> {
1566public:
1567 explicit VectorCreateMaskOpConversion(MLIRContext *context,
1568 bool enableIndexOpt)
1569 : OpConversionPattern<vector::CreateMaskOp>(context),
1570 force32BitVectorIndices(enableIndexOpt) {}
1571
1572 LogicalResult
1573 matchAndRewrite(vector::CreateMaskOp op, OpAdaptor adaptor,
1574 ConversionPatternRewriter &rewriter) const override {
1575 auto dstType = op.getType();
1576 if (dstType.getRank() != 1 || !cast<VectorType>(dstType).isScalable())
1577 return failure();
1578 IntegerType idxType =
1579 force32BitVectorIndices ? rewriter.getI32Type() : rewriter.getI64Type();
1580 auto loc = op->getLoc();
1581 Value indices = rewriter.create<LLVM::StepVectorOp>(
1582 loc, LLVM::getVectorType(idxType, dstType.getShape()[0],
1583 /*isScalable=*/true));
1584 auto bound = getValueOrCreateCastToIndexLike(rewriter, loc, idxType,
1585 adaptor.getOperands()[0]);
1586 Value bounds = rewriter.create<SplatOp>(loc, indices.getType(), bound);
1587 Value comp = rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::slt,
1588 indices, bounds);
1589 rewriter.replaceOp(op, comp);
1590 return success();
1591 }
1592
1593private:
1594 const bool force32BitVectorIndices;
1595};
1596
1597class VectorPrintOpConversion : public ConvertOpToLLVMPattern<vector::PrintOp> {
1598public:
1599 using ConvertOpToLLVMPattern<vector::PrintOp>::ConvertOpToLLVMPattern;
1600
1601 // Lowering implementation that relies on a small runtime support library,
1602 // which only needs to provide a few printing methods (single value for all
1603 // data types, opening/closing bracket, comma, newline). The lowering splits
1604 // the vector into elementary printing operations. The advantage of this
1605 // approach is that the library can remain unaware of all low-level
1606 // implementation details of vectors while still supporting output of any
1607 // shaped and dimensioned vector.
1608 //
1609 // Note: This lowering only handles scalars, n-D vectors are broken into
1610 // printing scalars in loops in VectorToSCF.
1611 //
1612 // TODO: rely solely on libc in future? something else?
1613 //
1614 LogicalResult
1615 matchAndRewrite(vector::PrintOp printOp, OpAdaptor adaptor,
1616 ConversionPatternRewriter &rewriter) const override {
1617 auto parent = printOp->getParentOfType<ModuleOp>();
1618 if (!parent)
1619 return failure();
1620
1621 auto loc = printOp->getLoc();
1622
1623 if (auto value = adaptor.getSource()) {
1624 Type printType = printOp.getPrintType();
1625 if (isa<VectorType>(Val: printType)) {
1626 // Vectors should be broken into elementary print ops in VectorToSCF.
1627 return failure();
1628 }
1629 if (failed(emitScalarPrint(rewriter, parent: parent, loc: loc, printType, value: value)))
1630 return failure();
1631 }
1632
1633 auto punct = printOp.getPunctuation();
1634 if (auto stringLiteral = printOp.getStringLiteral()) {
1635 auto createResult =
1636 LLVM::createPrintStrCall(builder&: rewriter, loc: loc, moduleOp: parent, symbolName: "vector_print_str",
1637 string: *stringLiteral, typeConverter: *getTypeConverter(),
1638 /*addNewline=*/false);
1639 if (createResult.failed())
1640 return failure();
1641
1642 } else if (punct != PrintPunctuation::NoPunctuation) {
1643 FailureOr<LLVM::LLVMFuncOp> op = [&]() {
1644 switch (punct) {
1645 case PrintPunctuation::Close:
1646 return LLVM::lookupOrCreatePrintCloseFn(rewriter, parent);
1647 case PrintPunctuation::Open:
1648 return LLVM::lookupOrCreatePrintOpenFn(rewriter, parent);
1649 case PrintPunctuation::Comma:
1650 return LLVM::lookupOrCreatePrintCommaFn(rewriter, parent);
1651 case PrintPunctuation::NewLine:
1652 return LLVM::lookupOrCreatePrintNewlineFn(rewriter, parent);
1653 default:
1654 llvm_unreachable("unexpected punctuation");
1655 }
1656 }();
1657 if (failed(op))
1658 return failure();
1659 emitCall(rewriter, loc: printOp->getLoc(), ref: op.value());
1660 }
1661
1662 rewriter.eraseOp(op: printOp);
1663 return success();
1664 }
1665
1666private:
1667 enum class PrintConversion {
1668 // clang-format off
1669 None,
1670 ZeroExt64,
1671 SignExt64,
1672 Bitcast16
1673 // clang-format on
1674 };
1675
1676 LogicalResult emitScalarPrint(ConversionPatternRewriter &rewriter,
1677 ModuleOp parent, Location loc, Type printType,
1678 Value value) const {
1679 if (typeConverter->convertType(printType) == nullptr)
1680 return failure();
1681
1682 // Make sure element type has runtime support.
1683 PrintConversion conversion = PrintConversion::None;
1684 FailureOr<Operation *> printer;
1685 if (printType.isF32()) {
1686 printer = LLVM::lookupOrCreatePrintF32Fn(b&: rewriter, moduleOp: parent);
1687 } else if (printType.isF64()) {
1688 printer = LLVM::lookupOrCreatePrintF64Fn(b&: rewriter, moduleOp: parent);
1689 } else if (printType.isF16()) {
1690 conversion = PrintConversion::Bitcast16; // bits!
1691 printer = LLVM::lookupOrCreatePrintF16Fn(b&: rewriter, moduleOp: parent);
1692 } else if (printType.isBF16()) {
1693 conversion = PrintConversion::Bitcast16; // bits!
1694 printer = LLVM::lookupOrCreatePrintBF16Fn(b&: rewriter, moduleOp: parent);
1695 } else if (printType.isIndex()) {
1696 printer = LLVM::lookupOrCreatePrintU64Fn(b&: rewriter, moduleOp: parent);
1697 } else if (auto intTy = dyn_cast<IntegerType>(printType)) {
1698 // Integers need a zero or sign extension on the operand
1699 // (depending on the source type) as well as a signed or
1700 // unsigned print method. Up to 64-bit is supported.
1701 unsigned width = intTy.getWidth();
1702 if (intTy.isUnsigned()) {
1703 if (width <= 64) {
1704 if (width < 64)
1705 conversion = PrintConversion::ZeroExt64;
1706 printer = LLVM::lookupOrCreatePrintU64Fn(b&: rewriter, moduleOp: parent);
1707 } else {
1708 return failure();
1709 }
1710 } else {
1711 assert(intTy.isSignless() || intTy.isSigned());
1712 if (width <= 64) {
1713 // Note that we *always* zero extend booleans (1-bit integers),
1714 // so that true/false is printed as 1/0 rather than -1/0.
1715 if (width == 1)
1716 conversion = PrintConversion::ZeroExt64;
1717 else if (width < 64)
1718 conversion = PrintConversion::SignExt64;
1719 printer = LLVM::lookupOrCreatePrintI64Fn(b&: rewriter, moduleOp: parent);
1720 } else {
1721 return failure();
1722 }
1723 }
1724 } else {
1725 return failure();
1726 }
1727 if (failed(Result: printer))
1728 return failure();
1729
1730 switch (conversion) {
1731 case PrintConversion::ZeroExt64:
1732 value = rewriter.create<arith::ExtUIOp>(
1733 loc, IntegerType::get(rewriter.getContext(), 64), value);
1734 break;
1735 case PrintConversion::SignExt64:
1736 value = rewriter.create<arith::ExtSIOp>(
1737 loc, IntegerType::get(rewriter.getContext(), 64), value);
1738 break;
1739 case PrintConversion::Bitcast16:
1740 value = rewriter.create<LLVM::BitcastOp>(
1741 loc, IntegerType::get(rewriter.getContext(), 16), value);
1742 break;
1743 case PrintConversion::None:
1744 break;
1745 }
1746 emitCall(rewriter, loc, ref: printer.value(), params: value);
1747 return success();
1748 }
1749
1750 // Helper to emit a call.
1751 static void emitCall(ConversionPatternRewriter &rewriter, Location loc,
1752 Operation *ref, ValueRange params = ValueRange()) {
1753 rewriter.create<LLVM::CallOp>(loc, TypeRange(), SymbolRefAttr::get(ref),
1754 params);
1755 }
1756};
1757
1758/// The Splat operation is lowered to an insertelement + a shufflevector
1759/// operation. Splat to only 0-d and 1-d vector result types are lowered.
1760struct VectorSplatOpLowering : public ConvertOpToLLVMPattern<vector::SplatOp> {
1761 using ConvertOpToLLVMPattern<vector::SplatOp>::ConvertOpToLLVMPattern;
1762
1763 LogicalResult
1764 matchAndRewrite(vector::SplatOp splatOp, OpAdaptor adaptor,
1765 ConversionPatternRewriter &rewriter) const override {
1766 VectorType resultType = cast<VectorType>(splatOp.getType());
1767 if (resultType.getRank() > 1)
1768 return failure();
1769
1770 // First insert it into a poison vector so we can shuffle it.
1771 auto vectorType = typeConverter->convertType(splatOp.getType());
1772 Value poison =
1773 rewriter.create<LLVM::PoisonOp>(splatOp.getLoc(), vectorType);
1774 auto zero = rewriter.create<LLVM::ConstantOp>(
1775 splatOp.getLoc(),
1776 typeConverter->convertType(rewriter.getIntegerType(32)),
1777 rewriter.getZeroAttr(rewriter.getIntegerType(32)));
1778
1779 // For 0-d vector, we simply do `insertelement`.
1780 if (resultType.getRank() == 0) {
1781 rewriter.replaceOpWithNewOp<LLVM::InsertElementOp>(
1782 splatOp, vectorType, poison, adaptor.getInput(), zero);
1783 return success();
1784 }
1785
1786 // For 1-d vector, we additionally do a `vectorshuffle`.
1787 auto v = rewriter.create<LLVM::InsertElementOp>(
1788 splatOp.getLoc(), vectorType, poison, adaptor.getInput(), zero);
1789
1790 int64_t width = cast<VectorType>(splatOp.getType()).getDimSize(0);
1791 SmallVector<int32_t> zeroValues(width, 0);
1792
1793 // Shuffle the value across the desired number of elements.
1794 rewriter.replaceOpWithNewOp<LLVM::ShuffleVectorOp>(splatOp, v, poison,
1795 zeroValues);
1796 return success();
1797 }
1798};
1799
1800/// The Splat operation is lowered to an insertelement + a shufflevector
1801/// operation. Splat to only 2+-d vector result types are lowered by the
1802/// SplatNdOpLowering, the 1-d case is handled by SplatOpLowering.
1803struct VectorSplatNdOpLowering : public ConvertOpToLLVMPattern<SplatOp> {
1804 using ConvertOpToLLVMPattern<SplatOp>::ConvertOpToLLVMPattern;
1805
1806 LogicalResult
1807 matchAndRewrite(SplatOp splatOp, OpAdaptor adaptor,
1808 ConversionPatternRewriter &rewriter) const override {
1809 VectorType resultType = splatOp.getType();
1810 if (resultType.getRank() <= 1)
1811 return failure();
1812
1813 // First insert it into an undef vector so we can shuffle it.
1814 auto loc = splatOp.getLoc();
1815 auto vectorTypeInfo =
1816 LLVM::detail::extractNDVectorTypeInfo(vectorType: resultType, converter: *getTypeConverter());
1817 auto llvmNDVectorTy = vectorTypeInfo.llvmNDVectorTy;
1818 auto llvm1DVectorTy = vectorTypeInfo.llvm1DVectorTy;
1819 if (!llvmNDVectorTy || !llvm1DVectorTy)
1820 return failure();
1821
1822 // Construct returned value.
1823 Value desc = rewriter.create<LLVM::PoisonOp>(loc, llvmNDVectorTy);
1824
1825 // Construct a 1-D vector with the splatted value that we insert in all the
1826 // places within the returned descriptor.
1827 Value vdesc = rewriter.create<LLVM::PoisonOp>(loc, llvm1DVectorTy);
1828 auto zero = rewriter.create<LLVM::ConstantOp>(
1829 loc, typeConverter->convertType(rewriter.getIntegerType(32)),
1830 rewriter.getZeroAttr(rewriter.getIntegerType(32)));
1831 Value v = rewriter.create<LLVM::InsertElementOp>(loc, llvm1DVectorTy, vdesc,
1832 adaptor.getInput(), zero);
1833
1834 // Shuffle the value across the desired number of elements.
1835 int64_t width = resultType.getDimSize(resultType.getRank() - 1);
1836 SmallVector<int32_t> zeroValues(width, 0);
1837 v = rewriter.create<LLVM::ShuffleVectorOp>(loc, v, v, zeroValues);
1838
1839 // Iterate of linear index, convert to coords space and insert splatted 1-D
1840 // vector in each position.
1841 nDVectorIterate(vectorTypeInfo, rewriter, [&](ArrayRef<int64_t> position) {
1842 desc = rewriter.create<LLVM::InsertValueOp>(loc, desc, v, position);
1843 });
1844 rewriter.replaceOp(splatOp, desc);
1845 return success();
1846 }
1847};
1848
1849/// Conversion pattern for a `vector.interleave`.
1850/// This supports fixed-sized vectors and scalable vectors.
1851struct VectorInterleaveOpLowering
1852 : public ConvertOpToLLVMPattern<vector::InterleaveOp> {
1853 using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern;
1854
1855 LogicalResult
1856 matchAndRewrite(vector::InterleaveOp interleaveOp, OpAdaptor adaptor,
1857 ConversionPatternRewriter &rewriter) const override {
1858 VectorType resultType = interleaveOp.getResultVectorType();
1859 // n-D interleaves should have been lowered already.
1860 if (resultType.getRank() != 1)
1861 return rewriter.notifyMatchFailure(interleaveOp,
1862 "InterleaveOp not rank 1");
1863 // If the result is rank 1, then this directly maps to LLVM.
1864 if (resultType.isScalable()) {
1865 rewriter.replaceOpWithNewOp<LLVM::vector_interleave2>(
1866 interleaveOp, typeConverter->convertType(resultType),
1867 adaptor.getLhs(), adaptor.getRhs());
1868 return success();
1869 }
1870 // Lower fixed-size interleaves to a shufflevector. While the
1871 // vector.interleave2 intrinsic supports fixed and scalable vectors, the
1872 // langref still recommends fixed-vectors use shufflevector, see:
1873 // https://llvm.org/docs/LangRef.html#id876.
1874 int64_t resultVectorSize = resultType.getNumElements();
1875 SmallVector<int32_t> interleaveShuffleMask;
1876 interleaveShuffleMask.reserve(N: resultVectorSize);
1877 for (int i = 0, end = resultVectorSize / 2; i < end; ++i) {
1878 interleaveShuffleMask.push_back(Elt: i);
1879 interleaveShuffleMask.push_back(Elt: (resultVectorSize / 2) + i);
1880 }
1881 rewriter.replaceOpWithNewOp<LLVM::ShuffleVectorOp>(
1882 interleaveOp, adaptor.getLhs(), adaptor.getRhs(),
1883 interleaveShuffleMask);
1884 return success();
1885 }
1886};
1887
1888/// Conversion pattern for a `vector.deinterleave`.
1889/// This supports fixed-sized vectors and scalable vectors.
1890struct VectorDeinterleaveOpLowering
1891 : public ConvertOpToLLVMPattern<vector::DeinterleaveOp> {
1892 using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern;
1893
1894 LogicalResult
1895 matchAndRewrite(vector::DeinterleaveOp deinterleaveOp, OpAdaptor adaptor,
1896 ConversionPatternRewriter &rewriter) const override {
1897 VectorType resultType = deinterleaveOp.getResultVectorType();
1898 VectorType sourceType = deinterleaveOp.getSourceVectorType();
1899 auto loc = deinterleaveOp.getLoc();
1900
1901 // Note: n-D deinterleave operations should be lowered to the 1-D before
1902 // converting to LLVM.
1903 if (resultType.getRank() != 1)
1904 return rewriter.notifyMatchFailure(deinterleaveOp,
1905 "DeinterleaveOp not rank 1");
1906
1907 if (resultType.isScalable()) {
1908 auto llvmTypeConverter = this->getTypeConverter();
1909 auto deinterleaveResults = deinterleaveOp.getResultTypes();
1910 auto packedOpResults =
1911 llvmTypeConverter->packOperationResults(deinterleaveResults);
1912 auto intrinsic = rewriter.create<LLVM::vector_deinterleave2>(
1913 loc, packedOpResults, adaptor.getSource());
1914
1915 auto evenResult = rewriter.create<LLVM::ExtractValueOp>(
1916 loc, intrinsic->getResult(0), 0);
1917 auto oddResult = rewriter.create<LLVM::ExtractValueOp>(
1918 loc, intrinsic->getResult(0), 1);
1919
1920 rewriter.replaceOp(deinterleaveOp, ValueRange{evenResult, oddResult});
1921 return success();
1922 }
1923 // Lower fixed-size deinterleave to two shufflevectors. While the
1924 // vector.deinterleave2 intrinsic supports fixed and scalable vectors, the
1925 // langref still recommends fixed-vectors use shufflevector, see:
1926 // https://llvm.org/docs/LangRef.html#id889.
1927 int64_t resultVectorSize = resultType.getNumElements();
1928 SmallVector<int32_t> evenShuffleMask;
1929 SmallVector<int32_t> oddShuffleMask;
1930
1931 evenShuffleMask.reserve(N: resultVectorSize);
1932 oddShuffleMask.reserve(N: resultVectorSize);
1933
1934 for (int i = 0; i < sourceType.getNumElements(); ++i) {
1935 if (i % 2 == 0)
1936 evenShuffleMask.push_back(Elt: i);
1937 else
1938 oddShuffleMask.push_back(Elt: i);
1939 }
1940
1941 auto poison = rewriter.create<LLVM::PoisonOp>(loc, sourceType);
1942 auto evenShuffle = rewriter.create<LLVM::ShuffleVectorOp>(
1943 loc, adaptor.getSource(), poison, evenShuffleMask);
1944 auto oddShuffle = rewriter.create<LLVM::ShuffleVectorOp>(
1945 loc, adaptor.getSource(), poison, oddShuffleMask);
1946
1947 rewriter.replaceOp(deinterleaveOp, ValueRange{evenShuffle, oddShuffle});
1948 return success();
1949 }
1950};
1951
1952/// Conversion pattern for a `vector.from_elements`.
1953struct VectorFromElementsLowering
1954 : public ConvertOpToLLVMPattern<vector::FromElementsOp> {
1955 using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern;
1956
1957 LogicalResult
1958 matchAndRewrite(vector::FromElementsOp fromElementsOp, OpAdaptor adaptor,
1959 ConversionPatternRewriter &rewriter) const override {
1960 Location loc = fromElementsOp.getLoc();
1961 VectorType vectorType = fromElementsOp.getType();
1962 // TODO: Multi-dimensional vectors lower to !llvm.array<... x vector<>>.
1963 // Such ops should be handled in the same way as vector.insert.
1964 if (vectorType.getRank() > 1)
1965 return rewriter.notifyMatchFailure(fromElementsOp,
1966 "rank > 1 vectors are not supported");
1967 Type llvmType = typeConverter->convertType(vectorType);
1968 Value result = rewriter.create<LLVM::PoisonOp>(loc, llvmType);
1969 for (auto [idx, val] : llvm::enumerate(adaptor.getElements()))
1970 result = rewriter.create<vector::InsertOp>(loc, val, result, idx);
1971 rewriter.replaceOp(fromElementsOp, result);
1972 return success();
1973 }
1974};
1975
1976/// Conversion pattern for vector.step.
1977struct VectorScalableStepOpLowering
1978 : public ConvertOpToLLVMPattern<vector::StepOp> {
1979 using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern;
1980
1981 LogicalResult
1982 matchAndRewrite(vector::StepOp stepOp, OpAdaptor adaptor,
1983 ConversionPatternRewriter &rewriter) const override {
1984 auto resultType = cast<VectorType>(stepOp.getType());
1985 if (!resultType.isScalable()) {
1986 return failure();
1987 }
1988 Type llvmType = typeConverter->convertType(stepOp.getType());
1989 rewriter.replaceOpWithNewOp<LLVM::StepVectorOp>(stepOp, llvmType);
1990 return success();
1991 }
1992};
1993
1994} // namespace
1995
1996void mlir::vector::populateVectorRankReducingFMAPattern(
1997 RewritePatternSet &patterns) {
1998 patterns.add<VectorFMAOpNDRewritePattern>(arg: patterns.getContext());
1999}
2000
2001/// Populate the given list with patterns that convert from Vector to LLVM.
2002void mlir::populateVectorToLLVMConversionPatterns(
2003 const LLVMTypeConverter &converter, RewritePatternSet &patterns,
2004 bool reassociateFPReductions, bool force32BitVectorIndices,
2005 bool useVectorAlignment) {
2006 // This function populates only ConversionPatterns, not RewritePatterns.
2007 MLIRContext *ctx = converter.getDialect()->getContext();
2008 patterns.add<VectorReductionOpConversion>(arg: converter, args&: reassociateFPReductions);
2009 patterns.add<VectorCreateMaskOpConversion>(arg&: ctx, args&: force32BitVectorIndices);
2010 patterns.add<VectorLoadStoreConversion<vector::LoadOp>,
2011 VectorLoadStoreConversion<vector::MaskedLoadOp>,
2012 VectorLoadStoreConversion<vector::StoreOp>,
2013 VectorLoadStoreConversion<vector::MaskedStoreOp>,
2014 VectorGatherOpConversion, VectorScatterOpConversion>(
2015 converter, useVectorAlignment);
2016 patterns.add<VectorBitCastOpConversion, VectorShuffleOpConversion,
2017 VectorExtractElementOpConversion, VectorExtractOpConversion,
2018 VectorFMAOp1DConversion, VectorInsertElementOpConversion,
2019 VectorInsertOpConversion, VectorPrintOpConversion,
2020 VectorTypeCastOpConversion, VectorScaleOpConversion,
2021 VectorExpandLoadOpConversion, VectorCompressStoreOpConversion,
2022 VectorSplatOpLowering, VectorSplatNdOpLowering,
2023 VectorScalableInsertOpLowering, VectorScalableExtractOpLowering,
2024 MaskedReductionOpConversion, VectorInterleaveOpLowering,
2025 VectorDeinterleaveOpLowering, VectorFromElementsLowering,
2026 VectorScalableStepOpLowering>(converter);
2027}
2028
2029void mlir::populateVectorToLLVMMatrixConversionPatterns(
2030 const LLVMTypeConverter &converter, RewritePatternSet &patterns) {
2031 patterns.add<VectorMatmulOpConversion>(arg: converter);
2032 patterns.add<VectorFlatTransposeOpConversion>(arg: converter);
2033}
2034
2035namespace {
2036struct VectorToLLVMDialectInterface : public ConvertToLLVMPatternInterface {
2037 using ConvertToLLVMPatternInterface::ConvertToLLVMPatternInterface;
2038 void loadDependentDialects(MLIRContext *context) const final {
2039 context->loadDialect<LLVM::LLVMDialect>();
2040 }
2041
2042 /// Hook for derived dialect interface to provide conversion patterns
2043 /// and mark dialect legal for the conversion target.
2044 void populateConvertToLLVMConversionPatterns(
2045 ConversionTarget &target, LLVMTypeConverter &typeConverter,
2046 RewritePatternSet &patterns) const final {
2047 populateVectorToLLVMConversionPatterns(converter: typeConverter, patterns);
2048 }
2049};
2050} // namespace
2051
2052void mlir::vector::registerConvertVectorToLLVMInterface(
2053 DialectRegistry &registry) {
2054 registry.addExtension(extensionFn: +[](MLIRContext *ctx, vector::VectorDialect *dialect) {
2055 dialect->addInterfaces<VectorToLLVMDialectInterface>();
2056 });
2057}
2058

Provided by KDAB

Privacy Policy
Improve your Profiling and Debugging skills
Find out more

source code of mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp