1//===- TosaReduceTransposes.cpp -------------------------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9// ----------
10// Motivation:
11// ----------
12
13// Some legalization pathways introduce redundant tosa.TRANSPOSE
14// operations that result in avoidable data movement. For example,
15// PyTorch -> TOSA contains a lot of unnecessary transposes due
16// to conversions between NCHW and NHWC.
17
18// We wish to remove all the ones that we can, since in general
19// it is possible to remove the overwhelming majority.
20
21// -------------------
22// High-Level Overview:
23// -------------------
24
25// The pass works through the transpose operators in the program. It begins at
26// some transpose operator with an associated permutations tensor. It traverses
27// upwards through the dependencies of this transpose and verifies that we
28// encounter only operators with the TosaElementwiseOperator trait and terminate
29// in either constants, reshapes, or transposes.
30
31// We then evaluate whether there are any additional restrictions (the
32// transposes it terminates in must invert the one we began at, and the reshapes
33// must be ones in which we can fold the transpose into), and then we hoist the
34// transpose through the intervening operators, folding it at the constants,
35// reshapes, and transposes.
36
37// Finally, we ensure that we do not need both the transposed form (the form
38// that had the transpose hoisted through it) and the untransposed form (which
39// it was prior), by analyzing the usages of those dependent operators of a
40// given transpose we are attempting to hoist and replace.
41
42// If they are such that it would require both forms to be necessary, then we do
43// not replace the hoisted transpose, causing the new chain to be dead.
44// Otherwise, we do and the old chain (untransposed form) becomes dead. Only one
45// chain will ever then be live, resulting in no duplication.
46
47// We then perform a simple one-pass DCE, so no canonicalization is necessary.
48
49// -----------
50// Future Work:
51// -----------
52
53// (1) Evaluate tradeoffs with permitting ConstOp to be duplicated across
54// hoisted
55// transposes with different permutation tensors.
56
57// (2) Expand the class of foldable upstream ReshapeOp we permit beyond
58// N -> 1x1x...x1xNx1x...x1x1.
59
60// (3) Enchance the pass to permit folding arbitrary transpose pairs, beyond
61// those that form the identity.
62
63// (4) Add support for more instructions besides TosaElementwiseOperator as
64// the intervening ones (for example, the reduce_* operators).
65
66// (5) Support hoisting transposes up to an input parameter.
67
68//===----------------------------------------------------------------------===//
69
70#include "mlir/Dialect/Func/IR/FuncOps.h"
71#include "mlir/Dialect/Tosa/IR/TosaOps.h"
72#include "mlir/Dialect/Tosa/Transforms/Passes.h"
73#include "mlir/Dialect/Tosa/Utils/ConversionUtils.h"
74#include "mlir/IR/Iterators.h"
75#include "mlir/IR/Matchers.h"
76#include "llvm/ADT/TypeSwitch.h"
77#include <memory>
78#include <set>
79#include <stack>
80
81namespace mlir {
82namespace tosa {
83#define GEN_PASS_DEF_TOSAREDUCETRANSPOSES
84#include "mlir/Dialect/Tosa/Transforms/Passes.h.inc"
85} // namespace tosa
86} // namespace mlir
87
88using namespace mlir;
89using namespace mlir::tosa;
90
91//===----------------------------------------------------------------------===//
92// TOSA Reduce Transposes Pass.
93//===----------------------------------------------------------------------===//
94
95namespace {
96
97struct TosaReduceTransposes final
98 : public tosa::impl::TosaReduceTransposesBase<TosaReduceTransposes> {
99 void runOnOperation() override;
100
101private:
102 // This will collect all the data dependencies for the given Operation
103 // up to and including ConstOp, ReshapeOp, and TransposeOp.
104 bool collectFanIn(Operation *op, SetVector<Operation *> &collected);
105 bool convertDependentOps(SetVector<Operation *> &dependentOps,
106 DenseMap<Value, Value> &valuesMap,
107 IRRewriter &rewriter,
108 ArrayRef<int32_t> hoistedPerms);
109
110 // Checks if the two permutations, when applied consecutively, result
111 // in the identity.
112 bool areInvolutionTransposes(ArrayRef<int32_t> perms1,
113 ArrayRef<int32_t> perms2);
114
115 // This is meant to apply to operations with the TosaElementwiseOperator
116 // trait.
117 std::optional<Value>
118 buildMappedToValue(Operation *op, const DenseMap<Value, Value> &valuesMap,
119 IRRewriter &rewriter, ArrayRef<int32_t> hoistedPerms);
120
121 // This updates valuesMap when we encounter another TransposeOp as a
122 // dependency of the hoisted one. %0 = tosa.transpose %arg0 <- applies to
123 // this %1 = tosa.transpose %0 <- when tracking back from this
124 std::optional<Value>
125 buildMappedToValue(TransposeOp transposeOp,
126 const DenseMap<Value, Value> &valuesMap,
127 IRRewriter &rewriter, ArrayRef<int32_t> hoistedPerms);
128
129 // Checks if ReshapeOp can have hoisted TransposeOp folded into it. If so,
130 // it creates new ReshapeOp with that fold.
131 std::optional<Value>
132 buildMappedToValue(ReshapeOp reshapeOp,
133 const DenseMap<Value, Value> &valuesMap,
134 IRRewriter &rewriter, ArrayRef<int32_t> hoistedPerms);
135
136 // We may have something like:
137 // %0 = tosa.const
138 // %1 = tosa.transpose
139 // %2 = tosa.add %0, %1
140 // %3 = tosa.transpose %2
141 // that --tosa-layerwise-const-fold wouldn't handle. This use shows up
142 // in MobilenetV3.
143 std::optional<Value>
144 buildMappedToValue(ConstOp constOp, const DenseMap<Value, Value> &valuesMap,
145 IRRewriter &rewriter, ArrayRef<int32_t> hoistedPerms);
146
147 // Checks which TransposeOp we should "replace", turning their converted
148 // chains of ops, through which they were propagated, "live", and the old code
149 // "dead." Attempts to avoid doing so when doing so would result in the old
150 // code staying "live," resulting in duplication.
151 std::set<TransposeOp> getGoodReplacements(
152 ArrayRef<int32_t> perms,
153 std::vector<std::pair<TransposeOp, SetVector<Operation *>>>
154 &transposeInfo);
155
156 // Helper function for dependenciesAreValid.
157 bool userNotContainedInValidTransposeDependencies(
158 Operation *user, std::set<TransposeOp> &validTransposes,
159 std::vector<std::pair<TransposeOp, SetVector<Operation *>>>
160 &transposeInfo);
161
162 // Helper function for getGoodReplacements to check if some TransposeOp's
163 // dependencies are OK.
164 bool dependenciesAreValid(
165 ArrayRef<int32_t> perms, const SetVector<Operation *> &dependentOps,
166 std::set<TransposeOp> &validTransposes,
167 std::vector<std::pair<TransposeOp, SetVector<Operation *>>>
168 &transposeInfo);
169
170 // Applies perms to the DenseElementsAttr.
171 // If it returns std::nullopt, it also triggers pass failure, since verifier
172 // guarantees from TOSA are not in place (and otherwise, if used elsewhere,
173 // it should fail).
174 // This is a basic API and may benefit from refactor into the core MLIR APIs.
175 std::optional<DenseElementsAttr>
176 transposeDenseAttribute(DenseElementsAttr input, ArrayRef<int32_t> perms);
177};
178
179std::optional<DenseElementsAttr>
180TosaReduceTransposes::transposeDenseAttribute(DenseElementsAttr input,
181 ArrayRef<int32_t> perms) {
182 RankedTensorType oldType = llvm::cast<RankedTensorType>(input.getType());
183 RankedTensorType newType =
184 RankedTensorType::get(applyTOSAPermutation(oldType.getShape(), perms),
185 oldType.getElementType());
186 size_t rank = oldType.getRank();
187
188 // Asserted by TransposeOp verifier and TOSA disallowing tensor with dimension
189 // 0. If not in place, something is very wrong.
190 if (rank <= 0 || oldType.getNumElements() <= 0) {
191 signalPassFailure();
192 return std::nullopt;
193 }
194
195 if (input.isSplat())
196 return input.reshape(newType);
197
198 // The algorithm is approximately as follows:
199 // input: perms, input flat array, input tensor type
200 // (1/2) determine the strides of input/output if
201 // they were strided in row-major order. (3) adjust the strides for the
202 // input to be in the same order of indices as the output is written.
203 // (4) process dimension by dimension. example: perms 2, 0, 1; input
204 // 2x3x4; output 4x2x3 for i ... 4, j ... 2, k ... 3: output[i][j][k] =
205 // input[j][k][i] output[6i + 3j + k] = input[12j + 4k + i] and we adjust
206 // input strides to be as input[i + 12j + 4k] so we may process
207 // layer-by-layer.
208
209 // Step 1/2: Strides for input. We ignore output since row-major and can just
210 // push_back.
211
212 SmallVector<int64_t> originalInputStrides(rank);
213 originalInputStrides[rank - 1] = 1;
214 // index with int64_t to avoid overflow
215 for (int64_t i = rank - 2; i >= 0; i--)
216 originalInputStrides[i] =
217 originalInputStrides[i + 1] * oldType.getDimSize(i + 1);
218
219 // Step 3: Transpose strides of input to be same indexing (i, j, k, ...) as
220 // output which is done in row-major order.
221
222 SmallVector<int64_t> newInputStrides;
223 newInputStrides.reserve(rank);
224 for (int32_t v : perms)
225 newInputStrides.push_back(originalInputStrides[v]);
226
227 // Step 4: Write out the transposed "flat array" dimension by dimension.
228
229 auto inputArray = input.getValues<Attribute>();
230 SmallVector<std::pair<int64_t, int64_t>> boundsAndStrides;
231 for (size_t i = 0; i < rank; i++)
232 boundsAndStrides.push_back({newType.getDimSize(i), newInputStrides[i]});
233
234 SmallVector<Attribute> resultArray;
235 resultArray.reserve(inputArray.size());
236
237 std::function<void(int64_t,
238 SmallVector<std::pair<int64_t, int64_t>>::const_iterator)>
239 processTransposeDim = [&](auto accumulatedIndex, auto it) {
240 if (it == boundsAndStrides.end()) {
241 resultArray.push_back(inputArray[accumulatedIndex]);
242 return;
243 }
244
245 for (int64_t i = 0; i < it->first; i++) {
246 int64_t j = accumulatedIndex + i * it->second;
247 processTransposeDim(j, it + 1);
248 }
249 };
250
251 processTransposeDim(0, boundsAndStrides.begin());
252
253 return DenseElementsAttr::get(newType, resultArray);
254}
255
256// The SetVector should only contain ConstOp, ReshapeOp, TransposeOp
257// as the sources of the data dependencies, and TosaElementWiseOperator
258// after that, if the function returns true.
259bool TosaReduceTransposes::collectFanIn(Operation *op,
260 SetVector<Operation *> &collected) {
261 // Can occur if defined through the parameter to a func.func.
262 if (!op)
263 return false;
264
265 if (!llvm::isa_and_present<tosa::TosaDialect>(op->getDialect()))
266 return false;
267
268 // Prevent extra work if already seen.
269 if (collected.contains(op))
270 return true;
271
272 // Throw it out so later don't have to deal with this.
273 if (op->getNumResults() != 1 ||
274 !llvm::isa<RankedTensorType>(op->getResult(idx: 0).getType()))
275 return false;
276
277 // We don't wish to traverse up a ReshapeOp, since generally we can't
278 // propagate a TransposeOp through it. TransposeOp, ReshapeOp, ConstOp
279 // will have no in-edges in the data dependency graph we construct for
280 // the downstream TransposeOp.
281 if (!llvm::isa<tosa::TransposeOp>(op) && !llvm::isa<tosa::ReshapeOp>(op) &&
282 !llvm::isa<tosa::ConstOp>(op)) {
283
284 if (!llvm::isa<tosa::MulOp>(op) &&
285 !op->hasTrait<OpTrait::tosa::TosaElementwiseOperator>())
286 return false;
287
288 for (Value operand : op->getOperands()) {
289 // If this is a problem in future, think about alternatives to recursion.
290 if (llvm::isa<tosa::MulOp>(op) && operand == op->getOperand(2)) {
291 // do not recurse into MulOp's shift operand
292 continue;
293 }
294 if (!collectFanIn(operand.getDefiningOp(), collected))
295 return false;
296 }
297 }
298
299 // Insert in topological order.
300 collected.insert(op);
301
302 return true;
303}
304
305// Assuming that due to the verification of TransposeOp perms arrays are
306// permutations of 0 - perms.size() - 1.
307bool TosaReduceTransposes::areInvolutionTransposes(ArrayRef<int32_t> perms1,
308 ArrayRef<int32_t> perms2) {
309 if (perms1.size() != perms2.size())
310 return false;
311 int32_t n = perms1.size();
312 for (int32_t i = 0; i < n; i++)
313 if (perms2[perms1[i]] != i)
314 return false;
315 return true;
316}
317
318// Primary overload for those with TosaElementwiseOperator trait.
319// The other ones handle the case of the operations that occur at the
320// roots of the data dependency graph (ConstOp, ReshapeOp, TransposeOp).
321std::optional<Value> TosaReduceTransposes::buildMappedToValue(
322 Operation *op, const DenseMap<Value, Value> &valuesMap,
323 IRRewriter &rewriter, ArrayRef<int32_t> hoistedPerms) {
324 if (op->getNumResults() != 1 ||
325 (!llvm::isa<tosa::MulOp>(op) &&
326 !op->hasTrait<OpTrait::tosa::TosaElementwiseOperator>()))
327 return std::nullopt;
328
329 auto resultType = llvm::cast<RankedTensorType>(op->getResult(idx: 0).getType());
330 SmallVector<Value, 3> operands;
331 for (Value v : op->getOperands()) {
332 if (valuesMap.contains(v)) {
333 operands.push_back(valuesMap.at(v));
334 } else if (llvm::isa<tosa::MulOp>(op) && v == op->getOperand(2)) {
335 // special case for MulOp's shift operand
336 operands.push_back(v);
337 } else {
338 return std::nullopt;
339 }
340 }
341
342 // Conceptually, we propagate the hoisted TransposeOp through
343 // these interveaning operations. For example,
344
345 // %0 = tosa.clamp %input : (tensor<2x3xi32>) -> tensor<2x3xi32>
346 // %1 = tosa.transpose %0 {perms = [1, 0]} : (tensor<2x3xi32>) ->
347 // tensor<3x2xi32>
348
349 // becomes:
350 // %0 = tosa.transpose %input {perms = [1, 0]} : (tensor<2x3xi32>) ->
351 // tensor<3x2xi32>
352 // %1 = tosa.clamp %0 : (tensor<3x2xi32>) -> tensor<3x2xi32>)
353
354 // We construct this new tosa.clamp here, but it doesn't
355 // turn "live" until the transpose being hoisted through this chain
356 // is replaced with the proper value from the new chain.
357
358 return rewriter
359 .create(op->getLoc(), op->getName().getIdentifier(), operands,
360 RankedTensorType::get(
361 applyTOSAPermutation(resultType.getShape(), hoistedPerms),
362 resultType.getElementType()),
363 op->getAttrs())
364 ->getResult(0);
365}
366
367std::optional<Value> TosaReduceTransposes::buildMappedToValue(
368 TransposeOp transposeOp, const DenseMap<Value, Value> &valuesMap,
369 IRRewriter &rewriter, ArrayRef<int32_t> hoistedPerms) {
370 if (!areInvolutionTransposes(perms1: hoistedPerms, perms2: transposeOp.getPerms()))
371 return std::nullopt;
372 return transposeOp.getInput1();
373}
374
375std::optional<Value> TosaReduceTransposes::buildMappedToValue(
376 ReshapeOp reshapeOp, const DenseMap<Value, Value> &valuesMap,
377 IRRewriter &rewriter, ArrayRef<int32_t> hoistedPerms) {
378 auto reshapeOutput = reshapeOp.getOutput();
379 auto reshapeInputType =
380 llvm::dyn_cast<RankedTensorType>(reshapeOp.getInput1().getType());
381 auto reshapeInputShape = reshapeInputType.getShape();
382 // want reshape N -> 1x1x...x1xNx1x...x1x1
383 if (!reshapeInputType || reshapeInputShape.size() != 1)
384 return std::nullopt;
385 auto reshapeOutputType =
386 llvm::cast<RankedTensorType>(reshapeOutput.getType());
387
388 // Instead of inserting a TransposeOp here, we check if we can fold it into
389 // the ReshapeOp. There is more complex cases where this is possible, and
390 // this check can be extended.
391
392 // Checking if reshape is N -> 1x1x...x1xNx1x...x1x1
393 auto shape = reshapeOutputType.getShape();
394 size_t ones = llvm::count(shape, 1);
395 // N == 1 and N != 1
396 if (ones != shape.size() - 1 &&
397 !(ones == shape.size() && reshapeInputShape[0] == 1))
398 return std::nullopt;
399
400 // Do not insert a TransposeOp, instead we fold the reshape and its attribute.
401 llvm::SmallVector<int64_t> newShape;
402 if (!tosa::getConstShapeValues(op: reshapeOp.getShape().getDefiningOp(),
403 result_shape&: newShape)) {
404 // this mean shape is not constant
405 return std::nullopt;
406 }
407 ImplicitLocOpBuilder builder(reshapeOp.getLoc(), rewriter);
408 auto foldedReshape = rewriter.create<ReshapeOp>(
409 reshapeOp.getLoc(),
410 RankedTensorType::get(applyTOSAPermutation(shape, hoistedPerms),
411 reshapeOutputType.getElementType()),
412 reshapeOp.getInput1(),
413 getTosaConstShape(builder, applyTOSAPermutation(llvm::ArrayRef(newShape),
414 hoistedPerms)));
415 return foldedReshape->getResult(0);
416}
417
418std::optional<Value> TosaReduceTransposes::buildMappedToValue(
419 ConstOp constOp, const DenseMap<Value, Value> &valuesMap,
420 IRRewriter &rewriter, ArrayRef<int32_t> hoistedPerms) {
421 auto denseAttr = llvm::dyn_cast<DenseElementsAttr>(constOp.getValues());
422 if (!denseAttr)
423 return std::nullopt;
424 auto maybeNewDenseAttr = transposeDenseAttribute(denseAttr, hoistedPerms);
425 if (!maybeNewDenseAttr.has_value())
426 return std::nullopt;
427 auto newDenseAttr = maybeNewDenseAttr.value();
428 auto newConstOp = rewriter.create<ConstOp>(
429 constOp.getLoc(), newDenseAttr.getType(), newDenseAttr);
430 return newConstOp->getResult(0);
431}
432
433bool TosaReduceTransposes::convertDependentOps(
434 SetVector<Operation *> &dependentOps, DenseMap<Value, Value> &valuesMap,
435 IRRewriter &rewriter, ArrayRef<int32_t> hoistedPerms) {
436
437 for (Operation *op : dependentOps) {
438 if (!op || op->getNumResults() != 1)
439 return false;
440
441 Value priorValue = op->getResult(0);
442
443 // It's possible on a prior transposeOp we had the same dependency and
444 // already resolved it.
445 if (valuesMap.contains(priorValue))
446 continue;
447
448 // Keep converted ops close to the original.
449 rewriter.setInsertionPointAfter(op);
450
451 std::optional<Value> maybeValue =
452 llvm::TypeSwitch<Operation *, std::optional<Value>>(op)
453 .Case<TransposeOp, ReshapeOp, ConstOp>([&](auto transposeOp) {
454 return buildMappedToValue(transposeOp, valuesMap, rewriter,
455 hoistedPerms);
456 })
457 .Default([&](Operation *op) {
458 return buildMappedToValue(op, valuesMap, rewriter, hoistedPerms);
459 });
460
461 if (!maybeValue.has_value())
462 return false;
463
464 valuesMap[priorValue] = maybeValue.value();
465 }
466
467 return true;
468}
469
470bool TosaReduceTransposes::userNotContainedInValidTransposeDependencies(
471 Operation *user, std::set<TransposeOp> &validTransposes,
472 std::vector<std::pair<TransposeOp, SetVector<Operation *>>>
473 &transposeInfo) {
474 return llvm::none_of(
475 transposeInfo,
476 [&validTransposes,
477 user](const std::pair<TransposeOp, SetVector<Operation *>> &info) {
478 const auto &[transposeOp, dependentOps] = info;
479 return validTransposes.count(transposeOp) &&
480 dependentOps.contains(user);
481 });
482}
483
484// Dependencies are valid for an operation if none of them occur outside
485// of the proper fan-in cones of the hoisted TransposeOp with the same perms
486// that we can replace. Described in more detail within.
487bool TosaReduceTransposes::dependenciesAreValid(
488 ArrayRef<int32_t> perms, const SetVector<Operation *> &dependentOps,
489 std::set<TransposeOp> &validTransposes,
490 std::vector<std::pair<TransposeOp, SetVector<Operation *>>>
491 &transposeInfo) {
492 for (Operation *op : dependentOps) {
493
494 // It's OK wherever ConstOp has uses -- in the worst case, we duplicate.
495 // This can be changed later if we find the memory impact is too high.
496 if (llvm::isa<ConstOp>(op))
497 continue;
498
499 for (OpOperand &use : op->getUses()) {
500 // Want the uses to be (1) contained in the dependentOps of other
501 // validTransposes, or (2) to be directly used in a TransposeOp with the
502 // same perms. For (2) it means the fan-in is a subset of our
503 // dependentOps, so it is also a validTranspose that will eventually be
504 // replaced.
505 Operation *user = use.getOwner();
506 if (auto otherTranspose = llvm::dyn_cast<TransposeOp>(user)) {
507 // Can later think about cases where transpose -> transpose
508 // or reshape -> transpose, where the transposes are not necessarily
509 // the same perms as the hoisted, if implementing a more general
510 // transform. These could be permitted.
511 if (!llvm::equal(perms, otherTranspose.getPerms()))
512 return false;
513 } else if (userNotContainedInValidTransposeDependencies(
514 user, validTransposes, transposeInfo)) {
515 return false;
516 }
517 }
518 }
519
520 return true;
521}
522
523// Getting the set of TransposeOp that we can replace without causing
524// the old fan-in cones of any TransposeOp to remain "live", i.e, -- not being
525// dead code. This is done by iterating the set until convergence, since
526// if you are used outside your own fan-in cone, it's possible to be used
527// in another fan-in cone of a TransposeOp that is being replaced -- unless
528// we find that that one has a usage outside of it too.
529std::set<TransposeOp> TosaReduceTransposes::getGoodReplacements(
530 ArrayRef<int32_t> perms,
531 std::vector<std::pair<TransposeOp, SetVector<Operation *>>>
532 &transposeInfo) {
533 // Initially, we assume they are all good to replace,
534 // and we whittle them down based on our criteria.
535 std::set<TransposeOp> ableToReplace;
536 for (const auto &[transposeOp, _] : transposeInfo)
537 ableToReplace.insert(transposeOp);
538
539 bool gotRid;
540 do {
541 gotRid = false;
542 for (const auto &[transposeOp, dependentOps] : transposeInfo) {
543 // We don't care about it. Already invalidated.
544 if (!ableToReplace.count(transposeOp))
545 continue;
546
547 // Check for validity.
548 if (!dependenciesAreValid(perms, dependentOps, ableToReplace,
549 transposeInfo)) {
550 ableToReplace.erase(transposeOp);
551 gotRid = true;
552 break;
553 }
554 }
555
556 } while (gotRid);
557
558 return ableToReplace;
559}
560
561void TosaReduceTransposes::runOnOperation() {
562 // We want to operate only within a single block.
563 if (!getOperation().getRegion().hasOneBlock())
564 return;
565
566 IRRewriter rewriter(&getContext());
567 // For each perms, maintain a mapping for converted ops, avoid duplication.
568 DenseMap<ArrayRef<int32_t>, DenseMap<Value, Value>> permsToValues;
569 // For each perms, we keep track of which TransposeOp are eligible
570 // for replacement alongside their dependentOps.
571 DenseMap<ArrayRef<int32_t>,
572 std::vector<std::pair<TransposeOp, SetVector<Operation *>>>>
573 permsToTransposeInfo;
574
575 // Necessary for lifetime, since DenseMap keeps a copy of the ArrayRef.
576 // Use SmallVector for perms (common-case is <= 4) but std::vector otherwise
577 // since no guarantee of smallness.
578 std::vector<SmallVector<int32_t>> collectedPerms;
579
580 // This keeps track of the order across all eligible-for-replacement
581 // TransposeOp and their perms, a necessity for the final replacements.
582 std::stack<std::pair<TransposeOp, ArrayRef<int32_t>>> totalTransposeOrder;
583
584 // We want to reserve the space up front, since SmallVector stores some data
585 // internally and the ArrayRef can reference that, which we don't want to get
586 // invalidated.
587 size_t expectedMaxPerms = 0;
588 getOperation().walk([&](TransposeOp) { expectedMaxPerms += 1; });
589 collectedPerms.reserve(expectedMaxPerms);
590
591 getOperation().walk([&](TransposeOp transposeOp) {
592 SetVector<Operation *> dependentOps;
593 collectedPerms.emplace_back();
594 SmallVector<int32_t> &perms = collectedPerms.back();
595
596 // Dynamic shapes are OK, but the incompatible ones will be rejected later.
597 auto input = transposeOp.getInput1();
598 auto output = transposeOp.getOutput();
599
600 // However, we don't support unranked tensors.
601 if (!llvm::isa<RankedTensorType>(input.getType()) ||
602 !llvm::isa<RankedTensorType>(output.getType()))
603 return;
604
605 llvm::for_each(transposeOp.getPerms(),
606 [&perms](const auto i) { perms.emplace_back(i); });
607
608 // We let --canonicalize deal with identity transpose.
609 if (llvm::equal(llvm::seq<int32_t>(0, perms.size()), perms))
610 return;
611
612 // Can fail if some set of basic invariants is not met that we want to
613 // perform our conversions.
614 if (!collectFanIn(input.getDefiningOp(), dependentOps))
615 return;
616
617 // Want to associate valuesMap for already converted of the same perms,
618 // since it's possible multiple hoisted transposes w/ different perms
619 // converge on an op, which would result in different transformations.
620 DenseMap<Value, Value> &valuesMap = permsToValues[perms];
621
622 // Attempt to perform the conversions and placements into IR
623 // without turning inserted code "live". Also fills out valuesMap.
624 // Fails if there is an intermediary we do not support.
625 if (!convertDependentOps(dependentOps, valuesMap, rewriter, perms))
626 // Some additional operations may have been inserted, but will be
627 // removed by dead code elimination.
628 return;
629
630 // This should not happen. If it does -- it's unexpected,
631 // so we fail the pass.
632 if (!valuesMap.contains(input))
633 return signalPassFailure();
634
635 // It's possible the types are not compatible (because of dynamic shapes),
636 // and in these cases, want to resolve dynamic shapes before running the
637 // pass.
638 if (output.getType() != valuesMap.at(input).getType())
639 return;
640
641 auto &transposeInfo = permsToTransposeInfo[perms];
642
643 // In general, we might also want to introduce "newDependentOps"
644 // if there are new usages that don't fall inside the original fan-ins
645 // (like the TransposeOp we insert for ReshapeOp),
646 // but in this case, that is specialized enough and overlaps
647 // with another direct-use TransposeOp case we need to cover anyway.
648 transposeInfo.push_back({transposeOp, dependentOps});
649
650 // This is for the final replacement across all transposes.
651 totalTransposeOrder.push({transposeOp, perms});
652 });
653
654 // We want to do a full fan-in analysis on a perms-level,
655 // since if we do it on a multi-perms level, and they share (due to a shared
656 // dependency on a Reshape) then we would also get duplicate ops.
657 // Const is special cased.
658 std::set<TransposeOp> ableToReplace;
659 for (auto &[perms, transposeInfo] : permsToTransposeInfo) {
660 // Gives us back replacements that would never result in any duplicate
661 // operations being inserted by us in the IR (i.e, our goal is only to
662 // remove transposes, and not create a "new chain" to do so, but replace
663 // the existing chains).
664 // Ideally, --canonicalize is run before this pass, since it helps this
665 // analysis by removing dead code to allow more potentially acceptable
666 // transformations.
667 auto goodReplacementsForPerms = getGoodReplacements(perms, transposeInfo);
668 ableToReplace.insert(goodReplacementsForPerms.begin(),
669 goodReplacementsForPerms.end());
670 }
671
672 // We want to do replacement across all transposes
673 // in reverse order, due to invalidation of valuesMap mappings
674 // if we did it otherwise.
675 while (!totalTransposeOrder.empty()) {
676 auto [transposeOp, perms] = totalTransposeOrder.top();
677 totalTransposeOrder.pop();
678
679 if (ableToReplace.count(transposeOp) == 0)
680 continue;
681
682 auto &valuesMap = permsToValues[perms];
683 auto input = transposeOp.getInput1();
684
685 // The purpose of this reverse iteration
686 // is to avoid valuesMap invalidation. If it happens,
687 // something is wrong.
688 if (!valuesMap.contains(input))
689 return signalPassFailure();
690
691 rewriter.replaceOp(transposeOp, valuesMap.at(input));
692 }
693
694 // We can remove all dead code by going in reverse.
695 // This is because we would remove usages before we
696 // see the users.
697 getOperation().walk<WalkOrder::PostOrder, ReverseIterator>(
698 [&](Operation *op) {
699 if (isOpTriviallyDead(op))
700 rewriter.eraseOp(op);
701 });
702}
703
704} // namespace
705

source code of mlir/lib/Dialect/Tosa/Transforms/TosaReduceTransposes.cpp