1//===- VectorToLLVM.cpp - Conversion from Vector to the LLVM dialect ------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#include "mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h"
10
11#include "mlir/Conversion/ArithCommon/AttrToLLVMConverter.h"
12#include "mlir/Conversion/ConvertToLLVM/ToLLVMInterface.h"
13#include "mlir/Conversion/LLVMCommon/PrintCallHelper.h"
14#include "mlir/Conversion/LLVMCommon/TypeConverter.h"
15#include "mlir/Conversion/LLVMCommon/VectorPattern.h"
16#include "mlir/Dialect/Arith/IR/Arith.h"
17#include "mlir/Dialect/Arith/Utils/Utils.h"
18#include "mlir/Dialect/LLVMIR/FunctionCallUtils.h"
19#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
20#include "mlir/Dialect/MemRef/IR/MemRef.h"
21#include "mlir/Dialect/Vector/IR/VectorOps.h"
22#include "mlir/Dialect/Vector/Interfaces/MaskableOpInterface.h"
23#include "mlir/Dialect/Vector/Transforms/LoweringPatterns.h"
24#include "mlir/Dialect/Vector/Transforms/VectorTransforms.h"
25#include "mlir/IR/BuiltinAttributes.h"
26#include "mlir/IR/BuiltinTypeInterfaces.h"
27#include "mlir/IR/BuiltinTypes.h"
28#include "mlir/IR/TypeUtilities.h"
29#include "mlir/Target/LLVMIR/TypeToLLVM.h"
30#include "mlir/Transforms/DialectConversion.h"
31#include "llvm/ADT/APFloat.h"
32#include "llvm/Support/Casting.h"
33#include <optional>
34
35using namespace mlir;
36using namespace mlir::vector;
37
38// Helper that picks the proper sequence for inserting.
39static Value insertOne(ConversionPatternRewriter &rewriter,
40 const LLVMTypeConverter &typeConverter, Location loc,
41 Value val1, Value val2, Type llvmType, int64_t rank,
42 int64_t pos) {
43 assert(rank > 0 && "0-D vector corner case should have been handled already");
44 if (rank == 1) {
45 auto idxType = rewriter.getIndexType();
46 auto constant = rewriter.create<LLVM::ConstantOp>(
47 location: loc, args: typeConverter.convertType(t: idxType),
48 args: rewriter.getIntegerAttr(type: idxType, value: pos));
49 return rewriter.create<LLVM::InsertElementOp>(location: loc, args&: llvmType, args&: val1, args&: val2,
50 args&: constant);
51 }
52 return rewriter.create<LLVM::InsertValueOp>(location: loc, args&: val1, args&: val2, args&: pos);
53}
54
55// Helper that picks the proper sequence for extracting.
56static Value extractOne(ConversionPatternRewriter &rewriter,
57 const LLVMTypeConverter &typeConverter, Location loc,
58 Value val, Type llvmType, int64_t rank, int64_t pos) {
59 if (rank <= 1) {
60 auto idxType = rewriter.getIndexType();
61 auto constant = rewriter.create<LLVM::ConstantOp>(
62 location: loc, args: typeConverter.convertType(t: idxType),
63 args: rewriter.getIntegerAttr(type: idxType, value: pos));
64 return rewriter.create<LLVM::ExtractElementOp>(location: loc, args&: llvmType, args&: val,
65 args&: constant);
66 }
67 return rewriter.create<LLVM::ExtractValueOp>(location: loc, args&: val, args&: pos);
68}
69
70// Helper that returns data layout alignment of a vector.
71LogicalResult getVectorAlignment(const LLVMTypeConverter &typeConverter,
72 VectorType vectorType, unsigned &align) {
73 Type convertedVectorTy = typeConverter.convertType(t: vectorType);
74 if (!convertedVectorTy)
75 return failure();
76
77 llvm::LLVMContext llvmContext;
78 align = LLVM::TypeToLLVMIRTranslator(llvmContext)
79 .getPreferredAlignment(type: convertedVectorTy,
80 layout: typeConverter.getDataLayout());
81
82 return success();
83}
84
85// Helper that returns data layout alignment of a memref.
86LogicalResult getMemRefAlignment(const LLVMTypeConverter &typeConverter,
87 MemRefType memrefType, unsigned &align) {
88 Type elementTy = typeConverter.convertType(t: memrefType.getElementType());
89 if (!elementTy)
90 return failure();
91
92 // TODO: this should use the MLIR data layout when it becomes available and
93 // stop depending on translation.
94 llvm::LLVMContext llvmContext;
95 align = LLVM::TypeToLLVMIRTranslator(llvmContext)
96 .getPreferredAlignment(type: elementTy, layout: typeConverter.getDataLayout());
97 return success();
98}
99
100// Helper to resolve the alignment for vector load/store, gather and scatter
101// ops. If useVectorAlignment is true, get the preferred alignment for the
102// vector type in the operation. This option is used for hardware backends with
103// vectorization. Otherwise, use the preferred alignment of the element type of
104// the memref. Note that if you choose to use vector alignment, the shape of the
105// vector type must be resolved before the ConvertVectorToLLVM pass is run.
106LogicalResult getVectorToLLVMAlignment(const LLVMTypeConverter &typeConverter,
107 VectorType vectorType,
108 MemRefType memrefType, unsigned &align,
109 bool useVectorAlignment) {
110 if (useVectorAlignment) {
111 if (failed(Result: getVectorAlignment(typeConverter, vectorType, align))) {
112 return failure();
113 }
114 } else {
115 if (failed(Result: getMemRefAlignment(typeConverter, memrefType, align))) {
116 return failure();
117 }
118 }
119 return success();
120}
121
122// Check if the last stride is non-unit and has a valid memory space.
123static LogicalResult isMemRefTypeSupported(MemRefType memRefType,
124 const LLVMTypeConverter &converter) {
125 if (!memRefType.isLastDimUnitStride())
126 return failure();
127 if (failed(Result: converter.getMemRefAddressSpace(type: memRefType)))
128 return failure();
129 return success();
130}
131
132// Add an index vector component to a base pointer.
133static Value getIndexedPtrs(ConversionPatternRewriter &rewriter, Location loc,
134 const LLVMTypeConverter &typeConverter,
135 MemRefType memRefType, Value llvmMemref, Value base,
136 Value index, VectorType vectorType) {
137 assert(succeeded(isMemRefTypeSupported(memRefType, typeConverter)) &&
138 "unsupported memref type");
139 assert(vectorType.getRank() == 1 && "expected a 1-d vector type");
140 auto pType = MemRefDescriptor(llvmMemref).getElementPtrType();
141 auto ptrsType =
142 LLVM::getVectorType(elementType: pType, numElements: vectorType.getDimSize(idx: 0),
143 /*isScalable=*/vectorType.getScalableDims()[0]);
144 return rewriter.create<LLVM::GEPOp>(
145 location: loc, args&: ptrsType, args: typeConverter.convertType(t: memRefType.getElementType()),
146 args&: base, args&: index);
147}
148
149/// Convert `foldResult` into a Value. Integer attribute is converted to
150/// an LLVM constant op.
151static Value getAsLLVMValue(OpBuilder &builder, Location loc,
152 OpFoldResult foldResult) {
153 if (auto attr = dyn_cast<Attribute>(Val&: foldResult)) {
154 auto intAttr = cast<IntegerAttr>(Val&: attr);
155 return builder.create<LLVM::ConstantOp>(location: loc, args&: intAttr).getResult();
156 }
157
158 return cast<Value>(Val&: foldResult);
159}
160
161namespace {
162
163/// Trivial Vector to LLVM conversions
164using VectorScaleOpConversion =
165 OneToOneConvertToLLVMPattern<vector::VectorScaleOp, LLVM::vscale>;
166
167/// Conversion pattern for a vector.bitcast.
168class VectorBitCastOpConversion
169 : public ConvertOpToLLVMPattern<vector::BitCastOp> {
170public:
171 using ConvertOpToLLVMPattern<vector::BitCastOp>::ConvertOpToLLVMPattern;
172
173 LogicalResult
174 matchAndRewrite(vector::BitCastOp bitCastOp, OpAdaptor adaptor,
175 ConversionPatternRewriter &rewriter) const override {
176 // Only 0-D and 1-D vectors can be lowered to LLVM.
177 VectorType resultTy = bitCastOp.getResultVectorType();
178 if (resultTy.getRank() > 1)
179 return failure();
180 Type newResultTy = typeConverter->convertType(t: resultTy);
181 rewriter.replaceOpWithNewOp<LLVM::BitcastOp>(op: bitCastOp, args&: newResultTy,
182 args: adaptor.getOperands()[0]);
183 return success();
184 }
185};
186
187/// Conversion pattern for a vector.matrix_multiply.
188/// This is lowered directly to the proper llvm.intr.matrix.multiply.
189class VectorMatmulOpConversion
190 : public ConvertOpToLLVMPattern<vector::MatmulOp> {
191public:
192 using ConvertOpToLLVMPattern<vector::MatmulOp>::ConvertOpToLLVMPattern;
193
194 LogicalResult
195 matchAndRewrite(vector::MatmulOp matmulOp, OpAdaptor adaptor,
196 ConversionPatternRewriter &rewriter) const override {
197 rewriter.replaceOpWithNewOp<LLVM::MatrixMultiplyOp>(
198 op: matmulOp, args: typeConverter->convertType(t: matmulOp.getRes().getType()),
199 args: adaptor.getLhs(), args: adaptor.getRhs(), args: matmulOp.getLhsRows(),
200 args: matmulOp.getLhsColumns(), args: matmulOp.getRhsColumns());
201 return success();
202 }
203};
204
205/// Conversion pattern for a vector.flat_transpose.
206/// This is lowered directly to the proper llvm.intr.matrix.transpose.
207class VectorFlatTransposeOpConversion
208 : public ConvertOpToLLVMPattern<vector::FlatTransposeOp> {
209public:
210 using ConvertOpToLLVMPattern<vector::FlatTransposeOp>::ConvertOpToLLVMPattern;
211
212 LogicalResult
213 matchAndRewrite(vector::FlatTransposeOp transOp, OpAdaptor adaptor,
214 ConversionPatternRewriter &rewriter) const override {
215 rewriter.replaceOpWithNewOp<LLVM::MatrixTransposeOp>(
216 op: transOp, args: typeConverter->convertType(t: transOp.getRes().getType()),
217 args: adaptor.getMatrix(), args: transOp.getRows(), args: transOp.getColumns());
218 return success();
219 }
220};
221
222/// Overloaded utility that replaces a vector.load, vector.store,
223/// vector.maskedload and vector.maskedstore with their respective LLVM
224/// couterparts.
225static void replaceLoadOrStoreOp(vector::LoadOp loadOp,
226 vector::LoadOpAdaptor adaptor,
227 VectorType vectorTy, Value ptr, unsigned align,
228 ConversionPatternRewriter &rewriter) {
229 rewriter.replaceOpWithNewOp<LLVM::LoadOp>(op: loadOp, args&: vectorTy, args&: ptr, args&: align,
230 /*volatile_=*/args: false,
231 args: loadOp.getNontemporal());
232}
233
234static void replaceLoadOrStoreOp(vector::MaskedLoadOp loadOp,
235 vector::MaskedLoadOpAdaptor adaptor,
236 VectorType vectorTy, Value ptr, unsigned align,
237 ConversionPatternRewriter &rewriter) {
238 rewriter.replaceOpWithNewOp<LLVM::MaskedLoadOp>(
239 op: loadOp, args&: vectorTy, args&: ptr, args: adaptor.getMask(), args: adaptor.getPassThru(), args&: align);
240}
241
242static void replaceLoadOrStoreOp(vector::StoreOp storeOp,
243 vector::StoreOpAdaptor adaptor,
244 VectorType vectorTy, Value ptr, unsigned align,
245 ConversionPatternRewriter &rewriter) {
246 rewriter.replaceOpWithNewOp<LLVM::StoreOp>(op: storeOp, args: adaptor.getValueToStore(),
247 args&: ptr, args&: align, /*volatile_=*/args: false,
248 args: storeOp.getNontemporal());
249}
250
251static void replaceLoadOrStoreOp(vector::MaskedStoreOp storeOp,
252 vector::MaskedStoreOpAdaptor adaptor,
253 VectorType vectorTy, Value ptr, unsigned align,
254 ConversionPatternRewriter &rewriter) {
255 rewriter.replaceOpWithNewOp<LLVM::MaskedStoreOp>(
256 op: storeOp, args: adaptor.getValueToStore(), args&: ptr, args: adaptor.getMask(), args&: align);
257}
258
259/// Conversion pattern for a vector.load, vector.store, vector.maskedload, and
260/// vector.maskedstore.
261template <class LoadOrStoreOp>
262class VectorLoadStoreConversion : public ConvertOpToLLVMPattern<LoadOrStoreOp> {
263public:
264 explicit VectorLoadStoreConversion(const LLVMTypeConverter &typeConv,
265 bool useVectorAlign)
266 : ConvertOpToLLVMPattern<LoadOrStoreOp>(typeConv),
267 useVectorAlignment(useVectorAlign) {}
268 using ConvertOpToLLVMPattern<LoadOrStoreOp>::ConvertOpToLLVMPattern;
269
270 LogicalResult
271 matchAndRewrite(LoadOrStoreOp loadOrStoreOp,
272 typename LoadOrStoreOp::Adaptor adaptor,
273 ConversionPatternRewriter &rewriter) const override {
274 // Only 1-D vectors can be lowered to LLVM.
275 VectorType vectorTy = loadOrStoreOp.getVectorType();
276 if (vectorTy.getRank() > 1)
277 return failure();
278
279 auto loc = loadOrStoreOp->getLoc();
280 MemRefType memRefTy = loadOrStoreOp.getMemRefType();
281
282 // Resolve alignment.
283 unsigned align;
284 if (failed(getVectorToLLVMAlignment(*this->getTypeConverter(), vectorTy,
285 memRefTy, align, useVectorAlignment)))
286 return rewriter.notifyMatchFailure(loadOrStoreOp,
287 "could not resolve alignment");
288
289 // Resolve address.
290 auto vtype = cast<VectorType>(
291 this->typeConverter->convertType(loadOrStoreOp.getVectorType()));
292 Value dataPtr = this->getStridedElementPtr(
293 rewriter, loc, memRefTy, adaptor.getBase(), adaptor.getIndices());
294 replaceLoadOrStoreOp(loadOrStoreOp, adaptor, vtype, dataPtr, align,
295 rewriter);
296 return success();
297 }
298
299private:
300 // If true, use the preferred alignment of the vector type.
301 // If false, use the preferred alignment of the element type
302 // of the memref. This flag is intended for use with hardware
303 // backends that require alignment of vector operations.
304 const bool useVectorAlignment;
305};
306
307/// Conversion pattern for a vector.gather.
308class VectorGatherOpConversion
309 : public ConvertOpToLLVMPattern<vector::GatherOp> {
310public:
311 explicit VectorGatherOpConversion(const LLVMTypeConverter &typeConv,
312 bool useVectorAlign)
313 : ConvertOpToLLVMPattern<vector::GatherOp>(typeConv),
314 useVectorAlignment(useVectorAlign) {}
315 using ConvertOpToLLVMPattern<vector::GatherOp>::ConvertOpToLLVMPattern;
316
317 LogicalResult
318 matchAndRewrite(vector::GatherOp gather, OpAdaptor adaptor,
319 ConversionPatternRewriter &rewriter) const override {
320 Location loc = gather->getLoc();
321 MemRefType memRefType = dyn_cast<MemRefType>(Val: gather.getBaseType());
322 assert(memRefType && "The base should be bufferized");
323
324 if (failed(Result: isMemRefTypeSupported(memRefType, converter: *this->getTypeConverter())))
325 return rewriter.notifyMatchFailure(arg&: gather, msg: "memref type not supported");
326
327 VectorType vType = gather.getVectorType();
328 if (vType.getRank() > 1) {
329 return rewriter.notifyMatchFailure(
330 arg&: gather, msg: "only 1-D vectors can be lowered to LLVM");
331 }
332
333 // Resolve alignment.
334 unsigned align;
335 if (failed(Result: getVectorToLLVMAlignment(typeConverter: *this->getTypeConverter(), vectorType: vType,
336 memrefType: memRefType, align, useVectorAlignment)))
337 return rewriter.notifyMatchFailure(arg&: gather, msg: "could not resolve alignment");
338
339 // Resolve address.
340 Value ptr = getStridedElementPtr(rewriter, loc, type: memRefType,
341 memRefDesc: adaptor.getBase(), indices: adaptor.getIndices());
342 Value base = adaptor.getBase();
343 Value ptrs =
344 getIndexedPtrs(rewriter, loc, typeConverter: *this->getTypeConverter(), memRefType,
345 llvmMemref: base, base: ptr, index: adaptor.getIndexVec(), vectorType: vType);
346
347 // Replace with the gather intrinsic.
348 rewriter.replaceOpWithNewOp<LLVM::masked_gather>(
349 op: gather, args: typeConverter->convertType(t: vType), args&: ptrs, args: adaptor.getMask(),
350 args: adaptor.getPassThru(), args: rewriter.getI32IntegerAttr(value: align));
351 return success();
352 }
353
354private:
355 // If true, use the preferred alignment of the vector type.
356 // If false, use the preferred alignment of the element type
357 // of the memref. This flag is intended for use with hardware
358 // backends that require alignment of vector operations.
359 const bool useVectorAlignment;
360};
361
362/// Conversion pattern for a vector.scatter.
363class VectorScatterOpConversion
364 : public ConvertOpToLLVMPattern<vector::ScatterOp> {
365public:
366 explicit VectorScatterOpConversion(const LLVMTypeConverter &typeConv,
367 bool useVectorAlign)
368 : ConvertOpToLLVMPattern<vector::ScatterOp>(typeConv),
369 useVectorAlignment(useVectorAlign) {}
370
371 using ConvertOpToLLVMPattern<vector::ScatterOp>::ConvertOpToLLVMPattern;
372
373 LogicalResult
374 matchAndRewrite(vector::ScatterOp scatter, OpAdaptor adaptor,
375 ConversionPatternRewriter &rewriter) const override {
376 auto loc = scatter->getLoc();
377 MemRefType memRefType = scatter.getMemRefType();
378
379 if (failed(Result: isMemRefTypeSupported(memRefType, converter: *this->getTypeConverter())))
380 return rewriter.notifyMatchFailure(arg&: scatter, msg: "memref type not supported");
381
382 VectorType vType = scatter.getVectorType();
383 if (vType.getRank() > 1) {
384 return rewriter.notifyMatchFailure(
385 arg&: scatter, msg: "only 1-D vectors can be lowered to LLVM");
386 }
387
388 // Resolve alignment.
389 unsigned align;
390 if (failed(Result: getVectorToLLVMAlignment(typeConverter: *this->getTypeConverter(), vectorType: vType,
391 memrefType: memRefType, align, useVectorAlignment)))
392 return rewriter.notifyMatchFailure(arg&: scatter,
393 msg: "could not resolve alignment");
394
395 // Resolve address.
396 Value ptr = getStridedElementPtr(rewriter, loc, type: memRefType,
397 memRefDesc: adaptor.getBase(), indices: adaptor.getIndices());
398 Value ptrs =
399 getIndexedPtrs(rewriter, loc, typeConverter: *this->getTypeConverter(), memRefType,
400 llvmMemref: adaptor.getBase(), base: ptr, index: adaptor.getIndexVec(), vectorType: vType);
401
402 // Replace with the scatter intrinsic.
403 rewriter.replaceOpWithNewOp<LLVM::masked_scatter>(
404 op: scatter, args: adaptor.getValueToStore(), args&: ptrs, args: adaptor.getMask(),
405 args: rewriter.getI32IntegerAttr(value: align));
406 return success();
407 }
408
409private:
410 // If true, use the preferred alignment of the vector type.
411 // If false, use the preferred alignment of the element type
412 // of the memref. This flag is intended for use with hardware
413 // backends that require alignment of vector operations.
414 const bool useVectorAlignment;
415};
416
417/// Conversion pattern for a vector.expandload.
418class VectorExpandLoadOpConversion
419 : public ConvertOpToLLVMPattern<vector::ExpandLoadOp> {
420public:
421 using ConvertOpToLLVMPattern<vector::ExpandLoadOp>::ConvertOpToLLVMPattern;
422
423 LogicalResult
424 matchAndRewrite(vector::ExpandLoadOp expand, OpAdaptor adaptor,
425 ConversionPatternRewriter &rewriter) const override {
426 auto loc = expand->getLoc();
427 MemRefType memRefType = expand.getMemRefType();
428
429 // Resolve address.
430 auto vtype = typeConverter->convertType(t: expand.getVectorType());
431 Value ptr = getStridedElementPtr(rewriter, loc, type: memRefType,
432 memRefDesc: adaptor.getBase(), indices: adaptor.getIndices());
433
434 rewriter.replaceOpWithNewOp<LLVM::masked_expandload>(
435 op: expand, args&: vtype, args&: ptr, args: adaptor.getMask(), args: adaptor.getPassThru());
436 return success();
437 }
438};
439
440/// Conversion pattern for a vector.compressstore.
441class VectorCompressStoreOpConversion
442 : public ConvertOpToLLVMPattern<vector::CompressStoreOp> {
443public:
444 using ConvertOpToLLVMPattern<vector::CompressStoreOp>::ConvertOpToLLVMPattern;
445
446 LogicalResult
447 matchAndRewrite(vector::CompressStoreOp compress, OpAdaptor adaptor,
448 ConversionPatternRewriter &rewriter) const override {
449 auto loc = compress->getLoc();
450 MemRefType memRefType = compress.getMemRefType();
451
452 // Resolve address.
453 Value ptr = getStridedElementPtr(rewriter, loc, type: memRefType,
454 memRefDesc: adaptor.getBase(), indices: adaptor.getIndices());
455
456 rewriter.replaceOpWithNewOp<LLVM::masked_compressstore>(
457 op: compress, args: adaptor.getValueToStore(), args&: ptr, args: adaptor.getMask());
458 return success();
459 }
460};
461
462/// Reduction neutral classes for overloading.
463class ReductionNeutralZero {};
464class ReductionNeutralIntOne {};
465class ReductionNeutralFPOne {};
466class ReductionNeutralAllOnes {};
467class ReductionNeutralSIntMin {};
468class ReductionNeutralUIntMin {};
469class ReductionNeutralSIntMax {};
470class ReductionNeutralUIntMax {};
471class ReductionNeutralFPMin {};
472class ReductionNeutralFPMax {};
473
474/// Create the reduction neutral zero value.
475static Value createReductionNeutralValue(ReductionNeutralZero neutral,
476 ConversionPatternRewriter &rewriter,
477 Location loc, Type llvmType) {
478 return rewriter.create<LLVM::ConstantOp>(location: loc, args&: llvmType,
479 args: rewriter.getZeroAttr(type: llvmType));
480}
481
482/// Create the reduction neutral integer one value.
483static Value createReductionNeutralValue(ReductionNeutralIntOne neutral,
484 ConversionPatternRewriter &rewriter,
485 Location loc, Type llvmType) {
486 return rewriter.create<LLVM::ConstantOp>(
487 location: loc, args&: llvmType, args: rewriter.getIntegerAttr(type: llvmType, value: 1));
488}
489
490/// Create the reduction neutral fp one value.
491static Value createReductionNeutralValue(ReductionNeutralFPOne neutral,
492 ConversionPatternRewriter &rewriter,
493 Location loc, Type llvmType) {
494 return rewriter.create<LLVM::ConstantOp>(
495 location: loc, args&: llvmType, args: rewriter.getFloatAttr(type: llvmType, value: 1.0));
496}
497
498/// Create the reduction neutral all-ones value.
499static Value createReductionNeutralValue(ReductionNeutralAllOnes neutral,
500 ConversionPatternRewriter &rewriter,
501 Location loc, Type llvmType) {
502 return rewriter.create<LLVM::ConstantOp>(
503 location: loc, args&: llvmType,
504 args: rewriter.getIntegerAttr(
505 type: llvmType, value: llvm::APInt::getAllOnes(numBits: llvmType.getIntOrFloatBitWidth())));
506}
507
508/// Create the reduction neutral signed int minimum value.
509static Value createReductionNeutralValue(ReductionNeutralSIntMin neutral,
510 ConversionPatternRewriter &rewriter,
511 Location loc, Type llvmType) {
512 return rewriter.create<LLVM::ConstantOp>(
513 location: loc, args&: llvmType,
514 args: rewriter.getIntegerAttr(type: llvmType, value: llvm::APInt::getSignedMinValue(
515 numBits: llvmType.getIntOrFloatBitWidth())));
516}
517
518/// Create the reduction neutral unsigned int minimum value.
519static Value createReductionNeutralValue(ReductionNeutralUIntMin neutral,
520 ConversionPatternRewriter &rewriter,
521 Location loc, Type llvmType) {
522 return rewriter.create<LLVM::ConstantOp>(
523 location: loc, args&: llvmType,
524 args: rewriter.getIntegerAttr(type: llvmType, value: llvm::APInt::getMinValue(
525 numBits: llvmType.getIntOrFloatBitWidth())));
526}
527
528/// Create the reduction neutral signed int maximum value.
529static Value createReductionNeutralValue(ReductionNeutralSIntMax neutral,
530 ConversionPatternRewriter &rewriter,
531 Location loc, Type llvmType) {
532 return rewriter.create<LLVM::ConstantOp>(
533 location: loc, args&: llvmType,
534 args: rewriter.getIntegerAttr(type: llvmType, value: llvm::APInt::getSignedMaxValue(
535 numBits: llvmType.getIntOrFloatBitWidth())));
536}
537
538/// Create the reduction neutral unsigned int maximum value.
539static Value createReductionNeutralValue(ReductionNeutralUIntMax neutral,
540 ConversionPatternRewriter &rewriter,
541 Location loc, Type llvmType) {
542 return rewriter.create<LLVM::ConstantOp>(
543 location: loc, args&: llvmType,
544 args: rewriter.getIntegerAttr(type: llvmType, value: llvm::APInt::getMaxValue(
545 numBits: llvmType.getIntOrFloatBitWidth())));
546}
547
548/// Create the reduction neutral fp minimum value.
549static Value createReductionNeutralValue(ReductionNeutralFPMin neutral,
550 ConversionPatternRewriter &rewriter,
551 Location loc, Type llvmType) {
552 auto floatType = cast<FloatType>(Val&: llvmType);
553 return rewriter.create<LLVM::ConstantOp>(
554 location: loc, args&: llvmType,
555 args: rewriter.getFloatAttr(
556 type: llvmType, value: llvm::APFloat::getQNaN(Sem: floatType.getFloatSemantics(),
557 /*Negative=*/false)));
558}
559
560/// Create the reduction neutral fp maximum value.
561static Value createReductionNeutralValue(ReductionNeutralFPMax neutral,
562 ConversionPatternRewriter &rewriter,
563 Location loc, Type llvmType) {
564 auto floatType = cast<FloatType>(Val&: llvmType);
565 return rewriter.create<LLVM::ConstantOp>(
566 location: loc, args&: llvmType,
567 args: rewriter.getFloatAttr(
568 type: llvmType, value: llvm::APFloat::getQNaN(Sem: floatType.getFloatSemantics(),
569 /*Negative=*/true)));
570}
571
572/// Returns `accumulator` if it has a valid value. Otherwise, creates and
573/// returns a new accumulator value using `ReductionNeutral`.
574template <class ReductionNeutral>
575static Value getOrCreateAccumulator(ConversionPatternRewriter &rewriter,
576 Location loc, Type llvmType,
577 Value accumulator) {
578 if (accumulator)
579 return accumulator;
580
581 return createReductionNeutralValue(ReductionNeutral(), rewriter, loc,
582 llvmType);
583}
584
585/// Creates a value with the 1-D vector shape provided in `llvmType`.
586/// This is used as effective vector length by some intrinsics supporting
587/// dynamic vector lengths at runtime.
588static Value createVectorLengthValue(ConversionPatternRewriter &rewriter,
589 Location loc, Type llvmType) {
590 VectorType vType = cast<VectorType>(Val&: llvmType);
591 auto vShape = vType.getShape();
592 assert(vShape.size() == 1 && "Unexpected multi-dim vector type");
593
594 Value baseVecLength = rewriter.create<LLVM::ConstantOp>(
595 location: loc, args: rewriter.getI32Type(),
596 args: rewriter.getIntegerAttr(type: rewriter.getI32Type(), value: vShape[0]));
597
598 if (!vType.getScalableDims()[0])
599 return baseVecLength;
600
601 // For a scalable vector type, create and return `vScale * baseVecLength`.
602 Value vScale = rewriter.create<vector::VectorScaleOp>(location: loc);
603 vScale =
604 rewriter.create<arith::IndexCastOp>(location: loc, args: rewriter.getI32Type(), args&: vScale);
605 Value scalableVecLength =
606 rewriter.create<arith::MulIOp>(location: loc, args&: baseVecLength, args&: vScale);
607 return scalableVecLength;
608}
609
610/// Helper method to lower a `vector.reduction` op that performs an arithmetic
611/// operation like add,mul, etc.. `VectorOp` is the LLVM vector intrinsic to use
612/// and `ScalarOp` is the scalar operation used to add the accumulation value if
613/// non-null.
614template <class LLVMRedIntrinOp, class ScalarOp>
615static Value createIntegerReductionArithmeticOpLowering(
616 ConversionPatternRewriter &rewriter, Location loc, Type llvmType,
617 Value vectorOperand, Value accumulator) {
618
619 Value result = rewriter.create<LLVMRedIntrinOp>(loc, llvmType, vectorOperand);
620
621 if (accumulator)
622 result = rewriter.create<ScalarOp>(loc, accumulator, result);
623 return result;
624}
625
626/// Helper method to lower a `vector.reduction` operation that performs
627/// a comparison operation like `min`/`max`. `VectorOp` is the LLVM vector
628/// intrinsic to use and `predicate` is the predicate to use to compare+combine
629/// the accumulator value if non-null.
630template <class LLVMRedIntrinOp>
631static Value createIntegerReductionComparisonOpLowering(
632 ConversionPatternRewriter &rewriter, Location loc, Type llvmType,
633 Value vectorOperand, Value accumulator, LLVM::ICmpPredicate predicate) {
634 Value result = rewriter.create<LLVMRedIntrinOp>(loc, llvmType, vectorOperand);
635 if (accumulator) {
636 Value cmp =
637 rewriter.create<LLVM::ICmpOp>(location: loc, args&: predicate, args&: accumulator, args&: result);
638 result = rewriter.create<LLVM::SelectOp>(location: loc, args&: cmp, args&: accumulator, args&: result);
639 }
640 return result;
641}
642
643namespace {
644template <typename Source>
645struct VectorToScalarMapper;
646template <>
647struct VectorToScalarMapper<LLVM::vector_reduce_fmaximum> {
648 using Type = LLVM::MaximumOp;
649};
650template <>
651struct VectorToScalarMapper<LLVM::vector_reduce_fminimum> {
652 using Type = LLVM::MinimumOp;
653};
654template <>
655struct VectorToScalarMapper<LLVM::vector_reduce_fmax> {
656 using Type = LLVM::MaxNumOp;
657};
658template <>
659struct VectorToScalarMapper<LLVM::vector_reduce_fmin> {
660 using Type = LLVM::MinNumOp;
661};
662} // namespace
663
664template <class LLVMRedIntrinOp>
665static Value createFPReductionComparisonOpLowering(
666 ConversionPatternRewriter &rewriter, Location loc, Type llvmType,
667 Value vectorOperand, Value accumulator, LLVM::FastmathFlagsAttr fmf) {
668 Value result =
669 rewriter.create<LLVMRedIntrinOp>(loc, llvmType, vectorOperand, fmf);
670
671 if (accumulator) {
672 result =
673 rewriter.create<typename VectorToScalarMapper<LLVMRedIntrinOp>::Type>(
674 loc, result, accumulator);
675 }
676
677 return result;
678}
679
680/// Reduction neutral classes for overloading
681class MaskNeutralFMaximum {};
682class MaskNeutralFMinimum {};
683
684/// Get the mask neutral floating point maximum value
685static llvm::APFloat
686getMaskNeutralValue(MaskNeutralFMaximum,
687 const llvm::fltSemantics &floatSemantics) {
688 return llvm::APFloat::getSmallest(Sem: floatSemantics, /*Negative=*/true);
689}
690/// Get the mask neutral floating point minimum value
691static llvm::APFloat
692getMaskNeutralValue(MaskNeutralFMinimum,
693 const llvm::fltSemantics &floatSemantics) {
694 return llvm::APFloat::getLargest(Sem: floatSemantics, /*Negative=*/false);
695}
696
697/// Create the mask neutral floating point MLIR vector constant
698template <typename MaskNeutral>
699static Value createMaskNeutralValue(ConversionPatternRewriter &rewriter,
700 Location loc, Type llvmType,
701 Type vectorType) {
702 const auto &floatSemantics = cast<FloatType>(Val&: llvmType).getFloatSemantics();
703 auto value = getMaskNeutralValue(MaskNeutral{}, floatSemantics);
704 auto denseValue = DenseElementsAttr::get(cast<ShapedType>(Val&: vectorType), value);
705 return rewriter.create<LLVM::ConstantOp>(loc, vectorType, denseValue);
706}
707
708/// Lowers masked `fmaximum` and `fminimum` reductions using the non-masked
709/// intrinsics. It is a workaround to overcome the lack of masked intrinsics for
710/// `fmaximum`/`fminimum`.
711/// More information: https://github.com/llvm/llvm-project/issues/64940
712template <class LLVMRedIntrinOp, class MaskNeutral>
713static Value
714lowerMaskedReductionWithRegular(ConversionPatternRewriter &rewriter,
715 Location loc, Type llvmType,
716 Value vectorOperand, Value accumulator,
717 Value mask, LLVM::FastmathFlagsAttr fmf) {
718 const Value vectorMaskNeutral = createMaskNeutralValue<MaskNeutral>(
719 rewriter, loc, llvmType, vectorOperand.getType());
720 const Value selectedVectorByMask = rewriter.create<LLVM::SelectOp>(
721 location: loc, args&: mask, args&: vectorOperand, args: vectorMaskNeutral);
722 return createFPReductionComparisonOpLowering<LLVMRedIntrinOp>(
723 rewriter, loc, llvmType, selectedVectorByMask, accumulator, fmf);
724}
725
726template <class LLVMRedIntrinOp, class ReductionNeutral>
727static Value
728lowerReductionWithStartValue(ConversionPatternRewriter &rewriter, Location loc,
729 Type llvmType, Value vectorOperand,
730 Value accumulator, LLVM::FastmathFlagsAttr fmf) {
731 accumulator = getOrCreateAccumulator<ReductionNeutral>(rewriter, loc,
732 llvmType, accumulator);
733 return rewriter.create<LLVMRedIntrinOp>(loc, llvmType,
734 /*startValue=*/accumulator,
735 vectorOperand, fmf);
736}
737
738/// Overloaded methods to lower a *predicated* reduction to an llvm intrinsic
739/// that requires a start value. This start value format spans across fp
740/// reductions without mask and all the masked reduction intrinsics.
741template <class LLVMVPRedIntrinOp, class ReductionNeutral>
742static Value
743lowerPredicatedReductionWithStartValue(ConversionPatternRewriter &rewriter,
744 Location loc, Type llvmType,
745 Value vectorOperand, Value accumulator) {
746 accumulator = getOrCreateAccumulator<ReductionNeutral>(rewriter, loc,
747 llvmType, accumulator);
748 return rewriter.create<LLVMVPRedIntrinOp>(loc, llvmType,
749 /*startValue=*/accumulator,
750 vectorOperand);
751}
752
753template <class LLVMVPRedIntrinOp, class ReductionNeutral>
754static Value lowerPredicatedReductionWithStartValue(
755 ConversionPatternRewriter &rewriter, Location loc, Type llvmType,
756 Value vectorOperand, Value accumulator, Value mask) {
757 accumulator = getOrCreateAccumulator<ReductionNeutral>(rewriter, loc,
758 llvmType, accumulator);
759 Value vectorLength =
760 createVectorLengthValue(rewriter, loc, llvmType: vectorOperand.getType());
761 return rewriter.create<LLVMVPRedIntrinOp>(loc, llvmType,
762 /*startValue=*/accumulator,
763 vectorOperand, mask, vectorLength);
764}
765
766template <class LLVMIntVPRedIntrinOp, class IntReductionNeutral,
767 class LLVMFPVPRedIntrinOp, class FPReductionNeutral>
768static Value lowerPredicatedReductionWithStartValue(
769 ConversionPatternRewriter &rewriter, Location loc, Type llvmType,
770 Value vectorOperand, Value accumulator, Value mask) {
771 if (llvmType.isIntOrIndex())
772 return lowerPredicatedReductionWithStartValue<LLVMIntVPRedIntrinOp,
773 IntReductionNeutral>(
774 rewriter, loc, llvmType, vectorOperand, accumulator, mask);
775
776 // FP dispatch.
777 return lowerPredicatedReductionWithStartValue<LLVMFPVPRedIntrinOp,
778 FPReductionNeutral>(
779 rewriter, loc, llvmType, vectorOperand, accumulator, mask);
780}
781
782/// Conversion pattern for all vector reductions.
783class VectorReductionOpConversion
784 : public ConvertOpToLLVMPattern<vector::ReductionOp> {
785public:
786 explicit VectorReductionOpConversion(const LLVMTypeConverter &typeConv,
787 bool reassociateFPRed)
788 : ConvertOpToLLVMPattern<vector::ReductionOp>(typeConv),
789 reassociateFPReductions(reassociateFPRed) {}
790
791 LogicalResult
792 matchAndRewrite(vector::ReductionOp reductionOp, OpAdaptor adaptor,
793 ConversionPatternRewriter &rewriter) const override {
794 auto kind = reductionOp.getKind();
795 Type eltType = reductionOp.getDest().getType();
796 Type llvmType = typeConverter->convertType(t: eltType);
797 Value operand = adaptor.getVector();
798 Value acc = adaptor.getAcc();
799 Location loc = reductionOp.getLoc();
800
801 if (eltType.isIntOrIndex()) {
802 // Integer reductions: add/mul/min/max/and/or/xor.
803 Value result;
804 switch (kind) {
805 case vector::CombiningKind::ADD:
806 result =
807 createIntegerReductionArithmeticOpLowering<LLVM::vector_reduce_add,
808 LLVM::AddOp>(
809 rewriter, loc, llvmType, vectorOperand: operand, accumulator: acc);
810 break;
811 case vector::CombiningKind::MUL:
812 result =
813 createIntegerReductionArithmeticOpLowering<LLVM::vector_reduce_mul,
814 LLVM::MulOp>(
815 rewriter, loc, llvmType, vectorOperand: operand, accumulator: acc);
816 break;
817 case vector::CombiningKind::MINUI:
818 result = createIntegerReductionComparisonOpLowering<
819 LLVM::vector_reduce_umin>(rewriter, loc, llvmType, vectorOperand: operand, accumulator: acc,
820 predicate: LLVM::ICmpPredicate::ule);
821 break;
822 case vector::CombiningKind::MINSI:
823 result = createIntegerReductionComparisonOpLowering<
824 LLVM::vector_reduce_smin>(rewriter, loc, llvmType, vectorOperand: operand, accumulator: acc,
825 predicate: LLVM::ICmpPredicate::sle);
826 break;
827 case vector::CombiningKind::MAXUI:
828 result = createIntegerReductionComparisonOpLowering<
829 LLVM::vector_reduce_umax>(rewriter, loc, llvmType, vectorOperand: operand, accumulator: acc,
830 predicate: LLVM::ICmpPredicate::uge);
831 break;
832 case vector::CombiningKind::MAXSI:
833 result = createIntegerReductionComparisonOpLowering<
834 LLVM::vector_reduce_smax>(rewriter, loc, llvmType, vectorOperand: operand, accumulator: acc,
835 predicate: LLVM::ICmpPredicate::sge);
836 break;
837 case vector::CombiningKind::AND:
838 result =
839 createIntegerReductionArithmeticOpLowering<LLVM::vector_reduce_and,
840 LLVM::AndOp>(
841 rewriter, loc, llvmType, vectorOperand: operand, accumulator: acc);
842 break;
843 case vector::CombiningKind::OR:
844 result =
845 createIntegerReductionArithmeticOpLowering<LLVM::vector_reduce_or,
846 LLVM::OrOp>(
847 rewriter, loc, llvmType, vectorOperand: operand, accumulator: acc);
848 break;
849 case vector::CombiningKind::XOR:
850 result =
851 createIntegerReductionArithmeticOpLowering<LLVM::vector_reduce_xor,
852 LLVM::XOrOp>(
853 rewriter, loc, llvmType, vectorOperand: operand, accumulator: acc);
854 break;
855 default:
856 return failure();
857 }
858 rewriter.replaceOp(op: reductionOp, newValues: result);
859
860 return success();
861 }
862
863 if (!isa<FloatType>(Val: eltType))
864 return failure();
865
866 arith::FastMathFlagsAttr fMFAttr = reductionOp.getFastMathFlagsAttr();
867 LLVM::FastmathFlagsAttr fmf = LLVM::FastmathFlagsAttr::get(
868 context: reductionOp.getContext(),
869 value: convertArithFastMathFlagsToLLVM(arithFMF: fMFAttr.getValue()));
870 fmf = LLVM::FastmathFlagsAttr::get(
871 context: reductionOp.getContext(),
872 value: fmf.getValue() | (reassociateFPReductions ? LLVM::FastmathFlags::reassoc
873 : LLVM::FastmathFlags::none));
874
875 // Floating-point reductions: add/mul/min/max
876 Value result;
877 if (kind == vector::CombiningKind::ADD) {
878 result = lowerReductionWithStartValue<LLVM::vector_reduce_fadd,
879 ReductionNeutralZero>(
880 rewriter, loc, llvmType, vectorOperand: operand, accumulator: acc, fmf);
881 } else if (kind == vector::CombiningKind::MUL) {
882 result = lowerReductionWithStartValue<LLVM::vector_reduce_fmul,
883 ReductionNeutralFPOne>(
884 rewriter, loc, llvmType, vectorOperand: operand, accumulator: acc, fmf);
885 } else if (kind == vector::CombiningKind::MINIMUMF) {
886 result =
887 createFPReductionComparisonOpLowering<LLVM::vector_reduce_fminimum>(
888 rewriter, loc, llvmType, vectorOperand: operand, accumulator: acc, fmf);
889 } else if (kind == vector::CombiningKind::MAXIMUMF) {
890 result =
891 createFPReductionComparisonOpLowering<LLVM::vector_reduce_fmaximum>(
892 rewriter, loc, llvmType, vectorOperand: operand, accumulator: acc, fmf);
893 } else if (kind == vector::CombiningKind::MINNUMF) {
894 result = createFPReductionComparisonOpLowering<LLVM::vector_reduce_fmin>(
895 rewriter, loc, llvmType, vectorOperand: operand, accumulator: acc, fmf);
896 } else if (kind == vector::CombiningKind::MAXNUMF) {
897 result = createFPReductionComparisonOpLowering<LLVM::vector_reduce_fmax>(
898 rewriter, loc, llvmType, vectorOperand: operand, accumulator: acc, fmf);
899 } else {
900 return failure();
901 }
902
903 rewriter.replaceOp(op: reductionOp, newValues: result);
904 return success();
905 }
906
907private:
908 const bool reassociateFPReductions;
909};
910
911/// Base class to convert a `vector.mask` operation while matching traits
912/// of the maskable operation nested inside. A `VectorMaskOpConversionBase`
913/// instance matches against a `vector.mask` operation. The `matchAndRewrite`
914/// method performs a second match against the maskable operation `MaskedOp`.
915/// Finally, it invokes the virtual method `matchAndRewriteMaskableOp` to be
916/// implemented by the concrete conversion classes. This method can match
917/// against specific traits of the `vector.mask` and the maskable operation. It
918/// must replace the `vector.mask` operation.
919template <class MaskedOp>
920class VectorMaskOpConversionBase
921 : public ConvertOpToLLVMPattern<vector::MaskOp> {
922public:
923 using ConvertOpToLLVMPattern<vector::MaskOp>::ConvertOpToLLVMPattern;
924
925 LogicalResult
926 matchAndRewrite(vector::MaskOp maskOp, OpAdaptor adaptor,
927 ConversionPatternRewriter &rewriter) const final {
928 // Match against the maskable operation kind.
929 auto maskedOp = llvm::dyn_cast_or_null<MaskedOp>(maskOp.getMaskableOp());
930 if (!maskedOp)
931 return failure();
932 return matchAndRewriteMaskableOp(maskOp, maskableOp: maskedOp, rewriter);
933 }
934
935protected:
936 virtual LogicalResult
937 matchAndRewriteMaskableOp(vector::MaskOp maskOp,
938 vector::MaskableOpInterface maskableOp,
939 ConversionPatternRewriter &rewriter) const = 0;
940};
941
942class MaskedReductionOpConversion
943 : public VectorMaskOpConversionBase<vector::ReductionOp> {
944
945public:
946 using VectorMaskOpConversionBase<
947 vector::ReductionOp>::VectorMaskOpConversionBase;
948
949 LogicalResult matchAndRewriteMaskableOp(
950 vector::MaskOp maskOp, MaskableOpInterface maskableOp,
951 ConversionPatternRewriter &rewriter) const override {
952 auto reductionOp = cast<ReductionOp>(Val: maskableOp.getOperation());
953 auto kind = reductionOp.getKind();
954 Type eltType = reductionOp.getDest().getType();
955 Type llvmType = typeConverter->convertType(t: eltType);
956 Value operand = reductionOp.getVector();
957 Value acc = reductionOp.getAcc();
958 Location loc = reductionOp.getLoc();
959
960 arith::FastMathFlagsAttr fMFAttr = reductionOp.getFastMathFlagsAttr();
961 LLVM::FastmathFlagsAttr fmf = LLVM::FastmathFlagsAttr::get(
962 context: reductionOp.getContext(),
963 value: convertArithFastMathFlagsToLLVM(arithFMF: fMFAttr.getValue()));
964
965 Value result;
966 switch (kind) {
967 case vector::CombiningKind::ADD:
968 result = lowerPredicatedReductionWithStartValue<
969 LLVM::VPReduceAddOp, ReductionNeutralZero, LLVM::VPReduceFAddOp,
970 ReductionNeutralZero>(rewriter, loc, llvmType, vectorOperand: operand, accumulator: acc,
971 mask: maskOp.getMask());
972 break;
973 case vector::CombiningKind::MUL:
974 result = lowerPredicatedReductionWithStartValue<
975 LLVM::VPReduceMulOp, ReductionNeutralIntOne, LLVM::VPReduceFMulOp,
976 ReductionNeutralFPOne>(rewriter, loc, llvmType, vectorOperand: operand, accumulator: acc,
977 mask: maskOp.getMask());
978 break;
979 case vector::CombiningKind::MINUI:
980 result = lowerPredicatedReductionWithStartValue<LLVM::VPReduceUMinOp,
981 ReductionNeutralUIntMax>(
982 rewriter, loc, llvmType, vectorOperand: operand, accumulator: acc, mask: maskOp.getMask());
983 break;
984 case vector::CombiningKind::MINSI:
985 result = lowerPredicatedReductionWithStartValue<LLVM::VPReduceSMinOp,
986 ReductionNeutralSIntMax>(
987 rewriter, loc, llvmType, vectorOperand: operand, accumulator: acc, mask: maskOp.getMask());
988 break;
989 case vector::CombiningKind::MAXUI:
990 result = lowerPredicatedReductionWithStartValue<LLVM::VPReduceUMaxOp,
991 ReductionNeutralUIntMin>(
992 rewriter, loc, llvmType, vectorOperand: operand, accumulator: acc, mask: maskOp.getMask());
993 break;
994 case vector::CombiningKind::MAXSI:
995 result = lowerPredicatedReductionWithStartValue<LLVM::VPReduceSMaxOp,
996 ReductionNeutralSIntMin>(
997 rewriter, loc, llvmType, vectorOperand: operand, accumulator: acc, mask: maskOp.getMask());
998 break;
999 case vector::CombiningKind::AND:
1000 result = lowerPredicatedReductionWithStartValue<LLVM::VPReduceAndOp,
1001 ReductionNeutralAllOnes>(
1002 rewriter, loc, llvmType, vectorOperand: operand, accumulator: acc, mask: maskOp.getMask());
1003 break;
1004 case vector::CombiningKind::OR:
1005 result = lowerPredicatedReductionWithStartValue<LLVM::VPReduceOrOp,
1006 ReductionNeutralZero>(
1007 rewriter, loc, llvmType, vectorOperand: operand, accumulator: acc, mask: maskOp.getMask());
1008 break;
1009 case vector::CombiningKind::XOR:
1010 result = lowerPredicatedReductionWithStartValue<LLVM::VPReduceXorOp,
1011 ReductionNeutralZero>(
1012 rewriter, loc, llvmType, vectorOperand: operand, accumulator: acc, mask: maskOp.getMask());
1013 break;
1014 case vector::CombiningKind::MINNUMF:
1015 result = lowerPredicatedReductionWithStartValue<LLVM::VPReduceFMinOp,
1016 ReductionNeutralFPMax>(
1017 rewriter, loc, llvmType, vectorOperand: operand, accumulator: acc, mask: maskOp.getMask());
1018 break;
1019 case vector::CombiningKind::MAXNUMF:
1020 result = lowerPredicatedReductionWithStartValue<LLVM::VPReduceFMaxOp,
1021 ReductionNeutralFPMin>(
1022 rewriter, loc, llvmType, vectorOperand: operand, accumulator: acc, mask: maskOp.getMask());
1023 break;
1024 case CombiningKind::MAXIMUMF:
1025 result = lowerMaskedReductionWithRegular<LLVM::vector_reduce_fmaximum,
1026 MaskNeutralFMaximum>(
1027 rewriter, loc, llvmType, vectorOperand: operand, accumulator: acc, mask: maskOp.getMask(), fmf);
1028 break;
1029 case CombiningKind::MINIMUMF:
1030 result = lowerMaskedReductionWithRegular<LLVM::vector_reduce_fminimum,
1031 MaskNeutralFMinimum>(
1032 rewriter, loc, llvmType, vectorOperand: operand, accumulator: acc, mask: maskOp.getMask(), fmf);
1033 break;
1034 }
1035
1036 // Replace `vector.mask` operation altogether.
1037 rewriter.replaceOp(op: maskOp, newValues: result);
1038 return success();
1039 }
1040};
1041
1042class VectorShuffleOpConversion
1043 : public ConvertOpToLLVMPattern<vector::ShuffleOp> {
1044public:
1045 using ConvertOpToLLVMPattern<vector::ShuffleOp>::ConvertOpToLLVMPattern;
1046
1047 LogicalResult
1048 matchAndRewrite(vector::ShuffleOp shuffleOp, OpAdaptor adaptor,
1049 ConversionPatternRewriter &rewriter) const override {
1050 auto loc = shuffleOp->getLoc();
1051 auto v1Type = shuffleOp.getV1VectorType();
1052 auto v2Type = shuffleOp.getV2VectorType();
1053 auto vectorType = shuffleOp.getResultVectorType();
1054 Type llvmType = typeConverter->convertType(t: vectorType);
1055 ArrayRef<int64_t> mask = shuffleOp.getMask();
1056
1057 // Bail if result type cannot be lowered.
1058 if (!llvmType)
1059 return failure();
1060
1061 // Get rank and dimension sizes.
1062 int64_t rank = vectorType.getRank();
1063#ifndef NDEBUG
1064 bool wellFormed0DCase =
1065 v1Type.getRank() == 0 && v2Type.getRank() == 0 && rank == 1;
1066 bool wellFormedNDCase =
1067 v1Type.getRank() == rank && v2Type.getRank() == rank;
1068 assert((wellFormed0DCase || wellFormedNDCase) && "op is not well-formed");
1069#endif
1070
1071 // For rank 0 and 1, where both operands have *exactly* the same vector
1072 // type, there is direct shuffle support in LLVM. Use it!
1073 if (rank <= 1 && v1Type == v2Type) {
1074 Value llvmShuffleOp = rewriter.create<LLVM::ShuffleVectorOp>(
1075 location: loc, args: adaptor.getV1(), args: adaptor.getV2(),
1076 args: llvm::to_vector_of<int32_t>(Range&: mask));
1077 rewriter.replaceOp(op: shuffleOp, newValues: llvmShuffleOp);
1078 return success();
1079 }
1080
1081 // For all other cases, insert the individual values individually.
1082 int64_t v1Dim = v1Type.getDimSize(idx: 0);
1083 Type eltType;
1084 if (auto arrayType = dyn_cast<LLVM::LLVMArrayType>(Val&: llvmType))
1085 eltType = arrayType.getElementType();
1086 else
1087 eltType = cast<VectorType>(Val&: llvmType).getElementType();
1088 Value insert = rewriter.create<LLVM::PoisonOp>(location: loc, args&: llvmType);
1089 int64_t insPos = 0;
1090 for (int64_t extPos : mask) {
1091 Value value = adaptor.getV1();
1092 if (extPos >= v1Dim) {
1093 extPos -= v1Dim;
1094 value = adaptor.getV2();
1095 }
1096 Value extract = extractOne(rewriter, typeConverter: *getTypeConverter(), loc, val: value,
1097 llvmType: eltType, rank, pos: extPos);
1098 insert = insertOne(rewriter, typeConverter: *getTypeConverter(), loc, val1: insert, val2: extract,
1099 llvmType, rank, pos: insPos++);
1100 }
1101 rewriter.replaceOp(op: shuffleOp, newValues: insert);
1102 return success();
1103 }
1104};
1105
1106class VectorExtractElementOpConversion
1107 : public ConvertOpToLLVMPattern<vector::ExtractElementOp> {
1108public:
1109 using ConvertOpToLLVMPattern<
1110 vector::ExtractElementOp>::ConvertOpToLLVMPattern;
1111
1112 LogicalResult
1113 matchAndRewrite(vector::ExtractElementOp extractEltOp, OpAdaptor adaptor,
1114 ConversionPatternRewriter &rewriter) const override {
1115 auto vectorType = extractEltOp.getSourceVectorType();
1116 auto llvmType = typeConverter->convertType(t: vectorType.getElementType());
1117
1118 // Bail if result type cannot be lowered.
1119 if (!llvmType)
1120 return failure();
1121
1122 if (vectorType.getRank() == 0) {
1123 Location loc = extractEltOp.getLoc();
1124 auto idxType = rewriter.getIndexType();
1125 auto zero = rewriter.create<LLVM::ConstantOp>(
1126 location: loc, args: typeConverter->convertType(t: idxType),
1127 args: rewriter.getIntegerAttr(type: idxType, value: 0));
1128 rewriter.replaceOpWithNewOp<LLVM::ExtractElementOp>(
1129 op: extractEltOp, args&: llvmType, args: adaptor.getVector(), args&: zero);
1130 return success();
1131 }
1132
1133 rewriter.replaceOpWithNewOp<LLVM::ExtractElementOp>(
1134 op: extractEltOp, args&: llvmType, args: adaptor.getVector(), args: adaptor.getPosition());
1135 return success();
1136 }
1137};
1138
1139class VectorExtractOpConversion
1140 : public ConvertOpToLLVMPattern<vector::ExtractOp> {
1141public:
1142 using ConvertOpToLLVMPattern<vector::ExtractOp>::ConvertOpToLLVMPattern;
1143
1144 LogicalResult
1145 matchAndRewrite(vector::ExtractOp extractOp, OpAdaptor adaptor,
1146 ConversionPatternRewriter &rewriter) const override {
1147 auto loc = extractOp->getLoc();
1148 auto resultType = extractOp.getResult().getType();
1149 auto llvmResultType = typeConverter->convertType(t: resultType);
1150 // Bail if result type cannot be lowered.
1151 if (!llvmResultType)
1152 return failure();
1153
1154 SmallVector<OpFoldResult> positionVec = getMixedValues(
1155 staticValues: adaptor.getStaticPosition(), dynamicValues: adaptor.getDynamicPosition(), b&: rewriter);
1156
1157 // The Vector -> LLVM lowering models N-D vectors as nested aggregates of
1158 // 1-d vectors. This nesting is modeled using arrays. We do this conversion
1159 // from a N-d vector extract to a nested aggregate vector extract in two
1160 // steps:
1161 // - Extract a member from the nested aggregate. The result can be
1162 // a lower rank nested aggregate or a vector (1-D). This is done using
1163 // `llvm.extractvalue`.
1164 // - Extract a scalar out of the vector if needed. This is done using
1165 // `llvm.extractelement`.
1166
1167 // Determine if we need to extract a member out of the aggregate. We
1168 // always need to extract a member if the input rank >= 2.
1169 bool extractsAggregate = extractOp.getSourceVectorType().getRank() >= 2;
1170 // Determine if we need to extract a scalar as the result. We extract
1171 // a scalar if the extract is full rank, i.e., the number of indices is
1172 // equal to source vector rank.
1173 bool extractsScalar = static_cast<int64_t>(positionVec.size()) ==
1174 extractOp.getSourceVectorType().getRank();
1175
1176 // Since the LLVM type converter converts 0-d vectors to 1-d vectors, we
1177 // need to add a position for this change.
1178 if (extractOp.getSourceVectorType().getRank() == 0) {
1179 Type idxType = typeConverter->convertType(t: rewriter.getIndexType());
1180 positionVec.push_back(Elt: rewriter.getZeroAttr(type: idxType));
1181 }
1182
1183 Value extracted = adaptor.getVector();
1184 if (extractsAggregate) {
1185 ArrayRef<OpFoldResult> position(positionVec);
1186 if (extractsScalar) {
1187 // If we are extracting a scalar from the extracted member, we drop
1188 // the last index, which will be used to extract the scalar out of the
1189 // vector.
1190 position = position.drop_back();
1191 }
1192 // llvm.extractvalue does not support dynamic dimensions.
1193 if (!llvm::all_of(Range&: position, P: llvm::IsaPred<Attribute>)) {
1194 return failure();
1195 }
1196 extracted = rewriter.create<LLVM::ExtractValueOp>(
1197 location: loc, args&: extracted, args: getAsIntegers(foldResults: position));
1198 }
1199
1200 if (extractsScalar) {
1201 extracted = rewriter.create<LLVM::ExtractElementOp>(
1202 location: loc, args&: extracted, args: getAsLLVMValue(builder&: rewriter, loc, foldResult: positionVec.back()));
1203 }
1204
1205 rewriter.replaceOp(op: extractOp, newValues: extracted);
1206 return success();
1207 }
1208};
1209
1210/// Conversion pattern that turns a vector.fma on a 1-D vector
1211/// into an llvm.intr.fmuladd. This is a trivial 1-1 conversion.
1212/// This does not match vectors of n >= 2 rank.
1213///
1214/// Example:
1215/// ```
1216/// vector.fma %a, %a, %a : vector<8xf32>
1217/// ```
1218/// is converted to:
1219/// ```
1220/// llvm.intr.fmuladd %va, %va, %va:
1221/// (!llvm."<8 x f32>">, !llvm<"<8 x f32>">, !llvm<"<8 x f32>">)
1222/// -> !llvm."<8 x f32>">
1223/// ```
1224class VectorFMAOp1DConversion : public ConvertOpToLLVMPattern<vector::FMAOp> {
1225public:
1226 using ConvertOpToLLVMPattern<vector::FMAOp>::ConvertOpToLLVMPattern;
1227
1228 LogicalResult
1229 matchAndRewrite(vector::FMAOp fmaOp, OpAdaptor adaptor,
1230 ConversionPatternRewriter &rewriter) const override {
1231 VectorType vType = fmaOp.getVectorType();
1232 if (vType.getRank() > 1)
1233 return failure();
1234
1235 rewriter.replaceOpWithNewOp<LLVM::FMulAddOp>(
1236 op: fmaOp, args: adaptor.getLhs(), args: adaptor.getRhs(), args: adaptor.getAcc());
1237 return success();
1238 }
1239};
1240
1241class VectorInsertElementOpConversion
1242 : public ConvertOpToLLVMPattern<vector::InsertElementOp> {
1243public:
1244 using ConvertOpToLLVMPattern<vector::InsertElementOp>::ConvertOpToLLVMPattern;
1245
1246 LogicalResult
1247 matchAndRewrite(vector::InsertElementOp insertEltOp, OpAdaptor adaptor,
1248 ConversionPatternRewriter &rewriter) const override {
1249 auto vectorType = insertEltOp.getDestVectorType();
1250 auto llvmType = typeConverter->convertType(t: vectorType);
1251
1252 // Bail if result type cannot be lowered.
1253 if (!llvmType)
1254 return failure();
1255
1256 if (vectorType.getRank() == 0) {
1257 Location loc = insertEltOp.getLoc();
1258 auto idxType = rewriter.getIndexType();
1259 auto zero = rewriter.create<LLVM::ConstantOp>(
1260 location: loc, args: typeConverter->convertType(t: idxType),
1261 args: rewriter.getIntegerAttr(type: idxType, value: 0));
1262 rewriter.replaceOpWithNewOp<LLVM::InsertElementOp>(
1263 op: insertEltOp, args&: llvmType, args: adaptor.getDest(), args: adaptor.getSource(), args&: zero);
1264 return success();
1265 }
1266
1267 rewriter.replaceOpWithNewOp<LLVM::InsertElementOp>(
1268 op: insertEltOp, args&: llvmType, args: adaptor.getDest(), args: adaptor.getSource(),
1269 args: adaptor.getPosition());
1270 return success();
1271 }
1272};
1273
1274class VectorInsertOpConversion
1275 : public ConvertOpToLLVMPattern<vector::InsertOp> {
1276public:
1277 using ConvertOpToLLVMPattern<vector::InsertOp>::ConvertOpToLLVMPattern;
1278
1279 LogicalResult
1280 matchAndRewrite(vector::InsertOp insertOp, OpAdaptor adaptor,
1281 ConversionPatternRewriter &rewriter) const override {
1282 auto loc = insertOp->getLoc();
1283 auto destVectorType = insertOp.getDestVectorType();
1284 auto llvmResultType = typeConverter->convertType(t: destVectorType);
1285 // Bail if result type cannot be lowered.
1286 if (!llvmResultType)
1287 return failure();
1288
1289 SmallVector<OpFoldResult> positionVec = getMixedValues(
1290 staticValues: adaptor.getStaticPosition(), dynamicValues: adaptor.getDynamicPosition(), b&: rewriter);
1291
1292 // The logic in this pattern mirrors VectorExtractOpConversion. Refer to
1293 // its explanatory comment about how N-D vectors are converted as nested
1294 // aggregates (llvm.array's) of 1D vectors.
1295 //
1296 // The innermost dimension of the destination vector, when converted to a
1297 // nested aggregate form, will always be a 1D vector.
1298 //
1299 // * If the insertion is happening into the innermost dimension of the
1300 // destination vector:
1301 // - If the destination is a nested aggregate, extract a 1D vector out of
1302 // the aggregate. This can be done using llvm.extractvalue. The
1303 // destination is now guaranteed to be a 1D vector, to which we are
1304 // inserting.
1305 // - Do the insertion into the 1D destination vector, and make the result
1306 // the new source nested aggregate. This can be done using
1307 // llvm.insertelement.
1308 // * Insert the source nested aggregate into the destination nested
1309 // aggregate.
1310
1311 // Determine if we need to extract/insert a 1D vector out of the aggregate.
1312 bool isNestedAggregate = isa<LLVM::LLVMArrayType>(Val: llvmResultType);
1313 // Determine if we need to insert a scalar into the 1D vector.
1314 bool insertIntoInnermostDim =
1315 static_cast<int64_t>(positionVec.size()) == destVectorType.getRank();
1316
1317 ArrayRef<OpFoldResult> positionOf1DVectorWithinAggregate(
1318 positionVec.begin(),
1319 insertIntoInnermostDim ? positionVec.size() - 1 : positionVec.size());
1320 OpFoldResult positionOfScalarWithin1DVector;
1321 if (destVectorType.getRank() == 0) {
1322 // Since the LLVM type converter converts 0D vectors to 1D vectors, we
1323 // need to create a 0 here as the position into the 1D vector.
1324 Type idxType = typeConverter->convertType(t: rewriter.getIndexType());
1325 positionOfScalarWithin1DVector = rewriter.getZeroAttr(type: idxType);
1326 } else if (insertIntoInnermostDim) {
1327 positionOfScalarWithin1DVector = positionVec.back();
1328 }
1329
1330 // We are going to mutate this 1D vector until it is either the final
1331 // result (in the non-aggregate case) or the value that needs to be
1332 // inserted into the aggregate result.
1333 Value sourceAggregate = adaptor.getValueToStore();
1334 if (insertIntoInnermostDim) {
1335 // Scalar-into-1D-vector case, so we know we will have to create a
1336 // InsertElementOp. The question is into what destination.
1337 if (isNestedAggregate) {
1338 // Aggregate case: the destination for the InsertElementOp needs to be
1339 // extracted from the aggregate.
1340 if (!llvm::all_of(Range&: positionOf1DVectorWithinAggregate,
1341 P: llvm::IsaPred<Attribute>)) {
1342 // llvm.extractvalue does not support dynamic dimensions.
1343 return failure();
1344 }
1345 sourceAggregate = rewriter.create<LLVM::ExtractValueOp>(
1346 location: loc, args: adaptor.getDest(),
1347 args: getAsIntegers(foldResults: positionOf1DVectorWithinAggregate));
1348 } else {
1349 // No-aggregate case. The destination for the InsertElementOp is just
1350 // the insertOp's destination.
1351 sourceAggregate = adaptor.getDest();
1352 }
1353 // Insert the scalar into the 1D vector.
1354 sourceAggregate = rewriter.create<LLVM::InsertElementOp>(
1355 location: loc, args: sourceAggregate.getType(), args&: sourceAggregate,
1356 args: adaptor.getValueToStore(),
1357 args: getAsLLVMValue(builder&: rewriter, loc, foldResult: positionOfScalarWithin1DVector));
1358 }
1359
1360 Value result = sourceAggregate;
1361 if (isNestedAggregate) {
1362 result = rewriter.create<LLVM::InsertValueOp>(
1363 location: loc, args: adaptor.getDest(), args&: sourceAggregate,
1364 args: getAsIntegers(foldResults: positionOf1DVectorWithinAggregate));
1365 }
1366
1367 rewriter.replaceOp(op: insertOp, newValues: result);
1368 return success();
1369 }
1370};
1371
1372/// Lower vector.scalable.insert ops to LLVM vector.insert
1373struct VectorScalableInsertOpLowering
1374 : public ConvertOpToLLVMPattern<vector::ScalableInsertOp> {
1375 using ConvertOpToLLVMPattern<
1376 vector::ScalableInsertOp>::ConvertOpToLLVMPattern;
1377
1378 LogicalResult
1379 matchAndRewrite(vector::ScalableInsertOp insOp, OpAdaptor adaptor,
1380 ConversionPatternRewriter &rewriter) const override {
1381 rewriter.replaceOpWithNewOp<LLVM::vector_insert>(
1382 op: insOp, args: adaptor.getDest(), args: adaptor.getValueToStore(), args: adaptor.getPos());
1383 return success();
1384 }
1385};
1386
1387/// Lower vector.scalable.extract ops to LLVM vector.extract
1388struct VectorScalableExtractOpLowering
1389 : public ConvertOpToLLVMPattern<vector::ScalableExtractOp> {
1390 using ConvertOpToLLVMPattern<
1391 vector::ScalableExtractOp>::ConvertOpToLLVMPattern;
1392
1393 LogicalResult
1394 matchAndRewrite(vector::ScalableExtractOp extOp, OpAdaptor adaptor,
1395 ConversionPatternRewriter &rewriter) const override {
1396 rewriter.replaceOpWithNewOp<LLVM::vector_extract>(
1397 op: extOp, args: typeConverter->convertType(t: extOp.getResultVectorType()),
1398 args: adaptor.getSource(), args: adaptor.getPos());
1399 return success();
1400 }
1401};
1402
1403/// Rank reducing rewrite for n-D FMA into (n-1)-D FMA where n > 1.
1404///
1405/// Example:
1406/// ```
1407/// %d = vector.fma %a, %b, %c : vector<2x4xf32>
1408/// ```
1409/// is rewritten into:
1410/// ```
1411/// %r = splat %f0: vector<2x4xf32>
1412/// %va = vector.extractvalue %a[0] : vector<2x4xf32>
1413/// %vb = vector.extractvalue %b[0] : vector<2x4xf32>
1414/// %vc = vector.extractvalue %c[0] : vector<2x4xf32>
1415/// %vd = vector.fma %va, %vb, %vc : vector<4xf32>
1416/// %r2 = vector.insertvalue %vd, %r[0] : vector<4xf32> into vector<2x4xf32>
1417/// %va2 = vector.extractvalue %a2[1] : vector<2x4xf32>
1418/// %vb2 = vector.extractvalue %b2[1] : vector<2x4xf32>
1419/// %vc2 = vector.extractvalue %c2[1] : vector<2x4xf32>
1420/// %vd2 = vector.fma %va2, %vb2, %vc2 : vector<4xf32>
1421/// %r3 = vector.insertvalue %vd2, %r2[1] : vector<4xf32> into vector<2x4xf32>
1422/// // %r3 holds the final value.
1423/// ```
1424class VectorFMAOpNDRewritePattern : public OpRewritePattern<FMAOp> {
1425public:
1426 using OpRewritePattern<FMAOp>::OpRewritePattern;
1427
1428 void initialize() {
1429 // This pattern recursively unpacks one dimension at a time. The recursion
1430 // bounded as the rank is strictly decreasing.
1431 setHasBoundedRewriteRecursion();
1432 }
1433
1434 LogicalResult matchAndRewrite(FMAOp op,
1435 PatternRewriter &rewriter) const override {
1436 auto vType = op.getVectorType();
1437 if (vType.getRank() < 2)
1438 return failure();
1439
1440 auto loc = op.getLoc();
1441 auto elemType = vType.getElementType();
1442 Value zero = rewriter.create<arith::ConstantOp>(
1443 location: loc, args&: elemType, args: rewriter.getZeroAttr(type: elemType));
1444 Value desc = rewriter.create<vector::SplatOp>(location: loc, args&: vType, args&: zero);
1445 for (int64_t i = 0, e = vType.getShape().front(); i != e; ++i) {
1446 Value extrLHS = rewriter.create<ExtractOp>(location: loc, args: op.getLhs(), args&: i);
1447 Value extrRHS = rewriter.create<ExtractOp>(location: loc, args: op.getRhs(), args&: i);
1448 Value extrACC = rewriter.create<ExtractOp>(location: loc, args: op.getAcc(), args&: i);
1449 Value fma = rewriter.create<FMAOp>(location: loc, args&: extrLHS, args&: extrRHS, args&: extrACC);
1450 desc = rewriter.create<InsertOp>(location: loc, args&: fma, args&: desc, args&: i);
1451 }
1452 rewriter.replaceOp(op, newValues: desc);
1453 return success();
1454 }
1455};
1456
1457/// Returns the strides if the memory underlying `memRefType` has a contiguous
1458/// static layout.
1459static std::optional<SmallVector<int64_t, 4>>
1460computeContiguousStrides(MemRefType memRefType) {
1461 int64_t offset;
1462 SmallVector<int64_t, 4> strides;
1463 if (failed(Result: memRefType.getStridesAndOffset(strides, offset)))
1464 return std::nullopt;
1465 if (!strides.empty() && strides.back() != 1)
1466 return std::nullopt;
1467 // If no layout or identity layout, this is contiguous by definition.
1468 if (memRefType.getLayout().isIdentity())
1469 return strides;
1470
1471 // Otherwise, we must determine contiguity form shapes. This can only ever
1472 // work in static cases because MemRefType is underspecified to represent
1473 // contiguous dynamic shapes in other ways than with just empty/identity
1474 // layout.
1475 auto sizes = memRefType.getShape();
1476 for (int index = 0, e = strides.size() - 1; index < e; ++index) {
1477 if (ShapedType::isDynamic(dValue: sizes[index + 1]) ||
1478 ShapedType::isDynamic(dValue: strides[index]) ||
1479 ShapedType::isDynamic(dValue: strides[index + 1]))
1480 return std::nullopt;
1481 if (strides[index] != strides[index + 1] * sizes[index + 1])
1482 return std::nullopt;
1483 }
1484 return strides;
1485}
1486
1487class VectorTypeCastOpConversion
1488 : public ConvertOpToLLVMPattern<vector::TypeCastOp> {
1489public:
1490 using ConvertOpToLLVMPattern<vector::TypeCastOp>::ConvertOpToLLVMPattern;
1491
1492 LogicalResult
1493 matchAndRewrite(vector::TypeCastOp castOp, OpAdaptor adaptor,
1494 ConversionPatternRewriter &rewriter) const override {
1495 auto loc = castOp->getLoc();
1496 MemRefType sourceMemRefType =
1497 cast<MemRefType>(Val: castOp.getOperand().getType());
1498 MemRefType targetMemRefType = castOp.getType();
1499
1500 // Only static shape casts supported atm.
1501 if (!sourceMemRefType.hasStaticShape() ||
1502 !targetMemRefType.hasStaticShape())
1503 return failure();
1504
1505 auto llvmSourceDescriptorTy =
1506 dyn_cast<LLVM::LLVMStructType>(Val: adaptor.getOperands()[0].getType());
1507 if (!llvmSourceDescriptorTy)
1508 return failure();
1509 MemRefDescriptor sourceMemRef(adaptor.getOperands()[0]);
1510
1511 auto llvmTargetDescriptorTy = dyn_cast_or_null<LLVM::LLVMStructType>(
1512 Val: typeConverter->convertType(t: targetMemRefType));
1513 if (!llvmTargetDescriptorTy)
1514 return failure();
1515
1516 // Only contiguous source buffers supported atm.
1517 auto sourceStrides = computeContiguousStrides(memRefType: sourceMemRefType);
1518 if (!sourceStrides)
1519 return failure();
1520 auto targetStrides = computeContiguousStrides(memRefType: targetMemRefType);
1521 if (!targetStrides)
1522 return failure();
1523 // Only support static strides for now, regardless of contiguity.
1524 if (llvm::any_of(Range&: *targetStrides, P: ShapedType::isDynamic))
1525 return failure();
1526
1527 auto int64Ty = IntegerType::get(context: rewriter.getContext(), width: 64);
1528
1529 // Create descriptor.
1530 auto desc = MemRefDescriptor::poison(builder&: rewriter, loc, descriptorType: llvmTargetDescriptorTy);
1531 // Set allocated ptr.
1532 Value allocated = sourceMemRef.allocatedPtr(builder&: rewriter, loc);
1533 desc.setAllocatedPtr(builder&: rewriter, loc, ptr: allocated);
1534
1535 // Set aligned ptr.
1536 Value ptr = sourceMemRef.alignedPtr(builder&: rewriter, loc);
1537 desc.setAlignedPtr(builder&: rewriter, loc, ptr);
1538 // Fill offset 0.
1539 auto attr = rewriter.getIntegerAttr(type: rewriter.getIndexType(), value: 0);
1540 auto zero = rewriter.create<LLVM::ConstantOp>(location: loc, args&: int64Ty, args&: attr);
1541 desc.setOffset(builder&: rewriter, loc, offset: zero);
1542
1543 // Fill size and stride descriptors in memref.
1544 for (const auto &indexedSize :
1545 llvm::enumerate(First: targetMemRefType.getShape())) {
1546 int64_t index = indexedSize.index();
1547 auto sizeAttr =
1548 rewriter.getIntegerAttr(type: rewriter.getIndexType(), value: indexedSize.value());
1549 auto size = rewriter.create<LLVM::ConstantOp>(location: loc, args&: int64Ty, args&: sizeAttr);
1550 desc.setSize(builder&: rewriter, loc, pos: index, size);
1551 auto strideAttr = rewriter.getIntegerAttr(type: rewriter.getIndexType(),
1552 value: (*targetStrides)[index]);
1553 auto stride = rewriter.create<LLVM::ConstantOp>(location: loc, args&: int64Ty, args&: strideAttr);
1554 desc.setStride(builder&: rewriter, loc, pos: index, stride);
1555 }
1556
1557 rewriter.replaceOp(op: castOp, newValues: {desc});
1558 return success();
1559 }
1560};
1561
1562/// Conversion pattern for a `vector.create_mask` (1-D scalable vectors only).
1563/// Non-scalable versions of this operation are handled in Vector Transforms.
1564class VectorCreateMaskOpConversion
1565 : public OpConversionPattern<vector::CreateMaskOp> {
1566public:
1567 explicit VectorCreateMaskOpConversion(MLIRContext *context,
1568 bool enableIndexOpt)
1569 : OpConversionPattern<vector::CreateMaskOp>(context),
1570 force32BitVectorIndices(enableIndexOpt) {}
1571
1572 LogicalResult
1573 matchAndRewrite(vector::CreateMaskOp op, OpAdaptor adaptor,
1574 ConversionPatternRewriter &rewriter) const override {
1575 auto dstType = op.getType();
1576 if (dstType.getRank() != 1 || !cast<VectorType>(Val&: dstType).isScalable())
1577 return failure();
1578 IntegerType idxType =
1579 force32BitVectorIndices ? rewriter.getI32Type() : rewriter.getI64Type();
1580 auto loc = op->getLoc();
1581 Value indices = rewriter.create<LLVM::StepVectorOp>(
1582 location: loc, args: LLVM::getVectorType(elementType: idxType, numElements: dstType.getShape()[0],
1583 /*isScalable=*/true));
1584 auto bound = getValueOrCreateCastToIndexLike(b&: rewriter, loc, targetType: idxType,
1585 value: adaptor.getOperands()[0]);
1586 Value bounds = rewriter.create<SplatOp>(location: loc, args: indices.getType(), args&: bound);
1587 Value comp = rewriter.create<arith::CmpIOp>(location: loc, args: arith::CmpIPredicate::slt,
1588 args&: indices, args&: bounds);
1589 rewriter.replaceOp(op, newValues: comp);
1590 return success();
1591 }
1592
1593private:
1594 const bool force32BitVectorIndices;
1595};
1596
1597class VectorPrintOpConversion : public ConvertOpToLLVMPattern<vector::PrintOp> {
1598 SymbolTableCollection *symbolTables = nullptr;
1599
1600public:
1601 explicit VectorPrintOpConversion(
1602 const LLVMTypeConverter &typeConverter,
1603 SymbolTableCollection *symbolTables = nullptr)
1604 : ConvertOpToLLVMPattern<vector::PrintOp>(typeConverter),
1605 symbolTables(symbolTables) {}
1606
1607 // Lowering implementation that relies on a small runtime support library,
1608 // which only needs to provide a few printing methods (single value for all
1609 // data types, opening/closing bracket, comma, newline). The lowering splits
1610 // the vector into elementary printing operations. The advantage of this
1611 // approach is that the library can remain unaware of all low-level
1612 // implementation details of vectors while still supporting output of any
1613 // shaped and dimensioned vector.
1614 //
1615 // Note: This lowering only handles scalars, n-D vectors are broken into
1616 // printing scalars in loops in VectorToSCF.
1617 //
1618 // TODO: rely solely on libc in future? something else?
1619 //
1620 LogicalResult
1621 matchAndRewrite(vector::PrintOp printOp, OpAdaptor adaptor,
1622 ConversionPatternRewriter &rewriter) const override {
1623 auto parent = printOp->getParentOfType<ModuleOp>();
1624 if (!parent)
1625 return failure();
1626
1627 auto loc = printOp->getLoc();
1628
1629 if (auto value = adaptor.getSource()) {
1630 Type printType = printOp.getPrintType();
1631 if (isa<VectorType>(Val: printType)) {
1632 // Vectors should be broken into elementary print ops in VectorToSCF.
1633 return failure();
1634 }
1635 if (failed(Result: emitScalarPrint(rewriter, parent, loc, printType, value)))
1636 return failure();
1637 }
1638
1639 auto punct = printOp.getPunctuation();
1640 if (auto stringLiteral = printOp.getStringLiteral()) {
1641 auto createResult =
1642 LLVM::createPrintStrCall(builder&: rewriter, loc, moduleOp: parent, symbolName: "vector_print_str",
1643 string: *stringLiteral, typeConverter: *getTypeConverter(),
1644 /*addNewline=*/false);
1645 if (createResult.failed())
1646 return failure();
1647
1648 } else if (punct != PrintPunctuation::NoPunctuation) {
1649 FailureOr<LLVM::LLVMFuncOp> op = [&]() {
1650 switch (punct) {
1651 case PrintPunctuation::Close:
1652 return LLVM::lookupOrCreatePrintCloseFn(b&: rewriter, moduleOp: parent,
1653 symbolTables);
1654 case PrintPunctuation::Open:
1655 return LLVM::lookupOrCreatePrintOpenFn(b&: rewriter, moduleOp: parent,
1656 symbolTables);
1657 case PrintPunctuation::Comma:
1658 return LLVM::lookupOrCreatePrintCommaFn(b&: rewriter, moduleOp: parent,
1659 symbolTables);
1660 case PrintPunctuation::NewLine:
1661 return LLVM::lookupOrCreatePrintNewlineFn(b&: rewriter, moduleOp: parent,
1662 symbolTables);
1663 default:
1664 llvm_unreachable("unexpected punctuation");
1665 }
1666 }();
1667 if (failed(Result: op))
1668 return failure();
1669 emitCall(rewriter, loc: printOp->getLoc(), ref: op.value());
1670 }
1671
1672 rewriter.eraseOp(op: printOp);
1673 return success();
1674 }
1675
1676private:
1677 enum class PrintConversion {
1678 // clang-format off
1679 None,
1680 ZeroExt64,
1681 SignExt64,
1682 Bitcast16
1683 // clang-format on
1684 };
1685
1686 LogicalResult emitScalarPrint(ConversionPatternRewriter &rewriter,
1687 ModuleOp parent, Location loc, Type printType,
1688 Value value) const {
1689 if (typeConverter->convertType(t: printType) == nullptr)
1690 return failure();
1691
1692 // Make sure element type has runtime support.
1693 PrintConversion conversion = PrintConversion::None;
1694 FailureOr<Operation *> printer;
1695 if (printType.isF32()) {
1696 printer = LLVM::lookupOrCreatePrintF32Fn(b&: rewriter, moduleOp: parent, symbolTables);
1697 } else if (printType.isF64()) {
1698 printer = LLVM::lookupOrCreatePrintF64Fn(b&: rewriter, moduleOp: parent, symbolTables);
1699 } else if (printType.isF16()) {
1700 conversion = PrintConversion::Bitcast16; // bits!
1701 printer = LLVM::lookupOrCreatePrintF16Fn(b&: rewriter, moduleOp: parent, symbolTables);
1702 } else if (printType.isBF16()) {
1703 conversion = PrintConversion::Bitcast16; // bits!
1704 printer = LLVM::lookupOrCreatePrintBF16Fn(b&: rewriter, moduleOp: parent, symbolTables);
1705 } else if (printType.isIndex()) {
1706 printer = LLVM::lookupOrCreatePrintU64Fn(b&: rewriter, moduleOp: parent, symbolTables);
1707 } else if (auto intTy = dyn_cast<IntegerType>(Val&: printType)) {
1708 // Integers need a zero or sign extension on the operand
1709 // (depending on the source type) as well as a signed or
1710 // unsigned print method. Up to 64-bit is supported.
1711 unsigned width = intTy.getWidth();
1712 if (intTy.isUnsigned()) {
1713 if (width <= 64) {
1714 if (width < 64)
1715 conversion = PrintConversion::ZeroExt64;
1716 printer =
1717 LLVM::lookupOrCreatePrintU64Fn(b&: rewriter, moduleOp: parent, symbolTables);
1718 } else {
1719 return failure();
1720 }
1721 } else {
1722 assert(intTy.isSignless() || intTy.isSigned());
1723 if (width <= 64) {
1724 // Note that we *always* zero extend booleans (1-bit integers),
1725 // so that true/false is printed as 1/0 rather than -1/0.
1726 if (width == 1)
1727 conversion = PrintConversion::ZeroExt64;
1728 else if (width < 64)
1729 conversion = PrintConversion::SignExt64;
1730 printer =
1731 LLVM::lookupOrCreatePrintI64Fn(b&: rewriter, moduleOp: parent, symbolTables);
1732 } else {
1733 return failure();
1734 }
1735 }
1736 } else {
1737 return failure();
1738 }
1739 if (failed(Result: printer))
1740 return failure();
1741
1742 switch (conversion) {
1743 case PrintConversion::ZeroExt64:
1744 value = rewriter.create<arith::ExtUIOp>(
1745 location: loc, args: IntegerType::get(context: rewriter.getContext(), width: 64), args&: value);
1746 break;
1747 case PrintConversion::SignExt64:
1748 value = rewriter.create<arith::ExtSIOp>(
1749 location: loc, args: IntegerType::get(context: rewriter.getContext(), width: 64), args&: value);
1750 break;
1751 case PrintConversion::Bitcast16:
1752 value = rewriter.create<LLVM::BitcastOp>(
1753 location: loc, args: IntegerType::get(context: rewriter.getContext(), width: 16), args&: value);
1754 break;
1755 case PrintConversion::None:
1756 break;
1757 }
1758 emitCall(rewriter, loc, ref: printer.value(), params: value);
1759 return success();
1760 }
1761
1762 // Helper to emit a call.
1763 static void emitCall(ConversionPatternRewriter &rewriter, Location loc,
1764 Operation *ref, ValueRange params = ValueRange()) {
1765 rewriter.create<LLVM::CallOp>(location: loc, args: TypeRange(), args: SymbolRefAttr::get(symbol: ref),
1766 args&: params);
1767 }
1768};
1769
1770/// The Splat operation is lowered to an insertelement + a shufflevector
1771/// operation. Splat to only 0-d and 1-d vector result types are lowered.
1772struct VectorSplatOpLowering : public ConvertOpToLLVMPattern<vector::SplatOp> {
1773 using ConvertOpToLLVMPattern<vector::SplatOp>::ConvertOpToLLVMPattern;
1774
1775 LogicalResult
1776 matchAndRewrite(vector::SplatOp splatOp, OpAdaptor adaptor,
1777 ConversionPatternRewriter &rewriter) const override {
1778 VectorType resultType = cast<VectorType>(Val: splatOp.getType());
1779 if (resultType.getRank() > 1)
1780 return failure();
1781
1782 // First insert it into a poison vector so we can shuffle it.
1783 auto vectorType = typeConverter->convertType(t: splatOp.getType());
1784 Value poison =
1785 rewriter.create<LLVM::PoisonOp>(location: splatOp.getLoc(), args&: vectorType);
1786 auto zero = rewriter.create<LLVM::ConstantOp>(
1787 location: splatOp.getLoc(),
1788 args: typeConverter->convertType(t: rewriter.getIntegerType(width: 32)),
1789 args: rewriter.getZeroAttr(type: rewriter.getIntegerType(width: 32)));
1790
1791 // For 0-d vector, we simply do `insertelement`.
1792 if (resultType.getRank() == 0) {
1793 rewriter.replaceOpWithNewOp<LLVM::InsertElementOp>(
1794 op: splatOp, args&: vectorType, args&: poison, args: adaptor.getInput(), args&: zero);
1795 return success();
1796 }
1797
1798 // For 1-d vector, we additionally do a `vectorshuffle`.
1799 auto v = rewriter.create<LLVM::InsertElementOp>(
1800 location: splatOp.getLoc(), args&: vectorType, args&: poison, args: adaptor.getInput(), args&: zero);
1801
1802 int64_t width = cast<VectorType>(Val: splatOp.getType()).getDimSize(idx: 0);
1803 SmallVector<int32_t> zeroValues(width, 0);
1804
1805 // Shuffle the value across the desired number of elements.
1806 rewriter.replaceOpWithNewOp<LLVM::ShuffleVectorOp>(op: splatOp, args&: v, args&: poison,
1807 args&: zeroValues);
1808 return success();
1809 }
1810};
1811
1812/// The Splat operation is lowered to an insertelement + a shufflevector
1813/// operation. Splat to only 2+-d vector result types are lowered by the
1814/// SplatNdOpLowering, the 1-d case is handled by SplatOpLowering.
1815struct VectorSplatNdOpLowering : public ConvertOpToLLVMPattern<SplatOp> {
1816 using ConvertOpToLLVMPattern<SplatOp>::ConvertOpToLLVMPattern;
1817
1818 LogicalResult
1819 matchAndRewrite(SplatOp splatOp, OpAdaptor adaptor,
1820 ConversionPatternRewriter &rewriter) const override {
1821 VectorType resultType = splatOp.getType();
1822 if (resultType.getRank() <= 1)
1823 return failure();
1824
1825 // First insert it into an undef vector so we can shuffle it.
1826 auto loc = splatOp.getLoc();
1827 auto vectorTypeInfo =
1828 LLVM::detail::extractNDVectorTypeInfo(vectorType: resultType, converter: *getTypeConverter());
1829 auto llvmNDVectorTy = vectorTypeInfo.llvmNDVectorTy;
1830 auto llvm1DVectorTy = vectorTypeInfo.llvm1DVectorTy;
1831 if (!llvmNDVectorTy || !llvm1DVectorTy)
1832 return failure();
1833
1834 // Construct returned value.
1835 Value desc = rewriter.create<LLVM::PoisonOp>(location: loc, args&: llvmNDVectorTy);
1836
1837 // Construct a 1-D vector with the splatted value that we insert in all the
1838 // places within the returned descriptor.
1839 Value vdesc = rewriter.create<LLVM::PoisonOp>(location: loc, args&: llvm1DVectorTy);
1840 auto zero = rewriter.create<LLVM::ConstantOp>(
1841 location: loc, args: typeConverter->convertType(t: rewriter.getIntegerType(width: 32)),
1842 args: rewriter.getZeroAttr(type: rewriter.getIntegerType(width: 32)));
1843 Value v = rewriter.create<LLVM::InsertElementOp>(location: loc, args&: llvm1DVectorTy, args&: vdesc,
1844 args: adaptor.getInput(), args&: zero);
1845
1846 // Shuffle the value across the desired number of elements.
1847 int64_t width = resultType.getDimSize(idx: resultType.getRank() - 1);
1848 SmallVector<int32_t> zeroValues(width, 0);
1849 v = rewriter.create<LLVM::ShuffleVectorOp>(location: loc, args&: v, args&: v, args&: zeroValues);
1850
1851 // Iterate of linear index, convert to coords space and insert splatted 1-D
1852 // vector in each position.
1853 nDVectorIterate(info: vectorTypeInfo, builder&: rewriter, fun: [&](ArrayRef<int64_t> position) {
1854 desc = rewriter.create<LLVM::InsertValueOp>(location: loc, args&: desc, args&: v, args&: position);
1855 });
1856 rewriter.replaceOp(op: splatOp, newValues: desc);
1857 return success();
1858 }
1859};
1860
1861/// Conversion pattern for a `vector.interleave`.
1862/// This supports fixed-sized vectors and scalable vectors.
1863struct VectorInterleaveOpLowering
1864 : public ConvertOpToLLVMPattern<vector::InterleaveOp> {
1865 using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern;
1866
1867 LogicalResult
1868 matchAndRewrite(vector::InterleaveOp interleaveOp, OpAdaptor adaptor,
1869 ConversionPatternRewriter &rewriter) const override {
1870 VectorType resultType = interleaveOp.getResultVectorType();
1871 // n-D interleaves should have been lowered already.
1872 if (resultType.getRank() != 1)
1873 return rewriter.notifyMatchFailure(arg&: interleaveOp,
1874 msg: "InterleaveOp not rank 1");
1875 // If the result is rank 1, then this directly maps to LLVM.
1876 if (resultType.isScalable()) {
1877 rewriter.replaceOpWithNewOp<LLVM::vector_interleave2>(
1878 op: interleaveOp, args: typeConverter->convertType(t: resultType),
1879 args: adaptor.getLhs(), args: adaptor.getRhs());
1880 return success();
1881 }
1882 // Lower fixed-size interleaves to a shufflevector. While the
1883 // vector.interleave2 intrinsic supports fixed and scalable vectors, the
1884 // langref still recommends fixed-vectors use shufflevector, see:
1885 // https://llvm.org/docs/LangRef.html#id876.
1886 int64_t resultVectorSize = resultType.getNumElements();
1887 SmallVector<int32_t> interleaveShuffleMask;
1888 interleaveShuffleMask.reserve(N: resultVectorSize);
1889 for (int i = 0, end = resultVectorSize / 2; i < end; ++i) {
1890 interleaveShuffleMask.push_back(Elt: i);
1891 interleaveShuffleMask.push_back(Elt: (resultVectorSize / 2) + i);
1892 }
1893 rewriter.replaceOpWithNewOp<LLVM::ShuffleVectorOp>(
1894 op: interleaveOp, args: adaptor.getLhs(), args: adaptor.getRhs(),
1895 args&: interleaveShuffleMask);
1896 return success();
1897 }
1898};
1899
1900/// Conversion pattern for a `vector.deinterleave`.
1901/// This supports fixed-sized vectors and scalable vectors.
1902struct VectorDeinterleaveOpLowering
1903 : public ConvertOpToLLVMPattern<vector::DeinterleaveOp> {
1904 using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern;
1905
1906 LogicalResult
1907 matchAndRewrite(vector::DeinterleaveOp deinterleaveOp, OpAdaptor adaptor,
1908 ConversionPatternRewriter &rewriter) const override {
1909 VectorType resultType = deinterleaveOp.getResultVectorType();
1910 VectorType sourceType = deinterleaveOp.getSourceVectorType();
1911 auto loc = deinterleaveOp.getLoc();
1912
1913 // Note: n-D deinterleave operations should be lowered to the 1-D before
1914 // converting to LLVM.
1915 if (resultType.getRank() != 1)
1916 return rewriter.notifyMatchFailure(arg&: deinterleaveOp,
1917 msg: "DeinterleaveOp not rank 1");
1918
1919 if (resultType.isScalable()) {
1920 auto llvmTypeConverter = this->getTypeConverter();
1921 auto deinterleaveResults = deinterleaveOp.getResultTypes();
1922 auto packedOpResults =
1923 llvmTypeConverter->packOperationResults(types: deinterleaveResults);
1924 auto intrinsic = rewriter.create<LLVM::vector_deinterleave2>(
1925 location: loc, args&: packedOpResults, args: adaptor.getSource());
1926
1927 auto evenResult = rewriter.create<LLVM::ExtractValueOp>(
1928 location: loc, args: intrinsic->getResult(idx: 0), args: 0);
1929 auto oddResult = rewriter.create<LLVM::ExtractValueOp>(
1930 location: loc, args: intrinsic->getResult(idx: 0), args: 1);
1931
1932 rewriter.replaceOp(op: deinterleaveOp, newValues: ValueRange{evenResult, oddResult});
1933 return success();
1934 }
1935 // Lower fixed-size deinterleave to two shufflevectors. While the
1936 // vector.deinterleave2 intrinsic supports fixed and scalable vectors, the
1937 // langref still recommends fixed-vectors use shufflevector, see:
1938 // https://llvm.org/docs/LangRef.html#id889.
1939 int64_t resultVectorSize = resultType.getNumElements();
1940 SmallVector<int32_t> evenShuffleMask;
1941 SmallVector<int32_t> oddShuffleMask;
1942
1943 evenShuffleMask.reserve(N: resultVectorSize);
1944 oddShuffleMask.reserve(N: resultVectorSize);
1945
1946 for (int i = 0; i < sourceType.getNumElements(); ++i) {
1947 if (i % 2 == 0)
1948 evenShuffleMask.push_back(Elt: i);
1949 else
1950 oddShuffleMask.push_back(Elt: i);
1951 }
1952
1953 auto poison = rewriter.create<LLVM::PoisonOp>(location: loc, args&: sourceType);
1954 auto evenShuffle = rewriter.create<LLVM::ShuffleVectorOp>(
1955 location: loc, args: adaptor.getSource(), args&: poison, args&: evenShuffleMask);
1956 auto oddShuffle = rewriter.create<LLVM::ShuffleVectorOp>(
1957 location: loc, args: adaptor.getSource(), args&: poison, args&: oddShuffleMask);
1958
1959 rewriter.replaceOp(op: deinterleaveOp, newValues: ValueRange{evenShuffle, oddShuffle});
1960 return success();
1961 }
1962};
1963
1964/// Conversion pattern for a `vector.from_elements`.
1965struct VectorFromElementsLowering
1966 : public ConvertOpToLLVMPattern<vector::FromElementsOp> {
1967 using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern;
1968
1969 LogicalResult
1970 matchAndRewrite(vector::FromElementsOp fromElementsOp, OpAdaptor adaptor,
1971 ConversionPatternRewriter &rewriter) const override {
1972 Location loc = fromElementsOp.getLoc();
1973 VectorType vectorType = fromElementsOp.getType();
1974 // TODO: Multi-dimensional vectors lower to !llvm.array<... x vector<>>.
1975 // Such ops should be handled in the same way as vector.insert.
1976 if (vectorType.getRank() > 1)
1977 return rewriter.notifyMatchFailure(arg&: fromElementsOp,
1978 msg: "rank > 1 vectors are not supported");
1979 Type llvmType = typeConverter->convertType(t: vectorType);
1980 Value result = rewriter.create<LLVM::PoisonOp>(location: loc, args&: llvmType);
1981 for (auto [idx, val] : llvm::enumerate(First: adaptor.getElements()))
1982 result = rewriter.create<vector::InsertOp>(location: loc, args&: val, args&: result, args&: idx);
1983 rewriter.replaceOp(op: fromElementsOp, newValues: result);
1984 return success();
1985 }
1986};
1987
1988/// Conversion pattern for a `vector.to_elements`.
1989struct VectorToElementsLowering
1990 : public ConvertOpToLLVMPattern<vector::ToElementsOp> {
1991 using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern;
1992
1993 LogicalResult
1994 matchAndRewrite(vector::ToElementsOp toElementsOp, OpAdaptor adaptor,
1995 ConversionPatternRewriter &rewriter) const override {
1996 Location loc = toElementsOp.getLoc();
1997 auto idxType = typeConverter->convertType(t: rewriter.getIndexType());
1998 Value source = adaptor.getSource();
1999
2000 SmallVector<Value> results(toElementsOp->getNumResults());
2001 for (auto [idx, element] : llvm::enumerate(First: toElementsOp.getElements())) {
2002 // Create an extractelement operation only for results that are not dead.
2003 if (element.use_empty())
2004 continue;
2005
2006 auto constIdx = rewriter.create<LLVM::ConstantOp>(
2007 location: loc, args&: idxType, args: rewriter.getIntegerAttr(type: idxType, value: idx));
2008 auto llvmType = typeConverter->convertType(t: element.getType());
2009
2010 Value result = rewriter.create<LLVM::ExtractElementOp>(location: loc, args&: llvmType,
2011 args&: source, args&: constIdx);
2012 results[idx] = result;
2013 }
2014
2015 rewriter.replaceOp(op: toElementsOp, newValues: results);
2016 return success();
2017 }
2018};
2019
2020/// Conversion pattern for vector.step.
2021struct VectorScalableStepOpLowering
2022 : public ConvertOpToLLVMPattern<vector::StepOp> {
2023 using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern;
2024
2025 LogicalResult
2026 matchAndRewrite(vector::StepOp stepOp, OpAdaptor adaptor,
2027 ConversionPatternRewriter &rewriter) const override {
2028 auto resultType = cast<VectorType>(Val: stepOp.getType());
2029 if (!resultType.isScalable()) {
2030 return failure();
2031 }
2032 Type llvmType = typeConverter->convertType(t: stepOp.getType());
2033 rewriter.replaceOpWithNewOp<LLVM::StepVectorOp>(op: stepOp, args&: llvmType);
2034 return success();
2035 }
2036};
2037
2038} // namespace
2039
2040void mlir::vector::populateVectorRankReducingFMAPattern(
2041 RewritePatternSet &patterns) {
2042 patterns.add<VectorFMAOpNDRewritePattern>(arg: patterns.getContext());
2043}
2044
2045/// Populate the given list with patterns that convert from Vector to LLVM.
2046void mlir::populateVectorToLLVMConversionPatterns(
2047 const LLVMTypeConverter &converter, RewritePatternSet &patterns,
2048 bool reassociateFPReductions, bool force32BitVectorIndices,
2049 bool useVectorAlignment) {
2050 // This function populates only ConversionPatterns, not RewritePatterns.
2051 MLIRContext *ctx = converter.getDialect()->getContext();
2052 patterns.add<VectorReductionOpConversion>(arg: converter, args&: reassociateFPReductions);
2053 patterns.add<VectorCreateMaskOpConversion>(arg&: ctx, args&: force32BitVectorIndices);
2054 patterns.add<VectorLoadStoreConversion<vector::LoadOp>,
2055 VectorLoadStoreConversion<vector::MaskedLoadOp>,
2056 VectorLoadStoreConversion<vector::StoreOp>,
2057 VectorLoadStoreConversion<vector::MaskedStoreOp>,
2058 VectorGatherOpConversion, VectorScatterOpConversion>(
2059 arg: converter, args&: useVectorAlignment);
2060 patterns.add<VectorBitCastOpConversion, VectorShuffleOpConversion,
2061 VectorExtractElementOpConversion, VectorExtractOpConversion,
2062 VectorFMAOp1DConversion, VectorInsertElementOpConversion,
2063 VectorInsertOpConversion, VectorPrintOpConversion,
2064 VectorTypeCastOpConversion, VectorScaleOpConversion,
2065 VectorExpandLoadOpConversion, VectorCompressStoreOpConversion,
2066 VectorSplatOpLowering, VectorSplatNdOpLowering,
2067 VectorScalableInsertOpLowering, VectorScalableExtractOpLowering,
2068 MaskedReductionOpConversion, VectorInterleaveOpLowering,
2069 VectorDeinterleaveOpLowering, VectorFromElementsLowering,
2070 VectorToElementsLowering, VectorScalableStepOpLowering>(
2071 arg: converter);
2072}
2073
2074void mlir::populateVectorToLLVMMatrixConversionPatterns(
2075 const LLVMTypeConverter &converter, RewritePatternSet &patterns) {
2076 patterns.add<VectorMatmulOpConversion>(arg: converter);
2077 patterns.add<VectorFlatTransposeOpConversion>(arg: converter);
2078}
2079
2080namespace {
2081struct VectorToLLVMDialectInterface : public ConvertToLLVMPatternInterface {
2082 using ConvertToLLVMPatternInterface::ConvertToLLVMPatternInterface;
2083 void loadDependentDialects(MLIRContext *context) const final {
2084 context->loadDialect<LLVM::LLVMDialect>();
2085 }
2086
2087 /// Hook for derived dialect interface to provide conversion patterns
2088 /// and mark dialect legal for the conversion target.
2089 void populateConvertToLLVMConversionPatterns(
2090 ConversionTarget &target, LLVMTypeConverter &typeConverter,
2091 RewritePatternSet &patterns) const final {
2092 populateVectorToLLVMConversionPatterns(converter: typeConverter, patterns);
2093 }
2094};
2095} // namespace
2096
2097void mlir::vector::registerConvertVectorToLLVMInterface(
2098 DialectRegistry &registry) {
2099 registry.addExtension(extensionFn: +[](MLIRContext *ctx, vector::VectorDialect *dialect) {
2100 dialect->addInterfaces<VectorToLLVMDialectInterface>();
2101 });
2102}
2103

source code of mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp