1//===- MathToSPIRV.cpp - Math to SPIR-V Patterns --------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implements patterns to convert Math dialect to SPIR-V dialect.
10//
11//===----------------------------------------------------------------------===//
12
13#include "../SPIRVCommon/Pattern.h"
14#include "mlir/Dialect/Math/IR/Math.h"
15#include "mlir/Dialect/SPIRV/IR/SPIRVOps.h"
16#include "mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h"
17#include "mlir/IR/BuiltinTypes.h"
18#include "mlir/IR/TypeUtilities.h"
19#include "mlir/Transforms/DialectConversion.h"
20#include "llvm/ADT/STLExtras.h"
21#include "llvm/Support/FormatVariadic.h"
22
23#define DEBUG_TYPE "math-to-spirv-pattern"
24
25using namespace mlir;
26
27//===----------------------------------------------------------------------===//
28// Utility functions
29//===----------------------------------------------------------------------===//
30
31/// Creates a 32-bit scalar/vector integer constant. Returns nullptr if the
32/// given type is not a 32-bit scalar/vector type.
33static Value getScalarOrVectorI32Constant(Type type, int value,
34 OpBuilder &builder, Location loc) {
35 if (auto vectorType = dyn_cast<VectorType>(Val&: type)) {
36 if (!vectorType.getElementType().isInteger(width: 32))
37 return nullptr;
38 SmallVector<int> values(vectorType.getNumElements(), value);
39 return builder.create<spirv::ConstantOp>(location: loc, args&: type,
40 args: builder.getI32VectorAttr(values));
41 }
42 if (type.isInteger(width: 32))
43 return builder.create<spirv::ConstantOp>(location: loc, args&: type,
44 args: builder.getI32IntegerAttr(value));
45
46 return nullptr;
47}
48
49/// Check if the type is supported by math-to-spirv conversion. We expect to
50/// only see scalars and vectors at this point, with higher-level types already
51/// lowered.
52static bool isSupportedSourceType(Type originalType) {
53 if (originalType.isIntOrIndexOrFloat())
54 return true;
55
56 if (auto vecTy = dyn_cast<VectorType>(Val&: originalType)) {
57 if (!vecTy.getElementType().isIntOrIndexOrFloat())
58 return false;
59 if (vecTy.isScalable())
60 return false;
61 if (vecTy.getRank() > 1)
62 return false;
63
64 return true;
65 }
66
67 return false;
68}
69
70/// Check if all `sourceOp` types are supported by math-to-spirv conversion.
71/// Notify of a match failure othwerise and return a `failure` result.
72/// This is intended to simplify type checks in `OpConversionPattern`s.
73static LogicalResult checkSourceOpTypes(ConversionPatternRewriter &rewriter,
74 Operation *sourceOp) {
75 auto allTypes = llvm::to_vector(Range: sourceOp->getOperandTypes());
76 llvm::append_range(C&: allTypes, R: sourceOp->getResultTypes());
77
78 for (Type ty : allTypes) {
79 if (!isSupportedSourceType(originalType: ty)) {
80 return rewriter.notifyMatchFailure(
81 arg&: sourceOp,
82 msg: llvm::formatv(
83 Fmt: "unsupported source type for Math to SPIR-V conversion: {0}",
84 Vals&: ty));
85 }
86 }
87
88 return success();
89}
90
91//===----------------------------------------------------------------------===//
92// Operation conversion
93//===----------------------------------------------------------------------===//
94
95// Note that DRR cannot be used for the patterns in this file: we may need to
96// convert type along the way, which requires ConversionPattern. DRR generates
97// normal RewritePattern.
98
99namespace {
100/// Converts elementwise unary, binary, and ternary standard operations to
101/// SPIR-V operations. Checks that source `Op` types are supported.
102template <typename Op, typename SPIRVOp>
103struct CheckedElementwiseOpPattern final
104 : public spirv::ElementwiseOpPattern<Op, SPIRVOp> {
105 using BasePattern = typename spirv::ElementwiseOpPattern<Op, SPIRVOp>;
106 using BasePattern::BasePattern;
107
108 LogicalResult
109 matchAndRewrite(Op op, typename Op::Adaptor adaptor,
110 ConversionPatternRewriter &rewriter) const override {
111 if (LogicalResult res = checkSourceOpTypes(rewriter, op); failed(Result: res))
112 return res;
113
114 return BasePattern::matchAndRewrite(op, adaptor, rewriter);
115 }
116};
117
118/// Converts math.copysign to SPIR-V ops.
119struct CopySignPattern final : public OpConversionPattern<math::CopySignOp> {
120 using OpConversionPattern::OpConversionPattern;
121
122 LogicalResult
123 matchAndRewrite(math::CopySignOp copySignOp, OpAdaptor adaptor,
124 ConversionPatternRewriter &rewriter) const override {
125 if (LogicalResult res = checkSourceOpTypes(rewriter, sourceOp: copySignOp);
126 failed(Result: res))
127 return res;
128
129 Type type = getTypeConverter()->convertType(t: copySignOp.getType());
130 if (!type)
131 return failure();
132
133 FloatType floatType;
134 if (auto scalarType = dyn_cast<FloatType>(Val: copySignOp.getType())) {
135 floatType = scalarType;
136 } else if (auto vectorType = dyn_cast<VectorType>(Val: copySignOp.getType())) {
137 floatType = cast<FloatType>(Val: vectorType.getElementType());
138 } else {
139 return failure();
140 }
141
142 Location loc = copySignOp.getLoc();
143 int bitwidth = floatType.getWidth();
144 Type intType = rewriter.getIntegerType(width: bitwidth);
145 uint64_t intValue = uint64_t(1) << (bitwidth - 1);
146
147 Value signMask = rewriter.create<spirv::ConstantOp>(
148 location: loc, args&: intType, args: rewriter.getIntegerAttr(type: intType, value: intValue));
149 Value valueMask = rewriter.create<spirv::ConstantOp>(
150 location: loc, args&: intType, args: rewriter.getIntegerAttr(type: intType, value: intValue - 1u));
151
152 if (auto vectorType = dyn_cast<VectorType>(Val&: type)) {
153 assert(vectorType.getRank() == 1);
154 int count = vectorType.getNumElements();
155 intType = VectorType::get(shape: count, elementType: intType);
156
157 SmallVector<Value> signSplat(count, signMask);
158 signMask =
159 rewriter.create<spirv::CompositeConstructOp>(location: loc, args&: intType, args&: signSplat);
160
161 SmallVector<Value> valueSplat(count, valueMask);
162 valueMask = rewriter.create<spirv::CompositeConstructOp>(location: loc, args&: intType,
163 args&: valueSplat);
164 }
165
166 Value lhsCast =
167 rewriter.create<spirv::BitcastOp>(location: loc, args&: intType, args: adaptor.getLhs());
168 Value rhsCast =
169 rewriter.create<spirv::BitcastOp>(location: loc, args&: intType, args: adaptor.getRhs());
170
171 Value value = rewriter.create<spirv::BitwiseAndOp>(
172 location: loc, args&: intType, args: ValueRange{lhsCast, valueMask});
173 Value sign = rewriter.create<spirv::BitwiseAndOp>(
174 location: loc, args&: intType, args: ValueRange{rhsCast, signMask});
175
176 Value result = rewriter.create<spirv::BitwiseOrOp>(location: loc, args&: intType,
177 args: ValueRange{value, sign});
178 rewriter.replaceOpWithNewOp<spirv::BitcastOp>(op: copySignOp, args&: type, args&: result);
179 return success();
180 }
181};
182
183/// Converts math.ctlz to SPIR-V ops.
184///
185/// SPIR-V does not have a direct operations for counting leading zeros. If
186/// Shader capability is supported, we can leverage GL FindUMsb to calculate
187/// it.
188struct CountLeadingZerosPattern final
189 : public OpConversionPattern<math::CountLeadingZerosOp> {
190 using OpConversionPattern::OpConversionPattern;
191
192 LogicalResult
193 matchAndRewrite(math::CountLeadingZerosOp countOp, OpAdaptor adaptor,
194 ConversionPatternRewriter &rewriter) const override {
195 if (LogicalResult res = checkSourceOpTypes(rewriter, sourceOp: countOp); failed(Result: res))
196 return res;
197
198 Type type = getTypeConverter()->convertType(t: countOp.getType());
199 if (!type)
200 return failure();
201
202 // We can only support 32-bit integer types for now.
203 unsigned bitwidth = 0;
204 if (isa<IntegerType>(Val: type))
205 bitwidth = type.getIntOrFloatBitWidth();
206 if (auto vectorType = dyn_cast<VectorType>(Val&: type))
207 bitwidth = vectorType.getElementTypeBitWidth();
208 if (bitwidth != 32)
209 return failure();
210
211 Location loc = countOp.getLoc();
212 Value input = adaptor.getOperand();
213 Value val1 = getScalarOrVectorI32Constant(type, value: 1, builder&: rewriter, loc);
214 Value val31 = getScalarOrVectorI32Constant(type, value: 31, builder&: rewriter, loc);
215 Value val32 = getScalarOrVectorI32Constant(type, value: 32, builder&: rewriter, loc);
216
217 Value msb = rewriter.create<spirv::GLFindUMsbOp>(location: loc, args&: input);
218 // We need to subtract from 31 given that the index returned by GLSL
219 // FindUMsb is counted from the least significant bit. Theoretically this
220 // also gives the correct result even if the integer has all zero bits, in
221 // which case GL FindUMsb would return -1.
222 Value subMsb = rewriter.create<spirv::ISubOp>(location: loc, args&: val31, args&: msb);
223 // However, certain Vulkan implementations have driver bugs for the corner
224 // case where the input is zero. And.. it can be smart to optimize a select
225 // only involving the corner case. So separately compute the result when the
226 // input is either zero or one.
227 Value subInput = rewriter.create<spirv::ISubOp>(location: loc, args&: val32, args&: input);
228 Value cmp = rewriter.create<spirv::ULessThanEqualOp>(location: loc, args&: input, args&: val1);
229 rewriter.replaceOpWithNewOp<spirv::SelectOp>(op: countOp, args&: cmp, args&: subInput,
230 args&: subMsb);
231 return success();
232 }
233};
234
235/// Converts math.expm1 to SPIR-V ops.
236///
237/// SPIR-V does not have a direct operations for exp(x)-1. Explicitly lower to
238/// these operations.
239template <typename ExpOp>
240struct ExpM1OpPattern final : public OpConversionPattern<math::ExpM1Op> {
241 using OpConversionPattern::OpConversionPattern;
242
243 LogicalResult
244 matchAndRewrite(math::ExpM1Op operation, OpAdaptor adaptor,
245 ConversionPatternRewriter &rewriter) const override {
246 assert(adaptor.getOperands().size() == 1);
247 if (LogicalResult res = checkSourceOpTypes(rewriter, sourceOp: operation);
248 failed(Result: res))
249 return res;
250
251 Location loc = operation.getLoc();
252 Type type = this->getTypeConverter()->convertType(operation.getType());
253 if (!type)
254 return failure();
255
256 Value exp = rewriter.create<ExpOp>(loc, type, adaptor.getOperand());
257 auto one = spirv::ConstantOp::getOne(type, loc, builder&: rewriter);
258 rewriter.replaceOpWithNewOp<spirv::FSubOp>(op: operation, args&: exp, args&: one);
259 return success();
260 }
261};
262
263/// Converts math.log1p to SPIR-V ops.
264///
265/// SPIR-V does not have a direct operations for log(1+x). Explicitly lower to
266/// these operations.
267template <typename LogOp>
268struct Log1pOpPattern final : public OpConversionPattern<math::Log1pOp> {
269 using OpConversionPattern::OpConversionPattern;
270
271 LogicalResult
272 matchAndRewrite(math::Log1pOp operation, OpAdaptor adaptor,
273 ConversionPatternRewriter &rewriter) const override {
274 assert(adaptor.getOperands().size() == 1);
275 if (LogicalResult res = checkSourceOpTypes(rewriter, sourceOp: operation);
276 failed(Result: res))
277 return res;
278
279 Location loc = operation.getLoc();
280 Type type = this->getTypeConverter()->convertType(operation.getType());
281 if (!type)
282 return failure();
283
284 auto one = spirv::ConstantOp::getOne(type, loc: operation.getLoc(), builder&: rewriter);
285 Value onePlus =
286 rewriter.create<spirv::FAddOp>(location: loc, args&: one, args: adaptor.getOperand());
287 rewriter.replaceOpWithNewOp<LogOp>(operation, type, onePlus);
288 return success();
289 }
290};
291
292/// Converts math.log2 and math.log10 to SPIR-V ops.
293///
294/// SPIR-V does not have direct operations for log2 and log10. Explicitly
295/// lower to these operations using:
296/// log2(x) = log(x) * 1/log(2)
297/// log10(x) = log(x) * 1/log(10)
298
299template <typename MathLogOp, typename SpirvLogOp>
300struct Log2Log10OpPattern final : public OpConversionPattern<MathLogOp> {
301 using OpConversionPattern<MathLogOp>::OpConversionPattern;
302 using typename OpConversionPattern<MathLogOp>::OpAdaptor;
303
304 static constexpr double log2Reciprocal =
305 1.442695040888963407359924681001892137426645954152985934135449407;
306 static constexpr double log10Reciprocal =
307 0.4342944819032518276511289189166050822943970058036665661144537832;
308
309 LogicalResult
310 matchAndRewrite(MathLogOp operation, OpAdaptor adaptor,
311 ConversionPatternRewriter &rewriter) const override {
312 assert(adaptor.getOperands().size() == 1);
313 if (LogicalResult res = checkSourceOpTypes(rewriter, operation);
314 failed(Result: res))
315 return res;
316
317 Location loc = operation.getLoc();
318 Type type = this->getTypeConverter()->convertType(operation.getType());
319 if (!type)
320 return rewriter.notifyMatchFailure(operation, "type conversion failed");
321
322 auto getConstantValue = [&](double value) {
323 if (auto floatType = dyn_cast<FloatType>(Val&: type)) {
324 return rewriter.create<spirv::ConstantOp>(
325 location: loc, args&: type, args: rewriter.getFloatAttr(type: floatType, value));
326 }
327 if (auto vectorType = dyn_cast<VectorType>(Val&: type)) {
328 Type elemType = vectorType.getElementType();
329
330 if (isa<FloatType>(Val: elemType)) {
331 return rewriter.create<spirv::ConstantOp>(
332 location: loc, args&: type,
333 args: DenseFPElementsAttr::get(
334 type: vectorType, arg: FloatAttr::get(type: elemType, value).getValue()));
335 }
336 }
337
338 llvm_unreachable("unimplemented types for log2/log10");
339 };
340
341 Value constantValue = getConstantValue(
342 std::is_same<MathLogOp, math::Log2Op>() ? log2Reciprocal
343 : log10Reciprocal);
344 Value log = rewriter.create<SpirvLogOp>(loc, adaptor.getOperand());
345 rewriter.replaceOpWithNewOp<spirv::FMulOp>(operation, type, log,
346 constantValue);
347 return success();
348 }
349};
350
351/// Converts math.powf to SPIRV-Ops.
352struct PowFOpPattern final : public OpConversionPattern<math::PowFOp> {
353 using OpConversionPattern::OpConversionPattern;
354
355 LogicalResult
356 matchAndRewrite(math::PowFOp powfOp, OpAdaptor adaptor,
357 ConversionPatternRewriter &rewriter) const override {
358 if (LogicalResult res = checkSourceOpTypes(rewriter, sourceOp: powfOp); failed(Result: res))
359 return res;
360
361 Type dstType = getTypeConverter()->convertType(t: powfOp.getType());
362 if (!dstType)
363 return failure();
364
365 // Get the scalar float type.
366 FloatType scalarFloatType;
367 if (auto scalarType = dyn_cast<FloatType>(Val: powfOp.getType())) {
368 scalarFloatType = scalarType;
369 } else if (auto vectorType = dyn_cast<VectorType>(Val: powfOp.getType())) {
370 scalarFloatType = cast<FloatType>(Val: vectorType.getElementType());
371 } else {
372 return failure();
373 }
374
375 // Get int type of the same shape as the float type.
376 Type scalarIntType = rewriter.getIntegerType(width: 32);
377 Type intType = scalarIntType;
378 auto operandType = adaptor.getRhs().getType();
379 if (auto vectorType = dyn_cast<VectorType>(Val&: operandType)) {
380 auto shape = vectorType.getShape();
381 intType = VectorType::get(shape, elementType: scalarIntType);
382 }
383
384 // Per GL Pow extended instruction spec:
385 // "Result is undefined if x < 0. Result is undefined if x = 0 and y <= 0."
386 Location loc = powfOp.getLoc();
387 Value zero = spirv::ConstantOp::getZero(type: operandType, loc, builder&: rewriter);
388 Value lessThan =
389 rewriter.create<spirv::FOrdLessThanOp>(location: loc, args: adaptor.getLhs(), args&: zero);
390
391 // Per C/C++ spec:
392 // > pow(base, exponent) returns NaN (and raises FE_INVALID) if base is
393 // > finite and negative and exponent is finite and non-integer.
394 // Calculate the reminder from the exponent and check whether it is zero.
395 Value floatOne = spirv::ConstantOp::getOne(type: operandType, loc, builder&: rewriter);
396 Value expRem =
397 rewriter.create<spirv::FRemOp>(location: loc, args: adaptor.getRhs(), args&: floatOne);
398 Value expRemNonZero =
399 rewriter.create<spirv::FOrdNotEqualOp>(location: loc, args&: expRem, args&: zero);
400 Value cmpNegativeWithFractionalExp =
401 rewriter.create<spirv::LogicalAndOp>(location: loc, args&: expRemNonZero, args&: lessThan);
402 // Create NaN result and replace base value if conditions are met.
403 const auto &floatSemantics = scalarFloatType.getFloatSemantics();
404 const auto nan = APFloat::getNaN(Sem: floatSemantics);
405 Attribute nanAttr = rewriter.getFloatAttr(type: scalarFloatType, value: nan);
406 if (auto vectorType = dyn_cast<VectorType>(Val&: operandType))
407 nanAttr = DenseElementsAttr::get(type: vectorType, values: nan);
408
409 Value NanValue =
410 rewriter.create<spirv::ConstantOp>(location: loc, args&: operandType, args&: nanAttr);
411 Value lhs = rewriter.create<spirv::SelectOp>(
412 location: loc, args&: cmpNegativeWithFractionalExp, args&: NanValue, args: adaptor.getLhs());
413 Value abs = rewriter.create<spirv::GLFAbsOp>(location: loc, args&: lhs);
414
415 // TODO: The following just forcefully casts y into an integer value in
416 // order to properly propagate the sign, assuming integer y cases. It
417 // doesn't cover other cases and should be fixed.
418
419 // Cast exponent to integer and calculate exponent % 2 != 0.
420 Value intRhs =
421 rewriter.create<spirv::ConvertFToSOp>(location: loc, args&: intType, args: adaptor.getRhs());
422 Value intOne = spirv::ConstantOp::getOne(type: intType, loc, builder&: rewriter);
423 Value bitwiseAndOne =
424 rewriter.create<spirv::BitwiseAndOp>(location: loc, args&: intRhs, args&: intOne);
425 Value isOdd = rewriter.create<spirv::IEqualOp>(location: loc, args&: bitwiseAndOne, args&: intOne);
426
427 // calculate pow based on abs(lhs)^rhs.
428 Value pow = rewriter.create<spirv::GLPowOp>(location: loc, args&: abs, args: adaptor.getRhs());
429 Value negate = rewriter.create<spirv::FNegateOp>(location: loc, args&: pow);
430 // if the exponent is odd and lhs < 0, negate the result.
431 Value shouldNegate =
432 rewriter.create<spirv::LogicalAndOp>(location: loc, args&: lessThan, args&: isOdd);
433 rewriter.replaceOpWithNewOp<spirv::SelectOp>(op: powfOp, args&: shouldNegate, args&: negate,
434 args&: pow);
435 return success();
436 }
437};
438
439/// Converts math.round to GLSL SPIRV extended ops.
440struct RoundOpPattern final : public OpConversionPattern<math::RoundOp> {
441 using OpConversionPattern::OpConversionPattern;
442
443 LogicalResult
444 matchAndRewrite(math::RoundOp roundOp, OpAdaptor adaptor,
445 ConversionPatternRewriter &rewriter) const override {
446 if (LogicalResult res = checkSourceOpTypes(rewriter, sourceOp: roundOp); failed(Result: res))
447 return res;
448
449 Location loc = roundOp.getLoc();
450 Value operand = roundOp.getOperand();
451 Type ty = operand.getType();
452 Type ety = getElementTypeOrSelf(type: ty);
453
454 auto zero = spirv::ConstantOp::getZero(type: ty, loc, builder&: rewriter);
455 auto one = spirv::ConstantOp::getOne(type: ty, loc, builder&: rewriter);
456 Value half;
457 if (VectorType vty = dyn_cast<VectorType>(Val&: ty)) {
458 half = rewriter.create<spirv::ConstantOp>(
459 location: loc, args&: vty,
460 args: DenseElementsAttr::get(type: vty,
461 values: rewriter.getFloatAttr(type: ety, value: 0.5).getValue()));
462 } else {
463 half = rewriter.create<spirv::ConstantOp>(
464 location: loc, args&: ty, args: rewriter.getFloatAttr(type: ety, value: 0.5));
465 }
466
467 auto abs = rewriter.create<spirv::GLFAbsOp>(location: loc, args&: operand);
468 auto floor = rewriter.create<spirv::GLFloorOp>(location: loc, args&: abs);
469 auto sub = rewriter.create<spirv::FSubOp>(location: loc, args&: abs, args&: floor);
470 auto greater =
471 rewriter.create<spirv::FOrdGreaterThanEqualOp>(location: loc, args&: sub, args&: half);
472 auto select = rewriter.create<spirv::SelectOp>(location: loc, args&: greater, args&: one, args&: zero);
473 auto add = rewriter.create<spirv::FAddOp>(location: loc, args&: floor, args&: select);
474 rewriter.replaceOpWithNewOp<math::CopySignOp>(op: roundOp, args&: add, args&: operand);
475 return success();
476 }
477};
478
479} // namespace
480
481//===----------------------------------------------------------------------===//
482// Pattern population
483//===----------------------------------------------------------------------===//
484
485namespace mlir {
486void populateMathToSPIRVPatterns(const SPIRVTypeConverter &typeConverter,
487 RewritePatternSet &patterns) {
488 // Core patterns
489 patterns.add<CopySignPattern>(arg: typeConverter, args: patterns.getContext());
490
491 // GLSL patterns
492 patterns
493 .add<CountLeadingZerosPattern, Log1pOpPattern<spirv::GLLogOp>,
494 Log2Log10OpPattern<math::Log2Op, spirv::GLLogOp>,
495 Log2Log10OpPattern<math::Log10Op, spirv::GLLogOp>,
496 ExpM1OpPattern<spirv::GLExpOp>, PowFOpPattern, RoundOpPattern,
497 CheckedElementwiseOpPattern<math::AbsFOp, spirv::GLFAbsOp>,
498 CheckedElementwiseOpPattern<math::AbsIOp, spirv::GLSAbsOp>,
499 CheckedElementwiseOpPattern<math::AtanOp, spirv::GLAtanOp>,
500 CheckedElementwiseOpPattern<math::CeilOp, spirv::GLCeilOp>,
501 CheckedElementwiseOpPattern<math::CosOp, spirv::GLCosOp>,
502 CheckedElementwiseOpPattern<math::ExpOp, spirv::GLExpOp>,
503 CheckedElementwiseOpPattern<math::FloorOp, spirv::GLFloorOp>,
504 CheckedElementwiseOpPattern<math::FmaOp, spirv::GLFmaOp>,
505 CheckedElementwiseOpPattern<math::LogOp, spirv::GLLogOp>,
506 CheckedElementwiseOpPattern<math::RoundEvenOp, spirv::GLRoundEvenOp>,
507 CheckedElementwiseOpPattern<math::RsqrtOp, spirv::GLInverseSqrtOp>,
508 CheckedElementwiseOpPattern<math::SinOp, spirv::GLSinOp>,
509 CheckedElementwiseOpPattern<math::SqrtOp, spirv::GLSqrtOp>,
510 CheckedElementwiseOpPattern<math::TanhOp, spirv::GLTanhOp>,
511 CheckedElementwiseOpPattern<math::TanOp, spirv::GLTanOp>,
512 CheckedElementwiseOpPattern<math::AsinOp, spirv::GLAsinOp>,
513 CheckedElementwiseOpPattern<math::AcosOp, spirv::GLAcosOp>,
514 CheckedElementwiseOpPattern<math::SinhOp, spirv::GLSinhOp>,
515 CheckedElementwiseOpPattern<math::CoshOp, spirv::GLCoshOp>,
516 CheckedElementwiseOpPattern<math::AsinhOp, spirv::GLAsinhOp>,
517 CheckedElementwiseOpPattern<math::AcoshOp, spirv::GLAcoshOp>,
518 CheckedElementwiseOpPattern<math::AtanhOp, spirv::GLAtanhOp>>(
519 arg: typeConverter, args: patterns.getContext());
520
521 // OpenCL patterns
522 patterns.add<Log1pOpPattern<spirv::CLLogOp>, ExpM1OpPattern<spirv::CLExpOp>,
523 Log2Log10OpPattern<math::Log2Op, spirv::CLLogOp>,
524 Log2Log10OpPattern<math::Log10Op, spirv::CLLogOp>,
525 CheckedElementwiseOpPattern<math::AbsFOp, spirv::CLFAbsOp>,
526 CheckedElementwiseOpPattern<math::AbsIOp, spirv::CLSAbsOp>,
527 CheckedElementwiseOpPattern<math::AtanOp, spirv::CLAtanOp>,
528 CheckedElementwiseOpPattern<math::Atan2Op, spirv::CLAtan2Op>,
529 CheckedElementwiseOpPattern<math::CeilOp, spirv::CLCeilOp>,
530 CheckedElementwiseOpPattern<math::CosOp, spirv::CLCosOp>,
531 CheckedElementwiseOpPattern<math::ErfOp, spirv::CLErfOp>,
532 CheckedElementwiseOpPattern<math::ExpOp, spirv::CLExpOp>,
533 CheckedElementwiseOpPattern<math::FloorOp, spirv::CLFloorOp>,
534 CheckedElementwiseOpPattern<math::FmaOp, spirv::CLFmaOp>,
535 CheckedElementwiseOpPattern<math::LogOp, spirv::CLLogOp>,
536 CheckedElementwiseOpPattern<math::PowFOp, spirv::CLPowOp>,
537 CheckedElementwiseOpPattern<math::RoundEvenOp, spirv::CLRintOp>,
538 CheckedElementwiseOpPattern<math::RoundOp, spirv::CLRoundOp>,
539 CheckedElementwiseOpPattern<math::RsqrtOp, spirv::CLRsqrtOp>,
540 CheckedElementwiseOpPattern<math::SinOp, spirv::CLSinOp>,
541 CheckedElementwiseOpPattern<math::SqrtOp, spirv::CLSqrtOp>,
542 CheckedElementwiseOpPattern<math::TanhOp, spirv::CLTanhOp>,
543 CheckedElementwiseOpPattern<math::TanOp, spirv::CLTanOp>,
544 CheckedElementwiseOpPattern<math::AsinOp, spirv::CLAsinOp>,
545 CheckedElementwiseOpPattern<math::AcosOp, spirv::CLAcosOp>,
546 CheckedElementwiseOpPattern<math::SinhOp, spirv::CLSinhOp>,
547 CheckedElementwiseOpPattern<math::CoshOp, spirv::CLCoshOp>,
548 CheckedElementwiseOpPattern<math::AsinhOp, spirv::CLAsinhOp>,
549 CheckedElementwiseOpPattern<math::AcoshOp, spirv::CLAcoshOp>,
550 CheckedElementwiseOpPattern<math::AtanhOp, spirv::CLAtanhOp>>(
551 arg: typeConverter, args: patterns.getContext());
552}
553
554} // namespace mlir
555

source code of mlir/lib/Conversion/MathToSPIRV/MathToSPIRV.cpp