1//===- TilingInterfaceImpl.cpp - Implementation of TilingInterface -------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#include "mlir/Dialect/Linalg/Transforms/TilingInterfaceImpl.h"
10
11#include "mlir/Analysis/SliceAnalysis.h"
12#include "mlir/Dialect/Affine/IR/AffineOps.h"
13#include "mlir/Dialect/Affine/Utils.h"
14#include "mlir/Dialect/Arith/IR/Arith.h"
15#include "mlir/Dialect/Arith/Utils/Utils.h"
16#include "mlir/Dialect/Linalg/IR/Linalg.h"
17#include "mlir/Dialect/Linalg/Utils/Utils.h"
18#include "mlir/Dialect/MemRef/IR/MemRef.h"
19#include "mlir/Dialect/Tensor/IR/Tensor.h"
20#include "mlir/Dialect/Utils/IndexingUtils.h"
21#include "mlir/Dialect/Utils/StaticValueUtils.h"
22#include "mlir/Dialect/Utils/StructuredOpsUtils.h"
23#include "mlir/Interfaces/TilingInterface.h"
24#include "mlir/Interfaces/ValueBoundsOpInterface.h"
25#include "llvm/Support/Debug.h"
26#include <optional>
27
28#define DEBUG_TYPE "linalg-tiling-interface-impl"
29
30using namespace mlir;
31using namespace mlir::linalg;
32
33//===----------------------------------------------------------------------===//
34// Utility methods for implementation of Tiling Interface for Linalg ops
35//===----------------------------------------------------------------------===//
36
37/// Return the SSA values that represent the data point accessed using a given
38/// `indexingMap` for a given point in the iteration space represented by `ivs`.
39static SmallVector<Value> getIndicesForAccess(OpBuilder &b, Location loc,
40 AffineMap indexingMap,
41 ValueRange ivs) {
42 SmallVector<Value> indices;
43 indices.reserve(N: indexingMap.getNumResults());
44 for (auto result : indexingMap.getResults()) {
45 AffineMap m = AffineMap::get(dimCount: indexingMap.getNumDims(),
46 symbolCount: indexingMap.getNumSymbols(), result);
47 Value v = b.create<affine::AffineApplyOp>(location: loc, args&: m, args&: ivs);
48 indices.push_back(Elt: v);
49 }
50 return indices;
51}
52
53/// Method to inline the payload of a `linalgOp` given the iteration space
54/// point and values for the arguments of the payload.
55static LogicalResult inlinePayload(OpBuilder &b, LinalgOp linalgOp,
56 ValueRange ivs, ValueRange argValues) {
57 Block *body = linalgOp.getBlock();
58 IRMapping map;
59 map.map(from: body->getArguments(), to&: argValues);
60 for (auto &op : body->without_terminator()) {
61 if (auto indexOp = dyn_cast<IndexOp>(Val: &op)) {
62 map.map(from: indexOp.getResult(), to: ivs[indexOp.getDim()]);
63 continue;
64 }
65 b.clone(op, mapper&: map);
66 }
67
68 Operation *terminator = body->getTerminator();
69 Location loc = terminator->getLoc();
70 for (const auto &operand : llvm::enumerate(First: terminator->getOperands())) {
71 Value toStore = map.lookupOrDefault(from: operand.value());
72 OpOperand *storeInto = linalgOp.getDpsInitOperand(i: operand.index());
73 auto indices = getIndicesForAccess(
74 b, loc, indexingMap: linalgOp.getMatchingIndexingMap(opOperand: storeInto), ivs);
75 b.create<memref::StoreOp>(
76 location: loc, args&: toStore, args: linalgOp.getDpsInitOperand(i: operand.index())->get(),
77 args&: indices);
78 }
79 return success();
80}
81
82//===----------------------------------------------------------------------===//
83// External Model for implementing `TilingInterface` for `LinalgOp`s.
84//===----------------------------------------------------------------------===//
85
86namespace {
87/// External model implementation of TilingInterface for LinalgOps. An external
88/// model implementation is used for now till the use of `TilingInterface` is
89/// on-par with the current Linalg tiling + fusion patterns. Once it is
90/// maybe possible to move this into the op-definition (though there are
91/// advantages to leaving it as an external model)
92template <typename LinalgOpTy>
93struct LinalgOpTilingInterface
94 : public TilingInterface::ExternalModel<LinalgOpTilingInterface<LinalgOpTy>,
95 LinalgOpTy> {
96 /// Return the loop iterator type.
97 SmallVector<utils::IteratorType> getLoopIteratorTypes(Operation *op) const {
98 LinalgOpTy concreteOp = cast<LinalgOpTy>(op);
99 return concreteOp.getIteratorTypesArray();
100 }
101
102 /// Return the iteration domain range.
103 SmallVector<Range> getIterationDomain(Operation *op, OpBuilder &b) const {
104 OpBuilder::InsertionGuard g(b);
105 b.setInsertionPoint(op);
106 Location loc = op->getLoc();
107 LinalgOp linalgOp = cast<LinalgOp>(Val: op);
108 SmallVector<OpFoldResult> allShapesSizes =
109 linalgOp.createFlatListOfOperandDims(b, loc);
110 AffineMap map = linalgOp.getShapesToLoopsMap();
111
112 return llvm::to_vector(
113 llvm::map_range(map.getResults(), [&](AffineExpr loopExpr) {
114 OpFoldResult ofr = affine::makeComposedFoldedAffineApply(
115 b, loc, expr: loopExpr, operands: allShapesSizes);
116 return Range{.offset: b.getIndexAttr(value: 0), .size: ofr, .stride: b.getIndexAttr(value: 1)};
117 }));
118 }
119
120 /// Instantiate the tiled implementation of the operation.
121 FailureOr<TilingResult>
122 getTiledImplementation(Operation *op, OpBuilder &b,
123 ArrayRef<OpFoldResult> offsets,
124 ArrayRef<OpFoldResult> sizes) const {
125 // Leave the `sizeBounds` value empty. That is only needed when the `sizes`
126 // specified could lead to out of bounds accesses.
127 Location loc = op->getLoc();
128 LinalgOp linalgOp = cast<LinalgOp>(Val: op);
129 SmallVector<Value> valuesToTile = linalgOp->getOperands();
130 SmallVector<Value> tiledOperands = makeTiledShapes(
131 builder&: b, loc, linalgOp, valuesToTile, ivs: offsets, tileSizes: sizes, sizeBounds: {}, omitPartialTileCheck: true);
132 SmallVector<Operation *> generatedSlices = llvm::map_to_vector(
133 llvm::make_filter_range(
134 tiledOperands,
135 [](Value v) -> bool {
136 return isa_and_nonnull<tensor::ExtractSliceOp, memref::SubViewOp>(
137 Val: v.getDefiningOp());
138 }),
139 [](Value v) -> Operation * { return v.getDefiningOp(); });
140
141 SmallVector<Type> resultTensorTypes =
142 getTensorOutputTypes(op: linalgOp, operands: tiledOperands);
143
144 Operation *tiledOp = clone(b, op: linalgOp, newResultTypes: resultTensorTypes, newOperands: tiledOperands);
145 offsetIndices(b, linalgOp: cast<LinalgOp>(Val: tiledOp), offests: offsets);
146
147 return TilingResult{
148 .tiledOps: {tiledOp}, .tiledValues: SmallVector<Value>(tiledOp->getResults()), .generatedSlices: generatedSlices};
149 }
150
151 /// Utility to fetch the offsets and sizes when applied as per the indexing
152 /// map of the linalg op. This helps in fusing the linalg op as a consumer of
153 /// a given slice op.
154 static LogicalResult
155 getMappedOffsetAndSize(LinalgOp linalgOp, OpBuilder &b,
156 ArrayRef<AffineMap> indexingMaps,
157 ArrayRef<SmallVector<OpFoldResult>> allOffsets,
158 ArrayRef<SmallVector<OpFoldResult>> allSizes,
159 SmallVectorImpl<OpFoldResult> &mappedOffsetsVec,
160 SmallVectorImpl<OpFoldResult> &mappedSizesVec) {
161 DenseMap<unsigned, OpFoldResult> mappedOffsets, mappedSizes;
162
163 for (auto [indexingMap, offsets, sizes] :
164 llvm::zip_equal(t&: indexingMaps, u&: allOffsets, args&: allSizes)) {
165 for (auto [resultExpr, offset, size] :
166 llvm::zip_equal(t: indexingMap.getResults(), u: offsets, args: sizes)) {
167 auto dimExpr = dyn_cast<AffineDimExpr>(Val: resultExpr);
168 if (!dimExpr)
169 continue;
170 unsigned position = dimExpr.getPosition();
171 auto it = mappedOffsets.find(Val: position);
172 if (it != mappedOffsets.end()) {
173 OpFoldResult seenOffset = it->second;
174 OpFoldResult seenSize = mappedSizes.lookup(Val: position);
175 if (seenOffset != offset || seenSize != size) {
176 LLVM_DEBUG({
177 llvm::dbgs() << "inconsistent iteration space mapping from "
178 "offsets/sizes of operands/results";
179 });
180 return failure();
181 }
182 } else {
183 mappedOffsets[position] = offset;
184 mappedSizes[position] = size;
185 }
186 }
187 }
188
189 // Aggregate from the given operand offsets and sizes, or default to
190 // iteration space values.
191 SmallVector<Range> iterationDomain =
192 cast<TilingInterface>(Val: linalgOp.getOperation()).getIterationDomain(b);
193 mappedOffsetsVec.resize(N: iterationDomain.size());
194 mappedSizesVec.resize(N: iterationDomain.size());
195 for (auto [index, domain] : llvm::enumerate(First&: iterationDomain)) {
196 auto it = mappedOffsets.find(Val: index);
197 if (it != mappedOffsets.end()) {
198 mappedOffsetsVec[index] = it->second;
199 mappedSizesVec[index] = mappedSizes.lookup(Val: index);
200 continue;
201 }
202 mappedOffsetsVec[index] = domain.offset;
203 mappedSizesVec[index] = domain.size;
204 }
205 return success();
206 }
207
208 /// Method to return the position of the result tile computed by the tiled
209 /// operation.
210 LogicalResult getIterationDomainTileFromOperandTiles(
211 Operation *op, OpBuilder &b, ArrayRef<unsigned> operandNumbers,
212 ArrayRef<SmallVector<OpFoldResult>> allOffsets,
213 ArrayRef<SmallVector<OpFoldResult>> allSizes,
214 SmallVectorImpl<OpFoldResult> &iterDomainOffsets,
215 SmallVectorImpl<OpFoldResult> &iterDomainSizes) const {
216 auto linalgOp = cast<LinalgOp>(Val: op);
217
218 std::optional<SmallVector<OpFoldResult>> iterationSpaceOffsets,
219 iterationSpaceSizes;
220 SmallVector<AffineMap> indexingMaps =
221 llvm::map_to_vector(operandNumbers, [&](unsigned operandNumber) {
222 OpOperand &opOperand = linalgOp->getOpOperand(idx: operandNumber);
223 return linalgOp.getMatchingIndexingMap(opOperand: &opOperand);
224 });
225 if (failed(Result: getMappedOffsetAndSize(linalgOp, b, indexingMaps, allOffsets,
226 allSizes, mappedOffsetsVec&: iterDomainOffsets,
227 mappedSizesVec&: iterDomainSizes))) {
228 return failure();
229 }
230 return success();
231 }
232
233 /// Return the details of the output tile generated by the tiled
234 /// implementation.
235 LogicalResult
236 getResultTilePosition(Operation *op, OpBuilder &b, unsigned resultNumber,
237 ArrayRef<OpFoldResult> offsets,
238 ArrayRef<OpFoldResult> sizes,
239 SmallVector<OpFoldResult> &resultOffsets,
240 SmallVector<OpFoldResult> &resultSizes) const {
241 Location loc = op->getLoc();
242 LinalgOp linalgOp = cast<LinalgOp>(Val: op);
243
244 AffineExpr d0;
245 bindDims(ctx: b.getContext(), exprs&: d0);
246 SmallVector<OpFoldResult> subShapeSizes =
247 llvm::to_vector(llvm::map_range(sizes, [&](OpFoldResult ofr) {
248 return affine::makeComposedFoldedAffineApply(b, loc, expr: d0 - 1, operands: ofr);
249 }));
250
251 OpOperand *outOperand = linalgOp.getDpsInitOperand(i: resultNumber);
252 SliceParameters sliceParams = computeSliceParameters(
253 builder&: b, loc, valueToTile: outOperand->get(), tileSizes: sizes,
254 map: linalgOp.getMatchingIndexingMap(opOperand: outOperand), lbs: offsets,
255 /*ubs*/ {}, subShapeSizes, omitPartialTileCheck: true);
256 resultOffsets = sliceParams.offsets;
257 resultSizes = sliceParams.sizes;
258 return success();
259 }
260
261 LogicalResult getIterationDomainTileFromResultTile(
262 Operation *op, OpBuilder &b, unsigned resultNumber,
263 ArrayRef<OpFoldResult> offsets, ArrayRef<OpFoldResult> sizes,
264 SmallVectorImpl<OpFoldResult> &iterDomainOffsets,
265 SmallVectorImpl<OpFoldResult> &iterDomainSizes) const {
266 auto linalgOp = cast<LinalgOp>(Val: op);
267
268 // Check that the indexing map used for the output is a projected
269 // permutation. This could be relaxed with a more general approach that can
270 // map the offsets and sizes from the result to iteration space tiles
271 // (filling in full extent for dimensions not used to access the result).
272 AffineMap indexingMap =
273 linalgOp.getIndexingMapMatchingResult(result: op->getResult(idx: resultNumber));
274 if (!indexingMap.isProjectedPermutation()) {
275 return op->emitOpError(
276 message: "unhandled tiled implementation generation when result is not "
277 "accessed using a permuted projection");
278 }
279
280 SmallVector<OpFoldResult> allOffsets = llvm::to_vector(Range&: offsets);
281 SmallVector<OpFoldResult> allSizes = llvm::to_vector(Range&: sizes);
282 auto status =
283 getMappedOffsetAndSize(linalgOp, b, indexingMaps: indexingMap, allOffsets: {allOffsets},
284 allSizes: {allSizes}, mappedOffsetsVec&: iterDomainOffsets, mappedSizesVec&: iterDomainSizes);
285 (void)status;
286 assert(succeeded(status) && "unexpected error in offset calculation");
287 return success();
288 }
289
290 FailureOr<TilingResult>
291 generateResultTileValue(Operation *op, OpBuilder &b, unsigned resultNumber,
292 ArrayRef<OpFoldResult> offsets,
293 ArrayRef<OpFoldResult> sizes) const {
294 SmallVector<OpFoldResult> mappedOffsets, mappedSizes;
295 if (failed(getIterationDomainTileFromResultTile(
296 op, b, resultNumber, offsets, sizes, iterDomainOffsets&: mappedOffsets, iterDomainSizes&: mappedSizes))) {
297 return failure();
298 }
299 auto tilingInterfaceOp = cast<TilingInterface>(Val: op);
300 FailureOr<TilingResult> tilingResult =
301 tilingInterfaceOp.getTiledImplementation(b, offsets: mappedOffsets, sizes: mappedSizes);
302
303 if (failed(Result: tilingResult))
304 return failure();
305
306 if (tilingResult->tiledOps.size() != 1)
307 return op->emitOpError(message: "failed to generate tiled implementation");
308
309 return TilingResult{
310 .tiledOps: tilingResult->tiledOps,
311 .tiledValues: SmallVector<Value>{tilingResult->tiledValues[resultNumber]},
312 .generatedSlices: tilingResult->generatedSlices};
313 }
314
315 /// Method to generate the tiled implementation of an operation from the tile
316 /// of the operand.
317 FailureOr<TilingResult> getTiledImplementationFromOperandTiles(
318 Operation *op, OpBuilder &b, ArrayRef<unsigned> operandNumbers,
319 ArrayRef<SmallVector<OpFoldResult>> allOffsets,
320 ArrayRef<SmallVector<OpFoldResult>> allSizes) const {
321 SmallVector<OpFoldResult> mappedOffsets, mappedSizes;
322 if (failed(getIterationDomainTileFromOperandTiles(
323 op, b, operandNumbers, allOffsets, allSizes, iterDomainOffsets&: mappedOffsets,
324 iterDomainSizes&: mappedSizes))) {
325 return failure();
326 }
327 return getTiledImplementation(op, b, offsets: mappedOffsets, sizes: mappedSizes);
328 }
329
330 LogicalResult generateScalarImplementation(Operation *op, OpBuilder &builder,
331 Location loc,
332 ValueRange ivs) const {
333 auto linalgOp = cast<LinalgOp>(Val: op);
334 if (!linalgOp.hasPureBufferSemantics())
335 return op->emitOpError(message: "expected operation to have buffer semantics");
336
337 SmallVector<Value> indexedValues;
338 indexedValues.reserve(N: linalgOp->getNumOperands());
339 Location linalgOpLoc = op->getLoc();
340 /// Load the data corresponding to the block arguments that
341 /// represent input operands.
342 for (OpOperand &operand : linalgOp->getOpOperands()) {
343 if (!linalgOp.payloadUsesValueFromOperand(opOperand: &operand)) {
344 indexedValues.push_back(Elt: nullptr);
345 continue;
346 }
347 if (linalgOp.isScalar(opOperand: &operand)) {
348 indexedValues.push_back(Elt: operand.get());
349 continue;
350 }
351 SmallVector<Value> indices = getIndicesForAccess(
352 b&: builder, loc: linalgOpLoc, indexingMap: linalgOp.getMatchingIndexingMap(opOperand: &operand), ivs);
353 Value load =
354 builder.create<memref::LoadOp>(location: linalgOpLoc, args: operand.get(), args&: indices);
355 indexedValues.push_back(Elt: load);
356 }
357
358 /// Inline the op payload and store the result.
359 return inlinePayload(b&: builder, linalgOp, ivs, argValues: indexedValues);
360 }
361};
362
363//===----------------------------------------------------------------------===//
364// External Model for implementing `PartialReductionInterface` for `LinalgOp`s.
365//===----------------------------------------------------------------------===//
366
367/// In a given set vector, get the position of a particular element.
368std::optional<int> getPositionIn(const llvm::SetVector<unsigned> &reductionDims,
369 unsigned value) {
370 for (auto [index, reductionDim] : llvm::enumerate(First: reductionDims)) {
371 if (reductionDim == value) {
372 return index;
373 }
374 }
375 return std::nullopt;
376}
377
378/// Return an AffineMaps to use for the `outs` operands of the linalg op
379/// generated for partial results. The new AffineMap is the AffineMap of the
380/// untiled op with reduction dimensions appended at end in order in which they
381/// were specified during tiling.
382static SmallVector<AffineMap>
383getPartialResultAffineMaps(LinalgOp linalgOp,
384 const SetVector<unsigned> &reductionDims) {
385 auto partialReductionMaps = llvm::map_to_vector(
386 C: linalgOp.getDpsInitsMutable(), F: [&](OpOperand &opOperand) {
387 AffineMap map = linalgOp.getMatchingIndexingMap(opOperand: &opOperand);
388 for (auto redPos : reductionDims) {
389 map =
390 map.insertResult(expr: getAffineDimExpr(position: redPos, context: linalgOp.getContext()),
391 pos: map.getNumResults());
392 }
393 return map;
394 });
395 return partialReductionMaps;
396}
397
398struct InitSliceInfo {
399 SmallVector<int64_t> resultShape;
400 SmallVector<OpFoldResult> offsets;
401 SmallVector<OpFoldResult> sizes;
402 SmallVector<OpFoldResult> strides;
403};
404
405/// Return the result shape, offsets, sizes and strides of the slice of the
406/// `initValue` to use as the destination of the partial reduction op generated
407/// with outer reduction strategy.
408static InitSliceInfo getInitSliceInfoForOuterReduction(
409 MLIRContext *context, ArrayRef<OpFoldResult> offsets,
410 ArrayRef<OpFoldResult> sizes, const SetVector<unsigned> &reductionDims,
411 ArrayRef<OpFoldResult> splitReductionIvs, AffineMap partialReductionMap) {
412 int64_t initRank = partialReductionMap.getNumResults();
413 SmallVector<OpFoldResult> initOffsets, initSizes;
414 Attribute zero = IntegerAttr::get(type: IndexType::get(context), value: 0);
415 Attribute one = IntegerAttr::get(type: IndexType::get(context), value: 1);
416 SmallVector<OpFoldResult> initStrides(initRank, one);
417 for (AffineExpr dimExpr : partialReductionMap.getResults()) {
418 unsigned dim = cast<AffineDimExpr>(Val&: dimExpr).getPosition();
419 if (reductionDims.contains(key: dim)) {
420 initOffsets.push_back(Elt: zero);
421 } else {
422 initOffsets.push_back(Elt: offsets[dim]);
423 }
424 initSizes.push_back(Elt: sizes[dim]);
425 }
426 SmallVector<int64_t> resultShape;
427 std::tie(args&: resultShape, args: std::ignore) = decomposeMixedValues(mixedValues: initSizes);
428 return {.resultShape: resultShape, .offsets: initOffsets, .sizes: initSizes, .strides: initStrides};
429}
430
431/// Return the result shape, offsets, sizes and strides of the slice of the
432/// `initValue` to use as destination of the partial reduction op generated with
433/// outer parallel strategy.
434static InitSliceInfo getInitSliceInfoForOuterParallel(
435 MLIRContext *context, ArrayRef<OpFoldResult> offsets,
436 ArrayRef<OpFoldResult> sizes, const SetVector<unsigned> &reductionDims,
437 ArrayRef<OpFoldResult> splitReductionIvs, AffineMap partialReductionMap) {
438 int64_t initRank = partialReductionMap.getNumResults();
439 SmallVector<OpFoldResult> initOffsets, initSizes;
440 Attribute one = IntegerAttr::get(type: IndexType::get(context), value: 1);
441 SmallVector<OpFoldResult> initStrides(initRank, one);
442 SmallVector<OpFoldResult> resultShape;
443 for (AffineExpr dimExpr : partialReductionMap.getResults()) {
444 unsigned dim = cast<AffineDimExpr>(Val&: dimExpr).getPosition();
445 if (std::optional<unsigned> dimPos = getPositionIn(reductionDims, value: dim)) {
446 initOffsets.push_back(Elt: splitReductionIvs[dimPos.value()]);
447 initSizes.push_back(Elt: one);
448 } else {
449 initOffsets.push_back(Elt: offsets[dim]);
450 initSizes.push_back(Elt: sizes[dim]);
451 resultShape.push_back(Elt: sizes[dim]);
452 }
453 }
454 SmallVector<int64_t> staticShapes;
455 std::tie(args&: staticShapes, args: std::ignore) = decomposeMixedValues(mixedValues: resultShape);
456 return {.resultShape: staticShapes, .offsets: initOffsets, .sizes: initSizes, .strides: initStrides};
457}
458
459/// Return the result shape, offsets, sizes and strides of the slice of the
460/// `initValue` to use as destination of the partial reduction op.
461static InitSliceInfo getInitSliceInfo(MLIRContext *context,
462 ReductionTilingStrategy strategy,
463 ArrayRef<OpFoldResult> offsets,
464 ArrayRef<OpFoldResult> sizes,
465 const SetVector<unsigned> &reductionDims,
466 ArrayRef<OpFoldResult> splitReductionIvs,
467 AffineMap partialReductionMap) {
468 if (strategy == ReductionTilingStrategy::PartialReductionOuterReduction) {
469 return getInitSliceInfoForOuterReduction(context, offsets, sizes,
470 reductionDims, splitReductionIvs,
471 partialReductionMap);
472 }
473 assert(strategy == ReductionTilingStrategy::PartialReductionOuterParallel &&
474 "unexpected ReductionTilingStrategy");
475 return getInitSliceInfoForOuterParallel(context, offsets, sizes,
476 reductionDims, splitReductionIvs,
477 partialReductionMap);
478}
479
480/// External model implementation of PartialReductionInterface for
481/// LinalgOps.
482template <typename LinalgOpTy>
483struct LinalgOpPartialReductionInterface
484 : public PartialReductionOpInterface::ExternalModel<
485 LinalgOpPartialReductionInterface<LinalgOpTy>, LinalgOpTy> {
486 FailureOr<SmallVector<Value>> generateInitialTensorForPartialReduction(
487 Operation *op, OpBuilder &b, Location loc, ArrayRef<OpFoldResult> sizes,
488 const SetVector<unsigned> &reductionDims) const {
489 auto linalgOp = cast<LinalgOp>(Val: op);
490
491 OpBuilder::InsertionGuard guard(b);
492 if (linalgOp.hasPureBufferSemantics())
493 return op->emitOpError(message: "expected operation to have tensor semantics");
494
495 SmallVector<AffineMap> partialResultMaps =
496 getPartialResultAffineMaps(linalgOp, reductionDims);
497
498 SmallVector<Value> inits;
499 for (auto [initIdx, result, partialMap] :
500 llvm::enumerate(First: linalgOp->getResults(), Rest&: partialResultMaps)) {
501 SmallVector<Operation *, 4> combinerOps;
502 if (!matchReduction(iterCarriedArgs: linalgOp.getRegionOutputArgs(), redPos: initIdx,
503 combinerOps) ||
504 combinerOps.size() != 1)
505 return op->emitOpError(message: "Failed to anaysis the reduction operation.");
506
507 Operation *reductionOp = combinerOps[0];
508 std::optional<TypedAttr> identity = arith::getNeutralElement(op: reductionOp);
509 if (!identity.has_value())
510 return op->emitOpError(
511 message: "Failed to get an identity value for the reduction operation.");
512
513 // Append the new partial result dimensions.
514 SmallVector<OpFoldResult> partialResultShape;
515 for (AffineExpr dimExpr : partialMap.getResults()) {
516 auto dim = cast<AffineDimExpr>(Val&: dimExpr);
517 partialResultShape.push_back(Elt: sizes[dim.getPosition()]);
518 }
519
520 Type elType = getElementTypeOrSelf(type: result.getType());
521 Value emptyTensor =
522 b.create<tensor::EmptyOp>(location: loc, args&: partialResultShape, args&: elType);
523 Value constantOp = b.create<arith::ConstantOp>(location: loc, args&: *identity);
524 auto identityTensor =
525 b.create<linalg::FillOp>(location: loc, args&: constantOp, args&: emptyTensor);
526 inits.push_back(Elt: identityTensor.getResult(i: 0));
527 }
528
529 return inits;
530 }
531
532 FailureOr<TilingResult>
533 tileToPartialReduction(Operation *op, OpBuilder &b, Location loc,
534 ReductionTilingStrategy tilingStrategy,
535 ValueRange init, ArrayRef<OpFoldResult> offsets,
536 ArrayRef<OpFoldResult> sizes,
537 const SetVector<unsigned> &reductionDims,
538 ArrayRef<OpFoldResult> splitReductionIvs) const {
539 OpBuilder::InsertionGuard guard(b);
540 auto linalgOp = cast<LinalgOp>(Val: op);
541
542 SmallVector<AffineMap> partialReductionMaps =
543 getPartialResultAffineMaps(linalgOp, reductionDims);
544
545 // Step 1. Extend init maps to have reduction dimension dims, since we
546 // are converting them to parallel dimensions.
547 SmallVector<AffineMap> newInitMaps;
548 if (tilingStrategy ==
549 ReductionTilingStrategy::PartialReductionOuterReduction) {
550 newInitMaps = llvm::to_vector(Range&: partialReductionMaps);
551 } else {
552 newInitMaps = llvm::map_to_vector(
553 linalgOp.getDpsInitsMutable(), [&](OpOperand &opOperand) {
554 return linalgOp.getMatchingIndexingMap(opOperand: &opOperand);
555 });
556 }
557
558 // Step 2a: Extract a slice of the input operands.
559 SmallVector<Value> tiledInputs = makeTiledShapes(
560 builder&: b, loc, linalgOp, valuesToTile: linalgOp.getDpsInputs(), ivs: offsets, tileSizes: sizes, sizeBounds: {}, omitPartialTileCheck: true);
561 SmallVector<Operation *> generatedSlices = llvm::map_to_vector(
562 llvm::make_filter_range(
563 tiledInputs, [](Value v) -> bool { return v.getDefiningOp(); }),
564 [](Value v) -> Operation * { return v.getDefiningOp(); });
565
566 // Step 2b: Extract a slice of the init operands.
567 SmallVector<Value, 1> tiledInits;
568 for (auto [partialReductionMap, valueToTile] :
569 llvm::zip_equal(t&: partialReductionMaps, u&: init)) {
570 InitSliceInfo sliceInfo = getInitSliceInfo(
571 context: b.getContext(), strategy: tilingStrategy, offsets, sizes, reductionDims,
572 splitReductionIvs, partialReductionMap);
573 auto valueToTileType = cast<RankedTensorType>(Val: valueToTile.getType());
574 RankedTensorType sliceResultType = RankedTensorType::get(
575 shape: sliceInfo.resultShape, elementType: valueToTileType.getElementType(),
576 encoding: valueToTileType.getEncoding());
577 auto sliceOp = b.create<tensor::ExtractSliceOp>(
578 location: loc, args&: sliceResultType, args&: valueToTile, args&: sliceInfo.offsets, args&: sliceInfo.sizes,
579 args&: sliceInfo.strides);
580 tiledInits.push_back(Elt: sliceOp.getResult());
581 generatedSlices.push_back(Elt: sliceOp);
582 }
583
584 // Update the indexing maps.
585 SmallVector<AffineMap> newMaps = linalgOp.getIndexingMapsArray();
586 for (auto [initOperand, newInitMap] :
587 llvm::zip_equal(t: linalgOp.getDpsInitsMutable(), u&: newInitMaps)) {
588 int mapIdx = linalgOp.getIndexingMapIndex(opOperand: &initOperand);
589 newMaps[mapIdx] = newInitMap;
590 }
591
592 // Step 3. Change the reduction dim iterator types.
593 SmallVector<utils::IteratorType> newIteratorTypes =
594 linalgOp.getIteratorTypesArray();
595 if (tilingStrategy ==
596 ReductionTilingStrategy::PartialReductionOuterReduction) {
597 for (int dim : reductionDims)
598 newIteratorTypes[dim] = utils::IteratorType::parallel;
599 }
600
601 // Step 4. Create the new generic op.
602 Operation *partialReductionOp;
603 auto resultTypes = ValueRange(tiledInits).getTypes();
604 if (tilingStrategy ==
605 ReductionTilingStrategy::PartialReductionOuterReduction) {
606 auto genericOp = b.create<GenericOp>(
607 location: loc, args&: resultTypes, args&: tiledInputs, args&: tiledInits, args&: newMaps, args&: newIteratorTypes);
608 IRMapping mapping;
609 op->getRegion(index: 0).cloneInto(dest: &genericOp.getRegion(),
610 destPos: genericOp.getRegion().begin(), mapper&: mapping);
611 partialReductionOp = genericOp.getOperation();
612 } else {
613 SmallVector<Value> operands = std::move(tiledInputs);
614 llvm::append_range(C&: operands, R&: tiledInits);
615 partialReductionOp = mlir::clone(b, op, newResultTypes: resultTypes, newOperands: operands);
616 }
617 return TilingResult{
618 {partialReductionOp},
619 llvm::map_to_vector(partialReductionOp->getResults(),
620 [](OpResult r) -> Value { return r; }),
621 generatedSlices};
622 }
623
624 FailureOr<MergeResult>
625 mergeReductions(Operation *op, OpBuilder &b, Location loc,
626 ValueRange partialReduce,
627 const SetVector<unsigned> &reductionDims) const {
628 auto linalgOp = cast<LinalgOp>(Val: op);
629 SmallVector<AffineMap> partialReductionMaps =
630 getPartialResultAffineMaps(linalgOp, reductionDims);
631
632 // Permute the reduction dims as permuted by the partial result map.
633 SmallVector<Operation *> mergeOperations;
634 SmallVector<Value> replacements;
635 for (auto [idx, init, partialResult, partialMap] : llvm::enumerate(
636 First: linalgOp.getDpsInits(), Rest&: partialReduce, Rest&: partialReductionMaps)) {
637 unsigned initIdx = idx;
638 // linalg.reduce's iteration space is the tiled result's iteration space
639 // (and not the tiled operation's iteration space). To account for this,
640 // permute the reduction dimensions based on the partial result map of the
641 // tiled result.
642 SmallVector<int64_t> partialReductionDims;
643 for (auto [resultNum, dimExpr] :
644 llvm::enumerate(First: partialMap.getResults())) {
645 unsigned dim = cast<AffineDimExpr>(Val: dimExpr).getPosition();
646 if (llvm::is_contained(Range: reductionDims, Element: dim)) {
647 partialReductionDims.push_back(Elt: resultNum);
648 }
649 }
650
651 auto reduction = b.create<linalg::ReduceOp>(
652 loc, partialResult, init, partialReductionDims,
653 [&linalgOp, &initIdx](OpBuilder &b, Location loc, ValueRange inputs) {
654 // Get the combiner op.
655 SmallVector<Operation *, 4> combinerOps;
656 matchReduction(iterCarriedArgs: linalgOp.getRegionOutputArgs(), redPos: initIdx,
657 combinerOps);
658 Operation *clonedReductionOp = b.clone(op&: *combinerOps[0]);
659 // Combine the input at idx and output at numInits + idx.
660 clonedReductionOp->setOperand(idx: 0, value: inputs[0]);
661 clonedReductionOp->setOperand(idx: 1, value: inputs[1]);
662 b.create<linalg::YieldOp>(location: loc, args: clonedReductionOp->getResult(idx: 0));
663 });
664
665 mergeOperations.push_back(Elt: reduction);
666 replacements.push_back(Elt: reduction->getResult(0));
667 }
668
669 return MergeResult{.mergeOps: mergeOperations, .replacements: replacements};
670 }
671
672 LogicalResult getPartialResultTilePosition(
673 Operation *op, OpBuilder &b, unsigned resultNumber,
674 ReductionTilingStrategy tilingStrategy, ArrayRef<OpFoldResult> offsets,
675 ArrayRef<OpFoldResult> sizes, const SetVector<unsigned> &reductionDims,
676 ArrayRef<OpFoldResult> splitReductionIvs,
677 SmallVector<OpFoldResult> &resultOffsets,
678 SmallVector<OpFoldResult> &resultSizes) const {
679 auto linalgOp = cast<LinalgOp>(Val: op);
680 SmallVector<AffineMap> partialReductionMaps =
681 getPartialResultAffineMaps(linalgOp, reductionDims);
682 InitSliceInfo sliceInfo = getInitSliceInfo(
683 context: b.getContext(), strategy: tilingStrategy, offsets, sizes, reductionDims,
684 splitReductionIvs, partialReductionMap: partialReductionMaps[resultNumber]);
685 std::swap(LHS&: resultOffsets, RHS&: sliceInfo.offsets);
686 std::swap(LHS&: resultSizes, RHS&: sliceInfo.sizes);
687
688 return success();
689 }
690};
691
692template <typename OpTy>
693static SmallVector<Range> getPackUnPackIterationDomain(OpTy op,
694 OpBuilder &builder) {
695 static_assert(llvm::is_one_of<OpTy, PackOp, UnPackOp>::value,
696 "applies to only pack or unpack operations");
697 OpBuilder::InsertionGuard g(builder);
698 int64_t rank = (std::is_same<OpTy, PackOp>::value) ? op.getSourceRank()
699 : op.getDestRank();
700 OpFoldResult zero = builder.getIndexAttr(value: 0);
701 OpFoldResult one = builder.getIndexAttr(value: 1);
702 ReifiedRankedShapedTypeDims resultShape;
703 (void)reifyResultShapes(builder, op, resultShape);
704 SmallVector<Range> loopBounds(rank);
705 for (auto dim : llvm::seq<int64_t>(Begin: 0, End: rank)) {
706 loopBounds[dim].offset = zero;
707 loopBounds[dim].stride = one;
708 loopBounds[dim].size = resultShape[0][dim];
709 }
710 return loopBounds;
711}
712
713static void applyPermToRange(SmallVector<OpFoldResult> &offsets,
714 SmallVector<OpFoldResult> &sizes,
715 ArrayRef<int64_t> permutation) {
716 if (permutation.empty())
717 return;
718 applyPermutationToVector<OpFoldResult>(inVec&: offsets, permutation);
719 applyPermutationToVector<OpFoldResult>(inVec&: sizes, permutation);
720}
721
722struct PackOpTiling
723 : public TilingInterface::ExternalModel<PackOpTiling, linalg::PackOp> {
724
725 SmallVector<utils::IteratorType> getLoopIteratorTypes(Operation *op) const {
726 // Note that here we only consider untiled dimensions and outer tiled data
727 // dimensions, the inner tiled data dimensions are materialized when
728 // building the body of the operation.
729 auto packOp = cast<PackOp>(Val: op);
730 SmallVector<utils::IteratorType> iteratorTypes(
731 packOp.getSourceRank(), utils::IteratorType::parallel);
732 return iteratorTypes;
733 }
734
735 SmallVector<Range> getIterationDomain(Operation *op, OpBuilder &b) const {
736 return getPackUnPackIterationDomain<PackOp>(op: cast<PackOp>(Val: op), builder&: b);
737 }
738
739 FailureOr<TilingResult>
740 getTiledImplementation(Operation *op, OpBuilder &b,
741 ArrayRef<OpFoldResult> offsets,
742 ArrayRef<OpFoldResult> sizes) const {
743 auto packOp = cast<PackOp>(Val: op);
744 Location loc = packOp.getLoc();
745
746 // The tiling is applied on interchanged dimensions. We have to undo the
747 // interchange to map sizes and offsets to the original input.
748 int64_t inputRank = packOp.getSourceRank();
749 SmallVector<OpFoldResult> origOffsets(offsets);
750 SmallVector<OpFoldResult> origSizes(sizes);
751 applyPermToRange(offsets&: origOffsets, sizes&: origSizes,
752 permutation: invertPermutationVector(permutation: packOp.getOuterDimsPerm()));
753
754 DenseMap<int64_t, OpFoldResult> dimAndTileMapping =
755 packOp.getDimAndTileMapping();
756 SmallVector<OpFoldResult> srcDimValues =
757 tensor::getMixedSizes(builder&: b, loc, value: packOp.getSource());
758 SmallVector<OpFoldResult> inputIndices, inputSizes;
759 for (auto dim : llvm::seq<int64_t>(Begin: 0, End: inputRank)) {
760 using AV = affine::AffineValueExpr;
761 affine::AffineBuilder ab(b, loc);
762 AffineExpr dim0, dim1, sym;
763 bindDims(ctx: b.getContext(), exprs&: dim0, exprs&: dim1);
764 bindSymbols(ctx: b.getContext(), exprs&: sym);
765 if (dimAndTileMapping.count(Val: dim)) {
766 // If the data dimension is tiled, the i-th index is the product of
767 // offset_i and tile_i, and the i-th size is the product of sizes_i and
768 // tile_i.
769 auto avOffset = AV(dim0).bind(v: origOffsets[dim]);
770 auto avSize = AV(dim0).bind(v: origSizes[dim]);
771 auto avTileSize = AV(sym).bind(v: dimAndTileMapping[dim]);
772 inputIndices.push_back(Elt: ab.mul(lhs: avOffset, rhs: avTileSize));
773 inputSizes.push_back(Elt: ab.mul(lhs: avSize, rhs: avTileSize));
774 } else {
775 inputIndices.push_back(Elt: origOffsets[dim]);
776 inputSizes.push_back(Elt: origSizes[dim]);
777 }
778
779 // Limit the size of the input operand for incomplete tiles.
780 if (packOp.getPaddingValue()) {
781 OpFoldResult dimSize = srcDimValues[dim];
782 auto avDimSize = AV(dim0).bind(v: dimSize);
783 auto avInputIdx = AV(dim1).bind(v: inputIndices.back());
784 inputSizes.back() =
785 ab.min(vals: {inputSizes.back(), ab.sub(lhs: avDimSize, rhs: avInputIdx)});
786 }
787 }
788
789 auto oneAttr = b.getI64IntegerAttr(value: 1);
790 SmallVector<OpFoldResult> strides(inputRank, oneAttr);
791
792 SmallVector<Value> tiledOperands;
793 auto sourceSlice = b.create<tensor::ExtractSliceOp>(
794 location: loc, args: packOp.getSource(), args&: inputIndices, args&: inputSizes, args&: strides);
795 tiledOperands.push_back(Elt: sourceSlice);
796
797 SmallVector<OpFoldResult> outputOffsets, outputSizes;
798 if (failed(Result: getResultTilePosition(op, b, resultNumber: 0, offsets, sizes, resultOffsets&: outputOffsets,
799 resultSizes&: outputSizes)))
800 return {};
801
802 strides.append(NumInputs: packOp.getDestRank() - inputRank, Elt: oneAttr);
803 auto outSlice = b.create<tensor::ExtractSliceOp>(
804 location: loc, args: packOp.getDest(), args&: outputOffsets, args&: outputSizes, args&: strides);
805 tiledOperands.push_back(Elt: outSlice);
806
807 if (auto val = packOp.getPaddingValue())
808 tiledOperands.push_back(Elt: val);
809 for (auto tile : packOp.getInnerTiles())
810 tiledOperands.push_back(Elt: tile);
811
812 Operation *tiledPackOp = b.create<PackOp>(
813 location: loc, args: TypeRange{outSlice.getType()}, args&: tiledOperands, args: op->getAttrs());
814
815 return TilingResult{
816 .tiledOps: {tiledPackOp},
817 .tiledValues: SmallVector<Value>(tiledPackOp->getResults()),
818 .generatedSlices: llvm::to_vector(Range: ArrayRef<Operation *>{sourceSlice, outSlice})};
819 }
820
821 LogicalResult
822 getResultTilePosition(Operation *op, OpBuilder &b, unsigned resultNumber,
823 ArrayRef<OpFoldResult> offsets,
824 ArrayRef<OpFoldResult> sizes,
825 SmallVector<OpFoldResult> &resultOffsets,
826 SmallVector<OpFoldResult> &resultSizes) const {
827 // The iteration domain is over outer dimensions of packed layout. In this
828 // context, the outer dimensions of `resultOffsets` are `offsets`. The
829 // inner dimensions of `resultOffsets` are zeros because tiling is not
830 // applied to them.
831 auto packOp = cast<PackOp>(Val: op);
832 int64_t inputRank = packOp.getSourceRank();
833 int64_t outputRank = packOp.getDestRank();
834 auto zeroAttr = b.getI64IntegerAttr(value: 0);
835 resultOffsets.assign(in_start: offsets.begin(), in_end: offsets.end());
836 resultOffsets.append(NumInputs: outputRank - inputRank, Elt: zeroAttr);
837
838 ReifiedRankedShapedTypeDims outputShape;
839 (void)reifyResultShapes(b, op: packOp, reifiedReturnShapes&: outputShape);
840 resultSizes.assign(in_start: sizes.begin(), in_end: sizes.end());
841 for (auto dataTileDim : llvm::seq<unsigned>(Begin: inputRank, End: outputRank))
842 resultSizes.push_back(Elt: outputShape[0][dataTileDim]);
843
844 return success();
845 }
846
847 FailureOr<TilingResult>
848 generateResultTileValue(Operation *op, OpBuilder &b, unsigned resultNumber,
849 ArrayRef<OpFoldResult> offsets,
850 ArrayRef<OpFoldResult> sizes) const {
851 auto packOp = cast<PackOp>(Val: op);
852 int64_t numTiles = packOp.getInnerDimsPos().size();
853
854 // tensor.pack op is fusible (as a producer) only if full inner tiles are
855 // iterated or inner dims are not tiled. Otherwise, it will generate a
856 // sequence of non-trivial ops (for partial tiles).
857 for (auto offset : offsets.take_back(N: numTiles))
858 if (!isZeroInteger(v: offset))
859 return failure();
860
861 for (auto iter :
862 llvm::zip_equal(t: packOp.getMixedTiles(), u: sizes.take_back(N: numTiles)))
863 if (!isEqualConstantIntOrValue(ofr1: std::get<0>(t&: iter), ofr2: std::get<1>(t&: iter)))
864 return failure();
865
866 FailureOr<TilingResult> tilingResult = getTiledImplementation(
867 op, b, offsets: offsets.drop_back(N: numTiles), sizes: sizes.drop_back(N: numTiles));
868 if (failed(Result: tilingResult))
869 return failure();
870 return tilingResult.value();
871 }
872
873 /// Method to return the position of iteration domain tile computed by the
874 /// tiled operation. In current `tensor.pack` context, the `resultOffsets` and
875 /// `resultSizes` only cover outer dimensions.
876 LogicalResult getIterationDomainTileFromOperandTiles(
877 Operation *op, OpBuilder &b, ArrayRef<unsigned> operandNumbers,
878 ArrayRef<SmallVector<OpFoldResult>> allOffsets,
879 ArrayRef<SmallVector<OpFoldResult>> allSizes,
880 SmallVectorImpl<OpFoldResult> &resultOffsets,
881 SmallVectorImpl<OpFoldResult> &resultSizes) const {
882 if (operandNumbers.size() != 1 || operandNumbers[0] != 0) {
883 LLVM_DEBUG(
884 { llvm::dbgs() << "unsupported operands for consumer fusion"; });
885 return failure();
886 }
887
888 ArrayRef<OpFoldResult> offsets(allOffsets[0]);
889 ArrayRef<OpFoldResult> sizes(allSizes[0]);
890
891 auto packOp = cast<PackOp>(Val: op);
892 // It is not trivial to infer dest tile from source tile if `packOp` has
893 // padding semantic.
894 if (packOp.getPaddingValue())
895 return failure();
896
897 Location loc = packOp.getLoc();
898
899 SmallVector<OpFoldResult> outerDimOffsets, outerDimSizes;
900 DenseMap<int64_t, OpFoldResult> dimAndTileMapping =
901 packOp.getDimAndTileMapping();
902 for (auto dim : llvm::seq<int64_t>(Size: packOp.getSourceRank())) {
903 if (dimAndTileMapping.count(Val: dim)) {
904 FailureOr<int64_t> cstSize =
905 ValueBoundsConstraintSet::computeConstantBound(
906 type: presburger::BoundType::UB, var: sizes[dim],
907 /*stopCondition=*/nullptr, /*closedUB=*/true);
908 std::optional<int64_t> cstInnerSize =
909 getConstantIntValue(ofr: dimAndTileMapping[dim]);
910 // Currently fusing `packOp` as consumer only expects perfect tiling
911 // scenario because even if without padding semantic, the `packOp` may
912 // also yield incomplete tiles. E.g. tensor<30xf32> -> tensor<5x6xf32>,
913 // where the `tileSize` from operand of `packOp` is 5, which is not
914 // exactly divided by `innerTile`(=6) of `packOp`. As the result:
915 // 1. the first slice is extracted from (0) to (4) and inserted into
916 // (0,0)~(0,4) at first row.
917 // 2. the second slice is extracted from (5) to (9) and SHOULD BE
918 // respectively inserted into two rows with different length, including
919 // first row: (0,5) and second row (1,0)~(1,3). It is hard to coordinate
920 // them, thus adding below constraint to bypass them temporarily. In
921 // another word, we can only support tiling with consumer if the tile
922 // size for the producer is a multiple of the inner tile size for the
923 // packed dimensions at this moment.
924 if (failed(Result: cstSize) || !cstInnerSize || *cstSize % *cstInnerSize != 0) {
925 return failure();
926 }
927
928 using AV = affine::AffineValueExpr;
929 affine::AffineBuilder ab(b, loc);
930 AffineExpr dim0, sym;
931 bindDims(ctx: b.getContext(), exprs&: dim0);
932 bindSymbols(ctx: b.getContext(), exprs&: sym);
933 auto avOffset = AV(dim0).bind(v: offsets[dim]);
934 auto avSize = AV(dim0).bind(v: sizes[dim]);
935 auto avTileSize = AV(sym).bind(v: dimAndTileMapping[dim]);
936 outerDimOffsets.push_back(Elt: ab.floor(lhs: avOffset, rhs: avTileSize));
937 outerDimSizes.push_back(Elt: ab.ceil(lhs: avSize, rhs: avTileSize));
938 } else {
939 outerDimOffsets.push_back(Elt: offsets[dim]);
940 outerDimSizes.push_back(Elt: sizes[dim]);
941 }
942 }
943 applyPermToRange(offsets&: outerDimOffsets, sizes&: outerDimSizes, permutation: packOp.getOuterDimsPerm());
944 resultOffsets = outerDimOffsets;
945 resultSizes = outerDimSizes;
946 return success();
947 }
948
949 /// Method to return the tiled implementation of tensor.pack as a consumer.
950 FailureOr<TilingResult> getTiledImplementationFromOperandTiles(
951 Operation *op, OpBuilder &b, ArrayRef<unsigned> operandNumbers,
952 ArrayRef<SmallVector<OpFoldResult>> allOffsets,
953 ArrayRef<SmallVector<OpFoldResult>> allSizes) const {
954 if (operandNumbers.size() != 1 || operandNumbers[0] != 0) {
955 LLVM_DEBUG(
956 { llvm ::dbgs() << "unhandled operands for consumer fusion"; });
957 return failure();
958 }
959
960 ArrayRef<OpFoldResult> offsets(allOffsets[0]);
961 ArrayRef<OpFoldResult> sizes(allSizes[0]);
962
963 auto packOp = cast<PackOp>(Val: op);
964 Location loc = packOp.getLoc();
965
966 int64_t inputRank = packOp.getSourceRank();
967 auto oneAttr = b.getI64IntegerAttr(value: 1);
968 SmallVector<OpFoldResult> strides(inputRank, oneAttr);
969
970 SmallVector<Value> tiledOperands;
971 auto sourceSlice = b.create<tensor::ExtractSliceOp>(
972 location: loc, args: packOp.getSource(), args&: offsets, args&: sizes, args&: strides);
973 tiledOperands.push_back(Elt: sourceSlice);
974
975 SmallVector<OpFoldResult> outerDimOffsets, outerDimSizes;
976 if (failed(Result: getIterationDomainTileFromOperandTiles(
977 op, b, operandNumbers, allOffsets, allSizes, resultOffsets&: outerDimOffsets,
978 resultSizes&: outerDimSizes)))
979 return failure();
980
981 SmallVector<OpFoldResult> outputOffsets, outputSizes;
982 if (failed(Result: getResultTilePosition(op, b, resultNumber: 0, offsets: outerDimOffsets, sizes: outerDimSizes,
983 resultOffsets&: outputOffsets, resultSizes&: outputSizes)))
984 return failure();
985
986 strides.append(NumInputs: packOp.getDestRank() - inputRank, Elt: oneAttr);
987 auto outSlice = b.create<tensor::ExtractSliceOp>(
988 location: loc, args: packOp.getDest(), args&: outputOffsets, args&: outputSizes, args&: strides);
989 tiledOperands.push_back(Elt: outSlice);
990
991 assert(!packOp.getPaddingValue() && "Expect no padding semantic");
992 for (auto tile : packOp.getInnerTiles())
993 tiledOperands.push_back(Elt: tile);
994
995 Operation *tiledPackOp = b.create<PackOp>(
996 location: loc, args: TypeRange{outSlice.getType()}, args&: tiledOperands, args: op->getAttrs());
997
998 return TilingResult{
999 .tiledOps: {tiledPackOp},
1000 .tiledValues: SmallVector<Value>(tiledPackOp->getResults()),
1001 .generatedSlices: llvm::to_vector(Range: ArrayRef<Operation *>{sourceSlice, outSlice})};
1002 }
1003};
1004
1005struct UnpackTileDimInfo {
1006 bool isAlignedToInnerTileSize;
1007 OpFoldResult sourceOffset;
1008 OpFoldResult sourceSize;
1009 OpFoldResult resultOffset;
1010 OpFoldResult destExpandedSize;
1011};
1012
1013/// Returns the needed information for tiling unpack op on `tileDim` with given
1014/// `tileOffset` and `tileSize`. For more details, see the comment of the
1015/// `getTiledImplementation`.
1016static UnpackTileDimInfo getUnpackTileDimInfo(OpBuilder &b, UnPackOp unpackOp,
1017 int64_t tileDim,
1018 OpFoldResult tileOffset,
1019 OpFoldResult tileSize) {
1020 UnpackTileDimInfo info;
1021 Attribute zeroAttr = b.getIndexAttr(value: 0);
1022 Attribute oneAttr = b.getIndexAttr(value: 1);
1023 DenseMap<int64_t, OpFoldResult> dimAndTileMapping =
1024 unpackOp.getDimAndTileMapping();
1025 // The dimension is not one of packed data dimension.
1026 if (!dimAndTileMapping.count(Val: tileDim)) {
1027 info.isAlignedToInnerTileSize = true;
1028 info.sourceOffset = tileOffset;
1029 info.sourceSize = tileSize;
1030 info.resultOffset = zeroAttr;
1031 info.destExpandedSize = tileSize;
1032 return info;
1033 }
1034
1035 Location loc = unpackOp.getLoc();
1036 using AV = affine::AffineValueExpr;
1037 affine::AffineBuilder ab(b, loc);
1038 AffineExpr dim0, dim1, sym0;
1039 bindDims(ctx: b.getContext(), exprs&: dim0, exprs&: dim1);
1040 bindSymbols(ctx: b.getContext(), exprs&: sym0);
1041
1042 OpFoldResult innerTileSize = dimAndTileMapping[tileDim];
1043
1044 info.isAlignedToInnerTileSize = false;
1045 FailureOr<int64_t> cstSize = ValueBoundsConstraintSet::computeConstantBound(
1046 type: presburger::BoundType::UB, var: tileSize,
1047 /*stopCondition=*/nullptr, /*closedUB=*/true);
1048 std::optional<int64_t> cstInnerSize = getConstantIntValue(ofr: innerTileSize);
1049 if (!failed(Result: cstSize) && cstInnerSize) {
1050 if (*cstSize % *cstInnerSize == 0)
1051 info.isAlignedToInnerTileSize = true;
1052
1053 // If the tiling size equals to the inner tiling size, the outer dims are
1054 // always 1.
1055 if (*cstInnerSize == *cstSize) {
1056 auto lhs = AV(dim0).bind(v: tileOffset);
1057 auto rhs = AV(dim1).bind(v: innerTileSize);
1058 info.sourceOffset = ab.floor(lhs, rhs);
1059 info.sourceSize = oneAttr;
1060 info.resultOffset = zeroAttr;
1061 info.destExpandedSize = tileSize;
1062 return info;
1063 }
1064 }
1065
1066 if (info.isAlignedToInnerTileSize) {
1067 info.sourceOffset =
1068 ab.floor(lhs: AV(dim0).bind(v: tileOffset), rhs: AV(dim1).bind(v: innerTileSize));
1069 info.resultOffset = zeroAttr;
1070 info.destExpandedSize = tileSize;
1071
1072 // The ceilDiv is needed here because there could be incomplete tile even
1073 // it is perfect tiling cases. E.g.,
1074 // %0 = unpack tensor<33x2xf32> into tensor<64xf32>
1075 // If the tiling size is 32, there will be 3 tiles. Two of them have
1076 // size=32; one of them have size=2. The size is represented using
1077 // affine_min op; we need ceilDiv.
1078 info.sourceSize =
1079 ab.ceil(lhs: AV(dim0).bind(v: tileSize), rhs: AV(dim1).bind(v: innerTileSize));
1080 return info;
1081 }
1082
1083 affine::DivModValue firstCoord = affine::getDivMod(
1084 b, loc, lhs: getValueOrCreateConstantIndexOp(b, loc, ofr: tileOffset),
1085 rhs: getValueOrCreateConstantIndexOp(b, loc, ofr: innerTileSize));
1086 OpFoldResult tileExclusiveBound =
1087 ab.add(lhs: AV(dim0).bind(v: tileOffset), rhs: AV(dim1).bind(v: tileSize));
1088 affine::DivModValue lastCoord = affine::getDivMod(
1089 b, loc,
1090 lhs: getValueOrCreateConstantIndexOp(
1091 b, loc,
1092 ofr: ab.sub(lhs: AV(dim0).bind(v: tileExclusiveBound), rhs: AV(dim1).bind(v: oneAttr))),
1093 rhs: getValueOrCreateConstantIndexOp(b, loc, ofr: innerTileSize));
1094
1095 OpFoldResult lengthMinusOne = ab.sub(lhs: AV(dim0).bind(v: lastCoord.quotient),
1096 rhs: AV(dim1).bind(v: firstCoord.quotient));
1097 info.sourceSize =
1098 ab.add(lhs: AV(dim0).bind(v: lengthMinusOne), rhs: AV(dim1).bind(v: oneAttr));
1099 info.sourceOffset = firstCoord.quotient;
1100 info.resultOffset = firstCoord.remainder;
1101 // Do not create an Affine ops for expanded size because the affine op is too
1102 // complicated which would trigger an issue in affine ops simplification.
1103 info.destExpandedSize = b.createOrFold<arith::MulIOp>(
1104 location: loc, args: getValueOrCreateConstantIndexOp(b, loc, ofr: info.sourceSize),
1105 args: getValueOrCreateConstantIndexOp(b, loc, ofr: innerTileSize));
1106 return info;
1107}
1108
1109struct UnPackOpTiling
1110 : public TilingInterface::ExternalModel<UnPackOpTiling, linalg::UnPackOp> {
1111
1112 SmallVector<utils::IteratorType> getLoopIteratorTypes(Operation *op) const {
1113 auto unpackOp = cast<UnPackOp>(Val: op);
1114 SmallVector<utils::IteratorType> iteratorTypes(
1115 unpackOp.getDestRank(), utils::IteratorType::parallel);
1116 return iteratorTypes;
1117 }
1118
1119 SmallVector<Range> getIterationDomain(Operation *op, OpBuilder &b) const {
1120 return getPackUnPackIterationDomain<UnPackOp>(op: cast<UnPackOp>(Val: op), builder&: b);
1121 }
1122
1123 /// There are two cases in tiling unpack ops. If the tiling size is aligned to
1124 /// the inner tile size, the corresponding tiles of source are all complete.
1125 /// Otherwise, there are in-complete tiles. We will need to expand the slice
1126 /// of source for getting complete tiles. The tiled unpack op unpacks more
1127 /// data from source, so We'll need an extract_slice op to shift and truncate
1128 /// the output.
1129 /// Take Nn_to_N as an example. Say that N=32, n=8, and tiling_size=15. The
1130 /// coordinates of second tile (i.e., result[15..31]) are
1131 /// [(1, 7), (2, 0,), (2, 1) ... (3, 6), (3, 7)]. The first row and the last
1132 /// row are incomplete tiles. To represent the unpack op, we have to complete
1133 /// the rows. I.e., the input coordinates would start with (1, 0); end with
1134 /// (3, 7). In this context, the tiled unpack produces a (3 * n) elements
1135 /// because there are 3 rows in total. Follow by a tensor.extract_slice op, we
1136 /// can get the actual result.
1137 FailureOr<TilingResult>
1138 getTiledImplementation(Operation *op, OpBuilder &b,
1139 ArrayRef<OpFoldResult> offsets,
1140 ArrayRef<OpFoldResult> sizes) const {
1141 auto unpackOp = cast<UnPackOp>(Val: op);
1142 int64_t srcRank = unpackOp.getSourceRank();
1143 int64_t destRank = unpackOp.getDestRank();
1144 int64_t numInnerTiles = srcRank - destRank;
1145 Location loc = unpackOp.getLoc();
1146
1147 // The perfect tiling case indicates that the tiling sizes are multiple of
1148 // inner_tile_size. In this context, no extra data is needed when
1149 // representing the tiled unpack op.
1150 bool isPerfectTilingCase = true;
1151 Attribute oneAttr = b.getIndexAttr(value: 1);
1152 SmallVector<OpFoldResult> sliceSrcStrides(destRank, oneAttr);
1153 SmallVector<OpFoldResult> sliceSrcIndices, sliceSrcSizes;
1154 SmallVector<OpFoldResult> destExpandedSizes, resultOffsetsFromDest;
1155 for (auto dim : llvm::seq<int64_t>(Begin: 0, End: destRank)) {
1156 UnpackTileDimInfo info =
1157 getUnpackTileDimInfo(b, unpackOp, tileDim: dim, tileOffset: offsets[dim], tileSize: sizes[dim]);
1158 if (!info.isAlignedToInnerTileSize)
1159 isPerfectTilingCase = false;
1160 sliceSrcIndices.push_back(Elt: info.sourceOffset);
1161 sliceSrcSizes.push_back(Elt: info.sourceSize);
1162 destExpandedSizes.push_back(Elt: info.destExpandedSize);
1163 resultOffsetsFromDest.push_back(Elt: info.resultOffset);
1164 }
1165
1166 // The tiling is applied on destination dimensions. We have to apply the
1167 // interchange on source dimensions if outer_dims_perm is set.
1168 applyPermToRange(offsets&: sliceSrcIndices, sizes&: sliceSrcSizes,
1169 permutation: unpackOp.getOuterDimsPerm());
1170 Attribute zeroAttr = b.getIndexAttr(value: 0);
1171 sliceSrcIndices.append(NumInputs: numInnerTiles, Elt: zeroAttr);
1172 sliceSrcSizes.append(RHS: unpackOp.getMixedTiles());
1173 sliceSrcStrides.append(NumInputs: numInnerTiles, Elt: oneAttr);
1174 SmallVector<Operation *> generatedSlices;
1175 tensor::ExtractSliceOp sliceSource = b.create<tensor::ExtractSliceOp>(
1176 location: loc, args: unpackOp.getSource(), args&: sliceSrcIndices, args&: sliceSrcSizes,
1177 args&: sliceSrcStrides);
1178 generatedSlices.push_back(Elt: sliceSource);
1179
1180 SmallVector<OpFoldResult> destStrides(destRank, oneAttr);
1181 Value sliceDest;
1182 if (isPerfectTilingCase) {
1183 auto destSliceOp = b.create<tensor::ExtractSliceOp>(
1184 location: loc, args: unpackOp.getDest(), args&: offsets, args&: sizes, args&: destStrides);
1185 sliceDest = destSliceOp;
1186 generatedSlices.push_back(Elt: destSliceOp);
1187 } else {
1188 sliceDest = b.create<tensor::EmptyOp>(
1189 location: loc, args&: destExpandedSizes, args: unpackOp.getDestType().getElementType());
1190 }
1191
1192 SmallVector<Value> tiledOperands = {sliceSource.getResult(), sliceDest};
1193 for (auto tile : unpackOp.getInnerTiles())
1194 tiledOperands.push_back(Elt: tile);
1195
1196 Operation *tiledUnpackOp = b.create<UnPackOp>(
1197 location: loc, args: TypeRange{sliceDest.getType()}, args&: tiledOperands, args: op->getAttrs());
1198
1199 if (isPerfectTilingCase)
1200 return TilingResult{.tiledOps: {tiledUnpackOp},
1201 .tiledValues: SmallVector<Value>(tiledUnpackOp->getResults()),
1202 .generatedSlices: generatedSlices};
1203
1204 auto extractSlice = b.create<tensor::ExtractSliceOp>(
1205 location: loc, args: tiledUnpackOp->getResult(idx: 0), args&: resultOffsetsFromDest, args&: sizes,
1206 args&: destStrides);
1207 return TilingResult{
1208 .tiledOps: {tiledUnpackOp}, .tiledValues: {extractSlice.getResult()}, .generatedSlices: generatedSlices};
1209 }
1210
1211 LogicalResult
1212 getResultTilePosition(Operation *op, OpBuilder &b, unsigned resultNumber,
1213 ArrayRef<OpFoldResult> offsets,
1214 ArrayRef<OpFoldResult> sizes,
1215 SmallVector<OpFoldResult> &resultOffsets,
1216 SmallVector<OpFoldResult> &resultSizes) const {
1217 resultOffsets = llvm::to_vector(Range&: offsets);
1218 resultSizes = llvm::to_vector(Range&: sizes);
1219 return success();
1220 }
1221
1222 FailureOr<TilingResult>
1223 generateResultTileValue(Operation *op, OpBuilder &b, unsigned resultNumber,
1224 ArrayRef<OpFoldResult> offsets,
1225 ArrayRef<OpFoldResult> sizes) const {
1226 FailureOr<TilingResult> tilingResult =
1227 getTiledImplementation(op, b, offsets, sizes);
1228 if (failed(Result: tilingResult))
1229 return failure();
1230 return tilingResult.value();
1231 }
1232
1233 /// Method to return the position of iteration domain tile computed by the
1234 /// tiled operation.
1235 LogicalResult getIterationDomainTileFromOperandTiles(
1236 Operation *op, OpBuilder &b, ArrayRef<unsigned> operandNumbers,
1237 ArrayRef<SmallVector<OpFoldResult>> allOffsets,
1238 ArrayRef<SmallVector<OpFoldResult>> allSizes,
1239 SmallVectorImpl<OpFoldResult> &resultOffsets,
1240 SmallVectorImpl<OpFoldResult> &resultSizes) const {
1241 if (operandNumbers.size() != 1) {
1242 LLVM_DEBUG({ llvm::dbgs() << "unable to handle multiple operands"; });
1243 return failure();
1244 }
1245 auto unPackOp = cast<UnPackOp>(Val: op);
1246 unsigned operandNumber = operandNumbers[0];
1247 ArrayRef<OpFoldResult> offsets(allOffsets[0]);
1248 ArrayRef<OpFoldResult> sizes(allSizes[0]);
1249
1250 // If the operand tile is the dest, then no adjustment is needed.
1251 if (operandNumber == unPackOp.getDestMutable().getOperandNumber()) {
1252 resultOffsets = llvm::to_vector(Range&: offsets);
1253 resultSizes = llvm::to_vector(Range&: sizes);
1254 return success();
1255 }
1256 Location loc = unPackOp.getLoc();
1257
1258 int64_t numTiles = unPackOp.getInnerDimsPos().size();
1259 auto destOffsets = offsets.drop_back(N: numTiles);
1260 auto destSizes = sizes.drop_back(N: numTiles);
1261 // The tiling is applied on interchanged dimensions. We have to undo the
1262 // interchange to map sizes and offsets to the original input.
1263 int64_t outputRank = unPackOp.getDestRank();
1264 ReifiedRankedShapedTypeDims reifiedReturnShapes;
1265 if (failed(Result: reifyResultShapes(b, op: unPackOp, reifiedReturnShapes)))
1266 return failure();
1267 SmallVector<OpFoldResult> outputMixedSizes = reifiedReturnShapes.front();
1268 SmallVector<OpFoldResult> origOffsets(destOffsets);
1269 SmallVector<OpFoldResult> origSizes(destSizes);
1270 applyPermToRange(offsets&: origOffsets, sizes&: origSizes,
1271 permutation: invertPermutationVector(permutation: unPackOp.getOuterDimsPerm()));
1272
1273 DenseMap<int64_t, OpFoldResult> dimAndTileMapping =
1274 unPackOp.getDimAndTileMapping();
1275
1276 for (auto dim : llvm::seq<int64_t>(Begin: 0, End: outputRank)) {
1277 using AV = affine::AffineValueExpr;
1278 affine::AffineBuilder ab(b, loc);
1279 AffineExpr dim0, dim1, sym0;
1280 bindDims(ctx: b.getContext(), exprs&: dim0, exprs&: dim1);
1281 bindSymbols(ctx: b.getContext(), exprs&: sym0);
1282 if (dimAndTileMapping.count(Val: dim)) {
1283 // If the data dimension is tiled, the i-th index is the product of
1284 // offset_i and tile_i, and the i-th size is the product of sizes_i and
1285 // tile_i. The sizes must be clamped to the sizes of the unpack result.
1286 auto avOffset = AV(dim0).bind(v: origOffsets[dim]);
1287 auto avSize = AV(dim0).bind(v: origSizes[dim]);
1288 auto avTileSize = AV(sym0).bind(v: dimAndTileMapping[dim]);
1289 auto avResultSize = AV(dim0).bind(v: outputMixedSizes[dim]);
1290 resultOffsets.push_back(Elt: ab.mul(lhs: avOffset, rhs: avTileSize));
1291 auto avResultOffset = AV(dim1).bind(v: resultOffsets.back());
1292 resultSizes.push_back(Elt: ab.min(vals: {ab.mul(lhs: avSize, rhs: avTileSize),
1293 ab.sub(lhs: avResultSize, rhs: avResultOffset)}));
1294 } else {
1295 resultOffsets.push_back(Elt: origOffsets[dim]);
1296 resultSizes.push_back(Elt: origSizes[dim]);
1297 }
1298 }
1299 return success();
1300 }
1301
1302 /// Method to return the tiled implementation of tensor.unpack as a consumer.
1303 FailureOr<TilingResult> getTiledImplementationFromOperandTiles(
1304 Operation *op, OpBuilder &b, ArrayRef<unsigned> operandNumbers,
1305 ArrayRef<SmallVector<OpFoldResult>> allOffsets,
1306 ArrayRef<SmallVector<OpFoldResult>> allSizes) const {
1307 if (operandNumbers.size() != 1 || operandNumbers[0] != 0) {
1308 LLVM_DEBUG({ llvm::dbgs() << "unhandled operands for consumer fusion"; });
1309 return failure();
1310 }
1311 auto unPackOp = cast<UnPackOp>(Val: op);
1312 ArrayRef<OpFoldResult> offsets(allOffsets[0]);
1313 ArrayRef<OpFoldResult> sizes(allSizes[0]);
1314
1315 // tensor.unpack op is fusible (as a consumer) only if inner dims are not
1316 // tiled.
1317 int64_t numTiles = unPackOp.getInnerDimsPos().size();
1318 for (auto iter :
1319 llvm::zip_equal(t: unPackOp.getMixedTiles(), u: sizes.take_back(N: numTiles))) {
1320 if (!isEqualConstantIntOrValue(ofr1: std::get<0>(t&: iter), ofr2: std::get<1>(t&: iter)))
1321 return failure();
1322 }
1323
1324 Location loc = unPackOp.getLoc();
1325
1326 // Fetch offset/size for creating the slice of the dest operand of
1327 // unpack op.
1328 SmallVector<OpFoldResult> outputOffsets, outputSizes;
1329 if (failed(Result: getIterationDomainTileFromOperandTiles(
1330 op, b, operandNumbers, allOffsets, allSizes, resultOffsets&: outputOffsets,
1331 resultSizes&: outputSizes)))
1332 return failure();
1333
1334 auto oneAttr = b.getI64IntegerAttr(value: 1);
1335 int64_t outputRank = unPackOp.getDestRank();
1336 SmallVector<OpFoldResult> strides(outputRank, oneAttr);
1337
1338 SmallVector<Value> tiledOperands;
1339 // Create slice of the dest operand.
1340 auto extractDestSlice = b.create<tensor::ExtractSliceOp>(
1341 location: loc, args: unPackOp.getDest(), args&: outputOffsets, args&: outputSizes, args&: strides);
1342 tiledOperands.push_back(Elt: extractDestSlice);
1343
1344 strides.append(NumInputs: unPackOp.getSourceRank() - outputRank, Elt: oneAttr);
1345 // Create slice of the source operand.
1346 auto extractSourceSlice = b.create<tensor::ExtractSliceOp>(
1347 location: loc, args: unPackOp.getSource(), args&: offsets, args&: sizes, args&: strides);
1348 tiledOperands.insert(I: tiledOperands.begin(), Elt: extractSourceSlice);
1349 for (auto tile : unPackOp.getInnerTiles())
1350 tiledOperands.push_back(Elt: tile);
1351
1352 // Create tiled unpack op.
1353 Operation *tiledUnPackOp =
1354 b.create<UnPackOp>(location: loc, args: TypeRange{extractDestSlice.getType()},
1355 args&: tiledOperands, args: op->getAttrs());
1356
1357 return TilingResult{.tiledOps: {tiledUnPackOp},
1358 .tiledValues: SmallVector<Value>(tiledUnPackOp->getResults()),
1359 .generatedSlices: llvm::to_vector(Range: ArrayRef<Operation *>{
1360 extractSourceSlice, extractDestSlice})};
1361 }
1362};
1363
1364} // namespace
1365
1366template <typename OpType>
1367static void registerOne(MLIRContext *ctx) {
1368 OpType::template attachInterface<LinalgOpTilingInterface<OpType>>(*ctx);
1369 OpType::template attachInterface<LinalgOpPartialReductionInterface<OpType>>(
1370 *ctx);
1371}
1372
1373/// Variadic helper function.
1374template <typename... OpTypes>
1375static void registerAll(MLIRContext *ctx) {
1376 (registerOne<OpTypes>(ctx), ...);
1377}
1378
1379#define GET_OP_LIST
1380
1381void mlir::linalg::registerTilingInterfaceExternalModels(
1382 DialectRegistry &registry) {
1383 registry.addExtension(extensionFn: +[](MLIRContext *ctx, linalg::LinalgDialect *dialect) {
1384 registerOne<linalg::GenericOp>(ctx);
1385 linalg::PackOp::attachInterface<PackOpTiling>(context&: *ctx);
1386 linalg::UnPackOp::attachInterface<UnPackOpTiling>(context&: *ctx);
1387 registerAll<
1388#include "mlir/Dialect/Linalg/IR/LinalgStructuredOps.cpp.inc"
1389 >(ctx);
1390 });
1391}
1392
1393void mlir::linalg::registerTilingInterfaceExternalModelsForPackUnPackOps(
1394 DialectRegistry &registry) {
1395 registry.addExtension(extensionFn: +[](MLIRContext *ctx, LinalgDialect *dialect) {
1396 linalg::PackOp::attachInterface<PackOpTiling>(context&: *ctx);
1397 linalg::UnPackOp::attachInterface<UnPackOpTiling>(context&: *ctx);
1398 });
1399}
1400

source code of mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp