1//===- RuntimeOpVerification.cpp - Op Verification ------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#include "mlir/Dialect/Linalg/Transforms/RuntimeOpVerification.h"
10
11#include "mlir/Dialect/Affine/IR/AffineOps.h"
12#include "mlir/Dialect/Arith/IR/Arith.h"
13#include "mlir/Dialect/Arith/Utils/Utils.h"
14#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
15#include "mlir/Dialect/Index/IR/IndexAttrs.h"
16#include "mlir/Dialect/Index/IR/IndexDialect.h"
17#include "mlir/Dialect/Index/IR/IndexOps.h"
18#include "mlir/Dialect/Linalg/IR/Linalg.h"
19#include "mlir/Dialect/MemRef/IR/MemRef.h"
20#include "mlir/Dialect/Tensor/IR/Tensor.h"
21#include "mlir/Interfaces/RuntimeVerifiableOpInterface.h"
22
23namespace mlir {
24namespace linalg {
25namespace {
26/// Verify that the runtime sizes of the operands to linalg structured ops are
27/// compatible with the runtime sizes inferred by composing the loop ranges with
28/// the linalg op's indexing maps. This is similar to the verifier except that
29/// here we insert IR to perform the verification at runtime.
30template <typename T>
31struct StructuredOpInterface
32 : public RuntimeVerifiableOpInterface::ExternalModel<
33 StructuredOpInterface<T>, T> {
34 void generateRuntimeVerification(Operation *op, OpBuilder &builder,
35 Location loc) const {
36 auto linalgOp = llvm::cast<LinalgOp>(op);
37
38 SmallVector<Range> loopRanges = linalgOp.createLoopRanges(builder, loc);
39 auto [starts, ends, _] = getOffsetsSizesAndStrides(loopRanges);
40
41 auto zero = builder.create<arith::ConstantIndexOp>(loc, 0);
42 auto one = builder.create<arith::ConstantIndexOp>(loc, 1);
43
44 // Subtract one from the loop ends before composing with the indexing map
45 transform(ends, ends.begin(), [&](OpFoldResult end) {
46 auto endValue = getValueOrCreateConstantIndexOp(b&: builder, loc, ofr: end);
47 return builder.createOrFold<index::SubOp>(loc, endValue, one);
48 });
49
50 for (OpOperand &opOperand : linalgOp->getOpOperands()) {
51 AffineMap indexingMap = linalgOp.getMatchingIndexingMap(&opOperand);
52 auto startIndices = affine::makeComposedFoldedMultiResultAffineApply(
53 builder, loc, indexingMap, starts);
54 auto endIndices = affine::makeComposedFoldedMultiResultAffineApply(
55 builder, loc, indexingMap, ends);
56
57 for (auto dim : llvm::seq(linalgOp.getRank(&opOperand))) {
58 auto startIndex =
59 getValueOrCreateConstantIndexOp(builder, loc, startIndices[dim]);
60 auto endIndex =
61 getValueOrCreateConstantIndexOp(builder, loc, endIndices[dim]);
62
63 // Generate:
64 // minIndex = min(startIndex, endIndex)
65 // assert(minIndex >= 0)
66 // To ensure we do not generate a negative index. We take the minimum of
67 // the start and end indices in order to handle reverse loops such as
68 // `affine_map<(i) -> (3 - i)>`
69 auto min =
70 builder.createOrFold<index::MinSOp>(loc, startIndex, endIndex);
71 auto cmpOp = builder.createOrFold<index::CmpOp>(
72 loc, index::IndexCmpPredicate::SGE, min, zero);
73 auto msg = RuntimeVerifiableOpInterface::generateErrorMessage(
74 linalgOp, "unexpected negative result on dimension #" +
75 std::to_string(dim) + " of input/output operand #" +
76 std::to_string(opOperand.getOperandNumber()));
77 builder.createOrFold<cf::AssertOp>(loc, cmpOp, msg);
78
79 // Generate:
80 // inferredDimSize = max(startIndex, endIndex) + 1
81 // actualDimSize = dim(operand)
82 // assert(inferredDimSize <= actualDimSize)
83 // To ensure that we do not index past the bounds of the operands.
84 auto max =
85 builder.createOrFold<index::MaxSOp>(loc, startIndex, endIndex);
86
87 auto inferredDimSize =
88 builder.createOrFold<index::AddOp>(loc, max, one);
89
90 auto actualDimSize =
91 createOrFoldDimOp(builder, loc, opOperand.get(), dim);
92
93 // Similar to the verifier, when the affine expression in the indexing
94 // map is complicated, we just check that the inferred dimension sizes
95 // are in the boundary of the operands' size. Being more precise than
96 // that is difficult.
97 auto predicate = isa<AffineDimExpr>(indexingMap.getResult(dim))
98 ? index::IndexCmpPredicate::EQ
99 : index::IndexCmpPredicate::SLE;
100
101 cmpOp = builder.createOrFold<index::CmpOp>(
102 loc, predicate, inferredDimSize, actualDimSize);
103 msg = RuntimeVerifiableOpInterface::generateErrorMessage(
104 linalgOp, "dimension #" + std::to_string(dim) +
105 " of input/output operand #" +
106 std::to_string(opOperand.getOperandNumber()) +
107 " is incompatible with inferred dimension size");
108 builder.createOrFold<cf::AssertOp>(loc, cmpOp, msg);
109 }
110 }
111 }
112};
113
114template <typename... OpTs>
115void attachInterface(MLIRContext *ctx) {
116 (OpTs::template attachInterface<StructuredOpInterface<OpTs>>(*ctx), ...);
117}
118} // namespace
119} // namespace linalg
120} // namespace mlir
121
122void mlir::linalg::registerRuntimeVerifiableOpInterfaceExternalModels(
123 DialectRegistry &registry) {
124 registry.addExtension(+[](MLIRContext *ctx, LinalgDialect *) {
125 attachInterface<
126#define GET_OP_LIST
127#include "mlir/Dialect/Linalg/IR/LinalgStructuredOps.cpp.inc"
128 >(ctx);
129
130 // Load additional dialects of which ops may get created.
131 ctx->loadDialect<affine::AffineDialect, arith::ArithDialect,
132 cf::ControlFlowDialect, index::IndexDialect,
133 tensor::TensorDialect>();
134 });
135}
136

Provided by KDAB

Privacy Policy
Improve your Profiling and Debugging skills
Find out more

source code of mlir/lib/Dialect/Linalg/Transforms/RuntimeOpVerification.cpp