1//===- Interchange.cpp - Linalg interchange transformation ----------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implements the linalg interchange transformation.
10//
11//===----------------------------------------------------------------------===//
12
13#include "mlir/Dialect/Affine/IR/AffineOps.h"
14#include "mlir/Dialect/Linalg/IR/Linalg.h"
15#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
16#include "mlir/Dialect/Linalg/Utils/Utils.h"
17#include "mlir/Dialect/Utils/IndexingUtils.h"
18#include "mlir/Dialect/Utils/StructuredOpsUtils.h"
19#include "mlir/Dialect/Vector/IR/VectorOps.h"
20#include "mlir/IR/AffineExpr.h"
21#include "mlir/IR/PatternMatch.h"
22#include "mlir/Support/LLVM.h"
23#include "llvm/ADT/ScopeExit.h"
24
25#define DEBUG_TYPE "linalg-interchange"
26
27using namespace mlir;
28using namespace mlir::linalg;
29
30static LogicalResult
31interchangeGenericOpPrecondition(GenericOp genericOp,
32 ArrayRef<unsigned> interchangeVector) {
33 // Interchange vector must be non-empty and match the number of loops.
34 if (interchangeVector.empty() ||
35 genericOp.getNumLoops() != interchangeVector.size())
36 return failure();
37 // Permutation map must be invertible.
38 if (!inversePermutation(map: AffineMap::getPermutationMap(permutation: interchangeVector,
39 context: genericOp.getContext())))
40 return failure();
41 return success();
42}
43
44FailureOr<GenericOp>
45mlir::linalg::interchangeGenericOp(RewriterBase &rewriter, GenericOp genericOp,
46 ArrayRef<unsigned> interchangeVector) {
47 if (failed(Result: interchangeGenericOpPrecondition(genericOp, interchangeVector)))
48 return rewriter.notifyMatchFailure(arg&: genericOp, msg: "preconditions not met");
49
50 // 1. Compute the inverse permutation map, it must be non-null since the
51 // preconditions are satisfied.
52 MLIRContext *context = genericOp.getContext();
53 AffineMap permutationMap = inversePermutation(
54 map: AffineMap::getPermutationMap(permutation: interchangeVector, context));
55 assert(permutationMap && "unexpected null map");
56
57 // Start a guarded inplace update.
58 rewriter.startOpModification(op: genericOp);
59 auto guard = llvm::make_scope_exit(
60 F: [&]() { rewriter.finalizeOpModification(op: genericOp); });
61
62 // 2. Compute the interchanged indexing maps.
63 SmallVector<AffineMap> newIndexingMaps;
64 for (OpOperand &opOperand : genericOp->getOpOperands()) {
65 AffineMap m = genericOp.getMatchingIndexingMap(opOperand: &opOperand);
66 if (!permutationMap.isEmpty())
67 m = m.compose(map: permutationMap);
68 newIndexingMaps.push_back(Elt: m);
69 }
70 genericOp.setIndexingMapsAttr(
71 rewriter.getAffineMapArrayAttr(values: newIndexingMaps));
72
73 // 3. Compute the interchanged iterator types.
74 ArrayRef<Attribute> itTypes = genericOp.getIteratorTypes().getValue();
75 SmallVector<Attribute> itTypesVector;
76 llvm::append_range(C&: itTypesVector, R&: itTypes);
77 SmallVector<int64_t> permutation(interchangeVector);
78 applyPermutationToVector(inVec&: itTypesVector, permutation);
79 genericOp.setIteratorTypesAttr(rewriter.getArrayAttr(value: itTypesVector));
80
81 // 4. Transform the index operations by applying the permutation map.
82 if (genericOp.hasIndexSemantics()) {
83 OpBuilder::InsertionGuard guard(rewriter);
84 for (IndexOp indexOp :
85 llvm::make_early_inc_range(Range: genericOp.getBody()->getOps<IndexOp>())) {
86 rewriter.setInsertionPoint(indexOp);
87 SmallVector<Value> allIndices;
88 allIndices.reserve(N: genericOp.getNumLoops());
89 llvm::transform(Range: llvm::seq<uint64_t>(Begin: 0, End: genericOp.getNumLoops()),
90 d_first: std::back_inserter(x&: allIndices), F: [&](uint64_t dim) {
91 return rewriter.create<IndexOp>(location: indexOp->getLoc(), args&: dim);
92 });
93 rewriter.replaceOpWithNewOp<affine::AffineApplyOp>(
94 op: indexOp, args: permutationMap.getSubMap(resultPos: indexOp.getDim()), args&: allIndices);
95 }
96 }
97
98 return genericOp;
99}
100

source code of mlir/lib/Dialect/Linalg/Transforms/Interchange.cpp