1//===- RuntimeOpVerification.cpp - Op Verification ------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#include "mlir/Dialect/MemRef/Transforms/RuntimeOpVerification.h"
10
11#include "mlir/Dialect/Affine/IR/AffineOps.h"
12#include "mlir/Dialect/Arith/IR/Arith.h"
13#include "mlir/Dialect/Arith/Utils/Utils.h"
14#include "mlir/Dialect/ControlFlow/IR/ControlFlow.h"
15#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
16#include "mlir/Dialect/MemRef/IR/MemRef.h"
17#include "mlir/Dialect/MemRef/Utils/MemRefUtils.h"
18#include "mlir/Dialect/Utils/IndexingUtils.h"
19#include "mlir/Interfaces/RuntimeVerifiableOpInterface.h"
20
21using namespace mlir;
22
23/// Generate an error message string for the given op and the specified error.
24static std::string generateErrorMessage(Operation *op, const std::string &msg) {
25 std::string buffer;
26 llvm::raw_string_ostream stream(buffer);
27 OpPrintingFlags flags;
28 // We may generate a lot of error messages and so we need to ensure the
29 // printing is fast.
30 flags.elideLargeElementsAttrs();
31 flags.printGenericOpForm();
32 flags.skipRegions();
33 flags.useLocalScope();
34 stream << "ERROR: Runtime op verification failed\n";
35 op->print(os&: stream, flags);
36 stream << "\n^ " << msg;
37 stream << "\nLocation: ";
38 op->getLoc().print(os&: stream);
39 return stream.str();
40}
41
42namespace mlir {
43namespace memref {
44namespace {
45struct CastOpInterface
46 : public RuntimeVerifiableOpInterface::ExternalModel<CastOpInterface,
47 CastOp> {
48 void generateRuntimeVerification(Operation *op, OpBuilder &builder,
49 Location loc) const {
50 auto castOp = cast<CastOp>(op);
51 auto srcType = cast<BaseMemRefType>(castOp.getSource().getType());
52
53 // Nothing to check if the result is an unranked memref.
54 auto resultType = dyn_cast<MemRefType>(castOp.getType());
55 if (!resultType)
56 return;
57
58 if (isa<UnrankedMemRefType>(srcType)) {
59 // Check rank.
60 Value srcRank = builder.create<RankOp>(loc, castOp.getSource());
61 Value resultRank =
62 builder.create<arith::ConstantIndexOp>(loc, resultType.getRank());
63 Value isSameRank = builder.create<arith::CmpIOp>(
64 loc, arith::CmpIPredicate::eq, srcRank, resultRank);
65 builder.create<cf::AssertOp>(loc, isSameRank,
66 generateErrorMessage(op, "rank mismatch"));
67 }
68
69 // Get source offset and strides. We do not have an op to get offsets and
70 // strides from unranked memrefs, so cast the source to a type with fully
71 // dynamic layout, from which we can then extract the offset and strides.
72 // (Rank was already verified.)
73 int64_t dynamicOffset = ShapedType::kDynamic;
74 SmallVector<int64_t> dynamicShape(resultType.getRank(),
75 ShapedType::kDynamic);
76 auto stridedLayout = StridedLayoutAttr::get(builder.getContext(),
77 dynamicOffset, dynamicShape);
78 auto dynStridesType =
79 MemRefType::get(dynamicShape, resultType.getElementType(),
80 stridedLayout, resultType.getMemorySpace());
81 Value helperCast =
82 builder.create<CastOp>(loc, dynStridesType, castOp.getSource());
83 auto metadataOp = builder.create<ExtractStridedMetadataOp>(loc, helperCast);
84
85 // Check dimension sizes.
86 for (const auto &it : llvm::enumerate(resultType.getShape())) {
87 // Static dim size -> static/dynamic dim size does not need verification.
88 if (auto rankedSrcType = dyn_cast<MemRefType>(srcType))
89 if (!rankedSrcType.isDynamicDim(it.index()))
90 continue;
91
92 // Static/dynamic dim size -> dynamic dim size does not need verification.
93 if (resultType.isDynamicDim(it.index()))
94 continue;
95
96 Value srcDimSz =
97 builder.create<DimOp>(loc, castOp.getSource(), it.index());
98 Value resultDimSz =
99 builder.create<arith::ConstantIndexOp>(loc, it.value());
100 Value isSameSz = builder.create<arith::CmpIOp>(
101 loc, arith::CmpIPredicate::eq, srcDimSz, resultDimSz);
102 builder.create<cf::AssertOp>(
103 loc, isSameSz,
104 generateErrorMessage(op, "size mismatch of dim " +
105 std::to_string(it.index())));
106 }
107
108 // Get result offset and strides.
109 int64_t resultOffset;
110 SmallVector<int64_t> resultStrides;
111 if (failed(getStridesAndOffset(resultType, resultStrides, resultOffset)))
112 return;
113
114 // Check offset.
115 if (resultOffset != ShapedType::kDynamic) {
116 // Static/dynamic offset -> dynamic offset does not need verification.
117 Value srcOffset = metadataOp.getResult(1);
118 Value resultOffsetVal =
119 builder.create<arith::ConstantIndexOp>(location: loc, args&: resultOffset);
120 Value isSameOffset = builder.create<arith::CmpIOp>(
121 loc, arith::CmpIPredicate::eq, srcOffset, resultOffsetVal);
122 builder.create<cf::AssertOp>(loc, isSameOffset,
123 generateErrorMessage(op, "offset mismatch"));
124 }
125
126 // Check strides.
127 for (const auto &it : llvm::enumerate(First&: resultStrides)) {
128 // Static/dynamic stride -> dynamic stride does not need verification.
129 if (it.value() == ShapedType::kDynamic)
130 continue;
131
132 Value srcStride =
133 metadataOp.getResult(2 + resultType.getRank() + it.index());
134 Value resultStrideVal =
135 builder.create<arith::ConstantIndexOp>(location: loc, args&: it.value());
136 Value isSameStride = builder.create<arith::CmpIOp>(
137 loc, arith::CmpIPredicate::eq, srcStride, resultStrideVal);
138 builder.create<cf::AssertOp>(
139 loc, isSameStride,
140 generateErrorMessage(op, "stride mismatch of dim " +
141 std::to_string(it.index())));
142 }
143 }
144};
145
146/// Verifies that the indices on load/store ops are in-bounds of the memref's
147/// index space: 0 <= index#i < dim#i
148template <typename LoadStoreOp>
149struct LoadStoreOpInterface
150 : public RuntimeVerifiableOpInterface::ExternalModel<
151 LoadStoreOpInterface<LoadStoreOp>, LoadStoreOp> {
152 void generateRuntimeVerification(Operation *op, OpBuilder &builder,
153 Location loc) const {
154 auto loadStoreOp = cast<LoadStoreOp>(op);
155
156 auto memref = loadStoreOp.getMemref();
157 auto rank = memref.getType().getRank();
158 if (rank == 0) {
159 return;
160 }
161 auto indices = loadStoreOp.getIndices();
162
163 auto zero = builder.create<arith::ConstantIndexOp>(location: loc, args: 0);
164 Value assertCond;
165 for (auto i : llvm::seq<int64_t>(0, rank)) {
166 auto index = indices[i];
167
168 auto dimOp = builder.createOrFold<memref::DimOp>(loc, memref, i);
169
170 auto geLow = builder.createOrFold<arith::CmpIOp>(
171 loc, arith::CmpIPredicate::sge, index, zero);
172 auto ltHigh = builder.createOrFold<arith::CmpIOp>(
173 loc, arith::CmpIPredicate::slt, index, dimOp);
174 auto andOp = builder.createOrFold<arith::AndIOp>(loc, geLow, ltHigh);
175
176 assertCond =
177 i > 0 ? builder.createOrFold<arith::AndIOp>(loc, assertCond, andOp)
178 : andOp;
179 }
180 builder.create<cf::AssertOp>(
181 loc, assertCond, generateErrorMessage(op, "out-of-bounds access"));
182 }
183};
184
185/// Compute the linear index for the provided strided layout and indices.
186Value computeLinearIndex(OpBuilder &builder, Location loc, OpFoldResult offset,
187 ArrayRef<OpFoldResult> strides,
188 ArrayRef<OpFoldResult> indices) {
189 auto [expr, values] = computeLinearIndex(sourceOffset: offset, strides, indices);
190 auto index =
191 affine::makeComposedFoldedAffineApply(b&: builder, loc, expr, operands: values);
192 return getValueOrCreateConstantIndexOp(b&: builder, loc, ofr: index);
193}
194
195/// Returns two Values representing the bounds of the provided strided layout
196/// metadata. The bounds are returned as a half open interval -- [low, high).
197std::pair<Value, Value> computeLinearBounds(OpBuilder &builder, Location loc,
198 OpFoldResult offset,
199 ArrayRef<OpFoldResult> strides,
200 ArrayRef<OpFoldResult> sizes) {
201 auto zeros = SmallVector<int64_t>(sizes.size(), 0);
202 auto indices = getAsIndexOpFoldResult(ctx: builder.getContext(), values: zeros);
203 auto lowerBound = computeLinearIndex(builder, loc, offset, strides, indices);
204 auto upperBound = computeLinearIndex(builder, loc, offset, strides, indices: sizes);
205 return {lowerBound, upperBound};
206}
207
208/// Returns two Values representing the bounds of the memref. The bounds are
209/// returned as a half open interval -- [low, high).
210std::pair<Value, Value> computeLinearBounds(OpBuilder &builder, Location loc,
211 TypedValue<BaseMemRefType> memref) {
212 auto runtimeMetadata = builder.create<ExtractStridedMetadataOp>(loc, memref);
213 auto offset = runtimeMetadata.getConstifiedMixedOffset();
214 auto strides = runtimeMetadata.getConstifiedMixedStrides();
215 auto sizes = runtimeMetadata.getConstifiedMixedSizes();
216 return computeLinearBounds(builder, loc, offset, strides, sizes);
217}
218
219/// Verifies that the linear bounds of a reinterpret_cast op are within the
220/// linear bounds of the base memref: low >= baseLow && high <= baseHigh
221struct ReinterpretCastOpInterface
222 : public RuntimeVerifiableOpInterface::ExternalModel<
223 ReinterpretCastOpInterface, ReinterpretCastOp> {
224 void generateRuntimeVerification(Operation *op, OpBuilder &builder,
225 Location loc) const {
226 auto reinterpretCast = cast<ReinterpretCastOp>(op);
227 auto baseMemref = reinterpretCast.getSource();
228 auto resultMemref =
229 cast<TypedValue<BaseMemRefType>>(reinterpretCast.getResult());
230
231 builder.setInsertionPointAfter(op);
232
233 // Compute the linear bounds of the base memref
234 auto [baseLow, baseHigh] = computeLinearBounds(builder, loc, baseMemref);
235
236 // Compute the linear bounds of the resulting memref
237 auto [low, high] = computeLinearBounds(builder, loc, resultMemref);
238
239 // Check low >= baseLow
240 auto geLow = builder.createOrFold<arith::CmpIOp>(
241 loc, arith::CmpIPredicate::sge, low, baseLow);
242
243 // Check high <= baseHigh
244 auto leHigh = builder.createOrFold<arith::CmpIOp>(
245 loc, arith::CmpIPredicate::sle, high, baseHigh);
246
247 auto assertCond = builder.createOrFold<arith::AndIOp>(loc, geLow, leHigh);
248
249 builder.create<cf::AssertOp>(
250 loc, assertCond,
251 generateErrorMessage(
252 op,
253 "result of reinterpret_cast is out-of-bounds of the base memref"));
254 }
255};
256
257/// Verifies that the linear bounds of a subview op are within the linear bounds
258/// of the base memref: low >= baseLow && high <= baseHigh
259/// TODO: This is not yet a full runtime verification of subview. For example,
260/// consider:
261/// %m = memref.alloc(%c10, %c10) : memref<10x10xf32>
262/// memref.subview %m[%c0, %c0][%c20, %c2][%c1, %c1]
263/// : memref<?x?xf32> to memref<?x?xf32>
264/// The subview is in-bounds of the entire base memref but the first dimension
265/// is out-of-bounds. Future work would verify the bounds on a per-dimension
266/// basis.
267struct SubViewOpInterface
268 : public RuntimeVerifiableOpInterface::ExternalModel<SubViewOpInterface,
269 SubViewOp> {
270 void generateRuntimeVerification(Operation *op, OpBuilder &builder,
271 Location loc) const {
272 auto subView = cast<SubViewOp>(op);
273 auto baseMemref = cast<TypedValue<BaseMemRefType>>(subView.getSource());
274 auto resultMemref = cast<TypedValue<BaseMemRefType>>(subView.getResult());
275
276 builder.setInsertionPointAfter(op);
277
278 // Compute the linear bounds of the base memref
279 auto [baseLow, baseHigh] = computeLinearBounds(builder, loc, baseMemref);
280
281 // Compute the linear bounds of the resulting memref
282 auto [low, high] = computeLinearBounds(builder, loc, resultMemref);
283
284 // Check low >= baseLow
285 auto geLow = builder.createOrFold<arith::CmpIOp>(
286 loc, arith::CmpIPredicate::sge, low, baseLow);
287
288 // Check high <= baseHigh
289 auto leHigh = builder.createOrFold<arith::CmpIOp>(
290 loc, arith::CmpIPredicate::sle, high, baseHigh);
291
292 auto assertCond = builder.createOrFold<arith::AndIOp>(loc, geLow, leHigh);
293
294 builder.create<cf::AssertOp>(
295 loc, assertCond,
296 generateErrorMessage(op,
297 "subview is out-of-bounds of the base memref"));
298 }
299};
300
301struct ExpandShapeOpInterface
302 : public RuntimeVerifiableOpInterface::ExternalModel<ExpandShapeOpInterface,
303 ExpandShapeOp> {
304 void generateRuntimeVerification(Operation *op, OpBuilder &builder,
305 Location loc) const {
306 auto expandShapeOp = cast<ExpandShapeOp>(op);
307
308 // Verify that the expanded dim sizes are a product of the collapsed dim
309 // size.
310 for (const auto &it :
311 llvm::enumerate(expandShapeOp.getReassociationIndices())) {
312 Value srcDimSz =
313 builder.create<DimOp>(loc, expandShapeOp.getSrc(), it.index());
314 int64_t groupSz = 1;
315 bool foundDynamicDim = false;
316 for (int64_t resultDim : it.value()) {
317 if (expandShapeOp.getResultType().isDynamicDim(resultDim)) {
318 // Keep this assert here in case the op is extended in the future.
319 assert(!foundDynamicDim &&
320 "more than one dynamic dim found in reassoc group");
321 (void)foundDynamicDim;
322 foundDynamicDim = true;
323 continue;
324 }
325 groupSz *= expandShapeOp.getResultType().getDimSize(resultDim);
326 }
327 Value staticResultDimSz =
328 builder.create<arith::ConstantIndexOp>(loc, groupSz);
329 // staticResultDimSz must divide srcDimSz evenly.
330 Value mod =
331 builder.create<arith::RemSIOp>(loc, srcDimSz, staticResultDimSz);
332 Value isModZero = builder.create<arith::CmpIOp>(
333 loc, arith::CmpIPredicate::eq, mod,
334 builder.create<arith::ConstantIndexOp>(loc, 0));
335 builder.create<cf::AssertOp>(
336 loc, isModZero,
337 generateErrorMessage(op, "static result dims in reassoc group do not "
338 "divide src dim evenly"));
339 }
340 }
341};
342} // namespace
343} // namespace memref
344} // namespace mlir
345
346void mlir::memref::registerRuntimeVerifiableOpInterfaceExternalModels(
347 DialectRegistry &registry) {
348 registry.addExtension(extensionFn: +[](MLIRContext *ctx, memref::MemRefDialect *dialect) {
349 CastOp::attachInterface<CastOpInterface>(*ctx);
350 ExpandShapeOp::attachInterface<ExpandShapeOpInterface>(*ctx);
351 LoadOp::attachInterface<LoadStoreOpInterface<LoadOp>>(*ctx);
352 ReinterpretCastOp::attachInterface<ReinterpretCastOpInterface>(*ctx);
353 StoreOp::attachInterface<LoadStoreOpInterface<StoreOp>>(*ctx);
354 SubViewOp::attachInterface<SubViewOpInterface>(*ctx);
355
356 // Load additional dialects of which ops may get created.
357 ctx->loadDialect<affine::AffineDialect, arith::ArithDialect,
358 cf::ControlFlowDialect>();
359 });
360}
361

source code of mlir/lib/Dialect/MemRef/Transforms/RuntimeOpVerification.cpp