1//===- Bufferize.cpp - Bufferization for Arith ops ---------*- C++ -*-===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#include "mlir/Dialect/Arith/Transforms/Passes.h"
10
11#include "mlir/Dialect/Arith/IR/Arith.h"
12#include "mlir/Dialect/Arith/Transforms/BufferizableOpInterfaceImpl.h"
13#include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h"
14#include "mlir/Dialect/Bufferization/IR/Bufferization.h"
15#include "mlir/Dialect/Bufferization/Transforms/Bufferize.h"
16#include "mlir/Dialect/MemRef/IR/MemRef.h"
17
18namespace mlir {
19namespace arith {
20#define GEN_PASS_DEF_ARITHBUFFERIZEPASS
21#include "mlir/Dialect/Arith/Transforms/Passes.h.inc"
22} // namespace arith
23} // namespace mlir
24
25using namespace mlir;
26using namespace bufferization;
27
28namespace {
29/// Pass to bufferize Arith ops.
30struct ArithBufferizePass
31 : public arith::impl::ArithBufferizePassBase<ArithBufferizePass> {
32 using ArithBufferizePassBase::ArithBufferizePassBase;
33
34 ArithBufferizePass(uint64_t alignment = 0, bool constantOpOnly = false)
35 : constantOpOnly(constantOpOnly) {
36 this->alignment = alignment;
37 }
38
39 void runOnOperation() override {
40 BufferizationOptions options = getPartialBufferizationOptions();
41 if (constantOpOnly) {
42 options.opFilter.allowOperation<arith::ConstantOp>();
43 } else {
44 options.opFilter.allowDialect<arith::ArithDialect>();
45 }
46 options.bufferAlignment = alignment;
47
48 if (failed(bufferizeOp(getOperation(), options)))
49 signalPassFailure();
50 }
51
52 void getDependentDialects(DialectRegistry &registry) const override {
53 registry.insert<bufferization::BufferizationDialect, memref::MemRefDialect,
54 arith::ArithDialect>();
55 arith::registerBufferizableOpInterfaceExternalModels(registry);
56 }
57
58private:
59 bool constantOpOnly;
60};
61} // namespace
62
63std::unique_ptr<Pass>
64mlir::arith::createConstantBufferizePass(uint64_t alignment) {
65 return std::make_unique<ArithBufferizePass>(args&: alignment,
66 /*constantOpOnly=*/args: true);
67}
68

source code of mlir/lib/Dialect/Arith/Transforms/Bufferize.cpp