| 1 | //===- LinalgToStandard.cpp - conversion from Linalg to Standard dialect --===// |
| 2 | // |
| 3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
| 4 | // See https://llvm.org/LICENSE.txt for license information. |
| 5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
| 6 | // |
| 7 | //===----------------------------------------------------------------------===// |
| 8 | |
| 9 | #include "mlir/Conversion/LinalgToStandard/LinalgToStandard.h" |
| 10 | |
| 11 | #include "mlir/Dialect/Affine/IR/AffineOps.h" |
| 12 | #include "mlir/Dialect/Func/IR/FuncOps.h" |
| 13 | #include "mlir/Dialect/LLVMIR/LLVMDialect.h" |
| 14 | #include "mlir/Dialect/Linalg/IR/Linalg.h" |
| 15 | #include "mlir/Dialect/Linalg/Transforms/Transforms.h" |
| 16 | #include "mlir/Dialect/MemRef/IR/MemRef.h" |
| 17 | #include "mlir/Dialect/SCF/IR/SCF.h" |
| 18 | #include "mlir/Pass/Pass.h" |
| 19 | |
| 20 | namespace mlir { |
| 21 | #define GEN_PASS_DEF_CONVERTLINALGTOSTANDARDPASS |
| 22 | #include "mlir/Conversion/Passes.h.inc" |
| 23 | } // namespace mlir |
| 24 | |
| 25 | using namespace mlir; |
| 26 | using namespace mlir::linalg; |
| 27 | |
| 28 | static MemRefType makeStridedLayoutDynamic(MemRefType type) { |
| 29 | return MemRefType::Builder(type).setLayout(StridedLayoutAttr::get( |
| 30 | type.getContext(), ShapedType::kDynamic, |
| 31 | SmallVector<int64_t>(type.getRank(), ShapedType::kDynamic))); |
| 32 | } |
| 33 | |
| 34 | /// Helper function to extract the operand types that are passed to the |
| 35 | /// generated CallOp. MemRefTypes have their layout canonicalized since the |
| 36 | /// information is not used in signature generation. |
| 37 | /// Note that static size information is not modified. |
| 38 | static SmallVector<Type, 4> extractOperandTypes(Operation *op) { |
| 39 | SmallVector<Type, 4> result; |
| 40 | result.reserve(N: op->getNumOperands()); |
| 41 | for (auto type : op->getOperandTypes()) { |
| 42 | // The underlying descriptor type (e.g. LLVM) does not have layout |
| 43 | // information. Canonicalizing the type at the level of std when going into |
| 44 | // a library call avoids needing to introduce DialectCastOp. |
| 45 | if (auto memrefType = dyn_cast<MemRefType>(type)) |
| 46 | result.push_back(makeStridedLayoutDynamic(memrefType)); |
| 47 | else |
| 48 | result.push_back(Elt: type); |
| 49 | } |
| 50 | return result; |
| 51 | } |
| 52 | |
| 53 | // Get a SymbolRefAttr containing the library function name for the LinalgOp. |
| 54 | // If the library function does not exist, insert a declaration. |
| 55 | static FailureOr<FlatSymbolRefAttr> |
| 56 | getLibraryCallSymbolRef(Operation *op, PatternRewriter &rewriter) { |
| 57 | auto linalgOp = cast<LinalgOp>(op); |
| 58 | auto fnName = linalgOp.getLibraryCallName(); |
| 59 | if (fnName.empty()) |
| 60 | return rewriter.notifyMatchFailure(arg&: op, msg: "No library call defined for: " ); |
| 61 | |
| 62 | // fnName is a dynamic std::string, unique it via a SymbolRefAttr. |
| 63 | FlatSymbolRefAttr fnNameAttr = |
| 64 | SymbolRefAttr::get(rewriter.getContext(), fnName); |
| 65 | auto module = op->getParentOfType<ModuleOp>(); |
| 66 | if (module.lookupSymbol(fnNameAttr.getAttr())) |
| 67 | return fnNameAttr; |
| 68 | |
| 69 | SmallVector<Type, 4> inputTypes(extractOperandTypes(op)); |
| 70 | if (op->getNumResults() != 0) { |
| 71 | return rewriter.notifyMatchFailure( |
| 72 | arg&: op, |
| 73 | msg: "Library call for linalg operation can be generated only for ops that " |
| 74 | "have void return types" ); |
| 75 | } |
| 76 | auto libFnType = rewriter.getFunctionType(inputTypes, {}); |
| 77 | |
| 78 | OpBuilder::InsertionGuard guard(rewriter); |
| 79 | // Insert before module terminator. |
| 80 | rewriter.setInsertionPoint(module.getBody(), |
| 81 | std::prev(module.getBody()->end())); |
| 82 | func::FuncOp funcOp = rewriter.create<func::FuncOp>( |
| 83 | op->getLoc(), fnNameAttr.getValue(), libFnType); |
| 84 | // Insert a function attribute that will trigger the emission of the |
| 85 | // corresponding `_mlir_ciface_xxx` interface so that external libraries see |
| 86 | // a normalized ABI. This interface is added during std to llvm conversion. |
| 87 | funcOp->setAttr(LLVM::LLVMDialect::getEmitCWrapperAttrName(), |
| 88 | UnitAttr::get(op->getContext())); |
| 89 | funcOp.setPrivate(); |
| 90 | return fnNameAttr; |
| 91 | } |
| 92 | |
| 93 | static SmallVector<Value, 4> |
| 94 | createTypeCanonicalizedMemRefOperands(OpBuilder &b, Location loc, |
| 95 | ValueRange operands) { |
| 96 | SmallVector<Value, 4> res; |
| 97 | res.reserve(N: operands.size()); |
| 98 | for (auto op : operands) { |
| 99 | auto memrefType = dyn_cast<MemRefType>(op.getType()); |
| 100 | if (!memrefType) { |
| 101 | res.push_back(Elt: op); |
| 102 | continue; |
| 103 | } |
| 104 | Value cast = |
| 105 | b.create<memref::CastOp>(loc, makeStridedLayoutDynamic(memrefType), op); |
| 106 | res.push_back(Elt: cast); |
| 107 | } |
| 108 | return res; |
| 109 | } |
| 110 | |
| 111 | LogicalResult mlir::linalg::LinalgOpToLibraryCallRewrite::matchAndRewrite( |
| 112 | LinalgOp op, PatternRewriter &rewriter) const { |
| 113 | auto libraryCallName = getLibraryCallSymbolRef(op, rewriter); |
| 114 | if (failed(libraryCallName)) |
| 115 | return failure(); |
| 116 | |
| 117 | // TODO: Add support for more complex library call signatures that include |
| 118 | // indices or captured values. |
| 119 | rewriter.replaceOpWithNewOp<func::CallOp>( |
| 120 | op, libraryCallName->getValue(), TypeRange(), |
| 121 | createTypeCanonicalizedMemRefOperands(rewriter, op->getLoc(), |
| 122 | op->getOperands())); |
| 123 | return success(); |
| 124 | } |
| 125 | |
| 126 | /// Populate the given list with patterns that convert from Linalg to Standard. |
| 127 | void mlir::linalg::populateLinalgToStandardConversionPatterns( |
| 128 | RewritePatternSet &patterns) { |
| 129 | // TODO: ConvOp conversion needs to export a descriptor with relevant |
| 130 | // attribute values such as kernel striding and dilation. |
| 131 | patterns.add<LinalgOpToLibraryCallRewrite>(arg: patterns.getContext()); |
| 132 | } |
| 133 | |
| 134 | namespace { |
| 135 | struct ConvertLinalgToStandardPass |
| 136 | : public impl::ConvertLinalgToStandardPassBase< |
| 137 | ConvertLinalgToStandardPass> { |
| 138 | void runOnOperation() override; |
| 139 | }; |
| 140 | } // namespace |
| 141 | |
| 142 | void ConvertLinalgToStandardPass::runOnOperation() { |
| 143 | auto module = getOperation(); |
| 144 | ConversionTarget target(getContext()); |
| 145 | target.addLegalDialect<affine::AffineDialect, arith::ArithDialect, |
| 146 | func::FuncDialect, memref::MemRefDialect, |
| 147 | scf::SCFDialect>(); |
| 148 | target.addLegalOp<ModuleOp, func::FuncOp, func::ReturnOp>(); |
| 149 | RewritePatternSet patterns(&getContext()); |
| 150 | populateLinalgToStandardConversionPatterns(patterns); |
| 151 | if (failed(applyFullConversion(module, target, std::move(patterns)))) |
| 152 | signalPassFailure(); |
| 153 | } |
| 154 | |