| 1 | //===- TosaReduceTransposes.cpp -------------------------------------------===// |
| 2 | // |
| 3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
| 4 | // See https://llvm.org/LICENSE.txt for license information. |
| 5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
| 6 | // |
| 7 | //===----------------------------------------------------------------------===// |
| 8 | |
| 9 | // ---------- |
| 10 | // Motivation: |
| 11 | // ---------- |
| 12 | |
| 13 | // Some legalization pathways introduce redundant tosa.TRANSPOSE |
| 14 | // operations that result in avoidable data movement. For example, |
| 15 | // PyTorch -> TOSA contains a lot of unnecessary transposes due |
| 16 | // to conversions between NCHW and NHWC. |
| 17 | |
| 18 | // We wish to remove all the ones that we can, since in general |
| 19 | // it is possible to remove the overwhelming majority. |
| 20 | |
| 21 | // ------------------- |
| 22 | // High-Level Overview: |
| 23 | // ------------------- |
| 24 | |
| 25 | // The pass works through the transpose operators in the program. It begins at |
| 26 | // some transpose operator with an associated permutations tensor. It traverses |
| 27 | // upwards through the dependencies of this transpose and verifies that we |
| 28 | // encounter only operators with the TosaElementwiseOperator trait and terminate |
| 29 | // in either constants, reshapes, or transposes. |
| 30 | |
| 31 | // We then evaluate whether there are any additional restrictions (the |
| 32 | // transposes it terminates in must invert the one we began at, and the reshapes |
| 33 | // must be ones in which we can fold the transpose into), and then we hoist the |
| 34 | // transpose through the intervening operators, folding it at the constants, |
| 35 | // reshapes, and transposes. |
| 36 | |
| 37 | // Finally, we ensure that we do not need both the transposed form (the form |
| 38 | // that had the transpose hoisted through it) and the untransposed form (which |
| 39 | // it was prior), by analyzing the usages of those dependent operators of a |
| 40 | // given transpose we are attempting to hoist and replace. |
| 41 | |
| 42 | // If they are such that it would require both forms to be necessary, then we do |
| 43 | // not replace the hoisted transpose, causing the new chain to be dead. |
| 44 | // Otherwise, we do and the old chain (untransposed form) becomes dead. Only one |
| 45 | // chain will ever then be live, resulting in no duplication. |
| 46 | |
| 47 | // We then perform a simple one-pass DCE, so no canonicalization is necessary. |
| 48 | |
| 49 | // ----------- |
| 50 | // Future Work: |
| 51 | // ----------- |
| 52 | |
| 53 | // (1) Evaluate tradeoffs with permitting ConstOp to be duplicated across |
| 54 | // hoisted |
| 55 | // transposes with different permutation tensors. |
| 56 | |
| 57 | // (2) Expand the class of foldable upstream ReshapeOp we permit beyond |
| 58 | // N -> 1x1x...x1xNx1x...x1x1. |
| 59 | |
| 60 | // (3) Enchance the pass to permit folding arbitrary transpose pairs, beyond |
| 61 | // those that form the identity. |
| 62 | |
| 63 | // (4) Add support for more instructions besides TosaElementwiseOperator as |
| 64 | // the intervening ones (for example, the reduce_* operators). |
| 65 | |
| 66 | // (5) Support hoisting transposes up to an input parameter. |
| 67 | |
| 68 | //===----------------------------------------------------------------------===// |
| 69 | |
| 70 | #include "mlir/Dialect/Func/IR/FuncOps.h" |
| 71 | #include "mlir/Dialect/Tosa/IR/TosaOps.h" |
| 72 | #include "mlir/Dialect/Tosa/Transforms/Passes.h" |
| 73 | #include "mlir/Dialect/Tosa/Utils/ConversionUtils.h" |
| 74 | #include "mlir/IR/Iterators.h" |
| 75 | #include "mlir/IR/Matchers.h" |
| 76 | #include "llvm/ADT/TypeSwitch.h" |
| 77 | #include <memory> |
| 78 | #include <set> |
| 79 | #include <stack> |
| 80 | |
| 81 | namespace mlir { |
| 82 | namespace tosa { |
| 83 | #define GEN_PASS_DEF_TOSAREDUCETRANSPOSES |
| 84 | #include "mlir/Dialect/Tosa/Transforms/Passes.h.inc" |
| 85 | } // namespace tosa |
| 86 | } // namespace mlir |
| 87 | |
| 88 | using namespace mlir; |
| 89 | using namespace mlir::tosa; |
| 90 | |
| 91 | //===----------------------------------------------------------------------===// |
| 92 | // TOSA Reduce Transposes Pass. |
| 93 | //===----------------------------------------------------------------------===// |
| 94 | |
| 95 | namespace { |
| 96 | |
| 97 | struct TosaReduceTransposes final |
| 98 | : public tosa::impl::TosaReduceTransposesBase<TosaReduceTransposes> { |
| 99 | void runOnOperation() override; |
| 100 | |
| 101 | private: |
| 102 | // This will collect all the data dependencies for the given Operation |
| 103 | // up to and including ConstOp, ReshapeOp, and TransposeOp. |
| 104 | bool collectFanIn(Operation *op, SetVector<Operation *> &collected); |
| 105 | bool convertDependentOps(SetVector<Operation *> &dependentOps, |
| 106 | DenseMap<Value, Value> &valuesMap, |
| 107 | IRRewriter &rewriter, |
| 108 | ArrayRef<int32_t> hoistedPerms); |
| 109 | |
| 110 | // Checks if the two permutations, when applied consecutively, result |
| 111 | // in the identity. |
| 112 | bool areInvolutionTransposes(ArrayRef<int32_t> perms1, |
| 113 | ArrayRef<int32_t> perms2); |
| 114 | |
| 115 | // This is meant to apply to operations with the TosaElementwiseOperator |
| 116 | // trait. |
| 117 | std::optional<Value> |
| 118 | buildMappedToValue(Operation *op, const DenseMap<Value, Value> &valuesMap, |
| 119 | IRRewriter &rewriter, ArrayRef<int32_t> hoistedPerms); |
| 120 | |
| 121 | // This updates valuesMap when we encounter another TransposeOp as a |
| 122 | // dependency of the hoisted one. %0 = tosa.transpose %arg0 <- applies to |
| 123 | // this %1 = tosa.transpose %0 <- when tracking back from this |
| 124 | std::optional<Value> |
| 125 | buildMappedToValue(TransposeOp transposeOp, |
| 126 | const DenseMap<Value, Value> &valuesMap, |
| 127 | IRRewriter &rewriter, ArrayRef<int32_t> hoistedPerms); |
| 128 | |
| 129 | // Checks if ReshapeOp can have hoisted TransposeOp folded into it. If so, |
| 130 | // it creates new ReshapeOp with that fold. |
| 131 | std::optional<Value> |
| 132 | buildMappedToValue(ReshapeOp reshapeOp, |
| 133 | const DenseMap<Value, Value> &valuesMap, |
| 134 | IRRewriter &rewriter, ArrayRef<int32_t> hoistedPerms); |
| 135 | |
| 136 | // We may have something like: |
| 137 | // %0 = tosa.const |
| 138 | // %1 = tosa.transpose |
| 139 | // %2 = tosa.add %0, %1 |
| 140 | // %3 = tosa.transpose %2 |
| 141 | // that --tosa-layerwise-const-fold wouldn't handle. This use shows up |
| 142 | // in MobilenetV3. |
| 143 | std::optional<Value> |
| 144 | buildMappedToValue(ConstOp constOp, const DenseMap<Value, Value> &valuesMap, |
| 145 | IRRewriter &rewriter, ArrayRef<int32_t> hoistedPerms); |
| 146 | |
| 147 | // Checks which TransposeOp we should "replace", turning their converted |
| 148 | // chains of ops, through which they were propagated, "live", and the old code |
| 149 | // "dead." Attempts to avoid doing so when doing so would result in the old |
| 150 | // code staying "live," resulting in duplication. |
| 151 | std::set<TransposeOp> getGoodReplacements( |
| 152 | ArrayRef<int32_t> perms, |
| 153 | std::vector<std::pair<TransposeOp, SetVector<Operation *>>> |
| 154 | &transposeInfo); |
| 155 | |
| 156 | // Helper function for dependenciesAreValid. |
| 157 | bool userNotContainedInValidTransposeDependencies( |
| 158 | Operation *user, std::set<TransposeOp> &validTransposes, |
| 159 | std::vector<std::pair<TransposeOp, SetVector<Operation *>>> |
| 160 | &transposeInfo); |
| 161 | |
| 162 | // Helper function for getGoodReplacements to check if some TransposeOp's |
| 163 | // dependencies are OK. |
| 164 | bool dependenciesAreValid( |
| 165 | ArrayRef<int32_t> perms, const SetVector<Operation *> &dependentOps, |
| 166 | std::set<TransposeOp> &validTransposes, |
| 167 | std::vector<std::pair<TransposeOp, SetVector<Operation *>>> |
| 168 | &transposeInfo); |
| 169 | |
| 170 | // Applies perms to the DenseElementsAttr. |
| 171 | // If it returns std::nullopt, it also triggers pass failure, since verifier |
| 172 | // guarantees from TOSA are not in place (and otherwise, if used elsewhere, |
| 173 | // it should fail). |
| 174 | // This is a basic API and may benefit from refactor into the core MLIR APIs. |
| 175 | std::optional<DenseElementsAttr> |
| 176 | transposeDenseAttribute(DenseElementsAttr input, ArrayRef<int32_t> perms); |
| 177 | }; |
| 178 | |
| 179 | std::optional<DenseElementsAttr> |
| 180 | TosaReduceTransposes::transposeDenseAttribute(DenseElementsAttr input, |
| 181 | ArrayRef<int32_t> perms) { |
| 182 | RankedTensorType oldType = llvm::cast<RankedTensorType>(input.getType()); |
| 183 | RankedTensorType newType = |
| 184 | RankedTensorType::get(applyTOSAPermutation(oldType.getShape(), perms), |
| 185 | oldType.getElementType()); |
| 186 | size_t rank = oldType.getRank(); |
| 187 | |
| 188 | // Asserted by TransposeOp verifier and TOSA disallowing tensor with dimension |
| 189 | // 0. If not in place, something is very wrong. |
| 190 | if (rank <= 0 || oldType.getNumElements() <= 0) { |
| 191 | signalPassFailure(); |
| 192 | return std::nullopt; |
| 193 | } |
| 194 | |
| 195 | if (input.isSplat()) |
| 196 | return input.reshape(newType); |
| 197 | |
| 198 | // The algorithm is approximately as follows: |
| 199 | // input: perms, input flat array, input tensor type |
| 200 | // (1/2) determine the strides of input/output if |
| 201 | // they were strided in row-major order. (3) adjust the strides for the |
| 202 | // input to be in the same order of indices as the output is written. |
| 203 | // (4) process dimension by dimension. example: perms 2, 0, 1; input |
| 204 | // 2x3x4; output 4x2x3 for i ... 4, j ... 2, k ... 3: output[i][j][k] = |
| 205 | // input[j][k][i] output[6i + 3j + k] = input[12j + 4k + i] and we adjust |
| 206 | // input strides to be as input[i + 12j + 4k] so we may process |
| 207 | // layer-by-layer. |
| 208 | |
| 209 | // Step 1/2: Strides for input. We ignore output since row-major and can just |
| 210 | // push_back. |
| 211 | |
| 212 | SmallVector<int64_t> originalInputStrides(rank); |
| 213 | originalInputStrides[rank - 1] = 1; |
| 214 | // index with int64_t to avoid overflow |
| 215 | for (int64_t i = rank - 2; i >= 0; i--) |
| 216 | originalInputStrides[i] = |
| 217 | originalInputStrides[i + 1] * oldType.getDimSize(i + 1); |
| 218 | |
| 219 | // Step 3: Transpose strides of input to be same indexing (i, j, k, ...) as |
| 220 | // output which is done in row-major order. |
| 221 | |
| 222 | SmallVector<int64_t> newInputStrides; |
| 223 | newInputStrides.reserve(rank); |
| 224 | for (int32_t v : perms) |
| 225 | newInputStrides.push_back(originalInputStrides[v]); |
| 226 | |
| 227 | // Step 4: Write out the transposed "flat array" dimension by dimension. |
| 228 | |
| 229 | auto inputArray = input.getValues<Attribute>(); |
| 230 | SmallVector<std::pair<int64_t, int64_t>> boundsAndStrides; |
| 231 | for (size_t i = 0; i < rank; i++) |
| 232 | boundsAndStrides.push_back({newType.getDimSize(i), newInputStrides[i]}); |
| 233 | |
| 234 | SmallVector<Attribute> resultArray; |
| 235 | resultArray.reserve(inputArray.size()); |
| 236 | |
| 237 | std::function<void(int64_t, |
| 238 | SmallVector<std::pair<int64_t, int64_t>>::const_iterator)> |
| 239 | processTransposeDim = [&](auto accumulatedIndex, auto it) { |
| 240 | if (it == boundsAndStrides.end()) { |
| 241 | resultArray.push_back(inputArray[accumulatedIndex]); |
| 242 | return; |
| 243 | } |
| 244 | |
| 245 | for (int64_t i = 0; i < it->first; i++) { |
| 246 | int64_t j = accumulatedIndex + i * it->second; |
| 247 | processTransposeDim(j, it + 1); |
| 248 | } |
| 249 | }; |
| 250 | |
| 251 | processTransposeDim(0, boundsAndStrides.begin()); |
| 252 | |
| 253 | return DenseElementsAttr::get(newType, resultArray); |
| 254 | } |
| 255 | |
| 256 | // The SetVector should only contain ConstOp, ReshapeOp, TransposeOp |
| 257 | // as the sources of the data dependencies, and TosaElementWiseOperator |
| 258 | // after that, if the function returns true. |
| 259 | bool TosaReduceTransposes::collectFanIn(Operation *op, |
| 260 | SetVector<Operation *> &collected) { |
| 261 | // Can occur if defined through the parameter to a func.func. |
| 262 | if (!op) |
| 263 | return false; |
| 264 | |
| 265 | if (!llvm::isa_and_present<tosa::TosaDialect>(op->getDialect())) |
| 266 | return false; |
| 267 | |
| 268 | // Prevent extra work if already seen. |
| 269 | if (collected.contains(op)) |
| 270 | return true; |
| 271 | |
| 272 | // Throw it out so later don't have to deal with this. |
| 273 | if (op->getNumResults() != 1 || |
| 274 | !llvm::isa<RankedTensorType>(op->getResult(idx: 0).getType())) |
| 275 | return false; |
| 276 | |
| 277 | // We don't wish to traverse up a ReshapeOp, since generally we can't |
| 278 | // propagate a TransposeOp through it. TransposeOp, ReshapeOp, ConstOp |
| 279 | // will have no in-edges in the data dependency graph we construct for |
| 280 | // the downstream TransposeOp. |
| 281 | if (!llvm::isa<tosa::TransposeOp>(op) && !llvm::isa<tosa::ReshapeOp>(op) && |
| 282 | !llvm::isa<tosa::ConstOp>(op)) { |
| 283 | |
| 284 | if (!llvm::isa<tosa::MulOp>(op) && |
| 285 | !op->hasTrait<OpTrait::tosa::TosaElementwiseOperator>()) |
| 286 | return false; |
| 287 | |
| 288 | for (Value operand : op->getOperands()) { |
| 289 | // If this is a problem in future, think about alternatives to recursion. |
| 290 | if (llvm::isa<tosa::MulOp>(op) && operand == op->getOperand(2)) { |
| 291 | // do not recurse into MulOp's shift operand |
| 292 | continue; |
| 293 | } |
| 294 | if (!collectFanIn(operand.getDefiningOp(), collected)) |
| 295 | return false; |
| 296 | } |
| 297 | } |
| 298 | |
| 299 | // Insert in topological order. |
| 300 | collected.insert(op); |
| 301 | |
| 302 | return true; |
| 303 | } |
| 304 | |
| 305 | // Assuming that due to the verification of TransposeOp perms arrays are |
| 306 | // permutations of 0 - perms.size() - 1. |
| 307 | bool TosaReduceTransposes::areInvolutionTransposes(ArrayRef<int32_t> perms1, |
| 308 | ArrayRef<int32_t> perms2) { |
| 309 | if (perms1.size() != perms2.size()) |
| 310 | return false; |
| 311 | int32_t n = perms1.size(); |
| 312 | for (int32_t i = 0; i < n; i++) |
| 313 | if (perms2[perms1[i]] != i) |
| 314 | return false; |
| 315 | return true; |
| 316 | } |
| 317 | |
| 318 | // Primary overload for those with TosaElementwiseOperator trait. |
| 319 | // The other ones handle the case of the operations that occur at the |
| 320 | // roots of the data dependency graph (ConstOp, ReshapeOp, TransposeOp). |
| 321 | std::optional<Value> TosaReduceTransposes::buildMappedToValue( |
| 322 | Operation *op, const DenseMap<Value, Value> &valuesMap, |
| 323 | IRRewriter &rewriter, ArrayRef<int32_t> hoistedPerms) { |
| 324 | if (op->getNumResults() != 1 || |
| 325 | (!llvm::isa<tosa::MulOp>(op) && |
| 326 | !op->hasTrait<OpTrait::tosa::TosaElementwiseOperator>())) |
| 327 | return std::nullopt; |
| 328 | |
| 329 | auto resultType = llvm::cast<RankedTensorType>(op->getResult(idx: 0).getType()); |
| 330 | SmallVector<Value, 3> operands; |
| 331 | for (Value v : op->getOperands()) { |
| 332 | if (valuesMap.contains(v)) { |
| 333 | operands.push_back(valuesMap.at(v)); |
| 334 | } else if (llvm::isa<tosa::MulOp>(op) && v == op->getOperand(2)) { |
| 335 | // special case for MulOp's shift operand |
| 336 | operands.push_back(v); |
| 337 | } else { |
| 338 | return std::nullopt; |
| 339 | } |
| 340 | } |
| 341 | |
| 342 | // Conceptually, we propagate the hoisted TransposeOp through |
| 343 | // these interveaning operations. For example, |
| 344 | |
| 345 | // %0 = tosa.clamp %input : (tensor<2x3xi32>) -> tensor<2x3xi32> |
| 346 | // %1 = tosa.transpose %0 {perms = [1, 0]} : (tensor<2x3xi32>) -> |
| 347 | // tensor<3x2xi32> |
| 348 | |
| 349 | // becomes: |
| 350 | // %0 = tosa.transpose %input {perms = [1, 0]} : (tensor<2x3xi32>) -> |
| 351 | // tensor<3x2xi32> |
| 352 | // %1 = tosa.clamp %0 : (tensor<3x2xi32>) -> tensor<3x2xi32>) |
| 353 | |
| 354 | // We construct this new tosa.clamp here, but it doesn't |
| 355 | // turn "live" until the transpose being hoisted through this chain |
| 356 | // is replaced with the proper value from the new chain. |
| 357 | |
| 358 | return rewriter |
| 359 | .create(op->getLoc(), op->getName().getIdentifier(), operands, |
| 360 | RankedTensorType::get( |
| 361 | applyTOSAPermutation(resultType.getShape(), hoistedPerms), |
| 362 | resultType.getElementType()), |
| 363 | op->getAttrs()) |
| 364 | ->getResult(0); |
| 365 | } |
| 366 | |
| 367 | std::optional<Value> TosaReduceTransposes::buildMappedToValue( |
| 368 | TransposeOp transposeOp, const DenseMap<Value, Value> &valuesMap, |
| 369 | IRRewriter &rewriter, ArrayRef<int32_t> hoistedPerms) { |
| 370 | if (!areInvolutionTransposes(perms1: hoistedPerms, perms2: transposeOp.getPerms())) |
| 371 | return std::nullopt; |
| 372 | return transposeOp.getInput1(); |
| 373 | } |
| 374 | |
| 375 | std::optional<Value> TosaReduceTransposes::buildMappedToValue( |
| 376 | ReshapeOp reshapeOp, const DenseMap<Value, Value> &valuesMap, |
| 377 | IRRewriter &rewriter, ArrayRef<int32_t> hoistedPerms) { |
| 378 | auto reshapeOutput = reshapeOp.getOutput(); |
| 379 | auto reshapeInputType = |
| 380 | llvm::dyn_cast<RankedTensorType>(reshapeOp.getInput1().getType()); |
| 381 | auto reshapeInputShape = reshapeInputType.getShape(); |
| 382 | // want reshape N -> 1x1x...x1xNx1x...x1x1 |
| 383 | if (!reshapeInputType || reshapeInputShape.size() != 1) |
| 384 | return std::nullopt; |
| 385 | auto reshapeOutputType = |
| 386 | llvm::cast<RankedTensorType>(reshapeOutput.getType()); |
| 387 | |
| 388 | // Instead of inserting a TransposeOp here, we check if we can fold it into |
| 389 | // the ReshapeOp. There is more complex cases where this is possible, and |
| 390 | // this check can be extended. |
| 391 | |
| 392 | // Checking if reshape is N -> 1x1x...x1xNx1x...x1x1 |
| 393 | auto shape = reshapeOutputType.getShape(); |
| 394 | size_t ones = llvm::count(shape, 1); |
| 395 | // N == 1 and N != 1 |
| 396 | if (ones != shape.size() - 1 && |
| 397 | !(ones == shape.size() && reshapeInputShape[0] == 1)) |
| 398 | return std::nullopt; |
| 399 | |
| 400 | // Do not insert a TransposeOp, instead we fold the reshape and its attribute. |
| 401 | llvm::SmallVector<int64_t> newShape; |
| 402 | if (!tosa::getConstShapeValues(op: reshapeOp.getShape().getDefiningOp(), |
| 403 | result_shape&: newShape)) { |
| 404 | // this mean shape is not constant |
| 405 | return std::nullopt; |
| 406 | } |
| 407 | ImplicitLocOpBuilder builder(reshapeOp.getLoc(), rewriter); |
| 408 | auto foldedReshape = rewriter.create<ReshapeOp>( |
| 409 | reshapeOp.getLoc(), |
| 410 | RankedTensorType::get(applyTOSAPermutation(shape, hoistedPerms), |
| 411 | reshapeOutputType.getElementType()), |
| 412 | reshapeOp.getInput1(), |
| 413 | getTosaConstShape(builder, applyTOSAPermutation(llvm::ArrayRef(newShape), |
| 414 | hoistedPerms))); |
| 415 | return foldedReshape->getResult(0); |
| 416 | } |
| 417 | |
| 418 | std::optional<Value> TosaReduceTransposes::buildMappedToValue( |
| 419 | ConstOp constOp, const DenseMap<Value, Value> &valuesMap, |
| 420 | IRRewriter &rewriter, ArrayRef<int32_t> hoistedPerms) { |
| 421 | auto denseAttr = llvm::dyn_cast<DenseElementsAttr>(constOp.getValues()); |
| 422 | if (!denseAttr) |
| 423 | return std::nullopt; |
| 424 | auto maybeNewDenseAttr = transposeDenseAttribute(denseAttr, hoistedPerms); |
| 425 | if (!maybeNewDenseAttr.has_value()) |
| 426 | return std::nullopt; |
| 427 | auto newDenseAttr = maybeNewDenseAttr.value(); |
| 428 | auto newConstOp = rewriter.create<ConstOp>( |
| 429 | constOp.getLoc(), newDenseAttr.getType(), newDenseAttr); |
| 430 | return newConstOp->getResult(0); |
| 431 | } |
| 432 | |
| 433 | bool TosaReduceTransposes::convertDependentOps( |
| 434 | SetVector<Operation *> &dependentOps, DenseMap<Value, Value> &valuesMap, |
| 435 | IRRewriter &rewriter, ArrayRef<int32_t> hoistedPerms) { |
| 436 | |
| 437 | for (Operation *op : dependentOps) { |
| 438 | if (!op || op->getNumResults() != 1) |
| 439 | return false; |
| 440 | |
| 441 | Value priorValue = op->getResult(0); |
| 442 | |
| 443 | // It's possible on a prior transposeOp we had the same dependency and |
| 444 | // already resolved it. |
| 445 | if (valuesMap.contains(priorValue)) |
| 446 | continue; |
| 447 | |
| 448 | // Keep converted ops close to the original. |
| 449 | rewriter.setInsertionPointAfter(op); |
| 450 | |
| 451 | std::optional<Value> maybeValue = |
| 452 | llvm::TypeSwitch<Operation *, std::optional<Value>>(op) |
| 453 | .Case<TransposeOp, ReshapeOp, ConstOp>([&](auto transposeOp) { |
| 454 | return buildMappedToValue(transposeOp, valuesMap, rewriter, |
| 455 | hoistedPerms); |
| 456 | }) |
| 457 | .Default([&](Operation *op) { |
| 458 | return buildMappedToValue(op, valuesMap, rewriter, hoistedPerms); |
| 459 | }); |
| 460 | |
| 461 | if (!maybeValue.has_value()) |
| 462 | return false; |
| 463 | |
| 464 | valuesMap[priorValue] = maybeValue.value(); |
| 465 | } |
| 466 | |
| 467 | return true; |
| 468 | } |
| 469 | |
| 470 | bool TosaReduceTransposes::userNotContainedInValidTransposeDependencies( |
| 471 | Operation *user, std::set<TransposeOp> &validTransposes, |
| 472 | std::vector<std::pair<TransposeOp, SetVector<Operation *>>> |
| 473 | &transposeInfo) { |
| 474 | return llvm::none_of( |
| 475 | transposeInfo, |
| 476 | [&validTransposes, |
| 477 | user](const std::pair<TransposeOp, SetVector<Operation *>> &info) { |
| 478 | const auto &[transposeOp, dependentOps] = info; |
| 479 | return validTransposes.count(transposeOp) && |
| 480 | dependentOps.contains(user); |
| 481 | }); |
| 482 | } |
| 483 | |
| 484 | // Dependencies are valid for an operation if none of them occur outside |
| 485 | // of the proper fan-in cones of the hoisted TransposeOp with the same perms |
| 486 | // that we can replace. Described in more detail within. |
| 487 | bool TosaReduceTransposes::dependenciesAreValid( |
| 488 | ArrayRef<int32_t> perms, const SetVector<Operation *> &dependentOps, |
| 489 | std::set<TransposeOp> &validTransposes, |
| 490 | std::vector<std::pair<TransposeOp, SetVector<Operation *>>> |
| 491 | &transposeInfo) { |
| 492 | for (Operation *op : dependentOps) { |
| 493 | |
| 494 | // It's OK wherever ConstOp has uses -- in the worst case, we duplicate. |
| 495 | // This can be changed later if we find the memory impact is too high. |
| 496 | if (llvm::isa<ConstOp>(op)) |
| 497 | continue; |
| 498 | |
| 499 | for (OpOperand &use : op->getUses()) { |
| 500 | // Want the uses to be (1) contained in the dependentOps of other |
| 501 | // validTransposes, or (2) to be directly used in a TransposeOp with the |
| 502 | // same perms. For (2) it means the fan-in is a subset of our |
| 503 | // dependentOps, so it is also a validTranspose that will eventually be |
| 504 | // replaced. |
| 505 | Operation *user = use.getOwner(); |
| 506 | if (auto otherTranspose = llvm::dyn_cast<TransposeOp>(user)) { |
| 507 | // Can later think about cases where transpose -> transpose |
| 508 | // or reshape -> transpose, where the transposes are not necessarily |
| 509 | // the same perms as the hoisted, if implementing a more general |
| 510 | // transform. These could be permitted. |
| 511 | if (!llvm::equal(perms, otherTranspose.getPerms())) |
| 512 | return false; |
| 513 | } else if (userNotContainedInValidTransposeDependencies( |
| 514 | user, validTransposes, transposeInfo)) { |
| 515 | return false; |
| 516 | } |
| 517 | } |
| 518 | } |
| 519 | |
| 520 | return true; |
| 521 | } |
| 522 | |
| 523 | // Getting the set of TransposeOp that we can replace without causing |
| 524 | // the old fan-in cones of any TransposeOp to remain "live", i.e, -- not being |
| 525 | // dead code. This is done by iterating the set until convergence, since |
| 526 | // if you are used outside your own fan-in cone, it's possible to be used |
| 527 | // in another fan-in cone of a TransposeOp that is being replaced -- unless |
| 528 | // we find that that one has a usage outside of it too. |
| 529 | std::set<TransposeOp> TosaReduceTransposes::getGoodReplacements( |
| 530 | ArrayRef<int32_t> perms, |
| 531 | std::vector<std::pair<TransposeOp, SetVector<Operation *>>> |
| 532 | &transposeInfo) { |
| 533 | // Initially, we assume they are all good to replace, |
| 534 | // and we whittle them down based on our criteria. |
| 535 | std::set<TransposeOp> ableToReplace; |
| 536 | for (const auto &[transposeOp, _] : transposeInfo) |
| 537 | ableToReplace.insert(transposeOp); |
| 538 | |
| 539 | bool gotRid; |
| 540 | do { |
| 541 | gotRid = false; |
| 542 | for (const auto &[transposeOp, dependentOps] : transposeInfo) { |
| 543 | // We don't care about it. Already invalidated. |
| 544 | if (!ableToReplace.count(transposeOp)) |
| 545 | continue; |
| 546 | |
| 547 | // Check for validity. |
| 548 | if (!dependenciesAreValid(perms, dependentOps, ableToReplace, |
| 549 | transposeInfo)) { |
| 550 | ableToReplace.erase(transposeOp); |
| 551 | gotRid = true; |
| 552 | break; |
| 553 | } |
| 554 | } |
| 555 | |
| 556 | } while (gotRid); |
| 557 | |
| 558 | return ableToReplace; |
| 559 | } |
| 560 | |
| 561 | void TosaReduceTransposes::runOnOperation() { |
| 562 | // We want to operate only within a single block. |
| 563 | if (!getOperation().getRegion().hasOneBlock()) |
| 564 | return; |
| 565 | |
| 566 | IRRewriter rewriter(&getContext()); |
| 567 | // For each perms, maintain a mapping for converted ops, avoid duplication. |
| 568 | DenseMap<ArrayRef<int32_t>, DenseMap<Value, Value>> permsToValues; |
| 569 | // For each perms, we keep track of which TransposeOp are eligible |
| 570 | // for replacement alongside their dependentOps. |
| 571 | DenseMap<ArrayRef<int32_t>, |
| 572 | std::vector<std::pair<TransposeOp, SetVector<Operation *>>>> |
| 573 | permsToTransposeInfo; |
| 574 | |
| 575 | // Necessary for lifetime, since DenseMap keeps a copy of the ArrayRef. |
| 576 | // Use SmallVector for perms (common-case is <= 4) but std::vector otherwise |
| 577 | // since no guarantee of smallness. |
| 578 | std::vector<SmallVector<int32_t>> collectedPerms; |
| 579 | |
| 580 | // This keeps track of the order across all eligible-for-replacement |
| 581 | // TransposeOp and their perms, a necessity for the final replacements. |
| 582 | std::stack<std::pair<TransposeOp, ArrayRef<int32_t>>> totalTransposeOrder; |
| 583 | |
| 584 | // We want to reserve the space up front, since SmallVector stores some data |
| 585 | // internally and the ArrayRef can reference that, which we don't want to get |
| 586 | // invalidated. |
| 587 | size_t expectedMaxPerms = 0; |
| 588 | getOperation().walk([&](TransposeOp) { expectedMaxPerms += 1; }); |
| 589 | collectedPerms.reserve(expectedMaxPerms); |
| 590 | |
| 591 | getOperation().walk([&](TransposeOp transposeOp) { |
| 592 | SetVector<Operation *> dependentOps; |
| 593 | collectedPerms.emplace_back(); |
| 594 | SmallVector<int32_t> &perms = collectedPerms.back(); |
| 595 | |
| 596 | // Dynamic shapes are OK, but the incompatible ones will be rejected later. |
| 597 | auto input = transposeOp.getInput1(); |
| 598 | auto output = transposeOp.getOutput(); |
| 599 | |
| 600 | // However, we don't support unranked tensors. |
| 601 | if (!llvm::isa<RankedTensorType>(input.getType()) || |
| 602 | !llvm::isa<RankedTensorType>(output.getType())) |
| 603 | return; |
| 604 | |
| 605 | llvm::for_each(transposeOp.getPerms(), |
| 606 | [&perms](const auto i) { perms.emplace_back(i); }); |
| 607 | |
| 608 | // We let --canonicalize deal with identity transpose. |
| 609 | if (llvm::equal(llvm::seq<int32_t>(0, perms.size()), perms)) |
| 610 | return; |
| 611 | |
| 612 | // Can fail if some set of basic invariants is not met that we want to |
| 613 | // perform our conversions. |
| 614 | if (!collectFanIn(input.getDefiningOp(), dependentOps)) |
| 615 | return; |
| 616 | |
| 617 | // Want to associate valuesMap for already converted of the same perms, |
| 618 | // since it's possible multiple hoisted transposes w/ different perms |
| 619 | // converge on an op, which would result in different transformations. |
| 620 | DenseMap<Value, Value> &valuesMap = permsToValues[perms]; |
| 621 | |
| 622 | // Attempt to perform the conversions and placements into IR |
| 623 | // without turning inserted code "live". Also fills out valuesMap. |
| 624 | // Fails if there is an intermediary we do not support. |
| 625 | if (!convertDependentOps(dependentOps, valuesMap, rewriter, perms)) |
| 626 | // Some additional operations may have been inserted, but will be |
| 627 | // removed by dead code elimination. |
| 628 | return; |
| 629 | |
| 630 | // This should not happen. If it does -- it's unexpected, |
| 631 | // so we fail the pass. |
| 632 | if (!valuesMap.contains(input)) |
| 633 | return signalPassFailure(); |
| 634 | |
| 635 | // It's possible the types are not compatible (because of dynamic shapes), |
| 636 | // and in these cases, want to resolve dynamic shapes before running the |
| 637 | // pass. |
| 638 | if (output.getType() != valuesMap.at(input).getType()) |
| 639 | return; |
| 640 | |
| 641 | auto &transposeInfo = permsToTransposeInfo[perms]; |
| 642 | |
| 643 | // In general, we might also want to introduce "newDependentOps" |
| 644 | // if there are new usages that don't fall inside the original fan-ins |
| 645 | // (like the TransposeOp we insert for ReshapeOp), |
| 646 | // but in this case, that is specialized enough and overlaps |
| 647 | // with another direct-use TransposeOp case we need to cover anyway. |
| 648 | transposeInfo.push_back({transposeOp, dependentOps}); |
| 649 | |
| 650 | // This is for the final replacement across all transposes. |
| 651 | totalTransposeOrder.push({transposeOp, perms}); |
| 652 | }); |
| 653 | |
| 654 | // We want to do a full fan-in analysis on a perms-level, |
| 655 | // since if we do it on a multi-perms level, and they share (due to a shared |
| 656 | // dependency on a Reshape) then we would also get duplicate ops. |
| 657 | // Const is special cased. |
| 658 | std::set<TransposeOp> ableToReplace; |
| 659 | for (auto &[perms, transposeInfo] : permsToTransposeInfo) { |
| 660 | // Gives us back replacements that would never result in any duplicate |
| 661 | // operations being inserted by us in the IR (i.e, our goal is only to |
| 662 | // remove transposes, and not create a "new chain" to do so, but replace |
| 663 | // the existing chains). |
| 664 | // Ideally, --canonicalize is run before this pass, since it helps this |
| 665 | // analysis by removing dead code to allow more potentially acceptable |
| 666 | // transformations. |
| 667 | auto goodReplacementsForPerms = getGoodReplacements(perms, transposeInfo); |
| 668 | ableToReplace.insert(goodReplacementsForPerms.begin(), |
| 669 | goodReplacementsForPerms.end()); |
| 670 | } |
| 671 | |
| 672 | // We want to do replacement across all transposes |
| 673 | // in reverse order, due to invalidation of valuesMap mappings |
| 674 | // if we did it otherwise. |
| 675 | while (!totalTransposeOrder.empty()) { |
| 676 | auto [transposeOp, perms] = totalTransposeOrder.top(); |
| 677 | totalTransposeOrder.pop(); |
| 678 | |
| 679 | if (ableToReplace.count(transposeOp) == 0) |
| 680 | continue; |
| 681 | |
| 682 | auto &valuesMap = permsToValues[perms]; |
| 683 | auto input = transposeOp.getInput1(); |
| 684 | |
| 685 | // The purpose of this reverse iteration |
| 686 | // is to avoid valuesMap invalidation. If it happens, |
| 687 | // something is wrong. |
| 688 | if (!valuesMap.contains(input)) |
| 689 | return signalPassFailure(); |
| 690 | |
| 691 | rewriter.replaceOp(transposeOp, valuesMap.at(input)); |
| 692 | } |
| 693 | |
| 694 | // We can remove all dead code by going in reverse. |
| 695 | // This is because we would remove usages before we |
| 696 | // see the users. |
| 697 | getOperation().walk<WalkOrder::PostOrder, ReverseIterator>( |
| 698 | [&](Operation *op) { |
| 699 | if (isOpTriviallyDead(op)) |
| 700 | rewriter.eraseOp(op); |
| 701 | }); |
| 702 | } |
| 703 | |
| 704 | } // namespace |
| 705 | |