1//===- StructuredOpsUtils.cpp - Utilities used by structured ops ----------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#include "mlir/Dialect/Utils/StructuredOpsUtils.h"
10#include "mlir/IR/AffineMap.h"
11#include "mlir/IR/Builders.h"
12#include "mlir/IR/BuiltinAttributes.h"
13#include "mlir/IR/IRMapping.h"
14#include "llvm/ADT/StringSet.h"
15
16#include "mlir/Dialect/Utils/DialectUtilsEnums.cpp.inc"
17
18using namespace mlir;
19
20bool mlir::isRowMajorMatmul(ArrayAttr indexingMaps) {
21 if (indexingMaps.size() != 3)
22 return false;
23
24 AffineMap map0 = cast<AffineMapAttr>(indexingMaps[0]).getValue();
25 AffineMap map1 = cast<AffineMapAttr>(indexingMaps[1]).getValue();
26 AffineMap map2 = cast<AffineMapAttr>(indexingMaps[2]).getValue();
27
28 if (map0.getNumResults() != 2 || map1.getNumResults() != 2 ||
29 map2.getNumResults() != 2 || map0.getNumInputs() != 3 ||
30 map1.getNumInputs() != 3 || map2.getNumInputs() != 3) {
31 return false;
32 }
33
34 // Extract dimensions for MxK * KxN -> MxN
35 AffineExpr m = map2.getResult(idx: 0);
36 AffineExpr n = map2.getResult(idx: 1);
37 AffineExpr k = map0.getResult(idx: 1);
38 auto *context = indexingMaps.getContext();
39 auto mapA = AffineMapAttr::get(AffineMap::get(3, 0, {m, k}, context));
40 auto mapB = AffineMapAttr::get(AffineMap::get(3, 0, {k, n}, context));
41 auto mapC = AffineMapAttr::get(AffineMap::get(3, 0, {m, n}, context));
42 auto maps = ArrayAttr::get(context, {mapA, mapB, mapC});
43 return indexingMaps == maps;
44}
45
46bool mlir::isColumnMajorMatmul(ArrayAttr indexingMaps) {
47 if (indexingMaps.size() != 3)
48 return false;
49
50 AffineMap map0 = cast<AffineMapAttr>(indexingMaps[0]).getValue();
51 AffineMap map1 = cast<AffineMapAttr>(indexingMaps[1]).getValue();
52 AffineMap map2 = cast<AffineMapAttr>(indexingMaps[2]).getValue();
53
54 if (map0.getNumResults() != 2 || map1.getNumResults() != 2 ||
55 map2.getNumResults() != 2 || map0.getNumInputs() != 3 ||
56 map1.getNumInputs() != 3 || map2.getNumInputs() != 3) {
57 return false;
58 }
59
60 // Extract dimensions for KxM * NxK -> NxM
61 AffineExpr n = map2.getResult(idx: 0);
62 AffineExpr m = map2.getResult(idx: 1);
63 AffineExpr k = map0.getResult(idx: 0);
64 auto *context = indexingMaps.getContext();
65 auto mapA = AffineMapAttr::get(AffineMap::get(3, 0, {k, m}, context));
66 auto mapB = AffineMapAttr::get(AffineMap::get(3, 0, {n, k}, context));
67 auto mapC = AffineMapAttr::get(AffineMap::get(3, 0, {n, m}, context));
68 auto maps = ArrayAttr::get(context, {mapA, mapB, mapC});
69 return indexingMaps == maps;
70}
71
72bool mlir::isRowMajorBatchMatmul(ArrayAttr indexingMaps) {
73 if (indexingMaps.size() != 3)
74 return false;
75
76 AffineMap map0 = cast<AffineMapAttr>(indexingMaps[0]).getValue();
77 AffineMap map1 = cast<AffineMapAttr>(indexingMaps[1]).getValue();
78 AffineMap map2 = cast<AffineMapAttr>(indexingMaps[2]).getValue();
79
80 if (map0.getNumResults() != 3 || map1.getNumResults() != 3 ||
81 map2.getNumResults() != 3 || map0.getNumInputs() != 4 ||
82 map1.getNumInputs() != 4 || map2.getNumInputs() != 4) {
83 return false;
84 }
85
86 // Extract dimensions for BxMxK * BxKxN -> BxMxN
87 AffineExpr b = map2.getResult(idx: 0);
88 AffineExpr m = map2.getResult(idx: 1);
89 AffineExpr n = map2.getResult(idx: 2);
90 AffineExpr k = map0.getResult(idx: 2);
91 auto *context = indexingMaps.getContext();
92 auto mapA = AffineMapAttr::get(AffineMap::get(4, 0, {b, m, k}, context));
93 auto mapB = AffineMapAttr::get(AffineMap::get(4, 0, {b, k, n}, context));
94 auto mapC = AffineMapAttr::get(AffineMap::get(4, 0, {b, m, n}, context));
95 auto maps = ArrayAttr::get(context, {mapA, mapB, mapC});
96 return indexingMaps == maps;
97}
98
99bool mlir::isVecmat(ArrayAttr indexingMaps) {
100 if (indexingMaps.size() != 3)
101 return false;
102 AffineMap map0 = cast<AffineMapAttr>(indexingMaps[0]).getValue();
103 AffineMap map1 = cast<AffineMapAttr>(indexingMaps[1]).getValue();
104 AffineMap map2 = cast<AffineMapAttr>(indexingMaps[2]).getValue();
105
106 if (map0.getNumResults() != 1 || map1.getNumResults() != 2 ||
107 map2.getNumResults() != 1 || map0.getNumInputs() != 2 ||
108 map1.getNumInputs() != 2 || map2.getNumInputs() != 2) {
109 return false;
110 }
111
112 // Extract dimensions for K * KxN -> N
113 AffineExpr k = map0.getResult(idx: 0);
114 AffineExpr n = map2.getResult(idx: 0);
115 auto *context = indexingMaps.getContext();
116 auto mapA = AffineMapAttr::get(AffineMap::get(2, 0, {k}, context));
117 auto mapB = AffineMapAttr::get(AffineMap::get(2, 0, {k, n}, context));
118 auto mapC = AffineMapAttr::get(AffineMap::get(2, 0, {n}, context));
119 auto maps = ArrayAttr::get(context, {mapA, mapB, mapC});
120 return indexingMaps == maps;
121}
122
123bool mlir::isBatchVecmat(ArrayAttr indexingMaps) {
124 if (indexingMaps.size() != 3)
125 return false;
126 AffineMap map0 = cast<AffineMapAttr>(indexingMaps[0]).getValue();
127 AffineMap map1 = cast<AffineMapAttr>(indexingMaps[1]).getValue();
128 AffineMap map2 = cast<AffineMapAttr>(indexingMaps[2]).getValue();
129
130 if (map0.getNumResults() != 2 || map1.getNumResults() != 3 ||
131 map2.getNumResults() != 2 || map0.getNumInputs() != 3 ||
132 map1.getNumInputs() != 3 || map2.getNumInputs() != 3) {
133 return false;
134 }
135
136 // Extract dimensions for B*K * B*K*N -> B*N
137 AffineExpr b = map0.getResult(idx: 0);
138 AffineExpr k = map0.getResult(idx: 1);
139 AffineExpr n = map2.getResult(idx: 1);
140 auto *context = indexingMaps.getContext();
141 auto mapA = AffineMapAttr::get(AffineMap::get(3, 0, {b, k}, context));
142 auto mapB = AffineMapAttr::get(AffineMap::get(3, 0, {b, k, n}, context));
143 auto mapC = AffineMapAttr::get(AffineMap::get(3, 0, {b, n}, context));
144 auto maps = ArrayAttr::get(context, {mapA, mapB, mapC});
145 return indexingMaps == maps;
146}
147
148bool mlir::isMatvec(ArrayAttr indexingMaps) {
149 if (indexingMaps.size() != 3)
150 return false;
151 AffineMap map0 = cast<AffineMapAttr>(indexingMaps[0]).getValue();
152 AffineMap map1 = cast<AffineMapAttr>(indexingMaps[1]).getValue();
153 AffineMap map2 = cast<AffineMapAttr>(indexingMaps[2]).getValue();
154
155 if (map0.getNumResults() != 2 || map1.getNumResults() != 1 ||
156 map2.getNumResults() != 1 || map0.getNumInputs() != 2 ||
157 map1.getNumInputs() != 2 || map2.getNumInputs() != 2) {
158 return false;
159 }
160
161 // Extract dimensions for N*K * K -> N
162 AffineExpr k = map1.getResult(idx: 0);
163 AffineExpr n = map2.getResult(idx: 0);
164 auto *context = indexingMaps.getContext();
165 auto mapA = AffineMapAttr::get(AffineMap::get(2, 0, {n, k}, context));
166 auto mapB = AffineMapAttr::get(AffineMap::get(2, 0, {k}, context));
167 auto mapC = AffineMapAttr::get(AffineMap::get(2, 0, {n}, context));
168 auto maps = ArrayAttr::get(context, {mapA, mapB, mapC});
169 return indexingMaps == maps;
170}
171
172bool mlir::isBatchMatvec(ArrayAttr indexingMaps) {
173 if (indexingMaps.size() != 3)
174 return false;
175 AffineMap map0 = cast<AffineMapAttr>(indexingMaps[0]).getValue();
176 AffineMap map1 = cast<AffineMapAttr>(indexingMaps[1]).getValue();
177 AffineMap map2 = cast<AffineMapAttr>(indexingMaps[2]).getValue();
178
179 if (map0.getNumResults() != 3 || map1.getNumResults() != 2 ||
180 map2.getNumResults() != 2 || map0.getNumInputs() != 3 ||
181 map1.getNumInputs() != 3 || map2.getNumInputs() != 3) {
182 return false;
183 }
184
185 // Extract dimensions for B*N*K * B*K -> B*N
186 AffineExpr b = map0.getResult(idx: 0);
187 AffineExpr k = map1.getResult(idx: 1);
188 AffineExpr n = map2.getResult(idx: 1);
189 auto *context = indexingMaps.getContext();
190 auto mapA = AffineMapAttr::get(AffineMap::get(3, 0, {b, n, k}, context));
191 auto mapB = AffineMapAttr::get(AffineMap::get(3, 0, {b, k}, context));
192 auto mapC = AffineMapAttr::get(AffineMap::get(3, 0, {b, n}, context));
193 auto maps = ArrayAttr::get(context, {mapA, mapB, mapC});
194 return indexingMaps == maps;
195}
196
197Operation *mlir::clone(OpBuilder &b, Operation *op, TypeRange newResultTypes,
198 ValueRange newOperands) {
199 IRMapping bvm;
200 OperationState state(op->getLoc(), op->getName(), newOperands, newResultTypes,
201 op->getAttrs());
202 for (Region &r : op->getRegions()) {
203 Region *newRegion = state.addRegion();
204 b.cloneRegionBefore(region&: r, parent&: *newRegion, before: newRegion->begin(), mapping&: bvm);
205 }
206 return b.create(state);
207}
208
209Operation *mlir::cloneWithoutRegions(OpBuilder &b, Operation *op,
210 TypeRange newResultTypes,
211 ValueRange newOperands) {
212 OperationState state(op->getLoc(), op->getName(), newOperands, newResultTypes,
213 op->getAttrs());
214 for (size_t cnt = 0, e = op->getNumRegions(); cnt < e; ++cnt)
215 state.addRegion();
216 return b.create(state);
217}
218
219SmallVector<NamedAttribute>
220mlir::getPrunedAttributeList(Operation *op, ArrayRef<StringRef> elidedAttrs) {
221 llvm::StringSet<> elidedAttrsSet;
222 elidedAttrsSet.insert(begin: elidedAttrs.begin(), end: elidedAttrs.end());
223 SmallVector<NamedAttribute> attrs;
224 for (auto attr : op->getAttrs()) {
225 if (elidedAttrsSet.count(attr.getName()))
226 continue;
227 attrs.push_back(Elt: attr);
228 }
229 return attrs;
230}
231

source code of mlir/lib/Dialect/Utils/StructuredOpsUtils.cpp