| 1 | //===- SPIRVWebGPUTransforms.cpp - WebGPU-specific transforms -------------===// |
| 2 | // |
| 3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
| 4 | // See https://llvm.org/LICENSE.txt for license information. |
| 5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
| 6 | // |
| 7 | //===----------------------------------------------------------------------===// |
| 8 | // |
| 9 | // This file implements SPIR-V transforms used when targetting WebGPU. |
| 10 | // |
| 11 | //===----------------------------------------------------------------------===// |
| 12 | |
| 13 | #include "mlir/Dialect/SPIRV/Transforms/SPIRVWebGPUTransforms.h" |
| 14 | #include "mlir/Dialect/SPIRV/IR/SPIRVOps.h" |
| 15 | #include "mlir/Dialect/SPIRV/Transforms/Passes.h" |
| 16 | #include "mlir/IR/BuiltinAttributes.h" |
| 17 | #include "mlir/IR/Location.h" |
| 18 | #include "mlir/IR/PatternMatch.h" |
| 19 | #include "mlir/IR/TypeUtilities.h" |
| 20 | #include "mlir/Transforms/GreedyPatternRewriteDriver.h" |
| 21 | #include "llvm/ADT/ArrayRef.h" |
| 22 | #include "llvm/ADT/STLExtras.h" |
| 23 | #include "llvm/Support/FormatVariadic.h" |
| 24 | |
| 25 | #include <array> |
| 26 | #include <cstdint> |
| 27 | |
| 28 | namespace mlir { |
| 29 | namespace spirv { |
| 30 | #define GEN_PASS_DEF_SPIRVWEBGPUPREPAREPASS |
| 31 | #include "mlir/Dialect/SPIRV/Transforms/Passes.h.inc" |
| 32 | } // namespace spirv |
| 33 | } // namespace mlir |
| 34 | |
| 35 | namespace mlir { |
| 36 | namespace spirv { |
| 37 | namespace { |
| 38 | //===----------------------------------------------------------------------===// |
| 39 | // Helpers |
| 40 | //===----------------------------------------------------------------------===// |
| 41 | static Attribute getScalarOrSplatAttr(Type type, int64_t value) { |
| 42 | APInt sizedValue(getElementTypeOrSelf(type).getIntOrFloatBitWidth(), value); |
| 43 | if (auto intTy = dyn_cast<IntegerType>(type)) |
| 44 | return IntegerAttr::get(intTy, sizedValue); |
| 45 | |
| 46 | return SplatElementsAttr::get(cast<ShapedType>(type), sizedValue); |
| 47 | } |
| 48 | |
| 49 | static Value lowerExtendedMultiplication(Operation *mulOp, |
| 50 | PatternRewriter &rewriter, Value lhs, |
| 51 | Value rhs, bool signExtendArguments) { |
| 52 | Location loc = mulOp->getLoc(); |
| 53 | Type argTy = lhs.getType(); |
| 54 | // Emulate 64-bit multiplication by splitting each input element of type i32 |
| 55 | // into 2 16-bit digits of type i32. This is so that the intermediate |
| 56 | // multiplications and additions do not overflow. We extract these 16-bit |
| 57 | // digits from i32 vector elements by masking (low digit) and shifting right |
| 58 | // (high digit). |
| 59 | // |
| 60 | // The multiplication algorithm used is the standard (long) multiplication. |
| 61 | // Multiplying two i32 integers produces 64 bits of result, i.e., 4 16-bit |
| 62 | // digits. |
| 63 | // - With zero-extended arguments, we end up emitting only 4 multiplications |
| 64 | // and 4 additions after constant folding. |
| 65 | // - With sign-extended arguments, we end up emitting 8 multiplications and |
| 66 | // and 12 additions after CSE. |
| 67 | Value cstLowMask = rewriter.create<ConstantOp>( |
| 68 | loc, lhs.getType(), getScalarOrSplatAttr(argTy, (1 << 16) - 1)); |
| 69 | auto getLowDigit = [&rewriter, loc, cstLowMask](Value val) { |
| 70 | return rewriter.create<BitwiseAndOp>(loc, val, cstLowMask); |
| 71 | }; |
| 72 | |
| 73 | Value cst16 = rewriter.create<ConstantOp>(loc, lhs.getType(), |
| 74 | getScalarOrSplatAttr(argTy, 16)); |
| 75 | auto getHighDigit = [&rewriter, loc, cst16](Value val) { |
| 76 | return rewriter.create<ShiftRightLogicalOp>(loc, val, cst16); |
| 77 | }; |
| 78 | |
| 79 | auto getSignDigit = [&rewriter, loc, cst16, &getHighDigit](Value val) { |
| 80 | // We only need to shift arithmetically by 15, but the extra |
| 81 | // sign-extension bit will be truncated by the logical shift, so this is |
| 82 | // fine. We do not have to introduce an extra constant since any |
| 83 | // value in [15, 32) would do. |
| 84 | return getHighDigit( |
| 85 | rewriter.create<ShiftRightArithmeticOp>(loc, val, cst16)); |
| 86 | }; |
| 87 | |
| 88 | Value cst0 = rewriter.create<ConstantOp>(loc, lhs.getType(), |
| 89 | getScalarOrSplatAttr(argTy, 0)); |
| 90 | |
| 91 | Value lhsLow = getLowDigit(lhs); |
| 92 | Value lhsHigh = getHighDigit(lhs); |
| 93 | Value lhsExt = signExtendArguments ? getSignDigit(lhs) : cst0; |
| 94 | Value rhsLow = getLowDigit(rhs); |
| 95 | Value rhsHigh = getHighDigit(rhs); |
| 96 | Value rhsExt = signExtendArguments ? getSignDigit(rhs) : cst0; |
| 97 | |
| 98 | std::array<Value, 4> lhsDigits = {lhsLow, lhsHigh, lhsExt, lhsExt}; |
| 99 | std::array<Value, 4> rhsDigits = {rhsLow, rhsHigh, rhsExt, rhsExt}; |
| 100 | std::array<Value, 4> resultDigits = {cst0, cst0, cst0, cst0}; |
| 101 | |
| 102 | for (auto [i, lhsDigit] : llvm::enumerate(lhsDigits)) { |
| 103 | for (auto [j, rhsDigit] : llvm::enumerate(rhsDigits)) { |
| 104 | if (i + j >= resultDigits.size()) |
| 105 | continue; |
| 106 | |
| 107 | if (lhsDigit == cst0 || rhsDigit == cst0) |
| 108 | continue; |
| 109 | |
| 110 | Value &thisResDigit = resultDigits[i + j]; |
| 111 | Value mul = rewriter.create<IMulOp>(loc, lhsDigit, rhsDigit); |
| 112 | Value current = rewriter.createOrFold<IAddOp>(loc, thisResDigit, mul); |
| 113 | thisResDigit = getLowDigit(current); |
| 114 | |
| 115 | if (i + j + 1 != resultDigits.size()) { |
| 116 | Value &nextResDigit = resultDigits[i + j + 1]; |
| 117 | Value carry = rewriter.createOrFold<IAddOp>(loc, nextResDigit, |
| 118 | getHighDigit(current)); |
| 119 | nextResDigit = carry; |
| 120 | } |
| 121 | } |
| 122 | } |
| 123 | |
| 124 | auto combineDigits = [loc, cst16, &rewriter](Value low, Value high) { |
| 125 | Value highBits = rewriter.create<ShiftLeftLogicalOp>(loc, high, cst16); |
| 126 | return rewriter.create<BitwiseOrOp>(loc, low, highBits); |
| 127 | }; |
| 128 | Value low = combineDigits(resultDigits[0], resultDigits[1]); |
| 129 | Value high = combineDigits(resultDigits[2], resultDigits[3]); |
| 130 | |
| 131 | return rewriter.create<CompositeConstructOp>( |
| 132 | loc, mulOp->getResultTypes().front(), llvm::ArrayRef({low, high})); |
| 133 | } |
| 134 | |
| 135 | //===----------------------------------------------------------------------===// |
| 136 | // Rewrite Patterns |
| 137 | //===----------------------------------------------------------------------===// |
| 138 | |
| 139 | template <typename MulExtendedOp, bool SignExtendArguments> |
| 140 | struct ExpandMulExtendedPattern final : OpRewritePattern<MulExtendedOp> { |
| 141 | using OpRewritePattern<MulExtendedOp>::OpRewritePattern; |
| 142 | |
| 143 | LogicalResult matchAndRewrite(MulExtendedOp op, |
| 144 | PatternRewriter &rewriter) const override { |
| 145 | Location loc = op->getLoc(); |
| 146 | Value lhs = op.getOperand1(); |
| 147 | Value rhs = op.getOperand2(); |
| 148 | |
| 149 | // Currently, WGSL only supports 32-bit integer types. Any other integer |
| 150 | // types should already have been promoted/demoted to i32. |
| 151 | auto elemTy = cast<IntegerType>(getElementTypeOrSelf(type: lhs.getType())); |
| 152 | if (elemTy.getIntOrFloatBitWidth() != 32) |
| 153 | return rewriter.notifyMatchFailure( |
| 154 | loc, |
| 155 | llvm::formatv("Unexpected integer type for WebGPU: '{0}'" , elemTy)); |
| 156 | |
| 157 | Value mul = lowerExtendedMultiplication(op, rewriter, lhs, rhs, |
| 158 | SignExtendArguments); |
| 159 | rewriter.replaceOp(op, mul); |
| 160 | return success(); |
| 161 | } |
| 162 | }; |
| 163 | |
| 164 | using ExpandSMulExtendedPattern = |
| 165 | ExpandMulExtendedPattern<SMulExtendedOp, true>; |
| 166 | using ExpandUMulExtendedPattern = |
| 167 | ExpandMulExtendedPattern<UMulExtendedOp, false>; |
| 168 | |
| 169 | struct ExpandAddCarryPattern final : OpRewritePattern<IAddCarryOp> { |
| 170 | using OpRewritePattern<IAddCarryOp>::OpRewritePattern; |
| 171 | |
| 172 | LogicalResult matchAndRewrite(IAddCarryOp op, |
| 173 | PatternRewriter &rewriter) const override { |
| 174 | Location loc = op->getLoc(); |
| 175 | Value lhs = op.getOperand1(); |
| 176 | Value rhs = op.getOperand2(); |
| 177 | |
| 178 | // Currently, WGSL only supports 32-bit integer types. Any other integer |
| 179 | // types should already have been promoted/demoted to i32. |
| 180 | Type argTy = lhs.getType(); |
| 181 | auto elemTy = cast<IntegerType>(getElementTypeOrSelf(type: argTy)); |
| 182 | if (elemTy.getIntOrFloatBitWidth() != 32) |
| 183 | return rewriter.notifyMatchFailure( |
| 184 | loc, |
| 185 | llvm::formatv("Unexpected integer type for WebGPU: '{0}'" , elemTy)); |
| 186 | |
| 187 | Value one = |
| 188 | rewriter.create<ConstantOp>(loc, argTy, getScalarOrSplatAttr(argTy, 1)); |
| 189 | Value zero = |
| 190 | rewriter.create<ConstantOp>(loc, argTy, getScalarOrSplatAttr(argTy, 0)); |
| 191 | |
| 192 | // Calculate the carry by checking if the addition resulted in an overflow. |
| 193 | Value out = rewriter.create<IAddOp>(loc, lhs, rhs); |
| 194 | Value cmp = rewriter.create<ULessThanOp>(loc, out, lhs); |
| 195 | Value carry = rewriter.create<SelectOp>(loc, cmp, one, zero); |
| 196 | |
| 197 | Value add = rewriter.create<CompositeConstructOp>( |
| 198 | loc, op->getResultTypes().front(), llvm::ArrayRef({out, carry})); |
| 199 | |
| 200 | rewriter.replaceOp(op, add); |
| 201 | return success(); |
| 202 | } |
| 203 | }; |
| 204 | |
| 205 | struct ExpandIsInfPattern final : OpRewritePattern<IsInfOp> { |
| 206 | using OpRewritePattern::OpRewritePattern; |
| 207 | |
| 208 | LogicalResult matchAndRewrite(IsInfOp op, |
| 209 | PatternRewriter &rewriter) const override { |
| 210 | // We assume values to be finite and turn `IsInf` info `false`. |
| 211 | rewriter.replaceOpWithNewOp<spirv::ConstantOp>( |
| 212 | op, op.getType(), getScalarOrSplatAttr(op.getType(), 0)); |
| 213 | return success(); |
| 214 | } |
| 215 | }; |
| 216 | |
| 217 | struct ExpandIsNanPattern final : OpRewritePattern<IsNanOp> { |
| 218 | using OpRewritePattern::OpRewritePattern; |
| 219 | |
| 220 | LogicalResult matchAndRewrite(IsNanOp op, |
| 221 | PatternRewriter &rewriter) const override { |
| 222 | // We assume values to be finite and turn `IsNan` info `false`. |
| 223 | rewriter.replaceOpWithNewOp<spirv::ConstantOp>( |
| 224 | op, op.getType(), getScalarOrSplatAttr(op.getType(), 0)); |
| 225 | return success(); |
| 226 | } |
| 227 | }; |
| 228 | |
| 229 | //===----------------------------------------------------------------------===// |
| 230 | // Passes |
| 231 | //===----------------------------------------------------------------------===// |
| 232 | struct WebGPUPreparePass final |
| 233 | : impl::SPIRVWebGPUPreparePassBase<WebGPUPreparePass> { |
| 234 | void runOnOperation() override { |
| 235 | RewritePatternSet patterns(&getContext()); |
| 236 | populateSPIRVExpandExtendedMultiplicationPatterns(patterns); |
| 237 | populateSPIRVExpandNonFiniteArithmeticPatterns(patterns); |
| 238 | |
| 239 | if (failed(applyPatternsGreedily(getOperation(), std::move(patterns)))) |
| 240 | signalPassFailure(); |
| 241 | } |
| 242 | }; |
| 243 | } // namespace |
| 244 | |
| 245 | //===----------------------------------------------------------------------===// |
| 246 | // Public Interface |
| 247 | //===----------------------------------------------------------------------===// |
| 248 | void populateSPIRVExpandExtendedMultiplicationPatterns( |
| 249 | RewritePatternSet &patterns) { |
| 250 | // WGSL currently does not support extended multiplication ops, see: |
| 251 | // https://github.com/gpuweb/gpuweb/issues/1565. |
| 252 | patterns.add<ExpandSMulExtendedPattern, ExpandUMulExtendedPattern, |
| 253 | ExpandAddCarryPattern>(patterns.getContext()); |
| 254 | } |
| 255 | |
| 256 | void populateSPIRVExpandNonFiniteArithmeticPatterns( |
| 257 | RewritePatternSet &patterns) { |
| 258 | // WGSL currently does not support `isInf` and `isNan`, see: |
| 259 | // https://github.com/gpuweb/gpuweb/pull/2311. |
| 260 | patterns.add<ExpandIsInfPattern, ExpandIsNanPattern>(patterns.getContext()); |
| 261 | } |
| 262 | |
| 263 | } // namespace spirv |
| 264 | } // namespace mlir |
| 265 | |