1 | //===- TransformOps.cpp - Transform dialect operations --------------------===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | |
9 | #include "mlir/Dialect/Transform/IR/TransformOps.h" |
10 | |
11 | #include "mlir/Conversion/ConvertToLLVM/ToLLVMInterface.h" |
12 | #include "mlir/Conversion/LLVMCommon/ConversionTarget.h" |
13 | #include "mlir/Conversion/LLVMCommon/TypeConverter.h" |
14 | #include "mlir/Dialect/Transform/IR/TransformAttrs.h" |
15 | #include "mlir/Dialect/Transform/IR/TransformDialect.h" |
16 | #include "mlir/Dialect/Transform/IR/TransformTypes.h" |
17 | #include "mlir/Dialect/Transform/Interfaces/MatchInterfaces.h" |
18 | #include "mlir/Dialect/Transform/Interfaces/TransformInterfaces.h" |
19 | #include "mlir/IR/BuiltinAttributes.h" |
20 | #include "mlir/IR/Diagnostics.h" |
21 | #include "mlir/IR/Dominance.h" |
22 | #include "mlir/IR/OperationSupport.h" |
23 | #include "mlir/IR/PatternMatch.h" |
24 | #include "mlir/IR/Verifier.h" |
25 | #include "mlir/Interfaces/CallInterfaces.h" |
26 | #include "mlir/Interfaces/ControlFlowInterfaces.h" |
27 | #include "mlir/Interfaces/FunctionImplementation.h" |
28 | #include "mlir/Interfaces/FunctionInterfaces.h" |
29 | #include "mlir/Pass/Pass.h" |
30 | #include "mlir/Pass/PassManager.h" |
31 | #include "mlir/Pass/PassRegistry.h" |
32 | #include "mlir/Transforms/CSE.h" |
33 | #include "mlir/Transforms/DialectConversion.h" |
34 | #include "mlir/Transforms/GreedyPatternRewriteDriver.h" |
35 | #include "mlir/Transforms/LoopInvariantCodeMotionUtils.h" |
36 | #include "llvm/ADT/DenseSet.h" |
37 | #include "llvm/ADT/STLExtras.h" |
38 | #include "llvm/ADT/ScopeExit.h" |
39 | #include "llvm/ADT/SmallPtrSet.h" |
40 | #include "llvm/ADT/TypeSwitch.h" |
41 | #include "llvm/Support/Debug.h" |
42 | #include "llvm/Support/ErrorHandling.h" |
43 | #include <optional> |
44 | |
45 | #define DEBUG_TYPE "transform-dialect" |
46 | #define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "] ") |
47 | |
48 | #define DEBUG_TYPE_MATCHER "transform-matcher" |
49 | #define DBGS_MATCHER() (llvm::dbgs() << "[" DEBUG_TYPE_MATCHER "] ") |
50 | #define DEBUG_MATCHER(x) DEBUG_WITH_TYPE(DEBUG_TYPE_MATCHER, x) |
51 | |
52 | using namespace mlir; |
53 | |
54 | static ParseResult parseSequenceOpOperands( |
55 | OpAsmParser &parser, std::optional<OpAsmParser::UnresolvedOperand> &root, |
56 | Type &rootType, |
57 | SmallVectorImpl<OpAsmParser::UnresolvedOperand> &, |
58 | SmallVectorImpl<Type> &); |
59 | static void printSequenceOpOperands(OpAsmPrinter &printer, Operation *op, |
60 | Value root, Type rootType, |
61 | ValueRange , |
62 | TypeRange ); |
63 | static void printForeachMatchSymbols(OpAsmPrinter &printer, Operation *op, |
64 | ArrayAttr matchers, ArrayAttr actions); |
65 | static ParseResult parseForeachMatchSymbols(OpAsmParser &parser, |
66 | ArrayAttr &matchers, |
67 | ArrayAttr &actions); |
68 | |
69 | /// Helper function to check if the given transform op is contained in (or |
70 | /// equal to) the given payload target op. In that case, an error is returned. |
71 | /// Transforming transform IR that is currently executing is generally unsafe. |
72 | static DiagnosedSilenceableFailure |
73 | ensurePayloadIsSeparateFromTransform(transform::TransformOpInterface transform, |
74 | Operation *payload) { |
75 | Operation *transformAncestor = transform.getOperation(); |
76 | while (transformAncestor) { |
77 | if (transformAncestor == payload) { |
78 | DiagnosedDefiniteFailure diag = |
79 | transform.emitDefiniteFailure() |
80 | << "cannot apply transform to itself (or one of its ancestors)" ; |
81 | diag.attachNote(loc: payload->getLoc()) << "target payload op" ; |
82 | return diag; |
83 | } |
84 | transformAncestor = transformAncestor->getParentOp(); |
85 | } |
86 | return DiagnosedSilenceableFailure::success(); |
87 | } |
88 | |
89 | #define GET_OP_CLASSES |
90 | #include "mlir/Dialect/Transform/IR/TransformOps.cpp.inc" |
91 | |
92 | //===----------------------------------------------------------------------===// |
93 | // AlternativesOp |
94 | //===----------------------------------------------------------------------===// |
95 | |
96 | OperandRange |
97 | transform::AlternativesOp::getEntrySuccessorOperands(RegionBranchPoint point) { |
98 | if (!point.isParent() && getOperation()->getNumOperands() == 1) |
99 | return getOperation()->getOperands(); |
100 | return OperandRange(getOperation()->operand_end(), |
101 | getOperation()->operand_end()); |
102 | } |
103 | |
104 | void transform::AlternativesOp::getSuccessorRegions( |
105 | RegionBranchPoint point, SmallVectorImpl<RegionSuccessor> ®ions) { |
106 | for (Region &alternative : llvm::drop_begin( |
107 | getAlternatives(), |
108 | point.isParent() ? 0 |
109 | : point.getRegionOrNull()->getRegionNumber() + 1)) { |
110 | regions.emplace_back(&alternative, !getOperands().empty() |
111 | ? alternative.getArguments() |
112 | : Block::BlockArgListType()); |
113 | } |
114 | if (!point.isParent()) |
115 | regions.emplace_back(getOperation()->getResults()); |
116 | } |
117 | |
118 | void transform::AlternativesOp::getRegionInvocationBounds( |
119 | ArrayRef<Attribute> operands, SmallVectorImpl<InvocationBounds> &bounds) { |
120 | (void)operands; |
121 | // The region corresponding to the first alternative is always executed, the |
122 | // remaining may or may not be executed. |
123 | bounds.reserve(getNumRegions()); |
124 | bounds.emplace_back(1, 1); |
125 | bounds.resize(getNumRegions(), InvocationBounds(0, 1)); |
126 | } |
127 | |
128 | static void forwardEmptyOperands(Block *block, transform::TransformState &state, |
129 | transform::TransformResults &results) { |
130 | for (const auto &res : block->getParentOp()->getOpResults()) |
131 | results.set(value: res, ops: {}); |
132 | } |
133 | |
134 | DiagnosedSilenceableFailure |
135 | transform::AlternativesOp::apply(transform::TransformRewriter &rewriter, |
136 | transform::TransformResults &results, |
137 | transform::TransformState &state) { |
138 | SmallVector<Operation *> originals; |
139 | if (Value scopeHandle = getScope()) |
140 | llvm::append_range(originals, state.getPayloadOps(scopeHandle)); |
141 | else |
142 | originals.push_back(state.getTopLevel()); |
143 | |
144 | for (Operation *original : originals) { |
145 | if (original->isAncestor(getOperation())) { |
146 | auto diag = emitDefiniteFailure() |
147 | << "scope must not contain the transforms being applied" ; |
148 | diag.attachNote(original->getLoc()) << "scope" ; |
149 | return diag; |
150 | } |
151 | if (!original->hasTrait<OpTrait::IsIsolatedFromAbove>()) { |
152 | auto diag = emitDefiniteFailure() |
153 | << "only isolated-from-above ops can be alternative scopes" ; |
154 | diag.attachNote(original->getLoc()) << "scope" ; |
155 | return diag; |
156 | } |
157 | } |
158 | |
159 | for (Region ® : getAlternatives()) { |
160 | // Clone the scope operations and make the transforms in this alternative |
161 | // region apply to them by virtue of mapping the block argument (the only |
162 | // visible handle) to the cloned scope operations. This effectively prevents |
163 | // the transformation from accessing any IR outside the scope. |
164 | auto scope = state.make_region_scope(reg); |
165 | auto clones = llvm::to_vector( |
166 | llvm::map_range(originals, [](Operation *op) { return op->clone(); })); |
167 | auto deleteClones = llvm::make_scope_exit([&] { |
168 | for (Operation *clone : clones) |
169 | clone->erase(); |
170 | }); |
171 | if (failed(state.mapBlockArguments(reg.front().getArgument(0), clones))) |
172 | return DiagnosedSilenceableFailure::definiteFailure(); |
173 | |
174 | bool failed = false; |
175 | for (Operation &transform : reg.front().without_terminator()) { |
176 | DiagnosedSilenceableFailure result = |
177 | state.applyTransform(cast<TransformOpInterface>(transform)); |
178 | if (result.isSilenceableFailure()) { |
179 | LLVM_DEBUG(DBGS() << "alternative failed: " << result.getMessage() |
180 | << "\n" ); |
181 | failed = true; |
182 | break; |
183 | } |
184 | |
185 | if (::mlir::failed(result.silence())) |
186 | return DiagnosedSilenceableFailure::definiteFailure(); |
187 | } |
188 | |
189 | // If all operations in the given alternative succeeded, no need to consider |
190 | // the rest. Replace the original scoping operation with the clone on which |
191 | // the transformations were performed. |
192 | if (!failed) { |
193 | // We will be using the clones, so cancel their scheduled deletion. |
194 | deleteClones.release(); |
195 | TrackingListener listener(state, *this); |
196 | IRRewriter rewriter(getContext(), &listener); |
197 | for (const auto &kvp : llvm::zip(originals, clones)) { |
198 | Operation *original = std::get<0>(kvp); |
199 | Operation *clone = std::get<1>(kvp); |
200 | original->getBlock()->getOperations().insert(original->getIterator(), |
201 | clone); |
202 | rewriter.replaceOp(original, clone->getResults()); |
203 | } |
204 | detail::forwardTerminatorOperands(®.front(), state, results); |
205 | return DiagnosedSilenceableFailure::success(); |
206 | } |
207 | } |
208 | return emitSilenceableError() << "all alternatives failed" ; |
209 | } |
210 | |
211 | void transform::AlternativesOp::getEffects( |
212 | SmallVectorImpl<MemoryEffects::EffectInstance> &effects) { |
213 | consumesHandle(getOperands(), effects); |
214 | producesHandle(getResults(), effects); |
215 | for (Region *region : getRegions()) { |
216 | if (!region->empty()) |
217 | producesHandle(region->front().getArguments(), effects); |
218 | } |
219 | modifiesPayload(effects); |
220 | } |
221 | |
222 | LogicalResult transform::AlternativesOp::verify() { |
223 | for (Region &alternative : getAlternatives()) { |
224 | Block &block = alternative.front(); |
225 | Operation *terminator = block.getTerminator(); |
226 | if (terminator->getOperands().getTypes() != getResults().getTypes()) { |
227 | InFlightDiagnostic diag = emitOpError() |
228 | << "expects terminator operands to have the " |
229 | "same type as results of the operation" ; |
230 | diag.attachNote(terminator->getLoc()) << "terminator" ; |
231 | return diag; |
232 | } |
233 | } |
234 | |
235 | return success(); |
236 | } |
237 | |
238 | //===----------------------------------------------------------------------===// |
239 | // AnnotateOp |
240 | //===----------------------------------------------------------------------===// |
241 | |
242 | DiagnosedSilenceableFailure |
243 | transform::AnnotateOp::apply(transform::TransformRewriter &rewriter, |
244 | transform::TransformResults &results, |
245 | transform::TransformState &state) { |
246 | SmallVector<Operation *> targets = |
247 | llvm::to_vector(state.getPayloadOps(getTarget())); |
248 | |
249 | Attribute attr = UnitAttr::get(getContext()); |
250 | if (auto paramH = getParam()) { |
251 | ArrayRef<Attribute> params = state.getParams(paramH); |
252 | if (params.size() != 1) { |
253 | if (targets.size() != params.size()) { |
254 | return emitSilenceableError() |
255 | << "parameter and target have different payload lengths (" |
256 | << params.size() << " vs " << targets.size() << ")" ; |
257 | } |
258 | for (auto &&[target, attr] : llvm::zip_equal(targets, params)) |
259 | target->setAttr(getName(), attr); |
260 | return DiagnosedSilenceableFailure::success(); |
261 | } |
262 | attr = params[0]; |
263 | } |
264 | for (auto *target : targets) |
265 | target->setAttr(getName(), attr); |
266 | return DiagnosedSilenceableFailure::success(); |
267 | } |
268 | |
269 | void transform::AnnotateOp::getEffects( |
270 | SmallVectorImpl<MemoryEffects::EffectInstance> &effects) { |
271 | onlyReadsHandle(getTarget(), effects); |
272 | onlyReadsHandle(getParam(), effects); |
273 | modifiesPayload(effects); |
274 | } |
275 | |
276 | //===----------------------------------------------------------------------===// |
277 | // ApplyCommonSubexpressionEliminationOp |
278 | //===----------------------------------------------------------------------===// |
279 | |
280 | DiagnosedSilenceableFailure |
281 | transform::ApplyCommonSubexpressionEliminationOp::applyToOne( |
282 | transform::TransformRewriter &rewriter, Operation *target, |
283 | ApplyToEachResultList &results, transform::TransformState &state) { |
284 | // Make sure that this transform is not applied to itself. Modifying the |
285 | // transform IR while it is being interpreted is generally dangerous. |
286 | DiagnosedSilenceableFailure payloadCheck = |
287 | ensurePayloadIsSeparateFromTransform(*this, target); |
288 | if (!payloadCheck.succeeded()) |
289 | return payloadCheck; |
290 | |
291 | DominanceInfo domInfo; |
292 | mlir::eliminateCommonSubExpressions(rewriter, domInfo, target); |
293 | return DiagnosedSilenceableFailure::success(); |
294 | } |
295 | |
296 | void transform::ApplyCommonSubexpressionEliminationOp::getEffects( |
297 | SmallVectorImpl<MemoryEffects::EffectInstance> &effects) { |
298 | transform::onlyReadsHandle(getTarget(), effects); |
299 | transform::modifiesPayload(effects); |
300 | } |
301 | |
302 | //===----------------------------------------------------------------------===// |
303 | // ApplyDeadCodeEliminationOp |
304 | //===----------------------------------------------------------------------===// |
305 | |
306 | DiagnosedSilenceableFailure transform::ApplyDeadCodeEliminationOp::applyToOne( |
307 | transform::TransformRewriter &rewriter, Operation *target, |
308 | ApplyToEachResultList &results, transform::TransformState &state) { |
309 | // Make sure that this transform is not applied to itself. Modifying the |
310 | // transform IR while it is being interpreted is generally dangerous. |
311 | DiagnosedSilenceableFailure payloadCheck = |
312 | ensurePayloadIsSeparateFromTransform(*this, target); |
313 | if (!payloadCheck.succeeded()) |
314 | return payloadCheck; |
315 | |
316 | // Maintain a worklist of potentially dead ops. |
317 | SetVector<Operation *> worklist; |
318 | |
319 | // Helper function that adds all defining ops of used values (operands and |
320 | // operands of nested ops). |
321 | auto addDefiningOpsToWorklist = [&](Operation *op) { |
322 | op->walk([&](Operation *op) { |
323 | for (Value v : op->getOperands()) |
324 | if (Operation *defOp = v.getDefiningOp()) |
325 | if (target->isProperAncestor(defOp)) |
326 | worklist.insert(defOp); |
327 | }); |
328 | }; |
329 | |
330 | // Helper function that erases an op. |
331 | auto eraseOp = [&](Operation *op) { |
332 | // Remove op and nested ops from the worklist. |
333 | op->walk([&](Operation *op) { |
334 | const auto *it = llvm::find(worklist, op); |
335 | if (it != worklist.end()) |
336 | worklist.erase(it); |
337 | }); |
338 | rewriter.eraseOp(op); |
339 | }; |
340 | |
341 | // Initial walk over the IR. |
342 | target->walk<WalkOrder::PostOrder>([&](Operation *op) { |
343 | if (op != target && isOpTriviallyDead(op)) { |
344 | addDefiningOpsToWorklist(op); |
345 | eraseOp(op); |
346 | } |
347 | }); |
348 | |
349 | // Erase all ops that have become dead. |
350 | while (!worklist.empty()) { |
351 | Operation *op = worklist.pop_back_val(); |
352 | if (!isOpTriviallyDead(op)) |
353 | continue; |
354 | addDefiningOpsToWorklist(op); |
355 | eraseOp(op); |
356 | } |
357 | |
358 | return DiagnosedSilenceableFailure::success(); |
359 | } |
360 | |
361 | void transform::ApplyDeadCodeEliminationOp::getEffects( |
362 | SmallVectorImpl<MemoryEffects::EffectInstance> &effects) { |
363 | transform::onlyReadsHandle(getTarget(), effects); |
364 | transform::modifiesPayload(effects); |
365 | } |
366 | |
367 | //===----------------------------------------------------------------------===// |
368 | // ApplyPatternsOp |
369 | //===----------------------------------------------------------------------===// |
370 | |
371 | DiagnosedSilenceableFailure transform::ApplyPatternsOp::applyToOne( |
372 | transform::TransformRewriter &rewriter, Operation *target, |
373 | ApplyToEachResultList &results, transform::TransformState &state) { |
374 | // Make sure that this transform is not applied to itself. Modifying the |
375 | // transform IR while it is being interpreted is generally dangerous. Even |
376 | // more so for the ApplyPatternsOp because the GreedyPatternRewriteDriver |
377 | // performs many additional simplifications such as dead code elimination. |
378 | DiagnosedSilenceableFailure payloadCheck = |
379 | ensurePayloadIsSeparateFromTransform(*this, target); |
380 | if (!payloadCheck.succeeded()) |
381 | return payloadCheck; |
382 | |
383 | // Gather all specified patterns. |
384 | MLIRContext *ctx = target->getContext(); |
385 | RewritePatternSet patterns(ctx); |
386 | if (!getRegion().empty()) { |
387 | for (Operation &op : getRegion().front()) { |
388 | cast<transform::PatternDescriptorOpInterface>(&op) |
389 | .populatePatternsWithState(patterns, state); |
390 | } |
391 | } |
392 | |
393 | // Configure the GreedyPatternRewriteDriver. |
394 | GreedyRewriteConfig config; |
395 | config.listener = |
396 | static_cast<RewriterBase::Listener *>(rewriter.getListener()); |
397 | FrozenRewritePatternSet frozenPatterns(std::move(patterns)); |
398 | |
399 | config.maxIterations = getMaxIterations() == static_cast<uint64_t>(-1) |
400 | ? GreedyRewriteConfig::kNoLimit |
401 | : getMaxIterations(); |
402 | config.maxNumRewrites = getMaxNumRewrites() == static_cast<uint64_t>(-1) |
403 | ? GreedyRewriteConfig::kNoLimit |
404 | : getMaxNumRewrites(); |
405 | |
406 | // Apply patterns and CSE repetitively until a fixpoint is reached. If no CSE |
407 | // was requested, apply the greedy pattern rewrite only once. (The greedy |
408 | // pattern rewrite driver already iterates to a fixpoint internally.) |
409 | bool cseChanged = false; |
410 | // One or two iterations should be sufficient. Stop iterating after a certain |
411 | // threshold to make debugging easier. |
412 | static const int64_t kNumMaxIterations = 50; |
413 | int64_t iteration = 0; |
414 | do { |
415 | LogicalResult result = failure(); |
416 | if (target->hasTrait<OpTrait::IsIsolatedFromAbove>()) { |
417 | // Op is isolated from above. Apply patterns and also perform region |
418 | // simplification. |
419 | result = applyPatternsAndFoldGreedily(target, frozenPatterns, config); |
420 | } else { |
421 | // Manually gather list of ops because the other |
422 | // GreedyPatternRewriteDriver overloads only accepts ops that are isolated |
423 | // from above. This way, patterns can be applied to ops that are not |
424 | // isolated from above. Regions are not being simplified. Furthermore, |
425 | // only a single greedy rewrite iteration is performed. |
426 | SmallVector<Operation *> ops; |
427 | target->walk([&](Operation *nestedOp) { |
428 | if (target != nestedOp) |
429 | ops.push_back(nestedOp); |
430 | }); |
431 | result = applyOpPatternsAndFold(ops, frozenPatterns, config); |
432 | } |
433 | |
434 | // A failure typically indicates that the pattern application did not |
435 | // converge. |
436 | if (failed(result)) { |
437 | return emitSilenceableFailure(target) |
438 | << "greedy pattern application failed" ; |
439 | } |
440 | |
441 | if (getApplyCse()) { |
442 | DominanceInfo domInfo; |
443 | mlir::eliminateCommonSubExpressions(rewriter, domInfo, target, |
444 | &cseChanged); |
445 | } |
446 | } while (cseChanged && ++iteration < kNumMaxIterations); |
447 | |
448 | if (iteration == kNumMaxIterations) |
449 | return emitDefiniteFailure() << "fixpoint iteration did not converge" ; |
450 | |
451 | return DiagnosedSilenceableFailure::success(); |
452 | } |
453 | |
454 | LogicalResult transform::ApplyPatternsOp::verify() { |
455 | if (!getRegion().empty()) { |
456 | for (Operation &op : getRegion().front()) { |
457 | if (!isa<transform::PatternDescriptorOpInterface>(&op)) { |
458 | InFlightDiagnostic diag = emitOpError() |
459 | << "expected children ops to implement " |
460 | "PatternDescriptorOpInterface" ; |
461 | diag.attachNote(op.getLoc()) << "op without interface" ; |
462 | return diag; |
463 | } |
464 | } |
465 | } |
466 | return success(); |
467 | } |
468 | |
469 | void transform::ApplyPatternsOp::getEffects( |
470 | SmallVectorImpl<MemoryEffects::EffectInstance> &effects) { |
471 | transform::onlyReadsHandle(getTarget(), effects); |
472 | transform::modifiesPayload(effects); |
473 | } |
474 | |
475 | void transform::ApplyPatternsOp::build( |
476 | OpBuilder &builder, OperationState &result, Value target, |
477 | function_ref<void(OpBuilder &, Location)> bodyBuilder) { |
478 | result.addOperands(target); |
479 | |
480 | OpBuilder::InsertionGuard g(builder); |
481 | Region *region = result.addRegion(); |
482 | builder.createBlock(region); |
483 | if (bodyBuilder) |
484 | bodyBuilder(builder, result.location); |
485 | } |
486 | |
487 | //===----------------------------------------------------------------------===// |
488 | // ApplyCanonicalizationPatternsOp |
489 | //===----------------------------------------------------------------------===// |
490 | |
491 | void transform::ApplyCanonicalizationPatternsOp::populatePatterns( |
492 | RewritePatternSet &patterns) { |
493 | MLIRContext *ctx = patterns.getContext(); |
494 | for (Dialect *dialect : ctx->getLoadedDialects()) |
495 | dialect->getCanonicalizationPatterns(patterns); |
496 | for (RegisteredOperationName op : ctx->getRegisteredOperations()) |
497 | op.getCanonicalizationPatterns(patterns, ctx); |
498 | } |
499 | |
500 | //===----------------------------------------------------------------------===// |
501 | // ApplyConversionPatternsOp |
502 | //===----------------------------------------------------------------------===// |
503 | |
504 | DiagnosedSilenceableFailure transform::ApplyConversionPatternsOp::apply( |
505 | transform::TransformRewriter &rewriter, |
506 | transform::TransformResults &results, transform::TransformState &state) { |
507 | MLIRContext *ctx = getContext(); |
508 | |
509 | // Instantiate the default type converter if a type converter builder is |
510 | // specified. |
511 | std::unique_ptr<TypeConverter> defaultTypeConverter; |
512 | transform::TypeConverterBuilderOpInterface typeConverterBuilder = |
513 | getDefaultTypeConverter(); |
514 | if (typeConverterBuilder) |
515 | defaultTypeConverter = typeConverterBuilder.getTypeConverter(); |
516 | |
517 | // Configure conversion target. |
518 | ConversionTarget conversionTarget(*getContext()); |
519 | if (getLegalOps()) |
520 | for (Attribute attr : cast<ArrayAttr>(*getLegalOps())) |
521 | conversionTarget.addLegalOp( |
522 | OperationName(cast<StringAttr>(attr).getValue(), ctx)); |
523 | if (getIllegalOps()) |
524 | for (Attribute attr : cast<ArrayAttr>(*getIllegalOps())) |
525 | conversionTarget.addIllegalOp( |
526 | OperationName(cast<StringAttr>(attr).getValue(), ctx)); |
527 | if (getLegalDialects()) |
528 | for (Attribute attr : cast<ArrayAttr>(*getLegalDialects())) |
529 | conversionTarget.addLegalDialect(cast<StringAttr>(attr).getValue()); |
530 | if (getIllegalDialects()) |
531 | for (Attribute attr : cast<ArrayAttr>(*getIllegalDialects())) |
532 | conversionTarget.addIllegalDialect(cast<StringAttr>(attr).getValue()); |
533 | |
534 | // Gather all specified patterns. |
535 | RewritePatternSet patterns(ctx); |
536 | // Need to keep the converters alive until after pattern application because |
537 | // the patterns take a reference to an object that would otherwise get out of |
538 | // scope. |
539 | SmallVector<std::unique_ptr<TypeConverter>> keepAliveConverters; |
540 | if (!getPatterns().empty()) { |
541 | for (Operation &op : getPatterns().front()) { |
542 | auto descriptor = |
543 | cast<transform::ConversionPatternDescriptorOpInterface>(&op); |
544 | |
545 | // Check if this pattern set specifies a type converter. |
546 | std::unique_ptr<TypeConverter> typeConverter = |
547 | descriptor.getTypeConverter(); |
548 | TypeConverter *converter = nullptr; |
549 | if (typeConverter) { |
550 | keepAliveConverters.emplace_back(std::move(typeConverter)); |
551 | converter = keepAliveConverters.back().get(); |
552 | } else { |
553 | // No type converter specified: Use the default type converter. |
554 | if (!defaultTypeConverter) { |
555 | auto diag = emitDefiniteFailure() |
556 | << "pattern descriptor does not specify type " |
557 | "converter and apply_conversion_patterns op has " |
558 | "no default type converter" ; |
559 | diag.attachNote(op.getLoc()) << "pattern descriptor op" ; |
560 | return diag; |
561 | } |
562 | converter = defaultTypeConverter.get(); |
563 | } |
564 | |
565 | // Add descriptor-specific updates to the conversion target, which may |
566 | // depend on the final type converter. In structural converters, the |
567 | // legality of types dictates the dynamic legality of an operation. |
568 | descriptor.populateConversionTargetRules(*converter, conversionTarget); |
569 | |
570 | descriptor.populatePatterns(*converter, patterns); |
571 | } |
572 | } |
573 | |
574 | // Attach a tracking listener if handles should be preserved. We configure the |
575 | // listener to allow op replacements with different names, as conversion |
576 | // patterns typically replace ops with replacement ops that have a different |
577 | // name. |
578 | TrackingListenerConfig trackingConfig; |
579 | trackingConfig.requireMatchingReplacementOpName = false; |
580 | ErrorCheckingTrackingListener trackingListener(state, *this, trackingConfig); |
581 | ConversionConfig conversionConfig; |
582 | if (getPreserveHandles()) |
583 | conversionConfig.listener = &trackingListener; |
584 | |
585 | FrozenRewritePatternSet frozenPatterns(std::move(patterns)); |
586 | for (Operation *target : state.getPayloadOps(getTarget())) { |
587 | // Make sure that this transform is not applied to itself. Modifying the |
588 | // transform IR while it is being interpreted is generally dangerous. |
589 | DiagnosedSilenceableFailure payloadCheck = |
590 | ensurePayloadIsSeparateFromTransform(*this, target); |
591 | if (!payloadCheck.succeeded()) |
592 | return payloadCheck; |
593 | |
594 | LogicalResult status = failure(); |
595 | if (getPartialConversion()) { |
596 | status = applyPartialConversion(target, conversionTarget, frozenPatterns, |
597 | conversionConfig); |
598 | } else { |
599 | status = applyFullConversion(target, conversionTarget, frozenPatterns, |
600 | conversionConfig); |
601 | } |
602 | |
603 | // Check dialect conversion state. |
604 | DiagnosedSilenceableFailure diag = DiagnosedSilenceableFailure::success(); |
605 | if (failed(status)) { |
606 | diag = emitSilenceableError() << "dialect conversion failed" ; |
607 | diag.attachNote(target->getLoc()) << "target op" ; |
608 | } |
609 | |
610 | // Check tracking listener error state. |
611 | DiagnosedSilenceableFailure trackingFailure = |
612 | trackingListener.checkAndResetError(); |
613 | if (!trackingFailure.succeeded()) { |
614 | if (diag.succeeded()) { |
615 | // Tracking failure is the only failure. |
616 | return trackingFailure; |
617 | } else { |
618 | diag.attachNote() << "tracking listener also failed: " |
619 | << trackingFailure.getMessage(); |
620 | (void)trackingFailure.silence(); |
621 | } |
622 | } |
623 | |
624 | if (!diag.succeeded()) |
625 | return diag; |
626 | } |
627 | |
628 | return DiagnosedSilenceableFailure::success(); |
629 | } |
630 | |
631 | LogicalResult transform::ApplyConversionPatternsOp::verify() { |
632 | if (getNumRegions() != 1 && getNumRegions() != 2) |
633 | return emitOpError() << "expected 1 or 2 regions" ; |
634 | if (!getPatterns().empty()) { |
635 | for (Operation &op : getPatterns().front()) { |
636 | if (!isa<transform::ConversionPatternDescriptorOpInterface>(&op)) { |
637 | InFlightDiagnostic diag = |
638 | emitOpError() << "expected pattern children ops to implement " |
639 | "ConversionPatternDescriptorOpInterface" ; |
640 | diag.attachNote(op.getLoc()) << "op without interface" ; |
641 | return diag; |
642 | } |
643 | } |
644 | } |
645 | if (getNumRegions() == 2) { |
646 | Region &typeConverterRegion = getRegion(1); |
647 | if (!llvm::hasSingleElement(typeConverterRegion.front())) |
648 | return emitOpError() |
649 | << "expected exactly one op in default type converter region" ; |
650 | auto typeConverterOp = dyn_cast<transform::TypeConverterBuilderOpInterface>( |
651 | &typeConverterRegion.front().front()); |
652 | if (!typeConverterOp) { |
653 | InFlightDiagnostic diag = emitOpError() |
654 | << "expected default converter child op to " |
655 | "implement TypeConverterBuilderOpInterface" ; |
656 | diag.attachNote(typeConverterOp->getLoc()) << "op without interface" ; |
657 | return diag; |
658 | } |
659 | // Check default type converter type. |
660 | if (!getPatterns().empty()) { |
661 | for (Operation &op : getPatterns().front()) { |
662 | auto descriptor = |
663 | cast<transform::ConversionPatternDescriptorOpInterface>(&op); |
664 | if (failed(descriptor.verifyTypeConverter(typeConverterOp))) |
665 | return failure(); |
666 | } |
667 | } |
668 | } |
669 | return success(); |
670 | } |
671 | |
672 | void transform::ApplyConversionPatternsOp::getEffects( |
673 | SmallVectorImpl<MemoryEffects::EffectInstance> &effects) { |
674 | if (!getPreserveHandles()) { |
675 | transform::consumesHandle(getTarget(), effects); |
676 | } else { |
677 | transform::onlyReadsHandle(getTarget(), effects); |
678 | } |
679 | transform::modifiesPayload(effects); |
680 | } |
681 | |
682 | void transform::ApplyConversionPatternsOp::build( |
683 | OpBuilder &builder, OperationState &result, Value target, |
684 | function_ref<void(OpBuilder &, Location)> patternsBodyBuilder, |
685 | function_ref<void(OpBuilder &, Location)> typeConverterBodyBuilder) { |
686 | result.addOperands(target); |
687 | |
688 | { |
689 | OpBuilder::InsertionGuard g(builder); |
690 | Region *region1 = result.addRegion(); |
691 | builder.createBlock(region1); |
692 | if (patternsBodyBuilder) |
693 | patternsBodyBuilder(builder, result.location); |
694 | } |
695 | { |
696 | OpBuilder::InsertionGuard g(builder); |
697 | Region *region2 = result.addRegion(); |
698 | builder.createBlock(region2); |
699 | if (typeConverterBodyBuilder) |
700 | typeConverterBodyBuilder(builder, result.location); |
701 | } |
702 | } |
703 | |
704 | //===----------------------------------------------------------------------===// |
705 | // ApplyToLLVMConversionPatternsOp |
706 | //===----------------------------------------------------------------------===// |
707 | |
708 | void transform::ApplyToLLVMConversionPatternsOp::populatePatterns( |
709 | TypeConverter &typeConverter, RewritePatternSet &patterns) { |
710 | Dialect *dialect = getContext()->getLoadedDialect(getDialectName()); |
711 | assert(dialect && "expected that dialect is loaded" ); |
712 | auto *iface = cast<ConvertToLLVMPatternInterface>(dialect); |
713 | // ConversionTarget is currently ignored because the enclosing |
714 | // apply_conversion_patterns op sets up its own ConversionTarget. |
715 | ConversionTarget target(*getContext()); |
716 | iface->populateConvertToLLVMConversionPatterns( |
717 | target, static_cast<LLVMTypeConverter &>(typeConverter), patterns); |
718 | } |
719 | |
720 | LogicalResult transform::ApplyToLLVMConversionPatternsOp::verifyTypeConverter( |
721 | transform::TypeConverterBuilderOpInterface builder) { |
722 | if (builder.getTypeConverterType() != "LLVMTypeConverter" ) |
723 | return emitOpError("expected LLVMTypeConverter" ); |
724 | return success(); |
725 | } |
726 | |
727 | LogicalResult transform::ApplyToLLVMConversionPatternsOp::verify() { |
728 | Dialect *dialect = getContext()->getLoadedDialect(getDialectName()); |
729 | if (!dialect) |
730 | return emitOpError("unknown dialect or dialect not loaded: " ) |
731 | << getDialectName(); |
732 | auto *iface = dyn_cast<ConvertToLLVMPatternInterface>(dialect); |
733 | if (!iface) |
734 | return emitOpError( |
735 | "dialect does not implement ConvertToLLVMPatternInterface or " |
736 | "extension was not loaded: " ) |
737 | << getDialectName(); |
738 | return success(); |
739 | } |
740 | |
741 | //===----------------------------------------------------------------------===// |
742 | // ApplyLoopInvariantCodeMotionOp |
743 | //===----------------------------------------------------------------------===// |
744 | |
745 | DiagnosedSilenceableFailure |
746 | transform::ApplyLoopInvariantCodeMotionOp::applyToOne( |
747 | transform::TransformRewriter &rewriter, LoopLikeOpInterface target, |
748 | transform::ApplyToEachResultList &results, |
749 | transform::TransformState &state) { |
750 | // Currently, LICM does not remove operations, so we don't need tracking. |
751 | // If this ever changes, add a LICM entry point that takes a rewriter. |
752 | moveLoopInvariantCode(target); |
753 | return DiagnosedSilenceableFailure::success(); |
754 | } |
755 | |
756 | void transform::ApplyLoopInvariantCodeMotionOp::getEffects( |
757 | SmallVectorImpl<MemoryEffects::EffectInstance> &effects) { |
758 | transform::onlyReadsHandle(getTarget(), effects); |
759 | transform::modifiesPayload(effects); |
760 | } |
761 | |
762 | //===----------------------------------------------------------------------===// |
763 | // ApplyRegisteredPassOp |
764 | //===----------------------------------------------------------------------===// |
765 | |
766 | DiagnosedSilenceableFailure transform::ApplyRegisteredPassOp::applyToOne( |
767 | transform::TransformRewriter &rewriter, Operation *target, |
768 | ApplyToEachResultList &results, transform::TransformState &state) { |
769 | // Make sure that this transform is not applied to itself. Modifying the |
770 | // transform IR while it is being interpreted is generally dangerous. Even |
771 | // more so when applying passes because they may perform a wide range of IR |
772 | // modifications. |
773 | DiagnosedSilenceableFailure payloadCheck = |
774 | ensurePayloadIsSeparateFromTransform(*this, target); |
775 | if (!payloadCheck.succeeded()) |
776 | return payloadCheck; |
777 | |
778 | // Get pass or pass pipeline from registry. |
779 | const PassRegistryEntry *info = PassPipelineInfo::lookup(getPassName()); |
780 | if (!info) |
781 | info = PassInfo::lookup(getPassName()); |
782 | if (!info) |
783 | return emitDefiniteFailure() |
784 | << "unknown pass or pass pipeline: " << getPassName(); |
785 | |
786 | // Create pass manager and run the pass or pass pipeline. |
787 | PassManager pm(getContext()); |
788 | if (failed(info->addToPipeline(pm, getOptions(), [&](const Twine &msg) { |
789 | emitError(msg); |
790 | return failure(); |
791 | }))) { |
792 | return emitDefiniteFailure() |
793 | << "failed to add pass or pass pipeline to pipeline: " |
794 | << getPassName(); |
795 | } |
796 | if (failed(pm.run(target))) { |
797 | auto diag = emitSilenceableError() << "pass pipeline failed" ; |
798 | diag.attachNote(target->getLoc()) << "target op" ; |
799 | return diag; |
800 | } |
801 | |
802 | results.push_back(target); |
803 | return DiagnosedSilenceableFailure::success(); |
804 | } |
805 | |
806 | //===----------------------------------------------------------------------===// |
807 | // CastOp |
808 | //===----------------------------------------------------------------------===// |
809 | |
810 | DiagnosedSilenceableFailure |
811 | transform::CastOp::applyToOne(transform::TransformRewriter &rewriter, |
812 | Operation *target, ApplyToEachResultList &results, |
813 | transform::TransformState &state) { |
814 | results.push_back(target); |
815 | return DiagnosedSilenceableFailure::success(); |
816 | } |
817 | |
818 | void transform::CastOp::getEffects( |
819 | SmallVectorImpl<MemoryEffects::EffectInstance> &effects) { |
820 | onlyReadsPayload(effects); |
821 | onlyReadsHandle(getInput(), effects); |
822 | producesHandle(getOutput(), effects); |
823 | } |
824 | |
825 | bool transform::CastOp::areCastCompatible(TypeRange inputs, TypeRange outputs) { |
826 | assert(inputs.size() == 1 && "expected one input" ); |
827 | assert(outputs.size() == 1 && "expected one output" ); |
828 | return llvm::all_of( |
829 | std::initializer_list<Type>{inputs.front(), outputs.front()}, |
830 | llvm::IsaPred<transform::TransformHandleTypeInterface>); |
831 | } |
832 | |
833 | //===----------------------------------------------------------------------===// |
834 | // CollectMatchingOp |
835 | //===----------------------------------------------------------------------===// |
836 | |
837 | /// Applies matcher operations from the given `block` assigning `op` as the |
838 | /// payload of the block's first argument. Updates `state` accordingly. If any |
839 | /// of the matcher produces a silenceable failure, discards it (printing the |
840 | /// content to the debug output stream) and returns failure. If any of the |
841 | /// matchers produces a definite failure, reports it and returns failure. If all |
842 | /// matchers in the block succeed, populates `mappings` with the payload |
843 | /// entities associated with the block terminator operands. |
844 | static DiagnosedSilenceableFailure |
845 | matchBlock(Block &block, Operation *op, transform::TransformState &state, |
846 | SmallVectorImpl<SmallVector<transform::MappedValue>> &mappings) { |
847 | assert(block.getParent() && "cannot match using a detached block" ); |
848 | auto matchScope = state.make_region_scope(region&: *block.getParent()); |
849 | if (failed(result: state.mapBlockArgument(argument: block.getArgument(i: 0), values: {op}))) |
850 | return DiagnosedSilenceableFailure::definiteFailure(); |
851 | |
852 | for (Operation &match : block.without_terminator()) { |
853 | if (!isa<transform::MatchOpInterface>(Val: match)) { |
854 | return emitDefiniteFailure(loc: match.getLoc()) |
855 | << "expected operations in the match part to " |
856 | "implement MatchOpInterface" ; |
857 | } |
858 | DiagnosedSilenceableFailure diag = |
859 | state.applyTransform(transform: cast<transform::TransformOpInterface>(match)); |
860 | if (diag.succeeded()) |
861 | continue; |
862 | |
863 | return diag; |
864 | } |
865 | |
866 | // Remember the values mapped to the terminator operands so we can |
867 | // forward them to the action. |
868 | ValueRange yieldedValues = block.getTerminator()->getOperands(); |
869 | transform::detail::prepareValueMappings(mappings, yieldedValues, state); |
870 | return DiagnosedSilenceableFailure::success(); |
871 | } |
872 | |
873 | /// Returns `true` if both types implement one of the interfaces provided as |
874 | /// template parameters. |
875 | template <typename... Tys> |
876 | static bool implementSameInterface(Type t1, Type t2) { |
877 | return ((isa<Tys>(t1) && isa<Tys>(t2)) || ... || false); |
878 | } |
879 | |
880 | /// Returns `true` if both types implement one of the transform dialect |
881 | /// interfaces. |
882 | static bool implementSameTransformInterface(Type t1, Type t2) { |
883 | return implementSameInterface<transform::TransformHandleTypeInterface, |
884 | transform::TransformParamTypeInterface, |
885 | transform::TransformValueHandleTypeInterface>( |
886 | t1, t2); |
887 | } |
888 | |
889 | //===----------------------------------------------------------------------===// |
890 | // CollectMatchingOp |
891 | //===----------------------------------------------------------------------===// |
892 | |
893 | DiagnosedSilenceableFailure |
894 | transform::CollectMatchingOp::apply(transform::TransformRewriter &rewriter, |
895 | transform::TransformResults &results, |
896 | transform::TransformState &state) { |
897 | auto matcher = SymbolTable::lookupNearestSymbolFrom<FunctionOpInterface>( |
898 | getOperation(), getMatcher()); |
899 | if (matcher.isExternal()) { |
900 | return emitDefiniteFailure() |
901 | << "unresolved external symbol " << getMatcher(); |
902 | } |
903 | |
904 | SmallVector<SmallVector<MappedValue>, 2> rawResults; |
905 | rawResults.resize(getOperation()->getNumResults()); |
906 | std::optional<DiagnosedSilenceableFailure> maybeFailure; |
907 | for (Operation *root : state.getPayloadOps(getRoot())) { |
908 | WalkResult walkResult = root->walk([&](Operation *op) { |
909 | DEBUG_MATCHER({ |
910 | DBGS_MATCHER() << "matching " ; |
911 | op->print(llvm::dbgs(), |
912 | OpPrintingFlags().assumeVerified().skipRegions()); |
913 | llvm::dbgs() << " @" << op << "\n" ; |
914 | }); |
915 | |
916 | // Try matching. |
917 | SmallVector<SmallVector<MappedValue>> mappings; |
918 | DiagnosedSilenceableFailure diag = |
919 | matchBlock(matcher.getFunctionBody().front(), op, state, mappings); |
920 | if (diag.isDefiniteFailure()) |
921 | return WalkResult::interrupt(); |
922 | if (diag.isSilenceableFailure()) { |
923 | DEBUG_MATCHER(DBGS_MATCHER() << "matcher " << matcher.getName() |
924 | << " failed: " << diag.getMessage()); |
925 | return WalkResult::advance(); |
926 | } |
927 | |
928 | // If succeeded, collect results. |
929 | for (auto &&[i, mapping] : llvm::enumerate(mappings)) { |
930 | if (mapping.size() != 1) { |
931 | maybeFailure.emplace(emitSilenceableError() |
932 | << "result #" << i << ", associated with " |
933 | << mapping.size() |
934 | << " payload objects, expected 1" ); |
935 | return WalkResult::interrupt(); |
936 | } |
937 | rawResults[i].push_back(mapping[0]); |
938 | } |
939 | return WalkResult::advance(); |
940 | }); |
941 | if (walkResult.wasInterrupted()) |
942 | return std::move(*maybeFailure); |
943 | assert(!maybeFailure && "failure set but the walk was not interrupted" ); |
944 | |
945 | for (auto &&[opResult, rawResult] : |
946 | llvm::zip_equal(getOperation()->getResults(), rawResults)) { |
947 | results.setMappedValues(opResult, rawResult); |
948 | } |
949 | } |
950 | return DiagnosedSilenceableFailure::success(); |
951 | } |
952 | |
953 | void transform::CollectMatchingOp::getEffects( |
954 | SmallVectorImpl<MemoryEffects::EffectInstance> &effects) { |
955 | onlyReadsHandle(getRoot(), effects); |
956 | producesHandle(getResults(), effects); |
957 | onlyReadsPayload(effects); |
958 | } |
959 | |
960 | LogicalResult transform::CollectMatchingOp::verifySymbolUses( |
961 | SymbolTableCollection &symbolTable) { |
962 | auto matcherSymbol = dyn_cast_or_null<FunctionOpInterface>( |
963 | symbolTable.lookupNearestSymbolFrom(getOperation(), getMatcher())); |
964 | if (!matcherSymbol || |
965 | !isa<TransformOpInterface>(matcherSymbol.getOperation())) |
966 | return emitError() << "unresolved matcher symbol " << getMatcher(); |
967 | |
968 | ArrayRef<Type> argumentTypes = matcherSymbol.getArgumentTypes(); |
969 | if (argumentTypes.size() != 1 || |
970 | !isa<TransformHandleTypeInterface>(argumentTypes[0])) { |
971 | return emitError() |
972 | << "expected the matcher to take one operation handle argument" ; |
973 | } |
974 | if (!matcherSymbol.getArgAttr( |
975 | 0, transform::TransformDialect::kArgReadOnlyAttrName)) { |
976 | return emitError() << "expected the matcher argument to be marked readonly" ; |
977 | } |
978 | |
979 | ArrayRef<Type> resultTypes = matcherSymbol.getResultTypes(); |
980 | if (resultTypes.size() != getOperation()->getNumResults()) { |
981 | return emitError() |
982 | << "expected the matcher to yield as many values as op has results (" |
983 | << getOperation()->getNumResults() << "), got " |
984 | << resultTypes.size(); |
985 | } |
986 | |
987 | for (auto &&[i, matcherType, resultType] : |
988 | llvm::enumerate(resultTypes, getOperation()->getResultTypes())) { |
989 | if (implementSameTransformInterface(matcherType, resultType)) |
990 | continue; |
991 | |
992 | return emitError() |
993 | << "mismatching type interfaces for matcher result and op result #" |
994 | << i; |
995 | } |
996 | |
997 | return success(); |
998 | } |
999 | |
1000 | //===----------------------------------------------------------------------===// |
1001 | // ForeachMatchOp |
1002 | //===----------------------------------------------------------------------===// |
1003 | |
1004 | DiagnosedSilenceableFailure |
1005 | transform::ForeachMatchOp::apply(transform::TransformRewriter &rewriter, |
1006 | transform::TransformResults &results, |
1007 | transform::TransformState &state) { |
1008 | SmallVector<std::pair<FunctionOpInterface, FunctionOpInterface>> |
1009 | matchActionPairs; |
1010 | matchActionPairs.reserve(getMatchers().size()); |
1011 | SymbolTableCollection symbolTable; |
1012 | for (auto &&[matcher, action] : |
1013 | llvm::zip_equal(getMatchers(), getActions())) { |
1014 | auto matcherSymbol = |
1015 | symbolTable.lookupNearestSymbolFrom<FunctionOpInterface>( |
1016 | getOperation(), cast<SymbolRefAttr>(matcher)); |
1017 | auto actionSymbol = |
1018 | symbolTable.lookupNearestSymbolFrom<FunctionOpInterface>( |
1019 | getOperation(), cast<SymbolRefAttr>(action)); |
1020 | assert(matcherSymbol && actionSymbol && |
1021 | "unresolved symbols not caught by the verifier" ); |
1022 | |
1023 | if (matcherSymbol.isExternal()) |
1024 | return emitDefiniteFailure() << "unresolved external symbol " << matcher; |
1025 | if (actionSymbol.isExternal()) |
1026 | return emitDefiniteFailure() << "unresolved external symbol " << action; |
1027 | |
1028 | matchActionPairs.emplace_back(matcherSymbol, actionSymbol); |
1029 | } |
1030 | |
1031 | DiagnosedSilenceableFailure overallDiag = |
1032 | DiagnosedSilenceableFailure::success(); |
1033 | for (Operation *root : state.getPayloadOps(getRoot())) { |
1034 | WalkResult walkResult = root->walk([&](Operation *op) { |
1035 | // If getRestrictRoot is not present, skip over the root op itself so we |
1036 | // don't invalidate it. |
1037 | if (!getRestrictRoot() && op == root) |
1038 | return WalkResult::advance(); |
1039 | |
1040 | DEBUG_MATCHER({ |
1041 | DBGS_MATCHER() << "matching " ; |
1042 | op->print(llvm::dbgs(), |
1043 | OpPrintingFlags().assumeVerified().skipRegions()); |
1044 | llvm::dbgs() << " @" << op << "\n" ; |
1045 | }); |
1046 | |
1047 | // Try all the match/action pairs until the first successful match. |
1048 | for (auto [matcher, action] : matchActionPairs) { |
1049 | SmallVector<SmallVector<MappedValue>> mappings; |
1050 | DiagnosedSilenceableFailure diag = |
1051 | matchBlock(matcher.getFunctionBody().front(), op, state, mappings); |
1052 | if (diag.isDefiniteFailure()) |
1053 | return WalkResult::interrupt(); |
1054 | if (diag.isSilenceableFailure()) { |
1055 | DEBUG_MATCHER(DBGS_MATCHER() << "matcher " << matcher.getName() |
1056 | << " failed: " << diag.getMessage()); |
1057 | continue; |
1058 | } |
1059 | |
1060 | auto scope = state.make_region_scope(action.getFunctionBody()); |
1061 | for (auto &&[arg, map] : llvm::zip_equal( |
1062 | action.getFunctionBody().front().getArguments(), mappings)) { |
1063 | if (failed(state.mapBlockArgument(arg, map))) |
1064 | return WalkResult::interrupt(); |
1065 | } |
1066 | |
1067 | for (Operation &transform : |
1068 | action.getFunctionBody().front().without_terminator()) { |
1069 | DiagnosedSilenceableFailure result = |
1070 | state.applyTransform(cast<TransformOpInterface>(transform)); |
1071 | if (result.isDefiniteFailure()) |
1072 | return WalkResult::interrupt(); |
1073 | if (result.isSilenceableFailure()) { |
1074 | if (overallDiag.succeeded()) { |
1075 | overallDiag = emitSilenceableError() << "actions failed" ; |
1076 | } |
1077 | overallDiag.attachNote(action->getLoc()) |
1078 | << "failed action: " << result.getMessage(); |
1079 | overallDiag.attachNote(op->getLoc()) |
1080 | << "when applied to this matching payload" ; |
1081 | (void)result.silence(); |
1082 | continue; |
1083 | } |
1084 | } |
1085 | break; |
1086 | } |
1087 | return WalkResult::advance(); |
1088 | }); |
1089 | if (walkResult.wasInterrupted()) |
1090 | return DiagnosedSilenceableFailure::definiteFailure(); |
1091 | } |
1092 | |
1093 | // The root operation should not have been affected, so we can just reassign |
1094 | // the payload to the result. Note that we need to consume the root handle to |
1095 | // make sure any handles to operations inside, that could have been affected |
1096 | // by actions, are invalidated. |
1097 | results.set(llvm::cast<OpResult>(getUpdated()), |
1098 | state.getPayloadOps(getRoot())); |
1099 | return overallDiag; |
1100 | } |
1101 | |
1102 | void transform::ForeachMatchOp::getEffects( |
1103 | SmallVectorImpl<MemoryEffects::EffectInstance> &effects) { |
1104 | // Bail if invalid. |
1105 | if (getOperation()->getNumOperands() < 1 || |
1106 | getOperation()->getNumResults() < 1) { |
1107 | return modifiesPayload(effects); |
1108 | } |
1109 | |
1110 | consumesHandle(getRoot(), effects); |
1111 | producesHandle(getUpdated(), effects); |
1112 | modifiesPayload(effects); |
1113 | } |
1114 | |
1115 | /// Parses the comma-separated list of symbol reference pairs of the format |
1116 | /// `@matcher -> @action`. |
1117 | static ParseResult parseForeachMatchSymbols(OpAsmParser &parser, |
1118 | ArrayAttr &matchers, |
1119 | ArrayAttr &actions) { |
1120 | StringAttr matcher; |
1121 | StringAttr action; |
1122 | SmallVector<Attribute> matcherList; |
1123 | SmallVector<Attribute> actionList; |
1124 | do { |
1125 | if (parser.parseSymbolName(matcher) || parser.parseArrow() || |
1126 | parser.parseSymbolName(action)) { |
1127 | return failure(); |
1128 | } |
1129 | matcherList.push_back(SymbolRefAttr::get(matcher)); |
1130 | actionList.push_back(SymbolRefAttr::get(action)); |
1131 | } while (parser.parseOptionalComma().succeeded()); |
1132 | |
1133 | matchers = parser.getBuilder().getArrayAttr(matcherList); |
1134 | actions = parser.getBuilder().getArrayAttr(actionList); |
1135 | return success(); |
1136 | } |
1137 | |
1138 | /// Prints the comma-separated list of symbol reference pairs of the format |
1139 | /// `@matcher -> @action`. |
1140 | static void printForeachMatchSymbols(OpAsmPrinter &printer, Operation *op, |
1141 | ArrayAttr matchers, ArrayAttr actions) { |
1142 | printer.increaseIndent(); |
1143 | printer.increaseIndent(); |
1144 | for (auto &&[matcher, action, idx] : llvm::zip_equal( |
1145 | matchers, actions, llvm::seq<unsigned>(0, matchers.size()))) { |
1146 | printer.printNewline(); |
1147 | printer << cast<SymbolRefAttr>(matcher) << " -> " |
1148 | << cast<SymbolRefAttr>(action); |
1149 | if (idx != matchers.size() - 1) |
1150 | printer << ", " ; |
1151 | } |
1152 | printer.decreaseIndent(); |
1153 | printer.decreaseIndent(); |
1154 | } |
1155 | |
1156 | LogicalResult transform::ForeachMatchOp::verify() { |
1157 | if (getMatchers().size() != getActions().size()) |
1158 | return emitOpError() << "expected the same number of matchers and actions" ; |
1159 | if (getMatchers().empty()) |
1160 | return emitOpError() << "expected at least one match/action pair" ; |
1161 | |
1162 | llvm::SmallPtrSet<Attribute, 8> matcherNames; |
1163 | for (Attribute name : getMatchers()) { |
1164 | if (matcherNames.insert(name).second) |
1165 | continue; |
1166 | emitWarning() << "matcher " << name |
1167 | << " is used more than once, only the first match will apply" ; |
1168 | } |
1169 | |
1170 | return success(); |
1171 | } |
1172 | |
1173 | /// Checks that the attributes of the function-like operation have correct |
1174 | /// consumption effect annotations. If `alsoVerifyInternal`, checks for |
1175 | /// annotations being present even if they can be inferred from the body. |
1176 | static DiagnosedSilenceableFailure |
1177 | verifyFunctionLikeConsumeAnnotations(FunctionOpInterface op, bool emitWarnings, |
1178 | bool alsoVerifyInternal = false) { |
1179 | auto transformOp = cast<transform::TransformOpInterface>(op.getOperation()); |
1180 | llvm::SmallDenseSet<unsigned> consumedArguments; |
1181 | if (!op.isExternal()) { |
1182 | transform::getConsumedBlockArguments(block&: op.getFunctionBody().front(), |
1183 | consumedArguments); |
1184 | } |
1185 | for (unsigned i = 0, e = op.getNumArguments(); i < e; ++i) { |
1186 | bool isConsumed = |
1187 | op.getArgAttr(i, transform::TransformDialect::kArgConsumedAttrName) != |
1188 | nullptr; |
1189 | bool isReadOnly = |
1190 | op.getArgAttr(i, transform::TransformDialect::kArgReadOnlyAttrName) != |
1191 | nullptr; |
1192 | if (isConsumed && isReadOnly) { |
1193 | return transformOp.emitSilenceableError() |
1194 | << "argument #" << i << " cannot be both readonly and consumed" ; |
1195 | } |
1196 | if ((op.isExternal() || alsoVerifyInternal) && !isConsumed && !isReadOnly) { |
1197 | return transformOp.emitSilenceableError() |
1198 | << "must provide consumed/readonly status for arguments of " |
1199 | "external or called ops" ; |
1200 | } |
1201 | if (op.isExternal()) |
1202 | continue; |
1203 | |
1204 | if (consumedArguments.contains(V: i) && !isConsumed && isReadOnly) { |
1205 | return transformOp.emitSilenceableError() |
1206 | << "argument #" << i |
1207 | << " is consumed in the body but is not marked as such" ; |
1208 | } |
1209 | if (emitWarnings && !consumedArguments.contains(V: i) && isConsumed) { |
1210 | // Cannot use op.emitWarning() here as it would attempt to verify the op |
1211 | // before printing, resulting in infinite recursion. |
1212 | emitWarning(op->getLoc()) |
1213 | << "op argument #" << i |
1214 | << " is not consumed in the body but is marked as consumed" ; |
1215 | } |
1216 | } |
1217 | return DiagnosedSilenceableFailure::success(); |
1218 | } |
1219 | |
1220 | LogicalResult transform::ForeachMatchOp::verifySymbolUses( |
1221 | SymbolTableCollection &symbolTable) { |
1222 | assert(getMatchers().size() == getActions().size()); |
1223 | auto consumedAttr = |
1224 | StringAttr::get(getContext(), TransformDialect::kArgConsumedAttrName); |
1225 | for (auto &&[matcher, action] : |
1226 | llvm::zip_equal(getMatchers(), getActions())) { |
1227 | auto matcherSymbol = dyn_cast_or_null<FunctionOpInterface>( |
1228 | symbolTable.lookupNearestSymbolFrom(getOperation(), |
1229 | cast<SymbolRefAttr>(matcher))); |
1230 | auto actionSymbol = dyn_cast_or_null<FunctionOpInterface>( |
1231 | symbolTable.lookupNearestSymbolFrom(getOperation(), |
1232 | cast<SymbolRefAttr>(action))); |
1233 | if (!matcherSymbol || |
1234 | !isa<TransformOpInterface>(matcherSymbol.getOperation())) |
1235 | return emitError() << "unresolved matcher symbol " << matcher; |
1236 | if (!actionSymbol || |
1237 | !isa<TransformOpInterface>(actionSymbol.getOperation())) |
1238 | return emitError() << "unresolved action symbol " << action; |
1239 | |
1240 | if (failed(verifyFunctionLikeConsumeAnnotations(matcherSymbol, |
1241 | /*emitWarnings=*/false, |
1242 | /*alsoVerifyInternal=*/true) |
1243 | .checkAndReport())) { |
1244 | return failure(); |
1245 | } |
1246 | if (failed(verifyFunctionLikeConsumeAnnotations(actionSymbol, |
1247 | /*emitWarnings=*/false, |
1248 | /*alsoVerifyInternal=*/true) |
1249 | .checkAndReport())) { |
1250 | return failure(); |
1251 | } |
1252 | |
1253 | ArrayRef<Type> matcherResults = matcherSymbol.getResultTypes(); |
1254 | ArrayRef<Type> actionArguments = actionSymbol.getArgumentTypes(); |
1255 | if (matcherResults.size() != actionArguments.size()) { |
1256 | return emitError() << "mismatching number of matcher results and " |
1257 | "action arguments between " |
1258 | << matcher << " (" << matcherResults.size() << ") and " |
1259 | << action << " (" << actionArguments.size() << ")" ; |
1260 | } |
1261 | for (auto &&[i, matcherType, actionType] : |
1262 | llvm::enumerate(matcherResults, actionArguments)) { |
1263 | if (implementSameTransformInterface(matcherType, actionType)) |
1264 | continue; |
1265 | |
1266 | return emitError() << "mismatching type interfaces for matcher result " |
1267 | "and action argument #" |
1268 | << i; |
1269 | } |
1270 | |
1271 | if (!actionSymbol.getResultTypes().empty()) { |
1272 | InFlightDiagnostic diag = |
1273 | emitError() << "action symbol is not expected to have results" ; |
1274 | diag.attachNote(actionSymbol->getLoc()) << "symbol declaration" ; |
1275 | return diag; |
1276 | } |
1277 | |
1278 | if (matcherSymbol.getArgumentTypes().size() != 1 || |
1279 | !implementSameTransformInterface(matcherSymbol.getArgumentTypes()[0], |
1280 | getRoot().getType())) { |
1281 | InFlightDiagnostic diag = |
1282 | emitOpError() << "expects matcher symbol to have one argument with " |
1283 | "the same transform interface as the first operand" ; |
1284 | diag.attachNote(matcherSymbol->getLoc()) << "symbol declaration" ; |
1285 | return diag; |
1286 | } |
1287 | |
1288 | if (matcherSymbol.getArgAttr(0, consumedAttr)) { |
1289 | InFlightDiagnostic diag = |
1290 | emitOpError() |
1291 | << "does not expect matcher symbol to consume its operand" ; |
1292 | diag.attachNote(matcherSymbol->getLoc()) << "symbol declaration" ; |
1293 | return diag; |
1294 | } |
1295 | } |
1296 | return success(); |
1297 | } |
1298 | |
1299 | //===----------------------------------------------------------------------===// |
1300 | // ForeachOp |
1301 | //===----------------------------------------------------------------------===// |
1302 | |
1303 | DiagnosedSilenceableFailure |
1304 | transform::ForeachOp::apply(transform::TransformRewriter &rewriter, |
1305 | transform::TransformResults &results, |
1306 | transform::TransformState &state) { |
1307 | SmallVector<SmallVector<Operation *>> resultOps(getNumResults(), {}); |
1308 | // Store payload ops in a vector because ops may be removed from the mapping |
1309 | // by the TrackingRewriter while the iteration is in progress. |
1310 | SmallVector<Operation *> targets = |
1311 | llvm::to_vector(state.getPayloadOps(getTarget())); |
1312 | for (Operation *op : targets) { |
1313 | auto scope = state.make_region_scope(getBody()); |
1314 | if (failed(state.mapBlockArguments(getIterationVariable(), {op}))) |
1315 | return DiagnosedSilenceableFailure::definiteFailure(); |
1316 | |
1317 | // Execute loop body. |
1318 | for (Operation &transform : getBody().front().without_terminator()) { |
1319 | DiagnosedSilenceableFailure result = state.applyTransform( |
1320 | cast<transform::TransformOpInterface>(transform)); |
1321 | if (!result.succeeded()) |
1322 | return result; |
1323 | } |
1324 | |
1325 | // Append yielded payload ops to result list (if any). |
1326 | for (unsigned i = 0; i < getNumResults(); ++i) { |
1327 | auto yieldedOps = state.getPayloadOps(getYieldOp().getOperand(i)); |
1328 | resultOps[i].append(yieldedOps.begin(), yieldedOps.end()); |
1329 | } |
1330 | } |
1331 | |
1332 | for (unsigned i = 0; i < getNumResults(); ++i) |
1333 | results.set(llvm::cast<OpResult>(getResult(i)), resultOps[i]); |
1334 | |
1335 | return DiagnosedSilenceableFailure::success(); |
1336 | } |
1337 | |
1338 | void transform::ForeachOp::getEffects( |
1339 | SmallVectorImpl<MemoryEffects::EffectInstance> &effects) { |
1340 | BlockArgument iterVar = getIterationVariable(); |
1341 | if (any_of(getBody().front().without_terminator(), [&](Operation &op) { |
1342 | return isHandleConsumed(iterVar, cast<TransformOpInterface>(&op)); |
1343 | })) { |
1344 | consumesHandle(getTarget(), effects); |
1345 | } else { |
1346 | onlyReadsHandle(getTarget(), effects); |
1347 | } |
1348 | |
1349 | if (any_of(getBody().front().without_terminator(), [&](Operation &op) { |
1350 | return doesModifyPayload(cast<TransformOpInterface>(&op)); |
1351 | })) { |
1352 | modifiesPayload(effects); |
1353 | } else if (any_of(getBody().front().without_terminator(), [&](Operation &op) { |
1354 | return doesReadPayload(cast<TransformOpInterface>(&op)); |
1355 | })) { |
1356 | onlyReadsPayload(effects); |
1357 | } |
1358 | |
1359 | for (Value result : getResults()) |
1360 | producesHandle(result, effects); |
1361 | } |
1362 | |
1363 | void transform::ForeachOp::getSuccessorRegions( |
1364 | RegionBranchPoint point, SmallVectorImpl<RegionSuccessor> ®ions) { |
1365 | Region *bodyRegion = &getBody(); |
1366 | if (point.isParent()) { |
1367 | regions.emplace_back(bodyRegion, bodyRegion->getArguments()); |
1368 | return; |
1369 | } |
1370 | |
1371 | // Branch back to the region or the parent. |
1372 | assert(point == getBody() && "unexpected region index" ); |
1373 | regions.emplace_back(bodyRegion, bodyRegion->getArguments()); |
1374 | regions.emplace_back(); |
1375 | } |
1376 | |
1377 | OperandRange |
1378 | transform::ForeachOp::getEntrySuccessorOperands(RegionBranchPoint point) { |
1379 | // The iteration variable op handle is mapped to a subset (one op to be |
1380 | // precise) of the payload ops of the ForeachOp operand. |
1381 | assert(point == getBody() && "unexpected region index" ); |
1382 | return getOperation()->getOperands(); |
1383 | } |
1384 | |
1385 | transform::YieldOp transform::ForeachOp::getYieldOp() { |
1386 | return cast<transform::YieldOp>(getBody().front().getTerminator()); |
1387 | } |
1388 | |
1389 | LogicalResult transform::ForeachOp::verify() { |
1390 | auto yieldOp = getYieldOp(); |
1391 | if (getNumResults() != yieldOp.getNumOperands()) |
1392 | return emitOpError() << "expects the same number of results as the " |
1393 | "terminator has operands" ; |
1394 | for (Value v : yieldOp.getOperands()) |
1395 | if (!llvm::isa<TransformHandleTypeInterface>(v.getType())) |
1396 | return yieldOp->emitOpError("expects operands to have types implementing " |
1397 | "TransformHandleTypeInterface" ); |
1398 | return success(); |
1399 | } |
1400 | |
1401 | //===----------------------------------------------------------------------===// |
1402 | // GetParentOp |
1403 | //===----------------------------------------------------------------------===// |
1404 | |
1405 | DiagnosedSilenceableFailure |
1406 | transform::GetParentOp::apply(transform::TransformRewriter &rewriter, |
1407 | transform::TransformResults &results, |
1408 | transform::TransformState &state) { |
1409 | SmallVector<Operation *> parents; |
1410 | DenseSet<Operation *> resultSet; |
1411 | for (Operation *target : state.getPayloadOps(getTarget())) { |
1412 | Operation *parent = target; |
1413 | for (int64_t i = 0, e = getNthParent(); i < e; ++i) { |
1414 | parent = parent->getParentOp(); |
1415 | while (parent) { |
1416 | bool checkIsolatedFromAbove = |
1417 | !getIsolatedFromAbove() || |
1418 | parent->hasTrait<OpTrait::IsIsolatedFromAbove>(); |
1419 | bool checkOpName = !getOpName().has_value() || |
1420 | parent->getName().getStringRef() == *getOpName(); |
1421 | if (checkIsolatedFromAbove && checkOpName) |
1422 | break; |
1423 | parent = parent->getParentOp(); |
1424 | } |
1425 | if (!parent) { |
1426 | if (getAllowEmptyResults()) { |
1427 | results.set(llvm::cast<OpResult>(getResult()), parents); |
1428 | return DiagnosedSilenceableFailure::success(); |
1429 | } |
1430 | DiagnosedSilenceableFailure diag = |
1431 | emitSilenceableError() |
1432 | << "could not find a parent op that matches all requirements" ; |
1433 | diag.attachNote(target->getLoc()) << "target op" ; |
1434 | return diag; |
1435 | } |
1436 | } |
1437 | if (getDeduplicate()) { |
1438 | if (!resultSet.contains(parent)) { |
1439 | parents.push_back(parent); |
1440 | resultSet.insert(parent); |
1441 | } |
1442 | } else { |
1443 | parents.push_back(parent); |
1444 | } |
1445 | } |
1446 | results.set(llvm::cast<OpResult>(getResult()), parents); |
1447 | return DiagnosedSilenceableFailure::success(); |
1448 | } |
1449 | |
1450 | //===----------------------------------------------------------------------===// |
1451 | // GetConsumersOfResult |
1452 | //===----------------------------------------------------------------------===// |
1453 | |
1454 | DiagnosedSilenceableFailure |
1455 | transform::GetConsumersOfResult::apply(transform::TransformRewriter &rewriter, |
1456 | transform::TransformResults &results, |
1457 | transform::TransformState &state) { |
1458 | int64_t resultNumber = getResultNumber(); |
1459 | auto payloadOps = state.getPayloadOps(getTarget()); |
1460 | if (std::empty(payloadOps)) { |
1461 | results.set(cast<OpResult>(getResult()), {}); |
1462 | return DiagnosedSilenceableFailure::success(); |
1463 | } |
1464 | if (!llvm::hasSingleElement(payloadOps)) |
1465 | return emitDefiniteFailure() |
1466 | << "handle must be mapped to exactly one payload op" ; |
1467 | |
1468 | Operation *target = *payloadOps.begin(); |
1469 | if (target->getNumResults() <= resultNumber) |
1470 | return emitDefiniteFailure() << "result number overflow" ; |
1471 | results.set(llvm::cast<OpResult>(getResult()), |
1472 | llvm::to_vector(target->getResult(resultNumber).getUsers())); |
1473 | return DiagnosedSilenceableFailure::success(); |
1474 | } |
1475 | |
1476 | //===----------------------------------------------------------------------===// |
1477 | // GetDefiningOp |
1478 | //===----------------------------------------------------------------------===// |
1479 | |
1480 | DiagnosedSilenceableFailure |
1481 | transform::GetDefiningOp::apply(transform::TransformRewriter &rewriter, |
1482 | transform::TransformResults &results, |
1483 | transform::TransformState &state) { |
1484 | SmallVector<Operation *> definingOps; |
1485 | for (Value v : state.getPayloadValues(getTarget())) { |
1486 | if (llvm::isa<BlockArgument>(v)) { |
1487 | DiagnosedSilenceableFailure diag = |
1488 | emitSilenceableError() << "cannot get defining op of block argument" ; |
1489 | diag.attachNote(v.getLoc()) << "target value" ; |
1490 | return diag; |
1491 | } |
1492 | definingOps.push_back(v.getDefiningOp()); |
1493 | } |
1494 | results.set(llvm::cast<OpResult>(getResult()), definingOps); |
1495 | return DiagnosedSilenceableFailure::success(); |
1496 | } |
1497 | |
1498 | //===----------------------------------------------------------------------===// |
1499 | // GetProducerOfOperand |
1500 | //===----------------------------------------------------------------------===// |
1501 | |
1502 | DiagnosedSilenceableFailure |
1503 | transform::GetProducerOfOperand::apply(transform::TransformRewriter &rewriter, |
1504 | transform::TransformResults &results, |
1505 | transform::TransformState &state) { |
1506 | int64_t operandNumber = getOperandNumber(); |
1507 | SmallVector<Operation *> producers; |
1508 | for (Operation *target : state.getPayloadOps(getTarget())) { |
1509 | Operation *producer = |
1510 | target->getNumOperands() <= operandNumber |
1511 | ? nullptr |
1512 | : target->getOperand(operandNumber).getDefiningOp(); |
1513 | if (!producer) { |
1514 | DiagnosedSilenceableFailure diag = |
1515 | emitSilenceableError() |
1516 | << "could not find a producer for operand number: " << operandNumber |
1517 | << " of " << *target; |
1518 | diag.attachNote(target->getLoc()) << "target op" ; |
1519 | return diag; |
1520 | } |
1521 | producers.push_back(producer); |
1522 | } |
1523 | results.set(llvm::cast<OpResult>(getResult()), producers); |
1524 | return DiagnosedSilenceableFailure::success(); |
1525 | } |
1526 | |
1527 | //===----------------------------------------------------------------------===// |
1528 | // GetOperandOp |
1529 | //===----------------------------------------------------------------------===// |
1530 | |
1531 | DiagnosedSilenceableFailure |
1532 | transform::GetOperandOp::apply(transform::TransformRewriter &rewriter, |
1533 | transform::TransformResults &results, |
1534 | transform::TransformState &state) { |
1535 | SmallVector<Value> operands; |
1536 | for (Operation *target : state.getPayloadOps(getTarget())) { |
1537 | SmallVector<int64_t> operandPositions; |
1538 | DiagnosedSilenceableFailure diag = expandTargetSpecification( |
1539 | getLoc(), getIsAll(), getIsInverted(), getRawPositionList(), |
1540 | target->getNumOperands(), operandPositions); |
1541 | if (diag.isSilenceableFailure()) { |
1542 | diag.attachNote(target->getLoc()) |
1543 | << "while considering positions of this payload operation" ; |
1544 | return diag; |
1545 | } |
1546 | llvm::append_range(operands, |
1547 | llvm::map_range(operandPositions, [&](int64_t pos) { |
1548 | return target->getOperand(pos); |
1549 | })); |
1550 | } |
1551 | results.setValues(cast<OpResult>(getResult()), operands); |
1552 | return DiagnosedSilenceableFailure::success(); |
1553 | } |
1554 | |
1555 | LogicalResult transform::GetOperandOp::verify() { |
1556 | return verifyTransformMatchDimsOp(getOperation(), getRawPositionList(), |
1557 | getIsInverted(), getIsAll()); |
1558 | } |
1559 | |
1560 | //===----------------------------------------------------------------------===// |
1561 | // GetResultOp |
1562 | //===----------------------------------------------------------------------===// |
1563 | |
1564 | DiagnosedSilenceableFailure |
1565 | transform::GetResultOp::apply(transform::TransformRewriter &rewriter, |
1566 | transform::TransformResults &results, |
1567 | transform::TransformState &state) { |
1568 | SmallVector<Value> opResults; |
1569 | for (Operation *target : state.getPayloadOps(getTarget())) { |
1570 | SmallVector<int64_t> resultPositions; |
1571 | DiagnosedSilenceableFailure diag = expandTargetSpecification( |
1572 | getLoc(), getIsAll(), getIsInverted(), getRawPositionList(), |
1573 | target->getNumResults(), resultPositions); |
1574 | if (diag.isSilenceableFailure()) { |
1575 | diag.attachNote(target->getLoc()) |
1576 | << "while considering positions of this payload operation" ; |
1577 | return diag; |
1578 | } |
1579 | llvm::append_range(opResults, |
1580 | llvm::map_range(resultPositions, [&](int64_t pos) { |
1581 | return target->getResult(pos); |
1582 | })); |
1583 | } |
1584 | results.setValues(cast<OpResult>(getResult()), opResults); |
1585 | return DiagnosedSilenceableFailure::success(); |
1586 | } |
1587 | |
1588 | LogicalResult transform::GetResultOp::verify() { |
1589 | return verifyTransformMatchDimsOp(getOperation(), getRawPositionList(), |
1590 | getIsInverted(), getIsAll()); |
1591 | } |
1592 | |
1593 | //===----------------------------------------------------------------------===// |
1594 | // GetTypeOp |
1595 | //===----------------------------------------------------------------------===// |
1596 | |
1597 | void transform::GetTypeOp::getEffects( |
1598 | SmallVectorImpl<MemoryEffects::EffectInstance> &effects) { |
1599 | onlyReadsHandle(getValue(), effects); |
1600 | producesHandle(getResult(), effects); |
1601 | onlyReadsPayload(effects); |
1602 | } |
1603 | |
1604 | DiagnosedSilenceableFailure |
1605 | transform::GetTypeOp::apply(transform::TransformRewriter &rewriter, |
1606 | transform::TransformResults &results, |
1607 | transform::TransformState &state) { |
1608 | SmallVector<Attribute> params; |
1609 | for (Value value : state.getPayloadValues(getValue())) { |
1610 | Type type = value.getType(); |
1611 | if (getElemental()) { |
1612 | if (auto shaped = dyn_cast<ShapedType>(type)) { |
1613 | type = shaped.getElementType(); |
1614 | } |
1615 | } |
1616 | params.push_back(TypeAttr::get(type)); |
1617 | } |
1618 | results.setParams(cast<OpResult>(getResult()), params); |
1619 | return DiagnosedSilenceableFailure::success(); |
1620 | } |
1621 | |
1622 | //===----------------------------------------------------------------------===// |
1623 | // IncludeOp |
1624 | //===----------------------------------------------------------------------===// |
1625 | |
1626 | /// Applies the transform ops contained in `block`. Maps `results` to the same |
1627 | /// values as the operands of the block terminator. |
1628 | static DiagnosedSilenceableFailure |
1629 | applySequenceBlock(Block &block, transform::FailurePropagationMode mode, |
1630 | transform::TransformState &state, |
1631 | transform::TransformResults &results) { |
1632 | // Apply the sequenced ops one by one. |
1633 | for (Operation &transform : block.without_terminator()) { |
1634 | DiagnosedSilenceableFailure result = |
1635 | state.applyTransform(transform: cast<transform::TransformOpInterface>(transform)); |
1636 | if (result.isDefiniteFailure()) |
1637 | return result; |
1638 | |
1639 | if (result.isSilenceableFailure()) { |
1640 | if (mode == transform::FailurePropagationMode::Propagate) { |
1641 | // Propagate empty results in case of early exit. |
1642 | forwardEmptyOperands(block: &block, state, results); |
1643 | return result; |
1644 | } |
1645 | (void)result.silence(); |
1646 | } |
1647 | } |
1648 | |
1649 | // Forward the operation mapping for values yielded from the sequence to the |
1650 | // values produced by the sequence op. |
1651 | transform::detail::forwardTerminatorOperands(block: &block, state, results); |
1652 | return DiagnosedSilenceableFailure::success(); |
1653 | } |
1654 | |
1655 | DiagnosedSilenceableFailure |
1656 | transform::IncludeOp::apply(transform::TransformRewriter &rewriter, |
1657 | transform::TransformResults &results, |
1658 | transform::TransformState &state) { |
1659 | auto callee = SymbolTable::lookupNearestSymbolFrom<NamedSequenceOp>( |
1660 | getOperation(), getTarget()); |
1661 | assert(callee && "unverified reference to unknown symbol" ); |
1662 | |
1663 | if (callee.isExternal()) |
1664 | return emitDefiniteFailure() << "unresolved external named sequence" ; |
1665 | |
1666 | // Map operands to block arguments. |
1667 | SmallVector<SmallVector<MappedValue>> mappings; |
1668 | detail::prepareValueMappings(mappings, getOperands(), state); |
1669 | auto scope = state.make_region_scope(callee.getBody()); |
1670 | for (auto &&[arg, map] : |
1671 | llvm::zip_equal(callee.getBody().front().getArguments(), mappings)) { |
1672 | if (failed(state.mapBlockArgument(arg, map))) |
1673 | return DiagnosedSilenceableFailure::definiteFailure(); |
1674 | } |
1675 | |
1676 | DiagnosedSilenceableFailure result = applySequenceBlock( |
1677 | callee.getBody().front(), getFailurePropagationMode(), state, results); |
1678 | mappings.clear(); |
1679 | detail::prepareValueMappings( |
1680 | mappings, callee.getBody().front().getTerminator()->getOperands(), state); |
1681 | for (auto &&[result, mapping] : llvm::zip_equal(getResults(), mappings)) |
1682 | results.setMappedValues(result, mapping); |
1683 | return result; |
1684 | } |
1685 | |
1686 | static DiagnosedSilenceableFailure |
1687 | verifyNamedSequenceOp(transform::NamedSequenceOp op, bool emitWarnings); |
1688 | |
1689 | void transform::IncludeOp::getEffects( |
1690 | SmallVectorImpl<MemoryEffects::EffectInstance> &effects) { |
1691 | // Always mark as modifying the payload. |
1692 | // TODO: a mechanism to annotate effects on payload. Even when all handles are |
1693 | // only read, the payload may still be modified, so we currently stay on the |
1694 | // conservative side and always indicate modification. This may prevent some |
1695 | // code reordering. |
1696 | modifiesPayload(effects); |
1697 | |
1698 | // Results are always produced. |
1699 | producesHandle(getResults(), effects); |
1700 | |
1701 | // Adds default effects to operands and results. This will be added if |
1702 | // preconditions fail so the trait verifier doesn't complain about missing |
1703 | // effects and the real precondition failure is reported later on. |
1704 | auto defaultEffects = [&] { onlyReadsHandle(getOperands(), effects); }; |
1705 | |
1706 | // Bail if the callee is unknown. This may run as part of the verification |
1707 | // process before we verified the validity of the callee or of this op. |
1708 | auto target = |
1709 | getOperation()->getAttrOfType<SymbolRefAttr>(getTargetAttrName()); |
1710 | if (!target) |
1711 | return defaultEffects(); |
1712 | auto callee = SymbolTable::lookupNearestSymbolFrom<NamedSequenceOp>( |
1713 | getOperation(), getTarget()); |
1714 | if (!callee) |
1715 | return defaultEffects(); |
1716 | DiagnosedSilenceableFailure earlyVerifierResult = |
1717 | verifyNamedSequenceOp(callee, /*emitWarnings=*/false); |
1718 | if (!earlyVerifierResult.succeeded()) { |
1719 | (void)earlyVerifierResult.silence(); |
1720 | return defaultEffects(); |
1721 | } |
1722 | |
1723 | for (unsigned i = 0, e = getNumOperands(); i < e; ++i) { |
1724 | if (callee.getArgAttr(i, TransformDialect::kArgConsumedAttrName)) |
1725 | consumesHandle(getOperand(i), effects); |
1726 | else |
1727 | onlyReadsHandle(getOperand(i), effects); |
1728 | } |
1729 | } |
1730 | |
1731 | LogicalResult |
1732 | transform::IncludeOp::verifySymbolUses(SymbolTableCollection &symbolTable) { |
1733 | // Access through indirection and do additional checking because this may be |
1734 | // running before the main op verifier. |
1735 | auto targetAttr = getOperation()->getAttrOfType<SymbolRefAttr>("target" ); |
1736 | if (!targetAttr) |
1737 | return emitOpError() << "expects a 'target' symbol reference attribute" ; |
1738 | |
1739 | auto target = symbolTable.lookupNearestSymbolFrom<transform::NamedSequenceOp>( |
1740 | *this, targetAttr); |
1741 | if (!target) |
1742 | return emitOpError() << "does not reference a named transform sequence" ; |
1743 | |
1744 | FunctionType fnType = target.getFunctionType(); |
1745 | if (fnType.getNumInputs() != getNumOperands()) |
1746 | return emitError("incorrect number of operands for callee" ); |
1747 | |
1748 | for (unsigned i = 0, e = fnType.getNumInputs(); i != e; ++i) { |
1749 | if (getOperand(i).getType() != fnType.getInput(i)) { |
1750 | return emitOpError("operand type mismatch: expected operand type " ) |
1751 | << fnType.getInput(i) << ", but provided " |
1752 | << getOperand(i).getType() << " for operand number " << i; |
1753 | } |
1754 | } |
1755 | |
1756 | if (fnType.getNumResults() != getNumResults()) |
1757 | return emitError("incorrect number of results for callee" ); |
1758 | |
1759 | for (unsigned i = 0, e = fnType.getNumResults(); i != e; ++i) { |
1760 | Type resultType = getResult(i).getType(); |
1761 | Type funcType = fnType.getResult(i); |
1762 | if (!implementSameTransformInterface(resultType, funcType)) { |
1763 | return emitOpError() << "type of result #" << i |
1764 | << " must implement the same transform dialect " |
1765 | "interface as the corresponding callee result" ; |
1766 | } |
1767 | } |
1768 | |
1769 | return verifyFunctionLikeConsumeAnnotations( |
1770 | cast<FunctionOpInterface>(*target), /*emitWarnings=*/false, |
1771 | /*alsoVerifyInternal=*/true) |
1772 | .checkAndReport(); |
1773 | } |
1774 | |
1775 | //===----------------------------------------------------------------------===// |
1776 | // MatchOperationEmptyOp |
1777 | //===----------------------------------------------------------------------===// |
1778 | |
1779 | DiagnosedSilenceableFailure transform::MatchOperationEmptyOp::matchOperation( |
1780 | ::std::optional<::mlir::Operation *> maybeCurrent, |
1781 | transform::TransformResults &results, transform::TransformState &state) { |
1782 | if (!maybeCurrent.has_value()) { |
1783 | DEBUG_MATCHER({ DBGS_MATCHER() << "MatchOperationEmptyOp success\n" ; }); |
1784 | return DiagnosedSilenceableFailure::success(); |
1785 | } |
1786 | DEBUG_MATCHER({ DBGS_MATCHER() << "MatchOperationEmptyOp failure\n" ; }); |
1787 | return emitSilenceableError() << "operation is not empty" ; |
1788 | } |
1789 | |
1790 | //===----------------------------------------------------------------------===// |
1791 | // MatchOperationNameOp |
1792 | //===----------------------------------------------------------------------===// |
1793 | |
1794 | DiagnosedSilenceableFailure transform::MatchOperationNameOp::matchOperation( |
1795 | Operation *current, transform::TransformResults &results, |
1796 | transform::TransformState &state) { |
1797 | StringRef currentOpName = current->getName().getStringRef(); |
1798 | for (auto acceptedAttr : getOpNames().getAsRange<StringAttr>()) { |
1799 | if (acceptedAttr.getValue() == currentOpName) |
1800 | return DiagnosedSilenceableFailure::success(); |
1801 | } |
1802 | return emitSilenceableError() << "wrong operation name" ; |
1803 | } |
1804 | |
1805 | //===----------------------------------------------------------------------===// |
1806 | // MatchParamCmpIOp |
1807 | //===----------------------------------------------------------------------===// |
1808 | |
1809 | DiagnosedSilenceableFailure |
1810 | transform::MatchParamCmpIOp::apply(transform::TransformRewriter &rewriter, |
1811 | transform::TransformResults &results, |
1812 | transform::TransformState &state) { |
1813 | auto signedAPIntAsString = [&](const APInt &value) { |
1814 | std::string str; |
1815 | llvm::raw_string_ostream os(str); |
1816 | value.print(os, /*isSigned=*/true); |
1817 | return os.str(); |
1818 | }; |
1819 | |
1820 | ArrayRef<Attribute> params = state.getParams(getParam()); |
1821 | ArrayRef<Attribute> references = state.getParams(getReference()); |
1822 | |
1823 | if (params.size() != references.size()) { |
1824 | return emitSilenceableError() |
1825 | << "parameters have different payload lengths (" << params.size() |
1826 | << " vs " << references.size() << ")" ; |
1827 | } |
1828 | |
1829 | for (auto &&[i, param, reference] : llvm::enumerate(params, references)) { |
1830 | auto intAttr = llvm::dyn_cast<IntegerAttr>(param); |
1831 | auto refAttr = llvm::dyn_cast<IntegerAttr>(reference); |
1832 | if (!intAttr || !refAttr) { |
1833 | return emitDefiniteFailure() |
1834 | << "non-integer parameter value not expected" ; |
1835 | } |
1836 | if (intAttr.getType() != refAttr.getType()) { |
1837 | return emitDefiniteFailure() |
1838 | << "mismatching integer attribute types in parameter #" << i; |
1839 | } |
1840 | APInt value = intAttr.getValue(); |
1841 | APInt refValue = refAttr.getValue(); |
1842 | |
1843 | // TODO: this copy will not be necessary in C++20. |
1844 | int64_t position = i; |
1845 | auto reportError = [&](StringRef direction) { |
1846 | DiagnosedSilenceableFailure diag = |
1847 | emitSilenceableError() << "expected parameter to be " << direction |
1848 | << " " << signedAPIntAsString(refValue) |
1849 | << ", got " << signedAPIntAsString(value); |
1850 | diag.attachNote(getParam().getLoc()) |
1851 | << "value # " << position |
1852 | << " associated with the parameter defined here" ; |
1853 | return diag; |
1854 | }; |
1855 | |
1856 | switch (getPredicate()) { |
1857 | case MatchCmpIPredicate::eq: |
1858 | if (value.eq(refValue)) |
1859 | break; |
1860 | return reportError("equal to" ); |
1861 | case MatchCmpIPredicate::ne: |
1862 | if (value.ne(refValue)) |
1863 | break; |
1864 | return reportError("not equal to" ); |
1865 | case MatchCmpIPredicate::lt: |
1866 | if (value.slt(refValue)) |
1867 | break; |
1868 | return reportError("less than" ); |
1869 | case MatchCmpIPredicate::le: |
1870 | if (value.sle(refValue)) |
1871 | break; |
1872 | return reportError("less than or equal to" ); |
1873 | case MatchCmpIPredicate::gt: |
1874 | if (value.sgt(refValue)) |
1875 | break; |
1876 | return reportError("greater than" ); |
1877 | case MatchCmpIPredicate::ge: |
1878 | if (value.sge(refValue)) |
1879 | break; |
1880 | return reportError("greater than or equal to" ); |
1881 | } |
1882 | } |
1883 | return DiagnosedSilenceableFailure::success(); |
1884 | } |
1885 | |
1886 | void transform::MatchParamCmpIOp::getEffects( |
1887 | SmallVectorImpl<MemoryEffects::EffectInstance> &effects) { |
1888 | onlyReadsHandle(getParam(), effects); |
1889 | onlyReadsHandle(getReference(), effects); |
1890 | } |
1891 | |
1892 | //===----------------------------------------------------------------------===// |
1893 | // ParamConstantOp |
1894 | //===----------------------------------------------------------------------===// |
1895 | |
1896 | DiagnosedSilenceableFailure |
1897 | transform::ParamConstantOp::apply(transform::TransformRewriter &rewriter, |
1898 | transform::TransformResults &results, |
1899 | transform::TransformState &state) { |
1900 | results.setParams(cast<OpResult>(getParam()), {getValue()}); |
1901 | return DiagnosedSilenceableFailure::success(); |
1902 | } |
1903 | |
1904 | //===----------------------------------------------------------------------===// |
1905 | // MergeHandlesOp |
1906 | //===----------------------------------------------------------------------===// |
1907 | |
1908 | DiagnosedSilenceableFailure |
1909 | transform::MergeHandlesOp::apply(transform::TransformRewriter &rewriter, |
1910 | transform::TransformResults &results, |
1911 | transform::TransformState &state) { |
1912 | ValueRange handles = getHandles(); |
1913 | if (isa<TransformHandleTypeInterface>(handles.front().getType())) { |
1914 | SmallVector<Operation *> operations; |
1915 | for (Value operand : handles) |
1916 | llvm::append_range(operations, state.getPayloadOps(operand)); |
1917 | if (!getDeduplicate()) { |
1918 | results.set(llvm::cast<OpResult>(getResult()), operations); |
1919 | return DiagnosedSilenceableFailure::success(); |
1920 | } |
1921 | |
1922 | SetVector<Operation *> uniqued(operations.begin(), operations.end()); |
1923 | results.set(llvm::cast<OpResult>(getResult()), uniqued.getArrayRef()); |
1924 | return DiagnosedSilenceableFailure::success(); |
1925 | } |
1926 | |
1927 | if (llvm::isa<TransformParamTypeInterface>(handles.front().getType())) { |
1928 | SmallVector<Attribute> attrs; |
1929 | for (Value attribute : handles) |
1930 | llvm::append_range(attrs, state.getParams(attribute)); |
1931 | if (!getDeduplicate()) { |
1932 | results.setParams(cast<OpResult>(getResult()), attrs); |
1933 | return DiagnosedSilenceableFailure::success(); |
1934 | } |
1935 | |
1936 | SetVector<Attribute> uniqued(attrs.begin(), attrs.end()); |
1937 | results.setParams(cast<OpResult>(getResult()), uniqued.getArrayRef()); |
1938 | return DiagnosedSilenceableFailure::success(); |
1939 | } |
1940 | |
1941 | assert( |
1942 | llvm::isa<TransformValueHandleTypeInterface>(handles.front().getType()) && |
1943 | "expected value handle type" ); |
1944 | SmallVector<Value> payloadValues; |
1945 | for (Value value : handles) |
1946 | llvm::append_range(payloadValues, state.getPayloadValues(value)); |
1947 | if (!getDeduplicate()) { |
1948 | results.setValues(cast<OpResult>(getResult()), payloadValues); |
1949 | return DiagnosedSilenceableFailure::success(); |
1950 | } |
1951 | |
1952 | SetVector<Value> uniqued(payloadValues.begin(), payloadValues.end()); |
1953 | results.setValues(cast<OpResult>(getResult()), uniqued.getArrayRef()); |
1954 | return DiagnosedSilenceableFailure::success(); |
1955 | } |
1956 | |
1957 | bool transform::MergeHandlesOp::allowsRepeatedHandleOperands() { |
1958 | // Handles may be the same if deduplicating is enabled. |
1959 | return getDeduplicate(); |
1960 | } |
1961 | |
1962 | void transform::MergeHandlesOp::getEffects( |
1963 | SmallVectorImpl<MemoryEffects::EffectInstance> &effects) { |
1964 | onlyReadsHandle(getHandles(), effects); |
1965 | producesHandle(getResult(), effects); |
1966 | |
1967 | // There are no effects on the Payload IR as this is only a handle |
1968 | // manipulation. |
1969 | } |
1970 | |
1971 | OpFoldResult transform::MergeHandlesOp::fold(FoldAdaptor adaptor) { |
1972 | if (getDeduplicate() || getHandles().size() != 1) |
1973 | return {}; |
1974 | |
1975 | // If deduplication is not required and there is only one operand, it can be |
1976 | // used directly instead of merging. |
1977 | return getHandles().front(); |
1978 | } |
1979 | |
1980 | //===----------------------------------------------------------------------===// |
1981 | // NamedSequenceOp |
1982 | //===----------------------------------------------------------------------===// |
1983 | |
1984 | DiagnosedSilenceableFailure |
1985 | transform::NamedSequenceOp::apply(transform::TransformRewriter &rewriter, |
1986 | transform::TransformResults &results, |
1987 | transform::TransformState &state) { |
1988 | if (isExternal()) |
1989 | return emitDefiniteFailure() << "unresolved external named sequence" ; |
1990 | |
1991 | // Map the entry block argument to the list of operations. |
1992 | // Note: this is the same implementation as PossibleTopLevelTransformOp but |
1993 | // without attaching the interface / trait since that is tailored to a |
1994 | // dangling top-level op that does not get "called". |
1995 | auto scope = state.make_region_scope(getBody()); |
1996 | if (failed(detail::mapPossibleTopLevelTransformOpBlockArguments( |
1997 | state, this->getOperation(), getBody()))) |
1998 | return DiagnosedSilenceableFailure::definiteFailure(); |
1999 | |
2000 | return applySequenceBlock(getBody().front(), |
2001 | FailurePropagationMode::Propagate, state, results); |
2002 | } |
2003 | |
2004 | void transform::NamedSequenceOp::getEffects( |
2005 | SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {} |
2006 | |
2007 | ParseResult transform::NamedSequenceOp::parse(OpAsmParser &parser, |
2008 | OperationState &result) { |
2009 | return function_interface_impl::parseFunctionOp( |
2010 | parser, result, /*allowVariadic=*/false, |
2011 | getFunctionTypeAttrName(result.name), |
2012 | [](Builder &builder, ArrayRef<Type> inputs, ArrayRef<Type> results, |
2013 | function_interface_impl::VariadicFlag, |
2014 | std::string &) { return builder.getFunctionType(inputs, results); }, |
2015 | getArgAttrsAttrName(result.name), getResAttrsAttrName(result.name)); |
2016 | } |
2017 | |
2018 | void transform::NamedSequenceOp::print(OpAsmPrinter &printer) { |
2019 | function_interface_impl::printFunctionOp( |
2020 | printer, cast<FunctionOpInterface>(getOperation()), /*isVariadic=*/false, |
2021 | getFunctionTypeAttrName().getValue(), getArgAttrsAttrName(), |
2022 | getResAttrsAttrName()); |
2023 | } |
2024 | |
2025 | /// Verifies that a symbol function-like transform dialect operation has the |
2026 | /// signature and the terminator that have conforming types, i.e., types |
2027 | /// implementing the same transform dialect type interface. If `allowExternal` |
2028 | /// is set, allow external symbols (declarations) and don't check the terminator |
2029 | /// as it may not exist. |
2030 | static DiagnosedSilenceableFailure |
2031 | verifyYieldingSingleBlockOp(FunctionOpInterface op, bool allowExternal) { |
2032 | if (auto parent = op->getParentOfType<transform::TransformOpInterface>()) { |
2033 | DiagnosedSilenceableFailure diag = |
2034 | emitSilenceableFailure(op) |
2035 | << "cannot be defined inside another transform op" ; |
2036 | diag.attachNote(loc: parent.getLoc()) << "ancestor transform op" ; |
2037 | return diag; |
2038 | } |
2039 | |
2040 | if (op.isExternal() || op.getFunctionBody().empty()) { |
2041 | if (allowExternal) |
2042 | return DiagnosedSilenceableFailure::success(); |
2043 | |
2044 | return emitSilenceableFailure(op) << "cannot be external" ; |
2045 | } |
2046 | |
2047 | if (op.getFunctionBody().front().empty()) |
2048 | return emitSilenceableFailure(op) << "expected a non-empty body block" ; |
2049 | |
2050 | Operation *terminator = &op.getFunctionBody().front().back(); |
2051 | if (!isa<transform::YieldOp>(terminator)) { |
2052 | DiagnosedSilenceableFailure diag = emitSilenceableFailure(op) |
2053 | << "expected '" |
2054 | << transform::YieldOp::getOperationName() |
2055 | << "' as terminator" ; |
2056 | diag.attachNote(loc: terminator->getLoc()) << "terminator" ; |
2057 | return diag; |
2058 | } |
2059 | |
2060 | if (terminator->getNumOperands() != op.getResultTypes().size()) { |
2061 | return emitSilenceableFailure(op: terminator) |
2062 | << "expected terminator to have as many operands as the parent op " |
2063 | "has results" ; |
2064 | } |
2065 | for (auto [i, operandType, resultType] : llvm::zip_equal( |
2066 | llvm::seq<unsigned>(0, terminator->getNumOperands()), |
2067 | terminator->getOperands().getType(), op.getResultTypes())) { |
2068 | if (operandType == resultType) |
2069 | continue; |
2070 | return emitSilenceableFailure(terminator) |
2071 | << "the type of the terminator operand #" << i |
2072 | << " must match the type of the corresponding parent op result (" |
2073 | << operandType << " vs " << resultType << ")" ; |
2074 | } |
2075 | |
2076 | return DiagnosedSilenceableFailure::success(); |
2077 | } |
2078 | |
2079 | /// Verification of a NamedSequenceOp. This does not report the error |
2080 | /// immediately, so it can be used to check for op's well-formedness before the |
2081 | /// verifier runs, e.g., during trait verification. |
2082 | static DiagnosedSilenceableFailure |
2083 | verifyNamedSequenceOp(transform::NamedSequenceOp op, bool emitWarnings) { |
2084 | if (Operation *parent = op->getParentWithTrait<OpTrait::SymbolTable>()) { |
2085 | if (!parent->getAttr( |
2086 | transform::TransformDialect::kWithNamedSequenceAttrName)) { |
2087 | DiagnosedSilenceableFailure diag = |
2088 | emitSilenceableFailure(op) |
2089 | << "expects the parent symbol table to have the '" |
2090 | << transform::TransformDialect::kWithNamedSequenceAttrName |
2091 | << "' attribute" ; |
2092 | diag.attachNote(loc: parent->getLoc()) << "symbol table operation" ; |
2093 | return diag; |
2094 | } |
2095 | } |
2096 | |
2097 | if (auto parent = op->getParentOfType<transform::TransformOpInterface>()) { |
2098 | DiagnosedSilenceableFailure diag = |
2099 | emitSilenceableFailure(op) |
2100 | << "cannot be defined inside another transform op" ; |
2101 | diag.attachNote(loc: parent.getLoc()) << "ancestor transform op" ; |
2102 | return diag; |
2103 | } |
2104 | |
2105 | if (op.isExternal() || op.getBody().empty()) |
2106 | return verifyFunctionLikeConsumeAnnotations(cast<FunctionOpInterface>(*op), |
2107 | emitWarnings); |
2108 | |
2109 | if (op.getBody().front().empty()) |
2110 | return emitSilenceableFailure(op) << "expected a non-empty body block" ; |
2111 | |
2112 | Operation *terminator = &op.getBody().front().back(); |
2113 | if (!isa<transform::YieldOp>(terminator)) { |
2114 | DiagnosedSilenceableFailure diag = emitSilenceableFailure(op) |
2115 | << "expected '" |
2116 | << transform::YieldOp::getOperationName() |
2117 | << "' as terminator" ; |
2118 | diag.attachNote(loc: terminator->getLoc()) << "terminator" ; |
2119 | return diag; |
2120 | } |
2121 | |
2122 | if (terminator->getNumOperands() != op.getFunctionType().getNumResults()) { |
2123 | return emitSilenceableFailure(op: terminator) |
2124 | << "expected terminator to have as many operands as the parent op " |
2125 | "has results" ; |
2126 | } |
2127 | for (auto [i, operandType, resultType] : |
2128 | llvm::zip_equal(llvm::seq<unsigned>(0, terminator->getNumOperands()), |
2129 | terminator->getOperands().getType(), |
2130 | op.getFunctionType().getResults())) { |
2131 | if (operandType == resultType) |
2132 | continue; |
2133 | return emitSilenceableFailure(terminator) |
2134 | << "the type of the terminator operand #" << i |
2135 | << " must match the type of the corresponding parent op result (" |
2136 | << operandType << " vs " << resultType << ")" ; |
2137 | } |
2138 | |
2139 | auto funcOp = cast<FunctionOpInterface>(*op); |
2140 | DiagnosedSilenceableFailure diag = |
2141 | verifyFunctionLikeConsumeAnnotations(funcOp, emitWarnings); |
2142 | if (!diag.succeeded()) |
2143 | return diag; |
2144 | |
2145 | return verifyYieldingSingleBlockOp(funcOp, |
2146 | /*allowExternal=*/true); |
2147 | } |
2148 | |
2149 | LogicalResult transform::NamedSequenceOp::verify() { |
2150 | // Actual verification happens in a separate function for reusability. |
2151 | return verifyNamedSequenceOp(*this, /*emitWarnings=*/true).checkAndReport(); |
2152 | } |
2153 | |
2154 | template <typename FnTy> |
2155 | static void buildSequenceBody(OpBuilder &builder, OperationState &state, |
2156 | Type bbArgType, TypeRange , |
2157 | FnTy bodyBuilder) { |
2158 | SmallVector<Type> types; |
2159 | types.reserve(N: 1 + extraBindingTypes.size()); |
2160 | types.push_back(Elt: bbArgType); |
2161 | llvm::append_range(C&: types, R&: extraBindingTypes); |
2162 | |
2163 | OpBuilder::InsertionGuard guard(builder); |
2164 | Region *region = state.regions.back().get(); |
2165 | Block *bodyBlock = |
2166 | builder.createBlock(parent: region, insertPt: region->begin(), argTypes: types, |
2167 | locs: SmallVector<Location>(types.size(), state.location)); |
2168 | |
2169 | // Populate body. |
2170 | builder.setInsertionPointToStart(bodyBlock); |
2171 | if constexpr (llvm::function_traits<FnTy>::num_args == 3) { |
2172 | bodyBuilder(builder, state.location, bodyBlock->getArgument(i: 0)); |
2173 | } else { |
2174 | bodyBuilder(builder, state.location, bodyBlock->getArgument(i: 0), |
2175 | bodyBlock->getArguments().drop_front()); |
2176 | } |
2177 | } |
2178 | |
2179 | void transform::NamedSequenceOp::build(OpBuilder &builder, |
2180 | OperationState &state, StringRef symName, |
2181 | Type rootType, TypeRange resultTypes, |
2182 | SequenceBodyBuilderFn bodyBuilder, |
2183 | ArrayRef<NamedAttribute> attrs, |
2184 | ArrayRef<DictionaryAttr> argAttrs) { |
2185 | state.addAttribute(SymbolTable::getSymbolAttrName(), |
2186 | builder.getStringAttr(symName)); |
2187 | state.addAttribute(getFunctionTypeAttrName(state.name), |
2188 | TypeAttr::get(FunctionType::get(builder.getContext(), |
2189 | rootType, resultTypes))); |
2190 | state.attributes.append(attrs.begin(), attrs.end()); |
2191 | state.addRegion(); |
2192 | |
2193 | buildSequenceBody(builder, state, rootType, |
2194 | /*extraBindingTypes=*/TypeRange(), bodyBuilder); |
2195 | } |
2196 | |
2197 | //===----------------------------------------------------------------------===// |
2198 | // NumAssociationsOp |
2199 | //===----------------------------------------------------------------------===// |
2200 | |
2201 | DiagnosedSilenceableFailure |
2202 | transform::NumAssociationsOp::apply(transform::TransformRewriter &rewriter, |
2203 | transform::TransformResults &results, |
2204 | transform::TransformState &state) { |
2205 | size_t numAssociations = |
2206 | llvm::TypeSwitch<Type, size_t>(getHandle().getType()) |
2207 | .Case([&](TransformHandleTypeInterface opHandle) { |
2208 | return llvm::range_size(state.getPayloadOps(getHandle())); |
2209 | }) |
2210 | .Case([&](TransformValueHandleTypeInterface valueHandle) { |
2211 | return llvm::range_size(state.getPayloadValues(getHandle())); |
2212 | }) |
2213 | .Case([&](TransformParamTypeInterface param) { |
2214 | return llvm::range_size(state.getParams(getHandle())); |
2215 | }) |
2216 | .Default([](Type) { |
2217 | llvm_unreachable("unknown kind of transform dialect type" ); |
2218 | return 0; |
2219 | }); |
2220 | results.setParams(cast<OpResult>(getNum()), |
2221 | rewriter.getI64IntegerAttr(numAssociations)); |
2222 | return DiagnosedSilenceableFailure::success(); |
2223 | } |
2224 | |
2225 | LogicalResult transform::NumAssociationsOp::verify() { |
2226 | // Verify that the result type accepts an i64 attribute as payload. |
2227 | auto resultType = cast<TransformParamTypeInterface>(getNum().getType()); |
2228 | return resultType |
2229 | .checkPayload(getLoc(), {Builder(getContext()).getI64IntegerAttr(0)}) |
2230 | .checkAndReport(); |
2231 | } |
2232 | |
2233 | //===----------------------------------------------------------------------===// |
2234 | // SelectOp |
2235 | //===----------------------------------------------------------------------===// |
2236 | |
2237 | DiagnosedSilenceableFailure |
2238 | transform::SelectOp::apply(transform::TransformRewriter &rewriter, |
2239 | transform::TransformResults &results, |
2240 | transform::TransformState &state) { |
2241 | SmallVector<Operation *> result; |
2242 | auto payloadOps = state.getPayloadOps(getTarget()); |
2243 | for (Operation *op : payloadOps) { |
2244 | if (op->getName().getStringRef() == getOpName()) |
2245 | result.push_back(op); |
2246 | } |
2247 | results.set(cast<OpResult>(getResult()), result); |
2248 | return DiagnosedSilenceableFailure::success(); |
2249 | } |
2250 | |
2251 | //===----------------------------------------------------------------------===// |
2252 | // SplitHandleOp |
2253 | //===----------------------------------------------------------------------===// |
2254 | |
2255 | void transform::SplitHandleOp::build(OpBuilder &builder, OperationState &result, |
2256 | Value target, int64_t numResultHandles) { |
2257 | result.addOperands(target); |
2258 | result.addTypes(SmallVector<Type>(numResultHandles, target.getType())); |
2259 | } |
2260 | |
2261 | DiagnosedSilenceableFailure |
2262 | transform::SplitHandleOp::apply(transform::TransformRewriter &rewriter, |
2263 | transform::TransformResults &results, |
2264 | transform::TransformState &state) { |
2265 | int64_t numPayloadOps = llvm::range_size(state.getPayloadOps(getHandle())); |
2266 | auto produceNumOpsError = [&]() { |
2267 | return emitSilenceableError() |
2268 | << getHandle() << " expected to contain " << this->getNumResults() |
2269 | << " payload ops but it contains " << numPayloadOps |
2270 | << " payload ops" ; |
2271 | }; |
2272 | |
2273 | // Fail if there are more payload ops than results and no overflow result was |
2274 | // specified. |
2275 | if (numPayloadOps > getNumResults() && !getOverflowResult().has_value()) |
2276 | return produceNumOpsError(); |
2277 | |
2278 | // Fail if there are more results than payload ops. Unless: |
2279 | // - "fail_on_payload_too_small" is set to "false", or |
2280 | // - "pass_through_empty_handle" is set to "true" and there are 0 payload ops. |
2281 | if (numPayloadOps < getNumResults() && getFailOnPayloadTooSmall() && |
2282 | (numPayloadOps != 0 || !getPassThroughEmptyHandle())) |
2283 | return produceNumOpsError(); |
2284 | |
2285 | // Distribute payload ops. |
2286 | SmallVector<SmallVector<Operation *, 1>> resultHandles(getNumResults(), {}); |
2287 | if (getOverflowResult()) |
2288 | resultHandles[*getOverflowResult()].reserve(numPayloadOps - |
2289 | getNumResults()); |
2290 | for (auto &&en : llvm::enumerate(state.getPayloadOps(getHandle()))) { |
2291 | int64_t resultNum = en.index(); |
2292 | if (resultNum >= getNumResults()) |
2293 | resultNum = *getOverflowResult(); |
2294 | resultHandles[resultNum].push_back(en.value()); |
2295 | } |
2296 | |
2297 | // Set transform op results. |
2298 | for (auto &&it : llvm::enumerate(resultHandles)) |
2299 | results.set(llvm::cast<OpResult>(getResult(it.index())), it.value()); |
2300 | |
2301 | return DiagnosedSilenceableFailure::success(); |
2302 | } |
2303 | |
2304 | void transform::SplitHandleOp::getEffects( |
2305 | SmallVectorImpl<MemoryEffects::EffectInstance> &effects) { |
2306 | onlyReadsHandle(getHandle(), effects); |
2307 | producesHandle(getResults(), effects); |
2308 | // There are no effects on the Payload IR as this is only a handle |
2309 | // manipulation. |
2310 | } |
2311 | |
2312 | LogicalResult transform::SplitHandleOp::verify() { |
2313 | if (getOverflowResult().has_value() && |
2314 | !(*getOverflowResult() < getNumResults())) |
2315 | return emitOpError("overflow_result is not a valid result index" ); |
2316 | return success(); |
2317 | } |
2318 | |
2319 | //===----------------------------------------------------------------------===// |
2320 | // ReplicateOp |
2321 | //===----------------------------------------------------------------------===// |
2322 | |
2323 | DiagnosedSilenceableFailure |
2324 | transform::ReplicateOp::apply(transform::TransformRewriter &rewriter, |
2325 | transform::TransformResults &results, |
2326 | transform::TransformState &state) { |
2327 | unsigned numRepetitions = llvm::range_size(state.getPayloadOps(getPattern())); |
2328 | for (const auto &en : llvm::enumerate(getHandles())) { |
2329 | Value handle = en.value(); |
2330 | if (isa<TransformHandleTypeInterface>(handle.getType())) { |
2331 | SmallVector<Operation *> current = |
2332 | llvm::to_vector(state.getPayloadOps(handle)); |
2333 | SmallVector<Operation *> payload; |
2334 | payload.reserve(numRepetitions * current.size()); |
2335 | for (unsigned i = 0; i < numRepetitions; ++i) |
2336 | llvm::append_range(payload, current); |
2337 | results.set(llvm::cast<OpResult>(getReplicated()[en.index()]), payload); |
2338 | } else { |
2339 | assert(llvm::isa<TransformParamTypeInterface>(handle.getType()) && |
2340 | "expected param type" ); |
2341 | ArrayRef<Attribute> current = state.getParams(handle); |
2342 | SmallVector<Attribute> params; |
2343 | params.reserve(numRepetitions * current.size()); |
2344 | for (unsigned i = 0; i < numRepetitions; ++i) |
2345 | llvm::append_range(params, current); |
2346 | results.setParams(llvm::cast<OpResult>(getReplicated()[en.index()]), |
2347 | params); |
2348 | } |
2349 | } |
2350 | return DiagnosedSilenceableFailure::success(); |
2351 | } |
2352 | |
2353 | void transform::ReplicateOp::getEffects( |
2354 | SmallVectorImpl<MemoryEffects::EffectInstance> &effects) { |
2355 | onlyReadsHandle(getPattern(), effects); |
2356 | onlyReadsHandle(getHandles(), effects); |
2357 | producesHandle(getReplicated(), effects); |
2358 | } |
2359 | |
2360 | //===----------------------------------------------------------------------===// |
2361 | // SequenceOp |
2362 | //===----------------------------------------------------------------------===// |
2363 | |
2364 | DiagnosedSilenceableFailure |
2365 | transform::SequenceOp::apply(transform::TransformRewriter &rewriter, |
2366 | transform::TransformResults &results, |
2367 | transform::TransformState &state) { |
2368 | // Map the entry block argument to the list of operations. |
2369 | auto scope = state.make_region_scope(*getBodyBlock()->getParent()); |
2370 | if (failed(mapBlockArguments(state))) |
2371 | return DiagnosedSilenceableFailure::definiteFailure(); |
2372 | |
2373 | return applySequenceBlock(*getBodyBlock(), getFailurePropagationMode(), state, |
2374 | results); |
2375 | } |
2376 | |
2377 | static ParseResult parseSequenceOpOperands( |
2378 | OpAsmParser &parser, std::optional<OpAsmParser::UnresolvedOperand> &root, |
2379 | Type &rootType, |
2380 | SmallVectorImpl<OpAsmParser::UnresolvedOperand> &, |
2381 | SmallVectorImpl<Type> &) { |
2382 | OpAsmParser::UnresolvedOperand rootOperand; |
2383 | OptionalParseResult hasRoot = parser.parseOptionalOperand(result&: rootOperand); |
2384 | if (!hasRoot.has_value()) { |
2385 | root = std::nullopt; |
2386 | return success(); |
2387 | } |
2388 | if (failed(result: hasRoot.value())) |
2389 | return failure(); |
2390 | root = rootOperand; |
2391 | |
2392 | if (succeeded(result: parser.parseOptionalComma())) { |
2393 | if (failed(result: parser.parseOperandList(result&: extraBindings))) |
2394 | return failure(); |
2395 | } |
2396 | if (failed(result: parser.parseColon())) |
2397 | return failure(); |
2398 | |
2399 | // The paren is truly optional. |
2400 | (void)parser.parseOptionalLParen(); |
2401 | |
2402 | if (failed(result: parser.parseType(result&: rootType))) { |
2403 | return failure(); |
2404 | } |
2405 | |
2406 | if (!extraBindings.empty()) { |
2407 | if (parser.parseComma() || parser.parseTypeList(result&: extraBindingTypes)) |
2408 | return failure(); |
2409 | } |
2410 | |
2411 | if (extraBindingTypes.size() != extraBindings.size()) { |
2412 | return parser.emitError(loc: parser.getNameLoc(), |
2413 | message: "expected types to be provided for all operands" ); |
2414 | } |
2415 | |
2416 | // The paren is truly optional. |
2417 | (void)parser.parseOptionalRParen(); |
2418 | return success(); |
2419 | } |
2420 | |
2421 | static void printSequenceOpOperands(OpAsmPrinter &printer, Operation *op, |
2422 | Value root, Type rootType, |
2423 | ValueRange , |
2424 | TypeRange ) { |
2425 | if (!root) |
2426 | return; |
2427 | |
2428 | printer << root; |
2429 | bool = !extraBindings.empty(); |
2430 | if (hasExtras) { |
2431 | printer << ", " ; |
2432 | printer.printOperands(container: extraBindings); |
2433 | } |
2434 | |
2435 | printer << " : " ; |
2436 | if (hasExtras) |
2437 | printer << "(" ; |
2438 | |
2439 | printer << rootType; |
2440 | if (hasExtras) { |
2441 | printer << ", " ; |
2442 | llvm::interleaveComma(c: extraBindingTypes, os&: printer.getStream()); |
2443 | printer << ")" ; |
2444 | } |
2445 | } |
2446 | |
2447 | /// Returns `true` if the given op operand may be consuming the handle value in |
2448 | /// the Transform IR. That is, if it may have a Free effect on it. |
2449 | static bool isValueUsePotentialConsumer(OpOperand &use) { |
2450 | // Conservatively assume the effect being present in absence of the interface. |
2451 | auto iface = dyn_cast<transform::TransformOpInterface>(use.getOwner()); |
2452 | if (!iface) |
2453 | return true; |
2454 | |
2455 | return isHandleConsumed(use.get(), iface); |
2456 | } |
2457 | |
2458 | LogicalResult |
2459 | checkDoubleConsume(Value value, |
2460 | function_ref<InFlightDiagnostic()> reportError) { |
2461 | OpOperand *potentialConsumer = nullptr; |
2462 | for (OpOperand &use : value.getUses()) { |
2463 | if (!isValueUsePotentialConsumer(use)) |
2464 | continue; |
2465 | |
2466 | if (!potentialConsumer) { |
2467 | potentialConsumer = &use; |
2468 | continue; |
2469 | } |
2470 | |
2471 | InFlightDiagnostic diag = reportError() |
2472 | << " has more than one potential consumer" ; |
2473 | diag.attachNote(noteLoc: potentialConsumer->getOwner()->getLoc()) |
2474 | << "used here as operand #" << potentialConsumer->getOperandNumber(); |
2475 | diag.attachNote(noteLoc: use.getOwner()->getLoc()) |
2476 | << "used here as operand #" << use.getOperandNumber(); |
2477 | return diag; |
2478 | } |
2479 | |
2480 | return success(); |
2481 | } |
2482 | |
2483 | LogicalResult transform::SequenceOp::verify() { |
2484 | assert(getBodyBlock()->getNumArguments() >= 1 && |
2485 | "the number of arguments must have been verified to be more than 1 by " |
2486 | "PossibleTopLevelTransformOpTrait" ); |
2487 | |
2488 | if (!getRoot() && !getExtraBindings().empty()) { |
2489 | return emitOpError() |
2490 | << "does not expect extra operands when used as top-level" ; |
2491 | } |
2492 | |
2493 | // Check if a block argument has more than one consuming use. |
2494 | for (BlockArgument arg : getBodyBlock()->getArguments()) { |
2495 | if (failed(checkDoubleConsume(arg, [this, arg]() { |
2496 | return (emitOpError() << "block argument #" << arg.getArgNumber()); |
2497 | }))) { |
2498 | return failure(); |
2499 | } |
2500 | } |
2501 | |
2502 | // Check properties of the nested operations they cannot check themselves. |
2503 | for (Operation &child : *getBodyBlock()) { |
2504 | if (!isa<TransformOpInterface>(child) && |
2505 | &child != &getBodyBlock()->back()) { |
2506 | InFlightDiagnostic diag = |
2507 | emitOpError() |
2508 | << "expected children ops to implement TransformOpInterface" ; |
2509 | diag.attachNote(child.getLoc()) << "op without interface" ; |
2510 | return diag; |
2511 | } |
2512 | |
2513 | for (OpResult result : child.getResults()) { |
2514 | auto report = [&]() { |
2515 | return (child.emitError() << "result #" << result.getResultNumber()); |
2516 | }; |
2517 | if (failed(checkDoubleConsume(result, report))) |
2518 | return failure(); |
2519 | } |
2520 | } |
2521 | |
2522 | if (!getBodyBlock()->mightHaveTerminator()) |
2523 | return emitOpError() << "expects to have a terminator in the body" ; |
2524 | |
2525 | if (getBodyBlock()->getTerminator()->getOperandTypes() != |
2526 | getOperation()->getResultTypes()) { |
2527 | InFlightDiagnostic diag = emitOpError() |
2528 | << "expects the types of the terminator operands " |
2529 | "to match the types of the result" ; |
2530 | diag.attachNote(getBodyBlock()->getTerminator()->getLoc()) << "terminator" ; |
2531 | return diag; |
2532 | } |
2533 | return success(); |
2534 | } |
2535 | |
2536 | void transform::SequenceOp::getEffects( |
2537 | SmallVectorImpl<MemoryEffects::EffectInstance> &effects) { |
2538 | getPotentialTopLevelEffects(effects); |
2539 | } |
2540 | |
2541 | OperandRange |
2542 | transform::SequenceOp::getEntrySuccessorOperands(RegionBranchPoint point) { |
2543 | assert(point == getBody() && "unexpected region index" ); |
2544 | if (getOperation()->getNumOperands() > 0) |
2545 | return getOperation()->getOperands(); |
2546 | return OperandRange(getOperation()->operand_end(), |
2547 | getOperation()->operand_end()); |
2548 | } |
2549 | |
2550 | void transform::SequenceOp::getSuccessorRegions( |
2551 | RegionBranchPoint point, SmallVectorImpl<RegionSuccessor> ®ions) { |
2552 | if (point.isParent()) { |
2553 | Region *bodyRegion = &getBody(); |
2554 | regions.emplace_back(bodyRegion, getNumOperands() != 0 |
2555 | ? bodyRegion->getArguments() |
2556 | : Block::BlockArgListType()); |
2557 | return; |
2558 | } |
2559 | |
2560 | assert(point == getBody() && "unexpected region index" ); |
2561 | regions.emplace_back(getOperation()->getResults()); |
2562 | } |
2563 | |
2564 | void transform::SequenceOp::getRegionInvocationBounds( |
2565 | ArrayRef<Attribute> operands, SmallVectorImpl<InvocationBounds> &bounds) { |
2566 | (void)operands; |
2567 | bounds.emplace_back(1, 1); |
2568 | } |
2569 | |
2570 | void transform::SequenceOp::build(OpBuilder &builder, OperationState &state, |
2571 | TypeRange resultTypes, |
2572 | FailurePropagationMode failurePropagationMode, |
2573 | Value root, |
2574 | SequenceBodyBuilderFn bodyBuilder) { |
2575 | build(builder, state, resultTypes, failurePropagationMode, root, |
2576 | /*extra_bindings=*/ValueRange()); |
2577 | Type bbArgType = root.getType(); |
2578 | buildSequenceBody(builder, state, bbArgType, |
2579 | /*extraBindingTypes=*/TypeRange(), bodyBuilder); |
2580 | } |
2581 | |
2582 | void transform::SequenceOp::build(OpBuilder &builder, OperationState &state, |
2583 | TypeRange resultTypes, |
2584 | FailurePropagationMode failurePropagationMode, |
2585 | Value root, ValueRange extraBindings, |
2586 | SequenceBodyBuilderArgsFn bodyBuilder) { |
2587 | build(builder, state, resultTypes, failurePropagationMode, root, |
2588 | extraBindings); |
2589 | buildSequenceBody(builder, state, root.getType(), extraBindings.getTypes(), |
2590 | bodyBuilder); |
2591 | } |
2592 | |
2593 | void transform::SequenceOp::build(OpBuilder &builder, OperationState &state, |
2594 | TypeRange resultTypes, |
2595 | FailurePropagationMode failurePropagationMode, |
2596 | Type bbArgType, |
2597 | SequenceBodyBuilderFn bodyBuilder) { |
2598 | build(builder, state, resultTypes, failurePropagationMode, /*root=*/Value(), |
2599 | /*extra_bindings=*/ValueRange()); |
2600 | buildSequenceBody(builder, state, bbArgType, |
2601 | /*extraBindingTypes=*/TypeRange(), bodyBuilder); |
2602 | } |
2603 | |
2604 | void transform::SequenceOp::build(OpBuilder &builder, OperationState &state, |
2605 | TypeRange resultTypes, |
2606 | FailurePropagationMode failurePropagationMode, |
2607 | Type bbArgType, TypeRange extraBindingTypes, |
2608 | SequenceBodyBuilderArgsFn bodyBuilder) { |
2609 | build(builder, state, resultTypes, failurePropagationMode, /*root=*/Value(), |
2610 | /*extra_bindings=*/ValueRange()); |
2611 | buildSequenceBody(builder, state, bbArgType, extraBindingTypes, bodyBuilder); |
2612 | } |
2613 | |
2614 | //===----------------------------------------------------------------------===// |
2615 | // PrintOp |
2616 | //===----------------------------------------------------------------------===// |
2617 | |
2618 | void transform::PrintOp::build(OpBuilder &builder, OperationState &result, |
2619 | StringRef name) { |
2620 | if (!name.empty()) |
2621 | result.getOrAddProperties<Properties>().name = builder.getStringAttr(name); |
2622 | } |
2623 | |
2624 | void transform::PrintOp::build(OpBuilder &builder, OperationState &result, |
2625 | Value target, StringRef name) { |
2626 | result.addOperands({target}); |
2627 | build(builder, result, name); |
2628 | } |
2629 | |
2630 | DiagnosedSilenceableFailure |
2631 | transform::PrintOp::apply(transform::TransformRewriter &rewriter, |
2632 | transform::TransformResults &results, |
2633 | transform::TransformState &state) { |
2634 | llvm::outs() << "[[[ IR printer: " ; |
2635 | if (getName().has_value()) |
2636 | llvm::outs() << *getName() << " " ; |
2637 | |
2638 | OpPrintingFlags printFlags; |
2639 | if (getAssumeVerified().value_or(false)) |
2640 | printFlags.assumeVerified(); |
2641 | if (getUseLocalScope().value_or(false)) |
2642 | printFlags.useLocalScope(); |
2643 | if (getSkipRegions().value_or(false)) |
2644 | printFlags.skipRegions(); |
2645 | |
2646 | if (!getTarget()) { |
2647 | llvm::outs() << "top-level ]]]\n" ; |
2648 | state.getTopLevel()->print(llvm::outs(), printFlags); |
2649 | llvm::outs() << "\n" ; |
2650 | return DiagnosedSilenceableFailure::success(); |
2651 | } |
2652 | |
2653 | llvm::outs() << "]]]\n" ; |
2654 | for (Operation *target : state.getPayloadOps(getTarget())) { |
2655 | target->print(llvm::outs(), printFlags); |
2656 | llvm::outs() << "\n" ; |
2657 | } |
2658 | |
2659 | return DiagnosedSilenceableFailure::success(); |
2660 | } |
2661 | |
2662 | void transform::PrintOp::getEffects( |
2663 | SmallVectorImpl<MemoryEffects::EffectInstance> &effects) { |
2664 | // We don't really care about mutability here, but `getTarget` now |
2665 | // unconditionally casts to a specific type before verification could run |
2666 | // here. |
2667 | if (!getTargetMutable().empty()) |
2668 | onlyReadsHandle(getTargetMutable()[0].get(), effects); |
2669 | onlyReadsPayload(effects); |
2670 | |
2671 | // There is no resource for stderr file descriptor, so just declare print |
2672 | // writes into the default resource. |
2673 | effects.emplace_back(MemoryEffects::Write::get()); |
2674 | } |
2675 | |
2676 | //===----------------------------------------------------------------------===// |
2677 | // VerifyOp |
2678 | //===----------------------------------------------------------------------===// |
2679 | |
2680 | DiagnosedSilenceableFailure |
2681 | transform::VerifyOp::applyToOne(transform::TransformRewriter &rewriter, |
2682 | Operation *target, |
2683 | transform::ApplyToEachResultList &results, |
2684 | transform::TransformState &state) { |
2685 | if (failed(::mlir::verify(target))) { |
2686 | DiagnosedDefiniteFailure diag = emitDefiniteFailure() |
2687 | << "failed to verify payload op" ; |
2688 | diag.attachNote(target->getLoc()) << "payload op" ; |
2689 | return diag; |
2690 | } |
2691 | return DiagnosedSilenceableFailure::success(); |
2692 | } |
2693 | |
2694 | void transform::VerifyOp::getEffects( |
2695 | SmallVectorImpl<MemoryEffects::EffectInstance> &effects) { |
2696 | transform::onlyReadsHandle(getTarget(), effects); |
2697 | } |
2698 | |
2699 | //===----------------------------------------------------------------------===// |
2700 | // YieldOp |
2701 | //===----------------------------------------------------------------------===// |
2702 | |
2703 | void transform::YieldOp::getEffects( |
2704 | SmallVectorImpl<MemoryEffects::EffectInstance> &effects) { |
2705 | onlyReadsHandle(getOperands(), effects); |
2706 | } |
2707 | |