| 1 | //===- ValueBoundsOpInterfaceImpl.cpp - Impl. of ValueBoundsOpInterface ---===// |
| 2 | // |
| 3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
| 4 | // See https://llvm.org/LICENSE.txt for license information. |
| 5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
| 6 | // |
| 7 | //===----------------------------------------------------------------------===// |
| 8 | |
| 9 | #include "mlir/Dialect/MemRef/IR/ValueBoundsOpInterfaceImpl.h" |
| 10 | |
| 11 | #include "mlir/Dialect/MemRef/IR/MemRef.h" |
| 12 | #include "mlir/Interfaces/ValueBoundsOpInterface.h" |
| 13 | |
| 14 | using namespace mlir; |
| 15 | |
| 16 | namespace mlir { |
| 17 | namespace memref { |
| 18 | namespace { |
| 19 | |
| 20 | template <typename OpTy> |
| 21 | struct AllocOpInterface |
| 22 | : public ValueBoundsOpInterface::ExternalModel<AllocOpInterface<OpTy>, |
| 23 | OpTy> { |
| 24 | void populateBoundsForShapedValueDim(Operation *op, Value value, int64_t dim, |
| 25 | ValueBoundsConstraintSet &cstr) const { |
| 26 | auto allocOp = cast<OpTy>(op); |
| 27 | assert(value == allocOp.getResult() && "invalid value" ); |
| 28 | |
| 29 | cstr.bound(value)[dim] == allocOp.getMixedSizes()[dim]; |
| 30 | } |
| 31 | }; |
| 32 | |
| 33 | struct CastOpInterface |
| 34 | : public ValueBoundsOpInterface::ExternalModel<CastOpInterface, CastOp> { |
| 35 | void populateBoundsForShapedValueDim(Operation *op, Value value, int64_t dim, |
| 36 | ValueBoundsConstraintSet &cstr) const { |
| 37 | auto castOp = cast<CastOp>(op); |
| 38 | assert(value == castOp.getResult() && "invalid value" ); |
| 39 | |
| 40 | if (llvm::isa<MemRefType>(castOp.getResult().getType()) && |
| 41 | llvm::isa<MemRefType>(castOp.getSource().getType())) { |
| 42 | cstr.bound(value)[dim] == cstr.getExpr(castOp.getSource(), dim); |
| 43 | } |
| 44 | } |
| 45 | }; |
| 46 | |
| 47 | struct DimOpInterface |
| 48 | : public ValueBoundsOpInterface::ExternalModel<DimOpInterface, DimOp> { |
| 49 | void populateBoundsForIndexValue(Operation *op, Value value, |
| 50 | ValueBoundsConstraintSet &cstr) const { |
| 51 | auto dimOp = cast<DimOp>(op); |
| 52 | assert(value == dimOp.getResult() && "invalid value" ); |
| 53 | |
| 54 | cstr.bound(value) >= 0; |
| 55 | auto constIndex = dimOp.getConstantIndex(); |
| 56 | if (!constIndex.has_value()) |
| 57 | return; |
| 58 | cstr.bound(value) == cstr.getExpr(dimOp.getSource(), *constIndex); |
| 59 | } |
| 60 | }; |
| 61 | |
| 62 | struct GetGlobalOpInterface |
| 63 | : public ValueBoundsOpInterface::ExternalModel<GetGlobalOpInterface, |
| 64 | GetGlobalOp> { |
| 65 | void populateBoundsForShapedValueDim(Operation *op, Value value, int64_t dim, |
| 66 | ValueBoundsConstraintSet &cstr) const { |
| 67 | auto getGlobalOp = cast<GetGlobalOp>(op); |
| 68 | assert(value == getGlobalOp.getResult() && "invalid value" ); |
| 69 | |
| 70 | auto type = getGlobalOp.getType(); |
| 71 | assert(!type.isDynamicDim(dim) && "expected static dim" ); |
| 72 | cstr.bound(value)[dim] == type.getDimSize(dim); |
| 73 | } |
| 74 | }; |
| 75 | |
| 76 | struct RankOpInterface |
| 77 | : public ValueBoundsOpInterface::ExternalModel<RankOpInterface, RankOp> { |
| 78 | void populateBoundsForIndexValue(Operation *op, Value value, |
| 79 | ValueBoundsConstraintSet &cstr) const { |
| 80 | auto rankOp = cast<RankOp>(op); |
| 81 | assert(value == rankOp.getResult() && "invalid value" ); |
| 82 | |
| 83 | auto memrefType = llvm::dyn_cast<MemRefType>(rankOp.getMemref().getType()); |
| 84 | if (!memrefType) |
| 85 | return; |
| 86 | cstr.bound(value) == memrefType.getRank(); |
| 87 | } |
| 88 | }; |
| 89 | |
| 90 | struct SubViewOpInterface |
| 91 | : public ValueBoundsOpInterface::ExternalModel<SubViewOpInterface, |
| 92 | SubViewOp> { |
| 93 | void populateBoundsForShapedValueDim(Operation *op, Value value, int64_t dim, |
| 94 | ValueBoundsConstraintSet &cstr) const { |
| 95 | auto subViewOp = cast<SubViewOp>(op); |
| 96 | assert(value == subViewOp.getResult() && "invalid value" ); |
| 97 | |
| 98 | llvm::SmallBitVector dropped = subViewOp.getDroppedDims(); |
| 99 | int64_t ctr = -1; |
| 100 | for (int64_t i = 0, e = subViewOp.getMixedSizes().size(); i < e; ++i) { |
| 101 | // Skip over rank-reduced dimensions. |
| 102 | if (!dropped.test(Idx: i)) |
| 103 | ++ctr; |
| 104 | if (ctr == dim) { |
| 105 | cstr.bound(value)[dim] == subViewOp.getMixedSizes()[i]; |
| 106 | return; |
| 107 | } |
| 108 | } |
| 109 | llvm_unreachable("could not find non-rank-reduced dim" ); |
| 110 | } |
| 111 | }; |
| 112 | |
| 113 | } // namespace |
| 114 | } // namespace memref |
| 115 | } // namespace mlir |
| 116 | |
| 117 | void mlir::memref::registerValueBoundsOpInterfaceExternalModels( |
| 118 | DialectRegistry ®istry) { |
| 119 | registry.addExtension(extensionFn: +[](MLIRContext *ctx, memref::MemRefDialect *dialect) { |
| 120 | memref::AllocOp::attachInterface<memref::AllocOpInterface<memref::AllocOp>>( |
| 121 | *ctx); |
| 122 | memref::AllocaOp::attachInterface< |
| 123 | memref::AllocOpInterface<memref::AllocaOp>>(*ctx); |
| 124 | memref::CastOp::attachInterface<memref::CastOpInterface>(*ctx); |
| 125 | memref::DimOp::attachInterface<memref::DimOpInterface>(*ctx); |
| 126 | memref::GetGlobalOp::attachInterface<memref::GetGlobalOpInterface>(*ctx); |
| 127 | memref::RankOp::attachInterface<memref::RankOpInterface>(*ctx); |
| 128 | memref::SubViewOp::attachInterface<memref::SubViewOpInterface>(*ctx); |
| 129 | }); |
| 130 | } |
| 131 | |