| 1 | //===- TosaInferShapes.cpp ------------------------------------------------===// |
| 2 | // |
| 3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
| 4 | // See https://llvm.org/LICENSE.txt for license information. |
| 5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
| 6 | // |
| 7 | //===----------------------------------------------------------------------===// |
| 8 | // |
| 9 | // Propagate shapes forward along TOSA operations to resolve dynamic shape |
| 10 | // operations. |
| 11 | // |
| 12 | //===----------------------------------------------------------------------===// |
| 13 | |
| 14 | #include "mlir/Dialect/Tosa/Transforms/Passes.h" |
| 15 | |
| 16 | #include "mlir/Dialect/Func/IR/FuncOps.h" |
| 17 | #include "mlir/Dialect/Tensor/IR/Tensor.h" |
| 18 | #include "mlir/Dialect/Tosa/IR/TosaOps.h" |
| 19 | #include "mlir/Dialect/Tosa/Utils/ShapeUtils.h" |
| 20 | #include "mlir/IR/Builders.h" |
| 21 | #include "mlir/IR/ImplicitLocOpBuilder.h" |
| 22 | #include "mlir/Interfaces/InferTypeOpInterface.h" |
| 23 | #include "mlir/Pass/Pass.h" |
| 24 | #include "mlir/Transforms/DialectConversion.h" |
| 25 | |
| 26 | namespace mlir { |
| 27 | namespace tosa { |
| 28 | #define GEN_PASS_DEF_TOSAINFERSHAPESPASS |
| 29 | #include "mlir/Dialect/Tosa/Transforms/Passes.h.inc" |
| 30 | } // namespace tosa |
| 31 | } // namespace mlir |
| 32 | |
| 33 | using namespace mlir; |
| 34 | using namespace mlir::tosa; |
| 35 | |
| 36 | namespace { |
| 37 | |
| 38 | // Check whether this use case is replaceable. We define an op as |
| 39 | // being replaceable if it is used by a TosaOp, or an op with a |
| 40 | // type-inference related interface. |
| 41 | // When a non-replaceable use is encountered, the value is wrapped in a |
| 42 | // cast back to the original type after inference. |
| 43 | bool canBeRefined(Operation *user) { |
| 44 | if (!user->getDialect()) |
| 45 | return false; |
| 46 | return user->getDialect()->getTypeID() == TypeID::get<TosaDialect>() || |
| 47 | isa<InferTypeOpInterface, InferShapedTypeOpInterface>(user); |
| 48 | } |
| 49 | |
| 50 | // During type propagation, the types of values in the operator graph are |
| 51 | // updated. For the tosa.while_loop operation, types are speculatively updated |
| 52 | // within the body region to determine the output type of the while_loop. This |
| 53 | // process is performed until a fixed point is reached, then the types are |
| 54 | // rolled back. |
| 55 | // |
| 56 | // This class encapsulates the state information needed to perform the roll back |
| 57 | // process or to commit to the final changes. |
| 58 | class TypeModificationState { |
| 59 | public: |
| 60 | TypeModificationState() = default; |
| 61 | |
| 62 | ~TypeModificationState() { |
| 63 | // Ensure the recorded modifications are either committed or rolled back. |
| 64 | assert(oldTypes.empty() && "unhandled type modifications" ); |
| 65 | } |
| 66 | |
| 67 | // Update the state of the value and record the old type. |
| 68 | void setType(Value value, Type type) { |
| 69 | if (value.getType() != type) { |
| 70 | oldTypes.emplace_back(value, value.getType()); |
| 71 | value.setType(type); |
| 72 | } |
| 73 | } |
| 74 | |
| 75 | // Roll back changes made to the types in the IR by setting all the affected |
| 76 | // values to their old types. |
| 77 | void rollBack() { |
| 78 | for (auto [value, type] : oldTypes) |
| 79 | value.setType(type); |
| 80 | |
| 81 | oldTypes.clear(); |
| 82 | } |
| 83 | |
| 84 | // Commit the changes to the types in the IR. |
| 85 | // This requires inserting tensor.cast operations to mediate the newly |
| 86 | // inferred result types with users that do not support type inference. |
| 87 | void commit() { |
| 88 | // For each use whose type changed, cast the value with the new type back to |
| 89 | // the old type. |
| 90 | for (auto [value, oldType] : oldTypes) { |
| 91 | // The call to 'use->set()' in the body of the loop below invalidates the |
| 92 | // iterator used to traverse op uses, so it is important to make a copy of |
| 93 | // these first. |
| 94 | llvm::SmallVector<OpOperand *> uses = llvm::map_to_vector( |
| 95 | value.getUses(), |
| 96 | [](OpOperand &use) -> OpOperand * { |
| 97 | return &use; |
| 98 | }); |
| 99 | |
| 100 | // A 'tensor.cast' op is emitted only if needed. Once emitted, it is |
| 101 | // cached and reused by all consumers. |
| 102 | tensor::CastOp castValue; |
| 103 | |
| 104 | // Traverse all uses |
| 105 | for (OpOperand *use : uses) { |
| 106 | if (canBeRefined(use->getOwner())) |
| 107 | continue; |
| 108 | |
| 109 | if (!castValue) { |
| 110 | // Set the insertion point as far back as possible, since new |
| 111 | // consumers of the 'tensor.cast' op generated in future iterations |
| 112 | // are likely to be further up in the code due to the order in which |
| 113 | // they appear in the use list. |
| 114 | OpBuilder builder{value.getContext()}; |
| 115 | builder.setInsertionPointAfter(value.getDefiningOp()); |
| 116 | castValue = |
| 117 | builder.create<tensor::CastOp>(value.getLoc(), oldType, value); |
| 118 | } |
| 119 | |
| 120 | use->set(castValue); |
| 121 | } |
| 122 | } |
| 123 | |
| 124 | oldTypes.clear(); |
| 125 | } |
| 126 | |
| 127 | private: |
| 128 | // A record of each value whose type was updated along with that value's |
| 129 | // previous type. |
| 130 | llvm::SmallVector<std::pair<Value, Type>> oldTypes; |
| 131 | }; |
| 132 | |
| 133 | void propagateShapesInRegion(Region ®ion, TypeModificationState &state); |
| 134 | |
| 135 | void propagateShapesToTosaIf(Operation &op, TypeModificationState &state) { |
| 136 | IfOp ifOp = dyn_cast<IfOp>(op); |
| 137 | if (!ifOp) |
| 138 | return; |
| 139 | |
| 140 | for (auto ®ion : op.getRegions()) { |
| 141 | Block &frontBlock = region.front(); |
| 142 | if (frontBlock.getNumArguments() + 1 != ifOp.getNumOperands()) |
| 143 | return; |
| 144 | |
| 145 | for (unsigned int i = 1, s = op.getNumOperands(); i < s; i++) { |
| 146 | auto inferredTy = cast<ShapedType>(op.getOperand(i).getType()); |
| 147 | auto blockArg = frontBlock.getArgument(i: i - 1); |
| 148 | auto oldType = cast<ShapedType>(blockArg.getType()); |
| 149 | |
| 150 | if (inferredTy.hasRank()) { |
| 151 | Type newType = oldType.clone(inferredTy.getShape()); |
| 152 | state.setType(value: blockArg, type: newType); |
| 153 | } |
| 154 | } |
| 155 | |
| 156 | for (int i = 0, e = frontBlock.getNumArguments(); i < e; i++) { |
| 157 | ValueKnowledge operandKnowledge = ValueKnowledge::getKnowledgeFromType( |
| 158 | type: ifOp.getOperand(i + 1).getType()); |
| 159 | ValueKnowledge blockKnowledge = ValueKnowledge::getKnowledgeFromType( |
| 160 | type: frontBlock.getArgument(i).getType()); |
| 161 | ValueKnowledge joinedKnowledge = |
| 162 | ValueKnowledge::join(lhs: operandKnowledge, rhs: blockKnowledge); |
| 163 | if (!joinedKnowledge) |
| 164 | continue; |
| 165 | state.setType(value: frontBlock.getArgument(i), type: joinedKnowledge.getType()); |
| 166 | } |
| 167 | |
| 168 | propagateShapesInRegion(region, state); |
| 169 | } |
| 170 | } |
| 171 | |
| 172 | void propagateShapesToTosaWhile(Operation &op, TypeModificationState &state) { |
| 173 | WhileOp whileOp = dyn_cast<WhileOp>(op); |
| 174 | if (!whileOp) |
| 175 | return; |
| 176 | |
| 177 | // Determine what the expected argument types are to the cond/body blocks. |
| 178 | // The expected arguments should be compatible with ever iteration of the |
| 179 | // loop body / condition for tosa.while. |
| 180 | SmallVector<Type> argTypes = llvm::to_vector(op.getOperandTypes()); |
| 181 | |
| 182 | bool hasNewTypes = true; |
| 183 | while (hasNewTypes) { |
| 184 | TypeModificationState localState; |
| 185 | |
| 186 | // Set types on the block args. |
| 187 | Region &bodyRegion = op.getRegion(index: 1); |
| 188 | Block &block = bodyRegion.front(); |
| 189 | for (int i = 0, s = argTypes.size(); i < s; i++) { |
| 190 | localState.setType(value: block.getArgument(i), type: argTypes[i]); |
| 191 | } |
| 192 | |
| 193 | // Propagate to the end. |
| 194 | propagateShapesInRegion(region&: bodyRegion, state&: localState); |
| 195 | |
| 196 | // Find all the tosa yield types and verify there is a single one. |
| 197 | llvm::SmallVector<YieldOp> yieldOps; |
| 198 | for (auto &block : bodyRegion) |
| 199 | if (auto yieldOp = dyn_cast<YieldOp>(block.getTerminator())) |
| 200 | yieldOps.push_back(yieldOp); |
| 201 | |
| 202 | assert(yieldOps.size() == 1 && "missing or non-unique yield op" ); |
| 203 | // Using the new tosa.yield operand types, infer the new subtypes. |
| 204 | llvm::SmallVector<ValueKnowledge> yieldTypeInfo; |
| 205 | for (auto ty : argTypes) { |
| 206 | yieldTypeInfo.push_back(ValueKnowledge::getKnowledgeFromType(ty)); |
| 207 | } |
| 208 | |
| 209 | for (auto yieldOp : yieldOps) { |
| 210 | for (const auto &it : llvm::enumerate(yieldOp.getOperands())) { |
| 211 | auto newKnowledge = |
| 212 | ValueKnowledge::getKnowledgeFromType(it.value().getType()); |
| 213 | yieldTypeInfo[it.index()] = |
| 214 | ValueKnowledge::meet(yieldTypeInfo[it.index()], newKnowledge); |
| 215 | } |
| 216 | } |
| 217 | |
| 218 | // This should never happen. |
| 219 | if (yieldTypeInfo.size() != argTypes.size()) { |
| 220 | op.emitWarning(message: "has a tosa.yield with the incorrect number of operands" ); |
| 221 | return; |
| 222 | } |
| 223 | |
| 224 | // Determine the new block args and see if any changed. |
| 225 | hasNewTypes = false; |
| 226 | for (int i = 0, s = yieldTypeInfo.size(); i < s; i++) { |
| 227 | Type newType = yieldTypeInfo[i].getType(); |
| 228 | hasNewTypes |= (newType != argTypes[i]); |
| 229 | argTypes[i] = newType; |
| 230 | } |
| 231 | |
| 232 | // Roll back all changes made during the speculative part of the algorithm. |
| 233 | localState.rollBack(); |
| 234 | } |
| 235 | |
| 236 | // We now set the block arguments according to the most recent shape |
| 237 | // inference results. This gives us the block arg types for the next |
| 238 | // iteration. |
| 239 | for (auto ®ion : op.getRegions()) { |
| 240 | for (unsigned int i = 0, s = argTypes.size(); i < s; i++) { |
| 241 | state.setType(value: region.front().getArgument(i), type: argTypes[i]); |
| 242 | } |
| 243 | |
| 244 | propagateShapesInRegion(region, state); |
| 245 | } |
| 246 | } |
| 247 | |
| 248 | void propagateShapesInRegion(Region ®ion, TypeModificationState &state) { |
| 249 | Dialect *tosaDialect = region.getContext()->getLoadedDialect<TosaDialect>(); |
| 250 | |
| 251 | for (auto &block : region) { |
| 252 | for (Operation &op : block) { |
| 253 | if (op.getDialect() != tosaDialect) |
| 254 | continue; |
| 255 | |
| 256 | propagateShapesToTosaIf(op, state); |
| 257 | propagateShapesToTosaWhile(op, state); |
| 258 | |
| 259 | InferShapedTypeOpInterface shapeInterface = |
| 260 | dyn_cast<InferShapedTypeOpInterface>(op); |
| 261 | if (!shapeInterface) |
| 262 | continue; |
| 263 | |
| 264 | SmallVector<ShapedTypeComponents> returnedShapes; |
| 265 | |
| 266 | if (shapeInterface |
| 267 | .inferReturnTypeComponents( |
| 268 | op.getContext(), op.getLoc(), op.getOperands(), |
| 269 | op.getDiscardableAttrDictionary(), op.getPropertiesStorage(), |
| 270 | op.getRegions(), returnedShapes) |
| 271 | .succeeded()) { |
| 272 | for (auto it : llvm::zip(op.getResults(), returnedShapes)) { |
| 273 | Value result = std::get<0>(it); |
| 274 | ShapedTypeComponents predictedShape = std::get<1>(it); |
| 275 | |
| 276 | // Determine the knowledge based on the output type. |
| 277 | // TODO: should also query WIP type probably |
| 278 | Type resultTy = result.getType(); |
| 279 | auto currentKnowledge = |
| 280 | ValueKnowledge::getKnowledgeFromType(resultTy); |
| 281 | |
| 282 | // Compute the knowledge based on the inferred type. |
| 283 | auto inferredKnowledge = ValueKnowledge::getPessimisticValueState(); |
| 284 | inferredKnowledge.dtype = cast<ShapedType>(resultTy).getElementType(); |
| 285 | inferredKnowledge.hasRank = predictedShape.hasRank(); |
| 286 | if (predictedShape.hasRank()) { |
| 287 | for (auto dim : predictedShape.getDims()) { |
| 288 | inferredKnowledge.sizes.push_back(dim); |
| 289 | } |
| 290 | } |
| 291 | |
| 292 | // Compute the new type based on the joined version. |
| 293 | auto newKnowledge = |
| 294 | ValueKnowledge::join(currentKnowledge, inferredKnowledge); |
| 295 | if (!newKnowledge) |
| 296 | continue; |
| 297 | |
| 298 | // Set new type |
| 299 | state.setType(result, newKnowledge.getType()); |
| 300 | } |
| 301 | } |
| 302 | } |
| 303 | } |
| 304 | } |
| 305 | |
| 306 | /// Recursively validate tosa ops with SameOperandsAndResultRank trait in region |
| 307 | /// and all nested regions |
| 308 | void validateSameOperandsAndResultRankTrait(Region ®ion) { |
| 309 | int errs = 0; |
| 310 | for (auto &block : region) { |
| 311 | for (auto &op : block) { |
| 312 | if (!op.getDialect() || |
| 313 | op.getDialect()->getNamespace() != TosaDialect::getDialectNamespace()) |
| 314 | continue; |
| 315 | if (op.hasTrait<OpTrait::SameOperandsAndResultRank>()) { |
| 316 | if (OpTrait::impl::verifySameOperandsAndResultRank(op: &op).failed()) { |
| 317 | errs++; |
| 318 | (void)errs; |
| 319 | } |
| 320 | } |
| 321 | WhileOp whileOp = dyn_cast<WhileOp>(op); |
| 322 | IfOp ifOp = dyn_cast<IfOp>(op); |
| 323 | if (whileOp || ifOp) { |
| 324 | // recurse into whileOp's regions |
| 325 | for (auto &next : op.getRegions()) { |
| 326 | validateSameOperandsAndResultRankTrait(region&: next); |
| 327 | } |
| 328 | } |
| 329 | } |
| 330 | } |
| 331 | } |
| 332 | |
| 333 | /// Pass that performs shape propagation across TOSA operations. This includes |
| 334 | /// migrating to within the regions of if/while operations. |
| 335 | struct TosaInferShapes |
| 336 | : public tosa::impl::TosaInferShapesPassBase<TosaInferShapes> { |
| 337 | public: |
| 338 | void runOnOperation() override { |
| 339 | func::FuncOp func = getOperation(); |
| 340 | TypeModificationState state; |
| 341 | propagateShapesInRegion(func.getBody(), state); |
| 342 | state.commit(); |
| 343 | |
| 344 | validateSameOperandsAndResultRankTrait(func.getBody()); |
| 345 | } |
| 346 | }; |
| 347 | } // namespace |
| 348 | |