1 | //===- NormalizeMemRefs.cpp -----------------------------------------------===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | // |
9 | // This file implements an interprocedural pass to normalize memrefs to have |
10 | // identity layout maps. |
11 | // |
12 | //===----------------------------------------------------------------------===// |
13 | |
14 | #include "mlir/Dialect/Affine/IR/AffineOps.h" |
15 | #include "mlir/Dialect/Affine/Utils.h" |
16 | #include "mlir/Dialect/Func/IR/FuncOps.h" |
17 | #include "mlir/Dialect/MemRef/IR/MemRef.h" |
18 | #include "mlir/Dialect/MemRef/Transforms/Passes.h" |
19 | #include "llvm/ADT/SmallSet.h" |
20 | #include "llvm/Support/Debug.h" |
21 | |
22 | namespace mlir { |
23 | namespace memref { |
24 | #define GEN_PASS_DEF_NORMALIZEMEMREFSPASS |
25 | #include "mlir/Dialect/MemRef/Transforms/Passes.h.inc" |
26 | } // namespace memref |
27 | } // namespace mlir |
28 | |
29 | #define DEBUG_TYPE "normalize-memrefs" |
30 | |
31 | using namespace mlir; |
32 | using namespace mlir::affine; |
33 | using namespace mlir::memref; |
34 | |
35 | namespace { |
36 | |
37 | /// All memrefs passed across functions with non-trivial layout maps are |
38 | /// converted to ones with trivial identity layout ones. |
39 | /// If all the memref types/uses in a function are normalizable, we treat |
40 | /// such functions as normalizable. Also, if a normalizable function is known |
41 | /// to call a non-normalizable function, we treat that function as |
42 | /// non-normalizable as well. We assume external functions to be normalizable. |
43 | struct NormalizeMemRefs |
44 | : public memref::impl::NormalizeMemRefsPassBase<NormalizeMemRefs> { |
45 | void runOnOperation() override; |
46 | void normalizeFuncOpMemRefs(func::FuncOp funcOp, ModuleOp moduleOp); |
47 | bool areMemRefsNormalizable(func::FuncOp funcOp); |
48 | void updateFunctionSignature(func::FuncOp funcOp, ModuleOp moduleOp); |
49 | void setCalleesAndCallersNonNormalizable( |
50 | func::FuncOp funcOp, ModuleOp moduleOp, |
51 | DenseSet<func::FuncOp> &normalizableFuncs); |
52 | Operation *createOpResultsNormalized(func::FuncOp funcOp, Operation *oldOp); |
53 | }; |
54 | |
55 | } // namespace |
56 | |
57 | void NormalizeMemRefs::runOnOperation() { |
58 | LLVM_DEBUG(llvm::dbgs() << "Normalizing Memrefs...\n" ); |
59 | ModuleOp moduleOp = getOperation(); |
60 | // We maintain all normalizable FuncOps in a DenseSet. It is initialized |
61 | // with all the functions within a module and then functions which are not |
62 | // normalizable are removed from this set. |
63 | // TODO: Change this to work on FuncLikeOp once there is an operation |
64 | // interface for it. |
65 | DenseSet<func::FuncOp> normalizableFuncs; |
66 | // Initialize `normalizableFuncs` with all the functions within a module. |
67 | moduleOp.walk([&](func::FuncOp funcOp) { normalizableFuncs.insert(funcOp); }); |
68 | |
69 | // Traverse through all the functions applying a filter which determines |
70 | // whether that function is normalizable or not. All callers/callees of |
71 | // a non-normalizable function will also become non-normalizable even if |
72 | // they aren't passing any or specific non-normalizable memrefs. So, |
73 | // functions which calls or get called by a non-normalizable becomes non- |
74 | // normalizable functions themselves. |
75 | moduleOp.walk([&](func::FuncOp funcOp) { |
76 | if (normalizableFuncs.contains(V: funcOp)) { |
77 | if (!areMemRefsNormalizable(funcOp: funcOp)) { |
78 | LLVM_DEBUG(llvm::dbgs() |
79 | << "@" << funcOp.getName() |
80 | << " contains ops that cannot normalize MemRefs\n" ); |
81 | // Since this function is not normalizable, we set all the caller |
82 | // functions and the callees of this function as not normalizable. |
83 | // TODO: Drop this conservative assumption in the future. |
84 | setCalleesAndCallersNonNormalizable(funcOp, moduleOp, |
85 | normalizableFuncs); |
86 | } |
87 | } |
88 | }); |
89 | |
90 | LLVM_DEBUG(llvm::dbgs() << "Normalizing " << normalizableFuncs.size() |
91 | << " functions\n" ); |
92 | // Those functions which can be normalized are subjected to normalization. |
93 | for (func::FuncOp &funcOp : normalizableFuncs) |
94 | normalizeFuncOpMemRefs(funcOp, moduleOp: moduleOp); |
95 | } |
96 | |
97 | /// Check whether all the uses of oldMemRef are either dereferencing uses or the |
98 | /// op is of type : DeallocOp, CallOp or ReturnOp. Only if these constraints |
99 | /// are satisfied will the value become a candidate for replacement. |
100 | /// TODO: Extend this for DimOps. |
101 | static bool isMemRefNormalizable(Value::user_range opUsers) { |
102 | return llvm::all_of(Range&: opUsers, P: [](Operation *op) { |
103 | return op->hasTrait<OpTrait::MemRefsNormalizable>(); |
104 | }); |
105 | } |
106 | |
107 | /// Set all the calling functions and the callees of the function as not |
108 | /// normalizable. |
109 | void NormalizeMemRefs::setCalleesAndCallersNonNormalizable( |
110 | func::FuncOp funcOp, ModuleOp moduleOp, |
111 | DenseSet<func::FuncOp> &normalizableFuncs) { |
112 | if (!normalizableFuncs.contains(V: funcOp)) |
113 | return; |
114 | |
115 | LLVM_DEBUG( |
116 | llvm::dbgs() << "@" << funcOp.getName() |
117 | << " calls or is called by non-normalizable function\n" ); |
118 | normalizableFuncs.erase(funcOp); |
119 | // Caller of the function. |
120 | std::optional<SymbolTable::UseRange> symbolUses = |
121 | funcOp.getSymbolUses(moduleOp); |
122 | for (SymbolTable::SymbolUse symbolUse : *symbolUses) { |
123 | // TODO: Extend this for ops that are FunctionOpInterface. This would |
124 | // require creating an OpInterface for FunctionOpInterface ops. |
125 | func::FuncOp parentFuncOp = |
126 | symbolUse.getUser()->getParentOfType<func::FuncOp>(); |
127 | for (func::FuncOp &funcOp : normalizableFuncs) { |
128 | if (parentFuncOp == funcOp) { |
129 | setCalleesAndCallersNonNormalizable(funcOp, moduleOp, |
130 | normalizableFuncs); |
131 | break; |
132 | } |
133 | } |
134 | } |
135 | |
136 | // Functions called by this function. |
137 | funcOp.walk([&](func::CallOp callOp) { |
138 | StringAttr callee = callOp.getCalleeAttr().getAttr(); |
139 | for (func::FuncOp &funcOp : normalizableFuncs) { |
140 | // We compare func::FuncOp and callee's name. |
141 | if (callee == funcOp.getNameAttr()) { |
142 | setCalleesAndCallersNonNormalizable(funcOp, moduleOp, |
143 | normalizableFuncs); |
144 | break; |
145 | } |
146 | } |
147 | }); |
148 | } |
149 | |
150 | /// Check whether all the uses of AllocOps, AllocaOps, CallOps and function |
151 | /// arguments of a function are either of dereferencing type or are uses in: |
152 | /// DeallocOp, CallOp or ReturnOp. Only if these constraints are satisfied will |
153 | /// the function become a candidate for normalization. When the uses of a memref |
154 | /// are non-normalizable and the memref map layout is trivial (identity), we can |
155 | /// still label the entire function as normalizable. We assume external |
156 | /// functions to be normalizable. |
157 | bool NormalizeMemRefs::areMemRefsNormalizable(func::FuncOp funcOp) { |
158 | // We assume external functions to be normalizable. |
159 | if (funcOp.isExternal()) |
160 | return true; |
161 | |
162 | if (funcOp |
163 | .walk([&](AllocOp allocOp) -> WalkResult { |
164 | Value oldMemRef = allocOp.getResult(); |
165 | if (!allocOp.getType().getLayout().isIdentity() && |
166 | !isMemRefNormalizable(opUsers: oldMemRef.getUsers())) |
167 | return WalkResult::interrupt(); |
168 | return WalkResult::advance(); |
169 | }) |
170 | .wasInterrupted()) |
171 | return false; |
172 | |
173 | if (funcOp |
174 | .walk([&](AllocaOp allocaOp) -> WalkResult { |
175 | Value oldMemRef = allocaOp.getResult(); |
176 | if (!allocaOp.getType().getLayout().isIdentity() && |
177 | !isMemRefNormalizable(opUsers: oldMemRef.getUsers())) |
178 | return WalkResult::interrupt(); |
179 | return WalkResult::advance(); |
180 | }) |
181 | .wasInterrupted()) |
182 | return false; |
183 | |
184 | if (funcOp |
185 | .walk([&](func::CallOp callOp) -> WalkResult { |
186 | for (unsigned resIndex : |
187 | llvm::seq<unsigned>(0, callOp.getNumResults())) { |
188 | Value oldMemRef = callOp.getResult(resIndex); |
189 | if (auto oldMemRefType = |
190 | dyn_cast<MemRefType>(oldMemRef.getType())) |
191 | if (!oldMemRefType.getLayout().isIdentity() && |
192 | !isMemRefNormalizable(oldMemRef.getUsers())) |
193 | return WalkResult::interrupt(); |
194 | } |
195 | return WalkResult::advance(); |
196 | }) |
197 | .wasInterrupted()) |
198 | return false; |
199 | |
200 | for (unsigned argIndex : llvm::seq<unsigned>(0, funcOp.getNumArguments())) { |
201 | BlockArgument oldMemRef = funcOp.getArgument(argIndex); |
202 | if (auto oldMemRefType = dyn_cast<MemRefType>(oldMemRef.getType())) |
203 | if (!oldMemRefType.getLayout().isIdentity() && |
204 | !isMemRefNormalizable(oldMemRef.getUsers())) |
205 | return false; |
206 | } |
207 | |
208 | return true; |
209 | } |
210 | |
211 | /// Fetch the updated argument list and result of the function and update the |
212 | /// function signature. This updates the function's return type at the caller |
213 | /// site and in case the return type is a normalized memref then it updates |
214 | /// the calling function's signature. |
215 | /// TODO: An update to the calling function signature is required only if the |
216 | /// returned value is in turn used in ReturnOp of the calling function. |
217 | void NormalizeMemRefs::updateFunctionSignature(func::FuncOp funcOp, |
218 | ModuleOp moduleOp) { |
219 | FunctionType functionType = funcOp.getFunctionType(); |
220 | SmallVector<Type, 4> resultTypes; |
221 | FunctionType newFuncType; |
222 | resultTypes = llvm::to_vector<4>(functionType.getResults()); |
223 | |
224 | // External function's signature was already updated in |
225 | // 'normalizeFuncOpMemRefs()'. |
226 | if (!funcOp.isExternal()) { |
227 | SmallVector<Type, 8> argTypes; |
228 | for (const auto &argEn : llvm::enumerate(funcOp.getArguments())) |
229 | argTypes.push_back(argEn.value().getType()); |
230 | |
231 | // Traverse ReturnOps to check if an update to the return type in the |
232 | // function signature is required. |
233 | funcOp.walk([&](func::ReturnOp returnOp) { |
234 | for (const auto &operandEn : llvm::enumerate(returnOp.getOperands())) { |
235 | Type opType = operandEn.value().getType(); |
236 | MemRefType memrefType = dyn_cast<MemRefType>(opType); |
237 | // If type is not memref or if the memref type is same as that in |
238 | // function's return signature then no update is required. |
239 | if (!memrefType || memrefType == resultTypes[operandEn.index()]) |
240 | continue; |
241 | // Update function's return type signature. |
242 | // Return type gets normalized either as a result of function argument |
243 | // normalization, AllocOp normalization or an update made at CallOp. |
244 | // There can be many call flows inside a function and an update to a |
245 | // specific ReturnOp has not yet been made. So we check that the result |
246 | // memref type is normalized. |
247 | // TODO: When selective normalization is implemented, handle multiple |
248 | // results case where some are normalized, some aren't. |
249 | if (memrefType.getLayout().isIdentity()) |
250 | resultTypes[operandEn.index()] = memrefType; |
251 | } |
252 | }); |
253 | |
254 | // We create a new function type and modify the function signature with this |
255 | // new type. |
256 | newFuncType = FunctionType::get(&getContext(), /*inputs=*/argTypes, |
257 | /*results=*/resultTypes); |
258 | } |
259 | |
260 | // Since we update the function signature, it might affect the result types at |
261 | // the caller site. Since this result might even be used by the caller |
262 | // function in ReturnOps, the caller function's signature will also change. |
263 | // Hence we record the caller function in 'funcOpsToUpdate' to update their |
264 | // signature as well. |
265 | llvm::SmallDenseSet<func::FuncOp, 8> funcOpsToUpdate; |
266 | // We iterate over all symbolic uses of the function and update the return |
267 | // type at the caller site. |
268 | std::optional<SymbolTable::UseRange> symbolUses = |
269 | funcOp.getSymbolUses(moduleOp); |
270 | for (SymbolTable::SymbolUse symbolUse : *symbolUses) { |
271 | Operation *userOp = symbolUse.getUser(); |
272 | OpBuilder builder(userOp); |
273 | // When `userOp` can not be casted to `CallOp`, it is skipped. This assumes |
274 | // that the non-CallOp has no memrefs to be replaced. |
275 | // TODO: Handle cases where a non-CallOp symbol use of a function deals with |
276 | // memrefs. |
277 | auto callOp = dyn_cast<func::CallOp>(userOp); |
278 | if (!callOp) |
279 | continue; |
280 | Operation *newCallOp = |
281 | builder.create<func::CallOp>(userOp->getLoc(), callOp.getCalleeAttr(), |
282 | resultTypes, userOp->getOperands()); |
283 | bool replacingMemRefUsesFailed = false; |
284 | bool returnTypeChanged = false; |
285 | for (unsigned resIndex : llvm::seq<unsigned>(0, userOp->getNumResults())) { |
286 | OpResult oldResult = userOp->getResult(resIndex); |
287 | OpResult newResult = newCallOp->getResult(resIndex); |
288 | // This condition ensures that if the result is not of type memref or if |
289 | // the resulting memref was already having a trivial map layout then we |
290 | // need not perform any use replacement here. |
291 | if (oldResult.getType() == newResult.getType()) |
292 | continue; |
293 | AffineMap layoutMap = |
294 | cast<MemRefType>(oldResult.getType()).getLayout().getAffineMap(); |
295 | if (failed(replaceAllMemRefUsesWith(oldResult, /*newMemRef=*/newResult, |
296 | /*extraIndices=*/{}, |
297 | /*indexRemap=*/layoutMap, |
298 | /*extraOperands=*/{}, |
299 | /*symbolOperands=*/{}, |
300 | /*domOpFilter=*/nullptr, |
301 | /*postDomOpFilter=*/nullptr, |
302 | /*allowNonDereferencingOps=*/true, |
303 | /*replaceInDeallocOp=*/true))) { |
304 | // If it failed (due to escapes for example), bail out. |
305 | // It should never hit this part of the code because it is called by |
306 | // only those functions which are normalizable. |
307 | newCallOp->erase(); |
308 | replacingMemRefUsesFailed = true; |
309 | break; |
310 | } |
311 | returnTypeChanged = true; |
312 | } |
313 | if (replacingMemRefUsesFailed) |
314 | continue; |
315 | // Replace all uses for other non-memref result types. |
316 | userOp->replaceAllUsesWith(newCallOp); |
317 | userOp->erase(); |
318 | if (returnTypeChanged) { |
319 | // Since the return type changed it might lead to a change in function's |
320 | // signature. |
321 | // TODO: If funcOp doesn't return any memref type then no need to update |
322 | // signature. |
323 | // TODO: Further optimization - Check if the memref is indeed part of |
324 | // ReturnOp at the parentFuncOp and only then updation of signature is |
325 | // required. |
326 | // TODO: Extend this for ops that are FunctionOpInterface. This would |
327 | // require creating an OpInterface for FunctionOpInterface ops. |
328 | func::FuncOp parentFuncOp = newCallOp->getParentOfType<func::FuncOp>(); |
329 | funcOpsToUpdate.insert(parentFuncOp); |
330 | } |
331 | } |
332 | // Because external function's signature is already updated in |
333 | // 'normalizeFuncOpMemRefs()', we don't need to update it here again. |
334 | if (!funcOp.isExternal()) |
335 | funcOp.setType(newFuncType); |
336 | |
337 | // Updating the signature type of those functions which call the current |
338 | // function. Only if the return type of the current function has a normalized |
339 | // memref will the caller function become a candidate for signature update. |
340 | for (func::FuncOp parentFuncOp : funcOpsToUpdate) |
341 | updateFunctionSignature(funcOp: parentFuncOp, moduleOp: moduleOp); |
342 | } |
343 | |
344 | /// Normalizes the memrefs within a function which includes those arising as a |
345 | /// result of AllocOps, AllocaOps, CallOps, ReinterpretCastOps and function's |
346 | /// argument. The ModuleOp argument is used to help update function's signature |
347 | /// after normalization. |
348 | void NormalizeMemRefs::normalizeFuncOpMemRefs(func::FuncOp funcOp, |
349 | ModuleOp moduleOp) { |
350 | // Turn memrefs' non-identity layouts maps into ones with identity. Collect |
351 | // alloc, alloca ops and reinterpret_cast ops first and then process since |
352 | // normalizeMemRef replaces/erases ops during memref rewriting. |
353 | SmallVector<AllocOp, 4> allocOps; |
354 | SmallVector<AllocaOp> allocaOps; |
355 | SmallVector<ReinterpretCastOp> reinterpretCastOps; |
356 | funcOp.walk([&](Operation *op) { |
357 | if (auto allocOp = dyn_cast<AllocOp>(op)) |
358 | allocOps.push_back(allocOp); |
359 | else if (auto allocaOp = dyn_cast<AllocaOp>(op)) |
360 | allocaOps.push_back(allocaOp); |
361 | else if (auto reinterpretCastOp = dyn_cast<ReinterpretCastOp>(op)) |
362 | reinterpretCastOps.push_back(reinterpretCastOp); |
363 | }); |
364 | for (AllocOp allocOp : allocOps) |
365 | (void)normalizeMemRef(allocOp); |
366 | for (AllocaOp allocaOp : allocaOps) |
367 | (void)normalizeMemRef(allocaOp); |
368 | for (ReinterpretCastOp reinterpretCastOp : reinterpretCastOps) |
369 | (void)normalizeMemRef(reinterpretCastOp); |
370 | |
371 | // We use this OpBuilder to create new memref layout later. |
372 | OpBuilder b(funcOp); |
373 | |
374 | FunctionType functionType = funcOp.getFunctionType(); |
375 | SmallVector<Location> functionArgLocs(llvm::map_range( |
376 | funcOp.getArguments(), [](BlockArgument arg) { return arg.getLoc(); })); |
377 | SmallVector<Type, 8> inputTypes; |
378 | // Walk over each argument of a function to perform memref normalization (if |
379 | for (unsigned argIndex : |
380 | llvm::seq<unsigned>(0, functionType.getNumInputs())) { |
381 | Type argType = functionType.getInput(argIndex); |
382 | MemRefType memrefType = dyn_cast<MemRefType>(argType); |
383 | // Check whether argument is of MemRef type. Any other argument type can |
384 | // simply be part of the final function signature. |
385 | if (!memrefType) { |
386 | inputTypes.push_back(argType); |
387 | continue; |
388 | } |
389 | // Fetch a new memref type after normalizing the old memref to have an |
390 | // identity map layout. |
391 | MemRefType newMemRefType = normalizeMemRefType(memrefType); |
392 | if (newMemRefType == memrefType || funcOp.isExternal()) { |
393 | // Either memrefType already had an identity map or the map couldn't be |
394 | // transformed to an identity map. |
395 | inputTypes.push_back(newMemRefType); |
396 | continue; |
397 | } |
398 | |
399 | // Insert a new temporary argument with the new memref type. |
400 | BlockArgument newMemRef = funcOp.front().insertArgument( |
401 | argIndex, newMemRefType, functionArgLocs[argIndex]); |
402 | BlockArgument oldMemRef = funcOp.getArgument(argIndex + 1); |
403 | AffineMap layoutMap = memrefType.getLayout().getAffineMap(); |
404 | // Replace all uses of the old memref. |
405 | if (failed(replaceAllMemRefUsesWith(oldMemRef, /*newMemRef=*/newMemRef, |
406 | /*extraIndices=*/{}, |
407 | /*indexRemap=*/layoutMap, |
408 | /*extraOperands=*/{}, |
409 | /*symbolOperands=*/{}, |
410 | /*domOpFilter=*/nullptr, |
411 | /*postDomOpFilter=*/nullptr, |
412 | /*allowNonDereferencingOps=*/true, |
413 | /*replaceInDeallocOp=*/true))) { |
414 | // If it failed (due to escapes for example), bail out. Removing the |
415 | // temporary argument inserted previously. |
416 | funcOp.front().eraseArgument(argIndex); |
417 | continue; |
418 | } |
419 | |
420 | // All uses for the argument with old memref type were replaced |
421 | // successfully. So we remove the old argument now. |
422 | funcOp.front().eraseArgument(argIndex + 1); |
423 | } |
424 | |
425 | // Walk over normalizable operations to normalize memrefs of the operation |
426 | // results. When `op` has memrefs with affine map in the operation results, |
427 | // new operation containin normalized memrefs is created. Then, the memrefs |
428 | // are replaced. `CallOp` is skipped here because it is handled in |
429 | // `updateFunctionSignature()`. |
430 | funcOp.walk([&](Operation *op) { |
431 | if (op->hasTrait<OpTrait::MemRefsNormalizable>() && |
432 | op->getNumResults() > 0 && !isa<func::CallOp>(op) && |
433 | !funcOp.isExternal()) { |
434 | // Create newOp containing normalized memref in the operation result. |
435 | Operation *newOp = createOpResultsNormalized(funcOp: funcOp, oldOp: op); |
436 | // When all of the operation results have no memrefs or memrefs without |
437 | // affine map, `newOp` is the same with `op` and following process is |
438 | // skipped. |
439 | if (op != newOp) { |
440 | bool replacingMemRefUsesFailed = false; |
441 | for (unsigned resIndex : llvm::seq<unsigned>(Begin: 0, End: op->getNumResults())) { |
442 | // Replace all uses of the old memrefs. |
443 | Value oldMemRef = op->getResult(idx: resIndex); |
444 | Value newMemRef = newOp->getResult(idx: resIndex); |
445 | MemRefType oldMemRefType = dyn_cast<MemRefType>(oldMemRef.getType()); |
446 | // Check whether the operation result is MemRef type. |
447 | if (!oldMemRefType) |
448 | continue; |
449 | MemRefType newMemRefType = cast<MemRefType>(newMemRef.getType()); |
450 | if (oldMemRefType == newMemRefType) |
451 | continue; |
452 | // TODO: Assume single layout map. Multiple maps not supported. |
453 | AffineMap layoutMap = oldMemRefType.getLayout().getAffineMap(); |
454 | if (failed(Result: replaceAllMemRefUsesWith(oldMemRef, |
455 | /*newMemRef=*/newMemRef, |
456 | /*extraIndices=*/{}, |
457 | /*indexRemap=*/layoutMap, |
458 | /*extraOperands=*/{}, |
459 | /*symbolOperands=*/{}, |
460 | /*domOpFilter=*/nullptr, |
461 | /*postDomOpFilter=*/nullptr, |
462 | /*allowNonDereferencingOps=*/true, |
463 | /*replaceInDeallocOp=*/true))) { |
464 | newOp->erase(); |
465 | replacingMemRefUsesFailed = true; |
466 | continue; |
467 | } |
468 | } |
469 | if (!replacingMemRefUsesFailed) { |
470 | // Replace other ops with new op and delete the old op when the |
471 | // replacement succeeded. |
472 | op->replaceAllUsesWith(values&: newOp); |
473 | op->erase(); |
474 | } |
475 | } |
476 | } |
477 | }); |
478 | |
479 | // In a normal function, memrefs in the return type signature gets normalized |
480 | // as a result of normalization of functions arguments, AllocOps or CallOps' |
481 | // result types. Since an external function doesn't have a body, memrefs in |
482 | // the return type signature can only get normalized by iterating over the |
483 | // individual return types. |
484 | if (funcOp.isExternal()) { |
485 | SmallVector<Type, 4> resultTypes; |
486 | for (unsigned resIndex : |
487 | llvm::seq<unsigned>(0, functionType.getNumResults())) { |
488 | Type resType = functionType.getResult(resIndex); |
489 | MemRefType memrefType = dyn_cast<MemRefType>(resType); |
490 | // Check whether result is of MemRef type. Any other argument type can |
491 | // simply be part of the final function signature. |
492 | if (!memrefType) { |
493 | resultTypes.push_back(resType); |
494 | continue; |
495 | } |
496 | // Computing a new memref type after normalizing the old memref to have an |
497 | // identity map layout. |
498 | MemRefType newMemRefType = normalizeMemRefType(memrefType); |
499 | resultTypes.push_back(newMemRefType); |
500 | } |
501 | |
502 | FunctionType newFuncType = |
503 | FunctionType::get(&getContext(), /*inputs=*/inputTypes, |
504 | /*results=*/resultTypes); |
505 | // Setting the new function signature for this external function. |
506 | funcOp.setType(newFuncType); |
507 | } |
508 | updateFunctionSignature(funcOp: funcOp, moduleOp: moduleOp); |
509 | } |
510 | |
511 | /// Create an operation containing normalized memrefs in the operation results. |
512 | /// When the results of `oldOp` have memrefs with affine map, the memrefs are |
513 | /// normalized, and new operation containing them in the operation results is |
514 | /// returned. If all of the results of `oldOp` have no memrefs or memrefs |
515 | /// without affine map, `oldOp` is returned without modification. |
516 | Operation *NormalizeMemRefs::createOpResultsNormalized(func::FuncOp funcOp, |
517 | Operation *oldOp) { |
518 | // Prepare OperationState to create newOp containing normalized memref in |
519 | // the operation results. |
520 | OperationState result(oldOp->getLoc(), oldOp->getName()); |
521 | result.addOperands(newOperands: oldOp->getOperands()); |
522 | result.addAttributes(newAttributes: oldOp->getAttrs()); |
523 | // Add normalized MemRefType to the OperationState. |
524 | SmallVector<Type, 4> resultTypes; |
525 | OpBuilder b(funcOp); |
526 | bool resultTypeNormalized = false; |
527 | for (unsigned resIndex : llvm::seq<unsigned>(Begin: 0, End: oldOp->getNumResults())) { |
528 | auto resultType = oldOp->getResult(idx: resIndex).getType(); |
529 | MemRefType memrefType = dyn_cast<MemRefType>(resultType); |
530 | // Check whether the operation result is MemRef type. |
531 | if (!memrefType) { |
532 | resultTypes.push_back(Elt: resultType); |
533 | continue; |
534 | } |
535 | |
536 | // Fetch a new memref type after normalizing the old memref. |
537 | MemRefType newMemRefType = normalizeMemRefType(memrefType); |
538 | if (newMemRefType == memrefType) { |
539 | // Either memrefType already had an identity map or the map couldn't |
540 | // be transformed to an identity map. |
541 | resultTypes.push_back(Elt: memrefType); |
542 | continue; |
543 | } |
544 | resultTypes.push_back(Elt: newMemRefType); |
545 | resultTypeNormalized = true; |
546 | } |
547 | result.addTypes(newTypes: resultTypes); |
548 | // When all of the results of `oldOp` have no memrefs or memrefs without |
549 | // affine map, `oldOp` is returned without modification. |
550 | if (resultTypeNormalized) { |
551 | OpBuilder bb(oldOp); |
552 | for (auto &oldRegion : oldOp->getRegions()) { |
553 | Region *newRegion = result.addRegion(); |
554 | newRegion->takeBody(other&: oldRegion); |
555 | } |
556 | return bb.create(state: result); |
557 | } |
558 | return oldOp; |
559 | } |
560 | |