1 | //===- TosaInferShapes.cpp ------------------------------------------------===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | // |
9 | // Propagate shapes forward along TOSA operations to resolve dynamic shape |
10 | // operations. |
11 | // |
12 | //===----------------------------------------------------------------------===// |
13 | |
14 | #include "mlir/Dialect/Tosa/Transforms/Passes.h" |
15 | |
16 | #include "mlir/Dialect/Func/IR/FuncOps.h" |
17 | #include "mlir/Dialect/Tensor/IR/Tensor.h" |
18 | #include "mlir/Dialect/Tosa/IR/TosaOps.h" |
19 | #include "mlir/Dialect/Tosa/Utils/ShapeUtils.h" |
20 | #include "mlir/IR/Builders.h" |
21 | #include "mlir/IR/ImplicitLocOpBuilder.h" |
22 | #include "mlir/Interfaces/InferTypeOpInterface.h" |
23 | #include "mlir/Pass/Pass.h" |
24 | #include "mlir/Transforms/DialectConversion.h" |
25 | |
26 | namespace mlir { |
27 | namespace tosa { |
28 | #define GEN_PASS_DEF_TOSAINFERSHAPESPASS |
29 | #include "mlir/Dialect/Tosa/Transforms/Passes.h.inc" |
30 | } // namespace tosa |
31 | } // namespace mlir |
32 | |
33 | using namespace mlir; |
34 | using namespace mlir::tosa; |
35 | |
36 | namespace { |
37 | |
38 | // Check whether this use case is replaceable. We define an op as |
39 | // being replaceable if it is used by a TosaOp, or an op with a |
40 | // type-inference related interface. |
41 | // When a non-replaceable use is encountered, the value is wrapped in a |
42 | // cast back to the original type after inference. |
43 | bool canBeRefined(Operation *user) { |
44 | if (!user->getDialect()) |
45 | return false; |
46 | return user->getDialect()->getTypeID() == TypeID::get<TosaDialect>() || |
47 | isa<InferTypeOpInterface, InferShapedTypeOpInterface>(user); |
48 | } |
49 | |
50 | // During type propagation, the types of values in the operator graph are |
51 | // updated. For the tosa.while_loop operation, types are speculatively updated |
52 | // within the body region to determine the output type of the while_loop. This |
53 | // process is performed until a fixed point is reached, then the types are |
54 | // rolled back. |
55 | // |
56 | // This class encapsulates the state information needed to perform the roll back |
57 | // process or to commit to the final changes. |
58 | class TypeModificationState { |
59 | public: |
60 | TypeModificationState() = default; |
61 | |
62 | ~TypeModificationState() { |
63 | // Ensure the recorded modifications are either committed or rolled back. |
64 | assert(oldTypes.empty() && "unhandled type modifications" ); |
65 | } |
66 | |
67 | // Update the state of the value and record the old type. |
68 | void setType(Value value, Type type) { |
69 | if (value.getType() != type) { |
70 | oldTypes.emplace_back(value, value.getType()); |
71 | value.setType(type); |
72 | } |
73 | } |
74 | |
75 | // Roll back changes made to the types in the IR by setting all the affected |
76 | // values to their old types. |
77 | void rollBack() { |
78 | for (auto [value, type] : oldTypes) |
79 | value.setType(type); |
80 | |
81 | oldTypes.clear(); |
82 | } |
83 | |
84 | // Commit the changes to the types in the IR. |
85 | // This requires inserting tensor.cast operations to mediate the newly |
86 | // inferred result types with users that do not support type inference. |
87 | void commit() { |
88 | // For each use whose type changed, cast the value with the new type back to |
89 | // the old type. |
90 | for (auto [value, oldType] : oldTypes) { |
91 | // The call to 'use->set()' in the body of the loop below invalidates the |
92 | // iterator used to traverse op uses, so it is important to make a copy of |
93 | // these first. |
94 | llvm::SmallVector<OpOperand *> uses = llvm::map_to_vector( |
95 | value.getUses(), |
96 | [](OpOperand &use) -> OpOperand * { |
97 | return &use; |
98 | }); |
99 | |
100 | // A 'tensor.cast' op is emitted only if needed. Once emitted, it is |
101 | // cached and reused by all consumers. |
102 | tensor::CastOp castValue; |
103 | |
104 | // Traverse all uses |
105 | for (OpOperand *use : uses) { |
106 | if (canBeRefined(use->getOwner())) |
107 | continue; |
108 | |
109 | if (!castValue) { |
110 | // Set the insertion point as far back as possible, since new |
111 | // consumers of the 'tensor.cast' op generated in future iterations |
112 | // are likely to be further up in the code due to the order in which |
113 | // they appear in the use list. |
114 | OpBuilder builder{value.getContext()}; |
115 | builder.setInsertionPointAfter(value.getDefiningOp()); |
116 | castValue = |
117 | builder.create<tensor::CastOp>(value.getLoc(), oldType, value); |
118 | } |
119 | |
120 | use->set(castValue); |
121 | } |
122 | } |
123 | |
124 | oldTypes.clear(); |
125 | } |
126 | |
127 | private: |
128 | // A record of each value whose type was updated along with that value's |
129 | // previous type. |
130 | llvm::SmallVector<std::pair<Value, Type>> oldTypes; |
131 | }; |
132 | |
133 | void propagateShapesInRegion(Region ®ion, TypeModificationState &state); |
134 | |
135 | void propagateShapesToTosaIf(Operation &op, TypeModificationState &state) { |
136 | IfOp ifOp = dyn_cast<IfOp>(op); |
137 | if (!ifOp) |
138 | return; |
139 | |
140 | for (auto ®ion : op.getRegions()) { |
141 | Block &frontBlock = region.front(); |
142 | if (frontBlock.getNumArguments() + 1 != ifOp.getNumOperands()) |
143 | return; |
144 | |
145 | for (unsigned int i = 1, s = op.getNumOperands(); i < s; i++) { |
146 | auto inferredTy = cast<ShapedType>(op.getOperand(i).getType()); |
147 | auto blockArg = frontBlock.getArgument(i: i - 1); |
148 | auto oldType = cast<ShapedType>(blockArg.getType()); |
149 | |
150 | if (inferredTy.hasRank()) { |
151 | Type newType = oldType.clone(inferredTy.getShape()); |
152 | state.setType(value: blockArg, type: newType); |
153 | } |
154 | } |
155 | |
156 | for (int i = 0, e = frontBlock.getNumArguments(); i < e; i++) { |
157 | ValueKnowledge operandKnowledge = ValueKnowledge::getKnowledgeFromType( |
158 | type: ifOp.getOperand(i + 1).getType()); |
159 | ValueKnowledge blockKnowledge = ValueKnowledge::getKnowledgeFromType( |
160 | type: frontBlock.getArgument(i).getType()); |
161 | ValueKnowledge joinedKnowledge = |
162 | ValueKnowledge::join(lhs: operandKnowledge, rhs: blockKnowledge); |
163 | if (!joinedKnowledge) |
164 | continue; |
165 | state.setType(value: frontBlock.getArgument(i), type: joinedKnowledge.getType()); |
166 | } |
167 | |
168 | propagateShapesInRegion(region, state); |
169 | } |
170 | } |
171 | |
172 | void propagateShapesToTosaWhile(Operation &op, TypeModificationState &state) { |
173 | WhileOp whileOp = dyn_cast<WhileOp>(op); |
174 | if (!whileOp) |
175 | return; |
176 | |
177 | // Determine what the expected argument types are to the cond/body blocks. |
178 | // The expected arguments should be compatible with ever iteration of the |
179 | // loop body / condition for tosa.while. |
180 | SmallVector<Type> argTypes = llvm::to_vector(op.getOperandTypes()); |
181 | |
182 | bool hasNewTypes = true; |
183 | while (hasNewTypes) { |
184 | TypeModificationState localState; |
185 | |
186 | // Set types on the block args. |
187 | Region &bodyRegion = op.getRegion(index: 1); |
188 | Block &block = bodyRegion.front(); |
189 | for (int i = 0, s = argTypes.size(); i < s; i++) { |
190 | localState.setType(value: block.getArgument(i), type: argTypes[i]); |
191 | } |
192 | |
193 | // Propagate to the end. |
194 | propagateShapesInRegion(region&: bodyRegion, state&: localState); |
195 | |
196 | // Find all the tosa yield types and verify there is a single one. |
197 | llvm::SmallVector<YieldOp> yieldOps; |
198 | for (auto &block : bodyRegion) |
199 | if (auto yieldOp = dyn_cast<YieldOp>(block.getTerminator())) |
200 | yieldOps.push_back(yieldOp); |
201 | |
202 | assert(yieldOps.size() == 1 && "missing or non-unique yield op" ); |
203 | // Using the new tosa.yield operand types, infer the new subtypes. |
204 | llvm::SmallVector<ValueKnowledge> yieldTypeInfo; |
205 | for (auto ty : argTypes) { |
206 | yieldTypeInfo.push_back(ValueKnowledge::getKnowledgeFromType(ty)); |
207 | } |
208 | |
209 | for (auto yieldOp : yieldOps) { |
210 | for (const auto &it : llvm::enumerate(yieldOp.getOperands())) { |
211 | auto newKnowledge = |
212 | ValueKnowledge::getKnowledgeFromType(it.value().getType()); |
213 | yieldTypeInfo[it.index()] = |
214 | ValueKnowledge::meet(yieldTypeInfo[it.index()], newKnowledge); |
215 | } |
216 | } |
217 | |
218 | // This should never happen. |
219 | if (yieldTypeInfo.size() != argTypes.size()) { |
220 | op.emitWarning(message: "has a tosa.yield with the incorrect number of operands" ); |
221 | return; |
222 | } |
223 | |
224 | // Determine the new block args and see if any changed. |
225 | hasNewTypes = false; |
226 | for (int i = 0, s = yieldTypeInfo.size(); i < s; i++) { |
227 | Type newType = yieldTypeInfo[i].getType(); |
228 | hasNewTypes |= (newType != argTypes[i]); |
229 | argTypes[i] = newType; |
230 | } |
231 | |
232 | // Roll back all changes made during the speculative part of the algorithm. |
233 | localState.rollBack(); |
234 | } |
235 | |
236 | // We now set the block arguments according to the most recent shape |
237 | // inference results. This gives us the block arg types for the next |
238 | // iteration. |
239 | for (auto ®ion : op.getRegions()) { |
240 | for (unsigned int i = 0, s = argTypes.size(); i < s; i++) { |
241 | state.setType(value: region.front().getArgument(i), type: argTypes[i]); |
242 | } |
243 | |
244 | propagateShapesInRegion(region, state); |
245 | } |
246 | } |
247 | |
248 | void propagateShapesInRegion(Region ®ion, TypeModificationState &state) { |
249 | Dialect *tosaDialect = region.getContext()->getLoadedDialect<TosaDialect>(); |
250 | |
251 | for (auto &block : region) { |
252 | for (Operation &op : block) { |
253 | if (op.getDialect() != tosaDialect) |
254 | continue; |
255 | |
256 | propagateShapesToTosaIf(op, state); |
257 | propagateShapesToTosaWhile(op, state); |
258 | |
259 | InferShapedTypeOpInterface shapeInterface = |
260 | dyn_cast<InferShapedTypeOpInterface>(op); |
261 | if (!shapeInterface) |
262 | continue; |
263 | |
264 | SmallVector<ShapedTypeComponents> returnedShapes; |
265 | |
266 | if (shapeInterface |
267 | .inferReturnTypeComponents( |
268 | op.getContext(), op.getLoc(), op.getOperands(), |
269 | op.getDiscardableAttrDictionary(), op.getPropertiesStorage(), |
270 | op.getRegions(), returnedShapes) |
271 | .succeeded()) { |
272 | for (auto it : llvm::zip(op.getResults(), returnedShapes)) { |
273 | Value result = std::get<0>(it); |
274 | ShapedTypeComponents predictedShape = std::get<1>(it); |
275 | |
276 | // Determine the knowledge based on the output type. |
277 | // TODO: should also query WIP type probably |
278 | Type resultTy = result.getType(); |
279 | auto currentKnowledge = |
280 | ValueKnowledge::getKnowledgeFromType(resultTy); |
281 | |
282 | // Compute the knowledge based on the inferred type. |
283 | auto inferredKnowledge = ValueKnowledge::getPessimisticValueState(); |
284 | inferredKnowledge.dtype = cast<ShapedType>(resultTy).getElementType(); |
285 | inferredKnowledge.hasRank = predictedShape.hasRank(); |
286 | if (predictedShape.hasRank()) { |
287 | for (auto dim : predictedShape.getDims()) { |
288 | inferredKnowledge.sizes.push_back(dim); |
289 | } |
290 | } |
291 | |
292 | // Compute the new type based on the joined version. |
293 | auto newKnowledge = |
294 | ValueKnowledge::join(currentKnowledge, inferredKnowledge); |
295 | if (!newKnowledge) |
296 | continue; |
297 | |
298 | // Set new type |
299 | state.setType(result, newKnowledge.getType()); |
300 | } |
301 | } |
302 | } |
303 | } |
304 | } |
305 | |
306 | /// Recursively validate tosa ops with SameOperandsAndResultRank trait in region |
307 | /// and all nested regions |
308 | void validateSameOperandsAndResultRankTrait(Region ®ion) { |
309 | int errs = 0; |
310 | for (auto &block : region) { |
311 | for (auto &op : block) { |
312 | if (!op.getDialect() || |
313 | op.getDialect()->getNamespace() != TosaDialect::getDialectNamespace()) |
314 | continue; |
315 | if (op.hasTrait<OpTrait::SameOperandsAndResultRank>()) { |
316 | if (OpTrait::impl::verifySameOperandsAndResultRank(op: &op).failed()) { |
317 | errs++; |
318 | (void)errs; |
319 | } |
320 | } |
321 | WhileOp whileOp = dyn_cast<WhileOp>(op); |
322 | IfOp ifOp = dyn_cast<IfOp>(op); |
323 | if (whileOp || ifOp) { |
324 | // recurse into whileOp's regions |
325 | for (auto &next : op.getRegions()) { |
326 | validateSameOperandsAndResultRankTrait(region&: next); |
327 | } |
328 | } |
329 | } |
330 | } |
331 | } |
332 | |
333 | /// Pass that performs shape propagation across TOSA operations. This includes |
334 | /// migrating to within the regions of if/while operations. |
335 | struct TosaInferShapes |
336 | : public tosa::impl::TosaInferShapesPassBase<TosaInferShapes> { |
337 | public: |
338 | void runOnOperation() override { |
339 | func::FuncOp func = getOperation(); |
340 | TypeModificationState state; |
341 | propagateShapesInRegion(func.getBody(), state); |
342 | state.commit(); |
343 | |
344 | validateSameOperandsAndResultRankTrait(func.getBody()); |
345 | } |
346 | }; |
347 | } // namespace |
348 | |