1 | //===- Interchange.cpp - Linalg interchange transformation ----------------===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | // |
9 | // This file implements the linalg interchange transformation. |
10 | // |
11 | //===----------------------------------------------------------------------===// |
12 | |
13 | #include "mlir/Dialect/Affine/IR/AffineOps.h" |
14 | #include "mlir/Dialect/Linalg/IR/Linalg.h" |
15 | #include "mlir/Dialect/Linalg/Transforms/Transforms.h" |
16 | #include "mlir/Dialect/Linalg/Utils/Utils.h" |
17 | #include "mlir/Dialect/Utils/IndexingUtils.h" |
18 | #include "mlir/Dialect/Utils/StructuredOpsUtils.h" |
19 | #include "mlir/Dialect/Vector/IR/VectorOps.h" |
20 | #include "mlir/IR/AffineExpr.h" |
21 | #include "mlir/IR/Matchers.h" |
22 | #include "mlir/IR/PatternMatch.h" |
23 | #include "mlir/Pass/Pass.h" |
24 | #include "mlir/Support/LLVM.h" |
25 | #include "llvm/ADT/ScopeExit.h" |
26 | #include "llvm/Support/Debug.h" |
27 | #include "llvm/Support/raw_ostream.h" |
28 | #include <type_traits> |
29 | |
30 | #define DEBUG_TYPE "linalg-interchange" |
31 | |
32 | using namespace mlir; |
33 | using namespace mlir::linalg; |
34 | |
35 | static LogicalResult |
36 | interchangeGenericOpPrecondition(GenericOp genericOp, |
37 | ArrayRef<unsigned> interchangeVector) { |
38 | // Interchange vector must be non-empty and match the number of loops. |
39 | if (interchangeVector.empty() || |
40 | genericOp.getNumLoops() != interchangeVector.size()) |
41 | return failure(); |
42 | // Permutation map must be invertible. |
43 | if (!inversePermutation(AffineMap::getPermutationMap(interchangeVector, |
44 | genericOp.getContext()))) |
45 | return failure(); |
46 | return success(); |
47 | } |
48 | |
49 | FailureOr<GenericOp> |
50 | mlir::linalg::interchangeGenericOp(RewriterBase &rewriter, GenericOp genericOp, |
51 | ArrayRef<unsigned> interchangeVector) { |
52 | if (failed(interchangeGenericOpPrecondition(genericOp, interchangeVector))) |
53 | return rewriter.notifyMatchFailure(genericOp, "preconditions not met" ); |
54 | |
55 | // 1. Compute the inverse permutation map, it must be non-null since the |
56 | // preconditions are satisfied. |
57 | MLIRContext *context = genericOp.getContext(); |
58 | AffineMap permutationMap = inversePermutation( |
59 | map: AffineMap::getPermutationMap(permutation: interchangeVector, context)); |
60 | assert(permutationMap && "unexpected null map" ); |
61 | |
62 | // Start a guarded inplace update. |
63 | rewriter.startOpModification(op: genericOp); |
64 | auto guard = llvm::make_scope_exit( |
65 | F: [&]() { rewriter.finalizeOpModification(op: genericOp); }); |
66 | |
67 | // 2. Compute the interchanged indexing maps. |
68 | SmallVector<AffineMap> newIndexingMaps; |
69 | for (OpOperand &opOperand : genericOp->getOpOperands()) { |
70 | AffineMap m = genericOp.getMatchingIndexingMap(&opOperand); |
71 | if (!permutationMap.isEmpty()) |
72 | m = m.compose(permutationMap); |
73 | newIndexingMaps.push_back(m); |
74 | } |
75 | genericOp.setIndexingMapsAttr( |
76 | rewriter.getAffineMapArrayAttr(newIndexingMaps)); |
77 | |
78 | // 3. Compute the interchanged iterator types. |
79 | ArrayRef<Attribute> itTypes = genericOp.getIteratorTypes().getValue(); |
80 | SmallVector<Attribute> itTypesVector; |
81 | llvm::append_range(C&: itTypesVector, R&: itTypes); |
82 | SmallVector<int64_t> permutation(interchangeVector.begin(), |
83 | interchangeVector.end()); |
84 | applyPermutationToVector(inVec&: itTypesVector, permutation); |
85 | genericOp.setIteratorTypesAttr(rewriter.getArrayAttr(itTypesVector)); |
86 | |
87 | // 4. Transform the index operations by applying the permutation map. |
88 | if (genericOp.hasIndexSemantics()) { |
89 | OpBuilder::InsertionGuard guard(rewriter); |
90 | for (IndexOp indexOp : |
91 | llvm::make_early_inc_range(genericOp.getBody()->getOps<IndexOp>())) { |
92 | rewriter.setInsertionPoint(indexOp); |
93 | SmallVector<Value> allIndices; |
94 | allIndices.reserve(genericOp.getNumLoops()); |
95 | llvm::transform(llvm::seq<uint64_t>(0, genericOp.getNumLoops()), |
96 | std::back_inserter(allIndices), [&](uint64_t dim) { |
97 | return rewriter.create<IndexOp>(indexOp->getLoc(), dim); |
98 | }); |
99 | rewriter.replaceOpWithNewOp<affine::AffineApplyOp>( |
100 | indexOp, permutationMap.getSubMap(indexOp.getDim()), allIndices); |
101 | } |
102 | } |
103 | |
104 | return genericOp; |
105 | } |
106 | |