1//===- ConvertShapeConstraints.cpp - Conversion of shape constraints ------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#include "mlir/Conversion/ShapeToStandard/ShapeToStandard.h"
10
11#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
12#include "mlir/Dialect/SCF/IR/SCF.h"
13#include "mlir/Dialect/Shape/IR/Shape.h"
14#include "mlir/Dialect/Tensor/IR/Tensor.h"
15#include "mlir/IR/PatternMatch.h"
16#include "mlir/Pass/Pass.h"
17#include "mlir/Pass/PassRegistry.h"
18#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
19
20namespace mlir {
21#define GEN_PASS_DEF_CONVERTSHAPECONSTRAINTS
22#include "mlir/Conversion/Passes.h.inc"
23} // namespace mlir
24
25using namespace mlir;
26
27namespace {
28#include "ShapeToStandard.cpp.inc"
29} // namespace
30
31namespace {
32class ConvertCstrRequireOp : public OpRewritePattern<shape::CstrRequireOp> {
33public:
34 using OpRewritePattern::OpRewritePattern;
35 LogicalResult matchAndRewrite(shape::CstrRequireOp op,
36 PatternRewriter &rewriter) const override {
37 rewriter.create<cf::AssertOp>(op.getLoc(), op.getPred(), op.getMsgAttr());
38 rewriter.replaceOpWithNewOp<shape::ConstWitnessOp>(op, true);
39 return success();
40 }
41};
42} // namespace
43
44void mlir::populateConvertShapeConstraintsConversionPatterns(
45 RewritePatternSet &patterns) {
46 patterns.add<CstrBroadcastableToRequire>(patterns.getContext());
47 patterns.add<CstrEqToRequire>(patterns.getContext());
48 patterns.add<ConvertCstrRequireOp>(arg: patterns.getContext());
49}
50
51namespace {
52// This pass eliminates shape constraints from the program, converting them to
53// eager (side-effecting) error handling code. After eager error handling code
54// is emitted, witnesses are satisfied, so they are replace with
55// `shape.const_witness true`.
56class ConvertShapeConstraints
57 : public impl::ConvertShapeConstraintsBase<ConvertShapeConstraints> {
58 void runOnOperation() override {
59 auto *func = getOperation();
60 auto *context = &getContext();
61
62 RewritePatternSet patterns(context);
63 populateConvertShapeConstraintsConversionPatterns(patterns);
64
65 if (failed(applyPatternsAndFoldGreedily(func, std::move(patterns))))
66 return signalPassFailure();
67 }
68};
69} // namespace
70
71std::unique_ptr<Pass> mlir::createConvertShapeConstraintsPass() {
72 return std::make_unique<ConvertShapeConstraints>();
73}
74

source code of mlir/lib/Conversion/ShapeToStandard/ConvertShapeConstraints.cpp