| 1 | //===- AffineStructures.cpp - MLIR Affine Structures Class-----------------===// |
| 2 | // |
| 3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
| 4 | // See https://llvm.org/LICENSE.txt for license information. |
| 5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
| 6 | // |
| 7 | //===----------------------------------------------------------------------===// |
| 8 | // |
| 9 | // Structures for affine/polyhedral analysis of affine dialect ops. |
| 10 | // |
| 11 | //===----------------------------------------------------------------------===// |
| 12 | |
| 13 | #include "mlir/Dialect/Affine/Analysis/AffineStructures.h" |
| 14 | #include "mlir/Analysis/Presburger/IntegerRelation.h" |
| 15 | #include "mlir/Analysis/Presburger/LinearTransform.h" |
| 16 | #include "mlir/Analysis/Presburger/Simplex.h" |
| 17 | #include "mlir/Analysis/Presburger/Utils.h" |
| 18 | #include "mlir/Dialect/Affine/IR/AffineOps.h" |
| 19 | #include "mlir/Dialect/Affine/IR/AffineValueMap.h" |
| 20 | #include "mlir/Dialect/Utils/StaticValueUtils.h" |
| 21 | #include "mlir/IR/AffineExprVisitor.h" |
| 22 | #include "mlir/IR/IntegerSet.h" |
| 23 | #include "mlir/Support/LLVM.h" |
| 24 | #include "llvm/ADT/STLExtras.h" |
| 25 | #include "llvm/ADT/SmallPtrSet.h" |
| 26 | #include "llvm/ADT/SmallVector.h" |
| 27 | #include "llvm/Support/Debug.h" |
| 28 | #include "llvm/Support/raw_ostream.h" |
| 29 | #include <optional> |
| 30 | |
| 31 | #define DEBUG_TYPE "affine-structures" |
| 32 | |
| 33 | using namespace mlir; |
| 34 | using namespace affine; |
| 35 | using namespace presburger; |
| 36 | |
| 37 | LogicalResult |
| 38 | FlatAffineValueConstraints::addInductionVarOrTerminalSymbol(Value val) { |
| 39 | if (containsVar(val)) |
| 40 | return success(); |
| 41 | |
| 42 | // Caller is expected to fully compose map/operands if necessary. |
| 43 | if (val.getDefiningOp<affine::AffineApplyOp>() || |
| 44 | (!isValidSymbol(value: val) && !isAffineInductionVar(val))) { |
| 45 | LLVM_DEBUG(llvm::dbgs() |
| 46 | << "only valid terminal symbols and affine IVs supported\n" ); |
| 47 | return failure(); |
| 48 | } |
| 49 | // Outer loop IVs could be used in forOp's bounds. |
| 50 | if (auto loop = getForInductionVarOwner(val)) { |
| 51 | appendDimVar(vals: val); |
| 52 | if (failed(this->addAffineForOpDomain(forOp: loop))) { |
| 53 | LLVM_DEBUG( |
| 54 | loop.emitWarning("failed to add domain info to constraint system" )); |
| 55 | return failure(); |
| 56 | } |
| 57 | return success(); |
| 58 | } |
| 59 | |
| 60 | if (auto parallel = getAffineParallelInductionVarOwner(val)) { |
| 61 | appendDimVar(parallel.getIVs()); |
| 62 | if (failed(this->addAffineParallelOpDomain(parallelOp: parallel))) { |
| 63 | LLVM_DEBUG(parallel.emitWarning( |
| 64 | "failed to add domain info to constraint system" )); |
| 65 | return failure(); |
| 66 | } |
| 67 | return success(); |
| 68 | } |
| 69 | |
| 70 | // Add top level symbol. |
| 71 | appendSymbolVar(vals: val); |
| 72 | // Check if the symbol is a constant. |
| 73 | if (std::optional<int64_t> constOp = getConstantIntValue(ofr: val)) |
| 74 | addBound(type: BoundType::EQ, val, value: constOp.value()); |
| 75 | return success(); |
| 76 | } |
| 77 | |
| 78 | LogicalResult |
| 79 | FlatAffineValueConstraints::addAffineForOpDomain(AffineForOp forOp) { |
| 80 | unsigned pos; |
| 81 | // Pre-condition for this method. |
| 82 | if (!findVar(val: forOp.getInductionVar(), pos: &pos)) { |
| 83 | assert(false && "Value not found" ); |
| 84 | return failure(); |
| 85 | } |
| 86 | |
| 87 | int64_t step = forOp.getStepAsInt(); |
| 88 | if (step != 1) { |
| 89 | if (!forOp.hasConstantLowerBound()) |
| 90 | LLVM_DEBUG(forOp.emitWarning("domain conservatively approximated" )); |
| 91 | else { |
| 92 | // Add constraints for the stride. |
| 93 | // (iv - lb) % step = 0 can be written as: |
| 94 | // (iv - lb) - step * q = 0 where q = (iv - lb) / step. |
| 95 | // Add local variable 'q' and add the above equality. |
| 96 | // The first constraint is q = (iv - lb) floordiv step |
| 97 | SmallVector<int64_t, 8> dividend(getNumCols(), 0); |
| 98 | int64_t lb = forOp.getConstantLowerBound(); |
| 99 | dividend[pos] = 1; |
| 100 | dividend.back() -= lb; |
| 101 | addLocalFloorDiv(dividend, divisor: step); |
| 102 | // Second constraint: (iv - lb) - step * q = 0. |
| 103 | SmallVector<int64_t, 8> eq(getNumCols(), 0); |
| 104 | eq[pos] = 1; |
| 105 | eq.back() -= lb; |
| 106 | // For the local var just added above. |
| 107 | eq[getNumCols() - 2] = -step; |
| 108 | addEquality(eq); |
| 109 | } |
| 110 | } |
| 111 | |
| 112 | if (forOp.hasConstantLowerBound()) { |
| 113 | addBound(BoundType::LB, pos, forOp.getConstantLowerBound()); |
| 114 | } else { |
| 115 | // Non-constant lower bound case. |
| 116 | if (failed(addBound(BoundType::LB, pos, forOp.getLowerBoundMap(), |
| 117 | forOp.getLowerBoundOperands()))) |
| 118 | return failure(); |
| 119 | } |
| 120 | |
| 121 | if (forOp.hasConstantUpperBound()) { |
| 122 | addBound(BoundType::UB, pos, forOp.getConstantUpperBound() - 1); |
| 123 | return success(); |
| 124 | } |
| 125 | // Non-constant upper bound case. |
| 126 | return addBound(BoundType::UB, pos, forOp.getUpperBoundMap(), |
| 127 | forOp.getUpperBoundOperands()); |
| 128 | } |
| 129 | |
| 130 | LogicalResult FlatAffineValueConstraints::addAffineParallelOpDomain( |
| 131 | AffineParallelOp parallelOp) { |
| 132 | size_t ivPos = 0; |
| 133 | for (Value iv : parallelOp.getIVs()) { |
| 134 | unsigned pos; |
| 135 | if (!findVar(iv, &pos)) { |
| 136 | assert(false && "variable expected for the IV value" ); |
| 137 | return failure(); |
| 138 | } |
| 139 | |
| 140 | AffineMap lowerBound = parallelOp.getLowerBoundMap(ivPos); |
| 141 | if (lowerBound.isConstant()) |
| 142 | addBound(BoundType::LB, pos, lowerBound.getSingleConstantResult()); |
| 143 | else if (failed(addBound(BoundType::LB, pos, lowerBound, |
| 144 | parallelOp.getLowerBoundsOperands()))) |
| 145 | return failure(); |
| 146 | |
| 147 | auto upperBound = parallelOp.getUpperBoundMap(ivPos); |
| 148 | if (upperBound.isConstant()) |
| 149 | addBound(BoundType::UB, pos, upperBound.getSingleConstantResult() - 1); |
| 150 | else if (failed(addBound(BoundType::UB, pos, upperBound, |
| 151 | parallelOp.getUpperBoundsOperands()))) |
| 152 | return failure(); |
| 153 | ++ivPos; |
| 154 | } |
| 155 | return success(); |
| 156 | } |
| 157 | |
| 158 | LogicalResult |
| 159 | FlatAffineValueConstraints::addDomainFromSliceMaps(ArrayRef<AffineMap> lbMaps, |
| 160 | ArrayRef<AffineMap> ubMaps, |
| 161 | ArrayRef<Value> operands) { |
| 162 | assert(lbMaps.size() == ubMaps.size()); |
| 163 | assert(lbMaps.size() <= getNumDimVars()); |
| 164 | |
| 165 | for (unsigned i = 0, e = lbMaps.size(); i < e; ++i) { |
| 166 | AffineMap lbMap = lbMaps[i]; |
| 167 | AffineMap ubMap = ubMaps[i]; |
| 168 | assert(!lbMap || lbMap.getNumInputs() == operands.size()); |
| 169 | assert(!ubMap || ubMap.getNumInputs() == operands.size()); |
| 170 | |
| 171 | // Check if this slice is just an equality along this dimension. If so, |
| 172 | // retrieve the existing loop it equates to and add it to the system. |
| 173 | if (lbMap && ubMap && lbMap.getNumResults() == 1 && |
| 174 | ubMap.getNumResults() == 1 && |
| 175 | lbMap.getResult(idx: 0) + 1 == ubMap.getResult(idx: 0) && |
| 176 | // The condition above will be true for maps describing a single |
| 177 | // iteration (e.g., lbMap.getResult(0) = 0, ubMap.getResult(0) = 1). |
| 178 | // Make sure we skip those cases by checking that the lb result is not |
| 179 | // just a constant. |
| 180 | !isa<AffineConstantExpr>(Val: lbMap.getResult(idx: 0))) { |
| 181 | // Limited support: we expect the lb result to be just a loop dimension. |
| 182 | // Not supported otherwise for now. |
| 183 | AffineDimExpr result = dyn_cast<AffineDimExpr>(Val: lbMap.getResult(idx: 0)); |
| 184 | if (!result) |
| 185 | return failure(); |
| 186 | |
| 187 | AffineForOp loop = |
| 188 | getForInductionVarOwner(operands[result.getPosition()]); |
| 189 | if (!loop) |
| 190 | return failure(); |
| 191 | |
| 192 | if (failed(addAffineForOpDomain(forOp: loop))) |
| 193 | return failure(); |
| 194 | continue; |
| 195 | } |
| 196 | |
| 197 | // This slice refers to a loop that doesn't exist in the IR yet. Add its |
| 198 | // bounds to the system assuming its dimension variable position is the |
| 199 | // same as the position of the loop in the loop nest. |
| 200 | if (lbMap && failed(Result: addBound(type: BoundType::LB, pos: i, boundMap: lbMap, operands))) |
| 201 | return failure(); |
| 202 | if (ubMap && failed(Result: addBound(type: BoundType::UB, pos: i, boundMap: ubMap, operands))) |
| 203 | return failure(); |
| 204 | } |
| 205 | return success(); |
| 206 | } |
| 207 | |
| 208 | void FlatAffineValueConstraints::addAffineIfOpDomain(AffineIfOp ifOp) { |
| 209 | IntegerSet set = ifOp.getIntegerSet(); |
| 210 | // Canonicalize set and operands to ensure unique values for |
| 211 | // FlatAffineValueConstraints below and for early simplification. |
| 212 | SmallVector<Value> operands(ifOp.getOperands()); |
| 213 | canonicalizeSetAndOperands(set: &set, operands: &operands); |
| 214 | |
| 215 | // Create the base constraints from the integer set attached to ifOp. |
| 216 | FlatAffineValueConstraints cst(set, operands); |
| 217 | |
| 218 | // Merge the constraints from ifOp to the current domain. We need first merge |
| 219 | // and align the IDs from both constraints, and then append the constraints |
| 220 | // from the ifOp into the current one. |
| 221 | mergeAndAlignVarsWithOther(offset: 0, other: &cst); |
| 222 | append(other: cst); |
| 223 | } |
| 224 | |
| 225 | LogicalResult FlatAffineValueConstraints::addBound(BoundType type, unsigned pos, |
| 226 | AffineMap boundMap, |
| 227 | ValueRange boundOperands) { |
| 228 | // Fully compose map and operands; canonicalize and simplify so that we |
| 229 | // transitively get to terminal symbols or loop IVs. |
| 230 | auto map = boundMap; |
| 231 | SmallVector<Value, 4> operands(boundOperands.begin(), boundOperands.end()); |
| 232 | fullyComposeAffineMapAndOperands(map: &map, operands: &operands); |
| 233 | map = simplifyAffineMap(map); |
| 234 | canonicalizeMapAndOperands(map: &map, operands: &operands); |
| 235 | for (Value operand : operands) { |
| 236 | if (failed(Result: addInductionVarOrTerminalSymbol(val: operand))) |
| 237 | return failure(); |
| 238 | } |
| 239 | return addBound(type, pos, boundMap: computeAlignedMap(map, operands)); |
| 240 | } |
| 241 | |
| 242 | // Adds slice lower bounds represented by lower bounds in 'lbMaps' and upper |
| 243 | // bounds in 'ubMaps' to each value in `values' that appears in the constraint |
| 244 | // system. Note that both lower/upper bounds share the same operand list |
| 245 | // 'operands'. |
| 246 | // This function assumes 'values.size' == 'lbMaps.size' == 'ubMaps.size', and |
| 247 | // skips any null AffineMaps in 'lbMaps' or 'ubMaps'. |
| 248 | // Note that both lower/upper bounds use operands from 'operands'. |
| 249 | // Returns failure for unimplemented cases such as semi-affine expressions or |
| 250 | // expressions with mod/floordiv. |
| 251 | LogicalResult FlatAffineValueConstraints::addSliceBounds( |
| 252 | ArrayRef<Value> values, ArrayRef<AffineMap> lbMaps, |
| 253 | ArrayRef<AffineMap> ubMaps, ArrayRef<Value> operands) { |
| 254 | assert(values.size() == lbMaps.size()); |
| 255 | assert(lbMaps.size() == ubMaps.size()); |
| 256 | |
| 257 | for (unsigned i = 0, e = lbMaps.size(); i < e; ++i) { |
| 258 | unsigned pos; |
| 259 | if (!findVar(val: values[i], pos: &pos)) |
| 260 | continue; |
| 261 | |
| 262 | AffineMap lbMap = lbMaps[i]; |
| 263 | AffineMap ubMap = ubMaps[i]; |
| 264 | assert(!lbMap || lbMap.getNumInputs() == operands.size()); |
| 265 | assert(!ubMap || ubMap.getNumInputs() == operands.size()); |
| 266 | |
| 267 | // Check if this slice is just an equality along this dimension. |
| 268 | if (lbMap && ubMap && lbMap.getNumResults() == 1 && |
| 269 | ubMap.getNumResults() == 1 && |
| 270 | lbMap.getResult(idx: 0) + 1 == ubMap.getResult(idx: 0)) { |
| 271 | if (failed(Result: addBound(type: BoundType::EQ, pos, boundMap: lbMap, boundOperands: operands))) |
| 272 | return failure(); |
| 273 | continue; |
| 274 | } |
| 275 | |
| 276 | // If lower or upper bound maps are null or provide no results, it implies |
| 277 | // that the source loop was not at all sliced, and the entire loop will be a |
| 278 | // part of the slice. |
| 279 | if (lbMap && lbMap.getNumResults() != 0 && ubMap && |
| 280 | ubMap.getNumResults() != 0) { |
| 281 | if (failed(Result: addBound(type: BoundType::LB, pos, boundMap: lbMap, boundOperands: operands))) |
| 282 | return failure(); |
| 283 | if (failed(Result: addBound(type: BoundType::UB, pos, boundMap: ubMap, boundOperands: operands))) |
| 284 | return failure(); |
| 285 | } else { |
| 286 | auto loop = getForInductionVarOwner(values[i]); |
| 287 | if (failed(this->addAffineForOpDomain(forOp: loop))) |
| 288 | return failure(); |
| 289 | } |
| 290 | } |
| 291 | return success(); |
| 292 | } |
| 293 | |
| 294 | LogicalResult |
| 295 | FlatAffineValueConstraints::composeMap(const AffineValueMap *vMap) { |
| 296 | return composeMatchingMap( |
| 297 | other: computeAlignedMap(map: vMap->getAffineMap(), operands: vMap->getOperands())); |
| 298 | } |
| 299 | |
| 300 | // Turn a symbol into a dimension. |
| 301 | static void turnSymbolIntoDim(FlatAffineValueConstraints *cst, Value value) { |
| 302 | unsigned pos; |
| 303 | if (cst->findVar(val: value, pos: &pos) && pos >= cst->getNumDimVars() && |
| 304 | pos < cst->getNumDimAndSymbolVars()) { |
| 305 | cst->swapVar(posA: pos, posB: cst->getNumDimVars()); |
| 306 | cst->setDimSymbolSeparation(cst->getNumSymbolVars() - 1); |
| 307 | } |
| 308 | } |
| 309 | |
| 310 | // Changes all symbol variables which are loop IVs to dim variables. |
| 311 | void FlatAffineValueConstraints::convertLoopIVSymbolsToDims() { |
| 312 | // Gather all symbols which are loop IVs. |
| 313 | SmallVector<Value, 4> loopIVs; |
| 314 | for (unsigned i = getNumDimVars(), e = getNumDimAndSymbolVars(); i < e; i++) { |
| 315 | if (hasValue(pos: i) && getForInductionVarOwner(getValue(pos: i))) |
| 316 | loopIVs.push_back(Elt: getValue(pos: i)); |
| 317 | } |
| 318 | // Turn each symbol in 'loopIVs' into a dim variable. |
| 319 | for (auto iv : loopIVs) { |
| 320 | turnSymbolIntoDim(cst: this, value: iv); |
| 321 | } |
| 322 | } |
| 323 | |
| 324 | void FlatAffineValueConstraints::getIneqAsAffineValueMap( |
| 325 | unsigned pos, unsigned ineqPos, AffineValueMap &vmap, |
| 326 | MLIRContext *context) const { |
| 327 | unsigned numDims = getNumDimVars(); |
| 328 | unsigned numSyms = getNumSymbolVars(); |
| 329 | |
| 330 | assert(pos < numDims && "invalid position" ); |
| 331 | assert(ineqPos < getNumInequalities() && "invalid inequality position" ); |
| 332 | |
| 333 | // Get expressions for local vars. |
| 334 | SmallVector<AffineExpr, 8> memo(getNumVars(), AffineExpr()); |
| 335 | if (failed(Result: computeLocalVars(memo, context))) |
| 336 | assert(false && |
| 337 | "one or more local exprs do not have an explicit representation" ); |
| 338 | auto localExprs = ArrayRef<AffineExpr>(memo).take_back(N: getNumLocalVars()); |
| 339 | |
| 340 | // Compute the AffineExpr lower/upper bound for this inequality. |
| 341 | SmallVector<int64_t, 8> inequality = getInequality64(idx: ineqPos); |
| 342 | SmallVector<int64_t, 8> bound; |
| 343 | bound.reserve(N: getNumCols() - 1); |
| 344 | // Everything other than the coefficient at `pos`. |
| 345 | bound.append(in_start: inequality.begin(), in_end: inequality.begin() + pos); |
| 346 | bound.append(in_start: inequality.begin() + pos + 1, in_end: inequality.end()); |
| 347 | |
| 348 | if (inequality[pos] > 0) |
| 349 | // Lower bound. |
| 350 | std::transform(first: bound.begin(), last: bound.end(), result: bound.begin(), |
| 351 | unary_op: std::negate<int64_t>()); |
| 352 | else |
| 353 | // Upper bound (which is exclusive). |
| 354 | bound.back() += 1; |
| 355 | |
| 356 | // Convert to AffineExpr (tree) form. |
| 357 | auto boundExpr = getAffineExprFromFlatForm(flatExprs: bound, numDims: numDims - 1, numSymbols: numSyms, |
| 358 | localExprs, context); |
| 359 | |
| 360 | // Get the values to bind to this affine expr (all dims and symbols). |
| 361 | SmallVector<Value, 4> operands; |
| 362 | getValues(start: 0, end: pos, values: &operands); |
| 363 | SmallVector<Value, 4> trailingOperands; |
| 364 | getValues(start: pos + 1, end: getNumDimAndSymbolVars(), values: &trailingOperands); |
| 365 | operands.append(in_start: trailingOperands.begin(), in_end: trailingOperands.end()); |
| 366 | vmap.reset(map: AffineMap::get(dimCount: numDims - 1, symbolCount: numSyms, result: boundExpr), operands); |
| 367 | } |
| 368 | |
| 369 | FlatAffineValueConstraints FlatAffineRelation::getDomainSet() const { |
| 370 | FlatAffineValueConstraints domain = *this; |
| 371 | // Convert all range variables to local variables. |
| 372 | domain.convertToLocal(kind: VarKind::SetDim, varStart: getNumDomainDims(), |
| 373 | varLimit: getNumDomainDims() + getNumRangeDims()); |
| 374 | return domain; |
| 375 | } |
| 376 | |
| 377 | FlatAffineValueConstraints FlatAffineRelation::getRangeSet() const { |
| 378 | FlatAffineValueConstraints range = *this; |
| 379 | // Convert all domain variables to local variables. |
| 380 | range.convertToLocal(kind: VarKind::SetDim, varStart: 0, varLimit: getNumDomainDims()); |
| 381 | return range; |
| 382 | } |
| 383 | |
| 384 | void FlatAffineRelation::compose(const FlatAffineRelation &other) { |
| 385 | assert(getNumDomainDims() == other.getNumRangeDims() && |
| 386 | "Domain of this and range of other do not match" ); |
| 387 | assert(space.getDomainSpace().isAligned(other.getSpace().getRangeSpace()) && |
| 388 | "Values of domain of this and range of other do not match" ); |
| 389 | |
| 390 | FlatAffineRelation rel = other; |
| 391 | |
| 392 | // Convert `rel` from |
| 393 | // [otherDomain] -> [otherRange] |
| 394 | // to |
| 395 | // [otherDomain] -> [otherRange thisRange] |
| 396 | // and `this` from |
| 397 | // [thisDomain] -> [thisRange] |
| 398 | // to |
| 399 | // [otherDomain thisDomain] -> [thisRange]. |
| 400 | unsigned removeDims = rel.getNumRangeDims(); |
| 401 | insertDomainVar(pos: 0, num: rel.getNumDomainDims()); |
| 402 | rel.appendRangeVar(num: getNumRangeDims()); |
| 403 | |
| 404 | // Merge symbol and local variables. |
| 405 | mergeSymbolVars(other&: rel); |
| 406 | mergeLocalVars(other&: rel); |
| 407 | |
| 408 | // Convert `rel` from [otherDomain] -> [otherRange thisRange] to |
| 409 | // [otherDomain] -> [thisRange] by converting first otherRange range vars |
| 410 | // to local vars. |
| 411 | rel.convertToLocal(kind: VarKind::SetDim, varStart: rel.getNumDomainDims(), |
| 412 | varLimit: rel.getNumDomainDims() + removeDims); |
| 413 | // Convert `this` from [otherDomain thisDomain] -> [thisRange] to |
| 414 | // [otherDomain] -> [thisRange] by converting last thisDomain domain vars |
| 415 | // to local vars. |
| 416 | convertToLocal(kind: VarKind::SetDim, varStart: getNumDomainDims() - removeDims, |
| 417 | varLimit: getNumDomainDims()); |
| 418 | |
| 419 | auto thisMaybeValues = getMaybeValues(kind: VarKind::SetDim); |
| 420 | auto relMaybeValues = rel.getMaybeValues(kind: VarKind::SetDim); |
| 421 | |
| 422 | // Add and match domain of `rel` to domain of `this`. |
| 423 | for (unsigned i = 0, e = rel.getNumDomainDims(); i < e; ++i) |
| 424 | if (relMaybeValues[i].has_value()) |
| 425 | setValue(pos: i, val: *relMaybeValues[i]); |
| 426 | // Add and match range of `this` to range of `rel`. |
| 427 | for (unsigned i = 0, e = getNumRangeDims(); i < e; ++i) { |
| 428 | unsigned rangeIdx = rel.getNumDomainDims() + i; |
| 429 | if (thisMaybeValues[rangeIdx].has_value()) |
| 430 | rel.setValue(pos: rangeIdx, val: *thisMaybeValues[rangeIdx]); |
| 431 | } |
| 432 | |
| 433 | // Append `this` to `rel` and simplify constraints. |
| 434 | rel.append(other: *this); |
| 435 | rel.removeRedundantLocalVars(); |
| 436 | |
| 437 | *this = rel; |
| 438 | } |
| 439 | |
| 440 | void FlatAffineRelation::inverse() { |
| 441 | unsigned oldDomain = getNumDomainDims(); |
| 442 | unsigned oldRange = getNumRangeDims(); |
| 443 | // Add new range vars. |
| 444 | appendRangeVar(num: oldDomain); |
| 445 | // Swap new vars with domain. |
| 446 | for (unsigned i = 0; i < oldDomain; ++i) |
| 447 | swapVar(posA: i, posB: oldDomain + oldRange + i); |
| 448 | // Remove the swapped domain. |
| 449 | removeVarRange(varStart: 0, varLimit: oldDomain); |
| 450 | // Set domain and range as inverse. |
| 451 | numDomainDims = oldRange; |
| 452 | numRangeDims = oldDomain; |
| 453 | } |
| 454 | |
| 455 | void FlatAffineRelation::insertDomainVar(unsigned pos, unsigned num) { |
| 456 | assert(pos <= getNumDomainDims() && |
| 457 | "Var cannot be inserted at invalid position" ); |
| 458 | insertDimVar(pos, num); |
| 459 | numDomainDims += num; |
| 460 | } |
| 461 | |
| 462 | void FlatAffineRelation::insertRangeVar(unsigned pos, unsigned num) { |
| 463 | assert(pos <= getNumRangeDims() && |
| 464 | "Var cannot be inserted at invalid position" ); |
| 465 | insertDimVar(pos: getNumDomainDims() + pos, num); |
| 466 | numRangeDims += num; |
| 467 | } |
| 468 | |
| 469 | void FlatAffineRelation::appendDomainVar(unsigned num) { |
| 470 | insertDimVar(pos: getNumDomainDims(), num); |
| 471 | numDomainDims += num; |
| 472 | } |
| 473 | |
| 474 | void FlatAffineRelation::appendRangeVar(unsigned num) { |
| 475 | insertDimVar(pos: getNumDimVars(), num); |
| 476 | numRangeDims += num; |
| 477 | } |
| 478 | |
| 479 | void FlatAffineRelation::removeVarRange(VarKind kind, unsigned varStart, |
| 480 | unsigned varLimit) { |
| 481 | assert(varLimit <= getNumVarKind(kind)); |
| 482 | if (varStart >= varLimit) |
| 483 | return; |
| 484 | |
| 485 | FlatAffineValueConstraints::removeVarRange(kind, varStart, varLimit); |
| 486 | |
| 487 | // If kind is not SetDim, domain and range don't need to be updated. |
| 488 | if (kind != VarKind::SetDim) |
| 489 | return; |
| 490 | |
| 491 | // Compute number of domain and range variables to remove. This is done by |
| 492 | // intersecting the range of domain/range vars with range of vars to remove. |
| 493 | unsigned intersectDomainLHS = std::min(a: varLimit, b: getNumDomainDims()); |
| 494 | unsigned intersectDomainRHS = varStart; |
| 495 | unsigned intersectRangeLHS = std::min(a: varLimit, b: getNumDimVars()); |
| 496 | unsigned intersectRangeRHS = std::max(a: varStart, b: getNumDomainDims()); |
| 497 | |
| 498 | if (intersectDomainLHS > intersectDomainRHS) |
| 499 | numDomainDims -= intersectDomainLHS - intersectDomainRHS; |
| 500 | if (intersectRangeLHS > intersectRangeRHS) |
| 501 | numRangeDims -= intersectRangeLHS - intersectRangeRHS; |
| 502 | } |
| 503 | |
| 504 | LogicalResult mlir::affine::getRelationFromMap(AffineMap &map, |
| 505 | IntegerRelation &rel) { |
| 506 | // Get flattened affine expressions. |
| 507 | std::vector<SmallVector<int64_t, 8>> flatExprs; |
| 508 | FlatAffineValueConstraints localVarCst; |
| 509 | if (failed(Result: getFlattenedAffineExprs(map, flattenedExprs: &flatExprs, cst: &localVarCst))) |
| 510 | return failure(); |
| 511 | |
| 512 | const unsigned oldDimNum = localVarCst.getNumDimVars(); |
| 513 | const unsigned oldCols = localVarCst.getNumCols(); |
| 514 | const unsigned numRangeVars = map.getNumResults(); |
| 515 | const unsigned numDomainVars = map.getNumDims(); |
| 516 | |
| 517 | // Add range as the new expressions. |
| 518 | localVarCst.appendDimVar(num: numRangeVars); |
| 519 | |
| 520 | // Add identifiers to the local constraints as getFlattenedAffineExprs creates |
| 521 | // a FlatLinearConstraints with no identifiers. |
| 522 | for (unsigned i = 0, e = localVarCst.getNumDimAndSymbolVars(); i < e; ++i) |
| 523 | localVarCst.setValue(pos: i, val: Value()); |
| 524 | |
| 525 | // Add equalities between source and range. |
| 526 | SmallVector<int64_t, 8> eq(localVarCst.getNumCols()); |
| 527 | for (unsigned i = 0, e = map.getNumResults(); i < e; ++i) { |
| 528 | // Zero fill. |
| 529 | std::fill(first: eq.begin(), last: eq.end(), value: 0); |
| 530 | // Fill equality. |
| 531 | for (unsigned j = 0, f = oldDimNum; j < f; ++j) |
| 532 | eq[j] = flatExprs[i][j]; |
| 533 | for (unsigned j = oldDimNum, f = oldCols; j < f; ++j) |
| 534 | eq[j + numRangeVars] = flatExprs[i][j]; |
| 535 | // Set this dimension to -1 to equate lhs and rhs and add equality. |
| 536 | eq[numDomainVars + i] = -1; |
| 537 | localVarCst.addEquality(eq); |
| 538 | } |
| 539 | |
| 540 | rel = localVarCst; |
| 541 | return success(); |
| 542 | } |
| 543 | |
| 544 | LogicalResult mlir::affine::getRelationFromMap(const AffineValueMap &map, |
| 545 | IntegerRelation &rel) { |
| 546 | |
| 547 | AffineMap affineMap = map.getAffineMap(); |
| 548 | if (failed(Result: getRelationFromMap(map&: affineMap, rel))) |
| 549 | return failure(); |
| 550 | |
| 551 | // Set identifiers for domain and symbol variables. |
| 552 | for (unsigned i = 0, e = affineMap.getNumDims(); i < e; ++i) |
| 553 | rel.setId(kind: VarKind::SetDim, i, id: Identifier(map.getOperand(i))); |
| 554 | |
| 555 | const unsigned mapNumResults = affineMap.getNumResults(); |
| 556 | for (unsigned i = 0, e = rel.getNumSymbolVars(); i < e; ++i) |
| 557 | rel.setId( |
| 558 | kind: VarKind::Symbol, i, |
| 559 | id: Identifier(map.getOperand(i: rel.getNumDimVars() + i - mapNumResults))); |
| 560 | |
| 561 | return success(); |
| 562 | } |
| 563 | |