1//===- FuncTransformOps.cpp - Implementation of CF transform ops ---===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#include "mlir/Dialect/Func/TransformOps/FuncTransformOps.h"
10
11#include "mlir/Conversion/FuncToLLVM/ConvertFuncToLLVM.h"
12#include "mlir/Conversion/LLVMCommon/TypeConverter.h"
13#include "mlir/Dialect/Func/IR/FuncOps.h"
14#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
15#include "mlir/Dialect/Transform/IR/TransformDialect.h"
16#include "mlir/Dialect/Transform/IR/TransformOps.h"
17#include "mlir/Dialect/Transform/Interfaces/TransformInterfaces.h"
18#include "mlir/Transforms/DialectConversion.h"
19
20using namespace mlir;
21
22//===----------------------------------------------------------------------===//
23// Apply...ConversionPatternsOp
24//===----------------------------------------------------------------------===//
25
26void transform::ApplyFuncToLLVMConversionPatternsOp::populatePatterns(
27 TypeConverter &typeConverter, RewritePatternSet &patterns) {
28 populateFuncToLLVMConversionPatterns(
29 static_cast<LLVMTypeConverter &>(typeConverter), patterns);
30}
31
32LogicalResult
33transform::ApplyFuncToLLVMConversionPatternsOp::verifyTypeConverter(
34 transform::TypeConverterBuilderOpInterface builder) {
35 if (builder.getTypeConverterType() != "LLVMTypeConverter")
36 return emitOpError("expected LLVMTypeConverter");
37 return success();
38}
39
40//===----------------------------------------------------------------------===//
41// CastAndCallOp
42//===----------------------------------------------------------------------===//
43
44DiagnosedSilenceableFailure
45transform::CastAndCallOp::apply(transform::TransformRewriter &rewriter,
46 transform::TransformResults &results,
47 transform::TransformState &state) {
48 SmallVector<Value> inputs;
49 if (getInputs())
50 llvm::append_range(inputs, state.getPayloadValues(getInputs()));
51
52 SetVector<Value> outputs;
53 if (getOutputs()) {
54 outputs.insert_range(state.getPayloadValues(getOutputs()));
55
56 // Verify that the set of output values to be replaced is unique.
57 if (outputs.size() !=
58 llvm::range_size(state.getPayloadValues(getOutputs()))) {
59 return emitSilenceableFailure(getLoc())
60 << "cast and call output values must be unique";
61 }
62 }
63
64 // Get the insertion point for the call.
65 auto insertionOps = state.getPayloadOps(getInsertionPoint());
66 if (!llvm::hasSingleElement(insertionOps)) {
67 return emitSilenceableFailure(getLoc())
68 << "Only one op can be specified as an insertion point";
69 }
70 bool insertAfter = getInsertAfter();
71 Operation *insertionPoint = *insertionOps.begin();
72
73 // Check that all inputs dominate the insertion point, and the insertion
74 // point dominates all users of the outputs.
75 DominanceInfo dom(insertionPoint);
76 for (Value output : outputs) {
77 for (Operation *user : output.getUsers()) {
78 // If we are inserting after the insertion point operation, the
79 // insertion point operation must properly dominate the user. Otherwise
80 // basic dominance is enough.
81 bool doesDominate = insertAfter
82 ? dom.properlyDominates(insertionPoint, user)
83 : dom.dominates(insertionPoint, user);
84 if (!doesDominate) {
85 return emitDefiniteFailure()
86 << "User " << user << " is not dominated by insertion point "
87 << insertionPoint;
88 }
89 }
90 }
91
92 for (Value input : inputs) {
93 // If we are inserting before the insertion point operation, the
94 // input must properly dominate the insertion point operation. Otherwise
95 // basic dominance is enough.
96 bool doesDominate = insertAfter
97 ? dom.dominates(input, insertionPoint)
98 : dom.properlyDominates(input, insertionPoint);
99 if (!doesDominate) {
100 return emitDefiniteFailure()
101 << "input " << input << " does not dominate insertion point "
102 << insertionPoint;
103 }
104 }
105
106 // Get the function to call. This can either be specified by symbol or as a
107 // transform handle.
108 func::FuncOp targetFunction = nullptr;
109 if (getFunctionName()) {
110 targetFunction = SymbolTable::lookupNearestSymbolFrom<func::FuncOp>(
111 insertionPoint, *getFunctionName());
112 if (!targetFunction) {
113 return emitDefiniteFailure()
114 << "unresolved symbol " << *getFunctionName();
115 }
116 } else if (getFunction()) {
117 auto payloadOps = state.getPayloadOps(getFunction());
118 if (!llvm::hasSingleElement(payloadOps)) {
119 return emitDefiniteFailure() << "requires a single function to call";
120 }
121 targetFunction = dyn_cast<func::FuncOp>(*payloadOps.begin());
122 if (!targetFunction) {
123 return emitDefiniteFailure() << "invalid non-function callee";
124 }
125 } else {
126 llvm_unreachable("Invalid CastAndCall op without a function to call");
127 return emitDefiniteFailure();
128 }
129
130 // Verify that the function argument and result lengths match the inputs and
131 // outputs given to this op.
132 if (targetFunction.getNumArguments() != inputs.size()) {
133 return emitSilenceableFailure(targetFunction.getLoc())
134 << "mismatch between number of function arguments "
135 << targetFunction.getNumArguments() << " and number of inputs "
136 << inputs.size();
137 }
138 if (targetFunction.getNumResults() != outputs.size()) {
139 return emitSilenceableFailure(targetFunction.getLoc())
140 << "mismatch between number of function results "
141 << targetFunction->getNumResults() << " and number of outputs "
142 << outputs.size();
143 }
144
145 // Gather all specified converters.
146 mlir::TypeConverter converter;
147 if (!getRegion().empty()) {
148 for (Operation &op : getRegion().front()) {
149 cast<transform::TypeConverterBuilderOpInterface>(&op)
150 .populateTypeMaterializations(converter);
151 }
152 }
153
154 if (insertAfter)
155 rewriter.setInsertionPointAfter(insertionPoint);
156 else
157 rewriter.setInsertionPoint(insertionPoint);
158
159 for (auto [input, type] :
160 llvm::zip_equal(inputs, targetFunction.getArgumentTypes())) {
161 if (input.getType() != type) {
162 Value newInput = converter.materializeSourceConversion(
163 rewriter, input.getLoc(), type, input);
164 if (!newInput) {
165 return emitDefiniteFailure() << "Failed to materialize conversion of "
166 << input << " to type " << type;
167 }
168 input = newInput;
169 }
170 }
171
172 auto callOp = rewriter.create<func::CallOp>(insertionPoint->getLoc(),
173 targetFunction, inputs);
174
175 // Cast the call results back to the expected types. If any conversions fail
176 // this is a definite failure as the call has been constructed at this point.
177 for (auto [output, newOutput] :
178 llvm::zip_equal(outputs, callOp.getResults())) {
179 Value convertedOutput = newOutput;
180 if (output.getType() != newOutput.getType()) {
181 convertedOutput = converter.materializeTargetConversion(
182 rewriter, output.getLoc(), output.getType(), newOutput);
183 if (!convertedOutput) {
184 return emitDefiniteFailure()
185 << "Failed to materialize conversion of " << newOutput
186 << " to type " << output.getType();
187 }
188 }
189 rewriter.replaceAllUsesExcept(output, convertedOutput, callOp);
190 }
191 results.set(cast<OpResult>(getResult()), {callOp});
192 return DiagnosedSilenceableFailure::success();
193}
194
195LogicalResult transform::CastAndCallOp::verify() {
196 if (!getRegion().empty()) {
197 for (Operation &op : getRegion().front()) {
198 if (!isa<transform::TypeConverterBuilderOpInterface>(&op)) {
199 InFlightDiagnostic diag = emitOpError()
200 << "expected children ops to implement "
201 "TypeConverterBuilderOpInterface";
202 diag.attachNote(op.getLoc()) << "op without interface";
203 return diag;
204 }
205 }
206 }
207 if (!getFunction() && !getFunctionName()) {
208 return emitOpError() << "expected a function handle or name to call";
209 }
210 if (getFunction() && getFunctionName()) {
211 return emitOpError() << "function handle and name are mutually exclusive";
212 }
213 return success();
214}
215
216void transform::CastAndCallOp::getEffects(
217 SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
218 transform::onlyReadsHandle(getInsertionPointMutable(), effects);
219 if (getInputs())
220 transform::onlyReadsHandle(getInputsMutable(), effects);
221 if (getOutputs())
222 transform::onlyReadsHandle(getOutputsMutable(), effects);
223 if (getFunction())
224 transform::onlyReadsHandle(getFunctionMutable(), effects);
225 transform::producesHandle(getOperation()->getOpResults(), effects);
226 transform::modifiesPayload(effects);
227}
228
229//===----------------------------------------------------------------------===//
230// Transform op registration
231//===----------------------------------------------------------------------===//
232
233namespace {
234class FuncTransformDialectExtension
235 : public transform::TransformDialectExtension<
236 FuncTransformDialectExtension> {
237public:
238 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(FuncTransformDialectExtension)
239
240 using Base::Base;
241
242 void init() {
243 declareGeneratedDialect<LLVM::LLVMDialect>();
244
245 registerTransformOps<
246#define GET_OP_LIST
247#include "mlir/Dialect/Func/TransformOps/FuncTransformOps.cpp.inc"
248 >();
249 }
250};
251} // namespace
252
253#define GET_OP_CLASSES
254#include "mlir/Dialect/Func/TransformOps/FuncTransformOps.cpp.inc"
255
256void mlir::func::registerTransformDialectExtension(DialectRegistry &registry) {
257 registry.addExtensions<FuncTransformDialectExtension>();
258}
259

Provided by KDAB

Privacy Policy
Update your C++ knowledge – Modern C++11/14/17 Training
Find out more

source code of mlir/lib/Dialect/Func/TransformOps/FuncTransformOps.cpp