1//===- TosaReduceTransposes.cpp -------------------------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9// ----------
10// Motivation:
11// ----------
12
13// Some legalization pathways introduce redundant tosa.TRANSPOSE
14// operations that result in avoidable data movement. For example,
15// PyTorch -> TOSA contains a lot of unnecessary transposes due
16// to conversions between NCHW and NHWC.
17
18// We wish to remove all the ones that we can, since in general
19// it is possible to remove the overwhelming majority.
20
21// -------------------
22// High-Level Overview:
23// -------------------
24
25// The pass works through the transpose operators in the program. It begins at
26// some transpose operator with an associated permutations tensor. It traverses
27// upwards through the dependencies of this transpose and verifies that we
28// encounter only operators with the TosaElementwiseOperator trait and terminate
29// in either constants, reshapes, or transposes.
30
31// We then evaluate whether there are any additional restrictions (the
32// transposes it terminates in must invert the one we began at, and the reshapes
33// must be ones in which we can fold the transpose into), and then we hoist the
34// transpose through the intervening operators, folding it at the constants,
35// reshapes, and transposes.
36
37// Finally, we ensure that we do not need both the transposed form (the form
38// that had the transpose hoisted through it) and the untransposed form (which
39// it was prior), by analyzing the usages of those dependent operators of a
40// given transpose we are attempting to hoist and replace.
41
42// If they are such that it would require both forms to be necessary, then we do
43// not replace the hoisted transpose, causing the new chain to be dead.
44// Otherwise, we do and the old chain (untransposed form) becomes dead. Only one
45// chain will ever then be live, resulting in no duplication.
46
47// We then perform a simple one-pass DCE, so no canonicalization is necessary.
48
49// -----------
50// Future Work:
51// -----------
52
53// (1) Evaluate tradeoffs with permitting ConstOp to be duplicated across
54// hoisted
55// transposes with different permutation tensors.
56
57// (2) Expand the class of foldable upstream ReshapeOp we permit beyond
58// N -> 1x1x...x1xNx1x...x1x1.
59
60// (3) Enchance the pass to permit folding arbitrary transpose pairs, beyond
61// those that form the identity.
62
63// (4) Add support for more instructions besides TosaElementwiseOperator as
64// the intervening ones (for example, the reduce_* operators).
65
66// (5) Support hoisting transposes up to an input parameter.
67
68//===----------------------------------------------------------------------===//
69
70#include "mlir/Dialect/Func/IR/FuncOps.h"
71#include "mlir/Dialect/Tosa/IR/TosaOps.h"
72#include "mlir/Dialect/Tosa/Transforms/Passes.h"
73#include "mlir/Dialect/Tosa/Utils/ConversionUtils.h"
74#include "mlir/IR/Iterators.h"
75#include "llvm/ADT/TypeSwitch.h"
76#include <set>
77#include <stack>
78
79namespace mlir {
80namespace tosa {
81#define GEN_PASS_DEF_TOSAREDUCETRANSPOSES
82#include "mlir/Dialect/Tosa/Transforms/Passes.h.inc"
83} // namespace tosa
84} // namespace mlir
85
86using namespace mlir;
87using namespace mlir::tosa;
88
89//===----------------------------------------------------------------------===//
90// TOSA Reduce Transposes Pass.
91//===----------------------------------------------------------------------===//
92
93namespace {
94
95struct TosaReduceTransposes final
96 : public tosa::impl::TosaReduceTransposesBase<TosaReduceTransposes> {
97 void runOnOperation() override;
98
99private:
100 // This will collect all the data dependencies for the given Operation
101 // up to and including ConstOp, ReshapeOp, and TransposeOp.
102 bool collectFanIn(Operation *op, SetVector<Operation *> &collected);
103 bool convertDependentOps(SetVector<Operation *> &dependentOps,
104 DenseMap<Value, Value> &valuesMap,
105 IRRewriter &rewriter,
106 ArrayRef<int32_t> hoistedPerms);
107
108 // Checks if the two permutations, when applied consecutively, result
109 // in the identity.
110 bool areInvolutionTransposes(ArrayRef<int32_t> perms1,
111 ArrayRef<int32_t> perms2);
112
113 // This is meant to apply to operations with the TosaElementwiseOperator
114 // trait.
115 std::optional<Value>
116 buildMappedToValue(Operation *op, const DenseMap<Value, Value> &valuesMap,
117 IRRewriter &rewriter, ArrayRef<int32_t> hoistedPerms);
118
119 // This updates valuesMap when we encounter another TransposeOp as a
120 // dependency of the hoisted one. %0 = tosa.transpose %arg0 <- applies to
121 // this %1 = tosa.transpose %0 <- when tracking back from this
122 std::optional<Value>
123 buildMappedToValue(TransposeOp transposeOp,
124 const DenseMap<Value, Value> &valuesMap,
125 IRRewriter &rewriter, ArrayRef<int32_t> hoistedPerms);
126
127 // Checks if ReshapeOp can have hoisted TransposeOp folded into it. If so,
128 // it creates new ReshapeOp with that fold.
129 std::optional<Value>
130 buildMappedToValue(ReshapeOp reshapeOp,
131 const DenseMap<Value, Value> &valuesMap,
132 IRRewriter &rewriter, ArrayRef<int32_t> hoistedPerms);
133
134 // We may have something like:
135 // %0 = tosa.const
136 // %1 = tosa.transpose
137 // %2 = tosa.add %0, %1
138 // %3 = tosa.transpose %2
139 // that --tosa-layerwise-const-fold wouldn't handle. This use shows up
140 // in MobilenetV3.
141 std::optional<Value>
142 buildMappedToValue(ConstOp constOp, const DenseMap<Value, Value> &valuesMap,
143 IRRewriter &rewriter, ArrayRef<int32_t> hoistedPerms);
144
145 // Checks which TransposeOp we should "replace", turning their converted
146 // chains of ops, through which they were propagated, "live", and the old code
147 // "dead." Attempts to avoid doing so when doing so would result in the old
148 // code staying "live," resulting in duplication.
149 std::set<TransposeOp> getGoodReplacements(
150 ArrayRef<int32_t> perms,
151 std::vector<std::pair<TransposeOp, SetVector<Operation *>>>
152 &transposeInfo);
153
154 // Helper function for dependenciesAreValid.
155 bool userNotContainedInValidTransposeDependencies(
156 Operation *user, std::set<TransposeOp> &validTransposes,
157 std::vector<std::pair<TransposeOp, SetVector<Operation *>>>
158 &transposeInfo);
159
160 // Helper function for getGoodReplacements to check if some TransposeOp's
161 // dependencies are OK.
162 bool dependenciesAreValid(
163 ArrayRef<int32_t> perms, const SetVector<Operation *> &dependentOps,
164 std::set<TransposeOp> &validTransposes,
165 std::vector<std::pair<TransposeOp, SetVector<Operation *>>>
166 &transposeInfo);
167
168 // Applies perms to the DenseElementsAttr.
169 // If it returns std::nullopt, it also triggers pass failure, since verifier
170 // guarantees from TOSA are not in place (and otherwise, if used elsewhere,
171 // it should fail).
172 // This is a basic API and may benefit from refactor into the core MLIR APIs.
173 std::optional<DenseElementsAttr>
174 transposeDenseAttribute(DenseElementsAttr input, ArrayRef<int32_t> perms);
175};
176
177std::optional<DenseElementsAttr>
178TosaReduceTransposes::transposeDenseAttribute(DenseElementsAttr input,
179 ArrayRef<int32_t> perms) {
180 RankedTensorType oldType = llvm::cast<RankedTensorType>(Val: input.getType());
181 RankedTensorType newType =
182 RankedTensorType::get(shape: applyTOSAPermutation(input: oldType.getShape(), perms),
183 elementType: oldType.getElementType());
184 size_t rank = oldType.getRank();
185
186 // Asserted by TransposeOp verifier and TOSA disallowing tensor with dimension
187 // 0. If not in place, something is very wrong.
188 if (rank <= 0 || oldType.getNumElements() <= 0) {
189 signalPassFailure();
190 return std::nullopt;
191 }
192
193 if (input.isSplat())
194 return input.reshape(newType);
195
196 // The algorithm is approximately as follows:
197 // input: perms, input flat array, input tensor type
198 // (1/2) determine the strides of input/output if
199 // they were strided in row-major order. (3) adjust the strides for the
200 // input to be in the same order of indices as the output is written.
201 // (4) process dimension by dimension. example: perms 2, 0, 1; input
202 // 2x3x4; output 4x2x3 for i ... 4, j ... 2, k ... 3: output[i][j][k] =
203 // input[j][k][i] output[6i + 3j + k] = input[12j + 4k + i] and we adjust
204 // input strides to be as input[i + 12j + 4k] so we may process
205 // layer-by-layer.
206
207 // Step 1/2: Strides for input. We ignore output since row-major and can just
208 // push_back.
209
210 SmallVector<int64_t> originalInputStrides(rank);
211 originalInputStrides[rank - 1] = 1;
212 // index with int64_t to avoid overflow
213 for (int64_t i = rank - 2; i >= 0; i--)
214 originalInputStrides[i] =
215 originalInputStrides[i + 1] * oldType.getDimSize(idx: i + 1);
216
217 // Step 3: Transpose strides of input to be same indexing (i, j, k, ...) as
218 // output which is done in row-major order.
219
220 SmallVector<int64_t> newInputStrides;
221 newInputStrides.reserve(N: rank);
222 for (int32_t v : perms)
223 newInputStrides.push_back(Elt: originalInputStrides[v]);
224
225 // Step 4: Write out the transposed "flat array" dimension by dimension.
226
227 auto inputArray = input.getValues<Attribute>();
228 SmallVector<std::pair<int64_t, int64_t>> boundsAndStrides;
229 for (size_t i = 0; i < rank; i++)
230 boundsAndStrides.push_back(Elt: {newType.getDimSize(idx: i), newInputStrides[i]});
231
232 SmallVector<Attribute> resultArray;
233 resultArray.reserve(N: inputArray.size());
234
235 std::function<void(int64_t,
236 SmallVector<std::pair<int64_t, int64_t>>::const_iterator)>
237 processTransposeDim = [&](auto accumulatedIndex, auto it) {
238 if (it == boundsAndStrides.end()) {
239 resultArray.push_back(Elt: inputArray[accumulatedIndex]);
240 return;
241 }
242
243 for (int64_t i = 0; i < it->first; i++) {
244 int64_t j = accumulatedIndex + i * it->second;
245 processTransposeDim(j, it + 1);
246 }
247 };
248
249 processTransposeDim(0, boundsAndStrides.begin());
250
251 return DenseElementsAttr::get(type: newType, values: resultArray);
252}
253
254// The SetVector should only contain ConstOp, ReshapeOp, TransposeOp
255// as the sources of the data dependencies, and TosaElementWiseOperator
256// after that, if the function returns true.
257bool TosaReduceTransposes::collectFanIn(Operation *op,
258 SetVector<Operation *> &collected) {
259 // Can occur if defined through the parameter to a func.func.
260 if (!op)
261 return false;
262
263 if (!llvm::isa_and_present<tosa::TosaDialect>(Val: op->getDialect()))
264 return false;
265
266 // Prevent extra work if already seen.
267 if (collected.contains(key: op))
268 return true;
269
270 // Throw it out so later don't have to deal with this.
271 if (op->getNumResults() != 1 ||
272 !llvm::isa<RankedTensorType>(Val: op->getResult(idx: 0).getType()))
273 return false;
274
275 // We don't wish to traverse up a ReshapeOp, since generally we can't
276 // propagate a TransposeOp through it. TransposeOp, ReshapeOp, ConstOp
277 // will have no in-edges in the data dependency graph we construct for
278 // the downstream TransposeOp.
279 if (!llvm::isa<tosa::TransposeOp>(Val: op) && !llvm::isa<tosa::ReshapeOp>(Val: op) &&
280 !llvm::isa<tosa::ConstOp>(Val: op)) {
281
282 if (!llvm::isa<tosa::MulOp>(Val: op) &&
283 !op->hasTrait<OpTrait::tosa::TosaElementwiseOperator>())
284 return false;
285
286 for (Value operand : op->getOperands()) {
287 // If this is a problem in future, think about alternatives to recursion.
288 if (llvm::isa<tosa::MulOp>(Val: op) && operand == op->getOperand(idx: 2)) {
289 // do not recurse into MulOp's shift operand
290 continue;
291 }
292 if (!collectFanIn(op: operand.getDefiningOp(), collected))
293 return false;
294 }
295 }
296
297 // Insert in topological order.
298 collected.insert(X: op);
299
300 return true;
301}
302
303// Assuming that due to the verification of TransposeOp perms arrays are
304// permutations of 0 - perms.size() - 1.
305bool TosaReduceTransposes::areInvolutionTransposes(ArrayRef<int32_t> perms1,
306 ArrayRef<int32_t> perms2) {
307 if (perms1.size() != perms2.size())
308 return false;
309 int32_t n = perms1.size();
310 for (int32_t i = 0; i < n; i++)
311 if (perms2[perms1[i]] != i)
312 return false;
313 return true;
314}
315
316// Primary overload for those with TosaElementwiseOperator trait.
317// The other ones handle the case of the operations that occur at the
318// roots of the data dependency graph (ConstOp, ReshapeOp, TransposeOp).
319std::optional<Value> TosaReduceTransposes::buildMappedToValue(
320 Operation *op, const DenseMap<Value, Value> &valuesMap,
321 IRRewriter &rewriter, ArrayRef<int32_t> hoistedPerms) {
322 if (op->getNumResults() != 1 ||
323 (!llvm::isa<tosa::MulOp>(Val: op) &&
324 !op->hasTrait<OpTrait::tosa::TosaElementwiseOperator>()))
325 return std::nullopt;
326
327 auto resultType = llvm::cast<RankedTensorType>(Val: op->getResult(idx: 0).getType());
328 SmallVector<Value, 3> operands;
329 for (Value v : op->getOperands()) {
330 if (valuesMap.contains(Val: v)) {
331 operands.push_back(Elt: valuesMap.at(Val: v));
332 } else if (llvm::isa<tosa::MulOp>(Val: op) && v == op->getOperand(idx: 2)) {
333 // special case for MulOp's shift operand
334 operands.push_back(Elt: v);
335 } else {
336 return std::nullopt;
337 }
338 }
339
340 // Conceptually, we propagate the hoisted TransposeOp through
341 // these interveaning operations. For example,
342
343 // %0 = tosa.clamp %input : (tensor<2x3xi32>) -> tensor<2x3xi32>
344 // %1 = tosa.transpose %0 {perms = [1, 0]} : (tensor<2x3xi32>) ->
345 // tensor<3x2xi32>
346
347 // becomes:
348 // %0 = tosa.transpose %input {perms = [1, 0]} : (tensor<2x3xi32>) ->
349 // tensor<3x2xi32>
350 // %1 = tosa.clamp %0 : (tensor<3x2xi32>) -> tensor<3x2xi32>)
351
352 // We construct this new tosa.clamp here, but it doesn't
353 // turn "live" until the transpose being hoisted through this chain
354 // is replaced with the proper value from the new chain.
355
356 return rewriter
357 .create(loc: op->getLoc(), opName: op->getName().getIdentifier(), operands,
358 types: RankedTensorType::get(
359 shape: applyTOSAPermutation(input: resultType.getShape(), perms: hoistedPerms),
360 elementType: resultType.getElementType()),
361 attributes: op->getAttrs())
362 ->getResult(idx: 0);
363}
364
365std::optional<Value> TosaReduceTransposes::buildMappedToValue(
366 TransposeOp transposeOp, const DenseMap<Value, Value> &valuesMap,
367 IRRewriter &rewriter, ArrayRef<int32_t> hoistedPerms) {
368 if (!areInvolutionTransposes(perms1: hoistedPerms, perms2: transposeOp.getPerms()))
369 return std::nullopt;
370 return transposeOp.getInput1();
371}
372
373std::optional<Value> TosaReduceTransposes::buildMappedToValue(
374 ReshapeOp reshapeOp, const DenseMap<Value, Value> &valuesMap,
375 IRRewriter &rewriter, ArrayRef<int32_t> hoistedPerms) {
376 auto reshapeOutput = reshapeOp.getOutput();
377 auto reshapeInputType =
378 llvm::dyn_cast<RankedTensorType>(Val: reshapeOp.getInput1().getType());
379 auto reshapeInputShape = reshapeInputType.getShape();
380 // want reshape N -> 1x1x...x1xNx1x...x1x1
381 if (!reshapeInputType || reshapeInputShape.size() != 1)
382 return std::nullopt;
383 auto reshapeOutputType =
384 llvm::cast<RankedTensorType>(Val: reshapeOutput.getType());
385
386 // Instead of inserting a TransposeOp here, we check if we can fold it into
387 // the ReshapeOp. There is more complex cases where this is possible, and
388 // this check can be extended.
389
390 // Checking if reshape is N -> 1x1x...x1xNx1x...x1x1
391 auto shape = reshapeOutputType.getShape();
392 size_t ones = llvm::count(Range&: shape, Element: 1);
393 // N == 1 and N != 1
394 if (ones != shape.size() - 1 &&
395 !(ones == shape.size() && reshapeInputShape[0] == 1))
396 return std::nullopt;
397
398 // Do not insert a TransposeOp, instead we fold the reshape and its attribute.
399 llvm::SmallVector<int64_t> newShape;
400 if (!tosa::getConstShapeValues(op: reshapeOp.getShape().getDefiningOp(),
401 result_shape&: newShape)) {
402 // this mean shape is not constant
403 return std::nullopt;
404 }
405 ImplicitLocOpBuilder builder(reshapeOp.getLoc(), rewriter);
406 auto foldedReshape = rewriter.create<ReshapeOp>(
407 location: reshapeOp.getLoc(),
408 args: RankedTensorType::get(shape: applyTOSAPermutation(input: shape, perms: hoistedPerms),
409 elementType: reshapeOutputType.getElementType()),
410 args: reshapeOp.getInput1(),
411 args: getTosaConstShape(builder, shape: applyTOSAPermutation(input: llvm::ArrayRef(newShape),
412 perms: hoistedPerms)));
413 return foldedReshape->getResult(idx: 0);
414}
415
416std::optional<Value> TosaReduceTransposes::buildMappedToValue(
417 ConstOp constOp, const DenseMap<Value, Value> &valuesMap,
418 IRRewriter &rewriter, ArrayRef<int32_t> hoistedPerms) {
419 auto denseAttr = llvm::dyn_cast<DenseElementsAttr>(Val: constOp.getValues());
420 if (!denseAttr)
421 return std::nullopt;
422 auto maybeNewDenseAttr = transposeDenseAttribute(input: denseAttr, perms: hoistedPerms);
423 if (!maybeNewDenseAttr.has_value())
424 return std::nullopt;
425 auto newDenseAttr = maybeNewDenseAttr.value();
426 auto newConstOp = rewriter.create<ConstOp>(
427 location: constOp.getLoc(), args: newDenseAttr.getType(), args&: newDenseAttr);
428 return newConstOp->getResult(idx: 0);
429}
430
431bool TosaReduceTransposes::convertDependentOps(
432 SetVector<Operation *> &dependentOps, DenseMap<Value, Value> &valuesMap,
433 IRRewriter &rewriter, ArrayRef<int32_t> hoistedPerms) {
434
435 for (Operation *op : dependentOps) {
436 if (!op || op->getNumResults() != 1)
437 return false;
438
439 Value priorValue = op->getResult(idx: 0);
440
441 // It's possible on a prior transposeOp we had the same dependency and
442 // already resolved it.
443 if (valuesMap.contains(Val: priorValue))
444 continue;
445
446 // Keep converted ops close to the original.
447 rewriter.setInsertionPointAfter(op);
448
449 std::optional<Value> maybeValue =
450 llvm::TypeSwitch<Operation *, std::optional<Value>>(op)
451 .Case<TransposeOp, ReshapeOp, ConstOp>(caseFn: [&](auto transposeOp) {
452 return buildMappedToValue(transposeOp, valuesMap, rewriter,
453 hoistedPerms);
454 })
455 .Default(defaultFn: [&](Operation *op) {
456 return buildMappedToValue(op, valuesMap, rewriter, hoistedPerms);
457 });
458
459 if (!maybeValue.has_value())
460 return false;
461
462 valuesMap[priorValue] = maybeValue.value();
463 }
464
465 return true;
466}
467
468bool TosaReduceTransposes::userNotContainedInValidTransposeDependencies(
469 Operation *user, std::set<TransposeOp> &validTransposes,
470 std::vector<std::pair<TransposeOp, SetVector<Operation *>>>
471 &transposeInfo) {
472 return llvm::none_of(
473 Range&: transposeInfo,
474 P: [&validTransposes,
475 user](const std::pair<TransposeOp, SetVector<Operation *>> &info) {
476 const auto &[transposeOp, dependentOps] = info;
477 return validTransposes.count(x: transposeOp) &&
478 dependentOps.contains(key: user);
479 });
480}
481
482// Dependencies are valid for an operation if none of them occur outside
483// of the proper fan-in cones of the hoisted TransposeOp with the same perms
484// that we can replace. Described in more detail within.
485bool TosaReduceTransposes::dependenciesAreValid(
486 ArrayRef<int32_t> perms, const SetVector<Operation *> &dependentOps,
487 std::set<TransposeOp> &validTransposes,
488 std::vector<std::pair<TransposeOp, SetVector<Operation *>>>
489 &transposeInfo) {
490 for (Operation *op : dependentOps) {
491
492 // It's OK wherever ConstOp has uses -- in the worst case, we duplicate.
493 // This can be changed later if we find the memory impact is too high.
494 if (llvm::isa<ConstOp>(Val: op))
495 continue;
496
497 for (OpOperand &use : op->getUses()) {
498 // Want the uses to be (1) contained in the dependentOps of other
499 // validTransposes, or (2) to be directly used in a TransposeOp with the
500 // same perms. For (2) it means the fan-in is a subset of our
501 // dependentOps, so it is also a validTranspose that will eventually be
502 // replaced.
503 Operation *user = use.getOwner();
504 if (auto otherTranspose = llvm::dyn_cast<TransposeOp>(Val: user)) {
505 // Can later think about cases where transpose -> transpose
506 // or reshape -> transpose, where the transposes are not necessarily
507 // the same perms as the hoisted, if implementing a more general
508 // transform. These could be permitted.
509 if (!llvm::equal(LRange&: perms, RRange: otherTranspose.getPerms()))
510 return false;
511 } else if (userNotContainedInValidTransposeDependencies(
512 user, validTransposes, transposeInfo)) {
513 return false;
514 }
515 }
516 }
517
518 return true;
519}
520
521// Getting the set of TransposeOp that we can replace without causing
522// the old fan-in cones of any TransposeOp to remain "live", i.e, -- not being
523// dead code. This is done by iterating the set until convergence, since
524// if you are used outside your own fan-in cone, it's possible to be used
525// in another fan-in cone of a TransposeOp that is being replaced -- unless
526// we find that that one has a usage outside of it too.
527std::set<TransposeOp> TosaReduceTransposes::getGoodReplacements(
528 ArrayRef<int32_t> perms,
529 std::vector<std::pair<TransposeOp, SetVector<Operation *>>>
530 &transposeInfo) {
531 // Initially, we assume they are all good to replace,
532 // and we whittle them down based on our criteria.
533 std::set<TransposeOp> ableToReplace;
534 for (const auto &[transposeOp, _] : transposeInfo)
535 ableToReplace.insert(x: transposeOp);
536
537 bool gotRid;
538 do {
539 gotRid = false;
540 for (const auto &[transposeOp, dependentOps] : transposeInfo) {
541 // We don't care about it. Already invalidated.
542 if (!ableToReplace.count(x: transposeOp))
543 continue;
544
545 // Check for validity.
546 if (!dependenciesAreValid(perms, dependentOps, validTransposes&: ableToReplace,
547 transposeInfo)) {
548 ableToReplace.erase(x: transposeOp);
549 gotRid = true;
550 break;
551 }
552 }
553
554 } while (gotRid);
555
556 return ableToReplace;
557}
558
559void TosaReduceTransposes::runOnOperation() {
560 // We want to operate only within a single block.
561 if (!getOperation().getRegion().hasOneBlock())
562 return;
563
564 IRRewriter rewriter(&getContext());
565 // For each perms, maintain a mapping for converted ops, avoid duplication.
566 DenseMap<ArrayRef<int32_t>, DenseMap<Value, Value>> permsToValues;
567 // For each perms, we keep track of which TransposeOp are eligible
568 // for replacement alongside their dependentOps.
569 DenseMap<ArrayRef<int32_t>,
570 std::vector<std::pair<TransposeOp, SetVector<Operation *>>>>
571 permsToTransposeInfo;
572
573 // Necessary for lifetime, since DenseMap keeps a copy of the ArrayRef.
574 // Use SmallVector for perms (common-case is <= 4) but std::vector otherwise
575 // since no guarantee of smallness.
576 std::vector<SmallVector<int32_t>> collectedPerms;
577
578 // This keeps track of the order across all eligible-for-replacement
579 // TransposeOp and their perms, a necessity for the final replacements.
580 std::stack<std::pair<TransposeOp, ArrayRef<int32_t>>> totalTransposeOrder;
581
582 // We want to reserve the space up front, since SmallVector stores some data
583 // internally and the ArrayRef can reference that, which we don't want to get
584 // invalidated.
585 size_t expectedMaxPerms = 0;
586 getOperation().walk(callback: [&](TransposeOp) { expectedMaxPerms += 1; });
587 collectedPerms.reserve(n: expectedMaxPerms);
588
589 getOperation().walk(callback: [&](TransposeOp transposeOp) {
590 SetVector<Operation *> dependentOps;
591 collectedPerms.emplace_back();
592 SmallVector<int32_t> &perms = collectedPerms.back();
593
594 // Dynamic shapes are OK, but the incompatible ones will be rejected later.
595 auto input = transposeOp.getInput1();
596 auto output = transposeOp.getOutput();
597
598 // However, we don't support unranked tensors.
599 if (!llvm::isa<RankedTensorType>(Val: input.getType()) ||
600 !llvm::isa<RankedTensorType>(Val: output.getType()))
601 return;
602
603 llvm::append_range(C&: perms, R: transposeOp.getPerms());
604
605 // We let --canonicalize deal with identity transpose.
606 if (llvm::equal(LRange: llvm::seq<int32_t>(Begin: 0, End: perms.size()), RRange&: perms))
607 return;
608
609 // Can fail if some set of basic invariants is not met that we want to
610 // perform our conversions.
611 if (!collectFanIn(op: input.getDefiningOp(), collected&: dependentOps))
612 return;
613
614 // Want to associate valuesMap for already converted of the same perms,
615 // since it's possible multiple hoisted transposes w/ different perms
616 // converge on an op, which would result in different transformations.
617 DenseMap<Value, Value> &valuesMap = permsToValues[perms];
618
619 // Attempt to perform the conversions and placements into IR
620 // without turning inserted code "live". Also fills out valuesMap.
621 // Fails if there is an intermediary we do not support.
622 if (!convertDependentOps(dependentOps, valuesMap, rewriter, hoistedPerms: perms))
623 // Some additional operations may have been inserted, but will be
624 // removed by dead code elimination.
625 return;
626
627 // This should not happen. If it does -- it's unexpected,
628 // so we fail the pass.
629 if (!valuesMap.contains(Val: input))
630 return signalPassFailure();
631
632 // It's possible the types are not compatible (because of dynamic shapes),
633 // and in these cases, want to resolve dynamic shapes before running the
634 // pass.
635 if (output.getType() != valuesMap.at(Val: input).getType())
636 return;
637
638 auto &transposeInfo = permsToTransposeInfo[perms];
639
640 // In general, we might also want to introduce "newDependentOps"
641 // if there are new usages that don't fall inside the original fan-ins
642 // (like the TransposeOp we insert for ReshapeOp),
643 // but in this case, that is specialized enough and overlaps
644 // with another direct-use TransposeOp case we need to cover anyway.
645 transposeInfo.push_back(x: {transposeOp, dependentOps});
646
647 // This is for the final replacement across all transposes.
648 totalTransposeOrder.push(x: {transposeOp, perms});
649 });
650
651 // We want to do a full fan-in analysis on a perms-level,
652 // since if we do it on a multi-perms level, and they share (due to a shared
653 // dependency on a Reshape) then we would also get duplicate ops.
654 // Const is special cased.
655 std::set<TransposeOp> ableToReplace;
656 for (auto &[perms, transposeInfo] : permsToTransposeInfo) {
657 // Gives us back replacements that would never result in any duplicate
658 // operations being inserted by us in the IR (i.e, our goal is only to
659 // remove transposes, and not create a "new chain" to do so, but replace
660 // the existing chains).
661 // Ideally, --canonicalize is run before this pass, since it helps this
662 // analysis by removing dead code to allow more potentially acceptable
663 // transformations.
664 auto goodReplacementsForPerms = getGoodReplacements(perms, transposeInfo);
665 ableToReplace.insert(first: goodReplacementsForPerms.begin(),
666 last: goodReplacementsForPerms.end());
667 }
668
669 // We want to do replacement across all transposes
670 // in reverse order, due to invalidation of valuesMap mappings
671 // if we did it otherwise.
672 while (!totalTransposeOrder.empty()) {
673 auto [transposeOp, perms] = totalTransposeOrder.top();
674 totalTransposeOrder.pop();
675
676 if (ableToReplace.count(x: transposeOp) == 0)
677 continue;
678
679 auto &valuesMap = permsToValues[perms];
680 auto input = transposeOp.getInput1();
681
682 // The purpose of this reverse iteration
683 // is to avoid valuesMap invalidation. If it happens,
684 // something is wrong.
685 if (!valuesMap.contains(Val: input))
686 return signalPassFailure();
687
688 rewriter.replaceOp(op: transposeOp, newValues: valuesMap.at(Val: input));
689 }
690
691 // We can remove all dead code by going in reverse.
692 // This is because we would remove usages before we
693 // see the users.
694 getOperation().walk<WalkOrder::PostOrder, ReverseIterator>(
695 callback: [&](Operation *op) {
696 if (isOpTriviallyDead(op))
697 rewriter.eraseOp(op);
698 });
699}
700
701} // namespace
702

source code of mlir/lib/Dialect/Tosa/Transforms/TosaReduceTransposes.cpp