1//===- TransformInterfaces.cpp - Transform Dialect Interfaces -------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#include "mlir/Dialect/Transform/Interfaces/TransformInterfaces.h"
10
11#include "mlir/IR/Diagnostics.h"
12#include "mlir/IR/Operation.h"
13#include "mlir/IR/PatternMatch.h"
14#include "mlir/Interfaces/CastInterfaces.h"
15#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
16#include "llvm/ADT/STLExtras.h"
17#include "llvm/ADT/ScopeExit.h"
18#include "llvm/ADT/iterator.h"
19#include "llvm/Support/Debug.h"
20#include "llvm/Support/ErrorHandling.h"
21#include "llvm/Support/InterleavedRange.h"
22
23#define DEBUG_TYPE "transform-dialect"
24#define DEBUG_TYPE_FULL "transform-dialect-full"
25#define DEBUG_PRINT_AFTER_ALL "transform-dialect-print-top-level-after-all"
26#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "] ")
27#define LDBG(X) LLVM_DEBUG(DBGS() << (X))
28#define FULL_LDBG(X) DEBUG_WITH_TYPE(DEBUG_TYPE_FULL, (DBGS() << (X)))
29
30using namespace mlir;
31
32//===----------------------------------------------------------------------===//
33// Helper functions
34//===----------------------------------------------------------------------===//
35
36/// Return true if `a` happens before `b`, i.e., `a` or one of its ancestors
37/// properly dominates `b` and `b` is not inside `a`.
38static bool happensBefore(Operation *a, Operation *b) {
39 do {
40 if (a->isProperAncestor(other: b))
41 return false;
42 if (Operation *bAncestor = a->getBlock()->findAncestorOpInBlock(op&: *b)) {
43 return a->isBeforeInBlock(other: bAncestor);
44 }
45 } while ((a = a->getParentOp()));
46 return false;
47}
48
49//===----------------------------------------------------------------------===//
50// TransformState
51//===----------------------------------------------------------------------===//
52
53constexpr const Value transform::TransformState::kTopLevelValue;
54
55transform::TransformState::TransformState(
56 Region *region, Operation *payloadRoot,
57 const RaggedArray<MappedValue> &extraMappings,
58 const TransformOptions &options)
59 : topLevel(payloadRoot), options(options) {
60 topLevelMappedValues.reserve(size: extraMappings.size());
61 for (ArrayRef<MappedValue> mapping : extraMappings)
62 topLevelMappedValues.push_back(elements&: mapping);
63 if (region) {
64 RegionScope *scope = new RegionScope(*this, *region);
65 topLevelRegionScope.reset(p: scope);
66 }
67}
68
69Operation *transform::TransformState::getTopLevel() const { return topLevel; }
70
71ArrayRef<Operation *>
72transform::TransformState::getPayloadOpsView(Value value) const {
73 const TransformOpMapping &operationMapping = getMapping(value).direct;
74 auto iter = operationMapping.find(Val: value);
75 assert(iter != operationMapping.end() &&
76 "cannot find mapping for payload handle (param/value handle "
77 "provided?)");
78 return iter->getSecond();
79}
80
81ArrayRef<Attribute> transform::TransformState::getParams(Value value) const {
82 const ParamMapping &mapping = getMapping(value).params;
83 auto iter = mapping.find(Val: value);
84 assert(iter != mapping.end() && "cannot find mapping for param handle "
85 "(operation/value handle provided?)");
86 return iter->getSecond();
87}
88
89ArrayRef<Value>
90transform::TransformState::getPayloadValuesView(Value handleValue) const {
91 const ValueMapping &mapping = getMapping(value: handleValue).values;
92 auto iter = mapping.find(Val: handleValue);
93 assert(iter != mapping.end() && "cannot find mapping for value handle "
94 "(param/operation handle provided?)");
95 return iter->getSecond();
96}
97
98LogicalResult transform::TransformState::getHandlesForPayloadOp(
99 Operation *op, SmallVectorImpl<Value> &handles,
100 bool includeOutOfScope) const {
101 bool found = false;
102 for (const auto &[region, mapping] : llvm::reverse(C: mappings)) {
103 auto iterator = mapping->reverse.find(Val: op);
104 if (iterator != mapping->reverse.end()) {
105 llvm::append_range(C&: handles, R&: iterator->getSecond());
106 found = true;
107 }
108 // Stop looking when reaching a region that is isolated from above.
109 if (!includeOutOfScope &&
110 region->getParentOp()->hasTrait<OpTrait::IsIsolatedFromAbove>())
111 break;
112 }
113
114 return success(IsSuccess: found);
115}
116
117LogicalResult transform::TransformState::getHandlesForPayloadValue(
118 Value payloadValue, SmallVectorImpl<Value> &handles,
119 bool includeOutOfScope) const {
120 bool found = false;
121 for (const auto &[region, mapping] : llvm::reverse(C: mappings)) {
122 auto iterator = mapping->reverseValues.find(Val: payloadValue);
123 if (iterator != mapping->reverseValues.end()) {
124 llvm::append_range(C&: handles, R&: iterator->getSecond());
125 found = true;
126 }
127 // Stop looking when reaching a region that is isolated from above.
128 if (!includeOutOfScope &&
129 region->getParentOp()->hasTrait<OpTrait::IsIsolatedFromAbove>())
130 break;
131 }
132
133 return success(IsSuccess: found);
134}
135
136/// Given a list of MappedValues, cast them to the value kind implied by the
137/// interface of the handle type, and dispatch to one of the callbacks.
138static DiagnosedSilenceableFailure dispatchMappedValues(
139 Value handle, ArrayRef<transform::MappedValue> values,
140 function_ref<LogicalResult(ArrayRef<Operation *>)> operationsFn,
141 function_ref<LogicalResult(ArrayRef<transform::Param>)> paramsFn,
142 function_ref<LogicalResult(ValueRange)> valuesFn) {
143 if (llvm::isa<transform::TransformHandleTypeInterface>(handle.getType())) {
144 SmallVector<Operation *> operations;
145 operations.reserve(N: values.size());
146 for (transform::MappedValue value : values) {
147 if (auto *op = llvm::dyn_cast_if_present<Operation *>(Val&: value)) {
148 operations.push_back(Elt: op);
149 continue;
150 }
151 return emitSilenceableFailure(loc: handle.getLoc())
152 << "wrong kind of value provided for top-level operation handle";
153 }
154 if (failed(Result: operationsFn(operations)))
155 return DiagnosedSilenceableFailure::definiteFailure();
156 return DiagnosedSilenceableFailure::success();
157 }
158
159 if (llvm::isa<transform::TransformValueHandleTypeInterface>(
160 handle.getType())) {
161 SmallVector<Value> payloadValues;
162 payloadValues.reserve(N: values.size());
163 for (transform::MappedValue value : values) {
164 if (auto v = llvm::dyn_cast_if_present<Value>(Val&: value)) {
165 payloadValues.push_back(Elt: v);
166 continue;
167 }
168 return emitSilenceableFailure(loc: handle.getLoc())
169 << "wrong kind of value provided for the top-level value handle";
170 }
171 if (failed(Result: valuesFn(payloadValues)))
172 return DiagnosedSilenceableFailure::definiteFailure();
173 return DiagnosedSilenceableFailure::success();
174 }
175
176 assert(llvm::isa<transform::TransformParamTypeInterface>(handle.getType()) &&
177 "unsupported kind of block argument");
178 SmallVector<transform::Param> parameters;
179 parameters.reserve(N: values.size());
180 for (transform::MappedValue value : values) {
181 if (auto attr = llvm::dyn_cast_if_present<Attribute>(Val&: value)) {
182 parameters.push_back(Elt: attr);
183 continue;
184 }
185 return emitSilenceableFailure(loc: handle.getLoc())
186 << "wrong kind of value provided for top-level parameter";
187 }
188 if (failed(Result: paramsFn(parameters)))
189 return DiagnosedSilenceableFailure::definiteFailure();
190 return DiagnosedSilenceableFailure::success();
191}
192
193LogicalResult
194transform::TransformState::mapBlockArgument(BlockArgument argument,
195 ArrayRef<MappedValue> values) {
196 return dispatchMappedValues(
197 handle: argument, values,
198 operationsFn: [&](ArrayRef<Operation *> operations) {
199 return setPayloadOps(value: argument, targets: operations);
200 },
201 paramsFn: [&](ArrayRef<Param> params) {
202 return setParams(value: argument, params);
203 },
204 valuesFn: [&](ValueRange payloadValues) {
205 return setPayloadValues(handle: argument, payloadValues);
206 })
207 .checkAndReport();
208}
209
210LogicalResult transform::TransformState::mapBlockArguments(
211 Block::BlockArgListType arguments,
212 ArrayRef<SmallVector<MappedValue>> mapping) {
213 for (auto &&[argument, values] : llvm::zip_equal(t&: arguments, u&: mapping))
214 if (failed(Result: mapBlockArgument(argument, values)))
215 return failure();
216 return success();
217}
218
219LogicalResult
220transform::TransformState::setPayloadOps(Value value,
221 ArrayRef<Operation *> targets) {
222 assert(value != kTopLevelValue &&
223 "attempting to reset the transformation root");
224 assert(llvm::isa<TransformHandleTypeInterface>(value.getType()) &&
225 "wrong handle type");
226
227 for (Operation *target : targets) {
228 if (target)
229 continue;
230 return emitError(loc: value.getLoc())
231 << "attempting to assign a null payload op to this transform value";
232 }
233
234 auto iface = llvm::cast<TransformHandleTypeInterface>(value.getType());
235 DiagnosedSilenceableFailure result =
236 iface.checkPayload(value.getLoc(), targets);
237 if (failed(Result: result.checkAndReport()))
238 return failure();
239
240 // Setting new payload for the value without cleaning it first is a misuse of
241 // the API, assert here.
242 SmallVector<Operation *> storedTargets(targets);
243 Mappings &mappings = getMapping(value);
244 bool inserted =
245 mappings.direct.insert(KV: {value, std::move(storedTargets)}).second;
246 assert(inserted && "value is already associated with another list");
247 (void)inserted;
248
249 for (Operation *op : targets)
250 mappings.reverse[op].push_back(Elt: value);
251
252 return success();
253}
254
255LogicalResult
256transform::TransformState::setPayloadValues(Value handle,
257 ValueRange payloadValues) {
258 assert(handle != nullptr && "attempting to set params for a null value");
259 assert(llvm::isa<TransformValueHandleTypeInterface>(handle.getType()) &&
260 "wrong handle type");
261
262 for (Value payload : payloadValues) {
263 if (payload)
264 continue;
265 return emitError(loc: handle.getLoc()) << "attempting to assign a null payload "
266 "value to this transform handle";
267 }
268
269 auto iface = llvm::cast<TransformValueHandleTypeInterface>(handle.getType());
270 SmallVector<Value> payloadValueVector = llvm::to_vector(Range&: payloadValues);
271 DiagnosedSilenceableFailure result =
272 iface.checkPayload(handle.getLoc(), payloadValueVector);
273 if (failed(Result: result.checkAndReport()))
274 return failure();
275
276 Mappings &mappings = getMapping(value: handle);
277 bool inserted =
278 mappings.values.insert(KV: {handle, std::move(payloadValueVector)}).second;
279 assert(
280 inserted &&
281 "value handle is already associated with another list of payload values");
282 (void)inserted;
283
284 for (Value payload : payloadValues)
285 mappings.reverseValues[payload].push_back(Elt: handle);
286
287 return success();
288}
289
290LogicalResult transform::TransformState::setParams(Value value,
291 ArrayRef<Param> params) {
292 assert(value != nullptr && "attempting to set params for a null value");
293
294 for (Attribute attr : params) {
295 if (attr)
296 continue;
297 return emitError(loc: value.getLoc())
298 << "attempting to assign a null parameter to this transform value";
299 }
300
301 auto valueType = llvm::dyn_cast<TransformParamTypeInterface>(value.getType());
302 assert(value &&
303 "cannot associate parameter with a value of non-parameter type");
304 DiagnosedSilenceableFailure result =
305 valueType.checkPayload(value.getLoc(), params);
306 if (failed(Result: result.checkAndReport()))
307 return failure();
308
309 Mappings &mappings = getMapping(value);
310 bool inserted =
311 mappings.params.insert(KV: {value, llvm::to_vector(Range&: params)}).second;
312 assert(inserted && "value is already associated with another list of params");
313 (void)inserted;
314 return success();
315}
316
317template <typename Mapping, typename Key, typename Mapped>
318void dropMappingEntry(Mapping &mapping, Key key, Mapped mapped) {
319 auto it = mapping.find(key);
320 if (it == mapping.end())
321 return;
322
323 llvm::erase(it->getSecond(), mapped);
324 if (it->getSecond().empty())
325 mapping.erase(it);
326}
327
328void transform::TransformState::forgetMapping(Value opHandle,
329 ValueRange origOpFlatResults,
330 bool allowOutOfScope) {
331 Mappings &mappings = getMapping(value: opHandle, allowOutOfScope);
332 for (Operation *op : mappings.direct[opHandle])
333 dropMappingEntry(mapping&: mappings.reverse, key: op, mapped: opHandle);
334 mappings.direct.erase(Val: opHandle);
335#if LLVM_ENABLE_ABI_BREAKING_CHECKS
336 // Payload IR is removed from the mapping. This invalidates the respective
337 // iterators.
338 mappings.incrementTimestamp(value: opHandle);
339#endif // LLVM_ENABLE_ABI_BREAKING_CHECKS
340
341 for (Value opResult : origOpFlatResults) {
342 SmallVector<Value> resultHandles;
343 (void)getHandlesForPayloadValue(payloadValue: opResult, handles&: resultHandles);
344 for (Value resultHandle : resultHandles) {
345 Mappings &localMappings = getMapping(value: resultHandle);
346 dropMappingEntry(mapping&: localMappings.values, key: resultHandle, mapped: opResult);
347#if LLVM_ENABLE_ABI_BREAKING_CHECKS
348 // Payload IR is removed from the mapping. This invalidates the respective
349 // iterators.
350 mappings.incrementTimestamp(value: resultHandle);
351#endif // LLVM_ENABLE_ABI_BREAKING_CHECKS
352 dropMappingEntry(mapping&: localMappings.reverseValues, key: opResult, mapped: resultHandle);
353 }
354 }
355}
356
357void transform::TransformState::forgetValueMapping(
358 Value valueHandle, ArrayRef<Operation *> payloadOperations) {
359 Mappings &mappings = getMapping(value: valueHandle);
360 for (Value payloadValue : mappings.reverseValues[valueHandle])
361 dropMappingEntry(mapping&: mappings.reverseValues, key: payloadValue, mapped: valueHandle);
362 mappings.values.erase(Val: valueHandle);
363#if LLVM_ENABLE_ABI_BREAKING_CHECKS
364 // Payload IR is removed from the mapping. This invalidates the respective
365 // iterators.
366 mappings.incrementTimestamp(value: valueHandle);
367#endif // LLVM_ENABLE_ABI_BREAKING_CHECKS
368
369 for (Operation *payloadOp : payloadOperations) {
370 SmallVector<Value> opHandles;
371 (void)getHandlesForPayloadOp(op: payloadOp, handles&: opHandles);
372 for (Value opHandle : opHandles) {
373 Mappings &localMappings = getMapping(value: opHandle);
374 dropMappingEntry(mapping&: localMappings.direct, key: opHandle, mapped: payloadOp);
375 dropMappingEntry(mapping&: localMappings.reverse, key: payloadOp, mapped: opHandle);
376
377#if LLVM_ENABLE_ABI_BREAKING_CHECKS
378 // Payload IR is removed from the mapping. This invalidates the respective
379 // iterators.
380 localMappings.incrementTimestamp(value: opHandle);
381#endif // LLVM_ENABLE_ABI_BREAKING_CHECKS
382 }
383 }
384}
385
386LogicalResult
387transform::TransformState::replacePayloadOp(Operation *op,
388 Operation *replacement) {
389 // TODO: consider invalidating the handles to nested objects here.
390
391#ifndef NDEBUG
392 for (Value opResult : op->getResults()) {
393 SmallVector<Value> valueHandles;
394 (void)getHandlesForPayloadValue(payloadValue: opResult, handles&: valueHandles,
395 /*includeOutOfScope=*/true);
396 assert(valueHandles.empty() && "expected no mapping to old results");
397 }
398#endif // NDEBUG
399
400 // Drop the mapping between the op and all handles that point to it. Fail if
401 // there are no handles.
402 SmallVector<Value> opHandles;
403 if (failed(Result: getHandlesForPayloadOp(op, handles&: opHandles, /*includeOutOfScope=*/true)))
404 return failure();
405 for (Value handle : opHandles) {
406 Mappings &mappings = getMapping(value: handle, /*allowOutOfScope=*/true);
407 dropMappingEntry(mapping&: mappings.reverse, key: op, mapped: handle);
408 }
409
410 // Replace the pointed-to object of all handles with the replacement object.
411 // In case a payload op was erased (replacement object is nullptr), a nullptr
412 // is stored in the mapping. These nullptrs are removed after each transform.
413 // Furthermore, nullptrs are not enumerated by payload op iterators. The
414 // relative order of ops is preserved.
415 //
416 // Removing an op from the mapping would be problematic because removing an
417 // element from an array invalidates iterators; merely changing the value of
418 // elements does not.
419 for (Value handle : opHandles) {
420 Mappings &mappings = getMapping(value: handle, /*allowOutOfScope=*/true);
421 auto it = mappings.direct.find(Val: handle);
422 if (it == mappings.direct.end())
423 continue;
424
425 SmallVector<Operation *, 2> &association = it->getSecond();
426 // Note that an operation may be associated with the handle more than once.
427 for (Operation *&mapped : association) {
428 if (mapped == op)
429 mapped = replacement;
430 }
431
432 if (replacement) {
433 mappings.reverse[replacement].push_back(Elt: handle);
434 } else {
435 opHandlesToCompact.insert(V: handle);
436 }
437 }
438
439 return success();
440}
441
442LogicalResult
443transform::TransformState::replacePayloadValue(Value value, Value replacement) {
444 SmallVector<Value> valueHandles;
445 if (failed(Result: getHandlesForPayloadValue(payloadValue: value, handles&: valueHandles,
446 /*includeOutOfScope=*/true)))
447 return failure();
448
449 for (Value handle : valueHandles) {
450 Mappings &mappings = getMapping(value: handle, /*allowOutOfScope=*/true);
451 dropMappingEntry(mapping&: mappings.reverseValues, key: value, mapped: handle);
452
453 // If replacing with null, that is erasing the mapping, drop the mapping
454 // between the handles and the IR objects
455 if (!replacement) {
456 dropMappingEntry(mapping&: mappings.values, key: handle, mapped: value);
457#if LLVM_ENABLE_ABI_BREAKING_CHECKS
458 // Payload IR is removed from the mapping. This invalidates the respective
459 // iterators.
460 mappings.incrementTimestamp(value: handle);
461#endif // LLVM_ENABLE_ABI_BREAKING_CHECKS
462 } else {
463 auto it = mappings.values.find(Val: handle);
464 if (it == mappings.values.end())
465 continue;
466
467 SmallVector<Value> &association = it->getSecond();
468 for (Value &mapped : association) {
469 if (mapped == value)
470 mapped = replacement;
471 }
472 mappings.reverseValues[replacement].push_back(Elt: handle);
473 }
474 }
475
476 return success();
477}
478
479void transform::TransformState::recordOpHandleInvalidationOne(
480 OpOperand &consumingHandle, ArrayRef<Operation *> potentialAncestors,
481 Operation *payloadOp, Value otherHandle, Value throughValue,
482 transform::TransformState::InvalidatedHandleMap &newlyInvalidated) const {
483 // If the op is associated with invalidated handle, skip the check as it
484 // may be reading invalid IR. This also ensures we report the first
485 // invalidation and not the last one.
486 if (invalidatedHandles.count(Val: otherHandle) ||
487 newlyInvalidated.count(Val: otherHandle))
488 return;
489
490 FULL_LDBG("--recordOpHandleInvalidationOne\n");
491 DEBUG_WITH_TYPE(DEBUG_TYPE_FULL, {
492 (DBGS() << "--ancestors: "
493 << llvm::interleaved(llvm::make_pointee_range(potentialAncestors))
494 << "\n");
495 });
496
497 Operation *owner = consumingHandle.getOwner();
498 unsigned operandNo = consumingHandle.getOperandNumber();
499 for (Operation *ancestor : potentialAncestors) {
500 // clang-format off
501 DEBUG_WITH_TYPE(DEBUG_TYPE_FULL,
502 { (DBGS() << "----handle one ancestor: " << *ancestor << "\n"); });
503 DEBUG_WITH_TYPE(DEBUG_TYPE_FULL,
504 { (DBGS() << "----of payload with name: "
505 << payloadOp->getName().getIdentifier() << "\n"); });
506 DEBUG_WITH_TYPE(DEBUG_TYPE_FULL,
507 { (DBGS() << "----of payload: " << *payloadOp << "\n"); });
508 // clang-format on
509 if (!ancestor->isAncestor(other: payloadOp))
510 continue;
511
512 // Make sure the error-reporting lambda doesn't capture anything
513 // by-reference because it will go out of scope. Additionally, extract
514 // location from Payload IR ops because the ops themselves may be
515 // deleted before the lambda gets called.
516 Location ancestorLoc = ancestor->getLoc();
517 Location opLoc = payloadOp->getLoc();
518 std::optional<Location> throughValueLoc =
519 throughValue ? std::make_optional(t: throughValue.getLoc()) : std::nullopt;
520 newlyInvalidated[otherHandle] = [ancestorLoc, opLoc, owner, operandNo,
521 otherHandle,
522 throughValueLoc](Location currentLoc) {
523 InFlightDiagnostic diag = emitError(loc: currentLoc)
524 << "op uses a handle invalidated by a "
525 "previously executed transform op";
526 diag.attachNote(noteLoc: otherHandle.getLoc()) << "handle to invalidated ops";
527 diag.attachNote(noteLoc: owner->getLoc())
528 << "invalidated by this transform op that consumes its operand #"
529 << operandNo
530 << " and invalidates all handles to payload IR entities associated "
531 "with this operand and entities nested in them";
532 diag.attachNote(noteLoc: ancestorLoc) << "ancestor payload op";
533 diag.attachNote(noteLoc: opLoc) << "nested payload op";
534 if (throughValueLoc) {
535 diag.attachNote(noteLoc: *throughValueLoc)
536 << "consumed handle points to this payload value";
537 }
538 };
539 }
540}
541
542void transform::TransformState::recordValueHandleInvalidationByOpHandleOne(
543 OpOperand &opHandle, ArrayRef<Operation *> potentialAncestors,
544 Value payloadValue, Value valueHandle,
545 transform::TransformState::InvalidatedHandleMap &newlyInvalidated) const {
546 // If the op is associated with invalidated handle, skip the check as it
547 // may be reading invalid IR. This also ensures we report the first
548 // invalidation and not the last one.
549 if (invalidatedHandles.count(Val: valueHandle) ||
550 newlyInvalidated.count(Val: valueHandle))
551 return;
552
553 for (Operation *ancestor : potentialAncestors) {
554 Operation *definingOp;
555 std::optional<unsigned> resultNo;
556 unsigned argumentNo = std::numeric_limits<unsigned>::max();
557 unsigned blockNo = std::numeric_limits<unsigned>::max();
558 unsigned regionNo = std::numeric_limits<unsigned>::max();
559 if (auto opResult = llvm::dyn_cast<OpResult>(Val&: payloadValue)) {
560 definingOp = opResult.getOwner();
561 resultNo = opResult.getResultNumber();
562 } else {
563 auto arg = llvm::cast<BlockArgument>(Val&: payloadValue);
564 definingOp = arg.getParentBlock()->getParentOp();
565 argumentNo = arg.getArgNumber();
566 blockNo = std::distance(first: arg.getOwner()->getParent()->begin(),
567 last: arg.getOwner()->getIterator());
568 regionNo = arg.getOwner()->getParent()->getRegionNumber();
569 }
570 assert(definingOp && "expected the value to be defined by an op as result "
571 "or block argument");
572 if (!ancestor->isAncestor(other: definingOp))
573 continue;
574
575 Operation *owner = opHandle.getOwner();
576 unsigned operandNo = opHandle.getOperandNumber();
577 Location ancestorLoc = ancestor->getLoc();
578 Location opLoc = definingOp->getLoc();
579 Location valueLoc = payloadValue.getLoc();
580 newlyInvalidated[valueHandle] = [valueHandle, owner, operandNo, resultNo,
581 argumentNo, blockNo, regionNo, ancestorLoc,
582 opLoc, valueLoc](Location currentLoc) {
583 InFlightDiagnostic diag = emitError(loc: currentLoc)
584 << "op uses a handle invalidated by a "
585 "previously executed transform op";
586 diag.attachNote(noteLoc: valueHandle.getLoc()) << "invalidated handle";
587 diag.attachNote(noteLoc: owner->getLoc())
588 << "invalidated by this transform op that consumes its operand #"
589 << operandNo
590 << " and invalidates all handles to payload IR entities "
591 "associated with this operand and entities nested in them";
592 diag.attachNote(noteLoc: ancestorLoc)
593 << "ancestor op associated with the consumed handle";
594 if (resultNo) {
595 diag.attachNote(noteLoc: opLoc)
596 << "op defining the value as result #" << *resultNo;
597 } else {
598 diag.attachNote(noteLoc: opLoc)
599 << "op defining the value as block argument #" << argumentNo
600 << " of block #" << blockNo << " in region #" << regionNo;
601 }
602 diag.attachNote(noteLoc: valueLoc) << "payload value";
603 };
604 }
605}
606
607void transform::TransformState::recordOpHandleInvalidation(
608 OpOperand &handle, ArrayRef<Operation *> potentialAncestors,
609 Value throughValue,
610 transform::TransformState::InvalidatedHandleMap &newlyInvalidated) const {
611
612 if (potentialAncestors.empty()) {
613 DEBUG_WITH_TYPE(DEBUG_TYPE_FULL, {
614 (DBGS() << "----recording invalidation for empty handle: " << handle.get()
615 << "\n");
616 });
617
618 Operation *owner = handle.getOwner();
619 unsigned operandNo = handle.getOperandNumber();
620 newlyInvalidated[handle.get()] = [owner, operandNo](Location currentLoc) {
621 InFlightDiagnostic diag = emitError(loc: currentLoc)
622 << "op uses a handle associated with empty "
623 "payload and invalidated by a "
624 "previously executed transform op";
625 diag.attachNote(noteLoc: owner->getLoc())
626 << "invalidated by this transform op that consumes its operand #"
627 << operandNo;
628 };
629 return;
630 }
631
632 // Iterate over the mapping and invalidate aliasing handles. This is quite
633 // expensive and only necessary for error reporting in case of transform
634 // dialect misuse with dangling handles. Iteration over the handles is based
635 // on the assumption that the number of handles is significantly less than the
636 // number of IR objects (operations and values). Alternatively, we could walk
637 // the IR nested in each payload op associated with the given handle and look
638 // for handles associated with each operation and value.
639 for (const auto &[region, mapping] : llvm::reverse(C: mappings)) {
640 // Go over all op handle mappings and mark as invalidated any handle
641 // pointing to any of the payload ops associated with the given handle or
642 // any op nested in them.
643 for (const auto &[payloadOp, otherHandles] : mapping->reverse) {
644 for (Value otherHandle : otherHandles)
645 recordOpHandleInvalidationOne(consumingHandle&: handle, potentialAncestors, payloadOp,
646 otherHandle, throughValue,
647 newlyInvalidated);
648 }
649 // Go over all value handle mappings and mark as invalidated any handle
650 // pointing to any result of the payload op associated with the given handle
651 // or any op nested in them. Similarly invalidate handles to argument of
652 // blocks belonging to any region of any payload op associated with the
653 // given handle or any op nested in them.
654 for (const auto &[payloadValue, valueHandles] : mapping->reverseValues) {
655 for (Value valueHandle : valueHandles)
656 recordValueHandleInvalidationByOpHandleOne(opHandle&: handle, potentialAncestors,
657 payloadValue, valueHandle,
658 newlyInvalidated);
659 }
660
661 // Stop lookup when reaching a region that is isolated from above.
662 if (region->getParentOp()->hasTrait<OpTrait::IsIsolatedFromAbove>())
663 break;
664 }
665}
666
667void transform::TransformState::recordValueHandleInvalidation(
668 OpOperand &valueHandle,
669 transform::TransformState::InvalidatedHandleMap &newlyInvalidated) const {
670 // Invalidate other handles to the same value.
671 for (Value payloadValue : getPayloadValuesView(handleValue: valueHandle.get())) {
672 SmallVector<Value> otherValueHandles;
673 (void)getHandlesForPayloadValue(payloadValue, handles&: otherValueHandles);
674 for (Value otherHandle : otherValueHandles) {
675 Operation *owner = valueHandle.getOwner();
676 unsigned operandNo = valueHandle.getOperandNumber();
677 Location valueLoc = payloadValue.getLoc();
678 newlyInvalidated[otherHandle] = [otherHandle, owner, operandNo,
679 valueLoc](Location currentLoc) {
680 InFlightDiagnostic diag = emitError(loc: currentLoc)
681 << "op uses a handle invalidated by a "
682 "previously executed transform op";
683 diag.attachNote(noteLoc: otherHandle.getLoc()) << "invalidated handle";
684 diag.attachNote(noteLoc: owner->getLoc())
685 << "invalidated by this transform op that consumes its operand #"
686 << operandNo
687 << " and invalidates handles to the same values as associated with "
688 "it";
689 diag.attachNote(noteLoc: valueLoc) << "payload value";
690 };
691 }
692
693 if (auto opResult = llvm::dyn_cast<OpResult>(Val&: payloadValue)) {
694 Operation *payloadOp = opResult.getOwner();
695 recordOpHandleInvalidation(handle&: valueHandle, potentialAncestors: payloadOp, throughValue: payloadValue,
696 newlyInvalidated);
697 } else {
698 auto arg = llvm::dyn_cast<BlockArgument>(Val&: payloadValue);
699 for (Operation &payloadOp : *arg.getOwner())
700 recordOpHandleInvalidation(handle&: valueHandle, potentialAncestors: &payloadOp, throughValue: payloadValue,
701 newlyInvalidated);
702 }
703 }
704}
705
706/// Checks that the operation does not use invalidated handles as operands.
707/// Reports errors and returns failure if it does. Otherwise, invalidates the
708/// handles consumed by the operation as well as any handles pointing to payload
709/// IR operations nested in the operations associated with the consumed handles.
710LogicalResult transform::TransformState::checkAndRecordHandleInvalidationImpl(
711 transform::TransformOpInterface transform,
712 transform::TransformState::InvalidatedHandleMap &newlyInvalidated) const {
713 FULL_LDBG("--Start checkAndRecordHandleInvalidation\n");
714 auto memoryEffectsIface =
715 cast<MemoryEffectOpInterface>(transform.getOperation());
716 SmallVector<MemoryEffects::EffectInstance> effects;
717 memoryEffectsIface.getEffectsOnResource(
718 transform::TransformMappingResource::get(), effects);
719
720 for (OpOperand &target : transform->getOpOperands()) {
721 DEBUG_WITH_TYPE(DEBUG_TYPE_FULL, {
722 (DBGS() << "----iterate on handle: " << target.get() << "\n");
723 });
724 // If the operand uses an invalidated handle, report it. If the operation
725 // allows handles to point to repeated payload operations, only report
726 // pre-existing invalidation errors. Otherwise, also report invalidations
727 // caused by the current transform operation affecting its other operands.
728 auto it = invalidatedHandles.find(target.get());
729 auto nit = newlyInvalidated.find(target.get());
730 if (it != invalidatedHandles.end()) {
731 FULL_LDBG("--End checkAndRecordHandleInvalidation, found already "
732 "invalidated -> FAILURE\n");
733 return it->getSecond()(transform->getLoc()), failure();
734 }
735 if (!transform.allowsRepeatedHandleOperands() &&
736 nit != newlyInvalidated.end()) {
737 FULL_LDBG("--End checkAndRecordHandleInvalidation, found newly "
738 "invalidated (by this op) -> FAILURE\n");
739 return nit->getSecond()(transform->getLoc()), failure();
740 }
741
742 // Invalidate handles pointing to the operations nested in the operation
743 // associated with the handle consumed by this operation.
744 auto consumesTarget = [&](const MemoryEffects::EffectInstance &effect) {
745 return isa<MemoryEffects::Free>(effect.getEffect()) &&
746 effect.getValue() == target.get();
747 };
748 if (llvm::any_of(effects, consumesTarget)) {
749 FULL_LDBG("----found consume effect\n");
750 if (llvm::isa<transform::TransformHandleTypeInterface>(
751 target.get().getType())) {
752 FULL_LDBG("----recordOpHandleInvalidation\n");
753 SmallVector<Operation *> payloadOps =
754 llvm::to_vector(getPayloadOps(target.get()));
755 recordOpHandleInvalidation(target, payloadOps, nullptr,
756 newlyInvalidated);
757 } else if (llvm::isa<transform::TransformValueHandleTypeInterface>(
758 target.get().getType())) {
759 FULL_LDBG("----recordValueHandleInvalidation\n");
760 recordValueHandleInvalidation(target, newlyInvalidated);
761 } else {
762 FULL_LDBG("----not a TransformHandle -> SKIP AND DROP ON THE FLOOR\n");
763 }
764 } else {
765 FULL_LDBG("----no consume effect -> SKIP\n");
766 }
767 }
768
769 FULL_LDBG("--End checkAndRecordHandleInvalidation -> SUCCESS\n");
770 return success();
771}
772
773LogicalResult transform::TransformState::checkAndRecordHandleInvalidation(
774 transform::TransformOpInterface transform) {
775 InvalidatedHandleMap newlyInvalidated;
776 LogicalResult checkResult =
777 checkAndRecordHandleInvalidationImpl(transform: transform, newlyInvalidated);
778 invalidatedHandles.insert(I: std::make_move_iterator(i: newlyInvalidated.begin()),
779 E: std::make_move_iterator(i: newlyInvalidated.end()));
780 return checkResult;
781}
782
783template <typename T>
784DiagnosedSilenceableFailure
785checkRepeatedConsumptionInOperand(ArrayRef<T> payload,
786 transform::TransformOpInterface transform,
787 unsigned operandNumber) {
788 DenseSet<T> seen;
789 for (T p : payload) {
790 if (!seen.insert(p).second) {
791 DiagnosedSilenceableFailure diag =
792 transform.emitSilenceableError()
793 << "a handle passed as operand #" << operandNumber
794 << " and consumed by this operation points to a payload "
795 "entity more than once";
796 if constexpr (std::is_pointer_v<T>)
797 diag.attachNote(loc: p->getLoc()) << "repeated target op";
798 else
799 diag.attachNote(loc: p.getLoc()) << "repeated target value";
800 return diag;
801 }
802 }
803 return DiagnosedSilenceableFailure::success();
804}
805
806void transform::TransformState::compactOpHandles() {
807 for (Value handle : opHandlesToCompact) {
808 Mappings &mappings = getMapping(value: handle, /*allowOutOfScope=*/true);
809#if LLVM_ENABLE_ABI_BREAKING_CHECKS
810 if (llvm::is_contained(Range&: mappings.direct[handle], Element: nullptr))
811 // Payload IR is removed from the mapping. This invalidates the respective
812 // iterators.
813 mappings.incrementTimestamp(value: handle);
814#endif // LLVM_ENABLE_ABI_BREAKING_CHECKS
815 llvm::erase(C&: mappings.direct[handle], V: nullptr);
816 }
817 opHandlesToCompact.clear();
818}
819
820DiagnosedSilenceableFailure
821transform::TransformState::applyTransform(TransformOpInterface transform) {
822 LLVM_DEBUG({
823 DBGS() << "applying: ";
824 transform->print(llvm::dbgs(), OpPrintingFlags().skipRegions());
825 llvm::dbgs() << "\n";
826 });
827 DEBUG_WITH_TYPE(DEBUG_TYPE_FULL,
828 DBGS() << "Top-level payload before application:\n"
829 << *getTopLevel() << "\n");
830 auto printOnFailureRAII = llvm::make_scope_exit(F: [this] {
831 (void)this;
832 LLVM_DEBUG(DBGS() << "Failing Top-level payload:\n"; getTopLevel()->print(
833 llvm::dbgs(), mlir::OpPrintingFlags().printGenericOpForm()););
834 });
835
836 // Set current transform op.
837 regionStack.back()->currentTransform = transform;
838
839 // Expensive checks to detect invalid transform IR.
840 if (options.getExpensiveChecksEnabled()) {
841 FULL_LDBG("ExpensiveChecksEnabled\n");
842 if (failed(checkAndRecordHandleInvalidation(transform: transform)))
843 return DiagnosedSilenceableFailure::definiteFailure();
844
845 for (OpOperand &operand : transform->getOpOperands()) {
846 DEBUG_WITH_TYPE(DEBUG_TYPE_FULL, {
847 (DBGS() << "iterate on handle: " << operand.get() << "\n");
848 });
849 if (!isHandleConsumed(operand.get(), transform)) {
850 FULL_LDBG("--handle not consumed -> SKIP\n");
851 continue;
852 }
853 if (transform.allowsRepeatedHandleOperands()) {
854 FULL_LDBG("--op allows repeated handles -> SKIP\n");
855 continue;
856 }
857 FULL_LDBG("--handle is consumed\n");
858
859 Type operandType = operand.get().getType();
860 if (llvm::isa<TransformHandleTypeInterface>(operandType)) {
861 FULL_LDBG("--checkRepeatedConsumptionInOperand for Operation*\n");
862 DiagnosedSilenceableFailure check =
863 checkRepeatedConsumptionInOperand<Operation *>(
864 getPayloadOpsView(operand.get()), transform,
865 operand.getOperandNumber());
866 if (!check.succeeded()) {
867 FULL_LDBG("----FAILED\n");
868 return check;
869 }
870 } else if (llvm::isa<TransformValueHandleTypeInterface>(operandType)) {
871 FULL_LDBG("--checkRepeatedConsumptionInOperand For Value\n");
872 DiagnosedSilenceableFailure check =
873 checkRepeatedConsumptionInOperand<Value>(
874 getPayloadValuesView(operand.get()), transform,
875 operand.getOperandNumber());
876 if (!check.succeeded()) {
877 FULL_LDBG("----FAILED\n");
878 return check;
879 }
880 } else {
881 FULL_LDBG("--not a TransformHandle -> SKIP AND DROP ON THE FLOOR\n");
882 }
883 }
884 }
885
886 // Find which operands are consumed.
887 SmallVector<OpOperand *> consumedOperands =
888 transform.getConsumedHandleOpOperands();
889
890 // Remember the results of the payload ops associated with the consumed
891 // op handles or the ops defining the value handles so we can drop the
892 // association with them later. This must happen here because the
893 // transformation may destroy or mutate them so we cannot traverse the payload
894 // IR after that.
895 SmallVector<Value> origOpFlatResults;
896 SmallVector<Operation *> origAssociatedOps;
897 for (OpOperand *opOperand : consumedOperands) {
898 Value operand = opOperand->get();
899 if (llvm::isa<TransformHandleTypeInterface>(operand.getType())) {
900 for (Operation *payloadOp : getPayloadOps(operand)) {
901 llvm::append_range(origOpFlatResults, payloadOp->getResults());
902 }
903 continue;
904 }
905 if (llvm::isa<TransformValueHandleTypeInterface>(operand.getType())) {
906 for (Value payloadValue : getPayloadValuesView(operand)) {
907 if (llvm::isa<OpResult>(payloadValue)) {
908 origAssociatedOps.push_back(payloadValue.getDefiningOp());
909 continue;
910 }
911 llvm::append_range(
912 origAssociatedOps,
913 llvm::map_range(*llvm::cast<BlockArgument>(payloadValue).getOwner(),
914 [](Operation &op) { return &op; }));
915 }
916 continue;
917 }
918 DiagnosedDefiniteFailure diag =
919 emitDefiniteFailure(transform->getLoc())
920 << "unexpectedly consumed a value that is not a handle as operand #"
921 << opOperand->getOperandNumber();
922 diag.attachNote(operand.getLoc())
923 << "value defined here with type " << operand.getType();
924 return diag;
925 }
926
927 // Prepare rewriter and listener.
928 TrackingListenerConfig config;
929 config.skipHandleFn = [&](Value handle) {
930 // Skip handle if it is dead.
931 auto scopeIt =
932 llvm::find_if(Range: llvm::reverse(C&: regionStack), P: [&](RegionScope *scope) {
933 return handle.getParentRegion() == scope->region;
934 });
935 assert(scopeIt != regionStack.rend() &&
936 "could not find region scope for handle");
937 RegionScope *scope = *scopeIt;
938 return llvm::all_of(Range: handle.getUsers(), P: [&](Operation *user) {
939 return user == scope->currentTransform ||
940 happensBefore(user, scope->currentTransform);
941 });
942 };
943 transform::ErrorCheckingTrackingListener trackingListener(*this, transform,
944 config);
945 transform::TransformRewriter rewriter(transform->getContext(),
946 &trackingListener);
947
948 // Compute the result but do not short-circuit the silenceable failure case as
949 // we still want the handles to propagate properly so the "suppress" mode can
950 // proceed on a best effort basis.
951 transform::TransformResults results(transform->getNumResults());
952 DiagnosedSilenceableFailure result(transform.apply(rewriter, results, *this));
953 compactOpHandles();
954
955 // Error handling: fail if transform or listener failed.
956 DiagnosedSilenceableFailure trackingFailure =
957 trackingListener.checkAndResetError();
958 if (!transform->hasTrait<ReportTrackingListenerFailuresOpTrait>() ||
959 transform->hasAttr(FindPayloadReplacementOpInterface::
960 kSilenceTrackingFailuresAttrName)) {
961 // Only report failures for ReportTrackingListenerFailuresOpTrait ops. Also
962 // do not report failures if the above mentioned attribute is set.
963 if (trackingFailure.isSilenceableFailure())
964 (void)trackingFailure.silence();
965 trackingFailure = DiagnosedSilenceableFailure::success();
966 }
967 if (!trackingFailure.succeeded()) {
968 if (result.succeeded()) {
969 result = std::move(trackingFailure);
970 } else {
971 // Transform op errors have precedence, report those first.
972 if (result.isSilenceableFailure())
973 result.attachNote() << "tracking listener also failed: "
974 << trackingFailure.getMessage();
975 (void)trackingFailure.silence();
976 }
977 }
978 if (result.isDefiniteFailure())
979 return result;
980
981 // If a silenceable failure was produced, some results may be unset, set them
982 // to empty lists.
983 if (result.isSilenceableFailure())
984 results.setRemainingToEmpty(transform);
985
986 // Remove the mapping for the operand if it is consumed by the operation. This
987 // allows us to catch use-after-free with assertions later on.
988 for (OpOperand *opOperand : consumedOperands) {
989 Value operand = opOperand->get();
990 if (llvm::isa<TransformHandleTypeInterface>(operand.getType())) {
991 forgetMapping(operand, origOpFlatResults);
992 } else if (llvm::isa<TransformValueHandleTypeInterface>(
993 operand.getType())) {
994 forgetValueMapping(operand, origAssociatedOps);
995 }
996 }
997
998 if (failed(updateStateFromResults(results, opResults: transform->getResults())))
999 return DiagnosedSilenceableFailure::definiteFailure();
1000
1001 printOnFailureRAII.release();
1002 DEBUG_WITH_TYPE(DEBUG_PRINT_AFTER_ALL, {
1003 DBGS() << "Top-level payload:\n";
1004 getTopLevel()->print(llvm::dbgs());
1005 });
1006 return result;
1007}
1008
1009LogicalResult transform::TransformState::updateStateFromResults(
1010 const TransformResults &results, ResultRange opResults) {
1011 for (OpResult result : opResults) {
1012 if (llvm::isa<TransformParamTypeInterface>(result.getType())) {
1013 assert(results.isParam(result.getResultNumber()) &&
1014 "expected parameters for the parameter-typed result");
1015 if (failed(
1016 Result: setParams(value: result, params: results.getParams(resultNumber: result.getResultNumber())))) {
1017 return failure();
1018 }
1019 } else if (llvm::isa<TransformValueHandleTypeInterface>(result.getType())) {
1020 assert(results.isValue(result.getResultNumber()) &&
1021 "expected values for value-type-result");
1022 if (failed(Result: setPayloadValues(
1023 handle: result, payloadValues: results.getValues(resultNumber: result.getResultNumber())))) {
1024 return failure();
1025 }
1026 } else {
1027 assert(!results.isParam(result.getResultNumber()) &&
1028 "expected payload ops for the non-parameter typed result");
1029 if (failed(
1030 Result: setPayloadOps(value: result, targets: results.get(resultNumber: result.getResultNumber())))) {
1031 return failure();
1032 }
1033 }
1034 }
1035 return success();
1036}
1037
1038//===----------------------------------------------------------------------===//
1039// TransformState::Extension
1040//===----------------------------------------------------------------------===//
1041
1042transform::TransformState::Extension::~Extension() = default;
1043
1044LogicalResult
1045transform::TransformState::Extension::replacePayloadOp(Operation *op,
1046 Operation *replacement) {
1047 // TODO: we may need to invalidate handles to operations and values nested in
1048 // the operation being replaced.
1049 return state.replacePayloadOp(op, replacement);
1050}
1051
1052LogicalResult
1053transform::TransformState::Extension::replacePayloadValue(Value value,
1054 Value replacement) {
1055 return state.replacePayloadValue(value, replacement);
1056}
1057
1058//===----------------------------------------------------------------------===//
1059// TransformState::RegionScope
1060//===----------------------------------------------------------------------===//
1061
1062transform::TransformState::RegionScope::~RegionScope() {
1063 // Remove handle invalidation notices as handles are going out of scope.
1064 // The same region may be re-entered leading to incorrect invalidation
1065 // errors.
1066 for (Block &block : *region) {
1067 for (Value handle : block.getArguments()) {
1068 state.invalidatedHandles.erase(Val: handle);
1069 }
1070 for (Operation &op : block) {
1071 for (Value handle : op.getResults()) {
1072 state.invalidatedHandles.erase(Val: handle);
1073 }
1074 }
1075 }
1076
1077#if LLVM_ENABLE_ABI_BREAKING_CHECKS
1078 // Remember pointers to payload ops referenced by the handles going out of
1079 // scope.
1080 SmallVector<Operation *> referencedOps =
1081 llvm::to_vector(Range: llvm::make_first_range(c&: state.mappings[region]->reverse));
1082#endif // LLVM_ENABLE_ABI_BREAKING_CHECKS
1083
1084 state.mappings.erase(Key: region);
1085 state.regionStack.pop_back();
1086}
1087
1088//===----------------------------------------------------------------------===//
1089// TransformResults
1090//===----------------------------------------------------------------------===//
1091
1092transform::TransformResults::TransformResults(unsigned numSegments) {
1093 operations.appendEmptyRows(num: numSegments);
1094 params.appendEmptyRows(num: numSegments);
1095 values.appendEmptyRows(num: numSegments);
1096}
1097
1098void transform::TransformResults::setParams(
1099 OpResult value, ArrayRef<transform::TransformState::Param> params) {
1100 int64_t position = value.getResultNumber();
1101 assert(position < static_cast<int64_t>(this->params.size()) &&
1102 "setting params for a non-existent handle");
1103 assert(this->params[position].data() == nullptr && "params already set");
1104 assert(operations[position].data() == nullptr &&
1105 "another kind of results already set");
1106 assert(values[position].data() == nullptr &&
1107 "another kind of results already set");
1108 this->params.replace(pos: position, elements&: params);
1109}
1110
1111void transform::TransformResults::setMappedValues(
1112 OpResult handle, ArrayRef<MappedValue> values) {
1113 DiagnosedSilenceableFailure diag = dispatchMappedValues(
1114 handle, values,
1115 operationsFn: [&](ArrayRef<Operation *> operations) {
1116 return set(value: handle, ops&: operations), success();
1117 },
1118 paramsFn: [&](ArrayRef<Param> params) {
1119 return setParams(value: handle, params), success();
1120 },
1121 valuesFn: [&](ValueRange payloadValues) {
1122 return setValues(handle, values&: payloadValues), success();
1123 });
1124#ifndef NDEBUG
1125 if (!diag.succeeded())
1126 llvm::dbgs() << diag.getStatusString() << "\n";
1127 assert(diag.succeeded() && "incorrect mapping");
1128#endif // NDEBUG
1129 (void)diag.silence();
1130}
1131
1132void transform::TransformResults::setRemainingToEmpty(
1133 transform::TransformOpInterface transform) {
1134 for (OpResult opResult : transform->getResults()) {
1135 if (!isSet(opResult.getResultNumber()))
1136 setMappedValues(opResult, {});
1137 }
1138}
1139
1140ArrayRef<Operation *>
1141transform::TransformResults::get(unsigned resultNumber) const {
1142 assert(resultNumber < operations.size() &&
1143 "querying results for a non-existent handle");
1144 assert(operations[resultNumber].data() != nullptr &&
1145 "querying unset results (values or params expected?)");
1146 return operations[resultNumber];
1147}
1148
1149ArrayRef<transform::TransformState::Param>
1150transform::TransformResults::getParams(unsigned resultNumber) const {
1151 assert(resultNumber < params.size() &&
1152 "querying params for a non-existent handle");
1153 assert(params[resultNumber].data() != nullptr &&
1154 "querying unset params (ops or values expected?)");
1155 return params[resultNumber];
1156}
1157
1158ArrayRef<Value>
1159transform::TransformResults::getValues(unsigned resultNumber) const {
1160 assert(resultNumber < values.size() &&
1161 "querying values for a non-existent handle");
1162 assert(values[resultNumber].data() != nullptr &&
1163 "querying unset values (ops or params expected?)");
1164 return values[resultNumber];
1165}
1166
1167bool transform::TransformResults::isParam(unsigned resultNumber) const {
1168 assert(resultNumber < params.size() &&
1169 "querying association for a non-existent handle");
1170 return params[resultNumber].data() != nullptr;
1171}
1172
1173bool transform::TransformResults::isValue(unsigned resultNumber) const {
1174 assert(resultNumber < values.size() &&
1175 "querying association for a non-existent handle");
1176 return values[resultNumber].data() != nullptr;
1177}
1178
1179bool transform::TransformResults::isSet(unsigned resultNumber) const {
1180 assert(resultNumber < params.size() &&
1181 "querying association for a non-existent handle");
1182 return params[resultNumber].data() != nullptr ||
1183 operations[resultNumber].data() != nullptr ||
1184 values[resultNumber].data() != nullptr;
1185}
1186
1187//===----------------------------------------------------------------------===//
1188// TrackingListener
1189//===----------------------------------------------------------------------===//
1190
1191transform::TrackingListener::TrackingListener(TransformState &state,
1192 TransformOpInterface op,
1193 TrackingListenerConfig config)
1194 : TransformState::Extension(state), transformOp(op), config(config) {
1195 if (op) {
1196 for (OpOperand *opOperand : transformOp.getConsumedHandleOpOperands()) {
1197 consumedHandles.insert(opOperand->get());
1198 }
1199 }
1200}
1201
1202Operation *transform::TrackingListener::getCommonDefiningOp(ValueRange values) {
1203 Operation *defOp = nullptr;
1204 for (Value v : values) {
1205 // Skip empty values.
1206 if (!v)
1207 continue;
1208 if (!defOp) {
1209 defOp = v.getDefiningOp();
1210 continue;
1211 }
1212 if (defOp != v.getDefiningOp())
1213 return nullptr;
1214 }
1215 return defOp;
1216}
1217
1218DiagnosedSilenceableFailure transform::TrackingListener::findReplacementOp(
1219 Operation *&result, Operation *op, ValueRange newValues) const {
1220 assert(op->getNumResults() == newValues.size() &&
1221 "invalid number of replacement values");
1222 SmallVector<Value> values(newValues.begin(), newValues.end());
1223
1224 DiagnosedSilenceableFailure diag = emitSilenceableFailure(
1225 getTransformOp(), "tracking listener failed to find replacement op "
1226 "during application of this transform op");
1227
1228 do {
1229 // If the replacement values belong to different ops, drop the mapping.
1230 Operation *defOp = getCommonDefiningOp(values);
1231 if (!defOp) {
1232 diag.attachNote() << "replacement values belong to different ops";
1233 return diag;
1234 }
1235
1236 // Skip through ops that implement CastOpInterface.
1237 if (config.skipCastOps && isa<CastOpInterface>(defOp)) {
1238 values.clear();
1239 values.assign(in_start: defOp->getOperands().begin(), in_end: defOp->getOperands().end());
1240 diag.attachNote(loc: defOp->getLoc())
1241 << "using output of 'CastOpInterface' op";
1242 continue;
1243 }
1244
1245 // If the defining op has the same name or we do not care about the name of
1246 // op replacements at all, we take it as a replacement.
1247 if (!config.requireMatchingReplacementOpName ||
1248 op->getName() == defOp->getName()) {
1249 result = defOp;
1250 return DiagnosedSilenceableFailure::success();
1251 }
1252
1253 // Replacing an op with a constant-like equivalent is a common
1254 // canonicalization.
1255 if (defOp->hasTrait<OpTrait::ConstantLike>()) {
1256 result = defOp;
1257 return DiagnosedSilenceableFailure::success();
1258 }
1259
1260 values.clear();
1261
1262 // Skip through ops that implement FindPayloadReplacementOpInterface.
1263 if (auto findReplacementOpInterface =
1264 dyn_cast<FindPayloadReplacementOpInterface>(defOp)) {
1265 values.assign(findReplacementOpInterface.getNextOperands());
1266 diag.attachNote(loc: defOp->getLoc()) << "using operands provided by "
1267 "'FindPayloadReplacementOpInterface'";
1268 continue;
1269 }
1270 } while (!values.empty());
1271
1272 diag.attachNote() << "ran out of suitable replacement values";
1273 return diag;
1274}
1275
1276void transform::TrackingListener::notifyMatchFailure(
1277 Location loc, function_ref<void(Diagnostic &)> reasonCallback) {
1278 LLVM_DEBUG({
1279 Diagnostic diag(loc, DiagnosticSeverity::Remark);
1280 reasonCallback(diag);
1281 DBGS() << "Match Failure : " << diag.str() << "\n";
1282 });
1283}
1284
1285void transform::TrackingListener::notifyOperationErased(Operation *op) {
1286 // Remove mappings for result values.
1287 for (OpResult value : op->getResults())
1288 (void)replacePayloadValue(value, nullptr);
1289 // Remove mapping for op.
1290 (void)replacePayloadOp(op, nullptr);
1291}
1292
1293void transform::TrackingListener::notifyOperationReplaced(
1294 Operation *op, ValueRange newValues) {
1295 assert(op->getNumResults() == newValues.size() &&
1296 "invalid number of replacement values");
1297
1298 // Replace value handles.
1299 for (auto [oldValue, newValue] : llvm::zip(t: op->getResults(), u&: newValues))
1300 (void)replacePayloadValue(oldValue, newValue);
1301
1302 // Replace op handle.
1303 SmallVector<Value> opHandles;
1304 if (failed(getTransformState().getHandlesForPayloadOp(
1305 op, opHandles, /*includeOutOfScope=*/true))) {
1306 // Op is not tracked.
1307 return;
1308 }
1309
1310 // Helper function to check if the current transform op consumes any handle
1311 // that is mapped to `op`.
1312 //
1313 // Note: If a handle was consumed, there shouldn't be any alive users, so it
1314 // is not really necessary to check for consumed handles. However, in case
1315 // there are indeed alive handles that were consumed (which is undefined
1316 // behavior) and a replacement op could not be found, we want to fail with a
1317 // nicer error message: "op uses a handle invalidated..." instead of "could
1318 // not find replacement op". This nicer error is produced later.
1319 auto handleWasConsumed = [&] {
1320 return llvm::any_of(Range&: opHandles,
1321 P: [&](Value h) { return consumedHandles.contains(V: h); });
1322 };
1323
1324 // Check if there are any handles that must be updated.
1325 Value aliveHandle;
1326 if (config.skipHandleFn) {
1327 auto it = llvm::find_if(Range&: opHandles,
1328 P: [&](Value v) { return !config.skipHandleFn(v); });
1329 if (it != opHandles.end())
1330 aliveHandle = *it;
1331 } else if (!opHandles.empty()) {
1332 aliveHandle = opHandles.front();
1333 }
1334 if (!aliveHandle || handleWasConsumed()) {
1335 // The op is tracked but the corresponding handles are dead or were
1336 // consumed. Drop the op form the mapping.
1337 (void)replacePayloadOp(op, nullptr);
1338 return;
1339 }
1340
1341 Operation *replacement;
1342 DiagnosedSilenceableFailure diag =
1343 findReplacementOp(result&: replacement, op, newValues);
1344 // If the op is tracked but no replacement op was found, send a
1345 // notification.
1346 if (!diag.succeeded()) {
1347 diag.attachNote(loc: aliveHandle.getLoc())
1348 << "replacement is required because this handle must be updated";
1349 notifyPayloadReplacementNotFound(op, values: newValues, diag: std::move(diag));
1350 (void)replacePayloadOp(op, nullptr);
1351 return;
1352 }
1353
1354 (void)replacePayloadOp(op, replacement);
1355}
1356
1357transform::ErrorCheckingTrackingListener::~ErrorCheckingTrackingListener() {
1358 // The state of the ErrorCheckingTrackingListener must be checked and reset
1359 // if there was an error. This is to prevent errors from accidentally being
1360 // missed.
1361 assert(status.succeeded() && "listener state was not checked");
1362}
1363
1364DiagnosedSilenceableFailure
1365transform::ErrorCheckingTrackingListener::checkAndResetError() {
1366 DiagnosedSilenceableFailure s = std::move(status);
1367 status = DiagnosedSilenceableFailure::success();
1368 errorCounter = 0;
1369 return s;
1370}
1371
1372bool transform::ErrorCheckingTrackingListener::failed() const {
1373 return !status.succeeded();
1374}
1375
1376void transform::ErrorCheckingTrackingListener::notifyPayloadReplacementNotFound(
1377 Operation *op, ValueRange values, DiagnosedSilenceableFailure &&diag) {
1378
1379 // Merge potentially existing diags and store the result in the listener.
1380 SmallVector<Diagnostic> diags;
1381 diag.takeDiagnostics(diags);
1382 if (!status.succeeded())
1383 status.takeDiagnostics(diags);
1384 status = DiagnosedSilenceableFailure::silenceableFailure(diag: std::move(diags));
1385
1386 // Report more details.
1387 status.attachNote(loc: op->getLoc()) << "[" << errorCounter << "] replaced op";
1388 for (auto &&[index, value] : llvm::enumerate(First&: values))
1389 status.attachNote(loc: value.getLoc())
1390 << "[" << errorCounter << "] replacement value " << index;
1391 ++errorCounter;
1392}
1393
1394std::string
1395transform::ErrorCheckingTrackingListener::getLatestMatchFailureMessage() {
1396 if (!matchFailure) {
1397 return "";
1398 }
1399 return matchFailure->str();
1400}
1401
1402void transform::ErrorCheckingTrackingListener::notifyMatchFailure(
1403 Location loc, function_ref<void(Diagnostic &)> reasonCallback) {
1404 Diagnostic diag(loc, DiagnosticSeverity::Remark);
1405 reasonCallback(diag);
1406 matchFailure = std::move(diag);
1407}
1408
1409//===----------------------------------------------------------------------===//
1410// TransformRewriter
1411//===----------------------------------------------------------------------===//
1412
1413transform::TransformRewriter::TransformRewriter(
1414 MLIRContext *ctx, ErrorCheckingTrackingListener *listener)
1415 : RewriterBase(ctx), listener(listener) {
1416 setListener(listener);
1417}
1418
1419bool transform::TransformRewriter::hasTrackingFailures() const {
1420 return listener->failed();
1421}
1422
1423/// Silence all tracking failures that have been encountered so far.
1424void transform::TransformRewriter::silenceTrackingFailure() {
1425 if (hasTrackingFailures()) {
1426 DiagnosedSilenceableFailure status = listener->checkAndResetError();
1427 (void)status.silence();
1428 }
1429}
1430
1431LogicalResult transform::TransformRewriter::notifyPayloadOperationReplaced(
1432 Operation *op, Operation *replacement) {
1433 return listener->replacePayloadOp(op, replacement);
1434}
1435
1436//===----------------------------------------------------------------------===//
1437// Utilities for TransformEachOpTrait.
1438//===----------------------------------------------------------------------===//
1439
1440LogicalResult
1441transform::detail::checkNestedConsumption(Location loc,
1442 ArrayRef<Operation *> targets) {
1443 for (auto &&[position, parent] : llvm::enumerate(First&: targets)) {
1444 for (Operation *child : targets.drop_front(N: position + 1)) {
1445 if (parent->isAncestor(other: child)) {
1446 InFlightDiagnostic diag =
1447 emitError(loc)
1448 << "transform operation consumes a handle pointing to an ancestor "
1449 "payload operation before its descendant";
1450 diag.attachNote()
1451 << "the ancestor is likely erased or rewritten before the "
1452 "descendant is accessed, leading to undefined behavior";
1453 diag.attachNote(noteLoc: parent->getLoc()) << "ancestor payload op";
1454 diag.attachNote(noteLoc: child->getLoc()) << "descendant payload op";
1455 return diag;
1456 }
1457 }
1458 }
1459 return success();
1460}
1461
1462LogicalResult
1463transform::detail::checkApplyToOne(Operation *transformOp,
1464 Location payloadOpLoc,
1465 const ApplyToEachResultList &partialResult) {
1466 Location transformOpLoc = transformOp->getLoc();
1467 StringRef transformOpName = transformOp->getName().getStringRef();
1468 unsigned expectedNumResults = transformOp->getNumResults();
1469
1470 // Reuse the emission of the diagnostic note.
1471 auto emitDiag = [&]() {
1472 auto diag = mlir::emitError(loc: transformOpLoc);
1473 diag.attachNote(noteLoc: payloadOpLoc) << "when applied to this op";
1474 return diag;
1475 };
1476
1477 if (partialResult.size() != expectedNumResults) {
1478 auto diag = emitDiag() << "application of " << transformOpName
1479 << " expected to produce " << expectedNumResults
1480 << " results (actually produced "
1481 << partialResult.size() << ").";
1482 diag.attachNote(noteLoc: transformOpLoc)
1483 << "if you need variadic results, consider a generic `apply` "
1484 << "instead of the specialized `applyToOne`.";
1485 return failure();
1486 }
1487
1488 // Check that the right kind of value was produced.
1489 for (const auto &[ptr, res] :
1490 llvm::zip(t: partialResult, u: transformOp->getResults())) {
1491 if (ptr.isNull())
1492 continue;
1493 if (llvm::isa<TransformHandleTypeInterface>(res.getType()) &&
1494 !isa<Operation *>(ptr)) {
1495 return emitDiag() << "application of " << transformOpName
1496 << " expected to produce an Operation * for result #"
1497 << res.getResultNumber();
1498 }
1499 if (llvm::isa<TransformParamTypeInterface>(res.getType()) &&
1500 !isa<Attribute>(ptr)) {
1501 return emitDiag() << "application of " << transformOpName
1502 << " expected to produce an Attribute for result #"
1503 << res.getResultNumber();
1504 }
1505 if (llvm::isa<TransformValueHandleTypeInterface>(res.getType()) &&
1506 !isa<Value>(ptr)) {
1507 return emitDiag() << "application of " << transformOpName
1508 << " expected to produce a Value for result #"
1509 << res.getResultNumber();
1510 }
1511 }
1512 return success();
1513}
1514
1515template <typename T>
1516static SmallVector<T> castVector(ArrayRef<transform::MappedValue> range) {
1517 return llvm::to_vector(llvm::map_range(
1518 range, [](transform::MappedValue value) { return cast<T>(value); }));
1519}
1520
1521void transform::detail::setApplyToOneResults(
1522 Operation *transformOp, TransformResults &transformResults,
1523 ArrayRef<ApplyToEachResultList> results) {
1524 SmallVector<SmallVector<MappedValue>> transposed;
1525 transposed.resize(N: transformOp->getNumResults());
1526 for (const ApplyToEachResultList &partialResults : results) {
1527 if (llvm::any_of(Range: partialResults,
1528 P: [](MappedValue value) { return value.isNull(); }))
1529 continue;
1530 assert(transformOp->getNumResults() == partialResults.size() &&
1531 "expected as many partial results as op as results");
1532 for (auto [i, value] : llvm::enumerate(First: partialResults))
1533 transposed[i].push_back(Elt: value);
1534 }
1535
1536 for (OpResult r : transformOp->getResults()) {
1537 unsigned position = r.getResultNumber();
1538 if (llvm::isa<TransformParamTypeInterface>(r.getType())) {
1539 transformResults.setParams(value: r,
1540 params: castVector<Attribute>(range: transposed[position]));
1541 } else if (llvm::isa<TransformValueHandleTypeInterface>(r.getType())) {
1542 transformResults.setValues(handle: r, values: castVector<Value>(range: transposed[position]));
1543 } else {
1544 transformResults.set(value: r, ops: castVector<Operation *>(range: transposed[position]));
1545 }
1546 }
1547}
1548
1549//===----------------------------------------------------------------------===//
1550// Utilities for implementing transform ops with regions.
1551//===----------------------------------------------------------------------===//
1552
1553LogicalResult transform::detail::appendValueMappings(
1554 MutableArrayRef<SmallVector<transform::MappedValue>> mappings,
1555 ValueRange values, const transform::TransformState &state, bool flatten) {
1556 assert(mappings.size() == values.size() && "mismatching number of mappings");
1557 for (auto &&[operand, mapped] : llvm::zip_equal(t&: values, u&: mappings)) {
1558 size_t mappedSize = mapped.size();
1559 if (llvm::isa<TransformHandleTypeInterface>(operand.getType())) {
1560 llvm::append_range(C&: mapped, R: state.getPayloadOps(value: operand));
1561 } else if (llvm::isa<TransformValueHandleTypeInterface>(
1562 operand.getType())) {
1563 llvm::append_range(C&: mapped, R: state.getPayloadValues(handleValue: operand));
1564 } else {
1565 assert(llvm::isa<TransformParamTypeInterface>(operand.getType()) &&
1566 "unsupported kind of transform dialect value");
1567 llvm::append_range(C&: mapped, R: state.getParams(value: operand));
1568 }
1569
1570 if (mapped.size() - mappedSize != 1 && !flatten)
1571 return failure();
1572 }
1573 return success();
1574}
1575
1576void transform::detail::prepareValueMappings(
1577 SmallVectorImpl<SmallVector<transform::MappedValue>> &mappings,
1578 ValueRange values, const transform::TransformState &state) {
1579 mappings.resize(N: mappings.size() + values.size());
1580 (void)appendValueMappings(
1581 MutableArrayRef<SmallVector<transform::MappedValue>>(mappings).take_back(
1582 values.size()),
1583 values, state);
1584}
1585
1586void transform::detail::forwardTerminatorOperands(
1587 Block *block, transform::TransformState &state,
1588 transform::TransformResults &results) {
1589 for (auto &&[terminatorOperand, result] :
1590 llvm::zip(t: block->getTerminator()->getOperands(),
1591 u: block->getParentOp()->getOpResults())) {
1592 if (llvm::isa<transform::TransformHandleTypeInterface>(result.getType())) {
1593 results.set(value: result, ops: state.getPayloadOps(value: terminatorOperand));
1594 } else if (llvm::isa<transform::TransformValueHandleTypeInterface>(
1595 result.getType())) {
1596 results.setValues(handle: result, values: state.getPayloadValues(handleValue: terminatorOperand));
1597 } else {
1598 assert(
1599 llvm::isa<transform::TransformParamTypeInterface>(result.getType()) &&
1600 "unhandled transform type interface");
1601 results.setParams(value: result, params: state.getParams(value: terminatorOperand));
1602 }
1603 }
1604}
1605
1606transform::TransformState
1607transform::detail::makeTransformStateForTesting(Region *region,
1608 Operation *payloadRoot) {
1609 return TransformState(region, payloadRoot);
1610}
1611
1612//===----------------------------------------------------------------------===//
1613// Utilities for PossibleTopLevelTransformOpTrait.
1614//===----------------------------------------------------------------------===//
1615
1616/// Appends to `effects` the memory effect instances on `target` with the same
1617/// resource and effect as the ones the operation `iface` having on `source`.
1618static void
1619remapEffects(MemoryEffectOpInterface iface, BlockArgument source,
1620 OpOperand *target,
1621 SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
1622 SmallVector<MemoryEffects::EffectInstance> nestedEffects;
1623 iface.getEffectsOnValue(source, nestedEffects);
1624 for (const auto &effect : nestedEffects)
1625 effects.emplace_back(Args: effect.getEffect(), Args&: target, Args: effect.getResource());
1626}
1627
1628/// Appends to `effects` the same effects as the operations of `block` have on
1629/// block arguments but associated with `operands.`
1630static void
1631remapArgumentEffects(Block &block, MutableArrayRef<OpOperand> operands,
1632 SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
1633 for (Operation &op : block) {
1634 auto iface = dyn_cast<MemoryEffectOpInterface>(&op);
1635 if (!iface)
1636 continue;
1637
1638 for (auto &&[source, target] : llvm::zip(t: block.getArguments(), u&: operands)) {
1639 remapEffects(iface, source, &target, effects);
1640 }
1641
1642 SmallVector<MemoryEffects::EffectInstance> nestedEffects;
1643 iface.getEffectsOnResource(transform::PayloadIRResource::get(),
1644 nestedEffects);
1645 llvm::append_range(C&: effects, R&: nestedEffects);
1646 }
1647}
1648
1649void transform::detail::getPotentialTopLevelEffects(
1650 Operation *operation, Value root, Block &body,
1651 SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
1652 transform::onlyReadsHandle(handles: operation->getOpOperands(), effects);
1653 transform::producesHandle(handles: operation->getOpResults(), effects);
1654
1655 if (!root) {
1656 for (Operation &op : body) {
1657 auto iface = dyn_cast<MemoryEffectOpInterface>(&op);
1658 if (!iface)
1659 continue;
1660
1661 iface.getEffects(effects);
1662 }
1663 return;
1664 }
1665
1666 // Carry over all effects on arguments of the entry block as those on the
1667 // operands, this is the same value just remapped.
1668 remapArgumentEffects(block&: body, operands: operation->getOpOperands(), effects);
1669}
1670
1671LogicalResult transform::detail::mapPossibleTopLevelTransformOpBlockArguments(
1672 TransformState &state, Operation *op, Region &region) {
1673 SmallVector<Operation *> targets;
1674 SmallVector<SmallVector<MappedValue>> extraMappings;
1675 if (op->getNumOperands() != 0) {
1676 llvm::append_range(C&: targets, R: state.getPayloadOps(value: op->getOperand(idx: 0)));
1677 prepareValueMappings(mappings&: extraMappings, values: op->getOperands().drop_front(), state);
1678 } else {
1679 if (state.getNumTopLevelMappings() !=
1680 region.front().getNumArguments() - 1) {
1681 return emitError(loc: op->getLoc())
1682 << "operation expects " << region.front().getNumArguments() - 1
1683 << " extra value bindings, but " << state.getNumTopLevelMappings()
1684 << " were provided to the interpreter";
1685 }
1686
1687 targets.push_back(Elt: state.getTopLevel());
1688
1689 for (unsigned i = 0, e = state.getNumTopLevelMappings(); i < e; ++i)
1690 extraMappings.push_back(Elt: llvm::to_vector(Range: state.getTopLevelMapping(position: i)));
1691 }
1692
1693 if (failed(Result: state.mapBlockArguments(argument: region.front().getArgument(i: 0), operations: targets)))
1694 return failure();
1695
1696 for (BlockArgument argument : region.front().getArguments().drop_front()) {
1697 if (failed(Result: state.mapBlockArgument(
1698 argument, values: extraMappings[argument.getArgNumber() - 1])))
1699 return failure();
1700 }
1701
1702 return success();
1703}
1704
1705LogicalResult
1706transform::detail::verifyPossibleTopLevelTransformOpTrait(Operation *op) {
1707 // Attaching this trait without the interface is a misuse of the API, but it
1708 // cannot be caught via a static_assert because interface registration is
1709 // dynamic.
1710 assert(isa<TransformOpInterface>(op) &&
1711 "should implement TransformOpInterface to have "
1712 "PossibleTopLevelTransformOpTrait");
1713
1714 if (op->getNumRegions() < 1)
1715 return op->emitOpError() << "expects at least one region";
1716
1717 Region *bodyRegion = &op->getRegion(index: 0);
1718 if (!llvm::hasNItems(C&: *bodyRegion, N: 1))
1719 return op->emitOpError() << "expects a single-block region";
1720
1721 Block *body = &bodyRegion->front();
1722 if (body->getNumArguments() == 0) {
1723 return op->emitOpError()
1724 << "expects the entry block to have at least one argument";
1725 }
1726 if (!llvm::isa<TransformHandleTypeInterface>(
1727 body->getArgument(0).getType())) {
1728 return op->emitOpError()
1729 << "expects the first entry block argument to be of type "
1730 "implementing TransformHandleTypeInterface";
1731 }
1732 BlockArgument arg = body->getArgument(i: 0);
1733 if (op->getNumOperands() != 0) {
1734 if (arg.getType() != op->getOperand(idx: 0).getType()) {
1735 return op->emitOpError()
1736 << "expects the type of the block argument to match "
1737 "the type of the operand";
1738 }
1739 }
1740 for (BlockArgument arg : body->getArguments().drop_front()) {
1741 if (llvm::isa<TransformHandleTypeInterface, TransformParamTypeInterface,
1742 TransformValueHandleTypeInterface>(arg.getType()))
1743 continue;
1744
1745 InFlightDiagnostic diag =
1746 op->emitOpError()
1747 << "expects trailing entry block arguments to be of type implementing "
1748 "TransformHandleTypeInterface, TransformValueHandleTypeInterface or "
1749 "TransformParamTypeInterface";
1750 diag.attachNote() << "argument #" << arg.getArgNumber() << " does not";
1751 return diag;
1752 }
1753
1754 if (auto *parent =
1755 op->getParentWithTrait<PossibleTopLevelTransformOpTrait>()) {
1756 if (op->getNumOperands() != body->getNumArguments()) {
1757 InFlightDiagnostic diag =
1758 op->emitOpError()
1759 << "expects operands to be provided for a nested op";
1760 diag.attachNote(noteLoc: parent->getLoc())
1761 << "nested in another possible top-level op";
1762 return diag;
1763 }
1764 }
1765
1766 return success();
1767}
1768
1769//===----------------------------------------------------------------------===//
1770// Utilities for ParamProducedTransformOpTrait.
1771//===----------------------------------------------------------------------===//
1772
1773void transform::detail::getParamProducerTransformOpTraitEffects(
1774 Operation *op, SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
1775 producesHandle(handles: op->getResults(), effects);
1776 bool hasPayloadOperands = false;
1777 for (OpOperand &operand : op->getOpOperands()) {
1778 onlyReadsHandle(handles: operand, effects);
1779 if (llvm::isa<TransformHandleTypeInterface,
1780 TransformValueHandleTypeInterface>(operand.get().getType()))
1781 hasPayloadOperands = true;
1782 }
1783 if (hasPayloadOperands)
1784 onlyReadsPayload(effects);
1785}
1786
1787LogicalResult
1788transform::detail::verifyParamProducerTransformOpTrait(Operation *op) {
1789 // Interfaces can be attached dynamically, so this cannot be a static
1790 // assert.
1791 if (!op->getName().getInterface<MemoryEffectOpInterface>()) {
1792 llvm::report_fatal_error(
1793 reason: Twine("ParamProducerTransformOpTrait must be attached to an op that "
1794 "implements MemoryEffectsOpInterface, found on ") +
1795 op->getName().getStringRef());
1796 }
1797 for (Value result : op->getResults()) {
1798 if (llvm::isa<TransformParamTypeInterface>(result.getType()))
1799 continue;
1800 return op->emitOpError()
1801 << "ParamProducerTransformOpTrait attached to this op expects "
1802 "result types to implement TransformParamTypeInterface";
1803 }
1804 return success();
1805}
1806
1807//===----------------------------------------------------------------------===//
1808// Memory effects.
1809//===----------------------------------------------------------------------===//
1810
1811void transform::consumesHandle(
1812 MutableArrayRef<OpOperand> handles,
1813 SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
1814 for (OpOperand &handle : handles) {
1815 effects.emplace_back(Args: MemoryEffects::Read::get(), Args: &handle,
1816 Args: TransformMappingResource::get());
1817 effects.emplace_back(Args: MemoryEffects::Free::get(), Args: &handle,
1818 Args: TransformMappingResource::get());
1819 }
1820}
1821
1822/// Returns `true` if the given list of effects instances contains an instance
1823/// with the effect type specified as template parameter.
1824template <typename EffectTy, typename ResourceTy, typename Range>
1825static bool hasEffect(Range &&effects) {
1826 return llvm::any_of(effects, [](const MemoryEffects::EffectInstance &effect) {
1827 return isa<EffectTy>(effect.getEffect()) &&
1828 isa<ResourceTy>(effect.getResource());
1829 });
1830}
1831
1832bool transform::isHandleConsumed(Value handle,
1833 transform::TransformOpInterface transform) {
1834 auto iface = cast<MemoryEffectOpInterface>(transform.getOperation());
1835 SmallVector<MemoryEffects::EffectInstance> effects;
1836 iface.getEffectsOnValue(handle, effects);
1837 return ::hasEffect<MemoryEffects::Read, TransformMappingResource>(effects) &&
1838 ::hasEffect<MemoryEffects::Free, TransformMappingResource>(effects);
1839}
1840
1841void transform::producesHandle(
1842 ResultRange handles,
1843 SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
1844 for (OpResult handle : handles) {
1845 effects.emplace_back(Args: MemoryEffects::Allocate::get(), Args&: handle,
1846 Args: TransformMappingResource::get());
1847 effects.emplace_back(Args: MemoryEffects::Write::get(), Args&: handle,
1848 Args: TransformMappingResource::get());
1849 }
1850}
1851
1852void transform::producesHandle(
1853 MutableArrayRef<BlockArgument> handles,
1854 SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
1855 for (BlockArgument handle : handles) {
1856 effects.emplace_back(Args: MemoryEffects::Allocate::get(), Args&: handle,
1857 Args: TransformMappingResource::get());
1858 effects.emplace_back(Args: MemoryEffects::Write::get(), Args&: handle,
1859 Args: TransformMappingResource::get());
1860 }
1861}
1862
1863void transform::onlyReadsHandle(
1864 MutableArrayRef<OpOperand> handles,
1865 SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
1866 for (OpOperand &handle : handles) {
1867 effects.emplace_back(Args: MemoryEffects::Read::get(), Args: &handle,
1868 Args: TransformMappingResource::get());
1869 }
1870}
1871
1872void transform::modifiesPayload(
1873 SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
1874 effects.emplace_back(Args: MemoryEffects::Read::get(), Args: PayloadIRResource::get());
1875 effects.emplace_back(Args: MemoryEffects::Write::get(), Args: PayloadIRResource::get());
1876}
1877
1878void transform::onlyReadsPayload(
1879 SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
1880 effects.emplace_back(Args: MemoryEffects::Read::get(), Args: PayloadIRResource::get());
1881}
1882
1883bool transform::doesModifyPayload(transform::TransformOpInterface transform) {
1884 auto iface = cast<MemoryEffectOpInterface>(transform.getOperation());
1885 SmallVector<MemoryEffects::EffectInstance> effects;
1886 iface.getEffects(effects);
1887 return ::hasEffect<MemoryEffects::Write, PayloadIRResource>(effects);
1888}
1889
1890bool transform::doesReadPayload(transform::TransformOpInterface transform) {
1891 auto iface = cast<MemoryEffectOpInterface>(transform.getOperation());
1892 SmallVector<MemoryEffects::EffectInstance> effects;
1893 iface.getEffects(effects);
1894 return ::hasEffect<MemoryEffects::Read, PayloadIRResource>(effects);
1895}
1896
1897void transform::getConsumedBlockArguments(
1898 Block &block, llvm::SmallDenseSet<unsigned int> &consumedArguments) {
1899 SmallVector<MemoryEffects::EffectInstance> effects;
1900 for (Operation &nested : block) {
1901 auto iface = dyn_cast<MemoryEffectOpInterface>(nested);
1902 if (!iface)
1903 continue;
1904
1905 effects.clear();
1906 iface.getEffects(effects);
1907 for (const MemoryEffects::EffectInstance &effect : effects) {
1908 BlockArgument argument =
1909 dyn_cast_or_null<BlockArgument>(Val: effect.getValue());
1910 if (!argument || argument.getOwner() != &block ||
1911 !isa<MemoryEffects::Free>(Val: effect.getEffect()) ||
1912 effect.getResource() != transform::TransformMappingResource::get()) {
1913 continue;
1914 }
1915 consumedArguments.insert(V: argument.getArgNumber());
1916 }
1917 }
1918}
1919
1920//===----------------------------------------------------------------------===//
1921// Utilities for TransformOpInterface.
1922//===----------------------------------------------------------------------===//
1923
1924SmallVector<OpOperand *> transform::detail::getConsumedHandleOpOperands(
1925 TransformOpInterface transformOp) {
1926 SmallVector<OpOperand *> consumedOperands;
1927 consumedOperands.reserve(N: transformOp->getNumOperands());
1928 auto memEffectInterface =
1929 cast<MemoryEffectOpInterface>(transformOp.getOperation());
1930 SmallVector<MemoryEffects::EffectInstance, 2> effects;
1931 for (OpOperand &target : transformOp->getOpOperands()) {
1932 effects.clear();
1933 memEffectInterface.getEffectsOnValue(target.get(), effects);
1934 if (llvm::any_of(effects, [](const MemoryEffects::EffectInstance &effect) {
1935 return isa<transform::TransformMappingResource>(
1936 effect.getResource()) &&
1937 isa<MemoryEffects::Free>(effect.getEffect());
1938 })) {
1939 consumedOperands.push_back(&target);
1940 }
1941 }
1942 return consumedOperands;
1943}
1944
1945LogicalResult transform::detail::verifyTransformOpInterface(Operation *op) {
1946 auto iface = cast<MemoryEffectOpInterface>(op);
1947 SmallVector<MemoryEffects::EffectInstance> effects;
1948 iface.getEffects(effects);
1949
1950 auto effectsOn = [&](Value value) {
1951 return llvm::make_filter_range(
1952 Range&: effects, Pred: [value](const MemoryEffects::EffectInstance &instance) {
1953 return instance.getValue() == value;
1954 });
1955 };
1956
1957 std::optional<unsigned> firstConsumedOperand;
1958 for (OpOperand &operand : op->getOpOperands()) {
1959 auto range = effectsOn(operand.get());
1960 if (range.empty()) {
1961 InFlightDiagnostic diag =
1962 op->emitError() << "TransformOpInterface requires memory effects "
1963 "on operands to be specified";
1964 diag.attachNote() << "no effects specified for operand #"
1965 << operand.getOperandNumber();
1966 return diag;
1967 }
1968 if (::hasEffect<MemoryEffects::Allocate, TransformMappingResource>(effects&: range)) {
1969 InFlightDiagnostic diag = op->emitError()
1970 << "TransformOpInterface did not expect "
1971 "'allocate' memory effect on an operand";
1972 diag.attachNote() << "specified for operand #"
1973 << operand.getOperandNumber();
1974 return diag;
1975 }
1976 if (!firstConsumedOperand &&
1977 ::hasEffect<MemoryEffects::Free, TransformMappingResource>(effects&: range)) {
1978 firstConsumedOperand = operand.getOperandNumber();
1979 }
1980 }
1981
1982 if (firstConsumedOperand &&
1983 !::hasEffect<MemoryEffects::Write, PayloadIRResource>(effects)) {
1984 InFlightDiagnostic diag =
1985 op->emitError()
1986 << "TransformOpInterface expects ops consuming operands to have a "
1987 "'write' effect on the payload resource";
1988 diag.attachNote() << "consumes operand #" << *firstConsumedOperand;
1989 return diag;
1990 }
1991
1992 for (OpResult result : op->getResults()) {
1993 auto range = effectsOn(result);
1994 if (!::hasEffect<MemoryEffects::Allocate, TransformMappingResource>(
1995 effects&: range)) {
1996 InFlightDiagnostic diag =
1997 op->emitError() << "TransformOpInterface requires 'allocate' memory "
1998 "effect to be specified for results";
1999 diag.attachNote() << "no 'allocate' effect specified for result #"
2000 << result.getResultNumber();
2001 return diag;
2002 }
2003 }
2004
2005 return success();
2006}
2007
2008//===----------------------------------------------------------------------===//
2009// Entry point.
2010//===----------------------------------------------------------------------===//
2011
2012LogicalResult transform::applyTransforms(
2013 Operation *payloadRoot, TransformOpInterface transform,
2014 const RaggedArray<MappedValue> &extraMapping,
2015 const TransformOptions &options, bool enforceToplevelTransformOp,
2016 function_ref<void(TransformState &)> stateInitializer,
2017 function_ref<LogicalResult(TransformState &)> stateExporter) {
2018 if (enforceToplevelTransformOp) {
2019 if (!transform->hasTrait<PossibleTopLevelTransformOpTrait>() ||
2020 transform->getNumOperands() != 0) {
2021 return transform->emitError()
2022 << "expected transform to start at the top-level transform op";
2023 }
2024 } else if (failed(
2025 detail::verifyPossibleTopLevelTransformOpTrait(op: transform))) {
2026 return failure();
2027 }
2028
2029 TransformState state(transform->getParentRegion(), payloadRoot, extraMapping,
2030 options);
2031 if (stateInitializer)
2032 stateInitializer(state);
2033 if (state.applyTransform(transform: transform).checkAndReport().failed())
2034 return failure();
2035 if (stateExporter)
2036 return stateExporter(state);
2037 return success();
2038}
2039
2040//===----------------------------------------------------------------------===//
2041// Generated interface implementation.
2042//===----------------------------------------------------------------------===//
2043
2044#include "mlir/Dialect/Transform/Interfaces/TransformInterfaces.cpp.inc"
2045#include "mlir/Dialect/Transform/Interfaces/TransformTypeInterfaces.cpp.inc"
2046

Provided by KDAB

Privacy Policy
Improve your Profiling and Debugging skills
Find out more

source code of mlir/lib/Dialect/Transform/Interfaces/TransformInterfaces.cpp