1//===- Tiling.cpp - Implementation of linalg Tiling -----------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implements the linalg dialect Tiling pass.
10//
11//===----------------------------------------------------------------------===//
12
13#include "mlir/Dialect/Linalg/Passes.h"
14
15#include "mlir/Dialect/Affine/IR/AffineOps.h"
16#include "mlir/Dialect/Affine/LoopUtils.h"
17#include "mlir/Dialect/Arith/Utils/Utils.h"
18#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
19#include "mlir/Dialect/Func/IR/FuncOps.h"
20#include "mlir/Dialect/Linalg/IR/Linalg.h"
21#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
22#include "mlir/Dialect/MemRef/IR/MemRef.h"
23#include "mlir/Dialect/SCF/Transforms/Transforms.h"
24#include "mlir/Dialect/Tensor/IR/Tensor.h"
25#include "mlir/Dialect/Utils/IndexingUtils.h"
26#include "mlir/Dialect/Utils/StaticValueUtils.h"
27#include "mlir/IR/AffineExpr.h"
28#include "mlir/IR/AffineMap.h"
29#include "mlir/IR/BuiltinOps.h"
30#include "mlir/IR/ValueRange.h"
31#include "mlir/Transforms/FoldUtils.h"
32#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
33#include "llvm/ADT/STLExtras.h"
34#include "llvm/Support/CommandLine.h"
35#include <utility>
36
37namespace mlir {
38#define GEN_PASS_DEF_LINALGTILINGPASS
39#include "mlir/Dialect/Linalg/Passes.h.inc"
40} // namespace mlir
41
42using namespace mlir;
43using namespace mlir::affine;
44using namespace mlir::linalg;
45using namespace mlir::scf;
46
47#define DEBUG_TYPE "linalg-tiling"
48
49std::tuple<SmallVector<Range, 4>, LoopIndexToRangeIndexMap>
50mlir::linalg::makeTiledLoopRanges(RewriterBase &b, Location loc, AffineMap map,
51 ArrayRef<OpFoldResult> allShapeSizes,
52 ArrayRef<OpFoldResult> allTileSizes) {
53 assert(allTileSizes.size() == map.getNumResults());
54 // Apply `map` to get shape sizes in loop order.
55 SmallVector<OpFoldResult> shapeSizes =
56 makeComposedFoldedMultiResultAffineApply(b, loc, map, allShapeSizes);
57 SmallVector<OpFoldResult> tileSizes(allTileSizes);
58
59 // Traverse the tile sizes, which are in loop order, erase zeros everywhere.
60 LoopIndexToRangeIndexMap loopIndexToRangeIndex;
61 for (int idx = 0, e = tileSizes.size(), zerosCount = 0; idx < e; ++idx) {
62 if (getConstantIntValue(tileSizes[idx - zerosCount]) ==
63 static_cast<int64_t>(0)) {
64 shapeSizes.erase(shapeSizes.begin() + idx - zerosCount);
65 tileSizes.erase(tileSizes.begin() + idx - zerosCount);
66 ++zerosCount;
67 continue;
68 }
69 loopIndexToRangeIndex[idx] = idx - zerosCount;
70 }
71
72 // Create a new range with the applied tile sizes.
73 SmallVector<Range, 4> res;
74 for (unsigned idx = 0, e = tileSizes.size(); idx < e; ++idx)
75 res.push_back(Range{b.getIndexAttr(0), shapeSizes[idx], tileSizes[idx]});
76 return std::make_tuple(res, loopIndexToRangeIndex);
77}
78
79void mlir::linalg::transformIndexOps(
80 RewriterBase &b, LinalgOp op, SmallVectorImpl<Value> &ivs,
81 const LoopIndexToRangeIndexMap &loopIndexToRangeIndex) {
82 SmallVector<Value> allIvs(op.getNumLoops(), nullptr);
83 for (auto en : enumerate(allIvs)) {
84 auto rangeIndex = loopIndexToRangeIndex.find(en.index());
85 if (rangeIndex == loopIndexToRangeIndex.end())
86 continue;
87 en.value() = ivs[rangeIndex->second];
88 }
89 offsetIndices(b, op, getAsOpFoldResult(values: allIvs));
90}
91
92/// Asserts that the given index-typed value is strictly positive. If the value
93/// is an attribute, asserts at compile time, otherwise emits an assertion
94/// checked at runtime.
95static void emitIsPositiveIndexAssertion(ImplicitLocOpBuilder &b,
96 OpFoldResult value) {
97 if (auto attr = llvm::dyn_cast_if_present<Attribute>(Val&: value)) {
98 assert(cast<IntegerAttr>(attr).getValue().isStrictlyPositive() &&
99 "expected strictly positive tile size and divisor");
100 return;
101 }
102
103 Value zero = b.create<arith::ConstantIndexOp>(args: 0);
104 Value condition = b.create<arith::CmpIOp>(arith::CmpIPredicate::sgt,
105 cast<Value>(value), zero);
106 b.create<cf::AssertOp>(
107 condition,
108 b.getStringAttr("expected strictly positive tile size and divisor"));
109}
110
111FailureOr<StaticContinuousTileSizeSpecification>
112mlir::linalg::computeStaticContinuousTileSizes(LinalgOp op,
113 unsigned dimension,
114 unsigned targetSize) {
115
116 assert(!op.hasDynamicShape() &&
117 "cannot compute static multi-tile sizes for an op with dynamic shape");
118 assert(targetSize > 0 && "target size must be non-negative");
119 assert(dimension < op.getNumLoops() && "dimension overflow");
120
121 StaticContinuousTileSizeSpecification spec;
122 int64_t loopRange = op.getStaticLoopRanges()[dimension];
123 int64_t tripCount = loopRange / targetSize;
124
125 unsigned tileSize = targetSize;
126
127 spec.tileSizes.push_back(Elt: tileSize);
128 spec.tripCounts.push_back(Elt: tripCount);
129
130 int64_t remainderChunk = loopRange % targetSize;
131
132 while (tileSize > 1 && remainderChunk != 0) {
133
134 uint64_t maxPower = llvm::bit_floor(Value: tileSize);
135 tileSize = maxPower == tileSize ? maxPower >> 1 : maxPower;
136
137 tripCount = remainderChunk / tileSize;
138
139 if (tripCount > 0) {
140 spec.tileSizes.push_back(Elt: tileSize);
141 spec.tripCounts.push_back(Elt: tripCount);
142 }
143
144 remainderChunk = remainderChunk % tileSize;
145 }
146
147 auto tripCountCheck = [&](SmallVector<int64_t> tileSizes,
148 SmallVector<int64_t> tripCounts,
149 int64_t range) -> bool {
150 int64_t computedRange = 0;
151 for (auto [tileSize, tripCount] : llvm::zip(t&: tileSizes, u&: tripCounts))
152 computedRange += tileSize * tripCount;
153 return range == computedRange;
154 };
155
156 if (!tripCountCheck(spec.tileSizes, spec.tripCounts, loopRange))
157 return failure();
158
159 return spec;
160}
161
162FailureOr<ContinuousTileSizeSpecification>
163mlir::linalg::computeContinuousTileSizes(OpBuilder &builder, TilingInterface op,
164 unsigned dimension,
165 OpFoldResult targetSize,
166 bool emitAssertions) {
167
168 SmallVector<Range> loopRanges = op.getIterationDomain(builder);
169 unsigned numLoops = loopRanges.size();
170
171 // Bail out on dimension overflow.
172 if (dimension >= numLoops)
173 return failure();
174
175 // The code below works only on values.
176 Location loc = op->getLoc();
177 ImplicitLocOpBuilder b(loc, builder);
178 if (emitAssertions) {
179 emitIsPositiveIndexAssertion(b, value: targetSize);
180 }
181 Value targetSizeValue =
182 getValueOrCreateConstantIndexOp(b&: builder, loc, ofr: targetSize);
183
184 // Find the trip count of the iteration space dimension for which the tile
185 // sizes are computed.
186 Value loopRange = getValueOrCreateConstantIndexOp(b, loc,
187 ofr: loopRanges[dimension].size);
188 ContinuousTileSizeSpecification spec;
189
190 // Compute the tile sizes and the respective numbers of tiles.
191 AffineExpr s0 = b.getAffineSymbolExpr(position: 0);
192 AffineExpr s1 = b.getAffineSymbolExpr(position: 1);
193 auto apply = [&](AffineExpr expr, ArrayRef<OpFoldResult> ofrs) -> Value {
194 return affine::makeComposedAffineApply(b, b.getLoc(), expr, ofrs);
195 };
196
197 Value tripCountValue = apply(s0.floorDiv(other: s1), {loopRange, targetSizeValue});
198 Value remainderChunkValue = apply(s0 % s1, {loopRange, targetSizeValue});
199
200 OpFoldResult tripCountSize = affine::makeComposedFoldedAffineApply(
201 b, loc: b.getLoc(), expr: s0.floorDiv(other: s1), operands: {loopRange, targetSizeValue});
202
203 // emitAssertions above already asserts that targetSize is
204 // a poistive integer.
205 uint64_t tileSizeInt = *getConstantIntValue(ofr: targetSizeValue);
206
207 assert(tileSizeInt > 0 && "target size must be non-negative");
208
209 spec.tileSizes.push_back(Elt: targetSizeValue);
210 spec.tripCounts.push_back(Elt: tripCountValue);
211
212 while (tileSizeInt > 1) {
213 uint64_t maxPower = llvm::bit_floor(Value: tileSizeInt);
214 tileSizeInt = maxPower == tileSizeInt ? maxPower >> 1 : maxPower;
215 auto constStepOp =
216 builder.createOrFold<arith::ConstantIndexOp>(b.getLoc(), tileSizeInt);
217 tripCountValue = apply(s0.floorDiv(other: s1), {remainderChunkValue, constStepOp});
218
219 tripCountSize = affine::makeComposedFoldedAffineApply(
220 b, b.getLoc(), s0.floorDiv(other: s1), {remainderChunkValue, constStepOp});
221
222 // Optimization if tripCount can be determined to be zero.
223 if (Attribute attr = llvm::dyn_cast_if_present<Attribute>(Val&: tripCountSize)) {
224 auto intAttr = cast<IntegerAttr>(attr);
225 bool isTripCountZero = intAttr.getValue().isZero();
226
227 if (!isTripCountZero) {
228 spec.tileSizes.push_back(Elt: constStepOp);
229 spec.tripCounts.push_back(Elt: tripCountValue);
230 }
231 } else {
232 spec.tileSizes.push_back(Elt: constStepOp);
233 spec.tripCounts.push_back(Elt: tripCountValue);
234 }
235
236 remainderChunkValue = apply(s0 % s1, {remainderChunkValue, constStepOp});
237 }
238
239 return spec;
240}
241
242FailureOr<StaticMultiSizeSpecification>
243mlir::linalg::computeStaticMultiTileSizes(LinalgOp op, unsigned dimension,
244 int64_t targetSize, int64_t divisor) {
245 assert(!op.hasDynamicShape() &&
246 "cannot compute static multi-tile sizes for an op with dynamic shape");
247 assert(targetSize > 0 && "target size must be non-negative");
248 assert(divisor > 0 && "divisor must be non-negative");
249 assert(dimension < op.getNumLoops() && "dimension overflow");
250
251 StaticMultiSizeSpecification spec;
252 int64_t tripCount = op.getStaticLoopRanges()[dimension];
253 int64_t a = tripCount / divisor;
254 int64_t t = (targetSize + divisor - 1) / divisor;
255 int64_t totalTripCount = (a + t - 1) / t;
256 spec.lowTileSize = (a / totalTripCount) * divisor;
257 spec.highTileSize = spec.lowTileSize + divisor;
258 spec.highTripCount = a % totalTripCount;
259 spec.lowTripCount = totalTripCount - spec.highTripCount;
260 if (spec.lowTileSize * spec.lowTripCount +
261 spec.highTileSize * spec.highTripCount !=
262 tripCount) {
263 return failure();
264 }
265 return spec;
266}
267
268FailureOr<MultiSizeSpecification>
269mlir::linalg::computeMultiTileSizes(OpBuilder &builder, LinalgOp op,
270 unsigned dimension, OpFoldResult targetSize,
271 OpFoldResult divisor, bool emitAssertions) {
272 // Bail out on dimension overflow.
273 if (dimension >= op.getNumLoops())
274 return failure();
275
276 // The code below works only on values.
277 Location loc = op.getLoc();
278 ImplicitLocOpBuilder b(loc, builder);
279 if (emitAssertions) {
280 emitIsPositiveIndexAssertion(b, value: targetSize);
281 emitIsPositiveIndexAssertion(b, value: divisor);
282 }
283 Value targetSizeValue =
284 getValueOrCreateConstantIndexOp(b&: builder, loc, ofr: targetSize);
285 Value divisorValue = getValueOrCreateConstantIndexOp(b&: builder, loc, ofr: divisor);
286
287 // Find the trip count of the iteration space dimension for which the tile
288 // sizes are computed.
289 SmallVector<OpFoldResult> allShapes =
290 op.createFlatListOfOperandDims(b, b.getLoc());
291 AffineMap shapesToLoops = op.getShapesToLoopsMap();
292 SmallVector<OpFoldResult> loopRanges =
293 makeComposedFoldedMultiResultAffineApply(b, op.getLoc(), shapesToLoops,
294 allShapes);
295 Value tripCount =
296 getValueOrCreateConstantIndexOp(b, op.getLoc(), loopRanges[dimension]);
297
298 // Compute the tile sizes and the respective numbers of tiles.
299 AffineExpr s0 = b.getAffineSymbolExpr(position: 0);
300 AffineExpr s1 = b.getAffineSymbolExpr(position: 1);
301 AffineExpr s2 = b.getAffineSymbolExpr(position: 2);
302 auto apply = [&](AffineExpr expr, ArrayRef<OpFoldResult> ofrs) -> Value {
303 return affine::makeComposedAffineApply(b, b.getLoc(), expr, ofrs);
304 };
305 Value a = apply(s0.floorDiv(other: s1), {tripCount, divisorValue});
306 Value t = apply((s0 + s1 - 1).floorDiv(other: s1), {targetSizeValue, divisorValue});
307 Value d = apply((s0 + s1 - 1).floorDiv(other: s1), {a, t});
308 Value s = apply(s0.floorDiv(other: s1) * s2, {a, d, divisorValue});
309 Value v = apply(s0 % s1, {a, d});
310 Value u = apply(s0 - s1, {d, v});
311
312 MultiSizeSpecification spec;
313 spec.lowTileSize = s;
314 spec.highTileSize = apply(s0 + s1, {s, divisorValue});
315 spec.lowTripCount = u;
316 spec.highTripCount = v;
317
318 // If requested, emit the check that the tile sizes are computed correctly.
319 // For example, for iteration dimension size of 15 and the target size 8 it is
320 // impossible to find two tile sizes both divisible by 8 that fully cover the
321 // original space dimension.
322 if (emitAssertions) {
323 AffineExpr s3 = builder.getAffineSymbolExpr(position: 3);
324 Value coveredSize =
325 apply(s0 * s1 + s2 * s3, {spec.lowTileSize, spec.lowTripCount,
326 spec.highTileSize, spec.highTripCount});
327 Value equals = b.create<arith::CmpIOp>(arith::CmpIPredicate::eq,
328 coveredSize, tripCount);
329 b.create<cf::AssertOp>(
330 equals, builder.getStringAttr(
331 "could not compute dynamic multi-size tile shapes"));
332 }
333
334 return spec;
335}
336
337/// Returns true if the maximum tile offset `tileSize * numThreads-1` is less
338/// than `iterationSize`.
339static bool canOmitTileOffsetInBoundsCheck(OpFoldResult tileSize,
340 OpFoldResult numThreads,
341 OpFoldResult iterationSize) {
342 std::optional<int64_t> tileSizeConst = getConstantIntValue(ofr: tileSize);
343 std::optional<int64_t> numThreadsConst = getConstantIntValue(ofr: numThreads);
344 std::optional<int64_t> iterSizeConst = getConstantIntValue(ofr: iterationSize);
345 if (!tileSizeConst || !numThreadsConst || !iterSizeConst)
346 return false;
347 return *tileSizeConst * (*numThreadsConst - 1) < *iterSizeConst;
348}
349
350/// Build an `affine_max` of all the `vals`.
351static OpFoldResult buildMax(OpBuilder &b, Location loc,
352 ArrayRef<OpFoldResult> vals) {
353 return affine::makeComposedFoldedAffineMax(
354 b, loc, map: AffineMap::getMultiDimIdentityMap(numDims: vals.size(), context: loc.getContext()),
355 operands: vals);
356}
357
358/// Build an `affine_min` of all the `vals`.
359static OpFoldResult buildMin(OpBuilder &b, Location loc,
360 ArrayRef<OpFoldResult> vals) {
361 return affine::makeComposedFoldedAffineMin(
362 b, loc, map: AffineMap::getMultiDimIdentityMap(numDims: vals.size(), context: loc.getContext()),
363 operands: vals);
364}
365
366/// Fill out the `tiledOffsets` and `tiledSizes` to be used to tile to a given
367/// number of threads.
368static void calculateTileOffsetsAndSizes(
369 RewriterBase &b, Location loc, scf::ForallOp forallOp,
370 ArrayRef<OpFoldResult> numThreads, SmallVector<Range> loopRanges,
371 bool omitTileOffsetBoundsCheck,
372 std::optional<ArrayRef<OpFoldResult>> nominalTileSizes,
373 SmallVector<OpFoldResult> &tiledOffsets,
374 SmallVector<OpFoldResult> &tiledSizes) {
375 OpBuilder::InsertionGuard g(b);
376 b.setInsertionPointToStart(forallOp.getBody(0));
377
378 SmallVector<Value> threadIds = forallOp.getInductionVars();
379 SmallVector<OpFoldResult> nonZeroNumThreads = llvm::filter_to_vector(
380 C&: numThreads, Pred: [](OpFoldResult ofr) { return !isZeroInteger(v: ofr); });
381 int64_t nLoops = loopRanges.size();
382 tiledOffsets.reserve(N: nLoops);
383 tiledSizes.reserve(N: nLoops);
384 for (unsigned loopIdx = 0, threadIdIdx = 0; loopIdx < nLoops; ++loopIdx) {
385 bool overflow = loopIdx >= numThreads.size();
386 bool isZero = !overflow && isZeroInteger(v: numThreads[loopIdx]);
387 // Degenerate case: take the whole domain.
388 if (overflow || isZero) {
389 tiledOffsets.push_back(Elt: loopRanges[loopIdx].offset);
390 tiledSizes.push_back(Elt: loopRanges[loopIdx].size);
391 continue;
392 }
393
394 // Tiled case: compute the offset and size.
395 AffineExpr i, j, m, n, o;
396 bindDims(ctx: b.getContext(), exprs&: i, exprs&: j);
397 bindSymbols(ctx: b.getContext(), exprs&: m, exprs&: n, exprs&: o);
398 OpFoldResult size = loopRanges[loopIdx].size;
399 OpFoldResult offset = loopRanges[loopIdx].offset;
400 OpFoldResult threadId = threadIds[threadIdIdx];
401 // Symbolic fixed max size per thread.
402 // TODO: floor + 0/1 depending on case for better load-balancing.
403 OpFoldResult tileSizePerThread =
404 nominalTileSizes.has_value()
405 ? (*nominalTileSizes)[loopIdx]
406 : makeComposedFoldedAffineApply(
407 b, loc, expr: m.ceilDiv(other: n),
408 operands: ArrayRef<OpFoldResult>{size, nonZeroNumThreads[threadIdIdx]});
409
410 // Dynamic offset shifted by threadId * maxSizePerThread.
411 OpFoldResult offsetPerThread = makeComposedFoldedAffineApply(
412 b, loc, expr: i + j * m, operands: {offset, threadId, tileSizePerThread});
413 // Dynamic upper-bound depending on the threadId.
414 OpFoldResult residualTileSize = makeComposedFoldedAffineApply(
415 b, loc, expr: i + j * m - n,
416 operands: {offset, nonZeroNumThreads[threadIdIdx], tileSizePerThread, size});
417 if (!isZeroInteger(v: residualTileSize)) {
418 OpFoldResult sizeMinusOffsetPerThread = makeComposedFoldedAffineApply(
419 b, loc, expr: -i + m, operands: {offsetPerThread, size});
420 tileSizePerThread =
421 buildMin(b, loc, vals: {sizeMinusOffsetPerThread, tileSizePerThread});
422 }
423
424 tiledOffsets.push_back(Elt: offsetPerThread);
425 // TODO: if tileSizePerThread <= 0 early exit.
426 if (!omitTileOffsetBoundsCheck &&
427 !canOmitTileOffsetInBoundsCheck(tileSize: tileSizePerThread,
428 numThreads: nonZeroNumThreads[threadIdIdx], iterationSize: size))
429 tileSizePerThread =
430 buildMax(b, loc, {b.getIndexAttr(0), tileSizePerThread});
431
432 tiledSizes.push_back(Elt: tileSizePerThread);
433 ++threadIdIdx;
434 }
435}
436
437template <typename LoopTy>
438static FailureOr<TiledLinalgOp>
439tileLinalgOpImpl(RewriterBase &b, LinalgOp op, ArrayRef<OpFoldResult> tileSizes,
440 const LinalgTilingOptions &options) {
441 OpBuilder::InsertionGuard g(b);
442
443 auto nLoops = op.getNumLoops();
444 // Initial tile sizes may be too big, only take the first nLoops.
445 tileSizes = tileSizes.take_front(N: nLoops);
446
447 if (llvm::all_of(tileSizes, [](OpFoldResult ofr) {
448 return getConstantIntValue(ofr) == static_cast<int64_t>(0);
449 })) {
450 TiledLinalgOp tiledOp;
451 tiledOp.op = cast<LinalgOp>(b.clone(*op.getOperation()));
452 tiledOp.tensorResults.assign(tiledOp.op->result_begin(),
453 tiledOp.op->result_end());
454 return tiledOp;
455 }
456
457 // 1. Build the tiled loop ranges.
458 SmallVector<OpFoldResult> allShapeSizes =
459 op.createFlatListOfOperandDims(b, op.getLoc());
460 AffineMap shapeSizesToLoopsMap = op.getShapesToLoopsMap();
461 if (!shapeSizesToLoopsMap)
462 return failure();
463
464 auto [loopRanges, loopIndexToRangeIndex] = makeTiledLoopRanges(
465 b, op.getLoc(), shapeSizesToLoopsMap, allShapeSizes, tileSizes);
466
467 SmallVector<utils::IteratorType, 4> iteratorTypes;
468 for (const auto &attr : enumerate(op.getIteratorTypesArray())) {
469 if (loopIndexToRangeIndex.count(attr.index()))
470 iteratorTypes.push_back(attr.value());
471 }
472 // If interchangeVector is empty, use the identity. Build the permutation map
473 // otherwise.
474 auto invPermutationMap =
475 AffineMap::getMultiDimIdentityMap(numDims: tileSizes.size(), context: b.getContext());
476 if (!options.interchangeVector.empty()) {
477 // Based on the pruned iterations (due to zero tile size), recompute the
478 // interchange vector.
479 SmallVector<unsigned, 4> interchangeVector;
480 interchangeVector.reserve(N: options.interchangeVector.size());
481 for (auto pos : options.interchangeVector) {
482 auto it = loopIndexToRangeIndex.find(pos);
483 if (it == loopIndexToRangeIndex.end())
484 continue;
485 interchangeVector.push_back(Elt: it->second);
486 }
487 // Interchange vector is guaranteed to be a permutation,
488 // `inversePermutation` must succeed.
489 invPermutationMap = inversePermutation(
490 map: AffineMap::getPermutationMap(permutation: interchangeVector, context: b.getContext()));
491 assert(invPermutationMap);
492 SmallVector<int64_t> permutation(interchangeVector.begin(),
493 interchangeVector.end());
494 applyPermutationToVector(loopRanges, permutation);
495 applyPermutationToVector(iteratorTypes, permutation);
496 }
497
498 // Handle distribution. Create a vector of the same size of loops that are to
499 // be tiled.
500 SmallVector<linalg::ProcInfo> procInfo;
501 if (options.distribution) {
502 procInfo.resize(
503 iteratorTypes.size(),
504 linalg::ProcInfo{.procId: nullptr, .nprocs: nullptr, .distributionMethod: linalg::DistributionMethod::None});
505 // Collect loop ranges of tiled loops, loops that are parallel.
506 SmallVector<Range> parallelLoopRanges;
507 for (const auto &iteratorType : llvm::enumerate(iteratorTypes)) {
508 if (!isParallelIterator(iteratorType.value()))
509 break;
510 parallelLoopRanges.push_back(loopRanges[iteratorType.index()]);
511 }
512 auto returnedProcInfo =
513 options.distribution->procInfo(b, op.getLoc(), parallelLoopRanges);
514 unsigned procIdIdx = 0;
515 // Update the distribution information for the loops.
516 for (const auto &iteratorType : llvm::enumerate(iteratorTypes)) {
517 if (!isParallelIterator(iteratorType.value()))
518 break;
519 procInfo[iteratorType.index()] = returnedProcInfo[procIdIdx++];
520 }
521 }
522
523 // 2. Create the tiled loops.
524 LinalgOp res = op;
525 SmallVector<Value, 4> ivs, tensorResults;
526 auto tiledLoopBodyBuilder =
527 [&](OpBuilder &builder, Location loc, ValueRange localIvs,
528 ValueRange operandValuesToUse) -> scf::ValueVector {
529 ivs.assign(in_start: localIvs.begin(), in_end: localIvs.end());
530
531 // When an `interchangeVector` is present, it has been applied to the
532 // loop ranges and the iterator types. Apply its inverse to the
533 // resulting loop `ivs` to match the op definition.
534 SmallVector<Value, 4> interchangedIvs;
535 if (!options.interchangeVector.empty()) {
536 for (AffineExpr result : invPermutationMap.getResults())
537 interchangedIvs.push_back(
538 Elt: ivs[cast<AffineDimExpr>(Val&: result).getPosition()]);
539 } else {
540 interchangedIvs.assign(in_start: ivs.begin(), in_end: ivs.end());
541 }
542
543 // Tile the `operandValuesToUse` that either match the `op` operands
544 // themselves or the tile loop arguments forwarding them.
545 assert(operandValuesToUse.size() ==
546 static_cast<size_t>(op->getNumOperands()) &&
547 "expect the number of operands and inputs and outputs to match");
548 SmallVector<Value> valuesToTile = operandValuesToUse;
549 SmallVector<OpFoldResult> sizeBounds =
550 makeComposedFoldedMultiResultAffineApply(b, loc, map: shapeSizesToLoopsMap,
551 operands: allShapeSizes);
552 SmallVector<Value> tiledOperands = makeTiledShapes(
553 b, loc, op, valuesToTile, getAsOpFoldResult(values: interchangedIvs), tileSizes,
554 sizeBounds,
555 /*omitPartialTileCheck=*/false);
556
557 SmallVector<Type> resultTensorTypes =
558 getTensorOutputTypes(op, tiledOperands);
559 res = clone(b, op, resultTensorTypes, tiledOperands);
560 tensorResults =
561 insertSlicesBack(builder, loc, op, tiledOperands, res->getResults());
562 return scf::ValueVector(tensorResults.begin(), tensorResults.end());
563 };
564 GenerateLoopNest<LoopTy>::doit(b, op.getLoc(), loopRanges, op, iteratorTypes,
565 tiledLoopBodyBuilder, procInfo);
566
567 // 3. Transform IndexOp results w.r.t. the tiling.
568 transformIndexOps(b, res, ivs, loopIndexToRangeIndex);
569
570 // 4. Gather the newly created loops and return them with the new op.
571 SmallVector<Operation *, 8> loops;
572 loops.reserve(N: ivs.size());
573 for (auto iv : ivs) {
574 if (isa<BlockArgument>(Val: iv)) {
575 loops.push_back(Elt: cast<BlockArgument>(Val&: iv).getOwner()->getParentOp());
576 assert(loops.back() && "no owner found for induction variable!");
577 } else {
578 // TODO: Instead of doing this, try to recover the ops used instead of the
579 // loop.
580 loops.push_back(Elt: nullptr);
581 }
582 }
583
584 // 5. Get the tensor results from the outermost loop if available. Otherwise
585 // use the previously captured `tensorResults`.
586 Operation *outermostLoop = nullptr;
587 for (Operation *loop : loops)
588 if ((outermostLoop = loop))
589 break;
590
591 return TiledLinalgOp{
592 res, loops, outermostLoop ? outermostLoop->getResults() : tensorResults};
593}
594
595FailureOr<linalg::ForallReductionTilingResult> linalg::tileReductionUsingForall(
596 RewriterBase &b, PartialReductionOpInterface op,
597 ArrayRef<OpFoldResult> numThreads, ArrayRef<OpFoldResult> tileSizes,
598 std::optional<ArrayAttr> mapping) {
599 Location loc = op.getLoc();
600 OpBuilder::InsertionGuard g(b);
601
602 // Ops implementing PartialReductionOpInterface are expected to implement
603 // TilingInterface.
604 // TODO: proper core mechanism to tie interfaces together.
605 auto tilingInterfaceOp = cast<TilingInterface>(op.getOperation());
606
607 // Ops implementing PartialReductionOpInterface are not necessarily expected
608 // to implement TilingInterface.. This cast is unsafe atm.
609 // TODO: proper core mechanism to tie interfaces together.
610 // TODO: this function requires a pair of interfaces ..
611 auto destinationStyleOp =
612 dyn_cast<DestinationStyleOpInterface>(op.getOperation());
613 if (!destinationStyleOp)
614 return b.notifyMatchFailure(op, "not a destination style op");
615
616 // Actually this only work for Linalg ops atm.
617 auto linalgOp = dyn_cast<linalg::LinalgOp>(op.getOperation());
618 if (!linalgOp)
619 return b.notifyMatchFailure(op, "not a linalg op");
620
621 SmallVector<Range> iterationDomain = tilingInterfaceOp.getIterationDomain(b);
622 if (op->getNumResults() != 1)
623 return b.notifyMatchFailure(
624 op, "don't support ops with multiple results for now");
625
626 SmallVector<utils::IteratorType> iterators =
627 tilingInterfaceOp.getLoopIteratorTypes();
628 SmallVector<unsigned> redDims;
629 linalgOp.getReductionDims(redDims);
630 if (redDims.size() != 1)
631 return b.notifyMatchFailure(
632 op, "only support ops with one reduction dimension.");
633 if (!tileSizes.empty() && tileSizes.size() != numThreads.size())
634 return b.notifyMatchFailure(op, "if tile sizes are present it must have as "
635 "many elements as number of threads");
636 int reductionDim = static_cast<int>(redDims.front());
637
638 if (redDims.front() >= numThreads.size())
639 return b.notifyMatchFailure(
640 op, "reduction dimension must be mapped to threads");
641
642 // 1. Create the inital tensor value.
643 FailureOr<SmallVector<Value>> maybeInitTensors =
644 op.generateInitialTensorForPartialReduction(b, loc, numThreads,
645 reductionDim);
646 if (failed(Result: maybeInitTensors))
647 return b.notifyMatchFailure(
648 op, "Failed to create inital tensors for partial reduction");
649 SmallVector<Value> &initTensors = maybeInitTensors.value();
650
651 // Gather destination tensors.
652 SmallVector<Value> dest;
653 if (failed(tensor::getOrCreateDestinations(b, loc, op: op, result&: dest)))
654 return b.notifyMatchFailure(op, "failed to get destination tensors");
655
656 Operation *tiledOp = nullptr;
657
658 SmallVector<OpFoldResult> nonZeroNumThreads = llvm::filter_to_vector(
659 C&: numThreads, Pred: [](OpFoldResult ofr) { return !isZeroInteger(v: ofr); });
660 SmallVector<Value> materializedNonZeroNumThreads =
661 getValueOrCreateConstantIndexOp(b, loc, valueOrAttrVec: nonZeroNumThreads);
662
663 // 2. Create the ForallOp with an empty region.
664 scf::ForallOp forallOp = b.create<scf::ForallOp>(
665 loc, getAsOpFoldResult(materializedNonZeroNumThreads), initTensors,
666 mapping);
667
668 // 3. Calculate the tile offsets and sizes for the subsequent loop that will
669 // be nested under `forallOp`.
670 SmallVector<OpFoldResult> tiledOffsets, tiledSizes;
671 calculateTileOffsetsAndSizes(b, loc, forallOp, numThreads, iterationDomain,
672 /*omitTileOffsetBoundsCheck =*/false,
673 /*nominalTileSizes=*/std::nullopt, tiledOffsets,
674 tiledSizes);
675
676 // 4b. Clone the tileable op and update its destination operands to use the
677 // output bbArgs of the ForallOp.
678 SmallVector<Value> tilingResults;
679 ArrayRef<BlockArgument> destBbArgs = forallOp.getRegionIterArgs();
680 {
681 // 4.a. RAII guard, inserting within forallOp, before terminator.
682 OpBuilder::InsertionGuard g(b);
683 b.setInsertionPoint(forallOp.getTerminator());
684
685 SmallVector<Value> tiledDpsInitOperands;
686 for (Value initOperand : destinationStyleOp.getDpsInits()) {
687 auto *it = llvm::find(dest, initOperand);
688 assert(it != dest.end() && "dest operand not found in dest");
689 unsigned destNum = std::distance(dest.begin(), it);
690 SmallVector<OpFoldResult> strides(numThreads.size(), b.getIndexAttr(1));
691 SmallVector<OpFoldResult> outOffsets(numThreads.size(),
692 b.getIndexAttr(0));
693 SmallVector<OpFoldResult> sizes = tiledSizes;
694 sizes[reductionDim] = b.getIndexAttr(1);
695 outOffsets[reductionDim] = forallOp.getInductionVars()[0];
696 // TODO: use SubsetExtractOpInterface once it is available.
697 tiledDpsInitOperands.push_back(b.create<tensor::ExtractSliceOp>(
698 loc, cast<RankedTensorType>(initOperand.getType()),
699 destBbArgs[destNum], outOffsets, sizes, strides));
700 }
701
702 // 4.b. Clone the op and update init operands.
703 // We cannot use a IRMapping here because it can replace
704 // different OpOperands with the same value.
705 Operation *clonedOp = b.clone(*op.getOperation());
706 b.modifyOpInPlace(root: clonedOp, callable: [&]() {
707 for (auto [initOperandPtr, tiledInitValue] : llvm::zip_equal(
708 cast<DestinationStyleOpInterface>(clonedOp).getDpsInitsMutable(),
709 tiledDpsInitOperands)) {
710 initOperandPtr.set(tiledInitValue);
711 }
712 });
713
714 // 5. Tile the cloned op and delete the clone.
715 if (tileSizes.empty()) {
716 FailureOr<TilingResult> tilingResult =
717 cast<TilingInterface>(clonedOp).getTiledImplementation(
718 b, tiledOffsets, tiledSizes);
719 if (failed(Result: tilingResult))
720 return clonedOp->emitError(message: "Failed to tile op: ");
721 if (tilingResult->tiledOps.size() != 1) {
722 return clonedOp->emitError(message: "expected a single produced tiled op, got ")
723 << tilingResult->tiledOps.size();
724 }
725 tiledOp = tilingResult->tiledOps.front();
726 tilingResults = tilingResult->tiledValues;
727 } else {
728 LinalgTilingOptions options;
729 FailureOr<TiledLinalgOp> maybeTiled = tileLinalgOpImpl<scf::ForOp>(
730 b, cast<LinalgOp>(clonedOp), tileSizes, options);
731 if (failed(Result: maybeTiled))
732 return b.notifyMatchFailure(op, "failed tileLinalgOpImpl");
733
734 SmallVector<Value> ids = forallOp.getInductionVars();
735 mapLoopToProcessorIds(cast<scf::ForOp>(maybeTiled->loops.back()), ids,
736 materializedNonZeroNumThreads);
737 if (maybeTiled->loops.size() != 1) {
738 return clonedOp->emitError(message: "expected a single produced loop");
739 }
740 tiledOp = maybeTiled->op;
741 tilingResults = maybeTiled->loops.front()->getResults();
742 }
743
744 b.eraseOp(op: clonedOp);
745 }
746
747 // 6. Insert the partial reductions back into a new tensor.
748 for (auto [index, result, bbArg] : llvm::zip(
749 llvm::seq<unsigned>(0, dest.size()), tilingResults, destBbArgs)) {
750 // 6.a. Partial subset information is inserted just before the terminator.
751 OpBuilder::InsertionGuard g(b);
752 b.setInsertionPoint(forallOp.getTerminator());
753
754 SmallVector<OpFoldResult> resultOffsets, resultSizes;
755 if (failed(tilingInterfaceOp.getResultTilePosition(
756 b, index, tiledOffsets, tiledSizes, resultOffsets, resultSizes)))
757 return op->emitOpError("output offsets couldn't be calculated");
758 SmallVector<OpFoldResult> resultOffsetsRank, resultSizesRank;
759 int64_t offIdx = 0;
760 int64_t sizeIdx = 0;
761 for (int64_t i = 0, e = numThreads.size(); i < e; ++i) {
762 if (i == reductionDim) {
763 resultOffsetsRank.push_back(forallOp.getInductionVars()[0]);
764 resultSizesRank.push_back(b.getIndexAttr(1));
765 continue;
766 }
767 resultOffsetsRank.push_back(resultOffsets[offIdx++]);
768 resultSizesRank.push_back(resultSizes[sizeIdx++]);
769 }
770 SmallVector<OpFoldResult> strides(resultSizesRank.size(),
771 b.getIndexAttr(1));
772
773 // 6.b. Parallel insertions are inserted at the end of the combining
774 // terminator.
775 b.setInsertionPointToEnd(forallOp.getTerminator().getBody());
776 b.create<tensor::ParallelInsertSliceOp>(
777 loc, result, bbArg, resultOffsetsRank, resultSizesRank, strides);
778 }
779
780 // 7. Merge the partial reductions.
781 b.setInsertionPointAfter(forallOp);
782 FailureOr<MergeResult> mergeResult =
783 op.mergeReductions(b, loc, forallOp->getResults(), reductionDim);
784 if (failed(Result: mergeResult)) {
785 return failure();
786 }
787 b.replaceOp(op, mergeResult->replacements);
788
789 // 8. Return.
790 ForallReductionTilingResult results;
791 results.initialValues = initTensors;
792 results.loops = forallOp;
793 results.parallelTiledOps.push_back(Elt: tiledOp);
794 results.mergeOps.append(RHS: mergeResult->mergeOps);
795 return results;
796}
797
798template <typename LoopTy>
799FailureOr<TiledLinalgOp> static tileLinalgOpImpl(
800 RewriterBase &b, LinalgOp op, const LinalgTilingOptions &options) {
801 OpBuilder::InsertionGuard g(b);
802 b.setInsertionPoint(op);
803
804 if (!options.tileSizeComputationFunction)
805 return failure();
806
807 // Enforce the convention that "tiling by zero" skips tiling a particular
808 // dimension. This convention is significantly simpler to handle instead of
809 // adjusting affine maps to account for missing dimensions.
810 auto nLoops = op.getNumLoops();
811 SmallVector<OpFoldResult> tileSizeVector =
812 getAsOpFoldResult(options.tileSizeComputationFunction(b, op));
813 if (tileSizeVector.size() < nLoops) {
814 tileSizeVector.append(nLoops - tileSizeVector.size(), b.getIndexAttr(0));
815 }
816
817 return tileLinalgOpImpl<LoopTy>(b, op, tileSizeVector, options);
818}
819
820FailureOr<TiledLinalgOp>
821mlir::linalg::tileLinalgOp(RewriterBase &b, LinalgOp op,
822 const LinalgTilingOptions &options) {
823 switch (options.loopType) {
824 case LinalgTilingLoopType::Loops:
825 return tileLinalgOpImpl<scf::ForOp>(b, op, options);
826 case LinalgTilingLoopType::ParallelLoops:
827 return tileLinalgOpImpl<scf::ParallelOp>(b, op, options);
828 default:;
829 }
830 return failure();
831}
832
833namespace {
834/// Helper classes for type list expansion.
835template <typename... OpTypes>
836class CanonicalizationPatternList;
837
838template <>
839class CanonicalizationPatternList<> {
840public:
841 static void insert(RewritePatternSet &patterns) {}
842};
843
844template <typename OpTy, typename... OpTypes>
845class CanonicalizationPatternList<OpTy, OpTypes...> {
846public:
847 static void insert(RewritePatternSet &patterns) {
848 OpTy::getCanonicalizationPatterns(patterns, patterns.getContext());
849 CanonicalizationPatternList<OpTypes...>::insert(patterns);
850 }
851};
852} // namespace
853
854RewritePatternSet
855mlir::linalg::getLinalgTilingCanonicalizationPatterns(MLIRContext *ctx) {
856 RewritePatternSet patterns(ctx);
857 populateLinalgTilingCanonicalizationPatterns(patterns);
858 return patterns;
859}
860
861void mlir::linalg::populateLinalgTilingCanonicalizationPatterns(
862 RewritePatternSet &patterns) {
863 auto *ctx = patterns.getContext();
864 affine::AffineApplyOp::getCanonicalizationPatterns(patterns, ctx);
865 affine::AffineForOp::getCanonicalizationPatterns(patterns, ctx);
866 affine::AffineMinOp::getCanonicalizationPatterns(patterns, ctx);
867 affine::AffineMaxOp::getCanonicalizationPatterns(patterns, ctx);
868 arith::ConstantIndexOp::getCanonicalizationPatterns(patterns, ctx);
869
870 memref::SubViewOp::getCanonicalizationPatterns(patterns, ctx);
871 memref::ViewOp::getCanonicalizationPatterns(patterns, ctx);
872
873 scf::ForOp::getCanonicalizationPatterns(patterns, ctx);
874 scf::ParallelOp::getCanonicalizationPatterns(patterns, ctx);
875
876 tensor::CastOp::getCanonicalizationPatterns(patterns, ctx);
877 tensor::EmptyOp::getCanonicalizationPatterns(patterns, ctx);
878 tensor::ExtractSliceOp::getCanonicalizationPatterns(patterns, ctx);
879 tensor::InsertSliceOp::getCanonicalizationPatterns(patterns, ctx);
880 tensor::PadOp::getCanonicalizationPatterns(patterns, ctx);
881 ctx->getLoadedDialect<LinalgDialect>()->getCanonicalizationPatterns(patterns);
882
883 CanonicalizationPatternList<
884#define GET_OP_LIST
885#include "mlir/Dialect/Linalg/IR/LinalgStructuredOps.cpp.inc"
886 >::insert(patterns);
887}
888

Provided by KDAB

Privacy Policy
Improve your Profiling and Debugging skills
Find out more

source code of mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp