1//===- IndependenceTransforms.cpp - Make ops independent of values --------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#include "mlir/Dialect/MemRef/Transforms/Transforms.h"
10
11#include "mlir/Dialect/Affine/Transforms/Transforms.h"
12#include "mlir/Dialect/MemRef/IR/MemRef.h"
13#include "mlir/Interfaces/ValueBoundsOpInterface.h"
14
15using namespace mlir;
16using namespace mlir::memref;
17
18/// Make the given OpFoldResult independent of all independencies.
19static FailureOr<OpFoldResult> makeIndependent(OpBuilder &b, Location loc,
20 OpFoldResult ofr,
21 ValueRange independencies) {
22 if (isa<Attribute>(Val: ofr))
23 return ofr;
24 AffineMap boundMap;
25 ValueDimList mapOperands;
26 if (failed(Result: ValueBoundsConstraintSet::computeIndependentBound(
27 resultMap&: boundMap, mapOperands, type: presburger::BoundType::UB, var: ofr, independencies,
28 /*closedUB=*/true)))
29 return failure();
30 return affine::materializeComputedBound(b, loc, boundMap, mapOperands);
31}
32
33FailureOr<Value> memref::buildIndependentOp(OpBuilder &b,
34 memref::AllocaOp allocaOp,
35 ValueRange independencies) {
36 OpBuilder::InsertionGuard g(b);
37 b.setInsertionPoint(allocaOp);
38 Location loc = allocaOp.getLoc();
39
40 SmallVector<OpFoldResult> newSizes;
41 for (OpFoldResult ofr : allocaOp.getMixedSizes()) {
42 auto ub = makeIndependent(b, loc, ofr, independencies);
43 if (failed(Result: ub))
44 return failure();
45 newSizes.push_back(Elt: *ub);
46 }
47
48 // Return existing memref::AllocaOp if nothing has changed.
49 if (llvm::equal(LRange: allocaOp.getMixedSizes(), RRange&: newSizes))
50 return allocaOp.getResult();
51
52 // Create a new memref::AllocaOp.
53 Value newAllocaOp =
54 b.create<AllocaOp>(location: loc, args&: newSizes, args: allocaOp.getType().getElementType());
55
56 // Create a memref::SubViewOp.
57 SmallVector<OpFoldResult> offsets(newSizes.size(), b.getIndexAttr(value: 0));
58 SmallVector<OpFoldResult> strides(newSizes.size(), b.getIndexAttr(value: 1));
59 return b
60 .create<SubViewOp>(location: loc, args&: newAllocaOp, args&: offsets, args: allocaOp.getMixedSizes(),
61 args&: strides)
62 .getResult();
63}
64
65/// Push down an UnrealizedConversionCastOp past a SubViewOp.
66static UnrealizedConversionCastOp
67propagateSubViewOp(RewriterBase &rewriter,
68 UnrealizedConversionCastOp conversionOp, SubViewOp op) {
69 OpBuilder::InsertionGuard g(rewriter);
70 rewriter.setInsertionPoint(op);
71 MemRefType newResultType = SubViewOp::inferRankReducedResultType(
72 resultShape: op.getType().getShape(), sourceMemRefType: op.getSourceType(), staticOffsets: op.getMixedOffsets(),
73 staticSizes: op.getMixedSizes(), staticStrides: op.getMixedStrides());
74 Value newSubview = rewriter.create<SubViewOp>(
75 location: op.getLoc(), args&: newResultType, args: conversionOp.getOperand(i: 0),
76 args: op.getMixedOffsets(), args: op.getMixedSizes(), args: op.getMixedStrides());
77 auto newConversionOp = rewriter.create<UnrealizedConversionCastOp>(
78 location: op.getLoc(), args: op.getType(), args&: newSubview);
79 rewriter.replaceAllUsesWith(from: op.getResult(), to: newConversionOp->getResult(idx: 0));
80 return newConversionOp;
81}
82
83/// Given an original op and a new, modified op with the same number of results,
84/// whose memref return types may differ, replace all uses of the original op
85/// with the new op and propagate the new memref types through the IR.
86///
87/// Example:
88/// %from = memref.alloca(%sz) : memref<?xf32>
89/// %to = memref.subview ... : ... to memref<?xf32, strided<[1], offset: ?>>
90/// memref.store %cst, %from[%c0] : memref<?xf32>
91///
92/// In the above example, all uses of %from are replaced with %to. This can be
93/// done directly for ops such as memref.store. For ops that have memref results
94/// (e.g., memref.subview), the result type may depend on the operand type, so
95/// we cannot just replace all uses. There is special handling for common memref
96/// ops. For all other ops, unrealized_conversion_cast is inserted.
97static void replaceAndPropagateMemRefType(RewriterBase &rewriter,
98 Operation *from, Operation *to) {
99 assert(from->getNumResults() == to->getNumResults() &&
100 "expected same number of results");
101 OpBuilder::InsertionGuard g(rewriter);
102 rewriter.setInsertionPointAfter(to);
103
104 // Wrap new results in unrealized_conversion_cast and replace all uses of the
105 // original op.
106 SmallVector<UnrealizedConversionCastOp> unrealizedConversions;
107 for (const auto &it :
108 llvm::enumerate(First: llvm::zip(t: from->getResults(), u: to->getResults()))) {
109 unrealizedConversions.push_back(Elt: rewriter.create<UnrealizedConversionCastOp>(
110 location: to->getLoc(), args: std::get<0>(t&: it.value()).getType(),
111 args&: std::get<1>(t&: it.value())));
112 rewriter.replaceAllUsesWith(from: from->getResult(idx: it.index()),
113 to: unrealizedConversions.back()->getResult(idx: 0));
114 }
115
116 // Push unrealized_conversion_cast ops further down in the IR. I.e., try to
117 // wrap results instead of operands in a cast.
118 for (int i = 0; i < static_cast<int>(unrealizedConversions.size()); ++i) {
119 UnrealizedConversionCastOp conversion = unrealizedConversions[i];
120 assert(conversion->getNumOperands() == 1 &&
121 conversion->getNumResults() == 1 &&
122 "expected single operand and single result");
123 SmallVector<Operation *> users = llvm::to_vector(Range: conversion->getUsers());
124 for (Operation *user : users) {
125 // Handle common memref dialect ops that produce new memrefs and must
126 // be recreated with the new result type.
127 if (auto subviewOp = dyn_cast<SubViewOp>(Val: user)) {
128 unrealizedConversions.push_back(
129 Elt: propagateSubViewOp(rewriter, conversionOp: conversion, op: subviewOp));
130 continue;
131 }
132
133 // TODO: Other memref ops such as memref.collapse_shape/expand_shape
134 // should also be handled here.
135
136 // Skip any ops that produce MemRef result or have MemRef region block
137 // arguments. These may need special handling (e.g., scf.for).
138 if (llvm::any_of(Range: user->getResultTypes(),
139 P: [](Type t) { return isa<MemRefType>(Val: t); }))
140 continue;
141 if (llvm::any_of(Range: user->getRegions(), P: [](Region &r) {
142 return llvm::any_of(Range: r.getArguments(), P: [](BlockArgument bbArg) {
143 return isa<MemRefType>(Val: bbArg.getType());
144 });
145 }))
146 continue;
147
148 // For all other ops, we assume that we can directly replace the operand.
149 // This may have to be revised in the future; e.g., there may be ops that
150 // do not support non-identity layout maps.
151 for (OpOperand &operand : user->getOpOperands()) {
152 if ([[maybe_unused]] auto castOp =
153 operand.get().getDefiningOp<UnrealizedConversionCastOp>()) {
154 rewriter.modifyOpInPlace(
155 root: user, callable: [&]() { operand.set(conversion->getOperand(idx: 0)); });
156 }
157 }
158 }
159 }
160
161 // Erase all unrealized_conversion_cast ops without uses.
162 for (auto op : unrealizedConversions)
163 if (op->getUses().empty())
164 rewriter.eraseOp(op);
165}
166
167FailureOr<Value> memref::replaceWithIndependentOp(RewriterBase &rewriter,
168 memref::AllocaOp allocaOp,
169 ValueRange independencies) {
170 auto replacement =
171 memref::buildIndependentOp(b&: rewriter, allocaOp, independencies);
172 if (failed(Result: replacement))
173 return failure();
174 replaceAndPropagateMemRefType(rewriter, from: allocaOp,
175 to: replacement->getDefiningOp());
176 return replacement;
177}
178
179memref::AllocaOp memref::allocToAlloca(
180 RewriterBase &rewriter, memref::AllocOp alloc,
181 function_ref<bool(memref::AllocOp, memref::DeallocOp)> filter) {
182 memref::DeallocOp dealloc = nullptr;
183 for (Operation &candidate :
184 llvm::make_range(x: alloc->getIterator(), y: alloc->getBlock()->end())) {
185 dealloc = dyn_cast<memref::DeallocOp>(Val&: candidate);
186 if (dealloc && dealloc.getMemref() == alloc.getMemref() &&
187 (!filter || filter(alloc, dealloc))) {
188 break;
189 }
190 }
191
192 if (!dealloc)
193 return nullptr;
194
195 OpBuilder::InsertionGuard guard(rewriter);
196 rewriter.setInsertionPoint(alloc);
197 auto alloca = rewriter.replaceOpWithNewOp<memref::AllocaOp>(
198 op: alloc, args: alloc.getMemref().getType(), args: alloc.getOperands());
199 rewriter.eraseOp(op: dealloc);
200 return alloca;
201}
202

source code of mlir/lib/Dialect/MemRef/Transforms/IndependenceTransforms.cpp