1//===- TosaToSCF.cpp - Lowering Tosa to SCF Dialect -----------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// These rewriters lower from the Tosa to the SCF dialect.
10//
11//===----------------------------------------------------------------------===//
12
13#include "mlir/Conversion/TosaToSCF/TosaToSCF.h"
14#include "mlir/Dialect/SCF/IR/SCF.h"
15#include "mlir/Dialect/Tensor/IR/Tensor.h"
16#include "mlir/Dialect/Tosa/IR/TosaOps.h"
17#include "mlir/IR/PatternMatch.h"
18
19using namespace mlir;
20using namespace tosa;
21
22static void inlineIfCase(Region &srcRegion, Region &dstRegion,
23 OperandRange operands, PatternRewriter &rewriter) {
24 rewriter.cloneRegionBefore(region&: srcRegion, before: &dstRegion.front());
25 rewriter.eraseBlock(block: &dstRegion.back());
26
27 Block *headBlock = &dstRegion.front();
28 for (auto it : llvm::zip(t: headBlock->getArguments(), u&: operands))
29 std::get<0>(t&: it).replaceAllUsesWith(newValue: std::get<1>(t&: it));
30
31 auto yield = cast<YieldOp>(Val: headBlock->getTerminator());
32 rewriter.setInsertionPoint(yield);
33 rewriter.create<scf::YieldOp>(location: yield.getLoc(), args: yield.getInputs());
34 rewriter.eraseOp(op: yield);
35
36 headBlock->eraseArguments(start: 0, num: headBlock->getNumArguments());
37}
38
39static void inlineWhileCase(Region &srcRegion, Region &dstRegion,
40 PatternRewriter &rewriter, bool isCond) {
41 rewriter.cloneRegionBefore(region&: srcRegion, before: &dstRegion.back());
42 rewriter.eraseBlock(block: &dstRegion.back());
43
44 Block *headBlock = &dstRegion.front();
45
46 auto yield = cast<YieldOp>(Val: headBlock->getTerminator());
47 rewriter.setInsertionPoint(yield);
48 if (isCond) {
49 auto condition =
50 rewriter.create<tensor::ExtractOp>(location: yield.getLoc(), args: yield.getOperand(i: 0));
51 rewriter.create<scf::ConditionOp>(location: yield.getLoc(), args&: condition,
52 args: headBlock->getArguments());
53 } else {
54 rewriter.setInsertionPoint(yield);
55 rewriter.create<scf::YieldOp>(location: yield.getLoc(), args: yield.getInputs());
56 }
57 rewriter.eraseOp(op: yield);
58}
59
60namespace {
61
62class IfOpConverter : public OpRewritePattern<tosa::IfOp> {
63public:
64 using OpRewritePattern<tosa::IfOp>::OpRewritePattern;
65
66 LogicalResult matchAndRewrite(tosa::IfOp op,
67 PatternRewriter &rewriter) const final {
68 auto condition =
69 rewriter.create<tensor::ExtractOp>(location: op.getLoc(), args: op.getCondition());
70 auto newIf = rewriter.create<scf::IfOp>(location: op.getLoc(), args: op.getResultTypes(),
71 args&: condition, args: true);
72
73 inlineIfCase(srcRegion&: op.getThenGraph(), dstRegion&: newIf.getThenRegion(), operands: op.getInputList(),
74 rewriter);
75 inlineIfCase(srcRegion&: op.getElseGraph(), dstRegion&: newIf.getElseRegion(), operands: op.getInputList(),
76 rewriter);
77
78 rewriter.replaceOp(op, newValues: newIf.getResults());
79 return success();
80 }
81};
82
83class ScatterOpConverter : public OpRewritePattern<tosa::ScatterOp> {
84 static Value createTensorDim(OpBuilder &builder, Location loc, Value tensor,
85 int64_t dim) {
86 return builder.createOrFold<tensor::DimOp>(location: loc, args&: tensor, args&: dim);
87 }
88
89 static Value createIndexConst(OpBuilder &builder, Location loc,
90 int64_t value) {
91 return builder.create<arith::ConstantIndexOp>(location: loc, args&: value);
92 }
93
94public:
95 using OpRewritePattern<tosa::ScatterOp>::OpRewritePattern;
96
97 LogicalResult matchAndRewrite(tosa::ScatterOp scatter,
98 PatternRewriter &rewriter) const final {
99 auto valuesIn = scatter.getValuesIn();
100 auto indices = scatter.getIndices();
101 auto input = scatter.getInput();
102 auto loc = scatter.getLoc();
103
104 // N, W, C are chosen to match the TOSA spec
105 auto dimN = createTensorDim(builder&: rewriter, loc, tensor: input, dim: 0);
106 auto dimW = createTensorDim(builder&: rewriter, loc, tensor: input, dim: 1);
107 auto dimC = createTensorDim(builder&: rewriter, loc, tensor: input, dim: 2);
108
109 auto zero = createIndexConst(builder&: rewriter, loc, value: 0);
110 auto one = createIndexConst(builder&: rewriter, loc, value: 1);
111
112 // Loop bounds
113 auto lbs = llvm::SmallVector<Value>(2, zero);
114 auto steps = llvm::SmallVector<Value>(2, one);
115 auto ubs = llvm::SmallVector<Value>{{dimN, dimW}};
116
117 auto buildBody = [&](OpBuilder &builder, Location loc, ValueRange ivs,
118 ValueRange args) -> scf::ValueVector {
119 auto n = ivs[0];
120
121 // Read the index and cast it to index type
122 auto index = builder.create<tensor::ExtractOp>(location: loc, args&: indices, args&: ivs);
123 auto castIndex = builder.create<arith::IndexCastOp>(
124 location: loc, args: builder.getIndexType(), args&: index);
125
126 // Offset, sizes, and strides for the input tensor
127 auto inputOffset = llvm::to_vector(Range&: ivs);
128 inputOffset.push_back(Elt: zero);
129
130 llvm::SmallVector<Value> sizes = {one, one, dimC};
131 llvm::SmallVector<Value> strides = {one, one, one};
132
133 auto slice = builder.create<tensor::ExtractSliceOp>(
134 location: loc, args&: input, args&: inputOffset, args&: sizes, args&: strides);
135
136 // Insert the slice into the output accumulator tensor.
137 llvm::SmallVector<Value> outputOffset = {n, castIndex, zero};
138 auto updated = builder.create<tensor::InsertSliceOp>(
139 location: loc, args&: slice, args: args[0], args&: outputOffset, args&: sizes, args&: strides);
140
141 return {updated};
142 };
143
144 auto loops = scf::buildLoopNest(builder&: rewriter, loc, lbs, ubs, steps,
145 iterArgs: ValueRange{valuesIn}, bodyBuilder: buildBody);
146 rewriter.replaceOp(op: scatter, newValues: loops.results);
147
148 return success();
149 }
150};
151
152class WhileOpConverter : public OpRewritePattern<tosa::WhileOp> {
153public:
154 using OpRewritePattern<tosa::WhileOp>::OpRewritePattern;
155
156 LogicalResult matchAndRewrite(tosa::WhileOp op,
157 PatternRewriter &rewriter) const final {
158 auto newWhile = rewriter.create<scf::WhileOp>(
159 location: op.getLoc(), args: op.getResultTypes(), args: op.getInputList());
160 rewriter.createBlock(parent: &newWhile.getBefore());
161 rewriter.createBlock(parent: &newWhile.getAfter());
162
163 inlineWhileCase(srcRegion&: op.getCondGraph(), dstRegion&: newWhile.getBefore(), rewriter, isCond: true);
164 inlineWhileCase(srcRegion&: op.getBodyGraph(), dstRegion&: newWhile.getAfter(), rewriter, isCond: false);
165
166 rewriter.replaceOp(op, newValues: newWhile.getResults());
167
168 return success();
169 }
170};
171
172} // namespace
173
174void mlir::tosa::populateTosaToSCFConversionPatterns(
175 RewritePatternSet *patterns) {
176 patterns->add<IfOpConverter, ScatterOpConverter, WhileOpConverter>(
177 arg: patterns->getContext());
178}
179

source code of mlir/lib/Conversion/TosaToSCF/TosaToSCF.cpp