1//===- Tiling.cpp - Implementation of linalg Tiling -----------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implements the linalg dialect Tiling pass.
10//
11//===----------------------------------------------------------------------===//
12
13#include "mlir/Dialect/Linalg/Passes.h"
14
15#include "mlir/Dialect/Affine/IR/AffineOps.h"
16#include "mlir/Dialect/Affine/LoopUtils.h"
17#include "mlir/Dialect/Arith/Utils/Utils.h"
18#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
19#include "mlir/Dialect/Func/IR/FuncOps.h"
20#include "mlir/Dialect/Linalg/IR/Linalg.h"
21#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
22#include "mlir/Dialect/MemRef/IR/MemRef.h"
23#include "mlir/Dialect/SCF/Transforms/Transforms.h"
24#include "mlir/Dialect/Tensor/IR/Tensor.h"
25#include "mlir/Dialect/Utils/IndexingUtils.h"
26#include "mlir/IR/AffineExpr.h"
27#include "mlir/IR/AffineMap.h"
28#include "mlir/IR/BuiltinOps.h"
29#include "mlir/IR/ValueRange.h"
30#include "mlir/Transforms/FoldUtils.h"
31#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
32#include "llvm/ADT/STLExtras.h"
33#include "llvm/Support/CommandLine.h"
34#include <utility>
35
36namespace mlir {
37#define GEN_PASS_DEF_LINALGTILINGPASS
38#include "mlir/Dialect/Linalg/Passes.h.inc"
39} // namespace mlir
40
41using namespace mlir;
42using namespace mlir::affine;
43using namespace mlir::linalg;
44using namespace mlir::scf;
45
46#define DEBUG_TYPE "linalg-tiling"
47
48std::tuple<SmallVector<Range, 4>, LoopIndexToRangeIndexMap>
49mlir::linalg::makeTiledLoopRanges(RewriterBase &b, Location loc, AffineMap map,
50 ArrayRef<OpFoldResult> allShapeSizes,
51 ArrayRef<OpFoldResult> allTileSizes) {
52 assert(allTileSizes.size() == map.getNumResults());
53 // Apply `map` to get shape sizes in loop order.
54 SmallVector<OpFoldResult> shapeSizes =
55 makeComposedFoldedMultiResultAffineApply(b, loc, map, allShapeSizes);
56 SmallVector<OpFoldResult> tileSizes(allTileSizes.begin(), allTileSizes.end());
57
58 // Traverse the tile sizes, which are in loop order, erase zeros everywhere.
59 LoopIndexToRangeIndexMap loopIndexToRangeIndex;
60 for (int idx = 0, e = tileSizes.size(), zerosCount = 0; idx < e; ++idx) {
61 if (getConstantIntValue(tileSizes[idx - zerosCount]) ==
62 static_cast<int64_t>(0)) {
63 shapeSizes.erase(shapeSizes.begin() + idx - zerosCount);
64 tileSizes.erase(tileSizes.begin() + idx - zerosCount);
65 ++zerosCount;
66 continue;
67 }
68 loopIndexToRangeIndex[idx] = idx - zerosCount;
69 }
70
71 // Create a new range with the applied tile sizes.
72 SmallVector<Range, 4> res;
73 for (unsigned idx = 0, e = tileSizes.size(); idx < e; ++idx)
74 res.push_back(Range{b.getIndexAttr(0), shapeSizes[idx], tileSizes[idx]});
75 return std::make_tuple(res, loopIndexToRangeIndex);
76}
77
78void mlir::linalg::transformIndexOps(
79 RewriterBase &b, LinalgOp op, SmallVectorImpl<Value> &ivs,
80 const LoopIndexToRangeIndexMap &loopIndexToRangeIndex) {
81 SmallVector<Value> allIvs(op.getNumLoops(), nullptr);
82 for (auto en : enumerate(allIvs)) {
83 auto rangeIndex = loopIndexToRangeIndex.find(en.index());
84 if (rangeIndex == loopIndexToRangeIndex.end())
85 continue;
86 en.value() = ivs[rangeIndex->second];
87 }
88 offsetIndices(b, op, getAsOpFoldResult(values: allIvs));
89}
90
91/// Asserts that the given index-typed value is strictly positive. If the value
92/// is an attribute, asserts at compile time, otherwise emits an assertion
93/// checked at runtime.
94static void emitIsPositiveIndexAssertion(ImplicitLocOpBuilder &b,
95 OpFoldResult value) {
96 if (auto attr = llvm::dyn_cast_if_present<Attribute>(Val&: value)) {
97 assert(cast<IntegerAttr>(attr).getValue().isStrictlyPositive() &&
98 "expected strictly positive tile size and divisor");
99 return;
100 }
101
102 Value zero = b.create<arith::ConstantIndexOp>(args: 0);
103 Value condition = b.create<arith::CmpIOp>(arith::CmpIPredicate::sgt,
104 value.get<Value>(), zero);
105 b.create<cf::AssertOp>(
106 condition,
107 b.getStringAttr("expected strictly positive tile size and divisor"));
108}
109
110FailureOr<StaticMultiSizeSpecification>
111mlir::linalg::computeStaticMultiTileSizes(LinalgOp op, unsigned dimension,
112 int64_t targetSize, int64_t divisor) {
113 assert(!op.hasDynamicShape() &&
114 "cannot compute static multi-tile sizes for an op with dynamic shape");
115 assert(targetSize > 0 && "target size must be non-negative");
116 assert(divisor > 0 && "divisor must be non-negative");
117 assert(dimension < op.getNumLoops() && "dimension overflow");
118
119 StaticMultiSizeSpecification spec;
120 int64_t tripCount = op.getStaticLoopRanges()[dimension];
121 int64_t a = tripCount / divisor;
122 int64_t t = (targetSize + divisor - 1) / divisor;
123 int64_t totalTripCount = (a + t - 1) / t;
124 spec.lowTileSize = (a / totalTripCount) * divisor;
125 spec.highTileSize = spec.lowTileSize + divisor;
126 spec.highTripCount = a % totalTripCount;
127 spec.lowTripCount = totalTripCount - spec.highTripCount;
128 if (spec.lowTileSize * spec.lowTripCount +
129 spec.highTileSize * spec.highTripCount !=
130 tripCount) {
131 return failure();
132 }
133 return spec;
134}
135
136FailureOr<MultiSizeSpecification>
137mlir::linalg::computeMultiTileSizes(OpBuilder &builder, LinalgOp op,
138 unsigned dimension, OpFoldResult targetSize,
139 OpFoldResult divisor, bool emitAssertions) {
140 // Bail out on dimension overflow.
141 if (dimension >= op.getNumLoops())
142 return failure();
143
144 // The code below works only on values.
145 Location loc = op.getLoc();
146 ImplicitLocOpBuilder b(loc, builder);
147 if (emitAssertions) {
148 emitIsPositiveIndexAssertion(b, value: targetSize);
149 emitIsPositiveIndexAssertion(b, value: divisor);
150 }
151 Value targetSizeValue =
152 getValueOrCreateConstantIndexOp(b&: builder, loc, ofr: targetSize);
153 Value divisorValue = getValueOrCreateConstantIndexOp(b&: builder, loc, ofr: divisor);
154
155 // Find the trip count of the iteration space dimension for which the tile
156 // sizes are computed.
157 SmallVector<OpFoldResult> allShapes =
158 op.createFlatListOfOperandDims(b, b.getLoc());
159 AffineMap shapesToLoops = op.getShapesToLoopsMap();
160 SmallVector<OpFoldResult> loopRanges =
161 makeComposedFoldedMultiResultAffineApply(b, op.getLoc(), shapesToLoops,
162 allShapes);
163 Value tripCount =
164 getValueOrCreateConstantIndexOp(b, op.getLoc(), loopRanges[dimension]);
165
166 // Compute the tile sizes and the respective numbers of tiles.
167 AffineExpr s0 = b.getAffineSymbolExpr(position: 0);
168 AffineExpr s1 = b.getAffineSymbolExpr(position: 1);
169 AffineExpr s2 = b.getAffineSymbolExpr(position: 2);
170 auto apply = [&](AffineExpr expr, ArrayRef<OpFoldResult> ofrs) -> Value {
171 return affine::makeComposedAffineApply(b, b.getLoc(), expr, ofrs);
172 };
173 Value a = apply(s0.floorDiv(other: s1), {tripCount, divisorValue});
174 Value t = apply((s0 + s1 - 1).floorDiv(other: s1), {targetSizeValue, divisorValue});
175 Value d = apply((s0 + s1 - 1).floorDiv(other: s1), {a, t});
176 Value s = apply(s0.floorDiv(other: s1) * s2, {a, d, divisorValue});
177 Value v = apply(s0 % s1, {a, d});
178 Value u = apply(s0 - s1, {d, v});
179
180 MultiSizeSpecification spec;
181 spec.lowTileSize = s;
182 spec.highTileSize = apply(s0 + s1, {s, divisorValue});
183 spec.lowTripCount = u;
184 spec.highTripCount = v;
185
186 // If requested, emit the check that the tile sizes are computed correctly.
187 // For example, for iteration dimension size of 15 and the target size 8 it is
188 // impossible to find two tile sizes both divisible by 8 that fully cover the
189 // original space dimension.
190 if (emitAssertions) {
191 AffineExpr s3 = builder.getAffineSymbolExpr(position: 3);
192 Value coveredSize =
193 apply(s0 * s1 + s2 * s3, {spec.lowTileSize, spec.lowTripCount,
194 spec.highTileSize, spec.highTripCount});
195 Value equals = b.create<arith::CmpIOp>(arith::CmpIPredicate::eq,
196 coveredSize, tripCount);
197 b.create<cf::AssertOp>(
198 equals, builder.getStringAttr(
199 "could not compute dynamic multi-size tile shapes"));
200 }
201
202 return spec;
203}
204
205/// Returns true if the maximum tile offset `tileSize * numThreads-1` is less
206/// than `iterationSize`.
207static bool canOmitTileOffsetInBoundsCheck(OpFoldResult tileSize,
208 OpFoldResult numThreads,
209 OpFoldResult iterationSize) {
210 std::optional<int64_t> tileSizeConst = getConstantIntValue(ofr: tileSize);
211 std::optional<int64_t> numThreadsConst = getConstantIntValue(ofr: numThreads);
212 std::optional<int64_t> iterSizeConst = getConstantIntValue(ofr: iterationSize);
213 if (!tileSizeConst || !numThreadsConst || !iterSizeConst)
214 return false;
215 return *tileSizeConst * (*numThreadsConst - 1) < *iterSizeConst;
216}
217
218/// Build an `affine_max` of all the `vals`.
219static OpFoldResult buildMax(OpBuilder &b, Location loc,
220 ArrayRef<OpFoldResult> vals) {
221 return affine::makeComposedFoldedAffineMax(
222 b, loc, map: AffineMap::getMultiDimIdentityMap(numDims: vals.size(), context: loc.getContext()),
223 operands: vals);
224}
225
226/// Build an `affine_min` of all the `vals`.
227static OpFoldResult buildMin(OpBuilder &b, Location loc,
228 ArrayRef<OpFoldResult> vals) {
229 return affine::makeComposedFoldedAffineMin(
230 b, loc, map: AffineMap::getMultiDimIdentityMap(numDims: vals.size(), context: loc.getContext()),
231 operands: vals);
232}
233
234/// Fill out the `tiledOffsets` and `tiledSizes` to be used to tile to a given
235/// number of threads.
236static void calculateTileOffsetsAndSizes(
237 RewriterBase &b, Location loc, scf::ForallOp forallOp,
238 ArrayRef<OpFoldResult> numThreads, SmallVector<Range> loopRanges,
239 bool omitTileOffsetBoundsCheck,
240 std::optional<ArrayRef<OpFoldResult>> nominalTileSizes,
241 SmallVector<OpFoldResult> &tiledOffsets,
242 SmallVector<OpFoldResult> &tiledSizes) {
243 OpBuilder::InsertionGuard g(b);
244 b.setInsertionPointToStart(forallOp.getBody(0));
245
246 ValueRange threadIds = forallOp.getInductionVars();
247 SmallVector<OpFoldResult> nonZeroNumThreads =
248 llvm::to_vector(Range: llvm::make_filter_range(Range&: numThreads, Pred: [](OpFoldResult ofr) {
249 return !isConstantIntValue(ofr, value: 0);
250 }));
251 int64_t nLoops = loopRanges.size();
252 tiledOffsets.reserve(N: nLoops);
253 tiledSizes.reserve(N: nLoops);
254 for (unsigned loopIdx = 0, threadIdIdx = 0; loopIdx < nLoops; ++loopIdx) {
255 bool overflow = loopIdx >= numThreads.size();
256 bool isZero = !overflow && isConstantIntValue(ofr: numThreads[loopIdx], value: 0);
257 // Degenerate case: take the whole domain.
258 if (overflow || isZero) {
259 tiledOffsets.push_back(Elt: loopRanges[loopIdx].offset);
260 tiledSizes.push_back(Elt: loopRanges[loopIdx].size);
261 continue;
262 }
263
264 // Tiled case: compute the offset and size.
265 AffineExpr i, j, m, n, o;
266 bindDims(ctx: b.getContext(), exprs&: i, exprs&: j);
267 bindSymbols(ctx: b.getContext(), exprs&: m, exprs&: n, exprs&: o);
268 OpFoldResult size = loopRanges[loopIdx].size;
269 OpFoldResult offset = loopRanges[loopIdx].offset;
270 OpFoldResult threadId = threadIds[threadIdIdx];
271 // Symbolic fixed max size per thread.
272 // TODO: floor + 0/1 depending on case for better load-balancing.
273 OpFoldResult tileSizePerThread =
274 nominalTileSizes.has_value()
275 ? (*nominalTileSizes)[loopIdx]
276 : makeComposedFoldedAffineApply(
277 b, loc, expr: m.ceilDiv(other: n),
278 operands: ArrayRef<OpFoldResult>{size, nonZeroNumThreads[threadIdIdx]});
279
280 // Dynamic offset shifted by threadId * maxSizePerThread.
281 OpFoldResult offsetPerThread = makeComposedFoldedAffineApply(
282 b, loc, expr: i + j * m, operands: {offset, threadId, tileSizePerThread});
283 // Dynamic upper-bound depending on the threadId.
284 OpFoldResult residualTileSize = makeComposedFoldedAffineApply(
285 b, loc, expr: i + j * m - n,
286 operands: {offset, nonZeroNumThreads[threadIdIdx], tileSizePerThread, size});
287 if (!isConstantIntValue(ofr: residualTileSize, value: 0)) {
288 OpFoldResult sizeMinusOffsetPerThread = makeComposedFoldedAffineApply(
289 b, loc, expr: -i + m, operands: {offsetPerThread, size});
290 tileSizePerThread =
291 buildMin(b, loc, vals: {sizeMinusOffsetPerThread, tileSizePerThread});
292 }
293
294 tiledOffsets.push_back(Elt: offsetPerThread);
295 // TODO: if tileSizePerThread <= 0 early exit.
296 if (!omitTileOffsetBoundsCheck &&
297 !canOmitTileOffsetInBoundsCheck(tileSize: tileSizePerThread,
298 numThreads: nonZeroNumThreads[threadIdIdx], iterationSize: size))
299 tileSizePerThread =
300 buildMax(b, loc, {b.getIndexAttr(0), tileSizePerThread});
301
302 tiledSizes.push_back(Elt: tileSizePerThread);
303 ++threadIdIdx;
304 }
305}
306
307/// Returns a vector of bools representing if, for each axis, `op` can be tiled
308/// without incurring in a race condition and thus it is thread-safe to do the
309/// tiling. This is checked by iterating over numThreads and ensuring that the
310/// corresponding iterator type is "parallel". If it is not, then we know that
311/// such dimension is unsafe to tile.
312SmallVector<bool> safeToTileToForall(mlir::MLIRContext *ctx, LinalgOp linalgOp,
313 ArrayRef<OpFoldResult> numThreads) {
314 auto iterators = linalgOp.getIteratorTypesArray();
315 SmallVector<bool> safeToTile(numThreads.size(), true);
316
317 for (unsigned i = 0, e = numThreads.size(); i != e; i++) {
318 if (auto attr = llvm::dyn_cast_if_present<Attribute>(Val: numThreads[i])) {
319 if (cast<IntegerAttr>(attr).getValue().getSExtValue() > 1) {
320 safeToTile[i] = iterators[i] == utils::IteratorType::parallel;
321 }
322 } else {
323 safeToTile[i] = iterators[i] == utils::IteratorType::parallel;
324 }
325 }
326 return safeToTile;
327}
328
329/// Rewrite a TilingInterface `op` to a tiled `scf.forall`. The
330/// tiling is specified by the number of tiles/threads `numThreads` and the
331/// optional nominal tile size `nominalTileSizes`. If `nominalTilSizes` is
332/// not specified, then it is derived from `numThreads` as `ceilDiv(dimSize[i],
333/// numThreads[i])`. If non-empty, the `mapping` is added as an
334/// attribute to the resulting `scf.forall`. A zero tile sizes indicate
335/// that the dimension is not tiled, and can be thought of as tiling by the full
336/// size of data.
337/// It is the user's responsibility to ensure that `numThreads` is a valid
338/// tiling specification (i.e. that only tiles parallel dimensions, e.g. in the
339/// Linalg case). If the dimension is not parallelizable, a warning is issued to
340/// notify the user that the generated code is not safe to parallelize. If
341/// `omitTileOffsetBoundsCheck` is true, then the function will assume that
342/// `tileSize[i] * (numThread[i] -1) <= dimSize[i]` holds.
343static FailureOr<ForallTilingResult> tileToForallOpImpl(
344 RewriterBase &b, TilingInterface op, ArrayRef<OpFoldResult> numThreads,
345 std::optional<ArrayRef<OpFoldResult>> nominalTileSizes,
346 std::optional<ArrayAttr> mapping, bool omitTileOffsetBoundsCheck) {
347 Location loc = op->getLoc();
348 OpBuilder::InsertionGuard g(b);
349
350 SmallVector<Range> loopRanges = op.getIterationDomain(b);
351 if (loopRanges.empty())
352 return op->emitOpError("expected non-empty loop ranges");
353 auto hasStrideOne = [](Range r) { return !isConstantIntValue(ofr: r.stride, value: 1); };
354 if (llvm::any_of(Range&: loopRanges, P: hasStrideOne))
355 return op->emitOpError("only stride-1 supported atm");
356
357 // Gather destination tensors.
358 SmallVector<Value> dest;
359 if (failed(tensor::getOrCreateDestinations(b, loc, op: op, result&: dest)))
360 return op->emitOpError("failed to get destination tensors");
361
362 SmallVector<OpFoldResult> nonZeroNumThreads =
363 llvm::to_vector(Range: llvm::make_filter_range(Range&: numThreads, Pred: [](OpFoldResult ofr) {
364 return !isConstantIntValue(ofr, value: 0);
365 }));
366 SmallVector<Value> materializedNonZeroNumThreads =
367 llvm::to_vector(llvm::map_range(nonZeroNumThreads, [&](OpFoldResult ofr) {
368 return getValueOrCreateConstantIndexOp(b, loc, ofr);
369 }));
370
371 LinalgOp linalgOp = dyn_cast<LinalgOp>(op.getOperation());
372 if (linalgOp) {
373 // Check if tiling is thread safe and print a warning if not.
374 SmallVector<bool> tilingSafety =
375 safeToTileToForall(b.getContext(), linalgOp, numThreads);
376 for (size_t i = 0; i < tilingSafety.size(); i++)
377 if (!tilingSafety[i])
378 op.emitWarning() << "tiling is not thread safe at axis #" << i;
379 }
380
381 // 1. Create the ForallOp. We don't use the lambda body-builder
382 // version because we require the use of RewriterBase in the body, so we
383 // manually move the insertion point to the body below.
384 scf::ForallOp forallOp = b.create<scf::ForallOp>(
385 loc, getAsOpFoldResult((materializedNonZeroNumThreads)), dest, mapping);
386
387 // 2. Fill out the ForallOp body.
388 SmallVector<OpFoldResult> tiledOffsets, tiledSizes;
389 calculateTileOffsetsAndSizes(b, loc, forallOp, numThreads, loopRanges,
390 omitTileOffsetBoundsCheck, nominalTileSizes,
391 tiledOffsets, tiledSizes);
392
393 // 3. Clone the tileable op and update its destination operands to use the
394 // output bbArgs of the ForallOp.
395 ArrayRef<BlockArgument> destBbArgs = forallOp.getRegionIterArgs();
396 Operation *tiledOp = nullptr;
397 SmallVector<Value> tiledValues;
398 {
399 // 3.a. RAII guard, inserting within forallOp, before terminator.
400 OpBuilder::InsertionGuard g(b);
401 b.setInsertionPoint(forallOp.getTerminator());
402 Operation *clonedOp = b.clone(*op.getOperation());
403 auto destinationStyleOp = dyn_cast<DestinationStyleOpInterface>(clonedOp);
404 if (destinationStyleOp) {
405 for (OpOperand &outOperand : destinationStyleOp.getDpsInitsMutable()) {
406 // Swap tensor inits with the corresponding block argument of the
407 // scf.forall op. Memref inits remain as is.
408 if (isa<TensorType>(outOperand.get().getType())) {
409 auto *it = llvm::find(dest, outOperand.get());
410 assert(it != dest.end() && "could not find destination tensor");
411 unsigned destNum = std::distance(dest.begin(), it);
412 outOperand.set(destBbArgs[destNum]);
413 }
414 }
415 }
416
417 // 4. Tile the cloned op and delete the clone.
418 FailureOr<TilingResult> tilingResult =
419 cast<TilingInterface>(clonedOp).getTiledImplementation(b, tiledOffsets,
420 tiledSizes);
421 if (failed(result: tilingResult))
422 return clonedOp->emitError(message: "Failed to tile op: ");
423 if (tilingResult->tiledOps.size() != 1) {
424 return clonedOp->emitError(message: "expected a single produced tiled op, got ")
425 << tilingResult->tiledOps.size();
426 }
427
428 b.eraseOp(op: clonedOp);
429 tiledOp = tilingResult->tiledOps.front();
430 tiledValues = tilingResult->tiledValues;
431 }
432
433 // 5. Parallel insert back into the result tensor.
434 for (auto it : llvm::zip(llvm::seq(unsigned(0), unsigned(dest.size())),
435 tiledValues, destBbArgs)) {
436 // 5.a. Partial subset information is inserted just before the terminator.
437 OpBuilder::InsertionGuard g(b);
438 b.setInsertionPoint(forallOp.getTerminator());
439
440 SmallVector<OpFoldResult> resultOffsets, resultSizes;
441 if (failed(op.getResultTilePosition(b, std::get<0>(it), tiledOffsets,
442 tiledSizes, resultOffsets,
443 resultSizes)))
444 return op->emitOpError("output offsets couldn't be calculated");
445 SmallVector<OpFoldResult> strides(resultSizes.size(), b.getIndexAttr(1));
446
447 // 5.b. Parallel insertions are inserted at the end of the combining
448 // terminator.
449 b.setInsertionPointToEnd(forallOp.getTerminator().getBody());
450 b.create<tensor::ParallelInsertSliceOp>(loc, std::get<1>(it),
451 std::get<2>(it), resultOffsets,
452 resultSizes, strides);
453 }
454 return ForallTilingResult{forallOp, tiledOp};
455}
456
457FailureOr<ForallTilingResult>
458linalg::tileToForallOp(RewriterBase &b, TilingInterface op,
459 ArrayRef<OpFoldResult> numThreads,
460 std::optional<ArrayAttr> mapping) {
461 return tileToForallOpImpl(b, op, numThreads,
462 /*nominalTileSizes=*/std::nullopt, mapping,
463 /*omitTileOffsetBoundsCheck=*/false);
464}
465
466FailureOr<ForallTilingResult>
467linalg::tileToForallOpUsingTileSizes(RewriterBase &b, TilingInterface op,
468 ArrayRef<OpFoldResult> tileSizes,
469 std::optional<ArrayAttr> mapping) {
470 SmallVector<Range> loopRanges = op.getIterationDomain(b);
471 unsigned nLoops = loopRanges.size();
472 SmallVector<OpFoldResult> numThreads;
473 numThreads.reserve(N: nLoops);
474 AffineExpr s0, s1;
475 bindSymbols(ctx: b.getContext(), exprs&: s0, exprs&: s1);
476 AffineExpr divExpr = s0.ceilDiv(other: s1);
477 for (const auto &it : llvm::zip(tileSizes, loopRanges)) {
478 OpFoldResult numTiles = std::get<0>(it);
479 if (!isConstantIntValue(numTiles, 0))
480 numTiles = makeComposedFoldedAffineApply(
481 b, op.getLoc(), divExpr, {std::get<1>(it).size, std::get<0>(it)});
482 numThreads.push_back(numTiles);
483 }
484 return tileToForallOpImpl(b, op, numThreads,
485 /*nominalTileSizes=*/tileSizes, mapping,
486 /*omitTileOffsetBoundsCheck=*/true);
487}
488
489template <typename LoopTy>
490static FailureOr<TiledLinalgOp>
491tileLinalgOpImpl(RewriterBase &b, LinalgOp op, ArrayRef<OpFoldResult> tileSizes,
492 const LinalgTilingOptions &options) {
493 OpBuilder::InsertionGuard g(b);
494
495 auto nLoops = op.getNumLoops();
496 // Initial tile sizes may be too big, only take the first nLoops.
497 tileSizes = tileSizes.take_front(N: nLoops);
498
499 if (llvm::all_of(tileSizes, [](OpFoldResult ofr) {
500 return getConstantIntValue(ofr) == static_cast<int64_t>(0);
501 })) {
502 TiledLinalgOp tiledOp;
503 tiledOp.op = cast<LinalgOp>(b.clone(*op.getOperation()));
504 tiledOp.tensorResults.assign(tiledOp.op->result_begin(),
505 tiledOp.op->result_end());
506 return tiledOp;
507 }
508
509 // 1. Build the tiled loop ranges.
510 SmallVector<OpFoldResult> allShapeSizes =
511 op.createFlatListOfOperandDims(b, op.getLoc());
512 AffineMap shapeSizesToLoopsMap = op.getShapesToLoopsMap();
513 if (!shapeSizesToLoopsMap)
514 return failure();
515
516 auto [loopRanges, loopIndexToRangeIndex] = makeTiledLoopRanges(
517 b, op.getLoc(), shapeSizesToLoopsMap, allShapeSizes, tileSizes);
518
519 SmallVector<utils::IteratorType, 4> iteratorTypes;
520 for (const auto &attr : enumerate(op.getIteratorTypesArray())) {
521 if (loopIndexToRangeIndex.count(attr.index()))
522 iteratorTypes.push_back(attr.value());
523 }
524 // If interchangeVector is empty, use the identity. Build the permutation map
525 // otherwise.
526 auto invPermutationMap =
527 AffineMap::getMultiDimIdentityMap(numDims: tileSizes.size(), context: b.getContext());
528 if (!options.interchangeVector.empty()) {
529 // Based on the pruned iterations (due to zero tile size), recompute the
530 // interchange vector.
531 SmallVector<unsigned, 4> interchangeVector;
532 interchangeVector.reserve(N: options.interchangeVector.size());
533 for (auto pos : options.interchangeVector) {
534 auto it = loopIndexToRangeIndex.find(pos);
535 if (it == loopIndexToRangeIndex.end())
536 continue;
537 interchangeVector.push_back(Elt: it->second);
538 }
539 // Interchange vector is guaranteed to be a permutation,
540 // `inversePermutation` must succeed.
541 invPermutationMap = inversePermutation(
542 map: AffineMap::getPermutationMap(permutation: interchangeVector, context: b.getContext()));
543 assert(invPermutationMap);
544 SmallVector<int64_t> permutation(interchangeVector.begin(),
545 interchangeVector.end());
546 applyPermutationToVector(loopRanges, permutation);
547 applyPermutationToVector(iteratorTypes, permutation);
548 }
549
550 // Handle distribution. Create a vector of the same size of loops that are to
551 // be tiled.
552 SmallVector<linalg::ProcInfo> procInfo;
553 if (options.distribution) {
554 procInfo.resize(
555 iteratorTypes.size(),
556 linalg::ProcInfo{.procId: nullptr, .nprocs: nullptr, .distributionMethod: linalg::DistributionMethod::None});
557 // Collect loop ranges of tiled loops, loops that are parallel.
558 SmallVector<Range> parallelLoopRanges;
559 for (const auto &iteratorType : llvm::enumerate(iteratorTypes)) {
560 if (!isParallelIterator(iteratorType.value()))
561 break;
562 parallelLoopRanges.push_back(loopRanges[iteratorType.index()]);
563 }
564 auto returnedProcInfo =
565 options.distribution->procInfo(b, op.getLoc(), parallelLoopRanges);
566 unsigned procIdIdx = 0;
567 // Update the distribution information for the loops.
568 for (const auto &iteratorType : llvm::enumerate(iteratorTypes)) {
569 if (!isParallelIterator(iteratorType.value()))
570 break;
571 procInfo[iteratorType.index()] = returnedProcInfo[procIdIdx++];
572 }
573 }
574
575 // 2. Create the tiled loops.
576 LinalgOp res = op;
577 SmallVector<Value, 4> ivs, tensorResults;
578 auto tiledLoopBodyBuilder =
579 [&](OpBuilder &builder, Location loc, ValueRange localIvs,
580 ValueRange operandValuesToUse) -> scf::ValueVector {
581 ivs.assign(in_start: localIvs.begin(), in_end: localIvs.end());
582
583 // When an `interchangeVector` is present, it has been applied to the
584 // loop ranges and the iterator types. Apply its inverse to the
585 // resulting loop `ivs` to match the op definition.
586 SmallVector<Value, 4> interchangedIvs;
587 if (!options.interchangeVector.empty()) {
588 for (AffineExpr result : invPermutationMap.getResults())
589 interchangedIvs.push_back(
590 Elt: ivs[cast<AffineDimExpr>(Val&: result).getPosition()]);
591 } else {
592 interchangedIvs.assign(in_start: ivs.begin(), in_end: ivs.end());
593 }
594
595 // Tile the `operandValuesToUse` that either match the `op` operands
596 // themselves or the tile loop arguments forwarding them.
597 assert(operandValuesToUse.size() ==
598 static_cast<size_t>(op->getNumOperands()) &&
599 "expect the number of operands and inputs and outputs to match");
600 SmallVector<Value> valuesToTile = operandValuesToUse;
601 SmallVector<OpFoldResult> sizeBounds =
602 makeComposedFoldedMultiResultAffineApply(b, loc, map: shapeSizesToLoopsMap,
603 operands: allShapeSizes);
604 SmallVector<Value> tiledOperands = makeTiledShapes(
605 b, loc, op, valuesToTile, getAsOpFoldResult(values: interchangedIvs), tileSizes,
606 sizeBounds,
607 /*omitPartialTileCheck=*/false);
608
609 SmallVector<Type> resultTensorTypes =
610 getTensorOutputTypes(op, tiledOperands);
611 res = clone(b, op, resultTensorTypes, tiledOperands);
612 tensorResults =
613 insertSlicesBack(builder, loc, op, tiledOperands, res->getResults());
614 return scf::ValueVector(tensorResults.begin(), tensorResults.end());
615 };
616 GenerateLoopNest<LoopTy>::doit(b, op.getLoc(), loopRanges, op, iteratorTypes,
617 tiledLoopBodyBuilder, procInfo);
618
619 // 3. Transform IndexOp results w.r.t. the tiling.
620 transformIndexOps(b, res, ivs, loopIndexToRangeIndex);
621
622 // 4. Gather the newly created loops and return them with the new op.
623 SmallVector<Operation *, 8> loops;
624 loops.reserve(N: ivs.size());
625 for (auto iv : ivs) {
626 if (isa<BlockArgument>(Val: iv)) {
627 loops.push_back(Elt: cast<BlockArgument>(Val&: iv).getOwner()->getParentOp());
628 assert(loops.back() && "no owner found for induction variable!");
629 } else {
630 // TODO: Instead of doing this, try to recover the ops used instead of the
631 // loop.
632 loops.push_back(Elt: nullptr);
633 }
634 }
635
636 // 5. Get the tensor results from the outermost loop if available. Otherwise
637 // use the previously captured `tensorResults`.
638 Operation *outermostLoop = nullptr;
639 for (Operation *loop : loops)
640 if ((outermostLoop = loop))
641 break;
642
643 return TiledLinalgOp{
644 res, loops, outermostLoop ? outermostLoop->getResults() : tensorResults};
645}
646
647FailureOr<linalg::ForallReductionTilingResult> linalg::tileReductionUsingForall(
648 RewriterBase &b, PartialReductionOpInterface op,
649 ArrayRef<OpFoldResult> numThreads, ArrayRef<OpFoldResult> tileSizes,
650 std::optional<ArrayAttr> mapping) {
651 Location loc = op.getLoc();
652 OpBuilder::InsertionGuard g(b);
653
654 // Ops implementing PartialReductionOpInterface are expected to implement
655 // TilingInterface.
656 // TODO: proper core mechanism to tie interfaces together.
657 auto tilingInterfaceOp = cast<TilingInterface>(op.getOperation());
658
659 // Ops implementing PartialReductionOpInterface are not necessarily expected
660 // to implement TilingInterface.. This cast is unsafe atm.
661 // TODO: proper core mechanism to tie interfaces together.
662 // TODO: this function requires a pair of interfaces ..
663 auto destinationStyleOp =
664 dyn_cast<DestinationStyleOpInterface>(op.getOperation());
665 if (!destinationStyleOp)
666 return b.notifyMatchFailure(op, "not a destination style op");
667
668 // Actually this only work for Linalg ops atm.
669 auto linalgOp = dyn_cast<linalg::LinalgOp>(op.getOperation());
670 if (!linalgOp)
671 return b.notifyMatchFailure(op, "not a linalg op");
672
673 SmallVector<Range> iterationDomain = tilingInterfaceOp.getIterationDomain(b);
674 if (op->getNumResults() != 1)
675 return b.notifyMatchFailure(
676 op, "don't support ops with multiple results for now");
677
678 SmallVector<utils::IteratorType> iterators =
679 tilingInterfaceOp.getLoopIteratorTypes();
680 SmallVector<unsigned> redDims;
681 linalgOp.getReductionDims(redDims);
682 if (redDims.size() != 1)
683 return b.notifyMatchFailure(
684 op, "only support ops with one reduction dimension.");
685 if (!tileSizes.empty() && tileSizes.size() != numThreads.size())
686 return b.notifyMatchFailure(op, "if tile sizes are present it must have as "
687 "many elements as number of threads");
688 int reductionDim = static_cast<int>(redDims.front());
689
690 if (redDims.front() >= numThreads.size())
691 return b.notifyMatchFailure(
692 op, "reduction dimension must be mapped to threads");
693
694 // 1. Create the inital tensor value.
695 FailureOr<Operation *> identityTensor =
696 op.generateInitialTensorForPartialReduction(b, loc, numThreads,
697 reductionDim);
698 if (failed(result: identityTensor))
699 return b.notifyMatchFailure(op,
700 "cannot create a tensor of identity value.");
701
702 // Gather destination tensors.
703 SmallVector<Value> dest;
704 if (failed(tensor::getOrCreateDestinations(b, loc, op: op, result&: dest)))
705 return b.notifyMatchFailure(op, "failed to get destination tensors");
706
707 Operation *tiledOp = nullptr;
708
709 SmallVector<OpFoldResult> nonZeroNumThreads =
710 llvm::to_vector(Range: llvm::make_filter_range(Range&: numThreads, Pred: [](OpFoldResult ofr) {
711 return !isConstantIntValue(ofr, value: 0);
712 }));
713 SmallVector<Value> materializedNonZeroNumThreads =
714 getValueOrCreateConstantIndexOp(b, loc, valueOrAttrVec: nonZeroNumThreads);
715
716 // 2. Create the ForallOp with an empty region.
717 scf::ForallOp forallOp = b.create<scf::ForallOp>(
718 loc, getAsOpFoldResult(materializedNonZeroNumThreads),
719 (*identityTensor)->getResults(), mapping);
720
721 // 3. Calculate the tile offsets and sizes for the subsequent loop that will
722 // be nested under `forallOp`.
723 SmallVector<OpFoldResult> tiledOffsets, tiledSizes;
724 calculateTileOffsetsAndSizes(b, loc, forallOp, numThreads, iterationDomain,
725 /*omitTileOffsetBoundsCheck =*/false,
726 /*nominalTileSizes=*/std::nullopt, tiledOffsets,
727 tiledSizes);
728
729 // 4. Clone the tileable op and update its destination operands to use the
730 // output bbArgs of the ForallOp.
731 SmallVector<Value> tilingResults;
732 ArrayRef<BlockArgument> destBbArgs = forallOp.getRegionIterArgs();
733 {
734 // 4.a. RAII guard, inserting within forallOp, before terminator.
735 OpBuilder::InsertionGuard g(b);
736 b.setInsertionPoint(forallOp.getTerminator());
737
738 SmallVector<Value> tiledDpsInitOperands;
739 for (Value initOperand : destinationStyleOp.getDpsInits()) {
740 auto *it = llvm::find(dest, initOperand);
741 assert(it != dest.end() && "dest operand not found in dest");
742 unsigned destNum = std::distance(dest.begin(), it);
743 SmallVector<OpFoldResult> strides(numThreads.size(), b.getIndexAttr(1));
744 SmallVector<OpFoldResult> outOffsets(numThreads.size(),
745 b.getIndexAttr(0));
746 SmallVector<OpFoldResult> sizes = tiledSizes;
747 sizes[reductionDim] = b.getIndexAttr(1);
748 outOffsets[reductionDim] = forallOp.getInductionVars().front();
749 // TODO: use SubsetExtractOpInterface once it is available.
750 tiledDpsInitOperands.push_back(b.create<tensor::ExtractSliceOp>(
751 loc, cast<RankedTensorType>(initOperand.getType()),
752 destBbArgs[destNum], outOffsets, sizes, strides));
753 }
754
755 // 4.b. Clone the op and update init operands.
756 // We cannot use a IRMapping here because it can replace
757 // different OpOperands with the same value.
758 Operation *clonedOp = b.clone(*op.getOperation());
759 b.modifyOpInPlace(root: clonedOp, callable: [&]() {
760 for (auto [initOperandPtr, tiledInitValue] : llvm::zip_equal(
761 cast<DestinationStyleOpInterface>(clonedOp).getDpsInitsMutable(),
762 tiledDpsInitOperands)) {
763 initOperandPtr.set(tiledInitValue);
764 }
765 });
766
767 // 5. Tile the cloned op and delete the clone.
768 if (tileSizes.empty()) {
769 FailureOr<TilingResult> tilingResult =
770 cast<TilingInterface>(clonedOp).getTiledImplementation(
771 b, tiledOffsets, tiledSizes);
772 if (failed(result: tilingResult))
773 return clonedOp->emitError(message: "Failed to tile op: ");
774 if (tilingResult->tiledOps.size() != 1) {
775 return clonedOp->emitError(message: "expected a single produced tiled op, got ")
776 << tilingResult->tiledOps.size();
777 }
778 tiledOp = tilingResult->tiledOps.front();
779 tilingResults = tilingResult->tiledValues;
780 } else {
781 LinalgTilingOptions options;
782 FailureOr<TiledLinalgOp> maybeTiled = tileLinalgOpImpl<scf::ForOp>(
783 b, cast<LinalgOp>(clonedOp), tileSizes, options);
784 if (failed(result: maybeTiled))
785 return b.notifyMatchFailure(op, "failed tileLinalgOpImpl");
786
787 SmallVector<Value> ids = forallOp.getInductionVars();
788 mapLoopToProcessorIds(cast<scf::ForOp>(maybeTiled->loops.back()), ids,
789 materializedNonZeroNumThreads);
790 if (maybeTiled->loops.size() != 1) {
791 return clonedOp->emitError(message: "expected a single produced loop");
792 }
793 tiledOp = maybeTiled->op;
794 tilingResults = maybeTiled->loops.front()->getResults();
795 }
796
797 b.eraseOp(op: clonedOp);
798 }
799
800 // 6. Insert the partial reductions back into a new tensor.
801 for (auto [index, result, bbArg] : llvm::zip(
802 llvm::seq<unsigned>(0, dest.size()), tilingResults, destBbArgs)) {
803 // 6.a. Partial subset information is inserted just before the terminator.
804 OpBuilder::InsertionGuard g(b);
805 b.setInsertionPoint(forallOp.getTerminator());
806
807 SmallVector<OpFoldResult> resultOffsets, resultSizes;
808 if (failed(tilingInterfaceOp.getResultTilePosition(
809 b, index, tiledOffsets, tiledSizes, resultOffsets, resultSizes)))
810 return op->emitOpError("output offsets couldn't be calculated");
811 SmallVector<OpFoldResult> resultOffsetsRank, resultSizesRank;
812 int64_t offIdx = 0;
813 int64_t sizeIdx = 0;
814 for (int64_t i = 0, e = numThreads.size(); i < e; ++i) {
815 if (i == reductionDim) {
816 resultOffsetsRank.push_back(forallOp.getInductionVars().front());
817 resultSizesRank.push_back(b.getIndexAttr(1));
818 continue;
819 }
820 resultOffsetsRank.push_back(resultOffsets[offIdx++]);
821 resultSizesRank.push_back(resultSizes[sizeIdx++]);
822 }
823 SmallVector<OpFoldResult> strides(resultSizesRank.size(),
824 b.getIndexAttr(1));
825
826 // 6.b. Parallel insertions are inserted at the end of the combining
827 // terminator.
828 b.setInsertionPointToEnd(forallOp.getTerminator().getBody());
829 b.create<tensor::ParallelInsertSliceOp>(
830 loc, result, bbArg, resultOffsetsRank, resultSizesRank, strides);
831 }
832
833 // 7. Merge the partial reductions.
834 b.setInsertionPointAfter(forallOp);
835 Operation *mergeOp =
836 op.mergeReductions(b, loc, forallOp->getResults(), reductionDim);
837 b.replaceOp(op, mergeOp->getResults());
838
839 // 8. Return.
840 ForallReductionTilingResult results;
841 results.initialOp = *identityTensor;
842 results.loops = forallOp;
843 results.parallelTiledOp = tiledOp;
844 results.mergeOp = mergeOp;
845 return results;
846}
847
848template <typename LoopTy>
849FailureOr<TiledLinalgOp> static tileLinalgOpImpl(
850 RewriterBase &b, LinalgOp op, const LinalgTilingOptions &options) {
851 OpBuilder::InsertionGuard g(b);
852 b.setInsertionPoint(op);
853
854 if (!options.tileSizeComputationFunction)
855 return failure();
856
857 // Enforce the convention that "tiling by zero" skips tiling a particular
858 // dimension. This convention is significantly simpler to handle instead of
859 // adjusting affine maps to account for missing dimensions.
860 auto nLoops = op.getNumLoops();
861 SmallVector<OpFoldResult> tileSizeVector =
862 getAsOpFoldResult(options.tileSizeComputationFunction(b, op));
863 if (tileSizeVector.size() < nLoops) {
864 tileSizeVector.append(nLoops - tileSizeVector.size(), b.getIndexAttr(0));
865 }
866
867 return tileLinalgOpImpl<LoopTy>(b, op, tileSizeVector, options);
868}
869
870FailureOr<TiledLinalgOp>
871mlir::linalg::tileLinalgOp(RewriterBase &b, LinalgOp op,
872 const LinalgTilingOptions &options) {
873 switch (options.loopType) {
874 case LinalgTilingLoopType::Loops:
875 return tileLinalgOpImpl<scf::ForOp>(b, op, options);
876 case LinalgTilingLoopType::ParallelLoops:
877 return tileLinalgOpImpl<scf::ParallelOp>(b, op, options);
878 default:;
879 }
880 return failure();
881}
882
883namespace {
884/// Helper classes for type list expansion.
885template <typename... OpTypes>
886class CanonicalizationPatternList;
887
888template <>
889class CanonicalizationPatternList<> {
890public:
891 static void insert(RewritePatternSet &patterns) {}
892};
893
894template <typename OpTy, typename... OpTypes>
895class CanonicalizationPatternList<OpTy, OpTypes...> {
896public:
897 static void insert(RewritePatternSet &patterns) {
898 OpTy::getCanonicalizationPatterns(patterns, patterns.getContext());
899 CanonicalizationPatternList<OpTypes...>::insert(patterns);
900 }
901};
902} // namespace
903
904RewritePatternSet
905mlir::linalg::getLinalgTilingCanonicalizationPatterns(MLIRContext *ctx) {
906 RewritePatternSet patterns(ctx);
907 populateLinalgTilingCanonicalizationPatterns(patterns);
908 return patterns;
909}
910
911void mlir::linalg::populateLinalgTilingCanonicalizationPatterns(
912 RewritePatternSet &patterns) {
913 auto *ctx = patterns.getContext();
914 affine::AffineApplyOp::getCanonicalizationPatterns(patterns, ctx);
915 affine::AffineForOp::getCanonicalizationPatterns(patterns, ctx);
916 affine::AffineMinOp::getCanonicalizationPatterns(patterns, ctx);
917 affine::AffineMaxOp::getCanonicalizationPatterns(patterns, ctx);
918 arith::ConstantIndexOp::getCanonicalizationPatterns(patterns, ctx);
919
920 memref::SubViewOp::getCanonicalizationPatterns(patterns, ctx);
921 memref::ViewOp::getCanonicalizationPatterns(patterns, ctx);
922
923 scf::ForOp::getCanonicalizationPatterns(patterns, ctx);
924 scf::ParallelOp::getCanonicalizationPatterns(patterns, ctx);
925
926 tensor::CastOp::getCanonicalizationPatterns(patterns, ctx);
927 tensor::EmptyOp::getCanonicalizationPatterns(patterns, ctx);
928 tensor::ExtractSliceOp::getCanonicalizationPatterns(patterns, ctx);
929 tensor::InsertSliceOp::getCanonicalizationPatterns(patterns, ctx);
930 tensor::PadOp::getCanonicalizationPatterns(patterns, ctx);
931 ctx->getLoadedDialect<LinalgDialect>()->getCanonicalizationPatterns(patterns);
932
933 CanonicalizationPatternList<
934#define GET_OP_LIST
935#include "mlir/Dialect/Linalg/IR/LinalgStructuredOps.cpp.inc"
936 >::insert(patterns);
937}
938

source code of mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp