1 | //===- ConvertConv2DToImg2Col.cpp - im2col implementation -----------------===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | |
9 | #include "mlir/Dialect/Affine/Utils.h" |
10 | #include "mlir/Dialect/Arith/IR/Arith.h" |
11 | #include "mlir/Dialect/Complex/IR/Complex.h" |
12 | #include "mlir/Dialect/Linalg/IR/Linalg.h" |
13 | #include "mlir/Dialect/Linalg/Transforms/Transforms.h" |
14 | #include "mlir/Dialect/Tensor/IR/Tensor.h" |
15 | #include "mlir/Dialect/Utils/ReshapeOpsUtils.h" |
16 | #include "mlir/Dialect/Utils/StructuredOpsUtils.h" |
17 | #include "mlir/IR/AffineExpr.h" |
18 | #include "mlir/IR/AffineMap.h" |
19 | #include "mlir/IR/Builders.h" |
20 | #include "mlir/IR/BuiltinAttributes.h" |
21 | #include "mlir/IR/BuiltinTypes.h" |
22 | #include "mlir/Transforms/GreedyPatternRewriteDriver.h" |
23 | #include <utility> |
24 | |
25 | namespace mlir { |
26 | namespace linalg { |
27 | static bool hasAllOneValues(DenseIntElementsAttr attr) { |
28 | return llvm::all_of( |
29 | Range&: attr, P: [](const APInt &element) { return element.getSExtValue() == 1; }); |
30 | } |
31 | |
32 | static Value createAdd(Location loc, Value x, Value y, OpBuilder &builder) { |
33 | if (isa<IntegerType>(x.getType())) |
34 | return builder.create<arith::AddIOp>(loc, x, y); |
35 | if (isa<ComplexType>(x.getType())) |
36 | return builder.create<complex::AddOp>(loc, x, y); |
37 | return builder.create<arith::AddFOp>(loc, x, y); |
38 | } |
39 | |
40 | static Value createMul(Location loc, Value x, Value y, Type accType, |
41 | OpBuilder &builder) { |
42 | // Linalg named ops specify signed extend for named ops. |
43 | Value xConvert = |
44 | convertScalarToDtype(b&: builder, loc, operand: x, toType: accType, /*isUnsignedCast=*/false); |
45 | Value yConvert = |
46 | convertScalarToDtype(b&: builder, loc, operand: y, toType: accType, /*isUnsignedCast=*/false); |
47 | if (isa<ComplexType>(accType)) |
48 | return builder.create<complex::MulOp>(loc, xConvert, yConvert); |
49 | if (isa<IntegerType>(accType)) |
50 | return builder.create<arith::MulIOp>(loc, xConvert, yConvert); |
51 | return builder.create<arith::MulFOp>(loc, xConvert, yConvert); |
52 | } |
53 | |
54 | // Delinearizes the given composite `index` by the basis specified in `factors`. |
55 | static SmallVector<Value> unrollIndex(OpBuilder &b, Location loc, Value index, |
56 | ArrayRef<int64_t> factors) { |
57 | assert(!factors.empty() && "empty factor list" ); |
58 | SmallVector<Value> basis; |
59 | for (int64_t f : factors) |
60 | basis.push_back(b.create<arith::ConstantOp>(loc, b.getIndexAttr(f))); |
61 | FailureOr<SmallVector<Value>> multiIndex = |
62 | affine::delinearizeIndex(b, loc, linearIndex: index, basis); |
63 | assert(!failed(multiIndex) && "Failed to linearize img2col index" ); |
64 | return *multiIndex; |
65 | } |
66 | |
67 | // Given indices corresponding to iterators in the output (oIndex) and filter |
68 | // (fIndex) for a convolution, compute the convolved index for the |
69 | // input as `oIndex * stride + fIndex`. |
70 | static Value getConvolvedIndex(OpBuilder &b, Location loc, Value oIndex, |
71 | Value fIndex, int64_t stride) { |
72 | AffineExpr oExpr, fExpr; |
73 | bindSymbols(ctx: b.getContext(), exprs&: oExpr, exprs&: fExpr); |
74 | AffineMap convMap = AffineMap::get(dimCount: 0, symbolCount: 2, result: stride * oExpr + fExpr); |
75 | return affine::makeComposedAffineApply(b, loc, convMap, {oIndex, fIndex}); |
76 | } |
77 | |
78 | FailureOr<std::pair<Operation *, Operation *>> |
79 | rewriteInIm2Col(RewriterBase &rewriter, linalg::Conv2DNhwcHwcfOp convOp) { |
80 | auto inputType = cast<ShapedType>(convOp.getInputs()[0].getType()); |
81 | auto filterType = cast<ShapedType>(convOp.getInputs()[1].getType()); |
82 | auto outputType = cast<ShapedType>(convOp.getOutputs()[0].getType()); |
83 | |
84 | if (!filterType.hasStaticShape()) |
85 | return rewriter.notifyMatchFailure( |
86 | convOp, "expected a static shape for the filter" ); |
87 | |
88 | if (!inputType.hasStaticShape()) |
89 | return rewriter.notifyMatchFailure(convOp, |
90 | "expected a static shape for the input" ); |
91 | |
92 | // TODO: Support dilation. |
93 | if (!hasAllOneValues(convOp.getDilations())) |
94 | return rewriter.notifyMatchFailure(convOp, |
95 | "expected all ones for dilations" ); |
96 | |
97 | MLIRContext *context = rewriter.getContext(); |
98 | Value input = convOp.getInputs()[0]; |
99 | Value filter = convOp.getInputs()[1]; |
100 | Value output = convOp.getOutputs()[0]; |
101 | |
102 | ArrayRef<int64_t> filterShape = filterType.getShape(); |
103 | ArrayRef<int64_t> outputShape = outputType.getShape(); |
104 | |
105 | int64_t n = outputShape[0]; |
106 | int64_t oh = outputShape[1]; |
107 | int64_t ow = outputShape[2]; |
108 | int64_t oc = outputShape[3]; |
109 | int64_t fh = filterShape[0]; |
110 | int64_t fw = filterShape[1]; |
111 | int64_t ic = filterShape[2]; |
112 | |
113 | Location loc = convOp.getLoc(); |
114 | |
115 | // Reshape output and filter to the LHS and result of a (B)MNK matmul. |
116 | SmallVector<ReassociationIndices> filterReassocIndices = {{0, 1, 2}, {3}}; |
117 | auto reshapedFilterType = |
118 | RankedTensorType::get({fh * fw * ic, oc}, filterType.getElementType()); |
119 | Value reshapedFilter = rewriter.create<tensor::CollapseShapeOp>( |
120 | loc, reshapedFilterType, filter, filterReassocIndices); |
121 | |
122 | SmallVector<ReassociationIndices> outputReassocIndices = {{0}, {1, 2}, {3}}; |
123 | RankedTensorType reshapedOutputType = |
124 | RankedTensorType::get({n, oh * ow, oc}, outputType.getElementType()); |
125 | Value reshapedOutput = rewriter.create<tensor::CollapseShapeOp>( |
126 | loc, reshapedOutputType, output, outputReassocIndices); |
127 | |
128 | SmallVector<int64_t> colTensorShape = {n, oh * ow, fh * fw * ic}; |
129 | Value colTensor = rewriter.create<tensor::EmptyOp>( |
130 | loc, colTensorShape, inputType.getElementType()); |
131 | |
132 | // Convert the input to a (BMK) column tensor. |
133 | auto nloops = colTensorShape.size(); |
134 | |
135 | auto parallel = utils::IteratorType::parallel; |
136 | auto reduction = utils::IteratorType::reduction; |
137 | SmallVector<utils::IteratorType> img2colIterators(nloops, parallel); |
138 | |
139 | SmallVector<AffineMap> img2colIndexingMaps = { |
140 | AffineMap::getMultiDimIdentityMap(numDims: nloops, context)}; |
141 | |
142 | auto img2ColTensor = rewriter.create<linalg::GenericOp>( |
143 | loc, colTensor.getType(), |
144 | /*inputs=*/ValueRange{}, /*outputs=*/colTensor, img2colIndexingMaps, |
145 | img2colIterators, |
146 | [&](OpBuilder &nestedBuilder, Location nestedLoc, ValueRange args) { |
147 | // Get the iterators named based on the matmul (batch, m, k). |
148 | Value bIndex = nestedBuilder.create<linalg::IndexOp>(loc, 0); |
149 | Value mIndex = nestedBuilder.create<linalg::IndexOp>(loc, 1); |
150 | Value kIndex = nestedBuilder.create<linalg::IndexOp>(loc, 2); |
151 | |
152 | // Recover the original iteration indices from the problem/input sizes. |
153 | SmallVector<Value> mIndices = unrollIndex( |
154 | nestedBuilder, nestedLoc, mIndex, ArrayRef<int64_t>{oh, ow}); |
155 | auto ohIndex = mIndices[0]; |
156 | auto owIndex = mIndices[1]; |
157 | |
158 | SmallVector<Value> kIndices = unrollIndex( |
159 | nestedBuilder, nestedLoc, kIndex, ArrayRef<int64_t>{fh, fw, ic}); |
160 | auto fhIndex = kIndices[0]; |
161 | auto fwIndex = kIndices[1]; |
162 | auto icIndex = kIndices[2]; |
163 | |
164 | // Extract the input element corresponding to the expanded indices. |
165 | Value hIndex = |
166 | getConvolvedIndex(nestedBuilder, nestedLoc, ohIndex, fhIndex, |
167 | convOp.getStrides().getValues<int64_t>()[0]); |
168 | Value wIndex = |
169 | getConvolvedIndex(nestedBuilder, nestedLoc, owIndex, fwIndex, |
170 | convOp.getStrides().getValues<int64_t>()[1]); |
171 | |
172 | // im2col[n, oh*ow, fh*fw*ic] = input[n, sh*oh + fh, sw*ow + fw, ic] |
173 | SmallVector<Value> extractionIndices{bIndex, hIndex, wIndex, icIndex}; |
174 | Value inputVal = nestedBuilder.create<tensor::ExtractOp>( |
175 | loc, input, extractionIndices); |
176 | nestedBuilder.create<linalg::YieldOp>(nestedLoc, inputVal); |
177 | }); |
178 | |
179 | // Because the filter does not share the same batch dimension, |
180 | // the batch dimension is only used in indexing the input and output. Thus |
181 | // we cannot use existing linalg named ops like linalg.batch_matmul. |
182 | // i.e. (B x) M x K * K x N = (B x) M x N |
183 | AffineExpr bDim, mDim, nDim, kDim; |
184 | bindDims(ctx: context, exprs&: bDim, exprs&: mDim, exprs&: nDim, exprs&: kDim); |
185 | auto lhsMap = AffineMap::get(dimCount: 4, symbolCount: 0, results: {bDim, mDim, kDim}, context); |
186 | auto rhsMap = AffineMap::get(dimCount: 4, symbolCount: 0, results: {kDim, nDim}, context); |
187 | auto resultMap = AffineMap::get(dimCount: 4, symbolCount: 0, results: {bDim, mDim, nDim}, context); |
188 | SmallVector<utils::IteratorType> genericIterators = {parallel, parallel, |
189 | parallel, reduction}; |
190 | |
191 | auto genericOp = rewriter.create<linalg::GenericOp>( |
192 | loc, reshapedOutputType, |
193 | /*inputs=*/ValueRange{img2ColTensor.getResult(0), reshapedFilter}, |
194 | /*outputs=*/ValueRange{reshapedOutput}, |
195 | ArrayRef<AffineMap>{lhsMap, rhsMap, resultMap}, genericIterators, |
196 | [&](OpBuilder &nestedBuilder, Location nestedLoc, ValueRange args) { |
197 | Value mul = |
198 | createMul(loc, args[0], args[1], args[2].getType(), nestedBuilder); |
199 | Value add = createAdd(loc, mul, args[2], nestedBuilder); |
200 | nestedBuilder.create<linalg::YieldOp>(nestedLoc, add); |
201 | }); |
202 | Value result = genericOp.getResults().front(); |
203 | |
204 | auto reshapedResult = rewriter.create<tensor::ExpandShapeOp>( |
205 | loc, outputType, result, outputReassocIndices); |
206 | |
207 | rewriter.replaceOp(convOp, ArrayRef<Value>{reshapedResult}); |
208 | |
209 | return std::make_pair(img2ColTensor.getOperation(), |
210 | reshapedResult.getOperation()); |
211 | } |
212 | |
213 | FailureOr<std::pair<Operation *, Operation *>> |
214 | rewriteInIm2Col(RewriterBase &rewriter, |
215 | linalg::DepthwiseConv2DNhwcHwcOp convOp) { |
216 | auto inputType = cast<RankedTensorType>(convOp.getInputs()[0].getType()); |
217 | auto filterType = cast<RankedTensorType>(convOp.getInputs()[1].getType()); |
218 | auto outputType = cast<RankedTensorType>(convOp.getOutputs()[0].getType()); |
219 | |
220 | if (!filterType.hasStaticShape()) |
221 | return rewriter.notifyMatchFailure( |
222 | convOp, "expected a static shape for the filter" ); |
223 | |
224 | if (!inputType.hasStaticShape()) |
225 | return rewriter.notifyMatchFailure(convOp, |
226 | "expected a static shape for the input" ); |
227 | |
228 | // TODO: Support dilation. |
229 | if (!hasAllOneValues(convOp.getDilations())) |
230 | return rewriter.notifyMatchFailure(convOp, |
231 | "expected all ones for dilations" ); |
232 | |
233 | Location loc = convOp.getLoc(); |
234 | |
235 | auto transposeOperand = [&](Value operand, ArrayRef<int64_t> indices) { |
236 | auto operandTensorType = cast<RankedTensorType>(operand.getType()); |
237 | auto nloops = indices.size(); |
238 | ArrayRef<int64_t> inputShape = operandTensorType.getShape(); |
239 | |
240 | SmallVector<AffineExpr> exprs = llvm::to_vector<4>( |
241 | Range: llvm::map_range(C&: indices, F: [&](int64_t index) -> AffineExpr { |
242 | return rewriter.getAffineDimExpr(position: index); |
243 | })); |
244 | |
245 | SmallVector<int64_t> targetShape = llvm::to_vector<4>(Range: llvm::map_range( |
246 | C&: indices, F: [&](int64_t index) -> int64_t { return inputShape[index]; })); |
247 | |
248 | Value outputTensor = rewriter.create<tensor::EmptyOp>( |
249 | loc, targetShape, operandTensorType.getElementType()); |
250 | |
251 | SmallVector<utils::IteratorType> loopAttributeTypes( |
252 | nloops, utils::IteratorType::parallel); |
253 | |
254 | SmallVector<AffineMap> indexingMaps = { |
255 | inversePermutation( |
256 | map: AffineMap::get(dimCount: nloops, symbolCount: 0, results: exprs, context: rewriter.getContext())), |
257 | AffineMap::getMultiDimIdentityMap(numDims: nloops, context: rewriter.getContext())}; |
258 | |
259 | auto transposedOp = rewriter.create<linalg::GenericOp>( |
260 | loc, outputTensor.getType(), |
261 | /*inputs=*/operand, /*outputs=*/outputTensor, indexingMaps, |
262 | loopAttributeTypes, |
263 | [&](OpBuilder &nestedBuilder, Location nestedLoc, ValueRange args) { |
264 | nestedBuilder.create<linalg::YieldOp>(nestedLoc, args[0]); |
265 | }); |
266 | |
267 | return transposedOp.getResult(0); |
268 | }; |
269 | |
270 | Value input = convOp.getInputs()[0]; |
271 | Value filter = convOp.getInputs()[1]; |
272 | Value output = convOp.getOutputs()[0]; |
273 | |
274 | // Transpose input, filter so channels are outermost |
275 | Value inputT = transposeOperand(input, {0, 3, 1, 2}); |
276 | Value filterT = transposeOperand(filter, {2, 0, 1}); |
277 | ArrayRef<int64_t> filterTShape = |
278 | cast<RankedTensorType>(filterT.getType()).getShape(); |
279 | ArrayRef<int64_t> outputShape = outputType.getShape(); |
280 | |
281 | int n = outputShape[0]; |
282 | int oh = outputShape[1]; |
283 | int ow = outputShape[2]; |
284 | int c = outputShape[3]; |
285 | int fh = filterTShape[1]; |
286 | int fw = filterTShape[2]; |
287 | |
288 | SmallVector<int64_t> colTensorShape = {n, c, oh, ow, fh, fw}; |
289 | Value transposedOutputTensor = transposeOperand(output, {0, 3, 1, 2}); |
290 | |
291 | AffineExpr nDim, cDim, ohDim, owDim, khDim, kwDim; |
292 | bindDims(ctx: rewriter.getContext(), exprs&: nDim, exprs&: cDim, exprs&: ohDim, exprs&: owDim, exprs&: khDim, exprs&: kwDim); |
293 | |
294 | AffineExpr shSym = rewriter.getAffineConstantExpr( |
295 | constant: convOp.getStrides().getValues<int64_t>()[0]); |
296 | AffineExpr swSym = rewriter.getAffineConstantExpr( |
297 | constant: convOp.getStrides().getValues<int64_t>()[1]); |
298 | |
299 | SmallVector<AffineExpr> inputExprs = {nDim, cDim, ohDim * shSym + khDim, |
300 | owDim * swSym + kwDim}; |
301 | |
302 | auto nloops = colTensorShape.size(); |
303 | |
304 | SmallVector<utils::IteratorType> loopAttributeTypes( |
305 | nloops, utils::IteratorType::parallel); |
306 | |
307 | SmallVector<AffineMap> indexingMaps = { |
308 | AffineMap::get(nloops, 0, inputExprs, rewriter.getContext()), |
309 | AffineMap::getMultiDimIdentityMap(numDims: nloops, context: rewriter.getContext())}; |
310 | |
311 | Value colTensor = rewriter.create<tensor::EmptyOp>( |
312 | loc, colTensorShape, inputType.getElementType()); |
313 | |
314 | auto img2ColTensor = rewriter.create<linalg::GenericOp>( |
315 | loc, colTensor.getType(), |
316 | /*inputs=*/inputT, /*outputs=*/colTensor, indexingMaps, |
317 | loopAttributeTypes, |
318 | [&](OpBuilder &nestedBuilder, Location nestedLoc, ValueRange args) { |
319 | nestedBuilder.create<linalg::YieldOp>(nestedLoc, args[0]); |
320 | }); |
321 | |
322 | SmallVector<ReassociationIndices> img2ColTensorReassocIndices = { |
323 | {0, 1}, {2, 3}, {4, 5}}; |
324 | SmallVector<ReassociationIndices> filterReassociationIndice = {{0}, {1, 2}}; |
325 | SmallVector<ReassociationIndices> outputReassociationIndice = {{0, 1}, |
326 | {2, 3}}; |
327 | |
328 | auto reshapedImg2ColTensorType = RankedTensorType::get( |
329 | {n * c, oh * ow, fh * fw}, inputType.getElementType()); |
330 | auto reshapedFilterTensorType = |
331 | RankedTensorType::get({c, fh * fw}, filterType.getElementType()); |
332 | auto reshapedOutputTensorType = |
333 | RankedTensorType::get({n * c, oh * ow}, outputType.getElementType()); |
334 | |
335 | Value reshapedImg2ColTensor = rewriter.create<tensor::CollapseShapeOp>( |
336 | loc, reshapedImg2ColTensorType, img2ColTensor.getResult(0), |
337 | img2ColTensorReassocIndices); |
338 | Value reshapedFilterTensor = rewriter.create<tensor::CollapseShapeOp>( |
339 | loc, reshapedFilterTensorType, filterT, filterReassociationIndice); |
340 | Value reshapedoutputTensor = rewriter.create<tensor::CollapseShapeOp>( |
341 | loc, reshapedOutputTensorType, transposedOutputTensor, |
342 | outputReassociationIndice); |
343 | |
344 | auto batchMatVecResult = rewriter.create<linalg::BatchMatvecOp>( |
345 | loc, TypeRange{reshapedoutputTensor.getType()}, |
346 | ValueRange{reshapedImg2ColTensor, reshapedFilterTensor}, |
347 | ValueRange{reshapedoutputTensor}); |
348 | |
349 | SmallVector<ReassociationIndices> batchMatVecReassociationIndice = {{0, 1}, |
350 | {2, 3}}; |
351 | |
352 | Value batchMatVecResultReshaped = rewriter.create<tensor::ExpandShapeOp>( |
353 | loc, transposedOutputTensor.getType(), batchMatVecResult.getResult(0), |
354 | batchMatVecReassociationIndice); |
355 | |
356 | Value transposedResult = |
357 | transposeOperand(batchMatVecResultReshaped, {0, 2, 3, 1}); |
358 | |
359 | rewriter.replaceOp(convOp, ArrayRef<Value>{transposedResult}); |
360 | return std::make_pair(img2ColTensor.getOperation(), |
361 | transposedResult.getDefiningOp()); |
362 | } |
363 | |
364 | FailureOr<std::pair<Operation *, Operation *>> |
365 | rewriteInIm2Col(RewriterBase &rewriter, linalg::Conv2DNchwFchwOp convOp) { |
366 | auto inputType = cast<ShapedType>(convOp.getInputs()[0].getType()); |
367 | auto filterType = cast<ShapedType>(convOp.getInputs()[1].getType()); |
368 | auto outputType = cast<ShapedType>(convOp.getOutputs()[0].getType()); |
369 | |
370 | if (!filterType.hasStaticShape()) |
371 | return rewriter.notifyMatchFailure( |
372 | convOp, "expected a static shape for the filter" ); |
373 | |
374 | if (!inputType.hasStaticShape()) |
375 | return rewriter.notifyMatchFailure(convOp, |
376 | "expected a static shape for the input" ); |
377 | |
378 | // TODO: Support dilation. |
379 | if (!hasAllOneValues(convOp.getDilations())) |
380 | return rewriter.notifyMatchFailure(convOp, |
381 | "expected all ones for dilations" ); |
382 | |
383 | Value input = convOp.getInputs()[0]; |
384 | Value filter = convOp.getInputs()[1]; |
385 | Value output = convOp.getOutputs()[0]; |
386 | |
387 | auto filterShape = filterType.getShape(); |
388 | auto outputShape = outputType.getShape(); |
389 | |
390 | int64_t n = outputShape[0]; |
391 | int64_t oc = outputShape[1]; |
392 | int64_t oh = outputShape[2]; |
393 | int64_t ow = outputShape[3]; |
394 | int64_t ic = filterShape[1]; |
395 | int64_t fh = filterShape[2]; |
396 | int64_t fw = filterShape[3]; |
397 | |
398 | auto loc = convOp.getLoc(); |
399 | MLIRContext *context = rewriter.getContext(); |
400 | |
401 | SmallVector<ReassociationIndices> filterReassocIndices = {{0}, {1, 2, 3}}; |
402 | auto reshapedFilterType = |
403 | RankedTensorType::get({oc, ic * fh * fw}, inputType.getElementType()); |
404 | Value reshapedFilter = rewriter.create<tensor::CollapseShapeOp>( |
405 | loc, reshapedFilterType, filter, filterReassocIndices); |
406 | |
407 | SmallVector<ReassociationIndices> outputReassocIndices = {{0}, {1}, {2, 3}}; |
408 | auto reshapedOutputType = |
409 | RankedTensorType::get({n, oc, oh * ow}, outputType.getElementType()); |
410 | Value reshapedOutput = rewriter.create<tensor::CollapseShapeOp>( |
411 | loc, reshapedOutputType, output, outputReassocIndices); |
412 | |
413 | // Convert the input to a (BKN) tensor. |
414 | SmallVector<int64_t, 4> colTensorShape = {n, ic * fh * fw, oh * ow}; |
415 | Value colTensor = rewriter.create<tensor::EmptyOp>( |
416 | loc, colTensorShape, inputType.getElementType()); |
417 | |
418 | auto nloops = colTensorShape.size(); |
419 | |
420 | auto parallel = utils::IteratorType::parallel; |
421 | auto reduction = utils::IteratorType::reduction; |
422 | SmallVector<utils::IteratorType, 3> img2colIterators(nloops, parallel); |
423 | |
424 | SmallVector<AffineMap, 4> img2colIndexingMaps = { |
425 | AffineMap::getMultiDimIdentityMap(numDims: nloops, context)}; |
426 | |
427 | auto img2ColTensor = rewriter.create<linalg::GenericOp>( |
428 | loc, colTensor.getType(), |
429 | /*inputs=*/ValueRange{}, /*outputs=*/colTensor, img2colIndexingMaps, |
430 | img2colIterators, |
431 | [&](OpBuilder &nestedBuilder, Location nestedLoc, ValueRange args) { |
432 | // Get the iterators named based on the matmul (batch, m, k). |
433 | Value bIndex = nestedBuilder.create<linalg::IndexOp>(loc, 0); |
434 | Value kIndex = nestedBuilder.create<linalg::IndexOp>(loc, 1); |
435 | Value nIndex = nestedBuilder.create<linalg::IndexOp>(loc, 2); |
436 | |
437 | // Recover the original iteration indices from the problem/input sizes. |
438 | SmallVector<Value> kIndices = unrollIndex( |
439 | nestedBuilder, nestedLoc, kIndex, ArrayRef<int64_t>{ic, fh, fw}); |
440 | auto icIndex = kIndices[0]; |
441 | auto fhIndex = kIndices[1]; |
442 | auto fwIndex = kIndices[2]; |
443 | |
444 | SmallVector<Value> nIndices = unrollIndex( |
445 | nestedBuilder, nestedLoc, nIndex, ArrayRef<int64_t>{oh, ow}); |
446 | auto ohIndex = nIndices[0]; |
447 | auto owIndex = nIndices[1]; |
448 | |
449 | // Extract the input element corresponding to the expanded indices. |
450 | Value hIndex = |
451 | getConvolvedIndex(nestedBuilder, nestedLoc, ohIndex, fhIndex, |
452 | convOp.getStrides().getValues<int64_t>()[0]); |
453 | Value wIndex = |
454 | getConvolvedIndex(nestedBuilder, nestedLoc, owIndex, fwIndex, |
455 | convOp.getStrides().getValues<int64_t>()[1]); |
456 | |
457 | // im2col[n, ic*fh*fw, oh*ow] = input[n, ic, sh*oh + fh, sw*ow + fw] |
458 | SmallVector<Value> extractionIndices{bIndex, icIndex, hIndex, wIndex}; |
459 | Value inputVal = nestedBuilder.create<tensor::ExtractOp>( |
460 | loc, input, extractionIndices); |
461 | nestedBuilder.create<linalg::YieldOp>(nestedLoc, inputVal); |
462 | }); |
463 | |
464 | // Because the filter does not share the same batch dimension, |
465 | // the batch dimension is only used in indexing the input and output. Thus |
466 | // we cannot use existing linalg named ops like linalg.batch_matmul. |
467 | // i.e. M x K * (B x) K x N = (B x) M x N |
468 | AffineExpr bDim, mDim, nDim, kDim; |
469 | bindDims(ctx: context, exprs&: bDim, exprs&: mDim, exprs&: nDim, exprs&: kDim); |
470 | auto lhsMap = AffineMap::get(dimCount: 4, symbolCount: 0, results: {mDim, kDim}, context); |
471 | auto rhsMap = AffineMap::get(dimCount: 4, symbolCount: 0, results: {bDim, kDim, nDim}, context); |
472 | auto resultMap = AffineMap::get(dimCount: 4, symbolCount: 0, results: {bDim, mDim, nDim}, context); |
473 | SmallVector<utils::IteratorType> genericIterators = {parallel, parallel, |
474 | parallel, reduction}; |
475 | auto genericOp = rewriter.create<linalg::GenericOp>( |
476 | loc, reshapedOutputType, |
477 | /*inputs=*/ValueRange{reshapedFilter, img2ColTensor.getResult(0)}, |
478 | /*outputs=*/ValueRange{reshapedOutput}, |
479 | ArrayRef<AffineMap>{lhsMap, rhsMap, resultMap}, genericIterators, |
480 | [&](OpBuilder &nestedBuilder, Location nestedLoc, ValueRange args) { |
481 | Value mul = |
482 | createMul(loc, args[0], args[1], args[2].getType(), nestedBuilder); |
483 | Value add = createAdd(loc, mul, args[2], nestedBuilder); |
484 | nestedBuilder.create<linalg::YieldOp>(nestedLoc, add); |
485 | }); |
486 | Value result = genericOp.getResults().front(); |
487 | |
488 | auto reshapedResult = rewriter.create<tensor::ExpandShapeOp>( |
489 | loc, outputType, result, outputReassocIndices); |
490 | |
491 | rewriter.replaceOp(convOp, ArrayRef<Value>{reshapedResult}); |
492 | |
493 | return std::make_pair(img2ColTensor.getOperation(), |
494 | reshapedResult.getOperation()); |
495 | } |
496 | |
497 | FailureOr<std::pair<Operation *, Operation *>> |
498 | rewriteInIm2Col(RewriterBase &rewriter, linalg::Conv2DNhwcFhwcOp convOp) { |
499 | auto inputType = cast<ShapedType>(convOp.getInputs()[0].getType()); |
500 | auto filterType = cast<ShapedType>(convOp.getInputs()[1].getType()); |
501 | auto outputType = cast<ShapedType>(convOp.getOutputs()[0].getType()); |
502 | |
503 | if (!filterType.hasStaticShape()) |
504 | return rewriter.notifyMatchFailure( |
505 | convOp, "expected a static shape for the filter" ); |
506 | |
507 | if (!inputType.hasStaticShape()) |
508 | return rewriter.notifyMatchFailure(convOp, |
509 | "expected a static shape for the input" ); |
510 | |
511 | // TODO: Support dilation. |
512 | if (!hasAllOneValues(convOp.getDilations())) |
513 | return rewriter.notifyMatchFailure(convOp, |
514 | "expected all ones for dilations" ); |
515 | |
516 | MLIRContext *context = rewriter.getContext(); |
517 | Value input = convOp.getInputs()[0]; |
518 | Value filter = convOp.getInputs()[1]; |
519 | Value output = convOp.getOutputs()[0]; |
520 | |
521 | ArrayRef<int64_t> filterShape = filterType.getShape(); |
522 | ArrayRef<int64_t> outputShape = outputType.getShape(); |
523 | |
524 | int64_t n = outputShape[0]; |
525 | int64_t oh = outputShape[1]; |
526 | int64_t ow = outputShape[2]; |
527 | int64_t oc = outputShape[3]; |
528 | int64_t fh = filterShape[1]; |
529 | int64_t fw = filterShape[2]; |
530 | int64_t ic = filterShape[3]; |
531 | |
532 | Location loc = convOp.getLoc(); |
533 | |
534 | // Reshape output and filter to the LHS and result of a "row-wise" matrix |
535 | // multiplication. |
536 | SmallVector<ReassociationIndices> filterReassocIndices = {{0}, {1, 2, 3}}; |
537 | auto reshapedFilterType = |
538 | RankedTensorType::get({oc, fh * fw * ic}, filterType.getElementType()); |
539 | Value reshapedFilter = rewriter.create<tensor::CollapseShapeOp>( |
540 | loc, reshapedFilterType, filter, filterReassocIndices); |
541 | |
542 | SmallVector<ReassociationIndices> outputReassocIndices = {{0}, {1, 2}, {3}}; |
543 | RankedTensorType reshapedOutputType = |
544 | RankedTensorType::get({n, oh * ow, oc}, outputType.getElementType()); |
545 | Value reshapedOutput = rewriter.create<tensor::CollapseShapeOp>( |
546 | loc, reshapedOutputType, output, outputReassocIndices); |
547 | |
548 | SmallVector<int64_t> colTensorShape = {n, oh * ow, fh * fw * ic}; |
549 | Value colTensor = rewriter.create<tensor::EmptyOp>( |
550 | loc, colTensorShape, inputType.getElementType()); |
551 | |
552 | // Convert the input to a (BMK) column tensor. |
553 | auto nloops = colTensorShape.size(); |
554 | |
555 | auto parallel = utils::IteratorType::parallel; |
556 | auto reduction = utils::IteratorType::reduction; |
557 | SmallVector<utils::IteratorType> img2colIterators(nloops, parallel); |
558 | |
559 | SmallVector<AffineMap> img2colIndexingMaps = { |
560 | AffineMap::getMultiDimIdentityMap(numDims: nloops, context)}; |
561 | |
562 | auto img2ColTensor = rewriter.create<linalg::GenericOp>( |
563 | loc, colTensor.getType(), |
564 | /*inputs=*/ValueRange{}, /*outputs=*/colTensor, img2colIndexingMaps, |
565 | img2colIterators, |
566 | [&](OpBuilder &nestedBuilder, Location nestedLoc, ValueRange args) { |
567 | // Get the iterators named based on the matmul (batch, m, k). |
568 | Value bIndex = nestedBuilder.create<linalg::IndexOp>(loc, 0); |
569 | Value mIndex = nestedBuilder.create<linalg::IndexOp>(loc, 1); |
570 | Value kIndex = nestedBuilder.create<linalg::IndexOp>(loc, 2); |
571 | |
572 | // Recover the original iteration indices from the problem/input sizes. |
573 | SmallVector<Value> mIndices = unrollIndex( |
574 | nestedBuilder, nestedLoc, mIndex, ArrayRef<int64_t>{oh, ow}); |
575 | auto ohIndex = mIndices[0]; |
576 | auto owIndex = mIndices[1]; |
577 | |
578 | SmallVector<Value> kIndices = unrollIndex( |
579 | nestedBuilder, nestedLoc, kIndex, ArrayRef<int64_t>{fh, fw, ic}); |
580 | auto fhIndex = kIndices[0]; |
581 | auto fwIndex = kIndices[1]; |
582 | auto icIndex = kIndices[2]; |
583 | |
584 | // Extract the input element corresponding to the expanded indices. |
585 | Value hIndex = |
586 | getConvolvedIndex(nestedBuilder, nestedLoc, ohIndex, fhIndex, |
587 | convOp.getStrides().getValues<int64_t>()[0]); |
588 | Value wIndex = |
589 | getConvolvedIndex(nestedBuilder, nestedLoc, owIndex, fwIndex, |
590 | convOp.getStrides().getValues<int64_t>()[1]); |
591 | |
592 | // im2col[n, oh*ow, fh*fw*ic] = input[n, sh*oh + fh, sw*ow + fw, ic] |
593 | SmallVector<Value> extractionIndices{bIndex, hIndex, wIndex, icIndex}; |
594 | Value inputVal = nestedBuilder.create<tensor::ExtractOp>( |
595 | loc, input, extractionIndices); |
596 | nestedBuilder.create<linalg::YieldOp>(nestedLoc, inputVal); |
597 | }); |
598 | |
599 | // Because we didn't transpose the filters we don't actually have a batched |
600 | // matrix multiply. Instead, we have an operation consisting of "row-wise" dot |
601 | // products. |
602 | AffineExpr bDim, mDim, nDim, kDim; |
603 | bindDims(ctx: context, exprs&: bDim, exprs&: mDim, exprs&: nDim, exprs&: kDim); |
604 | auto lhsMap = AffineMap::get(dimCount: 4, symbolCount: 0, results: {bDim, mDim, kDim}, context); |
605 | auto rhsMap = AffineMap::get(dimCount: 4, symbolCount: 0, results: {nDim, kDim}, context); |
606 | auto resultMap = AffineMap::get(dimCount: 4, symbolCount: 0, results: {bDim, mDim, nDim}, context); |
607 | SmallVector<utils::IteratorType> genericIterators = {parallel, parallel, |
608 | parallel, reduction}; |
609 | |
610 | auto genericOp = rewriter.create<linalg::GenericOp>( |
611 | loc, reshapedOutputType, |
612 | /*inputs=*/ValueRange{img2ColTensor.getResult(0), reshapedFilter}, |
613 | /*outputs=*/ValueRange{reshapedOutput}, |
614 | ArrayRef<AffineMap>{lhsMap, rhsMap, resultMap}, genericIterators, |
615 | [&](OpBuilder &nestedBuilder, Location nestedLoc, ValueRange args) { |
616 | Value mul = |
617 | createMul(loc, args[0], args[1], args[2].getType(), nestedBuilder); |
618 | Value add = createAdd(loc, mul, args[2], nestedBuilder); |
619 | nestedBuilder.create<linalg::YieldOp>(nestedLoc, add); |
620 | }); |
621 | Value result = genericOp.getResults().front(); |
622 | |
623 | auto reshapedResult = rewriter.create<tensor::ExpandShapeOp>( |
624 | loc, outputType, result, outputReassocIndices); |
625 | |
626 | rewriter.replaceOp(convOp, ArrayRef<Value>{reshapedResult}); |
627 | |
628 | return std::make_pair(img2ColTensor.getOperation(), |
629 | reshapedResult.getOperation()); |
630 | } |
631 | |
632 | namespace { |
633 | |
634 | class ConvertConv2DNhwcHwcf final |
635 | : public OpRewritePattern<linalg::Conv2DNhwcHwcfOp> { |
636 | public: |
637 | using OpRewritePattern::OpRewritePattern; |
638 | |
639 | LogicalResult matchAndRewrite(linalg::Conv2DNhwcHwcfOp convOp, |
640 | PatternRewriter &rewriter) const override { |
641 | if (failed(rewriteInIm2Col(rewriter, convOp))) |
642 | return failure(); |
643 | return success(); |
644 | } |
645 | }; |
646 | |
647 | class ConvertDepthwiseConv2DNhwcHwc final |
648 | : public OpRewritePattern<linalg::DepthwiseConv2DNhwcHwcOp> { |
649 | public: |
650 | using OpRewritePattern<linalg::DepthwiseConv2DNhwcHwcOp>::OpRewritePattern; |
651 | |
652 | LogicalResult matchAndRewrite(linalg::DepthwiseConv2DNhwcHwcOp convOp, |
653 | PatternRewriter &rewriter) const override { |
654 | if (failed(rewriteInIm2Col(rewriter, convOp))) |
655 | return failure(); |
656 | return success(); |
657 | } |
658 | }; |
659 | |
660 | class ConvertConv2DNchwFchw final |
661 | : public OpRewritePattern<linalg::Conv2DNchwFchwOp> { |
662 | public: |
663 | using OpRewritePattern::OpRewritePattern; |
664 | |
665 | LogicalResult matchAndRewrite(linalg::Conv2DNchwFchwOp convOp, |
666 | PatternRewriter &rewriter) const override { |
667 | if (failed(rewriteInIm2Col(rewriter, convOp))) |
668 | return failure(); |
669 | return success(); |
670 | } |
671 | }; |
672 | |
673 | class ConvertConv2DNhwcFhwc final |
674 | : public OpRewritePattern<linalg::Conv2DNhwcFhwcOp> { |
675 | public: |
676 | using OpRewritePattern::OpRewritePattern; |
677 | |
678 | LogicalResult matchAndRewrite(linalg::Conv2DNhwcFhwcOp convOp, |
679 | PatternRewriter &rewriter) const override { |
680 | if (failed(rewriteInIm2Col(rewriter, convOp))) |
681 | return failure(); |
682 | return success(); |
683 | } |
684 | }; |
685 | } // end anonymous namespace |
686 | |
687 | void populateConvertConv2DToImg2ColPatterns(RewritePatternSet &patterns) { |
688 | MLIRContext *context = patterns.getContext(); |
689 | patterns.insert<ConvertConv2DNhwcHwcf, ConvertDepthwiseConv2DNhwcHwc, |
690 | ConvertConv2DNchwFchw, ConvertConv2DNhwcFhwc>(arg&: context); |
691 | } |
692 | } // end namespace linalg |
693 | } // end namespace mlir |
694 | |