1//===- TosaInferShapes.cpp ------------------------------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// Propogate shapes forward along TOSA operations to resolve dynamic shape
10// operations.
11//
12//===----------------------------------------------------------------------===//
13
14#include "mlir/Dialect/Tosa/Transforms/Passes.h"
15
16#include "mlir/Dialect/Func/IR/FuncOps.h"
17#include "mlir/Dialect/Tensor/IR/Tensor.h"
18#include "mlir/Dialect/Tosa/IR/TosaOps.h"
19#include "mlir/Dialect/Tosa/Utils/ShapeUtils.h"
20#include "mlir/IR/Builders.h"
21#include "mlir/Interfaces/InferTypeOpInterface.h"
22#include "mlir/Pass/Pass.h"
23#include "mlir/Transforms/DialectConversion.h"
24
25namespace mlir {
26namespace tosa {
27#define GEN_PASS_DEF_TOSAINFERSHAPES
28#include "mlir/Dialect/Tosa/Transforms/Passes.h.inc"
29} // namespace tosa
30} // namespace mlir
31
32using namespace mlir;
33using namespace mlir::tosa;
34
35namespace {
36
37// Check whether this use case is replaceable. We define an op as
38// being replaceable if it is used by a TosaOp, or an op with a
39// type-inference related interface.
40// When a non-replaceable use is encountered, the value is wrapped in a
41// cast back to the original type after inference.
42bool isReplaceableUser(Operation *user) {
43 // Handle unregistered dialects.
44 if (!user->getDialect())
45 return false;
46
47 return user->getDialect()->getNamespace() ==
48 TosaDialect::getDialectNamespace() ||
49 isa<InferTypeOpInterface, InferShapedTypeOpInterface>(user);
50}
51
52// During type propagation, the types of values in the operator graph are
53// updated. For the tosa.while_loop operation, types are speculatively updated
54// within the body region to determine the output type of the while_loop. This
55// process is performed until a fixed point is reached, then the types are
56// reverted.
57//
58// This class encapsulates the state information needed to perform the reversion
59// process or to commit to the final changes.
60class TypeModificationState {
61public:
62 TypeModificationState() = default;
63
64 ~TypeModificationState() {
65 // Ensure the recorded modifications are either committed or reverted.
66 assert(oldTypes.empty() && "unhandled type modifications");
67 }
68
69 // Update the state of the value and record the old type.
70 void setType(Value value, Type type) {
71 if (value.getType() != type) {
72 oldTypes.emplace_back(value, value.getType());
73 value.setType(type);
74 }
75 }
76
77 // Revert changes made to the types in the IR by setting all the affected
78 // values to their old types.
79 void revert() {
80 // Otherwise revert the changes.
81 for (auto [value, type] : oldTypes)
82 value.setType(type);
83
84 oldTypes.clear();
85 }
86
87 // Commit the changes to the types in the IR.
88 // This requires inserting tensor.cast operations to mediate the newly
89 // inferred result types with users that do not support type inference.
90 void commit() {
91 // For each use whose type changed, cast the value with the new type back to
92 // the old type.
93 for (auto [value, oldType] : oldTypes) {
94 for (auto &use : value.getUses()) {
95 if (isReplaceableUser(use.getOwner()))
96 continue;
97
98 OpBuilder builder(value.getContext());
99 builder.setInsertionPoint(use.getOwner());
100
101 Location loc = value.getLoc();
102 use.set(builder.create<tensor::CastOp>(loc, oldType, value));
103 }
104 }
105
106 oldTypes.clear();
107 }
108
109private:
110 // A record of each value whose type was updated along with that value's
111 // previous type.
112 llvm::SmallVector<std::pair<Value, Type>> oldTypes;
113};
114
115void propagateShapesInRegion(Region &region, TypeModificationState &state);
116
117void propagateShapesToTosaIf(Operation &op, TypeModificationState &state) {
118 IfOp ifOp = dyn_cast<IfOp>(op);
119 if (!ifOp)
120 return;
121
122 for (auto &region : op.getRegions()) {
123 Block &frontBlock = region.front();
124 if (frontBlock.getNumArguments() + 1 != ifOp.getNumOperands())
125 return;
126
127 for (unsigned int i = 1, s = op.getNumOperands(); i < s; i++) {
128 auto inferredTy = cast<ShapedType>(op.getOperand(i).getType());
129 auto blockArg = frontBlock.getArgument(i: i - 1);
130 auto oldType = cast<ShapedType>(blockArg.getType());
131
132 if (inferredTy.hasRank()) {
133 Type newType = oldType.clone(inferredTy.getShape());
134 state.setType(value: blockArg, type: newType);
135 }
136 }
137
138 for (int i = 0, e = frontBlock.getNumArguments(); i < e; i++) {
139 ValueKnowledge operandKnowledge = ValueKnowledge::getKnowledgeFromType(
140 type: ifOp.getOperand(i + 1).getType());
141 ValueKnowledge blockKnowledge = ValueKnowledge::getKnowledgeFromType(
142 type: frontBlock.getArgument(i).getType());
143 ValueKnowledge joinedKnowledge =
144 ValueKnowledge::join(lhs: operandKnowledge, rhs: blockKnowledge);
145 if (!joinedKnowledge)
146 continue;
147 state.setType(value: frontBlock.getArgument(i), type: joinedKnowledge.getType());
148 }
149
150 propagateShapesInRegion(region, state);
151 }
152}
153
154void propagateShapesToTosaWhile(Operation &op, TypeModificationState &state) {
155 WhileOp whileOp = dyn_cast<WhileOp>(op);
156 if (!whileOp)
157 return;
158
159 // Determine what the expected argument types are to the cond/body blocks.
160 // The expected arguments should be compatible with ever iteration of the
161 // loop body / condition for tosa.while.
162 SmallVector<Type> argTypes = llvm::to_vector(op.getOperandTypes());
163
164 bool hasNewTypes = true;
165 while (hasNewTypes) {
166 TypeModificationState localState;
167
168 // Set types on the block args.
169 Region &bodyRegion = op.getRegion(index: 1);
170 Block &block = bodyRegion.front();
171 for (int i = 0, s = argTypes.size(); i < s; i++) {
172 localState.setType(value: block.getArgument(i), type: argTypes[i]);
173 }
174
175 // Propagate to the end.
176 propagateShapesInRegion(region&: bodyRegion, state&: localState);
177
178 // Find all the tosa yield types and verify there is a single one.
179 llvm::SmallVector<YieldOp> yieldOps;
180 for (auto &block : bodyRegion)
181 if (auto yieldOp = dyn_cast<YieldOp>(block.getTerminator()))
182 yieldOps.push_back(yieldOp);
183
184 assert(yieldOps.size() == 1 && "missing or non-unique yield op");
185 // Using the new tosa.yield operand types, infer the new subtypes.
186 llvm::SmallVector<ValueKnowledge> yieldTypeInfo;
187 for (auto ty : argTypes) {
188 yieldTypeInfo.push_back(ValueKnowledge::getKnowledgeFromType(ty));
189 }
190
191 for (auto yieldOp : yieldOps) {
192 for (const auto &it : llvm::enumerate(yieldOp.getOperands())) {
193 auto newKnowledge =
194 ValueKnowledge::getKnowledgeFromType(it.value().getType());
195 yieldTypeInfo[it.index()] =
196 ValueKnowledge::meet(yieldTypeInfo[it.index()], newKnowledge);
197 }
198 }
199
200 // This should never happen.
201 if (yieldTypeInfo.size() != argTypes.size()) {
202 op.emitWarning(message: "has a tosa.yield with the incorrect number of operands");
203 return;
204 }
205
206 // Determine the new block args and see if any changed.
207 hasNewTypes = false;
208 for (int i = 0, s = yieldTypeInfo.size(); i < s; i++) {
209 Type newType = yieldTypeInfo[i].getType();
210 hasNewTypes |= (newType != argTypes[i]);
211 argTypes[i] = newType;
212 }
213
214 // Revert all changes made during the speculative part of the algorithm.
215 localState.revert();
216 }
217
218 // We now set the block arguments according to the most recent shape
219 // inference results. This gives us the block arg types for the next
220 // iteration.
221 for (auto &region : op.getRegions()) {
222 for (unsigned int i = 0, s = argTypes.size(); i < s; i++) {
223 state.setType(value: region.front().getArgument(i), type: argTypes[i]);
224 }
225
226 propagateShapesInRegion(region, state);
227 }
228}
229
230void propagateShapesInRegion(Region &region, TypeModificationState &state) {
231 for (auto &block : region) {
232 for (Operation &op : block) {
233 if (!op.getDialect() ||
234 op.getDialect()->getNamespace() != TosaDialect::getDialectNamespace())
235 continue;
236
237 propagateShapesToTosaIf(op, state);
238 propagateShapesToTosaWhile(op, state);
239
240 InferShapedTypeOpInterface shapeInterface =
241 dyn_cast<InferShapedTypeOpInterface>(op);
242 if (!shapeInterface)
243 continue;
244
245 SmallVector<ShapedTypeComponents> returnedShapes;
246
247 if (shapeInterface
248 .inferReturnTypeComponents(
249 op.getContext(), op.getLoc(), op.getOperands(),
250 op.getDiscardableAttrDictionary(), op.getPropertiesStorage(),
251 op.getRegions(), returnedShapes)
252 .succeeded()) {
253 for (auto it : llvm::zip(op.getResults(), returnedShapes)) {
254 Value result = std::get<0>(it);
255 ShapedTypeComponents predictedShape = std::get<1>(it);
256
257 // Determine the knowledge based on the output type.
258 // TODO: should also query WIP type probably
259 Type resultTy = result.getType();
260 auto currentKnowledge =
261 ValueKnowledge::getKnowledgeFromType(resultTy);
262
263 // Compute the knowledge based on the inferred type.
264 auto inferredKnowledge = ValueKnowledge::getPessimisticValueState();
265 inferredKnowledge.dtype = cast<ShapedType>(resultTy).getElementType();
266 inferredKnowledge.hasRank = predictedShape.hasRank();
267 if (predictedShape.hasRank()) {
268 for (auto dim : predictedShape.getDims()) {
269 inferredKnowledge.sizes.push_back(dim);
270 }
271 }
272
273 // Compute the new type based on the joined version.
274 auto newKnowledge =
275 ValueKnowledge::join(currentKnowledge, inferredKnowledge);
276 if (!newKnowledge)
277 continue;
278
279 // Set new type
280 state.setType(result, newKnowledge.getType());
281 }
282 }
283 }
284 }
285}
286
287/// Pass that performs shape propagation across TOSA operations. This includes
288/// migrating to within the regions of if/while operations.
289struct TosaInferShapes
290 : public tosa::impl::TosaInferShapesBase<TosaInferShapes> {
291public:
292 void runOnOperation() override {
293 func::FuncOp func = getOperation();
294 TypeModificationState state;
295 propagateShapesInRegion(func.getBody(), state);
296 state.commit();
297 }
298};
299} // namespace
300
301std::unique_ptr<Pass> mlir::tosa::createTosaInferShapesPass() {
302 return std::make_unique<TosaInferShapes>();
303}
304

source code of mlir/lib/Dialect/Tosa/Transforms/TosaInferShapes.cpp