| 1 | //===- WinogradConv2D.cpp - Winograd Conv2D implementation ----------------===// |
| 2 | // |
| 3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
| 4 | // See https://llvm.org/LICENSE.txt for license information. |
| 5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
| 6 | // |
| 7 | //===----------------------------------------------------------------------===// |
| 8 | // |
| 9 | // Implement Winograd Conv2D algorithm. The implementation is based on the |
| 10 | // paper: Fast Algorithms for Convolutional Neural Networks |
| 11 | // (https://arxiv.org/abs/1509.09308) |
| 12 | // |
| 13 | //===----------------------------------------------------------------------===// |
| 14 | |
| 15 | #include "mlir/Dialect/Affine/IR/AffineOps.h" |
| 16 | #include "mlir/Dialect/Arith/IR/Arith.h" |
| 17 | #include "mlir/Dialect/Linalg/IR/Linalg.h" |
| 18 | #include "mlir/Dialect/Linalg/Utils/Utils.h" |
| 19 | #include "mlir/Dialect/Tensor/IR/Tensor.h" |
| 20 | #include "mlir/Dialect/Utils/StaticValueUtils.h" |
| 21 | #include "mlir/Transforms/GreedyPatternRewriteDriver.h" |
| 22 | #include "llvm/Support/MathExtras.h" |
| 23 | |
| 24 | namespace mlir { |
| 25 | namespace linalg { |
| 26 | |
| 27 | namespace { |
| 28 | |
| 29 | // clang-format off |
| 30 | /// Winograd Conv2D uses a minimal 2D filtering algorithm to calculate its |
| 31 | /// result. The formula of minimal 2D filtering algorithm F(m x m, r x r), |
| 32 | /// m is the output dimension and r is the filter dimension, is |
| 33 | /// |
| 34 | /// Y = A^T x [ (G x g x G^T) x (B^T x d x B) ] x A |
| 35 | /// |
| 36 | /// g is filter and d is input data. We need to prepare 6 constant |
| 37 | /// transformation matrices, G, G^T, B^T, B, A^T, and A for this formula. |
| 38 | /// |
| 39 | /// The following tables define these constant transformation matrices for |
| 40 | /// F(2 x 2, 3 x 3), F(4 x 4, 3 x 3), and F(2 x 2, 5 x 5) |
| 41 | constexpr float G_2x2_3x3[] = { |
| 42 | -1, 0, 0, |
| 43 | 1./2, -1./2, 1./2, |
| 44 | 1./2, 1./2, 1./2, |
| 45 | 0, 0, 1 |
| 46 | }; |
| 47 | |
| 48 | constexpr float GT_2x2_3x3[] = { |
| 49 | -1, 1./2, 1./2, 0, |
| 50 | 0, -1./2, 1./2, 0, |
| 51 | 0, 1./2, 1./2, 1 |
| 52 | }; |
| 53 | |
| 54 | constexpr float BT_2x2_3x3[] = { |
| 55 | -1, 0, 1, 0, |
| 56 | 0, -1, 1, 0, |
| 57 | 0, 1, 1, 0, |
| 58 | 0, -1, 0, 1 |
| 59 | }; |
| 60 | |
| 61 | constexpr float B_2x2_3x3[] = { |
| 62 | -1, 0, 0, 0, |
| 63 | 0, -1, 1, -1, |
| 64 | 1, 1, 1, 0, |
| 65 | 0, 0, 0, 1 |
| 66 | }; |
| 67 | |
| 68 | constexpr float AT_2x2_3x3[] = { |
| 69 | 1, 1, 1, 0, |
| 70 | 0, -1, 1, 1 |
| 71 | }; |
| 72 | |
| 73 | constexpr float A_2x2_3x3[] = { |
| 74 | 1, 0, |
| 75 | 1, -1, |
| 76 | 1, 1, |
| 77 | 0, 1 |
| 78 | }; |
| 79 | |
| 80 | constexpr float G_4x4_3x3[] = { |
| 81 | 1, 0, 0, |
| 82 | -1./3, 1./3, -1./3, |
| 83 | -1./3, -1./3, -1./3, |
| 84 | 1./12, -1./6, 1./3, |
| 85 | 1./12, 1./6, 1./3, |
| 86 | 0, 0, 1 |
| 87 | }; |
| 88 | |
| 89 | constexpr float GT_4x4_3x3[] = { |
| 90 | 1, -1./3, -1./3, 1./12, 1./12, 0, |
| 91 | 0, 1./3, -1./3, -1./6, 1./6, 0, |
| 92 | 0, -1./3, -1./3, 1./3, 1./3, 1 |
| 93 | }; |
| 94 | |
| 95 | constexpr float BT_4x4_3x3[] = { |
| 96 | 1./4, 0, -5./16, 0, 1./16, 0, |
| 97 | 0, 1./4, -1./4, -1./16, 1./16, 0, |
| 98 | 0, -1./4, -1./4, 1./16, 1./16, 0, |
| 99 | 0, 1./4, -1./8, -1./4, 1./8, 0, |
| 100 | 0, -1./4, -1./8, 1./4, 1./8, 0, |
| 101 | 0, 1./4, 0, -5./16, 0, 1./16 |
| 102 | }; |
| 103 | |
| 104 | constexpr float B_4x4_3x3[] = { |
| 105 | 1./4, 0, 0, 0, 0, 0, |
| 106 | 0, 1./4, -1./4, 1./4, -1./4, 1./4, |
| 107 | -5./16, -1./4, -1./4, -1./8, -1./8, 0, |
| 108 | 0, -1./16, 1./16, -1./4, 1./4, -5./16, |
| 109 | 1./16, 1./16, 1./16, 1./8, 1./8, 0, |
| 110 | 0, 0, 0, 0, 0, 1./16 |
| 111 | }; |
| 112 | |
| 113 | constexpr float AT_4x4_3x3[] = { |
| 114 | 1./8, 1./4, 1./4, 1./8, 1./8, 0, |
| 115 | 0, -1./4, 1./4, -1./4, 1./4, 0, |
| 116 | 0, 1./4, 1./4, 1./2, 1./2, 0, |
| 117 | 0, -1./4, 1./4, -1, 1, 1./2 |
| 118 | }; |
| 119 | |
| 120 | constexpr float A_4x4_3x3[] = { |
| 121 | 1./8, 0, 0, 0, |
| 122 | 1./4, -1./4, 1./4, -1./4, |
| 123 | 1./4, 1./4, 1./4, 1./4, |
| 124 | 1./8, -1./4, 1./2, -1, |
| 125 | 1./8, 1./4, 1./2, 1, |
| 126 | 0, 0, 0, 1./2 |
| 127 | }; |
| 128 | |
| 129 | constexpr float G_2x2_5x5[] = { |
| 130 | 1, 0, 0, 0, 0, |
| 131 | 1./6, -1./6, 1./6, -1./6, 1./6, |
| 132 | -1./6, -1./6, -1./6, -1./6, -1./6, |
| 133 | -4./15, 2./15, -1./15, 1./30, -1./60, |
| 134 | 1./60, 1./30, 1./15, 2./15, 4./15, |
| 135 | 0, 0, 0, 0, 1 |
| 136 | }; |
| 137 | |
| 138 | constexpr float GT_2x2_5x5[] = { |
| 139 | 1, 1./6, -1./6, -4./15, 1./60, 0, |
| 140 | 0, -1./6, -1./6, 2./15, 1./30, 0, |
| 141 | 0, 1./6, -1./6, -1./15, 1./15, 0, |
| 142 | 0, -1./6, -1./6, 1./30, 2./15, 0, |
| 143 | 0, 1./6, -1./6, -1./60, 4./15, 1 |
| 144 | }; |
| 145 | |
| 146 | constexpr float BT_2x2_5x5[] = { |
| 147 | 1./8, 3./16, -1./4, -3./16, 1./8, 0, |
| 148 | 0, 1./8, 1./16, -5./16, 1./8, 0, |
| 149 | 0, -1./8, -5./16, -1./16, 1./8, 0, |
| 150 | 0, 1./4, -1./8, -1./4, 1./8, 0, |
| 151 | 0, -1./8, -1./4, 1./8, 1./4, 0, |
| 152 | 0, 1./8, 3./16, -1./4, -3./16, 1./8 |
| 153 | }; |
| 154 | |
| 155 | constexpr float B_2x2_5x5[] = { |
| 156 | 1./8, 0, 0, 0, 0, 0, |
| 157 | 3./16, 1./8, -1./8, 1./4, -1./8, 1./8, |
| 158 | -1./4, 1./16, -5./16, -1./8, -1./4, 3./16, |
| 159 | -3./16, -5./16, -1./16, -1./4, 1./8, -1./4, |
| 160 | 1./8, 1./8, 1./8, 1./8, 1./4, -3./16, |
| 161 | 0, 0, 0, 0, 0, 1./8 |
| 162 | }; |
| 163 | |
| 164 | constexpr float AT_2x2_5x5[] = { |
| 165 | 1./2, 1, 1, 2, 1, 0, |
| 166 | 0, -1, 1, -1, 2, 1./2 |
| 167 | }; |
| 168 | |
| 169 | constexpr float A_2x2_5x5[] = { |
| 170 | 1./2, 0, |
| 171 | 1, -1, |
| 172 | 1, 1, |
| 173 | 2, -1, |
| 174 | 1, 2, |
| 175 | 0, 1./2 |
| 176 | }; |
| 177 | // clang-format on |
| 178 | |
| 179 | using TransformMapKeyTy = std::pair<int, int>; |
| 180 | |
| 181 | /// We use F(m, r) to define the size of minimal filtering algorithms. |
| 182 | /// m is the output dimension and r is the filter dimension. We can get |
| 183 | /// the input dimension, alpha, from the formula, alpha = m + r - 1. |
| 184 | /// |
| 185 | /// For example, when m = 2 and r = 3, we know its input size is 4. |
| 186 | /// The Conv2D will operate on 4x4 input data with 3x3 filter and get |
| 187 | /// 2x2 output result. |
| 188 | constexpr TransformMapKeyTy F_2_3{2, 3}; |
| 189 | constexpr TransformMapKeyTy F_4_3{4, 3}; |
| 190 | constexpr TransformMapKeyTy F_2_5{2, 5}; |
| 191 | |
| 192 | /// Structure to keep information of constant transform matrices. |
| 193 | struct TransformMatrix { |
| 194 | TransformMatrix(const float *table, int64_t rows, int64_t cols, |
| 195 | int64_t scalarFactor = 1) |
| 196 | : table(table), rows(rows), cols(cols), scalarFactor(scalarFactor) {} |
| 197 | |
| 198 | const float *table; |
| 199 | int64_t rows; |
| 200 | int64_t cols; |
| 201 | int64_t scalarFactor; |
| 202 | }; |
| 203 | |
| 204 | /// Utility function to convert constant array to arith.constant Value. |
| 205 | Value create2DTransformMatrix(OpBuilder &builder, Location loc, |
| 206 | TransformMatrix transform, Type type) { |
| 207 | ArrayRef<float> constVec(transform.table, transform.rows * transform.cols); |
| 208 | |
| 209 | return builder.create<arith::ConstantOp>( |
| 210 | loc, DenseFPElementsAttr::get( |
| 211 | RankedTensorType::get( |
| 212 | SmallVector<int64_t>{transform.rows, transform.cols}, type), |
| 213 | constVec)); |
| 214 | } |
| 215 | |
| 216 | /// Extract height x width data from 4D tensors. |
| 217 | Value (OpBuilder &builder, Location loc, Value source, |
| 218 | Value loopNorFIndex, Value loopCorFIndex, |
| 219 | Value heightOffset, Value widthOffset, |
| 220 | int64_t , int64_t , |
| 221 | int64_t loopNorFIdx, int64_t loopCorFIdx, |
| 222 | int64_t heightIdx, int64_t widthIdx) { |
| 223 | auto sourceType = cast<ShapedType>(source.getType()); |
| 224 | Type elementType = sourceType.getElementType(); |
| 225 | int64_t srcSize = sourceType.getRank(); |
| 226 | |
| 227 | auto oneIndex = builder.getIndexAttr(1); |
| 228 | SmallVector<OpFoldResult> offsets; |
| 229 | offsets.resize(N: srcSize); |
| 230 | offsets[loopNorFIdx] = loopNorFIndex; |
| 231 | offsets[loopCorFIdx] = loopCorFIndex; |
| 232 | offsets[heightIdx] = heightOffset; |
| 233 | offsets[widthIdx] = widthOffset; |
| 234 | SmallVector<OpFoldResult> sizes(srcSize, oneIndex); |
| 235 | sizes[heightIdx] = builder.getIndexAttr(extractHeight); |
| 236 | sizes[widthIdx] = builder.getIndexAttr(extractWidth); |
| 237 | SmallVector<OpFoldResult> strides(srcSize, oneIndex); |
| 238 | |
| 239 | auto = |
| 240 | RankedTensorType::get({extractHeight, extractWidth}, elementType); |
| 241 | auto = builder.create<tensor::ExtractSliceOp>( |
| 242 | loc, extractFilterType, source, offsets, sizes, strides); |
| 243 | |
| 244 | return extractFilterOp; |
| 245 | } |
| 246 | |
| 247 | /// Extract height x width data from 6D tensors. |
| 248 | Value (OpBuilder &builder, Location loc, Value source, |
| 249 | Value tileHIndex, Value tileWIndex, |
| 250 | Value loopNorFIndex, Value loopCorFIndex, |
| 251 | int64_t tileHIdx, int64_t tileWIdx, |
| 252 | int64_t loopNorFIdx, int64_t loopCorFIdx, |
| 253 | int64_t heightIdx, int64_t widthIdx) { |
| 254 | auto sourceType = cast<ShapedType>(source.getType()); |
| 255 | Type elementType = sourceType.getElementType(); |
| 256 | auto sourceShape = sourceType.getShape(); |
| 257 | int64_t srcSize = sourceType.getRank(); |
| 258 | int64_t height = sourceShape[heightIdx]; |
| 259 | int64_t width = sourceShape[widthIdx]; |
| 260 | |
| 261 | auto zeroIndex = builder.getIndexAttr(0); |
| 262 | auto oneIndex = builder.getIndexAttr(1); |
| 263 | SmallVector<OpFoldResult> offsets(srcSize, zeroIndex); |
| 264 | offsets.resize(N: srcSize); |
| 265 | offsets[tileHIdx] = tileHIndex; |
| 266 | offsets[tileWIdx] = tileWIndex; |
| 267 | offsets[loopNorFIdx] = loopNorFIndex; |
| 268 | offsets[loopCorFIdx] = loopCorFIndex; |
| 269 | SmallVector<OpFoldResult> sizes(srcSize, oneIndex); |
| 270 | sizes[heightIdx] = builder.getIndexAttr(height); |
| 271 | sizes[widthIdx] = builder.getIndexAttr(width); |
| 272 | SmallVector<OpFoldResult> strides(srcSize, oneIndex); |
| 273 | |
| 274 | auto = RankedTensorType::get({height, width}, elementType); |
| 275 | auto = builder.create<tensor::ExtractSliceOp>( |
| 276 | loc, extractFilterType, source, offsets, sizes, strides); |
| 277 | |
| 278 | return extractFilterOp; |
| 279 | } |
| 280 | |
| 281 | /// Insert transformed height x width data to 4D tensors which it is |
| 282 | /// extracted from. |
| 283 | Value insert2DDataTo4D(OpBuilder &builder, Location loc, Value source, |
| 284 | Value dest, Value loopNorFIndex, Value loopCorFIndex, |
| 285 | Value heightOffset, Value widthOffset, int64_t height, |
| 286 | int64_t width, int64_t loopNorFIdx, int64_t loopCorFIdx, |
| 287 | int64_t heightIdx, int64_t widthIdx) { |
| 288 | int64_t destSize = cast<ShapedType>(dest.getType()).getRank(); |
| 289 | auto oneIndex = builder.getIndexAttr(1); |
| 290 | SmallVector<OpFoldResult> retOffsets; |
| 291 | retOffsets.resize(N: destSize); |
| 292 | retOffsets[loopNorFIdx] = loopNorFIndex; |
| 293 | retOffsets[loopCorFIdx] = loopCorFIndex; |
| 294 | retOffsets[heightIdx] = heightOffset; |
| 295 | retOffsets[widthIdx] = widthOffset; |
| 296 | SmallVector<OpFoldResult> retSizes(destSize, oneIndex); |
| 297 | retSizes[heightIdx] = builder.getIndexAttr(height); |
| 298 | retSizes[widthIdx] = builder.getIndexAttr(width); |
| 299 | SmallVector<OpFoldResult> strides(destSize, oneIndex); |
| 300 | |
| 301 | auto insertSliceOp = builder.create<tensor::InsertSliceOp>( |
| 302 | loc, source, dest, retOffsets, retSizes, strides); |
| 303 | |
| 304 | return insertSliceOp; |
| 305 | } |
| 306 | |
| 307 | /// Insert transformed height x width data to 6D tensors which it is |
| 308 | /// extracted from. |
| 309 | Value insert2DDataTo6D(OpBuilder &builder, Location loc, Value source, |
| 310 | Value dest, Value tileHIndex, Value tileWIndex, |
| 311 | Value loopNorFIndex, Value loopCorFIndex, int64_t height, |
| 312 | int64_t width, int64_t tileHIdx, int64_t tileWIdx, |
| 313 | int64_t loopNorFIdx, int64_t loopCorFIdx, |
| 314 | int64_t heightIdx, int64_t widthIdx) { |
| 315 | int64_t destSize = cast<ShapedType>(dest.getType()).getRank(); |
| 316 | auto zeroIndex = builder.getIndexAttr(0); |
| 317 | auto oneIndex = builder.getIndexAttr(1); |
| 318 | SmallVector<OpFoldResult> retOffsets(destSize, zeroIndex); |
| 319 | retOffsets.resize(N: destSize); |
| 320 | retOffsets[tileHIdx] = tileHIndex; |
| 321 | retOffsets[tileWIdx] = tileWIndex; |
| 322 | retOffsets[loopNorFIdx] = loopNorFIndex; |
| 323 | retOffsets[loopCorFIdx] = loopCorFIndex; |
| 324 | SmallVector<OpFoldResult> retSizes(destSize, oneIndex); |
| 325 | retSizes[heightIdx] = builder.getIndexAttr(height); |
| 326 | retSizes[widthIdx] = builder.getIndexAttr(width); |
| 327 | SmallVector<OpFoldResult> strides(destSize, oneIndex); |
| 328 | |
| 329 | auto insertSliceOp = builder.create<tensor::InsertSliceOp>( |
| 330 | loc, source, dest, retOffsets, retSizes, strides); |
| 331 | |
| 332 | return insertSliceOp; |
| 333 | } |
| 334 | |
| 335 | /// This function transforms the filter. The data layout of the filter is FHWC. |
| 336 | /// The transformation matrix is 2-dimension. We need to extract H x W from |
| 337 | /// FHWC first. We need to generate 2 levels of loops to iterate on F and C. |
| 338 | /// After the transformation, we get |
| 339 | /// |
| 340 | /// scf.for %f = lo_f to hi_f step 1 |
| 341 | /// scf.for %c = lo_c to hi_c step 1 |
| 342 | /// %extracted = extract filter<h x w> from filter<f x h x w x c> |
| 343 | /// %ret = linalg.matmul G, %extracted |
| 344 | /// %ret = linalg.matmul %ret, GT |
| 345 | /// %inserted = insert %ret into filter<h x w x c x f> |
| 346 | Value filterTransform(RewriterBase &rewriter, Location loc, Value filter, |
| 347 | Value retValue, int64_t m, int64_t r, |
| 348 | bool leftTransform = true, bool rightTransform = true) { |
| 349 | // Map from (m, r) to G transform matrix. |
| 350 | static const llvm::SmallDenseMap<TransformMapKeyTy, TransformMatrix> |
| 351 | GMatrices = { |
| 352 | {F_2_3, TransformMatrix(G_2x2_3x3, 4, 3)}, |
| 353 | {F_4_3, TransformMatrix(G_4x4_3x3, 6, 3)}, |
| 354 | {F_2_5, TransformMatrix(G_2x2_5x5, 6, 5)}, |
| 355 | }; |
| 356 | |
| 357 | // Map from (m, r) to GT transform matrix. |
| 358 | static const llvm::SmallDenseMap<TransformMapKeyTy, TransformMatrix> |
| 359 | GTMatrices = { |
| 360 | {F_2_3, TransformMatrix(GT_2x2_3x3, 3, 4)}, |
| 361 | {F_4_3, TransformMatrix(GT_4x4_3x3, 3, 6)}, |
| 362 | {F_2_5, TransformMatrix(GT_2x2_5x5, 5, 6)}, |
| 363 | }; |
| 364 | |
| 365 | auto filterType = cast<ShapedType>(filter.getType()); |
| 366 | Type elementType = filterType.getElementType(); |
| 367 | auto filterShape = filterType.getShape(); // F, H, W, C |
| 368 | int64_t filterF = filterShape[0]; |
| 369 | int64_t filterH = filterShape[1]; |
| 370 | int64_t filterW = filterShape[2]; |
| 371 | int64_t filterC = filterShape[3]; |
| 372 | |
| 373 | if (filterH != r && filterH != 1) |
| 374 | return Value(); |
| 375 | if (filterW != r && filterW != 1) |
| 376 | return Value(); |
| 377 | |
| 378 | Value zeroIdx = rewriter.create<arith::ConstantIndexOp>(location: loc, args: 0); |
| 379 | auto buildBody = [&](OpBuilder &builder, Location loc, ValueRange ivs, |
| 380 | ValueRange args) -> scf::ValueVector { |
| 381 | Value FIter = ivs[0]; |
| 382 | Value CIter = ivs[1]; |
| 383 | |
| 384 | // Extract (H, W) from (F, H, W, C). |
| 385 | auto = |
| 386 | extract2DDataFrom4D(builder, loc, source: filter, loopNorFIndex: FIter, loopCorFIndex: CIter, heightOffset: zeroIdx, |
| 387 | widthOffset: zeroIdx, extractHeight: filterH, extractWidth: filterW, /*loopNorFIdx=*/0, |
| 388 | /*loopCorFIdx=*/3, /*heightIdx=*/1, /*widthIdx=*/2); |
| 389 | |
| 390 | TransformMapKeyTy key = {m, r}; |
| 391 | int64_t retRows = 1; |
| 392 | Value matmulRetValue = extractFilter; |
| 393 | Value zero = builder.create<arith::ConstantOp>( |
| 394 | loc, rewriter.getZeroAttr(elementType)); |
| 395 | if (leftTransform) { |
| 396 | // Get constant transform matrix G. |
| 397 | auto it = GMatrices.find(Val: key); |
| 398 | if (it == GMatrices.end()) |
| 399 | return {}; |
| 400 | const TransformMatrix &GMatrix = it->second; |
| 401 | |
| 402 | retRows = GMatrix.rows; |
| 403 | auto matmulType = RankedTensorType::get({retRows, filterW}, elementType); |
| 404 | auto empty = |
| 405 | builder |
| 406 | .create<tensor::EmptyOp>(loc, matmulType.getShape(), elementType) |
| 407 | .getResult(); |
| 408 | auto init = builder.create<linalg::FillOp>(loc, zero, empty).getResult(0); |
| 409 | |
| 410 | Value G = create2DTransformMatrix(builder, loc, transform: GMatrix, type: elementType); |
| 411 | // Multiply G x g. |
| 412 | auto matmulOp = builder.create<linalg::MatmulOp>( |
| 413 | loc, matmulType, ValueRange{G, extractFilter}, ValueRange{init}); |
| 414 | matmulRetValue = matmulOp.getResult(0); |
| 415 | } |
| 416 | |
| 417 | if (rightTransform) { |
| 418 | // Get constant transform matrix GT. |
| 419 | auto it = GTMatrices.find(Val: key); |
| 420 | if (it == GTMatrices.end()) |
| 421 | return {}; |
| 422 | const TransformMatrix >Matrix = it->second; |
| 423 | |
| 424 | auto matmulType = |
| 425 | RankedTensorType::get({retRows, GTMatrix.cols}, elementType); |
| 426 | auto empty = |
| 427 | builder |
| 428 | .create<tensor::EmptyOp>(loc, matmulType.getShape(), elementType) |
| 429 | .getResult(); |
| 430 | auto init = builder.create<linalg::FillOp>(loc, zero, empty).getResult(0); |
| 431 | |
| 432 | Value GT = create2DTransformMatrix(builder, loc, transform: GTMatrix, type: elementType); |
| 433 | // Multiply u = (G x g) x GT. |
| 434 | auto matmulOp = builder.create<linalg::MatmulOp>( |
| 435 | loc, matmulType, ValueRange{matmulRetValue, GT}, ValueRange{init}); |
| 436 | matmulRetValue = matmulOp.getResult(0); |
| 437 | } |
| 438 | |
| 439 | // Insert (H, W) to (H, W, C, F). |
| 440 | int64_t retHeight = leftTransform ? m + r - 1 : 1; |
| 441 | int64_t retWidth = rightTransform ? m + r - 1 : 1; |
| 442 | |
| 443 | auto insertSliceOp = |
| 444 | insert2DDataTo4D(builder, loc, source: matmulRetValue, dest: args[0], loopNorFIndex: FIter, loopCorFIndex: CIter, |
| 445 | heightOffset: zeroIdx, widthOffset: zeroIdx, height: retHeight, width: retWidth, |
| 446 | /*loopNorFIdx=*/3, /*loopCorFIdx=*/2, |
| 447 | /*heightIdx=*/0, /*widthIdx=*/1); |
| 448 | |
| 449 | return {insertSliceOp}; |
| 450 | }; |
| 451 | |
| 452 | auto fUpperBound = rewriter.create<arith::ConstantIndexOp>(location: loc, args&: filterF); |
| 453 | auto cUpperBound = rewriter.create<arith::ConstantIndexOp>(location: loc, args&: filterC); |
| 454 | auto oneStep = rewriter.create<arith::ConstantIndexOp>(location: loc, args: 1); |
| 455 | scf::LoopNest loops = scf::buildLoopNest( |
| 456 | rewriter, loc, {zeroIdx, zeroIdx}, {fUpperBound, cUpperBound}, |
| 457 | {oneStep, oneStep}, {retValue}, buildBody); |
| 458 | return loops.results[0]; |
| 459 | } |
| 460 | |
| 461 | /// This function transforms the input. The data layout of the input is NHWC. |
| 462 | /// The transformation matrix is 2-dimension. We need to extract H x W from |
| 463 | /// NHWC first. We need to generate 2 levels of loops to iterate on N and C. |
| 464 | /// After the transformation, we get |
| 465 | /// |
| 466 | /// scf.for %h = 0 to tileH step 1 |
| 467 | /// scf.for %w = 0 to tileW step 1 |
| 468 | /// scf.for %n = 0 to N step 1 |
| 469 | /// scf.for %c = 0 to C step 1 |
| 470 | /// %extracted = extract %extracted<alphaH x alphaW> from |
| 471 | /// %input<N x H x W x C> |
| 472 | /// at [%n, (%h x m), (%w x m), %c] |
| 473 | /// %ret = linalg.matmul BT, %extracted |
| 474 | /// %ret = linalg.matmul %ret, B |
| 475 | /// %inserted = insert %ret<alphaH x alphaW> into |
| 476 | /// %output<alphaH x alphaW x tileH x tileW x N x C> |
| 477 | /// at [0, 0, %h, %w, %n, %c] |
| 478 | Value inputTransform(RewriterBase &rewriter, Location loc, Value input, |
| 479 | Value retValue, int64_t m, int64_t r, |
| 480 | bool leftTransform = true, bool rightTransform = true) { |
| 481 | // Map from (m, r) to BT transform matrix. |
| 482 | static const llvm::SmallDenseMap<TransformMapKeyTy, TransformMatrix> |
| 483 | BTMatrices = { |
| 484 | {F_2_3, TransformMatrix(BT_2x2_3x3, 4, 4)}, |
| 485 | {F_4_3, TransformMatrix(BT_4x4_3x3, 6, 6)}, |
| 486 | {F_2_5, TransformMatrix(BT_2x2_5x5, 6, 6)}, |
| 487 | }; |
| 488 | |
| 489 | // Map from (m, r) to B transform matrix. |
| 490 | static const llvm::SmallDenseMap<TransformMapKeyTy, TransformMatrix> |
| 491 | BMatrices = { |
| 492 | {F_2_3, TransformMatrix(B_2x2_3x3, 4, 4)}, |
| 493 | {F_4_3, TransformMatrix(B_4x4_3x3, 6, 6)}, |
| 494 | {F_2_5, TransformMatrix(B_2x2_5x5, 6, 6)}, |
| 495 | }; |
| 496 | |
| 497 | auto inputType = cast<ShapedType>(input.getType()); |
| 498 | Type elementType = inputType.getElementType(); |
| 499 | auto inputShape = inputType.getShape(); // N, H, W, C |
| 500 | int64_t inputN = inputShape[0]; |
| 501 | int64_t inputC = inputShape[3]; |
| 502 | auto valueType = cast<ShapedType>(retValue.getType()); |
| 503 | auto valueShape = valueType.getShape(); // alphaH, alphaW, HTile, WTile, N, C |
| 504 | int64_t tileH = valueShape[2]; |
| 505 | int64_t tileW = valueShape[3]; |
| 506 | int64_t alphaH = leftTransform ? m + r - 1 : 1; |
| 507 | int64_t alphaW = rightTransform ? m + r - 1 : 1; |
| 508 | |
| 509 | auto buildBody = [&](OpBuilder &builder, Location loc, ValueRange ivs, |
| 510 | ValueRange args) -> scf::ValueVector { |
| 511 | Value tileHIter = ivs[0]; |
| 512 | Value tileWIter = ivs[1]; |
| 513 | Value NIter = ivs[2]; |
| 514 | Value CIter = ivs[3]; |
| 515 | |
| 516 | auto context = builder.getContext(); |
| 517 | |
| 518 | auto identityAffineMap = rewriter.getMultiDimIdentityMap(rank: 1); |
| 519 | auto affineMap = |
| 520 | AffineMap::get(dimCount: 1, symbolCount: 0, results: {builder.getAffineDimExpr(position: 0) * m}, context); |
| 521 | Value heightOffset = builder.create<affine::AffineApplyOp>( |
| 522 | loc, leftTransform ? affineMap : identityAffineMap, tileHIter); |
| 523 | Value widthOffset = builder.create<affine::AffineApplyOp>( |
| 524 | loc, rightTransform ? affineMap : identityAffineMap, tileWIter); |
| 525 | |
| 526 | // Extract (H, W) from (N, H, W, C). |
| 527 | auto = |
| 528 | extract2DDataFrom4D(builder, loc, source: input, loopNorFIndex: NIter, loopCorFIndex: CIter, heightOffset, |
| 529 | widthOffset, extractHeight: alphaH, extractWidth: alphaW, /*loopNorFIdx=*/0, |
| 530 | /*loopCorFIdx=*/3, /*heightIdx=*/1, /*widthIdx=*/2); |
| 531 | |
| 532 | TransformMapKeyTy key = {m, r}; |
| 533 | int64_t retRows = 1; |
| 534 | int64_t retCols = 1; |
| 535 | Value matmulRetValue = extractInput; |
| 536 | Value zero = builder.create<arith::ConstantOp>( |
| 537 | loc, rewriter.getZeroAttr(elementType)); |
| 538 | if (leftTransform) { |
| 539 | // Get constant transform matrix BT. |
| 540 | auto it = BTMatrices.find(Val: key); |
| 541 | if (it == BTMatrices.end()) |
| 542 | return {}; |
| 543 | const TransformMatrix &BTMatrix = it->second; |
| 544 | |
| 545 | retRows = BTMatrix.rows; |
| 546 | auto matmulType = RankedTensorType::get({retRows, alphaW}, elementType); |
| 547 | auto empty = |
| 548 | builder |
| 549 | .create<tensor::EmptyOp>(loc, matmulType.getShape(), elementType) |
| 550 | .getResult(); |
| 551 | auto init = builder.create<linalg::FillOp>(loc, zero, empty).getResult(0); |
| 552 | |
| 553 | Value BT = |
| 554 | create2DTransformMatrix(builder, loc, BTMatrix, builder.getF32Type()); |
| 555 | // Multiply BT x d. |
| 556 | auto matmulOp = builder.create<linalg::MatmulOp>( |
| 557 | loc, matmulType, ValueRange{BT, matmulRetValue}, ValueRange{init}); |
| 558 | matmulRetValue = matmulOp.getResult(0); |
| 559 | } |
| 560 | |
| 561 | if (rightTransform) { |
| 562 | // Get constant transform matrix B. |
| 563 | auto it = BMatrices.find(Val: key); |
| 564 | if (it == BMatrices.end()) |
| 565 | return {}; |
| 566 | const TransformMatrix &BMatrix = it->second; |
| 567 | |
| 568 | retCols = BMatrix.cols; |
| 569 | auto matmulType = RankedTensorType::get({retRows, retCols}, elementType); |
| 570 | auto empty = |
| 571 | builder |
| 572 | .create<tensor::EmptyOp>(loc, matmulType.getShape(), elementType) |
| 573 | .getResult(); |
| 574 | auto init = builder.create<linalg::FillOp>(loc, zero, empty).getResult(0); |
| 575 | Value B = |
| 576 | create2DTransformMatrix(builder, loc, BMatrix, builder.getF32Type()); |
| 577 | // Multiply v = (BT x d) x B. |
| 578 | auto matmulOp = builder.create<linalg::MatmulOp>( |
| 579 | loc, matmulType, ValueRange{matmulRetValue, B}, ValueRange{init}); |
| 580 | matmulRetValue = matmulOp.getResult(0); |
| 581 | } |
| 582 | |
| 583 | // Insert (H, W) to (H, W, tileH, tileW, N, C). |
| 584 | auto combinedVal = insert2DDataTo6D( |
| 585 | builder, loc, source: matmulRetValue, dest: args[0], tileHIndex: tileHIter, tileWIndex: tileWIter, loopNorFIndex: NIter, |
| 586 | loopCorFIndex: CIter, height: retRows, width: retCols, tileHIdx: 2, tileWIdx: 3, /*loopNorFIdx=*/4, /*loopCorFIdx=*/5, |
| 587 | /*heightIdx=*/0, /*widthIdx=*/1); |
| 588 | |
| 589 | return {combinedVal}; |
| 590 | }; |
| 591 | |
| 592 | auto zeroIdx = rewriter.create<arith::ConstantIndexOp>(location: loc, args: 0); |
| 593 | auto tileHBound = rewriter.create<arith::ConstantIndexOp>(location: loc, args&: tileH); |
| 594 | auto tileWBound = rewriter.create<arith::ConstantIndexOp>(location: loc, args&: tileW); |
| 595 | auto nUpperBound = rewriter.create<arith::ConstantIndexOp>(location: loc, args&: inputN); |
| 596 | auto cUpperBound = rewriter.create<arith::ConstantIndexOp>(location: loc, args&: inputC); |
| 597 | auto oneStep = rewriter.create<arith::ConstantIndexOp>(location: loc, args: 1); |
| 598 | scf::LoopNest loops = scf::buildLoopNest( |
| 599 | rewriter, loc, {zeroIdx, zeroIdx, zeroIdx, zeroIdx}, |
| 600 | {tileHBound, tileWBound, nUpperBound, cUpperBound}, |
| 601 | {oneStep, oneStep, oneStep, oneStep}, {retValue}, buildBody); |
| 602 | return loops.results[0]; |
| 603 | } |
| 604 | |
| 605 | /// This function generates linalg.batch_matmul to multiply input with filter. |
| 606 | /// linalg.batch_matmul only supports 3-dimensional inputs. We can treat |
| 607 | /// tileH x tileW x H x W data as the 1-dimensional data array. That is to |
| 608 | /// convert [tileH, tileW, H, W, N, C] to [tileH x tileW x H x W, N, C]. In this |
| 609 | /// way, we can convert 6-dimensional inputs to 3-dimensional representation |
| 610 | /// that is suitable for linalg.batch_matmul. |
| 611 | /// |
| 612 | /// Batched matmul will do the matrix multiply with the reduction on channel. |
| 613 | /// |
| 614 | /// We get |
| 615 | /// |
| 616 | /// %collapsed_input = tensor.collapse_shape %input |
| 617 | /// %collapsed_filter = tensor.collapse_shape %filter |
| 618 | /// %ret = linalg.batch_matmul %collapsed_input, %collapsed_filter |
| 619 | /// %expanded_ret = tensor.expand_shape %ret |
| 620 | /// |
| 621 | /// After this function, we get return value with data layout |
| 622 | /// (tileH, tileW, H, W, N, F). |
| 623 | static Value matrixMultiply(RewriterBase &rewriter, Location loc, |
| 624 | Value transformedFilter, Value transformedInput, |
| 625 | Type outputElementType) { |
| 626 | // Convert (alphaH, alphaW, C, F) to (alphaH x alphaW, C, F) for filter. |
| 627 | auto filterType = cast<ShapedType>(transformedFilter.getType()); |
| 628 | assert(filterType.hasStaticShape() && "only support static shapes." ); |
| 629 | ArrayRef<int64_t> filterShape = filterType.getShape(); |
| 630 | Type filterElementType = filterType.getElementType(); |
| 631 | auto filterReassocType = RankedTensorType::get( |
| 632 | {filterShape[0] * filterShape[1], filterShape[2], filterShape[3]}, |
| 633 | filterElementType); |
| 634 | SmallVector<ReassociationIndices> filterReassoc = {{0, 1}, {2}, {3}}; |
| 635 | Value collapseFilter = rewriter.create<tensor::CollapseShapeOp>( |
| 636 | loc, filterReassocType, transformedFilter, filterReassoc); |
| 637 | |
| 638 | // Convert (alphaH, alphaW, tileH, tileW, N, C) to |
| 639 | // (alphaH x alphaW, tileH x tileW x N, C) for input. |
| 640 | auto inputType = cast<ShapedType>(transformedInput.getType()); |
| 641 | assert(inputType.hasStaticShape() && "only support static shapes." ); |
| 642 | ArrayRef<int64_t> inputShape = inputType.getShape(); |
| 643 | Type inputElementType = inputType.getElementType(); |
| 644 | auto inputReassocType = RankedTensorType::get( |
| 645 | {inputShape[0] * inputShape[1], |
| 646 | inputShape[2] * inputShape[3] * inputShape[4], inputShape[5]}, |
| 647 | inputElementType); |
| 648 | SmallVector<ReassociationIndices> inputReassoc = {{0, 1}, {2, 3, 4}, {5}}; |
| 649 | Value collapseInput = rewriter.create<tensor::CollapseShapeOp>( |
| 650 | loc, inputReassocType, transformedInput, inputReassoc); |
| 651 | |
| 652 | // Batched matrix multiply. |
| 653 | auto matmulType = RankedTensorType::get( |
| 654 | {inputShape[0] * inputShape[1], |
| 655 | inputShape[2] * inputShape[3] * inputShape[4], filterShape[3]}, |
| 656 | outputElementType); |
| 657 | Value empty = rewriter |
| 658 | .create<tensor::EmptyOp>(loc, matmulType.getShape(), |
| 659 | outputElementType) |
| 660 | .getResult(); |
| 661 | Value zero = rewriter.create<arith::ConstantOp>( |
| 662 | loc, rewriter.getZeroAttr(outputElementType)); |
| 663 | Value init = rewriter.create<linalg::FillOp>(loc, zero, empty).getResult(0); |
| 664 | |
| 665 | auto matmulOp = rewriter.create<linalg::BatchMatmulOp>( |
| 666 | loc, matmulType, ValueRange({collapseInput, collapseFilter}), |
| 667 | ValueRange{init}); |
| 668 | |
| 669 | // The result shape of batch matmul is (alphaH x alphaW, tileH x tileW x N, F) |
| 670 | // Expand matmul result to (alphaH, alphaW, tileH, tileW, N, F). |
| 671 | SmallVector<ReassociationIndices> outputReassoc = {{0, 1}, {2, 3, 4}, {5}}; |
| 672 | auto outputReassocType = |
| 673 | RankedTensorType::get({inputShape[0], inputShape[1], inputShape[2], |
| 674 | inputShape[3], inputShape[4], filterShape[3]}, |
| 675 | outputElementType); |
| 676 | auto expandOutput = rewriter.create<tensor::ExpandShapeOp>( |
| 677 | loc, outputReassocType, matmulOp.getResult(0), outputReassoc); |
| 678 | return expandOutput; |
| 679 | } |
| 680 | |
| 681 | /// This function transforms the output. The data layout of the output is HWNF. |
| 682 | /// The transformation matrix is 2-dimension. We need to extract H x W from |
| 683 | /// HWNF first. We need to generate 2 levels of loops to iterate on N and F. |
| 684 | /// After the transformation, we get |
| 685 | /// |
| 686 | /// scf.for %h = 0 to tileH step 1 |
| 687 | /// scf.for %w = 0 to tileW step 1 |
| 688 | /// scf.for %n = 0 to N step 1 |
| 689 | /// scf.for %f = 0 to F step 1 |
| 690 | /// %extracted = extract %extracted<alphaH x alphaW> from |
| 691 | /// %input<alphaH x alphaW x tileH x tileW x N x F> |
| 692 | /// at [0, 0, %h, %w, %n, %f] |
| 693 | /// %ret = linalg.matmul AT, %extracted |
| 694 | /// %ret = linalg.matmul %ret, A |
| 695 | /// %inserted = insert %ret<alphaH x alphaW> into |
| 696 | /// output<N x H x W x F> |
| 697 | /// at [%n, (%h x m), (%w x m), %f] |
| 698 | Value outputTransform(RewriterBase &rewriter, Location loc, Value value, |
| 699 | Value output, int64_t m, int64_t r, |
| 700 | bool leftTransform = true, bool rightTransform = true) { |
| 701 | // Map from (m, r) to AT transform matrix. |
| 702 | static const llvm::SmallDenseMap<TransformMapKeyTy, TransformMatrix> |
| 703 | ATMatrices = { |
| 704 | {F_2_3, TransformMatrix(AT_2x2_3x3, 2, 4)}, |
| 705 | {F_4_3, TransformMatrix(AT_4x4_3x3, 4, 6, 32)}, |
| 706 | {F_2_5, TransformMatrix(AT_2x2_5x5, 2, 6, 16)}, |
| 707 | }; |
| 708 | |
| 709 | // Map from (m, r) to A transform matrix. |
| 710 | static const llvm::SmallDenseMap<TransformMapKeyTy, TransformMatrix> |
| 711 | AMatrices = { |
| 712 | {F_2_3, TransformMatrix(A_2x2_3x3, 4, 2)}, |
| 713 | {F_4_3, TransformMatrix(A_4x4_3x3, 6, 4, 32)}, |
| 714 | {F_2_5, TransformMatrix(A_2x2_5x5, 6, 2, 16)}, |
| 715 | }; |
| 716 | |
| 717 | auto valueType = cast<ShapedType>(value.getType()); |
| 718 | Type elementType = valueType.getElementType(); |
| 719 | auto valueShape = valueType.getShape(); // H, W, TileH, TileW, N, F |
| 720 | int64_t valueH = valueShape[0]; |
| 721 | int64_t valueW = valueShape[1]; |
| 722 | int64_t valueN = valueShape[4]; |
| 723 | int64_t valueF = valueShape[5]; |
| 724 | int64_t alphaH = leftTransform ? m + r - 1 : 1; |
| 725 | int64_t alphaW = rightTransform ? m + r - 1 : 1; |
| 726 | |
| 727 | if (valueH != alphaH && valueH != 1) |
| 728 | return Value(); |
| 729 | if (valueW != alphaW && valueW != 1) |
| 730 | return Value(); |
| 731 | |
| 732 | auto buildBody = [&](OpBuilder &builder, Location loc, ValueRange ivs, |
| 733 | ValueRange args) -> scf::ValueVector { |
| 734 | auto context = builder.getContext(); |
| 735 | Value tileHIter = ivs[0]; |
| 736 | Value tileWIter = ivs[1]; |
| 737 | Value NIter = ivs[2]; |
| 738 | Value FIter = ivs[3]; |
| 739 | |
| 740 | // Extract (H, W) from (H, W, tileH, tileW, N, F). |
| 741 | auto = |
| 742 | extract2DDataFrom6D(builder, loc, source: value, tileHIndex: tileHIter, tileWIndex: tileWIter, loopNorFIndex: NIter, |
| 743 | loopCorFIndex: FIter, tileHIdx: 2, tileWIdx: 3, /*loopNorFIdx=*/4, |
| 744 | /*loopCorFIdx=*/5, /*heightIdx=*/0, /*widthIdx=*/1); |
| 745 | |
| 746 | const TransformMapKeyTy key = {m, r}; |
| 747 | const TransformMatrix &AMatrix = AMatrices.at(Val: key); |
| 748 | const TransformMatrix &ATMatrix = ATMatrices.at(Val: key); |
| 749 | int64_t scalarFactor = (rightTransform ? AMatrix.scalarFactor : 1) * |
| 750 | (leftTransform ? ATMatrix.scalarFactor : 1); |
| 751 | int64_t retCols = rightTransform ? AMatrix.cols : 1; |
| 752 | int64_t retRows = leftTransform ? ATMatrix.rows : 1; |
| 753 | |
| 754 | Value matmulRetValue = extractValue; |
| 755 | Value zero = builder.create<arith::ConstantOp>( |
| 756 | loc, rewriter.getZeroAttr(elementType)); |
| 757 | |
| 758 | auto identityAffineMap = rewriter.getMultiDimIdentityMap(rank: 1); |
| 759 | auto affineMap = |
| 760 | AffineMap::get(dimCount: 1, symbolCount: 0, results: {builder.getAffineDimExpr(position: 0) * m}, context); |
| 761 | Value heightOffset = builder.create<affine::AffineApplyOp>( |
| 762 | loc, leftTransform ? affineMap : identityAffineMap, tileHIter); |
| 763 | Value widthOffset = builder.create<affine::AffineApplyOp>( |
| 764 | loc, rightTransform ? affineMap : identityAffineMap, tileWIter); |
| 765 | |
| 766 | Value outInitVal = |
| 767 | extract2DDataFrom4D(builder, loc, source: args[0], loopNorFIndex: NIter, loopCorFIndex: FIter, heightOffset, |
| 768 | widthOffset, extractHeight: retRows, extractWidth: retCols, |
| 769 | /*loopNorFIdx=*/0, |
| 770 | /*loopCorFIdx=*/3, /*heightIdx=*/1, |
| 771 | /*widthIdx=*/2); |
| 772 | if (leftTransform) { |
| 773 | auto matmulType = RankedTensorType::get({retRows, valueW}, elementType); |
| 774 | Value init = outInitVal; |
| 775 | if (rightTransform || scalarFactor != 1) { |
| 776 | auto empty = builder |
| 777 | .create<tensor::EmptyOp>(loc, matmulType.getShape(), |
| 778 | elementType) |
| 779 | .getResult(); |
| 780 | init = builder.create<linalg::FillOp>(loc, zero, empty).getResult(0); |
| 781 | } |
| 782 | |
| 783 | Value AT = create2DTransformMatrix(builder, loc, transform: ATMatrix, type: elementType); |
| 784 | // Multiply AT x m. |
| 785 | auto matmulOp = builder.create<linalg::MatmulOp>( |
| 786 | loc, matmulType, ValueRange{AT, matmulRetValue}, ValueRange{init}); |
| 787 | matmulRetValue = matmulOp.getResult(0); |
| 788 | } |
| 789 | |
| 790 | if (rightTransform) { |
| 791 | auto matmulType = |
| 792 | RankedTensorType::get({retRows, AMatrix.cols}, elementType); |
| 793 | Value init = outInitVal; |
| 794 | if (scalarFactor != 1) { |
| 795 | auto empty = builder |
| 796 | .create<tensor::EmptyOp>(loc, matmulType.getShape(), |
| 797 | elementType) |
| 798 | .getResult(); |
| 799 | init = builder.create<linalg::FillOp>(loc, zero, empty).getResult(0); |
| 800 | } |
| 801 | |
| 802 | Value A = create2DTransformMatrix(builder, loc, transform: AMatrix, type: elementType); |
| 803 | // Multiply y = (AT x m) x A. |
| 804 | auto matmulOp = builder.create<linalg::MatmulOp>( |
| 805 | loc, matmulType, ValueRange{matmulRetValue, A}, ValueRange{init}); |
| 806 | matmulRetValue = matmulOp.getResult(0); |
| 807 | } |
| 808 | |
| 809 | if (scalarFactor != 1) { |
| 810 | // Multiply by scalar factor and add outInitVal. |
| 811 | Value scalarFactorValue = builder.create<arith::ConstantOp>( |
| 812 | loc, FloatAttr::get(elementType, scalarFactor)); |
| 813 | auto matmulType = RankedTensorType::get({retRows, retCols}, elementType); |
| 814 | auto identityAffineMap = rewriter.getMultiDimIdentityMap(rank: 2); |
| 815 | SmallVector<AffineMap> affineMaps = { |
| 816 | AffineMap::get(dimCount: 2, symbolCount: 0, context), identityAffineMap, identityAffineMap}; |
| 817 | |
| 818 | matmulRetValue = |
| 819 | rewriter |
| 820 | .create<linalg::GenericOp>( |
| 821 | loc, matmulType, |
| 822 | ValueRange{scalarFactorValue, matmulRetValue}, |
| 823 | ValueRange{outInitVal}, affineMaps, |
| 824 | llvm::ArrayRef<utils::IteratorType>{ |
| 825 | utils::IteratorType::parallel, |
| 826 | utils::IteratorType::parallel}, |
| 827 | [&](OpBuilder &nestedBuilder, Location nestedLoc, |
| 828 | ValueRange args) { |
| 829 | auto mulf = nestedBuilder.create<arith::MulFOp>( |
| 830 | nestedLoc, args[0], args[1]); |
| 831 | auto addf = nestedBuilder.create<arith::AddFOp>( |
| 832 | nestedLoc, mulf.getResult(), args[2]); |
| 833 | nestedBuilder.create<linalg::YieldOp>(nestedLoc, |
| 834 | addf.getResult()); |
| 835 | }) |
| 836 | .getResult(0); |
| 837 | } |
| 838 | |
| 839 | // Insert (H, W) to (N, H, W, F). |
| 840 | Value combinedVal = |
| 841 | insert2DDataTo4D(builder, loc, source: matmulRetValue, dest: args[0], loopNorFIndex: NIter, loopCorFIndex: FIter, |
| 842 | heightOffset, widthOffset, height: retRows, width: retCols, |
| 843 | /*loopNorFIdx=*/0, |
| 844 | /*loopCorFIdx=*/3, /*heightIdx=*/1, |
| 845 | /*widthIdx=*/2); |
| 846 | |
| 847 | return {combinedVal}; |
| 848 | }; |
| 849 | |
| 850 | int64_t tilwH = valueShape[2]; |
| 851 | int64_t tileW = valueShape[3]; |
| 852 | auto zeroIdx = rewriter.create<arith::ConstantIndexOp>(location: loc, args: 0); |
| 853 | auto tileHBound = rewriter.create<arith::ConstantIndexOp>(location: loc, args&: tilwH); |
| 854 | auto tileWBound = rewriter.create<arith::ConstantIndexOp>(location: loc, args&: tileW); |
| 855 | auto nUpperBound = rewriter.create<arith::ConstantIndexOp>(location: loc, args&: valueN); |
| 856 | auto fUpperBound = rewriter.create<arith::ConstantIndexOp>(location: loc, args&: valueF); |
| 857 | auto oneStep = rewriter.create<arith::ConstantIndexOp>(location: loc, args: 1); |
| 858 | scf::LoopNest loops = scf::buildLoopNest( |
| 859 | rewriter, loc, {zeroIdx, zeroIdx, zeroIdx, zeroIdx}, |
| 860 | {tileHBound, tileWBound, nUpperBound, fUpperBound}, |
| 861 | {oneStep, oneStep, oneStep, oneStep}, {output}, buildBody); |
| 862 | return loops.results[0]; |
| 863 | } |
| 864 | |
| 865 | /// Create an empty tensor with alignedType and insert the value into the |
| 866 | /// created empty tensor with aligned size. |
| 867 | static Value padToAlignedTensor(RewriterBase &rewriter, Location loc, |
| 868 | Value value, ArrayRef<int64_t> alignedShape) { |
| 869 | auto valueType = cast<ShapedType>(value.getType()); |
| 870 | Type elementType = valueType.getElementType(); |
| 871 | auto alignedType = RankedTensorType::get(alignedShape, elementType); |
| 872 | Value padValue = rewriter.create<arith::ConstantOp>( |
| 873 | loc, elementType, rewriter.getZeroAttr(elementType)); |
| 874 | |
| 875 | return linalg::makeComposedPadHighOp(b&: rewriter, loc, type: alignedType, source: value, |
| 876 | pad: padValue, nofold: false); |
| 877 | } |
| 878 | |
| 879 | /// Extract sub-tensor with extractedType from value. |
| 880 | static Value (RewriterBase &rewriter, Location loc, |
| 881 | Value value, |
| 882 | RankedTensorType ) { |
| 883 | OpFoldResult zeroIndex = rewriter.getIndexAttr(0); |
| 884 | OpFoldResult oneIndex = rewriter.getIndexAttr(1); |
| 885 | SmallVector<OpFoldResult, 4> offsets(4, zeroIndex); |
| 886 | SmallVector<OpFoldResult, 4> strides(4, oneIndex); |
| 887 | |
| 888 | ArrayRef<int64_t> = extractedType.getShape(); |
| 889 | SmallVector<OpFoldResult> sizes = |
| 890 | getAsOpFoldResult(rewriter.getI64ArrayAttr(extractedShape)); |
| 891 | |
| 892 | return rewriter.create<tensor::ExtractSliceOp>(loc, extractedType, value, |
| 893 | offsets, sizes, strides); |
| 894 | } |
| 895 | |
| 896 | /// Utility function to check all values in the attribute are 1. |
| 897 | static bool hasAllOneValues(DenseIntElementsAttr attr) { |
| 898 | return llvm::all_of( |
| 899 | Range&: attr, P: [](const APInt &element) { return element.getSExtValue() == 1; }); |
| 900 | } |
| 901 | |
| 902 | /// A helper function to convert linalg.conv_2d_nhwc_fhwc to |
| 903 | /// linalg.winograd_*_transform ops. |
| 904 | static FailureOr<Operation *> |
| 905 | winogradConv2DHelper(RewriterBase &rewriter, linalg::Conv2DNhwcFhwcOp convOp, |
| 906 | int64_t m, int64_t r) { |
| 907 | Value input = convOp.getInputs()[0]; |
| 908 | Value filter = convOp.getInputs()[1]; |
| 909 | Value output = convOp.getOutputs()[0]; |
| 910 | auto inputType = cast<ShapedType>(input.getType()); |
| 911 | auto filterType = cast<ShapedType>(filter.getType()); |
| 912 | auto outputType = cast<ShapedType>(output.getType()); |
| 913 | |
| 914 | if (!inputType.hasStaticShape()) |
| 915 | return rewriter.notifyMatchFailure(convOp, |
| 916 | "expected a static shape for the input" ); |
| 917 | |
| 918 | if (!filterType.hasStaticShape()) |
| 919 | return rewriter.notifyMatchFailure( |
| 920 | convOp, "expected a static shape for the filter" ); |
| 921 | |
| 922 | if (!hasAllOneValues(convOp.getDilations())) |
| 923 | return rewriter.notifyMatchFailure(convOp, |
| 924 | "expected all ones for dilations" ); |
| 925 | |
| 926 | if (!hasAllOneValues(convOp.getStrides())) |
| 927 | return rewriter.notifyMatchFailure(convOp, "expected all ones for strides" ); |
| 928 | |
| 929 | ArrayRef<int64_t> filterShape = filterType.getShape(); |
| 930 | int64_t filterF = filterShape[0]; |
| 931 | int64_t filterH = filterShape[1]; |
| 932 | int64_t filterW = filterShape[2]; |
| 933 | int64_t filterC = filterShape[3]; |
| 934 | ArrayRef<int64_t> inputShape = inputType.getShape(); |
| 935 | int64_t inputN = inputShape[0]; |
| 936 | int64_t inputH = inputShape[1]; |
| 937 | int64_t inputW = inputShape[2]; |
| 938 | int64_t inputC = inputShape[3]; |
| 939 | ArrayRef<int64_t> outputShape = outputType.getShape(); |
| 940 | int64_t outputN = outputShape[0]; |
| 941 | int64_t outputH = outputShape[1]; |
| 942 | int64_t outputW = outputShape[2]; |
| 943 | int64_t outputF = outputShape[3]; |
| 944 | |
| 945 | // Only support F(m x m, r x r), F(m x 1, r x 1) or F(1 x m, 1 x r). |
| 946 | bool isSupportedFilter = false; |
| 947 | if (filterH == filterW && filterH == r) |
| 948 | isSupportedFilter = true; |
| 949 | if (filterH == r && filterW == 1) |
| 950 | isSupportedFilter = true; |
| 951 | if (filterH == 1 && filterW == r) |
| 952 | isSupportedFilter = true; |
| 953 | |
| 954 | if (!isSupportedFilter) |
| 955 | return rewriter.notifyMatchFailure( |
| 956 | convOp, "only support filter (r x r), (r x 1) or (1 x r)" ); |
| 957 | |
| 958 | // Currently, we support (m, r) = (2, 3) or (4, 3) or (2, 5). |
| 959 | static const llvm::SmallVector<TransformMapKeyTy, 3> validConfigs = { |
| 960 | F_2_3, F_4_3, F_2_5}; |
| 961 | |
| 962 | TransformMapKeyTy key = {m, r}; |
| 963 | auto it = llvm::find(Range: validConfigs, Val: key); |
| 964 | // If we cannot find the constant transformation matrix, it means we do |
| 965 | // not support this configuration yet. |
| 966 | if (it == validConfigs.end()) |
| 967 | return failure(); |
| 968 | |
| 969 | // All the criterias are satisfied. We can do Winograd Conv2D. |
| 970 | Location loc = convOp.getLoc(); |
| 971 | |
| 972 | // For F(m x 1, r x 1), we only need to do left side transform. |
| 973 | bool leftTransform = filterH != 1; |
| 974 | // For F(1 x m, 1 x r), we only need to do right side transform. |
| 975 | bool rightTransform = filterW != 1; |
| 976 | int64_t heightM = leftTransform ? m : 1; |
| 977 | int64_t widthM = rightTransform ? m : 1; |
| 978 | int64_t heightR = leftTransform ? r : 1; |
| 979 | int64_t widthR = rightTransform ? r : 1; |
| 980 | |
| 981 | // --- Create operation for filter transform --- |
| 982 | Type filterElementType = filterType.getElementType(); |
| 983 | int64_t alphaH = heightM + heightR - 1; |
| 984 | int64_t alphaW = widthM + widthR - 1; |
| 985 | int64_t tileH = llvm::divideCeilSigned(Numerator: outputH, Denominator: heightM); |
| 986 | int64_t tileW = llvm::divideCeilSigned(Numerator: outputW, Denominator: widthM); |
| 987 | auto retType = RankedTensorType::get({alphaH, alphaW, filterC, filterF}, |
| 988 | filterElementType); |
| 989 | Value retValue = rewriter.create<tensor::EmptyOp>(loc, retType.getShape(), |
| 990 | filterElementType); |
| 991 | auto transformedFilter = rewriter.create<linalg::WinogradFilterTransformOp>( |
| 992 | loc, retType, filter, retValue, m, r); |
| 993 | |
| 994 | // --- Create operation for input transform --- |
| 995 | |
| 996 | // When input size - (r - 1) is not aligned with output tile size, we need to |
| 997 | // pad the input data to create the full tiles as tiling. |
| 998 | Type inputElementType = inputType.getElementType(); |
| 999 | int64_t alignedInputH = tileH * heightM + (heightR - 1); |
| 1000 | int64_t alignedInputW = tileW * widthM + (widthR - 1); |
| 1001 | if (alignedInputH != inputH || alignedInputW != inputW) { |
| 1002 | input = padToAlignedTensor(rewriter, loc, value: input, |
| 1003 | alignedShape: {inputN, alignedInputH, alignedInputW, inputC}); |
| 1004 | } |
| 1005 | |
| 1006 | retType = RankedTensorType::get( |
| 1007 | {alphaH, alphaW, tileH, tileW, inputN, inputC}, inputElementType); |
| 1008 | retValue = rewriter.create<tensor::EmptyOp>(loc, retType.getShape(), |
| 1009 | inputElementType); |
| 1010 | auto transformedInput = rewriter.create<linalg::WinogradInputTransformOp>( |
| 1011 | loc, retType, input, retValue, m, r); |
| 1012 | |
| 1013 | Type outputElementType = outputType.getElementType(); |
| 1014 | Value matmulRet = matrixMultiply(rewriter, loc, transformedFilter, |
| 1015 | transformedInput, outputElementType); |
| 1016 | |
| 1017 | // --- Create operation for output transform --- |
| 1018 | |
| 1019 | // When output size is not aligned with output tile size, we need to pad the |
| 1020 | // output buffer to insert the full tiles after tiling. |
| 1021 | int64_t alignedOutputH = tileH * heightM; |
| 1022 | int64_t alignedOutputW = tileW * widthM; |
| 1023 | bool isOutputUnaligned = |
| 1024 | ((alignedOutputH != outputH) || (alignedOutputW != outputW)); |
| 1025 | if (isOutputUnaligned) { |
| 1026 | auto alignedOutputType = RankedTensorType::get( |
| 1027 | {outputN, alignedOutputH, alignedOutputW, outputF}, outputElementType); |
| 1028 | output = |
| 1029 | padToAlignedTensor(rewriter, loc, output, alignedOutputType.getShape()); |
| 1030 | outputType = alignedOutputType; |
| 1031 | } |
| 1032 | |
| 1033 | Value transformedOutput = rewriter.create<linalg::WinogradOutputTransformOp>( |
| 1034 | loc, outputType, matmulRet, output, m, r); |
| 1035 | |
| 1036 | // When output size is not aligned with output tile size, extract the |
| 1037 | // value from the padded buffer. |
| 1038 | if (isOutputUnaligned) { |
| 1039 | transformedOutput = extractFromAlignedTensor( |
| 1040 | rewriter, loc, transformedOutput, |
| 1041 | RankedTensorType::get({outputN, outputH, outputW, outputF}, |
| 1042 | outputElementType)); |
| 1043 | } |
| 1044 | |
| 1045 | rewriter.replaceOp(convOp, transformedOutput); |
| 1046 | |
| 1047 | return transformedOutput.getDefiningOp(); |
| 1048 | } |
| 1049 | |
| 1050 | /// A helper function to decompose linalg.winograd_filter_transform. |
| 1051 | FailureOr<Operation *> |
| 1052 | decomposeWinogradFilterTransformHelper(RewriterBase &rewriter, |
| 1053 | linalg::WinogradFilterTransformOp op) { |
| 1054 | Location loc = op.getLoc(); |
| 1055 | Value filter = op.getFilter(); |
| 1056 | auto filterType = cast<ShapedType>(filter.getType()); |
| 1057 | auto filterShape = filterType.getShape(); |
| 1058 | int64_t filterH = filterShape[1]; |
| 1059 | int64_t filterW = filterShape[2]; |
| 1060 | |
| 1061 | // For F(m x 1, r x 1), we only need to do left side transform. |
| 1062 | bool leftTransform = filterH != 1; |
| 1063 | // For F(1 x m, 1 x r), we only need to do right side transform. |
| 1064 | bool rightTransform = filterW != 1; |
| 1065 | Value transformedFilter = |
| 1066 | filterTransform(rewriter, loc, filter, op.getOutput(), op.getM(), |
| 1067 | op.getR(), leftTransform, rightTransform); |
| 1068 | if (!transformedFilter) |
| 1069 | return failure(); |
| 1070 | |
| 1071 | rewriter.replaceOp(op, transformedFilter); |
| 1072 | |
| 1073 | return transformedFilter.getDefiningOp(); |
| 1074 | } |
| 1075 | |
| 1076 | /// A helper function to decompose linalg.winograd_input_transform. |
| 1077 | FailureOr<Operation *> |
| 1078 | decomposeWinogradInputTransformHelper(RewriterBase &rewriter, |
| 1079 | linalg::WinogradInputTransformOp op) { |
| 1080 | Location loc = op.getLoc(); |
| 1081 | Value output = op.getOutput(); |
| 1082 | auto outputType = cast<ShapedType>(output.getType()); |
| 1083 | auto outputShape = outputType.getShape(); |
| 1084 | |
| 1085 | int64_t outputH = outputShape[0]; |
| 1086 | int64_t outputW = outputShape[1]; |
| 1087 | |
| 1088 | // For F(m x 1, r x 1), we only need to do left side transform. |
| 1089 | bool leftTransform = outputH != 1; |
| 1090 | // For F(1 x m, 1 x r), we only need to do right side transform. |
| 1091 | bool rightTransform = outputW != 1; |
| 1092 | Value transformedInput = |
| 1093 | inputTransform(rewriter, loc, op.getInput(), op.getOutput(), op.getM(), |
| 1094 | op.getR(), leftTransform, rightTransform); |
| 1095 | if (!transformedInput) |
| 1096 | return failure(); |
| 1097 | |
| 1098 | rewriter.replaceOp(op, transformedInput); |
| 1099 | |
| 1100 | return transformedInput.getDefiningOp(); |
| 1101 | } |
| 1102 | |
| 1103 | /// A helper function to decompose linalg.winograd_output_transform. |
| 1104 | FailureOr<Operation *> |
| 1105 | decomposeWinogradOutputTransformHelper(RewriterBase &rewriter, |
| 1106 | linalg::WinogradOutputTransformOp op) { |
| 1107 | Location loc = op.getLoc(); |
| 1108 | Value value = op.getValue(); |
| 1109 | auto valueType = cast<ShapedType>(value.getType()); |
| 1110 | auto valueShape = valueType.getShape(); |
| 1111 | int64_t valueH = valueShape[0]; |
| 1112 | int64_t valueW = valueShape[1]; |
| 1113 | |
| 1114 | // For F(m x 1, r x 1), we only need to do left side transform. |
| 1115 | bool leftTransform = valueH != 1; |
| 1116 | // For F(1 x m, 1 x r), we only need to do right side transform. |
| 1117 | bool rightTransform = valueW != 1; |
| 1118 | Value transformedOutput = |
| 1119 | outputTransform(rewriter, loc, value, op.getOutput(), op.getM(), |
| 1120 | op.getR(), leftTransform, rightTransform); |
| 1121 | if (!transformedOutput) |
| 1122 | return failure(); |
| 1123 | |
| 1124 | rewriter.replaceOp(op, transformedOutput); |
| 1125 | |
| 1126 | return transformedOutput.getDefiningOp(); |
| 1127 | } |
| 1128 | |
| 1129 | /// A rewrite pattern to decompose linalg.winograd_filter_transform operations. |
| 1130 | class DecomposeWinogradFilterTransform final |
| 1131 | : public OpRewritePattern<linalg::WinogradFilterTransformOp> { |
| 1132 | public: |
| 1133 | using OpRewritePattern::OpRewritePattern; |
| 1134 | |
| 1135 | LogicalResult matchAndRewrite(linalg::WinogradFilterTransformOp op, |
| 1136 | PatternRewriter &rewriter) const override { |
| 1137 | return decomposeWinogradFilterTransformHelper(rewriter, op); |
| 1138 | } |
| 1139 | }; |
| 1140 | |
| 1141 | /// A rewrite pattern to decompose linalg.winograd_input_transform operations. |
| 1142 | class DecomposeWinogradInputTransform final |
| 1143 | : public OpRewritePattern<linalg::WinogradInputTransformOp> { |
| 1144 | public: |
| 1145 | using OpRewritePattern::OpRewritePattern; |
| 1146 | |
| 1147 | LogicalResult matchAndRewrite(linalg::WinogradInputTransformOp op, |
| 1148 | PatternRewriter &rewriter) const override { |
| 1149 | return decomposeWinogradInputTransformHelper(rewriter, op); |
| 1150 | } |
| 1151 | }; |
| 1152 | |
| 1153 | /// A rewrite pattern to decompose linalg.winograd_output_transform operations. |
| 1154 | class DecomposeWinogradOutputTransform final |
| 1155 | : public OpRewritePattern<linalg::WinogradOutputTransformOp> { |
| 1156 | public: |
| 1157 | using OpRewritePattern::OpRewritePattern; |
| 1158 | |
| 1159 | LogicalResult matchAndRewrite(linalg::WinogradOutputTransformOp op, |
| 1160 | PatternRewriter &rewriter) const override { |
| 1161 | return decomposeWinogradOutputTransformHelper(rewriter, op); |
| 1162 | } |
| 1163 | }; |
| 1164 | |
| 1165 | /// A rewrite pattern for Winograd Conv2D algorithm. |
| 1166 | class WinogradConv2DNhwcFhwc final |
| 1167 | : public OpRewritePattern<linalg::Conv2DNhwcFhwcOp> { |
| 1168 | public: |
| 1169 | using OpRewritePattern::OpRewritePattern; |
| 1170 | WinogradConv2DNhwcFhwc(mlir::MLIRContext *context, int64_t m, int64_t r) |
| 1171 | : OpRewritePattern(context), m(m), r(r) {} |
| 1172 | |
| 1173 | LogicalResult matchAndRewrite(linalg::Conv2DNhwcFhwcOp convOp, |
| 1174 | PatternRewriter &rewriter) const override { |
| 1175 | if (failed(winogradConv2DHelper(rewriter, convOp, m, r))) |
| 1176 | return failure(); |
| 1177 | |
| 1178 | return success(); |
| 1179 | } |
| 1180 | |
| 1181 | private: |
| 1182 | int64_t m; |
| 1183 | int64_t r; |
| 1184 | }; |
| 1185 | } // end anonymous namespace |
| 1186 | |
| 1187 | //===----------------------------------------------------------------------===// |
| 1188 | FailureOr<Operation *> winogradConv2D(RewriterBase &rewriter, |
| 1189 | linalg::Conv2DNhwcFhwcOp op, int64_t m, |
| 1190 | int64_t r) { |
| 1191 | return winogradConv2DHelper(rewriter, op, m, r); |
| 1192 | } |
| 1193 | |
| 1194 | FailureOr<Operation *> |
| 1195 | decomposeWinogradFilterTransformOp(RewriterBase &rewriter, |
| 1196 | linalg::WinogradFilterTransformOp op) { |
| 1197 | return decomposeWinogradFilterTransformHelper(rewriter, op); |
| 1198 | } |
| 1199 | |
| 1200 | FailureOr<Operation *> |
| 1201 | decomposeWinogradInputTransformOp(RewriterBase &rewriter, |
| 1202 | linalg::WinogradInputTransformOp op) { |
| 1203 | return decomposeWinogradInputTransformHelper(rewriter, op); |
| 1204 | } |
| 1205 | |
| 1206 | FailureOr<Operation *> |
| 1207 | decomposeWinogradOutputTransformOp(RewriterBase &rewriter, |
| 1208 | linalg::WinogradOutputTransformOp op) { |
| 1209 | return decomposeWinogradOutputTransformHelper(rewriter, op); |
| 1210 | } |
| 1211 | |
| 1212 | void populateWinogradConv2DPatterns(RewritePatternSet &patterns, int64_t m, |
| 1213 | int64_t r) { |
| 1214 | MLIRContext *context = patterns.getContext(); |
| 1215 | // TODO: Support more Conv2D data layout, e.g., conv_2d_nchw_fchw |
| 1216 | patterns.insert<WinogradConv2DNhwcFhwc>(arg&: context, args&: m, args&: r); |
| 1217 | } |
| 1218 | |
| 1219 | void populateDecomposeWinogradOpsPatterns(RewritePatternSet &patterns) { |
| 1220 | MLIRContext *context = patterns.getContext(); |
| 1221 | patterns |
| 1222 | .insert<DecomposeWinogradFilterTransform, DecomposeWinogradInputTransform, |
| 1223 | DecomposeWinogradOutputTransform>(arg&: context); |
| 1224 | } |
| 1225 | |
| 1226 | } // end namespace linalg |
| 1227 | } // end namespace mlir |
| 1228 | |