1 | //===- TosaReduceTransposes.cpp -------------------------------------------===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | |
9 | // ---------- |
10 | // Motivation: |
11 | // ---------- |
12 | |
13 | // Some legalization pathways introduce redundant tosa.TRANSPOSE |
14 | // operations that result in avoidable data movement. For example, |
15 | // PyTorch -> TOSA contains a lot of unnecessary transposes due |
16 | // to conversions between NCHW and NHWC. |
17 | |
18 | // We wish to remove all the ones that we can, since in general |
19 | // it is possible to remove the overwhelming majority. |
20 | |
21 | // ------------------- |
22 | // High-Level Overview: |
23 | // ------------------- |
24 | |
25 | // The pass works through the transpose operators in the program. It begins at |
26 | // some transpose operator with an associated permutations tensor. It traverses |
27 | // upwards through the dependencies of this transpose and verifies that we |
28 | // encounter only operators with the TosaElementwiseOperator trait and terminate |
29 | // in either constants, reshapes, or transposes. |
30 | |
31 | // We then evaluate whether there are any additional restrictions (the |
32 | // transposes it terminates in must invert the one we began at, and the reshapes |
33 | // must be ones in which we can fold the transpose into), and then we hoist the |
34 | // transpose through the intervening operators, folding it at the constants, |
35 | // reshapes, and transposes. |
36 | |
37 | // Finally, we ensure that we do not need both the transposed form (the form |
38 | // that had the transpose hoisted through it) and the untransposed form (which |
39 | // it was prior), by analyzing the usages of those dependent operators of a |
40 | // given transpose we are attempting to hoist and replace. |
41 | |
42 | // If they are such that it would require both forms to be necessary, then we do |
43 | // not replace the hoisted transpose, causing the new chain to be dead. |
44 | // Otherwise, we do and the old chain (untransposed form) becomes dead. Only one |
45 | // chain will ever then be live, resulting in no duplication. |
46 | |
47 | // We then perform a simple one-pass DCE, so no canonicalization is necessary. |
48 | |
49 | // ----------- |
50 | // Future Work: |
51 | // ----------- |
52 | |
53 | // (1) Evaluate tradeoffs with permitting ConstOp to be duplicated across |
54 | // hoisted |
55 | // transposes with different permutation tensors. |
56 | |
57 | // (2) Expand the class of foldable upstream ReshapeOp we permit beyond |
58 | // N -> 1x1x...x1xNx1x...x1x1. |
59 | |
60 | // (3) Enchance the pass to permit folding arbitrary transpose pairs, beyond |
61 | // those that form the identity. |
62 | |
63 | // (4) Add support for more instructions besides TosaElementwiseOperator as |
64 | // the intervening ones (for example, the reduce_* operators). |
65 | |
66 | // (5) Support hoisting transposes up to an input parameter. |
67 | |
68 | //===----------------------------------------------------------------------===// |
69 | |
70 | #include "mlir/Dialect/Func/IR/FuncOps.h" |
71 | #include "mlir/Dialect/Tosa/IR/TosaOps.h" |
72 | #include "mlir/Dialect/Tosa/Transforms/Passes.h" |
73 | #include "mlir/Dialect/Tosa/Utils/ConversionUtils.h" |
74 | #include "mlir/IR/Iterators.h" |
75 | #include "mlir/IR/Matchers.h" |
76 | #include "llvm/ADT/TypeSwitch.h" |
77 | #include <memory> |
78 | #include <set> |
79 | #include <stack> |
80 | |
81 | namespace mlir { |
82 | namespace tosa { |
83 | #define GEN_PASS_DEF_TOSAREDUCETRANSPOSES |
84 | #include "mlir/Dialect/Tosa/Transforms/Passes.h.inc" |
85 | } // namespace tosa |
86 | } // namespace mlir |
87 | |
88 | using namespace mlir; |
89 | using namespace mlir::tosa; |
90 | |
91 | //===----------------------------------------------------------------------===// |
92 | // TOSA Reduce Transposes Pass. |
93 | //===----------------------------------------------------------------------===// |
94 | |
95 | namespace { |
96 | |
97 | struct TosaReduceTransposes final |
98 | : public tosa::impl::TosaReduceTransposesBase<TosaReduceTransposes> { |
99 | void runOnOperation() override; |
100 | |
101 | private: |
102 | // This will collect all the data dependencies for the given Operation |
103 | // up to and including ConstOp, ReshapeOp, and TransposeOp. |
104 | bool collectFanIn(Operation *op, SetVector<Operation *> &collected); |
105 | bool convertDependentOps(SetVector<Operation *> &dependentOps, |
106 | DenseMap<Value, Value> &valuesMap, |
107 | IRRewriter &rewriter, |
108 | ArrayRef<int32_t> hoistedPerms); |
109 | |
110 | // Checks if the two permutations, when applied consecutively, result |
111 | // in the identity. |
112 | bool areInvolutionTransposes(ArrayRef<int32_t> perms1, |
113 | ArrayRef<int32_t> perms2); |
114 | |
115 | // This is meant to apply to operations with the TosaElementwiseOperator |
116 | // trait. |
117 | std::optional<Value> |
118 | buildMappedToValue(Operation *op, const DenseMap<Value, Value> &valuesMap, |
119 | IRRewriter &rewriter, ArrayRef<int32_t> hoistedPerms); |
120 | |
121 | // This updates valuesMap when we encounter another TransposeOp as a |
122 | // dependency of the hoisted one. %0 = tosa.transpose %arg0 <- applies to |
123 | // this %1 = tosa.transpose %0 <- when tracking back from this |
124 | std::optional<Value> |
125 | buildMappedToValue(TransposeOp transposeOp, |
126 | const DenseMap<Value, Value> &valuesMap, |
127 | IRRewriter &rewriter, ArrayRef<int32_t> hoistedPerms); |
128 | |
129 | // Checks if ReshapeOp can have hoisted TransposeOp folded into it. If so, |
130 | // it creates new ReshapeOp with that fold. |
131 | std::optional<Value> |
132 | buildMappedToValue(ReshapeOp reshapeOp, |
133 | const DenseMap<Value, Value> &valuesMap, |
134 | IRRewriter &rewriter, ArrayRef<int32_t> hoistedPerms); |
135 | |
136 | // We may have something like: |
137 | // %0 = tosa.const |
138 | // %1 = tosa.transpose |
139 | // %2 = tosa.add %0, %1 |
140 | // %3 = tosa.transpose %2 |
141 | // that --tosa-layerwise-const-fold wouldn't handle. This use shows up |
142 | // in MobilenetV3. |
143 | std::optional<Value> |
144 | buildMappedToValue(ConstOp constOp, const DenseMap<Value, Value> &valuesMap, |
145 | IRRewriter &rewriter, ArrayRef<int32_t> hoistedPerms); |
146 | |
147 | // Checks which TransposeOp we should "replace", turning their converted |
148 | // chains of ops, through which they were propagated, "live", and the old code |
149 | // "dead." Attempts to avoid doing so when doing so would result in the old |
150 | // code staying "live," resulting in duplication. |
151 | std::set<TransposeOp> getGoodReplacements( |
152 | ArrayRef<int32_t> perms, |
153 | std::vector<std::pair<TransposeOp, SetVector<Operation *>>> |
154 | &transposeInfo); |
155 | |
156 | // Helper function for dependenciesAreValid. |
157 | bool userNotContainedInValidTransposeDependencies( |
158 | Operation *user, std::set<TransposeOp> &validTransposes, |
159 | std::vector<std::pair<TransposeOp, SetVector<Operation *>>> |
160 | &transposeInfo); |
161 | |
162 | // Helper function for getGoodReplacements to check if some TransposeOp's |
163 | // dependencies are OK. |
164 | bool dependenciesAreValid( |
165 | ArrayRef<int32_t> perms, const SetVector<Operation *> &dependentOps, |
166 | std::set<TransposeOp> &validTransposes, |
167 | std::vector<std::pair<TransposeOp, SetVector<Operation *>>> |
168 | &transposeInfo); |
169 | |
170 | // Applies perms to the DenseElementsAttr. |
171 | // If it returns std::nullopt, it also triggers pass failure, since verifier |
172 | // guarantees from TOSA are not in place (and otherwise, if used elsewhere, |
173 | // it should fail). |
174 | // This is a basic API and may benefit from refactor into the core MLIR APIs. |
175 | std::optional<DenseElementsAttr> |
176 | transposeDenseAttribute(DenseElementsAttr input, ArrayRef<int32_t> perms); |
177 | }; |
178 | |
179 | std::optional<DenseElementsAttr> |
180 | TosaReduceTransposes::transposeDenseAttribute(DenseElementsAttr input, |
181 | ArrayRef<int32_t> perms) { |
182 | RankedTensorType oldType = llvm::cast<RankedTensorType>(input.getType()); |
183 | RankedTensorType newType = |
184 | RankedTensorType::get(applyTOSAPermutation(oldType.getShape(), perms), |
185 | oldType.getElementType()); |
186 | size_t rank = oldType.getRank(); |
187 | |
188 | // Asserted by TransposeOp verifier and TOSA disallowing tensor with dimension |
189 | // 0. If not in place, something is very wrong. |
190 | if (rank <= 0 || oldType.getNumElements() <= 0) { |
191 | signalPassFailure(); |
192 | return std::nullopt; |
193 | } |
194 | |
195 | if (input.isSplat()) |
196 | return input.reshape(newType); |
197 | |
198 | // The algorithm is approximately as follows: |
199 | // input: perms, input flat array, input tensor type |
200 | // (1/2) determine the strides of input/output if |
201 | // they were strided in row-major order. (3) adjust the strides for the |
202 | // input to be in the same order of indices as the output is written. |
203 | // (4) process dimension by dimension. example: perms 2, 0, 1; input |
204 | // 2x3x4; output 4x2x3 for i ... 4, j ... 2, k ... 3: output[i][j][k] = |
205 | // input[j][k][i] output[6i + 3j + k] = input[12j + 4k + i] and we adjust |
206 | // input strides to be as input[i + 12j + 4k] so we may process |
207 | // layer-by-layer. |
208 | |
209 | // Step 1/2: Strides for input. We ignore output since row-major and can just |
210 | // push_back. |
211 | |
212 | SmallVector<int64_t> originalInputStrides(rank); |
213 | originalInputStrides[rank - 1] = 1; |
214 | // index with int64_t to avoid overflow |
215 | for (int64_t i = rank - 2; i >= 0; i--) |
216 | originalInputStrides[i] = |
217 | originalInputStrides[i + 1] * oldType.getDimSize(i + 1); |
218 | |
219 | // Step 3: Transpose strides of input to be same indexing (i, j, k, ...) as |
220 | // output which is done in row-major order. |
221 | |
222 | SmallVector<int64_t> newInputStrides; |
223 | newInputStrides.reserve(rank); |
224 | for (int32_t v : perms) |
225 | newInputStrides.push_back(originalInputStrides[v]); |
226 | |
227 | // Step 4: Write out the transposed "flat array" dimension by dimension. |
228 | |
229 | auto inputArray = input.getValues<Attribute>(); |
230 | SmallVector<std::pair<int64_t, int64_t>> boundsAndStrides; |
231 | for (size_t i = 0; i < rank; i++) |
232 | boundsAndStrides.push_back({newType.getDimSize(i), newInputStrides[i]}); |
233 | |
234 | SmallVector<Attribute> resultArray; |
235 | resultArray.reserve(inputArray.size()); |
236 | |
237 | std::function<void(int64_t, |
238 | SmallVector<std::pair<int64_t, int64_t>>::const_iterator)> |
239 | processTransposeDim = [&](auto accumulatedIndex, auto it) { |
240 | if (it == boundsAndStrides.end()) { |
241 | resultArray.push_back(inputArray[accumulatedIndex]); |
242 | return; |
243 | } |
244 | |
245 | for (int64_t i = 0; i < it->first; i++) { |
246 | int64_t j = accumulatedIndex + i * it->second; |
247 | processTransposeDim(j, it + 1); |
248 | } |
249 | }; |
250 | |
251 | processTransposeDim(0, boundsAndStrides.begin()); |
252 | |
253 | return DenseElementsAttr::get(newType, resultArray); |
254 | } |
255 | |
256 | // The SetVector should only contain ConstOp, ReshapeOp, TransposeOp |
257 | // as the sources of the data dependencies, and TosaElementWiseOperator |
258 | // after that, if the function returns true. |
259 | bool TosaReduceTransposes::collectFanIn(Operation *op, |
260 | SetVector<Operation *> &collected) { |
261 | // Can occur if defined through the parameter to a func.func. |
262 | if (!op) |
263 | return false; |
264 | |
265 | if (!llvm::isa_and_present<tosa::TosaDialect>(op->getDialect())) |
266 | return false; |
267 | |
268 | // Prevent extra work if already seen. |
269 | if (collected.contains(op)) |
270 | return true; |
271 | |
272 | // Throw it out so later don't have to deal with this. |
273 | if (op->getNumResults() != 1 || |
274 | !llvm::isa<RankedTensorType>(op->getResult(idx: 0).getType())) |
275 | return false; |
276 | |
277 | // We don't wish to traverse up a ReshapeOp, since generally we can't |
278 | // propagate a TransposeOp through it. TransposeOp, ReshapeOp, ConstOp |
279 | // will have no in-edges in the data dependency graph we construct for |
280 | // the downstream TransposeOp. |
281 | if (!llvm::isa<tosa::TransposeOp>(op) && !llvm::isa<tosa::ReshapeOp>(op) && |
282 | !llvm::isa<tosa::ConstOp>(op)) { |
283 | |
284 | if (!llvm::isa<tosa::MulOp>(op) && |
285 | !op->hasTrait<OpTrait::tosa::TosaElementwiseOperator>()) |
286 | return false; |
287 | |
288 | for (Value operand : op->getOperands()) { |
289 | // If this is a problem in future, think about alternatives to recursion. |
290 | if (llvm::isa<tosa::MulOp>(op) && operand == op->getOperand(2)) { |
291 | // do not recurse into MulOp's shift operand |
292 | continue; |
293 | } |
294 | if (!collectFanIn(operand.getDefiningOp(), collected)) |
295 | return false; |
296 | } |
297 | } |
298 | |
299 | // Insert in topological order. |
300 | collected.insert(op); |
301 | |
302 | return true; |
303 | } |
304 | |
305 | // Assuming that due to the verification of TransposeOp perms arrays are |
306 | // permutations of 0 - perms.size() - 1. |
307 | bool TosaReduceTransposes::areInvolutionTransposes(ArrayRef<int32_t> perms1, |
308 | ArrayRef<int32_t> perms2) { |
309 | if (perms1.size() != perms2.size()) |
310 | return false; |
311 | int32_t n = perms1.size(); |
312 | for (int32_t i = 0; i < n; i++) |
313 | if (perms2[perms1[i]] != i) |
314 | return false; |
315 | return true; |
316 | } |
317 | |
318 | // Primary overload for those with TosaElementwiseOperator trait. |
319 | // The other ones handle the case of the operations that occur at the |
320 | // roots of the data dependency graph (ConstOp, ReshapeOp, TransposeOp). |
321 | std::optional<Value> TosaReduceTransposes::buildMappedToValue( |
322 | Operation *op, const DenseMap<Value, Value> &valuesMap, |
323 | IRRewriter &rewriter, ArrayRef<int32_t> hoistedPerms) { |
324 | if (op->getNumResults() != 1 || |
325 | (!llvm::isa<tosa::MulOp>(op) && |
326 | !op->hasTrait<OpTrait::tosa::TosaElementwiseOperator>())) |
327 | return std::nullopt; |
328 | |
329 | auto resultType = llvm::cast<RankedTensorType>(op->getResult(idx: 0).getType()); |
330 | SmallVector<Value, 3> operands; |
331 | for (Value v : op->getOperands()) { |
332 | if (valuesMap.contains(v)) { |
333 | operands.push_back(valuesMap.at(v)); |
334 | } else if (llvm::isa<tosa::MulOp>(op) && v == op->getOperand(2)) { |
335 | // special case for MulOp's shift operand |
336 | operands.push_back(v); |
337 | } else { |
338 | return std::nullopt; |
339 | } |
340 | } |
341 | |
342 | // Conceptually, we propagate the hoisted TransposeOp through |
343 | // these interveaning operations. For example, |
344 | |
345 | // %0 = tosa.clamp %input : (tensor<2x3xi32>) -> tensor<2x3xi32> |
346 | // %1 = tosa.transpose %0 {perms = [1, 0]} : (tensor<2x3xi32>) -> |
347 | // tensor<3x2xi32> |
348 | |
349 | // becomes: |
350 | // %0 = tosa.transpose %input {perms = [1, 0]} : (tensor<2x3xi32>) -> |
351 | // tensor<3x2xi32> |
352 | // %1 = tosa.clamp %0 : (tensor<3x2xi32>) -> tensor<3x2xi32>) |
353 | |
354 | // We construct this new tosa.clamp here, but it doesn't |
355 | // turn "live" until the transpose being hoisted through this chain |
356 | // is replaced with the proper value from the new chain. |
357 | |
358 | return rewriter |
359 | .create(op->getLoc(), op->getName().getIdentifier(), operands, |
360 | RankedTensorType::get( |
361 | applyTOSAPermutation(resultType.getShape(), hoistedPerms), |
362 | resultType.getElementType()), |
363 | op->getAttrs()) |
364 | ->getResult(0); |
365 | } |
366 | |
367 | std::optional<Value> TosaReduceTransposes::buildMappedToValue( |
368 | TransposeOp transposeOp, const DenseMap<Value, Value> &valuesMap, |
369 | IRRewriter &rewriter, ArrayRef<int32_t> hoistedPerms) { |
370 | if (!areInvolutionTransposes(perms1: hoistedPerms, perms2: transposeOp.getPerms())) |
371 | return std::nullopt; |
372 | return transposeOp.getInput1(); |
373 | } |
374 | |
375 | std::optional<Value> TosaReduceTransposes::buildMappedToValue( |
376 | ReshapeOp reshapeOp, const DenseMap<Value, Value> &valuesMap, |
377 | IRRewriter &rewriter, ArrayRef<int32_t> hoistedPerms) { |
378 | auto reshapeOutput = reshapeOp.getOutput(); |
379 | auto reshapeInputType = |
380 | llvm::dyn_cast<RankedTensorType>(reshapeOp.getInput1().getType()); |
381 | auto reshapeInputShape = reshapeInputType.getShape(); |
382 | // want reshape N -> 1x1x...x1xNx1x...x1x1 |
383 | if (!reshapeInputType || reshapeInputShape.size() != 1) |
384 | return std::nullopt; |
385 | auto reshapeOutputType = |
386 | llvm::cast<RankedTensorType>(reshapeOutput.getType()); |
387 | |
388 | // Instead of inserting a TransposeOp here, we check if we can fold it into |
389 | // the ReshapeOp. There is more complex cases where this is possible, and |
390 | // this check can be extended. |
391 | |
392 | // Checking if reshape is N -> 1x1x...x1xNx1x...x1x1 |
393 | auto shape = reshapeOutputType.getShape(); |
394 | size_t ones = llvm::count(shape, 1); |
395 | // N == 1 and N != 1 |
396 | if (ones != shape.size() - 1 && |
397 | !(ones == shape.size() && reshapeInputShape[0] == 1)) |
398 | return std::nullopt; |
399 | |
400 | // Do not insert a TransposeOp, instead we fold the reshape and its attribute. |
401 | llvm::SmallVector<int64_t> newShape; |
402 | if (!tosa::getConstShapeValues(op: reshapeOp.getShape().getDefiningOp(), |
403 | result_shape&: newShape)) { |
404 | // this mean shape is not constant |
405 | return std::nullopt; |
406 | } |
407 | ImplicitLocOpBuilder builder(reshapeOp.getLoc(), rewriter); |
408 | auto foldedReshape = rewriter.create<ReshapeOp>( |
409 | reshapeOp.getLoc(), |
410 | RankedTensorType::get(applyTOSAPermutation(shape, hoistedPerms), |
411 | reshapeOutputType.getElementType()), |
412 | reshapeOp.getInput1(), |
413 | getTosaConstShape(builder, applyTOSAPermutation(llvm::ArrayRef(newShape), |
414 | hoistedPerms))); |
415 | return foldedReshape->getResult(0); |
416 | } |
417 | |
418 | std::optional<Value> TosaReduceTransposes::buildMappedToValue( |
419 | ConstOp constOp, const DenseMap<Value, Value> &valuesMap, |
420 | IRRewriter &rewriter, ArrayRef<int32_t> hoistedPerms) { |
421 | auto denseAttr = llvm::dyn_cast<DenseElementsAttr>(constOp.getValues()); |
422 | if (!denseAttr) |
423 | return std::nullopt; |
424 | auto maybeNewDenseAttr = transposeDenseAttribute(denseAttr, hoistedPerms); |
425 | if (!maybeNewDenseAttr.has_value()) |
426 | return std::nullopt; |
427 | auto newDenseAttr = maybeNewDenseAttr.value(); |
428 | auto newConstOp = rewriter.create<ConstOp>( |
429 | constOp.getLoc(), newDenseAttr.getType(), newDenseAttr); |
430 | return newConstOp->getResult(0); |
431 | } |
432 | |
433 | bool TosaReduceTransposes::convertDependentOps( |
434 | SetVector<Operation *> &dependentOps, DenseMap<Value, Value> &valuesMap, |
435 | IRRewriter &rewriter, ArrayRef<int32_t> hoistedPerms) { |
436 | |
437 | for (Operation *op : dependentOps) { |
438 | if (!op || op->getNumResults() != 1) |
439 | return false; |
440 | |
441 | Value priorValue = op->getResult(0); |
442 | |
443 | // It's possible on a prior transposeOp we had the same dependency and |
444 | // already resolved it. |
445 | if (valuesMap.contains(priorValue)) |
446 | continue; |
447 | |
448 | // Keep converted ops close to the original. |
449 | rewriter.setInsertionPointAfter(op); |
450 | |
451 | std::optional<Value> maybeValue = |
452 | llvm::TypeSwitch<Operation *, std::optional<Value>>(op) |
453 | .Case<TransposeOp, ReshapeOp, ConstOp>([&](auto transposeOp) { |
454 | return buildMappedToValue(transposeOp, valuesMap, rewriter, |
455 | hoistedPerms); |
456 | }) |
457 | .Default([&](Operation *op) { |
458 | return buildMappedToValue(op, valuesMap, rewriter, hoistedPerms); |
459 | }); |
460 | |
461 | if (!maybeValue.has_value()) |
462 | return false; |
463 | |
464 | valuesMap[priorValue] = maybeValue.value(); |
465 | } |
466 | |
467 | return true; |
468 | } |
469 | |
470 | bool TosaReduceTransposes::userNotContainedInValidTransposeDependencies( |
471 | Operation *user, std::set<TransposeOp> &validTransposes, |
472 | std::vector<std::pair<TransposeOp, SetVector<Operation *>>> |
473 | &transposeInfo) { |
474 | return llvm::none_of( |
475 | transposeInfo, |
476 | [&validTransposes, |
477 | user](const std::pair<TransposeOp, SetVector<Operation *>> &info) { |
478 | const auto &[transposeOp, dependentOps] = info; |
479 | return validTransposes.count(transposeOp) && |
480 | dependentOps.contains(user); |
481 | }); |
482 | } |
483 | |
484 | // Dependencies are valid for an operation if none of them occur outside |
485 | // of the proper fan-in cones of the hoisted TransposeOp with the same perms |
486 | // that we can replace. Described in more detail within. |
487 | bool TosaReduceTransposes::dependenciesAreValid( |
488 | ArrayRef<int32_t> perms, const SetVector<Operation *> &dependentOps, |
489 | std::set<TransposeOp> &validTransposes, |
490 | std::vector<std::pair<TransposeOp, SetVector<Operation *>>> |
491 | &transposeInfo) { |
492 | for (Operation *op : dependentOps) { |
493 | |
494 | // It's OK wherever ConstOp has uses -- in the worst case, we duplicate. |
495 | // This can be changed later if we find the memory impact is too high. |
496 | if (llvm::isa<ConstOp>(op)) |
497 | continue; |
498 | |
499 | for (OpOperand &use : op->getUses()) { |
500 | // Want the uses to be (1) contained in the dependentOps of other |
501 | // validTransposes, or (2) to be directly used in a TransposeOp with the |
502 | // same perms. For (2) it means the fan-in is a subset of our |
503 | // dependentOps, so it is also a validTranspose that will eventually be |
504 | // replaced. |
505 | Operation *user = use.getOwner(); |
506 | if (auto otherTranspose = llvm::dyn_cast<TransposeOp>(user)) { |
507 | // Can later think about cases where transpose -> transpose |
508 | // or reshape -> transpose, where the transposes are not necessarily |
509 | // the same perms as the hoisted, if implementing a more general |
510 | // transform. These could be permitted. |
511 | if (!llvm::equal(perms, otherTranspose.getPerms())) |
512 | return false; |
513 | } else if (userNotContainedInValidTransposeDependencies( |
514 | user, validTransposes, transposeInfo)) { |
515 | return false; |
516 | } |
517 | } |
518 | } |
519 | |
520 | return true; |
521 | } |
522 | |
523 | // Getting the set of TransposeOp that we can replace without causing |
524 | // the old fan-in cones of any TransposeOp to remain "live", i.e, -- not being |
525 | // dead code. This is done by iterating the set until convergence, since |
526 | // if you are used outside your own fan-in cone, it's possible to be used |
527 | // in another fan-in cone of a TransposeOp that is being replaced -- unless |
528 | // we find that that one has a usage outside of it too. |
529 | std::set<TransposeOp> TosaReduceTransposes::getGoodReplacements( |
530 | ArrayRef<int32_t> perms, |
531 | std::vector<std::pair<TransposeOp, SetVector<Operation *>>> |
532 | &transposeInfo) { |
533 | // Initially, we assume they are all good to replace, |
534 | // and we whittle them down based on our criteria. |
535 | std::set<TransposeOp> ableToReplace; |
536 | for (const auto &[transposeOp, _] : transposeInfo) |
537 | ableToReplace.insert(transposeOp); |
538 | |
539 | bool gotRid; |
540 | do { |
541 | gotRid = false; |
542 | for (const auto &[transposeOp, dependentOps] : transposeInfo) { |
543 | // We don't care about it. Already invalidated. |
544 | if (!ableToReplace.count(transposeOp)) |
545 | continue; |
546 | |
547 | // Check for validity. |
548 | if (!dependenciesAreValid(perms, dependentOps, ableToReplace, |
549 | transposeInfo)) { |
550 | ableToReplace.erase(transposeOp); |
551 | gotRid = true; |
552 | break; |
553 | } |
554 | } |
555 | |
556 | } while (gotRid); |
557 | |
558 | return ableToReplace; |
559 | } |
560 | |
561 | void TosaReduceTransposes::runOnOperation() { |
562 | // We want to operate only within a single block. |
563 | if (!getOperation().getRegion().hasOneBlock()) |
564 | return; |
565 | |
566 | IRRewriter rewriter(&getContext()); |
567 | // For each perms, maintain a mapping for converted ops, avoid duplication. |
568 | DenseMap<ArrayRef<int32_t>, DenseMap<Value, Value>> permsToValues; |
569 | // For each perms, we keep track of which TransposeOp are eligible |
570 | // for replacement alongside their dependentOps. |
571 | DenseMap<ArrayRef<int32_t>, |
572 | std::vector<std::pair<TransposeOp, SetVector<Operation *>>>> |
573 | permsToTransposeInfo; |
574 | |
575 | // Necessary for lifetime, since DenseMap keeps a copy of the ArrayRef. |
576 | // Use SmallVector for perms (common-case is <= 4) but std::vector otherwise |
577 | // since no guarantee of smallness. |
578 | std::vector<SmallVector<int32_t>> collectedPerms; |
579 | |
580 | // This keeps track of the order across all eligible-for-replacement |
581 | // TransposeOp and their perms, a necessity for the final replacements. |
582 | std::stack<std::pair<TransposeOp, ArrayRef<int32_t>>> totalTransposeOrder; |
583 | |
584 | // We want to reserve the space up front, since SmallVector stores some data |
585 | // internally and the ArrayRef can reference that, which we don't want to get |
586 | // invalidated. |
587 | size_t expectedMaxPerms = 0; |
588 | getOperation().walk([&](TransposeOp) { expectedMaxPerms += 1; }); |
589 | collectedPerms.reserve(expectedMaxPerms); |
590 | |
591 | getOperation().walk([&](TransposeOp transposeOp) { |
592 | SetVector<Operation *> dependentOps; |
593 | collectedPerms.emplace_back(); |
594 | SmallVector<int32_t> &perms = collectedPerms.back(); |
595 | |
596 | // Dynamic shapes are OK, but the incompatible ones will be rejected later. |
597 | auto input = transposeOp.getInput1(); |
598 | auto output = transposeOp.getOutput(); |
599 | |
600 | // However, we don't support unranked tensors. |
601 | if (!llvm::isa<RankedTensorType>(input.getType()) || |
602 | !llvm::isa<RankedTensorType>(output.getType())) |
603 | return; |
604 | |
605 | llvm::for_each(transposeOp.getPerms(), |
606 | [&perms](const auto i) { perms.emplace_back(i); }); |
607 | |
608 | // We let --canonicalize deal with identity transpose. |
609 | if (llvm::equal(llvm::seq<int32_t>(0, perms.size()), perms)) |
610 | return; |
611 | |
612 | // Can fail if some set of basic invariants is not met that we want to |
613 | // perform our conversions. |
614 | if (!collectFanIn(input.getDefiningOp(), dependentOps)) |
615 | return; |
616 | |
617 | // Want to associate valuesMap for already converted of the same perms, |
618 | // since it's possible multiple hoisted transposes w/ different perms |
619 | // converge on an op, which would result in different transformations. |
620 | DenseMap<Value, Value> &valuesMap = permsToValues[perms]; |
621 | |
622 | // Attempt to perform the conversions and placements into IR |
623 | // without turning inserted code "live". Also fills out valuesMap. |
624 | // Fails if there is an intermediary we do not support. |
625 | if (!convertDependentOps(dependentOps, valuesMap, rewriter, perms)) |
626 | // Some additional operations may have been inserted, but will be |
627 | // removed by dead code elimination. |
628 | return; |
629 | |
630 | // This should not happen. If it does -- it's unexpected, |
631 | // so we fail the pass. |
632 | if (!valuesMap.contains(input)) |
633 | return signalPassFailure(); |
634 | |
635 | // It's possible the types are not compatible (because of dynamic shapes), |
636 | // and in these cases, want to resolve dynamic shapes before running the |
637 | // pass. |
638 | if (output.getType() != valuesMap.at(input).getType()) |
639 | return; |
640 | |
641 | auto &transposeInfo = permsToTransposeInfo[perms]; |
642 | |
643 | // In general, we might also want to introduce "newDependentOps" |
644 | // if there are new usages that don't fall inside the original fan-ins |
645 | // (like the TransposeOp we insert for ReshapeOp), |
646 | // but in this case, that is specialized enough and overlaps |
647 | // with another direct-use TransposeOp case we need to cover anyway. |
648 | transposeInfo.push_back({transposeOp, dependentOps}); |
649 | |
650 | // This is for the final replacement across all transposes. |
651 | totalTransposeOrder.push({transposeOp, perms}); |
652 | }); |
653 | |
654 | // We want to do a full fan-in analysis on a perms-level, |
655 | // since if we do it on a multi-perms level, and they share (due to a shared |
656 | // dependency on a Reshape) then we would also get duplicate ops. |
657 | // Const is special cased. |
658 | std::set<TransposeOp> ableToReplace; |
659 | for (auto &[perms, transposeInfo] : permsToTransposeInfo) { |
660 | // Gives us back replacements that would never result in any duplicate |
661 | // operations being inserted by us in the IR (i.e, our goal is only to |
662 | // remove transposes, and not create a "new chain" to do so, but replace |
663 | // the existing chains). |
664 | // Ideally, --canonicalize is run before this pass, since it helps this |
665 | // analysis by removing dead code to allow more potentially acceptable |
666 | // transformations. |
667 | auto goodReplacementsForPerms = getGoodReplacements(perms, transposeInfo); |
668 | ableToReplace.insert(goodReplacementsForPerms.begin(), |
669 | goodReplacementsForPerms.end()); |
670 | } |
671 | |
672 | // We want to do replacement across all transposes |
673 | // in reverse order, due to invalidation of valuesMap mappings |
674 | // if we did it otherwise. |
675 | while (!totalTransposeOrder.empty()) { |
676 | auto [transposeOp, perms] = totalTransposeOrder.top(); |
677 | totalTransposeOrder.pop(); |
678 | |
679 | if (ableToReplace.count(transposeOp) == 0) |
680 | continue; |
681 | |
682 | auto &valuesMap = permsToValues[perms]; |
683 | auto input = transposeOp.getInput1(); |
684 | |
685 | // The purpose of this reverse iteration |
686 | // is to avoid valuesMap invalidation. If it happens, |
687 | // something is wrong. |
688 | if (!valuesMap.contains(input)) |
689 | return signalPassFailure(); |
690 | |
691 | rewriter.replaceOp(transposeOp, valuesMap.at(input)); |
692 | } |
693 | |
694 | // We can remove all dead code by going in reverse. |
695 | // This is because we would remove usages before we |
696 | // see the users. |
697 | getOperation().walk<WalkOrder::PostOrder, ReverseIterator>( |
698 | [&](Operation *op) { |
699 | if (isOpTriviallyDead(op)) |
700 | rewriter.eraseOp(op); |
701 | }); |
702 | } |
703 | |
704 | } // namespace |
705 | |