| 1 | //===- FuncTransformOps.cpp - Implementation of CF transform ops ----------===// |
| 2 | // |
| 3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
| 4 | // See https://llvm.org/LICENSE.txt for license information. |
| 5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
| 6 | // |
| 7 | //===----------------------------------------------------------------------===// |
| 8 | |
| 9 | #include "mlir/Dialect/Func/TransformOps/FuncTransformOps.h" |
| 10 | |
| 11 | #include "mlir/Conversion/FuncToLLVM/ConvertFuncToLLVM.h" |
| 12 | #include "mlir/Conversion/LLVMCommon/TypeConverter.h" |
| 13 | #include "mlir/Dialect/Func/IR/FuncOps.h" |
| 14 | #include "mlir/Dialect/Func/Utils/Utils.h" |
| 15 | #include "mlir/Dialect/LLVMIR/LLVMDialect.h" |
| 16 | #include "mlir/Dialect/Transform/IR/TransformDialect.h" |
| 17 | #include "mlir/Dialect/Transform/Interfaces/TransformInterfaces.h" |
| 18 | #include "mlir/IR/PatternMatch.h" |
| 19 | #include "mlir/Transforms/DialectConversion.h" |
| 20 | |
| 21 | using namespace mlir; |
| 22 | |
| 23 | //===----------------------------------------------------------------------===// |
| 24 | // Apply...ConversionPatternsOp |
| 25 | //===----------------------------------------------------------------------===// |
| 26 | |
| 27 | void transform::ApplyFuncToLLVMConversionPatternsOp::populatePatterns( |
| 28 | TypeConverter &typeConverter, RewritePatternSet &patterns) { |
| 29 | populateFuncToLLVMConversionPatterns( |
| 30 | converter: static_cast<LLVMTypeConverter &>(typeConverter), patterns); |
| 31 | } |
| 32 | |
| 33 | LogicalResult |
| 34 | transform::ApplyFuncToLLVMConversionPatternsOp::verifyTypeConverter( |
| 35 | transform::TypeConverterBuilderOpInterface builder) { |
| 36 | if (builder.getTypeConverterType() != "LLVMTypeConverter" ) |
| 37 | return emitOpError(message: "expected LLVMTypeConverter" ); |
| 38 | return success(); |
| 39 | } |
| 40 | |
| 41 | //===----------------------------------------------------------------------===// |
| 42 | // CastAndCallOp |
| 43 | //===----------------------------------------------------------------------===// |
| 44 | |
| 45 | DiagnosedSilenceableFailure |
| 46 | transform::CastAndCallOp::apply(transform::TransformRewriter &rewriter, |
| 47 | transform::TransformResults &results, |
| 48 | transform::TransformState &state) { |
| 49 | SmallVector<Value> inputs; |
| 50 | if (getInputs()) |
| 51 | llvm::append_range(C&: inputs, R: state.getPayloadValues(handleValue: getInputs())); |
| 52 | |
| 53 | SetVector<Value> outputs; |
| 54 | if (getOutputs()) { |
| 55 | outputs.insert_range(R: state.getPayloadValues(handleValue: getOutputs())); |
| 56 | |
| 57 | // Verify that the set of output values to be replaced is unique. |
| 58 | if (outputs.size() != |
| 59 | llvm::range_size(Range: state.getPayloadValues(handleValue: getOutputs()))) { |
| 60 | return emitSilenceableFailure(loc: getLoc()) |
| 61 | << "cast and call output values must be unique" ; |
| 62 | } |
| 63 | } |
| 64 | |
| 65 | // Get the insertion point for the call. |
| 66 | auto insertionOps = state.getPayloadOps(value: getInsertionPoint()); |
| 67 | if (!llvm::hasSingleElement(C&: insertionOps)) { |
| 68 | return emitSilenceableFailure(loc: getLoc()) |
| 69 | << "Only one op can be specified as an insertion point" ; |
| 70 | } |
| 71 | bool insertAfter = getInsertAfter(); |
| 72 | Operation *insertionPoint = *insertionOps.begin(); |
| 73 | |
| 74 | // Check that all inputs dominate the insertion point, and the insertion |
| 75 | // point dominates all users of the outputs. |
| 76 | DominanceInfo dom(insertionPoint); |
| 77 | for (Value output : outputs) { |
| 78 | for (Operation *user : output.getUsers()) { |
| 79 | // If we are inserting after the insertion point operation, the |
| 80 | // insertion point operation must properly dominate the user. Otherwise |
| 81 | // basic dominance is enough. |
| 82 | bool doesDominate = insertAfter |
| 83 | ? dom.properlyDominates(a: insertionPoint, b: user) |
| 84 | : dom.dominates(a: insertionPoint, b: user); |
| 85 | if (!doesDominate) { |
| 86 | return emitDefiniteFailure() |
| 87 | << "User " << user << " is not dominated by insertion point " |
| 88 | << insertionPoint; |
| 89 | } |
| 90 | } |
| 91 | } |
| 92 | |
| 93 | for (Value input : inputs) { |
| 94 | // If we are inserting before the insertion point operation, the |
| 95 | // input must properly dominate the insertion point operation. Otherwise |
| 96 | // basic dominance is enough. |
| 97 | bool doesDominate = insertAfter |
| 98 | ? dom.dominates(a: input, b: insertionPoint) |
| 99 | : dom.properlyDominates(a: input, b: insertionPoint); |
| 100 | if (!doesDominate) { |
| 101 | return emitDefiniteFailure() |
| 102 | << "input " << input << " does not dominate insertion point " |
| 103 | << insertionPoint; |
| 104 | } |
| 105 | } |
| 106 | |
| 107 | // Get the function to call. This can either be specified by symbol or as a |
| 108 | // transform handle. |
| 109 | func::FuncOp targetFunction = nullptr; |
| 110 | if (getFunctionName()) { |
| 111 | targetFunction = SymbolTable::lookupNearestSymbolFrom<func::FuncOp>( |
| 112 | from: insertionPoint, symbol: *getFunctionName()); |
| 113 | if (!targetFunction) { |
| 114 | return emitDefiniteFailure() |
| 115 | << "unresolved symbol " << *getFunctionName(); |
| 116 | } |
| 117 | } else if (getFunction()) { |
| 118 | auto payloadOps = state.getPayloadOps(value: getFunction()); |
| 119 | if (!llvm::hasSingleElement(C&: payloadOps)) { |
| 120 | return emitDefiniteFailure() << "requires a single function to call" ; |
| 121 | } |
| 122 | targetFunction = dyn_cast<func::FuncOp>(Val: *payloadOps.begin()); |
| 123 | if (!targetFunction) { |
| 124 | return emitDefiniteFailure() << "invalid non-function callee" ; |
| 125 | } |
| 126 | } else { |
| 127 | llvm_unreachable("Invalid CastAndCall op without a function to call" ); |
| 128 | return emitDefiniteFailure(); |
| 129 | } |
| 130 | |
| 131 | // Verify that the function argument and result lengths match the inputs and |
| 132 | // outputs given to this op. |
| 133 | if (targetFunction.getNumArguments() != inputs.size()) { |
| 134 | return emitSilenceableFailure(loc: targetFunction.getLoc()) |
| 135 | << "mismatch between number of function arguments " |
| 136 | << targetFunction.getNumArguments() << " and number of inputs " |
| 137 | << inputs.size(); |
| 138 | } |
| 139 | if (targetFunction.getNumResults() != outputs.size()) { |
| 140 | return emitSilenceableFailure(loc: targetFunction.getLoc()) |
| 141 | << "mismatch between number of function results " |
| 142 | << targetFunction->getNumResults() << " and number of outputs " |
| 143 | << outputs.size(); |
| 144 | } |
| 145 | |
| 146 | // Gather all specified converters. |
| 147 | mlir::TypeConverter converter; |
| 148 | if (!getRegion().empty()) { |
| 149 | for (Operation &op : getRegion().front()) { |
| 150 | cast<transform::TypeConverterBuilderOpInterface>(Val: &op) |
| 151 | .populateTypeMaterializations(converter); |
| 152 | } |
| 153 | } |
| 154 | |
| 155 | if (insertAfter) |
| 156 | rewriter.setInsertionPointAfter(insertionPoint); |
| 157 | else |
| 158 | rewriter.setInsertionPoint(insertionPoint); |
| 159 | |
| 160 | for (auto [input, type] : |
| 161 | llvm::zip_equal(t&: inputs, u: targetFunction.getArgumentTypes())) { |
| 162 | if (input.getType() != type) { |
| 163 | Value newInput = converter.materializeSourceConversion( |
| 164 | builder&: rewriter, loc: input.getLoc(), resultType: type, inputs: input); |
| 165 | if (!newInput) { |
| 166 | return emitDefiniteFailure() << "Failed to materialize conversion of " |
| 167 | << input << " to type " << type; |
| 168 | } |
| 169 | input = newInput; |
| 170 | } |
| 171 | } |
| 172 | |
| 173 | auto callOp = rewriter.create<func::CallOp>(location: insertionPoint->getLoc(), |
| 174 | args&: targetFunction, args&: inputs); |
| 175 | |
| 176 | // Cast the call results back to the expected types. If any conversions fail |
| 177 | // this is a definite failure as the call has been constructed at this point. |
| 178 | for (auto [output, newOutput] : |
| 179 | llvm::zip_equal(t&: outputs, u: callOp.getResults())) { |
| 180 | Value convertedOutput = newOutput; |
| 181 | if (output.getType() != newOutput.getType()) { |
| 182 | convertedOutput = converter.materializeTargetConversion( |
| 183 | builder&: rewriter, loc: output.getLoc(), resultType: output.getType(), inputs: newOutput); |
| 184 | if (!convertedOutput) { |
| 185 | return emitDefiniteFailure() |
| 186 | << "Failed to materialize conversion of " << newOutput |
| 187 | << " to type " << output.getType(); |
| 188 | } |
| 189 | } |
| 190 | rewriter.replaceAllUsesExcept(from: output, to: convertedOutput, exceptedUser: callOp); |
| 191 | } |
| 192 | results.set(value: cast<OpResult>(Val: getResult()), ops: {callOp}); |
| 193 | return DiagnosedSilenceableFailure::success(); |
| 194 | } |
| 195 | |
| 196 | LogicalResult transform::CastAndCallOp::verify() { |
| 197 | if (!getRegion().empty()) { |
| 198 | for (Operation &op : getRegion().front()) { |
| 199 | if (!isa<transform::TypeConverterBuilderOpInterface>(Val: &op)) { |
| 200 | InFlightDiagnostic diag = emitOpError() |
| 201 | << "expected children ops to implement " |
| 202 | "TypeConverterBuilderOpInterface" ; |
| 203 | diag.attachNote(noteLoc: op.getLoc()) << "op without interface" ; |
| 204 | return diag; |
| 205 | } |
| 206 | } |
| 207 | } |
| 208 | if (!getFunction() && !getFunctionName()) { |
| 209 | return emitOpError() << "expected a function handle or name to call" ; |
| 210 | } |
| 211 | if (getFunction() && getFunctionName()) { |
| 212 | return emitOpError() << "function handle and name are mutually exclusive" ; |
| 213 | } |
| 214 | return success(); |
| 215 | } |
| 216 | |
| 217 | void transform::CastAndCallOp::getEffects( |
| 218 | SmallVectorImpl<MemoryEffects::EffectInstance> &effects) { |
| 219 | transform::onlyReadsHandle(handles: getInsertionPointMutable(), effects); |
| 220 | if (getInputs()) |
| 221 | transform::onlyReadsHandle(handles: getInputsMutable(), effects); |
| 222 | if (getOutputs()) |
| 223 | transform::onlyReadsHandle(handles: getOutputsMutable(), effects); |
| 224 | if (getFunction()) |
| 225 | transform::onlyReadsHandle(handles: getFunctionMutable(), effects); |
| 226 | transform::producesHandle(handles: getOperation()->getOpResults(), effects); |
| 227 | transform::modifiesPayload(effects); |
| 228 | } |
| 229 | |
| 230 | //===----------------------------------------------------------------------===// |
| 231 | // ReplaceFuncSignatureOp |
| 232 | //===----------------------------------------------------------------------===// |
| 233 | |
| 234 | DiagnosedSilenceableFailure |
| 235 | transform::ReplaceFuncSignatureOp::apply(transform::TransformRewriter &rewriter, |
| 236 | transform::TransformResults &results, |
| 237 | transform::TransformState &state) { |
| 238 | auto payloadOps = state.getPayloadOps(value: getModule()); |
| 239 | if (!llvm::hasSingleElement(C&: payloadOps)) |
| 240 | return emitDefiniteFailure() << "requires a single module to operate on" ; |
| 241 | |
| 242 | auto targetModuleOp = dyn_cast<ModuleOp>(Val: *payloadOps.begin()); |
| 243 | if (!targetModuleOp) |
| 244 | return emitSilenceableFailure(loc: getLoc()) |
| 245 | << "target is expected to be module operation" ; |
| 246 | |
| 247 | func::FuncOp funcOp = |
| 248 | targetModuleOp.lookupSymbol<func::FuncOp>(symbol: getFunctionName()); |
| 249 | if (!funcOp) |
| 250 | return emitSilenceableFailure(loc: getLoc()) |
| 251 | << "function with name '" << getFunctionName() << "' not found" ; |
| 252 | |
| 253 | unsigned numArgs = funcOp.getNumArguments(); |
| 254 | unsigned numResults = funcOp.getNumResults(); |
| 255 | // Check that the number of arguments and results matches the |
| 256 | // interchange sizes. |
| 257 | if (numArgs != getArgsInterchange().size()) |
| 258 | return emitSilenceableFailure(loc: getLoc()) |
| 259 | << "function with name '" << getFunctionName() << "' has " << numArgs |
| 260 | << " arguments, but " << getArgsInterchange().size() |
| 261 | << " args interchange were given" ; |
| 262 | |
| 263 | if (numResults != getResultsInterchange().size()) |
| 264 | return emitSilenceableFailure(loc: getLoc()) |
| 265 | << "function with name '" << getFunctionName() << "' has " |
| 266 | << numResults << " results, but " << getResultsInterchange().size() |
| 267 | << " results interchange were given" ; |
| 268 | |
| 269 | // Check that the args and results interchanges are unique. |
| 270 | SetVector<unsigned> argsInterchange, resultsInterchange; |
| 271 | argsInterchange.insert_range(R: getArgsInterchange()); |
| 272 | resultsInterchange.insert_range(R: getResultsInterchange()); |
| 273 | if (argsInterchange.size() != getArgsInterchange().size()) |
| 274 | return emitSilenceableFailure(loc: getLoc()) |
| 275 | << "args interchange must be unique" ; |
| 276 | |
| 277 | if (resultsInterchange.size() != getResultsInterchange().size()) |
| 278 | return emitSilenceableFailure(loc: getLoc()) |
| 279 | << "results interchange must be unique" ; |
| 280 | |
| 281 | // Check that the args and results interchange indices are in bounds. |
| 282 | for (unsigned index : argsInterchange) { |
| 283 | if (index >= numArgs) { |
| 284 | return emitSilenceableFailure(loc: getLoc()) |
| 285 | << "args interchange index " << index |
| 286 | << " is out of bounds for function with name '" |
| 287 | << getFunctionName() << "' with " << numArgs << " arguments" ; |
| 288 | } |
| 289 | } |
| 290 | for (unsigned index : resultsInterchange) { |
| 291 | if (index >= numResults) { |
| 292 | return emitSilenceableFailure(loc: getLoc()) |
| 293 | << "results interchange index " << index |
| 294 | << " is out of bounds for function with name '" |
| 295 | << getFunctionName() << "' with " << numResults << " results" ; |
| 296 | } |
| 297 | } |
| 298 | |
| 299 | FailureOr<func::FuncOp> newFuncOpOrFailure = func::replaceFuncWithNewOrder( |
| 300 | rewriter, funcOp, newArgsOrder: argsInterchange.getArrayRef(), |
| 301 | newResultsOrder: resultsInterchange.getArrayRef()); |
| 302 | if (failed(Result: newFuncOpOrFailure)) |
| 303 | return emitSilenceableFailure(loc: getLoc()) |
| 304 | << "failed to replace function signature '" << getFunctionName() |
| 305 | << "' with new order" ; |
| 306 | |
| 307 | if (getAdjustFuncCalls()) { |
| 308 | SmallVector<func::CallOp> callOps; |
| 309 | targetModuleOp.walk(callback: [&](func::CallOp callOp) { |
| 310 | if (callOp.getCallee() == getFunctionName().getRootReference().getValue()) |
| 311 | callOps.push_back(Elt: callOp); |
| 312 | }); |
| 313 | |
| 314 | for (func::CallOp callOp : callOps) |
| 315 | func::replaceCallOpWithNewOrder(rewriter, callOp, |
| 316 | newArgsOrder: argsInterchange.getArrayRef(), |
| 317 | newResultsOrder: resultsInterchange.getArrayRef()); |
| 318 | } |
| 319 | |
| 320 | results.set(value: cast<OpResult>(Val: getTransformedModule()), ops: {targetModuleOp}); |
| 321 | results.set(value: cast<OpResult>(Val: getTransformedFunction()), ops: {*newFuncOpOrFailure}); |
| 322 | |
| 323 | return DiagnosedSilenceableFailure::success(); |
| 324 | } |
| 325 | |
| 326 | void transform::ReplaceFuncSignatureOp::getEffects( |
| 327 | SmallVectorImpl<MemoryEffects::EffectInstance> &effects) { |
| 328 | transform::consumesHandle(handles: getModuleMutable(), effects); |
| 329 | transform::producesHandle(handles: getOperation()->getOpResults(), effects); |
| 330 | transform::modifiesPayload(effects); |
| 331 | } |
| 332 | |
| 333 | //===----------------------------------------------------------------------===// |
| 334 | // Transform op registration |
| 335 | //===----------------------------------------------------------------------===// |
| 336 | |
| 337 | namespace { |
| 338 | class FuncTransformDialectExtension |
| 339 | : public transform::TransformDialectExtension< |
| 340 | FuncTransformDialectExtension> { |
| 341 | public: |
| 342 | MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(FuncTransformDialectExtension) |
| 343 | |
| 344 | using Base::Base; |
| 345 | |
| 346 | void init() { |
| 347 | declareGeneratedDialect<LLVM::LLVMDialect>(); |
| 348 | |
| 349 | registerTransformOps< |
| 350 | #define GET_OP_LIST |
| 351 | #include "mlir/Dialect/Func/TransformOps/FuncTransformOps.cpp.inc" |
| 352 | >(); |
| 353 | } |
| 354 | }; |
| 355 | } // namespace |
| 356 | |
| 357 | #define GET_OP_CLASSES |
| 358 | #include "mlir/Dialect/Func/TransformOps/FuncTransformOps.cpp.inc" |
| 359 | |
| 360 | void mlir::func::registerTransformDialectExtension(DialectRegistry ®istry) { |
| 361 | registry.addExtensions<FuncTransformDialectExtension>(); |
| 362 | } |
| 363 | |