1//===- NamedOpConversions.cpp - Implements conversions between named ops --===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implements conversions between named ops that can be seens as
10// canonicalizations of named ops.
11//
12//===----------------------------------------------------------------------===//
13
14#include "mlir/Dialect/Linalg/Passes.h"
15
16#include "mlir/Dialect/Linalg/IR/Linalg.h"
17#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
18#include "mlir/IR/PatternMatch.h"
19#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
20#include "llvm/ADT/SmallVector.h"
21#include "llvm/ADT/TypeSwitch.h"
22
23namespace mlir {
24#define GEN_PASS_DEF_LINALGNAMEDOPCONVERSIONPASS
25#include "mlir/Dialect/Linalg/Passes.h.inc"
26} // namespace mlir
27
28using namespace mlir;
29using namespace mlir::linalg;
30
31static llvm::SmallVector<int64_t> getIndicesVector(int start, int end) {
32 return llvm::to_vector<2>(llvm::seq<int64_t>(start, end));
33}
34
35static LogicalResult
36matchAndReplaceDepthwiseConv(Operation *operation, Value input, Value kernel,
37 Value iZp, Value kZp, Value init, Attribute stride,
38 Attribute dilation, PatternRewriter &rewriter) {
39 Location loc = operation->getLoc();
40 auto linalgOp = dyn_cast<LinalgOp>(operation);
41 // Exit out on the memref version of this operation.
42 if (!linalgOp || !linalgOp.hasPureTensorSemantics())
43 return failure();
44
45 auto result = operation->getResult(idx: 0);
46
47 auto kernelTy = dyn_cast<RankedTensorType>(kernel.getType());
48 auto initTy = dyn_cast<RankedTensorType>(init.getType());
49 auto resultTy = dyn_cast<RankedTensorType>(result.getType());
50 if (!kernelTy || !initTy || !resultTy)
51 return failure();
52
53 if (kernelTy.getDimSize(3) != 1)
54 return failure();
55
56 // Collapse kernel dims.
57 SmallVector<ReassociationIndices, 4> collapsedKernelDims = {
58 getIndicesVector(start: 0, end: 1), getIndicesVector(start: 1, end: 2), getIndicesVector(start: 2, end: 4)};
59 auto newKernelTy = RankedTensorType::get(
60 {kernelTy.getDimSize(0), kernelTy.getDimSize(1), kernelTy.getDimSize(2)},
61 kernelTy.getElementType());
62 auto collapsedKernel = rewriter.create<tensor::CollapseShapeOp>(
63 loc, newKernelTy, kernel, collapsedKernelDims);
64
65 // Collapse init dims.
66 SmallVector<ReassociationIndices, 4> collapsedInitDims = {
67 getIndicesVector(start: 0, end: 1), getIndicesVector(start: 1, end: 2), getIndicesVector(start: 2, end: 3),
68 getIndicesVector(start: 3, end: 5)};
69 auto newInitTy =
70 RankedTensorType::get({initTy.getDimSize(0), initTy.getDimSize(1),
71 initTy.getDimSize(2), initTy.getDimSize(3)},
72 initTy.getElementType());
73 auto collapsedInit = rewriter.create<tensor::CollapseShapeOp>(
74 loc, newInitTy, init, collapsedInitDims);
75
76 SmallVector<NamedAttribute> preservedAttrs;
77 Operation *newConv =
78 TypeSwitch<Operation *, Operation *>(operation)
79 .Case<DepthwiseConv2DNhwcHwcmOp>([&](auto op) {
80 preservedAttrs = getPrunedAttributeList(op);
81 return rewriter.create<DepthwiseConv2DNhwcHwcOp>(
82 loc, newInitTy, ValueRange{input, collapsedKernel},
83 ValueRange{collapsedInit}, stride, dilation);
84 })
85 .Case<DepthwiseConv2DNhwcHwcmQOp>([&](auto op) {
86 preservedAttrs = getPrunedAttributeList(op);
87 return rewriter.create<DepthwiseConv2DNhwcHwcQOp>(
88 loc, newInitTy, ValueRange{input, collapsedKernel, iZp, kZp},
89 ValueRange{collapsedInit}, stride, dilation);
90 })
91 .Default([](Operation *op) { return nullptr; });
92 if (!newConv)
93 return failure();
94 for (auto attr : preservedAttrs)
95 newConv->setAttr(attr.getName(), attr.getValue());
96
97 // Expand dimensions back out to
98 rewriter.replaceOpWithNewOp<tensor::ExpandShapeOp>(
99 operation, resultTy, newConv->getResult(0), collapsedInitDims);
100 return success();
101}
102
103namespace {
104struct SimplifyDepthwiseConvOp
105 : public OpRewritePattern<DepthwiseConv2DNhwcHwcmOp> {
106 using OpRewritePattern<DepthwiseConv2DNhwcHwcmOp>::OpRewritePattern;
107
108 LogicalResult matchAndRewrite(DepthwiseConv2DNhwcHwcmOp op,
109 PatternRewriter &rewriter) const override {
110 Operation *operation = op.getOperation();
111 Value input = op.getDpsInputOperand(0)->get();
112 Value kernel = op.getDpsInputOperand(1)->get();
113 Value init = op.getDpsInitOperand(0)->get();
114
115 auto stride = op.getStrides();
116 auto dilation = op.getDilations();
117
118 return matchAndReplaceDepthwiseConv(operation, input, kernel, nullptr,
119 nullptr, init, stride, dilation,
120 rewriter);
121 }
122};
123
124struct SimplifyDepthwiseConvQOp
125 : public OpRewritePattern<DepthwiseConv2DNhwcHwcmQOp> {
126 using OpRewritePattern<DepthwiseConv2DNhwcHwcmQOp>::OpRewritePattern;
127
128 LogicalResult matchAndRewrite(DepthwiseConv2DNhwcHwcmQOp op,
129 PatternRewriter &rewriter) const override {
130 Operation *operation = op.getOperation();
131 Value input = op.getDpsInputOperand(0)->get();
132 Value kernel = op.getDpsInputOperand(1)->get();
133 Value iZp = op.getDpsInputOperand(2)->get();
134 Value kZp = op.getDpsInputOperand(3)->get();
135 Value init = op.getDpsInitOperand(0)->get();
136
137 auto stride = op.getStrides();
138 auto dilation = op.getDilations();
139
140 return matchAndReplaceDepthwiseConv(operation, input, kernel, iZp, kZp,
141 init, stride, dilation, rewriter);
142 }
143};
144
145struct LinalgNamedOpConversionPass
146 : public impl::LinalgNamedOpConversionPassBase<
147 LinalgNamedOpConversionPass> {
148 using impl::LinalgNamedOpConversionPassBase<
149 LinalgNamedOpConversionPass>::LinalgNamedOpConversionPassBase;
150
151 void runOnOperation() override {
152 Operation *op = getOperation();
153 RewritePatternSet patterns(op->getContext());
154 populateLinalgNamedOpConversionPatterns(patterns);
155 if (failed(applyPatternsAndFoldGreedily(op, std::move(patterns))))
156 return signalPassFailure();
157 }
158};
159} // namespace
160
161void mlir::linalg::populateLinalgNamedOpConversionPatterns(
162 RewritePatternSet &patterns) {
163 patterns.add<SimplifyDepthwiseConvOp, SimplifyDepthwiseConvQOp>(
164 arg: patterns.getContext());
165}
166

source code of mlir/lib/Dialect/Linalg/Transforms/NamedOpConversions.cpp