1//===- TransformOps.cpp - Transform dialect operations --------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#include "mlir/Dialect/Transform/IR/TransformOps.h"
10
11#include "mlir/Conversion/ConvertToLLVM/ToLLVMInterface.h"
12#include "mlir/Conversion/LLVMCommon/ConversionTarget.h"
13#include "mlir/Conversion/LLVMCommon/TypeConverter.h"
14#include "mlir/Dialect/Transform/IR/TransformAttrs.h"
15#include "mlir/Dialect/Transform/IR/TransformDialect.h"
16#include "mlir/Dialect/Transform/IR/TransformTypes.h"
17#include "mlir/Dialect/Transform/Interfaces/MatchInterfaces.h"
18#include "mlir/Dialect/Transform/Interfaces/TransformInterfaces.h"
19#include "mlir/IR/BuiltinAttributes.h"
20#include "mlir/IR/Diagnostics.h"
21#include "mlir/IR/Dominance.h"
22#include "mlir/IR/OperationSupport.h"
23#include "mlir/IR/PatternMatch.h"
24#include "mlir/IR/Verifier.h"
25#include "mlir/Interfaces/CallInterfaces.h"
26#include "mlir/Interfaces/ControlFlowInterfaces.h"
27#include "mlir/Interfaces/FunctionImplementation.h"
28#include "mlir/Interfaces/FunctionInterfaces.h"
29#include "mlir/Pass/Pass.h"
30#include "mlir/Pass/PassManager.h"
31#include "mlir/Pass/PassRegistry.h"
32#include "mlir/Transforms/CSE.h"
33#include "mlir/Transforms/DialectConversion.h"
34#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
35#include "mlir/Transforms/LoopInvariantCodeMotionUtils.h"
36#include "llvm/ADT/DenseSet.h"
37#include "llvm/ADT/STLExtras.h"
38#include "llvm/ADT/ScopeExit.h"
39#include "llvm/ADT/SmallPtrSet.h"
40#include "llvm/ADT/TypeSwitch.h"
41#include "llvm/Support/Debug.h"
42#include "llvm/Support/ErrorHandling.h"
43#include <optional>
44
45#define DEBUG_TYPE "transform-dialect"
46#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "] ")
47
48#define DEBUG_TYPE_MATCHER "transform-matcher"
49#define DBGS_MATCHER() (llvm::dbgs() << "[" DEBUG_TYPE_MATCHER "] ")
50#define DEBUG_MATCHER(x) DEBUG_WITH_TYPE(DEBUG_TYPE_MATCHER, x)
51
52using namespace mlir;
53
54static ParseResult parseSequenceOpOperands(
55 OpAsmParser &parser, std::optional<OpAsmParser::UnresolvedOperand> &root,
56 Type &rootType,
57 SmallVectorImpl<OpAsmParser::UnresolvedOperand> &extraBindings,
58 SmallVectorImpl<Type> &extraBindingTypes);
59static void printSequenceOpOperands(OpAsmPrinter &printer, Operation *op,
60 Value root, Type rootType,
61 ValueRange extraBindings,
62 TypeRange extraBindingTypes);
63static void printForeachMatchSymbols(OpAsmPrinter &printer, Operation *op,
64 ArrayAttr matchers, ArrayAttr actions);
65static ParseResult parseForeachMatchSymbols(OpAsmParser &parser,
66 ArrayAttr &matchers,
67 ArrayAttr &actions);
68
69/// Helper function to check if the given transform op is contained in (or
70/// equal to) the given payload target op. In that case, an error is returned.
71/// Transforming transform IR that is currently executing is generally unsafe.
72static DiagnosedSilenceableFailure
73ensurePayloadIsSeparateFromTransform(transform::TransformOpInterface transform,
74 Operation *payload) {
75 Operation *transformAncestor = transform.getOperation();
76 while (transformAncestor) {
77 if (transformAncestor == payload) {
78 DiagnosedDefiniteFailure diag =
79 transform.emitDefiniteFailure()
80 << "cannot apply transform to itself (or one of its ancestors)";
81 diag.attachNote(loc: payload->getLoc()) << "target payload op";
82 return diag;
83 }
84 transformAncestor = transformAncestor->getParentOp();
85 }
86 return DiagnosedSilenceableFailure::success();
87}
88
89#define GET_OP_CLASSES
90#include "mlir/Dialect/Transform/IR/TransformOps.cpp.inc"
91
92//===----------------------------------------------------------------------===//
93// AlternativesOp
94//===----------------------------------------------------------------------===//
95
96OperandRange
97transform::AlternativesOp::getEntrySuccessorOperands(RegionBranchPoint point) {
98 if (!point.isParent() && getOperation()->getNumOperands() == 1)
99 return getOperation()->getOperands();
100 return OperandRange(getOperation()->operand_end(),
101 getOperation()->operand_end());
102}
103
104void transform::AlternativesOp::getSuccessorRegions(
105 RegionBranchPoint point, SmallVectorImpl<RegionSuccessor> &regions) {
106 for (Region &alternative : llvm::drop_begin(
107 getAlternatives(),
108 point.isParent() ? 0
109 : point.getRegionOrNull()->getRegionNumber() + 1)) {
110 regions.emplace_back(&alternative, !getOperands().empty()
111 ? alternative.getArguments()
112 : Block::BlockArgListType());
113 }
114 if (!point.isParent())
115 regions.emplace_back(getOperation()->getResults());
116}
117
118void transform::AlternativesOp::getRegionInvocationBounds(
119 ArrayRef<Attribute> operands, SmallVectorImpl<InvocationBounds> &bounds) {
120 (void)operands;
121 // The region corresponding to the first alternative is always executed, the
122 // remaining may or may not be executed.
123 bounds.reserve(getNumRegions());
124 bounds.emplace_back(1, 1);
125 bounds.resize(getNumRegions(), InvocationBounds(0, 1));
126}
127
128static void forwardEmptyOperands(Block *block, transform::TransformState &state,
129 transform::TransformResults &results) {
130 for (const auto &res : block->getParentOp()->getOpResults())
131 results.set(value: res, ops: {});
132}
133
134DiagnosedSilenceableFailure
135transform::AlternativesOp::apply(transform::TransformRewriter &rewriter,
136 transform::TransformResults &results,
137 transform::TransformState &state) {
138 SmallVector<Operation *> originals;
139 if (Value scopeHandle = getScope())
140 llvm::append_range(originals, state.getPayloadOps(scopeHandle));
141 else
142 originals.push_back(state.getTopLevel());
143
144 for (Operation *original : originals) {
145 if (original->isAncestor(getOperation())) {
146 auto diag = emitDefiniteFailure()
147 << "scope must not contain the transforms being applied";
148 diag.attachNote(original->getLoc()) << "scope";
149 return diag;
150 }
151 if (!original->hasTrait<OpTrait::IsIsolatedFromAbove>()) {
152 auto diag = emitDefiniteFailure()
153 << "only isolated-from-above ops can be alternative scopes";
154 diag.attachNote(original->getLoc()) << "scope";
155 return diag;
156 }
157 }
158
159 for (Region &reg : getAlternatives()) {
160 // Clone the scope operations and make the transforms in this alternative
161 // region apply to them by virtue of mapping the block argument (the only
162 // visible handle) to the cloned scope operations. This effectively prevents
163 // the transformation from accessing any IR outside the scope.
164 auto scope = state.make_region_scope(reg);
165 auto clones = llvm::to_vector(
166 llvm::map_range(originals, [](Operation *op) { return op->clone(); }));
167 auto deleteClones = llvm::make_scope_exit([&] {
168 for (Operation *clone : clones)
169 clone->erase();
170 });
171 if (failed(state.mapBlockArguments(reg.front().getArgument(0), clones)))
172 return DiagnosedSilenceableFailure::definiteFailure();
173
174 bool failed = false;
175 for (Operation &transform : reg.front().without_terminator()) {
176 DiagnosedSilenceableFailure result =
177 state.applyTransform(cast<TransformOpInterface>(transform));
178 if (result.isSilenceableFailure()) {
179 LLVM_DEBUG(DBGS() << "alternative failed: " << result.getMessage()
180 << "\n");
181 failed = true;
182 break;
183 }
184
185 if (::mlir::failed(result.silence()))
186 return DiagnosedSilenceableFailure::definiteFailure();
187 }
188
189 // If all operations in the given alternative succeeded, no need to consider
190 // the rest. Replace the original scoping operation with the clone on which
191 // the transformations were performed.
192 if (!failed) {
193 // We will be using the clones, so cancel their scheduled deletion.
194 deleteClones.release();
195 TrackingListener listener(state, *this);
196 IRRewriter rewriter(getContext(), &listener);
197 for (const auto &kvp : llvm::zip(originals, clones)) {
198 Operation *original = std::get<0>(kvp);
199 Operation *clone = std::get<1>(kvp);
200 original->getBlock()->getOperations().insert(original->getIterator(),
201 clone);
202 rewriter.replaceOp(original, clone->getResults());
203 }
204 detail::forwardTerminatorOperands(&reg.front(), state, results);
205 return DiagnosedSilenceableFailure::success();
206 }
207 }
208 return emitSilenceableError() << "all alternatives failed";
209}
210
211void transform::AlternativesOp::getEffects(
212 SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
213 consumesHandle(getOperands(), effects);
214 producesHandle(getResults(), effects);
215 for (Region *region : getRegions()) {
216 if (!region->empty())
217 producesHandle(region->front().getArguments(), effects);
218 }
219 modifiesPayload(effects);
220}
221
222LogicalResult transform::AlternativesOp::verify() {
223 for (Region &alternative : getAlternatives()) {
224 Block &block = alternative.front();
225 Operation *terminator = block.getTerminator();
226 if (terminator->getOperands().getTypes() != getResults().getTypes()) {
227 InFlightDiagnostic diag = emitOpError()
228 << "expects terminator operands to have the "
229 "same type as results of the operation";
230 diag.attachNote(terminator->getLoc()) << "terminator";
231 return diag;
232 }
233 }
234
235 return success();
236}
237
238//===----------------------------------------------------------------------===//
239// AnnotateOp
240//===----------------------------------------------------------------------===//
241
242DiagnosedSilenceableFailure
243transform::AnnotateOp::apply(transform::TransformRewriter &rewriter,
244 transform::TransformResults &results,
245 transform::TransformState &state) {
246 SmallVector<Operation *> targets =
247 llvm::to_vector(state.getPayloadOps(getTarget()));
248
249 Attribute attr = UnitAttr::get(getContext());
250 if (auto paramH = getParam()) {
251 ArrayRef<Attribute> params = state.getParams(paramH);
252 if (params.size() != 1) {
253 if (targets.size() != params.size()) {
254 return emitSilenceableError()
255 << "parameter and target have different payload lengths ("
256 << params.size() << " vs " << targets.size() << ")";
257 }
258 for (auto &&[target, attr] : llvm::zip_equal(targets, params))
259 target->setAttr(getName(), attr);
260 return DiagnosedSilenceableFailure::success();
261 }
262 attr = params[0];
263 }
264 for (auto *target : targets)
265 target->setAttr(getName(), attr);
266 return DiagnosedSilenceableFailure::success();
267}
268
269void transform::AnnotateOp::getEffects(
270 SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
271 onlyReadsHandle(getTarget(), effects);
272 onlyReadsHandle(getParam(), effects);
273 modifiesPayload(effects);
274}
275
276//===----------------------------------------------------------------------===//
277// ApplyCommonSubexpressionEliminationOp
278//===----------------------------------------------------------------------===//
279
280DiagnosedSilenceableFailure
281transform::ApplyCommonSubexpressionEliminationOp::applyToOne(
282 transform::TransformRewriter &rewriter, Operation *target,
283 ApplyToEachResultList &results, transform::TransformState &state) {
284 // Make sure that this transform is not applied to itself. Modifying the
285 // transform IR while it is being interpreted is generally dangerous.
286 DiagnosedSilenceableFailure payloadCheck =
287 ensurePayloadIsSeparateFromTransform(*this, target);
288 if (!payloadCheck.succeeded())
289 return payloadCheck;
290
291 DominanceInfo domInfo;
292 mlir::eliminateCommonSubExpressions(rewriter, domInfo, target);
293 return DiagnosedSilenceableFailure::success();
294}
295
296void transform::ApplyCommonSubexpressionEliminationOp::getEffects(
297 SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
298 transform::onlyReadsHandle(getTarget(), effects);
299 transform::modifiesPayload(effects);
300}
301
302//===----------------------------------------------------------------------===//
303// ApplyDeadCodeEliminationOp
304//===----------------------------------------------------------------------===//
305
306DiagnosedSilenceableFailure transform::ApplyDeadCodeEliminationOp::applyToOne(
307 transform::TransformRewriter &rewriter, Operation *target,
308 ApplyToEachResultList &results, transform::TransformState &state) {
309 // Make sure that this transform is not applied to itself. Modifying the
310 // transform IR while it is being interpreted is generally dangerous.
311 DiagnosedSilenceableFailure payloadCheck =
312 ensurePayloadIsSeparateFromTransform(*this, target);
313 if (!payloadCheck.succeeded())
314 return payloadCheck;
315
316 // Maintain a worklist of potentially dead ops.
317 SetVector<Operation *> worklist;
318
319 // Helper function that adds all defining ops of used values (operands and
320 // operands of nested ops).
321 auto addDefiningOpsToWorklist = [&](Operation *op) {
322 op->walk([&](Operation *op) {
323 for (Value v : op->getOperands())
324 if (Operation *defOp = v.getDefiningOp())
325 if (target->isProperAncestor(defOp))
326 worklist.insert(defOp);
327 });
328 };
329
330 // Helper function that erases an op.
331 auto eraseOp = [&](Operation *op) {
332 // Remove op and nested ops from the worklist.
333 op->walk([&](Operation *op) {
334 const auto *it = llvm::find(worklist, op);
335 if (it != worklist.end())
336 worklist.erase(it);
337 });
338 rewriter.eraseOp(op);
339 };
340
341 // Initial walk over the IR.
342 target->walk<WalkOrder::PostOrder>([&](Operation *op) {
343 if (op != target && isOpTriviallyDead(op)) {
344 addDefiningOpsToWorklist(op);
345 eraseOp(op);
346 }
347 });
348
349 // Erase all ops that have become dead.
350 while (!worklist.empty()) {
351 Operation *op = worklist.pop_back_val();
352 if (!isOpTriviallyDead(op))
353 continue;
354 addDefiningOpsToWorklist(op);
355 eraseOp(op);
356 }
357
358 return DiagnosedSilenceableFailure::success();
359}
360
361void transform::ApplyDeadCodeEliminationOp::getEffects(
362 SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
363 transform::onlyReadsHandle(getTarget(), effects);
364 transform::modifiesPayload(effects);
365}
366
367//===----------------------------------------------------------------------===//
368// ApplyPatternsOp
369//===----------------------------------------------------------------------===//
370
371DiagnosedSilenceableFailure transform::ApplyPatternsOp::applyToOne(
372 transform::TransformRewriter &rewriter, Operation *target,
373 ApplyToEachResultList &results, transform::TransformState &state) {
374 // Make sure that this transform is not applied to itself. Modifying the
375 // transform IR while it is being interpreted is generally dangerous. Even
376 // more so for the ApplyPatternsOp because the GreedyPatternRewriteDriver
377 // performs many additional simplifications such as dead code elimination.
378 DiagnosedSilenceableFailure payloadCheck =
379 ensurePayloadIsSeparateFromTransform(*this, target);
380 if (!payloadCheck.succeeded())
381 return payloadCheck;
382
383 // Gather all specified patterns.
384 MLIRContext *ctx = target->getContext();
385 RewritePatternSet patterns(ctx);
386 if (!getRegion().empty()) {
387 for (Operation &op : getRegion().front()) {
388 cast<transform::PatternDescriptorOpInterface>(&op)
389 .populatePatternsWithState(patterns, state);
390 }
391 }
392
393 // Configure the GreedyPatternRewriteDriver.
394 GreedyRewriteConfig config;
395 config.listener =
396 static_cast<RewriterBase::Listener *>(rewriter.getListener());
397 FrozenRewritePatternSet frozenPatterns(std::move(patterns));
398
399 config.maxIterations = getMaxIterations() == static_cast<uint64_t>(-1)
400 ? GreedyRewriteConfig::kNoLimit
401 : getMaxIterations();
402 config.maxNumRewrites = getMaxNumRewrites() == static_cast<uint64_t>(-1)
403 ? GreedyRewriteConfig::kNoLimit
404 : getMaxNumRewrites();
405
406 // Apply patterns and CSE repetitively until a fixpoint is reached. If no CSE
407 // was requested, apply the greedy pattern rewrite only once. (The greedy
408 // pattern rewrite driver already iterates to a fixpoint internally.)
409 bool cseChanged = false;
410 // One or two iterations should be sufficient. Stop iterating after a certain
411 // threshold to make debugging easier.
412 static const int64_t kNumMaxIterations = 50;
413 int64_t iteration = 0;
414 do {
415 LogicalResult result = failure();
416 if (target->hasTrait<OpTrait::IsIsolatedFromAbove>()) {
417 // Op is isolated from above. Apply patterns and also perform region
418 // simplification.
419 result = applyPatternsAndFoldGreedily(target, frozenPatterns, config);
420 } else {
421 // Manually gather list of ops because the other
422 // GreedyPatternRewriteDriver overloads only accepts ops that are isolated
423 // from above. This way, patterns can be applied to ops that are not
424 // isolated from above. Regions are not being simplified. Furthermore,
425 // only a single greedy rewrite iteration is performed.
426 SmallVector<Operation *> ops;
427 target->walk([&](Operation *nestedOp) {
428 if (target != nestedOp)
429 ops.push_back(nestedOp);
430 });
431 result = applyOpPatternsAndFold(ops, frozenPatterns, config);
432 }
433
434 // A failure typically indicates that the pattern application did not
435 // converge.
436 if (failed(result)) {
437 return emitSilenceableFailure(target)
438 << "greedy pattern application failed";
439 }
440
441 if (getApplyCse()) {
442 DominanceInfo domInfo;
443 mlir::eliminateCommonSubExpressions(rewriter, domInfo, target,
444 &cseChanged);
445 }
446 } while (cseChanged && ++iteration < kNumMaxIterations);
447
448 if (iteration == kNumMaxIterations)
449 return emitDefiniteFailure() << "fixpoint iteration did not converge";
450
451 return DiagnosedSilenceableFailure::success();
452}
453
454LogicalResult transform::ApplyPatternsOp::verify() {
455 if (!getRegion().empty()) {
456 for (Operation &op : getRegion().front()) {
457 if (!isa<transform::PatternDescriptorOpInterface>(&op)) {
458 InFlightDiagnostic diag = emitOpError()
459 << "expected children ops to implement "
460 "PatternDescriptorOpInterface";
461 diag.attachNote(op.getLoc()) << "op without interface";
462 return diag;
463 }
464 }
465 }
466 return success();
467}
468
469void transform::ApplyPatternsOp::getEffects(
470 SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
471 transform::onlyReadsHandle(getTarget(), effects);
472 transform::modifiesPayload(effects);
473}
474
475void transform::ApplyPatternsOp::build(
476 OpBuilder &builder, OperationState &result, Value target,
477 function_ref<void(OpBuilder &, Location)> bodyBuilder) {
478 result.addOperands(target);
479
480 OpBuilder::InsertionGuard g(builder);
481 Region *region = result.addRegion();
482 builder.createBlock(region);
483 if (bodyBuilder)
484 bodyBuilder(builder, result.location);
485}
486
487//===----------------------------------------------------------------------===//
488// ApplyCanonicalizationPatternsOp
489//===----------------------------------------------------------------------===//
490
491void transform::ApplyCanonicalizationPatternsOp::populatePatterns(
492 RewritePatternSet &patterns) {
493 MLIRContext *ctx = patterns.getContext();
494 for (Dialect *dialect : ctx->getLoadedDialects())
495 dialect->getCanonicalizationPatterns(patterns);
496 for (RegisteredOperationName op : ctx->getRegisteredOperations())
497 op.getCanonicalizationPatterns(patterns, ctx);
498}
499
500//===----------------------------------------------------------------------===//
501// ApplyConversionPatternsOp
502//===----------------------------------------------------------------------===//
503
504DiagnosedSilenceableFailure transform::ApplyConversionPatternsOp::apply(
505 transform::TransformRewriter &rewriter,
506 transform::TransformResults &results, transform::TransformState &state) {
507 MLIRContext *ctx = getContext();
508
509 // Instantiate the default type converter if a type converter builder is
510 // specified.
511 std::unique_ptr<TypeConverter> defaultTypeConverter;
512 transform::TypeConverterBuilderOpInterface typeConverterBuilder =
513 getDefaultTypeConverter();
514 if (typeConverterBuilder)
515 defaultTypeConverter = typeConverterBuilder.getTypeConverter();
516
517 // Configure conversion target.
518 ConversionTarget conversionTarget(*getContext());
519 if (getLegalOps())
520 for (Attribute attr : cast<ArrayAttr>(*getLegalOps()))
521 conversionTarget.addLegalOp(
522 OperationName(cast<StringAttr>(attr).getValue(), ctx));
523 if (getIllegalOps())
524 for (Attribute attr : cast<ArrayAttr>(*getIllegalOps()))
525 conversionTarget.addIllegalOp(
526 OperationName(cast<StringAttr>(attr).getValue(), ctx));
527 if (getLegalDialects())
528 for (Attribute attr : cast<ArrayAttr>(*getLegalDialects()))
529 conversionTarget.addLegalDialect(cast<StringAttr>(attr).getValue());
530 if (getIllegalDialects())
531 for (Attribute attr : cast<ArrayAttr>(*getIllegalDialects()))
532 conversionTarget.addIllegalDialect(cast<StringAttr>(attr).getValue());
533
534 // Gather all specified patterns.
535 RewritePatternSet patterns(ctx);
536 // Need to keep the converters alive until after pattern application because
537 // the patterns take a reference to an object that would otherwise get out of
538 // scope.
539 SmallVector<std::unique_ptr<TypeConverter>> keepAliveConverters;
540 if (!getPatterns().empty()) {
541 for (Operation &op : getPatterns().front()) {
542 auto descriptor =
543 cast<transform::ConversionPatternDescriptorOpInterface>(&op);
544
545 // Check if this pattern set specifies a type converter.
546 std::unique_ptr<TypeConverter> typeConverter =
547 descriptor.getTypeConverter();
548 TypeConverter *converter = nullptr;
549 if (typeConverter) {
550 keepAliveConverters.emplace_back(std::move(typeConverter));
551 converter = keepAliveConverters.back().get();
552 } else {
553 // No type converter specified: Use the default type converter.
554 if (!defaultTypeConverter) {
555 auto diag = emitDefiniteFailure()
556 << "pattern descriptor does not specify type "
557 "converter and apply_conversion_patterns op has "
558 "no default type converter";
559 diag.attachNote(op.getLoc()) << "pattern descriptor op";
560 return diag;
561 }
562 converter = defaultTypeConverter.get();
563 }
564
565 // Add descriptor-specific updates to the conversion target, which may
566 // depend on the final type converter. In structural converters, the
567 // legality of types dictates the dynamic legality of an operation.
568 descriptor.populateConversionTargetRules(*converter, conversionTarget);
569
570 descriptor.populatePatterns(*converter, patterns);
571 }
572 }
573
574 // Attach a tracking listener if handles should be preserved. We configure the
575 // listener to allow op replacements with different names, as conversion
576 // patterns typically replace ops with replacement ops that have a different
577 // name.
578 TrackingListenerConfig trackingConfig;
579 trackingConfig.requireMatchingReplacementOpName = false;
580 ErrorCheckingTrackingListener trackingListener(state, *this, trackingConfig);
581 ConversionConfig conversionConfig;
582 if (getPreserveHandles())
583 conversionConfig.listener = &trackingListener;
584
585 FrozenRewritePatternSet frozenPatterns(std::move(patterns));
586 for (Operation *target : state.getPayloadOps(getTarget())) {
587 // Make sure that this transform is not applied to itself. Modifying the
588 // transform IR while it is being interpreted is generally dangerous.
589 DiagnosedSilenceableFailure payloadCheck =
590 ensurePayloadIsSeparateFromTransform(*this, target);
591 if (!payloadCheck.succeeded())
592 return payloadCheck;
593
594 LogicalResult status = failure();
595 if (getPartialConversion()) {
596 status = applyPartialConversion(target, conversionTarget, frozenPatterns,
597 conversionConfig);
598 } else {
599 status = applyFullConversion(target, conversionTarget, frozenPatterns,
600 conversionConfig);
601 }
602
603 // Check dialect conversion state.
604 DiagnosedSilenceableFailure diag = DiagnosedSilenceableFailure::success();
605 if (failed(status)) {
606 diag = emitSilenceableError() << "dialect conversion failed";
607 diag.attachNote(target->getLoc()) << "target op";
608 }
609
610 // Check tracking listener error state.
611 DiagnosedSilenceableFailure trackingFailure =
612 trackingListener.checkAndResetError();
613 if (!trackingFailure.succeeded()) {
614 if (diag.succeeded()) {
615 // Tracking failure is the only failure.
616 return trackingFailure;
617 } else {
618 diag.attachNote() << "tracking listener also failed: "
619 << trackingFailure.getMessage();
620 (void)trackingFailure.silence();
621 }
622 }
623
624 if (!diag.succeeded())
625 return diag;
626 }
627
628 return DiagnosedSilenceableFailure::success();
629}
630
631LogicalResult transform::ApplyConversionPatternsOp::verify() {
632 if (getNumRegions() != 1 && getNumRegions() != 2)
633 return emitOpError() << "expected 1 or 2 regions";
634 if (!getPatterns().empty()) {
635 for (Operation &op : getPatterns().front()) {
636 if (!isa<transform::ConversionPatternDescriptorOpInterface>(&op)) {
637 InFlightDiagnostic diag =
638 emitOpError() << "expected pattern children ops to implement "
639 "ConversionPatternDescriptorOpInterface";
640 diag.attachNote(op.getLoc()) << "op without interface";
641 return diag;
642 }
643 }
644 }
645 if (getNumRegions() == 2) {
646 Region &typeConverterRegion = getRegion(1);
647 if (!llvm::hasSingleElement(typeConverterRegion.front()))
648 return emitOpError()
649 << "expected exactly one op in default type converter region";
650 auto typeConverterOp = dyn_cast<transform::TypeConverterBuilderOpInterface>(
651 &typeConverterRegion.front().front());
652 if (!typeConverterOp) {
653 InFlightDiagnostic diag = emitOpError()
654 << "expected default converter child op to "
655 "implement TypeConverterBuilderOpInterface";
656 diag.attachNote(typeConverterOp->getLoc()) << "op without interface";
657 return diag;
658 }
659 // Check default type converter type.
660 if (!getPatterns().empty()) {
661 for (Operation &op : getPatterns().front()) {
662 auto descriptor =
663 cast<transform::ConversionPatternDescriptorOpInterface>(&op);
664 if (failed(descriptor.verifyTypeConverter(typeConverterOp)))
665 return failure();
666 }
667 }
668 }
669 return success();
670}
671
672void transform::ApplyConversionPatternsOp::getEffects(
673 SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
674 if (!getPreserveHandles()) {
675 transform::consumesHandle(getTarget(), effects);
676 } else {
677 transform::onlyReadsHandle(getTarget(), effects);
678 }
679 transform::modifiesPayload(effects);
680}
681
682void transform::ApplyConversionPatternsOp::build(
683 OpBuilder &builder, OperationState &result, Value target,
684 function_ref<void(OpBuilder &, Location)> patternsBodyBuilder,
685 function_ref<void(OpBuilder &, Location)> typeConverterBodyBuilder) {
686 result.addOperands(target);
687
688 {
689 OpBuilder::InsertionGuard g(builder);
690 Region *region1 = result.addRegion();
691 builder.createBlock(region1);
692 if (patternsBodyBuilder)
693 patternsBodyBuilder(builder, result.location);
694 }
695 {
696 OpBuilder::InsertionGuard g(builder);
697 Region *region2 = result.addRegion();
698 builder.createBlock(region2);
699 if (typeConverterBodyBuilder)
700 typeConverterBodyBuilder(builder, result.location);
701 }
702}
703
704//===----------------------------------------------------------------------===//
705// ApplyToLLVMConversionPatternsOp
706//===----------------------------------------------------------------------===//
707
708void transform::ApplyToLLVMConversionPatternsOp::populatePatterns(
709 TypeConverter &typeConverter, RewritePatternSet &patterns) {
710 Dialect *dialect = getContext()->getLoadedDialect(getDialectName());
711 assert(dialect && "expected that dialect is loaded");
712 auto *iface = cast<ConvertToLLVMPatternInterface>(dialect);
713 // ConversionTarget is currently ignored because the enclosing
714 // apply_conversion_patterns op sets up its own ConversionTarget.
715 ConversionTarget target(*getContext());
716 iface->populateConvertToLLVMConversionPatterns(
717 target, static_cast<LLVMTypeConverter &>(typeConverter), patterns);
718}
719
720LogicalResult transform::ApplyToLLVMConversionPatternsOp::verifyTypeConverter(
721 transform::TypeConverterBuilderOpInterface builder) {
722 if (builder.getTypeConverterType() != "LLVMTypeConverter")
723 return emitOpError("expected LLVMTypeConverter");
724 return success();
725}
726
727LogicalResult transform::ApplyToLLVMConversionPatternsOp::verify() {
728 Dialect *dialect = getContext()->getLoadedDialect(getDialectName());
729 if (!dialect)
730 return emitOpError("unknown dialect or dialect not loaded: ")
731 << getDialectName();
732 auto *iface = dyn_cast<ConvertToLLVMPatternInterface>(dialect);
733 if (!iface)
734 return emitOpError(
735 "dialect does not implement ConvertToLLVMPatternInterface or "
736 "extension was not loaded: ")
737 << getDialectName();
738 return success();
739}
740
741//===----------------------------------------------------------------------===//
742// ApplyLoopInvariantCodeMotionOp
743//===----------------------------------------------------------------------===//
744
745DiagnosedSilenceableFailure
746transform::ApplyLoopInvariantCodeMotionOp::applyToOne(
747 transform::TransformRewriter &rewriter, LoopLikeOpInterface target,
748 transform::ApplyToEachResultList &results,
749 transform::TransformState &state) {
750 // Currently, LICM does not remove operations, so we don't need tracking.
751 // If this ever changes, add a LICM entry point that takes a rewriter.
752 moveLoopInvariantCode(target);
753 return DiagnosedSilenceableFailure::success();
754}
755
756void transform::ApplyLoopInvariantCodeMotionOp::getEffects(
757 SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
758 transform::onlyReadsHandle(getTarget(), effects);
759 transform::modifiesPayload(effects);
760}
761
762//===----------------------------------------------------------------------===//
763// ApplyRegisteredPassOp
764//===----------------------------------------------------------------------===//
765
766DiagnosedSilenceableFailure transform::ApplyRegisteredPassOp::applyToOne(
767 transform::TransformRewriter &rewriter, Operation *target,
768 ApplyToEachResultList &results, transform::TransformState &state) {
769 // Make sure that this transform is not applied to itself. Modifying the
770 // transform IR while it is being interpreted is generally dangerous. Even
771 // more so when applying passes because they may perform a wide range of IR
772 // modifications.
773 DiagnosedSilenceableFailure payloadCheck =
774 ensurePayloadIsSeparateFromTransform(*this, target);
775 if (!payloadCheck.succeeded())
776 return payloadCheck;
777
778 // Get pass or pass pipeline from registry.
779 const PassRegistryEntry *info = PassPipelineInfo::lookup(getPassName());
780 if (!info)
781 info = PassInfo::lookup(getPassName());
782 if (!info)
783 return emitDefiniteFailure()
784 << "unknown pass or pass pipeline: " << getPassName();
785
786 // Create pass manager and run the pass or pass pipeline.
787 PassManager pm(getContext());
788 if (failed(info->addToPipeline(pm, getOptions(), [&](const Twine &msg) {
789 emitError(msg);
790 return failure();
791 }))) {
792 return emitDefiniteFailure()
793 << "failed to add pass or pass pipeline to pipeline: "
794 << getPassName();
795 }
796 if (failed(pm.run(target))) {
797 auto diag = emitSilenceableError() << "pass pipeline failed";
798 diag.attachNote(target->getLoc()) << "target op";
799 return diag;
800 }
801
802 results.push_back(target);
803 return DiagnosedSilenceableFailure::success();
804}
805
806//===----------------------------------------------------------------------===//
807// CastOp
808//===----------------------------------------------------------------------===//
809
810DiagnosedSilenceableFailure
811transform::CastOp::applyToOne(transform::TransformRewriter &rewriter,
812 Operation *target, ApplyToEachResultList &results,
813 transform::TransformState &state) {
814 results.push_back(target);
815 return DiagnosedSilenceableFailure::success();
816}
817
818void transform::CastOp::getEffects(
819 SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
820 onlyReadsPayload(effects);
821 onlyReadsHandle(getInput(), effects);
822 producesHandle(getOutput(), effects);
823}
824
825bool transform::CastOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
826 assert(inputs.size() == 1 && "expected one input");
827 assert(outputs.size() == 1 && "expected one output");
828 return llvm::all_of(
829 std::initializer_list<Type>{inputs.front(), outputs.front()},
830 llvm::IsaPred<transform::TransformHandleTypeInterface>);
831}
832
833//===----------------------------------------------------------------------===//
834// CollectMatchingOp
835//===----------------------------------------------------------------------===//
836
837/// Applies matcher operations from the given `block` assigning `op` as the
838/// payload of the block's first argument. Updates `state` accordingly. If any
839/// of the matcher produces a silenceable failure, discards it (printing the
840/// content to the debug output stream) and returns failure. If any of the
841/// matchers produces a definite failure, reports it and returns failure. If all
842/// matchers in the block succeed, populates `mappings` with the payload
843/// entities associated with the block terminator operands.
844static DiagnosedSilenceableFailure
845matchBlock(Block &block, Operation *op, transform::TransformState &state,
846 SmallVectorImpl<SmallVector<transform::MappedValue>> &mappings) {
847 assert(block.getParent() && "cannot match using a detached block");
848 auto matchScope = state.make_region_scope(region&: *block.getParent());
849 if (failed(result: state.mapBlockArgument(argument: block.getArgument(i: 0), values: {op})))
850 return DiagnosedSilenceableFailure::definiteFailure();
851
852 for (Operation &match : block.without_terminator()) {
853 if (!isa<transform::MatchOpInterface>(Val: match)) {
854 return emitDefiniteFailure(loc: match.getLoc())
855 << "expected operations in the match part to "
856 "implement MatchOpInterface";
857 }
858 DiagnosedSilenceableFailure diag =
859 state.applyTransform(transform: cast<transform::TransformOpInterface>(match));
860 if (diag.succeeded())
861 continue;
862
863 return diag;
864 }
865
866 // Remember the values mapped to the terminator operands so we can
867 // forward them to the action.
868 ValueRange yieldedValues = block.getTerminator()->getOperands();
869 transform::detail::prepareValueMappings(mappings, yieldedValues, state);
870 return DiagnosedSilenceableFailure::success();
871}
872
873/// Returns `true` if both types implement one of the interfaces provided as
874/// template parameters.
875template <typename... Tys>
876static bool implementSameInterface(Type t1, Type t2) {
877 return ((isa<Tys>(t1) && isa<Tys>(t2)) || ... || false);
878}
879
880/// Returns `true` if both types implement one of the transform dialect
881/// interfaces.
882static bool implementSameTransformInterface(Type t1, Type t2) {
883 return implementSameInterface<transform::TransformHandleTypeInterface,
884 transform::TransformParamTypeInterface,
885 transform::TransformValueHandleTypeInterface>(
886 t1, t2);
887}
888
889//===----------------------------------------------------------------------===//
890// CollectMatchingOp
891//===----------------------------------------------------------------------===//
892
893DiagnosedSilenceableFailure
894transform::CollectMatchingOp::apply(transform::TransformRewriter &rewriter,
895 transform::TransformResults &results,
896 transform::TransformState &state) {
897 auto matcher = SymbolTable::lookupNearestSymbolFrom<FunctionOpInterface>(
898 getOperation(), getMatcher());
899 if (matcher.isExternal()) {
900 return emitDefiniteFailure()
901 << "unresolved external symbol " << getMatcher();
902 }
903
904 SmallVector<SmallVector<MappedValue>, 2> rawResults;
905 rawResults.resize(getOperation()->getNumResults());
906 std::optional<DiagnosedSilenceableFailure> maybeFailure;
907 for (Operation *root : state.getPayloadOps(getRoot())) {
908 WalkResult walkResult = root->walk([&](Operation *op) {
909 DEBUG_MATCHER({
910 DBGS_MATCHER() << "matching ";
911 op->print(llvm::dbgs(),
912 OpPrintingFlags().assumeVerified().skipRegions());
913 llvm::dbgs() << " @" << op << "\n";
914 });
915
916 // Try matching.
917 SmallVector<SmallVector<MappedValue>> mappings;
918 DiagnosedSilenceableFailure diag =
919 matchBlock(matcher.getFunctionBody().front(), op, state, mappings);
920 if (diag.isDefiniteFailure())
921 return WalkResult::interrupt();
922 if (diag.isSilenceableFailure()) {
923 DEBUG_MATCHER(DBGS_MATCHER() << "matcher " << matcher.getName()
924 << " failed: " << diag.getMessage());
925 return WalkResult::advance();
926 }
927
928 // If succeeded, collect results.
929 for (auto &&[i, mapping] : llvm::enumerate(mappings)) {
930 if (mapping.size() != 1) {
931 maybeFailure.emplace(emitSilenceableError()
932 << "result #" << i << ", associated with "
933 << mapping.size()
934 << " payload objects, expected 1");
935 return WalkResult::interrupt();
936 }
937 rawResults[i].push_back(mapping[0]);
938 }
939 return WalkResult::advance();
940 });
941 if (walkResult.wasInterrupted())
942 return std::move(*maybeFailure);
943 assert(!maybeFailure && "failure set but the walk was not interrupted");
944
945 for (auto &&[opResult, rawResult] :
946 llvm::zip_equal(getOperation()->getResults(), rawResults)) {
947 results.setMappedValues(opResult, rawResult);
948 }
949 }
950 return DiagnosedSilenceableFailure::success();
951}
952
953void transform::CollectMatchingOp::getEffects(
954 SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
955 onlyReadsHandle(getRoot(), effects);
956 producesHandle(getResults(), effects);
957 onlyReadsPayload(effects);
958}
959
960LogicalResult transform::CollectMatchingOp::verifySymbolUses(
961 SymbolTableCollection &symbolTable) {
962 auto matcherSymbol = dyn_cast_or_null<FunctionOpInterface>(
963 symbolTable.lookupNearestSymbolFrom(getOperation(), getMatcher()));
964 if (!matcherSymbol ||
965 !isa<TransformOpInterface>(matcherSymbol.getOperation()))
966 return emitError() << "unresolved matcher symbol " << getMatcher();
967
968 ArrayRef<Type> argumentTypes = matcherSymbol.getArgumentTypes();
969 if (argumentTypes.size() != 1 ||
970 !isa<TransformHandleTypeInterface>(argumentTypes[0])) {
971 return emitError()
972 << "expected the matcher to take one operation handle argument";
973 }
974 if (!matcherSymbol.getArgAttr(
975 0, transform::TransformDialect::kArgReadOnlyAttrName)) {
976 return emitError() << "expected the matcher argument to be marked readonly";
977 }
978
979 ArrayRef<Type> resultTypes = matcherSymbol.getResultTypes();
980 if (resultTypes.size() != getOperation()->getNumResults()) {
981 return emitError()
982 << "expected the matcher to yield as many values as op has results ("
983 << getOperation()->getNumResults() << "), got "
984 << resultTypes.size();
985 }
986
987 for (auto &&[i, matcherType, resultType] :
988 llvm::enumerate(resultTypes, getOperation()->getResultTypes())) {
989 if (implementSameTransformInterface(matcherType, resultType))
990 continue;
991
992 return emitError()
993 << "mismatching type interfaces for matcher result and op result #"
994 << i;
995 }
996
997 return success();
998}
999
1000//===----------------------------------------------------------------------===//
1001// ForeachMatchOp
1002//===----------------------------------------------------------------------===//
1003
1004DiagnosedSilenceableFailure
1005transform::ForeachMatchOp::apply(transform::TransformRewriter &rewriter,
1006 transform::TransformResults &results,
1007 transform::TransformState &state) {
1008 SmallVector<std::pair<FunctionOpInterface, FunctionOpInterface>>
1009 matchActionPairs;
1010 matchActionPairs.reserve(getMatchers().size());
1011 SymbolTableCollection symbolTable;
1012 for (auto &&[matcher, action] :
1013 llvm::zip_equal(getMatchers(), getActions())) {
1014 auto matcherSymbol =
1015 symbolTable.lookupNearestSymbolFrom<FunctionOpInterface>(
1016 getOperation(), cast<SymbolRefAttr>(matcher));
1017 auto actionSymbol =
1018 symbolTable.lookupNearestSymbolFrom<FunctionOpInterface>(
1019 getOperation(), cast<SymbolRefAttr>(action));
1020 assert(matcherSymbol && actionSymbol &&
1021 "unresolved symbols not caught by the verifier");
1022
1023 if (matcherSymbol.isExternal())
1024 return emitDefiniteFailure() << "unresolved external symbol " << matcher;
1025 if (actionSymbol.isExternal())
1026 return emitDefiniteFailure() << "unresolved external symbol " << action;
1027
1028 matchActionPairs.emplace_back(matcherSymbol, actionSymbol);
1029 }
1030
1031 DiagnosedSilenceableFailure overallDiag =
1032 DiagnosedSilenceableFailure::success();
1033 for (Operation *root : state.getPayloadOps(getRoot())) {
1034 WalkResult walkResult = root->walk([&](Operation *op) {
1035 // If getRestrictRoot is not present, skip over the root op itself so we
1036 // don't invalidate it.
1037 if (!getRestrictRoot() && op == root)
1038 return WalkResult::advance();
1039
1040 DEBUG_MATCHER({
1041 DBGS_MATCHER() << "matching ";
1042 op->print(llvm::dbgs(),
1043 OpPrintingFlags().assumeVerified().skipRegions());
1044 llvm::dbgs() << " @" << op << "\n";
1045 });
1046
1047 // Try all the match/action pairs until the first successful match.
1048 for (auto [matcher, action] : matchActionPairs) {
1049 SmallVector<SmallVector<MappedValue>> mappings;
1050 DiagnosedSilenceableFailure diag =
1051 matchBlock(matcher.getFunctionBody().front(), op, state, mappings);
1052 if (diag.isDefiniteFailure())
1053 return WalkResult::interrupt();
1054 if (diag.isSilenceableFailure()) {
1055 DEBUG_MATCHER(DBGS_MATCHER() << "matcher " << matcher.getName()
1056 << " failed: " << diag.getMessage());
1057 continue;
1058 }
1059
1060 auto scope = state.make_region_scope(action.getFunctionBody());
1061 for (auto &&[arg, map] : llvm::zip_equal(
1062 action.getFunctionBody().front().getArguments(), mappings)) {
1063 if (failed(state.mapBlockArgument(arg, map)))
1064 return WalkResult::interrupt();
1065 }
1066
1067 for (Operation &transform :
1068 action.getFunctionBody().front().without_terminator()) {
1069 DiagnosedSilenceableFailure result =
1070 state.applyTransform(cast<TransformOpInterface>(transform));
1071 if (result.isDefiniteFailure())
1072 return WalkResult::interrupt();
1073 if (result.isSilenceableFailure()) {
1074 if (overallDiag.succeeded()) {
1075 overallDiag = emitSilenceableError() << "actions failed";
1076 }
1077 overallDiag.attachNote(action->getLoc())
1078 << "failed action: " << result.getMessage();
1079 overallDiag.attachNote(op->getLoc())
1080 << "when applied to this matching payload";
1081 (void)result.silence();
1082 continue;
1083 }
1084 }
1085 break;
1086 }
1087 return WalkResult::advance();
1088 });
1089 if (walkResult.wasInterrupted())
1090 return DiagnosedSilenceableFailure::definiteFailure();
1091 }
1092
1093 // The root operation should not have been affected, so we can just reassign
1094 // the payload to the result. Note that we need to consume the root handle to
1095 // make sure any handles to operations inside, that could have been affected
1096 // by actions, are invalidated.
1097 results.set(llvm::cast<OpResult>(getUpdated()),
1098 state.getPayloadOps(getRoot()));
1099 return overallDiag;
1100}
1101
1102void transform::ForeachMatchOp::getEffects(
1103 SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
1104 // Bail if invalid.
1105 if (getOperation()->getNumOperands() < 1 ||
1106 getOperation()->getNumResults() < 1) {
1107 return modifiesPayload(effects);
1108 }
1109
1110 consumesHandle(getRoot(), effects);
1111 producesHandle(getUpdated(), effects);
1112 modifiesPayload(effects);
1113}
1114
1115/// Parses the comma-separated list of symbol reference pairs of the format
1116/// `@matcher -> @action`.
1117static ParseResult parseForeachMatchSymbols(OpAsmParser &parser,
1118 ArrayAttr &matchers,
1119 ArrayAttr &actions) {
1120 StringAttr matcher;
1121 StringAttr action;
1122 SmallVector<Attribute> matcherList;
1123 SmallVector<Attribute> actionList;
1124 do {
1125 if (parser.parseSymbolName(matcher) || parser.parseArrow() ||
1126 parser.parseSymbolName(action)) {
1127 return failure();
1128 }
1129 matcherList.push_back(SymbolRefAttr::get(matcher));
1130 actionList.push_back(SymbolRefAttr::get(action));
1131 } while (parser.parseOptionalComma().succeeded());
1132
1133 matchers = parser.getBuilder().getArrayAttr(matcherList);
1134 actions = parser.getBuilder().getArrayAttr(actionList);
1135 return success();
1136}
1137
1138/// Prints the comma-separated list of symbol reference pairs of the format
1139/// `@matcher -> @action`.
1140static void printForeachMatchSymbols(OpAsmPrinter &printer, Operation *op,
1141 ArrayAttr matchers, ArrayAttr actions) {
1142 printer.increaseIndent();
1143 printer.increaseIndent();
1144 for (auto &&[matcher, action, idx] : llvm::zip_equal(
1145 matchers, actions, llvm::seq<unsigned>(0, matchers.size()))) {
1146 printer.printNewline();
1147 printer << cast<SymbolRefAttr>(matcher) << " -> "
1148 << cast<SymbolRefAttr>(action);
1149 if (idx != matchers.size() - 1)
1150 printer << ", ";
1151 }
1152 printer.decreaseIndent();
1153 printer.decreaseIndent();
1154}
1155
1156LogicalResult transform::ForeachMatchOp::verify() {
1157 if (getMatchers().size() != getActions().size())
1158 return emitOpError() << "expected the same number of matchers and actions";
1159 if (getMatchers().empty())
1160 return emitOpError() << "expected at least one match/action pair";
1161
1162 llvm::SmallPtrSet<Attribute, 8> matcherNames;
1163 for (Attribute name : getMatchers()) {
1164 if (matcherNames.insert(name).second)
1165 continue;
1166 emitWarning() << "matcher " << name
1167 << " is used more than once, only the first match will apply";
1168 }
1169
1170 return success();
1171}
1172
1173/// Checks that the attributes of the function-like operation have correct
1174/// consumption effect annotations. If `alsoVerifyInternal`, checks for
1175/// annotations being present even if they can be inferred from the body.
1176static DiagnosedSilenceableFailure
1177verifyFunctionLikeConsumeAnnotations(FunctionOpInterface op, bool emitWarnings,
1178 bool alsoVerifyInternal = false) {
1179 auto transformOp = cast<transform::TransformOpInterface>(op.getOperation());
1180 llvm::SmallDenseSet<unsigned> consumedArguments;
1181 if (!op.isExternal()) {
1182 transform::getConsumedBlockArguments(block&: op.getFunctionBody().front(),
1183 consumedArguments);
1184 }
1185 for (unsigned i = 0, e = op.getNumArguments(); i < e; ++i) {
1186 bool isConsumed =
1187 op.getArgAttr(i, transform::TransformDialect::kArgConsumedAttrName) !=
1188 nullptr;
1189 bool isReadOnly =
1190 op.getArgAttr(i, transform::TransformDialect::kArgReadOnlyAttrName) !=
1191 nullptr;
1192 if (isConsumed && isReadOnly) {
1193 return transformOp.emitSilenceableError()
1194 << "argument #" << i << " cannot be both readonly and consumed";
1195 }
1196 if ((op.isExternal() || alsoVerifyInternal) && !isConsumed && !isReadOnly) {
1197 return transformOp.emitSilenceableError()
1198 << "must provide consumed/readonly status for arguments of "
1199 "external or called ops";
1200 }
1201 if (op.isExternal())
1202 continue;
1203
1204 if (consumedArguments.contains(V: i) && !isConsumed && isReadOnly) {
1205 return transformOp.emitSilenceableError()
1206 << "argument #" << i
1207 << " is consumed in the body but is not marked as such";
1208 }
1209 if (emitWarnings && !consumedArguments.contains(V: i) && isConsumed) {
1210 // Cannot use op.emitWarning() here as it would attempt to verify the op
1211 // before printing, resulting in infinite recursion.
1212 emitWarning(op->getLoc())
1213 << "op argument #" << i
1214 << " is not consumed in the body but is marked as consumed";
1215 }
1216 }
1217 return DiagnosedSilenceableFailure::success();
1218}
1219
1220LogicalResult transform::ForeachMatchOp::verifySymbolUses(
1221 SymbolTableCollection &symbolTable) {
1222 assert(getMatchers().size() == getActions().size());
1223 auto consumedAttr =
1224 StringAttr::get(getContext(), TransformDialect::kArgConsumedAttrName);
1225 for (auto &&[matcher, action] :
1226 llvm::zip_equal(getMatchers(), getActions())) {
1227 auto matcherSymbol = dyn_cast_or_null<FunctionOpInterface>(
1228 symbolTable.lookupNearestSymbolFrom(getOperation(),
1229 cast<SymbolRefAttr>(matcher)));
1230 auto actionSymbol = dyn_cast_or_null<FunctionOpInterface>(
1231 symbolTable.lookupNearestSymbolFrom(getOperation(),
1232 cast<SymbolRefAttr>(action)));
1233 if (!matcherSymbol ||
1234 !isa<TransformOpInterface>(matcherSymbol.getOperation()))
1235 return emitError() << "unresolved matcher symbol " << matcher;
1236 if (!actionSymbol ||
1237 !isa<TransformOpInterface>(actionSymbol.getOperation()))
1238 return emitError() << "unresolved action symbol " << action;
1239
1240 if (failed(verifyFunctionLikeConsumeAnnotations(matcherSymbol,
1241 /*emitWarnings=*/false,
1242 /*alsoVerifyInternal=*/true)
1243 .checkAndReport())) {
1244 return failure();
1245 }
1246 if (failed(verifyFunctionLikeConsumeAnnotations(actionSymbol,
1247 /*emitWarnings=*/false,
1248 /*alsoVerifyInternal=*/true)
1249 .checkAndReport())) {
1250 return failure();
1251 }
1252
1253 ArrayRef<Type> matcherResults = matcherSymbol.getResultTypes();
1254 ArrayRef<Type> actionArguments = actionSymbol.getArgumentTypes();
1255 if (matcherResults.size() != actionArguments.size()) {
1256 return emitError() << "mismatching number of matcher results and "
1257 "action arguments between "
1258 << matcher << " (" << matcherResults.size() << ") and "
1259 << action << " (" << actionArguments.size() << ")";
1260 }
1261 for (auto &&[i, matcherType, actionType] :
1262 llvm::enumerate(matcherResults, actionArguments)) {
1263 if (implementSameTransformInterface(matcherType, actionType))
1264 continue;
1265
1266 return emitError() << "mismatching type interfaces for matcher result "
1267 "and action argument #"
1268 << i;
1269 }
1270
1271 if (!actionSymbol.getResultTypes().empty()) {
1272 InFlightDiagnostic diag =
1273 emitError() << "action symbol is not expected to have results";
1274 diag.attachNote(actionSymbol->getLoc()) << "symbol declaration";
1275 return diag;
1276 }
1277
1278 if (matcherSymbol.getArgumentTypes().size() != 1 ||
1279 !implementSameTransformInterface(matcherSymbol.getArgumentTypes()[0],
1280 getRoot().getType())) {
1281 InFlightDiagnostic diag =
1282 emitOpError() << "expects matcher symbol to have one argument with "
1283 "the same transform interface as the first operand";
1284 diag.attachNote(matcherSymbol->getLoc()) << "symbol declaration";
1285 return diag;
1286 }
1287
1288 if (matcherSymbol.getArgAttr(0, consumedAttr)) {
1289 InFlightDiagnostic diag =
1290 emitOpError()
1291 << "does not expect matcher symbol to consume its operand";
1292 diag.attachNote(matcherSymbol->getLoc()) << "symbol declaration";
1293 return diag;
1294 }
1295 }
1296 return success();
1297}
1298
1299//===----------------------------------------------------------------------===//
1300// ForeachOp
1301//===----------------------------------------------------------------------===//
1302
1303DiagnosedSilenceableFailure
1304transform::ForeachOp::apply(transform::TransformRewriter &rewriter,
1305 transform::TransformResults &results,
1306 transform::TransformState &state) {
1307 SmallVector<SmallVector<Operation *>> resultOps(getNumResults(), {});
1308 // Store payload ops in a vector because ops may be removed from the mapping
1309 // by the TrackingRewriter while the iteration is in progress.
1310 SmallVector<Operation *> targets =
1311 llvm::to_vector(state.getPayloadOps(getTarget()));
1312 for (Operation *op : targets) {
1313 auto scope = state.make_region_scope(getBody());
1314 if (failed(state.mapBlockArguments(getIterationVariable(), {op})))
1315 return DiagnosedSilenceableFailure::definiteFailure();
1316
1317 // Execute loop body.
1318 for (Operation &transform : getBody().front().without_terminator()) {
1319 DiagnosedSilenceableFailure result = state.applyTransform(
1320 cast<transform::TransformOpInterface>(transform));
1321 if (!result.succeeded())
1322 return result;
1323 }
1324
1325 // Append yielded payload ops to result list (if any).
1326 for (unsigned i = 0; i < getNumResults(); ++i) {
1327 auto yieldedOps = state.getPayloadOps(getYieldOp().getOperand(i));
1328 resultOps[i].append(yieldedOps.begin(), yieldedOps.end());
1329 }
1330 }
1331
1332 for (unsigned i = 0; i < getNumResults(); ++i)
1333 results.set(llvm::cast<OpResult>(getResult(i)), resultOps[i]);
1334
1335 return DiagnosedSilenceableFailure::success();
1336}
1337
1338void transform::ForeachOp::getEffects(
1339 SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
1340 BlockArgument iterVar = getIterationVariable();
1341 if (any_of(getBody().front().without_terminator(), [&](Operation &op) {
1342 return isHandleConsumed(iterVar, cast<TransformOpInterface>(&op));
1343 })) {
1344 consumesHandle(getTarget(), effects);
1345 } else {
1346 onlyReadsHandle(getTarget(), effects);
1347 }
1348
1349 if (any_of(getBody().front().without_terminator(), [&](Operation &op) {
1350 return doesModifyPayload(cast<TransformOpInterface>(&op));
1351 })) {
1352 modifiesPayload(effects);
1353 } else if (any_of(getBody().front().without_terminator(), [&](Operation &op) {
1354 return doesReadPayload(cast<TransformOpInterface>(&op));
1355 })) {
1356 onlyReadsPayload(effects);
1357 }
1358
1359 for (Value result : getResults())
1360 producesHandle(result, effects);
1361}
1362
1363void transform::ForeachOp::getSuccessorRegions(
1364 RegionBranchPoint point, SmallVectorImpl<RegionSuccessor> &regions) {
1365 Region *bodyRegion = &getBody();
1366 if (point.isParent()) {
1367 regions.emplace_back(bodyRegion, bodyRegion->getArguments());
1368 return;
1369 }
1370
1371 // Branch back to the region or the parent.
1372 assert(point == getBody() && "unexpected region index");
1373 regions.emplace_back(bodyRegion, bodyRegion->getArguments());
1374 regions.emplace_back();
1375}
1376
1377OperandRange
1378transform::ForeachOp::getEntrySuccessorOperands(RegionBranchPoint point) {
1379 // The iteration variable op handle is mapped to a subset (one op to be
1380 // precise) of the payload ops of the ForeachOp operand.
1381 assert(point == getBody() && "unexpected region index");
1382 return getOperation()->getOperands();
1383}
1384
1385transform::YieldOp transform::ForeachOp::getYieldOp() {
1386 return cast<transform::YieldOp>(getBody().front().getTerminator());
1387}
1388
1389LogicalResult transform::ForeachOp::verify() {
1390 auto yieldOp = getYieldOp();
1391 if (getNumResults() != yieldOp.getNumOperands())
1392 return emitOpError() << "expects the same number of results as the "
1393 "terminator has operands";
1394 for (Value v : yieldOp.getOperands())
1395 if (!llvm::isa<TransformHandleTypeInterface>(v.getType()))
1396 return yieldOp->emitOpError("expects operands to have types implementing "
1397 "TransformHandleTypeInterface");
1398 return success();
1399}
1400
1401//===----------------------------------------------------------------------===//
1402// GetParentOp
1403//===----------------------------------------------------------------------===//
1404
1405DiagnosedSilenceableFailure
1406transform::GetParentOp::apply(transform::TransformRewriter &rewriter,
1407 transform::TransformResults &results,
1408 transform::TransformState &state) {
1409 SmallVector<Operation *> parents;
1410 DenseSet<Operation *> resultSet;
1411 for (Operation *target : state.getPayloadOps(getTarget())) {
1412 Operation *parent = target;
1413 for (int64_t i = 0, e = getNthParent(); i < e; ++i) {
1414 parent = parent->getParentOp();
1415 while (parent) {
1416 bool checkIsolatedFromAbove =
1417 !getIsolatedFromAbove() ||
1418 parent->hasTrait<OpTrait::IsIsolatedFromAbove>();
1419 bool checkOpName = !getOpName().has_value() ||
1420 parent->getName().getStringRef() == *getOpName();
1421 if (checkIsolatedFromAbove && checkOpName)
1422 break;
1423 parent = parent->getParentOp();
1424 }
1425 if (!parent) {
1426 if (getAllowEmptyResults()) {
1427 results.set(llvm::cast<OpResult>(getResult()), parents);
1428 return DiagnosedSilenceableFailure::success();
1429 }
1430 DiagnosedSilenceableFailure diag =
1431 emitSilenceableError()
1432 << "could not find a parent op that matches all requirements";
1433 diag.attachNote(target->getLoc()) << "target op";
1434 return diag;
1435 }
1436 }
1437 if (getDeduplicate()) {
1438 if (!resultSet.contains(parent)) {
1439 parents.push_back(parent);
1440 resultSet.insert(parent);
1441 }
1442 } else {
1443 parents.push_back(parent);
1444 }
1445 }
1446 results.set(llvm::cast<OpResult>(getResult()), parents);
1447 return DiagnosedSilenceableFailure::success();
1448}
1449
1450//===----------------------------------------------------------------------===//
1451// GetConsumersOfResult
1452//===----------------------------------------------------------------------===//
1453
1454DiagnosedSilenceableFailure
1455transform::GetConsumersOfResult::apply(transform::TransformRewriter &rewriter,
1456 transform::TransformResults &results,
1457 transform::TransformState &state) {
1458 int64_t resultNumber = getResultNumber();
1459 auto payloadOps = state.getPayloadOps(getTarget());
1460 if (std::empty(payloadOps)) {
1461 results.set(cast<OpResult>(getResult()), {});
1462 return DiagnosedSilenceableFailure::success();
1463 }
1464 if (!llvm::hasSingleElement(payloadOps))
1465 return emitDefiniteFailure()
1466 << "handle must be mapped to exactly one payload op";
1467
1468 Operation *target = *payloadOps.begin();
1469 if (target->getNumResults() <= resultNumber)
1470 return emitDefiniteFailure() << "result number overflow";
1471 results.set(llvm::cast<OpResult>(getResult()),
1472 llvm::to_vector(target->getResult(resultNumber).getUsers()));
1473 return DiagnosedSilenceableFailure::success();
1474}
1475
1476//===----------------------------------------------------------------------===//
1477// GetDefiningOp
1478//===----------------------------------------------------------------------===//
1479
1480DiagnosedSilenceableFailure
1481transform::GetDefiningOp::apply(transform::TransformRewriter &rewriter,
1482 transform::TransformResults &results,
1483 transform::TransformState &state) {
1484 SmallVector<Operation *> definingOps;
1485 for (Value v : state.getPayloadValues(getTarget())) {
1486 if (llvm::isa<BlockArgument>(v)) {
1487 DiagnosedSilenceableFailure diag =
1488 emitSilenceableError() << "cannot get defining op of block argument";
1489 diag.attachNote(v.getLoc()) << "target value";
1490 return diag;
1491 }
1492 definingOps.push_back(v.getDefiningOp());
1493 }
1494 results.set(llvm::cast<OpResult>(getResult()), definingOps);
1495 return DiagnosedSilenceableFailure::success();
1496}
1497
1498//===----------------------------------------------------------------------===//
1499// GetProducerOfOperand
1500//===----------------------------------------------------------------------===//
1501
1502DiagnosedSilenceableFailure
1503transform::GetProducerOfOperand::apply(transform::TransformRewriter &rewriter,
1504 transform::TransformResults &results,
1505 transform::TransformState &state) {
1506 int64_t operandNumber = getOperandNumber();
1507 SmallVector<Operation *> producers;
1508 for (Operation *target : state.getPayloadOps(getTarget())) {
1509 Operation *producer =
1510 target->getNumOperands() <= operandNumber
1511 ? nullptr
1512 : target->getOperand(operandNumber).getDefiningOp();
1513 if (!producer) {
1514 DiagnosedSilenceableFailure diag =
1515 emitSilenceableError()
1516 << "could not find a producer for operand number: " << operandNumber
1517 << " of " << *target;
1518 diag.attachNote(target->getLoc()) << "target op";
1519 return diag;
1520 }
1521 producers.push_back(producer);
1522 }
1523 results.set(llvm::cast<OpResult>(getResult()), producers);
1524 return DiagnosedSilenceableFailure::success();
1525}
1526
1527//===----------------------------------------------------------------------===//
1528// GetOperandOp
1529//===----------------------------------------------------------------------===//
1530
1531DiagnosedSilenceableFailure
1532transform::GetOperandOp::apply(transform::TransformRewriter &rewriter,
1533 transform::TransformResults &results,
1534 transform::TransformState &state) {
1535 SmallVector<Value> operands;
1536 for (Operation *target : state.getPayloadOps(getTarget())) {
1537 SmallVector<int64_t> operandPositions;
1538 DiagnosedSilenceableFailure diag = expandTargetSpecification(
1539 getLoc(), getIsAll(), getIsInverted(), getRawPositionList(),
1540 target->getNumOperands(), operandPositions);
1541 if (diag.isSilenceableFailure()) {
1542 diag.attachNote(target->getLoc())
1543 << "while considering positions of this payload operation";
1544 return diag;
1545 }
1546 llvm::append_range(operands,
1547 llvm::map_range(operandPositions, [&](int64_t pos) {
1548 return target->getOperand(pos);
1549 }));
1550 }
1551 results.setValues(cast<OpResult>(getResult()), operands);
1552 return DiagnosedSilenceableFailure::success();
1553}
1554
1555LogicalResult transform::GetOperandOp::verify() {
1556 return verifyTransformMatchDimsOp(getOperation(), getRawPositionList(),
1557 getIsInverted(), getIsAll());
1558}
1559
1560//===----------------------------------------------------------------------===//
1561// GetResultOp
1562//===----------------------------------------------------------------------===//
1563
1564DiagnosedSilenceableFailure
1565transform::GetResultOp::apply(transform::TransformRewriter &rewriter,
1566 transform::TransformResults &results,
1567 transform::TransformState &state) {
1568 SmallVector<Value> opResults;
1569 for (Operation *target : state.getPayloadOps(getTarget())) {
1570 SmallVector<int64_t> resultPositions;
1571 DiagnosedSilenceableFailure diag = expandTargetSpecification(
1572 getLoc(), getIsAll(), getIsInverted(), getRawPositionList(),
1573 target->getNumResults(), resultPositions);
1574 if (diag.isSilenceableFailure()) {
1575 diag.attachNote(target->getLoc())
1576 << "while considering positions of this payload operation";
1577 return diag;
1578 }
1579 llvm::append_range(opResults,
1580 llvm::map_range(resultPositions, [&](int64_t pos) {
1581 return target->getResult(pos);
1582 }));
1583 }
1584 results.setValues(cast<OpResult>(getResult()), opResults);
1585 return DiagnosedSilenceableFailure::success();
1586}
1587
1588LogicalResult transform::GetResultOp::verify() {
1589 return verifyTransformMatchDimsOp(getOperation(), getRawPositionList(),
1590 getIsInverted(), getIsAll());
1591}
1592
1593//===----------------------------------------------------------------------===//
1594// GetTypeOp
1595//===----------------------------------------------------------------------===//
1596
1597void transform::GetTypeOp::getEffects(
1598 SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
1599 onlyReadsHandle(getValue(), effects);
1600 producesHandle(getResult(), effects);
1601 onlyReadsPayload(effects);
1602}
1603
1604DiagnosedSilenceableFailure
1605transform::GetTypeOp::apply(transform::TransformRewriter &rewriter,
1606 transform::TransformResults &results,
1607 transform::TransformState &state) {
1608 SmallVector<Attribute> params;
1609 for (Value value : state.getPayloadValues(getValue())) {
1610 Type type = value.getType();
1611 if (getElemental()) {
1612 if (auto shaped = dyn_cast<ShapedType>(type)) {
1613 type = shaped.getElementType();
1614 }
1615 }
1616 params.push_back(TypeAttr::get(type));
1617 }
1618 results.setParams(cast<OpResult>(getResult()), params);
1619 return DiagnosedSilenceableFailure::success();
1620}
1621
1622//===----------------------------------------------------------------------===//
1623// IncludeOp
1624//===----------------------------------------------------------------------===//
1625
1626/// Applies the transform ops contained in `block`. Maps `results` to the same
1627/// values as the operands of the block terminator.
1628static DiagnosedSilenceableFailure
1629applySequenceBlock(Block &block, transform::FailurePropagationMode mode,
1630 transform::TransformState &state,
1631 transform::TransformResults &results) {
1632 // Apply the sequenced ops one by one.
1633 for (Operation &transform : block.without_terminator()) {
1634 DiagnosedSilenceableFailure result =
1635 state.applyTransform(transform: cast<transform::TransformOpInterface>(transform));
1636 if (result.isDefiniteFailure())
1637 return result;
1638
1639 if (result.isSilenceableFailure()) {
1640 if (mode == transform::FailurePropagationMode::Propagate) {
1641 // Propagate empty results in case of early exit.
1642 forwardEmptyOperands(block: &block, state, results);
1643 return result;
1644 }
1645 (void)result.silence();
1646 }
1647 }
1648
1649 // Forward the operation mapping for values yielded from the sequence to the
1650 // values produced by the sequence op.
1651 transform::detail::forwardTerminatorOperands(block: &block, state, results);
1652 return DiagnosedSilenceableFailure::success();
1653}
1654
1655DiagnosedSilenceableFailure
1656transform::IncludeOp::apply(transform::TransformRewriter &rewriter,
1657 transform::TransformResults &results,
1658 transform::TransformState &state) {
1659 auto callee = SymbolTable::lookupNearestSymbolFrom<NamedSequenceOp>(
1660 getOperation(), getTarget());
1661 assert(callee && "unverified reference to unknown symbol");
1662
1663 if (callee.isExternal())
1664 return emitDefiniteFailure() << "unresolved external named sequence";
1665
1666 // Map operands to block arguments.
1667 SmallVector<SmallVector<MappedValue>> mappings;
1668 detail::prepareValueMappings(mappings, getOperands(), state);
1669 auto scope = state.make_region_scope(callee.getBody());
1670 for (auto &&[arg, map] :
1671 llvm::zip_equal(callee.getBody().front().getArguments(), mappings)) {
1672 if (failed(state.mapBlockArgument(arg, map)))
1673 return DiagnosedSilenceableFailure::definiteFailure();
1674 }
1675
1676 DiagnosedSilenceableFailure result = applySequenceBlock(
1677 callee.getBody().front(), getFailurePropagationMode(), state, results);
1678 mappings.clear();
1679 detail::prepareValueMappings(
1680 mappings, callee.getBody().front().getTerminator()->getOperands(), state);
1681 for (auto &&[result, mapping] : llvm::zip_equal(getResults(), mappings))
1682 results.setMappedValues(result, mapping);
1683 return result;
1684}
1685
1686static DiagnosedSilenceableFailure
1687verifyNamedSequenceOp(transform::NamedSequenceOp op, bool emitWarnings);
1688
1689void transform::IncludeOp::getEffects(
1690 SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
1691 // Always mark as modifying the payload.
1692 // TODO: a mechanism to annotate effects on payload. Even when all handles are
1693 // only read, the payload may still be modified, so we currently stay on the
1694 // conservative side and always indicate modification. This may prevent some
1695 // code reordering.
1696 modifiesPayload(effects);
1697
1698 // Results are always produced.
1699 producesHandle(getResults(), effects);
1700
1701 // Adds default effects to operands and results. This will be added if
1702 // preconditions fail so the trait verifier doesn't complain about missing
1703 // effects and the real precondition failure is reported later on.
1704 auto defaultEffects = [&] { onlyReadsHandle(getOperands(), effects); };
1705
1706 // Bail if the callee is unknown. This may run as part of the verification
1707 // process before we verified the validity of the callee or of this op.
1708 auto target =
1709 getOperation()->getAttrOfType<SymbolRefAttr>(getTargetAttrName());
1710 if (!target)
1711 return defaultEffects();
1712 auto callee = SymbolTable::lookupNearestSymbolFrom<NamedSequenceOp>(
1713 getOperation(), getTarget());
1714 if (!callee)
1715 return defaultEffects();
1716 DiagnosedSilenceableFailure earlyVerifierResult =
1717 verifyNamedSequenceOp(callee, /*emitWarnings=*/false);
1718 if (!earlyVerifierResult.succeeded()) {
1719 (void)earlyVerifierResult.silence();
1720 return defaultEffects();
1721 }
1722
1723 for (unsigned i = 0, e = getNumOperands(); i < e; ++i) {
1724 if (callee.getArgAttr(i, TransformDialect::kArgConsumedAttrName))
1725 consumesHandle(getOperand(i), effects);
1726 else
1727 onlyReadsHandle(getOperand(i), effects);
1728 }
1729}
1730
1731LogicalResult
1732transform::IncludeOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
1733 // Access through indirection and do additional checking because this may be
1734 // running before the main op verifier.
1735 auto targetAttr = getOperation()->getAttrOfType<SymbolRefAttr>("target");
1736 if (!targetAttr)
1737 return emitOpError() << "expects a 'target' symbol reference attribute";
1738
1739 auto target = symbolTable.lookupNearestSymbolFrom<transform::NamedSequenceOp>(
1740 *this, targetAttr);
1741 if (!target)
1742 return emitOpError() << "does not reference a named transform sequence";
1743
1744 FunctionType fnType = target.getFunctionType();
1745 if (fnType.getNumInputs() != getNumOperands())
1746 return emitError("incorrect number of operands for callee");
1747
1748 for (unsigned i = 0, e = fnType.getNumInputs(); i != e; ++i) {
1749 if (getOperand(i).getType() != fnType.getInput(i)) {
1750 return emitOpError("operand type mismatch: expected operand type ")
1751 << fnType.getInput(i) << ", but provided "
1752 << getOperand(i).getType() << " for operand number " << i;
1753 }
1754 }
1755
1756 if (fnType.getNumResults() != getNumResults())
1757 return emitError("incorrect number of results for callee");
1758
1759 for (unsigned i = 0, e = fnType.getNumResults(); i != e; ++i) {
1760 Type resultType = getResult(i).getType();
1761 Type funcType = fnType.getResult(i);
1762 if (!implementSameTransformInterface(resultType, funcType)) {
1763 return emitOpError() << "type of result #" << i
1764 << " must implement the same transform dialect "
1765 "interface as the corresponding callee result";
1766 }
1767 }
1768
1769 return verifyFunctionLikeConsumeAnnotations(
1770 cast<FunctionOpInterface>(*target), /*emitWarnings=*/false,
1771 /*alsoVerifyInternal=*/true)
1772 .checkAndReport();
1773}
1774
1775//===----------------------------------------------------------------------===//
1776// MatchOperationEmptyOp
1777//===----------------------------------------------------------------------===//
1778
1779DiagnosedSilenceableFailure transform::MatchOperationEmptyOp::matchOperation(
1780 ::std::optional<::mlir::Operation *> maybeCurrent,
1781 transform::TransformResults &results, transform::TransformState &state) {
1782 if (!maybeCurrent.has_value()) {
1783 DEBUG_MATCHER({ DBGS_MATCHER() << "MatchOperationEmptyOp success\n"; });
1784 return DiagnosedSilenceableFailure::success();
1785 }
1786 DEBUG_MATCHER({ DBGS_MATCHER() << "MatchOperationEmptyOp failure\n"; });
1787 return emitSilenceableError() << "operation is not empty";
1788}
1789
1790//===----------------------------------------------------------------------===//
1791// MatchOperationNameOp
1792//===----------------------------------------------------------------------===//
1793
1794DiagnosedSilenceableFailure transform::MatchOperationNameOp::matchOperation(
1795 Operation *current, transform::TransformResults &results,
1796 transform::TransformState &state) {
1797 StringRef currentOpName = current->getName().getStringRef();
1798 for (auto acceptedAttr : getOpNames().getAsRange<StringAttr>()) {
1799 if (acceptedAttr.getValue() == currentOpName)
1800 return DiagnosedSilenceableFailure::success();
1801 }
1802 return emitSilenceableError() << "wrong operation name";
1803}
1804
1805//===----------------------------------------------------------------------===//
1806// MatchParamCmpIOp
1807//===----------------------------------------------------------------------===//
1808
1809DiagnosedSilenceableFailure
1810transform::MatchParamCmpIOp::apply(transform::TransformRewriter &rewriter,
1811 transform::TransformResults &results,
1812 transform::TransformState &state) {
1813 auto signedAPIntAsString = [&](const APInt &value) {
1814 std::string str;
1815 llvm::raw_string_ostream os(str);
1816 value.print(os, /*isSigned=*/true);
1817 return os.str();
1818 };
1819
1820 ArrayRef<Attribute> params = state.getParams(getParam());
1821 ArrayRef<Attribute> references = state.getParams(getReference());
1822
1823 if (params.size() != references.size()) {
1824 return emitSilenceableError()
1825 << "parameters have different payload lengths (" << params.size()
1826 << " vs " << references.size() << ")";
1827 }
1828
1829 for (auto &&[i, param, reference] : llvm::enumerate(params, references)) {
1830 auto intAttr = llvm::dyn_cast<IntegerAttr>(param);
1831 auto refAttr = llvm::dyn_cast<IntegerAttr>(reference);
1832 if (!intAttr || !refAttr) {
1833 return emitDefiniteFailure()
1834 << "non-integer parameter value not expected";
1835 }
1836 if (intAttr.getType() != refAttr.getType()) {
1837 return emitDefiniteFailure()
1838 << "mismatching integer attribute types in parameter #" << i;
1839 }
1840 APInt value = intAttr.getValue();
1841 APInt refValue = refAttr.getValue();
1842
1843 // TODO: this copy will not be necessary in C++20.
1844 int64_t position = i;
1845 auto reportError = [&](StringRef direction) {
1846 DiagnosedSilenceableFailure diag =
1847 emitSilenceableError() << "expected parameter to be " << direction
1848 << " " << signedAPIntAsString(refValue)
1849 << ", got " << signedAPIntAsString(value);
1850 diag.attachNote(getParam().getLoc())
1851 << "value # " << position
1852 << " associated with the parameter defined here";
1853 return diag;
1854 };
1855
1856 switch (getPredicate()) {
1857 case MatchCmpIPredicate::eq:
1858 if (value.eq(refValue))
1859 break;
1860 return reportError("equal to");
1861 case MatchCmpIPredicate::ne:
1862 if (value.ne(refValue))
1863 break;
1864 return reportError("not equal to");
1865 case MatchCmpIPredicate::lt:
1866 if (value.slt(refValue))
1867 break;
1868 return reportError("less than");
1869 case MatchCmpIPredicate::le:
1870 if (value.sle(refValue))
1871 break;
1872 return reportError("less than or equal to");
1873 case MatchCmpIPredicate::gt:
1874 if (value.sgt(refValue))
1875 break;
1876 return reportError("greater than");
1877 case MatchCmpIPredicate::ge:
1878 if (value.sge(refValue))
1879 break;
1880 return reportError("greater than or equal to");
1881 }
1882 }
1883 return DiagnosedSilenceableFailure::success();
1884}
1885
1886void transform::MatchParamCmpIOp::getEffects(
1887 SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
1888 onlyReadsHandle(getParam(), effects);
1889 onlyReadsHandle(getReference(), effects);
1890}
1891
1892//===----------------------------------------------------------------------===//
1893// ParamConstantOp
1894//===----------------------------------------------------------------------===//
1895
1896DiagnosedSilenceableFailure
1897transform::ParamConstantOp::apply(transform::TransformRewriter &rewriter,
1898 transform::TransformResults &results,
1899 transform::TransformState &state) {
1900 results.setParams(cast<OpResult>(getParam()), {getValue()});
1901 return DiagnosedSilenceableFailure::success();
1902}
1903
1904//===----------------------------------------------------------------------===//
1905// MergeHandlesOp
1906//===----------------------------------------------------------------------===//
1907
1908DiagnosedSilenceableFailure
1909transform::MergeHandlesOp::apply(transform::TransformRewriter &rewriter,
1910 transform::TransformResults &results,
1911 transform::TransformState &state) {
1912 ValueRange handles = getHandles();
1913 if (isa<TransformHandleTypeInterface>(handles.front().getType())) {
1914 SmallVector<Operation *> operations;
1915 for (Value operand : handles)
1916 llvm::append_range(operations, state.getPayloadOps(operand));
1917 if (!getDeduplicate()) {
1918 results.set(llvm::cast<OpResult>(getResult()), operations);
1919 return DiagnosedSilenceableFailure::success();
1920 }
1921
1922 SetVector<Operation *> uniqued(operations.begin(), operations.end());
1923 results.set(llvm::cast<OpResult>(getResult()), uniqued.getArrayRef());
1924 return DiagnosedSilenceableFailure::success();
1925 }
1926
1927 if (llvm::isa<TransformParamTypeInterface>(handles.front().getType())) {
1928 SmallVector<Attribute> attrs;
1929 for (Value attribute : handles)
1930 llvm::append_range(attrs, state.getParams(attribute));
1931 if (!getDeduplicate()) {
1932 results.setParams(cast<OpResult>(getResult()), attrs);
1933 return DiagnosedSilenceableFailure::success();
1934 }
1935
1936 SetVector<Attribute> uniqued(attrs.begin(), attrs.end());
1937 results.setParams(cast<OpResult>(getResult()), uniqued.getArrayRef());
1938 return DiagnosedSilenceableFailure::success();
1939 }
1940
1941 assert(
1942 llvm::isa<TransformValueHandleTypeInterface>(handles.front().getType()) &&
1943 "expected value handle type");
1944 SmallVector<Value> payloadValues;
1945 for (Value value : handles)
1946 llvm::append_range(payloadValues, state.getPayloadValues(value));
1947 if (!getDeduplicate()) {
1948 results.setValues(cast<OpResult>(getResult()), payloadValues);
1949 return DiagnosedSilenceableFailure::success();
1950 }
1951
1952 SetVector<Value> uniqued(payloadValues.begin(), payloadValues.end());
1953 results.setValues(cast<OpResult>(getResult()), uniqued.getArrayRef());
1954 return DiagnosedSilenceableFailure::success();
1955}
1956
1957bool transform::MergeHandlesOp::allowsRepeatedHandleOperands() {
1958 // Handles may be the same if deduplicating is enabled.
1959 return getDeduplicate();
1960}
1961
1962void transform::MergeHandlesOp::getEffects(
1963 SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
1964 onlyReadsHandle(getHandles(), effects);
1965 producesHandle(getResult(), effects);
1966
1967 // There are no effects on the Payload IR as this is only a handle
1968 // manipulation.
1969}
1970
1971OpFoldResult transform::MergeHandlesOp::fold(FoldAdaptor adaptor) {
1972 if (getDeduplicate() || getHandles().size() != 1)
1973 return {};
1974
1975 // If deduplication is not required and there is only one operand, it can be
1976 // used directly instead of merging.
1977 return getHandles().front();
1978}
1979
1980//===----------------------------------------------------------------------===//
1981// NamedSequenceOp
1982//===----------------------------------------------------------------------===//
1983
1984DiagnosedSilenceableFailure
1985transform::NamedSequenceOp::apply(transform::TransformRewriter &rewriter,
1986 transform::TransformResults &results,
1987 transform::TransformState &state) {
1988 if (isExternal())
1989 return emitDefiniteFailure() << "unresolved external named sequence";
1990
1991 // Map the entry block argument to the list of operations.
1992 // Note: this is the same implementation as PossibleTopLevelTransformOp but
1993 // without attaching the interface / trait since that is tailored to a
1994 // dangling top-level op that does not get "called".
1995 auto scope = state.make_region_scope(getBody());
1996 if (failed(detail::mapPossibleTopLevelTransformOpBlockArguments(
1997 state, this->getOperation(), getBody())))
1998 return DiagnosedSilenceableFailure::definiteFailure();
1999
2000 return applySequenceBlock(getBody().front(),
2001 FailurePropagationMode::Propagate, state, results);
2002}
2003
2004void transform::NamedSequenceOp::getEffects(
2005 SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {}
2006
2007ParseResult transform::NamedSequenceOp::parse(OpAsmParser &parser,
2008 OperationState &result) {
2009 return function_interface_impl::parseFunctionOp(
2010 parser, result, /*allowVariadic=*/false,
2011 getFunctionTypeAttrName(result.name),
2012 [](Builder &builder, ArrayRef<Type> inputs, ArrayRef<Type> results,
2013 function_interface_impl::VariadicFlag,
2014 std::string &) { return builder.getFunctionType(inputs, results); },
2015 getArgAttrsAttrName(result.name), getResAttrsAttrName(result.name));
2016}
2017
2018void transform::NamedSequenceOp::print(OpAsmPrinter &printer) {
2019 function_interface_impl::printFunctionOp(
2020 printer, cast<FunctionOpInterface>(getOperation()), /*isVariadic=*/false,
2021 getFunctionTypeAttrName().getValue(), getArgAttrsAttrName(),
2022 getResAttrsAttrName());
2023}
2024
2025/// Verifies that a symbol function-like transform dialect operation has the
2026/// signature and the terminator that have conforming types, i.e., types
2027/// implementing the same transform dialect type interface. If `allowExternal`
2028/// is set, allow external symbols (declarations) and don't check the terminator
2029/// as it may not exist.
2030static DiagnosedSilenceableFailure
2031verifyYieldingSingleBlockOp(FunctionOpInterface op, bool allowExternal) {
2032 if (auto parent = op->getParentOfType<transform::TransformOpInterface>()) {
2033 DiagnosedSilenceableFailure diag =
2034 emitSilenceableFailure(op)
2035 << "cannot be defined inside another transform op";
2036 diag.attachNote(loc: parent.getLoc()) << "ancestor transform op";
2037 return diag;
2038 }
2039
2040 if (op.isExternal() || op.getFunctionBody().empty()) {
2041 if (allowExternal)
2042 return DiagnosedSilenceableFailure::success();
2043
2044 return emitSilenceableFailure(op) << "cannot be external";
2045 }
2046
2047 if (op.getFunctionBody().front().empty())
2048 return emitSilenceableFailure(op) << "expected a non-empty body block";
2049
2050 Operation *terminator = &op.getFunctionBody().front().back();
2051 if (!isa<transform::YieldOp>(terminator)) {
2052 DiagnosedSilenceableFailure diag = emitSilenceableFailure(op)
2053 << "expected '"
2054 << transform::YieldOp::getOperationName()
2055 << "' as terminator";
2056 diag.attachNote(loc: terminator->getLoc()) << "terminator";
2057 return diag;
2058 }
2059
2060 if (terminator->getNumOperands() != op.getResultTypes().size()) {
2061 return emitSilenceableFailure(op: terminator)
2062 << "expected terminator to have as many operands as the parent op "
2063 "has results";
2064 }
2065 for (auto [i, operandType, resultType] : llvm::zip_equal(
2066 llvm::seq<unsigned>(0, terminator->getNumOperands()),
2067 terminator->getOperands().getType(), op.getResultTypes())) {
2068 if (operandType == resultType)
2069 continue;
2070 return emitSilenceableFailure(terminator)
2071 << "the type of the terminator operand #" << i
2072 << " must match the type of the corresponding parent op result ("
2073 << operandType << " vs " << resultType << ")";
2074 }
2075
2076 return DiagnosedSilenceableFailure::success();
2077}
2078
2079/// Verification of a NamedSequenceOp. This does not report the error
2080/// immediately, so it can be used to check for op's well-formedness before the
2081/// verifier runs, e.g., during trait verification.
2082static DiagnosedSilenceableFailure
2083verifyNamedSequenceOp(transform::NamedSequenceOp op, bool emitWarnings) {
2084 if (Operation *parent = op->getParentWithTrait<OpTrait::SymbolTable>()) {
2085 if (!parent->getAttr(
2086 transform::TransformDialect::kWithNamedSequenceAttrName)) {
2087 DiagnosedSilenceableFailure diag =
2088 emitSilenceableFailure(op)
2089 << "expects the parent symbol table to have the '"
2090 << transform::TransformDialect::kWithNamedSequenceAttrName
2091 << "' attribute";
2092 diag.attachNote(loc: parent->getLoc()) << "symbol table operation";
2093 return diag;
2094 }
2095 }
2096
2097 if (auto parent = op->getParentOfType<transform::TransformOpInterface>()) {
2098 DiagnosedSilenceableFailure diag =
2099 emitSilenceableFailure(op)
2100 << "cannot be defined inside another transform op";
2101 diag.attachNote(loc: parent.getLoc()) << "ancestor transform op";
2102 return diag;
2103 }
2104
2105 if (op.isExternal() || op.getBody().empty())
2106 return verifyFunctionLikeConsumeAnnotations(cast<FunctionOpInterface>(*op),
2107 emitWarnings);
2108
2109 if (op.getBody().front().empty())
2110 return emitSilenceableFailure(op) << "expected a non-empty body block";
2111
2112 Operation *terminator = &op.getBody().front().back();
2113 if (!isa<transform::YieldOp>(terminator)) {
2114 DiagnosedSilenceableFailure diag = emitSilenceableFailure(op)
2115 << "expected '"
2116 << transform::YieldOp::getOperationName()
2117 << "' as terminator";
2118 diag.attachNote(loc: terminator->getLoc()) << "terminator";
2119 return diag;
2120 }
2121
2122 if (terminator->getNumOperands() != op.getFunctionType().getNumResults()) {
2123 return emitSilenceableFailure(op: terminator)
2124 << "expected terminator to have as many operands as the parent op "
2125 "has results";
2126 }
2127 for (auto [i, operandType, resultType] :
2128 llvm::zip_equal(llvm::seq<unsigned>(0, terminator->getNumOperands()),
2129 terminator->getOperands().getType(),
2130 op.getFunctionType().getResults())) {
2131 if (operandType == resultType)
2132 continue;
2133 return emitSilenceableFailure(terminator)
2134 << "the type of the terminator operand #" << i
2135 << " must match the type of the corresponding parent op result ("
2136 << operandType << " vs " << resultType << ")";
2137 }
2138
2139 auto funcOp = cast<FunctionOpInterface>(*op);
2140 DiagnosedSilenceableFailure diag =
2141 verifyFunctionLikeConsumeAnnotations(funcOp, emitWarnings);
2142 if (!diag.succeeded())
2143 return diag;
2144
2145 return verifyYieldingSingleBlockOp(funcOp,
2146 /*allowExternal=*/true);
2147}
2148
2149LogicalResult transform::NamedSequenceOp::verify() {
2150 // Actual verification happens in a separate function for reusability.
2151 return verifyNamedSequenceOp(*this, /*emitWarnings=*/true).checkAndReport();
2152}
2153
2154template <typename FnTy>
2155static void buildSequenceBody(OpBuilder &builder, OperationState &state,
2156 Type bbArgType, TypeRange extraBindingTypes,
2157 FnTy bodyBuilder) {
2158 SmallVector<Type> types;
2159 types.reserve(N: 1 + extraBindingTypes.size());
2160 types.push_back(Elt: bbArgType);
2161 llvm::append_range(C&: types, R&: extraBindingTypes);
2162
2163 OpBuilder::InsertionGuard guard(builder);
2164 Region *region = state.regions.back().get();
2165 Block *bodyBlock =
2166 builder.createBlock(parent: region, insertPt: region->begin(), argTypes: types,
2167 locs: SmallVector<Location>(types.size(), state.location));
2168
2169 // Populate body.
2170 builder.setInsertionPointToStart(bodyBlock);
2171 if constexpr (llvm::function_traits<FnTy>::num_args == 3) {
2172 bodyBuilder(builder, state.location, bodyBlock->getArgument(i: 0));
2173 } else {
2174 bodyBuilder(builder, state.location, bodyBlock->getArgument(i: 0),
2175 bodyBlock->getArguments().drop_front());
2176 }
2177}
2178
2179void transform::NamedSequenceOp::build(OpBuilder &builder,
2180 OperationState &state, StringRef symName,
2181 Type rootType, TypeRange resultTypes,
2182 SequenceBodyBuilderFn bodyBuilder,
2183 ArrayRef<NamedAttribute> attrs,
2184 ArrayRef<DictionaryAttr> argAttrs) {
2185 state.addAttribute(SymbolTable::getSymbolAttrName(),
2186 builder.getStringAttr(symName));
2187 state.addAttribute(getFunctionTypeAttrName(state.name),
2188 TypeAttr::get(FunctionType::get(builder.getContext(),
2189 rootType, resultTypes)));
2190 state.attributes.append(attrs.begin(), attrs.end());
2191 state.addRegion();
2192
2193 buildSequenceBody(builder, state, rootType,
2194 /*extraBindingTypes=*/TypeRange(), bodyBuilder);
2195}
2196
2197//===----------------------------------------------------------------------===//
2198// NumAssociationsOp
2199//===----------------------------------------------------------------------===//
2200
2201DiagnosedSilenceableFailure
2202transform::NumAssociationsOp::apply(transform::TransformRewriter &rewriter,
2203 transform::TransformResults &results,
2204 transform::TransformState &state) {
2205 size_t numAssociations =
2206 llvm::TypeSwitch<Type, size_t>(getHandle().getType())
2207 .Case([&](TransformHandleTypeInterface opHandle) {
2208 return llvm::range_size(state.getPayloadOps(getHandle()));
2209 })
2210 .Case([&](TransformValueHandleTypeInterface valueHandle) {
2211 return llvm::range_size(state.getPayloadValues(getHandle()));
2212 })
2213 .Case([&](TransformParamTypeInterface param) {
2214 return llvm::range_size(state.getParams(getHandle()));
2215 })
2216 .Default([](Type) {
2217 llvm_unreachable("unknown kind of transform dialect type");
2218 return 0;
2219 });
2220 results.setParams(cast<OpResult>(getNum()),
2221 rewriter.getI64IntegerAttr(numAssociations));
2222 return DiagnosedSilenceableFailure::success();
2223}
2224
2225LogicalResult transform::NumAssociationsOp::verify() {
2226 // Verify that the result type accepts an i64 attribute as payload.
2227 auto resultType = cast<TransformParamTypeInterface>(getNum().getType());
2228 return resultType
2229 .checkPayload(getLoc(), {Builder(getContext()).getI64IntegerAttr(0)})
2230 .checkAndReport();
2231}
2232
2233//===----------------------------------------------------------------------===//
2234// SelectOp
2235//===----------------------------------------------------------------------===//
2236
2237DiagnosedSilenceableFailure
2238transform::SelectOp::apply(transform::TransformRewriter &rewriter,
2239 transform::TransformResults &results,
2240 transform::TransformState &state) {
2241 SmallVector<Operation *> result;
2242 auto payloadOps = state.getPayloadOps(getTarget());
2243 for (Operation *op : payloadOps) {
2244 if (op->getName().getStringRef() == getOpName())
2245 result.push_back(op);
2246 }
2247 results.set(cast<OpResult>(getResult()), result);
2248 return DiagnosedSilenceableFailure::success();
2249}
2250
2251//===----------------------------------------------------------------------===//
2252// SplitHandleOp
2253//===----------------------------------------------------------------------===//
2254
2255void transform::SplitHandleOp::build(OpBuilder &builder, OperationState &result,
2256 Value target, int64_t numResultHandles) {
2257 result.addOperands(target);
2258 result.addTypes(SmallVector<Type>(numResultHandles, target.getType()));
2259}
2260
2261DiagnosedSilenceableFailure
2262transform::SplitHandleOp::apply(transform::TransformRewriter &rewriter,
2263 transform::TransformResults &results,
2264 transform::TransformState &state) {
2265 int64_t numPayloadOps = llvm::range_size(state.getPayloadOps(getHandle()));
2266 auto produceNumOpsError = [&]() {
2267 return emitSilenceableError()
2268 << getHandle() << " expected to contain " << this->getNumResults()
2269 << " payload ops but it contains " << numPayloadOps
2270 << " payload ops";
2271 };
2272
2273 // Fail if there are more payload ops than results and no overflow result was
2274 // specified.
2275 if (numPayloadOps > getNumResults() && !getOverflowResult().has_value())
2276 return produceNumOpsError();
2277
2278 // Fail if there are more results than payload ops. Unless:
2279 // - "fail_on_payload_too_small" is set to "false", or
2280 // - "pass_through_empty_handle" is set to "true" and there are 0 payload ops.
2281 if (numPayloadOps < getNumResults() && getFailOnPayloadTooSmall() &&
2282 (numPayloadOps != 0 || !getPassThroughEmptyHandle()))
2283 return produceNumOpsError();
2284
2285 // Distribute payload ops.
2286 SmallVector<SmallVector<Operation *, 1>> resultHandles(getNumResults(), {});
2287 if (getOverflowResult())
2288 resultHandles[*getOverflowResult()].reserve(numPayloadOps -
2289 getNumResults());
2290 for (auto &&en : llvm::enumerate(state.getPayloadOps(getHandle()))) {
2291 int64_t resultNum = en.index();
2292 if (resultNum >= getNumResults())
2293 resultNum = *getOverflowResult();
2294 resultHandles[resultNum].push_back(en.value());
2295 }
2296
2297 // Set transform op results.
2298 for (auto &&it : llvm::enumerate(resultHandles))
2299 results.set(llvm::cast<OpResult>(getResult(it.index())), it.value());
2300
2301 return DiagnosedSilenceableFailure::success();
2302}
2303
2304void transform::SplitHandleOp::getEffects(
2305 SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
2306 onlyReadsHandle(getHandle(), effects);
2307 producesHandle(getResults(), effects);
2308 // There are no effects on the Payload IR as this is only a handle
2309 // manipulation.
2310}
2311
2312LogicalResult transform::SplitHandleOp::verify() {
2313 if (getOverflowResult().has_value() &&
2314 !(*getOverflowResult() < getNumResults()))
2315 return emitOpError("overflow_result is not a valid result index");
2316 return success();
2317}
2318
2319//===----------------------------------------------------------------------===//
2320// ReplicateOp
2321//===----------------------------------------------------------------------===//
2322
2323DiagnosedSilenceableFailure
2324transform::ReplicateOp::apply(transform::TransformRewriter &rewriter,
2325 transform::TransformResults &results,
2326 transform::TransformState &state) {
2327 unsigned numRepetitions = llvm::range_size(state.getPayloadOps(getPattern()));
2328 for (const auto &en : llvm::enumerate(getHandles())) {
2329 Value handle = en.value();
2330 if (isa<TransformHandleTypeInterface>(handle.getType())) {
2331 SmallVector<Operation *> current =
2332 llvm::to_vector(state.getPayloadOps(handle));
2333 SmallVector<Operation *> payload;
2334 payload.reserve(numRepetitions * current.size());
2335 for (unsigned i = 0; i < numRepetitions; ++i)
2336 llvm::append_range(payload, current);
2337 results.set(llvm::cast<OpResult>(getReplicated()[en.index()]), payload);
2338 } else {
2339 assert(llvm::isa<TransformParamTypeInterface>(handle.getType()) &&
2340 "expected param type");
2341 ArrayRef<Attribute> current = state.getParams(handle);
2342 SmallVector<Attribute> params;
2343 params.reserve(numRepetitions * current.size());
2344 for (unsigned i = 0; i < numRepetitions; ++i)
2345 llvm::append_range(params, current);
2346 results.setParams(llvm::cast<OpResult>(getReplicated()[en.index()]),
2347 params);
2348 }
2349 }
2350 return DiagnosedSilenceableFailure::success();
2351}
2352
2353void transform::ReplicateOp::getEffects(
2354 SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
2355 onlyReadsHandle(getPattern(), effects);
2356 onlyReadsHandle(getHandles(), effects);
2357 producesHandle(getReplicated(), effects);
2358}
2359
2360//===----------------------------------------------------------------------===//
2361// SequenceOp
2362//===----------------------------------------------------------------------===//
2363
2364DiagnosedSilenceableFailure
2365transform::SequenceOp::apply(transform::TransformRewriter &rewriter,
2366 transform::TransformResults &results,
2367 transform::TransformState &state) {
2368 // Map the entry block argument to the list of operations.
2369 auto scope = state.make_region_scope(*getBodyBlock()->getParent());
2370 if (failed(mapBlockArguments(state)))
2371 return DiagnosedSilenceableFailure::definiteFailure();
2372
2373 return applySequenceBlock(*getBodyBlock(), getFailurePropagationMode(), state,
2374 results);
2375}
2376
2377static ParseResult parseSequenceOpOperands(
2378 OpAsmParser &parser, std::optional<OpAsmParser::UnresolvedOperand> &root,
2379 Type &rootType,
2380 SmallVectorImpl<OpAsmParser::UnresolvedOperand> &extraBindings,
2381 SmallVectorImpl<Type> &extraBindingTypes) {
2382 OpAsmParser::UnresolvedOperand rootOperand;
2383 OptionalParseResult hasRoot = parser.parseOptionalOperand(result&: rootOperand);
2384 if (!hasRoot.has_value()) {
2385 root = std::nullopt;
2386 return success();
2387 }
2388 if (failed(result: hasRoot.value()))
2389 return failure();
2390 root = rootOperand;
2391
2392 if (succeeded(result: parser.parseOptionalComma())) {
2393 if (failed(result: parser.parseOperandList(result&: extraBindings)))
2394 return failure();
2395 }
2396 if (failed(result: parser.parseColon()))
2397 return failure();
2398
2399 // The paren is truly optional.
2400 (void)parser.parseOptionalLParen();
2401
2402 if (failed(result: parser.parseType(result&: rootType))) {
2403 return failure();
2404 }
2405
2406 if (!extraBindings.empty()) {
2407 if (parser.parseComma() || parser.parseTypeList(result&: extraBindingTypes))
2408 return failure();
2409 }
2410
2411 if (extraBindingTypes.size() != extraBindings.size()) {
2412 return parser.emitError(loc: parser.getNameLoc(),
2413 message: "expected types to be provided for all operands");
2414 }
2415
2416 // The paren is truly optional.
2417 (void)parser.parseOptionalRParen();
2418 return success();
2419}
2420
2421static void printSequenceOpOperands(OpAsmPrinter &printer, Operation *op,
2422 Value root, Type rootType,
2423 ValueRange extraBindings,
2424 TypeRange extraBindingTypes) {
2425 if (!root)
2426 return;
2427
2428 printer << root;
2429 bool hasExtras = !extraBindings.empty();
2430 if (hasExtras) {
2431 printer << ", ";
2432 printer.printOperands(container: extraBindings);
2433 }
2434
2435 printer << " : ";
2436 if (hasExtras)
2437 printer << "(";
2438
2439 printer << rootType;
2440 if (hasExtras) {
2441 printer << ", ";
2442 llvm::interleaveComma(c: extraBindingTypes, os&: printer.getStream());
2443 printer << ")";
2444 }
2445}
2446
2447/// Returns `true` if the given op operand may be consuming the handle value in
2448/// the Transform IR. That is, if it may have a Free effect on it.
2449static bool isValueUsePotentialConsumer(OpOperand &use) {
2450 // Conservatively assume the effect being present in absence of the interface.
2451 auto iface = dyn_cast<transform::TransformOpInterface>(use.getOwner());
2452 if (!iface)
2453 return true;
2454
2455 return isHandleConsumed(use.get(), iface);
2456}
2457
2458LogicalResult
2459checkDoubleConsume(Value value,
2460 function_ref<InFlightDiagnostic()> reportError) {
2461 OpOperand *potentialConsumer = nullptr;
2462 for (OpOperand &use : value.getUses()) {
2463 if (!isValueUsePotentialConsumer(use))
2464 continue;
2465
2466 if (!potentialConsumer) {
2467 potentialConsumer = &use;
2468 continue;
2469 }
2470
2471 InFlightDiagnostic diag = reportError()
2472 << " has more than one potential consumer";
2473 diag.attachNote(noteLoc: potentialConsumer->getOwner()->getLoc())
2474 << "used here as operand #" << potentialConsumer->getOperandNumber();
2475 diag.attachNote(noteLoc: use.getOwner()->getLoc())
2476 << "used here as operand #" << use.getOperandNumber();
2477 return diag;
2478 }
2479
2480 return success();
2481}
2482
2483LogicalResult transform::SequenceOp::verify() {
2484 assert(getBodyBlock()->getNumArguments() >= 1 &&
2485 "the number of arguments must have been verified to be more than 1 by "
2486 "PossibleTopLevelTransformOpTrait");
2487
2488 if (!getRoot() && !getExtraBindings().empty()) {
2489 return emitOpError()
2490 << "does not expect extra operands when used as top-level";
2491 }
2492
2493 // Check if a block argument has more than one consuming use.
2494 for (BlockArgument arg : getBodyBlock()->getArguments()) {
2495 if (failed(checkDoubleConsume(arg, [this, arg]() {
2496 return (emitOpError() << "block argument #" << arg.getArgNumber());
2497 }))) {
2498 return failure();
2499 }
2500 }
2501
2502 // Check properties of the nested operations they cannot check themselves.
2503 for (Operation &child : *getBodyBlock()) {
2504 if (!isa<TransformOpInterface>(child) &&
2505 &child != &getBodyBlock()->back()) {
2506 InFlightDiagnostic diag =
2507 emitOpError()
2508 << "expected children ops to implement TransformOpInterface";
2509 diag.attachNote(child.getLoc()) << "op without interface";
2510 return diag;
2511 }
2512
2513 for (OpResult result : child.getResults()) {
2514 auto report = [&]() {
2515 return (child.emitError() << "result #" << result.getResultNumber());
2516 };
2517 if (failed(checkDoubleConsume(result, report)))
2518 return failure();
2519 }
2520 }
2521
2522 if (!getBodyBlock()->mightHaveTerminator())
2523 return emitOpError() << "expects to have a terminator in the body";
2524
2525 if (getBodyBlock()->getTerminator()->getOperandTypes() !=
2526 getOperation()->getResultTypes()) {
2527 InFlightDiagnostic diag = emitOpError()
2528 << "expects the types of the terminator operands "
2529 "to match the types of the result";
2530 diag.attachNote(getBodyBlock()->getTerminator()->getLoc()) << "terminator";
2531 return diag;
2532 }
2533 return success();
2534}
2535
2536void transform::SequenceOp::getEffects(
2537 SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
2538 getPotentialTopLevelEffects(effects);
2539}
2540
2541OperandRange
2542transform::SequenceOp::getEntrySuccessorOperands(RegionBranchPoint point) {
2543 assert(point == getBody() && "unexpected region index");
2544 if (getOperation()->getNumOperands() > 0)
2545 return getOperation()->getOperands();
2546 return OperandRange(getOperation()->operand_end(),
2547 getOperation()->operand_end());
2548}
2549
2550void transform::SequenceOp::getSuccessorRegions(
2551 RegionBranchPoint point, SmallVectorImpl<RegionSuccessor> &regions) {
2552 if (point.isParent()) {
2553 Region *bodyRegion = &getBody();
2554 regions.emplace_back(bodyRegion, getNumOperands() != 0
2555 ? bodyRegion->getArguments()
2556 : Block::BlockArgListType());
2557 return;
2558 }
2559
2560 assert(point == getBody() && "unexpected region index");
2561 regions.emplace_back(getOperation()->getResults());
2562}
2563
2564void transform::SequenceOp::getRegionInvocationBounds(
2565 ArrayRef<Attribute> operands, SmallVectorImpl<InvocationBounds> &bounds) {
2566 (void)operands;
2567 bounds.emplace_back(1, 1);
2568}
2569
2570void transform::SequenceOp::build(OpBuilder &builder, OperationState &state,
2571 TypeRange resultTypes,
2572 FailurePropagationMode failurePropagationMode,
2573 Value root,
2574 SequenceBodyBuilderFn bodyBuilder) {
2575 build(builder, state, resultTypes, failurePropagationMode, root,
2576 /*extra_bindings=*/ValueRange());
2577 Type bbArgType = root.getType();
2578 buildSequenceBody(builder, state, bbArgType,
2579 /*extraBindingTypes=*/TypeRange(), bodyBuilder);
2580}
2581
2582void transform::SequenceOp::build(OpBuilder &builder, OperationState &state,
2583 TypeRange resultTypes,
2584 FailurePropagationMode failurePropagationMode,
2585 Value root, ValueRange extraBindings,
2586 SequenceBodyBuilderArgsFn bodyBuilder) {
2587 build(builder, state, resultTypes, failurePropagationMode, root,
2588 extraBindings);
2589 buildSequenceBody(builder, state, root.getType(), extraBindings.getTypes(),
2590 bodyBuilder);
2591}
2592
2593void transform::SequenceOp::build(OpBuilder &builder, OperationState &state,
2594 TypeRange resultTypes,
2595 FailurePropagationMode failurePropagationMode,
2596 Type bbArgType,
2597 SequenceBodyBuilderFn bodyBuilder) {
2598 build(builder, state, resultTypes, failurePropagationMode, /*root=*/Value(),
2599 /*extra_bindings=*/ValueRange());
2600 buildSequenceBody(builder, state, bbArgType,
2601 /*extraBindingTypes=*/TypeRange(), bodyBuilder);
2602}
2603
2604void transform::SequenceOp::build(OpBuilder &builder, OperationState &state,
2605 TypeRange resultTypes,
2606 FailurePropagationMode failurePropagationMode,
2607 Type bbArgType, TypeRange extraBindingTypes,
2608 SequenceBodyBuilderArgsFn bodyBuilder) {
2609 build(builder, state, resultTypes, failurePropagationMode, /*root=*/Value(),
2610 /*extra_bindings=*/ValueRange());
2611 buildSequenceBody(builder, state, bbArgType, extraBindingTypes, bodyBuilder);
2612}
2613
2614//===----------------------------------------------------------------------===//
2615// PrintOp
2616//===----------------------------------------------------------------------===//
2617
2618void transform::PrintOp::build(OpBuilder &builder, OperationState &result,
2619 StringRef name) {
2620 if (!name.empty())
2621 result.getOrAddProperties<Properties>().name = builder.getStringAttr(name);
2622}
2623
2624void transform::PrintOp::build(OpBuilder &builder, OperationState &result,
2625 Value target, StringRef name) {
2626 result.addOperands({target});
2627 build(builder, result, name);
2628}
2629
2630DiagnosedSilenceableFailure
2631transform::PrintOp::apply(transform::TransformRewriter &rewriter,
2632 transform::TransformResults &results,
2633 transform::TransformState &state) {
2634 llvm::outs() << "[[[ IR printer: ";
2635 if (getName().has_value())
2636 llvm::outs() << *getName() << " ";
2637
2638 OpPrintingFlags printFlags;
2639 if (getAssumeVerified().value_or(false))
2640 printFlags.assumeVerified();
2641 if (getUseLocalScope().value_or(false))
2642 printFlags.useLocalScope();
2643 if (getSkipRegions().value_or(false))
2644 printFlags.skipRegions();
2645
2646 if (!getTarget()) {
2647 llvm::outs() << "top-level ]]]\n";
2648 state.getTopLevel()->print(llvm::outs(), printFlags);
2649 llvm::outs() << "\n";
2650 return DiagnosedSilenceableFailure::success();
2651 }
2652
2653 llvm::outs() << "]]]\n";
2654 for (Operation *target : state.getPayloadOps(getTarget())) {
2655 target->print(llvm::outs(), printFlags);
2656 llvm::outs() << "\n";
2657 }
2658
2659 return DiagnosedSilenceableFailure::success();
2660}
2661
2662void transform::PrintOp::getEffects(
2663 SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
2664 // We don't really care about mutability here, but `getTarget` now
2665 // unconditionally casts to a specific type before verification could run
2666 // here.
2667 if (!getTargetMutable().empty())
2668 onlyReadsHandle(getTargetMutable()[0].get(), effects);
2669 onlyReadsPayload(effects);
2670
2671 // There is no resource for stderr file descriptor, so just declare print
2672 // writes into the default resource.
2673 effects.emplace_back(MemoryEffects::Write::get());
2674}
2675
2676//===----------------------------------------------------------------------===//
2677// VerifyOp
2678//===----------------------------------------------------------------------===//
2679
2680DiagnosedSilenceableFailure
2681transform::VerifyOp::applyToOne(transform::TransformRewriter &rewriter,
2682 Operation *target,
2683 transform::ApplyToEachResultList &results,
2684 transform::TransformState &state) {
2685 if (failed(::mlir::verify(target))) {
2686 DiagnosedDefiniteFailure diag = emitDefiniteFailure()
2687 << "failed to verify payload op";
2688 diag.attachNote(target->getLoc()) << "payload op";
2689 return diag;
2690 }
2691 return DiagnosedSilenceableFailure::success();
2692}
2693
2694void transform::VerifyOp::getEffects(
2695 SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
2696 transform::onlyReadsHandle(getTarget(), effects);
2697}
2698
2699//===----------------------------------------------------------------------===//
2700// YieldOp
2701//===----------------------------------------------------------------------===//
2702
2703void transform::YieldOp::getEffects(
2704 SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
2705 onlyReadsHandle(getOperands(), effects);
2706}
2707

source code of mlir/lib/Dialect/Transform/IR/TransformOps.cpp