1 | //===- TosaToMLProgram.cpp - Lowering Tosa to MLProgram Dialect------------===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | // |
9 | // These rewriters lower from the TOSA dialect to the MLProgram dialect. |
10 | // |
11 | //===----------------------------------------------------------------------===// |
12 | |
13 | #include "mlir/Conversion/TosaToMLProgram/TosaToMLProgram.h" |
14 | #include "mlir/Dialect/MLProgram/IR/MLProgram.h" |
15 | #include "mlir/Dialect/Tosa/IR/TosaOps.h" |
16 | #include "mlir/IR/IRMapping.h" |
17 | #include "mlir/IR/PatternMatch.h" |
18 | |
19 | using namespace mlir; |
20 | using namespace tosa; |
21 | namespace { |
22 | |
23 | class VariableOpConverter : public OpRewritePattern<tosa::VariableOp> { |
24 | public: |
25 | using OpRewritePattern<tosa::VariableOp>::OpRewritePattern; |
26 | |
27 | LogicalResult matchAndRewrite(tosa::VariableOp op, |
28 | PatternRewriter &rewriter) const final { |
29 | auto variableType = tosa::getVariableType(op); |
30 | auto newVariable = rewriter.create<mlir::ml_program::GlobalOp>( |
31 | op.getLoc(), op.getName(), variableType, /*is_mutable=*/true, |
32 | op.getInitialValueAttr(), /*sym_visibility=*/nullptr); |
33 | newVariable.setPrivate(); |
34 | rewriter.replaceOp(op, newVariable); |
35 | return success(); |
36 | } |
37 | }; |
38 | |
39 | class VariableWriteOpConverter |
40 | : public OpRewritePattern<tosa::VariableWriteOp> { |
41 | public: |
42 | using OpRewritePattern<tosa::VariableWriteOp>::OpRewritePattern; |
43 | |
44 | LogicalResult matchAndRewrite(tosa::VariableWriteOp op, |
45 | PatternRewriter &rewriter) const final { |
46 | auto globalSymbolRef = |
47 | SymbolRefAttr::get(rewriter.getContext(), op.getName()); |
48 | auto newVariableWrite = rewriter.create<ml_program::GlobalStoreOp>( |
49 | op.getLoc(), globalSymbolRef, op.getInput1()); |
50 | rewriter.replaceOp(op, newVariableWrite); |
51 | return success(); |
52 | } |
53 | }; |
54 | |
55 | class VariableReadOpConverter : public OpRewritePattern<tosa::VariableReadOp> { |
56 | public: |
57 | using OpRewritePattern<tosa::VariableReadOp>::OpRewritePattern; |
58 | |
59 | LogicalResult matchAndRewrite(tosa::VariableReadOp op, |
60 | PatternRewriter &rewriter) const final { |
61 | auto globalSymbolRef = |
62 | SymbolRefAttr::get(rewriter.getContext(), op.getName()); |
63 | auto newVariableRead = rewriter.create<ml_program::GlobalLoadOp>( |
64 | op.getLoc(), op.getType(), globalSymbolRef); |
65 | rewriter.replaceOp(op, newVariableRead); |
66 | |
67 | return success(); |
68 | } |
69 | }; |
70 | |
71 | } // namespace |
72 | |
73 | void mlir::tosa::populateTosaToMLProgramConversionPatterns( |
74 | RewritePatternSet *patterns) { |
75 | patterns->add<VariableOpConverter, VariableWriteOpConverter, |
76 | VariableReadOpConverter>(arg: patterns->getContext()); |
77 | } |
78 | |