| 1 | //===- Utils.cpp - Utilities to support the Linalg dialect ----------------===// |
| 2 | // |
| 3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
| 4 | // See https://llvm.org/LICENSE.txt for license information. |
| 5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
| 6 | // |
| 7 | //===----------------------------------------------------------------------===// |
| 8 | // |
| 9 | // This file implements utilities for the Linalg dialect. |
| 10 | // |
| 11 | //===----------------------------------------------------------------------===// |
| 12 | |
| 13 | #include "mlir/Dialect/Linalg/Utils/Utils.h" |
| 14 | |
| 15 | #include "mlir/Analysis/SliceAnalysis.h" |
| 16 | #include "mlir/Dialect/Affine/Analysis/AffineStructures.h" |
| 17 | #include "mlir/Dialect/Affine/IR/AffineOps.h" |
| 18 | #include "mlir/Dialect/Affine/IR/AffineValueMap.h" |
| 19 | #include "mlir/Dialect/Affine/LoopUtils.h" |
| 20 | #include "mlir/Dialect/Arith/IR/Arith.h" |
| 21 | #include "mlir/Dialect/Arith/Utils/Utils.h" |
| 22 | #include "mlir/Dialect/Func/IR/FuncOps.h" |
| 23 | #include "mlir/Dialect/Linalg/IR/Linalg.h" |
| 24 | #include "mlir/Dialect/MemRef/IR/MemRef.h" |
| 25 | #include "mlir/Dialect/SCF/IR/SCF.h" |
| 26 | #include "mlir/Dialect/Tensor/IR/Tensor.h" |
| 27 | #include "mlir/Dialect/Tensor/Utils/Utils.h" |
| 28 | #include "mlir/Dialect/Utils/IndexingUtils.h" |
| 29 | #include "mlir/Dialect/Utils/StaticValueUtils.h" |
| 30 | #include "mlir/IR/AffineExpr.h" |
| 31 | #include "mlir/IR/AffineExprVisitor.h" |
| 32 | #include "mlir/IR/AffineMap.h" |
| 33 | #include "mlir/IR/Matchers.h" |
| 34 | #include "mlir/IR/OpImplementation.h" |
| 35 | #include "mlir/Pass/Pass.h" |
| 36 | #include "llvm/ADT/TypeSwitch.h" |
| 37 | #include "llvm/Support/Debug.h" |
| 38 | #include <optional> |
| 39 | |
| 40 | #define DEBUG_TYPE "linalg-utils" |
| 41 | |
| 42 | using namespace mlir; |
| 43 | using namespace presburger; |
| 44 | using namespace mlir::affine; |
| 45 | using namespace mlir::linalg; |
| 46 | using namespace mlir::scf; |
| 47 | |
| 48 | namespace { |
| 49 | |
| 50 | // Helper visitor to determine whether an AffineExpr is tiled. |
| 51 | // This is achieved by traversing every AffineDimExpr with position `pos` and |
| 52 | // checking whether the corresponding `tileSizes[pos]` is non-zero. |
| 53 | // This also enforces only positive coefficients occur in multiplications. |
| 54 | // |
| 55 | // Example: |
| 56 | // `d0 + 2 * d1 + d3` is tiled by [0, 0, 0, 2] but not by [0, 0, 2, 0] |
| 57 | // |
| 58 | struct TileCheck : public AffineExprVisitor<TileCheck> { |
| 59 | TileCheck(ArrayRef<OpFoldResult> tileSizes) : tileSizes(tileSizes) {} |
| 60 | |
| 61 | void visitDimExpr(AffineDimExpr expr) { |
| 62 | isTiled |= !isZeroInteger(v: tileSizes[expr.getPosition()]); |
| 63 | } |
| 64 | void visitAffineBinaryOpExpr(AffineBinaryOpExpr expr) { |
| 65 | visit(expr: expr.getLHS()); |
| 66 | visit(expr: expr.getRHS()); |
| 67 | if (expr.getKind() == mlir::AffineExprKind::Mul) |
| 68 | assert(cast<AffineConstantExpr>(expr.getRHS()).getValue() > 0 && |
| 69 | "nonpositive multiplying coefficient" ); |
| 70 | } |
| 71 | bool isTiled = false; |
| 72 | ArrayRef<OpFoldResult> tileSizes; |
| 73 | }; |
| 74 | |
| 75 | } // namespace |
| 76 | |
| 77 | static bool isTiled(AffineExpr expr, ArrayRef<OpFoldResult> tileSizes) { |
| 78 | if (!expr) |
| 79 | return false; |
| 80 | TileCheck t(tileSizes); |
| 81 | t.visit(expr); |
| 82 | return t.isTiled; |
| 83 | } |
| 84 | |
| 85 | // Checks whether the `map varies with respect to a non-zero `tileSize`. |
| 86 | static bool isTiled(AffineMap map, ArrayRef<OpFoldResult> tileSizes) { |
| 87 | if (!map) |
| 88 | return false; |
| 89 | for (unsigned r = 0; r < map.getNumResults(); ++r) |
| 90 | if (isTiled(expr: map.getResult(idx: r), tileSizes)) |
| 91 | return true; |
| 92 | return false; |
| 93 | } |
| 94 | |
| 95 | std::optional<RegionMatcher::BinaryOpKind> |
| 96 | RegionMatcher::matchAsScalarBinaryOp(GenericOp op) { |
| 97 | auto ®ion = op.getRegion(); |
| 98 | if (!llvm::hasSingleElement(region)) |
| 99 | return std::nullopt; |
| 100 | |
| 101 | Block &block = region.front(); |
| 102 | if (block.getNumArguments() != 2 || |
| 103 | !block.getArgument(i: 0).getType().isSignlessIntOrFloat() || |
| 104 | !block.getArgument(i: 1).getType().isSignlessIntOrFloat()) |
| 105 | return std::nullopt; |
| 106 | |
| 107 | auto &ops = block.getOperations(); |
| 108 | if (!llvm::hasSingleElement(C: block.without_terminator())) |
| 109 | return std::nullopt; |
| 110 | |
| 111 | using mlir::matchers::m_Val; |
| 112 | auto a = m_Val(v: block.getArgument(i: 0)); |
| 113 | auto b = m_Val(v: block.getArgument(i: 1)); |
| 114 | |
| 115 | auto addPattern = m_Op<linalg::YieldOp>(m_Op<arith::AddIOp>(a, b)); |
| 116 | if (addPattern.match(&ops.back())) |
| 117 | return BinaryOpKind::IAdd; |
| 118 | |
| 119 | return std::nullopt; |
| 120 | } |
| 121 | |
| 122 | /// Explicit instantiation of loop nest generator for different loop types. |
| 123 | template struct mlir::linalg::GenerateLoopNest<scf::ForOp>; |
| 124 | template struct mlir::linalg::GenerateLoopNest<scf::ParallelOp>; |
| 125 | template struct mlir::linalg::GenerateLoopNest<AffineForOp>; |
| 126 | |
| 127 | /// Given a list of subview ranges, extract individual values for lower, upper |
| 128 | /// bounds and steps and put them into the corresponding vectors. |
| 129 | static void unpackRanges(OpBuilder &builder, Location loc, |
| 130 | ArrayRef<Range> ranges, SmallVectorImpl<Value> &lbs, |
| 131 | SmallVectorImpl<Value> &ubs, |
| 132 | SmallVectorImpl<Value> &steps) { |
| 133 | for (Range range : ranges) { |
| 134 | lbs.emplace_back( |
| 135 | Args: getValueOrCreateConstantIndexOp(b&: builder, loc, ofr: range.offset)); |
| 136 | ubs.emplace_back(Args: getValueOrCreateConstantIndexOp(b&: builder, loc, ofr: range.size)); |
| 137 | steps.emplace_back( |
| 138 | Args: getValueOrCreateConstantIndexOp(b&: builder, loc, ofr: range.stride)); |
| 139 | } |
| 140 | } |
| 141 | |
| 142 | //===----------------------------------------------------------------------===// |
| 143 | // General utilities |
| 144 | //===----------------------------------------------------------------------===// |
| 145 | // |
| 146 | /// The permutation can be obtained from two permutations: |
| 147 | /// a) Compute the permutation vector to move the last `numPackedDims` into |
| 148 | /// the `innerPosDims` of a shape of rank `rank`. |
| 149 | /// b) Compute the permutation vector to move outer dims if the |
| 150 | /// `outerPerm` parameter is not empty. |
| 151 | /// Apply (b) permutation on (a) permutation to get the final permutation. |
| 152 | static SmallVector<int64_t> |
| 153 | computePackUnPackPerm(int64_t rank, ArrayRef<int64_t> &innerDimsPos, |
| 154 | ArrayRef<int64_t> &outerPerm, |
| 155 | PackingMetadata &packingMetadata) { |
| 156 | int64_t numPackedDims = innerDimsPos.size(); |
| 157 | auto lastDims = |
| 158 | llvm::to_vector(Range: llvm::seq<int64_t>(Begin: rank - numPackedDims, End: rank)); |
| 159 | packingMetadata = computePackingMetadata(packedRank: rank, innerDimPos: innerDimsPos); |
| 160 | SmallVector<int64_t> innerPositionsPerm = |
| 161 | computePermutationVector(permSize: rank, positions: lastDims, desiredPositions: packingMetadata.insertPositions); |
| 162 | |
| 163 | SmallVector<int64_t> outerPos = packingMetadata.outerPositions; |
| 164 | if (!outerPerm.empty()) |
| 165 | applyPermutationToVector(inVec&: outerPos, permutation: outerPerm); |
| 166 | SmallVector<int64_t> outerPositionPerm = |
| 167 | computePermutationVector(permSize: rank, positions: packingMetadata.outerPositions, desiredPositions: outerPos); |
| 168 | |
| 169 | SmallVector<int64_t> packInverseDestPermutation = innerPositionsPerm; |
| 170 | applyPermutationToVector(inVec&: packInverseDestPermutation, permutation: outerPositionPerm); |
| 171 | return packInverseDestPermutation; |
| 172 | } |
| 173 | |
| 174 | namespace mlir { |
| 175 | namespace linalg { |
| 176 | |
| 177 | SmallVector<int64_t> getPackInverseDestPerm(PackOp packOp) { |
| 178 | |
| 179 | PackingMetadata pMetadata; |
| 180 | int64_t packedRank = packOp.getDestType().getRank(); |
| 181 | ArrayRef<int64_t> innerDimPos = packOp.getInnerDimsPos(); |
| 182 | ArrayRef<int64_t> outerPerm = packOp.getOuterDimsPerm(); |
| 183 | SmallVector<int64_t> packInvDestPerm = |
| 184 | computePackUnPackPerm(rank: packedRank, innerDimsPos&: innerDimPos, outerPerm, packingMetadata&: pMetadata); |
| 185 | return packInvDestPerm; |
| 186 | } |
| 187 | |
| 188 | SmallVector<int64_t> getUnPackInverseSrcPerm(UnPackOp unpackOp) { |
| 189 | PackingMetadata metadata; |
| 190 | return getUnPackInverseSrcPerm(unpackOp, metadata); |
| 191 | } |
| 192 | |
| 193 | SmallVector<int64_t> getUnPackInverseSrcPerm(UnPackOp unpackOp, |
| 194 | PackingMetadata &metadata) { |
| 195 | int64_t unpackRank = unpackOp.getSourceType().getRank(); |
| 196 | ArrayRef<int64_t> innerDimPos = unpackOp.getInnerDimsPos(); |
| 197 | ArrayRef<int64_t> outerPerm = unpackOp.getOuterDimsPerm(); |
| 198 | SmallVector<int64_t> unpackInvSrcPerm = |
| 199 | computePackUnPackPerm(rank: unpackRank, innerDimsPos&: innerDimPos, outerPerm, packingMetadata&: metadata); |
| 200 | return unpackInvSrcPerm; |
| 201 | } |
| 202 | |
| 203 | bool allIndexingsAreProjectedPermutation(LinalgOp op) { |
| 204 | return llvm::all_of(op.getIndexingMapsArray(), [](AffineMap m) { |
| 205 | return m.isProjectedPermutation(/*allowZeroInResults=*/true); |
| 206 | }); |
| 207 | } |
| 208 | |
| 209 | bool hasOnlyScalarElementwiseOp(Region &r) { |
| 210 | if (!llvm::hasSingleElement(C&: r)) |
| 211 | return false; |
| 212 | for (Operation &op : r.front()) { |
| 213 | if (!(isa<arith::ConstantOp, func::ConstantOp, tensor::ExtractOp, |
| 214 | linalg::YieldOp, linalg::IndexOp, AffineApplyOp>(op) || |
| 215 | OpTrait::hasElementwiseMappableTraits(&op)) || |
| 216 | llvm::any_of(op.getResultTypes(), |
| 217 | [](Type type) { return !type.isIntOrIndexOrFloat(); })) |
| 218 | return false; |
| 219 | } |
| 220 | return true; |
| 221 | } |
| 222 | |
| 223 | bool isElementwise(LinalgOp op) { |
| 224 | if (op.getNumLoops() != op.getNumParallelLoops()) |
| 225 | return false; |
| 226 | |
| 227 | if (!allIndexingsAreProjectedPermutation(op)) |
| 228 | return false; |
| 229 | |
| 230 | // TODO: relax the restrictions on indexing map. |
| 231 | for (OpOperand &opOperand : op.getDpsInitsMutable()) { |
| 232 | if (!op.getMatchingIndexingMap(&opOperand).isPermutation()) |
| 233 | return false; |
| 234 | } |
| 235 | return hasOnlyScalarElementwiseOp(op->getRegion(0)); |
| 236 | } |
| 237 | |
| 238 | bool isParallelIterator(utils::IteratorType iteratorType) { |
| 239 | return iteratorType == utils::IteratorType::parallel; |
| 240 | } |
| 241 | |
| 242 | bool isReductionIterator(utils::IteratorType iteratorType) { |
| 243 | return iteratorType == utils::IteratorType::reduction; |
| 244 | } |
| 245 | |
| 246 | Value makeComposedPadHighOp(OpBuilder &b, Location loc, RankedTensorType type, |
| 247 | Value source, Value pad, bool nofold) { |
| 248 | // Exit if `source` is not defined by an ExtractSliceOp. |
| 249 | auto sliceOp = source.getDefiningOp<tensor::ExtractSliceOp>(); |
| 250 | if (!sliceOp) |
| 251 | return tensor::createPadHighOp(type, source, pad, nofold, loc, b); |
| 252 | |
| 253 | // Search the `source` use-def chain for padded LinalgOps. |
| 254 | Value current = sliceOp.getSource(); |
| 255 | while (current) { |
| 256 | auto linalgOp = current.getDefiningOp<LinalgOp>(); |
| 257 | if (!linalgOp) |
| 258 | break; |
| 259 | OpResult opResult = cast<OpResult>(Val&: current); |
| 260 | current = linalgOp.getDpsInitOperand(opResult.getResultNumber())->get(); |
| 261 | } |
| 262 | auto padOp = current ? current.getDefiningOp<tensor::PadOp>() : nullptr; |
| 263 | |
| 264 | // Exit if the search fails to match a tensor::PadOp at the end of the matched |
| 265 | // LinalgOp sequence. |
| 266 | if (!padOp) |
| 267 | return tensor::createPadHighOp(type, source, pad, nofold, loc, b); |
| 268 | |
| 269 | // Exit if the padded result type does not match. |
| 270 | if (sliceOp.getSource().getType() != type) |
| 271 | return tensor::createPadHighOp(type, source, pad, nofold, loc, b); |
| 272 | |
| 273 | // Exit if the LinalgOps are not high padded. |
| 274 | if (llvm::any_of(padOp.getMixedLowPad(), [](OpFoldResult ofr) { |
| 275 | return getConstantIntValue(ofr) != static_cast<int64_t>(0); |
| 276 | })) |
| 277 | return tensor::createPadHighOp(type, source, pad, nofold, loc, b); |
| 278 | |
| 279 | // Exit if `padOpSliceOp`, which defines the slice used by |
| 280 | // `padOp`, is rank-reducing. |
| 281 | auto padOpSliceOp = padOp.getSource().getDefiningOp<tensor::ExtractSliceOp>(); |
| 282 | if (!padOpSliceOp || |
| 283 | sliceOp.getMixedSizes().size() != padOpSliceOp.getMixedSizes().size()) |
| 284 | return tensor::createPadHighOp(type, source, pad, nofold, loc, b); |
| 285 | |
| 286 | // Exit if the sizes of the dynamic sizes of `sliceOp` do not match the size |
| 287 | // of the slice padded by `padOp`. |
| 288 | if (llvm::any_of( |
| 289 | llvm::zip(sliceOp.getMixedSizes(), padOpSliceOp.getMixedSizes()), |
| 290 | [](std::tuple<OpFoldResult, OpFoldResult> it) { |
| 291 | return !isEqualConstantIntOrValue(ofr1: std::get<0>(t&: it), ofr2: std::get<1>(t&: it)); |
| 292 | })) |
| 293 | return tensor::createPadHighOp(type, source, pad, nofold, loc, b); |
| 294 | |
| 295 | // Exit if the padding values do not match. |
| 296 | Attribute padOpPadAttr, padAttr; |
| 297 | Value padOpPad = padOp.getConstantPaddingValue(); |
| 298 | if (!padOpPad || !matchPattern(value: padOpPad, pattern: m_Constant(bind_value: &padOpPadAttr)) || |
| 299 | !matchPattern(value: pad, pattern: m_Constant(bind_value: &padAttr)) || padOpPadAttr != padAttr) |
| 300 | return tensor::createPadHighOp(type, source, pad, nofold, loc, b); |
| 301 | |
| 302 | // Return the padded result if the padding values and sizes match. |
| 303 | return sliceOp.getSource(); |
| 304 | } |
| 305 | |
| 306 | GenericOp makeMemRefCopyOp(OpBuilder &b, Location loc, Value from, Value to) { |
| 307 | auto memrefTypeTo = cast<MemRefType>(to.getType()); |
| 308 | #ifndef NDEBUG |
| 309 | auto memrefTypeFrom = cast<MemRefType>(from.getType()); |
| 310 | assert(memrefTypeFrom.getRank() == memrefTypeTo.getRank() && |
| 311 | "`from` and `to` memref must have the same rank" ); |
| 312 | #endif // NDEBUG |
| 313 | |
| 314 | AffineMap id = |
| 315 | AffineMap::getMultiDimIdentityMap(numDims: memrefTypeTo.getRank(), context: b.getContext()); |
| 316 | SmallVector<utils::IteratorType> iteratorTypes(memrefTypeTo.getRank(), |
| 317 | utils::IteratorType::parallel); |
| 318 | return b.create<linalg::GenericOp>( |
| 319 | loc, |
| 320 | /*inputs=*/from, |
| 321 | /*outputs=*/to, |
| 322 | /*indexingMaps=*/llvm::ArrayRef({id, id}), |
| 323 | /*iteratorTypes=*/iteratorTypes, |
| 324 | [](OpBuilder &b, Location loc, ValueRange args) { |
| 325 | b.create<linalg::YieldOp>(loc, args.front()); |
| 326 | }); |
| 327 | } |
| 328 | |
| 329 | /// Specialization to build an scf "for" nest. |
| 330 | template <> |
| 331 | void GenerateLoopNest<scf::ForOp>::doit( |
| 332 | OpBuilder &b, Location loc, ArrayRef<Range> loopRanges, LinalgOp linalgOp, |
| 333 | ArrayRef<utils::IteratorType> iteratorTypes, |
| 334 | function_ref<scf::ValueVector(OpBuilder &, Location, ValueRange, |
| 335 | ValueRange)> |
| 336 | bodyBuilderFn, |
| 337 | ArrayRef<linalg::ProcInfo> procInfo) { |
| 338 | assert((procInfo.empty() || (procInfo.size() == loopRanges.size())) && |
| 339 | "expected as many entries for proc info as number of loops, even if " |
| 340 | "they are null entries" ); |
| 341 | SmallVector<Value> iterArgInitValues; |
| 342 | if (!linalgOp.hasPureBufferSemantics()) |
| 343 | llvm::append_range(iterArgInitValues, linalgOp.getDpsInits()); |
| 344 | SmallVector<Value, 4> lbs, ubs, steps; |
| 345 | unpackRanges(builder&: b, loc, ranges: loopRanges, lbs, ubs, steps); |
| 346 | LoopNest loopNest = mlir::scf::buildLoopNest( |
| 347 | b, loc, lbs, ubs, steps, iterArgInitValues, |
| 348 | [&](OpBuilder &b, Location loc, ValueRange ivs, ValueRange iterArgs) { |
| 349 | assert(iterArgs.size() == iterArgInitValues.size() && |
| 350 | "expect the number of output tensors and iter args to match" ); |
| 351 | SmallVector<Value> operandValuesToUse = linalgOp->getOperands(); |
| 352 | if (!iterArgs.empty()) { |
| 353 | operandValuesToUse = linalgOp.getDpsInputs(); |
| 354 | operandValuesToUse.append(in_start: iterArgs.begin(), in_end: iterArgs.end()); |
| 355 | } |
| 356 | return bodyBuilderFn(b, loc, ivs, operandValuesToUse); |
| 357 | }); |
| 358 | |
| 359 | if (loopNest.loops.empty() || procInfo.empty()) |
| 360 | return; |
| 361 | |
| 362 | // Filter out scf.for loops that were created out of parallel dimensions. |
| 363 | for (const auto &loop : llvm::enumerate(loopNest.loops)) { |
| 364 | if (procInfo[loop.index()].distributionMethod == |
| 365 | DistributionMethod::Cyclic) { |
| 366 | mapLoopToProcessorIds(loop.value(), procInfo[loop.index()].procId, |
| 367 | procInfo[loop.index()].nprocs); |
| 368 | } |
| 369 | } |
| 370 | } |
| 371 | |
| 372 | /// Specialization to build affine "for" nest. |
| 373 | template <> |
| 374 | void GenerateLoopNest<AffineForOp>::doit( |
| 375 | OpBuilder &b, Location loc, ArrayRef<Range> loopRanges, LinalgOp linalgOp, |
| 376 | ArrayRef<utils::IteratorType> iteratorTypes, |
| 377 | function_ref<scf::ValueVector(OpBuilder &, Location, ValueRange, |
| 378 | ValueRange)> |
| 379 | bodyBuilderFn, |
| 380 | ArrayRef<linalg::ProcInfo> /*procInfo*/) { |
| 381 | SmallVector<Value> iterArgInitValues; |
| 382 | if (!linalgOp.hasPureBufferSemantics()) |
| 383 | llvm::append_range(iterArgInitValues, linalgOp.getDpsInits()); |
| 384 | assert(iterArgInitValues.empty() && "unexpected AffineForOp init values" ); |
| 385 | SmallVector<Value, 4> lbs, ubs, steps; |
| 386 | unpackRanges(builder&: b, loc, ranges: loopRanges, lbs, ubs, steps); |
| 387 | |
| 388 | // Affine loops require constant steps. |
| 389 | SmallVector<int64_t, 4> constantSteps; |
| 390 | constantSteps.reserve(N: steps.size()); |
| 391 | for (Value v : steps) { |
| 392 | auto constVal = getConstantIntValue(ofr: v); |
| 393 | assert(constVal.has_value() && "Affine loops require constant steps" ); |
| 394 | constantSteps.push_back(Elt: constVal.value()); |
| 395 | } |
| 396 | |
| 397 | affine::buildAffineLoopNest(builder&: b, loc, lbs, ubs, steps: constantSteps, |
| 398 | bodyBuilderFn: [&](OpBuilder &b, Location loc, ValueRange ivs) { |
| 399 | bodyBuilderFn(b, loc, ivs, |
| 400 | linalgOp->getOperands()); |
| 401 | }); |
| 402 | } |
| 403 | |
| 404 | /// Update the `lb`, `ub` and `step` to get per processor `lb`, `ub` and `step`. |
| 405 | void updateBoundsForCyclicDistribution(OpBuilder &b, Location loc, Value procId, |
| 406 | Value nprocs, Value &lb, Value &ub, |
| 407 | Value &step) { |
| 408 | AffineExpr d0, d1; |
| 409 | bindDims(ctx: b.getContext(), exprs&: d0, exprs&: d1); |
| 410 | AffineExpr s0 = getAffineSymbolExpr(position: 0, context: b.getContext()); |
| 411 | lb = |
| 412 | affine::makeComposedAffineApply(b, loc, d0 + d1 * s0, {lb, procId, step}); |
| 413 | step = affine::makeComposedAffineApply(b, loc, d0 * s0, {nprocs, step}); |
| 414 | } |
| 415 | |
| 416 | /// Generates a loop nest consisting of scf.parallel and scf.for, depending |
| 417 | /// on the `iteratorTypes.` Consecutive parallel loops create a single |
| 418 | /// scf.parallel operation; each sequential loop creates a new scf.for |
| 419 | /// operation. The body of the innermost loop is populated by |
| 420 | /// `bodyBuilderFn` that accepts a range of induction variables for all |
| 421 | /// loops. `ivStorage` is used to store the partial list of induction |
| 422 | /// variables. |
| 423 | // TODO: this function can be made iterative instead. However, it |
| 424 | // will have at most as many recursive calls as nested loops, which rarely |
| 425 | // exceeds 10. |
| 426 | static void generateParallelLoopNest( |
| 427 | OpBuilder &b, Location loc, ValueRange lbs, ValueRange ubs, |
| 428 | ValueRange steps, ArrayRef<utils::IteratorType> iteratorTypes, |
| 429 | ArrayRef<linalg::ProcInfo> procInfo, |
| 430 | function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuilderFn, |
| 431 | SmallVectorImpl<Value> &ivStorage) { |
| 432 | assert(lbs.size() == ubs.size()); |
| 433 | assert(lbs.size() == steps.size()); |
| 434 | assert(lbs.size() == iteratorTypes.size()); |
| 435 | assert(procInfo.empty() || (lbs.size() == procInfo.size())); |
| 436 | |
| 437 | // If there are no (more) loops to be generated, generate the body and be |
| 438 | // done with it. |
| 439 | if (iteratorTypes.empty()) { |
| 440 | bodyBuilderFn(b, loc, ivStorage); |
| 441 | return; |
| 442 | } |
| 443 | |
| 444 | // If there are no outer parallel loops, generate one sequential loop and |
| 445 | // recurse. |
| 446 | if (!isParallelIterator(iteratorTypes.front())) { |
| 447 | LoopNest singleLoop = buildLoopNest( |
| 448 | builder&: b, loc, lbs: lbs.take_front(), ubs: ubs.take_front(), steps: steps.take_front(), |
| 449 | bodyBuilder: [&](OpBuilder &b, Location loc, ValueRange ivs) { |
| 450 | ivStorage.append(in_start: ivs.begin(), in_end: ivs.end()); |
| 451 | generateParallelLoopNest( |
| 452 | b, loc, lbs.drop_front(), ubs.drop_front(), steps.drop_front(), |
| 453 | iteratorTypes.drop_front(), |
| 454 | procInfo.empty() ? procInfo : procInfo.drop_front(), |
| 455 | bodyBuilderFn, ivStorage); |
| 456 | }); |
| 457 | return; |
| 458 | } |
| 459 | |
| 460 | unsigned nLoops = iteratorTypes.size(); |
| 461 | unsigned numProcessed = 0; |
| 462 | DistributionMethod distributionMethod = DistributionMethod::None; |
| 463 | if (procInfo.empty()) { |
| 464 | numProcessed = nLoops - iteratorTypes.drop_while(isParallelIterator).size(); |
| 465 | } else { |
| 466 | distributionMethod = procInfo.front().distributionMethod; |
| 467 | numProcessed = |
| 468 | nLoops - procInfo |
| 469 | .drop_while(Pred: [&](linalg::ProcInfo p) { |
| 470 | return p.distributionMethod == distributionMethod; |
| 471 | }) |
| 472 | .size(); |
| 473 | } |
| 474 | |
| 475 | auto remainderProcInfo = |
| 476 | procInfo.empty() ? procInfo : procInfo.drop_front(N: numProcessed); |
| 477 | switch (distributionMethod) { |
| 478 | case DistributionMethod::None: { |
| 479 | // Generate a single parallel loop-nest operation for all outermost |
| 480 | // parallel loops and recurse. |
| 481 | b.create<scf::ParallelOp>( |
| 482 | loc, lbs.take_front(n: numProcessed), ubs.take_front(n: numProcessed), |
| 483 | steps.take_front(n: numProcessed), |
| 484 | [&](OpBuilder &nestedBuilder, Location nestedLoc, ValueRange localIvs) { |
| 485 | ivStorage.append(in_start: localIvs.begin(), in_end: localIvs.end()); |
| 486 | generateParallelLoopNest( |
| 487 | nestedBuilder, nestedLoc, lbs.drop_front(n: numProcessed), |
| 488 | ubs.drop_front(n: numProcessed), steps.drop_front(n: numProcessed), |
| 489 | iteratorTypes.drop_front(numProcessed), remainderProcInfo, |
| 490 | bodyBuilderFn, ivStorage); |
| 491 | }); |
| 492 | return; |
| 493 | } |
| 494 | case DistributionMethod::Cyclic: { |
| 495 | // Generate a single parallel loop-nest operation for all outermost |
| 496 | // parallel loops and recurse. |
| 497 | b.create<scf::ParallelOp>( |
| 498 | loc, lbs.take_front(n: numProcessed), ubs.take_front(n: numProcessed), |
| 499 | steps.take_front(n: numProcessed), |
| 500 | [&](OpBuilder &nestedBuilder, Location nestedLoc, ValueRange localIvs) { |
| 501 | ivStorage.append(in_start: localIvs.begin(), in_end: localIvs.end()); |
| 502 | generateParallelLoopNest( |
| 503 | nestedBuilder, nestedLoc, lbs.drop_front(n: numProcessed), |
| 504 | ubs.drop_front(n: numProcessed), steps.drop_front(n: numProcessed), |
| 505 | iteratorTypes.drop_front(numProcessed), remainderProcInfo, |
| 506 | bodyBuilderFn, ivStorage); |
| 507 | }); |
| 508 | return; |
| 509 | } |
| 510 | case DistributionMethod::CyclicNumProcsGeNumIters: { |
| 511 | // Check (for the processed loops) that the iteration is in-bounds. |
| 512 | ArithBuilder ab(b, loc); |
| 513 | Value cond = ab.slt(lhs: lbs[0], rhs: ubs[0]); |
| 514 | for (unsigned i = 1; i < numProcessed; ++i) |
| 515 | cond = ab._and(lhs: cond, rhs: ab.slt(lhs: lbs[i], rhs: ubs[i])); |
| 516 | ivStorage.append(in_start: lbs.begin(), in_end: std::next(x: lbs.begin(), n: numProcessed)); |
| 517 | b.create<scf::IfOp>(loc, cond, [&](OpBuilder &b, Location loc) { |
| 518 | generateParallelLoopNest(b, loc, lbs.drop_front(numProcessed), |
| 519 | ubs.drop_front(numProcessed), |
| 520 | steps.drop_front(numProcessed), |
| 521 | iteratorTypes.drop_front(numProcessed), |
| 522 | remainderProcInfo, bodyBuilderFn, ivStorage); |
| 523 | b.create<scf::YieldOp>(loc, ValueRange{}); |
| 524 | }); |
| 525 | return; |
| 526 | } |
| 527 | case DistributionMethod::CyclicNumProcsEqNumIters: |
| 528 | // No check/loops needed here. Set the `%iv` to be the `%lb` and proceed |
| 529 | // with inner loop generation. |
| 530 | ivStorage.append(in_start: lbs.begin(), in_end: std::next(x: lbs.begin(), n: numProcessed)); |
| 531 | generateParallelLoopNest( |
| 532 | b, loc, lbs.drop_front(n: numProcessed), ubs.drop_front(n: numProcessed), |
| 533 | steps.drop_front(n: numProcessed), iteratorTypes.drop_front(numProcessed), |
| 534 | remainderProcInfo, bodyBuilderFn, ivStorage); |
| 535 | return; |
| 536 | } |
| 537 | } |
| 538 | |
| 539 | /// Specialization for generating a mix of parallel and sequential scf loops. |
| 540 | template <> |
| 541 | void GenerateLoopNest<scf::ParallelOp>::doit( |
| 542 | OpBuilder &b, Location loc, ArrayRef<Range> loopRanges, LinalgOp linalgOp, |
| 543 | ArrayRef<utils::IteratorType> iteratorTypes, |
| 544 | function_ref<scf::ValueVector(OpBuilder &, Location, ValueRange, |
| 545 | ValueRange)> |
| 546 | bodyBuilderFn, |
| 547 | ArrayRef<linalg::ProcInfo> procInfo) { |
| 548 | SmallVector<Value> iterArgInitValues; |
| 549 | if (!linalgOp.hasPureBufferSemantics()) |
| 550 | llvm::append_range(iterArgInitValues, linalgOp.getDpsInits()); |
| 551 | assert(iterArgInitValues.empty() && "unexpected ParallelOp init values" ); |
| 552 | // This function may be passed more iterator types than ranges. |
| 553 | assert(iteratorTypes.size() >= loopRanges.size() && |
| 554 | "expected iterator type for all ranges" ); |
| 555 | assert((procInfo.empty() || (procInfo.size() == loopRanges.size())) && |
| 556 | "expected proc information for all loops when present" ); |
| 557 | iteratorTypes = iteratorTypes.take_front(loopRanges.size()); |
| 558 | SmallVector<Value, 8> lbsStorage, ubsStorage, stepsStorage, ivs; |
| 559 | unsigned numLoops = iteratorTypes.size(); |
| 560 | ivs.reserve(N: numLoops); |
| 561 | lbsStorage.reserve(N: numLoops); |
| 562 | ubsStorage.reserve(N: numLoops); |
| 563 | stepsStorage.reserve(N: numLoops); |
| 564 | |
| 565 | // Get the loop lb, ub, and step. |
| 566 | unpackRanges(builder&: b, loc, ranges: loopRanges, lbs&: lbsStorage, ubs&: ubsStorage, steps&: stepsStorage); |
| 567 | |
| 568 | // Modify the lb, ub, and step based on the distribution options. |
| 569 | for (const auto &it : llvm::enumerate(First&: procInfo)) { |
| 570 | if (it.value().distributionMethod != linalg::DistributionMethod::None) { |
| 571 | updateBoundsForCyclicDistribution( |
| 572 | b, loc, procId: it.value().procId, nprocs: it.value().nprocs, lb&: lbsStorage[it.index()], |
| 573 | ub&: ubsStorage[it.index()], step&: stepsStorage[it.index()]); |
| 574 | } |
| 575 | } |
| 576 | ValueRange lbs(lbsStorage), ubs(ubsStorage), steps(stepsStorage); |
| 577 | generateParallelLoopNest( |
| 578 | b, loc, lbs, ubs, steps, iteratorTypes, procInfo, |
| 579 | [&](OpBuilder &b, Location loc, ValueRange ivs) { |
| 580 | bodyBuilderFn(b, loc, ivs, linalgOp->getOperands()); |
| 581 | }, |
| 582 | ivs); |
| 583 | |
| 584 | assert(ivs.size() == iteratorTypes.size() && "did not generate enough loops" ); |
| 585 | } |
| 586 | |
| 587 | static Operation *materializeTiledShape(OpBuilder &builder, Location loc, |
| 588 | Value valueToTile, |
| 589 | const SliceParameters &sliceParams) { |
| 590 | auto shapedType = dyn_cast<ShapedType>(valueToTile.getType()); |
| 591 | auto *sliceOp = TypeSwitch<ShapedType, Operation *>(shapedType) |
| 592 | .Case([&](MemRefType) { |
| 593 | return builder.create<memref::SubViewOp>( |
| 594 | loc, valueToTile, sliceParams.offsets, |
| 595 | sliceParams.sizes, sliceParams.strides); |
| 596 | }) |
| 597 | .Case([&](RankedTensorType) { |
| 598 | return builder.create<tensor::ExtractSliceOp>( |
| 599 | loc, valueToTile, sliceParams.offsets, |
| 600 | sliceParams.sizes, sliceParams.strides); |
| 601 | }) |
| 602 | .Default([](ShapedType) -> Operation * { |
| 603 | llvm_unreachable("Unexpected shaped type" ); |
| 604 | }); |
| 605 | return sliceOp; |
| 606 | } |
| 607 | |
| 608 | Operation *makeTiledShape(OpBuilder &builder, Location loc, Value valueToTile, |
| 609 | ArrayRef<OpFoldResult> tileSizes, AffineMap map, |
| 610 | ArrayRef<OpFoldResult> lbs, |
| 611 | ArrayRef<OpFoldResult> ubs, |
| 612 | ArrayRef<OpFoldResult> subShapeSizes, |
| 613 | bool omitPartialTileCheck) { |
| 614 | SliceParameters sliceParams = |
| 615 | computeSliceParameters(builder, loc, valueToTile, tileSizes, map, lbs, |
| 616 | ubs, subShapeSizes, omitPartialTileCheck); |
| 617 | return materializeTiledShape(builder, loc, valueToTile, sliceParams); |
| 618 | } |
| 619 | |
| 620 | SliceParameters |
| 621 | computeSliceParameters(OpBuilder &builder, Location loc, Value valueToTile, |
| 622 | ArrayRef<OpFoldResult> tileSizes, AffineMap map, |
| 623 | ArrayRef<OpFoldResult> lbs, ArrayRef<OpFoldResult> ubs, |
| 624 | ArrayRef<OpFoldResult> subShapeSizes, |
| 625 | bool omitPartialTileCheck) { |
| 626 | auto shapedType = dyn_cast<ShapedType>(valueToTile.getType()); |
| 627 | assert(shapedType && "only shaped types can be tiled" ); |
| 628 | ArrayRef<int64_t> shape = shapedType.getShape(); |
| 629 | int64_t rank = shapedType.getRank(); |
| 630 | |
| 631 | // Compute offsets/sizes/strides for the tile. |
| 632 | SliceParameters sliceParams; |
| 633 | sliceParams.offsets.reserve(N: rank); |
| 634 | sliceParams.sizes.reserve(N: rank); |
| 635 | sliceParams.strides.reserve(N: rank); |
| 636 | for (unsigned r = 0; r < rank; ++r) { |
| 637 | LLVM_DEBUG(llvm::dbgs() << "computeSliceParameters: for dim#" << r); |
| 638 | if (!isTiled(map: map.getSubMap(resultPos: {r}), tileSizes)) { |
| 639 | sliceParams.offsets.push_back(builder.getIndexAttr(0)); |
| 640 | OpFoldResult dim = createFoldedDimOp(b&: builder, loc, val: valueToTile, dim: r); |
| 641 | sliceParams.sizes.push_back(Elt: dim); |
| 642 | sliceParams.strides.push_back(builder.getIndexAttr(1)); |
| 643 | LLVM_DEBUG(llvm::dbgs() << ": not tiled: use size: " << dim << "\n" ); |
| 644 | continue; |
| 645 | } |
| 646 | LLVM_DEBUG(llvm::dbgs() << ": tiled: figure out subsize...\n" ); |
| 647 | |
| 648 | // Tiling creates a new slice at the proper index, the slice step is 1 |
| 649 | // (i.e. the op does not subsample, stepping occurs in the loop). |
| 650 | auto m = map.getSubMap(resultPos: {r}); |
| 651 | LLVM_DEBUG(llvm::dbgs() << "computeSliceParameters: submap: " << m << "\n" ); |
| 652 | IRRewriter rewriter(builder); |
| 653 | // The offset of the slice is m(lbs) - m(0). |
| 654 | SmallVector<Attribute> zeros(lbs.size(), rewriter.getIndexAttr(0)); |
| 655 | SmallVector<Attribute> mAtZero; |
| 656 | [[maybe_unused]] auto res = m.constantFold(operandConstants: zeros, results&: mAtZero); |
| 657 | assert(succeeded(res) && "affine_map must be evaluatable (not symbols)" ); |
| 658 | int64_t mAtZeroInt = |
| 659 | cast<IntegerAttr>(mAtZero[0]).getValue().getSExtValue(); |
| 660 | OpFoldResult offset = makeComposedFoldedAffineApply( |
| 661 | b&: rewriter, loc, expr: m.getResult(idx: 0) - mAtZeroInt, operands: lbs); |
| 662 | sliceParams.offsets.push_back(Elt: offset); |
| 663 | |
| 664 | OpFoldResult closedIntSize = |
| 665 | makeComposedFoldedAffineApply(b&: rewriter, loc, map: m, operands: subShapeSizes); |
| 666 | // Resulting size needs to be made half open interval again. |
| 667 | AffineExpr s0 = getAffineSymbolExpr(position: 0, context: builder.getContext()); |
| 668 | OpFoldResult size = |
| 669 | makeComposedFoldedAffineApply(b&: rewriter, loc, expr: s0 + 1, operands: closedIntSize); |
| 670 | LLVM_DEBUG(llvm::dbgs() |
| 671 | << "computeSliceParameters: raw size: " << size << "\n" ); |
| 672 | LLVM_DEBUG(llvm::dbgs() |
| 673 | << "computeSliceParameters: new offset: " << offset << "\n" ); |
| 674 | sliceParams.strides.push_back(builder.getIndexAttr(1)); |
| 675 | |
| 676 | if (omitPartialTileCheck) { |
| 677 | // We statically know that the partial/boundary tile condition is |
| 678 | // unnecessary. |
| 679 | LLVM_DEBUG(llvm::dbgs() << "makeTiledShape: new size: " << size << "\n" ); |
| 680 | sliceParams.sizes.push_back(Elt: size); |
| 681 | continue; |
| 682 | } |
| 683 | |
| 684 | // The size of the subview / extract_slice should be trimmed to avoid |
| 685 | // out-of-bounds accesses, unless: |
| 686 | // a. We statically know the subshape size divides the shape size evenly. |
| 687 | // b. The subshape size is 1. According to the way the loops are set up, |
| 688 | // tensors with "0" dimensions would never be constructed. |
| 689 | int64_t shapeSize = shape[r]; |
| 690 | std::optional<int64_t> sizeCst = getConstantIntValue(ofr: size); |
| 691 | auto hasTileSizeOne = sizeCst && *sizeCst == 1; |
| 692 | auto dividesEvenly = sizeCst && !ShapedType::isDynamic(shapeSize) && |
| 693 | ((shapeSize % *sizeCst) == 0); |
| 694 | if (!hasTileSizeOne && !dividesEvenly) { |
| 695 | LLVM_DEBUG(llvm::dbgs() << "makeTiledShape: shapeSize=" << shapeSize |
| 696 | << ", size: " << size |
| 697 | << ": make sure in bound with affine.min\n" ); |
| 698 | |
| 699 | AffineExpr dim0, dim1, dim2; |
| 700 | MLIRContext *context = builder.getContext(); |
| 701 | bindDims(ctx: context, exprs&: dim0, exprs&: dim1, exprs&: dim2); |
| 702 | |
| 703 | // Get the dimension size for this dimension. We need to first calculate |
| 704 | // the max index and then plus one. This is important because for |
| 705 | // convolution ops, we have its input window dimension's affine map of the |
| 706 | // form `(d0 * s0 + d1)`, where `d0`/`d1 is an output/filter window |
| 707 | // dimension and `s0` is stride. Directly use the dimension size of |
| 708 | // output/filer window dimensions will cause incorrect calculation. |
| 709 | AffineMap minusOneMap = AffineMap::inferFromExprList( |
| 710 | exprsList: {ArrayRef<AffineExpr>{dim0 - 1}}, context) |
| 711 | .front(); |
| 712 | AffineMap plusOneMap = AffineMap::inferFromExprList( |
| 713 | exprsList: {ArrayRef<AffineExpr>{dim0 + 1}}, context) |
| 714 | .front(); |
| 715 | SmallVector<OpFoldResult> maxIndices = |
| 716 | llvm::to_vector(Range: llvm::map_range(C&: ubs, F: [&](OpFoldResult ub) { |
| 717 | return makeComposedFoldedAffineApply(b&: rewriter, loc, map: minusOneMap, |
| 718 | operands: {ub}); |
| 719 | })); |
| 720 | OpFoldResult maxIndex = |
| 721 | makeComposedFoldedAffineApply(b&: rewriter, loc, map: m, operands: maxIndices); |
| 722 | OpFoldResult d = |
| 723 | makeComposedFoldedAffineApply(b&: rewriter, loc, map: plusOneMap, operands: {maxIndex}); |
| 724 | |
| 725 | // Compute min(dim - offset, size) to avoid out-of-bounds accesses. |
| 726 | AffineMap minMap = AffineMap::inferFromExprList( |
| 727 | exprsList: {ArrayRef<AffineExpr>{dim1 - dim2, dim0}}, context) |
| 728 | .front(); |
| 729 | size = |
| 730 | makeComposedFoldedAffineMin(b&: rewriter, loc, map: minMap, operands: {size, d, offset}); |
| 731 | } |
| 732 | LLVM_DEBUG(llvm::dbgs() << "makeTiledShape: new size: " << size << "\n" ); |
| 733 | sliceParams.sizes.push_back(Elt: size); |
| 734 | } |
| 735 | return sliceParams; |
| 736 | } |
| 737 | |
| 738 | SmallVector<OpFoldResult> computeTileOffsets(OpBuilder &b, Location loc, |
| 739 | ArrayRef<OpFoldResult> ivs, |
| 740 | ArrayRef<OpFoldResult> tileSizes) { |
| 741 | SmallVector<OpFoldResult> offsets; |
| 742 | for (unsigned idx = 0, idxIvs = 0, e = tileSizes.size(); idx < e; ++idx) { |
| 743 | LLVM_DEBUG(llvm::dbgs() << "makeTiledShapes: for loop#" << idx << "\n" ); |
| 744 | bool isTiled = !isZeroInteger(v: tileSizes[idx]); |
| 745 | offsets.push_back(Elt: isTiled ? ivs[idxIvs++] : b.getIndexAttr(0)); |
| 746 | LLVM_DEBUG(llvm::dbgs() |
| 747 | << "computeTileOffsets: " << offsets.back() << "\n" ); |
| 748 | } |
| 749 | return offsets; |
| 750 | } |
| 751 | |
| 752 | SmallVector<OpFoldResult> computeTileSizes(OpBuilder &b, Location loc, |
| 753 | ArrayRef<OpFoldResult> tileSizes, |
| 754 | ArrayRef<OpFoldResult> sizeBounds) { |
| 755 | SmallVector<OpFoldResult> sizes; |
| 756 | for (unsigned idx = 0, e = tileSizes.size(); idx < e; ++idx) { |
| 757 | bool isTiled = !isZeroInteger(v: tileSizes[idx]); |
| 758 | // Before composing, we need to make range a closed interval. |
| 759 | OpFoldResult size = isTiled ? tileSizes[idx] : sizeBounds[idx]; |
| 760 | AffineExpr d0 = getAffineDimExpr(position: 0, context: b.getContext()); |
| 761 | IRRewriter rewriter(b); |
| 762 | sizes.push_back(Elt: makeComposedFoldedAffineApply(b&: rewriter, loc, expr: d0 - 1, operands: size)); |
| 763 | LLVM_DEBUG(llvm::dbgs() << "computeTileSizes: " << sizes.back() << "\n" ); |
| 764 | } |
| 765 | return sizes; |
| 766 | } |
| 767 | |
| 768 | SmallVector<Type> getTensorOutputTypes(LinalgOp op, ValueRange operands) { |
| 769 | if (op.hasPureBufferSemantics()) |
| 770 | return {}; |
| 771 | return llvm::to_vector( |
| 772 | llvm::map_range(op.getDpsInitsMutable(), [&](OpOperand &opOperand) { |
| 773 | return operands[opOperand.getOperandNumber()].getType(); |
| 774 | })); |
| 775 | } |
| 776 | |
| 777 | SmallVector<Value> insertSlicesBack(OpBuilder &builder, Location loc, |
| 778 | LinalgOp op, ValueRange operands, |
| 779 | ValueRange results) { |
| 780 | if (op.hasPureBufferSemantics()) |
| 781 | return {}; |
| 782 | SmallVector<Value> tensorResults; |
| 783 | tensorResults.reserve(N: results.size()); |
| 784 | // Insert a insert_slice for each output tensor. |
| 785 | unsigned resultIdx = 0; |
| 786 | for (OpOperand &opOperand : op.getDpsInitsMutable()) { |
| 787 | // TODO: use an interface/adaptor to avoid leaking position in |
| 788 | // `tiledOperands`. |
| 789 | Value outputTensor = operands[opOperand.getOperandNumber()]; |
| 790 | if (auto sliceOp = outputTensor.getDefiningOp<tensor::ExtractSliceOp>()) { |
| 791 | Value inserted = builder.create<tensor::InsertSliceOp>( |
| 792 | loc, sliceOp.getSource().getType(), results[resultIdx], |
| 793 | sliceOp.getSource(), sliceOp.getOffsets(), sliceOp.getSizes(), |
| 794 | sliceOp.getStrides(), sliceOp.getStaticOffsets(), |
| 795 | sliceOp.getStaticSizes(), sliceOp.getStaticStrides()); |
| 796 | tensorResults.push_back(inserted); |
| 797 | } else { |
| 798 | tensorResults.push_back(results[resultIdx]); |
| 799 | } |
| 800 | ++resultIdx; |
| 801 | } |
| 802 | return tensorResults; |
| 803 | } |
| 804 | |
| 805 | SmallVector<std::optional<SliceParameters>> |
| 806 | computeAllSliceParameters(OpBuilder &builder, Location loc, LinalgOp linalgOp, |
| 807 | ValueRange valuesToTile, ArrayRef<OpFoldResult> ivs, |
| 808 | ArrayRef<OpFoldResult> tileSizes, |
| 809 | ArrayRef<OpFoldResult> sizeBounds, |
| 810 | bool omitPartialTileCheck) { |
| 811 | assert(ivs.size() == static_cast<size_t>(llvm::count_if( |
| 812 | llvm::make_range(tileSizes.begin(), tileSizes.end()), |
| 813 | [](OpFoldResult v) { return !isZeroInteger(v); })) && |
| 814 | "expected as many ivs as non-zero sizes" ); |
| 815 | |
| 816 | // Construct (potentially temporary) mins and maxes on which to apply maps |
| 817 | // that define tile subshapes. |
| 818 | SmallVector<OpFoldResult> lbs = |
| 819 | computeTileOffsets(b&: builder, loc, ivs, tileSizes); |
| 820 | SmallVector<OpFoldResult> subShapeSizes = |
| 821 | computeTileSizes(b&: builder, loc, tileSizes, sizeBounds); |
| 822 | |
| 823 | assert(static_cast<int64_t>(valuesToTile.size()) <= |
| 824 | linalgOp->getNumOperands() && |
| 825 | "more value to tile than operands." ); |
| 826 | SmallVector<std::optional<SliceParameters>> allSliceParams; |
| 827 | allSliceParams.reserve(N: valuesToTile.size()); |
| 828 | for (auto [opOperand, val] : |
| 829 | llvm::zip(linalgOp->getOpOperands(), valuesToTile)) { |
| 830 | Value shapedOp = val; |
| 831 | LLVM_DEBUG(llvm::dbgs() << "makeTiledShapes: for operand " << shapedOp); |
| 832 | AffineMap map = linalgOp.getMatchingIndexingMap(&opOperand); |
| 833 | // Use `opOperand` as is if it is not tiled and not an output tensor. Having |
| 834 | // an extract/insert slice pair for all output tensors simplifies follow up |
| 835 | // transformations such as padding and bufferization since the |
| 836 | // extract/insert slice pairs make the accessed iteration argument |
| 837 | // subdomains explicit. |
| 838 | |
| 839 | Type operandType = opOperand.get().getType(); |
| 840 | if (!isTiled(map, tileSizes) && !(isa<RankedTensorType>(operandType) && |
| 841 | linalgOp.isDpsInit(&opOperand))) { |
| 842 | allSliceParams.push_back(std::nullopt); |
| 843 | LLVM_DEBUG(llvm::dbgs() |
| 844 | << ": not tiled: use shape: " << operandType << "\n" ); |
| 845 | continue; |
| 846 | } |
| 847 | LLVM_DEBUG(llvm::dbgs() << ": tiled: figure out subshape...\n" ); |
| 848 | |
| 849 | allSliceParams.push_back(computeSliceParameters( |
| 850 | builder, loc, shapedOp, tileSizes, map, lbs, sizeBounds, subShapeSizes, |
| 851 | omitPartialTileCheck)); |
| 852 | } |
| 853 | |
| 854 | return allSliceParams; |
| 855 | } |
| 856 | |
| 857 | SmallVector<Value> makeTiledShapes(OpBuilder &builder, Location loc, |
| 858 | LinalgOp linalgOp, ValueRange valuesToTile, |
| 859 | ArrayRef<OpFoldResult> ivs, |
| 860 | ArrayRef<OpFoldResult> tileSizes, |
| 861 | ArrayRef<OpFoldResult> sizeBounds, |
| 862 | bool omitPartialTileCheck) { |
| 863 | SmallVector<std::optional<SliceParameters>> allSliceParameter = |
| 864 | computeAllSliceParameters(builder, loc, linalgOp, valuesToTile, ivs, |
| 865 | tileSizes, sizeBounds, omitPartialTileCheck); |
| 866 | SmallVector<Value> tiledShapes; |
| 867 | for (auto item : llvm::zip(valuesToTile, allSliceParameter)) { |
| 868 | Value valueToTile = std::get<0>(item); |
| 869 | std::optional<SliceParameters> sliceParams = std::get<1>(item); |
| 870 | tiledShapes.push_back( |
| 871 | sliceParams.has_value() |
| 872 | ? materializeTiledShape(builder, loc, valueToTile, *sliceParams) |
| 873 | ->getResult(0) |
| 874 | : valueToTile); |
| 875 | } |
| 876 | return tiledShapes; |
| 877 | } |
| 878 | |
| 879 | void offsetIndices(OpBuilder &b, LinalgOp linalgOp, |
| 880 | ArrayRef<OpFoldResult> offsets) { |
| 881 | IRRewriter rewriter(b); |
| 882 | offsetIndices(rewriter, linalgOp, offsets); |
| 883 | } |
| 884 | |
| 885 | void offsetIndices(RewriterBase &b, LinalgOp linalgOp, |
| 886 | ArrayRef<OpFoldResult> offsets) { |
| 887 | if (!linalgOp.hasIndexSemantics()) |
| 888 | return; |
| 889 | |
| 890 | for (IndexOp indexOp : linalgOp.getBlock()->getOps<IndexOp>()) { |
| 891 | if (indexOp.getDim() >= offsets.size() || !offsets[indexOp.getDim()]) |
| 892 | continue; |
| 893 | OpBuilder::InsertionGuard guard(b); |
| 894 | b.setInsertionPointAfter(indexOp); |
| 895 | AffineExpr index, offset; |
| 896 | bindDims(b.getContext(), index, offset); |
| 897 | OpFoldResult applied = makeComposedFoldedAffineApply( |
| 898 | b, indexOp.getLoc(), index + offset, |
| 899 | {getAsOpFoldResult(indexOp.getResult()), offsets[indexOp.getDim()]}); |
| 900 | Value materialized = |
| 901 | getValueOrCreateConstantIndexOp(b, indexOp.getLoc(), applied); |
| 902 | b.replaceUsesWithIf(indexOp, materialized, [&](OpOperand &use) { |
| 903 | return use.getOwner() != materialized.getDefiningOp(); |
| 904 | }); |
| 905 | } |
| 906 | } |
| 907 | |
| 908 | /// Get the reassociation maps to fold the result of a extract_slice (or source |
| 909 | /// of a insert_slice) operation with given offsets, and sizes to its |
| 910 | /// rank-reduced version. This is only done for the cases where the size is 1 |
| 911 | /// and offset is 0. Strictly speaking the offset 0 is not required in general, |
| 912 | /// but non-zero offsets are not handled by SPIR-V backend at this point (and |
| 913 | /// potentially cannot be handled). |
| 914 | std::optional<SmallVector<ReassociationIndices>> |
| 915 | getReassociationMapForFoldingUnitDims(ArrayRef<OpFoldResult> mixedSizes) { |
| 916 | SmallVector<ReassociationIndices> reassociation; |
| 917 | ReassociationIndices curr; |
| 918 | for (const auto &it : llvm::enumerate(First&: mixedSizes)) { |
| 919 | auto dim = it.index(); |
| 920 | auto size = it.value(); |
| 921 | curr.push_back(Elt: dim); |
| 922 | auto attr = llvm::dyn_cast_if_present<Attribute>(Val&: size); |
| 923 | if (attr && cast<IntegerAttr>(attr).getInt() == 1) |
| 924 | continue; |
| 925 | reassociation.emplace_back(Args: ReassociationIndices{}); |
| 926 | std::swap(LHS&: reassociation.back(), RHS&: curr); |
| 927 | } |
| 928 | // When the reassociations are not empty, then fold the remaining |
| 929 | // unit-dimensions into the last dimension. If the reassociations so far is |
| 930 | // empty, then leave it emtpy. This will fold everything to a rank-0 tensor. |
| 931 | if (!curr.empty() && !reassociation.empty()) |
| 932 | reassociation.back().append(in_start: curr.begin(), in_end: curr.end()); |
| 933 | return reassociation; |
| 934 | } |
| 935 | |
| 936 | } // namespace linalg |
| 937 | } // namespace mlir |
| 938 | |