1//===- WrapFuncInClass.cpp - Wrap Emitc Funcs in classes -------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#include "mlir/Dialect/EmitC/IR/EmitC.h"
10#include "mlir/Dialect/EmitC/Transforms/Passes.h"
11#include "mlir/Dialect/EmitC/Transforms/Transforms.h"
12#include "mlir/IR/Attributes.h"
13#include "mlir/IR/Builders.h"
14#include "mlir/IR/BuiltinAttributes.h"
15#include "mlir/IR/PatternMatch.h"
16#include "mlir/Transforms/WalkPatternRewriteDriver.h"
17
18using namespace mlir;
19using namespace emitc;
20
21namespace mlir {
22namespace emitc {
23#define GEN_PASS_DEF_WRAPFUNCINCLASSPASS
24#include "mlir/Dialect/EmitC/Transforms/Passes.h.inc"
25
26namespace {
27struct WrapFuncInClassPass
28 : public impl::WrapFuncInClassPassBase<WrapFuncInClassPass> {
29 using WrapFuncInClassPassBase::WrapFuncInClassPassBase;
30 void runOnOperation() override {
31 Operation *rootOp = getOperation();
32
33 RewritePatternSet patterns(&getContext());
34 populateFuncPatterns(patterns, namedAttribute);
35
36 walkAndApplyPatterns(op: rootOp, patterns: std::move(patterns));
37 }
38};
39
40} // namespace
41} // namespace emitc
42} // namespace mlir
43
44class WrapFuncInClass : public OpRewritePattern<emitc::FuncOp> {
45public:
46 WrapFuncInClass(MLIRContext *context, StringRef attrName)
47 : OpRewritePattern<emitc::FuncOp>(context), attributeName(attrName) {}
48
49 LogicalResult matchAndRewrite(emitc::FuncOp funcOp,
50 PatternRewriter &rewriter) const override {
51
52 auto className = funcOp.getSymNameAttr().str() + "Class";
53 ClassOp newClassOp = rewriter.create<ClassOp>(location: funcOp.getLoc(), args&: className);
54
55 SmallVector<std::pair<StringAttr, TypeAttr>> fields;
56 rewriter.createBlock(parent: &newClassOp.getBody());
57 rewriter.setInsertionPointToStart(&newClassOp.getBody().front());
58
59 auto argAttrs = funcOp.getArgAttrs();
60 for (auto [idx, val] : llvm::enumerate(First: funcOp.getArguments())) {
61 StringAttr fieldName;
62 Attribute argAttr = nullptr;
63
64 fieldName = rewriter.getStringAttr(bytes: "fieldName" + std::to_string(val: idx));
65 if (argAttrs && idx < argAttrs->size())
66 argAttr = (*argAttrs)[idx];
67
68 TypeAttr typeAttr = TypeAttr::get(type: val.getType());
69 fields.push_back(Elt: {fieldName, typeAttr});
70 rewriter.create<emitc::FieldOp>(location: funcOp.getLoc(), args&: fieldName, args&: typeAttr,
71 args&: argAttr);
72 }
73
74 rewriter.setInsertionPointToEnd(&newClassOp.getBody().front());
75 FunctionType funcType = funcOp.getFunctionType();
76 Location loc = funcOp.getLoc();
77 FuncOp newFuncOp =
78 rewriter.create<emitc::FuncOp>(location: loc, args: ("execute"), args&: funcType);
79
80 rewriter.createBlock(parent: &newFuncOp.getBody());
81 newFuncOp.getBody().takeBody(other&: funcOp.getBody());
82
83 rewriter.setInsertionPointToStart(&newFuncOp.getBody().front());
84 std::vector<Value> newArguments;
85 newArguments.reserve(n: fields.size());
86 for (auto &[fieldName, attr] : fields) {
87 GetFieldOp arg =
88 rewriter.create<emitc::GetFieldOp>(location: loc, args: attr.getValue(), args&: fieldName);
89 newArguments.push_back(x: arg);
90 }
91
92 for (auto [oldArg, newArg] :
93 llvm::zip(t: newFuncOp.getArguments(), u&: newArguments)) {
94 rewriter.replaceAllUsesWith(from: oldArg, to: newArg);
95 }
96
97 llvm::BitVector argsToErase(newFuncOp.getNumArguments(), true);
98 if (failed(Result: newFuncOp.eraseArguments(argIndices: argsToErase)))
99 newFuncOp->emitOpError(message: "failed to erase all arguments using BitVector");
100
101 rewriter.replaceOp(op: funcOp, newOp: newClassOp);
102 return success();
103 }
104
105private:
106 StringRef attributeName;
107};
108
109void mlir::emitc::populateFuncPatterns(RewritePatternSet &patterns,
110 StringRef namedAttribute) {
111 patterns.add<WrapFuncInClass>(arg: patterns.getContext(), args&: namedAttribute);
112}
113

source code of mlir/lib/Dialect/EmitC/Transforms/WrapFuncInClass.cpp