1//===- LinalgToStandard.cpp - conversion from Linalg to Standard dialect --===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#include "mlir/Conversion/LinalgToStandard/LinalgToStandard.h"
10
11#include "mlir/Dialect/Affine/IR/AffineOps.h"
12#include "mlir/Dialect/Func/IR/FuncOps.h"
13#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
14#include "mlir/Dialect/Linalg/IR/Linalg.h"
15#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
16#include "mlir/Dialect/MemRef/IR/MemRef.h"
17#include "mlir/Dialect/SCF/IR/SCF.h"
18
19namespace mlir {
20#define GEN_PASS_DEF_CONVERTLINALGTOSTANDARDPASS
21#include "mlir/Conversion/Passes.h.inc"
22} // namespace mlir
23
24using namespace mlir;
25using namespace mlir::linalg;
26
27static MemRefType makeStridedLayoutDynamic(MemRefType type) {
28 return MemRefType::Builder(type).setLayout(StridedLayoutAttr::get(
29 context: type.getContext(), offset: ShapedType::kDynamic,
30 strides: SmallVector<int64_t>(type.getRank(), ShapedType::kDynamic)));
31}
32
33/// Helper function to extract the operand types that are passed to the
34/// generated CallOp. MemRefTypes have their layout canonicalized since the
35/// information is not used in signature generation.
36/// Note that static size information is not modified.
37static SmallVector<Type, 4> extractOperandTypes(Operation *op) {
38 SmallVector<Type, 4> result;
39 result.reserve(N: op->getNumOperands());
40 for (auto type : op->getOperandTypes()) {
41 // The underlying descriptor type (e.g. LLVM) does not have layout
42 // information. Canonicalizing the type at the level of std when going into
43 // a library call avoids needing to introduce DialectCastOp.
44 if (auto memrefType = dyn_cast<MemRefType>(Val&: type))
45 result.push_back(Elt: makeStridedLayoutDynamic(type: memrefType));
46 else
47 result.push_back(Elt: type);
48 }
49 return result;
50}
51
52// Get a SymbolRefAttr containing the library function name for the LinalgOp.
53// If the library function does not exist, insert a declaration.
54static FailureOr<FlatSymbolRefAttr>
55getLibraryCallSymbolRef(Operation *op, PatternRewriter &rewriter) {
56 auto linalgOp = cast<LinalgOp>(Val: op);
57 auto fnName = linalgOp.getLibraryCallName();
58 if (fnName.empty())
59 return rewriter.notifyMatchFailure(arg&: op, msg: "No library call defined for: ");
60
61 // fnName is a dynamic std::string, unique it via a SymbolRefAttr.
62 FlatSymbolRefAttr fnNameAttr =
63 SymbolRefAttr::get(ctx: rewriter.getContext(), value: fnName);
64 auto module = op->getParentOfType<ModuleOp>();
65 if (module.lookupSymbol(name: fnNameAttr.getAttr()))
66 return fnNameAttr;
67
68 SmallVector<Type, 4> inputTypes(extractOperandTypes(op));
69 if (op->getNumResults() != 0) {
70 return rewriter.notifyMatchFailure(
71 arg&: op,
72 msg: "Library call for linalg operation can be generated only for ops that "
73 "have void return types");
74 }
75 auto libFnType = rewriter.getFunctionType(inputs: inputTypes, results: {});
76
77 OpBuilder::InsertionGuard guard(rewriter);
78 // Insert before module terminator.
79 rewriter.setInsertionPoint(block: module.getBody(),
80 insertPoint: std::prev(x: module.getBody()->end()));
81 func::FuncOp funcOp = rewriter.create<func::FuncOp>(
82 location: op->getLoc(), args: fnNameAttr.getValue(), args&: libFnType);
83 // Insert a function attribute that will trigger the emission of the
84 // corresponding `_mlir_ciface_xxx` interface so that external libraries see
85 // a normalized ABI. This interface is added during std to llvm conversion.
86 funcOp->setAttr(name: LLVM::LLVMDialect::getEmitCWrapperAttrName(),
87 value: UnitAttr::get(context: op->getContext()));
88 funcOp.setPrivate();
89 return fnNameAttr;
90}
91
92static SmallVector<Value, 4>
93createTypeCanonicalizedMemRefOperands(OpBuilder &b, Location loc,
94 ValueRange operands) {
95 SmallVector<Value, 4> res;
96 res.reserve(N: operands.size());
97 for (auto op : operands) {
98 auto memrefType = dyn_cast<MemRefType>(Val: op.getType());
99 if (!memrefType) {
100 res.push_back(Elt: op);
101 continue;
102 }
103 Value cast =
104 b.create<memref::CastOp>(location: loc, args: makeStridedLayoutDynamic(type: memrefType), args&: op);
105 res.push_back(Elt: cast);
106 }
107 return res;
108}
109
110LogicalResult mlir::linalg::LinalgOpToLibraryCallRewrite::matchAndRewrite(
111 LinalgOp op, PatternRewriter &rewriter) const {
112 auto libraryCallName = getLibraryCallSymbolRef(op, rewriter);
113 if (failed(Result: libraryCallName))
114 return failure();
115
116 // TODO: Add support for more complex library call signatures that include
117 // indices or captured values.
118 rewriter.replaceOpWithNewOp<func::CallOp>(
119 op, args: libraryCallName->getValue(), args: TypeRange(),
120 args: createTypeCanonicalizedMemRefOperands(b&: rewriter, loc: op->getLoc(),
121 operands: op->getOperands()));
122 return success();
123}
124
125/// Populate the given list with patterns that convert from Linalg to Standard.
126void mlir::linalg::populateLinalgToStandardConversionPatterns(
127 RewritePatternSet &patterns) {
128 // TODO: ConvOp conversion needs to export a descriptor with relevant
129 // attribute values such as kernel striding and dilation.
130 patterns.add<LinalgOpToLibraryCallRewrite>(arg: patterns.getContext());
131}
132
133namespace {
134struct ConvertLinalgToStandardPass
135 : public impl::ConvertLinalgToStandardPassBase<
136 ConvertLinalgToStandardPass> {
137 void runOnOperation() override;
138};
139} // namespace
140
141void ConvertLinalgToStandardPass::runOnOperation() {
142 auto module = getOperation();
143 ConversionTarget target(getContext());
144 target.addLegalDialect<affine::AffineDialect, arith::ArithDialect,
145 func::FuncDialect, memref::MemRefDialect,
146 scf::SCFDialect>();
147 target.addLegalOp<ModuleOp, func::FuncOp, func::ReturnOp>();
148 RewritePatternSet patterns(&getContext());
149 populateLinalgToStandardConversionPatterns(patterns);
150 if (failed(Result: applyFullConversion(op: module, target, patterns: std::move(patterns))))
151 signalPassFailure();
152}
153

source code of mlir/lib/Conversion/LinalgToStandard/LinalgToStandard.cpp