1//===- InliningUtils.cpp ---- Misc utilities for inlining -----------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implements miscellaneous inlining utilities.
10//
11//===----------------------------------------------------------------------===//
12
13#include "mlir/Transforms/InliningUtils.h"
14
15#include "mlir/IR/Builders.h"
16#include "mlir/IR/IRMapping.h"
17#include "mlir/IR/Operation.h"
18#include "mlir/Interfaces/CallInterfaces.h"
19#include "llvm/ADT/MapVector.h"
20#include "llvm/Support/Debug.h"
21#include "llvm/Support/raw_ostream.h"
22#include <optional>
23
24#define DEBUG_TYPE "inlining"
25
26using namespace mlir;
27
28/// Remap locations from the inlined blocks with CallSiteLoc locations with the
29/// provided caller location.
30static void
31remapInlinedLocations(iterator_range<Region::iterator> inlinedBlocks,
32 Location callerLoc) {
33 DenseMap<Location, Location> mappedLocations;
34 auto remapOpLoc = [&](Operation *op) {
35 auto it = mappedLocations.find(Val: op->getLoc());
36 if (it == mappedLocations.end()) {
37 auto newLoc = CallSiteLoc::get(op->getLoc(), callerLoc);
38 it = mappedLocations.try_emplace(op->getLoc(), newLoc).first;
39 }
40 op->setLoc(it->second);
41 };
42 for (auto &block : inlinedBlocks)
43 block.walk(callback&: remapOpLoc);
44}
45
46static void remapInlinedOperands(iterator_range<Region::iterator> inlinedBlocks,
47 IRMapping &mapper) {
48 auto remapOperands = [&](Operation *op) {
49 for (auto &operand : op->getOpOperands())
50 if (auto mappedOp = mapper.lookupOrNull(from: operand.get()))
51 operand.set(mappedOp);
52 };
53 for (auto &block : inlinedBlocks)
54 block.walk(callback&: remapOperands);
55}
56
57//===----------------------------------------------------------------------===//
58// InlinerInterface
59//===----------------------------------------------------------------------===//
60
61bool InlinerInterface::isLegalToInline(Operation *call, Operation *callable,
62 bool wouldBeCloned) const {
63 if (auto *handler = getInterfaceFor(obj: call))
64 return handler->isLegalToInline(call, callable, wouldBeCloned);
65 return false;
66}
67
68bool InlinerInterface::isLegalToInline(Region *dest, Region *src,
69 bool wouldBeCloned,
70 IRMapping &valueMapping) const {
71 if (auto *handler = getInterfaceFor(obj: dest->getParentOp()))
72 return handler->isLegalToInline(dest, src, wouldBeCloned, valueMapping);
73 return false;
74}
75
76bool InlinerInterface::isLegalToInline(Operation *op, Region *dest,
77 bool wouldBeCloned,
78 IRMapping &valueMapping) const {
79 if (auto *handler = getInterfaceFor(obj: op))
80 return handler->isLegalToInline(op, dest, wouldBeCloned, valueMapping);
81 return false;
82}
83
84bool InlinerInterface::shouldAnalyzeRecursively(Operation *op) const {
85 auto *handler = getInterfaceFor(obj: op);
86 return handler ? handler->shouldAnalyzeRecursively(op) : true;
87}
88
89/// Handle the given inlined terminator by replacing it with a new operation
90/// as necessary.
91void InlinerInterface::handleTerminator(Operation *op, Block *newDest) const {
92 auto *handler = getInterfaceFor(obj: op);
93 assert(handler && "expected valid dialect handler");
94 handler->handleTerminator(op, newDest);
95}
96
97/// Handle the given inlined terminator by replacing it with a new operation
98/// as necessary.
99void InlinerInterface::handleTerminator(Operation *op,
100 ValueRange valuesToRepl) const {
101 auto *handler = getInterfaceFor(obj: op);
102 assert(handler && "expected valid dialect handler");
103 handler->handleTerminator(op, valuesToReplace: valuesToRepl);
104}
105
106Value InlinerInterface::handleArgument(OpBuilder &builder, Operation *call,
107 Operation *callable, Value argument,
108 DictionaryAttr argumentAttrs) const {
109 auto *handler = getInterfaceFor(obj: callable);
110 assert(handler && "expected valid dialect handler");
111 return handler->handleArgument(builder, call, callable, argument,
112 argumentAttrs);
113}
114
115Value InlinerInterface::handleResult(OpBuilder &builder, Operation *call,
116 Operation *callable, Value result,
117 DictionaryAttr resultAttrs) const {
118 auto *handler = getInterfaceFor(obj: callable);
119 assert(handler && "expected valid dialect handler");
120 return handler->handleResult(builder, call, callable, result, resultAttrs);
121}
122
123void InlinerInterface::processInlinedCallBlocks(
124 Operation *call, iterator_range<Region::iterator> inlinedBlocks) const {
125 auto *handler = getInterfaceFor(obj: call);
126 assert(handler && "expected valid dialect handler");
127 handler->processInlinedCallBlocks(call, inlinedBlocks);
128}
129
130/// Utility to check that all of the operations within 'src' can be inlined.
131static bool isLegalToInline(InlinerInterface &interface, Region *src,
132 Region *insertRegion, bool shouldCloneInlinedRegion,
133 IRMapping &valueMapping) {
134 for (auto &block : *src) {
135 for (auto &op : block) {
136 // Check this operation.
137 if (!interface.isLegalToInline(op: &op, dest: insertRegion,
138 wouldBeCloned: shouldCloneInlinedRegion, valueMapping)) {
139 LLVM_DEBUG({
140 llvm::dbgs() << "* Illegal to inline because of op: ";
141 op.dump();
142 });
143 return false;
144 }
145 // Check any nested regions.
146 if (interface.shouldAnalyzeRecursively(op: &op) &&
147 llvm::any_of(Range: op.getRegions(), P: [&](Region &region) {
148 return !isLegalToInline(interface, src: &region, insertRegion,
149 shouldCloneInlinedRegion, valueMapping);
150 }))
151 return false;
152 }
153 }
154 return true;
155}
156
157//===----------------------------------------------------------------------===//
158// Inline Methods
159//===----------------------------------------------------------------------===//
160
161static void handleArgumentImpl(InlinerInterface &interface, OpBuilder &builder,
162 CallOpInterface call,
163 CallableOpInterface callable,
164 IRMapping &mapper) {
165 // Unpack the argument attributes if there are any.
166 SmallVector<DictionaryAttr> argAttrs(
167 callable.getCallableRegion()->getNumArguments(),
168 builder.getDictionaryAttr({}));
169 if (ArrayAttr arrayAttr = callable.getArgAttrsAttr()) {
170 assert(arrayAttr.size() == argAttrs.size());
171 for (auto [idx, attr] : llvm::enumerate(arrayAttr))
172 argAttrs[idx] = cast<DictionaryAttr>(attr);
173 }
174
175 // Run the argument attribute handler for the given argument and attribute.
176 for (auto [blockArg, argAttr] :
177 llvm::zip(callable.getCallableRegion()->getArguments(), argAttrs)) {
178 Value newArgument = interface.handleArgument(
179 builder, call, callable, mapper.lookup(blockArg), argAttr);
180 assert(newArgument.getType() == mapper.lookup(blockArg).getType() &&
181 "expected the argument type to not change");
182
183 // Update the mapping to point the new argument returned by the handler.
184 mapper.map(blockArg, newArgument);
185 }
186}
187
188static void handleResultImpl(InlinerInterface &interface, OpBuilder &builder,
189 CallOpInterface call, CallableOpInterface callable,
190 ValueRange results) {
191 // Unpack the result attributes if there are any.
192 SmallVector<DictionaryAttr> resAttrs(results.size(),
193 builder.getDictionaryAttr({}));
194 if (ArrayAttr arrayAttr = callable.getResAttrsAttr()) {
195 assert(arrayAttr.size() == resAttrs.size());
196 for (auto [idx, attr] : llvm::enumerate(arrayAttr))
197 resAttrs[idx] = cast<DictionaryAttr>(attr);
198 }
199
200 // Run the result attribute handler for the given result and attribute.
201 SmallVector<DictionaryAttr> resultAttributes;
202 for (auto [result, resAttr] : llvm::zip(results, resAttrs)) {
203 // Store the original result users before running the handler.
204 DenseSet<Operation *> resultUsers;
205 for (Operation *user : result.getUsers())
206 resultUsers.insert(user);
207
208 Value newResult =
209 interface.handleResult(builder, call, callable, result, resAttr);
210 assert(newResult.getType() == result.getType() &&
211 "expected the result type to not change");
212
213 // Replace the result uses except for the ones introduce by the handler.
214 result.replaceUsesWithIf(newResult, [&](OpOperand &operand) {
215 return resultUsers.count(operand.getOwner());
216 });
217 }
218}
219
220static LogicalResult
221inlineRegionImpl(InlinerInterface &interface, Region *src, Block *inlineBlock,
222 Block::iterator inlinePoint, IRMapping &mapper,
223 ValueRange resultsToReplace, TypeRange regionResultTypes,
224 std::optional<Location> inlineLoc,
225 bool shouldCloneInlinedRegion, CallOpInterface call = {}) {
226 assert(resultsToReplace.size() == regionResultTypes.size());
227 // We expect the region to have at least one block.
228 if (src->empty())
229 return failure();
230
231 // Check that all of the region arguments have been mapped.
232 auto *srcEntryBlock = &src->front();
233 if (llvm::any_of(Range: srcEntryBlock->getArguments(),
234 P: [&](BlockArgument arg) { return !mapper.contains(from: arg); }))
235 return failure();
236
237 // Check that the operations within the source region are valid to inline.
238 Region *insertRegion = inlineBlock->getParent();
239 if (!interface.isLegalToInline(dest: insertRegion, src, wouldBeCloned: shouldCloneInlinedRegion,
240 valueMapping&: mapper) ||
241 !isLegalToInline(interface, src, insertRegion, shouldCloneInlinedRegion,
242 valueMapping&: mapper))
243 return failure();
244
245 // Run the argument attribute handler before inlining the callable region.
246 OpBuilder builder(inlineBlock, inlinePoint);
247 auto callable = dyn_cast<CallableOpInterface>(src->getParentOp());
248 if (call && callable)
249 handleArgumentImpl(interface, builder, call, callable, mapper);
250
251 // Check to see if the region is being cloned, or moved inline. In either
252 // case, move the new blocks after the 'insertBlock' to improve IR
253 // readability.
254 Block *postInsertBlock = inlineBlock->splitBlock(splitBefore: inlinePoint);
255 if (shouldCloneInlinedRegion)
256 src->cloneInto(dest: insertRegion, destPos: postInsertBlock->getIterator(), mapper);
257 else
258 insertRegion->getBlocks().splice(where: postInsertBlock->getIterator(),
259 L2&: src->getBlocks(), first: src->begin(),
260 last: src->end());
261
262 // Get the range of newly inserted blocks.
263 auto newBlocks = llvm::make_range(x: std::next(x: inlineBlock->getIterator()),
264 y: postInsertBlock->getIterator());
265 Block *firstNewBlock = &*newBlocks.begin();
266
267 // Remap the locations of the inlined operations if a valid source location
268 // was provided.
269 if (inlineLoc && !llvm::isa<UnknownLoc>(Val: *inlineLoc))
270 remapInlinedLocations(inlinedBlocks: newBlocks, callerLoc: *inlineLoc);
271
272 // If the blocks were moved in-place, make sure to remap any necessary
273 // operands.
274 if (!shouldCloneInlinedRegion)
275 remapInlinedOperands(inlinedBlocks: newBlocks, mapper);
276
277 // Process the newly inlined blocks.
278 if (call)
279 interface.processInlinedCallBlocks(call: call, inlinedBlocks: newBlocks);
280 interface.processInlinedBlocks(inlinedBlocks: newBlocks);
281
282 // Handle the case where only a single block was inlined.
283 if (std::next(x: newBlocks.begin()) == newBlocks.end()) {
284 // Run the result attribute handler on the terminator operands.
285 Operation *firstBlockTerminator = firstNewBlock->getTerminator();
286 builder.setInsertionPoint(firstBlockTerminator);
287 if (call && callable)
288 handleResultImpl(interface, builder, call, callable,
289 firstBlockTerminator->getOperands());
290
291 // Have the interface handle the terminator of this block.
292 interface.handleTerminator(op: firstBlockTerminator, valuesToRepl: resultsToReplace);
293 firstBlockTerminator->erase();
294
295 // Merge the post insert block into the cloned entry block.
296 firstNewBlock->getOperations().splice(where: firstNewBlock->end(),
297 L2&: postInsertBlock->getOperations());
298 postInsertBlock->erase();
299 } else {
300 // Otherwise, there were multiple blocks inlined. Add arguments to the post
301 // insertion block to represent the results to replace.
302 for (const auto &resultToRepl : llvm::enumerate(First&: resultsToReplace)) {
303 resultToRepl.value().replaceAllUsesWith(
304 newValue: postInsertBlock->addArgument(type: regionResultTypes[resultToRepl.index()],
305 loc: resultToRepl.value().getLoc()));
306 }
307
308 // Run the result attribute handler on the post insertion block arguments.
309 builder.setInsertionPointToStart(postInsertBlock);
310 if (call && callable)
311 handleResultImpl(interface, builder, call, callable,
312 postInsertBlock->getArguments());
313
314 /// Handle the terminators for each of the new blocks.
315 for (auto &newBlock : newBlocks)
316 interface.handleTerminator(op: newBlock.getTerminator(), newDest: postInsertBlock);
317 }
318
319 // Splice the instructions of the inlined entry block into the insert block.
320 inlineBlock->getOperations().splice(where: inlineBlock->end(),
321 L2&: firstNewBlock->getOperations());
322 firstNewBlock->erase();
323 return success();
324}
325
326static LogicalResult
327inlineRegionImpl(InlinerInterface &interface, Region *src, Block *inlineBlock,
328 Block::iterator inlinePoint, ValueRange inlinedOperands,
329 ValueRange resultsToReplace, std::optional<Location> inlineLoc,
330 bool shouldCloneInlinedRegion, CallOpInterface call = {}) {
331 // We expect the region to have at least one block.
332 if (src->empty())
333 return failure();
334
335 auto *entryBlock = &src->front();
336 if (inlinedOperands.size() != entryBlock->getNumArguments())
337 return failure();
338
339 // Map the provided call operands to the arguments of the region.
340 IRMapping mapper;
341 for (unsigned i = 0, e = inlinedOperands.size(); i != e; ++i) {
342 // Verify that the types of the provided values match the function argument
343 // types.
344 BlockArgument regionArg = entryBlock->getArgument(i);
345 if (inlinedOperands[i].getType() != regionArg.getType())
346 return failure();
347 mapper.map(from: regionArg, to: inlinedOperands[i]);
348 }
349
350 // Call into the main region inliner function.
351 return inlineRegionImpl(interface, src, inlineBlock, inlinePoint, mapper,
352 resultsToReplace, resultsToReplace.getTypes(),
353 inlineLoc, shouldCloneInlinedRegion, call);
354}
355
356LogicalResult mlir::inlineRegion(InlinerInterface &interface, Region *src,
357 Operation *inlinePoint, IRMapping &mapper,
358 ValueRange resultsToReplace,
359 TypeRange regionResultTypes,
360 std::optional<Location> inlineLoc,
361 bool shouldCloneInlinedRegion) {
362 return inlineRegion(interface, src, inlinePoint->getBlock(),
363 ++inlinePoint->getIterator(), mapper, resultsToReplace,
364 regionResultTypes, inlineLoc, shouldCloneInlinedRegion);
365}
366LogicalResult mlir::inlineRegion(InlinerInterface &interface, Region *src,
367 Block *inlineBlock,
368 Block::iterator inlinePoint, IRMapping &mapper,
369 ValueRange resultsToReplace,
370 TypeRange regionResultTypes,
371 std::optional<Location> inlineLoc,
372 bool shouldCloneInlinedRegion) {
373 return inlineRegionImpl(interface, src, inlineBlock, inlinePoint, mapper,
374 resultsToReplace, regionResultTypes, inlineLoc,
375 shouldCloneInlinedRegion);
376}
377
378LogicalResult mlir::inlineRegion(InlinerInterface &interface, Region *src,
379 Operation *inlinePoint,
380 ValueRange inlinedOperands,
381 ValueRange resultsToReplace,
382 std::optional<Location> inlineLoc,
383 bool shouldCloneInlinedRegion) {
384 return inlineRegion(interface, src, inlinePoint->getBlock(),
385 ++inlinePoint->getIterator(), inlinedOperands,
386 resultsToReplace, inlineLoc, shouldCloneInlinedRegion);
387}
388LogicalResult mlir::inlineRegion(InlinerInterface &interface, Region *src,
389 Block *inlineBlock,
390 Block::iterator inlinePoint,
391 ValueRange inlinedOperands,
392 ValueRange resultsToReplace,
393 std::optional<Location> inlineLoc,
394 bool shouldCloneInlinedRegion) {
395 return inlineRegionImpl(interface, src, inlineBlock, inlinePoint,
396 inlinedOperands, resultsToReplace, inlineLoc,
397 shouldCloneInlinedRegion);
398}
399
400/// Utility function used to generate a cast operation from the given interface,
401/// or return nullptr if a cast could not be generated.
402static Value materializeConversion(const DialectInlinerInterface *interface,
403 SmallVectorImpl<Operation *> &castOps,
404 OpBuilder &castBuilder, Value arg, Type type,
405 Location conversionLoc) {
406 if (!interface)
407 return nullptr;
408
409 // Check to see if the interface for the call can materialize a conversion.
410 Operation *castOp = interface->materializeCallConversion(builder&: castBuilder, input: arg,
411 resultType: type, conversionLoc);
412 if (!castOp)
413 return nullptr;
414 castOps.push_back(Elt: castOp);
415
416 // Ensure that the generated cast is correct.
417 assert(castOp->getNumOperands() == 1 && castOp->getOperand(0) == arg &&
418 castOp->getNumResults() == 1 && *castOp->result_type_begin() == type);
419 return castOp->getResult(idx: 0);
420}
421
422/// This function inlines a given region, 'src', of a callable operation,
423/// 'callable', into the location defined by the given call operation. This
424/// function returns failure if inlining is not possible, success otherwise. On
425/// failure, no changes are made to the module. 'shouldCloneInlinedRegion'
426/// corresponds to whether the source region should be cloned into the 'call' or
427/// spliced directly.
428LogicalResult mlir::inlineCall(InlinerInterface &interface,
429 CallOpInterface call,
430 CallableOpInterface callable, Region *src,
431 bool shouldCloneInlinedRegion) {
432 // We expect the region to have at least one block.
433 if (src->empty())
434 return failure();
435 auto *entryBlock = &src->front();
436 ArrayRef<Type> callableResultTypes = callable.getResultTypes();
437
438 // Make sure that the number of arguments and results matchup between the call
439 // and the region.
440 SmallVector<Value, 8> callOperands(call.getArgOperands());
441 SmallVector<Value, 8> callResults(call->getResults());
442 if (callOperands.size() != entryBlock->getNumArguments() ||
443 callResults.size() != callableResultTypes.size())
444 return failure();
445
446 // A set of cast operations generated to matchup the signature of the region
447 // with the signature of the call.
448 SmallVector<Operation *, 4> castOps;
449 castOps.reserve(N: callOperands.size() + callResults.size());
450
451 // Functor used to cleanup generated state on failure.
452 auto cleanupState = [&] {
453 for (auto *op : castOps) {
454 op->getResult(idx: 0).replaceAllUsesWith(newValue: op->getOperand(idx: 0));
455 op->erase();
456 }
457 return failure();
458 };
459
460 // Builder used for any conversion operations that need to be materialized.
461 OpBuilder castBuilder(call);
462 Location castLoc = call.getLoc();
463 const auto *callInterface = interface.getInterfaceFor(call->getDialect());
464
465 // Map the provided call operands to the arguments of the region.
466 IRMapping mapper;
467 for (unsigned i = 0, e = callOperands.size(); i != e; ++i) {
468 BlockArgument regionArg = entryBlock->getArgument(i);
469 Value operand = callOperands[i];
470
471 // If the call operand doesn't match the expected region argument, try to
472 // generate a cast.
473 Type regionArgType = regionArg.getType();
474 if (operand.getType() != regionArgType) {
475 if (!(operand = materializeConversion(callInterface, castOps, castBuilder,
476 operand, regionArgType, castLoc)))
477 return cleanupState();
478 }
479 mapper.map(from: regionArg, to: operand);
480 }
481
482 // Ensure that the resultant values of the call match the callable.
483 castBuilder.setInsertionPointAfter(call);
484 for (unsigned i = 0, e = callResults.size(); i != e; ++i) {
485 Value callResult = callResults[i];
486 if (callResult.getType() == callableResultTypes[i])
487 continue;
488
489 // Generate a conversion that will produce the original type, so that the IR
490 // is still valid after the original call gets replaced.
491 Value castResult =
492 materializeConversion(callInterface, castOps, castBuilder, callResult,
493 callResult.getType(), castLoc);
494 if (!castResult)
495 return cleanupState();
496 callResult.replaceAllUsesWith(newValue: castResult);
497 castResult.getDefiningOp()->replaceUsesOfWith(from: castResult, to: callResult);
498 }
499
500 // Check that it is legal to inline the callable into the call.
501 if (!interface.isLegalToInline(call, callable, shouldCloneInlinedRegion))
502 return cleanupState();
503
504 // Attempt to inline the call.
505 if (failed(inlineRegionImpl(interface, src, call->getBlock(),
506 ++call->getIterator(), mapper, callResults,
507 callableResultTypes, call.getLoc(),
508 shouldCloneInlinedRegion, call)))
509 return cleanupState();
510 return success();
511}
512

source code of mlir/lib/Transforms/Utils/InliningUtils.cpp