1//===- MathToSPIRV.cpp - Math to SPIR-V Patterns --------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implements patterns to convert Math dialect to SPIR-V dialect.
10//
11//===----------------------------------------------------------------------===//
12
13#include "../SPIRVCommon/Pattern.h"
14#include "mlir/Dialect/Math/IR/Math.h"
15#include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h"
16#include "mlir/Dialect/SPIRV/IR/SPIRVOps.h"
17#include "mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h"
18#include "mlir/IR/BuiltinTypes.h"
19#include "mlir/IR/TypeUtilities.h"
20#include "mlir/Transforms/DialectConversion.h"
21#include "llvm/ADT/STLExtras.h"
22#include "llvm/Support/Debug.h"
23#include "llvm/Support/FormatVariadic.h"
24
25#define DEBUG_TYPE "math-to-spirv-pattern"
26
27using namespace mlir;
28
29//===----------------------------------------------------------------------===//
30// Utility functions
31//===----------------------------------------------------------------------===//
32
33/// Creates a 32-bit scalar/vector integer constant. Returns nullptr if the
34/// given type is not a 32-bit scalar/vector type.
35static Value getScalarOrVectorI32Constant(Type type, int value,
36 OpBuilder &builder, Location loc) {
37 if (auto vectorType = dyn_cast<VectorType>(type)) {
38 if (!vectorType.getElementType().isInteger(32))
39 return nullptr;
40 SmallVector<int> values(vectorType.getNumElements(), value);
41 return builder.create<spirv::ConstantOp>(loc, type,
42 builder.getI32VectorAttr(values));
43 }
44 if (type.isInteger(32))
45 return builder.create<spirv::ConstantOp>(loc, type,
46 builder.getI32IntegerAttr(value));
47
48 return nullptr;
49}
50
51/// Check if the type is supported by math-to-spirv conversion. We expect to
52/// only see scalars and vectors at this point, with higher-level types already
53/// lowered.
54static bool isSupportedSourceType(Type originalType) {
55 if (originalType.isIntOrIndexOrFloat())
56 return true;
57
58 if (auto vecTy = dyn_cast<VectorType>(originalType)) {
59 if (!vecTy.getElementType().isIntOrIndexOrFloat())
60 return false;
61 if (vecTy.isScalable())
62 return false;
63 if (vecTy.getRank() > 1)
64 return false;
65
66 return true;
67 }
68
69 return false;
70}
71
72/// Check if all `sourceOp` types are supported by math-to-spirv conversion.
73/// Notify of a match failure othwerise and return a `failure` result.
74/// This is intended to simplify type checks in `OpConversionPattern`s.
75static LogicalResult checkSourceOpTypes(ConversionPatternRewriter &rewriter,
76 Operation *sourceOp) {
77 auto allTypes = llvm::to_vector(Range: sourceOp->getOperandTypes());
78 llvm::append_range(C&: allTypes, R: sourceOp->getResultTypes());
79
80 for (Type ty : allTypes) {
81 if (!isSupportedSourceType(originalType: ty)) {
82 return rewriter.notifyMatchFailure(
83 arg&: sourceOp,
84 msg: llvm::formatv(
85 Fmt: "unsupported source type for Math to SPIR-V conversion: {0}",
86 Vals&: ty));
87 }
88 }
89
90 return success();
91}
92
93//===----------------------------------------------------------------------===//
94// Operation conversion
95//===----------------------------------------------------------------------===//
96
97// Note that DRR cannot be used for the patterns in this file: we may need to
98// convert type along the way, which requires ConversionPattern. DRR generates
99// normal RewritePattern.
100
101namespace {
102/// Converts elementwise unary, binary, and ternary standard operations to
103/// SPIR-V operations. Checks that source `Op` types are supported.
104template <typename Op, typename SPIRVOp>
105struct CheckedElementwiseOpPattern final
106 : public spirv::ElementwiseOpPattern<Op, SPIRVOp> {
107 using BasePattern = typename spirv::ElementwiseOpPattern<Op, SPIRVOp>;
108 using BasePattern::BasePattern;
109
110 LogicalResult
111 matchAndRewrite(Op op, typename Op::Adaptor adaptor,
112 ConversionPatternRewriter &rewriter) const override {
113 if (LogicalResult res = checkSourceOpTypes(rewriter, op); failed(result: res))
114 return res;
115
116 return BasePattern::matchAndRewrite(op, adaptor, rewriter);
117 }
118};
119
120/// Converts math.copysign to SPIR-V ops.
121struct CopySignPattern final : public OpConversionPattern<math::CopySignOp> {
122 using OpConversionPattern::OpConversionPattern;
123
124 LogicalResult
125 matchAndRewrite(math::CopySignOp copySignOp, OpAdaptor adaptor,
126 ConversionPatternRewriter &rewriter) const override {
127 if (LogicalResult res = checkSourceOpTypes(rewriter, copySignOp);
128 failed(result: res))
129 return res;
130
131 Type type = getTypeConverter()->convertType(copySignOp.getType());
132 if (!type)
133 return failure();
134
135 FloatType floatType;
136 if (auto scalarType = dyn_cast<FloatType>(copySignOp.getType())) {
137 floatType = scalarType;
138 } else if (auto vectorType = dyn_cast<VectorType>(copySignOp.getType())) {
139 floatType = cast<FloatType>(vectorType.getElementType());
140 } else {
141 return failure();
142 }
143
144 Location loc = copySignOp.getLoc();
145 int bitwidth = floatType.getWidth();
146 Type intType = rewriter.getIntegerType(bitwidth);
147 uint64_t intValue = uint64_t(1) << (bitwidth - 1);
148
149 Value signMask = rewriter.create<spirv::ConstantOp>(
150 loc, intType, rewriter.getIntegerAttr(intType, intValue));
151 Value valueMask = rewriter.create<spirv::ConstantOp>(
152 loc, intType, rewriter.getIntegerAttr(intType, intValue - 1u));
153
154 if (auto vectorType = dyn_cast<VectorType>(type)) {
155 assert(vectorType.getRank() == 1);
156 int count = vectorType.getNumElements();
157 intType = VectorType::get(count, intType);
158
159 SmallVector<Value> signSplat(count, signMask);
160 signMask =
161 rewriter.create<spirv::CompositeConstructOp>(loc, intType, signSplat);
162
163 SmallVector<Value> valueSplat(count, valueMask);
164 valueMask = rewriter.create<spirv::CompositeConstructOp>(loc, intType,
165 valueSplat);
166 }
167
168 Value lhsCast =
169 rewriter.create<spirv::BitcastOp>(loc, intType, adaptor.getLhs());
170 Value rhsCast =
171 rewriter.create<spirv::BitcastOp>(loc, intType, adaptor.getRhs());
172
173 Value value = rewriter.create<spirv::BitwiseAndOp>(
174 loc, intType, ValueRange{lhsCast, valueMask});
175 Value sign = rewriter.create<spirv::BitwiseAndOp>(
176 loc, intType, ValueRange{rhsCast, signMask});
177
178 Value result = rewriter.create<spirv::BitwiseOrOp>(loc, intType,
179 ValueRange{value, sign});
180 rewriter.replaceOpWithNewOp<spirv::BitcastOp>(copySignOp, type, result);
181 return success();
182 }
183};
184
185/// Converts math.ctlz to SPIR-V ops.
186///
187/// SPIR-V does not have a direct operations for counting leading zeros. If
188/// Shader capability is supported, we can leverage GL FindUMsb to calculate
189/// it.
190struct CountLeadingZerosPattern final
191 : public OpConversionPattern<math::CountLeadingZerosOp> {
192 using OpConversionPattern::OpConversionPattern;
193
194 LogicalResult
195 matchAndRewrite(math::CountLeadingZerosOp countOp, OpAdaptor adaptor,
196 ConversionPatternRewriter &rewriter) const override {
197 if (LogicalResult res = checkSourceOpTypes(rewriter, countOp); failed(result: res))
198 return res;
199
200 Type type = getTypeConverter()->convertType(countOp.getType());
201 if (!type)
202 return failure();
203
204 // We can only support 32-bit integer types for now.
205 unsigned bitwidth = 0;
206 if (isa<IntegerType>(Val: type))
207 bitwidth = type.getIntOrFloatBitWidth();
208 if (auto vectorType = dyn_cast<VectorType>(type))
209 bitwidth = vectorType.getElementTypeBitWidth();
210 if (bitwidth != 32)
211 return failure();
212
213 Location loc = countOp.getLoc();
214 Value input = adaptor.getOperand();
215 Value val1 = getScalarOrVectorI32Constant(type, value: 1, builder&: rewriter, loc);
216 Value val31 = getScalarOrVectorI32Constant(type, value: 31, builder&: rewriter, loc);
217 Value val32 = getScalarOrVectorI32Constant(type, value: 32, builder&: rewriter, loc);
218
219 Value msb = rewriter.create<spirv::GLFindUMsbOp>(loc, input);
220 // We need to subtract from 31 given that the index returned by GLSL
221 // FindUMsb is counted from the least significant bit. Theoretically this
222 // also gives the correct result even if the integer has all zero bits, in
223 // which case GL FindUMsb would return -1.
224 Value subMsb = rewriter.create<spirv::ISubOp>(loc, val31, msb);
225 // However, certain Vulkan implementations have driver bugs for the corner
226 // case where the input is zero. And.. it can be smart to optimize a select
227 // only involving the corner case. So separately compute the result when the
228 // input is either zero or one.
229 Value subInput = rewriter.create<spirv::ISubOp>(loc, val32, input);
230 Value cmp = rewriter.create<spirv::ULessThanEqualOp>(loc, input, val1);
231 rewriter.replaceOpWithNewOp<spirv::SelectOp>(countOp, cmp, subInput,
232 subMsb);
233 return success();
234 }
235};
236
237/// Converts math.expm1 to SPIR-V ops.
238///
239/// SPIR-V does not have a direct operations for exp(x)-1. Explicitly lower to
240/// these operations.
241template <typename ExpOp>
242struct ExpM1OpPattern final : public OpConversionPattern<math::ExpM1Op> {
243 using OpConversionPattern::OpConversionPattern;
244
245 LogicalResult
246 matchAndRewrite(math::ExpM1Op operation, OpAdaptor adaptor,
247 ConversionPatternRewriter &rewriter) const override {
248 assert(adaptor.getOperands().size() == 1);
249 if (LogicalResult res = checkSourceOpTypes(rewriter, operation);
250 failed(result: res))
251 return res;
252
253 Location loc = operation.getLoc();
254 Type type = this->getTypeConverter()->convertType(operation.getType());
255 if (!type)
256 return failure();
257
258 Value exp = rewriter.create<ExpOp>(loc, type, adaptor.getOperand());
259 auto one = spirv::ConstantOp::getOne(type, loc, rewriter);
260 rewriter.replaceOpWithNewOp<spirv::FSubOp>(operation, exp, one);
261 return success();
262 }
263};
264
265/// Converts math.log1p to SPIR-V ops.
266///
267/// SPIR-V does not have a direct operations for log(1+x). Explicitly lower to
268/// these operations.
269template <typename LogOp>
270struct Log1pOpPattern final : public OpConversionPattern<math::Log1pOp> {
271 using OpConversionPattern::OpConversionPattern;
272
273 LogicalResult
274 matchAndRewrite(math::Log1pOp operation, OpAdaptor adaptor,
275 ConversionPatternRewriter &rewriter) const override {
276 assert(adaptor.getOperands().size() == 1);
277 if (LogicalResult res = checkSourceOpTypes(rewriter, operation);
278 failed(result: res))
279 return res;
280
281 Location loc = operation.getLoc();
282 Type type = this->getTypeConverter()->convertType(operation.getType());
283 if (!type)
284 return failure();
285
286 auto one = spirv::ConstantOp::getOne(type, operation.getLoc(), rewriter);
287 Value onePlus =
288 rewriter.create<spirv::FAddOp>(loc, one, adaptor.getOperand());
289 rewriter.replaceOpWithNewOp<LogOp>(operation, type, onePlus);
290 return success();
291 }
292};
293
294/// Converts math.powf to SPIRV-Ops.
295struct PowFOpPattern final : public OpConversionPattern<math::PowFOp> {
296 using OpConversionPattern::OpConversionPattern;
297
298 LogicalResult
299 matchAndRewrite(math::PowFOp powfOp, OpAdaptor adaptor,
300 ConversionPatternRewriter &rewriter) const override {
301 if (LogicalResult res = checkSourceOpTypes(rewriter, powfOp); failed(result: res))
302 return res;
303
304 Type dstType = getTypeConverter()->convertType(powfOp.getType());
305 if (!dstType)
306 return failure();
307
308 // Get the scalar float type.
309 FloatType scalarFloatType;
310 if (auto scalarType = dyn_cast<FloatType>(powfOp.getType())) {
311 scalarFloatType = scalarType;
312 } else if (auto vectorType = dyn_cast<VectorType>(powfOp.getType())) {
313 scalarFloatType = cast<FloatType>(vectorType.getElementType());
314 } else {
315 return failure();
316 }
317
318 // Get int type of the same shape as the float type.
319 Type scalarIntType = rewriter.getIntegerType(32);
320 Type intType = scalarIntType;
321 if (auto vectorType = dyn_cast<VectorType>(adaptor.getRhs().getType())) {
322 auto shape = vectorType.getShape();
323 intType = VectorType::get(shape, scalarIntType);
324 }
325
326 // Per GL Pow extended instruction spec:
327 // "Result is undefined if x < 0. Result is undefined if x = 0 and y <= 0."
328 Location loc = powfOp.getLoc();
329 Value zero =
330 spirv::ConstantOp::getZero(adaptor.getLhs().getType(), loc, rewriter);
331 Value lessThan =
332 rewriter.create<spirv::FOrdLessThanOp>(loc, adaptor.getLhs(), zero);
333 Value abs = rewriter.create<spirv::GLFAbsOp>(loc, adaptor.getLhs());
334
335 // TODO: The following just forcefully casts y into an integer value in
336 // order to properly propagate the sign, assuming integer y cases. It
337 // doesn't cover other cases and should be fixed.
338
339 // Cast exponent to integer and calculate exponent % 2 != 0.
340 Value intRhs =
341 rewriter.create<spirv::ConvertFToSOp>(loc, intType, adaptor.getRhs());
342 Value intOne = spirv::ConstantOp::getOne(intType, loc, rewriter);
343 Value bitwiseAndOne =
344 rewriter.create<spirv::BitwiseAndOp>(loc, intRhs, intOne);
345 Value isOdd = rewriter.create<spirv::IEqualOp>(loc, bitwiseAndOne, intOne);
346
347 // calculate pow based on abs(lhs)^rhs.
348 Value pow = rewriter.create<spirv::GLPowOp>(loc, abs, adaptor.getRhs());
349 Value negate = rewriter.create<spirv::FNegateOp>(loc, pow);
350 // if the exponent is odd and lhs < 0, negate the result.
351 Value shouldNegate =
352 rewriter.create<spirv::LogicalAndOp>(loc, lessThan, isOdd);
353 rewriter.replaceOpWithNewOp<spirv::SelectOp>(powfOp, shouldNegate, negate,
354 pow);
355 return success();
356 }
357};
358
359/// Converts math.round to GLSL SPIRV extended ops.
360struct RoundOpPattern final : public OpConversionPattern<math::RoundOp> {
361 using OpConversionPattern::OpConversionPattern;
362
363 LogicalResult
364 matchAndRewrite(math::RoundOp roundOp, OpAdaptor adaptor,
365 ConversionPatternRewriter &rewriter) const override {
366 if (LogicalResult res = checkSourceOpTypes(rewriter, roundOp); failed(result: res))
367 return res;
368
369 Location loc = roundOp.getLoc();
370 Value operand = roundOp.getOperand();
371 Type ty = operand.getType();
372 Type ety = getElementTypeOrSelf(type: ty);
373
374 auto zero = spirv::ConstantOp::getZero(ty, loc, rewriter);
375 auto one = spirv::ConstantOp::getOne(ty, loc, rewriter);
376 Value half;
377 if (VectorType vty = dyn_cast<VectorType>(ty)) {
378 half = rewriter.create<spirv::ConstantOp>(
379 loc, vty,
380 DenseElementsAttr::get(vty,
381 rewriter.getFloatAttr(ety, 0.5).getValue()));
382 } else {
383 half = rewriter.create<spirv::ConstantOp>(
384 loc, ty, rewriter.getFloatAttr(ety, 0.5));
385 }
386
387 auto abs = rewriter.create<spirv::GLFAbsOp>(loc, operand);
388 auto floor = rewriter.create<spirv::GLFloorOp>(loc, abs);
389 auto sub = rewriter.create<spirv::FSubOp>(loc, abs, floor);
390 auto greater =
391 rewriter.create<spirv::FOrdGreaterThanEqualOp>(loc, sub, half);
392 auto select = rewriter.create<spirv::SelectOp>(loc, greater, one, zero);
393 auto add = rewriter.create<spirv::FAddOp>(loc, floor, select);
394 rewriter.replaceOpWithNewOp<math::CopySignOp>(roundOp, add, operand);
395 return success();
396 }
397};
398
399} // namespace
400
401//===----------------------------------------------------------------------===//
402// Pattern population
403//===----------------------------------------------------------------------===//
404
405namespace mlir {
406void populateMathToSPIRVPatterns(SPIRVTypeConverter &typeConverter,
407 RewritePatternSet &patterns) {
408 // Core patterns
409 patterns.add<CopySignPattern>(arg&: typeConverter, args: patterns.getContext());
410
411 // GLSL patterns
412 patterns
413 .add<CountLeadingZerosPattern, Log1pOpPattern<spirv::GLLogOp>,
414 ExpM1OpPattern<spirv::GLExpOp>, PowFOpPattern, RoundOpPattern,
415 CheckedElementwiseOpPattern<math::AbsFOp, spirv::GLFAbsOp>,
416 CheckedElementwiseOpPattern<math::AbsIOp, spirv::GLSAbsOp>,
417 CheckedElementwiseOpPattern<math::CeilOp, spirv::GLCeilOp>,
418 CheckedElementwiseOpPattern<math::CosOp, spirv::GLCosOp>,
419 CheckedElementwiseOpPattern<math::ExpOp, spirv::GLExpOp>,
420 CheckedElementwiseOpPattern<math::FloorOp, spirv::GLFloorOp>,
421 CheckedElementwiseOpPattern<math::FmaOp, spirv::GLFmaOp>,
422 CheckedElementwiseOpPattern<math::LogOp, spirv::GLLogOp>,
423 CheckedElementwiseOpPattern<math::RoundEvenOp, spirv::GLRoundEvenOp>,
424 CheckedElementwiseOpPattern<math::RsqrtOp, spirv::GLInverseSqrtOp>,
425 CheckedElementwiseOpPattern<math::SinOp, spirv::GLSinOp>,
426 CheckedElementwiseOpPattern<math::SqrtOp, spirv::GLSqrtOp>,
427 CheckedElementwiseOpPattern<math::TanhOp, spirv::GLTanhOp>>(
428 typeConverter, patterns.getContext());
429
430 // OpenCL patterns
431 patterns.add<Log1pOpPattern<spirv::CLLogOp>, ExpM1OpPattern<spirv::CLExpOp>,
432 CheckedElementwiseOpPattern<math::AbsFOp, spirv::CLFAbsOp>,
433 CheckedElementwiseOpPattern<math::AbsIOp, spirv::CLSAbsOp>,
434 CheckedElementwiseOpPattern<math::CeilOp, spirv::CLCeilOp>,
435 CheckedElementwiseOpPattern<math::CosOp, spirv::CLCosOp>,
436 CheckedElementwiseOpPattern<math::ErfOp, spirv::CLErfOp>,
437 CheckedElementwiseOpPattern<math::ExpOp, spirv::CLExpOp>,
438 CheckedElementwiseOpPattern<math::FloorOp, spirv::CLFloorOp>,
439 CheckedElementwiseOpPattern<math::FmaOp, spirv::CLFmaOp>,
440 CheckedElementwiseOpPattern<math::LogOp, spirv::CLLogOp>,
441 CheckedElementwiseOpPattern<math::PowFOp, spirv::CLPowOp>,
442 CheckedElementwiseOpPattern<math::RoundEvenOp, spirv::CLRintOp>,
443 CheckedElementwiseOpPattern<math::RoundOp, spirv::CLRoundOp>,
444 CheckedElementwiseOpPattern<math::RsqrtOp, spirv::CLRsqrtOp>,
445 CheckedElementwiseOpPattern<math::SinOp, spirv::CLSinOp>,
446 CheckedElementwiseOpPattern<math::SqrtOp, spirv::CLSqrtOp>,
447 CheckedElementwiseOpPattern<math::TanhOp, spirv::CLTanhOp>>(
448 typeConverter, patterns.getContext());
449}
450
451} // namespace mlir
452

source code of mlir/lib/Conversion/MathToSPIRV/MathToSPIRV.cpp