1//===- TransformOps.cpp - Transform dialect operations --------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#include "mlir/Dialect/Transform/IR/TransformOps.h"
10
11#include "mlir/Conversion/ConvertToLLVM/ToLLVMInterface.h"
12#include "mlir/Conversion/LLVMCommon/TypeConverter.h"
13#include "mlir/Dialect/Transform/IR/TransformAttrs.h"
14#include "mlir/Dialect/Transform/IR/TransformDialect.h"
15#include "mlir/Dialect/Transform/IR/TransformTypes.h"
16#include "mlir/Dialect/Transform/Interfaces/MatchInterfaces.h"
17#include "mlir/Dialect/Transform/Interfaces/TransformInterfaces.h"
18#include "mlir/IR/BuiltinAttributes.h"
19#include "mlir/IR/Diagnostics.h"
20#include "mlir/IR/Dominance.h"
21#include "mlir/IR/OpImplementation.h"
22#include "mlir/IR/OperationSupport.h"
23#include "mlir/IR/PatternMatch.h"
24#include "mlir/IR/Verifier.h"
25#include "mlir/Interfaces/ControlFlowInterfaces.h"
26#include "mlir/Interfaces/FunctionImplementation.h"
27#include "mlir/Interfaces/FunctionInterfaces.h"
28#include "mlir/Pass/PassManager.h"
29#include "mlir/Pass/PassRegistry.h"
30#include "mlir/Transforms/CSE.h"
31#include "mlir/Transforms/DialectConversion.h"
32#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
33#include "mlir/Transforms/LoopInvariantCodeMotionUtils.h"
34#include "llvm/ADT/DenseSet.h"
35#include "llvm/ADT/STLExtras.h"
36#include "llvm/ADT/ScopeExit.h"
37#include "llvm/ADT/SmallPtrSet.h"
38#include "llvm/ADT/TypeSwitch.h"
39#include "llvm/Support/Debug.h"
40#include "llvm/Support/ErrorHandling.h"
41#include "llvm/Support/InterleavedRange.h"
42#include <optional>
43
44#define DEBUG_TYPE "transform-dialect"
45#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "] ")
46
47#define DEBUG_TYPE_MATCHER "transform-matcher"
48#define DBGS_MATCHER() (llvm::dbgs() << "[" DEBUG_TYPE_MATCHER "] ")
49#define DEBUG_MATCHER(x) DEBUG_WITH_TYPE(DEBUG_TYPE_MATCHER, x)
50
51using namespace mlir;
52
53static ParseResult parseApplyRegisteredPassOptions(
54 OpAsmParser &parser, DictionaryAttr &options,
55 SmallVectorImpl<OpAsmParser::UnresolvedOperand> &dynamicOptions);
56static void printApplyRegisteredPassOptions(OpAsmPrinter &printer,
57 Operation *op,
58 DictionaryAttr options,
59 ValueRange dynamicOptions);
60static ParseResult parseSequenceOpOperands(
61 OpAsmParser &parser, std::optional<OpAsmParser::UnresolvedOperand> &root,
62 Type &rootType,
63 SmallVectorImpl<OpAsmParser::UnresolvedOperand> &extraBindings,
64 SmallVectorImpl<Type> &extraBindingTypes);
65static void printSequenceOpOperands(OpAsmPrinter &printer, Operation *op,
66 Value root, Type rootType,
67 ValueRange extraBindings,
68 TypeRange extraBindingTypes);
69static void printForeachMatchSymbols(OpAsmPrinter &printer, Operation *op,
70 ArrayAttr matchers, ArrayAttr actions);
71static ParseResult parseForeachMatchSymbols(OpAsmParser &parser,
72 ArrayAttr &matchers,
73 ArrayAttr &actions);
74
75/// Helper function to check if the given transform op is contained in (or
76/// equal to) the given payload target op. In that case, an error is returned.
77/// Transforming transform IR that is currently executing is generally unsafe.
78static DiagnosedSilenceableFailure
79ensurePayloadIsSeparateFromTransform(transform::TransformOpInterface transform,
80 Operation *payload) {
81 Operation *transformAncestor = transform.getOperation();
82 while (transformAncestor) {
83 if (transformAncestor == payload) {
84 DiagnosedDefiniteFailure diag =
85 transform.emitDefiniteFailure()
86 << "cannot apply transform to itself (or one of its ancestors)";
87 diag.attachNote(loc: payload->getLoc()) << "target payload op";
88 return diag;
89 }
90 transformAncestor = transformAncestor->getParentOp();
91 }
92 return DiagnosedSilenceableFailure::success();
93}
94
95#define GET_OP_CLASSES
96#include "mlir/Dialect/Transform/IR/TransformOps.cpp.inc"
97
98//===----------------------------------------------------------------------===//
99// AlternativesOp
100//===----------------------------------------------------------------------===//
101
102OperandRange
103transform::AlternativesOp::getEntrySuccessorOperands(RegionBranchPoint point) {
104 if (!point.isParent() && getOperation()->getNumOperands() == 1)
105 return getOperation()->getOperands();
106 return OperandRange(getOperation()->operand_end(),
107 getOperation()->operand_end());
108}
109
110void transform::AlternativesOp::getSuccessorRegions(
111 RegionBranchPoint point, SmallVectorImpl<RegionSuccessor> &regions) {
112 for (Region &alternative : llvm::drop_begin(
113 RangeOrContainer: getAlternatives(),
114 N: point.isParent() ? 0
115 : point.getRegionOrNull()->getRegionNumber() + 1)) {
116 regions.emplace_back(Args: &alternative, Args: !getOperands().empty()
117 ? alternative.getArguments()
118 : Block::BlockArgListType());
119 }
120 if (!point.isParent())
121 regions.emplace_back(Args: getOperation()->getResults());
122}
123
124void transform::AlternativesOp::getRegionInvocationBounds(
125 ArrayRef<Attribute> operands, SmallVectorImpl<InvocationBounds> &bounds) {
126 (void)operands;
127 // The region corresponding to the first alternative is always executed, the
128 // remaining may or may not be executed.
129 bounds.reserve(N: getNumRegions());
130 bounds.emplace_back(Args: 1, Args: 1);
131 bounds.resize(N: getNumRegions(), NV: InvocationBounds(0, 1));
132}
133
134static void forwardEmptyOperands(Block *block, transform::TransformState &state,
135 transform::TransformResults &results) {
136 for (const auto &res : block->getParentOp()->getOpResults())
137 results.set(value: res, ops: {});
138}
139
140DiagnosedSilenceableFailure
141transform::AlternativesOp::apply(transform::TransformRewriter &rewriter,
142 transform::TransformResults &results,
143 transform::TransformState &state) {
144 SmallVector<Operation *> originals;
145 if (Value scopeHandle = getScope())
146 llvm::append_range(C&: originals, R: state.getPayloadOps(value: scopeHandle));
147 else
148 originals.push_back(Elt: state.getTopLevel());
149
150 for (Operation *original : originals) {
151 if (original->isAncestor(other: getOperation())) {
152 auto diag = emitDefiniteFailure()
153 << "scope must not contain the transforms being applied";
154 diag.attachNote(loc: original->getLoc()) << "scope";
155 return diag;
156 }
157 if (!original->hasTrait<OpTrait::IsIsolatedFromAbove>()) {
158 auto diag = emitDefiniteFailure()
159 << "only isolated-from-above ops can be alternative scopes";
160 diag.attachNote(loc: original->getLoc()) << "scope";
161 return diag;
162 }
163 }
164
165 for (Region &reg : getAlternatives()) {
166 // Clone the scope operations and make the transforms in this alternative
167 // region apply to them by virtue of mapping the block argument (the only
168 // visible handle) to the cloned scope operations. This effectively prevents
169 // the transformation from accessing any IR outside the scope.
170 auto scope = state.make_region_scope(region&: reg);
171 auto clones = llvm::to_vector(
172 Range: llvm::map_range(C&: originals, F: [](Operation *op) { return op->clone(); }));
173 auto deleteClones = llvm::make_scope_exit(F: [&] {
174 for (Operation *clone : clones)
175 clone->erase();
176 });
177 if (failed(Result: state.mapBlockArguments(argument: reg.front().getArgument(i: 0), operations: clones)))
178 return DiagnosedSilenceableFailure::definiteFailure();
179
180 bool failed = false;
181 for (Operation &transform : reg.front().without_terminator()) {
182 DiagnosedSilenceableFailure result =
183 state.applyTransform(transform: cast<TransformOpInterface>(Val&: transform));
184 if (result.isSilenceableFailure()) {
185 LLVM_DEBUG(DBGS() << "alternative failed: " << result.getMessage()
186 << "\n");
187 failed = true;
188 break;
189 }
190
191 if (::mlir::failed(Result: result.silence()))
192 return DiagnosedSilenceableFailure::definiteFailure();
193 }
194
195 // If all operations in the given alternative succeeded, no need to consider
196 // the rest. Replace the original scoping operation with the clone on which
197 // the transformations were performed.
198 if (!failed) {
199 // We will be using the clones, so cancel their scheduled deletion.
200 deleteClones.release();
201 TrackingListener listener(state, *this);
202 IRRewriter rewriter(getContext(), &listener);
203 for (const auto &kvp : llvm::zip(t&: originals, u&: clones)) {
204 Operation *original = std::get<0>(t: kvp);
205 Operation *clone = std::get<1>(t: kvp);
206 original->getBlock()->getOperations().insert(where: original->getIterator(),
207 New: clone);
208 rewriter.replaceOp(op: original, newValues: clone->getResults());
209 }
210 detail::forwardTerminatorOperands(block: &reg.front(), state, results);
211 return DiagnosedSilenceableFailure::success();
212 }
213 }
214 return emitSilenceableError() << "all alternatives failed";
215}
216
217void transform::AlternativesOp::getEffects(
218 SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
219 consumesHandle(handles: getOperation()->getOpOperands(), effects);
220 producesHandle(handles: getOperation()->getOpResults(), effects);
221 for (Region *region : getRegions()) {
222 if (!region->empty())
223 producesHandle(handles: region->front().getArguments(), effects);
224 }
225 modifiesPayload(effects);
226}
227
228LogicalResult transform::AlternativesOp::verify() {
229 for (Region &alternative : getAlternatives()) {
230 Block &block = alternative.front();
231 Operation *terminator = block.getTerminator();
232 if (terminator->getOperands().getTypes() != getResults().getTypes()) {
233 InFlightDiagnostic diag = emitOpError()
234 << "expects terminator operands to have the "
235 "same type as results of the operation";
236 diag.attachNote(noteLoc: terminator->getLoc()) << "terminator";
237 return diag;
238 }
239 }
240
241 return success();
242}
243
244//===----------------------------------------------------------------------===//
245// AnnotateOp
246//===----------------------------------------------------------------------===//
247
248DiagnosedSilenceableFailure
249transform::AnnotateOp::apply(transform::TransformRewriter &rewriter,
250 transform::TransformResults &results,
251 transform::TransformState &state) {
252 SmallVector<Operation *> targets =
253 llvm::to_vector(Range: state.getPayloadOps(value: getTarget()));
254
255 Attribute attr = UnitAttr::get(context: getContext());
256 if (auto paramH = getParam()) {
257 ArrayRef<Attribute> params = state.getParams(value: paramH);
258 if (params.size() != 1) {
259 if (targets.size() != params.size()) {
260 return emitSilenceableError()
261 << "parameter and target have different payload lengths ("
262 << params.size() << " vs " << targets.size() << ")";
263 }
264 for (auto &&[target, attr] : llvm::zip_equal(t&: targets, u&: params))
265 target->setAttr(name: getName(), value: attr);
266 return DiagnosedSilenceableFailure::success();
267 }
268 attr = params[0];
269 }
270 for (auto *target : targets)
271 target->setAttr(name: getName(), value: attr);
272 return DiagnosedSilenceableFailure::success();
273}
274
275void transform::AnnotateOp::getEffects(
276 SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
277 onlyReadsHandle(handles: getTargetMutable(), effects);
278 onlyReadsHandle(handles: getParamMutable(), effects);
279 modifiesPayload(effects);
280}
281
282//===----------------------------------------------------------------------===//
283// ApplyCommonSubexpressionEliminationOp
284//===----------------------------------------------------------------------===//
285
286DiagnosedSilenceableFailure
287transform::ApplyCommonSubexpressionEliminationOp::applyToOne(
288 transform::TransformRewriter &rewriter, Operation *target,
289 ApplyToEachResultList &results, transform::TransformState &state) {
290 // Make sure that this transform is not applied to itself. Modifying the
291 // transform IR while it is being interpreted is generally dangerous.
292 DiagnosedSilenceableFailure payloadCheck =
293 ensurePayloadIsSeparateFromTransform(transform: *this, payload: target);
294 if (!payloadCheck.succeeded())
295 return payloadCheck;
296
297 DominanceInfo domInfo;
298 mlir::eliminateCommonSubExpressions(rewriter, domInfo, op: target);
299 return DiagnosedSilenceableFailure::success();
300}
301
302void transform::ApplyCommonSubexpressionEliminationOp::getEffects(
303 SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
304 transform::onlyReadsHandle(handles: getTargetMutable(), effects);
305 transform::modifiesPayload(effects);
306}
307
308//===----------------------------------------------------------------------===//
309// ApplyDeadCodeEliminationOp
310//===----------------------------------------------------------------------===//
311
312DiagnosedSilenceableFailure transform::ApplyDeadCodeEliminationOp::applyToOne(
313 transform::TransformRewriter &rewriter, Operation *target,
314 ApplyToEachResultList &results, transform::TransformState &state) {
315 // Make sure that this transform is not applied to itself. Modifying the
316 // transform IR while it is being interpreted is generally dangerous.
317 DiagnosedSilenceableFailure payloadCheck =
318 ensurePayloadIsSeparateFromTransform(transform: *this, payload: target);
319 if (!payloadCheck.succeeded())
320 return payloadCheck;
321
322 // Maintain a worklist of potentially dead ops.
323 SetVector<Operation *> worklist;
324
325 // Helper function that adds all defining ops of used values (operands and
326 // operands of nested ops).
327 auto addDefiningOpsToWorklist = [&](Operation *op) {
328 op->walk(callback: [&](Operation *op) {
329 for (Value v : op->getOperands())
330 if (Operation *defOp = v.getDefiningOp())
331 if (target->isProperAncestor(other: defOp))
332 worklist.insert(X: defOp);
333 });
334 };
335
336 // Helper function that erases an op.
337 auto eraseOp = [&](Operation *op) {
338 // Remove op and nested ops from the worklist.
339 op->walk(callback: [&](Operation *op) {
340 const auto *it = llvm::find(Range&: worklist, Val: op);
341 if (it != worklist.end())
342 worklist.erase(I: it);
343 });
344 rewriter.eraseOp(op);
345 };
346
347 // Initial walk over the IR.
348 target->walk<WalkOrder::PostOrder>(callback: [&](Operation *op) {
349 if (op != target && isOpTriviallyDead(op)) {
350 addDefiningOpsToWorklist(op);
351 eraseOp(op);
352 }
353 });
354
355 // Erase all ops that have become dead.
356 while (!worklist.empty()) {
357 Operation *op = worklist.pop_back_val();
358 if (!isOpTriviallyDead(op))
359 continue;
360 addDefiningOpsToWorklist(op);
361 eraseOp(op);
362 }
363
364 return DiagnosedSilenceableFailure::success();
365}
366
367void transform::ApplyDeadCodeEliminationOp::getEffects(
368 SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
369 transform::onlyReadsHandle(handles: getTargetMutable(), effects);
370 transform::modifiesPayload(effects);
371}
372
373//===----------------------------------------------------------------------===//
374// ApplyPatternsOp
375//===----------------------------------------------------------------------===//
376
377DiagnosedSilenceableFailure transform::ApplyPatternsOp::applyToOne(
378 transform::TransformRewriter &rewriter, Operation *target,
379 ApplyToEachResultList &results, transform::TransformState &state) {
380 // Make sure that this transform is not applied to itself. Modifying the
381 // transform IR while it is being interpreted is generally dangerous. Even
382 // more so for the ApplyPatternsOp because the GreedyPatternRewriteDriver
383 // performs many additional simplifications such as dead code elimination.
384 DiagnosedSilenceableFailure payloadCheck =
385 ensurePayloadIsSeparateFromTransform(transform: *this, payload: target);
386 if (!payloadCheck.succeeded())
387 return payloadCheck;
388
389 // Gather all specified patterns.
390 MLIRContext *ctx = target->getContext();
391 RewritePatternSet patterns(ctx);
392 if (!getRegion().empty()) {
393 for (Operation &op : getRegion().front()) {
394 cast<transform::PatternDescriptorOpInterface>(Val: &op)
395 .populatePatternsWithState(patterns, state);
396 }
397 }
398
399 // Configure the GreedyPatternRewriteDriver.
400 GreedyRewriteConfig config;
401 config.setListener(
402 static_cast<RewriterBase::Listener *>(rewriter.getListener()));
403 FrozenRewritePatternSet frozenPatterns(std::move(patterns));
404
405 config.setMaxIterations(getMaxIterations() == static_cast<uint64_t>(-1)
406 ? GreedyRewriteConfig::kNoLimit
407 : getMaxIterations());
408 config.setMaxNumRewrites(getMaxNumRewrites() == static_cast<uint64_t>(-1)
409 ? GreedyRewriteConfig::kNoLimit
410 : getMaxNumRewrites());
411
412 // Apply patterns and CSE repetitively until a fixpoint is reached. If no CSE
413 // was requested, apply the greedy pattern rewrite only once. (The greedy
414 // pattern rewrite driver already iterates to a fixpoint internally.)
415 bool cseChanged = false;
416 // One or two iterations should be sufficient. Stop iterating after a certain
417 // threshold to make debugging easier.
418 static const int64_t kNumMaxIterations = 50;
419 int64_t iteration = 0;
420 do {
421 LogicalResult result = failure();
422 if (target->hasTrait<OpTrait::IsIsolatedFromAbove>()) {
423 // Op is isolated from above. Apply patterns and also perform region
424 // simplification.
425 result = applyPatternsGreedily(op: target, patterns: frozenPatterns, config);
426 } else {
427 // Manually gather list of ops because the other
428 // GreedyPatternRewriteDriver overloads only accepts ops that are isolated
429 // from above. This way, patterns can be applied to ops that are not
430 // isolated from above. Regions are not being simplified. Furthermore,
431 // only a single greedy rewrite iteration is performed.
432 SmallVector<Operation *> ops;
433 target->walk(callback: [&](Operation *nestedOp) {
434 if (target != nestedOp)
435 ops.push_back(Elt: nestedOp);
436 });
437 result = applyOpPatternsGreedily(ops, patterns: frozenPatterns, config);
438 }
439
440 // A failure typically indicates that the pattern application did not
441 // converge.
442 if (failed(Result: result)) {
443 return emitSilenceableFailure(op: target)
444 << "greedy pattern application failed";
445 }
446
447 if (getApplyCse()) {
448 DominanceInfo domInfo;
449 mlir::eliminateCommonSubExpressions(rewriter, domInfo, op: target,
450 changed: &cseChanged);
451 }
452 } while (cseChanged && ++iteration < kNumMaxIterations);
453
454 if (iteration == kNumMaxIterations)
455 return emitDefiniteFailure() << "fixpoint iteration did not converge";
456
457 return DiagnosedSilenceableFailure::success();
458}
459
460LogicalResult transform::ApplyPatternsOp::verify() {
461 if (!getRegion().empty()) {
462 for (Operation &op : getRegion().front()) {
463 if (!isa<transform::PatternDescriptorOpInterface>(Val: &op)) {
464 InFlightDiagnostic diag = emitOpError()
465 << "expected children ops to implement "
466 "PatternDescriptorOpInterface";
467 diag.attachNote(noteLoc: op.getLoc()) << "op without interface";
468 return diag;
469 }
470 }
471 }
472 return success();
473}
474
475void transform::ApplyPatternsOp::getEffects(
476 SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
477 transform::onlyReadsHandle(handles: getTargetMutable(), effects);
478 transform::modifiesPayload(effects);
479}
480
481void transform::ApplyPatternsOp::build(
482 OpBuilder &builder, OperationState &result, Value target,
483 function_ref<void(OpBuilder &, Location)> bodyBuilder) {
484 result.addOperands(newOperands: target);
485
486 OpBuilder::InsertionGuard g(builder);
487 Region *region = result.addRegion();
488 builder.createBlock(parent: region);
489 if (bodyBuilder)
490 bodyBuilder(builder, result.location);
491}
492
493//===----------------------------------------------------------------------===//
494// ApplyCanonicalizationPatternsOp
495//===----------------------------------------------------------------------===//
496
497void transform::ApplyCanonicalizationPatternsOp::populatePatterns(
498 RewritePatternSet &patterns) {
499 MLIRContext *ctx = patterns.getContext();
500 for (Dialect *dialect : ctx->getLoadedDialects())
501 dialect->getCanonicalizationPatterns(results&: patterns);
502 for (RegisteredOperationName op : ctx->getRegisteredOperations())
503 op.getCanonicalizationPatterns(results&: patterns, context: ctx);
504}
505
506//===----------------------------------------------------------------------===//
507// ApplyConversionPatternsOp
508//===----------------------------------------------------------------------===//
509
510DiagnosedSilenceableFailure transform::ApplyConversionPatternsOp::apply(
511 transform::TransformRewriter &rewriter,
512 transform::TransformResults &results, transform::TransformState &state) {
513 MLIRContext *ctx = getContext();
514
515 // Instantiate the default type converter if a type converter builder is
516 // specified.
517 std::unique_ptr<TypeConverter> defaultTypeConverter;
518 transform::TypeConverterBuilderOpInterface typeConverterBuilder =
519 getDefaultTypeConverter();
520 if (typeConverterBuilder)
521 defaultTypeConverter = typeConverterBuilder.getTypeConverter();
522
523 // Configure conversion target.
524 ConversionTarget conversionTarget(*getContext());
525 if (getLegalOps())
526 for (Attribute attr : cast<ArrayAttr>(Val: *getLegalOps()))
527 conversionTarget.addLegalOp(
528 op: OperationName(cast<StringAttr>(Val&: attr).getValue(), ctx));
529 if (getIllegalOps())
530 for (Attribute attr : cast<ArrayAttr>(Val: *getIllegalOps()))
531 conversionTarget.addIllegalOp(
532 op: OperationName(cast<StringAttr>(Val&: attr).getValue(), ctx));
533 if (getLegalDialects())
534 for (Attribute attr : cast<ArrayAttr>(Val: *getLegalDialects()))
535 conversionTarget.addLegalDialect(name: cast<StringAttr>(Val&: attr).getValue());
536 if (getIllegalDialects())
537 for (Attribute attr : cast<ArrayAttr>(Val: *getIllegalDialects()))
538 conversionTarget.addIllegalDialect(name: cast<StringAttr>(Val&: attr).getValue());
539
540 // Gather all specified patterns.
541 RewritePatternSet patterns(ctx);
542 // Need to keep the converters alive until after pattern application because
543 // the patterns take a reference to an object that would otherwise get out of
544 // scope.
545 SmallVector<std::unique_ptr<TypeConverter>> keepAliveConverters;
546 if (!getPatterns().empty()) {
547 for (Operation &op : getPatterns().front()) {
548 auto descriptor =
549 cast<transform::ConversionPatternDescriptorOpInterface>(Val: &op);
550
551 // Check if this pattern set specifies a type converter.
552 std::unique_ptr<TypeConverter> typeConverter =
553 descriptor.getTypeConverter();
554 TypeConverter *converter = nullptr;
555 if (typeConverter) {
556 keepAliveConverters.emplace_back(Args: std::move(typeConverter));
557 converter = keepAliveConverters.back().get();
558 } else {
559 // No type converter specified: Use the default type converter.
560 if (!defaultTypeConverter) {
561 auto diag = emitDefiniteFailure()
562 << "pattern descriptor does not specify type "
563 "converter and apply_conversion_patterns op has "
564 "no default type converter";
565 diag.attachNote(loc: op.getLoc()) << "pattern descriptor op";
566 return diag;
567 }
568 converter = defaultTypeConverter.get();
569 }
570
571 // Add descriptor-specific updates to the conversion target, which may
572 // depend on the final type converter. In structural converters, the
573 // legality of types dictates the dynamic legality of an operation.
574 descriptor.populateConversionTargetRules(typeConverter: *converter, conversionTarget);
575
576 descriptor.populatePatterns(typeConverter&: *converter, patterns);
577 }
578 }
579
580 // Attach a tracking listener if handles should be preserved. We configure the
581 // listener to allow op replacements with different names, as conversion
582 // patterns typically replace ops with replacement ops that have a different
583 // name.
584 TrackingListenerConfig trackingConfig;
585 trackingConfig.requireMatchingReplacementOpName = false;
586 ErrorCheckingTrackingListener trackingListener(state, *this, trackingConfig);
587 ConversionConfig conversionConfig;
588 if (getPreserveHandles())
589 conversionConfig.listener = &trackingListener;
590
591 FrozenRewritePatternSet frozenPatterns(std::move(patterns));
592 for (Operation *target : state.getPayloadOps(value: getTarget())) {
593 // Make sure that this transform is not applied to itself. Modifying the
594 // transform IR while it is being interpreted is generally dangerous.
595 DiagnosedSilenceableFailure payloadCheck =
596 ensurePayloadIsSeparateFromTransform(transform: *this, payload: target);
597 if (!payloadCheck.succeeded())
598 return payloadCheck;
599
600 LogicalResult status = failure();
601 if (getPartialConversion()) {
602 status = applyPartialConversion(op: target, target: conversionTarget, patterns: frozenPatterns,
603 config: conversionConfig);
604 } else {
605 status = applyFullConversion(op: target, target: conversionTarget, patterns: frozenPatterns,
606 config: conversionConfig);
607 }
608
609 // Check dialect conversion state.
610 DiagnosedSilenceableFailure diag = DiagnosedSilenceableFailure::success();
611 if (failed(Result: status)) {
612 diag = emitSilenceableError() << "dialect conversion failed";
613 diag.attachNote(loc: target->getLoc()) << "target op";
614 }
615
616 // Check tracking listener error state.
617 DiagnosedSilenceableFailure trackingFailure =
618 trackingListener.checkAndResetError();
619 if (!trackingFailure.succeeded()) {
620 if (diag.succeeded()) {
621 // Tracking failure is the only failure.
622 return trackingFailure;
623 } else {
624 diag.attachNote() << "tracking listener also failed: "
625 << trackingFailure.getMessage();
626 (void)trackingFailure.silence();
627 }
628 }
629
630 if (!diag.succeeded())
631 return diag;
632 }
633
634 return DiagnosedSilenceableFailure::success();
635}
636
637LogicalResult transform::ApplyConversionPatternsOp::verify() {
638 if (getNumRegions() != 1 && getNumRegions() != 2)
639 return emitOpError() << "expected 1 or 2 regions";
640 if (!getPatterns().empty()) {
641 for (Operation &op : getPatterns().front()) {
642 if (!isa<transform::ConversionPatternDescriptorOpInterface>(Val: &op)) {
643 InFlightDiagnostic diag =
644 emitOpError() << "expected pattern children ops to implement "
645 "ConversionPatternDescriptorOpInterface";
646 diag.attachNote(noteLoc: op.getLoc()) << "op without interface";
647 return diag;
648 }
649 }
650 }
651 if (getNumRegions() == 2) {
652 Region &typeConverterRegion = getRegion(i: 1);
653 if (!llvm::hasSingleElement(C&: typeConverterRegion.front()))
654 return emitOpError()
655 << "expected exactly one op in default type converter region";
656 Operation *maybeTypeConverter = &typeConverterRegion.front().front();
657 auto typeConverterOp = dyn_cast<transform::TypeConverterBuilderOpInterface>(
658 Val: maybeTypeConverter);
659 if (!typeConverterOp) {
660 InFlightDiagnostic diag = emitOpError()
661 << "expected default converter child op to "
662 "implement TypeConverterBuilderOpInterface";
663 diag.attachNote(noteLoc: maybeTypeConverter->getLoc()) << "op without interface";
664 return diag;
665 }
666 // Check default type converter type.
667 if (!getPatterns().empty()) {
668 for (Operation &op : getPatterns().front()) {
669 auto descriptor =
670 cast<transform::ConversionPatternDescriptorOpInterface>(Val: &op);
671 if (failed(Result: descriptor.verifyTypeConverter(builder: typeConverterOp)))
672 return failure();
673 }
674 }
675 }
676 return success();
677}
678
679void transform::ApplyConversionPatternsOp::getEffects(
680 SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
681 if (!getPreserveHandles()) {
682 transform::consumesHandle(handles: getTargetMutable(), effects);
683 } else {
684 transform::onlyReadsHandle(handles: getTargetMutable(), effects);
685 }
686 transform::modifiesPayload(effects);
687}
688
689void transform::ApplyConversionPatternsOp::build(
690 OpBuilder &builder, OperationState &result, Value target,
691 function_ref<void(OpBuilder &, Location)> patternsBodyBuilder,
692 function_ref<void(OpBuilder &, Location)> typeConverterBodyBuilder) {
693 result.addOperands(newOperands: target);
694
695 {
696 OpBuilder::InsertionGuard g(builder);
697 Region *region1 = result.addRegion();
698 builder.createBlock(parent: region1);
699 if (patternsBodyBuilder)
700 patternsBodyBuilder(builder, result.location);
701 }
702 {
703 OpBuilder::InsertionGuard g(builder);
704 Region *region2 = result.addRegion();
705 builder.createBlock(parent: region2);
706 if (typeConverterBodyBuilder)
707 typeConverterBodyBuilder(builder, result.location);
708 }
709}
710
711//===----------------------------------------------------------------------===//
712// ApplyToLLVMConversionPatternsOp
713//===----------------------------------------------------------------------===//
714
715void transform::ApplyToLLVMConversionPatternsOp::populatePatterns(
716 TypeConverter &typeConverter, RewritePatternSet &patterns) {
717 Dialect *dialect = getContext()->getLoadedDialect(name: getDialectName());
718 assert(dialect && "expected that dialect is loaded");
719 auto *iface = cast<ConvertToLLVMPatternInterface>(Val: dialect);
720 // ConversionTarget is currently ignored because the enclosing
721 // apply_conversion_patterns op sets up its own ConversionTarget.
722 ConversionTarget target(*getContext());
723 iface->populateConvertToLLVMConversionPatterns(
724 target, typeConverter&: static_cast<LLVMTypeConverter &>(typeConverter), patterns);
725}
726
727LogicalResult transform::ApplyToLLVMConversionPatternsOp::verifyTypeConverter(
728 transform::TypeConverterBuilderOpInterface builder) {
729 if (builder.getTypeConverterType() != "LLVMTypeConverter")
730 return emitOpError(message: "expected LLVMTypeConverter");
731 return success();
732}
733
734LogicalResult transform::ApplyToLLVMConversionPatternsOp::verify() {
735 Dialect *dialect = getContext()->getLoadedDialect(name: getDialectName());
736 if (!dialect)
737 return emitOpError(message: "unknown dialect or dialect not loaded: ")
738 << getDialectName();
739 auto *iface = dyn_cast<ConvertToLLVMPatternInterface>(Val: dialect);
740 if (!iface)
741 return emitOpError(
742 message: "dialect does not implement ConvertToLLVMPatternInterface or "
743 "extension was not loaded: ")
744 << getDialectName();
745 return success();
746}
747
748//===----------------------------------------------------------------------===//
749// ApplyLoopInvariantCodeMotionOp
750//===----------------------------------------------------------------------===//
751
752DiagnosedSilenceableFailure
753transform::ApplyLoopInvariantCodeMotionOp::applyToOne(
754 transform::TransformRewriter &rewriter, LoopLikeOpInterface target,
755 transform::ApplyToEachResultList &results,
756 transform::TransformState &state) {
757 // Currently, LICM does not remove operations, so we don't need tracking.
758 // If this ever changes, add a LICM entry point that takes a rewriter.
759 moveLoopInvariantCode(loopLike: target);
760 return DiagnosedSilenceableFailure::success();
761}
762
763void transform::ApplyLoopInvariantCodeMotionOp::getEffects(
764 SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
765 transform::onlyReadsHandle(handles: getTargetMutable(), effects);
766 transform::modifiesPayload(effects);
767}
768
769//===----------------------------------------------------------------------===//
770// ApplyRegisteredPassOp
771//===----------------------------------------------------------------------===//
772
773void transform::ApplyRegisteredPassOp::getEffects(
774 SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
775 consumesHandle(handles: getTargetMutable(), effects);
776 onlyReadsHandle(handles: getDynamicOptionsMutable(), effects);
777 producesHandle(handles: getOperation()->getOpResults(), effects);
778 modifiesPayload(effects);
779}
780
781DiagnosedSilenceableFailure
782transform::ApplyRegisteredPassOp::apply(transform::TransformRewriter &rewriter,
783 transform::TransformResults &results,
784 transform::TransformState &state) {
785 // Obtain a single options-string to pass to the pass(-pipeline) from options
786 // passed in as a dictionary of keys mapping to values which are either
787 // attributes or param-operands pointing to attributes.
788 OperandRange dynamicOptions = getDynamicOptions();
789
790 std::string options;
791 llvm::raw_string_ostream optionsStream(options); // For "printing" attrs.
792
793 // A helper to convert an option's attribute value into a corresponding
794 // string representation, with the ability to obtain the attr(s) from a param.
795 std::function<void(Attribute)> appendValueAttr = [&](Attribute valueAttr) {
796 if (auto paramOperand = dyn_cast<transform::ParamOperandAttr>(Val&: valueAttr)) {
797 // The corresponding value attribute(s) is/are passed in via a param.
798 // Obtain the param-operand via its specified index.
799 int64_t dynamicOptionIdx = paramOperand.getIndex().getInt();
800 assert(dynamicOptionIdx < static_cast<int64_t>(dynamicOptions.size()) &&
801 "the number of ParamOperandAttrs in the options DictionaryAttr"
802 "should be the same as the number of options passed as params");
803 ArrayRef<Attribute> attrsAssociatedToParam =
804 state.getParams(value: dynamicOptions[dynamicOptionIdx]);
805 // Recursive so as to append all attrs associated to the param.
806 llvm::interleave(c: attrsAssociatedToParam, os&: optionsStream, each_fn: appendValueAttr,
807 separator: ",");
808 } else if (auto arrayAttr = dyn_cast<ArrayAttr>(Val&: valueAttr)) {
809 // Recursive so as to append all nested attrs of the array.
810 llvm::interleave(c: arrayAttr, os&: optionsStream, each_fn: appendValueAttr, separator: ",");
811 } else if (auto strAttr = dyn_cast<StringAttr>(Val&: valueAttr)) {
812 // Convert to unquoted string.
813 optionsStream << strAttr.getValue().str();
814 } else {
815 // For all other attributes, ask the attr to print itself (without type).
816 valueAttr.print(os&: optionsStream, /*elideType=*/true);
817 }
818 };
819
820 // Convert the options DictionaryAttr into a single string.
821 llvm::interleave(
822 c: getOptions(), os&: optionsStream,
823 each_fn: [&](auto namedAttribute) {
824 optionsStream << namedAttribute.getName().str(); // Append the key.
825 optionsStream << "="; // And the key-value separator.
826 appendValueAttr(namedAttribute.getValue()); // And the attr's str repr.
827 },
828 separator: " ");
829 optionsStream.flush();
830
831 // Get pass or pass pipeline from registry.
832 const PassRegistryEntry *info = PassPipelineInfo::lookup(pipelineArg: getPassName());
833 if (!info)
834 info = PassInfo::lookup(passArg: getPassName());
835 if (!info)
836 return emitDefiniteFailure()
837 << "unknown pass or pass pipeline: " << getPassName();
838
839 // Create pass manager and add the pass or pass pipeline.
840 PassManager pm(getContext());
841 if (failed(Result: info->addToPipeline(pm, options, errorHandler: [&](const Twine &msg) {
842 emitError(message: msg);
843 return failure();
844 }))) {
845 return emitDefiniteFailure()
846 << "failed to add pass or pass pipeline to pipeline: "
847 << getPassName();
848 }
849
850 auto targets = SmallVector<Operation *>(state.getPayloadOps(value: getTarget()));
851 for (Operation *target : targets) {
852 // Make sure that this transform is not applied to itself. Modifying the
853 // transform IR while it is being interpreted is generally dangerous. Even
854 // more so when applying passes because they may perform a wide range of IR
855 // modifications.
856 DiagnosedSilenceableFailure payloadCheck =
857 ensurePayloadIsSeparateFromTransform(transform: *this, payload: target);
858 if (!payloadCheck.succeeded())
859 return payloadCheck;
860
861 // Run the pass or pass pipeline on the current target operation.
862 if (failed(Result: pm.run(op: target))) {
863 auto diag = emitSilenceableError() << "pass pipeline failed";
864 diag.attachNote(loc: target->getLoc()) << "target op";
865 return diag;
866 }
867 }
868
869 // The applied pass will have directly modified the payload IR(s).
870 results.set(value: llvm::cast<OpResult>(Val: getResult()), ops&: targets);
871 return DiagnosedSilenceableFailure::success();
872}
873
874static ParseResult parseApplyRegisteredPassOptions(
875 OpAsmParser &parser, DictionaryAttr &options,
876 SmallVectorImpl<OpAsmParser::UnresolvedOperand> &dynamicOptions) {
877 // Construct the options DictionaryAttr per a `{ key = value, ... }` syntax.
878 SmallVector<NamedAttribute> keyValuePairs;
879 size_t dynamicOptionsIdx = 0;
880
881 // Helper for allowing parsing of option values which can be of the form:
882 // - a normal attribute
883 // - an operand (which would be converted to an attr referring to the operand)
884 // - ArrayAttrs containing the foregoing (in correspondence with ListOptions)
885 std::function<ParseResult(Attribute &)> parseValue =
886 [&](Attribute &valueAttr) -> ParseResult {
887 // Allow for array syntax, e.g. `[0 : i64, %param, true, %other_param]`:
888 if (succeeded(Result: parser.parseOptionalLSquare())) {
889 SmallVector<Attribute> attrs;
890
891 // Recursively parse the array's elements, which might be operands.
892 if (parser.parseCommaSeparatedList(
893 delimiter: AsmParser::Delimiter::None,
894 parseElementFn: [&]() -> ParseResult { return parseValue(attrs.emplace_back()); },
895 contextMessage: " in options dictionary") ||
896 parser.parseRSquare())
897 return failure(); // NB: Attempted parse should've output error message.
898
899 valueAttr = ArrayAttr::get(context: parser.getContext(), value: attrs);
900
901 return success();
902 }
903
904 // Parse the value, which can be either an attribute or an operand.
905 OptionalParseResult parsedValueAttr =
906 parser.parseOptionalAttribute(result&: valueAttr);
907 if (!parsedValueAttr.has_value()) {
908 OpAsmParser::UnresolvedOperand operand;
909 ParseResult parsedOperand = parser.parseOperand(result&: operand);
910 if (failed(Result: parsedOperand))
911 return failure(); // NB: Attempted parse should've output error message.
912 // To make use of the operand, we need to store it in the options dict.
913 // As SSA-values cannot occur in attributes, what we do instead is store
914 // an attribute in its place that contains the index of the param-operand,
915 // so that an attr-value associated to the param can be resolved later on.
916 dynamicOptions.push_back(Elt: operand);
917 auto wrappedIndex = IntegerAttr::get(
918 type: IntegerType::get(context: parser.getContext(), width: 64), value: dynamicOptionsIdx++);
919 valueAttr =
920 transform::ParamOperandAttr::get(context: parser.getContext(), index: wrappedIndex);
921 } else if (failed(Result: parsedValueAttr.value())) {
922 return failure(); // NB: Attempted parse should have output error message.
923 } else if (isa<transform::ParamOperandAttr>(Val: valueAttr)) {
924 return parser.emitError(loc: parser.getCurrentLocation())
925 << "the param_operand attribute is a marker reserved for "
926 << "indicating a value will be passed via params and is only used "
927 << "in the generic print format";
928 }
929
930 return success();
931 };
932
933 // Helper for `key = value`-pair parsing where `key` is a bare identifier or a
934 // string and `value` looks like either an attribute or an operand-in-an-attr.
935 std::function<ParseResult()> parseKeyValuePair = [&]() -> ParseResult {
936 std::string key;
937 Attribute valueAttr;
938
939 if (failed(Result: parser.parseOptionalKeywordOrString(result: &key)) || key.empty())
940 return parser.emitError(loc: parser.getCurrentLocation())
941 << "expected key to either be an identifier or a string";
942
943 if (failed(Result: parser.parseEqual()))
944 return parser.emitError(loc: parser.getCurrentLocation())
945 << "expected '=' after key in key-value pair";
946
947 if (failed(Result: parseValue(valueAttr)))
948 return parser.emitError(loc: parser.getCurrentLocation())
949 << "expected a valid attribute or operand as value associated "
950 << "to key '" << key << "'";
951
952 keyValuePairs.push_back(Elt: NamedAttribute(key, valueAttr));
953
954 return success();
955 };
956
957 if (parser.parseCommaSeparatedList(delimiter: AsmParser::Delimiter::Braces,
958 parseElementFn: parseKeyValuePair,
959 contextMessage: " in options dictionary"))
960 return failure(); // NB: Attempted parse should have output error message.
961
962 if (DictionaryAttr::findDuplicate(
963 array&: keyValuePairs, /*isSorted=*/false) // Also sorts the keyValuePairs.
964 .has_value())
965 return parser.emitError(loc: parser.getCurrentLocation())
966 << "duplicate keys found in options dictionary";
967
968 options = DictionaryAttr::getWithSorted(context: parser.getContext(), value: keyValuePairs);
969
970 return success();
971}
972
973static void printApplyRegisteredPassOptions(OpAsmPrinter &printer,
974 Operation *op,
975 DictionaryAttr options,
976 ValueRange dynamicOptions) {
977 if (options.empty())
978 return;
979
980 std::function<void(Attribute)> printOptionValue = [&](Attribute valueAttr) {
981 if (auto paramOperandAttr =
982 dyn_cast<transform::ParamOperandAttr>(Val&: valueAttr)) {
983 // Resolve index of param-operand to its actual SSA-value and print that.
984 printer.printOperand(
985 value: dynamicOptions[paramOperandAttr.getIndex().getInt()]);
986 } else if (auto arrayAttr = dyn_cast<ArrayAttr>(Val&: valueAttr)) {
987 // This case is so that ArrayAttr-contained operands are pretty-printed.
988 printer << "[";
989 llvm::interleaveComma(c: arrayAttr, os&: printer, each_fn: printOptionValue);
990 printer << "]";
991 } else {
992 printer.printAttribute(attr: valueAttr);
993 }
994 };
995
996 printer << "{";
997 llvm::interleaveComma(c: options, os&: printer, each_fn: [&](NamedAttribute namedAttribute) {
998 printer << namedAttribute.getName();
999 printer << " = ";
1000 printOptionValue(namedAttribute.getValue());
1001 });
1002 printer << "}";
1003}
1004
1005LogicalResult transform::ApplyRegisteredPassOp::verify() {
1006 // Check that there is a one-to-one correspondence between param operands
1007 // and references to dynamic options in the options dictionary.
1008
1009 auto dynamicOptions = SmallVector<Value>(getDynamicOptions());
1010
1011 // Helper for option values to mark seen operands as having been seen (once).
1012 std::function<LogicalResult(Attribute)> checkOptionValue =
1013 [&](Attribute valueAttr) -> LogicalResult {
1014 if (auto paramOperand = dyn_cast<transform::ParamOperandAttr>(Val&: valueAttr)) {
1015 int64_t dynamicOptionIdx = paramOperand.getIndex().getInt();
1016 if (dynamicOptionIdx < 0 ||
1017 dynamicOptionIdx >= static_cast<int64_t>(dynamicOptions.size()))
1018 return emitOpError()
1019 << "dynamic option index " << dynamicOptionIdx
1020 << " is out of bounds for the number of dynamic options: "
1021 << dynamicOptions.size();
1022 if (dynamicOptions[dynamicOptionIdx] == nullptr)
1023 return emitOpError() << "dynamic option index " << dynamicOptionIdx
1024 << " is already used in options";
1025 dynamicOptions[dynamicOptionIdx] = nullptr; // Mark this option as used.
1026 } else if (auto arrayAttr = dyn_cast<ArrayAttr>(Val&: valueAttr)) {
1027 // Recurse into ArrayAttrs as they may contain references to operands.
1028 for (auto eltAttr : arrayAttr)
1029 if (failed(Result: checkOptionValue(eltAttr)))
1030 return failure();
1031 }
1032 return success();
1033 };
1034
1035 for (NamedAttribute namedAttr : getOptions())
1036 if (failed(Result: checkOptionValue(namedAttr.getValue())))
1037 return failure();
1038
1039 // All dynamicOptions-params seen in the dict will have been set to null.
1040 for (Value dynamicOption : dynamicOptions)
1041 if (dynamicOption)
1042 return emitOpError() << "a param operand does not have a corresponding "
1043 << "param_operand attr in the options dict";
1044
1045 return success();
1046}
1047
1048//===----------------------------------------------------------------------===//
1049// CastOp
1050//===----------------------------------------------------------------------===//
1051
1052DiagnosedSilenceableFailure
1053transform::CastOp::applyToOne(transform::TransformRewriter &rewriter,
1054 Operation *target, ApplyToEachResultList &results,
1055 transform::TransformState &state) {
1056 results.push_back(op: target);
1057 return DiagnosedSilenceableFailure::success();
1058}
1059
1060void transform::CastOp::getEffects(
1061 SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
1062 onlyReadsPayload(effects);
1063 onlyReadsHandle(handles: getInputMutable(), effects);
1064 producesHandle(handles: getOperation()->getOpResults(), effects);
1065}
1066
1067bool transform::CastOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
1068 assert(inputs.size() == 1 && "expected one input");
1069 assert(outputs.size() == 1 && "expected one output");
1070 return llvm::all_of(
1071 Range: std::initializer_list<Type>{inputs.front(), outputs.front()},
1072 P: llvm::IsaPred<transform::TransformHandleTypeInterface>);
1073}
1074
1075//===----------------------------------------------------------------------===//
1076// CollectMatchingOp
1077//===----------------------------------------------------------------------===//
1078
1079/// Applies matcher operations from the given `block` using
1080/// `blockArgumentMapping` to initialize block arguments. Updates `state`
1081/// accordingly. If any of the matcher produces a silenceable failure, discards
1082/// it (printing the content to the debug output stream) and returns failure. If
1083/// any of the matchers produces a definite failure, reports it and returns
1084/// failure. If all matchers in the block succeed, populates `mappings` with the
1085/// payload entities associated with the block terminator operands. Note that
1086/// `mappings` will be cleared before that.
1087static DiagnosedSilenceableFailure
1088matchBlock(Block &block,
1089 ArrayRef<SmallVector<transform::MappedValue>> blockArgumentMapping,
1090 transform::TransformState &state,
1091 SmallVectorImpl<SmallVector<transform::MappedValue>> &mappings) {
1092 assert(block.getParent() && "cannot match using a detached block");
1093 auto matchScope = state.make_region_scope(region&: *block.getParent());
1094 if (failed(
1095 Result: state.mapBlockArguments(arguments: block.getArguments(), mapping: blockArgumentMapping)))
1096 return DiagnosedSilenceableFailure::definiteFailure();
1097
1098 for (Operation &match : block.without_terminator()) {
1099 if (!isa<transform::MatchOpInterface>(Val: match)) {
1100 return emitDefiniteFailure(loc: match.getLoc())
1101 << "expected operations in the match part to "
1102 "implement MatchOpInterface";
1103 }
1104 DiagnosedSilenceableFailure diag =
1105 state.applyTransform(transform: cast<transform::TransformOpInterface>(Val&: match));
1106 if (diag.succeeded())
1107 continue;
1108
1109 return diag;
1110 }
1111
1112 // Remember the values mapped to the terminator operands so we can
1113 // forward them to the action.
1114 ValueRange yieldedValues = block.getTerminator()->getOperands();
1115 // Our contract with the caller is that the mappings will contain only the
1116 // newly mapped values, clear the rest.
1117 mappings.clear();
1118 transform::detail::prepareValueMappings(mappings, values: yieldedValues, state);
1119 return DiagnosedSilenceableFailure::success();
1120}
1121
1122/// Returns `true` if both types implement one of the interfaces provided as
1123/// template parameters.
1124template <typename... Tys>
1125static bool implementSameInterface(Type t1, Type t2) {
1126 return ((isa<Tys>(t1) && isa<Tys>(t2)) || ... || false);
1127}
1128
1129/// Returns `true` if both types implement one of the transform dialect
1130/// interfaces.
1131static bool implementSameTransformInterface(Type t1, Type t2) {
1132 return implementSameInterface<transform::TransformHandleTypeInterface,
1133 transform::TransformParamTypeInterface,
1134 transform::TransformValueHandleTypeInterface>(
1135 t1, t2);
1136}
1137
1138//===----------------------------------------------------------------------===//
1139// CollectMatchingOp
1140//===----------------------------------------------------------------------===//
1141
1142DiagnosedSilenceableFailure
1143transform::CollectMatchingOp::apply(transform::TransformRewriter &rewriter,
1144 transform::TransformResults &results,
1145 transform::TransformState &state) {
1146 auto matcher = SymbolTable::lookupNearestSymbolFrom<FunctionOpInterface>(
1147 from: getOperation(), symbol: getMatcher());
1148 if (matcher.isExternal()) {
1149 return emitDefiniteFailure()
1150 << "unresolved external symbol " << getMatcher();
1151 }
1152
1153 SmallVector<SmallVector<MappedValue>, 2> rawResults;
1154 rawResults.resize(N: getOperation()->getNumResults());
1155 std::optional<DiagnosedSilenceableFailure> maybeFailure;
1156 for (Operation *root : state.getPayloadOps(value: getRoot())) {
1157 WalkResult walkResult = root->walk(callback: [&](Operation *op) {
1158 DEBUG_MATCHER({
1159 DBGS_MATCHER() << "matching ";
1160 op->print(llvm::dbgs(),
1161 OpPrintingFlags().assumeVerified().skipRegions());
1162 llvm::dbgs() << " @" << op << "\n";
1163 });
1164
1165 // Try matching.
1166 SmallVector<SmallVector<MappedValue>> mappings;
1167 SmallVector<transform::MappedValue> inputMapping({op});
1168 DiagnosedSilenceableFailure diag = matchBlock(
1169 block&: matcher.getFunctionBody().front(),
1170 blockArgumentMapping: ArrayRef<SmallVector<transform::MappedValue>>(inputMapping), state,
1171 mappings);
1172 if (diag.isDefiniteFailure())
1173 return WalkResult::interrupt();
1174 if (diag.isSilenceableFailure()) {
1175 DEBUG_MATCHER(DBGS_MATCHER() << "matcher " << matcher.getName()
1176 << " failed: " << diag.getMessage());
1177 return WalkResult::advance();
1178 }
1179
1180 // If succeeded, collect results.
1181 for (auto &&[i, mapping] : llvm::enumerate(First&: mappings)) {
1182 if (mapping.size() != 1) {
1183 maybeFailure.emplace(args: emitSilenceableError()
1184 << "result #" << i << ", associated with "
1185 << mapping.size()
1186 << " payload objects, expected 1");
1187 return WalkResult::interrupt();
1188 }
1189 rawResults[i].push_back(Elt: mapping[0]);
1190 }
1191 return WalkResult::advance();
1192 });
1193 if (walkResult.wasInterrupted())
1194 return std::move(*maybeFailure);
1195 assert(!maybeFailure && "failure set but the walk was not interrupted");
1196
1197 for (auto &&[opResult, rawResult] :
1198 llvm::zip_equal(t: getOperation()->getResults(), u&: rawResults)) {
1199 results.setMappedValues(handle: opResult, values: rawResult);
1200 }
1201 }
1202 return DiagnosedSilenceableFailure::success();
1203}
1204
1205void transform::CollectMatchingOp::getEffects(
1206 SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
1207 onlyReadsHandle(handles: getRootMutable(), effects);
1208 producesHandle(handles: getOperation()->getOpResults(), effects);
1209 onlyReadsPayload(effects);
1210}
1211
1212LogicalResult transform::CollectMatchingOp::verifySymbolUses(
1213 SymbolTableCollection &symbolTable) {
1214 auto matcherSymbol = dyn_cast_or_null<FunctionOpInterface>(
1215 Val: symbolTable.lookupNearestSymbolFrom(from: getOperation(), symbol: getMatcher()));
1216 if (!matcherSymbol ||
1217 !isa<TransformOpInterface>(Val: matcherSymbol.getOperation()))
1218 return emitError() << "unresolved matcher symbol " << getMatcher();
1219
1220 ArrayRef<Type> argumentTypes = matcherSymbol.getArgumentTypes();
1221 if (argumentTypes.size() != 1 ||
1222 !isa<TransformHandleTypeInterface>(Val: argumentTypes[0])) {
1223 return emitError()
1224 << "expected the matcher to take one operation handle argument";
1225 }
1226 if (!matcherSymbol.getArgAttr(
1227 index: 0, name: transform::TransformDialect::kArgReadOnlyAttrName)) {
1228 return emitError() << "expected the matcher argument to be marked readonly";
1229 }
1230
1231 ArrayRef<Type> resultTypes = matcherSymbol.getResultTypes();
1232 if (resultTypes.size() != getOperation()->getNumResults()) {
1233 return emitError()
1234 << "expected the matcher to yield as many values as op has results ("
1235 << getOperation()->getNumResults() << "), got "
1236 << resultTypes.size();
1237 }
1238
1239 for (auto &&[i, matcherType, resultType] :
1240 llvm::enumerate(First&: resultTypes, Rest: getOperation()->getResultTypes())) {
1241 if (implementSameTransformInterface(t1: matcherType, t2: resultType))
1242 continue;
1243
1244 return emitError()
1245 << "mismatching type interfaces for matcher result and op result #"
1246 << i;
1247 }
1248
1249 return success();
1250}
1251
1252//===----------------------------------------------------------------------===//
1253// ForeachMatchOp
1254//===----------------------------------------------------------------------===//
1255
1256// This is fine because nothing is actually consumed by this op.
1257bool transform::ForeachMatchOp::allowsRepeatedHandleOperands() { return true; }
1258
1259DiagnosedSilenceableFailure
1260transform::ForeachMatchOp::apply(transform::TransformRewriter &rewriter,
1261 transform::TransformResults &results,
1262 transform::TransformState &state) {
1263 SmallVector<std::pair<FunctionOpInterface, FunctionOpInterface>>
1264 matchActionPairs;
1265 matchActionPairs.reserve(N: getMatchers().size());
1266 SymbolTableCollection symbolTable;
1267 for (auto &&[matcher, action] :
1268 llvm::zip_equal(t: getMatchers(), u: getActions())) {
1269 auto matcherSymbol =
1270 symbolTable.lookupNearestSymbolFrom<FunctionOpInterface>(
1271 from: getOperation(), symbol: cast<SymbolRefAttr>(Val: matcher));
1272 auto actionSymbol =
1273 symbolTable.lookupNearestSymbolFrom<FunctionOpInterface>(
1274 from: getOperation(), symbol: cast<SymbolRefAttr>(Val: action));
1275 assert(matcherSymbol && actionSymbol &&
1276 "unresolved symbols not caught by the verifier");
1277
1278 if (matcherSymbol.isExternal())
1279 return emitDefiniteFailure() << "unresolved external symbol " << matcher;
1280 if (actionSymbol.isExternal())
1281 return emitDefiniteFailure() << "unresolved external symbol " << action;
1282
1283 matchActionPairs.emplace_back(Args&: matcherSymbol, Args&: actionSymbol);
1284 }
1285
1286 DiagnosedSilenceableFailure overallDiag =
1287 DiagnosedSilenceableFailure::success();
1288
1289 SmallVector<SmallVector<MappedValue>> matchInputMapping;
1290 SmallVector<SmallVector<MappedValue>> matchOutputMapping;
1291 SmallVector<SmallVector<MappedValue>> actionResultMapping;
1292 // Explicitly add the mapping for the first block argument (the op being
1293 // matched).
1294 matchInputMapping.emplace_back();
1295 transform::detail::prepareValueMappings(mappings&: matchInputMapping,
1296 values: getForwardedInputs(), state);
1297 SmallVector<MappedValue> &firstMatchArgument = matchInputMapping.front();
1298 actionResultMapping.resize(N: getForwardedOutputs().size());
1299
1300 for (Operation *root : state.getPayloadOps(value: getRoot())) {
1301 WalkResult walkResult = root->walk(callback: [&](Operation *op) {
1302 // If getRestrictRoot is not present, skip over the root op itself so we
1303 // don't invalidate it.
1304 if (!getRestrictRoot() && op == root)
1305 return WalkResult::advance();
1306
1307 DEBUG_MATCHER({
1308 DBGS_MATCHER() << "matching ";
1309 op->print(llvm::dbgs(),
1310 OpPrintingFlags().assumeVerified().skipRegions());
1311 llvm::dbgs() << " @" << op << "\n";
1312 });
1313
1314 firstMatchArgument.clear();
1315 firstMatchArgument.push_back(Elt: op);
1316
1317 // Try all the match/action pairs until the first successful match.
1318 for (auto [matcher, action] : matchActionPairs) {
1319 DiagnosedSilenceableFailure diag =
1320 matchBlock(block&: matcher.getFunctionBody().front(), blockArgumentMapping: matchInputMapping,
1321 state, mappings&: matchOutputMapping);
1322 if (diag.isDefiniteFailure())
1323 return WalkResult::interrupt();
1324 if (diag.isSilenceableFailure()) {
1325 DEBUG_MATCHER(DBGS_MATCHER() << "matcher " << matcher.getName()
1326 << " failed: " << diag.getMessage());
1327 continue;
1328 }
1329
1330 auto scope = state.make_region_scope(region&: action.getFunctionBody());
1331 if (failed(Result: state.mapBlockArguments(
1332 arguments: action.getFunctionBody().front().getArguments(),
1333 mapping: matchOutputMapping))) {
1334 return WalkResult::interrupt();
1335 }
1336
1337 for (Operation &transform :
1338 action.getFunctionBody().front().without_terminator()) {
1339 DiagnosedSilenceableFailure result =
1340 state.applyTransform(transform: cast<TransformOpInterface>(Val&: transform));
1341 if (result.isDefiniteFailure())
1342 return WalkResult::interrupt();
1343 if (result.isSilenceableFailure()) {
1344 if (overallDiag.succeeded()) {
1345 overallDiag = emitSilenceableError() << "actions failed";
1346 }
1347 overallDiag.attachNote(loc: action->getLoc())
1348 << "failed action: " << result.getMessage();
1349 overallDiag.attachNote(loc: op->getLoc())
1350 << "when applied to this matching payload";
1351 (void)result.silence();
1352 continue;
1353 }
1354 }
1355 if (failed(Result: detail::appendValueMappings(
1356 mappings: MutableArrayRef<SmallVector<MappedValue>>(actionResultMapping),
1357 values: action.getFunctionBody().front().getTerminator()->getOperands(),
1358 state, flatten: getFlattenResults()))) {
1359 emitDefiniteFailure()
1360 << "action @" << action.getName()
1361 << " has results associated with multiple payload entities, "
1362 "but flattening was not requested";
1363 return WalkResult::interrupt();
1364 }
1365 break;
1366 }
1367 return WalkResult::advance();
1368 });
1369 if (walkResult.wasInterrupted())
1370 return DiagnosedSilenceableFailure::definiteFailure();
1371 }
1372
1373 // The root operation should not have been affected, so we can just reassign
1374 // the payload to the result. Note that we need to consume the root handle to
1375 // make sure any handles to operations inside, that could have been affected
1376 // by actions, are invalidated.
1377 results.set(value: llvm::cast<OpResult>(Val: getUpdated()),
1378 ops: state.getPayloadOps(value: getRoot()));
1379 for (auto &&[result, mapping] :
1380 llvm::zip_equal(t: getForwardedOutputs(), u&: actionResultMapping)) {
1381 results.setMappedValues(handle: result, values: mapping);
1382 }
1383 return overallDiag;
1384}
1385
1386void transform::ForeachMatchOp::getAsmResultNames(
1387 OpAsmSetValueNameFn setNameFn) {
1388 setNameFn(getUpdated(), "updated_root");
1389 for (Value v : getForwardedOutputs()) {
1390 setNameFn(v, "yielded");
1391 }
1392}
1393
1394void transform::ForeachMatchOp::getEffects(
1395 SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
1396 // Bail if invalid.
1397 if (getOperation()->getNumOperands() < 1 ||
1398 getOperation()->getNumResults() < 1) {
1399 return modifiesPayload(effects);
1400 }
1401
1402 consumesHandle(handles: getRootMutable(), effects);
1403 onlyReadsHandle(handles: getForwardedInputsMutable(), effects);
1404 producesHandle(handles: getOperation()->getOpResults(), effects);
1405 modifiesPayload(effects);
1406}
1407
1408/// Parses the comma-separated list of symbol reference pairs of the format
1409/// `@matcher -> @action`.
1410static ParseResult parseForeachMatchSymbols(OpAsmParser &parser,
1411 ArrayAttr &matchers,
1412 ArrayAttr &actions) {
1413 StringAttr matcher;
1414 StringAttr action;
1415 SmallVector<Attribute> matcherList;
1416 SmallVector<Attribute> actionList;
1417 do {
1418 if (parser.parseSymbolName(result&: matcher) || parser.parseArrow() ||
1419 parser.parseSymbolName(result&: action)) {
1420 return failure();
1421 }
1422 matcherList.push_back(Elt: SymbolRefAttr::get(value: matcher));
1423 actionList.push_back(Elt: SymbolRefAttr::get(value: action));
1424 } while (parser.parseOptionalComma().succeeded());
1425
1426 matchers = parser.getBuilder().getArrayAttr(value: matcherList);
1427 actions = parser.getBuilder().getArrayAttr(value: actionList);
1428 return success();
1429}
1430
1431/// Prints the comma-separated list of symbol reference pairs of the format
1432/// `@matcher -> @action`.
1433static void printForeachMatchSymbols(OpAsmPrinter &printer, Operation *op,
1434 ArrayAttr matchers, ArrayAttr actions) {
1435 printer.increaseIndent();
1436 printer.increaseIndent();
1437 for (auto &&[matcher, action, idx] : llvm::zip_equal(
1438 t&: matchers, u&: actions, args: llvm::seq<unsigned>(Begin: 0, End: matchers.size()))) {
1439 printer.printNewline();
1440 printer << cast<SymbolRefAttr>(Val: matcher) << " -> "
1441 << cast<SymbolRefAttr>(Val: action);
1442 if (idx != matchers.size() - 1)
1443 printer << ", ";
1444 }
1445 printer.decreaseIndent();
1446 printer.decreaseIndent();
1447}
1448
1449LogicalResult transform::ForeachMatchOp::verify() {
1450 if (getMatchers().size() != getActions().size())
1451 return emitOpError() << "expected the same number of matchers and actions";
1452 if (getMatchers().empty())
1453 return emitOpError() << "expected at least one match/action pair";
1454
1455 llvm::SmallPtrSet<Attribute, 8> matcherNames;
1456 for (Attribute name : getMatchers()) {
1457 if (matcherNames.insert(Ptr: name).second)
1458 continue;
1459 emitWarning() << "matcher " << name
1460 << " is used more than once, only the first match will apply";
1461 }
1462
1463 return success();
1464}
1465
1466/// Checks that the attributes of the function-like operation have correct
1467/// consumption effect annotations. If `alsoVerifyInternal`, checks for
1468/// annotations being present even if they can be inferred from the body.
1469static DiagnosedSilenceableFailure
1470verifyFunctionLikeConsumeAnnotations(FunctionOpInterface op, bool emitWarnings,
1471 bool alsoVerifyInternal = false) {
1472 auto transformOp = cast<transform::TransformOpInterface>(Val: op.getOperation());
1473 llvm::SmallDenseSet<unsigned> consumedArguments;
1474 if (!op.isExternal()) {
1475 transform::getConsumedBlockArguments(block&: op.getFunctionBody().front(),
1476 consumedArguments);
1477 }
1478 for (unsigned i = 0, e = op.getNumArguments(); i < e; ++i) {
1479 bool isConsumed =
1480 op.getArgAttr(index: i, name: transform::TransformDialect::kArgConsumedAttrName) !=
1481 nullptr;
1482 bool isReadOnly =
1483 op.getArgAttr(index: i, name: transform::TransformDialect::kArgReadOnlyAttrName) !=
1484 nullptr;
1485 if (isConsumed && isReadOnly) {
1486 return transformOp.emitSilenceableError()
1487 << "argument #" << i << " cannot be both readonly and consumed";
1488 }
1489 if ((op.isExternal() || alsoVerifyInternal) && !isConsumed && !isReadOnly) {
1490 return transformOp.emitSilenceableError()
1491 << "must provide consumed/readonly status for arguments of "
1492 "external or called ops";
1493 }
1494 if (op.isExternal())
1495 continue;
1496
1497 if (consumedArguments.contains(V: i) && !isConsumed && isReadOnly) {
1498 return transformOp.emitSilenceableError()
1499 << "argument #" << i
1500 << " is consumed in the body but is not marked as such";
1501 }
1502 if (emitWarnings && !consumedArguments.contains(V: i) && isConsumed) {
1503 // Cannot use op.emitWarning() here as it would attempt to verify the op
1504 // before printing, resulting in infinite recursion.
1505 emitWarning(loc: op->getLoc())
1506 << "op argument #" << i
1507 << " is not consumed in the body but is marked as consumed";
1508 }
1509 }
1510 return DiagnosedSilenceableFailure::success();
1511}
1512
1513LogicalResult transform::ForeachMatchOp::verifySymbolUses(
1514 SymbolTableCollection &symbolTable) {
1515 assert(getMatchers().size() == getActions().size());
1516 auto consumedAttr =
1517 StringAttr::get(context: getContext(), bytes: TransformDialect::kArgConsumedAttrName);
1518 for (auto &&[matcher, action] :
1519 llvm::zip_equal(t: getMatchers(), u: getActions())) {
1520 // Presence and typing.
1521 auto matcherSymbol = dyn_cast_or_null<FunctionOpInterface>(
1522 Val: symbolTable.lookupNearestSymbolFrom(from: getOperation(),
1523 symbol: cast<SymbolRefAttr>(Val: matcher)));
1524 auto actionSymbol = dyn_cast_or_null<FunctionOpInterface>(
1525 Val: symbolTable.lookupNearestSymbolFrom(from: getOperation(),
1526 symbol: cast<SymbolRefAttr>(Val: action)));
1527 if (!matcherSymbol ||
1528 !isa<TransformOpInterface>(Val: matcherSymbol.getOperation()))
1529 return emitError() << "unresolved matcher symbol " << matcher;
1530 if (!actionSymbol ||
1531 !isa<TransformOpInterface>(Val: actionSymbol.getOperation()))
1532 return emitError() << "unresolved action symbol " << action;
1533
1534 if (failed(Result: verifyFunctionLikeConsumeAnnotations(op: matcherSymbol,
1535 /*emitWarnings=*/false,
1536 /*alsoVerifyInternal=*/true)
1537 .checkAndReport())) {
1538 return failure();
1539 }
1540 if (failed(Result: verifyFunctionLikeConsumeAnnotations(op: actionSymbol,
1541 /*emitWarnings=*/false,
1542 /*alsoVerifyInternal=*/true)
1543 .checkAndReport())) {
1544 return failure();
1545 }
1546
1547 // Input -> matcher forwarding.
1548 TypeRange operandTypes = getOperandTypes();
1549 TypeRange matcherArguments = matcherSymbol.getArgumentTypes();
1550 if (operandTypes.size() != matcherArguments.size()) {
1551 InFlightDiagnostic diag =
1552 emitError() << "the number of operands (" << operandTypes.size()
1553 << ") doesn't match the number of matcher arguments ("
1554 << matcherArguments.size() << ") for " << matcher;
1555 diag.attachNote(noteLoc: matcherSymbol->getLoc()) << "symbol declaration";
1556 return diag;
1557 }
1558 for (auto &&[i, operand, argument] :
1559 llvm::enumerate(First&: operandTypes, Rest&: matcherArguments)) {
1560 if (matcherSymbol.getArgAttr(index: i, name: consumedAttr)) {
1561 InFlightDiagnostic diag =
1562 emitOpError()
1563 << "does not expect matcher symbol to consume its operand #" << i;
1564 diag.attachNote(noteLoc: matcherSymbol->getLoc()) << "symbol declaration";
1565 return diag;
1566 }
1567
1568 if (implementSameTransformInterface(t1: operand, t2: argument))
1569 continue;
1570
1571 InFlightDiagnostic diag =
1572 emitError()
1573 << "mismatching type interfaces for operand and matcher argument #"
1574 << i << " of matcher " << matcher;
1575 diag.attachNote(noteLoc: matcherSymbol->getLoc()) << "symbol declaration";
1576 return diag;
1577 }
1578
1579 // Matcher -> action forwarding.
1580 TypeRange matcherResults = matcherSymbol.getResultTypes();
1581 TypeRange actionArguments = actionSymbol.getArgumentTypes();
1582 if (matcherResults.size() != actionArguments.size()) {
1583 return emitError() << "mismatching number of matcher results and "
1584 "action arguments between "
1585 << matcher << " (" << matcherResults.size() << ") and "
1586 << action << " (" << actionArguments.size() << ")";
1587 }
1588 for (auto &&[i, matcherType, actionType] :
1589 llvm::enumerate(First&: matcherResults, Rest&: actionArguments)) {
1590 if (implementSameTransformInterface(t1: matcherType, t2: actionType))
1591 continue;
1592
1593 return emitError() << "mismatching type interfaces for matcher result "
1594 "and action argument #"
1595 << i << "of matcher " << matcher << " and action "
1596 << action;
1597 }
1598
1599 // Action -> result forwarding.
1600 TypeRange actionResults = actionSymbol.getResultTypes();
1601 auto resultTypes = TypeRange(getResultTypes()).drop_front();
1602 if (actionResults.size() != resultTypes.size()) {
1603 InFlightDiagnostic diag =
1604 emitError() << "the number of action results ("
1605 << actionResults.size() << ") for " << action
1606 << " doesn't match the number of extra op results ("
1607 << resultTypes.size() << ")";
1608 diag.attachNote(noteLoc: actionSymbol->getLoc()) << "symbol declaration";
1609 return diag;
1610 }
1611 for (auto &&[i, resultType, actionType] :
1612 llvm::enumerate(First&: resultTypes, Rest&: actionResults)) {
1613 if (implementSameTransformInterface(t1: resultType, t2: actionType))
1614 continue;
1615
1616 InFlightDiagnostic diag =
1617 emitError() << "mismatching type interfaces for action result #" << i
1618 << " of action " << action << " and op result";
1619 diag.attachNote(noteLoc: actionSymbol->getLoc()) << "symbol declaration";
1620 return diag;
1621 }
1622 }
1623 return success();
1624}
1625
1626//===----------------------------------------------------------------------===//
1627// ForeachOp
1628//===----------------------------------------------------------------------===//
1629
1630DiagnosedSilenceableFailure
1631transform::ForeachOp::apply(transform::TransformRewriter &rewriter,
1632 transform::TransformResults &results,
1633 transform::TransformState &state) {
1634 // We store the payloads before executing the body as ops may be removed from
1635 // the mapping by the TrackingRewriter while iteration is in progress.
1636 SmallVector<SmallVector<MappedValue>> payloads;
1637 detail::prepareValueMappings(mappings&: payloads, values: getTargets(), state);
1638 size_t numIterations = payloads.empty() ? 0 : payloads.front().size();
1639 bool withZipShortest = getWithZipShortest();
1640
1641 // In case of `zip_shortest`, set the number of iterations to the
1642 // smallest payload in the targets.
1643 if (withZipShortest) {
1644 numIterations =
1645 llvm::min_element(Range&: payloads, C: [&](const SmallVector<MappedValue> &A,
1646 const SmallVector<MappedValue> &B) {
1647 return A.size() < B.size();
1648 })->size();
1649
1650 for (size_t argIdx = 0; argIdx < payloads.size(); argIdx++)
1651 payloads[argIdx].resize(N: numIterations);
1652 }
1653
1654 // As we will be "zipping" over them, check all payloads have the same size.
1655 // `zip_shortest` adjusts all payloads to the same size, so skip this check
1656 // when true.
1657 for (size_t argIdx = 1; !withZipShortest && argIdx < payloads.size();
1658 argIdx++) {
1659 if (payloads[argIdx].size() != numIterations) {
1660 return emitSilenceableError()
1661 << "prior targets' payload size (" << numIterations
1662 << ") differs from payload size (" << payloads[argIdx].size()
1663 << ") of target " << getTargets()[argIdx];
1664 }
1665 }
1666
1667 // Start iterating, indexing into payloads to obtain the right arguments to
1668 // call the body with - each slice of payloads at the same argument index
1669 // corresponding to a tuple to use as the body's block arguments.
1670 ArrayRef<BlockArgument> blockArguments = getBody().front().getArguments();
1671 SmallVector<SmallVector<MappedValue>> zippedResults(getNumResults(), {});
1672 for (size_t iterIdx = 0; iterIdx < numIterations; iterIdx++) {
1673 auto scope = state.make_region_scope(region&: getBody());
1674 // Set up arguments to the region's block.
1675 for (auto &&[argIdx, blockArg] : llvm::enumerate(First&: blockArguments)) {
1676 MappedValue argument = payloads[argIdx][iterIdx];
1677 // Note that each blockArg's handle gets associated with just a single
1678 // element from the corresponding target's payload.
1679 if (failed(Result: state.mapBlockArgument(argument: blockArg, values: {argument})))
1680 return DiagnosedSilenceableFailure::definiteFailure();
1681 }
1682
1683 // Execute loop body.
1684 for (Operation &transform : getBody().front().without_terminator()) {
1685 DiagnosedSilenceableFailure result = state.applyTransform(
1686 transform: llvm::cast<transform::TransformOpInterface>(Val&: transform));
1687 if (!result.succeeded())
1688 return result;
1689 }
1690
1691 // Append yielded payloads to corresponding results from prior iterations.
1692 OperandRange yieldOperands = getYieldOp().getOperands();
1693 for (auto &&[result, yieldOperand, resTuple] :
1694 llvm::zip_equal(t: getResults(), u&: yieldOperands, args&: zippedResults))
1695 // NB: each iteration we add any number of ops/vals/params to a result.
1696 if (isa<TransformHandleTypeInterface>(Val: result.getType()))
1697 llvm::append_range(C&: resTuple, R: state.getPayloadOps(value: yieldOperand));
1698 else if (isa<TransformValueHandleTypeInterface>(Val: result.getType()))
1699 llvm::append_range(C&: resTuple, R: state.getPayloadValues(handleValue: yieldOperand));
1700 else if (isa<TransformParamTypeInterface>(Val: result.getType()))
1701 llvm::append_range(C&: resTuple, R: state.getParams(value: yieldOperand));
1702 else
1703 assert(false && "unhandled handle type");
1704 }
1705
1706 // Associate the accumulated result payloads to the op's actual results.
1707 for (auto &&[result, resPayload] : zip_equal(t: getResults(), u&: zippedResults))
1708 results.setMappedValues(handle: llvm::cast<OpResult>(Val&: result), values: resPayload);
1709
1710 return DiagnosedSilenceableFailure::success();
1711}
1712
1713void transform::ForeachOp::getEffects(
1714 SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
1715 // NB: this `zip` should be `zip_equal` - while this op's verifier catches
1716 // arity errors, this method might get called before/in absence of `verify()`.
1717 for (auto &&[target, blockArg] :
1718 llvm::zip(t: getTargetsMutable(), u: getBody().front().getArguments())) {
1719 BlockArgument blockArgument = blockArg;
1720 if (any_of(Range: getBody().front().without_terminator(), P: [&](Operation &op) {
1721 return isHandleConsumed(handle: blockArgument,
1722 transform: cast<TransformOpInterface>(Val: &op));
1723 })) {
1724 consumesHandle(handles: target, effects);
1725 } else {
1726 onlyReadsHandle(handles: target, effects);
1727 }
1728 }
1729
1730 if (any_of(Range: getBody().front().without_terminator(), P: [&](Operation &op) {
1731 return doesModifyPayload(transform: cast<TransformOpInterface>(Val: &op));
1732 })) {
1733 modifiesPayload(effects);
1734 } else if (any_of(Range: getBody().front().without_terminator(), P: [&](Operation &op) {
1735 return doesReadPayload(transform: cast<TransformOpInterface>(Val: &op));
1736 })) {
1737 onlyReadsPayload(effects);
1738 }
1739
1740 producesHandle(handles: getOperation()->getOpResults(), effects);
1741}
1742
1743void transform::ForeachOp::getSuccessorRegions(
1744 RegionBranchPoint point, SmallVectorImpl<RegionSuccessor> &regions) {
1745 Region *bodyRegion = &getBody();
1746 if (point.isParent()) {
1747 regions.emplace_back(Args&: bodyRegion, Args: bodyRegion->getArguments());
1748 return;
1749 }
1750
1751 // Branch back to the region or the parent.
1752 assert(point == getBody() && "unexpected region index");
1753 regions.emplace_back(Args&: bodyRegion, Args: bodyRegion->getArguments());
1754 regions.emplace_back();
1755}
1756
1757OperandRange
1758transform::ForeachOp::getEntrySuccessorOperands(RegionBranchPoint point) {
1759 // Each block argument handle is mapped to a subset (one op to be precise)
1760 // of the payload of the corresponding `targets` operand of ForeachOp.
1761 assert(point == getBody() && "unexpected region index");
1762 return getOperation()->getOperands();
1763}
1764
1765transform::YieldOp transform::ForeachOp::getYieldOp() {
1766 return cast<transform::YieldOp>(Val: getBody().front().getTerminator());
1767}
1768
1769LogicalResult transform::ForeachOp::verify() {
1770 for (auto [targetOpt, bodyArgOpt] :
1771 llvm::zip_longest(t: getTargets(), u: getBody().front().getArguments())) {
1772 if (!targetOpt || !bodyArgOpt)
1773 return emitOpError() << "expects the same number of targets as the body "
1774 "has block arguments";
1775 if (targetOpt.value().getType() != bodyArgOpt.value().getType())
1776 return emitOpError(
1777 message: "expects co-indexed targets and the body's "
1778 "block arguments to have the same op/value/param type");
1779 }
1780
1781 for (auto [resultOpt, yieldOperandOpt] :
1782 llvm::zip_longest(t: getResults(), u: getYieldOp().getOperands())) {
1783 if (!resultOpt || !yieldOperandOpt)
1784 return emitOpError() << "expects the same number of results as the "
1785 "yield terminator has operands";
1786 if (resultOpt.value().getType() != yieldOperandOpt.value().getType())
1787 return emitOpError(message: "expects co-indexed results and yield "
1788 "operands to have the same op/value/param type");
1789 }
1790
1791 return success();
1792}
1793
1794//===----------------------------------------------------------------------===//
1795// GetParentOp
1796//===----------------------------------------------------------------------===//
1797
1798DiagnosedSilenceableFailure
1799transform::GetParentOp::apply(transform::TransformRewriter &rewriter,
1800 transform::TransformResults &results,
1801 transform::TransformState &state) {
1802 SmallVector<Operation *> parents;
1803 DenseSet<Operation *> resultSet;
1804 for (Operation *target : state.getPayloadOps(value: getTarget())) {
1805 Operation *parent = target;
1806 for (int64_t i = 0, e = getNthParent(); i < e; ++i) {
1807 parent = parent->getParentOp();
1808 while (parent) {
1809 bool checkIsolatedFromAbove =
1810 !getIsolatedFromAbove() ||
1811 parent->hasTrait<OpTrait::IsIsolatedFromAbove>();
1812 bool checkOpName = !getOpName().has_value() ||
1813 parent->getName().getStringRef() == *getOpName();
1814 if (checkIsolatedFromAbove && checkOpName)
1815 break;
1816 parent = parent->getParentOp();
1817 }
1818 if (!parent) {
1819 if (getAllowEmptyResults()) {
1820 results.set(value: llvm::cast<OpResult>(Val: getResult()), ops&: parents);
1821 return DiagnosedSilenceableFailure::success();
1822 }
1823 DiagnosedSilenceableFailure diag =
1824 emitSilenceableError()
1825 << "could not find a parent op that matches all requirements";
1826 diag.attachNote(loc: target->getLoc()) << "target op";
1827 return diag;
1828 }
1829 }
1830 if (getDeduplicate()) {
1831 if (resultSet.insert(V: parent).second)
1832 parents.push_back(Elt: parent);
1833 } else {
1834 parents.push_back(Elt: parent);
1835 }
1836 }
1837 results.set(value: llvm::cast<OpResult>(Val: getResult()), ops&: parents);
1838 return DiagnosedSilenceableFailure::success();
1839}
1840
1841//===----------------------------------------------------------------------===//
1842// GetConsumersOfResult
1843//===----------------------------------------------------------------------===//
1844
1845DiagnosedSilenceableFailure
1846transform::GetConsumersOfResult::apply(transform::TransformRewriter &rewriter,
1847 transform::TransformResults &results,
1848 transform::TransformState &state) {
1849 int64_t resultNumber = getResultNumber();
1850 auto payloadOps = state.getPayloadOps(value: getTarget());
1851 if (std::empty(cont: payloadOps)) {
1852 results.set(value: cast<OpResult>(Val: getResult()), ops: {});
1853 return DiagnosedSilenceableFailure::success();
1854 }
1855 if (!llvm::hasSingleElement(C&: payloadOps))
1856 return emitDefiniteFailure()
1857 << "handle must be mapped to exactly one payload op";
1858
1859 Operation *target = *payloadOps.begin();
1860 if (target->getNumResults() <= resultNumber)
1861 return emitDefiniteFailure() << "result number overflow";
1862 results.set(value: llvm::cast<OpResult>(Val: getResult()),
1863 ops: llvm::to_vector(Range: target->getResult(idx: resultNumber).getUsers()));
1864 return DiagnosedSilenceableFailure::success();
1865}
1866
1867//===----------------------------------------------------------------------===//
1868// GetDefiningOp
1869//===----------------------------------------------------------------------===//
1870
1871DiagnosedSilenceableFailure
1872transform::GetDefiningOp::apply(transform::TransformRewriter &rewriter,
1873 transform::TransformResults &results,
1874 transform::TransformState &state) {
1875 SmallVector<Operation *> definingOps;
1876 for (Value v : state.getPayloadValues(handleValue: getTarget())) {
1877 if (llvm::isa<BlockArgument>(Val: v)) {
1878 DiagnosedSilenceableFailure diag =
1879 emitSilenceableError() << "cannot get defining op of block argument";
1880 diag.attachNote(loc: v.getLoc()) << "target value";
1881 return diag;
1882 }
1883 definingOps.push_back(Elt: v.getDefiningOp());
1884 }
1885 results.set(value: llvm::cast<OpResult>(Val: getResult()), ops&: definingOps);
1886 return DiagnosedSilenceableFailure::success();
1887}
1888
1889//===----------------------------------------------------------------------===//
1890// GetProducerOfOperand
1891//===----------------------------------------------------------------------===//
1892
1893DiagnosedSilenceableFailure
1894transform::GetProducerOfOperand::apply(transform::TransformRewriter &rewriter,
1895 transform::TransformResults &results,
1896 transform::TransformState &state) {
1897 int64_t operandNumber = getOperandNumber();
1898 SmallVector<Operation *> producers;
1899 for (Operation *target : state.getPayloadOps(value: getTarget())) {
1900 Operation *producer =
1901 target->getNumOperands() <= operandNumber
1902 ? nullptr
1903 : target->getOperand(idx: operandNumber).getDefiningOp();
1904 if (!producer) {
1905 DiagnosedSilenceableFailure diag =
1906 emitSilenceableError()
1907 << "could not find a producer for operand number: " << operandNumber
1908 << " of " << *target;
1909 diag.attachNote(loc: target->getLoc()) << "target op";
1910 return diag;
1911 }
1912 producers.push_back(Elt: producer);
1913 }
1914 results.set(value: llvm::cast<OpResult>(Val: getResult()), ops&: producers);
1915 return DiagnosedSilenceableFailure::success();
1916}
1917
1918//===----------------------------------------------------------------------===//
1919// GetOperandOp
1920//===----------------------------------------------------------------------===//
1921
1922DiagnosedSilenceableFailure
1923transform::GetOperandOp::apply(transform::TransformRewriter &rewriter,
1924 transform::TransformResults &results,
1925 transform::TransformState &state) {
1926 SmallVector<Value> operands;
1927 for (Operation *target : state.getPayloadOps(value: getTarget())) {
1928 SmallVector<int64_t> operandPositions;
1929 DiagnosedSilenceableFailure diag = expandTargetSpecification(
1930 loc: getLoc(), isAll: getIsAll(), isInverted: getIsInverted(), rawList: getRawPositionList(),
1931 maxNumber: target->getNumOperands(), result&: operandPositions);
1932 if (diag.isSilenceableFailure()) {
1933 diag.attachNote(loc: target->getLoc())
1934 << "while considering positions of this payload operation";
1935 return diag;
1936 }
1937 llvm::append_range(C&: operands,
1938 R: llvm::map_range(C&: operandPositions, F: [&](int64_t pos) {
1939 return target->getOperand(idx: pos);
1940 }));
1941 }
1942 results.setValues(handle: cast<OpResult>(Val: getResult()), values&: operands);
1943 return DiagnosedSilenceableFailure::success();
1944}
1945
1946LogicalResult transform::GetOperandOp::verify() {
1947 return verifyTransformMatchDimsOp(op: getOperation(), raw: getRawPositionList(),
1948 inverted: getIsInverted(), all: getIsAll());
1949}
1950
1951//===----------------------------------------------------------------------===//
1952// GetResultOp
1953//===----------------------------------------------------------------------===//
1954
1955DiagnosedSilenceableFailure
1956transform::GetResultOp::apply(transform::TransformRewriter &rewriter,
1957 transform::TransformResults &results,
1958 transform::TransformState &state) {
1959 SmallVector<Value> opResults;
1960 for (Operation *target : state.getPayloadOps(value: getTarget())) {
1961 SmallVector<int64_t> resultPositions;
1962 DiagnosedSilenceableFailure diag = expandTargetSpecification(
1963 loc: getLoc(), isAll: getIsAll(), isInverted: getIsInverted(), rawList: getRawPositionList(),
1964 maxNumber: target->getNumResults(), result&: resultPositions);
1965 if (diag.isSilenceableFailure()) {
1966 diag.attachNote(loc: target->getLoc())
1967 << "while considering positions of this payload operation";
1968 return diag;
1969 }
1970 llvm::append_range(C&: opResults,
1971 R: llvm::map_range(C&: resultPositions, F: [&](int64_t pos) {
1972 return target->getResult(idx: pos);
1973 }));
1974 }
1975 results.setValues(handle: cast<OpResult>(Val: getResult()), values&: opResults);
1976 return DiagnosedSilenceableFailure::success();
1977}
1978
1979LogicalResult transform::GetResultOp::verify() {
1980 return verifyTransformMatchDimsOp(op: getOperation(), raw: getRawPositionList(),
1981 inverted: getIsInverted(), all: getIsAll());
1982}
1983
1984//===----------------------------------------------------------------------===//
1985// GetTypeOp
1986//===----------------------------------------------------------------------===//
1987
1988void transform::GetTypeOp::getEffects(
1989 SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
1990 onlyReadsHandle(handles: getValueMutable(), effects);
1991 producesHandle(handles: getOperation()->getOpResults(), effects);
1992 onlyReadsPayload(effects);
1993}
1994
1995DiagnosedSilenceableFailure
1996transform::GetTypeOp::apply(transform::TransformRewriter &rewriter,
1997 transform::TransformResults &results,
1998 transform::TransformState &state) {
1999 SmallVector<Attribute> params;
2000 for (Value value : state.getPayloadValues(handleValue: getValue())) {
2001 Type type = value.getType();
2002 if (getElemental()) {
2003 if (auto shaped = dyn_cast<ShapedType>(Val&: type)) {
2004 type = shaped.getElementType();
2005 }
2006 }
2007 params.push_back(Elt: TypeAttr::get(type));
2008 }
2009 results.setParams(value: cast<OpResult>(Val: getResult()), params);
2010 return DiagnosedSilenceableFailure::success();
2011}
2012
2013//===----------------------------------------------------------------------===//
2014// IncludeOp
2015//===----------------------------------------------------------------------===//
2016
2017/// Applies the transform ops contained in `block`. Maps `results` to the same
2018/// values as the operands of the block terminator.
2019static DiagnosedSilenceableFailure
2020applySequenceBlock(Block &block, transform::FailurePropagationMode mode,
2021 transform::TransformState &state,
2022 transform::TransformResults &results) {
2023 // Apply the sequenced ops one by one.
2024 for (Operation &transform : block.without_terminator()) {
2025 DiagnosedSilenceableFailure result =
2026 state.applyTransform(transform: cast<transform::TransformOpInterface>(Val&: transform));
2027 if (result.isDefiniteFailure())
2028 return result;
2029
2030 if (result.isSilenceableFailure()) {
2031 if (mode == transform::FailurePropagationMode::Propagate) {
2032 // Propagate empty results in case of early exit.
2033 forwardEmptyOperands(block: &block, state, results);
2034 return result;
2035 }
2036 (void)result.silence();
2037 }
2038 }
2039
2040 // Forward the operation mapping for values yielded from the sequence to the
2041 // values produced by the sequence op.
2042 transform::detail::forwardTerminatorOperands(block: &block, state, results);
2043 return DiagnosedSilenceableFailure::success();
2044}
2045
2046DiagnosedSilenceableFailure
2047transform::IncludeOp::apply(transform::TransformRewriter &rewriter,
2048 transform::TransformResults &results,
2049 transform::TransformState &state) {
2050 auto callee = SymbolTable::lookupNearestSymbolFrom<NamedSequenceOp>(
2051 from: getOperation(), symbol: getTarget());
2052 assert(callee && "unverified reference to unknown symbol");
2053
2054 if (callee.isExternal())
2055 return emitDefiniteFailure() << "unresolved external named sequence";
2056
2057 // Map operands to block arguments.
2058 SmallVector<SmallVector<MappedValue>> mappings;
2059 detail::prepareValueMappings(mappings, values: getOperands(), state);
2060 auto scope = state.make_region_scope(region&: callee.getBody());
2061 for (auto &&[arg, map] :
2062 llvm::zip_equal(t: callee.getBody().front().getArguments(), u&: mappings)) {
2063 if (failed(Result: state.mapBlockArgument(argument: arg, values: map)))
2064 return DiagnosedSilenceableFailure::definiteFailure();
2065 }
2066
2067 DiagnosedSilenceableFailure result = applySequenceBlock(
2068 block&: callee.getBody().front(), mode: getFailurePropagationMode(), state, results);
2069 mappings.clear();
2070 detail::prepareValueMappings(
2071 mappings, values: callee.getBody().front().getTerminator()->getOperands(), state);
2072 for (auto &&[result, mapping] : llvm::zip_equal(t: getResults(), u&: mappings))
2073 results.setMappedValues(handle: result, values: mapping);
2074 return result;
2075}
2076
2077static DiagnosedSilenceableFailure
2078verifyNamedSequenceOp(transform::NamedSequenceOp op, bool emitWarnings);
2079
2080void transform::IncludeOp::getEffects(
2081 SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
2082 // Always mark as modifying the payload.
2083 // TODO: a mechanism to annotate effects on payload. Even when all handles are
2084 // only read, the payload may still be modified, so we currently stay on the
2085 // conservative side and always indicate modification. This may prevent some
2086 // code reordering.
2087 modifiesPayload(effects);
2088
2089 // Results are always produced.
2090 producesHandle(handles: getOperation()->getOpResults(), effects);
2091
2092 // Adds default effects to operands and results. This will be added if
2093 // preconditions fail so the trait verifier doesn't complain about missing
2094 // effects and the real precondition failure is reported later on.
2095 auto defaultEffects = [&] {
2096 onlyReadsHandle(handles: getOperation()->getOpOperands(), effects);
2097 };
2098
2099 // Bail if the callee is unknown. This may run as part of the verification
2100 // process before we verified the validity of the callee or of this op.
2101 auto target =
2102 getOperation()->getAttrOfType<SymbolRefAttr>(name: getTargetAttrName());
2103 if (!target)
2104 return defaultEffects();
2105 auto callee = SymbolTable::lookupNearestSymbolFrom<NamedSequenceOp>(
2106 from: getOperation(), symbol: getTarget());
2107 if (!callee)
2108 return defaultEffects();
2109 DiagnosedSilenceableFailure earlyVerifierResult =
2110 verifyNamedSequenceOp(op: callee, /*emitWarnings=*/false);
2111 if (!earlyVerifierResult.succeeded()) {
2112 (void)earlyVerifierResult.silence();
2113 return defaultEffects();
2114 }
2115
2116 for (unsigned i = 0, e = getNumOperands(); i < e; ++i) {
2117 if (callee.getArgAttr(index: i, name: TransformDialect::kArgConsumedAttrName))
2118 consumesHandle(handles: getOperation()->getOpOperand(idx: i), effects);
2119 else
2120 onlyReadsHandle(handles: getOperation()->getOpOperand(idx: i), effects);
2121 }
2122}
2123
2124LogicalResult
2125transform::IncludeOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
2126 // Access through indirection and do additional checking because this may be
2127 // running before the main op verifier.
2128 auto targetAttr = getOperation()->getAttrOfType<SymbolRefAttr>(name: "target");
2129 if (!targetAttr)
2130 return emitOpError() << "expects a 'target' symbol reference attribute";
2131
2132 auto target = symbolTable.lookupNearestSymbolFrom<transform::NamedSequenceOp>(
2133 from: *this, symbol: targetAttr);
2134 if (!target)
2135 return emitOpError() << "does not reference a named transform sequence";
2136
2137 FunctionType fnType = target.getFunctionType();
2138 if (fnType.getNumInputs() != getNumOperands())
2139 return emitError(message: "incorrect number of operands for callee");
2140
2141 for (unsigned i = 0, e = fnType.getNumInputs(); i != e; ++i) {
2142 if (getOperand(i).getType() != fnType.getInput(i)) {
2143 return emitOpError(message: "operand type mismatch: expected operand type ")
2144 << fnType.getInput(i) << ", but provided "
2145 << getOperand(i).getType() << " for operand number " << i;
2146 }
2147 }
2148
2149 if (fnType.getNumResults() != getNumResults())
2150 return emitError(message: "incorrect number of results for callee");
2151
2152 for (unsigned i = 0, e = fnType.getNumResults(); i != e; ++i) {
2153 Type resultType = getResult(i).getType();
2154 Type funcType = fnType.getResult(i);
2155 if (!implementSameTransformInterface(t1: resultType, t2: funcType)) {
2156 return emitOpError() << "type of result #" << i
2157 << " must implement the same transform dialect "
2158 "interface as the corresponding callee result";
2159 }
2160 }
2161
2162 return verifyFunctionLikeConsumeAnnotations(
2163 op: cast<FunctionOpInterface>(Val&: *target), /*emitWarnings=*/false,
2164 /*alsoVerifyInternal=*/true)
2165 .checkAndReport();
2166}
2167
2168//===----------------------------------------------------------------------===//
2169// MatchOperationEmptyOp
2170//===----------------------------------------------------------------------===//
2171
2172DiagnosedSilenceableFailure transform::MatchOperationEmptyOp::matchOperation(
2173 ::std::optional<::mlir::Operation *> maybeCurrent,
2174 transform::TransformResults &results, transform::TransformState &state) {
2175 if (!maybeCurrent.has_value()) {
2176 DEBUG_MATCHER({ DBGS_MATCHER() << "MatchOperationEmptyOp success\n"; });
2177 return DiagnosedSilenceableFailure::success();
2178 }
2179 DEBUG_MATCHER({ DBGS_MATCHER() << "MatchOperationEmptyOp failure\n"; });
2180 return emitSilenceableError() << "operation is not empty";
2181}
2182
2183//===----------------------------------------------------------------------===//
2184// MatchOperationNameOp
2185//===----------------------------------------------------------------------===//
2186
2187DiagnosedSilenceableFailure transform::MatchOperationNameOp::matchOperation(
2188 Operation *current, transform::TransformResults &results,
2189 transform::TransformState &state) {
2190 StringRef currentOpName = current->getName().getStringRef();
2191 for (auto acceptedAttr : getOpNames().getAsRange<StringAttr>()) {
2192 if (acceptedAttr.getValue() == currentOpName)
2193 return DiagnosedSilenceableFailure::success();
2194 }
2195 return emitSilenceableError() << "wrong operation name";
2196}
2197
2198//===----------------------------------------------------------------------===//
2199// MatchParamCmpIOp
2200//===----------------------------------------------------------------------===//
2201
2202DiagnosedSilenceableFailure
2203transform::MatchParamCmpIOp::apply(transform::TransformRewriter &rewriter,
2204 transform::TransformResults &results,
2205 transform::TransformState &state) {
2206 auto signedAPIntAsString = [&](const APInt &value) {
2207 std::string str;
2208 llvm::raw_string_ostream os(str);
2209 value.print(OS&: os, /*isSigned=*/true);
2210 return str;
2211 };
2212
2213 ArrayRef<Attribute> params = state.getParams(value: getParam());
2214 ArrayRef<Attribute> references = state.getParams(value: getReference());
2215
2216 if (params.size() != references.size()) {
2217 return emitSilenceableError()
2218 << "parameters have different payload lengths (" << params.size()
2219 << " vs " << references.size() << ")";
2220 }
2221
2222 for (auto &&[i, param, reference] : llvm::enumerate(First&: params, Rest&: references)) {
2223 auto intAttr = llvm::dyn_cast<IntegerAttr>(Val: param);
2224 auto refAttr = llvm::dyn_cast<IntegerAttr>(Val: reference);
2225 if (!intAttr || !refAttr) {
2226 return emitDefiniteFailure()
2227 << "non-integer parameter value not expected";
2228 }
2229 if (intAttr.getType() != refAttr.getType()) {
2230 return emitDefiniteFailure()
2231 << "mismatching integer attribute types in parameter #" << i;
2232 }
2233 APInt value = intAttr.getValue();
2234 APInt refValue = refAttr.getValue();
2235
2236 // TODO: this copy will not be necessary in C++20.
2237 int64_t position = i;
2238 auto reportError = [&](StringRef direction) {
2239 DiagnosedSilenceableFailure diag =
2240 emitSilenceableError() << "expected parameter to be " << direction
2241 << " " << signedAPIntAsString(refValue)
2242 << ", got " << signedAPIntAsString(value);
2243 diag.attachNote(loc: getParam().getLoc())
2244 << "value # " << position
2245 << " associated with the parameter defined here";
2246 return diag;
2247 };
2248
2249 switch (getPredicate()) {
2250 case MatchCmpIPredicate::eq:
2251 if (value.eq(RHS: refValue))
2252 break;
2253 return reportError("equal to");
2254 case MatchCmpIPredicate::ne:
2255 if (value.ne(RHS: refValue))
2256 break;
2257 return reportError("not equal to");
2258 case MatchCmpIPredicate::lt:
2259 if (value.slt(RHS: refValue))
2260 break;
2261 return reportError("less than");
2262 case MatchCmpIPredicate::le:
2263 if (value.sle(RHS: refValue))
2264 break;
2265 return reportError("less than or equal to");
2266 case MatchCmpIPredicate::gt:
2267 if (value.sgt(RHS: refValue))
2268 break;
2269 return reportError("greater than");
2270 case MatchCmpIPredicate::ge:
2271 if (value.sge(RHS: refValue))
2272 break;
2273 return reportError("greater than or equal to");
2274 }
2275 }
2276 return DiagnosedSilenceableFailure::success();
2277}
2278
2279void transform::MatchParamCmpIOp::getEffects(
2280 SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
2281 onlyReadsHandle(handles: getParamMutable(), effects);
2282 onlyReadsHandle(handles: getReferenceMutable(), effects);
2283}
2284
2285//===----------------------------------------------------------------------===//
2286// ParamConstantOp
2287//===----------------------------------------------------------------------===//
2288
2289DiagnosedSilenceableFailure
2290transform::ParamConstantOp::apply(transform::TransformRewriter &rewriter,
2291 transform::TransformResults &results,
2292 transform::TransformState &state) {
2293 results.setParams(value: cast<OpResult>(Val: getParam()), params: {getValue()});
2294 return DiagnosedSilenceableFailure::success();
2295}
2296
2297//===----------------------------------------------------------------------===//
2298// MergeHandlesOp
2299//===----------------------------------------------------------------------===//
2300
2301DiagnosedSilenceableFailure
2302transform::MergeHandlesOp::apply(transform::TransformRewriter &rewriter,
2303 transform::TransformResults &results,
2304 transform::TransformState &state) {
2305 ValueRange handles = getHandles();
2306 if (isa<TransformHandleTypeInterface>(Val: handles.front().getType())) {
2307 SmallVector<Operation *> operations;
2308 for (Value operand : handles)
2309 llvm::append_range(C&: operations, R: state.getPayloadOps(value: operand));
2310 if (!getDeduplicate()) {
2311 results.set(value: llvm::cast<OpResult>(Val: getResult()), ops&: operations);
2312 return DiagnosedSilenceableFailure::success();
2313 }
2314
2315 SetVector<Operation *> uniqued(llvm::from_range, operations);
2316 results.set(value: llvm::cast<OpResult>(Val: getResult()), ops: uniqued.getArrayRef());
2317 return DiagnosedSilenceableFailure::success();
2318 }
2319
2320 if (llvm::isa<TransformParamTypeInterface>(Val: handles.front().getType())) {
2321 SmallVector<Attribute> attrs;
2322 for (Value attribute : handles)
2323 llvm::append_range(C&: attrs, R: state.getParams(value: attribute));
2324 if (!getDeduplicate()) {
2325 results.setParams(value: cast<OpResult>(Val: getResult()), params: attrs);
2326 return DiagnosedSilenceableFailure::success();
2327 }
2328
2329 SetVector<Attribute> uniqued(llvm::from_range, attrs);
2330 results.setParams(value: cast<OpResult>(Val: getResult()), params: uniqued.getArrayRef());
2331 return DiagnosedSilenceableFailure::success();
2332 }
2333
2334 assert(
2335 llvm::isa<TransformValueHandleTypeInterface>(handles.front().getType()) &&
2336 "expected value handle type");
2337 SmallVector<Value> payloadValues;
2338 for (Value value : handles)
2339 llvm::append_range(C&: payloadValues, R: state.getPayloadValues(handleValue: value));
2340 if (!getDeduplicate()) {
2341 results.setValues(handle: cast<OpResult>(Val: getResult()), values&: payloadValues);
2342 return DiagnosedSilenceableFailure::success();
2343 }
2344
2345 SetVector<Value> uniqued(llvm::from_range, payloadValues);
2346 results.setValues(handle: cast<OpResult>(Val: getResult()), values: uniqued.getArrayRef());
2347 return DiagnosedSilenceableFailure::success();
2348}
2349
2350bool transform::MergeHandlesOp::allowsRepeatedHandleOperands() {
2351 // Handles may be the same if deduplicating is enabled.
2352 return getDeduplicate();
2353}
2354
2355void transform::MergeHandlesOp::getEffects(
2356 SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
2357 onlyReadsHandle(handles: getHandlesMutable(), effects);
2358 producesHandle(handles: getOperation()->getOpResults(), effects);
2359
2360 // There are no effects on the Payload IR as this is only a handle
2361 // manipulation.
2362}
2363
2364OpFoldResult transform::MergeHandlesOp::fold(FoldAdaptor adaptor) {
2365 if (getDeduplicate() || getHandles().size() != 1)
2366 return {};
2367
2368 // If deduplication is not required and there is only one operand, it can be
2369 // used directly instead of merging.
2370 return getHandles().front();
2371}
2372
2373//===----------------------------------------------------------------------===//
2374// NamedSequenceOp
2375//===----------------------------------------------------------------------===//
2376
2377DiagnosedSilenceableFailure
2378transform::NamedSequenceOp::apply(transform::TransformRewriter &rewriter,
2379 transform::TransformResults &results,
2380 transform::TransformState &state) {
2381 if (isExternal())
2382 return emitDefiniteFailure() << "unresolved external named sequence";
2383
2384 // Map the entry block argument to the list of operations.
2385 // Note: this is the same implementation as PossibleTopLevelTransformOp but
2386 // without attaching the interface / trait since that is tailored to a
2387 // dangling top-level op that does not get "called".
2388 auto scope = state.make_region_scope(region&: getBody());
2389 if (failed(Result: detail::mapPossibleTopLevelTransformOpBlockArguments(
2390 state, op: this->getOperation(), region&: getBody())))
2391 return DiagnosedSilenceableFailure::definiteFailure();
2392
2393 return applySequenceBlock(block&: getBody().front(),
2394 mode: FailurePropagationMode::Propagate, state, results);
2395}
2396
2397void transform::NamedSequenceOp::getEffects(
2398 SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {}
2399
2400ParseResult transform::NamedSequenceOp::parse(OpAsmParser &parser,
2401 OperationState &result) {
2402 return function_interface_impl::parseFunctionOp(
2403 parser, result, /*allowVariadic=*/false,
2404 typeAttrName: getFunctionTypeAttrName(name: result.name),
2405 funcTypeBuilder: [](Builder &builder, ArrayRef<Type> inputs, ArrayRef<Type> results,
2406 function_interface_impl::VariadicFlag,
2407 std::string &) { return builder.getFunctionType(inputs, results); },
2408 argAttrsName: getArgAttrsAttrName(name: result.name), resAttrsName: getResAttrsAttrName(name: result.name));
2409}
2410
2411void transform::NamedSequenceOp::print(OpAsmPrinter &printer) {
2412 function_interface_impl::printFunctionOp(
2413 p&: printer, op: cast<FunctionOpInterface>(Val: getOperation()), /*isVariadic=*/false,
2414 typeAttrName: getFunctionTypeAttrName().getValue(), argAttrsName: getArgAttrsAttrName(),
2415 resAttrsName: getResAttrsAttrName());
2416}
2417
2418/// Verifies that a symbol function-like transform dialect operation has the
2419/// signature and the terminator that have conforming types, i.e., types
2420/// implementing the same transform dialect type interface. If `allowExternal`
2421/// is set, allow external symbols (declarations) and don't check the terminator
2422/// as it may not exist.
2423static DiagnosedSilenceableFailure
2424verifyYieldingSingleBlockOp(FunctionOpInterface op, bool allowExternal) {
2425 if (auto parent = op->getParentOfType<transform::TransformOpInterface>()) {
2426 DiagnosedSilenceableFailure diag =
2427 emitSilenceableFailure(op)
2428 << "cannot be defined inside another transform op";
2429 diag.attachNote(loc: parent.getLoc()) << "ancestor transform op";
2430 return diag;
2431 }
2432
2433 if (op.isExternal() || op.getFunctionBody().empty()) {
2434 if (allowExternal)
2435 return DiagnosedSilenceableFailure::success();
2436
2437 return emitSilenceableFailure(op) << "cannot be external";
2438 }
2439
2440 if (op.getFunctionBody().front().empty())
2441 return emitSilenceableFailure(op) << "expected a non-empty body block";
2442
2443 Operation *terminator = &op.getFunctionBody().front().back();
2444 if (!isa<transform::YieldOp>(Val: terminator)) {
2445 DiagnosedSilenceableFailure diag = emitSilenceableFailure(op)
2446 << "expected '"
2447 << transform::YieldOp::getOperationName()
2448 << "' as terminator";
2449 diag.attachNote(loc: terminator->getLoc()) << "terminator";
2450 return diag;
2451 }
2452
2453 if (terminator->getNumOperands() != op.getResultTypes().size()) {
2454 return emitSilenceableFailure(op: terminator)
2455 << "expected terminator to have as many operands as the parent op "
2456 "has results";
2457 }
2458 for (auto [i, operandType, resultType] : llvm::zip_equal(
2459 t: llvm::seq<unsigned>(Begin: 0, End: terminator->getNumOperands()),
2460 u: terminator->getOperands().getType(), args: op.getResultTypes())) {
2461 if (operandType == resultType)
2462 continue;
2463 return emitSilenceableFailure(op: terminator)
2464 << "the type of the terminator operand #" << i
2465 << " must match the type of the corresponding parent op result ("
2466 << operandType << " vs " << resultType << ")";
2467 }
2468
2469 return DiagnosedSilenceableFailure::success();
2470}
2471
2472/// Verification of a NamedSequenceOp. This does not report the error
2473/// immediately, so it can be used to check for op's well-formedness before the
2474/// verifier runs, e.g., during trait verification.
2475static DiagnosedSilenceableFailure
2476verifyNamedSequenceOp(transform::NamedSequenceOp op, bool emitWarnings) {
2477 if (Operation *parent = op->getParentWithTrait<OpTrait::SymbolTable>()) {
2478 if (!parent->getAttr(
2479 name: transform::TransformDialect::kWithNamedSequenceAttrName)) {
2480 DiagnosedSilenceableFailure diag =
2481 emitSilenceableFailure(op)
2482 << "expects the parent symbol table to have the '"
2483 << transform::TransformDialect::kWithNamedSequenceAttrName
2484 << "' attribute";
2485 diag.attachNote(loc: parent->getLoc()) << "symbol table operation";
2486 return diag;
2487 }
2488 }
2489
2490 if (auto parent = op->getParentOfType<transform::TransformOpInterface>()) {
2491 DiagnosedSilenceableFailure diag =
2492 emitSilenceableFailure(op)
2493 << "cannot be defined inside another transform op";
2494 diag.attachNote(loc: parent.getLoc()) << "ancestor transform op";
2495 return diag;
2496 }
2497
2498 if (op.isExternal() || op.getBody().empty())
2499 return verifyFunctionLikeConsumeAnnotations(op: cast<FunctionOpInterface>(Val&: *op),
2500 emitWarnings);
2501
2502 if (op.getBody().front().empty())
2503 return emitSilenceableFailure(op) << "expected a non-empty body block";
2504
2505 Operation *terminator = &op.getBody().front().back();
2506 if (!isa<transform::YieldOp>(Val: terminator)) {
2507 DiagnosedSilenceableFailure diag = emitSilenceableFailure(op)
2508 << "expected '"
2509 << transform::YieldOp::getOperationName()
2510 << "' as terminator";
2511 diag.attachNote(loc: terminator->getLoc()) << "terminator";
2512 return diag;
2513 }
2514
2515 if (terminator->getNumOperands() != op.getFunctionType().getNumResults()) {
2516 return emitSilenceableFailure(op: terminator)
2517 << "expected terminator to have as many operands as the parent op "
2518 "has results";
2519 }
2520 for (auto [i, operandType, resultType] :
2521 llvm::zip_equal(t: llvm::seq<unsigned>(Begin: 0, End: terminator->getNumOperands()),
2522 u: terminator->getOperands().getType(),
2523 args: op.getFunctionType().getResults())) {
2524 if (operandType == resultType)
2525 continue;
2526 return emitSilenceableFailure(op: terminator)
2527 << "the type of the terminator operand #" << i
2528 << " must match the type of the corresponding parent op result ("
2529 << operandType << " vs " << resultType << ")";
2530 }
2531
2532 auto funcOp = cast<FunctionOpInterface>(Val&: *op);
2533 DiagnosedSilenceableFailure diag =
2534 verifyFunctionLikeConsumeAnnotations(op: funcOp, emitWarnings);
2535 if (!diag.succeeded())
2536 return diag;
2537
2538 return verifyYieldingSingleBlockOp(op: funcOp,
2539 /*allowExternal=*/true);
2540}
2541
2542LogicalResult transform::NamedSequenceOp::verify() {
2543 // Actual verification happens in a separate function for reusability.
2544 return verifyNamedSequenceOp(op: *this, /*emitWarnings=*/true).checkAndReport();
2545}
2546
2547template <typename FnTy>
2548static void buildSequenceBody(OpBuilder &builder, OperationState &state,
2549 Type bbArgType, TypeRange extraBindingTypes,
2550 FnTy bodyBuilder) {
2551 SmallVector<Type> types;
2552 types.reserve(N: 1 + extraBindingTypes.size());
2553 types.push_back(Elt: bbArgType);
2554 llvm::append_range(C&: types, R&: extraBindingTypes);
2555
2556 OpBuilder::InsertionGuard guard(builder);
2557 Region *region = state.regions.back().get();
2558 Block *bodyBlock =
2559 builder.createBlock(parent: region, insertPt: region->begin(), argTypes: types,
2560 locs: SmallVector<Location>(types.size(), state.location));
2561
2562 // Populate body.
2563 builder.setInsertionPointToStart(bodyBlock);
2564 if constexpr (llvm::function_traits<FnTy>::num_args == 3) {
2565 bodyBuilder(builder, state.location, bodyBlock->getArgument(i: 0));
2566 } else {
2567 bodyBuilder(builder, state.location, bodyBlock->getArgument(i: 0),
2568 bodyBlock->getArguments().drop_front());
2569 }
2570}
2571
2572void transform::NamedSequenceOp::build(OpBuilder &builder,
2573 OperationState &state, StringRef symName,
2574 Type rootType, TypeRange resultTypes,
2575 SequenceBodyBuilderFn bodyBuilder,
2576 ArrayRef<NamedAttribute> attrs,
2577 ArrayRef<DictionaryAttr> argAttrs) {
2578 state.addAttribute(name: SymbolTable::getSymbolAttrName(),
2579 attr: builder.getStringAttr(bytes: symName));
2580 state.addAttribute(name: getFunctionTypeAttrName(name: state.name),
2581 attr: TypeAttr::get(type: FunctionType::get(context: builder.getContext(),
2582 inputs: rootType, results: resultTypes)));
2583 state.attributes.append(inStart: attrs.begin(), inEnd: attrs.end());
2584 state.addRegion();
2585
2586 buildSequenceBody(builder, state, bbArgType: rootType,
2587 /*extraBindingTypes=*/TypeRange(), bodyBuilder);
2588}
2589
2590//===----------------------------------------------------------------------===//
2591// NumAssociationsOp
2592//===----------------------------------------------------------------------===//
2593
2594DiagnosedSilenceableFailure
2595transform::NumAssociationsOp::apply(transform::TransformRewriter &rewriter,
2596 transform::TransformResults &results,
2597 transform::TransformState &state) {
2598 size_t numAssociations =
2599 llvm::TypeSwitch<Type, size_t>(getHandle().getType())
2600 .Case(caseFn: [&](TransformHandleTypeInterface opHandle) {
2601 return llvm::range_size(Range: state.getPayloadOps(value: getHandle()));
2602 })
2603 .Case(caseFn: [&](TransformValueHandleTypeInterface valueHandle) {
2604 return llvm::range_size(Range: state.getPayloadValues(handleValue: getHandle()));
2605 })
2606 .Case(caseFn: [&](TransformParamTypeInterface param) {
2607 return llvm::range_size(Range: state.getParams(value: getHandle()));
2608 })
2609 .Default(defaultFn: [](Type) {
2610 llvm_unreachable("unknown kind of transform dialect type");
2611 return 0;
2612 });
2613 results.setParams(value: cast<OpResult>(Val: getNum()),
2614 params: rewriter.getI64IntegerAttr(value: numAssociations));
2615 return DiagnosedSilenceableFailure::success();
2616}
2617
2618LogicalResult transform::NumAssociationsOp::verify() {
2619 // Verify that the result type accepts an i64 attribute as payload.
2620 auto resultType = cast<TransformParamTypeInterface>(Val: getNum().getType());
2621 return resultType
2622 .checkPayload(loc: getLoc(), payload: {Builder(getContext()).getI64IntegerAttr(value: 0)})
2623 .checkAndReport();
2624}
2625
2626//===----------------------------------------------------------------------===//
2627// SelectOp
2628//===----------------------------------------------------------------------===//
2629
2630DiagnosedSilenceableFailure
2631transform::SelectOp::apply(transform::TransformRewriter &rewriter,
2632 transform::TransformResults &results,
2633 transform::TransformState &state) {
2634 SmallVector<Operation *> result;
2635 auto payloadOps = state.getPayloadOps(value: getTarget());
2636 for (Operation *op : payloadOps) {
2637 if (op->getName().getStringRef() == getOpName())
2638 result.push_back(Elt: op);
2639 }
2640 results.set(value: cast<OpResult>(Val: getResult()), ops&: result);
2641 return DiagnosedSilenceableFailure::success();
2642}
2643
2644//===----------------------------------------------------------------------===//
2645// SplitHandleOp
2646//===----------------------------------------------------------------------===//
2647
2648void transform::SplitHandleOp::build(OpBuilder &builder, OperationState &result,
2649 Value target, int64_t numResultHandles) {
2650 result.addOperands(newOperands: target);
2651 result.addTypes(newTypes: SmallVector<Type>(numResultHandles, target.getType()));
2652}
2653
2654DiagnosedSilenceableFailure
2655transform::SplitHandleOp::apply(transform::TransformRewriter &rewriter,
2656 transform::TransformResults &results,
2657 transform::TransformState &state) {
2658 int64_t numPayloads =
2659 llvm::TypeSwitch<Type, int64_t>(getHandle().getType())
2660 .Case<TransformHandleTypeInterface>(caseFn: [&](auto x) {
2661 return llvm::range_size(Range: state.getPayloadOps(value: getHandle()));
2662 })
2663 .Case<TransformValueHandleTypeInterface>(caseFn: [&](auto x) {
2664 return llvm::range_size(Range: state.getPayloadValues(handleValue: getHandle()));
2665 })
2666 .Case<TransformParamTypeInterface>(caseFn: [&](auto x) {
2667 return llvm::range_size(Range: state.getParams(value: getHandle()));
2668 })
2669 .Default(defaultFn: [](auto x) {
2670 llvm_unreachable("unknown transform dialect type interface");
2671 return -1;
2672 });
2673
2674 auto produceNumOpsError = [&]() {
2675 return emitSilenceableError()
2676 << getHandle() << " expected to contain " << this->getNumResults()
2677 << " payloads but it contains " << numPayloads << " payloads";
2678 };
2679
2680 // Fail if there are more payload ops than results and no overflow result was
2681 // specified.
2682 if (numPayloads > getNumResults() && !getOverflowResult().has_value())
2683 return produceNumOpsError();
2684
2685 // Fail if there are more results than payload ops. Unless:
2686 // - "fail_on_payload_too_small" is set to "false", or
2687 // - "pass_through_empty_handle" is set to "true" and there are 0 payload ops.
2688 if (numPayloads < getNumResults() && getFailOnPayloadTooSmall() &&
2689 (numPayloads != 0 || !getPassThroughEmptyHandle()))
2690 return produceNumOpsError();
2691
2692 // Distribute payloads.
2693 SmallVector<SmallVector<MappedValue, 1>> resultHandles(getNumResults(), {});
2694 if (getOverflowResult())
2695 resultHandles[*getOverflowResult()].reserve(N: numPayloads - getNumResults());
2696
2697 auto container = [&]() {
2698 if (isa<TransformHandleTypeInterface>(Val: getHandle().getType())) {
2699 return llvm::map_to_vector(
2700 C: state.getPayloadOps(value: getHandle()),
2701 F: [](Operation *op) -> MappedValue { return op; });
2702 }
2703 if (isa<TransformValueHandleTypeInterface>(Val: getHandle().getType())) {
2704 return llvm::map_to_vector(C: state.getPayloadValues(handleValue: getHandle()),
2705 F: [](Value v) -> MappedValue { return v; });
2706 }
2707 assert(isa<TransformParamTypeInterface>(getHandle().getType()) &&
2708 "unsupported kind of transform dialect type");
2709 return llvm::map_to_vector(C: state.getParams(value: getHandle()),
2710 F: [](Attribute a) -> MappedValue { return a; });
2711 }();
2712
2713 for (auto &&en : llvm::enumerate(First&: container)) {
2714 int64_t resultNum = en.index();
2715 if (resultNum >= getNumResults())
2716 resultNum = *getOverflowResult();
2717 resultHandles[resultNum].push_back(Elt: en.value());
2718 }
2719
2720 // Set transform op results.
2721 for (auto &&it : llvm::enumerate(First&: resultHandles))
2722 results.setMappedValues(handle: llvm::cast<OpResult>(Val: getResult(i: it.index())),
2723 values: it.value());
2724
2725 return DiagnosedSilenceableFailure::success();
2726}
2727
2728void transform::SplitHandleOp::getEffects(
2729 SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
2730 onlyReadsHandle(handles: getHandleMutable(), effects);
2731 producesHandle(handles: getOperation()->getOpResults(), effects);
2732 // There are no effects on the Payload IR as this is only a handle
2733 // manipulation.
2734}
2735
2736LogicalResult transform::SplitHandleOp::verify() {
2737 if (getOverflowResult().has_value() &&
2738 !(*getOverflowResult() < getNumResults()))
2739 return emitOpError(message: "overflow_result is not a valid result index");
2740
2741 for (Type resultType : getResultTypes()) {
2742 if (implementSameTransformInterface(t1: getHandle().getType(), t2: resultType))
2743 continue;
2744
2745 return emitOpError(message: "expects result types to implement the same transform "
2746 "interface as the operand type");
2747 }
2748
2749 return success();
2750}
2751
2752//===----------------------------------------------------------------------===//
2753// ReplicateOp
2754//===----------------------------------------------------------------------===//
2755
2756DiagnosedSilenceableFailure
2757transform::ReplicateOp::apply(transform::TransformRewriter &rewriter,
2758 transform::TransformResults &results,
2759 transform::TransformState &state) {
2760 unsigned numRepetitions = llvm::range_size(Range: state.getPayloadOps(value: getPattern()));
2761 for (const auto &en : llvm::enumerate(First: getHandles())) {
2762 Value handle = en.value();
2763 if (isa<TransformHandleTypeInterface>(Val: handle.getType())) {
2764 SmallVector<Operation *> current =
2765 llvm::to_vector(Range: state.getPayloadOps(value: handle));
2766 SmallVector<Operation *> payload;
2767 payload.reserve(N: numRepetitions * current.size());
2768 for (unsigned i = 0; i < numRepetitions; ++i)
2769 llvm::append_range(C&: payload, R&: current);
2770 results.set(value: llvm::cast<OpResult>(Val: getReplicated()[en.index()]), ops&: payload);
2771 } else {
2772 assert(llvm::isa<TransformParamTypeInterface>(handle.getType()) &&
2773 "expected param type");
2774 ArrayRef<Attribute> current = state.getParams(value: handle);
2775 SmallVector<Attribute> params;
2776 params.reserve(N: numRepetitions * current.size());
2777 for (unsigned i = 0; i < numRepetitions; ++i)
2778 llvm::append_range(C&: params, R&: current);
2779 results.setParams(value: llvm::cast<OpResult>(Val: getReplicated()[en.index()]),
2780 params);
2781 }
2782 }
2783 return DiagnosedSilenceableFailure::success();
2784}
2785
2786void transform::ReplicateOp::getEffects(
2787 SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
2788 onlyReadsHandle(handles: getPatternMutable(), effects);
2789 onlyReadsHandle(handles: getHandlesMutable(), effects);
2790 producesHandle(handles: getOperation()->getOpResults(), effects);
2791}
2792
2793//===----------------------------------------------------------------------===//
2794// SequenceOp
2795//===----------------------------------------------------------------------===//
2796
2797DiagnosedSilenceableFailure
2798transform::SequenceOp::apply(transform::TransformRewriter &rewriter,
2799 transform::TransformResults &results,
2800 transform::TransformState &state) {
2801 // Map the entry block argument to the list of operations.
2802 auto scope = state.make_region_scope(region&: *getBodyBlock()->getParent());
2803 if (failed(Result: mapBlockArguments(state)))
2804 return DiagnosedSilenceableFailure::definiteFailure();
2805
2806 return applySequenceBlock(block&: *getBodyBlock(), mode: getFailurePropagationMode(), state,
2807 results);
2808}
2809
2810static ParseResult parseSequenceOpOperands(
2811 OpAsmParser &parser, std::optional<OpAsmParser::UnresolvedOperand> &root,
2812 Type &rootType,
2813 SmallVectorImpl<OpAsmParser::UnresolvedOperand> &extraBindings,
2814 SmallVectorImpl<Type> &extraBindingTypes) {
2815 OpAsmParser::UnresolvedOperand rootOperand;
2816 OptionalParseResult hasRoot = parser.parseOptionalOperand(result&: rootOperand);
2817 if (!hasRoot.has_value()) {
2818 root = std::nullopt;
2819 return success();
2820 }
2821 if (failed(Result: hasRoot.value()))
2822 return failure();
2823 root = rootOperand;
2824
2825 if (succeeded(Result: parser.parseOptionalComma())) {
2826 if (failed(Result: parser.parseOperandList(result&: extraBindings)))
2827 return failure();
2828 }
2829 if (failed(Result: parser.parseColon()))
2830 return failure();
2831
2832 // The paren is truly optional.
2833 (void)parser.parseOptionalLParen();
2834
2835 if (failed(Result: parser.parseType(result&: rootType))) {
2836 return failure();
2837 }
2838
2839 if (!extraBindings.empty()) {
2840 if (parser.parseComma() || parser.parseTypeList(result&: extraBindingTypes))
2841 return failure();
2842 }
2843
2844 if (extraBindingTypes.size() != extraBindings.size()) {
2845 return parser.emitError(loc: parser.getNameLoc(),
2846 message: "expected types to be provided for all operands");
2847 }
2848
2849 // The paren is truly optional.
2850 (void)parser.parseOptionalRParen();
2851 return success();
2852}
2853
2854static void printSequenceOpOperands(OpAsmPrinter &printer, Operation *op,
2855 Value root, Type rootType,
2856 ValueRange extraBindings,
2857 TypeRange extraBindingTypes) {
2858 if (!root)
2859 return;
2860
2861 printer << root;
2862 bool hasExtras = !extraBindings.empty();
2863 if (hasExtras) {
2864 printer << ", ";
2865 printer.printOperands(container: extraBindings);
2866 }
2867
2868 printer << " : ";
2869 if (hasExtras)
2870 printer << "(";
2871
2872 printer << rootType;
2873 if (hasExtras)
2874 printer << ", " << llvm::interleaved(R: extraBindingTypes) << ')';
2875}
2876
2877/// Returns `true` if the given op operand may be consuming the handle value in
2878/// the Transform IR. That is, if it may have a Free effect on it.
2879static bool isValueUsePotentialConsumer(OpOperand &use) {
2880 // Conservatively assume the effect being present in absence of the interface.
2881 auto iface = dyn_cast<transform::TransformOpInterface>(Val: use.getOwner());
2882 if (!iface)
2883 return true;
2884
2885 return isHandleConsumed(handle: use.get(), transform: iface);
2886}
2887
2888LogicalResult
2889checkDoubleConsume(Value value,
2890 function_ref<InFlightDiagnostic()> reportError) {
2891 OpOperand *potentialConsumer = nullptr;
2892 for (OpOperand &use : value.getUses()) {
2893 if (!isValueUsePotentialConsumer(use))
2894 continue;
2895
2896 if (!potentialConsumer) {
2897 potentialConsumer = &use;
2898 continue;
2899 }
2900
2901 InFlightDiagnostic diag = reportError()
2902 << " has more than one potential consumer";
2903 diag.attachNote(noteLoc: potentialConsumer->getOwner()->getLoc())
2904 << "used here as operand #" << potentialConsumer->getOperandNumber();
2905 diag.attachNote(noteLoc: use.getOwner()->getLoc())
2906 << "used here as operand #" << use.getOperandNumber();
2907 return diag;
2908 }
2909
2910 return success();
2911}
2912
2913LogicalResult transform::SequenceOp::verify() {
2914 assert(getBodyBlock()->getNumArguments() >= 1 &&
2915 "the number of arguments must have been verified to be more than 1 by "
2916 "PossibleTopLevelTransformOpTrait");
2917
2918 if (!getRoot() && !getExtraBindings().empty()) {
2919 return emitOpError()
2920 << "does not expect extra operands when used as top-level";
2921 }
2922
2923 // Check if a block argument has more than one consuming use.
2924 for (BlockArgument arg : getBodyBlock()->getArguments()) {
2925 if (failed(Result: checkDoubleConsume(value: arg, reportError: [this, arg]() {
2926 return (emitOpError() << "block argument #" << arg.getArgNumber());
2927 }))) {
2928 return failure();
2929 }
2930 }
2931
2932 // Check properties of the nested operations they cannot check themselves.
2933 for (Operation &child : *getBodyBlock()) {
2934 if (!isa<TransformOpInterface>(Val: child) &&
2935 &child != &getBodyBlock()->back()) {
2936 InFlightDiagnostic diag =
2937 emitOpError()
2938 << "expected children ops to implement TransformOpInterface";
2939 diag.attachNote(noteLoc: child.getLoc()) << "op without interface";
2940 return diag;
2941 }
2942
2943 for (OpResult result : child.getResults()) {
2944 auto report = [&]() {
2945 return (child.emitError() << "result #" << result.getResultNumber());
2946 };
2947 if (failed(Result: checkDoubleConsume(value: result, reportError: report)))
2948 return failure();
2949 }
2950 }
2951
2952 if (!getBodyBlock()->mightHaveTerminator())
2953 return emitOpError() << "expects to have a terminator in the body";
2954
2955 if (getBodyBlock()->getTerminator()->getOperandTypes() !=
2956 getOperation()->getResultTypes()) {
2957 InFlightDiagnostic diag = emitOpError()
2958 << "expects the types of the terminator operands "
2959 "to match the types of the result";
2960 diag.attachNote(noteLoc: getBodyBlock()->getTerminator()->getLoc()) << "terminator";
2961 return diag;
2962 }
2963 return success();
2964}
2965
2966void transform::SequenceOp::getEffects(
2967 SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
2968 getPotentialTopLevelEffects(effects);
2969}
2970
2971OperandRange
2972transform::SequenceOp::getEntrySuccessorOperands(RegionBranchPoint point) {
2973 assert(point == getBody() && "unexpected region index");
2974 if (getOperation()->getNumOperands() > 0)
2975 return getOperation()->getOperands();
2976 return OperandRange(getOperation()->operand_end(),
2977 getOperation()->operand_end());
2978}
2979
2980void transform::SequenceOp::getSuccessorRegions(
2981 RegionBranchPoint point, SmallVectorImpl<RegionSuccessor> &regions) {
2982 if (point.isParent()) {
2983 Region *bodyRegion = &getBody();
2984 regions.emplace_back(Args&: bodyRegion, Args: getNumOperands() != 0
2985 ? bodyRegion->getArguments()
2986 : Block::BlockArgListType());
2987 return;
2988 }
2989
2990 assert(point == getBody() && "unexpected region index");
2991 regions.emplace_back(Args: getOperation()->getResults());
2992}
2993
2994void transform::SequenceOp::getRegionInvocationBounds(
2995 ArrayRef<Attribute> operands, SmallVectorImpl<InvocationBounds> &bounds) {
2996 (void)operands;
2997 bounds.emplace_back(Args: 1, Args: 1);
2998}
2999
3000void transform::SequenceOp::build(OpBuilder &builder, OperationState &state,
3001 TypeRange resultTypes,
3002 FailurePropagationMode failurePropagationMode,
3003 Value root,
3004 SequenceBodyBuilderFn bodyBuilder) {
3005 build(odsBuilder&: builder, odsState&: state, results: resultTypes, failure_propagation_mode: failurePropagationMode, root,
3006 /*extra_bindings=*/ValueRange());
3007 Type bbArgType = root.getType();
3008 buildSequenceBody(builder, state, bbArgType,
3009 /*extraBindingTypes=*/TypeRange(), bodyBuilder);
3010}
3011
3012void transform::SequenceOp::build(OpBuilder &builder, OperationState &state,
3013 TypeRange resultTypes,
3014 FailurePropagationMode failurePropagationMode,
3015 Value root, ValueRange extraBindings,
3016 SequenceBodyBuilderArgsFn bodyBuilder) {
3017 build(odsBuilder&: builder, odsState&: state, results: resultTypes, failure_propagation_mode: failurePropagationMode, root,
3018 extra_bindings: extraBindings);
3019 buildSequenceBody(builder, state, bbArgType: root.getType(), extraBindingTypes: extraBindings.getTypes(),
3020 bodyBuilder);
3021}
3022
3023void transform::SequenceOp::build(OpBuilder &builder, OperationState &state,
3024 TypeRange resultTypes,
3025 FailurePropagationMode failurePropagationMode,
3026 Type bbArgType,
3027 SequenceBodyBuilderFn bodyBuilder) {
3028 build(odsBuilder&: builder, odsState&: state, results: resultTypes, failure_propagation_mode: failurePropagationMode, /*root=*/Value(),
3029 /*extra_bindings=*/ValueRange());
3030 buildSequenceBody(builder, state, bbArgType,
3031 /*extraBindingTypes=*/TypeRange(), bodyBuilder);
3032}
3033
3034void transform::SequenceOp::build(OpBuilder &builder, OperationState &state,
3035 TypeRange resultTypes,
3036 FailurePropagationMode failurePropagationMode,
3037 Type bbArgType, TypeRange extraBindingTypes,
3038 SequenceBodyBuilderArgsFn bodyBuilder) {
3039 build(odsBuilder&: builder, odsState&: state, results: resultTypes, failure_propagation_mode: failurePropagationMode, /*root=*/Value(),
3040 /*extra_bindings=*/ValueRange());
3041 buildSequenceBody(builder, state, bbArgType, extraBindingTypes, bodyBuilder);
3042}
3043
3044//===----------------------------------------------------------------------===//
3045// PrintOp
3046//===----------------------------------------------------------------------===//
3047
3048void transform::PrintOp::build(OpBuilder &builder, OperationState &result,
3049 StringRef name) {
3050 if (!name.empty())
3051 result.getOrAddProperties<Properties>().name = builder.getStringAttr(bytes: name);
3052}
3053
3054void transform::PrintOp::build(OpBuilder &builder, OperationState &result,
3055 Value target, StringRef name) {
3056 result.addOperands(newOperands: {target});
3057 build(builder, result, name);
3058}
3059
3060DiagnosedSilenceableFailure
3061transform::PrintOp::apply(transform::TransformRewriter &rewriter,
3062 transform::TransformResults &results,
3063 transform::TransformState &state) {
3064 llvm::outs() << "[[[ IR printer: ";
3065 if (getName().has_value())
3066 llvm::outs() << *getName() << " ";
3067
3068 OpPrintingFlags printFlags;
3069 if (getAssumeVerified().value_or(u: false))
3070 printFlags.assumeVerified();
3071 if (getUseLocalScope().value_or(u: false))
3072 printFlags.useLocalScope();
3073 if (getSkipRegions().value_or(u: false))
3074 printFlags.skipRegions();
3075
3076 if (!getTarget()) {
3077 llvm::outs() << "top-level ]]]\n";
3078 state.getTopLevel()->print(os&: llvm::outs(), flags: printFlags);
3079 llvm::outs() << "\n";
3080 llvm::outs().flush();
3081 return DiagnosedSilenceableFailure::success();
3082 }
3083
3084 llvm::outs() << "]]]\n";
3085 for (Operation *target : state.getPayloadOps(value: getTarget())) {
3086 target->print(os&: llvm::outs(), flags: printFlags);
3087 llvm::outs() << "\n";
3088 }
3089
3090 llvm::outs().flush();
3091 return DiagnosedSilenceableFailure::success();
3092}
3093
3094void transform::PrintOp::getEffects(
3095 SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
3096 // We don't really care about mutability here, but `getTarget` now
3097 // unconditionally casts to a specific type before verification could run
3098 // here.
3099 if (!getTargetMutable().empty())
3100 onlyReadsHandle(handles: getTargetMutable()[0], effects);
3101 onlyReadsPayload(effects);
3102
3103 // There is no resource for stderr file descriptor, so just declare print
3104 // writes into the default resource.
3105 effects.emplace_back(Args: MemoryEffects::Write::get());
3106}
3107
3108//===----------------------------------------------------------------------===//
3109// VerifyOp
3110//===----------------------------------------------------------------------===//
3111
3112DiagnosedSilenceableFailure
3113transform::VerifyOp::applyToOne(transform::TransformRewriter &rewriter,
3114 Operation *target,
3115 transform::ApplyToEachResultList &results,
3116 transform::TransformState &state) {
3117 if (failed(Result: ::mlir::verify(op: target))) {
3118 DiagnosedDefiniteFailure diag = emitDefiniteFailure()
3119 << "failed to verify payload op";
3120 diag.attachNote(loc: target->getLoc()) << "payload op";
3121 return diag;
3122 }
3123 return DiagnosedSilenceableFailure::success();
3124}
3125
3126void transform::VerifyOp::getEffects(
3127 SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
3128 transform::onlyReadsHandle(handles: getTargetMutable(), effects);
3129}
3130
3131//===----------------------------------------------------------------------===//
3132// YieldOp
3133//===----------------------------------------------------------------------===//
3134
3135void transform::YieldOp::getEffects(
3136 SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
3137 onlyReadsHandle(handles: getOperandsMutable(), effects);
3138}
3139

source code of mlir/lib/Dialect/Transform/IR/TransformOps.cpp