1 | //===- Linalg.cpp - C Interface for Linalg dialect ------------------------===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | |
9 | #include "mlir-c/Dialect/Linalg.h" |
10 | #include "mlir/CAPI/Registration.h" |
11 | #include "mlir/Dialect/Linalg/IR/Linalg.h" |
12 | |
13 | using namespace mlir; |
14 | using namespace mlir::linalg; |
15 | |
16 | /// Apply the special region builder for the builtin named Linalg op. |
17 | /// Assert that `op` is a builtin named Linalg op. |
18 | void mlirLinalgFillBuiltinNamedOpRegion(MlirOperation mlirOp) { |
19 | Operation *op = unwrap(c: mlirOp); |
20 | auto linalgOp = cast<LinalgOp>(op); |
21 | auto *dialect = static_cast<LinalgDialect *>(linalgOp->getDialect()); |
22 | LinalgDialect::RegionBuilderFunType fun = |
23 | dialect->getRegionBuilder(op->getName().getStringRef()); |
24 | |
25 | assert(fun && "Expected a builtin named Linalg op." ); |
26 | assert(op->getNumRegions() == 1 && "Expected Linalg op with 1 region" ); |
27 | assert(op->getRegion(0).getBlocks().empty() && |
28 | "Expected Linalg op with 0 blocks" ); |
29 | |
30 | SmallVector<Type, 8> argTypes; |
31 | SmallVector<Location, 8> argLocs; |
32 | for (OpOperand &opOperand : linalgOp->getOpOperands()) { |
33 | argTypes.push_back(getElementTypeOrSelf(opOperand.get().getType())); |
34 | argLocs.push_back(opOperand.get().getLoc()); |
35 | } |
36 | |
37 | ImplicitLocOpBuilder b(op->getLoc(), op->getContext()); |
38 | Region ®ion = op->getRegion(index: 0); |
39 | Block *body = b.createBlock(parent: ®ion, /*insertPt=*/{}, argTypes, locs: argLocs); |
40 | b.setInsertionPointToStart(body); |
41 | fun(b, *body, op->getAttrs()); |
42 | } |
43 | |
44 | MLIR_CAPI_EXPORTED bool mlirLinalgIsAContractionOp(MlirOperation op) { |
45 | auto linalgOp = llvm::dyn_cast<mlir::linalg::LinalgOp>(unwrap(c: op)); |
46 | // isaContractionOpInterface handles null linalgOp internally. |
47 | return linalg::isaContractionOpInterface(linalgOp: linalgOp); |
48 | } |
49 | |
50 | MLIR_CAPI_EXPORTED MlirLinalgContractionDimensions |
51 | mlirLinalgInferContractionDimensions(MlirOperation op) { |
52 | MlirLinalgContractionDimensions result{}; |
53 | auto linalgOp = dyn_cast<linalg::LinalgOp>(unwrap(c: op)); |
54 | if (!linalgOp) |
55 | return result; |
56 | |
57 | FailureOr<linalg::ContractionDimensions> maybeDims = |
58 | linalg::inferContractionDims(linalgOp); |
59 | if (failed(Result: maybeDims)) |
60 | return result; |
61 | |
62 | linalg::ContractionDimensions contractionDims = *maybeDims; |
63 | MLIRContext *ctx = linalgOp.getContext(); |
64 | |
65 | auto toAttr = [&ctx](const SmallVector<unsigned, 2> &vals) -> MlirAttribute { |
66 | return wrap( |
67 | DenseI32ArrayAttr::get(ctx, llvm::to_vector_of<int32_t, 2>(vals))); |
68 | }; |
69 | |
70 | result.batch = toAttr(contractionDims.batch); |
71 | result.m = toAttr(contractionDims.m); |
72 | result.n = toAttr(contractionDims.n); |
73 | result.k = toAttr(contractionDims.k); |
74 | |
75 | return result; |
76 | } |
77 | |
78 | MLIR_CAPI_EXPORTED bool mlirLinalgIsAConvolutionOp(MlirOperation op) { |
79 | auto linalgOp = llvm::dyn_cast<mlir::linalg::LinalgOp>(unwrap(c: op)); |
80 | if (!linalgOp) |
81 | return false; |
82 | |
83 | return linalg::isaConvolutionOpInterface(linalgOp: linalgOp); |
84 | } |
85 | |
86 | MLIR_CAPI_EXPORTED MlirLinalgConvolutionDimensions |
87 | mlirLinalgInferConvolutionDimensions(MlirOperation op) { |
88 | MlirLinalgConvolutionDimensions result{}; |
89 | auto linalgOp = llvm::dyn_cast<mlir::linalg::LinalgOp>(unwrap(c: op)); |
90 | if (!linalgOp) |
91 | return result; |
92 | |
93 | FailureOr<linalg::ConvolutionDimensions> maybeDims = |
94 | linalg::inferConvolutionDims(linalgOp: linalgOp); |
95 | if (failed(Result: maybeDims)) |
96 | return result; |
97 | |
98 | linalg::ConvolutionDimensions dims = *maybeDims; |
99 | MLIRContext *ctx = linalgOp.getContext(); |
100 | |
101 | auto toI32Attr = |
102 | [&ctx](const SmallVector<unsigned, 2> &vals) -> MlirAttribute { |
103 | return wrap(DenseI32ArrayAttr::get(ctx, llvm::to_vector_of<int32_t>(vals))); |
104 | }; |
105 | |
106 | auto toI64Attr = |
107 | [&ctx](const SmallVector<int64_t, 2> &vals) -> MlirAttribute { |
108 | return wrap(DenseI64ArrayAttr::get(ctx, vals)); |
109 | }; |
110 | |
111 | result.batch = toI32Attr(dims.batch); |
112 | result.outputImage = toI32Attr(dims.outputImage); |
113 | result.outputChannel = toI32Attr(dims.outputChannel); |
114 | result.filterLoop = toI32Attr(dims.filterLoop); |
115 | result.inputChannel = toI32Attr(dims.inputChannel); |
116 | result.depth = toI32Attr(dims.depth); |
117 | result.strides = toI64Attr(dims.strides); |
118 | result.dilations = toI64Attr(dims.dilations); |
119 | |
120 | return result; |
121 | } |
122 | |
123 | MLIR_CAPI_EXPORTED MlirAttribute |
124 | mlirLinalgGetIndexingMapsAttribute(MlirOperation op) { |
125 | auto linalgOp = llvm::dyn_cast<mlir::linalg::LinalgOp>(unwrap(c: op)); |
126 | if (!linalgOp) |
127 | return MlirAttribute{.ptr: nullptr}; |
128 | |
129 | ArrayAttr attr = linalgOp.getIndexingMaps(); |
130 | return wrap(attr); |
131 | } |
132 | |
133 | MLIR_DEFINE_CAPI_DIALECT_REGISTRATION(Linalg, linalg, LinalgDialect) |
134 | |