1//===- WinogradConv2D.cpp - Winograd Conv2D implementation ----------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// Implement Winograd Conv2D algorithm. The implementation is based on the
10// paper: Fast Algorithms for Convolutional Neural Networks
11// (https://arxiv.org/abs/1509.09308)
12//
13//===----------------------------------------------------------------------===//
14
15#include "mlir/Dialect/Affine/IR/AffineOps.h"
16#include "mlir/Dialect/Arith/IR/Arith.h"
17#include "mlir/Dialect/Linalg/IR/Linalg.h"
18#include "mlir/Dialect/Linalg/Utils/Utils.h"
19#include "mlir/Dialect/Tensor/IR/Tensor.h"
20#include "mlir/Dialect/Utils/StaticValueUtils.h"
21#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
22#include "llvm/Support/MathExtras.h"
23
24namespace mlir {
25namespace linalg {
26
27namespace {
28
29// clang-format off
30/// Winograd Conv2D uses a minimal 2D filtering algorithm to calculate its
31/// result. The formula of minimal 2D filtering algorithm F(m x m, r x r),
32/// m is the output dimension and r is the filter dimension, is
33///
34/// Y = A^T x [ (G x g x G^T) x (B^T x d x B) ] x A
35///
36/// g is filter and d is input data. We need to prepare 6 constant
37/// transformation matrices, G, G^T, B^T, B, A^T, and A for this formula.
38///
39/// The following tables define these constant transformation matrices for
40/// F(2 x 2, 3 x 3), F(4 x 4, 3 x 3), and F(2 x 2, 5 x 5)
41constexpr float G_2x2_3x3[] = {
42 -1, 0, 0,
43 1./2, -1./2, 1./2,
44 1./2, 1./2, 1./2,
45 0, 0, 1
46};
47
48constexpr float GT_2x2_3x3[] = {
49 -1, 1./2, 1./2, 0,
50 0, -1./2, 1./2, 0,
51 0, 1./2, 1./2, 1
52};
53
54constexpr float BT_2x2_3x3[] = {
55 -1, 0, 1, 0,
56 0, -1, 1, 0,
57 0, 1, 1, 0,
58 0, -1, 0, 1
59};
60
61constexpr float B_2x2_3x3[] = {
62 -1, 0, 0, 0,
63 0, -1, 1, -1,
64 1, 1, 1, 0,
65 0, 0, 0, 1
66};
67
68constexpr float AT_2x2_3x3[] = {
69 1, 1, 1, 0,
70 0, -1, 1, 1
71};
72
73constexpr float A_2x2_3x3[] = {
74 1, 0,
75 1, -1,
76 1, 1,
77 0, 1
78};
79
80constexpr float G_4x4_3x3[] = {
81 1, 0, 0,
82 -1./3, 1./3, -1./3,
83 -1./3, -1./3, -1./3,
84 1./12, -1./6, 1./3,
85 1./12, 1./6, 1./3,
86 0, 0, 1
87};
88
89constexpr float GT_4x4_3x3[] = {
90 1, -1./3, -1./3, 1./12, 1./12, 0,
91 0, 1./3, -1./3, -1./6, 1./6, 0,
92 0, -1./3, -1./3, 1./3, 1./3, 1
93};
94
95constexpr float BT_4x4_3x3[] = {
96 1./4, 0, -5./16, 0, 1./16, 0,
97 0, 1./4, -1./4, -1./16, 1./16, 0,
98 0, -1./4, -1./4, 1./16, 1./16, 0,
99 0, 1./4, -1./8, -1./4, 1./8, 0,
100 0, -1./4, -1./8, 1./4, 1./8, 0,
101 0, 1./4, 0, -5./16, 0, 1./16
102};
103
104constexpr float B_4x4_3x3[] = {
105 1./4, 0, 0, 0, 0, 0,
106 0, 1./4, -1./4, 1./4, -1./4, 1./4,
107 -5./16, -1./4, -1./4, -1./8, -1./8, 0,
108 0, -1./16, 1./16, -1./4, 1./4, -5./16,
109 1./16, 1./16, 1./16, 1./8, 1./8, 0,
110 0, 0, 0, 0, 0, 1./16
111};
112
113constexpr float AT_4x4_3x3[] = {
114 1./8, 1./4, 1./4, 1./8, 1./8, 0,
115 0, -1./4, 1./4, -1./4, 1./4, 0,
116 0, 1./4, 1./4, 1./2, 1./2, 0,
117 0, -1./4, 1./4, -1, 1, 1./2
118};
119
120constexpr float A_4x4_3x3[] = {
121 1./8, 0, 0, 0,
122 1./4, -1./4, 1./4, -1./4,
123 1./4, 1./4, 1./4, 1./4,
124 1./8, -1./4, 1./2, -1,
125 1./8, 1./4, 1./2, 1,
126 0, 0, 0, 1./2
127};
128
129constexpr float G_2x2_5x5[] = {
130 1, 0, 0, 0, 0,
131 1./6, -1./6, 1./6, -1./6, 1./6,
132 -1./6, -1./6, -1./6, -1./6, -1./6,
133-4./15, 2./15, -1./15, 1./30, -1./60,
134 1./60, 1./30, 1./15, 2./15, 4./15,
135 0, 0, 0, 0, 1
136};
137
138constexpr float GT_2x2_5x5[] = {
139 1, 1./6, -1./6, -4./15, 1./60, 0,
140 0, -1./6, -1./6, 2./15, 1./30, 0,
141 0, 1./6, -1./6, -1./15, 1./15, 0,
142 0, -1./6, -1./6, 1./30, 2./15, 0,
143 0, 1./6, -1./6, -1./60, 4./15, 1
144};
145
146constexpr float BT_2x2_5x5[] = {
147 1./8, 3./16, -1./4, -3./16, 1./8, 0,
148 0, 1./8, 1./16, -5./16, 1./8, 0,
149 0, -1./8, -5./16, -1./16, 1./8, 0,
150 0, 1./4, -1./8, -1./4, 1./8, 0,
151 0, -1./8, -1./4, 1./8, 1./4, 0,
152 0, 1./8, 3./16, -1./4, -3./16, 1./8
153};
154
155constexpr float B_2x2_5x5[] = {
156 1./8, 0, 0, 0, 0, 0,
157 3./16, 1./8, -1./8, 1./4, -1./8, 1./8,
158 -1./4, 1./16, -5./16, -1./8, -1./4, 3./16,
159 -3./16, -5./16, -1./16, -1./4, 1./8, -1./4,
160 1./8, 1./8, 1./8, 1./8, 1./4, -3./16,
161 0, 0, 0, 0, 0, 1./8
162};
163
164constexpr float AT_2x2_5x5[] = {
165 1./2, 1, 1, 2, 1, 0,
166 0, -1, 1, -1, 2, 1./2
167};
168
169constexpr float A_2x2_5x5[] = {
170 1./2, 0,
171 1, -1,
172 1, 1,
173 2, -1,
174 1, 2,
175 0, 1./2
176};
177// clang-format on
178
179using TransformMapKeyTy = std::pair<int, int>;
180
181/// We use F(m, r) to define the size of minimal filtering algorithms.
182/// m is the output dimension and r is the filter dimension. We can get
183/// the input dimension, alpha, from the formula, alpha = m + r - 1.
184///
185/// For example, when m = 2 and r = 3, we know its input size is 4.
186/// The Conv2D will operate on 4x4 input data with 3x3 filter and get
187/// 2x2 output result.
188constexpr TransformMapKeyTy F_2_3{2, 3};
189constexpr TransformMapKeyTy F_4_3{4, 3};
190constexpr TransformMapKeyTy F_2_5{2, 5};
191
192/// Structure to keep information of constant transform matrices.
193struct TransformMatrix {
194 TransformMatrix(const float *table, int64_t rows, int64_t cols,
195 int64_t scalarFactor = 1)
196 : table(table), rows(rows), cols(cols), scalarFactor(scalarFactor) {}
197
198 const float *table;
199 int64_t rows;
200 int64_t cols;
201 int64_t scalarFactor;
202};
203
204/// Utility function to convert constant array to arith.constant Value.
205Value create2DTransformMatrix(OpBuilder &builder, Location loc,
206 TransformMatrix transform, Type type) {
207 ArrayRef<float> constVec(transform.table, transform.rows * transform.cols);
208
209 return builder.create<arith::ConstantOp>(
210 loc, DenseFPElementsAttr::get(
211 RankedTensorType::get(
212 SmallVector<int64_t>{transform.rows, transform.cols}, type),
213 constVec));
214}
215
216/// Extract height x width data from 4D tensors.
217Value extract2DDataFrom4D(OpBuilder &builder, Location loc, Value source,
218 Value loopNorFIndex, Value loopCorFIndex,
219 Value heightOffset, Value widthOffset,
220 int64_t extractHeight, int64_t extractWidth,
221 int64_t loopNorFIdx, int64_t loopCorFIdx,
222 int64_t heightIdx, int64_t widthIdx) {
223 auto sourceType = cast<ShapedType>(source.getType());
224 Type elementType = sourceType.getElementType();
225 int64_t srcSize = sourceType.getRank();
226
227 auto oneIndex = builder.getIndexAttr(1);
228 SmallVector<OpFoldResult> offsets;
229 offsets.resize(N: srcSize);
230 offsets[loopNorFIdx] = loopNorFIndex;
231 offsets[loopCorFIdx] = loopCorFIndex;
232 offsets[heightIdx] = heightOffset;
233 offsets[widthIdx] = widthOffset;
234 SmallVector<OpFoldResult> sizes(srcSize, oneIndex);
235 sizes[heightIdx] = builder.getIndexAttr(extractHeight);
236 sizes[widthIdx] = builder.getIndexAttr(extractWidth);
237 SmallVector<OpFoldResult> strides(srcSize, oneIndex);
238
239 auto extractFilterType =
240 RankedTensorType::get({extractHeight, extractWidth}, elementType);
241 auto extractFilterOp = builder.create<tensor::ExtractSliceOp>(
242 loc, extractFilterType, source, offsets, sizes, strides);
243
244 return extractFilterOp;
245}
246
247/// Extract height x width data from 6D tensors.
248Value extract2DDataFrom6D(OpBuilder &builder, Location loc, Value source,
249 Value tileHIndex, Value tileWIndex,
250 Value loopNorFIndex, Value loopCorFIndex,
251 int64_t tileHIdx, int64_t tileWIdx,
252 int64_t loopNorFIdx, int64_t loopCorFIdx,
253 int64_t heightIdx, int64_t widthIdx) {
254 auto sourceType = cast<ShapedType>(source.getType());
255 Type elementType = sourceType.getElementType();
256 auto sourceShape = sourceType.getShape();
257 int64_t srcSize = sourceType.getRank();
258 int64_t height = sourceShape[heightIdx];
259 int64_t width = sourceShape[widthIdx];
260
261 auto zeroIndex = builder.getIndexAttr(0);
262 auto oneIndex = builder.getIndexAttr(1);
263 SmallVector<OpFoldResult> offsets(srcSize, zeroIndex);
264 offsets.resize(N: srcSize);
265 offsets[tileHIdx] = tileHIndex;
266 offsets[tileWIdx] = tileWIndex;
267 offsets[loopNorFIdx] = loopNorFIndex;
268 offsets[loopCorFIdx] = loopCorFIndex;
269 SmallVector<OpFoldResult> sizes(srcSize, oneIndex);
270 sizes[heightIdx] = builder.getIndexAttr(height);
271 sizes[widthIdx] = builder.getIndexAttr(width);
272 SmallVector<OpFoldResult> strides(srcSize, oneIndex);
273
274 auto extractFilterType = RankedTensorType::get({height, width}, elementType);
275 auto extractFilterOp = builder.create<tensor::ExtractSliceOp>(
276 loc, extractFilterType, source, offsets, sizes, strides);
277
278 return extractFilterOp;
279}
280
281/// Insert transformed height x width data to 4D tensors which it is
282/// extracted from.
283Value insert2DDataTo4D(OpBuilder &builder, Location loc, Value source,
284 Value dest, Value loopNorFIndex, Value loopCorFIndex,
285 Value heightOffset, Value widthOffset, int64_t height,
286 int64_t width, int64_t loopNorFIdx, int64_t loopCorFIdx,
287 int64_t heightIdx, int64_t widthIdx) {
288 int64_t destSize = cast<ShapedType>(dest.getType()).getRank();
289 auto oneIndex = builder.getIndexAttr(1);
290 SmallVector<OpFoldResult> retOffsets;
291 retOffsets.resize(N: destSize);
292 retOffsets[loopNorFIdx] = loopNorFIndex;
293 retOffsets[loopCorFIdx] = loopCorFIndex;
294 retOffsets[heightIdx] = heightOffset;
295 retOffsets[widthIdx] = widthOffset;
296 SmallVector<OpFoldResult> retSizes(destSize, oneIndex);
297 retSizes[heightIdx] = builder.getIndexAttr(height);
298 retSizes[widthIdx] = builder.getIndexAttr(width);
299 SmallVector<OpFoldResult> strides(destSize, oneIndex);
300
301 auto insertSliceOp = builder.create<tensor::InsertSliceOp>(
302 loc, source, dest, retOffsets, retSizes, strides);
303
304 return insertSliceOp;
305}
306
307/// Insert transformed height x width data to 6D tensors which it is
308/// extracted from.
309Value insert2DDataTo6D(OpBuilder &builder, Location loc, Value source,
310 Value dest, Value tileHIndex, Value tileWIndex,
311 Value loopNorFIndex, Value loopCorFIndex, int64_t height,
312 int64_t width, int64_t tileHIdx, int64_t tileWIdx,
313 int64_t loopNorFIdx, int64_t loopCorFIdx,
314 int64_t heightIdx, int64_t widthIdx) {
315 int64_t destSize = cast<ShapedType>(dest.getType()).getRank();
316 auto zeroIndex = builder.getIndexAttr(0);
317 auto oneIndex = builder.getIndexAttr(1);
318 SmallVector<OpFoldResult> retOffsets(destSize, zeroIndex);
319 retOffsets.resize(N: destSize);
320 retOffsets[tileHIdx] = tileHIndex;
321 retOffsets[tileWIdx] = tileWIndex;
322 retOffsets[loopNorFIdx] = loopNorFIndex;
323 retOffsets[loopCorFIdx] = loopCorFIndex;
324 SmallVector<OpFoldResult> retSizes(destSize, oneIndex);
325 retSizes[heightIdx] = builder.getIndexAttr(height);
326 retSizes[widthIdx] = builder.getIndexAttr(width);
327 SmallVector<OpFoldResult> strides(destSize, oneIndex);
328
329 auto insertSliceOp = builder.create<tensor::InsertSliceOp>(
330 loc, source, dest, retOffsets, retSizes, strides);
331
332 return insertSliceOp;
333}
334
335/// This function transforms the filter. The data layout of the filter is FHWC.
336/// The transformation matrix is 2-dimension. We need to extract H x W from
337/// FHWC first. We need to generate 2 levels of loops to iterate on F and C.
338/// After the transformation, we get
339///
340/// scf.for %f = lo_f to hi_f step 1
341/// scf.for %c = lo_c to hi_c step 1
342/// %extracted = extract filter<h x w> from filter<f x h x w x c>
343/// %ret = linalg.matmul G, %extracted
344/// %ret = linalg.matmul %ret, GT
345/// %inserted = insert %ret into filter<h x w x c x f>
346Value filterTransform(RewriterBase &rewriter, Location loc, Value filter,
347 Value retValue, int64_t m, int64_t r,
348 bool leftTransform = true, bool rightTransform = true) {
349 // Map from (m, r) to G transform matrix.
350 static const llvm::SmallDenseMap<TransformMapKeyTy, TransformMatrix>
351 GMatrices = {
352 {F_2_3, TransformMatrix(G_2x2_3x3, 4, 3)},
353 {F_4_3, TransformMatrix(G_4x4_3x3, 6, 3)},
354 {F_2_5, TransformMatrix(G_2x2_5x5, 6, 5)},
355 };
356
357 // Map from (m, r) to GT transform matrix.
358 static const llvm::SmallDenseMap<TransformMapKeyTy, TransformMatrix>
359 GTMatrices = {
360 {F_2_3, TransformMatrix(GT_2x2_3x3, 3, 4)},
361 {F_4_3, TransformMatrix(GT_4x4_3x3, 3, 6)},
362 {F_2_5, TransformMatrix(GT_2x2_5x5, 5, 6)},
363 };
364
365 auto filterType = cast<ShapedType>(filter.getType());
366 Type elementType = filterType.getElementType();
367 auto filterShape = filterType.getShape(); // F, H, W, C
368 int64_t filterF = filterShape[0];
369 int64_t filterH = filterShape[1];
370 int64_t filterW = filterShape[2];
371 int64_t filterC = filterShape[3];
372
373 if (filterH != r && filterH != 1)
374 return Value();
375 if (filterW != r && filterW != 1)
376 return Value();
377
378 Value zeroIdx = rewriter.create<arith::ConstantIndexOp>(location: loc, args: 0);
379 auto buildBody = [&](OpBuilder &builder, Location loc, ValueRange ivs,
380 ValueRange args) -> scf::ValueVector {
381 Value FIter = ivs[0];
382 Value CIter = ivs[1];
383
384 // Extract (H, W) from (F, H, W, C).
385 auto extractFilter =
386 extract2DDataFrom4D(builder, loc, source: filter, loopNorFIndex: FIter, loopCorFIndex: CIter, heightOffset: zeroIdx,
387 widthOffset: zeroIdx, extractHeight: filterH, extractWidth: filterW, /*loopNorFIdx=*/0,
388 /*loopCorFIdx=*/3, /*heightIdx=*/1, /*widthIdx=*/2);
389
390 TransformMapKeyTy key = {m, r};
391 int64_t retRows = 1;
392 Value matmulRetValue = extractFilter;
393 Value zero = builder.create<arith::ConstantOp>(
394 loc, rewriter.getZeroAttr(elementType));
395 if (leftTransform) {
396 // Get constant transform matrix G.
397 auto it = GMatrices.find(Val: key);
398 if (it == GMatrices.end())
399 return {};
400 const TransformMatrix &GMatrix = it->second;
401
402 retRows = GMatrix.rows;
403 auto matmulType = RankedTensorType::get({retRows, filterW}, elementType);
404 auto empty =
405 builder
406 .create<tensor::EmptyOp>(loc, matmulType.getShape(), elementType)
407 .getResult();
408 auto init = builder.create<linalg::FillOp>(loc, zero, empty).getResult(0);
409
410 Value G = create2DTransformMatrix(builder, loc, transform: GMatrix, type: elementType);
411 // Multiply G x g.
412 auto matmulOp = builder.create<linalg::MatmulOp>(
413 loc, matmulType, ValueRange{G, extractFilter}, ValueRange{init});
414 matmulRetValue = matmulOp.getResult(0);
415 }
416
417 if (rightTransform) {
418 // Get constant transform matrix GT.
419 auto it = GTMatrices.find(Val: key);
420 if (it == GTMatrices.end())
421 return {};
422 const TransformMatrix &GTMatrix = it->second;
423
424 auto matmulType =
425 RankedTensorType::get({retRows, GTMatrix.cols}, elementType);
426 auto empty =
427 builder
428 .create<tensor::EmptyOp>(loc, matmulType.getShape(), elementType)
429 .getResult();
430 auto init = builder.create<linalg::FillOp>(loc, zero, empty).getResult(0);
431
432 Value GT = create2DTransformMatrix(builder, loc, transform: GTMatrix, type: elementType);
433 // Multiply u = (G x g) x GT.
434 auto matmulOp = builder.create<linalg::MatmulOp>(
435 loc, matmulType, ValueRange{matmulRetValue, GT}, ValueRange{init});
436 matmulRetValue = matmulOp.getResult(0);
437 }
438
439 // Insert (H, W) to (H, W, C, F).
440 int64_t retHeight = leftTransform ? m + r - 1 : 1;
441 int64_t retWidth = rightTransform ? m + r - 1 : 1;
442
443 auto insertSliceOp =
444 insert2DDataTo4D(builder, loc, source: matmulRetValue, dest: args[0], loopNorFIndex: FIter, loopCorFIndex: CIter,
445 heightOffset: zeroIdx, widthOffset: zeroIdx, height: retHeight, width: retWidth,
446 /*loopNorFIdx=*/3, /*loopCorFIdx=*/2,
447 /*heightIdx=*/0, /*widthIdx=*/1);
448
449 return {insertSliceOp};
450 };
451
452 auto fUpperBound = rewriter.create<arith::ConstantIndexOp>(location: loc, args&: filterF);
453 auto cUpperBound = rewriter.create<arith::ConstantIndexOp>(location: loc, args&: filterC);
454 auto oneStep = rewriter.create<arith::ConstantIndexOp>(location: loc, args: 1);
455 scf::LoopNest loops = scf::buildLoopNest(
456 rewriter, loc, {zeroIdx, zeroIdx}, {fUpperBound, cUpperBound},
457 {oneStep, oneStep}, {retValue}, buildBody);
458 return loops.results[0];
459}
460
461/// This function transforms the input. The data layout of the input is NHWC.
462/// The transformation matrix is 2-dimension. We need to extract H x W from
463/// NHWC first. We need to generate 2 levels of loops to iterate on N and C.
464/// After the transformation, we get
465///
466/// scf.for %h = 0 to tileH step 1
467/// scf.for %w = 0 to tileW step 1
468/// scf.for %n = 0 to N step 1
469/// scf.for %c = 0 to C step 1
470/// %extracted = extract %extracted<alphaH x alphaW> from
471/// %input<N x H x W x C>
472/// at [%n, (%h x m), (%w x m), %c]
473/// %ret = linalg.matmul BT, %extracted
474/// %ret = linalg.matmul %ret, B
475/// %inserted = insert %ret<alphaH x alphaW> into
476/// %output<alphaH x alphaW x tileH x tileW x N x C>
477/// at [0, 0, %h, %w, %n, %c]
478Value inputTransform(RewriterBase &rewriter, Location loc, Value input,
479 Value retValue, int64_t m, int64_t r,
480 bool leftTransform = true, bool rightTransform = true) {
481 // Map from (m, r) to BT transform matrix.
482 static const llvm::SmallDenseMap<TransformMapKeyTy, TransformMatrix>
483 BTMatrices = {
484 {F_2_3, TransformMatrix(BT_2x2_3x3, 4, 4)},
485 {F_4_3, TransformMatrix(BT_4x4_3x3, 6, 6)},
486 {F_2_5, TransformMatrix(BT_2x2_5x5, 6, 6)},
487 };
488
489 // Map from (m, r) to B transform matrix.
490 static const llvm::SmallDenseMap<TransformMapKeyTy, TransformMatrix>
491 BMatrices = {
492 {F_2_3, TransformMatrix(B_2x2_3x3, 4, 4)},
493 {F_4_3, TransformMatrix(B_4x4_3x3, 6, 6)},
494 {F_2_5, TransformMatrix(B_2x2_5x5, 6, 6)},
495 };
496
497 auto inputType = cast<ShapedType>(input.getType());
498 Type elementType = inputType.getElementType();
499 auto inputShape = inputType.getShape(); // N, H, W, C
500 int64_t inputN = inputShape[0];
501 int64_t inputC = inputShape[3];
502 auto valueType = cast<ShapedType>(retValue.getType());
503 auto valueShape = valueType.getShape(); // alphaH, alphaW, HTile, WTile, N, C
504 int64_t tileH = valueShape[2];
505 int64_t tileW = valueShape[3];
506 int64_t alphaH = leftTransform ? m + r - 1 : 1;
507 int64_t alphaW = rightTransform ? m + r - 1 : 1;
508
509 auto buildBody = [&](OpBuilder &builder, Location loc, ValueRange ivs,
510 ValueRange args) -> scf::ValueVector {
511 Value tileHIter = ivs[0];
512 Value tileWIter = ivs[1];
513 Value NIter = ivs[2];
514 Value CIter = ivs[3];
515
516 auto context = builder.getContext();
517
518 auto identityAffineMap = rewriter.getMultiDimIdentityMap(rank: 1);
519 auto affineMap =
520 AffineMap::get(dimCount: 1, symbolCount: 0, results: {builder.getAffineDimExpr(position: 0) * m}, context);
521 Value heightOffset = builder.create<affine::AffineApplyOp>(
522 loc, leftTransform ? affineMap : identityAffineMap, tileHIter);
523 Value widthOffset = builder.create<affine::AffineApplyOp>(
524 loc, rightTransform ? affineMap : identityAffineMap, tileWIter);
525
526 // Extract (H, W) from (N, H, W, C).
527 auto extractInput =
528 extract2DDataFrom4D(builder, loc, source: input, loopNorFIndex: NIter, loopCorFIndex: CIter, heightOffset,
529 widthOffset, extractHeight: alphaH, extractWidth: alphaW, /*loopNorFIdx=*/0,
530 /*loopCorFIdx=*/3, /*heightIdx=*/1, /*widthIdx=*/2);
531
532 TransformMapKeyTy key = {m, r};
533 int64_t retRows = 1;
534 int64_t retCols = 1;
535 Value matmulRetValue = extractInput;
536 Value zero = builder.create<arith::ConstantOp>(
537 loc, rewriter.getZeroAttr(elementType));
538 if (leftTransform) {
539 // Get constant transform matrix BT.
540 auto it = BTMatrices.find(Val: key);
541 if (it == BTMatrices.end())
542 return {};
543 const TransformMatrix &BTMatrix = it->second;
544
545 retRows = BTMatrix.rows;
546 auto matmulType = RankedTensorType::get({retRows, alphaW}, elementType);
547 auto empty =
548 builder
549 .create<tensor::EmptyOp>(loc, matmulType.getShape(), elementType)
550 .getResult();
551 auto init = builder.create<linalg::FillOp>(loc, zero, empty).getResult(0);
552
553 Value BT =
554 create2DTransformMatrix(builder, loc, BTMatrix, builder.getF32Type());
555 // Multiply BT x d.
556 auto matmulOp = builder.create<linalg::MatmulOp>(
557 loc, matmulType, ValueRange{BT, matmulRetValue}, ValueRange{init});
558 matmulRetValue = matmulOp.getResult(0);
559 }
560
561 if (rightTransform) {
562 // Get constant transform matrix B.
563 auto it = BMatrices.find(Val: key);
564 if (it == BMatrices.end())
565 return {};
566 const TransformMatrix &BMatrix = it->second;
567
568 retCols = BMatrix.cols;
569 auto matmulType = RankedTensorType::get({retRows, retCols}, elementType);
570 auto empty =
571 builder
572 .create<tensor::EmptyOp>(loc, matmulType.getShape(), elementType)
573 .getResult();
574 auto init = builder.create<linalg::FillOp>(loc, zero, empty).getResult(0);
575 Value B =
576 create2DTransformMatrix(builder, loc, BMatrix, builder.getF32Type());
577 // Multiply v = (BT x d) x B.
578 auto matmulOp = builder.create<linalg::MatmulOp>(
579 loc, matmulType, ValueRange{matmulRetValue, B}, ValueRange{init});
580 matmulRetValue = matmulOp.getResult(0);
581 }
582
583 // Insert (H, W) to (H, W, tileH, tileW, N, C).
584 auto combinedVal = insert2DDataTo6D(
585 builder, loc, source: matmulRetValue, dest: args[0], tileHIndex: tileHIter, tileWIndex: tileWIter, loopNorFIndex: NIter,
586 loopCorFIndex: CIter, height: retRows, width: retCols, tileHIdx: 2, tileWIdx: 3, /*loopNorFIdx=*/4, /*loopCorFIdx=*/5,
587 /*heightIdx=*/0, /*widthIdx=*/1);
588
589 return {combinedVal};
590 };
591
592 auto zeroIdx = rewriter.create<arith::ConstantIndexOp>(location: loc, args: 0);
593 auto tileHBound = rewriter.create<arith::ConstantIndexOp>(location: loc, args&: tileH);
594 auto tileWBound = rewriter.create<arith::ConstantIndexOp>(location: loc, args&: tileW);
595 auto nUpperBound = rewriter.create<arith::ConstantIndexOp>(location: loc, args&: inputN);
596 auto cUpperBound = rewriter.create<arith::ConstantIndexOp>(location: loc, args&: inputC);
597 auto oneStep = rewriter.create<arith::ConstantIndexOp>(location: loc, args: 1);
598 scf::LoopNest loops = scf::buildLoopNest(
599 rewriter, loc, {zeroIdx, zeroIdx, zeroIdx, zeroIdx},
600 {tileHBound, tileWBound, nUpperBound, cUpperBound},
601 {oneStep, oneStep, oneStep, oneStep}, {retValue}, buildBody);
602 return loops.results[0];
603}
604
605/// This function generates linalg.batch_matmul to multiply input with filter.
606/// linalg.batch_matmul only supports 3-dimensional inputs. We can treat
607/// tileH x tileW x H x W data as the 1-dimensional data array. That is to
608/// convert [tileH, tileW, H, W, N, C] to [tileH x tileW x H x W, N, C]. In this
609/// way, we can convert 6-dimensional inputs to 3-dimensional representation
610/// that is suitable for linalg.batch_matmul.
611///
612/// Batched matmul will do the matrix multiply with the reduction on channel.
613///
614/// We get
615///
616/// %collapsed_input = tensor.collapse_shape %input
617/// %collapsed_filter = tensor.collapse_shape %filter
618/// %ret = linalg.batch_matmul %collapsed_input, %collapsed_filter
619/// %expanded_ret = tensor.expand_shape %ret
620///
621/// After this function, we get return value with data layout
622/// (tileH, tileW, H, W, N, F).
623static Value matrixMultiply(RewriterBase &rewriter, Location loc,
624 Value transformedFilter, Value transformedInput,
625 Type outputElementType) {
626 // Convert (alphaH, alphaW, C, F) to (alphaH x alphaW, C, F) for filter.
627 auto filterType = cast<ShapedType>(transformedFilter.getType());
628 assert(filterType.hasStaticShape() && "only support static shapes.");
629 ArrayRef<int64_t> filterShape = filterType.getShape();
630 Type filterElementType = filterType.getElementType();
631 auto filterReassocType = RankedTensorType::get(
632 {filterShape[0] * filterShape[1], filterShape[2], filterShape[3]},
633 filterElementType);
634 SmallVector<ReassociationIndices> filterReassoc = {{0, 1}, {2}, {3}};
635 Value collapseFilter = rewriter.create<tensor::CollapseShapeOp>(
636 loc, filterReassocType, transformedFilter, filterReassoc);
637
638 // Convert (alphaH, alphaW, tileH, tileW, N, C) to
639 // (alphaH x alphaW, tileH x tileW x N, C) for input.
640 auto inputType = cast<ShapedType>(transformedInput.getType());
641 assert(inputType.hasStaticShape() && "only support static shapes.");
642 ArrayRef<int64_t> inputShape = inputType.getShape();
643 Type inputElementType = inputType.getElementType();
644 auto inputReassocType = RankedTensorType::get(
645 {inputShape[0] * inputShape[1],
646 inputShape[2] * inputShape[3] * inputShape[4], inputShape[5]},
647 inputElementType);
648 SmallVector<ReassociationIndices> inputReassoc = {{0, 1}, {2, 3, 4}, {5}};
649 Value collapseInput = rewriter.create<tensor::CollapseShapeOp>(
650 loc, inputReassocType, transformedInput, inputReassoc);
651
652 // Batched matrix multiply.
653 auto matmulType = RankedTensorType::get(
654 {inputShape[0] * inputShape[1],
655 inputShape[2] * inputShape[3] * inputShape[4], filterShape[3]},
656 outputElementType);
657 Value empty = rewriter
658 .create<tensor::EmptyOp>(loc, matmulType.getShape(),
659 outputElementType)
660 .getResult();
661 Value zero = rewriter.create<arith::ConstantOp>(
662 loc, rewriter.getZeroAttr(outputElementType));
663 Value init = rewriter.create<linalg::FillOp>(loc, zero, empty).getResult(0);
664
665 auto matmulOp = rewriter.create<linalg::BatchMatmulOp>(
666 loc, matmulType, ValueRange({collapseInput, collapseFilter}),
667 ValueRange{init});
668
669 // The result shape of batch matmul is (alphaH x alphaW, tileH x tileW x N, F)
670 // Expand matmul result to (alphaH, alphaW, tileH, tileW, N, F).
671 SmallVector<ReassociationIndices> outputReassoc = {{0, 1}, {2, 3, 4}, {5}};
672 auto outputReassocType =
673 RankedTensorType::get({inputShape[0], inputShape[1], inputShape[2],
674 inputShape[3], inputShape[4], filterShape[3]},
675 outputElementType);
676 auto expandOutput = rewriter.create<tensor::ExpandShapeOp>(
677 loc, outputReassocType, matmulOp.getResult(0), outputReassoc);
678 return expandOutput;
679}
680
681/// This function transforms the output. The data layout of the output is HWNF.
682/// The transformation matrix is 2-dimension. We need to extract H x W from
683/// HWNF first. We need to generate 2 levels of loops to iterate on N and F.
684/// After the transformation, we get
685///
686/// scf.for %h = 0 to tileH step 1
687/// scf.for %w = 0 to tileW step 1
688/// scf.for %n = 0 to N step 1
689/// scf.for %f = 0 to F step 1
690/// %extracted = extract %extracted<alphaH x alphaW> from
691/// %input<alphaH x alphaW x tileH x tileW x N x F>
692/// at [0, 0, %h, %w, %n, %f]
693/// %ret = linalg.matmul AT, %extracted
694/// %ret = linalg.matmul %ret, A
695/// %inserted = insert %ret<alphaH x alphaW> into
696/// output<N x H x W x F>
697/// at [%n, (%h x m), (%w x m), %f]
698Value outputTransform(RewriterBase &rewriter, Location loc, Value value,
699 Value output, int64_t m, int64_t r,
700 bool leftTransform = true, bool rightTransform = true) {
701 // Map from (m, r) to AT transform matrix.
702 static const llvm::SmallDenseMap<TransformMapKeyTy, TransformMatrix>
703 ATMatrices = {
704 {F_2_3, TransformMatrix(AT_2x2_3x3, 2, 4)},
705 {F_4_3, TransformMatrix(AT_4x4_3x3, 4, 6, 32)},
706 {F_2_5, TransformMatrix(AT_2x2_5x5, 2, 6, 16)},
707 };
708
709 // Map from (m, r) to A transform matrix.
710 static const llvm::SmallDenseMap<TransformMapKeyTy, TransformMatrix>
711 AMatrices = {
712 {F_2_3, TransformMatrix(A_2x2_3x3, 4, 2)},
713 {F_4_3, TransformMatrix(A_4x4_3x3, 6, 4, 32)},
714 {F_2_5, TransformMatrix(A_2x2_5x5, 6, 2, 16)},
715 };
716
717 auto valueType = cast<ShapedType>(value.getType());
718 Type elementType = valueType.getElementType();
719 auto valueShape = valueType.getShape(); // H, W, TileH, TileW, N, F
720 int64_t valueH = valueShape[0];
721 int64_t valueW = valueShape[1];
722 int64_t valueN = valueShape[4];
723 int64_t valueF = valueShape[5];
724 int64_t alphaH = leftTransform ? m + r - 1 : 1;
725 int64_t alphaW = rightTransform ? m + r - 1 : 1;
726
727 if (valueH != alphaH && valueH != 1)
728 return Value();
729 if (valueW != alphaW && valueW != 1)
730 return Value();
731
732 auto buildBody = [&](OpBuilder &builder, Location loc, ValueRange ivs,
733 ValueRange args) -> scf::ValueVector {
734 auto context = builder.getContext();
735 Value tileHIter = ivs[0];
736 Value tileWIter = ivs[1];
737 Value NIter = ivs[2];
738 Value FIter = ivs[3];
739
740 // Extract (H, W) from (H, W, tileH, tileW, N, F).
741 auto extractValue =
742 extract2DDataFrom6D(builder, loc, source: value, tileHIndex: tileHIter, tileWIndex: tileWIter, loopNorFIndex: NIter,
743 loopCorFIndex: FIter, tileHIdx: 2, tileWIdx: 3, /*loopNorFIdx=*/4,
744 /*loopCorFIdx=*/5, /*heightIdx=*/0, /*widthIdx=*/1);
745
746 const TransformMapKeyTy key = {m, r};
747 const TransformMatrix &AMatrix = AMatrices.at(Val: key);
748 const TransformMatrix &ATMatrix = ATMatrices.at(Val: key);
749 int64_t scalarFactor = (rightTransform ? AMatrix.scalarFactor : 1) *
750 (leftTransform ? ATMatrix.scalarFactor : 1);
751 int64_t retCols = rightTransform ? AMatrix.cols : 1;
752 int64_t retRows = leftTransform ? ATMatrix.rows : 1;
753
754 Value matmulRetValue = extractValue;
755 Value zero = builder.create<arith::ConstantOp>(
756 loc, rewriter.getZeroAttr(elementType));
757
758 auto identityAffineMap = rewriter.getMultiDimIdentityMap(rank: 1);
759 auto affineMap =
760 AffineMap::get(dimCount: 1, symbolCount: 0, results: {builder.getAffineDimExpr(position: 0) * m}, context);
761 Value heightOffset = builder.create<affine::AffineApplyOp>(
762 loc, leftTransform ? affineMap : identityAffineMap, tileHIter);
763 Value widthOffset = builder.create<affine::AffineApplyOp>(
764 loc, rightTransform ? affineMap : identityAffineMap, tileWIter);
765
766 Value outInitVal =
767 extract2DDataFrom4D(builder, loc, source: args[0], loopNorFIndex: NIter, loopCorFIndex: FIter, heightOffset,
768 widthOffset, extractHeight: retRows, extractWidth: retCols,
769 /*loopNorFIdx=*/0,
770 /*loopCorFIdx=*/3, /*heightIdx=*/1,
771 /*widthIdx=*/2);
772 if (leftTransform) {
773 auto matmulType = RankedTensorType::get({retRows, valueW}, elementType);
774 Value init = outInitVal;
775 if (rightTransform || scalarFactor != 1) {
776 auto empty = builder
777 .create<tensor::EmptyOp>(loc, matmulType.getShape(),
778 elementType)
779 .getResult();
780 init = builder.create<linalg::FillOp>(loc, zero, empty).getResult(0);
781 }
782
783 Value AT = create2DTransformMatrix(builder, loc, transform: ATMatrix, type: elementType);
784 // Multiply AT x m.
785 auto matmulOp = builder.create<linalg::MatmulOp>(
786 loc, matmulType, ValueRange{AT, matmulRetValue}, ValueRange{init});
787 matmulRetValue = matmulOp.getResult(0);
788 }
789
790 if (rightTransform) {
791 auto matmulType =
792 RankedTensorType::get({retRows, AMatrix.cols}, elementType);
793 Value init = outInitVal;
794 if (scalarFactor != 1) {
795 auto empty = builder
796 .create<tensor::EmptyOp>(loc, matmulType.getShape(),
797 elementType)
798 .getResult();
799 init = builder.create<linalg::FillOp>(loc, zero, empty).getResult(0);
800 }
801
802 Value A = create2DTransformMatrix(builder, loc, transform: AMatrix, type: elementType);
803 // Multiply y = (AT x m) x A.
804 auto matmulOp = builder.create<linalg::MatmulOp>(
805 loc, matmulType, ValueRange{matmulRetValue, A}, ValueRange{init});
806 matmulRetValue = matmulOp.getResult(0);
807 }
808
809 if (scalarFactor != 1) {
810 // Multiply by scalar factor and add outInitVal.
811 Value scalarFactorValue = builder.create<arith::ConstantOp>(
812 loc, FloatAttr::get(elementType, scalarFactor));
813 auto matmulType = RankedTensorType::get({retRows, retCols}, elementType);
814 auto identityAffineMap = rewriter.getMultiDimIdentityMap(rank: 2);
815 SmallVector<AffineMap> affineMaps = {
816 AffineMap::get(dimCount: 2, symbolCount: 0, context), identityAffineMap, identityAffineMap};
817
818 matmulRetValue =
819 rewriter
820 .create<linalg::GenericOp>(
821 loc, matmulType,
822 ValueRange{scalarFactorValue, matmulRetValue},
823 ValueRange{outInitVal}, affineMaps,
824 llvm::ArrayRef<utils::IteratorType>{
825 utils::IteratorType::parallel,
826 utils::IteratorType::parallel},
827 [&](OpBuilder &nestedBuilder, Location nestedLoc,
828 ValueRange args) {
829 auto mulf = nestedBuilder.create<arith::MulFOp>(
830 nestedLoc, args[0], args[1]);
831 auto addf = nestedBuilder.create<arith::AddFOp>(
832 nestedLoc, mulf.getResult(), args[2]);
833 nestedBuilder.create<linalg::YieldOp>(nestedLoc,
834 addf.getResult());
835 })
836 .getResult(0);
837 }
838
839 // Insert (H, W) to (N, H, W, F).
840 Value combinedVal =
841 insert2DDataTo4D(builder, loc, source: matmulRetValue, dest: args[0], loopNorFIndex: NIter, loopCorFIndex: FIter,
842 heightOffset, widthOffset, height: retRows, width: retCols,
843 /*loopNorFIdx=*/0,
844 /*loopCorFIdx=*/3, /*heightIdx=*/1,
845 /*widthIdx=*/2);
846
847 return {combinedVal};
848 };
849
850 int64_t tilwH = valueShape[2];
851 int64_t tileW = valueShape[3];
852 auto zeroIdx = rewriter.create<arith::ConstantIndexOp>(location: loc, args: 0);
853 auto tileHBound = rewriter.create<arith::ConstantIndexOp>(location: loc, args&: tilwH);
854 auto tileWBound = rewriter.create<arith::ConstantIndexOp>(location: loc, args&: tileW);
855 auto nUpperBound = rewriter.create<arith::ConstantIndexOp>(location: loc, args&: valueN);
856 auto fUpperBound = rewriter.create<arith::ConstantIndexOp>(location: loc, args&: valueF);
857 auto oneStep = rewriter.create<arith::ConstantIndexOp>(location: loc, args: 1);
858 scf::LoopNest loops = scf::buildLoopNest(
859 rewriter, loc, {zeroIdx, zeroIdx, zeroIdx, zeroIdx},
860 {tileHBound, tileWBound, nUpperBound, fUpperBound},
861 {oneStep, oneStep, oneStep, oneStep}, {output}, buildBody);
862 return loops.results[0];
863}
864
865/// Create an empty tensor with alignedType and insert the value into the
866/// created empty tensor with aligned size.
867static Value padToAlignedTensor(RewriterBase &rewriter, Location loc,
868 Value value, ArrayRef<int64_t> alignedShape) {
869 auto valueType = cast<ShapedType>(value.getType());
870 Type elementType = valueType.getElementType();
871 auto alignedType = RankedTensorType::get(alignedShape, elementType);
872 Value padValue = rewriter.create<arith::ConstantOp>(
873 loc, elementType, rewriter.getZeroAttr(elementType));
874
875 return linalg::makeComposedPadHighOp(b&: rewriter, loc, type: alignedType, source: value,
876 pad: padValue, nofold: false);
877}
878
879/// Extract sub-tensor with extractedType from value.
880static Value extractFromAlignedTensor(RewriterBase &rewriter, Location loc,
881 Value value,
882 RankedTensorType extractedType) {
883 OpFoldResult zeroIndex = rewriter.getIndexAttr(0);
884 OpFoldResult oneIndex = rewriter.getIndexAttr(1);
885 SmallVector<OpFoldResult, 4> offsets(4, zeroIndex);
886 SmallVector<OpFoldResult, 4> strides(4, oneIndex);
887
888 ArrayRef<int64_t> extractedShape = extractedType.getShape();
889 SmallVector<OpFoldResult> sizes =
890 getAsOpFoldResult(rewriter.getI64ArrayAttr(extractedShape));
891
892 return rewriter.create<tensor::ExtractSliceOp>(loc, extractedType, value,
893 offsets, sizes, strides);
894}
895
896/// Utility function to check all values in the attribute are 1.
897static bool hasAllOneValues(DenseIntElementsAttr attr) {
898 return llvm::all_of(
899 Range&: attr, P: [](const APInt &element) { return element.getSExtValue() == 1; });
900}
901
902/// A helper function to convert linalg.conv_2d_nhwc_fhwc to
903/// linalg.winograd_*_transform ops.
904static FailureOr<Operation *>
905winogradConv2DHelper(RewriterBase &rewriter, linalg::Conv2DNhwcFhwcOp convOp,
906 int64_t m, int64_t r) {
907 Value input = convOp.getInputs()[0];
908 Value filter = convOp.getInputs()[1];
909 Value output = convOp.getOutputs()[0];
910 auto inputType = cast<ShapedType>(input.getType());
911 auto filterType = cast<ShapedType>(filter.getType());
912 auto outputType = cast<ShapedType>(output.getType());
913
914 if (!inputType.hasStaticShape())
915 return rewriter.notifyMatchFailure(convOp,
916 "expected a static shape for the input");
917
918 if (!filterType.hasStaticShape())
919 return rewriter.notifyMatchFailure(
920 convOp, "expected a static shape for the filter");
921
922 if (!hasAllOneValues(convOp.getDilations()))
923 return rewriter.notifyMatchFailure(convOp,
924 "expected all ones for dilations");
925
926 if (!hasAllOneValues(convOp.getStrides()))
927 return rewriter.notifyMatchFailure(convOp, "expected all ones for strides");
928
929 ArrayRef<int64_t> filterShape = filterType.getShape();
930 int64_t filterF = filterShape[0];
931 int64_t filterH = filterShape[1];
932 int64_t filterW = filterShape[2];
933 int64_t filterC = filterShape[3];
934 ArrayRef<int64_t> inputShape = inputType.getShape();
935 int64_t inputN = inputShape[0];
936 int64_t inputH = inputShape[1];
937 int64_t inputW = inputShape[2];
938 int64_t inputC = inputShape[3];
939 ArrayRef<int64_t> outputShape = outputType.getShape();
940 int64_t outputN = outputShape[0];
941 int64_t outputH = outputShape[1];
942 int64_t outputW = outputShape[2];
943 int64_t outputF = outputShape[3];
944
945 // Only support F(m x m, r x r), F(m x 1, r x 1) or F(1 x m, 1 x r).
946 bool isSupportedFilter = false;
947 if (filterH == filterW && filterH == r)
948 isSupportedFilter = true;
949 if (filterH == r && filterW == 1)
950 isSupportedFilter = true;
951 if (filterH == 1 && filterW == r)
952 isSupportedFilter = true;
953
954 if (!isSupportedFilter)
955 return rewriter.notifyMatchFailure(
956 convOp, "only support filter (r x r), (r x 1) or (1 x r)");
957
958 // Currently, we support (m, r) = (2, 3) or (4, 3) or (2, 5).
959 static const llvm::SmallVector<TransformMapKeyTy, 3> validConfigs = {
960 F_2_3, F_4_3, F_2_5};
961
962 TransformMapKeyTy key = {m, r};
963 auto it = llvm::find(Range: validConfigs, Val: key);
964 // If we cannot find the constant transformation matrix, it means we do
965 // not support this configuration yet.
966 if (it == validConfigs.end())
967 return failure();
968
969 // All the criterias are satisfied. We can do Winograd Conv2D.
970 Location loc = convOp.getLoc();
971
972 // For F(m x 1, r x 1), we only need to do left side transform.
973 bool leftTransform = filterH != 1;
974 // For F(1 x m, 1 x r), we only need to do right side transform.
975 bool rightTransform = filterW != 1;
976 int64_t heightM = leftTransform ? m : 1;
977 int64_t widthM = rightTransform ? m : 1;
978 int64_t heightR = leftTransform ? r : 1;
979 int64_t widthR = rightTransform ? r : 1;
980
981 // --- Create operation for filter transform ---
982 Type filterElementType = filterType.getElementType();
983 int64_t alphaH = heightM + heightR - 1;
984 int64_t alphaW = widthM + widthR - 1;
985 int64_t tileH = llvm::divideCeilSigned(Numerator: outputH, Denominator: heightM);
986 int64_t tileW = llvm::divideCeilSigned(Numerator: outputW, Denominator: widthM);
987 auto retType = RankedTensorType::get({alphaH, alphaW, filterC, filterF},
988 filterElementType);
989 Value retValue = rewriter.create<tensor::EmptyOp>(loc, retType.getShape(),
990 filterElementType);
991 auto transformedFilter = rewriter.create<linalg::WinogradFilterTransformOp>(
992 loc, retType, filter, retValue, m, r);
993
994 // --- Create operation for input transform ---
995
996 // When input size - (r - 1) is not aligned with output tile size, we need to
997 // pad the input data to create the full tiles as tiling.
998 Type inputElementType = inputType.getElementType();
999 int64_t alignedInputH = tileH * heightM + (heightR - 1);
1000 int64_t alignedInputW = tileW * widthM + (widthR - 1);
1001 if (alignedInputH != inputH || alignedInputW != inputW) {
1002 input = padToAlignedTensor(rewriter, loc, value: input,
1003 alignedShape: {inputN, alignedInputH, alignedInputW, inputC});
1004 }
1005
1006 retType = RankedTensorType::get(
1007 {alphaH, alphaW, tileH, tileW, inputN, inputC}, inputElementType);
1008 retValue = rewriter.create<tensor::EmptyOp>(loc, retType.getShape(),
1009 inputElementType);
1010 auto transformedInput = rewriter.create<linalg::WinogradInputTransformOp>(
1011 loc, retType, input, retValue, m, r);
1012
1013 Type outputElementType = outputType.getElementType();
1014 Value matmulRet = matrixMultiply(rewriter, loc, transformedFilter,
1015 transformedInput, outputElementType);
1016
1017 // --- Create operation for output transform ---
1018
1019 // When output size is not aligned with output tile size, we need to pad the
1020 // output buffer to insert the full tiles after tiling.
1021 int64_t alignedOutputH = tileH * heightM;
1022 int64_t alignedOutputW = tileW * widthM;
1023 bool isOutputUnaligned =
1024 ((alignedOutputH != outputH) || (alignedOutputW != outputW));
1025 if (isOutputUnaligned) {
1026 auto alignedOutputType = RankedTensorType::get(
1027 {outputN, alignedOutputH, alignedOutputW, outputF}, outputElementType);
1028 output =
1029 padToAlignedTensor(rewriter, loc, output, alignedOutputType.getShape());
1030 outputType = alignedOutputType;
1031 }
1032
1033 Value transformedOutput = rewriter.create<linalg::WinogradOutputTransformOp>(
1034 loc, outputType, matmulRet, output, m, r);
1035
1036 // When output size is not aligned with output tile size, extract the
1037 // value from the padded buffer.
1038 if (isOutputUnaligned) {
1039 transformedOutput = extractFromAlignedTensor(
1040 rewriter, loc, transformedOutput,
1041 RankedTensorType::get({outputN, outputH, outputW, outputF},
1042 outputElementType));
1043 }
1044
1045 rewriter.replaceOp(convOp, transformedOutput);
1046
1047 return transformedOutput.getDefiningOp();
1048}
1049
1050/// A helper function to decompose linalg.winograd_filter_transform.
1051FailureOr<Operation *>
1052decomposeWinogradFilterTransformHelper(RewriterBase &rewriter,
1053 linalg::WinogradFilterTransformOp op) {
1054 Location loc = op.getLoc();
1055 Value filter = op.getFilter();
1056 auto filterType = cast<ShapedType>(filter.getType());
1057 auto filterShape = filterType.getShape();
1058 int64_t filterH = filterShape[1];
1059 int64_t filterW = filterShape[2];
1060
1061 // For F(m x 1, r x 1), we only need to do left side transform.
1062 bool leftTransform = filterH != 1;
1063 // For F(1 x m, 1 x r), we only need to do right side transform.
1064 bool rightTransform = filterW != 1;
1065 Value transformedFilter =
1066 filterTransform(rewriter, loc, filter, op.getOutput(), op.getM(),
1067 op.getR(), leftTransform, rightTransform);
1068 if (!transformedFilter)
1069 return failure();
1070
1071 rewriter.replaceOp(op, transformedFilter);
1072
1073 return transformedFilter.getDefiningOp();
1074}
1075
1076/// A helper function to decompose linalg.winograd_input_transform.
1077FailureOr<Operation *>
1078decomposeWinogradInputTransformHelper(RewriterBase &rewriter,
1079 linalg::WinogradInputTransformOp op) {
1080 Location loc = op.getLoc();
1081 Value output = op.getOutput();
1082 auto outputType = cast<ShapedType>(output.getType());
1083 auto outputShape = outputType.getShape();
1084
1085 int64_t outputH = outputShape[0];
1086 int64_t outputW = outputShape[1];
1087
1088 // For F(m x 1, r x 1), we only need to do left side transform.
1089 bool leftTransform = outputH != 1;
1090 // For F(1 x m, 1 x r), we only need to do right side transform.
1091 bool rightTransform = outputW != 1;
1092 Value transformedInput =
1093 inputTransform(rewriter, loc, op.getInput(), op.getOutput(), op.getM(),
1094 op.getR(), leftTransform, rightTransform);
1095 if (!transformedInput)
1096 return failure();
1097
1098 rewriter.replaceOp(op, transformedInput);
1099
1100 return transformedInput.getDefiningOp();
1101}
1102
1103/// A helper function to decompose linalg.winograd_output_transform.
1104FailureOr<Operation *>
1105decomposeWinogradOutputTransformHelper(RewriterBase &rewriter,
1106 linalg::WinogradOutputTransformOp op) {
1107 Location loc = op.getLoc();
1108 Value value = op.getValue();
1109 auto valueType = cast<ShapedType>(value.getType());
1110 auto valueShape = valueType.getShape();
1111 int64_t valueH = valueShape[0];
1112 int64_t valueW = valueShape[1];
1113
1114 // For F(m x 1, r x 1), we only need to do left side transform.
1115 bool leftTransform = valueH != 1;
1116 // For F(1 x m, 1 x r), we only need to do right side transform.
1117 bool rightTransform = valueW != 1;
1118 Value transformedOutput =
1119 outputTransform(rewriter, loc, value, op.getOutput(), op.getM(),
1120 op.getR(), leftTransform, rightTransform);
1121 if (!transformedOutput)
1122 return failure();
1123
1124 rewriter.replaceOp(op, transformedOutput);
1125
1126 return transformedOutput.getDefiningOp();
1127}
1128
1129/// A rewrite pattern to decompose linalg.winograd_filter_transform operations.
1130class DecomposeWinogradFilterTransform final
1131 : public OpRewritePattern<linalg::WinogradFilterTransformOp> {
1132public:
1133 using OpRewritePattern::OpRewritePattern;
1134
1135 LogicalResult matchAndRewrite(linalg::WinogradFilterTransformOp op,
1136 PatternRewriter &rewriter) const override {
1137 return decomposeWinogradFilterTransformHelper(rewriter, op);
1138 }
1139};
1140
1141/// A rewrite pattern to decompose linalg.winograd_input_transform operations.
1142class DecomposeWinogradInputTransform final
1143 : public OpRewritePattern<linalg::WinogradInputTransformOp> {
1144public:
1145 using OpRewritePattern::OpRewritePattern;
1146
1147 LogicalResult matchAndRewrite(linalg::WinogradInputTransformOp op,
1148 PatternRewriter &rewriter) const override {
1149 return decomposeWinogradInputTransformHelper(rewriter, op);
1150 }
1151};
1152
1153/// A rewrite pattern to decompose linalg.winograd_output_transform operations.
1154class DecomposeWinogradOutputTransform final
1155 : public OpRewritePattern<linalg::WinogradOutputTransformOp> {
1156public:
1157 using OpRewritePattern::OpRewritePattern;
1158
1159 LogicalResult matchAndRewrite(linalg::WinogradOutputTransformOp op,
1160 PatternRewriter &rewriter) const override {
1161 return decomposeWinogradOutputTransformHelper(rewriter, op);
1162 }
1163};
1164
1165/// A rewrite pattern for Winograd Conv2D algorithm.
1166class WinogradConv2DNhwcFhwc final
1167 : public OpRewritePattern<linalg::Conv2DNhwcFhwcOp> {
1168public:
1169 using OpRewritePattern::OpRewritePattern;
1170 WinogradConv2DNhwcFhwc(mlir::MLIRContext *context, int64_t m, int64_t r)
1171 : OpRewritePattern(context), m(m), r(r) {}
1172
1173 LogicalResult matchAndRewrite(linalg::Conv2DNhwcFhwcOp convOp,
1174 PatternRewriter &rewriter) const override {
1175 if (failed(winogradConv2DHelper(rewriter, convOp, m, r)))
1176 return failure();
1177
1178 return success();
1179 }
1180
1181private:
1182 int64_t m;
1183 int64_t r;
1184};
1185} // end anonymous namespace
1186
1187//===----------------------------------------------------------------------===//
1188FailureOr<Operation *> winogradConv2D(RewriterBase &rewriter,
1189 linalg::Conv2DNhwcFhwcOp op, int64_t m,
1190 int64_t r) {
1191 return winogradConv2DHelper(rewriter, op, m, r);
1192}
1193
1194FailureOr<Operation *>
1195decomposeWinogradFilterTransformOp(RewriterBase &rewriter,
1196 linalg::WinogradFilterTransformOp op) {
1197 return decomposeWinogradFilterTransformHelper(rewriter, op);
1198}
1199
1200FailureOr<Operation *>
1201decomposeWinogradInputTransformOp(RewriterBase &rewriter,
1202 linalg::WinogradInputTransformOp op) {
1203 return decomposeWinogradInputTransformHelper(rewriter, op);
1204}
1205
1206FailureOr<Operation *>
1207decomposeWinogradOutputTransformOp(RewriterBase &rewriter,
1208 linalg::WinogradOutputTransformOp op) {
1209 return decomposeWinogradOutputTransformHelper(rewriter, op);
1210}
1211
1212void populateWinogradConv2DPatterns(RewritePatternSet &patterns, int64_t m,
1213 int64_t r) {
1214 MLIRContext *context = patterns.getContext();
1215 // TODO: Support more Conv2D data layout, e.g., conv_2d_nchw_fchw
1216 patterns.insert<WinogradConv2DNhwcFhwc>(arg&: context, args&: m, args&: r);
1217}
1218
1219void populateDecomposeWinogradOpsPatterns(RewritePatternSet &patterns) {
1220 MLIRContext *context = patterns.getContext();
1221 patterns
1222 .insert<DecomposeWinogradFilterTransform, DecomposeWinogradInputTransform,
1223 DecomposeWinogradOutputTransform>(arg&: context);
1224}
1225
1226} // end namespace linalg
1227} // end namespace mlir
1228

source code of mlir/lib/Dialect/Linalg/Transforms/WinogradConv2D.cpp