1 | //===- AffineStructures.cpp - MLIR Affine Structures Class-----------------===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | // |
9 | // Structures for affine/polyhedral analysis of affine dialect ops. |
10 | // |
11 | //===----------------------------------------------------------------------===// |
12 | |
13 | #include "mlir/Dialect/Affine/Analysis/AffineStructures.h" |
14 | #include "mlir/Analysis/Presburger/IntegerRelation.h" |
15 | #include "mlir/Analysis/Presburger/LinearTransform.h" |
16 | #include "mlir/Analysis/Presburger/Simplex.h" |
17 | #include "mlir/Analysis/Presburger/Utils.h" |
18 | #include "mlir/Dialect/Affine/IR/AffineOps.h" |
19 | #include "mlir/Dialect/Affine/IR/AffineValueMap.h" |
20 | #include "mlir/Dialect/Utils/StaticValueUtils.h" |
21 | #include "mlir/IR/AffineExprVisitor.h" |
22 | #include "mlir/IR/IntegerSet.h" |
23 | #include "mlir/Support/LLVM.h" |
24 | #include "mlir/Support/MathExtras.h" |
25 | #include "llvm/ADT/STLExtras.h" |
26 | #include "llvm/ADT/SmallPtrSet.h" |
27 | #include "llvm/ADT/SmallVector.h" |
28 | #include "llvm/Support/Debug.h" |
29 | #include "llvm/Support/raw_ostream.h" |
30 | #include <optional> |
31 | |
32 | #define DEBUG_TYPE "affine-structures" |
33 | |
34 | using namespace mlir; |
35 | using namespace affine; |
36 | using namespace presburger; |
37 | |
38 | |
39 | void FlatAffineValueConstraints::addInductionVarOrTerminalSymbol(Value val) { |
40 | if (containsVar(val)) |
41 | return; |
42 | |
43 | // Caller is expected to fully compose map/operands if necessary. |
44 | assert((isTopLevelValue(val) || isAffineInductionVar(val)) && |
45 | "non-terminal symbol / loop IV expected" ); |
46 | // Outer loop IVs could be used in forOp's bounds. |
47 | if (auto loop = getForInductionVarOwner(val)) { |
48 | appendDimVar(vals: val); |
49 | if (failed(this->addAffineForOpDomain(forOp: loop))) |
50 | LLVM_DEBUG( |
51 | loop.emitWarning("failed to add domain info to constraint system" )); |
52 | return; |
53 | } |
54 | if (auto parallel = getAffineParallelInductionVarOwner(val)) { |
55 | appendDimVar(parallel.getIVs()); |
56 | if (failed(this->addAffineParallelOpDomain(parallelOp: parallel))) |
57 | LLVM_DEBUG(parallel.emitWarning( |
58 | "failed to add domain info to constraint system" )); |
59 | return; |
60 | } |
61 | |
62 | // Add top level symbol. |
63 | appendSymbolVar(vals: val); |
64 | // Check if the symbol is a constant. |
65 | if (std::optional<int64_t> constOp = getConstantIntValue(ofr: val)) |
66 | addBound(type: BoundType::EQ, val, value: constOp.value()); |
67 | } |
68 | |
69 | LogicalResult |
70 | FlatAffineValueConstraints::addAffineForOpDomain(AffineForOp forOp) { |
71 | unsigned pos; |
72 | // Pre-condition for this method. |
73 | if (!findVar(val: forOp.getInductionVar(), pos: &pos)) { |
74 | assert(false && "Value not found" ); |
75 | return failure(); |
76 | } |
77 | |
78 | int64_t step = forOp.getStepAsInt(); |
79 | if (step != 1) { |
80 | if (!forOp.hasConstantLowerBound()) |
81 | LLVM_DEBUG(forOp.emitWarning("domain conservatively approximated" )); |
82 | else { |
83 | // Add constraints for the stride. |
84 | // (iv - lb) % step = 0 can be written as: |
85 | // (iv - lb) - step * q = 0 where q = (iv - lb) / step. |
86 | // Add local variable 'q' and add the above equality. |
87 | // The first constraint is q = (iv - lb) floordiv step |
88 | SmallVector<int64_t, 8> dividend(getNumCols(), 0); |
89 | int64_t lb = forOp.getConstantLowerBound(); |
90 | dividend[pos] = 1; |
91 | dividend.back() -= lb; |
92 | addLocalFloorDiv(dividend, divisor: step); |
93 | // Second constraint: (iv - lb) - step * q = 0. |
94 | SmallVector<int64_t, 8> eq(getNumCols(), 0); |
95 | eq[pos] = 1; |
96 | eq.back() -= lb; |
97 | // For the local var just added above. |
98 | eq[getNumCols() - 2] = -step; |
99 | addEquality(eq); |
100 | } |
101 | } |
102 | |
103 | if (forOp.hasConstantLowerBound()) { |
104 | addBound(BoundType::LB, pos, forOp.getConstantLowerBound()); |
105 | } else { |
106 | // Non-constant lower bound case. |
107 | if (failed(addBound(BoundType::LB, pos, forOp.getLowerBoundMap(), |
108 | forOp.getLowerBoundOperands()))) |
109 | return failure(); |
110 | } |
111 | |
112 | if (forOp.hasConstantUpperBound()) { |
113 | addBound(BoundType::UB, pos, forOp.getConstantUpperBound() - 1); |
114 | return success(); |
115 | } |
116 | // Non-constant upper bound case. |
117 | return addBound(BoundType::UB, pos, forOp.getUpperBoundMap(), |
118 | forOp.getUpperBoundOperands()); |
119 | } |
120 | |
121 | LogicalResult FlatAffineValueConstraints::addAffineParallelOpDomain( |
122 | AffineParallelOp parallelOp) { |
123 | size_t ivPos = 0; |
124 | for (Value iv : parallelOp.getIVs()) { |
125 | unsigned pos; |
126 | if (!findVar(iv, &pos)) { |
127 | assert(false && "variable expected for the IV value" ); |
128 | return failure(); |
129 | } |
130 | |
131 | AffineMap lowerBound = parallelOp.getLowerBoundMap(ivPos); |
132 | if (lowerBound.isConstant()) |
133 | addBound(BoundType::LB, pos, lowerBound.getSingleConstantResult()); |
134 | else if (failed(addBound(BoundType::LB, pos, lowerBound, |
135 | parallelOp.getLowerBoundsOperands()))) |
136 | return failure(); |
137 | |
138 | auto upperBound = parallelOp.getUpperBoundMap(ivPos); |
139 | if (upperBound.isConstant()) |
140 | addBound(BoundType::UB, pos, upperBound.getSingleConstantResult() - 1); |
141 | else if (failed(addBound(BoundType::UB, pos, upperBound, |
142 | parallelOp.getUpperBoundsOperands()))) |
143 | return failure(); |
144 | ++ivPos; |
145 | } |
146 | return success(); |
147 | } |
148 | |
149 | LogicalResult |
150 | FlatAffineValueConstraints::addDomainFromSliceMaps(ArrayRef<AffineMap> lbMaps, |
151 | ArrayRef<AffineMap> ubMaps, |
152 | ArrayRef<Value> operands) { |
153 | assert(lbMaps.size() == ubMaps.size()); |
154 | assert(lbMaps.size() <= getNumDimVars()); |
155 | |
156 | for (unsigned i = 0, e = lbMaps.size(); i < e; ++i) { |
157 | AffineMap lbMap = lbMaps[i]; |
158 | AffineMap ubMap = ubMaps[i]; |
159 | assert(!lbMap || lbMap.getNumInputs() == operands.size()); |
160 | assert(!ubMap || ubMap.getNumInputs() == operands.size()); |
161 | |
162 | // Check if this slice is just an equality along this dimension. If so, |
163 | // retrieve the existing loop it equates to and add it to the system. |
164 | if (lbMap && ubMap && lbMap.getNumResults() == 1 && |
165 | ubMap.getNumResults() == 1 && |
166 | lbMap.getResult(idx: 0) + 1 == ubMap.getResult(idx: 0) && |
167 | // The condition above will be true for maps describing a single |
168 | // iteration (e.g., lbMap.getResult(0) = 0, ubMap.getResult(0) = 1). |
169 | // Make sure we skip those cases by checking that the lb result is not |
170 | // just a constant. |
171 | !isa<AffineConstantExpr>(Val: lbMap.getResult(idx: 0))) { |
172 | // Limited support: we expect the lb result to be just a loop dimension. |
173 | // Not supported otherwise for now. |
174 | AffineDimExpr result = dyn_cast<AffineDimExpr>(Val: lbMap.getResult(idx: 0)); |
175 | if (!result) |
176 | return failure(); |
177 | |
178 | AffineForOp loop = |
179 | getForInductionVarOwner(operands[result.getPosition()]); |
180 | if (!loop) |
181 | return failure(); |
182 | |
183 | if (failed(addAffineForOpDomain(forOp: loop))) |
184 | return failure(); |
185 | continue; |
186 | } |
187 | |
188 | // This slice refers to a loop that doesn't exist in the IR yet. Add its |
189 | // bounds to the system assuming its dimension variable position is the |
190 | // same as the position of the loop in the loop nest. |
191 | if (lbMap && failed(result: addBound(type: BoundType::LB, pos: i, boundMap: lbMap, operands))) |
192 | return failure(); |
193 | if (ubMap && failed(result: addBound(type: BoundType::UB, pos: i, boundMap: ubMap, operands))) |
194 | return failure(); |
195 | } |
196 | return success(); |
197 | } |
198 | |
199 | void FlatAffineValueConstraints::addAffineIfOpDomain(AffineIfOp ifOp) { |
200 | IntegerSet set = ifOp.getIntegerSet(); |
201 | // Canonicalize set and operands to ensure unique values for |
202 | // FlatAffineValueConstraints below and for early simplification. |
203 | SmallVector<Value> operands(ifOp.getOperands()); |
204 | canonicalizeSetAndOperands(set: &set, operands: &operands); |
205 | |
206 | // Create the base constraints from the integer set attached to ifOp. |
207 | FlatAffineValueConstraints cst(set, operands); |
208 | |
209 | // Merge the constraints from ifOp to the current domain. We need first merge |
210 | // and align the IDs from both constraints, and then append the constraints |
211 | // from the ifOp into the current one. |
212 | mergeAndAlignVarsWithOther(offset: 0, other: &cst); |
213 | append(other: cst); |
214 | } |
215 | |
216 | LogicalResult FlatAffineValueConstraints::addBound(BoundType type, unsigned pos, |
217 | AffineMap boundMap, |
218 | ValueRange boundOperands) { |
219 | // Fully compose map and operands; canonicalize and simplify so that we |
220 | // transitively get to terminal symbols or loop IVs. |
221 | auto map = boundMap; |
222 | SmallVector<Value, 4> operands(boundOperands.begin(), boundOperands.end()); |
223 | fullyComposeAffineMapAndOperands(map: &map, operands: &operands); |
224 | map = simplifyAffineMap(map); |
225 | canonicalizeMapAndOperands(map: &map, operands: &operands); |
226 | for (auto operand : operands) |
227 | addInductionVarOrTerminalSymbol(val: operand); |
228 | return addBound(type, pos, boundMap: computeAlignedMap(map, operands)); |
229 | } |
230 | |
231 | // Adds slice lower bounds represented by lower bounds in 'lbMaps' and upper |
232 | // bounds in 'ubMaps' to each value in `values' that appears in the constraint |
233 | // system. Note that both lower/upper bounds share the same operand list |
234 | // 'operands'. |
235 | // This function assumes 'values.size' == 'lbMaps.size' == 'ubMaps.size', and |
236 | // skips any null AffineMaps in 'lbMaps' or 'ubMaps'. |
237 | // Note that both lower/upper bounds use operands from 'operands'. |
238 | // Returns failure for unimplemented cases such as semi-affine expressions or |
239 | // expressions with mod/floordiv. |
240 | LogicalResult FlatAffineValueConstraints::addSliceBounds( |
241 | ArrayRef<Value> values, ArrayRef<AffineMap> lbMaps, |
242 | ArrayRef<AffineMap> ubMaps, ArrayRef<Value> operands) { |
243 | assert(values.size() == lbMaps.size()); |
244 | assert(lbMaps.size() == ubMaps.size()); |
245 | |
246 | for (unsigned i = 0, e = lbMaps.size(); i < e; ++i) { |
247 | unsigned pos; |
248 | if (!findVar(val: values[i], pos: &pos)) |
249 | continue; |
250 | |
251 | AffineMap lbMap = lbMaps[i]; |
252 | AffineMap ubMap = ubMaps[i]; |
253 | assert(!lbMap || lbMap.getNumInputs() == operands.size()); |
254 | assert(!ubMap || ubMap.getNumInputs() == operands.size()); |
255 | |
256 | // Check if this slice is just an equality along this dimension. |
257 | if (lbMap && ubMap && lbMap.getNumResults() == 1 && |
258 | ubMap.getNumResults() == 1 && |
259 | lbMap.getResult(idx: 0) + 1 == ubMap.getResult(idx: 0)) { |
260 | if (failed(result: addBound(type: BoundType::EQ, pos, boundMap: lbMap, boundOperands: operands))) |
261 | return failure(); |
262 | continue; |
263 | } |
264 | |
265 | // If lower or upper bound maps are null or provide no results, it implies |
266 | // that the source loop was not at all sliced, and the entire loop will be a |
267 | // part of the slice. |
268 | if (lbMap && lbMap.getNumResults() != 0 && ubMap && |
269 | ubMap.getNumResults() != 0) { |
270 | if (failed(result: addBound(type: BoundType::LB, pos, boundMap: lbMap, boundOperands: operands))) |
271 | return failure(); |
272 | if (failed(result: addBound(type: BoundType::UB, pos, boundMap: ubMap, boundOperands: operands))) |
273 | return failure(); |
274 | } else { |
275 | auto loop = getForInductionVarOwner(values[i]); |
276 | if (failed(this->addAffineForOpDomain(forOp: loop))) |
277 | return failure(); |
278 | } |
279 | } |
280 | return success(); |
281 | } |
282 | |
283 | LogicalResult |
284 | FlatAffineValueConstraints::composeMap(const AffineValueMap *vMap) { |
285 | return composeMatchingMap( |
286 | other: computeAlignedMap(map: vMap->getAffineMap(), operands: vMap->getOperands())); |
287 | } |
288 | |
289 | // Turn a symbol into a dimension. |
290 | static void turnSymbolIntoDim(FlatAffineValueConstraints *cst, Value value) { |
291 | unsigned pos; |
292 | if (cst->findVar(val: value, pos: &pos) && pos >= cst->getNumDimVars() && |
293 | pos < cst->getNumDimAndSymbolVars()) { |
294 | cst->swapVar(posA: pos, posB: cst->getNumDimVars()); |
295 | cst->setDimSymbolSeparation(cst->getNumSymbolVars() - 1); |
296 | } |
297 | } |
298 | |
299 | // Changes all symbol variables which are loop IVs to dim variables. |
300 | void FlatAffineValueConstraints::convertLoopIVSymbolsToDims() { |
301 | // Gather all symbols which are loop IVs. |
302 | SmallVector<Value, 4> loopIVs; |
303 | for (unsigned i = getNumDimVars(), e = getNumDimAndSymbolVars(); i < e; i++) { |
304 | if (hasValue(pos: i) && getForInductionVarOwner(getValue(pos: i))) |
305 | loopIVs.push_back(Elt: getValue(pos: i)); |
306 | } |
307 | // Turn each symbol in 'loopIVs' into a dim variable. |
308 | for (auto iv : loopIVs) { |
309 | turnSymbolIntoDim(cst: this, value: iv); |
310 | } |
311 | } |
312 | |
313 | void FlatAffineValueConstraints::getIneqAsAffineValueMap( |
314 | unsigned pos, unsigned ineqPos, AffineValueMap &vmap, |
315 | MLIRContext *context) const { |
316 | unsigned numDims = getNumDimVars(); |
317 | unsigned numSyms = getNumSymbolVars(); |
318 | |
319 | assert(pos < numDims && "invalid position" ); |
320 | assert(ineqPos < getNumInequalities() && "invalid inequality position" ); |
321 | |
322 | // Get expressions for local vars. |
323 | SmallVector<AffineExpr, 8> memo(getNumVars(), AffineExpr()); |
324 | if (failed(result: computeLocalVars(memo, context))) |
325 | assert(false && |
326 | "one or more local exprs do not have an explicit representation" ); |
327 | auto localExprs = ArrayRef<AffineExpr>(memo).take_back(N: getNumLocalVars()); |
328 | |
329 | // Compute the AffineExpr lower/upper bound for this inequality. |
330 | SmallVector<int64_t, 8> inequality = getInequality64(idx: ineqPos); |
331 | SmallVector<int64_t, 8> bound; |
332 | bound.reserve(N: getNumCols() - 1); |
333 | // Everything other than the coefficient at `pos`. |
334 | bound.append(in_start: inequality.begin(), in_end: inequality.begin() + pos); |
335 | bound.append(in_start: inequality.begin() + pos + 1, in_end: inequality.end()); |
336 | |
337 | if (inequality[pos] > 0) |
338 | // Lower bound. |
339 | std::transform(first: bound.begin(), last: bound.end(), result: bound.begin(), |
340 | unary_op: std::negate<int64_t>()); |
341 | else |
342 | // Upper bound (which is exclusive). |
343 | bound.back() += 1; |
344 | |
345 | // Convert to AffineExpr (tree) form. |
346 | auto boundExpr = getAffineExprFromFlatForm(flatExprs: bound, numDims: numDims - 1, numSymbols: numSyms, |
347 | localExprs, context); |
348 | |
349 | // Get the values to bind to this affine expr (all dims and symbols). |
350 | SmallVector<Value, 4> operands; |
351 | getValues(start: 0, end: pos, values: &operands); |
352 | SmallVector<Value, 4> trailingOperands; |
353 | getValues(start: pos + 1, end: getNumDimAndSymbolVars(), values: &trailingOperands); |
354 | operands.append(in_start: trailingOperands.begin(), in_end: trailingOperands.end()); |
355 | vmap.reset(map: AffineMap::get(dimCount: numDims - 1, symbolCount: numSyms, result: boundExpr), operands); |
356 | } |
357 | |
358 | FlatAffineValueConstraints FlatAffineRelation::getDomainSet() const { |
359 | FlatAffineValueConstraints domain = *this; |
360 | // Convert all range variables to local variables. |
361 | domain.convertToLocal(kind: VarKind::SetDim, varStart: getNumDomainDims(), |
362 | varLimit: getNumDomainDims() + getNumRangeDims()); |
363 | return domain; |
364 | } |
365 | |
366 | FlatAffineValueConstraints FlatAffineRelation::getRangeSet() const { |
367 | FlatAffineValueConstraints range = *this; |
368 | // Convert all domain variables to local variables. |
369 | range.convertToLocal(kind: VarKind::SetDim, varStart: 0, varLimit: getNumDomainDims()); |
370 | return range; |
371 | } |
372 | |
373 | void FlatAffineRelation::compose(const FlatAffineRelation &other) { |
374 | assert(getNumDomainDims() == other.getNumRangeDims() && |
375 | "Domain of this and range of other do not match" ); |
376 | assert(space.getDomainSpace().isAligned(other.getSpace().getRangeSpace()) && |
377 | "Values of domain of this and range of other do not match" ); |
378 | |
379 | FlatAffineRelation rel = other; |
380 | |
381 | // Convert `rel` from |
382 | // [otherDomain] -> [otherRange] |
383 | // to |
384 | // [otherDomain] -> [otherRange thisRange] |
385 | // and `this` from |
386 | // [thisDomain] -> [thisRange] |
387 | // to |
388 | // [otherDomain thisDomain] -> [thisRange]. |
389 | unsigned removeDims = rel.getNumRangeDims(); |
390 | insertDomainVar(pos: 0, num: rel.getNumDomainDims()); |
391 | rel.appendRangeVar(num: getNumRangeDims()); |
392 | |
393 | // Merge symbol and local variables. |
394 | mergeSymbolVars(other&: rel); |
395 | mergeLocalVars(other&: rel); |
396 | |
397 | // Convert `rel` from [otherDomain] -> [otherRange thisRange] to |
398 | // [otherDomain] -> [thisRange] by converting first otherRange range vars |
399 | // to local vars. |
400 | rel.convertToLocal(kind: VarKind::SetDim, varStart: rel.getNumDomainDims(), |
401 | varLimit: rel.getNumDomainDims() + removeDims); |
402 | // Convert `this` from [otherDomain thisDomain] -> [thisRange] to |
403 | // [otherDomain] -> [thisRange] by converting last thisDomain domain vars |
404 | // to local vars. |
405 | convertToLocal(kind: VarKind::SetDim, varStart: getNumDomainDims() - removeDims, |
406 | varLimit: getNumDomainDims()); |
407 | |
408 | auto thisMaybeValues = getMaybeValues(kind: VarKind::SetDim); |
409 | auto relMaybeValues = rel.getMaybeValues(kind: VarKind::SetDim); |
410 | |
411 | // Add and match domain of `rel` to domain of `this`. |
412 | for (unsigned i = 0, e = rel.getNumDomainDims(); i < e; ++i) |
413 | if (relMaybeValues[i].has_value()) |
414 | setValue(pos: i, val: *relMaybeValues[i]); |
415 | // Add and match range of `this` to range of `rel`. |
416 | for (unsigned i = 0, e = getNumRangeDims(); i < e; ++i) { |
417 | unsigned rangeIdx = rel.getNumDomainDims() + i; |
418 | if (thisMaybeValues[rangeIdx].has_value()) |
419 | rel.setValue(pos: rangeIdx, val: *thisMaybeValues[rangeIdx]); |
420 | } |
421 | |
422 | // Append `this` to `rel` and simplify constraints. |
423 | rel.append(other: *this); |
424 | rel.removeRedundantLocalVars(); |
425 | |
426 | *this = rel; |
427 | } |
428 | |
429 | void FlatAffineRelation::inverse() { |
430 | unsigned oldDomain = getNumDomainDims(); |
431 | unsigned oldRange = getNumRangeDims(); |
432 | // Add new range vars. |
433 | appendRangeVar(num: oldDomain); |
434 | // Swap new vars with domain. |
435 | for (unsigned i = 0; i < oldDomain; ++i) |
436 | swapVar(posA: i, posB: oldDomain + oldRange + i); |
437 | // Remove the swapped domain. |
438 | removeVarRange(varStart: 0, varLimit: oldDomain); |
439 | // Set domain and range as inverse. |
440 | numDomainDims = oldRange; |
441 | numRangeDims = oldDomain; |
442 | } |
443 | |
444 | void FlatAffineRelation::insertDomainVar(unsigned pos, unsigned num) { |
445 | assert(pos <= getNumDomainDims() && |
446 | "Var cannot be inserted at invalid position" ); |
447 | insertDimVar(pos, num); |
448 | numDomainDims += num; |
449 | } |
450 | |
451 | void FlatAffineRelation::insertRangeVar(unsigned pos, unsigned num) { |
452 | assert(pos <= getNumRangeDims() && |
453 | "Var cannot be inserted at invalid position" ); |
454 | insertDimVar(pos: getNumDomainDims() + pos, num); |
455 | numRangeDims += num; |
456 | } |
457 | |
458 | void FlatAffineRelation::appendDomainVar(unsigned num) { |
459 | insertDimVar(pos: getNumDomainDims(), num); |
460 | numDomainDims += num; |
461 | } |
462 | |
463 | void FlatAffineRelation::appendRangeVar(unsigned num) { |
464 | insertDimVar(pos: getNumDimVars(), num); |
465 | numRangeDims += num; |
466 | } |
467 | |
468 | void FlatAffineRelation::removeVarRange(VarKind kind, unsigned varStart, |
469 | unsigned varLimit) { |
470 | assert(varLimit <= getNumVarKind(kind)); |
471 | if (varStart >= varLimit) |
472 | return; |
473 | |
474 | FlatAffineValueConstraints::removeVarRange(kind, varStart, varLimit); |
475 | |
476 | // If kind is not SetDim, domain and range don't need to be updated. |
477 | if (kind != VarKind::SetDim) |
478 | return; |
479 | |
480 | // Compute number of domain and range variables to remove. This is done by |
481 | // intersecting the range of domain/range vars with range of vars to remove. |
482 | unsigned intersectDomainLHS = std::min(a: varLimit, b: getNumDomainDims()); |
483 | unsigned intersectDomainRHS = varStart; |
484 | unsigned intersectRangeLHS = std::min(a: varLimit, b: getNumDimVars()); |
485 | unsigned intersectRangeRHS = std::max(a: varStart, b: getNumDomainDims()); |
486 | |
487 | if (intersectDomainLHS > intersectDomainRHS) |
488 | numDomainDims -= intersectDomainLHS - intersectDomainRHS; |
489 | if (intersectRangeLHS > intersectRangeRHS) |
490 | numRangeDims -= intersectRangeLHS - intersectRangeRHS; |
491 | } |
492 | |
493 | LogicalResult mlir::affine::getRelationFromMap(AffineMap &map, |
494 | IntegerRelation &rel) { |
495 | // Get flattened affine expressions. |
496 | std::vector<SmallVector<int64_t, 8>> flatExprs; |
497 | FlatAffineValueConstraints localVarCst; |
498 | if (failed(result: getFlattenedAffineExprs(map, flattenedExprs: &flatExprs, cst: &localVarCst))) |
499 | return failure(); |
500 | |
501 | const unsigned oldDimNum = localVarCst.getNumDimVars(); |
502 | const unsigned oldCols = localVarCst.getNumCols(); |
503 | const unsigned numRangeVars = map.getNumResults(); |
504 | const unsigned numDomainVars = map.getNumDims(); |
505 | |
506 | // Add range as the new expressions. |
507 | localVarCst.appendDimVar(num: numRangeVars); |
508 | |
509 | // Add identifiers to the local constraints as getFlattenedAffineExprs creates |
510 | // a FlatLinearConstraints with no identifiers. |
511 | for (unsigned i = 0, e = localVarCst.getNumDimAndSymbolVars(); i < e; ++i) |
512 | localVarCst.setValue(pos: i, val: Value()); |
513 | |
514 | // Add equalities between source and range. |
515 | SmallVector<int64_t, 8> eq(localVarCst.getNumCols()); |
516 | for (unsigned i = 0, e = map.getNumResults(); i < e; ++i) { |
517 | // Zero fill. |
518 | std::fill(first: eq.begin(), last: eq.end(), value: 0); |
519 | // Fill equality. |
520 | for (unsigned j = 0, f = oldDimNum; j < f; ++j) |
521 | eq[j] = flatExprs[i][j]; |
522 | for (unsigned j = oldDimNum, f = oldCols; j < f; ++j) |
523 | eq[j + numRangeVars] = flatExprs[i][j]; |
524 | // Set this dimension to -1 to equate lhs and rhs and add equality. |
525 | eq[numDomainVars + i] = -1; |
526 | localVarCst.addEquality(eq); |
527 | } |
528 | |
529 | rel = localVarCst; |
530 | return success(); |
531 | } |
532 | |
533 | LogicalResult mlir::affine::getRelationFromMap(const AffineValueMap &map, |
534 | IntegerRelation &rel) { |
535 | |
536 | AffineMap affineMap = map.getAffineMap(); |
537 | if (failed(result: getRelationFromMap(map&: affineMap, rel))) |
538 | return failure(); |
539 | |
540 | // Set identifiers for domain and symbol variables. |
541 | for (unsigned i = 0, e = affineMap.getNumDims(); i < e; ++i) |
542 | rel.setId(kind: VarKind::SetDim, i, id: Identifier(map.getOperand(i))); |
543 | |
544 | const unsigned mapNumResults = affineMap.getNumResults(); |
545 | for (unsigned i = 0, e = rel.getNumSymbolVars(); i < e; ++i) |
546 | rel.setId( |
547 | kind: VarKind::Symbol, i, |
548 | id: Identifier(map.getOperand(i: rel.getNumDimVars() + i - mapNumResults))); |
549 | |
550 | return success(); |
551 | } |
552 | |