1//===- ConvertConv2DToImg2Col.cpp - im2col implementation -----------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#include "mlir/Dialect/Affine/Utils.h"
10#include "mlir/Dialect/Arith/IR/Arith.h"
11#include "mlir/Dialect/Complex/IR/Complex.h"
12#include "mlir/Dialect/Linalg/IR/Linalg.h"
13#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
14#include "mlir/Dialect/Tensor/IR/Tensor.h"
15#include "mlir/Dialect/Utils/ReshapeOpsUtils.h"
16#include "mlir/Dialect/Utils/StructuredOpsUtils.h"
17#include "mlir/IR/AffineExpr.h"
18#include "mlir/IR/AffineMap.h"
19#include "mlir/IR/Builders.h"
20#include "mlir/IR/BuiltinAttributes.h"
21#include "mlir/IR/BuiltinTypes.h"
22#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
23#include <utility>
24
25namespace mlir {
26namespace linalg {
27static bool hasAllOneValues(DenseIntElementsAttr attr) {
28 return llvm::all_of(
29 Range&: attr, P: [](const APInt &element) { return element.getSExtValue() == 1; });
30}
31
32static Value createAdd(Location loc, Value x, Value y, OpBuilder &builder) {
33 if (isa<IntegerType>(x.getType()))
34 return builder.create<arith::AddIOp>(loc, x, y);
35 if (isa<ComplexType>(x.getType()))
36 return builder.create<complex::AddOp>(loc, x, y);
37 return builder.create<arith::AddFOp>(loc, x, y);
38}
39
40static Value createMul(Location loc, Value x, Value y, Type accType,
41 OpBuilder &builder) {
42 // Linalg named ops specify signed extend for named ops.
43 Value xConvert =
44 convertScalarToDtype(b&: builder, loc, operand: x, toType: accType, /*isUnsignedCast=*/false);
45 Value yConvert =
46 convertScalarToDtype(b&: builder, loc, operand: y, toType: accType, /*isUnsignedCast=*/false);
47 if (isa<ComplexType>(accType))
48 return builder.create<complex::MulOp>(loc, xConvert, yConvert);
49 if (isa<IntegerType>(accType))
50 return builder.create<arith::MulIOp>(loc, xConvert, yConvert);
51 return builder.create<arith::MulFOp>(loc, xConvert, yConvert);
52}
53
54// Delinearizes the given composite `index` by the basis specified in `factors`.
55static SmallVector<Value> unrollIndex(OpBuilder &b, Location loc, Value index,
56 ArrayRef<int64_t> factors) {
57 assert(!factors.empty() && "empty factor list");
58 SmallVector<Value> basis;
59 for (int64_t f : factors)
60 basis.push_back(b.create<arith::ConstantOp>(loc, b.getIndexAttr(f)));
61 FailureOr<SmallVector<Value>> multiIndex =
62 affine::delinearizeIndex(b, loc, linearIndex: index, basis);
63 assert(!failed(multiIndex) && "Failed to linearize img2col index");
64 return *multiIndex;
65}
66
67// Given indices corresponding to iterators in the output (oIndex) and filter
68// (fIndex) for a convolution, compute the convolved index for the
69// input as `oIndex * stride + fIndex`.
70static Value getConvolvedIndex(OpBuilder &b, Location loc, Value oIndex,
71 Value fIndex, int64_t stride) {
72 AffineExpr oExpr, fExpr;
73 bindSymbols(ctx: b.getContext(), exprs&: oExpr, exprs&: fExpr);
74 AffineMap convMap = AffineMap::get(dimCount: 0, symbolCount: 2, result: stride * oExpr + fExpr);
75 return affine::makeComposedAffineApply(b, loc, convMap, {oIndex, fIndex});
76}
77
78FailureOr<std::pair<Operation *, Operation *>>
79rewriteInIm2Col(RewriterBase &rewriter, linalg::Conv2DNhwcHwcfOp convOp) {
80 auto inputType = cast<ShapedType>(convOp.getInputs()[0].getType());
81 auto filterType = cast<ShapedType>(convOp.getInputs()[1].getType());
82 auto outputType = cast<ShapedType>(convOp.getOutputs()[0].getType());
83
84 if (!filterType.hasStaticShape())
85 return rewriter.notifyMatchFailure(
86 convOp, "expected a static shape for the filter");
87
88 if (!inputType.hasStaticShape())
89 return rewriter.notifyMatchFailure(convOp,
90 "expected a static shape for the input");
91
92 // TODO: Support dilation.
93 if (!hasAllOneValues(convOp.getDilations()))
94 return rewriter.notifyMatchFailure(convOp,
95 "expected all ones for dilations");
96
97 MLIRContext *context = rewriter.getContext();
98 Value input = convOp.getInputs()[0];
99 Value filter = convOp.getInputs()[1];
100 Value output = convOp.getOutputs()[0];
101
102 ArrayRef<int64_t> filterShape = filterType.getShape();
103 ArrayRef<int64_t> outputShape = outputType.getShape();
104
105 int64_t n = outputShape[0];
106 int64_t oh = outputShape[1];
107 int64_t ow = outputShape[2];
108 int64_t oc = outputShape[3];
109 int64_t fh = filterShape[0];
110 int64_t fw = filterShape[1];
111 int64_t ic = filterShape[2];
112
113 Location loc = convOp.getLoc();
114
115 // Reshape output and filter to the LHS and result of a (B)MNK matmul.
116 SmallVector<ReassociationIndices> filterReassocIndices = {{0, 1, 2}, {3}};
117 auto reshapedFilterType =
118 RankedTensorType::get({fh * fw * ic, oc}, filterType.getElementType());
119 Value reshapedFilter = rewriter.create<tensor::CollapseShapeOp>(
120 loc, reshapedFilterType, filter, filterReassocIndices);
121
122 SmallVector<ReassociationIndices> outputReassocIndices = {{0}, {1, 2}, {3}};
123 RankedTensorType reshapedOutputType =
124 RankedTensorType::get({n, oh * ow, oc}, outputType.getElementType());
125 Value reshapedOutput = rewriter.create<tensor::CollapseShapeOp>(
126 loc, reshapedOutputType, output, outputReassocIndices);
127
128 SmallVector<int64_t> colTensorShape = {n, oh * ow, fh * fw * ic};
129 Value colTensor = rewriter.create<tensor::EmptyOp>(
130 loc, colTensorShape, inputType.getElementType());
131
132 // Convert the input to a (BMK) column tensor.
133 auto nloops = colTensorShape.size();
134
135 auto parallel = utils::IteratorType::parallel;
136 auto reduction = utils::IteratorType::reduction;
137 SmallVector<utils::IteratorType> img2colIterators(nloops, parallel);
138
139 SmallVector<AffineMap> img2colIndexingMaps = {
140 AffineMap::getMultiDimIdentityMap(numDims: nloops, context)};
141
142 auto img2ColTensor = rewriter.create<linalg::GenericOp>(
143 loc, colTensor.getType(),
144 /*inputs=*/ValueRange{}, /*outputs=*/colTensor, img2colIndexingMaps,
145 img2colIterators,
146 [&](OpBuilder &nestedBuilder, Location nestedLoc, ValueRange args) {
147 // Get the iterators named based on the matmul (batch, m, k).
148 Value bIndex = nestedBuilder.create<linalg::IndexOp>(loc, 0);
149 Value mIndex = nestedBuilder.create<linalg::IndexOp>(loc, 1);
150 Value kIndex = nestedBuilder.create<linalg::IndexOp>(loc, 2);
151
152 // Recover the original iteration indices from the problem/input sizes.
153 SmallVector<Value> mIndices = unrollIndex(
154 nestedBuilder, nestedLoc, mIndex, ArrayRef<int64_t>{oh, ow});
155 auto ohIndex = mIndices[0];
156 auto owIndex = mIndices[1];
157
158 SmallVector<Value> kIndices = unrollIndex(
159 nestedBuilder, nestedLoc, kIndex, ArrayRef<int64_t>{fh, fw, ic});
160 auto fhIndex = kIndices[0];
161 auto fwIndex = kIndices[1];
162 auto icIndex = kIndices[2];
163
164 // Extract the input element corresponding to the expanded indices.
165 Value hIndex =
166 getConvolvedIndex(nestedBuilder, nestedLoc, ohIndex, fhIndex,
167 convOp.getStrides().getValues<int64_t>()[0]);
168 Value wIndex =
169 getConvolvedIndex(nestedBuilder, nestedLoc, owIndex, fwIndex,
170 convOp.getStrides().getValues<int64_t>()[1]);
171
172 // im2col[n, oh*ow, fh*fw*ic] = input[n, sh*oh + fh, sw*ow + fw, ic]
173 SmallVector<Value> extractionIndices{bIndex, hIndex, wIndex, icIndex};
174 Value inputVal = nestedBuilder.create<tensor::ExtractOp>(
175 loc, input, extractionIndices);
176 nestedBuilder.create<linalg::YieldOp>(nestedLoc, inputVal);
177 });
178
179 // Because the filter does not share the same batch dimension,
180 // the batch dimension is only used in indexing the input and output. Thus
181 // we cannot use existing linalg named ops like linalg.batch_matmul.
182 // i.e. (B x) M x K * K x N = (B x) M x N
183 AffineExpr bDim, mDim, nDim, kDim;
184 bindDims(ctx: context, exprs&: bDim, exprs&: mDim, exprs&: nDim, exprs&: kDim);
185 auto lhsMap = AffineMap::get(dimCount: 4, symbolCount: 0, results: {bDim, mDim, kDim}, context);
186 auto rhsMap = AffineMap::get(dimCount: 4, symbolCount: 0, results: {kDim, nDim}, context);
187 auto resultMap = AffineMap::get(dimCount: 4, symbolCount: 0, results: {bDim, mDim, nDim}, context);
188 SmallVector<utils::IteratorType> genericIterators = {parallel, parallel,
189 parallel, reduction};
190
191 auto genericOp = rewriter.create<linalg::GenericOp>(
192 loc, reshapedOutputType,
193 /*inputs=*/ValueRange{img2ColTensor.getResult(0), reshapedFilter},
194 /*outputs=*/ValueRange{reshapedOutput},
195 ArrayRef<AffineMap>{lhsMap, rhsMap, resultMap}, genericIterators,
196 [&](OpBuilder &nestedBuilder, Location nestedLoc, ValueRange args) {
197 Value mul =
198 createMul(loc, args[0], args[1], args[2].getType(), nestedBuilder);
199 Value add = createAdd(loc, mul, args[2], nestedBuilder);
200 nestedBuilder.create<linalg::YieldOp>(nestedLoc, add);
201 });
202 Value result = genericOp.getResults().front();
203
204 auto reshapedResult = rewriter.create<tensor::ExpandShapeOp>(
205 loc, outputType, result, outputReassocIndices);
206
207 rewriter.replaceOp(convOp, ArrayRef<Value>{reshapedResult});
208
209 return std::make_pair(img2ColTensor.getOperation(),
210 reshapedResult.getOperation());
211}
212
213FailureOr<std::pair<Operation *, Operation *>>
214rewriteInIm2Col(RewriterBase &rewriter,
215 linalg::DepthwiseConv2DNhwcHwcOp convOp) {
216 auto inputType = cast<RankedTensorType>(convOp.getInputs()[0].getType());
217 auto filterType = cast<RankedTensorType>(convOp.getInputs()[1].getType());
218 auto outputType = cast<RankedTensorType>(convOp.getOutputs()[0].getType());
219
220 if (!filterType.hasStaticShape())
221 return rewriter.notifyMatchFailure(
222 convOp, "expected a static shape for the filter");
223
224 if (!inputType.hasStaticShape())
225 return rewriter.notifyMatchFailure(convOp,
226 "expected a static shape for the input");
227
228 // TODO: Support dilation.
229 if (!hasAllOneValues(convOp.getDilations()))
230 return rewriter.notifyMatchFailure(convOp,
231 "expected all ones for dilations");
232
233 Location loc = convOp.getLoc();
234
235 auto transposeOperand = [&](Value operand, ArrayRef<int64_t> indices) {
236 auto operandTensorType = cast<RankedTensorType>(operand.getType());
237 auto nloops = indices.size();
238 ArrayRef<int64_t> inputShape = operandTensorType.getShape();
239
240 SmallVector<AffineExpr> exprs = llvm::to_vector<4>(
241 Range: llvm::map_range(C&: indices, F: [&](int64_t index) -> AffineExpr {
242 return rewriter.getAffineDimExpr(position: index);
243 }));
244
245 SmallVector<int64_t> targetShape = llvm::to_vector<4>(Range: llvm::map_range(
246 C&: indices, F: [&](int64_t index) -> int64_t { return inputShape[index]; }));
247
248 Value outputTensor = rewriter.create<tensor::EmptyOp>(
249 loc, targetShape, operandTensorType.getElementType());
250
251 SmallVector<utils::IteratorType> loopAttributeTypes(
252 nloops, utils::IteratorType::parallel);
253
254 SmallVector<AffineMap> indexingMaps = {
255 inversePermutation(
256 map: AffineMap::get(dimCount: nloops, symbolCount: 0, results: exprs, context: rewriter.getContext())),
257 AffineMap::getMultiDimIdentityMap(numDims: nloops, context: rewriter.getContext())};
258
259 auto transposedOp = rewriter.create<linalg::GenericOp>(
260 loc, outputTensor.getType(),
261 /*inputs=*/operand, /*outputs=*/outputTensor, indexingMaps,
262 loopAttributeTypes,
263 [&](OpBuilder &nestedBuilder, Location nestedLoc, ValueRange args) {
264 nestedBuilder.create<linalg::YieldOp>(nestedLoc, args[0]);
265 });
266
267 return transposedOp.getResult(0);
268 };
269
270 Value input = convOp.getInputs()[0];
271 Value filter = convOp.getInputs()[1];
272 Value output = convOp.getOutputs()[0];
273
274 // Transpose input, filter so channels are outermost
275 Value inputT = transposeOperand(input, {0, 3, 1, 2});
276 Value filterT = transposeOperand(filter, {2, 0, 1});
277 ArrayRef<int64_t> filterTShape =
278 cast<RankedTensorType>(filterT.getType()).getShape();
279 ArrayRef<int64_t> outputShape = outputType.getShape();
280
281 int n = outputShape[0];
282 int oh = outputShape[1];
283 int ow = outputShape[2];
284 int c = outputShape[3];
285 int fh = filterTShape[1];
286 int fw = filterTShape[2];
287
288 SmallVector<int64_t> colTensorShape = {n, c, oh, ow, fh, fw};
289 Value transposedOutputTensor = transposeOperand(output, {0, 3, 1, 2});
290
291 AffineExpr nDim, cDim, ohDim, owDim, khDim, kwDim;
292 bindDims(ctx: rewriter.getContext(), exprs&: nDim, exprs&: cDim, exprs&: ohDim, exprs&: owDim, exprs&: khDim, exprs&: kwDim);
293
294 AffineExpr shSym = rewriter.getAffineConstantExpr(
295 constant: convOp.getStrides().getValues<int64_t>()[0]);
296 AffineExpr swSym = rewriter.getAffineConstantExpr(
297 constant: convOp.getStrides().getValues<int64_t>()[1]);
298
299 SmallVector<AffineExpr> inputExprs = {nDim, cDim, ohDim * shSym + khDim,
300 owDim * swSym + kwDim};
301
302 auto nloops = colTensorShape.size();
303
304 SmallVector<utils::IteratorType> loopAttributeTypes(
305 nloops, utils::IteratorType::parallel);
306
307 SmallVector<AffineMap> indexingMaps = {
308 AffineMap::get(nloops, 0, inputExprs, rewriter.getContext()),
309 AffineMap::getMultiDimIdentityMap(numDims: nloops, context: rewriter.getContext())};
310
311 Value colTensor = rewriter.create<tensor::EmptyOp>(
312 loc, colTensorShape, inputType.getElementType());
313
314 auto img2ColTensor = rewriter.create<linalg::GenericOp>(
315 loc, colTensor.getType(),
316 /*inputs=*/inputT, /*outputs=*/colTensor, indexingMaps,
317 loopAttributeTypes,
318 [&](OpBuilder &nestedBuilder, Location nestedLoc, ValueRange args) {
319 nestedBuilder.create<linalg::YieldOp>(nestedLoc, args[0]);
320 });
321
322 SmallVector<ReassociationIndices> img2ColTensorReassocIndices = {
323 {0, 1}, {2, 3}, {4, 5}};
324 SmallVector<ReassociationIndices> filterReassociationIndice = {{0}, {1, 2}};
325 SmallVector<ReassociationIndices> outputReassociationIndice = {{0, 1},
326 {2, 3}};
327
328 auto reshapedImg2ColTensorType = RankedTensorType::get(
329 {n * c, oh * ow, fh * fw}, inputType.getElementType());
330 auto reshapedFilterTensorType =
331 RankedTensorType::get({c, fh * fw}, filterType.getElementType());
332 auto reshapedOutputTensorType =
333 RankedTensorType::get({n * c, oh * ow}, outputType.getElementType());
334
335 Value reshapedImg2ColTensor = rewriter.create<tensor::CollapseShapeOp>(
336 loc, reshapedImg2ColTensorType, img2ColTensor.getResult(0),
337 img2ColTensorReassocIndices);
338 Value reshapedFilterTensor = rewriter.create<tensor::CollapseShapeOp>(
339 loc, reshapedFilterTensorType, filterT, filterReassociationIndice);
340 Value reshapedoutputTensor = rewriter.create<tensor::CollapseShapeOp>(
341 loc, reshapedOutputTensorType, transposedOutputTensor,
342 outputReassociationIndice);
343
344 auto batchMatVecResult = rewriter.create<linalg::BatchMatvecOp>(
345 loc, TypeRange{reshapedoutputTensor.getType()},
346 ValueRange{reshapedImg2ColTensor, reshapedFilterTensor},
347 ValueRange{reshapedoutputTensor});
348
349 SmallVector<ReassociationIndices> batchMatVecReassociationIndice = {{0, 1},
350 {2, 3}};
351
352 Value batchMatVecResultReshaped = rewriter.create<tensor::ExpandShapeOp>(
353 loc, transposedOutputTensor.getType(), batchMatVecResult.getResult(0),
354 batchMatVecReassociationIndice);
355
356 Value transposedResult =
357 transposeOperand(batchMatVecResultReshaped, {0, 2, 3, 1});
358
359 rewriter.replaceOp(convOp, ArrayRef<Value>{transposedResult});
360 return std::make_pair(img2ColTensor.getOperation(),
361 transposedResult.getDefiningOp());
362}
363
364FailureOr<std::pair<Operation *, Operation *>>
365rewriteInIm2Col(RewriterBase &rewriter, linalg::Conv2DNchwFchwOp convOp) {
366 auto inputType = cast<ShapedType>(convOp.getInputs()[0].getType());
367 auto filterType = cast<ShapedType>(convOp.getInputs()[1].getType());
368 auto outputType = cast<ShapedType>(convOp.getOutputs()[0].getType());
369
370 if (!filterType.hasStaticShape())
371 return rewriter.notifyMatchFailure(
372 convOp, "expected a static shape for the filter");
373
374 if (!inputType.hasStaticShape())
375 return rewriter.notifyMatchFailure(convOp,
376 "expected a static shape for the input");
377
378 // TODO: Support dilation.
379 if (!hasAllOneValues(convOp.getDilations()))
380 return rewriter.notifyMatchFailure(convOp,
381 "expected all ones for dilations");
382
383 Value input = convOp.getInputs()[0];
384 Value filter = convOp.getInputs()[1];
385 Value output = convOp.getOutputs()[0];
386
387 auto filterShape = filterType.getShape();
388 auto outputShape = outputType.getShape();
389
390 int64_t n = outputShape[0];
391 int64_t oc = outputShape[1];
392 int64_t oh = outputShape[2];
393 int64_t ow = outputShape[3];
394 int64_t ic = filterShape[1];
395 int64_t fh = filterShape[2];
396 int64_t fw = filterShape[3];
397
398 auto loc = convOp.getLoc();
399 MLIRContext *context = rewriter.getContext();
400
401 SmallVector<ReassociationIndices> filterReassocIndices = {{0}, {1, 2, 3}};
402 auto reshapedFilterType =
403 RankedTensorType::get({oc, ic * fh * fw}, inputType.getElementType());
404 Value reshapedFilter = rewriter.create<tensor::CollapseShapeOp>(
405 loc, reshapedFilterType, filter, filterReassocIndices);
406
407 SmallVector<ReassociationIndices> outputReassocIndices = {{0}, {1}, {2, 3}};
408 auto reshapedOutputType =
409 RankedTensorType::get({n, oc, oh * ow}, outputType.getElementType());
410 Value reshapedOutput = rewriter.create<tensor::CollapseShapeOp>(
411 loc, reshapedOutputType, output, outputReassocIndices);
412
413 // Convert the input to a (BKN) tensor.
414 SmallVector<int64_t, 4> colTensorShape = {n, ic * fh * fw, oh * ow};
415 Value colTensor = rewriter.create<tensor::EmptyOp>(
416 loc, colTensorShape, inputType.getElementType());
417
418 auto nloops = colTensorShape.size();
419
420 auto parallel = utils::IteratorType::parallel;
421 auto reduction = utils::IteratorType::reduction;
422 SmallVector<utils::IteratorType, 3> img2colIterators(nloops, parallel);
423
424 SmallVector<AffineMap, 4> img2colIndexingMaps = {
425 AffineMap::getMultiDimIdentityMap(numDims: nloops, context)};
426
427 auto img2ColTensor = rewriter.create<linalg::GenericOp>(
428 loc, colTensor.getType(),
429 /*inputs=*/ValueRange{}, /*outputs=*/colTensor, img2colIndexingMaps,
430 img2colIterators,
431 [&](OpBuilder &nestedBuilder, Location nestedLoc, ValueRange args) {
432 // Get the iterators named based on the matmul (batch, m, k).
433 Value bIndex = nestedBuilder.create<linalg::IndexOp>(loc, 0);
434 Value kIndex = nestedBuilder.create<linalg::IndexOp>(loc, 1);
435 Value nIndex = nestedBuilder.create<linalg::IndexOp>(loc, 2);
436
437 // Recover the original iteration indices from the problem/input sizes.
438 SmallVector<Value> kIndices = unrollIndex(
439 nestedBuilder, nestedLoc, kIndex, ArrayRef<int64_t>{ic, fh, fw});
440 auto icIndex = kIndices[0];
441 auto fhIndex = kIndices[1];
442 auto fwIndex = kIndices[2];
443
444 SmallVector<Value> nIndices = unrollIndex(
445 nestedBuilder, nestedLoc, nIndex, ArrayRef<int64_t>{oh, ow});
446 auto ohIndex = nIndices[0];
447 auto owIndex = nIndices[1];
448
449 // Extract the input element corresponding to the expanded indices.
450 Value hIndex =
451 getConvolvedIndex(nestedBuilder, nestedLoc, ohIndex, fhIndex,
452 convOp.getStrides().getValues<int64_t>()[0]);
453 Value wIndex =
454 getConvolvedIndex(nestedBuilder, nestedLoc, owIndex, fwIndex,
455 convOp.getStrides().getValues<int64_t>()[1]);
456
457 // im2col[n, ic*fh*fw, oh*ow] = input[n, ic, sh*oh + fh, sw*ow + fw]
458 SmallVector<Value> extractionIndices{bIndex, icIndex, hIndex, wIndex};
459 Value inputVal = nestedBuilder.create<tensor::ExtractOp>(
460 loc, input, extractionIndices);
461 nestedBuilder.create<linalg::YieldOp>(nestedLoc, inputVal);
462 });
463
464 // Because the filter does not share the same batch dimension,
465 // the batch dimension is only used in indexing the input and output. Thus
466 // we cannot use existing linalg named ops like linalg.batch_matmul.
467 // i.e. M x K * (B x) K x N = (B x) M x N
468 AffineExpr bDim, mDim, nDim, kDim;
469 bindDims(ctx: context, exprs&: bDim, exprs&: mDim, exprs&: nDim, exprs&: kDim);
470 auto lhsMap = AffineMap::get(dimCount: 4, symbolCount: 0, results: {mDim, kDim}, context);
471 auto rhsMap = AffineMap::get(dimCount: 4, symbolCount: 0, results: {bDim, kDim, nDim}, context);
472 auto resultMap = AffineMap::get(dimCount: 4, symbolCount: 0, results: {bDim, mDim, nDim}, context);
473 SmallVector<utils::IteratorType> genericIterators = {parallel, parallel,
474 parallel, reduction};
475 auto genericOp = rewriter.create<linalg::GenericOp>(
476 loc, reshapedOutputType,
477 /*inputs=*/ValueRange{reshapedFilter, img2ColTensor.getResult(0)},
478 /*outputs=*/ValueRange{reshapedOutput},
479 ArrayRef<AffineMap>{lhsMap, rhsMap, resultMap}, genericIterators,
480 [&](OpBuilder &nestedBuilder, Location nestedLoc, ValueRange args) {
481 Value mul =
482 createMul(loc, args[0], args[1], args[2].getType(), nestedBuilder);
483 Value add = createAdd(loc, mul, args[2], nestedBuilder);
484 nestedBuilder.create<linalg::YieldOp>(nestedLoc, add);
485 });
486 Value result = genericOp.getResults().front();
487
488 auto reshapedResult = rewriter.create<tensor::ExpandShapeOp>(
489 loc, outputType, result, outputReassocIndices);
490
491 rewriter.replaceOp(convOp, ArrayRef<Value>{reshapedResult});
492
493 return std::make_pair(img2ColTensor.getOperation(),
494 reshapedResult.getOperation());
495}
496
497FailureOr<std::pair<Operation *, Operation *>>
498rewriteInIm2Col(RewriterBase &rewriter, linalg::Conv2DNhwcFhwcOp convOp) {
499 auto inputType = cast<ShapedType>(convOp.getInputs()[0].getType());
500 auto filterType = cast<ShapedType>(convOp.getInputs()[1].getType());
501 auto outputType = cast<ShapedType>(convOp.getOutputs()[0].getType());
502
503 if (!filterType.hasStaticShape())
504 return rewriter.notifyMatchFailure(
505 convOp, "expected a static shape for the filter");
506
507 if (!inputType.hasStaticShape())
508 return rewriter.notifyMatchFailure(convOp,
509 "expected a static shape for the input");
510
511 // TODO: Support dilation.
512 if (!hasAllOneValues(convOp.getDilations()))
513 return rewriter.notifyMatchFailure(convOp,
514 "expected all ones for dilations");
515
516 MLIRContext *context = rewriter.getContext();
517 Value input = convOp.getInputs()[0];
518 Value filter = convOp.getInputs()[1];
519 Value output = convOp.getOutputs()[0];
520
521 ArrayRef<int64_t> filterShape = filterType.getShape();
522 ArrayRef<int64_t> outputShape = outputType.getShape();
523
524 int64_t n = outputShape[0];
525 int64_t oh = outputShape[1];
526 int64_t ow = outputShape[2];
527 int64_t oc = outputShape[3];
528 int64_t fh = filterShape[1];
529 int64_t fw = filterShape[2];
530 int64_t ic = filterShape[3];
531
532 Location loc = convOp.getLoc();
533
534 // Reshape output and filter to the LHS and result of a "row-wise" matrix
535 // multiplication.
536 SmallVector<ReassociationIndices> filterReassocIndices = {{0}, {1, 2, 3}};
537 auto reshapedFilterType =
538 RankedTensorType::get({oc, fh * fw * ic}, filterType.getElementType());
539 Value reshapedFilter = rewriter.create<tensor::CollapseShapeOp>(
540 loc, reshapedFilterType, filter, filterReassocIndices);
541
542 SmallVector<ReassociationIndices> outputReassocIndices = {{0}, {1, 2}, {3}};
543 RankedTensorType reshapedOutputType =
544 RankedTensorType::get({n, oh * ow, oc}, outputType.getElementType());
545 Value reshapedOutput = rewriter.create<tensor::CollapseShapeOp>(
546 loc, reshapedOutputType, output, outputReassocIndices);
547
548 SmallVector<int64_t> colTensorShape = {n, oh * ow, fh * fw * ic};
549 Value colTensor = rewriter.create<tensor::EmptyOp>(
550 loc, colTensorShape, inputType.getElementType());
551
552 // Convert the input to a (BMK) column tensor.
553 auto nloops = colTensorShape.size();
554
555 auto parallel = utils::IteratorType::parallel;
556 auto reduction = utils::IteratorType::reduction;
557 SmallVector<utils::IteratorType> img2colIterators(nloops, parallel);
558
559 SmallVector<AffineMap> img2colIndexingMaps = {
560 AffineMap::getMultiDimIdentityMap(numDims: nloops, context)};
561
562 auto img2ColTensor = rewriter.create<linalg::GenericOp>(
563 loc, colTensor.getType(),
564 /*inputs=*/ValueRange{}, /*outputs=*/colTensor, img2colIndexingMaps,
565 img2colIterators,
566 [&](OpBuilder &nestedBuilder, Location nestedLoc, ValueRange args) {
567 // Get the iterators named based on the matmul (batch, m, k).
568 Value bIndex = nestedBuilder.create<linalg::IndexOp>(loc, 0);
569 Value mIndex = nestedBuilder.create<linalg::IndexOp>(loc, 1);
570 Value kIndex = nestedBuilder.create<linalg::IndexOp>(loc, 2);
571
572 // Recover the original iteration indices from the problem/input sizes.
573 SmallVector<Value> mIndices = unrollIndex(
574 nestedBuilder, nestedLoc, mIndex, ArrayRef<int64_t>{oh, ow});
575 auto ohIndex = mIndices[0];
576 auto owIndex = mIndices[1];
577
578 SmallVector<Value> kIndices = unrollIndex(
579 nestedBuilder, nestedLoc, kIndex, ArrayRef<int64_t>{fh, fw, ic});
580 auto fhIndex = kIndices[0];
581 auto fwIndex = kIndices[1];
582 auto icIndex = kIndices[2];
583
584 // Extract the input element corresponding to the expanded indices.
585 Value hIndex =
586 getConvolvedIndex(nestedBuilder, nestedLoc, ohIndex, fhIndex,
587 convOp.getStrides().getValues<int64_t>()[0]);
588 Value wIndex =
589 getConvolvedIndex(nestedBuilder, nestedLoc, owIndex, fwIndex,
590 convOp.getStrides().getValues<int64_t>()[1]);
591
592 // im2col[n, oh*ow, fh*fw*ic] = input[n, sh*oh + fh, sw*ow + fw, ic]
593 SmallVector<Value> extractionIndices{bIndex, hIndex, wIndex, icIndex};
594 Value inputVal = nestedBuilder.create<tensor::ExtractOp>(
595 loc, input, extractionIndices);
596 nestedBuilder.create<linalg::YieldOp>(nestedLoc, inputVal);
597 });
598
599 // Because we didn't transpose the filters we don't actually have a batched
600 // matrix multiply. Instead, we have an operation consisting of "row-wise" dot
601 // products.
602 AffineExpr bDim, mDim, nDim, kDim;
603 bindDims(ctx: context, exprs&: bDim, exprs&: mDim, exprs&: nDim, exprs&: kDim);
604 auto lhsMap = AffineMap::get(dimCount: 4, symbolCount: 0, results: {bDim, mDim, kDim}, context);
605 auto rhsMap = AffineMap::get(dimCount: 4, symbolCount: 0, results: {nDim, kDim}, context);
606 auto resultMap = AffineMap::get(dimCount: 4, symbolCount: 0, results: {bDim, mDim, nDim}, context);
607 SmallVector<utils::IteratorType> genericIterators = {parallel, parallel,
608 parallel, reduction};
609
610 auto genericOp = rewriter.create<linalg::GenericOp>(
611 loc, reshapedOutputType,
612 /*inputs=*/ValueRange{img2ColTensor.getResult(0), reshapedFilter},
613 /*outputs=*/ValueRange{reshapedOutput},
614 ArrayRef<AffineMap>{lhsMap, rhsMap, resultMap}, genericIterators,
615 [&](OpBuilder &nestedBuilder, Location nestedLoc, ValueRange args) {
616 Value mul =
617 createMul(loc, args[0], args[1], args[2].getType(), nestedBuilder);
618 Value add = createAdd(loc, mul, args[2], nestedBuilder);
619 nestedBuilder.create<linalg::YieldOp>(nestedLoc, add);
620 });
621 Value result = genericOp.getResults().front();
622
623 auto reshapedResult = rewriter.create<tensor::ExpandShapeOp>(
624 loc, outputType, result, outputReassocIndices);
625
626 rewriter.replaceOp(convOp, ArrayRef<Value>{reshapedResult});
627
628 return std::make_pair(img2ColTensor.getOperation(),
629 reshapedResult.getOperation());
630}
631
632namespace {
633
634class ConvertConv2DNhwcHwcf final
635 : public OpRewritePattern<linalg::Conv2DNhwcHwcfOp> {
636public:
637 using OpRewritePattern::OpRewritePattern;
638
639 LogicalResult matchAndRewrite(linalg::Conv2DNhwcHwcfOp convOp,
640 PatternRewriter &rewriter) const override {
641 if (failed(rewriteInIm2Col(rewriter, convOp)))
642 return failure();
643 return success();
644 }
645};
646
647class ConvertDepthwiseConv2DNhwcHwc final
648 : public OpRewritePattern<linalg::DepthwiseConv2DNhwcHwcOp> {
649public:
650 using OpRewritePattern<linalg::DepthwiseConv2DNhwcHwcOp>::OpRewritePattern;
651
652 LogicalResult matchAndRewrite(linalg::DepthwiseConv2DNhwcHwcOp convOp,
653 PatternRewriter &rewriter) const override {
654 if (failed(rewriteInIm2Col(rewriter, convOp)))
655 return failure();
656 return success();
657 }
658};
659
660class ConvertConv2DNchwFchw final
661 : public OpRewritePattern<linalg::Conv2DNchwFchwOp> {
662public:
663 using OpRewritePattern::OpRewritePattern;
664
665 LogicalResult matchAndRewrite(linalg::Conv2DNchwFchwOp convOp,
666 PatternRewriter &rewriter) const override {
667 if (failed(rewriteInIm2Col(rewriter, convOp)))
668 return failure();
669 return success();
670 }
671};
672
673class ConvertConv2DNhwcFhwc final
674 : public OpRewritePattern<linalg::Conv2DNhwcFhwcOp> {
675public:
676 using OpRewritePattern::OpRewritePattern;
677
678 LogicalResult matchAndRewrite(linalg::Conv2DNhwcFhwcOp convOp,
679 PatternRewriter &rewriter) const override {
680 if (failed(rewriteInIm2Col(rewriter, convOp)))
681 return failure();
682 return success();
683 }
684};
685} // end anonymous namespace
686
687void populateConvertConv2DToImg2ColPatterns(RewritePatternSet &patterns) {
688 MLIRContext *context = patterns.getContext();
689 patterns.insert<ConvertConv2DNhwcHwcf, ConvertDepthwiseConv2DNhwcHwc,
690 ConvertConv2DNchwFchw, ConvertConv2DNhwcFhwc>(arg&: context);
691}
692} // end namespace linalg
693} // end namespace mlir
694

source code of mlir/lib/Dialect/Linalg/Transforms/ConvertConv2DToImg2Col.cpp