1 | //===- RuntimeOpVerification.cpp - Op Verification ------------------------===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | |
9 | #include "mlir/Dialect/MemRef/Transforms/RuntimeOpVerification.h" |
10 | |
11 | #include "mlir/Dialect/Affine/IR/AffineOps.h" |
12 | #include "mlir/Dialect/Arith/IR/Arith.h" |
13 | #include "mlir/Dialect/Arith/Utils/Utils.h" |
14 | #include "mlir/Dialect/ControlFlow/IR/ControlFlow.h" |
15 | #include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h" |
16 | #include "mlir/Dialect/MemRef/IR/MemRef.h" |
17 | #include "mlir/Dialect/MemRef/Utils/MemRefUtils.h" |
18 | #include "mlir/Dialect/Utils/IndexingUtils.h" |
19 | #include "mlir/Interfaces/RuntimeVerifiableOpInterface.h" |
20 | |
21 | using namespace mlir; |
22 | |
23 | /// Generate an error message string for the given op and the specified error. |
24 | static std::string generateErrorMessage(Operation *op, const std::string &msg) { |
25 | std::string buffer; |
26 | llvm::raw_string_ostream stream(buffer); |
27 | OpPrintingFlags flags; |
28 | // We may generate a lot of error messages and so we need to ensure the |
29 | // printing is fast. |
30 | flags.elideLargeElementsAttrs(); |
31 | flags.printGenericOpForm(); |
32 | flags.skipRegions(); |
33 | flags.useLocalScope(); |
34 | stream << "ERROR: Runtime op verification failed\n" ; |
35 | op->print(os&: stream, flags); |
36 | stream << "\n^ " << msg; |
37 | stream << "\nLocation: " ; |
38 | op->getLoc().print(os&: stream); |
39 | return stream.str(); |
40 | } |
41 | |
42 | namespace mlir { |
43 | namespace memref { |
44 | namespace { |
45 | struct CastOpInterface |
46 | : public RuntimeVerifiableOpInterface::ExternalModel<CastOpInterface, |
47 | CastOp> { |
48 | void generateRuntimeVerification(Operation *op, OpBuilder &builder, |
49 | Location loc) const { |
50 | auto castOp = cast<CastOp>(op); |
51 | auto srcType = cast<BaseMemRefType>(castOp.getSource().getType()); |
52 | |
53 | // Nothing to check if the result is an unranked memref. |
54 | auto resultType = dyn_cast<MemRefType>(castOp.getType()); |
55 | if (!resultType) |
56 | return; |
57 | |
58 | if (isa<UnrankedMemRefType>(srcType)) { |
59 | // Check rank. |
60 | Value srcRank = builder.create<RankOp>(loc, castOp.getSource()); |
61 | Value resultRank = |
62 | builder.create<arith::ConstantIndexOp>(loc, resultType.getRank()); |
63 | Value isSameRank = builder.create<arith::CmpIOp>( |
64 | loc, arith::CmpIPredicate::eq, srcRank, resultRank); |
65 | builder.create<cf::AssertOp>(loc, isSameRank, |
66 | generateErrorMessage(op, "rank mismatch" )); |
67 | } |
68 | |
69 | // Get source offset and strides. We do not have an op to get offsets and |
70 | // strides from unranked memrefs, so cast the source to a type with fully |
71 | // dynamic layout, from which we can then extract the offset and strides. |
72 | // (Rank was already verified.) |
73 | int64_t dynamicOffset = ShapedType::kDynamic; |
74 | SmallVector<int64_t> dynamicShape(resultType.getRank(), |
75 | ShapedType::kDynamic); |
76 | auto stridedLayout = StridedLayoutAttr::get(builder.getContext(), |
77 | dynamicOffset, dynamicShape); |
78 | auto dynStridesType = |
79 | MemRefType::get(dynamicShape, resultType.getElementType(), |
80 | stridedLayout, resultType.getMemorySpace()); |
81 | Value helperCast = |
82 | builder.create<CastOp>(loc, dynStridesType, castOp.getSource()); |
83 | auto metadataOp = builder.create<ExtractStridedMetadataOp>(loc, helperCast); |
84 | |
85 | // Check dimension sizes. |
86 | for (const auto &it : llvm::enumerate(resultType.getShape())) { |
87 | // Static dim size -> static/dynamic dim size does not need verification. |
88 | if (auto rankedSrcType = dyn_cast<MemRefType>(srcType)) |
89 | if (!rankedSrcType.isDynamicDim(it.index())) |
90 | continue; |
91 | |
92 | // Static/dynamic dim size -> dynamic dim size does not need verification. |
93 | if (resultType.isDynamicDim(it.index())) |
94 | continue; |
95 | |
96 | Value srcDimSz = |
97 | builder.create<DimOp>(loc, castOp.getSource(), it.index()); |
98 | Value resultDimSz = |
99 | builder.create<arith::ConstantIndexOp>(loc, it.value()); |
100 | Value isSameSz = builder.create<arith::CmpIOp>( |
101 | loc, arith::CmpIPredicate::eq, srcDimSz, resultDimSz); |
102 | builder.create<cf::AssertOp>( |
103 | loc, isSameSz, |
104 | generateErrorMessage(op, "size mismatch of dim " + |
105 | std::to_string(it.index()))); |
106 | } |
107 | |
108 | // Get result offset and strides. |
109 | int64_t resultOffset; |
110 | SmallVector<int64_t> resultStrides; |
111 | if (failed(getStridesAndOffset(resultType, resultStrides, resultOffset))) |
112 | return; |
113 | |
114 | // Check offset. |
115 | if (resultOffset != ShapedType::kDynamic) { |
116 | // Static/dynamic offset -> dynamic offset does not need verification. |
117 | Value srcOffset = metadataOp.getResult(1); |
118 | Value resultOffsetVal = |
119 | builder.create<arith::ConstantIndexOp>(location: loc, args&: resultOffset); |
120 | Value isSameOffset = builder.create<arith::CmpIOp>( |
121 | loc, arith::CmpIPredicate::eq, srcOffset, resultOffsetVal); |
122 | builder.create<cf::AssertOp>(loc, isSameOffset, |
123 | generateErrorMessage(op, "offset mismatch" )); |
124 | } |
125 | |
126 | // Check strides. |
127 | for (const auto &it : llvm::enumerate(First&: resultStrides)) { |
128 | // Static/dynamic stride -> dynamic stride does not need verification. |
129 | if (it.value() == ShapedType::kDynamic) |
130 | continue; |
131 | |
132 | Value srcStride = |
133 | metadataOp.getResult(2 + resultType.getRank() + it.index()); |
134 | Value resultStrideVal = |
135 | builder.create<arith::ConstantIndexOp>(location: loc, args&: it.value()); |
136 | Value isSameStride = builder.create<arith::CmpIOp>( |
137 | loc, arith::CmpIPredicate::eq, srcStride, resultStrideVal); |
138 | builder.create<cf::AssertOp>( |
139 | loc, isSameStride, |
140 | generateErrorMessage(op, "stride mismatch of dim " + |
141 | std::to_string(it.index()))); |
142 | } |
143 | } |
144 | }; |
145 | |
146 | /// Verifies that the indices on load/store ops are in-bounds of the memref's |
147 | /// index space: 0 <= index#i < dim#i |
148 | template <typename LoadStoreOp> |
149 | struct LoadStoreOpInterface |
150 | : public RuntimeVerifiableOpInterface::ExternalModel< |
151 | LoadStoreOpInterface<LoadStoreOp>, LoadStoreOp> { |
152 | void generateRuntimeVerification(Operation *op, OpBuilder &builder, |
153 | Location loc) const { |
154 | auto loadStoreOp = cast<LoadStoreOp>(op); |
155 | |
156 | auto memref = loadStoreOp.getMemref(); |
157 | auto rank = memref.getType().getRank(); |
158 | if (rank == 0) { |
159 | return; |
160 | } |
161 | auto indices = loadStoreOp.getIndices(); |
162 | |
163 | auto zero = builder.create<arith::ConstantIndexOp>(location: loc, args: 0); |
164 | Value assertCond; |
165 | for (auto i : llvm::seq<int64_t>(0, rank)) { |
166 | auto index = indices[i]; |
167 | |
168 | auto dimOp = builder.createOrFold<memref::DimOp>(loc, memref, i); |
169 | |
170 | auto geLow = builder.createOrFold<arith::CmpIOp>( |
171 | loc, arith::CmpIPredicate::sge, index, zero); |
172 | auto ltHigh = builder.createOrFold<arith::CmpIOp>( |
173 | loc, arith::CmpIPredicate::slt, index, dimOp); |
174 | auto andOp = builder.createOrFold<arith::AndIOp>(loc, geLow, ltHigh); |
175 | |
176 | assertCond = |
177 | i > 0 ? builder.createOrFold<arith::AndIOp>(loc, assertCond, andOp) |
178 | : andOp; |
179 | } |
180 | builder.create<cf::AssertOp>( |
181 | loc, assertCond, generateErrorMessage(op, "out-of-bounds access" )); |
182 | } |
183 | }; |
184 | |
185 | /// Compute the linear index for the provided strided layout and indices. |
186 | Value computeLinearIndex(OpBuilder &builder, Location loc, OpFoldResult offset, |
187 | ArrayRef<OpFoldResult> strides, |
188 | ArrayRef<OpFoldResult> indices) { |
189 | auto [expr, values] = computeLinearIndex(sourceOffset: offset, strides, indices); |
190 | auto index = |
191 | affine::makeComposedFoldedAffineApply(b&: builder, loc, expr, operands: values); |
192 | return getValueOrCreateConstantIndexOp(b&: builder, loc, ofr: index); |
193 | } |
194 | |
195 | /// Returns two Values representing the bounds of the provided strided layout |
196 | /// metadata. The bounds are returned as a half open interval -- [low, high). |
197 | std::pair<Value, Value> computeLinearBounds(OpBuilder &builder, Location loc, |
198 | OpFoldResult offset, |
199 | ArrayRef<OpFoldResult> strides, |
200 | ArrayRef<OpFoldResult> sizes) { |
201 | auto zeros = SmallVector<int64_t>(sizes.size(), 0); |
202 | auto indices = getAsIndexOpFoldResult(ctx: builder.getContext(), values: zeros); |
203 | auto lowerBound = computeLinearIndex(builder, loc, offset, strides, indices); |
204 | auto upperBound = computeLinearIndex(builder, loc, offset, strides, indices: sizes); |
205 | return {lowerBound, upperBound}; |
206 | } |
207 | |
208 | /// Returns two Values representing the bounds of the memref. The bounds are |
209 | /// returned as a half open interval -- [low, high). |
210 | std::pair<Value, Value> computeLinearBounds(OpBuilder &builder, Location loc, |
211 | TypedValue<BaseMemRefType> memref) { |
212 | auto runtimeMetadata = builder.create<ExtractStridedMetadataOp>(loc, memref); |
213 | auto offset = runtimeMetadata.getConstifiedMixedOffset(); |
214 | auto strides = runtimeMetadata.getConstifiedMixedStrides(); |
215 | auto sizes = runtimeMetadata.getConstifiedMixedSizes(); |
216 | return computeLinearBounds(builder, loc, offset, strides, sizes); |
217 | } |
218 | |
219 | /// Verifies that the linear bounds of a reinterpret_cast op are within the |
220 | /// linear bounds of the base memref: low >= baseLow && high <= baseHigh |
221 | struct ReinterpretCastOpInterface |
222 | : public RuntimeVerifiableOpInterface::ExternalModel< |
223 | ReinterpretCastOpInterface, ReinterpretCastOp> { |
224 | void generateRuntimeVerification(Operation *op, OpBuilder &builder, |
225 | Location loc) const { |
226 | auto reinterpretCast = cast<ReinterpretCastOp>(op); |
227 | auto baseMemref = reinterpretCast.getSource(); |
228 | auto resultMemref = |
229 | cast<TypedValue<BaseMemRefType>>(reinterpretCast.getResult()); |
230 | |
231 | builder.setInsertionPointAfter(op); |
232 | |
233 | // Compute the linear bounds of the base memref |
234 | auto [baseLow, baseHigh] = computeLinearBounds(builder, loc, baseMemref); |
235 | |
236 | // Compute the linear bounds of the resulting memref |
237 | auto [low, high] = computeLinearBounds(builder, loc, resultMemref); |
238 | |
239 | // Check low >= baseLow |
240 | auto geLow = builder.createOrFold<arith::CmpIOp>( |
241 | loc, arith::CmpIPredicate::sge, low, baseLow); |
242 | |
243 | // Check high <= baseHigh |
244 | auto leHigh = builder.createOrFold<arith::CmpIOp>( |
245 | loc, arith::CmpIPredicate::sle, high, baseHigh); |
246 | |
247 | auto assertCond = builder.createOrFold<arith::AndIOp>(loc, geLow, leHigh); |
248 | |
249 | builder.create<cf::AssertOp>( |
250 | loc, assertCond, |
251 | generateErrorMessage( |
252 | op, |
253 | "result of reinterpret_cast is out-of-bounds of the base memref" )); |
254 | } |
255 | }; |
256 | |
257 | /// Verifies that the linear bounds of a subview op are within the linear bounds |
258 | /// of the base memref: low >= baseLow && high <= baseHigh |
259 | /// TODO: This is not yet a full runtime verification of subview. For example, |
260 | /// consider: |
261 | /// %m = memref.alloc(%c10, %c10) : memref<10x10xf32> |
262 | /// memref.subview %m[%c0, %c0][%c20, %c2][%c1, %c1] |
263 | /// : memref<?x?xf32> to memref<?x?xf32> |
264 | /// The subview is in-bounds of the entire base memref but the first dimension |
265 | /// is out-of-bounds. Future work would verify the bounds on a per-dimension |
266 | /// basis. |
267 | struct SubViewOpInterface |
268 | : public RuntimeVerifiableOpInterface::ExternalModel<SubViewOpInterface, |
269 | SubViewOp> { |
270 | void generateRuntimeVerification(Operation *op, OpBuilder &builder, |
271 | Location loc) const { |
272 | auto subView = cast<SubViewOp>(op); |
273 | auto baseMemref = cast<TypedValue<BaseMemRefType>>(subView.getSource()); |
274 | auto resultMemref = cast<TypedValue<BaseMemRefType>>(subView.getResult()); |
275 | |
276 | builder.setInsertionPointAfter(op); |
277 | |
278 | // Compute the linear bounds of the base memref |
279 | auto [baseLow, baseHigh] = computeLinearBounds(builder, loc, baseMemref); |
280 | |
281 | // Compute the linear bounds of the resulting memref |
282 | auto [low, high] = computeLinearBounds(builder, loc, resultMemref); |
283 | |
284 | // Check low >= baseLow |
285 | auto geLow = builder.createOrFold<arith::CmpIOp>( |
286 | loc, arith::CmpIPredicate::sge, low, baseLow); |
287 | |
288 | // Check high <= baseHigh |
289 | auto leHigh = builder.createOrFold<arith::CmpIOp>( |
290 | loc, arith::CmpIPredicate::sle, high, baseHigh); |
291 | |
292 | auto assertCond = builder.createOrFold<arith::AndIOp>(loc, geLow, leHigh); |
293 | |
294 | builder.create<cf::AssertOp>( |
295 | loc, assertCond, |
296 | generateErrorMessage(op, |
297 | "subview is out-of-bounds of the base memref" )); |
298 | } |
299 | }; |
300 | |
301 | struct ExpandShapeOpInterface |
302 | : public RuntimeVerifiableOpInterface::ExternalModel<ExpandShapeOpInterface, |
303 | ExpandShapeOp> { |
304 | void generateRuntimeVerification(Operation *op, OpBuilder &builder, |
305 | Location loc) const { |
306 | auto expandShapeOp = cast<ExpandShapeOp>(op); |
307 | |
308 | // Verify that the expanded dim sizes are a product of the collapsed dim |
309 | // size. |
310 | for (const auto &it : |
311 | llvm::enumerate(expandShapeOp.getReassociationIndices())) { |
312 | Value srcDimSz = |
313 | builder.create<DimOp>(loc, expandShapeOp.getSrc(), it.index()); |
314 | int64_t groupSz = 1; |
315 | bool foundDynamicDim = false; |
316 | for (int64_t resultDim : it.value()) { |
317 | if (expandShapeOp.getResultType().isDynamicDim(resultDim)) { |
318 | // Keep this assert here in case the op is extended in the future. |
319 | assert(!foundDynamicDim && |
320 | "more than one dynamic dim found in reassoc group" ); |
321 | (void)foundDynamicDim; |
322 | foundDynamicDim = true; |
323 | continue; |
324 | } |
325 | groupSz *= expandShapeOp.getResultType().getDimSize(resultDim); |
326 | } |
327 | Value staticResultDimSz = |
328 | builder.create<arith::ConstantIndexOp>(loc, groupSz); |
329 | // staticResultDimSz must divide srcDimSz evenly. |
330 | Value mod = |
331 | builder.create<arith::RemSIOp>(loc, srcDimSz, staticResultDimSz); |
332 | Value isModZero = builder.create<arith::CmpIOp>( |
333 | loc, arith::CmpIPredicate::eq, mod, |
334 | builder.create<arith::ConstantIndexOp>(loc, 0)); |
335 | builder.create<cf::AssertOp>( |
336 | loc, isModZero, |
337 | generateErrorMessage(op, "static result dims in reassoc group do not " |
338 | "divide src dim evenly" )); |
339 | } |
340 | } |
341 | }; |
342 | } // namespace |
343 | } // namespace memref |
344 | } // namespace mlir |
345 | |
346 | void mlir::memref::registerRuntimeVerifiableOpInterfaceExternalModels( |
347 | DialectRegistry ®istry) { |
348 | registry.addExtension(extensionFn: +[](MLIRContext *ctx, memref::MemRefDialect *dialect) { |
349 | CastOp::attachInterface<CastOpInterface>(*ctx); |
350 | ExpandShapeOp::attachInterface<ExpandShapeOpInterface>(*ctx); |
351 | LoadOp::attachInterface<LoadStoreOpInterface<LoadOp>>(*ctx); |
352 | ReinterpretCastOp::attachInterface<ReinterpretCastOpInterface>(*ctx); |
353 | StoreOp::attachInterface<LoadStoreOpInterface<StoreOp>>(*ctx); |
354 | SubViewOp::attachInterface<SubViewOpInterface>(*ctx); |
355 | |
356 | // Load additional dialects of which ops may get created. |
357 | ctx->loadDialect<affine::AffineDialect, arith::ArithDialect, |
358 | cf::ControlFlowDialect>(); |
359 | }); |
360 | } |
361 | |