1 | //===- WinogradConv2D.cpp - Winograd Conv2D implementation ----------------===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | // |
9 | // Implement Winograd Conv2D algorithm. The implementation is based on the |
10 | // paper: Fast Algorithms for Convolutional Neural Networks |
11 | // (https://arxiv.org/abs/1509.09308) |
12 | // |
13 | //===----------------------------------------------------------------------===// |
14 | |
15 | #include "mlir/Dialect/Affine/IR/AffineOps.h" |
16 | #include "mlir/Dialect/Arith/IR/Arith.h" |
17 | #include "mlir/Dialect/Linalg/IR/Linalg.h" |
18 | #include "mlir/Dialect/Linalg/Utils/Utils.h" |
19 | #include "mlir/Dialect/Tensor/IR/Tensor.h" |
20 | #include "mlir/Dialect/Utils/StaticValueUtils.h" |
21 | #include "mlir/Transforms/GreedyPatternRewriteDriver.h" |
22 | #include "llvm/Support/MathExtras.h" |
23 | |
24 | namespace mlir { |
25 | namespace linalg { |
26 | |
27 | namespace { |
28 | |
29 | // clang-format off |
30 | /// Winograd Conv2D uses a minimal 2D filtering algorithm to calculate its |
31 | /// result. The formula of minimal 2D filtering algorithm F(m x m, r x r), |
32 | /// m is the output dimension and r is the filter dimension, is |
33 | /// |
34 | /// Y = A^T x [ (G x g x G^T) x (B^T x d x B) ] x A |
35 | /// |
36 | /// g is filter and d is input data. We need to prepare 6 constant |
37 | /// transformation matrices, G, G^T, B^T, B, A^T, and A for this formula. |
38 | /// |
39 | /// The following tables define these constant transformation matrices for |
40 | /// F(2 x 2, 3 x 3), F(4 x 4, 3 x 3), and F(2 x 2, 5 x 5) |
41 | constexpr float G_2x2_3x3[] = { |
42 | -1, 0, 0, |
43 | 1./2, -1./2, 1./2, |
44 | 1./2, 1./2, 1./2, |
45 | 0, 0, 1 |
46 | }; |
47 | |
48 | constexpr float GT_2x2_3x3[] = { |
49 | -1, 1./2, 1./2, 0, |
50 | 0, -1./2, 1./2, 0, |
51 | 0, 1./2, 1./2, 1 |
52 | }; |
53 | |
54 | constexpr float BT_2x2_3x3[] = { |
55 | -1, 0, 1, 0, |
56 | 0, -1, 1, 0, |
57 | 0, 1, 1, 0, |
58 | 0, -1, 0, 1 |
59 | }; |
60 | |
61 | constexpr float B_2x2_3x3[] = { |
62 | -1, 0, 0, 0, |
63 | 0, -1, 1, -1, |
64 | 1, 1, 1, 0, |
65 | 0, 0, 0, 1 |
66 | }; |
67 | |
68 | constexpr float AT_2x2_3x3[] = { |
69 | 1, 1, 1, 0, |
70 | 0, -1, 1, 1 |
71 | }; |
72 | |
73 | constexpr float A_2x2_3x3[] = { |
74 | 1, 0, |
75 | 1, -1, |
76 | 1, 1, |
77 | 0, 1 |
78 | }; |
79 | |
80 | constexpr float G_4x4_3x3[] = { |
81 | 1, 0, 0, |
82 | -1./3, 1./3, -1./3, |
83 | -1./3, -1./3, -1./3, |
84 | 1./12, -1./6, 1./3, |
85 | 1./12, 1./6, 1./3, |
86 | 0, 0, 1 |
87 | }; |
88 | |
89 | constexpr float GT_4x4_3x3[] = { |
90 | 1, -1./3, -1./3, 1./12, 1./12, 0, |
91 | 0, 1./3, -1./3, -1./6, 1./6, 0, |
92 | 0, -1./3, -1./3, 1./3, 1./3, 1 |
93 | }; |
94 | |
95 | constexpr float BT_4x4_3x3[] = { |
96 | 1./4, 0, -5./16, 0, 1./16, 0, |
97 | 0, 1./4, -1./4, -1./16, 1./16, 0, |
98 | 0, -1./4, -1./4, 1./16, 1./16, 0, |
99 | 0, 1./4, -1./8, -1./4, 1./8, 0, |
100 | 0, -1./4, -1./8, 1./4, 1./8, 0, |
101 | 0, 1./4, 0, -5./16, 0, 1./16 |
102 | }; |
103 | |
104 | constexpr float B_4x4_3x3[] = { |
105 | 1./4, 0, 0, 0, 0, 0, |
106 | 0, 1./4, -1./4, 1./4, -1./4, 1./4, |
107 | -5./16, -1./4, -1./4, -1./8, -1./8, 0, |
108 | 0, -1./16, 1./16, -1./4, 1./4, -5./16, |
109 | 1./16, 1./16, 1./16, 1./8, 1./8, 0, |
110 | 0, 0, 0, 0, 0, 1./16 |
111 | }; |
112 | |
113 | constexpr float AT_4x4_3x3[] = { |
114 | 1./8, 1./4, 1./4, 1./8, 1./8, 0, |
115 | 0, -1./4, 1./4, -1./4, 1./4, 0, |
116 | 0, 1./4, 1./4, 1./2, 1./2, 0, |
117 | 0, -1./4, 1./4, -1, 1, 1./2 |
118 | }; |
119 | |
120 | constexpr float A_4x4_3x3[] = { |
121 | 1./8, 0, 0, 0, |
122 | 1./4, -1./4, 1./4, -1./4, |
123 | 1./4, 1./4, 1./4, 1./4, |
124 | 1./8, -1./4, 1./2, -1, |
125 | 1./8, 1./4, 1./2, 1, |
126 | 0, 0, 0, 1./2 |
127 | }; |
128 | |
129 | constexpr float G_2x2_5x5[] = { |
130 | 1, 0, 0, 0, 0, |
131 | 1./6, -1./6, 1./6, -1./6, 1./6, |
132 | -1./6, -1./6, -1./6, -1./6, -1./6, |
133 | -4./15, 2./15, -1./15, 1./30, -1./60, |
134 | 1./60, 1./30, 1./15, 2./15, 4./15, |
135 | 0, 0, 0, 0, 1 |
136 | }; |
137 | |
138 | constexpr float GT_2x2_5x5[] = { |
139 | 1, 1./6, -1./6, -4./15, 1./60, 0, |
140 | 0, -1./6, -1./6, 2./15, 1./30, 0, |
141 | 0, 1./6, -1./6, -1./15, 1./15, 0, |
142 | 0, -1./6, -1./6, 1./30, 2./15, 0, |
143 | 0, 1./6, -1./6, -1./60, 4./15, 1 |
144 | }; |
145 | |
146 | constexpr float BT_2x2_5x5[] = { |
147 | 1./8, 3./16, -1./4, -3./16, 1./8, 0, |
148 | 0, 1./8, 1./16, -5./16, 1./8, 0, |
149 | 0, -1./8, -5./16, -1./16, 1./8, 0, |
150 | 0, 1./4, -1./8, -1./4, 1./8, 0, |
151 | 0, -1./8, -1./4, 1./8, 1./4, 0, |
152 | 0, 1./8, 3./16, -1./4, -3./16, 1./8 |
153 | }; |
154 | |
155 | constexpr float B_2x2_5x5[] = { |
156 | 1./8, 0, 0, 0, 0, 0, |
157 | 3./16, 1./8, -1./8, 1./4, -1./8, 1./8, |
158 | -1./4, 1./16, -5./16, -1./8, -1./4, 3./16, |
159 | -3./16, -5./16, -1./16, -1./4, 1./8, -1./4, |
160 | 1./8, 1./8, 1./8, 1./8, 1./4, -3./16, |
161 | 0, 0, 0, 0, 0, 1./8 |
162 | }; |
163 | |
164 | constexpr float AT_2x2_5x5[] = { |
165 | 1./2, 1, 1, 2, 1, 0, |
166 | 0, -1, 1, -1, 2, 1./2 |
167 | }; |
168 | |
169 | constexpr float A_2x2_5x5[] = { |
170 | 1./2, 0, |
171 | 1, -1, |
172 | 1, 1, |
173 | 2, -1, |
174 | 1, 2, |
175 | 0, 1./2 |
176 | }; |
177 | // clang-format on |
178 | |
179 | using TransformMapKeyTy = std::pair<int, int>; |
180 | |
181 | /// We use F(m, r) to define the size of minimal filtering algorithms. |
182 | /// m is the output dimension and r is the filter dimension. We can get |
183 | /// the input dimension, alpha, from the formula, alpha = m + r - 1. |
184 | /// |
185 | /// For example, when m = 2 and r = 3, we know its input size is 4. |
186 | /// The Conv2D will operate on 4x4 input data with 3x3 filter and get |
187 | /// 2x2 output result. |
188 | constexpr TransformMapKeyTy F_2_3{2, 3}; |
189 | constexpr TransformMapKeyTy F_4_3{4, 3}; |
190 | constexpr TransformMapKeyTy F_2_5{2, 5}; |
191 | |
192 | /// Structure to keep information of constant transform matrices. |
193 | struct TransformMatrix { |
194 | TransformMatrix(const float *table, int64_t rows, int64_t cols, |
195 | int64_t scalarFactor = 1) |
196 | : table(table), rows(rows), cols(cols), scalarFactor(scalarFactor) {} |
197 | |
198 | const float *table; |
199 | int64_t rows; |
200 | int64_t cols; |
201 | int64_t scalarFactor; |
202 | }; |
203 | |
204 | /// Utility function to convert constant array to arith.constant Value. |
205 | Value create2DTransformMatrix(OpBuilder &builder, Location loc, |
206 | TransformMatrix transform, Type type) { |
207 | ArrayRef<float> constVec(transform.table, transform.rows * transform.cols); |
208 | |
209 | return builder.create<arith::ConstantOp>( |
210 | loc, DenseFPElementsAttr::get( |
211 | RankedTensorType::get( |
212 | SmallVector<int64_t>{transform.rows, transform.cols}, type), |
213 | constVec)); |
214 | } |
215 | |
216 | /// Extract height x width data from 4D tensors. |
217 | Value (OpBuilder &builder, Location loc, Value source, |
218 | Value loopNorFIndex, Value loopCorFIndex, |
219 | Value heightOffset, Value widthOffset, |
220 | int64_t , int64_t , |
221 | int64_t loopNorFIdx, int64_t loopCorFIdx, |
222 | int64_t heightIdx, int64_t widthIdx) { |
223 | auto sourceType = cast<ShapedType>(source.getType()); |
224 | Type elementType = sourceType.getElementType(); |
225 | int64_t srcSize = sourceType.getRank(); |
226 | |
227 | auto oneIndex = builder.getIndexAttr(1); |
228 | SmallVector<OpFoldResult> offsets; |
229 | offsets.resize(N: srcSize); |
230 | offsets[loopNorFIdx] = loopNorFIndex; |
231 | offsets[loopCorFIdx] = loopCorFIndex; |
232 | offsets[heightIdx] = heightOffset; |
233 | offsets[widthIdx] = widthOffset; |
234 | SmallVector<OpFoldResult> sizes(srcSize, oneIndex); |
235 | sizes[heightIdx] = builder.getIndexAttr(extractHeight); |
236 | sizes[widthIdx] = builder.getIndexAttr(extractWidth); |
237 | SmallVector<OpFoldResult> strides(srcSize, oneIndex); |
238 | |
239 | auto = |
240 | RankedTensorType::get({extractHeight, extractWidth}, elementType); |
241 | auto = builder.create<tensor::ExtractSliceOp>( |
242 | loc, extractFilterType, source, offsets, sizes, strides); |
243 | |
244 | return extractFilterOp; |
245 | } |
246 | |
247 | /// Extract height x width data from 6D tensors. |
248 | Value (OpBuilder &builder, Location loc, Value source, |
249 | Value tileHIndex, Value tileWIndex, |
250 | Value loopNorFIndex, Value loopCorFIndex, |
251 | int64_t tileHIdx, int64_t tileWIdx, |
252 | int64_t loopNorFIdx, int64_t loopCorFIdx, |
253 | int64_t heightIdx, int64_t widthIdx) { |
254 | auto sourceType = cast<ShapedType>(source.getType()); |
255 | Type elementType = sourceType.getElementType(); |
256 | auto sourceShape = sourceType.getShape(); |
257 | int64_t srcSize = sourceType.getRank(); |
258 | int64_t height = sourceShape[heightIdx]; |
259 | int64_t width = sourceShape[widthIdx]; |
260 | |
261 | auto zeroIndex = builder.getIndexAttr(0); |
262 | auto oneIndex = builder.getIndexAttr(1); |
263 | SmallVector<OpFoldResult> offsets(srcSize, zeroIndex); |
264 | offsets.resize(N: srcSize); |
265 | offsets[tileHIdx] = tileHIndex; |
266 | offsets[tileWIdx] = tileWIndex; |
267 | offsets[loopNorFIdx] = loopNorFIndex; |
268 | offsets[loopCorFIdx] = loopCorFIndex; |
269 | SmallVector<OpFoldResult> sizes(srcSize, oneIndex); |
270 | sizes[heightIdx] = builder.getIndexAttr(height); |
271 | sizes[widthIdx] = builder.getIndexAttr(width); |
272 | SmallVector<OpFoldResult> strides(srcSize, oneIndex); |
273 | |
274 | auto = RankedTensorType::get({height, width}, elementType); |
275 | auto = builder.create<tensor::ExtractSliceOp>( |
276 | loc, extractFilterType, source, offsets, sizes, strides); |
277 | |
278 | return extractFilterOp; |
279 | } |
280 | |
281 | /// Insert transformed height x width data to 4D tensors which it is |
282 | /// extracted from. |
283 | Value insert2DDataTo4D(OpBuilder &builder, Location loc, Value source, |
284 | Value dest, Value loopNorFIndex, Value loopCorFIndex, |
285 | Value heightOffset, Value widthOffset, int64_t height, |
286 | int64_t width, int64_t loopNorFIdx, int64_t loopCorFIdx, |
287 | int64_t heightIdx, int64_t widthIdx) { |
288 | int64_t destSize = cast<ShapedType>(dest.getType()).getRank(); |
289 | auto oneIndex = builder.getIndexAttr(1); |
290 | SmallVector<OpFoldResult> retOffsets; |
291 | retOffsets.resize(N: destSize); |
292 | retOffsets[loopNorFIdx] = loopNorFIndex; |
293 | retOffsets[loopCorFIdx] = loopCorFIndex; |
294 | retOffsets[heightIdx] = heightOffset; |
295 | retOffsets[widthIdx] = widthOffset; |
296 | SmallVector<OpFoldResult> retSizes(destSize, oneIndex); |
297 | retSizes[heightIdx] = builder.getIndexAttr(height); |
298 | retSizes[widthIdx] = builder.getIndexAttr(width); |
299 | SmallVector<OpFoldResult> strides(destSize, oneIndex); |
300 | |
301 | auto insertSliceOp = builder.create<tensor::InsertSliceOp>( |
302 | loc, source, dest, retOffsets, retSizes, strides); |
303 | |
304 | return insertSliceOp; |
305 | } |
306 | |
307 | /// Insert transformed height x width data to 6D tensors which it is |
308 | /// extracted from. |
309 | Value insert2DDataTo6D(OpBuilder &builder, Location loc, Value source, |
310 | Value dest, Value tileHIndex, Value tileWIndex, |
311 | Value loopNorFIndex, Value loopCorFIndex, int64_t height, |
312 | int64_t width, int64_t tileHIdx, int64_t tileWIdx, |
313 | int64_t loopNorFIdx, int64_t loopCorFIdx, |
314 | int64_t heightIdx, int64_t widthIdx) { |
315 | int64_t destSize = cast<ShapedType>(dest.getType()).getRank(); |
316 | auto zeroIndex = builder.getIndexAttr(0); |
317 | auto oneIndex = builder.getIndexAttr(1); |
318 | SmallVector<OpFoldResult> retOffsets(destSize, zeroIndex); |
319 | retOffsets.resize(N: destSize); |
320 | retOffsets[tileHIdx] = tileHIndex; |
321 | retOffsets[tileWIdx] = tileWIndex; |
322 | retOffsets[loopNorFIdx] = loopNorFIndex; |
323 | retOffsets[loopCorFIdx] = loopCorFIndex; |
324 | SmallVector<OpFoldResult> retSizes(destSize, oneIndex); |
325 | retSizes[heightIdx] = builder.getIndexAttr(height); |
326 | retSizes[widthIdx] = builder.getIndexAttr(width); |
327 | SmallVector<OpFoldResult> strides(destSize, oneIndex); |
328 | |
329 | auto insertSliceOp = builder.create<tensor::InsertSliceOp>( |
330 | loc, source, dest, retOffsets, retSizes, strides); |
331 | |
332 | return insertSliceOp; |
333 | } |
334 | |
335 | /// This function transforms the filter. The data layout of the filter is FHWC. |
336 | /// The transformation matrix is 2-dimension. We need to extract H x W from |
337 | /// FHWC first. We need to generate 2 levels of loops to iterate on F and C. |
338 | /// After the transformation, we get |
339 | /// |
340 | /// scf.for %f = lo_f to hi_f step 1 |
341 | /// scf.for %c = lo_c to hi_c step 1 |
342 | /// %extracted = extract filter<h x w> from filter<f x h x w x c> |
343 | /// %ret = linalg.matmul G, %extracted |
344 | /// %ret = linalg.matmul %ret, GT |
345 | /// %inserted = insert %ret into filter<h x w x c x f> |
346 | Value filterTransform(RewriterBase &rewriter, Location loc, Value filter, |
347 | Value retValue, int64_t m, int64_t r, |
348 | bool leftTransform = true, bool rightTransform = true) { |
349 | // Map from (m, r) to G transform matrix. |
350 | static const llvm::SmallDenseMap<TransformMapKeyTy, TransformMatrix> |
351 | GMatrices = { |
352 | {F_2_3, TransformMatrix(G_2x2_3x3, 4, 3)}, |
353 | {F_4_3, TransformMatrix(G_4x4_3x3, 6, 3)}, |
354 | {F_2_5, TransformMatrix(G_2x2_5x5, 6, 5)}, |
355 | }; |
356 | |
357 | // Map from (m, r) to GT transform matrix. |
358 | static const llvm::SmallDenseMap<TransformMapKeyTy, TransformMatrix> |
359 | GTMatrices = { |
360 | {F_2_3, TransformMatrix(GT_2x2_3x3, 3, 4)}, |
361 | {F_4_3, TransformMatrix(GT_4x4_3x3, 3, 6)}, |
362 | {F_2_5, TransformMatrix(GT_2x2_5x5, 5, 6)}, |
363 | }; |
364 | |
365 | auto filterType = cast<ShapedType>(filter.getType()); |
366 | Type elementType = filterType.getElementType(); |
367 | auto filterShape = filterType.getShape(); // F, H, W, C |
368 | int64_t filterF = filterShape[0]; |
369 | int64_t filterH = filterShape[1]; |
370 | int64_t filterW = filterShape[2]; |
371 | int64_t filterC = filterShape[3]; |
372 | |
373 | if (filterH != r && filterH != 1) |
374 | return Value(); |
375 | if (filterW != r && filterW != 1) |
376 | return Value(); |
377 | |
378 | Value zeroIdx = rewriter.create<arith::ConstantIndexOp>(location: loc, args: 0); |
379 | auto buildBody = [&](OpBuilder &builder, Location loc, ValueRange ivs, |
380 | ValueRange args) -> scf::ValueVector { |
381 | Value FIter = ivs[0]; |
382 | Value CIter = ivs[1]; |
383 | |
384 | // Extract (H, W) from (F, H, W, C). |
385 | auto = |
386 | extract2DDataFrom4D(builder, loc, source: filter, loopNorFIndex: FIter, loopCorFIndex: CIter, heightOffset: zeroIdx, |
387 | widthOffset: zeroIdx, extractHeight: filterH, extractWidth: filterW, /*loopNorFIdx=*/0, |
388 | /*loopCorFIdx=*/3, /*heightIdx=*/1, /*widthIdx=*/2); |
389 | |
390 | TransformMapKeyTy key = {m, r}; |
391 | int64_t retRows = 1; |
392 | Value matmulRetValue = extractFilter; |
393 | Value zero = builder.create<arith::ConstantOp>( |
394 | loc, rewriter.getZeroAttr(elementType)); |
395 | if (leftTransform) { |
396 | // Get constant transform matrix G. |
397 | auto it = GMatrices.find(Val: key); |
398 | if (it == GMatrices.end()) |
399 | return {}; |
400 | const TransformMatrix &GMatrix = it->second; |
401 | |
402 | retRows = GMatrix.rows; |
403 | auto matmulType = RankedTensorType::get({retRows, filterW}, elementType); |
404 | auto empty = |
405 | builder |
406 | .create<tensor::EmptyOp>(loc, matmulType.getShape(), elementType) |
407 | .getResult(); |
408 | auto init = builder.create<linalg::FillOp>(loc, zero, empty).getResult(0); |
409 | |
410 | Value G = create2DTransformMatrix(builder, loc, transform: GMatrix, type: elementType); |
411 | // Multiply G x g. |
412 | auto matmulOp = builder.create<linalg::MatmulOp>( |
413 | loc, matmulType, ValueRange{G, extractFilter}, ValueRange{init}); |
414 | matmulRetValue = matmulOp.getResult(0); |
415 | } |
416 | |
417 | if (rightTransform) { |
418 | // Get constant transform matrix GT. |
419 | auto it = GTMatrices.find(Val: key); |
420 | if (it == GTMatrices.end()) |
421 | return {}; |
422 | const TransformMatrix >Matrix = it->second; |
423 | |
424 | auto matmulType = |
425 | RankedTensorType::get({retRows, GTMatrix.cols}, elementType); |
426 | auto empty = |
427 | builder |
428 | .create<tensor::EmptyOp>(loc, matmulType.getShape(), elementType) |
429 | .getResult(); |
430 | auto init = builder.create<linalg::FillOp>(loc, zero, empty).getResult(0); |
431 | |
432 | Value GT = create2DTransformMatrix(builder, loc, transform: GTMatrix, type: elementType); |
433 | // Multiply u = (G x g) x GT. |
434 | auto matmulOp = builder.create<linalg::MatmulOp>( |
435 | loc, matmulType, ValueRange{matmulRetValue, GT}, ValueRange{init}); |
436 | matmulRetValue = matmulOp.getResult(0); |
437 | } |
438 | |
439 | // Insert (H, W) to (H, W, C, F). |
440 | int64_t retHeight = leftTransform ? m + r - 1 : 1; |
441 | int64_t retWidth = rightTransform ? m + r - 1 : 1; |
442 | |
443 | auto insertSliceOp = |
444 | insert2DDataTo4D(builder, loc, source: matmulRetValue, dest: args[0], loopNorFIndex: FIter, loopCorFIndex: CIter, |
445 | heightOffset: zeroIdx, widthOffset: zeroIdx, height: retHeight, width: retWidth, |
446 | /*loopNorFIdx=*/3, /*loopCorFIdx=*/2, |
447 | /*heightIdx=*/0, /*widthIdx=*/1); |
448 | |
449 | return {insertSliceOp}; |
450 | }; |
451 | |
452 | auto fUpperBound = rewriter.create<arith::ConstantIndexOp>(location: loc, args&: filterF); |
453 | auto cUpperBound = rewriter.create<arith::ConstantIndexOp>(location: loc, args&: filterC); |
454 | auto oneStep = rewriter.create<arith::ConstantIndexOp>(location: loc, args: 1); |
455 | scf::LoopNest loops = scf::buildLoopNest( |
456 | rewriter, loc, {zeroIdx, zeroIdx}, {fUpperBound, cUpperBound}, |
457 | {oneStep, oneStep}, {retValue}, buildBody); |
458 | return loops.results[0]; |
459 | } |
460 | |
461 | /// This function transforms the input. The data layout of the input is NHWC. |
462 | /// The transformation matrix is 2-dimension. We need to extract H x W from |
463 | /// NHWC first. We need to generate 2 levels of loops to iterate on N and C. |
464 | /// After the transformation, we get |
465 | /// |
466 | /// scf.for %h = 0 to tileH step 1 |
467 | /// scf.for %w = 0 to tileW step 1 |
468 | /// scf.for %n = 0 to N step 1 |
469 | /// scf.for %c = 0 to C step 1 |
470 | /// %extracted = extract %extracted<alphaH x alphaW> from |
471 | /// %input<N x H x W x C> |
472 | /// at [%n, (%h x m), (%w x m), %c] |
473 | /// %ret = linalg.matmul BT, %extracted |
474 | /// %ret = linalg.matmul %ret, B |
475 | /// %inserted = insert %ret<alphaH x alphaW> into |
476 | /// %output<alphaH x alphaW x tileH x tileW x N x C> |
477 | /// at [0, 0, %h, %w, %n, %c] |
478 | Value inputTransform(RewriterBase &rewriter, Location loc, Value input, |
479 | Value retValue, int64_t m, int64_t r, |
480 | bool leftTransform = true, bool rightTransform = true) { |
481 | // Map from (m, r) to BT transform matrix. |
482 | static const llvm::SmallDenseMap<TransformMapKeyTy, TransformMatrix> |
483 | BTMatrices = { |
484 | {F_2_3, TransformMatrix(BT_2x2_3x3, 4, 4)}, |
485 | {F_4_3, TransformMatrix(BT_4x4_3x3, 6, 6)}, |
486 | {F_2_5, TransformMatrix(BT_2x2_5x5, 6, 6)}, |
487 | }; |
488 | |
489 | // Map from (m, r) to B transform matrix. |
490 | static const llvm::SmallDenseMap<TransformMapKeyTy, TransformMatrix> |
491 | BMatrices = { |
492 | {F_2_3, TransformMatrix(B_2x2_3x3, 4, 4)}, |
493 | {F_4_3, TransformMatrix(B_4x4_3x3, 6, 6)}, |
494 | {F_2_5, TransformMatrix(B_2x2_5x5, 6, 6)}, |
495 | }; |
496 | |
497 | auto inputType = cast<ShapedType>(input.getType()); |
498 | Type elementType = inputType.getElementType(); |
499 | auto inputShape = inputType.getShape(); // N, H, W, C |
500 | int64_t inputN = inputShape[0]; |
501 | int64_t inputC = inputShape[3]; |
502 | auto valueType = cast<ShapedType>(retValue.getType()); |
503 | auto valueShape = valueType.getShape(); // alphaH, alphaW, HTile, WTile, N, C |
504 | int64_t tileH = valueShape[2]; |
505 | int64_t tileW = valueShape[3]; |
506 | int64_t alphaH = leftTransform ? m + r - 1 : 1; |
507 | int64_t alphaW = rightTransform ? m + r - 1 : 1; |
508 | |
509 | auto buildBody = [&](OpBuilder &builder, Location loc, ValueRange ivs, |
510 | ValueRange args) -> scf::ValueVector { |
511 | Value tileHIter = ivs[0]; |
512 | Value tileWIter = ivs[1]; |
513 | Value NIter = ivs[2]; |
514 | Value CIter = ivs[3]; |
515 | |
516 | auto context = builder.getContext(); |
517 | |
518 | auto identityAffineMap = rewriter.getMultiDimIdentityMap(rank: 1); |
519 | auto affineMap = |
520 | AffineMap::get(dimCount: 1, symbolCount: 0, results: {builder.getAffineDimExpr(position: 0) * m}, context); |
521 | Value heightOffset = builder.create<affine::AffineApplyOp>( |
522 | loc, leftTransform ? affineMap : identityAffineMap, tileHIter); |
523 | Value widthOffset = builder.create<affine::AffineApplyOp>( |
524 | loc, rightTransform ? affineMap : identityAffineMap, tileWIter); |
525 | |
526 | // Extract (H, W) from (N, H, W, C). |
527 | auto = |
528 | extract2DDataFrom4D(builder, loc, source: input, loopNorFIndex: NIter, loopCorFIndex: CIter, heightOffset, |
529 | widthOffset, extractHeight: alphaH, extractWidth: alphaW, /*loopNorFIdx=*/0, |
530 | /*loopCorFIdx=*/3, /*heightIdx=*/1, /*widthIdx=*/2); |
531 | |
532 | TransformMapKeyTy key = {m, r}; |
533 | int64_t retRows = 1; |
534 | int64_t retCols = 1; |
535 | Value matmulRetValue = extractInput; |
536 | Value zero = builder.create<arith::ConstantOp>( |
537 | loc, rewriter.getZeroAttr(elementType)); |
538 | if (leftTransform) { |
539 | // Get constant transform matrix BT. |
540 | auto it = BTMatrices.find(Val: key); |
541 | if (it == BTMatrices.end()) |
542 | return {}; |
543 | const TransformMatrix &BTMatrix = it->second; |
544 | |
545 | retRows = BTMatrix.rows; |
546 | auto matmulType = RankedTensorType::get({retRows, alphaW}, elementType); |
547 | auto empty = |
548 | builder |
549 | .create<tensor::EmptyOp>(loc, matmulType.getShape(), elementType) |
550 | .getResult(); |
551 | auto init = builder.create<linalg::FillOp>(loc, zero, empty).getResult(0); |
552 | |
553 | Value BT = |
554 | create2DTransformMatrix(builder, loc, BTMatrix, builder.getF32Type()); |
555 | // Multiply BT x d. |
556 | auto matmulOp = builder.create<linalg::MatmulOp>( |
557 | loc, matmulType, ValueRange{BT, matmulRetValue}, ValueRange{init}); |
558 | matmulRetValue = matmulOp.getResult(0); |
559 | } |
560 | |
561 | if (rightTransform) { |
562 | // Get constant transform matrix B. |
563 | auto it = BMatrices.find(Val: key); |
564 | if (it == BMatrices.end()) |
565 | return {}; |
566 | const TransformMatrix &BMatrix = it->second; |
567 | |
568 | retCols = BMatrix.cols; |
569 | auto matmulType = RankedTensorType::get({retRows, retCols}, elementType); |
570 | auto empty = |
571 | builder |
572 | .create<tensor::EmptyOp>(loc, matmulType.getShape(), elementType) |
573 | .getResult(); |
574 | auto init = builder.create<linalg::FillOp>(loc, zero, empty).getResult(0); |
575 | Value B = |
576 | create2DTransformMatrix(builder, loc, BMatrix, builder.getF32Type()); |
577 | // Multiply v = (BT x d) x B. |
578 | auto matmulOp = builder.create<linalg::MatmulOp>( |
579 | loc, matmulType, ValueRange{matmulRetValue, B}, ValueRange{init}); |
580 | matmulRetValue = matmulOp.getResult(0); |
581 | } |
582 | |
583 | // Insert (H, W) to (H, W, tileH, tileW, N, C). |
584 | auto combinedVal = insert2DDataTo6D( |
585 | builder, loc, source: matmulRetValue, dest: args[0], tileHIndex: tileHIter, tileWIndex: tileWIter, loopNorFIndex: NIter, |
586 | loopCorFIndex: CIter, height: retRows, width: retCols, tileHIdx: 2, tileWIdx: 3, /*loopNorFIdx=*/4, /*loopCorFIdx=*/5, |
587 | /*heightIdx=*/0, /*widthIdx=*/1); |
588 | |
589 | return {combinedVal}; |
590 | }; |
591 | |
592 | auto zeroIdx = rewriter.create<arith::ConstantIndexOp>(location: loc, args: 0); |
593 | auto tileHBound = rewriter.create<arith::ConstantIndexOp>(location: loc, args&: tileH); |
594 | auto tileWBound = rewriter.create<arith::ConstantIndexOp>(location: loc, args&: tileW); |
595 | auto nUpperBound = rewriter.create<arith::ConstantIndexOp>(location: loc, args&: inputN); |
596 | auto cUpperBound = rewriter.create<arith::ConstantIndexOp>(location: loc, args&: inputC); |
597 | auto oneStep = rewriter.create<arith::ConstantIndexOp>(location: loc, args: 1); |
598 | scf::LoopNest loops = scf::buildLoopNest( |
599 | rewriter, loc, {zeroIdx, zeroIdx, zeroIdx, zeroIdx}, |
600 | {tileHBound, tileWBound, nUpperBound, cUpperBound}, |
601 | {oneStep, oneStep, oneStep, oneStep}, {retValue}, buildBody); |
602 | return loops.results[0]; |
603 | } |
604 | |
605 | /// This function generates linalg.batch_matmul to multiply input with filter. |
606 | /// linalg.batch_matmul only supports 3-dimensional inputs. We can treat |
607 | /// tileH x tileW x H x W data as the 1-dimensional data array. That is to |
608 | /// convert [tileH, tileW, H, W, N, C] to [tileH x tileW x H x W, N, C]. In this |
609 | /// way, we can convert 6-dimensional inputs to 3-dimensional representation |
610 | /// that is suitable for linalg.batch_matmul. |
611 | /// |
612 | /// Batched matmul will do the matrix multiply with the reduction on channel. |
613 | /// |
614 | /// We get |
615 | /// |
616 | /// %collapsed_input = tensor.collapse_shape %input |
617 | /// %collapsed_filter = tensor.collapse_shape %filter |
618 | /// %ret = linalg.batch_matmul %collapsed_input, %collapsed_filter |
619 | /// %expanded_ret = tensor.expand_shape %ret |
620 | /// |
621 | /// After this function, we get return value with data layout |
622 | /// (tileH, tileW, H, W, N, F). |
623 | static Value matrixMultiply(RewriterBase &rewriter, Location loc, |
624 | Value transformedFilter, Value transformedInput, |
625 | Type outputElementType) { |
626 | // Convert (alphaH, alphaW, C, F) to (alphaH x alphaW, C, F) for filter. |
627 | auto filterType = cast<ShapedType>(transformedFilter.getType()); |
628 | assert(filterType.hasStaticShape() && "only support static shapes." ); |
629 | ArrayRef<int64_t> filterShape = filterType.getShape(); |
630 | Type filterElementType = filterType.getElementType(); |
631 | auto filterReassocType = RankedTensorType::get( |
632 | {filterShape[0] * filterShape[1], filterShape[2], filterShape[3]}, |
633 | filterElementType); |
634 | SmallVector<ReassociationIndices> filterReassoc = {{0, 1}, {2}, {3}}; |
635 | Value collapseFilter = rewriter.create<tensor::CollapseShapeOp>( |
636 | loc, filterReassocType, transformedFilter, filterReassoc); |
637 | |
638 | // Convert (alphaH, alphaW, tileH, tileW, N, C) to |
639 | // (alphaH x alphaW, tileH x tileW x N, C) for input. |
640 | auto inputType = cast<ShapedType>(transformedInput.getType()); |
641 | assert(inputType.hasStaticShape() && "only support static shapes." ); |
642 | ArrayRef<int64_t> inputShape = inputType.getShape(); |
643 | Type inputElementType = inputType.getElementType(); |
644 | auto inputReassocType = RankedTensorType::get( |
645 | {inputShape[0] * inputShape[1], |
646 | inputShape[2] * inputShape[3] * inputShape[4], inputShape[5]}, |
647 | inputElementType); |
648 | SmallVector<ReassociationIndices> inputReassoc = {{0, 1}, {2, 3, 4}, {5}}; |
649 | Value collapseInput = rewriter.create<tensor::CollapseShapeOp>( |
650 | loc, inputReassocType, transformedInput, inputReassoc); |
651 | |
652 | // Batched matrix multiply. |
653 | auto matmulType = RankedTensorType::get( |
654 | {inputShape[0] * inputShape[1], |
655 | inputShape[2] * inputShape[3] * inputShape[4], filterShape[3]}, |
656 | outputElementType); |
657 | Value empty = rewriter |
658 | .create<tensor::EmptyOp>(loc, matmulType.getShape(), |
659 | outputElementType) |
660 | .getResult(); |
661 | Value zero = rewriter.create<arith::ConstantOp>( |
662 | loc, rewriter.getZeroAttr(outputElementType)); |
663 | Value init = rewriter.create<linalg::FillOp>(loc, zero, empty).getResult(0); |
664 | |
665 | auto matmulOp = rewriter.create<linalg::BatchMatmulOp>( |
666 | loc, matmulType, ValueRange({collapseInput, collapseFilter}), |
667 | ValueRange{init}); |
668 | |
669 | // The result shape of batch matmul is (alphaH x alphaW, tileH x tileW x N, F) |
670 | // Expand matmul result to (alphaH, alphaW, tileH, tileW, N, F). |
671 | SmallVector<ReassociationIndices> outputReassoc = {{0, 1}, {2, 3, 4}, {5}}; |
672 | auto outputReassocType = |
673 | RankedTensorType::get({inputShape[0], inputShape[1], inputShape[2], |
674 | inputShape[3], inputShape[4], filterShape[3]}, |
675 | outputElementType); |
676 | auto expandOutput = rewriter.create<tensor::ExpandShapeOp>( |
677 | loc, outputReassocType, matmulOp.getResult(0), outputReassoc); |
678 | return expandOutput; |
679 | } |
680 | |
681 | /// This function transforms the output. The data layout of the output is HWNF. |
682 | /// The transformation matrix is 2-dimension. We need to extract H x W from |
683 | /// HWNF first. We need to generate 2 levels of loops to iterate on N and F. |
684 | /// After the transformation, we get |
685 | /// |
686 | /// scf.for %h = 0 to tileH step 1 |
687 | /// scf.for %w = 0 to tileW step 1 |
688 | /// scf.for %n = 0 to N step 1 |
689 | /// scf.for %f = 0 to F step 1 |
690 | /// %extracted = extract %extracted<alphaH x alphaW> from |
691 | /// %input<alphaH x alphaW x tileH x tileW x N x F> |
692 | /// at [0, 0, %h, %w, %n, %f] |
693 | /// %ret = linalg.matmul AT, %extracted |
694 | /// %ret = linalg.matmul %ret, A |
695 | /// %inserted = insert %ret<alphaH x alphaW> into |
696 | /// output<N x H x W x F> |
697 | /// at [%n, (%h x m), (%w x m), %f] |
698 | Value outputTransform(RewriterBase &rewriter, Location loc, Value value, |
699 | Value output, int64_t m, int64_t r, |
700 | bool leftTransform = true, bool rightTransform = true) { |
701 | // Map from (m, r) to AT transform matrix. |
702 | static const llvm::SmallDenseMap<TransformMapKeyTy, TransformMatrix> |
703 | ATMatrices = { |
704 | {F_2_3, TransformMatrix(AT_2x2_3x3, 2, 4)}, |
705 | {F_4_3, TransformMatrix(AT_4x4_3x3, 4, 6, 32)}, |
706 | {F_2_5, TransformMatrix(AT_2x2_5x5, 2, 6, 16)}, |
707 | }; |
708 | |
709 | // Map from (m, r) to A transform matrix. |
710 | static const llvm::SmallDenseMap<TransformMapKeyTy, TransformMatrix> |
711 | AMatrices = { |
712 | {F_2_3, TransformMatrix(A_2x2_3x3, 4, 2)}, |
713 | {F_4_3, TransformMatrix(A_4x4_3x3, 6, 4, 32)}, |
714 | {F_2_5, TransformMatrix(A_2x2_5x5, 6, 2, 16)}, |
715 | }; |
716 | |
717 | auto valueType = cast<ShapedType>(value.getType()); |
718 | Type elementType = valueType.getElementType(); |
719 | auto valueShape = valueType.getShape(); // H, W, TileH, TileW, N, F |
720 | int64_t valueH = valueShape[0]; |
721 | int64_t valueW = valueShape[1]; |
722 | int64_t valueN = valueShape[4]; |
723 | int64_t valueF = valueShape[5]; |
724 | int64_t alphaH = leftTransform ? m + r - 1 : 1; |
725 | int64_t alphaW = rightTransform ? m + r - 1 : 1; |
726 | |
727 | if (valueH != alphaH && valueH != 1) |
728 | return Value(); |
729 | if (valueW != alphaW && valueW != 1) |
730 | return Value(); |
731 | |
732 | auto buildBody = [&](OpBuilder &builder, Location loc, ValueRange ivs, |
733 | ValueRange args) -> scf::ValueVector { |
734 | auto context = builder.getContext(); |
735 | Value tileHIter = ivs[0]; |
736 | Value tileWIter = ivs[1]; |
737 | Value NIter = ivs[2]; |
738 | Value FIter = ivs[3]; |
739 | |
740 | // Extract (H, W) from (H, W, tileH, tileW, N, F). |
741 | auto = |
742 | extract2DDataFrom6D(builder, loc, source: value, tileHIndex: tileHIter, tileWIndex: tileWIter, loopNorFIndex: NIter, |
743 | loopCorFIndex: FIter, tileHIdx: 2, tileWIdx: 3, /*loopNorFIdx=*/4, |
744 | /*loopCorFIdx=*/5, /*heightIdx=*/0, /*widthIdx=*/1); |
745 | |
746 | const TransformMapKeyTy key = {m, r}; |
747 | const TransformMatrix &AMatrix = AMatrices.at(Val: key); |
748 | const TransformMatrix &ATMatrix = ATMatrices.at(Val: key); |
749 | int64_t scalarFactor = (rightTransform ? AMatrix.scalarFactor : 1) * |
750 | (leftTransform ? ATMatrix.scalarFactor : 1); |
751 | int64_t retCols = rightTransform ? AMatrix.cols : 1; |
752 | int64_t retRows = leftTransform ? ATMatrix.rows : 1; |
753 | |
754 | Value matmulRetValue = extractValue; |
755 | Value zero = builder.create<arith::ConstantOp>( |
756 | loc, rewriter.getZeroAttr(elementType)); |
757 | |
758 | auto identityAffineMap = rewriter.getMultiDimIdentityMap(rank: 1); |
759 | auto affineMap = |
760 | AffineMap::get(dimCount: 1, symbolCount: 0, results: {builder.getAffineDimExpr(position: 0) * m}, context); |
761 | Value heightOffset = builder.create<affine::AffineApplyOp>( |
762 | loc, leftTransform ? affineMap : identityAffineMap, tileHIter); |
763 | Value widthOffset = builder.create<affine::AffineApplyOp>( |
764 | loc, rightTransform ? affineMap : identityAffineMap, tileWIter); |
765 | |
766 | Value outInitVal = |
767 | extract2DDataFrom4D(builder, loc, source: args[0], loopNorFIndex: NIter, loopCorFIndex: FIter, heightOffset, |
768 | widthOffset, extractHeight: retRows, extractWidth: retCols, |
769 | /*loopNorFIdx=*/0, |
770 | /*loopCorFIdx=*/3, /*heightIdx=*/1, |
771 | /*widthIdx=*/2); |
772 | if (leftTransform) { |
773 | auto matmulType = RankedTensorType::get({retRows, valueW}, elementType); |
774 | Value init = outInitVal; |
775 | if (rightTransform || scalarFactor != 1) { |
776 | auto empty = builder |
777 | .create<tensor::EmptyOp>(loc, matmulType.getShape(), |
778 | elementType) |
779 | .getResult(); |
780 | init = builder.create<linalg::FillOp>(loc, zero, empty).getResult(0); |
781 | } |
782 | |
783 | Value AT = create2DTransformMatrix(builder, loc, transform: ATMatrix, type: elementType); |
784 | // Multiply AT x m. |
785 | auto matmulOp = builder.create<linalg::MatmulOp>( |
786 | loc, matmulType, ValueRange{AT, matmulRetValue}, ValueRange{init}); |
787 | matmulRetValue = matmulOp.getResult(0); |
788 | } |
789 | |
790 | if (rightTransform) { |
791 | auto matmulType = |
792 | RankedTensorType::get({retRows, AMatrix.cols}, elementType); |
793 | Value init = outInitVal; |
794 | if (scalarFactor != 1) { |
795 | auto empty = builder |
796 | .create<tensor::EmptyOp>(loc, matmulType.getShape(), |
797 | elementType) |
798 | .getResult(); |
799 | init = builder.create<linalg::FillOp>(loc, zero, empty).getResult(0); |
800 | } |
801 | |
802 | Value A = create2DTransformMatrix(builder, loc, transform: AMatrix, type: elementType); |
803 | // Multiply y = (AT x m) x A. |
804 | auto matmulOp = builder.create<linalg::MatmulOp>( |
805 | loc, matmulType, ValueRange{matmulRetValue, A}, ValueRange{init}); |
806 | matmulRetValue = matmulOp.getResult(0); |
807 | } |
808 | |
809 | if (scalarFactor != 1) { |
810 | // Multiply by scalar factor and add outInitVal. |
811 | Value scalarFactorValue = builder.create<arith::ConstantOp>( |
812 | loc, FloatAttr::get(elementType, scalarFactor)); |
813 | auto matmulType = RankedTensorType::get({retRows, retCols}, elementType); |
814 | auto identityAffineMap = rewriter.getMultiDimIdentityMap(rank: 2); |
815 | SmallVector<AffineMap> affineMaps = { |
816 | AffineMap::get(dimCount: 2, symbolCount: 0, context), identityAffineMap, identityAffineMap}; |
817 | |
818 | matmulRetValue = |
819 | rewriter |
820 | .create<linalg::GenericOp>( |
821 | loc, matmulType, |
822 | ValueRange{scalarFactorValue, matmulRetValue}, |
823 | ValueRange{outInitVal}, affineMaps, |
824 | llvm::ArrayRef<utils::IteratorType>{ |
825 | utils::IteratorType::parallel, |
826 | utils::IteratorType::parallel}, |
827 | [&](OpBuilder &nestedBuilder, Location nestedLoc, |
828 | ValueRange args) { |
829 | auto mulf = nestedBuilder.create<arith::MulFOp>( |
830 | nestedLoc, args[0], args[1]); |
831 | auto addf = nestedBuilder.create<arith::AddFOp>( |
832 | nestedLoc, mulf.getResult(), args[2]); |
833 | nestedBuilder.create<linalg::YieldOp>(nestedLoc, |
834 | addf.getResult()); |
835 | }) |
836 | .getResult(0); |
837 | } |
838 | |
839 | // Insert (H, W) to (N, H, W, F). |
840 | Value combinedVal = |
841 | insert2DDataTo4D(builder, loc, source: matmulRetValue, dest: args[0], loopNorFIndex: NIter, loopCorFIndex: FIter, |
842 | heightOffset, widthOffset, height: retRows, width: retCols, |
843 | /*loopNorFIdx=*/0, |
844 | /*loopCorFIdx=*/3, /*heightIdx=*/1, |
845 | /*widthIdx=*/2); |
846 | |
847 | return {combinedVal}; |
848 | }; |
849 | |
850 | int64_t tilwH = valueShape[2]; |
851 | int64_t tileW = valueShape[3]; |
852 | auto zeroIdx = rewriter.create<arith::ConstantIndexOp>(location: loc, args: 0); |
853 | auto tileHBound = rewriter.create<arith::ConstantIndexOp>(location: loc, args&: tilwH); |
854 | auto tileWBound = rewriter.create<arith::ConstantIndexOp>(location: loc, args&: tileW); |
855 | auto nUpperBound = rewriter.create<arith::ConstantIndexOp>(location: loc, args&: valueN); |
856 | auto fUpperBound = rewriter.create<arith::ConstantIndexOp>(location: loc, args&: valueF); |
857 | auto oneStep = rewriter.create<arith::ConstantIndexOp>(location: loc, args: 1); |
858 | scf::LoopNest loops = scf::buildLoopNest( |
859 | rewriter, loc, {zeroIdx, zeroIdx, zeroIdx, zeroIdx}, |
860 | {tileHBound, tileWBound, nUpperBound, fUpperBound}, |
861 | {oneStep, oneStep, oneStep, oneStep}, {output}, buildBody); |
862 | return loops.results[0]; |
863 | } |
864 | |
865 | /// Create an empty tensor with alignedType and insert the value into the |
866 | /// created empty tensor with aligned size. |
867 | static Value padToAlignedTensor(RewriterBase &rewriter, Location loc, |
868 | Value value, ArrayRef<int64_t> alignedShape) { |
869 | auto valueType = cast<ShapedType>(value.getType()); |
870 | Type elementType = valueType.getElementType(); |
871 | auto alignedType = RankedTensorType::get(alignedShape, elementType); |
872 | Value padValue = rewriter.create<arith::ConstantOp>( |
873 | loc, elementType, rewriter.getZeroAttr(elementType)); |
874 | |
875 | return linalg::makeComposedPadHighOp(b&: rewriter, loc, type: alignedType, source: value, |
876 | pad: padValue, nofold: false); |
877 | } |
878 | |
879 | /// Extract sub-tensor with extractedType from value. |
880 | static Value (RewriterBase &rewriter, Location loc, |
881 | Value value, |
882 | RankedTensorType ) { |
883 | OpFoldResult zeroIndex = rewriter.getIndexAttr(0); |
884 | OpFoldResult oneIndex = rewriter.getIndexAttr(1); |
885 | SmallVector<OpFoldResult, 4> offsets(4, zeroIndex); |
886 | SmallVector<OpFoldResult, 4> strides(4, oneIndex); |
887 | |
888 | ArrayRef<int64_t> = extractedType.getShape(); |
889 | SmallVector<OpFoldResult> sizes = |
890 | getAsOpFoldResult(rewriter.getI64ArrayAttr(extractedShape)); |
891 | |
892 | return rewriter.create<tensor::ExtractSliceOp>(loc, extractedType, value, |
893 | offsets, sizes, strides); |
894 | } |
895 | |
896 | /// Utility function to check all values in the attribute are 1. |
897 | static bool hasAllOneValues(DenseIntElementsAttr attr) { |
898 | return llvm::all_of( |
899 | Range&: attr, P: [](const APInt &element) { return element.getSExtValue() == 1; }); |
900 | } |
901 | |
902 | /// A helper function to convert linalg.conv_2d_nhwc_fhwc to |
903 | /// linalg.winograd_*_transform ops. |
904 | static FailureOr<Operation *> |
905 | winogradConv2DHelper(RewriterBase &rewriter, linalg::Conv2DNhwcFhwcOp convOp, |
906 | int64_t m, int64_t r) { |
907 | Value input = convOp.getInputs()[0]; |
908 | Value filter = convOp.getInputs()[1]; |
909 | Value output = convOp.getOutputs()[0]; |
910 | auto inputType = cast<ShapedType>(input.getType()); |
911 | auto filterType = cast<ShapedType>(filter.getType()); |
912 | auto outputType = cast<ShapedType>(output.getType()); |
913 | |
914 | if (!inputType.hasStaticShape()) |
915 | return rewriter.notifyMatchFailure(convOp, |
916 | "expected a static shape for the input" ); |
917 | |
918 | if (!filterType.hasStaticShape()) |
919 | return rewriter.notifyMatchFailure( |
920 | convOp, "expected a static shape for the filter" ); |
921 | |
922 | if (!hasAllOneValues(convOp.getDilations())) |
923 | return rewriter.notifyMatchFailure(convOp, |
924 | "expected all ones for dilations" ); |
925 | |
926 | if (!hasAllOneValues(convOp.getStrides())) |
927 | return rewriter.notifyMatchFailure(convOp, "expected all ones for strides" ); |
928 | |
929 | ArrayRef<int64_t> filterShape = filterType.getShape(); |
930 | int64_t filterF = filterShape[0]; |
931 | int64_t filterH = filterShape[1]; |
932 | int64_t filterW = filterShape[2]; |
933 | int64_t filterC = filterShape[3]; |
934 | ArrayRef<int64_t> inputShape = inputType.getShape(); |
935 | int64_t inputN = inputShape[0]; |
936 | int64_t inputH = inputShape[1]; |
937 | int64_t inputW = inputShape[2]; |
938 | int64_t inputC = inputShape[3]; |
939 | ArrayRef<int64_t> outputShape = outputType.getShape(); |
940 | int64_t outputN = outputShape[0]; |
941 | int64_t outputH = outputShape[1]; |
942 | int64_t outputW = outputShape[2]; |
943 | int64_t outputF = outputShape[3]; |
944 | |
945 | // Only support F(m x m, r x r), F(m x 1, r x 1) or F(1 x m, 1 x r). |
946 | bool isSupportedFilter = false; |
947 | if (filterH == filterW && filterH == r) |
948 | isSupportedFilter = true; |
949 | if (filterH == r && filterW == 1) |
950 | isSupportedFilter = true; |
951 | if (filterH == 1 && filterW == r) |
952 | isSupportedFilter = true; |
953 | |
954 | if (!isSupportedFilter) |
955 | return rewriter.notifyMatchFailure( |
956 | convOp, "only support filter (r x r), (r x 1) or (1 x r)" ); |
957 | |
958 | // Currently, we support (m, r) = (2, 3) or (4, 3) or (2, 5). |
959 | static const llvm::SmallVector<TransformMapKeyTy, 3> validConfigs = { |
960 | F_2_3, F_4_3, F_2_5}; |
961 | |
962 | TransformMapKeyTy key = {m, r}; |
963 | auto it = llvm::find(Range: validConfigs, Val: key); |
964 | // If we cannot find the constant transformation matrix, it means we do |
965 | // not support this configuration yet. |
966 | if (it == validConfigs.end()) |
967 | return failure(); |
968 | |
969 | // All the criterias are satisfied. We can do Winograd Conv2D. |
970 | Location loc = convOp.getLoc(); |
971 | |
972 | // For F(m x 1, r x 1), we only need to do left side transform. |
973 | bool leftTransform = filterH != 1; |
974 | // For F(1 x m, 1 x r), we only need to do right side transform. |
975 | bool rightTransform = filterW != 1; |
976 | int64_t heightM = leftTransform ? m : 1; |
977 | int64_t widthM = rightTransform ? m : 1; |
978 | int64_t heightR = leftTransform ? r : 1; |
979 | int64_t widthR = rightTransform ? r : 1; |
980 | |
981 | // --- Create operation for filter transform --- |
982 | Type filterElementType = filterType.getElementType(); |
983 | int64_t alphaH = heightM + heightR - 1; |
984 | int64_t alphaW = widthM + widthR - 1; |
985 | int64_t tileH = llvm::divideCeilSigned(Numerator: outputH, Denominator: heightM); |
986 | int64_t tileW = llvm::divideCeilSigned(Numerator: outputW, Denominator: widthM); |
987 | auto retType = RankedTensorType::get({alphaH, alphaW, filterC, filterF}, |
988 | filterElementType); |
989 | Value retValue = rewriter.create<tensor::EmptyOp>(loc, retType.getShape(), |
990 | filterElementType); |
991 | auto transformedFilter = rewriter.create<linalg::WinogradFilterTransformOp>( |
992 | loc, retType, filter, retValue, m, r); |
993 | |
994 | // --- Create operation for input transform --- |
995 | |
996 | // When input size - (r - 1) is not aligned with output tile size, we need to |
997 | // pad the input data to create the full tiles as tiling. |
998 | Type inputElementType = inputType.getElementType(); |
999 | int64_t alignedInputH = tileH * heightM + (heightR - 1); |
1000 | int64_t alignedInputW = tileW * widthM + (widthR - 1); |
1001 | if (alignedInputH != inputH || alignedInputW != inputW) { |
1002 | input = padToAlignedTensor(rewriter, loc, value: input, |
1003 | alignedShape: {inputN, alignedInputH, alignedInputW, inputC}); |
1004 | } |
1005 | |
1006 | retType = RankedTensorType::get( |
1007 | {alphaH, alphaW, tileH, tileW, inputN, inputC}, inputElementType); |
1008 | retValue = rewriter.create<tensor::EmptyOp>(loc, retType.getShape(), |
1009 | inputElementType); |
1010 | auto transformedInput = rewriter.create<linalg::WinogradInputTransformOp>( |
1011 | loc, retType, input, retValue, m, r); |
1012 | |
1013 | Type outputElementType = outputType.getElementType(); |
1014 | Value matmulRet = matrixMultiply(rewriter, loc, transformedFilter, |
1015 | transformedInput, outputElementType); |
1016 | |
1017 | // --- Create operation for output transform --- |
1018 | |
1019 | // When output size is not aligned with output tile size, we need to pad the |
1020 | // output buffer to insert the full tiles after tiling. |
1021 | int64_t alignedOutputH = tileH * heightM; |
1022 | int64_t alignedOutputW = tileW * widthM; |
1023 | bool isOutputUnaligned = |
1024 | ((alignedOutputH != outputH) || (alignedOutputW != outputW)); |
1025 | if (isOutputUnaligned) { |
1026 | auto alignedOutputType = RankedTensorType::get( |
1027 | {outputN, alignedOutputH, alignedOutputW, outputF}, outputElementType); |
1028 | output = |
1029 | padToAlignedTensor(rewriter, loc, output, alignedOutputType.getShape()); |
1030 | outputType = alignedOutputType; |
1031 | } |
1032 | |
1033 | Value transformedOutput = rewriter.create<linalg::WinogradOutputTransformOp>( |
1034 | loc, outputType, matmulRet, output, m, r); |
1035 | |
1036 | // When output size is not aligned with output tile size, extract the |
1037 | // value from the padded buffer. |
1038 | if (isOutputUnaligned) { |
1039 | transformedOutput = extractFromAlignedTensor( |
1040 | rewriter, loc, transformedOutput, |
1041 | RankedTensorType::get({outputN, outputH, outputW, outputF}, |
1042 | outputElementType)); |
1043 | } |
1044 | |
1045 | rewriter.replaceOp(convOp, transformedOutput); |
1046 | |
1047 | return transformedOutput.getDefiningOp(); |
1048 | } |
1049 | |
1050 | /// A helper function to decompose linalg.winograd_filter_transform. |
1051 | FailureOr<Operation *> |
1052 | decomposeWinogradFilterTransformHelper(RewriterBase &rewriter, |
1053 | linalg::WinogradFilterTransformOp op) { |
1054 | Location loc = op.getLoc(); |
1055 | Value filter = op.getFilter(); |
1056 | auto filterType = cast<ShapedType>(filter.getType()); |
1057 | auto filterShape = filterType.getShape(); |
1058 | int64_t filterH = filterShape[1]; |
1059 | int64_t filterW = filterShape[2]; |
1060 | |
1061 | // For F(m x 1, r x 1), we only need to do left side transform. |
1062 | bool leftTransform = filterH != 1; |
1063 | // For F(1 x m, 1 x r), we only need to do right side transform. |
1064 | bool rightTransform = filterW != 1; |
1065 | Value transformedFilter = |
1066 | filterTransform(rewriter, loc, filter, op.getOutput(), op.getM(), |
1067 | op.getR(), leftTransform, rightTransform); |
1068 | if (!transformedFilter) |
1069 | return failure(); |
1070 | |
1071 | rewriter.replaceOp(op, transformedFilter); |
1072 | |
1073 | return transformedFilter.getDefiningOp(); |
1074 | } |
1075 | |
1076 | /// A helper function to decompose linalg.winograd_input_transform. |
1077 | FailureOr<Operation *> |
1078 | decomposeWinogradInputTransformHelper(RewriterBase &rewriter, |
1079 | linalg::WinogradInputTransformOp op) { |
1080 | Location loc = op.getLoc(); |
1081 | Value output = op.getOutput(); |
1082 | auto outputType = cast<ShapedType>(output.getType()); |
1083 | auto outputShape = outputType.getShape(); |
1084 | |
1085 | int64_t outputH = outputShape[0]; |
1086 | int64_t outputW = outputShape[1]; |
1087 | |
1088 | // For F(m x 1, r x 1), we only need to do left side transform. |
1089 | bool leftTransform = outputH != 1; |
1090 | // For F(1 x m, 1 x r), we only need to do right side transform. |
1091 | bool rightTransform = outputW != 1; |
1092 | Value transformedInput = |
1093 | inputTransform(rewriter, loc, op.getInput(), op.getOutput(), op.getM(), |
1094 | op.getR(), leftTransform, rightTransform); |
1095 | if (!transformedInput) |
1096 | return failure(); |
1097 | |
1098 | rewriter.replaceOp(op, transformedInput); |
1099 | |
1100 | return transformedInput.getDefiningOp(); |
1101 | } |
1102 | |
1103 | /// A helper function to decompose linalg.winograd_output_transform. |
1104 | FailureOr<Operation *> |
1105 | decomposeWinogradOutputTransformHelper(RewriterBase &rewriter, |
1106 | linalg::WinogradOutputTransformOp op) { |
1107 | Location loc = op.getLoc(); |
1108 | Value value = op.getValue(); |
1109 | auto valueType = cast<ShapedType>(value.getType()); |
1110 | auto valueShape = valueType.getShape(); |
1111 | int64_t valueH = valueShape[0]; |
1112 | int64_t valueW = valueShape[1]; |
1113 | |
1114 | // For F(m x 1, r x 1), we only need to do left side transform. |
1115 | bool leftTransform = valueH != 1; |
1116 | // For F(1 x m, 1 x r), we only need to do right side transform. |
1117 | bool rightTransform = valueW != 1; |
1118 | Value transformedOutput = |
1119 | outputTransform(rewriter, loc, value, op.getOutput(), op.getM(), |
1120 | op.getR(), leftTransform, rightTransform); |
1121 | if (!transformedOutput) |
1122 | return failure(); |
1123 | |
1124 | rewriter.replaceOp(op, transformedOutput); |
1125 | |
1126 | return transformedOutput.getDefiningOp(); |
1127 | } |
1128 | |
1129 | /// A rewrite pattern to decompose linalg.winograd_filter_transform operations. |
1130 | class DecomposeWinogradFilterTransform final |
1131 | : public OpRewritePattern<linalg::WinogradFilterTransformOp> { |
1132 | public: |
1133 | using OpRewritePattern::OpRewritePattern; |
1134 | |
1135 | LogicalResult matchAndRewrite(linalg::WinogradFilterTransformOp op, |
1136 | PatternRewriter &rewriter) const override { |
1137 | return decomposeWinogradFilterTransformHelper(rewriter, op); |
1138 | } |
1139 | }; |
1140 | |
1141 | /// A rewrite pattern to decompose linalg.winograd_input_transform operations. |
1142 | class DecomposeWinogradInputTransform final |
1143 | : public OpRewritePattern<linalg::WinogradInputTransformOp> { |
1144 | public: |
1145 | using OpRewritePattern::OpRewritePattern; |
1146 | |
1147 | LogicalResult matchAndRewrite(linalg::WinogradInputTransformOp op, |
1148 | PatternRewriter &rewriter) const override { |
1149 | return decomposeWinogradInputTransformHelper(rewriter, op); |
1150 | } |
1151 | }; |
1152 | |
1153 | /// A rewrite pattern to decompose linalg.winograd_output_transform operations. |
1154 | class DecomposeWinogradOutputTransform final |
1155 | : public OpRewritePattern<linalg::WinogradOutputTransformOp> { |
1156 | public: |
1157 | using OpRewritePattern::OpRewritePattern; |
1158 | |
1159 | LogicalResult matchAndRewrite(linalg::WinogradOutputTransformOp op, |
1160 | PatternRewriter &rewriter) const override { |
1161 | return decomposeWinogradOutputTransformHelper(rewriter, op); |
1162 | } |
1163 | }; |
1164 | |
1165 | /// A rewrite pattern for Winograd Conv2D algorithm. |
1166 | class WinogradConv2DNhwcFhwc final |
1167 | : public OpRewritePattern<linalg::Conv2DNhwcFhwcOp> { |
1168 | public: |
1169 | using OpRewritePattern::OpRewritePattern; |
1170 | WinogradConv2DNhwcFhwc(mlir::MLIRContext *context, int64_t m, int64_t r) |
1171 | : OpRewritePattern(context), m(m), r(r) {} |
1172 | |
1173 | LogicalResult matchAndRewrite(linalg::Conv2DNhwcFhwcOp convOp, |
1174 | PatternRewriter &rewriter) const override { |
1175 | if (failed(winogradConv2DHelper(rewriter, convOp, m, r))) |
1176 | return failure(); |
1177 | |
1178 | return success(); |
1179 | } |
1180 | |
1181 | private: |
1182 | int64_t m; |
1183 | int64_t r; |
1184 | }; |
1185 | } // end anonymous namespace |
1186 | |
1187 | //===----------------------------------------------------------------------===// |
1188 | FailureOr<Operation *> winogradConv2D(RewriterBase &rewriter, |
1189 | linalg::Conv2DNhwcFhwcOp op, int64_t m, |
1190 | int64_t r) { |
1191 | return winogradConv2DHelper(rewriter, op, m, r); |
1192 | } |
1193 | |
1194 | FailureOr<Operation *> |
1195 | decomposeWinogradFilterTransformOp(RewriterBase &rewriter, |
1196 | linalg::WinogradFilterTransformOp op) { |
1197 | return decomposeWinogradFilterTransformHelper(rewriter, op); |
1198 | } |
1199 | |
1200 | FailureOr<Operation *> |
1201 | decomposeWinogradInputTransformOp(RewriterBase &rewriter, |
1202 | linalg::WinogradInputTransformOp op) { |
1203 | return decomposeWinogradInputTransformHelper(rewriter, op); |
1204 | } |
1205 | |
1206 | FailureOr<Operation *> |
1207 | decomposeWinogradOutputTransformOp(RewriterBase &rewriter, |
1208 | linalg::WinogradOutputTransformOp op) { |
1209 | return decomposeWinogradOutputTransformHelper(rewriter, op); |
1210 | } |
1211 | |
1212 | void populateWinogradConv2DPatterns(RewritePatternSet &patterns, int64_t m, |
1213 | int64_t r) { |
1214 | MLIRContext *context = patterns.getContext(); |
1215 | // TODO: Support more Conv2D data layout, e.g., conv_2d_nchw_fchw |
1216 | patterns.insert<WinogradConv2DNhwcFhwc>(arg&: context, args&: m, args&: r); |
1217 | } |
1218 | |
1219 | void populateDecomposeWinogradOpsPatterns(RewritePatternSet &patterns) { |
1220 | MLIRContext *context = patterns.getContext(); |
1221 | patterns |
1222 | .insert<DecomposeWinogradFilterTransform, DecomposeWinogradInputTransform, |
1223 | DecomposeWinogradOutputTransform>(arg&: context); |
1224 | } |
1225 | |
1226 | } // end namespace linalg |
1227 | } // end namespace mlir |
1228 | |