1//===- TosaInferShapes.cpp ------------------------------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// Propagate shapes forward along TOSA operations to resolve dynamic shape
10// operations.
11//
12//===----------------------------------------------------------------------===//
13
14#include "mlir/Dialect/Tosa/Transforms/Passes.h"
15
16#include "mlir/Dialect/Func/IR/FuncOps.h"
17#include "mlir/Dialect/Tensor/IR/Tensor.h"
18#include "mlir/Dialect/Tosa/IR/TosaOps.h"
19#include "mlir/Dialect/Tosa/Utils/ShapeUtils.h"
20#include "mlir/IR/Builders.h"
21#include "mlir/IR/ImplicitLocOpBuilder.h"
22#include "mlir/Interfaces/InferTypeOpInterface.h"
23#include "mlir/Pass/Pass.h"
24#include "mlir/Transforms/DialectConversion.h"
25
26namespace mlir {
27namespace tosa {
28#define GEN_PASS_DEF_TOSAINFERSHAPESPASS
29#include "mlir/Dialect/Tosa/Transforms/Passes.h.inc"
30} // namespace tosa
31} // namespace mlir
32
33using namespace mlir;
34using namespace mlir::tosa;
35
36namespace {
37
38// Check whether this use case is replaceable. We define an op as
39// being replaceable if it is used by a TosaOp, or an op with a
40// type-inference related interface.
41// When a non-replaceable use is encountered, the value is wrapped in a
42// cast back to the original type after inference.
43bool canBeRefined(Operation *user) {
44 if (!user->getDialect())
45 return false;
46 return user->getDialect()->getTypeID() == TypeID::get<TosaDialect>() ||
47 isa<InferTypeOpInterface, InferShapedTypeOpInterface>(user);
48}
49
50// During type propagation, the types of values in the operator graph are
51// updated. For the tosa.while_loop operation, types are speculatively updated
52// within the body region to determine the output type of the while_loop. This
53// process is performed until a fixed point is reached, then the types are
54// rolled back.
55//
56// This class encapsulates the state information needed to perform the roll back
57// process or to commit to the final changes.
58class TypeModificationState {
59public:
60 TypeModificationState() = default;
61
62 ~TypeModificationState() {
63 // Ensure the recorded modifications are either committed or rolled back.
64 assert(oldTypes.empty() && "unhandled type modifications");
65 }
66
67 // Update the state of the value and record the old type.
68 void setType(Value value, Type type) {
69 if (value.getType() != type) {
70 oldTypes.emplace_back(value, value.getType());
71 value.setType(type);
72 }
73 }
74
75 // Roll back changes made to the types in the IR by setting all the affected
76 // values to their old types.
77 void rollBack() {
78 for (auto [value, type] : oldTypes)
79 value.setType(type);
80
81 oldTypes.clear();
82 }
83
84 // Commit the changes to the types in the IR.
85 // This requires inserting tensor.cast operations to mediate the newly
86 // inferred result types with users that do not support type inference.
87 void commit() {
88 // For each use whose type changed, cast the value with the new type back to
89 // the old type.
90 for (auto [value, oldType] : oldTypes) {
91 // The call to 'use->set()' in the body of the loop below invalidates the
92 // iterator used to traverse op uses, so it is important to make a copy of
93 // these first.
94 llvm::SmallVector<OpOperand *> uses = llvm::map_to_vector(
95 value.getUses(),
96 [](OpOperand &use) -> OpOperand * {
97 return &use;
98 });
99
100 // A 'tensor.cast' op is emitted only if needed. Once emitted, it is
101 // cached and reused by all consumers.
102 tensor::CastOp castValue;
103
104 // Traverse all uses
105 for (OpOperand *use : uses) {
106 if (canBeRefined(use->getOwner()))
107 continue;
108
109 if (!castValue) {
110 // Set the insertion point as far back as possible, since new
111 // consumers of the 'tensor.cast' op generated in future iterations
112 // are likely to be further up in the code due to the order in which
113 // they appear in the use list.
114 OpBuilder builder{value.getContext()};
115 builder.setInsertionPointAfter(value.getDefiningOp());
116 castValue =
117 builder.create<tensor::CastOp>(value.getLoc(), oldType, value);
118 }
119
120 use->set(castValue);
121 }
122 }
123
124 oldTypes.clear();
125 }
126
127private:
128 // A record of each value whose type was updated along with that value's
129 // previous type.
130 llvm::SmallVector<std::pair<Value, Type>> oldTypes;
131};
132
133void propagateShapesInRegion(Region &region, TypeModificationState &state);
134
135void propagateShapesToTosaIf(Operation &op, TypeModificationState &state) {
136 IfOp ifOp = dyn_cast<IfOp>(op);
137 if (!ifOp)
138 return;
139
140 for (auto &region : op.getRegions()) {
141 Block &frontBlock = region.front();
142 if (frontBlock.getNumArguments() + 1 != ifOp.getNumOperands())
143 return;
144
145 for (unsigned int i = 1, s = op.getNumOperands(); i < s; i++) {
146 auto inferredTy = cast<ShapedType>(op.getOperand(i).getType());
147 auto blockArg = frontBlock.getArgument(i: i - 1);
148 auto oldType = cast<ShapedType>(blockArg.getType());
149
150 if (inferredTy.hasRank()) {
151 Type newType = oldType.clone(inferredTy.getShape());
152 state.setType(value: blockArg, type: newType);
153 }
154 }
155
156 for (int i = 0, e = frontBlock.getNumArguments(); i < e; i++) {
157 ValueKnowledge operandKnowledge = ValueKnowledge::getKnowledgeFromType(
158 type: ifOp.getOperand(i + 1).getType());
159 ValueKnowledge blockKnowledge = ValueKnowledge::getKnowledgeFromType(
160 type: frontBlock.getArgument(i).getType());
161 ValueKnowledge joinedKnowledge =
162 ValueKnowledge::join(lhs: operandKnowledge, rhs: blockKnowledge);
163 if (!joinedKnowledge)
164 continue;
165 state.setType(value: frontBlock.getArgument(i), type: joinedKnowledge.getType());
166 }
167
168 propagateShapesInRegion(region, state);
169 }
170}
171
172void propagateShapesToTosaWhile(Operation &op, TypeModificationState &state) {
173 WhileOp whileOp = dyn_cast<WhileOp>(op);
174 if (!whileOp)
175 return;
176
177 // Determine what the expected argument types are to the cond/body blocks.
178 // The expected arguments should be compatible with ever iteration of the
179 // loop body / condition for tosa.while.
180 SmallVector<Type> argTypes = llvm::to_vector(op.getOperandTypes());
181
182 bool hasNewTypes = true;
183 while (hasNewTypes) {
184 TypeModificationState localState;
185
186 // Set types on the block args.
187 Region &bodyRegion = op.getRegion(index: 1);
188 Block &block = bodyRegion.front();
189 for (int i = 0, s = argTypes.size(); i < s; i++) {
190 localState.setType(value: block.getArgument(i), type: argTypes[i]);
191 }
192
193 // Propagate to the end.
194 propagateShapesInRegion(region&: bodyRegion, state&: localState);
195
196 // Find all the tosa yield types and verify there is a single one.
197 llvm::SmallVector<YieldOp> yieldOps;
198 for (auto &block : bodyRegion)
199 if (auto yieldOp = dyn_cast<YieldOp>(block.getTerminator()))
200 yieldOps.push_back(yieldOp);
201
202 assert(yieldOps.size() == 1 && "missing or non-unique yield op");
203 // Using the new tosa.yield operand types, infer the new subtypes.
204 llvm::SmallVector<ValueKnowledge> yieldTypeInfo;
205 for (auto ty : argTypes) {
206 yieldTypeInfo.push_back(ValueKnowledge::getKnowledgeFromType(ty));
207 }
208
209 for (auto yieldOp : yieldOps) {
210 for (const auto &it : llvm::enumerate(yieldOp.getOperands())) {
211 auto newKnowledge =
212 ValueKnowledge::getKnowledgeFromType(it.value().getType());
213 yieldTypeInfo[it.index()] =
214 ValueKnowledge::meet(yieldTypeInfo[it.index()], newKnowledge);
215 }
216 }
217
218 // This should never happen.
219 if (yieldTypeInfo.size() != argTypes.size()) {
220 op.emitWarning(message: "has a tosa.yield with the incorrect number of operands");
221 return;
222 }
223
224 // Determine the new block args and see if any changed.
225 hasNewTypes = false;
226 for (int i = 0, s = yieldTypeInfo.size(); i < s; i++) {
227 Type newType = yieldTypeInfo[i].getType();
228 hasNewTypes |= (newType != argTypes[i]);
229 argTypes[i] = newType;
230 }
231
232 // Roll back all changes made during the speculative part of the algorithm.
233 localState.rollBack();
234 }
235
236 // We now set the block arguments according to the most recent shape
237 // inference results. This gives us the block arg types for the next
238 // iteration.
239 for (auto &region : op.getRegions()) {
240 for (unsigned int i = 0, s = argTypes.size(); i < s; i++) {
241 state.setType(value: region.front().getArgument(i), type: argTypes[i]);
242 }
243
244 propagateShapesInRegion(region, state);
245 }
246}
247
248void propagateShapesInRegion(Region &region, TypeModificationState &state) {
249 Dialect *tosaDialect = region.getContext()->getLoadedDialect<TosaDialect>();
250
251 for (auto &block : region) {
252 for (Operation &op : block) {
253 if (op.getDialect() != tosaDialect)
254 continue;
255
256 propagateShapesToTosaIf(op, state);
257 propagateShapesToTosaWhile(op, state);
258
259 InferShapedTypeOpInterface shapeInterface =
260 dyn_cast<InferShapedTypeOpInterface>(op);
261 if (!shapeInterface)
262 continue;
263
264 SmallVector<ShapedTypeComponents> returnedShapes;
265
266 if (shapeInterface
267 .inferReturnTypeComponents(
268 op.getContext(), op.getLoc(), op.getOperands(),
269 op.getDiscardableAttrDictionary(), op.getPropertiesStorage(),
270 op.getRegions(), returnedShapes)
271 .succeeded()) {
272 for (auto it : llvm::zip(op.getResults(), returnedShapes)) {
273 Value result = std::get<0>(it);
274 ShapedTypeComponents predictedShape = std::get<1>(it);
275
276 // Determine the knowledge based on the output type.
277 // TODO: should also query WIP type probably
278 Type resultTy = result.getType();
279 auto currentKnowledge =
280 ValueKnowledge::getKnowledgeFromType(resultTy);
281
282 // Compute the knowledge based on the inferred type.
283 auto inferredKnowledge = ValueKnowledge::getPessimisticValueState();
284 inferredKnowledge.dtype = cast<ShapedType>(resultTy).getElementType();
285 inferredKnowledge.hasRank = predictedShape.hasRank();
286 if (predictedShape.hasRank()) {
287 for (auto dim : predictedShape.getDims()) {
288 inferredKnowledge.sizes.push_back(dim);
289 }
290 }
291
292 // Compute the new type based on the joined version.
293 auto newKnowledge =
294 ValueKnowledge::join(currentKnowledge, inferredKnowledge);
295 if (!newKnowledge)
296 continue;
297
298 // Set new type
299 state.setType(result, newKnowledge.getType());
300 }
301 }
302 }
303 }
304}
305
306/// Recursively validate tosa ops with SameOperandsAndResultRank trait in region
307/// and all nested regions
308void validateSameOperandsAndResultRankTrait(Region &region) {
309 int errs = 0;
310 for (auto &block : region) {
311 for (auto &op : block) {
312 if (!op.getDialect() ||
313 op.getDialect()->getNamespace() != TosaDialect::getDialectNamespace())
314 continue;
315 if (op.hasTrait<OpTrait::SameOperandsAndResultRank>()) {
316 if (OpTrait::impl::verifySameOperandsAndResultRank(op: &op).failed()) {
317 errs++;
318 (void)errs;
319 }
320 }
321 WhileOp whileOp = dyn_cast<WhileOp>(op);
322 IfOp ifOp = dyn_cast<IfOp>(op);
323 if (whileOp || ifOp) {
324 // recurse into whileOp's regions
325 for (auto &next : op.getRegions()) {
326 validateSameOperandsAndResultRankTrait(region&: next);
327 }
328 }
329 }
330 }
331}
332
333/// Pass that performs shape propagation across TOSA operations. This includes
334/// migrating to within the regions of if/while operations.
335struct TosaInferShapes
336 : public tosa::impl::TosaInferShapesPassBase<TosaInferShapes> {
337public:
338 void runOnOperation() override {
339 func::FuncOp func = getOperation();
340 TypeModificationState state;
341 propagateShapesInRegion(func.getBody(), state);
342 state.commit();
343
344 validateSameOperandsAndResultRankTrait(func.getBody());
345 }
346};
347} // namespace
348

Provided by KDAB

Privacy Policy
Update your C++ knowledge – Modern C++11/14/17 Training
Find out more

source code of mlir/lib/Dialect/Tosa/Transforms/TosaInferShapes.cpp