1//===- LinalgTransformOps.cpp - Implementation of Linalg transform ops ----===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#include "mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.h"
10
11#include "mlir/AsmParser/AsmParser.h"
12
13#include "mlir/Dialect/Affine/IR/AffineOps.h"
14#include "mlir/Dialect/Arith/IR/Arith.h"
15#include "mlir/Dialect/Arith/Utils/Utils.h"
16#include "mlir/Dialect/Bufferization/IR/Bufferization.h"
17#include "mlir/Dialect/Bufferization/Transforms/OneShotAnalysis.h"
18#include "mlir/Dialect/GPU/IR/GPUDialect.h"
19#include "mlir/Dialect/Linalg/IR/Linalg.h"
20#include "mlir/Dialect/Linalg/TransformOps/GPUHeuristics.h"
21#include "mlir/Dialect/Linalg/TransformOps/Syntax.h"
22#include "mlir/Dialect/Linalg/Transforms/Hoisting.h"
23#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
24#include "mlir/Dialect/Linalg/Utils/Utils.h"
25#include "mlir/Dialect/SCF/Transforms/TileUsingInterface.h"
26#include "mlir/Dialect/Tensor/IR/Tensor.h"
27#include "mlir/Dialect/Transform/IR/TransformTypes.h"
28#include "mlir/Dialect/Transform/Interfaces/TransformInterfaces.h"
29#include "mlir/Dialect/Transform/Utils/Utils.h"
30#include "mlir/Dialect/Utils/IndexingUtils.h"
31#include "mlir/Dialect/Utils/StaticValueUtils.h"
32#include "mlir/Dialect/Vector/Transforms/LoweringPatterns.h"
33#include "mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h"
34#include "mlir/IR/BuiltinTypeInterfaces.h"
35#include "mlir/IR/PatternMatch.h"
36#include "mlir/IR/TypeUtilities.h"
37#include "mlir/Interfaces/TilingInterface.h"
38#include "mlir/Support/LLVM.h"
39#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
40#include "llvm/ADT/STLExtras.h"
41#include "llvm/ADT/ScopeExit.h"
42#include "llvm/ADT/TypeSwitch.h"
43#include "llvm/Support/Debug.h"
44#include "llvm/Support/LogicalResult.h"
45#include <type_traits>
46
47using namespace mlir;
48using namespace mlir::linalg;
49using namespace mlir::transform;
50
51#define DEBUG_TYPE "linalg-transforms"
52#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ")
53#define DBGSNL() (llvm::dbgs() << "\n")
54#define LDBG(X) LLVM_DEBUG(DBGS() << (X) << "\n")
55
56/// Attempts to apply the pattern specified as template argument to the given
57/// operation. The pattern is expected to have a `returningMatchAndRewrite`
58/// function that returns the "main" result or failure. Returns failure if the
59/// pattern failed to apply. Extra arguments are forwarded to the pattern
60/// constructor.
61template <typename PatternTy, typename... Args>
62static FailureOr<LinalgOp> tryApply(Operation *operation, Args &&...args) {
63 // Check if the given operation has the type expected by the pattern.
64 using OpTy = typename llvm::function_traits<
65 decltype(&PatternTy::returningMatchAndRewrite)>::template arg_t<0>;
66 auto op = dyn_cast<OpTy>(operation);
67 if (!op)
68 return failure();
69
70 // Apply the pattern directly to the op.
71 PatternTy pattern(operation->getContext(), std::forward<Args>(args)...);
72 // We want to discourage direct use of PatternRewriter in APIs but In this
73 // very specific case, an IRRewriter is not enough.
74 struct TrivialPatternRewriter : public PatternRewriter {
75 public:
76 explicit TrivialPatternRewriter(MLIRContext *context)
77 : PatternRewriter(context) {}
78 };
79 TrivialPatternRewriter rewriter(operation->getContext());
80 rewriter.setInsertionPoint(operation);
81 auto result = pattern.returningMatchAndRewrite(op, rewriter);
82 if (failed(result))
83 return failure();
84 return cast<LinalgOp>(result->getOperation());
85}
86
87/// Assuming that `ofr` is an index attr or a param of index type
88/// or a transform dialect handle mapped to exactly one op
89/// with one index result, return that value.
90static DiagnosedSilenceableFailure unpackSingleIndexResultPayloadOperations(
91 transform::TransformState &state, TransformOpInterface transformOp,
92 SmallVector<OpFoldResult> &result, ArrayRef<OpFoldResult> ofrs) {
93 for (OpFoldResult ofr : ofrs) {
94 if (auto attr = dyn_cast<Attribute>(Val&: ofr)) {
95 if (!isa<IntegerAttr>(Val: attr))
96 return transformOp.emitDefiniteFailure() << "expected IntegerAttr";
97 result.push_back(Elt: ofr);
98 continue;
99 }
100
101 Value transformValue = cast<Value>(Val&: ofr);
102 if (isa<TransformParamTypeInterface>(Val: transformValue.getType())) {
103 ArrayRef<Attribute> params = state.getParams(value: transformValue);
104 if (params.size() != 1)
105 return transformOp.emitDefiniteFailure()
106 << "requires exactly one parameter associated";
107 result.push_back(Elt: params[0]);
108 continue;
109 }
110
111 auto payloadOps = state.getPayloadOps(value: transformValue);
112 if (!llvm::hasSingleElement(C&: payloadOps)) {
113 DiagnosedSilenceableFailure diag =
114 transformOp.emitSilenceableError()
115 << "handle must be mapped to exactly one payload op";
116 diag.attachNote(loc: transformValue.getLoc())
117 << "mapped to " << llvm::range_size(Range&: payloadOps) << " payload ops";
118 return diag;
119 }
120
121 Operation *op = *payloadOps.begin();
122 if (op->getNumResults() != 1 || !op->getResult(idx: 0).getType().isIndex()) {
123 DiagnosedSilenceableFailure diag =
124 transformOp.emitSilenceableError()
125 << "payload op must have exactly 1 index result";
126 diag.attachNote(loc: op->getLoc())
127 << "has " << op->getNumResults() << " results";
128 return diag;
129 }
130 result.push_back(Elt: op->getResult(idx: 0));
131 }
132
133 return DiagnosedSilenceableFailure::success();
134}
135
136// Given a list of params that are index attrs or a list of OpFoldResults
137// that are either index attrs or op handles, return a list of OpFoldResults
138// of index attrs or a list of OpFoldResults where all op handles are
139// replaced with the first (and only) OpResult of that payload op.
140// (There must be exactly one parameter associated with the AnyParamType or
141// one mapped payload op which must have exactly one index result.)
142static DiagnosedSilenceableFailure unpackSingleIndexResultPayloadOperations(
143 transform::TransformState &state, TransformOpInterface transformOp,
144 SmallVector<OpFoldResult> &result, Value packedHandle) {
145 if (isa<TransformParamTypeInterface>(Val: packedHandle.getType())) {
146 ArrayRef<Attribute> params = state.getParams(value: packedHandle);
147 for (auto param : params) {
148 if (!isa<IntegerAttr>(Val: param))
149 return transformOp.emitDefiniteFailure()
150 << "expected the parameter to be associated with an integer "
151 "attribute";
152 result.push_back(Elt: param);
153 }
154 return DiagnosedSilenceableFailure::success();
155 }
156
157 for (Operation *op : state.getPayloadOps(value: packedHandle)) {
158 if (op->getNumResults() != 1 || !op->getResult(idx: 0).getType().isIndex()) {
159 DiagnosedSilenceableFailure diag =
160 transformOp.emitSilenceableError()
161 << "payload op must have exactly 1 index result";
162 diag.attachNote(loc: op->getLoc())
163 << "has " << op->getNumResults() << " results";
164 return diag;
165 }
166 result.push_back(Elt: op->getResult(idx: 0));
167 }
168
169 return DiagnosedSilenceableFailure::success();
170}
171
172/// When possible, converts each `OpFoldResult` in `mixedResult` to
173/// an integer if the value can be statically inferred. If a result
174/// is a `Value` then it must be either a `ParamType` or a handle
175/// to an a constant like op.
176static DiagnosedSilenceableFailure reifyMixedParamAndHandleResults(
177 TransformState &state, TransformOpInterface &transformOp,
178 ArrayRef<OpFoldResult> mixedResults, SmallVectorImpl<int64_t> &reified) {
179 for (OpFoldResult paramOrHandle : mixedResults) {
180 if (auto attr = dyn_cast<Attribute>(Val&: paramOrHandle)) {
181 reified.push_back(Elt: cast<IntegerAttr>(Val&: attr).getInt());
182 continue;
183 } else if (isa<ParamType>(Val: cast<Value>(Val&: paramOrHandle).getType())) {
184 ArrayRef<Attribute> params = state.getParams(value: cast<Value>(Val&: paramOrHandle));
185 if (params.size() != 1)
186 return transformOp.emitSilenceableError() << "expected a single param";
187 reified.push_back(
188 Elt: cast<IntegerAttr>(Val: params.front()).getValue().getSExtValue());
189 continue;
190 }
191
192 Value handle = cast<Value>(Val&: paramOrHandle);
193 if (!isa<TransformHandleTypeInterface>(Val: handle.getType()))
194 return transformOp.emitSilenceableError() << "unexpected value handle";
195 auto payload = state.getPayloadOps(value: handle);
196 if (!llvm::hasSingleElement(C&: payload))
197 return transformOp.emitSilenceableError()
198 << "requires param or handle that is mapped to 1 payload op";
199
200 Operation *paramOrHandlePayloadOp = *payload.begin();
201 if (paramOrHandlePayloadOp->getNumResults() != 1 ||
202 !paramOrHandlePayloadOp->getResult(idx: 0).getType().isIndex()) {
203 return transformOp.emitSilenceableError()
204 << "requires param or handle to be result of op with 1 index "
205 "result";
206 }
207
208 IntegerAttr attr;
209 if (!matchPattern(value: paramOrHandlePayloadOp->getResult(idx: 0), pattern: m_Constant(bind_value: &attr)))
210 return transformOp.emitSilenceableError()
211 << "requires param or handle to be the result of a constant like "
212 "op";
213
214 reified.push_back(Elt: attr.getInt());
215 }
216 return DiagnosedSilenceableFailure::success();
217}
218
219//===----------------------------------------------------------------------===//
220// Apply...PatternsOp
221//===----------------------------------------------------------------------===//
222
223void transform::ApplyEraseUnnecessaryInputsPatternsOp::populatePatterns(
224 RewritePatternSet &patterns) {
225 linalg::populateEraseUnnecessaryInputsPatterns(patterns);
226}
227
228void transform::ApplyDecomposeTensorPackUnpackPatternsOp::populatePatterns(
229 RewritePatternSet &patterns) {
230 linalg::populateDecomposePackUnpackPatterns(patterns);
231}
232
233void transform::ApplyDecomposeTensorPadPatternsOp::populatePatterns(
234 RewritePatternSet &patterns) {
235 linalg::populateDecomposePadPatterns(patterns);
236}
237
238void transform::ApplyFoldUnitExtentDimsViaReshapesPatternsOp::populatePatterns(
239 RewritePatternSet &patterns) {
240 linalg::ControlDropUnitDims options;
241 linalg::populateFoldUnitExtentDimsPatterns(patterns, options);
242}
243
244void transform::ApplyFoldUnitExtentDimsViaSlicesPatternsOp::populatePatterns(
245 RewritePatternSet &patterns) {
246 linalg::ControlDropUnitDims options;
247 options.rankReductionStrategy =
248 linalg::ControlDropUnitDims::RankReductionStrategy::ExtractInsertSlice;
249 linalg::populateFoldUnitExtentDimsPatterns(patterns, options);
250}
251
252void transform::ApplyTilingCanonicalizationPatternsOp::populatePatterns(
253 RewritePatternSet &patterns) {
254 linalg::populateLinalgTilingCanonicalizationPatterns(patterns);
255}
256
257void transform::ApplyFoldAddIntoDestPatternsOp::populatePatterns(
258 RewritePatternSet &patterns) {
259 linalg::populateFoldAddIntoDestPatterns(patterns);
260}
261
262void transform::ApplyPadVectorizationPatternsOp::populatePatterns(
263 RewritePatternSet &patterns) {
264 linalg::populatePadOpVectorizationPatterns(patterns);
265}
266
267void transform::ApplyFoldIntoPackAndUnpackPatternsOp::populatePatterns(
268 RewritePatternSet &patterns) {
269 linalg::populateFoldIntoPackAndUnpackPatterns(patterns);
270}
271
272void transform::ApplyFoldPackUnpackIntoEmptyPatternsOp::populatePatterns(
273 RewritePatternSet &patterns) {
274 linalg::populateFoldPackUnpackIntoTensorEmptyPatterns(patterns);
275}
276
277//===----------------------------------------------------------------------===//
278// BufferizeToAllocationOp
279//===----------------------------------------------------------------------===//
280
281void transform::BufferizeToAllocationOp::build(OpBuilder &b,
282 OperationState &result,
283 Value target,
284 Attribute memorySpace) {
285 SmallVector<Type> resultTypes;
286 resultTypes.push_back(Elt: b.getType<transform::AnyValueType>());
287 resultTypes.push_back(Elt: b.getType<transform::AnyOpType>());
288 return build(odsBuilder&: b, odsState&: result,
289 /*resultTypes=*/resultTypes,
290 /*target=*/target,
291 /*memorySpace=*/memory_space: memorySpace);
292}
293
294void transform::BufferizeToAllocationOp::build(OpBuilder &b,
295 OperationState &result,
296 Value target,
297 int64_t memorySpace) {
298 SmallVector<Type> resultTypes;
299 resultTypes.push_back(Elt: b.getType<transform::AnyValueType>());
300 resultTypes.push_back(Elt: b.getType<transform::AnyOpType>());
301 return build(odsBuilder&: b, odsState&: result,
302 /*resultTypes=*/resultTypes,
303 /*target=*/target,
304 /*memorySpace=*/memory_space: b.getI64IntegerAttr(value: memorySpace));
305}
306
307namespace {
308class NewOpsListener : public RewriterBase::ForwardingListener {
309public:
310 using RewriterBase::ForwardingListener::ForwardingListener;
311
312 SmallVector<Operation *> getNewOps() const {
313 return SmallVector<Operation *>(newOps.begin(), newOps.end());
314 }
315
316private:
317 void notifyOperationInserted(Operation *op,
318 OpBuilder::InsertPoint previous) override {
319 ForwardingListener::notifyOperationInserted(op, previous);
320 // We only care about newly created ops.
321 if (previous.isSet())
322 return;
323 auto inserted = newOps.insert(V: op);
324 (void)inserted;
325 assert(inserted.second && "expected newly created op");
326 }
327
328 void notifyOperationErased(Operation *op) override {
329 ForwardingListener::notifyOperationErased(op);
330 op->walk(callback: [&](Operation *op) { newOps.erase(V: op); });
331 }
332
333 DenseSet<Operation *> newOps;
334};
335} // namespace
336
337DiagnosedSilenceableFailure transform::BufferizeToAllocationOp::apply(
338 transform::TransformRewriter &rewriter,
339 transform::TransformResults &results, transform::TransformState &state) {
340 // Attach listener to keep track of newly created ops.
341 OpBuilder::Listener *previousListener = rewriter.getListener();
342 auto resetListener =
343 llvm::make_scope_exit(F: [&]() { rewriter.setListener(previousListener); });
344 NewOpsListener newOpsListener(previousListener);
345 rewriter.setListener(&newOpsListener);
346
347 linalg::BufferizeToAllocationOptions options;
348 if (getMemcpyOp() == "bufferization.materialize_in_destination") {
349 options.memcpyOp = linalg::BufferizeToAllocationOptions::MemcpyOp::
350 MaterializeInDestination;
351 } else if (getMemcpyOp() == "memref.copy") {
352 options.memcpyOp =
353 linalg::BufferizeToAllocationOptions::MemcpyOp::MemrefCopy;
354 } else if (getMemcpyOp() == "linalg.copy") {
355 options.memcpyOp =
356 linalg::BufferizeToAllocationOptions::MemcpyOp::LinalgCopy;
357 } else {
358 llvm_unreachable("invalid memcpy op");
359 }
360 if (getAllocOp() == "memref.alloc") {
361 options.allocOp =
362 linalg::BufferizeToAllocationOptions::AllocOp::MemrefAlloc;
363 } else if (getAllocOp() == "memref.alloca") {
364 options.allocOp =
365 linalg::BufferizeToAllocationOptions::AllocOp::MemrefAlloca;
366 } else {
367 llvm_unreachable("invalid alloc op");
368 }
369 options.bufferizeDestinationOnly = getBufferizeDestinationOnly();
370 options.emitDealloc = getEmitDealloc();
371
372 // Bufferize ops.
373 Attribute memorySpace =
374 getMemorySpace().has_value() ? getMemorySpace().value() : Attribute();
375 SmallVector<Value> allocatedBuffers;
376 for (Operation *op : state.getPayloadOps(value: getTarget())) {
377 Value buffer =
378 linalg::bufferizeToAllocation(rewriter, options, op, memorySpace);
379 if (!buffer) {
380 DiagnosedSilenceableFailure diag = emitSilenceableError()
381 << "failed to bufferize operation";
382 diag.attachNote(loc: op->getLoc()) << "target payload op";
383 return diag;
384 }
385 allocatedBuffers.push_back(Elt: buffer);
386 }
387
388 // Set results.
389 results.setValues(handle: cast<OpResult>(Val: getAllocatedBuffer()), values&: allocatedBuffers);
390 results.set(value: cast<OpResult>(Val: getNewOps()), ops: newOpsListener.getNewOps());
391 return DiagnosedSilenceableFailure::success();
392}
393
394void transform::BufferizeToAllocationOp::getEffects(
395 SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
396 if (getBufferizeDestinationOnly()) {
397 // The destination is replaced with a newly allocated buffer, but the op
398 // itself remains in place.
399 onlyReadsHandle(handles: getTargetMutable(), effects);
400 } else {
401 consumesHandle(handles: getTargetMutable(), effects);
402 }
403 producesHandle(handles: getOperation()->getOpResults(), effects);
404 modifiesPayload(effects);
405}
406
407LogicalResult transform::BufferizeToAllocationOp::verify() {
408 if (getMemcpyOp() != "bufferization.materialize_in_destination" &&
409 getMemcpyOp() != "memref.copy" && getMemcpyOp() != "linalg.copy")
410 return emitOpError() << "unsupported memcpy op";
411 if (getAllocOp() != "memref.alloc" && getAllocOp() != "memref.alloca")
412 return emitOpError() << "unsupported alloc op";
413 return success();
414}
415
416//===----------------------------------------------------------------------===//
417// DecomposeOp
418//===----------------------------------------------------------------------===//
419
420DiagnosedSilenceableFailure
421transform::DecomposeOp::applyToOne(transform::TransformRewriter &rewriter,
422 LinalgOp target,
423 transform::ApplyToEachResultList &results,
424 transform::TransformState &state) {
425#define DOWNSCALE(trans) \
426 { \
427 FailureOr<LinalgOp> res = tryApply<trans>(target); \
428 if (succeeded(res)) { \
429 results.push_back(*res); \
430 return DiagnosedSilenceableFailure::success(); \
431 } \
432 }
433
434#define DOWNSCALE_CALL(a, b) DownscaleSizeOneWindowed2DConvolution<a, b>
435#define DOWNSCALE_NORMAL(a, b) DOWNSCALE(DOWNSCALE_CALL(a, b))
436
437 DOWNSCALE_NORMAL(Conv2DNhwcHwcfOp, Conv1DNwcWcfOp)
438 DOWNSCALE_NORMAL(Conv2DNchwFchwOp, Conv1DNcwFcwOp)
439 DOWNSCALE_NORMAL(PoolingNhwcSumOp, PoolingNwcSumOp)
440 DOWNSCALE_NORMAL(PoolingNchwSumOp, PoolingNcwSumOp)
441 DOWNSCALE_NORMAL(PoolingNhwcMaxOp, PoolingNwcMaxOp)
442 DOWNSCALE_NORMAL(PoolingNhwcMaxUnsignedOp, PoolingNwcMaxUnsignedOp)
443 DOWNSCALE_NORMAL(PoolingNhwcMinOp, PoolingNwcMinOp)
444 DOWNSCALE_NORMAL(PoolingNhwcMinUnsignedOp, PoolingNwcMinUnsignedOp)
445 DOWNSCALE_NORMAL(PoolingNchwMaxOp, PoolingNcwMaxOp)
446 DOWNSCALE(DownscaleDepthwiseConv2DNhwcHwcOp)
447 DOWNSCALE(DownscaleConv2DOp)
448#undef DOWNSCALE_NORMAL
449#undef DOWNSCALE_CALL
450#undef DOWNSCALE
451 return emitDefaultSilenceableFailure(target);
452}
453
454//===----------------------------------------------------------------------===//
455// DecomposeInterfaceOp
456//===----------------------------------------------------------------------===//
457
458// Decompose the target operation if it implements the AggregatedOpInterface.
459// Push the decomposed operations (the ones that replaces the values produced by
460// \p target) in the `results`.
461DiagnosedSilenceableFailure transform::DecomposeInterfaceOp::applyToOne(
462 transform::TransformRewriter &rewriter, Operation *target,
463 transform::ApplyToEachResultList &results,
464 transform::TransformState &state) {
465 auto decomposableOp = dyn_cast<AggregatedOpInterface>(Val: target);
466 if (!decomposableOp) {
467 failed(Result: rewriter.notifyMatchFailure(arg&: target,
468 msg: "payload is not a decomposable op"));
469 return emitDefaultSilenceableFailure(target);
470 }
471
472 FailureOr<SmallVector<Value>> maybeNewResults =
473 decomposableOp.decomposeOperation(b&: rewriter);
474 if (failed(Result: maybeNewResults))
475 return emitDefaultSilenceableFailure(target);
476
477 rewriter.replaceOp(op: decomposableOp, newValues: *maybeNewResults);
478 for (Value val : *maybeNewResults) {
479 Operation *definition = val.getDefiningOp();
480 if (definition)
481 results.push_back(op: definition);
482 }
483 return DiagnosedSilenceableFailure::success();
484}
485
486//===----------------------------------------------------------------------===//
487// EliminateLinalgOpAnchoredEmptyTensorsOp
488//===----------------------------------------------------------------------===//
489
490void transform::EliminateLinalgOpAnchoredEmptyTensorsOp::getEffects(
491 SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
492 onlyReadsHandle(handles: getTargetMutable(), effects);
493 modifiesPayload(effects);
494}
495
496DiagnosedSilenceableFailure
497transform::EliminateLinalgOpAnchoredEmptyTensorsOp::apply(
498 transform::TransformRewriter &rewriter, TransformResults &transformResults,
499 TransformState &state) {
500 bufferization::OneShotBufferizationOptions options;
501 options.allowReturnAllocsFromLoops = true;
502
503 for (Operation *target : state.getPayloadOps(value: getTarget())) {
504 bufferization::OneShotAnalysisState state(target, options);
505 if (failed(Result: analyzeOp(op: target, state)))
506 return mlir::emitSilenceableFailure(loc: target->getLoc())
507 << "failed to analyze op";
508 if (failed(Result: linalg::linalgOpAnchoredEmptyTensorEliminationStep(
509 rewriter, op: target, state)))
510 return mlir::emitSilenceableFailure(loc: target->getLoc())
511 << "failed to eliminate LinalgOp anchored tensor.empty ops";
512 }
513 return DiagnosedSilenceableFailure::success();
514}
515
516//===----------------------------------------------------------------------===//
517// FuseOp
518//===----------------------------------------------------------------------===//
519
520/// Apply a tiling transformation to all payload ops and store both the
521/// tiled operation as well as the created tile loops.
522template <typename Range>
523static LogicalResult applyTilingToAll(
524 RewriterBase &rewriter, Operation *transformOp, Range &&payloadOps,
525 unsigned numLoops, transform::TransformResults &transformResults,
526 function_ref<FailureOr<scf::SCFTileAndFuseResult>(TilingInterface)>
527 applyFn) {
528 SmallVector<Operation *> tiledLinalgOps;
529 SmallVector<SmallVector<Operation *>> loopOps(numLoops);
530
531 for (Operation *target : payloadOps) {
532 auto tilingInterfaceOp = dyn_cast<TilingInterface>(Val: target);
533 if (!tilingInterfaceOp)
534 return transformOp->emitError(message: "only TilingInterface ops are supported");
535
536 rewriter.setInsertionPoint(target);
537 FailureOr<scf::SCFTileAndFuseResult> tiledResults =
538 applyFn(tilingInterfaceOp);
539 if (failed(Result: tiledResults))
540 return failure();
541
542 // Perform the replacement of tiled and fused values.
543 SmallVector<Operation *> opsToReplace{target};
544 llvm::append_range(C&: opsToReplace, R&: tiledResults->fusedProducers);
545 for (Operation *toReplace : opsToReplace) {
546 for (OpResult res : toReplace->getResults())
547 if (auto replacement = tiledResults->replacements.lookup(Val: res))
548 rewriter.replaceAllUsesWith(from: res, to: replacement);
549 if (toReplace->use_empty()) {
550 rewriter.eraseOp(op: toReplace);
551 }
552 }
553
554 // Report back the relevant handles to the transform op.
555 tiledLinalgOps.push_back(Elt: tiledResults->tiledAndFusedOps.front());
556 assert(tiledResults->loops.size() == numLoops &&
557 "Mismatched number of loops, tile and fuse transform should have "
558 "failed");
559 for (unsigned int i = 0; i < numLoops; ++i)
560 loopOps[i].push_back(Elt: tiledResults->loops[i]);
561 }
562
563 transformResults.set(value: transformOp->getOpResult(idx: 0), ops&: tiledLinalgOps);
564 for (unsigned int i = 0; i < numLoops; ++i)
565 transformResults.set(value: transformOp->getOpResult(idx: i + 1), ops&: loopOps[i]);
566
567 return success();
568}
569
570DiagnosedSilenceableFailure
571transform::FuseOp::apply(transform::TransformRewriter &rewriter,
572 mlir::transform::TransformResults &transformResults,
573 mlir::transform::TransformState &state) {
574 SmallVector<int64_t> tileSizes =
575 extractFromIntegerArrayAttr<int64_t>(attr: getTileSizes());
576 SmallVector<int64_t> tileInterchange =
577 extractFromIntegerArrayAttr<int64_t>(attr: getTileInterchange());
578
579 scf::SCFTilingOptions tilingOptions;
580 tilingOptions.interchangeVector = tileInterchange;
581 SmallVector<OpFoldResult> tileSizesOfr =
582 getAsIndexOpFoldResult(ctx: rewriter.getContext(), values: tileSizes);
583 tilingOptions = tilingOptions.setTileSizes(tileSizesOfr);
584 scf::SCFTileAndFuseOptions tileAndFuseOptions;
585 tileAndFuseOptions.tilingOptions = tilingOptions;
586
587 if (getApplyCleanup()) {
588 MLIRContext *context = rewriter.getContext();
589 RewritePatternSet patterns(context);
590 tensor::ExtractSliceOp::getCanonicalizationPatterns(results&: patterns, context);
591 tensor::populateMergeConsecutiveInsertExtractSlicePatterns(patterns);
592 tensor::populateBubbleUpExtractSliceOpPatterns(patterns);
593 tileAndFuseOptions.cleanupPatterns = std::move(patterns);
594 }
595
596 LogicalResult result = applyTilingToAll(
597 rewriter, transformOp: getOperation(), payloadOps: state.getPayloadOps(value: getTarget()),
598 numLoops: tileSizes.size() - llvm::count(Range&: tileSizes, Element: 0), transformResults,
599 applyFn: [&](TilingInterface tilingInterfaceOp)
600 -> FailureOr<scf::SCFTileAndFuseResult> {
601 return tileConsumerAndFuseProducersUsingSCF(rewriter, consumer: tilingInterfaceOp,
602 options: tileAndFuseOptions);
603 });
604 return failed(Result: result) ? DiagnosedSilenceableFailure::definiteFailure()
605 : DiagnosedSilenceableFailure::success();
606}
607
608LogicalResult transform::FuseOp::verify() {
609 SmallVector<int64_t> permutation =
610 extractFromIntegerArrayAttr<int64_t>(attr: getTileInterchange());
611 auto sequence = llvm::to_vector(Range: llvm::seq<int64_t>(Begin: 0, End: permutation.size()));
612 if (!std::is_permutation(first1: sequence.begin(), last1: sequence.end(),
613 first2: permutation.begin(), last2: permutation.end())) {
614 return emitOpError() << "expects interchange to be a permutation, found "
615 << getTileInterchange();
616 }
617
618 SmallVector<int64_t> sizes =
619 extractFromIntegerArrayAttr<int64_t>(attr: getTileSizes());
620 size_t numExpectedLoops = sizes.size() - llvm::count(Range&: sizes, Element: 0);
621 if (numExpectedLoops != getNumResults() - 1)
622 return emitOpError() << "expects " << numExpectedLoops << " loop results";
623
624 return success();
625}
626
627//===----------------------------------------------------------------------===//
628// FuseIntoContainingOp
629//===----------------------------------------------------------------------===//
630
631void transform::FuseIntoContainingOp::build(OpBuilder &builder,
632 OperationState &result,
633 Value producerOp,
634 Value containingOp) {
635 result.addOperands(newOperands: {producerOp, containingOp});
636 auto resultType = transform::AnyOpType::get(ctx: builder.getContext());
637 result.addTypes(newTypes: {resultType, resultType});
638}
639
640/// Add new operands to the forall op for users of the producerOp
641/// that are dominated by the containing scf.forall op.
642static Operation *replaceForAllWithNewSignature(
643 RewriterBase &rewriter, Diagnostic &diag, Operation *producerOp,
644 Operation *containingOp, TilingResult &tileAndFuseResult,
645 int64_t resultNumber, SmallVector<OpFoldResult> &offsets,
646 SmallVector<OpFoldResult> &sizes) {
647
648 // Count number of users not including the containing op
649 SetVector<Operation *> dominatedUsers;
650 DominanceInfo domInfo(containingOp);
651 for (Operation *user : producerOp->getResult(idx: resultNumber).getUsers()) {
652 if (!containingOp->isAncestor(other: user) &&
653 (domInfo.dominates(a: containingOp, b: user))) {
654 dominatedUsers.insert(X: user);
655 }
656 }
657 if (dominatedUsers.empty())
658 return nullptr;
659
660 // Create new scf.forall op
661 auto forallOp = cast<scf::ForallOp>(Val: containingOp);
662 OpBuilder::InsertionGuard g(rewriter);
663 rewriter.setInsertionPoint(forallOp);
664
665 // Get new output
666 Location loc = forallOp.getLoc();
667 auto genericOp = dyn_cast<linalg::GenericOp>(Val: producerOp);
668 if (!genericOp)
669 return nullptr;
670 SmallVector<Value> outputs = genericOp.getOutputs();
671 SmallVector<Value> newOuts(forallOp.getOutputs());
672 newOuts.push_back(Elt: outputs[resultNumber]);
673
674 // Create new scf.forall op
675 auto newforallOp = rewriter.create<scf::ForallOp>(
676 location: loc, args: forallOp.getMixedLowerBound(), args: forallOp.getMixedUpperBound(),
677 args: forallOp.getMixedStep(), args&: newOuts, args: forallOp.getMapping());
678 rewriter.eraseBlock(block: newforallOp.getBody());
679 newforallOp.getRegion().takeBody(other&: forallOp.getRegion());
680
681 // Add additional block argument for new value being returned
682 // and replaces all uses of the new output with corresponding bbArg
683 // inside the scf.forall to enable fusion into this new scf.forall.
684 newforallOp.getBody()->addArgument(type: newOuts.back().getType(),
685 loc: newOuts.back().getLoc());
686 auto bbArgs = newforallOp.getBody()->getArguments();
687 rewriter.replaceUsesWithIf(from: newOuts.back(), to: bbArgs.back(),
688 functor: [&](OpOperand &use) {
689 Operation *op = use.getOwner();
690 return newforallOp->isProperAncestor(other: op);
691 });
692
693 // Fix terminator
694 scf::InParallelOp terminatorOp = newforallOp.getTerminator();
695 SmallVector<Operation *> yieldingOps = llvm::to_vector<4>(Range: llvm::map_range(
696 C: terminatorOp.getYieldingOps(), F: [](Operation &op) { return &op; }));
697 Operation *firstYieldOp = yieldingOps.front();
698 rewriter.setInsertionPoint(firstYieldOp);
699 Value src = tileAndFuseResult.tiledValues[0];
700 Value dst = newforallOp.getRegionIterArgs().back();
701 SmallVector<OpFoldResult> strides(offsets.size(), rewriter.getIndexAttr(value: 1));
702 rewriter.create<tensor::ParallelInsertSliceOp>(location: firstYieldOp->getLoc(), args&: src,
703 args&: dst, args&: offsets, args&: sizes, args&: strides);
704
705 for (auto result : llvm::enumerate(First: forallOp.getResults())) {
706 rewriter.replaceAllUsesWith(from: result.value(),
707 to: newforallOp->getResult(idx: result.index()));
708 }
709 rewriter.replaceUsesWithIf(from: producerOp->getResult(idx: resultNumber),
710 to: newforallOp->getResults().back(),
711 functor: [&](OpOperand &use) {
712 Operation *user = use.getOwner();
713 return dominatedUsers.contains(key: user);
714 });
715 return newforallOp;
716}
717
718/// Given two operands coming from a loop iter arg, 'src' and 'dst', return true
719/// if the operand 'src' is equal to 'dst' or equal to a iter arg present in a
720/// outer loop. To determine the second condition, this function iterates
721/// using a worklist over the enclosing loops, trying to find 'src' in any of
722/// the parent loop's iter args.
723static bool sameOrEquivalentIterArg(Value src, Value dst) {
724 // Stack like vector containing possible iterArgs candidates. The first one
725 // is dst, and we will transverse the IR from there.
726 SmallVector<Value> destWorklist;
727 destWorklist.push_back(Elt: dst);
728
729 while (!destWorklist.empty()) {
730 Value currentDst = destWorklist.pop_back_val();
731
732 // We have found the same operand in some iter arg in the loop structure,
733 // so src and dst are equivalent.
734 if (src == currentDst)
735 return true;
736
737 // The operands are not equivalent, look for enclosing loops over
738 // currentDst.
739 auto bbArg = dyn_cast<BlockArgument>(Val&: currentDst);
740 if (!bbArg)
741 continue;
742
743 Block *parentBlock = bbArg.getOwner();
744 assert(parentBlock && "unlinked block argument");
745
746 Operation *parentOp = parentBlock->getParentOp();
747 assert(parentOp && "expected block argument with parent operation");
748
749 // Check if parent is loop-like. If it's not, do not add it to the worklist.
750 auto parentLoop = dyn_cast<LoopLikeOpInterface>(Val: parentOp);
751 if (!parentLoop)
752 continue;
753
754 for (auto innerIterArg : parentLoop.getRegionIterArgs()) {
755 // No need to check for null as innerIterArg is tied to parentLoop.
756 OpOperand *operand = parentLoop.getTiedLoopInit(bbArg: innerIterArg);
757 Value loopBlockArgument =
758 parentLoop->getOperand(idx: operand->getOperandNumber());
759 destWorklist.push_back(Elt: loopBlockArgument);
760 }
761 }
762
763 return false;
764}
765
766/// Find the first "extract" user of `producerOp` and tile it right before its
767/// use. The tiled op is fused under the `containingOp`.
768/// Return this fused op on success or nullptr if anything fails.
769/// If tiled op has uses that are dominated by `containingOp`, return
770/// a new `containingOp` with results of the fused op appended to
771/// results of the `containingOp` or nullptr if there are no dominated uses.
772static std::tuple<SmallVector<Operation *>, Operation *>
773tileAndFuseFirstExtractUse(RewriterBase &rewriter, Diagnostic &diag,
774 Operation *producerOp, Operation *containingOp) {
775 LLVM_DEBUG(DBGS() << "Try to fuse a direct extract use\n");
776 auto tileableProducer = dyn_cast<TilingInterface>(Val: producerOp);
777 if (!tileableProducer) {
778 diag.attachNote(noteLoc: producerOp->getLoc())
779 << "producer is not a TileableInterface: " << *producerOp;
780 return {};
781 }
782
783 // Search the producer slices accessed within the containing operation.
784 // TODO: Generalize to more extract/insert/parallel_insert triples, maybe
785 // evolve into an interface.
786 auto it = llvm::find_if(Range: tileableProducer->getUsers(), P: [&](Operation *user) {
787 auto sliceOp = dyn_cast<tensor::ExtractSliceOp>(Val: user);
788 return sliceOp && containingOp->isProperAncestor(other: sliceOp);
789 });
790
791 // Find a fusion opportunity.
792 if (it == tileableProducer->getUsers().end()) {
793 diag.attachNote(noteLoc: tileableProducer->getLoc())
794 << "could not find fusion opportunity for: " << *tileableProducer;
795 return {};
796 }
797 auto sliceOpToTile = cast<tensor::ExtractSliceOp>(Val: *it);
798
799 // Try to fuse the producer in-place.
800 OpBuilder::InsertionGuard guard(rewriter);
801 rewriter.setInsertionPoint(sliceOpToTile);
802
803 // Clone the producer inside the consumer and try to update the producer init
804 // operands using the loop bbArgs if applicable. More precisely, if the bbArg
805 // of the container loop points to a value that it is used by the consumer op,
806 // then, instead of using such value on the consumer, use the value coming
807 // from the bbArg instead. This allows to reuse the output tensor (instead of
808 // creating a new one) of the container when both producer and container write
809 // to the same output.
810 if (LoopLikeOpInterface containerLoop =
811 dyn_cast<LoopLikeOpInterface>(Val: sliceOpToTile->getParentOp())) {
812 Operation *clone = rewriter.clone(op&: *producerOp);
813 rewriter.modifyOpInPlace(root: clone, callable: [&]() {
814 // Iterate over the outputs of the producer and over the loop bbArgs and
815 // check if any bbArg points to the same value as the producer output. In
816 // such case, make the producer output point to the bbArg directly.
817 for (OpOperand &initOperandPtr :
818 cast<DestinationStyleOpInterface>(Val: clone).getDpsInitsMutable()) {
819 Value producerOperand =
820 clone->getOperand(idx: initOperandPtr.getOperandNumber());
821 for (BlockArgument containerIterArg :
822 containerLoop.getRegionIterArgs()) {
823 OpOperand *bbArg = containerLoop.getTiedLoopInit(bbArg: containerIterArg);
824 Value consumerOperand =
825 containerLoop->getOperand(idx: bbArg->getOperandNumber());
826 // The producer has the same init as the loop bbArg, use it.
827 if (sameOrEquivalentIterArg(src: producerOperand, dst: consumerOperand)) {
828 initOperandPtr.set(containerIterArg);
829 }
830 }
831 }
832 });
833
834 tileableProducer = dyn_cast<TilingInterface>(Val: clone);
835 }
836
837 // Tile the producer.
838 int64_t resultNumber =
839 cast<OpResult>(Val: sliceOpToTile.getSource()).getResultNumber();
840 LLVM_DEBUG(DBGS() << "resultNumber: " << resultNumber << "\n");
841
842 SmallVector<OpFoldResult> offsets = sliceOpToTile.getMixedOffsets();
843 SmallVector<OpFoldResult> sizes = sliceOpToTile.getMixedSizes();
844
845 FailureOr<TilingResult> tileAndFuseResult =
846 tileableProducer.generateResultTileValue(b&: rewriter, resultNumber, offsets,
847 sizes);
848
849 if (failed(Result: tileAndFuseResult)) {
850 diag.attachNote(noteLoc: tileableProducer->getLoc())
851 << "failed to tile producer op: " << *tileableProducer;
852 return {};
853 }
854
855#ifndef NDEBUG
856 for (auto *tiledOp : tileAndFuseResult->tiledOps) {
857 LLVM_DEBUG(DBGS() << "tiledProducer: " << *tiledOp << "\n");
858 }
859#endif
860
861 // Replace the extract op.
862 auto maybeRankReduced = tensor::ExtractSliceOp::rankReduceIfNeeded(
863 b&: rewriter, loc: sliceOpToTile->getLoc(), value: tileAndFuseResult->tiledValues[0],
864 desiredShape: cast<RankedTensorType>(Val: sliceOpToTile->getResult(idx: 0).getType()).getShape());
865 if (failed(Result: maybeRankReduced)) {
866 diag.attachNote(noteLoc: producerOp->getLoc())
867 << "shape types don't match (missing canonicalization?):\nTiledOp: "
868 << tileAndFuseResult->tiledValues[0]
869 << "\nSliceOp: " << sliceOpToTile.getOperation() << '\n';
870 return {};
871 }
872 rewriter.replaceOp(op: sliceOpToTile, newValues: *maybeRankReduced);
873
874 // Add new outputs to containing op, if required
875 Operation *newContainingOp = replaceForAllWithNewSignature(
876 rewriter, diag, producerOp, containingOp, tileAndFuseResult&: *tileAndFuseResult,
877 resultNumber, offsets, sizes);
878
879 // Cleanup clone.
880 if (dyn_cast<LoopLikeOpInterface>(Val: containingOp))
881 rewriter.eraseOp(op: tileableProducer);
882
883 return std::make_tuple(args&: tileAndFuseResult->tiledOps, args&: newContainingOp);
884}
885
886/// First, find the first "scf::ForallOp" user of `producerOp` and ensure
887/// it is exactly the `containingOp`, otherwise bail.
888/// Then, find the first "extract" user of the tied block argument and tile it
889/// right before its "extract" use. The tiled op is fused under the
890/// `containingOp`.
891/// Return this fused op on success or nullptr if anything fails.
892static SmallVector<Operation *>
893tileAndFuseFirstExtractUseThroughContainingOpBlockArgument(
894 RewriterBase &rewriter, Diagnostic &diag, Operation *producerOp,
895 Operation *containingOp) {
896 LLVM_DEBUG(DBGS() << "Try to fuse an extract use through block argument\n");
897
898 auto tileableProducer = dyn_cast<TilingInterface>(Val: producerOp);
899 if (!tileableProducer) {
900 diag.attachNote(noteLoc: producerOp->getLoc())
901 << "producer is not a TileableInterface: " << *producerOp;
902 return {};
903 }
904
905 // Search the first use by a "scf::ForallOp" user.
906 scf::ForallOp forallOp;
907 auto itProducerUses =
908 llvm::find_if(Range: tileableProducer->getUses(), P: [&](OpOperand &use) {
909 forallOp = dyn_cast<scf::ForallOp>(Val: use.getOwner());
910 return forallOp;
911 });
912 // If it's not from the containing op, return.
913 if (!forallOp || forallOp != containingOp) {
914 diag.attachNote(noteLoc: tileableProducer->getLoc())
915 << "could not find a use by the containing op: " << *tileableProducer;
916 return {};
917 }
918
919 // Search the producer slices accessed within the containing
920 // operation.
921 // TODO: Generalize to more extract/insert/parallel_insert triples.
922 // Maybe evolve into an interface.
923 OpOperand *pUse = &(*itProducerUses);
924 BlockArgument bbArg = forallOp.getTiedBlockArgument(opOperand: pUse);
925
926 // Search the producer slices accessed within the containing operation.
927 // TODO: Generalize to more extract/insert/parallel_insert triples, maybe
928 // evolve into an interface.
929 auto itBBArgUsers = llvm::find_if(Range: bbArg.getUsers(), P: [&](Operation *user) {
930 auto sliceOp = dyn_cast<tensor::ExtractSliceOp>(Val: user);
931 return sliceOp && containingOp->isProperAncestor(other: sliceOp);
932 });
933
934 // Find a fusion opportunity.
935 if (itBBArgUsers == bbArg.getUsers().end()) {
936 diag.attachNote(noteLoc: containingOp->getLoc())
937 << "could not find fusion opportunity for bbArg: " << bbArg;
938 return {};
939 }
940 auto sliceOpToTile = cast<tensor::ExtractSliceOp>(Val: *itBBArgUsers);
941
942 // Try to fuse the producer in-place.
943 OpBuilder::InsertionGuard guard(rewriter);
944 rewriter.setInsertionPoint(sliceOpToTile);
945
946 // Replace the use in the tileableProducer before tiling: clone, replace and
947 // then tile.
948 int64_t resultNumber = cast<OpResult>(Val: pUse->get()).getResultNumber();
949 LLVM_DEBUG(DBGS() << "resultNumber: " << resultNumber << "\n");
950
951 // Gather destination tensors.
952 SmallVector<Value> destinationTensors;
953 if (failed(Result: tensor::getOrCreateDestinations(
954 b&: rewriter, loc: tileableProducer->getLoc(), op: tileableProducer,
955 result&: destinationTensors))) {
956 diag.attachNote(noteLoc: tileableProducer->getLoc())
957 << "failed to get destination tensors for: " << *tileableProducer;
958 return {};
959 }
960
961 IRMapping bvm;
962 bvm.map(from: destinationTensors[resultNumber], to: bbArg);
963 auto tileableProducerClone =
964 cast<TilingInterface>(Val: rewriter.clone(op&: *tileableProducer, mapper&: bvm));
965 auto scopeGuard =
966 llvm::make_scope_exit(F: [&]() { rewriter.eraseOp(op: tileableProducerClone); });
967
968 // Tile the producer.
969 FailureOr<TilingResult> tileAndFuseResult =
970 tileableProducerClone.generateResultTileValue(
971 b&: rewriter, resultNumber, offsets: sliceOpToTile.getMixedOffsets(),
972 sizes: sliceOpToTile.getMixedSizes());
973 if (failed(Result: tileAndFuseResult)) {
974 diag.attachNote(noteLoc: tileableProducer->getLoc())
975 << "failed to tile producer op: " << *tileableProducer;
976 return {};
977 }
978
979 // Replace the extract op.
980 auto maybeRankReduced = tensor::ExtractSliceOp::rankReduceIfNeeded(
981 b&: rewriter, loc: sliceOpToTile->getLoc(), value: tileAndFuseResult->tiledValues[0],
982 desiredShape: cast<RankedTensorType>(Val: sliceOpToTile->getResult(idx: 0).getType()).getShape());
983 assert(succeeded(maybeRankReduced) && "unexpected shape");
984 rewriter.replaceOp(op: sliceOpToTile, newValues: *maybeRankReduced);
985
986 // Replace the use in containingOp.
987 rewriter.modifyOpInPlace(root: containingOp, callable: [&]() {
988 containingOp->setOperand(idx: pUse->getOperandNumber(),
989 value: destinationTensors.front());
990 });
991
992 return tileAndFuseResult->tiledOps;
993}
994
995static Operation *cloneAndFuseFirstUse(RewriterBase &rewriter, Diagnostic &diag,
996 Operation *producerOp,
997 Operation *containingOp) {
998 LLVM_DEBUG(DBGS() << "Try to fuse an use by cloning\n");
999
1000 // Gather all uses inside the containing op.
1001 SmallVector<OpOperand *> uses;
1002 for (OpResult result : producerOp->getOpResults()) {
1003 for (OpOperand &use : result.getUses()) {
1004 if (containingOp->isProperAncestor(other: use.getOwner())) {
1005 uses.push_back(Elt: &use);
1006 continue;
1007 }
1008 // Cannot clone and fuse if the use is by the containing op itself: fail
1009 // immediately.
1010 if (containingOp == use.getOwner()) {
1011 diag.attachNote(noteLoc: producerOp->getLoc())
1012 << "producer op use by containing op cannot be fused by cloning";
1013 return nullptr;
1014 }
1015 }
1016 }
1017
1018 // Check for a non-empty list of fusion opportunities.
1019 if (uses.empty()) {
1020 diag.attachNote(noteLoc: producerOp->getLoc()) << "no fusion opportunity by cloning";
1021 return nullptr;
1022 }
1023
1024 // Clone and fuse inside the containing op.
1025 Operation *fusedOp = nullptr;
1026 OpOperand *use = uses.front();
1027 // Parallel insert slice is not a valid clone destination.
1028 // TODO: Generalize to other type of ops.
1029 assert(!isa<tensor::ParallelInsertSliceOp>(use->getOwner()) &&
1030 "Parallel insert slice is not a valid clone destination");
1031 unsigned resultNumber = cast<OpResult>(Val: use->get()).getResultNumber();
1032 LLVM_DEBUG(DBGS() << "resultNumber: " << resultNumber << "\n");
1033
1034 OpBuilder::InsertionGuard guard(rewriter);
1035 rewriter.setInsertionPoint(use->getOwner());
1036 fusedOp = rewriter.clone(op&: *producerOp);
1037 rewriter.modifyOpInPlace(
1038 root: use->getOwner(), callable: [&] { use->set(fusedOp->getOpResult(idx: resultNumber)); });
1039
1040 return fusedOp;
1041}
1042
1043bool transform::FuseIntoContainingOp::allowsRepeatedHandleOperands() {
1044 // Allow repeated handles since we are fusing everything anyway.
1045 return true;
1046}
1047
1048DiagnosedSilenceableFailure
1049transform::FuseIntoContainingOp::apply(transform::TransformRewriter &rewriter,
1050 transform::TransformResults &results,
1051 transform::TransformState &state) {
1052 SmallVector<Operation *> fusedOps;
1053 auto producerOps = state.getPayloadOps(value: getProducerOp());
1054 auto containingOps = state.getPayloadOps(value: getContainingOp());
1055 if (!llvm::hasSingleElement(C&: containingOps)) {
1056 return emitDefiniteFailure()
1057 << "requires exactly one containing_op handle (got "
1058 << llvm::range_size(Range&: containingOps) << ")";
1059 }
1060 Operation *containingOp = *containingOps.begin();
1061
1062 // If nothing to fuse, propagate success.
1063 if (std::empty(cont: producerOps)) {
1064 results.set(value: cast<OpResult>(Val: getFusedOp()), ops: SmallVector<mlir::Operation *>{});
1065 results.set(value: cast<OpResult>(Val: getNewContainingOp()), ops: {containingOp});
1066 return DiagnosedSilenceableFailure::success();
1067 }
1068
1069 // Helper function to find the next producer that should be fused. Take any
1070 // producer that has a use inside the containing op.
1071 SetVector<Operation *> remainingProducers(llvm::from_range, producerOps);
1072 auto getNextProducer = [&]() -> FailureOr<Operation *> {
1073 for (const auto &it : enumerate(First&: remainingProducers)) {
1074 Operation *producerOp = it.value();
1075 // The containing op may be a user of producerOp: use isAncestor.
1076 int64_t numUsesInContainingOp =
1077 llvm::count_if(Range: producerOp->getUsers(), P: [&](Operation *op) {
1078 return containingOp->isAncestor(other: op);
1079 });
1080 // TODO: When resolving the TODO below (no duplicate ops), take an op
1081 // that has no use among the remaining producers. This is a topological
1082 // sorting.
1083 if (numUsesInContainingOp > 0) {
1084 if (numUsesInContainingOp == 1)
1085 remainingProducers.erase(I: remainingProducers.begin() + it.index());
1086 return producerOp;
1087 }
1088 }
1089 return failure();
1090 };
1091
1092 while (!remainingProducers.empty()) {
1093 auto nextProducer = getNextProducer();
1094 if (failed(Result: nextProducer)) {
1095 auto diag = mlir::emitSilenceableFailure(loc: getLoc())
1096 << "could not find next producer to fuse into container";
1097 diag.attachNote(loc: containingOp->getLoc()) << "containing op";
1098 return diag;
1099 }
1100
1101 Operation *producerOp = *nextProducer;
1102
1103 // Default diagnostic, to be complemented with more failure information.
1104 Diagnostic diag(producerOp->getLoc(), DiagnosticSeverity::Remark);
1105 diag << "could not fuse " << *producerOp << " into " << *containingOp;
1106
1107 // TODO: If there are multiple uses of the producer in the containing op,
1108 // we currently tile/clone the op multiple times (once per use). In some
1109 // cases, we can tile/clone once and reuse the value for each use.
1110 // Futhermore, producers should then be traversed according to a
1111 // topological sorting.
1112 auto [tiledOps, newContainingOp] =
1113 tileAndFuseFirstExtractUse(rewriter, diag, producerOp, containingOp);
1114 if (!tiledOps.empty()) {
1115 LLVM_DEBUG(DBGS() << "\nFused a direct extract use\n" << *containingOp);
1116 fusedOps.append(RHS: tiledOps);
1117 if (newContainingOp) {
1118 // Update handles associated with the containing op so we don't need to
1119 // invalidate them. This is a hack to support better composability
1120 // between tiling and fusion while a proper mechanism is being
1121 // investigated.
1122 //
1123 // DO NOT replicate this elsewhere unless you understand what you are
1124 // doing.
1125 LogicalResult replacementStatus =
1126 rewriter.notifyPayloadOperationReplaced(op: containingOp,
1127 replacement: newContainingOp);
1128 (void)replacementStatus;
1129 assert(succeeded(replacementStatus) &&
1130 "unable to update transform state mapping");
1131 rewriter.eraseOp(op: containingOp);
1132 containingOp = newContainingOp;
1133 }
1134 continue;
1135 }
1136
1137 SmallVector<Operation *> tiledContainingOpOperand =
1138 tileAndFuseFirstExtractUseThroughContainingOpBlockArgument(
1139 rewriter, diag, producerOp, containingOp);
1140 if (!tiledContainingOpOperand.empty()) {
1141 LLVM_DEBUG(DBGS() << "\nFused an extract use through block argument\n"
1142 << *containingOp);
1143 fusedOps.append(RHS: tiledContainingOpOperand);
1144 continue;
1145 }
1146
1147 Operation *cloned =
1148 cloneAndFuseFirstUse(rewriter, diag, producerOp, containingOp);
1149 if (cloned) {
1150 LLVM_DEBUG(DBGS() << "\nFused an use by cloning\n" << *containingOp);
1151 fusedOps.push_back(Elt: cloned);
1152 continue;
1153 }
1154 return DiagnosedSilenceableFailure::silenceableFailure(diag: std::move(diag));
1155 }
1156
1157 results.set(value: cast<OpResult>(Val: getFusedOp()), ops&: fusedOps);
1158 results.set(value: cast<OpResult>(Val: getNewContainingOp()), ops: {containingOp});
1159 return DiagnosedSilenceableFailure::success();
1160}
1161
1162void transform::FuseIntoContainingOp::getEffects(
1163 SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
1164 consumesHandle(handles: getProducerOpMutable(), effects);
1165 onlyReadsHandle(handles: getContainingOpMutable(), effects);
1166 producesHandle(handles: getOperation()->getOpResults(), effects);
1167 modifiesPayload(effects);
1168}
1169
1170//===----------------------------------------------------------------------===//
1171// GeneralizeOp
1172//===----------------------------------------------------------------------===//
1173
1174DiagnosedSilenceableFailure
1175transform::GeneralizeOp::applyToOne(transform::TransformRewriter &rewriter,
1176 LinalgOp target,
1177 transform::ApplyToEachResultList &results,
1178 transform::TransformState &state) {
1179 // Exit early if no transformation is needed.
1180 if (isa<GenericOp>(Val: target)) {
1181 results.push_back(op: target);
1182 return DiagnosedSilenceableFailure::success();
1183 }
1184 rewriter.setInsertionPoint(target);
1185 FailureOr<LinalgOp> generic = generalizeNamedOp(rewriter, linalgOp: target);
1186 if (succeeded(Result: generic)) {
1187 results.push_back(op: generic->getOperation());
1188 return DiagnosedSilenceableFailure::success();
1189 }
1190 return emitDefaultSilenceableFailure(target);
1191}
1192
1193//===----------------------------------------------------------------------===//
1194// SpecializeOp
1195//===----------------------------------------------------------------------===/
1196
1197DiagnosedSilenceableFailure
1198transform::SpecializeOp::applyToOne(transform::TransformRewriter &rewriter,
1199 LinalgOp target,
1200 transform::ApplyToEachResultList &results,
1201 transform::TransformState &state) {
1202 // Exit early if the operation is not a generic.
1203 if (!isa<GenericOp>(Val: target)) {
1204 results.push_back(op: target);
1205 return DiagnosedSilenceableFailure::success();
1206 }
1207 rewriter.setInsertionPoint(target);
1208 FailureOr<LinalgOp> named =
1209 specializeGenericOp(rewriter, genericOp: cast<GenericOp>(Val&: target));
1210 if (succeeded(Result: named)) {
1211 results.push_back(op: named->getOperation());
1212 return DiagnosedSilenceableFailure::success();
1213 }
1214 return emitDefaultSilenceableFailure(target);
1215}
1216
1217//===----------------------------------------------------------------------===//
1218// InterchangeOp
1219//===----------------------------------------------------------------------===//
1220
1221DiagnosedSilenceableFailure
1222transform::InterchangeOp::applyToOne(transform::TransformRewriter &rewriter,
1223 GenericOp target,
1224 transform::ApplyToEachResultList &results,
1225 transform::TransformState &state) {
1226 ArrayRef<int64_t> interchangeVector = getIteratorInterchange();
1227 // Exit early if no transformation is needed.
1228 if (interchangeVector.empty()) {
1229 results.push_back(op: target);
1230 return DiagnosedSilenceableFailure::success();
1231 }
1232
1233 unsigned numLoops = cast<LinalgOp>(Val: target.getOperation()).getNumLoops();
1234 if (interchangeVector.size() != numLoops) {
1235 return emitSilenceableError()
1236 << getIteratorInterchangeAttrName() << " has length ("
1237 << interchangeVector.size()
1238 << ") different from the number of loops in the target operation ("
1239 << numLoops << ")";
1240 }
1241 FailureOr<GenericOp> res = interchangeGenericOp(
1242 rewriter, genericOp: target, interchangeVector: SmallVector<unsigned>(interchangeVector));
1243 if (failed(Result: res))
1244 return emitDefiniteFailure() << "failed to apply";
1245 results.push_back(op: res->getOperation());
1246 return DiagnosedSilenceableFailure::success();
1247}
1248
1249LogicalResult transform::InterchangeOp::verify() {
1250 ArrayRef<int64_t> permutation = getIteratorInterchange();
1251 auto sequence = llvm::to_vector(Range: llvm::seq<int64_t>(Begin: 0, End: permutation.size()));
1252 if (!std::is_permutation(first1: sequence.begin(), last1: sequence.end(),
1253 first2: permutation.begin(), last2: permutation.end())) {
1254 return emitOpError()
1255 << "expects iterator_interchange to be a permutation, found "
1256 << getIteratorInterchange();
1257 }
1258 return success();
1259}
1260
1261//===----------------------------------------------------------------------===//
1262// LinalgCopyToMemrefOp
1263//===----------------------------------------------------------------------===//
1264
1265DiagnosedSilenceableFailure transform::LinalgCopyToMemrefOp::applyToOne(
1266 transform::TransformRewriter &rewriter, Operation *targetOp,
1267 transform::ApplyToEachResultList &results,
1268 transform::TransformState &state) {
1269
1270 // Check if the target can be converted.
1271 if (!isa<linalg::CopyOp>(Val: targetOp)) {
1272 DiagnosedSilenceableFailure diag =
1273 emitSilenceableError() << "only linalg.copy target ops are supported";
1274 diag.attachNote(loc: targetOp->getLoc()) << "target op";
1275 return diag;
1276 }
1277
1278 auto copyOp = dyn_cast<linalg::CopyOp>(Val: targetOp);
1279 if (!copyOp.hasPureBufferSemantics()) {
1280 DiagnosedSilenceableFailure diag =
1281 emitSilenceableError()
1282 << "cannot transform a linalg.copy on tensors into a memref.copy";
1283 diag.attachNote(loc: targetOp->getLoc()) << "target op";
1284 return diag;
1285 }
1286
1287 SmallVector<Value> inputs = copyOp.getInputs();
1288 SmallVector<Value> outputs = copyOp.getOutputs();
1289 assert(inputs.size() == 1 && "expected linalg copy op with one input");
1290 assert(outputs.size() == 1 && "expected memref copy op with one output");
1291 Value input = inputs.front();
1292 Value output = outputs.front();
1293
1294 // linalg.copy supports different element types on source/dest whereas
1295 // memref.copy does not, so we must check that the source and dest types can
1296 // be handled by memref.copy and otherwise reject the transformation.
1297 if (!isa<ShapedType>(Val: input.getType())) {
1298 DiagnosedSilenceableFailure diag =
1299 emitSilenceableError()
1300 << "cannot transform a linalg.copy which input has no shape";
1301 diag.attachNote(loc: targetOp->getLoc()) << "target op";
1302 return diag;
1303 }
1304
1305 // linalg.copy destination must be a shaped type.
1306 assert(isa<ShapedType>(output.getType()));
1307
1308 if (cast<ShapedType>(Val: input.getType()).getElementType() !=
1309 cast<ShapedType>(Val: output.getType()).getElementType()) {
1310 DiagnosedSilenceableFailure diag =
1311 emitSilenceableError()
1312 << "cannot transform a linalg.copy with different source and "
1313 "destination element types ";
1314 diag.attachNote(loc: targetOp->getLoc()) << "target op";
1315 return diag;
1316 }
1317
1318 // Target can be converted, do it.
1319 auto memrefCopyOp =
1320 rewriter.replaceOpWithNewOp<memref::CopyOp>(op: targetOp, args&: input, args&: output);
1321
1322 results.push_back(op: memrefCopyOp);
1323 return DiagnosedSilenceableFailure::success();
1324}
1325
1326//===----------------------------------------------------------------------===//
1327// LowerPackOp
1328//===----------------------------------------------------------------------===//
1329
1330DiagnosedSilenceableFailure transform::LowerPackOp::applyToOne(
1331 transform::TransformRewriter &rewriter, linalg::PackOp target,
1332 transform::ApplyToEachResultList &transformResults,
1333 transform::TransformState &state) {
1334 rewriter.setInsertionPoint(target);
1335 bool lowerPadLikeWithInsertSlice = getLowerPadLikeWithInsertSlice();
1336 FailureOr<LowerPackResult> res =
1337 lowerPack(rewriter, packOp: target, lowerPadLikeWithInsertSlice);
1338 if (failed(Result: res)) {
1339 return mlir::emitSilenceableFailure(loc: target->getLoc())
1340 << "cannot lower to pad + expand + transpose";
1341 }
1342 transformResults.push_back(op: res->padOp);
1343 transformResults.push_back(op: res->expandShapeOp);
1344 transformResults.push_back(op: res->transposeOp);
1345 return DiagnosedSilenceableFailure::success();
1346}
1347
1348//===----------------------------------------------------------------------===//
1349// LowerUnPackOp
1350//===----------------------------------------------------------------------===//
1351
1352DiagnosedSilenceableFailure transform::LowerUnPackOp::applyToOne(
1353 transform::TransformRewriter &rewriter, linalg::UnPackOp target,
1354 transform::ApplyToEachResultList &transformResults,
1355 transform::TransformState &state) {
1356 rewriter.setInsertionPoint(target);
1357 bool lowerUnpadLikeWithExtractSlice = getLowerUnpadLikeWithExtractSlice();
1358 FailureOr<LowerUnPackOpResult> res =
1359 lowerUnPack(rewriter, unPackOp: target, lowerUnpadLikeWithExtractSlice);
1360 if (failed(Result: res)) {
1361 DiagnosedSilenceableFailure diag =
1362 emitSilenceableError()
1363 << "cannot lower to transpose + collapse + extract";
1364 diag.attachNote(loc: target->getLoc()) << "target payload op";
1365 return diag;
1366 }
1367 transformResults.push_back(op: res->emptyOp);
1368 transformResults.push_back(op: res->transposeOp);
1369 transformResults.push_back(op: res->collapseShapeOp);
1370 transformResults.push_back(op: res->extractSliceOp);
1371 return DiagnosedSilenceableFailure::success();
1372}
1373
1374//===---------------------------------------------------------------------===//
1375// MatchOp
1376//===---------------------------------------------------------------------===//
1377
1378void transform::MatchOp::build(OpBuilder &builder, OperationState &result,
1379 Value target, ArrayRef<StringRef> opNames) {
1380 result.addOperands(newOperands: target);
1381 result.addAttribute(name: MatchOp::getOpsAttrName(name: result.name),
1382 attr: builder.getStrArrayAttr(values: opNames));
1383 result.addTypes(newTypes: transform::AnyOpType::get(ctx: builder.getContext()));
1384}
1385
1386void transform::MatchOp::build(OpBuilder &builder, OperationState &result,
1387 TypeRange resultTypes, Value target,
1388 ArrayRef<StringRef> opNames) {
1389 result.addOperands(newOperands: target);
1390 result.addAttribute(name: MatchOp::getOpsAttrName(name: result.name),
1391 attr: builder.getStrArrayAttr(values: opNames));
1392 result.addTypes(newTypes&: resultTypes);
1393}
1394
1395DiagnosedSilenceableFailure
1396transform::MatchOp::apply(transform::TransformRewriter &rewriter,
1397 transform::TransformResults &results,
1398 transform::TransformState &state) {
1399 llvm::StringSet<> strs;
1400 if (getOps().has_value())
1401 strs.insert_range(R: getOps()->getAsValueRange<StringAttr>());
1402
1403 auto payloadOps = state.getPayloadOps(value: getTarget());
1404 if (!llvm::hasSingleElement(C&: payloadOps)) {
1405 return emitDefiniteFailure(message: "requires exactly one target handle");
1406 }
1407
1408 SmallVector<Operation *> res;
1409 bool incorrectNumOperandTypes = false;
1410 auto matchFun = [&](Operation *op) {
1411 if (getOps().has_value() && !strs.contains(key: op->getName().getStringRef()))
1412 return;
1413
1414 // Interfaces cannot be matched by name, just by ID.
1415 // So we specifically encode the interfaces we care about for this op.
1416 if (getInterface().has_value()) {
1417 auto iface = getInterface().value();
1418 if (iface == transform::MatchInterfaceEnum::LinalgOp &&
1419 !isa<LinalgOp>(Val: op))
1420 return;
1421 if (iface == transform::MatchInterfaceEnum::TilingInterface &&
1422 !isa<TilingInterface>(Val: op))
1423 return;
1424 if (iface == transform::MatchInterfaceEnum::LoopLikeInterface &&
1425 !isa<LoopLikeOpInterface>(Val: op))
1426 return;
1427 }
1428
1429 // Check if all specified attributes match.
1430 if (getOpAttrs().has_value()) {
1431 DictionaryAttr opAttrs = getOpAttrs().value();
1432 for (NamedAttribute attr : opAttrs) {
1433 if (attr.getName() == getInterfaceAttrName() ||
1434 attr.getName() == getOpsAttrName())
1435 continue;
1436 if (!op->hasAttr(name: attr.getName()))
1437 return;
1438 if (op->getAttr(name: attr.getName()) != attr.getValue())
1439 return;
1440 }
1441 }
1442
1443 if (getFilterResultType().has_value()) {
1444 Type t = getFilterResultType().value();
1445 if (op->getNumResults() != 1 || op->getResultTypes().front() != t)
1446 return;
1447 }
1448
1449 if (getFilterOperandTypes().has_value()) {
1450 mlir::ArrayAttr types = getFilterOperandTypes().value();
1451 auto operandTypes = op->getOperandTypes();
1452
1453 if (types.size() == 1) {
1454 // All the operands must must be equal to the specified type
1455 auto typeattr =
1456 dyn_cast<mlir::TypeAttr>(Val: getFilterOperandTypes().value()[0]);
1457 Type t = cast<::mlir::Type>(Val: typeattr.getValue());
1458 if (!llvm::all_of(Range: op->getOperandTypes(),
1459 P: [&](Type operandType) { return operandType == t; }))
1460 return;
1461 } else {
1462 // The operand types must match all the types in the list (in the same
1463 // order in with they are specified)
1464 if (types.size() != operandTypes.size()) {
1465 incorrectNumOperandTypes = true;
1466 return;
1467 }
1468
1469 for (auto [attr, operandType] :
1470 llvm::zip_equal(t: getFilterOperandTypes().value(), u&: operandTypes)) {
1471 auto typeattr = cast<mlir::TypeAttr>(Val: attr);
1472 Type type = cast<::mlir::Type>(Val: typeattr.getValue());
1473
1474 if (type != operandType)
1475 return;
1476 }
1477 }
1478 }
1479
1480 // All constraints are satisfied.
1481 res.push_back(Elt: op);
1482 return;
1483 };
1484
1485 (*payloadOps.begin())->walk(callback&: matchFun);
1486 if (incorrectNumOperandTypes)
1487 return emitDefiniteFailure(message: "If filter_operand_types contains more than a "
1488 "type, then it must contain as much types as "
1489 "the number of operands in the target ops");
1490 results.set(value: cast<OpResult>(Val: getResult()), ops&: res);
1491 return DiagnosedSilenceableFailure::success();
1492}
1493
1494//===---------------------------------------------------------------------===//
1495// MultiTileSizesOp
1496//===---------------------------------------------------------------------===//
1497
1498static void printMultitileSizesTypes(OpAsmPrinter &printer, Operation *op,
1499 Type targetType, Type lowSizeType, Type,
1500 Type) {
1501 printer.printFunctionalType(inputs: TypeRange{targetType}, results: TypeRange{lowSizeType});
1502}
1503
1504static ParseResult parseMultitileSizesTypes(OpAsmParser &parser,
1505 Type &targetType, Type &lowSizeType,
1506 Type &highSizeType,
1507 Type &splitPointType) {
1508 FunctionType funcType;
1509 llvm::SMLoc typeLoc = parser.getCurrentLocation();
1510 if (failed(Result: parser.parseType<FunctionType>(result&: funcType)))
1511 return failure();
1512
1513 if (funcType.getNumInputs() != 1 || funcType.getNumResults() != 1) {
1514 parser.emitError(loc: typeLoc) << "expects a trailing functional type with one "
1515 "argument and one result";
1516 }
1517 targetType = funcType.getInput(i: 0);
1518 lowSizeType = highSizeType = splitPointType = funcType.getResult(i: 0);
1519
1520 return success();
1521}
1522
1523DiagnosedSilenceableFailure transform::MultiTileSizesOp::applyToOne(
1524 transform::TransformRewriter &rewriter, LinalgOp target,
1525 transform::ApplyToEachResultList &results, TransformState &state) {
1526 if (isa<TransformParamTypeInterface>(Val: getLowSize().getType())) {
1527 if (target.hasDynamicShape()) {
1528 auto diag = emitSilenceableError()
1529 << "cannot compute parametric tile sizes for dynamically "
1530 "shaped payload op";
1531 diag.attachNote(loc: target->getLoc()) << "payload op";
1532 return diag;
1533 }
1534
1535 FailureOr<StaticMultiSizeSpecification> spec = computeStaticMultiTileSizes(
1536 op: target, dimension: getDimension(), targetSize: getTargetSize(), divisor: getDivisor());
1537 if (failed(Result: spec)) {
1538 return emitSilenceableError()
1539 << "failed to compute multi-size tiling sizes";
1540 }
1541
1542 Builder builder(target.getContext());
1543 results.assign(range: llvm::map_range(
1544 C: ArrayRef<int64_t>({spec->lowTileSize, spec->highTileSize,
1545 spec->lowTileSize * spec->lowTripCount}),
1546 F: [&builder, this](int64_t value) {
1547 return builder.getIntegerAttr(
1548 type: cast<ParamType>(Val: getLowSize().getType()).getType(), value);
1549 }));
1550 return DiagnosedSilenceableFailure::success();
1551 }
1552
1553 OpBuilder builder(target.getContext());
1554 builder.setInsertionPoint(target);
1555 OpFoldResult targetSize = builder.getIndexAttr(value: getTargetSize());
1556 OpFoldResult divisor = builder.getIndexAttr(value: getDivisor());
1557 FailureOr<MultiSizeSpecification> spec = computeMultiTileSizes(
1558 builder, op: target, dimension: getDimension(), targetSize, divisor);
1559 if (failed(Result: spec)) {
1560 return emitSilenceableError() << "could not generate tile size computation";
1561 }
1562
1563 AffineExpr s0 = builder.getAffineSymbolExpr(position: 0);
1564 AffineExpr s1 = builder.getAffineSymbolExpr(position: 1);
1565 Operation *splitPoint =
1566 affine::makeComposedAffineApply(b&: builder, loc: target.getLoc(), e: s0 * s1,
1567 operands: {spec->lowTileSize, spec->lowTripCount});
1568 Operation *lowTileSize = spec->lowTileSize.getDefiningOp();
1569 Operation *highTileSize = spec->highTileSize.getDefiningOp();
1570 assert(lowTileSize && highTileSize && splitPoint &&
1571 "tile sizes are not produced by operations");
1572 results.reserve(size: results.size() + 3);
1573 results.push_back(op: lowTileSize);
1574 results.push_back(op: highTileSize);
1575 results.push_back(op: splitPoint);
1576 return DiagnosedSilenceableFailure::success();
1577}
1578
1579void transform::MultiTileSizesOp::getEffects(
1580 SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
1581 onlyReadsHandle(handles: getTargetMutable(), effects);
1582 producesHandle(handles: getOperation()->getOpResults(), effects);
1583 if (isa<TransformParamTypeInterface>(Val: getLowSize().getType()))
1584 onlyReadsPayload(effects);
1585 else
1586 modifiesPayload(effects);
1587}
1588
1589LogicalResult transform::MultiTileSizesOp::verify() {
1590 if (getLowSize().getType() != getHighSize().getType() ||
1591 getLowSize().getType() != getSplitPoint().getType()) {
1592 return emitOpError() << "expects all results type to be the same";
1593 }
1594 return success();
1595}
1596
1597//===---------------------------------------------------------------------===//
1598// PackOp
1599//===---------------------------------------------------------------------===//
1600
1601void transform::PackOp::build(OpBuilder &builder, OperationState &result,
1602 Value target,
1603 ArrayRef<OpFoldResult> mixedPackedSizes) {
1604 SmallVector<int64_t> staticPackedSizes;
1605 SmallVector<Value> dynamicPackedSizes;
1606 dispatchIndexOpFoldResults(ofrs: mixedPackedSizes, dynamicVec&: dynamicPackedSizes,
1607 staticVec&: staticPackedSizes);
1608 // Call the default builder which sets up the proper operands segment sizes
1609 // attributes for multiple variadic operands. In the absence of this, horrible
1610 // bugs ensue.
1611 Type linalgOpHType = transform::OperationType::get(
1612 context: builder.getContext(), operation_name: GenericOp::getOperationName());
1613 build(odsBuilder&: builder, odsState&: result,
1614 /*resultType=*/packed_op: linalgOpHType,
1615 /*target=*/target,
1616 /*dynamic_sizes=*/packed_sizes: dynamicPackedSizes,
1617 /*static_sizes=*/static_packed_sizes: builder.getDenseI64ArrayAttr(values: staticPackedSizes));
1618}
1619
1620SmallVector<OpFoldResult> transform::PackOp::getMixedPackedSizes() {
1621 Builder b(getContext());
1622 return getMixedValues(staticValues: getStaticPackedSizes(), dynamicValues: getPackedSizes(), b);
1623}
1624
1625DiagnosedSilenceableFailure
1626transform::PackOp::apply(transform::TransformRewriter &rewriter,
1627 transform::TransformResults &transformResults,
1628 transform::TransformState &state) {
1629 auto targetOps = state.getPayloadOps(value: getTarget());
1630 // If nothing to pack, propagate success.
1631 if (std::empty(cont: targetOps)) {
1632 transformResults.set(value: cast<OpResult>(Val: getPackedOp()),
1633 ops: ArrayRef<Operation *>({}));
1634 return DiagnosedSilenceableFailure::success();
1635 }
1636 // Fail on multi-op handles.
1637 auto linalgOp = dyn_cast<LinalgOp>(Val: *targetOps.begin());
1638 if (!llvm::hasSingleElement(C&: targetOps) || !linalgOp) {
1639 return emitSilenceableError()
1640 << "requires target to map to exactly 1 LinalgOp (got "
1641 << llvm::range_size(Range&: targetOps) << ")";
1642 }
1643 // Fail on mismatched number of pack sizes.
1644 if (getMixedPackedSizes().size() != linalgOp.getNumLoops()) {
1645 return emitSilenceableError()
1646 << "requires number of packed sizes match the number of loops ("
1647 << getMixedPackedSizes().size() << " vs " << linalgOp.getNumLoops()
1648 << ")";
1649 }
1650
1651 // Unpack handles to constants or actual SSA index values.
1652 SmallVector<OpFoldResult> packedSizes;
1653 DiagnosedSilenceableFailure status = unpackSingleIndexResultPayloadOperations(
1654 state, transformOp: *this, result&: packedSizes, ofrs: getMixedPackedSizes());
1655
1656 rewriter.setInsertionPoint(linalgOp);
1657 FailureOr<PackResult> maybeResult = pack(rewriter, linalgOp, packedSizes);
1658 if (failed(Result: maybeResult))
1659 return emitDefiniteFailure(message: "data tiling failed");
1660
1661 transformResults.set(value: cast<OpResult>(Val: getPackedOp()),
1662 ops: {maybeResult->packedLinalgOp.getOperation()});
1663 return DiagnosedSilenceableFailure::success();
1664}
1665
1666void transform::PackOp::getEffects(
1667 SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
1668 transform::consumesHandle(handles: getTargetMutable(), effects);
1669 transform::onlyReadsHandle(handles: getPackedSizesMutable(), effects);
1670 transform::producesHandle(handles: getOperation()->getOpResults(), effects);
1671 transform::modifiesPayload(effects);
1672}
1673
1674//===---------------------------------------------------------------------===//
1675// PackGreedilyOp.
1676//===---------------------------------------------------------------------===//
1677
1678LogicalResult transform::PackGreedilyOp::verify() {
1679 if (!isPermutationVector(interchange: getMatmulInnerDimsOrder())) {
1680 return emitOpError() << getMatmulInnerDimsOrderAttrName()
1681 << " is not a valid permutation";
1682 }
1683 // TODO: relax to allow empty once we have another strategy than just matmul.
1684 if (!getMatmulPaddedSizesNextMultipleOf().empty()) {
1685 for (auto [s, nmo] :
1686 llvm::zip_equal(t: getMixedMatmulPackedSizes(),
1687 u: getMatmulPaddedSizesNextMultipleOf())) {
1688 std::optional<int64_t> maybeStaticPackedSize = getConstantIntValue(ofr: s);
1689 if (nmo != 0 &&
1690 (!maybeStaticPackedSize.has_value() || *maybeStaticPackedSize != 0)) {
1691 return emitOpError() << "at most one of the packed_size and the "
1692 "padded_sizes_next_multiple_of can be nonzero "
1693 "for the matmul strategy";
1694 }
1695 }
1696 }
1697 return success();
1698}
1699
1700DiagnosedSilenceableFailure
1701PackGreedilyOp::apply(transform::TransformRewriter &rewriter,
1702 transform::TransformResults &transformResults,
1703 transform::TransformState &state) {
1704 SmallVector<Operation *> results;
1705 for (Operation *op : state.getPayloadOps(value: getTarget())) {
1706 auto linalgOp = dyn_cast<LinalgOp>(Val: op);
1707 if (!linalgOp)
1708 continue;
1709 // linalgOp will be replaced and the insertion point may be invalidated if
1710 // we set it before -> set it after.
1711 rewriter.setInsertionPointAfter(linalgOp);
1712 // Failing to pack greedily is perfectly fine.
1713 // In the future we will want to order packings according to some metric.
1714 FailureOr<PackResult> packResult = packMatmulGreedily(
1715 /*rewriter=*/rewriter,
1716 /*linalgOp=*/linalgOp,
1717 /*mnkPackedSizes=*/getMixedMatmulPackedSizes(),
1718 /*mnkPaddedSizesNextMultipleOf=*/
1719 getMatmulPaddedSizesNextMultipleOf(),
1720 /*mnkOrder=*/getMatmulInnerDimsOrder());
1721 if (succeeded(Result: packResult)) {
1722 results.push_back(Elt: packResult->packedLinalgOp);
1723 continue;
1724 }
1725 results.push_back(Elt: linalgOp);
1726 }
1727 transformResults.set(value: cast<OpResult>(Val: getPackedOp()), ops&: results);
1728 return DiagnosedSilenceableFailure::success();
1729}
1730
1731SmallVector<OpFoldResult> PackGreedilyOp::getMixedMatmulPackedSizes() {
1732 Builder b(getContext());
1733 return getMixedValues(staticValues: getStaticMatmulPackedSizes(), dynamicValues: getMatmulPackedSizes(),
1734 b);
1735}
1736
1737void transform::PackGreedilyOp::getEffects(
1738 SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
1739 transform::consumesHandle(handles: getTargetMutable(), effects);
1740 transform::onlyReadsHandle(handles: getMatmulPackedSizesMutable(), effects);
1741 transform::producesHandle(handles: getOperation()->getOpResults(), effects);
1742 transform::modifiesPayload(effects);
1743}
1744
1745//===---------------------------------------------------------------------===//
1746// PackTransposeOp
1747//===---------------------------------------------------------------------===//
1748
1749LogicalResult transform::PackTransposeOp::verify() {
1750 if (!isPermutationVector(interchange: getInnerPerm())) {
1751 return emitOpError() << getInnerPermAttrName()
1752 << " is not a valid permutation";
1753 }
1754 if (!isPermutationVector(interchange: getOuterPerm())) {
1755 return emitOpError() << getOuterPermAttrName()
1756 << " is not a valid permutation";
1757 }
1758 if (getInnerPerm().empty() && getOuterPerm().empty()) {
1759 return emitOpError() << " at least one of " << getInnerPermAttrName()
1760 << " or " << getOuterPermAttrName()
1761 << " must be specified";
1762 }
1763 return success();
1764}
1765
1766namespace {
1767enum class OuterOrInnerPerm { Outer = 0, Inner = 1 };
1768} // namespace
1769
1770/// Return true if `permutation` is a valid permutation of the
1771/// `outer_dims_perm` (case OuterOrInnerPerm::Outer) or `inner_dims_pos`
1772/// (OuterOrInnerPerm::Inner) of the `tensor.pack` or `tensor.unpack` `op.
1773/// This is the case when the `permutation` rank matches the rank expected by
1774/// `op` and `permutation` is itself a permutation vector.
1775/// Return true if either `op` or `permutation` are empty to allow a simpler
1776/// polymorphic implementation.
1777template <typename RelayoutOpTy>
1778bool isValidPackingPermutation(
1779 RelayoutOpTy op, ArrayRef<int64_t> permutation,
1780 OuterOrInnerPerm outerOrInnerPerm = OuterOrInnerPerm::Outer) {
1781 static_assert(
1782 llvm::is_one_of<RelayoutOpTy, linalg::PackOp, linalg::UnPackOp>::value,
1783 "applies to only pack or unpack operations");
1784 if (!op || permutation.empty())
1785 return true;
1786 size_t innerRank = op.getInnerDimsPos().size();
1787 if (outerOrInnerPerm == OuterOrInnerPerm::Inner)
1788 return permutation.size() == innerRank && isPermutationVector(interchange: permutation);
1789 // op.getOuterDimsPerm() may be empty, in which case it is identity.
1790 // Don't rely on it.
1791 if (std::is_same<RelayoutOpTy, linalg::PackOp>::value) {
1792 return permutation.size() == op.getSourceRank() &&
1793 isPermutationVector(interchange: permutation);
1794 }
1795 return permutation.size() == op.getDestRank() &&
1796 isPermutationVector(interchange: permutation);
1797}
1798
1799DiagnosedSilenceableFailure
1800transform::PackTransposeOp::apply(transform::TransformRewriter &rewriter,
1801 transform::TransformResults &transformResults,
1802 transform::TransformState &state) {
1803 auto packOrUnpackOps = state.getPayloadOps(value: getTargetPackOrUnPackOp());
1804 auto linalgOps = state.getPayloadOps(value: getTargetLinalgOp());
1805 // Step 1. If nothing to pack, propagate success.
1806 if (std::empty(cont: packOrUnpackOps)) {
1807 transformResults.set(value: cast<OpResult>(Val: getPackedOp()), ops: {});
1808 transformResults.set(value: cast<OpResult>(Val: getPackOp()), ops: {});
1809 transformResults.set(value: cast<OpResult>(Val: getUnPackOp()), ops: {});
1810 return DiagnosedSilenceableFailure::success();
1811 }
1812
1813 // Step 2. Bunch of runtime sanity check and error messages.
1814 // Step 2.1. Fail on multi-op handles.
1815 if (!llvm::hasSingleElement(C&: packOrUnpackOps) ||
1816 !llvm::hasSingleElement(C&: linalgOps)) {
1817 return emitSilenceableError()
1818 << "requires target to map to exactly 1 "
1819 "packing op and 1 packed op ("
1820 << "got " << llvm::range_size(Range&: packOrUnpackOps) << " and "
1821 << llvm::range_size(Range&: linalgOps) << ")";
1822 }
1823
1824 // Step 2.2. Fail on wrong type.
1825 auto packOp = dyn_cast<linalg::PackOp>(Val: *packOrUnpackOps.begin());
1826 auto unPackOp = dyn_cast<linalg::UnPackOp>(Val: *packOrUnpackOps.begin());
1827 if ((!packOp && !unPackOp)) {
1828 return emitSilenceableError() << "requires target to map to a "
1829 "linalg.pack or linalg.unpack";
1830 }
1831 LinalgOp linalgOpTarget = dyn_cast<LinalgOp>(Val: *linalgOps.begin());
1832 if (!linalgOpTarget)
1833 return emitSilenceableError() << "requires a LinalgOp target";
1834
1835 // Step 2.3. Fail if we can't get the producer / consumer Linalg op.
1836 LinalgOp linalgOp;
1837 if (packOp && packOp.getResult().hasOneUse())
1838 linalgOp = dyn_cast<LinalgOp>(Val: *(packOp.getResult().getUsers().begin()));
1839 else if (unPackOp)
1840 linalgOp = unPackOp.getSource().getDefiningOp<LinalgOp>();
1841 if (linalgOp != linalgOpTarget) {
1842 auto errorMsg =
1843 packOp ? StringLiteral{"not a single use by the LinalgOp target"}
1844 : StringLiteral{"not produced by the LinalgOp target"};
1845 return emitSilenceableError() << errorMsg;
1846 }
1847
1848 // Step 2.4. If we have an UnPackOp, we need to fetch the symmetrical
1849 // PackOp.
1850 if (unPackOp) {
1851 assert(!packOp && "packOp must be null on entry when unPackOp is not null");
1852 OpOperand *packUse = linalgOp.getDpsInitOperand(
1853 i: cast<OpResult>(Val: unPackOp.getSource()).getResultNumber());
1854 packOp = dyn_cast_or_null<linalg::PackOp>(Val: packUse->get().getDefiningOp());
1855 if (!packOp || !packOp.getResult().hasOneUse())
1856 return emitSilenceableError() << "could not find matching pack op";
1857 }
1858
1859 // Step 2.5. Fail if any permutation does not validate.
1860 for (auto permType : {OuterOrInnerPerm::Outer, OuterOrInnerPerm::Inner}) {
1861 ArrayRef<int64_t> perm =
1862 (permType == OuterOrInnerPerm::Outer) ? getOuterPerm() : getInnerPerm();
1863 auto errorMsg = (permType == OuterOrInnerPerm::Outer)
1864 ? StringLiteral{"invalid outer_perm"}
1865 : StringLiteral{"invalid inner_perm"};
1866 if (!isValidPackingPermutation(op: packOp, permutation: perm, outerOrInnerPerm: permType) ||
1867 !isValidPackingPermutation(op: unPackOp, permutation: perm, outerOrInnerPerm: permType)) {
1868 Operation *packOrUnpackOp =
1869 unPackOp ? unPackOp.getOperation() : packOp.getOperation();
1870 return emitSilenceableError() << errorMsg << ": " << *packOrUnpackOp;
1871 }
1872 }
1873
1874 // From here on, packOp and linalgOp are always present, unPackOp may or may
1875 // not be present.
1876 assert(packOp && linalgOp && "unexpected null op");
1877
1878 // Step 3. Actually transpose the ops.
1879 FailureOr<PackTransposeResult> res = packTranspose(
1880 rewriter, packOp, linalgOp, maybeUnPackOp: unPackOp, outerPerm: getOuterPerm(), innerPerm: getInnerPerm());
1881 // Preconditions have been checked, it is an error to fail here.
1882 assert(succeeded(res) && "unexpected packTranspose failure");
1883
1884 // Step 4. Return results.
1885 transformResults.set(value: cast<OpResult>(Val: getPackOp()), ops: {res->transposedPackOp});
1886 transformResults.set(value: cast<OpResult>(Val: getPackedOp()),
1887 ops: {res->transposedLinalgOp});
1888 if (unPackOp) {
1889 transformResults.set(value: cast<OpResult>(Val: getUnPackOp()),
1890 ops: {res->transposedUnPackOp});
1891 } else {
1892 transformResults.set(value: cast<OpResult>(Val: getUnPackOp()), ops: {});
1893 }
1894
1895 return DiagnosedSilenceableFailure::success();
1896}
1897
1898//===---------------------------------------------------------------------===//
1899// PadOp
1900//===---------------------------------------------------------------------===//
1901
1902void transform::PadOp::build(OpBuilder &b, OperationState &result, Value target,
1903 ArrayRef<int64_t> paddingDimensions,
1904 ArrayRef<int64_t> padToMultipleOf,
1905 ArrayRef<int64_t> nofoldFlags,
1906 ArrayRef<Attribute> transposePaddings,
1907 StringRef copyBackOp,
1908 bool usePrescribedTensorShapes) {
1909 auto resultType = transform::AnyOpType::get(ctx: b.getContext());
1910 return build(/*builder=*/odsBuilder&: b,
1911 /*result=*/odsState&: result,
1912 /*types=*/resultTypes: TypeRange{resultType, resultType},
1913 /*target=*/target,
1914 /*paddingValues=*/padding_values: ArrayAttr(), // let inference handle this
1915 /*paddingDimensions=*/padding_dimensions: b.getI64ArrayAttr(values: paddingDimensions),
1916 /*padToMultipleOf=*/pad_to_multiple_of: ValueRange{},
1917 /*padToMultipleOf=*/
1918 static_pad_to_multiple_of: (padToMultipleOf.empty()
1919 ? DenseI64ArrayAttr()
1920 : b.getDenseI64ArrayAttr(values: padToMultipleOf)),
1921 /*nofoldFlags=*/nofold_flags: b.getI64ArrayAttr(values: nofoldFlags),
1922 /*transposePaddings=*/transpose_paddings: b.getArrayAttr(value: transposePaddings),
1923 /*copyBackOp=*/copy_back_op: b.getStringAttr(bytes: copyBackOp),
1924 /*usePrescribedTensorShapes=*/
1925 use_prescribed_tensor_shapes: usePrescribedTensorShapes ? b.getUnitAttr() : nullptr);
1926}
1927
1928void transform::PadOp::build(OpBuilder &b, OperationState &result, Value target,
1929 ArrayRef<int64_t> paddingDimensions,
1930 ArrayRef<OpFoldResult> mixedPadToMultipleOf,
1931 ArrayRef<int64_t> nofoldFlags,
1932 ArrayRef<Attribute> transposePaddings,
1933 StringRef copyBackOp,
1934 bool usePrescribedTensorShapes) {
1935 auto resultType = transform::AnyOpType::get(ctx: b.getContext());
1936 SmallVector<int64_t> staticPadToMultipleOf;
1937 SmallVector<Value> dynamicPadToMultipleOf;
1938 dispatchIndexOpFoldResults(ofrs: mixedPadToMultipleOf, dynamicVec&: dynamicPadToMultipleOf,
1939 staticVec&: staticPadToMultipleOf);
1940 return build(/*builder=*/odsBuilder&: b,
1941 /*result=*/odsState&: result,
1942 /*types=*/resultTypes: TypeRange{resultType, resultType},
1943 /*target=*/target,
1944 /*paddingValues=*/padding_values: ArrayAttr(), // let inference handle this
1945 /*paddingDimensions=*/padding_dimensions: b.getI64ArrayAttr(values: paddingDimensions),
1946 /*padToMultipleOf=*/pad_to_multiple_of: dynamicPadToMultipleOf,
1947 /*padToMultipleOf=*/static_pad_to_multiple_of: staticPadToMultipleOf,
1948 /*nofoldFlags=*/nofold_flags: b.getI64ArrayAttr(values: nofoldFlags),
1949 /*transposePaddings=*/transpose_paddings: b.getArrayAttr(value: transposePaddings),
1950 /*copyBackOp=*/copy_back_op: copyBackOp,
1951 /*usePrescribedTensorShapes=*/use_prescribed_tensor_shapes: usePrescribedTensorShapes);
1952}
1953
1954void PadOp::getEffects(
1955 SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
1956 consumesHandle(handles: getTargetMutable(), effects);
1957 onlyReadsHandle(handles: getPadToMultipleOfMutable(), effects);
1958 producesHandle(handles: getOperation()->getOpResults(), effects);
1959 modifiesPayload(effects);
1960}
1961
1962SmallVector<OpFoldResult> PadOp::getMixedPadToMultipleOf() {
1963 Builder b(getContext());
1964 return getMixedValues(staticValues: getStaticPadToMultipleOf(), dynamicValues: getPadToMultipleOf(), b);
1965}
1966
1967DiagnosedSilenceableFailure
1968transform::PadOp::apply(transform::TransformRewriter &rewriter,
1969 transform::TransformResults &results,
1970 transform::TransformState &state) {
1971 auto transformOp = cast<TransformOpInterface>(Val: getOperation());
1972 SmallVector<Operation *> paddedOps, padOps, copyBackOps;
1973
1974 for (Operation *target : state.getPayloadOps(value: getTarget())) {
1975 auto linalgTarget = dyn_cast<LinalgOp>(Val: target);
1976 if (!linalgTarget) {
1977 auto diag = emitSilenceableError() << "expected LinalgOp target";
1978 diag.attachNote(loc: target->getLoc()) << "target op";
1979 return diag;
1980 }
1981
1982 // Convert the integer packing flags to booleans.
1983 SmallVector<bool> nofoldFlags;
1984 for (int64_t packPadding :
1985 extractFromIntegerArrayAttr<int64_t>(attr: getNofoldFlags()))
1986 nofoldFlags.push_back(Elt: static_cast<bool>(packPadding));
1987
1988 // Convert the padding values to attributes.
1989 SmallVector<Attribute> paddingValues;
1990 for (auto const &it :
1991 llvm::zip(t: getPaddingValues(), u: linalgTarget->getOperandTypes())) {
1992 auto attr = dyn_cast<TypedAttr>(Val: std::get<0>(t: it));
1993 if (!attr) {
1994 emitOpError(message: "expects padding values to be typed attributes");
1995 return DiagnosedSilenceableFailure::definiteFailure();
1996 }
1997 Type elementType = getElementTypeOrSelf(type: std::get<1>(t: it));
1998 // Try to parse string attributes to obtain an attribute of element type.
1999 if (auto stringAttr = dyn_cast<StringAttr>(Val&: attr)) {
2000 auto parsedAttr = dyn_cast_if_present<TypedAttr>(Val: parseAttribute(
2001 attrStr: stringAttr, context: getContext(), type: elementType,
2002 /*numRead=*/nullptr, /*isKnownNullTerminated=*/true));
2003 if (!parsedAttr || parsedAttr.getType() != elementType) {
2004 auto diag = this->emitOpError(message: "expects a padding that parses to ")
2005 << elementType << ", got " << std::get<0>(t: it);
2006 diag.attachNote(noteLoc: linalgTarget.getLoc()) << "when applied to this op";
2007 return DiagnosedSilenceableFailure::definiteFailure();
2008 }
2009 paddingValues.push_back(Elt: parsedAttr);
2010 continue;
2011 }
2012 // Otherwise, add the attribute directly.
2013 if (attr.getType() != elementType) {
2014 auto diag = this->emitOpError(message: "expects a padding value of type ")
2015 << elementType << ", got " << attr;
2016 diag.attachNote(noteLoc: linalgTarget.getLoc()) << "when applied to this op";
2017 return DiagnosedSilenceableFailure::definiteFailure();
2018 }
2019 paddingValues.push_back(Elt: attr);
2020 }
2021
2022 // Extract the transpose vectors.
2023 SmallVector<SmallVector<int64_t>> transposePaddings;
2024 for (Attribute transposeVector : cast<ArrayAttr>(Val: getTransposePaddings()))
2025 transposePaddings.push_back(Elt: extractFromIntegerArrayAttr<int64_t>(
2026 attr: cast<ArrayAttr>(Val&: transposeVector)));
2027
2028 LinalgOp paddedOp;
2029 LinalgPaddingOptions options;
2030 options.paddingDimensions =
2031 extractFromIntegerArrayAttr<int64_t>(attr: getPaddingDimensions());
2032
2033 SmallVector<int64_t> padToMultipleOf;
2034 DiagnosedSilenceableFailure status = reifyMixedParamAndHandleResults(
2035 state, transformOp, mixedResults: getMixedPadToMultipleOf(), reified&: padToMultipleOf);
2036 if (!status.succeeded())
2037 return status;
2038 if (padToMultipleOf.empty())
2039 padToMultipleOf =
2040 SmallVector<int64_t>(options.paddingDimensions.size(), 1);
2041
2042 options.padToMultipleOf = padToMultipleOf;
2043 options.paddingValues = paddingValues;
2044 options.nofoldFlags = nofoldFlags;
2045 if (getCopyBackOp() ==
2046 bufferization::MaterializeInDestinationOp::getOperationName()) {
2047 options.copyBackOp = LinalgPaddingOptions::CopyBackOp::
2048 BufferizationMaterializeInDestination;
2049 } else if (getCopyBackOp() == linalg::CopyOp::getOperationName()) {
2050 options.copyBackOp = LinalgPaddingOptions::CopyBackOp::LinalgCopy;
2051 } else if (getCopyBackOp() == kCopyOpNone) {
2052 options.copyBackOp = LinalgPaddingOptions::CopyBackOp::None;
2053 } else {
2054 llvm_unreachable("unsupported copy_back op");
2055 }
2056 // Populate `sizeToPadTo` with the dynamic tensor sizes for each operand.
2057 bool irChanged = false;
2058 if (getUsePrescribedTensorShapes() &&
2059 linalgTarget.hasPureTensorSemantics()) {
2060 OpBuilder::InsertionGuard g(rewriter);
2061 rewriter.setInsertionPoint(linalgTarget);
2062 for (OpOperand &operand : linalgTarget->getOpOperands()) {
2063 for (auto [i, dim] : llvm::enumerate(First: linalgTarget.getShape(opOperand: &operand))) {
2064 if (ShapedType::isStatic(dValue: dim))
2065 continue;
2066 options.setSizeToPadTo(operandIndex: operand.getOperandNumber(), dimIndex: i,
2067 size: tensor::getMixedSize(builder&: rewriter,
2068 loc: operand.get().getLoc(),
2069 value: operand.get(), dim: i));
2070 irChanged = true;
2071 }
2072 }
2073 }
2074
2075 SmallVector<Value> replacements;
2076 SmallVector<tensor::PadOp> newPadOps;
2077 if (failed(Result: rewriteAsPaddedOp(rewriter, opToPad: linalgTarget, options, paddedOp,
2078 replacements, padOps&: newPadOps))) {
2079 if (irChanged) {
2080 auto diag = emitDefiniteFailure() << "failed to pad op";
2081 diag.attachNote(loc: target->getLoc()) << "target op";
2082 return diag;
2083 }
2084 auto diag = emitSilenceableError() << "failed to pad op";
2085 diag.attachNote(loc: target->getLoc()) << "target op";
2086 return diag;
2087 }
2088
2089 // We need to perform our own replacement here because this API is still
2090 // used in patterns that "pad and hoist", for which the replacement values
2091 // need to be different.
2092 // TODO: clean this up and stop "pad and hoist" behavior more globally now
2093 // that we have more composable abstractions.
2094 rewriter.replaceOp(op: linalgTarget, newValues: replacements);
2095 paddedOps.push_back(Elt: paddedOp);
2096 padOps.append(in_start: newPadOps.begin(), in_end: newPadOps.end());
2097 if (options.copyBackOp != LinalgPaddingOptions::CopyBackOp::None) {
2098 for (Value v : replacements) {
2099 Operation *copyBackOp = v.getDefiningOp();
2100 if (!llvm::is_contained(Range&: copyBackOps, Element: copyBackOp))
2101 copyBackOps.push_back(Elt: copyBackOp);
2102 }
2103 }
2104 }
2105
2106 results.set(value: cast<OpResult>(Val: getPadded()), ops&: paddedOps);
2107 results.set(value: cast<OpResult>(Val: getPad()), ops&: padOps);
2108 results.set(value: cast<OpResult>(Val: getCopy()), ops&: copyBackOps);
2109 return DiagnosedSilenceableFailure::success();
2110}
2111
2112LogicalResult transform::PadOp::verify() {
2113 SmallVector<int64_t> nofoldFlags =
2114 extractFromIntegerArrayAttr<int64_t>(attr: getNofoldFlags());
2115 if (any_of(Range&: nofoldFlags, P: [](int64_t packPadding) {
2116 return packPadding != 0 && packPadding != 1;
2117 })) {
2118 return emitOpError()
2119 << "expects nofold_flags to contain booleans (0/1), found "
2120 << getNofoldFlags();
2121 }
2122
2123 SmallVector<int64_t> paddingDimensions =
2124 extractFromIntegerArrayAttr<int64_t>(attr: getPaddingDimensions());
2125 if (any_of(Range&: paddingDimensions,
2126 P: [](int64_t paddingDimension) { return paddingDimension < 0; })) {
2127 return emitOpError() << "expects padding_dimensions to contain positive "
2128 "integers, found "
2129 << getPaddingDimensions();
2130 }
2131 if (!getMixedPadToMultipleOf().empty()) {
2132 if (getMixedPadToMultipleOf().size() != paddingDimensions.size()) {
2133 return emitOpError() << "expects as many multiples as padding_dimensions";
2134 }
2135 }
2136 ArrayAttr transposes = getTransposePaddings();
2137 for (Attribute attr : transposes) {
2138 SmallVector<int64_t> transpose = extractFromIntegerArrayAttr<int64_t>(attr);
2139 auto sequence = llvm::to_vector(Range: llvm::seq<int64_t>(Begin: 0, End: transpose.size()));
2140 if (!std::is_permutation(first1: sequence.begin(), last1: sequence.end(),
2141 first2: transpose.begin(), last2: transpose.end())) {
2142 return emitOpError()
2143 << "expects transpose_paddings to be a permutation, found "
2144 << attr;
2145 }
2146 }
2147 if (getCopyBackOp() !=
2148 bufferization::MaterializeInDestinationOp::getOperationName() &&
2149 getCopyBackOp() != linalg::CopyOp::getOperationName() &&
2150 getCopyBackOp() != kCopyOpNone)
2151 return emitOpError() << "invalid copy_back_op";
2152 return success();
2153}
2154
2155//===---------------------------------------------------------------------===//
2156// PadTilingInterfaceOp
2157//===---------------------------------------------------------------------===//
2158
2159void transform::PadTilingInterfaceOp::build(OpBuilder &b,
2160 OperationState &result,
2161 Value target,
2162 ArrayRef<int64_t> paddingSizes,
2163 bool padToMultipleOf) {
2164 auto resultType = transform::AnyOpType::get(ctx: b.getContext());
2165 return build(/*builder=*/odsBuilder&: b,
2166 /*result=*/odsState&: result,
2167 /*types=*/resultTypes: TypeRange{resultType, resultType},
2168 /*target=*/target,
2169 /*paddingValues=*/padding_values: ArrayAttr(), // let inference handle this
2170 /*paddingSizes=*/padding_sizes: ValueRange{},
2171 /*paddingSizes=*/
2172 static_padding_sizes: (paddingSizes.empty() ? DenseI64ArrayAttr()
2173 : b.getDenseI64ArrayAttr(values: paddingSizes)),
2174 /*padToMultipleOf=*/
2175 pad_to_multiple_of: padToMultipleOf ? b.getUnitAttr() : nullptr);
2176}
2177
2178void transform::PadTilingInterfaceOp::build(
2179 OpBuilder &b, OperationState &result, Value target,
2180 ArrayRef<OpFoldResult> mixedPaddingSizes, bool padToMultipleOf) {
2181 auto resultType = transform::AnyOpType::get(ctx: b.getContext());
2182 SmallVector<int64_t> staticPaddingSizes;
2183 SmallVector<Value> dynamicPaddingSizes;
2184 dispatchIndexOpFoldResults(ofrs: mixedPaddingSizes, dynamicVec&: dynamicPaddingSizes,
2185 staticVec&: staticPaddingSizes);
2186 return build(/*builder=*/odsBuilder&: b,
2187 /*result=*/odsState&: result,
2188 /*types=*/resultTypes: TypeRange{resultType, resultType},
2189 /*target=*/target,
2190 /*paddingValues=*/padding_values: ArrayAttr(), // let inference handle this
2191 /*paddingSizes=*/padding_sizes: dynamicPaddingSizes,
2192 /*paddingSizes=*/static_padding_sizes: staticPaddingSizes,
2193 /*usePrescribedTensorShapes=*/pad_to_multiple_of: padToMultipleOf);
2194}
2195
2196void transform::PadTilingInterfaceOp::getEffects(
2197 SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
2198 consumesHandle(handles: getTargetMutable(), effects);
2199 onlyReadsHandle(handles: getPaddingSizesMutable(), effects);
2200 producesHandle(handles: getOperation()->getOpResults(), effects);
2201 modifiesPayload(effects);
2202}
2203
2204SmallVector<OpFoldResult>
2205transform::PadTilingInterfaceOp::getMixedPaddingSizes() {
2206 Builder b(getContext());
2207 return getMixedValues(staticValues: getStaticPaddingSizes(), dynamicValues: getPaddingSizes(), b);
2208}
2209
2210DiagnosedSilenceableFailure
2211transform::PadTilingInterfaceOp::apply(transform::TransformRewriter &rewriter,
2212 transform::TransformResults &results,
2213 transform::TransformState &state) {
2214 SmallVector<Operation *> paddedOps, padOps;
2215
2216 for (Operation *target : state.getPayloadOps(value: getTarget())) {
2217 auto targetOp = dyn_cast<TilingInterface>(Val: target);
2218 if (!targetOp) {
2219 auto diag = emitSilenceableError() << "expected TilingInterface target";
2220 diag.attachNote(loc: target->getLoc()) << "target op";
2221 return diag;
2222 }
2223
2224 // Only IndexingMapOpInterface ops for now, until TilingInterface exposes a
2225 // loopsToOperand map / C++ APIs to compute the effect of padding on
2226 // operands.
2227 if (!isa<IndexingMapOpInterface>(Val: targetOp.getOperation())) {
2228 auto diag = emitSilenceableError() << "only IndexingMapOpInterface ops "
2229 "supported atm";
2230 diag.attachNote(loc: target->getLoc()) << "target op";
2231 return diag;
2232 }
2233
2234 // Convert the padding values to attributes.
2235 SmallVector<Attribute> paddingValues;
2236 for (auto const &[untypedAttr, elementOrTensorType] :
2237 llvm::zip(t: getPaddingValues(), u: targetOp->getOperandTypes())) {
2238 auto attr = dyn_cast<TypedAttr>(Val: untypedAttr);
2239 Type elementType = getElementTypeOrSelf(type: elementOrTensorType);
2240 if (!attr) {
2241 emitOpError(message: "expects padding values to be typed attributes");
2242 return DiagnosedSilenceableFailure::definiteFailure();
2243 }
2244 // Try to parse string attributes to obtain an attribute of element type.
2245 if (auto stringAttr = dyn_cast<StringAttr>(Val&: attr)) {
2246 auto parsedAttr = dyn_cast_if_present<TypedAttr>(Val: parseAttribute(
2247 attrStr: stringAttr, context: getContext(), type: elementType,
2248 /*numRead=*/nullptr, /*isKnownNullTerminated=*/true));
2249 if (!parsedAttr || parsedAttr.getType() != elementType) {
2250 auto diag = this->emitOpError(message: "expects a padding that parses to ")
2251 << elementType << ", got " << attr;
2252 diag.attachNote(noteLoc: targetOp.getLoc()) << "when applied to this op";
2253 return DiagnosedSilenceableFailure::definiteFailure();
2254 }
2255 paddingValues.push_back(Elt: parsedAttr);
2256 continue;
2257 }
2258 // Otherwise, add the attribute directly.
2259 if (attr.getType() != elementType) {
2260 auto diag = this->emitOpError(message: "expects a padding value of type ")
2261 << elementType << ", got " << attr;
2262 diag.attachNote(noteLoc: targetOp.getLoc()) << "when applied to this op";
2263 return DiagnosedSilenceableFailure::definiteFailure();
2264 }
2265 paddingValues.push_back(Elt: attr);
2266 }
2267
2268 // Set options.
2269 TilingInterface paddedOp;
2270 PadTilingInterfaceOptions options;
2271 options.setPaddingValues(paddingValues)
2272 .setPaddingSizes(getMixedPaddingSizes())
2273 .setPadToMultipleOf(getPadToMultipleOf());
2274
2275 // Apply padding.
2276 SmallVector<tensor::PadOp> newPadOps;
2277 FailureOr<TilingInterface> maybePaddedOp = rewriteAsPaddedOp(
2278 rewriter, opToPad: cast<TilingInterface>(Val: targetOp.getOperation()), constOptions: options,
2279 padOps&: newPadOps);
2280 if (failed(Result: maybePaddedOp)) {
2281 auto diag = emitSilenceableError() << "failed to pad op";
2282 diag.attachNote(loc: target->getLoc()) << "target op";
2283 return diag;
2284 }
2285
2286 // Set transform results.
2287 paddedOps.push_back(Elt: cast<TilingInterface>(Val: maybePaddedOp->getOperation()));
2288 padOps.append(in_start: newPadOps.begin(), in_end: newPadOps.end());
2289 }
2290
2291 results.set(value: cast<OpResult>(Val: getPadded()), ops&: paddedOps);
2292 results.set(value: cast<OpResult>(Val: getPad()), ops&: padOps);
2293 return DiagnosedSilenceableFailure::success();
2294}
2295
2296LogicalResult transform::PadTilingInterfaceOp::verify() { return success(); }
2297
2298//===---------------------------------------------------------------------===//
2299// HoistPadOp
2300//===---------------------------------------------------------------------===//
2301
2302DiagnosedSilenceableFailure transform::HoistPadBuildPackingLoopNestOp::apply(
2303 transform::TransformRewriter &rewriter,
2304 transform::TransformResults &transformResults,
2305 transform::TransformState &state) {
2306 auto targetOps = state.getPayloadOps(value: getTarget());
2307 auto loopOps = state.getPayloadOps(value: getLoop());
2308 if (!llvm::hasSingleElement(C&: targetOps) || !llvm::hasSingleElement(C&: loopOps)) {
2309 return emitDefiniteFailure()
2310 << "requires exactly one target and one loop handle (got "
2311 << llvm::range_size(Range&: targetOps) << " and "
2312 << llvm::range_size(Range&: loopOps) << ")";
2313 }
2314
2315 auto padOp = dyn_cast_or_null<tensor::PadOp>(Val: *targetOps.begin());
2316 auto loopOp = dyn_cast_or_null<scf::ForOp>(Val: *loopOps.begin());
2317 if (!padOp || !loopOp)
2318 return emitDefiniteFailure() << "requires exactly 2 non-null handles";
2319
2320 FailureOr<linalg::detail::PackingResult> result =
2321 linalg::detail::buildPackingLoopNest(rewriter, opToHoist: padOp, outermostEnclosingForOp: loopOp,
2322 transposeVector: getTranspose());
2323 if (failed(Result: result))
2324 return emitDefiniteFailure() << "could not build packing loop nest";
2325
2326 if (result->clonedLoopIvs.empty()) {
2327 transformResults.set(value: cast<OpResult>(Val: getPackingLoop()),
2328 ops: {result->hoistedPadOp.getOperation()});
2329 return DiagnosedSilenceableFailure::success();
2330 }
2331 auto outerPackedLoop =
2332 scf::getForInductionVarOwner(val: result->clonedLoopIvs.front());
2333 transformResults.set(value: cast<OpResult>(Val: getPackingLoop()),
2334 ops: {outerPackedLoop.getOperation()});
2335 return DiagnosedSilenceableFailure::success();
2336}
2337
2338LogicalResult transform::HoistPadBuildPackingLoopNestOp::verify() {
2339 ArrayRef<int64_t> transpose = getTranspose();
2340 auto sequence = llvm::to_vector(Range: llvm::seq<int64_t>(Begin: 0, End: transpose.size()));
2341 if (!std::is_permutation(first1: sequence.begin(), last1: sequence.end(), first2: transpose.begin(),
2342 last2: transpose.end())) {
2343 return emitOpError() << "expects transpose to be a permutation, found "
2344 << getTranspose();
2345 }
2346 return success();
2347}
2348
2349void transform::HoistPadBuildPackingLoopNestOp::getEffects(
2350 SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
2351 transform::onlyReadsHandle(handles: getTargetMutable(), effects);
2352 transform::onlyReadsHandle(handles: getLoopMutable(), effects);
2353 transform::producesHandle(handles: getOperation()->getOpResults(), effects);
2354 transform::modifiesPayload(effects);
2355}
2356
2357DiagnosedSilenceableFailure
2358transform::HoistPadOp::applyToOne(transform::TransformRewriter &rewriter,
2359 tensor::PadOp target,
2360 transform::ApplyToEachResultList &results,
2361 transform::TransformState &state) {
2362 tensor::PadOp hoistedPadOp;
2363 SmallVector<TransposeOp> transposeOps;
2364 FailureOr<Value> result =
2365 hoistPaddingOnTensors(rewriter, opToHoist: target, numLoops: getNumLoops(), transposeVector: getTranspose(),
2366 hoistedOp&: hoistedPadOp, transposeOps);
2367 if (succeeded(Result: result)) {
2368 // We need to perform our own replacement here because this API is still
2369 // used in patterns that "pad and hoist", for which the replacement values
2370 // need to be different.
2371 // TODO: clean this up and stop "pad and hoist" behavior more globally now
2372 // that we have more composable abstractions.
2373 rewriter.replaceOp(op: target, newValues: *result);
2374 results.push_back(op: hoistedPadOp);
2375 return DiagnosedSilenceableFailure::success();
2376 }
2377 return emitDefaultSilenceableFailure(target);
2378}
2379
2380LogicalResult transform::HoistPadOp::verify() {
2381 ArrayRef<int64_t> transpose = getTranspose();
2382 auto sequence = llvm::to_vector(Range: llvm::seq<int64_t>(Begin: 0, End: transpose.size()));
2383 if (!std::is_permutation(first1: sequence.begin(), last1: sequence.end(), first2: transpose.begin(),
2384 last2: transpose.end())) {
2385 return emitOpError() << "expects transpose to be a permutation, found "
2386 << getTranspose();
2387 }
2388 return success();
2389}
2390
2391//===----------------------------------------------------------------------===//
2392// PromoteOp
2393//===----------------------------------------------------------------------===//
2394
2395DiagnosedSilenceableFailure
2396transform::PromoteOp::applyToOne(transform::TransformRewriter &rewriter,
2397 LinalgOp target,
2398 transform::ApplyToEachResultList &results,
2399 transform::TransformState &state) {
2400 LinalgPromotionOptions promotionOptions;
2401 if (!getOperandsToPromote().empty())
2402 promotionOptions = promotionOptions.setOperandsToPromote(
2403 extractFromIntegerArrayAttr<int64_t>(attr: getOperandsToPromote()));
2404 if (getUseFullTilesByDefault())
2405 promotionOptions = promotionOptions.setUseFullTileBuffersByDefault(
2406 getUseFullTilesByDefault());
2407 if (getUseOriginalSubviewSize())
2408 promotionOptions =
2409 promotionOptions.setUseOriginalSubviewSize(getUseOriginalSubviewSize());
2410 if (getUseAlloca())
2411 promotionOptions = promotionOptions.setUseAlloca(getUseAlloca());
2412 if (!getUseFullTileBuffers().empty())
2413 promotionOptions = promotionOptions.setUseFullTileBuffers(
2414 llvm::to_vector(Range: getUseFullTileBuffers().getAsValueRange<BoolAttr>()));
2415 if (getAlignment().has_value())
2416 promotionOptions = promotionOptions.setAlignment(*getAlignment());
2417 if (getMemorySpace().has_value())
2418 promotionOptions = promotionOptions.setMemorySpace(*getMemorySpace());
2419
2420 if (getMapping().has_value()) {
2421 // The mapping should only contain an element
2422 auto mapping = *getMapping();
2423 if (mapping.size() > 1)
2424 return emitDefaultDefiniteFailure(target);
2425
2426 auto addressSpace = cast<mlir::gpu::GPUMemorySpaceMappingAttr>(Val: mapping[0]);
2427
2428 if (addressSpace.getAddressSpace() ==
2429 mlir::gpu::GPUDialect::getWorkgroupAddressSpace()) {
2430 promotionOptions =
2431 promotionOptions
2432 .setAllocationDeallocationFns(allocFn: allocateWorkgroupMemory,
2433 deallocFn: deallocateWorkgroupMemory)
2434 .setCopyInOutFns(copyIn: copyToWorkgroupMemory, copyOut: copyToWorkgroupMemory)
2435 .setUseFullTileBuffers({false, false});
2436 } else if (addressSpace.getAddressSpace() ==
2437 mlir::gpu::GPUDialect::getPrivateAddressSpace()) {
2438 promotionOptions =
2439 promotionOptions
2440 .setAllocationDeallocationFns(allocFn: allocateGPUPrivateMemory,
2441 deallocFn: deallocateGPUPrivateMemory)
2442 .setCopyInOutFns(copyIn: copyToGPUPrivateMemory, copyOut: copyToGPUPrivateMemory)
2443 .setUseFullTileBuffers({false, false});
2444 } else {
2445 return emitDefaultDefiniteFailure(target);
2446 }
2447 }
2448
2449 if (failed(Result: promoteSubviewsPrecondition(op: target, options: promotionOptions)))
2450 return emitDefaultDefiniteFailure(target);
2451
2452 rewriter.setInsertionPoint(target);
2453 FailureOr<LinalgOp> res = promoteSubViews(b&: rewriter, op: target, options: promotionOptions);
2454 if (failed(Result: res))
2455 return emitDefaultDefiniteFailure(target);
2456 results.push_back(op: target);
2457 return DiagnosedSilenceableFailure::success();
2458}
2459
2460//===----------------------------------------------------------------------===//
2461// ReplaceOp
2462//===----------------------------------------------------------------------===//
2463
2464DiagnosedSilenceableFailure
2465transform::ReplaceOp::apply(transform::TransformRewriter &rewriter,
2466 TransformResults &transformResults,
2467 TransformState &state) {
2468 auto payload = state.getPayloadOps(value: getTarget());
2469
2470 // Check for invalid targets.
2471 for (Operation *target : payload) {
2472 if (target->getNumOperands() > 0)
2473 return emitDefiniteFailure() << "expected target without operands";
2474 if (!target->hasTrait<OpTrait::IsIsolatedFromAbove>() &&
2475 target->getNumRegions() > 0)
2476 return emitDefiniteFailure()
2477 << "expected target that is isolated from above";
2478 }
2479
2480 // Clone and replace.
2481 Operation *pattern = &getBodyRegion().front().front();
2482 SmallVector<Operation *> replacements;
2483 for (Operation *target : payload) {
2484 if (getOperation()->isAncestor(other: target))
2485 continue;
2486 rewriter.setInsertionPoint(target);
2487 Operation *replacement = rewriter.clone(op&: *pattern);
2488 rewriter.replaceOp(op: target, newValues: replacement->getResults());
2489 replacements.push_back(Elt: replacement);
2490 }
2491 transformResults.set(value: cast<OpResult>(Val: getReplacement()), ops&: replacements);
2492 return DiagnosedSilenceableFailure::success();
2493}
2494
2495void transform::ReplaceOp::getEffects(
2496 SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
2497 consumesHandle(handles: getTargetMutable(), effects);
2498 producesHandle(handles: getOperation()->getOpResults(), effects);
2499 modifiesPayload(effects);
2500}
2501
2502LogicalResult transform::ReplaceOp::verify() {
2503 if (!getBodyRegion().hasOneBlock())
2504 return emitOpError() << "expected one block";
2505 if (std::distance(first: getBodyRegion().front().begin(),
2506 last: getBodyRegion().front().end()) != 1)
2507 return emitOpError() << "expected one operation in block";
2508 Operation *replacement = &getBodyRegion().front().front();
2509 if (replacement->getNumOperands() > 0)
2510 return replacement->emitOpError()
2511 << "expected replacement without operands";
2512 if (!replacement->hasTrait<OpTrait::IsIsolatedFromAbove>() &&
2513 replacement->getNumRegions() > 0)
2514 return replacement->emitOpError()
2515 << "expect op that is isolated from above";
2516 return success();
2517}
2518
2519//===----------------------------------------------------------------------===//
2520// ScalarizeOp
2521//===----------------------------------------------------------------------===//
2522
2523DiagnosedSilenceableFailure
2524transform::ScalarizeOp::applyToOne(transform::TransformRewriter &rewriter,
2525 LinalgOp target,
2526 transform::ApplyToEachResultList &results,
2527 transform::TransformState &state) {
2528 scf::SCFTilingOptions tilingOptions;
2529 tilingOptions.setTileSizeComputationFunction([&](OpBuilder &b, Operation *) {
2530 SmallVector<OpFoldResult> tileSizes;
2531 Location loc = target.getLoc();
2532 SmallVector<OpFoldResult> allShapeSizes =
2533 target.createFlatListOfOperandDims(b, loc);
2534 AffineMap map = target.getShapesToLoopsMap();
2535 if (!map)
2536 return tileSizes;
2537 SmallVector<OpFoldResult> shapeSizes =
2538 affine::makeComposedFoldedMultiResultAffineApply(b&: rewriter, loc, map,
2539 operands: allShapeSizes);
2540 // If the shape size is dynamic, tile by 1.
2541 // Otherwise, do not tile (i.e. tile size 0).
2542 for (OpFoldResult shapeSize : shapeSizes) {
2543 tileSizes.push_back(Elt: getConstantIntValue(ofr: shapeSize) ? b.getIndexAttr(value: 0)
2544 : b.getIndexAttr(value: 1));
2545 }
2546 return tileSizes;
2547 });
2548 rewriter.setInsertionPoint(target);
2549 FailureOr<scf::SCFTilingResult> maybeTilingResult = tileUsingSCF(
2550 rewriter, op: cast<TilingInterface>(Val: target.getOperation()), options: tilingOptions);
2551 if (failed(Result: maybeTilingResult))
2552 return emitDefaultDefiniteFailure(target);
2553
2554 if (target->getNumResults())
2555 rewriter.replaceOp(op: target, newValues: maybeTilingResult->replacements);
2556 else
2557 rewriter.eraseOp(op: target);
2558
2559 results.reserve(size: maybeTilingResult->tiledOps.size());
2560 for (Operation *tiled : maybeTilingResult->tiledOps)
2561 results.push_back(op: tiled);
2562 return DiagnosedSilenceableFailure::success();
2563}
2564
2565//===----------------------------------------------------------------------===//
2566// ConvertToLoopsOp
2567//===----------------------------------------------------------------------===//
2568
2569DiagnosedSilenceableFailure
2570transform::ConvertToLoopsOp::apply(transform::TransformRewriter &rewriter,
2571 transform::TransformResults &results,
2572 transform::TransformState &state) {
2573 SmallVector<Operation *> loops;
2574 for (Operation *target : state.getPayloadOps(value: getTarget())) {
2575 auto tilingOp = dyn_cast<TilingInterface>(Val&: *target);
2576 if (!tilingOp) {
2577 DiagnosedSilenceableFailure diag =
2578 emitSilenceableError()
2579 << "expected the payload to implement TilingInterface";
2580 diag.attachNote(loc: target->getLoc()) << "payload op";
2581 return diag;
2582 }
2583 rewriter.setInsertionPoint(target);
2584 FailureOr<SmallVector<scf::ForOp>> generatedLoops =
2585 scf::lowerToLoopsUsingSCFForOp(rewriter, op: tilingOp);
2586 if (failed(Result: generatedLoops))
2587 return emitDefaultDefiniteFailure(target);
2588 for (scf::ForOp &loop : *generatedLoops) {
2589 loops.push_back(Elt: loop.getOperation());
2590 }
2591 rewriter.eraseOp(op: target);
2592 }
2593 results.set(value: cast<OpResult>(Val: getResult()), ops&: loops);
2594 return DiagnosedSilenceableFailure::success();
2595}
2596
2597//===----------------------------------------------------------------------===//
2598// RewriteInDestinationPassingStyleOp
2599//===----------------------------------------------------------------------===//
2600
2601DiagnosedSilenceableFailure
2602transform::RewriteInDestinationPassingStyleOp::applyToOne(
2603 transform::TransformRewriter &rewriter, Operation *target,
2604 transform::ApplyToEachResultList &results,
2605 transform::TransformState &state) {
2606 rewriter.setInsertionPoint(target);
2607 FailureOr<Operation *> maybeResult =
2608 TypeSwitch<Operation *, FailureOr<Operation *>>(target)
2609 .Case<tensor::FromElementsOp, tensor::GenerateOp, tensor::PadOp>(
2610 caseFn: [&rewriter](auto op) {
2611 return rewriteInDestinationPassingStyle(rewriter, op);
2612 });
2613 if (failed(Result: maybeResult))
2614 return emitDefaultSilenceableFailure(target);
2615 results.push_back(op: *maybeResult);
2616 return DiagnosedSilenceableFailure::success();
2617}
2618
2619//===----------------------------------------------------------------------===//
2620// SplitOp
2621//===----------------------------------------------------------------------===//
2622
2623DiagnosedSilenceableFailure
2624SplitOp::apply(transform::TransformRewriter &rewriter,
2625 TransformResults &results, TransformState &state) {
2626 // Collect the dynamic split points if provided.
2627 SmallVector<Operation *> payload =
2628 llvm::to_vector(Range: state.getPayloadOps(value: getTarget()));
2629
2630 bool isMultiwaySplit = getMultiway();
2631
2632 if (isMultiwaySplit && !llvm::hasSingleElement(C&: payload)) {
2633 return mlir::emitSilenceableFailure(loc: getLoc())
2634 << "requires exactly one target when "
2635 "multiway split is enabled (got "
2636 << llvm::range_size(Range&: payload) << ")";
2637 }
2638
2639 SmallVector<OpFoldResult> chunkSizes;
2640
2641 if (!isMultiwaySplit)
2642 chunkSizes.reserve(N: payload.size());
2643
2644 if (getDynamicChunkSizes()) {
2645 auto diag = DiagnosedSilenceableFailure::success();
2646 if (isa<TransformHandleTypeInterface>(Val: getDynamicChunkSizes().getType())) {
2647 chunkSizes = llvm::to_vector(Range: llvm::map_range(
2648 C: state.getPayloadOps(value: getDynamicChunkSizes()), F: [&](Operation *op) {
2649 if (op->getNumResults() != 1 ||
2650 !op->getResult(idx: 0).getType().isIndex()) {
2651 diag = emitSilenceableError()
2652 << "expected dynamic split point handle to point to a "
2653 "single-result index-typed op";
2654 diag.attachNote(loc: op->getLoc()) << "dynamic split point";
2655 }
2656 return OpFoldResult(op->getResult(idx: 0));
2657 }));
2658 } else {
2659 chunkSizes = llvm::to_vector(
2660 Range: llvm::map_range(C: state.getParams(value: getDynamicChunkSizes()),
2661 F: [](Attribute attr) { return OpFoldResult(attr); }));
2662 }
2663 if (diag.isSilenceableFailure())
2664 return diag;
2665
2666 // For multiway split, a single payload is expected to have multiple
2667 // split points.
2668 if (!isMultiwaySplit && chunkSizes.size() != payload.size()) {
2669 return emitDefiniteFailure()
2670 << "expected the dynamic split point handle to point to as "
2671 "many operations ("
2672 << chunkSizes.size() << ") as the target handle ("
2673 << payload.size() << ")";
2674 }
2675 } else {
2676 chunkSizes.resize(N: payload.size(),
2677 NV: rewriter.getIndexAttr(value: getStaticChunkSizes()));
2678 }
2679
2680 auto checkStructuredOpAndDimensions =
2681 [&](LinalgOp linalgOp, Location loc) -> DiagnosedSilenceableFailure {
2682 if (!linalgOp) {
2683 auto diag = emitSilenceableError() << "only applies to structured ops";
2684 diag.attachNote(loc) << "target op";
2685 return diag;
2686 }
2687
2688 if (getDimension() >= linalgOp.getNumLoops()) {
2689 auto diag = emitSilenceableError() << "dimension " << getDimension()
2690 << " does not exist in target op";
2691 diag.attachNote(loc) << "target op";
2692 return diag;
2693 }
2694 return DiagnosedSilenceableFailure::success();
2695 };
2696
2697 auto checkFailureInSplitting =
2698 [&](bool hasFailed, Location loc) -> DiagnosedSilenceableFailure {
2699 if (hasFailed) {
2700 auto diag = emitDefiniteFailure() << "internal failure in splitting";
2701 diag.attachNote(loc) << "target op";
2702 return diag;
2703 }
2704 return DiagnosedSilenceableFailure::success();
2705 };
2706
2707 SmallVector<Operation *> opList;
2708 if (isMultiwaySplit) {
2709
2710 // Split a single target operation at multiple points.
2711 TilingInterface head, tail;
2712 Operation *target = payload.front();
2713
2714 LinalgOp linalgOp = dyn_cast<LinalgOp>(Val: target);
2715
2716 // Check that the target is a valid LinalgOp with correct dimensions.
2717 DiagnosedSilenceableFailure diag =
2718 checkStructuredOpAndDimensions(linalgOp, target->getLoc());
2719 if (diag.isSilenceableFailure())
2720 return diag;
2721
2722 for (auto &&[idx, chunkSize] : llvm::enumerate(First&: chunkSizes)) {
2723
2724 if (idx > 0)
2725 target = tail.getOperation();
2726
2727 if (!target)
2728 break;
2729
2730 linalgOp = cast<LinalgOp>(Val: target);
2731 Location loc = target->getLoc();
2732
2733 rewriter.setInsertionPoint(linalgOp);
2734 std::tie(args&: head, args&: tail) = linalg::splitOp(
2735 rewriter, op: cast<TilingInterface>(Val: linalgOp.getOperation()),
2736 dimension: getDimension(), splitPoint: chunkSize);
2737
2738 // Propagate errors.
2739 DiagnosedSilenceableFailure diag =
2740 checkFailureInSplitting(!head && !tail, loc);
2741 if (diag.isDefiniteFailure())
2742 return diag;
2743
2744 opList.push_back(Elt: head.getOperation());
2745 }
2746
2747 // Append any leftover parts to the end of the result list.
2748 if (tail)
2749 opList.push_back(Elt: tail.getOperation());
2750
2751 } else {
2752 // Split each target operation.
2753 SmallVector<Operation *> first, second;
2754 Operation *noSecondPart = nullptr;
2755 for (const auto &pair : llvm::zip(t&: payload, u&: chunkSizes)) {
2756 Operation *target = std::get<0>(t: pair);
2757 Location loc = target->getLoc();
2758 LinalgOp linalgOp = dyn_cast<LinalgOp>(Val: target);
2759 DiagnosedSilenceableFailure diag =
2760 checkStructuredOpAndDimensions(linalgOp, target->getLoc());
2761
2762 if (diag.isSilenceableFailure())
2763 return diag;
2764
2765 rewriter.setInsertionPoint(linalgOp);
2766 std::tie(args&: first.emplace_back(), args&: second.emplace_back()) = linalg::splitOp(
2767 rewriter, op: cast<TilingInterface>(Val: linalgOp.getOperation()),
2768 dimension: getDimension(), splitPoint: std::get<1>(t: pair));
2769
2770 // Propagate errors.
2771 DiagnosedSilenceableFailure diagSplit =
2772 checkFailureInSplitting(!first.back() && !second.back(), loc);
2773 if (diagSplit.isDefiniteFailure())
2774 return diag;
2775
2776 // Do not add null second parts.
2777 if (!second.back()) {
2778 noSecondPart = target;
2779 second.pop_back();
2780 }
2781 }
2782
2783 if (second.size() != first.size() && !second.empty()) {
2784 auto diag = emitSilenceableError()
2785 << "splitting does not produce the second part for a subset "
2786 "of targets";
2787 diag.attachNote()
2788 << "expected splitting to produce the second part of all "
2789 "or none of the targets";
2790 diag.attachNote(loc: noSecondPart->getLoc())
2791 << "first target with no second part";
2792 return diag;
2793 }
2794
2795 opList.append(RHS: first);
2796 if (second.size())
2797 opList.append(RHS: second);
2798 }
2799 results.set(value: cast<OpResult>(Val: getSplitList()), ops&: opList);
2800 return DiagnosedSilenceableFailure::success();
2801}
2802
2803void SplitOp::getEffects(
2804 SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
2805 consumesHandle(handles: getTargetMutable(), effects);
2806 if (getDynamicChunkSizes())
2807 onlyReadsHandle(handles: getDynamicChunkSizesMutable(), effects);
2808 producesHandle(handles: getOperation()->getOpResults(), effects);
2809 modifiesPayload(effects);
2810}
2811
2812ParseResult SplitOp::parse(OpAsmParser &parser, OperationState &result) {
2813 OpAsmParser::UnresolvedOperand target, dynamicChunkSizes;
2814 IntegerAttr staticChunkSizes;
2815 if (parser.parseOperand(result&: target) || parser.parseKeyword(keyword: "after"))
2816 return failure();
2817
2818 OptionalParseResult dynamicPointParseResult =
2819 parser.parseOptionalOperand(result&: dynamicChunkSizes);
2820 if (!dynamicPointParseResult.has_value()) {
2821 int64_t staticChunkSizesValue;
2822 if (failed(Result: parser.parseInteger(result&: staticChunkSizesValue)))
2823 return failure();
2824
2825 staticChunkSizes =
2826 parser.getBuilder().getI64IntegerAttr(value: staticChunkSizesValue);
2827 }
2828
2829 Type targetType;
2830 if (parser.parseOptionalAttrDict(result&: result.attributes) ||
2831 parser.parseColonType(result&: targetType) ||
2832 parser.resolveOperand(operand: target, type: targetType, result&: result.operands)) {
2833 return failure();
2834 }
2835 if (dynamicPointParseResult.has_value()) {
2836 Type ChunkSizesType;
2837 if (failed(Result: *dynamicPointParseResult) || parser.parseComma() ||
2838 parser.parseType(result&: ChunkSizesType) ||
2839 parser.resolveOperand(operand: dynamicChunkSizes, type: ChunkSizesType,
2840 result&: result.operands)) {
2841 return failure();
2842 }
2843
2844 staticChunkSizes =
2845 parser.getBuilder().getI64IntegerAttr(value: ShapedType::kDynamic);
2846 }
2847
2848 result.addAttribute(
2849 name: SplitOp::getStaticChunkSizesAttrName(name: result.name).getValue(),
2850 attr: staticChunkSizes);
2851 result.addTypes(newTypes: targetType);
2852 return success();
2853}
2854
2855void SplitOp::print(OpAsmPrinter &printer) {
2856 printer << " " << getTarget() << " after ";
2857 int64_t staticChunkSize = static_cast<int64_t>(getStaticChunkSizes());
2858 if (staticChunkSize != ShapedType::kDynamic)
2859 printer << staticChunkSize;
2860 else
2861 printer << getDynamicChunkSizes();
2862 printer << " ";
2863 printer.printOptionalAttrDict(attrs: getOperation()->getAttrs(),
2864 elidedAttrs: {getStaticChunkSizesAttrName()});
2865 printer << " : " << getTarget().getType();
2866 if (staticChunkSize == ShapedType::kDynamic)
2867 printer << ", " << getDynamicChunkSizes().getType();
2868}
2869
2870LogicalResult SplitOp::verify() {
2871 if ((static_cast<int64_t>(getStaticChunkSizes()) != ShapedType::kDynamic) ^
2872 (getDynamicChunkSizes() == nullptr)) {
2873 return emitOpError() << "expects either a dynamic or a static split "
2874 "point to be provided";
2875 }
2876 return success();
2877}
2878
2879//===----------------------------------------------------------------------===//
2880// SplitReductionOp
2881//===----------------------------------------------------------------------===//
2882
2883void transform::SplitReductionOp::build(
2884 OpBuilder &builder, OperationState &result, Value target,
2885 int64_t splitFactor, int64_t insertSplitDimension, bool innerParallel,
2886 bool useScalingAlgorithm, bool useAlloc) {
2887 MLIRContext *ctx = builder.getContext();
2888 result.addOperands(newOperands: target);
2889 result.addAttribute(name: SplitReductionOp::getSplitFactorAttrName(name: result.name),
2890 attr: builder.getI64IntegerAttr(value: splitFactor));
2891 result.addAttribute(
2892 name: SplitReductionOp::getInsertSplitDimensionAttrName(name: result.name),
2893 attr: builder.getI64IntegerAttr(value: insertSplitDimension));
2894 if (innerParallel) {
2895 result.addAttribute(name: SplitReductionOp::getInnerParallelAttrName(name: result.name),
2896 attr: builder.getUnitAttr());
2897 }
2898 if (useScalingAlgorithm) {
2899 result.addAttribute(
2900 name: SplitReductionOp::getUseScalingAlgorithmAttrName(name: result.name),
2901 attr: builder.getUnitAttr());
2902 }
2903 if (useAlloc) {
2904 result.addAttribute(name: SplitReductionOp::getUseAllocAttrName(name: result.name),
2905 attr: builder.getUnitAttr());
2906 }
2907 auto resultType = transform::AnyOpType::get(ctx);
2908 result.addTypes(newTypes: {resultType, resultType, resultType, resultType});
2909}
2910
2911DiagnosedSilenceableFailure transform::SplitReductionOp::applyToOne(
2912 transform::TransformRewriter &rewriter, LinalgOp target,
2913 transform::ApplyToEachResultList &results,
2914 transform::TransformState &state) {
2915 ControlSplitReductionFn splitFn = [&](LinalgOp) {
2916 return linalg::SplitReductionOptions{.ratio: int64_t(getSplitFactor()),
2917 .index: unsigned(getInsertSplitDimension()),
2918 .innerParallel: bool(getInnerParallel())};
2919 };
2920 rewriter.setInsertionPoint(target);
2921 FailureOr<SplitReductionResult> splitResult =
2922 (getUseScalingAlgorithm())
2923 ? splitReductionByScaling(b&: rewriter, op: target, controlSplitReductionFn: splitFn, useAlloc: getUseAlloc())
2924 : splitReduction(b&: rewriter, op: target, controlSplitReductionFn: splitFn, useAlloc: getUseAlloc());
2925 if (failed(Result: splitResult))
2926 return emitDefaultDefiniteFailure(target);
2927
2928 results.push_back(op: splitResult->initOrAlloc);
2929 results.push_back(op: splitResult->fillOp);
2930 results.push_back(op: splitResult->splitLinalgOp);
2931 results.push_back(op: splitResult->resultCombiningLinalgOp);
2932 return DiagnosedSilenceableFailure::success();
2933}
2934
2935//===----------------------------------------------------------------------===//
2936// TileReductionUsingForOp
2937//===----------------------------------------------------------------------===//
2938
2939void transform::TileReductionUsingForOp::build(
2940 OpBuilder &builder, OperationState &result, Value target,
2941 ArrayRef<int64_t> staticTileSizes) {
2942 // Call the default builder.
2943 // This is future-proof re mixed static-dynamic and setting up the proper
2944 // operands segment sizes attributes for multiple variadic operands.
2945 // In the absence of this, horrible bugs ensue.
2946 // TODO: support mixed static-dynamic (see TileUsingForallOp).
2947 MLIRContext *ctx = builder.getContext();
2948 auto opTy = transform::AnyOpType::get(ctx);
2949 auto staticTileSizesAttr = builder.getI64ArrayAttr(values: staticTileSizes);
2950 build(odsBuilder&: builder, odsState&: result,
2951 /*resultTypes=*/TypeRange{opTy, opTy, opTy, opTy},
2952 /*target=*/target,
2953 /*reduction_dims=*/nullptr,
2954 /*tile_sizes=*/staticTileSizesAttr);
2955}
2956
2957DiagnosedSilenceableFailure transform::TileReductionUsingForOp::applyToOne(
2958 transform::TransformRewriter &rewriter, Operation *target,
2959 transform::ApplyToEachResultList &results,
2960 transform::TransformState &state) {
2961 rewriter.setInsertionPoint(target);
2962
2963 auto partialReductionOp = dyn_cast<PartialReductionOpInterface>(Val: target);
2964 if (!partialReductionOp) {
2965 return emitSilenceableFailure(
2966 loc: target->getLoc(),
2967 message: "Operation should implement PartialReductionOpInterface");
2968 }
2969
2970 SmallVector<unsigned> reductionDims =
2971 extractFromIntegerArrayAttr<unsigned>(attr: getReductionDims());
2972 if (reductionDims.empty()) {
2973 for (auto [idx, iteratorType] :
2974 llvm::enumerate(First: partialReductionOp.getLoopIteratorTypes())) {
2975 if (iteratorType == utils::IteratorType::reduction)
2976 reductionDims.push_back(Elt: idx);
2977 }
2978 }
2979
2980 scf::SCFTilingOptions options;
2981 options.setLoopType(scf::SCFTilingOptions::LoopType::ForOp);
2982 options.setReductionTilingStrategy(
2983 ReductionTilingStrategy::PartialReductionOuterReduction);
2984 options.setTileSizes(getAsOpFoldResult(arrayAttr: getTileSizesAttr()));
2985 options.setReductionDims(reductionDims);
2986 FailureOr<scf::SCFTilingResult> result =
2987 scf::tileUsingSCF(rewriter, op: partialReductionOp, options);
2988
2989 if (failed(Result: result)) {
2990 return emitSilenceableFailure(loc: getLoc(),
2991 message: "failed to tile using partial reduction");
2992 }
2993 rewriter.replaceOp(op: target, newValues: result->replacements);
2994 for (Value initValue : result->initialValues)
2995 results.push_back(op: initValue.getDefiningOp());
2996 for (auto parallelTiledOp : result->tiledOps)
2997 results.push_back(op: parallelTiledOp);
2998 for (auto mergeOp : result->mergeOps)
2999 results.push_back(op: mergeOp);
3000 results.push_back(op: result->loops.front());
3001 return DiagnosedSilenceableFailure::success();
3002}
3003
3004//===----------------------------------------------------------------------===//
3005// TileReductionUsingForallOp
3006//===----------------------------------------------------------------------===//
3007
3008void transform::TileReductionUsingForallOp::build(
3009 OpBuilder &builder, OperationState &result, Value target,
3010 ArrayRef<int64_t> staticNumThreads, ArrayRef<int64_t> staticTileSizes,
3011 ArrayAttr mapping) {
3012 // Call the default builder.
3013 // This is future-proof re mixed static-dynamic and setting up the proper
3014 // operands segment sizes attributes for multiple variadic operands.
3015 // In the absence of this, horrible bugs ensue.
3016 // TODO: support mixed static-dynamic (see TileUsingForallOp).
3017 MLIRContext *ctx = builder.getContext();
3018 auto opTy = transform::AnyOpType::get(ctx);
3019 auto staticNumThreadsAttr = builder.getDenseI64ArrayAttr(values: staticNumThreads);
3020 auto staticTileSizesAttr = builder.getDenseI64ArrayAttr(values: staticTileSizes);
3021 build(odsBuilder&: builder, odsState&: result,
3022 /*resultTypes=*/TypeRange{opTy, opTy, opTy, opTy},
3023 /*target=*/target,
3024 /*reduction_dims=*/{},
3025 /*num_threads=*/staticNumThreadsAttr,
3026 /*tile_sizes=*/staticTileSizesAttr,
3027 /*mapping=*/mapping);
3028}
3029
3030DiagnosedSilenceableFailure transform::TileReductionUsingForallOp::applyToOne(
3031 transform::TransformRewriter &rewriter, LinalgOp target,
3032 transform::ApplyToEachResultList &results,
3033 transform::TransformState &state) {
3034 rewriter.setInsertionPoint(target);
3035 SmallVector<OpFoldResult> numThreads =
3036 getAsOpFoldResult(arrayAttr: rewriter.getI64ArrayAttr(values: getNumThreads()));
3037 SmallVector<OpFoldResult> tileSizes =
3038 getAsOpFoldResult(arrayAttr: rewriter.getI64ArrayAttr(values: getTileSizes()));
3039
3040 scf::SCFTilingOptions options;
3041 options.setLoopType(scf::SCFTilingOptions::LoopType::ForallOp);
3042 options.setReductionTilingStrategy(
3043 ReductionTilingStrategy::PartialReductionOuterParallel);
3044 if (!getNumThreads().empty()) {
3045 options.setNumThreads(numThreads);
3046 } else {
3047 options.setTileSizes(tileSizes);
3048 }
3049 if (auto mapping = getMapping()) {
3050 options.setMapping(mapping.value().getValue());
3051 }
3052 SmallVector<unsigned> reductionDims =
3053 extractFromIntegerArrayAttr<unsigned>(attr: getReductionDims());
3054 if (reductionDims.empty()) {
3055 for (auto [idx, iteratorType] :
3056 llvm::enumerate(First: target.getIteratorTypesArray())) {
3057 if (iteratorType == utils::IteratorType::reduction)
3058 reductionDims.push_back(Elt: idx);
3059 }
3060 }
3061 options.setReductionDims(reductionDims);
3062 FailureOr<scf::SCFTilingResult> result = scf::tileUsingSCF(
3063 rewriter, op: cast<TilingInterface>(Val: target.getOperation()), options);
3064
3065 if (failed(Result: result)) {
3066 auto diag = emitSilenceableError() << "could not tile reduction";
3067 return diag;
3068 }
3069 rewriter.replaceOp(op: target, newValues: result->replacements);
3070
3071 for (Value initValue : result->initialValues)
3072 results.push_back(op: initValue.getDefiningOp());
3073 for (auto parallelTiledOp : result->tiledOps)
3074 results.push_back(op: parallelTiledOp);
3075 for (auto mergeOp : result->mergeOps)
3076 results.push_back(op: mergeOp);
3077 results.push_back(op: result->loops.front());
3078 return DiagnosedSilenceableFailure::success();
3079}
3080
3081//===----------------------------------------------------------------------===//
3082// ContinuousTileSizesOp
3083//===----------------------------------------------------------------------===//
3084
3085DiagnosedSilenceableFailure
3086transform::ContinuousTileSizesOp::apply(transform::TransformRewriter &rewriter,
3087 TransformResults &transformResults,
3088 TransformState &state) {
3089
3090 SmallVector<Operation *> targetOps =
3091 llvm::to_vector(Range: state.getPayloadOps(value: getTarget()));
3092
3093 if (!llvm::hasSingleElement(C&: targetOps)) {
3094 return mlir::emitSilenceableFailure(loc: getLoc())
3095 << "requires exactly one target (got " << llvm::range_size(Range&: targetOps)
3096 << ")";
3097 }
3098
3099 Operation *target = *targetOps.begin();
3100 auto linalgOp = dyn_cast<LinalgOp>(Val: target);
3101 auto tileableOp = dyn_cast<TilingInterface>(Val: target);
3102
3103 if (!linalgOp)
3104 return emitDefiniteFailure() << "expected Linalg Op";
3105
3106 OpBuilder builder(linalgOp.getContext());
3107
3108 if (isa<TransformParamTypeInterface>(Val: getChunkSizes().getType())) {
3109 if (linalgOp.hasDynamicShape()) {
3110 auto diag = emitSilenceableError()
3111 << "cannot compute parametric tile sizes for dynamically "
3112 "shaped payload op";
3113 diag.attachNote(loc: linalgOp->getLoc()) << "payload op";
3114 return diag;
3115 }
3116
3117 FailureOr<StaticContinuousTileSizeSpecification> spec =
3118 computeStaticContinuousTileSizes(op: linalgOp, dimension: getDimension(),
3119 targetSize: getTargetSize());
3120 if (failed(Result: spec)) {
3121 return emitSilenceableError()
3122 << "failed to compute multi-size tiling sizes";
3123 }
3124
3125 SmallVector<int64_t> chunkSizes;
3126
3127 for (auto &&[tileSize, tripCount] :
3128 llvm::zip_equal(t&: spec->tileSizes, u&: spec->tripCounts))
3129 chunkSizes.push_back(Elt: tileSize * tripCount);
3130
3131 auto getI64AttrsFromI64 = [&](ArrayRef<int64_t> values) {
3132 return llvm::map_to_vector(C&: values, F: [&](int64_t value) -> Attribute {
3133 return builder.getI64IntegerAttr(value);
3134 });
3135 };
3136 transformResults.setParams(value: cast<OpResult>(Val: getTileSizes()),
3137 params: getI64AttrsFromI64(spec->tileSizes));
3138 transformResults.setParams(value: cast<OpResult>(Val: getChunkSizes()),
3139 params: getI64AttrsFromI64(chunkSizes));
3140
3141 return DiagnosedSilenceableFailure::success();
3142 }
3143
3144 builder.setInsertionPoint(linalgOp);
3145
3146 OpFoldResult targetSize = builder.getIndexAttr(value: getTargetSize());
3147 unsigned dimension = getDimension();
3148
3149 FailureOr<ContinuousTileSizeSpecification> spec = computeContinuousTileSizes(
3150 builder, op: tileableOp, dimension, targetSize, emitAssertions: true);
3151 if (failed(Result: spec)) {
3152 return emitSilenceableError() << "could not generate tile size computation";
3153 }
3154
3155 AffineExpr s0 = builder.getAffineSymbolExpr(position: 0);
3156 AffineExpr s1 = builder.getAffineSymbolExpr(position: 1);
3157 auto apply = [&](AffineExpr expr, ArrayRef<OpFoldResult> ofrs) -> Value {
3158 return affine::makeComposedAffineApply(b&: builder, loc: linalgOp->getLoc(), e: expr,
3159 operands: ofrs);
3160 };
3161
3162 SmallVector<Value> chunkSizes;
3163 Value splitPoint;
3164 for (auto &&[tileSize, tripCount] :
3165 llvm::zip_equal(t&: spec->tileSizes, u&: spec->tripCounts)) {
3166 splitPoint = apply(s0 * s1, {tileSize, tripCount});
3167 chunkSizes.push_back(Elt: splitPoint);
3168 }
3169
3170 auto getDefiningOps = [&](ArrayRef<Value> values) {
3171 return llvm::map_to_vector(C&: values, F: [&](Value value) -> Operation * {
3172 return value.getDefiningOp();
3173 });
3174 };
3175
3176 transformResults.set(value: cast<OpResult>(Val: getTileSizes()),
3177 ops: getDefiningOps(spec->tileSizes));
3178 transformResults.set(value: cast<OpResult>(Val: getChunkSizes()),
3179 ops: getDefiningOps(chunkSizes));
3180
3181 return DiagnosedSilenceableFailure::success();
3182}
3183
3184LogicalResult transform::ContinuousTileSizesOp::verify() {
3185
3186 if (getTileSizes().getType() != getChunkSizes().getType()) {
3187 return emitOpError() << "expects all results type to be the same";
3188 }
3189
3190 return success();
3191}
3192
3193void transform::ContinuousTileSizesOp::getEffects(
3194 SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
3195 if (isa<TransformParamTypeInterface>(Val: getTileSizes().getType()))
3196 onlyReadsPayload(effects);
3197 else
3198 modifiesPayload(effects);
3199 onlyReadsHandle(handles: getTargetMutable(), effects);
3200 producesHandle(handles: getOperation()->getOpResults(), effects);
3201}
3202
3203static void printContinuousTileSizeTypes(OpAsmPrinter &printer, Operation *op,
3204 Type targetType, Type tile_sizes,
3205 Type) {
3206 printer.printFunctionalType(inputs: TypeRange{targetType}, results: TypeRange{tile_sizes});
3207}
3208
3209static ParseResult parseContinuousTileSizeTypes(OpAsmParser &parser,
3210 Type &targetType,
3211 Type &tileSizesType,
3212 Type &chunkSizesType) {
3213 FunctionType funcType;
3214 llvm::SMLoc typeLoc = parser.getCurrentLocation();
3215 if (failed(Result: parser.parseType<FunctionType>(result&: funcType)))
3216 return failure();
3217
3218 if (funcType.getNumInputs() != 1 || funcType.getNumResults() != 1) {
3219 parser.emitError(loc: typeLoc) << "expects a trailing functional type with one "
3220 "argument and one result";
3221 }
3222 targetType = funcType.getInput(i: 0);
3223 tileSizesType = chunkSizesType = funcType.getResult(i: 0);
3224
3225 return success();
3226}
3227
3228//===----------------------------------------------------------------------===//
3229// TileUsingForOp
3230//===----------------------------------------------------------------------===//
3231
3232void transform::TileUsingForOp::build(
3233 OpBuilder &builder, OperationState &result, TypeRange loopTypes,
3234 Value target, ArrayRef<int64_t> staticTileSizes,
3235 ArrayRef<int64_t> interchange,
3236 std::optional<ArrayRef<bool>> scalableSizes) {
3237 return build(odsBuilder&: builder, odsState&: result, loopTypes,
3238 /*target=*/target,
3239 /*mixedTileSizes=*/
3240 getAsOpFoldResult(arrayAttr: builder.getI64ArrayAttr(values: staticTileSizes)),
3241 interchange, scalableSizes);
3242}
3243
3244void transform::TileUsingForOp::build(
3245 OpBuilder &builder, OperationState &result, Value target,
3246 ArrayRef<int64_t> staticTileSizes, ArrayRef<int64_t> interchange,
3247 std::optional<ArrayRef<bool>> scalableSizes) {
3248 build(odsBuilder&: builder, odsState&: result, target,
3249 mixedTileSizes: getAsOpFoldResult(arrayAttr: builder.getI64ArrayAttr(values: staticTileSizes)),
3250 interchange, scalableSizes);
3251}
3252
3253void transform::TileUsingForOp::build(
3254 OpBuilder &builder, OperationState &result, Value target,
3255 ArrayRef<OpFoldResult> mixedTileSizes, ArrayRef<int64_t> interchange,
3256 std::optional<ArrayRef<bool>> scalableSizes) {
3257 // Loop types are automaticaly splat by the callee, setting up one is
3258 // enough.
3259 SmallVector<Type> loopTypes(1, builder.getType<transform::AnyOpType>());
3260 build(odsBuilder&: builder, odsState&: result, loopTypes, target, mixedTileSizes, interchange,
3261 scalableSizes);
3262}
3263
3264void transform::TileUsingForOp::build(
3265 OpBuilder &builder, OperationState &result, TypeRange loopTypes,
3266 Value target, ArrayRef<OpFoldResult> mixedTileSizes,
3267 ArrayRef<int64_t> interchange,
3268 std::optional<ArrayRef<bool>> scalableSizes) {
3269 SmallVector<int64_t> staticTileSizes;
3270 SmallVector<Value> dynamicTileSizes;
3271 dispatchIndexOpFoldResults(ofrs: mixedTileSizes, dynamicVec&: dynamicTileSizes, staticVec&: staticTileSizes);
3272 // Call the default builder which sets up the proper operands segment sizes
3273 // attributes for multiple variadic operands. In the absence of this,
3274 // horrible bugs ensue.
3275 auto staticTileSizesAttr = builder.getDenseI64ArrayAttr(values: staticTileSizes);
3276 unsigned numExpectedLoops =
3277 staticTileSizes.size() - llvm::count(Range&: staticTileSizes, Element: 0);
3278 SmallVector<Type> resultTypes;
3279 resultTypes.reserve(N: numExpectedLoops);
3280 assert((loopTypes.size() == 1 || loopTypes.size() == numExpectedLoops) &&
3281 "expected one loop type or as many as loops");
3282 if (loopTypes.size() == 1)
3283 resultTypes.append(NumInputs: numExpectedLoops, Elt: loopTypes[0]);
3284 else
3285 llvm::append_range(C&: resultTypes, R&: loopTypes);
3286 SmallVector<bool> expandedScalableSizes(mixedTileSizes.size(), false);
3287 if (scalableSizes.has_value())
3288 expandedScalableSizes.assign(in_start: scalableSizes->begin(), in_end: scalableSizes->end());
3289 build(odsBuilder&: builder, odsState&: result, /*tiled_linalg_op=*/target.getType(),
3290 /*loops=*/resultTypes,
3291 /*target=*/target,
3292 /*dynamic_sizes=*/dynamicTileSizes,
3293 /*static_sizes=*/staticTileSizesAttr,
3294 /*interchange=*/builder.getDenseI64ArrayAttr(values: interchange),
3295 /*scalable_sizes=*/expandedScalableSizes);
3296}
3297
3298LogicalResult transform::TileUsingForOp::verify() {
3299 if (getMixedSizes().size() != getScalableSizes().size())
3300 return emitOpError(message: "expected same number of sizes (")
3301 << getMixedSizes().size() << ") and scalable sizes ("
3302 << getScalableSizes().size() << ")";
3303 ArrayRef<int64_t> staticSizes = getStaticSizes();
3304 unsigned numExpectedLoops = staticSizes.size() - llvm::count(Range&: staticSizes, Element: 0);
3305 if (getLoops().size() != numExpectedLoops)
3306 return emitOpError(message: "expected number of loops to tile (")
3307 << numExpectedLoops << ") to match number of `loops` results ("
3308 << getLoops().size() << ")";
3309 return success();
3310}
3311
3312DiagnosedSilenceableFailure
3313transform::TileUsingForOp::apply(transform::TransformRewriter &rewriter,
3314 TransformResults &transformResults,
3315 TransformState &state) {
3316 ArrayRef<int64_t> tileSizes = getStaticSizes();
3317
3318 SmallVector<Operation *> targets =
3319 llvm::to_vector(Range: state.getPayloadOps(value: getTarget()));
3320 SmallVector<SmallVector<Operation *>> dynamicSizeProducers;
3321 SmallVector<SmallVector<int64_t>> paramSizes;
3322 dynamicSizeProducers.reserve(N: getDynamicSizes().size());
3323 paramSizes.reserve(N: getDynamicSizes().size());
3324 for (Value transformValue : getDynamicSizes()) {
3325 if (isa<ParamType>(Val: transformValue.getType())) {
3326 dynamicSizeProducers.push_back(Elt: {});
3327 ArrayRef<Attribute> params = state.getParams(value: transformValue);
3328 paramSizes.push_back(
3329 Elt: llvm::to_vector(Range: llvm::map_range(C&: params, F: [](Attribute attr) {
3330 return cast<IntegerAttr>(Val&: attr).getValue().getSExtValue();
3331 })));
3332
3333 if (paramSizes.back().size() != targets.size()) {
3334 DiagnosedSilenceableFailure diag =
3335 emitSilenceableError()
3336 << "expected as many parameter values ("
3337 << dynamicSizeProducers.back().size() << ") as target ops ("
3338 << targets.size() << ")";
3339 diag.attachNote(loc: transformValue.getLoc()) << "for this parameter";
3340 return diag;
3341 }
3342
3343 continue;
3344 }
3345 paramSizes.push_back(Elt: {});
3346 dynamicSizeProducers.push_back(
3347 Elt: llvm::to_vector(Range: state.getPayloadOps(value: transformValue)));
3348
3349 if (dynamicSizeProducers.back().size() != targets.size()) {
3350 DiagnosedSilenceableFailure diag =
3351 emitSilenceableError()
3352 << "expected as many dynamic size-producing operations ("
3353 << dynamicSizeProducers.back().size() << ") as target ops ("
3354 << targets.size() << ")";
3355 diag.attachNote(loc: transformValue.getLoc()) << "for this handle";
3356 return diag;
3357 }
3358
3359 for (Operation *op : dynamicSizeProducers.back()) {
3360 if (op->getNumResults() == 1 &&
3361 isa<IndexType>(Val: op->getResult(idx: 0).getType())) {
3362 continue;
3363 }
3364
3365 DiagnosedSilenceableFailure diag =
3366 emitSilenceableError() << "expected sizes to be produced by ops "
3367 "with a single index-type result";
3368 diag.attachNote(loc: op->getLoc()) << "size producer op";
3369 diag.attachNote(loc: transformValue.getLoc()) << "for this handle";
3370 return diag;
3371 }
3372 }
3373
3374 SmallVector<Operation *> tiled;
3375 SmallVector<SmallVector<Operation *, 4>, 4> loops;
3376 loops.resize(N: getLoops().size());
3377 auto scalableSizes = getScalableSizes();
3378 for (auto [i, op] : llvm::enumerate(First&: targets)) {
3379 auto tilingInterface = dyn_cast<TilingInterface>(Val: op);
3380 if (!tilingInterface) {
3381 DiagnosedSilenceableFailure diag =
3382 emitSilenceableError()
3383 << "only ops implementing TilingInterface are supported";
3384 diag.attachNote(loc: op->getLoc()) << "target op";
3385 return diag;
3386 }
3387 if (tileSizes.size() > tilingInterface.getLoopIteratorTypes().size()) {
3388 DiagnosedSilenceableFailure diag =
3389 emitSilenceableError()
3390 << "too many tiles provided, expected at most "
3391 << tilingInterface.getLoopIteratorTypes().size() << " found "
3392 << tileSizes.size();
3393 diag.attachNote(loc: op->getLoc()) << "target op";
3394 return diag;
3395 }
3396
3397 scf::SCFTilingOptions tilingOptions;
3398 if (tileSizes.empty()) {
3399 tilingOptions.setTileSizeComputationFunction(
3400 [](OpBuilder &, Operation *) -> SmallVector<OpFoldResult> {
3401 return {};
3402 });
3403 } else {
3404 tilingOptions.setTileSizeComputationFunction([&, index = i](OpBuilder &b,
3405 Operation *) {
3406 SmallVector<OpFoldResult> sizes;
3407 sizes.reserve(N: tileSizes.size());
3408 unsigned dynamicIdx = 0;
3409
3410 for (auto [ofrIdx, ofr] : llvm::enumerate(First: getMixedSizes())) {
3411 if (auto attr = llvm::dyn_cast_if_present<Attribute>(Val&: ofr)) {
3412 if (scalableSizes[ofrIdx]) {
3413 auto val = b.create<arith::ConstantIndexOp>(
3414 location: getLoc(), args: cast<IntegerAttr>(Val&: attr).getInt());
3415 Value vscale =
3416 b.create<vector::VectorScaleOp>(location: getLoc(), args: b.getIndexType());
3417 sizes.push_back(
3418 Elt: b.create<arith::MulIOp>(location: getLoc(), args&: val, args&: vscale).getResult());
3419 } else {
3420 sizes.push_back(Elt: attr);
3421 }
3422 continue;
3423 }
3424 ArrayRef<Operation *> dynamicSizes = dynamicSizeProducers[dynamicIdx];
3425 ArrayRef<int64_t> params = paramSizes[dynamicIdx];
3426 ++dynamicIdx;
3427 assert((dynamicSizes.empty() ^ params.empty()) &&
3428 "expected either dynamic sizes or parameters");
3429 if (!params.empty()) {
3430 sizes.push_back(Elt: b.getIndexAttr(value: params[index]));
3431 } else {
3432 sizes.push_back(Elt: dynamicSizes[index]->getResult(idx: 0));
3433 }
3434 }
3435 return sizes;
3436 });
3437 }
3438
3439 tilingOptions.setInterchange(getInterchange());
3440 FailureOr<scf::SCFTilingResult> maybeTilingResult =
3441 tileUsingSCF(rewriter, op: tilingInterface, options: tilingOptions);
3442 if (failed(Result: maybeTilingResult))
3443 return DiagnosedSilenceableFailure::definiteFailure();
3444
3445 rewriter.replaceOp(op, newValues: maybeTilingResult->replacements);
3446
3447 tiled.append(RHS: maybeTilingResult->tiledOps);
3448 for (const auto &en2 : llvm::enumerate(First&: maybeTilingResult->loops))
3449 loops[en2.index()].push_back(Elt: en2.value());
3450 }
3451
3452 transformResults.set(value: cast<OpResult>(Val: getTiledLinalgOp()), ops&: tiled);
3453 for (const auto &en : llvm::enumerate(First&: loops))
3454 transformResults.set(value: cast<OpResult>(Val: getLoops()[en.index()]), ops&: en.value());
3455
3456 return DiagnosedSilenceableFailure::success();
3457}
3458
3459SmallVector<OpFoldResult> transform::TileUsingForOp::getMixedSizes() {
3460 ValueRange dynamic = getDynamicSizes();
3461 ArrayRef<int64_t> tileSizes = getStaticSizes();
3462 SmallVector<OpFoldResult> results;
3463 results.reserve(N: tileSizes.size());
3464 unsigned dynamicPos = 0;
3465 Builder builder(getContext());
3466 for (int64_t size : tileSizes) {
3467 if (size == ShapedType::kDynamic) {
3468 results.push_back(Elt: dynamic[dynamicPos++]);
3469 } else {
3470 results.push_back(Elt: builder.getIndexAttr(value: size));
3471 }
3472 }
3473 return results;
3474}
3475
3476void transform::TileUsingForOp::getEffects(
3477 SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
3478 consumesHandle(handles: getTargetMutable(), effects);
3479 onlyReadsHandle(handles: getDynamicSizesMutable(), effects);
3480 producesHandle(handles: getOperation()->getOpResults(), effects);
3481 modifiesPayload(effects);
3482}
3483
3484//===----------------------------------------------------------------------===//
3485// TileUsingForallOp
3486//===----------------------------------------------------------------------===//
3487
3488void transform::TileUsingForallOp::build(OpBuilder &builder,
3489 OperationState &result, Value target,
3490 ArrayRef<int64_t> staticTileSizes,
3491 transform::TileSizesSpec,
3492 ArrayAttr mapping) {
3493 return build(odsBuilder&: builder, odsState&: result,
3494 /*target=*/target,
3495 /*mixedTileSizes=*/
3496 getAsOpFoldResult(arrayAttr: builder.getI64ArrayAttr(values: staticTileSizes)),
3497 /*_=*/odsArg2: TileSizesSpec(),
3498 /*mapping=*/mapping);
3499}
3500
3501void transform::TileUsingForallOp::build(OpBuilder &builder,
3502 OperationState &result, Value target,
3503 ArrayRef<OpFoldResult> mixedTileSizes,
3504 transform::TileSizesSpec,
3505 ArrayAttr mapping) {
3506 SmallVector<int64_t> staticTileSizes;
3507 SmallVector<Value> dynamicTileSizes;
3508 dispatchIndexOpFoldResults(ofrs: mixedTileSizes, dynamicVec&: dynamicTileSizes, staticVec&: staticTileSizes);
3509 // Call the default builder which sets up the proper operands segment sizes
3510 // attributes for multiple variadic operands. In the absence of this,
3511 // horrible bugs ensue.
3512 MLIRContext *ctx = builder.getContext();
3513 auto operationType = transform::AnyOpType::get(ctx);
3514 auto staticTileSizesAttr = builder.getDenseI64ArrayAttr(values: staticTileSizes);
3515 build(odsBuilder&: builder, odsState&: result,
3516 /*resultTypes=*/TypeRange{operationType, operationType},
3517 /*target=*/target,
3518 /*num_threads=*/ValueRange{},
3519 /*tile_sizes=*/dynamicTileSizes,
3520 /*packed_num_threads=*/Value(),
3521 /*packed_tile_sizes=*/Value(),
3522 /*static_num_threads=*/builder.getDenseI64ArrayAttr(values: {}),
3523 /*static_tile_sizes=*/staticTileSizesAttr,
3524 /*mapping=*/mapping);
3525}
3526
3527void transform::TileUsingForallOp::build(OpBuilder &builder,
3528 OperationState &result, Value target,
3529 ArrayRef<int64_t> staticNumThreads,
3530 transform::NumThreadsSpec,
3531 ArrayAttr mapping) {
3532 return build(odsBuilder&: builder, odsState&: result, target,
3533 mixedNumThreads: getAsOpFoldResult(arrayAttr: builder.getI64ArrayAttr(values: staticNumThreads)),
3534 odsArg2: NumThreadsSpec(), mapping);
3535}
3536
3537void transform::TileUsingForallOp::build(OpBuilder &builder,
3538 OperationState &result, Value target,
3539 ArrayRef<OpFoldResult> mixedNumThreads,
3540 transform::NumThreadsSpec,
3541 ArrayAttr mapping) {
3542 SmallVector<int64_t> staticNumThreads;
3543 SmallVector<Value> dynamicNumThreads;
3544 dispatchIndexOpFoldResults(ofrs: mixedNumThreads, dynamicVec&: dynamicNumThreads,
3545 staticVec&: staticNumThreads);
3546 // Call the default builder which sets up the proper operands segment sizes
3547 // attributes for multiple variadic operands. In the absence of this,
3548 // horrible bugs ensue.
3549 MLIRContext *ctx = builder.getContext();
3550 auto operationType = transform::AnyOpType::get(ctx);
3551 auto staticNumThreadsAttr = builder.getDenseI64ArrayAttr(values: staticNumThreads);
3552 build(odsBuilder&: builder, odsState&: result,
3553 /*resultTypes=*/TypeRange{operationType, operationType},
3554 /*target=*/target,
3555 /*num_threads=*/dynamicNumThreads,
3556 /*tile_sizes=*/ValueRange{},
3557 /*packed_num_threads=*/Value(),
3558 /*packed_tile_sizes=*/Value(),
3559 /*static_num_threads=*/staticNumThreadsAttr,
3560 /*static_tile_sizes=*/builder.getDenseI64ArrayAttr(values: {}),
3561 /*mapping=*/mapping);
3562}
3563
3564/// Given `lbs`, `ubs` and `steps` of loops, return (for each loop), the
3565/// normalized upper bound.
3566static SmallVector<OpFoldResult>
3567normalizeUpperBounds(RewriterBase &rewriter, Location loc,
3568 ArrayRef<OpFoldResult> lbs, ArrayRef<OpFoldResult> ubs,
3569 ArrayRef<OpFoldResult> steps) {
3570 AffineExpr s0, s1, s2;
3571 bindSymbols(ctx: rewriter.getContext(), exprs&: s0, exprs&: s1, exprs&: s2);
3572 AffineExpr normalizedUbExpr = (s1 - s0).ceilDiv(other: s2);
3573 SmallVector<OpFoldResult> normalizedUbs;
3574 for (auto [lb, ub, step] : llvm::zip_equal(t&: lbs, u&: ubs, args&: steps)) {
3575 OpFoldResult normalizedUb = affine::makeComposedFoldedAffineApply(
3576 b&: rewriter, loc, expr: normalizedUbExpr, operands: {lb, ub, step});
3577 normalizedUbs.push_back(Elt: normalizedUb);
3578 }
3579 return normalizedUbs;
3580}
3581
3582/// When a loop is normalized, the uses of the induction variable within the
3583/// loop need to replaced with `original_lb + old_iv * original_step`.
3584static SmallVector<Value> denormalizeIndVar(RewriterBase &rewriter,
3585 Location loc, ValueRange ivs,
3586 ArrayRef<OpFoldResult> lbs,
3587 ArrayRef<OpFoldResult> steps) {
3588 AffineExpr s0, s1;
3589 AffineExpr d0;
3590 bindSymbols(ctx: rewriter.getContext(), exprs&: s0, exprs&: s1);
3591 bindDims(ctx: rewriter.getContext(), exprs&: d0);
3592 AffineExpr denormExpr = s0 + d0 * s1;
3593 SmallVector<Value> denormalizedIvs;
3594
3595 for (auto [iv, lb, step] : llvm::zip_equal(t&: ivs, u&: lbs, args&: steps)) {
3596 OpFoldResult denormValue = affine::makeComposedFoldedAffineApply(
3597 b&: rewriter, loc, expr: denormExpr, operands: ArrayRef<OpFoldResult>{iv, lb, step});
3598 denormalizedIvs.push_back(
3599 Elt: getValueOrCreateConstantIndexOp(b&: rewriter, loc, ofr: denormValue));
3600 }
3601 return denormalizedIvs;
3602}
3603
3604/// Given a `scf.forall` loop return a loop op with the loop bounds
3605/// normalized.
3606/// TODO: Replace this with a general utility to normalize `scf.forall`.
3607/// At the time of writing, this wasnt done since adding this to `scf`
3608/// dialect would disallow using of `affine.apply` operations due
3609/// to cyclic dependencies. To avoid churn in lit tests
3610/// with the change this was added with, defer that to a follow up.
3611static scf::ForallOp normalizeForallLoopOp(RewriterBase &rewriter,
3612 scf::ForallOp loop) {
3613 SmallVector<OpFoldResult> lbs = loop.getMixedLowerBound();
3614 SmallVector<OpFoldResult> ubs = loop.getMixedUpperBound();
3615 SmallVector<OpFoldResult> steps = loop.getMixedStep();
3616
3617 if (llvm::all_of(Range&: lbs, P: isZeroInteger) && llvm::all_of(Range&: steps, P: isOneInteger)) {
3618 return loop;
3619 }
3620
3621 Location loc = loop.getLoc();
3622 SmallVector<OpFoldResult> normalizedUbs =
3623 normalizeUpperBounds(rewriter, loc, lbs, ubs, steps);
3624 SmallVector<OpFoldResult> normalizedLbs(normalizedUbs.size(),
3625 rewriter.getIndexAttr(value: 0));
3626 SmallVector<OpFoldResult> normalizedSteps(normalizedUbs.size(),
3627 rewriter.getIndexAttr(value: 1));
3628
3629 auto normalizedForallOp = rewriter.create<scf::ForallOp>(
3630 location: loc, args&: normalizedLbs, args&: normalizedUbs, args&: normalizedSteps, args: loop.getOutputs(),
3631 args: loop.getMapping(), args: [](OpBuilder &, Location, ValueRange) {});
3632
3633 auto normalizedLoopIvs = normalizedForallOp.getInductionVars();
3634 OpBuilder::InsertionGuard g(rewriter);
3635 Block *normalizedLoopBlock = normalizedForallOp.getBody();
3636 rewriter.setInsertionPointToStart(normalizedLoopBlock);
3637
3638 SmallVector<Value> argValues =
3639 denormalizeIndVar(rewriter, loc, ivs: normalizedLoopIvs, lbs, steps);
3640 argValues.append(in_start: normalizedForallOp.getRegionIterArgs().begin(),
3641 in_end: normalizedForallOp.getRegionIterArgs().end());
3642 Block *origLoopBlock = loop.getBody();
3643 rewriter.mergeBlocks(source: origLoopBlock, dest: normalizedLoopBlock, argValues);
3644
3645 rewriter.replaceOp(op: loop, newOp: normalizedForallOp);
3646 return normalizedForallOp;
3647}
3648
3649DiagnosedSilenceableFailure transform::tileToForallOpImpl(
3650 RewriterBase &rewriter, transform::TransformState &state,
3651 TransformOpInterface transformOp, Operation *target,
3652 ArrayRef<OpFoldResult> mixedNumThreads,
3653 ArrayRef<OpFoldResult> mixedTileSizes, std::optional<ArrayAttr> mapping,
3654 scf::SCFTilingResult &tilingResult) {
3655 // Transform all targets one by one.
3656 auto tileableOp = dyn_cast<TilingInterface>(Val: target);
3657 if (!tileableOp) {
3658 DiagnosedSilenceableFailure diag =
3659 transformOp.emitSilenceableError()
3660 << "only TilingInterface ops are supported";
3661 diag.attachNote(loc: target->getLoc()) << "target op";
3662 return diag;
3663 }
3664 rewriter.setInsertionPoint(tileableOp);
3665 scf::SCFTilingOptions options;
3666 options.setLoopType(scf::SCFTilingOptions::LoopType::ForallOp);
3667 if (!mixedNumThreads.empty()) {
3668 options.setNumThreads(mixedNumThreads);
3669 } else {
3670 options.setTileSizes(mixedTileSizes);
3671 }
3672 if (mapping) {
3673 options.setMapping(mapping.value().getValue());
3674 }
3675 FailureOr<scf::SCFTilingResult> maybeTilingResult =
3676 scf::tileUsingSCF(rewriter, op: tileableOp, options);
3677
3678 if (failed(Result: maybeTilingResult))
3679 return transformOp.emitDefaultSilenceableFailure(target: tileableOp);
3680
3681 rewriter.replaceOp(op: tileableOp, newValues: maybeTilingResult->replacements);
3682
3683 tilingResult = *maybeTilingResult;
3684
3685 if (mixedNumThreads.empty()) {
3686 auto generatedForallOp = cast<scf::ForallOp>(Val&: tilingResult.loops.front());
3687 OpBuilder::InsertionGuard g(rewriter);
3688 rewriter.setInsertionPoint(generatedForallOp);
3689 scf::ForallOp normalizedForallOp =
3690 normalizeForallLoopOp(rewriter, loop: generatedForallOp);
3691 tilingResult.loops.front() = normalizedForallOp;
3692 }
3693
3694 return DiagnosedSilenceableFailure::success();
3695}
3696
3697DiagnosedSilenceableFailure transform::TileUsingForallOp::apply(
3698 transform::TransformRewriter &rewriter,
3699 transform::TransformResults &transformResults,
3700 transform::TransformState &state) {
3701 auto transformOp = cast<TransformOpInterface>(Val: getOperation());
3702
3703 // Result payload ops.
3704 SmallVector<Operation *> tileOps;
3705 SmallVector<Operation *> tiledOps;
3706
3707 // Unpack handles.
3708 SmallVector<OpFoldResult> mixedNumThreads;
3709 DiagnosedSilenceableFailure status =
3710 getPackedNumThreads()
3711 ? unpackSingleIndexResultPayloadOperations(
3712 state, transformOp, result&: mixedNumThreads, packedHandle: getPackedNumThreads())
3713 : unpackSingleIndexResultPayloadOperations(
3714 state, transformOp, result&: mixedNumThreads, ofrs: getMixedNumThreads());
3715 if (!status.succeeded())
3716 return status;
3717 SmallVector<OpFoldResult> mixedTileSizes;
3718 status = getPackedTileSizes()
3719 ? unpackSingleIndexResultPayloadOperations(
3720 state, transformOp, result&: mixedTileSizes, packedHandle: getPackedTileSizes())
3721 : unpackSingleIndexResultPayloadOperations(
3722 state, transformOp, result&: mixedTileSizes, ofrs: getMixedTileSizes());
3723 if (!status.succeeded())
3724 return status;
3725
3726 for (Operation *target : state.getPayloadOps(value: getTarget())) {
3727 scf::SCFTilingResult tilingResult;
3728 DiagnosedSilenceableFailure diag = tileToForallOpImpl(
3729 rewriter, state, transformOp, target, mixedNumThreads, mixedTileSizes,
3730 mapping: getMapping(), tilingResult);
3731 if (!diag.succeeded())
3732 return diag;
3733 tileOps.push_back(Elt: tilingResult.loops.front());
3734 tiledOps.append(RHS: tilingResult.tiledOps);
3735 }
3736
3737 transformResults.set(value: cast<OpResult>(Val: getForallOp()), ops&: tileOps);
3738 transformResults.set(value: cast<OpResult>(Val: getTiledOp()), ops&: tiledOps);
3739
3740 return DiagnosedSilenceableFailure::success();
3741}
3742
3743void transform::TileUsingForallOp::getEffects(
3744 SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
3745 consumesHandle(handles: getTargetMutable(), effects);
3746 onlyReadsHandle(handles: getTileSizesMutable(), effects);
3747 onlyReadsHandle(handles: getNumThreadsMutable(), effects);
3748 onlyReadsHandle(handles: getPackedNumThreadsMutable(), effects);
3749 onlyReadsHandle(handles: getPackedTileSizesMutable(), effects);
3750 producesHandle(handles: getOperation()->getOpResults(), effects);
3751 modifiesPayload(effects);
3752}
3753
3754SmallVector<OpFoldResult> TileUsingForallOp::getMixedNumThreads() {
3755 Builder b(getContext());
3756 return getMixedValues(staticValues: getStaticNumThreads(), dynamicValues: getNumThreads(), b);
3757}
3758
3759SmallVector<OpFoldResult> TileUsingForallOp::getMixedTileSizes() {
3760 Builder b(getContext());
3761 return getMixedValues(staticValues: getStaticTileSizes(), dynamicValues: getTileSizes(), b);
3762}
3763
3764LogicalResult TileUsingForallOp::verify() {
3765 int numThreadsSpec = static_cast<int>(!getMixedNumThreads().empty()) +
3766 static_cast<int>(getPackedNumThreads() != Value());
3767 if (numThreadsSpec > 1)
3768 return emitOpError(
3769 message: "num_threads and packed_num_threads are mutually exclusive");
3770 int tileSizesSpec = static_cast<int>(!getMixedTileSizes().empty()) +
3771 static_cast<int>(getPackedTileSizes() != Value());
3772 if (tileSizesSpec > 1)
3773 return emitOpError(
3774 message: "tile_sizes and packed_tile_sizes are mutually exclusive");
3775 if (numThreadsSpec == 0 && tileSizesSpec == 0)
3776 return emitOpError(message: "either (packed_)num_threads or (packed_)tile_sizes "
3777 "must be specified");
3778 return success();
3779}
3780
3781//===----------------------------------------------------------------------===//
3782// VectorizeChildrenAndApplyPatternsOp
3783//===----------------------------------------------------------------------===//
3784
3785void transform::VectorizeChildrenAndApplyPatternsOp::build(
3786 OpBuilder &builder, OperationState &result, Value target,
3787 bool vectorizePadding, bool vectorizeExtract, bool flatten1DDepthwiseConv) {
3788 result.addOperands(newOperands: target);
3789 if (vectorizePadding) {
3790 result.addAttribute(
3791 name: VectorizeChildrenAndApplyPatternsOp::getVectorizePaddingAttrName(
3792 name: result.name),
3793 attr: builder.getUnitAttr());
3794 }
3795 if (vectorizeExtract) {
3796 result.addAttribute(
3797 name: VectorizeChildrenAndApplyPatternsOp::getVectorizeNdExtractAttrName(
3798 name: result.name),
3799 attr: builder.getUnitAttr());
3800 }
3801 if (flatten1DDepthwiseConv) {
3802 result.addAttribute(
3803 name: VectorizeChildrenAndApplyPatternsOp::getFlatten_1dDepthwiseConvAttrName(
3804 name: result.name),
3805 attr: builder.getUnitAttr());
3806 }
3807 result.addTypes(newTypes: transform::AnyOpType::get(ctx: builder.getContext()));
3808}
3809
3810namespace {
3811/// This is an helper only to call vectorize via a pattern inside of
3812/// VectorizeChildrenAndApplyPatternsOp::applyToOne.
3813struct VectorizationPattern : public RewritePattern {
3814 explicit VectorizationPattern(MLIRContext *context,
3815 bool vectorizeExtract = false,
3816 bool flattenConv = false)
3817 : RewritePattern(MatchAnyOpTypeTag(), /*benefit=*/1, context),
3818 vectorizeNDExtract(vectorizeExtract),
3819 flatten1DDepthwiseConv(flattenConv) {}
3820 LogicalResult matchAndRewrite(Operation *op,
3821 PatternRewriter &rewriter) const override {
3822 if (!linalg::hasVectorizationImpl(op))
3823 return rewriter.notifyMatchFailure(arg&: op,
3824 msg: "Unsupported Op, cannot vectorize");
3825 FailureOr<VectorizationResult> vectorResults =
3826 vectorize(rewriter, op, /*inputVectorSizes=*/{},
3827 /*inputScalableVecDims=*/{}, vectorizeNDExtract,
3828 flatten1DDepthwiseConv);
3829 if (failed(Result: vectorResults))
3830 return failure();
3831 rewriter.replaceOp(op, newValues: vectorResults->replacements);
3832 return success();
3833 }
3834
3835private:
3836 /// Controls whether to vectorize `tensor.extract` when the input tensor is
3837 /// rank >= 2.
3838 bool vectorizeNDExtract = false;
3839 /// Controls whether to "flatten" the channel dimension when vectorising 1D
3840 /// depthwise convolutions. This should lead to bette vectorization for
3841 /// tensors with a low number of channel dimensions.
3842 bool flatten1DDepthwiseConv = false;
3843};
3844} // namespace
3845
3846DiagnosedSilenceableFailure
3847transform::VectorizeChildrenAndApplyPatternsOp::applyToOne(
3848 transform::TransformRewriter &rewriter, Operation *target,
3849 transform::ApplyToEachResultList &results,
3850 transform::TransformState &state) {
3851 if (!target->hasTrait<OpTrait::IsIsolatedFromAbove>()) {
3852 auto diag = this->emitOpError(message: "requires isolated-from-above targets");
3853 diag.attachNote(noteLoc: target->getLoc()) << "non-isolated target";
3854 return DiagnosedSilenceableFailure::definiteFailure();
3855 }
3856
3857 MLIRContext *ctx = getContext();
3858 RewritePatternSet patterns(ctx);
3859 patterns.add<VectorizationPattern>(arg&: ctx, args: getVectorizeNdExtract(),
3860 args: getFlatten_1dDepthwiseConv());
3861
3862 if (!getDisableTransferPermutationMapLoweringPatterns())
3863 vector::populateVectorTransferPermutationMapLoweringPatterns(patterns);
3864
3865 if (!getDisableMultiReductionToContractPatterns())
3866 vector::populateVectorReductionToContractPatterns(patterns);
3867
3868 vector::populateSinkVectorOpsPatterns(patterns);
3869
3870 patterns.add<linalg::LinalgCopyVTRForwardingPattern,
3871 linalg::LinalgCopyVTWForwardingPattern>(arg&: ctx,
3872 /*benefit=*/args: 2);
3873 vector::TransferReadOp::getCanonicalizationPatterns(results&: patterns, context: ctx);
3874 vector::TransferWriteOp::getCanonicalizationPatterns(results&: patterns, context: ctx);
3875 tensor::populateFoldTensorSubsetIntoVectorTransferPatterns(patterns);
3876
3877 patterns.add<CopyVectorizationPattern>(arg&: ctx);
3878
3879 if (getVectorizePadding()) {
3880 linalg::populatePadOpVectorizationPatterns(patterns);
3881 // This creates an alternative path for lowering tensor.pad - by
3882 // decomposing it into e.g. linalg.fill.
3883 linalg::populateDecomposePadPatterns(patterns);
3884 }
3885 vector::populateVectorStepLoweringPatterns(patterns);
3886
3887 TrackingListener listener(state, *this);
3888 if (failed(
3889 Result: applyPatternsGreedily(op: target, patterns: std::move(patterns),
3890 config: GreedyRewriteConfig().setListener(&listener))))
3891 return emitDefaultDefiniteFailure(target);
3892
3893 results.push_back(op: target);
3894 return DiagnosedSilenceableFailure::success();
3895}
3896
3897//===----------------------------------------------------------------------===//
3898// VectorizeOp
3899//===----------------------------------------------------------------------===//
3900
3901DiagnosedSilenceableFailure transform::VectorizeOp::apply(
3902 transform::TransformRewriter &rewriter,
3903 mlir::transform::TransformResults &transformResults,
3904 mlir::transform::TransformState &state) {
3905 auto targets = state.getPayloadOps(value: getTarget());
3906 if (std::empty(cont: targets))
3907 return DiagnosedSilenceableFailure::success();
3908 auto transformOp = cast<TransformOpInterface>(Val: getOperation());
3909 SmallVector<int64_t> vectorSizes;
3910 DiagnosedSilenceableFailure status = reifyMixedParamAndHandleResults(
3911 state, transformOp, mixedResults: getMixedVectorSizes(), reified&: vectorSizes);
3912 if (!status.succeeded())
3913 return status;
3914
3915 // TODO: Check that the correct number of vectorSizes was provided.
3916 for (Operation *target : targets) {
3917 if (!linalg::hasVectorizationImpl(target)) {
3918 return mlir::emitSilenceableFailure(loc: target->getLoc())
3919 << "Unsupported Op, cannot vectorize";
3920 }
3921 FailureOr<VectorizationResult> vectorResults =
3922 linalg::vectorize(rewriter, op: target, inputVectorSizes: vectorSizes, inputScalableVecDims: getScalableSizes(),
3923 vectorizeNDExtract: getVectorizeNdExtract().value_or(u: false));
3924 if (failed(Result: vectorResults)) {
3925 return mlir::emitSilenceableFailure(loc: target->getLoc())
3926 << "Attempted to vectorize, but failed";
3927 }
3928 rewriter.replaceOp(op: target, newValues: vectorResults->replacements);
3929 }
3930
3931 return DiagnosedSilenceableFailure::success();
3932}
3933
3934void transform::VectorizeOp::getEffects(
3935 SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
3936 consumesHandle(handles: getTargetMutable(), effects);
3937 onlyReadsHandle(handles: getVectorSizesMutable(), effects);
3938 modifiesPayload(effects);
3939}
3940
3941SmallVector<OpFoldResult> VectorizeOp::getMixedVectorSizes() {
3942 OpBuilder b(getContext());
3943 return getMixedValues(staticValues: getStaticVectorSizes(), dynamicValues: getVectorSizes(), b);
3944}
3945
3946LogicalResult transform::VectorizeOp::verify() {
3947 if (getStaticVectorSizes().size() != getScalableSizes().size())
3948 return emitOpError(message: "expected same number of vector sizes (")
3949 << getStaticVectorSizes().size() << ") and scalable sizes ("
3950 << getScalableSizes().size() << ")";
3951 return success();
3952}
3953
3954//===----------------------------------------------------------------------===//
3955// HoistRedundantVectorTransfersOp
3956//===----------------------------------------------------------------------===//
3957
3958DiagnosedSilenceableFailure
3959transform::HoistRedundantVectorTransfersOp::applyToOne(
3960 transform::TransformRewriter &rewriter, func::FuncOp target,
3961 transform::ApplyToEachResultList &results,
3962 transform::TransformState &state) {
3963 // WARNING: This hoisting does not model parallelism and is generally
3964 // incorrect when used on distributed loops with memref semantics!
3965 // TODO: obsolete and should be retired.
3966 linalg::hoistRedundantVectorTransfers(root: target, verifyNonZeroTrip: getVerifyNonZeroTrip());
3967 results.push_back(op: target);
3968 return DiagnosedSilenceableFailure::success();
3969}
3970
3971//===----------------------------------------------------------------------===//
3972// HoistRedundantVectorBroadcastsOp
3973//===----------------------------------------------------------------------===//
3974
3975DiagnosedSilenceableFailure
3976transform::HoistRedundantVectorBroadcastsOp::applyToOne(
3977 transform::TransformRewriter &rewriter, mlir::Operation *target,
3978 transform::ApplyToEachResultList &results,
3979 transform::TransformState &state) {
3980 rewriter.setInsertionPoint(target);
3981 linalg::hoistRedundantVectorBroadcasts(rewriter, root: target);
3982 results.push_back(op: target);
3983 return DiagnosedSilenceableFailure::success();
3984}
3985
3986//===----------------------------------------------------------------------===//
3987// ConvertConv2DToImg2ColOp.
3988//===----------------------------------------------------------------------===//
3989
3990DiagnosedSilenceableFailure transform::ConvertConv2DToImg2ColOp::applyToOne(
3991 transform::TransformRewriter &rewriter, linalg::LinalgOp target,
3992 transform::ApplyToEachResultList &results,
3993 transform::TransformState &state) {
3994 rewriter.setInsertionPoint(target);
3995 auto maybeTransformed =
3996 TypeSwitch<Operation *, FailureOr<std::pair<Operation *, Operation *>>>(
3997 target)
3998 .Case(caseFn: [&](linalg::Conv2DNhwcHwcfOp op) {
3999 return rewriteInIm2Col(rewriter, convOp: op);
4000 })
4001 .Case(caseFn: [&](linalg::Conv2DNhwcFhwcOp op) {
4002 return rewriteInIm2Col(rewriter, convOp: op);
4003 })
4004 .Case(caseFn: [&](linalg::DepthwiseConv2DNhwcHwcOp op) {
4005 return rewriteInIm2Col(rewriter, convOp: op);
4006 })
4007 .Case(caseFn: [&](linalg::Conv2DNchwFchwOp op) {
4008 return rewriteInIm2Col(rewriter, convOp: op);
4009 })
4010 .Default(defaultFn: [&](Operation *op) {
4011 return rewriter.notifyMatchFailure(arg&: op, msg: "not supported");
4012 });
4013 if (failed(Result: maybeTransformed))
4014 return emitDefaultSilenceableFailure(target);
4015 // Handle to the operation producing the img2col tensor.
4016 results.push_back(op: maybeTransformed->first);
4017 // Handle to the operation that replaces the original convolution.
4018 results.push_back(op: maybeTransformed->second);
4019 return DiagnosedSilenceableFailure::success();
4020}
4021
4022//===----------------------------------------------------------------------===//
4023// FlattenElementwiseLinalgOp.
4024//===----------------------------------------------------------------------===//
4025
4026DiagnosedSilenceableFailure transform::FlattenElementwiseLinalgOp::applyToOne(
4027 transform::TransformRewriter &rewriter, linalg::LinalgOp target,
4028 transform::ApplyToEachResultList &results,
4029 transform::TransformState &state) {
4030 rewriter.setInsertionPoint(target);
4031 if (!isElementwise(op: target))
4032 return mlir::emitSilenceableFailure(loc: target->getLoc())
4033 << "only elementwise flattening is supported";
4034
4035 // If rank <= 1, do nothing
4036 if (target.getNumLoops() <= 1) {
4037 results.push_back(op: target);
4038 return DiagnosedSilenceableFailure::success();
4039 }
4040
4041 // Attempt to flatten all dims to one.
4042 ReassociationIndices reassociation(target.getNumLoops());
4043 std::iota(first: reassociation.begin(), last: reassociation.end(), value: 0);
4044 auto maybeFlattened =
4045 collapseOpIterationDims(op: target, foldedIterationDims: reassociation, rewriter);
4046 if (failed(Result: maybeFlattened))
4047 return mlir::emitSilenceableFailure(loc: target->getLoc())
4048 << "attempted to flatten, but failed";
4049 results.push_back(op: maybeFlattened->collapsedOp);
4050 rewriter.replaceOp(op: target, newValues: maybeFlattened->results);
4051 return DiagnosedSilenceableFailure::success();
4052}
4053
4054//===----------------------------------------------------------------------===//
4055// TransposeConv2DOp
4056//===----------------------------------------------------------------------===//
4057
4058DiagnosedSilenceableFailure transform::TransposeConv2DOp::applyToOne(
4059 transform::TransformRewriter &rewriter, linalg::LinalgOp target,
4060 transform::ApplyToEachResultList &results,
4061 transform::TransformState &state) {
4062 rewriter.setInsertionPoint(target);
4063 auto maybeTransformed =
4064 TypeSwitch<Operation *, FailureOr<Operation *>>(target)
4065 .Case(caseFn: [&](linalg::Conv2DNhwcFhwcOp op) {
4066 return transposeConv2D(rewriter, op);
4067 })
4068 .Case(caseFn: [&](linalg::Conv2DNhwcFhwcQOp op) {
4069 return transposeConv2D(rewriter, op);
4070 })
4071 .Default(defaultFn: [&](Operation *op) {
4072 return rewriter.notifyMatchFailure(arg&: op, msg: "not supported");
4073 });
4074 if (failed(Result: maybeTransformed))
4075 return emitDefaultSilenceableFailure(target);
4076 // Handle to the new Conv2D operation with transposed filters
4077 results.push_back(op: *maybeTransformed);
4078 return DiagnosedSilenceableFailure::success();
4079}
4080
4081//===----------------------------------------------------------------------===//
4082// TransposeMatmulOp
4083//===----------------------------------------------------------------------===//
4084
4085DiagnosedSilenceableFailure transform::TransposeMatmulOp::applyToOne(
4086 transform::TransformRewriter &rewriter, linalg::LinalgOp target,
4087 transform::ApplyToEachResultList &results,
4088 transform::TransformState &state) {
4089 rewriter.setInsertionPoint(target);
4090 bool transposeLHS = getInputToTranspose() == TransposeMatmulInput::lhs;
4091 auto maybeTransformed =
4092 TypeSwitch<Operation *, FailureOr<Operation *>>(target)
4093 .Case(caseFn: [&](linalg::MatmulOp op) {
4094 return transposeMatmul(rewriter, op, transposeLHS);
4095 })
4096 .Case(caseFn: [&](linalg::BatchMatmulOp op) {
4097 return transposeBatchMatmul(rewriter, op, transposeLHS);
4098 })
4099 .Default(defaultFn: [&](Operation *op) { return failure(); });
4100 if (failed(Result: maybeTransformed))
4101 return emitSilenceableFailure(loc: target->getLoc()) << "not supported";
4102 // Handle to the new Matmul operation with transposed filters
4103 results.push_back(op: *maybeTransformed);
4104 return DiagnosedSilenceableFailure::success();
4105}
4106
4107//===----------------------------------------------------------------------===//
4108// InsertSliceToCopyOp
4109//===----------------------------------------------------------------------===//
4110template <typename OpTy>
4111DiagnosedSilenceableFailure doit(RewriterBase &rewriter, OpTy target,
4112 transform::ApplyToEachResultList &results,
4113 transform::TransformState &state) {
4114 static_assert(llvm::is_one_of<OpTy, tensor::InsertSliceOp,
4115 tensor::ParallelInsertSliceOp>() &&
4116 "wrong op type");
4117
4118 if (auto copySource =
4119 target.getSource().template getDefiningOp<linalg::CopyOp>()) {
4120 results.push_back(copySource);
4121 return DiagnosedSilenceableFailure::success();
4122 }
4123
4124 // If we are inside an InParallel region, temporarily set the insertion point
4125 // outside: only tensor.parallel_insert_slice ops are allowed in there.
4126 if constexpr (std::is_same_v<OpTy, tensor::ParallelInsertSliceOp>) {
4127 rewriter.setInsertionPoint(
4128 target->template getParentOfType<scf::InParallelOp>());
4129 }
4130
4131 Value extracted = rewriter.create<tensor::ExtractSliceOp>(
4132 target.getLoc(), target.getDest(), target.getMixedOffsets(),
4133 target.getMixedSizes(), target.getMixedStrides());
4134 Value copied = rewriter
4135 .create<linalg::CopyOp>(target.getLoc(),
4136 target.getSource(), extracted)
4137 .getResult(0);
4138 // Reset the insertion point.
4139 rewriter.setInsertionPoint(target);
4140 rewriter.replaceOpWithNewOp<OpTy>(
4141 target, copied, target.getDest(), target.getMixedOffsets(),
4142 target.getMixedSizes(), target.getMixedStrides());
4143
4144 results.push_back(op: copied.getDefiningOp());
4145 return DiagnosedSilenceableFailure::success();
4146}
4147
4148DiagnosedSilenceableFailure transform::InsertSliceToCopyOp::applyToOne(
4149 transform::TransformRewriter &rewriter, Operation *targetOp,
4150 transform::ApplyToEachResultList &results,
4151 transform::TransformState &state) {
4152
4153 rewriter.setInsertionPoint(targetOp);
4154 if (auto target = dyn_cast<tensor::InsertSliceOp>(Val: targetOp))
4155 return doit(rewriter, target, results, state);
4156 if (auto target = dyn_cast<tensor::ParallelInsertSliceOp>(Val: targetOp))
4157 return doit(rewriter, target, results, state);
4158
4159 DiagnosedSilenceableFailure diag =
4160 emitSilenceableError()
4161 << "only InsertSliceOp and ParallelInsertSliceOp ops are supported";
4162 diag.attachNote(loc: targetOp->getLoc()) << "target op";
4163 return diag;
4164}
4165
4166//===----------------------------------------------------------------------===//
4167// MapCopyToThreadsOp
4168//===----------------------------------------------------------------------===//
4169
4170DiagnosedSilenceableFailure transform::MapCopyToThreadsOp::applyToOne(
4171 transform::TransformRewriter &rewriter, Operation *target,
4172 transform::ApplyToEachResultList &results,
4173 transform::TransformState &state) {
4174 // Check if the op is supported.
4175 if (!isa<linalg::CopyOp, tensor::PadOp>(Val: target)) {
4176 DiagnosedSilenceableFailure diag =
4177 emitSilenceableError()
4178 << "only linalg.copy and tensor.pad target ops are supported";
4179 diag.attachNote(loc: target->getLoc()) << "target op";
4180 return diag;
4181 }
4182 assert(target->getNumResults() == 1 && "expected single result");
4183 auto resultShapedType = cast<ShapedType>(Val: target->getResult(idx: 0).getType());
4184 if (!resultShapedType.hasStaticShape()) {
4185 DiagnosedSilenceableFailure diag =
4186 emitSilenceableError()
4187 << "only statically sized ops of rank <= 3 are supported";
4188 diag.attachNote(loc: target->getLoc()) << "target op";
4189 return diag;
4190 }
4191
4192 // Conservatively set the minimum viable desired bitwidth alignment.
4193 int64_t desiredBitAlignment = getDesiredBitAlignment();
4194 int64_t eltBitwidth =
4195 resultShapedType.getElementType().getIntOrFloatBitWidth();
4196 if (desiredBitAlignment % eltBitwidth != 0) {
4197 desiredBitAlignment = eltBitwidth;
4198 }
4199
4200 gpu::CopyMappingInfo mapping(
4201 /*ctx=*/getContext(),
4202 /*totalNumThreads=*/getTotalNumThreads(),
4203 /*alignment=*/desiredBitAlignment,
4204 /*sizes=*/resultShapedType.getShape(),
4205 /*favorPredication=*/false,
4206 /*elementalBitwidth=*/
4207 resultShapedType.getElementType().getIntOrFloatBitWidth());
4208 if (mapping.status == gpu::CopyMappingInfo::Status::Invalid) {
4209 DiagnosedSilenceableFailure diag =
4210 emitSilenceableError()
4211 << "too few threads to map copy op to threads on the most minor "
4212 "dimension, given alignment and vector size constraints, try "
4213 "smaller tile size of mapping to more threads";
4214 diag.attachNote(loc: target->getLoc()) << "target op";
4215 return diag;
4216 }
4217
4218 // OpBuilder only used to compute attributes.
4219 OpBuilder b(getContext());
4220 scf::SCFTilingResult tilingResult;
4221 DiagnosedSilenceableFailure diag = tileToForallOpImpl(
4222 /*rewriter=*/rewriter,
4223 /*state=*/state,
4224 /*transformOp=*/*this,
4225 /*target=*/target,
4226 /*mixedNumThreads=*/getMixedValues(staticValues: mapping.numThreads, dynamicValues: {}, b),
4227 /*mixedTileSizes=*/ArrayRef<OpFoldResult>{},
4228 /*mapping=*/b.getArrayAttr(value: mapping.threadMapping),
4229 /*tilingResult=*/tilingResult);
4230 if (!diag.succeeded())
4231 return diag;
4232
4233 results.push_back(op: tilingResult.loops.front());
4234 for (auto op : tilingResult.tiledOps)
4235 results.push_back(op);
4236 return DiagnosedSilenceableFailure::success();
4237}
4238
4239//===----------------------------------------------------------------------===//
4240// WinogradConv2DOp
4241//===----------------------------------------------------------------------===//
4242
4243DiagnosedSilenceableFailure transform::WinogradConv2DOp::applyToOne(
4244 transform::TransformRewriter &rewriter, linalg::LinalgOp target,
4245 transform::ApplyToEachResultList &results,
4246 transform::TransformState &state) {
4247 rewriter.setInsertionPoint(target);
4248 FailureOr<Operation *> maybeTransformed = failure();
4249 bool supported = TypeSwitch<Operation *, bool>(target)
4250 .Case(caseFn: [&](linalg::Conv2DNhwcFhwcOp op) {
4251 maybeTransformed =
4252 winogradConv2D(rewriter, op, fmr: getFmr());
4253 return true;
4254 })
4255 .Default(defaultFn: [&](Operation *op) { return false; });
4256
4257 if (!supported) {
4258 return emitSilenceableError()
4259 << "this operation is not supported to convert to Winograd Conv2D";
4260 }
4261
4262 if (failed(Result: maybeTransformed)) {
4263 return emitSilenceableError() << "apply Winograd Conv2D failed";
4264 }
4265
4266 results.push_back(op: *maybeTransformed);
4267 return DiagnosedSilenceableFailure::success();
4268}
4269
4270DiagnosedSilenceableFailure transform::DecomposeWinogradOp::applyToOne(
4271 transform::TransformRewriter &rewriter, Operation *target,
4272 transform::ApplyToEachResultList &results,
4273 transform::TransformState &state) {
4274 rewriter.setInsertionPoint(target);
4275 FailureOr<Operation *> maybeTransformed = failure();
4276 bool supported =
4277 TypeSwitch<Operation *, bool>(target)
4278 .Case(caseFn: [&](linalg::WinogradFilterTransformOp op) {
4279 maybeTransformed = decomposeWinogradFilterTransformOp(rewriter, op);
4280 return true;
4281 })
4282 .Case(caseFn: [&](linalg::WinogradInputTransformOp op) {
4283 maybeTransformed = decomposeWinogradInputTransformOp(rewriter, op);
4284 return true;
4285 })
4286 .Case(caseFn: [&](linalg::WinogradOutputTransformOp op) {
4287 maybeTransformed = decomposeWinogradOutputTransformOp(rewriter, op);
4288 return true;
4289 })
4290 .Default(defaultFn: [&](Operation *op) { return false; });
4291
4292 if (!supported) {
4293 DiagnosedSilenceableFailure diag =
4294 emitSilenceableError()
4295 << "this operation is not supported to decompose into other operations";
4296 diag.attachNote(loc: target->getLoc()) << "target op";
4297 return diag;
4298 }
4299
4300 if (failed(Result: maybeTransformed)) {
4301 DiagnosedSilenceableFailure diag =
4302 emitSilenceableError() << "decompose Winograd operations failed";
4303 diag.attachNote(loc: target->getLoc()) << "target op";
4304 return diag;
4305 }
4306
4307 results.push_back(op: *maybeTransformed);
4308 return DiagnosedSilenceableFailure::success();
4309}
4310
4311#include "mlir/Dialect/Linalg/TransformOps/LinalgTransformOpsEnums.cpp.inc"
4312
4313#define GET_OP_CLASSES
4314#include "mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp.inc"
4315

source code of mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp