1//===- ArithToSPIRV.cpp - Arithmetic to SPIRV dialect conversion -----===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#include "mlir/Conversion/ArithToSPIRV/ArithToSPIRV.h"
10
11#include "../SPIRVCommon/Pattern.h"
12#include "mlir/Dialect/Arith/IR/Arith.h"
13#include "mlir/Dialect/SPIRV/IR/SPIRVAttributes.h"
14#include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h"
15#include "mlir/Dialect/SPIRV/IR/SPIRVOps.h"
16#include "mlir/Dialect/SPIRV/IR/SPIRVTypes.h"
17#include "mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h"
18#include "mlir/IR/BuiltinAttributes.h"
19#include "mlir/IR/BuiltinTypes.h"
20#include "llvm/ADT/APInt.h"
21#include "llvm/ADT/ArrayRef.h"
22#include "llvm/ADT/STLExtras.h"
23#include "llvm/Support/Debug.h"
24#include "llvm/Support/MathExtras.h"
25#include <cassert>
26#include <memory>
27
28namespace mlir {
29#define GEN_PASS_DEF_CONVERTARITHTOSPIRV
30#include "mlir/Conversion/Passes.h.inc"
31} // namespace mlir
32
33#define DEBUG_TYPE "arith-to-spirv-pattern"
34
35using namespace mlir;
36
37//===----------------------------------------------------------------------===//
38// Conversion Helpers
39//===----------------------------------------------------------------------===//
40
41/// Converts the given `srcAttr` into a boolean attribute if it holds an
42/// integral value. Returns null attribute if conversion fails.
43static BoolAttr convertBoolAttr(Attribute srcAttr, Builder builder) {
44 if (auto boolAttr = dyn_cast<BoolAttr>(srcAttr))
45 return boolAttr;
46 if (auto intAttr = dyn_cast<IntegerAttr>(srcAttr))
47 return builder.getBoolAttr(value: intAttr.getValue().getBoolValue());
48 return {};
49}
50
51/// Converts the given `srcAttr` to a new attribute of the given `dstType`.
52/// Returns null attribute if conversion fails.
53static IntegerAttr convertIntegerAttr(IntegerAttr srcAttr, IntegerType dstType,
54 Builder builder) {
55 // If the source number uses less active bits than the target bitwidth, then
56 // it should be safe to convert.
57 if (srcAttr.getValue().isIntN(dstType.getWidth()))
58 return builder.getIntegerAttr(dstType, srcAttr.getInt());
59
60 // XXX: Try again by interpreting the source number as a signed value.
61 // Although integers in the standard dialect are signless, they can represent
62 // a signed number. It's the operation decides how to interpret. This is
63 // dangerous, but it seems there is no good way of handling this if we still
64 // want to change the bitwidth. Emit a message at least.
65 if (srcAttr.getValue().isSignedIntN(dstType.getWidth())) {
66 auto dstAttr = builder.getIntegerAttr(dstType, srcAttr.getInt());
67 LLVM_DEBUG(llvm::dbgs() << "attribute '" << srcAttr << "' converted to '"
68 << dstAttr << "' for type '" << dstType << "'\n");
69 return dstAttr;
70 }
71
72 LLVM_DEBUG(llvm::dbgs() << "attribute '" << srcAttr
73 << "' illegal: cannot fit into target type '"
74 << dstType << "'\n");
75 return {};
76}
77
78/// Converts the given `srcAttr` to a new attribute of the given `dstType`.
79/// Returns null attribute if `dstType` is not 32-bit or conversion fails.
80static FloatAttr convertFloatAttr(FloatAttr srcAttr, FloatType dstType,
81 Builder builder) {
82 // Only support converting to float for now.
83 if (!dstType.isF32())
84 return FloatAttr();
85
86 // Try to convert the source floating-point number to single precision.
87 APFloat dstVal = srcAttr.getValue();
88 bool losesInfo = false;
89 APFloat::opStatus status =
90 dstVal.convert(ToSemantics: APFloat::IEEEsingle(), RM: APFloat::rmTowardZero, losesInfo: &losesInfo);
91 if (status != APFloat::opOK || losesInfo) {
92 LLVM_DEBUG(llvm::dbgs()
93 << srcAttr << " illegal: cannot fit into converted type '"
94 << dstType << "'\n");
95 return FloatAttr();
96 }
97
98 return builder.getF32FloatAttr(dstVal.convertToFloat());
99}
100
101/// Returns true if the given `type` is a boolean scalar or vector type.
102static bool isBoolScalarOrVector(Type type) {
103 assert(type && "Not a valid type");
104 if (type.isInteger(width: 1))
105 return true;
106
107 if (auto vecType = dyn_cast<VectorType>(type))
108 return vecType.getElementType().isInteger(1);
109
110 return false;
111}
112
113/// Creates a scalar/vector integer constant.
114static Value getScalarOrVectorConstInt(Type type, uint64_t value,
115 OpBuilder &builder, Location loc) {
116 if (auto vectorType = dyn_cast<VectorType>(type)) {
117 Attribute element = IntegerAttr::get(vectorType.getElementType(), value);
118 auto attr = SplatElementsAttr::get(vectorType, element);
119 return builder.create<spirv::ConstantOp>(loc, vectorType, attr);
120 }
121
122 if (auto intType = dyn_cast<IntegerType>(type))
123 return builder.create<spirv::ConstantOp>(
124 loc, type, builder.getIntegerAttr(type, value));
125
126 return nullptr;
127}
128
129/// Returns true if scalar/vector type `a` and `b` have the same number of
130/// bitwidth.
131static bool hasSameBitwidth(Type a, Type b) {
132 auto getNumBitwidth = [](Type type) {
133 unsigned bw = 0;
134 if (type.isIntOrFloat())
135 bw = type.getIntOrFloatBitWidth();
136 else if (auto vecType = dyn_cast<VectorType>(type))
137 bw = vecType.getElementTypeBitWidth() * vecType.getNumElements();
138 return bw;
139 };
140 unsigned aBW = getNumBitwidth(a);
141 unsigned bBW = getNumBitwidth(b);
142 return aBW != 0 && bBW != 0 && aBW == bBW;
143}
144
145/// Returns a source type conversion failure for `srcType` and operation `op`.
146static LogicalResult
147getTypeConversionFailure(ConversionPatternRewriter &rewriter, Operation *op,
148 Type srcType) {
149 return rewriter.notifyMatchFailure(
150 arg: op->getLoc(),
151 msg: llvm::formatv(Fmt: "failed to convert source type '{0}'", Vals&: srcType));
152}
153
154/// Returns a source type conversion failure for the result type of `op`.
155static LogicalResult
156getTypeConversionFailure(ConversionPatternRewriter &rewriter, Operation *op) {
157 assert(op->getNumResults() == 1);
158 return getTypeConversionFailure(rewriter, op, srcType: op->getResultTypes().front());
159}
160
161// TODO: Move to some common place?
162static std::string getDecorationString(spirv::Decoration decor) {
163 return llvm::convertToSnakeFromCamelCase(input: stringifyDecoration(decor));
164}
165
166namespace {
167
168/// Converts elementwise unary, binary and ternary arith operations to SPIR-V
169/// operations. Op can potentially support overflow flags.
170template <typename Op, typename SPIRVOp>
171struct ElementwiseArithOpPattern final : OpConversionPattern<Op> {
172 using OpConversionPattern<Op>::OpConversionPattern;
173
174 LogicalResult
175 matchAndRewrite(Op op, typename Op::Adaptor adaptor,
176 ConversionPatternRewriter &rewriter) const override {
177 assert(adaptor.getOperands().size() <= 3);
178 auto converter = this->template getTypeConverter<SPIRVTypeConverter>();
179 Type dstType = converter->convertType(op.getType());
180 if (!dstType) {
181 return rewriter.notifyMatchFailure(
182 op->getLoc(),
183 llvm::formatv("failed to convert type {0} for SPIR-V", op.getType()));
184 }
185
186 if (SPIRVOp::template hasTrait<OpTrait::spirv::UnsignedOp>() &&
187 !getElementTypeOrSelf(op.getType()).isIndex() &&
188 dstType != op.getType()) {
189 return op.emitError("bitwidth emulation is not implemented yet on "
190 "unsigned op pattern version");
191 }
192
193 auto overflowFlags = arith::IntegerOverflowFlags::none;
194 if (auto overflowIface =
195 dyn_cast<arith::ArithIntegerOverflowFlagsInterface>(*op)) {
196 if (converter->getTargetEnv().allows(
197 spirv::Extension::SPV_KHR_no_integer_wrap_decoration))
198 overflowFlags = overflowIface.getOverflowAttr().getValue();
199 }
200
201 auto newOp = rewriter.template replaceOpWithNewOp<SPIRVOp>(
202 op, dstType, adaptor.getOperands());
203
204 if (bitEnumContainsAny(overflowFlags, arith::IntegerOverflowFlags::nsw))
205 newOp->setAttr(getDecorationString(spirv::Decoration::NoSignedWrap),
206 rewriter.getUnitAttr());
207
208 if (bitEnumContainsAny(overflowFlags, arith::IntegerOverflowFlags::nuw))
209 newOp->setAttr(getDecorationString(spirv::Decoration::NoUnsignedWrap),
210 rewriter.getUnitAttr());
211
212 return success();
213 }
214};
215
216//===----------------------------------------------------------------------===//
217// ConstantOp
218//===----------------------------------------------------------------------===//
219
220/// Converts composite arith.constant operation to spirv.Constant.
221struct ConstantCompositeOpPattern final
222 : public OpConversionPattern<arith::ConstantOp> {
223 using OpConversionPattern::OpConversionPattern;
224
225 LogicalResult
226 matchAndRewrite(arith::ConstantOp constOp, OpAdaptor adaptor,
227 ConversionPatternRewriter &rewriter) const override {
228 auto srcType = dyn_cast<ShapedType>(constOp.getType());
229 if (!srcType || srcType.getNumElements() == 1)
230 return failure();
231
232 // arith.constant should only have vector or tenor types.
233 assert((isa<VectorType, RankedTensorType>(srcType)));
234
235 Type dstType = getTypeConverter()->convertType(srcType);
236 if (!dstType)
237 return failure();
238
239 auto dstElementsAttr = dyn_cast<DenseElementsAttr>(constOp.getValue());
240 if (!dstElementsAttr)
241 return failure();
242
243 ShapedType dstAttrType = dstElementsAttr.getType();
244
245 // If the composite type has more than one dimensions, perform
246 // linearization.
247 if (srcType.getRank() > 1) {
248 if (isa<RankedTensorType>(srcType)) {
249 dstAttrType = RankedTensorType::get(srcType.getNumElements(),
250 srcType.getElementType());
251 dstElementsAttr = dstElementsAttr.reshape(dstAttrType);
252 } else {
253 // TODO: add support for large vectors.
254 return failure();
255 }
256 }
257
258 Type srcElemType = srcType.getElementType();
259 Type dstElemType;
260 // Tensor types are converted to SPIR-V array types; vector types are
261 // converted to SPIR-V vector/array types.
262 if (auto arrayType = dyn_cast<spirv::ArrayType>(dstType))
263 dstElemType = arrayType.getElementType();
264 else
265 dstElemType = cast<VectorType>(dstType).getElementType();
266
267 // If the source and destination element types are different, perform
268 // attribute conversion.
269 if (srcElemType != dstElemType) {
270 SmallVector<Attribute, 8> elements;
271 if (isa<FloatType>(Val: srcElemType)) {
272 for (FloatAttr srcAttr : dstElementsAttr.getValues<FloatAttr>()) {
273 FloatAttr dstAttr =
274 convertFloatAttr(srcAttr, cast<FloatType>(dstElemType), rewriter);
275 if (!dstAttr)
276 return failure();
277 elements.push_back(dstAttr);
278 }
279 } else if (srcElemType.isInteger(width: 1)) {
280 return failure();
281 } else {
282 for (IntegerAttr srcAttr : dstElementsAttr.getValues<IntegerAttr>()) {
283 IntegerAttr dstAttr = convertIntegerAttr(
284 srcAttr, cast<IntegerType>(dstElemType), rewriter);
285 if (!dstAttr)
286 return failure();
287 elements.push_back(dstAttr);
288 }
289 }
290
291 // Unfortunately, we cannot use dialect-specific types for element
292 // attributes; element attributes only works with builtin types. So we
293 // need to prepare another converted builtin types for the destination
294 // elements attribute.
295 if (isa<RankedTensorType>(dstAttrType))
296 dstAttrType =
297 RankedTensorType::get(dstAttrType.getShape(), dstElemType);
298 else
299 dstAttrType = VectorType::get(dstAttrType.getShape(), dstElemType);
300
301 dstElementsAttr = DenseElementsAttr::get(dstAttrType, elements);
302 }
303
304 rewriter.replaceOpWithNewOp<spirv::ConstantOp>(constOp, dstType,
305 dstElementsAttr);
306 return success();
307 }
308};
309
310/// Converts scalar arith.constant operation to spirv.Constant.
311struct ConstantScalarOpPattern final
312 : public OpConversionPattern<arith::ConstantOp> {
313 using OpConversionPattern::OpConversionPattern;
314
315 LogicalResult
316 matchAndRewrite(arith::ConstantOp constOp, OpAdaptor adaptor,
317 ConversionPatternRewriter &rewriter) const override {
318 Type srcType = constOp.getType();
319 if (auto shapedType = dyn_cast<ShapedType>(srcType)) {
320 if (shapedType.getNumElements() != 1)
321 return failure();
322 srcType = shapedType.getElementType();
323 }
324 if (!srcType.isIntOrIndexOrFloat())
325 return failure();
326
327 Attribute cstAttr = constOp.getValue();
328 if (auto elementsAttr = dyn_cast<DenseElementsAttr>(cstAttr))
329 cstAttr = elementsAttr.getSplatValue<Attribute>();
330
331 Type dstType = getTypeConverter()->convertType(srcType);
332 if (!dstType)
333 return failure();
334
335 // Floating-point types.
336 if (isa<FloatType>(Val: srcType)) {
337 auto srcAttr = cast<FloatAttr>(cstAttr);
338 auto dstAttr = srcAttr;
339
340 // Floating-point types not supported in the target environment are all
341 // converted to float type.
342 if (srcType != dstType) {
343 dstAttr = convertFloatAttr(srcAttr, cast<FloatType>(dstType), rewriter);
344 if (!dstAttr)
345 return failure();
346 }
347
348 rewriter.replaceOpWithNewOp<spirv::ConstantOp>(constOp, dstType, dstAttr);
349 return success();
350 }
351
352 // Bool type.
353 if (srcType.isInteger(width: 1)) {
354 // arith.constant can use 0/1 instead of true/false for i1 values. We need
355 // to handle that here.
356 auto dstAttr = convertBoolAttr(srcAttr: cstAttr, builder: rewriter);
357 if (!dstAttr)
358 return failure();
359 rewriter.replaceOpWithNewOp<spirv::ConstantOp>(constOp, dstType, dstAttr);
360 return success();
361 }
362
363 // IndexType or IntegerType. Index values are converted to 32-bit integer
364 // values when converting to SPIR-V.
365 auto srcAttr = cast<IntegerAttr>(cstAttr);
366 IntegerAttr dstAttr =
367 convertIntegerAttr(srcAttr, cast<IntegerType>(dstType), rewriter);
368 if (!dstAttr)
369 return failure();
370 rewriter.replaceOpWithNewOp<spirv::ConstantOp>(constOp, dstType, dstAttr);
371 return success();
372 }
373};
374
375//===----------------------------------------------------------------------===//
376// RemSIOp
377//===----------------------------------------------------------------------===//
378
379/// Returns signed remainder for `lhs` and `rhs` and lets the result follow
380/// the sign of `signOperand`.
381///
382/// Note that this is needed for Vulkan. Per the Vulkan's SPIR-V environment
383/// spec, "for the OpSRem and OpSMod instructions, if either operand is negative
384/// the result is undefined." So we cannot directly use spirv.SRem/spirv.SMod
385/// if either operand can be negative. Emulate it via spirv.UMod.
386template <typename SignedAbsOp>
387static Value emulateSignedRemainder(Location loc, Value lhs, Value rhs,
388 Value signOperand, OpBuilder &builder) {
389 assert(lhs.getType() == rhs.getType());
390 assert(lhs == signOperand || rhs == signOperand);
391
392 Type type = lhs.getType();
393
394 // Calculate the remainder with spirv.UMod.
395 Value lhsAbs = builder.create<SignedAbsOp>(loc, type, lhs);
396 Value rhsAbs = builder.create<SignedAbsOp>(loc, type, rhs);
397 Value abs = builder.create<spirv::UModOp>(loc, lhsAbs, rhsAbs);
398
399 // Fix the sign.
400 Value isPositive;
401 if (lhs == signOperand)
402 isPositive = builder.create<spirv::IEqualOp>(loc, lhs, lhsAbs);
403 else
404 isPositive = builder.create<spirv::IEqualOp>(loc, rhs, rhsAbs);
405 Value absNegate = builder.create<spirv::SNegateOp>(loc, type, abs);
406 return builder.create<spirv::SelectOp>(loc, type, isPositive, abs, absNegate);
407}
408
409/// Converts arith.remsi to GLSL SPIR-V ops.
410///
411/// This cannot be merged into the template unary/binary pattern due to Vulkan
412/// restrictions over spirv.SRem and spirv.SMod.
413struct RemSIOpGLPattern final : public OpConversionPattern<arith::RemSIOp> {
414 using OpConversionPattern::OpConversionPattern;
415
416 LogicalResult
417 matchAndRewrite(arith::RemSIOp op, OpAdaptor adaptor,
418 ConversionPatternRewriter &rewriter) const override {
419 Value result = emulateSignedRemainder<spirv::CLSAbsOp>(
420 op.getLoc(), adaptor.getOperands()[0], adaptor.getOperands()[1],
421 adaptor.getOperands()[0], rewriter);
422 rewriter.replaceOp(op, result);
423
424 return success();
425 }
426};
427
428/// Converts arith.remsi to OpenCL SPIR-V ops.
429struct RemSIOpCLPattern final : public OpConversionPattern<arith::RemSIOp> {
430 using OpConversionPattern::OpConversionPattern;
431
432 LogicalResult
433 matchAndRewrite(arith::RemSIOp op, OpAdaptor adaptor,
434 ConversionPatternRewriter &rewriter) const override {
435 Value result = emulateSignedRemainder<spirv::GLSAbsOp>(
436 op.getLoc(), adaptor.getOperands()[0], adaptor.getOperands()[1],
437 adaptor.getOperands()[0], rewriter);
438 rewriter.replaceOp(op, result);
439
440 return success();
441 }
442};
443
444//===----------------------------------------------------------------------===//
445// BitwiseOp
446//===----------------------------------------------------------------------===//
447
448/// Converts bitwise operations to SPIR-V operations. This is a special pattern
449/// other than the BinaryOpPatternPattern because if the operands are boolean
450/// values, SPIR-V uses different operations (`SPIRVLogicalOp`). For
451/// non-boolean operands, SPIR-V should use `SPIRVBitwiseOp`.
452template <typename Op, typename SPIRVLogicalOp, typename SPIRVBitwiseOp>
453struct BitwiseOpPattern final : public OpConversionPattern<Op> {
454 using OpConversionPattern<Op>::OpConversionPattern;
455
456 LogicalResult
457 matchAndRewrite(Op op, typename Op::Adaptor adaptor,
458 ConversionPatternRewriter &rewriter) const override {
459 assert(adaptor.getOperands().size() == 2);
460 Type dstType = this->getTypeConverter()->convertType(op.getType());
461 if (!dstType)
462 return getTypeConversionFailure(rewriter, op);
463
464 if (isBoolScalarOrVector(adaptor.getOperands().front().getType())) {
465 rewriter.template replaceOpWithNewOp<SPIRVLogicalOp>(
466 op, dstType, adaptor.getOperands());
467 } else {
468 rewriter.template replaceOpWithNewOp<SPIRVBitwiseOp>(
469 op, dstType, adaptor.getOperands());
470 }
471 return success();
472 }
473};
474
475//===----------------------------------------------------------------------===//
476// XOrIOp
477//===----------------------------------------------------------------------===//
478
479/// Converts arith.xori to SPIR-V operations.
480struct XOrIOpLogicalPattern final : public OpConversionPattern<arith::XOrIOp> {
481 using OpConversionPattern::OpConversionPattern;
482
483 LogicalResult
484 matchAndRewrite(arith::XOrIOp op, OpAdaptor adaptor,
485 ConversionPatternRewriter &rewriter) const override {
486 assert(adaptor.getOperands().size() == 2);
487
488 if (isBoolScalarOrVector(adaptor.getOperands().front().getType()))
489 return failure();
490
491 Type dstType = getTypeConverter()->convertType(op.getType());
492 if (!dstType)
493 return getTypeConversionFailure(rewriter, op);
494
495 rewriter.replaceOpWithNewOp<spirv::BitwiseXorOp>(op, dstType,
496 adaptor.getOperands());
497
498 return success();
499 }
500};
501
502/// Converts arith.xori to SPIR-V operations if the type of source is i1 or
503/// vector of i1.
504struct XOrIOpBooleanPattern final : public OpConversionPattern<arith::XOrIOp> {
505 using OpConversionPattern::OpConversionPattern;
506
507 LogicalResult
508 matchAndRewrite(arith::XOrIOp op, OpAdaptor adaptor,
509 ConversionPatternRewriter &rewriter) const override {
510 assert(adaptor.getOperands().size() == 2);
511
512 if (!isBoolScalarOrVector(adaptor.getOperands().front().getType()))
513 return failure();
514
515 Type dstType = getTypeConverter()->convertType(op.getType());
516 if (!dstType)
517 return getTypeConversionFailure(rewriter, op);
518
519 rewriter.replaceOpWithNewOp<spirv::LogicalNotEqualOp>(
520 op, dstType, adaptor.getOperands());
521 return success();
522 }
523};
524
525//===----------------------------------------------------------------------===//
526// UIToFPOp
527//===----------------------------------------------------------------------===//
528
529/// Converts arith.uitofp to spirv.Select if the type of source is i1 or vector
530/// of i1.
531struct UIToFPI1Pattern final : public OpConversionPattern<arith::UIToFPOp> {
532 using OpConversionPattern::OpConversionPattern;
533
534 LogicalResult
535 matchAndRewrite(arith::UIToFPOp op, OpAdaptor adaptor,
536 ConversionPatternRewriter &rewriter) const override {
537 Type srcType = adaptor.getOperands().front().getType();
538 if (!isBoolScalarOrVector(type: srcType))
539 return failure();
540
541 Type dstType = getTypeConverter()->convertType(op.getType());
542 if (!dstType)
543 return getTypeConversionFailure(rewriter, op);
544
545 Location loc = op.getLoc();
546 Value zero = spirv::ConstantOp::getZero(dstType, loc, rewriter);
547 Value one = spirv::ConstantOp::getOne(dstType, loc, rewriter);
548 rewriter.replaceOpWithNewOp<spirv::SelectOp>(
549 op, dstType, adaptor.getOperands().front(), one, zero);
550 return success();
551 }
552};
553
554//===----------------------------------------------------------------------===//
555// ExtSIOp
556//===----------------------------------------------------------------------===//
557
558/// Converts arith.extsi to spirv.Select if the type of source is i1 or vector
559/// of i1.
560struct ExtSII1Pattern final : public OpConversionPattern<arith::ExtSIOp> {
561 using OpConversionPattern::OpConversionPattern;
562
563 LogicalResult
564 matchAndRewrite(arith::ExtSIOp op, OpAdaptor adaptor,
565 ConversionPatternRewriter &rewriter) const override {
566 Value operand = adaptor.getIn();
567 if (!isBoolScalarOrVector(type: operand.getType()))
568 return failure();
569
570 Location loc = op.getLoc();
571 Type dstType = getTypeConverter()->convertType(op.getType());
572 if (!dstType)
573 return getTypeConversionFailure(rewriter, op);
574
575 Value allOnes;
576 if (auto intTy = dyn_cast<IntegerType>(dstType)) {
577 unsigned componentBitwidth = intTy.getWidth();
578 allOnes = rewriter.create<spirv::ConstantOp>(
579 loc, intTy,
580 rewriter.getIntegerAttr(intTy, APInt::getAllOnes(componentBitwidth)));
581 } else if (auto vectorTy = dyn_cast<VectorType>(dstType)) {
582 unsigned componentBitwidth = vectorTy.getElementTypeBitWidth();
583 allOnes = rewriter.create<spirv::ConstantOp>(
584 loc, vectorTy,
585 SplatElementsAttr::get(vectorTy,
586 APInt::getAllOnes(componentBitwidth)));
587 } else {
588 return rewriter.notifyMatchFailure(
589 arg&: loc, msg: llvm::formatv(Fmt: "unhandled type: {0}", Vals&: dstType));
590 }
591
592 Value zero = spirv::ConstantOp::getZero(dstType, loc, rewriter);
593 rewriter.replaceOpWithNewOp<spirv::SelectOp>(op, dstType, operand, allOnes,
594 zero);
595 return success();
596 }
597};
598
599/// Converts arith.extsi to spirv.Select if the type of source is neither i1 nor
600/// vector of i1.
601struct ExtSIPattern final : public OpConversionPattern<arith::ExtSIOp> {
602 using OpConversionPattern::OpConversionPattern;
603
604 LogicalResult
605 matchAndRewrite(arith::ExtSIOp op, OpAdaptor adaptor,
606 ConversionPatternRewriter &rewriter) const override {
607 Type srcType = adaptor.getIn().getType();
608 if (isBoolScalarOrVector(type: srcType))
609 return failure();
610
611 Type dstType = getTypeConverter()->convertType(op.getType());
612 if (!dstType)
613 return getTypeConversionFailure(rewriter, op);
614
615 if (dstType == srcType) {
616 // We can have the same source and destination type due to type emulation.
617 // Perform bit shifting to make sure we have the proper leading set bits.
618
619 unsigned srcBW =
620 getElementTypeOrSelf(op.getIn().getType()).getIntOrFloatBitWidth();
621 unsigned dstBW =
622 getElementTypeOrSelf(op.getType()).getIntOrFloatBitWidth();
623 assert(srcBW < dstBW);
624 Value shiftSize = getScalarOrVectorConstInt(dstType, dstBW - srcBW,
625 rewriter, op.getLoc());
626
627 // First shift left to sequeeze out all leading bits beyond the original
628 // bitwidth. Here we need to use the original source and result type's
629 // bitwidth.
630 auto shiftLOp = rewriter.create<spirv::ShiftLeftLogicalOp>(
631 op.getLoc(), dstType, adaptor.getIn(), shiftSize);
632
633 // Then we perform arithmetic right shift to make sure we have the right
634 // sign bits for negative values.
635 rewriter.replaceOpWithNewOp<spirv::ShiftRightArithmeticOp>(
636 op, dstType, shiftLOp, shiftSize);
637 } else {
638 rewriter.replaceOpWithNewOp<spirv::SConvertOp>(op, dstType,
639 adaptor.getOperands());
640 }
641
642 return success();
643 }
644};
645
646//===----------------------------------------------------------------------===//
647// ExtUIOp
648//===----------------------------------------------------------------------===//
649
650/// Converts arith.extui to spirv.Select if the type of source is i1 or vector
651/// of i1.
652struct ExtUII1Pattern final : public OpConversionPattern<arith::ExtUIOp> {
653 using OpConversionPattern::OpConversionPattern;
654
655 LogicalResult
656 matchAndRewrite(arith::ExtUIOp op, OpAdaptor adaptor,
657 ConversionPatternRewriter &rewriter) const override {
658 Type srcType = adaptor.getOperands().front().getType();
659 if (!isBoolScalarOrVector(type: srcType))
660 return failure();
661
662 Type dstType = getTypeConverter()->convertType(op.getType());
663 if (!dstType)
664 return getTypeConversionFailure(rewriter, op);
665
666 Location loc = op.getLoc();
667 Value zero = spirv::ConstantOp::getZero(dstType, loc, rewriter);
668 Value one = spirv::ConstantOp::getOne(dstType, loc, rewriter);
669 rewriter.replaceOpWithNewOp<spirv::SelectOp>(
670 op, dstType, adaptor.getOperands().front(), one, zero);
671 return success();
672 }
673};
674
675/// Converts arith.extui for cases where the type of source is neither i1 nor
676/// vector of i1.
677struct ExtUIPattern final : public OpConversionPattern<arith::ExtUIOp> {
678 using OpConversionPattern::OpConversionPattern;
679
680 LogicalResult
681 matchAndRewrite(arith::ExtUIOp op, OpAdaptor adaptor,
682 ConversionPatternRewriter &rewriter) const override {
683 Type srcType = adaptor.getIn().getType();
684 if (isBoolScalarOrVector(type: srcType))
685 return failure();
686
687 Type dstType = getTypeConverter()->convertType(op.getType());
688 if (!dstType)
689 return getTypeConversionFailure(rewriter, op);
690
691 if (dstType == srcType) {
692 // We can have the same source and destination type due to type emulation.
693 // Perform bit masking to make sure we don't pollute downstream consumers
694 // with unwanted bits. Here we need to use the original source type's
695 // bitwidth.
696 unsigned bitwidth =
697 getElementTypeOrSelf(op.getIn().getType()).getIntOrFloatBitWidth();
698 Value mask = getScalarOrVectorConstInt(
699 dstType, llvm::maskTrailingOnes<uint64_t>(N: bitwidth), rewriter,
700 op.getLoc());
701 rewriter.replaceOpWithNewOp<spirv::BitwiseAndOp>(op, dstType,
702 adaptor.getIn(), mask);
703 } else {
704 rewriter.replaceOpWithNewOp<spirv::UConvertOp>(op, dstType,
705 adaptor.getOperands());
706 }
707 return success();
708 }
709};
710
711//===----------------------------------------------------------------------===//
712// TruncIOp
713//===----------------------------------------------------------------------===//
714
715/// Converts arith.trunci to spirv.Select if the type of result is i1 or vector
716/// of i1.
717struct TruncII1Pattern final : public OpConversionPattern<arith::TruncIOp> {
718 using OpConversionPattern::OpConversionPattern;
719
720 LogicalResult
721 matchAndRewrite(arith::TruncIOp op, OpAdaptor adaptor,
722 ConversionPatternRewriter &rewriter) const override {
723 Type dstType = getTypeConverter()->convertType(op.getType());
724 if (!dstType)
725 return getTypeConversionFailure(rewriter, op);
726
727 if (!isBoolScalarOrVector(type: dstType))
728 return failure();
729
730 Location loc = op.getLoc();
731 auto srcType = adaptor.getOperands().front().getType();
732 // Check if (x & 1) == 1.
733 Value mask = spirv::ConstantOp::getOne(srcType, loc, rewriter);
734 Value maskedSrc = rewriter.create<spirv::BitwiseAndOp>(
735 loc, srcType, adaptor.getOperands()[0], mask);
736 Value isOne = rewriter.create<spirv::IEqualOp>(loc, maskedSrc, mask);
737
738 Value zero = spirv::ConstantOp::getZero(dstType, loc, rewriter);
739 Value one = spirv::ConstantOp::getOne(dstType, loc, rewriter);
740 rewriter.replaceOpWithNewOp<spirv::SelectOp>(op, dstType, isOne, one, zero);
741 return success();
742 }
743};
744
745/// Converts arith.trunci for cases where the type of result is neither i1
746/// nor vector of i1.
747struct TruncIPattern final : public OpConversionPattern<arith::TruncIOp> {
748 using OpConversionPattern::OpConversionPattern;
749
750 LogicalResult
751 matchAndRewrite(arith::TruncIOp op, OpAdaptor adaptor,
752 ConversionPatternRewriter &rewriter) const override {
753 Type srcType = adaptor.getIn().getType();
754 Type dstType = getTypeConverter()->convertType(op.getType());
755 if (!dstType)
756 return getTypeConversionFailure(rewriter, op);
757
758 if (isBoolScalarOrVector(type: dstType))
759 return failure();
760
761 if (dstType == srcType) {
762 // We can have the same source and destination type due to type emulation.
763 // Perform bit masking to make sure we don't pollute downstream consumers
764 // with unwanted bits. Here we need to use the original result type's
765 // bitwidth.
766 unsigned bw = getElementTypeOrSelf(op.getType()).getIntOrFloatBitWidth();
767 Value mask = getScalarOrVectorConstInt(
768 dstType, llvm::maskTrailingOnes<uint64_t>(N: bw), rewriter, op.getLoc());
769 rewriter.replaceOpWithNewOp<spirv::BitwiseAndOp>(op, dstType,
770 adaptor.getIn(), mask);
771 } else {
772 // Given this is truncation, either SConvertOp or UConvertOp works.
773 rewriter.replaceOpWithNewOp<spirv::SConvertOp>(op, dstType,
774 adaptor.getOperands());
775 }
776 return success();
777 }
778};
779
780//===----------------------------------------------------------------------===//
781// TypeCastingOp
782//===----------------------------------------------------------------------===//
783
784/// Converts type-casting standard operations to SPIR-V operations.
785template <typename Op, typename SPIRVOp>
786struct TypeCastingOpPattern final : public OpConversionPattern<Op> {
787 using OpConversionPattern<Op>::OpConversionPattern;
788
789 LogicalResult
790 matchAndRewrite(Op op, typename Op::Adaptor adaptor,
791 ConversionPatternRewriter &rewriter) const override {
792 assert(adaptor.getOperands().size() == 1);
793 Type srcType = adaptor.getOperands().front().getType();
794 Type dstType = this->getTypeConverter()->convertType(op.getType());
795 if (!dstType)
796 return getTypeConversionFailure(rewriter, op);
797
798 if (isBoolScalarOrVector(type: srcType) || isBoolScalarOrVector(type: dstType))
799 return failure();
800
801 if (dstType == srcType) {
802 // Due to type conversion, we are seeing the same source and target type.
803 // Then we can just erase this operation by forwarding its operand.
804 rewriter.replaceOp(op, adaptor.getOperands().front());
805 } else {
806 rewriter.template replaceOpWithNewOp<SPIRVOp>(op, dstType,
807 adaptor.getOperands());
808 if (auto roundingModeOp =
809 dyn_cast<arith::ArithRoundingModeInterface>(*op)) {
810 if (arith::RoundingModeAttr roundingMode =
811 roundingModeOp.getRoundingModeAttr()) {
812 // TODO: Perform rounding mode attribute conversion and attach to new
813 // operation when defined in the dialect.
814 return failure();
815 }
816 }
817 }
818 return success();
819 }
820};
821
822//===----------------------------------------------------------------------===//
823// CmpIOp
824//===----------------------------------------------------------------------===//
825
826/// Converts integer compare operation on i1 type operands to SPIR-V ops.
827class CmpIOpBooleanPattern final : public OpConversionPattern<arith::CmpIOp> {
828public:
829 using OpConversionPattern::OpConversionPattern;
830
831 LogicalResult
832 matchAndRewrite(arith::CmpIOp op, OpAdaptor adaptor,
833 ConversionPatternRewriter &rewriter) const override {
834 Type srcType = op.getLhs().getType();
835 if (!isBoolScalarOrVector(type: srcType))
836 return failure();
837 Type dstType = getTypeConverter()->convertType(srcType);
838 if (!dstType)
839 return getTypeConversionFailure(rewriter, op, srcType);
840
841 switch (op.getPredicate()) {
842 case arith::CmpIPredicate::eq: {
843 rewriter.replaceOpWithNewOp<spirv::LogicalEqualOp>(op, adaptor.getLhs(),
844 adaptor.getRhs());
845 return success();
846 }
847 case arith::CmpIPredicate::ne: {
848 rewriter.replaceOpWithNewOp<spirv::LogicalNotEqualOp>(
849 op, adaptor.getLhs(), adaptor.getRhs());
850 return success();
851 }
852 case arith::CmpIPredicate::uge:
853 case arith::CmpIPredicate::ugt:
854 case arith::CmpIPredicate::ule:
855 case arith::CmpIPredicate::ult: {
856 // There are no direct corresponding instructions in SPIR-V for such
857 // cases. Extend them to 32-bit and do comparision then.
858 Type type = rewriter.getI32Type();
859 if (auto vectorType = dyn_cast<VectorType>(dstType))
860 type = VectorType::get(vectorType.getShape(), type);
861 Value extLhs =
862 rewriter.create<arith::ExtUIOp>(op.getLoc(), type, adaptor.getLhs());
863 Value extRhs =
864 rewriter.create<arith::ExtUIOp>(op.getLoc(), type, adaptor.getRhs());
865
866 rewriter.replaceOpWithNewOp<arith::CmpIOp>(op, op.getPredicate(), extLhs,
867 extRhs);
868 return success();
869 }
870 default:
871 break;
872 }
873 return failure();
874 }
875};
876
877/// Converts integer compare operation to SPIR-V ops.
878class CmpIOpPattern final : public OpConversionPattern<arith::CmpIOp> {
879public:
880 using OpConversionPattern::OpConversionPattern;
881
882 LogicalResult
883 matchAndRewrite(arith::CmpIOp op, OpAdaptor adaptor,
884 ConversionPatternRewriter &rewriter) const override {
885 Type srcType = op.getLhs().getType();
886 if (isBoolScalarOrVector(type: srcType))
887 return failure();
888 Type dstType = getTypeConverter()->convertType(srcType);
889 if (!dstType)
890 return getTypeConversionFailure(rewriter, op, srcType);
891
892 switch (op.getPredicate()) {
893#define DISPATCH(cmpPredicate, spirvOp) \
894 case cmpPredicate: \
895 if (spirvOp::template hasTrait<OpTrait::spirv::UnsignedOp>() && \
896 !getElementTypeOrSelf(srcType).isIndex() && srcType != dstType && \
897 !hasSameBitwidth(srcType, dstType)) { \
898 return op.emitError( \
899 "bitwidth emulation is not implemented yet on unsigned op"); \
900 } \
901 rewriter.replaceOpWithNewOp<spirvOp>(op, adaptor.getLhs(), \
902 adaptor.getRhs()); \
903 return success();
904
905 DISPATCH(arith::CmpIPredicate::eq, spirv::IEqualOp);
906 DISPATCH(arith::CmpIPredicate::ne, spirv::INotEqualOp);
907 DISPATCH(arith::CmpIPredicate::slt, spirv::SLessThanOp);
908 DISPATCH(arith::CmpIPredicate::sle, spirv::SLessThanEqualOp);
909 DISPATCH(arith::CmpIPredicate::sgt, spirv::SGreaterThanOp);
910 DISPATCH(arith::CmpIPredicate::sge, spirv::SGreaterThanEqualOp);
911 DISPATCH(arith::CmpIPredicate::ult, spirv::ULessThanOp);
912 DISPATCH(arith::CmpIPredicate::ule, spirv::ULessThanEqualOp);
913 DISPATCH(arith::CmpIPredicate::ugt, spirv::UGreaterThanOp);
914 DISPATCH(arith::CmpIPredicate::uge, spirv::UGreaterThanEqualOp);
915
916#undef DISPATCH
917 }
918 return failure();
919 }
920};
921
922//===----------------------------------------------------------------------===//
923// CmpFOpPattern
924//===----------------------------------------------------------------------===//
925
926/// Converts floating-point comparison operations to SPIR-V ops.
927class CmpFOpPattern final : public OpConversionPattern<arith::CmpFOp> {
928public:
929 using OpConversionPattern::OpConversionPattern;
930
931 LogicalResult
932 matchAndRewrite(arith::CmpFOp op, OpAdaptor adaptor,
933 ConversionPatternRewriter &rewriter) const override {
934 switch (op.getPredicate()) {
935#define DISPATCH(cmpPredicate, spirvOp) \
936 case cmpPredicate: \
937 rewriter.replaceOpWithNewOp<spirvOp>(op, adaptor.getLhs(), \
938 adaptor.getRhs()); \
939 return success();
940
941 // Ordered.
942 DISPATCH(arith::CmpFPredicate::OEQ, spirv::FOrdEqualOp);
943 DISPATCH(arith::CmpFPredicate::OGT, spirv::FOrdGreaterThanOp);
944 DISPATCH(arith::CmpFPredicate::OGE, spirv::FOrdGreaterThanEqualOp);
945 DISPATCH(arith::CmpFPredicate::OLT, spirv::FOrdLessThanOp);
946 DISPATCH(arith::CmpFPredicate::OLE, spirv::FOrdLessThanEqualOp);
947 DISPATCH(arith::CmpFPredicate::ONE, spirv::FOrdNotEqualOp);
948 // Unordered.
949 DISPATCH(arith::CmpFPredicate::UEQ, spirv::FUnordEqualOp);
950 DISPATCH(arith::CmpFPredicate::UGT, spirv::FUnordGreaterThanOp);
951 DISPATCH(arith::CmpFPredicate::UGE, spirv::FUnordGreaterThanEqualOp);
952 DISPATCH(arith::CmpFPredicate::ULT, spirv::FUnordLessThanOp);
953 DISPATCH(arith::CmpFPredicate::ULE, spirv::FUnordLessThanEqualOp);
954 DISPATCH(arith::CmpFPredicate::UNE, spirv::FUnordNotEqualOp);
955
956#undef DISPATCH
957
958 default:
959 break;
960 }
961 return failure();
962 }
963};
964
965/// Converts floating point NaN check to SPIR-V ops. This pattern requires
966/// Kernel capability.
967class CmpFOpNanKernelPattern final : public OpConversionPattern<arith::CmpFOp> {
968public:
969 using OpConversionPattern::OpConversionPattern;
970
971 LogicalResult
972 matchAndRewrite(arith::CmpFOp op, OpAdaptor adaptor,
973 ConversionPatternRewriter &rewriter) const override {
974 if (op.getPredicate() == arith::CmpFPredicate::ORD) {
975 rewriter.replaceOpWithNewOp<spirv::OrderedOp>(op, adaptor.getLhs(),
976 adaptor.getRhs());
977 return success();
978 }
979
980 if (op.getPredicate() == arith::CmpFPredicate::UNO) {
981 rewriter.replaceOpWithNewOp<spirv::UnorderedOp>(op, adaptor.getLhs(),
982 adaptor.getRhs());
983 return success();
984 }
985
986 return failure();
987 }
988};
989
990/// Converts floating point NaN check to SPIR-V ops. This pattern does not
991/// require additional capability.
992class CmpFOpNanNonePattern final : public OpConversionPattern<arith::CmpFOp> {
993public:
994 using OpConversionPattern<arith::CmpFOp>::OpConversionPattern;
995
996 LogicalResult
997 matchAndRewrite(arith::CmpFOp op, OpAdaptor adaptor,
998 ConversionPatternRewriter &rewriter) const override {
999 if (op.getPredicate() != arith::CmpFPredicate::ORD &&
1000 op.getPredicate() != arith::CmpFPredicate::UNO)
1001 return failure();
1002
1003 Location loc = op.getLoc();
1004
1005 Value replace;
1006 if (bitEnumContainsAll(op.getFastmath(), arith::FastMathFlags::nnan)) {
1007 if (op.getPredicate() == arith::CmpFPredicate::ORD) {
1008 // Ordered comparsion checks if neither operand is NaN.
1009 replace = spirv::ConstantOp::getOne(op.getType(), loc, rewriter);
1010 } else {
1011 // Unordered comparsion checks if either operand is NaN.
1012 replace = spirv::ConstantOp::getZero(op.getType(), loc, rewriter);
1013 }
1014 } else {
1015 Value lhsIsNan = rewriter.create<spirv::IsNanOp>(loc, adaptor.getLhs());
1016 Value rhsIsNan = rewriter.create<spirv::IsNanOp>(loc, adaptor.getRhs());
1017
1018 replace = rewriter.create<spirv::LogicalOrOp>(loc, lhsIsNan, rhsIsNan);
1019 if (op.getPredicate() == arith::CmpFPredicate::ORD)
1020 replace = rewriter.create<spirv::LogicalNotOp>(loc, replace);
1021 }
1022
1023 rewriter.replaceOp(op, replace);
1024 return success();
1025 }
1026};
1027
1028//===----------------------------------------------------------------------===//
1029// AddUIExtendedOp
1030//===----------------------------------------------------------------------===//
1031
1032/// Converts arith.addui_extended to spirv.IAddCarry.
1033class AddUIExtendedOpPattern final
1034 : public OpConversionPattern<arith::AddUIExtendedOp> {
1035public:
1036 using OpConversionPattern::OpConversionPattern;
1037 LogicalResult
1038 matchAndRewrite(arith::AddUIExtendedOp op, OpAdaptor adaptor,
1039 ConversionPatternRewriter &rewriter) const override {
1040 Type dstElemTy = adaptor.getLhs().getType();
1041 Location loc = op->getLoc();
1042 Value result = rewriter.create<spirv::IAddCarryOp>(loc, adaptor.getLhs(),
1043 adaptor.getRhs());
1044
1045 Value sumResult = rewriter.create<spirv::CompositeExtractOp>(
1046 loc, result, llvm::ArrayRef(0));
1047 Value carryValue = rewriter.create<spirv::CompositeExtractOp>(
1048 loc, result, llvm::ArrayRef(1));
1049
1050 // Convert the carry value to boolean.
1051 Value one = spirv::ConstantOp::getOne(dstElemTy, loc, rewriter);
1052 Value carryResult = rewriter.create<spirv::IEqualOp>(loc, carryValue, one);
1053
1054 rewriter.replaceOp(op, {sumResult, carryResult});
1055 return success();
1056 }
1057};
1058
1059//===----------------------------------------------------------------------===//
1060// MulIExtendedOp
1061//===----------------------------------------------------------------------===//
1062
1063/// Converts arith.mul*i_extended to spirv.*MulExtended.
1064template <typename ArithMulOp, typename SPIRVMulOp>
1065class MulIExtendedOpPattern final : public OpConversionPattern<ArithMulOp> {
1066public:
1067 using OpConversionPattern<ArithMulOp>::OpConversionPattern;
1068 LogicalResult
1069 matchAndRewrite(ArithMulOp op, typename ArithMulOp::Adaptor adaptor,
1070 ConversionPatternRewriter &rewriter) const override {
1071 Location loc = op->getLoc();
1072 Value result =
1073 rewriter.create<SPIRVMulOp>(loc, adaptor.getLhs(), adaptor.getRhs());
1074
1075 Value low = rewriter.create<spirv::CompositeExtractOp>(loc, result,
1076 llvm::ArrayRef(0));
1077 Value high = rewriter.create<spirv::CompositeExtractOp>(loc, result,
1078 llvm::ArrayRef(1));
1079
1080 rewriter.replaceOp(op, {low, high});
1081 return success();
1082 }
1083};
1084
1085//===----------------------------------------------------------------------===//
1086// SelectOp
1087//===----------------------------------------------------------------------===//
1088
1089/// Converts arith.select to spirv.Select.
1090class SelectOpPattern final : public OpConversionPattern<arith::SelectOp> {
1091public:
1092 using OpConversionPattern::OpConversionPattern;
1093 LogicalResult
1094 matchAndRewrite(arith::SelectOp op, OpAdaptor adaptor,
1095 ConversionPatternRewriter &rewriter) const override {
1096 rewriter.replaceOpWithNewOp<spirv::SelectOp>(op, adaptor.getCondition(),
1097 adaptor.getTrueValue(),
1098 adaptor.getFalseValue());
1099 return success();
1100 }
1101};
1102
1103//===----------------------------------------------------------------------===//
1104// MinimumFOp, MaximumFOp
1105//===----------------------------------------------------------------------===//
1106
1107/// Converts arith.maximumf/minimumf to spirv.GL.FMax/FMin or
1108/// spirv.CL.fmax/fmin.
1109template <typename Op, typename SPIRVOp>
1110class MinimumMaximumFOpPattern final : public OpConversionPattern<Op> {
1111public:
1112 using OpConversionPattern<Op>::OpConversionPattern;
1113 LogicalResult
1114 matchAndRewrite(Op op, typename Op::Adaptor adaptor,
1115 ConversionPatternRewriter &rewriter) const override {
1116 auto *converter = this->template getTypeConverter<SPIRVTypeConverter>();
1117 Type dstType = converter->convertType(op.getType());
1118 if (!dstType)
1119 return getTypeConversionFailure(rewriter, op);
1120
1121 // arith.maximumf/minimumf:
1122 // "if one of the arguments is NaN, then the result is also NaN."
1123 // spirv.GL.FMax/FMin
1124 // "which operand is the result is undefined if one of the operands
1125 // is a NaN."
1126 // spirv.CL.fmax/fmin:
1127 // "If one argument is a NaN, Fmin returns the other argument."
1128
1129 Location loc = op.getLoc();
1130 Value spirvOp =
1131 rewriter.create<SPIRVOp>(loc, dstType, adaptor.getOperands());
1132
1133 if (bitEnumContainsAll(op.getFastmath(), arith::FastMathFlags::nnan)) {
1134 rewriter.replaceOp(op, spirvOp);
1135 return success();
1136 }
1137
1138 Value lhsIsNan = rewriter.create<spirv::IsNanOp>(loc, adaptor.getLhs());
1139 Value rhsIsNan = rewriter.create<spirv::IsNanOp>(loc, adaptor.getRhs());
1140
1141 Value select1 = rewriter.create<spirv::SelectOp>(loc, dstType, lhsIsNan,
1142 adaptor.getLhs(), spirvOp);
1143 Value select2 = rewriter.create<spirv::SelectOp>(loc, dstType, rhsIsNan,
1144 adaptor.getRhs(), select1);
1145
1146 rewriter.replaceOp(op, select2);
1147 return success();
1148 }
1149};
1150
1151//===----------------------------------------------------------------------===//
1152// MinNumFOp, MaxNumFOp
1153//===----------------------------------------------------------------------===//
1154
1155/// Converts arith.maxnumf/minnumf to spirv.GL.FMax/FMin or
1156/// spirv.CL.fmax/fmin.
1157template <typename Op, typename SPIRVOp>
1158class MinNumMaxNumFOpPattern final : public OpConversionPattern<Op> {
1159 template <typename TargetOp>
1160 constexpr bool shouldInsertNanGuards() const {
1161 return llvm::is_one_of<TargetOp, spirv::GLFMaxOp, spirv::GLFMinOp>::value;
1162 }
1163
1164public:
1165 using OpConversionPattern<Op>::OpConversionPattern;
1166 LogicalResult
1167 matchAndRewrite(Op op, typename Op::Adaptor adaptor,
1168 ConversionPatternRewriter &rewriter) const override {
1169 auto *converter = this->template getTypeConverter<SPIRVTypeConverter>();
1170 Type dstType = converter->convertType(op.getType());
1171 if (!dstType)
1172 return getTypeConversionFailure(rewriter, op);
1173
1174 // arith.maxnumf/minnumf:
1175 // "If one of the arguments is NaN, then the result is the other
1176 // argument."
1177 // spirv.GL.FMax/FMin
1178 // "which operand is the result is undefined if one of the operands
1179 // is a NaN."
1180 // spirv.CL.fmax/fmin:
1181 // "If one argument is a NaN, Fmin returns the other argument."
1182
1183 Location loc = op.getLoc();
1184 Value spirvOp =
1185 rewriter.create<SPIRVOp>(loc, dstType, adaptor.getOperands());
1186
1187 if (!shouldInsertNanGuards<SPIRVOp>() ||
1188 bitEnumContainsAll(op.getFastmath(), arith::FastMathFlags::nnan)) {
1189 rewriter.replaceOp(op, spirvOp);
1190 return success();
1191 }
1192
1193 Value lhsIsNan = rewriter.create<spirv::IsNanOp>(loc, adaptor.getLhs());
1194 Value rhsIsNan = rewriter.create<spirv::IsNanOp>(loc, adaptor.getRhs());
1195
1196 Value select1 = rewriter.create<spirv::SelectOp>(loc, dstType, lhsIsNan,
1197 adaptor.getRhs(), spirvOp);
1198 Value select2 = rewriter.create<spirv::SelectOp>(loc, dstType, rhsIsNan,
1199 adaptor.getLhs(), select1);
1200
1201 rewriter.replaceOp(op, select2);
1202 return success();
1203 }
1204};
1205
1206} // namespace
1207
1208//===----------------------------------------------------------------------===//
1209// Pattern Population
1210//===----------------------------------------------------------------------===//
1211
1212void mlir::arith::populateArithToSPIRVPatterns(
1213 SPIRVTypeConverter &typeConverter, RewritePatternSet &patterns) {
1214 // clang-format off
1215 patterns.add<
1216 ConstantCompositeOpPattern,
1217 ConstantScalarOpPattern,
1218 ElementwiseArithOpPattern<arith::AddIOp, spirv::IAddOp>,
1219 ElementwiseArithOpPattern<arith::SubIOp, spirv::ISubOp>,
1220 ElementwiseArithOpPattern<arith::MulIOp, spirv::IMulOp>,
1221 spirv::ElementwiseOpPattern<arith::DivUIOp, spirv::UDivOp>,
1222 spirv::ElementwiseOpPattern<arith::DivSIOp, spirv::SDivOp>,
1223 spirv::ElementwiseOpPattern<arith::RemUIOp, spirv::UModOp>,
1224 RemSIOpGLPattern, RemSIOpCLPattern,
1225 BitwiseOpPattern<arith::AndIOp, spirv::LogicalAndOp, spirv::BitwiseAndOp>,
1226 BitwiseOpPattern<arith::OrIOp, spirv::LogicalOrOp, spirv::BitwiseOrOp>,
1227 XOrIOpLogicalPattern, XOrIOpBooleanPattern,
1228 ElementwiseArithOpPattern<arith::ShLIOp, spirv::ShiftLeftLogicalOp>,
1229 spirv::ElementwiseOpPattern<arith::ShRUIOp, spirv::ShiftRightLogicalOp>,
1230 spirv::ElementwiseOpPattern<arith::ShRSIOp, spirv::ShiftRightArithmeticOp>,
1231 spirv::ElementwiseOpPattern<arith::NegFOp, spirv::FNegateOp>,
1232 spirv::ElementwiseOpPattern<arith::AddFOp, spirv::FAddOp>,
1233 spirv::ElementwiseOpPattern<arith::SubFOp, spirv::FSubOp>,
1234 spirv::ElementwiseOpPattern<arith::MulFOp, spirv::FMulOp>,
1235 spirv::ElementwiseOpPattern<arith::DivFOp, spirv::FDivOp>,
1236 spirv::ElementwiseOpPattern<arith::RemFOp, spirv::FRemOp>,
1237 ExtUIPattern, ExtUII1Pattern,
1238 ExtSIPattern, ExtSII1Pattern,
1239 TypeCastingOpPattern<arith::ExtFOp, spirv::FConvertOp>,
1240 TruncIPattern, TruncII1Pattern,
1241 TypeCastingOpPattern<arith::TruncFOp, spirv::FConvertOp>,
1242 TypeCastingOpPattern<arith::UIToFPOp, spirv::ConvertUToFOp>, UIToFPI1Pattern,
1243 TypeCastingOpPattern<arith::SIToFPOp, spirv::ConvertSToFOp>,
1244 TypeCastingOpPattern<arith::FPToUIOp, spirv::ConvertFToUOp>,
1245 TypeCastingOpPattern<arith::FPToSIOp, spirv::ConvertFToSOp>,
1246 TypeCastingOpPattern<arith::IndexCastOp, spirv::SConvertOp>,
1247 TypeCastingOpPattern<arith::IndexCastUIOp, spirv::UConvertOp>,
1248 TypeCastingOpPattern<arith::BitcastOp, spirv::BitcastOp>,
1249 CmpIOpBooleanPattern, CmpIOpPattern,
1250 CmpFOpNanNonePattern, CmpFOpPattern,
1251 AddUIExtendedOpPattern,
1252 MulIExtendedOpPattern<arith::MulSIExtendedOp, spirv::SMulExtendedOp>,
1253 MulIExtendedOpPattern<arith::MulUIExtendedOp, spirv::UMulExtendedOp>,
1254 SelectOpPattern,
1255
1256 MinimumMaximumFOpPattern<arith::MaximumFOp, spirv::GLFMaxOp>,
1257 MinimumMaximumFOpPattern<arith::MinimumFOp, spirv::GLFMinOp>,
1258 MinNumMaxNumFOpPattern<arith::MaxNumFOp, spirv::GLFMaxOp>,
1259 MinNumMaxNumFOpPattern<arith::MinNumFOp, spirv::GLFMinOp>,
1260 spirv::ElementwiseOpPattern<arith::MaxSIOp, spirv::GLSMaxOp>,
1261 spirv::ElementwiseOpPattern<arith::MaxUIOp, spirv::GLUMaxOp>,
1262 spirv::ElementwiseOpPattern<arith::MinSIOp, spirv::GLSMinOp>,
1263 spirv::ElementwiseOpPattern<arith::MinUIOp, spirv::GLUMinOp>,
1264
1265 MinimumMaximumFOpPattern<arith::MaximumFOp, spirv::CLFMaxOp>,
1266 MinimumMaximumFOpPattern<arith::MinimumFOp, spirv::CLFMinOp>,
1267 MinNumMaxNumFOpPattern<arith::MaxNumFOp, spirv::CLFMaxOp>,
1268 MinNumMaxNumFOpPattern<arith::MinNumFOp, spirv::CLFMinOp>,
1269 spirv::ElementwiseOpPattern<arith::MaxSIOp, spirv::CLSMaxOp>,
1270 spirv::ElementwiseOpPattern<arith::MaxUIOp, spirv::CLUMaxOp>,
1271 spirv::ElementwiseOpPattern<arith::MinSIOp, spirv::CLSMinOp>,
1272 spirv::ElementwiseOpPattern<arith::MinUIOp, spirv::CLUMinOp>
1273 >(typeConverter, patterns.getContext());
1274 // clang-format on
1275
1276 // Give CmpFOpNanKernelPattern a higher benefit so it can prevail when Kernel
1277 // capability is available.
1278 patterns.add<CmpFOpNanKernelPattern>(arg&: typeConverter, args: patterns.getContext(),
1279 /*benefit=*/args: 2);
1280}
1281
1282//===----------------------------------------------------------------------===//
1283// Pass Definition
1284//===----------------------------------------------------------------------===//
1285
1286namespace {
1287struct ConvertArithToSPIRVPass
1288 : public impl::ConvertArithToSPIRVBase<ConvertArithToSPIRVPass> {
1289 void runOnOperation() override {
1290 Operation *op = getOperation();
1291 spirv::TargetEnvAttr targetAttr = spirv::lookupTargetEnvOrDefault(op);
1292 std::unique_ptr<SPIRVConversionTarget> target =
1293 SPIRVConversionTarget::get(targetAttr);
1294
1295 SPIRVConversionOptions options;
1296 options.emulateLT32BitScalarTypes = this->emulateLT32BitScalarTypes;
1297 SPIRVTypeConverter typeConverter(targetAttr, options);
1298
1299 // Use UnrealizedConversionCast as the bridge so that we don't need to pull
1300 // in patterns for other dialects.
1301 target->addLegalOp<UnrealizedConversionCastOp>();
1302
1303 // Fail hard when there are any remaining 'arith' ops.
1304 target->addIllegalDialect<arith::ArithDialect>();
1305
1306 RewritePatternSet patterns(&getContext());
1307 arith::populateArithToSPIRVPatterns(typeConverter, patterns);
1308
1309 if (failed(applyPartialConversion(op, *target, std::move(patterns))))
1310 signalPassFailure();
1311 }
1312};
1313} // namespace
1314
1315std::unique_ptr<OperationPass<>> mlir::arith::createConvertArithToSPIRVPass() {
1316 return std::make_unique<ConvertArithToSPIRVPass>();
1317}
1318

source code of mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp