1 | //===- TosaInferShapes.cpp ------------------------------------------------===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | // |
9 | // Propogate shapes forward along TOSA operations to resolve dynamic shape |
10 | // operations. |
11 | // |
12 | //===----------------------------------------------------------------------===// |
13 | |
14 | #include "mlir/Dialect/Tosa/Transforms/Passes.h" |
15 | |
16 | #include "mlir/Dialect/Func/IR/FuncOps.h" |
17 | #include "mlir/Dialect/Tensor/IR/Tensor.h" |
18 | #include "mlir/Dialect/Tosa/IR/TosaOps.h" |
19 | #include "mlir/Dialect/Tosa/Utils/ShapeUtils.h" |
20 | #include "mlir/IR/Builders.h" |
21 | #include "mlir/Interfaces/InferTypeOpInterface.h" |
22 | #include "mlir/Pass/Pass.h" |
23 | #include "mlir/Transforms/DialectConversion.h" |
24 | |
25 | namespace mlir { |
26 | namespace tosa { |
27 | #define GEN_PASS_DEF_TOSAINFERSHAPES |
28 | #include "mlir/Dialect/Tosa/Transforms/Passes.h.inc" |
29 | } // namespace tosa |
30 | } // namespace mlir |
31 | |
32 | using namespace mlir; |
33 | using namespace mlir::tosa; |
34 | |
35 | namespace { |
36 | |
37 | // Check whether this use case is replaceable. We define an op as |
38 | // being replaceable if it is used by a TosaOp, or an op with a |
39 | // type-inference related interface. |
40 | // When a non-replaceable use is encountered, the value is wrapped in a |
41 | // cast back to the original type after inference. |
42 | bool isReplaceableUser(Operation *user) { |
43 | // Handle unregistered dialects. |
44 | if (!user->getDialect()) |
45 | return false; |
46 | |
47 | return user->getDialect()->getNamespace() == |
48 | TosaDialect::getDialectNamespace() || |
49 | isa<InferTypeOpInterface, InferShapedTypeOpInterface>(user); |
50 | } |
51 | |
52 | // During type propagation, the types of values in the operator graph are |
53 | // updated. For the tosa.while_loop operation, types are speculatively updated |
54 | // within the body region to determine the output type of the while_loop. This |
55 | // process is performed until a fixed point is reached, then the types are |
56 | // reverted. |
57 | // |
58 | // This class encapsulates the state information needed to perform the reversion |
59 | // process or to commit to the final changes. |
60 | class TypeModificationState { |
61 | public: |
62 | TypeModificationState() = default; |
63 | |
64 | ~TypeModificationState() { |
65 | // Ensure the recorded modifications are either committed or reverted. |
66 | assert(oldTypes.empty() && "unhandled type modifications" ); |
67 | } |
68 | |
69 | // Update the state of the value and record the old type. |
70 | void setType(Value value, Type type) { |
71 | if (value.getType() != type) { |
72 | oldTypes.emplace_back(value, value.getType()); |
73 | value.setType(type); |
74 | } |
75 | } |
76 | |
77 | // Revert changes made to the types in the IR by setting all the affected |
78 | // values to their old types. |
79 | void revert() { |
80 | // Otherwise revert the changes. |
81 | for (auto [value, type] : oldTypes) |
82 | value.setType(type); |
83 | |
84 | oldTypes.clear(); |
85 | } |
86 | |
87 | // Commit the changes to the types in the IR. |
88 | // This requires inserting tensor.cast operations to mediate the newly |
89 | // inferred result types with users that do not support type inference. |
90 | void commit() { |
91 | // For each use whose type changed, cast the value with the new type back to |
92 | // the old type. |
93 | for (auto [value, oldType] : oldTypes) { |
94 | for (auto &use : value.getUses()) { |
95 | if (isReplaceableUser(use.getOwner())) |
96 | continue; |
97 | |
98 | OpBuilder builder(value.getContext()); |
99 | builder.setInsertionPoint(use.getOwner()); |
100 | |
101 | Location loc = value.getLoc(); |
102 | use.set(builder.create<tensor::CastOp>(loc, oldType, value)); |
103 | } |
104 | } |
105 | |
106 | oldTypes.clear(); |
107 | } |
108 | |
109 | private: |
110 | // A record of each value whose type was updated along with that value's |
111 | // previous type. |
112 | llvm::SmallVector<std::pair<Value, Type>> oldTypes; |
113 | }; |
114 | |
115 | void propagateShapesInRegion(Region ®ion, TypeModificationState &state); |
116 | |
117 | void propagateShapesToTosaIf(Operation &op, TypeModificationState &state) { |
118 | IfOp ifOp = dyn_cast<IfOp>(op); |
119 | if (!ifOp) |
120 | return; |
121 | |
122 | for (auto ®ion : op.getRegions()) { |
123 | Block &frontBlock = region.front(); |
124 | if (frontBlock.getNumArguments() + 1 != ifOp.getNumOperands()) |
125 | return; |
126 | |
127 | for (unsigned int i = 1, s = op.getNumOperands(); i < s; i++) { |
128 | auto inferredTy = cast<ShapedType>(op.getOperand(i).getType()); |
129 | auto blockArg = frontBlock.getArgument(i: i - 1); |
130 | auto oldType = cast<ShapedType>(blockArg.getType()); |
131 | |
132 | if (inferredTy.hasRank()) { |
133 | Type newType = oldType.clone(inferredTy.getShape()); |
134 | state.setType(value: blockArg, type: newType); |
135 | } |
136 | } |
137 | |
138 | for (int i = 0, e = frontBlock.getNumArguments(); i < e; i++) { |
139 | ValueKnowledge operandKnowledge = ValueKnowledge::getKnowledgeFromType( |
140 | type: ifOp.getOperand(i + 1).getType()); |
141 | ValueKnowledge blockKnowledge = ValueKnowledge::getKnowledgeFromType( |
142 | type: frontBlock.getArgument(i).getType()); |
143 | ValueKnowledge joinedKnowledge = |
144 | ValueKnowledge::join(lhs: operandKnowledge, rhs: blockKnowledge); |
145 | if (!joinedKnowledge) |
146 | continue; |
147 | state.setType(value: frontBlock.getArgument(i), type: joinedKnowledge.getType()); |
148 | } |
149 | |
150 | propagateShapesInRegion(region, state); |
151 | } |
152 | } |
153 | |
154 | void propagateShapesToTosaWhile(Operation &op, TypeModificationState &state) { |
155 | WhileOp whileOp = dyn_cast<WhileOp>(op); |
156 | if (!whileOp) |
157 | return; |
158 | |
159 | // Determine what the expected argument types are to the cond/body blocks. |
160 | // The expected arguments should be compatible with ever iteration of the |
161 | // loop body / condition for tosa.while. |
162 | SmallVector<Type> argTypes = llvm::to_vector(op.getOperandTypes()); |
163 | |
164 | bool hasNewTypes = true; |
165 | while (hasNewTypes) { |
166 | TypeModificationState localState; |
167 | |
168 | // Set types on the block args. |
169 | Region &bodyRegion = op.getRegion(index: 1); |
170 | Block &block = bodyRegion.front(); |
171 | for (int i = 0, s = argTypes.size(); i < s; i++) { |
172 | localState.setType(value: block.getArgument(i), type: argTypes[i]); |
173 | } |
174 | |
175 | // Propagate to the end. |
176 | propagateShapesInRegion(region&: bodyRegion, state&: localState); |
177 | |
178 | // Find all the tosa yield types and verify there is a single one. |
179 | llvm::SmallVector<YieldOp> yieldOps; |
180 | for (auto &block : bodyRegion) |
181 | if (auto yieldOp = dyn_cast<YieldOp>(block.getTerminator())) |
182 | yieldOps.push_back(yieldOp); |
183 | |
184 | assert(yieldOps.size() == 1 && "missing or non-unique yield op" ); |
185 | // Using the new tosa.yield operand types, infer the new subtypes. |
186 | llvm::SmallVector<ValueKnowledge> yieldTypeInfo; |
187 | for (auto ty : argTypes) { |
188 | yieldTypeInfo.push_back(ValueKnowledge::getKnowledgeFromType(ty)); |
189 | } |
190 | |
191 | for (auto yieldOp : yieldOps) { |
192 | for (const auto &it : llvm::enumerate(yieldOp.getOperands())) { |
193 | auto newKnowledge = |
194 | ValueKnowledge::getKnowledgeFromType(it.value().getType()); |
195 | yieldTypeInfo[it.index()] = |
196 | ValueKnowledge::meet(yieldTypeInfo[it.index()], newKnowledge); |
197 | } |
198 | } |
199 | |
200 | // This should never happen. |
201 | if (yieldTypeInfo.size() != argTypes.size()) { |
202 | op.emitWarning(message: "has a tosa.yield with the incorrect number of operands" ); |
203 | return; |
204 | } |
205 | |
206 | // Determine the new block args and see if any changed. |
207 | hasNewTypes = false; |
208 | for (int i = 0, s = yieldTypeInfo.size(); i < s; i++) { |
209 | Type newType = yieldTypeInfo[i].getType(); |
210 | hasNewTypes |= (newType != argTypes[i]); |
211 | argTypes[i] = newType; |
212 | } |
213 | |
214 | // Revert all changes made during the speculative part of the algorithm. |
215 | localState.revert(); |
216 | } |
217 | |
218 | // We now set the block arguments according to the most recent shape |
219 | // inference results. This gives us the block arg types for the next |
220 | // iteration. |
221 | for (auto ®ion : op.getRegions()) { |
222 | for (unsigned int i = 0, s = argTypes.size(); i < s; i++) { |
223 | state.setType(value: region.front().getArgument(i), type: argTypes[i]); |
224 | } |
225 | |
226 | propagateShapesInRegion(region, state); |
227 | } |
228 | } |
229 | |
230 | void propagateShapesInRegion(Region ®ion, TypeModificationState &state) { |
231 | for (auto &block : region) { |
232 | for (Operation &op : block) { |
233 | if (!op.getDialect() || |
234 | op.getDialect()->getNamespace() != TosaDialect::getDialectNamespace()) |
235 | continue; |
236 | |
237 | propagateShapesToTosaIf(op, state); |
238 | propagateShapesToTosaWhile(op, state); |
239 | |
240 | InferShapedTypeOpInterface shapeInterface = |
241 | dyn_cast<InferShapedTypeOpInterface>(op); |
242 | if (!shapeInterface) |
243 | continue; |
244 | |
245 | SmallVector<ShapedTypeComponents> returnedShapes; |
246 | |
247 | if (shapeInterface |
248 | .inferReturnTypeComponents( |
249 | op.getContext(), op.getLoc(), op.getOperands(), |
250 | op.getDiscardableAttrDictionary(), op.getPropertiesStorage(), |
251 | op.getRegions(), returnedShapes) |
252 | .succeeded()) { |
253 | for (auto it : llvm::zip(op.getResults(), returnedShapes)) { |
254 | Value result = std::get<0>(it); |
255 | ShapedTypeComponents predictedShape = std::get<1>(it); |
256 | |
257 | // Determine the knowledge based on the output type. |
258 | // TODO: should also query WIP type probably |
259 | Type resultTy = result.getType(); |
260 | auto currentKnowledge = |
261 | ValueKnowledge::getKnowledgeFromType(resultTy); |
262 | |
263 | // Compute the knowledge based on the inferred type. |
264 | auto inferredKnowledge = ValueKnowledge::getPessimisticValueState(); |
265 | inferredKnowledge.dtype = cast<ShapedType>(resultTy).getElementType(); |
266 | inferredKnowledge.hasRank = predictedShape.hasRank(); |
267 | if (predictedShape.hasRank()) { |
268 | for (auto dim : predictedShape.getDims()) { |
269 | inferredKnowledge.sizes.push_back(dim); |
270 | } |
271 | } |
272 | |
273 | // Compute the new type based on the joined version. |
274 | auto newKnowledge = |
275 | ValueKnowledge::join(currentKnowledge, inferredKnowledge); |
276 | if (!newKnowledge) |
277 | continue; |
278 | |
279 | // Set new type |
280 | state.setType(result, newKnowledge.getType()); |
281 | } |
282 | } |
283 | } |
284 | } |
285 | } |
286 | |
287 | /// Pass that performs shape propagation across TOSA operations. This includes |
288 | /// migrating to within the regions of if/while operations. |
289 | struct TosaInferShapes |
290 | : public tosa::impl::TosaInferShapesBase<TosaInferShapes> { |
291 | public: |
292 | void runOnOperation() override { |
293 | func::FuncOp func = getOperation(); |
294 | TypeModificationState state; |
295 | propagateShapesInRegion(func.getBody(), state); |
296 | state.commit(); |
297 | } |
298 | }; |
299 | } // namespace |
300 | |
301 | std::unique_ptr<Pass> mlir::tosa::createTosaInferShapesPass() { |
302 | return std::make_unique<TosaInferShapes>(); |
303 | } |
304 | |