1//===- Utils.cpp - Utilities to support the Linalg dialect ----------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implements utilities for the Linalg dialect.
10//
11//===----------------------------------------------------------------------===//
12
13#include "mlir/Dialect/Linalg/Utils/Utils.h"
14
15#include "mlir/Analysis/SliceAnalysis.h"
16#include "mlir/Dialect/Affine/Analysis/AffineStructures.h"
17#include "mlir/Dialect/Affine/IR/AffineOps.h"
18#include "mlir/Dialect/Affine/LoopUtils.h"
19#include "mlir/Dialect/Arith/IR/Arith.h"
20#include "mlir/Dialect/Arith/Utils/Utils.h"
21#include "mlir/Dialect/Func/IR/FuncOps.h"
22#include "mlir/Dialect/Linalg/IR/Linalg.h"
23#include "mlir/Dialect/MemRef/IR/MemRef.h"
24#include "mlir/Dialect/SCF/IR/SCF.h"
25#include "mlir/Dialect/Tensor/IR/Tensor.h"
26#include "mlir/Dialect/Tensor/Utils/Utils.h"
27#include "mlir/Dialect/Utils/IndexingUtils.h"
28#include "mlir/Dialect/Utils/StaticValueUtils.h"
29#include "mlir/IR/AffineExpr.h"
30#include "mlir/IR/AffineExprVisitor.h"
31#include "mlir/IR/AffineMap.h"
32#include "mlir/IR/Matchers.h"
33#include "llvm/ADT/TypeSwitch.h"
34#include "llvm/Support/Debug.h"
35#include <optional>
36
37#define DEBUG_TYPE "linalg-utils"
38
39using namespace mlir;
40using namespace presburger;
41using namespace mlir::affine;
42using namespace mlir::linalg;
43using namespace mlir::scf;
44
45namespace {
46
47// Helper visitor to determine whether an AffineExpr is tiled.
48// This is achieved by traversing every AffineDimExpr with position `pos` and
49// checking whether the corresponding `tileSizes[pos]` is non-zero.
50// This also enforces only positive coefficients occur in multiplications.
51//
52// Example:
53// `d0 + 2 * d1 + d3` is tiled by [0, 0, 0, 2] but not by [0, 0, 2, 0]
54//
55struct TileCheck : public AffineExprVisitor<TileCheck> {
56 TileCheck(ArrayRef<OpFoldResult> tileSizes) : tileSizes(tileSizes) {}
57
58 void visitDimExpr(AffineDimExpr expr) {
59 isTiled |= !isZeroInteger(v: tileSizes[expr.getPosition()]);
60 }
61 void visitAffineBinaryOpExpr(AffineBinaryOpExpr expr) {
62 visit(expr: expr.getLHS());
63 visit(expr: expr.getRHS());
64 if (expr.getKind() == mlir::AffineExprKind::Mul)
65 assert(cast<AffineConstantExpr>(expr.getRHS()).getValue() > 0 &&
66 "nonpositive multiplying coefficient");
67 }
68 bool isTiled = false;
69 ArrayRef<OpFoldResult> tileSizes;
70};
71
72} // namespace
73
74static bool isTiled(AffineExpr expr, ArrayRef<OpFoldResult> tileSizes) {
75 if (!expr)
76 return false;
77 TileCheck t(tileSizes);
78 t.visit(expr);
79 return t.isTiled;
80}
81
82// Checks whether the `map varies with respect to a non-zero `tileSize`.
83static bool isTiled(AffineMap map, ArrayRef<OpFoldResult> tileSizes) {
84 if (!map)
85 return false;
86 for (unsigned r = 0; r < map.getNumResults(); ++r)
87 if (isTiled(expr: map.getResult(idx: r), tileSizes))
88 return true;
89 return false;
90}
91
92std::optional<RegionMatcher::BinaryOpKind>
93RegionMatcher::matchAsScalarBinaryOp(GenericOp op) {
94 auto &region = op.getRegion();
95 if (!llvm::hasSingleElement(C&: region))
96 return std::nullopt;
97
98 Block &block = region.front();
99 if (block.getNumArguments() != 2 ||
100 !block.getArgument(i: 0).getType().isSignlessIntOrFloat() ||
101 !block.getArgument(i: 1).getType().isSignlessIntOrFloat())
102 return std::nullopt;
103
104 auto &ops = block.getOperations();
105 if (!llvm::hasSingleElement(C: block.without_terminator()))
106 return std::nullopt;
107
108 using mlir::matchers::m_Val;
109 auto a = m_Val(v: block.getArgument(i: 0));
110 auto b = m_Val(v: block.getArgument(i: 1));
111
112 auto addPattern = m_Op<linalg::YieldOp>(matchers: m_Op<arith::AddIOp>(matchers: a, matchers: b));
113 if (addPattern.match(op: &ops.back()))
114 return BinaryOpKind::IAdd;
115
116 return std::nullopt;
117}
118
119/// Explicit instantiation of loop nest generator for different loop types.
120template struct mlir::linalg::GenerateLoopNest<scf::ForOp>;
121template struct mlir::linalg::GenerateLoopNest<scf::ParallelOp>;
122template struct mlir::linalg::GenerateLoopNest<AffineForOp>;
123
124/// Given a list of subview ranges, extract individual values for lower, upper
125/// bounds and steps and put them into the corresponding vectors.
126static void unpackRanges(OpBuilder &builder, Location loc,
127 ArrayRef<Range> ranges, SmallVectorImpl<Value> &lbs,
128 SmallVectorImpl<Value> &ubs,
129 SmallVectorImpl<Value> &steps) {
130 for (Range range : ranges) {
131 lbs.emplace_back(
132 Args: getValueOrCreateConstantIndexOp(b&: builder, loc, ofr: range.offset));
133 ubs.emplace_back(Args: getValueOrCreateConstantIndexOp(b&: builder, loc, ofr: range.size));
134 steps.emplace_back(
135 Args: getValueOrCreateConstantIndexOp(b&: builder, loc, ofr: range.stride));
136 }
137}
138
139//===----------------------------------------------------------------------===//
140// General utilities
141//===----------------------------------------------------------------------===//
142//
143/// The permutation can be obtained from two permutations:
144/// a) Compute the permutation vector to move the last `numPackedDims` into
145/// the `innerPosDims` of a shape of rank `rank`.
146/// b) Compute the permutation vector to move outer dims if the
147/// `outerPerm` parameter is not empty.
148/// Apply (b) permutation on (a) permutation to get the final permutation.
149static SmallVector<int64_t>
150computePackUnPackPerm(int64_t rank, ArrayRef<int64_t> &innerDimsPos,
151 ArrayRef<int64_t> &outerPerm,
152 PackingMetadata &packingMetadata) {
153 int64_t numPackedDims = innerDimsPos.size();
154 auto lastDims =
155 llvm::to_vector(Range: llvm::seq<int64_t>(Begin: rank - numPackedDims, End: rank));
156 packingMetadata = computePackingMetadata(packedRank: rank, innerDimPos: innerDimsPos);
157 SmallVector<int64_t> innerPositionsPerm =
158 computePermutationVector(permSize: rank, positions: lastDims, desiredPositions: packingMetadata.insertPositions);
159
160 SmallVector<int64_t> outerPos = packingMetadata.outerPositions;
161 if (!outerPerm.empty())
162 applyPermutationToVector(inVec&: outerPos, permutation: outerPerm);
163 SmallVector<int64_t> outerPositionPerm =
164 computePermutationVector(permSize: rank, positions: packingMetadata.outerPositions, desiredPositions: outerPos);
165
166 SmallVector<int64_t> packInverseDestPermutation = innerPositionsPerm;
167 applyPermutationToVector(inVec&: packInverseDestPermutation, permutation: outerPositionPerm);
168 return packInverseDestPermutation;
169}
170
171namespace mlir {
172namespace linalg {
173
174SmallVector<int64_t> getPackInverseDestPerm(PackOp packOp) {
175
176 PackingMetadata pMetadata;
177 int64_t packedRank = packOp.getDestType().getRank();
178 ArrayRef<int64_t> innerDimPos = packOp.getInnerDimsPos();
179 ArrayRef<int64_t> outerPerm = packOp.getOuterDimsPerm();
180 SmallVector<int64_t> packInvDestPerm =
181 computePackUnPackPerm(rank: packedRank, innerDimsPos&: innerDimPos, outerPerm, packingMetadata&: pMetadata);
182 return packInvDestPerm;
183}
184
185SmallVector<int64_t> getUnPackInverseSrcPerm(UnPackOp unpackOp) {
186 PackingMetadata metadata;
187 return getUnPackInverseSrcPerm(unpackOp, metadata);
188}
189
190SmallVector<int64_t> getUnPackInverseSrcPerm(UnPackOp unpackOp,
191 PackingMetadata &metadata) {
192 int64_t unpackRank = unpackOp.getSourceType().getRank();
193 ArrayRef<int64_t> innerDimPos = unpackOp.getInnerDimsPos();
194 ArrayRef<int64_t> outerPerm = unpackOp.getOuterDimsPerm();
195 SmallVector<int64_t> unpackInvSrcPerm =
196 computePackUnPackPerm(rank: unpackRank, innerDimsPos&: innerDimPos, outerPerm, packingMetadata&: metadata);
197 return unpackInvSrcPerm;
198}
199
200bool allIndexingsAreProjectedPermutation(LinalgOp op) {
201 return llvm::all_of(Range: op.getIndexingMapsArray(), P: [](AffineMap m) {
202 return m.isProjectedPermutation(/*allowZeroInResults=*/true);
203 });
204}
205
206bool hasOnlyScalarElementwiseOp(Region &r) {
207 if (!llvm::hasSingleElement(C&: r))
208 return false;
209 for (Operation &op : r.front()) {
210 if (!(isa<arith::ConstantOp, func::ConstantOp, tensor::ExtractOp,
211 linalg::YieldOp, linalg::IndexOp, AffineApplyOp>(Val: op) ||
212 OpTrait::hasElementwiseMappableTraits(op: &op)) ||
213 llvm::any_of(Range: op.getResultTypes(),
214 P: [](Type type) { return !type.isIntOrIndexOrFloat(); }))
215 return false;
216 }
217 return true;
218}
219
220bool isElementwise(LinalgOp op) {
221 if (op.getNumLoops() != op.getNumParallelLoops())
222 return false;
223
224 if (!allIndexingsAreProjectedPermutation(op))
225 return false;
226
227 // TODO: relax the restrictions on indexing map.
228 for (OpOperand &opOperand : op.getDpsInitsMutable()) {
229 if (!op.getMatchingIndexingMap(opOperand: &opOperand).isPermutation())
230 return false;
231 }
232 return hasOnlyScalarElementwiseOp(r&: op->getRegion(index: 0));
233}
234
235bool isParallelIterator(utils::IteratorType iteratorType) {
236 return iteratorType == utils::IteratorType::parallel;
237}
238
239bool isReductionIterator(utils::IteratorType iteratorType) {
240 return iteratorType == utils::IteratorType::reduction;
241}
242
243Value makeComposedPadHighOp(OpBuilder &b, Location loc, RankedTensorType type,
244 Value source, Value pad, bool nofold,
245 ValueRange typeDynDims) {
246 // Exit if `source` is not defined by an ExtractSliceOp.
247 auto sliceOp = source.getDefiningOp<tensor::ExtractSliceOp>();
248 if (!sliceOp)
249 return tensor::createPadHighOp(resType: type, source, pad, nofold, loc, builder&: b,
250 dynOutDims: typeDynDims);
251
252 // Search the `source` use-def chain for padded LinalgOps.
253 Value current = sliceOp.getSource();
254 while (current) {
255 auto linalgOp = current.getDefiningOp<LinalgOp>();
256 if (!linalgOp)
257 break;
258 OpResult opResult = cast<OpResult>(Val&: current);
259 current = linalgOp.getDpsInitOperand(i: opResult.getResultNumber())->get();
260 }
261 auto padOp = current ? current.getDefiningOp<tensor::PadOp>() : nullptr;
262
263 // Exit if the search fails to match a tensor::PadOp at the end of the matched
264 // LinalgOp sequence.
265 if (!padOp)
266 return tensor::createPadHighOp(resType: type, source, pad, nofold, loc, builder&: b,
267 dynOutDims: typeDynDims);
268
269 // Exit if the padded result type does not match.
270 if (sliceOp.getSource().getType() != type)
271 return tensor::createPadHighOp(resType: type, source, pad, nofold, loc, builder&: b,
272 dynOutDims: typeDynDims);
273
274 // Exit if the LinalgOps are not high padded.
275 if (llvm::any_of(Range: padOp.getMixedLowPad(), P: [](OpFoldResult ofr) {
276 return getConstantIntValue(ofr) != static_cast<int64_t>(0);
277 }))
278 return tensor::createPadHighOp(resType: type, source, pad, nofold, loc, builder&: b,
279 dynOutDims: typeDynDims);
280
281 // Exit if `padOpSliceOp`, which defines the slice used by
282 // `padOp`, is rank-reducing.
283 auto padOpSliceOp = padOp.getSource().getDefiningOp<tensor::ExtractSliceOp>();
284 if (!padOpSliceOp ||
285 sliceOp.getMixedSizes().size() != padOpSliceOp.getMixedSizes().size())
286 return tensor::createPadHighOp(resType: type, source, pad, nofold, loc, builder&: b,
287 dynOutDims: typeDynDims);
288
289 // Exit if the sizes of the dynamic sizes of `sliceOp` do not match the size
290 // of the slice padded by `padOp`.
291 if (llvm::any_of(
292 Range: llvm::zip(t: sliceOp.getMixedSizes(), u: padOpSliceOp.getMixedSizes()),
293 P: [](std::tuple<OpFoldResult, OpFoldResult> it) {
294 return !isEqualConstantIntOrValue(ofr1: std::get<0>(t&: it), ofr2: std::get<1>(t&: it));
295 }))
296 return tensor::createPadHighOp(resType: type, source, pad, nofold, loc, builder&: b,
297 dynOutDims: typeDynDims);
298
299 // Exit if the padding values do not match.
300 Attribute padOpPadAttr, padAttr;
301 Value padOpPad = padOp.getConstantPaddingValue();
302 if (!padOpPad || !matchPattern(value: padOpPad, pattern: m_Constant(bind_value: &padOpPadAttr)) ||
303 !matchPattern(value: pad, pattern: m_Constant(bind_value: &padAttr)) || padOpPadAttr != padAttr)
304 return tensor::createPadHighOp(resType: type, source, pad, nofold, loc, builder&: b,
305 dynOutDims: typeDynDims);
306
307 // Return the padded result if the padding values and sizes match.
308 return sliceOp.getSource();
309}
310
311GenericOp makeMemRefCopyOp(OpBuilder &b, Location loc, Value from, Value to) {
312 auto memrefTypeTo = cast<MemRefType>(Val: to.getType());
313#ifndef NDEBUG
314 auto memrefTypeFrom = cast<MemRefType>(from.getType());
315 assert(memrefTypeFrom.getRank() == memrefTypeTo.getRank() &&
316 "`from` and `to` memref must have the same rank");
317#endif // NDEBUG
318
319 AffineMap id =
320 AffineMap::getMultiDimIdentityMap(numDims: memrefTypeTo.getRank(), context: b.getContext());
321 SmallVector<utils::IteratorType> iteratorTypes(memrefTypeTo.getRank(),
322 utils::IteratorType::parallel);
323 return b.create<linalg::GenericOp>(
324 location: loc,
325 /*inputs=*/args&: from,
326 /*outputs=*/args&: to,
327 /*indexingMaps=*/args: llvm::ArrayRef({id, id}),
328 /*iteratorTypes=*/args&: iteratorTypes,
329 args: [](OpBuilder &b, Location loc, ValueRange args) {
330 b.create<linalg::YieldOp>(location: loc, args: args.front());
331 });
332}
333
334/// Specialization to build an scf "for" nest.
335template <>
336void GenerateLoopNest<scf::ForOp>::doit(
337 OpBuilder &b, Location loc, ArrayRef<Range> loopRanges, LinalgOp linalgOp,
338 ArrayRef<utils::IteratorType> iteratorTypes,
339 function_ref<scf::ValueVector(OpBuilder &, Location, ValueRange,
340 ValueRange)>
341 bodyBuilderFn,
342 ArrayRef<linalg::ProcInfo> procInfo) {
343 assert((procInfo.empty() || (procInfo.size() == loopRanges.size())) &&
344 "expected as many entries for proc info as number of loops, even if "
345 "they are null entries");
346 SmallVector<Value> iterArgInitValues;
347 if (!linalgOp.hasPureBufferSemantics())
348 llvm::append_range(C&: iterArgInitValues, R: linalgOp.getDpsInits());
349 SmallVector<Value, 4> lbs, ubs, steps;
350 unpackRanges(builder&: b, loc, ranges: loopRanges, lbs, ubs, steps);
351 LoopNest loopNest = mlir::scf::buildLoopNest(
352 builder&: b, loc, lbs, ubs, steps, iterArgs: iterArgInitValues,
353 bodyBuilder: [&](OpBuilder &b, Location loc, ValueRange ivs, ValueRange iterArgs) {
354 assert(iterArgs.size() == iterArgInitValues.size() &&
355 "expect the number of output tensors and iter args to match");
356 SmallVector<Value> operandValuesToUse = linalgOp->getOperands();
357 if (!iterArgs.empty()) {
358 operandValuesToUse = linalgOp.getDpsInputs();
359 operandValuesToUse.append(in_start: iterArgs.begin(), in_end: iterArgs.end());
360 }
361 return bodyBuilderFn(b, loc, ivs, operandValuesToUse);
362 });
363
364 if (loopNest.loops.empty() || procInfo.empty())
365 return;
366
367 // Filter out scf.for loops that were created out of parallel dimensions.
368 for (const auto &loop : llvm::enumerate(First&: loopNest.loops)) {
369 if (procInfo[loop.index()].distributionMethod ==
370 DistributionMethod::Cyclic) {
371 mapLoopToProcessorIds(forOp: loop.value(), processorId: procInfo[loop.index()].procId,
372 numProcessors: procInfo[loop.index()].nprocs);
373 }
374 }
375}
376
377/// Specialization to build affine "for" nest.
378template <>
379void GenerateLoopNest<AffineForOp>::doit(
380 OpBuilder &b, Location loc, ArrayRef<Range> loopRanges, LinalgOp linalgOp,
381 ArrayRef<utils::IteratorType> iteratorTypes,
382 function_ref<scf::ValueVector(OpBuilder &, Location, ValueRange,
383 ValueRange)>
384 bodyBuilderFn,
385 ArrayRef<linalg::ProcInfo> /*procInfo*/) {
386 SmallVector<Value> iterArgInitValues;
387 if (!linalgOp.hasPureBufferSemantics())
388 llvm::append_range(C&: iterArgInitValues, R: linalgOp.getDpsInits());
389 assert(iterArgInitValues.empty() && "unexpected AffineForOp init values");
390 SmallVector<Value, 4> lbs, ubs, steps;
391 unpackRanges(builder&: b, loc, ranges: loopRanges, lbs, ubs, steps);
392
393 // Affine loops require constant steps.
394 SmallVector<int64_t, 4> constantSteps;
395 constantSteps.reserve(N: steps.size());
396 for (Value v : steps) {
397 auto constVal = getConstantIntValue(ofr: v);
398 assert(constVal.has_value() && "Affine loops require constant steps");
399 constantSteps.push_back(Elt: constVal.value());
400 }
401
402 affine::buildAffineLoopNest(builder&: b, loc, lbs, ubs, steps: constantSteps,
403 bodyBuilderFn: [&](OpBuilder &b, Location loc, ValueRange ivs) {
404 bodyBuilderFn(b, loc, ivs,
405 linalgOp->getOperands());
406 });
407}
408
409/// Update the `lb`, `ub` and `step` to get per processor `lb`, `ub` and `step`.
410void updateBoundsForCyclicDistribution(OpBuilder &b, Location loc, Value procId,
411 Value nprocs, Value &lb, Value &ub,
412 Value &step) {
413 AffineExpr d0, d1;
414 bindDims(ctx: b.getContext(), exprs&: d0, exprs&: d1);
415 AffineExpr s0 = getAffineSymbolExpr(position: 0, context: b.getContext());
416 lb =
417 affine::makeComposedAffineApply(b, loc, e: d0 + d1 * s0, operands: {lb, procId, step});
418 step = affine::makeComposedAffineApply(b, loc, e: d0 * s0, operands: {nprocs, step});
419}
420
421/// Generates a loop nest consisting of scf.parallel and scf.for, depending
422/// on the `iteratorTypes.` Consecutive parallel loops create a single
423/// scf.parallel operation; each sequential loop creates a new scf.for
424/// operation. The body of the innermost loop is populated by
425/// `bodyBuilderFn` that accepts a range of induction variables for all
426/// loops. `ivStorage` is used to store the partial list of induction
427/// variables.
428// TODO: this function can be made iterative instead. However, it
429// will have at most as many recursive calls as nested loops, which rarely
430// exceeds 10.
431static void generateParallelLoopNest(
432 OpBuilder &b, Location loc, ValueRange lbs, ValueRange ubs,
433 ValueRange steps, ArrayRef<utils::IteratorType> iteratorTypes,
434 ArrayRef<linalg::ProcInfo> procInfo,
435 function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuilderFn,
436 SmallVectorImpl<Value> &ivStorage) {
437 assert(lbs.size() == ubs.size());
438 assert(lbs.size() == steps.size());
439 assert(lbs.size() == iteratorTypes.size());
440 assert(procInfo.empty() || (lbs.size() == procInfo.size()));
441
442 // If there are no (more) loops to be generated, generate the body and be
443 // done with it.
444 if (iteratorTypes.empty()) {
445 bodyBuilderFn(b, loc, ivStorage);
446 return;
447 }
448
449 // If there are no outer parallel loops, generate one sequential loop and
450 // recurse.
451 if (!isParallelIterator(iteratorType: iteratorTypes.front())) {
452 LoopNest singleLoop = buildLoopNest(
453 builder&: b, loc, lbs: lbs.take_front(), ubs: ubs.take_front(), steps: steps.take_front(),
454 bodyBuilder: [&](OpBuilder &b, Location loc, ValueRange ivs) {
455 ivStorage.append(in_start: ivs.begin(), in_end: ivs.end());
456 generateParallelLoopNest(
457 b, loc, lbs: lbs.drop_front(), ubs: ubs.drop_front(), steps: steps.drop_front(),
458 iteratorTypes: iteratorTypes.drop_front(),
459 procInfo: procInfo.empty() ? procInfo : procInfo.drop_front(),
460 bodyBuilderFn, ivStorage);
461 });
462 return;
463 }
464
465 unsigned nLoops = iteratorTypes.size();
466 unsigned numProcessed = 0;
467 DistributionMethod distributionMethod = DistributionMethod::None;
468 if (procInfo.empty()) {
469 numProcessed = nLoops - iteratorTypes.drop_while(Pred: isParallelIterator).size();
470 } else {
471 distributionMethod = procInfo.front().distributionMethod;
472 numProcessed =
473 nLoops - procInfo
474 .drop_while(Pred: [&](linalg::ProcInfo p) {
475 return p.distributionMethod == distributionMethod;
476 })
477 .size();
478 }
479
480 auto remainderProcInfo =
481 procInfo.empty() ? procInfo : procInfo.drop_front(N: numProcessed);
482 switch (distributionMethod) {
483 case DistributionMethod::None: {
484 // Generate a single parallel loop-nest operation for all outermost
485 // parallel loops and recurse.
486 b.create<scf::ParallelOp>(
487 location: loc, args: lbs.take_front(n: numProcessed), args: ubs.take_front(n: numProcessed),
488 args: steps.take_front(n: numProcessed),
489 args: [&](OpBuilder &nestedBuilder, Location nestedLoc, ValueRange localIvs) {
490 ivStorage.append(in_start: localIvs.begin(), in_end: localIvs.end());
491 generateParallelLoopNest(
492 b&: nestedBuilder, loc: nestedLoc, lbs: lbs.drop_front(n: numProcessed),
493 ubs: ubs.drop_front(n: numProcessed), steps: steps.drop_front(n: numProcessed),
494 iteratorTypes: iteratorTypes.drop_front(N: numProcessed), procInfo: remainderProcInfo,
495 bodyBuilderFn, ivStorage);
496 });
497 return;
498 }
499 case DistributionMethod::Cyclic: {
500 // Generate a single parallel loop-nest operation for all outermost
501 // parallel loops and recurse.
502 b.create<scf::ParallelOp>(
503 location: loc, args: lbs.take_front(n: numProcessed), args: ubs.take_front(n: numProcessed),
504 args: steps.take_front(n: numProcessed),
505 args: [&](OpBuilder &nestedBuilder, Location nestedLoc, ValueRange localIvs) {
506 ivStorage.append(in_start: localIvs.begin(), in_end: localIvs.end());
507 generateParallelLoopNest(
508 b&: nestedBuilder, loc: nestedLoc, lbs: lbs.drop_front(n: numProcessed),
509 ubs: ubs.drop_front(n: numProcessed), steps: steps.drop_front(n: numProcessed),
510 iteratorTypes: iteratorTypes.drop_front(N: numProcessed), procInfo: remainderProcInfo,
511 bodyBuilderFn, ivStorage);
512 });
513 return;
514 }
515 case DistributionMethod::CyclicNumProcsGeNumIters: {
516 // Check (for the processed loops) that the iteration is in-bounds.
517 ArithBuilder ab(b, loc);
518 Value cond = ab.slt(lhs: lbs[0], rhs: ubs[0]);
519 for (unsigned i = 1; i < numProcessed; ++i)
520 cond = ab._and(lhs: cond, rhs: ab.slt(lhs: lbs[i], rhs: ubs[i]));
521 ivStorage.append(in_start: lbs.begin(), in_end: std::next(x: lbs.begin(), n: numProcessed));
522 b.create<scf::IfOp>(location: loc, args&: cond, args: [&](OpBuilder &b, Location loc) {
523 generateParallelLoopNest(b, loc, lbs: lbs.drop_front(n: numProcessed),
524 ubs: ubs.drop_front(n: numProcessed),
525 steps: steps.drop_front(n: numProcessed),
526 iteratorTypes: iteratorTypes.drop_front(N: numProcessed),
527 procInfo: remainderProcInfo, bodyBuilderFn, ivStorage);
528 b.create<scf::YieldOp>(location: loc, args: ValueRange{});
529 });
530 return;
531 }
532 case DistributionMethod::CyclicNumProcsEqNumIters:
533 // No check/loops needed here. Set the `%iv` to be the `%lb` and proceed
534 // with inner loop generation.
535 ivStorage.append(in_start: lbs.begin(), in_end: std::next(x: lbs.begin(), n: numProcessed));
536 generateParallelLoopNest(
537 b, loc, lbs: lbs.drop_front(n: numProcessed), ubs: ubs.drop_front(n: numProcessed),
538 steps: steps.drop_front(n: numProcessed), iteratorTypes: iteratorTypes.drop_front(N: numProcessed),
539 procInfo: remainderProcInfo, bodyBuilderFn, ivStorage);
540 return;
541 }
542}
543
544/// Specialization for generating a mix of parallel and sequential scf loops.
545template <>
546void GenerateLoopNest<scf::ParallelOp>::doit(
547 OpBuilder &b, Location loc, ArrayRef<Range> loopRanges, LinalgOp linalgOp,
548 ArrayRef<utils::IteratorType> iteratorTypes,
549 function_ref<scf::ValueVector(OpBuilder &, Location, ValueRange,
550 ValueRange)>
551 bodyBuilderFn,
552 ArrayRef<linalg::ProcInfo> procInfo) {
553 SmallVector<Value> iterArgInitValues;
554 if (!linalgOp.hasPureBufferSemantics())
555 llvm::append_range(C&: iterArgInitValues, R: linalgOp.getDpsInits());
556 assert(iterArgInitValues.empty() && "unexpected ParallelOp init values");
557 // This function may be passed more iterator types than ranges.
558 assert(iteratorTypes.size() >= loopRanges.size() &&
559 "expected iterator type for all ranges");
560 assert((procInfo.empty() || (procInfo.size() == loopRanges.size())) &&
561 "expected proc information for all loops when present");
562 iteratorTypes = iteratorTypes.take_front(N: loopRanges.size());
563 SmallVector<Value, 8> lbsStorage, ubsStorage, stepsStorage, ivs;
564 unsigned numLoops = iteratorTypes.size();
565 ivs.reserve(N: numLoops);
566 lbsStorage.reserve(N: numLoops);
567 ubsStorage.reserve(N: numLoops);
568 stepsStorage.reserve(N: numLoops);
569
570 // Get the loop lb, ub, and step.
571 unpackRanges(builder&: b, loc, ranges: loopRanges, lbs&: lbsStorage, ubs&: ubsStorage, steps&: stepsStorage);
572
573 // Modify the lb, ub, and step based on the distribution options.
574 for (const auto &it : llvm::enumerate(First&: procInfo)) {
575 if (it.value().distributionMethod != linalg::DistributionMethod::None) {
576 updateBoundsForCyclicDistribution(
577 b, loc, procId: it.value().procId, nprocs: it.value().nprocs, lb&: lbsStorage[it.index()],
578 ub&: ubsStorage[it.index()], step&: stepsStorage[it.index()]);
579 }
580 }
581 ValueRange lbs(lbsStorage), ubs(ubsStorage), steps(stepsStorage);
582 generateParallelLoopNest(
583 b, loc, lbs, ubs, steps, iteratorTypes, procInfo,
584 bodyBuilderFn: [&](OpBuilder &b, Location loc, ValueRange ivs) {
585 bodyBuilderFn(b, loc, ivs, linalgOp->getOperands());
586 },
587 ivStorage&: ivs);
588
589 assert(ivs.size() == iteratorTypes.size() && "did not generate enough loops");
590}
591
592static Operation *materializeTiledShape(OpBuilder &builder, Location loc,
593 Value valueToTile,
594 const SliceParameters &sliceParams) {
595 auto shapedType = dyn_cast<ShapedType>(Val: valueToTile.getType());
596 auto *sliceOp = TypeSwitch<ShapedType, Operation *>(shapedType)
597 .Case(caseFn: [&](MemRefType) {
598 return builder.create<memref::SubViewOp>(
599 location: loc, args&: valueToTile, args: sliceParams.offsets,
600 args: sliceParams.sizes, args: sliceParams.strides);
601 })
602 .Case(caseFn: [&](RankedTensorType) {
603 return builder.create<tensor::ExtractSliceOp>(
604 location: loc, args&: valueToTile, args: sliceParams.offsets,
605 args: sliceParams.sizes, args: sliceParams.strides);
606 })
607 .Default(defaultFn: [](ShapedType) -> Operation * {
608 llvm_unreachable("Unexpected shaped type");
609 });
610 return sliceOp;
611}
612
613Operation *makeTiledShape(OpBuilder &builder, Location loc, Value valueToTile,
614 ArrayRef<OpFoldResult> tileSizes, AffineMap map,
615 ArrayRef<OpFoldResult> lbs,
616 ArrayRef<OpFoldResult> ubs,
617 ArrayRef<OpFoldResult> subShapeSizes,
618 bool omitPartialTileCheck) {
619 SliceParameters sliceParams =
620 computeSliceParameters(builder, loc, valueToTile, tileSizes, map, lbs,
621 ubs, subShapeSizes, omitPartialTileCheck);
622 return materializeTiledShape(builder, loc, valueToTile, sliceParams);
623}
624
625SliceParameters
626computeSliceParameters(OpBuilder &builder, Location loc, Value valueToTile,
627 ArrayRef<OpFoldResult> tileSizes, AffineMap map,
628 ArrayRef<OpFoldResult> lbs, ArrayRef<OpFoldResult> ubs,
629 ArrayRef<OpFoldResult> subShapeSizes,
630 bool omitPartialTileCheck) {
631 auto shapedType = dyn_cast<ShapedType>(Val: valueToTile.getType());
632 assert(shapedType && "only shaped types can be tiled");
633 ArrayRef<int64_t> shape = shapedType.getShape();
634 int64_t rank = shapedType.getRank();
635
636 // Compute offsets/sizes/strides for the tile.
637 SliceParameters sliceParams;
638 sliceParams.offsets.reserve(N: rank);
639 sliceParams.sizes.reserve(N: rank);
640 sliceParams.strides.reserve(N: rank);
641 for (unsigned r = 0; r < rank; ++r) {
642 LLVM_DEBUG(llvm::dbgs() << "computeSliceParameters: for dim#" << r);
643 if (!isTiled(map: map.getSubMap(resultPos: {r}), tileSizes)) {
644 sliceParams.offsets.push_back(Elt: builder.getIndexAttr(value: 0));
645 OpFoldResult dim = createFoldedDimOp(b&: builder, loc, val: valueToTile, dim: r);
646 sliceParams.sizes.push_back(Elt: dim);
647 sliceParams.strides.push_back(Elt: builder.getIndexAttr(value: 1));
648 LLVM_DEBUG(llvm::dbgs() << ": not tiled: use size: " << dim << "\n");
649 continue;
650 }
651 LLVM_DEBUG(llvm::dbgs() << ": tiled: figure out subsize...\n");
652
653 // Tiling creates a new slice at the proper index, the slice step is 1
654 // (i.e. the op does not subsample, stepping occurs in the loop).
655 auto m = map.getSubMap(resultPos: {r});
656 LLVM_DEBUG(llvm::dbgs() << "computeSliceParameters: submap: " << m << "\n");
657 IRRewriter rewriter(builder);
658 // The offset of the slice is m(lbs) - m(0).
659 SmallVector<Attribute> zeros(lbs.size(), rewriter.getIndexAttr(value: 0));
660 SmallVector<Attribute> mAtZero;
661 [[maybe_unused]] auto res = m.constantFold(operandConstants: zeros, results&: mAtZero);
662 assert(succeeded(res) && "affine_map must be evaluatable (not symbols)");
663 int64_t mAtZeroInt =
664 cast<IntegerAttr>(Val&: mAtZero[0]).getValue().getSExtValue();
665 OpFoldResult offset = makeComposedFoldedAffineApply(
666 b&: rewriter, loc, expr: m.getResult(idx: 0) - mAtZeroInt, operands: lbs);
667 sliceParams.offsets.push_back(Elt: offset);
668
669 OpFoldResult closedIntSize =
670 makeComposedFoldedAffineApply(b&: rewriter, loc, map: m, operands: subShapeSizes);
671 // Resulting size needs to be made half open interval again.
672 AffineExpr s0 = getAffineSymbolExpr(position: 0, context: builder.getContext());
673 OpFoldResult size =
674 makeComposedFoldedAffineApply(b&: rewriter, loc, expr: s0 + 1, operands: closedIntSize);
675 LLVM_DEBUG(llvm::dbgs()
676 << "computeSliceParameters: raw size: " << size << "\n");
677 LLVM_DEBUG(llvm::dbgs()
678 << "computeSliceParameters: new offset: " << offset << "\n");
679 sliceParams.strides.push_back(Elt: builder.getIndexAttr(value: 1));
680
681 if (omitPartialTileCheck) {
682 // We statically know that the partial/boundary tile condition is
683 // unnecessary.
684 LLVM_DEBUG(llvm::dbgs() << "makeTiledShape: new size: " << size << "\n");
685 sliceParams.sizes.push_back(Elt: size);
686 continue;
687 }
688
689 // The size of the subview / extract_slice should be trimmed to avoid
690 // out-of-bounds accesses, unless:
691 // a. We statically know the subshape size divides the shape size evenly.
692 // b. The subshape size is 1. According to the way the loops are set up,
693 // tensors with "0" dimensions would never be constructed.
694 int64_t shapeSize = shape[r];
695 std::optional<int64_t> sizeCst = getConstantIntValue(ofr: size);
696 auto hasTileSizeOne = sizeCst == 1;
697 auto dividesEvenly = sizeCst && ShapedType::isStatic(dValue: shapeSize) &&
698 ((shapeSize % *sizeCst) == 0);
699 if (!hasTileSizeOne && !dividesEvenly) {
700 LLVM_DEBUG(llvm::dbgs() << "makeTiledShape: shapeSize=" << shapeSize
701 << ", size: " << size
702 << ": make sure in bound with affine.min\n");
703
704 AffineExpr dim0, dim1, dim2;
705 MLIRContext *context = builder.getContext();
706 bindDims(ctx: context, exprs&: dim0, exprs&: dim1, exprs&: dim2);
707
708 // Get the dimension size for this dimension. We need to first calculate
709 // the max index and then plus one. This is important because for
710 // convolution ops, we have its input window dimension's affine map of the
711 // form `(d0 * s0 + d1)`, where `d0`/`d1 is an output/filter window
712 // dimension and `s0` is stride. Directly use the dimension size of
713 // output/filer window dimensions will cause incorrect calculation.
714 AffineMap minusOneMap = AffineMap::inferFromExprList(
715 exprsList: {ArrayRef<AffineExpr>{dim0 - 1}}, context)
716 .front();
717 AffineMap plusOneMap = AffineMap::inferFromExprList(
718 exprsList: {ArrayRef<AffineExpr>{dim0 + 1}}, context)
719 .front();
720 SmallVector<OpFoldResult> maxIndices =
721 llvm::to_vector(Range: llvm::map_range(C&: ubs, F: [&](OpFoldResult ub) {
722 return makeComposedFoldedAffineApply(b&: rewriter, loc, map: minusOneMap,
723 operands: {ub});
724 }));
725 OpFoldResult maxIndex =
726 makeComposedFoldedAffineApply(b&: rewriter, loc, map: m, operands: maxIndices);
727 OpFoldResult d =
728 makeComposedFoldedAffineApply(b&: rewriter, loc, map: plusOneMap, operands: {maxIndex});
729
730 // Compute min(dim - offset, size) to avoid out-of-bounds accesses.
731 AffineMap minMap = AffineMap::inferFromExprList(
732 exprsList: {ArrayRef<AffineExpr>{dim1 - dim2, dim0}}, context)
733 .front();
734 size =
735 makeComposedFoldedAffineMin(b&: rewriter, loc, map: minMap, operands: {size, d, offset});
736 }
737 LLVM_DEBUG(llvm::dbgs() << "makeTiledShape: new size: " << size << "\n");
738 sliceParams.sizes.push_back(Elt: size);
739 }
740 return sliceParams;
741}
742
743SmallVector<OpFoldResult> computeTileOffsets(OpBuilder &b, Location loc,
744 ArrayRef<OpFoldResult> ivs,
745 ArrayRef<OpFoldResult> tileSizes) {
746 SmallVector<OpFoldResult> offsets;
747 for (unsigned idx = 0, idxIvs = 0, e = tileSizes.size(); idx < e; ++idx) {
748 LLVM_DEBUG(llvm::dbgs() << "makeTiledShapes: for loop#" << idx << "\n");
749 bool isTiled = !isZeroInteger(v: tileSizes[idx]);
750 offsets.push_back(Elt: isTiled ? ivs[idxIvs++] : b.getIndexAttr(value: 0));
751 LLVM_DEBUG(llvm::dbgs()
752 << "computeTileOffsets: " << offsets.back() << "\n");
753 }
754 return offsets;
755}
756
757SmallVector<OpFoldResult> computeTileSizes(OpBuilder &b, Location loc,
758 ArrayRef<OpFoldResult> tileSizes,
759 ArrayRef<OpFoldResult> sizeBounds) {
760 SmallVector<OpFoldResult> sizes;
761 for (unsigned idx = 0, e = tileSizes.size(); idx < e; ++idx) {
762 bool isTiled = !isZeroInteger(v: tileSizes[idx]);
763 // Before composing, we need to make range a closed interval.
764 OpFoldResult size = isTiled ? tileSizes[idx] : sizeBounds[idx];
765 AffineExpr d0 = getAffineDimExpr(position: 0, context: b.getContext());
766 IRRewriter rewriter(b);
767 sizes.push_back(Elt: makeComposedFoldedAffineApply(b&: rewriter, loc, expr: d0 - 1, operands: size));
768 LLVM_DEBUG(llvm::dbgs() << "computeTileSizes: " << sizes.back() << "\n");
769 }
770 return sizes;
771}
772
773SmallVector<Type> getTensorOutputTypes(LinalgOp op, ValueRange operands) {
774 if (op.hasPureBufferSemantics())
775 return {};
776 return llvm::to_vector(
777 Range: llvm::map_range(C: op.getDpsInitsMutable(), F: [&](OpOperand &opOperand) {
778 return operands[opOperand.getOperandNumber()].getType();
779 }));
780}
781
782SmallVector<Value> insertSlicesBack(OpBuilder &builder, Location loc,
783 LinalgOp op, ValueRange operands,
784 ValueRange results) {
785 if (op.hasPureBufferSemantics())
786 return {};
787 SmallVector<Value> tensorResults;
788 tensorResults.reserve(N: results.size());
789 // Insert a insert_slice for each output tensor.
790 unsigned resultIdx = 0;
791 for (OpOperand &opOperand : op.getDpsInitsMutable()) {
792 // TODO: use an interface/adaptor to avoid leaking position in
793 // `tiledOperands`.
794 Value outputTensor = operands[opOperand.getOperandNumber()];
795 if (auto sliceOp = outputTensor.getDefiningOp<tensor::ExtractSliceOp>()) {
796 Value inserted = builder.create<tensor::InsertSliceOp>(
797 location: loc, args: sliceOp.getSource().getType(), args: results[resultIdx],
798 args: sliceOp.getSource(), args: sliceOp.getOffsets(), args: sliceOp.getSizes(),
799 args: sliceOp.getStrides(), args: sliceOp.getStaticOffsets(),
800 args: sliceOp.getStaticSizes(), args: sliceOp.getStaticStrides());
801 tensorResults.push_back(Elt: inserted);
802 } else {
803 tensorResults.push_back(Elt: results[resultIdx]);
804 }
805 ++resultIdx;
806 }
807 return tensorResults;
808}
809
810SmallVector<std::optional<SliceParameters>>
811computeAllSliceParameters(OpBuilder &builder, Location loc, LinalgOp linalgOp,
812 ValueRange valuesToTile, ArrayRef<OpFoldResult> ivs,
813 ArrayRef<OpFoldResult> tileSizes,
814 ArrayRef<OpFoldResult> sizeBounds,
815 bool omitPartialTileCheck) {
816 assert(ivs.size() == static_cast<size_t>(llvm::count_if(
817 llvm::make_range(tileSizes.begin(), tileSizes.end()),
818 [](OpFoldResult v) { return !isZeroInteger(v); })) &&
819 "expected as many ivs as non-zero sizes");
820
821 // Construct (potentially temporary) mins and maxes on which to apply maps
822 // that define tile subshapes.
823 SmallVector<OpFoldResult> lbs =
824 computeTileOffsets(b&: builder, loc, ivs, tileSizes);
825 SmallVector<OpFoldResult> subShapeSizes =
826 computeTileSizes(b&: builder, loc, tileSizes, sizeBounds);
827
828 assert(static_cast<int64_t>(valuesToTile.size()) <=
829 linalgOp->getNumOperands() &&
830 "more value to tile than operands.");
831 SmallVector<std::optional<SliceParameters>> allSliceParams;
832 allSliceParams.reserve(N: valuesToTile.size());
833 for (auto [opOperand, val] :
834 llvm::zip(t: linalgOp->getOpOperands(), u&: valuesToTile)) {
835 Value shapedOp = val;
836 LLVM_DEBUG(llvm::dbgs() << "makeTiledShapes: for operand " << shapedOp);
837 AffineMap map = linalgOp.getMatchingIndexingMap(opOperand: &opOperand);
838 // Use `opOperand` as is if it is not tiled and not an output tensor. Having
839 // an extract/insert slice pair for all output tensors simplifies follow up
840 // transformations such as padding and bufferization since the
841 // extract/insert slice pairs make the accessed iteration argument
842 // subdomains explicit.
843
844 Type operandType = opOperand.get().getType();
845 if (!isTiled(map, tileSizes) && !(isa<RankedTensorType>(Val: operandType) &&
846 linalgOp.isDpsInit(opOperand: &opOperand))) {
847 allSliceParams.push_back(Elt: std::nullopt);
848 LLVM_DEBUG(llvm::dbgs()
849 << ": not tiled: use shape: " << operandType << "\n");
850 continue;
851 }
852 LLVM_DEBUG(llvm::dbgs() << ": tiled: figure out subshape...\n");
853
854 allSliceParams.push_back(Elt: computeSliceParameters(
855 builder, loc, valueToTile: shapedOp, tileSizes, map, lbs, ubs: sizeBounds, subShapeSizes,
856 omitPartialTileCheck));
857 }
858
859 return allSliceParams;
860}
861
862SmallVector<Value> makeTiledShapes(OpBuilder &builder, Location loc,
863 LinalgOp linalgOp, ValueRange valuesToTile,
864 ArrayRef<OpFoldResult> ivs,
865 ArrayRef<OpFoldResult> tileSizes,
866 ArrayRef<OpFoldResult> sizeBounds,
867 bool omitPartialTileCheck) {
868 SmallVector<std::optional<SliceParameters>> allSliceParameter =
869 computeAllSliceParameters(builder, loc, linalgOp, valuesToTile, ivs,
870 tileSizes, sizeBounds, omitPartialTileCheck);
871 SmallVector<Value> tiledShapes;
872 for (auto item : llvm::zip(t&: valuesToTile, u&: allSliceParameter)) {
873 Value valueToTile = std::get<0>(t&: item);
874 std::optional<SliceParameters> sliceParams = std::get<1>(t&: item);
875 tiledShapes.push_back(
876 Elt: sliceParams.has_value()
877 ? materializeTiledShape(builder, loc, valueToTile, sliceParams: *sliceParams)
878 ->getResult(idx: 0)
879 : valueToTile);
880 }
881 return tiledShapes;
882}
883
884void offsetIndices(OpBuilder &b, LinalgOp linalgOp,
885 ArrayRef<OpFoldResult> offsets) {
886 IRRewriter rewriter(b);
887 offsetIndices(b&: rewriter, linalgOp, offests: offsets);
888}
889
890void offsetIndices(RewriterBase &b, LinalgOp linalgOp,
891 ArrayRef<OpFoldResult> offsets) {
892 if (!linalgOp.hasIndexSemantics())
893 return;
894
895 for (IndexOp indexOp : linalgOp.getBlock()->getOps<IndexOp>()) {
896 if (indexOp.getDim() >= offsets.size() || !offsets[indexOp.getDim()])
897 continue;
898 OpBuilder::InsertionGuard guard(b);
899 b.setInsertionPointAfter(indexOp);
900 AffineExpr index, offset;
901 bindDims(ctx: b.getContext(), exprs&: index, exprs&: offset);
902 OpFoldResult applied = makeComposedFoldedAffineApply(
903 b, loc: indexOp.getLoc(), expr: index + offset,
904 operands: {getAsOpFoldResult(val: indexOp.getResult()), offsets[indexOp.getDim()]});
905 Value materialized =
906 getValueOrCreateConstantIndexOp(b, loc: indexOp.getLoc(), ofr: applied);
907 b.replaceUsesWithIf(from: indexOp, to: materialized, functor: [&](OpOperand &use) {
908 return use.getOwner() != materialized.getDefiningOp();
909 });
910 }
911}
912
913/// Get the reassociation maps to fold the result of a extract_slice (or source
914/// of a insert_slice) operation with given offsets, and sizes to its
915/// rank-reduced version. This is only done for the cases where the size is 1
916/// and offset is 0. Strictly speaking the offset 0 is not required in general,
917/// but non-zero offsets are not handled by SPIR-V backend at this point (and
918/// potentially cannot be handled).
919std::optional<SmallVector<ReassociationIndices>>
920getReassociationMapForFoldingUnitDims(ArrayRef<OpFoldResult> mixedSizes) {
921 SmallVector<ReassociationIndices> reassociation;
922 ReassociationIndices curr;
923 for (const auto &it : llvm::enumerate(First&: mixedSizes)) {
924 auto dim = it.index();
925 auto size = it.value();
926 curr.push_back(Elt: dim);
927 auto attr = llvm::dyn_cast_if_present<Attribute>(Val&: size);
928 if (attr && cast<IntegerAttr>(Val&: attr).getInt() == 1)
929 continue;
930 reassociation.emplace_back(Args: ReassociationIndices{});
931 std::swap(LHS&: reassociation.back(), RHS&: curr);
932 }
933 // When the reassociations are not empty, then fold the remaining
934 // unit-dimensions into the last dimension. If the reassociations so far is
935 // empty, then leave it emtpy. This will fold everything to a rank-0 tensor.
936 if (!curr.empty() && !reassociation.empty())
937 reassociation.back().append(in_start: curr.begin(), in_end: curr.end());
938 return reassociation;
939}
940
941} // namespace linalg
942} // namespace mlir
943

source code of mlir/lib/Dialect/Linalg/Utils/Utils.cpp