1 | //===- TransformOps.cpp - Transform dialect operations --------------------===// |
---|---|
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | |
9 | #include "mlir/Dialect/Transform/IR/TransformOps.h" |
10 | |
11 | #include "mlir/Conversion/ConvertToLLVM/ToLLVMInterface.h" |
12 | #include "mlir/Conversion/LLVMCommon/ConversionTarget.h" |
13 | #include "mlir/Conversion/LLVMCommon/TypeConverter.h" |
14 | #include "mlir/Dialect/Transform/IR/TransformAttrs.h" |
15 | #include "mlir/Dialect/Transform/IR/TransformDialect.h" |
16 | #include "mlir/Dialect/Transform/IR/TransformTypes.h" |
17 | #include "mlir/Dialect/Transform/Interfaces/MatchInterfaces.h" |
18 | #include "mlir/Dialect/Transform/Interfaces/TransformInterfaces.h" |
19 | #include "mlir/IR/BuiltinAttributes.h" |
20 | #include "mlir/IR/Diagnostics.h" |
21 | #include "mlir/IR/Dominance.h" |
22 | #include "mlir/IR/OpImplementation.h" |
23 | #include "mlir/IR/OperationSupport.h" |
24 | #include "mlir/IR/PatternMatch.h" |
25 | #include "mlir/IR/Verifier.h" |
26 | #include "mlir/Interfaces/CallInterfaces.h" |
27 | #include "mlir/Interfaces/ControlFlowInterfaces.h" |
28 | #include "mlir/Interfaces/FunctionImplementation.h" |
29 | #include "mlir/Interfaces/FunctionInterfaces.h" |
30 | #include "mlir/Pass/Pass.h" |
31 | #include "mlir/Pass/PassManager.h" |
32 | #include "mlir/Pass/PassRegistry.h" |
33 | #include "mlir/Transforms/CSE.h" |
34 | #include "mlir/Transforms/DialectConversion.h" |
35 | #include "mlir/Transforms/GreedyPatternRewriteDriver.h" |
36 | #include "mlir/Transforms/LoopInvariantCodeMotionUtils.h" |
37 | #include "llvm/ADT/DenseSet.h" |
38 | #include "llvm/ADT/STLExtras.h" |
39 | #include "llvm/ADT/ScopeExit.h" |
40 | #include "llvm/ADT/SmallPtrSet.h" |
41 | #include "llvm/ADT/TypeSwitch.h" |
42 | #include "llvm/Support/Debug.h" |
43 | #include "llvm/Support/ErrorHandling.h" |
44 | #include "llvm/Support/InterleavedRange.h" |
45 | #include <optional> |
46 | |
47 | #define DEBUG_TYPE "transform-dialect" |
48 | #define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "] ") |
49 | |
50 | #define DEBUG_TYPE_MATCHER "transform-matcher" |
51 | #define DBGS_MATCHER() (llvm::dbgs() << "[" DEBUG_TYPE_MATCHER "] ") |
52 | #define DEBUG_MATCHER(x) DEBUG_WITH_TYPE(DEBUG_TYPE_MATCHER, x) |
53 | |
54 | using namespace mlir; |
55 | |
56 | static ParseResult parseApplyRegisteredPassOptions( |
57 | OpAsmParser &parser, DictionaryAttr &options, |
58 | SmallVectorImpl<OpAsmParser::UnresolvedOperand> &dynamicOptions); |
59 | static void printApplyRegisteredPassOptions(OpAsmPrinter &printer, |
60 | Operation *op, |
61 | DictionaryAttr options, |
62 | ValueRange dynamicOptions); |
63 | static ParseResult parseSequenceOpOperands( |
64 | OpAsmParser &parser, std::optional<OpAsmParser::UnresolvedOperand> &root, |
65 | Type &rootType, |
66 | SmallVectorImpl<OpAsmParser::UnresolvedOperand> &extraBindings, |
67 | SmallVectorImpl<Type> &extraBindingTypes); |
68 | static void printSequenceOpOperands(OpAsmPrinter &printer, Operation *op, |
69 | Value root, Type rootType, |
70 | ValueRange extraBindings, |
71 | TypeRange extraBindingTypes); |
72 | static void printForeachMatchSymbols(OpAsmPrinter &printer, Operation *op, |
73 | ArrayAttr matchers, ArrayAttr actions); |
74 | static ParseResult parseForeachMatchSymbols(OpAsmParser &parser, |
75 | ArrayAttr &matchers, |
76 | ArrayAttr &actions); |
77 | |
78 | /// Helper function to check if the given transform op is contained in (or |
79 | /// equal to) the given payload target op. In that case, an error is returned. |
80 | /// Transforming transform IR that is currently executing is generally unsafe. |
81 | static DiagnosedSilenceableFailure |
82 | ensurePayloadIsSeparateFromTransform(transform::TransformOpInterface transform, |
83 | Operation *payload) { |
84 | Operation *transformAncestor = transform.getOperation(); |
85 | while (transformAncestor) { |
86 | if (transformAncestor == payload) { |
87 | DiagnosedDefiniteFailure diag = |
88 | transform.emitDefiniteFailure() |
89 | << "cannot apply transform to itself (or one of its ancestors)"; |
90 | diag.attachNote(loc: payload->getLoc()) << "target payload op"; |
91 | return diag; |
92 | } |
93 | transformAncestor = transformAncestor->getParentOp(); |
94 | } |
95 | return DiagnosedSilenceableFailure::success(); |
96 | } |
97 | |
98 | #define GET_OP_CLASSES |
99 | #include "mlir/Dialect/Transform/IR/TransformOps.cpp.inc" |
100 | |
101 | //===----------------------------------------------------------------------===// |
102 | // AlternativesOp |
103 | //===----------------------------------------------------------------------===// |
104 | |
105 | OperandRange |
106 | transform::AlternativesOp::getEntrySuccessorOperands(RegionBranchPoint point) { |
107 | if (!point.isParent() && getOperation()->getNumOperands() == 1) |
108 | return getOperation()->getOperands(); |
109 | return OperandRange(getOperation()->operand_end(), |
110 | getOperation()->operand_end()); |
111 | } |
112 | |
113 | void transform::AlternativesOp::getSuccessorRegions( |
114 | RegionBranchPoint point, SmallVectorImpl<RegionSuccessor> ®ions) { |
115 | for (Region &alternative : llvm::drop_begin( |
116 | getAlternatives(), |
117 | point.isParent() ? 0 |
118 | : point.getRegionOrNull()->getRegionNumber() + 1)) { |
119 | regions.emplace_back(&alternative, !getOperands().empty() |
120 | ? alternative.getArguments() |
121 | : Block::BlockArgListType()); |
122 | } |
123 | if (!point.isParent()) |
124 | regions.emplace_back(getOperation()->getResults()); |
125 | } |
126 | |
127 | void transform::AlternativesOp::getRegionInvocationBounds( |
128 | ArrayRef<Attribute> operands, SmallVectorImpl<InvocationBounds> &bounds) { |
129 | (void)operands; |
130 | // The region corresponding to the first alternative is always executed, the |
131 | // remaining may or may not be executed. |
132 | bounds.reserve(getNumRegions()); |
133 | bounds.emplace_back(1, 1); |
134 | bounds.resize(getNumRegions(), InvocationBounds(0, 1)); |
135 | } |
136 | |
137 | static void forwardEmptyOperands(Block *block, transform::TransformState &state, |
138 | transform::TransformResults &results) { |
139 | for (const auto &res : block->getParentOp()->getOpResults()) |
140 | results.set(value: res, ops: {}); |
141 | } |
142 | |
143 | DiagnosedSilenceableFailure |
144 | transform::AlternativesOp::apply(transform::TransformRewriter &rewriter, |
145 | transform::TransformResults &results, |
146 | transform::TransformState &state) { |
147 | SmallVector<Operation *> originals; |
148 | if (Value scopeHandle = getScope()) |
149 | llvm::append_range(originals, state.getPayloadOps(scopeHandle)); |
150 | else |
151 | originals.push_back(state.getTopLevel()); |
152 | |
153 | for (Operation *original : originals) { |
154 | if (original->isAncestor(getOperation())) { |
155 | auto diag = emitDefiniteFailure() |
156 | << "scope must not contain the transforms being applied"; |
157 | diag.attachNote(original->getLoc()) << "scope"; |
158 | return diag; |
159 | } |
160 | if (!original->hasTrait<OpTrait::IsIsolatedFromAbove>()) { |
161 | auto diag = emitDefiniteFailure() |
162 | << "only isolated-from-above ops can be alternative scopes"; |
163 | diag.attachNote(original->getLoc()) << "scope"; |
164 | return diag; |
165 | } |
166 | } |
167 | |
168 | for (Region ® : getAlternatives()) { |
169 | // Clone the scope operations and make the transforms in this alternative |
170 | // region apply to them by virtue of mapping the block argument (the only |
171 | // visible handle) to the cloned scope operations. This effectively prevents |
172 | // the transformation from accessing any IR outside the scope. |
173 | auto scope = state.make_region_scope(reg); |
174 | auto clones = llvm::to_vector( |
175 | llvm::map_range(originals, [](Operation *op) { return op->clone(); })); |
176 | auto deleteClones = llvm::make_scope_exit([&] { |
177 | for (Operation *clone : clones) |
178 | clone->erase(); |
179 | }); |
180 | if (failed(state.mapBlockArguments(reg.front().getArgument(0), clones))) |
181 | return DiagnosedSilenceableFailure::definiteFailure(); |
182 | |
183 | bool failed = false; |
184 | for (Operation &transform : reg.front().without_terminator()) { |
185 | DiagnosedSilenceableFailure result = |
186 | state.applyTransform(cast<TransformOpInterface>(transform)); |
187 | if (result.isSilenceableFailure()) { |
188 | LLVM_DEBUG(DBGS() << "alternative failed: "<< result.getMessage() |
189 | << "\n"); |
190 | failed = true; |
191 | break; |
192 | } |
193 | |
194 | if (::mlir::failed(result.silence())) |
195 | return DiagnosedSilenceableFailure::definiteFailure(); |
196 | } |
197 | |
198 | // If all operations in the given alternative succeeded, no need to consider |
199 | // the rest. Replace the original scoping operation with the clone on which |
200 | // the transformations were performed. |
201 | if (!failed) { |
202 | // We will be using the clones, so cancel their scheduled deletion. |
203 | deleteClones.release(); |
204 | TrackingListener listener(state, *this); |
205 | IRRewriter rewriter(getContext(), &listener); |
206 | for (const auto &kvp : llvm::zip(originals, clones)) { |
207 | Operation *original = std::get<0>(kvp); |
208 | Operation *clone = std::get<1>(kvp); |
209 | original->getBlock()->getOperations().insert(original->getIterator(), |
210 | clone); |
211 | rewriter.replaceOp(original, clone->getResults()); |
212 | } |
213 | detail::forwardTerminatorOperands(®.front(), state, results); |
214 | return DiagnosedSilenceableFailure::success(); |
215 | } |
216 | } |
217 | return emitSilenceableError() << "all alternatives failed"; |
218 | } |
219 | |
220 | void transform::AlternativesOp::getEffects( |
221 | SmallVectorImpl<MemoryEffects::EffectInstance> &effects) { |
222 | consumesHandle(getOperation()->getOpOperands(), effects); |
223 | producesHandle(getOperation()->getOpResults(), effects); |
224 | for (Region *region : getRegions()) { |
225 | if (!region->empty()) |
226 | producesHandle(region->front().getArguments(), effects); |
227 | } |
228 | modifiesPayload(effects); |
229 | } |
230 | |
231 | LogicalResult transform::AlternativesOp::verify() { |
232 | for (Region &alternative : getAlternatives()) { |
233 | Block &block = alternative.front(); |
234 | Operation *terminator = block.getTerminator(); |
235 | if (terminator->getOperands().getTypes() != getResults().getTypes()) { |
236 | InFlightDiagnostic diag = emitOpError() |
237 | << "expects terminator operands to have the " |
238 | "same type as results of the operation"; |
239 | diag.attachNote(terminator->getLoc()) << "terminator"; |
240 | return diag; |
241 | } |
242 | } |
243 | |
244 | return success(); |
245 | } |
246 | |
247 | //===----------------------------------------------------------------------===// |
248 | // AnnotateOp |
249 | //===----------------------------------------------------------------------===// |
250 | |
251 | DiagnosedSilenceableFailure |
252 | transform::AnnotateOp::apply(transform::TransformRewriter &rewriter, |
253 | transform::TransformResults &results, |
254 | transform::TransformState &state) { |
255 | SmallVector<Operation *> targets = |
256 | llvm::to_vector(state.getPayloadOps(getTarget())); |
257 | |
258 | Attribute attr = UnitAttr::get(getContext()); |
259 | if (auto paramH = getParam()) { |
260 | ArrayRef<Attribute> params = state.getParams(paramH); |
261 | if (params.size() != 1) { |
262 | if (targets.size() != params.size()) { |
263 | return emitSilenceableError() |
264 | << "parameter and target have different payload lengths (" |
265 | << params.size() << " vs "<< targets.size() << ")"; |
266 | } |
267 | for (auto &&[target, attr] : llvm::zip_equal(targets, params)) |
268 | target->setAttr(getName(), attr); |
269 | return DiagnosedSilenceableFailure::success(); |
270 | } |
271 | attr = params[0]; |
272 | } |
273 | for (auto *target : targets) |
274 | target->setAttr(getName(), attr); |
275 | return DiagnosedSilenceableFailure::success(); |
276 | } |
277 | |
278 | void transform::AnnotateOp::getEffects( |
279 | SmallVectorImpl<MemoryEffects::EffectInstance> &effects) { |
280 | onlyReadsHandle(getTargetMutable(), effects); |
281 | onlyReadsHandle(getParamMutable(), effects); |
282 | modifiesPayload(effects); |
283 | } |
284 | |
285 | //===----------------------------------------------------------------------===// |
286 | // ApplyCommonSubexpressionEliminationOp |
287 | //===----------------------------------------------------------------------===// |
288 | |
289 | DiagnosedSilenceableFailure |
290 | transform::ApplyCommonSubexpressionEliminationOp::applyToOne( |
291 | transform::TransformRewriter &rewriter, Operation *target, |
292 | ApplyToEachResultList &results, transform::TransformState &state) { |
293 | // Make sure that this transform is not applied to itself. Modifying the |
294 | // transform IR while it is being interpreted is generally dangerous. |
295 | DiagnosedSilenceableFailure payloadCheck = |
296 | ensurePayloadIsSeparateFromTransform(*this, target); |
297 | if (!payloadCheck.succeeded()) |
298 | return payloadCheck; |
299 | |
300 | DominanceInfo domInfo; |
301 | mlir::eliminateCommonSubExpressions(rewriter, domInfo, target); |
302 | return DiagnosedSilenceableFailure::success(); |
303 | } |
304 | |
305 | void transform::ApplyCommonSubexpressionEliminationOp::getEffects( |
306 | SmallVectorImpl<MemoryEffects::EffectInstance> &effects) { |
307 | transform::onlyReadsHandle(getTargetMutable(), effects); |
308 | transform::modifiesPayload(effects); |
309 | } |
310 | |
311 | //===----------------------------------------------------------------------===// |
312 | // ApplyDeadCodeEliminationOp |
313 | //===----------------------------------------------------------------------===// |
314 | |
315 | DiagnosedSilenceableFailure transform::ApplyDeadCodeEliminationOp::applyToOne( |
316 | transform::TransformRewriter &rewriter, Operation *target, |
317 | ApplyToEachResultList &results, transform::TransformState &state) { |
318 | // Make sure that this transform is not applied to itself. Modifying the |
319 | // transform IR while it is being interpreted is generally dangerous. |
320 | DiagnosedSilenceableFailure payloadCheck = |
321 | ensurePayloadIsSeparateFromTransform(*this, target); |
322 | if (!payloadCheck.succeeded()) |
323 | return payloadCheck; |
324 | |
325 | // Maintain a worklist of potentially dead ops. |
326 | SetVector<Operation *> worklist; |
327 | |
328 | // Helper function that adds all defining ops of used values (operands and |
329 | // operands of nested ops). |
330 | auto addDefiningOpsToWorklist = [&](Operation *op) { |
331 | op->walk([&](Operation *op) { |
332 | for (Value v : op->getOperands()) |
333 | if (Operation *defOp = v.getDefiningOp()) |
334 | if (target->isProperAncestor(defOp)) |
335 | worklist.insert(defOp); |
336 | }); |
337 | }; |
338 | |
339 | // Helper function that erases an op. |
340 | auto eraseOp = [&](Operation *op) { |
341 | // Remove op and nested ops from the worklist. |
342 | op->walk([&](Operation *op) { |
343 | const auto *it = llvm::find(worklist, op); |
344 | if (it != worklist.end()) |
345 | worklist.erase(it); |
346 | }); |
347 | rewriter.eraseOp(op); |
348 | }; |
349 | |
350 | // Initial walk over the IR. |
351 | target->walk<WalkOrder::PostOrder>([&](Operation *op) { |
352 | if (op != target && isOpTriviallyDead(op)) { |
353 | addDefiningOpsToWorklist(op); |
354 | eraseOp(op); |
355 | } |
356 | }); |
357 | |
358 | // Erase all ops that have become dead. |
359 | while (!worklist.empty()) { |
360 | Operation *op = worklist.pop_back_val(); |
361 | if (!isOpTriviallyDead(op)) |
362 | continue; |
363 | addDefiningOpsToWorklist(op); |
364 | eraseOp(op); |
365 | } |
366 | |
367 | return DiagnosedSilenceableFailure::success(); |
368 | } |
369 | |
370 | void transform::ApplyDeadCodeEliminationOp::getEffects( |
371 | SmallVectorImpl<MemoryEffects::EffectInstance> &effects) { |
372 | transform::onlyReadsHandle(getTargetMutable(), effects); |
373 | transform::modifiesPayload(effects); |
374 | } |
375 | |
376 | //===----------------------------------------------------------------------===// |
377 | // ApplyPatternsOp |
378 | //===----------------------------------------------------------------------===// |
379 | |
380 | DiagnosedSilenceableFailure transform::ApplyPatternsOp::applyToOne( |
381 | transform::TransformRewriter &rewriter, Operation *target, |
382 | ApplyToEachResultList &results, transform::TransformState &state) { |
383 | // Make sure that this transform is not applied to itself. Modifying the |
384 | // transform IR while it is being interpreted is generally dangerous. Even |
385 | // more so for the ApplyPatternsOp because the GreedyPatternRewriteDriver |
386 | // performs many additional simplifications such as dead code elimination. |
387 | DiagnosedSilenceableFailure payloadCheck = |
388 | ensurePayloadIsSeparateFromTransform(*this, target); |
389 | if (!payloadCheck.succeeded()) |
390 | return payloadCheck; |
391 | |
392 | // Gather all specified patterns. |
393 | MLIRContext *ctx = target->getContext(); |
394 | RewritePatternSet patterns(ctx); |
395 | if (!getRegion().empty()) { |
396 | for (Operation &op : getRegion().front()) { |
397 | cast<transform::PatternDescriptorOpInterface>(&op) |
398 | .populatePatternsWithState(patterns, state); |
399 | } |
400 | } |
401 | |
402 | // Configure the GreedyPatternRewriteDriver. |
403 | GreedyRewriteConfig config; |
404 | config.setListener( |
405 | static_cast<RewriterBase::Listener *>(rewriter.getListener())); |
406 | FrozenRewritePatternSet frozenPatterns(std::move(patterns)); |
407 | |
408 | config.setMaxIterations(getMaxIterations() == static_cast<uint64_t>(-1) |
409 | ? GreedyRewriteConfig::kNoLimit |
410 | : getMaxIterations()); |
411 | config.setMaxNumRewrites(getMaxNumRewrites() == static_cast<uint64_t>(-1) |
412 | ? GreedyRewriteConfig::kNoLimit |
413 | : getMaxNumRewrites()); |
414 | |
415 | // Apply patterns and CSE repetitively until a fixpoint is reached. If no CSE |
416 | // was requested, apply the greedy pattern rewrite only once. (The greedy |
417 | // pattern rewrite driver already iterates to a fixpoint internally.) |
418 | bool cseChanged = false; |
419 | // One or two iterations should be sufficient. Stop iterating after a certain |
420 | // threshold to make debugging easier. |
421 | static const int64_t kNumMaxIterations = 50; |
422 | int64_t iteration = 0; |
423 | do { |
424 | LogicalResult result = failure(); |
425 | if (target->hasTrait<OpTrait::IsIsolatedFromAbove>()) { |
426 | // Op is isolated from above. Apply patterns and also perform region |
427 | // simplification. |
428 | result = applyPatternsGreedily(target, frozenPatterns, config); |
429 | } else { |
430 | // Manually gather list of ops because the other |
431 | // GreedyPatternRewriteDriver overloads only accepts ops that are isolated |
432 | // from above. This way, patterns can be applied to ops that are not |
433 | // isolated from above. Regions are not being simplified. Furthermore, |
434 | // only a single greedy rewrite iteration is performed. |
435 | SmallVector<Operation *> ops; |
436 | target->walk([&](Operation *nestedOp) { |
437 | if (target != nestedOp) |
438 | ops.push_back(nestedOp); |
439 | }); |
440 | result = applyOpPatternsGreedily(ops, frozenPatterns, config); |
441 | } |
442 | |
443 | // A failure typically indicates that the pattern application did not |
444 | // converge. |
445 | if (failed(result)) { |
446 | return emitSilenceableFailure(target) |
447 | << "greedy pattern application failed"; |
448 | } |
449 | |
450 | if (getApplyCse()) { |
451 | DominanceInfo domInfo; |
452 | mlir::eliminateCommonSubExpressions(rewriter, domInfo, target, |
453 | &cseChanged); |
454 | } |
455 | } while (cseChanged && ++iteration < kNumMaxIterations); |
456 | |
457 | if (iteration == kNumMaxIterations) |
458 | return emitDefiniteFailure() << "fixpoint iteration did not converge"; |
459 | |
460 | return DiagnosedSilenceableFailure::success(); |
461 | } |
462 | |
463 | LogicalResult transform::ApplyPatternsOp::verify() { |
464 | if (!getRegion().empty()) { |
465 | for (Operation &op : getRegion().front()) { |
466 | if (!isa<transform::PatternDescriptorOpInterface>(&op)) { |
467 | InFlightDiagnostic diag = emitOpError() |
468 | << "expected children ops to implement " |
469 | "PatternDescriptorOpInterface"; |
470 | diag.attachNote(op.getLoc()) << "op without interface"; |
471 | return diag; |
472 | } |
473 | } |
474 | } |
475 | return success(); |
476 | } |
477 | |
478 | void transform::ApplyPatternsOp::getEffects( |
479 | SmallVectorImpl<MemoryEffects::EffectInstance> &effects) { |
480 | transform::onlyReadsHandle(getTargetMutable(), effects); |
481 | transform::modifiesPayload(effects); |
482 | } |
483 | |
484 | void transform::ApplyPatternsOp::build( |
485 | OpBuilder &builder, OperationState &result, Value target, |
486 | function_ref<void(OpBuilder &, Location)> bodyBuilder) { |
487 | result.addOperands(target); |
488 | |
489 | OpBuilder::InsertionGuard g(builder); |
490 | Region *region = result.addRegion(); |
491 | builder.createBlock(region); |
492 | if (bodyBuilder) |
493 | bodyBuilder(builder, result.location); |
494 | } |
495 | |
496 | //===----------------------------------------------------------------------===// |
497 | // ApplyCanonicalizationPatternsOp |
498 | //===----------------------------------------------------------------------===// |
499 | |
500 | void transform::ApplyCanonicalizationPatternsOp::populatePatterns( |
501 | RewritePatternSet &patterns) { |
502 | MLIRContext *ctx = patterns.getContext(); |
503 | for (Dialect *dialect : ctx->getLoadedDialects()) |
504 | dialect->getCanonicalizationPatterns(patterns); |
505 | for (RegisteredOperationName op : ctx->getRegisteredOperations()) |
506 | op.getCanonicalizationPatterns(patterns, ctx); |
507 | } |
508 | |
509 | //===----------------------------------------------------------------------===// |
510 | // ApplyConversionPatternsOp |
511 | //===----------------------------------------------------------------------===// |
512 | |
513 | DiagnosedSilenceableFailure transform::ApplyConversionPatternsOp::apply( |
514 | transform::TransformRewriter &rewriter, |
515 | transform::TransformResults &results, transform::TransformState &state) { |
516 | MLIRContext *ctx = getContext(); |
517 | |
518 | // Instantiate the default type converter if a type converter builder is |
519 | // specified. |
520 | std::unique_ptr<TypeConverter> defaultTypeConverter; |
521 | transform::TypeConverterBuilderOpInterface typeConverterBuilder = |
522 | getDefaultTypeConverter(); |
523 | if (typeConverterBuilder) |
524 | defaultTypeConverter = typeConverterBuilder.getTypeConverter(); |
525 | |
526 | // Configure conversion target. |
527 | ConversionTarget conversionTarget(*getContext()); |
528 | if (getLegalOps()) |
529 | for (Attribute attr : cast<ArrayAttr>(*getLegalOps())) |
530 | conversionTarget.addLegalOp( |
531 | OperationName(cast<StringAttr>(attr).getValue(), ctx)); |
532 | if (getIllegalOps()) |
533 | for (Attribute attr : cast<ArrayAttr>(*getIllegalOps())) |
534 | conversionTarget.addIllegalOp( |
535 | OperationName(cast<StringAttr>(attr).getValue(), ctx)); |
536 | if (getLegalDialects()) |
537 | for (Attribute attr : cast<ArrayAttr>(*getLegalDialects())) |
538 | conversionTarget.addLegalDialect(cast<StringAttr>(attr).getValue()); |
539 | if (getIllegalDialects()) |
540 | for (Attribute attr : cast<ArrayAttr>(*getIllegalDialects())) |
541 | conversionTarget.addIllegalDialect(cast<StringAttr>(attr).getValue()); |
542 | |
543 | // Gather all specified patterns. |
544 | RewritePatternSet patterns(ctx); |
545 | // Need to keep the converters alive until after pattern application because |
546 | // the patterns take a reference to an object that would otherwise get out of |
547 | // scope. |
548 | SmallVector<std::unique_ptr<TypeConverter>> keepAliveConverters; |
549 | if (!getPatterns().empty()) { |
550 | for (Operation &op : getPatterns().front()) { |
551 | auto descriptor = |
552 | cast<transform::ConversionPatternDescriptorOpInterface>(&op); |
553 | |
554 | // Check if this pattern set specifies a type converter. |
555 | std::unique_ptr<TypeConverter> typeConverter = |
556 | descriptor.getTypeConverter(); |
557 | TypeConverter *converter = nullptr; |
558 | if (typeConverter) { |
559 | keepAliveConverters.emplace_back(std::move(typeConverter)); |
560 | converter = keepAliveConverters.back().get(); |
561 | } else { |
562 | // No type converter specified: Use the default type converter. |
563 | if (!defaultTypeConverter) { |
564 | auto diag = emitDefiniteFailure() |
565 | << "pattern descriptor does not specify type " |
566 | "converter and apply_conversion_patterns op has " |
567 | "no default type converter"; |
568 | diag.attachNote(op.getLoc()) << "pattern descriptor op"; |
569 | return diag; |
570 | } |
571 | converter = defaultTypeConverter.get(); |
572 | } |
573 | |
574 | // Add descriptor-specific updates to the conversion target, which may |
575 | // depend on the final type converter. In structural converters, the |
576 | // legality of types dictates the dynamic legality of an operation. |
577 | descriptor.populateConversionTargetRules(*converter, conversionTarget); |
578 | |
579 | descriptor.populatePatterns(*converter, patterns); |
580 | } |
581 | } |
582 | |
583 | // Attach a tracking listener if handles should be preserved. We configure the |
584 | // listener to allow op replacements with different names, as conversion |
585 | // patterns typically replace ops with replacement ops that have a different |
586 | // name. |
587 | TrackingListenerConfig trackingConfig; |
588 | trackingConfig.requireMatchingReplacementOpName = false; |
589 | ErrorCheckingTrackingListener trackingListener(state, *this, trackingConfig); |
590 | ConversionConfig conversionConfig; |
591 | if (getPreserveHandles()) |
592 | conversionConfig.listener = &trackingListener; |
593 | |
594 | FrozenRewritePatternSet frozenPatterns(std::move(patterns)); |
595 | for (Operation *target : state.getPayloadOps(getTarget())) { |
596 | // Make sure that this transform is not applied to itself. Modifying the |
597 | // transform IR while it is being interpreted is generally dangerous. |
598 | DiagnosedSilenceableFailure payloadCheck = |
599 | ensurePayloadIsSeparateFromTransform(*this, target); |
600 | if (!payloadCheck.succeeded()) |
601 | return payloadCheck; |
602 | |
603 | LogicalResult status = failure(); |
604 | if (getPartialConversion()) { |
605 | status = applyPartialConversion(target, conversionTarget, frozenPatterns, |
606 | conversionConfig); |
607 | } else { |
608 | status = applyFullConversion(target, conversionTarget, frozenPatterns, |
609 | conversionConfig); |
610 | } |
611 | |
612 | // Check dialect conversion state. |
613 | DiagnosedSilenceableFailure diag = DiagnosedSilenceableFailure::success(); |
614 | if (failed(status)) { |
615 | diag = emitSilenceableError() << "dialect conversion failed"; |
616 | diag.attachNote(target->getLoc()) << "target op"; |
617 | } |
618 | |
619 | // Check tracking listener error state. |
620 | DiagnosedSilenceableFailure trackingFailure = |
621 | trackingListener.checkAndResetError(); |
622 | if (!trackingFailure.succeeded()) { |
623 | if (diag.succeeded()) { |
624 | // Tracking failure is the only failure. |
625 | return trackingFailure; |
626 | } else { |
627 | diag.attachNote() << "tracking listener also failed: " |
628 | << trackingFailure.getMessage(); |
629 | (void)trackingFailure.silence(); |
630 | } |
631 | } |
632 | |
633 | if (!diag.succeeded()) |
634 | return diag; |
635 | } |
636 | |
637 | return DiagnosedSilenceableFailure::success(); |
638 | } |
639 | |
640 | LogicalResult transform::ApplyConversionPatternsOp::verify() { |
641 | if (getNumRegions() != 1 && getNumRegions() != 2) |
642 | return emitOpError() << "expected 1 or 2 regions"; |
643 | if (!getPatterns().empty()) { |
644 | for (Operation &op : getPatterns().front()) { |
645 | if (!isa<transform::ConversionPatternDescriptorOpInterface>(&op)) { |
646 | InFlightDiagnostic diag = |
647 | emitOpError() << "expected pattern children ops to implement " |
648 | "ConversionPatternDescriptorOpInterface"; |
649 | diag.attachNote(op.getLoc()) << "op without interface"; |
650 | return diag; |
651 | } |
652 | } |
653 | } |
654 | if (getNumRegions() == 2) { |
655 | Region &typeConverterRegion = getRegion(1); |
656 | if (!llvm::hasSingleElement(typeConverterRegion.front())) |
657 | return emitOpError() |
658 | << "expected exactly one op in default type converter region"; |
659 | Operation *maybeTypeConverter = &typeConverterRegion.front().front(); |
660 | auto typeConverterOp = dyn_cast<transform::TypeConverterBuilderOpInterface>( |
661 | maybeTypeConverter); |
662 | if (!typeConverterOp) { |
663 | InFlightDiagnostic diag = emitOpError() |
664 | << "expected default converter child op to " |
665 | "implement TypeConverterBuilderOpInterface"; |
666 | diag.attachNote(maybeTypeConverter->getLoc()) << "op without interface"; |
667 | return diag; |
668 | } |
669 | // Check default type converter type. |
670 | if (!getPatterns().empty()) { |
671 | for (Operation &op : getPatterns().front()) { |
672 | auto descriptor = |
673 | cast<transform::ConversionPatternDescriptorOpInterface>(&op); |
674 | if (failed(descriptor.verifyTypeConverter(typeConverterOp))) |
675 | return failure(); |
676 | } |
677 | } |
678 | } |
679 | return success(); |
680 | } |
681 | |
682 | void transform::ApplyConversionPatternsOp::getEffects( |
683 | SmallVectorImpl<MemoryEffects::EffectInstance> &effects) { |
684 | if (!getPreserveHandles()) { |
685 | transform::consumesHandle(getTargetMutable(), effects); |
686 | } else { |
687 | transform::onlyReadsHandle(getTargetMutable(), effects); |
688 | } |
689 | transform::modifiesPayload(effects); |
690 | } |
691 | |
692 | void transform::ApplyConversionPatternsOp::build( |
693 | OpBuilder &builder, OperationState &result, Value target, |
694 | function_ref<void(OpBuilder &, Location)> patternsBodyBuilder, |
695 | function_ref<void(OpBuilder &, Location)> typeConverterBodyBuilder) { |
696 | result.addOperands(target); |
697 | |
698 | { |
699 | OpBuilder::InsertionGuard g(builder); |
700 | Region *region1 = result.addRegion(); |
701 | builder.createBlock(region1); |
702 | if (patternsBodyBuilder) |
703 | patternsBodyBuilder(builder, result.location); |
704 | } |
705 | { |
706 | OpBuilder::InsertionGuard g(builder); |
707 | Region *region2 = result.addRegion(); |
708 | builder.createBlock(region2); |
709 | if (typeConverterBodyBuilder) |
710 | typeConverterBodyBuilder(builder, result.location); |
711 | } |
712 | } |
713 | |
714 | //===----------------------------------------------------------------------===// |
715 | // ApplyToLLVMConversionPatternsOp |
716 | //===----------------------------------------------------------------------===// |
717 | |
718 | void transform::ApplyToLLVMConversionPatternsOp::populatePatterns( |
719 | TypeConverter &typeConverter, RewritePatternSet &patterns) { |
720 | Dialect *dialect = getContext()->getLoadedDialect(getDialectName()); |
721 | assert(dialect && "expected that dialect is loaded"); |
722 | auto *iface = cast<ConvertToLLVMPatternInterface>(dialect); |
723 | // ConversionTarget is currently ignored because the enclosing |
724 | // apply_conversion_patterns op sets up its own ConversionTarget. |
725 | ConversionTarget target(*getContext()); |
726 | iface->populateConvertToLLVMConversionPatterns( |
727 | target, static_cast<LLVMTypeConverter &>(typeConverter), patterns); |
728 | } |
729 | |
730 | LogicalResult transform::ApplyToLLVMConversionPatternsOp::verifyTypeConverter( |
731 | transform::TypeConverterBuilderOpInterface builder) { |
732 | if (builder.getTypeConverterType() != "LLVMTypeConverter") |
733 | return emitOpError("expected LLVMTypeConverter"); |
734 | return success(); |
735 | } |
736 | |
737 | LogicalResult transform::ApplyToLLVMConversionPatternsOp::verify() { |
738 | Dialect *dialect = getContext()->getLoadedDialect(getDialectName()); |
739 | if (!dialect) |
740 | return emitOpError("unknown dialect or dialect not loaded: ") |
741 | << getDialectName(); |
742 | auto *iface = dyn_cast<ConvertToLLVMPatternInterface>(dialect); |
743 | if (!iface) |
744 | return emitOpError( |
745 | "dialect does not implement ConvertToLLVMPatternInterface or " |
746 | "extension was not loaded: ") |
747 | << getDialectName(); |
748 | return success(); |
749 | } |
750 | |
751 | //===----------------------------------------------------------------------===// |
752 | // ApplyLoopInvariantCodeMotionOp |
753 | //===----------------------------------------------------------------------===// |
754 | |
755 | DiagnosedSilenceableFailure |
756 | transform::ApplyLoopInvariantCodeMotionOp::applyToOne( |
757 | transform::TransformRewriter &rewriter, LoopLikeOpInterface target, |
758 | transform::ApplyToEachResultList &results, |
759 | transform::TransformState &state) { |
760 | // Currently, LICM does not remove operations, so we don't need tracking. |
761 | // If this ever changes, add a LICM entry point that takes a rewriter. |
762 | moveLoopInvariantCode(target); |
763 | return DiagnosedSilenceableFailure::success(); |
764 | } |
765 | |
766 | void transform::ApplyLoopInvariantCodeMotionOp::getEffects( |
767 | SmallVectorImpl<MemoryEffects::EffectInstance> &effects) { |
768 | transform::onlyReadsHandle(getTargetMutable(), effects); |
769 | transform::modifiesPayload(effects); |
770 | } |
771 | |
772 | //===----------------------------------------------------------------------===// |
773 | // ApplyRegisteredPassOp |
774 | //===----------------------------------------------------------------------===// |
775 | |
776 | void transform::ApplyRegisteredPassOp::getEffects( |
777 | SmallVectorImpl<MemoryEffects::EffectInstance> &effects) { |
778 | consumesHandle(getTargetMutable(), effects); |
779 | onlyReadsHandle(getDynamicOptionsMutable(), effects); |
780 | producesHandle(getOperation()->getOpResults(), effects); |
781 | modifiesPayload(effects); |
782 | } |
783 | |
784 | DiagnosedSilenceableFailure |
785 | transform::ApplyRegisteredPassOp::apply(transform::TransformRewriter &rewriter, |
786 | transform::TransformResults &results, |
787 | transform::TransformState &state) { |
788 | // Obtain a single options-string to pass to the pass(-pipeline) from options |
789 | // passed in as a dictionary of keys mapping to values which are either |
790 | // attributes or param-operands pointing to attributes. |
791 | |
792 | std::string options; |
793 | llvm::raw_string_ostream optionsStream(options); // For "printing" attrs. |
794 | |
795 | OperandRange dynamicOptions = getDynamicOptions(); |
796 | for (auto [idx, namedAttribute] : llvm::enumerate(getOptions())) { |
797 | if (idx > 0) |
798 | optionsStream << " "; // Interleave options separator. |
799 | optionsStream << namedAttribute.getName().str(); // Append the key. |
800 | optionsStream << "="; // And the key-value separator. |
801 | |
802 | Attribute valueAttrToAppend; |
803 | if (auto paramOperandIndex = |
804 | dyn_cast<transform::ParamOperandAttr>(namedAttribute.getValue())) { |
805 | // The corresponding value attribute is passed in via a param. |
806 | // Obtain the param-operand via its specified index. |
807 | size_t dynamicOptionIdx = paramOperandIndex.getIndex().getInt(); |
808 | assert(dynamicOptionIdx < dynamicOptions.size() && |
809 | "number of dynamic option markers (UnitAttr) in options ArrayAttr " |
810 | "should be the same as the number of options passed as params"); |
811 | ArrayRef<Attribute> dynamicOption = |
812 | state.getParams(dynamicOptions[dynamicOptionIdx]); |
813 | if (dynamicOption.size() != 1) |
814 | return emitSilenceableError() |
815 | << "options passed as a param must have " |
816 | "a single value associated, param " |
817 | << dynamicOptionIdx << " associates "<< dynamicOption.size(); |
818 | valueAttrToAppend = dynamicOption[0]; |
819 | } else { |
820 | // Value is a static attribute. |
821 | valueAttrToAppend = namedAttribute.getValue(); |
822 | } |
823 | |
824 | // Append string representation of value attribute. |
825 | if (auto strAttr = dyn_cast<StringAttr>(valueAttrToAppend)) { |
826 | optionsStream << strAttr.getValue().str(); |
827 | } else { |
828 | valueAttrToAppend.print(optionsStream, /*elideType=*/true); |
829 | } |
830 | } |
831 | optionsStream.flush(); |
832 | |
833 | // Get pass or pass pipeline from registry. |
834 | const PassRegistryEntry *info = PassPipelineInfo::lookup(getPassName()); |
835 | if (!info) |
836 | info = PassInfo::lookup(getPassName()); |
837 | if (!info) |
838 | return emitDefiniteFailure() |
839 | << "unknown pass or pass pipeline: "<< getPassName(); |
840 | |
841 | // Create pass manager and add the pass or pass pipeline. |
842 | PassManager pm(getContext()); |
843 | if (failed(info->addToPipeline(pm, options, [&](const Twine &msg) { |
844 | emitError(msg); |
845 | return failure(); |
846 | }))) { |
847 | return emitDefiniteFailure() |
848 | << "failed to add pass or pass pipeline to pipeline: " |
849 | << getPassName(); |
850 | } |
851 | |
852 | auto targets = SmallVector<Operation *>(state.getPayloadOps(getTarget())); |
853 | for (Operation *target : targets) { |
854 | // Make sure that this transform is not applied to itself. Modifying the |
855 | // transform IR while it is being interpreted is generally dangerous. Even |
856 | // more so when applying passes because they may perform a wide range of IR |
857 | // modifications. |
858 | DiagnosedSilenceableFailure payloadCheck = |
859 | ensurePayloadIsSeparateFromTransform(*this, target); |
860 | if (!payloadCheck.succeeded()) |
861 | return payloadCheck; |
862 | |
863 | // Run the pass or pass pipeline on the current target operation. |
864 | if (failed(pm.run(target))) { |
865 | auto diag = emitSilenceableError() << "pass pipeline failed"; |
866 | diag.attachNote(target->getLoc()) << "target op"; |
867 | return diag; |
868 | } |
869 | } |
870 | |
871 | // The applied pass will have directly modified the payload IR(s). |
872 | results.set(llvm::cast<OpResult>(getResult()), targets); |
873 | return DiagnosedSilenceableFailure::success(); |
874 | } |
875 | |
876 | static ParseResult parseApplyRegisteredPassOptions( |
877 | OpAsmParser &parser, DictionaryAttr &options, |
878 | SmallVectorImpl<OpAsmParser::UnresolvedOperand> &dynamicOptions) { |
879 | // Construct the options DictionaryAttr per a `{ key = value, ... }` syntax. |
880 | SmallVector<NamedAttribute> keyValuePairs; |
881 | |
882 | size_t dynamicOptionsIdx = 0; |
883 | auto parseKeyValuePair = [&]() -> ParseResult { |
884 | // Parse items of the form `key = value` where `key` is a bare identifier or |
885 | // a string and `value` is either an attribute or an operand. |
886 | |
887 | std::string key; |
888 | Attribute valueAttr; |
889 | if (parser.parseOptionalKeywordOrString(result: &key)) |
890 | return parser.emitError(loc: parser.getCurrentLocation()) |
891 | << "expected key to either be an identifier or a string"; |
892 | if (key.empty()) |
893 | return failure(); |
894 | |
895 | if (parser.parseEqual()) |
896 | return parser.emitError(loc: parser.getCurrentLocation()) |
897 | << "expected '=' after key in key-value pair"; |
898 | |
899 | // Parse the value, which can be either an attribute or an operand. |
900 | OptionalParseResult parsedValueAttr = |
901 | parser.parseOptionalAttribute(result&: valueAttr); |
902 | if (!parsedValueAttr.has_value()) { |
903 | OpAsmParser::UnresolvedOperand operand; |
904 | ParseResult parsedOperand = parser.parseOperand(result&: operand); |
905 | if (failed(Result: parsedOperand)) |
906 | return parser.emitError(loc: parser.getCurrentLocation()) |
907 | << "expected a valid attribute or operand as value associated " |
908 | << "to key '"<< key << "'"; |
909 | // To make use of the operand, we need to store it in the options dict. |
910 | // As SSA-values cannot occur in attributes, what we do instead is store |
911 | // an attribute in its place that contains the index of the param-operand, |
912 | // so that an attr-value associated to the param can be resolved later on. |
913 | dynamicOptions.push_back(Elt: operand); |
914 | auto wrappedIndex = IntegerAttr::get( |
915 | IntegerType::get(parser.getContext(), 64), dynamicOptionsIdx++); |
916 | valueAttr = |
917 | transform::ParamOperandAttr::get(parser.getContext(), wrappedIndex); |
918 | } else if (failed(Result: parsedValueAttr.value())) { |
919 | return failure(); // NB: Attempted parse should have output error message. |
920 | } else if (isa<transform::ParamOperandAttr>(valueAttr)) { |
921 | return parser.emitError(loc: parser.getCurrentLocation()) |
922 | << "the param_operand attribute is a marker reserved for " |
923 | << "indicating a value will be passed via params and is only used " |
924 | << "in the generic print format"; |
925 | } |
926 | |
927 | keyValuePairs.push_back(Elt: NamedAttribute(key, valueAttr)); |
928 | return success(); |
929 | }; |
930 | |
931 | if (parser.parseCommaSeparatedList(delimiter: AsmParser::Delimiter::Braces, |
932 | parseElementFn: parseKeyValuePair, |
933 | contextMessage: " in options dictionary")) |
934 | return failure(); // NB: Attempted parse should have output error message. |
935 | |
936 | if (DictionaryAttr::findDuplicate( |
937 | keyValuePairs, /*isSorted=*/false) // Also sorts the keyValuePairs. |
938 | .has_value()) |
939 | return parser.emitError(loc: parser.getCurrentLocation()) |
940 | << "duplicate keys found in options dictionary"; |
941 | |
942 | options = DictionaryAttr::getWithSorted(parser.getContext(), keyValuePairs); |
943 | |
944 | return success(); |
945 | } |
946 | |
947 | static void printApplyRegisteredPassOptions(OpAsmPrinter &printer, |
948 | Operation *op, |
949 | DictionaryAttr options, |
950 | ValueRange dynamicOptions) { |
951 | if (options.empty()) |
952 | return; |
953 | |
954 | printer << "{"; |
955 | llvm::interleaveComma(options, printer, [&](NamedAttribute namedAttribute) { |
956 | printer << namedAttribute.getName() << " = "; |
957 | Attribute value = namedAttribute.getValue(); |
958 | if (auto indexAttr = dyn_cast<transform::ParamOperandAttr>(value)) { |
959 | // Resolve index of param-operand to its actual SSA-value and print that. |
960 | printer.printOperand(dynamicOptions[indexAttr.getIndex().getInt()]); |
961 | } else { |
962 | printer.printAttribute(attr: value); |
963 | } |
964 | }); |
965 | printer << "}"; |
966 | } |
967 | |
968 | LogicalResult transform::ApplyRegisteredPassOp::verify() { |
969 | // Check that there is a one-to-one correspondence between param operands |
970 | // and references to dynamic options in the options dictionary. |
971 | |
972 | auto dynamicOptions = SmallVector<Value>(getDynamicOptions()); |
973 | for (NamedAttribute namedAttr : getOptions()) |
974 | if (auto paramOperand = |
975 | dyn_cast<transform::ParamOperandAttr>(namedAttr.getValue())) { |
976 | size_t dynamicOptionIdx = paramOperand.getIndex().getInt(); |
977 | if (dynamicOptionIdx < 0 || dynamicOptionIdx >= dynamicOptions.size()) |
978 | return emitOpError() |
979 | << "dynamic option index "<< dynamicOptionIdx |
980 | << " is out of bounds for the number of dynamic options: " |
981 | << dynamicOptions.size(); |
982 | if (dynamicOptions[dynamicOptionIdx] == nullptr) |
983 | return emitOpError() << "dynamic option index "<< dynamicOptionIdx |
984 | << " is already used in options"; |
985 | dynamicOptions[dynamicOptionIdx] = nullptr; // Mark this option as used. |
986 | } |
987 | |
988 | for (Value dynamicOption : dynamicOptions) |
989 | if (dynamicOption) |
990 | return emitOpError() << "a param operand does not have a corresponding " |
991 | << "param_operand attr in the options dict"; |
992 | |
993 | return success(); |
994 | } |
995 | |
996 | //===----------------------------------------------------------------------===// |
997 | // CastOp |
998 | //===----------------------------------------------------------------------===// |
999 | |
1000 | DiagnosedSilenceableFailure |
1001 | transform::CastOp::applyToOne(transform::TransformRewriter &rewriter, |
1002 | Operation *target, ApplyToEachResultList &results, |
1003 | transform::TransformState &state) { |
1004 | results.push_back(target); |
1005 | return DiagnosedSilenceableFailure::success(); |
1006 | } |
1007 | |
1008 | void transform::CastOp::getEffects( |
1009 | SmallVectorImpl<MemoryEffects::EffectInstance> &effects) { |
1010 | onlyReadsPayload(effects); |
1011 | onlyReadsHandle(getInputMutable(), effects); |
1012 | producesHandle(getOperation()->getOpResults(), effects); |
1013 | } |
1014 | |
1015 | bool transform::CastOp::areCastCompatible(TypeRange inputs, TypeRange outputs) { |
1016 | assert(inputs.size() == 1 && "expected one input"); |
1017 | assert(outputs.size() == 1 && "expected one output"); |
1018 | return llvm::all_of( |
1019 | std::initializer_list<Type>{inputs.front(), outputs.front()}, |
1020 | llvm::IsaPred<transform::TransformHandleTypeInterface>); |
1021 | } |
1022 | |
1023 | //===----------------------------------------------------------------------===// |
1024 | // CollectMatchingOp |
1025 | //===----------------------------------------------------------------------===// |
1026 | |
1027 | /// Applies matcher operations from the given `block` using |
1028 | /// `blockArgumentMapping` to initialize block arguments. Updates `state` |
1029 | /// accordingly. If any of the matcher produces a silenceable failure, discards |
1030 | /// it (printing the content to the debug output stream) and returns failure. If |
1031 | /// any of the matchers produces a definite failure, reports it and returns |
1032 | /// failure. If all matchers in the block succeed, populates `mappings` with the |
1033 | /// payload entities associated with the block terminator operands. Note that |
1034 | /// `mappings` will be cleared before that. |
1035 | static DiagnosedSilenceableFailure |
1036 | matchBlock(Block &block, |
1037 | ArrayRef<SmallVector<transform::MappedValue>> blockArgumentMapping, |
1038 | transform::TransformState &state, |
1039 | SmallVectorImpl<SmallVector<transform::MappedValue>> &mappings) { |
1040 | assert(block.getParent() && "cannot match using a detached block"); |
1041 | auto matchScope = state.make_region_scope(region&: *block.getParent()); |
1042 | if (failed( |
1043 | Result: state.mapBlockArguments(arguments: block.getArguments(), mapping: blockArgumentMapping))) |
1044 | return DiagnosedSilenceableFailure::definiteFailure(); |
1045 | |
1046 | for (Operation &match : block.without_terminator()) { |
1047 | if (!isa<transform::MatchOpInterface>(Val: match)) { |
1048 | return emitDefiniteFailure(loc: match.getLoc()) |
1049 | << "expected operations in the match part to " |
1050 | "implement MatchOpInterface"; |
1051 | } |
1052 | DiagnosedSilenceableFailure diag = |
1053 | state.applyTransform(transform: cast<transform::TransformOpInterface>(match)); |
1054 | if (diag.succeeded()) |
1055 | continue; |
1056 | |
1057 | return diag; |
1058 | } |
1059 | |
1060 | // Remember the values mapped to the terminator operands so we can |
1061 | // forward them to the action. |
1062 | ValueRange yieldedValues = block.getTerminator()->getOperands(); |
1063 | // Our contract with the caller is that the mappings will contain only the |
1064 | // newly mapped values, clear the rest. |
1065 | mappings.clear(); |
1066 | transform::detail::prepareValueMappings(mappings, yieldedValues, state); |
1067 | return DiagnosedSilenceableFailure::success(); |
1068 | } |
1069 | |
1070 | /// Returns `true` if both types implement one of the interfaces provided as |
1071 | /// template parameters. |
1072 | template <typename... Tys> |
1073 | static bool implementSameInterface(Type t1, Type t2) { |
1074 | return ((isa<Tys>(t1) && isa<Tys>(t2)) || ... || false); |
1075 | } |
1076 | |
1077 | /// Returns `true` if both types implement one of the transform dialect |
1078 | /// interfaces. |
1079 | static bool implementSameTransformInterface(Type t1, Type t2) { |
1080 | return implementSameInterface<transform::TransformHandleTypeInterface, |
1081 | transform::TransformParamTypeInterface, |
1082 | transform::TransformValueHandleTypeInterface>( |
1083 | t1, t2); |
1084 | } |
1085 | |
1086 | //===----------------------------------------------------------------------===// |
1087 | // CollectMatchingOp |
1088 | //===----------------------------------------------------------------------===// |
1089 | |
1090 | DiagnosedSilenceableFailure |
1091 | transform::CollectMatchingOp::apply(transform::TransformRewriter &rewriter, |
1092 | transform::TransformResults &results, |
1093 | transform::TransformState &state) { |
1094 | auto matcher = SymbolTable::lookupNearestSymbolFrom<FunctionOpInterface>( |
1095 | getOperation(), getMatcher()); |
1096 | if (matcher.isExternal()) { |
1097 | return emitDefiniteFailure() |
1098 | << "unresolved external symbol "<< getMatcher(); |
1099 | } |
1100 | |
1101 | SmallVector<SmallVector<MappedValue>, 2> rawResults; |
1102 | rawResults.resize(getOperation()->getNumResults()); |
1103 | std::optional<DiagnosedSilenceableFailure> maybeFailure; |
1104 | for (Operation *root : state.getPayloadOps(getRoot())) { |
1105 | WalkResult walkResult = root->walk([&](Operation *op) { |
1106 | DEBUG_MATCHER({ |
1107 | DBGS_MATCHER() << "matching "; |
1108 | op->print(llvm::dbgs(), |
1109 | OpPrintingFlags().assumeVerified().skipRegions()); |
1110 | llvm::dbgs() << " @"<< op << "\n"; |
1111 | }); |
1112 | |
1113 | // Try matching. |
1114 | SmallVector<SmallVector<MappedValue>> mappings; |
1115 | SmallVector<transform::MappedValue> inputMapping({op}); |
1116 | DiagnosedSilenceableFailure diag = matchBlock( |
1117 | matcher.getFunctionBody().front(), |
1118 | ArrayRef<SmallVector<transform::MappedValue>>(inputMapping), state, |
1119 | mappings); |
1120 | if (diag.isDefiniteFailure()) |
1121 | return WalkResult::interrupt(); |
1122 | if (diag.isSilenceableFailure()) { |
1123 | DEBUG_MATCHER(DBGS_MATCHER() << "matcher "<< matcher.getName() |
1124 | << " failed: "<< diag.getMessage()); |
1125 | return WalkResult::advance(); |
1126 | } |
1127 | |
1128 | // If succeeded, collect results. |
1129 | for (auto &&[i, mapping] : llvm::enumerate(mappings)) { |
1130 | if (mapping.size() != 1) { |
1131 | maybeFailure.emplace(emitSilenceableError() |
1132 | << "result #"<< i << ", associated with " |
1133 | << mapping.size() |
1134 | << " payload objects, expected 1"); |
1135 | return WalkResult::interrupt(); |
1136 | } |
1137 | rawResults[i].push_back(mapping[0]); |
1138 | } |
1139 | return WalkResult::advance(); |
1140 | }); |
1141 | if (walkResult.wasInterrupted()) |
1142 | return std::move(*maybeFailure); |
1143 | assert(!maybeFailure && "failure set but the walk was not interrupted"); |
1144 | |
1145 | for (auto &&[opResult, rawResult] : |
1146 | llvm::zip_equal(getOperation()->getResults(), rawResults)) { |
1147 | results.setMappedValues(opResult, rawResult); |
1148 | } |
1149 | } |
1150 | return DiagnosedSilenceableFailure::success(); |
1151 | } |
1152 | |
1153 | void transform::CollectMatchingOp::getEffects( |
1154 | SmallVectorImpl<MemoryEffects::EffectInstance> &effects) { |
1155 | onlyReadsHandle(getRootMutable(), effects); |
1156 | producesHandle(getOperation()->getOpResults(), effects); |
1157 | onlyReadsPayload(effects); |
1158 | } |
1159 | |
1160 | LogicalResult transform::CollectMatchingOp::verifySymbolUses( |
1161 | SymbolTableCollection &symbolTable) { |
1162 | auto matcherSymbol = dyn_cast_or_null<FunctionOpInterface>( |
1163 | symbolTable.lookupNearestSymbolFrom(getOperation(), getMatcher())); |
1164 | if (!matcherSymbol || |
1165 | !isa<TransformOpInterface>(matcherSymbol.getOperation())) |
1166 | return emitError() << "unresolved matcher symbol "<< getMatcher(); |
1167 | |
1168 | ArrayRef<Type> argumentTypes = matcherSymbol.getArgumentTypes(); |
1169 | if (argumentTypes.size() != 1 || |
1170 | !isa<TransformHandleTypeInterface>(argumentTypes[0])) { |
1171 | return emitError() |
1172 | << "expected the matcher to take one operation handle argument"; |
1173 | } |
1174 | if (!matcherSymbol.getArgAttr( |
1175 | 0, transform::TransformDialect::kArgReadOnlyAttrName)) { |
1176 | return emitError() << "expected the matcher argument to be marked readonly"; |
1177 | } |
1178 | |
1179 | ArrayRef<Type> resultTypes = matcherSymbol.getResultTypes(); |
1180 | if (resultTypes.size() != getOperation()->getNumResults()) { |
1181 | return emitError() |
1182 | << "expected the matcher to yield as many values as op has results (" |
1183 | << getOperation()->getNumResults() << "), got " |
1184 | << resultTypes.size(); |
1185 | } |
1186 | |
1187 | for (auto &&[i, matcherType, resultType] : |
1188 | llvm::enumerate(resultTypes, getOperation()->getResultTypes())) { |
1189 | if (implementSameTransformInterface(matcherType, resultType)) |
1190 | continue; |
1191 | |
1192 | return emitError() |
1193 | << "mismatching type interfaces for matcher result and op result #" |
1194 | << i; |
1195 | } |
1196 | |
1197 | return success(); |
1198 | } |
1199 | |
1200 | //===----------------------------------------------------------------------===// |
1201 | // ForeachMatchOp |
1202 | //===----------------------------------------------------------------------===// |
1203 | |
1204 | // This is fine because nothing is actually consumed by this op. |
1205 | bool transform::ForeachMatchOp::allowsRepeatedHandleOperands() { return true; } |
1206 | |
1207 | DiagnosedSilenceableFailure |
1208 | transform::ForeachMatchOp::apply(transform::TransformRewriter &rewriter, |
1209 | transform::TransformResults &results, |
1210 | transform::TransformState &state) { |
1211 | SmallVector<std::pair<FunctionOpInterface, FunctionOpInterface>> |
1212 | matchActionPairs; |
1213 | matchActionPairs.reserve(getMatchers().size()); |
1214 | SymbolTableCollection symbolTable; |
1215 | for (auto &&[matcher, action] : |
1216 | llvm::zip_equal(getMatchers(), getActions())) { |
1217 | auto matcherSymbol = |
1218 | symbolTable.lookupNearestSymbolFrom<FunctionOpInterface>( |
1219 | getOperation(), cast<SymbolRefAttr>(matcher)); |
1220 | auto actionSymbol = |
1221 | symbolTable.lookupNearestSymbolFrom<FunctionOpInterface>( |
1222 | getOperation(), cast<SymbolRefAttr>(action)); |
1223 | assert(matcherSymbol && actionSymbol && |
1224 | "unresolved symbols not caught by the verifier"); |
1225 | |
1226 | if (matcherSymbol.isExternal()) |
1227 | return emitDefiniteFailure() << "unresolved external symbol "<< matcher; |
1228 | if (actionSymbol.isExternal()) |
1229 | return emitDefiniteFailure() << "unresolved external symbol "<< action; |
1230 | |
1231 | matchActionPairs.emplace_back(matcherSymbol, actionSymbol); |
1232 | } |
1233 | |
1234 | DiagnosedSilenceableFailure overallDiag = |
1235 | DiagnosedSilenceableFailure::success(); |
1236 | |
1237 | SmallVector<SmallVector<MappedValue>> matchInputMapping; |
1238 | SmallVector<SmallVector<MappedValue>> matchOutputMapping; |
1239 | SmallVector<SmallVector<MappedValue>> actionResultMapping; |
1240 | // Explicitly add the mapping for the first block argument (the op being |
1241 | // matched). |
1242 | matchInputMapping.emplace_back(); |
1243 | transform::detail::prepareValueMappings(matchInputMapping, |
1244 | getForwardedInputs(), state); |
1245 | SmallVector<MappedValue> &firstMatchArgument = matchInputMapping.front(); |
1246 | actionResultMapping.resize(getForwardedOutputs().size()); |
1247 | |
1248 | for (Operation *root : state.getPayloadOps(getRoot())) { |
1249 | WalkResult walkResult = root->walk([&](Operation *op) { |
1250 | // If getRestrictRoot is not present, skip over the root op itself so we |
1251 | // don't invalidate it. |
1252 | if (!getRestrictRoot() && op == root) |
1253 | return WalkResult::advance(); |
1254 | |
1255 | DEBUG_MATCHER({ |
1256 | DBGS_MATCHER() << "matching "; |
1257 | op->print(llvm::dbgs(), |
1258 | OpPrintingFlags().assumeVerified().skipRegions()); |
1259 | llvm::dbgs() << " @"<< op << "\n"; |
1260 | }); |
1261 | |
1262 | firstMatchArgument.clear(); |
1263 | firstMatchArgument.push_back(op); |
1264 | |
1265 | // Try all the match/action pairs until the first successful match. |
1266 | for (auto [matcher, action] : matchActionPairs) { |
1267 | DiagnosedSilenceableFailure diag = |
1268 | matchBlock(matcher.getFunctionBody().front(), matchInputMapping, |
1269 | state, matchOutputMapping); |
1270 | if (diag.isDefiniteFailure()) |
1271 | return WalkResult::interrupt(); |
1272 | if (diag.isSilenceableFailure()) { |
1273 | DEBUG_MATCHER(DBGS_MATCHER() << "matcher "<< matcher.getName() |
1274 | << " failed: "<< diag.getMessage()); |
1275 | continue; |
1276 | } |
1277 | |
1278 | auto scope = state.make_region_scope(action.getFunctionBody()); |
1279 | if (failed(state.mapBlockArguments( |
1280 | action.getFunctionBody().front().getArguments(), |
1281 | matchOutputMapping))) { |
1282 | return WalkResult::interrupt(); |
1283 | } |
1284 | |
1285 | for (Operation &transform : |
1286 | action.getFunctionBody().front().without_terminator()) { |
1287 | DiagnosedSilenceableFailure result = |
1288 | state.applyTransform(cast<TransformOpInterface>(transform)); |
1289 | if (result.isDefiniteFailure()) |
1290 | return WalkResult::interrupt(); |
1291 | if (result.isSilenceableFailure()) { |
1292 | if (overallDiag.succeeded()) { |
1293 | overallDiag = emitSilenceableError() << "actions failed"; |
1294 | } |
1295 | overallDiag.attachNote(action->getLoc()) |
1296 | << "failed action: "<< result.getMessage(); |
1297 | overallDiag.attachNote(op->getLoc()) |
1298 | << "when applied to this matching payload"; |
1299 | (void)result.silence(); |
1300 | continue; |
1301 | } |
1302 | } |
1303 | if (failed(detail::appendValueMappings( |
1304 | MutableArrayRef<SmallVector<MappedValue>>(actionResultMapping), |
1305 | action.getFunctionBody().front().getTerminator()->getOperands(), |
1306 | state, getFlattenResults()))) { |
1307 | emitDefiniteFailure() |
1308 | << "action @"<< action.getName() |
1309 | << " has results associated with multiple payload entities, " |
1310 | "but flattening was not requested"; |
1311 | return WalkResult::interrupt(); |
1312 | } |
1313 | break; |
1314 | } |
1315 | return WalkResult::advance(); |
1316 | }); |
1317 | if (walkResult.wasInterrupted()) |
1318 | return DiagnosedSilenceableFailure::definiteFailure(); |
1319 | } |
1320 | |
1321 | // The root operation should not have been affected, so we can just reassign |
1322 | // the payload to the result. Note that we need to consume the root handle to |
1323 | // make sure any handles to operations inside, that could have been affected |
1324 | // by actions, are invalidated. |
1325 | results.set(llvm::cast<OpResult>(getUpdated()), |
1326 | state.getPayloadOps(getRoot())); |
1327 | for (auto &&[result, mapping] : |
1328 | llvm::zip_equal(getForwardedOutputs(), actionResultMapping)) { |
1329 | results.setMappedValues(result, mapping); |
1330 | } |
1331 | return overallDiag; |
1332 | } |
1333 | |
1334 | void transform::ForeachMatchOp::getAsmResultNames( |
1335 | OpAsmSetValueNameFn setNameFn) { |
1336 | setNameFn(getUpdated(), "updated_root"); |
1337 | for (Value v : getForwardedOutputs()) { |
1338 | setNameFn(v, "yielded"); |
1339 | } |
1340 | } |
1341 | |
1342 | void transform::ForeachMatchOp::getEffects( |
1343 | SmallVectorImpl<MemoryEffects::EffectInstance> &effects) { |
1344 | // Bail if invalid. |
1345 | if (getOperation()->getNumOperands() < 1 || |
1346 | getOperation()->getNumResults() < 1) { |
1347 | return modifiesPayload(effects); |
1348 | } |
1349 | |
1350 | consumesHandle(getRootMutable(), effects); |
1351 | onlyReadsHandle(getForwardedInputsMutable(), effects); |
1352 | producesHandle(getOperation()->getOpResults(), effects); |
1353 | modifiesPayload(effects); |
1354 | } |
1355 | |
1356 | /// Parses the comma-separated list of symbol reference pairs of the format |
1357 | /// `@matcher -> @action`. |
1358 | static ParseResult parseForeachMatchSymbols(OpAsmParser &parser, |
1359 | ArrayAttr &matchers, |
1360 | ArrayAttr &actions) { |
1361 | StringAttr matcher; |
1362 | StringAttr action; |
1363 | SmallVector<Attribute> matcherList; |
1364 | SmallVector<Attribute> actionList; |
1365 | do { |
1366 | if (parser.parseSymbolName(matcher) || parser.parseArrow() || |
1367 | parser.parseSymbolName(action)) { |
1368 | return failure(); |
1369 | } |
1370 | matcherList.push_back(SymbolRefAttr::get(matcher)); |
1371 | actionList.push_back(SymbolRefAttr::get(action)); |
1372 | } while (parser.parseOptionalComma().succeeded()); |
1373 | |
1374 | matchers = parser.getBuilder().getArrayAttr(matcherList); |
1375 | actions = parser.getBuilder().getArrayAttr(actionList); |
1376 | return success(); |
1377 | } |
1378 | |
1379 | /// Prints the comma-separated list of symbol reference pairs of the format |
1380 | /// `@matcher -> @action`. |
1381 | static void printForeachMatchSymbols(OpAsmPrinter &printer, Operation *op, |
1382 | ArrayAttr matchers, ArrayAttr actions) { |
1383 | printer.increaseIndent(); |
1384 | printer.increaseIndent(); |
1385 | for (auto &&[matcher, action, idx] : llvm::zip_equal( |
1386 | matchers, actions, llvm::seq<unsigned>(0, matchers.size()))) { |
1387 | printer.printNewline(); |
1388 | printer << cast<SymbolRefAttr>(matcher) << " -> " |
1389 | << cast<SymbolRefAttr>(action); |
1390 | if (idx != matchers.size() - 1) |
1391 | printer << ", "; |
1392 | } |
1393 | printer.decreaseIndent(); |
1394 | printer.decreaseIndent(); |
1395 | } |
1396 | |
1397 | LogicalResult transform::ForeachMatchOp::verify() { |
1398 | if (getMatchers().size() != getActions().size()) |
1399 | return emitOpError() << "expected the same number of matchers and actions"; |
1400 | if (getMatchers().empty()) |
1401 | return emitOpError() << "expected at least one match/action pair"; |
1402 | |
1403 | llvm::SmallPtrSet<Attribute, 8> matcherNames; |
1404 | for (Attribute name : getMatchers()) { |
1405 | if (matcherNames.insert(name).second) |
1406 | continue; |
1407 | emitWarning() << "matcher "<< name |
1408 | << " is used more than once, only the first match will apply"; |
1409 | } |
1410 | |
1411 | return success(); |
1412 | } |
1413 | |
1414 | /// Checks that the attributes of the function-like operation have correct |
1415 | /// consumption effect annotations. If `alsoVerifyInternal`, checks for |
1416 | /// annotations being present even if they can be inferred from the body. |
1417 | static DiagnosedSilenceableFailure |
1418 | verifyFunctionLikeConsumeAnnotations(FunctionOpInterface op, bool emitWarnings, |
1419 | bool alsoVerifyInternal = false) { |
1420 | auto transformOp = cast<transform::TransformOpInterface>(op.getOperation()); |
1421 | llvm::SmallDenseSet<unsigned> consumedArguments; |
1422 | if (!op.isExternal()) { |
1423 | transform::getConsumedBlockArguments(block&: op.getFunctionBody().front(), |
1424 | consumedArguments); |
1425 | } |
1426 | for (unsigned i = 0, e = op.getNumArguments(); i < e; ++i) { |
1427 | bool isConsumed = |
1428 | op.getArgAttr(i, transform::TransformDialect::kArgConsumedAttrName) != |
1429 | nullptr; |
1430 | bool isReadOnly = |
1431 | op.getArgAttr(i, transform::TransformDialect::kArgReadOnlyAttrName) != |
1432 | nullptr; |
1433 | if (isConsumed && isReadOnly) { |
1434 | return transformOp.emitSilenceableError() |
1435 | << "argument #"<< i << " cannot be both readonly and consumed"; |
1436 | } |
1437 | if ((op.isExternal() || alsoVerifyInternal) && !isConsumed && !isReadOnly) { |
1438 | return transformOp.emitSilenceableError() |
1439 | << "must provide consumed/readonly status for arguments of " |
1440 | "external or called ops"; |
1441 | } |
1442 | if (op.isExternal()) |
1443 | continue; |
1444 | |
1445 | if (consumedArguments.contains(V: i) && !isConsumed && isReadOnly) { |
1446 | return transformOp.emitSilenceableError() |
1447 | << "argument #"<< i |
1448 | << " is consumed in the body but is not marked as such"; |
1449 | } |
1450 | if (emitWarnings && !consumedArguments.contains(V: i) && isConsumed) { |
1451 | // Cannot use op.emitWarning() here as it would attempt to verify the op |
1452 | // before printing, resulting in infinite recursion. |
1453 | emitWarning(op->getLoc()) |
1454 | << "op argument #"<< i |
1455 | << " is not consumed in the body but is marked as consumed"; |
1456 | } |
1457 | } |
1458 | return DiagnosedSilenceableFailure::success(); |
1459 | } |
1460 | |
1461 | LogicalResult transform::ForeachMatchOp::verifySymbolUses( |
1462 | SymbolTableCollection &symbolTable) { |
1463 | assert(getMatchers().size() == getActions().size()); |
1464 | auto consumedAttr = |
1465 | StringAttr::get(getContext(), TransformDialect::kArgConsumedAttrName); |
1466 | for (auto &&[matcher, action] : |
1467 | llvm::zip_equal(getMatchers(), getActions())) { |
1468 | // Presence and typing. |
1469 | auto matcherSymbol = dyn_cast_or_null<FunctionOpInterface>( |
1470 | symbolTable.lookupNearestSymbolFrom(getOperation(), |
1471 | cast<SymbolRefAttr>(matcher))); |
1472 | auto actionSymbol = dyn_cast_or_null<FunctionOpInterface>( |
1473 | symbolTable.lookupNearestSymbolFrom(getOperation(), |
1474 | cast<SymbolRefAttr>(action))); |
1475 | if (!matcherSymbol || |
1476 | !isa<TransformOpInterface>(matcherSymbol.getOperation())) |
1477 | return emitError() << "unresolved matcher symbol "<< matcher; |
1478 | if (!actionSymbol || |
1479 | !isa<TransformOpInterface>(actionSymbol.getOperation())) |
1480 | return emitError() << "unresolved action symbol "<< action; |
1481 | |
1482 | if (failed(verifyFunctionLikeConsumeAnnotations(matcherSymbol, |
1483 | /*emitWarnings=*/false, |
1484 | /*alsoVerifyInternal=*/true) |
1485 | .checkAndReport())) { |
1486 | return failure(); |
1487 | } |
1488 | if (failed(verifyFunctionLikeConsumeAnnotations(actionSymbol, |
1489 | /*emitWarnings=*/false, |
1490 | /*alsoVerifyInternal=*/true) |
1491 | .checkAndReport())) { |
1492 | return failure(); |
1493 | } |
1494 | |
1495 | // Input -> matcher forwarding. |
1496 | TypeRange operandTypes = getOperandTypes(); |
1497 | TypeRange matcherArguments = matcherSymbol.getArgumentTypes(); |
1498 | if (operandTypes.size() != matcherArguments.size()) { |
1499 | InFlightDiagnostic diag = |
1500 | emitError() << "the number of operands ("<< operandTypes.size() |
1501 | << ") doesn't match the number of matcher arguments (" |
1502 | << matcherArguments.size() << ") for "<< matcher; |
1503 | diag.attachNote(matcherSymbol->getLoc()) << "symbol declaration"; |
1504 | return diag; |
1505 | } |
1506 | for (auto &&[i, operand, argument] : |
1507 | llvm::enumerate(operandTypes, matcherArguments)) { |
1508 | if (matcherSymbol.getArgAttr(i, consumedAttr)) { |
1509 | InFlightDiagnostic diag = |
1510 | emitOpError() |
1511 | << "does not expect matcher symbol to consume its operand #"<< i; |
1512 | diag.attachNote(matcherSymbol->getLoc()) << "symbol declaration"; |
1513 | return diag; |
1514 | } |
1515 | |
1516 | if (implementSameTransformInterface(operand, argument)) |
1517 | continue; |
1518 | |
1519 | InFlightDiagnostic diag = |
1520 | emitError() |
1521 | << "mismatching type interfaces for operand and matcher argument #" |
1522 | << i << " of matcher "<< matcher; |
1523 | diag.attachNote(matcherSymbol->getLoc()) << "symbol declaration"; |
1524 | return diag; |
1525 | } |
1526 | |
1527 | // Matcher -> action forwarding. |
1528 | TypeRange matcherResults = matcherSymbol.getResultTypes(); |
1529 | TypeRange actionArguments = actionSymbol.getArgumentTypes(); |
1530 | if (matcherResults.size() != actionArguments.size()) { |
1531 | return emitError() << "mismatching number of matcher results and " |
1532 | "action arguments between " |
1533 | << matcher << " ("<< matcherResults.size() << ") and " |
1534 | << action << " ("<< actionArguments.size() << ")"; |
1535 | } |
1536 | for (auto &&[i, matcherType, actionType] : |
1537 | llvm::enumerate(matcherResults, actionArguments)) { |
1538 | if (implementSameTransformInterface(matcherType, actionType)) |
1539 | continue; |
1540 | |
1541 | return emitError() << "mismatching type interfaces for matcher result " |
1542 | "and action argument #" |
1543 | << i << "of matcher "<< matcher << " and action " |
1544 | << action; |
1545 | } |
1546 | |
1547 | // Action -> result forwarding. |
1548 | TypeRange actionResults = actionSymbol.getResultTypes(); |
1549 | auto resultTypes = TypeRange(getResultTypes()).drop_front(); |
1550 | if (actionResults.size() != resultTypes.size()) { |
1551 | InFlightDiagnostic diag = |
1552 | emitError() << "the number of action results (" |
1553 | << actionResults.size() << ") for "<< action |
1554 | << " doesn't match the number of extra op results (" |
1555 | << resultTypes.size() << ")"; |
1556 | diag.attachNote(actionSymbol->getLoc()) << "symbol declaration"; |
1557 | return diag; |
1558 | } |
1559 | for (auto &&[i, resultType, actionType] : |
1560 | llvm::enumerate(resultTypes, actionResults)) { |
1561 | if (implementSameTransformInterface(resultType, actionType)) |
1562 | continue; |
1563 | |
1564 | InFlightDiagnostic diag = |
1565 | emitError() << "mismatching type interfaces for action result #"<< i |
1566 | << " of action "<< action << " and op result"; |
1567 | diag.attachNote(actionSymbol->getLoc()) << "symbol declaration"; |
1568 | return diag; |
1569 | } |
1570 | } |
1571 | return success(); |
1572 | } |
1573 | |
1574 | //===----------------------------------------------------------------------===// |
1575 | // ForeachOp |
1576 | //===----------------------------------------------------------------------===// |
1577 | |
1578 | DiagnosedSilenceableFailure |
1579 | transform::ForeachOp::apply(transform::TransformRewriter &rewriter, |
1580 | transform::TransformResults &results, |
1581 | transform::TransformState &state) { |
1582 | // We store the payloads before executing the body as ops may be removed from |
1583 | // the mapping by the TrackingRewriter while iteration is in progress. |
1584 | SmallVector<SmallVector<MappedValue>> payloads; |
1585 | detail::prepareValueMappings(payloads, getTargets(), state); |
1586 | size_t numIterations = payloads.empty() ? 0 : payloads.front().size(); |
1587 | bool withZipShortest = getWithZipShortest(); |
1588 | |
1589 | // In case of `zip_shortest`, set the number of iterations to the |
1590 | // smallest payload in the targets. |
1591 | if (withZipShortest) { |
1592 | numIterations = |
1593 | llvm::min_element(payloads, [&](const SmallVector<MappedValue> &A, |
1594 | const SmallVector<MappedValue> &B) { |
1595 | return A.size() < B.size(); |
1596 | })->size(); |
1597 | |
1598 | for (size_t argIdx = 0; argIdx < payloads.size(); argIdx++) |
1599 | payloads[argIdx].resize(numIterations); |
1600 | } |
1601 | |
1602 | // As we will be "zipping" over them, check all payloads have the same size. |
1603 | // `zip_shortest` adjusts all payloads to the same size, so skip this check |
1604 | // when true. |
1605 | for (size_t argIdx = 1; !withZipShortest && argIdx < payloads.size(); |
1606 | argIdx++) { |
1607 | if (payloads[argIdx].size() != numIterations) { |
1608 | return emitSilenceableError() |
1609 | << "prior targets' payload size ("<< numIterations |
1610 | << ") differs from payload size ("<< payloads[argIdx].size() |
1611 | << ") of target "<< getTargets()[argIdx]; |
1612 | } |
1613 | } |
1614 | |
1615 | // Start iterating, indexing into payloads to obtain the right arguments to |
1616 | // call the body with - each slice of payloads at the same argument index |
1617 | // corresponding to a tuple to use as the body's block arguments. |
1618 | ArrayRef<BlockArgument> blockArguments = getBody().front().getArguments(); |
1619 | SmallVector<SmallVector<MappedValue>> zippedResults(getNumResults(), {}); |
1620 | for (size_t iterIdx = 0; iterIdx < numIterations; iterIdx++) { |
1621 | auto scope = state.make_region_scope(getBody()); |
1622 | // Set up arguments to the region's block. |
1623 | for (auto &&[argIdx, blockArg] : llvm::enumerate(blockArguments)) { |
1624 | MappedValue argument = payloads[argIdx][iterIdx]; |
1625 | // Note that each blockArg's handle gets associated with just a single |
1626 | // element from the corresponding target's payload. |
1627 | if (failed(state.mapBlockArgument(blockArg, {argument}))) |
1628 | return DiagnosedSilenceableFailure::definiteFailure(); |
1629 | } |
1630 | |
1631 | // Execute loop body. |
1632 | for (Operation &transform : getBody().front().without_terminator()) { |
1633 | DiagnosedSilenceableFailure result = state.applyTransform( |
1634 | llvm::cast<transform::TransformOpInterface>(transform)); |
1635 | if (!result.succeeded()) |
1636 | return result; |
1637 | } |
1638 | |
1639 | // Append yielded payloads to corresponding results from prior iterations. |
1640 | OperandRange yieldOperands = getYieldOp().getOperands(); |
1641 | for (auto &&[result, yieldOperand, resTuple] : |
1642 | llvm::zip_equal(getResults(), yieldOperands, zippedResults)) |
1643 | // NB: each iteration we add any number of ops/vals/params to a result. |
1644 | if (isa<TransformHandleTypeInterface>(result.getType())) |
1645 | llvm::append_range(resTuple, state.getPayloadOps(yieldOperand)); |
1646 | else if (isa<TransformValueHandleTypeInterface>(result.getType())) |
1647 | llvm::append_range(resTuple, state.getPayloadValues(yieldOperand)); |
1648 | else if (isa<TransformParamTypeInterface>(result.getType())) |
1649 | llvm::append_range(resTuple, state.getParams(yieldOperand)); |
1650 | else |
1651 | assert(false && "unhandled handle type"); |
1652 | } |
1653 | |
1654 | // Associate the accumulated result payloads to the op's actual results. |
1655 | for (auto &&[result, resPayload] : zip_equal(getResults(), zippedResults)) |
1656 | results.setMappedValues(llvm::cast<OpResult>(result), resPayload); |
1657 | |
1658 | return DiagnosedSilenceableFailure::success(); |
1659 | } |
1660 | |
1661 | void transform::ForeachOp::getEffects( |
1662 | SmallVectorImpl<MemoryEffects::EffectInstance> &effects) { |
1663 | // NB: this `zip` should be `zip_equal` - while this op's verifier catches |
1664 | // arity errors, this method might get called before/in absence of `verify()`. |
1665 | for (auto &&[target, blockArg] : |
1666 | llvm::zip(getTargetsMutable(), getBody().front().getArguments())) { |
1667 | BlockArgument blockArgument = blockArg; |
1668 | if (any_of(getBody().front().without_terminator(), [&](Operation &op) { |
1669 | return isHandleConsumed(blockArgument, |
1670 | cast<TransformOpInterface>(&op)); |
1671 | })) { |
1672 | consumesHandle(target, effects); |
1673 | } else { |
1674 | onlyReadsHandle(target, effects); |
1675 | } |
1676 | } |
1677 | |
1678 | if (any_of(getBody().front().without_terminator(), [&](Operation &op) { |
1679 | return doesModifyPayload(cast<TransformOpInterface>(&op)); |
1680 | })) { |
1681 | modifiesPayload(effects); |
1682 | } else if (any_of(getBody().front().without_terminator(), [&](Operation &op) { |
1683 | return doesReadPayload(cast<TransformOpInterface>(&op)); |
1684 | })) { |
1685 | onlyReadsPayload(effects); |
1686 | } |
1687 | |
1688 | producesHandle(getOperation()->getOpResults(), effects); |
1689 | } |
1690 | |
1691 | void transform::ForeachOp::getSuccessorRegions( |
1692 | RegionBranchPoint point, SmallVectorImpl<RegionSuccessor> ®ions) { |
1693 | Region *bodyRegion = &getBody(); |
1694 | if (point.isParent()) { |
1695 | regions.emplace_back(bodyRegion, bodyRegion->getArguments()); |
1696 | return; |
1697 | } |
1698 | |
1699 | // Branch back to the region or the parent. |
1700 | assert(point == getBody() && "unexpected region index"); |
1701 | regions.emplace_back(bodyRegion, bodyRegion->getArguments()); |
1702 | regions.emplace_back(); |
1703 | } |
1704 | |
1705 | OperandRange |
1706 | transform::ForeachOp::getEntrySuccessorOperands(RegionBranchPoint point) { |
1707 | // Each block argument handle is mapped to a subset (one op to be precise) |
1708 | // of the payload of the corresponding `targets` operand of ForeachOp. |
1709 | assert(point == getBody() && "unexpected region index"); |
1710 | return getOperation()->getOperands(); |
1711 | } |
1712 | |
1713 | transform::YieldOp transform::ForeachOp::getYieldOp() { |
1714 | return cast<transform::YieldOp>(getBody().front().getTerminator()); |
1715 | } |
1716 | |
1717 | LogicalResult transform::ForeachOp::verify() { |
1718 | for (auto [targetOpt, bodyArgOpt] : |
1719 | llvm::zip_longest(getTargets(), getBody().front().getArguments())) { |
1720 | if (!targetOpt || !bodyArgOpt) |
1721 | return emitOpError() << "expects the same number of targets as the body " |
1722 | "has block arguments"; |
1723 | if (targetOpt.value().getType() != bodyArgOpt.value().getType()) |
1724 | return emitOpError( |
1725 | "expects co-indexed targets and the body's " |
1726 | "block arguments to have the same op/value/param type"); |
1727 | } |
1728 | |
1729 | for (auto [resultOpt, yieldOperandOpt] : |
1730 | llvm::zip_longest(getResults(), getYieldOp().getOperands())) { |
1731 | if (!resultOpt || !yieldOperandOpt) |
1732 | return emitOpError() << "expects the same number of results as the " |
1733 | "yield terminator has operands"; |
1734 | if (resultOpt.value().getType() != yieldOperandOpt.value().getType()) |
1735 | return emitOpError("expects co-indexed results and yield " |
1736 | "operands to have the same op/value/param type"); |
1737 | } |
1738 | |
1739 | return success(); |
1740 | } |
1741 | |
1742 | //===----------------------------------------------------------------------===// |
1743 | // GetParentOp |
1744 | //===----------------------------------------------------------------------===// |
1745 | |
1746 | DiagnosedSilenceableFailure |
1747 | transform::GetParentOp::apply(transform::TransformRewriter &rewriter, |
1748 | transform::TransformResults &results, |
1749 | transform::TransformState &state) { |
1750 | SmallVector<Operation *> parents; |
1751 | DenseSet<Operation *> resultSet; |
1752 | for (Operation *target : state.getPayloadOps(getTarget())) { |
1753 | Operation *parent = target; |
1754 | for (int64_t i = 0, e = getNthParent(); i < e; ++i) { |
1755 | parent = parent->getParentOp(); |
1756 | while (parent) { |
1757 | bool checkIsolatedFromAbove = |
1758 | !getIsolatedFromAbove() || |
1759 | parent->hasTrait<OpTrait::IsIsolatedFromAbove>(); |
1760 | bool checkOpName = !getOpName().has_value() || |
1761 | parent->getName().getStringRef() == *getOpName(); |
1762 | if (checkIsolatedFromAbove && checkOpName) |
1763 | break; |
1764 | parent = parent->getParentOp(); |
1765 | } |
1766 | if (!parent) { |
1767 | if (getAllowEmptyResults()) { |
1768 | results.set(llvm::cast<OpResult>(getResult()), parents); |
1769 | return DiagnosedSilenceableFailure::success(); |
1770 | } |
1771 | DiagnosedSilenceableFailure diag = |
1772 | emitSilenceableError() |
1773 | << "could not find a parent op that matches all requirements"; |
1774 | diag.attachNote(target->getLoc()) << "target op"; |
1775 | return diag; |
1776 | } |
1777 | } |
1778 | if (getDeduplicate()) { |
1779 | if (resultSet.insert(parent).second) |
1780 | parents.push_back(parent); |
1781 | } else { |
1782 | parents.push_back(parent); |
1783 | } |
1784 | } |
1785 | results.set(llvm::cast<OpResult>(getResult()), parents); |
1786 | return DiagnosedSilenceableFailure::success(); |
1787 | } |
1788 | |
1789 | //===----------------------------------------------------------------------===// |
1790 | // GetConsumersOfResult |
1791 | //===----------------------------------------------------------------------===// |
1792 | |
1793 | DiagnosedSilenceableFailure |
1794 | transform::GetConsumersOfResult::apply(transform::TransformRewriter &rewriter, |
1795 | transform::TransformResults &results, |
1796 | transform::TransformState &state) { |
1797 | int64_t resultNumber = getResultNumber(); |
1798 | auto payloadOps = state.getPayloadOps(getTarget()); |
1799 | if (std::empty(payloadOps)) { |
1800 | results.set(cast<OpResult>(getResult()), {}); |
1801 | return DiagnosedSilenceableFailure::success(); |
1802 | } |
1803 | if (!llvm::hasSingleElement(payloadOps)) |
1804 | return emitDefiniteFailure() |
1805 | << "handle must be mapped to exactly one payload op"; |
1806 | |
1807 | Operation *target = *payloadOps.begin(); |
1808 | if (target->getNumResults() <= resultNumber) |
1809 | return emitDefiniteFailure() << "result number overflow"; |
1810 | results.set(llvm::cast<OpResult>(getResult()), |
1811 | llvm::to_vector(target->getResult(resultNumber).getUsers())); |
1812 | return DiagnosedSilenceableFailure::success(); |
1813 | } |
1814 | |
1815 | //===----------------------------------------------------------------------===// |
1816 | // GetDefiningOp |
1817 | //===----------------------------------------------------------------------===// |
1818 | |
1819 | DiagnosedSilenceableFailure |
1820 | transform::GetDefiningOp::apply(transform::TransformRewriter &rewriter, |
1821 | transform::TransformResults &results, |
1822 | transform::TransformState &state) { |
1823 | SmallVector<Operation *> definingOps; |
1824 | for (Value v : state.getPayloadValues(getTarget())) { |
1825 | if (llvm::isa<BlockArgument>(v)) { |
1826 | DiagnosedSilenceableFailure diag = |
1827 | emitSilenceableError() << "cannot get defining op of block argument"; |
1828 | diag.attachNote(v.getLoc()) << "target value"; |
1829 | return diag; |
1830 | } |
1831 | definingOps.push_back(v.getDefiningOp()); |
1832 | } |
1833 | results.set(llvm::cast<OpResult>(getResult()), definingOps); |
1834 | return DiagnosedSilenceableFailure::success(); |
1835 | } |
1836 | |
1837 | //===----------------------------------------------------------------------===// |
1838 | // GetProducerOfOperand |
1839 | //===----------------------------------------------------------------------===// |
1840 | |
1841 | DiagnosedSilenceableFailure |
1842 | transform::GetProducerOfOperand::apply(transform::TransformRewriter &rewriter, |
1843 | transform::TransformResults &results, |
1844 | transform::TransformState &state) { |
1845 | int64_t operandNumber = getOperandNumber(); |
1846 | SmallVector<Operation *> producers; |
1847 | for (Operation *target : state.getPayloadOps(getTarget())) { |
1848 | Operation *producer = |
1849 | target->getNumOperands() <= operandNumber |
1850 | ? nullptr |
1851 | : target->getOperand(operandNumber).getDefiningOp(); |
1852 | if (!producer) { |
1853 | DiagnosedSilenceableFailure diag = |
1854 | emitSilenceableError() |
1855 | << "could not find a producer for operand number: "<< operandNumber |
1856 | << " of "<< *target; |
1857 | diag.attachNote(target->getLoc()) << "target op"; |
1858 | return diag; |
1859 | } |
1860 | producers.push_back(producer); |
1861 | } |
1862 | results.set(llvm::cast<OpResult>(getResult()), producers); |
1863 | return DiagnosedSilenceableFailure::success(); |
1864 | } |
1865 | |
1866 | //===----------------------------------------------------------------------===// |
1867 | // GetOperandOp |
1868 | //===----------------------------------------------------------------------===// |
1869 | |
1870 | DiagnosedSilenceableFailure |
1871 | transform::GetOperandOp::apply(transform::TransformRewriter &rewriter, |
1872 | transform::TransformResults &results, |
1873 | transform::TransformState &state) { |
1874 | SmallVector<Value> operands; |
1875 | for (Operation *target : state.getPayloadOps(getTarget())) { |
1876 | SmallVector<int64_t> operandPositions; |
1877 | DiagnosedSilenceableFailure diag = expandTargetSpecification( |
1878 | getLoc(), getIsAll(), getIsInverted(), getRawPositionList(), |
1879 | target->getNumOperands(), operandPositions); |
1880 | if (diag.isSilenceableFailure()) { |
1881 | diag.attachNote(target->getLoc()) |
1882 | << "while considering positions of this payload operation"; |
1883 | return diag; |
1884 | } |
1885 | llvm::append_range(operands, |
1886 | llvm::map_range(operandPositions, [&](int64_t pos) { |
1887 | return target->getOperand(pos); |
1888 | })); |
1889 | } |
1890 | results.setValues(cast<OpResult>(getResult()), operands); |
1891 | return DiagnosedSilenceableFailure::success(); |
1892 | } |
1893 | |
1894 | LogicalResult transform::GetOperandOp::verify() { |
1895 | return verifyTransformMatchDimsOp(getOperation(), getRawPositionList(), |
1896 | getIsInverted(), getIsAll()); |
1897 | } |
1898 | |
1899 | //===----------------------------------------------------------------------===// |
1900 | // GetResultOp |
1901 | //===----------------------------------------------------------------------===// |
1902 | |
1903 | DiagnosedSilenceableFailure |
1904 | transform::GetResultOp::apply(transform::TransformRewriter &rewriter, |
1905 | transform::TransformResults &results, |
1906 | transform::TransformState &state) { |
1907 | SmallVector<Value> opResults; |
1908 | for (Operation *target : state.getPayloadOps(getTarget())) { |
1909 | SmallVector<int64_t> resultPositions; |
1910 | DiagnosedSilenceableFailure diag = expandTargetSpecification( |
1911 | getLoc(), getIsAll(), getIsInverted(), getRawPositionList(), |
1912 | target->getNumResults(), resultPositions); |
1913 | if (diag.isSilenceableFailure()) { |
1914 | diag.attachNote(target->getLoc()) |
1915 | << "while considering positions of this payload operation"; |
1916 | return diag; |
1917 | } |
1918 | llvm::append_range(opResults, |
1919 | llvm::map_range(resultPositions, [&](int64_t pos) { |
1920 | return target->getResult(pos); |
1921 | })); |
1922 | } |
1923 | results.setValues(cast<OpResult>(getResult()), opResults); |
1924 | return DiagnosedSilenceableFailure::success(); |
1925 | } |
1926 | |
1927 | LogicalResult transform::GetResultOp::verify() { |
1928 | return verifyTransformMatchDimsOp(getOperation(), getRawPositionList(), |
1929 | getIsInverted(), getIsAll()); |
1930 | } |
1931 | |
1932 | //===----------------------------------------------------------------------===// |
1933 | // GetTypeOp |
1934 | //===----------------------------------------------------------------------===// |
1935 | |
1936 | void transform::GetTypeOp::getEffects( |
1937 | SmallVectorImpl<MemoryEffects::EffectInstance> &effects) { |
1938 | onlyReadsHandle(getValueMutable(), effects); |
1939 | producesHandle(getOperation()->getOpResults(), effects); |
1940 | onlyReadsPayload(effects); |
1941 | } |
1942 | |
1943 | DiagnosedSilenceableFailure |
1944 | transform::GetTypeOp::apply(transform::TransformRewriter &rewriter, |
1945 | transform::TransformResults &results, |
1946 | transform::TransformState &state) { |
1947 | SmallVector<Attribute> params; |
1948 | for (Value value : state.getPayloadValues(getValue())) { |
1949 | Type type = value.getType(); |
1950 | if (getElemental()) { |
1951 | if (auto shaped = dyn_cast<ShapedType>(type)) { |
1952 | type = shaped.getElementType(); |
1953 | } |
1954 | } |
1955 | params.push_back(TypeAttr::get(type)); |
1956 | } |
1957 | results.setParams(cast<OpResult>(getResult()), params); |
1958 | return DiagnosedSilenceableFailure::success(); |
1959 | } |
1960 | |
1961 | //===----------------------------------------------------------------------===// |
1962 | // IncludeOp |
1963 | //===----------------------------------------------------------------------===// |
1964 | |
1965 | /// Applies the transform ops contained in `block`. Maps `results` to the same |
1966 | /// values as the operands of the block terminator. |
1967 | static DiagnosedSilenceableFailure |
1968 | applySequenceBlock(Block &block, transform::FailurePropagationMode mode, |
1969 | transform::TransformState &state, |
1970 | transform::TransformResults &results) { |
1971 | // Apply the sequenced ops one by one. |
1972 | for (Operation &transform : block.without_terminator()) { |
1973 | DiagnosedSilenceableFailure result = |
1974 | state.applyTransform(transform: cast<transform::TransformOpInterface>(transform)); |
1975 | if (result.isDefiniteFailure()) |
1976 | return result; |
1977 | |
1978 | if (result.isSilenceableFailure()) { |
1979 | if (mode == transform::FailurePropagationMode::Propagate) { |
1980 | // Propagate empty results in case of early exit. |
1981 | forwardEmptyOperands(block: &block, state, results); |
1982 | return result; |
1983 | } |
1984 | (void)result.silence(); |
1985 | } |
1986 | } |
1987 | |
1988 | // Forward the operation mapping for values yielded from the sequence to the |
1989 | // values produced by the sequence op. |
1990 | transform::detail::forwardTerminatorOperands(block: &block, state, results); |
1991 | return DiagnosedSilenceableFailure::success(); |
1992 | } |
1993 | |
1994 | DiagnosedSilenceableFailure |
1995 | transform::IncludeOp::apply(transform::TransformRewriter &rewriter, |
1996 | transform::TransformResults &results, |
1997 | transform::TransformState &state) { |
1998 | auto callee = SymbolTable::lookupNearestSymbolFrom<NamedSequenceOp>( |
1999 | getOperation(), getTarget()); |
2000 | assert(callee && "unverified reference to unknown symbol"); |
2001 | |
2002 | if (callee.isExternal()) |
2003 | return emitDefiniteFailure() << "unresolved external named sequence"; |
2004 | |
2005 | // Map operands to block arguments. |
2006 | SmallVector<SmallVector<MappedValue>> mappings; |
2007 | detail::prepareValueMappings(mappings, getOperands(), state); |
2008 | auto scope = state.make_region_scope(callee.getBody()); |
2009 | for (auto &&[arg, map] : |
2010 | llvm::zip_equal(callee.getBody().front().getArguments(), mappings)) { |
2011 | if (failed(state.mapBlockArgument(arg, map))) |
2012 | return DiagnosedSilenceableFailure::definiteFailure(); |
2013 | } |
2014 | |
2015 | DiagnosedSilenceableFailure result = applySequenceBlock( |
2016 | callee.getBody().front(), getFailurePropagationMode(), state, results); |
2017 | mappings.clear(); |
2018 | detail::prepareValueMappings( |
2019 | mappings, callee.getBody().front().getTerminator()->getOperands(), state); |
2020 | for (auto &&[result, mapping] : llvm::zip_equal(getResults(), mappings)) |
2021 | results.setMappedValues(result, mapping); |
2022 | return result; |
2023 | } |
2024 | |
2025 | static DiagnosedSilenceableFailure |
2026 | verifyNamedSequenceOp(transform::NamedSequenceOp op, bool emitWarnings); |
2027 | |
2028 | void transform::IncludeOp::getEffects( |
2029 | SmallVectorImpl<MemoryEffects::EffectInstance> &effects) { |
2030 | // Always mark as modifying the payload. |
2031 | // TODO: a mechanism to annotate effects on payload. Even when all handles are |
2032 | // only read, the payload may still be modified, so we currently stay on the |
2033 | // conservative side and always indicate modification. This may prevent some |
2034 | // code reordering. |
2035 | modifiesPayload(effects); |
2036 | |
2037 | // Results are always produced. |
2038 | producesHandle(getOperation()->getOpResults(), effects); |
2039 | |
2040 | // Adds default effects to operands and results. This will be added if |
2041 | // preconditions fail so the trait verifier doesn't complain about missing |
2042 | // effects and the real precondition failure is reported later on. |
2043 | auto defaultEffects = [&] { |
2044 | onlyReadsHandle(getOperation()->getOpOperands(), effects); |
2045 | }; |
2046 | |
2047 | // Bail if the callee is unknown. This may run as part of the verification |
2048 | // process before we verified the validity of the callee or of this op. |
2049 | auto target = |
2050 | getOperation()->getAttrOfType<SymbolRefAttr>(getTargetAttrName()); |
2051 | if (!target) |
2052 | return defaultEffects(); |
2053 | auto callee = SymbolTable::lookupNearestSymbolFrom<NamedSequenceOp>( |
2054 | getOperation(), getTarget()); |
2055 | if (!callee) |
2056 | return defaultEffects(); |
2057 | DiagnosedSilenceableFailure earlyVerifierResult = |
2058 | verifyNamedSequenceOp(callee, /*emitWarnings=*/false); |
2059 | if (!earlyVerifierResult.succeeded()) { |
2060 | (void)earlyVerifierResult.silence(); |
2061 | return defaultEffects(); |
2062 | } |
2063 | |
2064 | for (unsigned i = 0, e = getNumOperands(); i < e; ++i) { |
2065 | if (callee.getArgAttr(i, TransformDialect::kArgConsumedAttrName)) |
2066 | consumesHandle(getOperation()->getOpOperand(i), effects); |
2067 | else |
2068 | onlyReadsHandle(getOperation()->getOpOperand(i), effects); |
2069 | } |
2070 | } |
2071 | |
2072 | LogicalResult |
2073 | transform::IncludeOp::verifySymbolUses(SymbolTableCollection &symbolTable) { |
2074 | // Access through indirection and do additional checking because this may be |
2075 | // running before the main op verifier. |
2076 | auto targetAttr = getOperation()->getAttrOfType<SymbolRefAttr>("target"); |
2077 | if (!targetAttr) |
2078 | return emitOpError() << "expects a 'target' symbol reference attribute"; |
2079 | |
2080 | auto target = symbolTable.lookupNearestSymbolFrom<transform::NamedSequenceOp>( |
2081 | *this, targetAttr); |
2082 | if (!target) |
2083 | return emitOpError() << "does not reference a named transform sequence"; |
2084 | |
2085 | FunctionType fnType = target.getFunctionType(); |
2086 | if (fnType.getNumInputs() != getNumOperands()) |
2087 | return emitError("incorrect number of operands for callee"); |
2088 | |
2089 | for (unsigned i = 0, e = fnType.getNumInputs(); i != e; ++i) { |
2090 | if (getOperand(i).getType() != fnType.getInput(i)) { |
2091 | return emitOpError("operand type mismatch: expected operand type ") |
2092 | << fnType.getInput(i) << ", but provided " |
2093 | << getOperand(i).getType() << " for operand number "<< i; |
2094 | } |
2095 | } |
2096 | |
2097 | if (fnType.getNumResults() != getNumResults()) |
2098 | return emitError("incorrect number of results for callee"); |
2099 | |
2100 | for (unsigned i = 0, e = fnType.getNumResults(); i != e; ++i) { |
2101 | Type resultType = getResult(i).getType(); |
2102 | Type funcType = fnType.getResult(i); |
2103 | if (!implementSameTransformInterface(resultType, funcType)) { |
2104 | return emitOpError() << "type of result #"<< i |
2105 | << " must implement the same transform dialect " |
2106 | "interface as the corresponding callee result"; |
2107 | } |
2108 | } |
2109 | |
2110 | return verifyFunctionLikeConsumeAnnotations( |
2111 | cast<FunctionOpInterface>(*target), /*emitWarnings=*/false, |
2112 | /*alsoVerifyInternal=*/true) |
2113 | .checkAndReport(); |
2114 | } |
2115 | |
2116 | //===----------------------------------------------------------------------===// |
2117 | // MatchOperationEmptyOp |
2118 | //===----------------------------------------------------------------------===// |
2119 | |
2120 | DiagnosedSilenceableFailure transform::MatchOperationEmptyOp::matchOperation( |
2121 | ::std::optional<::mlir::Operation *> maybeCurrent, |
2122 | transform::TransformResults &results, transform::TransformState &state) { |
2123 | if (!maybeCurrent.has_value()) { |
2124 | DEBUG_MATCHER({ DBGS_MATCHER() << "MatchOperationEmptyOp success\n"; }); |
2125 | return DiagnosedSilenceableFailure::success(); |
2126 | } |
2127 | DEBUG_MATCHER({ DBGS_MATCHER() << "MatchOperationEmptyOp failure\n"; }); |
2128 | return emitSilenceableError() << "operation is not empty"; |
2129 | } |
2130 | |
2131 | //===----------------------------------------------------------------------===// |
2132 | // MatchOperationNameOp |
2133 | //===----------------------------------------------------------------------===// |
2134 | |
2135 | DiagnosedSilenceableFailure transform::MatchOperationNameOp::matchOperation( |
2136 | Operation *current, transform::TransformResults &results, |
2137 | transform::TransformState &state) { |
2138 | StringRef currentOpName = current->getName().getStringRef(); |
2139 | for (auto acceptedAttr : getOpNames().getAsRange<StringAttr>()) { |
2140 | if (acceptedAttr.getValue() == currentOpName) |
2141 | return DiagnosedSilenceableFailure::success(); |
2142 | } |
2143 | return emitSilenceableError() << "wrong operation name"; |
2144 | } |
2145 | |
2146 | //===----------------------------------------------------------------------===// |
2147 | // MatchParamCmpIOp |
2148 | //===----------------------------------------------------------------------===// |
2149 | |
2150 | DiagnosedSilenceableFailure |
2151 | transform::MatchParamCmpIOp::apply(transform::TransformRewriter &rewriter, |
2152 | transform::TransformResults &results, |
2153 | transform::TransformState &state) { |
2154 | auto signedAPIntAsString = [&](const APInt &value) { |
2155 | std::string str; |
2156 | llvm::raw_string_ostream os(str); |
2157 | value.print(os, /*isSigned=*/true); |
2158 | return str; |
2159 | }; |
2160 | |
2161 | ArrayRef<Attribute> params = state.getParams(getParam()); |
2162 | ArrayRef<Attribute> references = state.getParams(getReference()); |
2163 | |
2164 | if (params.size() != references.size()) { |
2165 | return emitSilenceableError() |
2166 | << "parameters have different payload lengths ("<< params.size() |
2167 | << " vs "<< references.size() << ")"; |
2168 | } |
2169 | |
2170 | for (auto &&[i, param, reference] : llvm::enumerate(params, references)) { |
2171 | auto intAttr = llvm::dyn_cast<IntegerAttr>(param); |
2172 | auto refAttr = llvm::dyn_cast<IntegerAttr>(reference); |
2173 | if (!intAttr || !refAttr) { |
2174 | return emitDefiniteFailure() |
2175 | << "non-integer parameter value not expected"; |
2176 | } |
2177 | if (intAttr.getType() != refAttr.getType()) { |
2178 | return emitDefiniteFailure() |
2179 | << "mismatching integer attribute types in parameter #"<< i; |
2180 | } |
2181 | APInt value = intAttr.getValue(); |
2182 | APInt refValue = refAttr.getValue(); |
2183 | |
2184 | // TODO: this copy will not be necessary in C++20. |
2185 | int64_t position = i; |
2186 | auto reportError = [&](StringRef direction) { |
2187 | DiagnosedSilenceableFailure diag = |
2188 | emitSilenceableError() << "expected parameter to be "<< direction |
2189 | << " "<< signedAPIntAsString(refValue) |
2190 | << ", got "<< signedAPIntAsString(value); |
2191 | diag.attachNote(getParam().getLoc()) |
2192 | << "value # "<< position |
2193 | << " associated with the parameter defined here"; |
2194 | return diag; |
2195 | }; |
2196 | |
2197 | switch (getPredicate()) { |
2198 | case MatchCmpIPredicate::eq: |
2199 | if (value.eq(refValue)) |
2200 | break; |
2201 | return reportError("equal to"); |
2202 | case MatchCmpIPredicate::ne: |
2203 | if (value.ne(refValue)) |
2204 | break; |
2205 | return reportError("not equal to"); |
2206 | case MatchCmpIPredicate::lt: |
2207 | if (value.slt(refValue)) |
2208 | break; |
2209 | return reportError("less than"); |
2210 | case MatchCmpIPredicate::le: |
2211 | if (value.sle(refValue)) |
2212 | break; |
2213 | return reportError("less than or equal to"); |
2214 | case MatchCmpIPredicate::gt: |
2215 | if (value.sgt(refValue)) |
2216 | break; |
2217 | return reportError("greater than"); |
2218 | case MatchCmpIPredicate::ge: |
2219 | if (value.sge(refValue)) |
2220 | break; |
2221 | return reportError("greater than or equal to"); |
2222 | } |
2223 | } |
2224 | return DiagnosedSilenceableFailure::success(); |
2225 | } |
2226 | |
2227 | void transform::MatchParamCmpIOp::getEffects( |
2228 | SmallVectorImpl<MemoryEffects::EffectInstance> &effects) { |
2229 | onlyReadsHandle(getParamMutable(), effects); |
2230 | onlyReadsHandle(getReferenceMutable(), effects); |
2231 | } |
2232 | |
2233 | //===----------------------------------------------------------------------===// |
2234 | // ParamConstantOp |
2235 | //===----------------------------------------------------------------------===// |
2236 | |
2237 | DiagnosedSilenceableFailure |
2238 | transform::ParamConstantOp::apply(transform::TransformRewriter &rewriter, |
2239 | transform::TransformResults &results, |
2240 | transform::TransformState &state) { |
2241 | results.setParams(cast<OpResult>(getParam()), {getValue()}); |
2242 | return DiagnosedSilenceableFailure::success(); |
2243 | } |
2244 | |
2245 | //===----------------------------------------------------------------------===// |
2246 | // MergeHandlesOp |
2247 | //===----------------------------------------------------------------------===// |
2248 | |
2249 | DiagnosedSilenceableFailure |
2250 | transform::MergeHandlesOp::apply(transform::TransformRewriter &rewriter, |
2251 | transform::TransformResults &results, |
2252 | transform::TransformState &state) { |
2253 | ValueRange handles = getHandles(); |
2254 | if (isa<TransformHandleTypeInterface>(handles.front().getType())) { |
2255 | SmallVector<Operation *> operations; |
2256 | for (Value operand : handles) |
2257 | llvm::append_range(operations, state.getPayloadOps(operand)); |
2258 | if (!getDeduplicate()) { |
2259 | results.set(llvm::cast<OpResult>(getResult()), operations); |
2260 | return DiagnosedSilenceableFailure::success(); |
2261 | } |
2262 | |
2263 | SetVector<Operation *> uniqued(llvm::from_range, operations); |
2264 | results.set(llvm::cast<OpResult>(getResult()), uniqued.getArrayRef()); |
2265 | return DiagnosedSilenceableFailure::success(); |
2266 | } |
2267 | |
2268 | if (llvm::isa<TransformParamTypeInterface>(handles.front().getType())) { |
2269 | SmallVector<Attribute> attrs; |
2270 | for (Value attribute : handles) |
2271 | llvm::append_range(attrs, state.getParams(attribute)); |
2272 | if (!getDeduplicate()) { |
2273 | results.setParams(cast<OpResult>(getResult()), attrs); |
2274 | return DiagnosedSilenceableFailure::success(); |
2275 | } |
2276 | |
2277 | SetVector<Attribute> uniqued(llvm::from_range, attrs); |
2278 | results.setParams(cast<OpResult>(getResult()), uniqued.getArrayRef()); |
2279 | return DiagnosedSilenceableFailure::success(); |
2280 | } |
2281 | |
2282 | assert( |
2283 | llvm::isa<TransformValueHandleTypeInterface>(handles.front().getType()) && |
2284 | "expected value handle type"); |
2285 | SmallVector<Value> payloadValues; |
2286 | for (Value value : handles) |
2287 | llvm::append_range(payloadValues, state.getPayloadValues(value)); |
2288 | if (!getDeduplicate()) { |
2289 | results.setValues(cast<OpResult>(getResult()), payloadValues); |
2290 | return DiagnosedSilenceableFailure::success(); |
2291 | } |
2292 | |
2293 | SetVector<Value> uniqued(llvm::from_range, payloadValues); |
2294 | results.setValues(cast<OpResult>(getResult()), uniqued.getArrayRef()); |
2295 | return DiagnosedSilenceableFailure::success(); |
2296 | } |
2297 | |
2298 | bool transform::MergeHandlesOp::allowsRepeatedHandleOperands() { |
2299 | // Handles may be the same if deduplicating is enabled. |
2300 | return getDeduplicate(); |
2301 | } |
2302 | |
2303 | void transform::MergeHandlesOp::getEffects( |
2304 | SmallVectorImpl<MemoryEffects::EffectInstance> &effects) { |
2305 | onlyReadsHandle(getHandlesMutable(), effects); |
2306 | producesHandle(getOperation()->getOpResults(), effects); |
2307 | |
2308 | // There are no effects on the Payload IR as this is only a handle |
2309 | // manipulation. |
2310 | } |
2311 | |
2312 | OpFoldResult transform::MergeHandlesOp::fold(FoldAdaptor adaptor) { |
2313 | if (getDeduplicate() || getHandles().size() != 1) |
2314 | return {}; |
2315 | |
2316 | // If deduplication is not required and there is only one operand, it can be |
2317 | // used directly instead of merging. |
2318 | return getHandles().front(); |
2319 | } |
2320 | |
2321 | //===----------------------------------------------------------------------===// |
2322 | // NamedSequenceOp |
2323 | //===----------------------------------------------------------------------===// |
2324 | |
2325 | DiagnosedSilenceableFailure |
2326 | transform::NamedSequenceOp::apply(transform::TransformRewriter &rewriter, |
2327 | transform::TransformResults &results, |
2328 | transform::TransformState &state) { |
2329 | if (isExternal()) |
2330 | return emitDefiniteFailure() << "unresolved external named sequence"; |
2331 | |
2332 | // Map the entry block argument to the list of operations. |
2333 | // Note: this is the same implementation as PossibleTopLevelTransformOp but |
2334 | // without attaching the interface / trait since that is tailored to a |
2335 | // dangling top-level op that does not get "called". |
2336 | auto scope = state.make_region_scope(getBody()); |
2337 | if (failed(detail::mapPossibleTopLevelTransformOpBlockArguments( |
2338 | state, this->getOperation(), getBody()))) |
2339 | return DiagnosedSilenceableFailure::definiteFailure(); |
2340 | |
2341 | return applySequenceBlock(getBody().front(), |
2342 | FailurePropagationMode::Propagate, state, results); |
2343 | } |
2344 | |
2345 | void transform::NamedSequenceOp::getEffects( |
2346 | SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {} |
2347 | |
2348 | ParseResult transform::NamedSequenceOp::parse(OpAsmParser &parser, |
2349 | OperationState &result) { |
2350 | return function_interface_impl::parseFunctionOp( |
2351 | parser, result, /*allowVariadic=*/false, |
2352 | getFunctionTypeAttrName(result.name), |
2353 | [](Builder &builder, ArrayRef<Type> inputs, ArrayRef<Type> results, |
2354 | function_interface_impl::VariadicFlag, |
2355 | std::string &) { return builder.getFunctionType(inputs, results); }, |
2356 | getArgAttrsAttrName(result.name), getResAttrsAttrName(result.name)); |
2357 | } |
2358 | |
2359 | void transform::NamedSequenceOp::print(OpAsmPrinter &printer) { |
2360 | function_interface_impl::printFunctionOp( |
2361 | printer, cast<FunctionOpInterface>(getOperation()), /*isVariadic=*/false, |
2362 | getFunctionTypeAttrName().getValue(), getArgAttrsAttrName(), |
2363 | getResAttrsAttrName()); |
2364 | } |
2365 | |
2366 | /// Verifies that a symbol function-like transform dialect operation has the |
2367 | /// signature and the terminator that have conforming types, i.e., types |
2368 | /// implementing the same transform dialect type interface. If `allowExternal` |
2369 | /// is set, allow external symbols (declarations) and don't check the terminator |
2370 | /// as it may not exist. |
2371 | static DiagnosedSilenceableFailure |
2372 | verifyYieldingSingleBlockOp(FunctionOpInterface op, bool allowExternal) { |
2373 | if (auto parent = op->getParentOfType<transform::TransformOpInterface>()) { |
2374 | DiagnosedSilenceableFailure diag = |
2375 | emitSilenceableFailure(op) |
2376 | << "cannot be defined inside another transform op"; |
2377 | diag.attachNote(loc: parent.getLoc()) << "ancestor transform op"; |
2378 | return diag; |
2379 | } |
2380 | |
2381 | if (op.isExternal() || op.getFunctionBody().empty()) { |
2382 | if (allowExternal) |
2383 | return DiagnosedSilenceableFailure::success(); |
2384 | |
2385 | return emitSilenceableFailure(op) << "cannot be external"; |
2386 | } |
2387 | |
2388 | if (op.getFunctionBody().front().empty()) |
2389 | return emitSilenceableFailure(op) << "expected a non-empty body block"; |
2390 | |
2391 | Operation *terminator = &op.getFunctionBody().front().back(); |
2392 | if (!isa<transform::YieldOp>(terminator)) { |
2393 | DiagnosedSilenceableFailure diag = emitSilenceableFailure(op) |
2394 | << "expected '" |
2395 | << transform::YieldOp::getOperationName() |
2396 | << "' as terminator"; |
2397 | diag.attachNote(loc: terminator->getLoc()) << "terminator"; |
2398 | return diag; |
2399 | } |
2400 | |
2401 | if (terminator->getNumOperands() != op.getResultTypes().size()) { |
2402 | return emitSilenceableFailure(op: terminator) |
2403 | << "expected terminator to have as many operands as the parent op " |
2404 | "has results"; |
2405 | } |
2406 | for (auto [i, operandType, resultType] : llvm::zip_equal( |
2407 | llvm::seq<unsigned>(0, terminator->getNumOperands()), |
2408 | terminator->getOperands().getType(), op.getResultTypes())) { |
2409 | if (operandType == resultType) |
2410 | continue; |
2411 | return emitSilenceableFailure(terminator) |
2412 | << "the type of the terminator operand #"<< i |
2413 | << " must match the type of the corresponding parent op result (" |
2414 | << operandType << " vs "<< resultType << ")"; |
2415 | } |
2416 | |
2417 | return DiagnosedSilenceableFailure::success(); |
2418 | } |
2419 | |
2420 | /// Verification of a NamedSequenceOp. This does not report the error |
2421 | /// immediately, so it can be used to check for op's well-formedness before the |
2422 | /// verifier runs, e.g., during trait verification. |
2423 | static DiagnosedSilenceableFailure |
2424 | verifyNamedSequenceOp(transform::NamedSequenceOp op, bool emitWarnings) { |
2425 | if (Operation *parent = op->getParentWithTrait<OpTrait::SymbolTable>()) { |
2426 | if (!parent->getAttr( |
2427 | transform::TransformDialect::kWithNamedSequenceAttrName)) { |
2428 | DiagnosedSilenceableFailure diag = |
2429 | emitSilenceableFailure(op) |
2430 | << "expects the parent symbol table to have the '" |
2431 | << transform::TransformDialect::kWithNamedSequenceAttrName |
2432 | << "' attribute"; |
2433 | diag.attachNote(loc: parent->getLoc()) << "symbol table operation"; |
2434 | return diag; |
2435 | } |
2436 | } |
2437 | |
2438 | if (auto parent = op->getParentOfType<transform::TransformOpInterface>()) { |
2439 | DiagnosedSilenceableFailure diag = |
2440 | emitSilenceableFailure(op) |
2441 | << "cannot be defined inside another transform op"; |
2442 | diag.attachNote(loc: parent.getLoc()) << "ancestor transform op"; |
2443 | return diag; |
2444 | } |
2445 | |
2446 | if (op.isExternal() || op.getBody().empty()) |
2447 | return verifyFunctionLikeConsumeAnnotations(cast<FunctionOpInterface>(*op), |
2448 | emitWarnings); |
2449 | |
2450 | if (op.getBody().front().empty()) |
2451 | return emitSilenceableFailure(op) << "expected a non-empty body block"; |
2452 | |
2453 | Operation *terminator = &op.getBody().front().back(); |
2454 | if (!isa<transform::YieldOp>(terminator)) { |
2455 | DiagnosedSilenceableFailure diag = emitSilenceableFailure(op) |
2456 | << "expected '" |
2457 | << transform::YieldOp::getOperationName() |
2458 | << "' as terminator"; |
2459 | diag.attachNote(loc: terminator->getLoc()) << "terminator"; |
2460 | return diag; |
2461 | } |
2462 | |
2463 | if (terminator->getNumOperands() != op.getFunctionType().getNumResults()) { |
2464 | return emitSilenceableFailure(op: terminator) |
2465 | << "expected terminator to have as many operands as the parent op " |
2466 | "has results"; |
2467 | } |
2468 | for (auto [i, operandType, resultType] : |
2469 | llvm::zip_equal(llvm::seq<unsigned>(0, terminator->getNumOperands()), |
2470 | terminator->getOperands().getType(), |
2471 | op.getFunctionType().getResults())) { |
2472 | if (operandType == resultType) |
2473 | continue; |
2474 | return emitSilenceableFailure(terminator) |
2475 | << "the type of the terminator operand #"<< i |
2476 | << " must match the type of the corresponding parent op result (" |
2477 | << operandType << " vs "<< resultType << ")"; |
2478 | } |
2479 | |
2480 | auto funcOp = cast<FunctionOpInterface>(*op); |
2481 | DiagnosedSilenceableFailure diag = |
2482 | verifyFunctionLikeConsumeAnnotations(funcOp, emitWarnings); |
2483 | if (!diag.succeeded()) |
2484 | return diag; |
2485 | |
2486 | return verifyYieldingSingleBlockOp(funcOp, |
2487 | /*allowExternal=*/true); |
2488 | } |
2489 | |
2490 | LogicalResult transform::NamedSequenceOp::verify() { |
2491 | // Actual verification happens in a separate function for reusability. |
2492 | return verifyNamedSequenceOp(*this, /*emitWarnings=*/true).checkAndReport(); |
2493 | } |
2494 | |
2495 | template <typename FnTy> |
2496 | static void buildSequenceBody(OpBuilder &builder, OperationState &state, |
2497 | Type bbArgType, TypeRange extraBindingTypes, |
2498 | FnTy bodyBuilder) { |
2499 | SmallVector<Type> types; |
2500 | types.reserve(N: 1 + extraBindingTypes.size()); |
2501 | types.push_back(Elt: bbArgType); |
2502 | llvm::append_range(C&: types, R&: extraBindingTypes); |
2503 | |
2504 | OpBuilder::InsertionGuard guard(builder); |
2505 | Region *region = state.regions.back().get(); |
2506 | Block *bodyBlock = |
2507 | builder.createBlock(parent: region, insertPt: region->begin(), argTypes: types, |
2508 | locs: SmallVector<Location>(types.size(), state.location)); |
2509 | |
2510 | // Populate body. |
2511 | builder.setInsertionPointToStart(bodyBlock); |
2512 | if constexpr (llvm::function_traits<FnTy>::num_args == 3) { |
2513 | bodyBuilder(builder, state.location, bodyBlock->getArgument(i: 0)); |
2514 | } else { |
2515 | bodyBuilder(builder, state.location, bodyBlock->getArgument(i: 0), |
2516 | bodyBlock->getArguments().drop_front()); |
2517 | } |
2518 | } |
2519 | |
2520 | void transform::NamedSequenceOp::build(OpBuilder &builder, |
2521 | OperationState &state, StringRef symName, |
2522 | Type rootType, TypeRange resultTypes, |
2523 | SequenceBodyBuilderFn bodyBuilder, |
2524 | ArrayRef<NamedAttribute> attrs, |
2525 | ArrayRef<DictionaryAttr> argAttrs) { |
2526 | state.addAttribute(SymbolTable::getSymbolAttrName(), |
2527 | builder.getStringAttr(symName)); |
2528 | state.addAttribute(getFunctionTypeAttrName(state.name), |
2529 | TypeAttr::get(FunctionType::get(builder.getContext(), |
2530 | rootType, resultTypes))); |
2531 | state.attributes.append(attrs.begin(), attrs.end()); |
2532 | state.addRegion(); |
2533 | |
2534 | buildSequenceBody(builder, state, rootType, |
2535 | /*extraBindingTypes=*/TypeRange(), bodyBuilder); |
2536 | } |
2537 | |
2538 | //===----------------------------------------------------------------------===// |
2539 | // NumAssociationsOp |
2540 | //===----------------------------------------------------------------------===// |
2541 | |
2542 | DiagnosedSilenceableFailure |
2543 | transform::NumAssociationsOp::apply(transform::TransformRewriter &rewriter, |
2544 | transform::TransformResults &results, |
2545 | transform::TransformState &state) { |
2546 | size_t numAssociations = |
2547 | llvm::TypeSwitch<Type, size_t>(getHandle().getType()) |
2548 | .Case([&](TransformHandleTypeInterface opHandle) { |
2549 | return llvm::range_size(state.getPayloadOps(getHandle())); |
2550 | }) |
2551 | .Case([&](TransformValueHandleTypeInterface valueHandle) { |
2552 | return llvm::range_size(state.getPayloadValues(getHandle())); |
2553 | }) |
2554 | .Case([&](TransformParamTypeInterface param) { |
2555 | return llvm::range_size(state.getParams(getHandle())); |
2556 | }) |
2557 | .Default([](Type) { |
2558 | llvm_unreachable("unknown kind of transform dialect type"); |
2559 | return 0; |
2560 | }); |
2561 | results.setParams(cast<OpResult>(getNum()), |
2562 | rewriter.getI64IntegerAttr(numAssociations)); |
2563 | return DiagnosedSilenceableFailure::success(); |
2564 | } |
2565 | |
2566 | LogicalResult transform::NumAssociationsOp::verify() { |
2567 | // Verify that the result type accepts an i64 attribute as payload. |
2568 | auto resultType = cast<TransformParamTypeInterface>(getNum().getType()); |
2569 | return resultType |
2570 | .checkPayload(getLoc(), {Builder(getContext()).getI64IntegerAttr(0)}) |
2571 | .checkAndReport(); |
2572 | } |
2573 | |
2574 | //===----------------------------------------------------------------------===// |
2575 | // SelectOp |
2576 | //===----------------------------------------------------------------------===// |
2577 | |
2578 | DiagnosedSilenceableFailure |
2579 | transform::SelectOp::apply(transform::TransformRewriter &rewriter, |
2580 | transform::TransformResults &results, |
2581 | transform::TransformState &state) { |
2582 | SmallVector<Operation *> result; |
2583 | auto payloadOps = state.getPayloadOps(getTarget()); |
2584 | for (Operation *op : payloadOps) { |
2585 | if (op->getName().getStringRef() == getOpName()) |
2586 | result.push_back(op); |
2587 | } |
2588 | results.set(cast<OpResult>(getResult()), result); |
2589 | return DiagnosedSilenceableFailure::success(); |
2590 | } |
2591 | |
2592 | //===----------------------------------------------------------------------===// |
2593 | // SplitHandleOp |
2594 | //===----------------------------------------------------------------------===// |
2595 | |
2596 | void transform::SplitHandleOp::build(OpBuilder &builder, OperationState &result, |
2597 | Value target, int64_t numResultHandles) { |
2598 | result.addOperands(target); |
2599 | result.addTypes(SmallVector<Type>(numResultHandles, target.getType())); |
2600 | } |
2601 | |
2602 | DiagnosedSilenceableFailure |
2603 | transform::SplitHandleOp::apply(transform::TransformRewriter &rewriter, |
2604 | transform::TransformResults &results, |
2605 | transform::TransformState &state) { |
2606 | int64_t numPayloads = |
2607 | llvm::TypeSwitch<Type, int64_t>(getHandle().getType()) |
2608 | .Case<TransformHandleTypeInterface>([&](auto x) { |
2609 | return llvm::range_size(state.getPayloadOps(getHandle())); |
2610 | }) |
2611 | .Case<TransformValueHandleTypeInterface>([&](auto x) { |
2612 | return llvm::range_size(state.getPayloadValues(getHandle())); |
2613 | }) |
2614 | .Case<TransformParamTypeInterface>([&](auto x) { |
2615 | return llvm::range_size(state.getParams(getHandle())); |
2616 | }) |
2617 | .Default([](auto x) { |
2618 | llvm_unreachable("unknown transform dialect type interface"); |
2619 | return -1; |
2620 | }); |
2621 | |
2622 | auto produceNumOpsError = [&]() { |
2623 | return emitSilenceableError() |
2624 | << getHandle() << " expected to contain "<< this->getNumResults() |
2625 | << " payloads but it contains "<< numPayloads << " payloads"; |
2626 | }; |
2627 | |
2628 | // Fail if there are more payload ops than results and no overflow result was |
2629 | // specified. |
2630 | if (numPayloads > getNumResults() && !getOverflowResult().has_value()) |
2631 | return produceNumOpsError(); |
2632 | |
2633 | // Fail if there are more results than payload ops. Unless: |
2634 | // - "fail_on_payload_too_small" is set to "false", or |
2635 | // - "pass_through_empty_handle" is set to "true" and there are 0 payload ops. |
2636 | if (numPayloads < getNumResults() && getFailOnPayloadTooSmall() && |
2637 | (numPayloads != 0 || !getPassThroughEmptyHandle())) |
2638 | return produceNumOpsError(); |
2639 | |
2640 | // Distribute payloads. |
2641 | SmallVector<SmallVector<MappedValue, 1>> resultHandles(getNumResults(), {}); |
2642 | if (getOverflowResult()) |
2643 | resultHandles[*getOverflowResult()].reserve(numPayloads - getNumResults()); |
2644 | |
2645 | auto container = [&]() { |
2646 | if (isa<TransformHandleTypeInterface>(getHandle().getType())) { |
2647 | return llvm::map_to_vector( |
2648 | state.getPayloadOps(getHandle()), |
2649 | [](Operation *op) -> MappedValue { return op; }); |
2650 | } |
2651 | if (isa<TransformValueHandleTypeInterface>(getHandle().getType())) { |
2652 | return llvm::map_to_vector(state.getPayloadValues(getHandle()), |
2653 | [](Value v) -> MappedValue { return v; }); |
2654 | } |
2655 | assert(isa<TransformParamTypeInterface>(getHandle().getType()) && |
2656 | "unsupported kind of transform dialect type"); |
2657 | return llvm::map_to_vector(state.getParams(getHandle()), |
2658 | [](Attribute a) -> MappedValue { return a; }); |
2659 | }(); |
2660 | |
2661 | for (auto &&en : llvm::enumerate(container)) { |
2662 | int64_t resultNum = en.index(); |
2663 | if (resultNum >= getNumResults()) |
2664 | resultNum = *getOverflowResult(); |
2665 | resultHandles[resultNum].push_back(en.value()); |
2666 | } |
2667 | |
2668 | // Set transform op results. |
2669 | for (auto &&it : llvm::enumerate(resultHandles)) |
2670 | results.setMappedValues(llvm::cast<OpResult>(getResult(it.index())), |
2671 | it.value()); |
2672 | |
2673 | return DiagnosedSilenceableFailure::success(); |
2674 | } |
2675 | |
2676 | void transform::SplitHandleOp::getEffects( |
2677 | SmallVectorImpl<MemoryEffects::EffectInstance> &effects) { |
2678 | onlyReadsHandle(getHandleMutable(), effects); |
2679 | producesHandle(getOperation()->getOpResults(), effects); |
2680 | // There are no effects on the Payload IR as this is only a handle |
2681 | // manipulation. |
2682 | } |
2683 | |
2684 | LogicalResult transform::SplitHandleOp::verify() { |
2685 | if (getOverflowResult().has_value() && |
2686 | !(*getOverflowResult() < getNumResults())) |
2687 | return emitOpError("overflow_result is not a valid result index"); |
2688 | |
2689 | for (Type resultType : getResultTypes()) { |
2690 | if (implementSameTransformInterface(getHandle().getType(), resultType)) |
2691 | continue; |
2692 | |
2693 | return emitOpError("expects result types to implement the same transform " |
2694 | "interface as the operand type"); |
2695 | } |
2696 | |
2697 | return success(); |
2698 | } |
2699 | |
2700 | //===----------------------------------------------------------------------===// |
2701 | // ReplicateOp |
2702 | //===----------------------------------------------------------------------===// |
2703 | |
2704 | DiagnosedSilenceableFailure |
2705 | transform::ReplicateOp::apply(transform::TransformRewriter &rewriter, |
2706 | transform::TransformResults &results, |
2707 | transform::TransformState &state) { |
2708 | unsigned numRepetitions = llvm::range_size(state.getPayloadOps(getPattern())); |
2709 | for (const auto &en : llvm::enumerate(getHandles())) { |
2710 | Value handle = en.value(); |
2711 | if (isa<TransformHandleTypeInterface>(handle.getType())) { |
2712 | SmallVector<Operation *> current = |
2713 | llvm::to_vector(state.getPayloadOps(handle)); |
2714 | SmallVector<Operation *> payload; |
2715 | payload.reserve(numRepetitions * current.size()); |
2716 | for (unsigned i = 0; i < numRepetitions; ++i) |
2717 | llvm::append_range(payload, current); |
2718 | results.set(llvm::cast<OpResult>(getReplicated()[en.index()]), payload); |
2719 | } else { |
2720 | assert(llvm::isa<TransformParamTypeInterface>(handle.getType()) && |
2721 | "expected param type"); |
2722 | ArrayRef<Attribute> current = state.getParams(handle); |
2723 | SmallVector<Attribute> params; |
2724 | params.reserve(numRepetitions * current.size()); |
2725 | for (unsigned i = 0; i < numRepetitions; ++i) |
2726 | llvm::append_range(params, current); |
2727 | results.setParams(llvm::cast<OpResult>(getReplicated()[en.index()]), |
2728 | params); |
2729 | } |
2730 | } |
2731 | return DiagnosedSilenceableFailure::success(); |
2732 | } |
2733 | |
2734 | void transform::ReplicateOp::getEffects( |
2735 | SmallVectorImpl<MemoryEffects::EffectInstance> &effects) { |
2736 | onlyReadsHandle(getPatternMutable(), effects); |
2737 | onlyReadsHandle(getHandlesMutable(), effects); |
2738 | producesHandle(getOperation()->getOpResults(), effects); |
2739 | } |
2740 | |
2741 | //===----------------------------------------------------------------------===// |
2742 | // SequenceOp |
2743 | //===----------------------------------------------------------------------===// |
2744 | |
2745 | DiagnosedSilenceableFailure |
2746 | transform::SequenceOp::apply(transform::TransformRewriter &rewriter, |
2747 | transform::TransformResults &results, |
2748 | transform::TransformState &state) { |
2749 | // Map the entry block argument to the list of operations. |
2750 | auto scope = state.make_region_scope(*getBodyBlock()->getParent()); |
2751 | if (failed(mapBlockArguments(state))) |
2752 | return DiagnosedSilenceableFailure::definiteFailure(); |
2753 | |
2754 | return applySequenceBlock(*getBodyBlock(), getFailurePropagationMode(), state, |
2755 | results); |
2756 | } |
2757 | |
2758 | static ParseResult parseSequenceOpOperands( |
2759 | OpAsmParser &parser, std::optional<OpAsmParser::UnresolvedOperand> &root, |
2760 | Type &rootType, |
2761 | SmallVectorImpl<OpAsmParser::UnresolvedOperand> &extraBindings, |
2762 | SmallVectorImpl<Type> &extraBindingTypes) { |
2763 | OpAsmParser::UnresolvedOperand rootOperand; |
2764 | OptionalParseResult hasRoot = parser.parseOptionalOperand(result&: rootOperand); |
2765 | if (!hasRoot.has_value()) { |
2766 | root = std::nullopt; |
2767 | return success(); |
2768 | } |
2769 | if (failed(Result: hasRoot.value())) |
2770 | return failure(); |
2771 | root = rootOperand; |
2772 | |
2773 | if (succeeded(Result: parser.parseOptionalComma())) { |
2774 | if (failed(Result: parser.parseOperandList(result&: extraBindings))) |
2775 | return failure(); |
2776 | } |
2777 | if (failed(Result: parser.parseColon())) |
2778 | return failure(); |
2779 | |
2780 | // The paren is truly optional. |
2781 | (void)parser.parseOptionalLParen(); |
2782 | |
2783 | if (failed(Result: parser.parseType(result&: rootType))) { |
2784 | return failure(); |
2785 | } |
2786 | |
2787 | if (!extraBindings.empty()) { |
2788 | if (parser.parseComma() || parser.parseTypeList(result&: extraBindingTypes)) |
2789 | return failure(); |
2790 | } |
2791 | |
2792 | if (extraBindingTypes.size() != extraBindings.size()) { |
2793 | return parser.emitError(loc: parser.getNameLoc(), |
2794 | message: "expected types to be provided for all operands"); |
2795 | } |
2796 | |
2797 | // The paren is truly optional. |
2798 | (void)parser.parseOptionalRParen(); |
2799 | return success(); |
2800 | } |
2801 | |
2802 | static void printSequenceOpOperands(OpAsmPrinter &printer, Operation *op, |
2803 | Value root, Type rootType, |
2804 | ValueRange extraBindings, |
2805 | TypeRange extraBindingTypes) { |
2806 | if (!root) |
2807 | return; |
2808 | |
2809 | printer << root; |
2810 | bool hasExtras = !extraBindings.empty(); |
2811 | if (hasExtras) { |
2812 | printer << ", "; |
2813 | printer.printOperands(container: extraBindings); |
2814 | } |
2815 | |
2816 | printer << " : "; |
2817 | if (hasExtras) |
2818 | printer << "("; |
2819 | |
2820 | printer << rootType; |
2821 | if (hasExtras) |
2822 | printer << ", "<< llvm::interleaved(R: extraBindingTypes) << ')'; |
2823 | } |
2824 | |
2825 | /// Returns `true` if the given op operand may be consuming the handle value in |
2826 | /// the Transform IR. That is, if it may have a Free effect on it. |
2827 | static bool isValueUsePotentialConsumer(OpOperand &use) { |
2828 | // Conservatively assume the effect being present in absence of the interface. |
2829 | auto iface = dyn_cast<transform::TransformOpInterface>(use.getOwner()); |
2830 | if (!iface) |
2831 | return true; |
2832 | |
2833 | return isHandleConsumed(use.get(), iface); |
2834 | } |
2835 | |
2836 | LogicalResult |
2837 | checkDoubleConsume(Value value, |
2838 | function_ref<InFlightDiagnostic()> reportError) { |
2839 | OpOperand *potentialConsumer = nullptr; |
2840 | for (OpOperand &use : value.getUses()) { |
2841 | if (!isValueUsePotentialConsumer(use)) |
2842 | continue; |
2843 | |
2844 | if (!potentialConsumer) { |
2845 | potentialConsumer = &use; |
2846 | continue; |
2847 | } |
2848 | |
2849 | InFlightDiagnostic diag = reportError() |
2850 | << " has more than one potential consumer"; |
2851 | diag.attachNote(noteLoc: potentialConsumer->getOwner()->getLoc()) |
2852 | << "used here as operand #"<< potentialConsumer->getOperandNumber(); |
2853 | diag.attachNote(noteLoc: use.getOwner()->getLoc()) |
2854 | << "used here as operand #"<< use.getOperandNumber(); |
2855 | return diag; |
2856 | } |
2857 | |
2858 | return success(); |
2859 | } |
2860 | |
2861 | LogicalResult transform::SequenceOp::verify() { |
2862 | assert(getBodyBlock()->getNumArguments() >= 1 && |
2863 | "the number of arguments must have been verified to be more than 1 by " |
2864 | "PossibleTopLevelTransformOpTrait"); |
2865 | |
2866 | if (!getRoot() && !getExtraBindings().empty()) { |
2867 | return emitOpError() |
2868 | << "does not expect extra operands when used as top-level"; |
2869 | } |
2870 | |
2871 | // Check if a block argument has more than one consuming use. |
2872 | for (BlockArgument arg : getBodyBlock()->getArguments()) { |
2873 | if (failed(checkDoubleConsume(arg, [this, arg]() { |
2874 | return (emitOpError() << "block argument #"<< arg.getArgNumber()); |
2875 | }))) { |
2876 | return failure(); |
2877 | } |
2878 | } |
2879 | |
2880 | // Check properties of the nested operations they cannot check themselves. |
2881 | for (Operation &child : *getBodyBlock()) { |
2882 | if (!isa<TransformOpInterface>(child) && |
2883 | &child != &getBodyBlock()->back()) { |
2884 | InFlightDiagnostic diag = |
2885 | emitOpError() |
2886 | << "expected children ops to implement TransformOpInterface"; |
2887 | diag.attachNote(child.getLoc()) << "op without interface"; |
2888 | return diag; |
2889 | } |
2890 | |
2891 | for (OpResult result : child.getResults()) { |
2892 | auto report = [&]() { |
2893 | return (child.emitError() << "result #"<< result.getResultNumber()); |
2894 | }; |
2895 | if (failed(checkDoubleConsume(result, report))) |
2896 | return failure(); |
2897 | } |
2898 | } |
2899 | |
2900 | if (!getBodyBlock()->mightHaveTerminator()) |
2901 | return emitOpError() << "expects to have a terminator in the body"; |
2902 | |
2903 | if (getBodyBlock()->getTerminator()->getOperandTypes() != |
2904 | getOperation()->getResultTypes()) { |
2905 | InFlightDiagnostic diag = emitOpError() |
2906 | << "expects the types of the terminator operands " |
2907 | "to match the types of the result"; |
2908 | diag.attachNote(getBodyBlock()->getTerminator()->getLoc()) << "terminator"; |
2909 | return diag; |
2910 | } |
2911 | return success(); |
2912 | } |
2913 | |
2914 | void transform::SequenceOp::getEffects( |
2915 | SmallVectorImpl<MemoryEffects::EffectInstance> &effects) { |
2916 | getPotentialTopLevelEffects(effects); |
2917 | } |
2918 | |
2919 | OperandRange |
2920 | transform::SequenceOp::getEntrySuccessorOperands(RegionBranchPoint point) { |
2921 | assert(point == getBody() && "unexpected region index"); |
2922 | if (getOperation()->getNumOperands() > 0) |
2923 | return getOperation()->getOperands(); |
2924 | return OperandRange(getOperation()->operand_end(), |
2925 | getOperation()->operand_end()); |
2926 | } |
2927 | |
2928 | void transform::SequenceOp::getSuccessorRegions( |
2929 | RegionBranchPoint point, SmallVectorImpl<RegionSuccessor> ®ions) { |
2930 | if (point.isParent()) { |
2931 | Region *bodyRegion = &getBody(); |
2932 | regions.emplace_back(bodyRegion, getNumOperands() != 0 |
2933 | ? bodyRegion->getArguments() |
2934 | : Block::BlockArgListType()); |
2935 | return; |
2936 | } |
2937 | |
2938 | assert(point == getBody() && "unexpected region index"); |
2939 | regions.emplace_back(getOperation()->getResults()); |
2940 | } |
2941 | |
2942 | void transform::SequenceOp::getRegionInvocationBounds( |
2943 | ArrayRef<Attribute> operands, SmallVectorImpl<InvocationBounds> &bounds) { |
2944 | (void)operands; |
2945 | bounds.emplace_back(1, 1); |
2946 | } |
2947 | |
2948 | void transform::SequenceOp::build(OpBuilder &builder, OperationState &state, |
2949 | TypeRange resultTypes, |
2950 | FailurePropagationMode failurePropagationMode, |
2951 | Value root, |
2952 | SequenceBodyBuilderFn bodyBuilder) { |
2953 | build(builder, state, resultTypes, failurePropagationMode, root, |
2954 | /*extra_bindings=*/ValueRange()); |
2955 | Type bbArgType = root.getType(); |
2956 | buildSequenceBody(builder, state, bbArgType, |
2957 | /*extraBindingTypes=*/TypeRange(), bodyBuilder); |
2958 | } |
2959 | |
2960 | void transform::SequenceOp::build(OpBuilder &builder, OperationState &state, |
2961 | TypeRange resultTypes, |
2962 | FailurePropagationMode failurePropagationMode, |
2963 | Value root, ValueRange extraBindings, |
2964 | SequenceBodyBuilderArgsFn bodyBuilder) { |
2965 | build(builder, state, resultTypes, failurePropagationMode, root, |
2966 | extraBindings); |
2967 | buildSequenceBody(builder, state, root.getType(), extraBindings.getTypes(), |
2968 | bodyBuilder); |
2969 | } |
2970 | |
2971 | void transform::SequenceOp::build(OpBuilder &builder, OperationState &state, |
2972 | TypeRange resultTypes, |
2973 | FailurePropagationMode failurePropagationMode, |
2974 | Type bbArgType, |
2975 | SequenceBodyBuilderFn bodyBuilder) { |
2976 | build(builder, state, resultTypes, failurePropagationMode, /*root=*/Value(), |
2977 | /*extra_bindings=*/ValueRange()); |
2978 | buildSequenceBody(builder, state, bbArgType, |
2979 | /*extraBindingTypes=*/TypeRange(), bodyBuilder); |
2980 | } |
2981 | |
2982 | void transform::SequenceOp::build(OpBuilder &builder, OperationState &state, |
2983 | TypeRange resultTypes, |
2984 | FailurePropagationMode failurePropagationMode, |
2985 | Type bbArgType, TypeRange extraBindingTypes, |
2986 | SequenceBodyBuilderArgsFn bodyBuilder) { |
2987 | build(builder, state, resultTypes, failurePropagationMode, /*root=*/Value(), |
2988 | /*extra_bindings=*/ValueRange()); |
2989 | buildSequenceBody(builder, state, bbArgType, extraBindingTypes, bodyBuilder); |
2990 | } |
2991 | |
2992 | //===----------------------------------------------------------------------===// |
2993 | // PrintOp |
2994 | //===----------------------------------------------------------------------===// |
2995 | |
2996 | void transform::PrintOp::build(OpBuilder &builder, OperationState &result, |
2997 | StringRef name) { |
2998 | if (!name.empty()) |
2999 | result.getOrAddProperties<Properties>().name = builder.getStringAttr(name); |
3000 | } |
3001 | |
3002 | void transform::PrintOp::build(OpBuilder &builder, OperationState &result, |
3003 | Value target, StringRef name) { |
3004 | result.addOperands({target}); |
3005 | build(builder, result, name); |
3006 | } |
3007 | |
3008 | DiagnosedSilenceableFailure |
3009 | transform::PrintOp::apply(transform::TransformRewriter &rewriter, |
3010 | transform::TransformResults &results, |
3011 | transform::TransformState &state) { |
3012 | llvm::outs() << "[[[ IR printer: "; |
3013 | if (getName().has_value()) |
3014 | llvm::outs() << *getName() << " "; |
3015 | |
3016 | OpPrintingFlags printFlags; |
3017 | if (getAssumeVerified().value_or(false)) |
3018 | printFlags.assumeVerified(); |
3019 | if (getUseLocalScope().value_or(false)) |
3020 | printFlags.useLocalScope(); |
3021 | if (getSkipRegions().value_or(false)) |
3022 | printFlags.skipRegions(); |
3023 | |
3024 | if (!getTarget()) { |
3025 | llvm::outs() << "top-level ]]]\n"; |
3026 | state.getTopLevel()->print(llvm::outs(), printFlags); |
3027 | llvm::outs() << "\n"; |
3028 | llvm::outs().flush(); |
3029 | return DiagnosedSilenceableFailure::success(); |
3030 | } |
3031 | |
3032 | llvm::outs() << "]]]\n"; |
3033 | for (Operation *target : state.getPayloadOps(getTarget())) { |
3034 | target->print(llvm::outs(), printFlags); |
3035 | llvm::outs() << "\n"; |
3036 | } |
3037 | |
3038 | llvm::outs().flush(); |
3039 | return DiagnosedSilenceableFailure::success(); |
3040 | } |
3041 | |
3042 | void transform::PrintOp::getEffects( |
3043 | SmallVectorImpl<MemoryEffects::EffectInstance> &effects) { |
3044 | // We don't really care about mutability here, but `getTarget` now |
3045 | // unconditionally casts to a specific type before verification could run |
3046 | // here. |
3047 | if (!getTargetMutable().empty()) |
3048 | onlyReadsHandle(getTargetMutable()[0], effects); |
3049 | onlyReadsPayload(effects); |
3050 | |
3051 | // There is no resource for stderr file descriptor, so just declare print |
3052 | // writes into the default resource. |
3053 | effects.emplace_back(MemoryEffects::Write::get()); |
3054 | } |
3055 | |
3056 | //===----------------------------------------------------------------------===// |
3057 | // VerifyOp |
3058 | //===----------------------------------------------------------------------===// |
3059 | |
3060 | DiagnosedSilenceableFailure |
3061 | transform::VerifyOp::applyToOne(transform::TransformRewriter &rewriter, |
3062 | Operation *target, |
3063 | transform::ApplyToEachResultList &results, |
3064 | transform::TransformState &state) { |
3065 | if (failed(::mlir::verify(target))) { |
3066 | DiagnosedDefiniteFailure diag = emitDefiniteFailure() |
3067 | << "failed to verify payload op"; |
3068 | diag.attachNote(target->getLoc()) << "payload op"; |
3069 | return diag; |
3070 | } |
3071 | return DiagnosedSilenceableFailure::success(); |
3072 | } |
3073 | |
3074 | void transform::VerifyOp::getEffects( |
3075 | SmallVectorImpl<MemoryEffects::EffectInstance> &effects) { |
3076 | transform::onlyReadsHandle(getTargetMutable(), effects); |
3077 | } |
3078 | |
3079 | //===----------------------------------------------------------------------===// |
3080 | // YieldOp |
3081 | //===----------------------------------------------------------------------===// |
3082 | |
3083 | void transform::YieldOp::getEffects( |
3084 | SmallVectorImpl<MemoryEffects::EffectInstance> &effects) { |
3085 | onlyReadsHandle(getOperandsMutable(), effects); |
3086 | } |
3087 |
Definitions
- ensurePayloadIsSeparateFromTransform
- forwardEmptyOperands
- parseApplyRegisteredPassOptions
- printApplyRegisteredPassOptions
- matchBlock
- implementSameInterface
- implementSameTransformInterface
- parseForeachMatchSymbols
- printForeachMatchSymbols
- verifyFunctionLikeConsumeAnnotations
- applySequenceBlock
- verifyYieldingSingleBlockOp
- verifyNamedSequenceOp
- buildSequenceBody
- parseSequenceOpOperands
- printSequenceOpOperands
- isValueUsePotentialConsumer
Learn to use CMake with our Intro Training
Find out more