1//===- LinalgTransformOps.cpp - Implementation of Linalg match ops --------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#include "mlir/Dialect/Linalg/TransformOps/LinalgMatchOps.h"
10#include "mlir/Analysis/SliceAnalysis.h"
11#include "mlir/Dialect/Linalg/IR/Linalg.h"
12#include "mlir/Dialect/Linalg/IR/LinalgInterfaces.h"
13#include "mlir/Dialect/Linalg/TransformOps/Syntax.h"
14#include "mlir/Dialect/Linalg/Utils/Utils.h"
15#include "mlir/Dialect/Transform/IR/TransformTypes.h"
16#include "mlir/Dialect/Transform/Interfaces/MatchInterfaces.h"
17#include "mlir/IR/BuiltinAttributes.h"
18#include "llvm/Support/Debug.h"
19#include "llvm/Support/FormatVariadic.h"
20#include "llvm/Support/InterleavedRange.h"
21
22using namespace mlir;
23
24#define DEBUG_TYPE "linalg-transforms"
25#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ")
26
27//===----------------------------------------------------------------------===//
28// StructuredMatchOp
29//===----------------------------------------------------------------------===//
30
31DiagnosedSilenceableFailure transform::MatchStructuredOp::matchOperation(
32 Operation *current, transform::TransformResults &results,
33 transform::TransformState &state) {
34 // First, check if the payload operation is a structured Linalg operation.
35 if (!isa<linalg::LinalgOp>(Val: current)) {
36 if (getFailurePropagationMode().value_or(
37 u: FailurePropagationMode::Propagate) ==
38 FailurePropagationMode::Propagate) {
39 return emitSilenceableError() << "expected a Linalg op";
40 }
41 // If errors are suppressed, succeed and set all results to empty lists.
42 LLVM_DEBUG(DBGS() << "optional nested matcher expected a Linalg op");
43 results.setRemainingToEmpty(cast<TransformOpInterface>(Val: getOperation()));
44 return DiagnosedSilenceableFailure::success();
45 }
46
47 // Bind `current` to the block argument.
48 auto scope = state.make_region_scope(region&: getBodyRegion());
49 if (failed(Result: state.mapBlockArgument(argument: getBody()->getArgument(i: 0),
50 values: MappedValue(current)))) {
51 return DiagnosedSilenceableFailure::definiteFailure();
52 }
53
54 for (Operation &nested : getBody()->without_terminator()) {
55 DiagnosedSilenceableFailure diag =
56 state.applyTransform(transform: cast<TransformOpInterface>(Val&: nested));
57 if (diag.isDefiniteFailure())
58 return diag;
59 if (diag.succeeded())
60 continue;
61
62 // If propagating errors, do this immediately.
63 assert(diag.isSilenceableFailure());
64 if (getFailurePropagationMode().value_or(
65 u: FailurePropagationMode::Propagate) ==
66 FailurePropagationMode::Propagate) {
67 return diag;
68 }
69
70 // If suppressing errors, print the message into the debug stream before
71 // silencing it. Then set all results value that are already known.
72 // Results come from the terminator operands, which may be defined in the
73 // (single) block of this operation or above it. When they are defined
74 // above, they are known to be mapped at this point per SSA dominance.
75 // When they are defined in this block, we additionally check if we have
76 // already applied the operation that defines them. If not, the
77 // corresponding results will be set to empty lists.
78 LLVM_DEBUG(DBGS() << "optional nested matcher failed: " << diag.getMessage()
79 << "\n");
80 (void)diag.silence();
81 SmallVector<OpOperand *> undefinedOperands;
82 for (OpOperand &terminatorOperand :
83 getBody()->getTerminator()->getOpOperands()) {
84 Operation *definingOp = terminatorOperand.get().getDefiningOp();
85 if (!definingOp)
86 continue;
87 if (definingOp->getBlock() != getBody())
88 continue;
89 if (definingOp->isBeforeInBlock(other: &nested))
90 continue;
91
92 undefinedOperands.push_back(Elt: &terminatorOperand);
93 }
94
95 SmallVector<SmallVector<transform::MappedValue>> mappings;
96 auto filtered = llvm::make_filter_range(
97 Range: getBody()->getTerminator()->getOpOperands(), Pred: [&](OpOperand &opOperand) {
98 return !llvm::is_contained(Range&: undefinedOperands, Element: &opOperand);
99 });
100 SmallVector<Value> definedOperands = llvm::to_vector(Range: llvm::map_range(
101 C&: filtered, F: [](OpOperand &opOperand) { return opOperand.get(); }));
102 detail::prepareValueMappings(mappings, values: definedOperands, state);
103 for (auto &&[operand, mapping] : llvm::zip_equal(t&: filtered, u&: mappings)) {
104 results.setMappedValues(handle: getResults()[operand.getOperandNumber()],
105 values: mapping);
106 }
107 results.setRemainingToEmpty(cast<TransformOpInterface>(Val: getOperation()));
108 return DiagnosedSilenceableFailure::success();
109 }
110
111 // Set the results.
112 detail::forwardTerminatorOperands(block: getBody(), state, results);
113 return DiagnosedSilenceableFailure::success();
114}
115
116void transform::MatchStructuredOp::getEffects(
117 SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
118 onlyReadsHandle(handles: getCurrentMutable(), effects);
119 onlyReadsPayload(effects);
120 producesHandle(handles: getOperation()->getOpResults(), effects);
121}
122
123LogicalResult transform::MatchStructuredOp::verify() {
124 if (getBody()->getNumArguments() != 1)
125 return emitOpError() << "expected one body argument";
126 if (!isa<TransformHandleTypeInterface>(Val: getBody()->getArgument(i: 0).getType())) {
127 return emitOpError() << "expected body argument to implement "
128 "TransformHandleTypeInterface";
129 }
130 for (Operation &nested : getBody()->without_terminator()) {
131 if (isa<MatchOpInterface>(Val: nested))
132 continue;
133 InFlightDiagnostic diag =
134 emitOpError()
135 << "expects nested operations to implement MatchOpInterface";
136 diag.attachNote(noteLoc: nested.getLoc()) << "offending operation";
137 return diag;
138 }
139 return success();
140}
141
142//===----------------------------------------------------------------------===//
143// StructuredOpPredicateOpTrait
144//===----------------------------------------------------------------------===//
145
146LogicalResult transform::detail::verifyStructuredOpPredicateOpTrait(
147 Operation *op, Value structuredOpHandle) {
148 if (!isa_and_nonnull<MatchStructuredOp>(Val: op->getParentOp())) {
149 return op->emitOpError() << "expects parent op to be '"
150 << MatchStructuredOp::getOperationName() << "'";
151 }
152
153 // Bail out here, let the verifier of the parent complain.
154 Operation *parent = op->getParentOp();
155 if (parent->getNumRegions() < 1 || parent->getRegion(index: 0).empty() ||
156 parent->getRegion(index: 0).front().getNumArguments() < 1)
157 return success();
158
159 if (structuredOpHandle != parent->getRegion(index: 0).front().getArgument(i: 0)) {
160 return op->emitOpError()
161 << "expected predicate to apply to the surrounding structured op";
162 }
163 return success();
164}
165
166//===----------------------------------------------------------------------===//
167// MatchStructuredBodyOp
168//===----------------------------------------------------------------------===//
169
170DiagnosedSilenceableFailure transform::MatchStructuredBodyOp::matchOperation(
171 Operation *current, transform::TransformResults &results,
172 transform::TransformState &state) {
173 auto linalgOp = cast<linalg::LinalgOp>(Val: current);
174 if (std::optional<uint64_t> position = getReductionPosition()) {
175 SmallVector<Operation *> combinerOps;
176 if (!matchReduction(iterCarriedArgs: linalgOp.getRegionOutputArgs(), redPos: *position,
177 combinerOps)) {
178 return emitSilenceableError() << "could not match reduction";
179 }
180 if (combinerOps.size() != 1) {
181 return emitSilenceableError() << "reduction combiner is not a single op";
182 }
183 return DiagnosedSilenceableFailure::success();
184 }
185 if (getPassthrough()) {
186 Block &body = linalgOp->getRegion(index: 0).front();
187 if (body.getTerminator()->getOperands() != linalgOp.getRegionInputArgs()) {
188 return emitSilenceableError() << "not a passthrough";
189 }
190 return DiagnosedSilenceableFailure::success();
191 }
192 if (getElementwise()) {
193 if (!isElementwise(op: linalgOp))
194 return emitSilenceableError() << "not elementwise";
195 return DiagnosedSilenceableFailure::success();
196 }
197 if (std::optional<ArrayAttr> contractionOps = getContraction()) {
198 Block &body = linalgOp->getRegion(index: 0).front();
199 std::string message;
200 llvm::raw_string_ostream os(message);
201 bool result = linalg::detail::isContractionBody(
202 block&: body,
203 isaPair: [&](Operation *elem, Operation *red) {
204 return elem->getName().getStringRef() ==
205 cast<StringAttr>(Val: (*contractionOps)[0]).getValue() &&
206 red->getName().getStringRef() ==
207 cast<StringAttr>(Val: (*contractionOps)[1]).getValue();
208 },
209 errs&: os);
210 if (result)
211 return DiagnosedSilenceableFailure::success();
212 return emitSilenceableError() << "contraction: " << message;
213 }
214 return emitDefiniteFailure() << "unknown body condition";
215}
216
217LogicalResult transform::MatchStructuredBodyOp::verify() {
218 int64_t numOptions = getReductionPosition().has_value() + getPassthrough() +
219 getElementwise() + getContraction().has_value();
220
221 if (numOptions > 1) {
222 StringAttr attributeNames[] = {
223 getReductionPositionAttrName(), getPassthroughAttrName(),
224 getElementwiseAttrName(), getContractionAttrName()};
225 return emitOpError() << "only one of {" << llvm::interleaved(R: attributeNames)
226 << "} is allowed";
227 }
228
229 if (std::optional<ArrayAttr> contractionAttr = getContraction()) {
230 if (contractionAttr->size() != 2) {
231 return emitOpError() << "expects " << getContractionAttrName()
232 << " to contain two elements";
233 }
234 }
235 return success();
236}
237
238//===----------------------------------------------------------------------===//
239// MatchStructuredClassifyContractionDimsOp
240//===----------------------------------------------------------------------===//
241
242DiagnosedSilenceableFailure
243transform::MatchStructuredClassifyContractionDimsOp::matchOperation(
244 Operation *current, transform::TransformResults &results,
245 transform::TransformState &state) {
246 FailureOr<linalg::ContractionDimensions> contractionDims =
247 linalg::inferContractionDims(linalgOp: cast<linalg::LinalgOp>(Val: current));
248 if (failed(Result: contractionDims))
249 return emitSilenceableError() << "could not infer contraction dimensions";
250
251 MLIRContext *context = current->getContext();
252 Builder builder(context);
253 auto makeI64Attrs = [&](ArrayRef<unsigned> values) {
254 return llvm::to_vector(
255 Range: llvm::map_range(C&: values, F: [&](unsigned value) -> Attribute {
256 return builder.getI64IntegerAttr(value);
257 }));
258 };
259 results.setParams(value: cast<OpResult>(Val: getBatch()),
260 params: makeI64Attrs(contractionDims->batch));
261 results.setParams(value: cast<OpResult>(Val: getM()), params: makeI64Attrs(contractionDims->m));
262 results.setParams(value: cast<OpResult>(Val: getN()), params: makeI64Attrs(contractionDims->n));
263 results.setParams(value: cast<OpResult>(Val: getK()), params: makeI64Attrs(contractionDims->k));
264 return DiagnosedSilenceableFailure::success();
265}
266
267//===----------------------------------------------------------------------===//
268// MatchStructuredClassifyConvolutionDimsOp
269//===----------------------------------------------------------------------===//
270
271DiagnosedSilenceableFailure
272transform::MatchStructuredClassifyConvolutionDimsOp::matchOperation(
273 Operation *current, transform::TransformResults &results,
274 transform::TransformState &state) {
275 FailureOr<linalg::ConvolutionDimensions> convolutionDims =
276 linalg::inferConvolutionDims(linalgOp: cast<linalg::LinalgOp>(Val: current));
277 if (failed(Result: convolutionDims))
278 return emitSilenceableError() << "could not infer convolution dimensions";
279
280 MLIRContext *context = current->getContext();
281 Builder builder(context);
282 auto makeI64Attrs = [&](ArrayRef<unsigned> values) {
283 return llvm::to_vector(
284 Range: llvm::map_range(C&: values, F: [&](unsigned value) -> Attribute {
285 return builder.getI64IntegerAttr(value);
286 }));
287 };
288 results.setParams(value: cast<OpResult>(Val: getBatch()),
289 params: makeI64Attrs(convolutionDims->batch));
290 results.setParams(value: cast<OpResult>(Val: getOutputImage()),
291 params: makeI64Attrs(convolutionDims->outputImage));
292 results.setParams(value: cast<OpResult>(Val: getOutputChannel()),
293 params: makeI64Attrs(convolutionDims->outputChannel));
294 results.setParams(value: cast<OpResult>(Val: getFilterLoop()),
295 params: makeI64Attrs(convolutionDims->filterLoop));
296 results.setParams(value: cast<OpResult>(Val: getInputChannel()),
297 params: makeI64Attrs(convolutionDims->inputChannel));
298 results.setParams(value: cast<OpResult>(Val: getDepth()),
299 params: makeI64Attrs(convolutionDims->depth));
300
301 auto makeI64AttrsFromI64 = [&](ArrayRef<int64_t> values) {
302 return llvm::to_vector(
303 Range: llvm::map_range(C&: values, F: [&](int64_t value) -> Attribute {
304 return builder.getI64IntegerAttr(value);
305 }));
306 };
307 results.setParams(value: cast<OpResult>(Val: getStrides()),
308 params: makeI64AttrsFromI64(convolutionDims->strides));
309 results.setParams(value: cast<OpResult>(Val: getDilations()),
310 params: makeI64AttrsFromI64(convolutionDims->dilations));
311 return DiagnosedSilenceableFailure::success();
312}
313
314//===----------------------------------------------------------------------===//
315// Utilities for structured match predicates.
316//===----------------------------------------------------------------------===//
317
318/// Checks if all values from `list` are also contained in `reference`. Returns
319/// a silenceable error with the given message at the given location when it is
320/// not the case. The error message must contain the "{0}" placeholder that
321/// will be substituted with the value from `list` that is not contained in
322/// `reference`.
323static DiagnosedSilenceableFailure containsAll(ArrayRef<unsigned> reference,
324 ArrayRef<int64_t> list,
325 Location loc,
326 const char *message) {
327 for (int64_t value : list) {
328 if (llvm::any_of(Range&: reference, P: [&](unsigned ref) {
329 return static_cast<int64_t>(ref) == value;
330 })) {
331 continue;
332 }
333 return emitSilenceableFailure(loc) << llvm::formatv(Fmt: message, Vals&: value);
334 }
335 return DiagnosedSilenceableFailure::success();
336}
337
338//===----------------------------------------------------------------------===//
339// MatchStructuredDimOp
340//===----------------------------------------------------------------------===//
341
342DiagnosedSilenceableFailure transform::MatchStructuredDimOp::matchOperation(
343 Operation *current, transform::TransformResults &results,
344 transform::TransformState &state) {
345 auto linalgOp = cast<linalg::LinalgOp>(Val: current);
346 SmallVector<int64_t> dimensions;
347 DiagnosedSilenceableFailure diag = getDimensionsFor(op: linalgOp, dims&: dimensions);
348 if (!diag.succeeded())
349 return diag;
350
351 // If asked to check for the kind of dimension, perform the check.
352 if (getParallel() || getReduction()) {
353 SmallVector<unsigned> reference;
354 if (getParallel())
355 linalgOp.getParallelDims(res&: reference);
356 else if (getReduction())
357 linalgOp.getReductionDims(res&: reference);
358
359 DiagnosedSilenceableFailure diag =
360 containsAll(reference, list: dimensions, loc: getLoc(),
361 message: getParallel() ? "expects dimension #{0} to be parallel"
362 : "expects dimension #{0} to be reduction");
363 if (!diag.succeeded())
364 return diag;
365 }
366
367 // If not capturing, we are done here.
368 if (!getResult())
369 return diag;
370
371 SmallVector<int64_t, 4> ranges = linalgOp.getStaticLoopRanges();
372 Builder builder(current);
373 SmallVector<Attribute> captured = llvm::to_vector(
374 Range: llvm::map_range(C&: dimensions, F: [&](int64_t dim) -> Attribute {
375 return builder.getI64IntegerAttr(value: ranges[dim]);
376 }));
377 results.setParams(value: cast<OpResult>(Val: getResult()), params: captured);
378 return DiagnosedSilenceableFailure::success();
379}
380
381DiagnosedSilenceableFailure transform::MatchStructuredDimOp::getDimensionsFor(
382 linalg::LinalgOp op, SmallVectorImpl<int64_t> &dims) {
383 DiagnosedSilenceableFailure diag =
384 expandTargetSpecification(loc: getLoc(), isAll: getIsAll(), isInverted: getIsInverted(),
385 rawList: getRawDimList(), maxNumber: op.getNumLoops(), result&: dims);
386 if (diag.isSilenceableFailure()) {
387 diag.attachNote(loc: op->getLoc())
388 << "while considering dimensions of this payload operation";
389 }
390 return diag;
391}
392
393LogicalResult transform::MatchStructuredDimOp::verify() {
394 if (getParallel() && getReduction()) {
395 return emitOpError() << "cannot request the same dimension to be both "
396 "parallel and reduction";
397 }
398 return verifyTransformMatchDimsOp(op: getOperation(), raw: getRawDimList(),
399 inverted: getIsInverted(), all: getIsAll());
400}
401
402//===----------------------------------------------------------------------===//
403// MatchStructuredElementalBitwidthOp
404//===----------------------------------------------------------------------===//
405
406DiagnosedSilenceableFailure
407transform::MatchStructuredElementalBitwidthOp::matchValue(
408 Value current, transform::TransformResults &results,
409 transform::TransformState &state) {
410 auto setupResult = [&](int64_t bitwidth) {
411 Attribute attr = Builder(current.getContext()).getI64IntegerAttr(value: bitwidth);
412 results.setParams(value: cast<OpResult>(Val: getResult()), params: {attr});
413 return DiagnosedSilenceableFailure::success();
414 };
415
416 Type type = current.getType();
417 if (type.isIntOrFloat())
418 return setupResult(type.getIntOrFloatBitWidth());
419
420 if (auto shapedType = dyn_cast<ShapedType>(Val&: type)) {
421 if (shapedType.getElementType().isIntOrFloat())
422 return setupResult(shapedType.getElementTypeBitWidth());
423 }
424 return emitSilenceableError()
425 << "unsupported type for bitwidth extraction: " << type;
426}
427
428//===----------------------------------------------------------------------===//
429// MatchStructuredInputOp
430//===----------------------------------------------------------------------===//
431
432DiagnosedSilenceableFailure transform::MatchStructuredInputOp::matchOperation(
433 Operation *current, transform::TransformResults &results,
434 transform::TransformState &state) {
435 auto linalgOp = cast<linalg::LinalgOp>(Val: current);
436 SmallVector<int64_t> positions;
437 DiagnosedSilenceableFailure diag = getPositionsFor(op: linalgOp, positions);
438 if (!diag.succeeded())
439 return diag;
440
441 SmallVector<MappedValue> operandMapping;
442 operandMapping.reserve(N: positions.size());
443 for (int64_t position : positions) {
444 AffineMap indexingMap =
445 linalgOp.getMatchingIndexingMap(opOperand: linalgOp.getDpsInputOperand(i: position));
446 if (getPermutation() && !indexingMap.isPermutation()) {
447 return emitSilenceableError() << "the indexing map for input #"
448 << position << " is not a permutation";
449 }
450 if (getProjectedPermutation() && !indexingMap.isProjectedPermutation()) {
451 return emitSilenceableError()
452 << "the indexing map for input #" << position
453 << " is not a projected permutation";
454 }
455
456 // If capture not requested, skip it.
457 if (!getResult())
458 continue;
459
460 if (isa<AffineMapParamType>(Val: getResult().getType())) {
461 operandMapping.emplace_back(Args: AffineMapAttr::get(value: indexingMap));
462 continue;
463 }
464
465 Value operand = linalgOp.getDpsInputOperand(i: position)->get();
466 if (isa<TransformValueHandleTypeInterface>(Val: getResult().getType())) {
467 operandMapping.emplace_back(Args&: operand);
468 continue;
469 }
470
471 Operation *operandProducer = operand.getDefiningOp();
472 if (!operandProducer) {
473 return emitSilenceableError()
474 << "input #" << position << " is not produced by an operation";
475 }
476 operandMapping.emplace_back(Args&: operandProducer);
477 }
478 if (getResult())
479 results.setMappedValues(handle: cast<OpResult>(Val: getResult()), values: operandMapping);
480 return DiagnosedSilenceableFailure::success();
481}
482
483DiagnosedSilenceableFailure transform::MatchStructuredInputOp::getPositionsFor(
484 linalg::LinalgOp op, SmallVectorImpl<int64_t> &positions) {
485 DiagnosedSilenceableFailure diag = expandTargetSpecification(
486 loc: getLoc(), isAll: getIsAll(), isInverted: getIsInverted(), rawList: getRawPositionList(),
487 maxNumber: op.getNumDpsInputs(), result&: positions);
488 if (diag.isSilenceableFailure()) {
489 diag.attachNote(loc: op->getLoc())
490 << "while considering DPS inputs of this payload operation";
491 }
492 return diag;
493}
494
495/// Verifies a matcher op for structured input or output, specifically the
496/// attributes specifying the operand positions.
497template <typename OpTy>
498LogicalResult verifyStructuredOperandOp(OpTy op) {
499 if (op.getPermutation() && op.getProjectedPermutation()) {
500 return op.emitOpError()
501 << op.getPermutationAttrName() << " and "
502 << op.getProjectedPermutationAttrName() << " are mutually exclusive";
503 }
504 if (op.getRawPositionList().size() > 1 && op.getResult()) {
505 return op.emitOpError()
506 << "cannot bind multiple inputs/inits to the same value";
507 }
508
509 return success();
510}
511
512LogicalResult transform::MatchStructuredInputOp::verify() {
513 if (failed(Result: verifyStructuredOperandOp(op: *this)))
514 return failure();
515 return verifyTransformMatchDimsOp(op: getOperation(), raw: getRawPositionList(),
516 inverted: getIsInverted(), all: getIsAll());
517}
518
519//===----------------------------------------------------------------------===//
520// MatchStructuredInitOp
521//===----------------------------------------------------------------------===//
522
523DiagnosedSilenceableFailure transform::MatchStructuredInitOp::matchOperation(
524 Operation *current, transform::TransformResults &results,
525 transform::TransformState &state) {
526 auto linalgOp = cast<linalg::LinalgOp>(Val: current);
527 SmallVector<int64_t> positions;
528 DiagnosedSilenceableFailure diag = getPositionsFor(op: linalgOp, positions);
529 if (!diag.succeeded())
530 return diag;
531
532 SmallVector<MappedValue> operandMapping;
533 operandMapping.reserve(N: positions.size());
534 for (int64_t position : positions) {
535 AffineMap indexingMap =
536 linalgOp.getMatchingIndexingMap(opOperand: linalgOp.getDpsInitOperand(i: position));
537 if (getPermutation() && !indexingMap.isPermutation()) {
538 return emitSilenceableError() << "the indexing map for output(init) #"
539 << position << " is not a permutation";
540 }
541 if (getProjectedPermutation() && !indexingMap.isProjectedPermutation()) {
542 return emitSilenceableError() << "the indexing map for output(init) #"
543 << position << " is not a permutation";
544 }
545
546 // If capture not requested, skip it.
547 if (!getResult())
548 continue;
549
550 if (isa<AffineMapParamType>(Val: getResult().getType())) {
551 operandMapping.emplace_back(Args: AffineMapAttr::get(value: indexingMap));
552 continue;
553 }
554
555 Value operand = linalgOp.getDpsInitOperand(i: position)->get();
556 if (isa<TransformValueHandleTypeInterface>(Val: getResult().getType())) {
557 operandMapping.emplace_back(Args&: operand);
558 continue;
559 }
560
561 Operation *operandProducer = operand.getDefiningOp();
562 if (!operandProducer) {
563 return emitSilenceableError() << "output(init) #" << position
564 << " is not produced by an operation";
565 }
566 operandMapping.emplace_back(Args&: operandProducer);
567 }
568 if (getResult())
569 results.setMappedValues(handle: cast<OpResult>(Val: getResult()), values: operandMapping);
570 return DiagnosedSilenceableFailure::success();
571}
572
573DiagnosedSilenceableFailure transform::MatchStructuredInitOp::getPositionsFor(
574 linalg::LinalgOp op, SmallVectorImpl<int64_t> &positions) {
575 DiagnosedSilenceableFailure diag = expandTargetSpecification(
576 loc: getLoc(), isAll: getIsAll(), isInverted: getIsInverted(), rawList: getRawPositionList(),
577 maxNumber: op.getNumDpsInits(), result&: positions);
578 if (diag.isSilenceableFailure()) {
579 diag.attachNote(loc: op->getLoc())
580 << "while considering DPS inits (outputs) of this payload operation";
581 }
582 return diag;
583}
584
585LogicalResult transform::MatchStructuredInitOp::verify() {
586 if (failed(Result: verifyStructuredOperandOp(op: *this)))
587 return failure();
588 return verifyTransformMatchDimsOp(op: getOperation(), raw: getRawPositionList(),
589 inverted: getIsInverted(), all: getIsAll());
590}
591
592//===----------------------------------------------------------------------===//
593// MatchStructuredNumInputsOp
594//===----------------------------------------------------------------------===//
595
596DiagnosedSilenceableFailure
597transform::MatchStructuredNumInputsOp::matchOperation(
598 Operation *current, transform::TransformResults &results,
599 transform::TransformState &state) {
600 auto linalgOp = cast<linalg::LinalgOp>(Val: current);
601 Attribute attr =
602 Builder(current).getI64IntegerAttr(value: linalgOp.getNumDpsInputs());
603 results.setParams(value: cast<OpResult>(Val: getResult()), params: {attr});
604 return DiagnosedSilenceableFailure::success();
605}
606
607//===----------------------------------------------------------------------===//
608// MatchStructuredNumInitsOp
609//===----------------------------------------------------------------------===//
610
611DiagnosedSilenceableFailure
612transform::MatchStructuredNumInitsOp::matchOperation(
613 Operation *current, transform::TransformResults &results,
614 transform::TransformState &state) {
615 auto linalgOp = cast<linalg::LinalgOp>(Val: current);
616 Attribute attr =
617 Builder(current).getI64IntegerAttr(value: linalgOp.getNumDpsInits());
618 results.setParams(value: cast<OpResult>(Val: getResult()), params: {attr});
619 return DiagnosedSilenceableFailure::success();
620}
621
622//===----------------------------------------------------------------------===//
623// MatchStructuredRankOp
624//===----------------------------------------------------------------------===//
625
626DiagnosedSilenceableFailure transform::MatchStructuredRankOp::matchOperation(
627 Operation *current, transform::TransformResults &results,
628 transform::TransformState &state) {
629 auto linalgOp = cast<linalg::LinalgOp>(Val: current);
630 int64_t numLoops = linalgOp.getNumLoops();
631 Attribute attr = Builder(linalgOp->getContext()).getI64IntegerAttr(value: numLoops);
632 results.setParams(value: cast<OpResult>(Val: getRank()), params: {attr});
633 return DiagnosedSilenceableFailure::success();
634}
635
636//===----------------------------------------------------------------------===//
637// MatchStructuredResultOp
638//===----------------------------------------------------------------------===//
639
640DiagnosedSilenceableFailure transform::MatchStructuredResultOp::matchOperation(
641 Operation *op, transform::TransformResults &results,
642 transform::TransformState &state) {
643 auto linalgOp = cast<linalg::LinalgOp>(Val: op);
644 int64_t position;
645 DiagnosedSilenceableFailure diag = getPositionFor(op: linalgOp, position);
646 if (!diag.succeeded())
647 return diag;
648
649 Value result = linalgOp.getTiedOpResult(opOperand: linalgOp.getDpsInitOperand(i: position));
650 if (isa<TransformValueHandleTypeInterface>(Val: getResult().getType())) {
651 results.setValues(handle: cast<OpResult>(Val: getResult()), values: {result});
652 return DiagnosedSilenceableFailure::success();
653 }
654
655 if (result.getUsers().empty()) {
656 return emitSilenceableError()
657 << "no users of the result #" << getPosition();
658 }
659 Operation *firstUser = *result.getUsers().begin();
660 if (getAny()) {
661 results.set(value: cast<OpResult>(Val: getResult()), ops: {firstUser});
662 return DiagnosedSilenceableFailure::success();
663 }
664 if (getSingle()) {
665 if (!llvm::hasSingleElement(C: result.getUsers())) {
666 return emitSilenceableError()
667 << "more than one result user with single user requested";
668 }
669 results.set(value: cast<OpResult>(Val: getResult()), ops: {firstUser});
670 return DiagnosedSilenceableFailure::success();
671 }
672
673 return emitDefiniteFailure() << "unknown sub-predicate";
674}
675
676DiagnosedSilenceableFailure
677transform::MatchStructuredResultOp::getPositionFor(linalg::LinalgOp op,
678 int64_t &position) {
679 auto rawPosition = static_cast<int64_t>(getPosition());
680 position = rawPosition < 0 ? op.getNumDpsInits() + rawPosition : rawPosition;
681 if (position >= op.getNumDpsInits() || position < 0) {
682 return emitSilenceableError()
683 << "position " << rawPosition
684 << " overflows the number of results(ints) of the payload operation";
685 }
686 return DiagnosedSilenceableFailure::success();
687}
688
689LogicalResult transform::MatchStructuredResultOp::verify() {
690 if ((getAny() || getSingle()) ^
691 isa<TransformHandleTypeInterface>(Val: getResult().getType())) {
692 return emitOpError() << "expects either the any/single keyword or the type "
693 "value handle result type";
694 }
695 if (getAny() && getSingle()) {
696 return emitOpError() << "'any' and 'single' are mutually exclusive";
697 }
698 return success();
699}
700
701//===----------------------------------------------------------------------===//
702// MatchStructuredYieldOp
703//===----------------------------------------------------------------------===//
704
705void transform::MatchStructuredYieldOp::getEffects(
706 SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
707 onlyReadsHandle(handles: getHandlesMutable(), effects);
708 onlyReadsPayload(effects);
709}
710
711void transform::MatchStructuredYieldOp::build(OpBuilder &builder,
712 OperationState &state) {
713 build(odsBuilder&: builder, odsState&: state, handles: ValueRange());
714}
715
716#define GET_OP_CLASSES
717#include "mlir/Dialect/Linalg/TransformOps/LinalgMatchOps.cpp.inc"
718

source code of mlir/lib/Dialect/Linalg/TransformOps/LinalgMatchOps.cpp