1//===- VectorToSPIRV.cpp - Vector to SPIR-V Patterns ----------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implements patterns to convert Vector dialect to SPIRV dialect.
10//
11//===----------------------------------------------------------------------===//
12
13#include "mlir/Conversion/VectorToSPIRV/VectorToSPIRV.h"
14
15#include "mlir/Dialect/Arith/IR/Arith.h"
16#include "mlir/Dialect/SPIRV/IR/SPIRVOps.h"
17#include "mlir/Dialect/SPIRV/IR/SPIRVTypes.h"
18#include "mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h"
19#include "mlir/Dialect/Utils/StaticValueUtils.h"
20#include "mlir/Dialect/Vector/IR/VectorOps.h"
21#include "mlir/IR/Attributes.h"
22#include "mlir/IR/BuiltinAttributes.h"
23#include "mlir/IR/BuiltinTypes.h"
24#include "mlir/IR/Location.h"
25#include "mlir/IR/Matchers.h"
26#include "mlir/IR/PatternMatch.h"
27#include "mlir/IR/TypeUtilities.h"
28#include "mlir/Transforms/DialectConversion.h"
29#include "llvm/ADT/ArrayRef.h"
30#include "llvm/ADT/STLExtras.h"
31#include "llvm/ADT/SmallVector.h"
32#include "llvm/ADT/SmallVectorExtras.h"
33#include "llvm/Support/FormatVariadic.h"
34#include <cassert>
35#include <cstdint>
36#include <numeric>
37
38using namespace mlir;
39
40/// Returns the integer value from the first valid input element, assuming Value
41/// inputs are defined by a constant index ops and Attribute inputs are integer
42/// attributes.
43static uint64_t getFirstIntValue(ArrayAttr attr) {
44 return (*attr.getAsValueRange<IntegerAttr>().begin()).getZExtValue();
45}
46
47/// Returns the number of bits for the given scalar/vector type.
48static int getNumBits(Type type) {
49 // TODO: This does not take into account any memory layout or widening
50 // constraints. E.g., a vector<3xi57> may report to occupy 3x57=171 bit, even
51 // though in practice it will likely be stored as in a 4xi64 vector register.
52 if (auto vectorType = dyn_cast<VectorType>(Val&: type))
53 return vectorType.getNumElements() * vectorType.getElementTypeBitWidth();
54 return type.getIntOrFloatBitWidth();
55}
56
57namespace {
58
59struct VectorShapeCast final : public OpConversionPattern<vector::ShapeCastOp> {
60 using OpConversionPattern::OpConversionPattern;
61
62 LogicalResult
63 matchAndRewrite(vector::ShapeCastOp shapeCastOp, OpAdaptor adaptor,
64 ConversionPatternRewriter &rewriter) const override {
65 Type dstType = getTypeConverter()->convertType(t: shapeCastOp.getType());
66 if (!dstType)
67 return failure();
68
69 // If dstType is same as the source type or the vector size is 1, it can be
70 // directly replaced by the source.
71 if (dstType == adaptor.getSource().getType() ||
72 shapeCastOp.getResultVectorType().getNumElements() == 1) {
73 rewriter.replaceOp(op: shapeCastOp, newValues: adaptor.getSource());
74 return success();
75 }
76
77 // Lowering for size-n vectors when n > 1 hasn't been implemented.
78 return failure();
79 }
80};
81
82struct VectorBitcastConvert final
83 : public OpConversionPattern<vector::BitCastOp> {
84 using OpConversionPattern::OpConversionPattern;
85
86 LogicalResult
87 matchAndRewrite(vector::BitCastOp bitcastOp, OpAdaptor adaptor,
88 ConversionPatternRewriter &rewriter) const override {
89 Type dstType = getTypeConverter()->convertType(t: bitcastOp.getType());
90 if (!dstType)
91 return failure();
92
93 if (dstType == adaptor.getSource().getType()) {
94 rewriter.replaceOp(op: bitcastOp, newValues: adaptor.getSource());
95 return success();
96 }
97
98 // Check that the source and destination type have the same bitwidth.
99 // Depending on the target environment, we may need to emulate certain
100 // types, which can cause issue with bitcast.
101 Type srcType = adaptor.getSource().getType();
102 if (getNumBits(type: dstType) != getNumBits(type: srcType)) {
103 return rewriter.notifyMatchFailure(
104 arg&: bitcastOp,
105 msg: llvm::formatv(Fmt: "different source ({0}) and target ({1}) bitwidth",
106 Vals&: srcType, Vals&: dstType));
107 }
108
109 rewriter.replaceOpWithNewOp<spirv::BitcastOp>(op: bitcastOp, args&: dstType,
110 args: adaptor.getSource());
111 return success();
112 }
113};
114
115struct VectorBroadcastConvert final
116 : public OpConversionPattern<vector::BroadcastOp> {
117 using OpConversionPattern::OpConversionPattern;
118
119 LogicalResult
120 matchAndRewrite(vector::BroadcastOp castOp, OpAdaptor adaptor,
121 ConversionPatternRewriter &rewriter) const override {
122 Type resultType =
123 getTypeConverter()->convertType(t: castOp.getResultVectorType());
124 if (!resultType)
125 return failure();
126
127 if (isa<spirv::ScalarType>(Val: resultType)) {
128 rewriter.replaceOp(op: castOp, newValues: adaptor.getSource());
129 return success();
130 }
131
132 SmallVector<Value, 4> source(castOp.getResultVectorType().getNumElements(),
133 adaptor.getSource());
134 rewriter.replaceOpWithNewOp<spirv::CompositeConstructOp>(op: castOp, args&: resultType,
135 args&: source);
136 return success();
137 }
138};
139
140// SPIR-V does not have a concept of a poison index for certain instructions,
141// which creates a UB hazard when lowering from otherwise equivalent Vector
142// dialect instructions, because this index will be considered out-of-bounds.
143// To avoid this, this function implements a dynamic sanitization that returns
144// some arbitrary safe index. For power-of-two vector sizes, this uses a bitmask
145// (presumably more efficient), and otherwise index 0 (always in-bounds).
146static Value sanitizeDynamicIndex(ConversionPatternRewriter &rewriter,
147 Location loc, Value dynamicIndex,
148 int64_t kPoisonIndex, unsigned vectorSize) {
149 if (llvm::isPowerOf2_32(Value: vectorSize)) {
150 Value inBoundsMask = rewriter.create<spirv::ConstantOp>(
151 location: loc, args: dynamicIndex.getType(),
152 args: rewriter.getIntegerAttr(type: dynamicIndex.getType(), value: vectorSize - 1));
153 return rewriter.create<spirv::BitwiseAndOp>(location: loc, args&: dynamicIndex,
154 args&: inBoundsMask);
155 }
156 Value poisonIndex = rewriter.create<spirv::ConstantOp>(
157 location: loc, args: dynamicIndex.getType(),
158 args: rewriter.getIntegerAttr(type: dynamicIndex.getType(), value: kPoisonIndex));
159 Value cmpResult =
160 rewriter.create<spirv::IEqualOp>(location: loc, args&: dynamicIndex, args&: poisonIndex);
161 return rewriter.create<spirv::SelectOp>(
162 location: loc, args&: cmpResult,
163 args: spirv::ConstantOp::getZero(type: dynamicIndex.getType(), loc, builder&: rewriter),
164 args&: dynamicIndex);
165}
166
167struct VectorExtractOpConvert final
168 : public OpConversionPattern<vector::ExtractOp> {
169 using OpConversionPattern::OpConversionPattern;
170
171 LogicalResult
172 matchAndRewrite(vector::ExtractOp extractOp, OpAdaptor adaptor,
173 ConversionPatternRewriter &rewriter) const override {
174 Type dstType = getTypeConverter()->convertType(t: extractOp.getType());
175 if (!dstType)
176 return failure();
177
178 if (isa<spirv::ScalarType>(Val: adaptor.getVector().getType())) {
179 rewriter.replaceOp(op: extractOp, newValues: adaptor.getVector());
180 return success();
181 }
182
183 if (std::optional<int64_t> id =
184 getConstantIntValue(ofr: extractOp.getMixedPosition()[0])) {
185 if (id == vector::ExtractOp::kPoisonIndex)
186 return rewriter.notifyMatchFailure(
187 arg&: extractOp,
188 msg: "Static use of poison index handled elsewhere (folded to poison)");
189 rewriter.replaceOpWithNewOp<spirv::CompositeExtractOp>(
190 op: extractOp, args&: dstType, args: adaptor.getVector(),
191 args: rewriter.getI32ArrayAttr(values: id.value()));
192 } else {
193 Value sanitizedIndex = sanitizeDynamicIndex(
194 rewriter, loc: extractOp.getLoc(), dynamicIndex: adaptor.getDynamicPosition()[0],
195 kPoisonIndex: vector::ExtractOp::kPoisonIndex,
196 vectorSize: extractOp.getSourceVectorType().getNumElements());
197 rewriter.replaceOpWithNewOp<spirv::VectorExtractDynamicOp>(
198 op: extractOp, args&: dstType, args: adaptor.getVector(), args&: sanitizedIndex);
199 }
200 return success();
201 }
202};
203
204struct VectorExtractStridedSliceOpConvert final
205 : public OpConversionPattern<vector::ExtractStridedSliceOp> {
206 using OpConversionPattern::OpConversionPattern;
207
208 LogicalResult
209 matchAndRewrite(vector::ExtractStridedSliceOp extractOp, OpAdaptor adaptor,
210 ConversionPatternRewriter &rewriter) const override {
211 Type dstType = getTypeConverter()->convertType(t: extractOp.getType());
212 if (!dstType)
213 return failure();
214
215 uint64_t offset = getFirstIntValue(attr: extractOp.getOffsets());
216 uint64_t size = getFirstIntValue(attr: extractOp.getSizes());
217 uint64_t stride = getFirstIntValue(attr: extractOp.getStrides());
218 if (stride != 1)
219 return failure();
220
221 Value srcVector = adaptor.getOperands().front();
222
223 // Extract vector<1xT> case.
224 if (isa<spirv::ScalarType>(Val: dstType)) {
225 rewriter.replaceOpWithNewOp<spirv::CompositeExtractOp>(op: extractOp,
226 args&: srcVector, args&: offset);
227 return success();
228 }
229
230 SmallVector<int32_t, 2> indices(size);
231 std::iota(first: indices.begin(), last: indices.end(), value: offset);
232
233 rewriter.replaceOpWithNewOp<spirv::VectorShuffleOp>(
234 op: extractOp, args&: dstType, args&: srcVector, args&: srcVector,
235 args: rewriter.getI32ArrayAttr(values: indices));
236
237 return success();
238 }
239};
240
241template <class SPIRVFMAOp>
242struct VectorFmaOpConvert final : public OpConversionPattern<vector::FMAOp> {
243 using OpConversionPattern::OpConversionPattern;
244
245 LogicalResult
246 matchAndRewrite(vector::FMAOp fmaOp, OpAdaptor adaptor,
247 ConversionPatternRewriter &rewriter) const override {
248 Type dstType = getTypeConverter()->convertType(fmaOp.getType());
249 if (!dstType)
250 return failure();
251 rewriter.replaceOpWithNewOp<SPIRVFMAOp>(fmaOp, dstType, adaptor.getLhs(),
252 adaptor.getRhs(), adaptor.getAcc());
253 return success();
254 }
255};
256
257struct VectorFromElementsOpConvert final
258 : public OpConversionPattern<vector::FromElementsOp> {
259 using OpConversionPattern::OpConversionPattern;
260
261 LogicalResult
262 matchAndRewrite(vector::FromElementsOp op, OpAdaptor adaptor,
263 ConversionPatternRewriter &rewriter) const override {
264 Type resultType = getTypeConverter()->convertType(t: op.getType());
265 if (!resultType)
266 return failure();
267 OperandRange elements = op.getElements();
268 if (isa<spirv::ScalarType>(Val: resultType)) {
269 // In the case with a single scalar operand / single-element result,
270 // pass through the scalar.
271 rewriter.replaceOp(op, newValues: elements[0]);
272 return success();
273 }
274 // SPIRVTypeConverter rejects vectors with rank > 1, so multi-dimensional
275 // vector.from_elements cases should not need to be handled, only 1d.
276 assert(cast<VectorType>(resultType).getRank() == 1);
277 rewriter.replaceOpWithNewOp<spirv::CompositeConstructOp>(op, args&: resultType,
278 args&: elements);
279 return success();
280 }
281};
282
283struct VectorInsertOpConvert final
284 : public OpConversionPattern<vector::InsertOp> {
285 using OpConversionPattern::OpConversionPattern;
286
287 LogicalResult
288 matchAndRewrite(vector::InsertOp insertOp, OpAdaptor adaptor,
289 ConversionPatternRewriter &rewriter) const override {
290 if (isa<VectorType>(Val: insertOp.getValueToStoreType()))
291 return rewriter.notifyMatchFailure(arg&: insertOp, msg: "unsupported vector source");
292 if (!getTypeConverter()->convertType(t: insertOp.getDestVectorType()))
293 return rewriter.notifyMatchFailure(arg&: insertOp,
294 msg: "unsupported dest vector type");
295
296 // Special case for inserting scalar values into size-1 vectors.
297 if (insertOp.getValueToStoreType().isIntOrFloat() &&
298 insertOp.getDestVectorType().getNumElements() == 1) {
299 rewriter.replaceOp(op: insertOp, newValues: adaptor.getValueToStore());
300 return success();
301 }
302
303 if (std::optional<int64_t> id =
304 getConstantIntValue(ofr: insertOp.getMixedPosition()[0])) {
305 if (id == vector::InsertOp::kPoisonIndex)
306 return rewriter.notifyMatchFailure(
307 arg&: insertOp,
308 msg: "Static use of poison index handled elsewhere (folded to poison)");
309 rewriter.replaceOpWithNewOp<spirv::CompositeInsertOp>(
310 op: insertOp, args: adaptor.getValueToStore(), args: adaptor.getDest(), args&: id.value());
311 } else {
312 Value sanitizedIndex = sanitizeDynamicIndex(
313 rewriter, loc: insertOp.getLoc(), dynamicIndex: adaptor.getDynamicPosition()[0],
314 kPoisonIndex: vector::InsertOp::kPoisonIndex,
315 vectorSize: insertOp.getDestVectorType().getNumElements());
316 rewriter.replaceOpWithNewOp<spirv::VectorInsertDynamicOp>(
317 op: insertOp, args: insertOp.getDest(), args: adaptor.getValueToStore(),
318 args&: sanitizedIndex);
319 }
320 return success();
321 }
322};
323
324struct VectorExtractElementOpConvert final
325 : public OpConversionPattern<vector::ExtractElementOp> {
326 using OpConversionPattern::OpConversionPattern;
327
328 LogicalResult
329 matchAndRewrite(vector::ExtractElementOp extractOp, OpAdaptor adaptor,
330 ConversionPatternRewriter &rewriter) const override {
331 Type resultType = getTypeConverter()->convertType(t: extractOp.getType());
332 if (!resultType)
333 return failure();
334
335 if (isa<spirv::ScalarType>(Val: adaptor.getVector().getType())) {
336 rewriter.replaceOp(op: extractOp, newValues: adaptor.getVector());
337 return success();
338 }
339
340 APInt cstPos;
341 if (matchPattern(value: adaptor.getPosition(), pattern: m_ConstantInt(bind_value: &cstPos)))
342 rewriter.replaceOpWithNewOp<spirv::CompositeExtractOp>(
343 op: extractOp, args&: resultType, args: adaptor.getVector(),
344 args: rewriter.getI32ArrayAttr(values: {static_cast<int>(cstPos.getSExtValue())}));
345 else
346 rewriter.replaceOpWithNewOp<spirv::VectorExtractDynamicOp>(
347 op: extractOp, args&: resultType, args: adaptor.getVector(), args: adaptor.getPosition());
348 return success();
349 }
350};
351
352struct VectorInsertElementOpConvert final
353 : public OpConversionPattern<vector::InsertElementOp> {
354 using OpConversionPattern::OpConversionPattern;
355
356 LogicalResult
357 matchAndRewrite(vector::InsertElementOp insertOp, OpAdaptor adaptor,
358 ConversionPatternRewriter &rewriter) const override {
359 Type vectorType = getTypeConverter()->convertType(t: insertOp.getType());
360 if (!vectorType)
361 return failure();
362
363 if (isa<spirv::ScalarType>(Val: vectorType)) {
364 rewriter.replaceOp(op: insertOp, newValues: adaptor.getSource());
365 return success();
366 }
367
368 APInt cstPos;
369 if (matchPattern(value: adaptor.getPosition(), pattern: m_ConstantInt(bind_value: &cstPos)))
370 rewriter.replaceOpWithNewOp<spirv::CompositeInsertOp>(
371 op: insertOp, args: adaptor.getSource(), args: adaptor.getDest(),
372 args: cstPos.getSExtValue());
373 else
374 rewriter.replaceOpWithNewOp<spirv::VectorInsertDynamicOp>(
375 op: insertOp, args&: vectorType, args: insertOp.getDest(), args: adaptor.getSource(),
376 args: adaptor.getPosition());
377 return success();
378 }
379};
380
381struct VectorInsertStridedSliceOpConvert final
382 : public OpConversionPattern<vector::InsertStridedSliceOp> {
383 using OpConversionPattern::OpConversionPattern;
384
385 LogicalResult
386 matchAndRewrite(vector::InsertStridedSliceOp insertOp, OpAdaptor adaptor,
387 ConversionPatternRewriter &rewriter) const override {
388 Value srcVector = adaptor.getOperands().front();
389 Value dstVector = adaptor.getOperands().back();
390
391 uint64_t stride = getFirstIntValue(attr: insertOp.getStrides());
392 if (stride != 1)
393 return failure();
394 uint64_t offset = getFirstIntValue(attr: insertOp.getOffsets());
395
396 if (isa<spirv::ScalarType>(Val: srcVector.getType())) {
397 assert(!isa<spirv::ScalarType>(dstVector.getType()));
398 rewriter.replaceOpWithNewOp<spirv::CompositeInsertOp>(
399 op: insertOp, args: dstVector.getType(), args&: srcVector, args&: dstVector,
400 args: rewriter.getI32ArrayAttr(values: offset));
401 return success();
402 }
403
404 uint64_t totalSize = cast<VectorType>(Val: dstVector.getType()).getNumElements();
405 uint64_t insertSize =
406 cast<VectorType>(Val: srcVector.getType()).getNumElements();
407
408 SmallVector<int32_t, 2> indices(totalSize);
409 std::iota(first: indices.begin(), last: indices.end(), value: 0);
410 std::iota(first: indices.begin() + offset, last: indices.begin() + offset + insertSize,
411 value: totalSize);
412
413 rewriter.replaceOpWithNewOp<spirv::VectorShuffleOp>(
414 op: insertOp, args: dstVector.getType(), args&: dstVector, args&: srcVector,
415 args: rewriter.getI32ArrayAttr(values: indices));
416
417 return success();
418 }
419};
420
421static SmallVector<Value> extractAllElements(
422 vector::ReductionOp reduceOp, vector::ReductionOp::Adaptor adaptor,
423 VectorType srcVectorType, ConversionPatternRewriter &rewriter) {
424 int numElements = static_cast<int>(srcVectorType.getDimSize(idx: 0));
425 SmallVector<Value> values;
426 values.reserve(N: numElements + (adaptor.getAcc() ? 1 : 0));
427 Location loc = reduceOp.getLoc();
428
429 for (int i = 0; i < numElements; ++i) {
430 values.push_back(Elt: rewriter.create<spirv::CompositeExtractOp>(
431 location: loc, args: srcVectorType.getElementType(), args: adaptor.getVector(),
432 args: rewriter.getI32ArrayAttr(values: {i})));
433 }
434 if (Value acc = adaptor.getAcc())
435 values.push_back(Elt: acc);
436
437 return values;
438}
439
440struct ReductionRewriteInfo {
441 Type resultType;
442 SmallVector<Value> extractedElements;
443};
444
445FailureOr<ReductionRewriteInfo> static getReductionInfo(
446 vector::ReductionOp op, vector::ReductionOp::Adaptor adaptor,
447 ConversionPatternRewriter &rewriter, const TypeConverter &typeConverter) {
448 Type resultType = typeConverter.convertType(t: op.getType());
449 if (!resultType)
450 return failure();
451
452 auto srcVectorType = dyn_cast<VectorType>(Val: adaptor.getVector().getType());
453 if (!srcVectorType || srcVectorType.getRank() != 1)
454 return rewriter.notifyMatchFailure(arg&: op, msg: "not a 1-D vector source");
455
456 SmallVector<Value> extractedElements =
457 extractAllElements(reduceOp: op, adaptor, srcVectorType, rewriter);
458
459 return ReductionRewriteInfo{.resultType: resultType, .extractedElements: std::move(extractedElements)};
460}
461
462template <typename SPIRVUMaxOp, typename SPIRVUMinOp, typename SPIRVSMaxOp,
463 typename SPIRVSMinOp>
464struct VectorReductionPattern final : OpConversionPattern<vector::ReductionOp> {
465 using OpConversionPattern::OpConversionPattern;
466
467 LogicalResult
468 matchAndRewrite(vector::ReductionOp reduceOp, OpAdaptor adaptor,
469 ConversionPatternRewriter &rewriter) const override {
470 auto reductionInfo =
471 getReductionInfo(reduceOp, adaptor, rewriter, *getTypeConverter());
472 if (failed(reductionInfo))
473 return failure();
474
475 auto [resultType, extractedElements] = *reductionInfo;
476 Location loc = reduceOp->getLoc();
477 Value result = extractedElements.front();
478 for (Value next : llvm::drop_begin(extractedElements)) {
479 switch (reduceOp.getKind()) {
480
481#define INT_AND_FLOAT_CASE(kind, iop, fop) \
482 case vector::CombiningKind::kind: \
483 if (llvm::isa<IntegerType>(resultType)) { \
484 result = rewriter.create<spirv::iop>(loc, resultType, result, next); \
485 } else { \
486 assert(llvm::isa<FloatType>(resultType)); \
487 result = rewriter.create<spirv::fop>(loc, resultType, result, next); \
488 } \
489 break
490
491#define INT_OR_FLOAT_CASE(kind, fop) \
492 case vector::CombiningKind::kind: \
493 result = rewriter.create<fop>(loc, resultType, result, next); \
494 break
495
496 INT_AND_FLOAT_CASE(ADD, IAddOp, FAddOp);
497 INT_AND_FLOAT_CASE(MUL, IMulOp, FMulOp);
498 INT_OR_FLOAT_CASE(MINUI, SPIRVUMinOp);
499 INT_OR_FLOAT_CASE(MINSI, SPIRVSMinOp);
500 INT_OR_FLOAT_CASE(MAXUI, SPIRVUMaxOp);
501 INT_OR_FLOAT_CASE(MAXSI, SPIRVSMaxOp);
502
503 case vector::CombiningKind::AND:
504 case vector::CombiningKind::OR:
505 case vector::CombiningKind::XOR:
506 return rewriter.notifyMatchFailure(arg&: reduceOp, msg: "unimplemented");
507 default:
508 return rewriter.notifyMatchFailure(arg&: reduceOp, msg: "not handled here");
509 }
510#undef INT_AND_FLOAT_CASE
511#undef INT_OR_FLOAT_CASE
512 }
513
514 rewriter.replaceOp(op: reduceOp, newValues: result);
515 return success();
516 }
517};
518
519template <typename SPIRVFMaxOp, typename SPIRVFMinOp>
520struct VectorReductionFloatMinMax final
521 : OpConversionPattern<vector::ReductionOp> {
522 using OpConversionPattern::OpConversionPattern;
523
524 LogicalResult
525 matchAndRewrite(vector::ReductionOp reduceOp, OpAdaptor adaptor,
526 ConversionPatternRewriter &rewriter) const override {
527 auto reductionInfo =
528 getReductionInfo(reduceOp, adaptor, rewriter, *getTypeConverter());
529 if (failed(reductionInfo))
530 return failure();
531
532 auto [resultType, extractedElements] = *reductionInfo;
533 Location loc = reduceOp->getLoc();
534 Value result = extractedElements.front();
535 for (Value next : llvm::drop_begin(extractedElements)) {
536 switch (reduceOp.getKind()) {
537
538#define INT_OR_FLOAT_CASE(kind, fop) \
539 case vector::CombiningKind::kind: \
540 result = rewriter.create<fop>(loc, resultType, result, next); \
541 break
542
543 INT_OR_FLOAT_CASE(MAXIMUMF, SPIRVFMaxOp);
544 INT_OR_FLOAT_CASE(MINIMUMF, SPIRVFMinOp);
545 INT_OR_FLOAT_CASE(MAXNUMF, SPIRVFMaxOp);
546 INT_OR_FLOAT_CASE(MINNUMF, SPIRVFMinOp);
547
548 default:
549 return rewriter.notifyMatchFailure(arg&: reduceOp, msg: "not handled here");
550 }
551#undef INT_OR_FLOAT_CASE
552 }
553
554 rewriter.replaceOp(op: reduceOp, newValues: result);
555 return success();
556 }
557};
558
559class VectorSplatPattern final : public OpConversionPattern<vector::SplatOp> {
560public:
561 using OpConversionPattern<vector::SplatOp>::OpConversionPattern;
562
563 LogicalResult
564 matchAndRewrite(vector::SplatOp op, OpAdaptor adaptor,
565 ConversionPatternRewriter &rewriter) const override {
566 Type dstType = getTypeConverter()->convertType(t: op.getType());
567 if (!dstType)
568 return failure();
569 if (isa<spirv::ScalarType>(Val: dstType)) {
570 rewriter.replaceOp(op, newValues: adaptor.getInput());
571 } else {
572 auto dstVecType = cast<VectorType>(Val&: dstType);
573 SmallVector<Value, 4> source(dstVecType.getNumElements(),
574 adaptor.getInput());
575 rewriter.replaceOpWithNewOp<spirv::CompositeConstructOp>(op, args&: dstType,
576 args&: source);
577 }
578 return success();
579 }
580};
581
582struct VectorShuffleOpConvert final
583 : public OpConversionPattern<vector::ShuffleOp> {
584 using OpConversionPattern::OpConversionPattern;
585
586 LogicalResult
587 matchAndRewrite(vector::ShuffleOp shuffleOp, OpAdaptor adaptor,
588 ConversionPatternRewriter &rewriter) const override {
589 VectorType oldResultType = shuffleOp.getResultVectorType();
590 Type newResultType = getTypeConverter()->convertType(t: oldResultType);
591 if (!newResultType)
592 return rewriter.notifyMatchFailure(arg&: shuffleOp,
593 msg: "unsupported result vector type");
594
595 auto mask = llvm::to_vector_of<int32_t>(Range: shuffleOp.getMask());
596
597 VectorType oldV1Type = shuffleOp.getV1VectorType();
598 VectorType oldV2Type = shuffleOp.getV2VectorType();
599
600 // When both operands and the result are SPIR-V vectors, emit a SPIR-V
601 // shuffle.
602 if (oldV1Type.getNumElements() > 1 && oldV2Type.getNumElements() > 1 &&
603 oldResultType.getNumElements() > 1) {
604 rewriter.replaceOpWithNewOp<spirv::VectorShuffleOp>(
605 op: shuffleOp, args&: newResultType, args: adaptor.getV1(), args: adaptor.getV2(),
606 args: rewriter.getI32ArrayAttr(values: mask));
607 return success();
608 }
609
610 // When at least one of the operands or the result becomes a scalar after
611 // type conversion for SPIR-V, extract all the required elements and
612 // construct the result vector.
613 auto getElementAtIdx = [&rewriter, loc = shuffleOp.getLoc()](
614 Value scalarOrVec, int32_t idx) -> Value {
615 if (auto vecTy = dyn_cast<VectorType>(Val: scalarOrVec.getType()))
616 return rewriter.create<spirv::CompositeExtractOp>(location: loc, args&: scalarOrVec,
617 args&: idx);
618
619 assert(idx == 0 && "Invalid scalar element index");
620 return scalarOrVec;
621 };
622
623 int32_t numV1Elems = oldV1Type.getNumElements();
624 SmallVector<Value> newOperands(mask.size());
625 for (auto [shuffleIdx, newOperand] : llvm::zip_equal(t&: mask, u&: newOperands)) {
626 Value vec = adaptor.getV1();
627 int32_t elementIdx = shuffleIdx;
628 if (elementIdx >= numV1Elems) {
629 vec = adaptor.getV2();
630 elementIdx -= numV1Elems;
631 }
632
633 newOperand = getElementAtIdx(vec, elementIdx);
634 }
635
636 // Handle the scalar result corner case.
637 if (newOperands.size() == 1) {
638 rewriter.replaceOp(op: shuffleOp, newValues: newOperands.front());
639 return success();
640 }
641
642 rewriter.replaceOpWithNewOp<spirv::CompositeConstructOp>(
643 op: shuffleOp, args&: newResultType, args&: newOperands);
644 return success();
645 }
646};
647
648struct VectorInterleaveOpConvert final
649 : public OpConversionPattern<vector::InterleaveOp> {
650 using OpConversionPattern::OpConversionPattern;
651
652 LogicalResult
653 matchAndRewrite(vector::InterleaveOp interleaveOp, OpAdaptor adaptor,
654 ConversionPatternRewriter &rewriter) const override {
655 // Check the result vector type.
656 VectorType oldResultType = interleaveOp.getResultVectorType();
657 Type newResultType = getTypeConverter()->convertType(t: oldResultType);
658 if (!newResultType)
659 return rewriter.notifyMatchFailure(arg&: interleaveOp,
660 msg: "unsupported result vector type");
661
662 // Interleave the indices.
663 VectorType sourceType = interleaveOp.getSourceVectorType();
664 int n = sourceType.getNumElements();
665
666 // Input vectors of size 1 are converted to scalars by the type converter.
667 // We cannot use `spirv::VectorShuffleOp` directly in this case, and need to
668 // use `spirv::CompositeConstructOp`.
669 if (n == 1) {
670 Value newOperands[] = {adaptor.getLhs(), adaptor.getRhs()};
671 rewriter.replaceOpWithNewOp<spirv::CompositeConstructOp>(
672 op: interleaveOp, args&: newResultType, args&: newOperands);
673 return success();
674 }
675
676 auto seq = llvm::seq<int64_t>(Size: 2 * n);
677 auto indices = llvm::map_to_vector(
678 C&: seq, F: [n](int i) { return (i % 2 ? n : 0) + i / 2; });
679
680 // Emit a SPIR-V shuffle.
681 rewriter.replaceOpWithNewOp<spirv::VectorShuffleOp>(
682 op: interleaveOp, args&: newResultType, args: adaptor.getLhs(), args: adaptor.getRhs(),
683 args: rewriter.getI32ArrayAttr(values: indices));
684
685 return success();
686 }
687};
688
689struct VectorDeinterleaveOpConvert final
690 : public OpConversionPattern<vector::DeinterleaveOp> {
691 using OpConversionPattern::OpConversionPattern;
692
693 LogicalResult
694 matchAndRewrite(vector::DeinterleaveOp deinterleaveOp, OpAdaptor adaptor,
695 ConversionPatternRewriter &rewriter) const override {
696
697 // Check the result vector type.
698 VectorType oldResultType = deinterleaveOp.getResultVectorType();
699 Type newResultType = getTypeConverter()->convertType(t: oldResultType);
700 if (!newResultType)
701 return rewriter.notifyMatchFailure(arg&: deinterleaveOp,
702 msg: "unsupported result vector type");
703
704 Location loc = deinterleaveOp->getLoc();
705
706 // Deinterleave the indices.
707 Value sourceVector = adaptor.getSource();
708 VectorType sourceType = deinterleaveOp.getSourceVectorType();
709 int n = sourceType.getNumElements();
710
711 // Output vectors of size 1 are converted to scalars by the type converter.
712 // We cannot use `spirv::VectorShuffleOp` directly in this case, and need to
713 // use `spirv::CompositeExtractOp`.
714 if (n == 2) {
715 auto elem0 = rewriter.create<spirv::CompositeExtractOp>(
716 location: loc, args&: newResultType, args&: sourceVector, args: rewriter.getI32ArrayAttr(values: {0}));
717
718 auto elem1 = rewriter.create<spirv::CompositeExtractOp>(
719 location: loc, args&: newResultType, args&: sourceVector, args: rewriter.getI32ArrayAttr(values: {1}));
720
721 rewriter.replaceOp(op: deinterleaveOp, newValues: {elem0, elem1});
722 return success();
723 }
724
725 // Indices for `shuffleEven` (result 0).
726 auto seqEven = llvm::seq<int64_t>(Size: n / 2);
727 auto indicesEven =
728 llvm::map_to_vector(C&: seqEven, F: [](int i) { return i * 2; });
729
730 // Indices for `shuffleOdd` (result 1).
731 auto seqOdd = llvm::seq<int64_t>(Size: n / 2);
732 auto indicesOdd =
733 llvm::map_to_vector(C&: seqOdd, F: [](int i) { return i * 2 + 1; });
734
735 // Create two SPIR-V shuffles.
736 auto shuffleEven = rewriter.create<spirv::VectorShuffleOp>(
737 location: loc, args&: newResultType, args&: sourceVector, args&: sourceVector,
738 args: rewriter.getI32ArrayAttr(values: indicesEven));
739
740 auto shuffleOdd = rewriter.create<spirv::VectorShuffleOp>(
741 location: loc, args&: newResultType, args&: sourceVector, args&: sourceVector,
742 args: rewriter.getI32ArrayAttr(values: indicesOdd));
743
744 rewriter.replaceOp(op: deinterleaveOp, newValues: {shuffleEven, shuffleOdd});
745 return success();
746 }
747};
748
749struct VectorLoadOpConverter final
750 : public OpConversionPattern<vector::LoadOp> {
751 using OpConversionPattern::OpConversionPattern;
752
753 LogicalResult
754 matchAndRewrite(vector::LoadOp loadOp, OpAdaptor adaptor,
755 ConversionPatternRewriter &rewriter) const override {
756 auto memrefType = loadOp.getMemRefType();
757 auto attr =
758 dyn_cast_or_null<spirv::StorageClassAttr>(Val: memrefType.getMemorySpace());
759 if (!attr)
760 return rewriter.notifyMatchFailure(
761 arg&: loadOp, msg: "expected spirv.storage_class memory space");
762
763 const auto &typeConverter = *getTypeConverter<SPIRVTypeConverter>();
764 auto loc = loadOp.getLoc();
765 Value accessChain =
766 spirv::getElementPtr(typeConverter, baseType: memrefType, basePtr: adaptor.getBase(),
767 indices: adaptor.getIndices(), loc, builder&: rewriter);
768 if (!accessChain)
769 return rewriter.notifyMatchFailure(
770 arg&: loadOp, msg: "failed to get memref element pointer");
771
772 spirv::StorageClass storageClass = attr.getValue();
773 auto vectorType = loadOp.getVectorType();
774 // Use the converted vector type instead of original (single element vector
775 // would get converted to scalar).
776 auto spirvVectorType = typeConverter.convertType(t: vectorType);
777 auto vectorPtrType = spirv::PointerType::get(pointeeType: spirvVectorType, storageClass);
778
779 // For single element vectors, we don't need to bitcast the access chain to
780 // the original vector type. Both is going to be the same, a pointer
781 // to a scalar.
782 Value castedAccessChain = (vectorType.getNumElements() == 1)
783 ? accessChain
784 : rewriter.create<spirv::BitcastOp>(
785 location: loc, args&: vectorPtrType, args&: accessChain);
786
787 rewriter.replaceOpWithNewOp<spirv::LoadOp>(op: loadOp, args&: spirvVectorType,
788 args&: castedAccessChain);
789
790 return success();
791 }
792};
793
794struct VectorStoreOpConverter final
795 : public OpConversionPattern<vector::StoreOp> {
796 using OpConversionPattern::OpConversionPattern;
797
798 LogicalResult
799 matchAndRewrite(vector::StoreOp storeOp, OpAdaptor adaptor,
800 ConversionPatternRewriter &rewriter) const override {
801 auto memrefType = storeOp.getMemRefType();
802 auto attr =
803 dyn_cast_or_null<spirv::StorageClassAttr>(Val: memrefType.getMemorySpace());
804 if (!attr)
805 return rewriter.notifyMatchFailure(
806 arg&: storeOp, msg: "expected spirv.storage_class memory space");
807
808 const auto &typeConverter = *getTypeConverter<SPIRVTypeConverter>();
809 auto loc = storeOp.getLoc();
810 Value accessChain =
811 spirv::getElementPtr(typeConverter, baseType: memrefType, basePtr: adaptor.getBase(),
812 indices: adaptor.getIndices(), loc, builder&: rewriter);
813 if (!accessChain)
814 return rewriter.notifyMatchFailure(
815 arg&: storeOp, msg: "failed to get memref element pointer");
816
817 spirv::StorageClass storageClass = attr.getValue();
818 auto vectorType = storeOp.getVectorType();
819 auto vectorPtrType = spirv::PointerType::get(pointeeType: vectorType, storageClass);
820
821 // For single element vectors, we don't need to bitcast the access chain to
822 // the original vector type. Both is going to be the same, a pointer
823 // to a scalar.
824 Value castedAccessChain = (vectorType.getNumElements() == 1)
825 ? accessChain
826 : rewriter.create<spirv::BitcastOp>(
827 location: loc, args&: vectorPtrType, args&: accessChain);
828
829 rewriter.replaceOpWithNewOp<spirv::StoreOp>(op: storeOp, args&: castedAccessChain,
830 args: adaptor.getValueToStore());
831
832 return success();
833 }
834};
835
836struct VectorReductionToIntDotProd final
837 : OpRewritePattern<vector::ReductionOp> {
838 using OpRewritePattern::OpRewritePattern;
839
840 LogicalResult matchAndRewrite(vector::ReductionOp op,
841 PatternRewriter &rewriter) const override {
842 if (op.getKind() != vector::CombiningKind::ADD)
843 return rewriter.notifyMatchFailure(arg&: op, msg: "combining kind is not 'add'");
844
845 auto resultType = dyn_cast<IntegerType>(Val: op.getType());
846 if (!resultType)
847 return rewriter.notifyMatchFailure(arg&: op, msg: "result is not an integer");
848
849 int64_t resultBitwidth = resultType.getIntOrFloatBitWidth();
850 if (!llvm::is_contained(Set: {32, 64}, Element: resultBitwidth))
851 return rewriter.notifyMatchFailure(arg&: op, msg: "unsupported integer bitwidth");
852
853 VectorType inVecTy = op.getSourceVectorType();
854 if (!llvm::is_contained(Set: {4, 3}, Element: inVecTy.getNumElements()) ||
855 inVecTy.getShape().size() != 1 || inVecTy.isScalable())
856 return rewriter.notifyMatchFailure(arg&: op, msg: "unsupported vector shape");
857
858 auto mul = op.getVector().getDefiningOp<arith::MulIOp>();
859 if (!mul)
860 return rewriter.notifyMatchFailure(
861 arg&: op, msg: "reduction operand is not 'arith.muli'");
862
863 if (succeeded(Result: handleCase<arith::ExtSIOp, arith::ExtSIOp, spirv::SDotOp,
864 spirv::SDotAccSatOp, false>(op, mul, rewriter)))
865 return success();
866
867 if (succeeded(Result: handleCase<arith::ExtUIOp, arith::ExtUIOp, spirv::UDotOp,
868 spirv::UDotAccSatOp, false>(op, mul, rewriter)))
869 return success();
870
871 if (succeeded(Result: handleCase<arith::ExtSIOp, arith::ExtUIOp, spirv::SUDotOp,
872 spirv::SUDotAccSatOp, false>(op, mul, rewriter)))
873 return success();
874
875 if (succeeded(Result: handleCase<arith::ExtUIOp, arith::ExtSIOp, spirv::SUDotOp,
876 spirv::SUDotAccSatOp, true>(op, mul, rewriter)))
877 return success();
878
879 return failure();
880 }
881
882private:
883 template <typename LhsExtensionOp, typename RhsExtensionOp, typename DotOp,
884 typename DotAccOp, bool SwapOperands>
885 static LogicalResult handleCase(vector::ReductionOp op, arith::MulIOp mul,
886 PatternRewriter &rewriter) {
887 auto lhs = mul.getLhs().getDefiningOp<LhsExtensionOp>();
888 if (!lhs)
889 return failure();
890 Value lhsIn = lhs.getIn();
891 auto lhsInType = cast<VectorType>(Val: lhsIn.getType());
892 if (!lhsInType.getElementType().isInteger(width: 8))
893 return failure();
894
895 auto rhs = mul.getRhs().getDefiningOp<RhsExtensionOp>();
896 if (!rhs)
897 return failure();
898 Value rhsIn = rhs.getIn();
899 auto rhsInType = cast<VectorType>(Val: rhsIn.getType());
900 if (!rhsInType.getElementType().isInteger(width: 8))
901 return failure();
902
903 if (op.getSourceVectorType().getNumElements() == 3) {
904 IntegerType i8Type = rewriter.getI8Type();
905 auto v4i8Type = VectorType::get(shape: {4}, elementType: i8Type);
906 Location loc = op.getLoc();
907 Value zero = spirv::ConstantOp::getZero(type: i8Type, loc, builder&: rewriter);
908 lhsIn = rewriter.create<spirv::CompositeConstructOp>(
909 location: loc, args&: v4i8Type, args: ValueRange{lhsIn, zero});
910 rhsIn = rewriter.create<spirv::CompositeConstructOp>(
911 location: loc, args&: v4i8Type, args: ValueRange{rhsIn, zero});
912 }
913
914 // There's no variant of dot prod ops for unsigned LHS and signed RHS, so
915 // we have to swap operands instead in that case.
916 if (SwapOperands)
917 std::swap(a&: lhsIn, b&: rhsIn);
918
919 if (Value acc = op.getAcc()) {
920 rewriter.replaceOpWithNewOp<DotAccOp>(op, op.getType(), lhsIn, rhsIn, acc,
921 nullptr);
922 } else {
923 rewriter.replaceOpWithNewOp<DotOp>(op, op.getType(), lhsIn, rhsIn,
924 nullptr);
925 }
926
927 return success();
928 }
929};
930
931struct VectorReductionToFPDotProd final
932 : OpConversionPattern<vector::ReductionOp> {
933 using OpConversionPattern::OpConversionPattern;
934
935 LogicalResult
936 matchAndRewrite(vector::ReductionOp op, OpAdaptor adaptor,
937 ConversionPatternRewriter &rewriter) const override {
938 if (op.getKind() != vector::CombiningKind::ADD)
939 return rewriter.notifyMatchFailure(arg&: op, msg: "combining kind is not 'add'");
940
941 auto resultType = getTypeConverter()->convertType<FloatType>(t: op.getType());
942 if (!resultType)
943 return rewriter.notifyMatchFailure(arg&: op, msg: "result is not a float");
944
945 Value vec = adaptor.getVector();
946 Value acc = adaptor.getAcc();
947
948 auto vectorType = dyn_cast<VectorType>(Val: vec.getType());
949 if (!vectorType) {
950 assert(isa<FloatType>(vec.getType()) &&
951 "Expected the vector to be scalarized");
952 if (acc) {
953 rewriter.replaceOpWithNewOp<spirv::FAddOp>(op, args&: acc, args&: vec);
954 return success();
955 }
956
957 rewriter.replaceOp(op, newValues: vec);
958 return success();
959 }
960
961 Location loc = op.getLoc();
962 Value lhs;
963 Value rhs;
964 if (auto mul = vec.getDefiningOp<arith::MulFOp>()) {
965 lhs = mul.getLhs();
966 rhs = mul.getRhs();
967 } else {
968 // If the operand is not a mul, use a vector of ones for the dot operand
969 // to just sum up all values.
970 lhs = vec;
971 Attribute oneAttr =
972 rewriter.getFloatAttr(type: vectorType.getElementType(), value: 1.0);
973 oneAttr = SplatElementsAttr::get(type: vectorType, values: oneAttr);
974 rhs = rewriter.create<spirv::ConstantOp>(location: loc, args&: vectorType, args&: oneAttr);
975 }
976 assert(lhs);
977 assert(rhs);
978
979 Value res = rewriter.create<spirv::DotOp>(location: loc, args&: resultType, args&: lhs, args&: rhs);
980 if (acc)
981 res = rewriter.create<spirv::FAddOp>(location: loc, args&: acc, args&: res);
982
983 rewriter.replaceOp(op, newValues: res);
984 return success();
985 }
986};
987
988struct VectorStepOpConvert final : OpConversionPattern<vector::StepOp> {
989 using OpConversionPattern::OpConversionPattern;
990
991 LogicalResult
992 matchAndRewrite(vector::StepOp stepOp, OpAdaptor adaptor,
993 ConversionPatternRewriter &rewriter) const override {
994 const auto &typeConverter = *getTypeConverter<SPIRVTypeConverter>();
995 Type dstType = typeConverter.convertType(t: stepOp.getType());
996 if (!dstType)
997 return failure();
998
999 Location loc = stepOp.getLoc();
1000 int64_t numElements = stepOp.getType().getNumElements();
1001 auto intType =
1002 rewriter.getIntegerType(width: typeConverter.getIndexTypeBitwidth());
1003
1004 // Input vectors of size 1 are converted to scalars by the type converter.
1005 // We just create a constant in this case.
1006 if (numElements == 1) {
1007 Value zero = spirv::ConstantOp::getZero(type: intType, loc, builder&: rewriter);
1008 rewriter.replaceOp(op: stepOp, newValues: zero);
1009 return success();
1010 }
1011
1012 SmallVector<Value> source;
1013 source.reserve(N: numElements);
1014 for (int64_t i = 0; i < numElements; ++i) {
1015 Attribute intAttr = rewriter.getIntegerAttr(type: intType, value: i);
1016 Value constOp = rewriter.create<spirv::ConstantOp>(location: loc, args&: intType, args&: intAttr);
1017 source.push_back(Elt: constOp);
1018 }
1019 rewriter.replaceOpWithNewOp<spirv::CompositeConstructOp>(op: stepOp, args&: dstType,
1020 args&: source);
1021 return success();
1022 }
1023};
1024
1025struct VectorToElementOpConvert final
1026 : OpConversionPattern<vector::ToElementsOp> {
1027 using OpConversionPattern::OpConversionPattern;
1028
1029 LogicalResult
1030 matchAndRewrite(vector::ToElementsOp toElementsOp, OpAdaptor adaptor,
1031 ConversionPatternRewriter &rewriter) const override {
1032
1033 SmallVector<Value> results(toElementsOp->getNumResults());
1034 Location loc = toElementsOp.getLoc();
1035
1036 // Input vectors of size 1 are converted to scalars by the type converter.
1037 // We cannot use `spirv::CompositeExtractOp` directly in this case.
1038 // For a scalar source, the result is just the scalar itself.
1039 if (isa<spirv::ScalarType>(Val: adaptor.getSource().getType())) {
1040 results[0] = adaptor.getSource();
1041 rewriter.replaceOp(op: toElementsOp, newValues: results);
1042 return success();
1043 }
1044
1045 Type srcElementType = toElementsOp.getElements().getType().front();
1046 Type elementType = getTypeConverter()->convertType(t: srcElementType);
1047 if (!elementType)
1048 return rewriter.notifyMatchFailure(
1049 arg&: toElementsOp,
1050 msg: llvm::formatv(Fmt: "failed to convert element type '{0}' to SPIR-V",
1051 Vals&: srcElementType));
1052
1053 for (auto [idx, element] : llvm::enumerate(First: toElementsOp.getElements())) {
1054 // Create an CompositeExtract operation only for results that are not
1055 // dead.
1056 if (element.use_empty())
1057 continue;
1058
1059 Value result = rewriter.create<spirv::CompositeExtractOp>(
1060 location: loc, args&: elementType, args: adaptor.getSource(),
1061 args: rewriter.getI32ArrayAttr(values: {static_cast<int32_t>(idx)}));
1062 results[idx] = result;
1063 }
1064
1065 rewriter.replaceOp(op: toElementsOp, newValues: results);
1066 return success();
1067 }
1068};
1069
1070} // namespace
1071#define CL_INT_MAX_MIN_OPS \
1072 spirv::CLUMaxOp, spirv::CLUMinOp, spirv::CLSMaxOp, spirv::CLSMinOp
1073
1074#define GL_INT_MAX_MIN_OPS \
1075 spirv::GLUMaxOp, spirv::GLUMinOp, spirv::GLSMaxOp, spirv::GLSMinOp
1076
1077#define CL_FLOAT_MAX_MIN_OPS spirv::CLFMaxOp, spirv::CLFMinOp
1078#define GL_FLOAT_MAX_MIN_OPS spirv::GLFMaxOp, spirv::GLFMinOp
1079
1080void mlir::populateVectorToSPIRVPatterns(
1081 const SPIRVTypeConverter &typeConverter, RewritePatternSet &patterns) {
1082 patterns.add<
1083 VectorBitcastConvert, VectorBroadcastConvert,
1084 VectorExtractElementOpConvert, VectorExtractOpConvert,
1085 VectorExtractStridedSliceOpConvert, VectorFmaOpConvert<spirv::GLFmaOp>,
1086 VectorFmaOpConvert<spirv::CLFmaOp>, VectorFromElementsOpConvert,
1087 VectorToElementOpConvert, VectorInsertElementOpConvert,
1088 VectorInsertOpConvert, VectorReductionPattern<GL_INT_MAX_MIN_OPS>,
1089 VectorReductionPattern<CL_INT_MAX_MIN_OPS>,
1090 VectorReductionFloatMinMax<CL_FLOAT_MAX_MIN_OPS>,
1091 VectorReductionFloatMinMax<GL_FLOAT_MAX_MIN_OPS>, VectorShapeCast,
1092 VectorInsertStridedSliceOpConvert, VectorShuffleOpConvert,
1093 VectorInterleaveOpConvert, VectorDeinterleaveOpConvert,
1094 VectorSplatPattern, VectorLoadOpConverter, VectorStoreOpConverter,
1095 VectorStepOpConvert>(arg: typeConverter, args: patterns.getContext(),
1096 args: PatternBenefit(1));
1097
1098 // Make sure that the more specialized dot product pattern has higher benefit
1099 // than the generic one that extracts all elements.
1100 patterns.add<VectorReductionToFPDotProd>(arg: typeConverter, args: patterns.getContext(),
1101 args: PatternBenefit(2));
1102}
1103
1104void mlir::populateVectorReductionToSPIRVDotProductPatterns(
1105 RewritePatternSet &patterns) {
1106 patterns.add<VectorReductionToIntDotProd>(arg: patterns.getContext());
1107}
1108

source code of mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp