1//===- WinogradConv2D.cpp - Winograd Conv2D implementation ----------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// Implement Winograd Conv2D algorithm. The implementation is based on the
10// paper: Fast Algorithms for Convolutional Neural Networks
11// (https://arxiv.org/abs/1509.09308)
12//
13//===----------------------------------------------------------------------===//
14
15#include "mlir/Dialect/Affine/IR/AffineOps.h"
16#include "mlir/Dialect/Arith/IR/Arith.h"
17#include "mlir/Dialect/Linalg/IR/Linalg.h"
18#include "mlir/Dialect/Linalg/Utils/Utils.h"
19#include "mlir/Dialect/Tensor/IR/Tensor.h"
20#include "mlir/Dialect/Utils/StaticValueUtils.h"
21#include "llvm/Support/MathExtras.h"
22
23namespace mlir {
24namespace linalg {
25
26namespace {
27
28// clang-format off
29/// Winograd Conv2D uses a minimal 2D filtering algorithm to calculate its
30/// result. The formula of minimal 2D filtering algorithm F(m x m, r x r),
31/// m is the output dimension and r is the filter dimension, is
32///
33/// Y = A^T x [ (G x g x G^T) x (B^T x d x B) ] x A
34///
35/// g is filter and d is input data. We need to prepare 6 constant
36/// transformation matrices, G, G^T, B^T, B, A^T, and A for this formula.
37///
38/// The following tables define these constant transformation matrices for
39/// F(2 x 2, 3 x 3), F(4 x 4, 3 x 3), and F(2 x 2, 5 x 5)
40///
41/// To add more transformation matrices, we need to add the following
42/// items:
43/// 1. Add the constant transformation matrix to the corresponding
44/// G, GT, BT, B, AT, or A array.
45/// 2. Add the corresponding TransformMatrix to the GMatrices, GTMatrices,
46/// BTMatrices, BMatrices, ATMatrices, or AMatrices map.
47/// 3. Add a enum value F_m_r to WinogradConv2DFmr enum.
48///
49constexpr float G_2x2_3x3[] = {
50 -1, 0, 0,
51 1./2, -1./2, 1./2,
52 1./2, 1./2, 1./2,
53 0, 0, 1
54};
55
56constexpr float GT_2x2_3x3[] = {
57 -1, 1./2, 1./2, 0,
58 0, -1./2, 1./2, 0,
59 0, 1./2, 1./2, 1
60};
61
62constexpr float BT_2x2_3x3[] = {
63 -1, 0, 1, 0,
64 0, -1, 1, 0,
65 0, 1, 1, 0,
66 0, -1, 0, 1
67};
68
69constexpr float B_2x2_3x3[] = {
70 -1, 0, 0, 0,
71 0, -1, 1, -1,
72 1, 1, 1, 0,
73 0, 0, 0, 1
74};
75
76constexpr float AT_2x2_3x3[] = {
77 1, 1, 1, 0,
78 0, -1, 1, 1
79};
80
81constexpr float A_2x2_3x3[] = {
82 1, 0,
83 1, -1,
84 1, 1,
85 0, 1
86};
87
88constexpr float G_4x4_3x3[] = {
89 1, 0, 0,
90 -1./3, 1./3, -1./3,
91 -1./3, -1./3, -1./3,
92 1./12, -1./6, 1./3,
93 1./12, 1./6, 1./3,
94 0, 0, 1
95};
96
97constexpr float GT_4x4_3x3[] = {
98 1, -1./3, -1./3, 1./12, 1./12, 0,
99 0, 1./3, -1./3, -1./6, 1./6, 0,
100 0, -1./3, -1./3, 1./3, 1./3, 1
101};
102
103constexpr float BT_4x4_3x3[] = {
104 1./4, 0, -5./16, 0, 1./16, 0,
105 0, 1./4, -1./4, -1./16, 1./16, 0,
106 0, -1./4, -1./4, 1./16, 1./16, 0,
107 0, 1./4, -1./8, -1./4, 1./8, 0,
108 0, -1./4, -1./8, 1./4, 1./8, 0,
109 0, 1./4, 0, -5./16, 0, 1./16
110};
111
112constexpr float B_4x4_3x3[] = {
113 1./4, 0, 0, 0, 0, 0,
114 0, 1./4, -1./4, 1./4, -1./4, 1./4,
115 -5./16, -1./4, -1./4, -1./8, -1./8, 0,
116 0, -1./16, 1./16, -1./4, 1./4, -5./16,
117 1./16, 1./16, 1./16, 1./8, 1./8, 0,
118 0, 0, 0, 0, 0, 1./16
119};
120
121constexpr float AT_4x4_3x3[] = {
122 1./8, 1./4, 1./4, 1./8, 1./8, 0,
123 0, -1./4, 1./4, -1./4, 1./4, 0,
124 0, 1./4, 1./4, 1./2, 1./2, 0,
125 0, -1./4, 1./4, -1, 1, 1./2
126};
127
128constexpr float A_4x4_3x3[] = {
129 1./8, 0, 0, 0,
130 1./4, -1./4, 1./4, -1./4,
131 1./4, 1./4, 1./4, 1./4,
132 1./8, -1./4, 1./2, -1,
133 1./8, 1./4, 1./2, 1,
134 0, 0, 0, 1./2
135};
136
137constexpr float G_2x2_5x5[] = {
138 1, 0, 0, 0, 0,
139 1./6, -1./6, 1./6, -1./6, 1./6,
140 -1./6, -1./6, -1./6, -1./6, -1./6,
141-4./15, 2./15, -1./15, 1./30, -1./60,
142 1./60, 1./30, 1./15, 2./15, 4./15,
143 0, 0, 0, 0, 1
144};
145
146constexpr float GT_2x2_5x5[] = {
147 1, 1./6, -1./6, -4./15, 1./60, 0,
148 0, -1./6, -1./6, 2./15, 1./30, 0,
149 0, 1./6, -1./6, -1./15, 1./15, 0,
150 0, -1./6, -1./6, 1./30, 2./15, 0,
151 0, 1./6, -1./6, -1./60, 4./15, 1
152};
153
154constexpr float BT_2x2_5x5[] = {
155 1./8, 3./16, -1./4, -3./16, 1./8, 0,
156 0, 1./8, 1./16, -5./16, 1./8, 0,
157 0, -1./8, -5./16, -1./16, 1./8, 0,
158 0, 1./4, -1./8, -1./4, 1./8, 0,
159 0, -1./8, -1./4, 1./8, 1./4, 0,
160 0, 1./8, 3./16, -1./4, -3./16, 1./8
161};
162
163constexpr float B_2x2_5x5[] = {
164 1./8, 0, 0, 0, 0, 0,
165 3./16, 1./8, -1./8, 1./4, -1./8, 1./8,
166 -1./4, 1./16, -5./16, -1./8, -1./4, 3./16,
167 -3./16, -5./16, -1./16, -1./4, 1./8, -1./4,
168 1./8, 1./8, 1./8, 1./8, 1./4, -3./16,
169 0, 0, 0, 0, 0, 1./8
170};
171
172constexpr float AT_2x2_5x5[] = {
173 1./2, 1, 1, 2, 1, 0,
174 0, -1, 1, -1, 2, 1./2
175};
176
177constexpr float A_2x2_5x5[] = {
178 1./2, 0,
179 1, -1,
180 1, 1,
181 2, -1,
182 1, 2,
183 0, 1./2
184};
185// clang-format on
186
187/// Structure to keep information of constant transform matrices.
188struct TransformMatrix {
189 TransformMatrix(const float *table, int64_t rows, int64_t cols,
190 int64_t scalarFactor = 1)
191 : table(table), rows(rows), cols(cols), scalarFactor(scalarFactor) {}
192
193 const float *table;
194 int64_t rows;
195 int64_t cols;
196 int64_t scalarFactor;
197};
198
199/// Utility function to convert constant array to arith.constant Value.
200Value create2DTransformMatrix(OpBuilder &builder, Location loc,
201 TransformMatrix transform, Type type) {
202 ArrayRef<float> constVec(transform.table, transform.rows * transform.cols);
203
204 return builder.create<arith::ConstantOp>(
205 location: loc, args: DenseFPElementsAttr::get(
206 type: RankedTensorType::get(
207 shape: SmallVector<int64_t>{transform.rows, transform.cols}, elementType: type),
208 arg&: constVec));
209}
210
211/// Extract height x width data from 4D tensors.
212Value extract2DDataFrom4D(OpBuilder &builder, Location loc, Value source,
213 Value loopNorFIndex, Value loopCorFIndex,
214 Value heightOffset, Value widthOffset,
215 int64_t extractHeight, int64_t extractWidth,
216 int64_t loopNorFIdx, int64_t loopCorFIdx,
217 int64_t heightIdx, int64_t widthIdx) {
218 auto sourceType = cast<ShapedType>(Val: source.getType());
219 Type elementType = sourceType.getElementType();
220 int64_t srcSize = sourceType.getRank();
221
222 auto oneIndex = builder.getIndexAttr(value: 1);
223 SmallVector<OpFoldResult> offsets;
224 offsets.resize(N: srcSize);
225 offsets[loopNorFIdx] = loopNorFIndex;
226 offsets[loopCorFIdx] = loopCorFIndex;
227 offsets[heightIdx] = heightOffset;
228 offsets[widthIdx] = widthOffset;
229 SmallVector<OpFoldResult> sizes(srcSize, oneIndex);
230 sizes[heightIdx] = builder.getIndexAttr(value: extractHeight);
231 sizes[widthIdx] = builder.getIndexAttr(value: extractWidth);
232 SmallVector<OpFoldResult> strides(srcSize, oneIndex);
233
234 auto extractFilterType =
235 RankedTensorType::get(shape: {extractHeight, extractWidth}, elementType);
236 auto extractFilterOp = builder.create<tensor::ExtractSliceOp>(
237 location: loc, args&: extractFilterType, args&: source, args&: offsets, args&: sizes, args&: strides);
238
239 return extractFilterOp;
240}
241
242/// Extract height x width data from 6D tensors.
243Value extract2DDataFrom6D(OpBuilder &builder, Location loc, Value source,
244 Value tileHIndex, Value tileWIndex,
245 Value loopNorFIndex, Value loopCorFIndex,
246 int64_t tileHIdx, int64_t tileWIdx,
247 int64_t loopNorFIdx, int64_t loopCorFIdx,
248 int64_t heightIdx, int64_t widthIdx) {
249 auto sourceType = cast<ShapedType>(Val: source.getType());
250 Type elementType = sourceType.getElementType();
251 auto sourceShape = sourceType.getShape();
252 int64_t srcSize = sourceType.getRank();
253 int64_t height = sourceShape[heightIdx];
254 int64_t width = sourceShape[widthIdx];
255
256 auto zeroIndex = builder.getIndexAttr(value: 0);
257 auto oneIndex = builder.getIndexAttr(value: 1);
258 SmallVector<OpFoldResult> offsets(srcSize, zeroIndex);
259 offsets.resize(N: srcSize);
260 offsets[tileHIdx] = tileHIndex;
261 offsets[tileWIdx] = tileWIndex;
262 offsets[loopNorFIdx] = loopNorFIndex;
263 offsets[loopCorFIdx] = loopCorFIndex;
264 SmallVector<OpFoldResult> sizes(srcSize, oneIndex);
265 sizes[heightIdx] = builder.getIndexAttr(value: height);
266 sizes[widthIdx] = builder.getIndexAttr(value: width);
267 SmallVector<OpFoldResult> strides(srcSize, oneIndex);
268
269 auto extractFilterType = RankedTensorType::get(shape: {height, width}, elementType);
270 auto extractFilterOp = builder.create<tensor::ExtractSliceOp>(
271 location: loc, args&: extractFilterType, args&: source, args&: offsets, args&: sizes, args&: strides);
272
273 return extractFilterOp;
274}
275
276/// Insert transformed height x width data to 4D tensors which it is
277/// extracted from.
278Value insert2DDataTo4D(OpBuilder &builder, Location loc, Value source,
279 Value dest, Value loopNorFIndex, Value loopCorFIndex,
280 Value heightOffset, Value widthOffset, int64_t height,
281 int64_t width, int64_t loopNorFIdx, int64_t loopCorFIdx,
282 int64_t heightIdx, int64_t widthIdx) {
283 int64_t destSize = cast<ShapedType>(Val: dest.getType()).getRank();
284 auto oneIndex = builder.getIndexAttr(value: 1);
285 SmallVector<OpFoldResult> retOffsets;
286 retOffsets.resize(N: destSize);
287 retOffsets[loopNorFIdx] = loopNorFIndex;
288 retOffsets[loopCorFIdx] = loopCorFIndex;
289 retOffsets[heightIdx] = heightOffset;
290 retOffsets[widthIdx] = widthOffset;
291 SmallVector<OpFoldResult> retSizes(destSize, oneIndex);
292 retSizes[heightIdx] = builder.getIndexAttr(value: height);
293 retSizes[widthIdx] = builder.getIndexAttr(value: width);
294 SmallVector<OpFoldResult> strides(destSize, oneIndex);
295
296 auto insertSliceOp = builder.create<tensor::InsertSliceOp>(
297 location: loc, args&: source, args&: dest, args&: retOffsets, args&: retSizes, args&: strides);
298
299 return insertSliceOp;
300}
301
302/// Insert transformed height x width data to 6D tensors which it is
303/// extracted from.
304Value insert2DDataTo6D(OpBuilder &builder, Location loc, Value source,
305 Value dest, Value tileHIndex, Value tileWIndex,
306 Value loopNorFIndex, Value loopCorFIndex, int64_t height,
307 int64_t width, int64_t tileHIdx, int64_t tileWIdx,
308 int64_t loopNorFIdx, int64_t loopCorFIdx,
309 int64_t heightIdx, int64_t widthIdx) {
310 int64_t destSize = cast<ShapedType>(Val: dest.getType()).getRank();
311 auto zeroIndex = builder.getIndexAttr(value: 0);
312 auto oneIndex = builder.getIndexAttr(value: 1);
313 SmallVector<OpFoldResult> retOffsets(destSize, zeroIndex);
314 retOffsets.resize(N: destSize);
315 retOffsets[tileHIdx] = tileHIndex;
316 retOffsets[tileWIdx] = tileWIndex;
317 retOffsets[loopNorFIdx] = loopNorFIndex;
318 retOffsets[loopCorFIdx] = loopCorFIndex;
319 SmallVector<OpFoldResult> retSizes(destSize, oneIndex);
320 retSizes[heightIdx] = builder.getIndexAttr(value: height);
321 retSizes[widthIdx] = builder.getIndexAttr(value: width);
322 SmallVector<OpFoldResult> strides(destSize, oneIndex);
323
324 auto insertSliceOp = builder.create<tensor::InsertSliceOp>(
325 location: loc, args&: source, args&: dest, args&: retOffsets, args&: retSizes, args&: strides);
326
327 return insertSliceOp;
328}
329
330/// This function transforms the filter. The data layout of the filter is FHWC.
331/// The transformation matrix is 2-dimension. We need to extract H x W from
332/// FHWC first. We need to generate 2 levels of loops to iterate on F and C.
333/// After the transformation, we get
334///
335/// scf.for %f = lo_f to hi_f step 1
336/// scf.for %c = lo_c to hi_c step 1
337/// %extracted = extract filter<h x w> from filter<f x h x w x c>
338/// %ret = linalg.matmul G, %extracted
339/// %ret = linalg.matmul %ret, GT
340/// %inserted = insert %ret into filter<h x w x c x f>
341Value filterTransform(RewriterBase &rewriter, Location loc, Value filter,
342 Value retValue, WinogradConv2DFmr fmr,
343 bool leftTransform = true, bool rightTransform = true) {
344 // Map from (m, r) to G transform matrix.
345 static const llvm::SmallDenseMap<WinogradConv2DFmr, TransformMatrix>
346 GMatrices = {
347 {WinogradConv2DFmr::F_2_3, TransformMatrix(G_2x2_3x3, 4, 3)},
348 {WinogradConv2DFmr::F_4_3, TransformMatrix(G_4x4_3x3, 6, 3)},
349 {WinogradConv2DFmr::F_2_5, TransformMatrix(G_2x2_5x5, 6, 5)},
350 };
351
352 // Map from (m, r) to GT transform matrix.
353 static const llvm::SmallDenseMap<WinogradConv2DFmr, TransformMatrix>
354 GTMatrices = {
355 {WinogradConv2DFmr::F_2_3, TransformMatrix(GT_2x2_3x3, 3, 4)},
356 {WinogradConv2DFmr::F_4_3, TransformMatrix(GT_4x4_3x3, 3, 6)},
357 {WinogradConv2DFmr::F_2_5, TransformMatrix(GT_2x2_5x5, 5, 6)},
358 };
359
360 auto filterType = cast<ShapedType>(Val: filter.getType());
361 Type elementType = filterType.getElementType();
362 auto filterShape = filterType.getShape(); // F, H, W, C
363 int64_t filterF = filterShape[0];
364 int64_t filterH = filterShape[1];
365 int64_t filterW = filterShape[2];
366 int64_t filterC = filterShape[3];
367
368 int64_t m, r;
369 std::tie(args&: m, args&: r) = getFmrFromWinogradConv2DFmr(fmr);
370 if (filterH != r && filterH != 1)
371 return Value();
372 if (filterW != r && filterW != 1)
373 return Value();
374
375 Value zeroIdx = rewriter.create<arith::ConstantIndexOp>(location: loc, args: 0);
376 auto buildBody = [&](OpBuilder &builder, Location loc, ValueRange ivs,
377 ValueRange args) -> scf::ValueVector {
378 Value FIter = ivs[0];
379 Value CIter = ivs[1];
380
381 // Extract (H, W) from (F, H, W, C).
382 auto extractFilter =
383 extract2DDataFrom4D(builder, loc, source: filter, loopNorFIndex: FIter, loopCorFIndex: CIter, heightOffset: zeroIdx,
384 widthOffset: zeroIdx, extractHeight: filterH, extractWidth: filterW, /*loopNorFIdx=*/0,
385 /*loopCorFIdx=*/3, /*heightIdx=*/1, /*widthIdx=*/2);
386
387 int64_t retRows = 1;
388 Value matmulRetValue = extractFilter;
389 Value zero = builder.create<arith::ConstantOp>(
390 location: loc, args: rewriter.getZeroAttr(type: elementType));
391 if (leftTransform) {
392 // Get constant transform matrix G.
393 auto it = GMatrices.find(Val: fmr);
394 if (it == GMatrices.end())
395 return {};
396 const TransformMatrix &GMatrix = it->second;
397
398 retRows = GMatrix.rows;
399 auto matmulType = RankedTensorType::get(shape: {retRows, filterW}, elementType);
400 auto empty =
401 builder
402 .create<tensor::EmptyOp>(location: loc, args: matmulType.getShape(), args&: elementType)
403 .getResult();
404 auto init = builder.create<linalg::FillOp>(location: loc, args&: zero, args&: empty).getResult(i: 0);
405
406 Value G = create2DTransformMatrix(builder, loc, transform: GMatrix, type: elementType);
407 // Multiply G x g.
408 auto matmulOp = builder.create<linalg::MatmulOp>(
409 location: loc, args&: matmulType, args: ValueRange{G, extractFilter}, args: ValueRange{init});
410 matmulRetValue = matmulOp.getResult(i: 0);
411 }
412
413 if (rightTransform) {
414 // Get constant transform matrix GT.
415 auto it = GTMatrices.find(Val: fmr);
416 if (it == GTMatrices.end())
417 return {};
418 const TransformMatrix &GTMatrix = it->second;
419
420 auto matmulType =
421 RankedTensorType::get(shape: {retRows, GTMatrix.cols}, elementType);
422 auto empty =
423 builder
424 .create<tensor::EmptyOp>(location: loc, args: matmulType.getShape(), args&: elementType)
425 .getResult();
426 auto init = builder.create<linalg::FillOp>(location: loc, args&: zero, args&: empty).getResult(i: 0);
427
428 Value GT = create2DTransformMatrix(builder, loc, transform: GTMatrix, type: elementType);
429 // Multiply u = (G x g) x GT.
430 auto matmulOp = builder.create<linalg::MatmulOp>(
431 location: loc, args&: matmulType, args: ValueRange{matmulRetValue, GT}, args: ValueRange{init});
432 matmulRetValue = matmulOp.getResult(i: 0);
433 }
434
435 // Insert (H, W) to (H, W, C, F).
436 int64_t retHeight = leftTransform ? m + r - 1 : 1;
437 int64_t retWidth = rightTransform ? m + r - 1 : 1;
438
439 auto insertSliceOp =
440 insert2DDataTo4D(builder, loc, source: matmulRetValue, dest: args[0], loopNorFIndex: FIter, loopCorFIndex: CIter,
441 heightOffset: zeroIdx, widthOffset: zeroIdx, height: retHeight, width: retWidth,
442 /*loopNorFIdx=*/3, /*loopCorFIdx=*/2,
443 /*heightIdx=*/0, /*widthIdx=*/1);
444
445 return {insertSliceOp};
446 };
447
448 auto fUpperBound = rewriter.create<arith::ConstantIndexOp>(location: loc, args&: filterF);
449 auto cUpperBound = rewriter.create<arith::ConstantIndexOp>(location: loc, args&: filterC);
450 auto oneStep = rewriter.create<arith::ConstantIndexOp>(location: loc, args: 1);
451 scf::LoopNest loops = scf::buildLoopNest(
452 builder&: rewriter, loc, lbs: {zeroIdx, zeroIdx}, ubs: {fUpperBound, cUpperBound},
453 steps: {oneStep, oneStep}, iterArgs: {retValue}, bodyBuilder: buildBody);
454 return loops.results[0];
455}
456
457/// This function transforms the input. The data layout of the input is NHWC.
458/// The transformation matrix is 2-dimension. We need to extract H x W from
459/// NHWC first. We need to generate 2 levels of loops to iterate on N and C.
460/// After the transformation, we get
461///
462/// scf.for %h = 0 to tileH step 1
463/// scf.for %w = 0 to tileW step 1
464/// scf.for %n = 0 to N step 1
465/// scf.for %c = 0 to C step 1
466/// %extracted = extract %extracted<alphaH x alphaW> from
467/// %input<N x H x W x C>
468/// at [%n, (%h x m), (%w x m), %c]
469/// %ret = linalg.matmul BT, %extracted
470/// %ret = linalg.matmul %ret, B
471/// %inserted = insert %ret<alphaH x alphaW> into
472/// %output<alphaH x alphaW x tileH x tileW x N x C>
473/// at [0, 0, %h, %w, %n, %c]
474Value inputTransform(RewriterBase &rewriter, Location loc, Value input,
475 Value retValue, WinogradConv2DFmr fmr,
476 bool leftTransform = true, bool rightTransform = true) {
477 // Map from (m, r) to BT transform matrix.
478 static const llvm::SmallDenseMap<WinogradConv2DFmr, TransformMatrix>
479 BTMatrices = {
480 {WinogradConv2DFmr::F_2_3, TransformMatrix(BT_2x2_3x3, 4, 4)},
481 {WinogradConv2DFmr::F_4_3, TransformMatrix(BT_4x4_3x3, 6, 6)},
482 {WinogradConv2DFmr::F_2_5, TransformMatrix(BT_2x2_5x5, 6, 6)},
483 };
484
485 // Map from (m, r) to B transform matrix.
486 static const llvm::SmallDenseMap<WinogradConv2DFmr, TransformMatrix>
487 BMatrices = {
488 {WinogradConv2DFmr::F_2_3, TransformMatrix(B_2x2_3x3, 4, 4)},
489 {WinogradConv2DFmr::F_4_3, TransformMatrix(B_4x4_3x3, 6, 6)},
490 {WinogradConv2DFmr::F_2_5, TransformMatrix(B_2x2_5x5, 6, 6)},
491 };
492
493 int64_t m, r;
494 std::tie(args&: m, args&: r) = getFmrFromWinogradConv2DFmr(fmr);
495 auto inputType = cast<ShapedType>(Val: input.getType());
496 Type elementType = inputType.getElementType();
497 auto inputShape = inputType.getShape(); // N, H, W, C
498 int64_t inputN = inputShape[0];
499 int64_t inputC = inputShape[3];
500 auto valueType = cast<ShapedType>(Val: retValue.getType());
501 auto valueShape = valueType.getShape(); // alphaH, alphaW, HTile, WTile, N, C
502 int64_t tileH = valueShape[2];
503 int64_t tileW = valueShape[3];
504 int64_t alphaH = leftTransform ? m + r - 1 : 1;
505 int64_t alphaW = rightTransform ? m + r - 1 : 1;
506
507 auto buildBody = [&](OpBuilder &builder, Location loc, ValueRange ivs,
508 ValueRange args) -> scf::ValueVector {
509 Value tileHIter = ivs[0];
510 Value tileWIter = ivs[1];
511 Value NIter = ivs[2];
512 Value CIter = ivs[3];
513
514 auto context = builder.getContext();
515
516 auto identityAffineMap = rewriter.getMultiDimIdentityMap(rank: 1);
517 auto affineMap =
518 AffineMap::get(dimCount: 1, symbolCount: 0, results: {builder.getAffineDimExpr(position: 0) * m}, context);
519 Value heightOffset = builder.create<affine::AffineApplyOp>(
520 location: loc, args&: leftTransform ? affineMap : identityAffineMap, args&: tileHIter);
521 Value widthOffset = builder.create<affine::AffineApplyOp>(
522 location: loc, args&: rightTransform ? affineMap : identityAffineMap, args&: tileWIter);
523
524 // Extract (H, W) from (N, H, W, C).
525 auto extractInput =
526 extract2DDataFrom4D(builder, loc, source: input, loopNorFIndex: NIter, loopCorFIndex: CIter, heightOffset,
527 widthOffset, extractHeight: alphaH, extractWidth: alphaW, /*loopNorFIdx=*/0,
528 /*loopCorFIdx=*/3, /*heightIdx=*/1, /*widthIdx=*/2);
529
530 int64_t retRows = 1;
531 int64_t retCols = 1;
532 Value matmulRetValue = extractInput;
533 Value zero = builder.create<arith::ConstantOp>(
534 location: loc, args: rewriter.getZeroAttr(type: elementType));
535 if (leftTransform) {
536 // Get constant transform matrix BT.
537 auto it = BTMatrices.find(Val: fmr);
538 if (it == BTMatrices.end())
539 return {};
540 const TransformMatrix &BTMatrix = it->second;
541
542 retRows = BTMatrix.rows;
543 auto matmulType = RankedTensorType::get(shape: {retRows, alphaW}, elementType);
544 auto empty =
545 builder
546 .create<tensor::EmptyOp>(location: loc, args: matmulType.getShape(), args&: elementType)
547 .getResult();
548 auto init = builder.create<linalg::FillOp>(location: loc, args&: zero, args&: empty).getResult(i: 0);
549
550 Value BT =
551 create2DTransformMatrix(builder, loc, transform: BTMatrix, type: builder.getF32Type());
552 // Multiply BT x d.
553 auto matmulOp = builder.create<linalg::MatmulOp>(
554 location: loc, args&: matmulType, args: ValueRange{BT, matmulRetValue}, args: ValueRange{init});
555 matmulRetValue = matmulOp.getResult(i: 0);
556 }
557
558 if (rightTransform) {
559 // Get constant transform matrix B.
560 auto it = BMatrices.find(Val: fmr);
561 if (it == BMatrices.end())
562 return {};
563 const TransformMatrix &BMatrix = it->second;
564
565 retCols = BMatrix.cols;
566 auto matmulType = RankedTensorType::get(shape: {retRows, retCols}, elementType);
567 auto empty =
568 builder
569 .create<tensor::EmptyOp>(location: loc, args: matmulType.getShape(), args&: elementType)
570 .getResult();
571 auto init = builder.create<linalg::FillOp>(location: loc, args&: zero, args&: empty).getResult(i: 0);
572 Value B =
573 create2DTransformMatrix(builder, loc, transform: BMatrix, type: builder.getF32Type());
574 // Multiply v = (BT x d) x B.
575 auto matmulOp = builder.create<linalg::MatmulOp>(
576 location: loc, args&: matmulType, args: ValueRange{matmulRetValue, B}, args: ValueRange{init});
577 matmulRetValue = matmulOp.getResult(i: 0);
578 }
579
580 // Insert (H, W) to (H, W, tileH, tileW, N, C).
581 auto combinedVal = insert2DDataTo6D(
582 builder, loc, source: matmulRetValue, dest: args[0], tileHIndex: tileHIter, tileWIndex: tileWIter, loopNorFIndex: NIter,
583 loopCorFIndex: CIter, height: retRows, width: retCols, tileHIdx: 2, tileWIdx: 3, /*loopNorFIdx=*/4, /*loopCorFIdx=*/5,
584 /*heightIdx=*/0, /*widthIdx=*/1);
585
586 return {combinedVal};
587 };
588
589 auto zeroIdx = rewriter.create<arith::ConstantIndexOp>(location: loc, args: 0);
590 auto tileHBound = rewriter.create<arith::ConstantIndexOp>(location: loc, args&: tileH);
591 auto tileWBound = rewriter.create<arith::ConstantIndexOp>(location: loc, args&: tileW);
592 auto nUpperBound = rewriter.create<arith::ConstantIndexOp>(location: loc, args&: inputN);
593 auto cUpperBound = rewriter.create<arith::ConstantIndexOp>(location: loc, args&: inputC);
594 auto oneStep = rewriter.create<arith::ConstantIndexOp>(location: loc, args: 1);
595 scf::LoopNest loops = scf::buildLoopNest(
596 builder&: rewriter, loc, lbs: {zeroIdx, zeroIdx, zeroIdx, zeroIdx},
597 ubs: {tileHBound, tileWBound, nUpperBound, cUpperBound},
598 steps: {oneStep, oneStep, oneStep, oneStep}, iterArgs: {retValue}, bodyBuilder: buildBody);
599 return loops.results[0];
600}
601
602/// This function generates linalg.batch_matmul to multiply input with filter.
603/// linalg.batch_matmul only supports 3-dimensional inputs. We can treat
604/// tileH x tileW x H x W data as the 1-dimensional data array. That is to
605/// convert [tileH, tileW, H, W, N, C] to [tileH x tileW x H x W, N, C]. In this
606/// way, we can convert 6-dimensional inputs to 3-dimensional representation
607/// that is suitable for linalg.batch_matmul.
608///
609/// Batched matmul will do the matrix multiply with the reduction on channel.
610///
611/// We get
612///
613/// %collapsed_input = tensor.collapse_shape %input
614/// %collapsed_filter = tensor.collapse_shape %filter
615/// %ret = linalg.batch_matmul %collapsed_input, %collapsed_filter
616/// %expanded_ret = tensor.expand_shape %ret
617///
618/// After this function, we get return value with data layout
619/// (tileH, tileW, H, W, N, F).
620static Value matrixMultiply(RewriterBase &rewriter, Location loc,
621 Value transformedFilter, Value transformedInput,
622 Type outputElementType) {
623 // Convert (alphaH, alphaW, C, F) to (alphaH x alphaW, C, F) for filter.
624 auto filterType = cast<ShapedType>(Val: transformedFilter.getType());
625 assert(filterType.hasStaticShape() && "only support static shapes.");
626 ArrayRef<int64_t> filterShape = filterType.getShape();
627 Type filterElementType = filterType.getElementType();
628 auto filterReassocType = RankedTensorType::get(
629 shape: {filterShape[0] * filterShape[1], filterShape[2], filterShape[3]},
630 elementType: filterElementType);
631 SmallVector<ReassociationIndices> filterReassoc = {{0, 1}, {2}, {3}};
632 Value collapseFilter = rewriter.create<tensor::CollapseShapeOp>(
633 location: loc, args&: filterReassocType, args&: transformedFilter, args&: filterReassoc);
634
635 // Convert (alphaH, alphaW, tileH, tileW, N, C) to
636 // (alphaH x alphaW, tileH x tileW x N, C) for input.
637 auto inputType = cast<ShapedType>(Val: transformedInput.getType());
638 assert(inputType.hasStaticShape() && "only support static shapes.");
639 ArrayRef<int64_t> inputShape = inputType.getShape();
640 Type inputElementType = inputType.getElementType();
641 auto inputReassocType = RankedTensorType::get(
642 shape: {inputShape[0] * inputShape[1],
643 inputShape[2] * inputShape[3] * inputShape[4], inputShape[5]},
644 elementType: inputElementType);
645 SmallVector<ReassociationIndices> inputReassoc = {{0, 1}, {2, 3, 4}, {5}};
646 Value collapseInput = rewriter.create<tensor::CollapseShapeOp>(
647 location: loc, args&: inputReassocType, args&: transformedInput, args&: inputReassoc);
648
649 // Batched matrix multiply.
650 auto matmulType = RankedTensorType::get(
651 shape: {inputShape[0] * inputShape[1],
652 inputShape[2] * inputShape[3] * inputShape[4], filterShape[3]},
653 elementType: outputElementType);
654 Value empty = rewriter
655 .create<tensor::EmptyOp>(location: loc, args: matmulType.getShape(),
656 args&: outputElementType)
657 .getResult();
658 Value zero = rewriter.create<arith::ConstantOp>(
659 location: loc, args: rewriter.getZeroAttr(type: outputElementType));
660 Value init = rewriter.create<linalg::FillOp>(location: loc, args&: zero, args&: empty).getResult(i: 0);
661
662 auto matmulOp = rewriter.create<linalg::BatchMatmulOp>(
663 location: loc, args&: matmulType, args: ValueRange({collapseInput, collapseFilter}),
664 args: ValueRange{init});
665
666 // The result shape of batch matmul is (alphaH x alphaW, tileH x tileW x N, F)
667 // Expand matmul result to (alphaH, alphaW, tileH, tileW, N, F).
668 SmallVector<ReassociationIndices> outputReassoc = {{0, 1}, {2, 3, 4}, {5}};
669 auto outputReassocType =
670 RankedTensorType::get(shape: {inputShape[0], inputShape[1], inputShape[2],
671 inputShape[3], inputShape[4], filterShape[3]},
672 elementType: outputElementType);
673 auto expandOutput = rewriter.create<tensor::ExpandShapeOp>(
674 location: loc, args&: outputReassocType, args: matmulOp.getResult(i: 0), args&: outputReassoc);
675 return expandOutput;
676}
677
678/// This function transforms the output. The data layout of the output is HWNF.
679/// The transformation matrix is 2-dimension. We need to extract H x W from
680/// HWNF first. We need to generate 2 levels of loops to iterate on N and F.
681/// After the transformation, we get
682///
683/// scf.for %h = 0 to tileH step 1
684/// scf.for %w = 0 to tileW step 1
685/// scf.for %n = 0 to N step 1
686/// scf.for %f = 0 to F step 1
687/// %extracted = extract %extracted<alphaH x alphaW> from
688/// %input<alphaH x alphaW x tileH x tileW x N x F>
689/// at [0, 0, %h, %w, %n, %f]
690/// %ret = linalg.matmul AT, %extracted
691/// %ret = linalg.matmul %ret, A
692/// %inserted = insert %ret<alphaH x alphaW> into
693/// output<N x H x W x F>
694/// at [%n, (%h x m), (%w x m), %f]
695Value outputTransform(RewriterBase &rewriter, Location loc, Value value,
696 Value output, WinogradConv2DFmr fmr,
697 bool leftTransform = true, bool rightTransform = true) {
698 // Map from (m, r) to AT transform matrix.
699 static const llvm::SmallDenseMap<WinogradConv2DFmr, TransformMatrix>
700 ATMatrices = {
701 {WinogradConv2DFmr::F_2_3, TransformMatrix(AT_2x2_3x3, 2, 4)},
702 {WinogradConv2DFmr::F_4_3, TransformMatrix(AT_4x4_3x3, 4, 6, 32)},
703 {WinogradConv2DFmr::F_2_5, TransformMatrix(AT_2x2_5x5, 2, 6, 16)},
704 };
705
706 // Map from (m, r) to A transform matrix.
707 static const llvm::SmallDenseMap<WinogradConv2DFmr, TransformMatrix>
708 AMatrices = {
709 {WinogradConv2DFmr::F_2_3, TransformMatrix(A_2x2_3x3, 4, 2)},
710 {WinogradConv2DFmr::F_4_3, TransformMatrix(A_4x4_3x3, 6, 4, 32)},
711 {WinogradConv2DFmr::F_2_5, TransformMatrix(A_2x2_5x5, 6, 2, 16)},
712 };
713
714 int64_t m, r;
715 std::tie(args&: m, args&: r) = getFmrFromWinogradConv2DFmr(fmr);
716 auto valueType = cast<ShapedType>(Val: value.getType());
717 Type elementType = valueType.getElementType();
718 auto valueShape = valueType.getShape(); // H, W, TileH, TileW, N, F
719 int64_t valueH = valueShape[0];
720 int64_t valueW = valueShape[1];
721 int64_t valueN = valueShape[4];
722 int64_t valueF = valueShape[5];
723 int64_t alphaH = leftTransform ? m + r - 1 : 1;
724 int64_t alphaW = rightTransform ? m + r - 1 : 1;
725
726 if (valueH != alphaH && valueH != 1)
727 return Value();
728 if (valueW != alphaW && valueW != 1)
729 return Value();
730
731 auto buildBody = [&](OpBuilder &builder, Location loc, ValueRange ivs,
732 ValueRange args) -> scf::ValueVector {
733 auto context = builder.getContext();
734 Value tileHIter = ivs[0];
735 Value tileWIter = ivs[1];
736 Value NIter = ivs[2];
737 Value FIter = ivs[3];
738
739 // Extract (H, W) from (H, W, tileH, tileW, N, F).
740 auto extractValue =
741 extract2DDataFrom6D(builder, loc, source: value, tileHIndex: tileHIter, tileWIndex: tileWIter, loopNorFIndex: NIter,
742 loopCorFIndex: FIter, tileHIdx: 2, tileWIdx: 3, /*loopNorFIdx=*/4,
743 /*loopCorFIdx=*/5, /*heightIdx=*/0, /*widthIdx=*/1);
744
745 const TransformMatrix &AMatrix = AMatrices.at(Val: fmr);
746 const TransformMatrix &ATMatrix = ATMatrices.at(Val: fmr);
747 int64_t scalarFactor = (rightTransform ? AMatrix.scalarFactor : 1) *
748 (leftTransform ? ATMatrix.scalarFactor : 1);
749 int64_t retCols = rightTransform ? AMatrix.cols : 1;
750 int64_t retRows = leftTransform ? ATMatrix.rows : 1;
751
752 Value matmulRetValue = extractValue;
753 Value zero = builder.create<arith::ConstantOp>(
754 location: loc, args: rewriter.getZeroAttr(type: elementType));
755
756 auto identityAffineMap = rewriter.getMultiDimIdentityMap(rank: 1);
757 auto affineMap =
758 AffineMap::get(dimCount: 1, symbolCount: 0, results: {builder.getAffineDimExpr(position: 0) * m}, context);
759 Value heightOffset = builder.create<affine::AffineApplyOp>(
760 location: loc, args&: leftTransform ? affineMap : identityAffineMap, args&: tileHIter);
761 Value widthOffset = builder.create<affine::AffineApplyOp>(
762 location: loc, args&: rightTransform ? affineMap : identityAffineMap, args&: tileWIter);
763
764 Value outInitVal =
765 extract2DDataFrom4D(builder, loc, source: args[0], loopNorFIndex: NIter, loopCorFIndex: FIter, heightOffset,
766 widthOffset, extractHeight: retRows, extractWidth: retCols,
767 /*loopNorFIdx=*/0,
768 /*loopCorFIdx=*/3, /*heightIdx=*/1,
769 /*widthIdx=*/2);
770 if (leftTransform) {
771 auto matmulType = RankedTensorType::get(shape: {retRows, valueW}, elementType);
772 Value init = outInitVal;
773 if (rightTransform || scalarFactor != 1) {
774 auto empty = builder
775 .create<tensor::EmptyOp>(location: loc, args: matmulType.getShape(),
776 args&: elementType)
777 .getResult();
778 init = builder.create<linalg::FillOp>(location: loc, args&: zero, args&: empty).getResult(i: 0);
779 }
780
781 Value AT = create2DTransformMatrix(builder, loc, transform: ATMatrix, type: elementType);
782 // Multiply AT x m.
783 auto matmulOp = builder.create<linalg::MatmulOp>(
784 location: loc, args&: matmulType, args: ValueRange{AT, matmulRetValue}, args: ValueRange{init});
785 matmulRetValue = matmulOp.getResult(i: 0);
786 }
787
788 if (rightTransform) {
789 auto matmulType =
790 RankedTensorType::get(shape: {retRows, AMatrix.cols}, elementType);
791 Value init = outInitVal;
792 if (scalarFactor != 1) {
793 auto empty = builder
794 .create<tensor::EmptyOp>(location: loc, args: matmulType.getShape(),
795 args&: elementType)
796 .getResult();
797 init = builder.create<linalg::FillOp>(location: loc, args&: zero, args&: empty).getResult(i: 0);
798 }
799
800 Value A = create2DTransformMatrix(builder, loc, transform: AMatrix, type: elementType);
801 // Multiply y = (AT x m) x A.
802 auto matmulOp = builder.create<linalg::MatmulOp>(
803 location: loc, args&: matmulType, args: ValueRange{matmulRetValue, A}, args: ValueRange{init});
804 matmulRetValue = matmulOp.getResult(i: 0);
805 }
806
807 if (scalarFactor != 1) {
808 // Multiply by scalar factor and add outInitVal.
809 Value scalarFactorValue = builder.create<arith::ConstantOp>(
810 location: loc, args: FloatAttr::get(type: elementType, value: scalarFactor));
811 auto matmulType = RankedTensorType::get(shape: {retRows, retCols}, elementType);
812 auto identityAffineMap = rewriter.getMultiDimIdentityMap(rank: 2);
813 SmallVector<AffineMap> affineMaps = {
814 AffineMap::get(dimCount: 2, symbolCount: 0, context), identityAffineMap, identityAffineMap};
815
816 matmulRetValue =
817 rewriter
818 .create<linalg::GenericOp>(
819 location: loc, args&: matmulType,
820 args: ValueRange{scalarFactorValue, matmulRetValue},
821 args: ValueRange{outInitVal}, args&: affineMaps,
822 args: llvm::ArrayRef<utils::IteratorType>{
823 utils::IteratorType::parallel,
824 utils::IteratorType::parallel},
825 args: [&](OpBuilder &nestedBuilder, Location nestedLoc,
826 ValueRange args) {
827 auto mulf = nestedBuilder.create<arith::MulFOp>(
828 location: nestedLoc, args: args[0], args: args[1]);
829 auto addf = nestedBuilder.create<arith::AddFOp>(
830 location: nestedLoc, args: mulf.getResult(), args: args[2]);
831 nestedBuilder.create<linalg::YieldOp>(location: nestedLoc,
832 args: addf.getResult());
833 })
834 .getResult(i: 0);
835 }
836
837 // Insert (H, W) to (N, H, W, F).
838 Value combinedVal =
839 insert2DDataTo4D(builder, loc, source: matmulRetValue, dest: args[0], loopNorFIndex: NIter, loopCorFIndex: FIter,
840 heightOffset, widthOffset, height: retRows, width: retCols,
841 /*loopNorFIdx=*/0,
842 /*loopCorFIdx=*/3, /*heightIdx=*/1,
843 /*widthIdx=*/2);
844
845 return {combinedVal};
846 };
847
848 int64_t tilwH = valueShape[2];
849 int64_t tileW = valueShape[3];
850 auto zeroIdx = rewriter.create<arith::ConstantIndexOp>(location: loc, args: 0);
851 auto tileHBound = rewriter.create<arith::ConstantIndexOp>(location: loc, args&: tilwH);
852 auto tileWBound = rewriter.create<arith::ConstantIndexOp>(location: loc, args&: tileW);
853 auto nUpperBound = rewriter.create<arith::ConstantIndexOp>(location: loc, args&: valueN);
854 auto fUpperBound = rewriter.create<arith::ConstantIndexOp>(location: loc, args&: valueF);
855 auto oneStep = rewriter.create<arith::ConstantIndexOp>(location: loc, args: 1);
856 scf::LoopNest loops = scf::buildLoopNest(
857 builder&: rewriter, loc, lbs: {zeroIdx, zeroIdx, zeroIdx, zeroIdx},
858 ubs: {tileHBound, tileWBound, nUpperBound, fUpperBound},
859 steps: {oneStep, oneStep, oneStep, oneStep}, iterArgs: {output}, bodyBuilder: buildBody);
860 return loops.results[0];
861}
862
863/// Create an empty tensor with alignedType and insert the value into the
864/// created empty tensor with aligned size.
865static Value padToAlignedTensor(RewriterBase &rewriter, Location loc,
866 Value value, ArrayRef<int64_t> alignedShape) {
867 auto valueType = cast<ShapedType>(Val: value.getType());
868 Type elementType = valueType.getElementType();
869 auto alignedType = RankedTensorType::get(shape: alignedShape, elementType);
870 Value padValue = rewriter.create<arith::ConstantOp>(
871 location: loc, args&: elementType, args: rewriter.getZeroAttr(type: elementType));
872
873 return linalg::makeComposedPadHighOp(b&: rewriter, loc, type: alignedType, source: value,
874 padding: padValue, nofold: false);
875}
876
877/// Extract sub-tensor with extractedType from value.
878static Value extractFromAlignedTensor(RewriterBase &rewriter, Location loc,
879 Value value,
880 RankedTensorType extractedType) {
881 OpFoldResult zeroIndex = rewriter.getIndexAttr(value: 0);
882 OpFoldResult oneIndex = rewriter.getIndexAttr(value: 1);
883 SmallVector<OpFoldResult, 4> offsets(4, zeroIndex);
884 SmallVector<OpFoldResult, 4> strides(4, oneIndex);
885
886 ArrayRef<int64_t> extractedShape = extractedType.getShape();
887 SmallVector<OpFoldResult> sizes =
888 getAsOpFoldResult(arrayAttr: rewriter.getI64ArrayAttr(values: extractedShape));
889
890 return rewriter.create<tensor::ExtractSliceOp>(location: loc, args&: extractedType, args&: value,
891 args&: offsets, args&: sizes, args&: strides);
892}
893
894/// Utility function to check all values in the attribute are 1.
895static bool hasAllOneValues(DenseIntElementsAttr attr) {
896 return llvm::all_of(
897 Range&: attr, P: [](const APInt &element) { return element.getSExtValue() == 1; });
898}
899
900/// A helper function to convert linalg.conv_2d_nhwc_fhwc to
901/// linalg.winograd_*_transform ops.
902static FailureOr<Operation *>
903winogradConv2DHelper(RewriterBase &rewriter, linalg::Conv2DNhwcFhwcOp convOp,
904 WinogradConv2DFmr fmr) {
905 if (!convOp.hasPureTensorSemantics())
906 return rewriter.notifyMatchFailure(
907 arg&: convOp, msg: "expected pure tensor semantics for linalg.conv_2d_nhwc_fhwc");
908
909 Value input = convOp.getInputs()[0];
910 Value filter = convOp.getInputs()[1];
911 Value output = convOp.getOutputs()[0];
912 auto inputType = cast<ShapedType>(Val: input.getType());
913 auto filterType = cast<ShapedType>(Val: filter.getType());
914 auto outputType = cast<ShapedType>(Val: output.getType());
915
916 if (!inputType.hasStaticShape())
917 return rewriter.notifyMatchFailure(arg&: convOp,
918 msg: "expected a static shape for the input");
919
920 if (!filterType.hasStaticShape())
921 return rewriter.notifyMatchFailure(
922 arg&: convOp, msg: "expected a static shape for the filter");
923
924 if (!hasAllOneValues(attr: convOp.getDilations()))
925 return rewriter.notifyMatchFailure(arg&: convOp,
926 msg: "expected all ones for dilations");
927
928 if (!hasAllOneValues(attr: convOp.getStrides()))
929 return rewriter.notifyMatchFailure(arg&: convOp, msg: "expected all ones for strides");
930
931 ArrayRef<int64_t> filterShape = filterType.getShape();
932 int64_t filterF = filterShape[0];
933 int64_t filterH = filterShape[1];
934 int64_t filterW = filterShape[2];
935 int64_t filterC = filterShape[3];
936 ArrayRef<int64_t> inputShape = inputType.getShape();
937 int64_t inputN = inputShape[0];
938 int64_t inputH = inputShape[1];
939 int64_t inputW = inputShape[2];
940 int64_t inputC = inputShape[3];
941 ArrayRef<int64_t> outputShape = outputType.getShape();
942 int64_t outputN = outputShape[0];
943 int64_t outputH = outputShape[1];
944 int64_t outputW = outputShape[2];
945 int64_t outputF = outputShape[3];
946
947 int64_t m, r;
948 std::tie(args&: m, args&: r) = getFmrFromWinogradConv2DFmr(fmr);
949 // Only support F(m x m, r x r), F(m x 1, r x 1) or F(1 x m, 1 x r).
950 bool isSupportedFilter = false;
951 if (filterH == filterW && filterH == r)
952 isSupportedFilter = true;
953 if (filterH == r && filterW == 1)
954 isSupportedFilter = true;
955 if (filterH == 1 && filterW == r)
956 isSupportedFilter = true;
957
958 if (!isSupportedFilter)
959 return rewriter.notifyMatchFailure(
960 arg&: convOp, msg: "only support filter (r x r), (r x 1) or (1 x r)");
961
962 // All the criterias are satisfied. We can do Winograd Conv2D.
963 Location loc = convOp.getLoc();
964
965 // For F(m x 1, r x 1), we only need to do left side transform.
966 bool leftTransform = filterH != 1;
967 // For F(1 x m, 1 x r), we only need to do right side transform.
968 bool rightTransform = filterW != 1;
969 int64_t heightM = leftTransform ? m : 1;
970 int64_t widthM = rightTransform ? m : 1;
971 int64_t heightR = leftTransform ? r : 1;
972 int64_t widthR = rightTransform ? r : 1;
973
974 // --- Create operation for filter transform ---
975 Type filterElementType = filterType.getElementType();
976 int64_t alphaH = heightM + heightR - 1;
977 int64_t alphaW = widthM + widthR - 1;
978 int64_t tileH = llvm::divideCeilSigned(Numerator: outputH, Denominator: heightM);
979 int64_t tileW = llvm::divideCeilSigned(Numerator: outputW, Denominator: widthM);
980 auto retType = RankedTensorType::get(shape: {alphaH, alphaW, filterC, filterF},
981 elementType: filterElementType);
982 Value retValue = rewriter.create<tensor::EmptyOp>(location: loc, args: retType.getShape(),
983 args&: filterElementType);
984 auto transformedFilter = rewriter.create<linalg::WinogradFilterTransformOp>(
985 location: loc, args&: retType, args&: filter, args&: retValue, args&: fmr);
986
987 // --- Create operation for input transform ---
988
989 // When input size - (r - 1) is not aligned with output tile size, we need to
990 // pad the input data to create the full tiles as tiling.
991 Type inputElementType = inputType.getElementType();
992 int64_t alignedInputH = tileH * heightM + (heightR - 1);
993 int64_t alignedInputW = tileW * widthM + (widthR - 1);
994 if (alignedInputH != inputH || alignedInputW != inputW) {
995 input = padToAlignedTensor(rewriter, loc, value: input,
996 alignedShape: {inputN, alignedInputH, alignedInputW, inputC});
997 }
998
999 retType = RankedTensorType::get(
1000 shape: {alphaH, alphaW, tileH, tileW, inputN, inputC}, elementType: inputElementType);
1001 retValue = rewriter.create<tensor::EmptyOp>(location: loc, args: retType.getShape(),
1002 args&: inputElementType);
1003 auto transformedInput = rewriter.create<linalg::WinogradInputTransformOp>(
1004 location: loc, args&: retType, args&: input, args&: retValue, args&: fmr);
1005
1006 Type outputElementType = outputType.getElementType();
1007 Value matmulRet = matrixMultiply(rewriter, loc, transformedFilter,
1008 transformedInput, outputElementType);
1009
1010 // --- Create operation for output transform ---
1011
1012 // When output size is not aligned with output tile size, we need to pad the
1013 // output buffer to insert the full tiles after tiling.
1014 int64_t alignedOutputH = tileH * heightM;
1015 int64_t alignedOutputW = tileW * widthM;
1016 bool isOutputUnaligned =
1017 ((alignedOutputH != outputH) || (alignedOutputW != outputW));
1018 if (isOutputUnaligned) {
1019 auto alignedOutputType = RankedTensorType::get(
1020 shape: {outputN, alignedOutputH, alignedOutputW, outputF}, elementType: outputElementType);
1021 output =
1022 padToAlignedTensor(rewriter, loc, value: output, alignedShape: alignedOutputType.getShape());
1023 outputType = alignedOutputType;
1024 }
1025
1026 Value transformedOutput = rewriter.create<linalg::WinogradOutputTransformOp>(
1027 location: loc, args&: outputType, args&: matmulRet, args&: output, args&: fmr);
1028
1029 // When output size is not aligned with output tile size, extract the
1030 // value from the padded buffer.
1031 if (isOutputUnaligned) {
1032 transformedOutput = extractFromAlignedTensor(
1033 rewriter, loc, value: transformedOutput,
1034 extractedType: RankedTensorType::get(shape: {outputN, outputH, outputW, outputF},
1035 elementType: outputElementType));
1036 }
1037
1038 rewriter.replaceOp(op: convOp, newValues: transformedOutput);
1039
1040 return transformedOutput.getDefiningOp();
1041}
1042
1043/// A helper function to decompose linalg.winograd_filter_transform.
1044FailureOr<Operation *>
1045decomposeWinogradFilterTransformHelper(RewriterBase &rewriter,
1046 linalg::WinogradFilterTransformOp op) {
1047 Location loc = op.getLoc();
1048 Value filter = op.getFilter();
1049 auto filterType = cast<ShapedType>(Val: filter.getType());
1050 auto filterShape = filterType.getShape();
1051 int64_t filterH = filterShape[1];
1052 int64_t filterW = filterShape[2];
1053
1054 // For F(m x 1, r x 1), we only need to do left side transform.
1055 bool leftTransform = filterH != 1;
1056 // For F(1 x m, 1 x r), we only need to do right side transform.
1057 bool rightTransform = filterW != 1;
1058 Value transformedFilter =
1059 filterTransform(rewriter, loc, filter, retValue: op.getOutput(), fmr: op.getFmr(),
1060 leftTransform, rightTransform);
1061 if (!transformedFilter)
1062 return failure();
1063
1064 rewriter.replaceOp(op, newValues: transformedFilter);
1065
1066 return transformedFilter.getDefiningOp();
1067}
1068
1069/// A helper function to decompose linalg.winograd_input_transform.
1070FailureOr<Operation *>
1071decomposeWinogradInputTransformHelper(RewriterBase &rewriter,
1072 linalg::WinogradInputTransformOp op) {
1073 Location loc = op.getLoc();
1074 Value output = op.getOutput();
1075 auto outputType = cast<ShapedType>(Val: output.getType());
1076 auto outputShape = outputType.getShape();
1077
1078 int64_t outputH = outputShape[0];
1079 int64_t outputW = outputShape[1];
1080
1081 // For F(m x 1, r x 1), we only need to do left side transform.
1082 bool leftTransform = outputH != 1;
1083 // For F(1 x m, 1 x r), we only need to do right side transform.
1084 bool rightTransform = outputW != 1;
1085 Value transformedInput =
1086 inputTransform(rewriter, loc, input: op.getInput(), retValue: op.getOutput(), fmr: op.getFmr(),
1087 leftTransform, rightTransform);
1088 if (!transformedInput)
1089 return failure();
1090
1091 rewriter.replaceOp(op, newValues: transformedInput);
1092
1093 return transformedInput.getDefiningOp();
1094}
1095
1096/// A helper function to decompose linalg.winograd_output_transform.
1097FailureOr<Operation *>
1098decomposeWinogradOutputTransformHelper(RewriterBase &rewriter,
1099 linalg::WinogradOutputTransformOp op) {
1100 Location loc = op.getLoc();
1101 Value value = op.getValue();
1102 auto valueType = cast<ShapedType>(Val: value.getType());
1103 auto valueShape = valueType.getShape();
1104 int64_t valueH = valueShape[0];
1105 int64_t valueW = valueShape[1];
1106
1107 // For F(m x 1, r x 1), we only need to do left side transform.
1108 bool leftTransform = valueH != 1;
1109 // For F(1 x m, 1 x r), we only need to do right side transform.
1110 bool rightTransform = valueW != 1;
1111 Value transformedOutput =
1112 outputTransform(rewriter, loc, value, output: op.getOutput(), fmr: op.getFmr(),
1113 leftTransform, rightTransform);
1114 if (!transformedOutput)
1115 return failure();
1116
1117 rewriter.replaceOp(op, newValues: transformedOutput);
1118
1119 return transformedOutput.getDefiningOp();
1120}
1121
1122/// A rewrite pattern to decompose linalg.winograd_filter_transform operations.
1123class DecomposeWinogradFilterTransform final
1124 : public OpRewritePattern<linalg::WinogradFilterTransformOp> {
1125public:
1126 using OpRewritePattern::OpRewritePattern;
1127
1128 LogicalResult matchAndRewrite(linalg::WinogradFilterTransformOp op,
1129 PatternRewriter &rewriter) const override {
1130 return decomposeWinogradFilterTransformHelper(rewriter, op);
1131 }
1132};
1133
1134/// A rewrite pattern to decompose linalg.winograd_input_transform operations.
1135class DecomposeWinogradInputTransform final
1136 : public OpRewritePattern<linalg::WinogradInputTransformOp> {
1137public:
1138 using OpRewritePattern::OpRewritePattern;
1139
1140 LogicalResult matchAndRewrite(linalg::WinogradInputTransformOp op,
1141 PatternRewriter &rewriter) const override {
1142 return decomposeWinogradInputTransformHelper(rewriter, op);
1143 }
1144};
1145
1146/// A rewrite pattern to decompose linalg.winograd_output_transform operations.
1147class DecomposeWinogradOutputTransform final
1148 : public OpRewritePattern<linalg::WinogradOutputTransformOp> {
1149public:
1150 using OpRewritePattern::OpRewritePattern;
1151
1152 LogicalResult matchAndRewrite(linalg::WinogradOutputTransformOp op,
1153 PatternRewriter &rewriter) const override {
1154 return decomposeWinogradOutputTransformHelper(rewriter, op);
1155 }
1156};
1157
1158/// A rewrite pattern for Winograd Conv2D algorithm.
1159class WinogradConv2DNhwcFhwc final
1160 : public OpRewritePattern<linalg::Conv2DNhwcFhwcOp> {
1161public:
1162 using OpRewritePattern::OpRewritePattern;
1163 WinogradConv2DNhwcFhwc(mlir::MLIRContext *context, WinogradConv2DFmr fmr)
1164 : OpRewritePattern(context), fmr(fmr) {}
1165
1166 LogicalResult matchAndRewrite(linalg::Conv2DNhwcFhwcOp convOp,
1167 PatternRewriter &rewriter) const override {
1168 if (failed(Result: winogradConv2DHelper(rewriter, convOp, fmr)))
1169 return failure();
1170
1171 return success();
1172 }
1173
1174private:
1175 WinogradConv2DFmr fmr;
1176};
1177
1178} // end anonymous namespace
1179
1180//===----------------------------------------------------------------------===//
1181FailureOr<Operation *> winogradConv2D(RewriterBase &rewriter,
1182 linalg::Conv2DNhwcFhwcOp op,
1183 linalg::WinogradConv2DFmr fmr) {
1184 return winogradConv2DHelper(rewriter, convOp: op, fmr);
1185}
1186
1187FailureOr<Operation *>
1188decomposeWinogradFilterTransformOp(RewriterBase &rewriter,
1189 linalg::WinogradFilterTransformOp op) {
1190 return decomposeWinogradFilterTransformHelper(rewriter, op);
1191}
1192
1193FailureOr<Operation *>
1194decomposeWinogradInputTransformOp(RewriterBase &rewriter,
1195 linalg::WinogradInputTransformOp op) {
1196 return decomposeWinogradInputTransformHelper(rewriter, op);
1197}
1198
1199FailureOr<Operation *>
1200decomposeWinogradOutputTransformOp(RewriterBase &rewriter,
1201 linalg::WinogradOutputTransformOp op) {
1202 return decomposeWinogradOutputTransformHelper(rewriter, op);
1203}
1204
1205void populateWinogradConv2DPatterns(RewritePatternSet &patterns,
1206 WinogradConv2DFmr fmr) {
1207 MLIRContext *context = patterns.getContext();
1208 // TODO: Support more Conv2D data layout, e.g., conv_2d_nchw_fchw
1209 patterns.insert<WinogradConv2DNhwcFhwc>(arg&: context, args&: fmr);
1210}
1211
1212void populateDecomposeWinogradOpsPatterns(RewritePatternSet &patterns) {
1213 MLIRContext *context = patterns.getContext();
1214 patterns
1215 .insert<DecomposeWinogradFilterTransform, DecomposeWinogradInputTransform,
1216 DecomposeWinogradOutputTransform>(arg&: context);
1217}
1218
1219} // end namespace linalg
1220} // end namespace mlir
1221

source code of mlir/lib/Dialect/Linalg/Transforms/WinogradConv2D.cpp