1 | //===- ValueBoundsOpInterfaceImpl.cpp - Impl. of ValueBoundsOpInterface ---===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | |
9 | #include "mlir/Dialect/MemRef/IR/ValueBoundsOpInterfaceImpl.h" |
10 | |
11 | #include "mlir/Dialect/MemRef/IR/MemRef.h" |
12 | #include "mlir/Interfaces/ValueBoundsOpInterface.h" |
13 | |
14 | using namespace mlir; |
15 | |
16 | namespace mlir { |
17 | namespace memref { |
18 | namespace { |
19 | |
20 | template <typename OpTy> |
21 | struct AllocOpInterface |
22 | : public ValueBoundsOpInterface::ExternalModel<AllocOpInterface<OpTy>, |
23 | OpTy> { |
24 | void populateBoundsForShapedValueDim(Operation *op, Value value, int64_t dim, |
25 | ValueBoundsConstraintSet &cstr) const { |
26 | auto allocOp = cast<OpTy>(op); |
27 | assert(value == allocOp.getResult() && "invalid value" ); |
28 | |
29 | cstr.bound(value)[dim] == allocOp.getMixedSizes()[dim]; |
30 | } |
31 | }; |
32 | |
33 | struct CastOpInterface |
34 | : public ValueBoundsOpInterface::ExternalModel<CastOpInterface, CastOp> { |
35 | void populateBoundsForShapedValueDim(Operation *op, Value value, int64_t dim, |
36 | ValueBoundsConstraintSet &cstr) const { |
37 | auto castOp = cast<CastOp>(op); |
38 | assert(value == castOp.getResult() && "invalid value" ); |
39 | |
40 | if (llvm::isa<MemRefType>(castOp.getResult().getType()) && |
41 | llvm::isa<MemRefType>(castOp.getSource().getType())) { |
42 | cstr.bound(value)[dim] == cstr.getExpr(castOp.getSource(), dim); |
43 | } |
44 | } |
45 | }; |
46 | |
47 | struct DimOpInterface |
48 | : public ValueBoundsOpInterface::ExternalModel<DimOpInterface, DimOp> { |
49 | void populateBoundsForIndexValue(Operation *op, Value value, |
50 | ValueBoundsConstraintSet &cstr) const { |
51 | auto dimOp = cast<DimOp>(op); |
52 | assert(value == dimOp.getResult() && "invalid value" ); |
53 | |
54 | auto constIndex = dimOp.getConstantIndex(); |
55 | if (!constIndex.has_value()) |
56 | return; |
57 | cstr.bound(value) == cstr.getExpr(dimOp.getSource(), *constIndex); |
58 | } |
59 | }; |
60 | |
61 | struct GetGlobalOpInterface |
62 | : public ValueBoundsOpInterface::ExternalModel<GetGlobalOpInterface, |
63 | GetGlobalOp> { |
64 | void populateBoundsForShapedValueDim(Operation *op, Value value, int64_t dim, |
65 | ValueBoundsConstraintSet &cstr) const { |
66 | auto getGlobalOp = cast<GetGlobalOp>(op); |
67 | assert(value == getGlobalOp.getResult() && "invalid value" ); |
68 | |
69 | auto type = getGlobalOp.getType(); |
70 | assert(!type.isDynamicDim(dim) && "expected static dim" ); |
71 | cstr.bound(value)[dim] == type.getDimSize(dim); |
72 | } |
73 | }; |
74 | |
75 | struct RankOpInterface |
76 | : public ValueBoundsOpInterface::ExternalModel<RankOpInterface, RankOp> { |
77 | void populateBoundsForIndexValue(Operation *op, Value value, |
78 | ValueBoundsConstraintSet &cstr) const { |
79 | auto rankOp = cast<RankOp>(op); |
80 | assert(value == rankOp.getResult() && "invalid value" ); |
81 | |
82 | auto memrefType = llvm::dyn_cast<MemRefType>(rankOp.getMemref().getType()); |
83 | if (!memrefType) |
84 | return; |
85 | cstr.bound(value) == memrefType.getRank(); |
86 | } |
87 | }; |
88 | |
89 | struct SubViewOpInterface |
90 | : public ValueBoundsOpInterface::ExternalModel<SubViewOpInterface, |
91 | SubViewOp> { |
92 | void populateBoundsForShapedValueDim(Operation *op, Value value, int64_t dim, |
93 | ValueBoundsConstraintSet &cstr) const { |
94 | auto subViewOp = cast<SubViewOp>(op); |
95 | assert(value == subViewOp.getResult() && "invalid value" ); |
96 | |
97 | llvm::SmallBitVector dropped = subViewOp.getDroppedDims(); |
98 | int64_t ctr = -1; |
99 | for (int64_t i = 0, e = subViewOp.getMixedSizes().size(); i < e; ++i) { |
100 | // Skip over rank-reduced dimensions. |
101 | if (!dropped.test(Idx: i)) |
102 | ++ctr; |
103 | if (ctr == dim) { |
104 | cstr.bound(value)[dim] == subViewOp.getMixedSizes()[i]; |
105 | return; |
106 | } |
107 | } |
108 | llvm_unreachable("could not find non-rank-reduced dim" ); |
109 | } |
110 | }; |
111 | |
112 | } // namespace |
113 | } // namespace memref |
114 | } // namespace mlir |
115 | |
116 | void mlir::memref::registerValueBoundsOpInterfaceExternalModels( |
117 | DialectRegistry ®istry) { |
118 | registry.addExtension(extensionFn: +[](MLIRContext *ctx, memref::MemRefDialect *dialect) { |
119 | memref::AllocOp::attachInterface<memref::AllocOpInterface<memref::AllocOp>>( |
120 | *ctx); |
121 | memref::AllocaOp::attachInterface< |
122 | memref::AllocOpInterface<memref::AllocaOp>>(*ctx); |
123 | memref::CastOp::attachInterface<memref::CastOpInterface>(*ctx); |
124 | memref::DimOp::attachInterface<memref::DimOpInterface>(*ctx); |
125 | memref::GetGlobalOp::attachInterface<memref::GetGlobalOpInterface>(*ctx); |
126 | memref::RankOp::attachInterface<memref::RankOpInterface>(*ctx); |
127 | memref::SubViewOp::attachInterface<memref::SubViewOpInterface>(*ctx); |
128 | }); |
129 | } |
130 | |