1//===- NormalizeMemRefs.cpp -----------------------------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implements an interprocedural pass to normalize memrefs to have
10// identity layout maps.
11//
12//===----------------------------------------------------------------------===//
13
14#include "mlir/Dialect/Affine/IR/AffineOps.h"
15#include "mlir/Dialect/Affine/Utils.h"
16#include "mlir/Dialect/Func/IR/FuncOps.h"
17#include "mlir/Dialect/MemRef/IR/MemRef.h"
18#include "mlir/Dialect/MemRef/Transforms/Passes.h"
19#include "llvm/ADT/SmallSet.h"
20#include "llvm/Support/Debug.h"
21
22namespace mlir {
23namespace memref {
24#define GEN_PASS_DEF_NORMALIZEMEMREFS
25#include "mlir/Dialect/MemRef/Transforms/Passes.h.inc"
26} // namespace memref
27} // namespace mlir
28
29#define DEBUG_TYPE "normalize-memrefs"
30
31using namespace mlir;
32using namespace mlir::affine;
33
34namespace {
35
36/// All memrefs passed across functions with non-trivial layout maps are
37/// converted to ones with trivial identity layout ones.
38/// If all the memref types/uses in a function are normalizable, we treat
39/// such functions as normalizable. Also, if a normalizable function is known
40/// to call a non-normalizable function, we treat that function as
41/// non-normalizable as well. We assume external functions to be normalizable.
42struct NormalizeMemRefs
43 : public memref::impl::NormalizeMemRefsBase<NormalizeMemRefs> {
44 void runOnOperation() override;
45 void normalizeFuncOpMemRefs(func::FuncOp funcOp, ModuleOp moduleOp);
46 bool areMemRefsNormalizable(func::FuncOp funcOp);
47 void updateFunctionSignature(func::FuncOp funcOp, ModuleOp moduleOp);
48 void setCalleesAndCallersNonNormalizable(
49 func::FuncOp funcOp, ModuleOp moduleOp,
50 DenseSet<func::FuncOp> &normalizableFuncs);
51 Operation *createOpResultsNormalized(func::FuncOp funcOp, Operation *oldOp);
52};
53
54} // namespace
55
56std::unique_ptr<OperationPass<ModuleOp>>
57mlir::memref::createNormalizeMemRefsPass() {
58 return std::make_unique<NormalizeMemRefs>();
59}
60
61void NormalizeMemRefs::runOnOperation() {
62 LLVM_DEBUG(llvm::dbgs() << "Normalizing Memrefs...\n");
63 ModuleOp moduleOp = getOperation();
64 // We maintain all normalizable FuncOps in a DenseSet. It is initialized
65 // with all the functions within a module and then functions which are not
66 // normalizable are removed from this set.
67 // TODO: Change this to work on FuncLikeOp once there is an operation
68 // interface for it.
69 DenseSet<func::FuncOp> normalizableFuncs;
70 // Initialize `normalizableFuncs` with all the functions within a module.
71 moduleOp.walk([&](func::FuncOp funcOp) { normalizableFuncs.insert(funcOp); });
72
73 // Traverse through all the functions applying a filter which determines
74 // whether that function is normalizable or not. All callers/callees of
75 // a non-normalizable function will also become non-normalizable even if
76 // they aren't passing any or specific non-normalizable memrefs. So,
77 // functions which calls or get called by a non-normalizable becomes non-
78 // normalizable functions themselves.
79 moduleOp.walk([&](func::FuncOp funcOp) {
80 if (normalizableFuncs.contains(V: funcOp)) {
81 if (!areMemRefsNormalizable(funcOp: funcOp)) {
82 LLVM_DEBUG(llvm::dbgs()
83 << "@" << funcOp.getName()
84 << " contains ops that cannot normalize MemRefs\n");
85 // Since this function is not normalizable, we set all the caller
86 // functions and the callees of this function as not normalizable.
87 // TODO: Drop this conservative assumption in the future.
88 setCalleesAndCallersNonNormalizable(funcOp, moduleOp,
89 normalizableFuncs);
90 }
91 }
92 });
93
94 LLVM_DEBUG(llvm::dbgs() << "Normalizing " << normalizableFuncs.size()
95 << " functions\n");
96 // Those functions which can be normalized are subjected to normalization.
97 for (func::FuncOp &funcOp : normalizableFuncs)
98 normalizeFuncOpMemRefs(funcOp, moduleOp: moduleOp);
99}
100
101/// Check whether all the uses of oldMemRef are either dereferencing uses or the
102/// op is of type : DeallocOp, CallOp or ReturnOp. Only if these constraints
103/// are satisfied will the value become a candidate for replacement.
104/// TODO: Extend this for DimOps.
105static bool isMemRefNormalizable(Value::user_range opUsers) {
106 return llvm::all_of(Range&: opUsers, P: [](Operation *op) {
107 return op->hasTrait<OpTrait::MemRefsNormalizable>();
108 });
109}
110
111/// Set all the calling functions and the callees of the function as not
112/// normalizable.
113void NormalizeMemRefs::setCalleesAndCallersNonNormalizable(
114 func::FuncOp funcOp, ModuleOp moduleOp,
115 DenseSet<func::FuncOp> &normalizableFuncs) {
116 if (!normalizableFuncs.contains(V: funcOp))
117 return;
118
119 LLVM_DEBUG(
120 llvm::dbgs() << "@" << funcOp.getName()
121 << " calls or is called by non-normalizable function\n");
122 normalizableFuncs.erase(funcOp);
123 // Caller of the function.
124 std::optional<SymbolTable::UseRange> symbolUses =
125 funcOp.getSymbolUses(moduleOp);
126 for (SymbolTable::SymbolUse symbolUse : *symbolUses) {
127 // TODO: Extend this for ops that are FunctionOpInterface. This would
128 // require creating an OpInterface for FunctionOpInterface ops.
129 func::FuncOp parentFuncOp =
130 symbolUse.getUser()->getParentOfType<func::FuncOp>();
131 for (func::FuncOp &funcOp : normalizableFuncs) {
132 if (parentFuncOp == funcOp) {
133 setCalleesAndCallersNonNormalizable(funcOp, moduleOp,
134 normalizableFuncs);
135 break;
136 }
137 }
138 }
139
140 // Functions called by this function.
141 funcOp.walk([&](func::CallOp callOp) {
142 StringAttr callee = callOp.getCalleeAttr().getAttr();
143 for (func::FuncOp &funcOp : normalizableFuncs) {
144 // We compare func::FuncOp and callee's name.
145 if (callee == funcOp.getNameAttr()) {
146 setCalleesAndCallersNonNormalizable(funcOp, moduleOp,
147 normalizableFuncs);
148 break;
149 }
150 }
151 });
152}
153
154/// Check whether all the uses of AllocOps, CallOps and function arguments of a
155/// function are either of dereferencing type or are uses in: DeallocOp, CallOp
156/// or ReturnOp. Only if these constraints are satisfied will the function
157/// become a candidate for normalization. When the uses of a memref are
158/// non-normalizable and the memref map layout is trivial (identity), we can
159/// still label the entire function as normalizable. We assume external
160/// functions to be normalizable.
161bool NormalizeMemRefs::areMemRefsNormalizable(func::FuncOp funcOp) {
162 // We assume external functions to be normalizable.
163 if (funcOp.isExternal())
164 return true;
165
166 if (funcOp
167 .walk([&](memref::AllocOp allocOp) -> WalkResult {
168 Value oldMemRef = allocOp.getResult();
169 if (!allocOp.getType().getLayout().isIdentity() &&
170 !isMemRefNormalizable(opUsers: oldMemRef.getUsers()))
171 return WalkResult::interrupt();
172 return WalkResult::advance();
173 })
174 .wasInterrupted())
175 return false;
176
177 if (funcOp
178 .walk([&](func::CallOp callOp) -> WalkResult {
179 for (unsigned resIndex :
180 llvm::seq<unsigned>(0, callOp.getNumResults())) {
181 Value oldMemRef = callOp.getResult(resIndex);
182 if (auto oldMemRefType =
183 dyn_cast<MemRefType>(oldMemRef.getType()))
184 if (!oldMemRefType.getLayout().isIdentity() &&
185 !isMemRefNormalizable(oldMemRef.getUsers()))
186 return WalkResult::interrupt();
187 }
188 return WalkResult::advance();
189 })
190 .wasInterrupted())
191 return false;
192
193 for (unsigned argIndex : llvm::seq<unsigned>(0, funcOp.getNumArguments())) {
194 BlockArgument oldMemRef = funcOp.getArgument(argIndex);
195 if (auto oldMemRefType = dyn_cast<MemRefType>(oldMemRef.getType()))
196 if (!oldMemRefType.getLayout().isIdentity() &&
197 !isMemRefNormalizable(oldMemRef.getUsers()))
198 return false;
199 }
200
201 return true;
202}
203
204/// Fetch the updated argument list and result of the function and update the
205/// function signature. This updates the function's return type at the caller
206/// site and in case the return type is a normalized memref then it updates
207/// the calling function's signature.
208/// TODO: An update to the calling function signature is required only if the
209/// returned value is in turn used in ReturnOp of the calling function.
210void NormalizeMemRefs::updateFunctionSignature(func::FuncOp funcOp,
211 ModuleOp moduleOp) {
212 FunctionType functionType = funcOp.getFunctionType();
213 SmallVector<Type, 4> resultTypes;
214 FunctionType newFuncType;
215 resultTypes = llvm::to_vector<4>(functionType.getResults());
216
217 // External function's signature was already updated in
218 // 'normalizeFuncOpMemRefs()'.
219 if (!funcOp.isExternal()) {
220 SmallVector<Type, 8> argTypes;
221 for (const auto &argEn : llvm::enumerate(funcOp.getArguments()))
222 argTypes.push_back(argEn.value().getType());
223
224 // Traverse ReturnOps to check if an update to the return type in the
225 // function signature is required.
226 funcOp.walk([&](func::ReturnOp returnOp) {
227 for (const auto &operandEn : llvm::enumerate(returnOp.getOperands())) {
228 Type opType = operandEn.value().getType();
229 MemRefType memrefType = dyn_cast<MemRefType>(opType);
230 // If type is not memref or if the memref type is same as that in
231 // function's return signature then no update is required.
232 if (!memrefType || memrefType == resultTypes[operandEn.index()])
233 continue;
234 // Update function's return type signature.
235 // Return type gets normalized either as a result of function argument
236 // normalization, AllocOp normalization or an update made at CallOp.
237 // There can be many call flows inside a function and an update to a
238 // specific ReturnOp has not yet been made. So we check that the result
239 // memref type is normalized.
240 // TODO: When selective normalization is implemented, handle multiple
241 // results case where some are normalized, some aren't.
242 if (memrefType.getLayout().isIdentity())
243 resultTypes[operandEn.index()] = memrefType;
244 }
245 });
246
247 // We create a new function type and modify the function signature with this
248 // new type.
249 newFuncType = FunctionType::get(&getContext(), /*inputs=*/argTypes,
250 /*results=*/resultTypes);
251 }
252
253 // Since we update the function signature, it might affect the result types at
254 // the caller site. Since this result might even be used by the caller
255 // function in ReturnOps, the caller function's signature will also change.
256 // Hence we record the caller function in 'funcOpsToUpdate' to update their
257 // signature as well.
258 llvm::SmallDenseSet<func::FuncOp, 8> funcOpsToUpdate;
259 // We iterate over all symbolic uses of the function and update the return
260 // type at the caller site.
261 std::optional<SymbolTable::UseRange> symbolUses =
262 funcOp.getSymbolUses(moduleOp);
263 for (SymbolTable::SymbolUse symbolUse : *symbolUses) {
264 Operation *userOp = symbolUse.getUser();
265 OpBuilder builder(userOp);
266 // When `userOp` can not be casted to `CallOp`, it is skipped. This assumes
267 // that the non-CallOp has no memrefs to be replaced.
268 // TODO: Handle cases where a non-CallOp symbol use of a function deals with
269 // memrefs.
270 auto callOp = dyn_cast<func::CallOp>(userOp);
271 if (!callOp)
272 continue;
273 Operation *newCallOp =
274 builder.create<func::CallOp>(userOp->getLoc(), callOp.getCalleeAttr(),
275 resultTypes, userOp->getOperands());
276 bool replacingMemRefUsesFailed = false;
277 bool returnTypeChanged = false;
278 for (unsigned resIndex : llvm::seq<unsigned>(0, userOp->getNumResults())) {
279 OpResult oldResult = userOp->getResult(resIndex);
280 OpResult newResult = newCallOp->getResult(resIndex);
281 // This condition ensures that if the result is not of type memref or if
282 // the resulting memref was already having a trivial map layout then we
283 // need not perform any use replacement here.
284 if (oldResult.getType() == newResult.getType())
285 continue;
286 AffineMap layoutMap =
287 cast<MemRefType>(oldResult.getType()).getLayout().getAffineMap();
288 if (failed(replaceAllMemRefUsesWith(oldResult, /*newMemRef=*/newResult,
289 /*extraIndices=*/{},
290 /*indexRemap=*/layoutMap,
291 /*extraOperands=*/{},
292 /*symbolOperands=*/{},
293 /*domOpFilter=*/nullptr,
294 /*postDomOpFilter=*/nullptr,
295 /*allowNonDereferencingOps=*/true,
296 /*replaceInDeallocOp=*/true))) {
297 // If it failed (due to escapes for example), bail out.
298 // It should never hit this part of the code because it is called by
299 // only those functions which are normalizable.
300 newCallOp->erase();
301 replacingMemRefUsesFailed = true;
302 break;
303 }
304 returnTypeChanged = true;
305 }
306 if (replacingMemRefUsesFailed)
307 continue;
308 // Replace all uses for other non-memref result types.
309 userOp->replaceAllUsesWith(newCallOp);
310 userOp->erase();
311 if (returnTypeChanged) {
312 // Since the return type changed it might lead to a change in function's
313 // signature.
314 // TODO: If funcOp doesn't return any memref type then no need to update
315 // signature.
316 // TODO: Further optimization - Check if the memref is indeed part of
317 // ReturnOp at the parentFuncOp and only then updation of signature is
318 // required.
319 // TODO: Extend this for ops that are FunctionOpInterface. This would
320 // require creating an OpInterface for FunctionOpInterface ops.
321 func::FuncOp parentFuncOp = newCallOp->getParentOfType<func::FuncOp>();
322 funcOpsToUpdate.insert(parentFuncOp);
323 }
324 }
325 // Because external function's signature is already updated in
326 // 'normalizeFuncOpMemRefs()', we don't need to update it here again.
327 if (!funcOp.isExternal())
328 funcOp.setType(newFuncType);
329
330 // Updating the signature type of those functions which call the current
331 // function. Only if the return type of the current function has a normalized
332 // memref will the caller function become a candidate for signature update.
333 for (func::FuncOp parentFuncOp : funcOpsToUpdate)
334 updateFunctionSignature(funcOp: parentFuncOp, moduleOp: moduleOp);
335}
336
337/// Normalizes the memrefs within a function which includes those arising as a
338/// result of AllocOps, CallOps and function's argument. The ModuleOp argument
339/// is used to help update function's signature after normalization.
340void NormalizeMemRefs::normalizeFuncOpMemRefs(func::FuncOp funcOp,
341 ModuleOp moduleOp) {
342 // Turn memrefs' non-identity layouts maps into ones with identity. Collect
343 // alloc ops first and then process since normalizeMemRef replaces/erases ops
344 // during memref rewriting.
345 SmallVector<memref::AllocOp, 4> allocOps;
346 funcOp.walk([&](memref::AllocOp op) { allocOps.push_back(op); });
347 for (memref::AllocOp allocOp : allocOps)
348 (void)normalizeMemRef(&allocOp);
349
350 // We use this OpBuilder to create new memref layout later.
351 OpBuilder b(funcOp);
352
353 FunctionType functionType = funcOp.getFunctionType();
354 SmallVector<Location> functionArgLocs(llvm::map_range(
355 funcOp.getArguments(), [](BlockArgument arg) { return arg.getLoc(); }));
356 SmallVector<Type, 8> inputTypes;
357 // Walk over each argument of a function to perform memref normalization (if
358 for (unsigned argIndex :
359 llvm::seq<unsigned>(0, functionType.getNumInputs())) {
360 Type argType = functionType.getInput(argIndex);
361 MemRefType memrefType = dyn_cast<MemRefType>(argType);
362 // Check whether argument is of MemRef type. Any other argument type can
363 // simply be part of the final function signature.
364 if (!memrefType) {
365 inputTypes.push_back(argType);
366 continue;
367 }
368 // Fetch a new memref type after normalizing the old memref to have an
369 // identity map layout.
370 MemRefType newMemRefType = normalizeMemRefType(memrefType);
371 if (newMemRefType == memrefType || funcOp.isExternal()) {
372 // Either memrefType already had an identity map or the map couldn't be
373 // transformed to an identity map.
374 inputTypes.push_back(newMemRefType);
375 continue;
376 }
377
378 // Insert a new temporary argument with the new memref type.
379 BlockArgument newMemRef = funcOp.front().insertArgument(
380 argIndex, newMemRefType, functionArgLocs[argIndex]);
381 BlockArgument oldMemRef = funcOp.getArgument(argIndex + 1);
382 AffineMap layoutMap = memrefType.getLayout().getAffineMap();
383 // Replace all uses of the old memref.
384 if (failed(replaceAllMemRefUsesWith(oldMemRef, /*newMemRef=*/newMemRef,
385 /*extraIndices=*/{},
386 /*indexRemap=*/layoutMap,
387 /*extraOperands=*/{},
388 /*symbolOperands=*/{},
389 /*domOpFilter=*/nullptr,
390 /*postDomOpFilter=*/nullptr,
391 /*allowNonDereferencingOps=*/true,
392 /*replaceInDeallocOp=*/true))) {
393 // If it failed (due to escapes for example), bail out. Removing the
394 // temporary argument inserted previously.
395 funcOp.front().eraseArgument(argIndex);
396 continue;
397 }
398
399 // All uses for the argument with old memref type were replaced
400 // successfully. So we remove the old argument now.
401 funcOp.front().eraseArgument(argIndex + 1);
402 }
403
404 // Walk over normalizable operations to normalize memrefs of the operation
405 // results. When `op` has memrefs with affine map in the operation results,
406 // new operation containin normalized memrefs is created. Then, the memrefs
407 // are replaced. `CallOp` is skipped here because it is handled in
408 // `updateFunctionSignature()`.
409 funcOp.walk([&](Operation *op) {
410 if (op->hasTrait<OpTrait::MemRefsNormalizable>() &&
411 op->getNumResults() > 0 && !isa<func::CallOp>(op) &&
412 !funcOp.isExternal()) {
413 // Create newOp containing normalized memref in the operation result.
414 Operation *newOp = createOpResultsNormalized(funcOp: funcOp, oldOp: op);
415 // When all of the operation results have no memrefs or memrefs without
416 // affine map, `newOp` is the same with `op` and following process is
417 // skipped.
418 if (op != newOp) {
419 bool replacingMemRefUsesFailed = false;
420 for (unsigned resIndex : llvm::seq<unsigned>(Begin: 0, End: op->getNumResults())) {
421 // Replace all uses of the old memrefs.
422 Value oldMemRef = op->getResult(idx: resIndex);
423 Value newMemRef = newOp->getResult(idx: resIndex);
424 MemRefType oldMemRefType = dyn_cast<MemRefType>(oldMemRef.getType());
425 // Check whether the operation result is MemRef type.
426 if (!oldMemRefType)
427 continue;
428 MemRefType newMemRefType = cast<MemRefType>(newMemRef.getType());
429 if (oldMemRefType == newMemRefType)
430 continue;
431 // TODO: Assume single layout map. Multiple maps not supported.
432 AffineMap layoutMap = oldMemRefType.getLayout().getAffineMap();
433 if (failed(result: replaceAllMemRefUsesWith(oldMemRef,
434 /*newMemRef=*/newMemRef,
435 /*extraIndices=*/{},
436 /*indexRemap=*/layoutMap,
437 /*extraOperands=*/{},
438 /*symbolOperands=*/{},
439 /*domOpFilter=*/nullptr,
440 /*postDomOpFilter=*/nullptr,
441 /*allowNonDereferencingOps=*/true,
442 /*replaceInDeallocOp=*/true))) {
443 newOp->erase();
444 replacingMemRefUsesFailed = true;
445 continue;
446 }
447 }
448 if (!replacingMemRefUsesFailed) {
449 // Replace other ops with new op and delete the old op when the
450 // replacement succeeded.
451 op->replaceAllUsesWith(values&: newOp);
452 op->erase();
453 }
454 }
455 }
456 });
457
458 // In a normal function, memrefs in the return type signature gets normalized
459 // as a result of normalization of functions arguments, AllocOps or CallOps'
460 // result types. Since an external function doesn't have a body, memrefs in
461 // the return type signature can only get normalized by iterating over the
462 // individual return types.
463 if (funcOp.isExternal()) {
464 SmallVector<Type, 4> resultTypes;
465 for (unsigned resIndex :
466 llvm::seq<unsigned>(0, functionType.getNumResults())) {
467 Type resType = functionType.getResult(resIndex);
468 MemRefType memrefType = dyn_cast<MemRefType>(resType);
469 // Check whether result is of MemRef type. Any other argument type can
470 // simply be part of the final function signature.
471 if (!memrefType) {
472 resultTypes.push_back(resType);
473 continue;
474 }
475 // Computing a new memref type after normalizing the old memref to have an
476 // identity map layout.
477 MemRefType newMemRefType = normalizeMemRefType(memrefType);
478 resultTypes.push_back(newMemRefType);
479 }
480
481 FunctionType newFuncType =
482 FunctionType::get(&getContext(), /*inputs=*/inputTypes,
483 /*results=*/resultTypes);
484 // Setting the new function signature for this external function.
485 funcOp.setType(newFuncType);
486 }
487 updateFunctionSignature(funcOp: funcOp, moduleOp: moduleOp);
488}
489
490/// Create an operation containing normalized memrefs in the operation results.
491/// When the results of `oldOp` have memrefs with affine map, the memrefs are
492/// normalized, and new operation containing them in the operation results is
493/// returned. If all of the results of `oldOp` have no memrefs or memrefs
494/// without affine map, `oldOp` is returned without modification.
495Operation *NormalizeMemRefs::createOpResultsNormalized(func::FuncOp funcOp,
496 Operation *oldOp) {
497 // Prepare OperationState to create newOp containing normalized memref in
498 // the operation results.
499 OperationState result(oldOp->getLoc(), oldOp->getName());
500 result.addOperands(newOperands: oldOp->getOperands());
501 result.addAttributes(newAttributes: oldOp->getAttrs());
502 // Add normalized MemRefType to the OperationState.
503 SmallVector<Type, 4> resultTypes;
504 OpBuilder b(funcOp);
505 bool resultTypeNormalized = false;
506 for (unsigned resIndex : llvm::seq<unsigned>(Begin: 0, End: oldOp->getNumResults())) {
507 auto resultType = oldOp->getResult(idx: resIndex).getType();
508 MemRefType memrefType = dyn_cast<MemRefType>(resultType);
509 // Check whether the operation result is MemRef type.
510 if (!memrefType) {
511 resultTypes.push_back(Elt: resultType);
512 continue;
513 }
514
515 // Fetch a new memref type after normalizing the old memref.
516 MemRefType newMemRefType = normalizeMemRefType(memrefType);
517 if (newMemRefType == memrefType) {
518 // Either memrefType already had an identity map or the map couldn't
519 // be transformed to an identity map.
520 resultTypes.push_back(Elt: memrefType);
521 continue;
522 }
523 resultTypes.push_back(Elt: newMemRefType);
524 resultTypeNormalized = true;
525 }
526 result.addTypes(newTypes: resultTypes);
527 // When all of the results of `oldOp` have no memrefs or memrefs without
528 // affine map, `oldOp` is returned without modification.
529 if (resultTypeNormalized) {
530 OpBuilder bb(oldOp);
531 for (auto &oldRegion : oldOp->getRegions()) {
532 Region *newRegion = result.addRegion();
533 newRegion->takeBody(other&: oldRegion);
534 }
535 return bb.create(state: result);
536 }
537 return oldOp;
538}
539

source code of mlir/lib/Dialect/MemRef/Transforms/NormalizeMemRefs.cpp