1 | //===- Utils.cpp - Utilities to support the Linalg dialect ----------------===// |
---|---|
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | // |
9 | // This file implements utilities for the Linalg dialect. |
10 | // |
11 | //===----------------------------------------------------------------------===// |
12 | |
13 | #include "mlir/Dialect/Linalg/Utils/Utils.h" |
14 | |
15 | #include "mlir/Analysis/SliceAnalysis.h" |
16 | #include "mlir/Dialect/Affine/Analysis/AffineStructures.h" |
17 | #include "mlir/Dialect/Affine/IR/AffineOps.h" |
18 | #include "mlir/Dialect/Affine/IR/AffineValueMap.h" |
19 | #include "mlir/Dialect/Affine/LoopUtils.h" |
20 | #include "mlir/Dialect/Arith/IR/Arith.h" |
21 | #include "mlir/Dialect/Arith/Utils/Utils.h" |
22 | #include "mlir/Dialect/Func/IR/FuncOps.h" |
23 | #include "mlir/Dialect/Linalg/IR/Linalg.h" |
24 | #include "mlir/Dialect/MemRef/IR/MemRef.h" |
25 | #include "mlir/Dialect/SCF/IR/SCF.h" |
26 | #include "mlir/Dialect/Tensor/IR/Tensor.h" |
27 | #include "mlir/Dialect/Tensor/Utils/Utils.h" |
28 | #include "mlir/Dialect/Utils/IndexingUtils.h" |
29 | #include "mlir/Dialect/Utils/StaticValueUtils.h" |
30 | #include "mlir/IR/AffineExpr.h" |
31 | #include "mlir/IR/AffineExprVisitor.h" |
32 | #include "mlir/IR/AffineMap.h" |
33 | #include "mlir/IR/Matchers.h" |
34 | #include "mlir/IR/OpImplementation.h" |
35 | #include "mlir/Pass/Pass.h" |
36 | #include "llvm/ADT/TypeSwitch.h" |
37 | #include "llvm/Support/Debug.h" |
38 | #include <optional> |
39 | |
40 | #define DEBUG_TYPE "linalg-utils" |
41 | |
42 | using namespace mlir; |
43 | using namespace presburger; |
44 | using namespace mlir::affine; |
45 | using namespace mlir::linalg; |
46 | using namespace mlir::scf; |
47 | |
48 | namespace { |
49 | |
50 | // Helper visitor to determine whether an AffineExpr is tiled. |
51 | // This is achieved by traversing every AffineDimExpr with position `pos` and |
52 | // checking whether the corresponding `tileSizes[pos]` is non-zero. |
53 | // This also enforces only positive coefficients occur in multiplications. |
54 | // |
55 | // Example: |
56 | // `d0 + 2 * d1 + d3` is tiled by [0, 0, 0, 2] but not by [0, 0, 2, 0] |
57 | // |
58 | struct TileCheck : public AffineExprVisitor<TileCheck> { |
59 | TileCheck(ArrayRef<OpFoldResult> tileSizes) : tileSizes(tileSizes) {} |
60 | |
61 | void visitDimExpr(AffineDimExpr expr) { |
62 | isTiled |= !isZeroInteger(v: tileSizes[expr.getPosition()]); |
63 | } |
64 | void visitAffineBinaryOpExpr(AffineBinaryOpExpr expr) { |
65 | visit(expr: expr.getLHS()); |
66 | visit(expr: expr.getRHS()); |
67 | if (expr.getKind() == mlir::AffineExprKind::Mul) |
68 | assert(cast<AffineConstantExpr>(expr.getRHS()).getValue() > 0 && |
69 | "nonpositive multiplying coefficient"); |
70 | } |
71 | bool isTiled = false; |
72 | ArrayRef<OpFoldResult> tileSizes; |
73 | }; |
74 | |
75 | } // namespace |
76 | |
77 | static bool isTiled(AffineExpr expr, ArrayRef<OpFoldResult> tileSizes) { |
78 | if (!expr) |
79 | return false; |
80 | TileCheck t(tileSizes); |
81 | t.visit(expr); |
82 | return t.isTiled; |
83 | } |
84 | |
85 | // Checks whether the `map varies with respect to a non-zero `tileSize`. |
86 | static bool isTiled(AffineMap map, ArrayRef<OpFoldResult> tileSizes) { |
87 | if (!map) |
88 | return false; |
89 | for (unsigned r = 0; r < map.getNumResults(); ++r) |
90 | if (isTiled(expr: map.getResult(idx: r), tileSizes)) |
91 | return true; |
92 | return false; |
93 | } |
94 | |
95 | std::optional<RegionMatcher::BinaryOpKind> |
96 | RegionMatcher::matchAsScalarBinaryOp(GenericOp op) { |
97 | auto ®ion = op.getRegion(); |
98 | if (!llvm::hasSingleElement(region)) |
99 | return std::nullopt; |
100 | |
101 | Block &block = region.front(); |
102 | if (block.getNumArguments() != 2 || |
103 | !block.getArgument(i: 0).getType().isSignlessIntOrFloat() || |
104 | !block.getArgument(i: 1).getType().isSignlessIntOrFloat()) |
105 | return std::nullopt; |
106 | |
107 | auto &ops = block.getOperations(); |
108 | if (!llvm::hasSingleElement(C: block.without_terminator())) |
109 | return std::nullopt; |
110 | |
111 | using mlir::matchers::m_Val; |
112 | auto a = m_Val(v: block.getArgument(i: 0)); |
113 | auto b = m_Val(v: block.getArgument(i: 1)); |
114 | |
115 | auto addPattern = m_Op<linalg::YieldOp>(m_Op<arith::AddIOp>(a, b)); |
116 | if (addPattern.match(&ops.back())) |
117 | return BinaryOpKind::IAdd; |
118 | |
119 | return std::nullopt; |
120 | } |
121 | |
122 | /// Explicit instantiation of loop nest generator for different loop types. |
123 | template struct mlir::linalg::GenerateLoopNest<scf::ForOp>; |
124 | template struct mlir::linalg::GenerateLoopNest<scf::ParallelOp>; |
125 | template struct mlir::linalg::GenerateLoopNest<AffineForOp>; |
126 | |
127 | /// Given a list of subview ranges, extract individual values for lower, upper |
128 | /// bounds and steps and put them into the corresponding vectors. |
129 | static void unpackRanges(OpBuilder &builder, Location loc, |
130 | ArrayRef<Range> ranges, SmallVectorImpl<Value> &lbs, |
131 | SmallVectorImpl<Value> &ubs, |
132 | SmallVectorImpl<Value> &steps) { |
133 | for (Range range : ranges) { |
134 | lbs.emplace_back( |
135 | Args: getValueOrCreateConstantIndexOp(b&: builder, loc, ofr: range.offset)); |
136 | ubs.emplace_back(Args: getValueOrCreateConstantIndexOp(b&: builder, loc, ofr: range.size)); |
137 | steps.emplace_back( |
138 | Args: getValueOrCreateConstantIndexOp(b&: builder, loc, ofr: range.stride)); |
139 | } |
140 | } |
141 | |
142 | //===----------------------------------------------------------------------===// |
143 | // General utilities |
144 | //===----------------------------------------------------------------------===// |
145 | // |
146 | /// The permutation can be obtained from two permutations: |
147 | /// a) Compute the permutation vector to move the last `numPackedDims` into |
148 | /// the `innerPosDims` of a shape of rank `rank`. |
149 | /// b) Compute the permutation vector to move outer dims if the |
150 | /// `outerPerm` parameter is not empty. |
151 | /// Apply (b) permutation on (a) permutation to get the final permutation. |
152 | static SmallVector<int64_t> |
153 | computePackUnPackPerm(int64_t rank, ArrayRef<int64_t> &innerDimsPos, |
154 | ArrayRef<int64_t> &outerPerm, |
155 | PackingMetadata &packingMetadata) { |
156 | int64_t numPackedDims = innerDimsPos.size(); |
157 | auto lastDims = |
158 | llvm::to_vector(Range: llvm::seq<int64_t>(Begin: rank - numPackedDims, End: rank)); |
159 | packingMetadata = computePackingMetadata(packedRank: rank, innerDimPos: innerDimsPos); |
160 | SmallVector<int64_t> innerPositionsPerm = |
161 | computePermutationVector(permSize: rank, positions: lastDims, desiredPositions: packingMetadata.insertPositions); |
162 | |
163 | SmallVector<int64_t> outerPos = packingMetadata.outerPositions; |
164 | if (!outerPerm.empty()) |
165 | applyPermutationToVector(inVec&: outerPos, permutation: outerPerm); |
166 | SmallVector<int64_t> outerPositionPerm = |
167 | computePermutationVector(permSize: rank, positions: packingMetadata.outerPositions, desiredPositions: outerPos); |
168 | |
169 | SmallVector<int64_t> packInverseDestPermutation = innerPositionsPerm; |
170 | applyPermutationToVector(inVec&: packInverseDestPermutation, permutation: outerPositionPerm); |
171 | return packInverseDestPermutation; |
172 | } |
173 | |
174 | namespace mlir { |
175 | namespace linalg { |
176 | |
177 | SmallVector<int64_t> getPackInverseDestPerm(PackOp packOp) { |
178 | |
179 | PackingMetadata pMetadata; |
180 | int64_t packedRank = packOp.getDestType().getRank(); |
181 | ArrayRef<int64_t> innerDimPos = packOp.getInnerDimsPos(); |
182 | ArrayRef<int64_t> outerPerm = packOp.getOuterDimsPerm(); |
183 | SmallVector<int64_t> packInvDestPerm = |
184 | computePackUnPackPerm(rank: packedRank, innerDimsPos&: innerDimPos, outerPerm, packingMetadata&: pMetadata); |
185 | return packInvDestPerm; |
186 | } |
187 | |
188 | SmallVector<int64_t> getUnPackInverseSrcPerm(UnPackOp unpackOp) { |
189 | PackingMetadata metadata; |
190 | return getUnPackInverseSrcPerm(unpackOp, metadata); |
191 | } |
192 | |
193 | SmallVector<int64_t> getUnPackInverseSrcPerm(UnPackOp unpackOp, |
194 | PackingMetadata &metadata) { |
195 | int64_t unpackRank = unpackOp.getSourceType().getRank(); |
196 | ArrayRef<int64_t> innerDimPos = unpackOp.getInnerDimsPos(); |
197 | ArrayRef<int64_t> outerPerm = unpackOp.getOuterDimsPerm(); |
198 | SmallVector<int64_t> unpackInvSrcPerm = |
199 | computePackUnPackPerm(rank: unpackRank, innerDimsPos&: innerDimPos, outerPerm, packingMetadata&: metadata); |
200 | return unpackInvSrcPerm; |
201 | } |
202 | |
203 | bool allIndexingsAreProjectedPermutation(LinalgOp op) { |
204 | return llvm::all_of(op.getIndexingMapsArray(), [](AffineMap m) { |
205 | return m.isProjectedPermutation(/*allowZeroInResults=*/true); |
206 | }); |
207 | } |
208 | |
209 | bool hasOnlyScalarElementwiseOp(Region &r) { |
210 | if (!llvm::hasSingleElement(C&: r)) |
211 | return false; |
212 | for (Operation &op : r.front()) { |
213 | if (!(isa<arith::ConstantOp, func::ConstantOp, tensor::ExtractOp, |
214 | linalg::YieldOp, linalg::IndexOp, AffineApplyOp>(op) || |
215 | OpTrait::hasElementwiseMappableTraits(&op)) || |
216 | llvm::any_of(op.getResultTypes(), |
217 | [](Type type) { return !type.isIntOrIndexOrFloat(); })) |
218 | return false; |
219 | } |
220 | return true; |
221 | } |
222 | |
223 | bool isElementwise(LinalgOp op) { |
224 | if (op.getNumLoops() != op.getNumParallelLoops()) |
225 | return false; |
226 | |
227 | if (!allIndexingsAreProjectedPermutation(op)) |
228 | return false; |
229 | |
230 | // TODO: relax the restrictions on indexing map. |
231 | for (OpOperand &opOperand : op.getDpsInitsMutable()) { |
232 | if (!op.getMatchingIndexingMap(&opOperand).isPermutation()) |
233 | return false; |
234 | } |
235 | return hasOnlyScalarElementwiseOp(op->getRegion(0)); |
236 | } |
237 | |
238 | bool isParallelIterator(utils::IteratorType iteratorType) { |
239 | return iteratorType == utils::IteratorType::parallel; |
240 | } |
241 | |
242 | bool isReductionIterator(utils::IteratorType iteratorType) { |
243 | return iteratorType == utils::IteratorType::reduction; |
244 | } |
245 | |
246 | Value makeComposedPadHighOp(OpBuilder &b, Location loc, RankedTensorType type, |
247 | Value source, Value pad, bool nofold) { |
248 | // Exit if `source` is not defined by an ExtractSliceOp. |
249 | auto sliceOp = source.getDefiningOp<tensor::ExtractSliceOp>(); |
250 | if (!sliceOp) |
251 | return tensor::createPadHighOp(type, source, pad, nofold, loc, b); |
252 | |
253 | // Search the `source` use-def chain for padded LinalgOps. |
254 | Value current = sliceOp.getSource(); |
255 | while (current) { |
256 | auto linalgOp = current.getDefiningOp<LinalgOp>(); |
257 | if (!linalgOp) |
258 | break; |
259 | OpResult opResult = cast<OpResult>(Val&: current); |
260 | current = linalgOp.getDpsInitOperand(opResult.getResultNumber())->get(); |
261 | } |
262 | auto padOp = current ? current.getDefiningOp<tensor::PadOp>() : nullptr; |
263 | |
264 | // Exit if the search fails to match a tensor::PadOp at the end of the matched |
265 | // LinalgOp sequence. |
266 | if (!padOp) |
267 | return tensor::createPadHighOp(type, source, pad, nofold, loc, b); |
268 | |
269 | // Exit if the padded result type does not match. |
270 | if (sliceOp.getSource().getType() != type) |
271 | return tensor::createPadHighOp(type, source, pad, nofold, loc, b); |
272 | |
273 | // Exit if the LinalgOps are not high padded. |
274 | if (llvm::any_of(padOp.getMixedLowPad(), [](OpFoldResult ofr) { |
275 | return getConstantIntValue(ofr) != static_cast<int64_t>(0); |
276 | })) |
277 | return tensor::createPadHighOp(type, source, pad, nofold, loc, b); |
278 | |
279 | // Exit if `padOpSliceOp`, which defines the slice used by |
280 | // `padOp`, is rank-reducing. |
281 | auto padOpSliceOp = padOp.getSource().getDefiningOp<tensor::ExtractSliceOp>(); |
282 | if (!padOpSliceOp || |
283 | sliceOp.getMixedSizes().size() != padOpSliceOp.getMixedSizes().size()) |
284 | return tensor::createPadHighOp(type, source, pad, nofold, loc, b); |
285 | |
286 | // Exit if the sizes of the dynamic sizes of `sliceOp` do not match the size |
287 | // of the slice padded by `padOp`. |
288 | if (llvm::any_of( |
289 | llvm::zip(sliceOp.getMixedSizes(), padOpSliceOp.getMixedSizes()), |
290 | [](std::tuple<OpFoldResult, OpFoldResult> it) { |
291 | return !isEqualConstantIntOrValue(ofr1: std::get<0>(t&: it), ofr2: std::get<1>(t&: it)); |
292 | })) |
293 | return tensor::createPadHighOp(type, source, pad, nofold, loc, b); |
294 | |
295 | // Exit if the padding values do not match. |
296 | Attribute padOpPadAttr, padAttr; |
297 | Value padOpPad = padOp.getConstantPaddingValue(); |
298 | if (!padOpPad || !matchPattern(value: padOpPad, pattern: m_Constant(bind_value: &padOpPadAttr)) || |
299 | !matchPattern(value: pad, pattern: m_Constant(bind_value: &padAttr)) || padOpPadAttr != padAttr) |
300 | return tensor::createPadHighOp(type, source, pad, nofold, loc, b); |
301 | |
302 | // Return the padded result if the padding values and sizes match. |
303 | return sliceOp.getSource(); |
304 | } |
305 | |
306 | GenericOp makeMemRefCopyOp(OpBuilder &b, Location loc, Value from, Value to) { |
307 | auto memrefTypeTo = cast<MemRefType>(to.getType()); |
308 | #ifndef NDEBUG |
309 | auto memrefTypeFrom = cast<MemRefType>(from.getType()); |
310 | assert(memrefTypeFrom.getRank() == memrefTypeTo.getRank() && |
311 | "`from` and `to` memref must have the same rank"); |
312 | #endif // NDEBUG |
313 | |
314 | AffineMap id = |
315 | AffineMap::getMultiDimIdentityMap(numDims: memrefTypeTo.getRank(), context: b.getContext()); |
316 | SmallVector<utils::IteratorType> iteratorTypes(memrefTypeTo.getRank(), |
317 | utils::IteratorType::parallel); |
318 | return b.create<linalg::GenericOp>( |
319 | loc, |
320 | /*inputs=*/from, |
321 | /*outputs=*/to, |
322 | /*indexingMaps=*/llvm::ArrayRef({id, id}), |
323 | /*iteratorTypes=*/iteratorTypes, |
324 | [](OpBuilder &b, Location loc, ValueRange args) { |
325 | b.create<linalg::YieldOp>(loc, args.front()); |
326 | }); |
327 | } |
328 | |
329 | /// Specialization to build an scf "for" nest. |
330 | template <> |
331 | void GenerateLoopNest<scf::ForOp>::doit( |
332 | OpBuilder &b, Location loc, ArrayRef<Range> loopRanges, LinalgOp linalgOp, |
333 | ArrayRef<utils::IteratorType> iteratorTypes, |
334 | function_ref<scf::ValueVector(OpBuilder &, Location, ValueRange, |
335 | ValueRange)> |
336 | bodyBuilderFn, |
337 | ArrayRef<linalg::ProcInfo> procInfo) { |
338 | assert((procInfo.empty() || (procInfo.size() == loopRanges.size())) && |
339 | "expected as many entries for proc info as number of loops, even if " |
340 | "they are null entries"); |
341 | SmallVector<Value> iterArgInitValues; |
342 | if (!linalgOp.hasPureBufferSemantics()) |
343 | llvm::append_range(iterArgInitValues, linalgOp.getDpsInits()); |
344 | SmallVector<Value, 4> lbs, ubs, steps; |
345 | unpackRanges(builder&: b, loc, ranges: loopRanges, lbs, ubs, steps); |
346 | LoopNest loopNest = mlir::scf::buildLoopNest( |
347 | b, loc, lbs, ubs, steps, iterArgInitValues, |
348 | [&](OpBuilder &b, Location loc, ValueRange ivs, ValueRange iterArgs) { |
349 | assert(iterArgs.size() == iterArgInitValues.size() && |
350 | "expect the number of output tensors and iter args to match"); |
351 | SmallVector<Value> operandValuesToUse = linalgOp->getOperands(); |
352 | if (!iterArgs.empty()) { |
353 | operandValuesToUse = linalgOp.getDpsInputs(); |
354 | operandValuesToUse.append(in_start: iterArgs.begin(), in_end: iterArgs.end()); |
355 | } |
356 | return bodyBuilderFn(b, loc, ivs, operandValuesToUse); |
357 | }); |
358 | |
359 | if (loopNest.loops.empty() || procInfo.empty()) |
360 | return; |
361 | |
362 | // Filter out scf.for loops that were created out of parallel dimensions. |
363 | for (const auto &loop : llvm::enumerate(loopNest.loops)) { |
364 | if (procInfo[loop.index()].distributionMethod == |
365 | DistributionMethod::Cyclic) { |
366 | mapLoopToProcessorIds(loop.value(), procInfo[loop.index()].procId, |
367 | procInfo[loop.index()].nprocs); |
368 | } |
369 | } |
370 | } |
371 | |
372 | /// Specialization to build affine "for" nest. |
373 | template <> |
374 | void GenerateLoopNest<AffineForOp>::doit( |
375 | OpBuilder &b, Location loc, ArrayRef<Range> loopRanges, LinalgOp linalgOp, |
376 | ArrayRef<utils::IteratorType> iteratorTypes, |
377 | function_ref<scf::ValueVector(OpBuilder &, Location, ValueRange, |
378 | ValueRange)> |
379 | bodyBuilderFn, |
380 | ArrayRef<linalg::ProcInfo> /*procInfo*/) { |
381 | SmallVector<Value> iterArgInitValues; |
382 | if (!linalgOp.hasPureBufferSemantics()) |
383 | llvm::append_range(iterArgInitValues, linalgOp.getDpsInits()); |
384 | assert(iterArgInitValues.empty() && "unexpected AffineForOp init values"); |
385 | SmallVector<Value, 4> lbs, ubs, steps; |
386 | unpackRanges(builder&: b, loc, ranges: loopRanges, lbs, ubs, steps); |
387 | |
388 | // Affine loops require constant steps. |
389 | SmallVector<int64_t, 4> constantSteps; |
390 | constantSteps.reserve(N: steps.size()); |
391 | for (Value v : steps) { |
392 | auto constVal = getConstantIntValue(ofr: v); |
393 | assert(constVal.has_value() && "Affine loops require constant steps"); |
394 | constantSteps.push_back(Elt: constVal.value()); |
395 | } |
396 | |
397 | affine::buildAffineLoopNest(builder&: b, loc, lbs, ubs, steps: constantSteps, |
398 | bodyBuilderFn: [&](OpBuilder &b, Location loc, ValueRange ivs) { |
399 | bodyBuilderFn(b, loc, ivs, |
400 | linalgOp->getOperands()); |
401 | }); |
402 | } |
403 | |
404 | /// Update the `lb`, `ub` and `step` to get per processor `lb`, `ub` and `step`. |
405 | void updateBoundsForCyclicDistribution(OpBuilder &b, Location loc, Value procId, |
406 | Value nprocs, Value &lb, Value &ub, |
407 | Value &step) { |
408 | AffineExpr d0, d1; |
409 | bindDims(ctx: b.getContext(), exprs&: d0, exprs&: d1); |
410 | AffineExpr s0 = getAffineSymbolExpr(position: 0, context: b.getContext()); |
411 | lb = |
412 | affine::makeComposedAffineApply(b, loc, d0 + d1 * s0, {lb, procId, step}); |
413 | step = affine::makeComposedAffineApply(b, loc, d0 * s0, {nprocs, step}); |
414 | } |
415 | |
416 | /// Generates a loop nest consisting of scf.parallel and scf.for, depending |
417 | /// on the `iteratorTypes.` Consecutive parallel loops create a single |
418 | /// scf.parallel operation; each sequential loop creates a new scf.for |
419 | /// operation. The body of the innermost loop is populated by |
420 | /// `bodyBuilderFn` that accepts a range of induction variables for all |
421 | /// loops. `ivStorage` is used to store the partial list of induction |
422 | /// variables. |
423 | // TODO: this function can be made iterative instead. However, it |
424 | // will have at most as many recursive calls as nested loops, which rarely |
425 | // exceeds 10. |
426 | static void generateParallelLoopNest( |
427 | OpBuilder &b, Location loc, ValueRange lbs, ValueRange ubs, |
428 | ValueRange steps, ArrayRef<utils::IteratorType> iteratorTypes, |
429 | ArrayRef<linalg::ProcInfo> procInfo, |
430 | function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuilderFn, |
431 | SmallVectorImpl<Value> &ivStorage) { |
432 | assert(lbs.size() == ubs.size()); |
433 | assert(lbs.size() == steps.size()); |
434 | assert(lbs.size() == iteratorTypes.size()); |
435 | assert(procInfo.empty() || (lbs.size() == procInfo.size())); |
436 | |
437 | // If there are no (more) loops to be generated, generate the body and be |
438 | // done with it. |
439 | if (iteratorTypes.empty()) { |
440 | bodyBuilderFn(b, loc, ivStorage); |
441 | return; |
442 | } |
443 | |
444 | // If there are no outer parallel loops, generate one sequential loop and |
445 | // recurse. |
446 | if (!isParallelIterator(iteratorTypes.front())) { |
447 | LoopNest singleLoop = buildLoopNest( |
448 | builder&: b, loc, lbs: lbs.take_front(), ubs: ubs.take_front(), steps: steps.take_front(), |
449 | bodyBuilder: [&](OpBuilder &b, Location loc, ValueRange ivs) { |
450 | ivStorage.append(in_start: ivs.begin(), in_end: ivs.end()); |
451 | generateParallelLoopNest( |
452 | b, loc, lbs.drop_front(), ubs.drop_front(), steps.drop_front(), |
453 | iteratorTypes.drop_front(), |
454 | procInfo.empty() ? procInfo : procInfo.drop_front(), |
455 | bodyBuilderFn, ivStorage); |
456 | }); |
457 | return; |
458 | } |
459 | |
460 | unsigned nLoops = iteratorTypes.size(); |
461 | unsigned numProcessed = 0; |
462 | DistributionMethod distributionMethod = DistributionMethod::None; |
463 | if (procInfo.empty()) { |
464 | numProcessed = nLoops - iteratorTypes.drop_while(isParallelIterator).size(); |
465 | } else { |
466 | distributionMethod = procInfo.front().distributionMethod; |
467 | numProcessed = |
468 | nLoops - procInfo |
469 | .drop_while(Pred: [&](linalg::ProcInfo p) { |
470 | return p.distributionMethod == distributionMethod; |
471 | }) |
472 | .size(); |
473 | } |
474 | |
475 | auto remainderProcInfo = |
476 | procInfo.empty() ? procInfo : procInfo.drop_front(N: numProcessed); |
477 | switch (distributionMethod) { |
478 | case DistributionMethod::None: { |
479 | // Generate a single parallel loop-nest operation for all outermost |
480 | // parallel loops and recurse. |
481 | b.create<scf::ParallelOp>( |
482 | loc, lbs.take_front(n: numProcessed), ubs.take_front(n: numProcessed), |
483 | steps.take_front(n: numProcessed), |
484 | [&](OpBuilder &nestedBuilder, Location nestedLoc, ValueRange localIvs) { |
485 | ivStorage.append(in_start: localIvs.begin(), in_end: localIvs.end()); |
486 | generateParallelLoopNest( |
487 | nestedBuilder, nestedLoc, lbs.drop_front(n: numProcessed), |
488 | ubs.drop_front(n: numProcessed), steps.drop_front(n: numProcessed), |
489 | iteratorTypes.drop_front(numProcessed), remainderProcInfo, |
490 | bodyBuilderFn, ivStorage); |
491 | }); |
492 | return; |
493 | } |
494 | case DistributionMethod::Cyclic: { |
495 | // Generate a single parallel loop-nest operation for all outermost |
496 | // parallel loops and recurse. |
497 | b.create<scf::ParallelOp>( |
498 | loc, lbs.take_front(n: numProcessed), ubs.take_front(n: numProcessed), |
499 | steps.take_front(n: numProcessed), |
500 | [&](OpBuilder &nestedBuilder, Location nestedLoc, ValueRange localIvs) { |
501 | ivStorage.append(in_start: localIvs.begin(), in_end: localIvs.end()); |
502 | generateParallelLoopNest( |
503 | nestedBuilder, nestedLoc, lbs.drop_front(n: numProcessed), |
504 | ubs.drop_front(n: numProcessed), steps.drop_front(n: numProcessed), |
505 | iteratorTypes.drop_front(numProcessed), remainderProcInfo, |
506 | bodyBuilderFn, ivStorage); |
507 | }); |
508 | return; |
509 | } |
510 | case DistributionMethod::CyclicNumProcsGeNumIters: { |
511 | // Check (for the processed loops) that the iteration is in-bounds. |
512 | ArithBuilder ab(b, loc); |
513 | Value cond = ab.slt(lhs: lbs[0], rhs: ubs[0]); |
514 | for (unsigned i = 1; i < numProcessed; ++i) |
515 | cond = ab._and(lhs: cond, rhs: ab.slt(lhs: lbs[i], rhs: ubs[i])); |
516 | ivStorage.append(in_start: lbs.begin(), in_end: std::next(x: lbs.begin(), n: numProcessed)); |
517 | b.create<scf::IfOp>(loc, cond, [&](OpBuilder &b, Location loc) { |
518 | generateParallelLoopNest(b, loc, lbs.drop_front(numProcessed), |
519 | ubs.drop_front(numProcessed), |
520 | steps.drop_front(numProcessed), |
521 | iteratorTypes.drop_front(numProcessed), |
522 | remainderProcInfo, bodyBuilderFn, ivStorage); |
523 | b.create<scf::YieldOp>(loc, ValueRange{}); |
524 | }); |
525 | return; |
526 | } |
527 | case DistributionMethod::CyclicNumProcsEqNumIters: |
528 | // No check/loops needed here. Set the `%iv` to be the `%lb` and proceed |
529 | // with inner loop generation. |
530 | ivStorage.append(in_start: lbs.begin(), in_end: std::next(x: lbs.begin(), n: numProcessed)); |
531 | generateParallelLoopNest( |
532 | b, loc, lbs.drop_front(n: numProcessed), ubs.drop_front(n: numProcessed), |
533 | steps.drop_front(n: numProcessed), iteratorTypes.drop_front(numProcessed), |
534 | remainderProcInfo, bodyBuilderFn, ivStorage); |
535 | return; |
536 | } |
537 | } |
538 | |
539 | /// Specialization for generating a mix of parallel and sequential scf loops. |
540 | template <> |
541 | void GenerateLoopNest<scf::ParallelOp>::doit( |
542 | OpBuilder &b, Location loc, ArrayRef<Range> loopRanges, LinalgOp linalgOp, |
543 | ArrayRef<utils::IteratorType> iteratorTypes, |
544 | function_ref<scf::ValueVector(OpBuilder &, Location, ValueRange, |
545 | ValueRange)> |
546 | bodyBuilderFn, |
547 | ArrayRef<linalg::ProcInfo> procInfo) { |
548 | SmallVector<Value> iterArgInitValues; |
549 | if (!linalgOp.hasPureBufferSemantics()) |
550 | llvm::append_range(iterArgInitValues, linalgOp.getDpsInits()); |
551 | assert(iterArgInitValues.empty() && "unexpected ParallelOp init values"); |
552 | // This function may be passed more iterator types than ranges. |
553 | assert(iteratorTypes.size() >= loopRanges.size() && |
554 | "expected iterator type for all ranges"); |
555 | assert((procInfo.empty() || (procInfo.size() == loopRanges.size())) && |
556 | "expected proc information for all loops when present"); |
557 | iteratorTypes = iteratorTypes.take_front(loopRanges.size()); |
558 | SmallVector<Value, 8> lbsStorage, ubsStorage, stepsStorage, ivs; |
559 | unsigned numLoops = iteratorTypes.size(); |
560 | ivs.reserve(N: numLoops); |
561 | lbsStorage.reserve(N: numLoops); |
562 | ubsStorage.reserve(N: numLoops); |
563 | stepsStorage.reserve(N: numLoops); |
564 | |
565 | // Get the loop lb, ub, and step. |
566 | unpackRanges(builder&: b, loc, ranges: loopRanges, lbs&: lbsStorage, ubs&: ubsStorage, steps&: stepsStorage); |
567 | |
568 | // Modify the lb, ub, and step based on the distribution options. |
569 | for (const auto &it : llvm::enumerate(First&: procInfo)) { |
570 | if (it.value().distributionMethod != linalg::DistributionMethod::None) { |
571 | updateBoundsForCyclicDistribution( |
572 | b, loc, procId: it.value().procId, nprocs: it.value().nprocs, lb&: lbsStorage[it.index()], |
573 | ub&: ubsStorage[it.index()], step&: stepsStorage[it.index()]); |
574 | } |
575 | } |
576 | ValueRange lbs(lbsStorage), ubs(ubsStorage), steps(stepsStorage); |
577 | generateParallelLoopNest( |
578 | b, loc, lbs, ubs, steps, iteratorTypes, procInfo, |
579 | [&](OpBuilder &b, Location loc, ValueRange ivs) { |
580 | bodyBuilderFn(b, loc, ivs, linalgOp->getOperands()); |
581 | }, |
582 | ivs); |
583 | |
584 | assert(ivs.size() == iteratorTypes.size() && "did not generate enough loops"); |
585 | } |
586 | |
587 | static Operation *materializeTiledShape(OpBuilder &builder, Location loc, |
588 | Value valueToTile, |
589 | const SliceParameters &sliceParams) { |
590 | auto shapedType = dyn_cast<ShapedType>(valueToTile.getType()); |
591 | auto *sliceOp = TypeSwitch<ShapedType, Operation *>(shapedType) |
592 | .Case([&](MemRefType) { |
593 | return builder.create<memref::SubViewOp>( |
594 | loc, valueToTile, sliceParams.offsets, |
595 | sliceParams.sizes, sliceParams.strides); |
596 | }) |
597 | .Case([&](RankedTensorType) { |
598 | return builder.create<tensor::ExtractSliceOp>( |
599 | loc, valueToTile, sliceParams.offsets, |
600 | sliceParams.sizes, sliceParams.strides); |
601 | }) |
602 | .Default([](ShapedType) -> Operation * { |
603 | llvm_unreachable("Unexpected shaped type"); |
604 | }); |
605 | return sliceOp; |
606 | } |
607 | |
608 | Operation *makeTiledShape(OpBuilder &builder, Location loc, Value valueToTile, |
609 | ArrayRef<OpFoldResult> tileSizes, AffineMap map, |
610 | ArrayRef<OpFoldResult> lbs, |
611 | ArrayRef<OpFoldResult> ubs, |
612 | ArrayRef<OpFoldResult> subShapeSizes, |
613 | bool omitPartialTileCheck) { |
614 | SliceParameters sliceParams = |
615 | computeSliceParameters(builder, loc, valueToTile, tileSizes, map, lbs, |
616 | ubs, subShapeSizes, omitPartialTileCheck); |
617 | return materializeTiledShape(builder, loc, valueToTile, sliceParams); |
618 | } |
619 | |
620 | SliceParameters |
621 | computeSliceParameters(OpBuilder &builder, Location loc, Value valueToTile, |
622 | ArrayRef<OpFoldResult> tileSizes, AffineMap map, |
623 | ArrayRef<OpFoldResult> lbs, ArrayRef<OpFoldResult> ubs, |
624 | ArrayRef<OpFoldResult> subShapeSizes, |
625 | bool omitPartialTileCheck) { |
626 | auto shapedType = dyn_cast<ShapedType>(valueToTile.getType()); |
627 | assert(shapedType && "only shaped types can be tiled"); |
628 | ArrayRef<int64_t> shape = shapedType.getShape(); |
629 | int64_t rank = shapedType.getRank(); |
630 | |
631 | // Compute offsets/sizes/strides for the tile. |
632 | SliceParameters sliceParams; |
633 | sliceParams.offsets.reserve(N: rank); |
634 | sliceParams.sizes.reserve(N: rank); |
635 | sliceParams.strides.reserve(N: rank); |
636 | for (unsigned r = 0; r < rank; ++r) { |
637 | LLVM_DEBUG(llvm::dbgs() << "computeSliceParameters: for dim#"<< r); |
638 | if (!isTiled(map: map.getSubMap(resultPos: {r}), tileSizes)) { |
639 | sliceParams.offsets.push_back(builder.getIndexAttr(0)); |
640 | OpFoldResult dim = createFoldedDimOp(b&: builder, loc, val: valueToTile, dim: r); |
641 | sliceParams.sizes.push_back(Elt: dim); |
642 | sliceParams.strides.push_back(builder.getIndexAttr(1)); |
643 | LLVM_DEBUG(llvm::dbgs() << ": not tiled: use size: "<< dim << "\n"); |
644 | continue; |
645 | } |
646 | LLVM_DEBUG(llvm::dbgs() << ": tiled: figure out subsize...\n"); |
647 | |
648 | // Tiling creates a new slice at the proper index, the slice step is 1 |
649 | // (i.e. the op does not subsample, stepping occurs in the loop). |
650 | auto m = map.getSubMap(resultPos: {r}); |
651 | LLVM_DEBUG(llvm::dbgs() << "computeSliceParameters: submap: "<< m << "\n"); |
652 | IRRewriter rewriter(builder); |
653 | // The offset of the slice is m(lbs) - m(0). |
654 | SmallVector<Attribute> zeros(lbs.size(), rewriter.getIndexAttr(0)); |
655 | SmallVector<Attribute> mAtZero; |
656 | [[maybe_unused]] auto res = m.constantFold(operandConstants: zeros, results&: mAtZero); |
657 | assert(succeeded(res) && "affine_map must be evaluatable (not symbols)"); |
658 | int64_t mAtZeroInt = |
659 | cast<IntegerAttr>(mAtZero[0]).getValue().getSExtValue(); |
660 | OpFoldResult offset = makeComposedFoldedAffineApply( |
661 | b&: rewriter, loc, expr: m.getResult(idx: 0) - mAtZeroInt, operands: lbs); |
662 | sliceParams.offsets.push_back(Elt: offset); |
663 | |
664 | OpFoldResult closedIntSize = |
665 | makeComposedFoldedAffineApply(b&: rewriter, loc, map: m, operands: subShapeSizes); |
666 | // Resulting size needs to be made half open interval again. |
667 | AffineExpr s0 = getAffineSymbolExpr(position: 0, context: builder.getContext()); |
668 | OpFoldResult size = |
669 | makeComposedFoldedAffineApply(b&: rewriter, loc, expr: s0 + 1, operands: closedIntSize); |
670 | LLVM_DEBUG(llvm::dbgs() |
671 | << "computeSliceParameters: raw size: "<< size << "\n"); |
672 | LLVM_DEBUG(llvm::dbgs() |
673 | << "computeSliceParameters: new offset: "<< offset << "\n"); |
674 | sliceParams.strides.push_back(builder.getIndexAttr(1)); |
675 | |
676 | if (omitPartialTileCheck) { |
677 | // We statically know that the partial/boundary tile condition is |
678 | // unnecessary. |
679 | LLVM_DEBUG(llvm::dbgs() << "makeTiledShape: new size: "<< size << "\n"); |
680 | sliceParams.sizes.push_back(Elt: size); |
681 | continue; |
682 | } |
683 | |
684 | // The size of the subview / extract_slice should be trimmed to avoid |
685 | // out-of-bounds accesses, unless: |
686 | // a. We statically know the subshape size divides the shape size evenly. |
687 | // b. The subshape size is 1. According to the way the loops are set up, |
688 | // tensors with "0" dimensions would never be constructed. |
689 | int64_t shapeSize = shape[r]; |
690 | std::optional<int64_t> sizeCst = getConstantIntValue(ofr: size); |
691 | auto hasTileSizeOne = sizeCst && *sizeCst == 1; |
692 | auto dividesEvenly = sizeCst && !ShapedType::isDynamic(shapeSize) && |
693 | ((shapeSize % *sizeCst) == 0); |
694 | if (!hasTileSizeOne && !dividesEvenly) { |
695 | LLVM_DEBUG(llvm::dbgs() << "makeTiledShape: shapeSize="<< shapeSize |
696 | << ", size: "<< size |
697 | << ": make sure in bound with affine.min\n"); |
698 | |
699 | AffineExpr dim0, dim1, dim2; |
700 | MLIRContext *context = builder.getContext(); |
701 | bindDims(ctx: context, exprs&: dim0, exprs&: dim1, exprs&: dim2); |
702 | |
703 | // Get the dimension size for this dimension. We need to first calculate |
704 | // the max index and then plus one. This is important because for |
705 | // convolution ops, we have its input window dimension's affine map of the |
706 | // form `(d0 * s0 + d1)`, where `d0`/`d1 is an output/filter window |
707 | // dimension and `s0` is stride. Directly use the dimension size of |
708 | // output/filer window dimensions will cause incorrect calculation. |
709 | AffineMap minusOneMap = AffineMap::inferFromExprList( |
710 | exprsList: {ArrayRef<AffineExpr>{dim0 - 1}}, context) |
711 | .front(); |
712 | AffineMap plusOneMap = AffineMap::inferFromExprList( |
713 | exprsList: {ArrayRef<AffineExpr>{dim0 + 1}}, context) |
714 | .front(); |
715 | SmallVector<OpFoldResult> maxIndices = |
716 | llvm::to_vector(Range: llvm::map_range(C&: ubs, F: [&](OpFoldResult ub) { |
717 | return makeComposedFoldedAffineApply(b&: rewriter, loc, map: minusOneMap, |
718 | operands: {ub}); |
719 | })); |
720 | OpFoldResult maxIndex = |
721 | makeComposedFoldedAffineApply(b&: rewriter, loc, map: m, operands: maxIndices); |
722 | OpFoldResult d = |
723 | makeComposedFoldedAffineApply(b&: rewriter, loc, map: plusOneMap, operands: {maxIndex}); |
724 | |
725 | // Compute min(dim - offset, size) to avoid out-of-bounds accesses. |
726 | AffineMap minMap = AffineMap::inferFromExprList( |
727 | exprsList: {ArrayRef<AffineExpr>{dim1 - dim2, dim0}}, context) |
728 | .front(); |
729 | size = |
730 | makeComposedFoldedAffineMin(b&: rewriter, loc, map: minMap, operands: {size, d, offset}); |
731 | } |
732 | LLVM_DEBUG(llvm::dbgs() << "makeTiledShape: new size: "<< size << "\n"); |
733 | sliceParams.sizes.push_back(Elt: size); |
734 | } |
735 | return sliceParams; |
736 | } |
737 | |
738 | SmallVector<OpFoldResult> computeTileOffsets(OpBuilder &b, Location loc, |
739 | ArrayRef<OpFoldResult> ivs, |
740 | ArrayRef<OpFoldResult> tileSizes) { |
741 | SmallVector<OpFoldResult> offsets; |
742 | for (unsigned idx = 0, idxIvs = 0, e = tileSizes.size(); idx < e; ++idx) { |
743 | LLVM_DEBUG(llvm::dbgs() << "makeTiledShapes: for loop#"<< idx << "\n"); |
744 | bool isTiled = !isZeroInteger(v: tileSizes[idx]); |
745 | offsets.push_back(Elt: isTiled ? ivs[idxIvs++] : b.getIndexAttr(0)); |
746 | LLVM_DEBUG(llvm::dbgs() |
747 | << "computeTileOffsets: "<< offsets.back() << "\n"); |
748 | } |
749 | return offsets; |
750 | } |
751 | |
752 | SmallVector<OpFoldResult> computeTileSizes(OpBuilder &b, Location loc, |
753 | ArrayRef<OpFoldResult> tileSizes, |
754 | ArrayRef<OpFoldResult> sizeBounds) { |
755 | SmallVector<OpFoldResult> sizes; |
756 | for (unsigned idx = 0, e = tileSizes.size(); idx < e; ++idx) { |
757 | bool isTiled = !isZeroInteger(v: tileSizes[idx]); |
758 | // Before composing, we need to make range a closed interval. |
759 | OpFoldResult size = isTiled ? tileSizes[idx] : sizeBounds[idx]; |
760 | AffineExpr d0 = getAffineDimExpr(position: 0, context: b.getContext()); |
761 | IRRewriter rewriter(b); |
762 | sizes.push_back(Elt: makeComposedFoldedAffineApply(b&: rewriter, loc, expr: d0 - 1, operands: size)); |
763 | LLVM_DEBUG(llvm::dbgs() << "computeTileSizes: "<< sizes.back() << "\n"); |
764 | } |
765 | return sizes; |
766 | } |
767 | |
768 | SmallVector<Type> getTensorOutputTypes(LinalgOp op, ValueRange operands) { |
769 | if (op.hasPureBufferSemantics()) |
770 | return {}; |
771 | return llvm::to_vector( |
772 | llvm::map_range(op.getDpsInitsMutable(), [&](OpOperand &opOperand) { |
773 | return operands[opOperand.getOperandNumber()].getType(); |
774 | })); |
775 | } |
776 | |
777 | SmallVector<Value> insertSlicesBack(OpBuilder &builder, Location loc, |
778 | LinalgOp op, ValueRange operands, |
779 | ValueRange results) { |
780 | if (op.hasPureBufferSemantics()) |
781 | return {}; |
782 | SmallVector<Value> tensorResults; |
783 | tensorResults.reserve(N: results.size()); |
784 | // Insert a insert_slice for each output tensor. |
785 | unsigned resultIdx = 0; |
786 | for (OpOperand &opOperand : op.getDpsInitsMutable()) { |
787 | // TODO: use an interface/adaptor to avoid leaking position in |
788 | // `tiledOperands`. |
789 | Value outputTensor = operands[opOperand.getOperandNumber()]; |
790 | if (auto sliceOp = outputTensor.getDefiningOp<tensor::ExtractSliceOp>()) { |
791 | Value inserted = builder.create<tensor::InsertSliceOp>( |
792 | loc, sliceOp.getSource().getType(), results[resultIdx], |
793 | sliceOp.getSource(), sliceOp.getOffsets(), sliceOp.getSizes(), |
794 | sliceOp.getStrides(), sliceOp.getStaticOffsets(), |
795 | sliceOp.getStaticSizes(), sliceOp.getStaticStrides()); |
796 | tensorResults.push_back(inserted); |
797 | } else { |
798 | tensorResults.push_back(results[resultIdx]); |
799 | } |
800 | ++resultIdx; |
801 | } |
802 | return tensorResults; |
803 | } |
804 | |
805 | SmallVector<std::optional<SliceParameters>> |
806 | computeAllSliceParameters(OpBuilder &builder, Location loc, LinalgOp linalgOp, |
807 | ValueRange valuesToTile, ArrayRef<OpFoldResult> ivs, |
808 | ArrayRef<OpFoldResult> tileSizes, |
809 | ArrayRef<OpFoldResult> sizeBounds, |
810 | bool omitPartialTileCheck) { |
811 | assert(ivs.size() == static_cast<size_t>(llvm::count_if( |
812 | llvm::make_range(tileSizes.begin(), tileSizes.end()), |
813 | [](OpFoldResult v) { return !isZeroInteger(v); })) && |
814 | "expected as many ivs as non-zero sizes"); |
815 | |
816 | // Construct (potentially temporary) mins and maxes on which to apply maps |
817 | // that define tile subshapes. |
818 | SmallVector<OpFoldResult> lbs = |
819 | computeTileOffsets(b&: builder, loc, ivs, tileSizes); |
820 | SmallVector<OpFoldResult> subShapeSizes = |
821 | computeTileSizes(b&: builder, loc, tileSizes, sizeBounds); |
822 | |
823 | assert(static_cast<int64_t>(valuesToTile.size()) <= |
824 | linalgOp->getNumOperands() && |
825 | "more value to tile than operands."); |
826 | SmallVector<std::optional<SliceParameters>> allSliceParams; |
827 | allSliceParams.reserve(N: valuesToTile.size()); |
828 | for (auto [opOperand, val] : |
829 | llvm::zip(linalgOp->getOpOperands(), valuesToTile)) { |
830 | Value shapedOp = val; |
831 | LLVM_DEBUG(llvm::dbgs() << "makeTiledShapes: for operand "<< shapedOp); |
832 | AffineMap map = linalgOp.getMatchingIndexingMap(&opOperand); |
833 | // Use `opOperand` as is if it is not tiled and not an output tensor. Having |
834 | // an extract/insert slice pair for all output tensors simplifies follow up |
835 | // transformations such as padding and bufferization since the |
836 | // extract/insert slice pairs make the accessed iteration argument |
837 | // subdomains explicit. |
838 | |
839 | Type operandType = opOperand.get().getType(); |
840 | if (!isTiled(map, tileSizes) && !(isa<RankedTensorType>(operandType) && |
841 | linalgOp.isDpsInit(&opOperand))) { |
842 | allSliceParams.push_back(std::nullopt); |
843 | LLVM_DEBUG(llvm::dbgs() |
844 | << ": not tiled: use shape: "<< operandType << "\n"); |
845 | continue; |
846 | } |
847 | LLVM_DEBUG(llvm::dbgs() << ": tiled: figure out subshape...\n"); |
848 | |
849 | allSliceParams.push_back(computeSliceParameters( |
850 | builder, loc, shapedOp, tileSizes, map, lbs, sizeBounds, subShapeSizes, |
851 | omitPartialTileCheck)); |
852 | } |
853 | |
854 | return allSliceParams; |
855 | } |
856 | |
857 | SmallVector<Value> makeTiledShapes(OpBuilder &builder, Location loc, |
858 | LinalgOp linalgOp, ValueRange valuesToTile, |
859 | ArrayRef<OpFoldResult> ivs, |
860 | ArrayRef<OpFoldResult> tileSizes, |
861 | ArrayRef<OpFoldResult> sizeBounds, |
862 | bool omitPartialTileCheck) { |
863 | SmallVector<std::optional<SliceParameters>> allSliceParameter = |
864 | computeAllSliceParameters(builder, loc, linalgOp, valuesToTile, ivs, |
865 | tileSizes, sizeBounds, omitPartialTileCheck); |
866 | SmallVector<Value> tiledShapes; |
867 | for (auto item : llvm::zip(valuesToTile, allSliceParameter)) { |
868 | Value valueToTile = std::get<0>(item); |
869 | std::optional<SliceParameters> sliceParams = std::get<1>(item); |
870 | tiledShapes.push_back( |
871 | sliceParams.has_value() |
872 | ? materializeTiledShape(builder, loc, valueToTile, *sliceParams) |
873 | ->getResult(0) |
874 | : valueToTile); |
875 | } |
876 | return tiledShapes; |
877 | } |
878 | |
879 | void offsetIndices(OpBuilder &b, LinalgOp linalgOp, |
880 | ArrayRef<OpFoldResult> offsets) { |
881 | IRRewriter rewriter(b); |
882 | offsetIndices(rewriter, linalgOp, offsets); |
883 | } |
884 | |
885 | void offsetIndices(RewriterBase &b, LinalgOp linalgOp, |
886 | ArrayRef<OpFoldResult> offsets) { |
887 | if (!linalgOp.hasIndexSemantics()) |
888 | return; |
889 | |
890 | for (IndexOp indexOp : linalgOp.getBlock()->getOps<IndexOp>()) { |
891 | if (indexOp.getDim() >= offsets.size() || !offsets[indexOp.getDim()]) |
892 | continue; |
893 | OpBuilder::InsertionGuard guard(b); |
894 | b.setInsertionPointAfter(indexOp); |
895 | AffineExpr index, offset; |
896 | bindDims(b.getContext(), index, offset); |
897 | OpFoldResult applied = makeComposedFoldedAffineApply( |
898 | b, indexOp.getLoc(), index + offset, |
899 | {getAsOpFoldResult(indexOp.getResult()), offsets[indexOp.getDim()]}); |
900 | Value materialized = |
901 | getValueOrCreateConstantIndexOp(b, indexOp.getLoc(), applied); |
902 | b.replaceUsesWithIf(indexOp, materialized, [&](OpOperand &use) { |
903 | return use.getOwner() != materialized.getDefiningOp(); |
904 | }); |
905 | } |
906 | } |
907 | |
908 | /// Get the reassociation maps to fold the result of a extract_slice (or source |
909 | /// of a insert_slice) operation with given offsets, and sizes to its |
910 | /// rank-reduced version. This is only done for the cases where the size is 1 |
911 | /// and offset is 0. Strictly speaking the offset 0 is not required in general, |
912 | /// but non-zero offsets are not handled by SPIR-V backend at this point (and |
913 | /// potentially cannot be handled). |
914 | std::optional<SmallVector<ReassociationIndices>> |
915 | getReassociationMapForFoldingUnitDims(ArrayRef<OpFoldResult> mixedSizes) { |
916 | SmallVector<ReassociationIndices> reassociation; |
917 | ReassociationIndices curr; |
918 | for (const auto &it : llvm::enumerate(First&: mixedSizes)) { |
919 | auto dim = it.index(); |
920 | auto size = it.value(); |
921 | curr.push_back(Elt: dim); |
922 | auto attr = llvm::dyn_cast_if_present<Attribute>(Val&: size); |
923 | if (attr && cast<IntegerAttr>(attr).getInt() == 1) |
924 | continue; |
925 | reassociation.emplace_back(Args: ReassociationIndices{}); |
926 | std::swap(LHS&: reassociation.back(), RHS&: curr); |
927 | } |
928 | // When the reassociations are not empty, then fold the remaining |
929 | // unit-dimensions into the last dimension. If the reassociations so far is |
930 | // empty, then leave it emtpy. This will fold everything to a rank-0 tensor. |
931 | if (!curr.empty() && !reassociation.empty()) |
932 | reassociation.back().append(in_start: curr.begin(), in_end: curr.end()); |
933 | return reassociation; |
934 | } |
935 | |
936 | } // namespace linalg |
937 | } // namespace mlir |
938 |
Definitions
- TileCheck
- TileCheck
- visitDimExpr
- visitAffineBinaryOpExpr
- isTiled
- isTiled
- matchAsScalarBinaryOp
- GenerateLoopNest
- GenerateLoopNest
- GenerateLoopNest
- unpackRanges
- computePackUnPackPerm
- getPackInverseDestPerm
- getUnPackInverseSrcPerm
- getUnPackInverseSrcPerm
- allIndexingsAreProjectedPermutation
- hasOnlyScalarElementwiseOp
- isElementwise
- isParallelIterator
- isReductionIterator
- makeComposedPadHighOp
- makeMemRefCopyOp
- doit
- doit
- updateBoundsForCyclicDistribution
- generateParallelLoopNest
- doit
- materializeTiledShape
- makeTiledShape
- computeSliceParameters
- computeTileOffsets
- computeTileSizes
- getTensorOutputTypes
- insertSlicesBack
- computeAllSliceParameters
- makeTiledShapes
- offsetIndices
- offsetIndices
Learn to use CMake with our Intro Training
Find out more