1 | //===- VectorToSPIRV.cpp - Vector to SPIR-V Patterns ----------------------===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | // |
9 | // This file implements patterns to convert Vector dialect to SPIRV dialect. |
10 | // |
11 | //===----------------------------------------------------------------------===// |
12 | |
13 | #include "mlir/Conversion/VectorToSPIRV/VectorToSPIRV.h" |
14 | |
15 | #include "mlir/Dialect/Arith/IR/Arith.h" |
16 | #include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h" |
17 | #include "mlir/Dialect/SPIRV/IR/SPIRVOps.h" |
18 | #include "mlir/Dialect/SPIRV/IR/SPIRVTypes.h" |
19 | #include "mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h" |
20 | #include "mlir/Dialect/Vector/IR/VectorOps.h" |
21 | #include "mlir/IR/Attributes.h" |
22 | #include "mlir/IR/BuiltinAttributes.h" |
23 | #include "mlir/IR/BuiltinTypes.h" |
24 | #include "mlir/IR/Location.h" |
25 | #include "mlir/IR/Matchers.h" |
26 | #include "mlir/IR/PatternMatch.h" |
27 | #include "mlir/IR/TypeUtilities.h" |
28 | #include "mlir/Support/LogicalResult.h" |
29 | #include "mlir/Transforms/DialectConversion.h" |
30 | #include "llvm/ADT/ArrayRef.h" |
31 | #include "llvm/ADT/STLExtras.h" |
32 | #include "llvm/ADT/SmallVector.h" |
33 | #include "llvm/ADT/SmallVectorExtras.h" |
34 | #include "llvm/Support/FormatVariadic.h" |
35 | #include <cassert> |
36 | #include <cstdint> |
37 | #include <numeric> |
38 | |
39 | using namespace mlir; |
40 | |
41 | /// Returns the integer value from the first valid input element, assuming Value |
42 | /// inputs are defined by a constant index ops and Attribute inputs are integer |
43 | /// attributes. |
44 | static uint64_t getFirstIntValue(ValueRange values) { |
45 | return values[0].getDefiningOp<arith::ConstantIndexOp>().value(); |
46 | } |
47 | static uint64_t getFirstIntValue(ArrayRef<Attribute> attr) { |
48 | return cast<IntegerAttr>(attr[0]).getInt(); |
49 | } |
50 | static uint64_t getFirstIntValue(ArrayAttr attr) { |
51 | return (*attr.getAsValueRange<IntegerAttr>().begin()).getZExtValue(); |
52 | } |
53 | static uint64_t getFirstIntValue(ArrayRef<OpFoldResult> foldResults) { |
54 | auto attr = foldResults[0].dyn_cast<Attribute>(); |
55 | if (attr) |
56 | return getFirstIntValue(attr); |
57 | |
58 | return getFirstIntValue(values: ValueRange{foldResults[0].get<Value>()}); |
59 | } |
60 | |
61 | /// Returns the number of bits for the given scalar/vector type. |
62 | static int getNumBits(Type type) { |
63 | // TODO: This does not take into account any memory layout or widening |
64 | // constraints. E.g., a vector<3xi57> may report to occupy 3x57=171 bit, even |
65 | // though in practice it will likely be stored as in a 4xi64 vector register. |
66 | if (auto vectorType = dyn_cast<VectorType>(type)) |
67 | return vectorType.getNumElements() * vectorType.getElementTypeBitWidth(); |
68 | return type.getIntOrFloatBitWidth(); |
69 | } |
70 | |
71 | namespace { |
72 | |
73 | struct VectorShapeCast final : public OpConversionPattern<vector::ShapeCastOp> { |
74 | using OpConversionPattern::OpConversionPattern; |
75 | |
76 | LogicalResult |
77 | matchAndRewrite(vector::ShapeCastOp shapeCastOp, OpAdaptor adaptor, |
78 | ConversionPatternRewriter &rewriter) const override { |
79 | Type dstType = getTypeConverter()->convertType(shapeCastOp.getType()); |
80 | if (!dstType) |
81 | return failure(); |
82 | |
83 | // If dstType is same as the source type or the vector size is 1, it can be |
84 | // directly replaced by the source. |
85 | if (dstType == adaptor.getSource().getType() || |
86 | shapeCastOp.getResultVectorType().getNumElements() == 1) { |
87 | rewriter.replaceOp(shapeCastOp, adaptor.getSource()); |
88 | return success(); |
89 | } |
90 | |
91 | // Lowering for size-n vectors when n > 1 hasn't been implemented. |
92 | return failure(); |
93 | } |
94 | }; |
95 | |
96 | struct VectorBitcastConvert final |
97 | : public OpConversionPattern<vector::BitCastOp> { |
98 | using OpConversionPattern::OpConversionPattern; |
99 | |
100 | LogicalResult |
101 | matchAndRewrite(vector::BitCastOp bitcastOp, OpAdaptor adaptor, |
102 | ConversionPatternRewriter &rewriter) const override { |
103 | Type dstType = getTypeConverter()->convertType(bitcastOp.getType()); |
104 | if (!dstType) |
105 | return failure(); |
106 | |
107 | if (dstType == adaptor.getSource().getType()) { |
108 | rewriter.replaceOp(bitcastOp, adaptor.getSource()); |
109 | return success(); |
110 | } |
111 | |
112 | // Check that the source and destination type have the same bitwidth. |
113 | // Depending on the target environment, we may need to emulate certain |
114 | // types, which can cause issue with bitcast. |
115 | Type srcType = adaptor.getSource().getType(); |
116 | if (getNumBits(type: dstType) != getNumBits(type: srcType)) { |
117 | return rewriter.notifyMatchFailure( |
118 | bitcastOp, |
119 | llvm::formatv(Fmt: "different source ({0}) and target ({1}) bitwidth" , |
120 | Vals&: srcType, Vals&: dstType)); |
121 | } |
122 | |
123 | rewriter.replaceOpWithNewOp<spirv::BitcastOp>(bitcastOp, dstType, |
124 | adaptor.getSource()); |
125 | return success(); |
126 | } |
127 | }; |
128 | |
129 | struct VectorBroadcastConvert final |
130 | : public OpConversionPattern<vector::BroadcastOp> { |
131 | using OpConversionPattern::OpConversionPattern; |
132 | |
133 | LogicalResult |
134 | matchAndRewrite(vector::BroadcastOp castOp, OpAdaptor adaptor, |
135 | ConversionPatternRewriter &rewriter) const override { |
136 | Type resultType = |
137 | getTypeConverter()->convertType(castOp.getResultVectorType()); |
138 | if (!resultType) |
139 | return failure(); |
140 | |
141 | if (isa<spirv::ScalarType>(Val: resultType)) { |
142 | rewriter.replaceOp(castOp, adaptor.getSource()); |
143 | return success(); |
144 | } |
145 | |
146 | SmallVector<Value, 4> source(castOp.getResultVectorType().getNumElements(), |
147 | adaptor.getSource()); |
148 | rewriter.replaceOpWithNewOp<spirv::CompositeConstructOp>( |
149 | castOp, castOp.getResultVectorType(), source); |
150 | return success(); |
151 | } |
152 | }; |
153 | |
154 | struct final |
155 | : public OpConversionPattern<vector::ExtractOp> { |
156 | using OpConversionPattern::OpConversionPattern; |
157 | |
158 | LogicalResult |
159 | matchAndRewrite(vector::ExtractOp , OpAdaptor adaptor, |
160 | ConversionPatternRewriter &rewriter) const override { |
161 | if (extractOp.hasDynamicPosition()) |
162 | return failure(); |
163 | |
164 | Type dstType = getTypeConverter()->convertType(extractOp.getType()); |
165 | if (!dstType) |
166 | return failure(); |
167 | |
168 | if (isa<spirv::ScalarType>(adaptor.getVector().getType())) { |
169 | rewriter.replaceOp(extractOp, adaptor.getVector()); |
170 | return success(); |
171 | } |
172 | |
173 | int32_t id = getFirstIntValue(extractOp.getMixedPosition()); |
174 | rewriter.replaceOpWithNewOp<spirv::CompositeExtractOp>( |
175 | extractOp, adaptor.getVector(), id); |
176 | return success(); |
177 | } |
178 | }; |
179 | |
180 | struct final |
181 | : public OpConversionPattern<vector::ExtractStridedSliceOp> { |
182 | using OpConversionPattern::OpConversionPattern; |
183 | |
184 | LogicalResult |
185 | matchAndRewrite(vector::ExtractStridedSliceOp , OpAdaptor adaptor, |
186 | ConversionPatternRewriter &rewriter) const override { |
187 | Type dstType = getTypeConverter()->convertType(extractOp.getType()); |
188 | if (!dstType) |
189 | return failure(); |
190 | |
191 | uint64_t offset = getFirstIntValue(extractOp.getOffsets()); |
192 | uint64_t size = getFirstIntValue(extractOp.getSizes()); |
193 | uint64_t stride = getFirstIntValue(extractOp.getStrides()); |
194 | if (stride != 1) |
195 | return failure(); |
196 | |
197 | Value srcVector = adaptor.getOperands().front(); |
198 | |
199 | // Extract vector<1xT> case. |
200 | if (isa<spirv::ScalarType>(Val: dstType)) { |
201 | rewriter.replaceOpWithNewOp<spirv::CompositeExtractOp>(extractOp, |
202 | srcVector, offset); |
203 | return success(); |
204 | } |
205 | |
206 | SmallVector<int32_t, 2> indices(size); |
207 | std::iota(first: indices.begin(), last: indices.end(), value: offset); |
208 | |
209 | rewriter.replaceOpWithNewOp<spirv::VectorShuffleOp>( |
210 | extractOp, dstType, srcVector, srcVector, |
211 | rewriter.getI32ArrayAttr(indices)); |
212 | |
213 | return success(); |
214 | } |
215 | }; |
216 | |
217 | template <class SPIRVFMAOp> |
218 | struct VectorFmaOpConvert final : public OpConversionPattern<vector::FMAOp> { |
219 | using OpConversionPattern::OpConversionPattern; |
220 | |
221 | LogicalResult |
222 | matchAndRewrite(vector::FMAOp fmaOp, OpAdaptor adaptor, |
223 | ConversionPatternRewriter &rewriter) const override { |
224 | Type dstType = getTypeConverter()->convertType(fmaOp.getType()); |
225 | if (!dstType) |
226 | return failure(); |
227 | rewriter.replaceOpWithNewOp<SPIRVFMAOp>(fmaOp, dstType, adaptor.getLhs(), |
228 | adaptor.getRhs(), adaptor.getAcc()); |
229 | return success(); |
230 | } |
231 | }; |
232 | |
233 | struct VectorInsertOpConvert final |
234 | : public OpConversionPattern<vector::InsertOp> { |
235 | using OpConversionPattern::OpConversionPattern; |
236 | |
237 | LogicalResult |
238 | matchAndRewrite(vector::InsertOp insertOp, OpAdaptor adaptor, |
239 | ConversionPatternRewriter &rewriter) const override { |
240 | if (isa<VectorType>(insertOp.getSourceType())) |
241 | return rewriter.notifyMatchFailure(insertOp, "unsupported vector source" ); |
242 | if (!getTypeConverter()->convertType(insertOp.getDestVectorType())) |
243 | return rewriter.notifyMatchFailure(insertOp, |
244 | "unsupported dest vector type" ); |
245 | |
246 | // Special case for inserting scalar values into size-1 vectors. |
247 | if (insertOp.getSourceType().isIntOrFloat() && |
248 | insertOp.getDestVectorType().getNumElements() == 1) { |
249 | rewriter.replaceOp(insertOp, adaptor.getSource()); |
250 | return success(); |
251 | } |
252 | |
253 | int32_t id = getFirstIntValue(insertOp.getMixedPosition()); |
254 | rewriter.replaceOpWithNewOp<spirv::CompositeInsertOp>( |
255 | insertOp, adaptor.getSource(), adaptor.getDest(), id); |
256 | return success(); |
257 | } |
258 | }; |
259 | |
260 | struct final |
261 | : public OpConversionPattern<vector::ExtractElementOp> { |
262 | using OpConversionPattern::OpConversionPattern; |
263 | |
264 | LogicalResult |
265 | matchAndRewrite(vector::ExtractElementOp , OpAdaptor adaptor, |
266 | ConversionPatternRewriter &rewriter) const override { |
267 | Type resultType = getTypeConverter()->convertType(extractOp.getType()); |
268 | if (!resultType) |
269 | return failure(); |
270 | |
271 | if (isa<spirv::ScalarType>(adaptor.getVector().getType())) { |
272 | rewriter.replaceOp(extractOp, adaptor.getVector()); |
273 | return success(); |
274 | } |
275 | |
276 | APInt cstPos; |
277 | if (matchPattern(adaptor.getPosition(), m_ConstantInt(&cstPos))) |
278 | rewriter.replaceOpWithNewOp<spirv::CompositeExtractOp>( |
279 | extractOp, resultType, adaptor.getVector(), |
280 | rewriter.getI32ArrayAttr({static_cast<int>(cstPos.getSExtValue())})); |
281 | else |
282 | rewriter.replaceOpWithNewOp<spirv::VectorExtractDynamicOp>( |
283 | extractOp, resultType, adaptor.getVector(), adaptor.getPosition()); |
284 | return success(); |
285 | } |
286 | }; |
287 | |
288 | struct VectorInsertElementOpConvert final |
289 | : public OpConversionPattern<vector::InsertElementOp> { |
290 | using OpConversionPattern::OpConversionPattern; |
291 | |
292 | LogicalResult |
293 | matchAndRewrite(vector::InsertElementOp insertOp, OpAdaptor adaptor, |
294 | ConversionPatternRewriter &rewriter) const override { |
295 | Type vectorType = getTypeConverter()->convertType(insertOp.getType()); |
296 | if (!vectorType) |
297 | return failure(); |
298 | |
299 | if (isa<spirv::ScalarType>(Val: vectorType)) { |
300 | rewriter.replaceOp(insertOp, adaptor.getSource()); |
301 | return success(); |
302 | } |
303 | |
304 | APInt cstPos; |
305 | if (matchPattern(adaptor.getPosition(), m_ConstantInt(&cstPos))) |
306 | rewriter.replaceOpWithNewOp<spirv::CompositeInsertOp>( |
307 | insertOp, adaptor.getSource(), adaptor.getDest(), |
308 | cstPos.getSExtValue()); |
309 | else |
310 | rewriter.replaceOpWithNewOp<spirv::VectorInsertDynamicOp>( |
311 | insertOp, vectorType, insertOp.getDest(), adaptor.getSource(), |
312 | adaptor.getPosition()); |
313 | return success(); |
314 | } |
315 | }; |
316 | |
317 | struct VectorInsertStridedSliceOpConvert final |
318 | : public OpConversionPattern<vector::InsertStridedSliceOp> { |
319 | using OpConversionPattern::OpConversionPattern; |
320 | |
321 | LogicalResult |
322 | matchAndRewrite(vector::InsertStridedSliceOp insertOp, OpAdaptor adaptor, |
323 | ConversionPatternRewriter &rewriter) const override { |
324 | Value srcVector = adaptor.getOperands().front(); |
325 | Value dstVector = adaptor.getOperands().back(); |
326 | |
327 | uint64_t stride = getFirstIntValue(insertOp.getStrides()); |
328 | if (stride != 1) |
329 | return failure(); |
330 | uint64_t offset = getFirstIntValue(insertOp.getOffsets()); |
331 | |
332 | if (isa<spirv::ScalarType>(Val: srcVector.getType())) { |
333 | assert(!isa<spirv::ScalarType>(dstVector.getType())); |
334 | rewriter.replaceOpWithNewOp<spirv::CompositeInsertOp>( |
335 | insertOp, dstVector.getType(), srcVector, dstVector, |
336 | rewriter.getI32ArrayAttr(offset)); |
337 | return success(); |
338 | } |
339 | |
340 | uint64_t totalSize = cast<VectorType>(dstVector.getType()).getNumElements(); |
341 | uint64_t insertSize = |
342 | cast<VectorType>(srcVector.getType()).getNumElements(); |
343 | |
344 | SmallVector<int32_t, 2> indices(totalSize); |
345 | std::iota(first: indices.begin(), last: indices.end(), value: 0); |
346 | std::iota(first: indices.begin() + offset, last: indices.begin() + offset + insertSize, |
347 | value: totalSize); |
348 | |
349 | rewriter.replaceOpWithNewOp<spirv::VectorShuffleOp>( |
350 | insertOp, dstVector.getType(), dstVector, srcVector, |
351 | rewriter.getI32ArrayAttr(indices)); |
352 | |
353 | return success(); |
354 | } |
355 | }; |
356 | |
357 | static SmallVector<Value> ( |
358 | vector::ReductionOp reduceOp, vector::ReductionOp::Adaptor adaptor, |
359 | VectorType srcVectorType, ConversionPatternRewriter &rewriter) { |
360 | int numElements = static_cast<int>(srcVectorType.getDimSize(0)); |
361 | SmallVector<Value> values; |
362 | values.reserve(N: numElements + (adaptor.getAcc() ? 1 : 0)); |
363 | Location loc = reduceOp.getLoc(); |
364 | |
365 | for (int i = 0; i < numElements; ++i) { |
366 | values.push_back(rewriter.create<spirv::CompositeExtractOp>( |
367 | loc, srcVectorType.getElementType(), adaptor.getVector(), |
368 | rewriter.getI32ArrayAttr({i}))); |
369 | } |
370 | if (Value acc = adaptor.getAcc()) |
371 | values.push_back(Elt: acc); |
372 | |
373 | return values; |
374 | } |
375 | |
376 | struct ReductionRewriteInfo { |
377 | Type resultType; |
378 | SmallVector<Value> ; |
379 | }; |
380 | |
381 | FailureOr<ReductionRewriteInfo> static getReductionInfo( |
382 | vector::ReductionOp op, vector::ReductionOp::Adaptor adaptor, |
383 | ConversionPatternRewriter &rewriter, const TypeConverter &typeConverter) { |
384 | Type resultType = typeConverter.convertType(op.getType()); |
385 | if (!resultType) |
386 | return failure(); |
387 | |
388 | auto srcVectorType = dyn_cast<VectorType>(adaptor.getVector().getType()); |
389 | if (!srcVectorType || srcVectorType.getRank() != 1) |
390 | return rewriter.notifyMatchFailure(op, "not a 1-D vector source" ); |
391 | |
392 | SmallVector<Value> = |
393 | extractAllElements(op, adaptor, srcVectorType, rewriter); |
394 | |
395 | return ReductionRewriteInfo{.resultType: resultType, .extractedElements: std::move(extractedElements)}; |
396 | } |
397 | |
398 | template <typename SPIRVUMaxOp, typename SPIRVUMinOp, typename SPIRVSMaxOp, |
399 | typename SPIRVSMinOp> |
400 | struct VectorReductionPattern final : OpConversionPattern<vector::ReductionOp> { |
401 | using OpConversionPattern::OpConversionPattern; |
402 | |
403 | LogicalResult |
404 | matchAndRewrite(vector::ReductionOp reduceOp, OpAdaptor adaptor, |
405 | ConversionPatternRewriter &rewriter) const override { |
406 | auto reductionInfo = |
407 | getReductionInfo(reduceOp, adaptor, rewriter, *getTypeConverter()); |
408 | if (failed(reductionInfo)) |
409 | return failure(); |
410 | |
411 | auto [resultType, extractedElements] = *reductionInfo; |
412 | Location loc = reduceOp->getLoc(); |
413 | Value result = extractedElements.front(); |
414 | for (Value next : llvm::drop_begin(extractedElements)) { |
415 | switch (reduceOp.getKind()) { |
416 | |
417 | #define INT_AND_FLOAT_CASE(kind, iop, fop) \ |
418 | case vector::CombiningKind::kind: \ |
419 | if (llvm::isa<IntegerType>(resultType)) { \ |
420 | result = rewriter.create<spirv::iop>(loc, resultType, result, next); \ |
421 | } else { \ |
422 | assert(llvm::isa<FloatType>(resultType)); \ |
423 | result = rewriter.create<spirv::fop>(loc, resultType, result, next); \ |
424 | } \ |
425 | break |
426 | |
427 | #define INT_OR_FLOAT_CASE(kind, fop) \ |
428 | case vector::CombiningKind::kind: \ |
429 | result = rewriter.create<fop>(loc, resultType, result, next); \ |
430 | break |
431 | |
432 | INT_AND_FLOAT_CASE(ADD, IAddOp, FAddOp); |
433 | INT_AND_FLOAT_CASE(MUL, IMulOp, FMulOp); |
434 | INT_OR_FLOAT_CASE(MINUI, SPIRVUMinOp); |
435 | INT_OR_FLOAT_CASE(MINSI, SPIRVSMinOp); |
436 | INT_OR_FLOAT_CASE(MAXUI, SPIRVUMaxOp); |
437 | INT_OR_FLOAT_CASE(MAXSI, SPIRVSMaxOp); |
438 | |
439 | case vector::CombiningKind::AND: |
440 | case vector::CombiningKind::OR: |
441 | case vector::CombiningKind::XOR: |
442 | return rewriter.notifyMatchFailure(reduceOp, "unimplemented" ); |
443 | default: |
444 | return rewriter.notifyMatchFailure(reduceOp, "not handled here" ); |
445 | } |
446 | #undef INT_AND_FLOAT_CASE |
447 | #undef INT_OR_FLOAT_CASE |
448 | } |
449 | |
450 | rewriter.replaceOp(reduceOp, result); |
451 | return success(); |
452 | } |
453 | }; |
454 | |
455 | template <typename SPIRVFMaxOp, typename SPIRVFMinOp> |
456 | struct VectorReductionFloatMinMax final |
457 | : OpConversionPattern<vector::ReductionOp> { |
458 | using OpConversionPattern::OpConversionPattern; |
459 | |
460 | LogicalResult |
461 | matchAndRewrite(vector::ReductionOp reduceOp, OpAdaptor adaptor, |
462 | ConversionPatternRewriter &rewriter) const override { |
463 | auto reductionInfo = |
464 | getReductionInfo(reduceOp, adaptor, rewriter, *getTypeConverter()); |
465 | if (failed(reductionInfo)) |
466 | return failure(); |
467 | |
468 | auto [resultType, extractedElements] = *reductionInfo; |
469 | Location loc = reduceOp->getLoc(); |
470 | Value result = extractedElements.front(); |
471 | for (Value next : llvm::drop_begin(extractedElements)) { |
472 | switch (reduceOp.getKind()) { |
473 | |
474 | #define INT_OR_FLOAT_CASE(kind, fop) \ |
475 | case vector::CombiningKind::kind: \ |
476 | result = rewriter.create<fop>(loc, resultType, result, next); \ |
477 | break |
478 | |
479 | INT_OR_FLOAT_CASE(MAXIMUMF, SPIRVFMaxOp); |
480 | INT_OR_FLOAT_CASE(MINIMUMF, SPIRVFMinOp); |
481 | INT_OR_FLOAT_CASE(MAXNUMF, SPIRVFMaxOp); |
482 | INT_OR_FLOAT_CASE(MINNUMF, SPIRVFMinOp); |
483 | |
484 | default: |
485 | return rewriter.notifyMatchFailure(reduceOp, "not handled here" ); |
486 | } |
487 | #undef INT_OR_FLOAT_CASE |
488 | } |
489 | |
490 | rewriter.replaceOp(reduceOp, result); |
491 | return success(); |
492 | } |
493 | }; |
494 | |
495 | class VectorSplatPattern final : public OpConversionPattern<vector::SplatOp> { |
496 | public: |
497 | using OpConversionPattern<vector::SplatOp>::OpConversionPattern; |
498 | |
499 | LogicalResult |
500 | matchAndRewrite(vector::SplatOp op, OpAdaptor adaptor, |
501 | ConversionPatternRewriter &rewriter) const override { |
502 | Type dstType = getTypeConverter()->convertType(op.getType()); |
503 | if (!dstType) |
504 | return failure(); |
505 | if (isa<spirv::ScalarType>(Val: dstType)) { |
506 | rewriter.replaceOp(op, adaptor.getInput()); |
507 | } else { |
508 | auto dstVecType = cast<VectorType>(dstType); |
509 | SmallVector<Value, 4> source(dstVecType.getNumElements(), |
510 | adaptor.getInput()); |
511 | rewriter.replaceOpWithNewOp<spirv::CompositeConstructOp>(op, dstType, |
512 | source); |
513 | } |
514 | return success(); |
515 | } |
516 | }; |
517 | |
518 | struct VectorShuffleOpConvert final |
519 | : public OpConversionPattern<vector::ShuffleOp> { |
520 | using OpConversionPattern::OpConversionPattern; |
521 | |
522 | LogicalResult |
523 | matchAndRewrite(vector::ShuffleOp shuffleOp, OpAdaptor adaptor, |
524 | ConversionPatternRewriter &rewriter) const override { |
525 | auto oldResultType = shuffleOp.getResultVectorType(); |
526 | Type newResultType = getTypeConverter()->convertType(oldResultType); |
527 | if (!newResultType) |
528 | return rewriter.notifyMatchFailure(shuffleOp, |
529 | "unsupported result vector type" ); |
530 | |
531 | SmallVector<int32_t, 4> mask = llvm::map_to_vector<4>( |
532 | shuffleOp.getMask(), [](Attribute attr) -> int32_t { |
533 | return cast<IntegerAttr>(attr).getValue().getZExtValue(); |
534 | }); |
535 | |
536 | auto oldV1Type = shuffleOp.getV1VectorType(); |
537 | auto oldV2Type = shuffleOp.getV2VectorType(); |
538 | |
539 | // When both operands are SPIR-V vectors, emit a SPIR-V shuffle. |
540 | if (oldV1Type.getNumElements() > 1 && oldV2Type.getNumElements() > 1) { |
541 | rewriter.replaceOpWithNewOp<spirv::VectorShuffleOp>( |
542 | shuffleOp, newResultType, adaptor.getV1(), adaptor.getV2(), |
543 | rewriter.getI32ArrayAttr(mask)); |
544 | return success(); |
545 | } |
546 | |
547 | // When at least one of the operands becomes a scalar after type conversion |
548 | // for SPIR-V, extract all the required elements and construct the result |
549 | // vector. |
550 | auto getElementAtIdx = [&rewriter, loc = shuffleOp.getLoc()]( |
551 | Value scalarOrVec, int32_t idx) -> Value { |
552 | if (auto vecTy = dyn_cast<VectorType>(scalarOrVec.getType())) |
553 | return rewriter.create<spirv::CompositeExtractOp>(loc, scalarOrVec, |
554 | idx); |
555 | |
556 | assert(idx == 0 && "Invalid scalar element index" ); |
557 | return scalarOrVec; |
558 | }; |
559 | |
560 | int32_t numV1Elems = oldV1Type.getNumElements(); |
561 | SmallVector<Value> newOperands(mask.size()); |
562 | for (auto [shuffleIdx, newOperand] : llvm::zip_equal(mask, newOperands)) { |
563 | Value vec = adaptor.getV1(); |
564 | int32_t elementIdx = shuffleIdx; |
565 | if (elementIdx >= numV1Elems) { |
566 | vec = adaptor.getV2(); |
567 | elementIdx -= numV1Elems; |
568 | } |
569 | |
570 | newOperand = getElementAtIdx(vec, elementIdx); |
571 | } |
572 | |
573 | rewriter.replaceOpWithNewOp<spirv::CompositeConstructOp>( |
574 | shuffleOp, newResultType, newOperands); |
575 | |
576 | return success(); |
577 | } |
578 | }; |
579 | |
580 | struct VectorLoadOpConverter final |
581 | : public OpConversionPattern<vector::LoadOp> { |
582 | using OpConversionPattern::OpConversionPattern; |
583 | |
584 | LogicalResult |
585 | matchAndRewrite(vector::LoadOp loadOp, OpAdaptor adaptor, |
586 | ConversionPatternRewriter &rewriter) const override { |
587 | auto memrefType = loadOp.getMemRefType(); |
588 | auto attr = |
589 | dyn_cast_or_null<spirv::StorageClassAttr>(memrefType.getMemorySpace()); |
590 | if (!attr) |
591 | return rewriter.notifyMatchFailure( |
592 | loadOp, "expected spirv.storage_class memory space" ); |
593 | |
594 | const auto &typeConverter = *getTypeConverter<SPIRVTypeConverter>(); |
595 | auto loc = loadOp.getLoc(); |
596 | Value accessChain = |
597 | spirv::getElementPtr(typeConverter: typeConverter, baseType: memrefType, basePtr: adaptor.getBase(), |
598 | indices: adaptor.getIndices(), loc: loc, builder&: rewriter); |
599 | if (!accessChain) |
600 | return rewriter.notifyMatchFailure( |
601 | loadOp, "failed to get memref element pointer" ); |
602 | |
603 | spirv::StorageClass storageClass = attr.getValue(); |
604 | auto vectorType = loadOp.getVectorType(); |
605 | auto vectorPtrType = spirv::PointerType::get(vectorType, storageClass); |
606 | Value castedAccessChain = |
607 | rewriter.create<spirv::BitcastOp>(loc, vectorPtrType, accessChain); |
608 | rewriter.replaceOpWithNewOp<spirv::LoadOp>(loadOp, vectorType, |
609 | castedAccessChain); |
610 | |
611 | return success(); |
612 | } |
613 | }; |
614 | |
615 | struct VectorStoreOpConverter final |
616 | : public OpConversionPattern<vector::StoreOp> { |
617 | using OpConversionPattern::OpConversionPattern; |
618 | |
619 | LogicalResult |
620 | matchAndRewrite(vector::StoreOp storeOp, OpAdaptor adaptor, |
621 | ConversionPatternRewriter &rewriter) const override { |
622 | auto memrefType = storeOp.getMemRefType(); |
623 | auto attr = |
624 | dyn_cast_or_null<spirv::StorageClassAttr>(memrefType.getMemorySpace()); |
625 | if (!attr) |
626 | return rewriter.notifyMatchFailure( |
627 | storeOp, "expected spirv.storage_class memory space" ); |
628 | |
629 | const auto &typeConverter = *getTypeConverter<SPIRVTypeConverter>(); |
630 | auto loc = storeOp.getLoc(); |
631 | Value accessChain = |
632 | spirv::getElementPtr(typeConverter: typeConverter, baseType: memrefType, basePtr: adaptor.getBase(), |
633 | indices: adaptor.getIndices(), loc: loc, builder&: rewriter); |
634 | if (!accessChain) |
635 | return rewriter.notifyMatchFailure( |
636 | storeOp, "failed to get memref element pointer" ); |
637 | |
638 | spirv::StorageClass storageClass = attr.getValue(); |
639 | auto vectorType = storeOp.getVectorType(); |
640 | auto vectorPtrType = spirv::PointerType::get(vectorType, storageClass); |
641 | Value castedAccessChain = |
642 | rewriter.create<spirv::BitcastOp>(loc, vectorPtrType, accessChain); |
643 | rewriter.replaceOpWithNewOp<spirv::StoreOp>(storeOp, castedAccessChain, |
644 | adaptor.getValueToStore()); |
645 | |
646 | return success(); |
647 | } |
648 | }; |
649 | |
650 | struct VectorReductionToIntDotProd final |
651 | : OpRewritePattern<vector::ReductionOp> { |
652 | using OpRewritePattern::OpRewritePattern; |
653 | |
654 | LogicalResult matchAndRewrite(vector::ReductionOp op, |
655 | PatternRewriter &rewriter) const override { |
656 | if (op.getKind() != vector::CombiningKind::ADD) |
657 | return rewriter.notifyMatchFailure(op, "combining kind is not 'add'" ); |
658 | |
659 | auto resultType = dyn_cast<IntegerType>(op.getType()); |
660 | if (!resultType) |
661 | return rewriter.notifyMatchFailure(op, "result is not an integer" ); |
662 | |
663 | int64_t resultBitwidth = resultType.getIntOrFloatBitWidth(); |
664 | if (!llvm::is_contained(Set: {32, 64}, Element: resultBitwidth)) |
665 | return rewriter.notifyMatchFailure(op, "unsupported integer bitwidth" ); |
666 | |
667 | VectorType inVecTy = op.getSourceVectorType(); |
668 | if (!llvm::is_contained({4, 3}, inVecTy.getNumElements()) || |
669 | inVecTy.getShape().size() != 1 || inVecTy.isScalable()) |
670 | return rewriter.notifyMatchFailure(op, "unsupported vector shape" ); |
671 | |
672 | auto mul = op.getVector().getDefiningOp<arith::MulIOp>(); |
673 | if (!mul) |
674 | return rewriter.notifyMatchFailure( |
675 | op, "reduction operand is not 'arith.muli'" ); |
676 | |
677 | if (succeeded(handleCase<arith::ExtSIOp, arith::ExtSIOp, spirv::SDotOp, |
678 | spirv::SDotAccSatOp, false>(op, mul, rewriter))) |
679 | return success(); |
680 | |
681 | if (succeeded(handleCase<arith::ExtUIOp, arith::ExtUIOp, spirv::UDotOp, |
682 | spirv::UDotAccSatOp, false>(op, mul, rewriter))) |
683 | return success(); |
684 | |
685 | if (succeeded(handleCase<arith::ExtSIOp, arith::ExtUIOp, spirv::SUDotOp, |
686 | spirv::SUDotAccSatOp, false>(op, mul, rewriter))) |
687 | return success(); |
688 | |
689 | if (succeeded(handleCase<arith::ExtUIOp, arith::ExtSIOp, spirv::SUDotOp, |
690 | spirv::SUDotAccSatOp, true>(op, mul, rewriter))) |
691 | return success(); |
692 | |
693 | return failure(); |
694 | } |
695 | |
696 | private: |
697 | template <typename LhsExtensionOp, typename RhsExtensionOp, typename DotOp, |
698 | typename DotAccOp, bool SwapOperands> |
699 | static LogicalResult handleCase(vector::ReductionOp op, arith::MulIOp mul, |
700 | PatternRewriter &rewriter) { |
701 | auto lhs = mul.getLhs().getDefiningOp<LhsExtensionOp>(); |
702 | if (!lhs) |
703 | return failure(); |
704 | Value lhsIn = lhs.getIn(); |
705 | auto lhsInType = cast<VectorType>(lhsIn.getType()); |
706 | if (!lhsInType.getElementType().isInteger(8)) |
707 | return failure(); |
708 | |
709 | auto rhs = mul.getRhs().getDefiningOp<RhsExtensionOp>(); |
710 | if (!rhs) |
711 | return failure(); |
712 | Value rhsIn = rhs.getIn(); |
713 | auto rhsInType = cast<VectorType>(rhsIn.getType()); |
714 | if (!rhsInType.getElementType().isInteger(8)) |
715 | return failure(); |
716 | |
717 | if (op.getSourceVectorType().getNumElements() == 3) { |
718 | IntegerType i8Type = rewriter.getI8Type(); |
719 | auto v4i8Type = VectorType::get({4}, i8Type); |
720 | Location loc = op.getLoc(); |
721 | Value zero = spirv::ConstantOp::getZero(i8Type, loc, rewriter); |
722 | lhsIn = rewriter.create<spirv::CompositeConstructOp>( |
723 | loc, v4i8Type, ValueRange{lhsIn, zero}); |
724 | rhsIn = rewriter.create<spirv::CompositeConstructOp>( |
725 | loc, v4i8Type, ValueRange{rhsIn, zero}); |
726 | } |
727 | |
728 | // There's no variant of dot prod ops for unsigned LHS and signed RHS, so |
729 | // we have to swap operands instead in that case. |
730 | if (SwapOperands) |
731 | std::swap(a&: lhsIn, b&: rhsIn); |
732 | |
733 | if (Value acc = op.getAcc()) { |
734 | rewriter.replaceOpWithNewOp<DotAccOp>(op, op.getType(), lhsIn, rhsIn, acc, |
735 | nullptr); |
736 | } else { |
737 | rewriter.replaceOpWithNewOp<DotOp>(op, op.getType(), lhsIn, rhsIn, |
738 | nullptr); |
739 | } |
740 | |
741 | return success(); |
742 | } |
743 | }; |
744 | |
745 | struct VectorReductionToFPDotProd final |
746 | : OpConversionPattern<vector::ReductionOp> { |
747 | using OpConversionPattern::OpConversionPattern; |
748 | |
749 | LogicalResult |
750 | matchAndRewrite(vector::ReductionOp op, OpAdaptor adaptor, |
751 | ConversionPatternRewriter &rewriter) const override { |
752 | if (op.getKind() != vector::CombiningKind::ADD) |
753 | return rewriter.notifyMatchFailure(op, "combining kind is not 'add'" ); |
754 | |
755 | auto resultType = getTypeConverter()->convertType<FloatType>(op.getType()); |
756 | if (!resultType) |
757 | return rewriter.notifyMatchFailure(op, "result is not a float" ); |
758 | |
759 | Value vec = adaptor.getVector(); |
760 | Value acc = adaptor.getAcc(); |
761 | |
762 | auto vectorType = dyn_cast<VectorType>(vec.getType()); |
763 | if (!vectorType) { |
764 | assert(isa<FloatType>(vec.getType()) && |
765 | "Expected the vector to be scalarized" ); |
766 | if (acc) { |
767 | rewriter.replaceOpWithNewOp<spirv::FAddOp>(op, acc, vec); |
768 | return success(); |
769 | } |
770 | |
771 | rewriter.replaceOp(op, vec); |
772 | return success(); |
773 | } |
774 | |
775 | Location loc = op.getLoc(); |
776 | Value lhs; |
777 | Value rhs; |
778 | if (auto mul = vec.getDefiningOp<arith::MulFOp>()) { |
779 | lhs = mul.getLhs(); |
780 | rhs = mul.getRhs(); |
781 | } else { |
782 | // If the operand is not a mul, use a vector of ones for the dot operand |
783 | // to just sum up all values. |
784 | lhs = vec; |
785 | Attribute oneAttr = |
786 | rewriter.getFloatAttr(vectorType.getElementType(), 1.0); |
787 | oneAttr = SplatElementsAttr::get(vectorType, oneAttr); |
788 | rhs = rewriter.create<spirv::ConstantOp>(loc, vectorType, oneAttr); |
789 | } |
790 | assert(lhs); |
791 | assert(rhs); |
792 | |
793 | Value res = rewriter.create<spirv::DotOp>(loc, resultType, lhs, rhs); |
794 | if (acc) |
795 | res = rewriter.create<spirv::FAddOp>(loc, acc, res); |
796 | |
797 | rewriter.replaceOp(op, res); |
798 | return success(); |
799 | } |
800 | }; |
801 | |
802 | } // namespace |
803 | #define CL_INT_MAX_MIN_OPS \ |
804 | spirv::CLUMaxOp, spirv::CLUMinOp, spirv::CLSMaxOp, spirv::CLSMinOp |
805 | |
806 | #define GL_INT_MAX_MIN_OPS \ |
807 | spirv::GLUMaxOp, spirv::GLUMinOp, spirv::GLSMaxOp, spirv::GLSMinOp |
808 | |
809 | #define CL_FLOAT_MAX_MIN_OPS spirv::CLFMaxOp, spirv::CLFMinOp |
810 | #define GL_FLOAT_MAX_MIN_OPS spirv::GLFMaxOp, spirv::GLFMinOp |
811 | |
812 | void mlir::populateVectorToSPIRVPatterns(SPIRVTypeConverter &typeConverter, |
813 | RewritePatternSet &patterns) { |
814 | patterns.add< |
815 | VectorBitcastConvert, VectorBroadcastConvert, |
816 | VectorExtractElementOpConvert, VectorExtractOpConvert, |
817 | VectorExtractStridedSliceOpConvert, VectorFmaOpConvert<spirv::GLFmaOp>, |
818 | VectorFmaOpConvert<spirv::CLFmaOp>, VectorInsertElementOpConvert, |
819 | VectorInsertOpConvert, VectorReductionPattern<GL_INT_MAX_MIN_OPS>, |
820 | VectorReductionPattern<CL_INT_MAX_MIN_OPS>, |
821 | VectorReductionFloatMinMax<CL_FLOAT_MAX_MIN_OPS>, |
822 | VectorReductionFloatMinMax<GL_FLOAT_MAX_MIN_OPS>, VectorShapeCast, |
823 | VectorInsertStridedSliceOpConvert, VectorShuffleOpConvert, |
824 | VectorSplatPattern, VectorLoadOpConverter, VectorStoreOpConverter>( |
825 | typeConverter, patterns.getContext(), PatternBenefit(1)); |
826 | |
827 | // Make sure that the more specialized dot product pattern has higher benefit |
828 | // than the generic one that extracts all elements. |
829 | patterns.add<VectorReductionToFPDotProd>(arg&: typeConverter, args: patterns.getContext(), |
830 | args: PatternBenefit(2)); |
831 | } |
832 | |
833 | void mlir::populateVectorReductionToSPIRVDotProductPatterns( |
834 | RewritePatternSet &patterns) { |
835 | patterns.add<VectorReductionToIntDotProd>(arg: patterns.getContext()); |
836 | } |
837 | |