1 | //===- SPIRVWebGPUTransforms.cpp - WebGPU-specific transforms -------------===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | // |
9 | // This file implements SPIR-V transforms used when targetting WebGPU. |
10 | // |
11 | //===----------------------------------------------------------------------===// |
12 | |
13 | #include "mlir/Dialect/SPIRV/Transforms/SPIRVWebGPUTransforms.h" |
14 | #include "mlir/Dialect/SPIRV/IR/SPIRVOps.h" |
15 | #include "mlir/Dialect/SPIRV/Transforms/Passes.h" |
16 | #include "mlir/IR/BuiltinAttributes.h" |
17 | #include "mlir/IR/Location.h" |
18 | #include "mlir/IR/PatternMatch.h" |
19 | #include "mlir/IR/TypeUtilities.h" |
20 | #include "mlir/Transforms/GreedyPatternRewriteDriver.h" |
21 | #include "llvm/ADT/ArrayRef.h" |
22 | #include "llvm/ADT/STLExtras.h" |
23 | #include "llvm/Support/FormatVariadic.h" |
24 | |
25 | #include <array> |
26 | #include <cstdint> |
27 | |
28 | namespace mlir { |
29 | namespace spirv { |
30 | #define GEN_PASS_DEF_SPIRVWEBGPUPREPAREPASS |
31 | #include "mlir/Dialect/SPIRV/Transforms/Passes.h.inc" |
32 | } // namespace spirv |
33 | } // namespace mlir |
34 | |
35 | namespace mlir { |
36 | namespace spirv { |
37 | namespace { |
38 | //===----------------------------------------------------------------------===// |
39 | // Helpers |
40 | //===----------------------------------------------------------------------===// |
41 | static Attribute getScalarOrSplatAttr(Type type, int64_t value) { |
42 | APInt sizedValue(getElementTypeOrSelf(type).getIntOrFloatBitWidth(), value); |
43 | if (auto intTy = dyn_cast<IntegerType>(type)) |
44 | return IntegerAttr::get(intTy, sizedValue); |
45 | |
46 | return SplatElementsAttr::get(cast<ShapedType>(type), sizedValue); |
47 | } |
48 | |
49 | static Value lowerExtendedMultiplication(Operation *mulOp, |
50 | PatternRewriter &rewriter, Value lhs, |
51 | Value rhs, bool signExtendArguments) { |
52 | Location loc = mulOp->getLoc(); |
53 | Type argTy = lhs.getType(); |
54 | // Emulate 64-bit multiplication by splitting each input element of type i32 |
55 | // into 2 16-bit digits of type i32. This is so that the intermediate |
56 | // multiplications and additions do not overflow. We extract these 16-bit |
57 | // digits from i32 vector elements by masking (low digit) and shifting right |
58 | // (high digit). |
59 | // |
60 | // The multiplication algorithm used is the standard (long) multiplication. |
61 | // Multiplying two i32 integers produces 64 bits of result, i.e., 4 16-bit |
62 | // digits. |
63 | // - With zero-extended arguments, we end up emitting only 4 multiplications |
64 | // and 4 additions after constant folding. |
65 | // - With sign-extended arguments, we end up emitting 8 multiplications and |
66 | // and 12 additions after CSE. |
67 | Value cstLowMask = rewriter.create<ConstantOp>( |
68 | loc, lhs.getType(), getScalarOrSplatAttr(argTy, (1 << 16) - 1)); |
69 | auto getLowDigit = [&rewriter, loc, cstLowMask](Value val) { |
70 | return rewriter.create<BitwiseAndOp>(loc, val, cstLowMask); |
71 | }; |
72 | |
73 | Value cst16 = rewriter.create<ConstantOp>(loc, lhs.getType(), |
74 | getScalarOrSplatAttr(argTy, 16)); |
75 | auto getHighDigit = [&rewriter, loc, cst16](Value val) { |
76 | return rewriter.create<ShiftRightLogicalOp>(loc, val, cst16); |
77 | }; |
78 | |
79 | auto getSignDigit = [&rewriter, loc, cst16, &getHighDigit](Value val) { |
80 | // We only need to shift arithmetically by 15, but the extra |
81 | // sign-extension bit will be truncated by the logical shift, so this is |
82 | // fine. We do not have to introduce an extra constant since any |
83 | // value in [15, 32) would do. |
84 | return getHighDigit( |
85 | rewriter.create<ShiftRightArithmeticOp>(loc, val, cst16)); |
86 | }; |
87 | |
88 | Value cst0 = rewriter.create<ConstantOp>(loc, lhs.getType(), |
89 | getScalarOrSplatAttr(argTy, 0)); |
90 | |
91 | Value lhsLow = getLowDigit(lhs); |
92 | Value lhsHigh = getHighDigit(lhs); |
93 | Value lhsExt = signExtendArguments ? getSignDigit(lhs) : cst0; |
94 | Value rhsLow = getLowDigit(rhs); |
95 | Value rhsHigh = getHighDigit(rhs); |
96 | Value rhsExt = signExtendArguments ? getSignDigit(rhs) : cst0; |
97 | |
98 | std::array<Value, 4> lhsDigits = {lhsLow, lhsHigh, lhsExt, lhsExt}; |
99 | std::array<Value, 4> rhsDigits = {rhsLow, rhsHigh, rhsExt, rhsExt}; |
100 | std::array<Value, 4> resultDigits = {cst0, cst0, cst0, cst0}; |
101 | |
102 | for (auto [i, lhsDigit] : llvm::enumerate(lhsDigits)) { |
103 | for (auto [j, rhsDigit] : llvm::enumerate(rhsDigits)) { |
104 | if (i + j >= resultDigits.size()) |
105 | continue; |
106 | |
107 | if (lhsDigit == cst0 || rhsDigit == cst0) |
108 | continue; |
109 | |
110 | Value &thisResDigit = resultDigits[i + j]; |
111 | Value mul = rewriter.create<IMulOp>(loc, lhsDigit, rhsDigit); |
112 | Value current = rewriter.createOrFold<IAddOp>(loc, thisResDigit, mul); |
113 | thisResDigit = getLowDigit(current); |
114 | |
115 | if (i + j + 1 != resultDigits.size()) { |
116 | Value &nextResDigit = resultDigits[i + j + 1]; |
117 | Value carry = rewriter.createOrFold<IAddOp>(loc, nextResDigit, |
118 | getHighDigit(current)); |
119 | nextResDigit = carry; |
120 | } |
121 | } |
122 | } |
123 | |
124 | auto combineDigits = [loc, cst16, &rewriter](Value low, Value high) { |
125 | Value highBits = rewriter.create<ShiftLeftLogicalOp>(loc, high, cst16); |
126 | return rewriter.create<BitwiseOrOp>(loc, low, highBits); |
127 | }; |
128 | Value low = combineDigits(resultDigits[0], resultDigits[1]); |
129 | Value high = combineDigits(resultDigits[2], resultDigits[3]); |
130 | |
131 | return rewriter.create<CompositeConstructOp>( |
132 | loc, mulOp->getResultTypes().front(), llvm::ArrayRef({low, high})); |
133 | } |
134 | |
135 | //===----------------------------------------------------------------------===// |
136 | // Rewrite Patterns |
137 | //===----------------------------------------------------------------------===// |
138 | |
139 | template <typename MulExtendedOp, bool SignExtendArguments> |
140 | struct ExpandMulExtendedPattern final : OpRewritePattern<MulExtendedOp> { |
141 | using OpRewritePattern<MulExtendedOp>::OpRewritePattern; |
142 | |
143 | LogicalResult matchAndRewrite(MulExtendedOp op, |
144 | PatternRewriter &rewriter) const override { |
145 | Location loc = op->getLoc(); |
146 | Value lhs = op.getOperand1(); |
147 | Value rhs = op.getOperand2(); |
148 | |
149 | // Currently, WGSL only supports 32-bit integer types. Any other integer |
150 | // types should already have been promoted/demoted to i32. |
151 | auto elemTy = cast<IntegerType>(getElementTypeOrSelf(type: lhs.getType())); |
152 | if (elemTy.getIntOrFloatBitWidth() != 32) |
153 | return rewriter.notifyMatchFailure( |
154 | loc, |
155 | llvm::formatv("Unexpected integer type for WebGPU: '{0}'" , elemTy)); |
156 | |
157 | Value mul = lowerExtendedMultiplication(op, rewriter, lhs, rhs, |
158 | SignExtendArguments); |
159 | rewriter.replaceOp(op, mul); |
160 | return success(); |
161 | } |
162 | }; |
163 | |
164 | using ExpandSMulExtendedPattern = |
165 | ExpandMulExtendedPattern<SMulExtendedOp, true>; |
166 | using ExpandUMulExtendedPattern = |
167 | ExpandMulExtendedPattern<UMulExtendedOp, false>; |
168 | |
169 | struct ExpandAddCarryPattern final : OpRewritePattern<IAddCarryOp> { |
170 | using OpRewritePattern<IAddCarryOp>::OpRewritePattern; |
171 | |
172 | LogicalResult matchAndRewrite(IAddCarryOp op, |
173 | PatternRewriter &rewriter) const override { |
174 | Location loc = op->getLoc(); |
175 | Value lhs = op.getOperand1(); |
176 | Value rhs = op.getOperand2(); |
177 | |
178 | // Currently, WGSL only supports 32-bit integer types. Any other integer |
179 | // types should already have been promoted/demoted to i32. |
180 | Type argTy = lhs.getType(); |
181 | auto elemTy = cast<IntegerType>(getElementTypeOrSelf(type: argTy)); |
182 | if (elemTy.getIntOrFloatBitWidth() != 32) |
183 | return rewriter.notifyMatchFailure( |
184 | loc, |
185 | llvm::formatv("Unexpected integer type for WebGPU: '{0}'" , elemTy)); |
186 | |
187 | Value one = |
188 | rewriter.create<ConstantOp>(loc, argTy, getScalarOrSplatAttr(argTy, 1)); |
189 | Value zero = |
190 | rewriter.create<ConstantOp>(loc, argTy, getScalarOrSplatAttr(argTy, 0)); |
191 | |
192 | // Calculate the carry by checking if the addition resulted in an overflow. |
193 | Value out = rewriter.create<IAddOp>(loc, lhs, rhs); |
194 | Value cmp = rewriter.create<ULessThanOp>(loc, out, lhs); |
195 | Value carry = rewriter.create<SelectOp>(loc, cmp, one, zero); |
196 | |
197 | Value add = rewriter.create<CompositeConstructOp>( |
198 | loc, op->getResultTypes().front(), llvm::ArrayRef({out, carry})); |
199 | |
200 | rewriter.replaceOp(op, add); |
201 | return success(); |
202 | } |
203 | }; |
204 | |
205 | struct ExpandIsInfPattern final : OpRewritePattern<IsInfOp> { |
206 | using OpRewritePattern::OpRewritePattern; |
207 | |
208 | LogicalResult matchAndRewrite(IsInfOp op, |
209 | PatternRewriter &rewriter) const override { |
210 | // We assume values to be finite and turn `IsInf` info `false`. |
211 | rewriter.replaceOpWithNewOp<spirv::ConstantOp>( |
212 | op, op.getType(), getScalarOrSplatAttr(op.getType(), 0)); |
213 | return success(); |
214 | } |
215 | }; |
216 | |
217 | struct ExpandIsNanPattern final : OpRewritePattern<IsNanOp> { |
218 | using OpRewritePattern::OpRewritePattern; |
219 | |
220 | LogicalResult matchAndRewrite(IsNanOp op, |
221 | PatternRewriter &rewriter) const override { |
222 | // We assume values to be finite and turn `IsNan` info `false`. |
223 | rewriter.replaceOpWithNewOp<spirv::ConstantOp>( |
224 | op, op.getType(), getScalarOrSplatAttr(op.getType(), 0)); |
225 | return success(); |
226 | } |
227 | }; |
228 | |
229 | //===----------------------------------------------------------------------===// |
230 | // Passes |
231 | //===----------------------------------------------------------------------===// |
232 | struct WebGPUPreparePass final |
233 | : impl::SPIRVWebGPUPreparePassBase<WebGPUPreparePass> { |
234 | void runOnOperation() override { |
235 | RewritePatternSet patterns(&getContext()); |
236 | populateSPIRVExpandExtendedMultiplicationPatterns(patterns); |
237 | populateSPIRVExpandNonFiniteArithmeticPatterns(patterns); |
238 | |
239 | if (failed(applyPatternsGreedily(getOperation(), std::move(patterns)))) |
240 | signalPassFailure(); |
241 | } |
242 | }; |
243 | } // namespace |
244 | |
245 | //===----------------------------------------------------------------------===// |
246 | // Public Interface |
247 | //===----------------------------------------------------------------------===// |
248 | void populateSPIRVExpandExtendedMultiplicationPatterns( |
249 | RewritePatternSet &patterns) { |
250 | // WGSL currently does not support extended multiplication ops, see: |
251 | // https://github.com/gpuweb/gpuweb/issues/1565. |
252 | patterns.add<ExpandSMulExtendedPattern, ExpandUMulExtendedPattern, |
253 | ExpandAddCarryPattern>(patterns.getContext()); |
254 | } |
255 | |
256 | void populateSPIRVExpandNonFiniteArithmeticPatterns( |
257 | RewritePatternSet &patterns) { |
258 | // WGSL currently does not support `isInf` and `isNan`, see: |
259 | // https://github.com/gpuweb/gpuweb/pull/2311. |
260 | patterns.add<ExpandIsInfPattern, ExpandIsNanPattern>(patterns.getContext()); |
261 | } |
262 | |
263 | } // namespace spirv |
264 | } // namespace mlir |
265 | |