1//===- TransformInterfaces.cpp - Transform Dialect Interfaces -------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#include "mlir/Dialect/Transform/Interfaces/TransformInterfaces.h"
10
11#include "mlir/IR/Diagnostics.h"
12#include "mlir/IR/Operation.h"
13#include "mlir/IR/PatternMatch.h"
14#include "mlir/Interfaces/CastInterfaces.h"
15#include "mlir/Support/LogicalResult.h"
16#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
17#include "llvm/ADT/STLExtras.h"
18#include "llvm/ADT/ScopeExit.h"
19#include "llvm/Support/Debug.h"
20#include "llvm/Support/ErrorHandling.h"
21
22#define DEBUG_TYPE "transform-dialect"
23#define DEBUG_TYPE_FULL "transform-dialect-full"
24#define DEBUG_PRINT_AFTER_ALL "transform-dialect-print-top-level-after-all"
25#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "] ")
26#define LDBG(X) LLVM_DEBUG(DBGS() << (X))
27#define FULL_LDBG(X) DEBUG_WITH_TYPE(DEBUG_TYPE_FULL, (DBGS() << (X)))
28
29using namespace mlir;
30
31//===----------------------------------------------------------------------===//
32// Helper functions
33//===----------------------------------------------------------------------===//
34
35/// Return true if `a` happens before `b`, i.e., `a` or one of its ancestors
36/// properly dominates `b` and `b` is not inside `a`.
37static bool happensBefore(Operation *a, Operation *b) {
38 do {
39 if (a->isProperAncestor(other: b))
40 return false;
41 if (Operation *bAncestor = a->getBlock()->findAncestorOpInBlock(op&: *b)) {
42 return a->isBeforeInBlock(other: bAncestor);
43 }
44 } while ((a = a->getParentOp()));
45 return false;
46}
47
48//===----------------------------------------------------------------------===//
49// TransformState
50//===----------------------------------------------------------------------===//
51
52constexpr const Value transform::TransformState::kTopLevelValue;
53
54transform::TransformState::TransformState(
55 Region *region, Operation *payloadRoot,
56 const RaggedArray<MappedValue> &extraMappings,
57 const TransformOptions &options)
58 : topLevel(payloadRoot), options(options) {
59 topLevelMappedValues.reserve(size: extraMappings.size());
60 for (ArrayRef<MappedValue> mapping : extraMappings)
61 topLevelMappedValues.push_back(elements&: mapping);
62 if (region) {
63 RegionScope *scope = new RegionScope(*this, *region);
64 topLevelRegionScope.reset(p: scope);
65 }
66}
67
68Operation *transform::TransformState::getTopLevel() const { return topLevel; }
69
70ArrayRef<Operation *>
71transform::TransformState::getPayloadOpsView(Value value) const {
72 const TransformOpMapping &operationMapping = getMapping(value).direct;
73 auto iter = operationMapping.find(Val: value);
74 assert(iter != operationMapping.end() &&
75 "cannot find mapping for payload handle (param/value handle "
76 "provided?)");
77 return iter->getSecond();
78}
79
80ArrayRef<Attribute> transform::TransformState::getParams(Value value) const {
81 const ParamMapping &mapping = getMapping(value).params;
82 auto iter = mapping.find(Val: value);
83 assert(iter != mapping.end() && "cannot find mapping for param handle "
84 "(operation/value handle provided?)");
85 return iter->getSecond();
86}
87
88ArrayRef<Value>
89transform::TransformState::getPayloadValuesView(Value handleValue) const {
90 const ValueMapping &mapping = getMapping(value: handleValue).values;
91 auto iter = mapping.find(Val: handleValue);
92 assert(iter != mapping.end() && "cannot find mapping for value handle "
93 "(param/operation handle provided?)");
94 return iter->getSecond();
95}
96
97LogicalResult transform::TransformState::getHandlesForPayloadOp(
98 Operation *op, SmallVectorImpl<Value> &handles,
99 bool includeOutOfScope) const {
100 bool found = false;
101 for (const auto &[region, mapping] : llvm::reverse(C: mappings)) {
102 auto iterator = mapping->reverse.find(Val: op);
103 if (iterator != mapping->reverse.end()) {
104 llvm::append_range(C&: handles, R&: iterator->getSecond());
105 found = true;
106 }
107 // Stop looking when reaching a region that is isolated from above.
108 if (!includeOutOfScope &&
109 region->getParentOp()->hasTrait<OpTrait::IsIsolatedFromAbove>())
110 break;
111 }
112
113 return success(isSuccess: found);
114}
115
116LogicalResult transform::TransformState::getHandlesForPayloadValue(
117 Value payloadValue, SmallVectorImpl<Value> &handles,
118 bool includeOutOfScope) const {
119 bool found = false;
120 for (const auto &[region, mapping] : llvm::reverse(C: mappings)) {
121 auto iterator = mapping->reverseValues.find(Val: payloadValue);
122 if (iterator != mapping->reverseValues.end()) {
123 llvm::append_range(C&: handles, R&: iterator->getSecond());
124 found = true;
125 }
126 // Stop looking when reaching a region that is isolated from above.
127 if (!includeOutOfScope &&
128 region->getParentOp()->hasTrait<OpTrait::IsIsolatedFromAbove>())
129 break;
130 }
131
132 return success(isSuccess: found);
133}
134
135/// Given a list of MappedValues, cast them to the value kind implied by the
136/// interface of the handle type, and dispatch to one of the callbacks.
137static DiagnosedSilenceableFailure dispatchMappedValues(
138 Value handle, ArrayRef<transform::MappedValue> values,
139 function_ref<LogicalResult(ArrayRef<Operation *>)> operationsFn,
140 function_ref<LogicalResult(ArrayRef<transform::Param>)> paramsFn,
141 function_ref<LogicalResult(ValueRange)> valuesFn) {
142 if (llvm::isa<transform::TransformHandleTypeInterface>(handle.getType())) {
143 SmallVector<Operation *> operations;
144 operations.reserve(N: values.size());
145 for (transform::MappedValue value : values) {
146 if (auto *op = llvm::dyn_cast_if_present<Operation *>(Val&: value)) {
147 operations.push_back(Elt: op);
148 continue;
149 }
150 return emitSilenceableFailure(loc: handle.getLoc())
151 << "wrong kind of value provided for top-level operation handle";
152 }
153 if (failed(result: operationsFn(operations)))
154 return DiagnosedSilenceableFailure::definiteFailure();
155 return DiagnosedSilenceableFailure::success();
156 }
157
158 if (llvm::isa<transform::TransformValueHandleTypeInterface>(
159 handle.getType())) {
160 SmallVector<Value> payloadValues;
161 payloadValues.reserve(N: values.size());
162 for (transform::MappedValue value : values) {
163 if (auto v = llvm::dyn_cast_if_present<Value>(Val&: value)) {
164 payloadValues.push_back(Elt: v);
165 continue;
166 }
167 return emitSilenceableFailure(loc: handle.getLoc())
168 << "wrong kind of value provided for the top-level value handle";
169 }
170 if (failed(result: valuesFn(payloadValues)))
171 return DiagnosedSilenceableFailure::definiteFailure();
172 return DiagnosedSilenceableFailure::success();
173 }
174
175 assert(llvm::isa<transform::TransformParamTypeInterface>(handle.getType()) &&
176 "unsupported kind of block argument");
177 SmallVector<transform::Param> parameters;
178 parameters.reserve(N: values.size());
179 for (transform::MappedValue value : values) {
180 if (auto attr = llvm::dyn_cast_if_present<Attribute>(Val&: value)) {
181 parameters.push_back(Elt: attr);
182 continue;
183 }
184 return emitSilenceableFailure(loc: handle.getLoc())
185 << "wrong kind of value provided for top-level parameter";
186 }
187 if (failed(result: paramsFn(parameters)))
188 return DiagnosedSilenceableFailure::definiteFailure();
189 return DiagnosedSilenceableFailure::success();
190}
191
192LogicalResult
193transform::TransformState::mapBlockArgument(BlockArgument argument,
194 ArrayRef<MappedValue> values) {
195 return dispatchMappedValues(
196 handle: argument, values,
197 operationsFn: [&](ArrayRef<Operation *> operations) {
198 return setPayloadOps(value: argument, targets: operations);
199 },
200 paramsFn: [&](ArrayRef<Param> params) {
201 return setParams(value: argument, params);
202 },
203 valuesFn: [&](ValueRange payloadValues) {
204 return setPayloadValues(handle: argument, payloadValues);
205 })
206 .checkAndReport();
207}
208
209LogicalResult
210transform::TransformState::setPayloadOps(Value value,
211 ArrayRef<Operation *> targets) {
212 assert(value != kTopLevelValue &&
213 "attempting to reset the transformation root");
214 assert(llvm::isa<TransformHandleTypeInterface>(value.getType()) &&
215 "wrong handle type");
216
217 for (Operation *target : targets) {
218 if (target)
219 continue;
220 return emitError(loc: value.getLoc())
221 << "attempting to assign a null payload op to this transform value";
222 }
223
224 auto iface = llvm::cast<TransformHandleTypeInterface>(value.getType());
225 DiagnosedSilenceableFailure result =
226 iface.checkPayload(value.getLoc(), targets);
227 if (failed(result: result.checkAndReport()))
228 return failure();
229
230 // Setting new payload for the value without cleaning it first is a misuse of
231 // the API, assert here.
232 SmallVector<Operation *> storedTargets(targets.begin(), targets.end());
233 Mappings &mappings = getMapping(value);
234 bool inserted =
235 mappings.direct.insert(KV: {value, std::move(storedTargets)}).second;
236 assert(inserted && "value is already associated with another list");
237 (void)inserted;
238
239 for (Operation *op : targets)
240 mappings.reverse[op].push_back(Elt: value);
241
242 return success();
243}
244
245LogicalResult
246transform::TransformState::setPayloadValues(Value handle,
247 ValueRange payloadValues) {
248 assert(handle != nullptr && "attempting to set params for a null value");
249 assert(llvm::isa<TransformValueHandleTypeInterface>(handle.getType()) &&
250 "wrong handle type");
251
252 for (Value payload : payloadValues) {
253 if (payload)
254 continue;
255 return emitError(loc: handle.getLoc()) << "attempting to assign a null payload "
256 "value to this transform handle";
257 }
258
259 auto iface = llvm::cast<TransformValueHandleTypeInterface>(handle.getType());
260 SmallVector<Value> payloadValueVector = llvm::to_vector(Range&: payloadValues);
261 DiagnosedSilenceableFailure result =
262 iface.checkPayload(handle.getLoc(), payloadValueVector);
263 if (failed(result: result.checkAndReport()))
264 return failure();
265
266 Mappings &mappings = getMapping(value: handle);
267 bool inserted =
268 mappings.values.insert(KV: {handle, std::move(payloadValueVector)}).second;
269 assert(
270 inserted &&
271 "value handle is already associated with another list of payload values");
272 (void)inserted;
273
274 for (Value payload : payloadValues)
275 mappings.reverseValues[payload].push_back(Elt: handle);
276
277 return success();
278}
279
280LogicalResult transform::TransformState::setParams(Value value,
281 ArrayRef<Param> params) {
282 assert(value != nullptr && "attempting to set params for a null value");
283
284 for (Attribute attr : params) {
285 if (attr)
286 continue;
287 return emitError(loc: value.getLoc())
288 << "attempting to assign a null parameter to this transform value";
289 }
290
291 auto valueType = llvm::dyn_cast<TransformParamTypeInterface>(value.getType());
292 assert(value &&
293 "cannot associate parameter with a value of non-parameter type");
294 DiagnosedSilenceableFailure result =
295 valueType.checkPayload(value.getLoc(), params);
296 if (failed(result: result.checkAndReport()))
297 return failure();
298
299 Mappings &mappings = getMapping(value);
300 bool inserted =
301 mappings.params.insert(KV: {value, llvm::to_vector(Range&: params)}).second;
302 assert(inserted && "value is already associated with another list of params");
303 (void)inserted;
304 return success();
305}
306
307template <typename Mapping, typename Key, typename Mapped>
308void dropMappingEntry(Mapping &mapping, Key key, Mapped mapped) {
309 auto it = mapping.find(key);
310 if (it == mapping.end())
311 return;
312
313 llvm::erase(it->getSecond(), mapped);
314 if (it->getSecond().empty())
315 mapping.erase(it);
316}
317
318void transform::TransformState::forgetMapping(Value opHandle,
319 ValueRange origOpFlatResults,
320 bool allowOutOfScope) {
321 Mappings &mappings = getMapping(value: opHandle, allowOutOfScope);
322 for (Operation *op : mappings.direct[opHandle])
323 dropMappingEntry(mapping&: mappings.reverse, key: op, mapped: opHandle);
324 mappings.direct.erase(Val: opHandle);
325#ifdef LLVM_ENABLE_ABI_BREAKING_CHECKS
326 // Payload IR is removed from the mapping. This invalidates the respective
327 // iterators.
328 mappings.incrementTimestamp(value: opHandle);
329#endif // LLVM_ENABLE_ABI_BREAKING_CHECKS
330
331 for (Value opResult : origOpFlatResults) {
332 SmallVector<Value> resultHandles;
333 (void)getHandlesForPayloadValue(payloadValue: opResult, handles&: resultHandles);
334 for (Value resultHandle : resultHandles) {
335 Mappings &localMappings = getMapping(value: resultHandle);
336 dropMappingEntry(mapping&: localMappings.values, key: resultHandle, mapped: opResult);
337#ifdef LLVM_ENABLE_ABI_BREAKING_CHECKS
338 // Payload IR is removed from the mapping. This invalidates the respective
339 // iterators.
340 mappings.incrementTimestamp(value: resultHandle);
341#endif // LLVM_ENABLE_ABI_BREAKING_CHECKS
342 dropMappingEntry(mapping&: localMappings.reverseValues, key: opResult, mapped: resultHandle);
343 }
344 }
345}
346
347void transform::TransformState::forgetValueMapping(
348 Value valueHandle, ArrayRef<Operation *> payloadOperations) {
349 Mappings &mappings = getMapping(value: valueHandle);
350 for (Value payloadValue : mappings.reverseValues[valueHandle])
351 dropMappingEntry(mapping&: mappings.reverseValues, key: payloadValue, mapped: valueHandle);
352 mappings.values.erase(Val: valueHandle);
353#ifdef LLVM_ENABLE_ABI_BREAKING_CHECKS
354 // Payload IR is removed from the mapping. This invalidates the respective
355 // iterators.
356 mappings.incrementTimestamp(value: valueHandle);
357#endif // LLVM_ENABLE_ABI_BREAKING_CHECKS
358
359 for (Operation *payloadOp : payloadOperations) {
360 SmallVector<Value> opHandles;
361 (void)getHandlesForPayloadOp(op: payloadOp, handles&: opHandles);
362 for (Value opHandle : opHandles) {
363 Mappings &localMappings = getMapping(value: opHandle);
364 dropMappingEntry(mapping&: localMappings.direct, key: opHandle, mapped: payloadOp);
365 dropMappingEntry(mapping&: localMappings.reverse, key: payloadOp, mapped: opHandle);
366
367#ifdef LLVM_ENABLE_ABI_BREAKING_CHECKS
368 // Payload IR is removed from the mapping. This invalidates the respective
369 // iterators.
370 localMappings.incrementTimestamp(value: opHandle);
371#endif // LLVM_ENABLE_ABI_BREAKING_CHECKS
372 }
373 }
374}
375
376LogicalResult
377transform::TransformState::replacePayloadOp(Operation *op,
378 Operation *replacement) {
379 // TODO: consider invalidating the handles to nested objects here.
380
381#ifndef NDEBUG
382 for (Value opResult : op->getResults()) {
383 SmallVector<Value> valueHandles;
384 (void)getHandlesForPayloadValue(payloadValue: opResult, handles&: valueHandles,
385 /*includeOutOfScope=*/true);
386 assert(valueHandles.empty() && "expected no mapping to old results");
387 }
388#endif // NDEBUG
389
390 // Drop the mapping between the op and all handles that point to it. Fail if
391 // there are no handles.
392 SmallVector<Value> opHandles;
393 if (failed(result: getHandlesForPayloadOp(op, handles&: opHandles, /*includeOutOfScope=*/true)))
394 return failure();
395 for (Value handle : opHandles) {
396 Mappings &mappings = getMapping(value: handle, /*allowOutOfScope=*/true);
397 dropMappingEntry(mapping&: mappings.reverse, key: op, mapped: handle);
398 }
399
400 // Replace the pointed-to object of all handles with the replacement object.
401 // In case a payload op was erased (replacement object is nullptr), a nullptr
402 // is stored in the mapping. These nullptrs are removed after each transform.
403 // Furthermore, nullptrs are not enumerated by payload op iterators. The
404 // relative order of ops is preserved.
405 //
406 // Removing an op from the mapping would be problematic because removing an
407 // element from an array invalidates iterators; merely changing the value of
408 // elements does not.
409 for (Value handle : opHandles) {
410 Mappings &mappings = getMapping(value: handle, /*allowOutOfScope=*/true);
411 auto it = mappings.direct.find(Val: handle);
412 if (it == mappings.direct.end())
413 continue;
414
415 SmallVector<Operation *, 2> &association = it->getSecond();
416 // Note that an operation may be associated with the handle more than once.
417 for (Operation *&mapped : association) {
418 if (mapped == op)
419 mapped = replacement;
420 }
421
422 if (replacement) {
423 mappings.reverse[replacement].push_back(Elt: handle);
424 } else {
425 opHandlesToCompact.insert(V: handle);
426 }
427 }
428
429 return success();
430}
431
432LogicalResult
433transform::TransformState::replacePayloadValue(Value value, Value replacement) {
434 SmallVector<Value> valueHandles;
435 if (failed(result: getHandlesForPayloadValue(payloadValue: value, handles&: valueHandles,
436 /*includeOutOfScope=*/true)))
437 return failure();
438
439 for (Value handle : valueHandles) {
440 Mappings &mappings = getMapping(value: handle, /*allowOutOfScope=*/true);
441 dropMappingEntry(mapping&: mappings.reverseValues, key: value, mapped: handle);
442
443 // If replacing with null, that is erasing the mapping, drop the mapping
444 // between the handles and the IR objects
445 if (!replacement) {
446 dropMappingEntry(mapping&: mappings.values, key: handle, mapped: value);
447#ifdef LLVM_ENABLE_ABI_BREAKING_CHECKS
448 // Payload IR is removed from the mapping. This invalidates the respective
449 // iterators.
450 mappings.incrementTimestamp(value: handle);
451#endif // LLVM_ENABLE_ABI_BREAKING_CHECKS
452 } else {
453 auto it = mappings.values.find(Val: handle);
454 if (it == mappings.values.end())
455 continue;
456
457 SmallVector<Value> &association = it->getSecond();
458 for (Value &mapped : association) {
459 if (mapped == value)
460 mapped = replacement;
461 }
462 mappings.reverseValues[replacement].push_back(Elt: handle);
463 }
464 }
465
466 return success();
467}
468
469void transform::TransformState::recordOpHandleInvalidationOne(
470 OpOperand &consumingHandle, ArrayRef<Operation *> potentialAncestors,
471 Operation *payloadOp, Value otherHandle, Value throughValue,
472 transform::TransformState::InvalidatedHandleMap &newlyInvalidated) const {
473 // If the op is associated with invalidated handle, skip the check as it
474 // may be reading invalid IR. This also ensures we report the first
475 // invalidation and not the last one.
476 if (invalidatedHandles.count(Val: otherHandle) ||
477 newlyInvalidated.count(Val: otherHandle))
478 return;
479
480 FULL_LDBG("--recordOpHandleInvalidationOne\n");
481 DEBUG_WITH_TYPE(
482 DEBUG_TYPE_FULL,
483 llvm::interleaveComma(potentialAncestors, DBGS() << "--ancestors: ",
484 [](Operation *op) { llvm::dbgs() << *op; });
485 llvm::dbgs() << "\n");
486
487 Operation *owner = consumingHandle.getOwner();
488 unsigned operandNo = consumingHandle.getOperandNumber();
489 for (Operation *ancestor : potentialAncestors) {
490 // clang-format off
491 DEBUG_WITH_TYPE(DEBUG_TYPE_FULL,
492 { (DBGS() << "----handle one ancestor: " << *ancestor << "\n"); });
493 DEBUG_WITH_TYPE(DEBUG_TYPE_FULL,
494 { (DBGS() << "----of payload with name: "
495 << payloadOp->getName().getIdentifier() << "\n"); });
496 DEBUG_WITH_TYPE(DEBUG_TYPE_FULL,
497 { (DBGS() << "----of payload: " << *payloadOp << "\n"); });
498 // clang-format on
499 if (!ancestor->isAncestor(other: payloadOp))
500 continue;
501
502 // Make sure the error-reporting lambda doesn't capture anything
503 // by-reference because it will go out of scope. Additionally, extract
504 // location from Payload IR ops because the ops themselves may be
505 // deleted before the lambda gets called.
506 Location ancestorLoc = ancestor->getLoc();
507 Location opLoc = payloadOp->getLoc();
508 std::optional<Location> throughValueLoc =
509 throughValue ? std::make_optional(t: throughValue.getLoc()) : std::nullopt;
510 newlyInvalidated[otherHandle] = [ancestorLoc, opLoc, owner, operandNo,
511 otherHandle,
512 throughValueLoc](Location currentLoc) {
513 InFlightDiagnostic diag = emitError(loc: currentLoc)
514 << "op uses a handle invalidated by a "
515 "previously executed transform op";
516 diag.attachNote(noteLoc: otherHandle.getLoc()) << "handle to invalidated ops";
517 diag.attachNote(noteLoc: owner->getLoc())
518 << "invalidated by this transform op that consumes its operand #"
519 << operandNo
520 << " and invalidates all handles to payload IR entities associated "
521 "with this operand and entities nested in them";
522 diag.attachNote(noteLoc: ancestorLoc) << "ancestor payload op";
523 diag.attachNote(noteLoc: opLoc) << "nested payload op";
524 if (throughValueLoc) {
525 diag.attachNote(noteLoc: *throughValueLoc)
526 << "consumed handle points to this payload value";
527 }
528 };
529 }
530}
531
532void transform::TransformState::recordValueHandleInvalidationByOpHandleOne(
533 OpOperand &opHandle, ArrayRef<Operation *> potentialAncestors,
534 Value payloadValue, Value valueHandle,
535 transform::TransformState::InvalidatedHandleMap &newlyInvalidated) const {
536 // If the op is associated with invalidated handle, skip the check as it
537 // may be reading invalid IR. This also ensures we report the first
538 // invalidation and not the last one.
539 if (invalidatedHandles.count(Val: valueHandle) ||
540 newlyInvalidated.count(Val: valueHandle))
541 return;
542
543 for (Operation *ancestor : potentialAncestors) {
544 Operation *definingOp;
545 std::optional<unsigned> resultNo;
546 unsigned argumentNo = std::numeric_limits<unsigned>::max();
547 unsigned blockNo = std::numeric_limits<unsigned>::max();
548 unsigned regionNo = std::numeric_limits<unsigned>::max();
549 if (auto opResult = llvm::dyn_cast<OpResult>(Val&: payloadValue)) {
550 definingOp = opResult.getOwner();
551 resultNo = opResult.getResultNumber();
552 } else {
553 auto arg = llvm::cast<BlockArgument>(Val&: payloadValue);
554 definingOp = arg.getParentBlock()->getParentOp();
555 argumentNo = arg.getArgNumber();
556 blockNo = std::distance(first: arg.getOwner()->getParent()->begin(),
557 last: arg.getOwner()->getIterator());
558 regionNo = arg.getOwner()->getParent()->getRegionNumber();
559 }
560 assert(definingOp && "expected the value to be defined by an op as result "
561 "or block argument");
562 if (!ancestor->isAncestor(other: definingOp))
563 continue;
564
565 Operation *owner = opHandle.getOwner();
566 unsigned operandNo = opHandle.getOperandNumber();
567 Location ancestorLoc = ancestor->getLoc();
568 Location opLoc = definingOp->getLoc();
569 Location valueLoc = payloadValue.getLoc();
570 newlyInvalidated[valueHandle] = [valueHandle, owner, operandNo, resultNo,
571 argumentNo, blockNo, regionNo, ancestorLoc,
572 opLoc, valueLoc](Location currentLoc) {
573 InFlightDiagnostic diag = emitError(loc: currentLoc)
574 << "op uses a handle invalidated by a "
575 "previously executed transform op";
576 diag.attachNote(noteLoc: valueHandle.getLoc()) << "invalidated handle";
577 diag.attachNote(noteLoc: owner->getLoc())
578 << "invalidated by this transform op that consumes its operand #"
579 << operandNo
580 << " and invalidates all handles to payload IR entities "
581 "associated with this operand and entities nested in them";
582 diag.attachNote(noteLoc: ancestorLoc)
583 << "ancestor op associated with the consumed handle";
584 if (resultNo) {
585 diag.attachNote(noteLoc: opLoc)
586 << "op defining the value as result #" << *resultNo;
587 } else {
588 diag.attachNote(noteLoc: opLoc)
589 << "op defining the value as block argument #" << argumentNo
590 << " of block #" << blockNo << " in region #" << regionNo;
591 }
592 diag.attachNote(noteLoc: valueLoc) << "payload value";
593 };
594 }
595}
596
597void transform::TransformState::recordOpHandleInvalidation(
598 OpOperand &handle, ArrayRef<Operation *> potentialAncestors,
599 Value throughValue,
600 transform::TransformState::InvalidatedHandleMap &newlyInvalidated) const {
601
602 if (potentialAncestors.empty()) {
603 DEBUG_WITH_TYPE(DEBUG_TYPE_FULL, {
604 (DBGS() << "----recording invalidation for empty handle: " << handle.get()
605 << "\n");
606 });
607
608 Operation *owner = handle.getOwner();
609 unsigned operandNo = handle.getOperandNumber();
610 newlyInvalidated[handle.get()] = [owner, operandNo](Location currentLoc) {
611 InFlightDiagnostic diag = emitError(loc: currentLoc)
612 << "op uses a handle associated with empty "
613 "payload and invalidated by a "
614 "previously executed transform op";
615 diag.attachNote(noteLoc: owner->getLoc())
616 << "invalidated by this transform op that consumes its operand #"
617 << operandNo;
618 };
619 return;
620 }
621
622 // Iterate over the mapping and invalidate aliasing handles. This is quite
623 // expensive and only necessary for error reporting in case of transform
624 // dialect misuse with dangling handles. Iteration over the handles is based
625 // on the assumption that the number of handles is significantly less than the
626 // number of IR objects (operations and values). Alternatively, we could walk
627 // the IR nested in each payload op associated with the given handle and look
628 // for handles associated with each operation and value.
629 for (const auto &[region, mapping] : llvm::reverse(C: mappings)) {
630 // Go over all op handle mappings and mark as invalidated any handle
631 // pointing to any of the payload ops associated with the given handle or
632 // any op nested in them.
633 for (const auto &[payloadOp, otherHandles] : mapping->reverse) {
634 for (Value otherHandle : otherHandles)
635 recordOpHandleInvalidationOne(consumingHandle&: handle, potentialAncestors, payloadOp,
636 otherHandle, throughValue,
637 newlyInvalidated);
638 }
639 // Go over all value handle mappings and mark as invalidated any handle
640 // pointing to any result of the payload op associated with the given handle
641 // or any op nested in them. Similarly invalidate handles to argument of
642 // blocks belonging to any region of any payload op associated with the
643 // given handle or any op nested in them.
644 for (const auto &[payloadValue, valueHandles] : mapping->reverseValues) {
645 for (Value valueHandle : valueHandles)
646 recordValueHandleInvalidationByOpHandleOne(opHandle&: handle, potentialAncestors,
647 payloadValue, valueHandle,
648 newlyInvalidated);
649 }
650
651 // Stop lookup when reaching a region that is isolated from above.
652 if (region->getParentOp()->hasTrait<OpTrait::IsIsolatedFromAbove>())
653 break;
654 }
655}
656
657void transform::TransformState::recordValueHandleInvalidation(
658 OpOperand &valueHandle,
659 transform::TransformState::InvalidatedHandleMap &newlyInvalidated) const {
660 // Invalidate other handles to the same value.
661 for (Value payloadValue : getPayloadValuesView(handleValue: valueHandle.get())) {
662 SmallVector<Value> otherValueHandles;
663 (void)getHandlesForPayloadValue(payloadValue, handles&: otherValueHandles);
664 for (Value otherHandle : otherValueHandles) {
665 Operation *owner = valueHandle.getOwner();
666 unsigned operandNo = valueHandle.getOperandNumber();
667 Location valueLoc = payloadValue.getLoc();
668 newlyInvalidated[otherHandle] = [otherHandle, owner, operandNo,
669 valueLoc](Location currentLoc) {
670 InFlightDiagnostic diag = emitError(loc: currentLoc)
671 << "op uses a handle invalidated by a "
672 "previously executed transform op";
673 diag.attachNote(noteLoc: otherHandle.getLoc()) << "invalidated handle";
674 diag.attachNote(noteLoc: owner->getLoc())
675 << "invalidated by this transform op that consumes its operand #"
676 << operandNo
677 << " and invalidates handles to the same values as associated with "
678 "it";
679 diag.attachNote(noteLoc: valueLoc) << "payload value";
680 };
681 }
682
683 if (auto opResult = llvm::dyn_cast<OpResult>(Val&: payloadValue)) {
684 Operation *payloadOp = opResult.getOwner();
685 recordOpHandleInvalidation(handle&: valueHandle, potentialAncestors: payloadOp, throughValue: payloadValue,
686 newlyInvalidated);
687 } else {
688 auto arg = llvm::dyn_cast<BlockArgument>(Val&: payloadValue);
689 for (Operation &payloadOp : *arg.getOwner())
690 recordOpHandleInvalidation(handle&: valueHandle, potentialAncestors: &payloadOp, throughValue: payloadValue,
691 newlyInvalidated);
692 }
693 }
694}
695
696/// Checks that the operation does not use invalidated handles as operands.
697/// Reports errors and returns failure if it does. Otherwise, invalidates the
698/// handles consumed by the operation as well as any handles pointing to payload
699/// IR operations nested in the operations associated with the consumed handles.
700LogicalResult transform::TransformState::checkAndRecordHandleInvalidationImpl(
701 transform::TransformOpInterface transform,
702 transform::TransformState::InvalidatedHandleMap &newlyInvalidated) const {
703 FULL_LDBG("--Start checkAndRecordHandleInvalidation\n");
704 auto memoryEffectsIface =
705 cast<MemoryEffectOpInterface>(transform.getOperation());
706 SmallVector<MemoryEffects::EffectInstance> effects;
707 memoryEffectsIface.getEffectsOnResource(
708 transform::TransformMappingResource::get(), effects);
709
710 for (OpOperand &target : transform->getOpOperands()) {
711 DEBUG_WITH_TYPE(DEBUG_TYPE_FULL, {
712 (DBGS() << "----iterate on handle: " << target.get() << "\n");
713 });
714 // If the operand uses an invalidated handle, report it. If the operation
715 // allows handles to point to repeated payload operations, only report
716 // pre-existing invalidation errors. Otherwise, also report invalidations
717 // caused by the current transform operation affecting its other operands.
718 auto it = invalidatedHandles.find(target.get());
719 auto nit = newlyInvalidated.find(target.get());
720 if (it != invalidatedHandles.end()) {
721 FULL_LDBG("--End checkAndRecordHandleInvalidation, found already "
722 "invalidated -> FAILURE\n");
723 return it->getSecond()(transform->getLoc()), failure();
724 }
725 if (!transform.allowsRepeatedHandleOperands() &&
726 nit != newlyInvalidated.end()) {
727 FULL_LDBG("--End checkAndRecordHandleInvalidation, found newly "
728 "invalidated (by this op) -> FAILURE\n");
729 return nit->getSecond()(transform->getLoc()), failure();
730 }
731
732 // Invalidate handles pointing to the operations nested in the operation
733 // associated with the handle consumed by this operation.
734 auto consumesTarget = [&](const MemoryEffects::EffectInstance &effect) {
735 return isa<MemoryEffects::Free>(effect.getEffect()) &&
736 effect.getValue() == target.get();
737 };
738 if (llvm::any_of(effects, consumesTarget)) {
739 FULL_LDBG("----found consume effect\n");
740 if (llvm::isa<transform::TransformHandleTypeInterface>(
741 target.get().getType())) {
742 FULL_LDBG("----recordOpHandleInvalidation\n");
743 SmallVector<Operation *> payloadOps =
744 llvm::to_vector(getPayloadOps(target.get()));
745 recordOpHandleInvalidation(target, payloadOps, nullptr,
746 newlyInvalidated);
747 } else if (llvm::isa<transform::TransformValueHandleTypeInterface>(
748 target.get().getType())) {
749 FULL_LDBG("----recordValueHandleInvalidation\n");
750 recordValueHandleInvalidation(target, newlyInvalidated);
751 } else {
752 FULL_LDBG("----not a TransformHandle -> SKIP AND DROP ON THE FLOOR\n");
753 }
754 } else {
755 FULL_LDBG("----no consume effect -> SKIP\n");
756 }
757 }
758
759 FULL_LDBG("--End checkAndRecordHandleInvalidation -> SUCCESS\n");
760 return success();
761}
762
763LogicalResult transform::TransformState::checkAndRecordHandleInvalidation(
764 transform::TransformOpInterface transform) {
765 InvalidatedHandleMap newlyInvalidated;
766 LogicalResult checkResult =
767 checkAndRecordHandleInvalidationImpl(transform: transform, newlyInvalidated);
768 invalidatedHandles.insert(I: std::make_move_iterator(i: newlyInvalidated.begin()),
769 E: std::make_move_iterator(i: newlyInvalidated.end()));
770 return checkResult;
771}
772
773template <typename T>
774DiagnosedSilenceableFailure
775checkRepeatedConsumptionInOperand(ArrayRef<T> payload,
776 transform::TransformOpInterface transform,
777 unsigned operandNumber) {
778 DenseSet<T> seen;
779 for (T p : payload) {
780 if (!seen.insert(p).second) {
781 DiagnosedSilenceableFailure diag =
782 transform.emitSilenceableError()
783 << "a handle passed as operand #" << operandNumber
784 << " and consumed by this operation points to a payload "
785 "entity more than once";
786 if constexpr (std::is_pointer_v<T>)
787 diag.attachNote(loc: p->getLoc()) << "repeated target op";
788 else
789 diag.attachNote(loc: p.getLoc()) << "repeated target value";
790 return diag;
791 }
792 }
793 return DiagnosedSilenceableFailure::success();
794}
795
796void transform::TransformState::compactOpHandles() {
797 for (Value handle : opHandlesToCompact) {
798 Mappings &mappings = getMapping(value: handle, /*allowOutOfScope=*/true);
799#ifdef LLVM_ENABLE_ABI_BREAKING_CHECKS
800 if (llvm::find(Range&: mappings.direct[handle], Val: nullptr) !=
801 mappings.direct[handle].end())
802 // Payload IR is removed from the mapping. This invalidates the respective
803 // iterators.
804 mappings.incrementTimestamp(value: handle);
805#endif // LLVM_ENABLE_ABI_BREAKING_CHECKS
806 llvm::erase(C&: mappings.direct[handle], V: nullptr);
807 }
808 opHandlesToCompact.clear();
809}
810
811DiagnosedSilenceableFailure
812transform::TransformState::applyTransform(TransformOpInterface transform) {
813 LLVM_DEBUG({
814 DBGS() << "applying: ";
815 transform->print(llvm::dbgs(), OpPrintingFlags().skipRegions());
816 llvm::dbgs() << "\n";
817 });
818 DEBUG_WITH_TYPE(DEBUG_TYPE_FULL,
819 DBGS() << "Top-level payload before application:\n"
820 << *getTopLevel() << "\n");
821 auto printOnFailureRAII = llvm::make_scope_exit(F: [this] {
822 (void)this;
823 LLVM_DEBUG(DBGS() << "Failing Top-level payload:\n"; getTopLevel()->print(
824 llvm::dbgs(), mlir::OpPrintingFlags().printGenericOpForm()););
825 });
826
827 // Set current transform op.
828 regionStack.back()->currentTransform = transform;
829
830 // Expensive checks to detect invalid transform IR.
831 if (options.getExpensiveChecksEnabled()) {
832 FULL_LDBG("ExpensiveChecksEnabled\n");
833 if (failed(checkAndRecordHandleInvalidation(transform: transform)))
834 return DiagnosedSilenceableFailure::definiteFailure();
835
836 for (OpOperand &operand : transform->getOpOperands()) {
837 DEBUG_WITH_TYPE(DEBUG_TYPE_FULL, {
838 (DBGS() << "iterate on handle: " << operand.get() << "\n");
839 });
840 if (!isHandleConsumed(operand.get(), transform)) {
841 FULL_LDBG("--handle not consumed -> SKIP\n");
842 continue;
843 }
844 if (transform.allowsRepeatedHandleOperands()) {
845 FULL_LDBG("--op allows repeated handles -> SKIP\n");
846 continue;
847 }
848 FULL_LDBG("--handle is consumed\n");
849
850 Type operandType = operand.get().getType();
851 if (llvm::isa<TransformHandleTypeInterface>(operandType)) {
852 FULL_LDBG("--checkRepeatedConsumptionInOperand for Operation*\n");
853 DiagnosedSilenceableFailure check =
854 checkRepeatedConsumptionInOperand<Operation *>(
855 getPayloadOpsView(operand.get()), transform,
856 operand.getOperandNumber());
857 if (!check.succeeded()) {
858 FULL_LDBG("----FAILED\n");
859 return check;
860 }
861 } else if (llvm::isa<TransformValueHandleTypeInterface>(operandType)) {
862 FULL_LDBG("--checkRepeatedConsumptionInOperand For Value\n");
863 DiagnosedSilenceableFailure check =
864 checkRepeatedConsumptionInOperand<Value>(
865 getPayloadValuesView(operand.get()), transform,
866 operand.getOperandNumber());
867 if (!check.succeeded()) {
868 FULL_LDBG("----FAILED\n");
869 return check;
870 }
871 } else {
872 FULL_LDBG("--not a TransformHandle -> SKIP AND DROP ON THE FLOOR\n");
873 }
874 }
875 }
876
877 // Find which operands are consumed.
878 SmallVector<OpOperand *> consumedOperands =
879 transform.getConsumedHandleOpOperands();
880
881 // Remember the results of the payload ops associated with the consumed
882 // op handles or the ops defining the value handles so we can drop the
883 // association with them later. This must happen here because the
884 // transformation may destroy or mutate them so we cannot traverse the payload
885 // IR after that.
886 SmallVector<Value> origOpFlatResults;
887 SmallVector<Operation *> origAssociatedOps;
888 for (OpOperand *opOperand : consumedOperands) {
889 Value operand = opOperand->get();
890 if (llvm::isa<TransformHandleTypeInterface>(operand.getType())) {
891 for (Operation *payloadOp : getPayloadOps(operand)) {
892 llvm::append_range(origOpFlatResults, payloadOp->getResults());
893 }
894 continue;
895 }
896 if (llvm::isa<TransformValueHandleTypeInterface>(operand.getType())) {
897 for (Value payloadValue : getPayloadValuesView(operand)) {
898 if (llvm::isa<OpResult>(payloadValue)) {
899 origAssociatedOps.push_back(payloadValue.getDefiningOp());
900 continue;
901 }
902 llvm::append_range(
903 origAssociatedOps,
904 llvm::map_range(*llvm::cast<BlockArgument>(payloadValue).getOwner(),
905 [](Operation &op) { return &op; }));
906 }
907 continue;
908 }
909 DiagnosedDefiniteFailure diag =
910 emitDefiniteFailure(transform->getLoc())
911 << "unexpectedly consumed a value that is not a handle as operand #"
912 << opOperand->getOperandNumber();
913 diag.attachNote(operand.getLoc())
914 << "value defined here with type " << operand.getType();
915 return diag;
916 }
917
918 // Prepare rewriter and listener.
919 TrackingListenerConfig config;
920 config.skipHandleFn = [&](Value handle) {
921 // Skip handle if it is dead.
922 auto scopeIt =
923 llvm::find_if(Range: llvm::reverse(C&: regionStack), P: [&](RegionScope *scope) {
924 return handle.getParentRegion() == scope->region;
925 });
926 assert(scopeIt != regionStack.rend() &&
927 "could not find region scope for handle");
928 RegionScope *scope = *scopeIt;
929 for (Operation *user : handle.getUsers()) {
930 if (user != scope->currentTransform &&
931 !happensBefore(user, scope->currentTransform))
932 return false;
933 }
934 return true;
935 };
936 transform::ErrorCheckingTrackingListener trackingListener(*this, transform,
937 config);
938 transform::TransformRewriter rewriter(transform->getContext(),
939 &trackingListener);
940
941 // Compute the result but do not short-circuit the silenceable failure case as
942 // we still want the handles to propagate properly so the "suppress" mode can
943 // proceed on a best effort basis.
944 transform::TransformResults results(transform->getNumResults());
945 DiagnosedSilenceableFailure result(transform.apply(rewriter, results, *this));
946 compactOpHandles();
947
948 // Error handling: fail if transform or listener failed.
949 DiagnosedSilenceableFailure trackingFailure =
950 trackingListener.checkAndResetError();
951 if (!transform->hasTrait<ReportTrackingListenerFailuresOpTrait>() ||
952 transform->hasAttr(FindPayloadReplacementOpInterface::
953 kSilenceTrackingFailuresAttrName)) {
954 // Only report failures for ReportTrackingListenerFailuresOpTrait ops. Also
955 // do not report failures if the above mentioned attribute is set.
956 if (trackingFailure.isSilenceableFailure())
957 (void)trackingFailure.silence();
958 trackingFailure = DiagnosedSilenceableFailure::success();
959 }
960 if (!trackingFailure.succeeded()) {
961 if (result.succeeded()) {
962 result = std::move(trackingFailure);
963 } else {
964 // Transform op errors have precedence, report those first.
965 if (result.isSilenceableFailure())
966 result.attachNote() << "tracking listener also failed: "
967 << trackingFailure.getMessage();
968 (void)trackingFailure.silence();
969 }
970 }
971 if (result.isDefiniteFailure())
972 return result;
973
974 // If a silenceable failure was produced, some results may be unset, set them
975 // to empty lists.
976 if (result.isSilenceableFailure())
977 results.setRemainingToEmpty(transform);
978
979 // Remove the mapping for the operand if it is consumed by the operation. This
980 // allows us to catch use-after-free with assertions later on.
981 for (OpOperand *opOperand : consumedOperands) {
982 Value operand = opOperand->get();
983 if (llvm::isa<TransformHandleTypeInterface>(operand.getType())) {
984 forgetMapping(operand, origOpFlatResults);
985 } else if (llvm::isa<TransformValueHandleTypeInterface>(
986 operand.getType())) {
987 forgetValueMapping(operand, origAssociatedOps);
988 }
989 }
990
991 if (failed(updateStateFromResults(results, opResults: transform->getResults())))
992 return DiagnosedSilenceableFailure::definiteFailure();
993
994 printOnFailureRAII.release();
995 DEBUG_WITH_TYPE(DEBUG_PRINT_AFTER_ALL, {
996 DBGS() << "Top-level payload:\n";
997 getTopLevel()->print(llvm::dbgs());
998 });
999 return result;
1000}
1001
1002LogicalResult transform::TransformState::updateStateFromResults(
1003 const TransformResults &results, ResultRange opResults) {
1004 for (OpResult result : opResults) {
1005 if (llvm::isa<TransformParamTypeInterface>(result.getType())) {
1006 assert(results.isParam(result.getResultNumber()) &&
1007 "expected parameters for the parameter-typed result");
1008 if (failed(
1009 result: setParams(value: result, params: results.getParams(resultNumber: result.getResultNumber())))) {
1010 return failure();
1011 }
1012 } else if (llvm::isa<TransformValueHandleTypeInterface>(result.getType())) {
1013 assert(results.isValue(result.getResultNumber()) &&
1014 "expected values for value-type-result");
1015 if (failed(result: setPayloadValues(
1016 handle: result, payloadValues: results.getValues(resultNumber: result.getResultNumber())))) {
1017 return failure();
1018 }
1019 } else {
1020 assert(!results.isParam(result.getResultNumber()) &&
1021 "expected payload ops for the non-parameter typed result");
1022 if (failed(
1023 result: setPayloadOps(value: result, targets: results.get(resultNumber: result.getResultNumber())))) {
1024 return failure();
1025 }
1026 }
1027 }
1028 return success();
1029}
1030
1031//===----------------------------------------------------------------------===//
1032// TransformState::Extension
1033//===----------------------------------------------------------------------===//
1034
1035transform::TransformState::Extension::~Extension() = default;
1036
1037LogicalResult
1038transform::TransformState::Extension::replacePayloadOp(Operation *op,
1039 Operation *replacement) {
1040 // TODO: we may need to invalidate handles to operations and values nested in
1041 // the operation being replaced.
1042 return state.replacePayloadOp(op, replacement);
1043}
1044
1045LogicalResult
1046transform::TransformState::Extension::replacePayloadValue(Value value,
1047 Value replacement) {
1048 return state.replacePayloadValue(value, replacement);
1049}
1050
1051//===----------------------------------------------------------------------===//
1052// TransformState::RegionScope
1053//===----------------------------------------------------------------------===//
1054
1055transform::TransformState::RegionScope::~RegionScope() {
1056 // Remove handle invalidation notices as handles are going out of scope.
1057 // The same region may be re-entered leading to incorrect invalidation
1058 // errors.
1059 for (Block &block : *region) {
1060 for (Value handle : block.getArguments()) {
1061 state.invalidatedHandles.erase(Val: handle);
1062 }
1063 for (Operation &op : block) {
1064 for (Value handle : op.getResults()) {
1065 state.invalidatedHandles.erase(Val: handle);
1066 }
1067 }
1068 }
1069
1070#if LLVM_ENABLE_ABI_BREAKING_CHECKS
1071 // Remember pointers to payload ops referenced by the handles going out of
1072 // scope.
1073 SmallVector<Operation *> referencedOps =
1074 llvm::to_vector(Range: llvm::make_first_range(c&: state.mappings[region]->reverse));
1075#endif // LLVM_ENABLE_ABI_BREAKING_CHECKS
1076
1077 state.mappings.erase(Key: region);
1078 state.regionStack.pop_back();
1079}
1080
1081//===----------------------------------------------------------------------===//
1082// TransformResults
1083//===----------------------------------------------------------------------===//
1084
1085transform::TransformResults::TransformResults(unsigned numSegments) {
1086 operations.appendEmptyRows(num: numSegments);
1087 params.appendEmptyRows(num: numSegments);
1088 values.appendEmptyRows(num: numSegments);
1089}
1090
1091void transform::TransformResults::setParams(
1092 OpResult value, ArrayRef<transform::TransformState::Param> params) {
1093 int64_t position = value.getResultNumber();
1094 assert(position < static_cast<int64_t>(this->params.size()) &&
1095 "setting params for a non-existent handle");
1096 assert(this->params[position].data() == nullptr && "params already set");
1097 assert(operations[position].data() == nullptr &&
1098 "another kind of results already set");
1099 assert(values[position].data() == nullptr &&
1100 "another kind of results already set");
1101 this->params.replace(pos: position, elements&: params);
1102}
1103
1104void transform::TransformResults::setMappedValues(
1105 OpResult handle, ArrayRef<MappedValue> values) {
1106 DiagnosedSilenceableFailure diag = dispatchMappedValues(
1107 handle, values,
1108 operationsFn: [&](ArrayRef<Operation *> operations) {
1109 return set(value: handle, ops&: operations), success();
1110 },
1111 paramsFn: [&](ArrayRef<Param> params) {
1112 return setParams(value: handle, params), success();
1113 },
1114 valuesFn: [&](ValueRange payloadValues) {
1115 return setValues(handle, values&: payloadValues), success();
1116 });
1117#ifndef NDEBUG
1118 if (!diag.succeeded())
1119 llvm::dbgs() << diag.getStatusString() << "\n";
1120 assert(diag.succeeded() && "incorrect mapping");
1121#endif // NDEBUG
1122 (void)diag.silence();
1123}
1124
1125void transform::TransformResults::setRemainingToEmpty(
1126 transform::TransformOpInterface transform) {
1127 for (OpResult opResult : transform->getResults()) {
1128 if (!isSet(opResult.getResultNumber()))
1129 setMappedValues(opResult, {});
1130 }
1131}
1132
1133ArrayRef<Operation *>
1134transform::TransformResults::get(unsigned resultNumber) const {
1135 assert(resultNumber < operations.size() &&
1136 "querying results for a non-existent handle");
1137 assert(operations[resultNumber].data() != nullptr &&
1138 "querying unset results (values or params expected?)");
1139 return operations[resultNumber];
1140}
1141
1142ArrayRef<transform::TransformState::Param>
1143transform::TransformResults::getParams(unsigned resultNumber) const {
1144 assert(resultNumber < params.size() &&
1145 "querying params for a non-existent handle");
1146 assert(params[resultNumber].data() != nullptr &&
1147 "querying unset params (ops or values expected?)");
1148 return params[resultNumber];
1149}
1150
1151ArrayRef<Value>
1152transform::TransformResults::getValues(unsigned resultNumber) const {
1153 assert(resultNumber < values.size() &&
1154 "querying values for a non-existent handle");
1155 assert(values[resultNumber].data() != nullptr &&
1156 "querying unset values (ops or params expected?)");
1157 return values[resultNumber];
1158}
1159
1160bool transform::TransformResults::isParam(unsigned resultNumber) const {
1161 assert(resultNumber < params.size() &&
1162 "querying association for a non-existent handle");
1163 return params[resultNumber].data() != nullptr;
1164}
1165
1166bool transform::TransformResults::isValue(unsigned resultNumber) const {
1167 assert(resultNumber < values.size() &&
1168 "querying association for a non-existent handle");
1169 return values[resultNumber].data() != nullptr;
1170}
1171
1172bool transform::TransformResults::isSet(unsigned resultNumber) const {
1173 assert(resultNumber < params.size() &&
1174 "querying association for a non-existent handle");
1175 return params[resultNumber].data() != nullptr ||
1176 operations[resultNumber].data() != nullptr ||
1177 values[resultNumber].data() != nullptr;
1178}
1179
1180//===----------------------------------------------------------------------===//
1181// TrackingListener
1182//===----------------------------------------------------------------------===//
1183
1184transform::TrackingListener::TrackingListener(TransformState &state,
1185 TransformOpInterface op,
1186 TrackingListenerConfig config)
1187 : TransformState::Extension(state), transformOp(op), config(config) {
1188 if (op) {
1189 for (OpOperand *opOperand : transformOp.getConsumedHandleOpOperands()) {
1190 consumedHandles.insert(opOperand->get());
1191 }
1192 }
1193}
1194
1195Operation *transform::TrackingListener::getCommonDefiningOp(ValueRange values) {
1196 Operation *defOp = nullptr;
1197 for (Value v : values) {
1198 // Skip empty values.
1199 if (!v)
1200 continue;
1201 if (!defOp) {
1202 defOp = v.getDefiningOp();
1203 continue;
1204 }
1205 if (defOp != v.getDefiningOp())
1206 return nullptr;
1207 }
1208 return defOp;
1209}
1210
1211DiagnosedSilenceableFailure transform::TrackingListener::findReplacementOp(
1212 Operation *&result, Operation *op, ValueRange newValues) const {
1213 assert(op->getNumResults() == newValues.size() &&
1214 "invalid number of replacement values");
1215 SmallVector<Value> values(newValues.begin(), newValues.end());
1216
1217 DiagnosedSilenceableFailure diag = emitSilenceableFailure(
1218 getTransformOp(), "tracking listener failed to find replacement op "
1219 "during application of this transform op");
1220
1221 do {
1222 // If the replacement values belong to different ops, drop the mapping.
1223 Operation *defOp = getCommonDefiningOp(values);
1224 if (!defOp) {
1225 diag.attachNote() << "replacement values belong to different ops";
1226 return diag;
1227 }
1228
1229 // Skip through ops that implement CastOpInterface.
1230 if (config.skipCastOps && isa<CastOpInterface>(defOp)) {
1231 values.clear();
1232 values.assign(in_start: defOp->getOperands().begin(), in_end: defOp->getOperands().end());
1233 diag.attachNote(loc: defOp->getLoc())
1234 << "using output of 'CastOpInterface' op";
1235 continue;
1236 }
1237
1238 // If the defining op has the same name or we do not care about the name of
1239 // op replacements at all, we take it as a replacement.
1240 if (!config.requireMatchingReplacementOpName ||
1241 op->getName() == defOp->getName()) {
1242 result = defOp;
1243 return DiagnosedSilenceableFailure::success();
1244 }
1245
1246 // Replacing an op with a constant-like equivalent is a common
1247 // canonicalization.
1248 if (defOp->hasTrait<OpTrait::ConstantLike>()) {
1249 result = defOp;
1250 return DiagnosedSilenceableFailure::success();
1251 }
1252
1253 values.clear();
1254
1255 // Skip through ops that implement FindPayloadReplacementOpInterface.
1256 if (auto findReplacementOpInterface =
1257 dyn_cast<FindPayloadReplacementOpInterface>(defOp)) {
1258 values.assign(findReplacementOpInterface.getNextOperands());
1259 diag.attachNote(loc: defOp->getLoc()) << "using operands provided by "
1260 "'FindPayloadReplacementOpInterface'";
1261 continue;
1262 }
1263 } while (!values.empty());
1264
1265 diag.attachNote() << "ran out of suitable replacement values";
1266 return diag;
1267}
1268
1269void transform::TrackingListener::notifyMatchFailure(
1270 Location loc, function_ref<void(Diagnostic &)> reasonCallback) {
1271 LLVM_DEBUG({
1272 Diagnostic diag(loc, DiagnosticSeverity::Remark);
1273 reasonCallback(diag);
1274 DBGS() << "Match Failure : " << diag.str() << "\n";
1275 });
1276}
1277
1278void transform::TrackingListener::notifyOperationErased(Operation *op) {
1279 // Remove mappings for result values.
1280 for (OpResult value : op->getResults())
1281 (void)replacePayloadValue(value, nullptr);
1282 // Remove mapping for op.
1283 (void)replacePayloadOp(op, nullptr);
1284}
1285
1286void transform::TrackingListener::notifyOperationReplaced(
1287 Operation *op, ValueRange newValues) {
1288 assert(op->getNumResults() == newValues.size() &&
1289 "invalid number of replacement values");
1290
1291 // Replace value handles.
1292 for (auto [oldValue, newValue] : llvm::zip(t: op->getResults(), u&: newValues))
1293 (void)replacePayloadValue(oldValue, newValue);
1294
1295 // Replace op handle.
1296 SmallVector<Value> opHandles;
1297 if (failed(getTransformState().getHandlesForPayloadOp(
1298 op, opHandles, /*includeOutOfScope=*/true))) {
1299 // Op is not tracked.
1300 return;
1301 }
1302
1303 // Helper function to check if the current transform op consumes any handle
1304 // that is mapped to `op`.
1305 //
1306 // Note: If a handle was consumed, there shouldn't be any alive users, so it
1307 // is not really necessary to check for consumed handles. However, in case
1308 // there are indeed alive handles that were consumed (which is undefined
1309 // behavior) and a replacement op could not be found, we want to fail with a
1310 // nicer error message: "op uses a handle invalidated..." instead of "could
1311 // not find replacement op". This nicer error is produced later.
1312 auto handleWasConsumed = [&] {
1313 return llvm::any_of(Range&: opHandles,
1314 P: [&](Value h) { return consumedHandles.contains(V: h); });
1315 };
1316
1317 // Check if there are any handles that must be updated.
1318 Value aliveHandle;
1319 if (config.skipHandleFn) {
1320 auto it = llvm::find_if(Range&: opHandles,
1321 P: [&](Value v) { return !config.skipHandleFn(v); });
1322 if (it != opHandles.end())
1323 aliveHandle = *it;
1324 } else if (!opHandles.empty()) {
1325 aliveHandle = opHandles.front();
1326 }
1327 if (!aliveHandle || handleWasConsumed()) {
1328 // The op is tracked but the corresponding handles are dead or were
1329 // consumed. Drop the op form the mapping.
1330 (void)replacePayloadOp(op, nullptr);
1331 return;
1332 }
1333
1334 Operation *replacement;
1335 DiagnosedSilenceableFailure diag =
1336 findReplacementOp(result&: replacement, op, newValues);
1337 // If the op is tracked but no replacement op was found, send a
1338 // notification.
1339 if (!diag.succeeded()) {
1340 diag.attachNote(loc: aliveHandle.getLoc())
1341 << "replacement is required because this handle must be updated";
1342 notifyPayloadReplacementNotFound(op, values: newValues, diag: std::move(diag));
1343 (void)replacePayloadOp(op, nullptr);
1344 return;
1345 }
1346
1347 (void)replacePayloadOp(op, replacement);
1348}
1349
1350transform::ErrorCheckingTrackingListener::~ErrorCheckingTrackingListener() {
1351 // The state of the ErrorCheckingTrackingListener must be checked and reset
1352 // if there was an error. This is to prevent errors from accidentally being
1353 // missed.
1354 assert(status.succeeded() && "listener state was not checked");
1355}
1356
1357DiagnosedSilenceableFailure
1358transform::ErrorCheckingTrackingListener::checkAndResetError() {
1359 DiagnosedSilenceableFailure s = std::move(status);
1360 status = DiagnosedSilenceableFailure::success();
1361 errorCounter = 0;
1362 return s;
1363}
1364
1365bool transform::ErrorCheckingTrackingListener::failed() const {
1366 return !status.succeeded();
1367}
1368
1369void transform::ErrorCheckingTrackingListener::notifyPayloadReplacementNotFound(
1370 Operation *op, ValueRange values, DiagnosedSilenceableFailure &&diag) {
1371
1372 // Merge potentially existing diags and store the result in the listener.
1373 SmallVector<Diagnostic> diags;
1374 diag.takeDiagnostics(diags);
1375 if (!status.succeeded())
1376 status.takeDiagnostics(diags);
1377 status = DiagnosedSilenceableFailure::silenceableFailure(diag: std::move(diags));
1378
1379 // Report more details.
1380 status.attachNote(loc: op->getLoc()) << "[" << errorCounter << "] replaced op";
1381 for (auto &&[index, value] : llvm::enumerate(First&: values))
1382 status.attachNote(loc: value.getLoc())
1383 << "[" << errorCounter << "] replacement value " << index;
1384 ++errorCounter;
1385}
1386
1387//===----------------------------------------------------------------------===//
1388// TransformRewriter
1389//===----------------------------------------------------------------------===//
1390
1391transform::TransformRewriter::TransformRewriter(
1392 MLIRContext *ctx, ErrorCheckingTrackingListener *listener)
1393 : RewriterBase(ctx), listener(listener) {
1394 setListener(listener);
1395}
1396
1397bool transform::TransformRewriter::hasTrackingFailures() const {
1398 return listener->failed();
1399}
1400
1401/// Silence all tracking failures that have been encountered so far.
1402void transform::TransformRewriter::silenceTrackingFailure() {
1403 if (hasTrackingFailures()) {
1404 DiagnosedSilenceableFailure status = listener->checkAndResetError();
1405 (void)status.silence();
1406 }
1407}
1408
1409LogicalResult transform::TransformRewriter::notifyPayloadOperationReplaced(
1410 Operation *op, Operation *replacement) {
1411 return listener->replacePayloadOp(op, replacement);
1412}
1413
1414//===----------------------------------------------------------------------===//
1415// Utilities for TransformEachOpTrait.
1416//===----------------------------------------------------------------------===//
1417
1418LogicalResult
1419transform::detail::checkNestedConsumption(Location loc,
1420 ArrayRef<Operation *> targets) {
1421 for (auto &&[position, parent] : llvm::enumerate(First&: targets)) {
1422 for (Operation *child : targets.drop_front(N: position + 1)) {
1423 if (parent->isAncestor(other: child)) {
1424 InFlightDiagnostic diag =
1425 emitError(loc)
1426 << "transform operation consumes a handle pointing to an ancestor "
1427 "payload operation before its descendant";
1428 diag.attachNote()
1429 << "the ancestor is likely erased or rewritten before the "
1430 "descendant is accessed, leading to undefined behavior";
1431 diag.attachNote(noteLoc: parent->getLoc()) << "ancestor payload op";
1432 diag.attachNote(noteLoc: child->getLoc()) << "descendant payload op";
1433 return diag;
1434 }
1435 }
1436 }
1437 return success();
1438}
1439
1440LogicalResult
1441transform::detail::checkApplyToOne(Operation *transformOp,
1442 Location payloadOpLoc,
1443 const ApplyToEachResultList &partialResult) {
1444 Location transformOpLoc = transformOp->getLoc();
1445 StringRef transformOpName = transformOp->getName().getStringRef();
1446 unsigned expectedNumResults = transformOp->getNumResults();
1447
1448 // Reuse the emission of the diagnostic note.
1449 auto emitDiag = [&]() {
1450 auto diag = mlir::emitError(loc: transformOpLoc);
1451 diag.attachNote(noteLoc: payloadOpLoc) << "when applied to this op";
1452 return diag;
1453 };
1454
1455 if (partialResult.size() != expectedNumResults) {
1456 auto diag = emitDiag() << "application of " << transformOpName
1457 << " expected to produce " << expectedNumResults
1458 << " results (actually produced "
1459 << partialResult.size() << ").";
1460 diag.attachNote(noteLoc: transformOpLoc)
1461 << "if you need variadic results, consider a generic `apply` "
1462 << "instead of the specialized `applyToOne`.";
1463 return failure();
1464 }
1465
1466 // Check that the right kind of value was produced.
1467 for (const auto &[ptr, res] :
1468 llvm::zip(t: partialResult, u: transformOp->getResults())) {
1469 if (ptr.isNull())
1470 continue;
1471 if (llvm::isa<TransformHandleTypeInterface>(res.getType()) &&
1472 !ptr.is<Operation *>()) {
1473 return emitDiag() << "application of " << transformOpName
1474 << " expected to produce an Operation * for result #"
1475 << res.getResultNumber();
1476 }
1477 if (llvm::isa<TransformParamTypeInterface>(res.getType()) &&
1478 !ptr.is<Attribute>()) {
1479 return emitDiag() << "application of " << transformOpName
1480 << " expected to produce an Attribute for result #"
1481 << res.getResultNumber();
1482 }
1483 if (llvm::isa<TransformValueHandleTypeInterface>(res.getType()) &&
1484 !ptr.is<Value>()) {
1485 return emitDiag() << "application of " << transformOpName
1486 << " expected to produce a Value for result #"
1487 << res.getResultNumber();
1488 }
1489 }
1490 return success();
1491}
1492
1493template <typename T>
1494static SmallVector<T> castVector(ArrayRef<transform::MappedValue> range) {
1495 return llvm::to_vector(llvm::map_range(
1496 range, [](transform::MappedValue value) { return value.get<T>(); }));
1497}
1498
1499void transform::detail::setApplyToOneResults(
1500 Operation *transformOp, TransformResults &transformResults,
1501 ArrayRef<ApplyToEachResultList> results) {
1502 SmallVector<SmallVector<MappedValue>> transposed;
1503 transposed.resize(N: transformOp->getNumResults());
1504 for (const ApplyToEachResultList &partialResults : results) {
1505 if (llvm::any_of(Range: partialResults,
1506 P: [](MappedValue value) { return value.isNull(); }))
1507 continue;
1508 assert(transformOp->getNumResults() == partialResults.size() &&
1509 "expected as many partial results as op as results");
1510 for (auto [i, value] : llvm::enumerate(First: partialResults))
1511 transposed[i].push_back(Elt: value);
1512 }
1513
1514 for (OpResult r : transformOp->getResults()) {
1515 unsigned position = r.getResultNumber();
1516 if (llvm::isa<TransformParamTypeInterface>(r.getType())) {
1517 transformResults.setParams(value: r,
1518 params: castVector<Attribute>(range: transposed[position]));
1519 } else if (llvm::isa<TransformValueHandleTypeInterface>(r.getType())) {
1520 transformResults.setValues(handle: r, values: castVector<Value>(range: transposed[position]));
1521 } else {
1522 transformResults.set(value: r, ops: castVector<Operation *>(range: transposed[position]));
1523 }
1524 }
1525}
1526
1527//===----------------------------------------------------------------------===//
1528// Utilities for implementing transform ops with regions.
1529//===----------------------------------------------------------------------===//
1530
1531void transform::detail::prepareValueMappings(
1532 SmallVectorImpl<SmallVector<transform::MappedValue>> &mappings,
1533 ValueRange values, const transform::TransformState &state) {
1534 for (Value operand : values) {
1535 SmallVector<MappedValue> &mapped = mappings.emplace_back();
1536 if (llvm::isa<TransformHandleTypeInterface>(operand.getType())) {
1537 llvm::append_range(C&: mapped, R: state.getPayloadOps(value: operand));
1538 } else if (llvm::isa<TransformValueHandleTypeInterface>(
1539 operand.getType())) {
1540 llvm::append_range(C&: mapped, R: state.getPayloadValues(handleValue: operand));
1541 } else {
1542 assert(llvm::isa<TransformParamTypeInterface>(operand.getType()) &&
1543 "unsupported kind of transform dialect value");
1544 llvm::append_range(C&: mapped, R: state.getParams(value: operand));
1545 }
1546 }
1547}
1548
1549void transform::detail::forwardTerminatorOperands(
1550 Block *block, transform::TransformState &state,
1551 transform::TransformResults &results) {
1552 for (auto &&[terminatorOperand, result] :
1553 llvm::zip(t: block->getTerminator()->getOperands(),
1554 u: block->getParentOp()->getOpResults())) {
1555 if (llvm::isa<transform::TransformHandleTypeInterface>(result.getType())) {
1556 results.set(value: result, ops: state.getPayloadOps(value: terminatorOperand));
1557 } else if (llvm::isa<transform::TransformValueHandleTypeInterface>(
1558 result.getType())) {
1559 results.setValues(handle: result, values: state.getPayloadValues(handleValue: terminatorOperand));
1560 } else {
1561 assert(
1562 llvm::isa<transform::TransformParamTypeInterface>(result.getType()) &&
1563 "unhandled transform type interface");
1564 results.setParams(value: result, params: state.getParams(value: terminatorOperand));
1565 }
1566 }
1567}
1568
1569transform::TransformState
1570transform::detail::makeTransformStateForTesting(Region *region,
1571 Operation *payloadRoot) {
1572 return TransformState(region, payloadRoot);
1573}
1574
1575//===----------------------------------------------------------------------===//
1576// Utilities for PossibleTopLevelTransformOpTrait.
1577//===----------------------------------------------------------------------===//
1578
1579/// Appends to `effects` the memory effect instances on `target` with the same
1580/// resource and effect as the ones the operation `iface` having on `source`.
1581static void
1582remapEffects(MemoryEffectOpInterface iface, BlockArgument source, Value target,
1583 SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
1584 SmallVector<MemoryEffects::EffectInstance> nestedEffects;
1585 iface.getEffectsOnValue(source, nestedEffects);
1586 for (const auto &effect : nestedEffects)
1587 effects.emplace_back(Args: effect.getEffect(), Args&: target, Args: effect.getResource());
1588}
1589
1590/// Appends to `effects` the same effects as the operations of `block` have on
1591/// block arguments but associated with `operands.`
1592static void
1593remapArgumentEffects(Block &block, ValueRange operands,
1594 SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
1595 for (Operation &op : block) {
1596 auto iface = dyn_cast<MemoryEffectOpInterface>(&op);
1597 if (!iface)
1598 continue;
1599
1600 for (auto &&[source, target] : llvm::zip(t: block.getArguments(), u&: operands)) {
1601 remapEffects(iface, source, target, effects);
1602 }
1603
1604 SmallVector<MemoryEffects::EffectInstance> nestedEffects;
1605 iface.getEffectsOnResource(transform::PayloadIRResource::get(),
1606 nestedEffects);
1607 llvm::append_range(C&: effects, R&: nestedEffects);
1608 }
1609}
1610
1611void transform::detail::getPotentialTopLevelEffects(
1612 Operation *operation, Value root, Block &body,
1613 SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
1614 transform::onlyReadsHandle(handles: operation->getOperands(), effects);
1615 transform::producesHandle(handles: operation->getResults(), effects);
1616
1617 if (!root) {
1618 for (Operation &op : body) {
1619 auto iface = dyn_cast<MemoryEffectOpInterface>(&op);
1620 if (!iface)
1621 continue;
1622
1623 SmallVector<MemoryEffects::EffectInstance, 2> nestedEffects;
1624 iface.getEffects(effects);
1625 }
1626 return;
1627 }
1628
1629 // Carry over all effects on arguments of the entry block as those on the
1630 // operands, this is the same value just remapped.
1631 remapArgumentEffects(block&: body, operands: operation->getOperands(), effects);
1632}
1633
1634LogicalResult transform::detail::mapPossibleTopLevelTransformOpBlockArguments(
1635 TransformState &state, Operation *op, Region &region) {
1636 SmallVector<Operation *> targets;
1637 SmallVector<SmallVector<MappedValue>> extraMappings;
1638 if (op->getNumOperands() != 0) {
1639 llvm::append_range(C&: targets, R: state.getPayloadOps(value: op->getOperand(idx: 0)));
1640 prepareValueMappings(mappings&: extraMappings, values: op->getOperands().drop_front(), state);
1641 } else {
1642 if (state.getNumTopLevelMappings() !=
1643 region.front().getNumArguments() - 1) {
1644 return emitError(loc: op->getLoc())
1645 << "operation expects " << region.front().getNumArguments() - 1
1646 << " extra value bindings, but " << state.getNumTopLevelMappings()
1647 << " were provided to the interpreter";
1648 }
1649
1650 targets.push_back(Elt: state.getTopLevel());
1651
1652 for (unsigned i = 0, e = state.getNumTopLevelMappings(); i < e; ++i)
1653 extraMappings.push_back(Elt: llvm::to_vector(Range: state.getTopLevelMapping(position: i)));
1654 }
1655
1656 if (failed(result: state.mapBlockArguments(argument: region.front().getArgument(i: 0), operations: targets)))
1657 return failure();
1658
1659 for (BlockArgument argument : region.front().getArguments().drop_front()) {
1660 if (failed(result: state.mapBlockArgument(
1661 argument, values: extraMappings[argument.getArgNumber() - 1])))
1662 return failure();
1663 }
1664
1665 return success();
1666}
1667
1668LogicalResult
1669transform::detail::verifyPossibleTopLevelTransformOpTrait(Operation *op) {
1670 // Attaching this trait without the interface is a misuse of the API, but it
1671 // cannot be caught via a static_assert because interface registration is
1672 // dynamic.
1673 assert(isa<TransformOpInterface>(op) &&
1674 "should implement TransformOpInterface to have "
1675 "PossibleTopLevelTransformOpTrait");
1676
1677 if (op->getNumRegions() < 1)
1678 return op->emitOpError() << "expects at least one region";
1679
1680 Region *bodyRegion = &op->getRegion(index: 0);
1681 if (!llvm::hasNItems(C&: *bodyRegion, N: 1))
1682 return op->emitOpError() << "expects a single-block region";
1683
1684 Block *body = &bodyRegion->front();
1685 if (body->getNumArguments() == 0) {
1686 return op->emitOpError()
1687 << "expects the entry block to have at least one argument";
1688 }
1689 if (!llvm::isa<TransformHandleTypeInterface>(
1690 body->getArgument(0).getType())) {
1691 return op->emitOpError()
1692 << "expects the first entry block argument to be of type "
1693 "implementing TransformHandleTypeInterface";
1694 }
1695 BlockArgument arg = body->getArgument(i: 0);
1696 if (op->getNumOperands() != 0) {
1697 if (arg.getType() != op->getOperand(idx: 0).getType()) {
1698 return op->emitOpError()
1699 << "expects the type of the block argument to match "
1700 "the type of the operand";
1701 }
1702 }
1703 for (BlockArgument arg : body->getArguments().drop_front()) {
1704 if (llvm::isa<TransformHandleTypeInterface, TransformParamTypeInterface,
1705 TransformValueHandleTypeInterface>(arg.getType()))
1706 continue;
1707
1708 InFlightDiagnostic diag =
1709 op->emitOpError()
1710 << "expects trailing entry block arguments to be of type implementing "
1711 "TransformHandleTypeInterface, TransformValueHandleTypeInterface or "
1712 "TransformParamTypeInterface";
1713 diag.attachNote() << "argument #" << arg.getArgNumber() << " does not";
1714 return diag;
1715 }
1716
1717 if (auto *parent =
1718 op->getParentWithTrait<PossibleTopLevelTransformOpTrait>()) {
1719 if (op->getNumOperands() != body->getNumArguments()) {
1720 InFlightDiagnostic diag =
1721 op->emitOpError()
1722 << "expects operands to be provided for a nested op";
1723 diag.attachNote(noteLoc: parent->getLoc())
1724 << "nested in another possible top-level op";
1725 return diag;
1726 }
1727 }
1728
1729 return success();
1730}
1731
1732//===----------------------------------------------------------------------===//
1733// Utilities for ParamProducedTransformOpTrait.
1734//===----------------------------------------------------------------------===//
1735
1736void transform::detail::getParamProducerTransformOpTraitEffects(
1737 Operation *op, SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
1738 producesHandle(handles: op->getResults(), effects);
1739 bool hasPayloadOperands = false;
1740 for (Value operand : op->getOperands()) {
1741 onlyReadsHandle(handles: operand, effects);
1742 if (llvm::isa<TransformHandleTypeInterface,
1743 TransformValueHandleTypeInterface>(operand.getType()))
1744 hasPayloadOperands = true;
1745 }
1746 if (hasPayloadOperands)
1747 onlyReadsPayload(effects);
1748}
1749
1750LogicalResult
1751transform::detail::verifyParamProducerTransformOpTrait(Operation *op) {
1752 // Interfaces can be attached dynamically, so this cannot be a static
1753 // assert.
1754 if (!op->getName().getInterface<MemoryEffectOpInterface>()) {
1755 llvm::report_fatal_error(
1756 reason: Twine("ParamProducerTransformOpTrait must be attached to an op that "
1757 "implements MemoryEffectsOpInterface, found on ") +
1758 op->getName().getStringRef());
1759 }
1760 for (Value result : op->getResults()) {
1761 if (llvm::isa<TransformParamTypeInterface>(result.getType()))
1762 continue;
1763 return op->emitOpError()
1764 << "ParamProducerTransformOpTrait attached to this op expects "
1765 "result types to implement TransformParamTypeInterface";
1766 }
1767 return success();
1768}
1769
1770//===----------------------------------------------------------------------===//
1771// Memory effects.
1772//===----------------------------------------------------------------------===//
1773
1774void transform::consumesHandle(
1775 ValueRange handles,
1776 SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
1777 for (Value handle : handles) {
1778 effects.emplace_back(Args: MemoryEffects::Read::get(), Args&: handle,
1779 Args: TransformMappingResource::get());
1780 effects.emplace_back(Args: MemoryEffects::Free::get(), Args&: handle,
1781 Args: TransformMappingResource::get());
1782 }
1783}
1784
1785/// Returns `true` if the given list of effects instances contains an instance
1786/// with the effect type specified as template parameter.
1787template <typename EffectTy, typename ResourceTy, typename Range>
1788static bool hasEffect(Range &&effects) {
1789 return llvm::any_of(effects, [](const MemoryEffects::EffectInstance &effect) {
1790 return isa<EffectTy>(effect.getEffect()) &&
1791 isa<ResourceTy>(effect.getResource());
1792 });
1793}
1794
1795bool transform::isHandleConsumed(Value handle,
1796 transform::TransformOpInterface transform) {
1797 auto iface = cast<MemoryEffectOpInterface>(transform.getOperation());
1798 SmallVector<MemoryEffects::EffectInstance> effects;
1799 iface.getEffectsOnValue(handle, effects);
1800 return ::hasEffect<MemoryEffects::Read, TransformMappingResource>(effects) &&
1801 ::hasEffect<MemoryEffects::Free, TransformMappingResource>(effects);
1802}
1803
1804void transform::producesHandle(
1805 ValueRange handles,
1806 SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
1807 for (Value handle : handles) {
1808 effects.emplace_back(Args: MemoryEffects::Allocate::get(), Args&: handle,
1809 Args: TransformMappingResource::get());
1810 effects.emplace_back(Args: MemoryEffects::Write::get(), Args&: handle,
1811 Args: TransformMappingResource::get());
1812 }
1813}
1814
1815void transform::onlyReadsHandle(
1816 ValueRange handles,
1817 SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
1818 for (Value handle : handles) {
1819 effects.emplace_back(Args: MemoryEffects::Read::get(), Args&: handle,
1820 Args: TransformMappingResource::get());
1821 }
1822}
1823
1824void transform::modifiesPayload(
1825 SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
1826 effects.emplace_back(Args: MemoryEffects::Read::get(), Args: PayloadIRResource::get());
1827 effects.emplace_back(Args: MemoryEffects::Write::get(), Args: PayloadIRResource::get());
1828}
1829
1830void transform::onlyReadsPayload(
1831 SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
1832 effects.emplace_back(Args: MemoryEffects::Read::get(), Args: PayloadIRResource::get());
1833}
1834
1835bool transform::doesModifyPayload(transform::TransformOpInterface transform) {
1836 auto iface = cast<MemoryEffectOpInterface>(transform.getOperation());
1837 SmallVector<MemoryEffects::EffectInstance> effects;
1838 iface.getEffects(effects);
1839 return ::hasEffect<MemoryEffects::Write, PayloadIRResource>(effects);
1840}
1841
1842bool transform::doesReadPayload(transform::TransformOpInterface transform) {
1843 auto iface = cast<MemoryEffectOpInterface>(transform.getOperation());
1844 SmallVector<MemoryEffects::EffectInstance> effects;
1845 iface.getEffects(effects);
1846 return ::hasEffect<MemoryEffects::Read, PayloadIRResource>(effects);
1847}
1848
1849void transform::getConsumedBlockArguments(
1850 Block &block, llvm::SmallDenseSet<unsigned int> &consumedArguments) {
1851 SmallVector<MemoryEffects::EffectInstance> effects;
1852 for (Operation &nested : block) {
1853 auto iface = dyn_cast<MemoryEffectOpInterface>(nested);
1854 if (!iface)
1855 continue;
1856
1857 effects.clear();
1858 iface.getEffects(effects);
1859 for (const MemoryEffects::EffectInstance &effect : effects) {
1860 BlockArgument argument =
1861 dyn_cast_or_null<BlockArgument>(Val: effect.getValue());
1862 if (!argument || argument.getOwner() != &block ||
1863 !isa<MemoryEffects::Free>(Val: effect.getEffect()) ||
1864 effect.getResource() != transform::TransformMappingResource::get()) {
1865 continue;
1866 }
1867 consumedArguments.insert(V: argument.getArgNumber());
1868 }
1869 }
1870}
1871
1872//===----------------------------------------------------------------------===//
1873// Utilities for TransformOpInterface.
1874//===----------------------------------------------------------------------===//
1875
1876SmallVector<OpOperand *> transform::detail::getConsumedHandleOpOperands(
1877 TransformOpInterface transformOp) {
1878 SmallVector<OpOperand *> consumedOperands;
1879 consumedOperands.reserve(N: transformOp->getNumOperands());
1880 auto memEffectInterface =
1881 cast<MemoryEffectOpInterface>(transformOp.getOperation());
1882 SmallVector<MemoryEffects::EffectInstance, 2> effects;
1883 for (OpOperand &target : transformOp->getOpOperands()) {
1884 effects.clear();
1885 memEffectInterface.getEffectsOnValue(target.get(), effects);
1886 if (llvm::any_of(effects, [](const MemoryEffects::EffectInstance &effect) {
1887 return isa<transform::TransformMappingResource>(
1888 effect.getResource()) &&
1889 isa<MemoryEffects::Free>(effect.getEffect());
1890 })) {
1891 consumedOperands.push_back(&target);
1892 }
1893 }
1894 return consumedOperands;
1895}
1896
1897LogicalResult transform::detail::verifyTransformOpInterface(Operation *op) {
1898 auto iface = cast<MemoryEffectOpInterface>(op);
1899 SmallVector<MemoryEffects::EffectInstance> effects;
1900 iface.getEffects(effects);
1901
1902 auto effectsOn = [&](Value value) {
1903 return llvm::make_filter_range(
1904 Range&: effects, Pred: [value](const MemoryEffects::EffectInstance &instance) {
1905 return instance.getValue() == value;
1906 });
1907 };
1908
1909 std::optional<unsigned> firstConsumedOperand;
1910 for (OpOperand &operand : op->getOpOperands()) {
1911 auto range = effectsOn(operand.get());
1912 if (range.empty()) {
1913 InFlightDiagnostic diag =
1914 op->emitError() << "TransformOpInterface requires memory effects "
1915 "on operands to be specified";
1916 diag.attachNote() << "no effects specified for operand #"
1917 << operand.getOperandNumber();
1918 return diag;
1919 }
1920 if (::hasEffect<MemoryEffects::Allocate, TransformMappingResource>(effects&: range)) {
1921 InFlightDiagnostic diag = op->emitError()
1922 << "TransformOpInterface did not expect "
1923 "'allocate' memory effect on an operand";
1924 diag.attachNote() << "specified for operand #"
1925 << operand.getOperandNumber();
1926 return diag;
1927 }
1928 if (!firstConsumedOperand &&
1929 ::hasEffect<MemoryEffects::Free, TransformMappingResource>(effects&: range)) {
1930 firstConsumedOperand = operand.getOperandNumber();
1931 }
1932 }
1933
1934 if (firstConsumedOperand &&
1935 !::hasEffect<MemoryEffects::Write, PayloadIRResource>(effects)) {
1936 InFlightDiagnostic diag =
1937 op->emitError()
1938 << "TransformOpInterface expects ops consuming operands to have a "
1939 "'write' effect on the payload resource";
1940 diag.attachNote() << "consumes operand #" << *firstConsumedOperand;
1941 return diag;
1942 }
1943
1944 for (OpResult result : op->getResults()) {
1945 auto range = effectsOn(result);
1946 if (!::hasEffect<MemoryEffects::Allocate, TransformMappingResource>(
1947 effects&: range)) {
1948 InFlightDiagnostic diag =
1949 op->emitError() << "TransformOpInterface requires 'allocate' memory "
1950 "effect to be specified for results";
1951 diag.attachNote() << "no 'allocate' effect specified for result #"
1952 << result.getResultNumber();
1953 return diag;
1954 }
1955 }
1956
1957 return success();
1958}
1959
1960//===----------------------------------------------------------------------===//
1961// Entry point.
1962//===----------------------------------------------------------------------===//
1963
1964LogicalResult transform::applyTransforms(
1965 Operation *payloadRoot, TransformOpInterface transform,
1966 const RaggedArray<MappedValue> &extraMapping,
1967 const TransformOptions &options, bool enforceToplevelTransformOp) {
1968 if (enforceToplevelTransformOp) {
1969 if (!transform->hasTrait<PossibleTopLevelTransformOpTrait>() ||
1970 transform->getNumOperands() != 0) {
1971 return transform->emitError()
1972 << "expected transform to start at the top-level transform op";
1973 }
1974 } else if (failed(
1975 detail::verifyPossibleTopLevelTransformOpTrait(op: transform))) {
1976 return failure();
1977 }
1978
1979 TransformState state(transform->getParentRegion(), payloadRoot, extraMapping,
1980 options);
1981 return state.applyTransform(transform: transform).checkAndReport();
1982}
1983
1984//===----------------------------------------------------------------------===//
1985// Generated interface implementation.
1986//===----------------------------------------------------------------------===//
1987
1988#include "mlir/Dialect/Transform/Interfaces/TransformInterfaces.cpp.inc"
1989#include "mlir/Dialect/Transform/Interfaces/TransformTypeInterfaces.cpp.inc"
1990

source code of mlir/lib/Dialect/Transform/Interfaces/TransformInterfaces.cpp