1//===- InliningUtils.cpp ---- Misc utilities for inlining -----------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implements miscellaneous inlining utilities.
10//
11//===----------------------------------------------------------------------===//
12
13#include "mlir/Transforms/InliningUtils.h"
14
15#include "mlir/IR/Builders.h"
16#include "mlir/IR/IRMapping.h"
17#include "mlir/IR/Operation.h"
18#include "mlir/Interfaces/CallInterfaces.h"
19#include "llvm/ADT/MapVector.h"
20#include "llvm/Support/Debug.h"
21#include "llvm/Support/raw_ostream.h"
22#include <optional>
23
24#define DEBUG_TYPE "inlining"
25
26using namespace mlir;
27
28/// Combine `callee` location with `caller` location to create a stack that
29/// represents the call chain.
30/// If `callee` location is a `CallSiteLoc`, indicating an existing stack of
31/// locations, the `caller` location is appended to the end of it, extending
32/// the chain.
33/// Otherwise, a single `CallSiteLoc` is created, representing a direct call
34/// from `caller` to `callee`.
35static LocationAttr stackLocations(Location callee, Location caller) {
36 Location lastCallee = callee;
37 SmallVector<CallSiteLoc> calleeInliningStack;
38 while (auto nextCallSite = dyn_cast<CallSiteLoc>(lastCallee)) {
39 calleeInliningStack.push_back(nextCallSite);
40 lastCallee = nextCallSite.getCaller();
41 }
42
43 CallSiteLoc firstCallSite = CallSiteLoc::get(lastCallee, caller);
44 for (CallSiteLoc currentCallSite : reverse(calleeInliningStack))
45 firstCallSite =
46 CallSiteLoc::get(currentCallSite.getCallee(), firstCallSite);
47
48 return firstCallSite;
49}
50
51/// Remap all locations reachable from the inlined blocks with CallSiteLoc
52/// locations with the provided caller location.
53static void
54remapInlinedLocations(iterator_range<Region::iterator> inlinedBlocks,
55 Location callerLoc) {
56 DenseMap<Location, LocationAttr> mappedLocations;
57 auto remapLoc = [&](Location loc) {
58 auto [it, inserted] = mappedLocations.try_emplace(Key: loc);
59 // Only query the attribute uniquer once per callsite attribute.
60 if (inserted) {
61 LocationAttr newLoc = stackLocations(callee: loc, caller: callerLoc);
62 it->getSecond() = newLoc;
63 }
64 return it->second;
65 };
66
67 AttrTypeReplacer attrReplacer;
68 attrReplacer.addReplacement(
69 callback: [&](LocationAttr loc) -> std::pair<LocationAttr, WalkResult> {
70 return {remapLoc(loc), WalkResult::skip()};
71 });
72
73 for (Block &block : inlinedBlocks) {
74 for (BlockArgument &arg : block.getArguments())
75 if (LocationAttr newLoc = remapLoc(arg.getLoc()))
76 arg.setLoc(newLoc);
77
78 for (Operation &op : block)
79 attrReplacer.recursivelyReplaceElementsIn(op: &op, /*replaceAttrs=*/false,
80 /*replaceLocs=*/true);
81 }
82}
83
84static void remapInlinedOperands(iterator_range<Region::iterator> inlinedBlocks,
85 IRMapping &mapper) {
86 auto remapOperands = [&](Operation *op) {
87 for (auto &operand : op->getOpOperands())
88 if (auto mappedOp = mapper.lookupOrNull(from: operand.get()))
89 operand.set(mappedOp);
90 };
91 for (auto &block : inlinedBlocks)
92 block.walk(callback&: remapOperands);
93}
94
95//===----------------------------------------------------------------------===//
96// InlinerInterface
97//===----------------------------------------------------------------------===//
98
99bool InlinerInterface::isLegalToInline(Operation *call, Operation *callable,
100 bool wouldBeCloned) const {
101 if (auto *handler = getInterfaceFor(obj: call))
102 return handler->isLegalToInline(call, callable, wouldBeCloned);
103 return false;
104}
105
106bool InlinerInterface::isLegalToInline(Region *dest, Region *src,
107 bool wouldBeCloned,
108 IRMapping &valueMapping) const {
109 if (auto *handler = getInterfaceFor(obj: dest->getParentOp()))
110 return handler->isLegalToInline(dest, src, wouldBeCloned, valueMapping);
111 return false;
112}
113
114bool InlinerInterface::isLegalToInline(Operation *op, Region *dest,
115 bool wouldBeCloned,
116 IRMapping &valueMapping) const {
117 if (auto *handler = getInterfaceFor(obj: op))
118 return handler->isLegalToInline(op, dest, wouldBeCloned, valueMapping);
119 return false;
120}
121
122bool InlinerInterface::shouldAnalyzeRecursively(Operation *op) const {
123 auto *handler = getInterfaceFor(obj: op);
124 return handler ? handler->shouldAnalyzeRecursively(op) : true;
125}
126
127/// Handle the given inlined terminator by replacing it with a new operation
128/// as necessary.
129void InlinerInterface::handleTerminator(Operation *op, Block *newDest) const {
130 auto *handler = getInterfaceFor(obj: op);
131 assert(handler && "expected valid dialect handler");
132 handler->handleTerminator(op, newDest);
133}
134
135/// Handle the given inlined terminator by replacing it with a new operation
136/// as necessary.
137void InlinerInterface::handleTerminator(Operation *op,
138 ValueRange valuesToRepl) const {
139 auto *handler = getInterfaceFor(obj: op);
140 assert(handler && "expected valid dialect handler");
141 handler->handleTerminator(op, valuesToReplace: valuesToRepl);
142}
143
144/// Returns true if the inliner can assume a fast path of not creating a
145/// new block, if there is only one block.
146bool InlinerInterface::allowSingleBlockOptimization(
147 iterator_range<Region::iterator> inlinedBlocks) const {
148 if (inlinedBlocks.empty()) {
149 return true;
150 }
151 auto *handler = getInterfaceFor(obj: inlinedBlocks.begin()->getParentOp());
152 assert(handler && "expected valid dialect handler");
153 return handler->allowSingleBlockOptimization(inlinedBlocks);
154}
155
156Value InlinerInterface::handleArgument(OpBuilder &builder, Operation *call,
157 Operation *callable, Value argument,
158 DictionaryAttr argumentAttrs) const {
159 auto *handler = getInterfaceFor(obj: callable);
160 assert(handler && "expected valid dialect handler");
161 return handler->handleArgument(builder, call, callable, argument,
162 argumentAttrs);
163}
164
165Value InlinerInterface::handleResult(OpBuilder &builder, Operation *call,
166 Operation *callable, Value result,
167 DictionaryAttr resultAttrs) const {
168 auto *handler = getInterfaceFor(obj: callable);
169 assert(handler && "expected valid dialect handler");
170 return handler->handleResult(builder, call, callable, result, resultAttrs);
171}
172
173void InlinerInterface::processInlinedCallBlocks(
174 Operation *call, iterator_range<Region::iterator> inlinedBlocks) const {
175 auto *handler = getInterfaceFor(obj: call);
176 assert(handler && "expected valid dialect handler");
177 handler->processInlinedCallBlocks(call, inlinedBlocks);
178}
179
180/// Utility to check that all of the operations within 'src' can be inlined.
181static bool isLegalToInline(InlinerInterface &interface, Region *src,
182 Region *insertRegion, bool shouldCloneInlinedRegion,
183 IRMapping &valueMapping) {
184 for (auto &block : *src) {
185 for (auto &op : block) {
186 // Check this operation.
187 if (!interface.isLegalToInline(op: &op, dest: insertRegion,
188 wouldBeCloned: shouldCloneInlinedRegion, valueMapping)) {
189 LLVM_DEBUG({
190 llvm::dbgs() << "* Illegal to inline because of op: ";
191 op.dump();
192 });
193 return false;
194 }
195 // Check any nested regions.
196 if (interface.shouldAnalyzeRecursively(op: &op) &&
197 llvm::any_of(Range: op.getRegions(), P: [&](Region &region) {
198 return !isLegalToInline(interface, src: &region, insertRegion,
199 shouldCloneInlinedRegion, valueMapping);
200 }))
201 return false;
202 }
203 }
204 return true;
205}
206
207//===----------------------------------------------------------------------===//
208// Inline Methods
209//===----------------------------------------------------------------------===//
210
211static void handleArgumentImpl(InlinerInterface &interface, OpBuilder &builder,
212 CallOpInterface call,
213 CallableOpInterface callable,
214 IRMapping &mapper) {
215 // Unpack the argument attributes if there are any.
216 SmallVector<DictionaryAttr> argAttrs(
217 callable.getCallableRegion()->getNumArguments(),
218 builder.getDictionaryAttr({}));
219 if (ArrayAttr arrayAttr = callable.getArgAttrsAttr()) {
220 assert(arrayAttr.size() == argAttrs.size());
221 for (auto [idx, attr] : llvm::enumerate(arrayAttr))
222 argAttrs[idx] = cast<DictionaryAttr>(attr);
223 }
224
225 // Run the argument attribute handler for the given argument and attribute.
226 for (auto [blockArg, argAttr] :
227 llvm::zip(callable.getCallableRegion()->getArguments(), argAttrs)) {
228 Value newArgument = interface.handleArgument(
229 builder, call, callable, mapper.lookup(blockArg), argAttr);
230 assert(newArgument.getType() == mapper.lookup(blockArg).getType() &&
231 "expected the argument type to not change");
232
233 // Update the mapping to point the new argument returned by the handler.
234 mapper.map(blockArg, newArgument);
235 }
236}
237
238static void handleResultImpl(InlinerInterface &interface, OpBuilder &builder,
239 CallOpInterface call, CallableOpInterface callable,
240 ValueRange results) {
241 // Unpack the result attributes if there are any.
242 SmallVector<DictionaryAttr> resAttrs(results.size(),
243 builder.getDictionaryAttr({}));
244 if (ArrayAttr arrayAttr = callable.getResAttrsAttr()) {
245 assert(arrayAttr.size() == resAttrs.size());
246 for (auto [idx, attr] : llvm::enumerate(arrayAttr))
247 resAttrs[idx] = cast<DictionaryAttr>(attr);
248 }
249
250 // Run the result attribute handler for the given result and attribute.
251 for (auto [result, resAttr] : llvm::zip(results, resAttrs)) {
252 // Store the original result users before running the handler.
253 DenseSet<Operation *> resultUsers(llvm::from_range, result.getUsers());
254
255 Value newResult =
256 interface.handleResult(builder, call, callable, result, resAttr);
257 assert(newResult.getType() == result.getType() &&
258 "expected the result type to not change");
259
260 // Replace the result uses except for the ones introduce by the handler.
261 result.replaceUsesWithIf(newResult, [&](OpOperand &operand) {
262 return resultUsers.count(operand.getOwner());
263 });
264 }
265}
266
267static LogicalResult inlineRegionImpl(
268 InlinerInterface &interface,
269 function_ref<InlinerInterface::CloneCallbackSigTy> cloneCallback,
270 Region *src, Block *inlineBlock, Block::iterator inlinePoint,
271 IRMapping &mapper, ValueRange resultsToReplace, TypeRange regionResultTypes,
272 std::optional<Location> inlineLoc, bool shouldCloneInlinedRegion,
273 CallOpInterface call = {}) {
274 assert(resultsToReplace.size() == regionResultTypes.size());
275 // We expect the region to have at least one block.
276 if (src->empty())
277 return failure();
278
279 // Check that all of the region arguments have been mapped.
280 auto *srcEntryBlock = &src->front();
281 if (llvm::any_of(Range: srcEntryBlock->getArguments(),
282 P: [&](BlockArgument arg) { return !mapper.contains(from: arg); }))
283 return failure();
284
285 // Check that the operations within the source region are valid to inline.
286 Region *insertRegion = inlineBlock->getParent();
287 if (!interface.isLegalToInline(dest: insertRegion, src, wouldBeCloned: shouldCloneInlinedRegion,
288 valueMapping&: mapper) ||
289 !isLegalToInline(interface, src, insertRegion, shouldCloneInlinedRegion,
290 valueMapping&: mapper))
291 return failure();
292
293 // Run the argument attribute handler before inlining the callable region.
294 OpBuilder builder(inlineBlock, inlinePoint);
295 auto callable = dyn_cast<CallableOpInterface>(src->getParentOp());
296 if (call && callable)
297 handleArgumentImpl(interface, builder, call, callable, mapper);
298
299 // Clone the callee's source into the caller.
300 Block *postInsertBlock = inlineBlock->splitBlock(splitBefore: inlinePoint);
301 cloneCallback(builder, src, inlineBlock, postInsertBlock, mapper,
302 shouldCloneInlinedRegion);
303
304 // Get the range of newly inserted blocks.
305 auto newBlocks = llvm::make_range(x: std::next(x: inlineBlock->getIterator()),
306 y: postInsertBlock->getIterator());
307 Block *firstNewBlock = &*newBlocks.begin();
308
309 // Remap the locations of the inlined operations if a valid source location
310 // was provided.
311 if (inlineLoc && !llvm::isa<UnknownLoc>(Val: *inlineLoc))
312 remapInlinedLocations(inlinedBlocks: newBlocks, callerLoc: *inlineLoc);
313
314 // If the blocks were moved in-place, make sure to remap any necessary
315 // operands.
316 if (!shouldCloneInlinedRegion)
317 remapInlinedOperands(inlinedBlocks: newBlocks, mapper);
318
319 // Process the newly inlined blocks.
320 if (call)
321 interface.processInlinedCallBlocks(call: call, inlinedBlocks: newBlocks);
322 interface.processInlinedBlocks(inlinedBlocks: newBlocks);
323
324 bool singleBlockFastPath = interface.allowSingleBlockOptimization(inlinedBlocks: newBlocks);
325
326 // Handle the case where only a single block was inlined.
327 if (singleBlockFastPath && llvm::hasSingleElement(C&: newBlocks)) {
328 // Run the result attribute handler on the terminator operands.
329 Operation *firstBlockTerminator = firstNewBlock->getTerminator();
330 builder.setInsertionPoint(firstBlockTerminator);
331 if (call && callable)
332 handleResultImpl(interface, builder, call, callable,
333 firstBlockTerminator->getOperands());
334
335 // Have the interface handle the terminator of this block.
336 interface.handleTerminator(op: firstBlockTerminator, valuesToRepl: resultsToReplace);
337 firstBlockTerminator->erase();
338
339 // Merge the post insert block into the cloned entry block.
340 firstNewBlock->getOperations().splice(where: firstNewBlock->end(),
341 L2&: postInsertBlock->getOperations());
342 postInsertBlock->erase();
343 } else {
344 // Otherwise, there were multiple blocks inlined. Add arguments to the post
345 // insertion block to represent the results to replace.
346 for (const auto &resultToRepl : llvm::enumerate(First&: resultsToReplace)) {
347 resultToRepl.value().replaceAllUsesWith(
348 newValue: postInsertBlock->addArgument(type: regionResultTypes[resultToRepl.index()],
349 loc: resultToRepl.value().getLoc()));
350 }
351
352 // Run the result attribute handler on the post insertion block arguments.
353 builder.setInsertionPointToStart(postInsertBlock);
354 if (call && callable)
355 handleResultImpl(interface, builder, call, callable,
356 postInsertBlock->getArguments());
357
358 /// Handle the terminators for each of the new blocks.
359 for (auto &newBlock : newBlocks)
360 interface.handleTerminator(op: newBlock.getTerminator(), newDest: postInsertBlock);
361 }
362
363 // Splice the instructions of the inlined entry block into the insert block.
364 inlineBlock->getOperations().splice(where: inlineBlock->end(),
365 L2&: firstNewBlock->getOperations());
366 firstNewBlock->erase();
367 return success();
368}
369
370static LogicalResult inlineRegionImpl(
371 InlinerInterface &interface,
372 function_ref<InlinerInterface::CloneCallbackSigTy> cloneCallback,
373 Region *src, Block *inlineBlock, Block::iterator inlinePoint,
374 ValueRange inlinedOperands, ValueRange resultsToReplace,
375 std::optional<Location> inlineLoc, bool shouldCloneInlinedRegion,
376 CallOpInterface call = {}) {
377 // We expect the region to have at least one block.
378 if (src->empty())
379 return failure();
380
381 auto *entryBlock = &src->front();
382 if (inlinedOperands.size() != entryBlock->getNumArguments())
383 return failure();
384
385 // Map the provided call operands to the arguments of the region.
386 IRMapping mapper;
387 for (unsigned i = 0, e = inlinedOperands.size(); i != e; ++i) {
388 // Verify that the types of the provided values match the function argument
389 // types.
390 BlockArgument regionArg = entryBlock->getArgument(i);
391 if (inlinedOperands[i].getType() != regionArg.getType())
392 return failure();
393 mapper.map(from: regionArg, to: inlinedOperands[i]);
394 }
395
396 // Call into the main region inliner function.
397 return inlineRegionImpl(interface, cloneCallback, src, inlineBlock,
398 inlinePoint, mapper, resultsToReplace,
399 resultsToReplace.getTypes(), inlineLoc,
400 shouldCloneInlinedRegion, call);
401}
402
403LogicalResult mlir::inlineRegion(
404 InlinerInterface &interface,
405 function_ref<InlinerInterface::CloneCallbackSigTy> cloneCallback,
406 Region *src, Operation *inlinePoint, IRMapping &mapper,
407 ValueRange resultsToReplace, TypeRange regionResultTypes,
408 std::optional<Location> inlineLoc, bool shouldCloneInlinedRegion) {
409 return inlineRegion(interface, cloneCallback, src, inlinePoint->getBlock(),
410 ++inlinePoint->getIterator(), mapper, resultsToReplace,
411 regionResultTypes, inlineLoc, shouldCloneInlinedRegion);
412}
413
414LogicalResult mlir::inlineRegion(
415 InlinerInterface &interface,
416 function_ref<InlinerInterface::CloneCallbackSigTy> cloneCallback,
417 Region *src, Block *inlineBlock, Block::iterator inlinePoint,
418 IRMapping &mapper, ValueRange resultsToReplace, TypeRange regionResultTypes,
419 std::optional<Location> inlineLoc, bool shouldCloneInlinedRegion) {
420 return inlineRegionImpl(
421 interface, cloneCallback, src, inlineBlock, inlinePoint, mapper,
422 resultsToReplace, regionResultTypes, inlineLoc, shouldCloneInlinedRegion);
423}
424
425LogicalResult mlir::inlineRegion(
426 InlinerInterface &interface,
427 function_ref<InlinerInterface::CloneCallbackSigTy> cloneCallback,
428 Region *src, Operation *inlinePoint, ValueRange inlinedOperands,
429 ValueRange resultsToReplace, std::optional<Location> inlineLoc,
430 bool shouldCloneInlinedRegion) {
431 return inlineRegion(interface, cloneCallback, src, inlinePoint->getBlock(),
432 ++inlinePoint->getIterator(), inlinedOperands,
433 resultsToReplace, inlineLoc, shouldCloneInlinedRegion);
434}
435
436LogicalResult mlir::inlineRegion(
437 InlinerInterface &interface,
438 function_ref<InlinerInterface::CloneCallbackSigTy> cloneCallback,
439 Region *src, Block *inlineBlock, Block::iterator inlinePoint,
440 ValueRange inlinedOperands, ValueRange resultsToReplace,
441 std::optional<Location> inlineLoc, bool shouldCloneInlinedRegion) {
442 return inlineRegionImpl(interface, cloneCallback, src, inlineBlock,
443 inlinePoint, inlinedOperands, resultsToReplace,
444 inlineLoc, shouldCloneInlinedRegion);
445}
446
447/// Utility function used to generate a cast operation from the given interface,
448/// or return nullptr if a cast could not be generated.
449static Value materializeConversion(const DialectInlinerInterface *interface,
450 SmallVectorImpl<Operation *> &castOps,
451 OpBuilder &castBuilder, Value arg, Type type,
452 Location conversionLoc) {
453 if (!interface)
454 return nullptr;
455
456 // Check to see if the interface for the call can materialize a conversion.
457 Operation *castOp = interface->materializeCallConversion(builder&: castBuilder, input: arg,
458 resultType: type, conversionLoc);
459 if (!castOp)
460 return nullptr;
461 castOps.push_back(Elt: castOp);
462
463 // Ensure that the generated cast is correct.
464 assert(castOp->getNumOperands() == 1 && castOp->getOperand(0) == arg &&
465 castOp->getNumResults() == 1 && *castOp->result_type_begin() == type);
466 return castOp->getResult(idx: 0);
467}
468
469/// This function inlines a given region, 'src', of a callable operation,
470/// 'callable', into the location defined by the given call operation. This
471/// function returns failure if inlining is not possible, success otherwise. On
472/// failure, no changes are made to the module. 'shouldCloneInlinedRegion'
473/// corresponds to whether the source region should be cloned into the 'call' or
474/// spliced directly.
475LogicalResult mlir::inlineCall(
476 InlinerInterface &interface,
477 function_ref<InlinerInterface::CloneCallbackSigTy> cloneCallback,
478 CallOpInterface call, CallableOpInterface callable, Region *src,
479 bool shouldCloneInlinedRegion) {
480 // We expect the region to have at least one block.
481 if (src->empty())
482 return failure();
483 auto *entryBlock = &src->front();
484 ArrayRef<Type> callableResultTypes = callable.getResultTypes();
485
486 // Make sure that the number of arguments and results matchup between the call
487 // and the region.
488 SmallVector<Value, 8> callOperands(call.getArgOperands());
489 SmallVector<Value, 8> callResults(call->getResults());
490 if (callOperands.size() != entryBlock->getNumArguments() ||
491 callResults.size() != callableResultTypes.size())
492 return failure();
493
494 // A set of cast operations generated to matchup the signature of the region
495 // with the signature of the call.
496 SmallVector<Operation *, 4> castOps;
497 castOps.reserve(N: callOperands.size() + callResults.size());
498
499 // Functor used to cleanup generated state on failure.
500 auto cleanupState = [&] {
501 for (auto *op : castOps) {
502 op->getResult(idx: 0).replaceAllUsesWith(newValue: op->getOperand(idx: 0));
503 op->erase();
504 }
505 return failure();
506 };
507
508 // Builder used for any conversion operations that need to be materialized.
509 OpBuilder castBuilder(call);
510 Location castLoc = call.getLoc();
511 const auto *callInterface = interface.getInterfaceFor(call->getDialect());
512
513 // Map the provided call operands to the arguments of the region.
514 IRMapping mapper;
515 for (unsigned i = 0, e = callOperands.size(); i != e; ++i) {
516 BlockArgument regionArg = entryBlock->getArgument(i);
517 Value operand = callOperands[i];
518
519 // If the call operand doesn't match the expected region argument, try to
520 // generate a cast.
521 Type regionArgType = regionArg.getType();
522 if (operand.getType() != regionArgType) {
523 if (!(operand = materializeConversion(callInterface, castOps, castBuilder,
524 operand, regionArgType, castLoc)))
525 return cleanupState();
526 }
527 mapper.map(from: regionArg, to: operand);
528 }
529
530 // Ensure that the resultant values of the call match the callable.
531 castBuilder.setInsertionPointAfter(call);
532 for (unsigned i = 0, e = callResults.size(); i != e; ++i) {
533 Value callResult = callResults[i];
534 if (callResult.getType() == callableResultTypes[i])
535 continue;
536
537 // Generate a conversion that will produce the original type, so that the IR
538 // is still valid after the original call gets replaced.
539 Value castResult =
540 materializeConversion(callInterface, castOps, castBuilder, callResult,
541 callResult.getType(), castLoc);
542 if (!castResult)
543 return cleanupState();
544 callResult.replaceAllUsesWith(newValue: castResult);
545 castResult.getDefiningOp()->replaceUsesOfWith(from: castResult, to: callResult);
546 }
547
548 // Check that it is legal to inline the callable into the call.
549 if (!interface.isLegalToInline(call, callable, shouldCloneInlinedRegion))
550 return cleanupState();
551
552 // Attempt to inline the call.
553 if (failed(inlineRegionImpl(interface, cloneCallback, src, call->getBlock(),
554 ++call->getIterator(), mapper, callResults,
555 callableResultTypes, call.getLoc(),
556 shouldCloneInlinedRegion, call)))
557 return cleanupState();
558 return success();
559}
560

Provided by KDAB

Privacy Policy
Improve your Profiling and Debugging skills
Find out more

source code of mlir/lib/Transforms/Utils/InliningUtils.cpp