1//===- ConvertConv2DToImg2Col.cpp - im2col implementation -----------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#include "mlir/Dialect/Affine/Utils.h"
10#include "mlir/Dialect/Arith/IR/Arith.h"
11#include "mlir/Dialect/Complex/IR/Complex.h"
12#include "mlir/Dialect/Linalg/IR/Linalg.h"
13#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
14#include "mlir/Dialect/Tensor/IR/Tensor.h"
15#include "mlir/Dialect/Utils/ReshapeOpsUtils.h"
16#include "mlir/Dialect/Utils/StructuredOpsUtils.h"
17#include "mlir/IR/AffineExpr.h"
18#include "mlir/IR/AffineMap.h"
19#include "mlir/IR/Builders.h"
20#include "mlir/IR/BuiltinAttributes.h"
21#include "mlir/IR/BuiltinTypes.h"
22#include <utility>
23
24namespace mlir {
25namespace linalg {
26static bool hasAllOneValues(DenseIntElementsAttr attr) {
27 return llvm::all_of(
28 Range&: attr, P: [](const APInt &element) { return element.getSExtValue() == 1; });
29}
30
31static Value createAdd(Location loc, Value x, Value y, OpBuilder &builder) {
32 if (isa<IntegerType>(Val: x.getType()))
33 return builder.create<arith::AddIOp>(location: loc, args&: x, args&: y);
34 if (isa<ComplexType>(Val: x.getType()))
35 return builder.create<complex::AddOp>(location: loc, args&: x, args&: y);
36 return builder.create<arith::AddFOp>(location: loc, args&: x, args&: y);
37}
38
39static Value createMul(Location loc, Value x, Value y, Type accType,
40 OpBuilder &builder) {
41 // Linalg named ops specify signed extend for named ops.
42 Value xConvert =
43 convertScalarToDtype(b&: builder, loc, operand: x, toType: accType, /*isUnsignedCast=*/false);
44 Value yConvert =
45 convertScalarToDtype(b&: builder, loc, operand: y, toType: accType, /*isUnsignedCast=*/false);
46 if (isa<ComplexType>(Val: accType))
47 return builder.create<complex::MulOp>(location: loc, args&: xConvert, args&: yConvert);
48 if (isa<IntegerType>(Val: accType))
49 return builder.create<arith::MulIOp>(location: loc, args&: xConvert, args&: yConvert);
50 return builder.create<arith::MulFOp>(location: loc, args&: xConvert, args&: yConvert);
51}
52
53// Delinearizes the given composite `index` by the basis specified in `factors`.
54static SmallVector<Value> unrollIndex(OpBuilder &b, Location loc, Value index,
55 ArrayRef<int64_t> factors) {
56 assert(!factors.empty() && "empty factor list");
57 SmallVector<Value> basis;
58 for (int64_t f : factors)
59 basis.push_back(Elt: b.create<arith::ConstantOp>(location: loc, args: b.getIndexAttr(value: f)));
60 FailureOr<SmallVector<Value>> multiIndex =
61 affine::delinearizeIndex(b, loc, linearIndex: index, basis);
62 assert(!failed(multiIndex) && "Failed to linearize img2col index");
63 return *multiIndex;
64}
65
66// Given indices corresponding to iterators in the output (oIndex) and filter
67// (fIndex) for a convolution, compute the convolved index for the
68// input as `oIndex * stride + fIndex`.
69static Value getConvolvedIndex(OpBuilder &b, Location loc, Value oIndex,
70 Value fIndex, int64_t stride) {
71 AffineExpr oExpr, fExpr;
72 bindSymbols(ctx: b.getContext(), exprs&: oExpr, exprs&: fExpr);
73 AffineMap convMap = AffineMap::get(dimCount: 0, symbolCount: 2, result: stride * oExpr + fExpr);
74 return affine::makeComposedAffineApply(b, loc, map: convMap, operands: {oIndex, fIndex});
75}
76
77FailureOr<std::pair<Operation *, Operation *>>
78rewriteInIm2Col(RewriterBase &rewriter, linalg::Conv2DNhwcHwcfOp convOp) {
79 auto inputType = cast<ShapedType>(Val: convOp.getInputs()[0].getType());
80 auto filterType = cast<ShapedType>(Val: convOp.getInputs()[1].getType());
81 auto outputType = cast<ShapedType>(Val: convOp.getOutputs()[0].getType());
82
83 if (!filterType.hasStaticShape())
84 return rewriter.notifyMatchFailure(
85 arg&: convOp, msg: "expected a static shape for the filter");
86
87 if (!inputType.hasStaticShape())
88 return rewriter.notifyMatchFailure(arg&: convOp,
89 msg: "expected a static shape for the input");
90
91 // TODO: Support dilation.
92 if (!hasAllOneValues(attr: convOp.getDilations()))
93 return rewriter.notifyMatchFailure(arg&: convOp,
94 msg: "expected all ones for dilations");
95
96 MLIRContext *context = rewriter.getContext();
97 Value input = convOp.getInputs()[0];
98 Value filter = convOp.getInputs()[1];
99 Value output = convOp.getOutputs()[0];
100
101 ArrayRef<int64_t> filterShape = filterType.getShape();
102 ArrayRef<int64_t> outputShape = outputType.getShape();
103
104 int64_t n = outputShape[0];
105 int64_t oh = outputShape[1];
106 int64_t ow = outputShape[2];
107 int64_t oc = outputShape[3];
108 int64_t fh = filterShape[0];
109 int64_t fw = filterShape[1];
110 int64_t ic = filterShape[2];
111
112 Location loc = convOp.getLoc();
113
114 // Reshape output and filter to the LHS and result of a (B)MNK matmul.
115 SmallVector<ReassociationIndices> filterReassocIndices = {{0, 1, 2}, {3}};
116 auto reshapedFilterType =
117 RankedTensorType::get(shape: {fh * fw * ic, oc}, elementType: filterType.getElementType());
118 Value reshapedFilter = rewriter.create<tensor::CollapseShapeOp>(
119 location: loc, args&: reshapedFilterType, args&: filter, args&: filterReassocIndices);
120
121 SmallVector<ReassociationIndices> outputReassocIndices = {{0}, {1, 2}, {3}};
122 RankedTensorType reshapedOutputType =
123 RankedTensorType::get(shape: {n, oh * ow, oc}, elementType: outputType.getElementType());
124 Value reshapedOutput = rewriter.create<tensor::CollapseShapeOp>(
125 location: loc, args&: reshapedOutputType, args&: output, args&: outputReassocIndices);
126
127 SmallVector<int64_t> colTensorShape = {n, oh * ow, fh * fw * ic};
128 Value colTensor = rewriter.create<tensor::EmptyOp>(
129 location: loc, args&: colTensorShape, args: inputType.getElementType());
130
131 // Convert the input to a (BMK) column tensor.
132 auto nloops = colTensorShape.size();
133
134 auto parallel = utils::IteratorType::parallel;
135 auto reduction = utils::IteratorType::reduction;
136 SmallVector<utils::IteratorType> img2colIterators(nloops, parallel);
137
138 SmallVector<AffineMap> img2colIndexingMaps = {
139 AffineMap::getMultiDimIdentityMap(numDims: nloops, context)};
140
141 auto img2ColTensor = rewriter.create<linalg::GenericOp>(
142 location: loc, args: colTensor.getType(),
143 /*inputs=*/args: ValueRange{}, /*outputs=*/args&: colTensor, args&: img2colIndexingMaps,
144 args&: img2colIterators,
145 args: [&](OpBuilder &nestedBuilder, Location nestedLoc, ValueRange args) {
146 // Get the iterators named based on the matmul (batch, m, k).
147 Value bIndex = nestedBuilder.create<linalg::IndexOp>(location: loc, args: 0);
148 Value mIndex = nestedBuilder.create<linalg::IndexOp>(location: loc, args: 1);
149 Value kIndex = nestedBuilder.create<linalg::IndexOp>(location: loc, args: 2);
150
151 // Recover the original iteration indices from the problem/input sizes.
152 SmallVector<Value> mIndices = unrollIndex(
153 b&: nestedBuilder, loc: nestedLoc, index: mIndex, factors: ArrayRef<int64_t>{oh, ow});
154 auto ohIndex = mIndices[0];
155 auto owIndex = mIndices[1];
156
157 SmallVector<Value> kIndices = unrollIndex(
158 b&: nestedBuilder, loc: nestedLoc, index: kIndex, factors: ArrayRef<int64_t>{fh, fw, ic});
159 auto fhIndex = kIndices[0];
160 auto fwIndex = kIndices[1];
161 auto icIndex = kIndices[2];
162
163 // Extract the input element corresponding to the expanded indices.
164 Value hIndex =
165 getConvolvedIndex(b&: nestedBuilder, loc: nestedLoc, oIndex: ohIndex, fIndex: fhIndex,
166 stride: convOp.getStrides().getValues<int64_t>()[0]);
167 Value wIndex =
168 getConvolvedIndex(b&: nestedBuilder, loc: nestedLoc, oIndex: owIndex, fIndex: fwIndex,
169 stride: convOp.getStrides().getValues<int64_t>()[1]);
170
171 // im2col[n, oh*ow, fh*fw*ic] = input[n, sh*oh + fh, sw*ow + fw, ic]
172 SmallVector<Value> extractionIndices{bIndex, hIndex, wIndex, icIndex};
173 Value inputVal = nestedBuilder.create<tensor::ExtractOp>(
174 location: loc, args&: input, args&: extractionIndices);
175 nestedBuilder.create<linalg::YieldOp>(location: nestedLoc, args&: inputVal);
176 });
177
178 // Because the filter does not share the same batch dimension,
179 // the batch dimension is only used in indexing the input and output. Thus
180 // we cannot use existing linalg named ops like linalg.batch_matmul.
181 // i.e. (B x) M x K * K x N = (B x) M x N
182 AffineExpr bDim, mDim, nDim, kDim;
183 bindDims(ctx: context, exprs&: bDim, exprs&: mDim, exprs&: nDim, exprs&: kDim);
184 auto lhsMap = AffineMap::get(dimCount: 4, symbolCount: 0, results: {bDim, mDim, kDim}, context);
185 auto rhsMap = AffineMap::get(dimCount: 4, symbolCount: 0, results: {kDim, nDim}, context);
186 auto resultMap = AffineMap::get(dimCount: 4, symbolCount: 0, results: {bDim, mDim, nDim}, context);
187 SmallVector<utils::IteratorType> genericIterators = {parallel, parallel,
188 parallel, reduction};
189
190 auto genericOp = rewriter.create<linalg::GenericOp>(
191 location: loc, args&: reshapedOutputType,
192 /*inputs=*/args: ValueRange{img2ColTensor.getResult(i: 0), reshapedFilter},
193 /*outputs=*/args: ValueRange{reshapedOutput},
194 args: ArrayRef<AffineMap>{lhsMap, rhsMap, resultMap}, args&: genericIterators,
195 args: [&](OpBuilder &nestedBuilder, Location nestedLoc, ValueRange args) {
196 Value mul =
197 createMul(loc, x: args[0], y: args[1], accType: args[2].getType(), builder&: nestedBuilder);
198 Value add = createAdd(loc, x: mul, y: args[2], builder&: nestedBuilder);
199 nestedBuilder.create<linalg::YieldOp>(location: nestedLoc, args&: add);
200 });
201 Value result = genericOp.getResults().front();
202
203 auto reshapedResult = rewriter.create<tensor::ExpandShapeOp>(
204 location: loc, args&: outputType, args&: result, args&: outputReassocIndices);
205
206 rewriter.replaceOp(op: convOp, newValues: ArrayRef<Value>{reshapedResult});
207
208 return std::make_pair(x: img2ColTensor.getOperation(),
209 y: reshapedResult.getOperation());
210}
211
212FailureOr<std::pair<Operation *, Operation *>>
213rewriteInIm2Col(RewriterBase &rewriter,
214 linalg::DepthwiseConv2DNhwcHwcOp convOp) {
215 auto inputType = cast<RankedTensorType>(Val: convOp.getInputs()[0].getType());
216 auto filterType = cast<RankedTensorType>(Val: convOp.getInputs()[1].getType());
217 auto outputType = cast<RankedTensorType>(Val: convOp.getOutputs()[0].getType());
218
219 if (!filterType.hasStaticShape())
220 return rewriter.notifyMatchFailure(
221 arg&: convOp, msg: "expected a static shape for the filter");
222
223 if (!inputType.hasStaticShape())
224 return rewriter.notifyMatchFailure(arg&: convOp,
225 msg: "expected a static shape for the input");
226
227 // TODO: Support dilation.
228 if (!hasAllOneValues(attr: convOp.getDilations()))
229 return rewriter.notifyMatchFailure(arg&: convOp,
230 msg: "expected all ones for dilations");
231
232 Location loc = convOp.getLoc();
233
234 auto transposeOperand = [&](Value operand, ArrayRef<int64_t> indices) {
235 auto operandTensorType = cast<RankedTensorType>(Val: operand.getType());
236 auto nloops = indices.size();
237 ArrayRef<int64_t> inputShape = operandTensorType.getShape();
238
239 SmallVector<AffineExpr> exprs = llvm::to_vector<4>(
240 Range: llvm::map_range(C&: indices, F: [&](int64_t index) -> AffineExpr {
241 return rewriter.getAffineDimExpr(position: index);
242 }));
243
244 SmallVector<int64_t> targetShape = llvm::to_vector<4>(Range: llvm::map_range(
245 C&: indices, F: [&](int64_t index) -> int64_t { return inputShape[index]; }));
246
247 Value outputTensor = rewriter.create<tensor::EmptyOp>(
248 location: loc, args&: targetShape, args: operandTensorType.getElementType());
249
250 SmallVector<utils::IteratorType> loopAttributeTypes(
251 nloops, utils::IteratorType::parallel);
252
253 SmallVector<AffineMap> indexingMaps = {
254 inversePermutation(
255 map: AffineMap::get(dimCount: nloops, symbolCount: 0, results: exprs, context: rewriter.getContext())),
256 AffineMap::getMultiDimIdentityMap(numDims: nloops, context: rewriter.getContext())};
257
258 auto transposedOp = rewriter.create<linalg::GenericOp>(
259 location: loc, args: outputTensor.getType(),
260 /*inputs=*/args&: operand, /*outputs=*/args&: outputTensor, args&: indexingMaps,
261 args&: loopAttributeTypes,
262 args: [&](OpBuilder &nestedBuilder, Location nestedLoc, ValueRange args) {
263 nestedBuilder.create<linalg::YieldOp>(location: nestedLoc, args: args[0]);
264 });
265
266 return transposedOp.getResult(i: 0);
267 };
268
269 Value input = convOp.getInputs()[0];
270 Value filter = convOp.getInputs()[1];
271 Value output = convOp.getOutputs()[0];
272
273 // Transpose input, filter so channels are outermost
274 Value inputT = transposeOperand(input, {0, 3, 1, 2});
275 Value filterT = transposeOperand(filter, {2, 0, 1});
276 ArrayRef<int64_t> filterTShape =
277 cast<RankedTensorType>(Val: filterT.getType()).getShape();
278 ArrayRef<int64_t> outputShape = outputType.getShape();
279
280 int n = outputShape[0];
281 int oh = outputShape[1];
282 int ow = outputShape[2];
283 int c = outputShape[3];
284 int fh = filterTShape[1];
285 int fw = filterTShape[2];
286
287 SmallVector<int64_t> colTensorShape = {n, c, oh, ow, fh, fw};
288 Value transposedOutputTensor = transposeOperand(output, {0, 3, 1, 2});
289
290 AffineExpr nDim, cDim, ohDim, owDim, khDim, kwDim;
291 bindDims(ctx: rewriter.getContext(), exprs&: nDim, exprs&: cDim, exprs&: ohDim, exprs&: owDim, exprs&: khDim, exprs&: kwDim);
292
293 AffineExpr shSym = rewriter.getAffineConstantExpr(
294 constant: convOp.getStrides().getValues<int64_t>()[0]);
295 AffineExpr swSym = rewriter.getAffineConstantExpr(
296 constant: convOp.getStrides().getValues<int64_t>()[1]);
297
298 SmallVector<AffineExpr> inputExprs = {nDim, cDim, ohDim * shSym + khDim,
299 owDim * swSym + kwDim};
300
301 auto nloops = colTensorShape.size();
302
303 SmallVector<utils::IteratorType> loopAttributeTypes(
304 nloops, utils::IteratorType::parallel);
305
306 SmallVector<AffineMap> indexingMaps = {
307 AffineMap::get(dimCount: nloops, symbolCount: 0, results: inputExprs, context: rewriter.getContext()),
308 AffineMap::getMultiDimIdentityMap(numDims: nloops, context: rewriter.getContext())};
309
310 Value colTensor = rewriter.create<tensor::EmptyOp>(
311 location: loc, args&: colTensorShape, args: inputType.getElementType());
312
313 auto img2ColTensor = rewriter.create<linalg::GenericOp>(
314 location: loc, args: colTensor.getType(),
315 /*inputs=*/args&: inputT, /*outputs=*/args&: colTensor, args&: indexingMaps,
316 args&: loopAttributeTypes,
317 args: [&](OpBuilder &nestedBuilder, Location nestedLoc, ValueRange args) {
318 nestedBuilder.create<linalg::YieldOp>(location: nestedLoc, args: args[0]);
319 });
320
321 SmallVector<ReassociationIndices> img2ColTensorReassocIndices = {
322 {0, 1}, {2, 3}, {4, 5}};
323 SmallVector<ReassociationIndices> filterReassociationIndice = {{0}, {1, 2}};
324 SmallVector<ReassociationIndices> outputReassociationIndice = {{0, 1},
325 {2, 3}};
326
327 auto reshapedImg2ColTensorType = RankedTensorType::get(
328 shape: {n * c, oh * ow, fh * fw}, elementType: inputType.getElementType());
329 auto reshapedFilterTensorType =
330 RankedTensorType::get(shape: {c, fh * fw}, elementType: filterType.getElementType());
331 auto reshapedOutputTensorType =
332 RankedTensorType::get(shape: {n * c, oh * ow}, elementType: outputType.getElementType());
333
334 Value reshapedImg2ColTensor = rewriter.create<tensor::CollapseShapeOp>(
335 location: loc, args&: reshapedImg2ColTensorType, args: img2ColTensor.getResult(i: 0),
336 args&: img2ColTensorReassocIndices);
337 Value reshapedFilterTensor = rewriter.create<tensor::CollapseShapeOp>(
338 location: loc, args&: reshapedFilterTensorType, args&: filterT, args&: filterReassociationIndice);
339 Value reshapedoutputTensor = rewriter.create<tensor::CollapseShapeOp>(
340 location: loc, args&: reshapedOutputTensorType, args&: transposedOutputTensor,
341 args&: outputReassociationIndice);
342
343 auto batchMatVecResult = rewriter.create<linalg::BatchMatvecOp>(
344 location: loc, args: TypeRange{reshapedoutputTensor.getType()},
345 args: ValueRange{reshapedImg2ColTensor, reshapedFilterTensor},
346 args: ValueRange{reshapedoutputTensor});
347
348 SmallVector<ReassociationIndices> batchMatVecReassociationIndice = {{0, 1},
349 {2, 3}};
350
351 auto batchMatVecResultReshaped = rewriter.create<tensor::ExpandShapeOp>(
352 location: loc, args: transposedOutputTensor.getType(), args: batchMatVecResult.getResult(i: 0),
353 args&: batchMatVecReassociationIndice);
354
355 Value transposedResult =
356 transposeOperand(batchMatVecResultReshaped, {0, 2, 3, 1});
357
358 rewriter.replaceOp(op: convOp, newValues: ArrayRef<Value>{transposedResult});
359 return std::make_pair(x: img2ColTensor.getOperation(),
360 y: transposedResult.getDefiningOp());
361}
362
363FailureOr<std::pair<Operation *, Operation *>>
364rewriteInIm2Col(RewriterBase &rewriter, linalg::Conv2DNchwFchwOp convOp) {
365 auto inputType = cast<ShapedType>(Val: convOp.getInputs()[0].getType());
366 auto filterType = cast<ShapedType>(Val: convOp.getInputs()[1].getType());
367 auto outputType = cast<ShapedType>(Val: convOp.getOutputs()[0].getType());
368
369 if (!filterType.hasStaticShape())
370 return rewriter.notifyMatchFailure(
371 arg&: convOp, msg: "expected a static shape for the filter");
372
373 if (!inputType.hasStaticShape())
374 return rewriter.notifyMatchFailure(arg&: convOp,
375 msg: "expected a static shape for the input");
376
377 // TODO: Support dilation.
378 if (!hasAllOneValues(attr: convOp.getDilations()))
379 return rewriter.notifyMatchFailure(arg&: convOp,
380 msg: "expected all ones for dilations");
381
382 Value input = convOp.getInputs()[0];
383 Value filter = convOp.getInputs()[1];
384 Value output = convOp.getOutputs()[0];
385
386 auto filterShape = filterType.getShape();
387 auto outputShape = outputType.getShape();
388
389 int64_t n = outputShape[0];
390 int64_t oc = outputShape[1];
391 int64_t oh = outputShape[2];
392 int64_t ow = outputShape[3];
393 int64_t ic = filterShape[1];
394 int64_t fh = filterShape[2];
395 int64_t fw = filterShape[3];
396
397 auto loc = convOp.getLoc();
398 MLIRContext *context = rewriter.getContext();
399
400 SmallVector<ReassociationIndices> filterReassocIndices = {{0}, {1, 2, 3}};
401 auto reshapedFilterType =
402 RankedTensorType::get(shape: {oc, ic * fh * fw}, elementType: inputType.getElementType());
403 Value reshapedFilter = rewriter.create<tensor::CollapseShapeOp>(
404 location: loc, args&: reshapedFilterType, args&: filter, args&: filterReassocIndices);
405
406 SmallVector<ReassociationIndices> outputReassocIndices = {{0}, {1}, {2, 3}};
407 auto reshapedOutputType =
408 RankedTensorType::get(shape: {n, oc, oh * ow}, elementType: outputType.getElementType());
409 Value reshapedOutput = rewriter.create<tensor::CollapseShapeOp>(
410 location: loc, args&: reshapedOutputType, args&: output, args&: outputReassocIndices);
411
412 // Convert the input to a (BKN) tensor.
413 SmallVector<int64_t, 4> colTensorShape = {n, ic * fh * fw, oh * ow};
414 Value colTensor = rewriter.create<tensor::EmptyOp>(
415 location: loc, args&: colTensorShape, args: inputType.getElementType());
416
417 auto nloops = colTensorShape.size();
418
419 auto parallel = utils::IteratorType::parallel;
420 auto reduction = utils::IteratorType::reduction;
421 SmallVector<utils::IteratorType, 3> img2colIterators(nloops, parallel);
422
423 SmallVector<AffineMap, 4> img2colIndexingMaps = {
424 AffineMap::getMultiDimIdentityMap(numDims: nloops, context)};
425
426 auto img2ColTensor = rewriter.create<linalg::GenericOp>(
427 location: loc, args: colTensor.getType(),
428 /*inputs=*/args: ValueRange{}, /*outputs=*/args&: colTensor, args&: img2colIndexingMaps,
429 args&: img2colIterators,
430 args: [&](OpBuilder &nestedBuilder, Location nestedLoc, ValueRange args) {
431 // Get the iterators named based on the matmul (batch, m, k).
432 Value bIndex = nestedBuilder.create<linalg::IndexOp>(location: loc, args: 0);
433 Value kIndex = nestedBuilder.create<linalg::IndexOp>(location: loc, args: 1);
434 Value nIndex = nestedBuilder.create<linalg::IndexOp>(location: loc, args: 2);
435
436 // Recover the original iteration indices from the problem/input sizes.
437 SmallVector<Value> kIndices = unrollIndex(
438 b&: nestedBuilder, loc: nestedLoc, index: kIndex, factors: ArrayRef<int64_t>{ic, fh, fw});
439 auto icIndex = kIndices[0];
440 auto fhIndex = kIndices[1];
441 auto fwIndex = kIndices[2];
442
443 SmallVector<Value> nIndices = unrollIndex(
444 b&: nestedBuilder, loc: nestedLoc, index: nIndex, factors: ArrayRef<int64_t>{oh, ow});
445 auto ohIndex = nIndices[0];
446 auto owIndex = nIndices[1];
447
448 // Extract the input element corresponding to the expanded indices.
449 Value hIndex =
450 getConvolvedIndex(b&: nestedBuilder, loc: nestedLoc, oIndex: ohIndex, fIndex: fhIndex,
451 stride: convOp.getStrides().getValues<int64_t>()[0]);
452 Value wIndex =
453 getConvolvedIndex(b&: nestedBuilder, loc: nestedLoc, oIndex: owIndex, fIndex: fwIndex,
454 stride: convOp.getStrides().getValues<int64_t>()[1]);
455
456 // im2col[n, ic*fh*fw, oh*ow] = input[n, ic, sh*oh + fh, sw*ow + fw]
457 SmallVector<Value> extractionIndices{bIndex, icIndex, hIndex, wIndex};
458 Value inputVal = nestedBuilder.create<tensor::ExtractOp>(
459 location: loc, args&: input, args&: extractionIndices);
460 nestedBuilder.create<linalg::YieldOp>(location: nestedLoc, args&: inputVal);
461 });
462
463 // Because the filter does not share the same batch dimension,
464 // the batch dimension is only used in indexing the input and output. Thus
465 // we cannot use existing linalg named ops like linalg.batch_matmul.
466 // i.e. M x K * (B x) K x N = (B x) M x N
467 AffineExpr bDim, mDim, nDim, kDim;
468 bindDims(ctx: context, exprs&: bDim, exprs&: mDim, exprs&: nDim, exprs&: kDim);
469 auto lhsMap = AffineMap::get(dimCount: 4, symbolCount: 0, results: {mDim, kDim}, context);
470 auto rhsMap = AffineMap::get(dimCount: 4, symbolCount: 0, results: {bDim, kDim, nDim}, context);
471 auto resultMap = AffineMap::get(dimCount: 4, symbolCount: 0, results: {bDim, mDim, nDim}, context);
472 SmallVector<utils::IteratorType> genericIterators = {parallel, parallel,
473 parallel, reduction};
474 auto genericOp = rewriter.create<linalg::GenericOp>(
475 location: loc, args&: reshapedOutputType,
476 /*inputs=*/args: ValueRange{reshapedFilter, img2ColTensor.getResult(i: 0)},
477 /*outputs=*/args: ValueRange{reshapedOutput},
478 args: ArrayRef<AffineMap>{lhsMap, rhsMap, resultMap}, args&: genericIterators,
479 args: [&](OpBuilder &nestedBuilder, Location nestedLoc, ValueRange args) {
480 Value mul =
481 createMul(loc, x: args[0], y: args[1], accType: args[2].getType(), builder&: nestedBuilder);
482 Value add = createAdd(loc, x: mul, y: args[2], builder&: nestedBuilder);
483 nestedBuilder.create<linalg::YieldOp>(location: nestedLoc, args&: add);
484 });
485 Value result = genericOp.getResults().front();
486
487 auto reshapedResult = rewriter.create<tensor::ExpandShapeOp>(
488 location: loc, args&: outputType, args&: result, args&: outputReassocIndices);
489
490 rewriter.replaceOp(op: convOp, newValues: ArrayRef<Value>{reshapedResult});
491
492 return std::make_pair(x: img2ColTensor.getOperation(),
493 y: reshapedResult.getOperation());
494}
495
496FailureOr<std::pair<Operation *, Operation *>>
497rewriteInIm2Col(RewriterBase &rewriter, linalg::Conv2DNhwcFhwcOp convOp) {
498 auto inputType = cast<ShapedType>(Val: convOp.getInputs()[0].getType());
499 auto filterType = cast<ShapedType>(Val: convOp.getInputs()[1].getType());
500 auto outputType = cast<ShapedType>(Val: convOp.getOutputs()[0].getType());
501
502 if (!filterType.hasStaticShape())
503 return rewriter.notifyMatchFailure(
504 arg&: convOp, msg: "expected a static shape for the filter");
505
506 if (!inputType.hasStaticShape())
507 return rewriter.notifyMatchFailure(arg&: convOp,
508 msg: "expected a static shape for the input");
509
510 // TODO: Support dilation.
511 if (!hasAllOneValues(attr: convOp.getDilations()))
512 return rewriter.notifyMatchFailure(arg&: convOp,
513 msg: "expected all ones for dilations");
514
515 MLIRContext *context = rewriter.getContext();
516 Value input = convOp.getInputs()[0];
517 Value filter = convOp.getInputs()[1];
518 Value output = convOp.getOutputs()[0];
519
520 ArrayRef<int64_t> filterShape = filterType.getShape();
521 ArrayRef<int64_t> outputShape = outputType.getShape();
522
523 int64_t n = outputShape[0];
524 int64_t oh = outputShape[1];
525 int64_t ow = outputShape[2];
526 int64_t oc = outputShape[3];
527 int64_t fh = filterShape[1];
528 int64_t fw = filterShape[2];
529 int64_t ic = filterShape[3];
530
531 Location loc = convOp.getLoc();
532
533 // Reshape output and filter to the LHS and result of a "row-wise" matrix
534 // multiplication.
535 SmallVector<ReassociationIndices> filterReassocIndices = {{0}, {1, 2, 3}};
536 auto reshapedFilterType =
537 RankedTensorType::get(shape: {oc, fh * fw * ic}, elementType: filterType.getElementType());
538 Value reshapedFilter = rewriter.create<tensor::CollapseShapeOp>(
539 location: loc, args&: reshapedFilterType, args&: filter, args&: filterReassocIndices);
540
541 SmallVector<ReassociationIndices> outputReassocIndices = {{0}, {1, 2}, {3}};
542 RankedTensorType reshapedOutputType =
543 RankedTensorType::get(shape: {n, oh * ow, oc}, elementType: outputType.getElementType());
544 Value reshapedOutput = rewriter.create<tensor::CollapseShapeOp>(
545 location: loc, args&: reshapedOutputType, args&: output, args&: outputReassocIndices);
546
547 SmallVector<int64_t> colTensorShape = {n, oh * ow, fh * fw * ic};
548 Value colTensor = rewriter.create<tensor::EmptyOp>(
549 location: loc, args&: colTensorShape, args: inputType.getElementType());
550
551 // Convert the input to a (BMK) column tensor.
552 auto nloops = colTensorShape.size();
553
554 auto parallel = utils::IteratorType::parallel;
555 auto reduction = utils::IteratorType::reduction;
556 SmallVector<utils::IteratorType> img2colIterators(nloops, parallel);
557
558 SmallVector<AffineMap> img2colIndexingMaps = {
559 AffineMap::getMultiDimIdentityMap(numDims: nloops, context)};
560
561 auto img2ColTensor = rewriter.create<linalg::GenericOp>(
562 location: loc, args: colTensor.getType(),
563 /*inputs=*/args: ValueRange{}, /*outputs=*/args&: colTensor, args&: img2colIndexingMaps,
564 args&: img2colIterators,
565 args: [&](OpBuilder &nestedBuilder, Location nestedLoc, ValueRange args) {
566 // Get the iterators named based on the matmul (batch, m, k).
567 Value bIndex = nestedBuilder.create<linalg::IndexOp>(location: loc, args: 0);
568 Value mIndex = nestedBuilder.create<linalg::IndexOp>(location: loc, args: 1);
569 Value kIndex = nestedBuilder.create<linalg::IndexOp>(location: loc, args: 2);
570
571 // Recover the original iteration indices from the problem/input sizes.
572 SmallVector<Value> mIndices = unrollIndex(
573 b&: nestedBuilder, loc: nestedLoc, index: mIndex, factors: ArrayRef<int64_t>{oh, ow});
574 auto ohIndex = mIndices[0];
575 auto owIndex = mIndices[1];
576
577 SmallVector<Value> kIndices = unrollIndex(
578 b&: nestedBuilder, loc: nestedLoc, index: kIndex, factors: ArrayRef<int64_t>{fh, fw, ic});
579 auto fhIndex = kIndices[0];
580 auto fwIndex = kIndices[1];
581 auto icIndex = kIndices[2];
582
583 // Extract the input element corresponding to the expanded indices.
584 Value hIndex =
585 getConvolvedIndex(b&: nestedBuilder, loc: nestedLoc, oIndex: ohIndex, fIndex: fhIndex,
586 stride: convOp.getStrides().getValues<int64_t>()[0]);
587 Value wIndex =
588 getConvolvedIndex(b&: nestedBuilder, loc: nestedLoc, oIndex: owIndex, fIndex: fwIndex,
589 stride: convOp.getStrides().getValues<int64_t>()[1]);
590
591 // im2col[n, oh*ow, fh*fw*ic] = input[n, sh*oh + fh, sw*ow + fw, ic]
592 SmallVector<Value> extractionIndices{bIndex, hIndex, wIndex, icIndex};
593 Value inputVal = nestedBuilder.create<tensor::ExtractOp>(
594 location: loc, args&: input, args&: extractionIndices);
595 nestedBuilder.create<linalg::YieldOp>(location: nestedLoc, args&: inputVal);
596 });
597
598 // Because we didn't transpose the filters we don't actually have a batched
599 // matrix multiply. Instead, we have an operation consisting of "row-wise" dot
600 // products.
601 AffineExpr bDim, mDim, nDim, kDim;
602 bindDims(ctx: context, exprs&: bDim, exprs&: mDim, exprs&: nDim, exprs&: kDim);
603 auto lhsMap = AffineMap::get(dimCount: 4, symbolCount: 0, results: {bDim, mDim, kDim}, context);
604 auto rhsMap = AffineMap::get(dimCount: 4, symbolCount: 0, results: {nDim, kDim}, context);
605 auto resultMap = AffineMap::get(dimCount: 4, symbolCount: 0, results: {bDim, mDim, nDim}, context);
606 SmallVector<utils::IteratorType> genericIterators = {parallel, parallel,
607 parallel, reduction};
608
609 auto genericOp = rewriter.create<linalg::GenericOp>(
610 location: loc, args&: reshapedOutputType,
611 /*inputs=*/args: ValueRange{img2ColTensor.getResult(i: 0), reshapedFilter},
612 /*outputs=*/args: ValueRange{reshapedOutput},
613 args: ArrayRef<AffineMap>{lhsMap, rhsMap, resultMap}, args&: genericIterators,
614 args: [&](OpBuilder &nestedBuilder, Location nestedLoc, ValueRange args) {
615 Value mul =
616 createMul(loc, x: args[0], y: args[1], accType: args[2].getType(), builder&: nestedBuilder);
617 Value add = createAdd(loc, x: mul, y: args[2], builder&: nestedBuilder);
618 nestedBuilder.create<linalg::YieldOp>(location: nestedLoc, args&: add);
619 });
620 Value result = genericOp.getResults().front();
621
622 auto reshapedResult = rewriter.create<tensor::ExpandShapeOp>(
623 location: loc, args&: outputType, args&: result, args&: outputReassocIndices);
624
625 rewriter.replaceOp(op: convOp, newValues: ArrayRef<Value>{reshapedResult});
626
627 return std::make_pair(x: img2ColTensor.getOperation(),
628 y: reshapedResult.getOperation());
629}
630
631namespace {
632
633class ConvertConv2DNhwcHwcf final
634 : public OpRewritePattern<linalg::Conv2DNhwcHwcfOp> {
635public:
636 using OpRewritePattern::OpRewritePattern;
637
638 LogicalResult matchAndRewrite(linalg::Conv2DNhwcHwcfOp convOp,
639 PatternRewriter &rewriter) const override {
640 if (failed(Result: rewriteInIm2Col(rewriter, convOp)))
641 return failure();
642 return success();
643 }
644};
645
646class ConvertDepthwiseConv2DNhwcHwc final
647 : public OpRewritePattern<linalg::DepthwiseConv2DNhwcHwcOp> {
648public:
649 using OpRewritePattern<linalg::DepthwiseConv2DNhwcHwcOp>::OpRewritePattern;
650
651 LogicalResult matchAndRewrite(linalg::DepthwiseConv2DNhwcHwcOp convOp,
652 PatternRewriter &rewriter) const override {
653 if (failed(Result: rewriteInIm2Col(rewriter, convOp)))
654 return failure();
655 return success();
656 }
657};
658
659class ConvertConv2DNchwFchw final
660 : public OpRewritePattern<linalg::Conv2DNchwFchwOp> {
661public:
662 using OpRewritePattern::OpRewritePattern;
663
664 LogicalResult matchAndRewrite(linalg::Conv2DNchwFchwOp convOp,
665 PatternRewriter &rewriter) const override {
666 if (failed(Result: rewriteInIm2Col(rewriter, convOp)))
667 return failure();
668 return success();
669 }
670};
671
672class ConvertConv2DNhwcFhwc final
673 : public OpRewritePattern<linalg::Conv2DNhwcFhwcOp> {
674public:
675 using OpRewritePattern::OpRewritePattern;
676
677 LogicalResult matchAndRewrite(linalg::Conv2DNhwcFhwcOp convOp,
678 PatternRewriter &rewriter) const override {
679 if (failed(Result: rewriteInIm2Col(rewriter, convOp)))
680 return failure();
681 return success();
682 }
683};
684} // end anonymous namespace
685
686void populateConvertConv2DToImg2ColPatterns(RewritePatternSet &patterns) {
687 MLIRContext *context = patterns.getContext();
688 patterns.insert<ConvertConv2DNhwcHwcf, ConvertDepthwiseConv2DNhwcHwc,
689 ConvertConv2DNchwFchw, ConvertConv2DNhwcFhwc>(arg&: context);
690}
691} // end namespace linalg
692} // end namespace mlir
693

source code of mlir/lib/Dialect/Linalg/Transforms/ConvertConv2DToImg2Col.cpp