1//===- SPIRVCanonicalization.cpp - MLIR SPIR-V canonicalization patterns --===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file defines the folders and canonicalization patterns for SPIR-V ops.
10//
11//===----------------------------------------------------------------------===//
12
13#include <optional>
14#include <utility>
15
16#include "mlir/Dialect/SPIRV/IR/SPIRVOps.h"
17
18#include "mlir/Dialect/CommonFolders.h"
19#include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h"
20#include "mlir/Dialect/SPIRV/IR/SPIRVTypes.h"
21#include "mlir/Dialect/UB/IR/UBOps.h"
22#include "mlir/IR/Matchers.h"
23#include "mlir/IR/PatternMatch.h"
24#include "llvm/ADT/STLExtras.h"
25#include "llvm/ADT/SmallVectorExtras.h"
26
27using namespace mlir;
28
29//===----------------------------------------------------------------------===//
30// Common utility functions
31//===----------------------------------------------------------------------===//
32
33/// Returns the boolean value under the hood if the given `boolAttr` is a scalar
34/// or splat vector bool constant.
35static std::optional<bool> getScalarOrSplatBoolAttr(Attribute attr) {
36 if (!attr)
37 return std::nullopt;
38
39 if (auto boolAttr = llvm::dyn_cast<BoolAttr>(Val&: attr))
40 return boolAttr.getValue();
41 if (auto splatAttr = llvm::dyn_cast<SplatElementsAttr>(Val&: attr))
42 if (splatAttr.getElementType().isInteger(width: 1))
43 return splatAttr.getSplatValue<bool>();
44 return std::nullopt;
45}
46
47// Extracts an element from the given `composite` by following the given
48// `indices`. Returns a null Attribute if error happens.
49static Attribute extractCompositeElement(Attribute composite,
50 ArrayRef<unsigned> indices) {
51 // Check that given composite is a constant.
52 if (!composite)
53 return {};
54 // Return composite itself if we reach the end of the index chain.
55 if (indices.empty())
56 return composite;
57
58 if (auto vector = llvm::dyn_cast<ElementsAttr>(composite)) {
59 assert(indices.size() == 1 && "must have exactly one index for a vector");
60 return vector.getValues<Attribute>()[indices[0]];
61 }
62
63 if (auto array = llvm::dyn_cast<ArrayAttr>(composite)) {
64 assert(!indices.empty() && "must have at least one index for an array");
65 return extractCompositeElement(array.getValue()[indices[0]],
66 indices.drop_front());
67 }
68
69 return {};
70}
71
72static bool isDivZeroOrOverflow(const APInt &a, const APInt &b) {
73 bool div0 = b.isZero();
74 bool overflow = a.isMinSignedValue() && b.isAllOnes();
75
76 return div0 || overflow;
77}
78
79//===----------------------------------------------------------------------===//
80// TableGen'erated canonicalizers
81//===----------------------------------------------------------------------===//
82
83namespace {
84#include "SPIRVCanonicalization.inc"
85} // namespace
86
87//===----------------------------------------------------------------------===//
88// spirv.AccessChainOp
89//===----------------------------------------------------------------------===//
90
91namespace {
92
93/// Combines chained `spirv::AccessChainOp` operations into one
94/// `spirv::AccessChainOp` operation.
95struct CombineChainedAccessChain final
96 : OpRewritePattern<spirv::AccessChainOp> {
97 using OpRewritePattern::OpRewritePattern;
98
99 LogicalResult matchAndRewrite(spirv::AccessChainOp accessChainOp,
100 PatternRewriter &rewriter) const override {
101 auto parentAccessChainOp =
102 accessChainOp.getBasePtr().getDefiningOp<spirv::AccessChainOp>();
103
104 if (!parentAccessChainOp) {
105 return failure();
106 }
107
108 // Combine indices.
109 SmallVector<Value, 4> indices(parentAccessChainOp.getIndices());
110 llvm::append_range(indices, accessChainOp.getIndices());
111
112 rewriter.replaceOpWithNewOp<spirv::AccessChainOp>(
113 accessChainOp, parentAccessChainOp.getBasePtr(), indices);
114
115 return success();
116 }
117};
118} // namespace
119
120void spirv::AccessChainOp::getCanonicalizationPatterns(
121 RewritePatternSet &results, MLIRContext *context) {
122 results.add<CombineChainedAccessChain>(context);
123}
124
125//===----------------------------------------------------------------------===//
126// spirv.IAddCarry
127//===----------------------------------------------------------------------===//
128
129// We are required to use CompositeConstructOp to create a constant struct as
130// they are not yet implemented as constant, hence we can not do so in a fold.
131struct IAddCarryFold final : OpRewritePattern<spirv::IAddCarryOp> {
132 using OpRewritePattern::OpRewritePattern;
133
134 LogicalResult matchAndRewrite(spirv::IAddCarryOp op,
135 PatternRewriter &rewriter) const override {
136 Location loc = op.getLoc();
137 Value lhs = op.getOperand1();
138 Value rhs = op.getOperand2();
139 Type constituentType = lhs.getType();
140
141 // iaddcarry (x, 0) = <0, x>
142 if (matchPattern(value: rhs, pattern: m_Zero())) {
143 Value constituents[2] = {rhs, lhs};
144 rewriter.replaceOpWithNewOp<spirv::CompositeConstructOp>(op, op.getType(),
145 constituents);
146 return success();
147 }
148
149 // According to the SPIR-V spec:
150 //
151 // Result Type must be from OpTypeStruct. The struct must have two
152 // members...
153 //
154 // Member 0 of the result gets the low-order bits (full component width) of
155 // the addition.
156 //
157 // Member 1 of the result gets the high-order (carry) bit of the result of
158 // the addition. That is, it gets the value 1 if the addition overflowed
159 // the component width, and 0 otherwise.
160 Attribute lhsAttr;
161 Attribute rhsAttr;
162 if (!matchPattern(value: lhs, pattern: m_Constant(bind_value: &lhsAttr)) ||
163 !matchPattern(value: rhs, pattern: m_Constant(bind_value: &rhsAttr)))
164 return failure();
165
166 auto adds = constFoldBinaryOp<IntegerAttr>(
167 {lhsAttr, rhsAttr},
168 [](const APInt &a, const APInt &b) { return a + b; });
169 if (!adds)
170 return failure();
171
172 auto carrys = constFoldBinaryOp<IntegerAttr>(
173 ArrayRef{adds, lhsAttr}, [](const APInt &a, const APInt &b) {
174 APInt zero = APInt::getZero(numBits: a.getBitWidth());
175 return a.ult(RHS: b) ? (zero + 1) : zero;
176 });
177
178 if (!carrys)
179 return failure();
180
181 Value addsVal =
182 rewriter.create<spirv::ConstantOp>(loc, constituentType, adds);
183
184 Value carrysVal =
185 rewriter.create<spirv::ConstantOp>(loc, constituentType, carrys);
186
187 // Create empty struct
188 Value undef = rewriter.create<spirv::UndefOp>(loc, op.getType());
189 // Fill in adds at id 0
190 Value intermediate =
191 rewriter.create<spirv::CompositeInsertOp>(loc, addsVal, undef, 0);
192 // Fill in carrys at id 1
193 rewriter.replaceOpWithNewOp<spirv::CompositeInsertOp>(op, carrysVal,
194 intermediate, 1);
195 return success();
196 }
197};
198
199void spirv::IAddCarryOp::getCanonicalizationPatterns(
200 RewritePatternSet &patterns, MLIRContext *context) {
201 patterns.add<IAddCarryFold>(context);
202}
203
204//===----------------------------------------------------------------------===//
205// spirv.[S|U]MulExtended
206//===----------------------------------------------------------------------===//
207
208// We are required to use CompositeConstructOp to create a constant struct as
209// they are not yet implemented as constant, hence we can not do so in a fold.
210template <typename MulOp, bool IsSigned>
211struct MulExtendedFold final : OpRewritePattern<MulOp> {
212 using OpRewritePattern<MulOp>::OpRewritePattern;
213
214 LogicalResult matchAndRewrite(MulOp op,
215 PatternRewriter &rewriter) const override {
216 Location loc = op.getLoc();
217 Value lhs = op.getOperand1();
218 Value rhs = op.getOperand2();
219 Type constituentType = lhs.getType();
220
221 // [su]mulextended (x, 0) = <0, 0>
222 if (matchPattern(value: rhs, pattern: m_Zero())) {
223 Value zero = spirv::ConstantOp::getZero(constituentType, loc, rewriter);
224 Value constituents[2] = {zero, zero};
225 rewriter.replaceOpWithNewOp<spirv::CompositeConstructOp>(op, op.getType(),
226 constituents);
227 return success();
228 }
229
230 // According to the SPIR-V spec:
231 //
232 // Result Type must be from OpTypeStruct. The struct must have two
233 // members...
234 //
235 // Member 0 of the result gets the low-order bits of the multiplication.
236 //
237 // Member 1 of the result gets the high-order bits of the multiplication.
238 Attribute lhsAttr;
239 Attribute rhsAttr;
240 if (!matchPattern(value: lhs, pattern: m_Constant(bind_value: &lhsAttr)) ||
241 !matchPattern(value: rhs, pattern: m_Constant(bind_value: &rhsAttr)))
242 return failure();
243
244 auto lowBits = constFoldBinaryOp<IntegerAttr>(
245 {lhsAttr, rhsAttr},
246 [](const APInt &a, const APInt &b) { return a * b; });
247
248 if (!lowBits)
249 return failure();
250
251 auto highBits = constFoldBinaryOp<IntegerAttr>(
252 {lhsAttr, rhsAttr}, [](const APInt &a, const APInt &b) {
253 if (IsSigned) {
254 return llvm::APIntOps::mulhs(C1: a, C2: b);
255 } else {
256 return llvm::APIntOps::mulhu(C1: a, C2: b);
257 }
258 });
259
260 if (!highBits)
261 return failure();
262
263 Value lowBitsVal =
264 rewriter.create<spirv::ConstantOp>(loc, constituentType, lowBits);
265
266 Value highBitsVal =
267 rewriter.create<spirv::ConstantOp>(loc, constituentType, highBits);
268
269 // Create empty struct
270 Value undef = rewriter.create<spirv::UndefOp>(loc, op.getType());
271 // Fill in lowBits at id 0
272 Value intermediate =
273 rewriter.create<spirv::CompositeInsertOp>(loc, lowBitsVal, undef, 0);
274 // Fill in highBits at id 1
275 rewriter.replaceOpWithNewOp<spirv::CompositeInsertOp>(op, highBitsVal,
276 intermediate, 1);
277 return success();
278 }
279};
280
281using SMulExtendedOpFold = MulExtendedFold<spirv::SMulExtendedOp, true>;
282void spirv::SMulExtendedOp::getCanonicalizationPatterns(
283 RewritePatternSet &patterns, MLIRContext *context) {
284 patterns.add<SMulExtendedOpFold>(context);
285}
286
287struct UMulExtendedOpXOne final : OpRewritePattern<spirv::UMulExtendedOp> {
288 using OpRewritePattern::OpRewritePattern;
289
290 LogicalResult matchAndRewrite(spirv::UMulExtendedOp op,
291 PatternRewriter &rewriter) const override {
292 Location loc = op.getLoc();
293 Value lhs = op.getOperand1();
294 Value rhs = op.getOperand2();
295 Type constituentType = lhs.getType();
296
297 // umulextended (x, 1) = <x, 0>
298 if (matchPattern(value: rhs, pattern: m_One())) {
299 Value zero = spirv::ConstantOp::getZero(constituentType, loc, rewriter);
300 Value constituents[2] = {lhs, zero};
301 rewriter.replaceOpWithNewOp<spirv::CompositeConstructOp>(op, op.getType(),
302 constituents);
303 return success();
304 }
305
306 return failure();
307 }
308};
309
310using UMulExtendedOpFold = MulExtendedFold<spirv::UMulExtendedOp, false>;
311void spirv::UMulExtendedOp::getCanonicalizationPatterns(
312 RewritePatternSet &patterns, MLIRContext *context) {
313 patterns.add<UMulExtendedOpFold, UMulExtendedOpXOne>(context);
314}
315
316//===----------------------------------------------------------------------===//
317// spirv.UMod
318//===----------------------------------------------------------------------===//
319
320// Input:
321// %0 = spirv.UMod %arg0, %const32 : i32
322// %1 = spirv.UMod %0, %const4 : i32
323// Output:
324// %0 = spirv.UMod %arg0, %const32 : i32
325// %1 = spirv.UMod %arg0, %const4 : i32
326
327// The transformation is only applied if one divisor is a multiple of the other.
328
329struct UModSimplification final : OpRewritePattern<spirv::UModOp> {
330 using OpRewritePattern::OpRewritePattern;
331
332 LogicalResult matchAndRewrite(spirv::UModOp umodOp,
333 PatternRewriter &rewriter) const override {
334 auto prevUMod = umodOp.getOperand(0).getDefiningOp<spirv::UModOp>();
335 if (!prevUMod)
336 return failure();
337
338 TypedAttr prevValue;
339 TypedAttr currValue;
340 if (!matchPattern(prevUMod.getOperand(1), m_Constant(&prevValue)) ||
341 !matchPattern(umodOp.getOperand(1), m_Constant(&currValue)))
342 return failure();
343
344 // Ensure that previous divisor is a multiple of the current divisor. If
345 // not, fail the transformation.
346 bool isApplicable = false;
347 if (auto prevInt = dyn_cast<IntegerAttr>(prevValue)) {
348 auto currInt = cast<IntegerAttr>(currValue);
349 isApplicable = prevInt.getValue().urem(currInt.getValue()) == 0;
350 } else if (auto prevVec = dyn_cast<DenseElementsAttr>(prevValue)) {
351 auto currVec = cast<DenseElementsAttr>(currValue);
352 isApplicable = llvm::all_of(llvm::zip_equal(prevVec.getValues<APInt>(),
353 currVec.getValues<APInt>()),
354 [](const auto &pair) {
355 auto &[prev, curr] = pair;
356 return prev.urem(curr) == 0;
357 });
358 }
359
360 if (!isApplicable)
361 return failure();
362
363 // The transformation is safe. Replace the existing UMod operation with a
364 // new UMod operation, using the original dividend and the current divisor.
365 rewriter.replaceOpWithNewOp<spirv::UModOp>(
366 umodOp, umodOp.getType(), prevUMod.getOperand(0), umodOp.getOperand(1));
367
368 return success();
369 }
370};
371
372void spirv::UModOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
373 MLIRContext *context) {
374 patterns.insert<UModSimplification>(context);
375}
376
377//===----------------------------------------------------------------------===//
378// spirv.BitcastOp
379//===----------------------------------------------------------------------===//
380
381OpFoldResult spirv::BitcastOp::fold(FoldAdaptor /*adaptor*/) {
382 Value curInput = getOperand();
383 if (getType() == curInput.getType())
384 return curInput;
385
386 // Look through nested bitcasts.
387 if (auto prevCast = curInput.getDefiningOp<spirv::BitcastOp>()) {
388 Value prevInput = prevCast.getOperand();
389 if (prevInput.getType() == getType())
390 return prevInput;
391
392 getOperandMutable().assign(prevInput);
393 return getResult();
394 }
395
396 // TODO(kuhar): Consider constant-folding the operand attribute.
397 return {};
398}
399
400//===----------------------------------------------------------------------===//
401// spirv.CompositeExtractOp
402//===----------------------------------------------------------------------===//
403
404OpFoldResult spirv::CompositeExtractOp::fold(FoldAdaptor adaptor) {
405 Value compositeOp = getComposite();
406
407 while (auto insertOp =
408 compositeOp.getDefiningOp<spirv::CompositeInsertOp>()) {
409 if (getIndices() == insertOp.getIndices())
410 return insertOp.getObject();
411 compositeOp = insertOp.getComposite();
412 }
413
414 if (auto constructOp =
415 compositeOp.getDefiningOp<spirv::CompositeConstructOp>()) {
416 auto type = llvm::cast<spirv::CompositeType>(constructOp.getType());
417 if (getIndices().size() == 1 &&
418 constructOp.getConstituents().size() == type.getNumElements()) {
419 auto i = llvm::cast<IntegerAttr>(*getIndices().begin());
420 if (i.getValue().getSExtValue() <
421 static_cast<int64_t>(constructOp.getConstituents().size()))
422 return constructOp.getConstituents()[i.getValue().getSExtValue()];
423 }
424 }
425
426 auto indexVector = llvm::map_to_vector(getIndices(), [](Attribute attr) {
427 return static_cast<unsigned>(llvm::cast<IntegerAttr>(attr).getInt());
428 });
429 return extractCompositeElement(adaptor.getComposite(), indexVector);
430}
431
432//===----------------------------------------------------------------------===//
433// spirv.Constant
434//===----------------------------------------------------------------------===//
435
436OpFoldResult spirv::ConstantOp::fold(FoldAdaptor /*adaptor*/) {
437 return getValue();
438}
439
440//===----------------------------------------------------------------------===//
441// spirv.IAdd
442//===----------------------------------------------------------------------===//
443
444OpFoldResult spirv::IAddOp::fold(FoldAdaptor adaptor) {
445 // x + 0 = x
446 if (matchPattern(getOperand2(), m_Zero()))
447 return getOperand1();
448
449 // According to the SPIR-V spec:
450 //
451 // The resulting value will equal the low-order N bits of the correct result
452 // R, where N is the component width and R is computed with enough precision
453 // to avoid overflow and underflow.
454 return constFoldBinaryOp<IntegerAttr>(
455 adaptor.getOperands(),
456 [](APInt a, const APInt &b) { return std::move(a) + b; });
457}
458
459//===----------------------------------------------------------------------===//
460// spirv.IMul
461//===----------------------------------------------------------------------===//
462
463OpFoldResult spirv::IMulOp::fold(FoldAdaptor adaptor) {
464 // x * 0 == 0
465 if (matchPattern(getOperand2(), m_Zero()))
466 return getOperand2();
467 // x * 1 = x
468 if (matchPattern(getOperand2(), m_One()))
469 return getOperand1();
470
471 // According to the SPIR-V spec:
472 //
473 // The resulting value will equal the low-order N bits of the correct result
474 // R, where N is the component width and R is computed with enough precision
475 // to avoid overflow and underflow.
476 return constFoldBinaryOp<IntegerAttr>(
477 adaptor.getOperands(),
478 [](const APInt &a, const APInt &b) { return a * b; });
479}
480
481//===----------------------------------------------------------------------===//
482// spirv.ISub
483//===----------------------------------------------------------------------===//
484
485OpFoldResult spirv::ISubOp::fold(FoldAdaptor adaptor) {
486 // x - x = 0
487 if (getOperand1() == getOperand2())
488 return Builder(getContext()).getZeroAttr(getType());
489
490 // According to the SPIR-V spec:
491 //
492 // The resulting value will equal the low-order N bits of the correct result
493 // R, where N is the component width and R is computed with enough precision
494 // to avoid overflow and underflow.
495 return constFoldBinaryOp<IntegerAttr>(
496 adaptor.getOperands(),
497 [](APInt a, const APInt &b) { return std::move(a) - b; });
498}
499
500//===----------------------------------------------------------------------===//
501// spirv.SDiv
502//===----------------------------------------------------------------------===//
503
504OpFoldResult spirv::SDivOp::fold(FoldAdaptor adaptor) {
505 // sdiv (x, 1) = x
506 if (matchPattern(getOperand2(), m_One()))
507 return getOperand1();
508
509 // According to the SPIR-V spec:
510 //
511 // Signed-integer division of Operand 1 divided by Operand 2.
512 // Results are computed per component. Behavior is undefined if Operand 2 is
513 // 0. Behavior is undefined if Operand 2 is -1 and Operand 1 is the minimum
514 // representable value for the operands' type, causing signed overflow.
515 //
516 // So don't fold during undefined behavior.
517 bool div0OrOverflow = false;
518 auto res = constFoldBinaryOp<IntegerAttr>(
519 adaptor.getOperands(), [&](const APInt &a, const APInt &b) {
520 if (div0OrOverflow || isDivZeroOrOverflow(a, b)) {
521 div0OrOverflow = true;
522 return a;
523 }
524 return a.sdiv(b);
525 });
526 return div0OrOverflow ? Attribute() : res;
527}
528
529//===----------------------------------------------------------------------===//
530// spirv.SMod
531//===----------------------------------------------------------------------===//
532
533OpFoldResult spirv::SModOp::fold(FoldAdaptor adaptor) {
534 // smod (x, 1) = 0
535 if (matchPattern(getOperand2(), m_One()))
536 return Builder(getContext()).getZeroAttr(getType());
537
538 // According to SPIR-V spec:
539 //
540 // Signed remainder operation for the remainder whose sign matches the sign
541 // of Operand 2. Behavior is undefined if Operand 2 is 0. Behavior is
542 // undefined if Operand 2 is -1 and Operand 1 is the minimum representable
543 // value for the operands' type, causing signed overflow. Otherwise, the
544 // result is the remainder r of Operand 1 divided by Operand 2 where if
545 // r ≠ 0, the sign of r is the same as the sign of Operand 2.
546 //
547 // So don't fold during undefined behavior
548 bool div0OrOverflow = false;
549 auto res = constFoldBinaryOp<IntegerAttr>(
550 adaptor.getOperands(), [&](const APInt &a, const APInt &b) {
551 if (div0OrOverflow || isDivZeroOrOverflow(a, b)) {
552 div0OrOverflow = true;
553 return a;
554 }
555 APInt c = a.abs().urem(b.abs());
556 if (c.isZero())
557 return c;
558 if (b.isNegative()) {
559 APInt zero = APInt::getZero(c.getBitWidth());
560 return a.isNegative() ? (zero - c) : (b + c);
561 }
562 return a.isNegative() ? (b - c) : c;
563 });
564 return div0OrOverflow ? Attribute() : res;
565}
566
567//===----------------------------------------------------------------------===//
568// spirv.SRem
569//===----------------------------------------------------------------------===//
570
571OpFoldResult spirv::SRemOp::fold(FoldAdaptor adaptor) {
572 // x % 1 = 0
573 if (matchPattern(getOperand2(), m_One()))
574 return Builder(getContext()).getZeroAttr(getType());
575
576 // According to SPIR-V spec:
577 //
578 // Signed remainder operation for the remainder whose sign matches the sign
579 // of Operand 1. Behavior is undefined if Operand 2 is 0. Behavior is
580 // undefined if Operand 2 is -1 and Operand 1 is the minimum representable
581 // value for the operands' type, causing signed overflow. Otherwise, the
582 // result is the remainder r of Operand 1 divided by Operand 2 where if
583 // r ≠ 0, the sign of r is the same as the sign of Operand 1.
584
585 // Don't fold if it would do undefined behavior.
586 bool div0OrOverflow = false;
587 auto res = constFoldBinaryOp<IntegerAttr>(
588 adaptor.getOperands(), [&](APInt a, const APInt &b) {
589 if (div0OrOverflow || isDivZeroOrOverflow(a, b)) {
590 div0OrOverflow = true;
591 return a;
592 }
593 return a.srem(b);
594 });
595 return div0OrOverflow ? Attribute() : res;
596}
597
598//===----------------------------------------------------------------------===//
599// spirv.UDiv
600//===----------------------------------------------------------------------===//
601
602OpFoldResult spirv::UDivOp::fold(FoldAdaptor adaptor) {
603 // udiv (x, 1) = x
604 if (matchPattern(getOperand2(), m_One()))
605 return getOperand1();
606
607 // According to the SPIR-V spec:
608 //
609 // Unsigned-integer division of Operand 1 divided by Operand 2. Behavior is
610 // undefined if Operand 2 is 0.
611 //
612 // So don't fold during undefined behavior.
613 bool div0 = false;
614 auto res = constFoldBinaryOp<IntegerAttr>(
615 adaptor.getOperands(), [&](const APInt &a, const APInt &b) {
616 if (div0 || b.isZero()) {
617 div0 = true;
618 return a;
619 }
620 return a.udiv(b);
621 });
622 return div0 ? Attribute() : res;
623}
624
625//===----------------------------------------------------------------------===//
626// spirv.UMod
627//===----------------------------------------------------------------------===//
628
629OpFoldResult spirv::UModOp::fold(FoldAdaptor adaptor) {
630 // umod (x, 1) = 0
631 if (matchPattern(getOperand2(), m_One()))
632 return Builder(getContext()).getZeroAttr(getType());
633
634 // According to the SPIR-V spec:
635 //
636 // Unsigned modulo operation of Operand 1 modulo Operand 2. Behavior is
637 // undefined if Operand 2 is 0.
638 //
639 // So don't fold during undefined behavior.
640 bool div0 = false;
641 auto res = constFoldBinaryOp<IntegerAttr>(
642 adaptor.getOperands(), [&](const APInt &a, const APInt &b) {
643 if (div0 || b.isZero()) {
644 div0 = true;
645 return a;
646 }
647 return a.urem(b);
648 });
649 return div0 ? Attribute() : res;
650}
651
652//===----------------------------------------------------------------------===//
653// spirv.SNegate
654//===----------------------------------------------------------------------===//
655
656OpFoldResult spirv::SNegateOp::fold(FoldAdaptor adaptor) {
657 // -(-x) = 0 - (0 - x) = x
658 auto op = getOperand();
659 if (auto negateOp = op.getDefiningOp<spirv::SNegateOp>())
660 return negateOp->getOperand(0);
661
662 // According to the SPIR-V spec:
663 //
664 // Signed-integer subtract of Operand from zero.
665 return constFoldUnaryOp<IntegerAttr>(
666 adaptor.getOperands(), [](const APInt &a) {
667 APInt zero = APInt::getZero(a.getBitWidth());
668 return zero - a;
669 });
670}
671
672//===----------------------------------------------------------------------===//
673// spirv.NotOp
674//===----------------------------------------------------------------------===//
675
676OpFoldResult spirv::NotOp::fold(spirv::NotOp::FoldAdaptor adaptor) {
677 // !(!x) = x
678 auto op = getOperand();
679 if (auto notOp = op.getDefiningOp<spirv::NotOp>())
680 return notOp->getOperand(0);
681
682 // According to the SPIR-V spec:
683 //
684 // Complement the bits of Operand.
685 return constFoldUnaryOp<IntegerAttr>(adaptor.getOperands(), [&](APInt a) {
686 a.flipAllBits();
687 return a;
688 });
689}
690
691//===----------------------------------------------------------------------===//
692// spirv.LogicalAnd
693//===----------------------------------------------------------------------===//
694
695OpFoldResult spirv::LogicalAndOp::fold(FoldAdaptor adaptor) {
696 if (std::optional<bool> rhs =
697 getScalarOrSplatBoolAttr(adaptor.getOperand2())) {
698 // x && true = x
699 if (*rhs)
700 return getOperand1();
701
702 // x && false = false
703 if (!*rhs)
704 return adaptor.getOperand2();
705 }
706
707 return Attribute();
708}
709
710//===----------------------------------------------------------------------===//
711// spirv.LogicalEqualOp
712//===----------------------------------------------------------------------===//
713
714OpFoldResult
715spirv::LogicalEqualOp::fold(spirv::LogicalEqualOp::FoldAdaptor adaptor) {
716 // x == x -> true
717 if (getOperand1() == getOperand2()) {
718 auto trueAttr = BoolAttr::get(getContext(), true);
719 if (isa<IntegerType>(getType()))
720 return trueAttr;
721 if (auto vecTy = dyn_cast<VectorType>(getType()))
722 return SplatElementsAttr::get(vecTy, trueAttr);
723 }
724
725 return constFoldBinaryOp<IntegerAttr>(
726 adaptor.getOperands(), [](const APInt &a, const APInt &b) {
727 return a == b ? APInt::getAllOnes(1) : APInt::getZero(1);
728 });
729}
730
731//===----------------------------------------------------------------------===//
732// spirv.LogicalNotEqualOp
733//===----------------------------------------------------------------------===//
734
735OpFoldResult spirv::LogicalNotEqualOp::fold(FoldAdaptor adaptor) {
736 if (std::optional<bool> rhs =
737 getScalarOrSplatBoolAttr(adaptor.getOperand2())) {
738 // x != false -> x
739 if (!rhs.value())
740 return getOperand1();
741 }
742
743 // x == x -> false
744 if (getOperand1() == getOperand2()) {
745 auto falseAttr = BoolAttr::get(getContext(), false);
746 if (isa<IntegerType>(getType()))
747 return falseAttr;
748 if (auto vecTy = dyn_cast<VectorType>(getType()))
749 return SplatElementsAttr::get(vecTy, falseAttr);
750 }
751
752 return constFoldBinaryOp<IntegerAttr>(
753 adaptor.getOperands(), [](const APInt &a, const APInt &b) {
754 return a == b ? APInt::getZero(1) : APInt::getAllOnes(1);
755 });
756}
757
758//===----------------------------------------------------------------------===//
759// spirv.LogicalNot
760//===----------------------------------------------------------------------===//
761
762OpFoldResult spirv::LogicalNotOp::fold(FoldAdaptor adaptor) {
763 // !(!x) = x
764 auto op = getOperand();
765 if (auto notOp = op.getDefiningOp<spirv::LogicalNotOp>())
766 return notOp->getOperand(0);
767
768 // According to the SPIR-V spec:
769 //
770 // Complement the bits of Operand.
771 return constFoldUnaryOp<IntegerAttr>(adaptor.getOperands(),
772 [](const APInt &a) {
773 APInt zero = APInt::getZero(1);
774 return a == 1 ? zero : (zero + 1);
775 });
776}
777
778void spirv::LogicalNotOp::getCanonicalizationPatterns(
779 RewritePatternSet &results, MLIRContext *context) {
780 results
781 .add<ConvertLogicalNotOfIEqual, ConvertLogicalNotOfINotEqual,
782 ConvertLogicalNotOfLogicalEqual, ConvertLogicalNotOfLogicalNotEqual>(
783 context);
784}
785
786//===----------------------------------------------------------------------===//
787// spirv.LogicalOr
788//===----------------------------------------------------------------------===//
789
790OpFoldResult spirv::LogicalOrOp::fold(FoldAdaptor adaptor) {
791 if (auto rhs = getScalarOrSplatBoolAttr(adaptor.getOperand2())) {
792 if (*rhs) {
793 // x || true = true
794 return adaptor.getOperand2();
795 }
796
797 if (!*rhs) {
798 // x || false = x
799 return getOperand1();
800 }
801 }
802
803 return Attribute();
804}
805
806//===----------------------------------------------------------------------===//
807// spirv.SelectOp
808//===----------------------------------------------------------------------===//
809
810OpFoldResult spirv::SelectOp::fold(FoldAdaptor adaptor) {
811 // spirv.Select _ x x -> x
812 Value trueVals = getTrueValue();
813 Value falseVals = getFalseValue();
814 if (trueVals == falseVals)
815 return trueVals;
816
817 ArrayRef<Attribute> operands = adaptor.getOperands();
818
819 // spirv.Select true x y -> x
820 // spirv.Select false x y -> y
821 if (auto boolAttr = getScalarOrSplatBoolAttr(operands[0]))
822 return *boolAttr ? trueVals : falseVals;
823
824 // Check that all the operands are constant
825 if (!operands[0] || !operands[1] || !operands[2])
826 return Attribute();
827
828 // Note: getScalarOrSplatBoolAttr will always return a boolAttr if we are in
829 // the scalar case. Hence, we are only required to consider the case of
830 // DenseElementsAttr in foldSelectOp.
831 auto condAttrs = dyn_cast<DenseElementsAttr>(operands[0]);
832 auto trueAttrs = dyn_cast<DenseElementsAttr>(operands[1]);
833 auto falseAttrs = dyn_cast<DenseElementsAttr>(operands[2]);
834 if (!condAttrs || !trueAttrs || !falseAttrs)
835 return Attribute();
836
837 auto elementResults = llvm::to_vector<4>(trueAttrs.getValues<Attribute>());
838 auto iters = llvm::zip_equal(elementResults, condAttrs.getValues<BoolAttr>(),
839 falseAttrs.getValues<Attribute>());
840 for (auto [result, cond, falseRes] : iters) {
841 if (!cond.getValue())
842 result = falseRes;
843 }
844
845 auto resultType = trueAttrs.getType();
846 return DenseElementsAttr::get(cast<ShapedType>(resultType), elementResults);
847}
848
849//===----------------------------------------------------------------------===//
850// spirv.IEqualOp
851//===----------------------------------------------------------------------===//
852
853OpFoldResult spirv::IEqualOp::fold(spirv::IEqualOp::FoldAdaptor adaptor) {
854 // x == x -> true
855 if (getOperand1() == getOperand2()) {
856 auto trueAttr = BoolAttr::get(getContext(), true);
857 if (isa<IntegerType>(getType()))
858 return trueAttr;
859 if (auto vecTy = dyn_cast<VectorType>(getType()))
860 return SplatElementsAttr::get(vecTy, trueAttr);
861 }
862
863 return constFoldBinaryOp<IntegerAttr>(
864 adaptor.getOperands(), getType(), [](const APInt &a, const APInt &b) {
865 return a == b ? APInt::getAllOnes(1) : APInt::getZero(1);
866 });
867}
868
869//===----------------------------------------------------------------------===//
870// spirv.INotEqualOp
871//===----------------------------------------------------------------------===//
872
873OpFoldResult spirv::INotEqualOp::fold(spirv::INotEqualOp::FoldAdaptor adaptor) {
874 // x == x -> false
875 if (getOperand1() == getOperand2()) {
876 auto falseAttr = BoolAttr::get(getContext(), false);
877 if (isa<IntegerType>(getType()))
878 return falseAttr;
879 if (auto vecTy = dyn_cast<VectorType>(getType()))
880 return SplatElementsAttr::get(vecTy, falseAttr);
881 }
882
883 return constFoldBinaryOp<IntegerAttr>(
884 adaptor.getOperands(), getType(), [](const APInt &a, const APInt &b) {
885 return a == b ? APInt::getZero(1) : APInt::getAllOnes(1);
886 });
887}
888
889//===----------------------------------------------------------------------===//
890// spirv.SGreaterThan
891//===----------------------------------------------------------------------===//
892
893OpFoldResult
894spirv::SGreaterThanOp::fold(spirv::SGreaterThanOp::FoldAdaptor adaptor) {
895 // x == x -> false
896 if (getOperand1() == getOperand2()) {
897 auto falseAttr = BoolAttr::get(getContext(), false);
898 if (isa<IntegerType>(getType()))
899 return falseAttr;
900 if (auto vecTy = dyn_cast<VectorType>(getType()))
901 return SplatElementsAttr::get(vecTy, falseAttr);
902 }
903
904 return constFoldBinaryOp<IntegerAttr>(
905 adaptor.getOperands(), getType(), [](const APInt &a, const APInt &b) {
906 return a.sgt(b) ? APInt::getAllOnes(1) : APInt::getZero(1);
907 });
908}
909
910//===----------------------------------------------------------------------===//
911// spirv.SGreaterThanEqual
912//===----------------------------------------------------------------------===//
913
914OpFoldResult spirv::SGreaterThanEqualOp::fold(
915 spirv::SGreaterThanEqualOp::FoldAdaptor adaptor) {
916 // x == x -> true
917 if (getOperand1() == getOperand2()) {
918 auto trueAttr = BoolAttr::get(getContext(), true);
919 if (isa<IntegerType>(getType()))
920 return trueAttr;
921 if (auto vecTy = dyn_cast<VectorType>(getType()))
922 return SplatElementsAttr::get(vecTy, trueAttr);
923 }
924
925 return constFoldBinaryOp<IntegerAttr>(
926 adaptor.getOperands(), getType(), [](const APInt &a, const APInt &b) {
927 return a.sge(b) ? APInt::getAllOnes(1) : APInt::getZero(1);
928 });
929}
930
931//===----------------------------------------------------------------------===//
932// spirv.UGreaterThan
933//===----------------------------------------------------------------------===//
934
935OpFoldResult
936spirv::UGreaterThanOp::fold(spirv::UGreaterThanOp::FoldAdaptor adaptor) {
937 // x == x -> false
938 if (getOperand1() == getOperand2()) {
939 auto falseAttr = BoolAttr::get(getContext(), false);
940 if (isa<IntegerType>(getType()))
941 return falseAttr;
942 if (auto vecTy = dyn_cast<VectorType>(getType()))
943 return SplatElementsAttr::get(vecTy, falseAttr);
944 }
945
946 return constFoldBinaryOp<IntegerAttr>(
947 adaptor.getOperands(), getType(), [](const APInt &a, const APInt &b) {
948 return a.ugt(b) ? APInt::getAllOnes(1) : APInt::getZero(1);
949 });
950}
951
952//===----------------------------------------------------------------------===//
953// spirv.UGreaterThanEqual
954//===----------------------------------------------------------------------===//
955
956OpFoldResult spirv::UGreaterThanEqualOp::fold(
957 spirv::UGreaterThanEqualOp::FoldAdaptor adaptor) {
958 // x == x -> true
959 if (getOperand1() == getOperand2()) {
960 auto trueAttr = BoolAttr::get(getContext(), true);
961 if (isa<IntegerType>(getType()))
962 return trueAttr;
963 if (auto vecTy = dyn_cast<VectorType>(getType()))
964 return SplatElementsAttr::get(vecTy, trueAttr);
965 }
966
967 return constFoldBinaryOp<IntegerAttr>(
968 adaptor.getOperands(), getType(), [](const APInt &a, const APInt &b) {
969 return a.uge(b) ? APInt::getAllOnes(1) : APInt::getZero(1);
970 });
971}
972
973//===----------------------------------------------------------------------===//
974// spirv.SLessThan
975//===----------------------------------------------------------------------===//
976
977OpFoldResult spirv::SLessThanOp::fold(spirv::SLessThanOp::FoldAdaptor adaptor) {
978 // x == x -> false
979 if (getOperand1() == getOperand2()) {
980 auto falseAttr = BoolAttr::get(getContext(), false);
981 if (isa<IntegerType>(getType()))
982 return falseAttr;
983 if (auto vecTy = dyn_cast<VectorType>(getType()))
984 return SplatElementsAttr::get(vecTy, falseAttr);
985 }
986
987 return constFoldBinaryOp<IntegerAttr>(
988 adaptor.getOperands(), getType(), [](const APInt &a, const APInt &b) {
989 return a.slt(b) ? APInt::getAllOnes(1) : APInt::getZero(1);
990 });
991}
992
993//===----------------------------------------------------------------------===//
994// spirv.SLessThanEqual
995//===----------------------------------------------------------------------===//
996
997OpFoldResult
998spirv::SLessThanEqualOp::fold(spirv::SLessThanEqualOp::FoldAdaptor adaptor) {
999 // x == x -> true
1000 if (getOperand1() == getOperand2()) {
1001 auto trueAttr = BoolAttr::get(getContext(), true);
1002 if (isa<IntegerType>(getType()))
1003 return trueAttr;
1004 if (auto vecTy = dyn_cast<VectorType>(getType()))
1005 return SplatElementsAttr::get(vecTy, trueAttr);
1006 }
1007
1008 return constFoldBinaryOp<IntegerAttr>(
1009 adaptor.getOperands(), getType(), [](const APInt &a, const APInt &b) {
1010 return a.sle(b) ? APInt::getAllOnes(1) : APInt::getZero(1);
1011 });
1012}
1013
1014//===----------------------------------------------------------------------===//
1015// spirv.ULessThan
1016//===----------------------------------------------------------------------===//
1017
1018OpFoldResult spirv::ULessThanOp::fold(spirv::ULessThanOp::FoldAdaptor adaptor) {
1019 // x == x -> false
1020 if (getOperand1() == getOperand2()) {
1021 auto falseAttr = BoolAttr::get(getContext(), false);
1022 if (isa<IntegerType>(getType()))
1023 return falseAttr;
1024 if (auto vecTy = dyn_cast<VectorType>(getType()))
1025 return SplatElementsAttr::get(vecTy, falseAttr);
1026 }
1027
1028 return constFoldBinaryOp<IntegerAttr>(
1029 adaptor.getOperands(), getType(), [](const APInt &a, const APInt &b) {
1030 return a.ult(b) ? APInt::getAllOnes(1) : APInt::getZero(1);
1031 });
1032}
1033
1034//===----------------------------------------------------------------------===//
1035// spirv.ULessThanEqual
1036//===----------------------------------------------------------------------===//
1037
1038OpFoldResult
1039spirv::ULessThanEqualOp::fold(spirv::ULessThanEqualOp::FoldAdaptor adaptor) {
1040 // x == x -> true
1041 if (getOperand1() == getOperand2()) {
1042 auto trueAttr = BoolAttr::get(getContext(), true);
1043 if (isa<IntegerType>(getType()))
1044 return trueAttr;
1045 if (auto vecTy = dyn_cast<VectorType>(getType()))
1046 return SplatElementsAttr::get(vecTy, trueAttr);
1047 }
1048
1049 return constFoldBinaryOp<IntegerAttr>(
1050 adaptor.getOperands(), getType(), [](const APInt &a, const APInt &b) {
1051 return a.ule(b) ? APInt::getAllOnes(1) : APInt::getZero(1);
1052 });
1053}
1054
1055//===----------------------------------------------------------------------===//
1056// spirv.ShiftLeftLogical
1057//===----------------------------------------------------------------------===//
1058
1059OpFoldResult spirv::ShiftLeftLogicalOp::fold(
1060 spirv::ShiftLeftLogicalOp::FoldAdaptor adaptor) {
1061 // x << 0 -> x
1062 if (matchPattern(adaptor.getOperand2(), m_Zero())) {
1063 return getOperand1();
1064 }
1065
1066 // Unfortunately due to below undefined behaviour can't fold 0 for Base.
1067
1068 // Results are computed per component, and within each component, per bit...
1069 //
1070 // The result is undefined if Shift is greater than or equal to the bit width
1071 // of the components of Base.
1072 //
1073 // So we can use the APInt << method, but don't fold if undefined behaviour.
1074 bool shiftToLarge = false;
1075 auto res = constFoldBinaryOp<IntegerAttr>(
1076 adaptor.getOperands(), [&](const APInt &a, const APInt &b) {
1077 if (shiftToLarge || b.uge(a.getBitWidth())) {
1078 shiftToLarge = true;
1079 return a;
1080 }
1081 return a << b;
1082 });
1083 return shiftToLarge ? Attribute() : res;
1084}
1085
1086//===----------------------------------------------------------------------===//
1087// spirv.ShiftRightArithmetic
1088//===----------------------------------------------------------------------===//
1089
1090OpFoldResult spirv::ShiftRightArithmeticOp::fold(
1091 spirv::ShiftRightArithmeticOp::FoldAdaptor adaptor) {
1092 // x >> 0 -> x
1093 if (matchPattern(adaptor.getOperand2(), m_Zero())) {
1094 return getOperand1();
1095 }
1096
1097 // Unfortunately due to below undefined behaviour can't fold 0, -1 for Base.
1098
1099 // Results are computed per component, and within each component, per bit...
1100 //
1101 // The result is undefined if Shift is greater than or equal to the bit width
1102 // of the components of Base.
1103 //
1104 // So we can use the APInt ashr method, but don't fold if undefined behaviour.
1105 bool shiftToLarge = false;
1106 auto res = constFoldBinaryOp<IntegerAttr>(
1107 adaptor.getOperands(), [&](const APInt &a, const APInt &b) {
1108 if (shiftToLarge || b.uge(a.getBitWidth())) {
1109 shiftToLarge = true;
1110 return a;
1111 }
1112 return a.ashr(b);
1113 });
1114 return shiftToLarge ? Attribute() : res;
1115}
1116
1117//===----------------------------------------------------------------------===//
1118// spirv.ShiftRightLogical
1119//===----------------------------------------------------------------------===//
1120
1121OpFoldResult spirv::ShiftRightLogicalOp::fold(
1122 spirv::ShiftRightLogicalOp::FoldAdaptor adaptor) {
1123 // x >> 0 -> x
1124 if (matchPattern(adaptor.getOperand2(), m_Zero())) {
1125 return getOperand1();
1126 }
1127
1128 // Unfortunately due to below undefined behaviour can't fold 0 for Base.
1129
1130 // Results are computed per component, and within each component, per bit...
1131 //
1132 // The result is undefined if Shift is greater than or equal to the bit width
1133 // of the components of Base.
1134 //
1135 // So we can use the APInt lshr method, but don't fold if undefined behaviour.
1136 bool shiftToLarge = false;
1137 auto res = constFoldBinaryOp<IntegerAttr>(
1138 adaptor.getOperands(), [&](const APInt &a, const APInt &b) {
1139 if (shiftToLarge || b.uge(a.getBitWidth())) {
1140 shiftToLarge = true;
1141 return a;
1142 }
1143 return a.lshr(b);
1144 });
1145 return shiftToLarge ? Attribute() : res;
1146}
1147
1148//===----------------------------------------------------------------------===//
1149// spirv.BitwiseAndOp
1150//===----------------------------------------------------------------------===//
1151
1152OpFoldResult
1153spirv::BitwiseAndOp::fold(spirv::BitwiseAndOp::FoldAdaptor adaptor) {
1154 // x & x -> x
1155 if (getOperand1() == getOperand2()) {
1156 return getOperand1();
1157 }
1158
1159 APInt rhsMask;
1160 if (matchPattern(adaptor.getOperand2(), m_ConstantInt(&rhsMask))) {
1161 // x & 0 -> 0
1162 if (rhsMask.isZero())
1163 return getOperand2();
1164
1165 // x & <all ones> -> x
1166 if (rhsMask.isAllOnes())
1167 return getOperand1();
1168
1169 // (UConvert x : iN to iK) & <mask with N low bits set> -> UConvert x
1170 if (auto zext = getOperand1().getDefiningOp<spirv::UConvertOp>()) {
1171 int valueBits =
1172 getElementTypeOrSelf(zext.getOperand()).getIntOrFloatBitWidth();
1173 if (rhsMask.zextOrTrunc(valueBits).isAllOnes())
1174 return getOperand1();
1175 }
1176 }
1177
1178 // According to the SPIR-V spec:
1179 //
1180 // Type is a scalar or vector of integer type.
1181 // Results are computed per component, and within each component, per bit.
1182 // So we can use the APInt & method.
1183 return constFoldBinaryOp<IntegerAttr>(
1184 adaptor.getOperands(),
1185 [](const APInt &a, const APInt &b) { return a & b; });
1186}
1187
1188//===----------------------------------------------------------------------===//
1189// spirv.BitwiseOrOp
1190//===----------------------------------------------------------------------===//
1191
1192OpFoldResult spirv::BitwiseOrOp::fold(spirv::BitwiseOrOp::FoldAdaptor adaptor) {
1193 // x | x -> x
1194 if (getOperand1() == getOperand2()) {
1195 return getOperand1();
1196 }
1197
1198 APInt rhsMask;
1199 if (matchPattern(adaptor.getOperand2(), m_ConstantInt(&rhsMask))) {
1200 // x | 0 -> x
1201 if (rhsMask.isZero())
1202 return getOperand1();
1203
1204 // x | <all ones> -> <all ones>
1205 if (rhsMask.isAllOnes())
1206 return getOperand2();
1207 }
1208
1209 // According to the SPIR-V spec:
1210 //
1211 // Type is a scalar or vector of integer type.
1212 // Results are computed per component, and within each component, per bit.
1213 // So we can use the APInt | method.
1214 return constFoldBinaryOp<IntegerAttr>(
1215 adaptor.getOperands(),
1216 [](const APInt &a, const APInt &b) { return a | b; });
1217}
1218
1219//===----------------------------------------------------------------------===//
1220// spirv.BitwiseXorOp
1221//===----------------------------------------------------------------------===//
1222
1223OpFoldResult
1224spirv::BitwiseXorOp::fold(spirv::BitwiseXorOp::FoldAdaptor adaptor) {
1225 // x ^ 0 -> x
1226 if (matchPattern(adaptor.getOperand2(), m_Zero())) {
1227 return getOperand1();
1228 }
1229
1230 // x ^ x -> 0
1231 if (getOperand1() == getOperand2())
1232 return Builder(getContext()).getZeroAttr(getType());
1233
1234 // According to the SPIR-V spec:
1235 //
1236 // Type is a scalar or vector of integer type.
1237 // Results are computed per component, and within each component, per bit.
1238 // So we can use the APInt ^ method.
1239 return constFoldBinaryOp<IntegerAttr>(
1240 adaptor.getOperands(),
1241 [](const APInt &a, const APInt &b) { return a ^ b; });
1242}
1243
1244//===----------------------------------------------------------------------===//
1245// spirv.mlir.selection
1246//===----------------------------------------------------------------------===//
1247
1248namespace {
1249// Blocks from the given `spirv.mlir.selection` operation must satisfy the
1250// following layout:
1251//
1252// +-----------------------------------------------+
1253// | header block |
1254// | spirv.BranchConditionalOp %cond, ^case0, ^case1 |
1255// +-----------------------------------------------+
1256// / \
1257// ...
1258//
1259//
1260// +------------------------+ +------------------------+
1261// | case #0 | | case #1 |
1262// | spirv.Store %ptr %value0 | | spirv.Store %ptr %value1 |
1263// | spirv.Branch ^merge | | spirv.Branch ^merge |
1264// +------------------------+ +------------------------+
1265//
1266//
1267// ...
1268// \ /
1269// v
1270// +-------------+
1271// | merge block |
1272// +-------------+
1273//
1274struct ConvertSelectionOpToSelect final : OpRewritePattern<spirv::SelectionOp> {
1275 using OpRewritePattern::OpRewritePattern;
1276
1277 LogicalResult matchAndRewrite(spirv::SelectionOp selectionOp,
1278 PatternRewriter &rewriter) const override {
1279 Operation *op = selectionOp.getOperation();
1280 Region &body = op->getRegion(index: 0);
1281 // Verifier allows an empty region for `spirv.mlir.selection`.
1282 if (body.empty()) {
1283 return failure();
1284 }
1285
1286 // Check that region consists of 4 blocks:
1287 // header block, `true` block, `false` block and merge block.
1288 if (llvm::range_size(Range&: body) != 4) {
1289 return failure();
1290 }
1291
1292 Block *headerBlock = selectionOp.getHeaderBlock();
1293 if (!onlyContainsBranchConditionalOp(block: headerBlock)) {
1294 return failure();
1295 }
1296
1297 auto brConditionalOp =
1298 cast<spirv::BranchConditionalOp>(headerBlock->front());
1299
1300 Block *trueBlock = brConditionalOp.getSuccessor(0);
1301 Block *falseBlock = brConditionalOp.getSuccessor(1);
1302 Block *mergeBlock = selectionOp.getMergeBlock();
1303
1304 if (failed(Result: canCanonicalizeSelection(trueBlock, falseBlock, mergeBlock)))
1305 return failure();
1306
1307 Value trueValue = getSrcValue(block: trueBlock);
1308 Value falseValue = getSrcValue(block: falseBlock);
1309 Value ptrValue = getDstPtr(block: trueBlock);
1310 auto storeOpAttributes =
1311 cast<spirv::StoreOp>(trueBlock->front())->getAttrs();
1312
1313 auto selectOp = rewriter.create<spirv::SelectOp>(
1314 selectionOp.getLoc(), trueValue.getType(),
1315 brConditionalOp.getCondition(), trueValue, falseValue);
1316 rewriter.create<spirv::StoreOp>(selectOp.getLoc(), ptrValue,
1317 selectOp.getResult(), storeOpAttributes);
1318
1319 // `spirv.mlir.selection` is not needed anymore.
1320 rewriter.eraseOp(op);
1321 return success();
1322 }
1323
1324private:
1325 // Checks that given blocks follow the following rules:
1326 // 1. Each conditional block consists of two operations, the first operation
1327 // is a `spirv.Store` and the last operation is a `spirv.Branch`.
1328 // 2. Each `spirv.Store` uses the same pointer and the same memory attributes.
1329 // 3. A control flow goes into the given merge block from the given
1330 // conditional blocks.
1331 LogicalResult canCanonicalizeSelection(Block *trueBlock, Block *falseBlock,
1332 Block *mergeBlock) const;
1333
1334 bool onlyContainsBranchConditionalOp(Block *block) const {
1335 return llvm::hasSingleElement(*block) &&
1336 isa<spirv::BranchConditionalOp>(block->front());
1337 }
1338
1339 bool isSameAttrList(spirv::StoreOp lhs, spirv::StoreOp rhs) const {
1340 return lhs->getDiscardableAttrDictionary() ==
1341 rhs->getDiscardableAttrDictionary() &&
1342 lhs.getProperties() == rhs.getProperties();
1343 }
1344
1345 // Returns a source value for the given block.
1346 Value getSrcValue(Block *block) const {
1347 auto storeOp = cast<spirv::StoreOp>(block->front());
1348 return storeOp.getValue();
1349 }
1350
1351 // Returns a destination value for the given block.
1352 Value getDstPtr(Block *block) const {
1353 auto storeOp = cast<spirv::StoreOp>(block->front());
1354 return storeOp.getPtr();
1355 }
1356};
1357
1358LogicalResult ConvertSelectionOpToSelect::canCanonicalizeSelection(
1359 Block *trueBlock, Block *falseBlock, Block *mergeBlock) const {
1360 // Each block must consists of 2 operations.
1361 if (llvm::range_size(Range&: *trueBlock) != 2 || llvm::range_size(Range&: *falseBlock) != 2) {
1362 return failure();
1363 }
1364
1365 auto trueBrStoreOp = dyn_cast<spirv::StoreOp>(trueBlock->front());
1366 auto trueBrBranchOp =
1367 dyn_cast<spirv::BranchOp>(*std::next(trueBlock->begin()));
1368 auto falseBrStoreOp = dyn_cast<spirv::StoreOp>(falseBlock->front());
1369 auto falseBrBranchOp =
1370 dyn_cast<spirv::BranchOp>(*std::next(falseBlock->begin()));
1371
1372 if (!trueBrStoreOp || !trueBrBranchOp || !falseBrStoreOp ||
1373 !falseBrBranchOp) {
1374 return failure();
1375 }
1376
1377 // Checks that given type is valid for `spirv.SelectOp`.
1378 // According to SPIR-V spec:
1379 // "Before version 1.4, Result Type must be a pointer, scalar, or vector.
1380 // Starting with version 1.4, Result Type can additionally be a composite type
1381 // other than a vector."
1382 bool isScalarOrVector =
1383 llvm::cast<spirv::SPIRVType>(trueBrStoreOp.getValue().getType())
1384 .isScalarOrVector();
1385
1386 // Check that each `spirv.Store` uses the same pointer, memory access
1387 // attributes and a valid type of the value.
1388 if ((trueBrStoreOp.getPtr() != falseBrStoreOp.getPtr()) ||
1389 !isSameAttrList(trueBrStoreOp, falseBrStoreOp) || !isScalarOrVector) {
1390 return failure();
1391 }
1392
1393 if ((trueBrBranchOp->getSuccessor(0) != mergeBlock) ||
1394 (falseBrBranchOp->getSuccessor(0) != mergeBlock)) {
1395 return failure();
1396 }
1397
1398 return success();
1399}
1400} // namespace
1401
1402void spirv::SelectionOp::getCanonicalizationPatterns(RewritePatternSet &results,
1403 MLIRContext *context) {
1404 results.add<ConvertSelectionOpToSelect>(context);
1405}
1406

Provided by KDAB

Privacy Policy
Improve your Profiling and Debugging skills
Find out more

source code of mlir/lib/Dialect/SPIRV/IR/SPIRVCanonicalization.cpp