| 1 | //===- Utils.cpp - Utilities to support the Linalg dialect ----------------===// |
| 2 | // |
| 3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
| 4 | // See https://llvm.org/LICENSE.txt for license information. |
| 5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
| 6 | // |
| 7 | //===----------------------------------------------------------------------===// |
| 8 | // |
| 9 | // This file implements utilities for the Linalg dialect. |
| 10 | // |
| 11 | //===----------------------------------------------------------------------===// |
| 12 | |
| 13 | #include "mlir/Dialect/Linalg/Utils/Utils.h" |
| 14 | |
| 15 | #include "mlir/Analysis/SliceAnalysis.h" |
| 16 | #include "mlir/Dialect/Affine/Analysis/AffineStructures.h" |
| 17 | #include "mlir/Dialect/Affine/IR/AffineOps.h" |
| 18 | #include "mlir/Dialect/Affine/LoopUtils.h" |
| 19 | #include "mlir/Dialect/Arith/IR/Arith.h" |
| 20 | #include "mlir/Dialect/Arith/Utils/Utils.h" |
| 21 | #include "mlir/Dialect/Func/IR/FuncOps.h" |
| 22 | #include "mlir/Dialect/Linalg/IR/Linalg.h" |
| 23 | #include "mlir/Dialect/MemRef/IR/MemRef.h" |
| 24 | #include "mlir/Dialect/SCF/IR/SCF.h" |
| 25 | #include "mlir/Dialect/Tensor/IR/Tensor.h" |
| 26 | #include "mlir/Dialect/Tensor/Utils/Utils.h" |
| 27 | #include "mlir/Dialect/Utils/IndexingUtils.h" |
| 28 | #include "mlir/Dialect/Utils/StaticValueUtils.h" |
| 29 | #include "mlir/IR/AffineExpr.h" |
| 30 | #include "mlir/IR/AffineExprVisitor.h" |
| 31 | #include "mlir/IR/AffineMap.h" |
| 32 | #include "mlir/IR/Matchers.h" |
| 33 | #include "llvm/ADT/TypeSwitch.h" |
| 34 | #include "llvm/Support/Debug.h" |
| 35 | #include <optional> |
| 36 | |
| 37 | #define DEBUG_TYPE "linalg-utils" |
| 38 | |
| 39 | using namespace mlir; |
| 40 | using namespace presburger; |
| 41 | using namespace mlir::affine; |
| 42 | using namespace mlir::linalg; |
| 43 | using namespace mlir::scf; |
| 44 | |
| 45 | namespace { |
| 46 | |
| 47 | // Helper visitor to determine whether an AffineExpr is tiled. |
| 48 | // This is achieved by traversing every AffineDimExpr with position `pos` and |
| 49 | // checking whether the corresponding `tileSizes[pos]` is non-zero. |
| 50 | // This also enforces only positive coefficients occur in multiplications. |
| 51 | // |
| 52 | // Example: |
| 53 | // `d0 + 2 * d1 + d3` is tiled by [0, 0, 0, 2] but not by [0, 0, 2, 0] |
| 54 | // |
| 55 | struct TileCheck : public AffineExprVisitor<TileCheck> { |
| 56 | TileCheck(ArrayRef<OpFoldResult> tileSizes) : tileSizes(tileSizes) {} |
| 57 | |
| 58 | void visitDimExpr(AffineDimExpr expr) { |
| 59 | isTiled |= !isZeroInteger(v: tileSizes[expr.getPosition()]); |
| 60 | } |
| 61 | void visitAffineBinaryOpExpr(AffineBinaryOpExpr expr) { |
| 62 | visit(expr: expr.getLHS()); |
| 63 | visit(expr: expr.getRHS()); |
| 64 | if (expr.getKind() == mlir::AffineExprKind::Mul) |
| 65 | assert(cast<AffineConstantExpr>(expr.getRHS()).getValue() > 0 && |
| 66 | "nonpositive multiplying coefficient" ); |
| 67 | } |
| 68 | bool isTiled = false; |
| 69 | ArrayRef<OpFoldResult> tileSizes; |
| 70 | }; |
| 71 | |
| 72 | } // namespace |
| 73 | |
| 74 | static bool isTiled(AffineExpr expr, ArrayRef<OpFoldResult> tileSizes) { |
| 75 | if (!expr) |
| 76 | return false; |
| 77 | TileCheck t(tileSizes); |
| 78 | t.visit(expr); |
| 79 | return t.isTiled; |
| 80 | } |
| 81 | |
| 82 | // Checks whether the `map varies with respect to a non-zero `tileSize`. |
| 83 | static bool isTiled(AffineMap map, ArrayRef<OpFoldResult> tileSizes) { |
| 84 | if (!map) |
| 85 | return false; |
| 86 | for (unsigned r = 0; r < map.getNumResults(); ++r) |
| 87 | if (isTiled(expr: map.getResult(idx: r), tileSizes)) |
| 88 | return true; |
| 89 | return false; |
| 90 | } |
| 91 | |
| 92 | std::optional<RegionMatcher::BinaryOpKind> |
| 93 | RegionMatcher::matchAsScalarBinaryOp(GenericOp op) { |
| 94 | auto ®ion = op.getRegion(); |
| 95 | if (!llvm::hasSingleElement(C&: region)) |
| 96 | return std::nullopt; |
| 97 | |
| 98 | Block &block = region.front(); |
| 99 | if (block.getNumArguments() != 2 || |
| 100 | !block.getArgument(i: 0).getType().isSignlessIntOrFloat() || |
| 101 | !block.getArgument(i: 1).getType().isSignlessIntOrFloat()) |
| 102 | return std::nullopt; |
| 103 | |
| 104 | auto &ops = block.getOperations(); |
| 105 | if (!llvm::hasSingleElement(C: block.without_terminator())) |
| 106 | return std::nullopt; |
| 107 | |
| 108 | using mlir::matchers::m_Val; |
| 109 | auto a = m_Val(v: block.getArgument(i: 0)); |
| 110 | auto b = m_Val(v: block.getArgument(i: 1)); |
| 111 | |
| 112 | auto addPattern = m_Op<linalg::YieldOp>(matchers: m_Op<arith::AddIOp>(matchers: a, matchers: b)); |
| 113 | if (addPattern.match(op: &ops.back())) |
| 114 | return BinaryOpKind::IAdd; |
| 115 | |
| 116 | return std::nullopt; |
| 117 | } |
| 118 | |
| 119 | /// Explicit instantiation of loop nest generator for different loop types. |
| 120 | template struct mlir::linalg::GenerateLoopNest<scf::ForOp>; |
| 121 | template struct mlir::linalg::GenerateLoopNest<scf::ParallelOp>; |
| 122 | template struct mlir::linalg::GenerateLoopNest<AffineForOp>; |
| 123 | |
| 124 | /// Given a list of subview ranges, extract individual values for lower, upper |
| 125 | /// bounds and steps and put them into the corresponding vectors. |
| 126 | static void unpackRanges(OpBuilder &builder, Location loc, |
| 127 | ArrayRef<Range> ranges, SmallVectorImpl<Value> &lbs, |
| 128 | SmallVectorImpl<Value> &ubs, |
| 129 | SmallVectorImpl<Value> &steps) { |
| 130 | for (Range range : ranges) { |
| 131 | lbs.emplace_back( |
| 132 | Args: getValueOrCreateConstantIndexOp(b&: builder, loc, ofr: range.offset)); |
| 133 | ubs.emplace_back(Args: getValueOrCreateConstantIndexOp(b&: builder, loc, ofr: range.size)); |
| 134 | steps.emplace_back( |
| 135 | Args: getValueOrCreateConstantIndexOp(b&: builder, loc, ofr: range.stride)); |
| 136 | } |
| 137 | } |
| 138 | |
| 139 | //===----------------------------------------------------------------------===// |
| 140 | // General utilities |
| 141 | //===----------------------------------------------------------------------===// |
| 142 | // |
| 143 | /// The permutation can be obtained from two permutations: |
| 144 | /// a) Compute the permutation vector to move the last `numPackedDims` into |
| 145 | /// the `innerPosDims` of a shape of rank `rank`. |
| 146 | /// b) Compute the permutation vector to move outer dims if the |
| 147 | /// `outerPerm` parameter is not empty. |
| 148 | /// Apply (b) permutation on (a) permutation to get the final permutation. |
| 149 | static SmallVector<int64_t> |
| 150 | computePackUnPackPerm(int64_t rank, ArrayRef<int64_t> &innerDimsPos, |
| 151 | ArrayRef<int64_t> &outerPerm, |
| 152 | PackingMetadata &packingMetadata) { |
| 153 | int64_t numPackedDims = innerDimsPos.size(); |
| 154 | auto lastDims = |
| 155 | llvm::to_vector(Range: llvm::seq<int64_t>(Begin: rank - numPackedDims, End: rank)); |
| 156 | packingMetadata = computePackingMetadata(packedRank: rank, innerDimPos: innerDimsPos); |
| 157 | SmallVector<int64_t> innerPositionsPerm = |
| 158 | computePermutationVector(permSize: rank, positions: lastDims, desiredPositions: packingMetadata.insertPositions); |
| 159 | |
| 160 | SmallVector<int64_t> outerPos = packingMetadata.outerPositions; |
| 161 | if (!outerPerm.empty()) |
| 162 | applyPermutationToVector(inVec&: outerPos, permutation: outerPerm); |
| 163 | SmallVector<int64_t> outerPositionPerm = |
| 164 | computePermutationVector(permSize: rank, positions: packingMetadata.outerPositions, desiredPositions: outerPos); |
| 165 | |
| 166 | SmallVector<int64_t> packInverseDestPermutation = innerPositionsPerm; |
| 167 | applyPermutationToVector(inVec&: packInverseDestPermutation, permutation: outerPositionPerm); |
| 168 | return packInverseDestPermutation; |
| 169 | } |
| 170 | |
| 171 | namespace mlir { |
| 172 | namespace linalg { |
| 173 | |
| 174 | SmallVector<int64_t> getPackInverseDestPerm(PackOp packOp) { |
| 175 | |
| 176 | PackingMetadata pMetadata; |
| 177 | int64_t packedRank = packOp.getDestType().getRank(); |
| 178 | ArrayRef<int64_t> innerDimPos = packOp.getInnerDimsPos(); |
| 179 | ArrayRef<int64_t> outerPerm = packOp.getOuterDimsPerm(); |
| 180 | SmallVector<int64_t> packInvDestPerm = |
| 181 | computePackUnPackPerm(rank: packedRank, innerDimsPos&: innerDimPos, outerPerm, packingMetadata&: pMetadata); |
| 182 | return packInvDestPerm; |
| 183 | } |
| 184 | |
| 185 | SmallVector<int64_t> getUnPackInverseSrcPerm(UnPackOp unpackOp) { |
| 186 | PackingMetadata metadata; |
| 187 | return getUnPackInverseSrcPerm(unpackOp, metadata); |
| 188 | } |
| 189 | |
| 190 | SmallVector<int64_t> getUnPackInverseSrcPerm(UnPackOp unpackOp, |
| 191 | PackingMetadata &metadata) { |
| 192 | int64_t unpackRank = unpackOp.getSourceType().getRank(); |
| 193 | ArrayRef<int64_t> innerDimPos = unpackOp.getInnerDimsPos(); |
| 194 | ArrayRef<int64_t> outerPerm = unpackOp.getOuterDimsPerm(); |
| 195 | SmallVector<int64_t> unpackInvSrcPerm = |
| 196 | computePackUnPackPerm(rank: unpackRank, innerDimsPos&: innerDimPos, outerPerm, packingMetadata&: metadata); |
| 197 | return unpackInvSrcPerm; |
| 198 | } |
| 199 | |
| 200 | bool allIndexingsAreProjectedPermutation(LinalgOp op) { |
| 201 | return llvm::all_of(Range: op.getIndexingMapsArray(), P: [](AffineMap m) { |
| 202 | return m.isProjectedPermutation(/*allowZeroInResults=*/true); |
| 203 | }); |
| 204 | } |
| 205 | |
| 206 | bool hasOnlyScalarElementwiseOp(Region &r) { |
| 207 | if (!llvm::hasSingleElement(C&: r)) |
| 208 | return false; |
| 209 | for (Operation &op : r.front()) { |
| 210 | if (!(isa<arith::ConstantOp, func::ConstantOp, tensor::ExtractOp, |
| 211 | linalg::YieldOp, linalg::IndexOp, AffineApplyOp>(Val: op) || |
| 212 | OpTrait::hasElementwiseMappableTraits(op: &op)) || |
| 213 | llvm::any_of(Range: op.getResultTypes(), |
| 214 | P: [](Type type) { return !type.isIntOrIndexOrFloat(); })) |
| 215 | return false; |
| 216 | } |
| 217 | return true; |
| 218 | } |
| 219 | |
| 220 | bool isElementwise(LinalgOp op) { |
| 221 | if (op.getNumLoops() != op.getNumParallelLoops()) |
| 222 | return false; |
| 223 | |
| 224 | if (!allIndexingsAreProjectedPermutation(op)) |
| 225 | return false; |
| 226 | |
| 227 | // TODO: relax the restrictions on indexing map. |
| 228 | for (OpOperand &opOperand : op.getDpsInitsMutable()) { |
| 229 | if (!op.getMatchingIndexingMap(opOperand: &opOperand).isPermutation()) |
| 230 | return false; |
| 231 | } |
| 232 | return hasOnlyScalarElementwiseOp(r&: op->getRegion(index: 0)); |
| 233 | } |
| 234 | |
| 235 | bool isParallelIterator(utils::IteratorType iteratorType) { |
| 236 | return iteratorType == utils::IteratorType::parallel; |
| 237 | } |
| 238 | |
| 239 | bool isReductionIterator(utils::IteratorType iteratorType) { |
| 240 | return iteratorType == utils::IteratorType::reduction; |
| 241 | } |
| 242 | |
| 243 | Value makeComposedPadHighOp(OpBuilder &b, Location loc, RankedTensorType type, |
| 244 | Value source, Value pad, bool nofold, |
| 245 | ValueRange typeDynDims) { |
| 246 | // Exit if `source` is not defined by an ExtractSliceOp. |
| 247 | auto sliceOp = source.getDefiningOp<tensor::ExtractSliceOp>(); |
| 248 | if (!sliceOp) |
| 249 | return tensor::createPadHighOp(resType: type, source, pad, nofold, loc, builder&: b, |
| 250 | dynOutDims: typeDynDims); |
| 251 | |
| 252 | // Search the `source` use-def chain for padded LinalgOps. |
| 253 | Value current = sliceOp.getSource(); |
| 254 | while (current) { |
| 255 | auto linalgOp = current.getDefiningOp<LinalgOp>(); |
| 256 | if (!linalgOp) |
| 257 | break; |
| 258 | OpResult opResult = cast<OpResult>(Val&: current); |
| 259 | current = linalgOp.getDpsInitOperand(i: opResult.getResultNumber())->get(); |
| 260 | } |
| 261 | auto padOp = current ? current.getDefiningOp<tensor::PadOp>() : nullptr; |
| 262 | |
| 263 | // Exit if the search fails to match a tensor::PadOp at the end of the matched |
| 264 | // LinalgOp sequence. |
| 265 | if (!padOp) |
| 266 | return tensor::createPadHighOp(resType: type, source, pad, nofold, loc, builder&: b, |
| 267 | dynOutDims: typeDynDims); |
| 268 | |
| 269 | // Exit if the padded result type does not match. |
| 270 | if (sliceOp.getSource().getType() != type) |
| 271 | return tensor::createPadHighOp(resType: type, source, pad, nofold, loc, builder&: b, |
| 272 | dynOutDims: typeDynDims); |
| 273 | |
| 274 | // Exit if the LinalgOps are not high padded. |
| 275 | if (llvm::any_of(Range: padOp.getMixedLowPad(), P: [](OpFoldResult ofr) { |
| 276 | return getConstantIntValue(ofr) != static_cast<int64_t>(0); |
| 277 | })) |
| 278 | return tensor::createPadHighOp(resType: type, source, pad, nofold, loc, builder&: b, |
| 279 | dynOutDims: typeDynDims); |
| 280 | |
| 281 | // Exit if `padOpSliceOp`, which defines the slice used by |
| 282 | // `padOp`, is rank-reducing. |
| 283 | auto padOpSliceOp = padOp.getSource().getDefiningOp<tensor::ExtractSliceOp>(); |
| 284 | if (!padOpSliceOp || |
| 285 | sliceOp.getMixedSizes().size() != padOpSliceOp.getMixedSizes().size()) |
| 286 | return tensor::createPadHighOp(resType: type, source, pad, nofold, loc, builder&: b, |
| 287 | dynOutDims: typeDynDims); |
| 288 | |
| 289 | // Exit if the sizes of the dynamic sizes of `sliceOp` do not match the size |
| 290 | // of the slice padded by `padOp`. |
| 291 | if (llvm::any_of( |
| 292 | Range: llvm::zip(t: sliceOp.getMixedSizes(), u: padOpSliceOp.getMixedSizes()), |
| 293 | P: [](std::tuple<OpFoldResult, OpFoldResult> it) { |
| 294 | return !isEqualConstantIntOrValue(ofr1: std::get<0>(t&: it), ofr2: std::get<1>(t&: it)); |
| 295 | })) |
| 296 | return tensor::createPadHighOp(resType: type, source, pad, nofold, loc, builder&: b, |
| 297 | dynOutDims: typeDynDims); |
| 298 | |
| 299 | // Exit if the padding values do not match. |
| 300 | Attribute padOpPadAttr, padAttr; |
| 301 | Value padOpPad = padOp.getConstantPaddingValue(); |
| 302 | if (!padOpPad || !matchPattern(value: padOpPad, pattern: m_Constant(bind_value: &padOpPadAttr)) || |
| 303 | !matchPattern(value: pad, pattern: m_Constant(bind_value: &padAttr)) || padOpPadAttr != padAttr) |
| 304 | return tensor::createPadHighOp(resType: type, source, pad, nofold, loc, builder&: b, |
| 305 | dynOutDims: typeDynDims); |
| 306 | |
| 307 | // Return the padded result if the padding values and sizes match. |
| 308 | return sliceOp.getSource(); |
| 309 | } |
| 310 | |
| 311 | GenericOp makeMemRefCopyOp(OpBuilder &b, Location loc, Value from, Value to) { |
| 312 | auto memrefTypeTo = cast<MemRefType>(Val: to.getType()); |
| 313 | #ifndef NDEBUG |
| 314 | auto memrefTypeFrom = cast<MemRefType>(from.getType()); |
| 315 | assert(memrefTypeFrom.getRank() == memrefTypeTo.getRank() && |
| 316 | "`from` and `to` memref must have the same rank" ); |
| 317 | #endif // NDEBUG |
| 318 | |
| 319 | AffineMap id = |
| 320 | AffineMap::getMultiDimIdentityMap(numDims: memrefTypeTo.getRank(), context: b.getContext()); |
| 321 | SmallVector<utils::IteratorType> iteratorTypes(memrefTypeTo.getRank(), |
| 322 | utils::IteratorType::parallel); |
| 323 | return b.create<linalg::GenericOp>( |
| 324 | location: loc, |
| 325 | /*inputs=*/args&: from, |
| 326 | /*outputs=*/args&: to, |
| 327 | /*indexingMaps=*/args: llvm::ArrayRef({id, id}), |
| 328 | /*iteratorTypes=*/args&: iteratorTypes, |
| 329 | args: [](OpBuilder &b, Location loc, ValueRange args) { |
| 330 | b.create<linalg::YieldOp>(location: loc, args: args.front()); |
| 331 | }); |
| 332 | } |
| 333 | |
| 334 | /// Specialization to build an scf "for" nest. |
| 335 | template <> |
| 336 | void GenerateLoopNest<scf::ForOp>::doit( |
| 337 | OpBuilder &b, Location loc, ArrayRef<Range> loopRanges, LinalgOp linalgOp, |
| 338 | ArrayRef<utils::IteratorType> iteratorTypes, |
| 339 | function_ref<scf::ValueVector(OpBuilder &, Location, ValueRange, |
| 340 | ValueRange)> |
| 341 | bodyBuilderFn, |
| 342 | ArrayRef<linalg::ProcInfo> procInfo) { |
| 343 | assert((procInfo.empty() || (procInfo.size() == loopRanges.size())) && |
| 344 | "expected as many entries for proc info as number of loops, even if " |
| 345 | "they are null entries" ); |
| 346 | SmallVector<Value> iterArgInitValues; |
| 347 | if (!linalgOp.hasPureBufferSemantics()) |
| 348 | llvm::append_range(C&: iterArgInitValues, R: linalgOp.getDpsInits()); |
| 349 | SmallVector<Value, 4> lbs, ubs, steps; |
| 350 | unpackRanges(builder&: b, loc, ranges: loopRanges, lbs, ubs, steps); |
| 351 | LoopNest loopNest = mlir::scf::buildLoopNest( |
| 352 | builder&: b, loc, lbs, ubs, steps, iterArgs: iterArgInitValues, |
| 353 | bodyBuilder: [&](OpBuilder &b, Location loc, ValueRange ivs, ValueRange iterArgs) { |
| 354 | assert(iterArgs.size() == iterArgInitValues.size() && |
| 355 | "expect the number of output tensors and iter args to match" ); |
| 356 | SmallVector<Value> operandValuesToUse = linalgOp->getOperands(); |
| 357 | if (!iterArgs.empty()) { |
| 358 | operandValuesToUse = linalgOp.getDpsInputs(); |
| 359 | operandValuesToUse.append(in_start: iterArgs.begin(), in_end: iterArgs.end()); |
| 360 | } |
| 361 | return bodyBuilderFn(b, loc, ivs, operandValuesToUse); |
| 362 | }); |
| 363 | |
| 364 | if (loopNest.loops.empty() || procInfo.empty()) |
| 365 | return; |
| 366 | |
| 367 | // Filter out scf.for loops that were created out of parallel dimensions. |
| 368 | for (const auto &loop : llvm::enumerate(First&: loopNest.loops)) { |
| 369 | if (procInfo[loop.index()].distributionMethod == |
| 370 | DistributionMethod::Cyclic) { |
| 371 | mapLoopToProcessorIds(forOp: loop.value(), processorId: procInfo[loop.index()].procId, |
| 372 | numProcessors: procInfo[loop.index()].nprocs); |
| 373 | } |
| 374 | } |
| 375 | } |
| 376 | |
| 377 | /// Specialization to build affine "for" nest. |
| 378 | template <> |
| 379 | void GenerateLoopNest<AffineForOp>::doit( |
| 380 | OpBuilder &b, Location loc, ArrayRef<Range> loopRanges, LinalgOp linalgOp, |
| 381 | ArrayRef<utils::IteratorType> iteratorTypes, |
| 382 | function_ref<scf::ValueVector(OpBuilder &, Location, ValueRange, |
| 383 | ValueRange)> |
| 384 | bodyBuilderFn, |
| 385 | ArrayRef<linalg::ProcInfo> /*procInfo*/) { |
| 386 | SmallVector<Value> iterArgInitValues; |
| 387 | if (!linalgOp.hasPureBufferSemantics()) |
| 388 | llvm::append_range(C&: iterArgInitValues, R: linalgOp.getDpsInits()); |
| 389 | assert(iterArgInitValues.empty() && "unexpected AffineForOp init values" ); |
| 390 | SmallVector<Value, 4> lbs, ubs, steps; |
| 391 | unpackRanges(builder&: b, loc, ranges: loopRanges, lbs, ubs, steps); |
| 392 | |
| 393 | // Affine loops require constant steps. |
| 394 | SmallVector<int64_t, 4> constantSteps; |
| 395 | constantSteps.reserve(N: steps.size()); |
| 396 | for (Value v : steps) { |
| 397 | auto constVal = getConstantIntValue(ofr: v); |
| 398 | assert(constVal.has_value() && "Affine loops require constant steps" ); |
| 399 | constantSteps.push_back(Elt: constVal.value()); |
| 400 | } |
| 401 | |
| 402 | affine::buildAffineLoopNest(builder&: b, loc, lbs, ubs, steps: constantSteps, |
| 403 | bodyBuilderFn: [&](OpBuilder &b, Location loc, ValueRange ivs) { |
| 404 | bodyBuilderFn(b, loc, ivs, |
| 405 | linalgOp->getOperands()); |
| 406 | }); |
| 407 | } |
| 408 | |
| 409 | /// Update the `lb`, `ub` and `step` to get per processor `lb`, `ub` and `step`. |
| 410 | void updateBoundsForCyclicDistribution(OpBuilder &b, Location loc, Value procId, |
| 411 | Value nprocs, Value &lb, Value &ub, |
| 412 | Value &step) { |
| 413 | AffineExpr d0, d1; |
| 414 | bindDims(ctx: b.getContext(), exprs&: d0, exprs&: d1); |
| 415 | AffineExpr s0 = getAffineSymbolExpr(position: 0, context: b.getContext()); |
| 416 | lb = |
| 417 | affine::makeComposedAffineApply(b, loc, e: d0 + d1 * s0, operands: {lb, procId, step}); |
| 418 | step = affine::makeComposedAffineApply(b, loc, e: d0 * s0, operands: {nprocs, step}); |
| 419 | } |
| 420 | |
| 421 | /// Generates a loop nest consisting of scf.parallel and scf.for, depending |
| 422 | /// on the `iteratorTypes.` Consecutive parallel loops create a single |
| 423 | /// scf.parallel operation; each sequential loop creates a new scf.for |
| 424 | /// operation. The body of the innermost loop is populated by |
| 425 | /// `bodyBuilderFn` that accepts a range of induction variables for all |
| 426 | /// loops. `ivStorage` is used to store the partial list of induction |
| 427 | /// variables. |
| 428 | // TODO: this function can be made iterative instead. However, it |
| 429 | // will have at most as many recursive calls as nested loops, which rarely |
| 430 | // exceeds 10. |
| 431 | static void generateParallelLoopNest( |
| 432 | OpBuilder &b, Location loc, ValueRange lbs, ValueRange ubs, |
| 433 | ValueRange steps, ArrayRef<utils::IteratorType> iteratorTypes, |
| 434 | ArrayRef<linalg::ProcInfo> procInfo, |
| 435 | function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuilderFn, |
| 436 | SmallVectorImpl<Value> &ivStorage) { |
| 437 | assert(lbs.size() == ubs.size()); |
| 438 | assert(lbs.size() == steps.size()); |
| 439 | assert(lbs.size() == iteratorTypes.size()); |
| 440 | assert(procInfo.empty() || (lbs.size() == procInfo.size())); |
| 441 | |
| 442 | // If there are no (more) loops to be generated, generate the body and be |
| 443 | // done with it. |
| 444 | if (iteratorTypes.empty()) { |
| 445 | bodyBuilderFn(b, loc, ivStorage); |
| 446 | return; |
| 447 | } |
| 448 | |
| 449 | // If there are no outer parallel loops, generate one sequential loop and |
| 450 | // recurse. |
| 451 | if (!isParallelIterator(iteratorType: iteratorTypes.front())) { |
| 452 | LoopNest singleLoop = buildLoopNest( |
| 453 | builder&: b, loc, lbs: lbs.take_front(), ubs: ubs.take_front(), steps: steps.take_front(), |
| 454 | bodyBuilder: [&](OpBuilder &b, Location loc, ValueRange ivs) { |
| 455 | ivStorage.append(in_start: ivs.begin(), in_end: ivs.end()); |
| 456 | generateParallelLoopNest( |
| 457 | b, loc, lbs: lbs.drop_front(), ubs: ubs.drop_front(), steps: steps.drop_front(), |
| 458 | iteratorTypes: iteratorTypes.drop_front(), |
| 459 | procInfo: procInfo.empty() ? procInfo : procInfo.drop_front(), |
| 460 | bodyBuilderFn, ivStorage); |
| 461 | }); |
| 462 | return; |
| 463 | } |
| 464 | |
| 465 | unsigned nLoops = iteratorTypes.size(); |
| 466 | unsigned numProcessed = 0; |
| 467 | DistributionMethod distributionMethod = DistributionMethod::None; |
| 468 | if (procInfo.empty()) { |
| 469 | numProcessed = nLoops - iteratorTypes.drop_while(Pred: isParallelIterator).size(); |
| 470 | } else { |
| 471 | distributionMethod = procInfo.front().distributionMethod; |
| 472 | numProcessed = |
| 473 | nLoops - procInfo |
| 474 | .drop_while(Pred: [&](linalg::ProcInfo p) { |
| 475 | return p.distributionMethod == distributionMethod; |
| 476 | }) |
| 477 | .size(); |
| 478 | } |
| 479 | |
| 480 | auto remainderProcInfo = |
| 481 | procInfo.empty() ? procInfo : procInfo.drop_front(N: numProcessed); |
| 482 | switch (distributionMethod) { |
| 483 | case DistributionMethod::None: { |
| 484 | // Generate a single parallel loop-nest operation for all outermost |
| 485 | // parallel loops and recurse. |
| 486 | b.create<scf::ParallelOp>( |
| 487 | location: loc, args: lbs.take_front(n: numProcessed), args: ubs.take_front(n: numProcessed), |
| 488 | args: steps.take_front(n: numProcessed), |
| 489 | args: [&](OpBuilder &nestedBuilder, Location nestedLoc, ValueRange localIvs) { |
| 490 | ivStorage.append(in_start: localIvs.begin(), in_end: localIvs.end()); |
| 491 | generateParallelLoopNest( |
| 492 | b&: nestedBuilder, loc: nestedLoc, lbs: lbs.drop_front(n: numProcessed), |
| 493 | ubs: ubs.drop_front(n: numProcessed), steps: steps.drop_front(n: numProcessed), |
| 494 | iteratorTypes: iteratorTypes.drop_front(N: numProcessed), procInfo: remainderProcInfo, |
| 495 | bodyBuilderFn, ivStorage); |
| 496 | }); |
| 497 | return; |
| 498 | } |
| 499 | case DistributionMethod::Cyclic: { |
| 500 | // Generate a single parallel loop-nest operation for all outermost |
| 501 | // parallel loops and recurse. |
| 502 | b.create<scf::ParallelOp>( |
| 503 | location: loc, args: lbs.take_front(n: numProcessed), args: ubs.take_front(n: numProcessed), |
| 504 | args: steps.take_front(n: numProcessed), |
| 505 | args: [&](OpBuilder &nestedBuilder, Location nestedLoc, ValueRange localIvs) { |
| 506 | ivStorage.append(in_start: localIvs.begin(), in_end: localIvs.end()); |
| 507 | generateParallelLoopNest( |
| 508 | b&: nestedBuilder, loc: nestedLoc, lbs: lbs.drop_front(n: numProcessed), |
| 509 | ubs: ubs.drop_front(n: numProcessed), steps: steps.drop_front(n: numProcessed), |
| 510 | iteratorTypes: iteratorTypes.drop_front(N: numProcessed), procInfo: remainderProcInfo, |
| 511 | bodyBuilderFn, ivStorage); |
| 512 | }); |
| 513 | return; |
| 514 | } |
| 515 | case DistributionMethod::CyclicNumProcsGeNumIters: { |
| 516 | // Check (for the processed loops) that the iteration is in-bounds. |
| 517 | ArithBuilder ab(b, loc); |
| 518 | Value cond = ab.slt(lhs: lbs[0], rhs: ubs[0]); |
| 519 | for (unsigned i = 1; i < numProcessed; ++i) |
| 520 | cond = ab._and(lhs: cond, rhs: ab.slt(lhs: lbs[i], rhs: ubs[i])); |
| 521 | ivStorage.append(in_start: lbs.begin(), in_end: std::next(x: lbs.begin(), n: numProcessed)); |
| 522 | b.create<scf::IfOp>(location: loc, args&: cond, args: [&](OpBuilder &b, Location loc) { |
| 523 | generateParallelLoopNest(b, loc, lbs: lbs.drop_front(n: numProcessed), |
| 524 | ubs: ubs.drop_front(n: numProcessed), |
| 525 | steps: steps.drop_front(n: numProcessed), |
| 526 | iteratorTypes: iteratorTypes.drop_front(N: numProcessed), |
| 527 | procInfo: remainderProcInfo, bodyBuilderFn, ivStorage); |
| 528 | b.create<scf::YieldOp>(location: loc, args: ValueRange{}); |
| 529 | }); |
| 530 | return; |
| 531 | } |
| 532 | case DistributionMethod::CyclicNumProcsEqNumIters: |
| 533 | // No check/loops needed here. Set the `%iv` to be the `%lb` and proceed |
| 534 | // with inner loop generation. |
| 535 | ivStorage.append(in_start: lbs.begin(), in_end: std::next(x: lbs.begin(), n: numProcessed)); |
| 536 | generateParallelLoopNest( |
| 537 | b, loc, lbs: lbs.drop_front(n: numProcessed), ubs: ubs.drop_front(n: numProcessed), |
| 538 | steps: steps.drop_front(n: numProcessed), iteratorTypes: iteratorTypes.drop_front(N: numProcessed), |
| 539 | procInfo: remainderProcInfo, bodyBuilderFn, ivStorage); |
| 540 | return; |
| 541 | } |
| 542 | } |
| 543 | |
| 544 | /// Specialization for generating a mix of parallel and sequential scf loops. |
| 545 | template <> |
| 546 | void GenerateLoopNest<scf::ParallelOp>::doit( |
| 547 | OpBuilder &b, Location loc, ArrayRef<Range> loopRanges, LinalgOp linalgOp, |
| 548 | ArrayRef<utils::IteratorType> iteratorTypes, |
| 549 | function_ref<scf::ValueVector(OpBuilder &, Location, ValueRange, |
| 550 | ValueRange)> |
| 551 | bodyBuilderFn, |
| 552 | ArrayRef<linalg::ProcInfo> procInfo) { |
| 553 | SmallVector<Value> iterArgInitValues; |
| 554 | if (!linalgOp.hasPureBufferSemantics()) |
| 555 | llvm::append_range(C&: iterArgInitValues, R: linalgOp.getDpsInits()); |
| 556 | assert(iterArgInitValues.empty() && "unexpected ParallelOp init values" ); |
| 557 | // This function may be passed more iterator types than ranges. |
| 558 | assert(iteratorTypes.size() >= loopRanges.size() && |
| 559 | "expected iterator type for all ranges" ); |
| 560 | assert((procInfo.empty() || (procInfo.size() == loopRanges.size())) && |
| 561 | "expected proc information for all loops when present" ); |
| 562 | iteratorTypes = iteratorTypes.take_front(N: loopRanges.size()); |
| 563 | SmallVector<Value, 8> lbsStorage, ubsStorage, stepsStorage, ivs; |
| 564 | unsigned numLoops = iteratorTypes.size(); |
| 565 | ivs.reserve(N: numLoops); |
| 566 | lbsStorage.reserve(N: numLoops); |
| 567 | ubsStorage.reserve(N: numLoops); |
| 568 | stepsStorage.reserve(N: numLoops); |
| 569 | |
| 570 | // Get the loop lb, ub, and step. |
| 571 | unpackRanges(builder&: b, loc, ranges: loopRanges, lbs&: lbsStorage, ubs&: ubsStorage, steps&: stepsStorage); |
| 572 | |
| 573 | // Modify the lb, ub, and step based on the distribution options. |
| 574 | for (const auto &it : llvm::enumerate(First&: procInfo)) { |
| 575 | if (it.value().distributionMethod != linalg::DistributionMethod::None) { |
| 576 | updateBoundsForCyclicDistribution( |
| 577 | b, loc, procId: it.value().procId, nprocs: it.value().nprocs, lb&: lbsStorage[it.index()], |
| 578 | ub&: ubsStorage[it.index()], step&: stepsStorage[it.index()]); |
| 579 | } |
| 580 | } |
| 581 | ValueRange lbs(lbsStorage), ubs(ubsStorage), steps(stepsStorage); |
| 582 | generateParallelLoopNest( |
| 583 | b, loc, lbs, ubs, steps, iteratorTypes, procInfo, |
| 584 | bodyBuilderFn: [&](OpBuilder &b, Location loc, ValueRange ivs) { |
| 585 | bodyBuilderFn(b, loc, ivs, linalgOp->getOperands()); |
| 586 | }, |
| 587 | ivStorage&: ivs); |
| 588 | |
| 589 | assert(ivs.size() == iteratorTypes.size() && "did not generate enough loops" ); |
| 590 | } |
| 591 | |
| 592 | static Operation *materializeTiledShape(OpBuilder &builder, Location loc, |
| 593 | Value valueToTile, |
| 594 | const SliceParameters &sliceParams) { |
| 595 | auto shapedType = dyn_cast<ShapedType>(Val: valueToTile.getType()); |
| 596 | auto *sliceOp = TypeSwitch<ShapedType, Operation *>(shapedType) |
| 597 | .Case(caseFn: [&](MemRefType) { |
| 598 | return builder.create<memref::SubViewOp>( |
| 599 | location: loc, args&: valueToTile, args: sliceParams.offsets, |
| 600 | args: sliceParams.sizes, args: sliceParams.strides); |
| 601 | }) |
| 602 | .Case(caseFn: [&](RankedTensorType) { |
| 603 | return builder.create<tensor::ExtractSliceOp>( |
| 604 | location: loc, args&: valueToTile, args: sliceParams.offsets, |
| 605 | args: sliceParams.sizes, args: sliceParams.strides); |
| 606 | }) |
| 607 | .Default(defaultFn: [](ShapedType) -> Operation * { |
| 608 | llvm_unreachable("Unexpected shaped type" ); |
| 609 | }); |
| 610 | return sliceOp; |
| 611 | } |
| 612 | |
| 613 | Operation *makeTiledShape(OpBuilder &builder, Location loc, Value valueToTile, |
| 614 | ArrayRef<OpFoldResult> tileSizes, AffineMap map, |
| 615 | ArrayRef<OpFoldResult> lbs, |
| 616 | ArrayRef<OpFoldResult> ubs, |
| 617 | ArrayRef<OpFoldResult> subShapeSizes, |
| 618 | bool omitPartialTileCheck) { |
| 619 | SliceParameters sliceParams = |
| 620 | computeSliceParameters(builder, loc, valueToTile, tileSizes, map, lbs, |
| 621 | ubs, subShapeSizes, omitPartialTileCheck); |
| 622 | return materializeTiledShape(builder, loc, valueToTile, sliceParams); |
| 623 | } |
| 624 | |
| 625 | SliceParameters |
| 626 | computeSliceParameters(OpBuilder &builder, Location loc, Value valueToTile, |
| 627 | ArrayRef<OpFoldResult> tileSizes, AffineMap map, |
| 628 | ArrayRef<OpFoldResult> lbs, ArrayRef<OpFoldResult> ubs, |
| 629 | ArrayRef<OpFoldResult> subShapeSizes, |
| 630 | bool omitPartialTileCheck) { |
| 631 | auto shapedType = dyn_cast<ShapedType>(Val: valueToTile.getType()); |
| 632 | assert(shapedType && "only shaped types can be tiled" ); |
| 633 | ArrayRef<int64_t> shape = shapedType.getShape(); |
| 634 | int64_t rank = shapedType.getRank(); |
| 635 | |
| 636 | // Compute offsets/sizes/strides for the tile. |
| 637 | SliceParameters sliceParams; |
| 638 | sliceParams.offsets.reserve(N: rank); |
| 639 | sliceParams.sizes.reserve(N: rank); |
| 640 | sliceParams.strides.reserve(N: rank); |
| 641 | for (unsigned r = 0; r < rank; ++r) { |
| 642 | LLVM_DEBUG(llvm::dbgs() << "computeSliceParameters: for dim#" << r); |
| 643 | if (!isTiled(map: map.getSubMap(resultPos: {r}), tileSizes)) { |
| 644 | sliceParams.offsets.push_back(Elt: builder.getIndexAttr(value: 0)); |
| 645 | OpFoldResult dim = createFoldedDimOp(b&: builder, loc, val: valueToTile, dim: r); |
| 646 | sliceParams.sizes.push_back(Elt: dim); |
| 647 | sliceParams.strides.push_back(Elt: builder.getIndexAttr(value: 1)); |
| 648 | LLVM_DEBUG(llvm::dbgs() << ": not tiled: use size: " << dim << "\n" ); |
| 649 | continue; |
| 650 | } |
| 651 | LLVM_DEBUG(llvm::dbgs() << ": tiled: figure out subsize...\n" ); |
| 652 | |
| 653 | // Tiling creates a new slice at the proper index, the slice step is 1 |
| 654 | // (i.e. the op does not subsample, stepping occurs in the loop). |
| 655 | auto m = map.getSubMap(resultPos: {r}); |
| 656 | LLVM_DEBUG(llvm::dbgs() << "computeSliceParameters: submap: " << m << "\n" ); |
| 657 | IRRewriter rewriter(builder); |
| 658 | // The offset of the slice is m(lbs) - m(0). |
| 659 | SmallVector<Attribute> zeros(lbs.size(), rewriter.getIndexAttr(value: 0)); |
| 660 | SmallVector<Attribute> mAtZero; |
| 661 | [[maybe_unused]] auto res = m.constantFold(operandConstants: zeros, results&: mAtZero); |
| 662 | assert(succeeded(res) && "affine_map must be evaluatable (not symbols)" ); |
| 663 | int64_t mAtZeroInt = |
| 664 | cast<IntegerAttr>(Val&: mAtZero[0]).getValue().getSExtValue(); |
| 665 | OpFoldResult offset = makeComposedFoldedAffineApply( |
| 666 | b&: rewriter, loc, expr: m.getResult(idx: 0) - mAtZeroInt, operands: lbs); |
| 667 | sliceParams.offsets.push_back(Elt: offset); |
| 668 | |
| 669 | OpFoldResult closedIntSize = |
| 670 | makeComposedFoldedAffineApply(b&: rewriter, loc, map: m, operands: subShapeSizes); |
| 671 | // Resulting size needs to be made half open interval again. |
| 672 | AffineExpr s0 = getAffineSymbolExpr(position: 0, context: builder.getContext()); |
| 673 | OpFoldResult size = |
| 674 | makeComposedFoldedAffineApply(b&: rewriter, loc, expr: s0 + 1, operands: closedIntSize); |
| 675 | LLVM_DEBUG(llvm::dbgs() |
| 676 | << "computeSliceParameters: raw size: " << size << "\n" ); |
| 677 | LLVM_DEBUG(llvm::dbgs() |
| 678 | << "computeSliceParameters: new offset: " << offset << "\n" ); |
| 679 | sliceParams.strides.push_back(Elt: builder.getIndexAttr(value: 1)); |
| 680 | |
| 681 | if (omitPartialTileCheck) { |
| 682 | // We statically know that the partial/boundary tile condition is |
| 683 | // unnecessary. |
| 684 | LLVM_DEBUG(llvm::dbgs() << "makeTiledShape: new size: " << size << "\n" ); |
| 685 | sliceParams.sizes.push_back(Elt: size); |
| 686 | continue; |
| 687 | } |
| 688 | |
| 689 | // The size of the subview / extract_slice should be trimmed to avoid |
| 690 | // out-of-bounds accesses, unless: |
| 691 | // a. We statically know the subshape size divides the shape size evenly. |
| 692 | // b. The subshape size is 1. According to the way the loops are set up, |
| 693 | // tensors with "0" dimensions would never be constructed. |
| 694 | int64_t shapeSize = shape[r]; |
| 695 | std::optional<int64_t> sizeCst = getConstantIntValue(ofr: size); |
| 696 | auto hasTileSizeOne = sizeCst == 1; |
| 697 | auto dividesEvenly = sizeCst && ShapedType::isStatic(dValue: shapeSize) && |
| 698 | ((shapeSize % *sizeCst) == 0); |
| 699 | if (!hasTileSizeOne && !dividesEvenly) { |
| 700 | LLVM_DEBUG(llvm::dbgs() << "makeTiledShape: shapeSize=" << shapeSize |
| 701 | << ", size: " << size |
| 702 | << ": make sure in bound with affine.min\n" ); |
| 703 | |
| 704 | AffineExpr dim0, dim1, dim2; |
| 705 | MLIRContext *context = builder.getContext(); |
| 706 | bindDims(ctx: context, exprs&: dim0, exprs&: dim1, exprs&: dim2); |
| 707 | |
| 708 | // Get the dimension size for this dimension. We need to first calculate |
| 709 | // the max index and then plus one. This is important because for |
| 710 | // convolution ops, we have its input window dimension's affine map of the |
| 711 | // form `(d0 * s0 + d1)`, where `d0`/`d1 is an output/filter window |
| 712 | // dimension and `s0` is stride. Directly use the dimension size of |
| 713 | // output/filer window dimensions will cause incorrect calculation. |
| 714 | AffineMap minusOneMap = AffineMap::inferFromExprList( |
| 715 | exprsList: {ArrayRef<AffineExpr>{dim0 - 1}}, context) |
| 716 | .front(); |
| 717 | AffineMap plusOneMap = AffineMap::inferFromExprList( |
| 718 | exprsList: {ArrayRef<AffineExpr>{dim0 + 1}}, context) |
| 719 | .front(); |
| 720 | SmallVector<OpFoldResult> maxIndices = |
| 721 | llvm::to_vector(Range: llvm::map_range(C&: ubs, F: [&](OpFoldResult ub) { |
| 722 | return makeComposedFoldedAffineApply(b&: rewriter, loc, map: minusOneMap, |
| 723 | operands: {ub}); |
| 724 | })); |
| 725 | OpFoldResult maxIndex = |
| 726 | makeComposedFoldedAffineApply(b&: rewriter, loc, map: m, operands: maxIndices); |
| 727 | OpFoldResult d = |
| 728 | makeComposedFoldedAffineApply(b&: rewriter, loc, map: plusOneMap, operands: {maxIndex}); |
| 729 | |
| 730 | // Compute min(dim - offset, size) to avoid out-of-bounds accesses. |
| 731 | AffineMap minMap = AffineMap::inferFromExprList( |
| 732 | exprsList: {ArrayRef<AffineExpr>{dim1 - dim2, dim0}}, context) |
| 733 | .front(); |
| 734 | size = |
| 735 | makeComposedFoldedAffineMin(b&: rewriter, loc, map: minMap, operands: {size, d, offset}); |
| 736 | } |
| 737 | LLVM_DEBUG(llvm::dbgs() << "makeTiledShape: new size: " << size << "\n" ); |
| 738 | sliceParams.sizes.push_back(Elt: size); |
| 739 | } |
| 740 | return sliceParams; |
| 741 | } |
| 742 | |
| 743 | SmallVector<OpFoldResult> computeTileOffsets(OpBuilder &b, Location loc, |
| 744 | ArrayRef<OpFoldResult> ivs, |
| 745 | ArrayRef<OpFoldResult> tileSizes) { |
| 746 | SmallVector<OpFoldResult> offsets; |
| 747 | for (unsigned idx = 0, idxIvs = 0, e = tileSizes.size(); idx < e; ++idx) { |
| 748 | LLVM_DEBUG(llvm::dbgs() << "makeTiledShapes: for loop#" << idx << "\n" ); |
| 749 | bool isTiled = !isZeroInteger(v: tileSizes[idx]); |
| 750 | offsets.push_back(Elt: isTiled ? ivs[idxIvs++] : b.getIndexAttr(value: 0)); |
| 751 | LLVM_DEBUG(llvm::dbgs() |
| 752 | << "computeTileOffsets: " << offsets.back() << "\n" ); |
| 753 | } |
| 754 | return offsets; |
| 755 | } |
| 756 | |
| 757 | SmallVector<OpFoldResult> computeTileSizes(OpBuilder &b, Location loc, |
| 758 | ArrayRef<OpFoldResult> tileSizes, |
| 759 | ArrayRef<OpFoldResult> sizeBounds) { |
| 760 | SmallVector<OpFoldResult> sizes; |
| 761 | for (unsigned idx = 0, e = tileSizes.size(); idx < e; ++idx) { |
| 762 | bool isTiled = !isZeroInteger(v: tileSizes[idx]); |
| 763 | // Before composing, we need to make range a closed interval. |
| 764 | OpFoldResult size = isTiled ? tileSizes[idx] : sizeBounds[idx]; |
| 765 | AffineExpr d0 = getAffineDimExpr(position: 0, context: b.getContext()); |
| 766 | IRRewriter rewriter(b); |
| 767 | sizes.push_back(Elt: makeComposedFoldedAffineApply(b&: rewriter, loc, expr: d0 - 1, operands: size)); |
| 768 | LLVM_DEBUG(llvm::dbgs() << "computeTileSizes: " << sizes.back() << "\n" ); |
| 769 | } |
| 770 | return sizes; |
| 771 | } |
| 772 | |
| 773 | SmallVector<Type> getTensorOutputTypes(LinalgOp op, ValueRange operands) { |
| 774 | if (op.hasPureBufferSemantics()) |
| 775 | return {}; |
| 776 | return llvm::to_vector( |
| 777 | Range: llvm::map_range(C: op.getDpsInitsMutable(), F: [&](OpOperand &opOperand) { |
| 778 | return operands[opOperand.getOperandNumber()].getType(); |
| 779 | })); |
| 780 | } |
| 781 | |
| 782 | SmallVector<Value> insertSlicesBack(OpBuilder &builder, Location loc, |
| 783 | LinalgOp op, ValueRange operands, |
| 784 | ValueRange results) { |
| 785 | if (op.hasPureBufferSemantics()) |
| 786 | return {}; |
| 787 | SmallVector<Value> tensorResults; |
| 788 | tensorResults.reserve(N: results.size()); |
| 789 | // Insert a insert_slice for each output tensor. |
| 790 | unsigned resultIdx = 0; |
| 791 | for (OpOperand &opOperand : op.getDpsInitsMutable()) { |
| 792 | // TODO: use an interface/adaptor to avoid leaking position in |
| 793 | // `tiledOperands`. |
| 794 | Value outputTensor = operands[opOperand.getOperandNumber()]; |
| 795 | if (auto sliceOp = outputTensor.getDefiningOp<tensor::ExtractSliceOp>()) { |
| 796 | Value inserted = builder.create<tensor::InsertSliceOp>( |
| 797 | location: loc, args: sliceOp.getSource().getType(), args: results[resultIdx], |
| 798 | args: sliceOp.getSource(), args: sliceOp.getOffsets(), args: sliceOp.getSizes(), |
| 799 | args: sliceOp.getStrides(), args: sliceOp.getStaticOffsets(), |
| 800 | args: sliceOp.getStaticSizes(), args: sliceOp.getStaticStrides()); |
| 801 | tensorResults.push_back(Elt: inserted); |
| 802 | } else { |
| 803 | tensorResults.push_back(Elt: results[resultIdx]); |
| 804 | } |
| 805 | ++resultIdx; |
| 806 | } |
| 807 | return tensorResults; |
| 808 | } |
| 809 | |
| 810 | SmallVector<std::optional<SliceParameters>> |
| 811 | computeAllSliceParameters(OpBuilder &builder, Location loc, LinalgOp linalgOp, |
| 812 | ValueRange valuesToTile, ArrayRef<OpFoldResult> ivs, |
| 813 | ArrayRef<OpFoldResult> tileSizes, |
| 814 | ArrayRef<OpFoldResult> sizeBounds, |
| 815 | bool omitPartialTileCheck) { |
| 816 | assert(ivs.size() == static_cast<size_t>(llvm::count_if( |
| 817 | llvm::make_range(tileSizes.begin(), tileSizes.end()), |
| 818 | [](OpFoldResult v) { return !isZeroInteger(v); })) && |
| 819 | "expected as many ivs as non-zero sizes" ); |
| 820 | |
| 821 | // Construct (potentially temporary) mins and maxes on which to apply maps |
| 822 | // that define tile subshapes. |
| 823 | SmallVector<OpFoldResult> lbs = |
| 824 | computeTileOffsets(b&: builder, loc, ivs, tileSizes); |
| 825 | SmallVector<OpFoldResult> subShapeSizes = |
| 826 | computeTileSizes(b&: builder, loc, tileSizes, sizeBounds); |
| 827 | |
| 828 | assert(static_cast<int64_t>(valuesToTile.size()) <= |
| 829 | linalgOp->getNumOperands() && |
| 830 | "more value to tile than operands." ); |
| 831 | SmallVector<std::optional<SliceParameters>> allSliceParams; |
| 832 | allSliceParams.reserve(N: valuesToTile.size()); |
| 833 | for (auto [opOperand, val] : |
| 834 | llvm::zip(t: linalgOp->getOpOperands(), u&: valuesToTile)) { |
| 835 | Value shapedOp = val; |
| 836 | LLVM_DEBUG(llvm::dbgs() << "makeTiledShapes: for operand " << shapedOp); |
| 837 | AffineMap map = linalgOp.getMatchingIndexingMap(opOperand: &opOperand); |
| 838 | // Use `opOperand` as is if it is not tiled and not an output tensor. Having |
| 839 | // an extract/insert slice pair for all output tensors simplifies follow up |
| 840 | // transformations such as padding and bufferization since the |
| 841 | // extract/insert slice pairs make the accessed iteration argument |
| 842 | // subdomains explicit. |
| 843 | |
| 844 | Type operandType = opOperand.get().getType(); |
| 845 | if (!isTiled(map, tileSizes) && !(isa<RankedTensorType>(Val: operandType) && |
| 846 | linalgOp.isDpsInit(opOperand: &opOperand))) { |
| 847 | allSliceParams.push_back(Elt: std::nullopt); |
| 848 | LLVM_DEBUG(llvm::dbgs() |
| 849 | << ": not tiled: use shape: " << operandType << "\n" ); |
| 850 | continue; |
| 851 | } |
| 852 | LLVM_DEBUG(llvm::dbgs() << ": tiled: figure out subshape...\n" ); |
| 853 | |
| 854 | allSliceParams.push_back(Elt: computeSliceParameters( |
| 855 | builder, loc, valueToTile: shapedOp, tileSizes, map, lbs, ubs: sizeBounds, subShapeSizes, |
| 856 | omitPartialTileCheck)); |
| 857 | } |
| 858 | |
| 859 | return allSliceParams; |
| 860 | } |
| 861 | |
| 862 | SmallVector<Value> makeTiledShapes(OpBuilder &builder, Location loc, |
| 863 | LinalgOp linalgOp, ValueRange valuesToTile, |
| 864 | ArrayRef<OpFoldResult> ivs, |
| 865 | ArrayRef<OpFoldResult> tileSizes, |
| 866 | ArrayRef<OpFoldResult> sizeBounds, |
| 867 | bool omitPartialTileCheck) { |
| 868 | SmallVector<std::optional<SliceParameters>> allSliceParameter = |
| 869 | computeAllSliceParameters(builder, loc, linalgOp, valuesToTile, ivs, |
| 870 | tileSizes, sizeBounds, omitPartialTileCheck); |
| 871 | SmallVector<Value> tiledShapes; |
| 872 | for (auto item : llvm::zip(t&: valuesToTile, u&: allSliceParameter)) { |
| 873 | Value valueToTile = std::get<0>(t&: item); |
| 874 | std::optional<SliceParameters> sliceParams = std::get<1>(t&: item); |
| 875 | tiledShapes.push_back( |
| 876 | Elt: sliceParams.has_value() |
| 877 | ? materializeTiledShape(builder, loc, valueToTile, sliceParams: *sliceParams) |
| 878 | ->getResult(idx: 0) |
| 879 | : valueToTile); |
| 880 | } |
| 881 | return tiledShapes; |
| 882 | } |
| 883 | |
| 884 | void offsetIndices(OpBuilder &b, LinalgOp linalgOp, |
| 885 | ArrayRef<OpFoldResult> offsets) { |
| 886 | IRRewriter rewriter(b); |
| 887 | offsetIndices(b&: rewriter, linalgOp, offests: offsets); |
| 888 | } |
| 889 | |
| 890 | void offsetIndices(RewriterBase &b, LinalgOp linalgOp, |
| 891 | ArrayRef<OpFoldResult> offsets) { |
| 892 | if (!linalgOp.hasIndexSemantics()) |
| 893 | return; |
| 894 | |
| 895 | for (IndexOp indexOp : linalgOp.getBlock()->getOps<IndexOp>()) { |
| 896 | if (indexOp.getDim() >= offsets.size() || !offsets[indexOp.getDim()]) |
| 897 | continue; |
| 898 | OpBuilder::InsertionGuard guard(b); |
| 899 | b.setInsertionPointAfter(indexOp); |
| 900 | AffineExpr index, offset; |
| 901 | bindDims(ctx: b.getContext(), exprs&: index, exprs&: offset); |
| 902 | OpFoldResult applied = makeComposedFoldedAffineApply( |
| 903 | b, loc: indexOp.getLoc(), expr: index + offset, |
| 904 | operands: {getAsOpFoldResult(val: indexOp.getResult()), offsets[indexOp.getDim()]}); |
| 905 | Value materialized = |
| 906 | getValueOrCreateConstantIndexOp(b, loc: indexOp.getLoc(), ofr: applied); |
| 907 | b.replaceUsesWithIf(from: indexOp, to: materialized, functor: [&](OpOperand &use) { |
| 908 | return use.getOwner() != materialized.getDefiningOp(); |
| 909 | }); |
| 910 | } |
| 911 | } |
| 912 | |
| 913 | /// Get the reassociation maps to fold the result of a extract_slice (or source |
| 914 | /// of a insert_slice) operation with given offsets, and sizes to its |
| 915 | /// rank-reduced version. This is only done for the cases where the size is 1 |
| 916 | /// and offset is 0. Strictly speaking the offset 0 is not required in general, |
| 917 | /// but non-zero offsets are not handled by SPIR-V backend at this point (and |
| 918 | /// potentially cannot be handled). |
| 919 | std::optional<SmallVector<ReassociationIndices>> |
| 920 | getReassociationMapForFoldingUnitDims(ArrayRef<OpFoldResult> mixedSizes) { |
| 921 | SmallVector<ReassociationIndices> reassociation; |
| 922 | ReassociationIndices curr; |
| 923 | for (const auto &it : llvm::enumerate(First&: mixedSizes)) { |
| 924 | auto dim = it.index(); |
| 925 | auto size = it.value(); |
| 926 | curr.push_back(Elt: dim); |
| 927 | auto attr = llvm::dyn_cast_if_present<Attribute>(Val&: size); |
| 928 | if (attr && cast<IntegerAttr>(Val&: attr).getInt() == 1) |
| 929 | continue; |
| 930 | reassociation.emplace_back(Args: ReassociationIndices{}); |
| 931 | std::swap(LHS&: reassociation.back(), RHS&: curr); |
| 932 | } |
| 933 | // When the reassociations are not empty, then fold the remaining |
| 934 | // unit-dimensions into the last dimension. If the reassociations so far is |
| 935 | // empty, then leave it emtpy. This will fold everything to a rank-0 tensor. |
| 936 | if (!curr.empty() && !reassociation.empty()) |
| 937 | reassociation.back().append(in_start: curr.begin(), in_end: curr.end()); |
| 938 | return reassociation; |
| 939 | } |
| 940 | |
| 941 | } // namespace linalg |
| 942 | } // namespace mlir |
| 943 | |