1//===- ConvertShapeConstraints.cpp - Conversion of shape constraints ------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#include "mlir/Conversion/ShapeToStandard/ShapeToStandard.h"
10
11#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
12#include "mlir/Dialect/SCF/IR/SCF.h"
13#include "mlir/Dialect/Shape/IR/Shape.h"
14#include "mlir/IR/PatternMatch.h"
15#include "mlir/Pass/Pass.h"
16#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
17
18namespace mlir {
19#define GEN_PASS_DEF_CONVERTSHAPECONSTRAINTSPASS
20#include "mlir/Conversion/Passes.h.inc"
21} // namespace mlir
22
23using namespace mlir;
24
25namespace {
26#include "ShapeToStandard.cpp.inc"
27} // namespace
28
29namespace {
30class ConvertCstrRequireOp : public OpRewritePattern<shape::CstrRequireOp> {
31public:
32 using OpRewritePattern::OpRewritePattern;
33 LogicalResult matchAndRewrite(shape::CstrRequireOp op,
34 PatternRewriter &rewriter) const override {
35 rewriter.create<cf::AssertOp>(location: op.getLoc(), args: op.getPred(), args: op.getMsgAttr());
36 rewriter.replaceOpWithNewOp<shape::ConstWitnessOp>(op, args: true);
37 return success();
38 }
39};
40} // namespace
41
42void mlir::populateConvertShapeConstraintsConversionPatterns(
43 RewritePatternSet &patterns) {
44 patterns.add<CstrBroadcastableToRequire>(arg: patterns.getContext());
45 patterns.add<CstrEqToRequire>(arg: patterns.getContext());
46 patterns.add<ConvertCstrRequireOp>(arg: patterns.getContext());
47}
48
49namespace {
50// This pass eliminates shape constraints from the program, converting them to
51// eager (side-effecting) error handling code. After eager error handling code
52// is emitted, witnesses are satisfied, so they are replace with
53// `shape.const_witness true`.
54class ConvertShapeConstraints
55 : public impl::ConvertShapeConstraintsPassBase<ConvertShapeConstraints> {
56 void runOnOperation() override {
57 auto *func = getOperation();
58 auto *context = &getContext();
59
60 RewritePatternSet patterns(context);
61 populateConvertShapeConstraintsConversionPatterns(patterns);
62
63 if (failed(Result: applyPatternsGreedily(op: func, patterns: std::move(patterns))))
64 return signalPassFailure();
65 }
66};
67} // namespace
68

source code of mlir/lib/Conversion/ShapeToStandard/ConvertShapeConstraints.cpp