1//===- SPIRVCanonicalization.cpp - MLIR SPIR-V canonicalization patterns --===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file defines the folders and canonicalization patterns for SPIR-V ops.
10//
11//===----------------------------------------------------------------------===//
12
13#include <optional>
14#include <utility>
15
16#include "mlir/Dialect/SPIRV/IR/SPIRVOps.h"
17
18#include "mlir/Dialect/CommonFolders.h"
19#include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h"
20#include "mlir/Dialect/SPIRV/IR/SPIRVTypes.h"
21#include "mlir/Dialect/UB/IR/UBOps.h"
22#include "mlir/IR/Matchers.h"
23#include "mlir/IR/PatternMatch.h"
24#include "llvm/ADT/STLExtras.h"
25#include "llvm/ADT/SmallVectorExtras.h"
26
27using namespace mlir;
28
29//===----------------------------------------------------------------------===//
30// Common utility functions
31//===----------------------------------------------------------------------===//
32
33/// Returns the boolean value under the hood if the given `boolAttr` is a scalar
34/// or splat vector bool constant.
35static std::optional<bool> getScalarOrSplatBoolAttr(Attribute attr) {
36 if (!attr)
37 return std::nullopt;
38
39 if (auto boolAttr = llvm::dyn_cast<BoolAttr>(Val&: attr))
40 return boolAttr.getValue();
41 if (auto splatAttr = llvm::dyn_cast<SplatElementsAttr>(Val&: attr))
42 if (splatAttr.getElementType().isInteger(width: 1))
43 return splatAttr.getSplatValue<bool>();
44 return std::nullopt;
45}
46
47// Extracts an element from the given `composite` by following the given
48// `indices`. Returns a null Attribute if error happens.
49static Attribute extractCompositeElement(Attribute composite,
50 ArrayRef<unsigned> indices) {
51 // Check that given composite is a constant.
52 if (!composite)
53 return {};
54 // Return composite itself if we reach the end of the index chain.
55 if (indices.empty())
56 return composite;
57
58 if (auto vector = llvm::dyn_cast<ElementsAttr>(composite)) {
59 assert(indices.size() == 1 && "must have exactly one index for a vector");
60 return vector.getValues<Attribute>()[indices[0]];
61 }
62
63 if (auto array = llvm::dyn_cast<ArrayAttr>(composite)) {
64 assert(!indices.empty() && "must have at least one index for an array");
65 return extractCompositeElement(array.getValue()[indices[0]],
66 indices.drop_front());
67 }
68
69 return {};
70}
71
72static bool isDivZeroOrOverflow(const APInt &a, const APInt &b) {
73 bool div0 = b.isZero();
74 bool overflow = a.isMinSignedValue() && b.isAllOnes();
75
76 return div0 || overflow;
77}
78
79//===----------------------------------------------------------------------===//
80// TableGen'erated canonicalizers
81//===----------------------------------------------------------------------===//
82
83namespace {
84#include "SPIRVCanonicalization.inc"
85} // namespace
86
87//===----------------------------------------------------------------------===//
88// spirv.AccessChainOp
89//===----------------------------------------------------------------------===//
90
91namespace {
92
93/// Combines chained `spirv::AccessChainOp` operations into one
94/// `spirv::AccessChainOp` operation.
95struct CombineChainedAccessChain final
96 : OpRewritePattern<spirv::AccessChainOp> {
97 using OpRewritePattern::OpRewritePattern;
98
99 LogicalResult matchAndRewrite(spirv::AccessChainOp accessChainOp,
100 PatternRewriter &rewriter) const override {
101 auto parentAccessChainOp =
102 accessChainOp.getBasePtr().getDefiningOp<spirv::AccessChainOp>();
103
104 if (!parentAccessChainOp) {
105 return failure();
106 }
107
108 // Combine indices.
109 SmallVector<Value, 4> indices(parentAccessChainOp.getIndices());
110 llvm::append_range(indices, accessChainOp.getIndices());
111
112 rewriter.replaceOpWithNewOp<spirv::AccessChainOp>(
113 accessChainOp, parentAccessChainOp.getBasePtr(), indices);
114
115 return success();
116 }
117};
118} // namespace
119
120void spirv::AccessChainOp::getCanonicalizationPatterns(
121 RewritePatternSet &results, MLIRContext *context) {
122 results.add<CombineChainedAccessChain>(context);
123}
124
125//===----------------------------------------------------------------------===//
126// spirv.IAddCarry
127//===----------------------------------------------------------------------===//
128
129// We are required to use CompositeConstructOp to create a constant struct as
130// they are not yet implemented as constant, hence we can not do so in a fold.
131struct IAddCarryFold final : OpRewritePattern<spirv::IAddCarryOp> {
132 using OpRewritePattern::OpRewritePattern;
133
134 LogicalResult matchAndRewrite(spirv::IAddCarryOp op,
135 PatternRewriter &rewriter) const override {
136 Location loc = op.getLoc();
137 Value lhs = op.getOperand1();
138 Value rhs = op.getOperand2();
139 Type constituentType = lhs.getType();
140
141 // iaddcarry (x, 0) = <0, x>
142 if (matchPattern(value: rhs, pattern: m_Zero())) {
143 Value constituents[2] = {rhs, lhs};
144 rewriter.replaceOpWithNewOp<spirv::CompositeConstructOp>(op, op.getType(),
145 constituents);
146 return success();
147 }
148
149 // According to the SPIR-V spec:
150 //
151 // Result Type must be from OpTypeStruct. The struct must have two
152 // members...
153 //
154 // Member 0 of the result gets the low-order bits (full component width) of
155 // the addition.
156 //
157 // Member 1 of the result gets the high-order (carry) bit of the result of
158 // the addition. That is, it gets the value 1 if the addition overflowed
159 // the component width, and 0 otherwise.
160 Attribute lhsAttr;
161 Attribute rhsAttr;
162 if (!matchPattern(value: lhs, pattern: m_Constant(bind_value: &lhsAttr)) ||
163 !matchPattern(value: rhs, pattern: m_Constant(bind_value: &rhsAttr)))
164 return failure();
165
166 auto adds = constFoldBinaryOp<IntegerAttr>(
167 {lhsAttr, rhsAttr},
168 [](const APInt &a, const APInt &b) { return a + b; });
169 if (!adds)
170 return failure();
171
172 auto carrys = constFoldBinaryOp<IntegerAttr>(
173 ArrayRef{adds, lhsAttr}, [](const APInt &a, const APInt &b) {
174 APInt zero = APInt::getZero(numBits: a.getBitWidth());
175 return a.ult(RHS: b) ? (zero + 1) : zero;
176 });
177
178 if (!carrys)
179 return failure();
180
181 Value addsVal =
182 rewriter.create<spirv::ConstantOp>(loc, constituentType, adds);
183
184 Value carrysVal =
185 rewriter.create<spirv::ConstantOp>(loc, constituentType, carrys);
186
187 // Create empty struct
188 Value undef = rewriter.create<spirv::UndefOp>(loc, op.getType());
189 // Fill in adds at id 0
190 Value intermediate =
191 rewriter.create<spirv::CompositeInsertOp>(loc, addsVal, undef, 0);
192 // Fill in carrys at id 1
193 rewriter.replaceOpWithNewOp<spirv::CompositeInsertOp>(op, carrysVal,
194 intermediate, 1);
195 return success();
196 }
197};
198
199void spirv::IAddCarryOp::getCanonicalizationPatterns(
200 RewritePatternSet &patterns, MLIRContext *context) {
201 patterns.add<IAddCarryFold>(context);
202}
203
204//===----------------------------------------------------------------------===//
205// spirv.[S|U]MulExtended
206//===----------------------------------------------------------------------===//
207
208// We are required to use CompositeConstructOp to create a constant struct as
209// they are not yet implemented as constant, hence we can not do so in a fold.
210template <typename MulOp, bool IsSigned>
211struct MulExtendedFold final : OpRewritePattern<MulOp> {
212 using OpRewritePattern<MulOp>::OpRewritePattern;
213
214 LogicalResult matchAndRewrite(MulOp op,
215 PatternRewriter &rewriter) const override {
216 Location loc = op.getLoc();
217 Value lhs = op.getOperand1();
218 Value rhs = op.getOperand2();
219 Type constituentType = lhs.getType();
220
221 // [su]mulextended (x, 0) = <0, 0>
222 if (matchPattern(value: rhs, pattern: m_Zero())) {
223 Value zero = spirv::ConstantOp::getZero(constituentType, loc, rewriter);
224 Value constituents[2] = {zero, zero};
225 rewriter.replaceOpWithNewOp<spirv::CompositeConstructOp>(op, op.getType(),
226 constituents);
227 return success();
228 }
229
230 // According to the SPIR-V spec:
231 //
232 // Result Type must be from OpTypeStruct. The struct must have two
233 // members...
234 //
235 // Member 0 of the result gets the low-order bits of the multiplication.
236 //
237 // Member 1 of the result gets the high-order bits of the multiplication.
238 Attribute lhsAttr;
239 Attribute rhsAttr;
240 if (!matchPattern(value: lhs, pattern: m_Constant(bind_value: &lhsAttr)) ||
241 !matchPattern(value: rhs, pattern: m_Constant(bind_value: &rhsAttr)))
242 return failure();
243
244 auto lowBits = constFoldBinaryOp<IntegerAttr>(
245 {lhsAttr, rhsAttr},
246 [](const APInt &a, const APInt &b) { return a * b; });
247
248 if (!lowBits)
249 return failure();
250
251 auto highBits = constFoldBinaryOp<IntegerAttr>(
252 {lhsAttr, rhsAttr}, [](const APInt &a, const APInt &b) {
253 if (IsSigned) {
254 return llvm::APIntOps::mulhs(C1: a, C2: b);
255 } else {
256 return llvm::APIntOps::mulhu(C1: a, C2: b);
257 }
258 });
259
260 if (!highBits)
261 return failure();
262
263 Value lowBitsVal =
264 rewriter.create<spirv::ConstantOp>(loc, constituentType, lowBits);
265
266 Value highBitsVal =
267 rewriter.create<spirv::ConstantOp>(loc, constituentType, highBits);
268
269 // Create empty struct
270 Value undef = rewriter.create<spirv::UndefOp>(loc, op.getType());
271 // Fill in lowBits at id 0
272 Value intermediate =
273 rewriter.create<spirv::CompositeInsertOp>(loc, lowBitsVal, undef, 0);
274 // Fill in highBits at id 1
275 rewriter.replaceOpWithNewOp<spirv::CompositeInsertOp>(op, highBitsVal,
276 intermediate, 1);
277 return success();
278 }
279};
280
281using SMulExtendedOpFold = MulExtendedFold<spirv::SMulExtendedOp, true>;
282void spirv::SMulExtendedOp::getCanonicalizationPatterns(
283 RewritePatternSet &patterns, MLIRContext *context) {
284 patterns.add<SMulExtendedOpFold>(context);
285}
286
287struct UMulExtendedOpXOne final : OpRewritePattern<spirv::UMulExtendedOp> {
288 using OpRewritePattern::OpRewritePattern;
289
290 LogicalResult matchAndRewrite(spirv::UMulExtendedOp op,
291 PatternRewriter &rewriter) const override {
292 Location loc = op.getLoc();
293 Value lhs = op.getOperand1();
294 Value rhs = op.getOperand2();
295 Type constituentType = lhs.getType();
296
297 // umulextended (x, 1) = <x, 0>
298 if (matchPattern(value: rhs, pattern: m_One())) {
299 Value zero = spirv::ConstantOp::getZero(constituentType, loc, rewriter);
300 Value constituents[2] = {lhs, zero};
301 rewriter.replaceOpWithNewOp<spirv::CompositeConstructOp>(op, op.getType(),
302 constituents);
303 return success();
304 }
305
306 return failure();
307 }
308};
309
310using UMulExtendedOpFold = MulExtendedFold<spirv::UMulExtendedOp, false>;
311void spirv::UMulExtendedOp::getCanonicalizationPatterns(
312 RewritePatternSet &patterns, MLIRContext *context) {
313 patterns.add<UMulExtendedOpFold, UMulExtendedOpXOne>(context);
314}
315
316//===----------------------------------------------------------------------===//
317// spirv.UMod
318//===----------------------------------------------------------------------===//
319
320// Input:
321// %0 = spirv.UMod %arg0, %const32 : i32
322// %1 = spirv.UMod %0, %const4 : i32
323// Output:
324// %0 = spirv.UMod %arg0, %const32 : i32
325// %1 = spirv.UMod %arg0, %const4 : i32
326
327// The transformation is only applied if one divisor is a multiple of the other.
328
329// TODO(https://github.com/llvm/llvm-project/issues/63174): Add support for vector constants
330struct UModSimplification final : OpRewritePattern<spirv::UModOp> {
331 using OpRewritePattern::OpRewritePattern;
332
333 LogicalResult matchAndRewrite(spirv::UModOp umodOp,
334 PatternRewriter &rewriter) const override {
335 auto prevUMod = umodOp.getOperand(0).getDefiningOp<spirv::UModOp>();
336 if (!prevUMod)
337 return failure();
338
339 IntegerAttr prevValue;
340 IntegerAttr currValue;
341 if (!matchPattern(prevUMod.getOperand(1), m_Constant(&prevValue)) ||
342 !matchPattern(umodOp.getOperand(1), m_Constant(&currValue)))
343 return failure();
344
345 APInt prevConstValue = prevValue.getValue();
346 APInt currConstValue = currValue.getValue();
347
348 // Ensure that one divisor is a multiple of the other. If not, fail the
349 // transformation.
350 if (prevConstValue.urem(RHS: currConstValue) != 0 &&
351 currConstValue.urem(RHS: prevConstValue) != 0)
352 return failure();
353
354 // The transformation is safe. Replace the existing UMod operation with a
355 // new UMod operation, using the original dividend and the current divisor.
356 rewriter.replaceOpWithNewOp<spirv::UModOp>(
357 umodOp, umodOp.getType(), prevUMod.getOperand(0), umodOp.getOperand(1));
358
359 return success();
360 }
361};
362
363void spirv::UModOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
364 MLIRContext *context) {
365 patterns.insert<UModSimplification>(context);
366}
367
368//===----------------------------------------------------------------------===//
369// spirv.BitcastOp
370//===----------------------------------------------------------------------===//
371
372OpFoldResult spirv::BitcastOp::fold(FoldAdaptor /*adaptor*/) {
373 Value curInput = getOperand();
374 if (getType() == curInput.getType())
375 return curInput;
376
377 // Look through nested bitcasts.
378 if (auto prevCast = curInput.getDefiningOp<spirv::BitcastOp>()) {
379 Value prevInput = prevCast.getOperand();
380 if (prevInput.getType() == getType())
381 return prevInput;
382
383 getOperandMutable().assign(prevInput);
384 return getResult();
385 }
386
387 // TODO(kuhar): Consider constant-folding the operand attribute.
388 return {};
389}
390
391//===----------------------------------------------------------------------===//
392// spirv.CompositeExtractOp
393//===----------------------------------------------------------------------===//
394
395OpFoldResult spirv::CompositeExtractOp::fold(FoldAdaptor adaptor) {
396 Value compositeOp = getComposite();
397
398 while (auto insertOp =
399 compositeOp.getDefiningOp<spirv::CompositeInsertOp>()) {
400 if (getIndices() == insertOp.getIndices())
401 return insertOp.getObject();
402 compositeOp = insertOp.getComposite();
403 }
404
405 if (auto constructOp =
406 compositeOp.getDefiningOp<spirv::CompositeConstructOp>()) {
407 auto type = llvm::cast<spirv::CompositeType>(constructOp.getType());
408 if (getIndices().size() == 1 &&
409 constructOp.getConstituents().size() == type.getNumElements()) {
410 auto i = llvm::cast<IntegerAttr>(*getIndices().begin());
411 if (i.getValue().getSExtValue() <
412 static_cast<int64_t>(constructOp.getConstituents().size()))
413 return constructOp.getConstituents()[i.getValue().getSExtValue()];
414 }
415 }
416
417 auto indexVector = llvm::map_to_vector(getIndices(), [](Attribute attr) {
418 return static_cast<unsigned>(llvm::cast<IntegerAttr>(attr).getInt());
419 });
420 return extractCompositeElement(adaptor.getComposite(), indexVector);
421}
422
423//===----------------------------------------------------------------------===//
424// spirv.Constant
425//===----------------------------------------------------------------------===//
426
427OpFoldResult spirv::ConstantOp::fold(FoldAdaptor /*adaptor*/) {
428 return getValue();
429}
430
431//===----------------------------------------------------------------------===//
432// spirv.IAdd
433//===----------------------------------------------------------------------===//
434
435OpFoldResult spirv::IAddOp::fold(FoldAdaptor adaptor) {
436 // x + 0 = x
437 if (matchPattern(getOperand2(), m_Zero()))
438 return getOperand1();
439
440 // According to the SPIR-V spec:
441 //
442 // The resulting value will equal the low-order N bits of the correct result
443 // R, where N is the component width and R is computed with enough precision
444 // to avoid overflow and underflow.
445 return constFoldBinaryOp<IntegerAttr>(
446 adaptor.getOperands(),
447 [](APInt a, const APInt &b) { return std::move(a) + b; });
448}
449
450//===----------------------------------------------------------------------===//
451// spirv.IMul
452//===----------------------------------------------------------------------===//
453
454OpFoldResult spirv::IMulOp::fold(FoldAdaptor adaptor) {
455 // x * 0 == 0
456 if (matchPattern(getOperand2(), m_Zero()))
457 return getOperand2();
458 // x * 1 = x
459 if (matchPattern(getOperand2(), m_One()))
460 return getOperand1();
461
462 // According to the SPIR-V spec:
463 //
464 // The resulting value will equal the low-order N bits of the correct result
465 // R, where N is the component width and R is computed with enough precision
466 // to avoid overflow and underflow.
467 return constFoldBinaryOp<IntegerAttr>(
468 adaptor.getOperands(),
469 [](const APInt &a, const APInt &b) { return a * b; });
470}
471
472//===----------------------------------------------------------------------===//
473// spirv.ISub
474//===----------------------------------------------------------------------===//
475
476OpFoldResult spirv::ISubOp::fold(FoldAdaptor adaptor) {
477 // x - x = 0
478 if (getOperand1() == getOperand2())
479 return Builder(getContext()).getIntegerAttr(getType(), 0);
480
481 // According to the SPIR-V spec:
482 //
483 // The resulting value will equal the low-order N bits of the correct result
484 // R, where N is the component width and R is computed with enough precision
485 // to avoid overflow and underflow.
486 return constFoldBinaryOp<IntegerAttr>(
487 adaptor.getOperands(),
488 [](APInt a, const APInt &b) { return std::move(a) - b; });
489}
490
491//===----------------------------------------------------------------------===//
492// spirv.SDiv
493//===----------------------------------------------------------------------===//
494
495OpFoldResult spirv::SDivOp::fold(FoldAdaptor adaptor) {
496 // sdiv (x, 1) = x
497 if (matchPattern(getOperand2(), m_One()))
498 return getOperand1();
499
500 // According to the SPIR-V spec:
501 //
502 // Signed-integer division of Operand 1 divided by Operand 2.
503 // Results are computed per component. Behavior is undefined if Operand 2 is
504 // 0. Behavior is undefined if Operand 2 is -1 and Operand 1 is the minimum
505 // representable value for the operands' type, causing signed overflow.
506 //
507 // So don't fold during undefined behavior.
508 bool div0OrOverflow = false;
509 auto res = constFoldBinaryOp<IntegerAttr>(
510 adaptor.getOperands(), [&](const APInt &a, const APInt &b) {
511 if (div0OrOverflow || isDivZeroOrOverflow(a, b)) {
512 div0OrOverflow = true;
513 return a;
514 }
515 return a.sdiv(b);
516 });
517 return div0OrOverflow ? Attribute() : res;
518}
519
520//===----------------------------------------------------------------------===//
521// spirv.SMod
522//===----------------------------------------------------------------------===//
523
524OpFoldResult spirv::SModOp::fold(FoldAdaptor adaptor) {
525 // smod (x, 1) = 0
526 if (matchPattern(getOperand2(), m_One()))
527 return Builder(getContext()).getZeroAttr(getType());
528
529 // According to SPIR-V spec:
530 //
531 // Signed remainder operation for the remainder whose sign matches the sign
532 // of Operand 2. Behavior is undefined if Operand 2 is 0. Behavior is
533 // undefined if Operand 2 is -1 and Operand 1 is the minimum representable
534 // value for the operands' type, causing signed overflow. Otherwise, the
535 // result is the remainder r of Operand 1 divided by Operand 2 where if
536 // r ≠ 0, the sign of r is the same as the sign of Operand 2.
537 //
538 // So don't fold during undefined behavior
539 bool div0OrOverflow = false;
540 auto res = constFoldBinaryOp<IntegerAttr>(
541 adaptor.getOperands(), [&](const APInt &a, const APInt &b) {
542 if (div0OrOverflow || isDivZeroOrOverflow(a, b)) {
543 div0OrOverflow = true;
544 return a;
545 }
546 APInt c = a.abs().urem(b.abs());
547 if (c.isZero())
548 return c;
549 if (b.isNegative()) {
550 APInt zero = APInt::getZero(c.getBitWidth());
551 return a.isNegative() ? (zero - c) : (b + c);
552 }
553 return a.isNegative() ? (b - c) : c;
554 });
555 return div0OrOverflow ? Attribute() : res;
556}
557
558//===----------------------------------------------------------------------===//
559// spirv.SRem
560//===----------------------------------------------------------------------===//
561
562OpFoldResult spirv::SRemOp::fold(FoldAdaptor adaptor) {
563 // x % 1 = 0
564 if (matchPattern(getOperand2(), m_One()))
565 return Builder(getContext()).getZeroAttr(getType());
566
567 // According to SPIR-V spec:
568 //
569 // Signed remainder operation for the remainder whose sign matches the sign
570 // of Operand 1. Behavior is undefined if Operand 2 is 0. Behavior is
571 // undefined if Operand 2 is -1 and Operand 1 is the minimum representable
572 // value for the operands' type, causing signed overflow. Otherwise, the
573 // result is the remainder r of Operand 1 divided by Operand 2 where if
574 // r ≠ 0, the sign of r is the same as the sign of Operand 1.
575
576 // Don't fold if it would do undefined behavior.
577 bool div0OrOverflow = false;
578 auto res = constFoldBinaryOp<IntegerAttr>(
579 adaptor.getOperands(), [&](APInt a, const APInt &b) {
580 if (div0OrOverflow || isDivZeroOrOverflow(a, b)) {
581 div0OrOverflow = true;
582 return a;
583 }
584 return a.srem(b);
585 });
586 return div0OrOverflow ? Attribute() : res;
587}
588
589//===----------------------------------------------------------------------===//
590// spirv.UDiv
591//===----------------------------------------------------------------------===//
592
593OpFoldResult spirv::UDivOp::fold(FoldAdaptor adaptor) {
594 // udiv (x, 1) = x
595 if (matchPattern(getOperand2(), m_One()))
596 return getOperand1();
597
598 // According to the SPIR-V spec:
599 //
600 // Unsigned-integer division of Operand 1 divided by Operand 2. Behavior is
601 // undefined if Operand 2 is 0.
602 //
603 // So don't fold during undefined behavior.
604 bool div0 = false;
605 auto res = constFoldBinaryOp<IntegerAttr>(
606 adaptor.getOperands(), [&](const APInt &a, const APInt &b) {
607 if (div0 || b.isZero()) {
608 div0 = true;
609 return a;
610 }
611 return a.udiv(b);
612 });
613 return div0 ? Attribute() : res;
614}
615
616//===----------------------------------------------------------------------===//
617// spirv.UMod
618//===----------------------------------------------------------------------===//
619
620OpFoldResult spirv::UModOp::fold(FoldAdaptor adaptor) {
621 // umod (x, 1) = 0
622 if (matchPattern(getOperand2(), m_One()))
623 return Builder(getContext()).getZeroAttr(getType());
624
625 // According to the SPIR-V spec:
626 //
627 // Unsigned modulo operation of Operand 1 modulo Operand 2. Behavior is
628 // undefined if Operand 2 is 0.
629 //
630 // So don't fold during undefined behavior.
631 bool div0 = false;
632 auto res = constFoldBinaryOp<IntegerAttr>(
633 adaptor.getOperands(), [&](const APInt &a, const APInt &b) {
634 if (div0 || b.isZero()) {
635 div0 = true;
636 return a;
637 }
638 return a.urem(b);
639 });
640 return div0 ? Attribute() : res;
641}
642
643//===----------------------------------------------------------------------===//
644// spirv.SNegate
645//===----------------------------------------------------------------------===//
646
647OpFoldResult spirv::SNegateOp::fold(FoldAdaptor adaptor) {
648 // -(-x) = 0 - (0 - x) = x
649 auto op = getOperand();
650 if (auto negateOp = op.getDefiningOp<spirv::SNegateOp>())
651 return negateOp->getOperand(0);
652
653 // According to the SPIR-V spec:
654 //
655 // Signed-integer subtract of Operand from zero.
656 return constFoldUnaryOp<IntegerAttr>(
657 adaptor.getOperands(), [](const APInt &a) {
658 APInt zero = APInt::getZero(a.getBitWidth());
659 return zero - a;
660 });
661}
662
663//===----------------------------------------------------------------------===//
664// spirv.NotOp
665//===----------------------------------------------------------------------===//
666
667OpFoldResult spirv::NotOp::fold(spirv::NotOp::FoldAdaptor adaptor) {
668 // !(!x) = x
669 auto op = getOperand();
670 if (auto notOp = op.getDefiningOp<spirv::NotOp>())
671 return notOp->getOperand(0);
672
673 // According to the SPIR-V spec:
674 //
675 // Complement the bits of Operand.
676 return constFoldUnaryOp<IntegerAttr>(adaptor.getOperands(), [&](APInt a) {
677 a.flipAllBits();
678 return a;
679 });
680}
681
682//===----------------------------------------------------------------------===//
683// spirv.LogicalAnd
684//===----------------------------------------------------------------------===//
685
686OpFoldResult spirv::LogicalAndOp::fold(FoldAdaptor adaptor) {
687 if (std::optional<bool> rhs =
688 getScalarOrSplatBoolAttr(adaptor.getOperand2())) {
689 // x && true = x
690 if (*rhs)
691 return getOperand1();
692
693 // x && false = false
694 if (!*rhs)
695 return adaptor.getOperand2();
696 }
697
698 return Attribute();
699}
700
701//===----------------------------------------------------------------------===//
702// spirv.LogicalEqualOp
703//===----------------------------------------------------------------------===//
704
705OpFoldResult
706spirv::LogicalEqualOp::fold(spirv::LogicalEqualOp::FoldAdaptor adaptor) {
707 // x == x -> true
708 if (getOperand1() == getOperand2()) {
709 auto trueAttr = BoolAttr::get(getContext(), true);
710 if (isa<IntegerType>(getType()))
711 return trueAttr;
712 if (auto vecTy = dyn_cast<VectorType>(getType()))
713 return SplatElementsAttr::get(vecTy, trueAttr);
714 }
715
716 return constFoldBinaryOp<IntegerAttr>(
717 adaptor.getOperands(), [](const APInt &a, const APInt &b) {
718 return a == b ? APInt::getAllOnes(1) : APInt::getZero(1);
719 });
720}
721
722//===----------------------------------------------------------------------===//
723// spirv.LogicalNotEqualOp
724//===----------------------------------------------------------------------===//
725
726OpFoldResult spirv::LogicalNotEqualOp::fold(FoldAdaptor adaptor) {
727 if (std::optional<bool> rhs =
728 getScalarOrSplatBoolAttr(adaptor.getOperand2())) {
729 // x != false -> x
730 if (!rhs.value())
731 return getOperand1();
732 }
733
734 // x == x -> false
735 if (getOperand1() == getOperand2()) {
736 auto falseAttr = BoolAttr::get(getContext(), false);
737 if (isa<IntegerType>(getType()))
738 return falseAttr;
739 if (auto vecTy = dyn_cast<VectorType>(getType()))
740 return SplatElementsAttr::get(vecTy, falseAttr);
741 }
742
743 return constFoldBinaryOp<IntegerAttr>(
744 adaptor.getOperands(), [](const APInt &a, const APInt &b) {
745 return a == b ? APInt::getZero(1) : APInt::getAllOnes(1);
746 });
747}
748
749//===----------------------------------------------------------------------===//
750// spirv.LogicalNot
751//===----------------------------------------------------------------------===//
752
753OpFoldResult spirv::LogicalNotOp::fold(FoldAdaptor adaptor) {
754 // !(!x) = x
755 auto op = getOperand();
756 if (auto notOp = op.getDefiningOp<spirv::LogicalNotOp>())
757 return notOp->getOperand(0);
758
759 // According to the SPIR-V spec:
760 //
761 // Complement the bits of Operand.
762 return constFoldUnaryOp<IntegerAttr>(adaptor.getOperands(),
763 [](const APInt &a) {
764 APInt zero = APInt::getZero(1);
765 return a == 1 ? zero : (zero + 1);
766 });
767}
768
769void spirv::LogicalNotOp::getCanonicalizationPatterns(
770 RewritePatternSet &results, MLIRContext *context) {
771 results
772 .add<ConvertLogicalNotOfIEqual, ConvertLogicalNotOfINotEqual,
773 ConvertLogicalNotOfLogicalEqual, ConvertLogicalNotOfLogicalNotEqual>(
774 context);
775}
776
777//===----------------------------------------------------------------------===//
778// spirv.LogicalOr
779//===----------------------------------------------------------------------===//
780
781OpFoldResult spirv::LogicalOrOp::fold(FoldAdaptor adaptor) {
782 if (auto rhs = getScalarOrSplatBoolAttr(adaptor.getOperand2())) {
783 if (*rhs) {
784 // x || true = true
785 return adaptor.getOperand2();
786 }
787
788 if (!*rhs) {
789 // x || false = x
790 return getOperand1();
791 }
792 }
793
794 return Attribute();
795}
796
797//===----------------------------------------------------------------------===//
798// spirv.SelectOp
799//===----------------------------------------------------------------------===//
800
801OpFoldResult spirv::SelectOp::fold(FoldAdaptor adaptor) {
802 // spirv.Select _ x x -> x
803 Value trueVals = getTrueValue();
804 Value falseVals = getFalseValue();
805 if (trueVals == falseVals)
806 return trueVals;
807
808 ArrayRef<Attribute> operands = adaptor.getOperands();
809
810 // spirv.Select true x y -> x
811 // spirv.Select false x y -> y
812 if (auto boolAttr = getScalarOrSplatBoolAttr(operands[0]))
813 return *boolAttr ? trueVals : falseVals;
814
815 // Check that all the operands are constant
816 if (!operands[0] || !operands[1] || !operands[2])
817 return Attribute();
818
819 // Note: getScalarOrSplatBoolAttr will always return a boolAttr if we are in
820 // the scalar case. Hence, we are only required to consider the case of
821 // DenseElementsAttr in foldSelectOp.
822 auto condAttrs = dyn_cast<DenseElementsAttr>(operands[0]);
823 auto trueAttrs = dyn_cast<DenseElementsAttr>(operands[1]);
824 auto falseAttrs = dyn_cast<DenseElementsAttr>(operands[2]);
825 if (!condAttrs || !trueAttrs || !falseAttrs)
826 return Attribute();
827
828 auto elementResults = llvm::to_vector<4>(trueAttrs.getValues<Attribute>());
829 auto iters = llvm::zip_equal(elementResults, condAttrs.getValues<BoolAttr>(),
830 falseAttrs.getValues<Attribute>());
831 for (auto [result, cond, falseRes] : iters) {
832 if (!cond.getValue())
833 result = falseRes;
834 }
835
836 auto resultType = trueAttrs.getType();
837 return DenseElementsAttr::get(cast<ShapedType>(resultType), elementResults);
838}
839
840//===----------------------------------------------------------------------===//
841// spirv.IEqualOp
842//===----------------------------------------------------------------------===//
843
844OpFoldResult spirv::IEqualOp::fold(spirv::IEqualOp::FoldAdaptor adaptor) {
845 // x == x -> true
846 if (getOperand1() == getOperand2()) {
847 auto trueAttr = BoolAttr::get(getContext(), true);
848 if (isa<IntegerType>(getType()))
849 return trueAttr;
850 if (auto vecTy = dyn_cast<VectorType>(getType()))
851 return SplatElementsAttr::get(vecTy, trueAttr);
852 }
853
854 return constFoldBinaryOp<IntegerAttr>(
855 adaptor.getOperands(), getType(), [](const APInt &a, const APInt &b) {
856 return a == b ? APInt::getAllOnes(1) : APInt::getZero(1);
857 });
858}
859
860//===----------------------------------------------------------------------===//
861// spirv.INotEqualOp
862//===----------------------------------------------------------------------===//
863
864OpFoldResult spirv::INotEqualOp::fold(spirv::INotEqualOp::FoldAdaptor adaptor) {
865 // x == x -> false
866 if (getOperand1() == getOperand2()) {
867 auto falseAttr = BoolAttr::get(getContext(), false);
868 if (isa<IntegerType>(getType()))
869 return falseAttr;
870 if (auto vecTy = dyn_cast<VectorType>(getType()))
871 return SplatElementsAttr::get(vecTy, falseAttr);
872 }
873
874 return constFoldBinaryOp<IntegerAttr>(
875 adaptor.getOperands(), getType(), [](const APInt &a, const APInt &b) {
876 return a == b ? APInt::getZero(1) : APInt::getAllOnes(1);
877 });
878}
879
880//===----------------------------------------------------------------------===//
881// spirv.SGreaterThan
882//===----------------------------------------------------------------------===//
883
884OpFoldResult
885spirv::SGreaterThanOp::fold(spirv::SGreaterThanOp::FoldAdaptor adaptor) {
886 // x == x -> false
887 if (getOperand1() == getOperand2()) {
888 auto falseAttr = BoolAttr::get(getContext(), false);
889 if (isa<IntegerType>(getType()))
890 return falseAttr;
891 if (auto vecTy = dyn_cast<VectorType>(getType()))
892 return SplatElementsAttr::get(vecTy, falseAttr);
893 }
894
895 return constFoldBinaryOp<IntegerAttr>(
896 adaptor.getOperands(), getType(), [](const APInt &a, const APInt &b) {
897 return a.sgt(b) ? APInt::getAllOnes(1) : APInt::getZero(1);
898 });
899}
900
901//===----------------------------------------------------------------------===//
902// spirv.SGreaterThanEqual
903//===----------------------------------------------------------------------===//
904
905OpFoldResult spirv::SGreaterThanEqualOp::fold(
906 spirv::SGreaterThanEqualOp::FoldAdaptor adaptor) {
907 // x == x -> true
908 if (getOperand1() == getOperand2()) {
909 auto trueAttr = BoolAttr::get(getContext(), true);
910 if (isa<IntegerType>(getType()))
911 return trueAttr;
912 if (auto vecTy = dyn_cast<VectorType>(getType()))
913 return SplatElementsAttr::get(vecTy, trueAttr);
914 }
915
916 return constFoldBinaryOp<IntegerAttr>(
917 adaptor.getOperands(), getType(), [](const APInt &a, const APInt &b) {
918 return a.sge(b) ? APInt::getAllOnes(1) : APInt::getZero(1);
919 });
920}
921
922//===----------------------------------------------------------------------===//
923// spirv.UGreaterThan
924//===----------------------------------------------------------------------===//
925
926OpFoldResult
927spirv::UGreaterThanOp::fold(spirv::UGreaterThanOp::FoldAdaptor adaptor) {
928 // x == x -> false
929 if (getOperand1() == getOperand2()) {
930 auto falseAttr = BoolAttr::get(getContext(), false);
931 if (isa<IntegerType>(getType()))
932 return falseAttr;
933 if (auto vecTy = dyn_cast<VectorType>(getType()))
934 return SplatElementsAttr::get(vecTy, falseAttr);
935 }
936
937 return constFoldBinaryOp<IntegerAttr>(
938 adaptor.getOperands(), getType(), [](const APInt &a, const APInt &b) {
939 return a.ugt(b) ? APInt::getAllOnes(1) : APInt::getZero(1);
940 });
941}
942
943//===----------------------------------------------------------------------===//
944// spirv.UGreaterThanEqual
945//===----------------------------------------------------------------------===//
946
947OpFoldResult spirv::UGreaterThanEqualOp::fold(
948 spirv::UGreaterThanEqualOp::FoldAdaptor adaptor) {
949 // x == x -> true
950 if (getOperand1() == getOperand2()) {
951 auto trueAttr = BoolAttr::get(getContext(), true);
952 if (isa<IntegerType>(getType()))
953 return trueAttr;
954 if (auto vecTy = dyn_cast<VectorType>(getType()))
955 return SplatElementsAttr::get(vecTy, trueAttr);
956 }
957
958 return constFoldBinaryOp<IntegerAttr>(
959 adaptor.getOperands(), getType(), [](const APInt &a, const APInt &b) {
960 return a.uge(b) ? APInt::getAllOnes(1) : APInt::getZero(1);
961 });
962}
963
964//===----------------------------------------------------------------------===//
965// spirv.SLessThan
966//===----------------------------------------------------------------------===//
967
968OpFoldResult spirv::SLessThanOp::fold(spirv::SLessThanOp::FoldAdaptor adaptor) {
969 // x == x -> false
970 if (getOperand1() == getOperand2()) {
971 auto falseAttr = BoolAttr::get(getContext(), false);
972 if (isa<IntegerType>(getType()))
973 return falseAttr;
974 if (auto vecTy = dyn_cast<VectorType>(getType()))
975 return SplatElementsAttr::get(vecTy, falseAttr);
976 }
977
978 return constFoldBinaryOp<IntegerAttr>(
979 adaptor.getOperands(), getType(), [](const APInt &a, const APInt &b) {
980 return a.slt(b) ? APInt::getAllOnes(1) : APInt::getZero(1);
981 });
982}
983
984//===----------------------------------------------------------------------===//
985// spirv.SLessThanEqual
986//===----------------------------------------------------------------------===//
987
988OpFoldResult
989spirv::SLessThanEqualOp::fold(spirv::SLessThanEqualOp::FoldAdaptor adaptor) {
990 // x == x -> true
991 if (getOperand1() == getOperand2()) {
992 auto trueAttr = BoolAttr::get(getContext(), true);
993 if (isa<IntegerType>(getType()))
994 return trueAttr;
995 if (auto vecTy = dyn_cast<VectorType>(getType()))
996 return SplatElementsAttr::get(vecTy, trueAttr);
997 }
998
999 return constFoldBinaryOp<IntegerAttr>(
1000 adaptor.getOperands(), getType(), [](const APInt &a, const APInt &b) {
1001 return a.sle(b) ? APInt::getAllOnes(1) : APInt::getZero(1);
1002 });
1003}
1004
1005//===----------------------------------------------------------------------===//
1006// spirv.ULessThan
1007//===----------------------------------------------------------------------===//
1008
1009OpFoldResult spirv::ULessThanOp::fold(spirv::ULessThanOp::FoldAdaptor adaptor) {
1010 // x == x -> false
1011 if (getOperand1() == getOperand2()) {
1012 auto falseAttr = BoolAttr::get(getContext(), false);
1013 if (isa<IntegerType>(getType()))
1014 return falseAttr;
1015 if (auto vecTy = dyn_cast<VectorType>(getType()))
1016 return SplatElementsAttr::get(vecTy, falseAttr);
1017 }
1018
1019 return constFoldBinaryOp<IntegerAttr>(
1020 adaptor.getOperands(), getType(), [](const APInt &a, const APInt &b) {
1021 return a.ult(b) ? APInt::getAllOnes(1) : APInt::getZero(1);
1022 });
1023}
1024
1025//===----------------------------------------------------------------------===//
1026// spirv.ULessThanEqual
1027//===----------------------------------------------------------------------===//
1028
1029OpFoldResult
1030spirv::ULessThanEqualOp::fold(spirv::ULessThanEqualOp::FoldAdaptor adaptor) {
1031 // x == x -> true
1032 if (getOperand1() == getOperand2()) {
1033 auto trueAttr = BoolAttr::get(getContext(), true);
1034 if (isa<IntegerType>(getType()))
1035 return trueAttr;
1036 if (auto vecTy = dyn_cast<VectorType>(getType()))
1037 return SplatElementsAttr::get(vecTy, trueAttr);
1038 }
1039
1040 return constFoldBinaryOp<IntegerAttr>(
1041 adaptor.getOperands(), getType(), [](const APInt &a, const APInt &b) {
1042 return a.ule(b) ? APInt::getAllOnes(1) : APInt::getZero(1);
1043 });
1044}
1045
1046//===----------------------------------------------------------------------===//
1047// spirv.ShiftLeftLogical
1048//===----------------------------------------------------------------------===//
1049
1050OpFoldResult spirv::ShiftLeftLogicalOp::fold(
1051 spirv::ShiftLeftLogicalOp::FoldAdaptor adaptor) {
1052 // x << 0 -> x
1053 if (matchPattern(adaptor.getOperand2(), m_Zero())) {
1054 return getOperand1();
1055 }
1056
1057 // Unfortunately due to below undefined behaviour can't fold 0 for Base.
1058
1059 // Results are computed per component, and within each component, per bit...
1060 //
1061 // The result is undefined if Shift is greater than or equal to the bit width
1062 // of the components of Base.
1063 //
1064 // So we can use the APInt << method, but don't fold if undefined behaviour.
1065 bool shiftToLarge = false;
1066 auto res = constFoldBinaryOp<IntegerAttr>(
1067 adaptor.getOperands(), [&](const APInt &a, const APInt &b) {
1068 if (shiftToLarge || b.uge(a.getBitWidth())) {
1069 shiftToLarge = true;
1070 return a;
1071 }
1072 return a << b;
1073 });
1074 return shiftToLarge ? Attribute() : res;
1075}
1076
1077//===----------------------------------------------------------------------===//
1078// spirv.ShiftRightArithmetic
1079//===----------------------------------------------------------------------===//
1080
1081OpFoldResult spirv::ShiftRightArithmeticOp::fold(
1082 spirv::ShiftRightArithmeticOp::FoldAdaptor adaptor) {
1083 // x >> 0 -> x
1084 if (matchPattern(adaptor.getOperand2(), m_Zero())) {
1085 return getOperand1();
1086 }
1087
1088 // Unfortunately due to below undefined behaviour can't fold 0, -1 for Base.
1089
1090 // Results are computed per component, and within each component, per bit...
1091 //
1092 // The result is undefined if Shift is greater than or equal to the bit width
1093 // of the components of Base.
1094 //
1095 // So we can use the APInt ashr method, but don't fold if undefined behaviour.
1096 bool shiftToLarge = false;
1097 auto res = constFoldBinaryOp<IntegerAttr>(
1098 adaptor.getOperands(), [&](const APInt &a, const APInt &b) {
1099 if (shiftToLarge || b.uge(a.getBitWidth())) {
1100 shiftToLarge = true;
1101 return a;
1102 }
1103 return a.ashr(b);
1104 });
1105 return shiftToLarge ? Attribute() : res;
1106}
1107
1108//===----------------------------------------------------------------------===//
1109// spirv.ShiftRightLogical
1110//===----------------------------------------------------------------------===//
1111
1112OpFoldResult spirv::ShiftRightLogicalOp::fold(
1113 spirv::ShiftRightLogicalOp::FoldAdaptor adaptor) {
1114 // x >> 0 -> x
1115 if (matchPattern(adaptor.getOperand2(), m_Zero())) {
1116 return getOperand1();
1117 }
1118
1119 // Unfortunately due to below undefined behaviour can't fold 0 for Base.
1120
1121 // Results are computed per component, and within each component, per bit...
1122 //
1123 // The result is undefined if Shift is greater than or equal to the bit width
1124 // of the components of Base.
1125 //
1126 // So we can use the APInt lshr method, but don't fold if undefined behaviour.
1127 bool shiftToLarge = false;
1128 auto res = constFoldBinaryOp<IntegerAttr>(
1129 adaptor.getOperands(), [&](const APInt &a, const APInt &b) {
1130 if (shiftToLarge || b.uge(a.getBitWidth())) {
1131 shiftToLarge = true;
1132 return a;
1133 }
1134 return a.lshr(b);
1135 });
1136 return shiftToLarge ? Attribute() : res;
1137}
1138
1139//===----------------------------------------------------------------------===//
1140// spirv.BitwiseAndOp
1141//===----------------------------------------------------------------------===//
1142
1143OpFoldResult
1144spirv::BitwiseAndOp::fold(spirv::BitwiseAndOp::FoldAdaptor adaptor) {
1145 // x & x -> x
1146 if (getOperand1() == getOperand2()) {
1147 return getOperand1();
1148 }
1149
1150 APInt rhsMask;
1151 if (matchPattern(adaptor.getOperand2(), m_ConstantInt(&rhsMask))) {
1152 // x & 0 -> 0
1153 if (rhsMask.isZero())
1154 return getOperand2();
1155
1156 // x & <all ones> -> x
1157 if (rhsMask.isAllOnes())
1158 return getOperand1();
1159
1160 // (UConvert x : iN to iK) & <mask with N low bits set> -> UConvert x
1161 if (auto zext = getOperand1().getDefiningOp<spirv::UConvertOp>()) {
1162 int valueBits =
1163 getElementTypeOrSelf(zext.getOperand()).getIntOrFloatBitWidth();
1164 if (rhsMask.zextOrTrunc(valueBits).isAllOnes())
1165 return getOperand1();
1166 }
1167 }
1168
1169 // According to the SPIR-V spec:
1170 //
1171 // Type is a scalar or vector of integer type.
1172 // Results are computed per component, and within each component, per bit.
1173 // So we can use the APInt & method.
1174 return constFoldBinaryOp<IntegerAttr>(
1175 adaptor.getOperands(),
1176 [](const APInt &a, const APInt &b) { return a & b; });
1177}
1178
1179//===----------------------------------------------------------------------===//
1180// spirv.BitwiseOrOp
1181//===----------------------------------------------------------------------===//
1182
1183OpFoldResult spirv::BitwiseOrOp::fold(spirv::BitwiseOrOp::FoldAdaptor adaptor) {
1184 // x | x -> x
1185 if (getOperand1() == getOperand2()) {
1186 return getOperand1();
1187 }
1188
1189 APInt rhsMask;
1190 if (matchPattern(adaptor.getOperand2(), m_ConstantInt(&rhsMask))) {
1191 // x | 0 -> x
1192 if (rhsMask.isZero())
1193 return getOperand1();
1194
1195 // x | <all ones> -> <all ones>
1196 if (rhsMask.isAllOnes())
1197 return getOperand2();
1198 }
1199
1200 // According to the SPIR-V spec:
1201 //
1202 // Type is a scalar or vector of integer type.
1203 // Results are computed per component, and within each component, per bit.
1204 // So we can use the APInt | method.
1205 return constFoldBinaryOp<IntegerAttr>(
1206 adaptor.getOperands(),
1207 [](const APInt &a, const APInt &b) { return a | b; });
1208}
1209
1210//===----------------------------------------------------------------------===//
1211// spirv.BitwiseXorOp
1212//===----------------------------------------------------------------------===//
1213
1214OpFoldResult
1215spirv::BitwiseXorOp::fold(spirv::BitwiseXorOp::FoldAdaptor adaptor) {
1216 // x ^ 0 -> x
1217 if (matchPattern(adaptor.getOperand2(), m_Zero())) {
1218 return getOperand1();
1219 }
1220
1221 // x ^ x -> 0
1222 if (getOperand1() == getOperand2())
1223 return Builder(getContext()).getZeroAttr(getType());
1224
1225 // According to the SPIR-V spec:
1226 //
1227 // Type is a scalar or vector of integer type.
1228 // Results are computed per component, and within each component, per bit.
1229 // So we can use the APInt ^ method.
1230 return constFoldBinaryOp<IntegerAttr>(
1231 adaptor.getOperands(),
1232 [](const APInt &a, const APInt &b) { return a ^ b; });
1233}
1234
1235//===----------------------------------------------------------------------===//
1236// spirv.mlir.selection
1237//===----------------------------------------------------------------------===//
1238
1239namespace {
1240// Blocks from the given `spirv.mlir.selection` operation must satisfy the
1241// following layout:
1242//
1243// +-----------------------------------------------+
1244// | header block |
1245// | spirv.BranchConditionalOp %cond, ^case0, ^case1 |
1246// +-----------------------------------------------+
1247// / \
1248// ...
1249//
1250//
1251// +------------------------+ +------------------------+
1252// | case #0 | | case #1 |
1253// | spirv.Store %ptr %value0 | | spirv.Store %ptr %value1 |
1254// | spirv.Branch ^merge | | spirv.Branch ^merge |
1255// +------------------------+ +------------------------+
1256//
1257//
1258// ...
1259// \ /
1260// v
1261// +-------------+
1262// | merge block |
1263// +-------------+
1264//
1265struct ConvertSelectionOpToSelect final : OpRewritePattern<spirv::SelectionOp> {
1266 using OpRewritePattern::OpRewritePattern;
1267
1268 LogicalResult matchAndRewrite(spirv::SelectionOp selectionOp,
1269 PatternRewriter &rewriter) const override {
1270 Operation *op = selectionOp.getOperation();
1271 Region &body = op->getRegion(index: 0);
1272 // Verifier allows an empty region for `spirv.mlir.selection`.
1273 if (body.empty()) {
1274 return failure();
1275 }
1276
1277 // Check that region consists of 4 blocks:
1278 // header block, `true` block, `false` block and merge block.
1279 if (llvm::range_size(Range&: body) != 4) {
1280 return failure();
1281 }
1282
1283 Block *headerBlock = selectionOp.getHeaderBlock();
1284 if (!onlyContainsBranchConditionalOp(block: headerBlock)) {
1285 return failure();
1286 }
1287
1288 auto brConditionalOp =
1289 cast<spirv::BranchConditionalOp>(headerBlock->front());
1290
1291 Block *trueBlock = brConditionalOp.getSuccessor(0);
1292 Block *falseBlock = brConditionalOp.getSuccessor(1);
1293 Block *mergeBlock = selectionOp.getMergeBlock();
1294
1295 if (failed(result: canCanonicalizeSelection(trueBlock, falseBlock, mergeBlock)))
1296 return failure();
1297
1298 Value trueValue = getSrcValue(block: trueBlock);
1299 Value falseValue = getSrcValue(block: falseBlock);
1300 Value ptrValue = getDstPtr(block: trueBlock);
1301 auto storeOpAttributes =
1302 cast<spirv::StoreOp>(trueBlock->front())->getAttrs();
1303
1304 auto selectOp = rewriter.create<spirv::SelectOp>(
1305 selectionOp.getLoc(), trueValue.getType(),
1306 brConditionalOp.getCondition(), trueValue, falseValue);
1307 rewriter.create<spirv::StoreOp>(selectOp.getLoc(), ptrValue,
1308 selectOp.getResult(), storeOpAttributes);
1309
1310 // `spirv.mlir.selection` is not needed anymore.
1311 rewriter.eraseOp(op);
1312 return success();
1313 }
1314
1315private:
1316 // Checks that given blocks follow the following rules:
1317 // 1. Each conditional block consists of two operations, the first operation
1318 // is a `spirv.Store` and the last operation is a `spirv.Branch`.
1319 // 2. Each `spirv.Store` uses the same pointer and the same memory attributes.
1320 // 3. A control flow goes into the given merge block from the given
1321 // conditional blocks.
1322 LogicalResult canCanonicalizeSelection(Block *trueBlock, Block *falseBlock,
1323 Block *mergeBlock) const;
1324
1325 bool onlyContainsBranchConditionalOp(Block *block) const {
1326 return llvm::hasSingleElement(*block) &&
1327 isa<spirv::BranchConditionalOp>(block->front());
1328 }
1329
1330 bool isSameAttrList(spirv::StoreOp lhs, spirv::StoreOp rhs) const {
1331 return lhs->getDiscardableAttrDictionary() ==
1332 rhs->getDiscardableAttrDictionary() &&
1333 lhs.getProperties() == rhs.getProperties();
1334 }
1335
1336 // Returns a source value for the given block.
1337 Value getSrcValue(Block *block) const {
1338 auto storeOp = cast<spirv::StoreOp>(block->front());
1339 return storeOp.getValue();
1340 }
1341
1342 // Returns a destination value for the given block.
1343 Value getDstPtr(Block *block) const {
1344 auto storeOp = cast<spirv::StoreOp>(block->front());
1345 return storeOp.getPtr();
1346 }
1347};
1348
1349LogicalResult ConvertSelectionOpToSelect::canCanonicalizeSelection(
1350 Block *trueBlock, Block *falseBlock, Block *mergeBlock) const {
1351 // Each block must consists of 2 operations.
1352 if (llvm::range_size(Range&: *trueBlock) != 2 || llvm::range_size(Range&: *falseBlock) != 2) {
1353 return failure();
1354 }
1355
1356 auto trueBrStoreOp = dyn_cast<spirv::StoreOp>(trueBlock->front());
1357 auto trueBrBranchOp =
1358 dyn_cast<spirv::BranchOp>(*std::next(trueBlock->begin()));
1359 auto falseBrStoreOp = dyn_cast<spirv::StoreOp>(falseBlock->front());
1360 auto falseBrBranchOp =
1361 dyn_cast<spirv::BranchOp>(*std::next(falseBlock->begin()));
1362
1363 if (!trueBrStoreOp || !trueBrBranchOp || !falseBrStoreOp ||
1364 !falseBrBranchOp) {
1365 return failure();
1366 }
1367
1368 // Checks that given type is valid for `spirv.SelectOp`.
1369 // According to SPIR-V spec:
1370 // "Before version 1.4, Result Type must be a pointer, scalar, or vector.
1371 // Starting with version 1.4, Result Type can additionally be a composite type
1372 // other than a vector."
1373 bool isScalarOrVector =
1374 llvm::cast<spirv::SPIRVType>(trueBrStoreOp.getValue().getType())
1375 .isScalarOrVector();
1376
1377 // Check that each `spirv.Store` uses the same pointer, memory access
1378 // attributes and a valid type of the value.
1379 if ((trueBrStoreOp.getPtr() != falseBrStoreOp.getPtr()) ||
1380 !isSameAttrList(trueBrStoreOp, falseBrStoreOp) || !isScalarOrVector) {
1381 return failure();
1382 }
1383
1384 if ((trueBrBranchOp->getSuccessor(0) != mergeBlock) ||
1385 (falseBrBranchOp->getSuccessor(0) != mergeBlock)) {
1386 return failure();
1387 }
1388
1389 return success();
1390}
1391} // namespace
1392
1393void spirv::SelectionOp::getCanonicalizationPatterns(RewritePatternSet &results,
1394 MLIRContext *context) {
1395 results.add<ConvertSelectionOpToSelect>(context);
1396}
1397

source code of mlir/lib/Dialect/SPIRV/IR/SPIRVCanonicalization.cpp