1 | //===- RuntimeOpVerification.cpp - Op Verification ------------------------===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | |
9 | #include "mlir/Dialect/Tensor/Transforms/RuntimeOpVerification.h" |
10 | |
11 | #include "mlir/Dialect/Arith/IR/Arith.h" |
12 | #include "mlir/Dialect/Arith/Utils/Utils.h" |
13 | #include "mlir/Dialect/ControlFlow/IR/ControlFlow.h" |
14 | #include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h" |
15 | #include "mlir/Dialect/Tensor/IR/Tensor.h" |
16 | #include "mlir/Dialect/Utils/IndexingUtils.h" |
17 | #include "mlir/Interfaces/RuntimeVerifiableOpInterface.h" |
18 | |
19 | using namespace mlir; |
20 | |
21 | namespace mlir { |
22 | namespace tensor { |
23 | namespace { |
24 | /// Generate a runtime check for lb <= value < ub. |
25 | Value generateInBoundsCheck(OpBuilder &builder, Location loc, Value value, |
26 | Value lb, Value ub) { |
27 | Value inBounds1 = builder.createOrFold<arith::CmpIOp>( |
28 | loc, arith::CmpIPredicate::sge, value, lb); |
29 | Value inBounds2 = builder.createOrFold<arith::CmpIOp>( |
30 | loc, arith::CmpIPredicate::slt, value, ub); |
31 | Value inBounds = |
32 | builder.createOrFold<arith::AndIOp>(loc, inBounds1, inBounds2); |
33 | return inBounds; |
34 | } |
35 | |
36 | struct CastOpInterface |
37 | : public RuntimeVerifiableOpInterface::ExternalModel<CastOpInterface, |
38 | CastOp> { |
39 | void generateRuntimeVerification(Operation *op, OpBuilder &builder, |
40 | Location loc) const { |
41 | auto castOp = cast<CastOp>(op); |
42 | auto srcType = cast<TensorType>(castOp.getSource().getType()); |
43 | |
44 | // Nothing to check if the result is an unranked tensor. |
45 | auto resultType = dyn_cast<RankedTensorType>(castOp.getType()); |
46 | if (!resultType) |
47 | return; |
48 | |
49 | if (isa<UnrankedTensorType>(srcType)) { |
50 | // Check rank. |
51 | Value srcRank = builder.create<RankOp>(loc, castOp.getSource()); |
52 | Value resultRank = |
53 | builder.create<arith::ConstantIndexOp>(loc, resultType.getRank()); |
54 | Value isSameRank = builder.create<arith::CmpIOp>( |
55 | loc, arith::CmpIPredicate::eq, srcRank, resultRank); |
56 | builder.create<cf::AssertOp>( |
57 | loc, isSameRank, |
58 | RuntimeVerifiableOpInterface::generateErrorMessage(op, |
59 | "rank mismatch" )); |
60 | } |
61 | |
62 | // Check dimension sizes. |
63 | for (const auto &it : llvm::enumerate(resultType.getShape())) { |
64 | // Static dim size -> static/dynamic dim size does not need verification. |
65 | if (auto rankedSrcType = dyn_cast<RankedTensorType>(srcType)) |
66 | if (!rankedSrcType.isDynamicDim(it.index())) |
67 | continue; |
68 | |
69 | // Static/dynamic dim size -> dynamic dim size does not need verification. |
70 | if (resultType.isDynamicDim(it.index())) |
71 | continue; |
72 | |
73 | Value srcDimSz = |
74 | builder.create<DimOp>(loc, castOp.getSource(), it.index()); |
75 | Value resultDimSz = |
76 | builder.create<arith::ConstantIndexOp>(loc, it.value()); |
77 | Value isSameSz = builder.create<arith::CmpIOp>( |
78 | loc, arith::CmpIPredicate::eq, srcDimSz, resultDimSz); |
79 | builder.create<cf::AssertOp>( |
80 | loc, isSameSz, |
81 | RuntimeVerifiableOpInterface::generateErrorMessage( |
82 | op, "size mismatch of dim " + std::to_string(it.index()))); |
83 | } |
84 | } |
85 | }; |
86 | |
87 | struct DimOpInterface |
88 | : public RuntimeVerifiableOpInterface::ExternalModel<DimOpInterface, |
89 | DimOp> { |
90 | void generateRuntimeVerification(Operation *op, OpBuilder &builder, |
91 | Location loc) const { |
92 | auto dimOp = cast<DimOp>(op); |
93 | Value rank = builder.create<RankOp>(loc, dimOp.getSource()); |
94 | Value zero = builder.create<arith::ConstantIndexOp>(loc, 0); |
95 | builder.create<cf::AssertOp>( |
96 | loc, generateInBoundsCheck(builder, loc, dimOp.getIndex(), zero, rank), |
97 | RuntimeVerifiableOpInterface::generateErrorMessage( |
98 | op, "index is out of bounds" )); |
99 | } |
100 | }; |
101 | |
102 | /// Verifies that the indices on extract/insert ops are in-bounds of the |
103 | /// tensor's index space: 0 <= index#i < dim#i |
104 | template <typename OpTy> |
105 | struct |
106 | : public RuntimeVerifiableOpInterface::ExternalModel< |
107 | ExtractInsertOpInterface<OpTy>, OpTy> { |
108 | void (Operation *op, OpBuilder &builder, |
109 | Location loc) const { |
110 | auto = cast<OpTy>(op); |
111 | |
112 | Value tensor; |
113 | if constexpr (std::is_same_v<OpTy, ExtractOp>) { |
114 | tensor = extractInsertOp.getTensor(); |
115 | } else if constexpr (std::is_same_v<OpTy, InsertOp>) { |
116 | tensor = extractInsertOp.getDest(); |
117 | } else { |
118 | llvm_unreachable("invalid op" ); |
119 | } |
120 | auto tensorType = cast<RankedTensorType>(tensor.getType()); |
121 | auto rank = tensorType.getRank(); |
122 | if (rank == 0) { |
123 | // Nothing to check for 0-d tensors. |
124 | return; |
125 | } |
126 | |
127 | auto indices = extractInsertOp.getIndices(); |
128 | auto zero = builder.create<arith::ConstantIndexOp>(loc, 0); |
129 | Value assertCond; |
130 | for (auto i : llvm::seq<int64_t>(0, rank)) { |
131 | Value dimOp = builder.createOrFold<tensor::DimOp>(loc, tensor, i); |
132 | Value inBounds = |
133 | generateInBoundsCheck(builder, loc, indices[i], zero, dimOp); |
134 | assertCond = |
135 | i > 0 ? builder.createOrFold<arith::AndIOp>(loc, assertCond, inBounds) |
136 | : inBounds; |
137 | } |
138 | builder.create<cf::AssertOp>( |
139 | loc, assertCond, |
140 | RuntimeVerifiableOpInterface::generateErrorMessage( |
141 | op, "out-of-bounds access" )); |
142 | } |
143 | }; |
144 | |
145 | struct |
146 | : public RuntimeVerifiableOpInterface::ExternalModel< |
147 | ExtractSliceOpInterface, ExtractSliceOp> { |
148 | void (Operation *op, OpBuilder &builder, |
149 | Location loc) const { |
150 | auto = cast<ExtractSliceOp>(op); |
151 | RankedTensorType sourceType = extractSliceOp.getSource().getType(); |
152 | |
153 | // For each dimension, assert that: |
154 | // 0 <= offset < dim_size |
155 | // 0 <= offset + (size - 1) * stride < dim_size |
156 | Value zero = builder.create<arith::ConstantIndexOp>(loc, 0); |
157 | Value one = builder.create<arith::ConstantIndexOp>(loc, 1); |
158 | for (int64_t i = 0, e = sourceType.getRank(); i < e; ++i) { |
159 | Value offset = getValueOrCreateConstantIndexOp( |
160 | builder, loc, extractSliceOp.getMixedOffsets()[i]); |
161 | Value size = getValueOrCreateConstantIndexOp( |
162 | builder, loc, extractSliceOp.getMixedSizes()[i]); |
163 | Value stride = getValueOrCreateConstantIndexOp( |
164 | builder, loc, extractSliceOp.getMixedStrides()[i]); |
165 | |
166 | // Verify that offset is in-bounds. |
167 | Value dimSize = builder.createOrFold<tensor::DimOp>( |
168 | loc, extractSliceOp.getSource(), i); |
169 | Value offsetInBounds = |
170 | generateInBoundsCheck(builder, loc, value: offset, lb: zero, ub: dimSize); |
171 | builder.create<cf::AssertOp>( |
172 | loc, offsetInBounds, |
173 | RuntimeVerifiableOpInterface::generateErrorMessage( |
174 | op, "offset " + std::to_string(i) + " is out-of-bounds" )); |
175 | |
176 | // Verify that slice does not run out-of-bounds. |
177 | Value sizeMinusOne = builder.create<arith::SubIOp>(loc, size, one); |
178 | Value sizeMinusOneTimesStride = |
179 | builder.create<arith::MulIOp>(loc, sizeMinusOne, stride); |
180 | Value lastPos = |
181 | builder.create<arith::AddIOp>(loc, offset, sizeMinusOneTimesStride); |
182 | Value lastPosInBounds = |
183 | generateInBoundsCheck(builder, loc, value: lastPos, lb: zero, ub: dimSize); |
184 | builder.create<cf::AssertOp>( |
185 | loc, lastPosInBounds, |
186 | RuntimeVerifiableOpInterface::generateErrorMessage( |
187 | op, "extract_slice runs out-of-bounds along dimension " + |
188 | std::to_string(i))); |
189 | } |
190 | } |
191 | }; |
192 | } // namespace |
193 | } // namespace tensor |
194 | } // namespace mlir |
195 | |
196 | void mlir::tensor::registerRuntimeVerifiableOpInterfaceExternalModels( |
197 | DialectRegistry ®istry) { |
198 | registry.addExtension(extensionFn: +[](MLIRContext *ctx, tensor::TensorDialect *dialect) { |
199 | CastOp::attachInterface<CastOpInterface>(*ctx); |
200 | DimOp::attachInterface<DimOpInterface>(*ctx); |
201 | ExtractOp::attachInterface<ExtractInsertOpInterface<ExtractOp>>(*ctx); |
202 | ExtractSliceOp::attachInterface<ExtractSliceOpInterface>(*ctx); |
203 | InsertOp::attachInterface<ExtractInsertOpInterface<InsertOp>>(*ctx); |
204 | |
205 | // Load additional dialects of which ops may get created. |
206 | ctx->loadDialect<arith::ArithDialect, cf::ControlFlowDialect>(); |
207 | }); |
208 | } |
209 | |