1//===- SPIRVCanonicalization.cpp - MLIR SPIR-V canonicalization patterns --===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file defines the folders and canonicalization patterns for SPIR-V ops.
10//
11//===----------------------------------------------------------------------===//
12
13#include <optional>
14#include <utility>
15
16#include "mlir/Dialect/SPIRV/IR/SPIRVOps.h"
17
18#include "mlir/Dialect/CommonFolders.h"
19#include "mlir/Dialect/SPIRV/IR/SPIRVTypes.h"
20#include "mlir/Dialect/UB/IR/UBOps.h"
21#include "mlir/IR/Matchers.h"
22#include "mlir/IR/PatternMatch.h"
23#include "llvm/ADT/STLExtras.h"
24#include "llvm/ADT/SmallVectorExtras.h"
25
26using namespace mlir;
27
28//===----------------------------------------------------------------------===//
29// Common utility functions
30//===----------------------------------------------------------------------===//
31
32/// Returns the boolean value under the hood if the given `boolAttr` is a scalar
33/// or splat vector bool constant.
34static std::optional<bool> getScalarOrSplatBoolAttr(Attribute attr) {
35 if (!attr)
36 return std::nullopt;
37
38 if (auto boolAttr = llvm::dyn_cast<BoolAttr>(Val&: attr))
39 return boolAttr.getValue();
40 if (auto splatAttr = llvm::dyn_cast<SplatElementsAttr>(Val&: attr))
41 if (splatAttr.getElementType().isInteger(width: 1))
42 return splatAttr.getSplatValue<bool>();
43 return std::nullopt;
44}
45
46// Extracts an element from the given `composite` by following the given
47// `indices`. Returns a null Attribute if error happens.
48static Attribute extractCompositeElement(Attribute composite,
49 ArrayRef<unsigned> indices) {
50 // Check that given composite is a constant.
51 if (!composite)
52 return {};
53 // Return composite itself if we reach the end of the index chain.
54 if (indices.empty())
55 return composite;
56
57 if (auto vector = llvm::dyn_cast<ElementsAttr>(Val&: composite)) {
58 assert(indices.size() == 1 && "must have exactly one index for a vector");
59 return vector.getValues<Attribute>()[indices[0]];
60 }
61
62 if (auto array = llvm::dyn_cast<ArrayAttr>(Val&: composite)) {
63 assert(!indices.empty() && "must have at least one index for an array");
64 return extractCompositeElement(composite: array.getValue()[indices[0]],
65 indices: indices.drop_front());
66 }
67
68 return {};
69}
70
71static bool isDivZeroOrOverflow(const APInt &a, const APInt &b) {
72 bool div0 = b.isZero();
73 bool overflow = a.isMinSignedValue() && b.isAllOnes();
74
75 return div0 || overflow;
76}
77
78//===----------------------------------------------------------------------===//
79// TableGen'erated canonicalizers
80//===----------------------------------------------------------------------===//
81
82namespace {
83#include "SPIRVCanonicalization.inc"
84} // namespace
85
86//===----------------------------------------------------------------------===//
87// spirv.AccessChainOp
88//===----------------------------------------------------------------------===//
89
90namespace {
91
92/// Combines chained `spirv::AccessChainOp` operations into one
93/// `spirv::AccessChainOp` operation.
94struct CombineChainedAccessChain final
95 : OpRewritePattern<spirv::AccessChainOp> {
96 using OpRewritePattern::OpRewritePattern;
97
98 LogicalResult matchAndRewrite(spirv::AccessChainOp accessChainOp,
99 PatternRewriter &rewriter) const override {
100 auto parentAccessChainOp =
101 accessChainOp.getBasePtr().getDefiningOp<spirv::AccessChainOp>();
102
103 if (!parentAccessChainOp) {
104 return failure();
105 }
106
107 // Combine indices.
108 SmallVector<Value, 4> indices(parentAccessChainOp.getIndices());
109 llvm::append_range(C&: indices, R: accessChainOp.getIndices());
110
111 rewriter.replaceOpWithNewOp<spirv::AccessChainOp>(
112 op: accessChainOp, args: parentAccessChainOp.getBasePtr(), args&: indices);
113
114 return success();
115 }
116};
117} // namespace
118
119void spirv::AccessChainOp::getCanonicalizationPatterns(
120 RewritePatternSet &results, MLIRContext *context) {
121 results.add<CombineChainedAccessChain>(arg&: context);
122}
123
124//===----------------------------------------------------------------------===//
125// spirv.IAddCarry
126//===----------------------------------------------------------------------===//
127
128// We are required to use CompositeConstructOp to create a constant struct as
129// they are not yet implemented as constant, hence we can not do so in a fold.
130struct IAddCarryFold final : OpRewritePattern<spirv::IAddCarryOp> {
131 using OpRewritePattern::OpRewritePattern;
132
133 LogicalResult matchAndRewrite(spirv::IAddCarryOp op,
134 PatternRewriter &rewriter) const override {
135 Location loc = op.getLoc();
136 Value lhs = op.getOperand1();
137 Value rhs = op.getOperand2();
138 Type constituentType = lhs.getType();
139
140 // iaddcarry (x, 0) = <0, x>
141 if (matchPattern(value: rhs, pattern: m_Zero())) {
142 Value constituents[2] = {rhs, lhs};
143 rewriter.replaceOpWithNewOp<spirv::CompositeConstructOp>(op, args: op.getType(),
144 args&: constituents);
145 return success();
146 }
147
148 // According to the SPIR-V spec:
149 //
150 // Result Type must be from OpTypeStruct. The struct must have two
151 // members...
152 //
153 // Member 0 of the result gets the low-order bits (full component width) of
154 // the addition.
155 //
156 // Member 1 of the result gets the high-order (carry) bit of the result of
157 // the addition. That is, it gets the value 1 if the addition overflowed
158 // the component width, and 0 otherwise.
159 Attribute lhsAttr;
160 Attribute rhsAttr;
161 if (!matchPattern(value: lhs, pattern: m_Constant(bind_value: &lhsAttr)) ||
162 !matchPattern(value: rhs, pattern: m_Constant(bind_value: &rhsAttr)))
163 return failure();
164
165 auto adds = constFoldBinaryOp<IntegerAttr>(
166 operands: {lhsAttr, rhsAttr},
167 calculate: [](const APInt &a, const APInt &b) { return a + b; });
168 if (!adds)
169 return failure();
170
171 auto carrys = constFoldBinaryOp<IntegerAttr>(
172 operands: ArrayRef{adds, lhsAttr}, calculate: [](const APInt &a, const APInt &b) {
173 APInt zero = APInt::getZero(numBits: a.getBitWidth());
174 return a.ult(RHS: b) ? (zero + 1) : zero;
175 });
176
177 if (!carrys)
178 return failure();
179
180 Value addsVal =
181 rewriter.create<spirv::ConstantOp>(location: loc, args&: constituentType, args&: adds);
182
183 Value carrysVal =
184 rewriter.create<spirv::ConstantOp>(location: loc, args&: constituentType, args&: carrys);
185
186 // Create empty struct
187 Value undef = rewriter.create<spirv::UndefOp>(location: loc, args: op.getType());
188 // Fill in adds at id 0
189 Value intermediate =
190 rewriter.create<spirv::CompositeInsertOp>(location: loc, args&: addsVal, args&: undef, args: 0);
191 // Fill in carrys at id 1
192 rewriter.replaceOpWithNewOp<spirv::CompositeInsertOp>(op, args&: carrysVal,
193 args&: intermediate, args: 1);
194 return success();
195 }
196};
197
198void spirv::IAddCarryOp::getCanonicalizationPatterns(
199 RewritePatternSet &patterns, MLIRContext *context) {
200 patterns.add<IAddCarryFold>(arg&: context);
201}
202
203//===----------------------------------------------------------------------===//
204// spirv.[S|U]MulExtended
205//===----------------------------------------------------------------------===//
206
207// We are required to use CompositeConstructOp to create a constant struct as
208// they are not yet implemented as constant, hence we can not do so in a fold.
209template <typename MulOp, bool IsSigned>
210struct MulExtendedFold final : OpRewritePattern<MulOp> {
211 using OpRewritePattern<MulOp>::OpRewritePattern;
212
213 LogicalResult matchAndRewrite(MulOp op,
214 PatternRewriter &rewriter) const override {
215 Location loc = op.getLoc();
216 Value lhs = op.getOperand1();
217 Value rhs = op.getOperand2();
218 Type constituentType = lhs.getType();
219
220 // [su]mulextended (x, 0) = <0, 0>
221 if (matchPattern(value: rhs, pattern: m_Zero())) {
222 Value zero = spirv::ConstantOp::getZero(type: constituentType, loc, builder&: rewriter);
223 Value constituents[2] = {zero, zero};
224 rewriter.replaceOpWithNewOp<spirv::CompositeConstructOp>(op, op.getType(),
225 constituents);
226 return success();
227 }
228
229 // According to the SPIR-V spec:
230 //
231 // Result Type must be from OpTypeStruct. The struct must have two
232 // members...
233 //
234 // Member 0 of the result gets the low-order bits of the multiplication.
235 //
236 // Member 1 of the result gets the high-order bits of the multiplication.
237 Attribute lhsAttr;
238 Attribute rhsAttr;
239 if (!matchPattern(value: lhs, pattern: m_Constant(bind_value: &lhsAttr)) ||
240 !matchPattern(value: rhs, pattern: m_Constant(bind_value: &rhsAttr)))
241 return failure();
242
243 auto lowBits = constFoldBinaryOp<IntegerAttr>(
244 {lhsAttr, rhsAttr},
245 [](const APInt &a, const APInt &b) { return a * b; });
246
247 if (!lowBits)
248 return failure();
249
250 auto highBits = constFoldBinaryOp<IntegerAttr>(
251 {lhsAttr, rhsAttr}, [](const APInt &a, const APInt &b) {
252 if (IsSigned) {
253 return llvm::APIntOps::mulhs(C1: a, C2: b);
254 } else {
255 return llvm::APIntOps::mulhu(C1: a, C2: b);
256 }
257 });
258
259 if (!highBits)
260 return failure();
261
262 Value lowBitsVal =
263 rewriter.create<spirv::ConstantOp>(loc, constituentType, lowBits);
264
265 Value highBitsVal =
266 rewriter.create<spirv::ConstantOp>(loc, constituentType, highBits);
267
268 // Create empty struct
269 Value undef = rewriter.create<spirv::UndefOp>(loc, op.getType());
270 // Fill in lowBits at id 0
271 Value intermediate =
272 rewriter.create<spirv::CompositeInsertOp>(location: loc, args&: lowBitsVal, args&: undef, args: 0);
273 // Fill in highBits at id 1
274 rewriter.replaceOpWithNewOp<spirv::CompositeInsertOp>(op, highBitsVal,
275 intermediate, 1);
276 return success();
277 }
278};
279
280using SMulExtendedOpFold = MulExtendedFold<spirv::SMulExtendedOp, true>;
281void spirv::SMulExtendedOp::getCanonicalizationPatterns(
282 RewritePatternSet &patterns, MLIRContext *context) {
283 patterns.add<SMulExtendedOpFold>(arg&: context);
284}
285
286struct UMulExtendedOpXOne final : OpRewritePattern<spirv::UMulExtendedOp> {
287 using OpRewritePattern::OpRewritePattern;
288
289 LogicalResult matchAndRewrite(spirv::UMulExtendedOp op,
290 PatternRewriter &rewriter) const override {
291 Location loc = op.getLoc();
292 Value lhs = op.getOperand1();
293 Value rhs = op.getOperand2();
294 Type constituentType = lhs.getType();
295
296 // umulextended (x, 1) = <x, 0>
297 if (matchPattern(value: rhs, pattern: m_One())) {
298 Value zero = spirv::ConstantOp::getZero(type: constituentType, loc, builder&: rewriter);
299 Value constituents[2] = {lhs, zero};
300 rewriter.replaceOpWithNewOp<spirv::CompositeConstructOp>(op, args: op.getType(),
301 args&: constituents);
302 return success();
303 }
304
305 return failure();
306 }
307};
308
309using UMulExtendedOpFold = MulExtendedFold<spirv::UMulExtendedOp, false>;
310void spirv::UMulExtendedOp::getCanonicalizationPatterns(
311 RewritePatternSet &patterns, MLIRContext *context) {
312 patterns.add<UMulExtendedOpFold, UMulExtendedOpXOne>(arg&: context);
313}
314
315//===----------------------------------------------------------------------===//
316// spirv.UMod
317//===----------------------------------------------------------------------===//
318
319// Input:
320// %0 = spirv.UMod %arg0, %const32 : i32
321// %1 = spirv.UMod %0, %const4 : i32
322// Output:
323// %0 = spirv.UMod %arg0, %const32 : i32
324// %1 = spirv.UMod %arg0, %const4 : i32
325
326// The transformation is only applied if one divisor is a multiple of the other.
327
328struct UModSimplification final : OpRewritePattern<spirv::UModOp> {
329 using OpRewritePattern::OpRewritePattern;
330
331 LogicalResult matchAndRewrite(spirv::UModOp umodOp,
332 PatternRewriter &rewriter) const override {
333 auto prevUMod = umodOp.getOperand(i: 0).getDefiningOp<spirv::UModOp>();
334 if (!prevUMod)
335 return failure();
336
337 TypedAttr prevValue;
338 TypedAttr currValue;
339 if (!matchPattern(value: prevUMod.getOperand(i: 1), pattern: m_Constant(bind_value: &prevValue)) ||
340 !matchPattern(value: umodOp.getOperand(i: 1), pattern: m_Constant(bind_value: &currValue)))
341 return failure();
342
343 // Ensure that previous divisor is a multiple of the current divisor. If
344 // not, fail the transformation.
345 bool isApplicable = false;
346 if (auto prevInt = dyn_cast<IntegerAttr>(Val&: prevValue)) {
347 auto currInt = cast<IntegerAttr>(Val&: currValue);
348 isApplicable = prevInt.getValue().urem(RHS: currInt.getValue()) == 0;
349 } else if (auto prevVec = dyn_cast<DenseElementsAttr>(Val&: prevValue)) {
350 auto currVec = cast<DenseElementsAttr>(Val&: currValue);
351 isApplicable = llvm::all_of(Range: llvm::zip_equal(t: prevVec.getValues<APInt>(),
352 u: currVec.getValues<APInt>()),
353 P: [](const auto &pair) {
354 auto &[prev, curr] = pair;
355 return prev.urem(curr) == 0;
356 });
357 }
358
359 if (!isApplicable)
360 return failure();
361
362 // The transformation is safe. Replace the existing UMod operation with a
363 // new UMod operation, using the original dividend and the current divisor.
364 rewriter.replaceOpWithNewOp<spirv::UModOp>(
365 op: umodOp, args: umodOp.getType(), args: prevUMod.getOperand(i: 0), args: umodOp.getOperand(i: 1));
366
367 return success();
368 }
369};
370
371void spirv::UModOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
372 MLIRContext *context) {
373 patterns.insert<UModSimplification>(arg&: context);
374}
375
376//===----------------------------------------------------------------------===//
377// spirv.BitcastOp
378//===----------------------------------------------------------------------===//
379
380OpFoldResult spirv::BitcastOp::fold(FoldAdaptor /*adaptor*/) {
381 Value curInput = getOperand();
382 if (getType() == curInput.getType())
383 return curInput;
384
385 // Look through nested bitcasts.
386 if (auto prevCast = curInput.getDefiningOp<spirv::BitcastOp>()) {
387 Value prevInput = prevCast.getOperand();
388 if (prevInput.getType() == getType())
389 return prevInput;
390
391 getOperandMutable().assign(value: prevInput);
392 return getResult();
393 }
394
395 // TODO(kuhar): Consider constant-folding the operand attribute.
396 return {};
397}
398
399//===----------------------------------------------------------------------===//
400// spirv.CompositeExtractOp
401//===----------------------------------------------------------------------===//
402
403OpFoldResult spirv::CompositeExtractOp::fold(FoldAdaptor adaptor) {
404 Value compositeOp = getComposite();
405
406 while (auto insertOp =
407 compositeOp.getDefiningOp<spirv::CompositeInsertOp>()) {
408 if (getIndices() == insertOp.getIndices())
409 return insertOp.getObject();
410 compositeOp = insertOp.getComposite();
411 }
412
413 if (auto constructOp =
414 compositeOp.getDefiningOp<spirv::CompositeConstructOp>()) {
415 auto type = llvm::cast<spirv::CompositeType>(Val: constructOp.getType());
416 if (getIndices().size() == 1 &&
417 constructOp.getConstituents().size() == type.getNumElements()) {
418 auto i = llvm::cast<IntegerAttr>(Val: *getIndices().begin());
419 if (i.getValue().getSExtValue() <
420 static_cast<int64_t>(constructOp.getConstituents().size()))
421 return constructOp.getConstituents()[i.getValue().getSExtValue()];
422 }
423 }
424
425 auto indexVector = llvm::map_to_vector(C: getIndices(), F: [](Attribute attr) {
426 return static_cast<unsigned>(llvm::cast<IntegerAttr>(Val&: attr).getInt());
427 });
428 return extractCompositeElement(composite: adaptor.getComposite(), indices: indexVector);
429}
430
431//===----------------------------------------------------------------------===//
432// spirv.Constant
433//===----------------------------------------------------------------------===//
434
435OpFoldResult spirv::ConstantOp::fold(FoldAdaptor /*adaptor*/) {
436 return getValue();
437}
438
439//===----------------------------------------------------------------------===//
440// spirv.IAdd
441//===----------------------------------------------------------------------===//
442
443OpFoldResult spirv::IAddOp::fold(FoldAdaptor adaptor) {
444 // x + 0 = x
445 if (matchPattern(value: getOperand2(), pattern: m_Zero()))
446 return getOperand1();
447
448 // According to the SPIR-V spec:
449 //
450 // The resulting value will equal the low-order N bits of the correct result
451 // R, where N is the component width and R is computed with enough precision
452 // to avoid overflow and underflow.
453 return constFoldBinaryOp<IntegerAttr>(
454 operands: adaptor.getOperands(),
455 calculate: [](APInt a, const APInt &b) { return std::move(a) + b; });
456}
457
458//===----------------------------------------------------------------------===//
459// spirv.IMul
460//===----------------------------------------------------------------------===//
461
462OpFoldResult spirv::IMulOp::fold(FoldAdaptor adaptor) {
463 // x * 0 == 0
464 if (matchPattern(value: getOperand2(), pattern: m_Zero()))
465 return getOperand2();
466 // x * 1 = x
467 if (matchPattern(value: getOperand2(), pattern: m_One()))
468 return getOperand1();
469
470 // According to the SPIR-V spec:
471 //
472 // The resulting value will equal the low-order N bits of the correct result
473 // R, where N is the component width and R is computed with enough precision
474 // to avoid overflow and underflow.
475 return constFoldBinaryOp<IntegerAttr>(
476 operands: adaptor.getOperands(),
477 calculate: [](const APInt &a, const APInt &b) { return a * b; });
478}
479
480//===----------------------------------------------------------------------===//
481// spirv.ISub
482//===----------------------------------------------------------------------===//
483
484OpFoldResult spirv::ISubOp::fold(FoldAdaptor adaptor) {
485 // x - x = 0
486 if (getOperand1() == getOperand2())
487 return Builder(getContext()).getZeroAttr(type: getType());
488
489 // According to the SPIR-V spec:
490 //
491 // The resulting value will equal the low-order N bits of the correct result
492 // R, where N is the component width and R is computed with enough precision
493 // to avoid overflow and underflow.
494 return constFoldBinaryOp<IntegerAttr>(
495 operands: adaptor.getOperands(),
496 calculate: [](APInt a, const APInt &b) { return std::move(a) - b; });
497}
498
499//===----------------------------------------------------------------------===//
500// spirv.SDiv
501//===----------------------------------------------------------------------===//
502
503OpFoldResult spirv::SDivOp::fold(FoldAdaptor adaptor) {
504 // sdiv (x, 1) = x
505 if (matchPattern(value: getOperand2(), pattern: m_One()))
506 return getOperand1();
507
508 // According to the SPIR-V spec:
509 //
510 // Signed-integer division of Operand 1 divided by Operand 2.
511 // Results are computed per component. Behavior is undefined if Operand 2 is
512 // 0. Behavior is undefined if Operand 2 is -1 and Operand 1 is the minimum
513 // representable value for the operands' type, causing signed overflow.
514 //
515 // So don't fold during undefined behavior.
516 bool div0OrOverflow = false;
517 auto res = constFoldBinaryOp<IntegerAttr>(
518 operands: adaptor.getOperands(), calculate: [&](const APInt &a, const APInt &b) {
519 if (div0OrOverflow || isDivZeroOrOverflow(a, b)) {
520 div0OrOverflow = true;
521 return a;
522 }
523 return a.sdiv(RHS: b);
524 });
525 return div0OrOverflow ? Attribute() : res;
526}
527
528//===----------------------------------------------------------------------===//
529// spirv.SMod
530//===----------------------------------------------------------------------===//
531
532OpFoldResult spirv::SModOp::fold(FoldAdaptor adaptor) {
533 // smod (x, 1) = 0
534 if (matchPattern(value: getOperand2(), pattern: m_One()))
535 return Builder(getContext()).getZeroAttr(type: getType());
536
537 // According to SPIR-V spec:
538 //
539 // Signed remainder operation for the remainder whose sign matches the sign
540 // of Operand 2. Behavior is undefined if Operand 2 is 0. Behavior is
541 // undefined if Operand 2 is -1 and Operand 1 is the minimum representable
542 // value for the operands' type, causing signed overflow. Otherwise, the
543 // result is the remainder r of Operand 1 divided by Operand 2 where if
544 // r ≠ 0, the sign of r is the same as the sign of Operand 2.
545 //
546 // So don't fold during undefined behavior
547 bool div0OrOverflow = false;
548 auto res = constFoldBinaryOp<IntegerAttr>(
549 operands: adaptor.getOperands(), calculate: [&](const APInt &a, const APInt &b) {
550 if (div0OrOverflow || isDivZeroOrOverflow(a, b)) {
551 div0OrOverflow = true;
552 return a;
553 }
554 APInt c = a.abs().urem(RHS: b.abs());
555 if (c.isZero())
556 return c;
557 if (b.isNegative()) {
558 APInt zero = APInt::getZero(numBits: c.getBitWidth());
559 return a.isNegative() ? (zero - c) : (b + c);
560 }
561 return a.isNegative() ? (b - c) : c;
562 });
563 return div0OrOverflow ? Attribute() : res;
564}
565
566//===----------------------------------------------------------------------===//
567// spirv.SRem
568//===----------------------------------------------------------------------===//
569
570OpFoldResult spirv::SRemOp::fold(FoldAdaptor adaptor) {
571 // x % 1 = 0
572 if (matchPattern(value: getOperand2(), pattern: m_One()))
573 return Builder(getContext()).getZeroAttr(type: getType());
574
575 // According to SPIR-V spec:
576 //
577 // Signed remainder operation for the remainder whose sign matches the sign
578 // of Operand 1. Behavior is undefined if Operand 2 is 0. Behavior is
579 // undefined if Operand 2 is -1 and Operand 1 is the minimum representable
580 // value for the operands' type, causing signed overflow. Otherwise, the
581 // result is the remainder r of Operand 1 divided by Operand 2 where if
582 // r ≠ 0, the sign of r is the same as the sign of Operand 1.
583
584 // Don't fold if it would do undefined behavior.
585 bool div0OrOverflow = false;
586 auto res = constFoldBinaryOp<IntegerAttr>(
587 operands: adaptor.getOperands(), calculate: [&](APInt a, const APInt &b) {
588 if (div0OrOverflow || isDivZeroOrOverflow(a, b)) {
589 div0OrOverflow = true;
590 return a;
591 }
592 return a.srem(RHS: b);
593 });
594 return div0OrOverflow ? Attribute() : res;
595}
596
597//===----------------------------------------------------------------------===//
598// spirv.UDiv
599//===----------------------------------------------------------------------===//
600
601OpFoldResult spirv::UDivOp::fold(FoldAdaptor adaptor) {
602 // udiv (x, 1) = x
603 if (matchPattern(value: getOperand2(), pattern: m_One()))
604 return getOperand1();
605
606 // According to the SPIR-V spec:
607 //
608 // Unsigned-integer division of Operand 1 divided by Operand 2. Behavior is
609 // undefined if Operand 2 is 0.
610 //
611 // So don't fold during undefined behavior.
612 bool div0 = false;
613 auto res = constFoldBinaryOp<IntegerAttr>(
614 operands: adaptor.getOperands(), calculate: [&](const APInt &a, const APInt &b) {
615 if (div0 || b.isZero()) {
616 div0 = true;
617 return a;
618 }
619 return a.udiv(RHS: b);
620 });
621 return div0 ? Attribute() : res;
622}
623
624//===----------------------------------------------------------------------===//
625// spirv.UMod
626//===----------------------------------------------------------------------===//
627
628OpFoldResult spirv::UModOp::fold(FoldAdaptor adaptor) {
629 // umod (x, 1) = 0
630 if (matchPattern(value: getOperand2(), pattern: m_One()))
631 return Builder(getContext()).getZeroAttr(type: getType());
632
633 // According to the SPIR-V spec:
634 //
635 // Unsigned modulo operation of Operand 1 modulo Operand 2. Behavior is
636 // undefined if Operand 2 is 0.
637 //
638 // So don't fold during undefined behavior.
639 bool div0 = false;
640 auto res = constFoldBinaryOp<IntegerAttr>(
641 operands: adaptor.getOperands(), calculate: [&](const APInt &a, const APInt &b) {
642 if (div0 || b.isZero()) {
643 div0 = true;
644 return a;
645 }
646 return a.urem(RHS: b);
647 });
648 return div0 ? Attribute() : res;
649}
650
651//===----------------------------------------------------------------------===//
652// spirv.SNegate
653//===----------------------------------------------------------------------===//
654
655OpFoldResult spirv::SNegateOp::fold(FoldAdaptor adaptor) {
656 // -(-x) = 0 - (0 - x) = x
657 auto op = getOperand();
658 if (auto negateOp = op.getDefiningOp<spirv::SNegateOp>())
659 return negateOp->getOperand(idx: 0);
660
661 // According to the SPIR-V spec:
662 //
663 // Signed-integer subtract of Operand from zero.
664 return constFoldUnaryOp<IntegerAttr>(
665 operands: adaptor.getOperands(), calculate: [](const APInt &a) {
666 APInt zero = APInt::getZero(numBits: a.getBitWidth());
667 return zero - a;
668 });
669}
670
671//===----------------------------------------------------------------------===//
672// spirv.NotOp
673//===----------------------------------------------------------------------===//
674
675OpFoldResult spirv::NotOp::fold(spirv::NotOp::FoldAdaptor adaptor) {
676 // !(!x) = x
677 auto op = getOperand();
678 if (auto notOp = op.getDefiningOp<spirv::NotOp>())
679 return notOp->getOperand(idx: 0);
680
681 // According to the SPIR-V spec:
682 //
683 // Complement the bits of Operand.
684 return constFoldUnaryOp<IntegerAttr>(operands: adaptor.getOperands(), calculate: [&](APInt a) {
685 a.flipAllBits();
686 return a;
687 });
688}
689
690//===----------------------------------------------------------------------===//
691// spirv.LogicalAnd
692//===----------------------------------------------------------------------===//
693
694OpFoldResult spirv::LogicalAndOp::fold(FoldAdaptor adaptor) {
695 if (std::optional<bool> rhs =
696 getScalarOrSplatBoolAttr(attr: adaptor.getOperand2())) {
697 // x && true = x
698 if (*rhs)
699 return getOperand1();
700
701 // x && false = false
702 if (!*rhs)
703 return adaptor.getOperand2();
704 }
705
706 return Attribute();
707}
708
709//===----------------------------------------------------------------------===//
710// spirv.LogicalEqualOp
711//===----------------------------------------------------------------------===//
712
713OpFoldResult
714spirv::LogicalEqualOp::fold(spirv::LogicalEqualOp::FoldAdaptor adaptor) {
715 // x == x -> true
716 if (getOperand1() == getOperand2()) {
717 auto trueAttr = BoolAttr::get(context: getContext(), value: true);
718 if (isa<IntegerType>(Val: getType()))
719 return trueAttr;
720 if (auto vecTy = dyn_cast<VectorType>(Val: getType()))
721 return SplatElementsAttr::get(type: vecTy, values: trueAttr);
722 }
723
724 return constFoldBinaryOp<IntegerAttr>(
725 operands: adaptor.getOperands(), calculate: [](const APInt &a, const APInt &b) {
726 return a == b ? APInt::getAllOnes(numBits: 1) : APInt::getZero(numBits: 1);
727 });
728}
729
730//===----------------------------------------------------------------------===//
731// spirv.LogicalNotEqualOp
732//===----------------------------------------------------------------------===//
733
734OpFoldResult spirv::LogicalNotEqualOp::fold(FoldAdaptor adaptor) {
735 if (std::optional<bool> rhs =
736 getScalarOrSplatBoolAttr(attr: adaptor.getOperand2())) {
737 // x != false -> x
738 if (!rhs.value())
739 return getOperand1();
740 }
741
742 // x == x -> false
743 if (getOperand1() == getOperand2()) {
744 auto falseAttr = BoolAttr::get(context: getContext(), value: false);
745 if (isa<IntegerType>(Val: getType()))
746 return falseAttr;
747 if (auto vecTy = dyn_cast<VectorType>(Val: getType()))
748 return SplatElementsAttr::get(type: vecTy, values: falseAttr);
749 }
750
751 return constFoldBinaryOp<IntegerAttr>(
752 operands: adaptor.getOperands(), calculate: [](const APInt &a, const APInt &b) {
753 return a == b ? APInt::getZero(numBits: 1) : APInt::getAllOnes(numBits: 1);
754 });
755}
756
757//===----------------------------------------------------------------------===//
758// spirv.LogicalNot
759//===----------------------------------------------------------------------===//
760
761OpFoldResult spirv::LogicalNotOp::fold(FoldAdaptor adaptor) {
762 // !(!x) = x
763 auto op = getOperand();
764 if (auto notOp = op.getDefiningOp<spirv::LogicalNotOp>())
765 return notOp->getOperand(idx: 0);
766
767 // According to the SPIR-V spec:
768 //
769 // Complement the bits of Operand.
770 return constFoldUnaryOp<IntegerAttr>(operands: adaptor.getOperands(),
771 calculate: [](const APInt &a) {
772 APInt zero = APInt::getZero(numBits: 1);
773 return a == 1 ? zero : (zero + 1);
774 });
775}
776
777void spirv::LogicalNotOp::getCanonicalizationPatterns(
778 RewritePatternSet &results, MLIRContext *context) {
779 results
780 .add<ConvertLogicalNotOfIEqual, ConvertLogicalNotOfINotEqual,
781 ConvertLogicalNotOfLogicalEqual, ConvertLogicalNotOfLogicalNotEqual>(
782 arg&: context);
783}
784
785//===----------------------------------------------------------------------===//
786// spirv.LogicalOr
787//===----------------------------------------------------------------------===//
788
789OpFoldResult spirv::LogicalOrOp::fold(FoldAdaptor adaptor) {
790 if (auto rhs = getScalarOrSplatBoolAttr(attr: adaptor.getOperand2())) {
791 if (*rhs) {
792 // x || true = true
793 return adaptor.getOperand2();
794 }
795
796 if (!*rhs) {
797 // x || false = x
798 return getOperand1();
799 }
800 }
801
802 return Attribute();
803}
804
805//===----------------------------------------------------------------------===//
806// spirv.SelectOp
807//===----------------------------------------------------------------------===//
808
809OpFoldResult spirv::SelectOp::fold(FoldAdaptor adaptor) {
810 // spirv.Select _ x x -> x
811 Value trueVals = getTrueValue();
812 Value falseVals = getFalseValue();
813 if (trueVals == falseVals)
814 return trueVals;
815
816 ArrayRef<Attribute> operands = adaptor.getOperands();
817
818 // spirv.Select true x y -> x
819 // spirv.Select false x y -> y
820 if (auto boolAttr = getScalarOrSplatBoolAttr(attr: operands[0]))
821 return *boolAttr ? trueVals : falseVals;
822
823 // Check that all the operands are constant
824 if (!operands[0] || !operands[1] || !operands[2])
825 return Attribute();
826
827 // Note: getScalarOrSplatBoolAttr will always return a boolAttr if we are in
828 // the scalar case. Hence, we are only required to consider the case of
829 // DenseElementsAttr in foldSelectOp.
830 auto condAttrs = dyn_cast<DenseElementsAttr>(Val: operands[0]);
831 auto trueAttrs = dyn_cast<DenseElementsAttr>(Val: operands[1]);
832 auto falseAttrs = dyn_cast<DenseElementsAttr>(Val: operands[2]);
833 if (!condAttrs || !trueAttrs || !falseAttrs)
834 return Attribute();
835
836 auto elementResults = llvm::to_vector<4>(Range: trueAttrs.getValues<Attribute>());
837 auto iters = llvm::zip_equal(t&: elementResults, u: condAttrs.getValues<BoolAttr>(),
838 args: falseAttrs.getValues<Attribute>());
839 for (auto [result, cond, falseRes] : iters) {
840 if (!cond.getValue())
841 result = falseRes;
842 }
843
844 auto resultType = trueAttrs.getType();
845 return DenseElementsAttr::get(type: cast<ShapedType>(Val&: resultType), values: elementResults);
846}
847
848//===----------------------------------------------------------------------===//
849// spirv.IEqualOp
850//===----------------------------------------------------------------------===//
851
852OpFoldResult spirv::IEqualOp::fold(spirv::IEqualOp::FoldAdaptor adaptor) {
853 // x == x -> true
854 if (getOperand1() == getOperand2()) {
855 auto trueAttr = BoolAttr::get(context: getContext(), value: true);
856 if (isa<IntegerType>(Val: getType()))
857 return trueAttr;
858 if (auto vecTy = dyn_cast<VectorType>(Val: getType()))
859 return SplatElementsAttr::get(type: vecTy, values: trueAttr);
860 }
861
862 return constFoldBinaryOp<IntegerAttr>(
863 operands: adaptor.getOperands(), resultType: getType(), calculate: [](const APInt &a, const APInt &b) {
864 return a == b ? APInt::getAllOnes(numBits: 1) : APInt::getZero(numBits: 1);
865 });
866}
867
868//===----------------------------------------------------------------------===//
869// spirv.INotEqualOp
870//===----------------------------------------------------------------------===//
871
872OpFoldResult spirv::INotEqualOp::fold(spirv::INotEqualOp::FoldAdaptor adaptor) {
873 // x == x -> false
874 if (getOperand1() == getOperand2()) {
875 auto falseAttr = BoolAttr::get(context: getContext(), value: false);
876 if (isa<IntegerType>(Val: getType()))
877 return falseAttr;
878 if (auto vecTy = dyn_cast<VectorType>(Val: getType()))
879 return SplatElementsAttr::get(type: vecTy, values: falseAttr);
880 }
881
882 return constFoldBinaryOp<IntegerAttr>(
883 operands: adaptor.getOperands(), resultType: getType(), calculate: [](const APInt &a, const APInt &b) {
884 return a == b ? APInt::getZero(numBits: 1) : APInt::getAllOnes(numBits: 1);
885 });
886}
887
888//===----------------------------------------------------------------------===//
889// spirv.SGreaterThan
890//===----------------------------------------------------------------------===//
891
892OpFoldResult
893spirv::SGreaterThanOp::fold(spirv::SGreaterThanOp::FoldAdaptor adaptor) {
894 // x == x -> false
895 if (getOperand1() == getOperand2()) {
896 auto falseAttr = BoolAttr::get(context: getContext(), value: false);
897 if (isa<IntegerType>(Val: getType()))
898 return falseAttr;
899 if (auto vecTy = dyn_cast<VectorType>(Val: getType()))
900 return SplatElementsAttr::get(type: vecTy, values: falseAttr);
901 }
902
903 return constFoldBinaryOp<IntegerAttr>(
904 operands: adaptor.getOperands(), resultType: getType(), calculate: [](const APInt &a, const APInt &b) {
905 return a.sgt(RHS: b) ? APInt::getAllOnes(numBits: 1) : APInt::getZero(numBits: 1);
906 });
907}
908
909//===----------------------------------------------------------------------===//
910// spirv.SGreaterThanEqual
911//===----------------------------------------------------------------------===//
912
913OpFoldResult spirv::SGreaterThanEqualOp::fold(
914 spirv::SGreaterThanEqualOp::FoldAdaptor adaptor) {
915 // x == x -> true
916 if (getOperand1() == getOperand2()) {
917 auto trueAttr = BoolAttr::get(context: getContext(), value: true);
918 if (isa<IntegerType>(Val: getType()))
919 return trueAttr;
920 if (auto vecTy = dyn_cast<VectorType>(Val: getType()))
921 return SplatElementsAttr::get(type: vecTy, values: trueAttr);
922 }
923
924 return constFoldBinaryOp<IntegerAttr>(
925 operands: adaptor.getOperands(), resultType: getType(), calculate: [](const APInt &a, const APInt &b) {
926 return a.sge(RHS: b) ? APInt::getAllOnes(numBits: 1) : APInt::getZero(numBits: 1);
927 });
928}
929
930//===----------------------------------------------------------------------===//
931// spirv.UGreaterThan
932//===----------------------------------------------------------------------===//
933
934OpFoldResult
935spirv::UGreaterThanOp::fold(spirv::UGreaterThanOp::FoldAdaptor adaptor) {
936 // x == x -> false
937 if (getOperand1() == getOperand2()) {
938 auto falseAttr = BoolAttr::get(context: getContext(), value: false);
939 if (isa<IntegerType>(Val: getType()))
940 return falseAttr;
941 if (auto vecTy = dyn_cast<VectorType>(Val: getType()))
942 return SplatElementsAttr::get(type: vecTy, values: falseAttr);
943 }
944
945 return constFoldBinaryOp<IntegerAttr>(
946 operands: adaptor.getOperands(), resultType: getType(), calculate: [](const APInt &a, const APInt &b) {
947 return a.ugt(RHS: b) ? APInt::getAllOnes(numBits: 1) : APInt::getZero(numBits: 1);
948 });
949}
950
951//===----------------------------------------------------------------------===//
952// spirv.UGreaterThanEqual
953//===----------------------------------------------------------------------===//
954
955OpFoldResult spirv::UGreaterThanEqualOp::fold(
956 spirv::UGreaterThanEqualOp::FoldAdaptor adaptor) {
957 // x == x -> true
958 if (getOperand1() == getOperand2()) {
959 auto trueAttr = BoolAttr::get(context: getContext(), value: true);
960 if (isa<IntegerType>(Val: getType()))
961 return trueAttr;
962 if (auto vecTy = dyn_cast<VectorType>(Val: getType()))
963 return SplatElementsAttr::get(type: vecTy, values: trueAttr);
964 }
965
966 return constFoldBinaryOp<IntegerAttr>(
967 operands: adaptor.getOperands(), resultType: getType(), calculate: [](const APInt &a, const APInt &b) {
968 return a.uge(RHS: b) ? APInt::getAllOnes(numBits: 1) : APInt::getZero(numBits: 1);
969 });
970}
971
972//===----------------------------------------------------------------------===//
973// spirv.SLessThan
974//===----------------------------------------------------------------------===//
975
976OpFoldResult spirv::SLessThanOp::fold(spirv::SLessThanOp::FoldAdaptor adaptor) {
977 // x == x -> false
978 if (getOperand1() == getOperand2()) {
979 auto falseAttr = BoolAttr::get(context: getContext(), value: false);
980 if (isa<IntegerType>(Val: getType()))
981 return falseAttr;
982 if (auto vecTy = dyn_cast<VectorType>(Val: getType()))
983 return SplatElementsAttr::get(type: vecTy, values: falseAttr);
984 }
985
986 return constFoldBinaryOp<IntegerAttr>(
987 operands: adaptor.getOperands(), resultType: getType(), calculate: [](const APInt &a, const APInt &b) {
988 return a.slt(RHS: b) ? APInt::getAllOnes(numBits: 1) : APInt::getZero(numBits: 1);
989 });
990}
991
992//===----------------------------------------------------------------------===//
993// spirv.SLessThanEqual
994//===----------------------------------------------------------------------===//
995
996OpFoldResult
997spirv::SLessThanEqualOp::fold(spirv::SLessThanEqualOp::FoldAdaptor adaptor) {
998 // x == x -> true
999 if (getOperand1() == getOperand2()) {
1000 auto trueAttr = BoolAttr::get(context: getContext(), value: true);
1001 if (isa<IntegerType>(Val: getType()))
1002 return trueAttr;
1003 if (auto vecTy = dyn_cast<VectorType>(Val: getType()))
1004 return SplatElementsAttr::get(type: vecTy, values: trueAttr);
1005 }
1006
1007 return constFoldBinaryOp<IntegerAttr>(
1008 operands: adaptor.getOperands(), resultType: getType(), calculate: [](const APInt &a, const APInt &b) {
1009 return a.sle(RHS: b) ? APInt::getAllOnes(numBits: 1) : APInt::getZero(numBits: 1);
1010 });
1011}
1012
1013//===----------------------------------------------------------------------===//
1014// spirv.ULessThan
1015//===----------------------------------------------------------------------===//
1016
1017OpFoldResult spirv::ULessThanOp::fold(spirv::ULessThanOp::FoldAdaptor adaptor) {
1018 // x == x -> false
1019 if (getOperand1() == getOperand2()) {
1020 auto falseAttr = BoolAttr::get(context: getContext(), value: false);
1021 if (isa<IntegerType>(Val: getType()))
1022 return falseAttr;
1023 if (auto vecTy = dyn_cast<VectorType>(Val: getType()))
1024 return SplatElementsAttr::get(type: vecTy, values: falseAttr);
1025 }
1026
1027 return constFoldBinaryOp<IntegerAttr>(
1028 operands: adaptor.getOperands(), resultType: getType(), calculate: [](const APInt &a, const APInt &b) {
1029 return a.ult(RHS: b) ? APInt::getAllOnes(numBits: 1) : APInt::getZero(numBits: 1);
1030 });
1031}
1032
1033//===----------------------------------------------------------------------===//
1034// spirv.ULessThanEqual
1035//===----------------------------------------------------------------------===//
1036
1037OpFoldResult
1038spirv::ULessThanEqualOp::fold(spirv::ULessThanEqualOp::FoldAdaptor adaptor) {
1039 // x == x -> true
1040 if (getOperand1() == getOperand2()) {
1041 auto trueAttr = BoolAttr::get(context: getContext(), value: true);
1042 if (isa<IntegerType>(Val: getType()))
1043 return trueAttr;
1044 if (auto vecTy = dyn_cast<VectorType>(Val: getType()))
1045 return SplatElementsAttr::get(type: vecTy, values: trueAttr);
1046 }
1047
1048 return constFoldBinaryOp<IntegerAttr>(
1049 operands: adaptor.getOperands(), resultType: getType(), calculate: [](const APInt &a, const APInt &b) {
1050 return a.ule(RHS: b) ? APInt::getAllOnes(numBits: 1) : APInt::getZero(numBits: 1);
1051 });
1052}
1053
1054//===----------------------------------------------------------------------===//
1055// spirv.ShiftLeftLogical
1056//===----------------------------------------------------------------------===//
1057
1058OpFoldResult spirv::ShiftLeftLogicalOp::fold(
1059 spirv::ShiftLeftLogicalOp::FoldAdaptor adaptor) {
1060 // x << 0 -> x
1061 if (matchPattern(attr: adaptor.getOperand2(), pattern: m_Zero())) {
1062 return getOperand1();
1063 }
1064
1065 // Unfortunately due to below undefined behaviour can't fold 0 for Base.
1066
1067 // Results are computed per component, and within each component, per bit...
1068 //
1069 // The result is undefined if Shift is greater than or equal to the bit width
1070 // of the components of Base.
1071 //
1072 // So we can use the APInt << method, but don't fold if undefined behaviour.
1073 bool shiftToLarge = false;
1074 auto res = constFoldBinaryOp<IntegerAttr>(
1075 operands: adaptor.getOperands(), calculate: [&](const APInt &a, const APInt &b) {
1076 if (shiftToLarge || b.uge(RHS: a.getBitWidth())) {
1077 shiftToLarge = true;
1078 return a;
1079 }
1080 return a << b;
1081 });
1082 return shiftToLarge ? Attribute() : res;
1083}
1084
1085//===----------------------------------------------------------------------===//
1086// spirv.ShiftRightArithmetic
1087//===----------------------------------------------------------------------===//
1088
1089OpFoldResult spirv::ShiftRightArithmeticOp::fold(
1090 spirv::ShiftRightArithmeticOp::FoldAdaptor adaptor) {
1091 // x >> 0 -> x
1092 if (matchPattern(attr: adaptor.getOperand2(), pattern: m_Zero())) {
1093 return getOperand1();
1094 }
1095
1096 // Unfortunately due to below undefined behaviour can't fold 0, -1 for Base.
1097
1098 // Results are computed per component, and within each component, per bit...
1099 //
1100 // The result is undefined if Shift is greater than or equal to the bit width
1101 // of the components of Base.
1102 //
1103 // So we can use the APInt ashr method, but don't fold if undefined behaviour.
1104 bool shiftToLarge = false;
1105 auto res = constFoldBinaryOp<IntegerAttr>(
1106 operands: adaptor.getOperands(), calculate: [&](const APInt &a, const APInt &b) {
1107 if (shiftToLarge || b.uge(RHS: a.getBitWidth())) {
1108 shiftToLarge = true;
1109 return a;
1110 }
1111 return a.ashr(ShiftAmt: b);
1112 });
1113 return shiftToLarge ? Attribute() : res;
1114}
1115
1116//===----------------------------------------------------------------------===//
1117// spirv.ShiftRightLogical
1118//===----------------------------------------------------------------------===//
1119
1120OpFoldResult spirv::ShiftRightLogicalOp::fold(
1121 spirv::ShiftRightLogicalOp::FoldAdaptor adaptor) {
1122 // x >> 0 -> x
1123 if (matchPattern(attr: adaptor.getOperand2(), pattern: m_Zero())) {
1124 return getOperand1();
1125 }
1126
1127 // Unfortunately due to below undefined behaviour can't fold 0 for Base.
1128
1129 // Results are computed per component, and within each component, per bit...
1130 //
1131 // The result is undefined if Shift is greater than or equal to the bit width
1132 // of the components of Base.
1133 //
1134 // So we can use the APInt lshr method, but don't fold if undefined behaviour.
1135 bool shiftToLarge = false;
1136 auto res = constFoldBinaryOp<IntegerAttr>(
1137 operands: adaptor.getOperands(), calculate: [&](const APInt &a, const APInt &b) {
1138 if (shiftToLarge || b.uge(RHS: a.getBitWidth())) {
1139 shiftToLarge = true;
1140 return a;
1141 }
1142 return a.lshr(ShiftAmt: b);
1143 });
1144 return shiftToLarge ? Attribute() : res;
1145}
1146
1147//===----------------------------------------------------------------------===//
1148// spirv.BitwiseAndOp
1149//===----------------------------------------------------------------------===//
1150
1151OpFoldResult
1152spirv::BitwiseAndOp::fold(spirv::BitwiseAndOp::FoldAdaptor adaptor) {
1153 // x & x -> x
1154 if (getOperand1() == getOperand2()) {
1155 return getOperand1();
1156 }
1157
1158 APInt rhsMask;
1159 if (matchPattern(attr: adaptor.getOperand2(), pattern: m_ConstantInt(bind_value: &rhsMask))) {
1160 // x & 0 -> 0
1161 if (rhsMask.isZero())
1162 return getOperand2();
1163
1164 // x & <all ones> -> x
1165 if (rhsMask.isAllOnes())
1166 return getOperand1();
1167
1168 // (UConvert x : iN to iK) & <mask with N low bits set> -> UConvert x
1169 if (auto zext = getOperand1().getDefiningOp<spirv::UConvertOp>()) {
1170 int valueBits =
1171 getElementTypeOrSelf(val: zext.getOperand()).getIntOrFloatBitWidth();
1172 if (rhsMask.zextOrTrunc(width: valueBits).isAllOnes())
1173 return getOperand1();
1174 }
1175 }
1176
1177 // According to the SPIR-V spec:
1178 //
1179 // Type is a scalar or vector of integer type.
1180 // Results are computed per component, and within each component, per bit.
1181 // So we can use the APInt & method.
1182 return constFoldBinaryOp<IntegerAttr>(
1183 operands: adaptor.getOperands(),
1184 calculate: [](const APInt &a, const APInt &b) { return a & b; });
1185}
1186
1187//===----------------------------------------------------------------------===//
1188// spirv.BitwiseOrOp
1189//===----------------------------------------------------------------------===//
1190
1191OpFoldResult spirv::BitwiseOrOp::fold(spirv::BitwiseOrOp::FoldAdaptor adaptor) {
1192 // x | x -> x
1193 if (getOperand1() == getOperand2()) {
1194 return getOperand1();
1195 }
1196
1197 APInt rhsMask;
1198 if (matchPattern(attr: adaptor.getOperand2(), pattern: m_ConstantInt(bind_value: &rhsMask))) {
1199 // x | 0 -> x
1200 if (rhsMask.isZero())
1201 return getOperand1();
1202
1203 // x | <all ones> -> <all ones>
1204 if (rhsMask.isAllOnes())
1205 return getOperand2();
1206 }
1207
1208 // According to the SPIR-V spec:
1209 //
1210 // Type is a scalar or vector of integer type.
1211 // Results are computed per component, and within each component, per bit.
1212 // So we can use the APInt | method.
1213 return constFoldBinaryOp<IntegerAttr>(
1214 operands: adaptor.getOperands(),
1215 calculate: [](const APInt &a, const APInt &b) { return a | b; });
1216}
1217
1218//===----------------------------------------------------------------------===//
1219// spirv.BitwiseXorOp
1220//===----------------------------------------------------------------------===//
1221
1222OpFoldResult
1223spirv::BitwiseXorOp::fold(spirv::BitwiseXorOp::FoldAdaptor adaptor) {
1224 // x ^ 0 -> x
1225 if (matchPattern(attr: adaptor.getOperand2(), pattern: m_Zero())) {
1226 return getOperand1();
1227 }
1228
1229 // x ^ x -> 0
1230 if (getOperand1() == getOperand2())
1231 return Builder(getContext()).getZeroAttr(type: getType());
1232
1233 // According to the SPIR-V spec:
1234 //
1235 // Type is a scalar or vector of integer type.
1236 // Results are computed per component, and within each component, per bit.
1237 // So we can use the APInt ^ method.
1238 return constFoldBinaryOp<IntegerAttr>(
1239 operands: adaptor.getOperands(),
1240 calculate: [](const APInt &a, const APInt &b) { return a ^ b; });
1241}
1242
1243//===----------------------------------------------------------------------===//
1244// spirv.mlir.selection
1245//===----------------------------------------------------------------------===//
1246
1247namespace {
1248// Blocks from the given `spirv.mlir.selection` operation must satisfy the
1249// following layout:
1250//
1251// +-----------------------------------------------+
1252// | header block |
1253// | spirv.BranchConditionalOp %cond, ^case0, ^case1 |
1254// +-----------------------------------------------+
1255// / \
1256// ...
1257//
1258//
1259// +------------------------+ +------------------------+
1260// | case #0 | | case #1 |
1261// | spirv.Store %ptr %value0 | | spirv.Store %ptr %value1 |
1262// | spirv.Branch ^merge | | spirv.Branch ^merge |
1263// +------------------------+ +------------------------+
1264//
1265//
1266// ...
1267// \ /
1268// v
1269// +-------------+
1270// | merge block |
1271// +-------------+
1272//
1273struct ConvertSelectionOpToSelect final : OpRewritePattern<spirv::SelectionOp> {
1274 using OpRewritePattern::OpRewritePattern;
1275
1276 LogicalResult matchAndRewrite(spirv::SelectionOp selectionOp,
1277 PatternRewriter &rewriter) const override {
1278 Operation *op = selectionOp.getOperation();
1279 Region &body = op->getRegion(index: 0);
1280 // Verifier allows an empty region for `spirv.mlir.selection`.
1281 if (body.empty()) {
1282 return failure();
1283 }
1284
1285 // Check that region consists of 4 blocks:
1286 // header block, `true` block, `false` block and merge block.
1287 if (llvm::range_size(Range&: body) != 4) {
1288 return failure();
1289 }
1290
1291 Block *headerBlock = selectionOp.getHeaderBlock();
1292 if (!onlyContainsBranchConditionalOp(block: headerBlock)) {
1293 return failure();
1294 }
1295
1296 auto brConditionalOp =
1297 cast<spirv::BranchConditionalOp>(Val&: headerBlock->front());
1298
1299 Block *trueBlock = brConditionalOp.getSuccessor(i: 0);
1300 Block *falseBlock = brConditionalOp.getSuccessor(i: 1);
1301 Block *mergeBlock = selectionOp.getMergeBlock();
1302
1303 if (failed(Result: canCanonicalizeSelection(trueBlock, falseBlock, mergeBlock)))
1304 return failure();
1305
1306 Value trueValue = getSrcValue(block: trueBlock);
1307 Value falseValue = getSrcValue(block: falseBlock);
1308 Value ptrValue = getDstPtr(block: trueBlock);
1309 auto storeOpAttributes =
1310 cast<spirv::StoreOp>(Val&: trueBlock->front())->getAttrs();
1311
1312 auto selectOp = rewriter.create<spirv::SelectOp>(
1313 location: selectionOp.getLoc(), args: trueValue.getType(),
1314 args: brConditionalOp.getCondition(), args&: trueValue, args&: falseValue);
1315 rewriter.create<spirv::StoreOp>(location: selectOp.getLoc(), args&: ptrValue,
1316 args: selectOp.getResult(), args&: storeOpAttributes);
1317
1318 // `spirv.mlir.selection` is not needed anymore.
1319 rewriter.eraseOp(op);
1320 return success();
1321 }
1322
1323private:
1324 // Checks that given blocks follow the following rules:
1325 // 1. Each conditional block consists of two operations, the first operation
1326 // is a `spirv.Store` and the last operation is a `spirv.Branch`.
1327 // 2. Each `spirv.Store` uses the same pointer and the same memory attributes.
1328 // 3. A control flow goes into the given merge block from the given
1329 // conditional blocks.
1330 LogicalResult canCanonicalizeSelection(Block *trueBlock, Block *falseBlock,
1331 Block *mergeBlock) const;
1332
1333 bool onlyContainsBranchConditionalOp(Block *block) const {
1334 return llvm::hasSingleElement(C&: *block) &&
1335 isa<spirv::BranchConditionalOp>(Val: block->front());
1336 }
1337
1338 bool isSameAttrList(spirv::StoreOp lhs, spirv::StoreOp rhs) const {
1339 return lhs->getDiscardableAttrDictionary() ==
1340 rhs->getDiscardableAttrDictionary() &&
1341 lhs.getProperties() == rhs.getProperties();
1342 }
1343
1344 // Returns a source value for the given block.
1345 Value getSrcValue(Block *block) const {
1346 auto storeOp = cast<spirv::StoreOp>(Val&: block->front());
1347 return storeOp.getValue();
1348 }
1349
1350 // Returns a destination value for the given block.
1351 Value getDstPtr(Block *block) const {
1352 auto storeOp = cast<spirv::StoreOp>(Val&: block->front());
1353 return storeOp.getPtr();
1354 }
1355};
1356
1357LogicalResult ConvertSelectionOpToSelect::canCanonicalizeSelection(
1358 Block *trueBlock, Block *falseBlock, Block *mergeBlock) const {
1359 // Each block must consists of 2 operations.
1360 if (llvm::range_size(Range&: *trueBlock) != 2 || llvm::range_size(Range&: *falseBlock) != 2) {
1361 return failure();
1362 }
1363
1364 auto trueBrStoreOp = dyn_cast<spirv::StoreOp>(Val&: trueBlock->front());
1365 auto trueBrBranchOp =
1366 dyn_cast<spirv::BranchOp>(Val&: *std::next(x: trueBlock->begin()));
1367 auto falseBrStoreOp = dyn_cast<spirv::StoreOp>(Val&: falseBlock->front());
1368 auto falseBrBranchOp =
1369 dyn_cast<spirv::BranchOp>(Val&: *std::next(x: falseBlock->begin()));
1370
1371 if (!trueBrStoreOp || !trueBrBranchOp || !falseBrStoreOp ||
1372 !falseBrBranchOp) {
1373 return failure();
1374 }
1375
1376 // Checks that given type is valid for `spirv.SelectOp`.
1377 // According to SPIR-V spec:
1378 // "Before version 1.4, Result Type must be a pointer, scalar, or vector.
1379 // Starting with version 1.4, Result Type can additionally be a composite type
1380 // other than a vector."
1381 bool isScalarOrVector =
1382 llvm::cast<spirv::SPIRVType>(Val: trueBrStoreOp.getValue().getType())
1383 .isScalarOrVector();
1384
1385 // Check that each `spirv.Store` uses the same pointer, memory access
1386 // attributes and a valid type of the value.
1387 if ((trueBrStoreOp.getPtr() != falseBrStoreOp.getPtr()) ||
1388 !isSameAttrList(lhs: trueBrStoreOp, rhs: falseBrStoreOp) || !isScalarOrVector) {
1389 return failure();
1390 }
1391
1392 if ((trueBrBranchOp->getSuccessor(index: 0) != mergeBlock) ||
1393 (falseBrBranchOp->getSuccessor(index: 0) != mergeBlock)) {
1394 return failure();
1395 }
1396
1397 return success();
1398}
1399} // namespace
1400
1401void spirv::SelectionOp::getCanonicalizationPatterns(RewritePatternSet &results,
1402 MLIRContext *context) {
1403 results.add<ConvertSelectionOpToSelect>(arg&: context);
1404}
1405

source code of mlir/lib/Dialect/SPIRV/IR/SPIRVCanonicalization.cpp