1//===- ShapeToShapeLowering.cpp - Prepare for lowering to Standard --------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#include "mlir/Dialect/Shape/Transforms/Passes.h"
10
11#include "mlir/Dialect/Arith/IR/Arith.h"
12#include "mlir/Dialect/Func/IR/FuncOps.h"
13#include "mlir/Dialect/Shape/IR/Shape.h"
14#include "mlir/IR/Builders.h"
15#include "mlir/IR/PatternMatch.h"
16#include "mlir/Transforms/DialectConversion.h"
17
18namespace mlir {
19#define GEN_PASS_DEF_SHAPETOSHAPELOWERINGPASS
20#include "mlir/Dialect/Shape/Transforms/Passes.h.inc"
21} // namespace mlir
22
23using namespace mlir;
24using namespace mlir::shape;
25
26namespace {
27/// Converts `shape.num_elements` to `shape.reduce`.
28struct NumElementsOpConverter : public OpRewritePattern<NumElementsOp> {
29public:
30 using OpRewritePattern::OpRewritePattern;
31
32 LogicalResult matchAndRewrite(NumElementsOp op,
33 PatternRewriter &rewriter) const final;
34};
35} // namespace
36
37LogicalResult
38NumElementsOpConverter::matchAndRewrite(NumElementsOp op,
39 PatternRewriter &rewriter) const {
40 auto loc = op.getLoc();
41 Type valueType = op.getResult().getType();
42 Value init = op->getDialect()
43 ->materializeConstant(builder&: rewriter, value: rewriter.getIndexAttr(value: 1),
44 type: valueType, loc)
45 ->getResult(idx: 0);
46 ReduceOp reduce = rewriter.create<ReduceOp>(location: loc, args: op.getShape(), args&: init);
47
48 // Generate reduce operator.
49 Block *body = reduce.getBody();
50 OpBuilder b = OpBuilder::atBlockEnd(block: body);
51 Value product = b.create<MulOp>(location: loc, args&: valueType, args: body->getArgument(i: 1),
52 args: body->getArgument(i: 2));
53 b.create<shape::YieldOp>(location: loc, args&: product);
54
55 rewriter.replaceOp(op, newValues: reduce.getResult());
56 return success();
57}
58
59namespace {
60struct ShapeToShapeLowering
61 : public impl::ShapeToShapeLoweringPassBase<ShapeToShapeLowering> {
62 void runOnOperation() override;
63};
64} // namespace
65
66void ShapeToShapeLowering::runOnOperation() {
67 MLIRContext &ctx = getContext();
68
69 RewritePatternSet patterns(&ctx);
70 populateShapeRewritePatterns(patterns);
71
72 ConversionTarget target(getContext());
73 target.addLegalDialect<arith::ArithDialect, ShapeDialect>();
74 target.addIllegalOp<NumElementsOp>();
75 if (failed(Result: mlir::applyPartialConversion(op: getOperation(), target,
76 patterns: std::move(patterns))))
77 signalPassFailure();
78}
79
80void mlir::populateShapeRewritePatterns(RewritePatternSet &patterns) {
81 patterns.add<NumElementsOpConverter>(arg: patterns.getContext());
82}
83

source code of mlir/lib/Dialect/Shape/Transforms/ShapeToShapeLowering.cpp