1 | //===- ArithToSPIRV.cpp - Arithmetic to SPIRV dialect conversion -----===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | |
9 | #include "mlir/Conversion/ArithToSPIRV/ArithToSPIRV.h" |
10 | |
11 | #include "../SPIRVCommon/Pattern.h" |
12 | #include "mlir/Dialect/Arith/IR/Arith.h" |
13 | #include "mlir/Dialect/SPIRV/IR/SPIRVAttributes.h" |
14 | #include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h" |
15 | #include "mlir/Dialect/SPIRV/IR/SPIRVOps.h" |
16 | #include "mlir/Dialect/SPIRV/IR/SPIRVTypes.h" |
17 | #include "mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h" |
18 | #include "mlir/IR/BuiltinAttributes.h" |
19 | #include "mlir/IR/BuiltinTypes.h" |
20 | #include "llvm/ADT/APInt.h" |
21 | #include "llvm/ADT/ArrayRef.h" |
22 | #include "llvm/ADT/STLExtras.h" |
23 | #include "llvm/Support/Debug.h" |
24 | #include "llvm/Support/MathExtras.h" |
25 | #include <cassert> |
26 | #include <memory> |
27 | |
28 | namespace mlir { |
29 | #define GEN_PASS_DEF_CONVERTARITHTOSPIRV |
30 | #include "mlir/Conversion/Passes.h.inc" |
31 | } // namespace mlir |
32 | |
33 | #define DEBUG_TYPE "arith-to-spirv-pattern" |
34 | |
35 | using namespace mlir; |
36 | |
37 | //===----------------------------------------------------------------------===// |
38 | // Conversion Helpers |
39 | //===----------------------------------------------------------------------===// |
40 | |
41 | /// Converts the given `srcAttr` into a boolean attribute if it holds an |
42 | /// integral value. Returns null attribute if conversion fails. |
43 | static BoolAttr convertBoolAttr(Attribute srcAttr, Builder builder) { |
44 | if (auto boolAttr = dyn_cast<BoolAttr>(srcAttr)) |
45 | return boolAttr; |
46 | if (auto intAttr = dyn_cast<IntegerAttr>(srcAttr)) |
47 | return builder.getBoolAttr(value: intAttr.getValue().getBoolValue()); |
48 | return {}; |
49 | } |
50 | |
51 | /// Converts the given `srcAttr` to a new attribute of the given `dstType`. |
52 | /// Returns null attribute if conversion fails. |
53 | static IntegerAttr convertIntegerAttr(IntegerAttr srcAttr, IntegerType dstType, |
54 | Builder builder) { |
55 | // If the source number uses less active bits than the target bitwidth, then |
56 | // it should be safe to convert. |
57 | if (srcAttr.getValue().isIntN(dstType.getWidth())) |
58 | return builder.getIntegerAttr(dstType, srcAttr.getInt()); |
59 | |
60 | // XXX: Try again by interpreting the source number as a signed value. |
61 | // Although integers in the standard dialect are signless, they can represent |
62 | // a signed number. It's the operation decides how to interpret. This is |
63 | // dangerous, but it seems there is no good way of handling this if we still |
64 | // want to change the bitwidth. Emit a message at least. |
65 | if (srcAttr.getValue().isSignedIntN(dstType.getWidth())) { |
66 | auto dstAttr = builder.getIntegerAttr(dstType, srcAttr.getInt()); |
67 | LLVM_DEBUG(llvm::dbgs() << "attribute '" << srcAttr << "' converted to '" |
68 | << dstAttr << "' for type '" << dstType << "'\n" ); |
69 | return dstAttr; |
70 | } |
71 | |
72 | LLVM_DEBUG(llvm::dbgs() << "attribute '" << srcAttr |
73 | << "' illegal: cannot fit into target type '" |
74 | << dstType << "'\n" ); |
75 | return {}; |
76 | } |
77 | |
78 | /// Converts the given `srcAttr` to a new attribute of the given `dstType`. |
79 | /// Returns null attribute if `dstType` is not 32-bit or conversion fails. |
80 | static FloatAttr convertFloatAttr(FloatAttr srcAttr, FloatType dstType, |
81 | Builder builder) { |
82 | // Only support converting to float for now. |
83 | if (!dstType.isF32()) |
84 | return FloatAttr(); |
85 | |
86 | // Try to convert the source floating-point number to single precision. |
87 | APFloat dstVal = srcAttr.getValue(); |
88 | bool losesInfo = false; |
89 | APFloat::opStatus status = |
90 | dstVal.convert(ToSemantics: APFloat::IEEEsingle(), RM: APFloat::rmTowardZero, losesInfo: &losesInfo); |
91 | if (status != APFloat::opOK || losesInfo) { |
92 | LLVM_DEBUG(llvm::dbgs() |
93 | << srcAttr << " illegal: cannot fit into converted type '" |
94 | << dstType << "'\n" ); |
95 | return FloatAttr(); |
96 | } |
97 | |
98 | return builder.getF32FloatAttr(dstVal.convertToFloat()); |
99 | } |
100 | |
101 | /// Returns true if the given `type` is a boolean scalar or vector type. |
102 | static bool isBoolScalarOrVector(Type type) { |
103 | assert(type && "Not a valid type" ); |
104 | if (type.isInteger(width: 1)) |
105 | return true; |
106 | |
107 | if (auto vecType = dyn_cast<VectorType>(type)) |
108 | return vecType.getElementType().isInteger(1); |
109 | |
110 | return false; |
111 | } |
112 | |
113 | /// Creates a scalar/vector integer constant. |
114 | static Value getScalarOrVectorConstInt(Type type, uint64_t value, |
115 | OpBuilder &builder, Location loc) { |
116 | if (auto vectorType = dyn_cast<VectorType>(type)) { |
117 | Attribute element = IntegerAttr::get(vectorType.getElementType(), value); |
118 | auto attr = SplatElementsAttr::get(vectorType, element); |
119 | return builder.create<spirv::ConstantOp>(loc, vectorType, attr); |
120 | } |
121 | |
122 | if (auto intType = dyn_cast<IntegerType>(type)) |
123 | return builder.create<spirv::ConstantOp>( |
124 | loc, type, builder.getIntegerAttr(type, value)); |
125 | |
126 | return nullptr; |
127 | } |
128 | |
129 | /// Returns true if scalar/vector type `a` and `b` have the same number of |
130 | /// bitwidth. |
131 | static bool hasSameBitwidth(Type a, Type b) { |
132 | auto getNumBitwidth = [](Type type) { |
133 | unsigned bw = 0; |
134 | if (type.isIntOrFloat()) |
135 | bw = type.getIntOrFloatBitWidth(); |
136 | else if (auto vecType = dyn_cast<VectorType>(type)) |
137 | bw = vecType.getElementTypeBitWidth() * vecType.getNumElements(); |
138 | return bw; |
139 | }; |
140 | unsigned aBW = getNumBitwidth(a); |
141 | unsigned bBW = getNumBitwidth(b); |
142 | return aBW != 0 && bBW != 0 && aBW == bBW; |
143 | } |
144 | |
145 | /// Returns a source type conversion failure for `srcType` and operation `op`. |
146 | static LogicalResult |
147 | getTypeConversionFailure(ConversionPatternRewriter &rewriter, Operation *op, |
148 | Type srcType) { |
149 | return rewriter.notifyMatchFailure( |
150 | arg: op->getLoc(), |
151 | msg: llvm::formatv(Fmt: "failed to convert source type '{0}'" , Vals&: srcType)); |
152 | } |
153 | |
154 | /// Returns a source type conversion failure for the result type of `op`. |
155 | static LogicalResult |
156 | getTypeConversionFailure(ConversionPatternRewriter &rewriter, Operation *op) { |
157 | assert(op->getNumResults() == 1); |
158 | return getTypeConversionFailure(rewriter, op, srcType: op->getResultTypes().front()); |
159 | } |
160 | |
161 | // TODO: Move to some common place? |
162 | static std::string getDecorationString(spirv::Decoration decor) { |
163 | return llvm::convertToSnakeFromCamelCase(input: stringifyDecoration(decor)); |
164 | } |
165 | |
166 | namespace { |
167 | |
168 | /// Converts elementwise unary, binary and ternary arith operations to SPIR-V |
169 | /// operations. Op can potentially support overflow flags. |
170 | template <typename Op, typename SPIRVOp> |
171 | struct ElementwiseArithOpPattern final : OpConversionPattern<Op> { |
172 | using OpConversionPattern<Op>::OpConversionPattern; |
173 | |
174 | LogicalResult |
175 | matchAndRewrite(Op op, typename Op::Adaptor adaptor, |
176 | ConversionPatternRewriter &rewriter) const override { |
177 | assert(adaptor.getOperands().size() <= 3); |
178 | auto converter = this->template getTypeConverter<SPIRVTypeConverter>(); |
179 | Type dstType = converter->convertType(op.getType()); |
180 | if (!dstType) { |
181 | return rewriter.notifyMatchFailure( |
182 | op->getLoc(), |
183 | llvm::formatv("failed to convert type {0} for SPIR-V" , op.getType())); |
184 | } |
185 | |
186 | if (SPIRVOp::template hasTrait<OpTrait::spirv::UnsignedOp>() && |
187 | !getElementTypeOrSelf(op.getType()).isIndex() && |
188 | dstType != op.getType()) { |
189 | return op.emitError("bitwidth emulation is not implemented yet on " |
190 | "unsigned op pattern version" ); |
191 | } |
192 | |
193 | auto overflowFlags = arith::IntegerOverflowFlags::none; |
194 | if (auto overflowIface = |
195 | dyn_cast<arith::ArithIntegerOverflowFlagsInterface>(*op)) { |
196 | if (converter->getTargetEnv().allows( |
197 | spirv::Extension::SPV_KHR_no_integer_wrap_decoration)) |
198 | overflowFlags = overflowIface.getOverflowAttr().getValue(); |
199 | } |
200 | |
201 | auto newOp = rewriter.template replaceOpWithNewOp<SPIRVOp>( |
202 | op, dstType, adaptor.getOperands()); |
203 | |
204 | if (bitEnumContainsAny(overflowFlags, arith::IntegerOverflowFlags::nsw)) |
205 | newOp->setAttr(getDecorationString(spirv::Decoration::NoSignedWrap), |
206 | rewriter.getUnitAttr()); |
207 | |
208 | if (bitEnumContainsAny(overflowFlags, arith::IntegerOverflowFlags::nuw)) |
209 | newOp->setAttr(getDecorationString(spirv::Decoration::NoUnsignedWrap), |
210 | rewriter.getUnitAttr()); |
211 | |
212 | return success(); |
213 | } |
214 | }; |
215 | |
216 | //===----------------------------------------------------------------------===// |
217 | // ConstantOp |
218 | //===----------------------------------------------------------------------===// |
219 | |
220 | /// Converts composite arith.constant operation to spirv.Constant. |
221 | struct ConstantCompositeOpPattern final |
222 | : public OpConversionPattern<arith::ConstantOp> { |
223 | using OpConversionPattern::OpConversionPattern; |
224 | |
225 | LogicalResult |
226 | matchAndRewrite(arith::ConstantOp constOp, OpAdaptor adaptor, |
227 | ConversionPatternRewriter &rewriter) const override { |
228 | auto srcType = dyn_cast<ShapedType>(constOp.getType()); |
229 | if (!srcType || srcType.getNumElements() == 1) |
230 | return failure(); |
231 | |
232 | // arith.constant should only have vector or tenor types. |
233 | assert((isa<VectorType, RankedTensorType>(srcType))); |
234 | |
235 | Type dstType = getTypeConverter()->convertType(srcType); |
236 | if (!dstType) |
237 | return failure(); |
238 | |
239 | auto dstElementsAttr = dyn_cast<DenseElementsAttr>(constOp.getValue()); |
240 | if (!dstElementsAttr) |
241 | return failure(); |
242 | |
243 | ShapedType dstAttrType = dstElementsAttr.getType(); |
244 | |
245 | // If the composite type has more than one dimensions, perform |
246 | // linearization. |
247 | if (srcType.getRank() > 1) { |
248 | if (isa<RankedTensorType>(srcType)) { |
249 | dstAttrType = RankedTensorType::get(srcType.getNumElements(), |
250 | srcType.getElementType()); |
251 | dstElementsAttr = dstElementsAttr.reshape(dstAttrType); |
252 | } else { |
253 | // TODO: add support for large vectors. |
254 | return failure(); |
255 | } |
256 | } |
257 | |
258 | Type srcElemType = srcType.getElementType(); |
259 | Type dstElemType; |
260 | // Tensor types are converted to SPIR-V array types; vector types are |
261 | // converted to SPIR-V vector/array types. |
262 | if (auto arrayType = dyn_cast<spirv::ArrayType>(dstType)) |
263 | dstElemType = arrayType.getElementType(); |
264 | else |
265 | dstElemType = cast<VectorType>(dstType).getElementType(); |
266 | |
267 | // If the source and destination element types are different, perform |
268 | // attribute conversion. |
269 | if (srcElemType != dstElemType) { |
270 | SmallVector<Attribute, 8> elements; |
271 | if (isa<FloatType>(Val: srcElemType)) { |
272 | for (FloatAttr srcAttr : dstElementsAttr.getValues<FloatAttr>()) { |
273 | FloatAttr dstAttr = |
274 | convertFloatAttr(srcAttr, cast<FloatType>(dstElemType), rewriter); |
275 | if (!dstAttr) |
276 | return failure(); |
277 | elements.push_back(dstAttr); |
278 | } |
279 | } else if (srcElemType.isInteger(width: 1)) { |
280 | return failure(); |
281 | } else { |
282 | for (IntegerAttr srcAttr : dstElementsAttr.getValues<IntegerAttr>()) { |
283 | IntegerAttr dstAttr = convertIntegerAttr( |
284 | srcAttr, cast<IntegerType>(dstElemType), rewriter); |
285 | if (!dstAttr) |
286 | return failure(); |
287 | elements.push_back(dstAttr); |
288 | } |
289 | } |
290 | |
291 | // Unfortunately, we cannot use dialect-specific types for element |
292 | // attributes; element attributes only works with builtin types. So we |
293 | // need to prepare another converted builtin types for the destination |
294 | // elements attribute. |
295 | if (isa<RankedTensorType>(dstAttrType)) |
296 | dstAttrType = |
297 | RankedTensorType::get(dstAttrType.getShape(), dstElemType); |
298 | else |
299 | dstAttrType = VectorType::get(dstAttrType.getShape(), dstElemType); |
300 | |
301 | dstElementsAttr = DenseElementsAttr::get(dstAttrType, elements); |
302 | } |
303 | |
304 | rewriter.replaceOpWithNewOp<spirv::ConstantOp>(constOp, dstType, |
305 | dstElementsAttr); |
306 | return success(); |
307 | } |
308 | }; |
309 | |
310 | /// Converts scalar arith.constant operation to spirv.Constant. |
311 | struct ConstantScalarOpPattern final |
312 | : public OpConversionPattern<arith::ConstantOp> { |
313 | using OpConversionPattern::OpConversionPattern; |
314 | |
315 | LogicalResult |
316 | matchAndRewrite(arith::ConstantOp constOp, OpAdaptor adaptor, |
317 | ConversionPatternRewriter &rewriter) const override { |
318 | Type srcType = constOp.getType(); |
319 | if (auto shapedType = dyn_cast<ShapedType>(srcType)) { |
320 | if (shapedType.getNumElements() != 1) |
321 | return failure(); |
322 | srcType = shapedType.getElementType(); |
323 | } |
324 | if (!srcType.isIntOrIndexOrFloat()) |
325 | return failure(); |
326 | |
327 | Attribute cstAttr = constOp.getValue(); |
328 | if (auto elementsAttr = dyn_cast<DenseElementsAttr>(cstAttr)) |
329 | cstAttr = elementsAttr.getSplatValue<Attribute>(); |
330 | |
331 | Type dstType = getTypeConverter()->convertType(srcType); |
332 | if (!dstType) |
333 | return failure(); |
334 | |
335 | // Floating-point types. |
336 | if (isa<FloatType>(Val: srcType)) { |
337 | auto srcAttr = cast<FloatAttr>(cstAttr); |
338 | auto dstAttr = srcAttr; |
339 | |
340 | // Floating-point types not supported in the target environment are all |
341 | // converted to float type. |
342 | if (srcType != dstType) { |
343 | dstAttr = convertFloatAttr(srcAttr, cast<FloatType>(dstType), rewriter); |
344 | if (!dstAttr) |
345 | return failure(); |
346 | } |
347 | |
348 | rewriter.replaceOpWithNewOp<spirv::ConstantOp>(constOp, dstType, dstAttr); |
349 | return success(); |
350 | } |
351 | |
352 | // Bool type. |
353 | if (srcType.isInteger(width: 1)) { |
354 | // arith.constant can use 0/1 instead of true/false for i1 values. We need |
355 | // to handle that here. |
356 | auto dstAttr = convertBoolAttr(srcAttr: cstAttr, builder: rewriter); |
357 | if (!dstAttr) |
358 | return failure(); |
359 | rewriter.replaceOpWithNewOp<spirv::ConstantOp>(constOp, dstType, dstAttr); |
360 | return success(); |
361 | } |
362 | |
363 | // IndexType or IntegerType. Index values are converted to 32-bit integer |
364 | // values when converting to SPIR-V. |
365 | auto srcAttr = cast<IntegerAttr>(cstAttr); |
366 | IntegerAttr dstAttr = |
367 | convertIntegerAttr(srcAttr, cast<IntegerType>(dstType), rewriter); |
368 | if (!dstAttr) |
369 | return failure(); |
370 | rewriter.replaceOpWithNewOp<spirv::ConstantOp>(constOp, dstType, dstAttr); |
371 | return success(); |
372 | } |
373 | }; |
374 | |
375 | //===----------------------------------------------------------------------===// |
376 | // RemSIOp |
377 | //===----------------------------------------------------------------------===// |
378 | |
379 | /// Returns signed remainder for `lhs` and `rhs` and lets the result follow |
380 | /// the sign of `signOperand`. |
381 | /// |
382 | /// Note that this is needed for Vulkan. Per the Vulkan's SPIR-V environment |
383 | /// spec, "for the OpSRem and OpSMod instructions, if either operand is negative |
384 | /// the result is undefined." So we cannot directly use spirv.SRem/spirv.SMod |
385 | /// if either operand can be negative. Emulate it via spirv.UMod. |
386 | template <typename SignedAbsOp> |
387 | static Value emulateSignedRemainder(Location loc, Value lhs, Value rhs, |
388 | Value signOperand, OpBuilder &builder) { |
389 | assert(lhs.getType() == rhs.getType()); |
390 | assert(lhs == signOperand || rhs == signOperand); |
391 | |
392 | Type type = lhs.getType(); |
393 | |
394 | // Calculate the remainder with spirv.UMod. |
395 | Value lhsAbs = builder.create<SignedAbsOp>(loc, type, lhs); |
396 | Value rhsAbs = builder.create<SignedAbsOp>(loc, type, rhs); |
397 | Value abs = builder.create<spirv::UModOp>(loc, lhsAbs, rhsAbs); |
398 | |
399 | // Fix the sign. |
400 | Value isPositive; |
401 | if (lhs == signOperand) |
402 | isPositive = builder.create<spirv::IEqualOp>(loc, lhs, lhsAbs); |
403 | else |
404 | isPositive = builder.create<spirv::IEqualOp>(loc, rhs, rhsAbs); |
405 | Value absNegate = builder.create<spirv::SNegateOp>(loc, type, abs); |
406 | return builder.create<spirv::SelectOp>(loc, type, isPositive, abs, absNegate); |
407 | } |
408 | |
409 | /// Converts arith.remsi to GLSL SPIR-V ops. |
410 | /// |
411 | /// This cannot be merged into the template unary/binary pattern due to Vulkan |
412 | /// restrictions over spirv.SRem and spirv.SMod. |
413 | struct RemSIOpGLPattern final : public OpConversionPattern<arith::RemSIOp> { |
414 | using OpConversionPattern::OpConversionPattern; |
415 | |
416 | LogicalResult |
417 | matchAndRewrite(arith::RemSIOp op, OpAdaptor adaptor, |
418 | ConversionPatternRewriter &rewriter) const override { |
419 | Value result = emulateSignedRemainder<spirv::CLSAbsOp>( |
420 | op.getLoc(), adaptor.getOperands()[0], adaptor.getOperands()[1], |
421 | adaptor.getOperands()[0], rewriter); |
422 | rewriter.replaceOp(op, result); |
423 | |
424 | return success(); |
425 | } |
426 | }; |
427 | |
428 | /// Converts arith.remsi to OpenCL SPIR-V ops. |
429 | struct RemSIOpCLPattern final : public OpConversionPattern<arith::RemSIOp> { |
430 | using OpConversionPattern::OpConversionPattern; |
431 | |
432 | LogicalResult |
433 | matchAndRewrite(arith::RemSIOp op, OpAdaptor adaptor, |
434 | ConversionPatternRewriter &rewriter) const override { |
435 | Value result = emulateSignedRemainder<spirv::GLSAbsOp>( |
436 | op.getLoc(), adaptor.getOperands()[0], adaptor.getOperands()[1], |
437 | adaptor.getOperands()[0], rewriter); |
438 | rewriter.replaceOp(op, result); |
439 | |
440 | return success(); |
441 | } |
442 | }; |
443 | |
444 | //===----------------------------------------------------------------------===// |
445 | // BitwiseOp |
446 | //===----------------------------------------------------------------------===// |
447 | |
448 | /// Converts bitwise operations to SPIR-V operations. This is a special pattern |
449 | /// other than the BinaryOpPatternPattern because if the operands are boolean |
450 | /// values, SPIR-V uses different operations (`SPIRVLogicalOp`). For |
451 | /// non-boolean operands, SPIR-V should use `SPIRVBitwiseOp`. |
452 | template <typename Op, typename SPIRVLogicalOp, typename SPIRVBitwiseOp> |
453 | struct BitwiseOpPattern final : public OpConversionPattern<Op> { |
454 | using OpConversionPattern<Op>::OpConversionPattern; |
455 | |
456 | LogicalResult |
457 | matchAndRewrite(Op op, typename Op::Adaptor adaptor, |
458 | ConversionPatternRewriter &rewriter) const override { |
459 | assert(adaptor.getOperands().size() == 2); |
460 | Type dstType = this->getTypeConverter()->convertType(op.getType()); |
461 | if (!dstType) |
462 | return getTypeConversionFailure(rewriter, op); |
463 | |
464 | if (isBoolScalarOrVector(adaptor.getOperands().front().getType())) { |
465 | rewriter.template replaceOpWithNewOp<SPIRVLogicalOp>( |
466 | op, dstType, adaptor.getOperands()); |
467 | } else { |
468 | rewriter.template replaceOpWithNewOp<SPIRVBitwiseOp>( |
469 | op, dstType, adaptor.getOperands()); |
470 | } |
471 | return success(); |
472 | } |
473 | }; |
474 | |
475 | //===----------------------------------------------------------------------===// |
476 | // XOrIOp |
477 | //===----------------------------------------------------------------------===// |
478 | |
479 | /// Converts arith.xori to SPIR-V operations. |
480 | struct XOrIOpLogicalPattern final : public OpConversionPattern<arith::XOrIOp> { |
481 | using OpConversionPattern::OpConversionPattern; |
482 | |
483 | LogicalResult |
484 | matchAndRewrite(arith::XOrIOp op, OpAdaptor adaptor, |
485 | ConversionPatternRewriter &rewriter) const override { |
486 | assert(adaptor.getOperands().size() == 2); |
487 | |
488 | if (isBoolScalarOrVector(adaptor.getOperands().front().getType())) |
489 | return failure(); |
490 | |
491 | Type dstType = getTypeConverter()->convertType(op.getType()); |
492 | if (!dstType) |
493 | return getTypeConversionFailure(rewriter, op); |
494 | |
495 | rewriter.replaceOpWithNewOp<spirv::BitwiseXorOp>(op, dstType, |
496 | adaptor.getOperands()); |
497 | |
498 | return success(); |
499 | } |
500 | }; |
501 | |
502 | /// Converts arith.xori to SPIR-V operations if the type of source is i1 or |
503 | /// vector of i1. |
504 | struct XOrIOpBooleanPattern final : public OpConversionPattern<arith::XOrIOp> { |
505 | using OpConversionPattern::OpConversionPattern; |
506 | |
507 | LogicalResult |
508 | matchAndRewrite(arith::XOrIOp op, OpAdaptor adaptor, |
509 | ConversionPatternRewriter &rewriter) const override { |
510 | assert(adaptor.getOperands().size() == 2); |
511 | |
512 | if (!isBoolScalarOrVector(adaptor.getOperands().front().getType())) |
513 | return failure(); |
514 | |
515 | Type dstType = getTypeConverter()->convertType(op.getType()); |
516 | if (!dstType) |
517 | return getTypeConversionFailure(rewriter, op); |
518 | |
519 | rewriter.replaceOpWithNewOp<spirv::LogicalNotEqualOp>( |
520 | op, dstType, adaptor.getOperands()); |
521 | return success(); |
522 | } |
523 | }; |
524 | |
525 | //===----------------------------------------------------------------------===// |
526 | // UIToFPOp |
527 | //===----------------------------------------------------------------------===// |
528 | |
529 | /// Converts arith.uitofp to spirv.Select if the type of source is i1 or vector |
530 | /// of i1. |
531 | struct UIToFPI1Pattern final : public OpConversionPattern<arith::UIToFPOp> { |
532 | using OpConversionPattern::OpConversionPattern; |
533 | |
534 | LogicalResult |
535 | matchAndRewrite(arith::UIToFPOp op, OpAdaptor adaptor, |
536 | ConversionPatternRewriter &rewriter) const override { |
537 | Type srcType = adaptor.getOperands().front().getType(); |
538 | if (!isBoolScalarOrVector(type: srcType)) |
539 | return failure(); |
540 | |
541 | Type dstType = getTypeConverter()->convertType(op.getType()); |
542 | if (!dstType) |
543 | return getTypeConversionFailure(rewriter, op); |
544 | |
545 | Location loc = op.getLoc(); |
546 | Value zero = spirv::ConstantOp::getZero(dstType, loc, rewriter); |
547 | Value one = spirv::ConstantOp::getOne(dstType, loc, rewriter); |
548 | rewriter.replaceOpWithNewOp<spirv::SelectOp>( |
549 | op, dstType, adaptor.getOperands().front(), one, zero); |
550 | return success(); |
551 | } |
552 | }; |
553 | |
554 | //===----------------------------------------------------------------------===// |
555 | // ExtSIOp |
556 | //===----------------------------------------------------------------------===// |
557 | |
558 | /// Converts arith.extsi to spirv.Select if the type of source is i1 or vector |
559 | /// of i1. |
560 | struct ExtSII1Pattern final : public OpConversionPattern<arith::ExtSIOp> { |
561 | using OpConversionPattern::OpConversionPattern; |
562 | |
563 | LogicalResult |
564 | matchAndRewrite(arith::ExtSIOp op, OpAdaptor adaptor, |
565 | ConversionPatternRewriter &rewriter) const override { |
566 | Value operand = adaptor.getIn(); |
567 | if (!isBoolScalarOrVector(type: operand.getType())) |
568 | return failure(); |
569 | |
570 | Location loc = op.getLoc(); |
571 | Type dstType = getTypeConverter()->convertType(op.getType()); |
572 | if (!dstType) |
573 | return getTypeConversionFailure(rewriter, op); |
574 | |
575 | Value allOnes; |
576 | if (auto intTy = dyn_cast<IntegerType>(dstType)) { |
577 | unsigned componentBitwidth = intTy.getWidth(); |
578 | allOnes = rewriter.create<spirv::ConstantOp>( |
579 | loc, intTy, |
580 | rewriter.getIntegerAttr(intTy, APInt::getAllOnes(componentBitwidth))); |
581 | } else if (auto vectorTy = dyn_cast<VectorType>(dstType)) { |
582 | unsigned componentBitwidth = vectorTy.getElementTypeBitWidth(); |
583 | allOnes = rewriter.create<spirv::ConstantOp>( |
584 | loc, vectorTy, |
585 | SplatElementsAttr::get(vectorTy, |
586 | APInt::getAllOnes(componentBitwidth))); |
587 | } else { |
588 | return rewriter.notifyMatchFailure( |
589 | arg&: loc, msg: llvm::formatv(Fmt: "unhandled type: {0}" , Vals&: dstType)); |
590 | } |
591 | |
592 | Value zero = spirv::ConstantOp::getZero(dstType, loc, rewriter); |
593 | rewriter.replaceOpWithNewOp<spirv::SelectOp>(op, dstType, operand, allOnes, |
594 | zero); |
595 | return success(); |
596 | } |
597 | }; |
598 | |
599 | /// Converts arith.extsi to spirv.Select if the type of source is neither i1 nor |
600 | /// vector of i1. |
601 | struct ExtSIPattern final : public OpConversionPattern<arith::ExtSIOp> { |
602 | using OpConversionPattern::OpConversionPattern; |
603 | |
604 | LogicalResult |
605 | matchAndRewrite(arith::ExtSIOp op, OpAdaptor adaptor, |
606 | ConversionPatternRewriter &rewriter) const override { |
607 | Type srcType = adaptor.getIn().getType(); |
608 | if (isBoolScalarOrVector(type: srcType)) |
609 | return failure(); |
610 | |
611 | Type dstType = getTypeConverter()->convertType(op.getType()); |
612 | if (!dstType) |
613 | return getTypeConversionFailure(rewriter, op); |
614 | |
615 | if (dstType == srcType) { |
616 | // We can have the same source and destination type due to type emulation. |
617 | // Perform bit shifting to make sure we have the proper leading set bits. |
618 | |
619 | unsigned srcBW = |
620 | getElementTypeOrSelf(op.getIn().getType()).getIntOrFloatBitWidth(); |
621 | unsigned dstBW = |
622 | getElementTypeOrSelf(op.getType()).getIntOrFloatBitWidth(); |
623 | assert(srcBW < dstBW); |
624 | Value shiftSize = getScalarOrVectorConstInt(dstType, dstBW - srcBW, |
625 | rewriter, op.getLoc()); |
626 | |
627 | // First shift left to sequeeze out all leading bits beyond the original |
628 | // bitwidth. Here we need to use the original source and result type's |
629 | // bitwidth. |
630 | auto shiftLOp = rewriter.create<spirv::ShiftLeftLogicalOp>( |
631 | op.getLoc(), dstType, adaptor.getIn(), shiftSize); |
632 | |
633 | // Then we perform arithmetic right shift to make sure we have the right |
634 | // sign bits for negative values. |
635 | rewriter.replaceOpWithNewOp<spirv::ShiftRightArithmeticOp>( |
636 | op, dstType, shiftLOp, shiftSize); |
637 | } else { |
638 | rewriter.replaceOpWithNewOp<spirv::SConvertOp>(op, dstType, |
639 | adaptor.getOperands()); |
640 | } |
641 | |
642 | return success(); |
643 | } |
644 | }; |
645 | |
646 | //===----------------------------------------------------------------------===// |
647 | // ExtUIOp |
648 | //===----------------------------------------------------------------------===// |
649 | |
650 | /// Converts arith.extui to spirv.Select if the type of source is i1 or vector |
651 | /// of i1. |
652 | struct ExtUII1Pattern final : public OpConversionPattern<arith::ExtUIOp> { |
653 | using OpConversionPattern::OpConversionPattern; |
654 | |
655 | LogicalResult |
656 | matchAndRewrite(arith::ExtUIOp op, OpAdaptor adaptor, |
657 | ConversionPatternRewriter &rewriter) const override { |
658 | Type srcType = adaptor.getOperands().front().getType(); |
659 | if (!isBoolScalarOrVector(type: srcType)) |
660 | return failure(); |
661 | |
662 | Type dstType = getTypeConverter()->convertType(op.getType()); |
663 | if (!dstType) |
664 | return getTypeConversionFailure(rewriter, op); |
665 | |
666 | Location loc = op.getLoc(); |
667 | Value zero = spirv::ConstantOp::getZero(dstType, loc, rewriter); |
668 | Value one = spirv::ConstantOp::getOne(dstType, loc, rewriter); |
669 | rewriter.replaceOpWithNewOp<spirv::SelectOp>( |
670 | op, dstType, adaptor.getOperands().front(), one, zero); |
671 | return success(); |
672 | } |
673 | }; |
674 | |
675 | /// Converts arith.extui for cases where the type of source is neither i1 nor |
676 | /// vector of i1. |
677 | struct ExtUIPattern final : public OpConversionPattern<arith::ExtUIOp> { |
678 | using OpConversionPattern::OpConversionPattern; |
679 | |
680 | LogicalResult |
681 | matchAndRewrite(arith::ExtUIOp op, OpAdaptor adaptor, |
682 | ConversionPatternRewriter &rewriter) const override { |
683 | Type srcType = adaptor.getIn().getType(); |
684 | if (isBoolScalarOrVector(type: srcType)) |
685 | return failure(); |
686 | |
687 | Type dstType = getTypeConverter()->convertType(op.getType()); |
688 | if (!dstType) |
689 | return getTypeConversionFailure(rewriter, op); |
690 | |
691 | if (dstType == srcType) { |
692 | // We can have the same source and destination type due to type emulation. |
693 | // Perform bit masking to make sure we don't pollute downstream consumers |
694 | // with unwanted bits. Here we need to use the original source type's |
695 | // bitwidth. |
696 | unsigned bitwidth = |
697 | getElementTypeOrSelf(op.getIn().getType()).getIntOrFloatBitWidth(); |
698 | Value mask = getScalarOrVectorConstInt( |
699 | dstType, llvm::maskTrailingOnes<uint64_t>(N: bitwidth), rewriter, |
700 | op.getLoc()); |
701 | rewriter.replaceOpWithNewOp<spirv::BitwiseAndOp>(op, dstType, |
702 | adaptor.getIn(), mask); |
703 | } else { |
704 | rewriter.replaceOpWithNewOp<spirv::UConvertOp>(op, dstType, |
705 | adaptor.getOperands()); |
706 | } |
707 | return success(); |
708 | } |
709 | }; |
710 | |
711 | //===----------------------------------------------------------------------===// |
712 | // TruncIOp |
713 | //===----------------------------------------------------------------------===// |
714 | |
715 | /// Converts arith.trunci to spirv.Select if the type of result is i1 or vector |
716 | /// of i1. |
717 | struct TruncII1Pattern final : public OpConversionPattern<arith::TruncIOp> { |
718 | using OpConversionPattern::OpConversionPattern; |
719 | |
720 | LogicalResult |
721 | matchAndRewrite(arith::TruncIOp op, OpAdaptor adaptor, |
722 | ConversionPatternRewriter &rewriter) const override { |
723 | Type dstType = getTypeConverter()->convertType(op.getType()); |
724 | if (!dstType) |
725 | return getTypeConversionFailure(rewriter, op); |
726 | |
727 | if (!isBoolScalarOrVector(type: dstType)) |
728 | return failure(); |
729 | |
730 | Location loc = op.getLoc(); |
731 | auto srcType = adaptor.getOperands().front().getType(); |
732 | // Check if (x & 1) == 1. |
733 | Value mask = spirv::ConstantOp::getOne(srcType, loc, rewriter); |
734 | Value maskedSrc = rewriter.create<spirv::BitwiseAndOp>( |
735 | loc, srcType, adaptor.getOperands()[0], mask); |
736 | Value isOne = rewriter.create<spirv::IEqualOp>(loc, maskedSrc, mask); |
737 | |
738 | Value zero = spirv::ConstantOp::getZero(dstType, loc, rewriter); |
739 | Value one = spirv::ConstantOp::getOne(dstType, loc, rewriter); |
740 | rewriter.replaceOpWithNewOp<spirv::SelectOp>(op, dstType, isOne, one, zero); |
741 | return success(); |
742 | } |
743 | }; |
744 | |
745 | /// Converts arith.trunci for cases where the type of result is neither i1 |
746 | /// nor vector of i1. |
747 | struct TruncIPattern final : public OpConversionPattern<arith::TruncIOp> { |
748 | using OpConversionPattern::OpConversionPattern; |
749 | |
750 | LogicalResult |
751 | matchAndRewrite(arith::TruncIOp op, OpAdaptor adaptor, |
752 | ConversionPatternRewriter &rewriter) const override { |
753 | Type srcType = adaptor.getIn().getType(); |
754 | Type dstType = getTypeConverter()->convertType(op.getType()); |
755 | if (!dstType) |
756 | return getTypeConversionFailure(rewriter, op); |
757 | |
758 | if (isBoolScalarOrVector(type: dstType)) |
759 | return failure(); |
760 | |
761 | if (dstType == srcType) { |
762 | // We can have the same source and destination type due to type emulation. |
763 | // Perform bit masking to make sure we don't pollute downstream consumers |
764 | // with unwanted bits. Here we need to use the original result type's |
765 | // bitwidth. |
766 | unsigned bw = getElementTypeOrSelf(op.getType()).getIntOrFloatBitWidth(); |
767 | Value mask = getScalarOrVectorConstInt( |
768 | dstType, llvm::maskTrailingOnes<uint64_t>(N: bw), rewriter, op.getLoc()); |
769 | rewriter.replaceOpWithNewOp<spirv::BitwiseAndOp>(op, dstType, |
770 | adaptor.getIn(), mask); |
771 | } else { |
772 | // Given this is truncation, either SConvertOp or UConvertOp works. |
773 | rewriter.replaceOpWithNewOp<spirv::SConvertOp>(op, dstType, |
774 | adaptor.getOperands()); |
775 | } |
776 | return success(); |
777 | } |
778 | }; |
779 | |
780 | //===----------------------------------------------------------------------===// |
781 | // TypeCastingOp |
782 | //===----------------------------------------------------------------------===// |
783 | |
784 | /// Converts type-casting standard operations to SPIR-V operations. |
785 | template <typename Op, typename SPIRVOp> |
786 | struct TypeCastingOpPattern final : public OpConversionPattern<Op> { |
787 | using OpConversionPattern<Op>::OpConversionPattern; |
788 | |
789 | LogicalResult |
790 | matchAndRewrite(Op op, typename Op::Adaptor adaptor, |
791 | ConversionPatternRewriter &rewriter) const override { |
792 | assert(adaptor.getOperands().size() == 1); |
793 | Type srcType = adaptor.getOperands().front().getType(); |
794 | Type dstType = this->getTypeConverter()->convertType(op.getType()); |
795 | if (!dstType) |
796 | return getTypeConversionFailure(rewriter, op); |
797 | |
798 | if (isBoolScalarOrVector(type: srcType) || isBoolScalarOrVector(type: dstType)) |
799 | return failure(); |
800 | |
801 | if (dstType == srcType) { |
802 | // Due to type conversion, we are seeing the same source and target type. |
803 | // Then we can just erase this operation by forwarding its operand. |
804 | rewriter.replaceOp(op, adaptor.getOperands().front()); |
805 | } else { |
806 | rewriter.template replaceOpWithNewOp<SPIRVOp>(op, dstType, |
807 | adaptor.getOperands()); |
808 | if (auto roundingModeOp = |
809 | dyn_cast<arith::ArithRoundingModeInterface>(*op)) { |
810 | if (arith::RoundingModeAttr roundingMode = |
811 | roundingModeOp.getRoundingModeAttr()) { |
812 | // TODO: Perform rounding mode attribute conversion and attach to new |
813 | // operation when defined in the dialect. |
814 | return failure(); |
815 | } |
816 | } |
817 | } |
818 | return success(); |
819 | } |
820 | }; |
821 | |
822 | //===----------------------------------------------------------------------===// |
823 | // CmpIOp |
824 | //===----------------------------------------------------------------------===// |
825 | |
826 | /// Converts integer compare operation on i1 type operands to SPIR-V ops. |
827 | class CmpIOpBooleanPattern final : public OpConversionPattern<arith::CmpIOp> { |
828 | public: |
829 | using OpConversionPattern::OpConversionPattern; |
830 | |
831 | LogicalResult |
832 | matchAndRewrite(arith::CmpIOp op, OpAdaptor adaptor, |
833 | ConversionPatternRewriter &rewriter) const override { |
834 | Type srcType = op.getLhs().getType(); |
835 | if (!isBoolScalarOrVector(type: srcType)) |
836 | return failure(); |
837 | Type dstType = getTypeConverter()->convertType(srcType); |
838 | if (!dstType) |
839 | return getTypeConversionFailure(rewriter, op, srcType); |
840 | |
841 | switch (op.getPredicate()) { |
842 | case arith::CmpIPredicate::eq: { |
843 | rewriter.replaceOpWithNewOp<spirv::LogicalEqualOp>(op, adaptor.getLhs(), |
844 | adaptor.getRhs()); |
845 | return success(); |
846 | } |
847 | case arith::CmpIPredicate::ne: { |
848 | rewriter.replaceOpWithNewOp<spirv::LogicalNotEqualOp>( |
849 | op, adaptor.getLhs(), adaptor.getRhs()); |
850 | return success(); |
851 | } |
852 | case arith::CmpIPredicate::uge: |
853 | case arith::CmpIPredicate::ugt: |
854 | case arith::CmpIPredicate::ule: |
855 | case arith::CmpIPredicate::ult: { |
856 | // There are no direct corresponding instructions in SPIR-V for such |
857 | // cases. Extend them to 32-bit and do comparision then. |
858 | Type type = rewriter.getI32Type(); |
859 | if (auto vectorType = dyn_cast<VectorType>(dstType)) |
860 | type = VectorType::get(vectorType.getShape(), type); |
861 | Value extLhs = |
862 | rewriter.create<arith::ExtUIOp>(op.getLoc(), type, adaptor.getLhs()); |
863 | Value extRhs = |
864 | rewriter.create<arith::ExtUIOp>(op.getLoc(), type, adaptor.getRhs()); |
865 | |
866 | rewriter.replaceOpWithNewOp<arith::CmpIOp>(op, op.getPredicate(), extLhs, |
867 | extRhs); |
868 | return success(); |
869 | } |
870 | default: |
871 | break; |
872 | } |
873 | return failure(); |
874 | } |
875 | }; |
876 | |
877 | /// Converts integer compare operation to SPIR-V ops. |
878 | class CmpIOpPattern final : public OpConversionPattern<arith::CmpIOp> { |
879 | public: |
880 | using OpConversionPattern::OpConversionPattern; |
881 | |
882 | LogicalResult |
883 | matchAndRewrite(arith::CmpIOp op, OpAdaptor adaptor, |
884 | ConversionPatternRewriter &rewriter) const override { |
885 | Type srcType = op.getLhs().getType(); |
886 | if (isBoolScalarOrVector(type: srcType)) |
887 | return failure(); |
888 | Type dstType = getTypeConverter()->convertType(srcType); |
889 | if (!dstType) |
890 | return getTypeConversionFailure(rewriter, op, srcType); |
891 | |
892 | switch (op.getPredicate()) { |
893 | #define DISPATCH(cmpPredicate, spirvOp) \ |
894 | case cmpPredicate: \ |
895 | if (spirvOp::template hasTrait<OpTrait::spirv::UnsignedOp>() && \ |
896 | !getElementTypeOrSelf(srcType).isIndex() && srcType != dstType && \ |
897 | !hasSameBitwidth(srcType, dstType)) { \ |
898 | return op.emitError( \ |
899 | "bitwidth emulation is not implemented yet on unsigned op"); \ |
900 | } \ |
901 | rewriter.replaceOpWithNewOp<spirvOp>(op, adaptor.getLhs(), \ |
902 | adaptor.getRhs()); \ |
903 | return success(); |
904 | |
905 | DISPATCH(arith::CmpIPredicate::eq, spirv::IEqualOp); |
906 | DISPATCH(arith::CmpIPredicate::ne, spirv::INotEqualOp); |
907 | DISPATCH(arith::CmpIPredicate::slt, spirv::SLessThanOp); |
908 | DISPATCH(arith::CmpIPredicate::sle, spirv::SLessThanEqualOp); |
909 | DISPATCH(arith::CmpIPredicate::sgt, spirv::SGreaterThanOp); |
910 | DISPATCH(arith::CmpIPredicate::sge, spirv::SGreaterThanEqualOp); |
911 | DISPATCH(arith::CmpIPredicate::ult, spirv::ULessThanOp); |
912 | DISPATCH(arith::CmpIPredicate::ule, spirv::ULessThanEqualOp); |
913 | DISPATCH(arith::CmpIPredicate::ugt, spirv::UGreaterThanOp); |
914 | DISPATCH(arith::CmpIPredicate::uge, spirv::UGreaterThanEqualOp); |
915 | |
916 | #undef DISPATCH |
917 | } |
918 | return failure(); |
919 | } |
920 | }; |
921 | |
922 | //===----------------------------------------------------------------------===// |
923 | // CmpFOpPattern |
924 | //===----------------------------------------------------------------------===// |
925 | |
926 | /// Converts floating-point comparison operations to SPIR-V ops. |
927 | class CmpFOpPattern final : public OpConversionPattern<arith::CmpFOp> { |
928 | public: |
929 | using OpConversionPattern::OpConversionPattern; |
930 | |
931 | LogicalResult |
932 | matchAndRewrite(arith::CmpFOp op, OpAdaptor adaptor, |
933 | ConversionPatternRewriter &rewriter) const override { |
934 | switch (op.getPredicate()) { |
935 | #define DISPATCH(cmpPredicate, spirvOp) \ |
936 | case cmpPredicate: \ |
937 | rewriter.replaceOpWithNewOp<spirvOp>(op, adaptor.getLhs(), \ |
938 | adaptor.getRhs()); \ |
939 | return success(); |
940 | |
941 | // Ordered. |
942 | DISPATCH(arith::CmpFPredicate::OEQ, spirv::FOrdEqualOp); |
943 | DISPATCH(arith::CmpFPredicate::OGT, spirv::FOrdGreaterThanOp); |
944 | DISPATCH(arith::CmpFPredicate::OGE, spirv::FOrdGreaterThanEqualOp); |
945 | DISPATCH(arith::CmpFPredicate::OLT, spirv::FOrdLessThanOp); |
946 | DISPATCH(arith::CmpFPredicate::OLE, spirv::FOrdLessThanEqualOp); |
947 | DISPATCH(arith::CmpFPredicate::ONE, spirv::FOrdNotEqualOp); |
948 | // Unordered. |
949 | DISPATCH(arith::CmpFPredicate::UEQ, spirv::FUnordEqualOp); |
950 | DISPATCH(arith::CmpFPredicate::UGT, spirv::FUnordGreaterThanOp); |
951 | DISPATCH(arith::CmpFPredicate::UGE, spirv::FUnordGreaterThanEqualOp); |
952 | DISPATCH(arith::CmpFPredicate::ULT, spirv::FUnordLessThanOp); |
953 | DISPATCH(arith::CmpFPredicate::ULE, spirv::FUnordLessThanEqualOp); |
954 | DISPATCH(arith::CmpFPredicate::UNE, spirv::FUnordNotEqualOp); |
955 | |
956 | #undef DISPATCH |
957 | |
958 | default: |
959 | break; |
960 | } |
961 | return failure(); |
962 | } |
963 | }; |
964 | |
965 | /// Converts floating point NaN check to SPIR-V ops. This pattern requires |
966 | /// Kernel capability. |
967 | class CmpFOpNanKernelPattern final : public OpConversionPattern<arith::CmpFOp> { |
968 | public: |
969 | using OpConversionPattern::OpConversionPattern; |
970 | |
971 | LogicalResult |
972 | matchAndRewrite(arith::CmpFOp op, OpAdaptor adaptor, |
973 | ConversionPatternRewriter &rewriter) const override { |
974 | if (op.getPredicate() == arith::CmpFPredicate::ORD) { |
975 | rewriter.replaceOpWithNewOp<spirv::OrderedOp>(op, adaptor.getLhs(), |
976 | adaptor.getRhs()); |
977 | return success(); |
978 | } |
979 | |
980 | if (op.getPredicate() == arith::CmpFPredicate::UNO) { |
981 | rewriter.replaceOpWithNewOp<spirv::UnorderedOp>(op, adaptor.getLhs(), |
982 | adaptor.getRhs()); |
983 | return success(); |
984 | } |
985 | |
986 | return failure(); |
987 | } |
988 | }; |
989 | |
990 | /// Converts floating point NaN check to SPIR-V ops. This pattern does not |
991 | /// require additional capability. |
992 | class CmpFOpNanNonePattern final : public OpConversionPattern<arith::CmpFOp> { |
993 | public: |
994 | using OpConversionPattern<arith::CmpFOp>::OpConversionPattern; |
995 | |
996 | LogicalResult |
997 | matchAndRewrite(arith::CmpFOp op, OpAdaptor adaptor, |
998 | ConversionPatternRewriter &rewriter) const override { |
999 | if (op.getPredicate() != arith::CmpFPredicate::ORD && |
1000 | op.getPredicate() != arith::CmpFPredicate::UNO) |
1001 | return failure(); |
1002 | |
1003 | Location loc = op.getLoc(); |
1004 | |
1005 | Value replace; |
1006 | if (bitEnumContainsAll(op.getFastmath(), arith::FastMathFlags::nnan)) { |
1007 | if (op.getPredicate() == arith::CmpFPredicate::ORD) { |
1008 | // Ordered comparsion checks if neither operand is NaN. |
1009 | replace = spirv::ConstantOp::getOne(op.getType(), loc, rewriter); |
1010 | } else { |
1011 | // Unordered comparsion checks if either operand is NaN. |
1012 | replace = spirv::ConstantOp::getZero(op.getType(), loc, rewriter); |
1013 | } |
1014 | } else { |
1015 | Value lhsIsNan = rewriter.create<spirv::IsNanOp>(loc, adaptor.getLhs()); |
1016 | Value rhsIsNan = rewriter.create<spirv::IsNanOp>(loc, adaptor.getRhs()); |
1017 | |
1018 | replace = rewriter.create<spirv::LogicalOrOp>(loc, lhsIsNan, rhsIsNan); |
1019 | if (op.getPredicate() == arith::CmpFPredicate::ORD) |
1020 | replace = rewriter.create<spirv::LogicalNotOp>(loc, replace); |
1021 | } |
1022 | |
1023 | rewriter.replaceOp(op, replace); |
1024 | return success(); |
1025 | } |
1026 | }; |
1027 | |
1028 | //===----------------------------------------------------------------------===// |
1029 | // AddUIExtendedOp |
1030 | //===----------------------------------------------------------------------===// |
1031 | |
1032 | /// Converts arith.addui_extended to spirv.IAddCarry. |
1033 | class AddUIExtendedOpPattern final |
1034 | : public OpConversionPattern<arith::AddUIExtendedOp> { |
1035 | public: |
1036 | using OpConversionPattern::OpConversionPattern; |
1037 | LogicalResult |
1038 | matchAndRewrite(arith::AddUIExtendedOp op, OpAdaptor adaptor, |
1039 | ConversionPatternRewriter &rewriter) const override { |
1040 | Type dstElemTy = adaptor.getLhs().getType(); |
1041 | Location loc = op->getLoc(); |
1042 | Value result = rewriter.create<spirv::IAddCarryOp>(loc, adaptor.getLhs(), |
1043 | adaptor.getRhs()); |
1044 | |
1045 | Value sumResult = rewriter.create<spirv::CompositeExtractOp>( |
1046 | loc, result, llvm::ArrayRef(0)); |
1047 | Value carryValue = rewriter.create<spirv::CompositeExtractOp>( |
1048 | loc, result, llvm::ArrayRef(1)); |
1049 | |
1050 | // Convert the carry value to boolean. |
1051 | Value one = spirv::ConstantOp::getOne(dstElemTy, loc, rewriter); |
1052 | Value carryResult = rewriter.create<spirv::IEqualOp>(loc, carryValue, one); |
1053 | |
1054 | rewriter.replaceOp(op, {sumResult, carryResult}); |
1055 | return success(); |
1056 | } |
1057 | }; |
1058 | |
1059 | //===----------------------------------------------------------------------===// |
1060 | // MulIExtendedOp |
1061 | //===----------------------------------------------------------------------===// |
1062 | |
1063 | /// Converts arith.mul*i_extended to spirv.*MulExtended. |
1064 | template <typename ArithMulOp, typename SPIRVMulOp> |
1065 | class MulIExtendedOpPattern final : public OpConversionPattern<ArithMulOp> { |
1066 | public: |
1067 | using OpConversionPattern<ArithMulOp>::OpConversionPattern; |
1068 | LogicalResult |
1069 | matchAndRewrite(ArithMulOp op, typename ArithMulOp::Adaptor adaptor, |
1070 | ConversionPatternRewriter &rewriter) const override { |
1071 | Location loc = op->getLoc(); |
1072 | Value result = |
1073 | rewriter.create<SPIRVMulOp>(loc, adaptor.getLhs(), adaptor.getRhs()); |
1074 | |
1075 | Value low = rewriter.create<spirv::CompositeExtractOp>(loc, result, |
1076 | llvm::ArrayRef(0)); |
1077 | Value high = rewriter.create<spirv::CompositeExtractOp>(loc, result, |
1078 | llvm::ArrayRef(1)); |
1079 | |
1080 | rewriter.replaceOp(op, {low, high}); |
1081 | return success(); |
1082 | } |
1083 | }; |
1084 | |
1085 | //===----------------------------------------------------------------------===// |
1086 | // SelectOp |
1087 | //===----------------------------------------------------------------------===// |
1088 | |
1089 | /// Converts arith.select to spirv.Select. |
1090 | class SelectOpPattern final : public OpConversionPattern<arith::SelectOp> { |
1091 | public: |
1092 | using OpConversionPattern::OpConversionPattern; |
1093 | LogicalResult |
1094 | matchAndRewrite(arith::SelectOp op, OpAdaptor adaptor, |
1095 | ConversionPatternRewriter &rewriter) const override { |
1096 | rewriter.replaceOpWithNewOp<spirv::SelectOp>(op, adaptor.getCondition(), |
1097 | adaptor.getTrueValue(), |
1098 | adaptor.getFalseValue()); |
1099 | return success(); |
1100 | } |
1101 | }; |
1102 | |
1103 | //===----------------------------------------------------------------------===// |
1104 | // MinimumFOp, MaximumFOp |
1105 | //===----------------------------------------------------------------------===// |
1106 | |
1107 | /// Converts arith.maximumf/minimumf to spirv.GL.FMax/FMin or |
1108 | /// spirv.CL.fmax/fmin. |
1109 | template <typename Op, typename SPIRVOp> |
1110 | class MinimumMaximumFOpPattern final : public OpConversionPattern<Op> { |
1111 | public: |
1112 | using OpConversionPattern<Op>::OpConversionPattern; |
1113 | LogicalResult |
1114 | matchAndRewrite(Op op, typename Op::Adaptor adaptor, |
1115 | ConversionPatternRewriter &rewriter) const override { |
1116 | auto *converter = this->template getTypeConverter<SPIRVTypeConverter>(); |
1117 | Type dstType = converter->convertType(op.getType()); |
1118 | if (!dstType) |
1119 | return getTypeConversionFailure(rewriter, op); |
1120 | |
1121 | // arith.maximumf/minimumf: |
1122 | // "if one of the arguments is NaN, then the result is also NaN." |
1123 | // spirv.GL.FMax/FMin |
1124 | // "which operand is the result is undefined if one of the operands |
1125 | // is a NaN." |
1126 | // spirv.CL.fmax/fmin: |
1127 | // "If one argument is a NaN, Fmin returns the other argument." |
1128 | |
1129 | Location loc = op.getLoc(); |
1130 | Value spirvOp = |
1131 | rewriter.create<SPIRVOp>(loc, dstType, adaptor.getOperands()); |
1132 | |
1133 | if (bitEnumContainsAll(op.getFastmath(), arith::FastMathFlags::nnan)) { |
1134 | rewriter.replaceOp(op, spirvOp); |
1135 | return success(); |
1136 | } |
1137 | |
1138 | Value lhsIsNan = rewriter.create<spirv::IsNanOp>(loc, adaptor.getLhs()); |
1139 | Value rhsIsNan = rewriter.create<spirv::IsNanOp>(loc, adaptor.getRhs()); |
1140 | |
1141 | Value select1 = rewriter.create<spirv::SelectOp>(loc, dstType, lhsIsNan, |
1142 | adaptor.getLhs(), spirvOp); |
1143 | Value select2 = rewriter.create<spirv::SelectOp>(loc, dstType, rhsIsNan, |
1144 | adaptor.getRhs(), select1); |
1145 | |
1146 | rewriter.replaceOp(op, select2); |
1147 | return success(); |
1148 | } |
1149 | }; |
1150 | |
1151 | //===----------------------------------------------------------------------===// |
1152 | // MinNumFOp, MaxNumFOp |
1153 | //===----------------------------------------------------------------------===// |
1154 | |
1155 | /// Converts arith.maxnumf/minnumf to spirv.GL.FMax/FMin or |
1156 | /// spirv.CL.fmax/fmin. |
1157 | template <typename Op, typename SPIRVOp> |
1158 | class MinNumMaxNumFOpPattern final : public OpConversionPattern<Op> { |
1159 | template <typename TargetOp> |
1160 | constexpr bool shouldInsertNanGuards() const { |
1161 | return llvm::is_one_of<TargetOp, spirv::GLFMaxOp, spirv::GLFMinOp>::value; |
1162 | } |
1163 | |
1164 | public: |
1165 | using OpConversionPattern<Op>::OpConversionPattern; |
1166 | LogicalResult |
1167 | matchAndRewrite(Op op, typename Op::Adaptor adaptor, |
1168 | ConversionPatternRewriter &rewriter) const override { |
1169 | auto *converter = this->template getTypeConverter<SPIRVTypeConverter>(); |
1170 | Type dstType = converter->convertType(op.getType()); |
1171 | if (!dstType) |
1172 | return getTypeConversionFailure(rewriter, op); |
1173 | |
1174 | // arith.maxnumf/minnumf: |
1175 | // "If one of the arguments is NaN, then the result is the other |
1176 | // argument." |
1177 | // spirv.GL.FMax/FMin |
1178 | // "which operand is the result is undefined if one of the operands |
1179 | // is a NaN." |
1180 | // spirv.CL.fmax/fmin: |
1181 | // "If one argument is a NaN, Fmin returns the other argument." |
1182 | |
1183 | Location loc = op.getLoc(); |
1184 | Value spirvOp = |
1185 | rewriter.create<SPIRVOp>(loc, dstType, adaptor.getOperands()); |
1186 | |
1187 | if (!shouldInsertNanGuards<SPIRVOp>() || |
1188 | bitEnumContainsAll(op.getFastmath(), arith::FastMathFlags::nnan)) { |
1189 | rewriter.replaceOp(op, spirvOp); |
1190 | return success(); |
1191 | } |
1192 | |
1193 | Value lhsIsNan = rewriter.create<spirv::IsNanOp>(loc, adaptor.getLhs()); |
1194 | Value rhsIsNan = rewriter.create<spirv::IsNanOp>(loc, adaptor.getRhs()); |
1195 | |
1196 | Value select1 = rewriter.create<spirv::SelectOp>(loc, dstType, lhsIsNan, |
1197 | adaptor.getRhs(), spirvOp); |
1198 | Value select2 = rewriter.create<spirv::SelectOp>(loc, dstType, rhsIsNan, |
1199 | adaptor.getLhs(), select1); |
1200 | |
1201 | rewriter.replaceOp(op, select2); |
1202 | return success(); |
1203 | } |
1204 | }; |
1205 | |
1206 | } // namespace |
1207 | |
1208 | //===----------------------------------------------------------------------===// |
1209 | // Pattern Population |
1210 | //===----------------------------------------------------------------------===// |
1211 | |
1212 | void mlir::arith::populateArithToSPIRVPatterns( |
1213 | SPIRVTypeConverter &typeConverter, RewritePatternSet &patterns) { |
1214 | // clang-format off |
1215 | patterns.add< |
1216 | ConstantCompositeOpPattern, |
1217 | ConstantScalarOpPattern, |
1218 | ElementwiseArithOpPattern<arith::AddIOp, spirv::IAddOp>, |
1219 | ElementwiseArithOpPattern<arith::SubIOp, spirv::ISubOp>, |
1220 | ElementwiseArithOpPattern<arith::MulIOp, spirv::IMulOp>, |
1221 | spirv::ElementwiseOpPattern<arith::DivUIOp, spirv::UDivOp>, |
1222 | spirv::ElementwiseOpPattern<arith::DivSIOp, spirv::SDivOp>, |
1223 | spirv::ElementwiseOpPattern<arith::RemUIOp, spirv::UModOp>, |
1224 | RemSIOpGLPattern, RemSIOpCLPattern, |
1225 | BitwiseOpPattern<arith::AndIOp, spirv::LogicalAndOp, spirv::BitwiseAndOp>, |
1226 | BitwiseOpPattern<arith::OrIOp, spirv::LogicalOrOp, spirv::BitwiseOrOp>, |
1227 | XOrIOpLogicalPattern, XOrIOpBooleanPattern, |
1228 | ElementwiseArithOpPattern<arith::ShLIOp, spirv::ShiftLeftLogicalOp>, |
1229 | spirv::ElementwiseOpPattern<arith::ShRUIOp, spirv::ShiftRightLogicalOp>, |
1230 | spirv::ElementwiseOpPattern<arith::ShRSIOp, spirv::ShiftRightArithmeticOp>, |
1231 | spirv::ElementwiseOpPattern<arith::NegFOp, spirv::FNegateOp>, |
1232 | spirv::ElementwiseOpPattern<arith::AddFOp, spirv::FAddOp>, |
1233 | spirv::ElementwiseOpPattern<arith::SubFOp, spirv::FSubOp>, |
1234 | spirv::ElementwiseOpPattern<arith::MulFOp, spirv::FMulOp>, |
1235 | spirv::ElementwiseOpPattern<arith::DivFOp, spirv::FDivOp>, |
1236 | spirv::ElementwiseOpPattern<arith::RemFOp, spirv::FRemOp>, |
1237 | ExtUIPattern, ExtUII1Pattern, |
1238 | ExtSIPattern, ExtSII1Pattern, |
1239 | TypeCastingOpPattern<arith::ExtFOp, spirv::FConvertOp>, |
1240 | TruncIPattern, TruncII1Pattern, |
1241 | TypeCastingOpPattern<arith::TruncFOp, spirv::FConvertOp>, |
1242 | TypeCastingOpPattern<arith::UIToFPOp, spirv::ConvertUToFOp>, UIToFPI1Pattern, |
1243 | TypeCastingOpPattern<arith::SIToFPOp, spirv::ConvertSToFOp>, |
1244 | TypeCastingOpPattern<arith::FPToUIOp, spirv::ConvertFToUOp>, |
1245 | TypeCastingOpPattern<arith::FPToSIOp, spirv::ConvertFToSOp>, |
1246 | TypeCastingOpPattern<arith::IndexCastOp, spirv::SConvertOp>, |
1247 | TypeCastingOpPattern<arith::IndexCastUIOp, spirv::UConvertOp>, |
1248 | TypeCastingOpPattern<arith::BitcastOp, spirv::BitcastOp>, |
1249 | CmpIOpBooleanPattern, CmpIOpPattern, |
1250 | CmpFOpNanNonePattern, CmpFOpPattern, |
1251 | AddUIExtendedOpPattern, |
1252 | MulIExtendedOpPattern<arith::MulSIExtendedOp, spirv::SMulExtendedOp>, |
1253 | MulIExtendedOpPattern<arith::MulUIExtendedOp, spirv::UMulExtendedOp>, |
1254 | SelectOpPattern, |
1255 | |
1256 | MinimumMaximumFOpPattern<arith::MaximumFOp, spirv::GLFMaxOp>, |
1257 | MinimumMaximumFOpPattern<arith::MinimumFOp, spirv::GLFMinOp>, |
1258 | MinNumMaxNumFOpPattern<arith::MaxNumFOp, spirv::GLFMaxOp>, |
1259 | MinNumMaxNumFOpPattern<arith::MinNumFOp, spirv::GLFMinOp>, |
1260 | spirv::ElementwiseOpPattern<arith::MaxSIOp, spirv::GLSMaxOp>, |
1261 | spirv::ElementwiseOpPattern<arith::MaxUIOp, spirv::GLUMaxOp>, |
1262 | spirv::ElementwiseOpPattern<arith::MinSIOp, spirv::GLSMinOp>, |
1263 | spirv::ElementwiseOpPattern<arith::MinUIOp, spirv::GLUMinOp>, |
1264 | |
1265 | MinimumMaximumFOpPattern<arith::MaximumFOp, spirv::CLFMaxOp>, |
1266 | MinimumMaximumFOpPattern<arith::MinimumFOp, spirv::CLFMinOp>, |
1267 | MinNumMaxNumFOpPattern<arith::MaxNumFOp, spirv::CLFMaxOp>, |
1268 | MinNumMaxNumFOpPattern<arith::MinNumFOp, spirv::CLFMinOp>, |
1269 | spirv::ElementwiseOpPattern<arith::MaxSIOp, spirv::CLSMaxOp>, |
1270 | spirv::ElementwiseOpPattern<arith::MaxUIOp, spirv::CLUMaxOp>, |
1271 | spirv::ElementwiseOpPattern<arith::MinSIOp, spirv::CLSMinOp>, |
1272 | spirv::ElementwiseOpPattern<arith::MinUIOp, spirv::CLUMinOp> |
1273 | >(typeConverter, patterns.getContext()); |
1274 | // clang-format on |
1275 | |
1276 | // Give CmpFOpNanKernelPattern a higher benefit so it can prevail when Kernel |
1277 | // capability is available. |
1278 | patterns.add<CmpFOpNanKernelPattern>(arg&: typeConverter, args: patterns.getContext(), |
1279 | /*benefit=*/args: 2); |
1280 | } |
1281 | |
1282 | //===----------------------------------------------------------------------===// |
1283 | // Pass Definition |
1284 | //===----------------------------------------------------------------------===// |
1285 | |
1286 | namespace { |
1287 | struct ConvertArithToSPIRVPass |
1288 | : public impl::ConvertArithToSPIRVBase<ConvertArithToSPIRVPass> { |
1289 | void runOnOperation() override { |
1290 | Operation *op = getOperation(); |
1291 | spirv::TargetEnvAttr targetAttr = spirv::lookupTargetEnvOrDefault(op); |
1292 | std::unique_ptr<SPIRVConversionTarget> target = |
1293 | SPIRVConversionTarget::get(targetAttr); |
1294 | |
1295 | SPIRVConversionOptions options; |
1296 | options.emulateLT32BitScalarTypes = this->emulateLT32BitScalarTypes; |
1297 | SPIRVTypeConverter typeConverter(targetAttr, options); |
1298 | |
1299 | // Use UnrealizedConversionCast as the bridge so that we don't need to pull |
1300 | // in patterns for other dialects. |
1301 | target->addLegalOp<UnrealizedConversionCastOp>(); |
1302 | |
1303 | // Fail hard when there are any remaining 'arith' ops. |
1304 | target->addIllegalDialect<arith::ArithDialect>(); |
1305 | |
1306 | RewritePatternSet patterns(&getContext()); |
1307 | arith::populateArithToSPIRVPatterns(typeConverter, patterns); |
1308 | |
1309 | if (failed(applyPartialConversion(op, *target, std::move(patterns)))) |
1310 | signalPassFailure(); |
1311 | } |
1312 | }; |
1313 | } // namespace |
1314 | |
1315 | std::unique_ptr<OperationPass<>> mlir::arith::createConvertArithToSPIRVPass() { |
1316 | return std::make_unique<ConvertArithToSPIRVPass>(); |
1317 | } |
1318 | |