1//===- VectorToSPIRV.cpp - Vector to SPIR-V Patterns ----------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implements patterns to convert Vector dialect to SPIRV dialect.
10//
11//===----------------------------------------------------------------------===//
12
13#include "mlir/Conversion/VectorToSPIRV/VectorToSPIRV.h"
14
15#include "mlir/Dialect/Arith/IR/Arith.h"
16#include "mlir/Dialect/SPIRV/IR/SPIRVOps.h"
17#include "mlir/Dialect/SPIRV/IR/SPIRVTypes.h"
18#include "mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h"
19#include "mlir/Dialect/Utils/StaticValueUtils.h"
20#include "mlir/Dialect/Vector/IR/VectorOps.h"
21#include "mlir/IR/Attributes.h"
22#include "mlir/IR/BuiltinAttributes.h"
23#include "mlir/IR/BuiltinTypes.h"
24#include "mlir/IR/Location.h"
25#include "mlir/IR/Matchers.h"
26#include "mlir/IR/PatternMatch.h"
27#include "mlir/IR/TypeUtilities.h"
28#include "mlir/Transforms/DialectConversion.h"
29#include "llvm/ADT/ArrayRef.h"
30#include "llvm/ADT/STLExtras.h"
31#include "llvm/ADT/SmallVector.h"
32#include "llvm/ADT/SmallVectorExtras.h"
33#include "llvm/Support/FormatVariadic.h"
34#include <cassert>
35#include <cstdint>
36#include <numeric>
37
38using namespace mlir;
39
40/// Returns the integer value from the first valid input element, assuming Value
41/// inputs are defined by a constant index ops and Attribute inputs are integer
42/// attributes.
43static uint64_t getFirstIntValue(ArrayAttr attr) {
44 return (*attr.getAsValueRange<IntegerAttr>().begin()).getZExtValue();
45}
46
47/// Returns the number of bits for the given scalar/vector type.
48static int getNumBits(Type type) {
49 // TODO: This does not take into account any memory layout or widening
50 // constraints. E.g., a vector<3xi57> may report to occupy 3x57=171 bit, even
51 // though in practice it will likely be stored as in a 4xi64 vector register.
52 if (auto vectorType = dyn_cast<VectorType>(type))
53 return vectorType.getNumElements() * vectorType.getElementTypeBitWidth();
54 return type.getIntOrFloatBitWidth();
55}
56
57namespace {
58
59struct VectorShapeCast final : public OpConversionPattern<vector::ShapeCastOp> {
60 using OpConversionPattern::OpConversionPattern;
61
62 LogicalResult
63 matchAndRewrite(vector::ShapeCastOp shapeCastOp, OpAdaptor adaptor,
64 ConversionPatternRewriter &rewriter) const override {
65 Type dstType = getTypeConverter()->convertType(shapeCastOp.getType());
66 if (!dstType)
67 return failure();
68
69 // If dstType is same as the source type or the vector size is 1, it can be
70 // directly replaced by the source.
71 if (dstType == adaptor.getSource().getType() ||
72 shapeCastOp.getResultVectorType().getNumElements() == 1) {
73 rewriter.replaceOp(shapeCastOp, adaptor.getSource());
74 return success();
75 }
76
77 // Lowering for size-n vectors when n > 1 hasn't been implemented.
78 return failure();
79 }
80};
81
82struct VectorBitcastConvert final
83 : public OpConversionPattern<vector::BitCastOp> {
84 using OpConversionPattern::OpConversionPattern;
85
86 LogicalResult
87 matchAndRewrite(vector::BitCastOp bitcastOp, OpAdaptor adaptor,
88 ConversionPatternRewriter &rewriter) const override {
89 Type dstType = getTypeConverter()->convertType(bitcastOp.getType());
90 if (!dstType)
91 return failure();
92
93 if (dstType == adaptor.getSource().getType()) {
94 rewriter.replaceOp(bitcastOp, adaptor.getSource());
95 return success();
96 }
97
98 // Check that the source and destination type have the same bitwidth.
99 // Depending on the target environment, we may need to emulate certain
100 // types, which can cause issue with bitcast.
101 Type srcType = adaptor.getSource().getType();
102 if (getNumBits(type: dstType) != getNumBits(type: srcType)) {
103 return rewriter.notifyMatchFailure(
104 bitcastOp,
105 llvm::formatv(Fmt: "different source ({0}) and target ({1}) bitwidth",
106 Vals&: srcType, Vals&: dstType));
107 }
108
109 rewriter.replaceOpWithNewOp<spirv::BitcastOp>(bitcastOp, dstType,
110 adaptor.getSource());
111 return success();
112 }
113};
114
115struct VectorBroadcastConvert final
116 : public OpConversionPattern<vector::BroadcastOp> {
117 using OpConversionPattern::OpConversionPattern;
118
119 LogicalResult
120 matchAndRewrite(vector::BroadcastOp castOp, OpAdaptor adaptor,
121 ConversionPatternRewriter &rewriter) const override {
122 Type resultType =
123 getTypeConverter()->convertType(castOp.getResultVectorType());
124 if (!resultType)
125 return failure();
126
127 if (isa<spirv::ScalarType>(Val: resultType)) {
128 rewriter.replaceOp(castOp, adaptor.getSource());
129 return success();
130 }
131
132 SmallVector<Value, 4> source(castOp.getResultVectorType().getNumElements(),
133 adaptor.getSource());
134 rewriter.replaceOpWithNewOp<spirv::CompositeConstructOp>(castOp, resultType,
135 source);
136 return success();
137 }
138};
139
140// SPIR-V does not have a concept of a poison index for certain instructions,
141// which creates a UB hazard when lowering from otherwise equivalent Vector
142// dialect instructions, because this index will be considered out-of-bounds.
143// To avoid this, this function implements a dynamic sanitization that returns
144// some arbitrary safe index. For power-of-two vector sizes, this uses a bitmask
145// (presumably more efficient), and otherwise index 0 (always in-bounds).
146static Value sanitizeDynamicIndex(ConversionPatternRewriter &rewriter,
147 Location loc, Value dynamicIndex,
148 int64_t kPoisonIndex, unsigned vectorSize) {
149 if (llvm::isPowerOf2_32(Value: vectorSize)) {
150 Value inBoundsMask = rewriter.create<spirv::ConstantOp>(
151 loc, dynamicIndex.getType(),
152 rewriter.getIntegerAttr(dynamicIndex.getType(), vectorSize - 1));
153 return rewriter.create<spirv::BitwiseAndOp>(loc, dynamicIndex,
154 inBoundsMask);
155 }
156 Value poisonIndex = rewriter.create<spirv::ConstantOp>(
157 loc, dynamicIndex.getType(),
158 rewriter.getIntegerAttr(dynamicIndex.getType(), kPoisonIndex));
159 Value cmpResult =
160 rewriter.create<spirv::IEqualOp>(loc, dynamicIndex, poisonIndex);
161 return rewriter.create<spirv::SelectOp>(
162 loc, cmpResult,
163 spirv::ConstantOp::getZero(dynamicIndex.getType(), loc, rewriter),
164 dynamicIndex);
165}
166
167struct VectorExtractOpConvert final
168 : public OpConversionPattern<vector::ExtractOp> {
169 using OpConversionPattern::OpConversionPattern;
170
171 LogicalResult
172 matchAndRewrite(vector::ExtractOp extractOp, OpAdaptor adaptor,
173 ConversionPatternRewriter &rewriter) const override {
174 Type dstType = getTypeConverter()->convertType(extractOp.getType());
175 if (!dstType)
176 return failure();
177
178 if (isa<spirv::ScalarType>(adaptor.getVector().getType())) {
179 rewriter.replaceOp(extractOp, adaptor.getVector());
180 return success();
181 }
182
183 if (std::optional<int64_t> id =
184 getConstantIntValue(extractOp.getMixedPosition()[0])) {
185 if (id == vector::ExtractOp::kPoisonIndex)
186 return rewriter.notifyMatchFailure(
187 extractOp,
188 "Static use of poison index handled elsewhere (folded to poison)");
189 rewriter.replaceOpWithNewOp<spirv::CompositeExtractOp>(
190 extractOp, dstType, adaptor.getVector(),
191 rewriter.getI32ArrayAttr(id.value()));
192 } else {
193 Value sanitizedIndex = sanitizeDynamicIndex(
194 rewriter, extractOp.getLoc(), adaptor.getDynamicPosition()[0],
195 vector::ExtractOp::kPoisonIndex,
196 extractOp.getSourceVectorType().getNumElements());
197 rewriter.replaceOpWithNewOp<spirv::VectorExtractDynamicOp>(
198 extractOp, dstType, adaptor.getVector(), sanitizedIndex);
199 }
200 return success();
201 }
202};
203
204struct VectorExtractStridedSliceOpConvert final
205 : public OpConversionPattern<vector::ExtractStridedSliceOp> {
206 using OpConversionPattern::OpConversionPattern;
207
208 LogicalResult
209 matchAndRewrite(vector::ExtractStridedSliceOp extractOp, OpAdaptor adaptor,
210 ConversionPatternRewriter &rewriter) const override {
211 Type dstType = getTypeConverter()->convertType(extractOp.getType());
212 if (!dstType)
213 return failure();
214
215 uint64_t offset = getFirstIntValue(extractOp.getOffsets());
216 uint64_t size = getFirstIntValue(extractOp.getSizes());
217 uint64_t stride = getFirstIntValue(extractOp.getStrides());
218 if (stride != 1)
219 return failure();
220
221 Value srcVector = adaptor.getOperands().front();
222
223 // Extract vector<1xT> case.
224 if (isa<spirv::ScalarType>(Val: dstType)) {
225 rewriter.replaceOpWithNewOp<spirv::CompositeExtractOp>(extractOp,
226 srcVector, offset);
227 return success();
228 }
229
230 SmallVector<int32_t, 2> indices(size);
231 std::iota(first: indices.begin(), last: indices.end(), value: offset);
232
233 rewriter.replaceOpWithNewOp<spirv::VectorShuffleOp>(
234 extractOp, dstType, srcVector, srcVector,
235 rewriter.getI32ArrayAttr(indices));
236
237 return success();
238 }
239};
240
241template <class SPIRVFMAOp>
242struct VectorFmaOpConvert final : public OpConversionPattern<vector::FMAOp> {
243 using OpConversionPattern::OpConversionPattern;
244
245 LogicalResult
246 matchAndRewrite(vector::FMAOp fmaOp, OpAdaptor adaptor,
247 ConversionPatternRewriter &rewriter) const override {
248 Type dstType = getTypeConverter()->convertType(fmaOp.getType());
249 if (!dstType)
250 return failure();
251 rewriter.replaceOpWithNewOp<SPIRVFMAOp>(fmaOp, dstType, adaptor.getLhs(),
252 adaptor.getRhs(), adaptor.getAcc());
253 return success();
254 }
255};
256
257struct VectorFromElementsOpConvert final
258 : public OpConversionPattern<vector::FromElementsOp> {
259 using OpConversionPattern::OpConversionPattern;
260
261 LogicalResult
262 matchAndRewrite(vector::FromElementsOp op, OpAdaptor adaptor,
263 ConversionPatternRewriter &rewriter) const override {
264 Type resultType = getTypeConverter()->convertType(op.getType());
265 if (!resultType)
266 return failure();
267 OperandRange elements = op.getElements();
268 if (isa<spirv::ScalarType>(Val: resultType)) {
269 // In the case with a single scalar operand / single-element result,
270 // pass through the scalar.
271 rewriter.replaceOp(op, elements[0]);
272 return success();
273 }
274 // SPIRVTypeConverter rejects vectors with rank > 1, so multi-dimensional
275 // vector.from_elements cases should not need to be handled, only 1d.
276 assert(cast<VectorType>(resultType).getRank() == 1);
277 rewriter.replaceOpWithNewOp<spirv::CompositeConstructOp>(op, resultType,
278 elements);
279 return success();
280 }
281};
282
283struct VectorInsertOpConvert final
284 : public OpConversionPattern<vector::InsertOp> {
285 using OpConversionPattern::OpConversionPattern;
286
287 LogicalResult
288 matchAndRewrite(vector::InsertOp insertOp, OpAdaptor adaptor,
289 ConversionPatternRewriter &rewriter) const override {
290 if (isa<VectorType>(insertOp.getValueToStoreType()))
291 return rewriter.notifyMatchFailure(insertOp, "unsupported vector source");
292 if (!getTypeConverter()->convertType(insertOp.getDestVectorType()))
293 return rewriter.notifyMatchFailure(insertOp,
294 "unsupported dest vector type");
295
296 // Special case for inserting scalar values into size-1 vectors.
297 if (insertOp.getValueToStoreType().isIntOrFloat() &&
298 insertOp.getDestVectorType().getNumElements() == 1) {
299 rewriter.replaceOp(insertOp, adaptor.getValueToStore());
300 return success();
301 }
302
303 if (std::optional<int64_t> id =
304 getConstantIntValue(insertOp.getMixedPosition()[0])) {
305 if (id == vector::InsertOp::kPoisonIndex)
306 return rewriter.notifyMatchFailure(
307 insertOp,
308 "Static use of poison index handled elsewhere (folded to poison)");
309 rewriter.replaceOpWithNewOp<spirv::CompositeInsertOp>(
310 insertOp, adaptor.getValueToStore(), adaptor.getDest(), id.value());
311 } else {
312 Value sanitizedIndex = sanitizeDynamicIndex(
313 rewriter, insertOp.getLoc(), adaptor.getDynamicPosition()[0],
314 vector::InsertOp::kPoisonIndex,
315 insertOp.getDestVectorType().getNumElements());
316 rewriter.replaceOpWithNewOp<spirv::VectorInsertDynamicOp>(
317 insertOp, insertOp.getDest(), adaptor.getValueToStore(),
318 sanitizedIndex);
319 }
320 return success();
321 }
322};
323
324struct VectorExtractElementOpConvert final
325 : public OpConversionPattern<vector::ExtractElementOp> {
326 using OpConversionPattern::OpConversionPattern;
327
328 LogicalResult
329 matchAndRewrite(vector::ExtractElementOp extractOp, OpAdaptor adaptor,
330 ConversionPatternRewriter &rewriter) const override {
331 Type resultType = getTypeConverter()->convertType(extractOp.getType());
332 if (!resultType)
333 return failure();
334
335 if (isa<spirv::ScalarType>(adaptor.getVector().getType())) {
336 rewriter.replaceOp(extractOp, adaptor.getVector());
337 return success();
338 }
339
340 APInt cstPos;
341 if (matchPattern(adaptor.getPosition(), m_ConstantInt(&cstPos)))
342 rewriter.replaceOpWithNewOp<spirv::CompositeExtractOp>(
343 extractOp, resultType, adaptor.getVector(),
344 rewriter.getI32ArrayAttr({static_cast<int>(cstPos.getSExtValue())}));
345 else
346 rewriter.replaceOpWithNewOp<spirv::VectorExtractDynamicOp>(
347 extractOp, resultType, adaptor.getVector(), adaptor.getPosition());
348 return success();
349 }
350};
351
352struct VectorInsertElementOpConvert final
353 : public OpConversionPattern<vector::InsertElementOp> {
354 using OpConversionPattern::OpConversionPattern;
355
356 LogicalResult
357 matchAndRewrite(vector::InsertElementOp insertOp, OpAdaptor adaptor,
358 ConversionPatternRewriter &rewriter) const override {
359 Type vectorType = getTypeConverter()->convertType(insertOp.getType());
360 if (!vectorType)
361 return failure();
362
363 if (isa<spirv::ScalarType>(Val: vectorType)) {
364 rewriter.replaceOp(insertOp, adaptor.getSource());
365 return success();
366 }
367
368 APInt cstPos;
369 if (matchPattern(adaptor.getPosition(), m_ConstantInt(&cstPos)))
370 rewriter.replaceOpWithNewOp<spirv::CompositeInsertOp>(
371 insertOp, adaptor.getSource(), adaptor.getDest(),
372 cstPos.getSExtValue());
373 else
374 rewriter.replaceOpWithNewOp<spirv::VectorInsertDynamicOp>(
375 insertOp, vectorType, insertOp.getDest(), adaptor.getSource(),
376 adaptor.getPosition());
377 return success();
378 }
379};
380
381struct VectorInsertStridedSliceOpConvert final
382 : public OpConversionPattern<vector::InsertStridedSliceOp> {
383 using OpConversionPattern::OpConversionPattern;
384
385 LogicalResult
386 matchAndRewrite(vector::InsertStridedSliceOp insertOp, OpAdaptor adaptor,
387 ConversionPatternRewriter &rewriter) const override {
388 Value srcVector = adaptor.getOperands().front();
389 Value dstVector = adaptor.getOperands().back();
390
391 uint64_t stride = getFirstIntValue(insertOp.getStrides());
392 if (stride != 1)
393 return failure();
394 uint64_t offset = getFirstIntValue(insertOp.getOffsets());
395
396 if (isa<spirv::ScalarType>(Val: srcVector.getType())) {
397 assert(!isa<spirv::ScalarType>(dstVector.getType()));
398 rewriter.replaceOpWithNewOp<spirv::CompositeInsertOp>(
399 insertOp, dstVector.getType(), srcVector, dstVector,
400 rewriter.getI32ArrayAttr(offset));
401 return success();
402 }
403
404 uint64_t totalSize = cast<VectorType>(dstVector.getType()).getNumElements();
405 uint64_t insertSize =
406 cast<VectorType>(srcVector.getType()).getNumElements();
407
408 SmallVector<int32_t, 2> indices(totalSize);
409 std::iota(first: indices.begin(), last: indices.end(), value: 0);
410 std::iota(first: indices.begin() + offset, last: indices.begin() + offset + insertSize,
411 value: totalSize);
412
413 rewriter.replaceOpWithNewOp<spirv::VectorShuffleOp>(
414 insertOp, dstVector.getType(), dstVector, srcVector,
415 rewriter.getI32ArrayAttr(indices));
416
417 return success();
418 }
419};
420
421static SmallVector<Value> extractAllElements(
422 vector::ReductionOp reduceOp, vector::ReductionOp::Adaptor adaptor,
423 VectorType srcVectorType, ConversionPatternRewriter &rewriter) {
424 int numElements = static_cast<int>(srcVectorType.getDimSize(0));
425 SmallVector<Value> values;
426 values.reserve(N: numElements + (adaptor.getAcc() ? 1 : 0));
427 Location loc = reduceOp.getLoc();
428
429 for (int i = 0; i < numElements; ++i) {
430 values.push_back(rewriter.create<spirv::CompositeExtractOp>(
431 loc, srcVectorType.getElementType(), adaptor.getVector(),
432 rewriter.getI32ArrayAttr({i})));
433 }
434 if (Value acc = adaptor.getAcc())
435 values.push_back(Elt: acc);
436
437 return values;
438}
439
440struct ReductionRewriteInfo {
441 Type resultType;
442 SmallVector<Value> extractedElements;
443};
444
445FailureOr<ReductionRewriteInfo> static getReductionInfo(
446 vector::ReductionOp op, vector::ReductionOp::Adaptor adaptor,
447 ConversionPatternRewriter &rewriter, const TypeConverter &typeConverter) {
448 Type resultType = typeConverter.convertType(op.getType());
449 if (!resultType)
450 return failure();
451
452 auto srcVectorType = dyn_cast<VectorType>(adaptor.getVector().getType());
453 if (!srcVectorType || srcVectorType.getRank() != 1)
454 return rewriter.notifyMatchFailure(op, "not a 1-D vector source");
455
456 SmallVector<Value> extractedElements =
457 extractAllElements(op, adaptor, srcVectorType, rewriter);
458
459 return ReductionRewriteInfo{.resultType: resultType, .extractedElements: std::move(extractedElements)};
460}
461
462template <typename SPIRVUMaxOp, typename SPIRVUMinOp, typename SPIRVSMaxOp,
463 typename SPIRVSMinOp>
464struct VectorReductionPattern final : OpConversionPattern<vector::ReductionOp> {
465 using OpConversionPattern::OpConversionPattern;
466
467 LogicalResult
468 matchAndRewrite(vector::ReductionOp reduceOp, OpAdaptor adaptor,
469 ConversionPatternRewriter &rewriter) const override {
470 auto reductionInfo =
471 getReductionInfo(reduceOp, adaptor, rewriter, *getTypeConverter());
472 if (failed(reductionInfo))
473 return failure();
474
475 auto [resultType, extractedElements] = *reductionInfo;
476 Location loc = reduceOp->getLoc();
477 Value result = extractedElements.front();
478 for (Value next : llvm::drop_begin(extractedElements)) {
479 switch (reduceOp.getKind()) {
480
481#define INT_AND_FLOAT_CASE(kind, iop, fop) \
482 case vector::CombiningKind::kind: \
483 if (llvm::isa<IntegerType>(resultType)) { \
484 result = rewriter.create<spirv::iop>(loc, resultType, result, next); \
485 } else { \
486 assert(llvm::isa<FloatType>(resultType)); \
487 result = rewriter.create<spirv::fop>(loc, resultType, result, next); \
488 } \
489 break
490
491#define INT_OR_FLOAT_CASE(kind, fop) \
492 case vector::CombiningKind::kind: \
493 result = rewriter.create<fop>(loc, resultType, result, next); \
494 break
495
496 INT_AND_FLOAT_CASE(ADD, IAddOp, FAddOp);
497 INT_AND_FLOAT_CASE(MUL, IMulOp, FMulOp);
498 INT_OR_FLOAT_CASE(MINUI, SPIRVUMinOp);
499 INT_OR_FLOAT_CASE(MINSI, SPIRVSMinOp);
500 INT_OR_FLOAT_CASE(MAXUI, SPIRVUMaxOp);
501 INT_OR_FLOAT_CASE(MAXSI, SPIRVSMaxOp);
502
503 case vector::CombiningKind::AND:
504 case vector::CombiningKind::OR:
505 case vector::CombiningKind::XOR:
506 return rewriter.notifyMatchFailure(reduceOp, "unimplemented");
507 default:
508 return rewriter.notifyMatchFailure(reduceOp, "not handled here");
509 }
510#undef INT_AND_FLOAT_CASE
511#undef INT_OR_FLOAT_CASE
512 }
513
514 rewriter.replaceOp(reduceOp, result);
515 return success();
516 }
517};
518
519template <typename SPIRVFMaxOp, typename SPIRVFMinOp>
520struct VectorReductionFloatMinMax final
521 : OpConversionPattern<vector::ReductionOp> {
522 using OpConversionPattern::OpConversionPattern;
523
524 LogicalResult
525 matchAndRewrite(vector::ReductionOp reduceOp, OpAdaptor adaptor,
526 ConversionPatternRewriter &rewriter) const override {
527 auto reductionInfo =
528 getReductionInfo(reduceOp, adaptor, rewriter, *getTypeConverter());
529 if (failed(reductionInfo))
530 return failure();
531
532 auto [resultType, extractedElements] = *reductionInfo;
533 Location loc = reduceOp->getLoc();
534 Value result = extractedElements.front();
535 for (Value next : llvm::drop_begin(extractedElements)) {
536 switch (reduceOp.getKind()) {
537
538#define INT_OR_FLOAT_CASE(kind, fop) \
539 case vector::CombiningKind::kind: \
540 result = rewriter.create<fop>(loc, resultType, result, next); \
541 break
542
543 INT_OR_FLOAT_CASE(MAXIMUMF, SPIRVFMaxOp);
544 INT_OR_FLOAT_CASE(MINIMUMF, SPIRVFMinOp);
545 INT_OR_FLOAT_CASE(MAXNUMF, SPIRVFMaxOp);
546 INT_OR_FLOAT_CASE(MINNUMF, SPIRVFMinOp);
547
548 default:
549 return rewriter.notifyMatchFailure(reduceOp, "not handled here");
550 }
551#undef INT_OR_FLOAT_CASE
552 }
553
554 rewriter.replaceOp(reduceOp, result);
555 return success();
556 }
557};
558
559class VectorSplatPattern final : public OpConversionPattern<vector::SplatOp> {
560public:
561 using OpConversionPattern<vector::SplatOp>::OpConversionPattern;
562
563 LogicalResult
564 matchAndRewrite(vector::SplatOp op, OpAdaptor adaptor,
565 ConversionPatternRewriter &rewriter) const override {
566 Type dstType = getTypeConverter()->convertType(op.getType());
567 if (!dstType)
568 return failure();
569 if (isa<spirv::ScalarType>(Val: dstType)) {
570 rewriter.replaceOp(op, adaptor.getInput());
571 } else {
572 auto dstVecType = cast<VectorType>(dstType);
573 SmallVector<Value, 4> source(dstVecType.getNumElements(),
574 adaptor.getInput());
575 rewriter.replaceOpWithNewOp<spirv::CompositeConstructOp>(op, dstType,
576 source);
577 }
578 return success();
579 }
580};
581
582struct VectorShuffleOpConvert final
583 : public OpConversionPattern<vector::ShuffleOp> {
584 using OpConversionPattern::OpConversionPattern;
585
586 LogicalResult
587 matchAndRewrite(vector::ShuffleOp shuffleOp, OpAdaptor adaptor,
588 ConversionPatternRewriter &rewriter) const override {
589 VectorType oldResultType = shuffleOp.getResultVectorType();
590 Type newResultType = getTypeConverter()->convertType(oldResultType);
591 if (!newResultType)
592 return rewriter.notifyMatchFailure(shuffleOp,
593 "unsupported result vector type");
594
595 auto mask = llvm::to_vector_of<int32_t>(shuffleOp.getMask());
596
597 VectorType oldV1Type = shuffleOp.getV1VectorType();
598 VectorType oldV2Type = shuffleOp.getV2VectorType();
599
600 // When both operands and the result are SPIR-V vectors, emit a SPIR-V
601 // shuffle.
602 if (oldV1Type.getNumElements() > 1 && oldV2Type.getNumElements() > 1 &&
603 oldResultType.getNumElements() > 1) {
604 rewriter.replaceOpWithNewOp<spirv::VectorShuffleOp>(
605 shuffleOp, newResultType, adaptor.getV1(), adaptor.getV2(),
606 rewriter.getI32ArrayAttr(mask));
607 return success();
608 }
609
610 // When at least one of the operands or the result becomes a scalar after
611 // type conversion for SPIR-V, extract all the required elements and
612 // construct the result vector.
613 auto getElementAtIdx = [&rewriter, loc = shuffleOp.getLoc()](
614 Value scalarOrVec, int32_t idx) -> Value {
615 if (auto vecTy = dyn_cast<VectorType>(scalarOrVec.getType()))
616 return rewriter.create<spirv::CompositeExtractOp>(loc, scalarOrVec,
617 idx);
618
619 assert(idx == 0 && "Invalid scalar element index");
620 return scalarOrVec;
621 };
622
623 int32_t numV1Elems = oldV1Type.getNumElements();
624 SmallVector<Value> newOperands(mask.size());
625 for (auto [shuffleIdx, newOperand] : llvm::zip_equal(mask, newOperands)) {
626 Value vec = adaptor.getV1();
627 int32_t elementIdx = shuffleIdx;
628 if (elementIdx >= numV1Elems) {
629 vec = adaptor.getV2();
630 elementIdx -= numV1Elems;
631 }
632
633 newOperand = getElementAtIdx(vec, elementIdx);
634 }
635
636 // Handle the scalar result corner case.
637 if (newOperands.size() == 1) {
638 rewriter.replaceOp(shuffleOp, newOperands.front());
639 return success();
640 }
641
642 rewriter.replaceOpWithNewOp<spirv::CompositeConstructOp>(
643 shuffleOp, newResultType, newOperands);
644 return success();
645 }
646};
647
648struct VectorInterleaveOpConvert final
649 : public OpConversionPattern<vector::InterleaveOp> {
650 using OpConversionPattern::OpConversionPattern;
651
652 LogicalResult
653 matchAndRewrite(vector::InterleaveOp interleaveOp, OpAdaptor adaptor,
654 ConversionPatternRewriter &rewriter) const override {
655 // Check the result vector type.
656 VectorType oldResultType = interleaveOp.getResultVectorType();
657 Type newResultType = getTypeConverter()->convertType(oldResultType);
658 if (!newResultType)
659 return rewriter.notifyMatchFailure(interleaveOp,
660 "unsupported result vector type");
661
662 // Interleave the indices.
663 VectorType sourceType = interleaveOp.getSourceVectorType();
664 int n = sourceType.getNumElements();
665
666 // Input vectors of size 1 are converted to scalars by the type converter.
667 // We cannot use `spirv::VectorShuffleOp` directly in this case, and need to
668 // use `spirv::CompositeConstructOp`.
669 if (n == 1) {
670 Value newOperands[] = {adaptor.getLhs(), adaptor.getRhs()};
671 rewriter.replaceOpWithNewOp<spirv::CompositeConstructOp>(
672 interleaveOp, newResultType, newOperands);
673 return success();
674 }
675
676 auto seq = llvm::seq<int64_t>(Size: 2 * n);
677 auto indices = llvm::map_to_vector(
678 seq, [n](int i) { return (i % 2 ? n : 0) + i / 2; });
679
680 // Emit a SPIR-V shuffle.
681 rewriter.replaceOpWithNewOp<spirv::VectorShuffleOp>(
682 interleaveOp, newResultType, adaptor.getLhs(), adaptor.getRhs(),
683 rewriter.getI32ArrayAttr(indices));
684
685 return success();
686 }
687};
688
689struct VectorDeinterleaveOpConvert final
690 : public OpConversionPattern<vector::DeinterleaveOp> {
691 using OpConversionPattern::OpConversionPattern;
692
693 LogicalResult
694 matchAndRewrite(vector::DeinterleaveOp deinterleaveOp, OpAdaptor adaptor,
695 ConversionPatternRewriter &rewriter) const override {
696
697 // Check the result vector type.
698 VectorType oldResultType = deinterleaveOp.getResultVectorType();
699 Type newResultType = getTypeConverter()->convertType(oldResultType);
700 if (!newResultType)
701 return rewriter.notifyMatchFailure(deinterleaveOp,
702 "unsupported result vector type");
703
704 Location loc = deinterleaveOp->getLoc();
705
706 // Deinterleave the indices.
707 Value sourceVector = adaptor.getSource();
708 VectorType sourceType = deinterleaveOp.getSourceVectorType();
709 int n = sourceType.getNumElements();
710
711 // Output vectors of size 1 are converted to scalars by the type converter.
712 // We cannot use `spirv::VectorShuffleOp` directly in this case, and need to
713 // use `spirv::CompositeExtractOp`.
714 if (n == 2) {
715 auto elem0 = rewriter.create<spirv::CompositeExtractOp>(
716 loc, newResultType, sourceVector, rewriter.getI32ArrayAttr({0}));
717
718 auto elem1 = rewriter.create<spirv::CompositeExtractOp>(
719 loc, newResultType, sourceVector, rewriter.getI32ArrayAttr({1}));
720
721 rewriter.replaceOp(deinterleaveOp, {elem0, elem1});
722 return success();
723 }
724
725 // Indices for `shuffleEven` (result 0).
726 auto seqEven = llvm::seq<int64_t>(Size: n / 2);
727 auto indicesEven =
728 llvm::map_to_vector(seqEven, [](int i) { return i * 2; });
729
730 // Indices for `shuffleOdd` (result 1).
731 auto seqOdd = llvm::seq<int64_t>(Size: n / 2);
732 auto indicesOdd =
733 llvm::map_to_vector(seqOdd, [](int i) { return i * 2 + 1; });
734
735 // Create two SPIR-V shuffles.
736 auto shuffleEven = rewriter.create<spirv::VectorShuffleOp>(
737 loc, newResultType, sourceVector, sourceVector,
738 rewriter.getI32ArrayAttr(indicesEven));
739
740 auto shuffleOdd = rewriter.create<spirv::VectorShuffleOp>(
741 loc, newResultType, sourceVector, sourceVector,
742 rewriter.getI32ArrayAttr(indicesOdd));
743
744 rewriter.replaceOp(deinterleaveOp, {shuffleEven, shuffleOdd});
745 return success();
746 }
747};
748
749struct VectorLoadOpConverter final
750 : public OpConversionPattern<vector::LoadOp> {
751 using OpConversionPattern::OpConversionPattern;
752
753 LogicalResult
754 matchAndRewrite(vector::LoadOp loadOp, OpAdaptor adaptor,
755 ConversionPatternRewriter &rewriter) const override {
756 auto memrefType = loadOp.getMemRefType();
757 auto attr =
758 dyn_cast_or_null<spirv::StorageClassAttr>(memrefType.getMemorySpace());
759 if (!attr)
760 return rewriter.notifyMatchFailure(
761 loadOp, "expected spirv.storage_class memory space");
762
763 const auto &typeConverter = *getTypeConverter<SPIRVTypeConverter>();
764 auto loc = loadOp.getLoc();
765 Value accessChain =
766 spirv::getElementPtr(typeConverter: typeConverter, baseType: memrefType, basePtr: adaptor.getBase(),
767 indices: adaptor.getIndices(), loc: loc, builder&: rewriter);
768 if (!accessChain)
769 return rewriter.notifyMatchFailure(
770 loadOp, "failed to get memref element pointer");
771
772 spirv::StorageClass storageClass = attr.getValue();
773 auto vectorType = loadOp.getVectorType();
774 // Use the converted vector type instead of original (single element vector
775 // would get converted to scalar).
776 auto spirvVectorType = typeConverter.convertType(vectorType);
777 auto vectorPtrType = spirv::PointerType::get(spirvVectorType, storageClass);
778
779 // For single element vectors, we don't need to bitcast the access chain to
780 // the original vector type. Both is going to be the same, a pointer
781 // to a scalar.
782 Value castedAccessChain = (vectorType.getNumElements() == 1)
783 ? accessChain
784 : rewriter.create<spirv::BitcastOp>(
785 loc, vectorPtrType, accessChain);
786
787 rewriter.replaceOpWithNewOp<spirv::LoadOp>(loadOp, spirvVectorType,
788 castedAccessChain);
789
790 return success();
791 }
792};
793
794struct VectorStoreOpConverter final
795 : public OpConversionPattern<vector::StoreOp> {
796 using OpConversionPattern::OpConversionPattern;
797
798 LogicalResult
799 matchAndRewrite(vector::StoreOp storeOp, OpAdaptor adaptor,
800 ConversionPatternRewriter &rewriter) const override {
801 auto memrefType = storeOp.getMemRefType();
802 auto attr =
803 dyn_cast_or_null<spirv::StorageClassAttr>(memrefType.getMemorySpace());
804 if (!attr)
805 return rewriter.notifyMatchFailure(
806 storeOp, "expected spirv.storage_class memory space");
807
808 const auto &typeConverter = *getTypeConverter<SPIRVTypeConverter>();
809 auto loc = storeOp.getLoc();
810 Value accessChain =
811 spirv::getElementPtr(typeConverter: typeConverter, baseType: memrefType, basePtr: adaptor.getBase(),
812 indices: adaptor.getIndices(), loc: loc, builder&: rewriter);
813 if (!accessChain)
814 return rewriter.notifyMatchFailure(
815 storeOp, "failed to get memref element pointer");
816
817 spirv::StorageClass storageClass = attr.getValue();
818 auto vectorType = storeOp.getVectorType();
819 auto vectorPtrType = spirv::PointerType::get(vectorType, storageClass);
820
821 // For single element vectors, we don't need to bitcast the access chain to
822 // the original vector type. Both is going to be the same, a pointer
823 // to a scalar.
824 Value castedAccessChain = (vectorType.getNumElements() == 1)
825 ? accessChain
826 : rewriter.create<spirv::BitcastOp>(
827 loc, vectorPtrType, accessChain);
828
829 rewriter.replaceOpWithNewOp<spirv::StoreOp>(storeOp, castedAccessChain,
830 adaptor.getValueToStore());
831
832 return success();
833 }
834};
835
836struct VectorReductionToIntDotProd final
837 : OpRewritePattern<vector::ReductionOp> {
838 using OpRewritePattern::OpRewritePattern;
839
840 LogicalResult matchAndRewrite(vector::ReductionOp op,
841 PatternRewriter &rewriter) const override {
842 if (op.getKind() != vector::CombiningKind::ADD)
843 return rewriter.notifyMatchFailure(op, "combining kind is not 'add'");
844
845 auto resultType = dyn_cast<IntegerType>(op.getType());
846 if (!resultType)
847 return rewriter.notifyMatchFailure(op, "result is not an integer");
848
849 int64_t resultBitwidth = resultType.getIntOrFloatBitWidth();
850 if (!llvm::is_contained(Set: {32, 64}, Element: resultBitwidth))
851 return rewriter.notifyMatchFailure(op, "unsupported integer bitwidth");
852
853 VectorType inVecTy = op.getSourceVectorType();
854 if (!llvm::is_contained({4, 3}, inVecTy.getNumElements()) ||
855 inVecTy.getShape().size() != 1 || inVecTy.isScalable())
856 return rewriter.notifyMatchFailure(op, "unsupported vector shape");
857
858 auto mul = op.getVector().getDefiningOp<arith::MulIOp>();
859 if (!mul)
860 return rewriter.notifyMatchFailure(
861 op, "reduction operand is not 'arith.muli'");
862
863 if (succeeded(handleCase<arith::ExtSIOp, arith::ExtSIOp, spirv::SDotOp,
864 spirv::SDotAccSatOp, false>(op, mul, rewriter)))
865 return success();
866
867 if (succeeded(handleCase<arith::ExtUIOp, arith::ExtUIOp, spirv::UDotOp,
868 spirv::UDotAccSatOp, false>(op, mul, rewriter)))
869 return success();
870
871 if (succeeded(handleCase<arith::ExtSIOp, arith::ExtUIOp, spirv::SUDotOp,
872 spirv::SUDotAccSatOp, false>(op, mul, rewriter)))
873 return success();
874
875 if (succeeded(handleCase<arith::ExtUIOp, arith::ExtSIOp, spirv::SUDotOp,
876 spirv::SUDotAccSatOp, true>(op, mul, rewriter)))
877 return success();
878
879 return failure();
880 }
881
882private:
883 template <typename LhsExtensionOp, typename RhsExtensionOp, typename DotOp,
884 typename DotAccOp, bool SwapOperands>
885 static LogicalResult handleCase(vector::ReductionOp op, arith::MulIOp mul,
886 PatternRewriter &rewriter) {
887 auto lhs = mul.getLhs().getDefiningOp<LhsExtensionOp>();
888 if (!lhs)
889 return failure();
890 Value lhsIn = lhs.getIn();
891 auto lhsInType = cast<VectorType>(lhsIn.getType());
892 if (!lhsInType.getElementType().isInteger(8))
893 return failure();
894
895 auto rhs = mul.getRhs().getDefiningOp<RhsExtensionOp>();
896 if (!rhs)
897 return failure();
898 Value rhsIn = rhs.getIn();
899 auto rhsInType = cast<VectorType>(rhsIn.getType());
900 if (!rhsInType.getElementType().isInteger(8))
901 return failure();
902
903 if (op.getSourceVectorType().getNumElements() == 3) {
904 IntegerType i8Type = rewriter.getI8Type();
905 auto v4i8Type = VectorType::get({4}, i8Type);
906 Location loc = op.getLoc();
907 Value zero = spirv::ConstantOp::getZero(i8Type, loc, rewriter);
908 lhsIn = rewriter.create<spirv::CompositeConstructOp>(
909 loc, v4i8Type, ValueRange{lhsIn, zero});
910 rhsIn = rewriter.create<spirv::CompositeConstructOp>(
911 loc, v4i8Type, ValueRange{rhsIn, zero});
912 }
913
914 // There's no variant of dot prod ops for unsigned LHS and signed RHS, so
915 // we have to swap operands instead in that case.
916 if (SwapOperands)
917 std::swap(a&: lhsIn, b&: rhsIn);
918
919 if (Value acc = op.getAcc()) {
920 rewriter.replaceOpWithNewOp<DotAccOp>(op, op.getType(), lhsIn, rhsIn, acc,
921 nullptr);
922 } else {
923 rewriter.replaceOpWithNewOp<DotOp>(op, op.getType(), lhsIn, rhsIn,
924 nullptr);
925 }
926
927 return success();
928 }
929};
930
931struct VectorReductionToFPDotProd final
932 : OpConversionPattern<vector::ReductionOp> {
933 using OpConversionPattern::OpConversionPattern;
934
935 LogicalResult
936 matchAndRewrite(vector::ReductionOp op, OpAdaptor adaptor,
937 ConversionPatternRewriter &rewriter) const override {
938 if (op.getKind() != vector::CombiningKind::ADD)
939 return rewriter.notifyMatchFailure(op, "combining kind is not 'add'");
940
941 auto resultType = getTypeConverter()->convertType<FloatType>(op.getType());
942 if (!resultType)
943 return rewriter.notifyMatchFailure(op, "result is not a float");
944
945 Value vec = adaptor.getVector();
946 Value acc = adaptor.getAcc();
947
948 auto vectorType = dyn_cast<VectorType>(vec.getType());
949 if (!vectorType) {
950 assert(isa<FloatType>(vec.getType()) &&
951 "Expected the vector to be scalarized");
952 if (acc) {
953 rewriter.replaceOpWithNewOp<spirv::FAddOp>(op, acc, vec);
954 return success();
955 }
956
957 rewriter.replaceOp(op, vec);
958 return success();
959 }
960
961 Location loc = op.getLoc();
962 Value lhs;
963 Value rhs;
964 if (auto mul = vec.getDefiningOp<arith::MulFOp>()) {
965 lhs = mul.getLhs();
966 rhs = mul.getRhs();
967 } else {
968 // If the operand is not a mul, use a vector of ones for the dot operand
969 // to just sum up all values.
970 lhs = vec;
971 Attribute oneAttr =
972 rewriter.getFloatAttr(vectorType.getElementType(), 1.0);
973 oneAttr = SplatElementsAttr::get(vectorType, oneAttr);
974 rhs = rewriter.create<spirv::ConstantOp>(loc, vectorType, oneAttr);
975 }
976 assert(lhs);
977 assert(rhs);
978
979 Value res = rewriter.create<spirv::DotOp>(loc, resultType, lhs, rhs);
980 if (acc)
981 res = rewriter.create<spirv::FAddOp>(loc, acc, res);
982
983 rewriter.replaceOp(op, res);
984 return success();
985 }
986};
987
988struct VectorStepOpConvert final : OpConversionPattern<vector::StepOp> {
989 using OpConversionPattern::OpConversionPattern;
990
991 LogicalResult
992 matchAndRewrite(vector::StepOp stepOp, OpAdaptor adaptor,
993 ConversionPatternRewriter &rewriter) const override {
994 const auto &typeConverter = *getTypeConverter<SPIRVTypeConverter>();
995 Type dstType = typeConverter.convertType(stepOp.getType());
996 if (!dstType)
997 return failure();
998
999 Location loc = stepOp.getLoc();
1000 int64_t numElements = stepOp.getType().getNumElements();
1001 auto intType =
1002 rewriter.getIntegerType(typeConverter.getIndexTypeBitwidth());
1003
1004 // Input vectors of size 1 are converted to scalars by the type converter.
1005 // We just create a constant in this case.
1006 if (numElements == 1) {
1007 Value zero = spirv::ConstantOp::getZero(intType, loc, rewriter);
1008 rewriter.replaceOp(stepOp, zero);
1009 return success();
1010 }
1011
1012 SmallVector<Value> source;
1013 source.reserve(N: numElements);
1014 for (int64_t i = 0; i < numElements; ++i) {
1015 Attribute intAttr = rewriter.getIntegerAttr(intType, i);
1016 Value constOp = rewriter.create<spirv::ConstantOp>(loc, intType, intAttr);
1017 source.push_back(Elt: constOp);
1018 }
1019 rewriter.replaceOpWithNewOp<spirv::CompositeConstructOp>(stepOp, dstType,
1020 source);
1021 return success();
1022 }
1023};
1024
1025} // namespace
1026#define CL_INT_MAX_MIN_OPS \
1027 spirv::CLUMaxOp, spirv::CLUMinOp, spirv::CLSMaxOp, spirv::CLSMinOp
1028
1029#define GL_INT_MAX_MIN_OPS \
1030 spirv::GLUMaxOp, spirv::GLUMinOp, spirv::GLSMaxOp, spirv::GLSMinOp
1031
1032#define CL_FLOAT_MAX_MIN_OPS spirv::CLFMaxOp, spirv::CLFMinOp
1033#define GL_FLOAT_MAX_MIN_OPS spirv::GLFMaxOp, spirv::GLFMinOp
1034
1035void mlir::populateVectorToSPIRVPatterns(
1036 const SPIRVTypeConverter &typeConverter, RewritePatternSet &patterns) {
1037 patterns.add<
1038 VectorBitcastConvert, VectorBroadcastConvert,
1039 VectorExtractElementOpConvert, VectorExtractOpConvert,
1040 VectorExtractStridedSliceOpConvert, VectorFmaOpConvert<spirv::GLFmaOp>,
1041 VectorFmaOpConvert<spirv::CLFmaOp>, VectorFromElementsOpConvert,
1042 VectorInsertElementOpConvert, VectorInsertOpConvert,
1043 VectorReductionPattern<GL_INT_MAX_MIN_OPS>,
1044 VectorReductionPattern<CL_INT_MAX_MIN_OPS>,
1045 VectorReductionFloatMinMax<CL_FLOAT_MAX_MIN_OPS>,
1046 VectorReductionFloatMinMax<GL_FLOAT_MAX_MIN_OPS>, VectorShapeCast,
1047 VectorInsertStridedSliceOpConvert, VectorShuffleOpConvert,
1048 VectorInterleaveOpConvert, VectorDeinterleaveOpConvert,
1049 VectorSplatPattern, VectorLoadOpConverter, VectorStoreOpConverter,
1050 VectorStepOpConvert>(typeConverter, patterns.getContext(),
1051 PatternBenefit(1));
1052
1053 // Make sure that the more specialized dot product pattern has higher benefit
1054 // than the generic one that extracts all elements.
1055 patterns.add<VectorReductionToFPDotProd>(arg: typeConverter, args: patterns.getContext(),
1056 args: PatternBenefit(2));
1057}
1058
1059void mlir::populateVectorReductionToSPIRVDotProductPatterns(
1060 RewritePatternSet &patterns) {
1061 patterns.add<VectorReductionToIntDotProd>(arg: patterns.getContext());
1062}
1063

Provided by KDAB

Privacy Policy
Learn to use CMake with our Intro Training
Find out more

source code of mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp