1//===- TosaToSCF.cpp - Lowering Tosa to SCF Dialect -----------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// These rewriters lower from the Tosa to the SCF dialect.
10//
11//===----------------------------------------------------------------------===//
12
13#include "mlir/Conversion/TosaToSCF/TosaToSCF.h"
14#include "mlir/Dialect/SCF/IR/SCF.h"
15#include "mlir/Dialect/Tensor/IR/Tensor.h"
16#include "mlir/Dialect/Tosa/IR/TosaOps.h"
17#include "mlir/IR/IRMapping.h"
18#include "mlir/IR/PatternMatch.h"
19#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
20
21using namespace mlir;
22using namespace tosa;
23
24static void inlineIfCase(Region &srcRegion, Region &dstRegion,
25 OperandRange operands, PatternRewriter &rewriter) {
26 rewriter.cloneRegionBefore(region&: srcRegion, before: &dstRegion.front());
27 rewriter.eraseBlock(block: &dstRegion.back());
28
29 Block *headBlock = &dstRegion.front();
30 for (auto it : llvm::zip(t: headBlock->getArguments(), u&: operands))
31 std::get<0>(t&: it).replaceAllUsesWith(newValue: std::get<1>(t&: it));
32
33 auto yield = cast<YieldOp>(headBlock->getTerminator());
34 rewriter.setInsertionPoint(yield);
35 rewriter.create<scf::YieldOp>(yield.getLoc(), yield.getInputs());
36 rewriter.eraseOp(op: yield);
37
38 headBlock->eraseArguments(start: 0, num: headBlock->getNumArguments());
39}
40
41static void inlineWhileCase(Region &srcRegion, Region &dstRegion,
42 PatternRewriter &rewriter, bool isCond) {
43 rewriter.cloneRegionBefore(region&: srcRegion, before: &dstRegion.back());
44 rewriter.eraseBlock(block: &dstRegion.back());
45
46 Block *headBlock = &dstRegion.front();
47
48 auto yield = cast<YieldOp>(headBlock->getTerminator());
49 rewriter.setInsertionPoint(yield);
50 if (isCond) {
51 auto condition =
52 rewriter.create<tensor::ExtractOp>(yield.getLoc(), yield.getOperand(0));
53 rewriter.create<scf::ConditionOp>(yield.getLoc(), condition,
54 headBlock->getArguments());
55 } else {
56 rewriter.setInsertionPoint(yield);
57 rewriter.create<scf::YieldOp>(yield.getLoc(), yield.getInputs());
58 }
59 rewriter.eraseOp(op: yield);
60}
61
62namespace {
63
64class IfOpConverter : public OpRewritePattern<tosa::IfOp> {
65public:
66 using OpRewritePattern<tosa::IfOp>::OpRewritePattern;
67
68 LogicalResult matchAndRewrite(tosa::IfOp op,
69 PatternRewriter &rewriter) const final {
70 auto condition =
71 rewriter.create<tensor::ExtractOp>(op.getLoc(), op.getCond());
72 auto newIf = rewriter.create<scf::IfOp>(op.getLoc(), op.getResultTypes(),
73 condition, true);
74
75 inlineIfCase(op.getThenBranch(), newIf.getThenRegion(), op.getInputs(),
76 rewriter);
77 inlineIfCase(op.getElseBranch(), newIf.getElseRegion(), op.getInputs(),
78 rewriter);
79
80 rewriter.replaceOp(op, newIf.getResults());
81 return success();
82 }
83};
84
85class ScatterOpConverter : public OpRewritePattern<tosa::ScatterOp> {
86 static Value createTensorDim(OpBuilder &builder, Location loc, Value tensor,
87 int64_t dim) {
88 return builder.createOrFold<tensor::DimOp>(loc, tensor, dim);
89 }
90
91 static Value createIndexConst(OpBuilder &builder, Location loc,
92 int64_t value) {
93 return builder.create<arith::ConstantIndexOp>(location: loc, args&: value);
94 }
95
96public:
97 using OpRewritePattern<tosa::ScatterOp>::OpRewritePattern;
98
99 LogicalResult matchAndRewrite(tosa::ScatterOp scatter,
100 PatternRewriter &rewriter) const final {
101 auto valuesIn = scatter.getValuesIn();
102 auto indices = scatter.getIndices();
103 auto input = scatter.getInput();
104 auto loc = scatter.getLoc();
105
106 // N, W, C are chosen to match the TOSA spec
107 auto dimN = createTensorDim(builder&: rewriter, loc: loc, tensor: input, dim: 0);
108 auto dimW = createTensorDim(builder&: rewriter, loc: loc, tensor: input, dim: 1);
109 auto dimC = createTensorDim(builder&: rewriter, loc: loc, tensor: input, dim: 2);
110
111 auto zero = createIndexConst(builder&: rewriter, loc: loc, value: 0);
112 auto one = createIndexConst(builder&: rewriter, loc: loc, value: 1);
113
114 // Loop bounds
115 auto lbs = llvm::SmallVector<Value>(2, zero);
116 auto steps = llvm::SmallVector<Value>(2, one);
117 auto ubs = llvm::SmallVector<Value>{{dimN, dimW}};
118
119 auto buildBody = [&](OpBuilder &builder, Location loc, ValueRange ivs,
120 ValueRange args) -> scf::ValueVector {
121 auto n = ivs[0];
122
123 // Read the index and cast it to index type
124 auto index = builder.create<tensor::ExtractOp>(loc, indices, ivs);
125 auto castIndex = builder.create<arith::IndexCastOp>(
126 loc, builder.getIndexType(), index);
127
128 // Offset, sizes, and strides for the input tensor
129 auto inputOffset = llvm::to_vector(Range&: ivs);
130 inputOffset.push_back(Elt: zero);
131
132 llvm::SmallVector<Value> sizes = {one, one, dimC};
133 llvm::SmallVector<Value> strides = {one, one, one};
134
135 auto slice = builder.create<tensor::ExtractSliceOp>(
136 loc, input, inputOffset, sizes, strides);
137
138 // Insert the slice into the output accumulator tensor.
139 llvm::SmallVector<Value> outputOffset = {n, castIndex, zero};
140 auto updated = builder.create<tensor::InsertSliceOp>(
141 loc, slice, args[0], outputOffset, sizes, strides);
142
143 return {updated};
144 };
145
146 auto loops = scf::buildLoopNest(rewriter, loc, lbs, ubs, steps,
147 ValueRange{valuesIn}, buildBody);
148 rewriter.replaceOp(scatter, loops.results);
149
150 return success();
151 }
152};
153
154class WhileOpConverter : public OpRewritePattern<tosa::WhileOp> {
155public:
156 using OpRewritePattern<tosa::WhileOp>::OpRewritePattern;
157
158 LogicalResult matchAndRewrite(tosa::WhileOp op,
159 PatternRewriter &rewriter) const final {
160 auto newWhile = rewriter.create<scf::WhileOp>(
161 op.getLoc(), op.getResultTypes(), op.getInputs());
162 rewriter.createBlock(&newWhile.getBefore());
163 rewriter.createBlock(&newWhile.getAfter());
164
165 inlineWhileCase(op.getCond(), newWhile.getBefore(), rewriter, true);
166 inlineWhileCase(op.getBody(), newWhile.getAfter(), rewriter, false);
167
168 rewriter.replaceOp(op, newWhile.getResults());
169
170 return success();
171 }
172};
173
174} // namespace
175
176void mlir::tosa::populateTosaToSCFConversionPatterns(
177 RewritePatternSet *patterns) {
178 patterns->add<IfOpConverter, ScatterOpConverter, WhileOpConverter>(
179 arg: patterns->getContext());
180}
181

source code of mlir/lib/Conversion/TosaToSCF/TosaToSCF.cpp