1//===- StdExpandDivs.cpp - Code to prepare Std for lowering Divs to LLVM -===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file Std transformations to expand Divs operation to help for the
10// lowering to LLVM. Currently implemented transformations are Ceil and Floor
11// for Signed Integers.
12//
13//===----------------------------------------------------------------------===//
14
15#include "mlir/Dialect/MemRef/Transforms/Passes.h"
16
17#include "mlir/Dialect/Arith/IR/Arith.h"
18#include "mlir/Dialect/Arith/Transforms/Passes.h"
19#include "mlir/Dialect/MemRef/IR/MemRef.h"
20#include "mlir/Dialect/MemRef/Transforms/Transforms.h"
21#include "mlir/IR/TypeUtilities.h"
22#include "mlir/Transforms/DialectConversion.h"
23#include "llvm/ADT/STLExtras.h"
24
25namespace mlir {
26namespace memref {
27#define GEN_PASS_DEF_EXPANDOPS
28#include "mlir/Dialect/MemRef/Transforms/Passes.h.inc"
29} // namespace memref
30} // namespace mlir
31
32using namespace mlir;
33
34namespace {
35
36/// Converts `atomic_rmw` that cannot be lowered to a simple atomic op with
37/// AtomicRMWOpLowering pattern, such as minimum and maximum operations for
38/// floating-point numbers, to `memref.generic_atomic_rmw` with the expanded
39/// code.
40///
41/// %x = atomic_rmw maximumf %fval, %F[%i] : (f32, memref<10xf32>) -> f32
42///
43/// will be lowered to
44///
45/// %x = memref.generic_atomic_rmw %F[%i] : memref<10xf32> {
46/// ^bb0(%current: f32):
47/// %1 = arith.maximumf %current, %fval : f32
48/// memref.atomic_yield %1 : f32
49/// }
50struct AtomicRMWOpConverter : public OpRewritePattern<memref::AtomicRMWOp> {
51public:
52 using OpRewritePattern::OpRewritePattern;
53
54 LogicalResult matchAndRewrite(memref::AtomicRMWOp op,
55 PatternRewriter &rewriter) const final {
56 auto loc = op.getLoc();
57 auto genericOp = rewriter.create<memref::GenericAtomicRMWOp>(
58 loc, op.getMemref(), op.getIndices());
59 OpBuilder bodyBuilder =
60 OpBuilder::atBlockEnd(block: genericOp.getBody(), listener: rewriter.getListener());
61
62 Value lhs = genericOp.getCurrentValue();
63 Value rhs = op.getValue();
64
65 Value arithOp =
66 mlir::arith::getReductionOp(op.getKind(), bodyBuilder, loc, lhs, rhs);
67 bodyBuilder.create<memref::AtomicYieldOp>(loc, arithOp);
68
69 rewriter.replaceOp(op, genericOp.getResult());
70 return success();
71 }
72};
73
74/// Converts `memref.reshape` that has a target shape of a statically-known
75/// size to `memref.reinterpret_cast`.
76struct MemRefReshapeOpConverter : public OpRewritePattern<memref::ReshapeOp> {
77public:
78 using OpRewritePattern::OpRewritePattern;
79
80 LogicalResult matchAndRewrite(memref::ReshapeOp op,
81 PatternRewriter &rewriter) const final {
82 auto shapeType = cast<MemRefType>(op.getShape().getType());
83 if (!shapeType.hasStaticShape())
84 return failure();
85
86 int64_t rank = cast<MemRefType>(shapeType).getDimSize(0);
87 SmallVector<OpFoldResult, 4> sizes, strides;
88 sizes.resize(rank);
89 strides.resize(rank);
90
91 Location loc = op.getLoc();
92 Value stride = rewriter.create<arith::ConstantIndexOp>(loc, 1);
93 for (int i = rank - 1; i >= 0; --i) {
94 Value size;
95 // Load dynamic sizes from the shape input, use constants for static dims.
96 if (op.getType().isDynamicDim(i)) {
97 Value index = rewriter.create<arith::ConstantIndexOp>(loc, i);
98 size = rewriter.create<memref::LoadOp>(loc, op.getShape(), index);
99 if (!isa<IndexType>(size.getType()))
100 size = rewriter.create<arith::IndexCastOp>(
101 loc, rewriter.getIndexType(), size);
102 sizes[i] = size;
103 } else {
104 auto sizeAttr = rewriter.getIndexAttr(value: op.getType().getDimSize(i));
105 size = rewriter.create<arith::ConstantOp>(loc, sizeAttr);
106 sizes[i] = sizeAttr;
107 }
108 strides[i] = stride;
109 if (i > 0)
110 stride = rewriter.create<arith::MulIOp>(loc, stride, size);
111 }
112 rewriter.replaceOpWithNewOp<memref::ReinterpretCastOp>(
113 op, op.getType(), op.getSource(), /*offset=*/rewriter.getIndexAttr(0),
114 sizes, strides);
115 return success();
116 }
117};
118
119struct ExpandOpsPass : public memref::impl::ExpandOpsBase<ExpandOpsPass> {
120 void runOnOperation() override {
121 MLIRContext &ctx = getContext();
122
123 RewritePatternSet patterns(&ctx);
124 memref::populateExpandOpsPatterns(patterns);
125 ConversionTarget target(ctx);
126
127 target.addLegalDialect<arith::ArithDialect, memref::MemRefDialect>();
128 target.addDynamicallyLegalOp<memref::AtomicRMWOp>(
129 [](memref::AtomicRMWOp op) {
130 constexpr std::array shouldBeExpandedKinds = {
131 arith::AtomicRMWKind::maximumf, arith::AtomicRMWKind::minimumf,
132 arith::AtomicRMWKind::minnumf, arith::AtomicRMWKind::maxnumf};
133 return !llvm::is_contained(shouldBeExpandedKinds, op.getKind());
134 });
135 target.addDynamicallyLegalOp<memref::ReshapeOp>([](memref::ReshapeOp op) {
136 return !cast<MemRefType>(op.getShape().getType()).hasStaticShape();
137 });
138 if (failed(applyPartialConversion(getOperation(), target,
139 std::move(patterns))))
140 signalPassFailure();
141 }
142};
143
144} // namespace
145
146void mlir::memref::populateExpandOpsPatterns(RewritePatternSet &patterns) {
147 patterns.add<AtomicRMWOpConverter, MemRefReshapeOpConverter>(
148 arg: patterns.getContext());
149}
150
151std::unique_ptr<Pass> mlir::memref::createExpandOpsPass() {
152 return std::make_unique<ExpandOpsPass>();
153}
154

source code of mlir/lib/Dialect/MemRef/Transforms/ExpandOps.cpp