1//===- SPIRVWebGPUTransforms.cpp - WebGPU-specific transforms -------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implements SPIR-V transforms used when targetting WebGPU.
10//
11//===----------------------------------------------------------------------===//
12
13#include "mlir/Dialect/SPIRV/Transforms/SPIRVWebGPUTransforms.h"
14#include "mlir/Dialect/SPIRV/IR/SPIRVOps.h"
15#include "mlir/Dialect/SPIRV/Transforms/Passes.h"
16#include "mlir/IR/BuiltinAttributes.h"
17#include "mlir/IR/Location.h"
18#include "mlir/IR/PatternMatch.h"
19#include "mlir/IR/TypeUtilities.h"
20#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
21#include "llvm/ADT/ArrayRef.h"
22#include "llvm/ADT/STLExtras.h"
23#include "llvm/Support/FormatVariadic.h"
24
25#include <array>
26#include <cstdint>
27
28namespace mlir {
29namespace spirv {
30#define GEN_PASS_DEF_SPIRVWEBGPUPREPAREPASS
31#include "mlir/Dialect/SPIRV/Transforms/Passes.h.inc"
32} // namespace spirv
33} // namespace mlir
34
35namespace mlir {
36namespace spirv {
37namespace {
38//===----------------------------------------------------------------------===//
39// Helpers
40//===----------------------------------------------------------------------===//
41static Attribute getScalarOrSplatAttr(Type type, int64_t value) {
42 APInt sizedValue(getElementTypeOrSelf(type).getIntOrFloatBitWidth(), value);
43 if (auto intTy = dyn_cast<IntegerType>(type))
44 return IntegerAttr::get(intTy, sizedValue);
45
46 return SplatElementsAttr::get(cast<ShapedType>(type), sizedValue);
47}
48
49static Value lowerExtendedMultiplication(Operation *mulOp,
50 PatternRewriter &rewriter, Value lhs,
51 Value rhs, bool signExtendArguments) {
52 Location loc = mulOp->getLoc();
53 Type argTy = lhs.getType();
54 // Emulate 64-bit multiplication by splitting each input element of type i32
55 // into 2 16-bit digits of type i32. This is so that the intermediate
56 // multiplications and additions do not overflow. We extract these 16-bit
57 // digits from i32 vector elements by masking (low digit) and shifting right
58 // (high digit).
59 //
60 // The multiplication algorithm used is the standard (long) multiplication.
61 // Multiplying two i32 integers produces 64 bits of result, i.e., 4 16-bit
62 // digits.
63 // - With zero-extended arguments, we end up emitting only 4 multiplications
64 // and 4 additions after constant folding.
65 // - With sign-extended arguments, we end up emitting 8 multiplications and
66 // and 12 additions after CSE.
67 Value cstLowMask = rewriter.create<ConstantOp>(
68 loc, lhs.getType(), getScalarOrSplatAttr(argTy, (1 << 16) - 1));
69 auto getLowDigit = [&rewriter, loc, cstLowMask](Value val) {
70 return rewriter.create<BitwiseAndOp>(loc, val, cstLowMask);
71 };
72
73 Value cst16 = rewriter.create<ConstantOp>(loc, lhs.getType(),
74 getScalarOrSplatAttr(argTy, 16));
75 auto getHighDigit = [&rewriter, loc, cst16](Value val) {
76 return rewriter.create<ShiftRightLogicalOp>(loc, val, cst16);
77 };
78
79 auto getSignDigit = [&rewriter, loc, cst16, &getHighDigit](Value val) {
80 // We only need to shift arithmetically by 15, but the extra
81 // sign-extension bit will be truncated by the logical shift, so this is
82 // fine. We do not have to introduce an extra constant since any
83 // value in [15, 32) would do.
84 return getHighDigit(
85 rewriter.create<ShiftRightArithmeticOp>(loc, val, cst16));
86 };
87
88 Value cst0 = rewriter.create<ConstantOp>(loc, lhs.getType(),
89 getScalarOrSplatAttr(argTy, 0));
90
91 Value lhsLow = getLowDigit(lhs);
92 Value lhsHigh = getHighDigit(lhs);
93 Value lhsExt = signExtendArguments ? getSignDigit(lhs) : cst0;
94 Value rhsLow = getLowDigit(rhs);
95 Value rhsHigh = getHighDigit(rhs);
96 Value rhsExt = signExtendArguments ? getSignDigit(rhs) : cst0;
97
98 std::array<Value, 4> lhsDigits = {lhsLow, lhsHigh, lhsExt, lhsExt};
99 std::array<Value, 4> rhsDigits = {rhsLow, rhsHigh, rhsExt, rhsExt};
100 std::array<Value, 4> resultDigits = {cst0, cst0, cst0, cst0};
101
102 for (auto [i, lhsDigit] : llvm::enumerate(lhsDigits)) {
103 for (auto [j, rhsDigit] : llvm::enumerate(rhsDigits)) {
104 if (i + j >= resultDigits.size())
105 continue;
106
107 if (lhsDigit == cst0 || rhsDigit == cst0)
108 continue;
109
110 Value &thisResDigit = resultDigits[i + j];
111 Value mul = rewriter.create<IMulOp>(loc, lhsDigit, rhsDigit);
112 Value current = rewriter.createOrFold<IAddOp>(loc, thisResDigit, mul);
113 thisResDigit = getLowDigit(current);
114
115 if (i + j + 1 != resultDigits.size()) {
116 Value &nextResDigit = resultDigits[i + j + 1];
117 Value carry = rewriter.createOrFold<IAddOp>(loc, nextResDigit,
118 getHighDigit(current));
119 nextResDigit = carry;
120 }
121 }
122 }
123
124 auto combineDigits = [loc, cst16, &rewriter](Value low, Value high) {
125 Value highBits = rewriter.create<ShiftLeftLogicalOp>(loc, high, cst16);
126 return rewriter.create<BitwiseOrOp>(loc, low, highBits);
127 };
128 Value low = combineDigits(resultDigits[0], resultDigits[1]);
129 Value high = combineDigits(resultDigits[2], resultDigits[3]);
130
131 return rewriter.create<CompositeConstructOp>(
132 loc, mulOp->getResultTypes().front(), llvm::ArrayRef({low, high}));
133}
134
135//===----------------------------------------------------------------------===//
136// Rewrite Patterns
137//===----------------------------------------------------------------------===//
138
139template <typename MulExtendedOp, bool SignExtendArguments>
140struct ExpandMulExtendedPattern final : OpRewritePattern<MulExtendedOp> {
141 using OpRewritePattern<MulExtendedOp>::OpRewritePattern;
142
143 LogicalResult matchAndRewrite(MulExtendedOp op,
144 PatternRewriter &rewriter) const override {
145 Location loc = op->getLoc();
146 Value lhs = op.getOperand1();
147 Value rhs = op.getOperand2();
148
149 // Currently, WGSL only supports 32-bit integer types. Any other integer
150 // types should already have been promoted/demoted to i32.
151 auto elemTy = cast<IntegerType>(getElementTypeOrSelf(type: lhs.getType()));
152 if (elemTy.getIntOrFloatBitWidth() != 32)
153 return rewriter.notifyMatchFailure(
154 loc,
155 llvm::formatv("Unexpected integer type for WebGPU: '{0}'", elemTy));
156
157 Value mul = lowerExtendedMultiplication(op, rewriter, lhs, rhs,
158 SignExtendArguments);
159 rewriter.replaceOp(op, mul);
160 return success();
161 }
162};
163
164using ExpandSMulExtendedPattern =
165 ExpandMulExtendedPattern<SMulExtendedOp, true>;
166using ExpandUMulExtendedPattern =
167 ExpandMulExtendedPattern<UMulExtendedOp, false>;
168
169struct ExpandAddCarryPattern final : OpRewritePattern<IAddCarryOp> {
170 using OpRewritePattern<IAddCarryOp>::OpRewritePattern;
171
172 LogicalResult matchAndRewrite(IAddCarryOp op,
173 PatternRewriter &rewriter) const override {
174 Location loc = op->getLoc();
175 Value lhs = op.getOperand1();
176 Value rhs = op.getOperand2();
177
178 // Currently, WGSL only supports 32-bit integer types. Any other integer
179 // types should already have been promoted/demoted to i32.
180 Type argTy = lhs.getType();
181 auto elemTy = cast<IntegerType>(getElementTypeOrSelf(type: argTy));
182 if (elemTy.getIntOrFloatBitWidth() != 32)
183 return rewriter.notifyMatchFailure(
184 loc,
185 llvm::formatv("Unexpected integer type for WebGPU: '{0}'", elemTy));
186
187 Value one =
188 rewriter.create<ConstantOp>(loc, argTy, getScalarOrSplatAttr(argTy, 1));
189 Value zero =
190 rewriter.create<ConstantOp>(loc, argTy, getScalarOrSplatAttr(argTy, 0));
191
192 // Calculate the carry by checking if the addition resulted in an overflow.
193 Value out = rewriter.create<IAddOp>(loc, lhs, rhs);
194 Value cmp = rewriter.create<ULessThanOp>(loc, out, lhs);
195 Value carry = rewriter.create<SelectOp>(loc, cmp, one, zero);
196
197 Value add = rewriter.create<CompositeConstructOp>(
198 loc, op->getResultTypes().front(), llvm::ArrayRef({out, carry}));
199
200 rewriter.replaceOp(op, add);
201 return success();
202 }
203};
204
205struct ExpandIsInfPattern final : OpRewritePattern<IsInfOp> {
206 using OpRewritePattern::OpRewritePattern;
207
208 LogicalResult matchAndRewrite(IsInfOp op,
209 PatternRewriter &rewriter) const override {
210 // We assume values to be finite and turn `IsInf` info `false`.
211 rewriter.replaceOpWithNewOp<spirv::ConstantOp>(
212 op, op.getType(), getScalarOrSplatAttr(op.getType(), 0));
213 return success();
214 }
215};
216
217struct ExpandIsNanPattern final : OpRewritePattern<IsNanOp> {
218 using OpRewritePattern::OpRewritePattern;
219
220 LogicalResult matchAndRewrite(IsNanOp op,
221 PatternRewriter &rewriter) const override {
222 // We assume values to be finite and turn `IsNan` info `false`.
223 rewriter.replaceOpWithNewOp<spirv::ConstantOp>(
224 op, op.getType(), getScalarOrSplatAttr(op.getType(), 0));
225 return success();
226 }
227};
228
229//===----------------------------------------------------------------------===//
230// Passes
231//===----------------------------------------------------------------------===//
232struct WebGPUPreparePass final
233 : impl::SPIRVWebGPUPreparePassBase<WebGPUPreparePass> {
234 void runOnOperation() override {
235 RewritePatternSet patterns(&getContext());
236 populateSPIRVExpandExtendedMultiplicationPatterns(patterns);
237 populateSPIRVExpandNonFiniteArithmeticPatterns(patterns);
238
239 if (failed(applyPatternsGreedily(getOperation(), std::move(patterns))))
240 signalPassFailure();
241 }
242};
243} // namespace
244
245//===----------------------------------------------------------------------===//
246// Public Interface
247//===----------------------------------------------------------------------===//
248void populateSPIRVExpandExtendedMultiplicationPatterns(
249 RewritePatternSet &patterns) {
250 // WGSL currently does not support extended multiplication ops, see:
251 // https://github.com/gpuweb/gpuweb/issues/1565.
252 patterns.add<ExpandSMulExtendedPattern, ExpandUMulExtendedPattern,
253 ExpandAddCarryPattern>(patterns.getContext());
254}
255
256void populateSPIRVExpandNonFiniteArithmeticPatterns(
257 RewritePatternSet &patterns) {
258 // WGSL currently does not support `isInf` and `isNan`, see:
259 // https://github.com/gpuweb/gpuweb/pull/2311.
260 patterns.add<ExpandIsInfPattern, ExpandIsNanPattern>(patterns.getContext());
261}
262
263} // namespace spirv
264} // namespace mlir
265

Provided by KDAB

Privacy Policy
Learn to use CMake with our Intro Training
Find out more

source code of mlir/lib/Dialect/SPIRV/Transforms/SPIRVWebGPUTransforms.cpp