1//===- LinalgTransformOps.cpp - Implementation of Linalg transform ops ----===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#include "mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.h"
10
11#include "mlir/AsmParser/AsmParser.h"
12
13#include "mlir/Dialect/Affine/IR/AffineOps.h"
14#include "mlir/Dialect/Arith/IR/Arith.h"
15#include "mlir/Dialect/Bufferization/IR/Bufferization.h"
16#include "mlir/Dialect/Bufferization/Transforms/OneShotAnalysis.h"
17#include "mlir/Dialect/GPU/IR/GPUDialect.h"
18#include "mlir/Dialect/Linalg/IR/Linalg.h"
19#include "mlir/Dialect/Linalg/TransformOps/GPUHeuristics.h"
20#include "mlir/Dialect/Linalg/TransformOps/Syntax.h"
21#include "mlir/Dialect/Linalg/Transforms/Hoisting.h"
22#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
23#include "mlir/Dialect/Linalg/Utils/Utils.h"
24#include "mlir/Dialect/SCF/Transforms/TileUsingInterface.h"
25#include "mlir/Dialect/Tensor/IR/Tensor.h"
26#include "mlir/Dialect/Tensor/Utils/Utils.h"
27#include "mlir/Dialect/Transform/IR/TransformDialect.h"
28#include "mlir/Dialect/Transform/IR/TransformOps.h"
29#include "mlir/Dialect/Transform/IR/TransformTypes.h"
30#include "mlir/Dialect/Transform/Interfaces/TransformInterfaces.h"
31#include "mlir/Dialect/Transform/Utils/Utils.h"
32#include "mlir/Dialect/Utils/IndexingUtils.h"
33#include "mlir/Dialect/Utils/StaticValueUtils.h"
34#include "mlir/Dialect/Vector/Transforms/LoweringPatterns.h"
35#include "mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h"
36#include "mlir/IR/BuiltinTypeInterfaces.h"
37#include "mlir/IR/PatternMatch.h"
38#include "mlir/IR/TypeUtilities.h"
39#include "mlir/Interfaces/TilingInterface.h"
40#include "mlir/Support/LLVM.h"
41#include "mlir/Support/TypeID.h"
42#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
43#include "llvm/ADT/STLExtras.h"
44#include "llvm/ADT/ScopeExit.h"
45#include "llvm/ADT/TypeSwitch.h"
46#include "llvm/Support/Debug.h"
47#include <type_traits>
48
49using namespace mlir;
50using namespace mlir::linalg;
51using namespace mlir::transform;
52
53#define DEBUG_TYPE "linalg-transforms"
54#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ")
55#define DBGSNL() (llvm::dbgs() << "\n")
56#define LDBG(X) LLVM_DEBUG(DBGS() << (X) << "\n")
57
58/// Attempts to apply the pattern specified as template argument to the given
59/// operation. The pattern is expected to have a `returningMatchAndRewrite`
60/// function that returns the "main" result or failure. Returns failure if the
61/// pattern failed to apply. Extra arguments are forwarded to the pattern
62/// constructor.
63template <typename PatternTy, typename... Args>
64static FailureOr<LinalgOp> tryApply(Operation *operation, Args &&...args) {
65 // Check if the given operation has the type expected by the pattern.
66 using OpTy = typename llvm::function_traits<
67 decltype(&PatternTy::returningMatchAndRewrite)>::template arg_t<0>;
68 auto op = dyn_cast<OpTy>(operation);
69 if (!op)
70 return failure();
71
72 // Apply the pattern directly to the op.
73 PatternTy pattern(operation->getContext(), std::forward<Args>(args)...);
74 // We want to discourage direct use of PatternRewriter in APIs but In this
75 // very specific case, an IRRewriter is not enough.
76 struct TrivialPatternRewriter : public PatternRewriter {
77 public:
78 explicit TrivialPatternRewriter(MLIRContext *context)
79 : PatternRewriter(context) {}
80 };
81 TrivialPatternRewriter rewriter(operation->getContext());
82 rewriter.setInsertionPoint(operation);
83 auto result = pattern.returningMatchAndRewrite(op, rewriter);
84 if (failed(result))
85 return failure();
86 return cast<LinalgOp>(result->getOperation());
87}
88
89/// Assuming that `ofr` is an index attr or a param of index type
90/// or a transform dialect handle mapped to exactly one op
91/// with one index result, return that value.
92static DiagnosedSilenceableFailure unpackSingleIndexResultPayloadOperations(
93 transform::TransformState &state, TransformOpInterface transformOp,
94 SmallVector<OpFoldResult> &result, ArrayRef<OpFoldResult> ofrs) {
95 for (OpFoldResult ofr : ofrs) {
96 if (ofr.is<Attribute>()) {
97 if (!isa<IntegerAttr>(Val: ofr.get<Attribute>()))
98 return transformOp.emitDefiniteFailure() << "expected IntegerAttr";
99 result.push_back(Elt: ofr);
100 continue;
101 }
102
103 Value transformValue = ofr.get<Value>();
104 if (isa<TransformParamTypeInterface>(transformValue.getType())) {
105 ArrayRef<Attribute> params = state.getParams(value: transformValue);
106 if (params.size() != 1)
107 return transformOp.emitDefiniteFailure()
108 << "requires exactly one parameter associated";
109 result.push_back(Elt: params[0]);
110 continue;
111 }
112
113 auto payloadOps = state.getPayloadOps(value: transformValue);
114 if (!llvm::hasSingleElement(C&: payloadOps)) {
115 DiagnosedSilenceableFailure diag =
116 transformOp.emitSilenceableError()
117 << "handle must be mapped to exactly one payload op";
118 diag.attachNote(loc: transformValue.getLoc())
119 << "mapped to " << llvm::range_size(Range&: payloadOps) << " payload ops";
120 return diag;
121 }
122
123 Operation *op = *payloadOps.begin();
124 if (op->getNumResults() != 1 || !op->getResult(idx: 0).getType().isIndex()) {
125 DiagnosedSilenceableFailure diag =
126 transformOp.emitSilenceableError()
127 << "payload op must have exactly 1 index result";
128 diag.attachNote(loc: op->getLoc())
129 << "has " << op->getNumResults() << " results";
130 return diag;
131 }
132 result.push_back(Elt: op->getResult(idx: 0));
133 }
134
135 return DiagnosedSilenceableFailure::success();
136}
137
138// Given a list of params that are index attrs or a list of OpFoldResults
139// that are either index attrs or op handles, return a list of OpFoldResults
140// of index attrs or a list of OpFoldResults where all op handles are
141// replaced with the first (and only) OpResult of that payload op.
142// (There must be exactly one parameter associated with the AnyParamType or
143// one mapped payload op which must have exactly one index result.)
144static DiagnosedSilenceableFailure unpackSingleIndexResultPayloadOperations(
145 transform::TransformState &state, TransformOpInterface transformOp,
146 SmallVector<OpFoldResult> &result, Value packedHandle) {
147 if (isa<TransformParamTypeInterface>(packedHandle.getType())) {
148 ArrayRef<Attribute> params = state.getParams(value: packedHandle);
149 for (auto param : params) {
150 if (!isa<IntegerAttr>(Val: param))
151 return transformOp.emitDefiniteFailure()
152 << "expected the parameter to be associated with an integer "
153 "attribute";
154 result.push_back(Elt: param);
155 }
156 return DiagnosedSilenceableFailure::success();
157 }
158
159 for (Operation *op : state.getPayloadOps(value: packedHandle)) {
160 if (op->getNumResults() != 1 || !op->getResult(idx: 0).getType().isIndex()) {
161 DiagnosedSilenceableFailure diag =
162 transformOp.emitSilenceableError()
163 << "payload op must have exactly 1 index result";
164 diag.attachNote(loc: op->getLoc())
165 << "has " << op->getNumResults() << " results";
166 return diag;
167 }
168 result.push_back(Elt: op->getResult(idx: 0));
169 }
170
171 return DiagnosedSilenceableFailure::success();
172}
173
174//===----------------------------------------------------------------------===//
175// Apply...PatternsOp
176//===----------------------------------------------------------------------===//
177
178void transform::ApplyEraseUnnecessaryInputsPatternsOp::populatePatterns(
179 RewritePatternSet &patterns) {
180 linalg::populateEraseUnnecessaryInputsPatterns(patterns);
181}
182
183void transform::ApplyFoldUnitExtentDimsViaReshapesPatternsOp::populatePatterns(
184 RewritePatternSet &patterns) {
185 linalg::ControlDropUnitDims options;
186 linalg::populateFoldUnitExtentDimsPatterns(patterns, options);
187}
188
189void transform::ApplyFoldUnitExtentDimsViaSlicesPatternsOp::populatePatterns(
190 RewritePatternSet &patterns) {
191 linalg::ControlDropUnitDims options;
192 options.rankReductionStrategy =
193 linalg::ControlDropUnitDims::RankReductionStrategy::ExtractInsertSlice;
194 linalg::populateFoldUnitExtentDimsPatterns(patterns, options);
195}
196
197void transform::ApplyTilingCanonicalizationPatternsOp::populatePatterns(
198 RewritePatternSet &patterns) {
199 linalg::populateLinalgTilingCanonicalizationPatterns(patterns);
200}
201
202//===----------------------------------------------------------------------===//
203// BufferizeToAllocationOp
204//===----------------------------------------------------------------------===//
205
206void transform::BufferizeToAllocationOp::build(OpBuilder &b,
207 OperationState &result,
208 Value target,
209 Attribute memorySpace) {
210 SmallVector<Type> resultTypes;
211 resultTypes.push_back(b.getType<transform::AnyValueType>());
212 resultTypes.push_back(b.getType<transform::AnyOpType>());
213 return build(b, result,
214 /*resultTypes=*/resultTypes,
215 /*target=*/target,
216 /*memorySpace=*/memorySpace);
217}
218
219void transform::BufferizeToAllocationOp::build(OpBuilder &b,
220 OperationState &result,
221 Value target,
222 int64_t memorySpace) {
223 SmallVector<Type> resultTypes;
224 resultTypes.push_back(b.getType<transform::AnyValueType>());
225 resultTypes.push_back(b.getType<transform::AnyOpType>());
226 return build(b, result,
227 /*resultTypes=*/resultTypes,
228 /*target=*/target,
229 /*memorySpace=*/b.getI64IntegerAttr(memorySpace));
230}
231
232namespace {
233class NewOpsListener : public RewriterBase::ForwardingListener {
234public:
235 using RewriterBase::ForwardingListener::ForwardingListener;
236
237 SmallVector<Operation *> getNewOps() const {
238 return SmallVector<Operation *>(newOps.begin(), newOps.end());
239 }
240
241private:
242 void notifyOperationInserted(Operation *op,
243 OpBuilder::InsertPoint previous) override {
244 ForwardingListener::notifyOperationInserted(op, previous);
245 // We only care about newly created ops.
246 if (previous.isSet())
247 return;
248 auto inserted = newOps.insert(V: op);
249 (void)inserted;
250 assert(inserted.second && "expected newly created op");
251 }
252
253 void notifyOperationErased(Operation *op) override {
254 ForwardingListener::notifyOperationErased(op);
255 op->walk(callback: [&](Operation *op) { newOps.erase(V: op); });
256 }
257
258 DenseSet<Operation *> newOps;
259};
260} // namespace
261
262DiagnosedSilenceableFailure transform::BufferizeToAllocationOp::apply(
263 transform::TransformRewriter &rewriter,
264 transform::TransformResults &results, transform::TransformState &state) {
265 // Attach listener to keep track of newly created ops.
266 OpBuilder::Listener *previousListener = rewriter.getListener();
267 auto resetListener =
268 llvm::make_scope_exit([&]() { rewriter.setListener(previousListener); });
269 NewOpsListener newOpsListener(previousListener);
270 rewriter.setListener(&newOpsListener);
271
272 linalg::BufferizeToAllocationOptions options;
273 if (getMemcpyOp() == "bufferization.materialize_in_destination") {
274 options.memcpyOp = linalg::BufferizeToAllocationOptions::MemcpyOp::
275 MaterializeInDestination;
276 } else if (getMemcpyOp() == "memref.copy") {
277 options.memcpyOp =
278 linalg::BufferizeToAllocationOptions::MemcpyOp::MemrefCopy;
279 } else if (getMemcpyOp() == "linalg.copy") {
280 options.memcpyOp =
281 linalg::BufferizeToAllocationOptions::MemcpyOp::LinalgCopy;
282 } else {
283 llvm_unreachable("invalid memcpy op");
284 }
285 if (getAllocOp() == "memref.alloc") {
286 options.allocOp =
287 linalg::BufferizeToAllocationOptions::AllocOp::MemrefAlloc;
288 } else if (getAllocOp() == "memref.alloca") {
289 options.allocOp =
290 linalg::BufferizeToAllocationOptions::AllocOp::MemrefAlloca;
291 } else {
292 llvm_unreachable("invalid alloc op");
293 }
294 options.bufferizeDestinationOnly = getBufferizeDestinationOnly();
295 options.emitDealloc = getEmitDealloc();
296
297 // Bufferize ops.
298 Attribute memorySpace =
299 getMemorySpace().has_value() ? getMemorySpace().value() : Attribute();
300 SmallVector<Value> allocatedBuffers;
301 for (Operation *op : state.getPayloadOps(getTarget())) {
302 Value buffer =
303 linalg::bufferizeToAllocation(rewriter, options, op, memorySpace);
304 if (!buffer) {
305 DiagnosedSilenceableFailure diag = emitSilenceableError()
306 << "failed to bufferize operation";
307 diag.attachNote(op->getLoc()) << "target payload op";
308 return diag;
309 }
310 allocatedBuffers.push_back(buffer);
311 }
312
313 // Set results.
314 results.setValues(cast<OpResult>(getAllocatedBuffer()), allocatedBuffers);
315 results.set(cast<OpResult>(getNewOps()), newOpsListener.getNewOps());
316 return DiagnosedSilenceableFailure::success();
317}
318
319void transform::BufferizeToAllocationOp::getEffects(
320 SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
321 if (getBufferizeDestinationOnly()) {
322 // The destination is replaced with a newly allocated buffer, but the op
323 // itself remains in place.
324 onlyReadsHandle(getTarget(), effects);
325 } else {
326 consumesHandle(getTarget(), effects);
327 }
328 producesHandle(getAllocatedBuffer(), effects);
329 producesHandle(getNewOps(), effects);
330 modifiesPayload(effects);
331}
332
333LogicalResult transform::BufferizeToAllocationOp::verify() {
334 if (getMemcpyOp() != "bufferization.materialize_in_destination" &&
335 getMemcpyOp() != "memref.copy" && getMemcpyOp() != "linalg.copy")
336 return emitOpError() << "unsupported memcpy op";
337 if (getAllocOp() != "memref.alloc" && getAllocOp() != "memref.alloca")
338 return emitOpError() << "unsupported alloc op";
339 return success();
340}
341
342//===----------------------------------------------------------------------===//
343// DecomposeOp
344//===----------------------------------------------------------------------===//
345
346DiagnosedSilenceableFailure
347transform::DecomposeOp::applyToOne(transform::TransformRewriter &rewriter,
348 LinalgOp target,
349 transform::ApplyToEachResultList &results,
350 transform::TransformState &state) {
351#define DOWNSCALE(trans) \
352 { \
353 FailureOr<LinalgOp> res = tryApply<trans>(target); \
354 if (succeeded(res)) { \
355 results.push_back(*res); \
356 return DiagnosedSilenceableFailure::success(); \
357 } \
358 }
359
360#define DOWNSCALE_CALL(a, b) DownscaleSizeOneWindowed2DConvolution<a, b>
361#define DOWNSCALE_NORMAL(a, b) DOWNSCALE(DOWNSCALE_CALL(a, b))
362
363 DOWNSCALE_NORMAL(Conv2DNhwcHwcfOp, Conv1DNwcWcfOp)
364 DOWNSCALE_NORMAL(Conv2DNchwFchwOp, Conv1DNcwFcwOp)
365 DOWNSCALE_NORMAL(PoolingNhwcSumOp, PoolingNwcSumOp)
366 DOWNSCALE_NORMAL(PoolingNchwSumOp, PoolingNcwSumOp)
367 DOWNSCALE_NORMAL(PoolingNhwcMaxOp, PoolingNwcMaxOp)
368 DOWNSCALE_NORMAL(PoolingNhwcMaxUnsignedOp, PoolingNwcMaxUnsignedOp)
369 DOWNSCALE_NORMAL(PoolingNhwcMinOp, PoolingNwcMinOp)
370 DOWNSCALE_NORMAL(PoolingNhwcMinUnsignedOp, PoolingNwcMinUnsignedOp)
371 DOWNSCALE_NORMAL(PoolingNchwMaxOp, PoolingNcwMaxOp)
372 DOWNSCALE(DownscaleDepthwiseConv2DNhwcHwcOp)
373 DOWNSCALE(DownscaleConv2DOp)
374#undef DOWNSCALE_NORMAL
375#undef DOWNSCALE_CALL
376#undef DOWNSCALE
377 return emitDefaultSilenceableFailure(target);
378}
379
380//===----------------------------------------------------------------------===//
381// DecomposeInterfaceOp
382//===----------------------------------------------------------------------===//
383
384// Decompose the target operation if it implements the AggregatedOpInterface.
385// Push the decomposed operations (the ones that replaces the values produced by
386// \p target) in the `results`.
387DiagnosedSilenceableFailure transform::DecomposeInterfaceOp::applyToOne(
388 transform::TransformRewriter &rewriter, Operation *target,
389 transform::ApplyToEachResultList &results,
390 transform::TransformState &state) {
391 auto decomposableOp = dyn_cast<AggregatedOpInterface>(target);
392 if (!decomposableOp) {
393 failed(rewriter.notifyMatchFailure(target,
394 "payload is not a decomposable op"));
395 return emitDefaultSilenceableFailure(target);
396 }
397
398 FailureOr<SmallVector<Value>> maybeNewResults =
399 decomposableOp.decomposeOperation(rewriter);
400 if (failed(maybeNewResults))
401 return emitDefaultSilenceableFailure(target);
402
403 rewriter.replaceOp(decomposableOp, *maybeNewResults);
404 for (Value val : *maybeNewResults) {
405 Operation *definition = val.getDefiningOp();
406 if (definition)
407 results.push_back(definition);
408 }
409 return DiagnosedSilenceableFailure::success();
410}
411
412//===----------------------------------------------------------------------===//
413// EliminateLinalgOpAnchoredEmptyTensorsOp
414//===----------------------------------------------------------------------===//
415
416void transform::EliminateLinalgOpAnchoredEmptyTensorsOp::getEffects(
417 SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
418 onlyReadsHandle(getTarget(), effects);
419 modifiesPayload(effects);
420}
421
422DiagnosedSilenceableFailure
423transform::EliminateLinalgOpAnchoredEmptyTensorsOp::apply(
424 transform::TransformRewriter &rewriter, TransformResults &transformResults,
425 TransformState &state) {
426 bufferization::OneShotBufferizationOptions options;
427 options.allowReturnAllocsFromLoops = true;
428
429 for (Operation *target : state.getPayloadOps(getTarget())) {
430 bufferization::OneShotAnalysisState state(target, options);
431 if (failed(analyzeOp(target, state)))
432 return mlir::emitSilenceableFailure(target->getLoc())
433 << "failed to analyze op";
434 if (failed(linalg::linalgOpAnchoredEmptyTensorEliminationStep(
435 rewriter, target, state)))
436 return mlir::emitSilenceableFailure(target->getLoc())
437 << "failed to eliminate LinalgOp anchored tensor.empty ops";
438 }
439 return DiagnosedSilenceableFailure::success();
440}
441
442//===----------------------------------------------------------------------===//
443// FuseOp
444//===----------------------------------------------------------------------===//
445
446/// Apply a tiling transformation to all payload ops and store both the
447/// tiled operation as well as the created tile loops.
448template <typename Range>
449static LogicalResult applyTilingToAll(
450 RewriterBase &rewriter, Operation *transformOp, Range &&payloadOps,
451 unsigned numLoops, transform::TransformResults &transformResults,
452 function_ref<FailureOr<scf::SCFTileAndFuseResult>(TilingInterface)>
453 applyFn) {
454 SmallVector<Operation *> tiledLinalgOps;
455 SmallVector<SmallVector<Operation *>> loopOps(numLoops);
456
457 for (Operation *target : payloadOps) {
458 auto tilingInterfaceOp = dyn_cast<TilingInterface>(target);
459 if (!tilingInterfaceOp)
460 return transformOp->emitError(message: "only TilingInterface ops are supported");
461
462 rewriter.setInsertionPoint(target);
463 FailureOr<scf::SCFTileAndFuseResult> tiledResults =
464 applyFn(tilingInterfaceOp);
465 if (failed(tiledResults))
466 return failure();
467
468 // Perform the replacement of tiled and fused values.
469 SmallVector<Operation *> opsToReplace{target};
470 llvm::append_range(opsToReplace, tiledResults->fusedProducers);
471 for (Operation *toReplace : opsToReplace) {
472 for (OpResult res : toReplace->getResults())
473 if (auto replacement = tiledResults->replacements.lookup(res))
474 rewriter.replaceAllUsesWith(res, replacement);
475 if (toReplace->use_empty()) {
476 rewriter.eraseOp(op: toReplace);
477 }
478 }
479
480 // Report back the relevant handles to the transform op.
481 tiledLinalgOps.push_back(Elt: tiledResults->tiledAndFusedOps.front());
482 assert(tiledResults->loops.size() == numLoops &&
483 "Mismatched number of loops, tile and fuse transform should have "
484 "failed");
485 for (unsigned int i = 0; i < numLoops; ++i)
486 loopOps[i].push_back(Elt: tiledResults->loops[i]);
487 }
488
489 transformResults.set(value: transformOp->getOpResult(idx: 0), ops&: tiledLinalgOps);
490 for (unsigned int i = 0; i < numLoops; ++i)
491 transformResults.set(value: transformOp->getOpResult(idx: i + 1), ops&: loopOps[i]);
492
493 return success();
494}
495
496DiagnosedSilenceableFailure
497transform::FuseOp::apply(transform::TransformRewriter &rewriter,
498 mlir::transform::TransformResults &transformResults,
499 mlir::transform::TransformState &state) {
500 SmallVector<int64_t> tileSizes =
501 extractFromIntegerArrayAttr<int64_t>(getTileSizes());
502 SmallVector<int64_t> tileInterchange =
503 extractFromIntegerArrayAttr<int64_t>(getTileInterchange());
504
505 scf::SCFTilingOptions tilingOptions;
506 tilingOptions.interchangeVector = tileInterchange;
507 SmallVector<OpFoldResult> tileSizesOfr =
508 getAsIndexOpFoldResult(rewriter.getContext(), tileSizes);
509 tilingOptions = tilingOptions.setTileSizes(tileSizesOfr);
510 scf::SCFTileAndFuseOptions tileAndFuseOptions;
511 tileAndFuseOptions.tilingOptions = tilingOptions;
512 LogicalResult result = applyTilingToAll(
513 rewriter, getOperation(), state.getPayloadOps(getTarget()),
514 tileSizes.size() - llvm::count(tileSizes, 0), transformResults,
515 [&](TilingInterface tilingInterfaceOp)
516 -> FailureOr<scf::SCFTileAndFuseResult> {
517 return tileConsumerAndFuseProducersUsingSCF(rewriter, tilingInterfaceOp,
518 tileAndFuseOptions);
519 });
520 return failed(result) ? DiagnosedSilenceableFailure::definiteFailure()
521 : DiagnosedSilenceableFailure::success();
522}
523
524LogicalResult transform::FuseOp::verify() {
525 SmallVector<int64_t> permutation =
526 extractFromIntegerArrayAttr<int64_t>(getTileInterchange());
527 auto sequence = llvm::to_vector(llvm::seq<int64_t>(0, permutation.size()));
528 if (!std::is_permutation(sequence.begin(), sequence.end(),
529 permutation.begin(), permutation.end())) {
530 return emitOpError() << "expects interchange to be a permutation, found "
531 << getTileInterchange();
532 }
533
534 SmallVector<int64_t> sizes =
535 extractFromIntegerArrayAttr<int64_t>(getTileSizes());
536 size_t numExpectedLoops = sizes.size() - llvm::count(sizes, 0);
537 if (numExpectedLoops != getNumResults() - 1)
538 return emitOpError() << "expects " << numExpectedLoops << " loop results";
539
540 return success();
541}
542
543//===----------------------------------------------------------------------===//
544// FuseIntoContainingOp
545//===----------------------------------------------------------------------===//
546
547void transform::FuseIntoContainingOp::build(OpBuilder &builder,
548 OperationState &result,
549 Value producerOp,
550 Value containingOp) {
551 result.addOperands({producerOp, containingOp});
552 auto resultType = transform::AnyOpType::get(builder.getContext());
553 result.addTypes({resultType, resultType});
554}
555
556/// Add new operands to the forall op for users of the producerOp
557/// that are dominated by the containing scf.forall op.
558static Operation *replaceForAllWithNewSignature(
559 RewriterBase &rewriter, Diagnostic &diag, Operation *producerOp,
560 Operation *containingOp, TilingResult &tileAndFuseResult,
561 int64_t resultNumber, SmallVector<OpFoldResult> &offsets,
562 SmallVector<OpFoldResult> &sizes) {
563
564 // Count number of users not including the containing op
565 SetVector<Operation *> dominatedUsers;
566 DominanceInfo domInfo(containingOp);
567 for (Operation *user : producerOp->getResult(idx: resultNumber).getUsers()) {
568 if (!containingOp->isAncestor(other: user) &&
569 (domInfo.dominates(a: containingOp, b: user))) {
570 dominatedUsers.insert(X: user);
571 }
572 }
573 if (dominatedUsers.empty())
574 return nullptr;
575
576 // Create new scf.forall op
577 auto forallOp = cast<scf::ForallOp>(containingOp);
578 OpBuilder::InsertionGuard g(rewriter);
579 rewriter.setInsertionPoint(forallOp);
580
581 // Get new output
582 Location loc = forallOp.getLoc();
583 auto genericOp = dyn_cast<linalg::GenericOp>(producerOp);
584 if (!genericOp)
585 return nullptr;
586 SmallVector<Value> outputs = genericOp.getOutputs();
587 SmallVector<Value> newOuts(forallOp.getOutputs());
588 newOuts.push_back(Elt: outputs[resultNumber]);
589
590 // Create new scf.forall op
591 auto newforallOp = rewriter.create<scf::ForallOp>(
592 loc, forallOp.getMixedLowerBound(), forallOp.getMixedUpperBound(),
593 forallOp.getMixedStep(), newOuts, forallOp.getMapping());
594 rewriter.eraseBlock(block: newforallOp.getBody());
595 newforallOp.getRegion().takeBody(forallOp.getRegion());
596
597 // Add additional block argument for new value being returned
598 // and replaces all uses of the new output with corresponding bbArg
599 // inside the scf.forall to enable fusion into this new scf.forall.
600 newforallOp.getBody()->addArgument(newOuts.back().getType(),
601 newOuts.back().getLoc());
602 auto bbArgs = newforallOp.getBody()->getArguments();
603 rewriter.replaceUsesWithIf(newOuts.back(), bbArgs.back(),
604 [&](OpOperand &use) {
605 Operation *op = use.getOwner();
606 return newforallOp->isProperAncestor(op);
607 });
608
609 // Fix terminator
610 scf::InParallelOp terminatorOp = newforallOp.getTerminator();
611 SmallVector<Operation *> yieldingOps = llvm::to_vector<4>(llvm::map_range(
612 terminatorOp.getYieldingOps(), [](Operation &op) { return &op; }));
613 Operation *firstYieldOp = yieldingOps.front();
614 rewriter.setInsertionPoint(firstYieldOp);
615 Value src = tileAndFuseResult.tiledValues[0];
616 Value dst = newforallOp.getRegionIterArgs().back();
617 SmallVector<OpFoldResult> strides(offsets.size(), rewriter.getIndexAttr(1));
618 rewriter.create<tensor::ParallelInsertSliceOp>(firstYieldOp->getLoc(), src,
619 dst, offsets, sizes, strides);
620
621 for (auto result : llvm::enumerate(forallOp.getResults())) {
622 rewriter.replaceAllUsesWith(result.value(),
623 newforallOp->getResult(result.index()));
624 }
625 rewriter.replaceUsesWithIf(producerOp->getResult(idx: resultNumber),
626 newforallOp->getResults().back(),
627 [&](OpOperand &use) {
628 Operation *user = use.getOwner();
629 return dominatedUsers.contains(key: user);
630 });
631 return newforallOp;
632}
633
634/// Find the first "extract" user of `producerOp` and tile it right before its
635/// use. The tiled op is fused under the `containingOp`.
636/// Return this fused op on success or nullptr if anything fails.
637/// If tiled op has uses that are dominated by `containingOp`, return
638/// a new `containingOp` with results of the fused op appended to
639/// results of the `containingOp` or nullptr if there are no dominated uses.
640static std::tuple<SmallVector<Operation *>, Operation *>
641tileAndFuseFirstExtractUse(RewriterBase &rewriter, Diagnostic &diag,
642 Operation *producerOp, Operation *containingOp) {
643 LLVM_DEBUG(DBGS() << "Try to fuse a direct extract use\n");
644 auto tileableProducer = dyn_cast<TilingInterface>(producerOp);
645 if (!tileableProducer) {
646 diag.attachNote(noteLoc: producerOp->getLoc())
647 << "producer is not a TileableInterface: " << *producerOp;
648 return {};
649 }
650
651 // Search the producer slices accessed within the containing operation.
652 // TODO: Generalize to more extract/insert/parallel_insert triples, maybe
653 // evolve into an interface.
654 auto it = llvm::find_if(tileableProducer->getUsers(), [&](Operation *user) {
655 auto sliceOp = dyn_cast<tensor::ExtractSliceOp>(user);
656 return sliceOp && containingOp->isProperAncestor(sliceOp);
657 });
658
659 // Find a fusion opportunity.
660 if (it == tileableProducer->getUsers().end()) {
661 diag.attachNote(noteLoc: tileableProducer->getLoc())
662 << "could not find fusion opportunity for: " << *tileableProducer;
663 return {};
664 }
665 auto sliceOpToTile = cast<tensor::ExtractSliceOp>(*it);
666
667 // Try to fuse the producer in-place.
668 OpBuilder::InsertionGuard guard(rewriter);
669 rewriter.setInsertionPoint(sliceOpToTile);
670
671 // Tile the producer.
672 int64_t resultNumber =
673 cast<OpResult>(sliceOpToTile.getSource()).getResultNumber();
674 LLVM_DEBUG(DBGS() << "resultNumber: " << resultNumber << "\n");
675
676 SmallVector<OpFoldResult> offsets = sliceOpToTile.getMixedOffsets();
677 SmallVector<OpFoldResult> sizes = sliceOpToTile.getMixedSizes();
678
679 FailureOr<TilingResult> tileAndFuseResult =
680 tileableProducer.generateResultTileValue(rewriter, resultNumber, offsets,
681 sizes);
682
683 if (failed(result: tileAndFuseResult)) {
684 diag.attachNote(noteLoc: tileableProducer->getLoc())
685 << "failed to tile producer op: " << *tileableProducer;
686 return {};
687 }
688
689#ifndef NDEBUG
690 for (auto *tiledOp : tileAndFuseResult->tiledOps) {
691 LLVM_DEBUG(DBGS() << "tiledProducer: " << *tiledOp << "\n");
692 }
693#endif
694
695 // Replace the extract op.
696 auto maybeRankReduced = tensor::ExtractSliceOp::rankReduceIfNeeded(
697 rewriter, sliceOpToTile->getLoc(), tileAndFuseResult->tiledValues[0],
698 cast<RankedTensorType>(sliceOpToTile->getResult(0).getType()).getShape());
699 if (failed(maybeRankReduced)) {
700 diag.attachNote(noteLoc: producerOp->getLoc())
701 << "shape types don't match (missing canonicalization?):\nTiledOp: "
702 << tileAndFuseResult->tiledValues[0]
703 << "\nSliceOp: " << sliceOpToTile.getOperation() << '\n';
704 return {};
705 }
706 rewriter.replaceOp(sliceOpToTile, *maybeRankReduced);
707
708 // Add new outputs to containing op, if required
709 Operation *newContainingOp = replaceForAllWithNewSignature(
710 rewriter, diag, producerOp, containingOp, tileAndFuseResult&: *tileAndFuseResult,
711 resultNumber, offsets, sizes);
712
713 return std::make_tuple(args&: tileAndFuseResult->tiledOps, args&: newContainingOp);
714}
715
716/// First, find the first "scf::ForallOp" user of `producerOp` and ensure
717/// it is exactly the `containingOp`, otherwise bail.
718/// Then, find the first "extract" user of the tied block argument and tile it
719/// right before its "extract" use. The tiled op is fused under the
720/// `containingOp`.
721/// Return this fused op on success or nullptr if anything fails.
722static SmallVector<Operation *>
723tileAndFuseFirstExtractUseThroughContainingOpBlockArgument(
724 RewriterBase &rewriter, Diagnostic &diag, Operation *producerOp,
725 Operation *containingOp) {
726 LLVM_DEBUG(DBGS() << "Try to fuse an extract use through block argument\n");
727
728 auto tileableProducer = dyn_cast<TilingInterface>(producerOp);
729 if (!tileableProducer) {
730 diag.attachNote(noteLoc: producerOp->getLoc())
731 << "producer is not a TileableInterface: " << *producerOp;
732 return {};
733 }
734
735 // Search the first use by a "scf::ForallOp" user.
736 scf::ForallOp forallOp;
737 auto itProducerUses =
738 llvm::find_if(tileableProducer->getUses(), [&](OpOperand &use) {
739 forallOp = dyn_cast<scf::ForallOp>(use.getOwner());
740 return forallOp;
741 });
742 // If it's not from the containing op, return.
743 if (!forallOp || forallOp != containingOp) {
744 diag.attachNote(noteLoc: tileableProducer->getLoc())
745 << "could not find a use by the containing op: " << *tileableProducer;
746 return {};
747 }
748
749 // Search the producer slices accessed within the containing
750 // operation.
751 // TODO: Generalize to more extract/insert/parallel_insert triples.
752 // Maybe evolve into an interface.
753 OpOperand *pUse = &(*itProducerUses);
754 BlockArgument bbArg = forallOp.getTiedBlockArgument(pUse);
755
756 // Search the producer slices accessed within the containing operation.
757 // TODO: Generalize to more extract/insert/parallel_insert triples, maybe
758 // evolve into an interface.
759 auto itBBArgUsers = llvm::find_if(Range: bbArg.getUsers(), P: [&](Operation *user) {
760 auto sliceOp = dyn_cast<tensor::ExtractSliceOp>(user);
761 return sliceOp && containingOp->isProperAncestor(sliceOp);
762 });
763
764 // Find a fusion opportunity.
765 if (itBBArgUsers == bbArg.getUsers().end()) {
766 diag.attachNote(noteLoc: containingOp->getLoc())
767 << "could not find fusion opportunity for bbArg: " << bbArg;
768 return {};
769 }
770 auto sliceOpToTile = cast<tensor::ExtractSliceOp>(*itBBArgUsers);
771
772 // Try to fuse the producer in-place.
773 OpBuilder::InsertionGuard guard(rewriter);
774 rewriter.setInsertionPoint(sliceOpToTile);
775
776 // Replace the use in the tileableProducer before tiling: clone, replace and
777 // then tile.
778 int64_t resultNumber = cast<OpResult>(Val: pUse->get()).getResultNumber();
779 LLVM_DEBUG(DBGS() << "resultNumber: " << resultNumber << "\n");
780
781 // Gather destination tensors.
782 SmallVector<Value> destinationTensors;
783 if (failed(tensor::getOrCreateDestinations(
784 b&: rewriter, loc: tileableProducer->getLoc(), op: tileableProducer,
785 result&: destinationTensors))) {
786 diag.attachNote(noteLoc: tileableProducer->getLoc())
787 << "failed to get destination tensors for: " << *tileableProducer;
788 return {};
789 }
790
791 IRMapping bvm;
792 bvm.map(from: destinationTensors[resultNumber], to: bbArg);
793 auto tileableProducerClone =
794 cast<TilingInterface>(rewriter.clone(*tileableProducer, bvm));
795 auto scopeGuard =
796 llvm::make_scope_exit(F: [&]() { rewriter.eraseOp(op: tileableProducerClone); });
797
798 // Tile the producer.
799 FailureOr<TilingResult> tileAndFuseResult =
800 tileableProducerClone.generateResultTileValue(
801 rewriter, resultNumber, sliceOpToTile.getMixedOffsets(),
802 sliceOpToTile.getMixedSizes());
803 if (failed(result: tileAndFuseResult)) {
804 diag.attachNote(noteLoc: tileableProducer->getLoc())
805 << "failed to tile producer op: " << *tileableProducer;
806 return {};
807 }
808
809 // Replace the extract op.
810 auto maybeRankReduced = tensor::ExtractSliceOp::rankReduceIfNeeded(
811 rewriter, sliceOpToTile->getLoc(), tileAndFuseResult->tiledValues[0],
812 cast<RankedTensorType>(sliceOpToTile->getResult(0).getType()).getShape());
813 assert(succeeded(maybeRankReduced) && "unexpected shape");
814 rewriter.replaceOp(sliceOpToTile, *maybeRankReduced);
815
816 // Replace the use in containingOp.
817 rewriter.modifyOpInPlace(root: containingOp, callable: [&]() {
818 containingOp->setOperand(idx: pUse->getOperandNumber(),
819 value: destinationTensors.front());
820 });
821
822 return tileAndFuseResult->tiledOps;
823}
824
825static Operation *cloneAndFuseFirstUse(RewriterBase &rewriter, Diagnostic &diag,
826 Operation *producerOp,
827 Operation *containingOp) {
828 LLVM_DEBUG(DBGS() << "Try to fuse an use by cloning\n");
829
830 // Gather all uses inside the containing op.
831 SmallVector<OpOperand *> uses;
832 for (OpResult result : producerOp->getOpResults()) {
833 for (OpOperand &use : result.getUses()) {
834 if (containingOp->isProperAncestor(other: use.getOwner())) {
835 uses.push_back(Elt: &use);
836 continue;
837 }
838 // Cannot clone and fuse if the use is by the containing op itself: fail
839 // immediately.
840 if (containingOp == use.getOwner()) {
841 diag.attachNote(noteLoc: producerOp->getLoc())
842 << "producer op use by containing op cannot be fused by cloning";
843 return nullptr;
844 }
845 }
846 }
847
848 // Check for a non-empty list of fusion opportunities.
849 if (uses.empty()) {
850 diag.attachNote(noteLoc: producerOp->getLoc()) << "no fusion opportunity by cloning";
851 return nullptr;
852 }
853
854 // Clone and fuse inside the containing op.
855 Operation *fusedOp = nullptr;
856 OpOperand *use = uses.front();
857 // Parallel insert slice is not a valid clone destination.
858 // TODO: Generalize to other type of ops.
859 assert(!isa<tensor::ParallelInsertSliceOp>(use->getOwner()) &&
860 "Parallel insert slice is not a valid clone destination");
861 unsigned resultNumber = cast<OpResult>(Val: use->get()).getResultNumber();
862 LLVM_DEBUG(DBGS() << "resultNumber: " << resultNumber << "\n");
863
864 OpBuilder::InsertionGuard guard(rewriter);
865 rewriter.setInsertionPoint(use->getOwner());
866 fusedOp = rewriter.clone(op&: *producerOp);
867 rewriter.modifyOpInPlace(
868 root: use->getOwner(), callable: [&] { use->set(fusedOp->getOpResult(idx: resultNumber)); });
869
870 return fusedOp;
871}
872
873bool transform::FuseIntoContainingOp::allowsRepeatedHandleOperands() {
874 // Allow repeated handles since we are fusing everything anyway.
875 return true;
876}
877
878DiagnosedSilenceableFailure
879transform::FuseIntoContainingOp::apply(transform::TransformRewriter &rewriter,
880 transform::TransformResults &results,
881 transform::TransformState &state) {
882 SmallVector<Operation *> fusedOps;
883 auto producerOps = state.getPayloadOps(getProducerOp());
884 auto containingOps = state.getPayloadOps(getContainingOp());
885 if (!llvm::hasSingleElement(containingOps)) {
886 return emitDefiniteFailure()
887 << "requires exactly one containing_op handle (got "
888 << llvm::range_size(containingOps) << ")";
889 }
890 Operation *containingOp = *containingOps.begin();
891
892 // If nothing to fuse, propagate success.
893 if (std::empty(producerOps)) {
894 results.set(cast<OpResult>(getFusedOp()), SmallVector<mlir::Operation *>{});
895 results.set(cast<OpResult>(getNewContainingOp()), {containingOp});
896 return DiagnosedSilenceableFailure::success();
897 }
898
899 // Helper function to find the next producer that should be fused. Take any
900 // producer that has a use inside the containing op.
901 SetVector<Operation *> remainingProducers(producerOps.begin(),
902 producerOps.end());
903 auto getNextProducer = [&]() -> FailureOr<Operation *> {
904 for (const auto &it : enumerate(remainingProducers)) {
905 Operation *producerOp = it.value();
906 // The containing op may be a user of producerOp: use isAncestor.
907 int64_t numUsesInContainingOp =
908 llvm::count_if(producerOp->getUsers(), [&](Operation *op) {
909 return containingOp->isAncestor(op);
910 });
911 // TODO: When resolving the TODO below (no duplicate ops), take an op
912 // that has no use among the remaining producers. This is a topological
913 // sorting.
914 if (numUsesInContainingOp > 0) {
915 if (numUsesInContainingOp == 1)
916 remainingProducers.erase(remainingProducers.begin() + it.index());
917 return producerOp;
918 }
919 }
920 return failure();
921 };
922
923 while (!remainingProducers.empty()) {
924 auto nextProducer = getNextProducer();
925 if (failed(nextProducer)) {
926 auto diag = mlir::emitSilenceableFailure(getLoc())
927 << "could not find next producer to fuse into container";
928 diag.attachNote(containingOp->getLoc()) << "containing op";
929 return diag;
930 }
931
932 Operation *producerOp = *nextProducer;
933
934 // Default diagnostic, to be complemented with more failure information.
935 Diagnostic diag(producerOp->getLoc(), DiagnosticSeverity::Remark);
936 diag << "could not fuse " << *producerOp << " into " << *containingOp;
937
938 // TODO: If there are multiple uses of the producer in the containing op,
939 // we currently tile/clone the op multiple times (once per use). In some
940 // cases, we can tile/clone once and reuse the value for each use.
941 // Futhermore, producers should then be traversed according to a
942 // topological sorting.
943 auto [tiledOps, newContainingOp] =
944 tileAndFuseFirstExtractUse(rewriter, diag, producerOp, containingOp);
945 if (!tiledOps.empty()) {
946 LLVM_DEBUG(DBGS() << "\nFused a direct extract use\n" << *containingOp);
947 fusedOps.append(tiledOps);
948 if (newContainingOp) {
949 // Update handles associated with the containing op so we don't need to
950 // invalidate them. This is a hack to support better composability
951 // between tiling and fusion while a proper mechanism is being
952 // investigated.
953 //
954 // DO NOT replicate this elsewhere unless you understand what you are
955 // doing.
956 LogicalResult replacementStatus =
957 rewriter.notifyPayloadOperationReplaced(containingOp,
958 newContainingOp);
959 (void)replacementStatus;
960 assert(succeeded(replacementStatus) &&
961 "unable to update transform state mapping");
962 rewriter.eraseOp(containingOp);
963 containingOp = newContainingOp;
964 }
965 continue;
966 }
967
968 SmallVector<Operation *> tiledContainingOpOperand =
969 tileAndFuseFirstExtractUseThroughContainingOpBlockArgument(
970 rewriter, diag, producerOp, containingOp);
971 if (!tiledContainingOpOperand.empty()) {
972 LLVM_DEBUG(DBGS() << "\nFused an extract use through block argument\n"
973 << *containingOp);
974 fusedOps.append(tiledContainingOpOperand);
975 continue;
976 }
977
978 Operation *cloned =
979 cloneAndFuseFirstUse(rewriter, diag, producerOp, containingOp);
980 if (cloned) {
981 LLVM_DEBUG(DBGS() << "\nFused an use by cloning\n" << *containingOp);
982 fusedOps.push_back(cloned);
983 continue;
984 }
985 return DiagnosedSilenceableFailure::silenceableFailure(std::move(diag));
986 }
987
988 results.set(cast<OpResult>(getFusedOp()), fusedOps);
989 results.set(cast<OpResult>(getNewContainingOp()), {containingOp});
990 return DiagnosedSilenceableFailure::success();
991}
992
993void transform::FuseIntoContainingOp::getEffects(
994 SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
995 consumesHandle(getProducerOp(), effects);
996 onlyReadsHandle(getContainingOp(), effects);
997 producesHandle(getResults(), effects);
998 modifiesPayload(effects);
999}
1000
1001//===----------------------------------------------------------------------===//
1002// GeneralizeOp
1003//===----------------------------------------------------------------------===//
1004
1005DiagnosedSilenceableFailure
1006transform::GeneralizeOp::applyToOne(transform::TransformRewriter &rewriter,
1007 LinalgOp target,
1008 transform::ApplyToEachResultList &results,
1009 transform::TransformState &state) {
1010 // Exit early if no transformation is needed.
1011 if (isa<GenericOp>(target)) {
1012 results.push_back(target);
1013 return DiagnosedSilenceableFailure::success();
1014 }
1015 rewriter.setInsertionPoint(target);
1016 FailureOr<LinalgOp> generic = generalizeNamedOp(rewriter, target);
1017 if (succeeded(generic)) {
1018 results.push_back(generic->getOperation());
1019 return DiagnosedSilenceableFailure::success();
1020 }
1021 return emitDefaultSilenceableFailure(target);
1022}
1023
1024//===----------------------------------------------------------------------===//
1025// SpecializeOp
1026//===----------------------------------------------------------------------===/
1027
1028DiagnosedSilenceableFailure
1029transform::SpecializeOp::applyToOne(transform::TransformRewriter &rewriter,
1030 LinalgOp target,
1031 transform::ApplyToEachResultList &results,
1032 transform::TransformState &state) {
1033 // Exit early if the operation is not a generic.
1034 if (!isa<GenericOp>(target)) {
1035 results.push_back(target);
1036 return DiagnosedSilenceableFailure::success();
1037 }
1038 rewriter.setInsertionPoint(target);
1039 FailureOr<LinalgOp> named =
1040 specializeGenericOp(rewriter, cast<GenericOp>(target));
1041 if (succeeded(named)) {
1042 results.push_back(named->getOperation());
1043 return DiagnosedSilenceableFailure::success();
1044 }
1045 return emitDefaultSilenceableFailure(target);
1046}
1047
1048//===----------------------------------------------------------------------===//
1049// InterchangeOp
1050//===----------------------------------------------------------------------===//
1051
1052DiagnosedSilenceableFailure
1053transform::InterchangeOp::applyToOne(transform::TransformRewriter &rewriter,
1054 GenericOp target,
1055 transform::ApplyToEachResultList &results,
1056 transform::TransformState &state) {
1057 ArrayRef<int64_t> interchangeVector = getIteratorInterchange();
1058 // Exit early if no transformation is needed.
1059 if (interchangeVector.empty()) {
1060 results.push_back(target);
1061 return DiagnosedSilenceableFailure::success();
1062 }
1063
1064 unsigned numLoops = cast<LinalgOp>(target.getOperation()).getNumLoops();
1065 if (interchangeVector.size() != numLoops) {
1066 return emitSilenceableError()
1067 << getIteratorInterchangeAttrName() << " has length ("
1068 << interchangeVector.size()
1069 << ") different from the number of loops in the target operation ("
1070 << numLoops << ")";
1071 }
1072 FailureOr<GenericOp> res =
1073 interchangeGenericOp(rewriter, target,
1074 SmallVector<unsigned>(interchangeVector.begin(),
1075 interchangeVector.end()));
1076 if (failed(res))
1077 return emitDefiniteFailure() << "failed to apply";
1078 results.push_back(res->getOperation());
1079 return DiagnosedSilenceableFailure::success();
1080}
1081
1082LogicalResult transform::InterchangeOp::verify() {
1083 ArrayRef<int64_t> permutation = getIteratorInterchange();
1084 auto sequence = llvm::to_vector(llvm::seq<int64_t>(0, permutation.size()));
1085 if (!std::is_permutation(sequence.begin(), sequence.end(),
1086 permutation.begin(), permutation.end())) {
1087 return emitOpError()
1088 << "expects iterator_interchange to be a permutation, found "
1089 << getIteratorInterchange();
1090 }
1091 return success();
1092}
1093
1094//===----------------------------------------------------------------------===//
1095// LowerPackOp
1096//===----------------------------------------------------------------------===//
1097
1098DiagnosedSilenceableFailure transform::LowerPackOp::applyToOne(
1099 transform::TransformRewriter &rewriter, tensor::PackOp target,
1100 transform::ApplyToEachResultList &transformResults,
1101 transform::TransformState &state) {
1102 rewriter.setInsertionPoint(target);
1103 FailureOr<LowerPackResult> res = lowerPack(rewriter, target);
1104 if (failed(res)) {
1105 return mlir::emitSilenceableFailure(target->getLoc())
1106 << "cannot lower to pad + expand + transpose";
1107 }
1108 transformResults.push_back(res->padOp);
1109 transformResults.push_back(res->expandShapeOp);
1110 transformResults.push_back(res->transposeOp);
1111 return DiagnosedSilenceableFailure::success();
1112}
1113
1114//===----------------------------------------------------------------------===//
1115// LowerUnPackOp
1116//===----------------------------------------------------------------------===//
1117
1118DiagnosedSilenceableFailure transform::LowerUnPackOp::applyToOne(
1119 transform::TransformRewriter &rewriter, tensor::UnPackOp target,
1120 transform::ApplyToEachResultList &transformResults,
1121 transform::TransformState &state) {
1122 rewriter.setInsertionPoint(target);
1123 FailureOr<LowerUnPackOpResult> res = lowerUnPack(rewriter, target);
1124 if (failed(res)) {
1125 DiagnosedSilenceableFailure diag =
1126 emitSilenceableError()
1127 << "cannot lower to transpose + collapse + extract";
1128 diag.attachNote(target->getLoc()) << "target payload op";
1129 return diag;
1130 }
1131 transformResults.push_back(res->emptyOp);
1132 transformResults.push_back(res->transposeOp);
1133 transformResults.push_back(res->collapseShapeOp);
1134 transformResults.push_back(res->extractSliceOp);
1135 return DiagnosedSilenceableFailure::success();
1136}
1137
1138//===---------------------------------------------------------------------===//
1139// MatchOp
1140//===---------------------------------------------------------------------===//
1141
1142void transform::MatchOp::build(OpBuilder &builder, OperationState &result,
1143 Value target, ArrayRef<StringRef> opNames) {
1144 result.addOperands(target);
1145 result.addAttribute(MatchOp::getOpsAttrName(result.name),
1146 builder.getStrArrayAttr(opNames));
1147 result.addTypes(transform::AnyOpType::get(builder.getContext()));
1148}
1149
1150void transform::MatchOp::build(OpBuilder &builder, OperationState &result,
1151 TypeRange resultTypes, Value target,
1152 ArrayRef<StringRef> opNames) {
1153 result.addOperands(target);
1154 result.addAttribute(MatchOp::getOpsAttrName(result.name),
1155 builder.getStrArrayAttr(opNames));
1156 result.addTypes(resultTypes);
1157}
1158
1159DiagnosedSilenceableFailure
1160transform::MatchOp::apply(transform::TransformRewriter &rewriter,
1161 transform::TransformResults &results,
1162 transform::TransformState &state) {
1163 llvm::StringSet<> strs;
1164 if (getOps().has_value())
1165 strs.insert(getOps()->getAsValueRange<StringAttr>().begin(),
1166 getOps()->getAsValueRange<StringAttr>().end());
1167
1168 auto payloadOps = state.getPayloadOps(getTarget());
1169 if (!llvm::hasSingleElement(payloadOps)) {
1170 return emitDefiniteFailure("requires exactly one target handle");
1171 }
1172
1173 SmallVector<Operation *> res;
1174 bool incorrectNumOperandTypes = false;
1175 auto matchFun = [&](Operation *op) {
1176 if (getOps().has_value() && !strs.contains(op->getName().getStringRef()))
1177 return;
1178
1179 // Interfaces cannot be matched by name, just by ID.
1180 // So we specifically encode the interfaces we care about for this op.
1181 if (getInterface().has_value()) {
1182 auto iface = getInterface().value();
1183 if (iface == transform::MatchInterfaceEnum::LinalgOp &&
1184 !isa<LinalgOp>(op))
1185 return;
1186 if (iface == transform::MatchInterfaceEnum::TilingInterface &&
1187 !isa<TilingInterface>(op))
1188 return;
1189 if (iface == transform::MatchInterfaceEnum::LoopLikeInterface &&
1190 !isa<LoopLikeOpInterface>(op))
1191 return;
1192 }
1193
1194 // Check if all specified attributes match.
1195 if (getOpAttrs().has_value()) {
1196 DictionaryAttr opAttrs = getOpAttrs().value();
1197 for (NamedAttribute attr : opAttrs) {
1198 if (attr.getName() == getInterfaceAttrName() ||
1199 attr.getName() == getOpsAttrName())
1200 continue;
1201 if (!op->hasAttr(attr.getName()))
1202 return;
1203 if (op->getAttr(attr.getName()) != attr.getValue())
1204 return;
1205 }
1206 }
1207
1208 if (getFilterResultType().has_value()) {
1209 Type t = getFilterResultType().value();
1210 if (op->getNumResults() != 1 || op->getResultTypes().front() != t)
1211 return;
1212 }
1213
1214 if (getFilterOperandTypes().has_value()) {
1215 mlir::ArrayAttr types = getFilterOperandTypes().value();
1216 auto operandTypes = op->getOperandTypes();
1217
1218 if (types.size() == 1) {
1219 // All the operands must must be equal to the specified type
1220 auto typeattr =
1221 dyn_cast<mlir::TypeAttr>(getFilterOperandTypes().value()[0]);
1222 Type t = cast<::mlir::Type>(typeattr.getValue());
1223 if (!llvm::all_of(op->getOperandTypes(),
1224 [&](Type operandType) { return operandType == t; }))
1225 return;
1226 } else {
1227 // The operand types must match all the types in the list (in the same
1228 // order in with they are specified)
1229 if (types.size() != operandTypes.size()) {
1230 incorrectNumOperandTypes = true;
1231 return;
1232 }
1233
1234 for (auto [attr, operandType] :
1235 llvm::zip_equal(getFilterOperandTypes().value(), operandTypes)) {
1236 auto typeattr = cast<mlir::TypeAttr>(attr);
1237 Type type = cast<::mlir::Type>(typeattr.getValue());
1238
1239 if (type != operandType)
1240 return;
1241 }
1242 }
1243 }
1244
1245 // All constraints are satisfied.
1246 res.push_back(op);
1247 return;
1248 };
1249
1250 (*payloadOps.begin())->walk(matchFun);
1251 if (incorrectNumOperandTypes)
1252 return emitDefiniteFailure("If filter_operand_types contains more than a "
1253 "type, then it must contain as much types as "
1254 "the number of operands in the target ops");
1255 results.set(cast<OpResult>(getResult()), res);
1256 return DiagnosedSilenceableFailure::success();
1257}
1258
1259//===---------------------------------------------------------------------===//
1260// MultiTileSizesOp
1261//===---------------------------------------------------------------------===//
1262
1263static void printMultitileSizesTypes(OpAsmPrinter &printer, Operation *op,
1264 Type targetType, Type lowSizeType, Type,
1265 Type) {
1266 printer.printFunctionalType(inputs: TypeRange{targetType}, results: TypeRange{lowSizeType});
1267}
1268
1269static ParseResult parseMultitileSizesTypes(OpAsmParser &parser,
1270 Type &targetType, Type &lowSizeType,
1271 Type &highSizeType,
1272 Type &splitPointType) {
1273 FunctionType funcType;
1274 llvm::SMLoc typeLoc = parser.getCurrentLocation();
1275 if (failed(parser.parseType<FunctionType>(funcType)))
1276 return failure();
1277
1278 if (funcType.getNumInputs() != 1 || funcType.getNumResults() != 1) {
1279 parser.emitError(loc: typeLoc) << "expects a trailing functional type with one "
1280 "argument and one result";
1281 }
1282 targetType = funcType.getInput(0);
1283 lowSizeType = highSizeType = splitPointType = funcType.getResult(0);
1284
1285 return success();
1286}
1287
1288DiagnosedSilenceableFailure transform::MultiTileSizesOp::applyToOne(
1289 transform::TransformRewriter &rewriter, LinalgOp target,
1290 transform::ApplyToEachResultList &results, TransformState &state) {
1291 if (isa<TransformParamTypeInterface>(getLowSize().getType())) {
1292 if (target.hasDynamicShape()) {
1293 auto diag = emitSilenceableError()
1294 << "cannot compute parametric tile sizes for dynamically "
1295 "shaped payload op";
1296 diag.attachNote(target->getLoc()) << "payload op";
1297 return diag;
1298 }
1299
1300 FailureOr<StaticMultiSizeSpecification> spec = computeStaticMultiTileSizes(
1301 target, getDimension(), getTargetSize(), getDivisor());
1302 if (failed(spec)) {
1303 return emitSilenceableError()
1304 << "failed to compute multi-size tiling sizes";
1305 }
1306
1307 Builder builder(target.getContext());
1308 results.assign(llvm::map_range(
1309 ArrayRef<int64_t>({spec->lowTileSize, spec->highTileSize,
1310 spec->lowTileSize * spec->lowTripCount}),
1311 [&builder, this](int64_t value) {
1312 return builder.getIntegerAttr(
1313 cast<ParamType>(getLowSize().getType()).getType(), value);
1314 }));
1315 return DiagnosedSilenceableFailure::success();
1316 }
1317
1318 OpBuilder builder(target.getContext());
1319 builder.setInsertionPoint(target);
1320 OpFoldResult targetSize = builder.getIndexAttr(getTargetSize());
1321 OpFoldResult divisor = builder.getIndexAttr(getDivisor());
1322 FailureOr<MultiSizeSpecification> spec = computeMultiTileSizes(
1323 builder, target, getDimension(), targetSize, divisor);
1324 if (failed(spec)) {
1325 return emitSilenceableError() << "could not generate tile size computation";
1326 }
1327
1328 AffineExpr s0 = builder.getAffineSymbolExpr(0);
1329 AffineExpr s1 = builder.getAffineSymbolExpr(1);
1330 Operation *splitPoint =
1331 affine::makeComposedAffineApply(builder, target.getLoc(), s0 * s1,
1332 {spec->lowTileSize, spec->lowTripCount});
1333 Operation *lowTileSize = spec->lowTileSize.getDefiningOp();
1334 Operation *highTileSize = spec->highTileSize.getDefiningOp();
1335 assert(lowTileSize && highTileSize && splitPoint &&
1336 "tile sizes are not produced by operations");
1337 results.reserve(results.size() + 3);
1338 results.push_back(lowTileSize);
1339 results.push_back(highTileSize);
1340 results.push_back(splitPoint);
1341 return DiagnosedSilenceableFailure::success();
1342}
1343
1344void transform::MultiTileSizesOp::getEffects(
1345 SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
1346 onlyReadsHandle(getTarget(), effects);
1347 producesHandle(getResults(), effects);
1348 if (isa<TransformParamTypeInterface>(getLowSize().getType()))
1349 onlyReadsPayload(effects);
1350 else
1351 modifiesPayload(effects);
1352}
1353
1354LogicalResult transform::MultiTileSizesOp::verify() {
1355 if (getLowSize().getType() != getHighSize().getType() ||
1356 getLowSize().getType() != getSplitPoint().getType()) {
1357 return emitOpError() << "expects all results type to be the same";
1358 }
1359 return success();
1360}
1361
1362//===---------------------------------------------------------------------===//
1363// PackOp
1364//===---------------------------------------------------------------------===//
1365
1366void transform::PackOp::build(OpBuilder &builder, OperationState &result,
1367 Value target,
1368 ArrayRef<OpFoldResult> mixedPackedSizes) {
1369 SmallVector<int64_t> staticPackedSizes;
1370 SmallVector<Value> dynamicPackedSizes;
1371 dispatchIndexOpFoldResults(mixedPackedSizes, dynamicPackedSizes,
1372 staticPackedSizes);
1373 // Call the default builder which sets up the proper operands segment sizes
1374 // attributes for multiple variadic operands. In the absence of this, horrible
1375 // bugs ensue.
1376 Type linalgOpHType = transform::OperationType::get(
1377 builder.getContext(), GenericOp::getOperationName());
1378 build(builder, result,
1379 /*resultType=*/linalgOpHType,
1380 /*target=*/target,
1381 /*dynamic_sizes=*/dynamicPackedSizes,
1382 /*static_sizes=*/builder.getDenseI64ArrayAttr(staticPackedSizes));
1383}
1384
1385SmallVector<OpFoldResult> transform::PackOp::getMixedPackedSizes() {
1386 Builder b(getContext());
1387 return getMixedValues(getStaticPackedSizes(), getPackedSizes(), b);
1388}
1389
1390DiagnosedSilenceableFailure
1391transform::PackOp::apply(transform::TransformRewriter &rewriter,
1392 transform::TransformResults &transformResults,
1393 transform::TransformState &state) {
1394 auto targetOps = state.getPayloadOps(getTarget());
1395 // If nothing to pack, propagate success.
1396 if (std::empty(targetOps)) {
1397 transformResults.set(cast<OpResult>(getPackedOp()),
1398 ArrayRef<Operation *>({}));
1399 return DiagnosedSilenceableFailure::success();
1400 }
1401 // Fail on multi-op handles.
1402 auto linalgOp = dyn_cast<LinalgOp>(*targetOps.begin());
1403 if (!llvm::hasSingleElement(targetOps) || !linalgOp) {
1404 return emitSilenceableError()
1405 << "requires target to map to exactly 1 LinalgOp (got "
1406 << llvm::range_size(targetOps) << ")";
1407 }
1408 // Fail on mismatched number of pack sizes.
1409 if (getMixedPackedSizes().size() != linalgOp.getNumLoops()) {
1410 return emitSilenceableError()
1411 << "requires number of packed sizes match the number of loops ("
1412 << getMixedPackedSizes().size() << " vs " << linalgOp.getNumLoops()
1413 << ")";
1414 }
1415
1416 // Unpack handles to constants or actual SSA index values.
1417 SmallVector<OpFoldResult> packedSizes;
1418 DiagnosedSilenceableFailure status = unpackSingleIndexResultPayloadOperations(
1419 state, *this, packedSizes, getMixedPackedSizes());
1420
1421 rewriter.setInsertionPoint(linalgOp);
1422 FailureOr<PackResult> maybeResult = pack(rewriter, linalgOp, packedSizes);
1423 if (failed(maybeResult))
1424 return emitDefiniteFailure("data tiling failed");
1425
1426 transformResults.set(cast<OpResult>(getPackedOp()),
1427 {maybeResult->packedLinalgOp.getOperation()});
1428 return DiagnosedSilenceableFailure::success();
1429}
1430
1431void transform::PackOp::getEffects(
1432 SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
1433 transform::consumesHandle(getTarget(), effects);
1434 transform::onlyReadsHandle(getPackedSizes(), effects);
1435 transform::producesHandle(getPackedOp(), effects);
1436 transform::modifiesPayload(effects);
1437}
1438
1439//===---------------------------------------------------------------------===//
1440// PackGreedilyOp.
1441//===---------------------------------------------------------------------===//
1442
1443LogicalResult transform::PackGreedilyOp::verify() {
1444 if (!isPermutationVector(getMatmulInnerDimsOrder())) {
1445 return emitOpError() << getMatmulInnerDimsOrderAttrName()
1446 << " is not a valid permutation";
1447 }
1448 // TODO: relax to allow empty once we have another strategy than just matmul.
1449 if (!getMatmulPaddedSizesNextMultipleOf().empty()) {
1450 for (auto [s, nmo] :
1451 llvm::zip_equal(getMixedMatmulPackedSizes(),
1452 getMatmulPaddedSizesNextMultipleOf())) {
1453 std::optional<int64_t> maybeStaticPackedSize = getConstantIntValue(s);
1454 if (nmo != 0 &&
1455 (!maybeStaticPackedSize.has_value() || *maybeStaticPackedSize != 0)) {
1456 return emitOpError() << "at most one of the packed_size and the "
1457 "padded_sizes_next_multiple_of can be nonzero "
1458 "for the matmul strategy";
1459 }
1460 }
1461 }
1462 return success();
1463}
1464
1465DiagnosedSilenceableFailure
1466PackGreedilyOp::apply(transform::TransformRewriter &rewriter,
1467 transform::TransformResults &transformResults,
1468 transform::TransformState &state) {
1469 SmallVector<Operation *> results;
1470 for (Operation *op : state.getPayloadOps(getTarget())) {
1471 auto linalgOp = dyn_cast<LinalgOp>(op);
1472 if (!linalgOp)
1473 continue;
1474 // linalgOp will be replaced and the insertion point may be invalidated if
1475 // we set it before -> set it after.
1476 rewriter.setInsertionPointAfter(linalgOp);
1477 // Failing to pack greedily is perfectly fine.
1478 // In the future we will want to order packings according to some metric.
1479 FailureOr<PackResult> packResult = packMatmulGreedily(
1480 /*rewriter=*/rewriter,
1481 /*linalgOp=*/linalgOp,
1482 /*mnkPackedSizes=*/getMixedMatmulPackedSizes(),
1483 /*mnkPaddedSizesNextMultipleOf=*/
1484 getMatmulPaddedSizesNextMultipleOf(),
1485 /*mnkOrder=*/getMatmulInnerDimsOrder());
1486 if (succeeded(packResult)) {
1487 results.push_back(packResult->packedLinalgOp);
1488 continue;
1489 }
1490 results.push_back(linalgOp);
1491 }
1492 transformResults.set(cast<OpResult>(getPackedOp()), results);
1493 return DiagnosedSilenceableFailure::success();
1494}
1495
1496SmallVector<OpFoldResult> PackGreedilyOp::getMixedMatmulPackedSizes() {
1497 Builder b(getContext());
1498 return getMixedValues(getStaticMatmulPackedSizes(), getMatmulPackedSizes(),
1499 b);
1500}
1501
1502void transform::PackGreedilyOp::getEffects(
1503 SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
1504 transform::consumesHandle(getTarget(), effects);
1505 transform::onlyReadsHandle(getMatmulPackedSizes(), effects);
1506 transform::producesHandle(getPackedOp(), effects);
1507 transform::modifiesPayload(effects);
1508}
1509
1510//===---------------------------------------------------------------------===//
1511// PackTransposeOp
1512//===---------------------------------------------------------------------===//
1513
1514LogicalResult transform::PackTransposeOp::verify() {
1515 if (!isPermutationVector(getInnerPerm())) {
1516 return emitOpError() << getInnerPermAttrName()
1517 << " is not a valid permutation";
1518 }
1519 if (!isPermutationVector(getOuterPerm())) {
1520 return emitOpError() << getOuterPermAttrName()
1521 << " is not a valid permutation";
1522 }
1523 if (getInnerPerm().empty() && getOuterPerm().empty()) {
1524 return emitOpError() << " at least one of " << getInnerPermAttrName()
1525 << " or " << getOuterPermAttrName()
1526 << " must be specified";
1527 }
1528 return success();
1529}
1530
1531namespace {
1532enum class OuterOrInnerPerm { Outer = 0, Inner = 1 };
1533} // namespace
1534
1535/// Return true if `permutation` is a valid permutation of the
1536/// `outer_dims_perm` (case OuterOrInnerPerm::Outer) or `inner_dims_pos`
1537/// (OuterOrInnerPerm::Inner) of the `tensor.pack` or `tensor.unpack` `op.
1538/// This is the case when the `permutation` rank matches the rank expected by
1539/// `op` and `permutation` is itself a permutation vector.
1540/// Return true if either `op` or `permutation` are empty to allow a simpler
1541/// polymorphic implementation.
1542template <typename RelayoutOpTy>
1543bool isValidPackingPermutation(
1544 RelayoutOpTy op, ArrayRef<int64_t> permutation,
1545 OuterOrInnerPerm outerOrInnerPerm = OuterOrInnerPerm::Outer) {
1546 static_assert(
1547 llvm::is_one_of<RelayoutOpTy, tensor::PackOp, tensor::UnPackOp>::value,
1548 "applies to only pack or unpack operations");
1549 if (!op || permutation.empty())
1550 return true;
1551 size_t innerRank = op.getInnerDimsPos().size();
1552 if (outerOrInnerPerm == OuterOrInnerPerm::Inner)
1553 return permutation.size() == innerRank && isPermutationVector(interchange: permutation);
1554 // op.getOuterDimsPerm() may be empty, in which case it is identity.
1555 // Don't rely on it.
1556 if (std::is_same<RelayoutOpTy, tensor::PackOp>::value) {
1557 return permutation.size() == op.getSourceRank() &&
1558 isPermutationVector(interchange: permutation);
1559 }
1560 return permutation.size() == op.getDestRank() &&
1561 isPermutationVector(interchange: permutation);
1562}
1563
1564DiagnosedSilenceableFailure
1565transform::PackTransposeOp::apply(transform::TransformRewriter &rewriter,
1566 transform::TransformResults &transformResults,
1567 transform::TransformState &state) {
1568 auto packOrUnpackOps = state.getPayloadOps(getTargetPackOrUnPackOp());
1569 auto linalgOps = state.getPayloadOps(getTargetLinalgOp());
1570 // Step 1. If nothing to pack, propagate success.
1571 if (std::empty(packOrUnpackOps)) {
1572 transformResults.set(cast<OpResult>(getPackedOp()), {});
1573 transformResults.set(cast<OpResult>(getPackOp()), {});
1574 transformResults.set(cast<OpResult>(getUnPackOp()), {});
1575 return DiagnosedSilenceableFailure::success();
1576 }
1577
1578 // Step 2. Bunch of runtime sanity check and error messages.
1579 // Step 2.1. Fail on multi-op handles.
1580 if (!llvm::hasSingleElement(packOrUnpackOps) ||
1581 !llvm::hasSingleElement(linalgOps)) {
1582 return emitSilenceableError()
1583 << "requires target to map to exactly 1 "
1584 "packing op and 1 packed op ("
1585 << "got " << llvm::range_size(packOrUnpackOps) << " and "
1586 << llvm::range_size(linalgOps) << ")";
1587 }
1588
1589 // Step 2.2. Fail on wrong type.
1590 auto packOp = dyn_cast<tensor::PackOp>(*packOrUnpackOps.begin());
1591 auto unPackOp = dyn_cast<tensor::UnPackOp>(*packOrUnpackOps.begin());
1592 if ((!packOp && !unPackOp)) {
1593 return emitSilenceableError() << "requires target to map to a "
1594 "tensor.pack or tensor.unpack";
1595 }
1596 LinalgOp linalgOpTarget = dyn_cast<LinalgOp>(*linalgOps.begin());
1597 if (!linalgOpTarget)
1598 return emitSilenceableError() << "requires a LinalgOp target";
1599
1600 // Step 2.3. Fail if we can't get the producer / consumer Linalg op.
1601 LinalgOp linalgOp;
1602 if (packOp && packOp.getResult().hasOneUse())
1603 linalgOp = dyn_cast<LinalgOp>(*(packOp.getResult().getUsers().begin()));
1604 else if (unPackOp)
1605 linalgOp = unPackOp.getSource().getDefiningOp<LinalgOp>();
1606 if (linalgOp != linalgOpTarget) {
1607 auto errorMsg =
1608 packOp ? StringLiteral{"not a single use by the LinalgOp target"}
1609 : StringLiteral{"not produced by the LinalgOp target"};
1610 return emitSilenceableError() << errorMsg;
1611 }
1612
1613 // Step 2.4. If we have an UnPackOp, we need to fetch the symmetrical
1614 // PackOp.
1615 if (unPackOp) {
1616 assert(!packOp && "packOp must be null on entry when unPackOp is not null");
1617 OpOperand *packUse = linalgOp.getDpsInitOperand(
1618 cast<OpResult>(unPackOp.getSource()).getResultNumber());
1619 packOp = dyn_cast_or_null<tensor::PackOp>(packUse->get().getDefiningOp());
1620 if (!packOp || !packOp.getResult().hasOneUse())
1621 return emitSilenceableError() << "could not find matching pack op";
1622 }
1623
1624 // Step 2.5. Fail if any permutation does not validate.
1625 for (auto permType : {OuterOrInnerPerm::Outer, OuterOrInnerPerm::Inner}) {
1626 ArrayRef<int64_t> perm =
1627 (permType == OuterOrInnerPerm::Outer) ? getOuterPerm() : getInnerPerm();
1628 auto errorMsg = (permType == OuterOrInnerPerm::Outer)
1629 ? StringLiteral{"invalid outer_perm"}
1630 : StringLiteral{"invalid inner_perm"};
1631 if (!isValidPackingPermutation(packOp, perm, permType) ||
1632 !isValidPackingPermutation(unPackOp, perm, permType)) {
1633 Operation *packOrUnpackOp =
1634 unPackOp ? unPackOp.getOperation() : packOp.getOperation();
1635 return emitSilenceableError() << errorMsg << ": " << *packOrUnpackOp;
1636 }
1637 }
1638
1639 // From here on, packOp and linalgOp are always present, unPackOp may or may
1640 // not be present.
1641 assert(packOp && linalgOp && "unexpected null op");
1642
1643 // Step 3. Actually transpose the ops.
1644 FailureOr<PackTransposeResult> res = packTranspose(
1645 rewriter, packOp, linalgOp, unPackOp, getOuterPerm(), getInnerPerm());
1646 // Preconditions have been checked, it is an error to fail here.
1647 assert(succeeded(res) && "unexpected packTranspose failure");
1648
1649 // Step 4. Return results.
1650 transformResults.set(cast<OpResult>(getPackOp()), {res->transposedPackOp});
1651 transformResults.set(cast<OpResult>(getPackedOp()),
1652 {res->transposedLinalgOp});
1653 if (unPackOp) {
1654 transformResults.set(cast<OpResult>(getUnPackOp()),
1655 {res->transposedUnPackOp});
1656 } else {
1657 transformResults.set(cast<OpResult>(getUnPackOp()), {});
1658 }
1659
1660 return DiagnosedSilenceableFailure::success();
1661}
1662
1663//===---------------------------------------------------------------------===//
1664// PadOp
1665//===---------------------------------------------------------------------===//
1666
1667void transform::PadOp::build(OpBuilder &b, OperationState &result, Value target,
1668 ArrayRef<int64_t> paddingDimensions,
1669 ArrayRef<int64_t> padToMultipleOf,
1670 ArrayRef<int64_t> packPaddings,
1671 ArrayRef<Attribute> transposePaddings,
1672 StringRef copyBackOp) {
1673 auto resultType = transform::AnyOpType::get(b.getContext());
1674 return build(/*builder=*/b,
1675 /*result=*/result,
1676 /*types=*/TypeRange{resultType, resultType},
1677 /*target=*/target,
1678 /*paddingValues=*/ArrayAttr(), // let inference handle this
1679 /*paddingDimensions=*/b.getI64ArrayAttr(paddingDimensions),
1680 /*padToMultipleOf=*/
1681 (padToMultipleOf.empty() ? ArrayAttr()
1682 : b.getI64ArrayAttr(padToMultipleOf)),
1683 /*packPaddings=*/b.getI64ArrayAttr(packPaddings),
1684 /*transposePaddings=*/b.getArrayAttr(transposePaddings),
1685 /*copyBackOp=*/b.getStringAttr(copyBackOp));
1686}
1687
1688DiagnosedSilenceableFailure
1689transform::PadOp::apply(transform::TransformRewriter &rewriter,
1690 transform::TransformResults &results,
1691 transform::TransformState &state) {
1692 SmallVector<Operation *> paddedOps, padOps, copyBackOps;
1693
1694 for (Operation *target : state.getPayloadOps(getTarget())) {
1695 auto linalgTarget = dyn_cast<LinalgOp>(target);
1696 if (!linalgTarget) {
1697 auto diag = emitSilenceableError() << "expected LinalgOp target";
1698 diag.attachNote(target->getLoc()) << "target op";
1699 return diag;
1700 }
1701
1702 // Convert the integer packing flags to booleans.
1703 SmallVector<bool> packPaddings;
1704 for (int64_t packPadding :
1705 extractFromIntegerArrayAttr<int64_t>(getPackPaddings()))
1706 packPaddings.push_back(static_cast<bool>(packPadding));
1707
1708 // Convert the padding values to attributes.
1709 SmallVector<Attribute> paddingValues;
1710 for (auto const &it :
1711 llvm::zip(getPaddingValues(), linalgTarget->getOperandTypes())) {
1712 auto attr = dyn_cast<TypedAttr>(std::get<0>(it));
1713 if (!attr) {
1714 emitOpError("expects padding values to be typed attributes");
1715 return DiagnosedSilenceableFailure::definiteFailure();
1716 }
1717 Type elementType = getElementTypeOrSelf(std::get<1>(it));
1718 // Try to parse string attributes to obtain an attribute of element type.
1719 if (auto stringAttr = dyn_cast<StringAttr>(attr)) {
1720 auto parsedAttr = dyn_cast_if_present<TypedAttr>(parseAttribute(
1721 stringAttr, getContext(), elementType,
1722 /*numRead=*/nullptr, /*isKnownNullTerminated=*/true));
1723 if (!parsedAttr || parsedAttr.getType() != elementType) {
1724 auto diag = this->emitOpError("expects a padding that parses to ")
1725 << elementType << ", got " << std::get<0>(it);
1726 diag.attachNote(linalgTarget.getLoc()) << "when applied to this op";
1727 return DiagnosedSilenceableFailure::definiteFailure();
1728 }
1729 paddingValues.push_back(parsedAttr);
1730 continue;
1731 }
1732 // Otherwise, add the attribute directly.
1733 if (attr.getType() != elementType) {
1734 auto diag = this->emitOpError("expects a padding value of type ")
1735 << elementType << ", got " << attr;
1736 diag.attachNote(linalgTarget.getLoc()) << "when applied to this op";
1737 return DiagnosedSilenceableFailure::definiteFailure();
1738 }
1739 paddingValues.push_back(attr);
1740 }
1741
1742 // Extract the transpose vectors.
1743 SmallVector<SmallVector<int64_t>> transposePaddings;
1744 for (Attribute transposeVector : cast<ArrayAttr>(getTransposePaddings()))
1745 transposePaddings.push_back(extractFromIntegerArrayAttr<int64_t>(
1746 cast<ArrayAttr>(transposeVector)));
1747
1748 LinalgOp paddedOp;
1749 LinalgPaddingOptions options;
1750 options.paddingDimensions =
1751 extractFromIntegerArrayAttr<int64_t>(getPaddingDimensions());
1752 SmallVector<int64_t> padToMultipleOf(options.paddingDimensions.size(), 1);
1753 if (getPadToMultipleOf().has_value())
1754 padToMultipleOf =
1755 extractFromIntegerArrayAttr<int64_t>(*getPadToMultipleOf());
1756 options.padToMultipleOf = padToMultipleOf;
1757 options.paddingValues = paddingValues;
1758 options.packPaddings = packPaddings;
1759 if (getCopyBackOp() ==
1760 bufferization::MaterializeInDestinationOp::getOperationName()) {
1761 options.copyBackOp = LinalgPaddingOptions::CopyBackOp::
1762 BufferizationMaterializeInDestination;
1763 } else if (getCopyBackOp() == linalg::CopyOp::getOperationName()) {
1764 options.copyBackOp = LinalgPaddingOptions::CopyBackOp::LinalgCopy;
1765 } else if (getCopyBackOp() == kCopyOpNone) {
1766 options.copyBackOp = LinalgPaddingOptions::CopyBackOp::None;
1767 } else {
1768 llvm_unreachable("unsupported copy_back op");
1769 }
1770
1771 SmallVector<Value> replacements;
1772 SmallVector<tensor::PadOp> newPadOps;
1773 if (failed(rewriteAsPaddedOp(rewriter, linalgTarget, options, paddedOp,
1774 replacements, newPadOps))) {
1775 auto diag = emitSilenceableError() << "failed to pad op";
1776 diag.attachNote(target->getLoc()) << "target op";
1777 return diag;
1778 }
1779
1780 // We need to perform our own replacement here because this API is still
1781 // used in patterns that "pad and hoist", for which the replacement values
1782 // need to be different.
1783 // TODO: clean this up and stop "pad and hoist" behavior more globally now
1784 // that we have more composable abstractions.
1785 rewriter.replaceOp(linalgTarget, replacements);
1786 paddedOps.push_back(paddedOp);
1787 padOps.append(newPadOps.begin(), newPadOps.end());
1788 if (options.copyBackOp != LinalgPaddingOptions::CopyBackOp::None) {
1789 for (Value v : replacements) {
1790 Operation *copyBackOp = v.getDefiningOp();
1791 if (!llvm::is_contained(copyBackOps, copyBackOp))
1792 copyBackOps.push_back(copyBackOp);
1793 }
1794 }
1795 }
1796
1797 results.set(cast<OpResult>(getPadded()), paddedOps);
1798 results.set(cast<OpResult>(getPad()), padOps);
1799 results.set(cast<OpResult>(getCopy()), copyBackOps);
1800 return DiagnosedSilenceableFailure::success();
1801}
1802
1803LogicalResult transform::PadOp::verify() {
1804 SmallVector<int64_t> packPaddings =
1805 extractFromIntegerArrayAttr<int64_t>(getPackPaddings());
1806 if (any_of(packPaddings, [](int64_t packPadding) {
1807 return packPadding != 0 && packPadding != 1;
1808 })) {
1809 return emitOpError()
1810 << "expects pack_paddings to contain booleans (0/1), found "
1811 << getPackPaddings();
1812 }
1813
1814 SmallVector<int64_t> paddingDimensions =
1815 extractFromIntegerArrayAttr<int64_t>(getPaddingDimensions());
1816 if (any_of(paddingDimensions,
1817 [](int64_t paddingDimension) { return paddingDimension < 0; })) {
1818 return emitOpError() << "expects padding_dimensions to contain positive "
1819 "integers, found "
1820 << getPaddingDimensions();
1821 }
1822 if (getPadToMultipleOf().has_value()) {
1823 if (getPadToMultipleOf()->size() != paddingDimensions.size()) {
1824 return emitOpError() << "expects as many multiples as padding_dimensions";
1825 }
1826 }
1827 ArrayAttr transposes = getTransposePaddings();
1828 for (Attribute attr : transposes) {
1829 SmallVector<int64_t> transpose = extractFromIntegerArrayAttr<int64_t>(attr);
1830 auto sequence = llvm::to_vector(llvm::seq<int64_t>(0, transpose.size()));
1831 if (!std::is_permutation(sequence.begin(), sequence.end(),
1832 transpose.begin(), transpose.end())) {
1833 return emitOpError()
1834 << "expects transpose_paddings to be a permutation, found "
1835 << attr;
1836 }
1837 }
1838 if (getCopyBackOp() !=
1839 bufferization::MaterializeInDestinationOp::getOperationName() &&
1840 getCopyBackOp() != linalg::CopyOp::getOperationName() &&
1841 getCopyBackOp() != kCopyOpNone)
1842 return emitOpError() << "invalid copy_back_op";
1843 return success();
1844}
1845
1846//===---------------------------------------------------------------------===//
1847// HoistPadOp
1848//===---------------------------------------------------------------------===//
1849
1850DiagnosedSilenceableFailure transform::HoistPadBuildPackingLoopNestOp::apply(
1851 transform::TransformRewriter &rewriter,
1852 transform::TransformResults &transformResults,
1853 transform::TransformState &state) {
1854 auto targetOps = state.getPayloadOps(getTarget());
1855 auto loopOps = state.getPayloadOps(getLoop());
1856 if (!llvm::hasSingleElement(targetOps) || !llvm::hasSingleElement(loopOps)) {
1857 return emitDefiniteFailure()
1858 << "requires exactly one target and one loop handle (got "
1859 << llvm::range_size(targetOps) << " and "
1860 << llvm::range_size(loopOps) << ")";
1861 }
1862
1863 auto padOp = dyn_cast_or_null<tensor::PadOp>(*targetOps.begin());
1864 auto loopOp = dyn_cast_or_null<scf::ForOp>(*loopOps.begin());
1865 if (!padOp || !loopOp)
1866 return emitDefiniteFailure() << "requires exactly 2 non-null handles";
1867
1868 FailureOr<linalg::detail::PackingResult> result =
1869 linalg::detail::buildPackingLoopNest(rewriter, padOp, loopOp,
1870 getTranspose());
1871 if (failed(result))
1872 return emitDefiniteFailure() << "could not build packing loop nest";
1873
1874 if (result->clonedLoopIvs.empty()) {
1875 transformResults.set(cast<OpResult>(getPackingLoop()),
1876 {result->hoistedPadOp.getOperation()});
1877 return DiagnosedSilenceableFailure::success();
1878 }
1879 auto outerPackedLoop =
1880 scf::getForInductionVarOwner(result->clonedLoopIvs.front());
1881 transformResults.set(cast<OpResult>(getPackingLoop()),
1882 {outerPackedLoop.getOperation()});
1883 return DiagnosedSilenceableFailure::success();
1884}
1885
1886LogicalResult transform::HoistPadBuildPackingLoopNestOp::verify() {
1887 ArrayRef<int64_t> transpose = getTranspose();
1888 auto sequence = llvm::to_vector(llvm::seq<int64_t>(0, transpose.size()));
1889 if (!std::is_permutation(sequence.begin(), sequence.end(), transpose.begin(),
1890 transpose.end())) {
1891 return emitOpError() << "expects transpose to be a permutation, found "
1892 << getTranspose();
1893 }
1894 return success();
1895}
1896
1897void transform::HoistPadBuildPackingLoopNestOp::getEffects(
1898 SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
1899 transform::onlyReadsHandle(getTarget(), effects);
1900 transform::onlyReadsHandle(getLoop(), effects);
1901 transform::producesHandle(getPackingLoop(), effects);
1902 transform::modifiesPayload(effects);
1903}
1904
1905DiagnosedSilenceableFailure
1906transform::HoistPadOp::applyToOne(transform::TransformRewriter &rewriter,
1907 tensor::PadOp target,
1908 transform::ApplyToEachResultList &results,
1909 transform::TransformState &state) {
1910 tensor::PadOp hoistedPadOp;
1911 SmallVector<GenericOp> transposeOps;
1912 FailureOr<Value> result =
1913 hoistPaddingOnTensors(rewriter, target, getNumLoops(), getTranspose(),
1914 hoistedPadOp, transposeOps);
1915 if (succeeded(result)) {
1916 // We need to perform our own replacement here because this API is still
1917 // used in patterns that "pad and hoist", for which the replacement values
1918 // need to be different.
1919 // TODO: clean this up and stop "pad and hoist" behavior more globally now
1920 // that we have more composable abstractions.
1921 rewriter.replaceOp(target, *result);
1922 results.push_back(hoistedPadOp);
1923 return DiagnosedSilenceableFailure::success();
1924 }
1925 return emitDefaultSilenceableFailure(target);
1926}
1927
1928LogicalResult transform::HoistPadOp::verify() {
1929 ArrayRef<int64_t> transpose = getTranspose();
1930 auto sequence = llvm::to_vector(llvm::seq<int64_t>(0, transpose.size()));
1931 if (!std::is_permutation(sequence.begin(), sequence.end(), transpose.begin(),
1932 transpose.end())) {
1933 return emitOpError() << "expects transpose to be a permutation, found "
1934 << getTranspose();
1935 }
1936 return success();
1937}
1938
1939//===----------------------------------------------------------------------===//
1940// PromoteOp
1941//===----------------------------------------------------------------------===//
1942
1943DiagnosedSilenceableFailure
1944transform::PromoteOp::applyToOne(transform::TransformRewriter &rewriter,
1945 LinalgOp target,
1946 transform::ApplyToEachResultList &results,
1947 transform::TransformState &state) {
1948 LinalgPromotionOptions promotionOptions;
1949 if (!getOperandsToPromote().empty())
1950 promotionOptions = promotionOptions.setOperandsToPromote(
1951 extractFromIntegerArrayAttr<int64_t>(getOperandsToPromote()));
1952 if (getUseFullTilesByDefault())
1953 promotionOptions = promotionOptions.setUseFullTileBuffersByDefault(
1954 getUseFullTilesByDefault());
1955 if (getUseAlloca())
1956 promotionOptions = promotionOptions.setUseAlloca(getUseAlloca());
1957 if (!getUseFullTileBuffers().empty())
1958 promotionOptions = promotionOptions.setUseFullTileBuffers(
1959 llvm::to_vector(getUseFullTileBuffers().getAsValueRange<BoolAttr>()));
1960 if (getAlignment().has_value())
1961 promotionOptions = promotionOptions.setAlignment(*getAlignment());
1962 if (getMemorySpace().has_value())
1963 promotionOptions = promotionOptions.setMemorySpace(*getMemorySpace());
1964
1965 if (getMapping().has_value()) {
1966 // The mapping should only contain an element
1967 auto mapping = *getMapping();
1968 if (mapping.size() > 1)
1969 return emitDefaultDefiniteFailure(target);
1970
1971 auto addressSpace = cast<mlir::gpu::GPUMemorySpaceMappingAttr>(mapping[0]);
1972
1973 if (addressSpace.getAddressSpace() ==
1974 mlir::gpu::GPUDialect::getWorkgroupAddressSpace()) {
1975 promotionOptions =
1976 promotionOptions
1977 .setAllocationDeallocationFns(allocateWorkgroupMemory,
1978 deallocateWorkgroupMemory)
1979 .setCopyInOutFns(copyToWorkgroupMemory, copyToWorkgroupMemory)
1980 .setUseFullTileBuffers({false, false});
1981 } else if (addressSpace.getAddressSpace() ==
1982 mlir::gpu::GPUDialect::getPrivateAddressSpace()) {
1983 promotionOptions =
1984 promotionOptions
1985 .setAllocationDeallocationFns(allocateGPUPrivateMemory,
1986 deallocateGPUPrivateMemory)
1987 .setCopyInOutFns(copyToGPUPrivateMemory, copyToGPUPrivateMemory)
1988 .setUseFullTileBuffers({false, false});
1989 } else {
1990 return emitDefaultDefiniteFailure(target);
1991 }
1992 }
1993
1994 if (failed(promoteSubviewsPrecondition(target, promotionOptions)))
1995 return emitDefaultDefiniteFailure(target);
1996
1997 rewriter.setInsertionPoint(target);
1998 FailureOr<LinalgOp> res = promoteSubViews(rewriter, target, promotionOptions);
1999 if (failed(res))
2000 return emitDefaultDefiniteFailure(target);
2001 results.push_back(target);
2002 return DiagnosedSilenceableFailure::success();
2003}
2004
2005//===----------------------------------------------------------------------===//
2006// ReplaceOp
2007//===----------------------------------------------------------------------===//
2008
2009DiagnosedSilenceableFailure
2010transform::ReplaceOp::apply(transform::TransformRewriter &rewriter,
2011 TransformResults &transformResults,
2012 TransformState &state) {
2013 auto payload = state.getPayloadOps(getTarget());
2014
2015 // Check for invalid targets.
2016 for (Operation *target : payload) {
2017 if (target->getNumOperands() > 0)
2018 return emitDefiniteFailure() << "expected target without operands";
2019 if (!target->hasTrait<OpTrait::IsIsolatedFromAbove>() &&
2020 target->getNumRegions() > 0)
2021 return emitDefiniteFailure()
2022 << "expected target that is isolated from above";
2023 }
2024
2025 // Clone and replace.
2026 Operation *pattern = &getBodyRegion().front().front();
2027 SmallVector<Operation *> replacements;
2028 for (Operation *target : payload) {
2029 if (getOperation()->isAncestor(target))
2030 continue;
2031 rewriter.setInsertionPoint(target);
2032 Operation *replacement = rewriter.clone(*pattern);
2033 rewriter.replaceOp(target, replacement->getResults());
2034 replacements.push_back(replacement);
2035 }
2036 transformResults.set(cast<OpResult>(getReplacement()), replacements);
2037 return DiagnosedSilenceableFailure::success();
2038}
2039
2040void transform::ReplaceOp::getEffects(
2041 SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
2042 consumesHandle(getTarget(), effects);
2043 producesHandle(getReplacement(), effects);
2044 modifiesPayload(effects);
2045}
2046
2047LogicalResult transform::ReplaceOp::verify() {
2048 if (!getBodyRegion().hasOneBlock())
2049 return emitOpError() << "expected one block";
2050 if (std::distance(getBodyRegion().front().begin(),
2051 getBodyRegion().front().end()) != 1)
2052 return emitOpError() << "expected one operation in block";
2053 Operation *replacement = &getBodyRegion().front().front();
2054 if (replacement->getNumOperands() > 0)
2055 return replacement->emitOpError()
2056 << "expected replacement without operands";
2057 if (!replacement->hasTrait<OpTrait::IsIsolatedFromAbove>() &&
2058 replacement->getNumRegions() > 0)
2059 return replacement->emitOpError()
2060 << "expect op that is isolated from above";
2061 return success();
2062}
2063
2064//===----------------------------------------------------------------------===//
2065// ScalarizeOp
2066//===----------------------------------------------------------------------===//
2067
2068DiagnosedSilenceableFailure
2069transform::ScalarizeOp::applyToOne(transform::TransformRewriter &rewriter,
2070 LinalgOp target,
2071 transform::ApplyToEachResultList &results,
2072 transform::TransformState &state) {
2073 scf::SCFTilingOptions tilingOptions;
2074 tilingOptions.setTileSizeComputationFunction([&](OpBuilder &b, Operation *) {
2075 SmallVector<OpFoldResult> tileSizes;
2076 Location loc = target.getLoc();
2077 SmallVector<OpFoldResult> allShapeSizes =
2078 target.createFlatListOfOperandDims(b, loc);
2079 AffineMap map = target.getShapesToLoopsMap();
2080 if (!map)
2081 return tileSizes;
2082 SmallVector<OpFoldResult> shapeSizes =
2083 affine::makeComposedFoldedMultiResultAffineApply(rewriter, loc, map,
2084 allShapeSizes);
2085 // If the shape size is dynamic, tile by 1.
2086 // Otherwise, do not tile (i.e. tile size 0).
2087 for (OpFoldResult shapeSize : shapeSizes) {
2088 tileSizes.push_back(getConstantIntValue(shapeSize) ? b.getIndexAttr(0)
2089 : b.getIndexAttr(1));
2090 }
2091 return tileSizes;
2092 });
2093 SmallVector<int64_t> emptyTileSizes;
2094 rewriter.setInsertionPoint(target);
2095 FailureOr<scf::SCFTilingResult> maybeTilingResult = tileUsingSCF(
2096 rewriter, cast<TilingInterface>(target.getOperation()), tilingOptions);
2097 if (failed(maybeTilingResult))
2098 return emitDefaultDefiniteFailure(target);
2099
2100 if (target->getNumResults())
2101 rewriter.replaceOp(target, maybeTilingResult->replacements);
2102 else
2103 rewriter.eraseOp(target);
2104
2105 results.reserve(maybeTilingResult->tiledOps.size());
2106 for (Operation *tiled : maybeTilingResult->tiledOps)
2107 results.push_back(tiled);
2108 return DiagnosedSilenceableFailure::success();
2109}
2110
2111//===----------------------------------------------------------------------===//
2112// ConvertToLoopsOp
2113//===----------------------------------------------------------------------===//
2114
2115DiagnosedSilenceableFailure
2116transform::ConvertToLoopsOp::apply(transform::TransformRewriter &rewriter,
2117 transform::TransformResults &results,
2118 transform::TransformState &state) {
2119 SmallVector<Operation *> loops;
2120 for (Operation *target : state.getPayloadOps(getTarget())) {
2121 auto tilingOp = dyn_cast<TilingInterface>(*target);
2122 if (!target) {
2123 DiagnosedSilenceableFailure diag =
2124 emitSilenceableError()
2125 << "expected the payload to implement TilingInterface";
2126 diag.attachNote(target->getLoc()) << "payload op";
2127 return diag;
2128 }
2129 rewriter.setInsertionPoint(target);
2130 FailureOr<SmallVector<scf::ForOp>> generatedLoops =
2131 scf::lowerToLoopsUsingSCFForOp(rewriter, tilingOp);
2132 if (failed(generatedLoops))
2133 return emitDefaultDefiniteFailure(target);
2134 for (scf::ForOp &loop : *generatedLoops) {
2135 loops.push_back(loop.getOperation());
2136 }
2137 rewriter.eraseOp(target);
2138 }
2139 results.set(cast<OpResult>(getResult()), loops);
2140 return DiagnosedSilenceableFailure::success();
2141}
2142
2143//===----------------------------------------------------------------------===//
2144// RewriteInDestinationPassingStyleOp
2145//===----------------------------------------------------------------------===//
2146
2147DiagnosedSilenceableFailure
2148transform::RewriteInDestinationPassingStyleOp::applyToOne(
2149 transform::TransformRewriter &rewriter, Operation *target,
2150 transform::ApplyToEachResultList &results,
2151 transform::TransformState &state) {
2152 SmallVector<Operation *> res;
2153 rewriter.setInsertionPoint(target);
2154 FailureOr<Operation *> maybeResult =
2155 TypeSwitch<Operation *, FailureOr<Operation *>>(target)
2156 .Case<tensor::FromElementsOp, tensor::GenerateOp, tensor::PadOp>(
2157 [&rewriter](auto op) {
2158 return rewriteInDestinationPassingStyle(rewriter, op);
2159 });
2160 if (failed(maybeResult))
2161 return emitDefaultSilenceableFailure(target);
2162 results.push_back(*maybeResult);
2163 return DiagnosedSilenceableFailure::success();
2164}
2165
2166//===----------------------------------------------------------------------===//
2167// SplitOp
2168//===----------------------------------------------------------------------===//
2169
2170DiagnosedSilenceableFailure
2171SplitOp::apply(transform::TransformRewriter &rewriter,
2172 TransformResults &results, TransformState &state) {
2173 // Collect the dynamic split points if provided.
2174 SmallVector<Operation *> payload =
2175 llvm::to_vector(state.getPayloadOps(getTarget()));
2176 SmallVector<OpFoldResult> splitPoints;
2177 splitPoints.reserve(payload.size());
2178 if (getDynamicSplitPoint()) {
2179 auto diag = DiagnosedSilenceableFailure::success();
2180 if (isa<TransformHandleTypeInterface>(getDynamicSplitPoint().getType())) {
2181 splitPoints = llvm::to_vector(llvm::map_range(
2182 state.getPayloadOps(getDynamicSplitPoint()), [&](Operation *op) {
2183 if (op->getNumResults() != 1 ||
2184 !op->getResult(0).getType().isIndex()) {
2185 diag = emitSilenceableError()
2186 << "expected dynamic split point handle to point to a "
2187 "single-result index-typed op";
2188 diag.attachNote(op->getLoc()) << "dynamic split point";
2189 }
2190 return OpFoldResult(op->getResult(0));
2191 }));
2192 } else {
2193 splitPoints = llvm::to_vector(
2194 llvm::map_range(state.getParams(getDynamicSplitPoint()),
2195 [](Attribute attr) { return OpFoldResult(attr); }));
2196 }
2197 if (diag.isSilenceableFailure())
2198 return diag;
2199
2200 if (splitPoints.size() != payload.size()) {
2201 return emitDefiniteFailure()
2202 << "expected the dynamic split point handle to point to as "
2203 "many operations ("
2204 << splitPoints.size() << ") as the target handle ("
2205 << payload.size() << ")";
2206 }
2207 } else {
2208 splitPoints.resize(payload.size(),
2209 rewriter.getIndexAttr(getStaticSplitPoint()));
2210 }
2211
2212 // Split each target operation.
2213 SmallVector<Operation *> first, second;
2214 Operation *noSecondPart = nullptr;
2215 for (const auto &pair : llvm::zip(payload, splitPoints)) {
2216 Operation *target = std::get<0>(pair);
2217 auto linalgOp = dyn_cast<LinalgOp>(target);
2218 if (!linalgOp) {
2219 auto diag = emitSilenceableError() << "only applies to structured ops";
2220 diag.attachNote(target->getLoc()) << "target op";
2221 return diag;
2222 }
2223
2224 if (getDimension() >= linalgOp.getNumLoops()) {
2225 auto diag = emitSilenceableError() << "dimension " << getDimension()
2226 << " does not exist in target op";
2227 diag.attachNote(target->getLoc()) << "target op";
2228 return diag;
2229 }
2230
2231 rewriter.setInsertionPoint(linalgOp);
2232 std::tie(first.emplace_back(), second.emplace_back()) = linalg::splitOp(
2233 rewriter, cast<TilingInterface>(linalgOp.getOperation()),
2234 getDimension(), std::get<1>(pair));
2235
2236 // Propagate errors.
2237 if (!first.back() && !second.back()) {
2238 auto diag = emitDefiniteFailure() << "internal failure in splitting";
2239 diag.attachNote(target->getLoc()) << "target op";
2240 return diag;
2241 }
2242
2243 // Do not add null second parts.
2244 if (!second.back()) {
2245 noSecondPart = target;
2246 second.pop_back();
2247 }
2248 }
2249
2250 if (second.size() != first.size() && !second.empty()) {
2251 auto diag = emitSilenceableError()
2252 << "splitting does not produce the second part for a subset "
2253 "of targets";
2254 diag.attachNote() << "expected splitting to produce the second part of all "
2255 "or none of the targets";
2256 diag.attachNote(noSecondPart->getLoc())
2257 << "first target with no second part";
2258 return diag;
2259 }
2260
2261 results.set(cast<OpResult>(getFirst()), first);
2262 results.set(cast<OpResult>(getSecond()), second);
2263 return DiagnosedSilenceableFailure::success();
2264}
2265
2266void SplitOp::getEffects(
2267 SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
2268 consumesHandle(getTarget(), effects);
2269 if (getDynamicSplitPoint())
2270 onlyReadsHandle(getDynamicSplitPoint(), effects);
2271 producesHandle(getResults(), effects);
2272 modifiesPayload(effects);
2273}
2274
2275ParseResult SplitOp::parse(OpAsmParser &parser, OperationState &result) {
2276 OpAsmParser::UnresolvedOperand target, dynamicSplitPoint;
2277 IntegerAttr staticSplitPoint;
2278 if (parser.parseOperand(target) || parser.parseKeyword("after"))
2279 return failure();
2280
2281 OptionalParseResult dynamicPointParseResult =
2282 parser.parseOptionalOperand(dynamicSplitPoint);
2283 if (!dynamicPointParseResult.has_value()) {
2284 int64_t staticSplitPointValue;
2285 if (failed(parser.parseInteger(staticSplitPointValue)))
2286 return failure();
2287
2288 staticSplitPoint =
2289 parser.getBuilder().getI64IntegerAttr(staticSplitPointValue);
2290 }
2291
2292 Type targetType;
2293 if (parser.parseOptionalAttrDict(result.attributes) ||
2294 parser.parseColonType(targetType) ||
2295 parser.resolveOperand(target, targetType, result.operands)) {
2296 return failure();
2297 }
2298 if (dynamicPointParseResult.has_value()) {
2299 Type splitPointType;
2300 if (failed(*dynamicPointParseResult) || parser.parseComma() ||
2301 parser.parseType(splitPointType) ||
2302 parser.resolveOperand(dynamicSplitPoint, splitPointType,
2303 result.operands)) {
2304 return failure();
2305 }
2306
2307 staticSplitPoint =
2308 parser.getBuilder().getI64IntegerAttr(ShapedType::kDynamic);
2309 }
2310
2311 result.addAttribute(
2312 SplitOp::getStaticSplitPointAttrName(result.name).getValue(),
2313 staticSplitPoint);
2314 result.addTypes({targetType, targetType});
2315 return success();
2316}
2317
2318void SplitOp::print(OpAsmPrinter &printer) {
2319 printer << " " << getTarget() << " after ";
2320 int64_t staticSplitSize = static_cast<int64_t>(getStaticSplitPoint());
2321 if (staticSplitSize != ShapedType::kDynamic)
2322 printer << staticSplitSize;
2323 else
2324 printer << getDynamicSplitPoint();
2325 printer << " ";
2326 printer.printOptionalAttrDict(getOperation()->getAttrs(),
2327 {getStaticSplitPointAttrName()});
2328 printer << " : " << getTarget().getType();
2329 if (staticSplitSize == ShapedType::kDynamic)
2330 printer << ", " << getDynamicSplitPoint().getType();
2331}
2332
2333LogicalResult SplitOp::verify() {
2334 if ((static_cast<int64_t>(getStaticSplitPoint()) != ShapedType::kDynamic) ^
2335 (getDynamicSplitPoint() == nullptr)) {
2336 return emitOpError() << "expects either a dynamic or a static split "
2337 "point to be provided";
2338 }
2339 return success();
2340}
2341
2342//===----------------------------------------------------------------------===//
2343// SplitReductionOp
2344//===----------------------------------------------------------------------===//
2345
2346void transform::SplitReductionOp::build(
2347 OpBuilder &builder, OperationState &result, Value target,
2348 int64_t splitFactor, int64_t insertSplitDimension, bool innerParallel,
2349 bool useScalingAlgorithm, bool useAlloc) {
2350 MLIRContext *ctx = builder.getContext();
2351 result.addOperands(target);
2352 result.addAttribute(SplitReductionOp::getSplitFactorAttrName(result.name),
2353 builder.getI64IntegerAttr(splitFactor));
2354 result.addAttribute(
2355 SplitReductionOp::getInsertSplitDimensionAttrName(result.name),
2356 builder.getI64IntegerAttr(insertSplitDimension));
2357 if (innerParallel) {
2358 result.addAttribute(SplitReductionOp::getInnerParallelAttrName(result.name),
2359 builder.getUnitAttr());
2360 }
2361 if (useScalingAlgorithm) {
2362 result.addAttribute(
2363 SplitReductionOp::getUseScalingAlgorithmAttrName(result.name),
2364 builder.getUnitAttr());
2365 }
2366 if (useAlloc) {
2367 result.addAttribute(SplitReductionOp::getUseAllocAttrName(result.name),
2368 builder.getUnitAttr());
2369 }
2370 auto resultType = transform::AnyOpType::get(ctx);
2371 result.addTypes({resultType, resultType, resultType, resultType});
2372}
2373
2374DiagnosedSilenceableFailure transform::SplitReductionOp::applyToOne(
2375 transform::TransformRewriter &rewriter, LinalgOp target,
2376 transform::ApplyToEachResultList &results,
2377 transform::TransformState &state) {
2378 ControlSplitReductionFn splitFn = [&](LinalgOp) {
2379 return linalg::SplitReductionOptions{int64_t(getSplitFactor()),
2380 unsigned(getInsertSplitDimension()),
2381 bool(getInnerParallel())};
2382 };
2383 rewriter.setInsertionPoint(target);
2384 FailureOr<SplitReductionResult> splitResult =
2385 (getUseScalingAlgorithm())
2386 ? splitReductionByScaling(rewriter, target, splitFn, getUseAlloc())
2387 : splitReduction(rewriter, target, splitFn, getUseAlloc());
2388 if (failed(splitResult))
2389 return emitDefaultDefiniteFailure(target);
2390
2391 results.push_back(splitResult->initOrAlloc);
2392 results.push_back(splitResult->fillOp);
2393 results.push_back(splitResult->splitLinalgOp);
2394 results.push_back(splitResult->resultCombiningLinalgOp);
2395 return DiagnosedSilenceableFailure::success();
2396}
2397
2398//===----------------------------------------------------------------------===//
2399// TileReductionUsingForOp
2400//===----------------------------------------------------------------------===//
2401
2402void transform::TileReductionUsingForOp::build(
2403 OpBuilder &builder, OperationState &result, Value target,
2404 ArrayRef<int64_t> staticTileSizes) {
2405 // Call the default builder.
2406 // This is future-proof re mixed static-dynamic and setting up the proper
2407 // operands segment sizes attributes for multiple variadic operands.
2408 // In the absence of this, horrible bugs ensue.
2409 // TODO: support mixed static-dynamic (see TileUsingForallOp).
2410 MLIRContext *ctx = builder.getContext();
2411 auto opTy = transform::AnyOpType::get(ctx);
2412 auto staticTileSizesAttr = builder.getDenseI64ArrayAttr(staticTileSizes);
2413 build(builder, result,
2414 /*resultTypes=*/TypeRange{opTy, opTy, opTy, opTy},
2415 /*target=*/target,
2416 /*tile_sizes=*/staticTileSizesAttr);
2417}
2418
2419DiagnosedSilenceableFailure transform::TileReductionUsingForOp::applyToOne(
2420 transform::TransformRewriter &rewriter, LinalgOp target,
2421 transform::ApplyToEachResultList &results,
2422 transform::TransformState &state) {
2423 rewriter.setInsertionPoint(target);
2424 FailureOr<scf::SCFReductionTilingResult> result = scf::tileReductionUsingScf(
2425 rewriter, cast<PartialReductionOpInterface>(target.getOperation()),
2426 getAsOpFoldResult(rewriter.getI64ArrayAttr(getTileSizes())));
2427
2428 if (failed(result))
2429 return emitDefaultSilenceableFailure(target);
2430 results.push_back(result->initialOp);
2431 results.push_back(result->parallelTiledOp);
2432 results.push_back(result->mergeOp);
2433 results.push_back(result->loops.front());
2434 return DiagnosedSilenceableFailure::success();
2435}
2436
2437//===----------------------------------------------------------------------===//
2438// TileReductionUsingForallOp
2439//===----------------------------------------------------------------------===//
2440
2441void transform::TileReductionUsingForallOp::build(
2442 OpBuilder &builder, OperationState &result, Value target,
2443 ArrayRef<int64_t> staticNumThreads, ArrayRef<int64_t> staticTileSizes,
2444 ArrayAttr mapping) {
2445 // Call the default builder.
2446 // This is future-proof re mixed static-dynamic and setting up the proper
2447 // operands segment sizes attributes for multiple variadic operands.
2448 // In the absence of this, horrible bugs ensue.
2449 // TODO: support mixed static-dynamic (see TileUsingForallOp).
2450 MLIRContext *ctx = builder.getContext();
2451 auto opTy = transform::AnyOpType::get(ctx);
2452 auto staticNumThreadsAttr = builder.getDenseI64ArrayAttr(staticNumThreads);
2453 auto staticTileSizesAttr = builder.getDenseI64ArrayAttr(staticTileSizes);
2454 build(builder, result,
2455 /*resultTypes=*/TypeRange{opTy, opTy, opTy, opTy},
2456 /*target=*/target,
2457 /*num_threads=*/staticNumThreadsAttr,
2458 /*tile_sizes=*/staticTileSizesAttr,
2459 /*mapping=*/mapping);
2460}
2461
2462DiagnosedSilenceableFailure transform::TileReductionUsingForallOp::applyToOne(
2463 transform::TransformRewriter &rewriter, LinalgOp target,
2464 transform::ApplyToEachResultList &results,
2465 transform::TransformState &state) {
2466 rewriter.setInsertionPoint(target);
2467 SmallVector<OpFoldResult> numThreads =
2468 getAsOpFoldResult(rewriter.getI64ArrayAttr(getNumThreads()));
2469 SmallVector<OpFoldResult> tileSizes =
2470 getAsOpFoldResult(rewriter.getI64ArrayAttr(getTileSizes()));
2471 FailureOr<linalg::ForallReductionTilingResult> result =
2472 linalg::tileReductionUsingForall(
2473 rewriter, cast<PartialReductionOpInterface>(target.getOperation()),
2474 numThreads, tileSizes, getMapping());
2475
2476 if (failed(result)) {
2477 auto diag = emitSilenceableError() << "could not tile reduction";
2478 diag.attachNote(target.getLoc()) << "target operation";
2479 return diag;
2480 }
2481 results.push_back(result->initialOp);
2482 results.push_back(result->parallelTiledOp);
2483 results.push_back(result->mergeOp);
2484 results.push_back(result->loops);
2485 return DiagnosedSilenceableFailure::success();
2486}
2487
2488//===----------------------------------------------------------------------===//
2489// TileUsingForOp
2490//===----------------------------------------------------------------------===//
2491
2492void transform::TileUsingForOp::build(
2493 OpBuilder &builder, OperationState &result, TypeRange loopTypes,
2494 Value target, ArrayRef<int64_t> staticTileSizes,
2495 ArrayRef<int64_t> interchange,
2496 std::optional<ArrayRef<bool>> scalableSizes) {
2497 return build(builder, result, loopTypes,
2498 /*target=*/target,
2499 /*mixedTileSizes=*/
2500 getAsOpFoldResult(builder.getI64ArrayAttr(staticTileSizes)),
2501 interchange, scalableSizes);
2502}
2503
2504void transform::TileUsingForOp::build(
2505 OpBuilder &builder, OperationState &result, Value target,
2506 ArrayRef<int64_t> staticTileSizes, ArrayRef<int64_t> interchange,
2507 std::optional<ArrayRef<bool>> scalableSizes) {
2508 build(builder, result, target,
2509 getAsOpFoldResult(builder.getI64ArrayAttr(staticTileSizes)),
2510 interchange, scalableSizes);
2511}
2512
2513void transform::TileUsingForOp::build(
2514 OpBuilder &builder, OperationState &result, Value target,
2515 ArrayRef<OpFoldResult> mixedTileSizes, ArrayRef<int64_t> interchange,
2516 std::optional<ArrayRef<bool>> scalableSizes) {
2517 // Loop types are automaticaly splat by the callee, setting up one is
2518 // enough.
2519 SmallVector<Type> loopTypes(1, builder.getType<transform::AnyOpType>());
2520 build(builder, result, loopTypes, target, mixedTileSizes, interchange,
2521 scalableSizes);
2522}
2523
2524void transform::TileUsingForOp::build(
2525 OpBuilder &builder, OperationState &result, TypeRange loopTypes,
2526 Value target, ArrayRef<OpFoldResult> mixedTileSizes,
2527 ArrayRef<int64_t> interchange,
2528 std::optional<ArrayRef<bool>> scalableSizes) {
2529 SmallVector<int64_t> staticTileSizes;
2530 SmallVector<Value> dynamicTileSizes;
2531 dispatchIndexOpFoldResults(mixedTileSizes, dynamicTileSizes, staticTileSizes);
2532 // Call the default builder which sets up the proper operands segment sizes
2533 // attributes for multiple variadic operands. In the absence of this,
2534 // horrible bugs ensue.
2535 auto staticTileSizesAttr = builder.getDenseI64ArrayAttr(staticTileSizes);
2536 unsigned numExpectedLoops =
2537 staticTileSizes.size() - llvm::count(staticTileSizes, 0);
2538 SmallVector<Type> resultTypes;
2539 resultTypes.reserve(numExpectedLoops);
2540 assert((loopTypes.size() == 1 || loopTypes.size() == numExpectedLoops) &&
2541 "expected one loop type or as many as loops");
2542 if (loopTypes.size() == 1)
2543 resultTypes.append(numExpectedLoops, loopTypes[0]);
2544 else
2545 llvm::append_range(resultTypes, loopTypes);
2546 SmallVector<bool> expandedScalableSizes(mixedTileSizes.size(), false);
2547 if (scalableSizes.has_value())
2548 expandedScalableSizes.assign(scalableSizes->begin(), scalableSizes->end());
2549 build(builder, result, /*tiled_linalg_op=*/target.getType(),
2550 /*loops=*/resultTypes,
2551 /*target=*/target,
2552 /*dynamic_sizes=*/dynamicTileSizes,
2553 /*static_sizes=*/staticTileSizesAttr,
2554 /*interchange=*/builder.getDenseI64ArrayAttr(interchange),
2555 /*scalable_sizes=*/expandedScalableSizes);
2556}
2557
2558LogicalResult transform::TileUsingForOp::verify() {
2559 if (getMixedSizes().size() != getScalableSizes().size())
2560 return emitOpError("expected same number of sizes (")
2561 << getMixedSizes().size() << ") and scalable sizes ()"
2562 << getScalableSizes().size() << ")";
2563 return success();
2564}
2565
2566DiagnosedSilenceableFailure
2567transform::TileUsingForOp::apply(transform::TransformRewriter &rewriter,
2568 TransformResults &transformResults,
2569 TransformState &state) {
2570 ArrayRef<int64_t> tileSizes = getStaticSizes();
2571
2572 SmallVector<Operation *> targets =
2573 llvm::to_vector(state.getPayloadOps(getTarget()));
2574 SmallVector<SmallVector<Operation *>> dynamicSizeProducers;
2575 SmallVector<SmallVector<int64_t>> paramSizes;
2576 dynamicSizeProducers.reserve(getDynamicSizes().size());
2577 paramSizes.reserve(getDynamicSizes().size());
2578 for (Value transformValue : getDynamicSizes()) {
2579 if (isa<ParamType>(transformValue.getType())) {
2580 dynamicSizeProducers.push_back({});
2581 ArrayRef<Attribute> params = state.getParams(transformValue);
2582 paramSizes.push_back(
2583 llvm::to_vector(llvm::map_range(params, [](Attribute attr) {
2584 return cast<IntegerAttr>(attr).getValue().getSExtValue();
2585 })));
2586
2587 if (paramSizes.back().size() != targets.size()) {
2588 DiagnosedSilenceableFailure diag =
2589 emitSilenceableError()
2590 << "expected as many parameter values ("
2591 << dynamicSizeProducers.back().size() << ") as target ops ("
2592 << targets.size() << ")";
2593 diag.attachNote(transformValue.getLoc()) << "for this parameter";
2594 return diag;
2595 }
2596
2597 continue;
2598 }
2599 paramSizes.push_back({});
2600 dynamicSizeProducers.push_back(
2601 llvm::to_vector(state.getPayloadOps(transformValue)));
2602
2603 if (dynamicSizeProducers.back().size() != targets.size()) {
2604 DiagnosedSilenceableFailure diag =
2605 emitSilenceableError()
2606 << "expected as many dynamic size-producing operations ("
2607 << dynamicSizeProducers.back().size() << ") as target ops ("
2608 << targets.size() << ")";
2609 diag.attachNote(transformValue.getLoc()) << "for this handle";
2610 return diag;
2611 }
2612
2613 for (Operation *op : dynamicSizeProducers.back()) {
2614 if (op->getNumResults() == 1 &&
2615 isa<IndexType>(op->getResult(0).getType())) {
2616 continue;
2617 }
2618
2619 DiagnosedSilenceableFailure diag =
2620 emitSilenceableError() << "expected sizes to be produced by ops "
2621 "with a single index-type result";
2622 diag.attachNote(op->getLoc()) << "size producer op";
2623 diag.attachNote(transformValue.getLoc()) << "for this handle";
2624 return diag;
2625 }
2626 }
2627
2628 SmallVector<Operation *> tiled;
2629 SmallVector<SmallVector<Operation *, 4>, 4> loops;
2630 loops.resize(getLoops().size());
2631 auto scalableSizes = getScalableSizes();
2632 for (auto [i, op] : llvm::enumerate(targets)) {
2633 auto tilingInterface = dyn_cast<TilingInterface>(op);
2634 if (!tilingInterface) {
2635 DiagnosedSilenceableFailure diag =
2636 emitSilenceableError()
2637 << "only ops implementing TilingInterface are supported";
2638 diag.attachNote(op->getLoc()) << "target op";
2639 return diag;
2640 }
2641 if (tileSizes.size() > tilingInterface.getLoopIteratorTypes().size()) {
2642 DiagnosedSilenceableFailure diag =
2643 emitSilenceableError()
2644 << "too many tiles provided, expected at most "
2645 << tilingInterface.getLoopIteratorTypes().size() << " found "
2646 << tileSizes.size();
2647 diag.attachNote(op->getLoc()) << "target op";
2648 return diag;
2649 }
2650
2651 scf::SCFTilingOptions tilingOptions;
2652 if (tileSizes.empty()) {
2653 tilingOptions.setTileSizeComputationFunction(
2654 [](OpBuilder &, Operation *) -> SmallVector<OpFoldResult> {
2655 return {};
2656 });
2657 } else {
2658 tilingOptions.setTileSizeComputationFunction([&, index = i](OpBuilder &b,
2659 Operation *) {
2660 SmallVector<OpFoldResult> sizes;
2661 sizes.reserve(tileSizes.size());
2662 unsigned dynamicIdx = 0;
2663
2664 for (auto [ofrIdx, ofr] : llvm::enumerate(getMixedSizes())) {
2665 if (auto attr = llvm::dyn_cast_if_present<Attribute>(ofr)) {
2666 if (scalableSizes[ofrIdx]) {
2667 auto val = b.create<arith::ConstantIndexOp>(
2668 getLoc(), cast<IntegerAttr>(attr).getInt());
2669 Value vscale =
2670 b.create<vector::VectorScaleOp>(getLoc(), b.getIndexType());
2671 sizes.push_back(
2672 b.create<arith::MulIOp>(getLoc(), val, vscale).getResult());
2673 } else {
2674 sizes.push_back(attr);
2675 }
2676 continue;
2677 }
2678 ArrayRef<Operation *> dynamicSizes = dynamicSizeProducers[dynamicIdx];
2679 ArrayRef<int64_t> params = paramSizes[dynamicIdx];
2680 ++dynamicIdx;
2681 assert((dynamicSizes.empty() ^ params.empty()) &&
2682 "expected either dynamic sizes or parameters");
2683 if (!params.empty()) {
2684 sizes.push_back(b.getIndexAttr(params[index]));
2685 } else {
2686 sizes.push_back(dynamicSizes[index]->getResult(0));
2687 }
2688 }
2689 return sizes;
2690 });
2691 }
2692
2693 tilingOptions.setInterchange(getInterchange());
2694 FailureOr<scf::SCFTilingResult> maybeTilingResult =
2695 tileUsingSCF(rewriter, tilingInterface, tilingOptions);
2696 if (failed(maybeTilingResult))
2697 return DiagnosedSilenceableFailure::definiteFailure();
2698
2699 rewriter.replaceOp(op, maybeTilingResult->replacements);
2700
2701 tiled.append(maybeTilingResult->tiledOps);
2702 for (const auto &en2 : llvm::enumerate(maybeTilingResult->loops))
2703 loops[en2.index()].push_back(en2.value());
2704 }
2705
2706 transformResults.set(cast<OpResult>(getTiledLinalgOp()), tiled);
2707 for (const auto &en : llvm::enumerate(loops))
2708 transformResults.set(cast<OpResult>(getLoops()[en.index()]), en.value());
2709
2710 return DiagnosedSilenceableFailure::success();
2711}
2712
2713SmallVector<OpFoldResult> transform::TileUsingForOp::getMixedSizes() {
2714 ValueRange dynamic = getDynamicSizes();
2715 ArrayRef<int64_t> tileSizes = getStaticSizes();
2716 SmallVector<OpFoldResult> results;
2717 results.reserve(tileSizes.size());
2718 unsigned dynamicPos = 0;
2719 Builder builder(getContext());
2720 for (int64_t size : tileSizes) {
2721 if (size == ShapedType::kDynamic) {
2722 results.push_back(dynamic[dynamicPos++]);
2723 } else {
2724 results.push_back(builder.getIndexAttr(size));
2725 }
2726 }
2727 return results;
2728}
2729
2730// We want to parse `DenseI64ArrayAttr` using the short form without the
2731// `array` prefix to be consistent in the IR with `parseDynamicIndexList`.
2732ParseResult parseOptionalInterchange(OpAsmParser &parser,
2733 OperationState &result) {
2734 if (failed(result: parser.parseOptionalKeyword(keyword: "interchange")))
2735 return success();
2736 if (failed(result: parser.parseEqual()))
2737 return failure();
2738 result.addAttribute(
2739 transform::TileUsingForOp::getInterchangeAttrName(result.name),
2740 DenseI64ArrayAttr::parse(parser, Type{}));
2741 return success();
2742}
2743
2744void printOptionalInterchange(OpAsmPrinter &p,
2745 ArrayRef<int64_t> interchangeVals) {
2746 if (!interchangeVals.empty()) {
2747 p << " interchange = [";
2748 llvm::interleaveComma(c: interchangeVals, os&: p,
2749 each_fn: [&](int64_t integer) { p << integer; });
2750 p << "]";
2751 }
2752}
2753
2754ParseResult transform::TileUsingForOp::parse(OpAsmParser &parser,
2755 OperationState &result) {
2756 OpAsmParser::UnresolvedOperand target;
2757 SmallVector<OpAsmParser::UnresolvedOperand> dynamicSizes;
2758 DenseI64ArrayAttr staticSizes;
2759 FunctionType functionalType;
2760 llvm::SMLoc operandLoc;
2761 DenseBoolArrayAttr scalableVals;
2762
2763 if (parser.parseOperand(target) || parser.getCurrentLocation(&operandLoc) ||
2764 parseDynamicIndexList(parser, dynamicSizes, staticSizes, scalableVals) ||
2765 parseOptionalInterchange(parser, result) ||
2766 parser.parseOptionalAttrDict(result.attributes) ||
2767 parser.parseColonType(functionalType))
2768 return ParseResult::failure();
2769
2770 size_t numExpectedLoops =
2771 staticSizes.size() - llvm::count(staticSizes.asArrayRef(), 0);
2772 if (functionalType.getNumResults() != numExpectedLoops + 1) {
2773 return parser.emitError(parser.getNameLoc())
2774 << "expected " << (numExpectedLoops + 1) << " result type(s)";
2775 }
2776 if (functionalType.getNumInputs() != dynamicSizes.size() + 1) {
2777 return parser.emitError(operandLoc)
2778 << "expected " << dynamicSizes.size() + 1 << " operand type(s)";
2779 }
2780 if (parser.resolveOperand(target, functionalType.getInputs().front(),
2781 result.operands) ||
2782 parser.resolveOperands(dynamicSizes,
2783 functionalType.getInputs().drop_front(),
2784 operandLoc, result.operands)) {
2785 return failure();
2786 }
2787
2788 result.addAttribute(getScalableSizesAttrName(result.name), scalableVals);
2789
2790 result.addAttribute(getStaticSizesAttrName(result.name), staticSizes);
2791 result.addTypes(functionalType.getResults());
2792 return success();
2793}
2794
2795void TileUsingForOp::print(OpAsmPrinter &p) {
2796 p << ' ' << getTarget();
2797 printDynamicIndexList(p, getOperation(), getDynamicSizes(), getStaticSizes(),
2798 /*valueTypes=*/{}, getScalableSizesAttr(),
2799 OpAsmParser::Delimiter::Square);
2800 printOptionalInterchange(p, getInterchange());
2801 p.printOptionalAttrDict(
2802 (*this)->getAttrs(),
2803 /*elidedAttrs=*/{getInterchangeAttrName(getOperation()->getName()),
2804 getScalableSizesAttrName(getOperation()->getName()),
2805 getStaticSizesAttrName(getOperation()->getName())});
2806 p << " : ";
2807 p.printFunctionalType(getOperands().getTypes(), getResults().getTypes());
2808}
2809
2810void transform::TileUsingForOp::getEffects(
2811 SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
2812 consumesHandle(getTarget(), effects);
2813 onlyReadsHandle(getDynamicSizes(), effects);
2814 producesHandle(getTiledLinalgOp(), effects);
2815 producesHandle(getLoops(), effects);
2816 modifiesPayload(effects);
2817}
2818
2819//===----------------------------------------------------------------------===//
2820// TileUsingForallOp
2821//===----------------------------------------------------------------------===//
2822
2823void transform::TileUsingForallOp::build(OpBuilder &builder,
2824 OperationState &result, Value target,
2825 ArrayRef<int64_t> staticTileSizes,
2826 transform::TileSizesSpec,
2827 ArrayAttr mapping) {
2828 return build(builder, result,
2829 /*target=*/target,
2830 /*mixedTileSizes=*/
2831 getAsOpFoldResult(builder.getI64ArrayAttr(staticTileSizes)),
2832 /*_=*/TileSizesSpec(),
2833 /*mapping=*/mapping);
2834}
2835
2836void transform::TileUsingForallOp::build(OpBuilder &builder,
2837 OperationState &result, Value target,
2838 ArrayRef<OpFoldResult> mixedTileSizes,
2839 transform::TileSizesSpec,
2840 ArrayAttr mapping) {
2841 SmallVector<int64_t> staticTileSizes;
2842 SmallVector<Value> dynamicTileSizes;
2843 dispatchIndexOpFoldResults(mixedTileSizes, dynamicTileSizes, staticTileSizes);
2844 // Call the default builder which sets up the proper operands segment sizes
2845 // attributes for multiple variadic operands. In the absence of this,
2846 // horrible bugs ensue.
2847 MLIRContext *ctx = builder.getContext();
2848 auto operationType = transform::AnyOpType::get(ctx);
2849 auto staticTileSizesAttr = builder.getDenseI64ArrayAttr(staticTileSizes);
2850 build(builder, result,
2851 /*resultTypes=*/TypeRange{operationType, operationType},
2852 /*target=*/target,
2853 /*num_threads=*/ValueRange{},
2854 /*tile_sizes=*/dynamicTileSizes,
2855 /*packed_num_threads=*/Value(),
2856 /*packed_tile_sizes=*/Value(),
2857 /*static_num_threads=*/builder.getDenseI64ArrayAttr({}),
2858 /*static_tile_sizes=*/staticTileSizesAttr,
2859 /*mapping=*/mapping);
2860}
2861
2862void transform::TileUsingForallOp::build(OpBuilder &builder,
2863 OperationState &result, Value target,
2864 ArrayRef<int64_t> staticNumThreads,
2865 transform::NumThreadsSpec,
2866 ArrayAttr mapping) {
2867 return build(builder, result, target,
2868 getAsOpFoldResult(builder.getI64ArrayAttr(staticNumThreads)),
2869 NumThreadsSpec(), mapping);
2870}
2871
2872void transform::TileUsingForallOp::build(OpBuilder &builder,
2873 OperationState &result, Value target,
2874 ArrayRef<OpFoldResult> mixedNumThreads,
2875 transform::NumThreadsSpec,
2876 ArrayAttr mapping) {
2877 SmallVector<int64_t> staticNumThreads;
2878 SmallVector<Value> dynamicNumThreads;
2879 dispatchIndexOpFoldResults(mixedNumThreads, dynamicNumThreads,
2880 staticNumThreads);
2881 // Call the default builder which sets up the proper operands segment sizes
2882 // attributes for multiple variadic operands. In the absence of this,
2883 // horrible bugs ensue.
2884 MLIRContext *ctx = builder.getContext();
2885 auto operationType = transform::AnyOpType::get(ctx);
2886 auto staticNumThreadsAttr = builder.getDenseI64ArrayAttr(staticNumThreads);
2887 build(builder, result,
2888 /*resultTypes=*/TypeRange{operationType, operationType},
2889 /*target=*/target,
2890 /*num_threads=*/dynamicNumThreads,
2891 /*tile_sizes=*/ValueRange{},
2892 /*packed_num_threads=*/Value(),
2893 /*packed_tile_sizes=*/Value(),
2894 /*static_num_threads=*/staticNumThreadsAttr,
2895 /*static_tile_sizes=*/builder.getDenseI64ArrayAttr({}),
2896 /*mapping=*/mapping);
2897}
2898
2899DiagnosedSilenceableFailure transform::tileToForallOpImpl(
2900 RewriterBase &rewriter, transform::TransformState &state,
2901 TransformOpInterface transformOp, Operation *target,
2902 ArrayRef<OpFoldResult> mixedNumThreads,
2903 ArrayRef<OpFoldResult> mixedTileSizes, std::optional<ArrayAttr> mapping,
2904 linalg::ForallTilingResult &tilingResult) {
2905 // Transform all targets one by one.
2906 auto tileableOp = dyn_cast<TilingInterface>(target);
2907 if (!tileableOp) {
2908 DiagnosedSilenceableFailure diag =
2909 transformOp.emitSilenceableError()
2910 << "only TilingInterface ops are supported";
2911 diag.attachNote(loc: target->getLoc()) << "target op";
2912 return diag;
2913 }
2914 rewriter.setInsertionPoint(tileableOp);
2915 FailureOr<linalg::ForallTilingResult> maybeTilingResult = failure();
2916 if (!mixedNumThreads.empty()) {
2917 maybeTilingResult =
2918 linalg::tileToForallOp(builder&: rewriter, op: tileableOp, numThreads: mixedNumThreads, mapping);
2919 } else {
2920 maybeTilingResult = linalg::tileToForallOpUsingTileSizes(
2921 builder&: rewriter, op: tileableOp, tileSizes: mixedTileSizes, mapping);
2922 }
2923
2924 if (failed(result: maybeTilingResult))
2925 return transformOp.emitDefaultSilenceableFailure(tileableOp);
2926 rewriter.replaceOp(tileableOp, maybeTilingResult->tileOp->getResults());
2927
2928 tilingResult = *maybeTilingResult;
2929 return DiagnosedSilenceableFailure::success();
2930}
2931
2932DiagnosedSilenceableFailure transform::TileUsingForallOp::apply(
2933 transform::TransformRewriter &rewriter,
2934 transform::TransformResults &transformResults,
2935 transform::TransformState &state) {
2936 auto transformOp = cast<TransformOpInterface>(getOperation());
2937
2938 // Result payload ops.
2939 SmallVector<Operation *> tileOps;
2940 SmallVector<Operation *> tiledOps;
2941
2942 // Unpack handles.
2943 SmallVector<OpFoldResult> mixedNumThreads;
2944 DiagnosedSilenceableFailure status =
2945 getPackedNumThreads()
2946 ? unpackSingleIndexResultPayloadOperations(
2947 state, transformOp, mixedNumThreads, getPackedNumThreads())
2948 : unpackSingleIndexResultPayloadOperations(
2949 state, transformOp, mixedNumThreads, getMixedNumThreads());
2950 if (!status.succeeded())
2951 return status;
2952 SmallVector<OpFoldResult> mixedTileSizes;
2953 status = getPackedTileSizes()
2954 ? unpackSingleIndexResultPayloadOperations(
2955 state, transformOp, mixedTileSizes, getPackedTileSizes())
2956 : unpackSingleIndexResultPayloadOperations(
2957 state, transformOp, mixedTileSizes, getMixedTileSizes());
2958 if (!status.succeeded())
2959 return status;
2960
2961 for (Operation *target : state.getPayloadOps(getTarget())) {
2962 linalg::ForallTilingResult tilingResult;
2963 DiagnosedSilenceableFailure diag = tileToForallOpImpl(
2964 rewriter, state, transformOp, target, mixedNumThreads, mixedTileSizes,
2965 getMapping(), tilingResult);
2966 if (!diag.succeeded())
2967 return diag;
2968 tileOps.push_back(tilingResult.tileOp);
2969 tiledOps.push_back(tilingResult.tiledOp);
2970 }
2971
2972 transformResults.set(cast<OpResult>(getForallOp()), tileOps);
2973 transformResults.set(cast<OpResult>(getTiledOp()), tiledOps);
2974
2975 return DiagnosedSilenceableFailure::success();
2976}
2977
2978void transform::TileUsingForallOp::getEffects(
2979 SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
2980 consumesHandle(getTarget(), effects);
2981 onlyReadsHandle(getTileSizes(), effects);
2982 onlyReadsHandle(getNumThreads(), effects);
2983 onlyReadsHandle(getPackedNumThreads(), effects);
2984 onlyReadsHandle(getPackedTileSizes(), effects);
2985 producesHandle(getResults(), effects);
2986 modifiesPayload(effects);
2987}
2988
2989SmallVector<OpFoldResult> TileUsingForallOp::getMixedNumThreads() {
2990 Builder b(getContext());
2991 return getMixedValues(getStaticNumThreads(), getNumThreads(), b);
2992}
2993
2994SmallVector<OpFoldResult> TileUsingForallOp::getMixedTileSizes() {
2995 Builder b(getContext());
2996 return getMixedValues(getStaticTileSizes(), getTileSizes(), b);
2997}
2998
2999LogicalResult TileUsingForallOp::verify() {
3000 int numThreadsSpec = static_cast<int>(!getMixedNumThreads().empty()) +
3001 static_cast<int>(getPackedNumThreads() != Value());
3002 if (numThreadsSpec > 1)
3003 return emitOpError(
3004 "num_threads and packed_num_threads are mutually exclusive");
3005 int tileSizesSpec = static_cast<int>(!getMixedTileSizes().empty()) +
3006 static_cast<int>(getPackedTileSizes() != Value());
3007 if (tileSizesSpec > 1)
3008 return emitOpError(
3009 "tile_sizes and packed_tile_sizes are mutually exclusive");
3010 if (numThreadsSpec == 0 && tileSizesSpec == 0)
3011 return emitOpError("either (packed_)num_threads or (packed_)tile_sizes "
3012 "must be specified");
3013 return success();
3014}
3015
3016//===----------------------------------------------------------------------===//
3017// VectorizeChildrenAndApplyPatternsOp
3018//===----------------------------------------------------------------------===//
3019
3020void transform::VectorizeChildrenAndApplyPatternsOp::build(
3021 OpBuilder &builder, OperationState &result, Value target,
3022 bool vectorizePadding, bool vectorizeExtract, bool flatten1DDepthwiseConv) {
3023 result.addOperands(target);
3024 if (vectorizePadding) {
3025 result.addAttribute(
3026 VectorizeChildrenAndApplyPatternsOp::getVectorizePaddingAttrName(
3027 result.name),
3028 builder.getUnitAttr());
3029 }
3030 if (vectorizeExtract) {
3031 result.addAttribute(
3032 VectorizeChildrenAndApplyPatternsOp::getVectorizeNdExtractAttrName(
3033 result.name),
3034 builder.getUnitAttr());
3035 }
3036 if (flatten1DDepthwiseConv) {
3037 result.addAttribute(
3038 VectorizeChildrenAndApplyPatternsOp::getFlatten_1dDepthwiseConvAttrName(
3039 result.name),
3040 builder.getUnitAttr());
3041 }
3042 result.addTypes(transform::AnyOpType::get(builder.getContext()));
3043}
3044
3045namespace {
3046/// This is an helper only to call vectorize via a pattern inside of
3047/// VectorizeChildrenAndApplyPatternsOp::applyToOne.
3048struct VectorizationPattern : public RewritePattern {
3049 explicit VectorizationPattern(MLIRContext *context,
3050 bool vectorizeExtract = false,
3051 bool flattenConv = false)
3052 : RewritePattern(MatchAnyOpTypeTag(), /*benefit=*/1, context),
3053 vectorizeNDExtract(vectorizeExtract),
3054 flatten1DDepthwiseConv(flattenConv) {}
3055 LogicalResult matchAndRewrite(Operation *op,
3056 PatternRewriter &rewriter) const override {
3057 LinalgOp linalgOp = dyn_cast<LinalgOp>(op);
3058 if (!linalgOp)
3059 return rewriter.notifyMatchFailure(arg&: op, msg: "expected Linalg Op");
3060 return vectorize(rewriter, linalgOp, /*inputVectorSizes=*/{},
3061 /*scalableVecDims=*/{}, vectorizeNDExtract,
3062 flatten1DDepthwiseConv);
3063 }
3064
3065private:
3066 /// Controls whether to vectorize `tensor.extract` when the input tensor is
3067 /// rank >= 2.
3068 bool vectorizeNDExtract = false;
3069 /// Controls whether to "flatten" the channel dimension when vectorising 1D
3070 /// depthwise convolutions. This should lead to bette vectorization for
3071 /// tensors with a low number of channel dimensions.
3072 bool flatten1DDepthwiseConv = false;
3073};
3074} // namespace
3075
3076DiagnosedSilenceableFailure
3077transform::VectorizeChildrenAndApplyPatternsOp::applyToOne(
3078 transform::TransformRewriter &rewriter, Operation *target,
3079 transform::ApplyToEachResultList &results,
3080 transform::TransformState &state) {
3081 if (!target->hasTrait<OpTrait::IsIsolatedFromAbove>()) {
3082 auto diag = this->emitOpError("requires isolated-from-above targets");
3083 diag.attachNote(target->getLoc()) << "non-isolated target";
3084 return DiagnosedSilenceableFailure::definiteFailure();
3085 }
3086
3087 MLIRContext *ctx = getContext();
3088 RewritePatternSet patterns(ctx);
3089 patterns.add<VectorizationPattern>(ctx, getVectorizeNdExtract(),
3090 getFlatten_1dDepthwiseConv());
3091
3092 if (!getDisableTransferPermutationMapLoweringPatterns())
3093 vector::populateVectorTransferPermutationMapLoweringPatterns(patterns);
3094
3095 if (!getDisableMultiReductionToContractPatterns())
3096 vector::populateVectorReductionToContractPatterns(patterns);
3097
3098 vector::populateSinkVectorBroadcastPatterns(patterns);
3099
3100 patterns.add<linalg::LinalgCopyVTRForwardingPattern,
3101 linalg::LinalgCopyVTWForwardingPattern>(ctx,
3102 /*benefit=*/2);
3103 vector::TransferReadOp::getCanonicalizationPatterns(patterns, ctx);
3104 vector::TransferWriteOp::getCanonicalizationPatterns(patterns, ctx);
3105 tensor::populateFoldTensorSubsetIntoVectorTransferPatterns(patterns);
3106
3107 patterns.add<CopyVectorizationPattern>(ctx);
3108
3109 if (getVectorizePadding())
3110 linalg::populatePadOpVectorizationPatterns(patterns);
3111
3112 TrackingListener listener(state, *this);
3113 GreedyRewriteConfig config;
3114 config.listener = &listener;
3115 if (failed(applyPatternsAndFoldGreedily(target, std::move(patterns), config)))
3116 return emitDefaultDefiniteFailure(target);
3117
3118 results.push_back(target);
3119 return DiagnosedSilenceableFailure::success();
3120}
3121
3122//===----------------------------------------------------------------------===//
3123// VectorizeOp
3124//===----------------------------------------------------------------------===//
3125
3126static const StringLiteral kVectorSizesKeyword = "vector_sizes";
3127
3128ParseResult transform::VectorizeOp::parse(OpAsmParser &parser,
3129 OperationState &result) {
3130 OpAsmParser::UnresolvedOperand target;
3131 SmallVector<OpAsmParser::UnresolvedOperand> dynamicSizes;
3132 DenseI64ArrayAttr staticSizes;
3133 SmallVector<Type> operandTypes;
3134 llvm::SMLoc operandLoc;
3135 DenseBoolArrayAttr scalableVals;
3136
3137 if (parser.parseOperand(target) || parser.getCurrentLocation(&operandLoc))
3138 return ParseResult::failure();
3139
3140 if (succeeded(parser.parseOptionalKeyword(kVectorSizesKeyword))) {
3141 if (failed(parseDynamicIndexList(parser, dynamicSizes, staticSizes,
3142 scalableVals)))
3143 return ParseResult::failure();
3144 }
3145
3146 if (succeeded(parser.parseOptionalKeyword(
3147 getVectorizeNdExtractAttrName(result.name))))
3148 result.addAttribute(getVectorizeNdExtractAttrName(result.name),
3149 parser.getBuilder().getUnitAttr());
3150
3151 if (parser.parseOptionalAttrDict(result.attributes) ||
3152 parser.parseColonTypeList(operandTypes))
3153 return ParseResult::failure();
3154
3155 if (operandTypes.size() != dynamicSizes.size() + 1) {
3156 return parser.emitError(operandLoc)
3157 << "expected " << dynamicSizes.size() + 1 << " operand type(s)";
3158 }
3159 if (parser.resolveOperand(target, operandTypes.front(), result.operands) ||
3160 parser.resolveOperands(dynamicSizes, ArrayRef(operandTypes).drop_front(),
3161 operandLoc, result.operands)) {
3162 return failure();
3163 }
3164
3165 if (scalableVals)
3166 result.addAttribute(getScalableSizesAttrName(result.name), scalableVals);
3167 if (staticSizes)
3168 result.addAttribute(getStaticVectorSizesAttrName(result.name), staticSizes);
3169
3170 return success();
3171}
3172
3173void transform::VectorizeOp::print(OpAsmPrinter &p) {
3174 p << ' ' << getTarget() << ' ';
3175 if (!getMixedVectorSizes().empty()) {
3176 p << kVectorSizesKeyword << ' ';
3177 printDynamicIndexList(p, getOperation(), getVectorSizes(),
3178 getStaticVectorSizesAttr(),
3179 /*valueTypes=*/{}, getScalableSizesAttr(),
3180 OpAsmParser::Delimiter::Square);
3181 }
3182
3183 if (getVectorizeNdExtract())
3184 p << getVectorizeNdExtractAttrName() << ' ';
3185
3186 p.printOptionalAttrDict(
3187 (*this)->getAttrs(),
3188 /*elidedAttrs=*/{
3189 getScalableSizesAttrName(getOperation()->getName()),
3190 getStaticVectorSizesAttrName(getOperation()->getName())});
3191 p << " : ";
3192 p << getTarget().getType();
3193 if (!getVectorSizes().empty()) {
3194 p << ", ";
3195 llvm::interleaveComma(getVectorSizes(), p,
3196 [&](Value operand) { p << operand.getType(); });
3197 }
3198}
3199
3200DiagnosedSilenceableFailure transform::VectorizeOp::apply(
3201 transform::TransformRewriter &rewriter,
3202 mlir::transform::TransformResults &transformResults,
3203 mlir::transform::TransformState &state) {
3204 auto targets = state.getPayloadOps(getTarget());
3205 if (std::empty(targets))
3206 return DiagnosedSilenceableFailure::success();
3207
3208 SmallVector<int64_t> vectorSizes;
3209 for (OpFoldResult sz : getMixedVectorSizes()) {
3210 if (sz.is<Attribute>()) {
3211 auto attr = sz.get<Attribute>();
3212 vectorSizes.push_back(cast<IntegerAttr>(attr).getInt());
3213 continue;
3214 } else if (sz.is<Value>() && isa<ParamType>(sz.get<Value>().getType())) {
3215 ArrayRef<Attribute> params = state.getParams(sz.get<Value>());
3216 if (params.size() != 1)
3217 return emitSilenceableFailure(getLoc()) << "expected a single param";
3218 vectorSizes.push_back(
3219 cast<IntegerAttr>(params.front()).getValue().getSExtValue());
3220 continue;
3221 }
3222
3223 auto szPayloads = state.getPayloadOps(sz.get<Value>());
3224 if (!llvm::hasSingleElement(szPayloads)) {
3225 auto diag = this->emitOpError(
3226 "requires vector size handle that is mapped to 1 payload op");
3227 diag.attachNote(sz.get<Value>().getLoc())
3228 << "mapped to " << llvm::range_size(szPayloads) << " payload ops";
3229 return DiagnosedSilenceableFailure::definiteFailure();
3230 }
3231
3232 Operation *szPayloadOp = *szPayloads.begin();
3233 if (szPayloadOp->getNumResults() != 1 ||
3234 !szPayloadOp->getResult(0).getType().isIndex()) {
3235 auto diag = this->emitOpError(
3236 "requires vector size payload op with 1 index result");
3237 diag.attachNote(szPayloadOp->getLoc()) << "vector size payload op";
3238 return DiagnosedSilenceableFailure::definiteFailure();
3239 }
3240
3241 IntegerAttr attr;
3242 if (!matchPattern(szPayloadOp->getResult(0), m_Constant(&attr))) {
3243 auto diag = this->emitOpError("requires constant vector size");
3244 diag.attachNote(szPayloadOp->getLoc()) << "vector size payload op";
3245 return DiagnosedSilenceableFailure::definiteFailure();
3246 }
3247
3248 vectorSizes.push_back(attr.getInt());
3249 }
3250
3251 // TODO: Check that the correct number of vectorSizes was provided.
3252 for (Operation *target : targets) {
3253 if (!isa<linalg::LinalgOp, tensor::PadOp, tensor::PackOp, tensor::UnPackOp>(
3254 target)) {
3255 return mlir::emitSilenceableFailure(target->getLoc())
3256 << "Unsupported Op, cannot vectorize";
3257 }
3258
3259 if (failed(linalg::vectorize(rewriter, target, vectorSizes,
3260 getScalableSizes(),
3261 getVectorizeNdExtract().has_value()
3262 ? getVectorizeNdExtract().value()
3263 : false))) {
3264 return mlir::emitSilenceableFailure(target->getLoc())
3265 << "Attempted to vectorize, but failed";
3266 }
3267 }
3268
3269 return DiagnosedSilenceableFailure::success();
3270}
3271
3272void transform::VectorizeOp::getEffects(
3273 SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
3274 consumesHandle(getTarget(), effects);
3275 onlyReadsHandle(getVectorSizes(), effects);
3276 modifiesPayload(effects);
3277}
3278
3279SmallVector<OpFoldResult> VectorizeOp::getMixedVectorSizes() {
3280 OpBuilder b(getContext());
3281 return getMixedValues(getStaticVectorSizes(), getVectorSizes(), b);
3282}
3283
3284LogicalResult transform::VectorizeOp::verify() {
3285 if (getStaticVectorSizes().size() != getScalableSizes().size())
3286 return emitOpError("expected same number of vector sizes (")
3287 << getStaticVectorSizes().size() << ") and scalable sizes ("
3288 << getScalableSizes().size() << ")";
3289 return success();
3290}
3291
3292//===----------------------------------------------------------------------===//
3293// HoistRedundantVectorTransfersOp
3294//===----------------------------------------------------------------------===//
3295
3296DiagnosedSilenceableFailure
3297transform::HoistRedundantVectorTransfersOp::applyToOne(
3298 transform::TransformRewriter &rewriter, func::FuncOp target,
3299 transform::ApplyToEachResultList &results,
3300 transform::TransformState &state) {
3301 // WARNING: This hoisting does not model parallelism and is generally
3302 // incorrect when used on distributed loops with memref semantics!
3303 // TODO: obsolete and should be retired.
3304 linalg::hoistRedundantVectorTransfers(target);
3305 results.push_back(target);
3306 return DiagnosedSilenceableFailure::success();
3307}
3308
3309//===----------------------------------------------------------------------===//
3310// HoistRedundantVectorBroadcastsOp
3311//===----------------------------------------------------------------------===//
3312
3313DiagnosedSilenceableFailure
3314transform::HoistRedundantVectorBroadcastsOp::applyToOne(
3315 transform::TransformRewriter &rewriter, mlir::Operation *target,
3316 transform::ApplyToEachResultList &results,
3317 transform::TransformState &state) {
3318 rewriter.setInsertionPoint(target);
3319 linalg::hoistRedundantVectorBroadcasts(rewriter, target);
3320 results.push_back(target);
3321 return DiagnosedSilenceableFailure::success();
3322}
3323
3324//===----------------------------------------------------------------------===//
3325// ConvertConv2DToImg2ColOp.
3326//===----------------------------------------------------------------------===//
3327
3328DiagnosedSilenceableFailure transform::ConvertConv2DToImg2ColOp::applyToOne(
3329 transform::TransformRewriter &rewriter, linalg::LinalgOp target,
3330 transform::ApplyToEachResultList &results,
3331 transform::TransformState &state) {
3332 rewriter.setInsertionPoint(target);
3333 auto maybeTransformed =
3334 TypeSwitch<Operation *, FailureOr<std::pair<Operation *, Operation *>>>(
3335 target)
3336 .Case([&](linalg::Conv2DNhwcHwcfOp op) {
3337 return rewriteInIm2Col(rewriter, op);
3338 })
3339 .Case([&](linalg::Conv2DNhwcFhwcOp op) {
3340 return rewriteInIm2Col(rewriter, op);
3341 })
3342 .Case([&](linalg::DepthwiseConv2DNhwcHwcOp op) {
3343 return rewriteInIm2Col(rewriter, op);
3344 })
3345 .Case([&](linalg::Conv2DNchwFchwOp op) {
3346 return rewriteInIm2Col(rewriter, op);
3347 })
3348 .Default([&](Operation *op) {
3349 return rewriter.notifyMatchFailure(op, "not supported");
3350 });
3351 if (failed(maybeTransformed))
3352 return emitDefaultSilenceableFailure(target);
3353 // Handle to the operation producing the img2col tensor.
3354 results.push_back(maybeTransformed->first);
3355 // Handle to the operation that replaces the original convolution.
3356 results.push_back(maybeTransformed->second);
3357 return DiagnosedSilenceableFailure::success();
3358}
3359
3360//===----------------------------------------------------------------------===//
3361// FlattenElementwiseLinalgOp.
3362//===----------------------------------------------------------------------===//
3363
3364DiagnosedSilenceableFailure transform::FlattenElementwiseLinalgOp::applyToOne(
3365 transform::TransformRewriter &rewriter, linalg::LinalgOp target,
3366 transform::ApplyToEachResultList &results,
3367 transform::TransformState &state) {
3368 rewriter.setInsertionPoint(target);
3369 if (!isElementwise(target))
3370 return mlir::emitSilenceableFailure(target->getLoc())
3371 << "only elementwise flattening is supported";
3372
3373 // If rank <= 1, do nothing
3374 if (target.getNumLoops() <= 1) {
3375 results.push_back(target);
3376 return DiagnosedSilenceableFailure::success();
3377 }
3378
3379 // Attempt to flatten all dims to one.
3380 ReassociationIndices reassociation(target.getNumLoops());
3381 std::iota(reassociation.begin(), reassociation.end(), 0);
3382 auto maybeFlattened =
3383 collapseOpIterationDims(target, reassociation, rewriter);
3384 if (failed(maybeFlattened))
3385 return mlir::emitSilenceableFailure(target->getLoc())
3386 << "attempted to flatten, but failed";
3387 results.push_back(maybeFlattened->collapsedOp);
3388 rewriter.replaceOp(target, maybeFlattened->results);
3389 return DiagnosedSilenceableFailure::success();
3390}
3391
3392//===----------------------------------------------------------------------===//
3393// TransposeConv2DOp
3394//===----------------------------------------------------------------------===//
3395
3396DiagnosedSilenceableFailure transform::TransposeConv2DOp::applyToOne(
3397 transform::TransformRewriter &rewriter, linalg::LinalgOp target,
3398 transform::ApplyToEachResultList &results,
3399 transform::TransformState &state) {
3400 rewriter.setInsertionPoint(target);
3401 auto maybeTransformed =
3402 TypeSwitch<Operation *, FailureOr<Operation *>>(target)
3403 .Case([&](linalg::Conv2DNhwcFhwcOp op) {
3404 return transposeConv2D(rewriter, op);
3405 })
3406 .Case([&](linalg::Conv2DNhwcFhwcQOp op) {
3407 return transposeConv2D(rewriter, op);
3408 })
3409 .Default([&](Operation *op) {
3410 return rewriter.notifyMatchFailure(op, "not supported");
3411 });
3412 if (failed(maybeTransformed))
3413 return emitDefaultSilenceableFailure(target);
3414 // Handle to the new Conv2D operation with transposed filters
3415 results.push_back(*maybeTransformed);
3416 return DiagnosedSilenceableFailure::success();
3417}
3418
3419//===----------------------------------------------------------------------===//
3420// TransposeMatmulOp
3421//===----------------------------------------------------------------------===//
3422
3423DiagnosedSilenceableFailure transform::TransposeMatmulOp::applyToOne(
3424 transform::TransformRewriter &rewriter, linalg::LinalgOp target,
3425 transform::ApplyToEachResultList &results,
3426 transform::TransformState &state) {
3427 rewriter.setInsertionPoint(target);
3428 bool transposeLHS = getInputToTranspose() == TransposeMatmulInput::lhs;
3429 auto maybeTransformed =
3430 TypeSwitch<Operation *, FailureOr<Operation *>>(target)
3431 .Case([&](linalg::MatmulOp op) {
3432 return transposeMatmul(rewriter, op, transposeLHS);
3433 })
3434 .Case([&](linalg::BatchMatmulOp op) {
3435 return transposeBatchMatmul(rewriter, op, transposeLHS);
3436 })
3437 .Default([&](Operation *op) { return failure(); });
3438 if (failed(maybeTransformed))
3439 return emitSilenceableFailure(target->getLoc()) << "not supported";
3440 // Handle to the new Matmul operation with transposed filters
3441 results.push_back(*maybeTransformed);
3442 return DiagnosedSilenceableFailure::success();
3443}
3444
3445//===----------------------------------------------------------------------===//
3446// InsertSliceToCopyOp
3447//===----------------------------------------------------------------------===//
3448template <typename OpTy>
3449DiagnosedSilenceableFailure doit(RewriterBase &rewriter, OpTy target,
3450 transform::ApplyToEachResultList &results,
3451 transform::TransformState &state) {
3452 static_assert(llvm::is_one_of<OpTy, tensor::InsertSliceOp,
3453 tensor::ParallelInsertSliceOp>() &&
3454 "wrong op type");
3455
3456 if (auto copySource =
3457 target.getSource().template getDefiningOp<linalg::CopyOp>()) {
3458 results.push_back(copySource);
3459 return DiagnosedSilenceableFailure::success();
3460 }
3461
3462 // If we are inside an InParallel region, temporarily set the insertion point
3463 // outside: only tensor.parallel_insert_slice ops are allowed in there.
3464 if constexpr (std::is_same_v<OpTy, tensor::ParallelInsertSliceOp>) {
3465 rewriter.setInsertionPoint(
3466 target->template getParentOfType<scf::InParallelOp>());
3467 }
3468
3469 Value extracted = rewriter.create<tensor::ExtractSliceOp>(
3470 target.getLoc(), target.getDest(), target.getMixedOffsets(),
3471 target.getMixedSizes(), target.getMixedStrides());
3472 Value copied = rewriter
3473 .create<linalg::CopyOp>(target.getLoc(),
3474 target.getSource(), extracted)
3475 .getResult(0);
3476 // Reset the insertion point.
3477 rewriter.setInsertionPoint(target);
3478 rewriter.replaceOpWithNewOp<OpTy>(
3479 target, copied, target.getDest(), target.getMixedOffsets(),
3480 target.getMixedSizes(), target.getMixedStrides());
3481
3482 results.push_back(op: copied.getDefiningOp());
3483 return DiagnosedSilenceableFailure::success();
3484}
3485
3486DiagnosedSilenceableFailure transform::InsertSliceToCopyOp::applyToOne(
3487 transform::TransformRewriter &rewriter, Operation *targetOp,
3488 transform::ApplyToEachResultList &results,
3489 transform::TransformState &state) {
3490
3491 rewriter.setInsertionPoint(targetOp);
3492 if (auto target = dyn_cast<tensor::InsertSliceOp>(targetOp))
3493 return doit(rewriter, target, results, state);
3494 if (auto target = dyn_cast<tensor::ParallelInsertSliceOp>(targetOp))
3495 return doit(rewriter, target, results, state);
3496
3497 DiagnosedSilenceableFailure diag =
3498 emitSilenceableError()
3499 << "only InsertSliceOp and ParallelInsertSliceOp ops are supported";
3500 diag.attachNote(targetOp->getLoc()) << "target op";
3501 return diag;
3502}
3503
3504//===----------------------------------------------------------------------===//
3505// MapCopyToThreadsOp
3506//===----------------------------------------------------------------------===//
3507
3508DiagnosedSilenceableFailure transform::MapCopyToThreadsOp::applyToOne(
3509 transform::TransformRewriter &rewriter, Operation *target,
3510 transform::ApplyToEachResultList &results,
3511 transform::TransformState &state) {
3512 // Check if the op is supported.
3513 if (!isa<linalg::CopyOp, tensor::PadOp>(target)) {
3514 DiagnosedSilenceableFailure diag =
3515 emitSilenceableError()
3516 << "only linalg.copy and tensor.pad target ops are supported";
3517 diag.attachNote(target->getLoc()) << "target op";
3518 return diag;
3519 }
3520 assert(target->getNumResults() == 1 && "expected single result");
3521 auto resultShapedType = cast<ShapedType>(target->getResult(0).getType());
3522 if (!resultShapedType.hasStaticShape()) {
3523 DiagnosedSilenceableFailure diag =
3524 emitSilenceableError()
3525 << "only statically sized ops of rank <= 3 are supported";
3526 diag.attachNote(target->getLoc()) << "target op";
3527 return diag;
3528 }
3529
3530 // Conservatively set the minimum viable desired bitwidth alignment.
3531 int64_t desiredBitAlignment = getDesiredBitAlignment();
3532 int64_t eltBitwidth =
3533 resultShapedType.getElementType().getIntOrFloatBitWidth();
3534 if (desiredBitAlignment % eltBitwidth != 0) {
3535 desiredBitAlignment = eltBitwidth;
3536 }
3537
3538 gpu::CopyMappingInfo mapping(
3539 /*ctx=*/getContext(),
3540 /*totalNumThreads=*/getTotalNumThreads(),
3541 /*alignment=*/desiredBitAlignment,
3542 /*sizes=*/resultShapedType.getShape(),
3543 /*favorPredication=*/false,
3544 /*elementalBitwidth=*/
3545 resultShapedType.getElementType().getIntOrFloatBitWidth());
3546 if (mapping.status == gpu::CopyMappingInfo::Status::Invalid) {
3547 DiagnosedSilenceableFailure diag =
3548 emitSilenceableError()
3549 << "too few threads to map copy op to threads on the most minor "
3550 "dimension, given alignment and vector size constraints, try "
3551 "smaller tile size of mapping to more threads";
3552 diag.attachNote(target->getLoc()) << "target op";
3553 return diag;
3554 }
3555
3556 // OpBuilder only used to compute attributes.
3557 OpBuilder b(getContext());
3558 linalg::ForallTilingResult tilingResult;
3559 DiagnosedSilenceableFailure diag = tileToForallOpImpl(
3560 /*rewriter=*/rewriter,
3561 /*state=*/state,
3562 /*transformOp=*/*this,
3563 /*target=*/target,
3564 /*mixedNumThreads=*/getMixedValues(mapping.numThreads, {}, b),
3565 /*mixedTileSizes=*/ArrayRef<OpFoldResult>{},
3566 /*mapping=*/b.getArrayAttr(mapping.threadMapping),
3567 /*tilingResult=*/tilingResult);
3568 if (!diag.succeeded())
3569 return diag;
3570
3571 results.push_back(tilingResult.tileOp);
3572 results.push_back(tilingResult.tiledOp);
3573 return DiagnosedSilenceableFailure::success();
3574}
3575
3576#include "mlir/Dialect/Linalg/TransformOps/LinalgTransformOpsEnums.cpp.inc"
3577
3578#define GET_OP_CLASSES
3579#include "mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp.inc"
3580

source code of mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp