1//===- FuncTransformOps.cpp - Implementation of CF transform ops ----------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#include "mlir/Dialect/Func/TransformOps/FuncTransformOps.h"
10
11#include "mlir/Conversion/FuncToLLVM/ConvertFuncToLLVM.h"
12#include "mlir/Conversion/LLVMCommon/TypeConverter.h"
13#include "mlir/Dialect/Func/IR/FuncOps.h"
14#include "mlir/Dialect/Func/Utils/Utils.h"
15#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
16#include "mlir/Dialect/Transform/IR/TransformDialect.h"
17#include "mlir/Dialect/Transform/Interfaces/TransformInterfaces.h"
18#include "mlir/IR/PatternMatch.h"
19#include "mlir/Transforms/DialectConversion.h"
20
21using namespace mlir;
22
23//===----------------------------------------------------------------------===//
24// Apply...ConversionPatternsOp
25//===----------------------------------------------------------------------===//
26
27void transform::ApplyFuncToLLVMConversionPatternsOp::populatePatterns(
28 TypeConverter &typeConverter, RewritePatternSet &patterns) {
29 populateFuncToLLVMConversionPatterns(
30 converter: static_cast<LLVMTypeConverter &>(typeConverter), patterns);
31}
32
33LogicalResult
34transform::ApplyFuncToLLVMConversionPatternsOp::verifyTypeConverter(
35 transform::TypeConverterBuilderOpInterface builder) {
36 if (builder.getTypeConverterType() != "LLVMTypeConverter")
37 return emitOpError(message: "expected LLVMTypeConverter");
38 return success();
39}
40
41//===----------------------------------------------------------------------===//
42// CastAndCallOp
43//===----------------------------------------------------------------------===//
44
45DiagnosedSilenceableFailure
46transform::CastAndCallOp::apply(transform::TransformRewriter &rewriter,
47 transform::TransformResults &results,
48 transform::TransformState &state) {
49 SmallVector<Value> inputs;
50 if (getInputs())
51 llvm::append_range(C&: inputs, R: state.getPayloadValues(handleValue: getInputs()));
52
53 SetVector<Value> outputs;
54 if (getOutputs()) {
55 outputs.insert_range(R: state.getPayloadValues(handleValue: getOutputs()));
56
57 // Verify that the set of output values to be replaced is unique.
58 if (outputs.size() !=
59 llvm::range_size(Range: state.getPayloadValues(handleValue: getOutputs()))) {
60 return emitSilenceableFailure(loc: getLoc())
61 << "cast and call output values must be unique";
62 }
63 }
64
65 // Get the insertion point for the call.
66 auto insertionOps = state.getPayloadOps(value: getInsertionPoint());
67 if (!llvm::hasSingleElement(C&: insertionOps)) {
68 return emitSilenceableFailure(loc: getLoc())
69 << "Only one op can be specified as an insertion point";
70 }
71 bool insertAfter = getInsertAfter();
72 Operation *insertionPoint = *insertionOps.begin();
73
74 // Check that all inputs dominate the insertion point, and the insertion
75 // point dominates all users of the outputs.
76 DominanceInfo dom(insertionPoint);
77 for (Value output : outputs) {
78 for (Operation *user : output.getUsers()) {
79 // If we are inserting after the insertion point operation, the
80 // insertion point operation must properly dominate the user. Otherwise
81 // basic dominance is enough.
82 bool doesDominate = insertAfter
83 ? dom.properlyDominates(a: insertionPoint, b: user)
84 : dom.dominates(a: insertionPoint, b: user);
85 if (!doesDominate) {
86 return emitDefiniteFailure()
87 << "User " << user << " is not dominated by insertion point "
88 << insertionPoint;
89 }
90 }
91 }
92
93 for (Value input : inputs) {
94 // If we are inserting before the insertion point operation, the
95 // input must properly dominate the insertion point operation. Otherwise
96 // basic dominance is enough.
97 bool doesDominate = insertAfter
98 ? dom.dominates(a: input, b: insertionPoint)
99 : dom.properlyDominates(a: input, b: insertionPoint);
100 if (!doesDominate) {
101 return emitDefiniteFailure()
102 << "input " << input << " does not dominate insertion point "
103 << insertionPoint;
104 }
105 }
106
107 // Get the function to call. This can either be specified by symbol or as a
108 // transform handle.
109 func::FuncOp targetFunction = nullptr;
110 if (getFunctionName()) {
111 targetFunction = SymbolTable::lookupNearestSymbolFrom<func::FuncOp>(
112 from: insertionPoint, symbol: *getFunctionName());
113 if (!targetFunction) {
114 return emitDefiniteFailure()
115 << "unresolved symbol " << *getFunctionName();
116 }
117 } else if (getFunction()) {
118 auto payloadOps = state.getPayloadOps(value: getFunction());
119 if (!llvm::hasSingleElement(C&: payloadOps)) {
120 return emitDefiniteFailure() << "requires a single function to call";
121 }
122 targetFunction = dyn_cast<func::FuncOp>(Val: *payloadOps.begin());
123 if (!targetFunction) {
124 return emitDefiniteFailure() << "invalid non-function callee";
125 }
126 } else {
127 llvm_unreachable("Invalid CastAndCall op without a function to call");
128 return emitDefiniteFailure();
129 }
130
131 // Verify that the function argument and result lengths match the inputs and
132 // outputs given to this op.
133 if (targetFunction.getNumArguments() != inputs.size()) {
134 return emitSilenceableFailure(loc: targetFunction.getLoc())
135 << "mismatch between number of function arguments "
136 << targetFunction.getNumArguments() << " and number of inputs "
137 << inputs.size();
138 }
139 if (targetFunction.getNumResults() != outputs.size()) {
140 return emitSilenceableFailure(loc: targetFunction.getLoc())
141 << "mismatch between number of function results "
142 << targetFunction->getNumResults() << " and number of outputs "
143 << outputs.size();
144 }
145
146 // Gather all specified converters.
147 mlir::TypeConverter converter;
148 if (!getRegion().empty()) {
149 for (Operation &op : getRegion().front()) {
150 cast<transform::TypeConverterBuilderOpInterface>(Val: &op)
151 .populateTypeMaterializations(converter);
152 }
153 }
154
155 if (insertAfter)
156 rewriter.setInsertionPointAfter(insertionPoint);
157 else
158 rewriter.setInsertionPoint(insertionPoint);
159
160 for (auto [input, type] :
161 llvm::zip_equal(t&: inputs, u: targetFunction.getArgumentTypes())) {
162 if (input.getType() != type) {
163 Value newInput = converter.materializeSourceConversion(
164 builder&: rewriter, loc: input.getLoc(), resultType: type, inputs: input);
165 if (!newInput) {
166 return emitDefiniteFailure() << "Failed to materialize conversion of "
167 << input << " to type " << type;
168 }
169 input = newInput;
170 }
171 }
172
173 auto callOp = rewriter.create<func::CallOp>(location: insertionPoint->getLoc(),
174 args&: targetFunction, args&: inputs);
175
176 // Cast the call results back to the expected types. If any conversions fail
177 // this is a definite failure as the call has been constructed at this point.
178 for (auto [output, newOutput] :
179 llvm::zip_equal(t&: outputs, u: callOp.getResults())) {
180 Value convertedOutput = newOutput;
181 if (output.getType() != newOutput.getType()) {
182 convertedOutput = converter.materializeTargetConversion(
183 builder&: rewriter, loc: output.getLoc(), resultType: output.getType(), inputs: newOutput);
184 if (!convertedOutput) {
185 return emitDefiniteFailure()
186 << "Failed to materialize conversion of " << newOutput
187 << " to type " << output.getType();
188 }
189 }
190 rewriter.replaceAllUsesExcept(from: output, to: convertedOutput, exceptedUser: callOp);
191 }
192 results.set(value: cast<OpResult>(Val: getResult()), ops: {callOp});
193 return DiagnosedSilenceableFailure::success();
194}
195
196LogicalResult transform::CastAndCallOp::verify() {
197 if (!getRegion().empty()) {
198 for (Operation &op : getRegion().front()) {
199 if (!isa<transform::TypeConverterBuilderOpInterface>(Val: &op)) {
200 InFlightDiagnostic diag = emitOpError()
201 << "expected children ops to implement "
202 "TypeConverterBuilderOpInterface";
203 diag.attachNote(noteLoc: op.getLoc()) << "op without interface";
204 return diag;
205 }
206 }
207 }
208 if (!getFunction() && !getFunctionName()) {
209 return emitOpError() << "expected a function handle or name to call";
210 }
211 if (getFunction() && getFunctionName()) {
212 return emitOpError() << "function handle and name are mutually exclusive";
213 }
214 return success();
215}
216
217void transform::CastAndCallOp::getEffects(
218 SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
219 transform::onlyReadsHandle(handles: getInsertionPointMutable(), effects);
220 if (getInputs())
221 transform::onlyReadsHandle(handles: getInputsMutable(), effects);
222 if (getOutputs())
223 transform::onlyReadsHandle(handles: getOutputsMutable(), effects);
224 if (getFunction())
225 transform::onlyReadsHandle(handles: getFunctionMutable(), effects);
226 transform::producesHandle(handles: getOperation()->getOpResults(), effects);
227 transform::modifiesPayload(effects);
228}
229
230//===----------------------------------------------------------------------===//
231// ReplaceFuncSignatureOp
232//===----------------------------------------------------------------------===//
233
234DiagnosedSilenceableFailure
235transform::ReplaceFuncSignatureOp::apply(transform::TransformRewriter &rewriter,
236 transform::TransformResults &results,
237 transform::TransformState &state) {
238 auto payloadOps = state.getPayloadOps(value: getModule());
239 if (!llvm::hasSingleElement(C&: payloadOps))
240 return emitDefiniteFailure() << "requires a single module to operate on";
241
242 auto targetModuleOp = dyn_cast<ModuleOp>(Val: *payloadOps.begin());
243 if (!targetModuleOp)
244 return emitSilenceableFailure(loc: getLoc())
245 << "target is expected to be module operation";
246
247 func::FuncOp funcOp =
248 targetModuleOp.lookupSymbol<func::FuncOp>(symbol: getFunctionName());
249 if (!funcOp)
250 return emitSilenceableFailure(loc: getLoc())
251 << "function with name '" << getFunctionName() << "' not found";
252
253 unsigned numArgs = funcOp.getNumArguments();
254 unsigned numResults = funcOp.getNumResults();
255 // Check that the number of arguments and results matches the
256 // interchange sizes.
257 if (numArgs != getArgsInterchange().size())
258 return emitSilenceableFailure(loc: getLoc())
259 << "function with name '" << getFunctionName() << "' has " << numArgs
260 << " arguments, but " << getArgsInterchange().size()
261 << " args interchange were given";
262
263 if (numResults != getResultsInterchange().size())
264 return emitSilenceableFailure(loc: getLoc())
265 << "function with name '" << getFunctionName() << "' has "
266 << numResults << " results, but " << getResultsInterchange().size()
267 << " results interchange were given";
268
269 // Check that the args and results interchanges are unique.
270 SetVector<unsigned> argsInterchange, resultsInterchange;
271 argsInterchange.insert_range(R: getArgsInterchange());
272 resultsInterchange.insert_range(R: getResultsInterchange());
273 if (argsInterchange.size() != getArgsInterchange().size())
274 return emitSilenceableFailure(loc: getLoc())
275 << "args interchange must be unique";
276
277 if (resultsInterchange.size() != getResultsInterchange().size())
278 return emitSilenceableFailure(loc: getLoc())
279 << "results interchange must be unique";
280
281 // Check that the args and results interchange indices are in bounds.
282 for (unsigned index : argsInterchange) {
283 if (index >= numArgs) {
284 return emitSilenceableFailure(loc: getLoc())
285 << "args interchange index " << index
286 << " is out of bounds for function with name '"
287 << getFunctionName() << "' with " << numArgs << " arguments";
288 }
289 }
290 for (unsigned index : resultsInterchange) {
291 if (index >= numResults) {
292 return emitSilenceableFailure(loc: getLoc())
293 << "results interchange index " << index
294 << " is out of bounds for function with name '"
295 << getFunctionName() << "' with " << numResults << " results";
296 }
297 }
298
299 FailureOr<func::FuncOp> newFuncOpOrFailure = func::replaceFuncWithNewOrder(
300 rewriter, funcOp, newArgsOrder: argsInterchange.getArrayRef(),
301 newResultsOrder: resultsInterchange.getArrayRef());
302 if (failed(Result: newFuncOpOrFailure))
303 return emitSilenceableFailure(loc: getLoc())
304 << "failed to replace function signature '" << getFunctionName()
305 << "' with new order";
306
307 if (getAdjustFuncCalls()) {
308 SmallVector<func::CallOp> callOps;
309 targetModuleOp.walk(callback: [&](func::CallOp callOp) {
310 if (callOp.getCallee() == getFunctionName().getRootReference().getValue())
311 callOps.push_back(Elt: callOp);
312 });
313
314 for (func::CallOp callOp : callOps)
315 func::replaceCallOpWithNewOrder(rewriter, callOp,
316 newArgsOrder: argsInterchange.getArrayRef(),
317 newResultsOrder: resultsInterchange.getArrayRef());
318 }
319
320 results.set(value: cast<OpResult>(Val: getTransformedModule()), ops: {targetModuleOp});
321 results.set(value: cast<OpResult>(Val: getTransformedFunction()), ops: {*newFuncOpOrFailure});
322
323 return DiagnosedSilenceableFailure::success();
324}
325
326void transform::ReplaceFuncSignatureOp::getEffects(
327 SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
328 transform::consumesHandle(handles: getModuleMutable(), effects);
329 transform::producesHandle(handles: getOperation()->getOpResults(), effects);
330 transform::modifiesPayload(effects);
331}
332
333//===----------------------------------------------------------------------===//
334// Transform op registration
335//===----------------------------------------------------------------------===//
336
337namespace {
338class FuncTransformDialectExtension
339 : public transform::TransformDialectExtension<
340 FuncTransformDialectExtension> {
341public:
342 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(FuncTransformDialectExtension)
343
344 using Base::Base;
345
346 void init() {
347 declareGeneratedDialect<LLVM::LLVMDialect>();
348
349 registerTransformOps<
350#define GET_OP_LIST
351#include "mlir/Dialect/Func/TransformOps/FuncTransformOps.cpp.inc"
352 >();
353 }
354};
355} // namespace
356
357#define GET_OP_CLASSES
358#include "mlir/Dialect/Func/TransformOps/FuncTransformOps.cpp.inc"
359
360void mlir::func::registerTransformDialectExtension(DialectRegistry &registry) {
361 registry.addExtensions<FuncTransformDialectExtension>();
362}
363

source code of mlir/lib/Dialect/Func/TransformOps/FuncTransformOps.cpp