1 | //===- NormalizeMemRefs.cpp -----------------------------------------------===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | // |
9 | // This file implements an interprocedural pass to normalize memrefs to have |
10 | // identity layout maps. |
11 | // |
12 | //===----------------------------------------------------------------------===// |
13 | |
14 | #include "mlir/Dialect/Affine/IR/AffineOps.h" |
15 | #include "mlir/Dialect/Affine/Utils.h" |
16 | #include "mlir/Dialect/Func/IR/FuncOps.h" |
17 | #include "mlir/Dialect/MemRef/IR/MemRef.h" |
18 | #include "mlir/Dialect/MemRef/Transforms/Passes.h" |
19 | #include "llvm/ADT/SmallSet.h" |
20 | #include "llvm/Support/Debug.h" |
21 | |
22 | namespace mlir { |
23 | namespace memref { |
24 | #define GEN_PASS_DEF_NORMALIZEMEMREFS |
25 | #include "mlir/Dialect/MemRef/Transforms/Passes.h.inc" |
26 | } // namespace memref |
27 | } // namespace mlir |
28 | |
29 | #define DEBUG_TYPE "normalize-memrefs" |
30 | |
31 | using namespace mlir; |
32 | using namespace mlir::affine; |
33 | |
34 | namespace { |
35 | |
36 | /// All memrefs passed across functions with non-trivial layout maps are |
37 | /// converted to ones with trivial identity layout ones. |
38 | /// If all the memref types/uses in a function are normalizable, we treat |
39 | /// such functions as normalizable. Also, if a normalizable function is known |
40 | /// to call a non-normalizable function, we treat that function as |
41 | /// non-normalizable as well. We assume external functions to be normalizable. |
42 | struct NormalizeMemRefs |
43 | : public memref::impl::NormalizeMemRefsBase<NormalizeMemRefs> { |
44 | void runOnOperation() override; |
45 | void normalizeFuncOpMemRefs(func::FuncOp funcOp, ModuleOp moduleOp); |
46 | bool areMemRefsNormalizable(func::FuncOp funcOp); |
47 | void updateFunctionSignature(func::FuncOp funcOp, ModuleOp moduleOp); |
48 | void setCalleesAndCallersNonNormalizable( |
49 | func::FuncOp funcOp, ModuleOp moduleOp, |
50 | DenseSet<func::FuncOp> &normalizableFuncs); |
51 | Operation *createOpResultsNormalized(func::FuncOp funcOp, Operation *oldOp); |
52 | }; |
53 | |
54 | } // namespace |
55 | |
56 | std::unique_ptr<OperationPass<ModuleOp>> |
57 | mlir::memref::createNormalizeMemRefsPass() { |
58 | return std::make_unique<NormalizeMemRefs>(); |
59 | } |
60 | |
61 | void NormalizeMemRefs::runOnOperation() { |
62 | LLVM_DEBUG(llvm::dbgs() << "Normalizing Memrefs...\n" ); |
63 | ModuleOp moduleOp = getOperation(); |
64 | // We maintain all normalizable FuncOps in a DenseSet. It is initialized |
65 | // with all the functions within a module and then functions which are not |
66 | // normalizable are removed from this set. |
67 | // TODO: Change this to work on FuncLikeOp once there is an operation |
68 | // interface for it. |
69 | DenseSet<func::FuncOp> normalizableFuncs; |
70 | // Initialize `normalizableFuncs` with all the functions within a module. |
71 | moduleOp.walk([&](func::FuncOp funcOp) { normalizableFuncs.insert(funcOp); }); |
72 | |
73 | // Traverse through all the functions applying a filter which determines |
74 | // whether that function is normalizable or not. All callers/callees of |
75 | // a non-normalizable function will also become non-normalizable even if |
76 | // they aren't passing any or specific non-normalizable memrefs. So, |
77 | // functions which calls or get called by a non-normalizable becomes non- |
78 | // normalizable functions themselves. |
79 | moduleOp.walk([&](func::FuncOp funcOp) { |
80 | if (normalizableFuncs.contains(V: funcOp)) { |
81 | if (!areMemRefsNormalizable(funcOp: funcOp)) { |
82 | LLVM_DEBUG(llvm::dbgs() |
83 | << "@" << funcOp.getName() |
84 | << " contains ops that cannot normalize MemRefs\n" ); |
85 | // Since this function is not normalizable, we set all the caller |
86 | // functions and the callees of this function as not normalizable. |
87 | // TODO: Drop this conservative assumption in the future. |
88 | setCalleesAndCallersNonNormalizable(funcOp, moduleOp, |
89 | normalizableFuncs); |
90 | } |
91 | } |
92 | }); |
93 | |
94 | LLVM_DEBUG(llvm::dbgs() << "Normalizing " << normalizableFuncs.size() |
95 | << " functions\n" ); |
96 | // Those functions which can be normalized are subjected to normalization. |
97 | for (func::FuncOp &funcOp : normalizableFuncs) |
98 | normalizeFuncOpMemRefs(funcOp, moduleOp: moduleOp); |
99 | } |
100 | |
101 | /// Check whether all the uses of oldMemRef are either dereferencing uses or the |
102 | /// op is of type : DeallocOp, CallOp or ReturnOp. Only if these constraints |
103 | /// are satisfied will the value become a candidate for replacement. |
104 | /// TODO: Extend this for DimOps. |
105 | static bool isMemRefNormalizable(Value::user_range opUsers) { |
106 | return llvm::all_of(Range&: opUsers, P: [](Operation *op) { |
107 | return op->hasTrait<OpTrait::MemRefsNormalizable>(); |
108 | }); |
109 | } |
110 | |
111 | /// Set all the calling functions and the callees of the function as not |
112 | /// normalizable. |
113 | void NormalizeMemRefs::setCalleesAndCallersNonNormalizable( |
114 | func::FuncOp funcOp, ModuleOp moduleOp, |
115 | DenseSet<func::FuncOp> &normalizableFuncs) { |
116 | if (!normalizableFuncs.contains(V: funcOp)) |
117 | return; |
118 | |
119 | LLVM_DEBUG( |
120 | llvm::dbgs() << "@" << funcOp.getName() |
121 | << " calls or is called by non-normalizable function\n" ); |
122 | normalizableFuncs.erase(funcOp); |
123 | // Caller of the function. |
124 | std::optional<SymbolTable::UseRange> symbolUses = |
125 | funcOp.getSymbolUses(moduleOp); |
126 | for (SymbolTable::SymbolUse symbolUse : *symbolUses) { |
127 | // TODO: Extend this for ops that are FunctionOpInterface. This would |
128 | // require creating an OpInterface for FunctionOpInterface ops. |
129 | func::FuncOp parentFuncOp = |
130 | symbolUse.getUser()->getParentOfType<func::FuncOp>(); |
131 | for (func::FuncOp &funcOp : normalizableFuncs) { |
132 | if (parentFuncOp == funcOp) { |
133 | setCalleesAndCallersNonNormalizable(funcOp, moduleOp, |
134 | normalizableFuncs); |
135 | break; |
136 | } |
137 | } |
138 | } |
139 | |
140 | // Functions called by this function. |
141 | funcOp.walk([&](func::CallOp callOp) { |
142 | StringAttr callee = callOp.getCalleeAttr().getAttr(); |
143 | for (func::FuncOp &funcOp : normalizableFuncs) { |
144 | // We compare func::FuncOp and callee's name. |
145 | if (callee == funcOp.getNameAttr()) { |
146 | setCalleesAndCallersNonNormalizable(funcOp, moduleOp, |
147 | normalizableFuncs); |
148 | break; |
149 | } |
150 | } |
151 | }); |
152 | } |
153 | |
154 | /// Check whether all the uses of AllocOps, CallOps and function arguments of a |
155 | /// function are either of dereferencing type or are uses in: DeallocOp, CallOp |
156 | /// or ReturnOp. Only if these constraints are satisfied will the function |
157 | /// become a candidate for normalization. When the uses of a memref are |
158 | /// non-normalizable and the memref map layout is trivial (identity), we can |
159 | /// still label the entire function as normalizable. We assume external |
160 | /// functions to be normalizable. |
161 | bool NormalizeMemRefs::areMemRefsNormalizable(func::FuncOp funcOp) { |
162 | // We assume external functions to be normalizable. |
163 | if (funcOp.isExternal()) |
164 | return true; |
165 | |
166 | if (funcOp |
167 | .walk([&](memref::AllocOp allocOp) -> WalkResult { |
168 | Value oldMemRef = allocOp.getResult(); |
169 | if (!allocOp.getType().getLayout().isIdentity() && |
170 | !isMemRefNormalizable(opUsers: oldMemRef.getUsers())) |
171 | return WalkResult::interrupt(); |
172 | return WalkResult::advance(); |
173 | }) |
174 | .wasInterrupted()) |
175 | return false; |
176 | |
177 | if (funcOp |
178 | .walk([&](func::CallOp callOp) -> WalkResult { |
179 | for (unsigned resIndex : |
180 | llvm::seq<unsigned>(0, callOp.getNumResults())) { |
181 | Value oldMemRef = callOp.getResult(resIndex); |
182 | if (auto oldMemRefType = |
183 | dyn_cast<MemRefType>(oldMemRef.getType())) |
184 | if (!oldMemRefType.getLayout().isIdentity() && |
185 | !isMemRefNormalizable(oldMemRef.getUsers())) |
186 | return WalkResult::interrupt(); |
187 | } |
188 | return WalkResult::advance(); |
189 | }) |
190 | .wasInterrupted()) |
191 | return false; |
192 | |
193 | for (unsigned argIndex : llvm::seq<unsigned>(0, funcOp.getNumArguments())) { |
194 | BlockArgument oldMemRef = funcOp.getArgument(argIndex); |
195 | if (auto oldMemRefType = dyn_cast<MemRefType>(oldMemRef.getType())) |
196 | if (!oldMemRefType.getLayout().isIdentity() && |
197 | !isMemRefNormalizable(oldMemRef.getUsers())) |
198 | return false; |
199 | } |
200 | |
201 | return true; |
202 | } |
203 | |
204 | /// Fetch the updated argument list and result of the function and update the |
205 | /// function signature. This updates the function's return type at the caller |
206 | /// site and in case the return type is a normalized memref then it updates |
207 | /// the calling function's signature. |
208 | /// TODO: An update to the calling function signature is required only if the |
209 | /// returned value is in turn used in ReturnOp of the calling function. |
210 | void NormalizeMemRefs::updateFunctionSignature(func::FuncOp funcOp, |
211 | ModuleOp moduleOp) { |
212 | FunctionType functionType = funcOp.getFunctionType(); |
213 | SmallVector<Type, 4> resultTypes; |
214 | FunctionType newFuncType; |
215 | resultTypes = llvm::to_vector<4>(functionType.getResults()); |
216 | |
217 | // External function's signature was already updated in |
218 | // 'normalizeFuncOpMemRefs()'. |
219 | if (!funcOp.isExternal()) { |
220 | SmallVector<Type, 8> argTypes; |
221 | for (const auto &argEn : llvm::enumerate(funcOp.getArguments())) |
222 | argTypes.push_back(argEn.value().getType()); |
223 | |
224 | // Traverse ReturnOps to check if an update to the return type in the |
225 | // function signature is required. |
226 | funcOp.walk([&](func::ReturnOp returnOp) { |
227 | for (const auto &operandEn : llvm::enumerate(returnOp.getOperands())) { |
228 | Type opType = operandEn.value().getType(); |
229 | MemRefType memrefType = dyn_cast<MemRefType>(opType); |
230 | // If type is not memref or if the memref type is same as that in |
231 | // function's return signature then no update is required. |
232 | if (!memrefType || memrefType == resultTypes[operandEn.index()]) |
233 | continue; |
234 | // Update function's return type signature. |
235 | // Return type gets normalized either as a result of function argument |
236 | // normalization, AllocOp normalization or an update made at CallOp. |
237 | // There can be many call flows inside a function and an update to a |
238 | // specific ReturnOp has not yet been made. So we check that the result |
239 | // memref type is normalized. |
240 | // TODO: When selective normalization is implemented, handle multiple |
241 | // results case where some are normalized, some aren't. |
242 | if (memrefType.getLayout().isIdentity()) |
243 | resultTypes[operandEn.index()] = memrefType; |
244 | } |
245 | }); |
246 | |
247 | // We create a new function type and modify the function signature with this |
248 | // new type. |
249 | newFuncType = FunctionType::get(&getContext(), /*inputs=*/argTypes, |
250 | /*results=*/resultTypes); |
251 | } |
252 | |
253 | // Since we update the function signature, it might affect the result types at |
254 | // the caller site. Since this result might even be used by the caller |
255 | // function in ReturnOps, the caller function's signature will also change. |
256 | // Hence we record the caller function in 'funcOpsToUpdate' to update their |
257 | // signature as well. |
258 | llvm::SmallDenseSet<func::FuncOp, 8> funcOpsToUpdate; |
259 | // We iterate over all symbolic uses of the function and update the return |
260 | // type at the caller site. |
261 | std::optional<SymbolTable::UseRange> symbolUses = |
262 | funcOp.getSymbolUses(moduleOp); |
263 | for (SymbolTable::SymbolUse symbolUse : *symbolUses) { |
264 | Operation *userOp = symbolUse.getUser(); |
265 | OpBuilder builder(userOp); |
266 | // When `userOp` can not be casted to `CallOp`, it is skipped. This assumes |
267 | // that the non-CallOp has no memrefs to be replaced. |
268 | // TODO: Handle cases where a non-CallOp symbol use of a function deals with |
269 | // memrefs. |
270 | auto callOp = dyn_cast<func::CallOp>(userOp); |
271 | if (!callOp) |
272 | continue; |
273 | Operation *newCallOp = |
274 | builder.create<func::CallOp>(userOp->getLoc(), callOp.getCalleeAttr(), |
275 | resultTypes, userOp->getOperands()); |
276 | bool replacingMemRefUsesFailed = false; |
277 | bool returnTypeChanged = false; |
278 | for (unsigned resIndex : llvm::seq<unsigned>(0, userOp->getNumResults())) { |
279 | OpResult oldResult = userOp->getResult(resIndex); |
280 | OpResult newResult = newCallOp->getResult(resIndex); |
281 | // This condition ensures that if the result is not of type memref or if |
282 | // the resulting memref was already having a trivial map layout then we |
283 | // need not perform any use replacement here. |
284 | if (oldResult.getType() == newResult.getType()) |
285 | continue; |
286 | AffineMap layoutMap = |
287 | cast<MemRefType>(oldResult.getType()).getLayout().getAffineMap(); |
288 | if (failed(replaceAllMemRefUsesWith(oldResult, /*newMemRef=*/newResult, |
289 | /*extraIndices=*/{}, |
290 | /*indexRemap=*/layoutMap, |
291 | /*extraOperands=*/{}, |
292 | /*symbolOperands=*/{}, |
293 | /*domOpFilter=*/nullptr, |
294 | /*postDomOpFilter=*/nullptr, |
295 | /*allowNonDereferencingOps=*/true, |
296 | /*replaceInDeallocOp=*/true))) { |
297 | // If it failed (due to escapes for example), bail out. |
298 | // It should never hit this part of the code because it is called by |
299 | // only those functions which are normalizable. |
300 | newCallOp->erase(); |
301 | replacingMemRefUsesFailed = true; |
302 | break; |
303 | } |
304 | returnTypeChanged = true; |
305 | } |
306 | if (replacingMemRefUsesFailed) |
307 | continue; |
308 | // Replace all uses for other non-memref result types. |
309 | userOp->replaceAllUsesWith(newCallOp); |
310 | userOp->erase(); |
311 | if (returnTypeChanged) { |
312 | // Since the return type changed it might lead to a change in function's |
313 | // signature. |
314 | // TODO: If funcOp doesn't return any memref type then no need to update |
315 | // signature. |
316 | // TODO: Further optimization - Check if the memref is indeed part of |
317 | // ReturnOp at the parentFuncOp and only then updation of signature is |
318 | // required. |
319 | // TODO: Extend this for ops that are FunctionOpInterface. This would |
320 | // require creating an OpInterface for FunctionOpInterface ops. |
321 | func::FuncOp parentFuncOp = newCallOp->getParentOfType<func::FuncOp>(); |
322 | funcOpsToUpdate.insert(parentFuncOp); |
323 | } |
324 | } |
325 | // Because external function's signature is already updated in |
326 | // 'normalizeFuncOpMemRefs()', we don't need to update it here again. |
327 | if (!funcOp.isExternal()) |
328 | funcOp.setType(newFuncType); |
329 | |
330 | // Updating the signature type of those functions which call the current |
331 | // function. Only if the return type of the current function has a normalized |
332 | // memref will the caller function become a candidate for signature update. |
333 | for (func::FuncOp parentFuncOp : funcOpsToUpdate) |
334 | updateFunctionSignature(funcOp: parentFuncOp, moduleOp: moduleOp); |
335 | } |
336 | |
337 | /// Normalizes the memrefs within a function which includes those arising as a |
338 | /// result of AllocOps, CallOps and function's argument. The ModuleOp argument |
339 | /// is used to help update function's signature after normalization. |
340 | void NormalizeMemRefs::normalizeFuncOpMemRefs(func::FuncOp funcOp, |
341 | ModuleOp moduleOp) { |
342 | // Turn memrefs' non-identity layouts maps into ones with identity. Collect |
343 | // alloc ops first and then process since normalizeMemRef replaces/erases ops |
344 | // during memref rewriting. |
345 | SmallVector<memref::AllocOp, 4> allocOps; |
346 | funcOp.walk([&](memref::AllocOp op) { allocOps.push_back(op); }); |
347 | for (memref::AllocOp allocOp : allocOps) |
348 | (void)normalizeMemRef(&allocOp); |
349 | |
350 | // We use this OpBuilder to create new memref layout later. |
351 | OpBuilder b(funcOp); |
352 | |
353 | FunctionType functionType = funcOp.getFunctionType(); |
354 | SmallVector<Location> functionArgLocs(llvm::map_range( |
355 | funcOp.getArguments(), [](BlockArgument arg) { return arg.getLoc(); })); |
356 | SmallVector<Type, 8> inputTypes; |
357 | // Walk over each argument of a function to perform memref normalization (if |
358 | for (unsigned argIndex : |
359 | llvm::seq<unsigned>(0, functionType.getNumInputs())) { |
360 | Type argType = functionType.getInput(argIndex); |
361 | MemRefType memrefType = dyn_cast<MemRefType>(argType); |
362 | // Check whether argument is of MemRef type. Any other argument type can |
363 | // simply be part of the final function signature. |
364 | if (!memrefType) { |
365 | inputTypes.push_back(argType); |
366 | continue; |
367 | } |
368 | // Fetch a new memref type after normalizing the old memref to have an |
369 | // identity map layout. |
370 | MemRefType newMemRefType = normalizeMemRefType(memrefType); |
371 | if (newMemRefType == memrefType || funcOp.isExternal()) { |
372 | // Either memrefType already had an identity map or the map couldn't be |
373 | // transformed to an identity map. |
374 | inputTypes.push_back(newMemRefType); |
375 | continue; |
376 | } |
377 | |
378 | // Insert a new temporary argument with the new memref type. |
379 | BlockArgument newMemRef = funcOp.front().insertArgument( |
380 | argIndex, newMemRefType, functionArgLocs[argIndex]); |
381 | BlockArgument oldMemRef = funcOp.getArgument(argIndex + 1); |
382 | AffineMap layoutMap = memrefType.getLayout().getAffineMap(); |
383 | // Replace all uses of the old memref. |
384 | if (failed(replaceAllMemRefUsesWith(oldMemRef, /*newMemRef=*/newMemRef, |
385 | /*extraIndices=*/{}, |
386 | /*indexRemap=*/layoutMap, |
387 | /*extraOperands=*/{}, |
388 | /*symbolOperands=*/{}, |
389 | /*domOpFilter=*/nullptr, |
390 | /*postDomOpFilter=*/nullptr, |
391 | /*allowNonDereferencingOps=*/true, |
392 | /*replaceInDeallocOp=*/true))) { |
393 | // If it failed (due to escapes for example), bail out. Removing the |
394 | // temporary argument inserted previously. |
395 | funcOp.front().eraseArgument(argIndex); |
396 | continue; |
397 | } |
398 | |
399 | // All uses for the argument with old memref type were replaced |
400 | // successfully. So we remove the old argument now. |
401 | funcOp.front().eraseArgument(argIndex + 1); |
402 | } |
403 | |
404 | // Walk over normalizable operations to normalize memrefs of the operation |
405 | // results. When `op` has memrefs with affine map in the operation results, |
406 | // new operation containin normalized memrefs is created. Then, the memrefs |
407 | // are replaced. `CallOp` is skipped here because it is handled in |
408 | // `updateFunctionSignature()`. |
409 | funcOp.walk([&](Operation *op) { |
410 | if (op->hasTrait<OpTrait::MemRefsNormalizable>() && |
411 | op->getNumResults() > 0 && !isa<func::CallOp>(op) && |
412 | !funcOp.isExternal()) { |
413 | // Create newOp containing normalized memref in the operation result. |
414 | Operation *newOp = createOpResultsNormalized(funcOp: funcOp, oldOp: op); |
415 | // When all of the operation results have no memrefs or memrefs without |
416 | // affine map, `newOp` is the same with `op` and following process is |
417 | // skipped. |
418 | if (op != newOp) { |
419 | bool replacingMemRefUsesFailed = false; |
420 | for (unsigned resIndex : llvm::seq<unsigned>(Begin: 0, End: op->getNumResults())) { |
421 | // Replace all uses of the old memrefs. |
422 | Value oldMemRef = op->getResult(idx: resIndex); |
423 | Value newMemRef = newOp->getResult(idx: resIndex); |
424 | MemRefType oldMemRefType = dyn_cast<MemRefType>(oldMemRef.getType()); |
425 | // Check whether the operation result is MemRef type. |
426 | if (!oldMemRefType) |
427 | continue; |
428 | MemRefType newMemRefType = cast<MemRefType>(newMemRef.getType()); |
429 | if (oldMemRefType == newMemRefType) |
430 | continue; |
431 | // TODO: Assume single layout map. Multiple maps not supported. |
432 | AffineMap layoutMap = oldMemRefType.getLayout().getAffineMap(); |
433 | if (failed(result: replaceAllMemRefUsesWith(oldMemRef, |
434 | /*newMemRef=*/newMemRef, |
435 | /*extraIndices=*/{}, |
436 | /*indexRemap=*/layoutMap, |
437 | /*extraOperands=*/{}, |
438 | /*symbolOperands=*/{}, |
439 | /*domOpFilter=*/nullptr, |
440 | /*postDomOpFilter=*/nullptr, |
441 | /*allowNonDereferencingOps=*/true, |
442 | /*replaceInDeallocOp=*/true))) { |
443 | newOp->erase(); |
444 | replacingMemRefUsesFailed = true; |
445 | continue; |
446 | } |
447 | } |
448 | if (!replacingMemRefUsesFailed) { |
449 | // Replace other ops with new op and delete the old op when the |
450 | // replacement succeeded. |
451 | op->replaceAllUsesWith(values&: newOp); |
452 | op->erase(); |
453 | } |
454 | } |
455 | } |
456 | }); |
457 | |
458 | // In a normal function, memrefs in the return type signature gets normalized |
459 | // as a result of normalization of functions arguments, AllocOps or CallOps' |
460 | // result types. Since an external function doesn't have a body, memrefs in |
461 | // the return type signature can only get normalized by iterating over the |
462 | // individual return types. |
463 | if (funcOp.isExternal()) { |
464 | SmallVector<Type, 4> resultTypes; |
465 | for (unsigned resIndex : |
466 | llvm::seq<unsigned>(0, functionType.getNumResults())) { |
467 | Type resType = functionType.getResult(resIndex); |
468 | MemRefType memrefType = dyn_cast<MemRefType>(resType); |
469 | // Check whether result is of MemRef type. Any other argument type can |
470 | // simply be part of the final function signature. |
471 | if (!memrefType) { |
472 | resultTypes.push_back(resType); |
473 | continue; |
474 | } |
475 | // Computing a new memref type after normalizing the old memref to have an |
476 | // identity map layout. |
477 | MemRefType newMemRefType = normalizeMemRefType(memrefType); |
478 | resultTypes.push_back(newMemRefType); |
479 | } |
480 | |
481 | FunctionType newFuncType = |
482 | FunctionType::get(&getContext(), /*inputs=*/inputTypes, |
483 | /*results=*/resultTypes); |
484 | // Setting the new function signature for this external function. |
485 | funcOp.setType(newFuncType); |
486 | } |
487 | updateFunctionSignature(funcOp: funcOp, moduleOp: moduleOp); |
488 | } |
489 | |
490 | /// Create an operation containing normalized memrefs in the operation results. |
491 | /// When the results of `oldOp` have memrefs with affine map, the memrefs are |
492 | /// normalized, and new operation containing them in the operation results is |
493 | /// returned. If all of the results of `oldOp` have no memrefs or memrefs |
494 | /// without affine map, `oldOp` is returned without modification. |
495 | Operation *NormalizeMemRefs::createOpResultsNormalized(func::FuncOp funcOp, |
496 | Operation *oldOp) { |
497 | // Prepare OperationState to create newOp containing normalized memref in |
498 | // the operation results. |
499 | OperationState result(oldOp->getLoc(), oldOp->getName()); |
500 | result.addOperands(newOperands: oldOp->getOperands()); |
501 | result.addAttributes(newAttributes: oldOp->getAttrs()); |
502 | // Add normalized MemRefType to the OperationState. |
503 | SmallVector<Type, 4> resultTypes; |
504 | OpBuilder b(funcOp); |
505 | bool resultTypeNormalized = false; |
506 | for (unsigned resIndex : llvm::seq<unsigned>(Begin: 0, End: oldOp->getNumResults())) { |
507 | auto resultType = oldOp->getResult(idx: resIndex).getType(); |
508 | MemRefType memrefType = dyn_cast<MemRefType>(resultType); |
509 | // Check whether the operation result is MemRef type. |
510 | if (!memrefType) { |
511 | resultTypes.push_back(Elt: resultType); |
512 | continue; |
513 | } |
514 | |
515 | // Fetch a new memref type after normalizing the old memref. |
516 | MemRefType newMemRefType = normalizeMemRefType(memrefType); |
517 | if (newMemRefType == memrefType) { |
518 | // Either memrefType already had an identity map or the map couldn't |
519 | // be transformed to an identity map. |
520 | resultTypes.push_back(Elt: memrefType); |
521 | continue; |
522 | } |
523 | resultTypes.push_back(Elt: newMemRefType); |
524 | resultTypeNormalized = true; |
525 | } |
526 | result.addTypes(newTypes: resultTypes); |
527 | // When all of the results of `oldOp` have no memrefs or memrefs without |
528 | // affine map, `oldOp` is returned without modification. |
529 | if (resultTypeNormalized) { |
530 | OpBuilder bb(oldOp); |
531 | for (auto &oldRegion : oldOp->getRegions()) { |
532 | Region *newRegion = result.addRegion(); |
533 | newRegion->takeBody(other&: oldRegion); |
534 | } |
535 | return bb.create(state: result); |
536 | } |
537 | return oldOp; |
538 | } |
539 | |