1//===- TransformOps.cpp - Transform dialect operations --------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#include "mlir/Dialect/Transform/IR/TransformOps.h"
10
11#include "mlir/Conversion/ConvertToLLVM/ToLLVMInterface.h"
12#include "mlir/Conversion/LLVMCommon/ConversionTarget.h"
13#include "mlir/Conversion/LLVMCommon/TypeConverter.h"
14#include "mlir/Dialect/Transform/IR/TransformAttrs.h"
15#include "mlir/Dialect/Transform/IR/TransformDialect.h"
16#include "mlir/Dialect/Transform/IR/TransformTypes.h"
17#include "mlir/Dialect/Transform/Interfaces/MatchInterfaces.h"
18#include "mlir/Dialect/Transform/Interfaces/TransformInterfaces.h"
19#include "mlir/IR/BuiltinAttributes.h"
20#include "mlir/IR/Diagnostics.h"
21#include "mlir/IR/Dominance.h"
22#include "mlir/IR/OpImplementation.h"
23#include "mlir/IR/OperationSupport.h"
24#include "mlir/IR/PatternMatch.h"
25#include "mlir/IR/Verifier.h"
26#include "mlir/Interfaces/CallInterfaces.h"
27#include "mlir/Interfaces/ControlFlowInterfaces.h"
28#include "mlir/Interfaces/FunctionImplementation.h"
29#include "mlir/Interfaces/FunctionInterfaces.h"
30#include "mlir/Pass/Pass.h"
31#include "mlir/Pass/PassManager.h"
32#include "mlir/Pass/PassRegistry.h"
33#include "mlir/Transforms/CSE.h"
34#include "mlir/Transforms/DialectConversion.h"
35#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
36#include "mlir/Transforms/LoopInvariantCodeMotionUtils.h"
37#include "llvm/ADT/DenseSet.h"
38#include "llvm/ADT/STLExtras.h"
39#include "llvm/ADT/ScopeExit.h"
40#include "llvm/ADT/SmallPtrSet.h"
41#include "llvm/ADT/TypeSwitch.h"
42#include "llvm/Support/Debug.h"
43#include "llvm/Support/ErrorHandling.h"
44#include "llvm/Support/InterleavedRange.h"
45#include <optional>
46
47#define DEBUG_TYPE "transform-dialect"
48#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "] ")
49
50#define DEBUG_TYPE_MATCHER "transform-matcher"
51#define DBGS_MATCHER() (llvm::dbgs() << "[" DEBUG_TYPE_MATCHER "] ")
52#define DEBUG_MATCHER(x) DEBUG_WITH_TYPE(DEBUG_TYPE_MATCHER, x)
53
54using namespace mlir;
55
56static ParseResult parseApplyRegisteredPassOptions(
57 OpAsmParser &parser, DictionaryAttr &options,
58 SmallVectorImpl<OpAsmParser::UnresolvedOperand> &dynamicOptions);
59static void printApplyRegisteredPassOptions(OpAsmPrinter &printer,
60 Operation *op,
61 DictionaryAttr options,
62 ValueRange dynamicOptions);
63static ParseResult parseSequenceOpOperands(
64 OpAsmParser &parser, std::optional<OpAsmParser::UnresolvedOperand> &root,
65 Type &rootType,
66 SmallVectorImpl<OpAsmParser::UnresolvedOperand> &extraBindings,
67 SmallVectorImpl<Type> &extraBindingTypes);
68static void printSequenceOpOperands(OpAsmPrinter &printer, Operation *op,
69 Value root, Type rootType,
70 ValueRange extraBindings,
71 TypeRange extraBindingTypes);
72static void printForeachMatchSymbols(OpAsmPrinter &printer, Operation *op,
73 ArrayAttr matchers, ArrayAttr actions);
74static ParseResult parseForeachMatchSymbols(OpAsmParser &parser,
75 ArrayAttr &matchers,
76 ArrayAttr &actions);
77
78/// Helper function to check if the given transform op is contained in (or
79/// equal to) the given payload target op. In that case, an error is returned.
80/// Transforming transform IR that is currently executing is generally unsafe.
81static DiagnosedSilenceableFailure
82ensurePayloadIsSeparateFromTransform(transform::TransformOpInterface transform,
83 Operation *payload) {
84 Operation *transformAncestor = transform.getOperation();
85 while (transformAncestor) {
86 if (transformAncestor == payload) {
87 DiagnosedDefiniteFailure diag =
88 transform.emitDefiniteFailure()
89 << "cannot apply transform to itself (or one of its ancestors)";
90 diag.attachNote(loc: payload->getLoc()) << "target payload op";
91 return diag;
92 }
93 transformAncestor = transformAncestor->getParentOp();
94 }
95 return DiagnosedSilenceableFailure::success();
96}
97
98#define GET_OP_CLASSES
99#include "mlir/Dialect/Transform/IR/TransformOps.cpp.inc"
100
101//===----------------------------------------------------------------------===//
102// AlternativesOp
103//===----------------------------------------------------------------------===//
104
105OperandRange
106transform::AlternativesOp::getEntrySuccessorOperands(RegionBranchPoint point) {
107 if (!point.isParent() && getOperation()->getNumOperands() == 1)
108 return getOperation()->getOperands();
109 return OperandRange(getOperation()->operand_end(),
110 getOperation()->operand_end());
111}
112
113void transform::AlternativesOp::getSuccessorRegions(
114 RegionBranchPoint point, SmallVectorImpl<RegionSuccessor> &regions) {
115 for (Region &alternative : llvm::drop_begin(
116 getAlternatives(),
117 point.isParent() ? 0
118 : point.getRegionOrNull()->getRegionNumber() + 1)) {
119 regions.emplace_back(&alternative, !getOperands().empty()
120 ? alternative.getArguments()
121 : Block::BlockArgListType());
122 }
123 if (!point.isParent())
124 regions.emplace_back(getOperation()->getResults());
125}
126
127void transform::AlternativesOp::getRegionInvocationBounds(
128 ArrayRef<Attribute> operands, SmallVectorImpl<InvocationBounds> &bounds) {
129 (void)operands;
130 // The region corresponding to the first alternative is always executed, the
131 // remaining may or may not be executed.
132 bounds.reserve(getNumRegions());
133 bounds.emplace_back(1, 1);
134 bounds.resize(getNumRegions(), InvocationBounds(0, 1));
135}
136
137static void forwardEmptyOperands(Block *block, transform::TransformState &state,
138 transform::TransformResults &results) {
139 for (const auto &res : block->getParentOp()->getOpResults())
140 results.set(value: res, ops: {});
141}
142
143DiagnosedSilenceableFailure
144transform::AlternativesOp::apply(transform::TransformRewriter &rewriter,
145 transform::TransformResults &results,
146 transform::TransformState &state) {
147 SmallVector<Operation *> originals;
148 if (Value scopeHandle = getScope())
149 llvm::append_range(originals, state.getPayloadOps(scopeHandle));
150 else
151 originals.push_back(state.getTopLevel());
152
153 for (Operation *original : originals) {
154 if (original->isAncestor(getOperation())) {
155 auto diag = emitDefiniteFailure()
156 << "scope must not contain the transforms being applied";
157 diag.attachNote(original->getLoc()) << "scope";
158 return diag;
159 }
160 if (!original->hasTrait<OpTrait::IsIsolatedFromAbove>()) {
161 auto diag = emitDefiniteFailure()
162 << "only isolated-from-above ops can be alternative scopes";
163 diag.attachNote(original->getLoc()) << "scope";
164 return diag;
165 }
166 }
167
168 for (Region &reg : getAlternatives()) {
169 // Clone the scope operations and make the transforms in this alternative
170 // region apply to them by virtue of mapping the block argument (the only
171 // visible handle) to the cloned scope operations. This effectively prevents
172 // the transformation from accessing any IR outside the scope.
173 auto scope = state.make_region_scope(reg);
174 auto clones = llvm::to_vector(
175 llvm::map_range(originals, [](Operation *op) { return op->clone(); }));
176 auto deleteClones = llvm::make_scope_exit([&] {
177 for (Operation *clone : clones)
178 clone->erase();
179 });
180 if (failed(state.mapBlockArguments(reg.front().getArgument(0), clones)))
181 return DiagnosedSilenceableFailure::definiteFailure();
182
183 bool failed = false;
184 for (Operation &transform : reg.front().without_terminator()) {
185 DiagnosedSilenceableFailure result =
186 state.applyTransform(cast<TransformOpInterface>(transform));
187 if (result.isSilenceableFailure()) {
188 LLVM_DEBUG(DBGS() << "alternative failed: " << result.getMessage()
189 << "\n");
190 failed = true;
191 break;
192 }
193
194 if (::mlir::failed(result.silence()))
195 return DiagnosedSilenceableFailure::definiteFailure();
196 }
197
198 // If all operations in the given alternative succeeded, no need to consider
199 // the rest. Replace the original scoping operation with the clone on which
200 // the transformations were performed.
201 if (!failed) {
202 // We will be using the clones, so cancel their scheduled deletion.
203 deleteClones.release();
204 TrackingListener listener(state, *this);
205 IRRewriter rewriter(getContext(), &listener);
206 for (const auto &kvp : llvm::zip(originals, clones)) {
207 Operation *original = std::get<0>(kvp);
208 Operation *clone = std::get<1>(kvp);
209 original->getBlock()->getOperations().insert(original->getIterator(),
210 clone);
211 rewriter.replaceOp(original, clone->getResults());
212 }
213 detail::forwardTerminatorOperands(&reg.front(), state, results);
214 return DiagnosedSilenceableFailure::success();
215 }
216 }
217 return emitSilenceableError() << "all alternatives failed";
218}
219
220void transform::AlternativesOp::getEffects(
221 SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
222 consumesHandle(getOperation()->getOpOperands(), effects);
223 producesHandle(getOperation()->getOpResults(), effects);
224 for (Region *region : getRegions()) {
225 if (!region->empty())
226 producesHandle(region->front().getArguments(), effects);
227 }
228 modifiesPayload(effects);
229}
230
231LogicalResult transform::AlternativesOp::verify() {
232 for (Region &alternative : getAlternatives()) {
233 Block &block = alternative.front();
234 Operation *terminator = block.getTerminator();
235 if (terminator->getOperands().getTypes() != getResults().getTypes()) {
236 InFlightDiagnostic diag = emitOpError()
237 << "expects terminator operands to have the "
238 "same type as results of the operation";
239 diag.attachNote(terminator->getLoc()) << "terminator";
240 return diag;
241 }
242 }
243
244 return success();
245}
246
247//===----------------------------------------------------------------------===//
248// AnnotateOp
249//===----------------------------------------------------------------------===//
250
251DiagnosedSilenceableFailure
252transform::AnnotateOp::apply(transform::TransformRewriter &rewriter,
253 transform::TransformResults &results,
254 transform::TransformState &state) {
255 SmallVector<Operation *> targets =
256 llvm::to_vector(state.getPayloadOps(getTarget()));
257
258 Attribute attr = UnitAttr::get(getContext());
259 if (auto paramH = getParam()) {
260 ArrayRef<Attribute> params = state.getParams(paramH);
261 if (params.size() != 1) {
262 if (targets.size() != params.size()) {
263 return emitSilenceableError()
264 << "parameter and target have different payload lengths ("
265 << params.size() << " vs " << targets.size() << ")";
266 }
267 for (auto &&[target, attr] : llvm::zip_equal(targets, params))
268 target->setAttr(getName(), attr);
269 return DiagnosedSilenceableFailure::success();
270 }
271 attr = params[0];
272 }
273 for (auto *target : targets)
274 target->setAttr(getName(), attr);
275 return DiagnosedSilenceableFailure::success();
276}
277
278void transform::AnnotateOp::getEffects(
279 SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
280 onlyReadsHandle(getTargetMutable(), effects);
281 onlyReadsHandle(getParamMutable(), effects);
282 modifiesPayload(effects);
283}
284
285//===----------------------------------------------------------------------===//
286// ApplyCommonSubexpressionEliminationOp
287//===----------------------------------------------------------------------===//
288
289DiagnosedSilenceableFailure
290transform::ApplyCommonSubexpressionEliminationOp::applyToOne(
291 transform::TransformRewriter &rewriter, Operation *target,
292 ApplyToEachResultList &results, transform::TransformState &state) {
293 // Make sure that this transform is not applied to itself. Modifying the
294 // transform IR while it is being interpreted is generally dangerous.
295 DiagnosedSilenceableFailure payloadCheck =
296 ensurePayloadIsSeparateFromTransform(*this, target);
297 if (!payloadCheck.succeeded())
298 return payloadCheck;
299
300 DominanceInfo domInfo;
301 mlir::eliminateCommonSubExpressions(rewriter, domInfo, target);
302 return DiagnosedSilenceableFailure::success();
303}
304
305void transform::ApplyCommonSubexpressionEliminationOp::getEffects(
306 SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
307 transform::onlyReadsHandle(getTargetMutable(), effects);
308 transform::modifiesPayload(effects);
309}
310
311//===----------------------------------------------------------------------===//
312// ApplyDeadCodeEliminationOp
313//===----------------------------------------------------------------------===//
314
315DiagnosedSilenceableFailure transform::ApplyDeadCodeEliminationOp::applyToOne(
316 transform::TransformRewriter &rewriter, Operation *target,
317 ApplyToEachResultList &results, transform::TransformState &state) {
318 // Make sure that this transform is not applied to itself. Modifying the
319 // transform IR while it is being interpreted is generally dangerous.
320 DiagnosedSilenceableFailure payloadCheck =
321 ensurePayloadIsSeparateFromTransform(*this, target);
322 if (!payloadCheck.succeeded())
323 return payloadCheck;
324
325 // Maintain a worklist of potentially dead ops.
326 SetVector<Operation *> worklist;
327
328 // Helper function that adds all defining ops of used values (operands and
329 // operands of nested ops).
330 auto addDefiningOpsToWorklist = [&](Operation *op) {
331 op->walk([&](Operation *op) {
332 for (Value v : op->getOperands())
333 if (Operation *defOp = v.getDefiningOp())
334 if (target->isProperAncestor(defOp))
335 worklist.insert(defOp);
336 });
337 };
338
339 // Helper function that erases an op.
340 auto eraseOp = [&](Operation *op) {
341 // Remove op and nested ops from the worklist.
342 op->walk([&](Operation *op) {
343 const auto *it = llvm::find(worklist, op);
344 if (it != worklist.end())
345 worklist.erase(it);
346 });
347 rewriter.eraseOp(op);
348 };
349
350 // Initial walk over the IR.
351 target->walk<WalkOrder::PostOrder>([&](Operation *op) {
352 if (op != target && isOpTriviallyDead(op)) {
353 addDefiningOpsToWorklist(op);
354 eraseOp(op);
355 }
356 });
357
358 // Erase all ops that have become dead.
359 while (!worklist.empty()) {
360 Operation *op = worklist.pop_back_val();
361 if (!isOpTriviallyDead(op))
362 continue;
363 addDefiningOpsToWorklist(op);
364 eraseOp(op);
365 }
366
367 return DiagnosedSilenceableFailure::success();
368}
369
370void transform::ApplyDeadCodeEliminationOp::getEffects(
371 SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
372 transform::onlyReadsHandle(getTargetMutable(), effects);
373 transform::modifiesPayload(effects);
374}
375
376//===----------------------------------------------------------------------===//
377// ApplyPatternsOp
378//===----------------------------------------------------------------------===//
379
380DiagnosedSilenceableFailure transform::ApplyPatternsOp::applyToOne(
381 transform::TransformRewriter &rewriter, Operation *target,
382 ApplyToEachResultList &results, transform::TransformState &state) {
383 // Make sure that this transform is not applied to itself. Modifying the
384 // transform IR while it is being interpreted is generally dangerous. Even
385 // more so for the ApplyPatternsOp because the GreedyPatternRewriteDriver
386 // performs many additional simplifications such as dead code elimination.
387 DiagnosedSilenceableFailure payloadCheck =
388 ensurePayloadIsSeparateFromTransform(*this, target);
389 if (!payloadCheck.succeeded())
390 return payloadCheck;
391
392 // Gather all specified patterns.
393 MLIRContext *ctx = target->getContext();
394 RewritePatternSet patterns(ctx);
395 if (!getRegion().empty()) {
396 for (Operation &op : getRegion().front()) {
397 cast<transform::PatternDescriptorOpInterface>(&op)
398 .populatePatternsWithState(patterns, state);
399 }
400 }
401
402 // Configure the GreedyPatternRewriteDriver.
403 GreedyRewriteConfig config;
404 config.setListener(
405 static_cast<RewriterBase::Listener *>(rewriter.getListener()));
406 FrozenRewritePatternSet frozenPatterns(std::move(patterns));
407
408 config.setMaxIterations(getMaxIterations() == static_cast<uint64_t>(-1)
409 ? GreedyRewriteConfig::kNoLimit
410 : getMaxIterations());
411 config.setMaxNumRewrites(getMaxNumRewrites() == static_cast<uint64_t>(-1)
412 ? GreedyRewriteConfig::kNoLimit
413 : getMaxNumRewrites());
414
415 // Apply patterns and CSE repetitively until a fixpoint is reached. If no CSE
416 // was requested, apply the greedy pattern rewrite only once. (The greedy
417 // pattern rewrite driver already iterates to a fixpoint internally.)
418 bool cseChanged = false;
419 // One or two iterations should be sufficient. Stop iterating after a certain
420 // threshold to make debugging easier.
421 static const int64_t kNumMaxIterations = 50;
422 int64_t iteration = 0;
423 do {
424 LogicalResult result = failure();
425 if (target->hasTrait<OpTrait::IsIsolatedFromAbove>()) {
426 // Op is isolated from above. Apply patterns and also perform region
427 // simplification.
428 result = applyPatternsGreedily(target, frozenPatterns, config);
429 } else {
430 // Manually gather list of ops because the other
431 // GreedyPatternRewriteDriver overloads only accepts ops that are isolated
432 // from above. This way, patterns can be applied to ops that are not
433 // isolated from above. Regions are not being simplified. Furthermore,
434 // only a single greedy rewrite iteration is performed.
435 SmallVector<Operation *> ops;
436 target->walk([&](Operation *nestedOp) {
437 if (target != nestedOp)
438 ops.push_back(nestedOp);
439 });
440 result = applyOpPatternsGreedily(ops, frozenPatterns, config);
441 }
442
443 // A failure typically indicates that the pattern application did not
444 // converge.
445 if (failed(result)) {
446 return emitSilenceableFailure(target)
447 << "greedy pattern application failed";
448 }
449
450 if (getApplyCse()) {
451 DominanceInfo domInfo;
452 mlir::eliminateCommonSubExpressions(rewriter, domInfo, target,
453 &cseChanged);
454 }
455 } while (cseChanged && ++iteration < kNumMaxIterations);
456
457 if (iteration == kNumMaxIterations)
458 return emitDefiniteFailure() << "fixpoint iteration did not converge";
459
460 return DiagnosedSilenceableFailure::success();
461}
462
463LogicalResult transform::ApplyPatternsOp::verify() {
464 if (!getRegion().empty()) {
465 for (Operation &op : getRegion().front()) {
466 if (!isa<transform::PatternDescriptorOpInterface>(&op)) {
467 InFlightDiagnostic diag = emitOpError()
468 << "expected children ops to implement "
469 "PatternDescriptorOpInterface";
470 diag.attachNote(op.getLoc()) << "op without interface";
471 return diag;
472 }
473 }
474 }
475 return success();
476}
477
478void transform::ApplyPatternsOp::getEffects(
479 SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
480 transform::onlyReadsHandle(getTargetMutable(), effects);
481 transform::modifiesPayload(effects);
482}
483
484void transform::ApplyPatternsOp::build(
485 OpBuilder &builder, OperationState &result, Value target,
486 function_ref<void(OpBuilder &, Location)> bodyBuilder) {
487 result.addOperands(target);
488
489 OpBuilder::InsertionGuard g(builder);
490 Region *region = result.addRegion();
491 builder.createBlock(region);
492 if (bodyBuilder)
493 bodyBuilder(builder, result.location);
494}
495
496//===----------------------------------------------------------------------===//
497// ApplyCanonicalizationPatternsOp
498//===----------------------------------------------------------------------===//
499
500void transform::ApplyCanonicalizationPatternsOp::populatePatterns(
501 RewritePatternSet &patterns) {
502 MLIRContext *ctx = patterns.getContext();
503 for (Dialect *dialect : ctx->getLoadedDialects())
504 dialect->getCanonicalizationPatterns(patterns);
505 for (RegisteredOperationName op : ctx->getRegisteredOperations())
506 op.getCanonicalizationPatterns(patterns, ctx);
507}
508
509//===----------------------------------------------------------------------===//
510// ApplyConversionPatternsOp
511//===----------------------------------------------------------------------===//
512
513DiagnosedSilenceableFailure transform::ApplyConversionPatternsOp::apply(
514 transform::TransformRewriter &rewriter,
515 transform::TransformResults &results, transform::TransformState &state) {
516 MLIRContext *ctx = getContext();
517
518 // Instantiate the default type converter if a type converter builder is
519 // specified.
520 std::unique_ptr<TypeConverter> defaultTypeConverter;
521 transform::TypeConverterBuilderOpInterface typeConverterBuilder =
522 getDefaultTypeConverter();
523 if (typeConverterBuilder)
524 defaultTypeConverter = typeConverterBuilder.getTypeConverter();
525
526 // Configure conversion target.
527 ConversionTarget conversionTarget(*getContext());
528 if (getLegalOps())
529 for (Attribute attr : cast<ArrayAttr>(*getLegalOps()))
530 conversionTarget.addLegalOp(
531 OperationName(cast<StringAttr>(attr).getValue(), ctx));
532 if (getIllegalOps())
533 for (Attribute attr : cast<ArrayAttr>(*getIllegalOps()))
534 conversionTarget.addIllegalOp(
535 OperationName(cast<StringAttr>(attr).getValue(), ctx));
536 if (getLegalDialects())
537 for (Attribute attr : cast<ArrayAttr>(*getLegalDialects()))
538 conversionTarget.addLegalDialect(cast<StringAttr>(attr).getValue());
539 if (getIllegalDialects())
540 for (Attribute attr : cast<ArrayAttr>(*getIllegalDialects()))
541 conversionTarget.addIllegalDialect(cast<StringAttr>(attr).getValue());
542
543 // Gather all specified patterns.
544 RewritePatternSet patterns(ctx);
545 // Need to keep the converters alive until after pattern application because
546 // the patterns take a reference to an object that would otherwise get out of
547 // scope.
548 SmallVector<std::unique_ptr<TypeConverter>> keepAliveConverters;
549 if (!getPatterns().empty()) {
550 for (Operation &op : getPatterns().front()) {
551 auto descriptor =
552 cast<transform::ConversionPatternDescriptorOpInterface>(&op);
553
554 // Check if this pattern set specifies a type converter.
555 std::unique_ptr<TypeConverter> typeConverter =
556 descriptor.getTypeConverter();
557 TypeConverter *converter = nullptr;
558 if (typeConverter) {
559 keepAliveConverters.emplace_back(std::move(typeConverter));
560 converter = keepAliveConverters.back().get();
561 } else {
562 // No type converter specified: Use the default type converter.
563 if (!defaultTypeConverter) {
564 auto diag = emitDefiniteFailure()
565 << "pattern descriptor does not specify type "
566 "converter and apply_conversion_patterns op has "
567 "no default type converter";
568 diag.attachNote(op.getLoc()) << "pattern descriptor op";
569 return diag;
570 }
571 converter = defaultTypeConverter.get();
572 }
573
574 // Add descriptor-specific updates to the conversion target, which may
575 // depend on the final type converter. In structural converters, the
576 // legality of types dictates the dynamic legality of an operation.
577 descriptor.populateConversionTargetRules(*converter, conversionTarget);
578
579 descriptor.populatePatterns(*converter, patterns);
580 }
581 }
582
583 // Attach a tracking listener if handles should be preserved. We configure the
584 // listener to allow op replacements with different names, as conversion
585 // patterns typically replace ops with replacement ops that have a different
586 // name.
587 TrackingListenerConfig trackingConfig;
588 trackingConfig.requireMatchingReplacementOpName = false;
589 ErrorCheckingTrackingListener trackingListener(state, *this, trackingConfig);
590 ConversionConfig conversionConfig;
591 if (getPreserveHandles())
592 conversionConfig.listener = &trackingListener;
593
594 FrozenRewritePatternSet frozenPatterns(std::move(patterns));
595 for (Operation *target : state.getPayloadOps(getTarget())) {
596 // Make sure that this transform is not applied to itself. Modifying the
597 // transform IR while it is being interpreted is generally dangerous.
598 DiagnosedSilenceableFailure payloadCheck =
599 ensurePayloadIsSeparateFromTransform(*this, target);
600 if (!payloadCheck.succeeded())
601 return payloadCheck;
602
603 LogicalResult status = failure();
604 if (getPartialConversion()) {
605 status = applyPartialConversion(target, conversionTarget, frozenPatterns,
606 conversionConfig);
607 } else {
608 status = applyFullConversion(target, conversionTarget, frozenPatterns,
609 conversionConfig);
610 }
611
612 // Check dialect conversion state.
613 DiagnosedSilenceableFailure diag = DiagnosedSilenceableFailure::success();
614 if (failed(status)) {
615 diag = emitSilenceableError() << "dialect conversion failed";
616 diag.attachNote(target->getLoc()) << "target op";
617 }
618
619 // Check tracking listener error state.
620 DiagnosedSilenceableFailure trackingFailure =
621 trackingListener.checkAndResetError();
622 if (!trackingFailure.succeeded()) {
623 if (diag.succeeded()) {
624 // Tracking failure is the only failure.
625 return trackingFailure;
626 } else {
627 diag.attachNote() << "tracking listener also failed: "
628 << trackingFailure.getMessage();
629 (void)trackingFailure.silence();
630 }
631 }
632
633 if (!diag.succeeded())
634 return diag;
635 }
636
637 return DiagnosedSilenceableFailure::success();
638}
639
640LogicalResult transform::ApplyConversionPatternsOp::verify() {
641 if (getNumRegions() != 1 && getNumRegions() != 2)
642 return emitOpError() << "expected 1 or 2 regions";
643 if (!getPatterns().empty()) {
644 for (Operation &op : getPatterns().front()) {
645 if (!isa<transform::ConversionPatternDescriptorOpInterface>(&op)) {
646 InFlightDiagnostic diag =
647 emitOpError() << "expected pattern children ops to implement "
648 "ConversionPatternDescriptorOpInterface";
649 diag.attachNote(op.getLoc()) << "op without interface";
650 return diag;
651 }
652 }
653 }
654 if (getNumRegions() == 2) {
655 Region &typeConverterRegion = getRegion(1);
656 if (!llvm::hasSingleElement(typeConverterRegion.front()))
657 return emitOpError()
658 << "expected exactly one op in default type converter region";
659 Operation *maybeTypeConverter = &typeConverterRegion.front().front();
660 auto typeConverterOp = dyn_cast<transform::TypeConverterBuilderOpInterface>(
661 maybeTypeConverter);
662 if (!typeConverterOp) {
663 InFlightDiagnostic diag = emitOpError()
664 << "expected default converter child op to "
665 "implement TypeConverterBuilderOpInterface";
666 diag.attachNote(maybeTypeConverter->getLoc()) << "op without interface";
667 return diag;
668 }
669 // Check default type converter type.
670 if (!getPatterns().empty()) {
671 for (Operation &op : getPatterns().front()) {
672 auto descriptor =
673 cast<transform::ConversionPatternDescriptorOpInterface>(&op);
674 if (failed(descriptor.verifyTypeConverter(typeConverterOp)))
675 return failure();
676 }
677 }
678 }
679 return success();
680}
681
682void transform::ApplyConversionPatternsOp::getEffects(
683 SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
684 if (!getPreserveHandles()) {
685 transform::consumesHandle(getTargetMutable(), effects);
686 } else {
687 transform::onlyReadsHandle(getTargetMutable(), effects);
688 }
689 transform::modifiesPayload(effects);
690}
691
692void transform::ApplyConversionPatternsOp::build(
693 OpBuilder &builder, OperationState &result, Value target,
694 function_ref<void(OpBuilder &, Location)> patternsBodyBuilder,
695 function_ref<void(OpBuilder &, Location)> typeConverterBodyBuilder) {
696 result.addOperands(target);
697
698 {
699 OpBuilder::InsertionGuard g(builder);
700 Region *region1 = result.addRegion();
701 builder.createBlock(region1);
702 if (patternsBodyBuilder)
703 patternsBodyBuilder(builder, result.location);
704 }
705 {
706 OpBuilder::InsertionGuard g(builder);
707 Region *region2 = result.addRegion();
708 builder.createBlock(region2);
709 if (typeConverterBodyBuilder)
710 typeConverterBodyBuilder(builder, result.location);
711 }
712}
713
714//===----------------------------------------------------------------------===//
715// ApplyToLLVMConversionPatternsOp
716//===----------------------------------------------------------------------===//
717
718void transform::ApplyToLLVMConversionPatternsOp::populatePatterns(
719 TypeConverter &typeConverter, RewritePatternSet &patterns) {
720 Dialect *dialect = getContext()->getLoadedDialect(getDialectName());
721 assert(dialect && "expected that dialect is loaded");
722 auto *iface = cast<ConvertToLLVMPatternInterface>(dialect);
723 // ConversionTarget is currently ignored because the enclosing
724 // apply_conversion_patterns op sets up its own ConversionTarget.
725 ConversionTarget target(*getContext());
726 iface->populateConvertToLLVMConversionPatterns(
727 target, static_cast<LLVMTypeConverter &>(typeConverter), patterns);
728}
729
730LogicalResult transform::ApplyToLLVMConversionPatternsOp::verifyTypeConverter(
731 transform::TypeConverterBuilderOpInterface builder) {
732 if (builder.getTypeConverterType() != "LLVMTypeConverter")
733 return emitOpError("expected LLVMTypeConverter");
734 return success();
735}
736
737LogicalResult transform::ApplyToLLVMConversionPatternsOp::verify() {
738 Dialect *dialect = getContext()->getLoadedDialect(getDialectName());
739 if (!dialect)
740 return emitOpError("unknown dialect or dialect not loaded: ")
741 << getDialectName();
742 auto *iface = dyn_cast<ConvertToLLVMPatternInterface>(dialect);
743 if (!iface)
744 return emitOpError(
745 "dialect does not implement ConvertToLLVMPatternInterface or "
746 "extension was not loaded: ")
747 << getDialectName();
748 return success();
749}
750
751//===----------------------------------------------------------------------===//
752// ApplyLoopInvariantCodeMotionOp
753//===----------------------------------------------------------------------===//
754
755DiagnosedSilenceableFailure
756transform::ApplyLoopInvariantCodeMotionOp::applyToOne(
757 transform::TransformRewriter &rewriter, LoopLikeOpInterface target,
758 transform::ApplyToEachResultList &results,
759 transform::TransformState &state) {
760 // Currently, LICM does not remove operations, so we don't need tracking.
761 // If this ever changes, add a LICM entry point that takes a rewriter.
762 moveLoopInvariantCode(target);
763 return DiagnosedSilenceableFailure::success();
764}
765
766void transform::ApplyLoopInvariantCodeMotionOp::getEffects(
767 SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
768 transform::onlyReadsHandle(getTargetMutable(), effects);
769 transform::modifiesPayload(effects);
770}
771
772//===----------------------------------------------------------------------===//
773// ApplyRegisteredPassOp
774//===----------------------------------------------------------------------===//
775
776void transform::ApplyRegisteredPassOp::getEffects(
777 SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
778 consumesHandle(getTargetMutable(), effects);
779 onlyReadsHandle(getDynamicOptionsMutable(), effects);
780 producesHandle(getOperation()->getOpResults(), effects);
781 modifiesPayload(effects);
782}
783
784DiagnosedSilenceableFailure
785transform::ApplyRegisteredPassOp::apply(transform::TransformRewriter &rewriter,
786 transform::TransformResults &results,
787 transform::TransformState &state) {
788 // Obtain a single options-string to pass to the pass(-pipeline) from options
789 // passed in as a dictionary of keys mapping to values which are either
790 // attributes or param-operands pointing to attributes.
791
792 std::string options;
793 llvm::raw_string_ostream optionsStream(options); // For "printing" attrs.
794
795 OperandRange dynamicOptions = getDynamicOptions();
796 for (auto [idx, namedAttribute] : llvm::enumerate(getOptions())) {
797 if (idx > 0)
798 optionsStream << " "; // Interleave options separator.
799 optionsStream << namedAttribute.getName().str(); // Append the key.
800 optionsStream << "="; // And the key-value separator.
801
802 Attribute valueAttrToAppend;
803 if (auto paramOperandIndex =
804 dyn_cast<transform::ParamOperandAttr>(namedAttribute.getValue())) {
805 // The corresponding value attribute is passed in via a param.
806 // Obtain the param-operand via its specified index.
807 size_t dynamicOptionIdx = paramOperandIndex.getIndex().getInt();
808 assert(dynamicOptionIdx < dynamicOptions.size() &&
809 "number of dynamic option markers (UnitAttr) in options ArrayAttr "
810 "should be the same as the number of options passed as params");
811 ArrayRef<Attribute> dynamicOption =
812 state.getParams(dynamicOptions[dynamicOptionIdx]);
813 if (dynamicOption.size() != 1)
814 return emitSilenceableError()
815 << "options passed as a param must have "
816 "a single value associated, param "
817 << dynamicOptionIdx << " associates " << dynamicOption.size();
818 valueAttrToAppend = dynamicOption[0];
819 } else {
820 // Value is a static attribute.
821 valueAttrToAppend = namedAttribute.getValue();
822 }
823
824 // Append string representation of value attribute.
825 if (auto strAttr = dyn_cast<StringAttr>(valueAttrToAppend)) {
826 optionsStream << strAttr.getValue().str();
827 } else {
828 valueAttrToAppend.print(optionsStream, /*elideType=*/true);
829 }
830 }
831 optionsStream.flush();
832
833 // Get pass or pass pipeline from registry.
834 const PassRegistryEntry *info = PassPipelineInfo::lookup(getPassName());
835 if (!info)
836 info = PassInfo::lookup(getPassName());
837 if (!info)
838 return emitDefiniteFailure()
839 << "unknown pass or pass pipeline: " << getPassName();
840
841 // Create pass manager and add the pass or pass pipeline.
842 PassManager pm(getContext());
843 if (failed(info->addToPipeline(pm, options, [&](const Twine &msg) {
844 emitError(msg);
845 return failure();
846 }))) {
847 return emitDefiniteFailure()
848 << "failed to add pass or pass pipeline to pipeline: "
849 << getPassName();
850 }
851
852 auto targets = SmallVector<Operation *>(state.getPayloadOps(getTarget()));
853 for (Operation *target : targets) {
854 // Make sure that this transform is not applied to itself. Modifying the
855 // transform IR while it is being interpreted is generally dangerous. Even
856 // more so when applying passes because they may perform a wide range of IR
857 // modifications.
858 DiagnosedSilenceableFailure payloadCheck =
859 ensurePayloadIsSeparateFromTransform(*this, target);
860 if (!payloadCheck.succeeded())
861 return payloadCheck;
862
863 // Run the pass or pass pipeline on the current target operation.
864 if (failed(pm.run(target))) {
865 auto diag = emitSilenceableError() << "pass pipeline failed";
866 diag.attachNote(target->getLoc()) << "target op";
867 return diag;
868 }
869 }
870
871 // The applied pass will have directly modified the payload IR(s).
872 results.set(llvm::cast<OpResult>(getResult()), targets);
873 return DiagnosedSilenceableFailure::success();
874}
875
876static ParseResult parseApplyRegisteredPassOptions(
877 OpAsmParser &parser, DictionaryAttr &options,
878 SmallVectorImpl<OpAsmParser::UnresolvedOperand> &dynamicOptions) {
879 // Construct the options DictionaryAttr per a `{ key = value, ... }` syntax.
880 SmallVector<NamedAttribute> keyValuePairs;
881
882 size_t dynamicOptionsIdx = 0;
883 auto parseKeyValuePair = [&]() -> ParseResult {
884 // Parse items of the form `key = value` where `key` is a bare identifier or
885 // a string and `value` is either an attribute or an operand.
886
887 std::string key;
888 Attribute valueAttr;
889 if (parser.parseOptionalKeywordOrString(result: &key))
890 return parser.emitError(loc: parser.getCurrentLocation())
891 << "expected key to either be an identifier or a string";
892 if (key.empty())
893 return failure();
894
895 if (parser.parseEqual())
896 return parser.emitError(loc: parser.getCurrentLocation())
897 << "expected '=' after key in key-value pair";
898
899 // Parse the value, which can be either an attribute or an operand.
900 OptionalParseResult parsedValueAttr =
901 parser.parseOptionalAttribute(result&: valueAttr);
902 if (!parsedValueAttr.has_value()) {
903 OpAsmParser::UnresolvedOperand operand;
904 ParseResult parsedOperand = parser.parseOperand(result&: operand);
905 if (failed(Result: parsedOperand))
906 return parser.emitError(loc: parser.getCurrentLocation())
907 << "expected a valid attribute or operand as value associated "
908 << "to key '" << key << "'";
909 // To make use of the operand, we need to store it in the options dict.
910 // As SSA-values cannot occur in attributes, what we do instead is store
911 // an attribute in its place that contains the index of the param-operand,
912 // so that an attr-value associated to the param can be resolved later on.
913 dynamicOptions.push_back(Elt: operand);
914 auto wrappedIndex = IntegerAttr::get(
915 IntegerType::get(parser.getContext(), 64), dynamicOptionsIdx++);
916 valueAttr =
917 transform::ParamOperandAttr::get(parser.getContext(), wrappedIndex);
918 } else if (failed(Result: parsedValueAttr.value())) {
919 return failure(); // NB: Attempted parse should have output error message.
920 } else if (isa<transform::ParamOperandAttr>(valueAttr)) {
921 return parser.emitError(loc: parser.getCurrentLocation())
922 << "the param_operand attribute is a marker reserved for "
923 << "indicating a value will be passed via params and is only used "
924 << "in the generic print format";
925 }
926
927 keyValuePairs.push_back(Elt: NamedAttribute(key, valueAttr));
928 return success();
929 };
930
931 if (parser.parseCommaSeparatedList(delimiter: AsmParser::Delimiter::Braces,
932 parseElementFn: parseKeyValuePair,
933 contextMessage: " in options dictionary"))
934 return failure(); // NB: Attempted parse should have output error message.
935
936 if (DictionaryAttr::findDuplicate(
937 keyValuePairs, /*isSorted=*/false) // Also sorts the keyValuePairs.
938 .has_value())
939 return parser.emitError(loc: parser.getCurrentLocation())
940 << "duplicate keys found in options dictionary";
941
942 options = DictionaryAttr::getWithSorted(parser.getContext(), keyValuePairs);
943
944 return success();
945}
946
947static void printApplyRegisteredPassOptions(OpAsmPrinter &printer,
948 Operation *op,
949 DictionaryAttr options,
950 ValueRange dynamicOptions) {
951 if (options.empty())
952 return;
953
954 printer << "{";
955 llvm::interleaveComma(options, printer, [&](NamedAttribute namedAttribute) {
956 printer << namedAttribute.getName() << " = ";
957 Attribute value = namedAttribute.getValue();
958 if (auto indexAttr = dyn_cast<transform::ParamOperandAttr>(value)) {
959 // Resolve index of param-operand to its actual SSA-value and print that.
960 printer.printOperand(dynamicOptions[indexAttr.getIndex().getInt()]);
961 } else {
962 printer.printAttribute(attr: value);
963 }
964 });
965 printer << "}";
966}
967
968LogicalResult transform::ApplyRegisteredPassOp::verify() {
969 // Check that there is a one-to-one correspondence between param operands
970 // and references to dynamic options in the options dictionary.
971
972 auto dynamicOptions = SmallVector<Value>(getDynamicOptions());
973 for (NamedAttribute namedAttr : getOptions())
974 if (auto paramOperand =
975 dyn_cast<transform::ParamOperandAttr>(namedAttr.getValue())) {
976 size_t dynamicOptionIdx = paramOperand.getIndex().getInt();
977 if (dynamicOptionIdx < 0 || dynamicOptionIdx >= dynamicOptions.size())
978 return emitOpError()
979 << "dynamic option index " << dynamicOptionIdx
980 << " is out of bounds for the number of dynamic options: "
981 << dynamicOptions.size();
982 if (dynamicOptions[dynamicOptionIdx] == nullptr)
983 return emitOpError() << "dynamic option index " << dynamicOptionIdx
984 << " is already used in options";
985 dynamicOptions[dynamicOptionIdx] = nullptr; // Mark this option as used.
986 }
987
988 for (Value dynamicOption : dynamicOptions)
989 if (dynamicOption)
990 return emitOpError() << "a param operand does not have a corresponding "
991 << "param_operand attr in the options dict";
992
993 return success();
994}
995
996//===----------------------------------------------------------------------===//
997// CastOp
998//===----------------------------------------------------------------------===//
999
1000DiagnosedSilenceableFailure
1001transform::CastOp::applyToOne(transform::TransformRewriter &rewriter,
1002 Operation *target, ApplyToEachResultList &results,
1003 transform::TransformState &state) {
1004 results.push_back(target);
1005 return DiagnosedSilenceableFailure::success();
1006}
1007
1008void transform::CastOp::getEffects(
1009 SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
1010 onlyReadsPayload(effects);
1011 onlyReadsHandle(getInputMutable(), effects);
1012 producesHandle(getOperation()->getOpResults(), effects);
1013}
1014
1015bool transform::CastOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
1016 assert(inputs.size() == 1 && "expected one input");
1017 assert(outputs.size() == 1 && "expected one output");
1018 return llvm::all_of(
1019 std::initializer_list<Type>{inputs.front(), outputs.front()},
1020 llvm::IsaPred<transform::TransformHandleTypeInterface>);
1021}
1022
1023//===----------------------------------------------------------------------===//
1024// CollectMatchingOp
1025//===----------------------------------------------------------------------===//
1026
1027/// Applies matcher operations from the given `block` using
1028/// `blockArgumentMapping` to initialize block arguments. Updates `state`
1029/// accordingly. If any of the matcher produces a silenceable failure, discards
1030/// it (printing the content to the debug output stream) and returns failure. If
1031/// any of the matchers produces a definite failure, reports it and returns
1032/// failure. If all matchers in the block succeed, populates `mappings` with the
1033/// payload entities associated with the block terminator operands. Note that
1034/// `mappings` will be cleared before that.
1035static DiagnosedSilenceableFailure
1036matchBlock(Block &block,
1037 ArrayRef<SmallVector<transform::MappedValue>> blockArgumentMapping,
1038 transform::TransformState &state,
1039 SmallVectorImpl<SmallVector<transform::MappedValue>> &mappings) {
1040 assert(block.getParent() && "cannot match using a detached block");
1041 auto matchScope = state.make_region_scope(region&: *block.getParent());
1042 if (failed(
1043 Result: state.mapBlockArguments(arguments: block.getArguments(), mapping: blockArgumentMapping)))
1044 return DiagnosedSilenceableFailure::definiteFailure();
1045
1046 for (Operation &match : block.without_terminator()) {
1047 if (!isa<transform::MatchOpInterface>(Val: match)) {
1048 return emitDefiniteFailure(loc: match.getLoc())
1049 << "expected operations in the match part to "
1050 "implement MatchOpInterface";
1051 }
1052 DiagnosedSilenceableFailure diag =
1053 state.applyTransform(transform: cast<transform::TransformOpInterface>(match));
1054 if (diag.succeeded())
1055 continue;
1056
1057 return diag;
1058 }
1059
1060 // Remember the values mapped to the terminator operands so we can
1061 // forward them to the action.
1062 ValueRange yieldedValues = block.getTerminator()->getOperands();
1063 // Our contract with the caller is that the mappings will contain only the
1064 // newly mapped values, clear the rest.
1065 mappings.clear();
1066 transform::detail::prepareValueMappings(mappings, yieldedValues, state);
1067 return DiagnosedSilenceableFailure::success();
1068}
1069
1070/// Returns `true` if both types implement one of the interfaces provided as
1071/// template parameters.
1072template <typename... Tys>
1073static bool implementSameInterface(Type t1, Type t2) {
1074 return ((isa<Tys>(t1) && isa<Tys>(t2)) || ... || false);
1075}
1076
1077/// Returns `true` if both types implement one of the transform dialect
1078/// interfaces.
1079static bool implementSameTransformInterface(Type t1, Type t2) {
1080 return implementSameInterface<transform::TransformHandleTypeInterface,
1081 transform::TransformParamTypeInterface,
1082 transform::TransformValueHandleTypeInterface>(
1083 t1, t2);
1084}
1085
1086//===----------------------------------------------------------------------===//
1087// CollectMatchingOp
1088//===----------------------------------------------------------------------===//
1089
1090DiagnosedSilenceableFailure
1091transform::CollectMatchingOp::apply(transform::TransformRewriter &rewriter,
1092 transform::TransformResults &results,
1093 transform::TransformState &state) {
1094 auto matcher = SymbolTable::lookupNearestSymbolFrom<FunctionOpInterface>(
1095 getOperation(), getMatcher());
1096 if (matcher.isExternal()) {
1097 return emitDefiniteFailure()
1098 << "unresolved external symbol " << getMatcher();
1099 }
1100
1101 SmallVector<SmallVector<MappedValue>, 2> rawResults;
1102 rawResults.resize(getOperation()->getNumResults());
1103 std::optional<DiagnosedSilenceableFailure> maybeFailure;
1104 for (Operation *root : state.getPayloadOps(getRoot())) {
1105 WalkResult walkResult = root->walk([&](Operation *op) {
1106 DEBUG_MATCHER({
1107 DBGS_MATCHER() << "matching ";
1108 op->print(llvm::dbgs(),
1109 OpPrintingFlags().assumeVerified().skipRegions());
1110 llvm::dbgs() << " @" << op << "\n";
1111 });
1112
1113 // Try matching.
1114 SmallVector<SmallVector<MappedValue>> mappings;
1115 SmallVector<transform::MappedValue> inputMapping({op});
1116 DiagnosedSilenceableFailure diag = matchBlock(
1117 matcher.getFunctionBody().front(),
1118 ArrayRef<SmallVector<transform::MappedValue>>(inputMapping), state,
1119 mappings);
1120 if (diag.isDefiniteFailure())
1121 return WalkResult::interrupt();
1122 if (diag.isSilenceableFailure()) {
1123 DEBUG_MATCHER(DBGS_MATCHER() << "matcher " << matcher.getName()
1124 << " failed: " << diag.getMessage());
1125 return WalkResult::advance();
1126 }
1127
1128 // If succeeded, collect results.
1129 for (auto &&[i, mapping] : llvm::enumerate(mappings)) {
1130 if (mapping.size() != 1) {
1131 maybeFailure.emplace(emitSilenceableError()
1132 << "result #" << i << ", associated with "
1133 << mapping.size()
1134 << " payload objects, expected 1");
1135 return WalkResult::interrupt();
1136 }
1137 rawResults[i].push_back(mapping[0]);
1138 }
1139 return WalkResult::advance();
1140 });
1141 if (walkResult.wasInterrupted())
1142 return std::move(*maybeFailure);
1143 assert(!maybeFailure && "failure set but the walk was not interrupted");
1144
1145 for (auto &&[opResult, rawResult] :
1146 llvm::zip_equal(getOperation()->getResults(), rawResults)) {
1147 results.setMappedValues(opResult, rawResult);
1148 }
1149 }
1150 return DiagnosedSilenceableFailure::success();
1151}
1152
1153void transform::CollectMatchingOp::getEffects(
1154 SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
1155 onlyReadsHandle(getRootMutable(), effects);
1156 producesHandle(getOperation()->getOpResults(), effects);
1157 onlyReadsPayload(effects);
1158}
1159
1160LogicalResult transform::CollectMatchingOp::verifySymbolUses(
1161 SymbolTableCollection &symbolTable) {
1162 auto matcherSymbol = dyn_cast_or_null<FunctionOpInterface>(
1163 symbolTable.lookupNearestSymbolFrom(getOperation(), getMatcher()));
1164 if (!matcherSymbol ||
1165 !isa<TransformOpInterface>(matcherSymbol.getOperation()))
1166 return emitError() << "unresolved matcher symbol " << getMatcher();
1167
1168 ArrayRef<Type> argumentTypes = matcherSymbol.getArgumentTypes();
1169 if (argumentTypes.size() != 1 ||
1170 !isa<TransformHandleTypeInterface>(argumentTypes[0])) {
1171 return emitError()
1172 << "expected the matcher to take one operation handle argument";
1173 }
1174 if (!matcherSymbol.getArgAttr(
1175 0, transform::TransformDialect::kArgReadOnlyAttrName)) {
1176 return emitError() << "expected the matcher argument to be marked readonly";
1177 }
1178
1179 ArrayRef<Type> resultTypes = matcherSymbol.getResultTypes();
1180 if (resultTypes.size() != getOperation()->getNumResults()) {
1181 return emitError()
1182 << "expected the matcher to yield as many values as op has results ("
1183 << getOperation()->getNumResults() << "), got "
1184 << resultTypes.size();
1185 }
1186
1187 for (auto &&[i, matcherType, resultType] :
1188 llvm::enumerate(resultTypes, getOperation()->getResultTypes())) {
1189 if (implementSameTransformInterface(matcherType, resultType))
1190 continue;
1191
1192 return emitError()
1193 << "mismatching type interfaces for matcher result and op result #"
1194 << i;
1195 }
1196
1197 return success();
1198}
1199
1200//===----------------------------------------------------------------------===//
1201// ForeachMatchOp
1202//===----------------------------------------------------------------------===//
1203
1204// This is fine because nothing is actually consumed by this op.
1205bool transform::ForeachMatchOp::allowsRepeatedHandleOperands() { return true; }
1206
1207DiagnosedSilenceableFailure
1208transform::ForeachMatchOp::apply(transform::TransformRewriter &rewriter,
1209 transform::TransformResults &results,
1210 transform::TransformState &state) {
1211 SmallVector<std::pair<FunctionOpInterface, FunctionOpInterface>>
1212 matchActionPairs;
1213 matchActionPairs.reserve(getMatchers().size());
1214 SymbolTableCollection symbolTable;
1215 for (auto &&[matcher, action] :
1216 llvm::zip_equal(getMatchers(), getActions())) {
1217 auto matcherSymbol =
1218 symbolTable.lookupNearestSymbolFrom<FunctionOpInterface>(
1219 getOperation(), cast<SymbolRefAttr>(matcher));
1220 auto actionSymbol =
1221 symbolTable.lookupNearestSymbolFrom<FunctionOpInterface>(
1222 getOperation(), cast<SymbolRefAttr>(action));
1223 assert(matcherSymbol && actionSymbol &&
1224 "unresolved symbols not caught by the verifier");
1225
1226 if (matcherSymbol.isExternal())
1227 return emitDefiniteFailure() << "unresolved external symbol " << matcher;
1228 if (actionSymbol.isExternal())
1229 return emitDefiniteFailure() << "unresolved external symbol " << action;
1230
1231 matchActionPairs.emplace_back(matcherSymbol, actionSymbol);
1232 }
1233
1234 DiagnosedSilenceableFailure overallDiag =
1235 DiagnosedSilenceableFailure::success();
1236
1237 SmallVector<SmallVector<MappedValue>> matchInputMapping;
1238 SmallVector<SmallVector<MappedValue>> matchOutputMapping;
1239 SmallVector<SmallVector<MappedValue>> actionResultMapping;
1240 // Explicitly add the mapping for the first block argument (the op being
1241 // matched).
1242 matchInputMapping.emplace_back();
1243 transform::detail::prepareValueMappings(matchInputMapping,
1244 getForwardedInputs(), state);
1245 SmallVector<MappedValue> &firstMatchArgument = matchInputMapping.front();
1246 actionResultMapping.resize(getForwardedOutputs().size());
1247
1248 for (Operation *root : state.getPayloadOps(getRoot())) {
1249 WalkResult walkResult = root->walk([&](Operation *op) {
1250 // If getRestrictRoot is not present, skip over the root op itself so we
1251 // don't invalidate it.
1252 if (!getRestrictRoot() && op == root)
1253 return WalkResult::advance();
1254
1255 DEBUG_MATCHER({
1256 DBGS_MATCHER() << "matching ";
1257 op->print(llvm::dbgs(),
1258 OpPrintingFlags().assumeVerified().skipRegions());
1259 llvm::dbgs() << " @" << op << "\n";
1260 });
1261
1262 firstMatchArgument.clear();
1263 firstMatchArgument.push_back(op);
1264
1265 // Try all the match/action pairs until the first successful match.
1266 for (auto [matcher, action] : matchActionPairs) {
1267 DiagnosedSilenceableFailure diag =
1268 matchBlock(matcher.getFunctionBody().front(), matchInputMapping,
1269 state, matchOutputMapping);
1270 if (diag.isDefiniteFailure())
1271 return WalkResult::interrupt();
1272 if (diag.isSilenceableFailure()) {
1273 DEBUG_MATCHER(DBGS_MATCHER() << "matcher " << matcher.getName()
1274 << " failed: " << diag.getMessage());
1275 continue;
1276 }
1277
1278 auto scope = state.make_region_scope(action.getFunctionBody());
1279 if (failed(state.mapBlockArguments(
1280 action.getFunctionBody().front().getArguments(),
1281 matchOutputMapping))) {
1282 return WalkResult::interrupt();
1283 }
1284
1285 for (Operation &transform :
1286 action.getFunctionBody().front().without_terminator()) {
1287 DiagnosedSilenceableFailure result =
1288 state.applyTransform(cast<TransformOpInterface>(transform));
1289 if (result.isDefiniteFailure())
1290 return WalkResult::interrupt();
1291 if (result.isSilenceableFailure()) {
1292 if (overallDiag.succeeded()) {
1293 overallDiag = emitSilenceableError() << "actions failed";
1294 }
1295 overallDiag.attachNote(action->getLoc())
1296 << "failed action: " << result.getMessage();
1297 overallDiag.attachNote(op->getLoc())
1298 << "when applied to this matching payload";
1299 (void)result.silence();
1300 continue;
1301 }
1302 }
1303 if (failed(detail::appendValueMappings(
1304 MutableArrayRef<SmallVector<MappedValue>>(actionResultMapping),
1305 action.getFunctionBody().front().getTerminator()->getOperands(),
1306 state, getFlattenResults()))) {
1307 emitDefiniteFailure()
1308 << "action @" << action.getName()
1309 << " has results associated with multiple payload entities, "
1310 "but flattening was not requested";
1311 return WalkResult::interrupt();
1312 }
1313 break;
1314 }
1315 return WalkResult::advance();
1316 });
1317 if (walkResult.wasInterrupted())
1318 return DiagnosedSilenceableFailure::definiteFailure();
1319 }
1320
1321 // The root operation should not have been affected, so we can just reassign
1322 // the payload to the result. Note that we need to consume the root handle to
1323 // make sure any handles to operations inside, that could have been affected
1324 // by actions, are invalidated.
1325 results.set(llvm::cast<OpResult>(getUpdated()),
1326 state.getPayloadOps(getRoot()));
1327 for (auto &&[result, mapping] :
1328 llvm::zip_equal(getForwardedOutputs(), actionResultMapping)) {
1329 results.setMappedValues(result, mapping);
1330 }
1331 return overallDiag;
1332}
1333
1334void transform::ForeachMatchOp::getAsmResultNames(
1335 OpAsmSetValueNameFn setNameFn) {
1336 setNameFn(getUpdated(), "updated_root");
1337 for (Value v : getForwardedOutputs()) {
1338 setNameFn(v, "yielded");
1339 }
1340}
1341
1342void transform::ForeachMatchOp::getEffects(
1343 SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
1344 // Bail if invalid.
1345 if (getOperation()->getNumOperands() < 1 ||
1346 getOperation()->getNumResults() < 1) {
1347 return modifiesPayload(effects);
1348 }
1349
1350 consumesHandle(getRootMutable(), effects);
1351 onlyReadsHandle(getForwardedInputsMutable(), effects);
1352 producesHandle(getOperation()->getOpResults(), effects);
1353 modifiesPayload(effects);
1354}
1355
1356/// Parses the comma-separated list of symbol reference pairs of the format
1357/// `@matcher -> @action`.
1358static ParseResult parseForeachMatchSymbols(OpAsmParser &parser,
1359 ArrayAttr &matchers,
1360 ArrayAttr &actions) {
1361 StringAttr matcher;
1362 StringAttr action;
1363 SmallVector<Attribute> matcherList;
1364 SmallVector<Attribute> actionList;
1365 do {
1366 if (parser.parseSymbolName(matcher) || parser.parseArrow() ||
1367 parser.parseSymbolName(action)) {
1368 return failure();
1369 }
1370 matcherList.push_back(SymbolRefAttr::get(matcher));
1371 actionList.push_back(SymbolRefAttr::get(action));
1372 } while (parser.parseOptionalComma().succeeded());
1373
1374 matchers = parser.getBuilder().getArrayAttr(matcherList);
1375 actions = parser.getBuilder().getArrayAttr(actionList);
1376 return success();
1377}
1378
1379/// Prints the comma-separated list of symbol reference pairs of the format
1380/// `@matcher -> @action`.
1381static void printForeachMatchSymbols(OpAsmPrinter &printer, Operation *op,
1382 ArrayAttr matchers, ArrayAttr actions) {
1383 printer.increaseIndent();
1384 printer.increaseIndent();
1385 for (auto &&[matcher, action, idx] : llvm::zip_equal(
1386 matchers, actions, llvm::seq<unsigned>(0, matchers.size()))) {
1387 printer.printNewline();
1388 printer << cast<SymbolRefAttr>(matcher) << " -> "
1389 << cast<SymbolRefAttr>(action);
1390 if (idx != matchers.size() - 1)
1391 printer << ", ";
1392 }
1393 printer.decreaseIndent();
1394 printer.decreaseIndent();
1395}
1396
1397LogicalResult transform::ForeachMatchOp::verify() {
1398 if (getMatchers().size() != getActions().size())
1399 return emitOpError() << "expected the same number of matchers and actions";
1400 if (getMatchers().empty())
1401 return emitOpError() << "expected at least one match/action pair";
1402
1403 llvm::SmallPtrSet<Attribute, 8> matcherNames;
1404 for (Attribute name : getMatchers()) {
1405 if (matcherNames.insert(name).second)
1406 continue;
1407 emitWarning() << "matcher " << name
1408 << " is used more than once, only the first match will apply";
1409 }
1410
1411 return success();
1412}
1413
1414/// Checks that the attributes of the function-like operation have correct
1415/// consumption effect annotations. If `alsoVerifyInternal`, checks for
1416/// annotations being present even if they can be inferred from the body.
1417static DiagnosedSilenceableFailure
1418verifyFunctionLikeConsumeAnnotations(FunctionOpInterface op, bool emitWarnings,
1419 bool alsoVerifyInternal = false) {
1420 auto transformOp = cast<transform::TransformOpInterface>(op.getOperation());
1421 llvm::SmallDenseSet<unsigned> consumedArguments;
1422 if (!op.isExternal()) {
1423 transform::getConsumedBlockArguments(block&: op.getFunctionBody().front(),
1424 consumedArguments);
1425 }
1426 for (unsigned i = 0, e = op.getNumArguments(); i < e; ++i) {
1427 bool isConsumed =
1428 op.getArgAttr(i, transform::TransformDialect::kArgConsumedAttrName) !=
1429 nullptr;
1430 bool isReadOnly =
1431 op.getArgAttr(i, transform::TransformDialect::kArgReadOnlyAttrName) !=
1432 nullptr;
1433 if (isConsumed && isReadOnly) {
1434 return transformOp.emitSilenceableError()
1435 << "argument #" << i << " cannot be both readonly and consumed";
1436 }
1437 if ((op.isExternal() || alsoVerifyInternal) && !isConsumed && !isReadOnly) {
1438 return transformOp.emitSilenceableError()
1439 << "must provide consumed/readonly status for arguments of "
1440 "external or called ops";
1441 }
1442 if (op.isExternal())
1443 continue;
1444
1445 if (consumedArguments.contains(V: i) && !isConsumed && isReadOnly) {
1446 return transformOp.emitSilenceableError()
1447 << "argument #" << i
1448 << " is consumed in the body but is not marked as such";
1449 }
1450 if (emitWarnings && !consumedArguments.contains(V: i) && isConsumed) {
1451 // Cannot use op.emitWarning() here as it would attempt to verify the op
1452 // before printing, resulting in infinite recursion.
1453 emitWarning(op->getLoc())
1454 << "op argument #" << i
1455 << " is not consumed in the body but is marked as consumed";
1456 }
1457 }
1458 return DiagnosedSilenceableFailure::success();
1459}
1460
1461LogicalResult transform::ForeachMatchOp::verifySymbolUses(
1462 SymbolTableCollection &symbolTable) {
1463 assert(getMatchers().size() == getActions().size());
1464 auto consumedAttr =
1465 StringAttr::get(getContext(), TransformDialect::kArgConsumedAttrName);
1466 for (auto &&[matcher, action] :
1467 llvm::zip_equal(getMatchers(), getActions())) {
1468 // Presence and typing.
1469 auto matcherSymbol = dyn_cast_or_null<FunctionOpInterface>(
1470 symbolTable.lookupNearestSymbolFrom(getOperation(),
1471 cast<SymbolRefAttr>(matcher)));
1472 auto actionSymbol = dyn_cast_or_null<FunctionOpInterface>(
1473 symbolTable.lookupNearestSymbolFrom(getOperation(),
1474 cast<SymbolRefAttr>(action)));
1475 if (!matcherSymbol ||
1476 !isa<TransformOpInterface>(matcherSymbol.getOperation()))
1477 return emitError() << "unresolved matcher symbol " << matcher;
1478 if (!actionSymbol ||
1479 !isa<TransformOpInterface>(actionSymbol.getOperation()))
1480 return emitError() << "unresolved action symbol " << action;
1481
1482 if (failed(verifyFunctionLikeConsumeAnnotations(matcherSymbol,
1483 /*emitWarnings=*/false,
1484 /*alsoVerifyInternal=*/true)
1485 .checkAndReport())) {
1486 return failure();
1487 }
1488 if (failed(verifyFunctionLikeConsumeAnnotations(actionSymbol,
1489 /*emitWarnings=*/false,
1490 /*alsoVerifyInternal=*/true)
1491 .checkAndReport())) {
1492 return failure();
1493 }
1494
1495 // Input -> matcher forwarding.
1496 TypeRange operandTypes = getOperandTypes();
1497 TypeRange matcherArguments = matcherSymbol.getArgumentTypes();
1498 if (operandTypes.size() != matcherArguments.size()) {
1499 InFlightDiagnostic diag =
1500 emitError() << "the number of operands (" << operandTypes.size()
1501 << ") doesn't match the number of matcher arguments ("
1502 << matcherArguments.size() << ") for " << matcher;
1503 diag.attachNote(matcherSymbol->getLoc()) << "symbol declaration";
1504 return diag;
1505 }
1506 for (auto &&[i, operand, argument] :
1507 llvm::enumerate(operandTypes, matcherArguments)) {
1508 if (matcherSymbol.getArgAttr(i, consumedAttr)) {
1509 InFlightDiagnostic diag =
1510 emitOpError()
1511 << "does not expect matcher symbol to consume its operand #" << i;
1512 diag.attachNote(matcherSymbol->getLoc()) << "symbol declaration";
1513 return diag;
1514 }
1515
1516 if (implementSameTransformInterface(operand, argument))
1517 continue;
1518
1519 InFlightDiagnostic diag =
1520 emitError()
1521 << "mismatching type interfaces for operand and matcher argument #"
1522 << i << " of matcher " << matcher;
1523 diag.attachNote(matcherSymbol->getLoc()) << "symbol declaration";
1524 return diag;
1525 }
1526
1527 // Matcher -> action forwarding.
1528 TypeRange matcherResults = matcherSymbol.getResultTypes();
1529 TypeRange actionArguments = actionSymbol.getArgumentTypes();
1530 if (matcherResults.size() != actionArguments.size()) {
1531 return emitError() << "mismatching number of matcher results and "
1532 "action arguments between "
1533 << matcher << " (" << matcherResults.size() << ") and "
1534 << action << " (" << actionArguments.size() << ")";
1535 }
1536 for (auto &&[i, matcherType, actionType] :
1537 llvm::enumerate(matcherResults, actionArguments)) {
1538 if (implementSameTransformInterface(matcherType, actionType))
1539 continue;
1540
1541 return emitError() << "mismatching type interfaces for matcher result "
1542 "and action argument #"
1543 << i << "of matcher " << matcher << " and action "
1544 << action;
1545 }
1546
1547 // Action -> result forwarding.
1548 TypeRange actionResults = actionSymbol.getResultTypes();
1549 auto resultTypes = TypeRange(getResultTypes()).drop_front();
1550 if (actionResults.size() != resultTypes.size()) {
1551 InFlightDiagnostic diag =
1552 emitError() << "the number of action results ("
1553 << actionResults.size() << ") for " << action
1554 << " doesn't match the number of extra op results ("
1555 << resultTypes.size() << ")";
1556 diag.attachNote(actionSymbol->getLoc()) << "symbol declaration";
1557 return diag;
1558 }
1559 for (auto &&[i, resultType, actionType] :
1560 llvm::enumerate(resultTypes, actionResults)) {
1561 if (implementSameTransformInterface(resultType, actionType))
1562 continue;
1563
1564 InFlightDiagnostic diag =
1565 emitError() << "mismatching type interfaces for action result #" << i
1566 << " of action " << action << " and op result";
1567 diag.attachNote(actionSymbol->getLoc()) << "symbol declaration";
1568 return diag;
1569 }
1570 }
1571 return success();
1572}
1573
1574//===----------------------------------------------------------------------===//
1575// ForeachOp
1576//===----------------------------------------------------------------------===//
1577
1578DiagnosedSilenceableFailure
1579transform::ForeachOp::apply(transform::TransformRewriter &rewriter,
1580 transform::TransformResults &results,
1581 transform::TransformState &state) {
1582 // We store the payloads before executing the body as ops may be removed from
1583 // the mapping by the TrackingRewriter while iteration is in progress.
1584 SmallVector<SmallVector<MappedValue>> payloads;
1585 detail::prepareValueMappings(payloads, getTargets(), state);
1586 size_t numIterations = payloads.empty() ? 0 : payloads.front().size();
1587 bool withZipShortest = getWithZipShortest();
1588
1589 // In case of `zip_shortest`, set the number of iterations to the
1590 // smallest payload in the targets.
1591 if (withZipShortest) {
1592 numIterations =
1593 llvm::min_element(payloads, [&](const SmallVector<MappedValue> &A,
1594 const SmallVector<MappedValue> &B) {
1595 return A.size() < B.size();
1596 })->size();
1597
1598 for (size_t argIdx = 0; argIdx < payloads.size(); argIdx++)
1599 payloads[argIdx].resize(numIterations);
1600 }
1601
1602 // As we will be "zipping" over them, check all payloads have the same size.
1603 // `zip_shortest` adjusts all payloads to the same size, so skip this check
1604 // when true.
1605 for (size_t argIdx = 1; !withZipShortest && argIdx < payloads.size();
1606 argIdx++) {
1607 if (payloads[argIdx].size() != numIterations) {
1608 return emitSilenceableError()
1609 << "prior targets' payload size (" << numIterations
1610 << ") differs from payload size (" << payloads[argIdx].size()
1611 << ") of target " << getTargets()[argIdx];
1612 }
1613 }
1614
1615 // Start iterating, indexing into payloads to obtain the right arguments to
1616 // call the body with - each slice of payloads at the same argument index
1617 // corresponding to a tuple to use as the body's block arguments.
1618 ArrayRef<BlockArgument> blockArguments = getBody().front().getArguments();
1619 SmallVector<SmallVector<MappedValue>> zippedResults(getNumResults(), {});
1620 for (size_t iterIdx = 0; iterIdx < numIterations; iterIdx++) {
1621 auto scope = state.make_region_scope(getBody());
1622 // Set up arguments to the region's block.
1623 for (auto &&[argIdx, blockArg] : llvm::enumerate(blockArguments)) {
1624 MappedValue argument = payloads[argIdx][iterIdx];
1625 // Note that each blockArg's handle gets associated with just a single
1626 // element from the corresponding target's payload.
1627 if (failed(state.mapBlockArgument(blockArg, {argument})))
1628 return DiagnosedSilenceableFailure::definiteFailure();
1629 }
1630
1631 // Execute loop body.
1632 for (Operation &transform : getBody().front().without_terminator()) {
1633 DiagnosedSilenceableFailure result = state.applyTransform(
1634 llvm::cast<transform::TransformOpInterface>(transform));
1635 if (!result.succeeded())
1636 return result;
1637 }
1638
1639 // Append yielded payloads to corresponding results from prior iterations.
1640 OperandRange yieldOperands = getYieldOp().getOperands();
1641 for (auto &&[result, yieldOperand, resTuple] :
1642 llvm::zip_equal(getResults(), yieldOperands, zippedResults))
1643 // NB: each iteration we add any number of ops/vals/params to a result.
1644 if (isa<TransformHandleTypeInterface>(result.getType()))
1645 llvm::append_range(resTuple, state.getPayloadOps(yieldOperand));
1646 else if (isa<TransformValueHandleTypeInterface>(result.getType()))
1647 llvm::append_range(resTuple, state.getPayloadValues(yieldOperand));
1648 else if (isa<TransformParamTypeInterface>(result.getType()))
1649 llvm::append_range(resTuple, state.getParams(yieldOperand));
1650 else
1651 assert(false && "unhandled handle type");
1652 }
1653
1654 // Associate the accumulated result payloads to the op's actual results.
1655 for (auto &&[result, resPayload] : zip_equal(getResults(), zippedResults))
1656 results.setMappedValues(llvm::cast<OpResult>(result), resPayload);
1657
1658 return DiagnosedSilenceableFailure::success();
1659}
1660
1661void transform::ForeachOp::getEffects(
1662 SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
1663 // NB: this `zip` should be `zip_equal` - while this op's verifier catches
1664 // arity errors, this method might get called before/in absence of `verify()`.
1665 for (auto &&[target, blockArg] :
1666 llvm::zip(getTargetsMutable(), getBody().front().getArguments())) {
1667 BlockArgument blockArgument = blockArg;
1668 if (any_of(getBody().front().without_terminator(), [&](Operation &op) {
1669 return isHandleConsumed(blockArgument,
1670 cast<TransformOpInterface>(&op));
1671 })) {
1672 consumesHandle(target, effects);
1673 } else {
1674 onlyReadsHandle(target, effects);
1675 }
1676 }
1677
1678 if (any_of(getBody().front().without_terminator(), [&](Operation &op) {
1679 return doesModifyPayload(cast<TransformOpInterface>(&op));
1680 })) {
1681 modifiesPayload(effects);
1682 } else if (any_of(getBody().front().without_terminator(), [&](Operation &op) {
1683 return doesReadPayload(cast<TransformOpInterface>(&op));
1684 })) {
1685 onlyReadsPayload(effects);
1686 }
1687
1688 producesHandle(getOperation()->getOpResults(), effects);
1689}
1690
1691void transform::ForeachOp::getSuccessorRegions(
1692 RegionBranchPoint point, SmallVectorImpl<RegionSuccessor> &regions) {
1693 Region *bodyRegion = &getBody();
1694 if (point.isParent()) {
1695 regions.emplace_back(bodyRegion, bodyRegion->getArguments());
1696 return;
1697 }
1698
1699 // Branch back to the region or the parent.
1700 assert(point == getBody() && "unexpected region index");
1701 regions.emplace_back(bodyRegion, bodyRegion->getArguments());
1702 regions.emplace_back();
1703}
1704
1705OperandRange
1706transform::ForeachOp::getEntrySuccessorOperands(RegionBranchPoint point) {
1707 // Each block argument handle is mapped to a subset (one op to be precise)
1708 // of the payload of the corresponding `targets` operand of ForeachOp.
1709 assert(point == getBody() && "unexpected region index");
1710 return getOperation()->getOperands();
1711}
1712
1713transform::YieldOp transform::ForeachOp::getYieldOp() {
1714 return cast<transform::YieldOp>(getBody().front().getTerminator());
1715}
1716
1717LogicalResult transform::ForeachOp::verify() {
1718 for (auto [targetOpt, bodyArgOpt] :
1719 llvm::zip_longest(getTargets(), getBody().front().getArguments())) {
1720 if (!targetOpt || !bodyArgOpt)
1721 return emitOpError() << "expects the same number of targets as the body "
1722 "has block arguments";
1723 if (targetOpt.value().getType() != bodyArgOpt.value().getType())
1724 return emitOpError(
1725 "expects co-indexed targets and the body's "
1726 "block arguments to have the same op/value/param type");
1727 }
1728
1729 for (auto [resultOpt, yieldOperandOpt] :
1730 llvm::zip_longest(getResults(), getYieldOp().getOperands())) {
1731 if (!resultOpt || !yieldOperandOpt)
1732 return emitOpError() << "expects the same number of results as the "
1733 "yield terminator has operands";
1734 if (resultOpt.value().getType() != yieldOperandOpt.value().getType())
1735 return emitOpError("expects co-indexed results and yield "
1736 "operands to have the same op/value/param type");
1737 }
1738
1739 return success();
1740}
1741
1742//===----------------------------------------------------------------------===//
1743// GetParentOp
1744//===----------------------------------------------------------------------===//
1745
1746DiagnosedSilenceableFailure
1747transform::GetParentOp::apply(transform::TransformRewriter &rewriter,
1748 transform::TransformResults &results,
1749 transform::TransformState &state) {
1750 SmallVector<Operation *> parents;
1751 DenseSet<Operation *> resultSet;
1752 for (Operation *target : state.getPayloadOps(getTarget())) {
1753 Operation *parent = target;
1754 for (int64_t i = 0, e = getNthParent(); i < e; ++i) {
1755 parent = parent->getParentOp();
1756 while (parent) {
1757 bool checkIsolatedFromAbove =
1758 !getIsolatedFromAbove() ||
1759 parent->hasTrait<OpTrait::IsIsolatedFromAbove>();
1760 bool checkOpName = !getOpName().has_value() ||
1761 parent->getName().getStringRef() == *getOpName();
1762 if (checkIsolatedFromAbove && checkOpName)
1763 break;
1764 parent = parent->getParentOp();
1765 }
1766 if (!parent) {
1767 if (getAllowEmptyResults()) {
1768 results.set(llvm::cast<OpResult>(getResult()), parents);
1769 return DiagnosedSilenceableFailure::success();
1770 }
1771 DiagnosedSilenceableFailure diag =
1772 emitSilenceableError()
1773 << "could not find a parent op that matches all requirements";
1774 diag.attachNote(target->getLoc()) << "target op";
1775 return diag;
1776 }
1777 }
1778 if (getDeduplicate()) {
1779 if (resultSet.insert(parent).second)
1780 parents.push_back(parent);
1781 } else {
1782 parents.push_back(parent);
1783 }
1784 }
1785 results.set(llvm::cast<OpResult>(getResult()), parents);
1786 return DiagnosedSilenceableFailure::success();
1787}
1788
1789//===----------------------------------------------------------------------===//
1790// GetConsumersOfResult
1791//===----------------------------------------------------------------------===//
1792
1793DiagnosedSilenceableFailure
1794transform::GetConsumersOfResult::apply(transform::TransformRewriter &rewriter,
1795 transform::TransformResults &results,
1796 transform::TransformState &state) {
1797 int64_t resultNumber = getResultNumber();
1798 auto payloadOps = state.getPayloadOps(getTarget());
1799 if (std::empty(payloadOps)) {
1800 results.set(cast<OpResult>(getResult()), {});
1801 return DiagnosedSilenceableFailure::success();
1802 }
1803 if (!llvm::hasSingleElement(payloadOps))
1804 return emitDefiniteFailure()
1805 << "handle must be mapped to exactly one payload op";
1806
1807 Operation *target = *payloadOps.begin();
1808 if (target->getNumResults() <= resultNumber)
1809 return emitDefiniteFailure() << "result number overflow";
1810 results.set(llvm::cast<OpResult>(getResult()),
1811 llvm::to_vector(target->getResult(resultNumber).getUsers()));
1812 return DiagnosedSilenceableFailure::success();
1813}
1814
1815//===----------------------------------------------------------------------===//
1816// GetDefiningOp
1817//===----------------------------------------------------------------------===//
1818
1819DiagnosedSilenceableFailure
1820transform::GetDefiningOp::apply(transform::TransformRewriter &rewriter,
1821 transform::TransformResults &results,
1822 transform::TransformState &state) {
1823 SmallVector<Operation *> definingOps;
1824 for (Value v : state.getPayloadValues(getTarget())) {
1825 if (llvm::isa<BlockArgument>(v)) {
1826 DiagnosedSilenceableFailure diag =
1827 emitSilenceableError() << "cannot get defining op of block argument";
1828 diag.attachNote(v.getLoc()) << "target value";
1829 return diag;
1830 }
1831 definingOps.push_back(v.getDefiningOp());
1832 }
1833 results.set(llvm::cast<OpResult>(getResult()), definingOps);
1834 return DiagnosedSilenceableFailure::success();
1835}
1836
1837//===----------------------------------------------------------------------===//
1838// GetProducerOfOperand
1839//===----------------------------------------------------------------------===//
1840
1841DiagnosedSilenceableFailure
1842transform::GetProducerOfOperand::apply(transform::TransformRewriter &rewriter,
1843 transform::TransformResults &results,
1844 transform::TransformState &state) {
1845 int64_t operandNumber = getOperandNumber();
1846 SmallVector<Operation *> producers;
1847 for (Operation *target : state.getPayloadOps(getTarget())) {
1848 Operation *producer =
1849 target->getNumOperands() <= operandNumber
1850 ? nullptr
1851 : target->getOperand(operandNumber).getDefiningOp();
1852 if (!producer) {
1853 DiagnosedSilenceableFailure diag =
1854 emitSilenceableError()
1855 << "could not find a producer for operand number: " << operandNumber
1856 << " of " << *target;
1857 diag.attachNote(target->getLoc()) << "target op";
1858 return diag;
1859 }
1860 producers.push_back(producer);
1861 }
1862 results.set(llvm::cast<OpResult>(getResult()), producers);
1863 return DiagnosedSilenceableFailure::success();
1864}
1865
1866//===----------------------------------------------------------------------===//
1867// GetOperandOp
1868//===----------------------------------------------------------------------===//
1869
1870DiagnosedSilenceableFailure
1871transform::GetOperandOp::apply(transform::TransformRewriter &rewriter,
1872 transform::TransformResults &results,
1873 transform::TransformState &state) {
1874 SmallVector<Value> operands;
1875 for (Operation *target : state.getPayloadOps(getTarget())) {
1876 SmallVector<int64_t> operandPositions;
1877 DiagnosedSilenceableFailure diag = expandTargetSpecification(
1878 getLoc(), getIsAll(), getIsInverted(), getRawPositionList(),
1879 target->getNumOperands(), operandPositions);
1880 if (diag.isSilenceableFailure()) {
1881 diag.attachNote(target->getLoc())
1882 << "while considering positions of this payload operation";
1883 return diag;
1884 }
1885 llvm::append_range(operands,
1886 llvm::map_range(operandPositions, [&](int64_t pos) {
1887 return target->getOperand(pos);
1888 }));
1889 }
1890 results.setValues(cast<OpResult>(getResult()), operands);
1891 return DiagnosedSilenceableFailure::success();
1892}
1893
1894LogicalResult transform::GetOperandOp::verify() {
1895 return verifyTransformMatchDimsOp(getOperation(), getRawPositionList(),
1896 getIsInverted(), getIsAll());
1897}
1898
1899//===----------------------------------------------------------------------===//
1900// GetResultOp
1901//===----------------------------------------------------------------------===//
1902
1903DiagnosedSilenceableFailure
1904transform::GetResultOp::apply(transform::TransformRewriter &rewriter,
1905 transform::TransformResults &results,
1906 transform::TransformState &state) {
1907 SmallVector<Value> opResults;
1908 for (Operation *target : state.getPayloadOps(getTarget())) {
1909 SmallVector<int64_t> resultPositions;
1910 DiagnosedSilenceableFailure diag = expandTargetSpecification(
1911 getLoc(), getIsAll(), getIsInverted(), getRawPositionList(),
1912 target->getNumResults(), resultPositions);
1913 if (diag.isSilenceableFailure()) {
1914 diag.attachNote(target->getLoc())
1915 << "while considering positions of this payload operation";
1916 return diag;
1917 }
1918 llvm::append_range(opResults,
1919 llvm::map_range(resultPositions, [&](int64_t pos) {
1920 return target->getResult(pos);
1921 }));
1922 }
1923 results.setValues(cast<OpResult>(getResult()), opResults);
1924 return DiagnosedSilenceableFailure::success();
1925}
1926
1927LogicalResult transform::GetResultOp::verify() {
1928 return verifyTransformMatchDimsOp(getOperation(), getRawPositionList(),
1929 getIsInverted(), getIsAll());
1930}
1931
1932//===----------------------------------------------------------------------===//
1933// GetTypeOp
1934//===----------------------------------------------------------------------===//
1935
1936void transform::GetTypeOp::getEffects(
1937 SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
1938 onlyReadsHandle(getValueMutable(), effects);
1939 producesHandle(getOperation()->getOpResults(), effects);
1940 onlyReadsPayload(effects);
1941}
1942
1943DiagnosedSilenceableFailure
1944transform::GetTypeOp::apply(transform::TransformRewriter &rewriter,
1945 transform::TransformResults &results,
1946 transform::TransformState &state) {
1947 SmallVector<Attribute> params;
1948 for (Value value : state.getPayloadValues(getValue())) {
1949 Type type = value.getType();
1950 if (getElemental()) {
1951 if (auto shaped = dyn_cast<ShapedType>(type)) {
1952 type = shaped.getElementType();
1953 }
1954 }
1955 params.push_back(TypeAttr::get(type));
1956 }
1957 results.setParams(cast<OpResult>(getResult()), params);
1958 return DiagnosedSilenceableFailure::success();
1959}
1960
1961//===----------------------------------------------------------------------===//
1962// IncludeOp
1963//===----------------------------------------------------------------------===//
1964
1965/// Applies the transform ops contained in `block`. Maps `results` to the same
1966/// values as the operands of the block terminator.
1967static DiagnosedSilenceableFailure
1968applySequenceBlock(Block &block, transform::FailurePropagationMode mode,
1969 transform::TransformState &state,
1970 transform::TransformResults &results) {
1971 // Apply the sequenced ops one by one.
1972 for (Operation &transform : block.without_terminator()) {
1973 DiagnosedSilenceableFailure result =
1974 state.applyTransform(transform: cast<transform::TransformOpInterface>(transform));
1975 if (result.isDefiniteFailure())
1976 return result;
1977
1978 if (result.isSilenceableFailure()) {
1979 if (mode == transform::FailurePropagationMode::Propagate) {
1980 // Propagate empty results in case of early exit.
1981 forwardEmptyOperands(block: &block, state, results);
1982 return result;
1983 }
1984 (void)result.silence();
1985 }
1986 }
1987
1988 // Forward the operation mapping for values yielded from the sequence to the
1989 // values produced by the sequence op.
1990 transform::detail::forwardTerminatorOperands(block: &block, state, results);
1991 return DiagnosedSilenceableFailure::success();
1992}
1993
1994DiagnosedSilenceableFailure
1995transform::IncludeOp::apply(transform::TransformRewriter &rewriter,
1996 transform::TransformResults &results,
1997 transform::TransformState &state) {
1998 auto callee = SymbolTable::lookupNearestSymbolFrom<NamedSequenceOp>(
1999 getOperation(), getTarget());
2000 assert(callee && "unverified reference to unknown symbol");
2001
2002 if (callee.isExternal())
2003 return emitDefiniteFailure() << "unresolved external named sequence";
2004
2005 // Map operands to block arguments.
2006 SmallVector<SmallVector<MappedValue>> mappings;
2007 detail::prepareValueMappings(mappings, getOperands(), state);
2008 auto scope = state.make_region_scope(callee.getBody());
2009 for (auto &&[arg, map] :
2010 llvm::zip_equal(callee.getBody().front().getArguments(), mappings)) {
2011 if (failed(state.mapBlockArgument(arg, map)))
2012 return DiagnosedSilenceableFailure::definiteFailure();
2013 }
2014
2015 DiagnosedSilenceableFailure result = applySequenceBlock(
2016 callee.getBody().front(), getFailurePropagationMode(), state, results);
2017 mappings.clear();
2018 detail::prepareValueMappings(
2019 mappings, callee.getBody().front().getTerminator()->getOperands(), state);
2020 for (auto &&[result, mapping] : llvm::zip_equal(getResults(), mappings))
2021 results.setMappedValues(result, mapping);
2022 return result;
2023}
2024
2025static DiagnosedSilenceableFailure
2026verifyNamedSequenceOp(transform::NamedSequenceOp op, bool emitWarnings);
2027
2028void transform::IncludeOp::getEffects(
2029 SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
2030 // Always mark as modifying the payload.
2031 // TODO: a mechanism to annotate effects on payload. Even when all handles are
2032 // only read, the payload may still be modified, so we currently stay on the
2033 // conservative side and always indicate modification. This may prevent some
2034 // code reordering.
2035 modifiesPayload(effects);
2036
2037 // Results are always produced.
2038 producesHandle(getOperation()->getOpResults(), effects);
2039
2040 // Adds default effects to operands and results. This will be added if
2041 // preconditions fail so the trait verifier doesn't complain about missing
2042 // effects and the real precondition failure is reported later on.
2043 auto defaultEffects = [&] {
2044 onlyReadsHandle(getOperation()->getOpOperands(), effects);
2045 };
2046
2047 // Bail if the callee is unknown. This may run as part of the verification
2048 // process before we verified the validity of the callee or of this op.
2049 auto target =
2050 getOperation()->getAttrOfType<SymbolRefAttr>(getTargetAttrName());
2051 if (!target)
2052 return defaultEffects();
2053 auto callee = SymbolTable::lookupNearestSymbolFrom<NamedSequenceOp>(
2054 getOperation(), getTarget());
2055 if (!callee)
2056 return defaultEffects();
2057 DiagnosedSilenceableFailure earlyVerifierResult =
2058 verifyNamedSequenceOp(callee, /*emitWarnings=*/false);
2059 if (!earlyVerifierResult.succeeded()) {
2060 (void)earlyVerifierResult.silence();
2061 return defaultEffects();
2062 }
2063
2064 for (unsigned i = 0, e = getNumOperands(); i < e; ++i) {
2065 if (callee.getArgAttr(i, TransformDialect::kArgConsumedAttrName))
2066 consumesHandle(getOperation()->getOpOperand(i), effects);
2067 else
2068 onlyReadsHandle(getOperation()->getOpOperand(i), effects);
2069 }
2070}
2071
2072LogicalResult
2073transform::IncludeOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
2074 // Access through indirection and do additional checking because this may be
2075 // running before the main op verifier.
2076 auto targetAttr = getOperation()->getAttrOfType<SymbolRefAttr>("target");
2077 if (!targetAttr)
2078 return emitOpError() << "expects a 'target' symbol reference attribute";
2079
2080 auto target = symbolTable.lookupNearestSymbolFrom<transform::NamedSequenceOp>(
2081 *this, targetAttr);
2082 if (!target)
2083 return emitOpError() << "does not reference a named transform sequence";
2084
2085 FunctionType fnType = target.getFunctionType();
2086 if (fnType.getNumInputs() != getNumOperands())
2087 return emitError("incorrect number of operands for callee");
2088
2089 for (unsigned i = 0, e = fnType.getNumInputs(); i != e; ++i) {
2090 if (getOperand(i).getType() != fnType.getInput(i)) {
2091 return emitOpError("operand type mismatch: expected operand type ")
2092 << fnType.getInput(i) << ", but provided "
2093 << getOperand(i).getType() << " for operand number " << i;
2094 }
2095 }
2096
2097 if (fnType.getNumResults() != getNumResults())
2098 return emitError("incorrect number of results for callee");
2099
2100 for (unsigned i = 0, e = fnType.getNumResults(); i != e; ++i) {
2101 Type resultType = getResult(i).getType();
2102 Type funcType = fnType.getResult(i);
2103 if (!implementSameTransformInterface(resultType, funcType)) {
2104 return emitOpError() << "type of result #" << i
2105 << " must implement the same transform dialect "
2106 "interface as the corresponding callee result";
2107 }
2108 }
2109
2110 return verifyFunctionLikeConsumeAnnotations(
2111 cast<FunctionOpInterface>(*target), /*emitWarnings=*/false,
2112 /*alsoVerifyInternal=*/true)
2113 .checkAndReport();
2114}
2115
2116//===----------------------------------------------------------------------===//
2117// MatchOperationEmptyOp
2118//===----------------------------------------------------------------------===//
2119
2120DiagnosedSilenceableFailure transform::MatchOperationEmptyOp::matchOperation(
2121 ::std::optional<::mlir::Operation *> maybeCurrent,
2122 transform::TransformResults &results, transform::TransformState &state) {
2123 if (!maybeCurrent.has_value()) {
2124 DEBUG_MATCHER({ DBGS_MATCHER() << "MatchOperationEmptyOp success\n"; });
2125 return DiagnosedSilenceableFailure::success();
2126 }
2127 DEBUG_MATCHER({ DBGS_MATCHER() << "MatchOperationEmptyOp failure\n"; });
2128 return emitSilenceableError() << "operation is not empty";
2129}
2130
2131//===----------------------------------------------------------------------===//
2132// MatchOperationNameOp
2133//===----------------------------------------------------------------------===//
2134
2135DiagnosedSilenceableFailure transform::MatchOperationNameOp::matchOperation(
2136 Operation *current, transform::TransformResults &results,
2137 transform::TransformState &state) {
2138 StringRef currentOpName = current->getName().getStringRef();
2139 for (auto acceptedAttr : getOpNames().getAsRange<StringAttr>()) {
2140 if (acceptedAttr.getValue() == currentOpName)
2141 return DiagnosedSilenceableFailure::success();
2142 }
2143 return emitSilenceableError() << "wrong operation name";
2144}
2145
2146//===----------------------------------------------------------------------===//
2147// MatchParamCmpIOp
2148//===----------------------------------------------------------------------===//
2149
2150DiagnosedSilenceableFailure
2151transform::MatchParamCmpIOp::apply(transform::TransformRewriter &rewriter,
2152 transform::TransformResults &results,
2153 transform::TransformState &state) {
2154 auto signedAPIntAsString = [&](const APInt &value) {
2155 std::string str;
2156 llvm::raw_string_ostream os(str);
2157 value.print(os, /*isSigned=*/true);
2158 return str;
2159 };
2160
2161 ArrayRef<Attribute> params = state.getParams(getParam());
2162 ArrayRef<Attribute> references = state.getParams(getReference());
2163
2164 if (params.size() != references.size()) {
2165 return emitSilenceableError()
2166 << "parameters have different payload lengths (" << params.size()
2167 << " vs " << references.size() << ")";
2168 }
2169
2170 for (auto &&[i, param, reference] : llvm::enumerate(params, references)) {
2171 auto intAttr = llvm::dyn_cast<IntegerAttr>(param);
2172 auto refAttr = llvm::dyn_cast<IntegerAttr>(reference);
2173 if (!intAttr || !refAttr) {
2174 return emitDefiniteFailure()
2175 << "non-integer parameter value not expected";
2176 }
2177 if (intAttr.getType() != refAttr.getType()) {
2178 return emitDefiniteFailure()
2179 << "mismatching integer attribute types in parameter #" << i;
2180 }
2181 APInt value = intAttr.getValue();
2182 APInt refValue = refAttr.getValue();
2183
2184 // TODO: this copy will not be necessary in C++20.
2185 int64_t position = i;
2186 auto reportError = [&](StringRef direction) {
2187 DiagnosedSilenceableFailure diag =
2188 emitSilenceableError() << "expected parameter to be " << direction
2189 << " " << signedAPIntAsString(refValue)
2190 << ", got " << signedAPIntAsString(value);
2191 diag.attachNote(getParam().getLoc())
2192 << "value # " << position
2193 << " associated with the parameter defined here";
2194 return diag;
2195 };
2196
2197 switch (getPredicate()) {
2198 case MatchCmpIPredicate::eq:
2199 if (value.eq(refValue))
2200 break;
2201 return reportError("equal to");
2202 case MatchCmpIPredicate::ne:
2203 if (value.ne(refValue))
2204 break;
2205 return reportError("not equal to");
2206 case MatchCmpIPredicate::lt:
2207 if (value.slt(refValue))
2208 break;
2209 return reportError("less than");
2210 case MatchCmpIPredicate::le:
2211 if (value.sle(refValue))
2212 break;
2213 return reportError("less than or equal to");
2214 case MatchCmpIPredicate::gt:
2215 if (value.sgt(refValue))
2216 break;
2217 return reportError("greater than");
2218 case MatchCmpIPredicate::ge:
2219 if (value.sge(refValue))
2220 break;
2221 return reportError("greater than or equal to");
2222 }
2223 }
2224 return DiagnosedSilenceableFailure::success();
2225}
2226
2227void transform::MatchParamCmpIOp::getEffects(
2228 SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
2229 onlyReadsHandle(getParamMutable(), effects);
2230 onlyReadsHandle(getReferenceMutable(), effects);
2231}
2232
2233//===----------------------------------------------------------------------===//
2234// ParamConstantOp
2235//===----------------------------------------------------------------------===//
2236
2237DiagnosedSilenceableFailure
2238transform::ParamConstantOp::apply(transform::TransformRewriter &rewriter,
2239 transform::TransformResults &results,
2240 transform::TransformState &state) {
2241 results.setParams(cast<OpResult>(getParam()), {getValue()});
2242 return DiagnosedSilenceableFailure::success();
2243}
2244
2245//===----------------------------------------------------------------------===//
2246// MergeHandlesOp
2247//===----------------------------------------------------------------------===//
2248
2249DiagnosedSilenceableFailure
2250transform::MergeHandlesOp::apply(transform::TransformRewriter &rewriter,
2251 transform::TransformResults &results,
2252 transform::TransformState &state) {
2253 ValueRange handles = getHandles();
2254 if (isa<TransformHandleTypeInterface>(handles.front().getType())) {
2255 SmallVector<Operation *> operations;
2256 for (Value operand : handles)
2257 llvm::append_range(operations, state.getPayloadOps(operand));
2258 if (!getDeduplicate()) {
2259 results.set(llvm::cast<OpResult>(getResult()), operations);
2260 return DiagnosedSilenceableFailure::success();
2261 }
2262
2263 SetVector<Operation *> uniqued(llvm::from_range, operations);
2264 results.set(llvm::cast<OpResult>(getResult()), uniqued.getArrayRef());
2265 return DiagnosedSilenceableFailure::success();
2266 }
2267
2268 if (llvm::isa<TransformParamTypeInterface>(handles.front().getType())) {
2269 SmallVector<Attribute> attrs;
2270 for (Value attribute : handles)
2271 llvm::append_range(attrs, state.getParams(attribute));
2272 if (!getDeduplicate()) {
2273 results.setParams(cast<OpResult>(getResult()), attrs);
2274 return DiagnosedSilenceableFailure::success();
2275 }
2276
2277 SetVector<Attribute> uniqued(llvm::from_range, attrs);
2278 results.setParams(cast<OpResult>(getResult()), uniqued.getArrayRef());
2279 return DiagnosedSilenceableFailure::success();
2280 }
2281
2282 assert(
2283 llvm::isa<TransformValueHandleTypeInterface>(handles.front().getType()) &&
2284 "expected value handle type");
2285 SmallVector<Value> payloadValues;
2286 for (Value value : handles)
2287 llvm::append_range(payloadValues, state.getPayloadValues(value));
2288 if (!getDeduplicate()) {
2289 results.setValues(cast<OpResult>(getResult()), payloadValues);
2290 return DiagnosedSilenceableFailure::success();
2291 }
2292
2293 SetVector<Value> uniqued(llvm::from_range, payloadValues);
2294 results.setValues(cast<OpResult>(getResult()), uniqued.getArrayRef());
2295 return DiagnosedSilenceableFailure::success();
2296}
2297
2298bool transform::MergeHandlesOp::allowsRepeatedHandleOperands() {
2299 // Handles may be the same if deduplicating is enabled.
2300 return getDeduplicate();
2301}
2302
2303void transform::MergeHandlesOp::getEffects(
2304 SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
2305 onlyReadsHandle(getHandlesMutable(), effects);
2306 producesHandle(getOperation()->getOpResults(), effects);
2307
2308 // There are no effects on the Payload IR as this is only a handle
2309 // manipulation.
2310}
2311
2312OpFoldResult transform::MergeHandlesOp::fold(FoldAdaptor adaptor) {
2313 if (getDeduplicate() || getHandles().size() != 1)
2314 return {};
2315
2316 // If deduplication is not required and there is only one operand, it can be
2317 // used directly instead of merging.
2318 return getHandles().front();
2319}
2320
2321//===----------------------------------------------------------------------===//
2322// NamedSequenceOp
2323//===----------------------------------------------------------------------===//
2324
2325DiagnosedSilenceableFailure
2326transform::NamedSequenceOp::apply(transform::TransformRewriter &rewriter,
2327 transform::TransformResults &results,
2328 transform::TransformState &state) {
2329 if (isExternal())
2330 return emitDefiniteFailure() << "unresolved external named sequence";
2331
2332 // Map the entry block argument to the list of operations.
2333 // Note: this is the same implementation as PossibleTopLevelTransformOp but
2334 // without attaching the interface / trait since that is tailored to a
2335 // dangling top-level op that does not get "called".
2336 auto scope = state.make_region_scope(getBody());
2337 if (failed(detail::mapPossibleTopLevelTransformOpBlockArguments(
2338 state, this->getOperation(), getBody())))
2339 return DiagnosedSilenceableFailure::definiteFailure();
2340
2341 return applySequenceBlock(getBody().front(),
2342 FailurePropagationMode::Propagate, state, results);
2343}
2344
2345void transform::NamedSequenceOp::getEffects(
2346 SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {}
2347
2348ParseResult transform::NamedSequenceOp::parse(OpAsmParser &parser,
2349 OperationState &result) {
2350 return function_interface_impl::parseFunctionOp(
2351 parser, result, /*allowVariadic=*/false,
2352 getFunctionTypeAttrName(result.name),
2353 [](Builder &builder, ArrayRef<Type> inputs, ArrayRef<Type> results,
2354 function_interface_impl::VariadicFlag,
2355 std::string &) { return builder.getFunctionType(inputs, results); },
2356 getArgAttrsAttrName(result.name), getResAttrsAttrName(result.name));
2357}
2358
2359void transform::NamedSequenceOp::print(OpAsmPrinter &printer) {
2360 function_interface_impl::printFunctionOp(
2361 printer, cast<FunctionOpInterface>(getOperation()), /*isVariadic=*/false,
2362 getFunctionTypeAttrName().getValue(), getArgAttrsAttrName(),
2363 getResAttrsAttrName());
2364}
2365
2366/// Verifies that a symbol function-like transform dialect operation has the
2367/// signature and the terminator that have conforming types, i.e., types
2368/// implementing the same transform dialect type interface. If `allowExternal`
2369/// is set, allow external symbols (declarations) and don't check the terminator
2370/// as it may not exist.
2371static DiagnosedSilenceableFailure
2372verifyYieldingSingleBlockOp(FunctionOpInterface op, bool allowExternal) {
2373 if (auto parent = op->getParentOfType<transform::TransformOpInterface>()) {
2374 DiagnosedSilenceableFailure diag =
2375 emitSilenceableFailure(op)
2376 << "cannot be defined inside another transform op";
2377 diag.attachNote(loc: parent.getLoc()) << "ancestor transform op";
2378 return diag;
2379 }
2380
2381 if (op.isExternal() || op.getFunctionBody().empty()) {
2382 if (allowExternal)
2383 return DiagnosedSilenceableFailure::success();
2384
2385 return emitSilenceableFailure(op) << "cannot be external";
2386 }
2387
2388 if (op.getFunctionBody().front().empty())
2389 return emitSilenceableFailure(op) << "expected a non-empty body block";
2390
2391 Operation *terminator = &op.getFunctionBody().front().back();
2392 if (!isa<transform::YieldOp>(terminator)) {
2393 DiagnosedSilenceableFailure diag = emitSilenceableFailure(op)
2394 << "expected '"
2395 << transform::YieldOp::getOperationName()
2396 << "' as terminator";
2397 diag.attachNote(loc: terminator->getLoc()) << "terminator";
2398 return diag;
2399 }
2400
2401 if (terminator->getNumOperands() != op.getResultTypes().size()) {
2402 return emitSilenceableFailure(op: terminator)
2403 << "expected terminator to have as many operands as the parent op "
2404 "has results";
2405 }
2406 for (auto [i, operandType, resultType] : llvm::zip_equal(
2407 llvm::seq<unsigned>(0, terminator->getNumOperands()),
2408 terminator->getOperands().getType(), op.getResultTypes())) {
2409 if (operandType == resultType)
2410 continue;
2411 return emitSilenceableFailure(terminator)
2412 << "the type of the terminator operand #" << i
2413 << " must match the type of the corresponding parent op result ("
2414 << operandType << " vs " << resultType << ")";
2415 }
2416
2417 return DiagnosedSilenceableFailure::success();
2418}
2419
2420/// Verification of a NamedSequenceOp. This does not report the error
2421/// immediately, so it can be used to check for op's well-formedness before the
2422/// verifier runs, e.g., during trait verification.
2423static DiagnosedSilenceableFailure
2424verifyNamedSequenceOp(transform::NamedSequenceOp op, bool emitWarnings) {
2425 if (Operation *parent = op->getParentWithTrait<OpTrait::SymbolTable>()) {
2426 if (!parent->getAttr(
2427 transform::TransformDialect::kWithNamedSequenceAttrName)) {
2428 DiagnosedSilenceableFailure diag =
2429 emitSilenceableFailure(op)
2430 << "expects the parent symbol table to have the '"
2431 << transform::TransformDialect::kWithNamedSequenceAttrName
2432 << "' attribute";
2433 diag.attachNote(loc: parent->getLoc()) << "symbol table operation";
2434 return diag;
2435 }
2436 }
2437
2438 if (auto parent = op->getParentOfType<transform::TransformOpInterface>()) {
2439 DiagnosedSilenceableFailure diag =
2440 emitSilenceableFailure(op)
2441 << "cannot be defined inside another transform op";
2442 diag.attachNote(loc: parent.getLoc()) << "ancestor transform op";
2443 return diag;
2444 }
2445
2446 if (op.isExternal() || op.getBody().empty())
2447 return verifyFunctionLikeConsumeAnnotations(cast<FunctionOpInterface>(*op),
2448 emitWarnings);
2449
2450 if (op.getBody().front().empty())
2451 return emitSilenceableFailure(op) << "expected a non-empty body block";
2452
2453 Operation *terminator = &op.getBody().front().back();
2454 if (!isa<transform::YieldOp>(terminator)) {
2455 DiagnosedSilenceableFailure diag = emitSilenceableFailure(op)
2456 << "expected '"
2457 << transform::YieldOp::getOperationName()
2458 << "' as terminator";
2459 diag.attachNote(loc: terminator->getLoc()) << "terminator";
2460 return diag;
2461 }
2462
2463 if (terminator->getNumOperands() != op.getFunctionType().getNumResults()) {
2464 return emitSilenceableFailure(op: terminator)
2465 << "expected terminator to have as many operands as the parent op "
2466 "has results";
2467 }
2468 for (auto [i, operandType, resultType] :
2469 llvm::zip_equal(llvm::seq<unsigned>(0, terminator->getNumOperands()),
2470 terminator->getOperands().getType(),
2471 op.getFunctionType().getResults())) {
2472 if (operandType == resultType)
2473 continue;
2474 return emitSilenceableFailure(terminator)
2475 << "the type of the terminator operand #" << i
2476 << " must match the type of the corresponding parent op result ("
2477 << operandType << " vs " << resultType << ")";
2478 }
2479
2480 auto funcOp = cast<FunctionOpInterface>(*op);
2481 DiagnosedSilenceableFailure diag =
2482 verifyFunctionLikeConsumeAnnotations(funcOp, emitWarnings);
2483 if (!diag.succeeded())
2484 return diag;
2485
2486 return verifyYieldingSingleBlockOp(funcOp,
2487 /*allowExternal=*/true);
2488}
2489
2490LogicalResult transform::NamedSequenceOp::verify() {
2491 // Actual verification happens in a separate function for reusability.
2492 return verifyNamedSequenceOp(*this, /*emitWarnings=*/true).checkAndReport();
2493}
2494
2495template <typename FnTy>
2496static void buildSequenceBody(OpBuilder &builder, OperationState &state,
2497 Type bbArgType, TypeRange extraBindingTypes,
2498 FnTy bodyBuilder) {
2499 SmallVector<Type> types;
2500 types.reserve(N: 1 + extraBindingTypes.size());
2501 types.push_back(Elt: bbArgType);
2502 llvm::append_range(C&: types, R&: extraBindingTypes);
2503
2504 OpBuilder::InsertionGuard guard(builder);
2505 Region *region = state.regions.back().get();
2506 Block *bodyBlock =
2507 builder.createBlock(parent: region, insertPt: region->begin(), argTypes: types,
2508 locs: SmallVector<Location>(types.size(), state.location));
2509
2510 // Populate body.
2511 builder.setInsertionPointToStart(bodyBlock);
2512 if constexpr (llvm::function_traits<FnTy>::num_args == 3) {
2513 bodyBuilder(builder, state.location, bodyBlock->getArgument(i: 0));
2514 } else {
2515 bodyBuilder(builder, state.location, bodyBlock->getArgument(i: 0),
2516 bodyBlock->getArguments().drop_front());
2517 }
2518}
2519
2520void transform::NamedSequenceOp::build(OpBuilder &builder,
2521 OperationState &state, StringRef symName,
2522 Type rootType, TypeRange resultTypes,
2523 SequenceBodyBuilderFn bodyBuilder,
2524 ArrayRef<NamedAttribute> attrs,
2525 ArrayRef<DictionaryAttr> argAttrs) {
2526 state.addAttribute(SymbolTable::getSymbolAttrName(),
2527 builder.getStringAttr(symName));
2528 state.addAttribute(getFunctionTypeAttrName(state.name),
2529 TypeAttr::get(FunctionType::get(builder.getContext(),
2530 rootType, resultTypes)));
2531 state.attributes.append(attrs.begin(), attrs.end());
2532 state.addRegion();
2533
2534 buildSequenceBody(builder, state, rootType,
2535 /*extraBindingTypes=*/TypeRange(), bodyBuilder);
2536}
2537
2538//===----------------------------------------------------------------------===//
2539// NumAssociationsOp
2540//===----------------------------------------------------------------------===//
2541
2542DiagnosedSilenceableFailure
2543transform::NumAssociationsOp::apply(transform::TransformRewriter &rewriter,
2544 transform::TransformResults &results,
2545 transform::TransformState &state) {
2546 size_t numAssociations =
2547 llvm::TypeSwitch<Type, size_t>(getHandle().getType())
2548 .Case([&](TransformHandleTypeInterface opHandle) {
2549 return llvm::range_size(state.getPayloadOps(getHandle()));
2550 })
2551 .Case([&](TransformValueHandleTypeInterface valueHandle) {
2552 return llvm::range_size(state.getPayloadValues(getHandle()));
2553 })
2554 .Case([&](TransformParamTypeInterface param) {
2555 return llvm::range_size(state.getParams(getHandle()));
2556 })
2557 .Default([](Type) {
2558 llvm_unreachable("unknown kind of transform dialect type");
2559 return 0;
2560 });
2561 results.setParams(cast<OpResult>(getNum()),
2562 rewriter.getI64IntegerAttr(numAssociations));
2563 return DiagnosedSilenceableFailure::success();
2564}
2565
2566LogicalResult transform::NumAssociationsOp::verify() {
2567 // Verify that the result type accepts an i64 attribute as payload.
2568 auto resultType = cast<TransformParamTypeInterface>(getNum().getType());
2569 return resultType
2570 .checkPayload(getLoc(), {Builder(getContext()).getI64IntegerAttr(0)})
2571 .checkAndReport();
2572}
2573
2574//===----------------------------------------------------------------------===//
2575// SelectOp
2576//===----------------------------------------------------------------------===//
2577
2578DiagnosedSilenceableFailure
2579transform::SelectOp::apply(transform::TransformRewriter &rewriter,
2580 transform::TransformResults &results,
2581 transform::TransformState &state) {
2582 SmallVector<Operation *> result;
2583 auto payloadOps = state.getPayloadOps(getTarget());
2584 for (Operation *op : payloadOps) {
2585 if (op->getName().getStringRef() == getOpName())
2586 result.push_back(op);
2587 }
2588 results.set(cast<OpResult>(getResult()), result);
2589 return DiagnosedSilenceableFailure::success();
2590}
2591
2592//===----------------------------------------------------------------------===//
2593// SplitHandleOp
2594//===----------------------------------------------------------------------===//
2595
2596void transform::SplitHandleOp::build(OpBuilder &builder, OperationState &result,
2597 Value target, int64_t numResultHandles) {
2598 result.addOperands(target);
2599 result.addTypes(SmallVector<Type>(numResultHandles, target.getType()));
2600}
2601
2602DiagnosedSilenceableFailure
2603transform::SplitHandleOp::apply(transform::TransformRewriter &rewriter,
2604 transform::TransformResults &results,
2605 transform::TransformState &state) {
2606 int64_t numPayloads =
2607 llvm::TypeSwitch<Type, int64_t>(getHandle().getType())
2608 .Case<TransformHandleTypeInterface>([&](auto x) {
2609 return llvm::range_size(state.getPayloadOps(getHandle()));
2610 })
2611 .Case<TransformValueHandleTypeInterface>([&](auto x) {
2612 return llvm::range_size(state.getPayloadValues(getHandle()));
2613 })
2614 .Case<TransformParamTypeInterface>([&](auto x) {
2615 return llvm::range_size(state.getParams(getHandle()));
2616 })
2617 .Default([](auto x) {
2618 llvm_unreachable("unknown transform dialect type interface");
2619 return -1;
2620 });
2621
2622 auto produceNumOpsError = [&]() {
2623 return emitSilenceableError()
2624 << getHandle() << " expected to contain " << this->getNumResults()
2625 << " payloads but it contains " << numPayloads << " payloads";
2626 };
2627
2628 // Fail if there are more payload ops than results and no overflow result was
2629 // specified.
2630 if (numPayloads > getNumResults() && !getOverflowResult().has_value())
2631 return produceNumOpsError();
2632
2633 // Fail if there are more results than payload ops. Unless:
2634 // - "fail_on_payload_too_small" is set to "false", or
2635 // - "pass_through_empty_handle" is set to "true" and there are 0 payload ops.
2636 if (numPayloads < getNumResults() && getFailOnPayloadTooSmall() &&
2637 (numPayloads != 0 || !getPassThroughEmptyHandle()))
2638 return produceNumOpsError();
2639
2640 // Distribute payloads.
2641 SmallVector<SmallVector<MappedValue, 1>> resultHandles(getNumResults(), {});
2642 if (getOverflowResult())
2643 resultHandles[*getOverflowResult()].reserve(numPayloads - getNumResults());
2644
2645 auto container = [&]() {
2646 if (isa<TransformHandleTypeInterface>(getHandle().getType())) {
2647 return llvm::map_to_vector(
2648 state.getPayloadOps(getHandle()),
2649 [](Operation *op) -> MappedValue { return op; });
2650 }
2651 if (isa<TransformValueHandleTypeInterface>(getHandle().getType())) {
2652 return llvm::map_to_vector(state.getPayloadValues(getHandle()),
2653 [](Value v) -> MappedValue { return v; });
2654 }
2655 assert(isa<TransformParamTypeInterface>(getHandle().getType()) &&
2656 "unsupported kind of transform dialect type");
2657 return llvm::map_to_vector(state.getParams(getHandle()),
2658 [](Attribute a) -> MappedValue { return a; });
2659 }();
2660
2661 for (auto &&en : llvm::enumerate(container)) {
2662 int64_t resultNum = en.index();
2663 if (resultNum >= getNumResults())
2664 resultNum = *getOverflowResult();
2665 resultHandles[resultNum].push_back(en.value());
2666 }
2667
2668 // Set transform op results.
2669 for (auto &&it : llvm::enumerate(resultHandles))
2670 results.setMappedValues(llvm::cast<OpResult>(getResult(it.index())),
2671 it.value());
2672
2673 return DiagnosedSilenceableFailure::success();
2674}
2675
2676void transform::SplitHandleOp::getEffects(
2677 SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
2678 onlyReadsHandle(getHandleMutable(), effects);
2679 producesHandle(getOperation()->getOpResults(), effects);
2680 // There are no effects on the Payload IR as this is only a handle
2681 // manipulation.
2682}
2683
2684LogicalResult transform::SplitHandleOp::verify() {
2685 if (getOverflowResult().has_value() &&
2686 !(*getOverflowResult() < getNumResults()))
2687 return emitOpError("overflow_result is not a valid result index");
2688
2689 for (Type resultType : getResultTypes()) {
2690 if (implementSameTransformInterface(getHandle().getType(), resultType))
2691 continue;
2692
2693 return emitOpError("expects result types to implement the same transform "
2694 "interface as the operand type");
2695 }
2696
2697 return success();
2698}
2699
2700//===----------------------------------------------------------------------===//
2701// ReplicateOp
2702//===----------------------------------------------------------------------===//
2703
2704DiagnosedSilenceableFailure
2705transform::ReplicateOp::apply(transform::TransformRewriter &rewriter,
2706 transform::TransformResults &results,
2707 transform::TransformState &state) {
2708 unsigned numRepetitions = llvm::range_size(state.getPayloadOps(getPattern()));
2709 for (const auto &en : llvm::enumerate(getHandles())) {
2710 Value handle = en.value();
2711 if (isa<TransformHandleTypeInterface>(handle.getType())) {
2712 SmallVector<Operation *> current =
2713 llvm::to_vector(state.getPayloadOps(handle));
2714 SmallVector<Operation *> payload;
2715 payload.reserve(numRepetitions * current.size());
2716 for (unsigned i = 0; i < numRepetitions; ++i)
2717 llvm::append_range(payload, current);
2718 results.set(llvm::cast<OpResult>(getReplicated()[en.index()]), payload);
2719 } else {
2720 assert(llvm::isa<TransformParamTypeInterface>(handle.getType()) &&
2721 "expected param type");
2722 ArrayRef<Attribute> current = state.getParams(handle);
2723 SmallVector<Attribute> params;
2724 params.reserve(numRepetitions * current.size());
2725 for (unsigned i = 0; i < numRepetitions; ++i)
2726 llvm::append_range(params, current);
2727 results.setParams(llvm::cast<OpResult>(getReplicated()[en.index()]),
2728 params);
2729 }
2730 }
2731 return DiagnosedSilenceableFailure::success();
2732}
2733
2734void transform::ReplicateOp::getEffects(
2735 SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
2736 onlyReadsHandle(getPatternMutable(), effects);
2737 onlyReadsHandle(getHandlesMutable(), effects);
2738 producesHandle(getOperation()->getOpResults(), effects);
2739}
2740
2741//===----------------------------------------------------------------------===//
2742// SequenceOp
2743//===----------------------------------------------------------------------===//
2744
2745DiagnosedSilenceableFailure
2746transform::SequenceOp::apply(transform::TransformRewriter &rewriter,
2747 transform::TransformResults &results,
2748 transform::TransformState &state) {
2749 // Map the entry block argument to the list of operations.
2750 auto scope = state.make_region_scope(*getBodyBlock()->getParent());
2751 if (failed(mapBlockArguments(state)))
2752 return DiagnosedSilenceableFailure::definiteFailure();
2753
2754 return applySequenceBlock(*getBodyBlock(), getFailurePropagationMode(), state,
2755 results);
2756}
2757
2758static ParseResult parseSequenceOpOperands(
2759 OpAsmParser &parser, std::optional<OpAsmParser::UnresolvedOperand> &root,
2760 Type &rootType,
2761 SmallVectorImpl<OpAsmParser::UnresolvedOperand> &extraBindings,
2762 SmallVectorImpl<Type> &extraBindingTypes) {
2763 OpAsmParser::UnresolvedOperand rootOperand;
2764 OptionalParseResult hasRoot = parser.parseOptionalOperand(result&: rootOperand);
2765 if (!hasRoot.has_value()) {
2766 root = std::nullopt;
2767 return success();
2768 }
2769 if (failed(Result: hasRoot.value()))
2770 return failure();
2771 root = rootOperand;
2772
2773 if (succeeded(Result: parser.parseOptionalComma())) {
2774 if (failed(Result: parser.parseOperandList(result&: extraBindings)))
2775 return failure();
2776 }
2777 if (failed(Result: parser.parseColon()))
2778 return failure();
2779
2780 // The paren is truly optional.
2781 (void)parser.parseOptionalLParen();
2782
2783 if (failed(Result: parser.parseType(result&: rootType))) {
2784 return failure();
2785 }
2786
2787 if (!extraBindings.empty()) {
2788 if (parser.parseComma() || parser.parseTypeList(result&: extraBindingTypes))
2789 return failure();
2790 }
2791
2792 if (extraBindingTypes.size() != extraBindings.size()) {
2793 return parser.emitError(loc: parser.getNameLoc(),
2794 message: "expected types to be provided for all operands");
2795 }
2796
2797 // The paren is truly optional.
2798 (void)parser.parseOptionalRParen();
2799 return success();
2800}
2801
2802static void printSequenceOpOperands(OpAsmPrinter &printer, Operation *op,
2803 Value root, Type rootType,
2804 ValueRange extraBindings,
2805 TypeRange extraBindingTypes) {
2806 if (!root)
2807 return;
2808
2809 printer << root;
2810 bool hasExtras = !extraBindings.empty();
2811 if (hasExtras) {
2812 printer << ", ";
2813 printer.printOperands(container: extraBindings);
2814 }
2815
2816 printer << " : ";
2817 if (hasExtras)
2818 printer << "(";
2819
2820 printer << rootType;
2821 if (hasExtras)
2822 printer << ", " << llvm::interleaved(R: extraBindingTypes) << ')';
2823}
2824
2825/// Returns `true` if the given op operand may be consuming the handle value in
2826/// the Transform IR. That is, if it may have a Free effect on it.
2827static bool isValueUsePotentialConsumer(OpOperand &use) {
2828 // Conservatively assume the effect being present in absence of the interface.
2829 auto iface = dyn_cast<transform::TransformOpInterface>(use.getOwner());
2830 if (!iface)
2831 return true;
2832
2833 return isHandleConsumed(use.get(), iface);
2834}
2835
2836LogicalResult
2837checkDoubleConsume(Value value,
2838 function_ref<InFlightDiagnostic()> reportError) {
2839 OpOperand *potentialConsumer = nullptr;
2840 for (OpOperand &use : value.getUses()) {
2841 if (!isValueUsePotentialConsumer(use))
2842 continue;
2843
2844 if (!potentialConsumer) {
2845 potentialConsumer = &use;
2846 continue;
2847 }
2848
2849 InFlightDiagnostic diag = reportError()
2850 << " has more than one potential consumer";
2851 diag.attachNote(noteLoc: potentialConsumer->getOwner()->getLoc())
2852 << "used here as operand #" << potentialConsumer->getOperandNumber();
2853 diag.attachNote(noteLoc: use.getOwner()->getLoc())
2854 << "used here as operand #" << use.getOperandNumber();
2855 return diag;
2856 }
2857
2858 return success();
2859}
2860
2861LogicalResult transform::SequenceOp::verify() {
2862 assert(getBodyBlock()->getNumArguments() >= 1 &&
2863 "the number of arguments must have been verified to be more than 1 by "
2864 "PossibleTopLevelTransformOpTrait");
2865
2866 if (!getRoot() && !getExtraBindings().empty()) {
2867 return emitOpError()
2868 << "does not expect extra operands when used as top-level";
2869 }
2870
2871 // Check if a block argument has more than one consuming use.
2872 for (BlockArgument arg : getBodyBlock()->getArguments()) {
2873 if (failed(checkDoubleConsume(arg, [this, arg]() {
2874 return (emitOpError() << "block argument #" << arg.getArgNumber());
2875 }))) {
2876 return failure();
2877 }
2878 }
2879
2880 // Check properties of the nested operations they cannot check themselves.
2881 for (Operation &child : *getBodyBlock()) {
2882 if (!isa<TransformOpInterface>(child) &&
2883 &child != &getBodyBlock()->back()) {
2884 InFlightDiagnostic diag =
2885 emitOpError()
2886 << "expected children ops to implement TransformOpInterface";
2887 diag.attachNote(child.getLoc()) << "op without interface";
2888 return diag;
2889 }
2890
2891 for (OpResult result : child.getResults()) {
2892 auto report = [&]() {
2893 return (child.emitError() << "result #" << result.getResultNumber());
2894 };
2895 if (failed(checkDoubleConsume(result, report)))
2896 return failure();
2897 }
2898 }
2899
2900 if (!getBodyBlock()->mightHaveTerminator())
2901 return emitOpError() << "expects to have a terminator in the body";
2902
2903 if (getBodyBlock()->getTerminator()->getOperandTypes() !=
2904 getOperation()->getResultTypes()) {
2905 InFlightDiagnostic diag = emitOpError()
2906 << "expects the types of the terminator operands "
2907 "to match the types of the result";
2908 diag.attachNote(getBodyBlock()->getTerminator()->getLoc()) << "terminator";
2909 return diag;
2910 }
2911 return success();
2912}
2913
2914void transform::SequenceOp::getEffects(
2915 SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
2916 getPotentialTopLevelEffects(effects);
2917}
2918
2919OperandRange
2920transform::SequenceOp::getEntrySuccessorOperands(RegionBranchPoint point) {
2921 assert(point == getBody() && "unexpected region index");
2922 if (getOperation()->getNumOperands() > 0)
2923 return getOperation()->getOperands();
2924 return OperandRange(getOperation()->operand_end(),
2925 getOperation()->operand_end());
2926}
2927
2928void transform::SequenceOp::getSuccessorRegions(
2929 RegionBranchPoint point, SmallVectorImpl<RegionSuccessor> &regions) {
2930 if (point.isParent()) {
2931 Region *bodyRegion = &getBody();
2932 regions.emplace_back(bodyRegion, getNumOperands() != 0
2933 ? bodyRegion->getArguments()
2934 : Block::BlockArgListType());
2935 return;
2936 }
2937
2938 assert(point == getBody() && "unexpected region index");
2939 regions.emplace_back(getOperation()->getResults());
2940}
2941
2942void transform::SequenceOp::getRegionInvocationBounds(
2943 ArrayRef<Attribute> operands, SmallVectorImpl<InvocationBounds> &bounds) {
2944 (void)operands;
2945 bounds.emplace_back(1, 1);
2946}
2947
2948void transform::SequenceOp::build(OpBuilder &builder, OperationState &state,
2949 TypeRange resultTypes,
2950 FailurePropagationMode failurePropagationMode,
2951 Value root,
2952 SequenceBodyBuilderFn bodyBuilder) {
2953 build(builder, state, resultTypes, failurePropagationMode, root,
2954 /*extra_bindings=*/ValueRange());
2955 Type bbArgType = root.getType();
2956 buildSequenceBody(builder, state, bbArgType,
2957 /*extraBindingTypes=*/TypeRange(), bodyBuilder);
2958}
2959
2960void transform::SequenceOp::build(OpBuilder &builder, OperationState &state,
2961 TypeRange resultTypes,
2962 FailurePropagationMode failurePropagationMode,
2963 Value root, ValueRange extraBindings,
2964 SequenceBodyBuilderArgsFn bodyBuilder) {
2965 build(builder, state, resultTypes, failurePropagationMode, root,
2966 extraBindings);
2967 buildSequenceBody(builder, state, root.getType(), extraBindings.getTypes(),
2968 bodyBuilder);
2969}
2970
2971void transform::SequenceOp::build(OpBuilder &builder, OperationState &state,
2972 TypeRange resultTypes,
2973 FailurePropagationMode failurePropagationMode,
2974 Type bbArgType,
2975 SequenceBodyBuilderFn bodyBuilder) {
2976 build(builder, state, resultTypes, failurePropagationMode, /*root=*/Value(),
2977 /*extra_bindings=*/ValueRange());
2978 buildSequenceBody(builder, state, bbArgType,
2979 /*extraBindingTypes=*/TypeRange(), bodyBuilder);
2980}
2981
2982void transform::SequenceOp::build(OpBuilder &builder, OperationState &state,
2983 TypeRange resultTypes,
2984 FailurePropagationMode failurePropagationMode,
2985 Type bbArgType, TypeRange extraBindingTypes,
2986 SequenceBodyBuilderArgsFn bodyBuilder) {
2987 build(builder, state, resultTypes, failurePropagationMode, /*root=*/Value(),
2988 /*extra_bindings=*/ValueRange());
2989 buildSequenceBody(builder, state, bbArgType, extraBindingTypes, bodyBuilder);
2990}
2991
2992//===----------------------------------------------------------------------===//
2993// PrintOp
2994//===----------------------------------------------------------------------===//
2995
2996void transform::PrintOp::build(OpBuilder &builder, OperationState &result,
2997 StringRef name) {
2998 if (!name.empty())
2999 result.getOrAddProperties<Properties>().name = builder.getStringAttr(name);
3000}
3001
3002void transform::PrintOp::build(OpBuilder &builder, OperationState &result,
3003 Value target, StringRef name) {
3004 result.addOperands({target});
3005 build(builder, result, name);
3006}
3007
3008DiagnosedSilenceableFailure
3009transform::PrintOp::apply(transform::TransformRewriter &rewriter,
3010 transform::TransformResults &results,
3011 transform::TransformState &state) {
3012 llvm::outs() << "[[[ IR printer: ";
3013 if (getName().has_value())
3014 llvm::outs() << *getName() << " ";
3015
3016 OpPrintingFlags printFlags;
3017 if (getAssumeVerified().value_or(false))
3018 printFlags.assumeVerified();
3019 if (getUseLocalScope().value_or(false))
3020 printFlags.useLocalScope();
3021 if (getSkipRegions().value_or(false))
3022 printFlags.skipRegions();
3023
3024 if (!getTarget()) {
3025 llvm::outs() << "top-level ]]]\n";
3026 state.getTopLevel()->print(llvm::outs(), printFlags);
3027 llvm::outs() << "\n";
3028 llvm::outs().flush();
3029 return DiagnosedSilenceableFailure::success();
3030 }
3031
3032 llvm::outs() << "]]]\n";
3033 for (Operation *target : state.getPayloadOps(getTarget())) {
3034 target->print(llvm::outs(), printFlags);
3035 llvm::outs() << "\n";
3036 }
3037
3038 llvm::outs().flush();
3039 return DiagnosedSilenceableFailure::success();
3040}
3041
3042void transform::PrintOp::getEffects(
3043 SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
3044 // We don't really care about mutability here, but `getTarget` now
3045 // unconditionally casts to a specific type before verification could run
3046 // here.
3047 if (!getTargetMutable().empty())
3048 onlyReadsHandle(getTargetMutable()[0], effects);
3049 onlyReadsPayload(effects);
3050
3051 // There is no resource for stderr file descriptor, so just declare print
3052 // writes into the default resource.
3053 effects.emplace_back(MemoryEffects::Write::get());
3054}
3055
3056//===----------------------------------------------------------------------===//
3057// VerifyOp
3058//===----------------------------------------------------------------------===//
3059
3060DiagnosedSilenceableFailure
3061transform::VerifyOp::applyToOne(transform::TransformRewriter &rewriter,
3062 Operation *target,
3063 transform::ApplyToEachResultList &results,
3064 transform::TransformState &state) {
3065 if (failed(::mlir::verify(target))) {
3066 DiagnosedDefiniteFailure diag = emitDefiniteFailure()
3067 << "failed to verify payload op";
3068 diag.attachNote(target->getLoc()) << "payload op";
3069 return diag;
3070 }
3071 return DiagnosedSilenceableFailure::success();
3072}
3073
3074void transform::VerifyOp::getEffects(
3075 SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
3076 transform::onlyReadsHandle(getTargetMutable(), effects);
3077}
3078
3079//===----------------------------------------------------------------------===//
3080// YieldOp
3081//===----------------------------------------------------------------------===//
3082
3083void transform::YieldOp::getEffects(
3084 SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
3085 onlyReadsHandle(getOperandsMutable(), effects);
3086}
3087

Provided by KDAB

Privacy Policy
Learn to use CMake with our Intro Training
Find out more

source code of mlir/lib/Dialect/Transform/IR/TransformOps.cpp