1//===- ArithToSPIRV.cpp - Arithmetic to SPIRV dialect conversion -----===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#include "mlir/Conversion/ArithToSPIRV/ArithToSPIRV.h"
10
11#include "../SPIRVCommon/Pattern.h"
12#include "mlir/Dialect/Arith/IR/Arith.h"
13#include "mlir/Dialect/SPIRV/IR/SPIRVAttributes.h"
14#include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h"
15#include "mlir/Dialect/SPIRV/IR/SPIRVOps.h"
16#include "mlir/Dialect/SPIRV/IR/SPIRVTypes.h"
17#include "mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h"
18#include "mlir/IR/BuiltinAttributes.h"
19#include "mlir/IR/BuiltinTypes.h"
20#include "mlir/IR/DialectResourceBlobManager.h"
21#include "llvm/ADT/APInt.h"
22#include "llvm/ADT/ArrayRef.h"
23#include "llvm/ADT/STLExtras.h"
24#include "llvm/Support/Debug.h"
25#include "llvm/Support/MathExtras.h"
26#include <cassert>
27#include <memory>
28
29namespace mlir {
30#define GEN_PASS_DEF_CONVERTARITHTOSPIRVPASS
31#include "mlir/Conversion/Passes.h.inc"
32} // namespace mlir
33
34#define DEBUG_TYPE "arith-to-spirv-pattern"
35
36using namespace mlir;
37
38//===----------------------------------------------------------------------===//
39// Conversion Helpers
40//===----------------------------------------------------------------------===//
41
42/// Converts the given `srcAttr` into a boolean attribute if it holds an
43/// integral value. Returns null attribute if conversion fails.
44static BoolAttr convertBoolAttr(Attribute srcAttr, Builder builder) {
45 if (auto boolAttr = dyn_cast<BoolAttr>(srcAttr))
46 return boolAttr;
47 if (auto intAttr = dyn_cast<IntegerAttr>(srcAttr))
48 return builder.getBoolAttr(value: intAttr.getValue().getBoolValue());
49 return {};
50}
51
52/// Converts the given `srcAttr` to a new attribute of the given `dstType`.
53/// Returns null attribute if conversion fails.
54static IntegerAttr convertIntegerAttr(IntegerAttr srcAttr, IntegerType dstType,
55 Builder builder) {
56 // If the source number uses less active bits than the target bitwidth, then
57 // it should be safe to convert.
58 if (srcAttr.getValue().isIntN(dstType.getWidth()))
59 return builder.getIntegerAttr(dstType, srcAttr.getInt());
60
61 // XXX: Try again by interpreting the source number as a signed value.
62 // Although integers in the standard dialect are signless, they can represent
63 // a signed number. It's the operation decides how to interpret. This is
64 // dangerous, but it seems there is no good way of handling this if we still
65 // want to change the bitwidth. Emit a message at least.
66 if (srcAttr.getValue().isSignedIntN(dstType.getWidth())) {
67 auto dstAttr = builder.getIntegerAttr(dstType, srcAttr.getInt());
68 LLVM_DEBUG(llvm::dbgs() << "attribute '" << srcAttr << "' converted to '"
69 << dstAttr << "' for type '" << dstType << "'\n");
70 return dstAttr;
71 }
72
73 LLVM_DEBUG(llvm::dbgs() << "attribute '" << srcAttr
74 << "' illegal: cannot fit into target type '"
75 << dstType << "'\n");
76 return {};
77}
78
79/// Converts the given `srcAttr` to a new attribute of the given `dstType`.
80/// Returns null attribute if `dstType` is not 32-bit or conversion fails.
81static FloatAttr convertFloatAttr(FloatAttr srcAttr, FloatType dstType,
82 Builder builder) {
83 // Only support converting to float for now.
84 if (!dstType.isF32())
85 return FloatAttr();
86
87 // Try to convert the source floating-point number to single precision.
88 APFloat dstVal = srcAttr.getValue();
89 bool losesInfo = false;
90 APFloat::opStatus status =
91 dstVal.convert(ToSemantics: APFloat::IEEEsingle(), RM: APFloat::rmTowardZero, losesInfo: &losesInfo);
92 if (status != APFloat::opOK || losesInfo) {
93 LLVM_DEBUG(llvm::dbgs()
94 << srcAttr << " illegal: cannot fit into converted type '"
95 << dstType << "'\n");
96 return FloatAttr();
97 }
98
99 return builder.getF32FloatAttr(dstVal.convertToFloat());
100}
101
102/// Returns true if the given `type` is a boolean scalar or vector type.
103static bool isBoolScalarOrVector(Type type) {
104 assert(type && "Not a valid type");
105 if (type.isInteger(width: 1))
106 return true;
107
108 if (auto vecType = dyn_cast<VectorType>(type))
109 return vecType.getElementType().isInteger(1);
110
111 return false;
112}
113
114/// Creates a scalar/vector integer constant.
115static Value getScalarOrVectorConstInt(Type type, uint64_t value,
116 OpBuilder &builder, Location loc) {
117 if (auto vectorType = dyn_cast<VectorType>(type)) {
118 Attribute element = IntegerAttr::get(vectorType.getElementType(), value);
119 auto attr = SplatElementsAttr::get(vectorType, element);
120 return builder.create<spirv::ConstantOp>(loc, vectorType, attr);
121 }
122
123 if (auto intType = dyn_cast<IntegerType>(type))
124 return builder.create<spirv::ConstantOp>(
125 loc, type, builder.getIntegerAttr(type, value));
126
127 return nullptr;
128}
129
130/// Returns true if scalar/vector type `a` and `b` have the same number of
131/// bitwidth.
132static bool hasSameBitwidth(Type a, Type b) {
133 auto getNumBitwidth = [](Type type) {
134 unsigned bw = 0;
135 if (type.isIntOrFloat())
136 bw = type.getIntOrFloatBitWidth();
137 else if (auto vecType = dyn_cast<VectorType>(type))
138 bw = vecType.getElementTypeBitWidth() * vecType.getNumElements();
139 return bw;
140 };
141 unsigned aBW = getNumBitwidth(a);
142 unsigned bBW = getNumBitwidth(b);
143 return aBW != 0 && bBW != 0 && aBW == bBW;
144}
145
146/// Returns a source type conversion failure for `srcType` and operation `op`.
147static LogicalResult
148getTypeConversionFailure(ConversionPatternRewriter &rewriter, Operation *op,
149 Type srcType) {
150 return rewriter.notifyMatchFailure(
151 arg: op->getLoc(),
152 msg: llvm::formatv(Fmt: "failed to convert source type '{0}'", Vals&: srcType));
153}
154
155/// Returns a source type conversion failure for the result type of `op`.
156static LogicalResult
157getTypeConversionFailure(ConversionPatternRewriter &rewriter, Operation *op) {
158 assert(op->getNumResults() == 1);
159 return getTypeConversionFailure(rewriter, op, srcType: op->getResultTypes().front());
160}
161
162// TODO: Move to some common place?
163static std::string getDecorationString(spirv::Decoration decor) {
164 return llvm::convertToSnakeFromCamelCase(input: stringifyDecoration(decor));
165}
166
167namespace {
168
169/// Converts elementwise unary, binary and ternary arith operations to SPIR-V
170/// operations. Op can potentially support overflow flags.
171template <typename Op, typename SPIRVOp>
172struct ElementwiseArithOpPattern final : OpConversionPattern<Op> {
173 using OpConversionPattern<Op>::OpConversionPattern;
174
175 LogicalResult
176 matchAndRewrite(Op op, typename Op::Adaptor adaptor,
177 ConversionPatternRewriter &rewriter) const override {
178 assert(adaptor.getOperands().size() <= 3);
179 auto converter = this->template getTypeConverter<SPIRVTypeConverter>();
180 Type dstType = converter->convertType(op.getType());
181 if (!dstType) {
182 return rewriter.notifyMatchFailure(
183 op->getLoc(),
184 llvm::formatv("failed to convert type {0} for SPIR-V", op.getType()));
185 }
186
187 if (SPIRVOp::template hasTrait<OpTrait::spirv::UnsignedOp>() &&
188 !getElementTypeOrSelf(op.getType()).isIndex() &&
189 dstType != op.getType()) {
190 return op.emitError("bitwidth emulation is not implemented yet on "
191 "unsigned op pattern version");
192 }
193
194 auto overflowFlags = arith::IntegerOverflowFlags::none;
195 if (auto overflowIface =
196 dyn_cast<arith::ArithIntegerOverflowFlagsInterface>(*op)) {
197 if (converter->getTargetEnv().allows(
198 spirv::Extension::SPV_KHR_no_integer_wrap_decoration))
199 overflowFlags = overflowIface.getOverflowAttr().getValue();
200 }
201
202 auto newOp = rewriter.template replaceOpWithNewOp<SPIRVOp>(
203 op, dstType, adaptor.getOperands());
204
205 if (bitEnumContainsAny(overflowFlags, arith::IntegerOverflowFlags::nsw))
206 newOp->setAttr(getDecorationString(spirv::Decoration::NoSignedWrap),
207 rewriter.getUnitAttr());
208
209 if (bitEnumContainsAny(overflowFlags, arith::IntegerOverflowFlags::nuw))
210 newOp->setAttr(getDecorationString(spirv::Decoration::NoUnsignedWrap),
211 rewriter.getUnitAttr());
212
213 return success();
214 }
215};
216
217//===----------------------------------------------------------------------===//
218// ConstantOp
219//===----------------------------------------------------------------------===//
220
221/// Converts composite arith.constant operation to spirv.Constant.
222struct ConstantCompositeOpPattern final
223 : public OpConversionPattern<arith::ConstantOp> {
224 using OpConversionPattern::OpConversionPattern;
225
226 LogicalResult
227 matchAndRewrite(arith::ConstantOp constOp, OpAdaptor adaptor,
228 ConversionPatternRewriter &rewriter) const override {
229 auto srcType = dyn_cast<ShapedType>(constOp.getType());
230 if (!srcType || srcType.getNumElements() == 1)
231 return failure();
232
233 // arith.constant should only have vector or tensor types. This is a MLIR
234 // wide problem at the moment.
235 if (!isa<VectorType, RankedTensorType>(srcType))
236 return rewriter.notifyMatchFailure(constOp, "unsupported ShapedType");
237
238 Type dstType = getTypeConverter()->convertType(srcType);
239 if (!dstType)
240 return failure();
241
242 // Import the resource into the IR to make use of the special handling of
243 // element types later on.
244 mlir::DenseElementsAttr dstElementsAttr;
245 if (auto denseElementsAttr =
246 dyn_cast<DenseElementsAttr>(constOp.getValue())) {
247 dstElementsAttr = denseElementsAttr;
248 } else if (auto resourceAttr =
249 dyn_cast<DenseResourceElementsAttr>(constOp.getValue())) {
250
251 AsmResourceBlob *blob = resourceAttr.getRawHandle().getBlob();
252 if (!blob)
253 return constOp->emitError("could not find resource blob");
254
255 ArrayRef<char> ptr = blob->getData();
256
257 // Check that the buffer meets the requirements to get converted to a
258 // DenseElementsAttr
259 bool detectedSplat = false;
260 if (!DenseElementsAttr::isValidRawBuffer(srcType, ptr, detectedSplat))
261 return constOp->emitError("resource is not a valid buffer");
262
263 dstElementsAttr =
264 DenseElementsAttr::getFromRawBuffer(resourceAttr.getType(), ptr);
265 } else {
266 return constOp->emitError("unsupported elements attribute");
267 }
268
269 ShapedType dstAttrType = dstElementsAttr.getType();
270
271 // If the composite type has more than one dimensions, perform
272 // linearization.
273 if (srcType.getRank() > 1) {
274 if (isa<RankedTensorType>(srcType)) {
275 dstAttrType = RankedTensorType::get(srcType.getNumElements(),
276 srcType.getElementType());
277 dstElementsAttr = dstElementsAttr.reshape(dstAttrType);
278 } else {
279 // TODO: add support for large vectors.
280 return failure();
281 }
282 }
283
284 Type srcElemType = srcType.getElementType();
285 Type dstElemType;
286 // Tensor types are converted to SPIR-V array types; vector types are
287 // converted to SPIR-V vector/array types.
288 if (auto arrayType = dyn_cast<spirv::ArrayType>(dstType))
289 dstElemType = arrayType.getElementType();
290 else
291 dstElemType = cast<VectorType>(dstType).getElementType();
292
293 // If the source and destination element types are different, perform
294 // attribute conversion.
295 if (srcElemType != dstElemType) {
296 SmallVector<Attribute, 8> elements;
297 if (isa<FloatType>(Val: srcElemType)) {
298 for (FloatAttr srcAttr : dstElementsAttr.getValues<FloatAttr>()) {
299 FloatAttr dstAttr =
300 convertFloatAttr(srcAttr, cast<FloatType>(dstElemType), rewriter);
301 if (!dstAttr)
302 return failure();
303 elements.push_back(dstAttr);
304 }
305 } else if (srcElemType.isInteger(width: 1)) {
306 return failure();
307 } else {
308 for (IntegerAttr srcAttr : dstElementsAttr.getValues<IntegerAttr>()) {
309 IntegerAttr dstAttr = convertIntegerAttr(
310 srcAttr, cast<IntegerType>(dstElemType), rewriter);
311 if (!dstAttr)
312 return failure();
313 elements.push_back(dstAttr);
314 }
315 }
316
317 // Unfortunately, we cannot use dialect-specific types for element
318 // attributes; element attributes only works with builtin types. So we
319 // need to prepare another converted builtin types for the destination
320 // elements attribute.
321 if (isa<RankedTensorType>(dstAttrType))
322 dstAttrType =
323 RankedTensorType::get(dstAttrType.getShape(), dstElemType);
324 else
325 dstAttrType = VectorType::get(dstAttrType.getShape(), dstElemType);
326
327 dstElementsAttr = DenseElementsAttr::get(dstAttrType, elements);
328 }
329
330 rewriter.replaceOpWithNewOp<spirv::ConstantOp>(constOp, dstType,
331 dstElementsAttr);
332 return success();
333 }
334};
335
336/// Converts scalar arith.constant operation to spirv.Constant.
337struct ConstantScalarOpPattern final
338 : public OpConversionPattern<arith::ConstantOp> {
339 using OpConversionPattern::OpConversionPattern;
340
341 LogicalResult
342 matchAndRewrite(arith::ConstantOp constOp, OpAdaptor adaptor,
343 ConversionPatternRewriter &rewriter) const override {
344 Type srcType = constOp.getType();
345 if (auto shapedType = dyn_cast<ShapedType>(srcType)) {
346 if (shapedType.getNumElements() != 1)
347 return failure();
348 srcType = shapedType.getElementType();
349 }
350 if (!srcType.isIntOrIndexOrFloat())
351 return failure();
352
353 Attribute cstAttr = constOp.getValue();
354 if (auto elementsAttr = dyn_cast<DenseElementsAttr>(cstAttr))
355 cstAttr = elementsAttr.getSplatValue<Attribute>();
356
357 Type dstType = getTypeConverter()->convertType(srcType);
358 if (!dstType)
359 return failure();
360
361 // Floating-point types.
362 if (isa<FloatType>(Val: srcType)) {
363 auto srcAttr = cast<FloatAttr>(cstAttr);
364 auto dstAttr = srcAttr;
365
366 // Floating-point types not supported in the target environment are all
367 // converted to float type.
368 if (srcType != dstType) {
369 dstAttr = convertFloatAttr(srcAttr, cast<FloatType>(dstType), rewriter);
370 if (!dstAttr)
371 return failure();
372 }
373
374 rewriter.replaceOpWithNewOp<spirv::ConstantOp>(constOp, dstType, dstAttr);
375 return success();
376 }
377
378 // Bool type.
379 if (srcType.isInteger(width: 1)) {
380 // arith.constant can use 0/1 instead of true/false for i1 values. We need
381 // to handle that here.
382 auto dstAttr = convertBoolAttr(srcAttr: cstAttr, builder: rewriter);
383 if (!dstAttr)
384 return failure();
385 rewriter.replaceOpWithNewOp<spirv::ConstantOp>(constOp, dstType, dstAttr);
386 return success();
387 }
388
389 // IndexType or IntegerType. Index values are converted to 32-bit integer
390 // values when converting to SPIR-V.
391 auto srcAttr = cast<IntegerAttr>(cstAttr);
392 IntegerAttr dstAttr =
393 convertIntegerAttr(srcAttr, cast<IntegerType>(dstType), rewriter);
394 if (!dstAttr)
395 return failure();
396 rewriter.replaceOpWithNewOp<spirv::ConstantOp>(constOp, dstType, dstAttr);
397 return success();
398 }
399};
400
401//===----------------------------------------------------------------------===//
402// RemSIOp
403//===----------------------------------------------------------------------===//
404
405/// Returns signed remainder for `lhs` and `rhs` and lets the result follow
406/// the sign of `signOperand`.
407///
408/// Note that this is needed for Vulkan. Per the Vulkan's SPIR-V environment
409/// spec, "for the OpSRem and OpSMod instructions, if either operand is negative
410/// the result is undefined." So we cannot directly use spirv.SRem/spirv.SMod
411/// if either operand can be negative. Emulate it via spirv.UMod.
412template <typename SignedAbsOp>
413static Value emulateSignedRemainder(Location loc, Value lhs, Value rhs,
414 Value signOperand, OpBuilder &builder) {
415 assert(lhs.getType() == rhs.getType());
416 assert(lhs == signOperand || rhs == signOperand);
417
418 Type type = lhs.getType();
419
420 // Calculate the remainder with spirv.UMod.
421 Value lhsAbs = builder.create<SignedAbsOp>(loc, type, lhs);
422 Value rhsAbs = builder.create<SignedAbsOp>(loc, type, rhs);
423 Value abs = builder.create<spirv::UModOp>(loc, lhsAbs, rhsAbs);
424
425 // Fix the sign.
426 Value isPositive;
427 if (lhs == signOperand)
428 isPositive = builder.create<spirv::IEqualOp>(loc, lhs, lhsAbs);
429 else
430 isPositive = builder.create<spirv::IEqualOp>(loc, rhs, rhsAbs);
431 Value absNegate = builder.create<spirv::SNegateOp>(loc, type, abs);
432 return builder.create<spirv::SelectOp>(loc, type, isPositive, abs, absNegate);
433}
434
435/// Converts arith.remsi to GLSL SPIR-V ops.
436///
437/// This cannot be merged into the template unary/binary pattern due to Vulkan
438/// restrictions over spirv.SRem and spirv.SMod.
439struct RemSIOpGLPattern final : public OpConversionPattern<arith::RemSIOp> {
440 using OpConversionPattern::OpConversionPattern;
441
442 LogicalResult
443 matchAndRewrite(arith::RemSIOp op, OpAdaptor adaptor,
444 ConversionPatternRewriter &rewriter) const override {
445 Value result = emulateSignedRemainder<spirv::CLSAbsOp>(
446 op.getLoc(), adaptor.getOperands()[0], adaptor.getOperands()[1],
447 adaptor.getOperands()[0], rewriter);
448 rewriter.replaceOp(op, result);
449
450 return success();
451 }
452};
453
454/// Converts arith.remsi to OpenCL SPIR-V ops.
455struct RemSIOpCLPattern final : public OpConversionPattern<arith::RemSIOp> {
456 using OpConversionPattern::OpConversionPattern;
457
458 LogicalResult
459 matchAndRewrite(arith::RemSIOp op, OpAdaptor adaptor,
460 ConversionPatternRewriter &rewriter) const override {
461 Value result = emulateSignedRemainder<spirv::GLSAbsOp>(
462 op.getLoc(), adaptor.getOperands()[0], adaptor.getOperands()[1],
463 adaptor.getOperands()[0], rewriter);
464 rewriter.replaceOp(op, result);
465
466 return success();
467 }
468};
469
470//===----------------------------------------------------------------------===//
471// BitwiseOp
472//===----------------------------------------------------------------------===//
473
474/// Converts bitwise operations to SPIR-V operations. This is a special pattern
475/// other than the BinaryOpPatternPattern because if the operands are boolean
476/// values, SPIR-V uses different operations (`SPIRVLogicalOp`). For
477/// non-boolean operands, SPIR-V should use `SPIRVBitwiseOp`.
478template <typename Op, typename SPIRVLogicalOp, typename SPIRVBitwiseOp>
479struct BitwiseOpPattern final : public OpConversionPattern<Op> {
480 using OpConversionPattern<Op>::OpConversionPattern;
481
482 LogicalResult
483 matchAndRewrite(Op op, typename Op::Adaptor adaptor,
484 ConversionPatternRewriter &rewriter) const override {
485 assert(adaptor.getOperands().size() == 2);
486 Type dstType = this->getTypeConverter()->convertType(op.getType());
487 if (!dstType)
488 return getTypeConversionFailure(rewriter, op);
489
490 if (isBoolScalarOrVector(adaptor.getOperands().front().getType())) {
491 rewriter.template replaceOpWithNewOp<SPIRVLogicalOp>(
492 op, dstType, adaptor.getOperands());
493 } else {
494 rewriter.template replaceOpWithNewOp<SPIRVBitwiseOp>(
495 op, dstType, adaptor.getOperands());
496 }
497 return success();
498 }
499};
500
501//===----------------------------------------------------------------------===//
502// XOrIOp
503//===----------------------------------------------------------------------===//
504
505/// Converts arith.xori to SPIR-V operations.
506struct XOrIOpLogicalPattern final : public OpConversionPattern<arith::XOrIOp> {
507 using OpConversionPattern::OpConversionPattern;
508
509 LogicalResult
510 matchAndRewrite(arith::XOrIOp op, OpAdaptor adaptor,
511 ConversionPatternRewriter &rewriter) const override {
512 assert(adaptor.getOperands().size() == 2);
513
514 if (isBoolScalarOrVector(adaptor.getOperands().front().getType()))
515 return failure();
516
517 Type dstType = getTypeConverter()->convertType(op.getType());
518 if (!dstType)
519 return getTypeConversionFailure(rewriter, op);
520
521 rewriter.replaceOpWithNewOp<spirv::BitwiseXorOp>(op, dstType,
522 adaptor.getOperands());
523
524 return success();
525 }
526};
527
528/// Converts arith.xori to SPIR-V operations if the type of source is i1 or
529/// vector of i1.
530struct XOrIOpBooleanPattern final : public OpConversionPattern<arith::XOrIOp> {
531 using OpConversionPattern::OpConversionPattern;
532
533 LogicalResult
534 matchAndRewrite(arith::XOrIOp op, OpAdaptor adaptor,
535 ConversionPatternRewriter &rewriter) const override {
536 assert(adaptor.getOperands().size() == 2);
537
538 if (!isBoolScalarOrVector(adaptor.getOperands().front().getType()))
539 return failure();
540
541 Type dstType = getTypeConverter()->convertType(op.getType());
542 if (!dstType)
543 return getTypeConversionFailure(rewriter, op);
544
545 rewriter.replaceOpWithNewOp<spirv::LogicalNotEqualOp>(
546 op, dstType, adaptor.getOperands());
547 return success();
548 }
549};
550
551//===----------------------------------------------------------------------===//
552// UIToFPOp
553//===----------------------------------------------------------------------===//
554
555/// Converts arith.uitofp to spirv.Select if the type of source is i1 or vector
556/// of i1.
557struct UIToFPI1Pattern final : public OpConversionPattern<arith::UIToFPOp> {
558 using OpConversionPattern::OpConversionPattern;
559
560 LogicalResult
561 matchAndRewrite(arith::UIToFPOp op, OpAdaptor adaptor,
562 ConversionPatternRewriter &rewriter) const override {
563 Type srcType = adaptor.getOperands().front().getType();
564 if (!isBoolScalarOrVector(type: srcType))
565 return failure();
566
567 Type dstType = getTypeConverter()->convertType(op.getType());
568 if (!dstType)
569 return getTypeConversionFailure(rewriter, op);
570
571 Location loc = op.getLoc();
572 Value zero = spirv::ConstantOp::getZero(dstType, loc, rewriter);
573 Value one = spirv::ConstantOp::getOne(dstType, loc, rewriter);
574 rewriter.replaceOpWithNewOp<spirv::SelectOp>(
575 op, dstType, adaptor.getOperands().front(), one, zero);
576 return success();
577 }
578};
579
580//===----------------------------------------------------------------------===//
581// ExtSIOp
582//===----------------------------------------------------------------------===//
583
584/// Converts arith.extsi to spirv.Select if the type of source is i1 or vector
585/// of i1.
586struct ExtSII1Pattern final : public OpConversionPattern<arith::ExtSIOp> {
587 using OpConversionPattern::OpConversionPattern;
588
589 LogicalResult
590 matchAndRewrite(arith::ExtSIOp op, OpAdaptor adaptor,
591 ConversionPatternRewriter &rewriter) const override {
592 Value operand = adaptor.getIn();
593 if (!isBoolScalarOrVector(type: operand.getType()))
594 return failure();
595
596 Location loc = op.getLoc();
597 Type dstType = getTypeConverter()->convertType(op.getType());
598 if (!dstType)
599 return getTypeConversionFailure(rewriter, op);
600
601 Value allOnes;
602 if (auto intTy = dyn_cast<IntegerType>(dstType)) {
603 unsigned componentBitwidth = intTy.getWidth();
604 allOnes = rewriter.create<spirv::ConstantOp>(
605 loc, intTy,
606 rewriter.getIntegerAttr(intTy, APInt::getAllOnes(componentBitwidth)));
607 } else if (auto vectorTy = dyn_cast<VectorType>(dstType)) {
608 unsigned componentBitwidth = vectorTy.getElementTypeBitWidth();
609 allOnes = rewriter.create<spirv::ConstantOp>(
610 loc, vectorTy,
611 SplatElementsAttr::get(vectorTy,
612 APInt::getAllOnes(componentBitwidth)));
613 } else {
614 return rewriter.notifyMatchFailure(
615 arg&: loc, msg: llvm::formatv(Fmt: "unhandled type: {0}", Vals&: dstType));
616 }
617
618 Value zero = spirv::ConstantOp::getZero(dstType, loc, rewriter);
619 rewriter.replaceOpWithNewOp<spirv::SelectOp>(op, dstType, operand, allOnes,
620 zero);
621 return success();
622 }
623};
624
625/// Converts arith.extsi to spirv.Select if the type of source is neither i1 nor
626/// vector of i1.
627struct ExtSIPattern final : public OpConversionPattern<arith::ExtSIOp> {
628 using OpConversionPattern::OpConversionPattern;
629
630 LogicalResult
631 matchAndRewrite(arith::ExtSIOp op, OpAdaptor adaptor,
632 ConversionPatternRewriter &rewriter) const override {
633 Type srcType = adaptor.getIn().getType();
634 if (isBoolScalarOrVector(type: srcType))
635 return failure();
636
637 Type dstType = getTypeConverter()->convertType(op.getType());
638 if (!dstType)
639 return getTypeConversionFailure(rewriter, op);
640
641 if (dstType == srcType) {
642 // We can have the same source and destination type due to type emulation.
643 // Perform bit shifting to make sure we have the proper leading set bits.
644
645 unsigned srcBW =
646 getElementTypeOrSelf(op.getIn().getType()).getIntOrFloatBitWidth();
647 unsigned dstBW =
648 getElementTypeOrSelf(op.getType()).getIntOrFloatBitWidth();
649 assert(srcBW < dstBW);
650 Value shiftSize = getScalarOrVectorConstInt(dstType, dstBW - srcBW,
651 rewriter, op.getLoc());
652
653 // First shift left to sequeeze out all leading bits beyond the original
654 // bitwidth. Here we need to use the original source and result type's
655 // bitwidth.
656 auto shiftLOp = rewriter.create<spirv::ShiftLeftLogicalOp>(
657 op.getLoc(), dstType, adaptor.getIn(), shiftSize);
658
659 // Then we perform arithmetic right shift to make sure we have the right
660 // sign bits for negative values.
661 rewriter.replaceOpWithNewOp<spirv::ShiftRightArithmeticOp>(
662 op, dstType, shiftLOp, shiftSize);
663 } else {
664 rewriter.replaceOpWithNewOp<spirv::SConvertOp>(op, dstType,
665 adaptor.getOperands());
666 }
667
668 return success();
669 }
670};
671
672//===----------------------------------------------------------------------===//
673// ExtUIOp
674//===----------------------------------------------------------------------===//
675
676/// Converts arith.extui to spirv.Select if the type of source is i1 or vector
677/// of i1.
678struct ExtUII1Pattern final : public OpConversionPattern<arith::ExtUIOp> {
679 using OpConversionPattern::OpConversionPattern;
680
681 LogicalResult
682 matchAndRewrite(arith::ExtUIOp op, OpAdaptor adaptor,
683 ConversionPatternRewriter &rewriter) const override {
684 Type srcType = adaptor.getOperands().front().getType();
685 if (!isBoolScalarOrVector(type: srcType))
686 return failure();
687
688 Type dstType = getTypeConverter()->convertType(op.getType());
689 if (!dstType)
690 return getTypeConversionFailure(rewriter, op);
691
692 Location loc = op.getLoc();
693 Value zero = spirv::ConstantOp::getZero(dstType, loc, rewriter);
694 Value one = spirv::ConstantOp::getOne(dstType, loc, rewriter);
695 rewriter.replaceOpWithNewOp<spirv::SelectOp>(
696 op, dstType, adaptor.getOperands().front(), one, zero);
697 return success();
698 }
699};
700
701/// Converts arith.extui for cases where the type of source is neither i1 nor
702/// vector of i1.
703struct ExtUIPattern final : public OpConversionPattern<arith::ExtUIOp> {
704 using OpConversionPattern::OpConversionPattern;
705
706 LogicalResult
707 matchAndRewrite(arith::ExtUIOp op, OpAdaptor adaptor,
708 ConversionPatternRewriter &rewriter) const override {
709 Type srcType = adaptor.getIn().getType();
710 if (isBoolScalarOrVector(type: srcType))
711 return failure();
712
713 Type dstType = getTypeConverter()->convertType(op.getType());
714 if (!dstType)
715 return getTypeConversionFailure(rewriter, op);
716
717 if (dstType == srcType) {
718 // We can have the same source and destination type due to type emulation.
719 // Perform bit masking to make sure we don't pollute downstream consumers
720 // with unwanted bits. Here we need to use the original source type's
721 // bitwidth.
722 unsigned bitwidth =
723 getElementTypeOrSelf(op.getIn().getType()).getIntOrFloatBitWidth();
724 Value mask = getScalarOrVectorConstInt(
725 dstType, llvm::maskTrailingOnes<uint64_t>(N: bitwidth), rewriter,
726 op.getLoc());
727 rewriter.replaceOpWithNewOp<spirv::BitwiseAndOp>(op, dstType,
728 adaptor.getIn(), mask);
729 } else {
730 rewriter.replaceOpWithNewOp<spirv::UConvertOp>(op, dstType,
731 adaptor.getOperands());
732 }
733 return success();
734 }
735};
736
737//===----------------------------------------------------------------------===//
738// TruncIOp
739//===----------------------------------------------------------------------===//
740
741/// Converts arith.trunci to spirv.Select if the type of result is i1 or vector
742/// of i1.
743struct TruncII1Pattern final : public OpConversionPattern<arith::TruncIOp> {
744 using OpConversionPattern::OpConversionPattern;
745
746 LogicalResult
747 matchAndRewrite(arith::TruncIOp op, OpAdaptor adaptor,
748 ConversionPatternRewriter &rewriter) const override {
749 Type dstType = getTypeConverter()->convertType(op.getType());
750 if (!dstType)
751 return getTypeConversionFailure(rewriter, op);
752
753 if (!isBoolScalarOrVector(type: dstType))
754 return failure();
755
756 Location loc = op.getLoc();
757 auto srcType = adaptor.getOperands().front().getType();
758 // Check if (x & 1) == 1.
759 Value mask = spirv::ConstantOp::getOne(srcType, loc, rewriter);
760 Value maskedSrc = rewriter.create<spirv::BitwiseAndOp>(
761 loc, srcType, adaptor.getOperands()[0], mask);
762 Value isOne = rewriter.create<spirv::IEqualOp>(loc, maskedSrc, mask);
763
764 Value zero = spirv::ConstantOp::getZero(dstType, loc, rewriter);
765 Value one = spirv::ConstantOp::getOne(dstType, loc, rewriter);
766 rewriter.replaceOpWithNewOp<spirv::SelectOp>(op, dstType, isOne, one, zero);
767 return success();
768 }
769};
770
771/// Converts arith.trunci for cases where the type of result is neither i1
772/// nor vector of i1.
773struct TruncIPattern final : public OpConversionPattern<arith::TruncIOp> {
774 using OpConversionPattern::OpConversionPattern;
775
776 LogicalResult
777 matchAndRewrite(arith::TruncIOp op, OpAdaptor adaptor,
778 ConversionPatternRewriter &rewriter) const override {
779 Type srcType = adaptor.getIn().getType();
780 Type dstType = getTypeConverter()->convertType(op.getType());
781 if (!dstType)
782 return getTypeConversionFailure(rewriter, op);
783
784 if (isBoolScalarOrVector(type: dstType))
785 return failure();
786
787 if (dstType == srcType) {
788 // We can have the same source and destination type due to type emulation.
789 // Perform bit masking to make sure we don't pollute downstream consumers
790 // with unwanted bits. Here we need to use the original result type's
791 // bitwidth.
792 unsigned bw = getElementTypeOrSelf(op.getType()).getIntOrFloatBitWidth();
793 Value mask = getScalarOrVectorConstInt(
794 dstType, llvm::maskTrailingOnes<uint64_t>(N: bw), rewriter, op.getLoc());
795 rewriter.replaceOpWithNewOp<spirv::BitwiseAndOp>(op, dstType,
796 adaptor.getIn(), mask);
797 } else {
798 // Given this is truncation, either SConvertOp or UConvertOp works.
799 rewriter.replaceOpWithNewOp<spirv::SConvertOp>(op, dstType,
800 adaptor.getOperands());
801 }
802 return success();
803 }
804};
805
806//===----------------------------------------------------------------------===//
807// TypeCastingOp
808//===----------------------------------------------------------------------===//
809
810static std::optional<spirv::FPRoundingMode>
811convertArithRoundingModeToSPIRV(arith::RoundingMode roundingMode) {
812 switch (roundingMode) {
813 case arith::RoundingMode::downward:
814 return spirv::FPRoundingMode::RTN;
815 case arith::RoundingMode::to_nearest_even:
816 return spirv::FPRoundingMode::RTE;
817 case arith::RoundingMode::toward_zero:
818 return spirv::FPRoundingMode::RTZ;
819 case arith::RoundingMode::upward:
820 return spirv::FPRoundingMode::RTP;
821 case arith::RoundingMode::to_nearest_away:
822 // SPIR-V FPRoundingMode decoration has no ties-away-from-zero mode
823 // (as of SPIR-V 1.6)
824 return std::nullopt;
825 }
826 llvm_unreachable("Unhandled rounding mode");
827}
828
829/// Converts type-casting standard operations to SPIR-V operations.
830template <typename Op, typename SPIRVOp>
831struct TypeCastingOpPattern final : public OpConversionPattern<Op> {
832 using OpConversionPattern<Op>::OpConversionPattern;
833
834 LogicalResult
835 matchAndRewrite(Op op, typename Op::Adaptor adaptor,
836 ConversionPatternRewriter &rewriter) const override {
837 Type srcType = llvm::getSingleElement(adaptor.getOperands()).getType();
838 Type dstType = this->getTypeConverter()->convertType(op.getType());
839 if (!dstType)
840 return getTypeConversionFailure(rewriter, op);
841
842 if (isBoolScalarOrVector(type: srcType) || isBoolScalarOrVector(type: dstType))
843 return failure();
844
845 if (dstType == srcType) {
846 // Due to type conversion, we are seeing the same source and target type.
847 // Then we can just erase this operation by forwarding its operand.
848 rewriter.replaceOp(op, adaptor.getOperands().front());
849 } else {
850 // Compute new rounding mode (if any).
851 std::optional<spirv::FPRoundingMode> rm = std::nullopt;
852 if (auto roundingModeOp =
853 dyn_cast<arith::ArithRoundingModeInterface>(*op)) {
854 if (arith::RoundingModeAttr roundingMode =
855 roundingModeOp.getRoundingModeAttr()) {
856 if (!(rm =
857 convertArithRoundingModeToSPIRV(roundingMode.getValue()))) {
858 return rewriter.notifyMatchFailure(
859 op->getLoc(),
860 llvm::formatv("unsupported rounding mode '{0}'", roundingMode));
861 }
862 }
863 }
864 // Create replacement op and attach rounding mode attribute (if any).
865 auto newOp = rewriter.template replaceOpWithNewOp<SPIRVOp>(
866 op, dstType, adaptor.getOperands());
867 if (rm) {
868 newOp->setAttr(
869 getDecorationString(spirv::Decoration::FPRoundingMode),
870 spirv::FPRoundingModeAttr::get(rewriter.getContext(), *rm));
871 }
872 }
873 return success();
874 }
875};
876
877//===----------------------------------------------------------------------===//
878// CmpIOp
879//===----------------------------------------------------------------------===//
880
881/// Converts integer compare operation on i1 type operands to SPIR-V ops.
882class CmpIOpBooleanPattern final : public OpConversionPattern<arith::CmpIOp> {
883public:
884 using OpConversionPattern::OpConversionPattern;
885
886 LogicalResult
887 matchAndRewrite(arith::CmpIOp op, OpAdaptor adaptor,
888 ConversionPatternRewriter &rewriter) const override {
889 Type srcType = op.getLhs().getType();
890 if (!isBoolScalarOrVector(type: srcType))
891 return failure();
892 Type dstType = getTypeConverter()->convertType(srcType);
893 if (!dstType)
894 return getTypeConversionFailure(rewriter, op, srcType);
895
896 switch (op.getPredicate()) {
897 case arith::CmpIPredicate::eq: {
898 rewriter.replaceOpWithNewOp<spirv::LogicalEqualOp>(op, adaptor.getLhs(),
899 adaptor.getRhs());
900 return success();
901 }
902 case arith::CmpIPredicate::ne: {
903 rewriter.replaceOpWithNewOp<spirv::LogicalNotEqualOp>(
904 op, adaptor.getLhs(), adaptor.getRhs());
905 return success();
906 }
907 case arith::CmpIPredicate::uge:
908 case arith::CmpIPredicate::ugt:
909 case arith::CmpIPredicate::ule:
910 case arith::CmpIPredicate::ult: {
911 // There are no direct corresponding instructions in SPIR-V for such
912 // cases. Extend them to 32-bit and do comparision then.
913 Type type = rewriter.getI32Type();
914 if (auto vectorType = dyn_cast<VectorType>(dstType))
915 type = VectorType::get(vectorType.getShape(), type);
916 Value extLhs =
917 rewriter.create<arith::ExtUIOp>(op.getLoc(), type, adaptor.getLhs());
918 Value extRhs =
919 rewriter.create<arith::ExtUIOp>(op.getLoc(), type, adaptor.getRhs());
920
921 rewriter.replaceOpWithNewOp<arith::CmpIOp>(op, op.getPredicate(), extLhs,
922 extRhs);
923 return success();
924 }
925 default:
926 break;
927 }
928 return failure();
929 }
930};
931
932/// Converts integer compare operation to SPIR-V ops.
933class CmpIOpPattern final : public OpConversionPattern<arith::CmpIOp> {
934public:
935 using OpConversionPattern::OpConversionPattern;
936
937 LogicalResult
938 matchAndRewrite(arith::CmpIOp op, OpAdaptor adaptor,
939 ConversionPatternRewriter &rewriter) const override {
940 Type srcType = op.getLhs().getType();
941 if (isBoolScalarOrVector(type: srcType))
942 return failure();
943 Type dstType = getTypeConverter()->convertType(srcType);
944 if (!dstType)
945 return getTypeConversionFailure(rewriter, op, srcType);
946
947 switch (op.getPredicate()) {
948#define DISPATCH(cmpPredicate, spirvOp) \
949 case cmpPredicate: \
950 if (spirvOp::template hasTrait<OpTrait::spirv::UnsignedOp>() && \
951 !getElementTypeOrSelf(srcType).isIndex() && srcType != dstType && \
952 !hasSameBitwidth(srcType, dstType)) { \
953 return op.emitError( \
954 "bitwidth emulation is not implemented yet on unsigned op"); \
955 } \
956 rewriter.replaceOpWithNewOp<spirvOp>(op, adaptor.getLhs(), \
957 adaptor.getRhs()); \
958 return success();
959
960 DISPATCH(arith::CmpIPredicate::eq, spirv::IEqualOp);
961 DISPATCH(arith::CmpIPredicate::ne, spirv::INotEqualOp);
962 DISPATCH(arith::CmpIPredicate::slt, spirv::SLessThanOp);
963 DISPATCH(arith::CmpIPredicate::sle, spirv::SLessThanEqualOp);
964 DISPATCH(arith::CmpIPredicate::sgt, spirv::SGreaterThanOp);
965 DISPATCH(arith::CmpIPredicate::sge, spirv::SGreaterThanEqualOp);
966 DISPATCH(arith::CmpIPredicate::ult, spirv::ULessThanOp);
967 DISPATCH(arith::CmpIPredicate::ule, spirv::ULessThanEqualOp);
968 DISPATCH(arith::CmpIPredicate::ugt, spirv::UGreaterThanOp);
969 DISPATCH(arith::CmpIPredicate::uge, spirv::UGreaterThanEqualOp);
970
971#undef DISPATCH
972 }
973 return failure();
974 }
975};
976
977//===----------------------------------------------------------------------===//
978// CmpFOpPattern
979//===----------------------------------------------------------------------===//
980
981/// Converts floating-point comparison operations to SPIR-V ops.
982class CmpFOpPattern final : public OpConversionPattern<arith::CmpFOp> {
983public:
984 using OpConversionPattern::OpConversionPattern;
985
986 LogicalResult
987 matchAndRewrite(arith::CmpFOp op, OpAdaptor adaptor,
988 ConversionPatternRewriter &rewriter) const override {
989 switch (op.getPredicate()) {
990#define DISPATCH(cmpPredicate, spirvOp) \
991 case cmpPredicate: \
992 rewriter.replaceOpWithNewOp<spirvOp>(op, adaptor.getLhs(), \
993 adaptor.getRhs()); \
994 return success();
995
996 // Ordered.
997 DISPATCH(arith::CmpFPredicate::OEQ, spirv::FOrdEqualOp);
998 DISPATCH(arith::CmpFPredicate::OGT, spirv::FOrdGreaterThanOp);
999 DISPATCH(arith::CmpFPredicate::OGE, spirv::FOrdGreaterThanEqualOp);
1000 DISPATCH(arith::CmpFPredicate::OLT, spirv::FOrdLessThanOp);
1001 DISPATCH(arith::CmpFPredicate::OLE, spirv::FOrdLessThanEqualOp);
1002 DISPATCH(arith::CmpFPredicate::ONE, spirv::FOrdNotEqualOp);
1003 // Unordered.
1004 DISPATCH(arith::CmpFPredicate::UEQ, spirv::FUnordEqualOp);
1005 DISPATCH(arith::CmpFPredicate::UGT, spirv::FUnordGreaterThanOp);
1006 DISPATCH(arith::CmpFPredicate::UGE, spirv::FUnordGreaterThanEqualOp);
1007 DISPATCH(arith::CmpFPredicate::ULT, spirv::FUnordLessThanOp);
1008 DISPATCH(arith::CmpFPredicate::ULE, spirv::FUnordLessThanEqualOp);
1009 DISPATCH(arith::CmpFPredicate::UNE, spirv::FUnordNotEqualOp);
1010
1011#undef DISPATCH
1012
1013 default:
1014 break;
1015 }
1016 return failure();
1017 }
1018};
1019
1020/// Converts floating point NaN check to SPIR-V ops. This pattern requires
1021/// Kernel capability.
1022class CmpFOpNanKernelPattern final : public OpConversionPattern<arith::CmpFOp> {
1023public:
1024 using OpConversionPattern::OpConversionPattern;
1025
1026 LogicalResult
1027 matchAndRewrite(arith::CmpFOp op, OpAdaptor adaptor,
1028 ConversionPatternRewriter &rewriter) const override {
1029 if (op.getPredicate() == arith::CmpFPredicate::ORD) {
1030 rewriter.replaceOpWithNewOp<spirv::OrderedOp>(op, adaptor.getLhs(),
1031 adaptor.getRhs());
1032 return success();
1033 }
1034
1035 if (op.getPredicate() == arith::CmpFPredicate::UNO) {
1036 rewriter.replaceOpWithNewOp<spirv::UnorderedOp>(op, adaptor.getLhs(),
1037 adaptor.getRhs());
1038 return success();
1039 }
1040
1041 return failure();
1042 }
1043};
1044
1045/// Converts floating point NaN check to SPIR-V ops. This pattern does not
1046/// require additional capability.
1047class CmpFOpNanNonePattern final : public OpConversionPattern<arith::CmpFOp> {
1048public:
1049 using OpConversionPattern<arith::CmpFOp>::OpConversionPattern;
1050
1051 LogicalResult
1052 matchAndRewrite(arith::CmpFOp op, OpAdaptor adaptor,
1053 ConversionPatternRewriter &rewriter) const override {
1054 if (op.getPredicate() != arith::CmpFPredicate::ORD &&
1055 op.getPredicate() != arith::CmpFPredicate::UNO)
1056 return failure();
1057
1058 Location loc = op.getLoc();
1059
1060 Value replace;
1061 if (bitEnumContainsAll(op.getFastmath(), arith::FastMathFlags::nnan)) {
1062 if (op.getPredicate() == arith::CmpFPredicate::ORD) {
1063 // Ordered comparsion checks if neither operand is NaN.
1064 replace = spirv::ConstantOp::getOne(op.getType(), loc, rewriter);
1065 } else {
1066 // Unordered comparsion checks if either operand is NaN.
1067 replace = spirv::ConstantOp::getZero(op.getType(), loc, rewriter);
1068 }
1069 } else {
1070 Value lhsIsNan = rewriter.create<spirv::IsNanOp>(loc, adaptor.getLhs());
1071 Value rhsIsNan = rewriter.create<spirv::IsNanOp>(loc, adaptor.getRhs());
1072
1073 replace = rewriter.create<spirv::LogicalOrOp>(loc, lhsIsNan, rhsIsNan);
1074 if (op.getPredicate() == arith::CmpFPredicate::ORD)
1075 replace = rewriter.create<spirv::LogicalNotOp>(loc, replace);
1076 }
1077
1078 rewriter.replaceOp(op, replace);
1079 return success();
1080 }
1081};
1082
1083//===----------------------------------------------------------------------===//
1084// AddUIExtendedOp
1085//===----------------------------------------------------------------------===//
1086
1087/// Converts arith.addui_extended to spirv.IAddCarry.
1088class AddUIExtendedOpPattern final
1089 : public OpConversionPattern<arith::AddUIExtendedOp> {
1090public:
1091 using OpConversionPattern::OpConversionPattern;
1092 LogicalResult
1093 matchAndRewrite(arith::AddUIExtendedOp op, OpAdaptor adaptor,
1094 ConversionPatternRewriter &rewriter) const override {
1095 Type dstElemTy = adaptor.getLhs().getType();
1096 Location loc = op->getLoc();
1097 Value result = rewriter.create<spirv::IAddCarryOp>(loc, adaptor.getLhs(),
1098 adaptor.getRhs());
1099
1100 Value sumResult = rewriter.create<spirv::CompositeExtractOp>(
1101 loc, result, llvm::ArrayRef(0));
1102 Value carryValue = rewriter.create<spirv::CompositeExtractOp>(
1103 loc, result, llvm::ArrayRef(1));
1104
1105 // Convert the carry value to boolean.
1106 Value one = spirv::ConstantOp::getOne(dstElemTy, loc, rewriter);
1107 Value carryResult = rewriter.create<spirv::IEqualOp>(loc, carryValue, one);
1108
1109 rewriter.replaceOp(op, {sumResult, carryResult});
1110 return success();
1111 }
1112};
1113
1114//===----------------------------------------------------------------------===//
1115// MulIExtendedOp
1116//===----------------------------------------------------------------------===//
1117
1118/// Converts arith.mul*i_extended to spirv.*MulExtended.
1119template <typename ArithMulOp, typename SPIRVMulOp>
1120class MulIExtendedOpPattern final : public OpConversionPattern<ArithMulOp> {
1121public:
1122 using OpConversionPattern<ArithMulOp>::OpConversionPattern;
1123 LogicalResult
1124 matchAndRewrite(ArithMulOp op, typename ArithMulOp::Adaptor adaptor,
1125 ConversionPatternRewriter &rewriter) const override {
1126 Location loc = op->getLoc();
1127 Value result =
1128 rewriter.create<SPIRVMulOp>(loc, adaptor.getLhs(), adaptor.getRhs());
1129
1130 Value low = rewriter.create<spirv::CompositeExtractOp>(loc, result,
1131 llvm::ArrayRef(0));
1132 Value high = rewriter.create<spirv::CompositeExtractOp>(loc, result,
1133 llvm::ArrayRef(1));
1134
1135 rewriter.replaceOp(op, {low, high});
1136 return success();
1137 }
1138};
1139
1140//===----------------------------------------------------------------------===//
1141// SelectOp
1142//===----------------------------------------------------------------------===//
1143
1144/// Converts arith.select to spirv.Select.
1145class SelectOpPattern final : public OpConversionPattern<arith::SelectOp> {
1146public:
1147 using OpConversionPattern::OpConversionPattern;
1148 LogicalResult
1149 matchAndRewrite(arith::SelectOp op, OpAdaptor adaptor,
1150 ConversionPatternRewriter &rewriter) const override {
1151 rewriter.replaceOpWithNewOp<spirv::SelectOp>(op, adaptor.getCondition(),
1152 adaptor.getTrueValue(),
1153 adaptor.getFalseValue());
1154 return success();
1155 }
1156};
1157
1158//===----------------------------------------------------------------------===//
1159// MinimumFOp, MaximumFOp
1160//===----------------------------------------------------------------------===//
1161
1162/// Converts arith.maximumf/minimumf to spirv.GL.FMax/FMin or
1163/// spirv.CL.fmax/fmin.
1164template <typename Op, typename SPIRVOp>
1165class MinimumMaximumFOpPattern final : public OpConversionPattern<Op> {
1166public:
1167 using OpConversionPattern<Op>::OpConversionPattern;
1168 LogicalResult
1169 matchAndRewrite(Op op, typename Op::Adaptor adaptor,
1170 ConversionPatternRewriter &rewriter) const override {
1171 auto *converter = this->template getTypeConverter<SPIRVTypeConverter>();
1172 Type dstType = converter->convertType(op.getType());
1173 if (!dstType)
1174 return getTypeConversionFailure(rewriter, op);
1175
1176 // arith.maximumf/minimumf:
1177 // "if one of the arguments is NaN, then the result is also NaN."
1178 // spirv.GL.FMax/FMin
1179 // "which operand is the result is undefined if one of the operands
1180 // is a NaN."
1181 // spirv.CL.fmax/fmin:
1182 // "If one argument is a NaN, Fmin returns the other argument."
1183
1184 Location loc = op.getLoc();
1185 Value spirvOp =
1186 rewriter.create<SPIRVOp>(loc, dstType, adaptor.getOperands());
1187
1188 if (bitEnumContainsAll(op.getFastmath(), arith::FastMathFlags::nnan)) {
1189 rewriter.replaceOp(op, spirvOp);
1190 return success();
1191 }
1192
1193 Value lhsIsNan = rewriter.create<spirv::IsNanOp>(loc, adaptor.getLhs());
1194 Value rhsIsNan = rewriter.create<spirv::IsNanOp>(loc, adaptor.getRhs());
1195
1196 Value select1 = rewriter.create<spirv::SelectOp>(loc, dstType, lhsIsNan,
1197 adaptor.getLhs(), spirvOp);
1198 Value select2 = rewriter.create<spirv::SelectOp>(loc, dstType, rhsIsNan,
1199 adaptor.getRhs(), select1);
1200
1201 rewriter.replaceOp(op, select2);
1202 return success();
1203 }
1204};
1205
1206//===----------------------------------------------------------------------===//
1207// MinNumFOp, MaxNumFOp
1208//===----------------------------------------------------------------------===//
1209
1210/// Converts arith.maxnumf/minnumf to spirv.GL.FMax/FMin or
1211/// spirv.CL.fmax/fmin.
1212template <typename Op, typename SPIRVOp>
1213class MinNumMaxNumFOpPattern final : public OpConversionPattern<Op> {
1214 template <typename TargetOp>
1215 constexpr bool shouldInsertNanGuards() const {
1216 return llvm::is_one_of<TargetOp, spirv::GLFMaxOp, spirv::GLFMinOp>::value;
1217 }
1218
1219public:
1220 using OpConversionPattern<Op>::OpConversionPattern;
1221 LogicalResult
1222 matchAndRewrite(Op op, typename Op::Adaptor adaptor,
1223 ConversionPatternRewriter &rewriter) const override {
1224 auto *converter = this->template getTypeConverter<SPIRVTypeConverter>();
1225 Type dstType = converter->convertType(op.getType());
1226 if (!dstType)
1227 return getTypeConversionFailure(rewriter, op);
1228
1229 // arith.maxnumf/minnumf:
1230 // "If one of the arguments is NaN, then the result is the other
1231 // argument."
1232 // spirv.GL.FMax/FMin
1233 // "which operand is the result is undefined if one of the operands
1234 // is a NaN."
1235 // spirv.CL.fmax/fmin:
1236 // "If one argument is a NaN, Fmin returns the other argument."
1237
1238 Location loc = op.getLoc();
1239 Value spirvOp =
1240 rewriter.create<SPIRVOp>(loc, dstType, adaptor.getOperands());
1241
1242 if (!shouldInsertNanGuards<SPIRVOp>() ||
1243 bitEnumContainsAll(op.getFastmath(), arith::FastMathFlags::nnan)) {
1244 rewriter.replaceOp(op, spirvOp);
1245 return success();
1246 }
1247
1248 Value lhsIsNan = rewriter.create<spirv::IsNanOp>(loc, adaptor.getLhs());
1249 Value rhsIsNan = rewriter.create<spirv::IsNanOp>(loc, adaptor.getRhs());
1250
1251 Value select1 = rewriter.create<spirv::SelectOp>(loc, dstType, lhsIsNan,
1252 adaptor.getRhs(), spirvOp);
1253 Value select2 = rewriter.create<spirv::SelectOp>(loc, dstType, rhsIsNan,
1254 adaptor.getLhs(), select1);
1255
1256 rewriter.replaceOp(op, select2);
1257 return success();
1258 }
1259};
1260
1261} // namespace
1262
1263//===----------------------------------------------------------------------===//
1264// Pattern Population
1265//===----------------------------------------------------------------------===//
1266
1267void mlir::arith::populateArithToSPIRVPatterns(
1268 const SPIRVTypeConverter &typeConverter, RewritePatternSet &patterns) {
1269 // clang-format off
1270 patterns.add<
1271 ConstantCompositeOpPattern,
1272 ConstantScalarOpPattern,
1273 ElementwiseArithOpPattern<arith::AddIOp, spirv::IAddOp>,
1274 ElementwiseArithOpPattern<arith::SubIOp, spirv::ISubOp>,
1275 ElementwiseArithOpPattern<arith::MulIOp, spirv::IMulOp>,
1276 spirv::ElementwiseOpPattern<arith::DivUIOp, spirv::UDivOp>,
1277 spirv::ElementwiseOpPattern<arith::DivSIOp, spirv::SDivOp>,
1278 spirv::ElementwiseOpPattern<arith::RemUIOp, spirv::UModOp>,
1279 RemSIOpGLPattern, RemSIOpCLPattern,
1280 BitwiseOpPattern<arith::AndIOp, spirv::LogicalAndOp, spirv::BitwiseAndOp>,
1281 BitwiseOpPattern<arith::OrIOp, spirv::LogicalOrOp, spirv::BitwiseOrOp>,
1282 XOrIOpLogicalPattern, XOrIOpBooleanPattern,
1283 ElementwiseArithOpPattern<arith::ShLIOp, spirv::ShiftLeftLogicalOp>,
1284 spirv::ElementwiseOpPattern<arith::ShRUIOp, spirv::ShiftRightLogicalOp>,
1285 spirv::ElementwiseOpPattern<arith::ShRSIOp, spirv::ShiftRightArithmeticOp>,
1286 spirv::ElementwiseOpPattern<arith::NegFOp, spirv::FNegateOp>,
1287 spirv::ElementwiseOpPattern<arith::AddFOp, spirv::FAddOp>,
1288 spirv::ElementwiseOpPattern<arith::SubFOp, spirv::FSubOp>,
1289 spirv::ElementwiseOpPattern<arith::MulFOp, spirv::FMulOp>,
1290 spirv::ElementwiseOpPattern<arith::DivFOp, spirv::FDivOp>,
1291 spirv::ElementwiseOpPattern<arith::RemFOp, spirv::FRemOp>,
1292 ExtUIPattern, ExtUII1Pattern,
1293 ExtSIPattern, ExtSII1Pattern,
1294 TypeCastingOpPattern<arith::ExtFOp, spirv::FConvertOp>,
1295 TruncIPattern, TruncII1Pattern,
1296 TypeCastingOpPattern<arith::TruncFOp, spirv::FConvertOp>,
1297 TypeCastingOpPattern<arith::UIToFPOp, spirv::ConvertUToFOp>, UIToFPI1Pattern,
1298 TypeCastingOpPattern<arith::SIToFPOp, spirv::ConvertSToFOp>,
1299 TypeCastingOpPattern<arith::FPToUIOp, spirv::ConvertFToUOp>,
1300 TypeCastingOpPattern<arith::FPToSIOp, spirv::ConvertFToSOp>,
1301 TypeCastingOpPattern<arith::IndexCastOp, spirv::SConvertOp>,
1302 TypeCastingOpPattern<arith::IndexCastUIOp, spirv::UConvertOp>,
1303 TypeCastingOpPattern<arith::BitcastOp, spirv::BitcastOp>,
1304 CmpIOpBooleanPattern, CmpIOpPattern,
1305 CmpFOpNanNonePattern, CmpFOpPattern,
1306 AddUIExtendedOpPattern,
1307 MulIExtendedOpPattern<arith::MulSIExtendedOp, spirv::SMulExtendedOp>,
1308 MulIExtendedOpPattern<arith::MulUIExtendedOp, spirv::UMulExtendedOp>,
1309 SelectOpPattern,
1310
1311 MinimumMaximumFOpPattern<arith::MaximumFOp, spirv::GLFMaxOp>,
1312 MinimumMaximumFOpPattern<arith::MinimumFOp, spirv::GLFMinOp>,
1313 MinNumMaxNumFOpPattern<arith::MaxNumFOp, spirv::GLFMaxOp>,
1314 MinNumMaxNumFOpPattern<arith::MinNumFOp, spirv::GLFMinOp>,
1315 spirv::ElementwiseOpPattern<arith::MaxSIOp, spirv::GLSMaxOp>,
1316 spirv::ElementwiseOpPattern<arith::MaxUIOp, spirv::GLUMaxOp>,
1317 spirv::ElementwiseOpPattern<arith::MinSIOp, spirv::GLSMinOp>,
1318 spirv::ElementwiseOpPattern<arith::MinUIOp, spirv::GLUMinOp>,
1319
1320 MinimumMaximumFOpPattern<arith::MaximumFOp, spirv::CLFMaxOp>,
1321 MinimumMaximumFOpPattern<arith::MinimumFOp, spirv::CLFMinOp>,
1322 MinNumMaxNumFOpPattern<arith::MaxNumFOp, spirv::CLFMaxOp>,
1323 MinNumMaxNumFOpPattern<arith::MinNumFOp, spirv::CLFMinOp>,
1324 spirv::ElementwiseOpPattern<arith::MaxSIOp, spirv::CLSMaxOp>,
1325 spirv::ElementwiseOpPattern<arith::MaxUIOp, spirv::CLUMaxOp>,
1326 spirv::ElementwiseOpPattern<arith::MinSIOp, spirv::CLSMinOp>,
1327 spirv::ElementwiseOpPattern<arith::MinUIOp, spirv::CLUMinOp>
1328 >(typeConverter, patterns.getContext());
1329 // clang-format on
1330
1331 // Give CmpFOpNanKernelPattern a higher benefit so it can prevail when Kernel
1332 // capability is available.
1333 patterns.add<CmpFOpNanKernelPattern>(arg: typeConverter, args: patterns.getContext(),
1334 /*benefit=*/args: 2);
1335}
1336
1337//===----------------------------------------------------------------------===//
1338// Pass Definition
1339//===----------------------------------------------------------------------===//
1340
1341namespace {
1342struct ConvertArithToSPIRVPass
1343 : public impl::ConvertArithToSPIRVPassBase<ConvertArithToSPIRVPass> {
1344 using Base::Base;
1345
1346 void runOnOperation() override {
1347 Operation *op = getOperation();
1348 spirv::TargetEnvAttr targetAttr = spirv::lookupTargetEnvOrDefault(op);
1349 std::unique_ptr<SPIRVConversionTarget> target =
1350 SPIRVConversionTarget::get(targetAttr);
1351
1352 SPIRVConversionOptions options;
1353 options.emulateLT32BitScalarTypes = this->emulateLT32BitScalarTypes;
1354 SPIRVTypeConverter typeConverter(targetAttr, options);
1355
1356 // Use UnrealizedConversionCast as the bridge so that we don't need to pull
1357 // in patterns for other dialects.
1358 target->addLegalOp<UnrealizedConversionCastOp>();
1359
1360 // Fail hard when there are any remaining 'arith' ops.
1361 target->addIllegalDialect<arith::ArithDialect>();
1362
1363 RewritePatternSet patterns(&getContext());
1364 arith::populateArithToSPIRVPatterns(typeConverter, patterns);
1365
1366 if (failed(applyPartialConversion(op, *target, std::move(patterns))))
1367 signalPassFailure();
1368 }
1369};
1370} // namespace
1371

Provided by KDAB

Privacy Policy
Improve your Profiling and Debugging skills
Find out more

source code of mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp