1//===- FuncTransformOps.cpp - Implementation of CF transform ops ---===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#include "mlir/Dialect/Func/TransformOps/FuncTransformOps.h"
10
11#include "mlir/Conversion/FuncToLLVM/ConvertFuncToLLVM.h"
12#include "mlir/Conversion/LLVMCommon/TypeConverter.h"
13#include "mlir/Dialect/Func/IR/FuncOps.h"
14#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
15#include "mlir/Dialect/Transform/IR/TransformDialect.h"
16#include "mlir/Dialect/Transform/IR/TransformOps.h"
17#include "mlir/Dialect/Transform/Interfaces/TransformInterfaces.h"
18#include "mlir/Transforms/DialectConversion.h"
19
20using namespace mlir;
21
22//===----------------------------------------------------------------------===//
23// Apply...ConversionPatternsOp
24//===----------------------------------------------------------------------===//
25
26void transform::ApplyFuncToLLVMConversionPatternsOp::populatePatterns(
27 TypeConverter &typeConverter, RewritePatternSet &patterns) {
28 populateFuncToLLVMConversionPatterns(
29 static_cast<LLVMTypeConverter &>(typeConverter), patterns);
30}
31
32LogicalResult
33transform::ApplyFuncToLLVMConversionPatternsOp::verifyTypeConverter(
34 transform::TypeConverterBuilderOpInterface builder) {
35 if (builder.getTypeConverterType() != "LLVMTypeConverter")
36 return emitOpError("expected LLVMTypeConverter");
37 return success();
38}
39
40//===----------------------------------------------------------------------===//
41// CastAndCallOp
42//===----------------------------------------------------------------------===//
43
44DiagnosedSilenceableFailure
45transform::CastAndCallOp::apply(transform::TransformRewriter &rewriter,
46 transform::TransformResults &results,
47 transform::TransformState &state) {
48 SmallVector<Value> inputs;
49 if (getInputs())
50 llvm::append_range(inputs, state.getPayloadValues(getInputs()));
51
52 SetVector<Value> outputs;
53 if (getOutputs()) {
54 for (auto output : state.getPayloadValues(getOutputs()))
55 outputs.insert(output);
56
57 // Verify that the set of output values to be replaced is unique.
58 if (outputs.size() !=
59 llvm::range_size(state.getPayloadValues(getOutputs()))) {
60 return emitSilenceableFailure(getLoc())
61 << "cast and call output values must be unique";
62 }
63 }
64
65 // Get the insertion point for the call.
66 auto insertionOps = state.getPayloadOps(getInsertionPoint());
67 if (!llvm::hasSingleElement(insertionOps)) {
68 return emitSilenceableFailure(getLoc())
69 << "Only one op can be specified as an insertion point";
70 }
71 bool insertAfter = getInsertAfter();
72 Operation *insertionPoint = *insertionOps.begin();
73
74 // Check that all inputs dominate the insertion point, and the insertion
75 // point dominates all users of the outputs.
76 DominanceInfo dom(insertionPoint);
77 for (Value output : outputs) {
78 for (Operation *user : output.getUsers()) {
79 // If we are inserting after the insertion point operation, the
80 // insertion point operation must properly dominate the user. Otherwise
81 // basic dominance is enough.
82 bool doesDominate = insertAfter
83 ? dom.properlyDominates(insertionPoint, user)
84 : dom.dominates(insertionPoint, user);
85 if (!doesDominate) {
86 return emitDefiniteFailure()
87 << "User " << user << " is not dominated by insertion point "
88 << insertionPoint;
89 }
90 }
91 }
92
93 for (Value input : inputs) {
94 // If we are inserting before the insertion point operation, the
95 // input must properly dominate the insertion point operation. Otherwise
96 // basic dominance is enough.
97 bool doesDominate = insertAfter
98 ? dom.dominates(input, insertionPoint)
99 : dom.properlyDominates(input, insertionPoint);
100 if (!doesDominate) {
101 return emitDefiniteFailure()
102 << "input " << input << " does not dominate insertion point "
103 << insertionPoint;
104 }
105 }
106
107 // Get the function to call. This can either be specified by symbol or as a
108 // transform handle.
109 func::FuncOp targetFunction = nullptr;
110 if (getFunctionName()) {
111 targetFunction = SymbolTable::lookupNearestSymbolFrom<func::FuncOp>(
112 insertionPoint, *getFunctionName());
113 if (!targetFunction) {
114 return emitDefiniteFailure()
115 << "unresolved symbol " << *getFunctionName();
116 }
117 } else if (getFunction()) {
118 auto payloadOps = state.getPayloadOps(getFunction());
119 if (!llvm::hasSingleElement(payloadOps)) {
120 return emitDefiniteFailure() << "requires a single function to call";
121 }
122 targetFunction = dyn_cast<func::FuncOp>(*payloadOps.begin());
123 if (!targetFunction) {
124 return emitDefiniteFailure() << "invalid non-function callee";
125 }
126 } else {
127 llvm_unreachable("Invalid CastAndCall op without a function to call");
128 return emitDefiniteFailure();
129 }
130
131 // Verify that the function argument and result lengths match the inputs and
132 // outputs given to this op.
133 if (targetFunction.getNumArguments() != inputs.size()) {
134 return emitSilenceableFailure(targetFunction.getLoc())
135 << "mismatch between number of function arguments "
136 << targetFunction.getNumArguments() << " and number of inputs "
137 << inputs.size();
138 }
139 if (targetFunction.getNumResults() != outputs.size()) {
140 return emitSilenceableFailure(targetFunction.getLoc())
141 << "mismatch between number of function results "
142 << targetFunction->getNumResults() << " and number of outputs "
143 << outputs.size();
144 }
145
146 // Gather all specified converters.
147 mlir::TypeConverter converter;
148 if (!getRegion().empty()) {
149 for (Operation &op : getRegion().front()) {
150 cast<transform::TypeConverterBuilderOpInterface>(&op)
151 .populateTypeMaterializations(converter);
152 }
153 }
154
155 if (insertAfter)
156 rewriter.setInsertionPointAfter(insertionPoint);
157 else
158 rewriter.setInsertionPoint(insertionPoint);
159
160 for (auto [input, type] :
161 llvm::zip_equal(inputs, targetFunction.getArgumentTypes())) {
162 if (input.getType() != type) {
163 Value newInput = converter.materializeSourceConversion(
164 rewriter, input.getLoc(), type, input);
165 if (!newInput) {
166 return emitDefiniteFailure() << "Failed to materialize conversion of "
167 << input << " to type " << type;
168 }
169 input = newInput;
170 }
171 }
172
173 auto callOp = rewriter.create<func::CallOp>(insertionPoint->getLoc(),
174 targetFunction, inputs);
175
176 // Cast the call results back to the expected types. If any conversions fail
177 // this is a definite failure as the call has been constructed at this point.
178 for (auto [output, newOutput] :
179 llvm::zip_equal(outputs, callOp.getResults())) {
180 Value convertedOutput = newOutput;
181 if (output.getType() != newOutput.getType()) {
182 convertedOutput = converter.materializeTargetConversion(
183 rewriter, output.getLoc(), output.getType(), newOutput);
184 if (!convertedOutput) {
185 return emitDefiniteFailure()
186 << "Failed to materialize conversion of " << newOutput
187 << " to type " << output.getType();
188 }
189 }
190 rewriter.replaceAllUsesExcept(output, convertedOutput, callOp);
191 }
192 results.set(cast<OpResult>(getResult()), {callOp});
193 return DiagnosedSilenceableFailure::success();
194}
195
196LogicalResult transform::CastAndCallOp::verify() {
197 if (!getRegion().empty()) {
198 for (Operation &op : getRegion().front()) {
199 if (!isa<transform::TypeConverterBuilderOpInterface>(&op)) {
200 InFlightDiagnostic diag = emitOpError()
201 << "expected children ops to implement "
202 "TypeConverterBuilderOpInterface";
203 diag.attachNote(op.getLoc()) << "op without interface";
204 return diag;
205 }
206 }
207 }
208 if (!getFunction() && !getFunctionName()) {
209 return emitOpError() << "expected a function handle or name to call";
210 }
211 if (getFunction() && getFunctionName()) {
212 return emitOpError() << "function handle and name are mutually exclusive";
213 }
214 return success();
215}
216
217void transform::CastAndCallOp::getEffects(
218 SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
219 transform::onlyReadsHandle(getInsertionPoint(), effects);
220 if (getInputs())
221 transform::onlyReadsHandle(getInputs(), effects);
222 if (getOutputs())
223 transform::onlyReadsHandle(getOutputs(), effects);
224 if (getFunction())
225 transform::onlyReadsHandle(getFunction(), effects);
226 transform::producesHandle(getResult(), effects);
227 transform::modifiesPayload(effects);
228}
229
230//===----------------------------------------------------------------------===//
231// Transform op registration
232//===----------------------------------------------------------------------===//
233
234namespace {
235class FuncTransformDialectExtension
236 : public transform::TransformDialectExtension<
237 FuncTransformDialectExtension> {
238public:
239 using Base::Base;
240
241 void init() {
242 declareGeneratedDialect<LLVM::LLVMDialect>();
243
244 registerTransformOps<
245#define GET_OP_LIST
246#include "mlir/Dialect/Func/TransformOps/FuncTransformOps.cpp.inc"
247 >();
248 }
249};
250} // namespace
251
252#define GET_OP_CLASSES
253#include "mlir/Dialect/Func/TransformOps/FuncTransformOps.cpp.inc"
254
255void mlir::func::registerTransformDialectExtension(DialectRegistry &registry) {
256 registry.addExtensions<FuncTransformDialectExtension>();
257}
258

source code of mlir/lib/Dialect/Func/TransformOps/FuncTransformOps.cpp