1 | //===- TilingInterfaceImpl.cpp - Implementation of TilingInterface -------===// |
---|---|
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | |
9 | #include "mlir/Dialect/Linalg/Transforms/TilingInterfaceImpl.h" |
10 | |
11 | #include "mlir/Analysis/SliceAnalysis.h" |
12 | #include "mlir/Dialect/Affine/IR/AffineOps.h" |
13 | #include "mlir/Dialect/Affine/Utils.h" |
14 | #include "mlir/Dialect/Arith/IR/Arith.h" |
15 | #include "mlir/Dialect/Arith/Utils/Utils.h" |
16 | #include "mlir/Dialect/Linalg/IR/Linalg.h" |
17 | #include "mlir/Dialect/Linalg/Utils/Utils.h" |
18 | #include "mlir/Dialect/MemRef/IR/MemRef.h" |
19 | #include "mlir/Dialect/Tensor/IR/Tensor.h" |
20 | #include "mlir/Dialect/Utils/IndexingUtils.h" |
21 | #include "mlir/Dialect/Utils/StaticValueUtils.h" |
22 | #include "mlir/Interfaces/TilingInterface.h" |
23 | #include "mlir/Interfaces/ValueBoundsOpInterface.h" |
24 | #include <optional> |
25 | |
26 | using namespace mlir; |
27 | using namespace mlir::linalg; |
28 | |
29 | //===----------------------------------------------------------------------===// |
30 | // Utility methods for implementation of Tiling Interface for Linalg ops |
31 | //===----------------------------------------------------------------------===// |
32 | |
33 | /// Return the SSA values that represent the data point accessed using a given |
34 | /// `indexingMap` for a given point in the iteration space represented by `ivs`. |
35 | static SmallVector<Value> getIndicesForAccess(OpBuilder &b, Location loc, |
36 | AffineMap indexingMap, |
37 | ValueRange ivs) { |
38 | SmallVector<Value> indices; |
39 | indices.reserve(N: indexingMap.getNumResults()); |
40 | for (auto result : indexingMap.getResults()) { |
41 | AffineMap m = AffineMap::get(dimCount: indexingMap.getNumDims(), |
42 | symbolCount: indexingMap.getNumSymbols(), result); |
43 | Value v = b.create<affine::AffineApplyOp>(loc, m, ivs); |
44 | indices.push_back(Elt: v); |
45 | } |
46 | return indices; |
47 | } |
48 | |
49 | /// Method to inline the payload of a `linalgOp` given the iteration space |
50 | /// point and values for the arguments of the payload. |
51 | static LogicalResult inlinePayload(OpBuilder &b, LinalgOp linalgOp, |
52 | ValueRange ivs, ValueRange argValues) { |
53 | Block *body = linalgOp.getBlock(); |
54 | IRMapping map; |
55 | map.map(from: body->getArguments(), to&: argValues); |
56 | for (auto &op : body->without_terminator()) { |
57 | if (auto indexOp = dyn_cast<IndexOp>(&op)) { |
58 | map.map(indexOp.getResult(), ivs[indexOp.getDim()]); |
59 | continue; |
60 | } |
61 | b.clone(op, map); |
62 | } |
63 | |
64 | Operation *terminator = body->getTerminator(); |
65 | Location loc = terminator->getLoc(); |
66 | for (const auto &operand : llvm::enumerate(terminator->getOperands())) { |
67 | Value toStore = map.lookupOrDefault(operand.value()); |
68 | OpOperand *storeInto = linalgOp.getDpsInitOperand(operand.index()); |
69 | auto indices = getIndicesForAccess( |
70 | b, loc, linalgOp.getMatchingIndexingMap(storeInto), ivs); |
71 | b.create<memref::StoreOp>( |
72 | loc, toStore, linalgOp.getDpsInitOperand(operand.index())->get(), |
73 | indices); |
74 | } |
75 | return success(); |
76 | } |
77 | |
78 | //===----------------------------------------------------------------------===// |
79 | // External Model for implementing `TilingInterface` for `LinalgOp`s. |
80 | //===----------------------------------------------------------------------===// |
81 | |
82 | namespace { |
83 | /// External model implementation of TilingInterface for LinalgOps. An external |
84 | /// model implementation is used for now till the use of `TilingInterface` is |
85 | /// on-par with the current Linalg tiling + fusion patterns. Once it is |
86 | /// maybe possible to move this into the op-definition (though there are |
87 | /// advantages to leaving it as an external model) |
88 | template <typename LinalgOpTy> |
89 | struct LinalgOpTilingInterface |
90 | : public TilingInterface::ExternalModel<LinalgOpTilingInterface<LinalgOpTy>, |
91 | LinalgOpTy> { |
92 | /// Return the loop iterator type. |
93 | SmallVector<utils::IteratorType> getLoopIteratorTypes(Operation *op) const { |
94 | LinalgOpTy concreteOp = cast<LinalgOpTy>(op); |
95 | return concreteOp.getIteratorTypesArray(); |
96 | } |
97 | |
98 | /// Return the iteration domain range. |
99 | SmallVector<Range> getIterationDomain(Operation *op, OpBuilder &b) const { |
100 | OpBuilder::InsertionGuard g(b); |
101 | b.setInsertionPoint(op); |
102 | Location loc = op->getLoc(); |
103 | LinalgOp linalgOp = cast<LinalgOp>(op); |
104 | SmallVector<OpFoldResult> allShapesSizes = |
105 | linalgOp.createFlatListOfOperandDims(b, loc); |
106 | AffineMap map = linalgOp.getShapesToLoopsMap(); |
107 | |
108 | return llvm::to_vector( |
109 | llvm::map_range(map.getResults(), [&](AffineExpr loopExpr) { |
110 | OpFoldResult ofr = affine::makeComposedFoldedAffineApply( |
111 | b, loc, expr: loopExpr, operands: allShapesSizes); |
112 | return Range{b.getIndexAttr(0), .size: ofr, b.getIndexAttr(1)}; |
113 | })); |
114 | } |
115 | |
116 | /// Instantiate the tiled implementation of the operation. |
117 | FailureOr<TilingResult> |
118 | getTiledImplementation(Operation *op, OpBuilder &b, |
119 | ArrayRef<OpFoldResult> offsets, |
120 | ArrayRef<OpFoldResult> sizes) const { |
121 | // Leave the `sizeBounds` value empty. That is only needed when the `sizes` |
122 | // specified could lead to out of bounds accesses. |
123 | Location loc = op->getLoc(); |
124 | LinalgOp linalgOp = cast<LinalgOp>(op); |
125 | SmallVector<Value> valuesToTile = linalgOp->getOperands(); |
126 | SmallVector<Value> tiledOperands = makeTiledShapes( |
127 | b, loc, linalgOp, valuesToTile, offsets, sizes, {}, true); |
128 | SmallVector<Operation *> generatedSlices = llvm::map_to_vector( |
129 | llvm::make_filter_range( |
130 | tiledOperands, |
131 | [](Value v) -> bool { |
132 | return isa_and_nonnull<tensor::ExtractSliceOp, memref::SubViewOp>( |
133 | v.getDefiningOp()); |
134 | }), |
135 | [](Value v) -> Operation * { return v.getDefiningOp(); }); |
136 | |
137 | SmallVector<Type> resultTensorTypes = |
138 | getTensorOutputTypes(linalgOp, tiledOperands); |
139 | |
140 | Operation *tiledOp = clone(b, linalgOp, resultTensorTypes, tiledOperands); |
141 | offsetIndices(b, cast<LinalgOp>(tiledOp), offsets); |
142 | |
143 | return TilingResult{ |
144 | .tiledOps: {tiledOp}, .tiledValues: SmallVector<Value>(tiledOp->getResults()), .generatedSlices: generatedSlices}; |
145 | } |
146 | |
147 | /// Utility to fetch the offsets and sizes when applied as per the indexing |
148 | /// map of the linalg op. This helps in fusing the linalg op as a consumer of |
149 | /// a given slice op. |
150 | void |
151 | getMappedOffsetAndSize(LinalgOp linalgOp, OpBuilder &b, AffineMap indexingMap, |
152 | ArrayRef<OpFoldResult> offsets, |
153 | ArrayRef<OpFoldResult> sizes, |
154 | SmallVectorImpl<OpFoldResult> &mappedOffsets, |
155 | SmallVectorImpl<OpFoldResult> &mappedSizes) const { |
156 | unsigned numLoops = linalgOp.getNumLoops(); |
157 | auto tilingInterfaceOp = cast<TilingInterface>(linalgOp.getOperation()); |
158 | mappedOffsets.resize(N: numLoops); |
159 | mappedSizes.resize(N: numLoops); |
160 | if (!indexingMap.isPermutation()) { |
161 | SmallVector<Range> iterationDomain = |
162 | tilingInterfaceOp.getIterationDomain(b); |
163 | for (const auto &&[index, value] : llvm::enumerate(iterationDomain)) { |
164 | mappedOffsets[index] = value.offset; |
165 | mappedSizes[index] = value.size; |
166 | } |
167 | } |
168 | for (const auto &&[index, value] : |
169 | llvm::enumerate(First: indexingMap.getResults())) { |
170 | unsigned dimPosition = cast<AffineDimExpr>(Val: value).getPosition(); |
171 | mappedOffsets[dimPosition] = offsets[index]; |
172 | mappedSizes[dimPosition] = sizes[index]; |
173 | } |
174 | } |
175 | |
176 | /// Method to return the position of the result tile computed by the tiled |
177 | /// operation. |
178 | LogicalResult getIterationDomainTileFromOperandTile( |
179 | Operation *op, OpBuilder &b, unsigned operandNumber, |
180 | ArrayRef<OpFoldResult> offsets, ArrayRef<OpFoldResult> sizes, |
181 | SmallVectorImpl<OpFoldResult> &iterDomainOffsets, |
182 | SmallVectorImpl<OpFoldResult> &iterDomainSizes) const { |
183 | auto linalgOp = cast<LinalgOp>(op); |
184 | |
185 | // Check that the indexing map used for the operand is a projected |
186 | // permutation. This could be relaxed with a more general approach that can |
187 | // map the offsets and sizes from the operand to iteration space tiles |
188 | // (filling in full extent for dimensions not used to access the result). |
189 | AffineMap indexingMap = |
190 | linalgOp.getMatchingIndexingMap(&op->getOpOperand(idx: operandNumber)); |
191 | if (!indexingMap.isProjectedPermutation()) { |
192 | return op->emitError() |
193 | << "unhandled get iter domain position when operand is not " |
194 | "accessed using a permuted projection"; |
195 | } |
196 | |
197 | getMappedOffsetAndSize(linalgOp: linalgOp, b, indexingMap, offsets, sizes, |
198 | mappedOffsets&: iterDomainOffsets, mappedSizes&: iterDomainSizes); |
199 | return success(); |
200 | } |
201 | |
202 | /// Return the details of the output tile generated by the tiled |
203 | /// implementation. |
204 | LogicalResult |
205 | getResultTilePosition(Operation *op, OpBuilder &b, unsigned resultNumber, |
206 | ArrayRef<OpFoldResult> offsets, |
207 | ArrayRef<OpFoldResult> sizes, |
208 | SmallVector<OpFoldResult> &resultOffsets, |
209 | SmallVector<OpFoldResult> &resultSizes) const { |
210 | Location loc = op->getLoc(); |
211 | LinalgOp linalgOp = cast<LinalgOp>(op); |
212 | |
213 | AffineExpr d0; |
214 | bindDims(ctx: b.getContext(), exprs&: d0); |
215 | SmallVector<OpFoldResult> subShapeSizes = |
216 | llvm::to_vector(llvm::map_range(sizes, [&](OpFoldResult ofr) { |
217 | return affine::makeComposedFoldedAffineApply(b, loc, expr: d0 - 1, operands: ofr); |
218 | })); |
219 | |
220 | OpOperand *outOperand = linalgOp.getDpsInitOperand(resultNumber); |
221 | SliceParameters sliceParams = computeSliceParameters( |
222 | b, loc, outOperand->get(), sizes, |
223 | linalgOp.getMatchingIndexingMap(outOperand), offsets, |
224 | /*ubs*/ {}, subShapeSizes, true); |
225 | resultOffsets = sliceParams.offsets; |
226 | resultSizes = sliceParams.sizes; |
227 | return success(); |
228 | } |
229 | |
230 | LogicalResult getIterationDomainTileFromResultTile( |
231 | Operation *op, OpBuilder &b, unsigned resultNumber, |
232 | ArrayRef<OpFoldResult> offsets, ArrayRef<OpFoldResult> sizes, |
233 | SmallVectorImpl<OpFoldResult> &iterDomainOffsets, |
234 | SmallVectorImpl<OpFoldResult> &iterDomainSizes) const { |
235 | auto linalgOp = cast<LinalgOp>(op); |
236 | |
237 | // Check that the indexing map used for the output is a projected |
238 | // permutation. This could be relaxed with a more general approach that can |
239 | // map the offsets and sizes from the result to iteration space tiles |
240 | // (filling in full extent for dimensions not used to access the result). |
241 | AffineMap indexingMap = |
242 | linalgOp.getIndexingMapMatchingResult(op->getResult(idx: resultNumber)); |
243 | if (!indexingMap.isProjectedPermutation()) { |
244 | return op->emitOpError( |
245 | message: "unhandled tiled implementation generation when result is not " |
246 | "accessed using a permuted projection"); |
247 | } |
248 | |
249 | getMappedOffsetAndSize(linalgOp: linalgOp, b, indexingMap, offsets, sizes, |
250 | mappedOffsets&: iterDomainOffsets, mappedSizes&: iterDomainSizes); |
251 | return success(); |
252 | } |
253 | |
254 | FailureOr<TilingResult> |
255 | generateResultTileValue(Operation *op, OpBuilder &b, unsigned resultNumber, |
256 | ArrayRef<OpFoldResult> offsets, |
257 | ArrayRef<OpFoldResult> sizes) const { |
258 | SmallVector<OpFoldResult> mappedOffsets, mappedSizes; |
259 | if (failed(getIterationDomainTileFromResultTile( |
260 | op, b, resultNumber, offsets, sizes, iterDomainOffsets&: mappedOffsets, iterDomainSizes&: mappedSizes))) { |
261 | return failure(); |
262 | } |
263 | auto tilingInterfaceOp = cast<TilingInterface>(op); |
264 | FailureOr<TilingResult> tilingResult = |
265 | tilingInterfaceOp.getTiledImplementation(b, mappedOffsets, mappedSizes); |
266 | |
267 | if (failed(Result: tilingResult)) |
268 | return failure(); |
269 | |
270 | if (tilingResult->tiledOps.size() != 1) |
271 | return op->emitOpError(message: "failed to generate tiled implementation"); |
272 | |
273 | return TilingResult{ |
274 | .tiledOps: tilingResult->tiledOps, |
275 | .tiledValues: SmallVector<Value>{tilingResult->tiledValues[resultNumber]}, |
276 | .generatedSlices: tilingResult->generatedSlices}; |
277 | } |
278 | |
279 | /// Method to generate the tiled implementation of an operation from the tile |
280 | /// of the operand. |
281 | FailureOr<TilingResult> getTiledImplementationFromOperandTile( |
282 | Operation *op, OpBuilder &b, unsigned operandNumber, |
283 | ArrayRef<OpFoldResult> offsets, ArrayRef<OpFoldResult> sizes) const { |
284 | SmallVector<OpFoldResult> mappedOffsets, mappedSizes; |
285 | if (failed(getIterationDomainTileFromOperandTile( |
286 | op, b, operandNumber, offsets, sizes, iterDomainOffsets&: mappedOffsets, |
287 | iterDomainSizes&: mappedSizes))) { |
288 | return failure(); |
289 | } |
290 | return getTiledImplementation(op, b, offsets: mappedOffsets, sizes: mappedSizes); |
291 | } |
292 | |
293 | LogicalResult generateScalarImplementation(Operation *op, OpBuilder &builder, |
294 | Location loc, |
295 | ValueRange ivs) const { |
296 | auto linalgOp = cast<LinalgOp>(op); |
297 | if (!linalgOp.hasPureBufferSemantics()) |
298 | return op->emitOpError(message: "expected operation to have buffer semantics"); |
299 | |
300 | SmallVector<Value> indexedValues; |
301 | indexedValues.reserve(N: linalgOp->getNumOperands()); |
302 | Location linalgOpLoc = op->getLoc(); |
303 | /// Load the data corresponding to the block arguments that |
304 | /// represent input operands. |
305 | for (OpOperand &operand : linalgOp->getOpOperands()) { |
306 | if (!linalgOp.payloadUsesValueFromOperand(&operand)) { |
307 | indexedValues.push_back(nullptr); |
308 | continue; |
309 | } |
310 | if (linalgOp.isScalar(&operand)) { |
311 | indexedValues.push_back(operand.get()); |
312 | continue; |
313 | } |
314 | SmallVector<Value> indices = getIndicesForAccess( |
315 | builder, linalgOpLoc, linalgOp.getMatchingIndexingMap(&operand), ivs); |
316 | Value load = |
317 | builder.create<memref::LoadOp>(linalgOpLoc, operand.get(), indices); |
318 | indexedValues.push_back(load); |
319 | } |
320 | |
321 | /// Inline the op payload and store the result. |
322 | return inlinePayload(builder, linalgOp, ivs, indexedValues); |
323 | } |
324 | }; |
325 | |
326 | //===----------------------------------------------------------------------===// |
327 | // External Model for implementing `PartialReductionInterface` for `LinalgOp`s. |
328 | //===----------------------------------------------------------------------===// |
329 | |
330 | /// Return an AffineMap for a partial result for the given result number, |
331 | /// assuming the partial tiling strategy is outer-reduction loop + |
332 | /// inner-parallel tile. The returned AffineMap can be used as the replacement |
333 | /// AffineMap for the inner-parallel tile linalg op for the given result number. |
334 | /// |
335 | /// The new AffineMap is the old AffineMap with reduction dimensions appended |
336 | /// at end. |
337 | static AffineMap getPartialResultAffineMap(LinalgOp linalgOp, |
338 | ArrayRef<int> reductionDims, |
339 | unsigned resultNumber) { |
340 | AffineMap map = |
341 | linalgOp.getMatchingIndexingMap(linalgOp.getDpsInitOperand(resultNumber)); |
342 | for (int redPos : reductionDims) { |
343 | map = map.insertResult(expr: getAffineDimExpr(redPos, linalgOp.getContext()), |
344 | pos: map.getNumResults()); |
345 | } |
346 | return map; |
347 | } |
348 | |
349 | /// External model implementation of PartialReductionInterface for |
350 | /// LinalgOps. |
351 | template <typename LinalgOpTy> |
352 | struct LinalgOpPartialReductionInterface |
353 | : public PartialReductionOpInterface::ExternalModel< |
354 | LinalgOpPartialReductionInterface<LinalgOpTy>, LinalgOpTy> { |
355 | FailureOr<SmallVector<Value>> generateInitialTensorForPartialReduction( |
356 | Operation *op, OpBuilder &b, Location loc, ArrayRef<OpFoldResult> sizes, |
357 | ArrayRef<int> reductionDims) const { |
358 | auto linalgOp = cast<LinalgOp>(op); |
359 | OpBuilder::InsertionGuard guard(b); |
360 | |
361 | if (linalgOp.hasPureBufferSemantics()) |
362 | return op->emitOpError(message: "expected operation to have tensor semantics"); |
363 | |
364 | // LinalgOp implements TilingInterface. |
365 | auto tilingInterfaceOp = cast<TilingInterface>(linalgOp.getOperation()); |
366 | SmallVector<OpFoldResult> shape = |
367 | llvm::map_to_vector(tilingInterfaceOp.getIterationDomain(b), |
368 | [](Range x) { return x.size; }); |
369 | |
370 | SmallVector<OpFoldResult> tiledShape; |
371 | for (auto [tileSize, dimSize] : llvm::zip_equal(sizes, shape)) { |
372 | if (isZeroInteger(tileSize)) { |
373 | tiledShape.push_back(dimSize); |
374 | } else { |
375 | tiledShape.push_back(tileSize); |
376 | } |
377 | } |
378 | |
379 | SmallVector<Value> inits; |
380 | for (int initIdx = 0, e = linalgOp.getNumDpsInits(); initIdx < e; |
381 | ++initIdx) { |
382 | SmallVector<Operation *, 4> combinerOps; |
383 | if (!matchReduction(linalgOp.getRegionOutputArgs(), initIdx, |
384 | combinerOps) || |
385 | combinerOps.size() != 1) |
386 | return op->emitOpError(message: "Failed to anaysis the reduction operation."); |
387 | |
388 | Operation *reductionOp = combinerOps[0]; |
389 | std::optional<TypedAttr> identity = arith::getNeutralElement(reductionOp); |
390 | if (!identity.has_value()) |
391 | return op->emitOpError( |
392 | message: "Failed to get an identity value for the reduction operation."); |
393 | |
394 | // Append the new partial result dimensions. |
395 | AffineMap partialMap = |
396 | getPartialResultAffineMap(linalgOp, reductionDims, initIdx); |
397 | SmallVector<OpFoldResult> partialResultShape; |
398 | for (AffineExpr dimExpr : partialMap.getResults()) { |
399 | auto dim = cast<AffineDimExpr>(dimExpr); |
400 | partialResultShape.push_back(tiledShape[dim.getPosition()]); |
401 | } |
402 | |
403 | Type elType = |
404 | getElementTypeOrSelf(linalgOp->getResult(initIdx).getType()); |
405 | Value emptyTensor = |
406 | b.create<tensor::EmptyOp>(loc, partialResultShape, elType); |
407 | Value constantOp = b.create<arith::ConstantOp>(loc, *identity); |
408 | auto identityTensor = |
409 | b.create<linalg::FillOp>(loc, constantOp, emptyTensor); |
410 | inits.push_back(Elt: identityTensor.getResult(0)); |
411 | } |
412 | |
413 | return inits; |
414 | } |
415 | |
416 | FailureOr<TilingResult> |
417 | tileToPartialReduction(Operation *op, OpBuilder &b, Location loc, |
418 | ValueRange init, ArrayRef<OpFoldResult> offsets, |
419 | ArrayRef<OpFoldResult> sizes, |
420 | ArrayRef<int> reductionDims) const { |
421 | OpBuilder::InsertionGuard guard(b); |
422 | auto linalgOp = cast<LinalgOp>(op); |
423 | |
424 | // Step 1. Extend init maps to have reduction dimension dims, since we |
425 | // are converting them to parallel dimensions. |
426 | SmallVector<AffineMap> newInitMaps; |
427 | newInitMaps.reserve(N: linalgOp.getNumDpsInits()); |
428 | for (int idx : llvm::seq<int>(0, linalgOp.getNumDpsInits())) { |
429 | // TODO: linalg::Generic doesn't have getDpsInitOperands. Can replace |
430 | // this with a for range loop when we have it. |
431 | AffineMap newMap = |
432 | getPartialResultAffineMap(linalgOp, reductionDims, idx); |
433 | newInitMaps.push_back(newMap); |
434 | } |
435 | |
436 | // Step 2a: Extract a slice of the input operands. |
437 | SmallVector<Value> tiledInputs = makeTiledShapes( |
438 | b, loc, linalgOp, linalgOp.getDpsInputs(), offsets, sizes, {}, true); |
439 | SmallVector<Operation *> generatedSlices = llvm::map_to_vector( |
440 | llvm::make_filter_range( |
441 | tiledInputs, [](Value v) -> bool { return v.getDefiningOp(); }), |
442 | [](Value v) -> Operation * { return v.getDefiningOp(); }); |
443 | |
444 | // Step 2b: Extract a slice of the init operands. |
445 | SmallVector<Value, 1> tiledInits; |
446 | for (auto [valueMap, valueToTile] : llvm::zip_equal(t&: newInitMaps, u&: init)) { |
447 | int64_t initRank = valueMap.getNumResults(); |
448 | SmallVector<OpFoldResult> initOffset(initRank, b.getIndexAttr(0)); |
449 | SmallVector<OpFoldResult> initStride(initRank, b.getIndexAttr(1)); |
450 | SmallVector<OpFoldResult> initSizes; |
451 | for (AffineExpr dimExpr : valueMap.getResults()) { |
452 | auto dim = cast<AffineDimExpr>(Val&: dimExpr); |
453 | initSizes.push_back(Elt: sizes[dim.getPosition()]); |
454 | } |
455 | // TODO: Use SubsetExtractOpInterface here once available. |
456 | auto extractSlice = b.create<tensor::ExtractSliceOp>( |
457 | loc, valueToTile, initOffset, initSizes, initStride); |
458 | tiledInits.push_back(Elt: extractSlice); |
459 | generatedSlices.push_back(Elt: extractSlice); |
460 | } |
461 | |
462 | // Update the indexing maps. |
463 | SmallVector<AffineMap> newMaps = linalgOp.getIndexingMapsArray(); |
464 | // Change the init maps. |
465 | for (int idx : llvm::seq<int>(0, linalgOp.getNumDpsInits())) { |
466 | // TODO: linalg::Generic doesn't have getDpsInitOperands. Can replace |
467 | // this with a for range loop when we have it. |
468 | OpOperand *initOperand = linalgOp.getDpsInitOperand(idx); |
469 | int64_t mapIdx = linalgOp.getIndexingMapIndex(initOperand); |
470 | newMaps[mapIdx] = newInitMaps[idx]; |
471 | } |
472 | |
473 | // Step 3. Change the reduction dim iterator types. |
474 | SmallVector<utils::IteratorType> newIteratorTypes = |
475 | linalgOp.getIteratorTypesArray(); |
476 | for (int dim : reductionDims) |
477 | newIteratorTypes[dim] = utils::IteratorType::parallel; |
478 | |
479 | // Step 4. Create the new generic op. |
480 | auto genericOp = |
481 | b.create<GenericOp>(loc, ValueRange(tiledInits).getTypes(), tiledInputs, |
482 | tiledInits, newMaps, newIteratorTypes); |
483 | IRMapping mapping; |
484 | op->getRegion(index: 0).cloneInto(&genericOp.getRegion(), |
485 | genericOp.getRegion().begin(), mapping); |
486 | return TilingResult{ |
487 | {genericOp.getOperation()}, |
488 | llvm::map_to_vector(genericOp->getResults(), |
489 | [](OpResult r) -> Value { return r; }), |
490 | generatedSlices}; |
491 | } |
492 | |
493 | FailureOr<MergeResult> mergeReductions(Operation *op, OpBuilder &b, |
494 | Location loc, ValueRange partialReduce, |
495 | ArrayRef<int> reductionDims) const { |
496 | auto linalgOp = cast<LinalgOp>(op); |
497 | |
498 | // Permute the reduction dims as permuted by the partial result map. |
499 | |
500 | int64_t numInits = linalgOp.getNumDpsInits(); |
501 | SmallVector<Operation *> mergeOperations; |
502 | SmallVector<Value> replacements; |
503 | for (int idx : llvm::seq(numInits)) { |
504 | // linalg.reduce's iteration space is the tiled result's iteration space |
505 | // (and not the tiled operation's iteration space). To account for this, |
506 | // permute the reduction dimensions based on the partial result map of the |
507 | // tiled result. |
508 | AffineMap partialMap = |
509 | getPartialResultAffineMap(linalgOp, reductionDims, idx); |
510 | SmallVector<int64_t> partialReductionDims; |
511 | for (auto [resultNum, dimExpr] : |
512 | llvm::enumerate(partialMap.getResults())) { |
513 | unsigned dim = cast<AffineDimExpr>(dimExpr).getPosition(); |
514 | if (llvm::is_contained(reductionDims, dim)) { |
515 | partialReductionDims.push_back(resultNum); |
516 | } |
517 | } |
518 | |
519 | Value partialResult = partialReduce[idx]; |
520 | Value init = linalgOp.getDpsInits()[idx]; |
521 | |
522 | auto reduction = b.create<linalg::ReduceOp>( |
523 | loc, partialResult, init, partialReductionDims, |
524 | [&linalgOp, &idx](OpBuilder &b, Location loc, ValueRange inputs) { |
525 | // Get the combiner op. |
526 | SmallVector<Operation *, 4> combinerOps; |
527 | matchReduction(linalgOp.getRegionOutputArgs(), idx, combinerOps); |
528 | Operation *clonedReductionOp = b.clone(*combinerOps[0]); |
529 | // Combine the input at idx and output at numInits + idx. |
530 | clonedReductionOp->setOperand(0, inputs[0]); |
531 | clonedReductionOp->setOperand(1, inputs[1]); |
532 | b.create<linalg::YieldOp>(loc, clonedReductionOp->getResult(0)); |
533 | }); |
534 | |
535 | mergeOperations.push_back(reduction); |
536 | replacements.push_back(reduction->getResult(0)); |
537 | } |
538 | |
539 | return MergeResult{.mergeOps: mergeOperations, .replacements: replacements}; |
540 | } |
541 | |
542 | LogicalResult getPartialResultTilePosition( |
543 | Operation *op, OpBuilder &b, unsigned resultNumber, |
544 | ArrayRef<OpFoldResult> offsets, ArrayRef<OpFoldResult> sizes, |
545 | SmallVector<OpFoldResult> &resultOffsets, |
546 | SmallVector<OpFoldResult> &resultSizes, |
547 | ArrayRef<int> reductionDims) const { |
548 | auto linalgOp = cast<LinalgOp>(op); |
549 | |
550 | AffineMap partialMap = |
551 | getPartialResultAffineMap(linalgOp, reductionDims, resultNumber); |
552 | for (AffineExpr dimExpr : partialMap.getResults()) { |
553 | unsigned dim = cast<AffineDimExpr>(dimExpr).getPosition(); |
554 | resultSizes.push_back(sizes[dim]); |
555 | |
556 | if (llvm::is_contained(reductionDims, dim)) { |
557 | // Reduction dims are reduced, and are always outputed in the same |
558 | // place. So use offset 0 for them. |
559 | resultOffsets.push_back(b.getIndexAttr(0)); |
560 | } else { |
561 | resultOffsets.push_back(offsets[dim]); |
562 | } |
563 | } |
564 | |
565 | return success(); |
566 | } |
567 | }; |
568 | |
569 | template <typename OpTy> |
570 | static SmallVector<Range> getPackUnPackIterationDomain(OpTy op, |
571 | OpBuilder &builder) { |
572 | static_assert(llvm::is_one_of<OpTy, PackOp, UnPackOp>::value, |
573 | "applies to only pack or unpack operations"); |
574 | OpBuilder::InsertionGuard g(builder); |
575 | int64_t rank = (std::is_same<OpTy, PackOp>::value) ? op.getSourceRank() |
576 | : op.getDestRank(); |
577 | OpFoldResult zero = builder.getIndexAttr(0); |
578 | OpFoldResult one = builder.getIndexAttr(1); |
579 | ReifiedRankedShapedTypeDims resultShape; |
580 | (void)reifyResultShapes(builder, op, resultShape); |
581 | SmallVector<Range> loopBounds(rank); |
582 | for (auto dim : llvm::seq<int64_t>(Begin: 0, End: rank)) { |
583 | loopBounds[dim].offset = zero; |
584 | loopBounds[dim].stride = one; |
585 | loopBounds[dim].size = resultShape[0][dim]; |
586 | } |
587 | return loopBounds; |
588 | } |
589 | |
590 | static void applyPermToRange(SmallVector<OpFoldResult> &offsets, |
591 | SmallVector<OpFoldResult> &sizes, |
592 | ArrayRef<int64_t> permutation) { |
593 | if (permutation.empty()) |
594 | return; |
595 | applyPermutationToVector<OpFoldResult>(inVec&: offsets, permutation); |
596 | applyPermutationToVector<OpFoldResult>(inVec&: sizes, permutation); |
597 | } |
598 | |
599 | struct PackOpTiling |
600 | : public TilingInterface::ExternalModel<PackOpTiling, linalg::PackOp> { |
601 | |
602 | SmallVector<utils::IteratorType> getLoopIteratorTypes(Operation *op) const { |
603 | // Note that here we only consider untiled dimensions and outer tiled data |
604 | // dimensions, the inner tiled data dimensions are materialized when |
605 | // building the body of the operation. |
606 | auto packOp = cast<PackOp>(op); |
607 | SmallVector<utils::IteratorType> iteratorTypes( |
608 | packOp.getSourceRank(), utils::IteratorType::parallel); |
609 | return iteratorTypes; |
610 | } |
611 | |
612 | SmallVector<Range> getIterationDomain(Operation *op, OpBuilder &b) const { |
613 | return getPackUnPackIterationDomain<PackOp>(cast<PackOp>(op), b); |
614 | } |
615 | |
616 | FailureOr<TilingResult> |
617 | getTiledImplementation(Operation *op, OpBuilder &b, |
618 | ArrayRef<OpFoldResult> offsets, |
619 | ArrayRef<OpFoldResult> sizes) const { |
620 | auto packOp = cast<PackOp>(op); |
621 | Location loc = packOp.getLoc(); |
622 | |
623 | // The tiling is applied on interchanged dimensions. We have to undo the |
624 | // interchange to map sizes and offsets to the original input. |
625 | int64_t inputRank = packOp.getSourceRank(); |
626 | SmallVector<OpFoldResult> origOffsets(offsets); |
627 | SmallVector<OpFoldResult> origSizes(sizes); |
628 | applyPermToRange(origOffsets, origSizes, |
629 | invertPermutationVector(packOp.getOuterDimsPerm())); |
630 | |
631 | DenseMap<int64_t, OpFoldResult> dimAndTileMapping = |
632 | packOp.getDimAndTileMapping(); |
633 | SmallVector<OpFoldResult> srcDimValues = |
634 | tensor::getMixedSizes(builder&: b, loc, value: packOp.getSource()); |
635 | SmallVector<OpFoldResult> inputIndices, inputSizes; |
636 | for (auto dim : llvm::seq<int64_t>(0, inputRank)) { |
637 | using AV = affine::AffineValueExpr; |
638 | affine::AffineBuilder ab(b, loc); |
639 | AffineExpr dim0, dim1, sym; |
640 | bindDims(b.getContext(), dim0, dim1); |
641 | bindSymbols(b.getContext(), sym); |
642 | if (dimAndTileMapping.count(dim)) { |
643 | // If the data dimension is tiled, the i-th index is the product of |
644 | // offset_i and tile_i, and the i-th size is the product of sizes_i and |
645 | // tile_i. |
646 | auto avOffset = AV(dim0).bind(origOffsets[dim]); |
647 | auto avSize = AV(dim0).bind(origSizes[dim]); |
648 | auto avTileSize = AV(sym).bind(dimAndTileMapping[dim]); |
649 | inputIndices.push_back(ab.mul(avOffset, avTileSize)); |
650 | inputSizes.push_back(ab.mul(avSize, avTileSize)); |
651 | } else { |
652 | inputIndices.push_back(origOffsets[dim]); |
653 | inputSizes.push_back(origSizes[dim]); |
654 | } |
655 | |
656 | // Limit the size of the input operand for incomplete tiles. |
657 | if (packOp.getPaddingValue()) { |
658 | OpFoldResult dimSize = srcDimValues[dim]; |
659 | auto avDimSize = AV(dim0).bind(dimSize); |
660 | auto avInputIdx = AV(dim1).bind(inputIndices.back()); |
661 | inputSizes.back() = |
662 | ab.min({inputSizes.back(), ab.sub(avDimSize, avInputIdx)}); |
663 | } |
664 | } |
665 | |
666 | auto oneAttr = b.getI64IntegerAttr(1); |
667 | SmallVector<OpFoldResult> strides(inputRank, oneAttr); |
668 | |
669 | SmallVector<Value> tiledOperands; |
670 | auto sourceSlice = b.create<tensor::ExtractSliceOp>( |
671 | loc, packOp.getSource(), inputIndices, inputSizes, strides); |
672 | tiledOperands.push_back(Elt: sourceSlice); |
673 | |
674 | SmallVector<OpFoldResult> outputOffsets, outputSizes; |
675 | if (failed(Result: getResultTilePosition(op, b, resultNumber: 0, offsets, sizes, resultOffsets&: outputOffsets, |
676 | resultSizes&: outputSizes))) |
677 | return {}; |
678 | |
679 | strides.append(packOp.getDestRank() - inputRank, oneAttr); |
680 | auto outSlice = b.create<tensor::ExtractSliceOp>( |
681 | loc, packOp.getDest(), outputOffsets, outputSizes, strides); |
682 | tiledOperands.push_back(Elt: outSlice); |
683 | |
684 | if (auto val = packOp.getPaddingValue()) |
685 | tiledOperands.push_back(Elt: val); |
686 | for (auto tile : packOp.getInnerTiles()) |
687 | tiledOperands.push_back(tile); |
688 | |
689 | Operation *tiledPackOp = b.create<PackOp>( |
690 | loc, TypeRange{outSlice.getType()}, tiledOperands, op->getAttrs()); |
691 | |
692 | return TilingResult{ |
693 | .tiledOps: {tiledPackOp}, |
694 | .tiledValues: SmallVector<Value>(tiledPackOp->getResults()), |
695 | .generatedSlices: llvm::to_vector(Range: ArrayRef<Operation *>{sourceSlice, outSlice})}; |
696 | } |
697 | |
698 | LogicalResult |
699 | getResultTilePosition(Operation *op, OpBuilder &b, unsigned resultNumber, |
700 | ArrayRef<OpFoldResult> offsets, |
701 | ArrayRef<OpFoldResult> sizes, |
702 | SmallVector<OpFoldResult> &resultOffsets, |
703 | SmallVector<OpFoldResult> &resultSizes) const { |
704 | // The iteration domain is over outer dimensions of packed layout. In this |
705 | // context, the outer dimensions of `resultOffsets` are `offsets`. The |
706 | // inner dimensions of `resultOffsets` are zeros because tiling is not |
707 | // applied to them. |
708 | auto packOp = cast<PackOp>(op); |
709 | int64_t inputRank = packOp.getSourceRank(); |
710 | int64_t outputRank = packOp.getDestRank(); |
711 | auto zeroAttr = b.getI64IntegerAttr(0); |
712 | resultOffsets.assign(in_start: offsets.begin(), in_end: offsets.end()); |
713 | resultOffsets.append(outputRank - inputRank, zeroAttr); |
714 | |
715 | ReifiedRankedShapedTypeDims outputShape; |
716 | (void)reifyResultShapes(b, packOp, outputShape); |
717 | resultSizes.assign(in_start: sizes.begin(), in_end: sizes.end()); |
718 | for (auto dataTileDim : llvm::seq<unsigned>(inputRank, outputRank)) |
719 | resultSizes.push_back(outputShape[0][dataTileDim]); |
720 | |
721 | return success(); |
722 | } |
723 | |
724 | FailureOr<TilingResult> |
725 | generateResultTileValue(Operation *op, OpBuilder &b, unsigned resultNumber, |
726 | ArrayRef<OpFoldResult> offsets, |
727 | ArrayRef<OpFoldResult> sizes) const { |
728 | auto packOp = cast<PackOp>(op); |
729 | int64_t numTiles = packOp.getInnerDimsPos().size(); |
730 | |
731 | // tensor.pack op is fusible (as a producer) only if full inner tiles are |
732 | // iterated or inner dims are not tiled. Otherwise, it will generate a |
733 | // sequence of non-trivial ops (for partial tiles). |
734 | for (auto offset : offsets.take_back(numTiles)) |
735 | if (!isZeroInteger(offset)) |
736 | return failure(); |
737 | |
738 | for (auto iter : |
739 | llvm::zip_equal(packOp.getMixedTiles(), sizes.take_back(numTiles))) |
740 | if (!isEqualConstantIntOrValue(std::get<0>(iter), std::get<1>(iter))) |
741 | return failure(); |
742 | |
743 | FailureOr<TilingResult> tilingResult = getTiledImplementation( |
744 | op, b, offsets: offsets.drop_back(N: numTiles), sizes: sizes.drop_back(N: numTiles)); |
745 | if (failed(Result: tilingResult)) |
746 | return failure(); |
747 | return tilingResult.value(); |
748 | } |
749 | |
750 | /// Method to return the position of iteration domain tile computed by the |
751 | /// tiled operation. In current `tensor.pack` context, the `resultOffsets` and |
752 | /// `resultSizes` only cover outer dimensions. |
753 | LogicalResult getIterationDomainTileFromOperandTile( |
754 | Operation *op, OpBuilder &b, unsigned operandNumber, |
755 | ArrayRef<OpFoldResult> offsets, ArrayRef<OpFoldResult> sizes, |
756 | SmallVectorImpl<OpFoldResult> &resultOffsets, |
757 | SmallVectorImpl<OpFoldResult> &resultSizes) const { |
758 | if (operandNumber != 0) |
759 | return failure(); |
760 | |
761 | auto packOp = cast<PackOp>(op); |
762 | // It is not trivial to infer dest tile from source tile if `packOp` has |
763 | // padding semantic. |
764 | if (packOp.getPaddingValue()) |
765 | return failure(); |
766 | |
767 | Location loc = packOp.getLoc(); |
768 | |
769 | SmallVector<OpFoldResult> outerDimOffsets, outerDimSizes; |
770 | DenseMap<int64_t, OpFoldResult> dimAndTileMapping = |
771 | packOp.getDimAndTileMapping(); |
772 | for (auto dim : llvm::seq<int64_t>(packOp.getSourceRank())) { |
773 | if (dimAndTileMapping.count(dim)) { |
774 | FailureOr<int64_t> cstSize = |
775 | ValueBoundsConstraintSet::computeConstantBound( |
776 | presburger::BoundType::UB, sizes[dim], |
777 | /*stopCondition=*/nullptr, /*closedUB=*/true); |
778 | std::optional<int64_t> cstInnerSize = |
779 | getConstantIntValue(dimAndTileMapping[dim]); |
780 | // Currently fusing `packOp` as consumer only expects perfect tiling |
781 | // scenario because even if without padding semantic, the `packOp` may |
782 | // also yield incomplete tiles. E.g. tensor<30xf32> -> tensor<5x6xf32>, |
783 | // where the `tileSize` from operand of `packOp` is 5, which is not |
784 | // exactly divided by `innerTile`(=6) of `packOp`. As the result: |
785 | // 1. the first slice is extracted from (0) to (4) and inserted into |
786 | // (0,0)~(0,4) at first row. |
787 | // 2. the second slice is extracted from (5) to (9) and SHOULD BE |
788 | // respectively inserted into two rows with different length, including |
789 | // first row: (0,5) and second row (1,0)~(1,3). It is hard to coordinate |
790 | // them, thus adding below constraint to bypass them temporarily. In |
791 | // another word, we can only support tiling with consumer if the tile |
792 | // size for the producer is a multiple of the inner tile size for the |
793 | // packed dimensions at this moment. |
794 | if (failed(cstSize) || !cstInnerSize || *cstSize % *cstInnerSize != 0) { |
795 | return failure(); |
796 | } |
797 | |
798 | using AV = affine::AffineValueExpr; |
799 | affine::AffineBuilder ab(b, loc); |
800 | AffineExpr dim0, sym; |
801 | bindDims(b.getContext(), dim0); |
802 | bindSymbols(b.getContext(), sym); |
803 | auto avOffset = AV(dim0).bind(offsets[dim]); |
804 | auto avSize = AV(dim0).bind(sizes[dim]); |
805 | auto avTileSize = AV(sym).bind(dimAndTileMapping[dim]); |
806 | outerDimOffsets.push_back(ab.floor(avOffset, avTileSize)); |
807 | outerDimSizes.push_back(ab.ceil(avSize, avTileSize)); |
808 | } else { |
809 | outerDimOffsets.push_back(offsets[dim]); |
810 | outerDimSizes.push_back(sizes[dim]); |
811 | } |
812 | } |
813 | applyPermToRange(outerDimOffsets, outerDimSizes, packOp.getOuterDimsPerm()); |
814 | resultOffsets = outerDimOffsets; |
815 | resultSizes = outerDimSizes; |
816 | return success(); |
817 | } |
818 | |
819 | /// Method to return the tiled implementation of tensor.pack as a consumer. |
820 | FailureOr<TilingResult> getTiledImplementationFromOperandTile( |
821 | Operation *op, OpBuilder &b, unsigned operandNumber, |
822 | ArrayRef<OpFoldResult> offsets, ArrayRef<OpFoldResult> sizes) const { |
823 | if (operandNumber != 0) |
824 | return failure(); |
825 | |
826 | auto packOp = cast<PackOp>(op); |
827 | Location loc = packOp.getLoc(); |
828 | |
829 | int64_t inputRank = packOp.getSourceRank(); |
830 | auto oneAttr = b.getI64IntegerAttr(1); |
831 | SmallVector<OpFoldResult> strides(inputRank, oneAttr); |
832 | |
833 | SmallVector<Value> tiledOperands; |
834 | auto sourceSlice = b.create<tensor::ExtractSliceOp>( |
835 | loc, packOp.getSource(), offsets, sizes, strides); |
836 | tiledOperands.push_back(Elt: sourceSlice); |
837 | |
838 | SmallVector<OpFoldResult> outerDimOffsets, outerDimSizes; |
839 | if (failed(Result: getIterationDomainTileFromOperandTile( |
840 | op, b, /*operandNumber=*/0, offsets, sizes, resultOffsets&: outerDimOffsets, |
841 | resultSizes&: outerDimSizes))) |
842 | return failure(); |
843 | |
844 | SmallVector<OpFoldResult> outputOffsets, outputSizes; |
845 | if (failed(Result: getResultTilePosition(op, b, resultNumber: 0, offsets: outerDimOffsets, sizes: outerDimSizes, |
846 | resultOffsets&: outputOffsets, resultSizes&: outputSizes))) |
847 | return failure(); |
848 | |
849 | strides.append(packOp.getDestRank() - inputRank, oneAttr); |
850 | auto outSlice = b.create<tensor::ExtractSliceOp>( |
851 | loc, packOp.getDest(), outputOffsets, outputSizes, strides); |
852 | tiledOperands.push_back(Elt: outSlice); |
853 | |
854 | assert(!packOp.getPaddingValue() && "Expect no padding semantic"); |
855 | for (auto tile : packOp.getInnerTiles()) |
856 | tiledOperands.push_back(tile); |
857 | |
858 | Operation *tiledPackOp = b.create<PackOp>( |
859 | loc, TypeRange{outSlice.getType()}, tiledOperands, op->getAttrs()); |
860 | |
861 | return TilingResult{ |
862 | .tiledOps: {tiledPackOp}, |
863 | .tiledValues: SmallVector<Value>(tiledPackOp->getResults()), |
864 | .generatedSlices: llvm::to_vector(Range: ArrayRef<Operation *>{sourceSlice, outSlice})}; |
865 | } |
866 | }; |
867 | |
868 | struct UnpackTileDimInfo { |
869 | bool isAlignedToInnerTileSize; |
870 | OpFoldResult sourceOffset; |
871 | OpFoldResult sourceSize; |
872 | OpFoldResult resultOffset; |
873 | OpFoldResult destExpandedSize; |
874 | }; |
875 | |
876 | /// Returns the needed information for tiling unpack op on `tileDim` with given |
877 | /// `tileOffset` and `tileSize`. For more details, see the comment of the |
878 | /// `getTiledImplementation`. |
879 | static UnpackTileDimInfo getUnpackTileDimInfo(OpBuilder &b, UnPackOp unpackOp, |
880 | int64_t tileDim, |
881 | OpFoldResult tileOffset, |
882 | OpFoldResult tileSize) { |
883 | UnpackTileDimInfo info; |
884 | Attribute zeroAttr = b.getIndexAttr(0); |
885 | Attribute oneAttr = b.getIndexAttr(1); |
886 | DenseMap<int64_t, OpFoldResult> dimAndTileMapping = |
887 | unpackOp.getDimAndTileMapping(); |
888 | // The dimension is not one of packed data dimension. |
889 | if (!dimAndTileMapping.count(Val: tileDim)) { |
890 | info.isAlignedToInnerTileSize = true; |
891 | info.sourceOffset = tileOffset; |
892 | info.sourceSize = tileSize; |
893 | info.resultOffset = zeroAttr; |
894 | info.destExpandedSize = tileSize; |
895 | return info; |
896 | } |
897 | |
898 | Location loc = unpackOp.getLoc(); |
899 | using AV = affine::AffineValueExpr; |
900 | affine::AffineBuilder ab(b, loc); |
901 | AffineExpr dim0, dim1, sym0; |
902 | bindDims(ctx: b.getContext(), exprs&: dim0, exprs&: dim1); |
903 | bindSymbols(ctx: b.getContext(), exprs&: sym0); |
904 | |
905 | OpFoldResult innerTileSize = dimAndTileMapping[tileDim]; |
906 | |
907 | info.isAlignedToInnerTileSize = false; |
908 | FailureOr<int64_t> cstSize = ValueBoundsConstraintSet::computeConstantBound( |
909 | type: presburger::BoundType::UB, var: tileSize, |
910 | /*stopCondition=*/nullptr, /*closedUB=*/true); |
911 | std::optional<int64_t> cstInnerSize = getConstantIntValue(ofr: innerTileSize); |
912 | if (!failed(Result: cstSize) && cstInnerSize) { |
913 | if (*cstSize % *cstInnerSize == 0) |
914 | info.isAlignedToInnerTileSize = true; |
915 | |
916 | // If the tiling size equals to the inner tiling size, the outer dims are |
917 | // always 1. |
918 | if (*cstInnerSize == *cstSize) { |
919 | auto lhs = AV(dim0).bind(v: tileOffset); |
920 | auto rhs = AV(dim1).bind(v: innerTileSize); |
921 | info.sourceOffset = ab.floor(lhs, rhs: rhs); |
922 | info.sourceSize = oneAttr; |
923 | info.resultOffset = zeroAttr; |
924 | info.destExpandedSize = tileSize; |
925 | return info; |
926 | } |
927 | } |
928 | |
929 | if (info.isAlignedToInnerTileSize) { |
930 | info.sourceOffset = |
931 | ab.floor(lhs: AV(dim0).bind(v: tileOffset), rhs: AV(dim1).bind(v: innerTileSize)); |
932 | info.resultOffset = zeroAttr; |
933 | info.destExpandedSize = tileSize; |
934 | |
935 | // The ceilDiv is needed here because there could be incomplete tile even |
936 | // it is perfect tiling cases. E.g., |
937 | // %0 = unpack tensor<33x2xf32> into tensor<64xf32> |
938 | // If the tiling size is 32, there will be 3 tiles. Two of them have |
939 | // size=32; one of them have size=2. The size is represented using |
940 | // affine_min op; we need ceilDiv. |
941 | info.sourceSize = |
942 | ab.ceil(lhs: AV(dim0).bind(v: tileSize), rhs: AV(dim1).bind(v: innerTileSize)); |
943 | return info; |
944 | } |
945 | |
946 | affine::DivModValue firstCoord = affine::getDivMod( |
947 | b, loc, lhs: getValueOrCreateConstantIndexOp(b, loc, ofr: tileOffset), |
948 | rhs: getValueOrCreateConstantIndexOp(b, loc, ofr: innerTileSize)); |
949 | OpFoldResult tileExclusiveBound = |
950 | ab.add(lhs: AV(dim0).bind(v: tileOffset), rhs: AV(dim1).bind(v: tileSize)); |
951 | affine::DivModValue lastCoord = affine::getDivMod( |
952 | b, loc, |
953 | lhs: getValueOrCreateConstantIndexOp( |
954 | b, loc, |
955 | ofr: ab.sub(lhs: AV(dim0).bind(v: tileExclusiveBound), rhs: AV(dim1).bind(v: oneAttr))), |
956 | rhs: getValueOrCreateConstantIndexOp(b, loc, ofr: innerTileSize)); |
957 | |
958 | OpFoldResult lengthMinusOne = ab.sub(lhs: AV(dim0).bind(v: lastCoord.quotient), |
959 | rhs: AV(dim1).bind(v: firstCoord.quotient)); |
960 | info.sourceSize = |
961 | ab.add(lhs: AV(dim0).bind(v: lengthMinusOne), rhs: AV(dim1).bind(v: oneAttr)); |
962 | info.sourceOffset = firstCoord.quotient; |
963 | info.resultOffset = firstCoord.remainder; |
964 | // Do not create an Affine ops for expanded size because the affine op is too |
965 | // complicated which would trigger an issue in affine ops simplification. |
966 | info.destExpandedSize = b.createOrFold<arith::MulIOp>( |
967 | loc, getValueOrCreateConstantIndexOp(b, loc, info.sourceSize), |
968 | getValueOrCreateConstantIndexOp(b, loc, innerTileSize)); |
969 | return info; |
970 | } |
971 | |
972 | struct UnPackOpTiling |
973 | : public TilingInterface::ExternalModel<UnPackOpTiling, linalg::UnPackOp> { |
974 | |
975 | SmallVector<utils::IteratorType> getLoopIteratorTypes(Operation *op) const { |
976 | auto unpackOp = cast<UnPackOp>(op); |
977 | SmallVector<utils::IteratorType> iteratorTypes( |
978 | unpackOp.getDestRank(), utils::IteratorType::parallel); |
979 | return iteratorTypes; |
980 | } |
981 | |
982 | SmallVector<Range> getIterationDomain(Operation *op, OpBuilder &b) const { |
983 | return getPackUnPackIterationDomain<UnPackOp>(cast<UnPackOp>(op), b); |
984 | } |
985 | |
986 | /// There are two cases in tiling unpack ops. If the tiling size is aligned to |
987 | /// the inner tile size, the corresponding tiles of source are all complete. |
988 | /// Otherwise, there are in-complete tiles. We will need to expand the slice |
989 | /// of source for getting complete tiles. The tiled unpack op unpacks more |
990 | /// data from source, so We'll need an extract_slice op to shift and truncate |
991 | /// the output. |
992 | /// Take Nn_to_N as an example. Say that N=32, n=8, and tiling_size=15. The |
993 | /// coordinates of second tile (i.e., result[15..31]) are |
994 | /// [(1, 7), (2, 0,), (2, 1) ... (3, 6), (3, 7)]. The first row and the last |
995 | /// row are incomplete tiles. To represent the unpack op, we have to complete |
996 | /// the rows. I.e., the input coordinates would start with (1, 0); end with |
997 | /// (3, 7). In this context, the tiled unpack produces a (3 * n) elements |
998 | /// because there are 3 rows in total. Follow by a tensor.extract_slice op, we |
999 | /// can get the actual result. |
1000 | FailureOr<TilingResult> |
1001 | getTiledImplementation(Operation *op, OpBuilder &b, |
1002 | ArrayRef<OpFoldResult> offsets, |
1003 | ArrayRef<OpFoldResult> sizes) const { |
1004 | auto unpackOp = cast<UnPackOp>(op); |
1005 | int64_t srcRank = unpackOp.getSourceRank(); |
1006 | int64_t destRank = unpackOp.getDestRank(); |
1007 | int64_t numInnerTiles = srcRank - destRank; |
1008 | Location loc = unpackOp.getLoc(); |
1009 | |
1010 | // The perfect tiling case indicates that the tiling sizes are multiple of |
1011 | // inner_tile_size. In this context, no extra data is needed when |
1012 | // representing the tiled unpack op. |
1013 | bool isPerfectTilingCase = true; |
1014 | Attribute oneAttr = b.getIndexAttr(1); |
1015 | SmallVector<OpFoldResult> sliceSrcStrides(destRank, oneAttr); |
1016 | SmallVector<OpFoldResult> sliceSrcIndices, sliceSrcSizes; |
1017 | SmallVector<OpFoldResult> destExpandedSizes, resultOffsetsFromDest; |
1018 | for (auto dim : llvm::seq<int64_t>(0, destRank)) { |
1019 | UnpackTileDimInfo info = |
1020 | getUnpackTileDimInfo(b, unpackOp, dim, offsets[dim], sizes[dim]); |
1021 | if (!info.isAlignedToInnerTileSize) |
1022 | isPerfectTilingCase = false; |
1023 | sliceSrcIndices.push_back(info.sourceOffset); |
1024 | sliceSrcSizes.push_back(info.sourceSize); |
1025 | destExpandedSizes.push_back(info.destExpandedSize); |
1026 | resultOffsetsFromDest.push_back(info.resultOffset); |
1027 | } |
1028 | |
1029 | // The tiling is applied on destination dimensions. We have to apply the |
1030 | // interchange on source dimensions if outer_dims_perm is set. |
1031 | applyPermToRange(sliceSrcIndices, sliceSrcSizes, |
1032 | unpackOp.getOuterDimsPerm()); |
1033 | Attribute zeroAttr = b.getIndexAttr(0); |
1034 | sliceSrcIndices.append(NumInputs: numInnerTiles, Elt: zeroAttr); |
1035 | sliceSrcSizes.append(unpackOp.getMixedTiles()); |
1036 | sliceSrcStrides.append(NumInputs: numInnerTiles, Elt: oneAttr); |
1037 | SmallVector<Operation *> generatedSlices; |
1038 | tensor::ExtractSliceOp sliceSource = b.create<tensor::ExtractSliceOp>( |
1039 | loc, unpackOp.getSource(), sliceSrcIndices, sliceSrcSizes, |
1040 | sliceSrcStrides); |
1041 | generatedSlices.push_back(Elt: sliceSource); |
1042 | |
1043 | SmallVector<OpFoldResult> destStrides(destRank, oneAttr); |
1044 | Value sliceDest; |
1045 | if (isPerfectTilingCase) { |
1046 | auto destSliceOp = b.create<tensor::ExtractSliceOp>( |
1047 | loc, unpackOp.getDest(), offsets, sizes, destStrides); |
1048 | sliceDest = destSliceOp; |
1049 | generatedSlices.push_back(Elt: destSliceOp); |
1050 | } else { |
1051 | sliceDest = b.create<tensor::EmptyOp>( |
1052 | loc, destExpandedSizes, unpackOp.getDestType().getElementType()); |
1053 | } |
1054 | |
1055 | SmallVector<Value> tiledOperands = {sliceSource.getResult(), sliceDest}; |
1056 | for (auto tile : unpackOp.getInnerTiles()) |
1057 | tiledOperands.push_back(tile); |
1058 | |
1059 | Operation *tiledUnpackOp = b.create<UnPackOp>( |
1060 | loc, TypeRange{sliceDest.getType()}, tiledOperands, op->getAttrs()); |
1061 | |
1062 | if (isPerfectTilingCase) |
1063 | return TilingResult{.tiledOps: {tiledUnpackOp}, |
1064 | .tiledValues: SmallVector<Value>(tiledUnpackOp->getResults()), |
1065 | .generatedSlices: generatedSlices}; |
1066 | |
1067 | auto extractSlice = b.create<tensor::ExtractSliceOp>( |
1068 | loc, tiledUnpackOp->getResult(idx: 0), resultOffsetsFromDest, sizes, |
1069 | destStrides); |
1070 | return TilingResult{ |
1071 | {tiledUnpackOp}, {extractSlice.getResult()}, generatedSlices}; |
1072 | } |
1073 | |
1074 | LogicalResult |
1075 | getResultTilePosition(Operation *op, OpBuilder &b, unsigned resultNumber, |
1076 | ArrayRef<OpFoldResult> offsets, |
1077 | ArrayRef<OpFoldResult> sizes, |
1078 | SmallVector<OpFoldResult> &resultOffsets, |
1079 | SmallVector<OpFoldResult> &resultSizes) const { |
1080 | resultOffsets = llvm::to_vector(Range&: offsets); |
1081 | resultSizes = llvm::to_vector(Range&: sizes); |
1082 | return success(); |
1083 | } |
1084 | |
1085 | FailureOr<TilingResult> |
1086 | generateResultTileValue(Operation *op, OpBuilder &b, unsigned resultNumber, |
1087 | ArrayRef<OpFoldResult> offsets, |
1088 | ArrayRef<OpFoldResult> sizes) const { |
1089 | FailureOr<TilingResult> tilingResult = |
1090 | getTiledImplementation(op, b, offsets, sizes); |
1091 | if (failed(Result: tilingResult)) |
1092 | return failure(); |
1093 | return tilingResult.value(); |
1094 | } |
1095 | |
1096 | /// Method to return the position of iteration domain tile computed by the |
1097 | /// tiled operation. |
1098 | LogicalResult getIterationDomainTileFromOperandTile( |
1099 | Operation *op, OpBuilder &b, unsigned operandNumber, |
1100 | ArrayRef<OpFoldResult> offsets, ArrayRef<OpFoldResult> sizes, |
1101 | SmallVectorImpl<OpFoldResult> &resultOffsets, |
1102 | SmallVectorImpl<OpFoldResult> &resultSizes) const { |
1103 | auto unPackOp = cast<UnPackOp>(op); |
1104 | // If the operand tile is the dest, then no adjustment is needed. |
1105 | if (operandNumber == unPackOp.getDestMutable().getOperandNumber()) { |
1106 | resultOffsets = llvm::to_vector(Range&: offsets); |
1107 | resultSizes = llvm::to_vector(Range&: sizes); |
1108 | return success(); |
1109 | } |
1110 | Location loc = unPackOp.getLoc(); |
1111 | |
1112 | int64_t numTiles = unPackOp.getInnerDimsPos().size(); |
1113 | auto destOffsets = offsets.drop_back(N: numTiles); |
1114 | auto destSizes = sizes.drop_back(N: numTiles); |
1115 | // The tiling is applied on interchanged dimensions. We have to undo the |
1116 | // interchange to map sizes and offsets to the original input. |
1117 | int64_t outputRank = unPackOp.getDestRank(); |
1118 | ReifiedRankedShapedTypeDims reifiedReturnShapes; |
1119 | if (failed(reifyResultShapes(b, unPackOp, reifiedReturnShapes))) |
1120 | return failure(); |
1121 | SmallVector<OpFoldResult> outputMixedSizes = reifiedReturnShapes.front(); |
1122 | SmallVector<OpFoldResult> origOffsets(destOffsets); |
1123 | SmallVector<OpFoldResult> origSizes(destSizes); |
1124 | applyPermToRange(origOffsets, origSizes, |
1125 | invertPermutationVector(unPackOp.getOuterDimsPerm())); |
1126 | |
1127 | DenseMap<int64_t, OpFoldResult> dimAndTileMapping = |
1128 | unPackOp.getDimAndTileMapping(); |
1129 | |
1130 | for (auto dim : llvm::seq<int64_t>(0, outputRank)) { |
1131 | using AV = affine::AffineValueExpr; |
1132 | affine::AffineBuilder ab(b, loc); |
1133 | AffineExpr dim0, dim1, sym0; |
1134 | bindDims(b.getContext(), dim0, dim1); |
1135 | bindSymbols(b.getContext(), sym0); |
1136 | if (dimAndTileMapping.count(dim)) { |
1137 | // If the data dimension is tiled, the i-th index is the product of |
1138 | // offset_i and tile_i, and the i-th size is the product of sizes_i and |
1139 | // tile_i. The sizes must be clamped to the sizes of the unpack result. |
1140 | auto avOffset = AV(dim0).bind(origOffsets[dim]); |
1141 | auto avSize = AV(dim0).bind(origSizes[dim]); |
1142 | auto avTileSize = AV(sym0).bind(dimAndTileMapping[dim]); |
1143 | auto avResultSize = AV(dim0).bind(outputMixedSizes[dim]); |
1144 | resultOffsets.push_back(ab.mul(avOffset, avTileSize)); |
1145 | auto avResultOffset = AV(dim1).bind(resultOffsets.back()); |
1146 | resultSizes.push_back(ab.min({ab.mul(avSize, avTileSize), |
1147 | ab.sub(avResultSize, avResultOffset)})); |
1148 | } else { |
1149 | resultOffsets.push_back(origOffsets[dim]); |
1150 | resultSizes.push_back(origSizes[dim]); |
1151 | } |
1152 | } |
1153 | return success(); |
1154 | } |
1155 | |
1156 | /// Method to return the tiled implementation of tensor.unpack as a consumer. |
1157 | FailureOr<TilingResult> getTiledImplementationFromOperandTile( |
1158 | Operation *op, OpBuilder &b, unsigned operandNumber, |
1159 | ArrayRef<OpFoldResult> offsets, ArrayRef<OpFoldResult> sizes) const { |
1160 | auto unPackOp = cast<UnPackOp>(op); |
1161 | // tensor.unpack op is fusible (as a consumer) only if inner dims are not |
1162 | // tiled. |
1163 | int64_t numTiles = unPackOp.getInnerDimsPos().size(); |
1164 | for (auto iter : |
1165 | llvm::zip_equal(unPackOp.getMixedTiles(), sizes.take_back(numTiles))) { |
1166 | if (!isEqualConstantIntOrValue(std::get<0>(iter), std::get<1>(iter))) |
1167 | return failure(); |
1168 | } |
1169 | |
1170 | Location loc = unPackOp.getLoc(); |
1171 | |
1172 | // Fetch offset/size for creating the slice of the dest operand of |
1173 | // unpack op. |
1174 | SmallVector<OpFoldResult> outputOffsets, outputSizes; |
1175 | if (failed(Result: getIterationDomainTileFromOperandTile( |
1176 | op, b, /*operandNumber=*/0, offsets, sizes, resultOffsets&: outputOffsets, |
1177 | resultSizes&: outputSizes))) |
1178 | return failure(); |
1179 | |
1180 | auto oneAttr = b.getI64IntegerAttr(1); |
1181 | int64_t outputRank = unPackOp.getDestRank(); |
1182 | SmallVector<OpFoldResult> strides(outputRank, oneAttr); |
1183 | |
1184 | SmallVector<Value> tiledOperands; |
1185 | // Create slice of the dest operand. |
1186 | auto extractDestSlice = b.create<tensor::ExtractSliceOp>( |
1187 | loc, unPackOp.getDest(), outputOffsets, outputSizes, strides); |
1188 | tiledOperands.push_back(Elt: extractDestSlice); |
1189 | |
1190 | strides.append(unPackOp.getSourceRank() - outputRank, oneAttr); |
1191 | // Create slice of the source operand. |
1192 | auto extractSourceSlice = b.create<tensor::ExtractSliceOp>( |
1193 | loc, unPackOp.getSource(), offsets, sizes, strides); |
1194 | tiledOperands.insert(tiledOperands.begin(), extractSourceSlice); |
1195 | for (auto tile : unPackOp.getInnerTiles()) |
1196 | tiledOperands.push_back(tile); |
1197 | |
1198 | // Create tiled unpack op. |
1199 | Operation *tiledUnPackOp = |
1200 | b.create<UnPackOp>(loc, TypeRange{extractDestSlice.getType()}, |
1201 | tiledOperands, op->getAttrs()); |
1202 | |
1203 | return TilingResult{.tiledOps: {tiledUnPackOp}, |
1204 | .tiledValues: SmallVector<Value>(tiledUnPackOp->getResults()), |
1205 | .generatedSlices: llvm::to_vector(Range: ArrayRef<Operation *>{ |
1206 | extractSourceSlice, extractDestSlice})}; |
1207 | } |
1208 | }; |
1209 | |
1210 | } // namespace |
1211 | |
1212 | template <typename OpType> |
1213 | static void registerOne(MLIRContext *ctx) { |
1214 | OpType::template attachInterface<LinalgOpTilingInterface<OpType>>(*ctx); |
1215 | OpType::template attachInterface<LinalgOpPartialReductionInterface<OpType>>( |
1216 | *ctx); |
1217 | } |
1218 | |
1219 | /// Variadic helper function. |
1220 | template <typename... OpTypes> |
1221 | static void registerAll(MLIRContext *ctx) { |
1222 | (registerOne<OpTypes>(ctx), ...); |
1223 | } |
1224 | |
1225 | #define GET_OP_LIST |
1226 | |
1227 | void mlir::linalg::registerTilingInterfaceExternalModels( |
1228 | DialectRegistry ®istry) { |
1229 | registry.addExtension(+[](MLIRContext *ctx, linalg::LinalgDialect *dialect) { |
1230 | registerOne<linalg::GenericOp>(ctx); |
1231 | linalg::PackOp::attachInterface<PackOpTiling>(*ctx); |
1232 | linalg::UnPackOp::attachInterface<UnPackOpTiling>(*ctx); |
1233 | registerAll< |
1234 | #include "mlir/Dialect/Linalg/IR/LinalgStructuredOps.cpp.inc" |
1235 | >(ctx); |
1236 | }); |
1237 | } |
1238 | |
1239 | void mlir::linalg::registerTilingInterfaceExternalModelsForPackUnPackOps( |
1240 | DialectRegistry ®istry) { |
1241 | registry.addExtension(extensionFn: +[](MLIRContext *ctx, LinalgDialect *dialect) { |
1242 | linalg::PackOp::attachInterface<PackOpTiling>(*ctx); |
1243 | linalg::UnPackOp::attachInterface<UnPackOpTiling>(*ctx); |
1244 | }); |
1245 | } |
1246 |
Definitions
- getIndicesForAccess
- inlinePayload
- LinalgOpTilingInterface
- getLoopIteratorTypes
- getIterationDomain
- getTiledImplementation
- getMappedOffsetAndSize
- getIterationDomainTileFromOperandTile
- getResultTilePosition
- getIterationDomainTileFromResultTile
- generateResultTileValue
- getTiledImplementationFromOperandTile
- generateScalarImplementation
- getPartialResultAffineMap
- LinalgOpPartialReductionInterface
- generateInitialTensorForPartialReduction
- tileToPartialReduction
- mergeReductions
- getPartialResultTilePosition
- getPackUnPackIterationDomain
- applyPermToRange
- PackOpTiling
- getLoopIteratorTypes
- getIterationDomain
- getTiledImplementation
- getResultTilePosition
- generateResultTileValue
- getIterationDomainTileFromOperandTile
- getTiledImplementationFromOperandTile
- UnpackTileDimInfo
- getUnpackTileDimInfo
- UnPackOpTiling
- getLoopIteratorTypes
- getIterationDomain
- getTiledImplementation
- getResultTilePosition
- generateResultTileValue
- getIterationDomainTileFromOperandTile
- getTiledImplementationFromOperandTile
- registerOne
- registerAll
- registerTilingInterfaceExternalModels
Improve your Profiling and Debugging skills
Find out more