1//===- ValueBoundsOpInterface.cpp - Value Bounds -------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#include "mlir/Interfaces/ValueBoundsOpInterface.h"
10
11#include "mlir/IR/BuiltinTypes.h"
12#include "mlir/IR/Matchers.h"
13#include "mlir/Interfaces/DestinationStyleOpInterface.h"
14#include "mlir/Interfaces/ViewLikeInterface.h"
15#include "llvm/ADT/APSInt.h"
16#include "llvm/Support/Debug.h"
17
18#define DEBUG_TYPE "value-bounds-op-interface"
19
20using namespace mlir;
21using presburger::BoundType;
22using presburger::VarKind;
23
24namespace mlir {
25#include "mlir/Interfaces/ValueBoundsOpInterface.cpp.inc"
26} // namespace mlir
27
28static Operation *getOwnerOfValue(Value value) {
29 if (auto bbArg = dyn_cast<BlockArgument>(value))
30 return bbArg.getOwner()->getParentOp();
31 return value.getDefiningOp();
32}
33
34HyperrectangularSlice::HyperrectangularSlice(ArrayRef<OpFoldResult> offsets,
35 ArrayRef<OpFoldResult> sizes,
36 ArrayRef<OpFoldResult> strides)
37 : mixedOffsets(offsets), mixedSizes(sizes), mixedStrides(strides) {
38 assert(offsets.size() == sizes.size() &&
39 "expected same number of offsets, sizes, strides");
40 assert(offsets.size() == strides.size() &&
41 "expected same number of offsets, sizes, strides");
42}
43
44HyperrectangularSlice::HyperrectangularSlice(ArrayRef<OpFoldResult> offsets,
45 ArrayRef<OpFoldResult> sizes)
46 : mixedOffsets(offsets), mixedSizes(sizes) {
47 assert(offsets.size() == sizes.size() &&
48 "expected same number of offsets and sizes");
49 // Assume that all strides are 1.
50 if (offsets.empty())
51 return;
52 MLIRContext *ctx = offsets.front().getContext();
53 mixedStrides.append(offsets.size(), Builder(ctx).getIndexAttr(1));
54}
55
56HyperrectangularSlice::HyperrectangularSlice(OffsetSizeAndStrideOpInterface op)
57 : HyperrectangularSlice(op.getMixedOffsets(), op.getMixedSizes(),
58 op.getMixedStrides()) {}
59
60/// If ofr is a constant integer or an IntegerAttr, return the integer.
61static std::optional<int64_t> getConstantIntValue(OpFoldResult ofr) {
62 // Case 1: Check for Constant integer.
63 if (auto val = llvm::dyn_cast_if_present<Value>(Val&: ofr)) {
64 APSInt intVal;
65 if (matchPattern(val, m_ConstantInt(&intVal)))
66 return intVal.getSExtValue();
67 return std::nullopt;
68 }
69 // Case 2: Check for IntegerAttr.
70 Attribute attr = llvm::dyn_cast_if_present<Attribute>(Val&: ofr);
71 if (auto intAttr = dyn_cast_or_null<IntegerAttr>(attr))
72 return intAttr.getValue().getSExtValue();
73 return std::nullopt;
74}
75
76ValueBoundsConstraintSet::Variable::Variable(OpFoldResult ofr)
77 : Variable(ofr, std::nullopt) {}
78
79ValueBoundsConstraintSet::Variable::Variable(Value indexValue)
80 : Variable(static_cast<OpFoldResult>(indexValue)) {}
81
82ValueBoundsConstraintSet::Variable::Variable(Value shapedValue, int64_t dim)
83 : Variable(static_cast<OpFoldResult>(shapedValue), std::optional(dim)) {}
84
85ValueBoundsConstraintSet::Variable::Variable(OpFoldResult ofr,
86 std::optional<int64_t> dim) {
87 Builder b(ofr.getContext());
88 if (auto constInt = ::getConstantIntValue(ofr)) {
89 assert(!dim && "expected no dim for index-typed values");
90 map = AffineMap::get(/*dimCount=*/0, /*symbolCount=*/0,
91 result: b.getAffineConstantExpr(constant: *constInt));
92 return;
93 }
94 Value value = cast<Value>(Val&: ofr);
95#ifndef NDEBUG
96 if (dim) {
97 assert(isa<ShapedType>(value.getType()) && "expected shaped type");
98 } else {
99 assert(value.getType().isIndex() && "expected index type");
100 }
101#endif // NDEBUG
102 map = AffineMap::get(/*dimCount=*/0, /*symbolCount=*/1,
103 result: b.getAffineSymbolExpr(position: 0));
104 mapOperands.emplace_back(Args&: value, Args&: dim);
105}
106
107ValueBoundsConstraintSet::Variable::Variable(AffineMap map,
108 ArrayRef<Variable> mapOperands) {
109 assert(map.getNumResults() == 1 && "expected single result");
110
111 // Turn all dims into symbols.
112 Builder b(map.getContext());
113 SmallVector<AffineExpr> dimReplacements, symReplacements;
114 for (int64_t i = 0, e = map.getNumDims(); i < e; ++i)
115 dimReplacements.push_back(Elt: b.getAffineSymbolExpr(position: i));
116 for (int64_t i = 0, e = map.getNumSymbols(); i < e; ++i)
117 symReplacements.push_back(Elt: b.getAffineSymbolExpr(position: i + map.getNumDims()));
118 AffineMap tmpMap = map.replaceDimsAndSymbols(
119 dimReplacements, symReplacements, /*numResultDims=*/0,
120 /*numResultSyms=*/map.getNumSymbols() + map.getNumDims());
121
122 // Inline operands.
123 DenseMap<AffineExpr, AffineExpr> replacements;
124 for (auto [index, var] : llvm::enumerate(First&: mapOperands)) {
125 assert(var.map.getNumResults() == 1 && "expected single result");
126 assert(var.map.getNumDims() == 0 && "expected only symbols");
127 SmallVector<AffineExpr> symReplacements;
128 for (auto valueDim : var.mapOperands) {
129 auto it = llvm::find(Range&: this->mapOperands, Val: valueDim);
130 if (it != this->mapOperands.end()) {
131 // There is already a symbol for this operand.
132 symReplacements.push_back(Elt: b.getAffineSymbolExpr(
133 position: std::distance(first: this->mapOperands.begin(), last: it)));
134 } else {
135 // This is a new operand: add a new symbol.
136 symReplacements.push_back(
137 Elt: b.getAffineSymbolExpr(position: this->mapOperands.size()));
138 this->mapOperands.push_back(Elt: valueDim);
139 }
140 }
141 replacements[b.getAffineSymbolExpr(position: index)] =
142 var.map.getResult(idx: 0).replaceSymbols(symReplacements);
143 }
144 this->map = tmpMap.replace(map: replacements, /*numResultDims=*/0,
145 /*numResultSyms=*/this->mapOperands.size());
146}
147
148ValueBoundsConstraintSet::Variable::Variable(AffineMap map,
149 ArrayRef<Value> mapOperands)
150 : Variable(map, llvm::map_to_vector(C&: mapOperands,
151 F: [](Value v) { return Variable(v); })) {}
152
153ValueBoundsConstraintSet::ValueBoundsConstraintSet(
154 MLIRContext *ctx, StopConditionFn stopCondition)
155 : builder(ctx), stopCondition(stopCondition) {
156 assert(stopCondition && "expected non-null stop condition");
157}
158
159char ValueBoundsConstraintSet::ID = 0;
160
161#ifndef NDEBUG
162static void assertValidValueDim(Value value, std::optional<int64_t> dim) {
163 if (value.getType().isIndex()) {
164 assert(!dim.has_value() && "invalid dim value");
165 } else if (auto shapedType = dyn_cast<ShapedType>(value.getType())) {
166 assert(*dim >= 0 && "invalid dim value");
167 if (shapedType.hasRank())
168 assert(*dim < shapedType.getRank() && "invalid dim value");
169 } else {
170 llvm_unreachable("unsupported type");
171 }
172}
173#endif // NDEBUG
174
175void ValueBoundsConstraintSet::addBound(BoundType type, int64_t pos,
176 AffineExpr expr) {
177 LogicalResult status = cstr.addBound(
178 type, pos,
179 boundMap: AffineMap::get(dimCount: cstr.getNumDimVars(), symbolCount: cstr.getNumSymbolVars(), result: expr));
180 if (failed(result: status)) {
181 // Non-pure (e.g., semi-affine) expressions are not yet supported by
182 // FlatLinearConstraints. However, we can just ignore such failures here.
183 // Even without this bound, there may be enough information in the
184 // constraint system to compute the requested bound. In case this bound is
185 // actually needed, `computeBound` will return `failure`.
186 LLVM_DEBUG(llvm::dbgs() << "Failed to add bound: " << expr << "\n");
187 }
188}
189
190AffineExpr ValueBoundsConstraintSet::getExpr(Value value,
191 std::optional<int64_t> dim) {
192#ifndef NDEBUG
193 assertValidValueDim(value, dim);
194#endif // NDEBUG
195
196 // Check if the value/dim is statically known. In that case, an affine
197 // constant expression should be returned. This allows us to support
198 // multiplications with constants. (Multiplications of two columns in the
199 // constraint set is not supported.)
200 std::optional<int64_t> constSize = std::nullopt;
201 auto shapedType = dyn_cast<ShapedType>(value.getType());
202 if (shapedType) {
203 if (shapedType.hasRank() && !shapedType.isDynamicDim(*dim))
204 constSize = shapedType.getDimSize(*dim);
205 } else if (auto constInt = ::getConstantIntValue(ofr: value)) {
206 constSize = *constInt;
207 }
208
209 // If the value/dim is already mapped, return the corresponding expression
210 // directly.
211 ValueDim valueDim = std::make_pair(x&: value, y: dim.value_or(u: kIndexValue));
212 if (valueDimToPosition.contains(Val: valueDim)) {
213 // If it is a constant, return an affine constant expression. Otherwise,
214 // return an affine expression that represents the respective column in the
215 // constraint set.
216 if (constSize)
217 return builder.getAffineConstantExpr(constant: *constSize);
218 return getPosExpr(pos: getPos(value, dim));
219 }
220
221 if (constSize) {
222 // Constant index value/dim: add column to the constraint set, add EQ bound
223 // and return an affine constant expression without pushing the newly added
224 // column to the worklist.
225 (void)insert(value, dim, /*isSymbol=*/true, /*addToWorklist=*/false);
226 if (shapedType)
227 bound(value)[*dim] == *constSize;
228 else
229 bound(value) == *constSize;
230 return builder.getAffineConstantExpr(constant: *constSize);
231 }
232
233 // Dynamic value/dim: insert column to the constraint set and put it on the
234 // worklist. Return an affine expression that represents the newly inserted
235 // column in the constraint set.
236 return getPosExpr(pos: insert(value, dim, /*isSymbol=*/true));
237}
238
239AffineExpr ValueBoundsConstraintSet::getExpr(OpFoldResult ofr) {
240 if (Value value = llvm::dyn_cast_if_present<Value>(Val&: ofr))
241 return getExpr(value, /*dim=*/std::nullopt);
242 auto constInt = ::getConstantIntValue(ofr);
243 assert(constInt.has_value() && "expected Integer constant");
244 return builder.getAffineConstantExpr(constant: *constInt);
245}
246
247AffineExpr ValueBoundsConstraintSet::getExpr(int64_t constant) {
248 return builder.getAffineConstantExpr(constant);
249}
250
251int64_t ValueBoundsConstraintSet::insert(Value value,
252 std::optional<int64_t> dim,
253 bool isSymbol, bool addToWorklist) {
254#ifndef NDEBUG
255 assertValidValueDim(value, dim);
256#endif // NDEBUG
257
258 ValueDim valueDim = std::make_pair(x&: value, y: dim.value_or(u: kIndexValue));
259 assert(!valueDimToPosition.contains(valueDim) && "already mapped");
260 int64_t pos = isSymbol ? cstr.appendVar(kind: VarKind::Symbol)
261 : cstr.appendVar(kind: VarKind::SetDim);
262 LLVM_DEBUG(llvm::dbgs() << "Inserting constraint set column " << pos
263 << " for: " << value
264 << " (dim: " << dim.value_or(kIndexValue)
265 << ", owner: " << getOwnerOfValue(value)->getName()
266 << ")\n");
267 positionToValueDim.insert(I: positionToValueDim.begin() + pos, Elt: valueDim);
268 // Update reverse mapping.
269 for (int64_t i = pos, e = positionToValueDim.size(); i < e; ++i)
270 if (positionToValueDim[i].has_value())
271 valueDimToPosition[*positionToValueDim[i]] = i;
272
273 if (addToWorklist) {
274 LLVM_DEBUG(llvm::dbgs() << "Push to worklist: " << value
275 << " (dim: " << dim.value_or(kIndexValue) << ")\n");
276 worklist.push(x: pos);
277 }
278
279 return pos;
280}
281
282int64_t ValueBoundsConstraintSet::insert(bool isSymbol) {
283 int64_t pos = isSymbol ? cstr.appendVar(kind: VarKind::Symbol)
284 : cstr.appendVar(kind: VarKind::SetDim);
285 LLVM_DEBUG(llvm::dbgs() << "Inserting anonymous constraint set column " << pos
286 << "\n");
287 positionToValueDim.insert(I: positionToValueDim.begin() + pos, Elt: std::nullopt);
288 // Update reverse mapping.
289 for (int64_t i = pos, e = positionToValueDim.size(); i < e; ++i)
290 if (positionToValueDim[i].has_value())
291 valueDimToPosition[*positionToValueDim[i]] = i;
292 return pos;
293}
294
295int64_t ValueBoundsConstraintSet::insert(AffineMap map, ValueDimList operands,
296 bool isSymbol) {
297 assert(map.getNumResults() == 1 && "expected affine map with one result");
298 int64_t pos = insert(isSymbol);
299
300 // Add map and operands to the constraint set. Dimensions are converted to
301 // symbols. All operands are added to the worklist (unless they were already
302 // processed).
303 auto mapper = [&](std::pair<Value, std::optional<int64_t>> v) {
304 return getExpr(value: v.first, dim: v.second);
305 };
306 SmallVector<AffineExpr> dimReplacements = llvm::to_vector(
307 Range: llvm::map_range(C: ArrayRef(operands).take_front(N: map.getNumDims()), F: mapper));
308 SmallVector<AffineExpr> symReplacements = llvm::to_vector(
309 Range: llvm::map_range(C: ArrayRef(operands).drop_front(N: map.getNumDims()), F: mapper));
310 addBound(
311 type: presburger::BoundType::EQ, pos,
312 expr: map.getResult(idx: 0).replaceDimsAndSymbols(dimReplacements, symReplacements));
313
314 return pos;
315}
316
317int64_t ValueBoundsConstraintSet::insert(const Variable &var, bool isSymbol) {
318 return insert(map: var.map, operands: var.mapOperands, isSymbol);
319}
320
321int64_t ValueBoundsConstraintSet::getPos(Value value,
322 std::optional<int64_t> dim) const {
323#ifndef NDEBUG
324 assertValidValueDim(value, dim);
325 assert((isa<OpResult>(value) ||
326 cast<BlockArgument>(value).getOwner()->isEntryBlock()) &&
327 "unstructured control flow is not supported");
328#endif // NDEBUG
329 LLVM_DEBUG(llvm::dbgs() << "Getting pos for: " << value
330 << " (dim: " << dim.value_or(kIndexValue)
331 << ", owner: " << getOwnerOfValue(value)->getName()
332 << ")\n");
333 auto it =
334 valueDimToPosition.find(Val: std::make_pair(x&: value, y: dim.value_or(u: kIndexValue)));
335 assert(it != valueDimToPosition.end() && "expected mapped entry");
336 return it->second;
337}
338
339AffineExpr ValueBoundsConstraintSet::getPosExpr(int64_t pos) {
340 assert(pos >= 0 && pos < cstr.getNumDimAndSymbolVars() && "invalid position");
341 return pos < cstr.getNumDimVars()
342 ? builder.getAffineDimExpr(position: pos)
343 : builder.getAffineSymbolExpr(position: pos - cstr.getNumDimVars());
344}
345
346bool ValueBoundsConstraintSet::isMapped(Value value,
347 std::optional<int64_t> dim) const {
348 auto it =
349 valueDimToPosition.find(Val: std::make_pair(x&: value, y: dim.value_or(u: kIndexValue)));
350 return it != valueDimToPosition.end();
351}
352
353void ValueBoundsConstraintSet::processWorklist() {
354 LLVM_DEBUG(llvm::dbgs() << "Processing value bounds worklist...\n");
355 while (!worklist.empty()) {
356 int64_t pos = worklist.front();
357 worklist.pop();
358 assert(positionToValueDim[pos].has_value() &&
359 "did not expect std::nullopt on worklist");
360 ValueDim valueDim = *positionToValueDim[pos];
361 Value value = valueDim.first;
362 int64_t dim = valueDim.second;
363
364 // Check for static dim size.
365 if (dim != kIndexValue) {
366 auto shapedType = cast<ShapedType>(value.getType());
367 if (shapedType.hasRank() && !shapedType.isDynamicDim(dim)) {
368 bound(value)[dim] == getExpr(shapedType.getDimSize(dim));
369 continue;
370 }
371 }
372
373 // Do not process any further if the stop condition is met.
374 auto maybeDim = dim == kIndexValue ? std::nullopt : std::make_optional(t&: dim);
375 if (stopCondition(value, maybeDim, *this)) {
376 LLVM_DEBUG(llvm::dbgs() << "Stop condition met for: " << value
377 << " (dim: " << maybeDim << ")\n");
378 continue;
379 }
380
381 // Query `ValueBoundsOpInterface` for constraints. New items may be added to
382 // the worklist.
383 auto valueBoundsOp =
384 dyn_cast<ValueBoundsOpInterface>(getOwnerOfValue(value));
385 LLVM_DEBUG(llvm::dbgs()
386 << "Query value bounds for: " << value
387 << " (owner: " << getOwnerOfValue(value)->getName() << ")\n");
388 if (valueBoundsOp) {
389 if (dim == kIndexValue) {
390 valueBoundsOp.populateBoundsForIndexValue(value, *this);
391 } else {
392 valueBoundsOp.populateBoundsForShapedValueDim(value, dim, *this);
393 }
394 continue;
395 }
396 LLVM_DEBUG(llvm::dbgs() << "--> ValueBoundsOpInterface not implemented\n");
397
398 // If the op does not implement `ValueBoundsOpInterface`, check if it
399 // implements the `DestinationStyleOpInterface`. OpResults of such ops are
400 // tied to OpOperands. Tied values have the same shape.
401 auto dstOp = value.getDefiningOp<DestinationStyleOpInterface>();
402 if (!dstOp || dim == kIndexValue)
403 continue;
404 Value tiedOperand = dstOp.getTiedOpOperand(cast<OpResult>(Val&: value))->get();
405 bound(value)[dim] == getExpr(value: tiedOperand, dim);
406 }
407}
408
409void ValueBoundsConstraintSet::projectOut(int64_t pos) {
410 assert(pos >= 0 && pos < static_cast<int64_t>(positionToValueDim.size()) &&
411 "invalid position");
412 cstr.projectOut(pos);
413 if (positionToValueDim[pos].has_value()) {
414 bool erased = valueDimToPosition.erase(Val: *positionToValueDim[pos]);
415 (void)erased;
416 assert(erased && "inconsistent reverse mapping");
417 }
418 positionToValueDim.erase(CI: positionToValueDim.begin() + pos);
419 // Update reverse mapping.
420 for (int64_t i = pos, e = positionToValueDim.size(); i < e; ++i)
421 if (positionToValueDim[i].has_value())
422 valueDimToPosition[*positionToValueDim[i]] = i;
423}
424
425void ValueBoundsConstraintSet::projectOut(
426 function_ref<bool(ValueDim)> condition) {
427 int64_t nextPos = 0;
428 while (nextPos < static_cast<int64_t>(positionToValueDim.size())) {
429 if (positionToValueDim[nextPos].has_value() &&
430 condition(*positionToValueDim[nextPos])) {
431 projectOut(pos: nextPos);
432 // The column was projected out so another column is now at that position.
433 // Do not increase the counter.
434 } else {
435 ++nextPos;
436 }
437 }
438}
439
440void ValueBoundsConstraintSet::projectOutAnonymous(
441 std::optional<int64_t> except) {
442 int64_t nextPos = 0;
443 while (nextPos < static_cast<int64_t>(positionToValueDim.size())) {
444 if (positionToValueDim[nextPos].has_value() || except == nextPos) {
445 ++nextPos;
446 } else {
447 projectOut(pos: nextPos);
448 // The column was projected out so another column is now at that position.
449 // Do not increase the counter.
450 }
451 }
452}
453
454LogicalResult ValueBoundsConstraintSet::computeBound(
455 AffineMap &resultMap, ValueDimList &mapOperands, presburger::BoundType type,
456 const Variable &var, StopConditionFn stopCondition, bool closedUB) {
457 MLIRContext *ctx = var.getContext();
458 int64_t ubAdjustment = closedUB ? 0 : 1;
459 Builder b(ctx);
460 mapOperands.clear();
461
462 // Process the backward slice of `value` (i.e., reverse use-def chain) until
463 // `stopCondition` is met.
464 ValueBoundsConstraintSet cstr(ctx, stopCondition);
465 int64_t pos = cstr.insert(var, /*isSymbol=*/false);
466 assert(pos == 0 && "expected first column");
467 cstr.processWorklist();
468
469 // Project out all variables (apart from `valueDim`) that do not match the
470 // stop condition.
471 cstr.projectOut(condition: [&](ValueDim p) {
472 auto maybeDim =
473 p.second == kIndexValue ? std::nullopt : std::make_optional(t&: p.second);
474 return !stopCondition(p.first, maybeDim, cstr);
475 });
476 cstr.projectOutAnonymous(/*except=*/pos);
477
478 // Compute lower and upper bounds for `valueDim`.
479 SmallVector<AffineMap> lb(1), ub(1);
480 cstr.cstr.getSliceBounds(offset: pos, num: 1, context: ctx, lbMaps: &lb, ubMaps: &ub,
481 /*closedUB=*/true);
482
483 // Note: There are TODOs in the implementation of `getSliceBounds`. In such a
484 // case, no lower/upper bound can be computed at the moment.
485 // EQ, UB bounds: upper bound is needed.
486 if ((type != BoundType::LB) &&
487 (ub.empty() || !ub[0] || ub[0].getNumResults() == 0))
488 return failure();
489 // EQ, LB bounds: lower bound is needed.
490 if ((type != BoundType::UB) &&
491 (lb.empty() || !lb[0] || lb[0].getNumResults() == 0))
492 return failure();
493
494 // TODO: Generate an affine map with multiple results.
495 if (type != BoundType::LB)
496 assert(ub.size() == 1 && ub[0].getNumResults() == 1 &&
497 "multiple bounds not supported");
498 if (type != BoundType::UB)
499 assert(lb.size() == 1 && lb[0].getNumResults() == 1 &&
500 "multiple bounds not supported");
501
502 // EQ bound: lower and upper bound must match.
503 if (type == BoundType::EQ && ub[0] != lb[0])
504 return failure();
505
506 AffineMap bound;
507 if (type == BoundType::EQ || type == BoundType::LB) {
508 bound = lb[0];
509 } else {
510 // Computed UB is a closed bound.
511 bound = AffineMap::get(dimCount: ub[0].getNumDims(), symbolCount: ub[0].getNumSymbols(),
512 result: ub[0].getResult(idx: 0) + ubAdjustment);
513 }
514
515 // Gather all SSA values that are used in the computed bound.
516 assert(cstr.cstr.getNumDimAndSymbolVars() == cstr.positionToValueDim.size() &&
517 "inconsistent mapping state");
518 SmallVector<AffineExpr> replacementDims, replacementSymbols;
519 int64_t numDims = 0, numSymbols = 0;
520 for (int64_t i = 0; i < cstr.cstr.getNumDimAndSymbolVars(); ++i) {
521 // Skip `value`.
522 if (i == pos)
523 continue;
524 // Check if the position `i` is used in the generated bound. If so, it must
525 // be included in the generated affine.apply op.
526 bool used = false;
527 bool isDim = i < cstr.cstr.getNumDimVars();
528 if (isDim) {
529 if (bound.isFunctionOfDim(position: i))
530 used = true;
531 } else {
532 if (bound.isFunctionOfSymbol(position: i - cstr.cstr.getNumDimVars()))
533 used = true;
534 }
535
536 if (!used) {
537 // Not used: Remove dim/symbol from the result.
538 if (isDim) {
539 replacementDims.push_back(Elt: b.getAffineConstantExpr(constant: 0));
540 } else {
541 replacementSymbols.push_back(Elt: b.getAffineConstantExpr(constant: 0));
542 }
543 continue;
544 }
545
546 if (isDim) {
547 replacementDims.push_back(Elt: b.getAffineDimExpr(position: numDims++));
548 } else {
549 replacementSymbols.push_back(Elt: b.getAffineSymbolExpr(position: numSymbols++));
550 }
551
552 assert(cstr.positionToValueDim[i].has_value() &&
553 "cannot build affine map in terms of anonymous column");
554 ValueBoundsConstraintSet::ValueDim valueDim = *cstr.positionToValueDim[i];
555 Value value = valueDim.first;
556 int64_t dim = valueDim.second;
557 if (dim == ValueBoundsConstraintSet::kIndexValue) {
558 // An index-type value is used: can be used directly in the affine.apply
559 // op.
560 assert(value.getType().isIndex() && "expected index type");
561 mapOperands.push_back(Elt: std::make_pair(x&: value, y: std::nullopt));
562 continue;
563 }
564
565 assert(cast<ShapedType>(value.getType()).isDynamicDim(dim) &&
566 "expected dynamic dim");
567 mapOperands.push_back(Elt: std::make_pair(x&: value, y&: dim));
568 }
569
570 resultMap = bound.replaceDimsAndSymbols(dimReplacements: replacementDims, symReplacements: replacementSymbols,
571 numResultDims: numDims, numResultSyms: numSymbols);
572 return success();
573}
574
575LogicalResult ValueBoundsConstraintSet::computeDependentBound(
576 AffineMap &resultMap, ValueDimList &mapOperands, presburger::BoundType type,
577 const Variable &var, ValueDimList dependencies, bool closedUB) {
578 return computeBound(
579 resultMap, mapOperands, type, var,
580 stopCondition: [&](Value v, std::optional<int64_t> d, ValueBoundsConstraintSet &cstr) {
581 return llvm::is_contained(Range&: dependencies, Element: std::make_pair(x&: v, y&: d));
582 },
583 closedUB);
584}
585
586LogicalResult ValueBoundsConstraintSet::computeIndependentBound(
587 AffineMap &resultMap, ValueDimList &mapOperands, presburger::BoundType type,
588 const Variable &var, ValueRange independencies, bool closedUB) {
589 // Return "true" if the given value is independent of all values in
590 // `independencies`. I.e., neither the value itself nor any value in the
591 // backward slice (reverse use-def chain) is contained in `independencies`.
592 auto isIndependent = [&](Value v) {
593 SmallVector<Value> worklist;
594 DenseSet<Value> visited;
595 worklist.push_back(Elt: v);
596 while (!worklist.empty()) {
597 Value next = worklist.pop_back_val();
598 if (visited.contains(V: next))
599 continue;
600 visited.insert(V: next);
601 if (llvm::is_contained(Range&: independencies, Element: next))
602 return false;
603 // TODO: DominanceInfo could be used to stop the traversal early.
604 Operation *op = next.getDefiningOp();
605 if (!op)
606 continue;
607 worklist.append(in_start: op->getOperands().begin(), in_end: op->getOperands().end());
608 }
609 return true;
610 };
611
612 // Reify bounds in terms of any independent values.
613 return computeBound(
614 resultMap, mapOperands, type, var,
615 stopCondition: [&](Value v, std::optional<int64_t> d, ValueBoundsConstraintSet &cstr) {
616 return isIndependent(v);
617 },
618 closedUB);
619}
620
621FailureOr<int64_t> ValueBoundsConstraintSet::computeConstantBound(
622 presburger::BoundType type, const Variable &var,
623 StopConditionFn stopCondition, bool closedUB) {
624 // Default stop condition if none was specified: Keep adding constraints until
625 // a bound could be computed.
626 int64_t pos = 0;
627 auto defaultStopCondition = [&](Value v, std::optional<int64_t> dim,
628 ValueBoundsConstraintSet &cstr) {
629 return cstr.cstr.getConstantBound64(type, pos).has_value();
630 };
631
632 ValueBoundsConstraintSet cstr(
633 var.getContext(), stopCondition ? stopCondition : defaultStopCondition);
634 pos = cstr.populateConstraints(map: var.map, mapOperands: var.mapOperands);
635 assert(pos == 0 && "expected `map` is the first column");
636
637 // Compute constant bound for `valueDim`.
638 int64_t ubAdjustment = closedUB ? 0 : 1;
639 if (auto bound = cstr.cstr.getConstantBound64(type, pos))
640 return type == BoundType::UB ? *bound + ubAdjustment : *bound;
641 return failure();
642}
643
644void ValueBoundsConstraintSet::populateConstraints(Value value,
645 std::optional<int64_t> dim) {
646#ifndef NDEBUG
647 assertValidValueDim(value, dim);
648#endif // NDEBUG
649
650 // `getExpr` pushes the value/dim onto the worklist (unless it was already
651 // analyzed).
652 (void)getExpr(value, dim);
653 // Process all values/dims on the worklist. This may traverse and analyze
654 // additional IR, depending the current stop function.
655 processWorklist();
656}
657
658int64_t ValueBoundsConstraintSet::populateConstraints(AffineMap map,
659 ValueDimList operands) {
660 int64_t pos = insert(map, operands, /*isSymbol=*/false);
661 // Process the backward slice of `operands` (i.e., reverse use-def chain)
662 // until `stopCondition` is met.
663 processWorklist();
664 return pos;
665}
666
667FailureOr<int64_t>
668ValueBoundsConstraintSet::computeConstantDelta(Value value1, Value value2,
669 std::optional<int64_t> dim1,
670 std::optional<int64_t> dim2) {
671#ifndef NDEBUG
672 assertValidValueDim(value: value1, dim: dim1);
673 assertValidValueDim(value: value2, dim: dim2);
674#endif // NDEBUG
675
676 Builder b(value1.getContext());
677 AffineMap map = AffineMap::get(/*dimCount=*/2, /*symbolCount=*/0,
678 result: b.getAffineDimExpr(position: 0) - b.getAffineDimExpr(position: 1));
679 return computeConstantBound(type: presburger::BoundType::EQ,
680 var: Variable(map, {{value1, dim1}, {value2, dim2}}));
681}
682
683bool ValueBoundsConstraintSet::comparePos(int64_t lhsPos,
684 ComparisonOperator cmp,
685 int64_t rhsPos) {
686 // This function returns "true" if "lhs CMP rhs" is proven to hold.
687 //
688 // Example for ComparisonOperator::LE and index-typed values: We would like to
689 // prove that lhs <= rhs. Proof by contradiction: add the inverse
690 // relation (lhs > rhs) to the constraint set and check if the resulting
691 // constraint set is "empty" (i.e. has no solution). In that case,
692 // lhs > rhs must be incorrect and we can deduce that lhs <= rhs holds.
693
694 // We cannot prove anything if the constraint set is already empty.
695 if (cstr.isEmpty()) {
696 LLVM_DEBUG(
697 llvm::dbgs()
698 << "cannot compare value/dims: constraint system is already empty");
699 return false;
700 }
701
702 // EQ can be expressed as LE and GE.
703 if (cmp == EQ)
704 return comparePos(lhsPos, cmp: ComparisonOperator::LE, rhsPos) &&
705 comparePos(lhsPos, cmp: ComparisonOperator::GE, rhsPos);
706
707 // Construct inequality.
708 SmallVector<int64_t> eq(cstr.getNumCols(), 0);
709 if (cmp == LT || cmp == LE) {
710 ++eq[lhsPos];
711 --eq[rhsPos];
712 } else if (cmp == GT || cmp == GE) {
713 --eq[lhsPos];
714 ++eq[rhsPos];
715 } else {
716 llvm_unreachable("unsupported comparison operator");
717 }
718 if (cmp == LE || cmp == GE)
719 eq[cstr.getNumCols() - 1] -= 1;
720
721 // Add inequality to the constraint set and check if it made the constraint
722 // set empty.
723 int64_t ineqPos = cstr.getNumInequalities();
724 cstr.addInequality(inEq: eq);
725 bool isEmpty = cstr.isEmpty();
726 cstr.removeInequality(pos: ineqPos);
727 return isEmpty;
728}
729
730bool ValueBoundsConstraintSet::populateAndCompare(const Variable &lhs,
731 ComparisonOperator cmp,
732 const Variable &rhs) {
733 int64_t lhsPos = populateConstraints(map: lhs.map, operands: lhs.mapOperands);
734 int64_t rhsPos = populateConstraints(map: rhs.map, operands: rhs.mapOperands);
735 return comparePos(lhsPos, cmp, rhsPos);
736}
737
738bool ValueBoundsConstraintSet::compare(const Variable &lhs,
739 ComparisonOperator cmp,
740 const Variable &rhs) {
741 int64_t lhsPos = -1, rhsPos = -1;
742 auto stopCondition = [&](Value v, std::optional<int64_t> dim,
743 ValueBoundsConstraintSet &cstr) {
744 // Keep processing as long as lhs/rhs were not processed.
745 if (size_t(lhsPos) >= cstr.positionToValueDim.size() ||
746 size_t(rhsPos) >= cstr.positionToValueDim.size())
747 return false;
748 // Keep processing as long as the relation cannot be proven.
749 return cstr.comparePos(lhsPos, cmp, rhsPos);
750 };
751 ValueBoundsConstraintSet cstr(lhs.getContext(), stopCondition);
752 lhsPos = cstr.populateConstraints(map: lhs.map, operands: lhs.mapOperands);
753 rhsPos = cstr.populateConstraints(map: rhs.map, operands: rhs.mapOperands);
754 return cstr.comparePos(lhsPos, cmp, rhsPos);
755}
756
757FailureOr<bool> ValueBoundsConstraintSet::areEqual(const Variable &var1,
758 const Variable &var2) {
759 if (ValueBoundsConstraintSet::compare(lhs: var1, cmp: ComparisonOperator::EQ, rhs: var2))
760 return true;
761 if (ValueBoundsConstraintSet::compare(lhs: var1, cmp: ComparisonOperator::LT, rhs: var2) ||
762 ValueBoundsConstraintSet::compare(lhs: var1, cmp: ComparisonOperator::GT, rhs: var2))
763 return false;
764 return failure();
765}
766
767FailureOr<bool>
768ValueBoundsConstraintSet::areOverlappingSlices(MLIRContext *ctx,
769 HyperrectangularSlice slice1,
770 HyperrectangularSlice slice2) {
771 assert(slice1.getMixedOffsets().size() == slice1.getMixedOffsets().size() &&
772 "expected slices of same rank");
773 assert(slice1.getMixedSizes().size() == slice1.getMixedSizes().size() &&
774 "expected slices of same rank");
775 assert(slice1.getMixedStrides().size() == slice1.getMixedStrides().size() &&
776 "expected slices of same rank");
777
778 Builder b(ctx);
779 bool foundUnknownBound = false;
780 for (int64_t i = 0, e = slice1.getMixedOffsets().size(); i < e; ++i) {
781 AffineMap map =
782 AffineMap::get(/*dimCount=*/0, /*symbolCount=*/4,
783 result: b.getAffineSymbolExpr(position: 0) +
784 b.getAffineSymbolExpr(position: 1) * b.getAffineSymbolExpr(position: 2) -
785 b.getAffineSymbolExpr(position: 3));
786 {
787 // Case 1: Slices are guaranteed to be non-overlapping if
788 // offset1 + size1 * stride1 <= offset2 (for at least one dimension).
789 SmallVector<OpFoldResult> ofrOperands;
790 ofrOperands.push_back(Elt: slice1.getMixedOffsets()[i]);
791 ofrOperands.push_back(Elt: slice1.getMixedSizes()[i]);
792 ofrOperands.push_back(Elt: slice1.getMixedStrides()[i]);
793 ofrOperands.push_back(Elt: slice2.getMixedOffsets()[i]);
794 SmallVector<Value> valueOperands;
795 AffineMap foldedMap =
796 foldAttributesIntoMap(b, map, operands: ofrOperands, remainingValues&: valueOperands);
797 FailureOr<int64_t> constBound = computeConstantBound(
798 type: presburger::BoundType::EQ, var: Variable(foldedMap, valueOperands));
799 foundUnknownBound |= failed(result: constBound);
800 if (succeeded(result: constBound) && *constBound <= 0)
801 return false;
802 }
803 {
804 // Case 2: Slices are guaranteed to be non-overlapping if
805 // offset2 + size2 * stride2 <= offset1 (for at least one dimension).
806 SmallVector<OpFoldResult> ofrOperands;
807 ofrOperands.push_back(Elt: slice2.getMixedOffsets()[i]);
808 ofrOperands.push_back(Elt: slice2.getMixedSizes()[i]);
809 ofrOperands.push_back(Elt: slice2.getMixedStrides()[i]);
810 ofrOperands.push_back(Elt: slice1.getMixedOffsets()[i]);
811 SmallVector<Value> valueOperands;
812 AffineMap foldedMap =
813 foldAttributesIntoMap(b, map, operands: ofrOperands, remainingValues&: valueOperands);
814 FailureOr<int64_t> constBound = computeConstantBound(
815 type: presburger::BoundType::EQ, var: Variable(foldedMap, valueOperands));
816 foundUnknownBound |= failed(result: constBound);
817 if (succeeded(result: constBound) && *constBound <= 0)
818 return false;
819 }
820 }
821
822 // If at least one bound could not be computed, we cannot be certain that the
823 // slices are really overlapping.
824 if (foundUnknownBound)
825 return failure();
826
827 // All bounds could be computed and none of the above cases applied.
828 // Therefore, the slices are guaranteed to overlap.
829 return true;
830}
831
832FailureOr<bool>
833ValueBoundsConstraintSet::areEquivalentSlices(MLIRContext *ctx,
834 HyperrectangularSlice slice1,
835 HyperrectangularSlice slice2) {
836 assert(slice1.getMixedOffsets().size() == slice1.getMixedOffsets().size() &&
837 "expected slices of same rank");
838 assert(slice1.getMixedSizes().size() == slice1.getMixedSizes().size() &&
839 "expected slices of same rank");
840 assert(slice1.getMixedStrides().size() == slice1.getMixedStrides().size() &&
841 "expected slices of same rank");
842
843 // The two slices are equivalent if all of their offsets, sizes and strides
844 // are equal. If equality cannot be determined for at least one of those
845 // values, equivalence cannot be determined and this function returns
846 // "failure".
847 for (auto [offset1, offset2] :
848 llvm::zip_equal(t: slice1.getMixedOffsets(), u: slice2.getMixedOffsets())) {
849 FailureOr<bool> equal = areEqual(var1: offset1, var2: offset2);
850 if (failed(result: equal))
851 return failure();
852 if (!equal.value())
853 return false;
854 }
855 for (auto [size1, size2] :
856 llvm::zip_equal(t: slice1.getMixedSizes(), u: slice2.getMixedSizes())) {
857 FailureOr<bool> equal = areEqual(var1: size1, var2: size2);
858 if (failed(result: equal))
859 return failure();
860 if (!equal.value())
861 return false;
862 }
863 for (auto [stride1, stride2] :
864 llvm::zip_equal(t: slice1.getMixedStrides(), u: slice2.getMixedStrides())) {
865 FailureOr<bool> equal = areEqual(var1: stride1, var2: stride2);
866 if (failed(result: equal))
867 return failure();
868 if (!equal.value())
869 return false;
870 }
871 return true;
872}
873
874void ValueBoundsConstraintSet::dump() const {
875 llvm::errs() << "==========\nColumns:\n";
876 llvm::errs() << "(column\tdim\tvalue)\n";
877 for (auto [index, valueDim] : llvm::enumerate(First: positionToValueDim)) {
878 llvm::errs() << " " << index << "\t";
879 if (valueDim) {
880 if (valueDim->second == kIndexValue) {
881 llvm::errs() << "n/a\t";
882 } else {
883 llvm::errs() << valueDim->second << "\t";
884 }
885 llvm::errs() << getOwnerOfValue(value: valueDim->first)->getName() << " ";
886 if (OpResult result = dyn_cast<OpResult>(Val: valueDim->first)) {
887 llvm::errs() << "(result " << result.getResultNumber() << ")";
888 } else {
889 llvm::errs() << "(bbarg "
890 << cast<BlockArgument>(Val: valueDim->first).getArgNumber()
891 << ")";
892 }
893 llvm::errs() << "\n";
894 } else {
895 llvm::errs() << "n/a\tn/a\n";
896 }
897 }
898 llvm::errs() << "\nConstraint set:\n";
899 cstr.dump();
900 llvm::errs() << "==========\n";
901}
902
903ValueBoundsConstraintSet::BoundBuilder &
904ValueBoundsConstraintSet::BoundBuilder::operator[](int64_t dim) {
905 assert(!this->dim.has_value() && "dim was already set");
906 this->dim = dim;
907#ifndef NDEBUG
908 assertValidValueDim(value, dim: this->dim);
909#endif // NDEBUG
910 return *this;
911}
912
913void ValueBoundsConstraintSet::BoundBuilder::operator<(AffineExpr expr) {
914#ifndef NDEBUG
915 assertValidValueDim(value, dim: this->dim);
916#endif // NDEBUG
917 cstr.addBound(type: BoundType::UB, pos: cstr.getPos(value, dim: this->dim), expr);
918}
919
920void ValueBoundsConstraintSet::BoundBuilder::operator<=(AffineExpr expr) {
921 operator<(expr: expr + 1);
922}
923
924void ValueBoundsConstraintSet::BoundBuilder::operator>(AffineExpr expr) {
925 operator>=(expr: expr + 1);
926}
927
928void ValueBoundsConstraintSet::BoundBuilder::operator>=(AffineExpr expr) {
929#ifndef NDEBUG
930 assertValidValueDim(value, dim: this->dim);
931#endif // NDEBUG
932 cstr.addBound(type: BoundType::LB, pos: cstr.getPos(value, dim: this->dim), expr);
933}
934
935void ValueBoundsConstraintSet::BoundBuilder::operator==(AffineExpr expr) {
936#ifndef NDEBUG
937 assertValidValueDim(value, dim: this->dim);
938#endif // NDEBUG
939 cstr.addBound(type: BoundType::EQ, pos: cstr.getPos(value, dim: this->dim), expr);
940}
941
942void ValueBoundsConstraintSet::BoundBuilder::operator<(OpFoldResult ofr) {
943 operator<(expr: cstr.getExpr(ofr));
944}
945
946void ValueBoundsConstraintSet::BoundBuilder::operator<=(OpFoldResult ofr) {
947 operator<=(expr: cstr.getExpr(ofr));
948}
949
950void ValueBoundsConstraintSet::BoundBuilder::operator>(OpFoldResult ofr) {
951 operator>(expr: cstr.getExpr(ofr));
952}
953
954void ValueBoundsConstraintSet::BoundBuilder::operator>=(OpFoldResult ofr) {
955 operator>=(expr: cstr.getExpr(ofr));
956}
957
958void ValueBoundsConstraintSet::BoundBuilder::operator==(OpFoldResult ofr) {
959 operator==(expr: cstr.getExpr(ofr));
960}
961
962void ValueBoundsConstraintSet::BoundBuilder::operator<(int64_t i) {
963 operator<(expr: cstr.getExpr(constant: i));
964}
965
966void ValueBoundsConstraintSet::BoundBuilder::operator<=(int64_t i) {
967 operator<=(expr: cstr.getExpr(constant: i));
968}
969
970void ValueBoundsConstraintSet::BoundBuilder::operator>(int64_t i) {
971 operator>(expr: cstr.getExpr(constant: i));
972}
973
974void ValueBoundsConstraintSet::BoundBuilder::operator>=(int64_t i) {
975 operator>=(expr: cstr.getExpr(constant: i));
976}
977
978void ValueBoundsConstraintSet::BoundBuilder::operator==(int64_t i) {
979 operator==(expr: cstr.getExpr(constant: i));
980}
981

source code of mlir/lib/Interfaces/ValueBoundsOpInterface.cpp