1//===- ReifyResultShapes.cpp - Reify result shapes ------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This transform reifies result shapes of `ReifyRankedShapedTypeOpInterface`
10// operations with ranked `memref` and `tensor` results.
11//
12//===----------------------------------------------------------------------===//
13
14#include "mlir/Dialect/MemRef/Transforms/Passes.h"
15
16#include "mlir/Dialect/Affine/IR/AffineOps.h"
17#include "mlir/Dialect/MemRef/IR/MemRef.h"
18#include "mlir/Dialect/MemRef/Transforms/Transforms.h"
19#include "mlir/Dialect/Tensor/IR/Tensor.h"
20#include "mlir/Interfaces/DestinationStyleOpInterface.h"
21#include "mlir/Interfaces/InferTypeOpInterface.h"
22#include "llvm/Support/InterleavedRange.h"
23
24#define DEBUG_TYPE "reify-result-shapes"
25#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE << "]: ")
26
27namespace mlir {
28namespace memref {
29#define GEN_PASS_DEF_REIFYRESULTSHAPESPASS
30#include "mlir/Dialect/MemRef/Transforms/Passes.h.inc"
31} // namespace memref
32} // namespace mlir
33
34using namespace mlir;
35
36/// Reifies the results of `op`, potentially replacing `op` with a reified
37/// version. Returns `failure` if `mlir::reifyResultShapes` returned failure,
38/// otherwise it always succeeds. Users of this transform should always expect
39/// it to modify the IR, even when it fails. If any of the result types changes,
40/// the transform will insert cast operations to the old type to keep the IR
41/// consistent.
42static LogicalResult reifyOpResultShapes(RewriterBase &rewriter,
43 ReifyRankedShapedTypeOpInterface op) {
44 LLVM_DEBUG({ DBGS() << " reifying op: " << op << "\n"; });
45 // Get the reified out shapes.
46 ReifiedRankedShapedTypeDims reifiedResultShapes;
47 if (failed(Result: mlir::reifyResultShapes(b&: rewriter, op, reifiedReturnShapes&: reifiedResultShapes)) ||
48 reifiedResultShapes.empty()) {
49 return op->emitWarning() << "failed to get the reified shapes";
50 }
51
52 bool modified = false;
53 // Compute the new output types.
54 SmallVector<Type> outTypes;
55 for (const auto &[oldTy, reifiedShape] :
56 llvm::zip(t: op->getResultTypes(), u&: reifiedResultShapes)) {
57 // Skip if it's not a memref or tensor type.
58 if (!isa<RankedTensorType, MemRefType>(Val: oldTy)) {
59 outTypes.push_back(Elt: oldTy);
60 continue;
61 }
62
63 ShapedType shapedTy = dyn_cast<ShapedType>(Val: oldTy);
64
65 SmallVector<int64_t> shape = llvm::to_vector(Range: shapedTy.getShape());
66 for (auto &&[dim, ofr] : llvm::zip_equal(t&: shape, u&: reifiedShape)) {
67 std::optional<int64_t> maybeCst = getConstantIntValue(ofr);
68 // If the reified dim is dynamic set it appropriately.
69 if (!maybeCst.has_value()) {
70 dim = ShapedType::kDynamic;
71 continue;
72 }
73 // Set the static dim.
74 dim = *maybeCst;
75 }
76
77 // If the shape didn't change continue.
78 if (shape == shapedTy.getShape()) {
79 outTypes.push_back(Elt: oldTy);
80 continue;
81 }
82 modified = true;
83 outTypes.push_back(Elt: shapedTy.cloneWith(shape, elementType: shapedTy.getElementType()));
84 }
85
86 // Return if we don't need to update.
87 if (!modified) {
88 LLVM_DEBUG({ DBGS() << "- op doesn't require update\n"; });
89 return success();
90 }
91
92 LLVM_DEBUG({
93 DBGS() << "- oldTypes: " << llvm::interleaved_array(op->getResultTypes())
94 << " \n";
95 DBGS() << "- outTypes: " << llvm::interleaved_array(outTypes) << " \n";
96 });
97
98 // We now have outTypes that need to be turned to cast ops.
99 Location loc = op->getLoc();
100 SmallVector<Value> newResults;
101 // TODO: `mlir::reifyResultShapes` and op verifiers may not agree atm.
102 // This is a confluence problem that will need to be addressed.
103 // For now, we know PadOp and ConcatOp are fine.
104 assert((isa<tensor::PadOp, tensor::ConcatOp>(op.getOperation())) &&
105 "incorrect op");
106 Operation *newOp = rewriter.clone(op&: *op);
107 for (auto [reifiedTy, oldRes] : llvm::zip(t&: outTypes, u: op->getResults())) {
108 OpResult newRes = newOp->getResult(idx: oldRes.getResultNumber());
109 Type oldTy = oldRes.getType();
110 // Continue if the type remained invariant or is not shaped.
111 if (oldTy == reifiedTy || !isa<MemRefType, RankedTensorType>(Val: oldTy)) {
112 newResults.push_back(Elt: newRes);
113 continue;
114 }
115
116 // Update the type.
117 newRes.setType(reifiedTy);
118 if (isa<RankedTensorType>(Val: reifiedTy)) {
119 newResults.push_back(Elt: rewriter.create<tensor::CastOp>(location: loc, args&: oldTy, args&: newRes));
120 } else {
121 assert(isa<MemRefType>(reifiedTy) && "expected a memref type");
122 newResults.push_back(Elt: rewriter.create<memref::CastOp>(location: loc, args&: oldTy, args&: newRes));
123 }
124 }
125
126 LLVM_DEBUG({
127 DBGS() << "- reified results " << llvm::interleaved_array(newResults)
128 << "\n";
129 });
130 rewriter.replaceOp(op, newValues: newResults);
131 return success();
132}
133
134//===----------------------------------------------------------------------===//
135// Pass registration
136//===----------------------------------------------------------------------===//
137
138namespace {
139struct ReifyResultShapesPass final
140 : public memref::impl::ReifyResultShapesPassBase<ReifyResultShapesPass> {
141 void runOnOperation() override;
142};
143} // namespace
144
145void ReifyResultShapesPass::runOnOperation() {
146 SmallVector<ReifyRankedShapedTypeOpInterface> ops;
147 getOperation()->walk(callback: [&](ReifyRankedShapedTypeOpInterface op) {
148 // Handle ops that are not DPS and that do not carry an tied operand shapes.
149 // For now, limit to tensor::PadOp and tensor::ConcatOp.
150 if (!isa<tensor::PadOp, tensor::ConcatOp>(Val: op.getOperation()))
151 return;
152 ops.push_back(Elt: op);
153 });
154 IRRewriter rewriter(&getContext());
155 for (ReifyRankedShapedTypeOpInterface op : ops) {
156 rewriter.setInsertionPoint(op);
157 (void)reifyOpResultShapes(rewriter, op);
158 }
159}
160

source code of mlir/lib/Dialect/MemRef/Transforms/ReifyResultShapes.cpp