1//===- LinalgToStandard.cpp - conversion from Linalg to Standard dialect --===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#include "mlir/Conversion/LinalgToStandard/LinalgToStandard.h"
10
11#include "mlir/Dialect/Affine/IR/AffineOps.h"
12#include "mlir/Dialect/Func/IR/FuncOps.h"
13#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
14#include "mlir/Dialect/Linalg/IR/Linalg.h"
15#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
16#include "mlir/Dialect/MemRef/IR/MemRef.h"
17#include "mlir/Dialect/SCF/IR/SCF.h"
18#include "mlir/Pass/Pass.h"
19
20namespace mlir {
21#define GEN_PASS_DEF_CONVERTLINALGTOSTANDARD
22#include "mlir/Conversion/Passes.h.inc"
23} // namespace mlir
24
25using namespace mlir;
26using namespace mlir::linalg;
27
28static MemRefType makeStridedLayoutDynamic(MemRefType type) {
29 return MemRefType::Builder(type).setLayout(StridedLayoutAttr::get(
30 type.getContext(), ShapedType::kDynamic,
31 SmallVector<int64_t>(type.getRank(), ShapedType::kDynamic)));
32}
33
34/// Helper function to extract the operand types that are passed to the
35/// generated CallOp. MemRefTypes have their layout canonicalized since the
36/// information is not used in signature generation.
37/// Note that static size information is not modified.
38static SmallVector<Type, 4> extractOperandTypes(Operation *op) {
39 SmallVector<Type, 4> result;
40 result.reserve(N: op->getNumOperands());
41 for (auto type : op->getOperandTypes()) {
42 // The underlying descriptor type (e.g. LLVM) does not have layout
43 // information. Canonicalizing the type at the level of std when going into
44 // a library call avoids needing to introduce DialectCastOp.
45 if (auto memrefType = dyn_cast<MemRefType>(type))
46 result.push_back(makeStridedLayoutDynamic(memrefType));
47 else
48 result.push_back(Elt: type);
49 }
50 return result;
51}
52
53// Get a SymbolRefAttr containing the library function name for the LinalgOp.
54// If the library function does not exist, insert a declaration.
55static FailureOr<FlatSymbolRefAttr>
56getLibraryCallSymbolRef(Operation *op, PatternRewriter &rewriter) {
57 auto linalgOp = cast<LinalgOp>(op);
58 auto fnName = linalgOp.getLibraryCallName();
59 if (fnName.empty())
60 return rewriter.notifyMatchFailure(arg&: op, msg: "No library call defined for: ");
61
62 // fnName is a dynamic std::string, unique it via a SymbolRefAttr.
63 FlatSymbolRefAttr fnNameAttr =
64 SymbolRefAttr::get(rewriter.getContext(), fnName);
65 auto module = op->getParentOfType<ModuleOp>();
66 if (module.lookupSymbol(fnNameAttr.getAttr()))
67 return fnNameAttr;
68
69 SmallVector<Type, 4> inputTypes(extractOperandTypes(op));
70 if (op->getNumResults() != 0) {
71 return rewriter.notifyMatchFailure(
72 arg&: op,
73 msg: "Library call for linalg operation can be generated only for ops that "
74 "have void return types");
75 }
76 auto libFnType = rewriter.getFunctionType(inputTypes, {});
77
78 OpBuilder::InsertionGuard guard(rewriter);
79 // Insert before module terminator.
80 rewriter.setInsertionPoint(module.getBody(),
81 std::prev(module.getBody()->end()));
82 func::FuncOp funcOp = rewriter.create<func::FuncOp>(
83 op->getLoc(), fnNameAttr.getValue(), libFnType);
84 // Insert a function attribute that will trigger the emission of the
85 // corresponding `_mlir_ciface_xxx` interface so that external libraries see
86 // a normalized ABI. This interface is added during std to llvm conversion.
87 funcOp->setAttr(LLVM::LLVMDialect::getEmitCWrapperAttrName(),
88 UnitAttr::get(op->getContext()));
89 funcOp.setPrivate();
90 return fnNameAttr;
91}
92
93static SmallVector<Value, 4>
94createTypeCanonicalizedMemRefOperands(OpBuilder &b, Location loc,
95 ValueRange operands) {
96 SmallVector<Value, 4> res;
97 res.reserve(N: operands.size());
98 for (auto op : operands) {
99 auto memrefType = dyn_cast<MemRefType>(op.getType());
100 if (!memrefType) {
101 res.push_back(Elt: op);
102 continue;
103 }
104 Value cast =
105 b.create<memref::CastOp>(loc, makeStridedLayoutDynamic(memrefType), op);
106 res.push_back(Elt: cast);
107 }
108 return res;
109}
110
111LogicalResult mlir::linalg::LinalgOpToLibraryCallRewrite::matchAndRewrite(
112 LinalgOp op, PatternRewriter &rewriter) const {
113 auto libraryCallName = getLibraryCallSymbolRef(op, rewriter);
114 if (failed(libraryCallName))
115 return failure();
116
117 // TODO: Add support for more complex library call signatures that include
118 // indices or captured values.
119 rewriter.replaceOpWithNewOp<func::CallOp>(
120 op, libraryCallName->getValue(), TypeRange(),
121 createTypeCanonicalizedMemRefOperands(rewriter, op->getLoc(),
122 op->getOperands()));
123 return success();
124}
125
126/// Populate the given list with patterns that convert from Linalg to Standard.
127void mlir::linalg::populateLinalgToStandardConversionPatterns(
128 RewritePatternSet &patterns) {
129 // TODO: ConvOp conversion needs to export a descriptor with relevant
130 // attribute values such as kernel striding and dilation.
131 patterns.add<LinalgOpToLibraryCallRewrite>(arg: patterns.getContext());
132}
133
134namespace {
135struct ConvertLinalgToStandardPass
136 : public impl::ConvertLinalgToStandardBase<ConvertLinalgToStandardPass> {
137 void runOnOperation() override;
138};
139} // namespace
140
141void ConvertLinalgToStandardPass::runOnOperation() {
142 auto module = getOperation();
143 ConversionTarget target(getContext());
144 target.addLegalDialect<affine::AffineDialect, arith::ArithDialect,
145 func::FuncDialect, memref::MemRefDialect,
146 scf::SCFDialect>();
147 target.addLegalOp<ModuleOp, func::FuncOp, func::ReturnOp>();
148 RewritePatternSet patterns(&getContext());
149 populateLinalgToStandardConversionPatterns(patterns);
150 if (failed(applyFullConversion(module, target, std::move(patterns))))
151 signalPassFailure();
152}
153
154std::unique_ptr<OperationPass<ModuleOp>>
155mlir::createConvertLinalgToStandardPass() {
156 return std::make_unique<ConvertLinalgToStandardPass>();
157}
158

source code of mlir/lib/Conversion/LinalgToStandard/LinalgToStandard.cpp