1//===- LinalgTransformOps.cpp - Implementation of Linalg transform ops ----===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#include "mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.h"
10
11#include "mlir/AsmParser/AsmParser.h"
12
13#include "mlir/Dialect/Affine/IR/AffineOps.h"
14#include "mlir/Dialect/Arith/IR/Arith.h"
15#include "mlir/Dialect/Arith/Utils/Utils.h"
16#include "mlir/Dialect/Bufferization/IR/Bufferization.h"
17#include "mlir/Dialect/Bufferization/Transforms/OneShotAnalysis.h"
18#include "mlir/Dialect/GPU/IR/GPUDialect.h"
19#include "mlir/Dialect/Linalg/IR/Linalg.h"
20#include "mlir/Dialect/Linalg/TransformOps/GPUHeuristics.h"
21#include "mlir/Dialect/Linalg/TransformOps/Syntax.h"
22#include "mlir/Dialect/Linalg/Transforms/Hoisting.h"
23#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
24#include "mlir/Dialect/Linalg/Utils/Utils.h"
25#include "mlir/Dialect/SCF/Transforms/TileUsingInterface.h"
26#include "mlir/Dialect/Tensor/IR/Tensor.h"
27#include "mlir/Dialect/Tensor/Utils/Utils.h"
28#include "mlir/Dialect/Transform/IR/TransformDialect.h"
29#include "mlir/Dialect/Transform/IR/TransformOps.h"
30#include "mlir/Dialect/Transform/IR/TransformTypes.h"
31#include "mlir/Dialect/Transform/Interfaces/TransformInterfaces.h"
32#include "mlir/Dialect/Transform/Utils/Utils.h"
33#include "mlir/Dialect/Utils/IndexingUtils.h"
34#include "mlir/Dialect/Utils/StaticValueUtils.h"
35#include "mlir/Dialect/Vector/Transforms/LoweringPatterns.h"
36#include "mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h"
37#include "mlir/IR/BuiltinTypeInterfaces.h"
38#include "mlir/IR/PatternMatch.h"
39#include "mlir/IR/TypeUtilities.h"
40#include "mlir/Interfaces/TilingInterface.h"
41#include "mlir/Support/LLVM.h"
42#include "mlir/Support/TypeID.h"
43#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
44#include "llvm/ADT/STLExtras.h"
45#include "llvm/ADT/ScopeExit.h"
46#include "llvm/ADT/TypeSwitch.h"
47#include "llvm/Support/Debug.h"
48#include <type_traits>
49
50using namespace mlir;
51using namespace mlir::linalg;
52using namespace mlir::transform;
53
54#define DEBUG_TYPE "linalg-transforms"
55#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ")
56#define DBGSNL() (llvm::dbgs() << "\n")
57#define LDBG(X) LLVM_DEBUG(DBGS() << (X) << "\n")
58
59/// Attempts to apply the pattern specified as template argument to the given
60/// operation. The pattern is expected to have a `returningMatchAndRewrite`
61/// function that returns the "main" result or failure. Returns failure if the
62/// pattern failed to apply. Extra arguments are forwarded to the pattern
63/// constructor.
64template <typename PatternTy, typename... Args>
65static FailureOr<LinalgOp> tryApply(Operation *operation, Args &&...args) {
66 // Check if the given operation has the type expected by the pattern.
67 using OpTy = typename llvm::function_traits<
68 decltype(&PatternTy::returningMatchAndRewrite)>::template arg_t<0>;
69 auto op = dyn_cast<OpTy>(operation);
70 if (!op)
71 return failure();
72
73 // Apply the pattern directly to the op.
74 PatternTy pattern(operation->getContext(), std::forward<Args>(args)...);
75 // We want to discourage direct use of PatternRewriter in APIs but In this
76 // very specific case, an IRRewriter is not enough.
77 struct TrivialPatternRewriter : public PatternRewriter {
78 public:
79 explicit TrivialPatternRewriter(MLIRContext *context)
80 : PatternRewriter(context) {}
81 };
82 TrivialPatternRewriter rewriter(operation->getContext());
83 rewriter.setInsertionPoint(operation);
84 auto result = pattern.returningMatchAndRewrite(op, rewriter);
85 if (failed(result))
86 return failure();
87 return cast<LinalgOp>(result->getOperation());
88}
89
90/// Assuming that `ofr` is an index attr or a param of index type
91/// or a transform dialect handle mapped to exactly one op
92/// with one index result, return that value.
93static DiagnosedSilenceableFailure unpackSingleIndexResultPayloadOperations(
94 transform::TransformState &state, TransformOpInterface transformOp,
95 SmallVector<OpFoldResult> &result, ArrayRef<OpFoldResult> ofrs) {
96 for (OpFoldResult ofr : ofrs) {
97 if (auto attr = dyn_cast<Attribute>(Val&: ofr)) {
98 if (!isa<IntegerAttr>(Val: attr))
99 return transformOp.emitDefiniteFailure() << "expected IntegerAttr";
100 result.push_back(Elt: ofr);
101 continue;
102 }
103
104 Value transformValue = cast<Value>(Val&: ofr);
105 if (isa<TransformParamTypeInterface>(transformValue.getType())) {
106 ArrayRef<Attribute> params = state.getParams(value: transformValue);
107 if (params.size() != 1)
108 return transformOp.emitDefiniteFailure()
109 << "requires exactly one parameter associated";
110 result.push_back(Elt: params[0]);
111 continue;
112 }
113
114 auto payloadOps = state.getPayloadOps(value: transformValue);
115 if (!llvm::hasSingleElement(C&: payloadOps)) {
116 DiagnosedSilenceableFailure diag =
117 transformOp.emitSilenceableError()
118 << "handle must be mapped to exactly one payload op";
119 diag.attachNote(loc: transformValue.getLoc())
120 << "mapped to " << llvm::range_size(Range&: payloadOps) << " payload ops";
121 return diag;
122 }
123
124 Operation *op = *payloadOps.begin();
125 if (op->getNumResults() != 1 || !op->getResult(idx: 0).getType().isIndex()) {
126 DiagnosedSilenceableFailure diag =
127 transformOp.emitSilenceableError()
128 << "payload op must have exactly 1 index result";
129 diag.attachNote(loc: op->getLoc())
130 << "has " << op->getNumResults() << " results";
131 return diag;
132 }
133 result.push_back(Elt: op->getResult(idx: 0));
134 }
135
136 return DiagnosedSilenceableFailure::success();
137}
138
139// Given a list of params that are index attrs or a list of OpFoldResults
140// that are either index attrs or op handles, return a list of OpFoldResults
141// of index attrs or a list of OpFoldResults where all op handles are
142// replaced with the first (and only) OpResult of that payload op.
143// (There must be exactly one parameter associated with the AnyParamType or
144// one mapped payload op which must have exactly one index result.)
145static DiagnosedSilenceableFailure unpackSingleIndexResultPayloadOperations(
146 transform::TransformState &state, TransformOpInterface transformOp,
147 SmallVector<OpFoldResult> &result, Value packedHandle) {
148 if (isa<TransformParamTypeInterface>(packedHandle.getType())) {
149 ArrayRef<Attribute> params = state.getParams(value: packedHandle);
150 for (auto param : params) {
151 if (!isa<IntegerAttr>(Val: param))
152 return transformOp.emitDefiniteFailure()
153 << "expected the parameter to be associated with an integer "
154 "attribute";
155 result.push_back(Elt: param);
156 }
157 return DiagnosedSilenceableFailure::success();
158 }
159
160 for (Operation *op : state.getPayloadOps(value: packedHandle)) {
161 if (op->getNumResults() != 1 || !op->getResult(idx: 0).getType().isIndex()) {
162 DiagnosedSilenceableFailure diag =
163 transformOp.emitSilenceableError()
164 << "payload op must have exactly 1 index result";
165 diag.attachNote(loc: op->getLoc())
166 << "has " << op->getNumResults() << " results";
167 return diag;
168 }
169 result.push_back(Elt: op->getResult(idx: 0));
170 }
171
172 return DiagnosedSilenceableFailure::success();
173}
174
175/// When possible, converts each `OpFoldResult` in `mixedResult` to
176/// an integer if the value can be statically inferred. If a result
177/// is a `Value` then it must be either a `ParamType` or a handle
178/// to an a constant like op.
179static DiagnosedSilenceableFailure reifyMixedParamAndHandleResults(
180 TransformState &state, TransformOpInterface &transformOp,
181 ArrayRef<OpFoldResult> mixedResults, SmallVectorImpl<int64_t> &reified) {
182 for (OpFoldResult paramOrHandle : mixedResults) {
183 if (auto attr = dyn_cast<Attribute>(Val&: paramOrHandle)) {
184 reified.push_back(Elt: cast<IntegerAttr>(attr).getInt());
185 continue;
186 } else if (isa<ParamType>(cast<Value>(paramOrHandle).getType())) {
187 ArrayRef<Attribute> params = state.getParams(value: cast<Value>(Val&: paramOrHandle));
188 if (params.size() != 1)
189 return transformOp.emitSilenceableError() << "expected a single param";
190 reified.push_back(
191 Elt: cast<IntegerAttr>(params.front()).getValue().getSExtValue());
192 continue;
193 }
194
195 Value handle = cast<Value>(Val&: paramOrHandle);
196 if (!isa<TransformHandleTypeInterface>(handle.getType()))
197 return transformOp.emitSilenceableError() << "unexpected value handle";
198 auto payload = state.getPayloadOps(value: handle);
199 if (!llvm::hasSingleElement(C&: payload))
200 return transformOp.emitSilenceableError()
201 << "requires param or handle that is mapped to 1 payload op";
202
203 Operation *paramOrHandlePayloadOp = *payload.begin();
204 if (paramOrHandlePayloadOp->getNumResults() != 1 ||
205 !paramOrHandlePayloadOp->getResult(idx: 0).getType().isIndex()) {
206 return transformOp.emitSilenceableError()
207 << "requires param or handle to be result of op with 1 index "
208 "result";
209 }
210
211 IntegerAttr attr;
212 if (!matchPattern(paramOrHandlePayloadOp->getResult(idx: 0), m_Constant(&attr)))
213 return transformOp.emitSilenceableError()
214 << "requires param or handle to be the result of a constant like "
215 "op";
216
217 reified.push_back(Elt: attr.getInt());
218 }
219 return DiagnosedSilenceableFailure::success();
220}
221
222//===----------------------------------------------------------------------===//
223// Apply...PatternsOp
224//===----------------------------------------------------------------------===//
225
226void transform::ApplyEraseUnnecessaryInputsPatternsOp::populatePatterns(
227 RewritePatternSet &patterns) {
228 linalg::populateEraseUnnecessaryInputsPatterns(patterns);
229}
230
231void transform::ApplyDecomposeTensorPackUnpackPatternsOp::populatePatterns(
232 RewritePatternSet &patterns) {
233 linalg::populateDecomposePackUnpackPatterns(patterns);
234}
235
236void transform::ApplyDecomposeTensorPadPatternsOp::populatePatterns(
237 RewritePatternSet &patterns) {
238 linalg::populateDecomposePadPatterns(patterns);
239}
240
241void transform::ApplyFoldUnitExtentDimsViaReshapesPatternsOp::populatePatterns(
242 RewritePatternSet &patterns) {
243 linalg::ControlDropUnitDims options;
244 linalg::populateFoldUnitExtentDimsPatterns(patterns, options);
245}
246
247void transform::ApplyFoldUnitExtentDimsViaSlicesPatternsOp::populatePatterns(
248 RewritePatternSet &patterns) {
249 linalg::ControlDropUnitDims options;
250 options.rankReductionStrategy =
251 linalg::ControlDropUnitDims::RankReductionStrategy::ExtractInsertSlice;
252 linalg::populateFoldUnitExtentDimsPatterns(patterns, options);
253}
254
255void transform::ApplyTilingCanonicalizationPatternsOp::populatePatterns(
256 RewritePatternSet &patterns) {
257 linalg::populateLinalgTilingCanonicalizationPatterns(patterns);
258}
259
260void transform::ApplyFoldAddIntoDestPatternsOp::populatePatterns(
261 RewritePatternSet &patterns) {
262 linalg::populateFoldAddIntoDestPatterns(patterns);
263}
264
265void transform::ApplyPadVectorizationPatternsOp::populatePatterns(
266 RewritePatternSet &patterns) {
267 linalg::populatePadOpVectorizationPatterns(patterns);
268}
269
270void transform::ApplyFoldIntoPackAndUnpackPatternsOp::populatePatterns(
271 RewritePatternSet &patterns) {
272 linalg::populateFoldIntoPackAndUnpackPatterns(patterns);
273}
274
275void transform::ApplyFoldPackUnpackIntoEmptyPatternsOp::populatePatterns(
276 RewritePatternSet &patterns) {
277 linalg::populateFoldPackUnpackIntoTensorEmptyPatterns(patterns);
278}
279
280//===----------------------------------------------------------------------===//
281// BufferizeToAllocationOp
282//===----------------------------------------------------------------------===//
283
284void transform::BufferizeToAllocationOp::build(OpBuilder &b,
285 OperationState &result,
286 Value target,
287 Attribute memorySpace) {
288 SmallVector<Type> resultTypes;
289 resultTypes.push_back(b.getType<transform::AnyValueType>());
290 resultTypes.push_back(b.getType<transform::AnyOpType>());
291 return build(b, result,
292 /*resultTypes=*/resultTypes,
293 /*target=*/target,
294 /*memorySpace=*/memorySpace);
295}
296
297void transform::BufferizeToAllocationOp::build(OpBuilder &b,
298 OperationState &result,
299 Value target,
300 int64_t memorySpace) {
301 SmallVector<Type> resultTypes;
302 resultTypes.push_back(b.getType<transform::AnyValueType>());
303 resultTypes.push_back(b.getType<transform::AnyOpType>());
304 return build(b, result,
305 /*resultTypes=*/resultTypes,
306 /*target=*/target,
307 /*memorySpace=*/b.getI64IntegerAttr(memorySpace));
308}
309
310namespace {
311class NewOpsListener : public RewriterBase::ForwardingListener {
312public:
313 using RewriterBase::ForwardingListener::ForwardingListener;
314
315 SmallVector<Operation *> getNewOps() const {
316 return SmallVector<Operation *>(newOps.begin(), newOps.end());
317 }
318
319private:
320 void notifyOperationInserted(Operation *op,
321 OpBuilder::InsertPoint previous) override {
322 ForwardingListener::notifyOperationInserted(op, previous);
323 // We only care about newly created ops.
324 if (previous.isSet())
325 return;
326 auto inserted = newOps.insert(V: op);
327 (void)inserted;
328 assert(inserted.second && "expected newly created op");
329 }
330
331 void notifyOperationErased(Operation *op) override {
332 ForwardingListener::notifyOperationErased(op);
333 op->walk(callback: [&](Operation *op) { newOps.erase(V: op); });
334 }
335
336 DenseSet<Operation *> newOps;
337};
338} // namespace
339
340DiagnosedSilenceableFailure transform::BufferizeToAllocationOp::apply(
341 transform::TransformRewriter &rewriter,
342 transform::TransformResults &results, transform::TransformState &state) {
343 // Attach listener to keep track of newly created ops.
344 OpBuilder::Listener *previousListener = rewriter.getListener();
345 auto resetListener =
346 llvm::make_scope_exit([&]() { rewriter.setListener(previousListener); });
347 NewOpsListener newOpsListener(previousListener);
348 rewriter.setListener(&newOpsListener);
349
350 linalg::BufferizeToAllocationOptions options;
351 if (getMemcpyOp() == "bufferization.materialize_in_destination") {
352 options.memcpyOp = linalg::BufferizeToAllocationOptions::MemcpyOp::
353 MaterializeInDestination;
354 } else if (getMemcpyOp() == "memref.copy") {
355 options.memcpyOp =
356 linalg::BufferizeToAllocationOptions::MemcpyOp::MemrefCopy;
357 } else if (getMemcpyOp() == "linalg.copy") {
358 options.memcpyOp =
359 linalg::BufferizeToAllocationOptions::MemcpyOp::LinalgCopy;
360 } else {
361 llvm_unreachable("invalid memcpy op");
362 }
363 if (getAllocOp() == "memref.alloc") {
364 options.allocOp =
365 linalg::BufferizeToAllocationOptions::AllocOp::MemrefAlloc;
366 } else if (getAllocOp() == "memref.alloca") {
367 options.allocOp =
368 linalg::BufferizeToAllocationOptions::AllocOp::MemrefAlloca;
369 } else {
370 llvm_unreachable("invalid alloc op");
371 }
372 options.bufferizeDestinationOnly = getBufferizeDestinationOnly();
373 options.emitDealloc = getEmitDealloc();
374
375 // Bufferize ops.
376 Attribute memorySpace =
377 getMemorySpace().has_value() ? getMemorySpace().value() : Attribute();
378 SmallVector<Value> allocatedBuffers;
379 for (Operation *op : state.getPayloadOps(getTarget())) {
380 Value buffer =
381 linalg::bufferizeToAllocation(rewriter, options, op, memorySpace);
382 if (!buffer) {
383 DiagnosedSilenceableFailure diag = emitSilenceableError()
384 << "failed to bufferize operation";
385 diag.attachNote(op->getLoc()) << "target payload op";
386 return diag;
387 }
388 allocatedBuffers.push_back(buffer);
389 }
390
391 // Set results.
392 results.setValues(cast<OpResult>(getAllocatedBuffer()), allocatedBuffers);
393 results.set(cast<OpResult>(getNewOps()), newOpsListener.getNewOps());
394 return DiagnosedSilenceableFailure::success();
395}
396
397void transform::BufferizeToAllocationOp::getEffects(
398 SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
399 if (getBufferizeDestinationOnly()) {
400 // The destination is replaced with a newly allocated buffer, but the op
401 // itself remains in place.
402 onlyReadsHandle(getTargetMutable(), effects);
403 } else {
404 consumesHandle(getTargetMutable(), effects);
405 }
406 producesHandle(getOperation()->getOpResults(), effects);
407 modifiesPayload(effects);
408}
409
410LogicalResult transform::BufferizeToAllocationOp::verify() {
411 if (getMemcpyOp() != "bufferization.materialize_in_destination" &&
412 getMemcpyOp() != "memref.copy" && getMemcpyOp() != "linalg.copy")
413 return emitOpError() << "unsupported memcpy op";
414 if (getAllocOp() != "memref.alloc" && getAllocOp() != "memref.alloca")
415 return emitOpError() << "unsupported alloc op";
416 return success();
417}
418
419//===----------------------------------------------------------------------===//
420// DecomposeOp
421//===----------------------------------------------------------------------===//
422
423DiagnosedSilenceableFailure
424transform::DecomposeOp::applyToOne(transform::TransformRewriter &rewriter,
425 LinalgOp target,
426 transform::ApplyToEachResultList &results,
427 transform::TransformState &state) {
428#define DOWNSCALE(trans) \
429 { \
430 FailureOr<LinalgOp> res = tryApply<trans>(target); \
431 if (succeeded(res)) { \
432 results.push_back(*res); \
433 return DiagnosedSilenceableFailure::success(); \
434 } \
435 }
436
437#define DOWNSCALE_CALL(a, b) DownscaleSizeOneWindowed2DConvolution<a, b>
438#define DOWNSCALE_NORMAL(a, b) DOWNSCALE(DOWNSCALE_CALL(a, b))
439
440 DOWNSCALE_NORMAL(Conv2DNhwcHwcfOp, Conv1DNwcWcfOp)
441 DOWNSCALE_NORMAL(Conv2DNchwFchwOp, Conv1DNcwFcwOp)
442 DOWNSCALE_NORMAL(PoolingNhwcSumOp, PoolingNwcSumOp)
443 DOWNSCALE_NORMAL(PoolingNchwSumOp, PoolingNcwSumOp)
444 DOWNSCALE_NORMAL(PoolingNhwcMaxOp, PoolingNwcMaxOp)
445 DOWNSCALE_NORMAL(PoolingNhwcMaxUnsignedOp, PoolingNwcMaxUnsignedOp)
446 DOWNSCALE_NORMAL(PoolingNhwcMinOp, PoolingNwcMinOp)
447 DOWNSCALE_NORMAL(PoolingNhwcMinUnsignedOp, PoolingNwcMinUnsignedOp)
448 DOWNSCALE_NORMAL(PoolingNchwMaxOp, PoolingNcwMaxOp)
449 DOWNSCALE(DownscaleDepthwiseConv2DNhwcHwcOp)
450 DOWNSCALE(DownscaleConv2DOp)
451#undef DOWNSCALE_NORMAL
452#undef DOWNSCALE_CALL
453#undef DOWNSCALE
454 return emitDefaultSilenceableFailure(target);
455}
456
457//===----------------------------------------------------------------------===//
458// DecomposeInterfaceOp
459//===----------------------------------------------------------------------===//
460
461// Decompose the target operation if it implements the AggregatedOpInterface.
462// Push the decomposed operations (the ones that replaces the values produced by
463// \p target) in the `results`.
464DiagnosedSilenceableFailure transform::DecomposeInterfaceOp::applyToOne(
465 transform::TransformRewriter &rewriter, Operation *target,
466 transform::ApplyToEachResultList &results,
467 transform::TransformState &state) {
468 auto decomposableOp = dyn_cast<AggregatedOpInterface>(target);
469 if (!decomposableOp) {
470 failed(rewriter.notifyMatchFailure(target,
471 "payload is not a decomposable op"));
472 return emitDefaultSilenceableFailure(target);
473 }
474
475 FailureOr<SmallVector<Value>> maybeNewResults =
476 decomposableOp.decomposeOperation(rewriter);
477 if (failed(maybeNewResults))
478 return emitDefaultSilenceableFailure(target);
479
480 rewriter.replaceOp(decomposableOp, *maybeNewResults);
481 for (Value val : *maybeNewResults) {
482 Operation *definition = val.getDefiningOp();
483 if (definition)
484 results.push_back(definition);
485 }
486 return DiagnosedSilenceableFailure::success();
487}
488
489//===----------------------------------------------------------------------===//
490// EliminateLinalgOpAnchoredEmptyTensorsOp
491//===----------------------------------------------------------------------===//
492
493void transform::EliminateLinalgOpAnchoredEmptyTensorsOp::getEffects(
494 SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
495 onlyReadsHandle(getTargetMutable(), effects);
496 modifiesPayload(effects);
497}
498
499DiagnosedSilenceableFailure
500transform::EliminateLinalgOpAnchoredEmptyTensorsOp::apply(
501 transform::TransformRewriter &rewriter, TransformResults &transformResults,
502 TransformState &state) {
503 bufferization::OneShotBufferizationOptions options;
504 options.allowReturnAllocsFromLoops = true;
505
506 for (Operation *target : state.getPayloadOps(getTarget())) {
507 bufferization::OneShotAnalysisState state(target, options);
508 if (failed(analyzeOp(target, state)))
509 return mlir::emitSilenceableFailure(target->getLoc())
510 << "failed to analyze op";
511 if (failed(linalg::linalgOpAnchoredEmptyTensorEliminationStep(
512 rewriter, target, state)))
513 return mlir::emitSilenceableFailure(target->getLoc())
514 << "failed to eliminate LinalgOp anchored tensor.empty ops";
515 }
516 return DiagnosedSilenceableFailure::success();
517}
518
519//===----------------------------------------------------------------------===//
520// FuseOp
521//===----------------------------------------------------------------------===//
522
523/// Apply a tiling transformation to all payload ops and store both the
524/// tiled operation as well as the created tile loops.
525template <typename Range>
526static LogicalResult applyTilingToAll(
527 RewriterBase &rewriter, Operation *transformOp, Range &&payloadOps,
528 unsigned numLoops, transform::TransformResults &transformResults,
529 function_ref<FailureOr<scf::SCFTileAndFuseResult>(TilingInterface)>
530 applyFn) {
531 SmallVector<Operation *> tiledLinalgOps;
532 SmallVector<SmallVector<Operation *>> loopOps(numLoops);
533
534 for (Operation *target : payloadOps) {
535 auto tilingInterfaceOp = dyn_cast<TilingInterface>(target);
536 if (!tilingInterfaceOp)
537 return transformOp->emitError(message: "only TilingInterface ops are supported");
538
539 rewriter.setInsertionPoint(target);
540 FailureOr<scf::SCFTileAndFuseResult> tiledResults =
541 applyFn(tilingInterfaceOp);
542 if (failed(tiledResults))
543 return failure();
544
545 // Perform the replacement of tiled and fused values.
546 SmallVector<Operation *> opsToReplace{target};
547 llvm::append_range(opsToReplace, tiledResults->fusedProducers);
548 for (Operation *toReplace : opsToReplace) {
549 for (OpResult res : toReplace->getResults())
550 if (auto replacement = tiledResults->replacements.lookup(res))
551 rewriter.replaceAllUsesWith(res, replacement);
552 if (toReplace->use_empty()) {
553 rewriter.eraseOp(op: toReplace);
554 }
555 }
556
557 // Report back the relevant handles to the transform op.
558 tiledLinalgOps.push_back(Elt: tiledResults->tiledAndFusedOps.front());
559 assert(tiledResults->loops.size() == numLoops &&
560 "Mismatched number of loops, tile and fuse transform should have "
561 "failed");
562 for (unsigned int i = 0; i < numLoops; ++i)
563 loopOps[i].push_back(Elt: tiledResults->loops[i]);
564 }
565
566 transformResults.set(value: transformOp->getOpResult(idx: 0), ops&: tiledLinalgOps);
567 for (unsigned int i = 0; i < numLoops; ++i)
568 transformResults.set(value: transformOp->getOpResult(idx: i + 1), ops&: loopOps[i]);
569
570 return success();
571}
572
573DiagnosedSilenceableFailure
574transform::FuseOp::apply(transform::TransformRewriter &rewriter,
575 mlir::transform::TransformResults &transformResults,
576 mlir::transform::TransformState &state) {
577 SmallVector<int64_t> tileSizes =
578 extractFromIntegerArrayAttr<int64_t>(getTileSizes());
579 SmallVector<int64_t> tileInterchange =
580 extractFromIntegerArrayAttr<int64_t>(getTileInterchange());
581
582 scf::SCFTilingOptions tilingOptions;
583 tilingOptions.interchangeVector = tileInterchange;
584 SmallVector<OpFoldResult> tileSizesOfr =
585 getAsIndexOpFoldResult(rewriter.getContext(), tileSizes);
586 tilingOptions = tilingOptions.setTileSizes(tileSizesOfr);
587 scf::SCFTileAndFuseOptions tileAndFuseOptions;
588 tileAndFuseOptions.tilingOptions = tilingOptions;
589
590 if (getApplyCleanup()) {
591 MLIRContext *context = rewriter.getContext();
592 RewritePatternSet patterns(context);
593 tensor::ExtractSliceOp::getCanonicalizationPatterns(patterns, context);
594 tensor::populateMergeConsecutiveInsertExtractSlicePatterns(patterns);
595 tensor::populateBubbleUpExtractSliceOpPatterns(patterns);
596 tileAndFuseOptions.cleanupPatterns = std::move(patterns);
597 }
598
599 LogicalResult result = applyTilingToAll(
600 rewriter, getOperation(), state.getPayloadOps(getTarget()),
601 tileSizes.size() - llvm::count(tileSizes, 0), transformResults,
602 [&](TilingInterface tilingInterfaceOp)
603 -> FailureOr<scf::SCFTileAndFuseResult> {
604 return tileConsumerAndFuseProducersUsingSCF(rewriter, tilingInterfaceOp,
605 tileAndFuseOptions);
606 });
607 return failed(result) ? DiagnosedSilenceableFailure::definiteFailure()
608 : DiagnosedSilenceableFailure::success();
609}
610
611LogicalResult transform::FuseOp::verify() {
612 SmallVector<int64_t> permutation =
613 extractFromIntegerArrayAttr<int64_t>(getTileInterchange());
614 auto sequence = llvm::to_vector(llvm::seq<int64_t>(0, permutation.size()));
615 if (!std::is_permutation(sequence.begin(), sequence.end(),
616 permutation.begin(), permutation.end())) {
617 return emitOpError() << "expects interchange to be a permutation, found "
618 << getTileInterchange();
619 }
620
621 SmallVector<int64_t> sizes =
622 extractFromIntegerArrayAttr<int64_t>(getTileSizes());
623 size_t numExpectedLoops = sizes.size() - llvm::count(sizes, 0);
624 if (numExpectedLoops != getNumResults() - 1)
625 return emitOpError() << "expects " << numExpectedLoops << " loop results";
626
627 return success();
628}
629
630//===----------------------------------------------------------------------===//
631// FuseIntoContainingOp
632//===----------------------------------------------------------------------===//
633
634void transform::FuseIntoContainingOp::build(OpBuilder &builder,
635 OperationState &result,
636 Value producerOp,
637 Value containingOp) {
638 result.addOperands({producerOp, containingOp});
639 auto resultType = transform::AnyOpType::get(builder.getContext());
640 result.addTypes({resultType, resultType});
641}
642
643/// Add new operands to the forall op for users of the producerOp
644/// that are dominated by the containing scf.forall op.
645static Operation *replaceForAllWithNewSignature(
646 RewriterBase &rewriter, Diagnostic &diag, Operation *producerOp,
647 Operation *containingOp, TilingResult &tileAndFuseResult,
648 int64_t resultNumber, SmallVector<OpFoldResult> &offsets,
649 SmallVector<OpFoldResult> &sizes) {
650
651 // Count number of users not including the containing op
652 SetVector<Operation *> dominatedUsers;
653 DominanceInfo domInfo(containingOp);
654 for (Operation *user : producerOp->getResult(idx: resultNumber).getUsers()) {
655 if (!containingOp->isAncestor(other: user) &&
656 (domInfo.dominates(a: containingOp, b: user))) {
657 dominatedUsers.insert(X: user);
658 }
659 }
660 if (dominatedUsers.empty())
661 return nullptr;
662
663 // Create new scf.forall op
664 auto forallOp = cast<scf::ForallOp>(containingOp);
665 OpBuilder::InsertionGuard g(rewriter);
666 rewriter.setInsertionPoint(forallOp);
667
668 // Get new output
669 Location loc = forallOp.getLoc();
670 auto genericOp = dyn_cast<linalg::GenericOp>(producerOp);
671 if (!genericOp)
672 return nullptr;
673 SmallVector<Value> outputs = genericOp.getOutputs();
674 SmallVector<Value> newOuts(forallOp.getOutputs());
675 newOuts.push_back(Elt: outputs[resultNumber]);
676
677 // Create new scf.forall op
678 auto newforallOp = rewriter.create<scf::ForallOp>(
679 loc, forallOp.getMixedLowerBound(), forallOp.getMixedUpperBound(),
680 forallOp.getMixedStep(), newOuts, forallOp.getMapping());
681 rewriter.eraseBlock(block: newforallOp.getBody());
682 newforallOp.getRegion().takeBody(forallOp.getRegion());
683
684 // Add additional block argument for new value being returned
685 // and replaces all uses of the new output with corresponding bbArg
686 // inside the scf.forall to enable fusion into this new scf.forall.
687 newforallOp.getBody()->addArgument(newOuts.back().getType(),
688 newOuts.back().getLoc());
689 auto bbArgs = newforallOp.getBody()->getArguments();
690 rewriter.replaceUsesWithIf(newOuts.back(), bbArgs.back(),
691 [&](OpOperand &use) {
692 Operation *op = use.getOwner();
693 return newforallOp->isProperAncestor(op);
694 });
695
696 // Fix terminator
697 scf::InParallelOp terminatorOp = newforallOp.getTerminator();
698 SmallVector<Operation *> yieldingOps = llvm::to_vector<4>(llvm::map_range(
699 terminatorOp.getYieldingOps(), [](Operation &op) { return &op; }));
700 Operation *firstYieldOp = yieldingOps.front();
701 rewriter.setInsertionPoint(firstYieldOp);
702 Value src = tileAndFuseResult.tiledValues[0];
703 Value dst = newforallOp.getRegionIterArgs().back();
704 SmallVector<OpFoldResult> strides(offsets.size(), rewriter.getIndexAttr(1));
705 rewriter.create<tensor::ParallelInsertSliceOp>(firstYieldOp->getLoc(), src,
706 dst, offsets, sizes, strides);
707
708 for (auto result : llvm::enumerate(forallOp.getResults())) {
709 rewriter.replaceAllUsesWith(result.value(),
710 newforallOp->getResult(result.index()));
711 }
712 rewriter.replaceUsesWithIf(producerOp->getResult(idx: resultNumber),
713 newforallOp->getResults().back(),
714 [&](OpOperand &use) {
715 Operation *user = use.getOwner();
716 return dominatedUsers.contains(key: user);
717 });
718 return newforallOp;
719}
720
721/// Given two operands coming from a loop iter arg, 'src' and 'dst', return true
722/// if the operand 'src' is equal to 'dst' or equal to a iter arg present in a
723/// outer loop. To determine the second condition, this function iterates
724/// using a worklist over the enclosing loops, trying to find 'src' in any of
725/// the parent loop's iter args.
726static bool sameOrEquivalentIterArg(Value src, Value dst) {
727 // Stack like vector containing possible iterArgs candidates. The first one
728 // is dst, and we will transverse the IR from there.
729 SmallVector<Value> destWorklist;
730 destWorklist.push_back(Elt: dst);
731
732 while (!destWorklist.empty()) {
733 Value currentDst = destWorklist.pop_back_val();
734
735 // We have found the same operand in some iter arg in the loop structure,
736 // so src and dst are equivalent.
737 if (src == currentDst)
738 return true;
739
740 // The operands are not equivalent, look for enclosing loops over
741 // currentDst.
742 auto bbArg = dyn_cast<BlockArgument>(Val&: currentDst);
743 if (!bbArg)
744 continue;
745
746 Block *parentBlock = bbArg.getOwner();
747 assert(parentBlock && "unlinked block argument");
748
749 Operation *parentOp = parentBlock->getParentOp();
750 assert(parentOp && "expected block argument with parent operation");
751
752 // Check if parent is loop-like. If it's not, do not add it to the worklist.
753 auto parentLoop = dyn_cast<LoopLikeOpInterface>(parentOp);
754 if (!parentLoop)
755 continue;
756
757 for (auto innerIterArg : parentLoop.getRegionIterArgs()) {
758 // No need to check for null as innerIterArg is tied to parentLoop.
759 OpOperand *operand = parentLoop.getTiedLoopInit(innerIterArg);
760 Value loopBlockArgument =
761 parentLoop->getOperand(operand->getOperandNumber());
762 destWorklist.push_back(loopBlockArgument);
763 }
764 }
765
766 return false;
767}
768
769/// Find the first "extract" user of `producerOp` and tile it right before its
770/// use. The tiled op is fused under the `containingOp`.
771/// Return this fused op on success or nullptr if anything fails.
772/// If tiled op has uses that are dominated by `containingOp`, return
773/// a new `containingOp` with results of the fused op appended to
774/// results of the `containingOp` or nullptr if there are no dominated uses.
775static std::tuple<SmallVector<Operation *>, Operation *>
776tileAndFuseFirstExtractUse(RewriterBase &rewriter, Diagnostic &diag,
777 Operation *producerOp, Operation *containingOp) {
778 LLVM_DEBUG(DBGS() << "Try to fuse a direct extract use\n");
779 auto tileableProducer = dyn_cast<TilingInterface>(producerOp);
780 if (!tileableProducer) {
781 diag.attachNote(noteLoc: producerOp->getLoc())
782 << "producer is not a TileableInterface: " << *producerOp;
783 return {};
784 }
785
786 // Search the producer slices accessed within the containing operation.
787 // TODO: Generalize to more extract/insert/parallel_insert triples, maybe
788 // evolve into an interface.
789 auto it = llvm::find_if(tileableProducer->getUsers(), [&](Operation *user) {
790 auto sliceOp = dyn_cast<tensor::ExtractSliceOp>(user);
791 return sliceOp && containingOp->isProperAncestor(sliceOp);
792 });
793
794 // Find a fusion opportunity.
795 if (it == tileableProducer->getUsers().end()) {
796 diag.attachNote(noteLoc: tileableProducer->getLoc())
797 << "could not find fusion opportunity for: " << *tileableProducer;
798 return {};
799 }
800 auto sliceOpToTile = cast<tensor::ExtractSliceOp>(*it);
801
802 // Try to fuse the producer in-place.
803 OpBuilder::InsertionGuard guard(rewriter);
804 rewriter.setInsertionPoint(sliceOpToTile);
805
806 // Clone the producer inside the consumer and try to update the producer init
807 // operands using the loop bbArgs if applicable. More precisely, if the bbArg
808 // of the container loop points to a value that it is used by the consumer op,
809 // then, instead of using such value on the consumer, use the value coming
810 // from the bbArg instead. This allows to reuse the output tensor (instead of
811 // creating a new one) of the container when both producer and container write
812 // to the same output.
813 if (LoopLikeOpInterface containerLoop =
814 dyn_cast<LoopLikeOpInterface>(sliceOpToTile->getParentOp())) {
815 Operation *clone = rewriter.clone(op&: *producerOp);
816 rewriter.modifyOpInPlace(root: clone, callable: [&]() {
817 // Iterate over the outputs of the producer and over the loop bbArgs and
818 // check if any bbArg points to the same value as the producer output. In
819 // such case, make the producer output point to the bbArg directly.
820 for (OpOperand &initOperandPtr :
821 cast<DestinationStyleOpInterface>(clone).getDpsInitsMutable()) {
822 Value producerOperand =
823 clone->getOperand(initOperandPtr.getOperandNumber());
824 for (BlockArgument containerIterArg :
825 containerLoop.getRegionIterArgs()) {
826 OpOperand *bbArg = containerLoop.getTiedLoopInit(containerIterArg);
827 Value consumerOperand =
828 containerLoop->getOperand(bbArg->getOperandNumber());
829 // The producer has the same init as the loop bbArg, use it.
830 if (sameOrEquivalentIterArg(producerOperand, consumerOperand)) {
831 initOperandPtr.set(containerIterArg);
832 }
833 }
834 }
835 });
836
837 tileableProducer = dyn_cast<TilingInterface>(clone);
838 }
839
840 // Tile the producer.
841 int64_t resultNumber =
842 cast<OpResult>(sliceOpToTile.getSource()).getResultNumber();
843 LLVM_DEBUG(DBGS() << "resultNumber: " << resultNumber << "\n");
844
845 SmallVector<OpFoldResult> offsets = sliceOpToTile.getMixedOffsets();
846 SmallVector<OpFoldResult> sizes = sliceOpToTile.getMixedSizes();
847
848 FailureOr<TilingResult> tileAndFuseResult =
849 tileableProducer.generateResultTileValue(rewriter, resultNumber, offsets,
850 sizes);
851
852 if (failed(Result: tileAndFuseResult)) {
853 diag.attachNote(noteLoc: tileableProducer->getLoc())
854 << "failed to tile producer op: " << *tileableProducer;
855 return {};
856 }
857
858#ifndef NDEBUG
859 for (auto *tiledOp : tileAndFuseResult->tiledOps) {
860 LLVM_DEBUG(DBGS() << "tiledProducer: " << *tiledOp << "\n");
861 }
862#endif
863
864 // Replace the extract op.
865 auto maybeRankReduced = tensor::ExtractSliceOp::rankReduceIfNeeded(
866 rewriter, sliceOpToTile->getLoc(), tileAndFuseResult->tiledValues[0],
867 cast<RankedTensorType>(sliceOpToTile->getResult(0).getType()).getShape());
868 if (failed(maybeRankReduced)) {
869 diag.attachNote(noteLoc: producerOp->getLoc())
870 << "shape types don't match (missing canonicalization?):\nTiledOp: "
871 << tileAndFuseResult->tiledValues[0]
872 << "\nSliceOp: " << sliceOpToTile.getOperation() << '\n';
873 return {};
874 }
875 rewriter.replaceOp(sliceOpToTile, *maybeRankReduced);
876
877 // Add new outputs to containing op, if required
878 Operation *newContainingOp = replaceForAllWithNewSignature(
879 rewriter, diag, producerOp, containingOp, tileAndFuseResult&: *tileAndFuseResult,
880 resultNumber, offsets, sizes);
881
882 // Cleanup clone.
883 if (dyn_cast<LoopLikeOpInterface>(containingOp))
884 rewriter.eraseOp(op: tileableProducer);
885
886 return std::make_tuple(args&: tileAndFuseResult->tiledOps, args&: newContainingOp);
887}
888
889/// First, find the first "scf::ForallOp" user of `producerOp` and ensure
890/// it is exactly the `containingOp`, otherwise bail.
891/// Then, find the first "extract" user of the tied block argument and tile it
892/// right before its "extract" use. The tiled op is fused under the
893/// `containingOp`.
894/// Return this fused op on success or nullptr if anything fails.
895static SmallVector<Operation *>
896tileAndFuseFirstExtractUseThroughContainingOpBlockArgument(
897 RewriterBase &rewriter, Diagnostic &diag, Operation *producerOp,
898 Operation *containingOp) {
899 LLVM_DEBUG(DBGS() << "Try to fuse an extract use through block argument\n");
900
901 auto tileableProducer = dyn_cast<TilingInterface>(producerOp);
902 if (!tileableProducer) {
903 diag.attachNote(noteLoc: producerOp->getLoc())
904 << "producer is not a TileableInterface: " << *producerOp;
905 return {};
906 }
907
908 // Search the first use by a "scf::ForallOp" user.
909 scf::ForallOp forallOp;
910 auto itProducerUses =
911 llvm::find_if(tileableProducer->getUses(), [&](OpOperand &use) {
912 forallOp = dyn_cast<scf::ForallOp>(use.getOwner());
913 return forallOp;
914 });
915 // If it's not from the containing op, return.
916 if (!forallOp || forallOp != containingOp) {
917 diag.attachNote(noteLoc: tileableProducer->getLoc())
918 << "could not find a use by the containing op: " << *tileableProducer;
919 return {};
920 }
921
922 // Search the producer slices accessed within the containing
923 // operation.
924 // TODO: Generalize to more extract/insert/parallel_insert triples.
925 // Maybe evolve into an interface.
926 OpOperand *pUse = &(*itProducerUses);
927 BlockArgument bbArg = forallOp.getTiedBlockArgument(pUse);
928
929 // Search the producer slices accessed within the containing operation.
930 // TODO: Generalize to more extract/insert/parallel_insert triples, maybe
931 // evolve into an interface.
932 auto itBBArgUsers = llvm::find_if(Range: bbArg.getUsers(), P: [&](Operation *user) {
933 auto sliceOp = dyn_cast<tensor::ExtractSliceOp>(user);
934 return sliceOp && containingOp->isProperAncestor(sliceOp);
935 });
936
937 // Find a fusion opportunity.
938 if (itBBArgUsers == bbArg.getUsers().end()) {
939 diag.attachNote(noteLoc: containingOp->getLoc())
940 << "could not find fusion opportunity for bbArg: " << bbArg;
941 return {};
942 }
943 auto sliceOpToTile = cast<tensor::ExtractSliceOp>(*itBBArgUsers);
944
945 // Try to fuse the producer in-place.
946 OpBuilder::InsertionGuard guard(rewriter);
947 rewriter.setInsertionPoint(sliceOpToTile);
948
949 // Replace the use in the tileableProducer before tiling: clone, replace and
950 // then tile.
951 int64_t resultNumber = cast<OpResult>(Val: pUse->get()).getResultNumber();
952 LLVM_DEBUG(DBGS() << "resultNumber: " << resultNumber << "\n");
953
954 // Gather destination tensors.
955 SmallVector<Value> destinationTensors;
956 if (failed(tensor::getOrCreateDestinations(
957 b&: rewriter, loc: tileableProducer->getLoc(), op: tileableProducer,
958 result&: destinationTensors))) {
959 diag.attachNote(noteLoc: tileableProducer->getLoc())
960 << "failed to get destination tensors for: " << *tileableProducer;
961 return {};
962 }
963
964 IRMapping bvm;
965 bvm.map(from: destinationTensors[resultNumber], to: bbArg);
966 auto tileableProducerClone =
967 cast<TilingInterface>(rewriter.clone(*tileableProducer, bvm));
968 auto scopeGuard =
969 llvm::make_scope_exit(F: [&]() { rewriter.eraseOp(op: tileableProducerClone); });
970
971 // Tile the producer.
972 FailureOr<TilingResult> tileAndFuseResult =
973 tileableProducerClone.generateResultTileValue(
974 rewriter, resultNumber, sliceOpToTile.getMixedOffsets(),
975 sliceOpToTile.getMixedSizes());
976 if (failed(Result: tileAndFuseResult)) {
977 diag.attachNote(noteLoc: tileableProducer->getLoc())
978 << "failed to tile producer op: " << *tileableProducer;
979 return {};
980 }
981
982 // Replace the extract op.
983 auto maybeRankReduced = tensor::ExtractSliceOp::rankReduceIfNeeded(
984 rewriter, sliceOpToTile->getLoc(), tileAndFuseResult->tiledValues[0],
985 cast<RankedTensorType>(sliceOpToTile->getResult(0).getType()).getShape());
986 assert(succeeded(maybeRankReduced) && "unexpected shape");
987 rewriter.replaceOp(sliceOpToTile, *maybeRankReduced);
988
989 // Replace the use in containingOp.
990 rewriter.modifyOpInPlace(root: containingOp, callable: [&]() {
991 containingOp->setOperand(idx: pUse->getOperandNumber(),
992 value: destinationTensors.front());
993 });
994
995 return tileAndFuseResult->tiledOps;
996}
997
998static Operation *cloneAndFuseFirstUse(RewriterBase &rewriter, Diagnostic &diag,
999 Operation *producerOp,
1000 Operation *containingOp) {
1001 LLVM_DEBUG(DBGS() << "Try to fuse an use by cloning\n");
1002
1003 // Gather all uses inside the containing op.
1004 SmallVector<OpOperand *> uses;
1005 for (OpResult result : producerOp->getOpResults()) {
1006 for (OpOperand &use : result.getUses()) {
1007 if (containingOp->isProperAncestor(other: use.getOwner())) {
1008 uses.push_back(Elt: &use);
1009 continue;
1010 }
1011 // Cannot clone and fuse if the use is by the containing op itself: fail
1012 // immediately.
1013 if (containingOp == use.getOwner()) {
1014 diag.attachNote(noteLoc: producerOp->getLoc())
1015 << "producer op use by containing op cannot be fused by cloning";
1016 return nullptr;
1017 }
1018 }
1019 }
1020
1021 // Check for a non-empty list of fusion opportunities.
1022 if (uses.empty()) {
1023 diag.attachNote(noteLoc: producerOp->getLoc()) << "no fusion opportunity by cloning";
1024 return nullptr;
1025 }
1026
1027 // Clone and fuse inside the containing op.
1028 Operation *fusedOp = nullptr;
1029 OpOperand *use = uses.front();
1030 // Parallel insert slice is not a valid clone destination.
1031 // TODO: Generalize to other type of ops.
1032 assert(!isa<tensor::ParallelInsertSliceOp>(use->getOwner()) &&
1033 "Parallel insert slice is not a valid clone destination");
1034 unsigned resultNumber = cast<OpResult>(Val: use->get()).getResultNumber();
1035 LLVM_DEBUG(DBGS() << "resultNumber: " << resultNumber << "\n");
1036
1037 OpBuilder::InsertionGuard guard(rewriter);
1038 rewriter.setInsertionPoint(use->getOwner());
1039 fusedOp = rewriter.clone(op&: *producerOp);
1040 rewriter.modifyOpInPlace(
1041 root: use->getOwner(), callable: [&] { use->set(fusedOp->getOpResult(idx: resultNumber)); });
1042
1043 return fusedOp;
1044}
1045
1046bool transform::FuseIntoContainingOp::allowsRepeatedHandleOperands() {
1047 // Allow repeated handles since we are fusing everything anyway.
1048 return true;
1049}
1050
1051DiagnosedSilenceableFailure
1052transform::FuseIntoContainingOp::apply(transform::TransformRewriter &rewriter,
1053 transform::TransformResults &results,
1054 transform::TransformState &state) {
1055 SmallVector<Operation *> fusedOps;
1056 auto producerOps = state.getPayloadOps(getProducerOp());
1057 auto containingOps = state.getPayloadOps(getContainingOp());
1058 if (!llvm::hasSingleElement(containingOps)) {
1059 return emitDefiniteFailure()
1060 << "requires exactly one containing_op handle (got "
1061 << llvm::range_size(containingOps) << ")";
1062 }
1063 Operation *containingOp = *containingOps.begin();
1064
1065 // If nothing to fuse, propagate success.
1066 if (std::empty(producerOps)) {
1067 results.set(cast<OpResult>(getFusedOp()), SmallVector<mlir::Operation *>{});
1068 results.set(cast<OpResult>(getNewContainingOp()), {containingOp});
1069 return DiagnosedSilenceableFailure::success();
1070 }
1071
1072 // Helper function to find the next producer that should be fused. Take any
1073 // producer that has a use inside the containing op.
1074 SetVector<Operation *> remainingProducers(llvm::from_range, producerOps);
1075 auto getNextProducer = [&]() -> FailureOr<Operation *> {
1076 for (const auto &it : enumerate(remainingProducers)) {
1077 Operation *producerOp = it.value();
1078 // The containing op may be a user of producerOp: use isAncestor.
1079 int64_t numUsesInContainingOp =
1080 llvm::count_if(producerOp->getUsers(), [&](Operation *op) {
1081 return containingOp->isAncestor(op);
1082 });
1083 // TODO: When resolving the TODO below (no duplicate ops), take an op
1084 // that has no use among the remaining producers. This is a topological
1085 // sorting.
1086 if (numUsesInContainingOp > 0) {
1087 if (numUsesInContainingOp == 1)
1088 remainingProducers.erase(remainingProducers.begin() + it.index());
1089 return producerOp;
1090 }
1091 }
1092 return failure();
1093 };
1094
1095 while (!remainingProducers.empty()) {
1096 auto nextProducer = getNextProducer();
1097 if (failed(nextProducer)) {
1098 auto diag = mlir::emitSilenceableFailure(getLoc())
1099 << "could not find next producer to fuse into container";
1100 diag.attachNote(containingOp->getLoc()) << "containing op";
1101 return diag;
1102 }
1103
1104 Operation *producerOp = *nextProducer;
1105
1106 // Default diagnostic, to be complemented with more failure information.
1107 Diagnostic diag(producerOp->getLoc(), DiagnosticSeverity::Remark);
1108 diag << "could not fuse " << *producerOp << " into " << *containingOp;
1109
1110 // TODO: If there are multiple uses of the producer in the containing op,
1111 // we currently tile/clone the op multiple times (once per use). In some
1112 // cases, we can tile/clone once and reuse the value for each use.
1113 // Futhermore, producers should then be traversed according to a
1114 // topological sorting.
1115 auto [tiledOps, newContainingOp] =
1116 tileAndFuseFirstExtractUse(rewriter, diag, producerOp, containingOp);
1117 if (!tiledOps.empty()) {
1118 LLVM_DEBUG(DBGS() << "\nFused a direct extract use\n" << *containingOp);
1119 fusedOps.append(tiledOps);
1120 if (newContainingOp) {
1121 // Update handles associated with the containing op so we don't need to
1122 // invalidate them. This is a hack to support better composability
1123 // between tiling and fusion while a proper mechanism is being
1124 // investigated.
1125 //
1126 // DO NOT replicate this elsewhere unless you understand what you are
1127 // doing.
1128 LogicalResult replacementStatus =
1129 rewriter.notifyPayloadOperationReplaced(containingOp,
1130 newContainingOp);
1131 (void)replacementStatus;
1132 assert(succeeded(replacementStatus) &&
1133 "unable to update transform state mapping");
1134 rewriter.eraseOp(containingOp);
1135 containingOp = newContainingOp;
1136 }
1137 continue;
1138 }
1139
1140 SmallVector<Operation *> tiledContainingOpOperand =
1141 tileAndFuseFirstExtractUseThroughContainingOpBlockArgument(
1142 rewriter, diag, producerOp, containingOp);
1143 if (!tiledContainingOpOperand.empty()) {
1144 LLVM_DEBUG(DBGS() << "\nFused an extract use through block argument\n"
1145 << *containingOp);
1146 fusedOps.append(tiledContainingOpOperand);
1147 continue;
1148 }
1149
1150 Operation *cloned =
1151 cloneAndFuseFirstUse(rewriter, diag, producerOp, containingOp);
1152 if (cloned) {
1153 LLVM_DEBUG(DBGS() << "\nFused an use by cloning\n" << *containingOp);
1154 fusedOps.push_back(cloned);
1155 continue;
1156 }
1157 return DiagnosedSilenceableFailure::silenceableFailure(std::move(diag));
1158 }
1159
1160 results.set(cast<OpResult>(getFusedOp()), fusedOps);
1161 results.set(cast<OpResult>(getNewContainingOp()), {containingOp});
1162 return DiagnosedSilenceableFailure::success();
1163}
1164
1165void transform::FuseIntoContainingOp::getEffects(
1166 SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
1167 consumesHandle(getProducerOpMutable(), effects);
1168 onlyReadsHandle(getContainingOpMutable(), effects);
1169 producesHandle(getOperation()->getOpResults(), effects);
1170 modifiesPayload(effects);
1171}
1172
1173//===----------------------------------------------------------------------===//
1174// GeneralizeOp
1175//===----------------------------------------------------------------------===//
1176
1177DiagnosedSilenceableFailure
1178transform::GeneralizeOp::applyToOne(transform::TransformRewriter &rewriter,
1179 LinalgOp target,
1180 transform::ApplyToEachResultList &results,
1181 transform::TransformState &state) {
1182 // Exit early if no transformation is needed.
1183 if (isa<GenericOp>(target)) {
1184 results.push_back(target);
1185 return DiagnosedSilenceableFailure::success();
1186 }
1187 rewriter.setInsertionPoint(target);
1188 FailureOr<LinalgOp> generic = generalizeNamedOp(rewriter, target);
1189 if (succeeded(generic)) {
1190 results.push_back(generic->getOperation());
1191 return DiagnosedSilenceableFailure::success();
1192 }
1193 return emitDefaultSilenceableFailure(target);
1194}
1195
1196//===----------------------------------------------------------------------===//
1197// SpecializeOp
1198//===----------------------------------------------------------------------===/
1199
1200DiagnosedSilenceableFailure
1201transform::SpecializeOp::applyToOne(transform::TransformRewriter &rewriter,
1202 LinalgOp target,
1203 transform::ApplyToEachResultList &results,
1204 transform::TransformState &state) {
1205 // Exit early if the operation is not a generic.
1206 if (!isa<GenericOp>(target)) {
1207 results.push_back(target);
1208 return DiagnosedSilenceableFailure::success();
1209 }
1210 rewriter.setInsertionPoint(target);
1211 FailureOr<LinalgOp> named =
1212 specializeGenericOp(rewriter, cast<GenericOp>(target));
1213 if (succeeded(named)) {
1214 results.push_back(named->getOperation());
1215 return DiagnosedSilenceableFailure::success();
1216 }
1217 return emitDefaultSilenceableFailure(target);
1218}
1219
1220//===----------------------------------------------------------------------===//
1221// InterchangeOp
1222//===----------------------------------------------------------------------===//
1223
1224DiagnosedSilenceableFailure
1225transform::InterchangeOp::applyToOne(transform::TransformRewriter &rewriter,
1226 GenericOp target,
1227 transform::ApplyToEachResultList &results,
1228 transform::TransformState &state) {
1229 ArrayRef<int64_t> interchangeVector = getIteratorInterchange();
1230 // Exit early if no transformation is needed.
1231 if (interchangeVector.empty()) {
1232 results.push_back(target);
1233 return DiagnosedSilenceableFailure::success();
1234 }
1235
1236 unsigned numLoops = cast<LinalgOp>(target.getOperation()).getNumLoops();
1237 if (interchangeVector.size() != numLoops) {
1238 return emitSilenceableError()
1239 << getIteratorInterchangeAttrName() << " has length ("
1240 << interchangeVector.size()
1241 << ") different from the number of loops in the target operation ("
1242 << numLoops << ")";
1243 }
1244 FailureOr<GenericOp> res = interchangeGenericOp(
1245 rewriter, target, SmallVector<unsigned>(interchangeVector));
1246 if (failed(res))
1247 return emitDefiniteFailure() << "failed to apply";
1248 results.push_back(res->getOperation());
1249 return DiagnosedSilenceableFailure::success();
1250}
1251
1252LogicalResult transform::InterchangeOp::verify() {
1253 ArrayRef<int64_t> permutation = getIteratorInterchange();
1254 auto sequence = llvm::to_vector(llvm::seq<int64_t>(0, permutation.size()));
1255 if (!std::is_permutation(sequence.begin(), sequence.end(),
1256 permutation.begin(), permutation.end())) {
1257 return emitOpError()
1258 << "expects iterator_interchange to be a permutation, found "
1259 << getIteratorInterchange();
1260 }
1261 return success();
1262}
1263
1264//===----------------------------------------------------------------------===//
1265// LinalgCopyToMemrefOp
1266//===----------------------------------------------------------------------===//
1267
1268DiagnosedSilenceableFailure transform::LinalgCopyToMemrefOp::applyToOne(
1269 transform::TransformRewriter &rewriter, Operation *targetOp,
1270 transform::ApplyToEachResultList &results,
1271 transform::TransformState &state) {
1272
1273 // Check if the target can be converted.
1274 if (!isa<linalg::CopyOp>(targetOp)) {
1275 DiagnosedSilenceableFailure diag =
1276 emitSilenceableError() << "only linalg.copy target ops are supported";
1277 diag.attachNote(targetOp->getLoc()) << "target op";
1278 return diag;
1279 }
1280
1281 auto copyOp = dyn_cast<linalg::CopyOp>(targetOp);
1282 if (!copyOp.hasPureBufferSemantics()) {
1283 DiagnosedSilenceableFailure diag =
1284 emitSilenceableError()
1285 << "cannot transform a linalg.copy on tensors into a memref.copy";
1286 diag.attachNote(targetOp->getLoc()) << "target op";
1287 return diag;
1288 }
1289
1290 SmallVector<Value> inputs = copyOp.getInputs();
1291 SmallVector<Value> outputs = copyOp.getOutputs();
1292 assert(inputs.size() == 1 && "expected linalg copy op with one input");
1293 assert(outputs.size() == 1 && "expected memref copy op with one output");
1294 Value input = inputs.front();
1295 Value output = outputs.front();
1296
1297 // linalg.copy supports different element types on source/dest whereas
1298 // memref.copy does not, so we must check that the source and dest types can
1299 // be handled by memref.copy and otherwise reject the transformation.
1300 if (!isa<ShapedType>(input.getType())) {
1301 DiagnosedSilenceableFailure diag =
1302 emitSilenceableError()
1303 << "cannot transform a linalg.copy which input has no shape";
1304 diag.attachNote(targetOp->getLoc()) << "target op";
1305 return diag;
1306 }
1307
1308 // linalg.copy destination must be a shaped type.
1309 assert(isa<ShapedType>(output.getType()));
1310
1311 if (cast<ShapedType>(input.getType()).getElementType() !=
1312 cast<ShapedType>(output.getType()).getElementType()) {
1313 DiagnosedSilenceableFailure diag =
1314 emitSilenceableError()
1315 << "cannot transform a linalg.copy with different source and "
1316 "destination element types ";
1317 diag.attachNote(targetOp->getLoc()) << "target op";
1318 return diag;
1319 }
1320
1321 // Target can be converted, do it.
1322 auto memrefCopyOp =
1323 rewriter.replaceOpWithNewOp<memref::CopyOp>(targetOp, input, output);
1324
1325 results.push_back(memrefCopyOp);
1326 return DiagnosedSilenceableFailure::success();
1327}
1328
1329//===----------------------------------------------------------------------===//
1330// LowerPackOp
1331//===----------------------------------------------------------------------===//
1332
1333DiagnosedSilenceableFailure transform::LowerPackOp::applyToOne(
1334 transform::TransformRewriter &rewriter, linalg::PackOp target,
1335 transform::ApplyToEachResultList &transformResults,
1336 transform::TransformState &state) {
1337 rewriter.setInsertionPoint(target);
1338 bool lowerPadLikeWithInsertSlice = getLowerPadLikeWithInsertSlice();
1339 FailureOr<LowerPackResult> res =
1340 lowerPack(rewriter, target, lowerPadLikeWithInsertSlice);
1341 if (failed(res)) {
1342 return mlir::emitSilenceableFailure(target->getLoc())
1343 << "cannot lower to pad + expand + transpose";
1344 }
1345 transformResults.push_back(res->padOp);
1346 transformResults.push_back(res->expandShapeOp);
1347 transformResults.push_back(res->transposeOp);
1348 return DiagnosedSilenceableFailure::success();
1349}
1350
1351//===----------------------------------------------------------------------===//
1352// LowerUnPackOp
1353//===----------------------------------------------------------------------===//
1354
1355DiagnosedSilenceableFailure transform::LowerUnPackOp::applyToOne(
1356 transform::TransformRewriter &rewriter, linalg::UnPackOp target,
1357 transform::ApplyToEachResultList &transformResults,
1358 transform::TransformState &state) {
1359 rewriter.setInsertionPoint(target);
1360 bool lowerUnpadLikeWithExtractSlice = getLowerUnpadLikeWithExtractSlice();
1361 FailureOr<LowerUnPackOpResult> res =
1362 lowerUnPack(rewriter, target, lowerUnpadLikeWithExtractSlice);
1363 if (failed(res)) {
1364 DiagnosedSilenceableFailure diag =
1365 emitSilenceableError()
1366 << "cannot lower to transpose + collapse + extract";
1367 diag.attachNote(target->getLoc()) << "target payload op";
1368 return diag;
1369 }
1370 transformResults.push_back(res->emptyOp);
1371 transformResults.push_back(res->transposeOp);
1372 transformResults.push_back(res->collapseShapeOp);
1373 transformResults.push_back(res->extractSliceOp);
1374 return DiagnosedSilenceableFailure::success();
1375}
1376
1377//===---------------------------------------------------------------------===//
1378// MatchOp
1379//===---------------------------------------------------------------------===//
1380
1381void transform::MatchOp::build(OpBuilder &builder, OperationState &result,
1382 Value target, ArrayRef<StringRef> opNames) {
1383 result.addOperands(target);
1384 result.addAttribute(MatchOp::getOpsAttrName(result.name),
1385 builder.getStrArrayAttr(opNames));
1386 result.addTypes(transform::AnyOpType::get(builder.getContext()));
1387}
1388
1389void transform::MatchOp::build(OpBuilder &builder, OperationState &result,
1390 TypeRange resultTypes, Value target,
1391 ArrayRef<StringRef> opNames) {
1392 result.addOperands(target);
1393 result.addAttribute(MatchOp::getOpsAttrName(result.name),
1394 builder.getStrArrayAttr(opNames));
1395 result.addTypes(resultTypes);
1396}
1397
1398DiagnosedSilenceableFailure
1399transform::MatchOp::apply(transform::TransformRewriter &rewriter,
1400 transform::TransformResults &results,
1401 transform::TransformState &state) {
1402 llvm::StringSet<> strs;
1403 if (getOps().has_value())
1404 strs.insert_range(getOps()->getAsValueRange<StringAttr>());
1405
1406 auto payloadOps = state.getPayloadOps(getTarget());
1407 if (!llvm::hasSingleElement(payloadOps)) {
1408 return emitDefiniteFailure("requires exactly one target handle");
1409 }
1410
1411 SmallVector<Operation *> res;
1412 bool incorrectNumOperandTypes = false;
1413 auto matchFun = [&](Operation *op) {
1414 if (getOps().has_value() && !strs.contains(op->getName().getStringRef()))
1415 return;
1416
1417 // Interfaces cannot be matched by name, just by ID.
1418 // So we specifically encode the interfaces we care about for this op.
1419 if (getInterface().has_value()) {
1420 auto iface = getInterface().value();
1421 if (iface == transform::MatchInterfaceEnum::LinalgOp &&
1422 !isa<LinalgOp>(op))
1423 return;
1424 if (iface == transform::MatchInterfaceEnum::TilingInterface &&
1425 !isa<TilingInterface>(op))
1426 return;
1427 if (iface == transform::MatchInterfaceEnum::LoopLikeInterface &&
1428 !isa<LoopLikeOpInterface>(op))
1429 return;
1430 }
1431
1432 // Check if all specified attributes match.
1433 if (getOpAttrs().has_value()) {
1434 DictionaryAttr opAttrs = getOpAttrs().value();
1435 for (NamedAttribute attr : opAttrs) {
1436 if (attr.getName() == getInterfaceAttrName() ||
1437 attr.getName() == getOpsAttrName())
1438 continue;
1439 if (!op->hasAttr(attr.getName()))
1440 return;
1441 if (op->getAttr(attr.getName()) != attr.getValue())
1442 return;
1443 }
1444 }
1445
1446 if (getFilterResultType().has_value()) {
1447 Type t = getFilterResultType().value();
1448 if (op->getNumResults() != 1 || op->getResultTypes().front() != t)
1449 return;
1450 }
1451
1452 if (getFilterOperandTypes().has_value()) {
1453 mlir::ArrayAttr types = getFilterOperandTypes().value();
1454 auto operandTypes = op->getOperandTypes();
1455
1456 if (types.size() == 1) {
1457 // All the operands must must be equal to the specified type
1458 auto typeattr =
1459 dyn_cast<mlir::TypeAttr>(getFilterOperandTypes().value()[0]);
1460 Type t = cast<::mlir::Type>(typeattr.getValue());
1461 if (!llvm::all_of(op->getOperandTypes(),
1462 [&](Type operandType) { return operandType == t; }))
1463 return;
1464 } else {
1465 // The operand types must match all the types in the list (in the same
1466 // order in with they are specified)
1467 if (types.size() != operandTypes.size()) {
1468 incorrectNumOperandTypes = true;
1469 return;
1470 }
1471
1472 for (auto [attr, operandType] :
1473 llvm::zip_equal(getFilterOperandTypes().value(), operandTypes)) {
1474 auto typeattr = cast<mlir::TypeAttr>(attr);
1475 Type type = cast<::mlir::Type>(typeattr.getValue());
1476
1477 if (type != operandType)
1478 return;
1479 }
1480 }
1481 }
1482
1483 // All constraints are satisfied.
1484 res.push_back(op);
1485 return;
1486 };
1487
1488 (*payloadOps.begin())->walk(matchFun);
1489 if (incorrectNumOperandTypes)
1490 return emitDefiniteFailure("If filter_operand_types contains more than a "
1491 "type, then it must contain as much types as "
1492 "the number of operands in the target ops");
1493 results.set(cast<OpResult>(getResult()), res);
1494 return DiagnosedSilenceableFailure::success();
1495}
1496
1497//===---------------------------------------------------------------------===//
1498// MultiTileSizesOp
1499//===---------------------------------------------------------------------===//
1500
1501static void printMultitileSizesTypes(OpAsmPrinter &printer, Operation *op,
1502 Type targetType, Type lowSizeType, Type,
1503 Type) {
1504 printer.printFunctionalType(inputs: TypeRange{targetType}, results: TypeRange{lowSizeType});
1505}
1506
1507static ParseResult parseMultitileSizesTypes(OpAsmParser &parser,
1508 Type &targetType, Type &lowSizeType,
1509 Type &highSizeType,
1510 Type &splitPointType) {
1511 FunctionType funcType;
1512 llvm::SMLoc typeLoc = parser.getCurrentLocation();
1513 if (failed(parser.parseType<FunctionType>(funcType)))
1514 return failure();
1515
1516 if (funcType.getNumInputs() != 1 || funcType.getNumResults() != 1) {
1517 parser.emitError(loc: typeLoc) << "expects a trailing functional type with one "
1518 "argument and one result";
1519 }
1520 targetType = funcType.getInput(0);
1521 lowSizeType = highSizeType = splitPointType = funcType.getResult(0);
1522
1523 return success();
1524}
1525
1526DiagnosedSilenceableFailure transform::MultiTileSizesOp::applyToOne(
1527 transform::TransformRewriter &rewriter, LinalgOp target,
1528 transform::ApplyToEachResultList &results, TransformState &state) {
1529 if (isa<TransformParamTypeInterface>(getLowSize().getType())) {
1530 if (target.hasDynamicShape()) {
1531 auto diag = emitSilenceableError()
1532 << "cannot compute parametric tile sizes for dynamically "
1533 "shaped payload op";
1534 diag.attachNote(target->getLoc()) << "payload op";
1535 return diag;
1536 }
1537
1538 FailureOr<StaticMultiSizeSpecification> spec = computeStaticMultiTileSizes(
1539 target, getDimension(), getTargetSize(), getDivisor());
1540 if (failed(spec)) {
1541 return emitSilenceableError()
1542 << "failed to compute multi-size tiling sizes";
1543 }
1544
1545 Builder builder(target.getContext());
1546 results.assign(llvm::map_range(
1547 ArrayRef<int64_t>({spec->lowTileSize, spec->highTileSize,
1548 spec->lowTileSize * spec->lowTripCount}),
1549 [&builder, this](int64_t value) {
1550 return builder.getIntegerAttr(
1551 cast<ParamType>(getLowSize().getType()).getType(), value);
1552 }));
1553 return DiagnosedSilenceableFailure::success();
1554 }
1555
1556 OpBuilder builder(target.getContext());
1557 builder.setInsertionPoint(target);
1558 OpFoldResult targetSize = builder.getIndexAttr(getTargetSize());
1559 OpFoldResult divisor = builder.getIndexAttr(getDivisor());
1560 FailureOr<MultiSizeSpecification> spec = computeMultiTileSizes(
1561 builder, target, getDimension(), targetSize, divisor);
1562 if (failed(spec)) {
1563 return emitSilenceableError() << "could not generate tile size computation";
1564 }
1565
1566 AffineExpr s0 = builder.getAffineSymbolExpr(0);
1567 AffineExpr s1 = builder.getAffineSymbolExpr(1);
1568 Operation *splitPoint =
1569 affine::makeComposedAffineApply(builder, target.getLoc(), s0 * s1,
1570 {spec->lowTileSize, spec->lowTripCount});
1571 Operation *lowTileSize = spec->lowTileSize.getDefiningOp();
1572 Operation *highTileSize = spec->highTileSize.getDefiningOp();
1573 assert(lowTileSize && highTileSize && splitPoint &&
1574 "tile sizes are not produced by operations");
1575 results.reserve(results.size() + 3);
1576 results.push_back(lowTileSize);
1577 results.push_back(highTileSize);
1578 results.push_back(splitPoint);
1579 return DiagnosedSilenceableFailure::success();
1580}
1581
1582void transform::MultiTileSizesOp::getEffects(
1583 SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
1584 onlyReadsHandle(getTargetMutable(), effects);
1585 producesHandle(getOperation()->getOpResults(), effects);
1586 if (isa<TransformParamTypeInterface>(getLowSize().getType()))
1587 onlyReadsPayload(effects);
1588 else
1589 modifiesPayload(effects);
1590}
1591
1592LogicalResult transform::MultiTileSizesOp::verify() {
1593 if (getLowSize().getType() != getHighSize().getType() ||
1594 getLowSize().getType() != getSplitPoint().getType()) {
1595 return emitOpError() << "expects all results type to be the same";
1596 }
1597 return success();
1598}
1599
1600//===---------------------------------------------------------------------===//
1601// PackOp
1602//===---------------------------------------------------------------------===//
1603
1604void transform::PackOp::build(OpBuilder &builder, OperationState &result,
1605 Value target,
1606 ArrayRef<OpFoldResult> mixedPackedSizes) {
1607 SmallVector<int64_t> staticPackedSizes;
1608 SmallVector<Value> dynamicPackedSizes;
1609 dispatchIndexOpFoldResults(mixedPackedSizes, dynamicPackedSizes,
1610 staticPackedSizes);
1611 // Call the default builder which sets up the proper operands segment sizes
1612 // attributes for multiple variadic operands. In the absence of this, horrible
1613 // bugs ensue.
1614 Type linalgOpHType = transform::OperationType::get(
1615 builder.getContext(), GenericOp::getOperationName());
1616 build(builder, result,
1617 /*resultType=*/linalgOpHType,
1618 /*target=*/target,
1619 /*dynamic_sizes=*/dynamicPackedSizes,
1620 /*static_sizes=*/builder.getDenseI64ArrayAttr(staticPackedSizes));
1621}
1622
1623SmallVector<OpFoldResult> transform::PackOp::getMixedPackedSizes() {
1624 Builder b(getContext());
1625 return getMixedValues(getStaticPackedSizes(), getPackedSizes(), b);
1626}
1627
1628DiagnosedSilenceableFailure
1629transform::PackOp::apply(transform::TransformRewriter &rewriter,
1630 transform::TransformResults &transformResults,
1631 transform::TransformState &state) {
1632 auto targetOps = state.getPayloadOps(getTarget());
1633 // If nothing to pack, propagate success.
1634 if (std::empty(targetOps)) {
1635 transformResults.set(cast<OpResult>(getPackedOp()),
1636 ArrayRef<Operation *>({}));
1637 return DiagnosedSilenceableFailure::success();
1638 }
1639 // Fail on multi-op handles.
1640 auto linalgOp = dyn_cast<LinalgOp>(*targetOps.begin());
1641 if (!llvm::hasSingleElement(targetOps) || !linalgOp) {
1642 return emitSilenceableError()
1643 << "requires target to map to exactly 1 LinalgOp (got "
1644 << llvm::range_size(targetOps) << ")";
1645 }
1646 // Fail on mismatched number of pack sizes.
1647 if (getMixedPackedSizes().size() != linalgOp.getNumLoops()) {
1648 return emitSilenceableError()
1649 << "requires number of packed sizes match the number of loops ("
1650 << getMixedPackedSizes().size() << " vs " << linalgOp.getNumLoops()
1651 << ")";
1652 }
1653
1654 // Unpack handles to constants or actual SSA index values.
1655 SmallVector<OpFoldResult> packedSizes;
1656 DiagnosedSilenceableFailure status = unpackSingleIndexResultPayloadOperations(
1657 state, *this, packedSizes, getMixedPackedSizes());
1658
1659 rewriter.setInsertionPoint(linalgOp);
1660 FailureOr<PackResult> maybeResult = pack(rewriter, linalgOp, packedSizes);
1661 if (failed(maybeResult))
1662 return emitDefiniteFailure("data tiling failed");
1663
1664 transformResults.set(cast<OpResult>(getPackedOp()),
1665 {maybeResult->packedLinalgOp.getOperation()});
1666 return DiagnosedSilenceableFailure::success();
1667}
1668
1669void transform::PackOp::getEffects(
1670 SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
1671 transform::consumesHandle(getTargetMutable(), effects);
1672 transform::onlyReadsHandle(getPackedSizesMutable(), effects);
1673 transform::producesHandle(getOperation()->getOpResults(), effects);
1674 transform::modifiesPayload(effects);
1675}
1676
1677//===---------------------------------------------------------------------===//
1678// PackGreedilyOp.
1679//===---------------------------------------------------------------------===//
1680
1681LogicalResult transform::PackGreedilyOp::verify() {
1682 if (!isPermutationVector(getMatmulInnerDimsOrder())) {
1683 return emitOpError() << getMatmulInnerDimsOrderAttrName()
1684 << " is not a valid permutation";
1685 }
1686 // TODO: relax to allow empty once we have another strategy than just matmul.
1687 if (!getMatmulPaddedSizesNextMultipleOf().empty()) {
1688 for (auto [s, nmo] :
1689 llvm::zip_equal(getMixedMatmulPackedSizes(),
1690 getMatmulPaddedSizesNextMultipleOf())) {
1691 std::optional<int64_t> maybeStaticPackedSize = getConstantIntValue(s);
1692 if (nmo != 0 &&
1693 (!maybeStaticPackedSize.has_value() || *maybeStaticPackedSize != 0)) {
1694 return emitOpError() << "at most one of the packed_size and the "
1695 "padded_sizes_next_multiple_of can be nonzero "
1696 "for the matmul strategy";
1697 }
1698 }
1699 }
1700 return success();
1701}
1702
1703DiagnosedSilenceableFailure
1704PackGreedilyOp::apply(transform::TransformRewriter &rewriter,
1705 transform::TransformResults &transformResults,
1706 transform::TransformState &state) {
1707 SmallVector<Operation *> results;
1708 for (Operation *op : state.getPayloadOps(getTarget())) {
1709 auto linalgOp = dyn_cast<LinalgOp>(op);
1710 if (!linalgOp)
1711 continue;
1712 // linalgOp will be replaced and the insertion point may be invalidated if
1713 // we set it before -> set it after.
1714 rewriter.setInsertionPointAfter(linalgOp);
1715 // Failing to pack greedily is perfectly fine.
1716 // In the future we will want to order packings according to some metric.
1717 FailureOr<PackResult> packResult = packMatmulGreedily(
1718 /*rewriter=*/rewriter,
1719 /*linalgOp=*/linalgOp,
1720 /*mnkPackedSizes=*/getMixedMatmulPackedSizes(),
1721 /*mnkPaddedSizesNextMultipleOf=*/
1722 getMatmulPaddedSizesNextMultipleOf(),
1723 /*mnkOrder=*/getMatmulInnerDimsOrder());
1724 if (succeeded(packResult)) {
1725 results.push_back(packResult->packedLinalgOp);
1726 continue;
1727 }
1728 results.push_back(linalgOp);
1729 }
1730 transformResults.set(cast<OpResult>(getPackedOp()), results);
1731 return DiagnosedSilenceableFailure::success();
1732}
1733
1734SmallVector<OpFoldResult> PackGreedilyOp::getMixedMatmulPackedSizes() {
1735 Builder b(getContext());
1736 return getMixedValues(getStaticMatmulPackedSizes(), getMatmulPackedSizes(),
1737 b);
1738}
1739
1740void transform::PackGreedilyOp::getEffects(
1741 SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
1742 transform::consumesHandle(getTargetMutable(), effects);
1743 transform::onlyReadsHandle(getMatmulPackedSizesMutable(), effects);
1744 transform::producesHandle(getOperation()->getOpResults(), effects);
1745 transform::modifiesPayload(effects);
1746}
1747
1748//===---------------------------------------------------------------------===//
1749// PackTransposeOp
1750//===---------------------------------------------------------------------===//
1751
1752LogicalResult transform::PackTransposeOp::verify() {
1753 if (!isPermutationVector(getInnerPerm())) {
1754 return emitOpError() << getInnerPermAttrName()
1755 << " is not a valid permutation";
1756 }
1757 if (!isPermutationVector(getOuterPerm())) {
1758 return emitOpError() << getOuterPermAttrName()
1759 << " is not a valid permutation";
1760 }
1761 if (getInnerPerm().empty() && getOuterPerm().empty()) {
1762 return emitOpError() << " at least one of " << getInnerPermAttrName()
1763 << " or " << getOuterPermAttrName()
1764 << " must be specified";
1765 }
1766 return success();
1767}
1768
1769namespace {
1770enum class OuterOrInnerPerm { Outer = 0, Inner = 1 };
1771} // namespace
1772
1773/// Return true if `permutation` is a valid permutation of the
1774/// `outer_dims_perm` (case OuterOrInnerPerm::Outer) or `inner_dims_pos`
1775/// (OuterOrInnerPerm::Inner) of the `tensor.pack` or `tensor.unpack` `op.
1776/// This is the case when the `permutation` rank matches the rank expected by
1777/// `op` and `permutation` is itself a permutation vector.
1778/// Return true if either `op` or `permutation` are empty to allow a simpler
1779/// polymorphic implementation.
1780template <typename RelayoutOpTy>
1781bool isValidPackingPermutation(
1782 RelayoutOpTy op, ArrayRef<int64_t> permutation,
1783 OuterOrInnerPerm outerOrInnerPerm = OuterOrInnerPerm::Outer) {
1784 static_assert(
1785 llvm::is_one_of<RelayoutOpTy, linalg::PackOp, linalg::UnPackOp>::value,
1786 "applies to only pack or unpack operations");
1787 if (!op || permutation.empty())
1788 return true;
1789 size_t innerRank = op.getInnerDimsPos().size();
1790 if (outerOrInnerPerm == OuterOrInnerPerm::Inner)
1791 return permutation.size() == innerRank && isPermutationVector(interchange: permutation);
1792 // op.getOuterDimsPerm() may be empty, in which case it is identity.
1793 // Don't rely on it.
1794 if (std::is_same<RelayoutOpTy, linalg::PackOp>::value) {
1795 return permutation.size() == op.getSourceRank() &&
1796 isPermutationVector(interchange: permutation);
1797 }
1798 return permutation.size() == op.getDestRank() &&
1799 isPermutationVector(interchange: permutation);
1800}
1801
1802DiagnosedSilenceableFailure
1803transform::PackTransposeOp::apply(transform::TransformRewriter &rewriter,
1804 transform::TransformResults &transformResults,
1805 transform::TransformState &state) {
1806 auto packOrUnpackOps = state.getPayloadOps(getTargetPackOrUnPackOp());
1807 auto linalgOps = state.getPayloadOps(getTargetLinalgOp());
1808 // Step 1. If nothing to pack, propagate success.
1809 if (std::empty(packOrUnpackOps)) {
1810 transformResults.set(cast<OpResult>(getPackedOp()), {});
1811 transformResults.set(cast<OpResult>(getPackOp()), {});
1812 transformResults.set(cast<OpResult>(getUnPackOp()), {});
1813 return DiagnosedSilenceableFailure::success();
1814 }
1815
1816 // Step 2. Bunch of runtime sanity check and error messages.
1817 // Step 2.1. Fail on multi-op handles.
1818 if (!llvm::hasSingleElement(packOrUnpackOps) ||
1819 !llvm::hasSingleElement(linalgOps)) {
1820 return emitSilenceableError()
1821 << "requires target to map to exactly 1 "
1822 "packing op and 1 packed op ("
1823 << "got " << llvm::range_size(packOrUnpackOps) << " and "
1824 << llvm::range_size(linalgOps) << ")";
1825 }
1826
1827 // Step 2.2. Fail on wrong type.
1828 auto packOp = dyn_cast<linalg::PackOp>(*packOrUnpackOps.begin());
1829 auto unPackOp = dyn_cast<linalg::UnPackOp>(*packOrUnpackOps.begin());
1830 if ((!packOp && !unPackOp)) {
1831 return emitSilenceableError() << "requires target to map to a "
1832 "linalg.pack or linalg.unpack";
1833 }
1834 LinalgOp linalgOpTarget = dyn_cast<LinalgOp>(*linalgOps.begin());
1835 if (!linalgOpTarget)
1836 return emitSilenceableError() << "requires a LinalgOp target";
1837
1838 // Step 2.3. Fail if we can't get the producer / consumer Linalg op.
1839 LinalgOp linalgOp;
1840 if (packOp && packOp.getResult().hasOneUse())
1841 linalgOp = dyn_cast<LinalgOp>(*(packOp.getResult().getUsers().begin()));
1842 else if (unPackOp)
1843 linalgOp = unPackOp.getSource().getDefiningOp<LinalgOp>();
1844 if (linalgOp != linalgOpTarget) {
1845 auto errorMsg =
1846 packOp ? StringLiteral{"not a single use by the LinalgOp target"}
1847 : StringLiteral{"not produced by the LinalgOp target"};
1848 return emitSilenceableError() << errorMsg;
1849 }
1850
1851 // Step 2.4. If we have an UnPackOp, we need to fetch the symmetrical
1852 // PackOp.
1853 if (unPackOp) {
1854 assert(!packOp && "packOp must be null on entry when unPackOp is not null");
1855 OpOperand *packUse = linalgOp.getDpsInitOperand(
1856 cast<OpResult>(unPackOp.getSource()).getResultNumber());
1857 packOp = dyn_cast_or_null<linalg::PackOp>(packUse->get().getDefiningOp());
1858 if (!packOp || !packOp.getResult().hasOneUse())
1859 return emitSilenceableError() << "could not find matching pack op";
1860 }
1861
1862 // Step 2.5. Fail if any permutation does not validate.
1863 for (auto permType : {OuterOrInnerPerm::Outer, OuterOrInnerPerm::Inner}) {
1864 ArrayRef<int64_t> perm =
1865 (permType == OuterOrInnerPerm::Outer) ? getOuterPerm() : getInnerPerm();
1866 auto errorMsg = (permType == OuterOrInnerPerm::Outer)
1867 ? StringLiteral{"invalid outer_perm"}
1868 : StringLiteral{"invalid inner_perm"};
1869 if (!isValidPackingPermutation(packOp, perm, permType) ||
1870 !isValidPackingPermutation(unPackOp, perm, permType)) {
1871 Operation *packOrUnpackOp =
1872 unPackOp ? unPackOp.getOperation() : packOp.getOperation();
1873 return emitSilenceableError() << errorMsg << ": " << *packOrUnpackOp;
1874 }
1875 }
1876
1877 // From here on, packOp and linalgOp are always present, unPackOp may or may
1878 // not be present.
1879 assert(packOp && linalgOp && "unexpected null op");
1880
1881 // Step 3. Actually transpose the ops.
1882 FailureOr<PackTransposeResult> res = packTranspose(
1883 rewriter, packOp, linalgOp, unPackOp, getOuterPerm(), getInnerPerm());
1884 // Preconditions have been checked, it is an error to fail here.
1885 assert(succeeded(res) && "unexpected packTranspose failure");
1886
1887 // Step 4. Return results.
1888 transformResults.set(cast<OpResult>(getPackOp()), {res->transposedPackOp});
1889 transformResults.set(cast<OpResult>(getPackedOp()),
1890 {res->transposedLinalgOp});
1891 if (unPackOp) {
1892 transformResults.set(cast<OpResult>(getUnPackOp()),
1893 {res->transposedUnPackOp});
1894 } else {
1895 transformResults.set(cast<OpResult>(getUnPackOp()), {});
1896 }
1897
1898 return DiagnosedSilenceableFailure::success();
1899}
1900
1901//===---------------------------------------------------------------------===//
1902// PadOp
1903//===---------------------------------------------------------------------===//
1904
1905void transform::PadOp::build(OpBuilder &b, OperationState &result, Value target,
1906 ArrayRef<int64_t> paddingDimensions,
1907 ArrayRef<int64_t> padToMultipleOf,
1908 ArrayRef<int64_t> nofoldFlags,
1909 ArrayRef<Attribute> transposePaddings,
1910 StringRef copyBackOp) {
1911 auto resultType = transform::AnyOpType::get(b.getContext());
1912 return build(/*builder=*/b,
1913 /*result=*/result,
1914 /*types=*/TypeRange{resultType, resultType},
1915 /*target=*/target,
1916 /*paddingValues=*/ArrayAttr(), // let inference handle this
1917 /*paddingDimensions=*/b.getI64ArrayAttr(paddingDimensions),
1918 /*padToMultipleOf=*/ValueRange{},
1919 /*padToMultipleOf=*/
1920 (padToMultipleOf.empty()
1921 ? DenseI64ArrayAttr()
1922 : b.getDenseI64ArrayAttr(padToMultipleOf)),
1923 /*nofoldFlags=*/b.getI64ArrayAttr(nofoldFlags),
1924 /*transposePaddings=*/b.getArrayAttr(transposePaddings),
1925 /*copyBackOp=*/b.getStringAttr(copyBackOp));
1926}
1927
1928void transform::PadOp::build(OpBuilder &b, OperationState &result, Value target,
1929 ArrayRef<int64_t> paddingDimensions,
1930 ArrayRef<OpFoldResult> mixedPadToMultipleOf,
1931 ArrayRef<int64_t> nofoldFlags,
1932 ArrayRef<Attribute> transposePaddings,
1933 StringRef copyBackOp) {
1934 auto resultType = transform::AnyOpType::get(b.getContext());
1935 SmallVector<int64_t> staticPadToMultipleOf;
1936 SmallVector<Value> dynamicPadToMultipleOf;
1937 dispatchIndexOpFoldResults(mixedPadToMultipleOf, dynamicPadToMultipleOf,
1938 staticPadToMultipleOf);
1939 return build(/*builder=*/b,
1940 /*result=*/result,
1941 /*types=*/TypeRange{resultType, resultType},
1942 /*target=*/target,
1943 /*paddingValues=*/ArrayAttr(), // let inference handle this
1944 /*paddingDimensions=*/b.getI64ArrayAttr(paddingDimensions),
1945 /*padToMultipleOf=*/dynamicPadToMultipleOf,
1946 /*padToMultipleOf=*/staticPadToMultipleOf,
1947 /*nofoldFlags=*/b.getI64ArrayAttr(nofoldFlags),
1948 /*transposePaddings=*/b.getArrayAttr(transposePaddings),
1949 /*copyBackOp=*/b.getStringAttr(copyBackOp));
1950}
1951
1952void PadOp::getEffects(
1953 SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
1954 consumesHandle(getTargetMutable(), effects);
1955 onlyReadsHandle(getPadToMultipleOfMutable(), effects);
1956 producesHandle(getOperation()->getOpResults(), effects);
1957 modifiesPayload(effects);
1958}
1959
1960SmallVector<OpFoldResult> PadOp::getMixedPadToMultipleOf() {
1961 Builder b(getContext());
1962 return getMixedValues(getStaticPadToMultipleOf(), getPadToMultipleOf(), b);
1963}
1964
1965DiagnosedSilenceableFailure
1966transform::PadOp::apply(transform::TransformRewriter &rewriter,
1967 transform::TransformResults &results,
1968 transform::TransformState &state) {
1969 auto transformOp = cast<TransformOpInterface>(getOperation());
1970 SmallVector<Operation *> paddedOps, padOps, copyBackOps;
1971
1972 for (Operation *target : state.getPayloadOps(getTarget())) {
1973 auto linalgTarget = dyn_cast<LinalgOp>(target);
1974 if (!linalgTarget) {
1975 auto diag = emitSilenceableError() << "expected LinalgOp target";
1976 diag.attachNote(target->getLoc()) << "target op";
1977 return diag;
1978 }
1979
1980 // Convert the integer packing flags to booleans.
1981 SmallVector<bool> nofoldFlags;
1982 for (int64_t packPadding :
1983 extractFromIntegerArrayAttr<int64_t>(getNofoldFlags()))
1984 nofoldFlags.push_back(static_cast<bool>(packPadding));
1985
1986 // Convert the padding values to attributes.
1987 SmallVector<Attribute> paddingValues;
1988 for (auto const &it :
1989 llvm::zip(getPaddingValues(), linalgTarget->getOperandTypes())) {
1990 auto attr = dyn_cast<TypedAttr>(std::get<0>(it));
1991 if (!attr) {
1992 emitOpError("expects padding values to be typed attributes");
1993 return DiagnosedSilenceableFailure::definiteFailure();
1994 }
1995 Type elementType = getElementTypeOrSelf(std::get<1>(it));
1996 // Try to parse string attributes to obtain an attribute of element type.
1997 if (auto stringAttr = dyn_cast<StringAttr>(attr)) {
1998 auto parsedAttr = dyn_cast_if_present<TypedAttr>(parseAttribute(
1999 stringAttr, getContext(), elementType,
2000 /*numRead=*/nullptr, /*isKnownNullTerminated=*/true));
2001 if (!parsedAttr || parsedAttr.getType() != elementType) {
2002 auto diag = this->emitOpError("expects a padding that parses to ")
2003 << elementType << ", got " << std::get<0>(it);
2004 diag.attachNote(linalgTarget.getLoc()) << "when applied to this op";
2005 return DiagnosedSilenceableFailure::definiteFailure();
2006 }
2007 paddingValues.push_back(parsedAttr);
2008 continue;
2009 }
2010 // Otherwise, add the attribute directly.
2011 if (attr.getType() != elementType) {
2012 auto diag = this->emitOpError("expects a padding value of type ")
2013 << elementType << ", got " << attr;
2014 diag.attachNote(linalgTarget.getLoc()) << "when applied to this op";
2015 return DiagnosedSilenceableFailure::definiteFailure();
2016 }
2017 paddingValues.push_back(attr);
2018 }
2019
2020 // Extract the transpose vectors.
2021 SmallVector<SmallVector<int64_t>> transposePaddings;
2022 for (Attribute transposeVector : cast<ArrayAttr>(getTransposePaddings()))
2023 transposePaddings.push_back(extractFromIntegerArrayAttr<int64_t>(
2024 cast<ArrayAttr>(transposeVector)));
2025
2026 LinalgOp paddedOp;
2027 LinalgPaddingOptions options;
2028 options.paddingDimensions =
2029 extractFromIntegerArrayAttr<int64_t>(getPaddingDimensions());
2030
2031 SmallVector<int64_t> padToMultipleOf;
2032 DiagnosedSilenceableFailure status = reifyMixedParamAndHandleResults(
2033 state, transformOp, getMixedPadToMultipleOf(), padToMultipleOf);
2034 if (!status.succeeded())
2035 return status;
2036 if (padToMultipleOf.empty())
2037 padToMultipleOf =
2038 SmallVector<int64_t>(options.paddingDimensions.size(), 1);
2039
2040 options.padToMultipleOf = padToMultipleOf;
2041 options.paddingValues = paddingValues;
2042 options.nofoldFlags = nofoldFlags;
2043 if (getCopyBackOp() ==
2044 bufferization::MaterializeInDestinationOp::getOperationName()) {
2045 options.copyBackOp = LinalgPaddingOptions::CopyBackOp::
2046 BufferizationMaterializeInDestination;
2047 } else if (getCopyBackOp() == linalg::CopyOp::getOperationName()) {
2048 options.copyBackOp = LinalgPaddingOptions::CopyBackOp::LinalgCopy;
2049 } else if (getCopyBackOp() == kCopyOpNone) {
2050 options.copyBackOp = LinalgPaddingOptions::CopyBackOp::None;
2051 } else {
2052 llvm_unreachable("unsupported copy_back op");
2053 }
2054
2055 SmallVector<Value> replacements;
2056 SmallVector<tensor::PadOp> newPadOps;
2057 if (failed(rewriteAsPaddedOp(rewriter, linalgTarget, options, paddedOp,
2058 replacements, newPadOps))) {
2059 auto diag = emitSilenceableError() << "failed to pad op";
2060 diag.attachNote(target->getLoc()) << "target op";
2061 return diag;
2062 }
2063
2064 // We need to perform our own replacement here because this API is still
2065 // used in patterns that "pad and hoist", for which the replacement values
2066 // need to be different.
2067 // TODO: clean this up and stop "pad and hoist" behavior more globally now
2068 // that we have more composable abstractions.
2069 rewriter.replaceOp(linalgTarget, replacements);
2070 paddedOps.push_back(paddedOp);
2071 padOps.append(newPadOps.begin(), newPadOps.end());
2072 if (options.copyBackOp != LinalgPaddingOptions::CopyBackOp::None) {
2073 for (Value v : replacements) {
2074 Operation *copyBackOp = v.getDefiningOp();
2075 if (!llvm::is_contained(copyBackOps, copyBackOp))
2076 copyBackOps.push_back(copyBackOp);
2077 }
2078 }
2079 }
2080
2081 results.set(cast<OpResult>(getPadded()), paddedOps);
2082 results.set(cast<OpResult>(getPad()), padOps);
2083 results.set(cast<OpResult>(getCopy()), copyBackOps);
2084 return DiagnosedSilenceableFailure::success();
2085}
2086
2087LogicalResult transform::PadOp::verify() {
2088 SmallVector<int64_t> nofoldFlags =
2089 extractFromIntegerArrayAttr<int64_t>(getNofoldFlags());
2090 if (any_of(nofoldFlags, [](int64_t packPadding) {
2091 return packPadding != 0 && packPadding != 1;
2092 })) {
2093 return emitOpError()
2094 << "expects nofold_flags to contain booleans (0/1), found "
2095 << getNofoldFlags();
2096 }
2097
2098 SmallVector<int64_t> paddingDimensions =
2099 extractFromIntegerArrayAttr<int64_t>(getPaddingDimensions());
2100 if (any_of(paddingDimensions,
2101 [](int64_t paddingDimension) { return paddingDimension < 0; })) {
2102 return emitOpError() << "expects padding_dimensions to contain positive "
2103 "integers, found "
2104 << getPaddingDimensions();
2105 }
2106 if (!getMixedPadToMultipleOf().empty()) {
2107 if (getMixedPadToMultipleOf().size() != paddingDimensions.size()) {
2108 return emitOpError() << "expects as many multiples as padding_dimensions";
2109 }
2110 }
2111 ArrayAttr transposes = getTransposePaddings();
2112 for (Attribute attr : transposes) {
2113 SmallVector<int64_t> transpose = extractFromIntegerArrayAttr<int64_t>(attr);
2114 auto sequence = llvm::to_vector(llvm::seq<int64_t>(0, transpose.size()));
2115 if (!std::is_permutation(sequence.begin(), sequence.end(),
2116 transpose.begin(), transpose.end())) {
2117 return emitOpError()
2118 << "expects transpose_paddings to be a permutation, found "
2119 << attr;
2120 }
2121 }
2122 if (getCopyBackOp() !=
2123 bufferization::MaterializeInDestinationOp::getOperationName() &&
2124 getCopyBackOp() != linalg::CopyOp::getOperationName() &&
2125 getCopyBackOp() != kCopyOpNone)
2126 return emitOpError() << "invalid copy_back_op";
2127 return success();
2128}
2129
2130//===---------------------------------------------------------------------===//
2131// HoistPadOp
2132//===---------------------------------------------------------------------===//
2133
2134DiagnosedSilenceableFailure transform::HoistPadBuildPackingLoopNestOp::apply(
2135 transform::TransformRewriter &rewriter,
2136 transform::TransformResults &transformResults,
2137 transform::TransformState &state) {
2138 auto targetOps = state.getPayloadOps(getTarget());
2139 auto loopOps = state.getPayloadOps(getLoop());
2140 if (!llvm::hasSingleElement(targetOps) || !llvm::hasSingleElement(loopOps)) {
2141 return emitDefiniteFailure()
2142 << "requires exactly one target and one loop handle (got "
2143 << llvm::range_size(targetOps) << " and "
2144 << llvm::range_size(loopOps) << ")";
2145 }
2146
2147 auto padOp = dyn_cast_or_null<tensor::PadOp>(*targetOps.begin());
2148 auto loopOp = dyn_cast_or_null<scf::ForOp>(*loopOps.begin());
2149 if (!padOp || !loopOp)
2150 return emitDefiniteFailure() << "requires exactly 2 non-null handles";
2151
2152 FailureOr<linalg::detail::PackingResult> result =
2153 linalg::detail::buildPackingLoopNest(rewriter, padOp, loopOp,
2154 getTranspose());
2155 if (failed(result))
2156 return emitDefiniteFailure() << "could not build packing loop nest";
2157
2158 if (result->clonedLoopIvs.empty()) {
2159 transformResults.set(cast<OpResult>(getPackingLoop()),
2160 {result->hoistedPadOp.getOperation()});
2161 return DiagnosedSilenceableFailure::success();
2162 }
2163 auto outerPackedLoop =
2164 scf::getForInductionVarOwner(result->clonedLoopIvs.front());
2165 transformResults.set(cast<OpResult>(getPackingLoop()),
2166 {outerPackedLoop.getOperation()});
2167 return DiagnosedSilenceableFailure::success();
2168}
2169
2170LogicalResult transform::HoistPadBuildPackingLoopNestOp::verify() {
2171 ArrayRef<int64_t> transpose = getTranspose();
2172 auto sequence = llvm::to_vector(llvm::seq<int64_t>(0, transpose.size()));
2173 if (!std::is_permutation(sequence.begin(), sequence.end(), transpose.begin(),
2174 transpose.end())) {
2175 return emitOpError() << "expects transpose to be a permutation, found "
2176 << getTranspose();
2177 }
2178 return success();
2179}
2180
2181void transform::HoistPadBuildPackingLoopNestOp::getEffects(
2182 SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
2183 transform::onlyReadsHandle(getTargetMutable(), effects);
2184 transform::onlyReadsHandle(getLoopMutable(), effects);
2185 transform::producesHandle(getOperation()->getOpResults(), effects);
2186 transform::modifiesPayload(effects);
2187}
2188
2189DiagnosedSilenceableFailure
2190transform::HoistPadOp::applyToOne(transform::TransformRewriter &rewriter,
2191 tensor::PadOp target,
2192 transform::ApplyToEachResultList &results,
2193 transform::TransformState &state) {
2194 tensor::PadOp hoistedPadOp;
2195 SmallVector<TransposeOp> transposeOps;
2196 FailureOr<Value> result =
2197 hoistPaddingOnTensors(rewriter, target, getNumLoops(), getTranspose(),
2198 hoistedPadOp, transposeOps);
2199 if (succeeded(result)) {
2200 // We need to perform our own replacement here because this API is still
2201 // used in patterns that "pad and hoist", for which the replacement values
2202 // need to be different.
2203 // TODO: clean this up and stop "pad and hoist" behavior more globally now
2204 // that we have more composable abstractions.
2205 rewriter.replaceOp(target, *result);
2206 results.push_back(hoistedPadOp);
2207 return DiagnosedSilenceableFailure::success();
2208 }
2209 return emitDefaultSilenceableFailure(target);
2210}
2211
2212LogicalResult transform::HoistPadOp::verify() {
2213 ArrayRef<int64_t> transpose = getTranspose();
2214 auto sequence = llvm::to_vector(llvm::seq<int64_t>(0, transpose.size()));
2215 if (!std::is_permutation(sequence.begin(), sequence.end(), transpose.begin(),
2216 transpose.end())) {
2217 return emitOpError() << "expects transpose to be a permutation, found "
2218 << getTranspose();
2219 }
2220 return success();
2221}
2222
2223//===----------------------------------------------------------------------===//
2224// PromoteOp
2225//===----------------------------------------------------------------------===//
2226
2227DiagnosedSilenceableFailure
2228transform::PromoteOp::applyToOne(transform::TransformRewriter &rewriter,
2229 LinalgOp target,
2230 transform::ApplyToEachResultList &results,
2231 transform::TransformState &state) {
2232 LinalgPromotionOptions promotionOptions;
2233 if (!getOperandsToPromote().empty())
2234 promotionOptions = promotionOptions.setOperandsToPromote(
2235 extractFromIntegerArrayAttr<int64_t>(getOperandsToPromote()));
2236 if (getUseFullTilesByDefault())
2237 promotionOptions = promotionOptions.setUseFullTileBuffersByDefault(
2238 getUseFullTilesByDefault());
2239 if (getUseAlloca())
2240 promotionOptions = promotionOptions.setUseAlloca(getUseAlloca());
2241 if (!getUseFullTileBuffers().empty())
2242 promotionOptions = promotionOptions.setUseFullTileBuffers(
2243 llvm::to_vector(getUseFullTileBuffers().getAsValueRange<BoolAttr>()));
2244 if (getAlignment().has_value())
2245 promotionOptions = promotionOptions.setAlignment(*getAlignment());
2246 if (getMemorySpace().has_value())
2247 promotionOptions = promotionOptions.setMemorySpace(*getMemorySpace());
2248
2249 if (getMapping().has_value()) {
2250 // The mapping should only contain an element
2251 auto mapping = *getMapping();
2252 if (mapping.size() > 1)
2253 return emitDefaultDefiniteFailure(target);
2254
2255 auto addressSpace = cast<mlir::gpu::GPUMemorySpaceMappingAttr>(mapping[0]);
2256
2257 if (addressSpace.getAddressSpace() ==
2258 mlir::gpu::GPUDialect::getWorkgroupAddressSpace()) {
2259 promotionOptions =
2260 promotionOptions
2261 .setAllocationDeallocationFns(allocateWorkgroupMemory,
2262 deallocateWorkgroupMemory)
2263 .setCopyInOutFns(copyToWorkgroupMemory, copyToWorkgroupMemory)
2264 .setUseFullTileBuffers({false, false});
2265 } else if (addressSpace.getAddressSpace() ==
2266 mlir::gpu::GPUDialect::getPrivateAddressSpace()) {
2267 promotionOptions =
2268 promotionOptions
2269 .setAllocationDeallocationFns(allocateGPUPrivateMemory,
2270 deallocateGPUPrivateMemory)
2271 .setCopyInOutFns(copyToGPUPrivateMemory, copyToGPUPrivateMemory)
2272 .setUseFullTileBuffers({false, false});
2273 } else {
2274 return emitDefaultDefiniteFailure(target);
2275 }
2276 }
2277
2278 if (failed(promoteSubviewsPrecondition(target, promotionOptions)))
2279 return emitDefaultDefiniteFailure(target);
2280
2281 rewriter.setInsertionPoint(target);
2282 FailureOr<LinalgOp> res = promoteSubViews(rewriter, target, promotionOptions);
2283 if (failed(res))
2284 return emitDefaultDefiniteFailure(target);
2285 results.push_back(target);
2286 return DiagnosedSilenceableFailure::success();
2287}
2288
2289//===----------------------------------------------------------------------===//
2290// ReplaceOp
2291//===----------------------------------------------------------------------===//
2292
2293DiagnosedSilenceableFailure
2294transform::ReplaceOp::apply(transform::TransformRewriter &rewriter,
2295 TransformResults &transformResults,
2296 TransformState &state) {
2297 auto payload = state.getPayloadOps(getTarget());
2298
2299 // Check for invalid targets.
2300 for (Operation *target : payload) {
2301 if (target->getNumOperands() > 0)
2302 return emitDefiniteFailure() << "expected target without operands";
2303 if (!target->hasTrait<OpTrait::IsIsolatedFromAbove>() &&
2304 target->getNumRegions() > 0)
2305 return emitDefiniteFailure()
2306 << "expected target that is isolated from above";
2307 }
2308
2309 // Clone and replace.
2310 Operation *pattern = &getBodyRegion().front().front();
2311 SmallVector<Operation *> replacements;
2312 for (Operation *target : payload) {
2313 if (getOperation()->isAncestor(target))
2314 continue;
2315 rewriter.setInsertionPoint(target);
2316 Operation *replacement = rewriter.clone(*pattern);
2317 rewriter.replaceOp(target, replacement->getResults());
2318 replacements.push_back(replacement);
2319 }
2320 transformResults.set(cast<OpResult>(getReplacement()), replacements);
2321 return DiagnosedSilenceableFailure::success();
2322}
2323
2324void transform::ReplaceOp::getEffects(
2325 SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
2326 consumesHandle(getTargetMutable(), effects);
2327 producesHandle(getOperation()->getOpResults(), effects);
2328 modifiesPayload(effects);
2329}
2330
2331LogicalResult transform::ReplaceOp::verify() {
2332 if (!getBodyRegion().hasOneBlock())
2333 return emitOpError() << "expected one block";
2334 if (std::distance(getBodyRegion().front().begin(),
2335 getBodyRegion().front().end()) != 1)
2336 return emitOpError() << "expected one operation in block";
2337 Operation *replacement = &getBodyRegion().front().front();
2338 if (replacement->getNumOperands() > 0)
2339 return replacement->emitOpError()
2340 << "expected replacement without operands";
2341 if (!replacement->hasTrait<OpTrait::IsIsolatedFromAbove>() &&
2342 replacement->getNumRegions() > 0)
2343 return replacement->emitOpError()
2344 << "expect op that is isolated from above";
2345 return success();
2346}
2347
2348//===----------------------------------------------------------------------===//
2349// ScalarizeOp
2350//===----------------------------------------------------------------------===//
2351
2352DiagnosedSilenceableFailure
2353transform::ScalarizeOp::applyToOne(transform::TransformRewriter &rewriter,
2354 LinalgOp target,
2355 transform::ApplyToEachResultList &results,
2356 transform::TransformState &state) {
2357 scf::SCFTilingOptions tilingOptions;
2358 tilingOptions.setTileSizeComputationFunction([&](OpBuilder &b, Operation *) {
2359 SmallVector<OpFoldResult> tileSizes;
2360 Location loc = target.getLoc();
2361 SmallVector<OpFoldResult> allShapeSizes =
2362 target.createFlatListOfOperandDims(b, loc);
2363 AffineMap map = target.getShapesToLoopsMap();
2364 if (!map)
2365 return tileSizes;
2366 SmallVector<OpFoldResult> shapeSizes =
2367 affine::makeComposedFoldedMultiResultAffineApply(rewriter, loc, map,
2368 allShapeSizes);
2369 // If the shape size is dynamic, tile by 1.
2370 // Otherwise, do not tile (i.e. tile size 0).
2371 for (OpFoldResult shapeSize : shapeSizes) {
2372 tileSizes.push_back(getConstantIntValue(shapeSize) ? b.getIndexAttr(0)
2373 : b.getIndexAttr(1));
2374 }
2375 return tileSizes;
2376 });
2377 rewriter.setInsertionPoint(target);
2378 FailureOr<scf::SCFTilingResult> maybeTilingResult = tileUsingSCF(
2379 rewriter, cast<TilingInterface>(target.getOperation()), tilingOptions);
2380 if (failed(maybeTilingResult))
2381 return emitDefaultDefiniteFailure(target);
2382
2383 if (target->getNumResults())
2384 rewriter.replaceOp(target, maybeTilingResult->replacements);
2385 else
2386 rewriter.eraseOp(target);
2387
2388 results.reserve(maybeTilingResult->tiledOps.size());
2389 for (Operation *tiled : maybeTilingResult->tiledOps)
2390 results.push_back(tiled);
2391 return DiagnosedSilenceableFailure::success();
2392}
2393
2394//===----------------------------------------------------------------------===//
2395// ConvertToLoopsOp
2396//===----------------------------------------------------------------------===//
2397
2398DiagnosedSilenceableFailure
2399transform::ConvertToLoopsOp::apply(transform::TransformRewriter &rewriter,
2400 transform::TransformResults &results,
2401 transform::TransformState &state) {
2402 SmallVector<Operation *> loops;
2403 for (Operation *target : state.getPayloadOps(getTarget())) {
2404 auto tilingOp = dyn_cast<TilingInterface>(*target);
2405 if (!tilingOp) {
2406 DiagnosedSilenceableFailure diag =
2407 emitSilenceableError()
2408 << "expected the payload to implement TilingInterface";
2409 diag.attachNote(target->getLoc()) << "payload op";
2410 return diag;
2411 }
2412 rewriter.setInsertionPoint(target);
2413 FailureOr<SmallVector<scf::ForOp>> generatedLoops =
2414 scf::lowerToLoopsUsingSCFForOp(rewriter, tilingOp);
2415 if (failed(generatedLoops))
2416 return emitDefaultDefiniteFailure(target);
2417 for (scf::ForOp &loop : *generatedLoops) {
2418 loops.push_back(loop.getOperation());
2419 }
2420 rewriter.eraseOp(target);
2421 }
2422 results.set(cast<OpResult>(getResult()), loops);
2423 return DiagnosedSilenceableFailure::success();
2424}
2425
2426//===----------------------------------------------------------------------===//
2427// RewriteInDestinationPassingStyleOp
2428//===----------------------------------------------------------------------===//
2429
2430DiagnosedSilenceableFailure
2431transform::RewriteInDestinationPassingStyleOp::applyToOne(
2432 transform::TransformRewriter &rewriter, Operation *target,
2433 transform::ApplyToEachResultList &results,
2434 transform::TransformState &state) {
2435 rewriter.setInsertionPoint(target);
2436 FailureOr<Operation *> maybeResult =
2437 TypeSwitch<Operation *, FailureOr<Operation *>>(target)
2438 .Case<tensor::FromElementsOp, tensor::GenerateOp, tensor::PadOp>(
2439 [&rewriter](auto op) {
2440 return rewriteInDestinationPassingStyle(rewriter, op);
2441 });
2442 if (failed(maybeResult))
2443 return emitDefaultSilenceableFailure(target);
2444 results.push_back(*maybeResult);
2445 return DiagnosedSilenceableFailure::success();
2446}
2447
2448//===----------------------------------------------------------------------===//
2449// SplitOp
2450//===----------------------------------------------------------------------===//
2451
2452DiagnosedSilenceableFailure
2453SplitOp::apply(transform::TransformRewriter &rewriter,
2454 TransformResults &results, TransformState &state) {
2455 // Collect the dynamic split points if provided.
2456 SmallVector<Operation *> payload =
2457 llvm::to_vector(state.getPayloadOps(getTarget()));
2458
2459 bool isMultiwaySplit = getMultiway();
2460
2461 if (isMultiwaySplit && !llvm::hasSingleElement(payload)) {
2462 return mlir::emitSilenceableFailure(getLoc())
2463 << "requires exactly one target when "
2464 "multiway split is enabled (got "
2465 << llvm::range_size(payload) << ")";
2466 }
2467
2468 SmallVector<OpFoldResult> chunkSizes;
2469
2470 if (!isMultiwaySplit)
2471 chunkSizes.reserve(payload.size());
2472
2473 if (getDynamicChunkSizes()) {
2474 auto diag = DiagnosedSilenceableFailure::success();
2475 if (isa<TransformHandleTypeInterface>(getDynamicChunkSizes().getType())) {
2476 chunkSizes = llvm::to_vector(llvm::map_range(
2477 state.getPayloadOps(getDynamicChunkSizes()), [&](Operation *op) {
2478 if (op->getNumResults() != 1 ||
2479 !op->getResult(0).getType().isIndex()) {
2480 diag = emitSilenceableError()
2481 << "expected dynamic split point handle to point to a "
2482 "single-result index-typed op";
2483 diag.attachNote(op->getLoc()) << "dynamic split point";
2484 }
2485 return OpFoldResult(op->getResult(0));
2486 }));
2487 } else {
2488 chunkSizes = llvm::to_vector(
2489 llvm::map_range(state.getParams(getDynamicChunkSizes()),
2490 [](Attribute attr) { return OpFoldResult(attr); }));
2491 }
2492 if (diag.isSilenceableFailure())
2493 return diag;
2494
2495 // For multiway split, a single payload is expected to have multiple
2496 // split points.
2497 if (!isMultiwaySplit && chunkSizes.size() != payload.size()) {
2498 return emitDefiniteFailure()
2499 << "expected the dynamic split point handle to point to as "
2500 "many operations ("
2501 << chunkSizes.size() << ") as the target handle ("
2502 << payload.size() << ")";
2503 }
2504 } else {
2505 chunkSizes.resize(payload.size(),
2506 rewriter.getIndexAttr(getStaticChunkSizes()));
2507 }
2508
2509 auto checkStructuredOpAndDimensions =
2510 [&](LinalgOp linalgOp, Location loc) -> DiagnosedSilenceableFailure {
2511 if (!linalgOp) {
2512 auto diag = emitSilenceableError() << "only applies to structured ops";
2513 diag.attachNote(loc) << "target op";
2514 return diag;
2515 }
2516
2517 if (getDimension() >= linalgOp.getNumLoops()) {
2518 auto diag = emitSilenceableError() << "dimension " << getDimension()
2519 << " does not exist in target op";
2520 diag.attachNote(loc) << "target op";
2521 return diag;
2522 }
2523 return DiagnosedSilenceableFailure::success();
2524 };
2525
2526 auto checkFailureInSplitting =
2527 [&](bool hasFailed, Location loc) -> DiagnosedSilenceableFailure {
2528 if (hasFailed) {
2529 auto diag = emitDefiniteFailure() << "internal failure in splitting";
2530 diag.attachNote(loc) << "target op";
2531 return diag;
2532 }
2533 return DiagnosedSilenceableFailure::success();
2534 };
2535
2536 SmallVector<Operation *> opList;
2537 if (isMultiwaySplit) {
2538
2539 // Split a single target operation at multiple points.
2540 TilingInterface head, tail;
2541 Operation *target = payload.front();
2542
2543 LinalgOp linalgOp = dyn_cast<LinalgOp>(target);
2544
2545 // Check that the target is a valid LinalgOp with correct dimensions.
2546 DiagnosedSilenceableFailure diag =
2547 checkStructuredOpAndDimensions(linalgOp, target->getLoc());
2548 if (diag.isSilenceableFailure())
2549 return diag;
2550
2551 for (auto &&[idx, chunkSize] : llvm::enumerate(chunkSizes)) {
2552
2553 if (idx > 0)
2554 target = tail.getOperation();
2555
2556 if (!target)
2557 break;
2558
2559 linalgOp = cast<LinalgOp>(target);
2560 Location loc = target->getLoc();
2561
2562 rewriter.setInsertionPoint(linalgOp);
2563 std::tie(head, tail) = linalg::splitOp(
2564 rewriter, cast<TilingInterface>(linalgOp.getOperation()),
2565 getDimension(), chunkSize);
2566
2567 // Propagate errors.
2568 DiagnosedSilenceableFailure diag =
2569 checkFailureInSplitting(!head && !tail, loc);
2570 if (diag.isDefiniteFailure())
2571 return diag;
2572
2573 opList.push_back(head.getOperation());
2574 }
2575
2576 // Append any leftover parts to the end of the result list.
2577 if (tail)
2578 opList.push_back(tail.getOperation());
2579
2580 } else {
2581 // Split each target operation.
2582 SmallVector<Operation *> first, second;
2583 Operation *noSecondPart = nullptr;
2584 for (const auto &pair : llvm::zip(payload, chunkSizes)) {
2585 Operation *target = std::get<0>(pair);
2586 Location loc = target->getLoc();
2587 LinalgOp linalgOp = dyn_cast<LinalgOp>(target);
2588 DiagnosedSilenceableFailure diag =
2589 checkStructuredOpAndDimensions(linalgOp, target->getLoc());
2590
2591 if (diag.isSilenceableFailure())
2592 return diag;
2593
2594 rewriter.setInsertionPoint(linalgOp);
2595 std::tie(first.emplace_back(), second.emplace_back()) = linalg::splitOp(
2596 rewriter, cast<TilingInterface>(linalgOp.getOperation()),
2597 getDimension(), std::get<1>(pair));
2598
2599 // Propagate errors.
2600 DiagnosedSilenceableFailure diagSplit =
2601 checkFailureInSplitting(!first.back() && !second.back(), loc);
2602 if (diagSplit.isDefiniteFailure())
2603 return diag;
2604
2605 // Do not add null second parts.
2606 if (!second.back()) {
2607 noSecondPart = target;
2608 second.pop_back();
2609 }
2610 }
2611
2612 if (second.size() != first.size() && !second.empty()) {
2613 auto diag = emitSilenceableError()
2614 << "splitting does not produce the second part for a subset "
2615 "of targets";
2616 diag.attachNote()
2617 << "expected splitting to produce the second part of all "
2618 "or none of the targets";
2619 diag.attachNote(noSecondPart->getLoc())
2620 << "first target with no second part";
2621 return diag;
2622 }
2623
2624 opList.append(first);
2625 if (second.size())
2626 opList.append(second);
2627 }
2628 results.set(cast<OpResult>(getSplitList()), opList);
2629 return DiagnosedSilenceableFailure::success();
2630}
2631
2632void SplitOp::getEffects(
2633 SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
2634 consumesHandle(getTargetMutable(), effects);
2635 if (getDynamicChunkSizes())
2636 onlyReadsHandle(getDynamicChunkSizesMutable(), effects);
2637 producesHandle(getOperation()->getOpResults(), effects);
2638 modifiesPayload(effects);
2639}
2640
2641ParseResult SplitOp::parse(OpAsmParser &parser, OperationState &result) {
2642 OpAsmParser::UnresolvedOperand target, dynamicChunkSizes;
2643 IntegerAttr staticChunkSizes;
2644 if (parser.parseOperand(target) || parser.parseKeyword("after"))
2645 return failure();
2646
2647 OptionalParseResult dynamicPointParseResult =
2648 parser.parseOptionalOperand(dynamicChunkSizes);
2649 if (!dynamicPointParseResult.has_value()) {
2650 int64_t staticChunkSizesValue;
2651 if (failed(parser.parseInteger(staticChunkSizesValue)))
2652 return failure();
2653
2654 staticChunkSizes =
2655 parser.getBuilder().getI64IntegerAttr(staticChunkSizesValue);
2656 }
2657
2658 Type targetType;
2659 if (parser.parseOptionalAttrDict(result.attributes) ||
2660 parser.parseColonType(targetType) ||
2661 parser.resolveOperand(target, targetType, result.operands)) {
2662 return failure();
2663 }
2664 if (dynamicPointParseResult.has_value()) {
2665 Type ChunkSizesType;
2666 if (failed(*dynamicPointParseResult) || parser.parseComma() ||
2667 parser.parseType(ChunkSizesType) ||
2668 parser.resolveOperand(dynamicChunkSizes, ChunkSizesType,
2669 result.operands)) {
2670 return failure();
2671 }
2672
2673 staticChunkSizes =
2674 parser.getBuilder().getI64IntegerAttr(ShapedType::kDynamic);
2675 }
2676
2677 result.addAttribute(
2678 SplitOp::getStaticChunkSizesAttrName(result.name).getValue(),
2679 staticChunkSizes);
2680 result.addTypes(targetType);
2681 return success();
2682}
2683
2684void SplitOp::print(OpAsmPrinter &printer) {
2685 printer << " " << getTarget() << " after ";
2686 int64_t staticChunkSize = static_cast<int64_t>(getStaticChunkSizes());
2687 if (staticChunkSize != ShapedType::kDynamic)
2688 printer << staticChunkSize;
2689 else
2690 printer << getDynamicChunkSizes();
2691 printer << " ";
2692 printer.printOptionalAttrDict(getOperation()->getAttrs(),
2693 {getStaticChunkSizesAttrName()});
2694 printer << " : " << getTarget().getType();
2695 if (staticChunkSize == ShapedType::kDynamic)
2696 printer << ", " << getDynamicChunkSizes().getType();
2697}
2698
2699LogicalResult SplitOp::verify() {
2700 if ((static_cast<int64_t>(getStaticChunkSizes()) != ShapedType::kDynamic) ^
2701 (getDynamicChunkSizes() == nullptr)) {
2702 return emitOpError() << "expects either a dynamic or a static split "
2703 "point to be provided";
2704 }
2705 return success();
2706}
2707
2708//===----------------------------------------------------------------------===//
2709// SplitReductionOp
2710//===----------------------------------------------------------------------===//
2711
2712void transform::SplitReductionOp::build(
2713 OpBuilder &builder, OperationState &result, Value target,
2714 int64_t splitFactor, int64_t insertSplitDimension, bool innerParallel,
2715 bool useScalingAlgorithm, bool useAlloc) {
2716 MLIRContext *ctx = builder.getContext();
2717 result.addOperands(target);
2718 result.addAttribute(SplitReductionOp::getSplitFactorAttrName(result.name),
2719 builder.getI64IntegerAttr(splitFactor));
2720 result.addAttribute(
2721 SplitReductionOp::getInsertSplitDimensionAttrName(result.name),
2722 builder.getI64IntegerAttr(insertSplitDimension));
2723 if (innerParallel) {
2724 result.addAttribute(SplitReductionOp::getInnerParallelAttrName(result.name),
2725 builder.getUnitAttr());
2726 }
2727 if (useScalingAlgorithm) {
2728 result.addAttribute(
2729 SplitReductionOp::getUseScalingAlgorithmAttrName(result.name),
2730 builder.getUnitAttr());
2731 }
2732 if (useAlloc) {
2733 result.addAttribute(SplitReductionOp::getUseAllocAttrName(result.name),
2734 builder.getUnitAttr());
2735 }
2736 auto resultType = transform::AnyOpType::get(ctx);
2737 result.addTypes({resultType, resultType, resultType, resultType});
2738}
2739
2740DiagnosedSilenceableFailure transform::SplitReductionOp::applyToOne(
2741 transform::TransformRewriter &rewriter, LinalgOp target,
2742 transform::ApplyToEachResultList &results,
2743 transform::TransformState &state) {
2744 ControlSplitReductionFn splitFn = [&](LinalgOp) {
2745 return linalg::SplitReductionOptions{int64_t(getSplitFactor()),
2746 unsigned(getInsertSplitDimension()),
2747 bool(getInnerParallel())};
2748 };
2749 rewriter.setInsertionPoint(target);
2750 FailureOr<SplitReductionResult> splitResult =
2751 (getUseScalingAlgorithm())
2752 ? splitReductionByScaling(rewriter, target, splitFn, getUseAlloc())
2753 : splitReduction(rewriter, target, splitFn, getUseAlloc());
2754 if (failed(splitResult))
2755 return emitDefaultDefiniteFailure(target);
2756
2757 results.push_back(splitResult->initOrAlloc);
2758 results.push_back(splitResult->fillOp);
2759 results.push_back(splitResult->splitLinalgOp);
2760 results.push_back(splitResult->resultCombiningLinalgOp);
2761 return DiagnosedSilenceableFailure::success();
2762}
2763
2764//===----------------------------------------------------------------------===//
2765// TileReductionUsingForOp
2766//===----------------------------------------------------------------------===//
2767
2768void transform::TileReductionUsingForOp::build(
2769 OpBuilder &builder, OperationState &result, Value target,
2770 ArrayRef<int64_t> staticTileSizes) {
2771 // Call the default builder.
2772 // This is future-proof re mixed static-dynamic and setting up the proper
2773 // operands segment sizes attributes for multiple variadic operands.
2774 // In the absence of this, horrible bugs ensue.
2775 // TODO: support mixed static-dynamic (see TileUsingForallOp).
2776 MLIRContext *ctx = builder.getContext();
2777 auto opTy = transform::AnyOpType::get(ctx);
2778 auto staticTileSizesAttr = builder.getDenseI64ArrayAttr(staticTileSizes);
2779 build(builder, result,
2780 /*resultTypes=*/TypeRange{opTy, opTy, opTy, opTy},
2781 /*target=*/target,
2782 /*tile_sizes=*/staticTileSizesAttr);
2783}
2784
2785DiagnosedSilenceableFailure transform::TileReductionUsingForOp::applyToOne(
2786 transform::TransformRewriter &rewriter, Operation *target,
2787 transform::ApplyToEachResultList &results,
2788 transform::TransformState &state) {
2789 rewriter.setInsertionPoint(target);
2790
2791 auto partialReductionOp = dyn_cast<PartialReductionOpInterface>(target);
2792 if (!partialReductionOp) {
2793 return emitSilenceableFailure(
2794 target->getLoc(),
2795 "Operation should implement PartialReductionOpInterface");
2796 }
2797 FailureOr<scf::SCFTilingResult> result = scf::tileReductionUsingScf(
2798 rewriter, partialReductionOp,
2799 getAsOpFoldResult(rewriter.getI64ArrayAttr(getTileSizes())));
2800
2801 if (failed(result))
2802 return emitDefaultSilenceableFailure(target);
2803 rewriter.replaceOp(target, result->replacements);
2804 for (Value initValue : result->initialValues)
2805 results.push_back(initValue.getDefiningOp());
2806 for (auto parallelTiledOp : result->tiledOps)
2807 results.push_back(parallelTiledOp);
2808 for (auto mergeOp : result->mergeOps)
2809 results.push_back(mergeOp);
2810 results.push_back(result->loops.front());
2811 return DiagnosedSilenceableFailure::success();
2812}
2813
2814//===----------------------------------------------------------------------===//
2815// TileReductionUsingForallOp
2816//===----------------------------------------------------------------------===//
2817
2818void transform::TileReductionUsingForallOp::build(
2819 OpBuilder &builder, OperationState &result, Value target,
2820 ArrayRef<int64_t> staticNumThreads, ArrayRef<int64_t> staticTileSizes,
2821 ArrayAttr mapping) {
2822 // Call the default builder.
2823 // This is future-proof re mixed static-dynamic and setting up the proper
2824 // operands segment sizes attributes for multiple variadic operands.
2825 // In the absence of this, horrible bugs ensue.
2826 // TODO: support mixed static-dynamic (see TileUsingForallOp).
2827 MLIRContext *ctx = builder.getContext();
2828 auto opTy = transform::AnyOpType::get(ctx);
2829 auto staticNumThreadsAttr = builder.getDenseI64ArrayAttr(staticNumThreads);
2830 auto staticTileSizesAttr = builder.getDenseI64ArrayAttr(staticTileSizes);
2831 build(builder, result,
2832 /*resultTypes=*/TypeRange{opTy, opTy, opTy, opTy},
2833 /*target=*/target,
2834 /*num_threads=*/staticNumThreadsAttr,
2835 /*tile_sizes=*/staticTileSizesAttr,
2836 /*mapping=*/mapping);
2837}
2838
2839DiagnosedSilenceableFailure transform::TileReductionUsingForallOp::applyToOne(
2840 transform::TransformRewriter &rewriter, LinalgOp target,
2841 transform::ApplyToEachResultList &results,
2842 transform::TransformState &state) {
2843 rewriter.setInsertionPoint(target);
2844 SmallVector<OpFoldResult> numThreads =
2845 getAsOpFoldResult(rewriter.getI64ArrayAttr(getNumThreads()));
2846 SmallVector<OpFoldResult> tileSizes =
2847 getAsOpFoldResult(rewriter.getI64ArrayAttr(getTileSizes()));
2848 FailureOr<linalg::ForallReductionTilingResult> result =
2849 linalg::tileReductionUsingForall(
2850 rewriter, cast<PartialReductionOpInterface>(target.getOperation()),
2851 numThreads, tileSizes, getMapping());
2852
2853 if (failed(result)) {
2854 auto diag = emitSilenceableError() << "could not tile reduction";
2855 diag.attachNote(target.getLoc()) << "target operation";
2856 return diag;
2857 }
2858 for (Value initValue : result->initialValues)
2859 results.push_back(initValue.getDefiningOp());
2860 for (auto parallelTiledOp : result->parallelTiledOps)
2861 results.push_back(parallelTiledOp);
2862 for (auto mergeOp : result->mergeOps)
2863 results.push_back(mergeOp);
2864 results.push_back(result->loops);
2865 return DiagnosedSilenceableFailure::success();
2866}
2867
2868//===----------------------------------------------------------------------===//
2869// ContinuousTileSizesOp
2870//===----------------------------------------------------------------------===//
2871
2872DiagnosedSilenceableFailure
2873transform::ContinuousTileSizesOp::apply(transform::TransformRewriter &rewriter,
2874 TransformResults &transformResults,
2875 TransformState &state) {
2876
2877 SmallVector<Operation *> targetOps =
2878 llvm::to_vector(state.getPayloadOps(getTarget()));
2879
2880 if (!llvm::hasSingleElement(targetOps)) {
2881 return mlir::emitSilenceableFailure(getLoc())
2882 << "requires exactly one target (got " << llvm::range_size(targetOps)
2883 << ")";
2884 }
2885
2886 Operation *target = *targetOps.begin();
2887 auto linalgOp = dyn_cast<LinalgOp>(target);
2888 auto tileableOp = dyn_cast<TilingInterface>(target);
2889
2890 if (!linalgOp)
2891 return emitDefiniteFailure() << "expected Linalg Op";
2892
2893 OpBuilder builder(linalgOp.getContext());
2894
2895 if (isa<TransformParamTypeInterface>(getChunkSizes().getType())) {
2896 if (linalgOp.hasDynamicShape()) {
2897 auto diag = emitSilenceableError()
2898 << "cannot compute parametric tile sizes for dynamically "
2899 "shaped payload op";
2900 diag.attachNote(linalgOp->getLoc()) << "payload op";
2901 return diag;
2902 }
2903
2904 FailureOr<StaticContinuousTileSizeSpecification> spec =
2905 computeStaticContinuousTileSizes(linalgOp, getDimension(),
2906 getTargetSize());
2907 if (failed(spec)) {
2908 return emitSilenceableError()
2909 << "failed to compute multi-size tiling sizes";
2910 }
2911
2912 SmallVector<int64_t> chunkSizes;
2913
2914 for (auto &&[tileSize, tripCount] :
2915 llvm::zip_equal(spec->tileSizes, spec->tripCounts))
2916 chunkSizes.push_back(tileSize * tripCount);
2917
2918 auto getI64AttrsFromI64 = [&](ArrayRef<int64_t> values) {
2919 return llvm::map_to_vector(values, [&](int64_t value) -> Attribute {
2920 return builder.getI64IntegerAttr(value);
2921 });
2922 };
2923 transformResults.setParams(cast<OpResult>(getTileSizes()),
2924 getI64AttrsFromI64(spec->tileSizes));
2925 transformResults.setParams(cast<OpResult>(getChunkSizes()),
2926 getI64AttrsFromI64(chunkSizes));
2927
2928 return DiagnosedSilenceableFailure::success();
2929 }
2930
2931 builder.setInsertionPoint(linalgOp);
2932
2933 OpFoldResult targetSize = builder.getIndexAttr(getTargetSize());
2934 unsigned dimension = getDimension();
2935
2936 FailureOr<ContinuousTileSizeSpecification> spec = computeContinuousTileSizes(
2937 builder, tileableOp, dimension, targetSize, true);
2938 if (failed(spec)) {
2939 return emitSilenceableError() << "could not generate tile size computation";
2940 }
2941
2942 AffineExpr s0 = builder.getAffineSymbolExpr(0);
2943 AffineExpr s1 = builder.getAffineSymbolExpr(1);
2944 auto apply = [&](AffineExpr expr, ArrayRef<OpFoldResult> ofrs) -> Value {
2945 return affine::makeComposedAffineApply(builder, linalgOp->getLoc(), expr,
2946 ofrs);
2947 };
2948
2949 SmallVector<Value> chunkSizes;
2950 Value splitPoint;
2951 for (auto &&[tileSize, tripCount] :
2952 llvm::zip_equal(spec->tileSizes, spec->tripCounts)) {
2953 splitPoint = apply(s0 * s1, {tileSize, tripCount});
2954 chunkSizes.push_back(splitPoint);
2955 }
2956
2957 auto getDefiningOps = [&](ArrayRef<Value> values) {
2958 return llvm::map_to_vector(values, [&](Value value) -> Operation * {
2959 return value.getDefiningOp();
2960 });
2961 };
2962
2963 transformResults.set(cast<OpResult>(getTileSizes()),
2964 getDefiningOps(spec->tileSizes));
2965 transformResults.set(cast<OpResult>(getChunkSizes()),
2966 getDefiningOps(chunkSizes));
2967
2968 return DiagnosedSilenceableFailure::success();
2969}
2970
2971LogicalResult transform::ContinuousTileSizesOp::verify() {
2972
2973 if (getTileSizes().getType() != getChunkSizes().getType()) {
2974 return emitOpError() << "expects all results type to be the same";
2975 }
2976
2977 return success();
2978}
2979
2980void transform::ContinuousTileSizesOp::getEffects(
2981 SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
2982 if (isa<TransformParamTypeInterface>(getTileSizes().getType()))
2983 onlyReadsPayload(effects);
2984 else
2985 modifiesPayload(effects);
2986 onlyReadsHandle(getTargetMutable(), effects);
2987 producesHandle(getOperation()->getOpResults(), effects);
2988}
2989
2990static void printContinuousTileSizeTypes(OpAsmPrinter &printer, Operation *op,
2991 Type targetType, Type tile_sizes,
2992 Type) {
2993 printer.printFunctionalType(inputs: TypeRange{targetType}, results: TypeRange{tile_sizes});
2994}
2995
2996static ParseResult parseContinuousTileSizeTypes(OpAsmParser &parser,
2997 Type &targetType,
2998 Type &tileSizesType,
2999 Type &chunkSizesType) {
3000 FunctionType funcType;
3001 llvm::SMLoc typeLoc = parser.getCurrentLocation();
3002 if (failed(parser.parseType<FunctionType>(funcType)))
3003 return failure();
3004
3005 if (funcType.getNumInputs() != 1 || funcType.getNumResults() != 1) {
3006 parser.emitError(loc: typeLoc) << "expects a trailing functional type with one "
3007 "argument and one result";
3008 }
3009 targetType = funcType.getInput(0);
3010 tileSizesType = chunkSizesType = funcType.getResult(0);
3011
3012 return success();
3013}
3014
3015//===----------------------------------------------------------------------===//
3016// TileUsingForOp
3017//===----------------------------------------------------------------------===//
3018
3019void transform::TileUsingForOp::build(
3020 OpBuilder &builder, OperationState &result, TypeRange loopTypes,
3021 Value target, ArrayRef<int64_t> staticTileSizes,
3022 ArrayRef<int64_t> interchange,
3023 std::optional<ArrayRef<bool>> scalableSizes) {
3024 return build(builder, result, loopTypes,
3025 /*target=*/target,
3026 /*mixedTileSizes=*/
3027 getAsOpFoldResult(builder.getI64ArrayAttr(staticTileSizes)),
3028 interchange, scalableSizes);
3029}
3030
3031void transform::TileUsingForOp::build(
3032 OpBuilder &builder, OperationState &result, Value target,
3033 ArrayRef<int64_t> staticTileSizes, ArrayRef<int64_t> interchange,
3034 std::optional<ArrayRef<bool>> scalableSizes) {
3035 build(builder, result, target,
3036 getAsOpFoldResult(builder.getI64ArrayAttr(staticTileSizes)),
3037 interchange, scalableSizes);
3038}
3039
3040void transform::TileUsingForOp::build(
3041 OpBuilder &builder, OperationState &result, Value target,
3042 ArrayRef<OpFoldResult> mixedTileSizes, ArrayRef<int64_t> interchange,
3043 std::optional<ArrayRef<bool>> scalableSizes) {
3044 // Loop types are automaticaly splat by the callee, setting up one is
3045 // enough.
3046 SmallVector<Type> loopTypes(1, builder.getType<transform::AnyOpType>());
3047 build(builder, result, loopTypes, target, mixedTileSizes, interchange,
3048 scalableSizes);
3049}
3050
3051void transform::TileUsingForOp::build(
3052 OpBuilder &builder, OperationState &result, TypeRange loopTypes,
3053 Value target, ArrayRef<OpFoldResult> mixedTileSizes,
3054 ArrayRef<int64_t> interchange,
3055 std::optional<ArrayRef<bool>> scalableSizes) {
3056 SmallVector<int64_t> staticTileSizes;
3057 SmallVector<Value> dynamicTileSizes;
3058 dispatchIndexOpFoldResults(mixedTileSizes, dynamicTileSizes, staticTileSizes);
3059 // Call the default builder which sets up the proper operands segment sizes
3060 // attributes for multiple variadic operands. In the absence of this,
3061 // horrible bugs ensue.
3062 auto staticTileSizesAttr = builder.getDenseI64ArrayAttr(staticTileSizes);
3063 unsigned numExpectedLoops =
3064 staticTileSizes.size() - llvm::count(staticTileSizes, 0);
3065 SmallVector<Type> resultTypes;
3066 resultTypes.reserve(numExpectedLoops);
3067 assert((loopTypes.size() == 1 || loopTypes.size() == numExpectedLoops) &&
3068 "expected one loop type or as many as loops");
3069 if (loopTypes.size() == 1)
3070 resultTypes.append(numExpectedLoops, loopTypes[0]);
3071 else
3072 llvm::append_range(resultTypes, loopTypes);
3073 SmallVector<bool> expandedScalableSizes(mixedTileSizes.size(), false);
3074 if (scalableSizes.has_value())
3075 expandedScalableSizes.assign(scalableSizes->begin(), scalableSizes->end());
3076 build(builder, result, /*tiled_linalg_op=*/target.getType(),
3077 /*loops=*/resultTypes,
3078 /*target=*/target,
3079 /*dynamic_sizes=*/dynamicTileSizes,
3080 /*static_sizes=*/staticTileSizesAttr,
3081 /*interchange=*/builder.getDenseI64ArrayAttr(interchange),
3082 /*scalable_sizes=*/expandedScalableSizes);
3083}
3084
3085LogicalResult transform::TileUsingForOp::verify() {
3086 if (getMixedSizes().size() != getScalableSizes().size())
3087 return emitOpError("expected same number of sizes (")
3088 << getMixedSizes().size() << ") and scalable sizes ("
3089 << getScalableSizes().size() << ")";
3090 ArrayRef<int64_t> staticSizes = getStaticSizes();
3091 unsigned numExpectedLoops = staticSizes.size() - llvm::count(staticSizes, 0);
3092 if (getLoops().size() != numExpectedLoops)
3093 return emitOpError("expected number of loops to tile (")
3094 << numExpectedLoops << ") to match number of `loops` results ("
3095 << getLoops().size() << ")";
3096 return success();
3097}
3098
3099DiagnosedSilenceableFailure
3100transform::TileUsingForOp::apply(transform::TransformRewriter &rewriter,
3101 TransformResults &transformResults,
3102 TransformState &state) {
3103 ArrayRef<int64_t> tileSizes = getStaticSizes();
3104
3105 SmallVector<Operation *> targets =
3106 llvm::to_vector(state.getPayloadOps(getTarget()));
3107 SmallVector<SmallVector<Operation *>> dynamicSizeProducers;
3108 SmallVector<SmallVector<int64_t>> paramSizes;
3109 dynamicSizeProducers.reserve(getDynamicSizes().size());
3110 paramSizes.reserve(getDynamicSizes().size());
3111 for (Value transformValue : getDynamicSizes()) {
3112 if (isa<ParamType>(transformValue.getType())) {
3113 dynamicSizeProducers.push_back({});
3114 ArrayRef<Attribute> params = state.getParams(transformValue);
3115 paramSizes.push_back(
3116 llvm::to_vector(llvm::map_range(params, [](Attribute attr) {
3117 return cast<IntegerAttr>(attr).getValue().getSExtValue();
3118 })));
3119
3120 if (paramSizes.back().size() != targets.size()) {
3121 DiagnosedSilenceableFailure diag =
3122 emitSilenceableError()
3123 << "expected as many parameter values ("
3124 << dynamicSizeProducers.back().size() << ") as target ops ("
3125 << targets.size() << ")";
3126 diag.attachNote(transformValue.getLoc()) << "for this parameter";
3127 return diag;
3128 }
3129
3130 continue;
3131 }
3132 paramSizes.push_back({});
3133 dynamicSizeProducers.push_back(
3134 llvm::to_vector(state.getPayloadOps(transformValue)));
3135
3136 if (dynamicSizeProducers.back().size() != targets.size()) {
3137 DiagnosedSilenceableFailure diag =
3138 emitSilenceableError()
3139 << "expected as many dynamic size-producing operations ("
3140 << dynamicSizeProducers.back().size() << ") as target ops ("
3141 << targets.size() << ")";
3142 diag.attachNote(transformValue.getLoc()) << "for this handle";
3143 return diag;
3144 }
3145
3146 for (Operation *op : dynamicSizeProducers.back()) {
3147 if (op->getNumResults() == 1 &&
3148 isa<IndexType>(op->getResult(0).getType())) {
3149 continue;
3150 }
3151
3152 DiagnosedSilenceableFailure diag =
3153 emitSilenceableError() << "expected sizes to be produced by ops "
3154 "with a single index-type result";
3155 diag.attachNote(op->getLoc()) << "size producer op";
3156 diag.attachNote(transformValue.getLoc()) << "for this handle";
3157 return diag;
3158 }
3159 }
3160
3161 SmallVector<Operation *> tiled;
3162 SmallVector<SmallVector<Operation *, 4>, 4> loops;
3163 loops.resize(getLoops().size());
3164 auto scalableSizes = getScalableSizes();
3165 for (auto [i, op] : llvm::enumerate(targets)) {
3166 auto tilingInterface = dyn_cast<TilingInterface>(op);
3167 if (!tilingInterface) {
3168 DiagnosedSilenceableFailure diag =
3169 emitSilenceableError()
3170 << "only ops implementing TilingInterface are supported";
3171 diag.attachNote(op->getLoc()) << "target op";
3172 return diag;
3173 }
3174 if (tileSizes.size() > tilingInterface.getLoopIteratorTypes().size()) {
3175 DiagnosedSilenceableFailure diag =
3176 emitSilenceableError()
3177 << "too many tiles provided, expected at most "
3178 << tilingInterface.getLoopIteratorTypes().size() << " found "
3179 << tileSizes.size();
3180 diag.attachNote(op->getLoc()) << "target op";
3181 return diag;
3182 }
3183
3184 scf::SCFTilingOptions tilingOptions;
3185 if (tileSizes.empty()) {
3186 tilingOptions.setTileSizeComputationFunction(
3187 [](OpBuilder &, Operation *) -> SmallVector<OpFoldResult> {
3188 return {};
3189 });
3190 } else {
3191 tilingOptions.setTileSizeComputationFunction([&, index = i](OpBuilder &b,
3192 Operation *) {
3193 SmallVector<OpFoldResult> sizes;
3194 sizes.reserve(tileSizes.size());
3195 unsigned dynamicIdx = 0;
3196
3197 for (auto [ofrIdx, ofr] : llvm::enumerate(getMixedSizes())) {
3198 if (auto attr = llvm::dyn_cast_if_present<Attribute>(ofr)) {
3199 if (scalableSizes[ofrIdx]) {
3200 auto val = b.create<arith::ConstantIndexOp>(
3201 getLoc(), cast<IntegerAttr>(attr).getInt());
3202 Value vscale =
3203 b.create<vector::VectorScaleOp>(getLoc(), b.getIndexType());
3204 sizes.push_back(
3205 b.create<arith::MulIOp>(getLoc(), val, vscale).getResult());
3206 } else {
3207 sizes.push_back(attr);
3208 }
3209 continue;
3210 }
3211 ArrayRef<Operation *> dynamicSizes = dynamicSizeProducers[dynamicIdx];
3212 ArrayRef<int64_t> params = paramSizes[dynamicIdx];
3213 ++dynamicIdx;
3214 assert((dynamicSizes.empty() ^ params.empty()) &&
3215 "expected either dynamic sizes or parameters");
3216 if (!params.empty()) {
3217 sizes.push_back(b.getIndexAttr(params[index]));
3218 } else {
3219 sizes.push_back(dynamicSizes[index]->getResult(0));
3220 }
3221 }
3222 return sizes;
3223 });
3224 }
3225
3226 tilingOptions.setInterchange(getInterchange());
3227 FailureOr<scf::SCFTilingResult> maybeTilingResult =
3228 tileUsingSCF(rewriter, tilingInterface, tilingOptions);
3229 if (failed(maybeTilingResult))
3230 return DiagnosedSilenceableFailure::definiteFailure();
3231
3232 rewriter.replaceOp(op, maybeTilingResult->replacements);
3233
3234 tiled.append(maybeTilingResult->tiledOps);
3235 for (const auto &en2 : llvm::enumerate(maybeTilingResult->loops))
3236 loops[en2.index()].push_back(en2.value());
3237 }
3238
3239 transformResults.set(cast<OpResult>(getTiledLinalgOp()), tiled);
3240 for (const auto &en : llvm::enumerate(loops))
3241 transformResults.set(cast<OpResult>(getLoops()[en.index()]), en.value());
3242
3243 return DiagnosedSilenceableFailure::success();
3244}
3245
3246SmallVector<OpFoldResult> transform::TileUsingForOp::getMixedSizes() {
3247 ValueRange dynamic = getDynamicSizes();
3248 ArrayRef<int64_t> tileSizes = getStaticSizes();
3249 SmallVector<OpFoldResult> results;
3250 results.reserve(tileSizes.size());
3251 unsigned dynamicPos = 0;
3252 Builder builder(getContext());
3253 for (int64_t size : tileSizes) {
3254 if (size == ShapedType::kDynamic) {
3255 results.push_back(dynamic[dynamicPos++]);
3256 } else {
3257 results.push_back(builder.getIndexAttr(size));
3258 }
3259 }
3260 return results;
3261}
3262
3263void transform::TileUsingForOp::getEffects(
3264 SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
3265 consumesHandle(getTargetMutable(), effects);
3266 onlyReadsHandle(getDynamicSizesMutable(), effects);
3267 producesHandle(getOperation()->getOpResults(), effects);
3268 modifiesPayload(effects);
3269}
3270
3271//===----------------------------------------------------------------------===//
3272// TileUsingForallOp
3273//===----------------------------------------------------------------------===//
3274
3275void transform::TileUsingForallOp::build(OpBuilder &builder,
3276 OperationState &result, Value target,
3277 ArrayRef<int64_t> staticTileSizes,
3278 transform::TileSizesSpec,
3279 ArrayAttr mapping) {
3280 return build(builder, result,
3281 /*target=*/target,
3282 /*mixedTileSizes=*/
3283 getAsOpFoldResult(builder.getI64ArrayAttr(staticTileSizes)),
3284 /*_=*/TileSizesSpec(),
3285 /*mapping=*/mapping);
3286}
3287
3288void transform::TileUsingForallOp::build(OpBuilder &builder,
3289 OperationState &result, Value target,
3290 ArrayRef<OpFoldResult> mixedTileSizes,
3291 transform::TileSizesSpec,
3292 ArrayAttr mapping) {
3293 SmallVector<int64_t> staticTileSizes;
3294 SmallVector<Value> dynamicTileSizes;
3295 dispatchIndexOpFoldResults(mixedTileSizes, dynamicTileSizes, staticTileSizes);
3296 // Call the default builder which sets up the proper operands segment sizes
3297 // attributes for multiple variadic operands. In the absence of this,
3298 // horrible bugs ensue.
3299 MLIRContext *ctx = builder.getContext();
3300 auto operationType = transform::AnyOpType::get(ctx);
3301 auto staticTileSizesAttr = builder.getDenseI64ArrayAttr(staticTileSizes);
3302 build(builder, result,
3303 /*resultTypes=*/TypeRange{operationType, operationType},
3304 /*target=*/target,
3305 /*num_threads=*/ValueRange{},
3306 /*tile_sizes=*/dynamicTileSizes,
3307 /*packed_num_threads=*/Value(),
3308 /*packed_tile_sizes=*/Value(),
3309 /*static_num_threads=*/builder.getDenseI64ArrayAttr({}),
3310 /*static_tile_sizes=*/staticTileSizesAttr,
3311 /*mapping=*/mapping);
3312}
3313
3314void transform::TileUsingForallOp::build(OpBuilder &builder,
3315 OperationState &result, Value target,
3316 ArrayRef<int64_t> staticNumThreads,
3317 transform::NumThreadsSpec,
3318 ArrayAttr mapping) {
3319 return build(builder, result, target,
3320 getAsOpFoldResult(builder.getI64ArrayAttr(staticNumThreads)),
3321 NumThreadsSpec(), mapping);
3322}
3323
3324void transform::TileUsingForallOp::build(OpBuilder &builder,
3325 OperationState &result, Value target,
3326 ArrayRef<OpFoldResult> mixedNumThreads,
3327 transform::NumThreadsSpec,
3328 ArrayAttr mapping) {
3329 SmallVector<int64_t> staticNumThreads;
3330 SmallVector<Value> dynamicNumThreads;
3331 dispatchIndexOpFoldResults(mixedNumThreads, dynamicNumThreads,
3332 staticNumThreads);
3333 // Call the default builder which sets up the proper operands segment sizes
3334 // attributes for multiple variadic operands. In the absence of this,
3335 // horrible bugs ensue.
3336 MLIRContext *ctx = builder.getContext();
3337 auto operationType = transform::AnyOpType::get(ctx);
3338 auto staticNumThreadsAttr = builder.getDenseI64ArrayAttr(staticNumThreads);
3339 build(builder, result,
3340 /*resultTypes=*/TypeRange{operationType, operationType},
3341 /*target=*/target,
3342 /*num_threads=*/dynamicNumThreads,
3343 /*tile_sizes=*/ValueRange{},
3344 /*packed_num_threads=*/Value(),
3345 /*packed_tile_sizes=*/Value(),
3346 /*static_num_threads=*/staticNumThreadsAttr,
3347 /*static_tile_sizes=*/builder.getDenseI64ArrayAttr({}),
3348 /*mapping=*/mapping);
3349}
3350
3351/// Given `lbs`, `ubs` and `steps` of loops, return (for each loop), the
3352/// normalized upper bound.
3353static SmallVector<OpFoldResult>
3354normalizeUpperBounds(RewriterBase &rewriter, Location loc,
3355 ArrayRef<OpFoldResult> lbs, ArrayRef<OpFoldResult> ubs,
3356 ArrayRef<OpFoldResult> steps) {
3357 AffineExpr s0, s1, s2;
3358 bindSymbols(ctx: rewriter.getContext(), exprs&: s0, exprs&: s1, exprs&: s2);
3359 AffineExpr normalizedUbExpr = (s1 - s0).ceilDiv(other: s2);
3360 SmallVector<OpFoldResult> normalizedUbs;
3361 for (auto [lb, ub, step] : llvm::zip_equal(t&: lbs, u&: ubs, args&: steps)) {
3362 OpFoldResult normalizedUb = affine::makeComposedFoldedAffineApply(
3363 b&: rewriter, loc, expr: normalizedUbExpr, operands: {lb, ub, step});
3364 normalizedUbs.push_back(Elt: normalizedUb);
3365 }
3366 return normalizedUbs;
3367}
3368
3369/// When a loop is normalized, the uses of the induction variable within the
3370/// loop need to replaced with `original_lb + old_iv * original_step`.
3371static SmallVector<Value> denormalizeIndVar(RewriterBase &rewriter,
3372 Location loc, ValueRange ivs,
3373 ArrayRef<OpFoldResult> lbs,
3374 ArrayRef<OpFoldResult> steps) {
3375 AffineExpr s0, s1;
3376 AffineExpr d0;
3377 bindSymbols(ctx: rewriter.getContext(), exprs&: s0, exprs&: s1);
3378 bindDims(ctx: rewriter.getContext(), exprs&: d0);
3379 AffineExpr denormExpr = s0 + d0 * s1;
3380 SmallVector<Value> denormalizedIvs;
3381
3382 for (auto [iv, lb, step] : llvm::zip_equal(t&: ivs, u&: lbs, args&: steps)) {
3383 OpFoldResult denormValue = affine::makeComposedFoldedAffineApply(
3384 b&: rewriter, loc, expr: denormExpr, operands: ArrayRef<OpFoldResult>{iv, lb, step});
3385 denormalizedIvs.push_back(
3386 Elt: getValueOrCreateConstantIndexOp(b&: rewriter, loc, ofr: denormValue));
3387 }
3388 return denormalizedIvs;
3389}
3390
3391/// Given a `scf.forall` loop return a loop op with the loop bounds
3392/// normalized.
3393/// TODO: Replace this with a general utility to normalize `scf.forall`.
3394/// At the time of writing, this wasnt done since adding this to `scf`
3395/// dialect would disallow using of `affine.apply` operations due
3396/// to cyclic dependencies. To avoid churn in lit tests
3397/// with the change this was added with, defer that to a follow up.
3398static scf::ForallOp normalizeForallLoopOp(RewriterBase &rewriter,
3399 scf::ForallOp loop) {
3400 SmallVector<OpFoldResult> lbs = loop.getMixedLowerBound();
3401 SmallVector<OpFoldResult> ubs = loop.getMixedUpperBound();
3402 SmallVector<OpFoldResult> steps = loop.getMixedStep();
3403
3404 if (llvm::all_of(Range&: lbs, P: isZeroInteger) && llvm::all_of(Range&: steps, P: isOneInteger)) {
3405 return loop;
3406 }
3407
3408 Location loc = loop.getLoc();
3409 SmallVector<OpFoldResult> normalizedUbs =
3410 normalizeUpperBounds(rewriter, loc, lbs, ubs, steps);
3411 SmallVector<OpFoldResult> normalizedLbs(normalizedUbs.size(),
3412 rewriter.getIndexAttr(0));
3413 SmallVector<OpFoldResult> normalizedSteps(normalizedUbs.size(),
3414 rewriter.getIndexAttr(1));
3415
3416 auto normalizedForallOp = rewriter.create<scf::ForallOp>(
3417 loc, normalizedLbs, normalizedUbs, normalizedSteps, loop.getOutputs(),
3418 loop.getMapping(), [](OpBuilder &, Location, ValueRange) {});
3419
3420 auto normalizedLoopIvs = normalizedForallOp.getInductionVars();
3421 OpBuilder::InsertionGuard g(rewriter);
3422 Block *normalizedLoopBlock = normalizedForallOp.getBody();
3423 rewriter.setInsertionPointToStart(normalizedLoopBlock);
3424
3425 SmallVector<Value> argValues =
3426 denormalizeIndVar(rewriter, loc, normalizedLoopIvs, lbs, steps);
3427 argValues.append(normalizedForallOp.getRegionIterArgs().begin(),
3428 normalizedForallOp.getRegionIterArgs().end());
3429 Block *origLoopBlock = loop.getBody();
3430 rewriter.mergeBlocks(source: origLoopBlock, dest: normalizedLoopBlock, argValues);
3431
3432 rewriter.replaceOp(loop, normalizedForallOp);
3433 return normalizedForallOp;
3434}
3435
3436DiagnosedSilenceableFailure transform::tileToForallOpImpl(
3437 RewriterBase &rewriter, transform::TransformState &state,
3438 TransformOpInterface transformOp, Operation *target,
3439 ArrayRef<OpFoldResult> mixedNumThreads,
3440 ArrayRef<OpFoldResult> mixedTileSizes, std::optional<ArrayAttr> mapping,
3441 scf::SCFTilingResult &tilingResult) {
3442 // Transform all targets one by one.
3443 auto tileableOp = dyn_cast<TilingInterface>(target);
3444 if (!tileableOp) {
3445 DiagnosedSilenceableFailure diag =
3446 transformOp.emitSilenceableError()
3447 << "only TilingInterface ops are supported";
3448 diag.attachNote(loc: target->getLoc()) << "target op";
3449 return diag;
3450 }
3451 rewriter.setInsertionPoint(tileableOp);
3452 scf::SCFTilingOptions options;
3453 options.setLoopType(scf::SCFTilingOptions::LoopType::ForallOp);
3454 if (!mixedNumThreads.empty()) {
3455 options.setNumThreads(mixedNumThreads);
3456 } else {
3457 options.setTileSizes(mixedTileSizes);
3458 }
3459 if (mapping) {
3460 options.setMapping(mapping.value().getValue());
3461 }
3462 FailureOr<scf::SCFTilingResult> maybeTilingResult =
3463 scf::tileUsingSCF(rewriter, tileableOp, options);
3464
3465 if (failed(maybeTilingResult))
3466 return transformOp.emitDefaultSilenceableFailure(tileableOp);
3467
3468 rewriter.replaceOp(tileableOp, maybeTilingResult->replacements);
3469
3470 tilingResult = *maybeTilingResult;
3471
3472 if (mixedNumThreads.empty()) {
3473 auto generatedForallOp = cast<scf::ForallOp>(tilingResult.loops.front());
3474 OpBuilder::InsertionGuard g(rewriter);
3475 rewriter.setInsertionPoint(generatedForallOp);
3476 scf::ForallOp normalizedForallOp =
3477 normalizeForallLoopOp(rewriter, generatedForallOp);
3478 tilingResult.loops.front() = normalizedForallOp;
3479 }
3480
3481 return DiagnosedSilenceableFailure::success();
3482}
3483
3484DiagnosedSilenceableFailure transform::TileUsingForallOp::apply(
3485 transform::TransformRewriter &rewriter,
3486 transform::TransformResults &transformResults,
3487 transform::TransformState &state) {
3488 auto transformOp = cast<TransformOpInterface>(getOperation());
3489
3490 // Result payload ops.
3491 SmallVector<Operation *> tileOps;
3492 SmallVector<Operation *> tiledOps;
3493
3494 // Unpack handles.
3495 SmallVector<OpFoldResult> mixedNumThreads;
3496 DiagnosedSilenceableFailure status =
3497 getPackedNumThreads()
3498 ? unpackSingleIndexResultPayloadOperations(
3499 state, transformOp, mixedNumThreads, getPackedNumThreads())
3500 : unpackSingleIndexResultPayloadOperations(
3501 state, transformOp, mixedNumThreads, getMixedNumThreads());
3502 if (!status.succeeded())
3503 return status;
3504 SmallVector<OpFoldResult> mixedTileSizes;
3505 status = getPackedTileSizes()
3506 ? unpackSingleIndexResultPayloadOperations(
3507 state, transformOp, mixedTileSizes, getPackedTileSizes())
3508 : unpackSingleIndexResultPayloadOperations(
3509 state, transformOp, mixedTileSizes, getMixedTileSizes());
3510 if (!status.succeeded())
3511 return status;
3512
3513 for (Operation *target : state.getPayloadOps(getTarget())) {
3514 scf::SCFTilingResult tilingResult;
3515 DiagnosedSilenceableFailure diag = tileToForallOpImpl(
3516 rewriter, state, transformOp, target, mixedNumThreads, mixedTileSizes,
3517 getMapping(), tilingResult);
3518 if (!diag.succeeded())
3519 return diag;
3520 tileOps.push_back(tilingResult.loops.front());
3521 tiledOps.append(tilingResult.tiledOps);
3522 }
3523
3524 transformResults.set(cast<OpResult>(getForallOp()), tileOps);
3525 transformResults.set(cast<OpResult>(getTiledOp()), tiledOps);
3526
3527 return DiagnosedSilenceableFailure::success();
3528}
3529
3530void transform::TileUsingForallOp::getEffects(
3531 SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
3532 consumesHandle(getTargetMutable(), effects);
3533 onlyReadsHandle(getTileSizesMutable(), effects);
3534 onlyReadsHandle(getNumThreadsMutable(), effects);
3535 onlyReadsHandle(getPackedNumThreadsMutable(), effects);
3536 onlyReadsHandle(getPackedTileSizesMutable(), effects);
3537 producesHandle(getOperation()->getOpResults(), effects);
3538 modifiesPayload(effects);
3539}
3540
3541SmallVector<OpFoldResult> TileUsingForallOp::getMixedNumThreads() {
3542 Builder b(getContext());
3543 return getMixedValues(getStaticNumThreads(), getNumThreads(), b);
3544}
3545
3546SmallVector<OpFoldResult> TileUsingForallOp::getMixedTileSizes() {
3547 Builder b(getContext());
3548 return getMixedValues(getStaticTileSizes(), getTileSizes(), b);
3549}
3550
3551LogicalResult TileUsingForallOp::verify() {
3552 int numThreadsSpec = static_cast<int>(!getMixedNumThreads().empty()) +
3553 static_cast<int>(getPackedNumThreads() != Value());
3554 if (numThreadsSpec > 1)
3555 return emitOpError(
3556 "num_threads and packed_num_threads are mutually exclusive");
3557 int tileSizesSpec = static_cast<int>(!getMixedTileSizes().empty()) +
3558 static_cast<int>(getPackedTileSizes() != Value());
3559 if (tileSizesSpec > 1)
3560 return emitOpError(
3561 "tile_sizes and packed_tile_sizes are mutually exclusive");
3562 if (numThreadsSpec == 0 && tileSizesSpec == 0)
3563 return emitOpError("either (packed_)num_threads or (packed_)tile_sizes "
3564 "must be specified");
3565 return success();
3566}
3567
3568//===----------------------------------------------------------------------===//
3569// VectorizeChildrenAndApplyPatternsOp
3570//===----------------------------------------------------------------------===//
3571
3572void transform::VectorizeChildrenAndApplyPatternsOp::build(
3573 OpBuilder &builder, OperationState &result, Value target,
3574 bool vectorizePadding, bool vectorizeExtract, bool flatten1DDepthwiseConv) {
3575 result.addOperands(target);
3576 if (vectorizePadding) {
3577 result.addAttribute(
3578 VectorizeChildrenAndApplyPatternsOp::getVectorizePaddingAttrName(
3579 result.name),
3580 builder.getUnitAttr());
3581 }
3582 if (vectorizeExtract) {
3583 result.addAttribute(
3584 VectorizeChildrenAndApplyPatternsOp::getVectorizeNdExtractAttrName(
3585 result.name),
3586 builder.getUnitAttr());
3587 }
3588 if (flatten1DDepthwiseConv) {
3589 result.addAttribute(
3590 VectorizeChildrenAndApplyPatternsOp::getFlatten_1dDepthwiseConvAttrName(
3591 result.name),
3592 builder.getUnitAttr());
3593 }
3594 result.addTypes(transform::AnyOpType::get(builder.getContext()));
3595}
3596
3597namespace {
3598/// This is an helper only to call vectorize via a pattern inside of
3599/// VectorizeChildrenAndApplyPatternsOp::applyToOne.
3600struct VectorizationPattern : public RewritePattern {
3601 explicit VectorizationPattern(MLIRContext *context,
3602 bool vectorizeExtract = false,
3603 bool flattenConv = false)
3604 : RewritePattern(MatchAnyOpTypeTag(), /*benefit=*/1, context),
3605 vectorizeNDExtract(vectorizeExtract),
3606 flatten1DDepthwiseConv(flattenConv) {}
3607 LogicalResult matchAndRewrite(Operation *op,
3608 PatternRewriter &rewriter) const override {
3609 if (!linalg::hasVectorizationImpl(op))
3610 return rewriter.notifyMatchFailure(arg&: op,
3611 msg: "Unsupported Op, cannot vectorize");
3612 return vectorize(rewriter, op, /*inputVectorSizes=*/{},
3613 /*inputScalableVecDims=*/{}, vectorizeNDExtract,
3614 flatten1DDepthwiseConv);
3615 }
3616
3617private:
3618 /// Controls whether to vectorize `tensor.extract` when the input tensor is
3619 /// rank >= 2.
3620 bool vectorizeNDExtract = false;
3621 /// Controls whether to "flatten" the channel dimension when vectorising 1D
3622 /// depthwise convolutions. This should lead to bette vectorization for
3623 /// tensors with a low number of channel dimensions.
3624 bool flatten1DDepthwiseConv = false;
3625};
3626} // namespace
3627
3628DiagnosedSilenceableFailure
3629transform::VectorizeChildrenAndApplyPatternsOp::applyToOne(
3630 transform::TransformRewriter &rewriter, Operation *target,
3631 transform::ApplyToEachResultList &results,
3632 transform::TransformState &state) {
3633 if (!target->hasTrait<OpTrait::IsIsolatedFromAbove>()) {
3634 auto diag = this->emitOpError("requires isolated-from-above targets");
3635 diag.attachNote(target->getLoc()) << "non-isolated target";
3636 return DiagnosedSilenceableFailure::definiteFailure();
3637 }
3638
3639 MLIRContext *ctx = getContext();
3640 RewritePatternSet patterns(ctx);
3641 patterns.add<VectorizationPattern>(ctx, getVectorizeNdExtract(),
3642 getFlatten_1dDepthwiseConv());
3643
3644 if (!getDisableTransferPermutationMapLoweringPatterns())
3645 vector::populateVectorTransferPermutationMapLoweringPatterns(patterns);
3646
3647 if (!getDisableMultiReductionToContractPatterns())
3648 vector::populateVectorReductionToContractPatterns(patterns);
3649
3650 vector::populateSinkVectorOpsPatterns(patterns);
3651
3652 patterns.add<linalg::LinalgCopyVTRForwardingPattern,
3653 linalg::LinalgCopyVTWForwardingPattern>(ctx,
3654 /*benefit=*/2);
3655 vector::TransferReadOp::getCanonicalizationPatterns(patterns, ctx);
3656 vector::TransferWriteOp::getCanonicalizationPatterns(patterns, ctx);
3657 tensor::populateFoldTensorSubsetIntoVectorTransferPatterns(patterns);
3658
3659 patterns.add<CopyVectorizationPattern>(ctx);
3660
3661 if (getVectorizePadding()) {
3662 linalg::populatePadOpVectorizationPatterns(patterns);
3663 // This creates an alternative path for lowering tensor.pad - by
3664 // decomposing it into e.g. linalg.fill.
3665 linalg::populateDecomposePadPatterns(patterns);
3666 }
3667 vector::populateVectorStepLoweringPatterns(patterns);
3668
3669 TrackingListener listener(state, *this);
3670 if (failed(
3671 applyPatternsGreedily(target, std::move(patterns),
3672 GreedyRewriteConfig().setListener(&listener))))
3673 return emitDefaultDefiniteFailure(target);
3674
3675 results.push_back(target);
3676 return DiagnosedSilenceableFailure::success();
3677}
3678
3679//===----------------------------------------------------------------------===//
3680// VectorizeOp
3681//===----------------------------------------------------------------------===//
3682
3683DiagnosedSilenceableFailure transform::VectorizeOp::apply(
3684 transform::TransformRewriter &rewriter,
3685 mlir::transform::TransformResults &transformResults,
3686 mlir::transform::TransformState &state) {
3687 auto targets = state.getPayloadOps(getTarget());
3688 if (std::empty(targets))
3689 return DiagnosedSilenceableFailure::success();
3690 auto transformOp = cast<TransformOpInterface>(getOperation());
3691 SmallVector<int64_t> vectorSizes;
3692 DiagnosedSilenceableFailure status = reifyMixedParamAndHandleResults(
3693 state, transformOp, getMixedVectorSizes(), vectorSizes);
3694 if (!status.succeeded())
3695 return status;
3696
3697 // TODO: Check that the correct number of vectorSizes was provided.
3698 for (Operation *target : targets) {
3699 if (!linalg::hasVectorizationImpl(target)) {
3700 return mlir::emitSilenceableFailure(target->getLoc())
3701 << "Unsupported Op, cannot vectorize";
3702 }
3703
3704 if (failed(linalg::vectorize(rewriter, target, vectorSizes,
3705 getScalableSizes(),
3706 getVectorizeNdExtract().value_or(false)))) {
3707 return mlir::emitSilenceableFailure(target->getLoc())
3708 << "Attempted to vectorize, but failed";
3709 }
3710 }
3711
3712 return DiagnosedSilenceableFailure::success();
3713}
3714
3715void transform::VectorizeOp::getEffects(
3716 SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
3717 consumesHandle(getTargetMutable(), effects);
3718 onlyReadsHandle(getVectorSizesMutable(), effects);
3719 modifiesPayload(effects);
3720}
3721
3722SmallVector<OpFoldResult> VectorizeOp::getMixedVectorSizes() {
3723 OpBuilder b(getContext());
3724 return getMixedValues(getStaticVectorSizes(), getVectorSizes(), b);
3725}
3726
3727LogicalResult transform::VectorizeOp::verify() {
3728 if (getStaticVectorSizes().size() != getScalableSizes().size())
3729 return emitOpError("expected same number of vector sizes (")
3730 << getStaticVectorSizes().size() << ") and scalable sizes ("
3731 << getScalableSizes().size() << ")";
3732 return success();
3733}
3734
3735//===----------------------------------------------------------------------===//
3736// HoistRedundantVectorTransfersOp
3737//===----------------------------------------------------------------------===//
3738
3739DiagnosedSilenceableFailure
3740transform::HoistRedundantVectorTransfersOp::applyToOne(
3741 transform::TransformRewriter &rewriter, func::FuncOp target,
3742 transform::ApplyToEachResultList &results,
3743 transform::TransformState &state) {
3744 // WARNING: This hoisting does not model parallelism and is generally
3745 // incorrect when used on distributed loops with memref semantics!
3746 // TODO: obsolete and should be retired.
3747 linalg::hoistRedundantVectorTransfers(target, getVerifyNonZeroTrip());
3748 results.push_back(target);
3749 return DiagnosedSilenceableFailure::success();
3750}
3751
3752//===----------------------------------------------------------------------===//
3753// HoistRedundantVectorBroadcastsOp
3754//===----------------------------------------------------------------------===//
3755
3756DiagnosedSilenceableFailure
3757transform::HoistRedundantVectorBroadcastsOp::applyToOne(
3758 transform::TransformRewriter &rewriter, mlir::Operation *target,
3759 transform::ApplyToEachResultList &results,
3760 transform::TransformState &state) {
3761 rewriter.setInsertionPoint(target);
3762 linalg::hoistRedundantVectorBroadcasts(rewriter, target);
3763 results.push_back(target);
3764 return DiagnosedSilenceableFailure::success();
3765}
3766
3767//===----------------------------------------------------------------------===//
3768// ConvertConv2DToImg2ColOp.
3769//===----------------------------------------------------------------------===//
3770
3771DiagnosedSilenceableFailure transform::ConvertConv2DToImg2ColOp::applyToOne(
3772 transform::TransformRewriter &rewriter, linalg::LinalgOp target,
3773 transform::ApplyToEachResultList &results,
3774 transform::TransformState &state) {
3775 rewriter.setInsertionPoint(target);
3776 auto maybeTransformed =
3777 TypeSwitch<Operation *, FailureOr<std::pair<Operation *, Operation *>>>(
3778 target)
3779 .Case([&](linalg::Conv2DNhwcHwcfOp op) {
3780 return rewriteInIm2Col(rewriter, op);
3781 })
3782 .Case([&](linalg::Conv2DNhwcFhwcOp op) {
3783 return rewriteInIm2Col(rewriter, op);
3784 })
3785 .Case([&](linalg::DepthwiseConv2DNhwcHwcOp op) {
3786 return rewriteInIm2Col(rewriter, op);
3787 })
3788 .Case([&](linalg::Conv2DNchwFchwOp op) {
3789 return rewriteInIm2Col(rewriter, op);
3790 })
3791 .Default([&](Operation *op) {
3792 return rewriter.notifyMatchFailure(op, "not supported");
3793 });
3794 if (failed(maybeTransformed))
3795 return emitDefaultSilenceableFailure(target);
3796 // Handle to the operation producing the img2col tensor.
3797 results.push_back(maybeTransformed->first);
3798 // Handle to the operation that replaces the original convolution.
3799 results.push_back(maybeTransformed->second);
3800 return DiagnosedSilenceableFailure::success();
3801}
3802
3803//===----------------------------------------------------------------------===//
3804// FlattenElementwiseLinalgOp.
3805//===----------------------------------------------------------------------===//
3806
3807DiagnosedSilenceableFailure transform::FlattenElementwiseLinalgOp::applyToOne(
3808 transform::TransformRewriter &rewriter, linalg::LinalgOp target,
3809 transform::ApplyToEachResultList &results,
3810 transform::TransformState &state) {
3811 rewriter.setInsertionPoint(target);
3812 if (!isElementwise(target))
3813 return mlir::emitSilenceableFailure(target->getLoc())
3814 << "only elementwise flattening is supported";
3815
3816 // If rank <= 1, do nothing
3817 if (target.getNumLoops() <= 1) {
3818 results.push_back(target);
3819 return DiagnosedSilenceableFailure::success();
3820 }
3821
3822 // Attempt to flatten all dims to one.
3823 ReassociationIndices reassociation(target.getNumLoops());
3824 std::iota(reassociation.begin(), reassociation.end(), 0);
3825 auto maybeFlattened =
3826 collapseOpIterationDims(target, reassociation, rewriter);
3827 if (failed(maybeFlattened))
3828 return mlir::emitSilenceableFailure(target->getLoc())
3829 << "attempted to flatten, but failed";
3830 results.push_back(maybeFlattened->collapsedOp);
3831 rewriter.replaceOp(target, maybeFlattened->results);
3832 return DiagnosedSilenceableFailure::success();
3833}
3834
3835//===----------------------------------------------------------------------===//
3836// TransposeConv2DOp
3837//===----------------------------------------------------------------------===//
3838
3839DiagnosedSilenceableFailure transform::TransposeConv2DOp::applyToOne(
3840 transform::TransformRewriter &rewriter, linalg::LinalgOp target,
3841 transform::ApplyToEachResultList &results,
3842 transform::TransformState &state) {
3843 rewriter.setInsertionPoint(target);
3844 auto maybeTransformed =
3845 TypeSwitch<Operation *, FailureOr<Operation *>>(target)
3846 .Case([&](linalg::Conv2DNhwcFhwcOp op) {
3847 return transposeConv2D(rewriter, op);
3848 })
3849 .Case([&](linalg::Conv2DNhwcFhwcQOp op) {
3850 return transposeConv2D(rewriter, op);
3851 })
3852 .Default([&](Operation *op) {
3853 return rewriter.notifyMatchFailure(op, "not supported");
3854 });
3855 if (failed(maybeTransformed))
3856 return emitDefaultSilenceableFailure(target);
3857 // Handle to the new Conv2D operation with transposed filters
3858 results.push_back(*maybeTransformed);
3859 return DiagnosedSilenceableFailure::success();
3860}
3861
3862//===----------------------------------------------------------------------===//
3863// TransposeMatmulOp
3864//===----------------------------------------------------------------------===//
3865
3866DiagnosedSilenceableFailure transform::TransposeMatmulOp::applyToOne(
3867 transform::TransformRewriter &rewriter, linalg::LinalgOp target,
3868 transform::ApplyToEachResultList &results,
3869 transform::TransformState &state) {
3870 rewriter.setInsertionPoint(target);
3871 bool transposeLHS = getInputToTranspose() == TransposeMatmulInput::lhs;
3872 auto maybeTransformed =
3873 TypeSwitch<Operation *, FailureOr<Operation *>>(target)
3874 .Case([&](linalg::MatmulOp op) {
3875 return transposeMatmul(rewriter, op, transposeLHS);
3876 })
3877 .Case([&](linalg::BatchMatmulOp op) {
3878 return transposeBatchMatmul(rewriter, op, transposeLHS);
3879 })
3880 .Default([&](Operation *op) { return failure(); });
3881 if (failed(maybeTransformed))
3882 return emitSilenceableFailure(target->getLoc()) << "not supported";
3883 // Handle to the new Matmul operation with transposed filters
3884 results.push_back(*maybeTransformed);
3885 return DiagnosedSilenceableFailure::success();
3886}
3887
3888//===----------------------------------------------------------------------===//
3889// InsertSliceToCopyOp
3890//===----------------------------------------------------------------------===//
3891template <typename OpTy>
3892DiagnosedSilenceableFailure doit(RewriterBase &rewriter, OpTy target,
3893 transform::ApplyToEachResultList &results,
3894 transform::TransformState &state) {
3895 static_assert(llvm::is_one_of<OpTy, tensor::InsertSliceOp,
3896 tensor::ParallelInsertSliceOp>() &&
3897 "wrong op type");
3898
3899 if (auto copySource =
3900 target.getSource().template getDefiningOp<linalg::CopyOp>()) {
3901 results.push_back(copySource);
3902 return DiagnosedSilenceableFailure::success();
3903 }
3904
3905 // If we are inside an InParallel region, temporarily set the insertion point
3906 // outside: only tensor.parallel_insert_slice ops are allowed in there.
3907 if constexpr (std::is_same_v<OpTy, tensor::ParallelInsertSliceOp>) {
3908 rewriter.setInsertionPoint(
3909 target->template getParentOfType<scf::InParallelOp>());
3910 }
3911
3912 Value extracted = rewriter.create<tensor::ExtractSliceOp>(
3913 target.getLoc(), target.getDest(), target.getMixedOffsets(),
3914 target.getMixedSizes(), target.getMixedStrides());
3915 Value copied = rewriter
3916 .create<linalg::CopyOp>(target.getLoc(),
3917 target.getSource(), extracted)
3918 .getResult(0);
3919 // Reset the insertion point.
3920 rewriter.setInsertionPoint(target);
3921 rewriter.replaceOpWithNewOp<OpTy>(
3922 target, copied, target.getDest(), target.getMixedOffsets(),
3923 target.getMixedSizes(), target.getMixedStrides());
3924
3925 results.push_back(op: copied.getDefiningOp());
3926 return DiagnosedSilenceableFailure::success();
3927}
3928
3929DiagnosedSilenceableFailure transform::InsertSliceToCopyOp::applyToOne(
3930 transform::TransformRewriter &rewriter, Operation *targetOp,
3931 transform::ApplyToEachResultList &results,
3932 transform::TransformState &state) {
3933
3934 rewriter.setInsertionPoint(targetOp);
3935 if (auto target = dyn_cast<tensor::InsertSliceOp>(targetOp))
3936 return doit(rewriter, target, results, state);
3937 if (auto target = dyn_cast<tensor::ParallelInsertSliceOp>(targetOp))
3938 return doit(rewriter, target, results, state);
3939
3940 DiagnosedSilenceableFailure diag =
3941 emitSilenceableError()
3942 << "only InsertSliceOp and ParallelInsertSliceOp ops are supported";
3943 diag.attachNote(targetOp->getLoc()) << "target op";
3944 return diag;
3945}
3946
3947//===----------------------------------------------------------------------===//
3948// MapCopyToThreadsOp
3949//===----------------------------------------------------------------------===//
3950
3951DiagnosedSilenceableFailure transform::MapCopyToThreadsOp::applyToOne(
3952 transform::TransformRewriter &rewriter, Operation *target,
3953 transform::ApplyToEachResultList &results,
3954 transform::TransformState &state) {
3955 // Check if the op is supported.
3956 if (!isa<linalg::CopyOp, tensor::PadOp>(target)) {
3957 DiagnosedSilenceableFailure diag =
3958 emitSilenceableError()
3959 << "only linalg.copy and tensor.pad target ops are supported";
3960 diag.attachNote(target->getLoc()) << "target op";
3961 return diag;
3962 }
3963 assert(target->getNumResults() == 1 && "expected single result");
3964 auto resultShapedType = cast<ShapedType>(target->getResult(0).getType());
3965 if (!resultShapedType.hasStaticShape()) {
3966 DiagnosedSilenceableFailure diag =
3967 emitSilenceableError()
3968 << "only statically sized ops of rank <= 3 are supported";
3969 diag.attachNote(target->getLoc()) << "target op";
3970 return diag;
3971 }
3972
3973 // Conservatively set the minimum viable desired bitwidth alignment.
3974 int64_t desiredBitAlignment = getDesiredBitAlignment();
3975 int64_t eltBitwidth =
3976 resultShapedType.getElementType().getIntOrFloatBitWidth();
3977 if (desiredBitAlignment % eltBitwidth != 0) {
3978 desiredBitAlignment = eltBitwidth;
3979 }
3980
3981 gpu::CopyMappingInfo mapping(
3982 /*ctx=*/getContext(),
3983 /*totalNumThreads=*/getTotalNumThreads(),
3984 /*alignment=*/desiredBitAlignment,
3985 /*sizes=*/resultShapedType.getShape(),
3986 /*favorPredication=*/false,
3987 /*elementalBitwidth=*/
3988 resultShapedType.getElementType().getIntOrFloatBitWidth());
3989 if (mapping.status == gpu::CopyMappingInfo::Status::Invalid) {
3990 DiagnosedSilenceableFailure diag =
3991 emitSilenceableError()
3992 << "too few threads to map copy op to threads on the most minor "
3993 "dimension, given alignment and vector size constraints, try "
3994 "smaller tile size of mapping to more threads";
3995 diag.attachNote(target->getLoc()) << "target op";
3996 return diag;
3997 }
3998
3999 // OpBuilder only used to compute attributes.
4000 OpBuilder b(getContext());
4001 scf::SCFTilingResult tilingResult;
4002 DiagnosedSilenceableFailure diag = tileToForallOpImpl(
4003 /*rewriter=*/rewriter,
4004 /*state=*/state,
4005 /*transformOp=*/*this,
4006 /*target=*/target,
4007 /*mixedNumThreads=*/getMixedValues(mapping.numThreads, {}, b),
4008 /*mixedTileSizes=*/ArrayRef<OpFoldResult>{},
4009 /*mapping=*/b.getArrayAttr(mapping.threadMapping),
4010 /*tilingResult=*/tilingResult);
4011 if (!diag.succeeded())
4012 return diag;
4013
4014 results.push_back(tilingResult.loops.front());
4015 for (auto op : tilingResult.tiledOps)
4016 results.push_back(op);
4017 return DiagnosedSilenceableFailure::success();
4018}
4019
4020//===----------------------------------------------------------------------===//
4021// WinogradConv2DOp
4022//===----------------------------------------------------------------------===//
4023
4024DiagnosedSilenceableFailure transform::WinogradConv2DOp::applyToOne(
4025 transform::TransformRewriter &rewriter, linalg::LinalgOp target,
4026 transform::ApplyToEachResultList &results,
4027 transform::TransformState &state) {
4028 rewriter.setInsertionPoint(target);
4029 FailureOr<Operation *> maybeTransformed = failure();
4030 bool supported = TypeSwitch<Operation *, bool>(target)
4031 .Case([&](linalg::Conv2DNhwcFhwcOp op) {
4032 maybeTransformed =
4033 winogradConv2D(rewriter, op, getM(), getR());
4034 return true;
4035 })
4036 .Default([&](Operation *op) { return false; });
4037
4038 if (!supported) {
4039 return emitSilenceableError()
4040 << "this operation is not supported to convert to Winograd Conv2D";
4041 }
4042
4043 if (failed(maybeTransformed)) {
4044 return emitSilenceableError() << "apply Winograd Conv2D failed";
4045 }
4046
4047 results.push_back(*maybeTransformed);
4048 return DiagnosedSilenceableFailure::success();
4049}
4050
4051DiagnosedSilenceableFailure transform::DecomposeWinogradOp::applyToOne(
4052 transform::TransformRewriter &rewriter, Operation *target,
4053 transform::ApplyToEachResultList &results,
4054 transform::TransformState &state) {
4055 rewriter.setInsertionPoint(target);
4056 FailureOr<Operation *> maybeTransformed = failure();
4057 bool supported =
4058 TypeSwitch<Operation *, bool>(target)
4059 .Case([&](linalg::WinogradFilterTransformOp op) {
4060 maybeTransformed = decomposeWinogradFilterTransformOp(rewriter, op);
4061 return true;
4062 })
4063 .Case([&](linalg::WinogradInputTransformOp op) {
4064 maybeTransformed = decomposeWinogradInputTransformOp(rewriter, op);
4065 return true;
4066 })
4067 .Case([&](linalg::WinogradOutputTransformOp op) {
4068 maybeTransformed = decomposeWinogradOutputTransformOp(rewriter, op);
4069 return true;
4070 })
4071 .Default([&](Operation *op) { return false; });
4072
4073 if (!supported) {
4074 DiagnosedSilenceableFailure diag =
4075 emitSilenceableError()
4076 << "this operation is not supported to decompose into other operations";
4077 diag.attachNote(target->getLoc()) << "target op";
4078 return diag;
4079 }
4080
4081 if (failed(maybeTransformed)) {
4082 DiagnosedSilenceableFailure diag =
4083 emitSilenceableError() << "decompose Winograd operations failed";
4084 diag.attachNote(target->getLoc()) << "target op";
4085 return diag;
4086 }
4087
4088 results.push_back(*maybeTransformed);
4089 return DiagnosedSilenceableFailure::success();
4090}
4091
4092#include "mlir/Dialect/Linalg/TransformOps/LinalgTransformOpsEnums.cpp.inc"
4093
4094#define GET_OP_CLASSES
4095#include "mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp.inc"
4096

Provided by KDAB

Privacy Policy
Update your C++ knowledge – Modern C++11/14/17 Training
Find out more

source code of mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp