| 1 | //===- WinogradConv2D.cpp - Winograd Conv2D implementation ----------------===// |
| 2 | // |
| 3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
| 4 | // See https://llvm.org/LICENSE.txt for license information. |
| 5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
| 6 | // |
| 7 | //===----------------------------------------------------------------------===// |
| 8 | // |
| 9 | // Implement Winograd Conv2D algorithm. The implementation is based on the |
| 10 | // paper: Fast Algorithms for Convolutional Neural Networks |
| 11 | // (https://arxiv.org/abs/1509.09308) |
| 12 | // |
| 13 | //===----------------------------------------------------------------------===// |
| 14 | |
| 15 | #include "mlir/Dialect/Affine/IR/AffineOps.h" |
| 16 | #include "mlir/Dialect/Arith/IR/Arith.h" |
| 17 | #include "mlir/Dialect/Linalg/IR/Linalg.h" |
| 18 | #include "mlir/Dialect/Linalg/Utils/Utils.h" |
| 19 | #include "mlir/Dialect/Tensor/IR/Tensor.h" |
| 20 | #include "mlir/Dialect/Utils/StaticValueUtils.h" |
| 21 | #include "llvm/Support/MathExtras.h" |
| 22 | |
| 23 | namespace mlir { |
| 24 | namespace linalg { |
| 25 | |
| 26 | namespace { |
| 27 | |
| 28 | // clang-format off |
| 29 | /// Winograd Conv2D uses a minimal 2D filtering algorithm to calculate its |
| 30 | /// result. The formula of minimal 2D filtering algorithm F(m x m, r x r), |
| 31 | /// m is the output dimension and r is the filter dimension, is |
| 32 | /// |
| 33 | /// Y = A^T x [ (G x g x G^T) x (B^T x d x B) ] x A |
| 34 | /// |
| 35 | /// g is filter and d is input data. We need to prepare 6 constant |
| 36 | /// transformation matrices, G, G^T, B^T, B, A^T, and A for this formula. |
| 37 | /// |
| 38 | /// The following tables define these constant transformation matrices for |
| 39 | /// F(2 x 2, 3 x 3), F(4 x 4, 3 x 3), and F(2 x 2, 5 x 5) |
| 40 | /// |
| 41 | /// To add more transformation matrices, we need to add the following |
| 42 | /// items: |
| 43 | /// 1. Add the constant transformation matrix to the corresponding |
| 44 | /// G, GT, BT, B, AT, or A array. |
| 45 | /// 2. Add the corresponding TransformMatrix to the GMatrices, GTMatrices, |
| 46 | /// BTMatrices, BMatrices, ATMatrices, or AMatrices map. |
| 47 | /// 3. Add a enum value F_m_r to WinogradConv2DFmr enum. |
| 48 | /// |
| 49 | constexpr float G_2x2_3x3[] = { |
| 50 | -1, 0, 0, |
| 51 | 1./2, -1./2, 1./2, |
| 52 | 1./2, 1./2, 1./2, |
| 53 | 0, 0, 1 |
| 54 | }; |
| 55 | |
| 56 | constexpr float GT_2x2_3x3[] = { |
| 57 | -1, 1./2, 1./2, 0, |
| 58 | 0, -1./2, 1./2, 0, |
| 59 | 0, 1./2, 1./2, 1 |
| 60 | }; |
| 61 | |
| 62 | constexpr float BT_2x2_3x3[] = { |
| 63 | -1, 0, 1, 0, |
| 64 | 0, -1, 1, 0, |
| 65 | 0, 1, 1, 0, |
| 66 | 0, -1, 0, 1 |
| 67 | }; |
| 68 | |
| 69 | constexpr float B_2x2_3x3[] = { |
| 70 | -1, 0, 0, 0, |
| 71 | 0, -1, 1, -1, |
| 72 | 1, 1, 1, 0, |
| 73 | 0, 0, 0, 1 |
| 74 | }; |
| 75 | |
| 76 | constexpr float AT_2x2_3x3[] = { |
| 77 | 1, 1, 1, 0, |
| 78 | 0, -1, 1, 1 |
| 79 | }; |
| 80 | |
| 81 | constexpr float A_2x2_3x3[] = { |
| 82 | 1, 0, |
| 83 | 1, -1, |
| 84 | 1, 1, |
| 85 | 0, 1 |
| 86 | }; |
| 87 | |
| 88 | constexpr float G_4x4_3x3[] = { |
| 89 | 1, 0, 0, |
| 90 | -1./3, 1./3, -1./3, |
| 91 | -1./3, -1./3, -1./3, |
| 92 | 1./12, -1./6, 1./3, |
| 93 | 1./12, 1./6, 1./3, |
| 94 | 0, 0, 1 |
| 95 | }; |
| 96 | |
| 97 | constexpr float GT_4x4_3x3[] = { |
| 98 | 1, -1./3, -1./3, 1./12, 1./12, 0, |
| 99 | 0, 1./3, -1./3, -1./6, 1./6, 0, |
| 100 | 0, -1./3, -1./3, 1./3, 1./3, 1 |
| 101 | }; |
| 102 | |
| 103 | constexpr float BT_4x4_3x3[] = { |
| 104 | 1./4, 0, -5./16, 0, 1./16, 0, |
| 105 | 0, 1./4, -1./4, -1./16, 1./16, 0, |
| 106 | 0, -1./4, -1./4, 1./16, 1./16, 0, |
| 107 | 0, 1./4, -1./8, -1./4, 1./8, 0, |
| 108 | 0, -1./4, -1./8, 1./4, 1./8, 0, |
| 109 | 0, 1./4, 0, -5./16, 0, 1./16 |
| 110 | }; |
| 111 | |
| 112 | constexpr float B_4x4_3x3[] = { |
| 113 | 1./4, 0, 0, 0, 0, 0, |
| 114 | 0, 1./4, -1./4, 1./4, -1./4, 1./4, |
| 115 | -5./16, -1./4, -1./4, -1./8, -1./8, 0, |
| 116 | 0, -1./16, 1./16, -1./4, 1./4, -5./16, |
| 117 | 1./16, 1./16, 1./16, 1./8, 1./8, 0, |
| 118 | 0, 0, 0, 0, 0, 1./16 |
| 119 | }; |
| 120 | |
| 121 | constexpr float AT_4x4_3x3[] = { |
| 122 | 1./8, 1./4, 1./4, 1./8, 1./8, 0, |
| 123 | 0, -1./4, 1./4, -1./4, 1./4, 0, |
| 124 | 0, 1./4, 1./4, 1./2, 1./2, 0, |
| 125 | 0, -1./4, 1./4, -1, 1, 1./2 |
| 126 | }; |
| 127 | |
| 128 | constexpr float A_4x4_3x3[] = { |
| 129 | 1./8, 0, 0, 0, |
| 130 | 1./4, -1./4, 1./4, -1./4, |
| 131 | 1./4, 1./4, 1./4, 1./4, |
| 132 | 1./8, -1./4, 1./2, -1, |
| 133 | 1./8, 1./4, 1./2, 1, |
| 134 | 0, 0, 0, 1./2 |
| 135 | }; |
| 136 | |
| 137 | constexpr float G_2x2_5x5[] = { |
| 138 | 1, 0, 0, 0, 0, |
| 139 | 1./6, -1./6, 1./6, -1./6, 1./6, |
| 140 | -1./6, -1./6, -1./6, -1./6, -1./6, |
| 141 | -4./15, 2./15, -1./15, 1./30, -1./60, |
| 142 | 1./60, 1./30, 1./15, 2./15, 4./15, |
| 143 | 0, 0, 0, 0, 1 |
| 144 | }; |
| 145 | |
| 146 | constexpr float GT_2x2_5x5[] = { |
| 147 | 1, 1./6, -1./6, -4./15, 1./60, 0, |
| 148 | 0, -1./6, -1./6, 2./15, 1./30, 0, |
| 149 | 0, 1./6, -1./6, -1./15, 1./15, 0, |
| 150 | 0, -1./6, -1./6, 1./30, 2./15, 0, |
| 151 | 0, 1./6, -1./6, -1./60, 4./15, 1 |
| 152 | }; |
| 153 | |
| 154 | constexpr float BT_2x2_5x5[] = { |
| 155 | 1./8, 3./16, -1./4, -3./16, 1./8, 0, |
| 156 | 0, 1./8, 1./16, -5./16, 1./8, 0, |
| 157 | 0, -1./8, -5./16, -1./16, 1./8, 0, |
| 158 | 0, 1./4, -1./8, -1./4, 1./8, 0, |
| 159 | 0, -1./8, -1./4, 1./8, 1./4, 0, |
| 160 | 0, 1./8, 3./16, -1./4, -3./16, 1./8 |
| 161 | }; |
| 162 | |
| 163 | constexpr float B_2x2_5x5[] = { |
| 164 | 1./8, 0, 0, 0, 0, 0, |
| 165 | 3./16, 1./8, -1./8, 1./4, -1./8, 1./8, |
| 166 | -1./4, 1./16, -5./16, -1./8, -1./4, 3./16, |
| 167 | -3./16, -5./16, -1./16, -1./4, 1./8, -1./4, |
| 168 | 1./8, 1./8, 1./8, 1./8, 1./4, -3./16, |
| 169 | 0, 0, 0, 0, 0, 1./8 |
| 170 | }; |
| 171 | |
| 172 | constexpr float AT_2x2_5x5[] = { |
| 173 | 1./2, 1, 1, 2, 1, 0, |
| 174 | 0, -1, 1, -1, 2, 1./2 |
| 175 | }; |
| 176 | |
| 177 | constexpr float A_2x2_5x5[] = { |
| 178 | 1./2, 0, |
| 179 | 1, -1, |
| 180 | 1, 1, |
| 181 | 2, -1, |
| 182 | 1, 2, |
| 183 | 0, 1./2 |
| 184 | }; |
| 185 | // clang-format on |
| 186 | |
| 187 | /// Structure to keep information of constant transform matrices. |
| 188 | struct TransformMatrix { |
| 189 | TransformMatrix(const float *table, int64_t rows, int64_t cols, |
| 190 | int64_t scalarFactor = 1) |
| 191 | : table(table), rows(rows), cols(cols), scalarFactor(scalarFactor) {} |
| 192 | |
| 193 | const float *table; |
| 194 | int64_t rows; |
| 195 | int64_t cols; |
| 196 | int64_t scalarFactor; |
| 197 | }; |
| 198 | |
| 199 | /// Utility function to convert constant array to arith.constant Value. |
| 200 | Value create2DTransformMatrix(OpBuilder &builder, Location loc, |
| 201 | TransformMatrix transform, Type type) { |
| 202 | ArrayRef<float> constVec(transform.table, transform.rows * transform.cols); |
| 203 | |
| 204 | return builder.create<arith::ConstantOp>( |
| 205 | location: loc, args: DenseFPElementsAttr::get( |
| 206 | type: RankedTensorType::get( |
| 207 | shape: SmallVector<int64_t>{transform.rows, transform.cols}, elementType: type), |
| 208 | arg&: constVec)); |
| 209 | } |
| 210 | |
| 211 | /// Extract height x width data from 4D tensors. |
| 212 | Value (OpBuilder &builder, Location loc, Value source, |
| 213 | Value loopNorFIndex, Value loopCorFIndex, |
| 214 | Value heightOffset, Value widthOffset, |
| 215 | int64_t , int64_t , |
| 216 | int64_t loopNorFIdx, int64_t loopCorFIdx, |
| 217 | int64_t heightIdx, int64_t widthIdx) { |
| 218 | auto sourceType = cast<ShapedType>(Val: source.getType()); |
| 219 | Type elementType = sourceType.getElementType(); |
| 220 | int64_t srcSize = sourceType.getRank(); |
| 221 | |
| 222 | auto oneIndex = builder.getIndexAttr(value: 1); |
| 223 | SmallVector<OpFoldResult> offsets; |
| 224 | offsets.resize(N: srcSize); |
| 225 | offsets[loopNorFIdx] = loopNorFIndex; |
| 226 | offsets[loopCorFIdx] = loopCorFIndex; |
| 227 | offsets[heightIdx] = heightOffset; |
| 228 | offsets[widthIdx] = widthOffset; |
| 229 | SmallVector<OpFoldResult> sizes(srcSize, oneIndex); |
| 230 | sizes[heightIdx] = builder.getIndexAttr(value: extractHeight); |
| 231 | sizes[widthIdx] = builder.getIndexAttr(value: extractWidth); |
| 232 | SmallVector<OpFoldResult> strides(srcSize, oneIndex); |
| 233 | |
| 234 | auto = |
| 235 | RankedTensorType::get(shape: {extractHeight, extractWidth}, elementType); |
| 236 | auto = builder.create<tensor::ExtractSliceOp>( |
| 237 | location: loc, args&: extractFilterType, args&: source, args&: offsets, args&: sizes, args&: strides); |
| 238 | |
| 239 | return extractFilterOp; |
| 240 | } |
| 241 | |
| 242 | /// Extract height x width data from 6D tensors. |
| 243 | Value (OpBuilder &builder, Location loc, Value source, |
| 244 | Value tileHIndex, Value tileWIndex, |
| 245 | Value loopNorFIndex, Value loopCorFIndex, |
| 246 | int64_t tileHIdx, int64_t tileWIdx, |
| 247 | int64_t loopNorFIdx, int64_t loopCorFIdx, |
| 248 | int64_t heightIdx, int64_t widthIdx) { |
| 249 | auto sourceType = cast<ShapedType>(Val: source.getType()); |
| 250 | Type elementType = sourceType.getElementType(); |
| 251 | auto sourceShape = sourceType.getShape(); |
| 252 | int64_t srcSize = sourceType.getRank(); |
| 253 | int64_t height = sourceShape[heightIdx]; |
| 254 | int64_t width = sourceShape[widthIdx]; |
| 255 | |
| 256 | auto zeroIndex = builder.getIndexAttr(value: 0); |
| 257 | auto oneIndex = builder.getIndexAttr(value: 1); |
| 258 | SmallVector<OpFoldResult> offsets(srcSize, zeroIndex); |
| 259 | offsets.resize(N: srcSize); |
| 260 | offsets[tileHIdx] = tileHIndex; |
| 261 | offsets[tileWIdx] = tileWIndex; |
| 262 | offsets[loopNorFIdx] = loopNorFIndex; |
| 263 | offsets[loopCorFIdx] = loopCorFIndex; |
| 264 | SmallVector<OpFoldResult> sizes(srcSize, oneIndex); |
| 265 | sizes[heightIdx] = builder.getIndexAttr(value: height); |
| 266 | sizes[widthIdx] = builder.getIndexAttr(value: width); |
| 267 | SmallVector<OpFoldResult> strides(srcSize, oneIndex); |
| 268 | |
| 269 | auto = RankedTensorType::get(shape: {height, width}, elementType); |
| 270 | auto = builder.create<tensor::ExtractSliceOp>( |
| 271 | location: loc, args&: extractFilterType, args&: source, args&: offsets, args&: sizes, args&: strides); |
| 272 | |
| 273 | return extractFilterOp; |
| 274 | } |
| 275 | |
| 276 | /// Insert transformed height x width data to 4D tensors which it is |
| 277 | /// extracted from. |
| 278 | Value insert2DDataTo4D(OpBuilder &builder, Location loc, Value source, |
| 279 | Value dest, Value loopNorFIndex, Value loopCorFIndex, |
| 280 | Value heightOffset, Value widthOffset, int64_t height, |
| 281 | int64_t width, int64_t loopNorFIdx, int64_t loopCorFIdx, |
| 282 | int64_t heightIdx, int64_t widthIdx) { |
| 283 | int64_t destSize = cast<ShapedType>(Val: dest.getType()).getRank(); |
| 284 | auto oneIndex = builder.getIndexAttr(value: 1); |
| 285 | SmallVector<OpFoldResult> retOffsets; |
| 286 | retOffsets.resize(N: destSize); |
| 287 | retOffsets[loopNorFIdx] = loopNorFIndex; |
| 288 | retOffsets[loopCorFIdx] = loopCorFIndex; |
| 289 | retOffsets[heightIdx] = heightOffset; |
| 290 | retOffsets[widthIdx] = widthOffset; |
| 291 | SmallVector<OpFoldResult> retSizes(destSize, oneIndex); |
| 292 | retSizes[heightIdx] = builder.getIndexAttr(value: height); |
| 293 | retSizes[widthIdx] = builder.getIndexAttr(value: width); |
| 294 | SmallVector<OpFoldResult> strides(destSize, oneIndex); |
| 295 | |
| 296 | auto insertSliceOp = builder.create<tensor::InsertSliceOp>( |
| 297 | location: loc, args&: source, args&: dest, args&: retOffsets, args&: retSizes, args&: strides); |
| 298 | |
| 299 | return insertSliceOp; |
| 300 | } |
| 301 | |
| 302 | /// Insert transformed height x width data to 6D tensors which it is |
| 303 | /// extracted from. |
| 304 | Value insert2DDataTo6D(OpBuilder &builder, Location loc, Value source, |
| 305 | Value dest, Value tileHIndex, Value tileWIndex, |
| 306 | Value loopNorFIndex, Value loopCorFIndex, int64_t height, |
| 307 | int64_t width, int64_t tileHIdx, int64_t tileWIdx, |
| 308 | int64_t loopNorFIdx, int64_t loopCorFIdx, |
| 309 | int64_t heightIdx, int64_t widthIdx) { |
| 310 | int64_t destSize = cast<ShapedType>(Val: dest.getType()).getRank(); |
| 311 | auto zeroIndex = builder.getIndexAttr(value: 0); |
| 312 | auto oneIndex = builder.getIndexAttr(value: 1); |
| 313 | SmallVector<OpFoldResult> retOffsets(destSize, zeroIndex); |
| 314 | retOffsets.resize(N: destSize); |
| 315 | retOffsets[tileHIdx] = tileHIndex; |
| 316 | retOffsets[tileWIdx] = tileWIndex; |
| 317 | retOffsets[loopNorFIdx] = loopNorFIndex; |
| 318 | retOffsets[loopCorFIdx] = loopCorFIndex; |
| 319 | SmallVector<OpFoldResult> retSizes(destSize, oneIndex); |
| 320 | retSizes[heightIdx] = builder.getIndexAttr(value: height); |
| 321 | retSizes[widthIdx] = builder.getIndexAttr(value: width); |
| 322 | SmallVector<OpFoldResult> strides(destSize, oneIndex); |
| 323 | |
| 324 | auto insertSliceOp = builder.create<tensor::InsertSliceOp>( |
| 325 | location: loc, args&: source, args&: dest, args&: retOffsets, args&: retSizes, args&: strides); |
| 326 | |
| 327 | return insertSliceOp; |
| 328 | } |
| 329 | |
| 330 | /// This function transforms the filter. The data layout of the filter is FHWC. |
| 331 | /// The transformation matrix is 2-dimension. We need to extract H x W from |
| 332 | /// FHWC first. We need to generate 2 levels of loops to iterate on F and C. |
| 333 | /// After the transformation, we get |
| 334 | /// |
| 335 | /// scf.for %f = lo_f to hi_f step 1 |
| 336 | /// scf.for %c = lo_c to hi_c step 1 |
| 337 | /// %extracted = extract filter<h x w> from filter<f x h x w x c> |
| 338 | /// %ret = linalg.matmul G, %extracted |
| 339 | /// %ret = linalg.matmul %ret, GT |
| 340 | /// %inserted = insert %ret into filter<h x w x c x f> |
| 341 | Value filterTransform(RewriterBase &rewriter, Location loc, Value filter, |
| 342 | Value retValue, WinogradConv2DFmr fmr, |
| 343 | bool leftTransform = true, bool rightTransform = true) { |
| 344 | // Map from (m, r) to G transform matrix. |
| 345 | static const llvm::SmallDenseMap<WinogradConv2DFmr, TransformMatrix> |
| 346 | GMatrices = { |
| 347 | {WinogradConv2DFmr::F_2_3, TransformMatrix(G_2x2_3x3, 4, 3)}, |
| 348 | {WinogradConv2DFmr::F_4_3, TransformMatrix(G_4x4_3x3, 6, 3)}, |
| 349 | {WinogradConv2DFmr::F_2_5, TransformMatrix(G_2x2_5x5, 6, 5)}, |
| 350 | }; |
| 351 | |
| 352 | // Map from (m, r) to GT transform matrix. |
| 353 | static const llvm::SmallDenseMap<WinogradConv2DFmr, TransformMatrix> |
| 354 | GTMatrices = { |
| 355 | {WinogradConv2DFmr::F_2_3, TransformMatrix(GT_2x2_3x3, 3, 4)}, |
| 356 | {WinogradConv2DFmr::F_4_3, TransformMatrix(GT_4x4_3x3, 3, 6)}, |
| 357 | {WinogradConv2DFmr::F_2_5, TransformMatrix(GT_2x2_5x5, 5, 6)}, |
| 358 | }; |
| 359 | |
| 360 | auto filterType = cast<ShapedType>(Val: filter.getType()); |
| 361 | Type elementType = filterType.getElementType(); |
| 362 | auto filterShape = filterType.getShape(); // F, H, W, C |
| 363 | int64_t filterF = filterShape[0]; |
| 364 | int64_t filterH = filterShape[1]; |
| 365 | int64_t filterW = filterShape[2]; |
| 366 | int64_t filterC = filterShape[3]; |
| 367 | |
| 368 | int64_t m, r; |
| 369 | std::tie(args&: m, args&: r) = getFmrFromWinogradConv2DFmr(fmr); |
| 370 | if (filterH != r && filterH != 1) |
| 371 | return Value(); |
| 372 | if (filterW != r && filterW != 1) |
| 373 | return Value(); |
| 374 | |
| 375 | Value zeroIdx = rewriter.create<arith::ConstantIndexOp>(location: loc, args: 0); |
| 376 | auto buildBody = [&](OpBuilder &builder, Location loc, ValueRange ivs, |
| 377 | ValueRange args) -> scf::ValueVector { |
| 378 | Value FIter = ivs[0]; |
| 379 | Value CIter = ivs[1]; |
| 380 | |
| 381 | // Extract (H, W) from (F, H, W, C). |
| 382 | auto = |
| 383 | extract2DDataFrom4D(builder, loc, source: filter, loopNorFIndex: FIter, loopCorFIndex: CIter, heightOffset: zeroIdx, |
| 384 | widthOffset: zeroIdx, extractHeight: filterH, extractWidth: filterW, /*loopNorFIdx=*/0, |
| 385 | /*loopCorFIdx=*/3, /*heightIdx=*/1, /*widthIdx=*/2); |
| 386 | |
| 387 | int64_t retRows = 1; |
| 388 | Value matmulRetValue = extractFilter; |
| 389 | Value zero = builder.create<arith::ConstantOp>( |
| 390 | location: loc, args: rewriter.getZeroAttr(type: elementType)); |
| 391 | if (leftTransform) { |
| 392 | // Get constant transform matrix G. |
| 393 | auto it = GMatrices.find(Val: fmr); |
| 394 | if (it == GMatrices.end()) |
| 395 | return {}; |
| 396 | const TransformMatrix &GMatrix = it->second; |
| 397 | |
| 398 | retRows = GMatrix.rows; |
| 399 | auto matmulType = RankedTensorType::get(shape: {retRows, filterW}, elementType); |
| 400 | auto empty = |
| 401 | builder |
| 402 | .create<tensor::EmptyOp>(location: loc, args: matmulType.getShape(), args&: elementType) |
| 403 | .getResult(); |
| 404 | auto init = builder.create<linalg::FillOp>(location: loc, args&: zero, args&: empty).getResult(i: 0); |
| 405 | |
| 406 | Value G = create2DTransformMatrix(builder, loc, transform: GMatrix, type: elementType); |
| 407 | // Multiply G x g. |
| 408 | auto matmulOp = builder.create<linalg::MatmulOp>( |
| 409 | location: loc, args&: matmulType, args: ValueRange{G, extractFilter}, args: ValueRange{init}); |
| 410 | matmulRetValue = matmulOp.getResult(i: 0); |
| 411 | } |
| 412 | |
| 413 | if (rightTransform) { |
| 414 | // Get constant transform matrix GT. |
| 415 | auto it = GTMatrices.find(Val: fmr); |
| 416 | if (it == GTMatrices.end()) |
| 417 | return {}; |
| 418 | const TransformMatrix >Matrix = it->second; |
| 419 | |
| 420 | auto matmulType = |
| 421 | RankedTensorType::get(shape: {retRows, GTMatrix.cols}, elementType); |
| 422 | auto empty = |
| 423 | builder |
| 424 | .create<tensor::EmptyOp>(location: loc, args: matmulType.getShape(), args&: elementType) |
| 425 | .getResult(); |
| 426 | auto init = builder.create<linalg::FillOp>(location: loc, args&: zero, args&: empty).getResult(i: 0); |
| 427 | |
| 428 | Value GT = create2DTransformMatrix(builder, loc, transform: GTMatrix, type: elementType); |
| 429 | // Multiply u = (G x g) x GT. |
| 430 | auto matmulOp = builder.create<linalg::MatmulOp>( |
| 431 | location: loc, args&: matmulType, args: ValueRange{matmulRetValue, GT}, args: ValueRange{init}); |
| 432 | matmulRetValue = matmulOp.getResult(i: 0); |
| 433 | } |
| 434 | |
| 435 | // Insert (H, W) to (H, W, C, F). |
| 436 | int64_t retHeight = leftTransform ? m + r - 1 : 1; |
| 437 | int64_t retWidth = rightTransform ? m + r - 1 : 1; |
| 438 | |
| 439 | auto insertSliceOp = |
| 440 | insert2DDataTo4D(builder, loc, source: matmulRetValue, dest: args[0], loopNorFIndex: FIter, loopCorFIndex: CIter, |
| 441 | heightOffset: zeroIdx, widthOffset: zeroIdx, height: retHeight, width: retWidth, |
| 442 | /*loopNorFIdx=*/3, /*loopCorFIdx=*/2, |
| 443 | /*heightIdx=*/0, /*widthIdx=*/1); |
| 444 | |
| 445 | return {insertSliceOp}; |
| 446 | }; |
| 447 | |
| 448 | auto fUpperBound = rewriter.create<arith::ConstantIndexOp>(location: loc, args&: filterF); |
| 449 | auto cUpperBound = rewriter.create<arith::ConstantIndexOp>(location: loc, args&: filterC); |
| 450 | auto oneStep = rewriter.create<arith::ConstantIndexOp>(location: loc, args: 1); |
| 451 | scf::LoopNest loops = scf::buildLoopNest( |
| 452 | builder&: rewriter, loc, lbs: {zeroIdx, zeroIdx}, ubs: {fUpperBound, cUpperBound}, |
| 453 | steps: {oneStep, oneStep}, iterArgs: {retValue}, bodyBuilder: buildBody); |
| 454 | return loops.results[0]; |
| 455 | } |
| 456 | |
| 457 | /// This function transforms the input. The data layout of the input is NHWC. |
| 458 | /// The transformation matrix is 2-dimension. We need to extract H x W from |
| 459 | /// NHWC first. We need to generate 2 levels of loops to iterate on N and C. |
| 460 | /// After the transformation, we get |
| 461 | /// |
| 462 | /// scf.for %h = 0 to tileH step 1 |
| 463 | /// scf.for %w = 0 to tileW step 1 |
| 464 | /// scf.for %n = 0 to N step 1 |
| 465 | /// scf.for %c = 0 to C step 1 |
| 466 | /// %extracted = extract %extracted<alphaH x alphaW> from |
| 467 | /// %input<N x H x W x C> |
| 468 | /// at [%n, (%h x m), (%w x m), %c] |
| 469 | /// %ret = linalg.matmul BT, %extracted |
| 470 | /// %ret = linalg.matmul %ret, B |
| 471 | /// %inserted = insert %ret<alphaH x alphaW> into |
| 472 | /// %output<alphaH x alphaW x tileH x tileW x N x C> |
| 473 | /// at [0, 0, %h, %w, %n, %c] |
| 474 | Value inputTransform(RewriterBase &rewriter, Location loc, Value input, |
| 475 | Value retValue, WinogradConv2DFmr fmr, |
| 476 | bool leftTransform = true, bool rightTransform = true) { |
| 477 | // Map from (m, r) to BT transform matrix. |
| 478 | static const llvm::SmallDenseMap<WinogradConv2DFmr, TransformMatrix> |
| 479 | BTMatrices = { |
| 480 | {WinogradConv2DFmr::F_2_3, TransformMatrix(BT_2x2_3x3, 4, 4)}, |
| 481 | {WinogradConv2DFmr::F_4_3, TransformMatrix(BT_4x4_3x3, 6, 6)}, |
| 482 | {WinogradConv2DFmr::F_2_5, TransformMatrix(BT_2x2_5x5, 6, 6)}, |
| 483 | }; |
| 484 | |
| 485 | // Map from (m, r) to B transform matrix. |
| 486 | static const llvm::SmallDenseMap<WinogradConv2DFmr, TransformMatrix> |
| 487 | BMatrices = { |
| 488 | {WinogradConv2DFmr::F_2_3, TransformMatrix(B_2x2_3x3, 4, 4)}, |
| 489 | {WinogradConv2DFmr::F_4_3, TransformMatrix(B_4x4_3x3, 6, 6)}, |
| 490 | {WinogradConv2DFmr::F_2_5, TransformMatrix(B_2x2_5x5, 6, 6)}, |
| 491 | }; |
| 492 | |
| 493 | int64_t m, r; |
| 494 | std::tie(args&: m, args&: r) = getFmrFromWinogradConv2DFmr(fmr); |
| 495 | auto inputType = cast<ShapedType>(Val: input.getType()); |
| 496 | Type elementType = inputType.getElementType(); |
| 497 | auto inputShape = inputType.getShape(); // N, H, W, C |
| 498 | int64_t inputN = inputShape[0]; |
| 499 | int64_t inputC = inputShape[3]; |
| 500 | auto valueType = cast<ShapedType>(Val: retValue.getType()); |
| 501 | auto valueShape = valueType.getShape(); // alphaH, alphaW, HTile, WTile, N, C |
| 502 | int64_t tileH = valueShape[2]; |
| 503 | int64_t tileW = valueShape[3]; |
| 504 | int64_t alphaH = leftTransform ? m + r - 1 : 1; |
| 505 | int64_t alphaW = rightTransform ? m + r - 1 : 1; |
| 506 | |
| 507 | auto buildBody = [&](OpBuilder &builder, Location loc, ValueRange ivs, |
| 508 | ValueRange args) -> scf::ValueVector { |
| 509 | Value tileHIter = ivs[0]; |
| 510 | Value tileWIter = ivs[1]; |
| 511 | Value NIter = ivs[2]; |
| 512 | Value CIter = ivs[3]; |
| 513 | |
| 514 | auto context = builder.getContext(); |
| 515 | |
| 516 | auto identityAffineMap = rewriter.getMultiDimIdentityMap(rank: 1); |
| 517 | auto affineMap = |
| 518 | AffineMap::get(dimCount: 1, symbolCount: 0, results: {builder.getAffineDimExpr(position: 0) * m}, context); |
| 519 | Value heightOffset = builder.create<affine::AffineApplyOp>( |
| 520 | location: loc, args&: leftTransform ? affineMap : identityAffineMap, args&: tileHIter); |
| 521 | Value widthOffset = builder.create<affine::AffineApplyOp>( |
| 522 | location: loc, args&: rightTransform ? affineMap : identityAffineMap, args&: tileWIter); |
| 523 | |
| 524 | // Extract (H, W) from (N, H, W, C). |
| 525 | auto = |
| 526 | extract2DDataFrom4D(builder, loc, source: input, loopNorFIndex: NIter, loopCorFIndex: CIter, heightOffset, |
| 527 | widthOffset, extractHeight: alphaH, extractWidth: alphaW, /*loopNorFIdx=*/0, |
| 528 | /*loopCorFIdx=*/3, /*heightIdx=*/1, /*widthIdx=*/2); |
| 529 | |
| 530 | int64_t retRows = 1; |
| 531 | int64_t retCols = 1; |
| 532 | Value matmulRetValue = extractInput; |
| 533 | Value zero = builder.create<arith::ConstantOp>( |
| 534 | location: loc, args: rewriter.getZeroAttr(type: elementType)); |
| 535 | if (leftTransform) { |
| 536 | // Get constant transform matrix BT. |
| 537 | auto it = BTMatrices.find(Val: fmr); |
| 538 | if (it == BTMatrices.end()) |
| 539 | return {}; |
| 540 | const TransformMatrix &BTMatrix = it->second; |
| 541 | |
| 542 | retRows = BTMatrix.rows; |
| 543 | auto matmulType = RankedTensorType::get(shape: {retRows, alphaW}, elementType); |
| 544 | auto empty = |
| 545 | builder |
| 546 | .create<tensor::EmptyOp>(location: loc, args: matmulType.getShape(), args&: elementType) |
| 547 | .getResult(); |
| 548 | auto init = builder.create<linalg::FillOp>(location: loc, args&: zero, args&: empty).getResult(i: 0); |
| 549 | |
| 550 | Value BT = |
| 551 | create2DTransformMatrix(builder, loc, transform: BTMatrix, type: builder.getF32Type()); |
| 552 | // Multiply BT x d. |
| 553 | auto matmulOp = builder.create<linalg::MatmulOp>( |
| 554 | location: loc, args&: matmulType, args: ValueRange{BT, matmulRetValue}, args: ValueRange{init}); |
| 555 | matmulRetValue = matmulOp.getResult(i: 0); |
| 556 | } |
| 557 | |
| 558 | if (rightTransform) { |
| 559 | // Get constant transform matrix B. |
| 560 | auto it = BMatrices.find(Val: fmr); |
| 561 | if (it == BMatrices.end()) |
| 562 | return {}; |
| 563 | const TransformMatrix &BMatrix = it->second; |
| 564 | |
| 565 | retCols = BMatrix.cols; |
| 566 | auto matmulType = RankedTensorType::get(shape: {retRows, retCols}, elementType); |
| 567 | auto empty = |
| 568 | builder |
| 569 | .create<tensor::EmptyOp>(location: loc, args: matmulType.getShape(), args&: elementType) |
| 570 | .getResult(); |
| 571 | auto init = builder.create<linalg::FillOp>(location: loc, args&: zero, args&: empty).getResult(i: 0); |
| 572 | Value B = |
| 573 | create2DTransformMatrix(builder, loc, transform: BMatrix, type: builder.getF32Type()); |
| 574 | // Multiply v = (BT x d) x B. |
| 575 | auto matmulOp = builder.create<linalg::MatmulOp>( |
| 576 | location: loc, args&: matmulType, args: ValueRange{matmulRetValue, B}, args: ValueRange{init}); |
| 577 | matmulRetValue = matmulOp.getResult(i: 0); |
| 578 | } |
| 579 | |
| 580 | // Insert (H, W) to (H, W, tileH, tileW, N, C). |
| 581 | auto combinedVal = insert2DDataTo6D( |
| 582 | builder, loc, source: matmulRetValue, dest: args[0], tileHIndex: tileHIter, tileWIndex: tileWIter, loopNorFIndex: NIter, |
| 583 | loopCorFIndex: CIter, height: retRows, width: retCols, tileHIdx: 2, tileWIdx: 3, /*loopNorFIdx=*/4, /*loopCorFIdx=*/5, |
| 584 | /*heightIdx=*/0, /*widthIdx=*/1); |
| 585 | |
| 586 | return {combinedVal}; |
| 587 | }; |
| 588 | |
| 589 | auto zeroIdx = rewriter.create<arith::ConstantIndexOp>(location: loc, args: 0); |
| 590 | auto tileHBound = rewriter.create<arith::ConstantIndexOp>(location: loc, args&: tileH); |
| 591 | auto tileWBound = rewriter.create<arith::ConstantIndexOp>(location: loc, args&: tileW); |
| 592 | auto nUpperBound = rewriter.create<arith::ConstantIndexOp>(location: loc, args&: inputN); |
| 593 | auto cUpperBound = rewriter.create<arith::ConstantIndexOp>(location: loc, args&: inputC); |
| 594 | auto oneStep = rewriter.create<arith::ConstantIndexOp>(location: loc, args: 1); |
| 595 | scf::LoopNest loops = scf::buildLoopNest( |
| 596 | builder&: rewriter, loc, lbs: {zeroIdx, zeroIdx, zeroIdx, zeroIdx}, |
| 597 | ubs: {tileHBound, tileWBound, nUpperBound, cUpperBound}, |
| 598 | steps: {oneStep, oneStep, oneStep, oneStep}, iterArgs: {retValue}, bodyBuilder: buildBody); |
| 599 | return loops.results[0]; |
| 600 | } |
| 601 | |
| 602 | /// This function generates linalg.batch_matmul to multiply input with filter. |
| 603 | /// linalg.batch_matmul only supports 3-dimensional inputs. We can treat |
| 604 | /// tileH x tileW x H x W data as the 1-dimensional data array. That is to |
| 605 | /// convert [tileH, tileW, H, W, N, C] to [tileH x tileW x H x W, N, C]. In this |
| 606 | /// way, we can convert 6-dimensional inputs to 3-dimensional representation |
| 607 | /// that is suitable for linalg.batch_matmul. |
| 608 | /// |
| 609 | /// Batched matmul will do the matrix multiply with the reduction on channel. |
| 610 | /// |
| 611 | /// We get |
| 612 | /// |
| 613 | /// %collapsed_input = tensor.collapse_shape %input |
| 614 | /// %collapsed_filter = tensor.collapse_shape %filter |
| 615 | /// %ret = linalg.batch_matmul %collapsed_input, %collapsed_filter |
| 616 | /// %expanded_ret = tensor.expand_shape %ret |
| 617 | /// |
| 618 | /// After this function, we get return value with data layout |
| 619 | /// (tileH, tileW, H, W, N, F). |
| 620 | static Value matrixMultiply(RewriterBase &rewriter, Location loc, |
| 621 | Value transformedFilter, Value transformedInput, |
| 622 | Type outputElementType) { |
| 623 | // Convert (alphaH, alphaW, C, F) to (alphaH x alphaW, C, F) for filter. |
| 624 | auto filterType = cast<ShapedType>(Val: transformedFilter.getType()); |
| 625 | assert(filterType.hasStaticShape() && "only support static shapes." ); |
| 626 | ArrayRef<int64_t> filterShape = filterType.getShape(); |
| 627 | Type filterElementType = filterType.getElementType(); |
| 628 | auto filterReassocType = RankedTensorType::get( |
| 629 | shape: {filterShape[0] * filterShape[1], filterShape[2], filterShape[3]}, |
| 630 | elementType: filterElementType); |
| 631 | SmallVector<ReassociationIndices> filterReassoc = {{0, 1}, {2}, {3}}; |
| 632 | Value collapseFilter = rewriter.create<tensor::CollapseShapeOp>( |
| 633 | location: loc, args&: filterReassocType, args&: transformedFilter, args&: filterReassoc); |
| 634 | |
| 635 | // Convert (alphaH, alphaW, tileH, tileW, N, C) to |
| 636 | // (alphaH x alphaW, tileH x tileW x N, C) for input. |
| 637 | auto inputType = cast<ShapedType>(Val: transformedInput.getType()); |
| 638 | assert(inputType.hasStaticShape() && "only support static shapes." ); |
| 639 | ArrayRef<int64_t> inputShape = inputType.getShape(); |
| 640 | Type inputElementType = inputType.getElementType(); |
| 641 | auto inputReassocType = RankedTensorType::get( |
| 642 | shape: {inputShape[0] * inputShape[1], |
| 643 | inputShape[2] * inputShape[3] * inputShape[4], inputShape[5]}, |
| 644 | elementType: inputElementType); |
| 645 | SmallVector<ReassociationIndices> inputReassoc = {{0, 1}, {2, 3, 4}, {5}}; |
| 646 | Value collapseInput = rewriter.create<tensor::CollapseShapeOp>( |
| 647 | location: loc, args&: inputReassocType, args&: transformedInput, args&: inputReassoc); |
| 648 | |
| 649 | // Batched matrix multiply. |
| 650 | auto matmulType = RankedTensorType::get( |
| 651 | shape: {inputShape[0] * inputShape[1], |
| 652 | inputShape[2] * inputShape[3] * inputShape[4], filterShape[3]}, |
| 653 | elementType: outputElementType); |
| 654 | Value empty = rewriter |
| 655 | .create<tensor::EmptyOp>(location: loc, args: matmulType.getShape(), |
| 656 | args&: outputElementType) |
| 657 | .getResult(); |
| 658 | Value zero = rewriter.create<arith::ConstantOp>( |
| 659 | location: loc, args: rewriter.getZeroAttr(type: outputElementType)); |
| 660 | Value init = rewriter.create<linalg::FillOp>(location: loc, args&: zero, args&: empty).getResult(i: 0); |
| 661 | |
| 662 | auto matmulOp = rewriter.create<linalg::BatchMatmulOp>( |
| 663 | location: loc, args&: matmulType, args: ValueRange({collapseInput, collapseFilter}), |
| 664 | args: ValueRange{init}); |
| 665 | |
| 666 | // The result shape of batch matmul is (alphaH x alphaW, tileH x tileW x N, F) |
| 667 | // Expand matmul result to (alphaH, alphaW, tileH, tileW, N, F). |
| 668 | SmallVector<ReassociationIndices> outputReassoc = {{0, 1}, {2, 3, 4}, {5}}; |
| 669 | auto outputReassocType = |
| 670 | RankedTensorType::get(shape: {inputShape[0], inputShape[1], inputShape[2], |
| 671 | inputShape[3], inputShape[4], filterShape[3]}, |
| 672 | elementType: outputElementType); |
| 673 | auto expandOutput = rewriter.create<tensor::ExpandShapeOp>( |
| 674 | location: loc, args&: outputReassocType, args: matmulOp.getResult(i: 0), args&: outputReassoc); |
| 675 | return expandOutput; |
| 676 | } |
| 677 | |
| 678 | /// This function transforms the output. The data layout of the output is HWNF. |
| 679 | /// The transformation matrix is 2-dimension. We need to extract H x W from |
| 680 | /// HWNF first. We need to generate 2 levels of loops to iterate on N and F. |
| 681 | /// After the transformation, we get |
| 682 | /// |
| 683 | /// scf.for %h = 0 to tileH step 1 |
| 684 | /// scf.for %w = 0 to tileW step 1 |
| 685 | /// scf.for %n = 0 to N step 1 |
| 686 | /// scf.for %f = 0 to F step 1 |
| 687 | /// %extracted = extract %extracted<alphaH x alphaW> from |
| 688 | /// %input<alphaH x alphaW x tileH x tileW x N x F> |
| 689 | /// at [0, 0, %h, %w, %n, %f] |
| 690 | /// %ret = linalg.matmul AT, %extracted |
| 691 | /// %ret = linalg.matmul %ret, A |
| 692 | /// %inserted = insert %ret<alphaH x alphaW> into |
| 693 | /// output<N x H x W x F> |
| 694 | /// at [%n, (%h x m), (%w x m), %f] |
| 695 | Value outputTransform(RewriterBase &rewriter, Location loc, Value value, |
| 696 | Value output, WinogradConv2DFmr fmr, |
| 697 | bool leftTransform = true, bool rightTransform = true) { |
| 698 | // Map from (m, r) to AT transform matrix. |
| 699 | static const llvm::SmallDenseMap<WinogradConv2DFmr, TransformMatrix> |
| 700 | ATMatrices = { |
| 701 | {WinogradConv2DFmr::F_2_3, TransformMatrix(AT_2x2_3x3, 2, 4)}, |
| 702 | {WinogradConv2DFmr::F_4_3, TransformMatrix(AT_4x4_3x3, 4, 6, 32)}, |
| 703 | {WinogradConv2DFmr::F_2_5, TransformMatrix(AT_2x2_5x5, 2, 6, 16)}, |
| 704 | }; |
| 705 | |
| 706 | // Map from (m, r) to A transform matrix. |
| 707 | static const llvm::SmallDenseMap<WinogradConv2DFmr, TransformMatrix> |
| 708 | AMatrices = { |
| 709 | {WinogradConv2DFmr::F_2_3, TransformMatrix(A_2x2_3x3, 4, 2)}, |
| 710 | {WinogradConv2DFmr::F_4_3, TransformMatrix(A_4x4_3x3, 6, 4, 32)}, |
| 711 | {WinogradConv2DFmr::F_2_5, TransformMatrix(A_2x2_5x5, 6, 2, 16)}, |
| 712 | }; |
| 713 | |
| 714 | int64_t m, r; |
| 715 | std::tie(args&: m, args&: r) = getFmrFromWinogradConv2DFmr(fmr); |
| 716 | auto valueType = cast<ShapedType>(Val: value.getType()); |
| 717 | Type elementType = valueType.getElementType(); |
| 718 | auto valueShape = valueType.getShape(); // H, W, TileH, TileW, N, F |
| 719 | int64_t valueH = valueShape[0]; |
| 720 | int64_t valueW = valueShape[1]; |
| 721 | int64_t valueN = valueShape[4]; |
| 722 | int64_t valueF = valueShape[5]; |
| 723 | int64_t alphaH = leftTransform ? m + r - 1 : 1; |
| 724 | int64_t alphaW = rightTransform ? m + r - 1 : 1; |
| 725 | |
| 726 | if (valueH != alphaH && valueH != 1) |
| 727 | return Value(); |
| 728 | if (valueW != alphaW && valueW != 1) |
| 729 | return Value(); |
| 730 | |
| 731 | auto buildBody = [&](OpBuilder &builder, Location loc, ValueRange ivs, |
| 732 | ValueRange args) -> scf::ValueVector { |
| 733 | auto context = builder.getContext(); |
| 734 | Value tileHIter = ivs[0]; |
| 735 | Value tileWIter = ivs[1]; |
| 736 | Value NIter = ivs[2]; |
| 737 | Value FIter = ivs[3]; |
| 738 | |
| 739 | // Extract (H, W) from (H, W, tileH, tileW, N, F). |
| 740 | auto = |
| 741 | extract2DDataFrom6D(builder, loc, source: value, tileHIndex: tileHIter, tileWIndex: tileWIter, loopNorFIndex: NIter, |
| 742 | loopCorFIndex: FIter, tileHIdx: 2, tileWIdx: 3, /*loopNorFIdx=*/4, |
| 743 | /*loopCorFIdx=*/5, /*heightIdx=*/0, /*widthIdx=*/1); |
| 744 | |
| 745 | const TransformMatrix &AMatrix = AMatrices.at(Val: fmr); |
| 746 | const TransformMatrix &ATMatrix = ATMatrices.at(Val: fmr); |
| 747 | int64_t scalarFactor = (rightTransform ? AMatrix.scalarFactor : 1) * |
| 748 | (leftTransform ? ATMatrix.scalarFactor : 1); |
| 749 | int64_t retCols = rightTransform ? AMatrix.cols : 1; |
| 750 | int64_t retRows = leftTransform ? ATMatrix.rows : 1; |
| 751 | |
| 752 | Value matmulRetValue = extractValue; |
| 753 | Value zero = builder.create<arith::ConstantOp>( |
| 754 | location: loc, args: rewriter.getZeroAttr(type: elementType)); |
| 755 | |
| 756 | auto identityAffineMap = rewriter.getMultiDimIdentityMap(rank: 1); |
| 757 | auto affineMap = |
| 758 | AffineMap::get(dimCount: 1, symbolCount: 0, results: {builder.getAffineDimExpr(position: 0) * m}, context); |
| 759 | Value heightOffset = builder.create<affine::AffineApplyOp>( |
| 760 | location: loc, args&: leftTransform ? affineMap : identityAffineMap, args&: tileHIter); |
| 761 | Value widthOffset = builder.create<affine::AffineApplyOp>( |
| 762 | location: loc, args&: rightTransform ? affineMap : identityAffineMap, args&: tileWIter); |
| 763 | |
| 764 | Value outInitVal = |
| 765 | extract2DDataFrom4D(builder, loc, source: args[0], loopNorFIndex: NIter, loopCorFIndex: FIter, heightOffset, |
| 766 | widthOffset, extractHeight: retRows, extractWidth: retCols, |
| 767 | /*loopNorFIdx=*/0, |
| 768 | /*loopCorFIdx=*/3, /*heightIdx=*/1, |
| 769 | /*widthIdx=*/2); |
| 770 | if (leftTransform) { |
| 771 | auto matmulType = RankedTensorType::get(shape: {retRows, valueW}, elementType); |
| 772 | Value init = outInitVal; |
| 773 | if (rightTransform || scalarFactor != 1) { |
| 774 | auto empty = builder |
| 775 | .create<tensor::EmptyOp>(location: loc, args: matmulType.getShape(), |
| 776 | args&: elementType) |
| 777 | .getResult(); |
| 778 | init = builder.create<linalg::FillOp>(location: loc, args&: zero, args&: empty).getResult(i: 0); |
| 779 | } |
| 780 | |
| 781 | Value AT = create2DTransformMatrix(builder, loc, transform: ATMatrix, type: elementType); |
| 782 | // Multiply AT x m. |
| 783 | auto matmulOp = builder.create<linalg::MatmulOp>( |
| 784 | location: loc, args&: matmulType, args: ValueRange{AT, matmulRetValue}, args: ValueRange{init}); |
| 785 | matmulRetValue = matmulOp.getResult(i: 0); |
| 786 | } |
| 787 | |
| 788 | if (rightTransform) { |
| 789 | auto matmulType = |
| 790 | RankedTensorType::get(shape: {retRows, AMatrix.cols}, elementType); |
| 791 | Value init = outInitVal; |
| 792 | if (scalarFactor != 1) { |
| 793 | auto empty = builder |
| 794 | .create<tensor::EmptyOp>(location: loc, args: matmulType.getShape(), |
| 795 | args&: elementType) |
| 796 | .getResult(); |
| 797 | init = builder.create<linalg::FillOp>(location: loc, args&: zero, args&: empty).getResult(i: 0); |
| 798 | } |
| 799 | |
| 800 | Value A = create2DTransformMatrix(builder, loc, transform: AMatrix, type: elementType); |
| 801 | // Multiply y = (AT x m) x A. |
| 802 | auto matmulOp = builder.create<linalg::MatmulOp>( |
| 803 | location: loc, args&: matmulType, args: ValueRange{matmulRetValue, A}, args: ValueRange{init}); |
| 804 | matmulRetValue = matmulOp.getResult(i: 0); |
| 805 | } |
| 806 | |
| 807 | if (scalarFactor != 1) { |
| 808 | // Multiply by scalar factor and add outInitVal. |
| 809 | Value scalarFactorValue = builder.create<arith::ConstantOp>( |
| 810 | location: loc, args: FloatAttr::get(type: elementType, value: scalarFactor)); |
| 811 | auto matmulType = RankedTensorType::get(shape: {retRows, retCols}, elementType); |
| 812 | auto identityAffineMap = rewriter.getMultiDimIdentityMap(rank: 2); |
| 813 | SmallVector<AffineMap> affineMaps = { |
| 814 | AffineMap::get(dimCount: 2, symbolCount: 0, context), identityAffineMap, identityAffineMap}; |
| 815 | |
| 816 | matmulRetValue = |
| 817 | rewriter |
| 818 | .create<linalg::GenericOp>( |
| 819 | location: loc, args&: matmulType, |
| 820 | args: ValueRange{scalarFactorValue, matmulRetValue}, |
| 821 | args: ValueRange{outInitVal}, args&: affineMaps, |
| 822 | args: llvm::ArrayRef<utils::IteratorType>{ |
| 823 | utils::IteratorType::parallel, |
| 824 | utils::IteratorType::parallel}, |
| 825 | args: [&](OpBuilder &nestedBuilder, Location nestedLoc, |
| 826 | ValueRange args) { |
| 827 | auto mulf = nestedBuilder.create<arith::MulFOp>( |
| 828 | location: nestedLoc, args: args[0], args: args[1]); |
| 829 | auto addf = nestedBuilder.create<arith::AddFOp>( |
| 830 | location: nestedLoc, args: mulf.getResult(), args: args[2]); |
| 831 | nestedBuilder.create<linalg::YieldOp>(location: nestedLoc, |
| 832 | args: addf.getResult()); |
| 833 | }) |
| 834 | .getResult(i: 0); |
| 835 | } |
| 836 | |
| 837 | // Insert (H, W) to (N, H, W, F). |
| 838 | Value combinedVal = |
| 839 | insert2DDataTo4D(builder, loc, source: matmulRetValue, dest: args[0], loopNorFIndex: NIter, loopCorFIndex: FIter, |
| 840 | heightOffset, widthOffset, height: retRows, width: retCols, |
| 841 | /*loopNorFIdx=*/0, |
| 842 | /*loopCorFIdx=*/3, /*heightIdx=*/1, |
| 843 | /*widthIdx=*/2); |
| 844 | |
| 845 | return {combinedVal}; |
| 846 | }; |
| 847 | |
| 848 | int64_t tilwH = valueShape[2]; |
| 849 | int64_t tileW = valueShape[3]; |
| 850 | auto zeroIdx = rewriter.create<arith::ConstantIndexOp>(location: loc, args: 0); |
| 851 | auto tileHBound = rewriter.create<arith::ConstantIndexOp>(location: loc, args&: tilwH); |
| 852 | auto tileWBound = rewriter.create<arith::ConstantIndexOp>(location: loc, args&: tileW); |
| 853 | auto nUpperBound = rewriter.create<arith::ConstantIndexOp>(location: loc, args&: valueN); |
| 854 | auto fUpperBound = rewriter.create<arith::ConstantIndexOp>(location: loc, args&: valueF); |
| 855 | auto oneStep = rewriter.create<arith::ConstantIndexOp>(location: loc, args: 1); |
| 856 | scf::LoopNest loops = scf::buildLoopNest( |
| 857 | builder&: rewriter, loc, lbs: {zeroIdx, zeroIdx, zeroIdx, zeroIdx}, |
| 858 | ubs: {tileHBound, tileWBound, nUpperBound, fUpperBound}, |
| 859 | steps: {oneStep, oneStep, oneStep, oneStep}, iterArgs: {output}, bodyBuilder: buildBody); |
| 860 | return loops.results[0]; |
| 861 | } |
| 862 | |
| 863 | /// Create an empty tensor with alignedType and insert the value into the |
| 864 | /// created empty tensor with aligned size. |
| 865 | static Value padToAlignedTensor(RewriterBase &rewriter, Location loc, |
| 866 | Value value, ArrayRef<int64_t> alignedShape) { |
| 867 | auto valueType = cast<ShapedType>(Val: value.getType()); |
| 868 | Type elementType = valueType.getElementType(); |
| 869 | auto alignedType = RankedTensorType::get(shape: alignedShape, elementType); |
| 870 | Value padValue = rewriter.create<arith::ConstantOp>( |
| 871 | location: loc, args&: elementType, args: rewriter.getZeroAttr(type: elementType)); |
| 872 | |
| 873 | return linalg::makeComposedPadHighOp(b&: rewriter, loc, type: alignedType, source: value, |
| 874 | padding: padValue, nofold: false); |
| 875 | } |
| 876 | |
| 877 | /// Extract sub-tensor with extractedType from value. |
| 878 | static Value (RewriterBase &rewriter, Location loc, |
| 879 | Value value, |
| 880 | RankedTensorType ) { |
| 881 | OpFoldResult zeroIndex = rewriter.getIndexAttr(value: 0); |
| 882 | OpFoldResult oneIndex = rewriter.getIndexAttr(value: 1); |
| 883 | SmallVector<OpFoldResult, 4> offsets(4, zeroIndex); |
| 884 | SmallVector<OpFoldResult, 4> strides(4, oneIndex); |
| 885 | |
| 886 | ArrayRef<int64_t> = extractedType.getShape(); |
| 887 | SmallVector<OpFoldResult> sizes = |
| 888 | getAsOpFoldResult(arrayAttr: rewriter.getI64ArrayAttr(values: extractedShape)); |
| 889 | |
| 890 | return rewriter.create<tensor::ExtractSliceOp>(location: loc, args&: extractedType, args&: value, |
| 891 | args&: offsets, args&: sizes, args&: strides); |
| 892 | } |
| 893 | |
| 894 | /// Utility function to check all values in the attribute are 1. |
| 895 | static bool hasAllOneValues(DenseIntElementsAttr attr) { |
| 896 | return llvm::all_of( |
| 897 | Range&: attr, P: [](const APInt &element) { return element.getSExtValue() == 1; }); |
| 898 | } |
| 899 | |
| 900 | /// A helper function to convert linalg.conv_2d_nhwc_fhwc to |
| 901 | /// linalg.winograd_*_transform ops. |
| 902 | static FailureOr<Operation *> |
| 903 | winogradConv2DHelper(RewriterBase &rewriter, linalg::Conv2DNhwcFhwcOp convOp, |
| 904 | WinogradConv2DFmr fmr) { |
| 905 | if (!convOp.hasPureTensorSemantics()) |
| 906 | return rewriter.notifyMatchFailure( |
| 907 | arg&: convOp, msg: "expected pure tensor semantics for linalg.conv_2d_nhwc_fhwc" ); |
| 908 | |
| 909 | Value input = convOp.getInputs()[0]; |
| 910 | Value filter = convOp.getInputs()[1]; |
| 911 | Value output = convOp.getOutputs()[0]; |
| 912 | auto inputType = cast<ShapedType>(Val: input.getType()); |
| 913 | auto filterType = cast<ShapedType>(Val: filter.getType()); |
| 914 | auto outputType = cast<ShapedType>(Val: output.getType()); |
| 915 | |
| 916 | if (!inputType.hasStaticShape()) |
| 917 | return rewriter.notifyMatchFailure(arg&: convOp, |
| 918 | msg: "expected a static shape for the input" ); |
| 919 | |
| 920 | if (!filterType.hasStaticShape()) |
| 921 | return rewriter.notifyMatchFailure( |
| 922 | arg&: convOp, msg: "expected a static shape for the filter" ); |
| 923 | |
| 924 | if (!hasAllOneValues(attr: convOp.getDilations())) |
| 925 | return rewriter.notifyMatchFailure(arg&: convOp, |
| 926 | msg: "expected all ones for dilations" ); |
| 927 | |
| 928 | if (!hasAllOneValues(attr: convOp.getStrides())) |
| 929 | return rewriter.notifyMatchFailure(arg&: convOp, msg: "expected all ones for strides" ); |
| 930 | |
| 931 | ArrayRef<int64_t> filterShape = filterType.getShape(); |
| 932 | int64_t filterF = filterShape[0]; |
| 933 | int64_t filterH = filterShape[1]; |
| 934 | int64_t filterW = filterShape[2]; |
| 935 | int64_t filterC = filterShape[3]; |
| 936 | ArrayRef<int64_t> inputShape = inputType.getShape(); |
| 937 | int64_t inputN = inputShape[0]; |
| 938 | int64_t inputH = inputShape[1]; |
| 939 | int64_t inputW = inputShape[2]; |
| 940 | int64_t inputC = inputShape[3]; |
| 941 | ArrayRef<int64_t> outputShape = outputType.getShape(); |
| 942 | int64_t outputN = outputShape[0]; |
| 943 | int64_t outputH = outputShape[1]; |
| 944 | int64_t outputW = outputShape[2]; |
| 945 | int64_t outputF = outputShape[3]; |
| 946 | |
| 947 | int64_t m, r; |
| 948 | std::tie(args&: m, args&: r) = getFmrFromWinogradConv2DFmr(fmr); |
| 949 | // Only support F(m x m, r x r), F(m x 1, r x 1) or F(1 x m, 1 x r). |
| 950 | bool isSupportedFilter = false; |
| 951 | if (filterH == filterW && filterH == r) |
| 952 | isSupportedFilter = true; |
| 953 | if (filterH == r && filterW == 1) |
| 954 | isSupportedFilter = true; |
| 955 | if (filterH == 1 && filterW == r) |
| 956 | isSupportedFilter = true; |
| 957 | |
| 958 | if (!isSupportedFilter) |
| 959 | return rewriter.notifyMatchFailure( |
| 960 | arg&: convOp, msg: "only support filter (r x r), (r x 1) or (1 x r)" ); |
| 961 | |
| 962 | // All the criterias are satisfied. We can do Winograd Conv2D. |
| 963 | Location loc = convOp.getLoc(); |
| 964 | |
| 965 | // For F(m x 1, r x 1), we only need to do left side transform. |
| 966 | bool leftTransform = filterH != 1; |
| 967 | // For F(1 x m, 1 x r), we only need to do right side transform. |
| 968 | bool rightTransform = filterW != 1; |
| 969 | int64_t heightM = leftTransform ? m : 1; |
| 970 | int64_t widthM = rightTransform ? m : 1; |
| 971 | int64_t heightR = leftTransform ? r : 1; |
| 972 | int64_t widthR = rightTransform ? r : 1; |
| 973 | |
| 974 | // --- Create operation for filter transform --- |
| 975 | Type filterElementType = filterType.getElementType(); |
| 976 | int64_t alphaH = heightM + heightR - 1; |
| 977 | int64_t alphaW = widthM + widthR - 1; |
| 978 | int64_t tileH = llvm::divideCeilSigned(Numerator: outputH, Denominator: heightM); |
| 979 | int64_t tileW = llvm::divideCeilSigned(Numerator: outputW, Denominator: widthM); |
| 980 | auto retType = RankedTensorType::get(shape: {alphaH, alphaW, filterC, filterF}, |
| 981 | elementType: filterElementType); |
| 982 | Value retValue = rewriter.create<tensor::EmptyOp>(location: loc, args: retType.getShape(), |
| 983 | args&: filterElementType); |
| 984 | auto transformedFilter = rewriter.create<linalg::WinogradFilterTransformOp>( |
| 985 | location: loc, args&: retType, args&: filter, args&: retValue, args&: fmr); |
| 986 | |
| 987 | // --- Create operation for input transform --- |
| 988 | |
| 989 | // When input size - (r - 1) is not aligned with output tile size, we need to |
| 990 | // pad the input data to create the full tiles as tiling. |
| 991 | Type inputElementType = inputType.getElementType(); |
| 992 | int64_t alignedInputH = tileH * heightM + (heightR - 1); |
| 993 | int64_t alignedInputW = tileW * widthM + (widthR - 1); |
| 994 | if (alignedInputH != inputH || alignedInputW != inputW) { |
| 995 | input = padToAlignedTensor(rewriter, loc, value: input, |
| 996 | alignedShape: {inputN, alignedInputH, alignedInputW, inputC}); |
| 997 | } |
| 998 | |
| 999 | retType = RankedTensorType::get( |
| 1000 | shape: {alphaH, alphaW, tileH, tileW, inputN, inputC}, elementType: inputElementType); |
| 1001 | retValue = rewriter.create<tensor::EmptyOp>(location: loc, args: retType.getShape(), |
| 1002 | args&: inputElementType); |
| 1003 | auto transformedInput = rewriter.create<linalg::WinogradInputTransformOp>( |
| 1004 | location: loc, args&: retType, args&: input, args&: retValue, args&: fmr); |
| 1005 | |
| 1006 | Type outputElementType = outputType.getElementType(); |
| 1007 | Value matmulRet = matrixMultiply(rewriter, loc, transformedFilter, |
| 1008 | transformedInput, outputElementType); |
| 1009 | |
| 1010 | // --- Create operation for output transform --- |
| 1011 | |
| 1012 | // When output size is not aligned with output tile size, we need to pad the |
| 1013 | // output buffer to insert the full tiles after tiling. |
| 1014 | int64_t alignedOutputH = tileH * heightM; |
| 1015 | int64_t alignedOutputW = tileW * widthM; |
| 1016 | bool isOutputUnaligned = |
| 1017 | ((alignedOutputH != outputH) || (alignedOutputW != outputW)); |
| 1018 | if (isOutputUnaligned) { |
| 1019 | auto alignedOutputType = RankedTensorType::get( |
| 1020 | shape: {outputN, alignedOutputH, alignedOutputW, outputF}, elementType: outputElementType); |
| 1021 | output = |
| 1022 | padToAlignedTensor(rewriter, loc, value: output, alignedShape: alignedOutputType.getShape()); |
| 1023 | outputType = alignedOutputType; |
| 1024 | } |
| 1025 | |
| 1026 | Value transformedOutput = rewriter.create<linalg::WinogradOutputTransformOp>( |
| 1027 | location: loc, args&: outputType, args&: matmulRet, args&: output, args&: fmr); |
| 1028 | |
| 1029 | // When output size is not aligned with output tile size, extract the |
| 1030 | // value from the padded buffer. |
| 1031 | if (isOutputUnaligned) { |
| 1032 | transformedOutput = extractFromAlignedTensor( |
| 1033 | rewriter, loc, value: transformedOutput, |
| 1034 | extractedType: RankedTensorType::get(shape: {outputN, outputH, outputW, outputF}, |
| 1035 | elementType: outputElementType)); |
| 1036 | } |
| 1037 | |
| 1038 | rewriter.replaceOp(op: convOp, newValues: transformedOutput); |
| 1039 | |
| 1040 | return transformedOutput.getDefiningOp(); |
| 1041 | } |
| 1042 | |
| 1043 | /// A helper function to decompose linalg.winograd_filter_transform. |
| 1044 | FailureOr<Operation *> |
| 1045 | decomposeWinogradFilterTransformHelper(RewriterBase &rewriter, |
| 1046 | linalg::WinogradFilterTransformOp op) { |
| 1047 | Location loc = op.getLoc(); |
| 1048 | Value filter = op.getFilter(); |
| 1049 | auto filterType = cast<ShapedType>(Val: filter.getType()); |
| 1050 | auto filterShape = filterType.getShape(); |
| 1051 | int64_t filterH = filterShape[1]; |
| 1052 | int64_t filterW = filterShape[2]; |
| 1053 | |
| 1054 | // For F(m x 1, r x 1), we only need to do left side transform. |
| 1055 | bool leftTransform = filterH != 1; |
| 1056 | // For F(1 x m, 1 x r), we only need to do right side transform. |
| 1057 | bool rightTransform = filterW != 1; |
| 1058 | Value transformedFilter = |
| 1059 | filterTransform(rewriter, loc, filter, retValue: op.getOutput(), fmr: op.getFmr(), |
| 1060 | leftTransform, rightTransform); |
| 1061 | if (!transformedFilter) |
| 1062 | return failure(); |
| 1063 | |
| 1064 | rewriter.replaceOp(op, newValues: transformedFilter); |
| 1065 | |
| 1066 | return transformedFilter.getDefiningOp(); |
| 1067 | } |
| 1068 | |
| 1069 | /// A helper function to decompose linalg.winograd_input_transform. |
| 1070 | FailureOr<Operation *> |
| 1071 | decomposeWinogradInputTransformHelper(RewriterBase &rewriter, |
| 1072 | linalg::WinogradInputTransformOp op) { |
| 1073 | Location loc = op.getLoc(); |
| 1074 | Value output = op.getOutput(); |
| 1075 | auto outputType = cast<ShapedType>(Val: output.getType()); |
| 1076 | auto outputShape = outputType.getShape(); |
| 1077 | |
| 1078 | int64_t outputH = outputShape[0]; |
| 1079 | int64_t outputW = outputShape[1]; |
| 1080 | |
| 1081 | // For F(m x 1, r x 1), we only need to do left side transform. |
| 1082 | bool leftTransform = outputH != 1; |
| 1083 | // For F(1 x m, 1 x r), we only need to do right side transform. |
| 1084 | bool rightTransform = outputW != 1; |
| 1085 | Value transformedInput = |
| 1086 | inputTransform(rewriter, loc, input: op.getInput(), retValue: op.getOutput(), fmr: op.getFmr(), |
| 1087 | leftTransform, rightTransform); |
| 1088 | if (!transformedInput) |
| 1089 | return failure(); |
| 1090 | |
| 1091 | rewriter.replaceOp(op, newValues: transformedInput); |
| 1092 | |
| 1093 | return transformedInput.getDefiningOp(); |
| 1094 | } |
| 1095 | |
| 1096 | /// A helper function to decompose linalg.winograd_output_transform. |
| 1097 | FailureOr<Operation *> |
| 1098 | decomposeWinogradOutputTransformHelper(RewriterBase &rewriter, |
| 1099 | linalg::WinogradOutputTransformOp op) { |
| 1100 | Location loc = op.getLoc(); |
| 1101 | Value value = op.getValue(); |
| 1102 | auto valueType = cast<ShapedType>(Val: value.getType()); |
| 1103 | auto valueShape = valueType.getShape(); |
| 1104 | int64_t valueH = valueShape[0]; |
| 1105 | int64_t valueW = valueShape[1]; |
| 1106 | |
| 1107 | // For F(m x 1, r x 1), we only need to do left side transform. |
| 1108 | bool leftTransform = valueH != 1; |
| 1109 | // For F(1 x m, 1 x r), we only need to do right side transform. |
| 1110 | bool rightTransform = valueW != 1; |
| 1111 | Value transformedOutput = |
| 1112 | outputTransform(rewriter, loc, value, output: op.getOutput(), fmr: op.getFmr(), |
| 1113 | leftTransform, rightTransform); |
| 1114 | if (!transformedOutput) |
| 1115 | return failure(); |
| 1116 | |
| 1117 | rewriter.replaceOp(op, newValues: transformedOutput); |
| 1118 | |
| 1119 | return transformedOutput.getDefiningOp(); |
| 1120 | } |
| 1121 | |
| 1122 | /// A rewrite pattern to decompose linalg.winograd_filter_transform operations. |
| 1123 | class DecomposeWinogradFilterTransform final |
| 1124 | : public OpRewritePattern<linalg::WinogradFilterTransformOp> { |
| 1125 | public: |
| 1126 | using OpRewritePattern::OpRewritePattern; |
| 1127 | |
| 1128 | LogicalResult matchAndRewrite(linalg::WinogradFilterTransformOp op, |
| 1129 | PatternRewriter &rewriter) const override { |
| 1130 | return decomposeWinogradFilterTransformHelper(rewriter, op); |
| 1131 | } |
| 1132 | }; |
| 1133 | |
| 1134 | /// A rewrite pattern to decompose linalg.winograd_input_transform operations. |
| 1135 | class DecomposeWinogradInputTransform final |
| 1136 | : public OpRewritePattern<linalg::WinogradInputTransformOp> { |
| 1137 | public: |
| 1138 | using OpRewritePattern::OpRewritePattern; |
| 1139 | |
| 1140 | LogicalResult matchAndRewrite(linalg::WinogradInputTransformOp op, |
| 1141 | PatternRewriter &rewriter) const override { |
| 1142 | return decomposeWinogradInputTransformHelper(rewriter, op); |
| 1143 | } |
| 1144 | }; |
| 1145 | |
| 1146 | /// A rewrite pattern to decompose linalg.winograd_output_transform operations. |
| 1147 | class DecomposeWinogradOutputTransform final |
| 1148 | : public OpRewritePattern<linalg::WinogradOutputTransformOp> { |
| 1149 | public: |
| 1150 | using OpRewritePattern::OpRewritePattern; |
| 1151 | |
| 1152 | LogicalResult matchAndRewrite(linalg::WinogradOutputTransformOp op, |
| 1153 | PatternRewriter &rewriter) const override { |
| 1154 | return decomposeWinogradOutputTransformHelper(rewriter, op); |
| 1155 | } |
| 1156 | }; |
| 1157 | |
| 1158 | /// A rewrite pattern for Winograd Conv2D algorithm. |
| 1159 | class WinogradConv2DNhwcFhwc final |
| 1160 | : public OpRewritePattern<linalg::Conv2DNhwcFhwcOp> { |
| 1161 | public: |
| 1162 | using OpRewritePattern::OpRewritePattern; |
| 1163 | WinogradConv2DNhwcFhwc(mlir::MLIRContext *context, WinogradConv2DFmr fmr) |
| 1164 | : OpRewritePattern(context), fmr(fmr) {} |
| 1165 | |
| 1166 | LogicalResult matchAndRewrite(linalg::Conv2DNhwcFhwcOp convOp, |
| 1167 | PatternRewriter &rewriter) const override { |
| 1168 | if (failed(Result: winogradConv2DHelper(rewriter, convOp, fmr))) |
| 1169 | return failure(); |
| 1170 | |
| 1171 | return success(); |
| 1172 | } |
| 1173 | |
| 1174 | private: |
| 1175 | WinogradConv2DFmr fmr; |
| 1176 | }; |
| 1177 | |
| 1178 | } // end anonymous namespace |
| 1179 | |
| 1180 | //===----------------------------------------------------------------------===// |
| 1181 | FailureOr<Operation *> winogradConv2D(RewriterBase &rewriter, |
| 1182 | linalg::Conv2DNhwcFhwcOp op, |
| 1183 | linalg::WinogradConv2DFmr fmr) { |
| 1184 | return winogradConv2DHelper(rewriter, convOp: op, fmr); |
| 1185 | } |
| 1186 | |
| 1187 | FailureOr<Operation *> |
| 1188 | decomposeWinogradFilterTransformOp(RewriterBase &rewriter, |
| 1189 | linalg::WinogradFilterTransformOp op) { |
| 1190 | return decomposeWinogradFilterTransformHelper(rewriter, op); |
| 1191 | } |
| 1192 | |
| 1193 | FailureOr<Operation *> |
| 1194 | decomposeWinogradInputTransformOp(RewriterBase &rewriter, |
| 1195 | linalg::WinogradInputTransformOp op) { |
| 1196 | return decomposeWinogradInputTransformHelper(rewriter, op); |
| 1197 | } |
| 1198 | |
| 1199 | FailureOr<Operation *> |
| 1200 | decomposeWinogradOutputTransformOp(RewriterBase &rewriter, |
| 1201 | linalg::WinogradOutputTransformOp op) { |
| 1202 | return decomposeWinogradOutputTransformHelper(rewriter, op); |
| 1203 | } |
| 1204 | |
| 1205 | void populateWinogradConv2DPatterns(RewritePatternSet &patterns, |
| 1206 | WinogradConv2DFmr fmr) { |
| 1207 | MLIRContext *context = patterns.getContext(); |
| 1208 | // TODO: Support more Conv2D data layout, e.g., conv_2d_nchw_fchw |
| 1209 | patterns.insert<WinogradConv2DNhwcFhwc>(arg&: context, args&: fmr); |
| 1210 | } |
| 1211 | |
| 1212 | void populateDecomposeWinogradOpsPatterns(RewritePatternSet &patterns) { |
| 1213 | MLIRContext *context = patterns.getContext(); |
| 1214 | patterns |
| 1215 | .insert<DecomposeWinogradFilterTransform, DecomposeWinogradInputTransform, |
| 1216 | DecomposeWinogradOutputTransform>(arg&: context); |
| 1217 | } |
| 1218 | |
| 1219 | } // end namespace linalg |
| 1220 | } // end namespace mlir |
| 1221 | |