1//===- VectorToLLVM.cpp - Conversion from Vector to the LLVM dialect ------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#include "mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h"
10
11#include "mlir/Conversion/ArithCommon/AttrToLLVMConverter.h"
12#include "mlir/Conversion/LLVMCommon/PrintCallHelper.h"
13#include "mlir/Conversion/LLVMCommon/TypeConverter.h"
14#include "mlir/Conversion/LLVMCommon/VectorPattern.h"
15#include "mlir/Dialect/Arith/IR/Arith.h"
16#include "mlir/Dialect/Arith/Utils/Utils.h"
17#include "mlir/Dialect/LLVMIR/FunctionCallUtils.h"
18#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
19#include "mlir/Dialect/MemRef/IR/MemRef.h"
20#include "mlir/Dialect/Vector/IR/VectorOps.h"
21#include "mlir/Dialect/Vector/Interfaces/MaskableOpInterface.h"
22#include "mlir/Dialect/Vector/Transforms/LoweringPatterns.h"
23#include "mlir/Dialect/Vector/Transforms/VectorTransforms.h"
24#include "mlir/IR/BuiltinAttributes.h"
25#include "mlir/IR/BuiltinTypeInterfaces.h"
26#include "mlir/IR/BuiltinTypes.h"
27#include "mlir/IR/TypeUtilities.h"
28#include "mlir/Target/LLVMIR/TypeToLLVM.h"
29#include "mlir/Transforms/DialectConversion.h"
30#include "llvm/ADT/APFloat.h"
31#include "llvm/Support/Casting.h"
32#include <optional>
33
34using namespace mlir;
35using namespace mlir::vector;
36
37// Helper to reduce vector type by *all* but one rank at back.
38static VectorType reducedVectorTypeBack(VectorType tp) {
39 assert((tp.getRank() > 1) && "unlowerable vector type");
40 return VectorType::get(tp.getShape().take_back(), tp.getElementType(),
41 tp.getScalableDims().take_back());
42}
43
44// Helper that picks the proper sequence for inserting.
45static Value insertOne(ConversionPatternRewriter &rewriter,
46 const LLVMTypeConverter &typeConverter, Location loc,
47 Value val1, Value val2, Type llvmType, int64_t rank,
48 int64_t pos) {
49 assert(rank > 0 && "0-D vector corner case should have been handled already");
50 if (rank == 1) {
51 auto idxType = rewriter.getIndexType();
52 auto constant = rewriter.create<LLVM::ConstantOp>(
53 loc, typeConverter.convertType(idxType),
54 rewriter.getIntegerAttr(idxType, pos));
55 return rewriter.create<LLVM::InsertElementOp>(loc, llvmType, val1, val2,
56 constant);
57 }
58 return rewriter.create<LLVM::InsertValueOp>(loc, val1, val2, pos);
59}
60
61// Helper that picks the proper sequence for extracting.
62static Value extractOne(ConversionPatternRewriter &rewriter,
63 const LLVMTypeConverter &typeConverter, Location loc,
64 Value val, Type llvmType, int64_t rank, int64_t pos) {
65 if (rank <= 1) {
66 auto idxType = rewriter.getIndexType();
67 auto constant = rewriter.create<LLVM::ConstantOp>(
68 loc, typeConverter.convertType(idxType),
69 rewriter.getIntegerAttr(idxType, pos));
70 return rewriter.create<LLVM::ExtractElementOp>(loc, llvmType, val,
71 constant);
72 }
73 return rewriter.create<LLVM::ExtractValueOp>(loc, val, pos);
74}
75
76// Helper that returns data layout alignment of a memref.
77LogicalResult getMemRefAlignment(const LLVMTypeConverter &typeConverter,
78 MemRefType memrefType, unsigned &align) {
79 Type elementTy = typeConverter.convertType(memrefType.getElementType());
80 if (!elementTy)
81 return failure();
82
83 // TODO: this should use the MLIR data layout when it becomes available and
84 // stop depending on translation.
85 llvm::LLVMContext llvmContext;
86 align = LLVM::TypeToLLVMIRTranslator(llvmContext)
87 .getPreferredAlignment(type: elementTy, layout: typeConverter.getDataLayout());
88 return success();
89}
90
91// Check if the last stride is non-unit and has a valid memory space.
92static LogicalResult isMemRefTypeSupported(MemRefType memRefType,
93 const LLVMTypeConverter &converter) {
94 if (!isLastMemrefDimUnitStride(memRefType))
95 return failure();
96 if (failed(converter.getMemRefAddressSpace(type: memRefType)))
97 return failure();
98 return success();
99}
100
101// Add an index vector component to a base pointer.
102static Value getIndexedPtrs(ConversionPatternRewriter &rewriter, Location loc,
103 const LLVMTypeConverter &typeConverter,
104 MemRefType memRefType, Value llvmMemref, Value base,
105 Value index, uint64_t vLen) {
106 assert(succeeded(isMemRefTypeSupported(memRefType, typeConverter)) &&
107 "unsupported memref type");
108 auto pType = MemRefDescriptor(llvmMemref).getElementPtrType();
109 auto ptrsType = LLVM::getFixedVectorType(elementType: pType, numElements: vLen);
110 return rewriter.create<LLVM::GEPOp>(
111 loc, ptrsType, typeConverter.convertType(memRefType.getElementType()),
112 base, index);
113}
114
115/// Convert `foldResult` into a Value. Integer attribute is converted to
116/// an LLVM constant op.
117static Value getAsLLVMValue(OpBuilder &builder, Location loc,
118 OpFoldResult foldResult) {
119 if (auto attr = foldResult.dyn_cast<Attribute>()) {
120 auto intAttr = cast<IntegerAttr>(attr);
121 return builder.create<LLVM::ConstantOp>(loc, intAttr).getResult();
122 }
123
124 return foldResult.get<Value>();
125}
126
127namespace {
128
129/// Trivial Vector to LLVM conversions
130using VectorScaleOpConversion =
131 OneToOneConvertToLLVMPattern<vector::VectorScaleOp, LLVM::vscale>;
132
133/// Conversion pattern for a vector.bitcast.
134class VectorBitCastOpConversion
135 : public ConvertOpToLLVMPattern<vector::BitCastOp> {
136public:
137 using ConvertOpToLLVMPattern<vector::BitCastOp>::ConvertOpToLLVMPattern;
138
139 LogicalResult
140 matchAndRewrite(vector::BitCastOp bitCastOp, OpAdaptor adaptor,
141 ConversionPatternRewriter &rewriter) const override {
142 // Only 0-D and 1-D vectors can be lowered to LLVM.
143 VectorType resultTy = bitCastOp.getResultVectorType();
144 if (resultTy.getRank() > 1)
145 return failure();
146 Type newResultTy = typeConverter->convertType(resultTy);
147 rewriter.replaceOpWithNewOp<LLVM::BitcastOp>(bitCastOp, newResultTy,
148 adaptor.getOperands()[0]);
149 return success();
150 }
151};
152
153/// Conversion pattern for a vector.matrix_multiply.
154/// This is lowered directly to the proper llvm.intr.matrix.multiply.
155class VectorMatmulOpConversion
156 : public ConvertOpToLLVMPattern<vector::MatmulOp> {
157public:
158 using ConvertOpToLLVMPattern<vector::MatmulOp>::ConvertOpToLLVMPattern;
159
160 LogicalResult
161 matchAndRewrite(vector::MatmulOp matmulOp, OpAdaptor adaptor,
162 ConversionPatternRewriter &rewriter) const override {
163 rewriter.replaceOpWithNewOp<LLVM::MatrixMultiplyOp>(
164 matmulOp, typeConverter->convertType(matmulOp.getRes().getType()),
165 adaptor.getLhs(), adaptor.getRhs(), matmulOp.getLhsRows(),
166 matmulOp.getLhsColumns(), matmulOp.getRhsColumns());
167 return success();
168 }
169};
170
171/// Conversion pattern for a vector.flat_transpose.
172/// This is lowered directly to the proper llvm.intr.matrix.transpose.
173class VectorFlatTransposeOpConversion
174 : public ConvertOpToLLVMPattern<vector::FlatTransposeOp> {
175public:
176 using ConvertOpToLLVMPattern<vector::FlatTransposeOp>::ConvertOpToLLVMPattern;
177
178 LogicalResult
179 matchAndRewrite(vector::FlatTransposeOp transOp, OpAdaptor adaptor,
180 ConversionPatternRewriter &rewriter) const override {
181 rewriter.replaceOpWithNewOp<LLVM::MatrixTransposeOp>(
182 transOp, typeConverter->convertType(transOp.getRes().getType()),
183 adaptor.getMatrix(), transOp.getRows(), transOp.getColumns());
184 return success();
185 }
186};
187
188/// Overloaded utility that replaces a vector.load, vector.store,
189/// vector.maskedload and vector.maskedstore with their respective LLVM
190/// couterparts.
191static void replaceLoadOrStoreOp(vector::LoadOp loadOp,
192 vector::LoadOpAdaptor adaptor,
193 VectorType vectorTy, Value ptr, unsigned align,
194 ConversionPatternRewriter &rewriter) {
195 rewriter.replaceOpWithNewOp<LLVM::LoadOp>(loadOp, vectorTy, ptr, align,
196 /*volatile_=*/false,
197 loadOp.getNontemporal());
198}
199
200static void replaceLoadOrStoreOp(vector::MaskedLoadOp loadOp,
201 vector::MaskedLoadOpAdaptor adaptor,
202 VectorType vectorTy, Value ptr, unsigned align,
203 ConversionPatternRewriter &rewriter) {
204 rewriter.replaceOpWithNewOp<LLVM::MaskedLoadOp>(
205 loadOp, vectorTy, ptr, adaptor.getMask(), adaptor.getPassThru(), align);
206}
207
208static void replaceLoadOrStoreOp(vector::StoreOp storeOp,
209 vector::StoreOpAdaptor adaptor,
210 VectorType vectorTy, Value ptr, unsigned align,
211 ConversionPatternRewriter &rewriter) {
212 rewriter.replaceOpWithNewOp<LLVM::StoreOp>(storeOp, adaptor.getValueToStore(),
213 ptr, align, /*volatile_=*/false,
214 storeOp.getNontemporal());
215}
216
217static void replaceLoadOrStoreOp(vector::MaskedStoreOp storeOp,
218 vector::MaskedStoreOpAdaptor adaptor,
219 VectorType vectorTy, Value ptr, unsigned align,
220 ConversionPatternRewriter &rewriter) {
221 rewriter.replaceOpWithNewOp<LLVM::MaskedStoreOp>(
222 storeOp, adaptor.getValueToStore(), ptr, adaptor.getMask(), align);
223}
224
225/// Conversion pattern for a vector.load, vector.store, vector.maskedload, and
226/// vector.maskedstore.
227template <class LoadOrStoreOp>
228class VectorLoadStoreConversion : public ConvertOpToLLVMPattern<LoadOrStoreOp> {
229public:
230 using ConvertOpToLLVMPattern<LoadOrStoreOp>::ConvertOpToLLVMPattern;
231
232 LogicalResult
233 matchAndRewrite(LoadOrStoreOp loadOrStoreOp,
234 typename LoadOrStoreOp::Adaptor adaptor,
235 ConversionPatternRewriter &rewriter) const override {
236 // Only 1-D vectors can be lowered to LLVM.
237 VectorType vectorTy = loadOrStoreOp.getVectorType();
238 if (vectorTy.getRank() > 1)
239 return failure();
240
241 auto loc = loadOrStoreOp->getLoc();
242 MemRefType memRefTy = loadOrStoreOp.getMemRefType();
243
244 // Resolve alignment.
245 unsigned align;
246 if (failed(getMemRefAlignment(*this->getTypeConverter(), memRefTy, align)))
247 return failure();
248
249 // Resolve address.
250 auto vtype = cast<VectorType>(
251 this->typeConverter->convertType(loadOrStoreOp.getVectorType()));
252 Value dataPtr = this->getStridedElementPtr(loc, memRefTy, adaptor.getBase(),
253 adaptor.getIndices(), rewriter);
254 replaceLoadOrStoreOp(loadOrStoreOp, adaptor, vtype, dataPtr, align,
255 rewriter);
256 return success();
257 }
258};
259
260/// Conversion pattern for a vector.gather.
261class VectorGatherOpConversion
262 : public ConvertOpToLLVMPattern<vector::GatherOp> {
263public:
264 using ConvertOpToLLVMPattern<vector::GatherOp>::ConvertOpToLLVMPattern;
265
266 LogicalResult
267 matchAndRewrite(vector::GatherOp gather, OpAdaptor adaptor,
268 ConversionPatternRewriter &rewriter) const override {
269 MemRefType memRefType = dyn_cast<MemRefType>(gather.getBaseType());
270 assert(memRefType && "The base should be bufferized");
271
272 if (failed(isMemRefTypeSupported(memRefType, *this->getTypeConverter())))
273 return failure();
274
275 auto loc = gather->getLoc();
276
277 // Resolve alignment.
278 unsigned align;
279 if (failed(getMemRefAlignment(*getTypeConverter(), memRefType, align)))
280 return failure();
281
282 Value ptr = getStridedElementPtr(loc, memRefType, adaptor.getBase(),
283 adaptor.getIndices(), rewriter);
284 Value base = adaptor.getBase();
285
286 auto llvmNDVectorTy = adaptor.getIndexVec().getType();
287 // Handle the simple case of 1-D vector.
288 if (!isa<LLVM::LLVMArrayType>(llvmNDVectorTy)) {
289 auto vType = gather.getVectorType();
290 // Resolve address.
291 Value ptrs = getIndexedPtrs(rewriter, loc, *this->getTypeConverter(),
292 memRefType, base, ptr, adaptor.getIndexVec(),
293 /*vLen=*/vType.getDimSize(0));
294 // Replace with the gather intrinsic.
295 rewriter.replaceOpWithNewOp<LLVM::masked_gather>(
296 gather, typeConverter->convertType(vType), ptrs, adaptor.getMask(),
297 adaptor.getPassThru(), rewriter.getI32IntegerAttr(align));
298 return success();
299 }
300
301 const LLVMTypeConverter &typeConverter = *this->getTypeConverter();
302 auto callback = [align, memRefType, base, ptr, loc, &rewriter,
303 &typeConverter](Type llvm1DVectorTy,
304 ValueRange vectorOperands) {
305 // Resolve address.
306 Value ptrs = getIndexedPtrs(
307 rewriter, loc, typeConverter, memRefType, base, ptr,
308 /*index=*/vectorOperands[0],
309 LLVM::getVectorNumElements(type: llvm1DVectorTy).getFixedValue());
310 // Create the gather intrinsic.
311 return rewriter.create<LLVM::masked_gather>(
312 loc, llvm1DVectorTy, ptrs, /*mask=*/vectorOperands[1],
313 /*passThru=*/vectorOperands[2], rewriter.getI32IntegerAttr(align));
314 };
315 SmallVector<Value> vectorOperands = {
316 adaptor.getIndexVec(), adaptor.getMask(), adaptor.getPassThru()};
317 return LLVM::detail::handleMultidimensionalVectors(
318 op: gather, operands: vectorOperands, typeConverter: *getTypeConverter(), createOperand: callback, rewriter);
319 }
320};
321
322/// Conversion pattern for a vector.scatter.
323class VectorScatterOpConversion
324 : public ConvertOpToLLVMPattern<vector::ScatterOp> {
325public:
326 using ConvertOpToLLVMPattern<vector::ScatterOp>::ConvertOpToLLVMPattern;
327
328 LogicalResult
329 matchAndRewrite(vector::ScatterOp scatter, OpAdaptor adaptor,
330 ConversionPatternRewriter &rewriter) const override {
331 auto loc = scatter->getLoc();
332 MemRefType memRefType = scatter.getMemRefType();
333
334 if (failed(isMemRefTypeSupported(memRefType, *this->getTypeConverter())))
335 return failure();
336
337 // Resolve alignment.
338 unsigned align;
339 if (failed(getMemRefAlignment(*getTypeConverter(), memRefType, align)))
340 return failure();
341
342 // Resolve address.
343 VectorType vType = scatter.getVectorType();
344 Value ptr = getStridedElementPtr(loc, memRefType, adaptor.getBase(),
345 adaptor.getIndices(), rewriter);
346 Value ptrs = getIndexedPtrs(
347 rewriter, loc, *this->getTypeConverter(), memRefType, adaptor.getBase(),
348 ptr, adaptor.getIndexVec(), /*vLen=*/vType.getDimSize(0));
349
350 // Replace with the scatter intrinsic.
351 rewriter.replaceOpWithNewOp<LLVM::masked_scatter>(
352 scatter, adaptor.getValueToStore(), ptrs, adaptor.getMask(),
353 rewriter.getI32IntegerAttr(align));
354 return success();
355 }
356};
357
358/// Conversion pattern for a vector.expandload.
359class VectorExpandLoadOpConversion
360 : public ConvertOpToLLVMPattern<vector::ExpandLoadOp> {
361public:
362 using ConvertOpToLLVMPattern<vector::ExpandLoadOp>::ConvertOpToLLVMPattern;
363
364 LogicalResult
365 matchAndRewrite(vector::ExpandLoadOp expand, OpAdaptor adaptor,
366 ConversionPatternRewriter &rewriter) const override {
367 auto loc = expand->getLoc();
368 MemRefType memRefType = expand.getMemRefType();
369
370 // Resolve address.
371 auto vtype = typeConverter->convertType(expand.getVectorType());
372 Value ptr = getStridedElementPtr(loc, memRefType, adaptor.getBase(),
373 adaptor.getIndices(), rewriter);
374
375 rewriter.replaceOpWithNewOp<LLVM::masked_expandload>(
376 expand, vtype, ptr, adaptor.getMask(), adaptor.getPassThru());
377 return success();
378 }
379};
380
381/// Conversion pattern for a vector.compressstore.
382class VectorCompressStoreOpConversion
383 : public ConvertOpToLLVMPattern<vector::CompressStoreOp> {
384public:
385 using ConvertOpToLLVMPattern<vector::CompressStoreOp>::ConvertOpToLLVMPattern;
386
387 LogicalResult
388 matchAndRewrite(vector::CompressStoreOp compress, OpAdaptor adaptor,
389 ConversionPatternRewriter &rewriter) const override {
390 auto loc = compress->getLoc();
391 MemRefType memRefType = compress.getMemRefType();
392
393 // Resolve address.
394 Value ptr = getStridedElementPtr(loc, memRefType, adaptor.getBase(),
395 adaptor.getIndices(), rewriter);
396
397 rewriter.replaceOpWithNewOp<LLVM::masked_compressstore>(
398 compress, adaptor.getValueToStore(), ptr, adaptor.getMask());
399 return success();
400 }
401};
402
403/// Reduction neutral classes for overloading.
404class ReductionNeutralZero {};
405class ReductionNeutralIntOne {};
406class ReductionNeutralFPOne {};
407class ReductionNeutralAllOnes {};
408class ReductionNeutralSIntMin {};
409class ReductionNeutralUIntMin {};
410class ReductionNeutralSIntMax {};
411class ReductionNeutralUIntMax {};
412class ReductionNeutralFPMin {};
413class ReductionNeutralFPMax {};
414
415/// Create the reduction neutral zero value.
416static Value createReductionNeutralValue(ReductionNeutralZero neutral,
417 ConversionPatternRewriter &rewriter,
418 Location loc, Type llvmType) {
419 return rewriter.create<LLVM::ConstantOp>(loc, llvmType,
420 rewriter.getZeroAttr(llvmType));
421}
422
423/// Create the reduction neutral integer one value.
424static Value createReductionNeutralValue(ReductionNeutralIntOne neutral,
425 ConversionPatternRewriter &rewriter,
426 Location loc, Type llvmType) {
427 return rewriter.create<LLVM::ConstantOp>(
428 loc, llvmType, rewriter.getIntegerAttr(llvmType, 1));
429}
430
431/// Create the reduction neutral fp one value.
432static Value createReductionNeutralValue(ReductionNeutralFPOne neutral,
433 ConversionPatternRewriter &rewriter,
434 Location loc, Type llvmType) {
435 return rewriter.create<LLVM::ConstantOp>(
436 loc, llvmType, rewriter.getFloatAttr(llvmType, 1.0));
437}
438
439/// Create the reduction neutral all-ones value.
440static Value createReductionNeutralValue(ReductionNeutralAllOnes neutral,
441 ConversionPatternRewriter &rewriter,
442 Location loc, Type llvmType) {
443 return rewriter.create<LLVM::ConstantOp>(
444 loc, llvmType,
445 rewriter.getIntegerAttr(
446 llvmType, llvm::APInt::getAllOnes(llvmType.getIntOrFloatBitWidth())));
447}
448
449/// Create the reduction neutral signed int minimum value.
450static Value createReductionNeutralValue(ReductionNeutralSIntMin neutral,
451 ConversionPatternRewriter &rewriter,
452 Location loc, Type llvmType) {
453 return rewriter.create<LLVM::ConstantOp>(
454 loc, llvmType,
455 rewriter.getIntegerAttr(llvmType, llvm::APInt::getSignedMinValue(
456 llvmType.getIntOrFloatBitWidth())));
457}
458
459/// Create the reduction neutral unsigned int minimum value.
460static Value createReductionNeutralValue(ReductionNeutralUIntMin neutral,
461 ConversionPatternRewriter &rewriter,
462 Location loc, Type llvmType) {
463 return rewriter.create<LLVM::ConstantOp>(
464 loc, llvmType,
465 rewriter.getIntegerAttr(llvmType, llvm::APInt::getMinValue(
466 llvmType.getIntOrFloatBitWidth())));
467}
468
469/// Create the reduction neutral signed int maximum value.
470static Value createReductionNeutralValue(ReductionNeutralSIntMax neutral,
471 ConversionPatternRewriter &rewriter,
472 Location loc, Type llvmType) {
473 return rewriter.create<LLVM::ConstantOp>(
474 loc, llvmType,
475 rewriter.getIntegerAttr(llvmType, llvm::APInt::getSignedMaxValue(
476 llvmType.getIntOrFloatBitWidth())));
477}
478
479/// Create the reduction neutral unsigned int maximum value.
480static Value createReductionNeutralValue(ReductionNeutralUIntMax neutral,
481 ConversionPatternRewriter &rewriter,
482 Location loc, Type llvmType) {
483 return rewriter.create<LLVM::ConstantOp>(
484 loc, llvmType,
485 rewriter.getIntegerAttr(llvmType, llvm::APInt::getMaxValue(
486 llvmType.getIntOrFloatBitWidth())));
487}
488
489/// Create the reduction neutral fp minimum value.
490static Value createReductionNeutralValue(ReductionNeutralFPMin neutral,
491 ConversionPatternRewriter &rewriter,
492 Location loc, Type llvmType) {
493 auto floatType = cast<FloatType>(Val&: llvmType);
494 return rewriter.create<LLVM::ConstantOp>(
495 loc, llvmType,
496 rewriter.getFloatAttr(
497 llvmType, llvm::APFloat::getQNaN(floatType.getFloatSemantics(),
498 /*Negative=*/false)));
499}
500
501/// Create the reduction neutral fp maximum value.
502static Value createReductionNeutralValue(ReductionNeutralFPMax neutral,
503 ConversionPatternRewriter &rewriter,
504 Location loc, Type llvmType) {
505 auto floatType = cast<FloatType>(Val&: llvmType);
506 return rewriter.create<LLVM::ConstantOp>(
507 loc, llvmType,
508 rewriter.getFloatAttr(
509 llvmType, llvm::APFloat::getQNaN(floatType.getFloatSemantics(),
510 /*Negative=*/true)));
511}
512
513/// Returns `accumulator` if it has a valid value. Otherwise, creates and
514/// returns a new accumulator value using `ReductionNeutral`.
515template <class ReductionNeutral>
516static Value getOrCreateAccumulator(ConversionPatternRewriter &rewriter,
517 Location loc, Type llvmType,
518 Value accumulator) {
519 if (accumulator)
520 return accumulator;
521
522 return createReductionNeutralValue(ReductionNeutral(), rewriter, loc,
523 llvmType);
524}
525
526/// Creates a constant value with the 1-D vector shape provided in `llvmType`.
527/// This is used as effective vector length by some intrinsics supporting
528/// dynamic vector lengths at runtime.
529static Value createVectorLengthValue(ConversionPatternRewriter &rewriter,
530 Location loc, Type llvmType) {
531 VectorType vType = cast<VectorType>(llvmType);
532 auto vShape = vType.getShape();
533 assert(vShape.size() == 1 && "Unexpected multi-dim vector type");
534
535 return rewriter.create<LLVM::ConstantOp>(
536 loc, rewriter.getI32Type(),
537 rewriter.getIntegerAttr(rewriter.getI32Type(), vShape[0]));
538}
539
540/// Helper method to lower a `vector.reduction` op that performs an arithmetic
541/// operation like add,mul, etc.. `VectorOp` is the LLVM vector intrinsic to use
542/// and `ScalarOp` is the scalar operation used to add the accumulation value if
543/// non-null.
544template <class LLVMRedIntrinOp, class ScalarOp>
545static Value createIntegerReductionArithmeticOpLowering(
546 ConversionPatternRewriter &rewriter, Location loc, Type llvmType,
547 Value vectorOperand, Value accumulator) {
548
549 Value result = rewriter.create<LLVMRedIntrinOp>(loc, llvmType, vectorOperand);
550
551 if (accumulator)
552 result = rewriter.create<ScalarOp>(loc, accumulator, result);
553 return result;
554}
555
556/// Helper method to lower a `vector.reduction` operation that performs
557/// a comparison operation like `min`/`max`. `VectorOp` is the LLVM vector
558/// intrinsic to use and `predicate` is the predicate to use to compare+combine
559/// the accumulator value if non-null.
560template <class LLVMRedIntrinOp>
561static Value createIntegerReductionComparisonOpLowering(
562 ConversionPatternRewriter &rewriter, Location loc, Type llvmType,
563 Value vectorOperand, Value accumulator, LLVM::ICmpPredicate predicate) {
564 Value result = rewriter.create<LLVMRedIntrinOp>(loc, llvmType, vectorOperand);
565 if (accumulator) {
566 Value cmp =
567 rewriter.create<LLVM::ICmpOp>(loc, predicate, accumulator, result);
568 result = rewriter.create<LLVM::SelectOp>(loc, cmp, accumulator, result);
569 }
570 return result;
571}
572
573namespace {
574template <typename Source>
575struct VectorToScalarMapper;
576template <>
577struct VectorToScalarMapper<LLVM::vector_reduce_fmaximum> {
578 using Type = LLVM::MaximumOp;
579};
580template <>
581struct VectorToScalarMapper<LLVM::vector_reduce_fminimum> {
582 using Type = LLVM::MinimumOp;
583};
584template <>
585struct VectorToScalarMapper<LLVM::vector_reduce_fmax> {
586 using Type = LLVM::MaxNumOp;
587};
588template <>
589struct VectorToScalarMapper<LLVM::vector_reduce_fmin> {
590 using Type = LLVM::MinNumOp;
591};
592} // namespace
593
594template <class LLVMRedIntrinOp>
595static Value createFPReductionComparisonOpLowering(
596 ConversionPatternRewriter &rewriter, Location loc, Type llvmType,
597 Value vectorOperand, Value accumulator, LLVM::FastmathFlagsAttr fmf) {
598 Value result =
599 rewriter.create<LLVMRedIntrinOp>(loc, llvmType, vectorOperand, fmf);
600
601 if (accumulator) {
602 result =
603 rewriter.create<typename VectorToScalarMapper<LLVMRedIntrinOp>::Type>(
604 loc, result, accumulator);
605 }
606
607 return result;
608}
609
610/// Reduction neutral classes for overloading
611class MaskNeutralFMaximum {};
612class MaskNeutralFMinimum {};
613
614/// Get the mask neutral floating point maximum value
615static llvm::APFloat
616getMaskNeutralValue(MaskNeutralFMaximum,
617 const llvm::fltSemantics &floatSemantics) {
618 return llvm::APFloat::getSmallest(Sem: floatSemantics, /*Negative=*/true);
619}
620/// Get the mask neutral floating point minimum value
621static llvm::APFloat
622getMaskNeutralValue(MaskNeutralFMinimum,
623 const llvm::fltSemantics &floatSemantics) {
624 return llvm::APFloat::getLargest(Sem: floatSemantics, /*Negative=*/false);
625}
626
627/// Create the mask neutral floating point MLIR vector constant
628template <typename MaskNeutral>
629static Value createMaskNeutralValue(ConversionPatternRewriter &rewriter,
630 Location loc, Type llvmType,
631 Type vectorType) {
632 const auto &floatSemantics = cast<FloatType>(Val&: llvmType).getFloatSemantics();
633 auto value = getMaskNeutralValue(MaskNeutral{}, floatSemantics);
634 auto denseValue = DenseElementsAttr::get(cast<ShapedType>(vectorType), value);
635 return rewriter.create<LLVM::ConstantOp>(loc, vectorType, denseValue);
636}
637
638/// Lowers masked `fmaximum` and `fminimum` reductions using the non-masked
639/// intrinsics. It is a workaround to overcome the lack of masked intrinsics for
640/// `fmaximum`/`fminimum`.
641/// More information: https://github.com/llvm/llvm-project/issues/64940
642template <class LLVMRedIntrinOp, class MaskNeutral>
643static Value
644lowerMaskedReductionWithRegular(ConversionPatternRewriter &rewriter,
645 Location loc, Type llvmType,
646 Value vectorOperand, Value accumulator,
647 Value mask, LLVM::FastmathFlagsAttr fmf) {
648 const Value vectorMaskNeutral = createMaskNeutralValue<MaskNeutral>(
649 rewriter, loc, llvmType, vectorOperand.getType());
650 const Value selectedVectorByMask = rewriter.create<LLVM::SelectOp>(
651 loc, mask, vectorOperand, vectorMaskNeutral);
652 return createFPReductionComparisonOpLowering<LLVMRedIntrinOp>(
653 rewriter, loc, llvmType, selectedVectorByMask, accumulator, fmf);
654}
655
656template <class LLVMRedIntrinOp, class ReductionNeutral>
657static Value
658lowerReductionWithStartValue(ConversionPatternRewriter &rewriter, Location loc,
659 Type llvmType, Value vectorOperand,
660 Value accumulator, LLVM::FastmathFlagsAttr fmf) {
661 accumulator = getOrCreateAccumulator<ReductionNeutral>(rewriter, loc,
662 llvmType, accumulator);
663 return rewriter.create<LLVMRedIntrinOp>(loc, llvmType,
664 /*startValue=*/accumulator,
665 vectorOperand, fmf);
666}
667
668/// Overloaded methods to lower a *predicated* reduction to an llvm instrinsic
669/// that requires a start value. This start value format spans across fp
670/// reductions without mask and all the masked reduction intrinsics.
671template <class LLVMVPRedIntrinOp, class ReductionNeutral>
672static Value
673lowerPredicatedReductionWithStartValue(ConversionPatternRewriter &rewriter,
674 Location loc, Type llvmType,
675 Value vectorOperand, Value accumulator) {
676 accumulator = getOrCreateAccumulator<ReductionNeutral>(rewriter, loc,
677 llvmType, accumulator);
678 return rewriter.create<LLVMVPRedIntrinOp>(loc, llvmType,
679 /*startValue=*/accumulator,
680 vectorOperand);
681}
682
683template <class LLVMVPRedIntrinOp, class ReductionNeutral>
684static Value lowerPredicatedReductionWithStartValue(
685 ConversionPatternRewriter &rewriter, Location loc, Type llvmType,
686 Value vectorOperand, Value accumulator, Value mask) {
687 accumulator = getOrCreateAccumulator<ReductionNeutral>(rewriter, loc,
688 llvmType, accumulator);
689 Value vectorLength =
690 createVectorLengthValue(rewriter, loc, llvmType: vectorOperand.getType());
691 return rewriter.create<LLVMVPRedIntrinOp>(loc, llvmType,
692 /*startValue=*/accumulator,
693 vectorOperand, mask, vectorLength);
694}
695
696template <class LLVMIntVPRedIntrinOp, class IntReductionNeutral,
697 class LLVMFPVPRedIntrinOp, class FPReductionNeutral>
698static Value lowerPredicatedReductionWithStartValue(
699 ConversionPatternRewriter &rewriter, Location loc, Type llvmType,
700 Value vectorOperand, Value accumulator, Value mask) {
701 if (llvmType.isIntOrIndex())
702 return lowerPredicatedReductionWithStartValue<LLVMIntVPRedIntrinOp,
703 IntReductionNeutral>(
704 rewriter, loc, llvmType, vectorOperand, accumulator, mask);
705
706 // FP dispatch.
707 return lowerPredicatedReductionWithStartValue<LLVMFPVPRedIntrinOp,
708 FPReductionNeutral>(
709 rewriter, loc, llvmType, vectorOperand, accumulator, mask);
710}
711
712/// Conversion pattern for all vector reductions.
713class VectorReductionOpConversion
714 : public ConvertOpToLLVMPattern<vector::ReductionOp> {
715public:
716 explicit VectorReductionOpConversion(const LLVMTypeConverter &typeConv,
717 bool reassociateFPRed)
718 : ConvertOpToLLVMPattern<vector::ReductionOp>(typeConv),
719 reassociateFPReductions(reassociateFPRed) {}
720
721 LogicalResult
722 matchAndRewrite(vector::ReductionOp reductionOp, OpAdaptor adaptor,
723 ConversionPatternRewriter &rewriter) const override {
724 auto kind = reductionOp.getKind();
725 Type eltType = reductionOp.getDest().getType();
726 Type llvmType = typeConverter->convertType(eltType);
727 Value operand = adaptor.getVector();
728 Value acc = adaptor.getAcc();
729 Location loc = reductionOp.getLoc();
730
731 if (eltType.isIntOrIndex()) {
732 // Integer reductions: add/mul/min/max/and/or/xor.
733 Value result;
734 switch (kind) {
735 case vector::CombiningKind::ADD:
736 result =
737 createIntegerReductionArithmeticOpLowering<LLVM::vector_reduce_add,
738 LLVM::AddOp>(
739 rewriter, loc, llvmType, operand, acc);
740 break;
741 case vector::CombiningKind::MUL:
742 result =
743 createIntegerReductionArithmeticOpLowering<LLVM::vector_reduce_mul,
744 LLVM::MulOp>(
745 rewriter, loc, llvmType, operand, acc);
746 break;
747 case vector::CombiningKind::MINUI:
748 result = createIntegerReductionComparisonOpLowering<
749 LLVM::vector_reduce_umin>(rewriter, loc, llvmType, operand, acc,
750 LLVM::ICmpPredicate::ule);
751 break;
752 case vector::CombiningKind::MINSI:
753 result = createIntegerReductionComparisonOpLowering<
754 LLVM::vector_reduce_smin>(rewriter, loc, llvmType, operand, acc,
755 LLVM::ICmpPredicate::sle);
756 break;
757 case vector::CombiningKind::MAXUI:
758 result = createIntegerReductionComparisonOpLowering<
759 LLVM::vector_reduce_umax>(rewriter, loc, llvmType, operand, acc,
760 LLVM::ICmpPredicate::uge);
761 break;
762 case vector::CombiningKind::MAXSI:
763 result = createIntegerReductionComparisonOpLowering<
764 LLVM::vector_reduce_smax>(rewriter, loc, llvmType, operand, acc,
765 LLVM::ICmpPredicate::sge);
766 break;
767 case vector::CombiningKind::AND:
768 result =
769 createIntegerReductionArithmeticOpLowering<LLVM::vector_reduce_and,
770 LLVM::AndOp>(
771 rewriter, loc, llvmType, operand, acc);
772 break;
773 case vector::CombiningKind::OR:
774 result =
775 createIntegerReductionArithmeticOpLowering<LLVM::vector_reduce_or,
776 LLVM::OrOp>(
777 rewriter, loc, llvmType, operand, acc);
778 break;
779 case vector::CombiningKind::XOR:
780 result =
781 createIntegerReductionArithmeticOpLowering<LLVM::vector_reduce_xor,
782 LLVM::XOrOp>(
783 rewriter, loc, llvmType, operand, acc);
784 break;
785 default:
786 return failure();
787 }
788 rewriter.replaceOp(reductionOp, result);
789
790 return success();
791 }
792
793 if (!isa<FloatType>(Val: eltType))
794 return failure();
795
796 arith::FastMathFlagsAttr fMFAttr = reductionOp.getFastMathFlagsAttr();
797 LLVM::FastmathFlagsAttr fmf = LLVM::FastmathFlagsAttr::get(
798 reductionOp.getContext(),
799 convertArithFastMathFlagsToLLVM(fMFAttr.getValue()));
800 fmf = LLVM::FastmathFlagsAttr::get(
801 reductionOp.getContext(),
802 fmf.getValue() | (reassociateFPReductions ? LLVM::FastmathFlags::reassoc
803 : LLVM::FastmathFlags::none));
804
805 // Floating-point reductions: add/mul/min/max
806 Value result;
807 if (kind == vector::CombiningKind::ADD) {
808 result = lowerReductionWithStartValue<LLVM::vector_reduce_fadd,
809 ReductionNeutralZero>(
810 rewriter, loc, llvmType, operand, acc, fmf);
811 } else if (kind == vector::CombiningKind::MUL) {
812 result = lowerReductionWithStartValue<LLVM::vector_reduce_fmul,
813 ReductionNeutralFPOne>(
814 rewriter, loc, llvmType, operand, acc, fmf);
815 } else if (kind == vector::CombiningKind::MINIMUMF) {
816 result =
817 createFPReductionComparisonOpLowering<LLVM::vector_reduce_fminimum>(
818 rewriter, loc, llvmType, operand, acc, fmf);
819 } else if (kind == vector::CombiningKind::MAXIMUMF) {
820 result =
821 createFPReductionComparisonOpLowering<LLVM::vector_reduce_fmaximum>(
822 rewriter, loc, llvmType, operand, acc, fmf);
823 } else if (kind == vector::CombiningKind::MINNUMF) {
824 result = createFPReductionComparisonOpLowering<LLVM::vector_reduce_fmin>(
825 rewriter, loc, llvmType, operand, acc, fmf);
826 } else if (kind == vector::CombiningKind::MAXNUMF) {
827 result = createFPReductionComparisonOpLowering<LLVM::vector_reduce_fmax>(
828 rewriter, loc, llvmType, operand, acc, fmf);
829 } else
830 return failure();
831
832 rewriter.replaceOp(reductionOp, result);
833 return success();
834 }
835
836private:
837 const bool reassociateFPReductions;
838};
839
840/// Base class to convert a `vector.mask` operation while matching traits
841/// of the maskable operation nested inside. A `VectorMaskOpConversionBase`
842/// instance matches against a `vector.mask` operation. The `matchAndRewrite`
843/// method performs a second match against the maskable operation `MaskedOp`.
844/// Finally, it invokes the virtual method `matchAndRewriteMaskableOp` to be
845/// implemented by the concrete conversion classes. This method can match
846/// against specific traits of the `vector.mask` and the maskable operation. It
847/// must replace the `vector.mask` operation.
848template <class MaskedOp>
849class VectorMaskOpConversionBase
850 : public ConvertOpToLLVMPattern<vector::MaskOp> {
851public:
852 using ConvertOpToLLVMPattern<vector::MaskOp>::ConvertOpToLLVMPattern;
853
854 LogicalResult
855 matchAndRewrite(vector::MaskOp maskOp, OpAdaptor adaptor,
856 ConversionPatternRewriter &rewriter) const final {
857 // Match against the maskable operation kind.
858 auto maskedOp = llvm::dyn_cast_or_null<MaskedOp>(maskOp.getMaskableOp());
859 if (!maskedOp)
860 return failure();
861 return matchAndRewriteMaskableOp(maskOp, maskedOp, rewriter);
862 }
863
864protected:
865 virtual LogicalResult
866 matchAndRewriteMaskableOp(vector::MaskOp maskOp,
867 vector::MaskableOpInterface maskableOp,
868 ConversionPatternRewriter &rewriter) const = 0;
869};
870
871class MaskedReductionOpConversion
872 : public VectorMaskOpConversionBase<vector::ReductionOp> {
873
874public:
875 using VectorMaskOpConversionBase<
876 vector::ReductionOp>::VectorMaskOpConversionBase;
877
878 LogicalResult matchAndRewriteMaskableOp(
879 vector::MaskOp maskOp, MaskableOpInterface maskableOp,
880 ConversionPatternRewriter &rewriter) const override {
881 auto reductionOp = cast<ReductionOp>(maskableOp.getOperation());
882 auto kind = reductionOp.getKind();
883 Type eltType = reductionOp.getDest().getType();
884 Type llvmType = typeConverter->convertType(eltType);
885 Value operand = reductionOp.getVector();
886 Value acc = reductionOp.getAcc();
887 Location loc = reductionOp.getLoc();
888
889 arith::FastMathFlagsAttr fMFAttr = reductionOp.getFastMathFlagsAttr();
890 LLVM::FastmathFlagsAttr fmf = LLVM::FastmathFlagsAttr::get(
891 reductionOp.getContext(),
892 convertArithFastMathFlagsToLLVM(fMFAttr.getValue()));
893
894 Value result;
895 switch (kind) {
896 case vector::CombiningKind::ADD:
897 result = lowerPredicatedReductionWithStartValue<
898 LLVM::VPReduceAddOp, ReductionNeutralZero, LLVM::VPReduceFAddOp,
899 ReductionNeutralZero>(rewriter, loc, llvmType, operand, acc,
900 maskOp.getMask());
901 break;
902 case vector::CombiningKind::MUL:
903 result = lowerPredicatedReductionWithStartValue<
904 LLVM::VPReduceMulOp, ReductionNeutralIntOne, LLVM::VPReduceFMulOp,
905 ReductionNeutralFPOne>(rewriter, loc, llvmType, operand, acc,
906 maskOp.getMask());
907 break;
908 case vector::CombiningKind::MINUI:
909 result = lowerPredicatedReductionWithStartValue<LLVM::VPReduceUMinOp,
910 ReductionNeutralUIntMax>(
911 rewriter, loc, llvmType, operand, acc, maskOp.getMask());
912 break;
913 case vector::CombiningKind::MINSI:
914 result = lowerPredicatedReductionWithStartValue<LLVM::VPReduceSMinOp,
915 ReductionNeutralSIntMax>(
916 rewriter, loc, llvmType, operand, acc, maskOp.getMask());
917 break;
918 case vector::CombiningKind::MAXUI:
919 result = lowerPredicatedReductionWithStartValue<LLVM::VPReduceUMaxOp,
920 ReductionNeutralUIntMin>(
921 rewriter, loc, llvmType, operand, acc, maskOp.getMask());
922 break;
923 case vector::CombiningKind::MAXSI:
924 result = lowerPredicatedReductionWithStartValue<LLVM::VPReduceSMaxOp,
925 ReductionNeutralSIntMin>(
926 rewriter, loc, llvmType, operand, acc, maskOp.getMask());
927 break;
928 case vector::CombiningKind::AND:
929 result = lowerPredicatedReductionWithStartValue<LLVM::VPReduceAndOp,
930 ReductionNeutralAllOnes>(
931 rewriter, loc, llvmType, operand, acc, maskOp.getMask());
932 break;
933 case vector::CombiningKind::OR:
934 result = lowerPredicatedReductionWithStartValue<LLVM::VPReduceOrOp,
935 ReductionNeutralZero>(
936 rewriter, loc, llvmType, operand, acc, maskOp.getMask());
937 break;
938 case vector::CombiningKind::XOR:
939 result = lowerPredicatedReductionWithStartValue<LLVM::VPReduceXorOp,
940 ReductionNeutralZero>(
941 rewriter, loc, llvmType, operand, acc, maskOp.getMask());
942 break;
943 case vector::CombiningKind::MINNUMF:
944 result = lowerPredicatedReductionWithStartValue<LLVM::VPReduceFMinOp,
945 ReductionNeutralFPMax>(
946 rewriter, loc, llvmType, operand, acc, maskOp.getMask());
947 break;
948 case vector::CombiningKind::MAXNUMF:
949 result = lowerPredicatedReductionWithStartValue<LLVM::VPReduceFMaxOp,
950 ReductionNeutralFPMin>(
951 rewriter, loc, llvmType, operand, acc, maskOp.getMask());
952 break;
953 case CombiningKind::MAXIMUMF:
954 result = lowerMaskedReductionWithRegular<LLVM::vector_reduce_fmaximum,
955 MaskNeutralFMaximum>(
956 rewriter, loc, llvmType, operand, acc, maskOp.getMask(), fmf);
957 break;
958 case CombiningKind::MINIMUMF:
959 result = lowerMaskedReductionWithRegular<LLVM::vector_reduce_fminimum,
960 MaskNeutralFMinimum>(
961 rewriter, loc, llvmType, operand, acc, maskOp.getMask(), fmf);
962 break;
963 }
964
965 // Replace `vector.mask` operation altogether.
966 rewriter.replaceOp(maskOp, result);
967 return success();
968 }
969};
970
971class VectorShuffleOpConversion
972 : public ConvertOpToLLVMPattern<vector::ShuffleOp> {
973public:
974 using ConvertOpToLLVMPattern<vector::ShuffleOp>::ConvertOpToLLVMPattern;
975
976 LogicalResult
977 matchAndRewrite(vector::ShuffleOp shuffleOp, OpAdaptor adaptor,
978 ConversionPatternRewriter &rewriter) const override {
979 auto loc = shuffleOp->getLoc();
980 auto v1Type = shuffleOp.getV1VectorType();
981 auto v2Type = shuffleOp.getV2VectorType();
982 auto vectorType = shuffleOp.getResultVectorType();
983 Type llvmType = typeConverter->convertType(vectorType);
984 auto maskArrayAttr = shuffleOp.getMask();
985
986 // Bail if result type cannot be lowered.
987 if (!llvmType)
988 return failure();
989
990 // Get rank and dimension sizes.
991 int64_t rank = vectorType.getRank();
992#ifndef NDEBUG
993 bool wellFormed0DCase =
994 v1Type.getRank() == 0 && v2Type.getRank() == 0 && rank == 1;
995 bool wellFormedNDCase =
996 v1Type.getRank() == rank && v2Type.getRank() == rank;
997 assert((wellFormed0DCase || wellFormedNDCase) && "op is not well-formed");
998#endif
999
1000 // For rank 0 and 1, where both operands have *exactly* the same vector
1001 // type, there is direct shuffle support in LLVM. Use it!
1002 if (rank <= 1 && v1Type == v2Type) {
1003 Value llvmShuffleOp = rewriter.create<LLVM::ShuffleVectorOp>(
1004 loc, adaptor.getV1(), adaptor.getV2(),
1005 LLVM::convertArrayToIndices<int32_t>(maskArrayAttr));
1006 rewriter.replaceOp(shuffleOp, llvmShuffleOp);
1007 return success();
1008 }
1009
1010 // For all other cases, insert the individual values individually.
1011 int64_t v1Dim = v1Type.getDimSize(0);
1012 Type eltType;
1013 if (auto arrayType = dyn_cast<LLVM::LLVMArrayType>(llvmType))
1014 eltType = arrayType.getElementType();
1015 else
1016 eltType = cast<VectorType>(llvmType).getElementType();
1017 Value insert = rewriter.create<LLVM::UndefOp>(loc, llvmType);
1018 int64_t insPos = 0;
1019 for (const auto &en : llvm::enumerate(maskArrayAttr)) {
1020 int64_t extPos = cast<IntegerAttr>(en.value()).getInt();
1021 Value value = adaptor.getV1();
1022 if (extPos >= v1Dim) {
1023 extPos -= v1Dim;
1024 value = adaptor.getV2();
1025 }
1026 Value extract = extractOne(rewriter, *getTypeConverter(), loc, value,
1027 eltType, rank, extPos);
1028 insert = insertOne(rewriter, *getTypeConverter(), loc, insert, extract,
1029 llvmType, rank, insPos++);
1030 }
1031 rewriter.replaceOp(shuffleOp, insert);
1032 return success();
1033 }
1034};
1035
1036class VectorExtractElementOpConversion
1037 : public ConvertOpToLLVMPattern<vector::ExtractElementOp> {
1038public:
1039 using ConvertOpToLLVMPattern<
1040 vector::ExtractElementOp>::ConvertOpToLLVMPattern;
1041
1042 LogicalResult
1043 matchAndRewrite(vector::ExtractElementOp extractEltOp, OpAdaptor adaptor,
1044 ConversionPatternRewriter &rewriter) const override {
1045 auto vectorType = extractEltOp.getSourceVectorType();
1046 auto llvmType = typeConverter->convertType(vectorType.getElementType());
1047
1048 // Bail if result type cannot be lowered.
1049 if (!llvmType)
1050 return failure();
1051
1052 if (vectorType.getRank() == 0) {
1053 Location loc = extractEltOp.getLoc();
1054 auto idxType = rewriter.getIndexType();
1055 auto zero = rewriter.create<LLVM::ConstantOp>(
1056 loc, typeConverter->convertType(idxType),
1057 rewriter.getIntegerAttr(idxType, 0));
1058 rewriter.replaceOpWithNewOp<LLVM::ExtractElementOp>(
1059 extractEltOp, llvmType, adaptor.getVector(), zero);
1060 return success();
1061 }
1062
1063 rewriter.replaceOpWithNewOp<LLVM::ExtractElementOp>(
1064 extractEltOp, llvmType, adaptor.getVector(), adaptor.getPosition());
1065 return success();
1066 }
1067};
1068
1069class VectorExtractOpConversion
1070 : public ConvertOpToLLVMPattern<vector::ExtractOp> {
1071public:
1072 using ConvertOpToLLVMPattern<vector::ExtractOp>::ConvertOpToLLVMPattern;
1073
1074 LogicalResult
1075 matchAndRewrite(vector::ExtractOp extractOp, OpAdaptor adaptor,
1076 ConversionPatternRewriter &rewriter) const override {
1077 auto loc = extractOp->getLoc();
1078 auto resultType = extractOp.getResult().getType();
1079 auto llvmResultType = typeConverter->convertType(resultType);
1080 // Bail if result type cannot be lowered.
1081 if (!llvmResultType)
1082 return failure();
1083
1084 SmallVector<OpFoldResult> positionVec = getMixedValues(
1085 adaptor.getStaticPosition(), adaptor.getDynamicPosition(), rewriter);
1086
1087 // Extract entire vector. Should be handled by folder, but just to be safe.
1088 ArrayRef<OpFoldResult> position(positionVec);
1089 if (position.empty()) {
1090 rewriter.replaceOp(extractOp, adaptor.getVector());
1091 return success();
1092 }
1093
1094 // One-shot extraction of vector from array (only requires extractvalue).
1095 if (isa<VectorType>(resultType)) {
1096 if (extractOp.hasDynamicPosition())
1097 return failure();
1098
1099 Value extracted = rewriter.create<LLVM::ExtractValueOp>(
1100 loc, adaptor.getVector(), getAsIntegers(position));
1101 rewriter.replaceOp(extractOp, extracted);
1102 return success();
1103 }
1104
1105 // Potential extraction of 1-D vector from array.
1106 Value extracted = adaptor.getVector();
1107 if (position.size() > 1) {
1108 if (extractOp.hasDynamicPosition())
1109 return failure();
1110
1111 SmallVector<int64_t> nMinusOnePosition =
1112 getAsIntegers(foldResults: position.drop_back());
1113 extracted = rewriter.create<LLVM::ExtractValueOp>(loc, extracted,
1114 nMinusOnePosition);
1115 }
1116
1117 Value lastPosition = getAsLLVMValue(rewriter, loc, position.back());
1118 // Remaining extraction of element from 1-D LLVM vector.
1119 rewriter.replaceOpWithNewOp<LLVM::ExtractElementOp>(extractOp, extracted,
1120 lastPosition);
1121 return success();
1122 }
1123};
1124
1125/// Conversion pattern that turns a vector.fma on a 1-D vector
1126/// into an llvm.intr.fmuladd. This is a trivial 1-1 conversion.
1127/// This does not match vectors of n >= 2 rank.
1128///
1129/// Example:
1130/// ```
1131/// vector.fma %a, %a, %a : vector<8xf32>
1132/// ```
1133/// is converted to:
1134/// ```
1135/// llvm.intr.fmuladd %va, %va, %va:
1136/// (!llvm."<8 x f32>">, !llvm<"<8 x f32>">, !llvm<"<8 x f32>">)
1137/// -> !llvm."<8 x f32>">
1138/// ```
1139class VectorFMAOp1DConversion : public ConvertOpToLLVMPattern<vector::FMAOp> {
1140public:
1141 using ConvertOpToLLVMPattern<vector::FMAOp>::ConvertOpToLLVMPattern;
1142
1143 LogicalResult
1144 matchAndRewrite(vector::FMAOp fmaOp, OpAdaptor adaptor,
1145 ConversionPatternRewriter &rewriter) const override {
1146 VectorType vType = fmaOp.getVectorType();
1147 if (vType.getRank() > 1)
1148 return failure();
1149
1150 rewriter.replaceOpWithNewOp<LLVM::FMulAddOp>(
1151 fmaOp, adaptor.getLhs(), adaptor.getRhs(), adaptor.getAcc());
1152 return success();
1153 }
1154};
1155
1156class VectorInsertElementOpConversion
1157 : public ConvertOpToLLVMPattern<vector::InsertElementOp> {
1158public:
1159 using ConvertOpToLLVMPattern<vector::InsertElementOp>::ConvertOpToLLVMPattern;
1160
1161 LogicalResult
1162 matchAndRewrite(vector::InsertElementOp insertEltOp, OpAdaptor adaptor,
1163 ConversionPatternRewriter &rewriter) const override {
1164 auto vectorType = insertEltOp.getDestVectorType();
1165 auto llvmType = typeConverter->convertType(vectorType);
1166
1167 // Bail if result type cannot be lowered.
1168 if (!llvmType)
1169 return failure();
1170
1171 if (vectorType.getRank() == 0) {
1172 Location loc = insertEltOp.getLoc();
1173 auto idxType = rewriter.getIndexType();
1174 auto zero = rewriter.create<LLVM::ConstantOp>(
1175 loc, typeConverter->convertType(idxType),
1176 rewriter.getIntegerAttr(idxType, 0));
1177 rewriter.replaceOpWithNewOp<LLVM::InsertElementOp>(
1178 insertEltOp, llvmType, adaptor.getDest(), adaptor.getSource(), zero);
1179 return success();
1180 }
1181
1182 rewriter.replaceOpWithNewOp<LLVM::InsertElementOp>(
1183 insertEltOp, llvmType, adaptor.getDest(), adaptor.getSource(),
1184 adaptor.getPosition());
1185 return success();
1186 }
1187};
1188
1189class VectorInsertOpConversion
1190 : public ConvertOpToLLVMPattern<vector::InsertOp> {
1191public:
1192 using ConvertOpToLLVMPattern<vector::InsertOp>::ConvertOpToLLVMPattern;
1193
1194 LogicalResult
1195 matchAndRewrite(vector::InsertOp insertOp, OpAdaptor adaptor,
1196 ConversionPatternRewriter &rewriter) const override {
1197 auto loc = insertOp->getLoc();
1198 auto sourceType = insertOp.getSourceType();
1199 auto destVectorType = insertOp.getDestVectorType();
1200 auto llvmResultType = typeConverter->convertType(destVectorType);
1201 // Bail if result type cannot be lowered.
1202 if (!llvmResultType)
1203 return failure();
1204
1205 SmallVector<OpFoldResult> positionVec = getMixedValues(
1206 adaptor.getStaticPosition(), adaptor.getDynamicPosition(), rewriter);
1207
1208 // Overwrite entire vector with value. Should be handled by folder, but
1209 // just to be safe.
1210 ArrayRef<OpFoldResult> position(positionVec);
1211 if (position.empty()) {
1212 rewriter.replaceOp(insertOp, adaptor.getSource());
1213 return success();
1214 }
1215
1216 // One-shot insertion of a vector into an array (only requires insertvalue).
1217 if (isa<VectorType>(sourceType)) {
1218 if (insertOp.hasDynamicPosition())
1219 return failure();
1220
1221 Value inserted = rewriter.create<LLVM::InsertValueOp>(
1222 loc, adaptor.getDest(), adaptor.getSource(), getAsIntegers(position));
1223 rewriter.replaceOp(insertOp, inserted);
1224 return success();
1225 }
1226
1227 // Potential extraction of 1-D vector from array.
1228 Value extracted = adaptor.getDest();
1229 auto oneDVectorType = destVectorType;
1230 if (position.size() > 1) {
1231 if (insertOp.hasDynamicPosition())
1232 return failure();
1233
1234 oneDVectorType = reducedVectorTypeBack(destVectorType);
1235 extracted = rewriter.create<LLVM::ExtractValueOp>(
1236 loc, extracted, getAsIntegers(position.drop_back()));
1237 }
1238
1239 // Insertion of an element into a 1-D LLVM vector.
1240 Value inserted = rewriter.create<LLVM::InsertElementOp>(
1241 loc, typeConverter->convertType(oneDVectorType), extracted,
1242 adaptor.getSource(), getAsLLVMValue(rewriter, loc, position.back()));
1243
1244 // Potential insertion of resulting 1-D vector into array.
1245 if (position.size() > 1) {
1246 if (insertOp.hasDynamicPosition())
1247 return failure();
1248
1249 inserted = rewriter.create<LLVM::InsertValueOp>(
1250 loc, adaptor.getDest(), inserted,
1251 getAsIntegers(position.drop_back()));
1252 }
1253
1254 rewriter.replaceOp(insertOp, inserted);
1255 return success();
1256 }
1257};
1258
1259/// Lower vector.scalable.insert ops to LLVM vector.insert
1260struct VectorScalableInsertOpLowering
1261 : public ConvertOpToLLVMPattern<vector::ScalableInsertOp> {
1262 using ConvertOpToLLVMPattern<
1263 vector::ScalableInsertOp>::ConvertOpToLLVMPattern;
1264
1265 LogicalResult
1266 matchAndRewrite(vector::ScalableInsertOp insOp, OpAdaptor adaptor,
1267 ConversionPatternRewriter &rewriter) const override {
1268 rewriter.replaceOpWithNewOp<LLVM::vector_insert>(
1269 insOp, adaptor.getDest(), adaptor.getSource(), adaptor.getPos());
1270 return success();
1271 }
1272};
1273
1274/// Lower vector.scalable.extract ops to LLVM vector.extract
1275struct VectorScalableExtractOpLowering
1276 : public ConvertOpToLLVMPattern<vector::ScalableExtractOp> {
1277 using ConvertOpToLLVMPattern<
1278 vector::ScalableExtractOp>::ConvertOpToLLVMPattern;
1279
1280 LogicalResult
1281 matchAndRewrite(vector::ScalableExtractOp extOp, OpAdaptor adaptor,
1282 ConversionPatternRewriter &rewriter) const override {
1283 rewriter.replaceOpWithNewOp<LLVM::vector_extract>(
1284 extOp, typeConverter->convertType(extOp.getResultVectorType()),
1285 adaptor.getSource(), adaptor.getPos());
1286 return success();
1287 }
1288};
1289
1290/// Rank reducing rewrite for n-D FMA into (n-1)-D FMA where n > 1.
1291///
1292/// Example:
1293/// ```
1294/// %d = vector.fma %a, %b, %c : vector<2x4xf32>
1295/// ```
1296/// is rewritten into:
1297/// ```
1298/// %r = splat %f0: vector<2x4xf32>
1299/// %va = vector.extractvalue %a[0] : vector<2x4xf32>
1300/// %vb = vector.extractvalue %b[0] : vector<2x4xf32>
1301/// %vc = vector.extractvalue %c[0] : vector<2x4xf32>
1302/// %vd = vector.fma %va, %vb, %vc : vector<4xf32>
1303/// %r2 = vector.insertvalue %vd, %r[0] : vector<4xf32> into vector<2x4xf32>
1304/// %va2 = vector.extractvalue %a2[1] : vector<2x4xf32>
1305/// %vb2 = vector.extractvalue %b2[1] : vector<2x4xf32>
1306/// %vc2 = vector.extractvalue %c2[1] : vector<2x4xf32>
1307/// %vd2 = vector.fma %va2, %vb2, %vc2 : vector<4xf32>
1308/// %r3 = vector.insertvalue %vd2, %r2[1] : vector<4xf32> into vector<2x4xf32>
1309/// // %r3 holds the final value.
1310/// ```
1311class VectorFMAOpNDRewritePattern : public OpRewritePattern<FMAOp> {
1312public:
1313 using OpRewritePattern<FMAOp>::OpRewritePattern;
1314
1315 void initialize() {
1316 // This pattern recursively unpacks one dimension at a time. The recursion
1317 // bounded as the rank is strictly decreasing.
1318 setHasBoundedRewriteRecursion();
1319 }
1320
1321 LogicalResult matchAndRewrite(FMAOp op,
1322 PatternRewriter &rewriter) const override {
1323 auto vType = op.getVectorType();
1324 if (vType.getRank() < 2)
1325 return failure();
1326
1327 auto loc = op.getLoc();
1328 auto elemType = vType.getElementType();
1329 Value zero = rewriter.create<arith::ConstantOp>(
1330 loc, elemType, rewriter.getZeroAttr(elemType));
1331 Value desc = rewriter.create<vector::SplatOp>(loc, vType, zero);
1332 for (int64_t i = 0, e = vType.getShape().front(); i != e; ++i) {
1333 Value extrLHS = rewriter.create<ExtractOp>(loc, op.getLhs(), i);
1334 Value extrRHS = rewriter.create<ExtractOp>(loc, op.getRhs(), i);
1335 Value extrACC = rewriter.create<ExtractOp>(loc, op.getAcc(), i);
1336 Value fma = rewriter.create<FMAOp>(loc, extrLHS, extrRHS, extrACC);
1337 desc = rewriter.create<InsertOp>(loc, fma, desc, i);
1338 }
1339 rewriter.replaceOp(op, desc);
1340 return success();
1341 }
1342};
1343
1344/// Returns the strides if the memory underlying `memRefType` has a contiguous
1345/// static layout.
1346static std::optional<SmallVector<int64_t, 4>>
1347computeContiguousStrides(MemRefType memRefType) {
1348 int64_t offset;
1349 SmallVector<int64_t, 4> strides;
1350 if (failed(getStridesAndOffset(memRefType, strides, offset)))
1351 return std::nullopt;
1352 if (!strides.empty() && strides.back() != 1)
1353 return std::nullopt;
1354 // If no layout or identity layout, this is contiguous by definition.
1355 if (memRefType.getLayout().isIdentity())
1356 return strides;
1357
1358 // Otherwise, we must determine contiguity form shapes. This can only ever
1359 // work in static cases because MemRefType is underspecified to represent
1360 // contiguous dynamic shapes in other ways than with just empty/identity
1361 // layout.
1362 auto sizes = memRefType.getShape();
1363 for (int index = 0, e = strides.size() - 1; index < e; ++index) {
1364 if (ShapedType::isDynamic(sizes[index + 1]) ||
1365 ShapedType::isDynamic(strides[index]) ||
1366 ShapedType::isDynamic(strides[index + 1]))
1367 return std::nullopt;
1368 if (strides[index] != strides[index + 1] * sizes[index + 1])
1369 return std::nullopt;
1370 }
1371 return strides;
1372}
1373
1374class VectorTypeCastOpConversion
1375 : public ConvertOpToLLVMPattern<vector::TypeCastOp> {
1376public:
1377 using ConvertOpToLLVMPattern<vector::TypeCastOp>::ConvertOpToLLVMPattern;
1378
1379 LogicalResult
1380 matchAndRewrite(vector::TypeCastOp castOp, OpAdaptor adaptor,
1381 ConversionPatternRewriter &rewriter) const override {
1382 auto loc = castOp->getLoc();
1383 MemRefType sourceMemRefType =
1384 cast<MemRefType>(castOp.getOperand().getType());
1385 MemRefType targetMemRefType = castOp.getType();
1386
1387 // Only static shape casts supported atm.
1388 if (!sourceMemRefType.hasStaticShape() ||
1389 !targetMemRefType.hasStaticShape())
1390 return failure();
1391
1392 auto llvmSourceDescriptorTy =
1393 dyn_cast<LLVM::LLVMStructType>(adaptor.getOperands()[0].getType());
1394 if (!llvmSourceDescriptorTy)
1395 return failure();
1396 MemRefDescriptor sourceMemRef(adaptor.getOperands()[0]);
1397
1398 auto llvmTargetDescriptorTy = dyn_cast_or_null<LLVM::LLVMStructType>(
1399 typeConverter->convertType(targetMemRefType));
1400 if (!llvmTargetDescriptorTy)
1401 return failure();
1402
1403 // Only contiguous source buffers supported atm.
1404 auto sourceStrides = computeContiguousStrides(sourceMemRefType);
1405 if (!sourceStrides)
1406 return failure();
1407 auto targetStrides = computeContiguousStrides(targetMemRefType);
1408 if (!targetStrides)
1409 return failure();
1410 // Only support static strides for now, regardless of contiguity.
1411 if (llvm::any_of(*targetStrides, ShapedType::isDynamic))
1412 return failure();
1413
1414 auto int64Ty = IntegerType::get(rewriter.getContext(), 64);
1415
1416 // Create descriptor.
1417 auto desc = MemRefDescriptor::undef(builder&: rewriter, loc: loc, descriptorType: llvmTargetDescriptorTy);
1418 // Set allocated ptr.
1419 Value allocated = sourceMemRef.allocatedPtr(builder&: rewriter, loc: loc);
1420 desc.setAllocatedPtr(rewriter, loc, allocated);
1421
1422 // Set aligned ptr.
1423 Value ptr = sourceMemRef.alignedPtr(builder&: rewriter, loc: loc);
1424 desc.setAlignedPtr(rewriter, loc, ptr);
1425 // Fill offset 0.
1426 auto attr = rewriter.getIntegerAttr(rewriter.getIndexType(), 0);
1427 auto zero = rewriter.create<LLVM::ConstantOp>(loc, int64Ty, attr);
1428 desc.setOffset(rewriter, loc, zero);
1429
1430 // Fill size and stride descriptors in memref.
1431 for (const auto &indexedSize :
1432 llvm::enumerate(targetMemRefType.getShape())) {
1433 int64_t index = indexedSize.index();
1434 auto sizeAttr =
1435 rewriter.getIntegerAttr(rewriter.getIndexType(), indexedSize.value());
1436 auto size = rewriter.create<LLVM::ConstantOp>(loc, int64Ty, sizeAttr);
1437 desc.setSize(rewriter, loc, index, size);
1438 auto strideAttr = rewriter.getIntegerAttr(rewriter.getIndexType(),
1439 (*targetStrides)[index]);
1440 auto stride = rewriter.create<LLVM::ConstantOp>(loc, int64Ty, strideAttr);
1441 desc.setStride(rewriter, loc, index, stride);
1442 }
1443
1444 rewriter.replaceOp(castOp, {desc});
1445 return success();
1446 }
1447};
1448
1449/// Conversion pattern for a `vector.create_mask` (1-D scalable vectors only).
1450/// Non-scalable versions of this operation are handled in Vector Transforms.
1451class VectorCreateMaskOpRewritePattern
1452 : public OpRewritePattern<vector::CreateMaskOp> {
1453public:
1454 explicit VectorCreateMaskOpRewritePattern(MLIRContext *context,
1455 bool enableIndexOpt)
1456 : OpRewritePattern<vector::CreateMaskOp>(context),
1457 force32BitVectorIndices(enableIndexOpt) {}
1458
1459 LogicalResult matchAndRewrite(vector::CreateMaskOp op,
1460 PatternRewriter &rewriter) const override {
1461 auto dstType = op.getType();
1462 if (dstType.getRank() != 1 || !cast<VectorType>(dstType).isScalable())
1463 return failure();
1464 IntegerType idxType =
1465 force32BitVectorIndices ? rewriter.getI32Type() : rewriter.getI64Type();
1466 auto loc = op->getLoc();
1467 Value indices = rewriter.create<LLVM::StepVectorOp>(
1468 loc, LLVM::getVectorType(idxType, dstType.getShape()[0],
1469 /*isScalable=*/true));
1470 auto bound = getValueOrCreateCastToIndexLike(rewriter, loc, idxType,
1471 op.getOperand(0));
1472 Value bounds = rewriter.create<SplatOp>(loc, indices.getType(), bound);
1473 Value comp = rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::slt,
1474 indices, bounds);
1475 rewriter.replaceOp(op, comp);
1476 return success();
1477 }
1478
1479private:
1480 const bool force32BitVectorIndices;
1481};
1482
1483class VectorPrintOpConversion : public ConvertOpToLLVMPattern<vector::PrintOp> {
1484public:
1485 using ConvertOpToLLVMPattern<vector::PrintOp>::ConvertOpToLLVMPattern;
1486
1487 // Lowering implementation that relies on a small runtime support library,
1488 // which only needs to provide a few printing methods (single value for all
1489 // data types, opening/closing bracket, comma, newline). The lowering splits
1490 // the vector into elementary printing operations. The advantage of this
1491 // approach is that the library can remain unaware of all low-level
1492 // implementation details of vectors while still supporting output of any
1493 // shaped and dimensioned vector.
1494 //
1495 // Note: This lowering only handles scalars, n-D vectors are broken into
1496 // printing scalars in loops in VectorToSCF.
1497 //
1498 // TODO: rely solely on libc in future? something else?
1499 //
1500 LogicalResult
1501 matchAndRewrite(vector::PrintOp printOp, OpAdaptor adaptor,
1502 ConversionPatternRewriter &rewriter) const override {
1503 auto parent = printOp->getParentOfType<ModuleOp>();
1504 if (!parent)
1505 return failure();
1506
1507 auto loc = printOp->getLoc();
1508
1509 if (auto value = adaptor.getSource()) {
1510 Type printType = printOp.getPrintType();
1511 if (isa<VectorType>(Val: printType)) {
1512 // Vectors should be broken into elementary print ops in VectorToSCF.
1513 return failure();
1514 }
1515 if (failed(emitScalarPrint(rewriter, parent: parent, loc: loc, printType, value: value)))
1516 return failure();
1517 }
1518
1519 auto punct = printOp.getPunctuation();
1520 if (auto stringLiteral = printOp.getStringLiteral()) {
1521 LLVM::createPrintStrCall(builder&: rewriter, loc: loc, moduleOp: parent, symbolName: "vector_print_str",
1522 string: *stringLiteral, typeConverter: *getTypeConverter(),
1523 /*addNewline=*/false);
1524 } else if (punct != PrintPunctuation::NoPunctuation) {
1525 emitCall(rewriter, loc: printOp->getLoc(), ref: [&] {
1526 switch (punct) {
1527 case PrintPunctuation::Close:
1528 return LLVM::lookupOrCreatePrintCloseFn(parent);
1529 case PrintPunctuation::Open:
1530 return LLVM::lookupOrCreatePrintOpenFn(parent);
1531 case PrintPunctuation::Comma:
1532 return LLVM::lookupOrCreatePrintCommaFn(parent);
1533 case PrintPunctuation::NewLine:
1534 return LLVM::lookupOrCreatePrintNewlineFn(parent);
1535 default:
1536 llvm_unreachable("unexpected punctuation");
1537 }
1538 }());
1539 }
1540
1541 rewriter.eraseOp(op: printOp);
1542 return success();
1543 }
1544
1545private:
1546 enum class PrintConversion {
1547 // clang-format off
1548 None,
1549 ZeroExt64,
1550 SignExt64,
1551 Bitcast16
1552 // clang-format on
1553 };
1554
1555 LogicalResult emitScalarPrint(ConversionPatternRewriter &rewriter,
1556 ModuleOp parent, Location loc, Type printType,
1557 Value value) const {
1558 if (typeConverter->convertType(printType) == nullptr)
1559 return failure();
1560
1561 // Make sure element type has runtime support.
1562 PrintConversion conversion = PrintConversion::None;
1563 Operation *printer;
1564 if (printType.isF32()) {
1565 printer = LLVM::lookupOrCreatePrintF32Fn(moduleOp: parent);
1566 } else if (printType.isF64()) {
1567 printer = LLVM::lookupOrCreatePrintF64Fn(moduleOp: parent);
1568 } else if (printType.isF16()) {
1569 conversion = PrintConversion::Bitcast16; // bits!
1570 printer = LLVM::lookupOrCreatePrintF16Fn(moduleOp: parent);
1571 } else if (printType.isBF16()) {
1572 conversion = PrintConversion::Bitcast16; // bits!
1573 printer = LLVM::lookupOrCreatePrintBF16Fn(moduleOp: parent);
1574 } else if (printType.isIndex()) {
1575 printer = LLVM::lookupOrCreatePrintU64Fn(moduleOp: parent);
1576 } else if (auto intTy = dyn_cast<IntegerType>(printType)) {
1577 // Integers need a zero or sign extension on the operand
1578 // (depending on the source type) as well as a signed or
1579 // unsigned print method. Up to 64-bit is supported.
1580 unsigned width = intTy.getWidth();
1581 if (intTy.isUnsigned()) {
1582 if (width <= 64) {
1583 if (width < 64)
1584 conversion = PrintConversion::ZeroExt64;
1585 printer = LLVM::lookupOrCreatePrintU64Fn(moduleOp: parent);
1586 } else {
1587 return failure();
1588 }
1589 } else {
1590 assert(intTy.isSignless() || intTy.isSigned());
1591 if (width <= 64) {
1592 // Note that we *always* zero extend booleans (1-bit integers),
1593 // so that true/false is printed as 1/0 rather than -1/0.
1594 if (width == 1)
1595 conversion = PrintConversion::ZeroExt64;
1596 else if (width < 64)
1597 conversion = PrintConversion::SignExt64;
1598 printer = LLVM::lookupOrCreatePrintI64Fn(moduleOp: parent);
1599 } else {
1600 return failure();
1601 }
1602 }
1603 } else {
1604 return failure();
1605 }
1606
1607 switch (conversion) {
1608 case PrintConversion::ZeroExt64:
1609 value = rewriter.create<arith::ExtUIOp>(
1610 loc, IntegerType::get(rewriter.getContext(), 64), value);
1611 break;
1612 case PrintConversion::SignExt64:
1613 value = rewriter.create<arith::ExtSIOp>(
1614 loc, IntegerType::get(rewriter.getContext(), 64), value);
1615 break;
1616 case PrintConversion::Bitcast16:
1617 value = rewriter.create<LLVM::BitcastOp>(
1618 loc, IntegerType::get(rewriter.getContext(), 16), value);
1619 break;
1620 case PrintConversion::None:
1621 break;
1622 }
1623 emitCall(rewriter, loc, ref: printer, params: value);
1624 return success();
1625 }
1626
1627 // Helper to emit a call.
1628 static void emitCall(ConversionPatternRewriter &rewriter, Location loc,
1629 Operation *ref, ValueRange params = ValueRange()) {
1630 rewriter.create<LLVM::CallOp>(loc, TypeRange(), SymbolRefAttr::get(ref),
1631 params);
1632 }
1633};
1634
1635/// The Splat operation is lowered to an insertelement + a shufflevector
1636/// operation. Splat to only 0-d and 1-d vector result types are lowered.
1637struct VectorSplatOpLowering : public ConvertOpToLLVMPattern<vector::SplatOp> {
1638 using ConvertOpToLLVMPattern<vector::SplatOp>::ConvertOpToLLVMPattern;
1639
1640 LogicalResult
1641 matchAndRewrite(vector::SplatOp splatOp, OpAdaptor adaptor,
1642 ConversionPatternRewriter &rewriter) const override {
1643 VectorType resultType = cast<VectorType>(splatOp.getType());
1644 if (resultType.getRank() > 1)
1645 return failure();
1646
1647 // First insert it into an undef vector so we can shuffle it.
1648 auto vectorType = typeConverter->convertType(splatOp.getType());
1649 Value undef = rewriter.create<LLVM::UndefOp>(splatOp.getLoc(), vectorType);
1650 auto zero = rewriter.create<LLVM::ConstantOp>(
1651 splatOp.getLoc(),
1652 typeConverter->convertType(rewriter.getIntegerType(32)),
1653 rewriter.getZeroAttr(rewriter.getIntegerType(32)));
1654
1655 // For 0-d vector, we simply do `insertelement`.
1656 if (resultType.getRank() == 0) {
1657 rewriter.replaceOpWithNewOp<LLVM::InsertElementOp>(
1658 splatOp, vectorType, undef, adaptor.getInput(), zero);
1659 return success();
1660 }
1661
1662 // For 1-d vector, we additionally do a `vectorshuffle`.
1663 auto v = rewriter.create<LLVM::InsertElementOp>(
1664 splatOp.getLoc(), vectorType, undef, adaptor.getInput(), zero);
1665
1666 int64_t width = cast<VectorType>(splatOp.getType()).getDimSize(0);
1667 SmallVector<int32_t> zeroValues(width, 0);
1668
1669 // Shuffle the value across the desired number of elements.
1670 rewriter.replaceOpWithNewOp<LLVM::ShuffleVectorOp>(splatOp, v, undef,
1671 zeroValues);
1672 return success();
1673 }
1674};
1675
1676/// The Splat operation is lowered to an insertelement + a shufflevector
1677/// operation. Splat to only 2+-d vector result types are lowered by the
1678/// SplatNdOpLowering, the 1-d case is handled by SplatOpLowering.
1679struct VectorSplatNdOpLowering : public ConvertOpToLLVMPattern<SplatOp> {
1680 using ConvertOpToLLVMPattern<SplatOp>::ConvertOpToLLVMPattern;
1681
1682 LogicalResult
1683 matchAndRewrite(SplatOp splatOp, OpAdaptor adaptor,
1684 ConversionPatternRewriter &rewriter) const override {
1685 VectorType resultType = splatOp.getType();
1686 if (resultType.getRank() <= 1)
1687 return failure();
1688
1689 // First insert it into an undef vector so we can shuffle it.
1690 auto loc = splatOp.getLoc();
1691 auto vectorTypeInfo =
1692 LLVM::detail::extractNDVectorTypeInfo(vectorType: resultType, converter: *getTypeConverter());
1693 auto llvmNDVectorTy = vectorTypeInfo.llvmNDVectorTy;
1694 auto llvm1DVectorTy = vectorTypeInfo.llvm1DVectorTy;
1695 if (!llvmNDVectorTy || !llvm1DVectorTy)
1696 return failure();
1697
1698 // Construct returned value.
1699 Value desc = rewriter.create<LLVM::UndefOp>(loc, llvmNDVectorTy);
1700
1701 // Construct a 1-D vector with the splatted value that we insert in all the
1702 // places within the returned descriptor.
1703 Value vdesc = rewriter.create<LLVM::UndefOp>(loc, llvm1DVectorTy);
1704 auto zero = rewriter.create<LLVM::ConstantOp>(
1705 loc, typeConverter->convertType(rewriter.getIntegerType(32)),
1706 rewriter.getZeroAttr(rewriter.getIntegerType(32)));
1707 Value v = rewriter.create<LLVM::InsertElementOp>(loc, llvm1DVectorTy, vdesc,
1708 adaptor.getInput(), zero);
1709
1710 // Shuffle the value across the desired number of elements.
1711 int64_t width = resultType.getDimSize(resultType.getRank() - 1);
1712 SmallVector<int32_t> zeroValues(width, 0);
1713 v = rewriter.create<LLVM::ShuffleVectorOp>(loc, v, v, zeroValues);
1714
1715 // Iterate of linear index, convert to coords space and insert splatted 1-D
1716 // vector in each position.
1717 nDVectorIterate(vectorTypeInfo, rewriter, [&](ArrayRef<int64_t> position) {
1718 desc = rewriter.create<LLVM::InsertValueOp>(loc, desc, v, position);
1719 });
1720 rewriter.replaceOp(splatOp, desc);
1721 return success();
1722 }
1723};
1724
1725/// Conversion pattern for a `vector.interleave`.
1726/// This supports fixed-sized vectors and scalable vectors.
1727struct VectorInterleaveOpLowering
1728 : public ConvertOpToLLVMPattern<vector::InterleaveOp> {
1729 using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern;
1730
1731 LogicalResult
1732 matchAndRewrite(vector::InterleaveOp interleaveOp, OpAdaptor adaptor,
1733 ConversionPatternRewriter &rewriter) const override {
1734 VectorType resultType = interleaveOp.getResultVectorType();
1735 // n-D interleaves should have been lowered already.
1736 if (resultType.getRank() != 1)
1737 return rewriter.notifyMatchFailure(interleaveOp,
1738 "InterleaveOp not rank 1");
1739 // If the result is rank 1, then this directly maps to LLVM.
1740 if (resultType.isScalable()) {
1741 rewriter.replaceOpWithNewOp<LLVM::experimental_vector_interleave2>(
1742 interleaveOp, typeConverter->convertType(resultType),
1743 adaptor.getLhs(), adaptor.getRhs());
1744 return success();
1745 }
1746 // Lower fixed-size interleaves to a shufflevector. While the
1747 // vector.interleave2 intrinsic supports fixed and scalable vectors, the
1748 // langref still recommends fixed-vectors use shufflevector, see:
1749 // https://llvm.org/docs/LangRef.html#id876.
1750 int64_t resultVectorSize = resultType.getNumElements();
1751 SmallVector<int32_t> interleaveShuffleMask;
1752 interleaveShuffleMask.reserve(N: resultVectorSize);
1753 for (int i = 0, end = resultVectorSize / 2; i < end; ++i) {
1754 interleaveShuffleMask.push_back(Elt: i);
1755 interleaveShuffleMask.push_back(Elt: (resultVectorSize / 2) + i);
1756 }
1757 rewriter.replaceOpWithNewOp<LLVM::ShuffleVectorOp>(
1758 interleaveOp, adaptor.getLhs(), adaptor.getRhs(),
1759 interleaveShuffleMask);
1760 return success();
1761 }
1762};
1763
1764} // namespace
1765
1766/// Populate the given list with patterns that convert from Vector to LLVM.
1767void mlir::populateVectorToLLVMConversionPatterns(
1768 LLVMTypeConverter &converter, RewritePatternSet &patterns,
1769 bool reassociateFPReductions, bool force32BitVectorIndices) {
1770 MLIRContext *ctx = converter.getDialect()->getContext();
1771 patterns.add<VectorFMAOpNDRewritePattern>(arg&: ctx);
1772 populateVectorInsertExtractStridedSliceTransforms(patterns);
1773 patterns.add<VectorReductionOpConversion>(arg&: converter, args&: reassociateFPReductions);
1774 patterns.add<VectorCreateMaskOpRewritePattern>(arg&: ctx, args&: force32BitVectorIndices);
1775 patterns.add<VectorBitCastOpConversion, VectorShuffleOpConversion,
1776 VectorExtractElementOpConversion, VectorExtractOpConversion,
1777 VectorFMAOp1DConversion, VectorInsertElementOpConversion,
1778 VectorInsertOpConversion, VectorPrintOpConversion,
1779 VectorTypeCastOpConversion, VectorScaleOpConversion,
1780 VectorLoadStoreConversion<vector::LoadOp>,
1781 VectorLoadStoreConversion<vector::MaskedLoadOp>,
1782 VectorLoadStoreConversion<vector::StoreOp>,
1783 VectorLoadStoreConversion<vector::MaskedStoreOp>,
1784 VectorGatherOpConversion, VectorScatterOpConversion,
1785 VectorExpandLoadOpConversion, VectorCompressStoreOpConversion,
1786 VectorSplatOpLowering, VectorSplatNdOpLowering,
1787 VectorScalableInsertOpLowering, VectorScalableExtractOpLowering,
1788 MaskedReductionOpConversion, VectorInterleaveOpLowering>(
1789 converter);
1790 // Transfer ops with rank > 1 are handled by VectorToSCF.
1791 populateVectorTransferLoweringPatterns(patterns, /*maxTransferRank=*/1);
1792}
1793
1794void mlir::populateVectorToLLVMMatrixConversionPatterns(
1795 LLVMTypeConverter &converter, RewritePatternSet &patterns) {
1796 patterns.add<VectorMatmulOpConversion>(arg&: converter);
1797 patterns.add<VectorFlatTransposeOpConversion>(arg&: converter);
1798}
1799

source code of mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp