1//===- ShapeToShapeLowering.cpp - Prepare for lowering to Standard --------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#include "mlir/Dialect/Shape/Transforms/Passes.h"
10
11#include "mlir/Dialect/Arith/IR/Arith.h"
12#include "mlir/Dialect/Func/IR/FuncOps.h"
13#include "mlir/Dialect/Shape/IR/Shape.h"
14#include "mlir/IR/Builders.h"
15#include "mlir/IR/PatternMatch.h"
16#include "mlir/Pass/Pass.h"
17#include "mlir/Transforms/DialectConversion.h"
18
19namespace mlir {
20#define GEN_PASS_DEF_SHAPETOSHAPELOWERING
21#include "mlir/Dialect/Shape/Transforms/Passes.h.inc"
22} // namespace mlir
23
24using namespace mlir;
25using namespace mlir::shape;
26
27namespace {
28/// Converts `shape.num_elements` to `shape.reduce`.
29struct NumElementsOpConverter : public OpRewritePattern<NumElementsOp> {
30public:
31 using OpRewritePattern::OpRewritePattern;
32
33 LogicalResult matchAndRewrite(NumElementsOp op,
34 PatternRewriter &rewriter) const final;
35};
36} // namespace
37
38LogicalResult
39NumElementsOpConverter::matchAndRewrite(NumElementsOp op,
40 PatternRewriter &rewriter) const {
41 auto loc = op.getLoc();
42 Type valueType = op.getResult().getType();
43 Value init = op->getDialect()
44 ->materializeConstant(rewriter, rewriter.getIndexAttr(1),
45 valueType, loc)
46 ->getResult(0);
47 ReduceOp reduce = rewriter.create<ReduceOp>(loc, op.getShape(), init);
48
49 // Generate reduce operator.
50 Block *body = reduce.getBody();
51 OpBuilder b = OpBuilder::atBlockEnd(block: body);
52 Value product = b.create<MulOp>(loc, valueType, body->getArgument(1),
53 body->getArgument(2));
54 b.create<shape::YieldOp>(loc, product);
55
56 rewriter.replaceOp(op, reduce.getResult());
57 return success();
58}
59
60namespace {
61struct ShapeToShapeLowering
62 : public impl::ShapeToShapeLoweringBase<ShapeToShapeLowering> {
63 void runOnOperation() override;
64};
65} // namespace
66
67void ShapeToShapeLowering::runOnOperation() {
68 MLIRContext &ctx = getContext();
69
70 RewritePatternSet patterns(&ctx);
71 populateShapeRewritePatterns(patterns);
72
73 ConversionTarget target(getContext());
74 target.addLegalDialect<arith::ArithDialect, ShapeDialect>();
75 target.addIllegalOp<NumElementsOp>();
76 if (failed(mlir::applyPartialConversion(getOperation(), target,
77 std::move(patterns))))
78 signalPassFailure();
79}
80
81void mlir::populateShapeRewritePatterns(RewritePatternSet &patterns) {
82 patterns.add<NumElementsOpConverter>(arg: patterns.getContext());
83}
84
85std::unique_ptr<Pass> mlir::createShapeToShapeLowering() {
86 return std::make_unique<ShapeToShapeLowering>();
87}
88

source code of mlir/lib/Dialect/Shape/Transforms/ShapeToShapeLowering.cpp