1//===- VectorToSPIRV.cpp - Vector to SPIR-V Patterns ----------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implements patterns to convert Vector dialect to SPIRV dialect.
10//
11//===----------------------------------------------------------------------===//
12
13#include "mlir/Conversion/VectorToSPIRV/VectorToSPIRV.h"
14
15#include "mlir/Dialect/Arith/IR/Arith.h"
16#include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h"
17#include "mlir/Dialect/SPIRV/IR/SPIRVOps.h"
18#include "mlir/Dialect/SPIRV/IR/SPIRVTypes.h"
19#include "mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h"
20#include "mlir/Dialect/Vector/IR/VectorOps.h"
21#include "mlir/IR/Attributes.h"
22#include "mlir/IR/BuiltinAttributes.h"
23#include "mlir/IR/BuiltinTypes.h"
24#include "mlir/IR/Location.h"
25#include "mlir/IR/Matchers.h"
26#include "mlir/IR/PatternMatch.h"
27#include "mlir/IR/TypeUtilities.h"
28#include "mlir/Support/LogicalResult.h"
29#include "mlir/Transforms/DialectConversion.h"
30#include "llvm/ADT/ArrayRef.h"
31#include "llvm/ADT/STLExtras.h"
32#include "llvm/ADT/SmallVector.h"
33#include "llvm/ADT/SmallVectorExtras.h"
34#include "llvm/Support/FormatVariadic.h"
35#include <cassert>
36#include <cstdint>
37#include <numeric>
38
39using namespace mlir;
40
41/// Returns the integer value from the first valid input element, assuming Value
42/// inputs are defined by a constant index ops and Attribute inputs are integer
43/// attributes.
44static uint64_t getFirstIntValue(ValueRange values) {
45 return values[0].getDefiningOp<arith::ConstantIndexOp>().value();
46}
47static uint64_t getFirstIntValue(ArrayRef<Attribute> attr) {
48 return cast<IntegerAttr>(attr[0]).getInt();
49}
50static uint64_t getFirstIntValue(ArrayAttr attr) {
51 return (*attr.getAsValueRange<IntegerAttr>().begin()).getZExtValue();
52}
53static uint64_t getFirstIntValue(ArrayRef<OpFoldResult> foldResults) {
54 auto attr = foldResults[0].dyn_cast<Attribute>();
55 if (attr)
56 return getFirstIntValue(attr);
57
58 return getFirstIntValue(values: ValueRange{foldResults[0].get<Value>()});
59}
60
61/// Returns the number of bits for the given scalar/vector type.
62static int getNumBits(Type type) {
63 // TODO: This does not take into account any memory layout or widening
64 // constraints. E.g., a vector<3xi57> may report to occupy 3x57=171 bit, even
65 // though in practice it will likely be stored as in a 4xi64 vector register.
66 if (auto vectorType = dyn_cast<VectorType>(type))
67 return vectorType.getNumElements() * vectorType.getElementTypeBitWidth();
68 return type.getIntOrFloatBitWidth();
69}
70
71namespace {
72
73struct VectorShapeCast final : public OpConversionPattern<vector::ShapeCastOp> {
74 using OpConversionPattern::OpConversionPattern;
75
76 LogicalResult
77 matchAndRewrite(vector::ShapeCastOp shapeCastOp, OpAdaptor adaptor,
78 ConversionPatternRewriter &rewriter) const override {
79 Type dstType = getTypeConverter()->convertType(shapeCastOp.getType());
80 if (!dstType)
81 return failure();
82
83 // If dstType is same as the source type or the vector size is 1, it can be
84 // directly replaced by the source.
85 if (dstType == adaptor.getSource().getType() ||
86 shapeCastOp.getResultVectorType().getNumElements() == 1) {
87 rewriter.replaceOp(shapeCastOp, adaptor.getSource());
88 return success();
89 }
90
91 // Lowering for size-n vectors when n > 1 hasn't been implemented.
92 return failure();
93 }
94};
95
96struct VectorBitcastConvert final
97 : public OpConversionPattern<vector::BitCastOp> {
98 using OpConversionPattern::OpConversionPattern;
99
100 LogicalResult
101 matchAndRewrite(vector::BitCastOp bitcastOp, OpAdaptor adaptor,
102 ConversionPatternRewriter &rewriter) const override {
103 Type dstType = getTypeConverter()->convertType(bitcastOp.getType());
104 if (!dstType)
105 return failure();
106
107 if (dstType == adaptor.getSource().getType()) {
108 rewriter.replaceOp(bitcastOp, adaptor.getSource());
109 return success();
110 }
111
112 // Check that the source and destination type have the same bitwidth.
113 // Depending on the target environment, we may need to emulate certain
114 // types, which can cause issue with bitcast.
115 Type srcType = adaptor.getSource().getType();
116 if (getNumBits(type: dstType) != getNumBits(type: srcType)) {
117 return rewriter.notifyMatchFailure(
118 bitcastOp,
119 llvm::formatv(Fmt: "different source ({0}) and target ({1}) bitwidth",
120 Vals&: srcType, Vals&: dstType));
121 }
122
123 rewriter.replaceOpWithNewOp<spirv::BitcastOp>(bitcastOp, dstType,
124 adaptor.getSource());
125 return success();
126 }
127};
128
129struct VectorBroadcastConvert final
130 : public OpConversionPattern<vector::BroadcastOp> {
131 using OpConversionPattern::OpConversionPattern;
132
133 LogicalResult
134 matchAndRewrite(vector::BroadcastOp castOp, OpAdaptor adaptor,
135 ConversionPatternRewriter &rewriter) const override {
136 Type resultType =
137 getTypeConverter()->convertType(castOp.getResultVectorType());
138 if (!resultType)
139 return failure();
140
141 if (isa<spirv::ScalarType>(Val: resultType)) {
142 rewriter.replaceOp(castOp, adaptor.getSource());
143 return success();
144 }
145
146 SmallVector<Value, 4> source(castOp.getResultVectorType().getNumElements(),
147 adaptor.getSource());
148 rewriter.replaceOpWithNewOp<spirv::CompositeConstructOp>(
149 castOp, castOp.getResultVectorType(), source);
150 return success();
151 }
152};
153
154struct VectorExtractOpConvert final
155 : public OpConversionPattern<vector::ExtractOp> {
156 using OpConversionPattern::OpConversionPattern;
157
158 LogicalResult
159 matchAndRewrite(vector::ExtractOp extractOp, OpAdaptor adaptor,
160 ConversionPatternRewriter &rewriter) const override {
161 if (extractOp.hasDynamicPosition())
162 return failure();
163
164 Type dstType = getTypeConverter()->convertType(extractOp.getType());
165 if (!dstType)
166 return failure();
167
168 if (isa<spirv::ScalarType>(adaptor.getVector().getType())) {
169 rewriter.replaceOp(extractOp, adaptor.getVector());
170 return success();
171 }
172
173 int32_t id = getFirstIntValue(extractOp.getMixedPosition());
174 rewriter.replaceOpWithNewOp<spirv::CompositeExtractOp>(
175 extractOp, adaptor.getVector(), id);
176 return success();
177 }
178};
179
180struct VectorExtractStridedSliceOpConvert final
181 : public OpConversionPattern<vector::ExtractStridedSliceOp> {
182 using OpConversionPattern::OpConversionPattern;
183
184 LogicalResult
185 matchAndRewrite(vector::ExtractStridedSliceOp extractOp, OpAdaptor adaptor,
186 ConversionPatternRewriter &rewriter) const override {
187 Type dstType = getTypeConverter()->convertType(extractOp.getType());
188 if (!dstType)
189 return failure();
190
191 uint64_t offset = getFirstIntValue(extractOp.getOffsets());
192 uint64_t size = getFirstIntValue(extractOp.getSizes());
193 uint64_t stride = getFirstIntValue(extractOp.getStrides());
194 if (stride != 1)
195 return failure();
196
197 Value srcVector = adaptor.getOperands().front();
198
199 // Extract vector<1xT> case.
200 if (isa<spirv::ScalarType>(Val: dstType)) {
201 rewriter.replaceOpWithNewOp<spirv::CompositeExtractOp>(extractOp,
202 srcVector, offset);
203 return success();
204 }
205
206 SmallVector<int32_t, 2> indices(size);
207 std::iota(first: indices.begin(), last: indices.end(), value: offset);
208
209 rewriter.replaceOpWithNewOp<spirv::VectorShuffleOp>(
210 extractOp, dstType, srcVector, srcVector,
211 rewriter.getI32ArrayAttr(indices));
212
213 return success();
214 }
215};
216
217template <class SPIRVFMAOp>
218struct VectorFmaOpConvert final : public OpConversionPattern<vector::FMAOp> {
219 using OpConversionPattern::OpConversionPattern;
220
221 LogicalResult
222 matchAndRewrite(vector::FMAOp fmaOp, OpAdaptor adaptor,
223 ConversionPatternRewriter &rewriter) const override {
224 Type dstType = getTypeConverter()->convertType(fmaOp.getType());
225 if (!dstType)
226 return failure();
227 rewriter.replaceOpWithNewOp<SPIRVFMAOp>(fmaOp, dstType, adaptor.getLhs(),
228 adaptor.getRhs(), adaptor.getAcc());
229 return success();
230 }
231};
232
233struct VectorInsertOpConvert final
234 : public OpConversionPattern<vector::InsertOp> {
235 using OpConversionPattern::OpConversionPattern;
236
237 LogicalResult
238 matchAndRewrite(vector::InsertOp insertOp, OpAdaptor adaptor,
239 ConversionPatternRewriter &rewriter) const override {
240 if (isa<VectorType>(insertOp.getSourceType()))
241 return rewriter.notifyMatchFailure(insertOp, "unsupported vector source");
242 if (!getTypeConverter()->convertType(insertOp.getDestVectorType()))
243 return rewriter.notifyMatchFailure(insertOp,
244 "unsupported dest vector type");
245
246 // Special case for inserting scalar values into size-1 vectors.
247 if (insertOp.getSourceType().isIntOrFloat() &&
248 insertOp.getDestVectorType().getNumElements() == 1) {
249 rewriter.replaceOp(insertOp, adaptor.getSource());
250 return success();
251 }
252
253 int32_t id = getFirstIntValue(insertOp.getMixedPosition());
254 rewriter.replaceOpWithNewOp<spirv::CompositeInsertOp>(
255 insertOp, adaptor.getSource(), adaptor.getDest(), id);
256 return success();
257 }
258};
259
260struct VectorExtractElementOpConvert final
261 : public OpConversionPattern<vector::ExtractElementOp> {
262 using OpConversionPattern::OpConversionPattern;
263
264 LogicalResult
265 matchAndRewrite(vector::ExtractElementOp extractOp, OpAdaptor adaptor,
266 ConversionPatternRewriter &rewriter) const override {
267 Type resultType = getTypeConverter()->convertType(extractOp.getType());
268 if (!resultType)
269 return failure();
270
271 if (isa<spirv::ScalarType>(adaptor.getVector().getType())) {
272 rewriter.replaceOp(extractOp, adaptor.getVector());
273 return success();
274 }
275
276 APInt cstPos;
277 if (matchPattern(adaptor.getPosition(), m_ConstantInt(&cstPos)))
278 rewriter.replaceOpWithNewOp<spirv::CompositeExtractOp>(
279 extractOp, resultType, adaptor.getVector(),
280 rewriter.getI32ArrayAttr({static_cast<int>(cstPos.getSExtValue())}));
281 else
282 rewriter.replaceOpWithNewOp<spirv::VectorExtractDynamicOp>(
283 extractOp, resultType, adaptor.getVector(), adaptor.getPosition());
284 return success();
285 }
286};
287
288struct VectorInsertElementOpConvert final
289 : public OpConversionPattern<vector::InsertElementOp> {
290 using OpConversionPattern::OpConversionPattern;
291
292 LogicalResult
293 matchAndRewrite(vector::InsertElementOp insertOp, OpAdaptor adaptor,
294 ConversionPatternRewriter &rewriter) const override {
295 Type vectorType = getTypeConverter()->convertType(insertOp.getType());
296 if (!vectorType)
297 return failure();
298
299 if (isa<spirv::ScalarType>(Val: vectorType)) {
300 rewriter.replaceOp(insertOp, adaptor.getSource());
301 return success();
302 }
303
304 APInt cstPos;
305 if (matchPattern(adaptor.getPosition(), m_ConstantInt(&cstPos)))
306 rewriter.replaceOpWithNewOp<spirv::CompositeInsertOp>(
307 insertOp, adaptor.getSource(), adaptor.getDest(),
308 cstPos.getSExtValue());
309 else
310 rewriter.replaceOpWithNewOp<spirv::VectorInsertDynamicOp>(
311 insertOp, vectorType, insertOp.getDest(), adaptor.getSource(),
312 adaptor.getPosition());
313 return success();
314 }
315};
316
317struct VectorInsertStridedSliceOpConvert final
318 : public OpConversionPattern<vector::InsertStridedSliceOp> {
319 using OpConversionPattern::OpConversionPattern;
320
321 LogicalResult
322 matchAndRewrite(vector::InsertStridedSliceOp insertOp, OpAdaptor adaptor,
323 ConversionPatternRewriter &rewriter) const override {
324 Value srcVector = adaptor.getOperands().front();
325 Value dstVector = adaptor.getOperands().back();
326
327 uint64_t stride = getFirstIntValue(insertOp.getStrides());
328 if (stride != 1)
329 return failure();
330 uint64_t offset = getFirstIntValue(insertOp.getOffsets());
331
332 if (isa<spirv::ScalarType>(Val: srcVector.getType())) {
333 assert(!isa<spirv::ScalarType>(dstVector.getType()));
334 rewriter.replaceOpWithNewOp<spirv::CompositeInsertOp>(
335 insertOp, dstVector.getType(), srcVector, dstVector,
336 rewriter.getI32ArrayAttr(offset));
337 return success();
338 }
339
340 uint64_t totalSize = cast<VectorType>(dstVector.getType()).getNumElements();
341 uint64_t insertSize =
342 cast<VectorType>(srcVector.getType()).getNumElements();
343
344 SmallVector<int32_t, 2> indices(totalSize);
345 std::iota(first: indices.begin(), last: indices.end(), value: 0);
346 std::iota(first: indices.begin() + offset, last: indices.begin() + offset + insertSize,
347 value: totalSize);
348
349 rewriter.replaceOpWithNewOp<spirv::VectorShuffleOp>(
350 insertOp, dstVector.getType(), dstVector, srcVector,
351 rewriter.getI32ArrayAttr(indices));
352
353 return success();
354 }
355};
356
357static SmallVector<Value> extractAllElements(
358 vector::ReductionOp reduceOp, vector::ReductionOp::Adaptor adaptor,
359 VectorType srcVectorType, ConversionPatternRewriter &rewriter) {
360 int numElements = static_cast<int>(srcVectorType.getDimSize(0));
361 SmallVector<Value> values;
362 values.reserve(N: numElements + (adaptor.getAcc() ? 1 : 0));
363 Location loc = reduceOp.getLoc();
364
365 for (int i = 0; i < numElements; ++i) {
366 values.push_back(rewriter.create<spirv::CompositeExtractOp>(
367 loc, srcVectorType.getElementType(), adaptor.getVector(),
368 rewriter.getI32ArrayAttr({i})));
369 }
370 if (Value acc = adaptor.getAcc())
371 values.push_back(Elt: acc);
372
373 return values;
374}
375
376struct ReductionRewriteInfo {
377 Type resultType;
378 SmallVector<Value> extractedElements;
379};
380
381FailureOr<ReductionRewriteInfo> static getReductionInfo(
382 vector::ReductionOp op, vector::ReductionOp::Adaptor adaptor,
383 ConversionPatternRewriter &rewriter, const TypeConverter &typeConverter) {
384 Type resultType = typeConverter.convertType(op.getType());
385 if (!resultType)
386 return failure();
387
388 auto srcVectorType = dyn_cast<VectorType>(adaptor.getVector().getType());
389 if (!srcVectorType || srcVectorType.getRank() != 1)
390 return rewriter.notifyMatchFailure(op, "not a 1-D vector source");
391
392 SmallVector<Value> extractedElements =
393 extractAllElements(op, adaptor, srcVectorType, rewriter);
394
395 return ReductionRewriteInfo{.resultType: resultType, .extractedElements: std::move(extractedElements)};
396}
397
398template <typename SPIRVUMaxOp, typename SPIRVUMinOp, typename SPIRVSMaxOp,
399 typename SPIRVSMinOp>
400struct VectorReductionPattern final : OpConversionPattern<vector::ReductionOp> {
401 using OpConversionPattern::OpConversionPattern;
402
403 LogicalResult
404 matchAndRewrite(vector::ReductionOp reduceOp, OpAdaptor adaptor,
405 ConversionPatternRewriter &rewriter) const override {
406 auto reductionInfo =
407 getReductionInfo(reduceOp, adaptor, rewriter, *getTypeConverter());
408 if (failed(reductionInfo))
409 return failure();
410
411 auto [resultType, extractedElements] = *reductionInfo;
412 Location loc = reduceOp->getLoc();
413 Value result = extractedElements.front();
414 for (Value next : llvm::drop_begin(extractedElements)) {
415 switch (reduceOp.getKind()) {
416
417#define INT_AND_FLOAT_CASE(kind, iop, fop) \
418 case vector::CombiningKind::kind: \
419 if (llvm::isa<IntegerType>(resultType)) { \
420 result = rewriter.create<spirv::iop>(loc, resultType, result, next); \
421 } else { \
422 assert(llvm::isa<FloatType>(resultType)); \
423 result = rewriter.create<spirv::fop>(loc, resultType, result, next); \
424 } \
425 break
426
427#define INT_OR_FLOAT_CASE(kind, fop) \
428 case vector::CombiningKind::kind: \
429 result = rewriter.create<fop>(loc, resultType, result, next); \
430 break
431
432 INT_AND_FLOAT_CASE(ADD, IAddOp, FAddOp);
433 INT_AND_FLOAT_CASE(MUL, IMulOp, FMulOp);
434 INT_OR_FLOAT_CASE(MINUI, SPIRVUMinOp);
435 INT_OR_FLOAT_CASE(MINSI, SPIRVSMinOp);
436 INT_OR_FLOAT_CASE(MAXUI, SPIRVUMaxOp);
437 INT_OR_FLOAT_CASE(MAXSI, SPIRVSMaxOp);
438
439 case vector::CombiningKind::AND:
440 case vector::CombiningKind::OR:
441 case vector::CombiningKind::XOR:
442 return rewriter.notifyMatchFailure(reduceOp, "unimplemented");
443 default:
444 return rewriter.notifyMatchFailure(reduceOp, "not handled here");
445 }
446#undef INT_AND_FLOAT_CASE
447#undef INT_OR_FLOAT_CASE
448 }
449
450 rewriter.replaceOp(reduceOp, result);
451 return success();
452 }
453};
454
455template <typename SPIRVFMaxOp, typename SPIRVFMinOp>
456struct VectorReductionFloatMinMax final
457 : OpConversionPattern<vector::ReductionOp> {
458 using OpConversionPattern::OpConversionPattern;
459
460 LogicalResult
461 matchAndRewrite(vector::ReductionOp reduceOp, OpAdaptor adaptor,
462 ConversionPatternRewriter &rewriter) const override {
463 auto reductionInfo =
464 getReductionInfo(reduceOp, adaptor, rewriter, *getTypeConverter());
465 if (failed(reductionInfo))
466 return failure();
467
468 auto [resultType, extractedElements] = *reductionInfo;
469 Location loc = reduceOp->getLoc();
470 Value result = extractedElements.front();
471 for (Value next : llvm::drop_begin(extractedElements)) {
472 switch (reduceOp.getKind()) {
473
474#define INT_OR_FLOAT_CASE(kind, fop) \
475 case vector::CombiningKind::kind: \
476 result = rewriter.create<fop>(loc, resultType, result, next); \
477 break
478
479 INT_OR_FLOAT_CASE(MAXIMUMF, SPIRVFMaxOp);
480 INT_OR_FLOAT_CASE(MINIMUMF, SPIRVFMinOp);
481 INT_OR_FLOAT_CASE(MAXNUMF, SPIRVFMaxOp);
482 INT_OR_FLOAT_CASE(MINNUMF, SPIRVFMinOp);
483
484 default:
485 return rewriter.notifyMatchFailure(reduceOp, "not handled here");
486 }
487#undef INT_OR_FLOAT_CASE
488 }
489
490 rewriter.replaceOp(reduceOp, result);
491 return success();
492 }
493};
494
495class VectorSplatPattern final : public OpConversionPattern<vector::SplatOp> {
496public:
497 using OpConversionPattern<vector::SplatOp>::OpConversionPattern;
498
499 LogicalResult
500 matchAndRewrite(vector::SplatOp op, OpAdaptor adaptor,
501 ConversionPatternRewriter &rewriter) const override {
502 Type dstType = getTypeConverter()->convertType(op.getType());
503 if (!dstType)
504 return failure();
505 if (isa<spirv::ScalarType>(Val: dstType)) {
506 rewriter.replaceOp(op, adaptor.getInput());
507 } else {
508 auto dstVecType = cast<VectorType>(dstType);
509 SmallVector<Value, 4> source(dstVecType.getNumElements(),
510 adaptor.getInput());
511 rewriter.replaceOpWithNewOp<spirv::CompositeConstructOp>(op, dstType,
512 source);
513 }
514 return success();
515 }
516};
517
518struct VectorShuffleOpConvert final
519 : public OpConversionPattern<vector::ShuffleOp> {
520 using OpConversionPattern::OpConversionPattern;
521
522 LogicalResult
523 matchAndRewrite(vector::ShuffleOp shuffleOp, OpAdaptor adaptor,
524 ConversionPatternRewriter &rewriter) const override {
525 auto oldResultType = shuffleOp.getResultVectorType();
526 Type newResultType = getTypeConverter()->convertType(oldResultType);
527 if (!newResultType)
528 return rewriter.notifyMatchFailure(shuffleOp,
529 "unsupported result vector type");
530
531 SmallVector<int32_t, 4> mask = llvm::map_to_vector<4>(
532 shuffleOp.getMask(), [](Attribute attr) -> int32_t {
533 return cast<IntegerAttr>(attr).getValue().getZExtValue();
534 });
535
536 auto oldV1Type = shuffleOp.getV1VectorType();
537 auto oldV2Type = shuffleOp.getV2VectorType();
538
539 // When both operands are SPIR-V vectors, emit a SPIR-V shuffle.
540 if (oldV1Type.getNumElements() > 1 && oldV2Type.getNumElements() > 1) {
541 rewriter.replaceOpWithNewOp<spirv::VectorShuffleOp>(
542 shuffleOp, newResultType, adaptor.getV1(), adaptor.getV2(),
543 rewriter.getI32ArrayAttr(mask));
544 return success();
545 }
546
547 // When at least one of the operands becomes a scalar after type conversion
548 // for SPIR-V, extract all the required elements and construct the result
549 // vector.
550 auto getElementAtIdx = [&rewriter, loc = shuffleOp.getLoc()](
551 Value scalarOrVec, int32_t idx) -> Value {
552 if (auto vecTy = dyn_cast<VectorType>(scalarOrVec.getType()))
553 return rewriter.create<spirv::CompositeExtractOp>(loc, scalarOrVec,
554 idx);
555
556 assert(idx == 0 && "Invalid scalar element index");
557 return scalarOrVec;
558 };
559
560 int32_t numV1Elems = oldV1Type.getNumElements();
561 SmallVector<Value> newOperands(mask.size());
562 for (auto [shuffleIdx, newOperand] : llvm::zip_equal(mask, newOperands)) {
563 Value vec = adaptor.getV1();
564 int32_t elementIdx = shuffleIdx;
565 if (elementIdx >= numV1Elems) {
566 vec = adaptor.getV2();
567 elementIdx -= numV1Elems;
568 }
569
570 newOperand = getElementAtIdx(vec, elementIdx);
571 }
572
573 rewriter.replaceOpWithNewOp<spirv::CompositeConstructOp>(
574 shuffleOp, newResultType, newOperands);
575
576 return success();
577 }
578};
579
580struct VectorLoadOpConverter final
581 : public OpConversionPattern<vector::LoadOp> {
582 using OpConversionPattern::OpConversionPattern;
583
584 LogicalResult
585 matchAndRewrite(vector::LoadOp loadOp, OpAdaptor adaptor,
586 ConversionPatternRewriter &rewriter) const override {
587 auto memrefType = loadOp.getMemRefType();
588 auto attr =
589 dyn_cast_or_null<spirv::StorageClassAttr>(memrefType.getMemorySpace());
590 if (!attr)
591 return rewriter.notifyMatchFailure(
592 loadOp, "expected spirv.storage_class memory space");
593
594 const auto &typeConverter = *getTypeConverter<SPIRVTypeConverter>();
595 auto loc = loadOp.getLoc();
596 Value accessChain =
597 spirv::getElementPtr(typeConverter: typeConverter, baseType: memrefType, basePtr: adaptor.getBase(),
598 indices: adaptor.getIndices(), loc: loc, builder&: rewriter);
599 if (!accessChain)
600 return rewriter.notifyMatchFailure(
601 loadOp, "failed to get memref element pointer");
602
603 spirv::StorageClass storageClass = attr.getValue();
604 auto vectorType = loadOp.getVectorType();
605 auto vectorPtrType = spirv::PointerType::get(vectorType, storageClass);
606 Value castedAccessChain =
607 rewriter.create<spirv::BitcastOp>(loc, vectorPtrType, accessChain);
608 rewriter.replaceOpWithNewOp<spirv::LoadOp>(loadOp, vectorType,
609 castedAccessChain);
610
611 return success();
612 }
613};
614
615struct VectorStoreOpConverter final
616 : public OpConversionPattern<vector::StoreOp> {
617 using OpConversionPattern::OpConversionPattern;
618
619 LogicalResult
620 matchAndRewrite(vector::StoreOp storeOp, OpAdaptor adaptor,
621 ConversionPatternRewriter &rewriter) const override {
622 auto memrefType = storeOp.getMemRefType();
623 auto attr =
624 dyn_cast_or_null<spirv::StorageClassAttr>(memrefType.getMemorySpace());
625 if (!attr)
626 return rewriter.notifyMatchFailure(
627 storeOp, "expected spirv.storage_class memory space");
628
629 const auto &typeConverter = *getTypeConverter<SPIRVTypeConverter>();
630 auto loc = storeOp.getLoc();
631 Value accessChain =
632 spirv::getElementPtr(typeConverter: typeConverter, baseType: memrefType, basePtr: adaptor.getBase(),
633 indices: adaptor.getIndices(), loc: loc, builder&: rewriter);
634 if (!accessChain)
635 return rewriter.notifyMatchFailure(
636 storeOp, "failed to get memref element pointer");
637
638 spirv::StorageClass storageClass = attr.getValue();
639 auto vectorType = storeOp.getVectorType();
640 auto vectorPtrType = spirv::PointerType::get(vectorType, storageClass);
641 Value castedAccessChain =
642 rewriter.create<spirv::BitcastOp>(loc, vectorPtrType, accessChain);
643 rewriter.replaceOpWithNewOp<spirv::StoreOp>(storeOp, castedAccessChain,
644 adaptor.getValueToStore());
645
646 return success();
647 }
648};
649
650struct VectorReductionToIntDotProd final
651 : OpRewritePattern<vector::ReductionOp> {
652 using OpRewritePattern::OpRewritePattern;
653
654 LogicalResult matchAndRewrite(vector::ReductionOp op,
655 PatternRewriter &rewriter) const override {
656 if (op.getKind() != vector::CombiningKind::ADD)
657 return rewriter.notifyMatchFailure(op, "combining kind is not 'add'");
658
659 auto resultType = dyn_cast<IntegerType>(op.getType());
660 if (!resultType)
661 return rewriter.notifyMatchFailure(op, "result is not an integer");
662
663 int64_t resultBitwidth = resultType.getIntOrFloatBitWidth();
664 if (!llvm::is_contained(Set: {32, 64}, Element: resultBitwidth))
665 return rewriter.notifyMatchFailure(op, "unsupported integer bitwidth");
666
667 VectorType inVecTy = op.getSourceVectorType();
668 if (!llvm::is_contained({4, 3}, inVecTy.getNumElements()) ||
669 inVecTy.getShape().size() != 1 || inVecTy.isScalable())
670 return rewriter.notifyMatchFailure(op, "unsupported vector shape");
671
672 auto mul = op.getVector().getDefiningOp<arith::MulIOp>();
673 if (!mul)
674 return rewriter.notifyMatchFailure(
675 op, "reduction operand is not 'arith.muli'");
676
677 if (succeeded(handleCase<arith::ExtSIOp, arith::ExtSIOp, spirv::SDotOp,
678 spirv::SDotAccSatOp, false>(op, mul, rewriter)))
679 return success();
680
681 if (succeeded(handleCase<arith::ExtUIOp, arith::ExtUIOp, spirv::UDotOp,
682 spirv::UDotAccSatOp, false>(op, mul, rewriter)))
683 return success();
684
685 if (succeeded(handleCase<arith::ExtSIOp, arith::ExtUIOp, spirv::SUDotOp,
686 spirv::SUDotAccSatOp, false>(op, mul, rewriter)))
687 return success();
688
689 if (succeeded(handleCase<arith::ExtUIOp, arith::ExtSIOp, spirv::SUDotOp,
690 spirv::SUDotAccSatOp, true>(op, mul, rewriter)))
691 return success();
692
693 return failure();
694 }
695
696private:
697 template <typename LhsExtensionOp, typename RhsExtensionOp, typename DotOp,
698 typename DotAccOp, bool SwapOperands>
699 static LogicalResult handleCase(vector::ReductionOp op, arith::MulIOp mul,
700 PatternRewriter &rewriter) {
701 auto lhs = mul.getLhs().getDefiningOp<LhsExtensionOp>();
702 if (!lhs)
703 return failure();
704 Value lhsIn = lhs.getIn();
705 auto lhsInType = cast<VectorType>(lhsIn.getType());
706 if (!lhsInType.getElementType().isInteger(8))
707 return failure();
708
709 auto rhs = mul.getRhs().getDefiningOp<RhsExtensionOp>();
710 if (!rhs)
711 return failure();
712 Value rhsIn = rhs.getIn();
713 auto rhsInType = cast<VectorType>(rhsIn.getType());
714 if (!rhsInType.getElementType().isInteger(8))
715 return failure();
716
717 if (op.getSourceVectorType().getNumElements() == 3) {
718 IntegerType i8Type = rewriter.getI8Type();
719 auto v4i8Type = VectorType::get({4}, i8Type);
720 Location loc = op.getLoc();
721 Value zero = spirv::ConstantOp::getZero(i8Type, loc, rewriter);
722 lhsIn = rewriter.create<spirv::CompositeConstructOp>(
723 loc, v4i8Type, ValueRange{lhsIn, zero});
724 rhsIn = rewriter.create<spirv::CompositeConstructOp>(
725 loc, v4i8Type, ValueRange{rhsIn, zero});
726 }
727
728 // There's no variant of dot prod ops for unsigned LHS and signed RHS, so
729 // we have to swap operands instead in that case.
730 if (SwapOperands)
731 std::swap(a&: lhsIn, b&: rhsIn);
732
733 if (Value acc = op.getAcc()) {
734 rewriter.replaceOpWithNewOp<DotAccOp>(op, op.getType(), lhsIn, rhsIn, acc,
735 nullptr);
736 } else {
737 rewriter.replaceOpWithNewOp<DotOp>(op, op.getType(), lhsIn, rhsIn,
738 nullptr);
739 }
740
741 return success();
742 }
743};
744
745struct VectorReductionToFPDotProd final
746 : OpConversionPattern<vector::ReductionOp> {
747 using OpConversionPattern::OpConversionPattern;
748
749 LogicalResult
750 matchAndRewrite(vector::ReductionOp op, OpAdaptor adaptor,
751 ConversionPatternRewriter &rewriter) const override {
752 if (op.getKind() != vector::CombiningKind::ADD)
753 return rewriter.notifyMatchFailure(op, "combining kind is not 'add'");
754
755 auto resultType = getTypeConverter()->convertType<FloatType>(op.getType());
756 if (!resultType)
757 return rewriter.notifyMatchFailure(op, "result is not a float");
758
759 Value vec = adaptor.getVector();
760 Value acc = adaptor.getAcc();
761
762 auto vectorType = dyn_cast<VectorType>(vec.getType());
763 if (!vectorType) {
764 assert(isa<FloatType>(vec.getType()) &&
765 "Expected the vector to be scalarized");
766 if (acc) {
767 rewriter.replaceOpWithNewOp<spirv::FAddOp>(op, acc, vec);
768 return success();
769 }
770
771 rewriter.replaceOp(op, vec);
772 return success();
773 }
774
775 Location loc = op.getLoc();
776 Value lhs;
777 Value rhs;
778 if (auto mul = vec.getDefiningOp<arith::MulFOp>()) {
779 lhs = mul.getLhs();
780 rhs = mul.getRhs();
781 } else {
782 // If the operand is not a mul, use a vector of ones for the dot operand
783 // to just sum up all values.
784 lhs = vec;
785 Attribute oneAttr =
786 rewriter.getFloatAttr(vectorType.getElementType(), 1.0);
787 oneAttr = SplatElementsAttr::get(vectorType, oneAttr);
788 rhs = rewriter.create<spirv::ConstantOp>(loc, vectorType, oneAttr);
789 }
790 assert(lhs);
791 assert(rhs);
792
793 Value res = rewriter.create<spirv::DotOp>(loc, resultType, lhs, rhs);
794 if (acc)
795 res = rewriter.create<spirv::FAddOp>(loc, acc, res);
796
797 rewriter.replaceOp(op, res);
798 return success();
799 }
800};
801
802} // namespace
803#define CL_INT_MAX_MIN_OPS \
804 spirv::CLUMaxOp, spirv::CLUMinOp, spirv::CLSMaxOp, spirv::CLSMinOp
805
806#define GL_INT_MAX_MIN_OPS \
807 spirv::GLUMaxOp, spirv::GLUMinOp, spirv::GLSMaxOp, spirv::GLSMinOp
808
809#define CL_FLOAT_MAX_MIN_OPS spirv::CLFMaxOp, spirv::CLFMinOp
810#define GL_FLOAT_MAX_MIN_OPS spirv::GLFMaxOp, spirv::GLFMinOp
811
812void mlir::populateVectorToSPIRVPatterns(SPIRVTypeConverter &typeConverter,
813 RewritePatternSet &patterns) {
814 patterns.add<
815 VectorBitcastConvert, VectorBroadcastConvert,
816 VectorExtractElementOpConvert, VectorExtractOpConvert,
817 VectorExtractStridedSliceOpConvert, VectorFmaOpConvert<spirv::GLFmaOp>,
818 VectorFmaOpConvert<spirv::CLFmaOp>, VectorInsertElementOpConvert,
819 VectorInsertOpConvert, VectorReductionPattern<GL_INT_MAX_MIN_OPS>,
820 VectorReductionPattern<CL_INT_MAX_MIN_OPS>,
821 VectorReductionFloatMinMax<CL_FLOAT_MAX_MIN_OPS>,
822 VectorReductionFloatMinMax<GL_FLOAT_MAX_MIN_OPS>, VectorShapeCast,
823 VectorInsertStridedSliceOpConvert, VectorShuffleOpConvert,
824 VectorSplatPattern, VectorLoadOpConverter, VectorStoreOpConverter>(
825 typeConverter, patterns.getContext(), PatternBenefit(1));
826
827 // Make sure that the more specialized dot product pattern has higher benefit
828 // than the generic one that extracts all elements.
829 patterns.add<VectorReductionToFPDotProd>(arg&: typeConverter, args: patterns.getContext(),
830 args: PatternBenefit(2));
831}
832
833void mlir::populateVectorReductionToSPIRVDotProductPatterns(
834 RewritePatternSet &patterns) {
835 patterns.add<VectorReductionToIntDotProd>(arg: patterns.getContext());
836}
837

source code of mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp