1 | //===- ValueBoundsOpInterfaceImpl.cpp - Impl. of ValueBoundsOpInterface ---===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | |
9 | #include "mlir/Dialect/MemRef/IR/ValueBoundsOpInterfaceImpl.h" |
10 | |
11 | #include "mlir/Dialect/MemRef/IR/MemRef.h" |
12 | #include "mlir/Interfaces/ValueBoundsOpInterface.h" |
13 | |
14 | using namespace mlir; |
15 | |
16 | namespace mlir { |
17 | namespace memref { |
18 | namespace { |
19 | |
20 | template <typename OpTy> |
21 | struct AllocOpInterface |
22 | : public ValueBoundsOpInterface::ExternalModel<AllocOpInterface<OpTy>, |
23 | OpTy> { |
24 | void populateBoundsForShapedValueDim(Operation *op, Value value, int64_t dim, |
25 | ValueBoundsConstraintSet &cstr) const { |
26 | auto allocOp = cast<OpTy>(op); |
27 | assert(value == allocOp.getResult() && "invalid value" ); |
28 | |
29 | cstr.bound(value)[dim] == allocOp.getMixedSizes()[dim]; |
30 | } |
31 | }; |
32 | |
33 | struct CastOpInterface |
34 | : public ValueBoundsOpInterface::ExternalModel<CastOpInterface, CastOp> { |
35 | void populateBoundsForShapedValueDim(Operation *op, Value value, int64_t dim, |
36 | ValueBoundsConstraintSet &cstr) const { |
37 | auto castOp = cast<CastOp>(op); |
38 | assert(value == castOp.getResult() && "invalid value" ); |
39 | |
40 | if (llvm::isa<MemRefType>(castOp.getResult().getType()) && |
41 | llvm::isa<MemRefType>(castOp.getSource().getType())) { |
42 | cstr.bound(value)[dim] == cstr.getExpr(castOp.getSource(), dim); |
43 | } |
44 | } |
45 | }; |
46 | |
47 | struct DimOpInterface |
48 | : public ValueBoundsOpInterface::ExternalModel<DimOpInterface, DimOp> { |
49 | void populateBoundsForIndexValue(Operation *op, Value value, |
50 | ValueBoundsConstraintSet &cstr) const { |
51 | auto dimOp = cast<DimOp>(op); |
52 | assert(value == dimOp.getResult() && "invalid value" ); |
53 | |
54 | cstr.bound(value) >= 0; |
55 | auto constIndex = dimOp.getConstantIndex(); |
56 | if (!constIndex.has_value()) |
57 | return; |
58 | cstr.bound(value) == cstr.getExpr(dimOp.getSource(), *constIndex); |
59 | } |
60 | }; |
61 | |
62 | struct GetGlobalOpInterface |
63 | : public ValueBoundsOpInterface::ExternalModel<GetGlobalOpInterface, |
64 | GetGlobalOp> { |
65 | void populateBoundsForShapedValueDim(Operation *op, Value value, int64_t dim, |
66 | ValueBoundsConstraintSet &cstr) const { |
67 | auto getGlobalOp = cast<GetGlobalOp>(op); |
68 | assert(value == getGlobalOp.getResult() && "invalid value" ); |
69 | |
70 | auto type = getGlobalOp.getType(); |
71 | assert(!type.isDynamicDim(dim) && "expected static dim" ); |
72 | cstr.bound(value)[dim] == type.getDimSize(dim); |
73 | } |
74 | }; |
75 | |
76 | struct RankOpInterface |
77 | : public ValueBoundsOpInterface::ExternalModel<RankOpInterface, RankOp> { |
78 | void populateBoundsForIndexValue(Operation *op, Value value, |
79 | ValueBoundsConstraintSet &cstr) const { |
80 | auto rankOp = cast<RankOp>(op); |
81 | assert(value == rankOp.getResult() && "invalid value" ); |
82 | |
83 | auto memrefType = llvm::dyn_cast<MemRefType>(rankOp.getMemref().getType()); |
84 | if (!memrefType) |
85 | return; |
86 | cstr.bound(value) == memrefType.getRank(); |
87 | } |
88 | }; |
89 | |
90 | struct SubViewOpInterface |
91 | : public ValueBoundsOpInterface::ExternalModel<SubViewOpInterface, |
92 | SubViewOp> { |
93 | void populateBoundsForShapedValueDim(Operation *op, Value value, int64_t dim, |
94 | ValueBoundsConstraintSet &cstr) const { |
95 | auto subViewOp = cast<SubViewOp>(op); |
96 | assert(value == subViewOp.getResult() && "invalid value" ); |
97 | |
98 | llvm::SmallBitVector dropped = subViewOp.getDroppedDims(); |
99 | int64_t ctr = -1; |
100 | for (int64_t i = 0, e = subViewOp.getMixedSizes().size(); i < e; ++i) { |
101 | // Skip over rank-reduced dimensions. |
102 | if (!dropped.test(Idx: i)) |
103 | ++ctr; |
104 | if (ctr == dim) { |
105 | cstr.bound(value)[dim] == subViewOp.getMixedSizes()[i]; |
106 | return; |
107 | } |
108 | } |
109 | llvm_unreachable("could not find non-rank-reduced dim" ); |
110 | } |
111 | }; |
112 | |
113 | } // namespace |
114 | } // namespace memref |
115 | } // namespace mlir |
116 | |
117 | void mlir::memref::registerValueBoundsOpInterfaceExternalModels( |
118 | DialectRegistry ®istry) { |
119 | registry.addExtension(extensionFn: +[](MLIRContext *ctx, memref::MemRefDialect *dialect) { |
120 | memref::AllocOp::attachInterface<memref::AllocOpInterface<memref::AllocOp>>( |
121 | *ctx); |
122 | memref::AllocaOp::attachInterface< |
123 | memref::AllocOpInterface<memref::AllocaOp>>(*ctx); |
124 | memref::CastOp::attachInterface<memref::CastOpInterface>(*ctx); |
125 | memref::DimOp::attachInterface<memref::DimOpInterface>(*ctx); |
126 | memref::GetGlobalOp::attachInterface<memref::GetGlobalOpInterface>(*ctx); |
127 | memref::RankOp::attachInterface<memref::RankOpInterface>(*ctx); |
128 | memref::SubViewOp::attachInterface<memref::SubViewOpInterface>(*ctx); |
129 | }); |
130 | } |
131 | |