| 1 | //===- IndependenceTransforms.cpp - Make ops independent of values --------===// |
| 2 | // |
| 3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
| 4 | // See https://llvm.org/LICENSE.txt for license information. |
| 5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
| 6 | // |
| 7 | //===----------------------------------------------------------------------===// |
| 8 | |
| 9 | #include "mlir/Dialect/MemRef/Transforms/Transforms.h" |
| 10 | |
| 11 | #include "mlir/Dialect/Affine/IR/AffineOps.h" |
| 12 | #include "mlir/Dialect/Affine/Transforms/Transforms.h" |
| 13 | #include "mlir/Dialect/MemRef/IR/MemRef.h" |
| 14 | #include "mlir/Dialect/Utils/StaticValueUtils.h" |
| 15 | #include "mlir/Interfaces/ValueBoundsOpInterface.h" |
| 16 | |
| 17 | using namespace mlir; |
| 18 | using namespace mlir::memref; |
| 19 | |
| 20 | /// Make the given OpFoldResult independent of all independencies. |
| 21 | static FailureOr<OpFoldResult> makeIndependent(OpBuilder &b, Location loc, |
| 22 | OpFoldResult ofr, |
| 23 | ValueRange independencies) { |
| 24 | if (isa<Attribute>(Val: ofr)) |
| 25 | return ofr; |
| 26 | AffineMap boundMap; |
| 27 | ValueDimList mapOperands; |
| 28 | if (failed(Result: ValueBoundsConstraintSet::computeIndependentBound( |
| 29 | resultMap&: boundMap, mapOperands, type: presburger::BoundType::UB, var: ofr, independencies, |
| 30 | /*closedUB=*/true))) |
| 31 | return failure(); |
| 32 | return affine::materializeComputedBound(b, loc, boundMap, mapOperands); |
| 33 | } |
| 34 | |
| 35 | FailureOr<Value> memref::buildIndependentOp(OpBuilder &b, |
| 36 | memref::AllocaOp allocaOp, |
| 37 | ValueRange independencies) { |
| 38 | OpBuilder::InsertionGuard g(b); |
| 39 | b.setInsertionPoint(allocaOp); |
| 40 | Location loc = allocaOp.getLoc(); |
| 41 | |
| 42 | SmallVector<OpFoldResult> newSizes; |
| 43 | for (OpFoldResult ofr : allocaOp.getMixedSizes()) { |
| 44 | auto ub = makeIndependent(b, loc, ofr, independencies); |
| 45 | if (failed(ub)) |
| 46 | return failure(); |
| 47 | newSizes.push_back(*ub); |
| 48 | } |
| 49 | |
| 50 | // Return existing memref::AllocaOp if nothing has changed. |
| 51 | if (llvm::equal(allocaOp.getMixedSizes(), newSizes)) |
| 52 | return allocaOp.getResult(); |
| 53 | |
| 54 | // Create a new memref::AllocaOp. |
| 55 | Value newAllocaOp = |
| 56 | b.create<AllocaOp>(loc, newSizes, allocaOp.getType().getElementType()); |
| 57 | |
| 58 | // Create a memref::SubViewOp. |
| 59 | SmallVector<OpFoldResult> offsets(newSizes.size(), b.getIndexAttr(0)); |
| 60 | SmallVector<OpFoldResult> strides(newSizes.size(), b.getIndexAttr(1)); |
| 61 | return b |
| 62 | .create<SubViewOp>(loc, newAllocaOp, offsets, allocaOp.getMixedSizes(), |
| 63 | strides) |
| 64 | .getResult(); |
| 65 | } |
| 66 | |
| 67 | /// Push down an UnrealizedConversionCastOp past a SubViewOp. |
| 68 | static UnrealizedConversionCastOp |
| 69 | propagateSubViewOp(RewriterBase &rewriter, |
| 70 | UnrealizedConversionCastOp conversionOp, SubViewOp op) { |
| 71 | OpBuilder::InsertionGuard g(rewriter); |
| 72 | rewriter.setInsertionPoint(op); |
| 73 | MemRefType newResultType = SubViewOp::inferRankReducedResultType( |
| 74 | op.getType().getShape(), op.getSourceType(), op.getMixedOffsets(), |
| 75 | op.getMixedSizes(), op.getMixedStrides()); |
| 76 | Value newSubview = rewriter.create<SubViewOp>( |
| 77 | op.getLoc(), newResultType, conversionOp.getOperand(0), |
| 78 | op.getMixedOffsets(), op.getMixedSizes(), op.getMixedStrides()); |
| 79 | auto newConversionOp = rewriter.create<UnrealizedConversionCastOp>( |
| 80 | op.getLoc(), op.getType(), newSubview); |
| 81 | rewriter.replaceAllUsesWith(op.getResult(), newConversionOp->getResult(0)); |
| 82 | return newConversionOp; |
| 83 | } |
| 84 | |
| 85 | /// Given an original op and a new, modified op with the same number of results, |
| 86 | /// whose memref return types may differ, replace all uses of the original op |
| 87 | /// with the new op and propagate the new memref types through the IR. |
| 88 | /// |
| 89 | /// Example: |
| 90 | /// %from = memref.alloca(%sz) : memref<?xf32> |
| 91 | /// %to = memref.subview ... : ... to memref<?xf32, strided<[1], offset: ?>> |
| 92 | /// memref.store %cst, %from[%c0] : memref<?xf32> |
| 93 | /// |
| 94 | /// In the above example, all uses of %from are replaced with %to. This can be |
| 95 | /// done directly for ops such as memref.store. For ops that have memref results |
| 96 | /// (e.g., memref.subview), the result type may depend on the operand type, so |
| 97 | /// we cannot just replace all uses. There is special handling for common memref |
| 98 | /// ops. For all other ops, unrealized_conversion_cast is inserted. |
| 99 | static void replaceAndPropagateMemRefType(RewriterBase &rewriter, |
| 100 | Operation *from, Operation *to) { |
| 101 | assert(from->getNumResults() == to->getNumResults() && |
| 102 | "expected same number of results" ); |
| 103 | OpBuilder::InsertionGuard g(rewriter); |
| 104 | rewriter.setInsertionPointAfter(to); |
| 105 | |
| 106 | // Wrap new results in unrealized_conversion_cast and replace all uses of the |
| 107 | // original op. |
| 108 | SmallVector<UnrealizedConversionCastOp> unrealizedConversions; |
| 109 | for (const auto &it : |
| 110 | llvm::enumerate(First: llvm::zip(t: from->getResults(), u: to->getResults()))) { |
| 111 | unrealizedConversions.push_back(rewriter.create<UnrealizedConversionCastOp>( |
| 112 | to->getLoc(), std::get<0>(it.value()).getType(), |
| 113 | std::get<1>(it.value()))); |
| 114 | rewriter.replaceAllUsesWith(from->getResult(idx: it.index()), |
| 115 | unrealizedConversions.back()->getResult(0)); |
| 116 | } |
| 117 | |
| 118 | // Push unrealized_conversion_cast ops further down in the IR. I.e., try to |
| 119 | // wrap results instead of operands in a cast. |
| 120 | for (int i = 0; i < static_cast<int>(unrealizedConversions.size()); ++i) { |
| 121 | UnrealizedConversionCastOp conversion = unrealizedConversions[i]; |
| 122 | assert(conversion->getNumOperands() == 1 && |
| 123 | conversion->getNumResults() == 1 && |
| 124 | "expected single operand and single result" ); |
| 125 | SmallVector<Operation *> users = llvm::to_vector(conversion->getUsers()); |
| 126 | for (Operation *user : users) { |
| 127 | // Handle common memref dialect ops that produce new memrefs and must |
| 128 | // be recreated with the new result type. |
| 129 | if (auto subviewOp = dyn_cast<SubViewOp>(user)) { |
| 130 | unrealizedConversions.push_back( |
| 131 | propagateSubViewOp(rewriter, conversion, subviewOp)); |
| 132 | continue; |
| 133 | } |
| 134 | |
| 135 | // TODO: Other memref ops such as memref.collapse_shape/expand_shape |
| 136 | // should also be handled here. |
| 137 | |
| 138 | // Skip any ops that produce MemRef result or have MemRef region block |
| 139 | // arguments. These may need special handling (e.g., scf.for). |
| 140 | if (llvm::any_of(user->getResultTypes(), |
| 141 | [](Type t) { return isa<MemRefType>(t); })) |
| 142 | continue; |
| 143 | if (llvm::any_of(user->getRegions(), [](Region &r) { |
| 144 | return llvm::any_of(r.getArguments(), [](BlockArgument bbArg) { |
| 145 | return isa<MemRefType>(bbArg.getType()); |
| 146 | }); |
| 147 | })) |
| 148 | continue; |
| 149 | |
| 150 | // For all other ops, we assume that we can directly replace the operand. |
| 151 | // This may have to be revised in the future; e.g., there may be ops that |
| 152 | // do not support non-identity layout maps. |
| 153 | for (OpOperand &operand : user->getOpOperands()) { |
| 154 | if ([[maybe_unused]] auto castOp = |
| 155 | operand.get().getDefiningOp<UnrealizedConversionCastOp>()) { |
| 156 | rewriter.modifyOpInPlace( |
| 157 | user, [&]() { operand.set(conversion->getOperand(0)); }); |
| 158 | } |
| 159 | } |
| 160 | } |
| 161 | } |
| 162 | |
| 163 | // Erase all unrealized_conversion_cast ops without uses. |
| 164 | for (auto op : unrealizedConversions) |
| 165 | if (op->getUses().empty()) |
| 166 | rewriter.eraseOp(op); |
| 167 | } |
| 168 | |
| 169 | FailureOr<Value> memref::replaceWithIndependentOp(RewriterBase &rewriter, |
| 170 | memref::AllocaOp allocaOp, |
| 171 | ValueRange independencies) { |
| 172 | auto replacement = |
| 173 | memref::buildIndependentOp(b&: rewriter, allocaOp: allocaOp, independencies); |
| 174 | if (failed(replacement)) |
| 175 | return failure(); |
| 176 | replaceAndPropagateMemRefType(rewriter, allocaOp, |
| 177 | replacement->getDefiningOp()); |
| 178 | return replacement; |
| 179 | } |
| 180 | |
| 181 | memref::AllocaOp memref::allocToAlloca( |
| 182 | RewriterBase &rewriter, memref::AllocOp alloc, |
| 183 | function_ref<bool(memref::AllocOp, memref::DeallocOp)> filter) { |
| 184 | memref::DeallocOp dealloc = nullptr; |
| 185 | for (Operation &candidate : |
| 186 | llvm::make_range(alloc->getIterator(), alloc->getBlock()->end())) { |
| 187 | dealloc = dyn_cast<memref::DeallocOp>(candidate); |
| 188 | if (dealloc && dealloc.getMemref() == alloc.getMemref() && |
| 189 | (!filter || filter(alloc, dealloc))) { |
| 190 | break; |
| 191 | } |
| 192 | } |
| 193 | |
| 194 | if (!dealloc) |
| 195 | return nullptr; |
| 196 | |
| 197 | OpBuilder::InsertionGuard guard(rewriter); |
| 198 | rewriter.setInsertionPoint(alloc); |
| 199 | auto alloca = rewriter.replaceOpWithNewOp<memref::AllocaOp>( |
| 200 | alloc, alloc.getMemref().getType(), alloc.getOperands()); |
| 201 | rewriter.eraseOp(op: dealloc); |
| 202 | return alloca; |
| 203 | } |
| 204 | |